Source code for sleepless.models.chambon2018

# SPDX-FileCopyrightText: Copyright (c) 2017-2020 Braindecode Developers
#
# SPDX-FileContributor: Hubert Banville <hubert.jbanville@gmail.com>
#
# SPDX-License-Identifier: BSD-3-Clause

from torch import nn


[docs] class SleepStagerChambon2018(nn.Module): """Sleep staging architecture from [Chambon-2018]_. This class was copied with minor modifications from https://github.com/braindecode/braindecode/blob/master/braindecode/models/sleep_stager_chambon_2018.py v0.7.0 Modification: remove condition to return features extracted before classification,now there are always returned Convolutional neural network for sleep staging described in [Chambon-2018]_. Parameters ---------- n_channels : int Number of EEG channels. sfreq : float EEG sampling frequency. n_conv_chs : int Number of convolutional channels. Set to 8 in [Chambon-2018]_. time_conv_size_s : float Size of filters in temporal convolution layers, in seconds. Set to 0.5 in [Chambon-2018]_ (64 samples at sfreq=128). max_pool_size_s : float Max pooling size, in seconds. Set to 0.125 in [Chambon-2018]_ (16 samples at sfreq=128). n_classes : int Number of classes. input_size_s : float Size of the input, in seconds. dropout : float Dropout rate before the output dense layer. References ---------- .. [Chambon-2018]_ """ def __init__( self, n_channels, sfreq, n_conv_chs=8, time_conv_size_s=0.5, max_pool_size_s=0.125, n_classes=5, input_size_s=30, dropout=0.25, ): super().__init__() time_conv_size = int(time_conv_size_s * sfreq) max_pool_size = int(max_pool_size_s * sfreq) input_size = int(input_size_s * sfreq) pad_size = time_conv_size // 2 self.n_channels = n_channels len_last_layer = self._len_last_layer( n_channels, input_size, max_pool_size, n_conv_chs ) if n_channels > 1: self.spatial_conv = nn.Conv2d(1, n_channels, (n_channels, 1)) self.feature_extractor = nn.Sequential( nn.Conv2d( 1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size) ), nn.ReLU(), nn.MaxPool2d((1, max_pool_size)), nn.Conv2d( n_conv_chs, n_conv_chs, (1, time_conv_size), padding=(0, pad_size), ), nn.ReLU(), nn.MaxPool2d((1, max_pool_size)), ) self.fc = nn.Sequential( nn.Dropout(dropout), nn.Linear(len_last_layer, n_classes) ) @staticmethod def _len_last_layer(n_channels, input_size, max_pool_size, n_conv_chs): return n_channels * (input_size // (max_pool_size**2)) * n_conv_chs
[docs] def forward(self, x): """Forward pass. Parameters --------- x: torch.Tensor Batch of EEG windows of shape (batch_size, n_channels, n_times). """ if x.ndim == 3: x = x.unsqueeze(1) if self.n_channels > 1: x = self.spatial_conv(x) x = x.transpose(1, 2) x = self.feature_extractor(x) x = x.flatten(start_dim=1) # we are always returning the features extracted before classification return self.fc(x), x