Source code for sleepless.utils.checkpointer
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from __future__ import annotations
import logging
import os
import torch
logger = logging.getLogger(__name__)
[docs]class Checkpointer:
"""A simple pytorch checkpointer.
:param model: Network model, eventually loaded from a checkpointed
file
:param optimizer: Optimizer
:param scheduler: Learning rate scheduler
:param path: Directory where to save checkpoints.
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
scheduler: torch.optim._LRScheduler = None,
path: str = ".",
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.path = os.path.realpath(path)
[docs] def save(self, name, **kwargs):
data = {}
data["model"] = self.model.state_dict()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
data["scheduler"] = self.scheduler.state_dict()
data.update(kwargs)
name = f"{name}.pth"
outf = os.path.join(self.path, name)
logger.info(f"Saving checkpoint to {outf}")
torch.save(data, outf)
with open(self._last_checkpoint_filename, "w") as f:
f.write(name)
[docs] def load(self, f: str = None):
"""Loads model, optimizer and scheduler from file.
:param f: Name of a file (absolute or relative to ``self.path``), that
contains the checkpoint data to load into the model, and optionally
into the optimizer and the scheduler. If not specified, loads data
from current path.
"""
if f is None:
f = self.last_checkpoint()
if f is None:
# no checkpoint could be found
logger.warning("No checkpoint found (and none passed)")
return {}
# loads file data into memory
logger.info(f"Loading checkpoint from {f}...")
checkpoint = torch.load(f, map_location=torch.device("cpu"))
# converts model entry to model parameters
self.model.load_state_dict(checkpoint.pop("model"))
if self.optimizer is not None:
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if self.scheduler is not None:
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
return checkpoint
@property
def _last_checkpoint_filename(self):
return os.path.join(self.path, "last_checkpoint")
[docs] def last_checkpoint(self):
if self.has_checkpoint():
with open(self._last_checkpoint_filename) as fobj:
return os.path.join(self.path, fobj.read().strip())
return None