sleepless.engine.predictor_torch#
Predcition script.
Functions
|
Prepare data and model to runs inference on data. |
|
Runs inference on input data, outputs HDF5 files with predictions. |
- sleepless.engine.predictor_torch.run(model, data_loader, name, device, output_folder)[source]#
Runs inference on input data, outputs HDF5 files with predictions.
- Parameters:
model (
Module
) – neural network model fitteddata_loader (
DataLoader
) – datasetname (
str
) – the local name of this dataset (e.g.train
, ortest
), to be used when saving measures files.device (
device
) – device to useoutput_folder (
str
) – folder where to store output prediction (HDF5 files)
- sleepless.engine.predictor_torch.predict_torch(dataset, model, weight, output_folder, model_parameters)[source]#
Prepare data and model to runs inference on data.
- Parameters:
dataset (
dict
) – A dictionary containing atorch.utils.data.ConcatDataset
per keymodel (
Module
) – neural network model not fittedweight (
str
) – weigth path to fit the neural networkoutput_folder (
str
) – folder where prediciton will be savedmodel_parameters (
Mapping
) – a dictionary where the following keys need to be defined,batch_size
: intparallel
: intdevice
: strtransform
: list (if data are not trasnformed yet)