sleepless.engine.predictor_torch¶
Predcition script.
Functions
|
Prepare data and model to runs inference on data. |
|
Runs inference on input data, outputs HDF5 files with predictions. |
- sleepless.engine.predictor_torch.run(model, data_loader, name, device, output_folder)[source]¶
Runs inference on input data, outputs HDF5 files with predictions.
- Parameters:
model (
Module) – neural network model fitteddata_loader (
DataLoader) – datasetname (
str) – the local name of this dataset (e.g.train, ortest), to be used when saving measures files.device (
device) – device to useoutput_folder (
str) – folder where to store output prediction (HDF5 files)
- sleepless.engine.predictor_torch.predict_torch(dataset, model, weight, output_folder, model_parameters)[source]¶
Prepare data and model to runs inference on data.
- Parameters:
dataset (
dict) – A dictionary containing atorch.utils.data.ConcatDatasetper keymodel (
Module) – neural network model not fittedweight (
str) – weigth path to fit the neural networkoutput_folder (
str) – folder where prediciton will be savedmodel_parameters (
Mapping) – a dictionary where the following keys need to be defined,batch_size: intparallel: intdevice: strtransform: list (if data are not trasnformed yet)