[docs]classCheckpointer:"""A simple pytorch checkpointer. :param model: Network model, eventually loaded from a checkpointed file :param optimizer: Optimizer :param scheduler: Learning rate scheduler :param path: Directory where to save checkpoints. """def__init__(self,model:torch.nn.Module,optimizer:torch.optim.Optimizer=None,scheduler:torch.optim._LRScheduler=None,path:str=".",):self.model=modelself.optimizer=optimizerself.scheduler=schedulerself.path=os.path.realpath(path)
[docs]defsave(self,name,**kwargs):data={}data["model"]=self.model.state_dict()ifself.optimizerisnotNone:data["optimizer"]=self.optimizer.state_dict()ifself.schedulerisnotNone:data["scheduler"]=self.scheduler.state_dict()data.update(kwargs)name=f"{name}.pth"outf=os.path.join(self.path,name)logger.info(f"Saving checkpoint to {outf}")torch.save(data,outf)withopen(self._last_checkpoint_filename,"w")asf:f.write(name)
[docs]defload(self,f:str=None):"""Loads model, optimizer and scheduler from file. :param f: Name of a file (absolute or relative to ``self.path``), that contains the checkpoint data to load into the model, and optionally into the optimizer and the scheduler. If not specified, loads data from current path. """iffisNone:f=self.last_checkpoint()iffisNone:# no checkpoint could be foundlogger.warning("No checkpoint found (and none passed)")return{}# loads file data into memorylogger.info(f"Loading checkpoint from {f}...")checkpoint=torch.load(f,map_location=torch.device("cpu"))# converts model entry to model parametersself.model.load_state_dict(checkpoint.pop("model"))ifself.optimizerisnotNone:self.optimizer.load_state_dict(checkpoint.pop("optimizer"))ifself.schedulerisnotNone:self.scheduler.load_state_dict(checkpoint.pop("scheduler"))returncheckpoint