Source code for sleepless.models.ninel_1d_cnn_ssc

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

from torch import nn


class _CnnBlock(nn.Module):
    """Convolutional neural network block, composed of 1d-cnn layer, one batch
    normalization layer and one ReLu activation layer.

    Parameters
    ----------
    n_channels : int
        Number of input channels.

    n_out : int
        Number of output channels.

    kernel_size_cnn : int
        kernel size

    stride_cnn : int
        stride size

    padding_cnn : int
        padding size
    """

    def __init__(
        self,
        n_channels: int,
        n_out: int,
        kernel_size_cnn: int | tuple,
        stride_cnn: int,
        padding_cnn: int | str = 0,
    ):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(
                in_channels=n_channels,
                out_channels=n_out,
                kernel_size=kernel_size_cnn,
                stride=stride_cnn,
                padding=padding_cnn,
            ),
            nn.BatchNorm1d(num_features=n_out),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.cnn(x)
        return x


[docs] class NineL_1DCNN_SSC(nn.Module): """Sleep staging architecture from [Satapathy-2023]_. Convolutional neural network for sleep staging described in [Satapathy-2023]_. Parameters ---------- n_channels : int Number of EEG channels. n_classes : int Number of classes. n_hidden_state : int output size of the feature extractor References ---------- .. [Satapathy-2023]_ """ def __init__(self, n_channels: int, n_hidden_state: int, n_classes: int): super().__init__() self.cnn1 = _CnnBlock(n_channels, n_hidden_state * 16, 8, 4, 2) self.mp1 = nn.MaxPool1d(2, 2) self.feature_extractor = nn.Sequential( _CnnBlock(n_hidden_state * 16, n_hidden_state * 32, 3, 1, 1), nn.MaxPool1d(2, 2), _CnnBlock(n_hidden_state * 32, n_hidden_state * 64, 3, 1, 1), nn.MaxPool1d(2, 2), _CnnBlock(n_hidden_state * 64, n_hidden_state * 64, 3, 1, 1), nn.MaxPool1d(2, 2), _CnnBlock(n_hidden_state * 64, n_hidden_state * 64, 3, 1, 1), nn.MaxPool1d(2, 2), _CnnBlock(n_hidden_state * 64, n_hidden_state * 64, 3, 1, 1), nn.MaxPool1d(2, 2), _CnnBlock(n_hidden_state * 64, n_hidden_state * 64, 3, 1, 1), nn.MaxPool1d(2, 2), _CnnBlock(n_hidden_state * 64, n_hidden_state * 64, 3, 1, 1), nn.MaxPool1d(2, 2), _CnnBlock(n_hidden_state * 64, n_hidden_state * 64, 3, 1, 1), nn.MaxPool1d(2, 2), ) self.classifier = nn.Sequential( nn.Linear(n_hidden_state * 64, 100), nn.ReLU(), nn.Linear(100, n_classes), )
[docs] def forward(self, x): """Forward pass. :param x: Batch of signals windows of shape (batch_size, n_channels, n_times). """ if len(x) == 2: x = x.unsqueeze(1) x = self.cnn1(x) x = self.mp1(x) features = self.feature_extractor(x).flatten(start_dim=1) x = self.classifier(features) return x, features