# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Common utilities."""
from __future__ import annotations
import mne
import numpy as np
import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import ConcatDataset, Dataset
from .sample import DelayedSample, Sample
[docs]
def plot_PSG_and_annotation(raw_obj, annot_obj):
"""Plot raw signal with annotation using MNE plot function."""
raw_obj.set_annotations(annot_obj, emit_warning=False)
return raw_obj.plot(
start=1200,
duration=30,
scalings=dict(eeg=2e-4, eog=1e-4, emg=10 - 7),
)
[docs]
def saving_preprocess(dataset):
for k, v in dataset.items():
for sample in v:
sample.data
return
[docs]
class ListSampleDataset(Dataset):
"""PyTorch dataset wrapper around Sample list. This Class takes a
py:class:`DelayedSample`or :py:class:`Sample` object and generate a object
as :py:class:`torch.utils.data.dataset.Dataset` where every sample of the
dataset is defined as a list. It supports indexing such that dataset[i] can
be used to get the i-th sample.
:param sample: sample to be wrap into the dataset object
:param normalize: if set to True, we remove the mean of data, and
divided them by the standard deviation (it is done epoch
(window) wise)
:param pick_chan: the channel type (e.g. "eeg","eog") or name (e.g.
"Fpz-Cz") whom extract the features, if None default compute
features for all EEG channels.
:param n_past_epochs: it indicates how many past epochs to
concatenate to the current sample epoch. E.g. if n_past_epochs =
1, then one past epoch is concatenated to the current.
"""
def __init__(
self,
sample: DelayedSample | Sample,
normalize: bool,
pick_chan: dict[str, bool | list[str]] | None,
n_past_epochs: int = 0,
):
self.key = sample.key
self.labels = sample.data["label"]
epochs = sample.data["data"]
pick_chan_idx = "eeg"
if pick_chan is not None:
pick_chan_idx = mne.pick_types(info=epochs.info, **pick_chan)
data = epochs.get_data(picks=pick_chan_idx)
self.data = data
self.n_past_epochs = n_past_epochs
# if normalize is true remove the mean of the data and divide
# by the standard deviation, the result is set to zero if the standard deviation
# is equal to zero
if normalize:
data_tensor = torch.from_numpy(data)
mean_data = torch.mean(data_tensor, dim=-1, keepdim=True)
std_data = torch.std(data_tensor, dim=-1, keepdim=True)
data_norm = torch.where(
std_data != 0,
(data_tensor - mean_data) / std_data,
torch.zeros(data_tensor.shape),
)
self.data = data_norm
def __len__(self) -> int:
"""
:return: The size of the dataset
"""
return len(self.labels)
def __getitem__(self, idx: int) -> list:
"""
:param idx: int
:return: sample data with following structure [name path of the file,data,label,night_epoch_index]
"""
# Create a tensor of indexes, where the negative indexes indicate that padding is required
indexes = torch.arange(
idx - self.n_past_epochs, idx + 1, dtype=torch.long
)
# Create a tensor with zeros, with dimensions based on self.data
dt = torch.zeros((len(indexes), self.data.shape[1], self.data.shape[2]))
# Iterate over the indexes
for i in range(len(indexes)):
# Check if the index is greater than or equal to 0
if indexes[i] >= 0:
# Set the i-th element of dt to a tensor created from self.data at the corresponding index
dt[i] = torch.Tensor(self.data[indexes[i]])
# Check if dt has more than one element
if dt.shape[0] > 1:
# Reshape dt to have dimensions dt.shape[1] and dt.shape[0]*dt.shape[2]
dt = dt.reshape(dt.shape[1], dt.shape[0] * dt.shape[2])
# Return a list containing self.key, dt, self.labels[idx], and idx
return [self.key, dt, self.labels[idx], idx]
# If dt has only one element, remove the extra dimension
return [self.key, dt.squeeze(), self.labels[idx], idx]
[docs]
def get_samples_weights(dataset: ConcatDataset | Dataset):
"""Compute the weights of all the samples of the dataset to balance the
cross-entropy criterion. This function takes as input a
:py:class:`torch.utils.data.dataset.Dataset` and computes the weights to
balance each class in the dataset and the datasets themselves if we have a
ConcatDataset.
:param dataset: torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
:return: :py:class:`torch.Tensor`
the weights for all the samples in the dataset given as input
"""
train_y = np.concatenate([ds.labels for ds in dataset.datasets])
class_weights = compute_class_weight(
"balanced", classes=np.unique(train_y), y=train_y
)
return torch.tensor(class_weights, dtype=torch.float32)
[docs]
def chan_list_to_dict(chan_list: list[str]) -> dict[str, bool | list[str]]:
"""Handle list which are a mix of channel type and channel names. To be use
then with the :py:func:`mne.pick_types`
:param chan_list: list of channel which can be channel
type or channel names
:return: dictionary where recognized channel types are assigned as key and their value set to true
and other channels are expected to be channel names and are placed in a list, the key of this list is "include".
See :py:func:`mne.pick_types` for more information
"""
list_of_types = [
"eeg",
"seeg",
"ecog",
"dbs",
"eog",
"ecg",
"emg",
"bio",
"resp",
"temp",
"misc",
"sao2",
]
dic: dict[str, bool | list[str]] = {}
include_list: list[str] = []
for chan in chan_list:
low_case_chan_name = chan.lower()
if low_case_chan_name in list_of_types:
dic[low_case_chan_name] = True
else:
include_list.append(chan)
dic["include"] = include_list
return dic