Source code for sleepless.engine.trainer_scikit

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Training script."""

from __future__ import annotations

import logging
import typing

from collections.abc import Mapping
from datetime import datetime

import joblib
import numpy as np
import pandas as pd

from sklearn.model_selection import PredefinedSplit

from ..data.sample import DelayedSample
from ..utils.utils_fig_table_df import make_rst_tabulate, save_tables

logger = logging.getLogger(__name__)


[docs] def train_scikit( model: typing.Any, training_set: list[DelayedSample], validation_set: list[list[DelayedSample]], output_folder: str, model_parameters: Mapping, ) -> None: """Train script for the scikit-learn pipeline. :param model: The scikit learn model to be fit :param training_set: the training_set which need to already be transformed. :param validation_set: a list of validation_set which need to already be transformed. :param output_folder: A path where the training model will be saved :param model_parameters: The parameters to train the model """ if "early_stop" in model_parameters: early_stop = model_parameters["early_stop"] else: early_stop = False get_np_data_train = np.array( [ np.array( (sample.features, sample.label[np.newaxis, :]), dtype=object ) for sample in training_set ], ) X_train, y_train = np.concatenate(get_np_data_train[:, 0]), np.concatenate( get_np_data_train[:, 1], axis=1 ) if ("grid-search" in model_parameters) or early_stop: get_np_data_valid_list = [ np.array( [ np.array( (sample.features, sample.label[np.newaxis, :]), dtype=object, ) for sample in valid_transf ] ) for valid_transf in validation_set ] list_valid = [ ( np.concatenate(get_np_data_valid[:, 0]), np.concatenate(get_np_data_valid[:, 1], axis=1), ) for get_np_data_valid in get_np_data_valid_list ] eval_set = [(X_train, y_train)] + list_valid if "grid-search" in model_parameters: fit_model = train_scikit_grid_search( model, eval_set, output_folder, **model_parameters["grid-search"], ) elif early_stop: logger.info("early stop option activated") fit_model = model.fit( X_train, y_train.ravel(), eval_set=eval_set, ) else: fit_model = model.fit(X_train, y_train.ravel()) joblib.dump(fit_model, output_folder + "/fit_model") return