sleepless.engine.trainer_torch#
Training script.
Functions
|
Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced. |
|
Check the device type and the availability of GPU. |
|
Process the checkpointer, save the final model and keep track of the best model. |
|
Creation of the logfile fields that will appear in the logfile. |
|
Fits a CNN model using supervised learning and save it to disk. |
|
Save a little summary of the model in a txt file. |
Save the static information in a csv file. |
|
|
Context manager to turn ON/OFF model evaluation. |
|
Trains the model for a single epoch (through all batches) |
|
Fits a CNN model using supervised learning and save it to disk. |
|
Processes input samples and returns loss (scalar) |
|
Write log info in trainlog.csv. |
- sleepless.engine.trainer_torch.torch_evaluation(model)[source]#
Context manager to turn ON/OFF model evaluation. This context manager will turn evaluation mode ON on entry and turn it OFF when exiting the
with
statement block.- Parameters:
model (
Module
) – pytorch network- Yields:
model (pytorch network)
- sleepless.engine.trainer_torch.check_gpu(device)[source]#
Check the device type and the availability of GPU.
- Parameters:
device (
device
) – device to use
- sleepless.engine.trainer_torch.save_model_summary(output_folder, model)[source]#
Save a little summary of the model in a txt file.
- sleepless.engine.trainer_torch.static_information_to_csv(static_logfile_name, device, n)[source]#
Save the static information in a csv file.
- sleepless.engine.trainer_torch.check_exist_logfile(logfile_name, arguments)[source]#
Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced.
- sleepless.engine.trainer_torch.create_logfile_fields(valid_loader, extra_valid_loaders, device)[source]#
Creation of the logfile fields that will appear in the logfile.
- Parameters:
valid_loader (
DataLoader
) – To be used to validate the model and enable automatic checkpointing. If set toNone
, then do not validate it.extra_valid_loaders (
list
[DataLoader
]) – To be used to validate the model, however does not affect automatic checkpointing. If set toNone
, or empty, then does not log anything else. Otherwise, an extra column with the loss of every dataset in this list is kept on the final training log.device (
device
) – device to use
- Return type:
- Returns:
The fields that will appear in trainlog.csv
- sleepless.engine.trainer_torch.train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count)[source]#
Trains the model for a single epoch (through all batches)
- Parameters:
loader –
torch.utils.data.DataLoader
To be used to train the modelmodel – pytorch network
optimizer – pytorch optimizer
device – device to use
criterion – pytorch loss function
batch_chunk_count – If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case. To better understand gradient accumulation, read https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch.
- Returns:
A floating-point value corresponding the weighted average of this epoch’s loss
- sleepless.engine.trainer_torch.validate_epoch(loader, model, device, criterion, pbar_desc)[source]#
Processes input samples and returns loss (scalar)
- Parameters:
loader – To be used to validate the model
model – pytorch network
optimizer – pytorch optimizer
device – device to use
criterion – loss function
pbar_desc – A string for the progress bar descriptor
- Returns:
A floating-point value corresponding the weighted average of this epoch’s loss
- sleepless.engine.trainer_torch.checkpointer_process(checkpointer, checkpoint_period, valid_loss, lowest_validation_loss, arguments, epoch, max_epoch)[source]#
Process the checkpointer, save the final model and keep track of the best model.
- Parameters:
checkpointer (
Checkpointer
) – checkpointer implementationcheckpoint_period (
int
) – save a checkpoint everyn
epochs. If set to0
(zero), then do not save intermediary checkpointsvalid_loss (
float
) – Current epoch validation losslowest_validation_loss (
float
) – Keeps track of the best (lowest) validation lossarguments (
dict
) – start and end epochsepoch (
int
) – current epochmax_epoch (
int
) – end_epoch
- Return type:
- Returns:
The lowest validation loss currently observed
- sleepless.engine.trainer_torch.write_log_info(epoch, current_time, eta_seconds, loss, valid_loss, extra_valid_losses, optimizer, logwriter, logfile, resource_data)[source]#
Write log info in trainlog.csv.
- Parameters:
epoch (
int
) – Current epochcurrent_time (
float
) – Current training timeeta_seconds (
float
) – estimated time-of-arrival taking into consideration previous epoch performanceloss (
float
) – Current epoch’s training lossvalid_loss (
Optional
[float
]) – Current epoch’s validation lossextra_valid_losses (
Optional
[list
[float
]]) – Validation losses from other validation datasets being currently trackedoptimizer (
Optimizer
) – pytorch optimizerlogwriter (
DictWriter
) – Dictionary writer that give the ability to write on the trainlog.csvlogfile (
TextIOWrapper
) – text file containing the logdresource_data (
tuple
) – Monitored resources at the machine (CPU and GPU)
- sleepless.engine.trainer_torch.run(model, data_loader, valid_loader, extra_valid_loaders, optimizer, scheduler, criterion, checkpointer, checkpoint_period, device, arguments, output_folder, monitoring_interval, batch_chunk_count, criterion_valid, patience)[source]#
Fits a CNN model using supervised learning and save it to disk. This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.
- Parameters:
model – pytorch network
data_loader – To be used to train the model
valid_loaders – To be used to validate the model and enable automatic checkpointing. If
None
, then do not validate it.extra_valid_loaders – To be used to validate the model, however does not affect automatic checkpointing. If empty, then does not log anything else. Otherwise, an extra column with the loss of every dataset in this list is kept on the final training log.
optimizer – pytorch optimizer
scheduler – pytorch scheduler
criterion – loss function
checkpointer – checkpointer implementation
checkpoint_period – save a checkpoint every
n
epochs. If set to0
(zero), then do not save intermediary checkpointsdevice – device to use
arguments – start and end epochs
output_folder – output path
monitoring_interval – interval, in seconds (or fractions), through which we should monitor resources during training.
batch_chunk_count – If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case.
criterion_valid – specific loss function for the validation set
- sleepless.engine.trainer_torch.train_torch(model, training_set, validation_set, output_folder, model_parameters)[source]#
Fits a CNN model using supervised learning and save it to disk. This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.
- Parameters:
model (
Module
) – pytorch networkdata_loader – To be used to train the model
valid_loaders – To be used to validate the model and enable automatic checkpointing. If
None
, then do not validate it.output_folder (
str
) – path to save the model and parametersmodel_parameters (
Mapping
) –a dictionary where the following keys need to be defined,
optimizer
:torch.optim.Optimizer
epochs
: intbatch_size
: intvalid_batch_size
: intbatch_chunk_count
: intdrop_incomplete_batch
: boolcriterion
: pytorch loss functionscheduler
:torch.optim
checkpoint_period
: intdevice
: strseed
: intparallel
: intmonitoring_interval
: int | floatand optionally:
criterion_valid
: pytorch loss functionpatience
: pytorch loss function