Source code for sleepless.engine.predictor_scikit
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Predcition script."""
from __future__ import annotations
import logging
import os
from collections.abc import Mapping
import numpy as np
from ..data.sample import DelayedSample
from ..data.utils import ComposeTransform
from .utils import save_hdf5
logger = logging.getLogger(__name__)
[docs]def predict_scikit(
dataset: dict[str, list[DelayedSample]],
model: object,
output_folder: str,
model_parameters: Mapping,
):
"""Compute the class probabilities prediction (or prediction if predict
probabilities is not possible) for a set of data, given a fitted model. The
prediction are computed for all samples of all keys.
:param dataset: A dictionary containing a list of DelayedSample.
:param model: A scikit learn model already fitted.
:param output_folder: A path where prediction will be saved
:param model_parameters: a dictionary where the following key need to be defined,
``transform``: list (if data are not transformed yet)
"""
if "transform" in model_parameters:
compose_transform = ComposeTransform(model_parameters["transform"])
for k, v in dataset.items():
if not (hasattr(v[0], "features")):
v = compose_transform(v)
for sample in v:
if hasattr(model, "predict_proba"):
output_prob = model.predict_proba(sample.features)
elif hasattr(model, "predict"):
output_prob = model.predict(sample.features)[:, None]
else:
logger.error("Model can not predict")
output_folder_pred = os.path.join(output_folder, k)
save_hdf5(
sample.key,
output_prob,
sample.label,
np.arange(0, len(sample.label)),
output_folder_pred,
sample.features,
)
return