sleepless.engine.predictor_torch#

Predcition script.

Functions

predict_torch(dataset, model, weight, ...)

Prepare data and model to runs inference on data.

run(model, data_loader, name, device, ...)

Runs inference on input data, outputs HDF5 files with predictions.

sleepless.engine.predictor_torch.run(model, data_loader, name, device, output_folder)[source]#

Runs inference on input data, outputs HDF5 files with predictions.

Parameters:
  • model (Module) – neural network model fitted

  • data_loader (DataLoader) – dataset

  • name (str) – the local name of this dataset (e.g. train, or test), to be used when saving measures files.

  • device (device) – device to use

  • output_folder (str) – folder where to store output prediction (HDF5 files)

sleepless.engine.predictor_torch.predict_torch(dataset, model, weight, output_folder, model_parameters)[source]#

Prepare data and model to runs inference on data.

Parameters:
  • dataset (dict) – A dictionary containing a torch.utils.data.ConcatDataset per key

  • model (Module) – neural network model not fitted

  • weight (str) – weigth path to fit the neural network

  • output_folder (str) – folder where prediciton will be saved

  • model_parameters (Mapping) – a dictionary where the following keys need to be defined, batch_size: int parallel: int device: str transform: list (if data are not trasnformed yet)