Source code for sleepless.data.transforms
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
# For Class RawToEpochs, Class EEGPowerBand and Class FeatureExtractorChambon:
# SPDX-FileCopyrightText: Copyright © 2011-2022, authors of MNE-Python
#
# SPDX-FileContributor: Alexandre Gramfort <alexandre.gramfort@inria.fr>
# SPDX-FileContributor: Stanislas Chambon <stan.chambon@gmail.com>
# SPDX-FileContributor: Joan Massich <mailsik@gmail.com>
#
# SPDX-License-Identifier: BSD-3-Clause
"""Signal transformations for our pipelines."""
from __future__ import annotations
import logging
import mne
import numpy as np
import scipy
import torch
from mne import Annotations, Epochs, events_from_annotations
from mne.io.edf.edf import RawEDF
from torch.utils.data import ConcatDataset
from .sample import DelayedSample
from .utils import ListSampleDataset, chan_list_to_dict
logger = logging.getLogger(__name__)
[docs]
class RawToEpochs:
"""Transform Raw and Annotation objects to Epochs object, with some
preprocessing options. Data are not loaded in memory until
mne.Epochs.get_data() is call.
This class was inspired from version 1.4 of
https://mne.tools/stable/auto_tutorials/clinical/60_sleep.html
:param raw_obj: A mne raw object
:param annot_obj: A mne object for annotating segments of raw data
:param event_id: Map stage_name (str=keys) and interger event codes
(int=values)
:param chunk_duration: The window time (in seconde) that was used to
annotated
:param picks_chan: List of channels to keep (e.g.
['eeg','eog','emg']), if None (default)
:param no_overlapping: If True remove last sample from epoch to
avoid overlapping, Default (False)
:param crop_wake_time: Wake time (in minute) to keep at the begining
and the end, in some case it is usefull to crop a part of the
wake time if it is too long regarding the other stage or not
usefull (e.g. walking time)
:param wake_stage_name: Only needed for crop_wake_time. (e.g. "Sleep
stage W")
"""
def __init__(
self,
event_id: dict[str, int],
chunk_duration: float,
picks_chan: str | list[str] | None = None,
no_overlapping: bool = True,
crop_wake_time: float = 0.0,
wake_stage_name: str | None = None,
):
self.event_id = event_id
self.chunk_duration = chunk_duration
self.picks_chan = picks_chan
self.no_overlapping = no_overlapping
self.crop_wake_time = crop_wake_time
self.wake_stage_name = wake_stage_name
def __call__(self, samples: list[DelayedSample]) -> list[DelayedSample]:
"""Return a list of samples with computed epochs.
:return: A list of samples with where mne.epochs are computed
"""
for index, sample in enumerate(samples, 1):
logger.info(
f"computing epochs for sample {index} on {len(samples)} samples"
)
sample.epochs = self.transform(
sample.data["data"], sample.data["label"]
)
return samples
[docs]
def transform(self, raw_obj: RawEDF, annot_obj: Annotations) -> Epochs:
"""Return the mne.Epochs object from raw data and label of a sample.
:return: A mne Epochs object
"""
raw_obj.set_annotations(annot_obj, emit_warning=True, verbose=True)
if self.crop_wake_time > 0:
if self.wake_stage_name is None:
logger.error(
"wake_stage_name need to be defined for cropping process"
)
mask = [x == self.wake_stage_name for x in annot_obj.description]
sleep_event_inds = np.where(mask)[0]
tmin = (
annot_obj[int(sleep_event_inds[0] + 1)]["onset"]
- self.crop_wake_time * 60
)
tmin = max(raw_obj.times[0], tmin)
tmax = (
annot_obj[int(sleep_event_inds[-1])]["onset"]
+ self.crop_wake_time * 60
)
tmax = min(tmax, raw_obj.times[-1])
raw_obj.crop(tmin=tmin, tmax=tmax)
events, map_event_id = events_from_annotations(
raw_obj,
event_id=self.event_id,
chunk_duration=self.chunk_duration,
verbose=True,
)
tmax = self.chunk_duration
if self.no_overlapping:
tmax = self.chunk_duration - 1.0 / raw_obj.info["sfreq"]
epochs_obj = Epochs(
raw=raw_obj,
events=events,
event_id=map_event_id,
picks=self.picks_chan,
tmin=0.0,
tmax=tmax,
preload=False,
baseline=None,
verbose=True,
on_missing="warn",
)
return epochs_obj
[docs]
class EEGPowerBand:
"""Extract feature from a py:class:`DelayedSample` list.
This class was copied and modified from https://mne.tools/stable/auto_tutorials/clinical/60_sleep.html v1.4
Modification: change to class, adding pick_chan attribute, management division by zero
:param pick_chan: the channel type (e.g. "eeg","eog") or name (e.g. "Fpz-Cz") whom extract the features, if None default compute
features for all EEG channels.
:return: py:class:`DelayedSample` list where features have been extracted
"""
def __init__(self, pick_chan: list[str] = None):
self.pick_chan = None
if pick_chan is not None:
self.pick_chan = chan_list_to_dict(pick_chan)
def __call__(self, samples: list[DelayedSample]) -> list[DelayedSample]:
for index, sample in enumerate(samples, 1):
logger.info(
f"computing frequency decomposition for sample {index} on {len(samples)} samples"
)
sample.features, sample.label = (
self._transform(sample.data["data"]),
sample.data["data"].events[:, 2],
)
return samples
def _transform(self, epochs: Epochs) -> np.ndarray:
"""EEG relative power band feature extraction. This function takes an
``mne.Epochs`` object and creates EEG features based on relative power
in specific frequency bands. Also saving the labels as attribute of the
sample while data are loaded.
:param epochs: mne.Epochs of a sample
:return: Transformed data of shape [n_epochs, n_channel*5]
"""
FREQ_BANDS = {
"delta": [0.5, 4.5],
"theta": [4.5, 8.5],
"alpha": [8.5, 11.5],
"sigma": [11.5, 15.5],
"beta": [15.5, 30],
}
pick_chan_idx = "eeg"
if self.pick_chan is not None:
pick_chan_idx = mne.pick_types(info=epochs.info, **self.pick_chan)
spectrum = epochs.compute_psd(fmin=0.5, fmax=30.0, picks=pick_chan_idx)
psds, freqs = spectrum.get_data(return_freqs=True, picks="all")
psds_norm = np.zeros(psds.shape)
psds_sum = np.sum(psds, axis=-1, keepdims=True)
np.divide(psds, psds_sum, where=psds_sum > 0, out=psds_norm)
X = []
for fmin, fmax in FREQ_BANDS.values():
psds_band = psds_norm[:, :, (freqs >= fmin) & (freqs < fmax)].mean(
axis=-1
)
X.append(psds_band.reshape(len(psds), -1))
return np.concatenate(X, axis=1)
[docs]
class ToTorchDataset:
"""Build Torch dataset from a py:class:`DelayedSample` list.
:param normalize: If True, normalized the sample
:param pick_chan: the channel type (e.g. "eeg","eog") or name (e.g. "Fpz-Cz") whom extract the features, if None default compute
features for all EEG channels.
:param n_past_epochs: number of precedent epochs to include in the ListSampleDataset object (by concatenation).
:return: :py:class:`torch.utils.data.dataset.ConcatDataset` of all samples
"""
def __init__(
self,
normalize: bool = False,
pick_chan: list[str] = None,
n_past_epochs: int = 0,
):
self.normalize = normalize
self.pick_chan = None
self.n_past_epochs = n_past_epochs
if pick_chan is not None:
self.pick_chan = chan_list_to_dict(pick_chan)
def __call__(
self, samples: list[DelayedSample]
) -> torch.utils.ConcatDataset:
samples_list = []
for index, sample in enumerate(samples, 1):
logger.info(
f"making torch dataset for sample {index} on {len(samples)} samples"
)
samples_list.append(
ListSampleDataset(
sample, self.normalize, self.pick_chan, self.n_past_epochs
)
)
return ConcatDataset(samples_list)
[docs]
class ResampleEpochs:
"""Resample ``mne.Epochs`` object from a py:class:`DelayedSample` list.
:param sampling_freq: sampling frequency to which resample
:return: py:class:`DelayedSample` list with resampled data
"""
def __init__(self, sampling_freq: int):
self.sampling_freq = sampling_freq
def __call__(self, samples: list[DelayedSample]) -> list[DelayedSample]:
for index, sample in enumerate(samples, 1):
logger.info(
f"resampling for sample {index} on {len(samples)} samples"
)
freq_sample = sample.data["data"].info["sfreq"]
if freq_sample != self.sampling_freq:
logger.info(
f"resampled as sampling frequency was {freq_sample} and {self.sampling_freq} is needed"
)
sample.data["data"] = self._transform(sample.data["data"])
else:
logger.info(
f" no need of resample as sampling frequency is already {freq_sample}"
)
return samples
def _transform(self, epochs: Epochs) -> Epochs:
"""To resample the epochs at a given frequency. This function takes an
``mne.Epochs`` object.
:param epochs: ``mne.Epochs`` object
:return: ``mne.Epochs`` resampled
"""
resampled_epochs = epochs.load_data().resample(sfreq=self.sampling_freq)
return resampled_epochs
[docs]
class FeatureExtractorChambon:
"""Extract feature from a py:class:`DelayedSample` list. 26 manually chosen
features (total power (5), relative power (5), power ratio (10), spectral
entropy, mean, variance, skewness, kurtosis, 75% quantile.)
This class was copied and modified from https://mne.tools/stable/auto_tutorials/clinical/60_sleep.html v1.4
Modification: change to class, adding pick_chan attribute, management division by zero, adding computation of more features
Reference: [Chambon-2018]_
:param pick_chan: the channel type (e.g. "eeg","eog") or name (e.g. "Fpz-Cz") whom extract the features, if None default compute
features for all EEG channels.
:return: py:class:`DelayedSample` list where features have been extracted
"""
def __init__(self, pick_chan: list[str] = None):
self.pick_chan = None
if pick_chan is not None:
self.pick_chan = chan_list_to_dict(pick_chan)
def __call__(self, samples: list[DelayedSample]) -> list[DelayedSample]:
for index, sample in enumerate(samples, 1):
logger.info(
f"computing frequency decomposition for sample {index} on {len(samples)} samples"
)
sample.features, sample.label = (
self._transform(sample.data["data"]),
sample.data["data"].events[:, 2],
)
return samples
def _transform(self, epochs: Epochs) -> np.ndarray:
"""EEG relative power band feature extraction. This function takes an
``mne.Epochs`` object and creates EEG features based on relative power
in specific frequency bands. Also saving the labels as attribute of the
sample while data are loaded.
:param epochs: mne.Epochs of a sample
:return: Transformed data of shape [n_epochs, n_channel*26] 20
spectral features and 6 temporal
"""
FREQ_BANDS = {
"delta": [0.5, 4.5],
"theta": [4.5, 8.5],
"alpha": [8.5, 11.5],
"sigma": [11.5, 15.5],
"beta": [15.5, 30],
}
pick_chan_idx = "eeg"
if self.pick_chan is not None:
pick_chan_idx = mne.pick_types(info=epochs.info, **self.pick_chan)
spectrum = epochs.compute_psd(fmin=0.5, fmax=30.0, picks=pick_chan_idx)
psds, freqs = spectrum.get_data(return_freqs=True, picks="all")
total_power_list = []
# compute the total power (spectral power) of each bands in FREQ_BANDS
for fmin, fmax in FREQ_BANDS.values():
psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(
axis=-1
)
total_power_list.append(psds_band.reshape(len(psds), -1))
total_power_vec = np.concatenate(total_power_list, axis=1)
psds_norm = np.zeros(psds.shape)
psds_sum = np.sum(psds, axis=-1, keepdims=True)
np.divide(psds, psds_sum, where=psds_sum > 0, out=psds_norm)
relative_power_list = []
# compute the relative power for each bands in FREQ_BANDS (total power of the bands/total power)
for fmin, fmax in FREQ_BANDS.values():
psds_band = psds_norm[:, :, (freqs >= fmin) & (freqs < fmax)].mean(
axis=-1
)
relative_power_list.append(psds_band.reshape(len(psds), -1))
relative_power_vec = np.concatenate(relative_power_list, axis=1)
ratio_power_list = []
# ratio of power (e.g. relative power beta bands/relative power theta bands)
# in all 10 combinations
for index, psds_band in enumerate(relative_power_list, 1):
for i in range(index, len(relative_power_list)):
ratio_power = np.zeros(psds_band.shape)
ratio_power_list.append(
np.divide(
psds_band,
relative_power_list[i],
where=relative_power_list[i] != 0,
out=ratio_power,
)
)
ratio_power_vec = np.concatenate(ratio_power_list, axis=1)
# compute the spectral entropy
log_psds_norm = np.zeros(psds_norm.shape)
np.log(psds_norm, where=psds_norm > 0, out=log_psds_norm)
spectral_entropy = -(1 / np.log2(psds_norm.shape[2])) * (
psds_norm * log_psds_norm
).sum(axis=-1)
del psds
data = epochs.get_data(picks=pick_chan_idx)
# compute 6 temporal features (mean,variance, skewness,kurtosis and 75% quantile)
mean_vec = np.mean(data, axis=-1)
vars_vec = np.std(data, axis=-1)
skew_vec = scipy.stats.mstats.skew(data, axis=-1).data
kurtosis_vec = scipy.stats.mstats.kurtosis(data, axis=-1).data
quantile_75_vec = np.quantile(data, 0.75, axis=-1)
del data
feature_vector = (
total_power_vec,
relative_power_vec,
ratio_power_vec,
spectral_entropy,
mean_vec,
vars_vec,
skew_vec,
kurtosis_vec,
quantile_75_vec,
)
return np.hstack(feature_vector)