Source code for sleepless.utils.stats_protocol

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Compute Statistics for protocols."""

from __future__ import annotations

import matplotlib.pyplot as plt
import pandas as pd

from matplotlib.figure import Figure

from ..data.sample import DelayedSample
from .utils_fig_table_df import create_df, save_fig


[docs] def make_stats( dataset: dict[str, list[DelayedSample]], out_path: str | None = None, attributes: list[str] = ["age", "gender"], bins: list[int] = [0, 18, 60, 70, 80, 90, 100, 110], ) -> list[Figure]: """Compute different statistics on the subsets of dataset and created figure of the statistics saved. :param dataset: A dictionary containing different sets (e.g. train,test). :param keep: a list of attribute on which we perform analysis :param out_path: the path location where files will be saved :param bins: definition of the age categories """ keys = list(dataset.keys()) _attributes = [att for att in dir(dataset[keys[0]][0]) if att in attributes] keys.append("protocol") frames = [] list_fig = [] df_subset = {} for key in keys: if key != "protocol": df_subset[key] = create_df(dataset[key], _attributes, bins) frames.append(df_subset[key]) else: df_subset[key] = pd.concat(frames) for attribute in _attributes: fig, axes = plt.subplots(nrows=1, ncols=len(keys)) axes_index = 0 for key in keys: if attribute == "gender": class_count = df_subset[key].groupby(["gender"]).size() class_count.plot( ax=axes[axes_index], kind="pie", title=str(key).capitalize(), autopct=lambda p: "{:.1f}%({:.0f})".format( p, (p / 100) * class_count.sum() ), ylabel="", ) axes_index += 1 if attribute == "age": class_count = ( df_subset[key] .groupby(["ageGroup", "gender"]) .age.count() .unstack() ) class_count.plot( ax=axes[axes_index], kind="bar", stacked=False, ylabel="", title=str(key).capitalize(), ) axes_index += 1 if attribute == "medication": class_count = df_subset[key].groupby(["medication"]).size() class_count.plot( ax=axes[axes_index], kind="pie", title=str(key).capitalize(), autopct=lambda p: "{:.1f}%({:.0f})".format( p, (p / 100) * class_count.sum() ), ylabel="", ) axes_index += 1 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) fig.suptitle(str(attribute).capitalize()) plt.close() list_fig.append(fig) if out_path is not None: save_fig(out_path, list_fig) return list_fig