Skip to content

Train Module Overview

The physioex.train module provides comprehensive tools for training, fine-tuning, and evaluating deep learning models on sleep staging datasets.

Usage Example

Below is an example of how to use the training module:

from physioex.train.utils import train, test, finetune
from physioex.data import PhysioExDataModule

from physioex.train.networks.utils.loss import config as loss_config
from physioex.train.networks import config as network_config

checkpoint_path = "/path/to/your/checkpoint/dir/"

# first configure the model

# set the model configuration dictonary

# in case your model is from physioex.train.networks
# you can load its configuration

your_model_config = network_config["tinysleepnet"] 

your_model_config["loss_call"] = loss_config["cel"] # CrossEntropy Loss
your_model_config["loss_params"] = dict()
your_model_config["seq_len"] = 21 # needs to be the same as the DataModule
your_model_config["in_channels"] = 1 # needs to be the same as the DataModule

#your_model_config["n_classes"] = 5  needs to be set if you are loading a custom SleepModule

Here we loaded che configuration setup to train a physioex.train.networks.TinySleepNet model. Best practices is to have a .yaml file where to store the model configuration, both in case you are using a custom SleepModule or in case you are using one of the models provided by PhysioEx.

Here is an example of a possible .yaml configuration file and how to read it properly:

module_config: 
    loss_call : physioex.train.networks.utils.loss:CrossEntropyLoss
    loss_params : {}
    seq_len : 21
    in_channels : 1
    n_classes : 5
    ... # your additional model configuration parameters should be provided here
module: physioex.train.networks:TinySleepNet # can be any model extends SleepModule
model_name: "tinysleepnet" # can be avoided if loading your custom SleepModule
input_transform: "raw"
target_transform: null
checkpoint_path : "/path/to/your/checkpoint/dir/"
import yaml

with open("my_network_config.yaml", "r") as file:
    config = yaml.safe_load(file)

your_module_config = config["module_config"]

# load the loss function 
import importlib
loss_package, loss_class = your_module_config["loss_call"].split(":")
your_model_config = getattr(importlib.import_module(loss_package), loss_class)

# in case you provide model_name the system loads the additional model parameters from the library
if "model_name" in config:
    model_name = config["model_name"]
    module_config = networks_config[model_name]["module_config"]
    your_model_config.update(module_config)

    config["input_transform"] = networks_config[model_name]["input_transform"]
    config["target_transform"] = networks_config[model_name]["target_transform"]

# load the model class
model_package, model_class = config["module"].split(":")
model_class = getattr(importlib.import_module(model_package), model_class)
Note

In case you are using a model provided by the library, "input_transform" and "target_transform" can be loaded from the network_config ( line 20-21 )

Now we need to set up the datamodule arguments and we can start training the model:

datamodule_kwargs = {
    "selected_channels" : ["EEG"], # needs to match in_channels
    "sequence_length" : your_model_config["seq_len"],
    "target_transform" : your_model_config["target_transform"]
    "preprocessing" : your_model_config["input_transform"],
    "data_folder" : "/your/data/folder",
}

model_config = your_model_config

# Train the model
best_checkpoint = train(
    datasets = "hmc", # can be a list or a PhysioExDataModule
    datamodule_kwargs = datamodule_kwargs,
    model = model,
    model_class = model_class,
    model_config = model_config,
    checkpoint_path = checkpoint_path
    batch_size = 128,
    max_epochs = 10
)

# Test the model
results_dataframe = test(
    datasets = "hmc",
    datamodule_kwargs = datamodule_kwargs,
    model_class = model_class,
    model_config = your_model_config,
    chekcpoint_path = os.path.join( checkpoint_path, best_checkpoint ),
    batch_size = 128,
    results_dir = checkpoint_path,  # if you want to save the test results 
                                    # in your checkpoint directory
)
Even in this case, the best practice should be to save the datamodule_kwargs into the .yaml configuration file, at least the non-dynamic ones.

Now imagine that we want to fine-tune the trained model on a new dataset.

train_kwargs = {
    "dataset" = "dcsm",
    "datamodule" = datamodule_kwargs,
    "batch_size" = 128,
    "max_epochs" = 10,
    "checkpoint_path" = checkpoint_path,
}

new_best_checkpoint = finetune(
    model_class = model_class,
    model_config = model_config,
    model_checkpoint = os.path.join( checkpoint_path, best_checkpoint),
    learning_rate = 1e-7, # slower the learning rate to avoid losing prior training info.
    train_kwargs = train_kwargs,
) 

Documentation

train

Trains a model using the provided datasets and configuration.

Parameters:

Name Type Description Default
datasets Union[List[str], str, PhysioExDataModule]

The datasets to be used for training. Can be a list of dataset names, a single dataset name, or a PhysioExDataModule instance.

required
datamodule_kwargs dict

Additional keyword arguments to be passed to the PhysioExDataModule. Defaults to {}.

required
model SleepModule

The model to be trained. If provided, model_class, model_config, and resume are ignored. Defaults to None.

required
model_class type

The class of the model to be trained. Required if model is not provided. Defaults to None.

required
model_config dict

The configuration dictionary for the model. Required if model is not provided. Defaults to None.

required
batch_size int

The batch size to be used for training. Defaults to 128.

required
fold int

The fold index for cross-validation. Defaults to -1.

required
hpc bool

Whether to use high-performance computing (HPC) settings. Defaults to False.

required
num_validations int

The number of validation steps per epoch. Defaults to 10.

required
checkpoint_path str

The path to save the model checkpoints. If None, a new path is generated. Defaults to None.

required
max_epochs int

The maximum number of epochs for training. Defaults to 10.

required
num_nodes int

The number of nodes to be used for distributed training. Defaults to 1.

required
resume bool

Whether to resume training from the last checkpoint. Defaults to True.

required

Returns:

Name Type Description
str

The path to the best model checkpoint.

Raises:

Type Description
ValueError

If datasets is not a list, a string, or a PhysioExDataModule instance.

ValueError

If model is None and any of model_class or model_config are also None.

Notes
  • The function sets up the data module, model, and trainer, and then starts the training process.
  • If resume is True and a checkpoint is found, training resumes from the last checkpoint.
  • The function returns the path to the best model checkpoint based on validation accuracy.

test

Tests a model using the provided datasets and configuration.

Parameters:

Name Type Description Default
datasets Union[List[str], str, PhysioExDataModule]

The datasets to be used for testing. Can be a list of dataset names, a single dataset name, or a PhysioExDataModule instance.

required
datamodule_kwargs dict

Additional keyword arguments to be passed to the PhysioExDataModule. Defaults to {}.

required
model SleepModule

The model to be tested. If provided, model_class, model_config, and resume are ignored. Defaults to None.

required
model_class type

The class of the model to be tested. Required if model is not provided. Defaults to None.

required
model_config dict

The configuration dictionary for the model. Required if model is not provided. Defaults to None.

required
batch_size int

The batch size to be used for testing. Defaults to 128.

required
fold int

The fold index for cross-validation. Defaults to -1.

required
hpc bool

Whether to use high-performance computing (HPC) settings. Defaults to False.

required
checkpoint_path str

The path to the checkpoint from which to load the model. Required if model is not provided. Defaults to None.

required
results_path str

The path to save the test results. If None, results are not saved. Defaults to None.

required
num_nodes int

The number of nodes to be used for distributed testing. Defaults to 1.

required
aggregate_datasets bool

Whether to aggregate the datasets for testing. Defaults to False.

required

Returns:

Type Description

pd.DataFrame: A DataFrame containing the test results.

Raises:

Type Description
ValueError

If datasets is not a list, a string, or a PhysioExDataModule instance.

ValueError

If model is None and any of model_class or model_config are also None.

Notes
  • The function sets up the data module, model, and trainer, and then starts the testing process.
  • The function returns a DataFrame containing the test results for each dataset.
  • If results_path is provided, the results are saved as a CSV file in the specified path.

finetune

Fine-tunes a pre-trained model using the provided datasets and configuration.

Parameters:

Name Type Description Default
datasets Union[List[str], str, PhysioExDataModule]

The datasets to be used for fine-tuning. Can be a list of dataset names, a single dataset name, or a PhysioExDataModule instance.

required
datamodule_kwargs dict

Additional keyword arguments to be passed to the PhysioExDataModule. Defaults to {}.

required
model Union[dict, SleepModule]

The model to be fine-tuned. If provided, model_class, model_config, and model_checkpoint are ignored. Defaults to None.

required
model_class type

The class of the model to be fine-tuned. Required if model is not provided. Defaults to None.

required
model_config dict

The configuration dictionary for the model. Required if model is not provided. Defaults to None.

required
model_checkpoint str

The path to the checkpoint from which to load the model. Required if model is not provided. Defaults to None.

required
learning_rate float

The learning rate to be set for fine-tuning. If None, the learning rate is not updated. Default is 1e-7.

required
weight_decay Union[str, float]

The weight decay to be set for fine-tuning. If None, the weight decay is not updated. If "auto", it is set to 10% of the learning rate. Default is "auto".

required
train_kwargs Dict

Additional keyword arguments to be passed to the train function. Defaults to {}.

required

Returns:

Name Type Description
str

The path of the best model checkpoint.

Raises:

Type Description
ValueError

If model is None and any of model_class, model_config, or model_checkpoint are also None.

ValueError

If model is not a dictionary or a SleepModule.

Notes
  • Models cannot be fine-tuned from scratch; they must be loaded from a checkpoint or be a pre-trained model from physioex.models.
  • Typically, when fine-tuning a model, you want to set up the learning rate.