12345678910111213141516171819 |
- def save(self, name, **kwargs):
- if not self.save_dir:
- return
- if not self.save_to_disk:
- return
- data = {}
- data["model"] = self.model.state_dict()
- if self.optimizer is not None:
- data["optimizer"] = self.optimizer.state_dict()
- if self.scheduler is not None:
- data["scheduler"] = self.scheduler.state_dict()
- data.update(kwargs)
- save_file = os.path.join(self.save_dir, "{}.pth".format(name))
- self.logger.info("Saving checkpoint to {}".format(save_file))
- torch.save(data, save_file)
- self.tag_last_checkpoint(save_file)
|