Train Module Overview¶
The physioex.train
module provides comprehensive tools for training, fine-tuning, and evaluating deep learning models on sleep staging datasets.
Usage Example¶
Below is an example of how to use the training module:
Here we loaded che configuration setup to train a physioex.train.networks.TinySleepNet
model. Best practices is to have a .yaml file where to store the model configuration, both in case you are using a custom SleepModule
or in case you are using one of the models provided by PhysioEx.
Here is an example of a possible .yaml configuration file and how to read it properly:
Note
In case you are using a model provided by the library, "input_transform" and "target_transform" can be loaded from the network_config ( line 20-21 )
Now we need to set up the datamodule arguments and we can start training the model:
Now imagine that we want to fine-tune the trained model on a new dataset.
Documentation¶
train
Trains a model using the provided datasets and configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
datasets
|
Union[List[str], str, PhysioExDataModule]
|
The datasets to be used for training. Can be a list of dataset names, a single dataset name, or a PhysioExDataModule instance. |
required |
datamodule_kwargs
|
dict
|
Additional keyword arguments to be passed to the PhysioExDataModule. Defaults to {}. |
required |
model
|
SleepModule
|
The model to be trained. If provided, |
required |
model_class
|
type
|
The class of the model to be trained. Required if |
required |
model_config
|
dict
|
The configuration dictionary for the model. Required if |
required |
batch_size
|
int
|
The batch size to be used for training. Defaults to 128. |
required |
fold
|
int
|
The fold index for cross-validation. Defaults to -1. |
required |
hpc
|
bool
|
Whether to use high-performance computing (HPC) settings. Defaults to False. |
required |
num_validations
|
int
|
The number of validation steps per epoch. Defaults to 10. |
required |
checkpoint_path
|
str
|
The path to save the model checkpoints. If None, a new path is generated. Defaults to None. |
required |
max_epochs
|
int
|
The maximum number of epochs for training. Defaults to 10. |
required |
num_nodes
|
int
|
The number of nodes to be used for distributed training. Defaults to 1. |
required |
resume
|
bool
|
Whether to resume training from the last checkpoint. Defaults to True. |
required |
Returns:
Name | Type | Description |
---|---|---|
str |
The path to the best model checkpoint. |
Raises:
Type | Description |
---|---|
ValueError
|
If |
ValueError
|
If |
Notes
- The function sets up the data module, model, and trainer, and then starts the training process.
- If
resume
is True and a checkpoint is found, training resumes from the last checkpoint. - The function returns the path to the best model checkpoint based on validation accuracy.
test
Tests a model using the provided datasets and configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
datasets
|
Union[List[str], str, PhysioExDataModule]
|
The datasets to be used for testing. Can be a list of dataset names, a single dataset name, or a PhysioExDataModule instance. |
required |
datamodule_kwargs
|
dict
|
Additional keyword arguments to be passed to the PhysioExDataModule. Defaults to {}. |
required |
model
|
SleepModule
|
The model to be tested. If provided, |
required |
model_class
|
type
|
The class of the model to be tested. Required if |
required |
model_config
|
dict
|
The configuration dictionary for the model. Required if |
required |
batch_size
|
int
|
The batch size to be used for testing. Defaults to 128. |
required |
fold
|
int
|
The fold index for cross-validation. Defaults to -1. |
required |
hpc
|
bool
|
Whether to use high-performance computing (HPC) settings. Defaults to False. |
required |
checkpoint_path
|
str
|
The path to the checkpoint from which to load the model. Required if |
required |
results_path
|
str
|
The path to save the test results. If None, results are not saved. Defaults to None. |
required |
num_nodes
|
int
|
The number of nodes to be used for distributed testing. Defaults to 1. |
required |
aggregate_datasets
|
bool
|
Whether to aggregate the datasets for testing. Defaults to False. |
required |
Returns:
Type | Description |
---|---|
pd.DataFrame: A DataFrame containing the test results. |
Raises:
Type | Description |
---|---|
ValueError
|
If |
ValueError
|
If |
Notes
- The function sets up the data module, model, and trainer, and then starts the testing process.
- The function returns a DataFrame containing the test results for each dataset.
- If
results_path
is provided, the results are saved as a CSV file in the specified path.
finetune
Fine-tunes a pre-trained model using the provided datasets and configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
datasets
|
Union[List[str], str, PhysioExDataModule]
|
The datasets to be used for fine-tuning. Can be a list of dataset names, a single dataset name, or a PhysioExDataModule instance. |
required |
datamodule_kwargs
|
dict
|
Additional keyword arguments to be passed to the PhysioExDataModule. Defaults to {}. |
required |
model
|
Union[dict, SleepModule]
|
The model to be fine-tuned. If provided, |
required |
model_class
|
type
|
The class of the model to be fine-tuned. Required if |
required |
model_config
|
dict
|
The configuration dictionary for the model. Required if |
required |
model_checkpoint
|
str
|
The path to the checkpoint from which to load the model. Required if |
required |
learning_rate
|
float
|
The learning rate to be set for fine-tuning. If |
required |
weight_decay
|
Union[str, float]
|
The weight decay to be set for fine-tuning. If |
required |
train_kwargs
|
Dict
|
Additional keyword arguments to be passed to the |
required |
Returns:
Name | Type | Description |
---|---|---|
str |
The path of the best model checkpoint. |
Raises:
Type | Description |
---|---|
ValueError
|
If |
ValueError
|
If |
Notes
- Models cannot be fine-tuned from scratch; they must be loaded from a checkpoint or be a pre-trained model from
physioex.models
. - Typically, when fine-tuning a model, you want to set up the learning rate.