checkpoint_2.py 649 B

12345678910111213141516171819
  1. def save(self, name, **kwargs):
  2. if not self.save_dir:
  3. return
  4. if not self.save_to_disk:
  5. return
  6. data = {}
  7. data["model"] = self.model.state_dict()
  8. if self.optimizer is not None:
  9. data["optimizer"] = self.optimizer.state_dict()
  10. if self.scheduler is not None:
  11. data["scheduler"] = self.scheduler.state_dict()
  12. data.update(kwargs)
  13. save_file = os.path.join(self.save_dir, "{}.pth".format(name))
  14. self.logger.info("Saving checkpoint to {}".format(save_file))
  15. torch.save(data, save_file)
  16. self.tag_last_checkpoint(save_file)