Source code for sleepless.utils.misclassification

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Compute misclassified Epochs and function to visualize them."""

from __future__ import annotations

import matplotlib.pyplot as plt
import mne.epochs
import numpy as np

from matplotlib.figure import Figure

from ..data.sample import DelayedSample
from .matplotlib_utils import PointBrowser, get_sleep_stage_labels, stage_colors


[docs] def compute_misclassification( sample: DelayedSample, return_well_classified: bool = False ) -> tuple[mne.Epochs, dict[str, np.ndarray]]: """Compute the misclassification epochs of a sample. :param sample: a sample of a dataset :param return_well_classified: If True, the function return well classified epochs instead of misclassified :return: the misclassified epochs and a dictionary with label, prediction and labels of the misclassified epochs. """ prob = sample.output_prob y_pred = np.argmax(prob, axis=1) y_label = sample.label index = np.argwhere(y_pred != y_label).flatten() if return_well_classified: index = np.argwhere(y_pred == y_label).flatten() y_prob = [str(prob[i, :]) for i in index] misclassified_epochs = sample.data["data"][index] misclassified_pred = y_pred[index] misclassified_labels = y_label[index] df_mis = dict( zip( ["output_probability", "prediction", "label", "Epochs index"], [ y_prob, misclassified_pred, misclassified_labels, index, ], ) ) return misclassified_epochs, df_mis
[docs] def plot_misclassified_epochs( sample: DelayedSample, vizu_epochs: bool = False, return_well_classified: bool = False, ) -> tuple[Figure, dict[str, np.ndarray]]: """Generate a plot and a dictionary of the misclassified epochs. :param sample: a sample of a dataset :vizu_epochs: open an interactive plot of misclassified epochs :return_well_classified: If True, return well classified epochs instead of misclassified :return: a figure and a dictionary of the misclassified epochs """ misclassified_epochs, dic_miss = compute_misclassification( sample, return_well_classified ) fig = plt.figure(1) x = dic_miss["Epochs index"] y_pred = dic_miss["prediction"] y = dic_miss["label"] event_dic_sample_label = get_sleep_stage_labels([y]) event_dic_sample_pred = get_sleep_stage_labels([y_pred]) if return_well_classified: for index, colors in enumerate(stage_colors[: len(np.unique(y_pred))]): plt.scatter( x[y == np.unique(y_pred)[index]], y[y == np.unique(y_pred)[index]], color=colors, picker=True, ) plt.title("Well classified Epochs") else: for index, colors in enumerate(stage_colors[: len(np.unique(y_pred))]): plt.scatter( x[y_pred == np.unique(y_pred)[index]], y[y_pred == np.unique(y_pred)[index]], color=colors, picker=True, ) plt.title("Misclassified Epochs") plt.legend( event_dic_sample_pred.keys(), title="Prediction", loc="center left", bbox_to_anchor=(1, 0.5), ) plt.yticks(np.unique(y), event_dic_sample_label.keys()) plt.ylabel(ylabel="True label") plt.xlabel("Epochs index") fig.tight_layout(rect=[0, 0.03, 1, 0.95]) if vizu_epochs: fig.suptitle( "Sample (Key=" + str(sample.key) + ", Age =" + str(sample.age) + ", Gender =" + str(sample.gender) + " )" ) y_pred = [ list(event_dic_sample_pred.keys())[int(pred)] for pred in y_pred ] label = [list(event_dic_sample_label.keys())[int(label)] for label in y] browser = PointBrowser(x, y, label, y_pred, misclassified_epochs, fig) fig.canvas.mpl_connect("pick_event", browser.on_pick) fig.canvas.mpl_connect("key_press_event", browser.on_press) plt.show() return fig, dic_miss