Skip to content

Networks Module Overview

The model implemented into the PhysioEx library are:

  • Chambon2018 model for sleep stage classification ( raw time series as input).
  • TinySleepNet model for sleep stage classification (raw time series as input).
  • SeqSleepNet model for sleep stage classification (time-frequency images as input).
Model Model Name Input Transform Target Transform
Chambon2018 chambon2018 raw get_mid_label
TinySleepNet tinysleepnet raw None
SeqSleepNet seqsleepnet xsleepnet None

The models in PhysioEx are designed to receive input sequences of 30-second sleep epochs. These sequences can be either preprocessed or raw, depending on the specific requirements of the model. The preprocessing status of the input sequences is indicated by the "Input Transform" attribute. This attribute must match the "preprocessing" argument of the dataset to ensure that the model receives as input the correct information.

Similarly, models can be sequence-to-sequence (default) or sequence-to-epoch. In the last case a function that selects one epoch in the sequence needs to be added to the PhysioExDataModule pipeline to match the target data and the output of the model. These functions are implemented into the physioex.train.networks.utils.target_transform module.

When implementing your own SleepModule, the Input Transform and Target Transform methods must be configurated properly, the best practice is to set them into a .yaml file as discussed in the train module documentation page.

Extending the SleepModule

All the models compatible with PhysioEx are Pytorch Lightning Modules which extends the physioes.train.networks.base.SleepModule.

By extending the module you can implement your own custom sleep staging deep learning network. When extending the module use a dictionary module_config: dict as the argument to the construct to allow compatibility with all the library. Second define your custom torch.nn.Module and use module_config: dict as its constructor argument too.

Example

import torch
from physioex.train.networks.base import SleepModule

class CustomNet( torch.nn.Module ):
    def __init__(self, module_config: dict):

        # tipycally here you have an epoch_encoder and a sequence_encoder
        self.epoch_encoder = ...
        self.sequence_encoder = ...

        pass

    def forward(self, x : torch.Tensor):
        encoding, preds = self.encode(x)
        return preds

    def encode(self, x : torch.Tensor):
        # get your latent-space encodings
        encodings = ...

        # get your predictions out of the encodings
        preds = ...

        return econdings, preds

class CustomModule(SleepModule):
    def __init__(self, module_config: dict):
        super(CustomNet, self).__init__(CustomNet(module_config), module_config)

The SleepModule needs to know the n_classes ( for sleep staging this is tipycally 5 ) and the loss to be computed during training. By default the loss function in PhysioEx ( check physioex.train.networks.utils.loss ) take a python dict in its constructor, so you should always specify in your module_config the n_classes value, loss_call and loss_params.

SleepModule

Bases: LightningModule

A PyTorch Lightning module for sleep stage classification and regression tasks.

This module is designed to handle both classification and regression experiments for sleep stage analysis. It leverages PyTorch Lightning for training, validation, and testing, and integrates various metrics for performance evaluation.

Parameters:

Name Type Description Default
`nn` Module

The neural network model to be used for sleep stage analysis.

required
`config` Dict

A dictionary containing configuration parameters for the module. Must include: - n_classes (int): The number of classes for classification tasks. If n_classes is 1, the module performs regression. - loss_call (callable): A callable that returns the loss function. - loss_params (dict): A dictionary of parameters to be passed to the loss function.

required

Attributes:

Name Type Description
`nn` Module

The neural network model.

`n_classes` int

The number of classes for classification tasks.

`loss` callable

The loss function.

`module_config` Dict

The configuration dictionary.

`learning_rate` float

The learning rate for the optimizer. Default is 1e-4.

`weight_decay` float

The weight decay for the optimizer. Default is 1e-6.

`val_loss` float

The best validation loss observed during training.

Example
import torch.nn as nn
from your_module import SleepModule

config = {
    "n_classes": 5,
    "loss_call": nn.CrossEntropyLoss,
    "loss_params": {}
}

model = SleepModule(nn=YourNeuralNetwork(), config=config)
Notes
  • This module supports both classification and regression tasks. The behavior is determined by the n_classes argument of the config dictionary.
  • Various metrics are logged during training, validation, and testing to monitor performance.
  • The learning rate scheduler is configured to reduce the learning rate when the validation loss plateaus.

configure_optimizers()

Configures the optimizer and learning rate scheduler.

Returns:

Type Description

Tuple[List[Optimizer], List[Dict]]: A tuple containing the optimizer and the learning rate scheduler.

forward(x)

Defines the forward pass of the neural network.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Type Description

torch.Tensor: The output tensor.

encode(x)

Encodes the input data using the neural network.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Type Description

torch.Tensor: The encoded tensor.

compute_loss()

Computes the loss and logs metrics.

Parameters:

Name Type Description Default
embeddings Tensor

The embeddings tensor.

required
outputs Tensor

The outputs tensor.

required
targets Tensor

The targets tensor.

required
log str

The log prefix. Defaults to "train".

required
log_metrics bool

Whether to log additional metrics. Defaults to False.

required

Returns:

Type Description

torch.Tensor: The computed loss.

training_step(batch, batch_idx)

Defines a single training step.

Parameters:

Name Type Description Default
batch Tuple[Tensor, Tensor]

The input batch.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description

torch.Tensor: The computed loss.

validation_step(batch, batch_idx)

Defines a single validation step.

Parameters:

Name Type Description Default
batch Tuple[Tensor, Tensor]

The input batch.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description

torch.Tensor: The computed loss.

test_step(batch, batch_idx)

Defines a single test step.

Parameters:

Name Type Description Default
batch

The input batch.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description

torch.Tensor: The computed loss.