Networks Module Overview¶
The model implemented into the PhysioEx library are:
- Chambon2018 model for sleep stage classification ( raw time series as input).
- TinySleepNet model for sleep stage classification (raw time series as input).
- SeqSleepNet model for sleep stage classification (time-frequency images as input).
Model | Model Name | Input Transform | Target Transform |
---|---|---|---|
Chambon2018 | chambon2018 | raw | get_mid_label |
TinySleepNet | tinysleepnet | raw | None |
SeqSleepNet | seqsleepnet | xsleepnet | None |
The models in PhysioEx are designed to receive input sequences of 30-second sleep epochs. These sequences can be either preprocessed or raw, depending on the specific requirements of the model. The preprocessing status of the input sequences is indicated by the "Input Transform" attribute. This attribute must match the "preprocessing" argument of the dataset to ensure that the model receives as input the correct information.
Similarly, models can be sequence-to-sequence (default) or sequence-to-epoch. In the last case a function that selects one epoch in the sequence needs to be added to the PhysioExDataModule pipeline to match the target data and the output of the model. These functions are implemented into the physioex.train.networks.utils.target_transform
module.
When implementing your own SleepModule, the Input Transform and Target Transform methods must be configurated properly, the best practice is to set them into a .yaml
file as discussed in the train module documentation page.
Extending the SleepModule¶
All the models compatible with PhysioEx are Pytorch Lightning Modules which extends the physioes.train.networks.base.SleepModule
.
By extending the module you can implement your own custom sleep staging deep learning network. When extending the module use a dictionary module_config: dict
as the argument to the construct to allow compatibility with all the library. Second define your custom torch.nn.Module
and use module_config: dict
as its constructor argument too.
Example
The SleepModule needs to know the n_classes
( for sleep staging this is tipycally 5 ) and the loss to be computed during training. By default the loss function in PhysioEx ( check physioex.train.networks.utils.loss
) take a python dict
in its constructor, so you should always specify in your module_config the n_classes
value, loss_call
and loss_params
.
SleepModule
Bases: LightningModule
A PyTorch Lightning module for sleep stage classification and regression tasks.
This module is designed to handle both classification and regression experiments for sleep stage analysis. It leverages PyTorch Lightning for training, validation, and testing, and integrates various metrics for performance evaluation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
`nn`
|
Module
|
The neural network model to be used for sleep stage analysis. |
required |
`config`
|
Dict
|
A dictionary containing configuration parameters for the module. Must include:
- |
required |
Attributes:
Name | Type | Description |
---|---|---|
`nn` |
Module
|
The neural network model. |
`n_classes` |
int
|
The number of classes for classification tasks. |
`loss` |
callable
|
The loss function. |
`module_config` |
Dict
|
The configuration dictionary. |
`learning_rate` |
float
|
The learning rate for the optimizer. Default is 1e-4. |
`weight_decay` |
float
|
The weight decay for the optimizer. Default is 1e-6. |
`val_loss` |
float
|
The best validation loss observed during training. |
Example
Notes
- This module supports both classification and regression tasks. The behavior is determined by the
n_classes
argument of the config dictionary. - Various metrics are logged during training, validation, and testing to monitor performance.
- The learning rate scheduler is configured to reduce the learning rate when the validation loss plateaus.
configure_optimizers()
¶
Configures the optimizer and learning rate scheduler.
Returns:
Type | Description |
---|---|
Tuple[List[Optimizer], List[Dict]]: A tuple containing the optimizer and the learning rate scheduler. |
forward(x)
¶
Defines the forward pass of the neural network.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
The input tensor. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The output tensor. |
encode(x)
¶
Encodes the input data using the neural network.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
The input tensor. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The encoded tensor. |
compute_loss()
¶
Computes the loss and logs metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embeddings
|
Tensor
|
The embeddings tensor. |
required |
outputs
|
Tensor
|
The outputs tensor. |
required |
targets
|
Tensor
|
The targets tensor. |
required |
log
|
str
|
The log prefix. Defaults to "train". |
required |
log_metrics
|
bool
|
Whether to log additional metrics. Defaults to False. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The computed loss. |
training_step(batch, batch_idx)
¶
Defines a single training step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Tuple[Tensor, Tensor]
|
The input batch. |
required |
batch_idx
|
int
|
The index of the batch. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The computed loss. |
validation_step(batch, batch_idx)
¶
Defines a single validation step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Tuple[Tensor, Tensor]
|
The input batch. |
required |
batch_idx
|
int
|
The index of the batch. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The computed loss. |
test_step(batch, batch_idx)
¶
Defines a single test step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
The input batch. |
required | |
batch_idx
|
int
|
The index of the batch. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The computed loss. |