Source code for sleepless.engine.analyze

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Analyze script."""

from __future__ import annotations

import logging
import typing

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.metrics import (
    ConfusionMatrixDisplay,
    accuracy_score,
    balanced_accuracy_score,
    cohen_kappa_score,
    hamming_loss,
    matthews_corrcoef,
)

from ..data.sample import DelayedSample
from ..utils.matplotlib_utils import get_sleep_stage_labels
from ..utils.stats_protocol import make_stats
from ..utils.utils_fig_table_df import create_df, make_rst_tabulate

logger = logging.getLogger(__name__)


[docs] def metric_stats( dataset: dict[str, list[DelayedSample]], bins: list[int] = [0, 18, 60, 70, 80, 90, 100, 110], ) -> tuple[list, list, typing.Iterable[tuple]]: """Compute different metrics on a dataset and saved them as table and figure. :param dataset: A dictionary containing different sets (e.g. train,test). :param bins: definition of the age categories :param out_path: the path location where fils will be saved :return: list of figures, list of tables and dictionary of metrics """ # for now only stats for these attribute can be computed # ATTENTION for now attribute name is hardcode, attribute names has to exactly match these names in keep keep = ["age", "gender"] keys = [key for key in dataset.keys()] attributes = [att for att in dir(dataset[keys[0]][0]) if att in keep] figures = [] list_df = [pd.DataFrame() for attr in range(len(attributes) + 1)] table_names = ["metrics table"] table_names += [attr + " metrics table" for attr in attributes] for key in keys: name_subset = str(key) df_subset = create_df(dataset[key], attributes, bins) df_subset["output_prob"] = [_id.output_prob for _id in dataset[key]] df_subset["y_label"] = [sample.label for sample in dataset[key]] return_metric_key, return_figures_key = metrics_computation( df_subset, name_subset ) list_df[0][str(name_subset)] = return_metric_key figures += return_figures_key for index, attribute in enumerate(attributes, 1): if attribute == "age": attribute = "ageGroup" df_index = df_subset[str(attribute)].unique() df_index = df_index[np.argsort(df_index.codes)] else: df_index = df_subset[str(attribute)].unique() name_attribute = name_subset + " " + str(attribute) for value in df_index: df_value = df_subset[df_subset[str(attribute)] == value] name_value = name_attribute + " " + str(value) ( return_metric_attrib, return_figures_attrib, ) = metrics_computation(df_value, name_value) list_df[index][str(name_value)] = return_metric_attrib figures += return_figures_attrib tables_metrics = [] for index, df in enumerate(list_df): tables_metrics.append(make_rst_tabulate(table_names[index], df)) figures += make_stats(dataset, attributes=attributes) return figures, tables_metrics, zip(table_names, list_df)
[docs] def metrics_computation( data: pd.DataFrame, name: str ) -> tuple[pd.DataFrame, list[plt.Figure]]: """Compute metrics for the scikit-learn pipeline,6 metrics are computed: accuracy, confusion matrix, matthews_corrcoef, balanced_accuracy, linear weighted Kappa and quadratic weighted Kappa. :param data: a dataframe containing label and prediction for different samples :param name: path location where files will be saved :return: Matthews_corrcoef, accuracy, linear weighted Kappa, quadratic weighted Kappa and balanced_accuracy are return in common pd.Dataframe(df_metrics) Figure of confusion matrix are return as list of figure """ dic_sample = data["output_prob"].to_numpy() nb_samples = dic_sample.shape[0] y_label = np.concatenate(data["y_label"].to_numpy(), axis=0) output = np.concatenate(dic_sample, axis=0) output = np.argmax(output, axis=1) nb_epochs = output.shape[0] labels = get_sleep_stage_labels([y_label, output]) _accuracy_score = accuracy_score(y_label, output) _confusion_matrix = ConfusionMatrixDisplay.from_predictions( y_label, output, normalize="true", display_labels=labels.keys(), values_format=".2f", xticks_rotation="vertical", ) _confusion_matrix_total_number = ConfusionMatrixDisplay.from_predictions( y_label, output, normalize=None, display_labels=labels.keys(), xticks_rotation="vertical", ) _linear_kappa_score = cohen_kappa_score(y_label, output, weights="linear") _quadratic_kappa_score = cohen_kappa_score( y_label, output, weights="quadratic" ) _balanced_accuracy_score = balanced_accuracy_score(y_label, output) _matthews_corrcoef = matthews_corrcoef(y_label, output) _hamming_loss = hamming_loss(y_label, output) key = [ "Samples number", "Total number of epochs", "Balanced accuracy", "Accuracy score", "Linear Weighted Kappa", "Quadratic Weighted Kappa", "Mcc", "Hamming Loss", ] value = [ int(nb_samples), int(nb_epochs), _balanced_accuracy_score, _accuracy_score, _linear_kappa_score, _quadratic_kappa_score, _matthews_corrcoef, _hamming_loss, ] df_metrics = pd.DataFrame(value, key) _confusion_matrix.ax_.set_title( " ".join(name.split("_")).capitalize() + " normalized" ) _confusion_matrix.figure_.set_tight_layout(True) plt.close() _confusion_matrix_total_number.ax_.set_title( " ".join(name.split("_")).capitalize() ) _confusion_matrix_total_number.figure_.set_tight_layout(True) plt.close() figures = [ _confusion_matrix.figure_, _confusion_matrix_total_number.figure_, ] return df_metrics, figures
[docs] def misclassified_analyze( dataset: dict[str, list[DelayedSample]], ) -> tuple[list, dict[str, dict]]: """Compute an analysis of the misclassified and well classified Epochs of a dataset, for each samples of the dataset, one figure (prediction/true label) and a table is generated. :param dataset: a sample of a data :return: list of figures and a dictionary """ from ..utils.misclassification import plot_misclassified_epochs fig_list = [] dic_output = {} for key in dataset.keys(): for index_sample, sample in enumerate(dataset[key]): for boolean in [True, False]: tab_name = "Misclassified" if boolean: tab_name = "Well classified" fig, dic_miss = plot_misclassified_epochs( sample, return_well_classified=boolean ) fig.suptitle( str(key).capitalize() + " set, Sample " + str(index_sample) + " ( " + str(sample.key) + " )" ) plt.close() name_dic = ( str(tab_name) + " Epochs for : " + str(key).capitalize() + " set, Sample " + str(index_sample) + " ( " + str(sample.key) + " )" ) fig_list.append(fig) dic_output[name_dic] = dic_miss return fig_list, dic_output