utils.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334
  1. import torch
  2. import shutil
  3. import os
  4. def save_checkpoint(state, is_best, file_path, file_name='checkpoint.pth.tar'):
  5. """
  6. Saves the current state of the model. Does a copy of the file
  7. in case the model performed better than previously.
  8. Parameters:
  9. state (dict): Includes optimizer and model state dictionaries.
  10. is_best (bool): True if model is best performing model.
  11. file_path (str): Path to save the file.
  12. file_name (str): File name with extension (default: checkpoint.pth.tar).
  13. """
  14. save_path = os.path.join(file_path, file_name)
  15. torch.save(state, save_path)
  16. if is_best:
  17. shutil.copyfile(save_path, os.path.join(file_path, 'model_best.pth.tar'))
  18. def save_task_checkpoint(file_path, task_num):
  19. """
  20. Saves the current state of the model for a given task by copying existing checkpoint created by the
  21. save_checkpoint function.
  22. Parameters:
  23. file_path (str): Path to save the file,
  24. task_num (int): Number of task increment.
  25. """
  26. save_path = os.path.join(file_path, 'checkpoint_task_' + str(task_num) + '.pth.tar')
  27. shutil.copyfile(os.path.join(file_path, 'checkpoint.pth.tar'), save_path)