utils_1.py 681 B

12345678910111213141516
  1. def save_checkpoint(state, is_best, file_path, file_name='checkpoint.pth.tar'):
  2. """
  3. Saves the current state of the model. Does a copy of the file
  4. in case the model performed better than previously.
  5. Parameters:
  6. state (dict): Includes optimizer and model state dictionaries.
  7. is_best (bool): True if model is best performing model.
  8. file_path (str): Path to save the file.
  9. file_name (str): File name with extension (default: checkpoint.pth.tar).
  10. """
  11. save_path = os.path.join(file_path, file_name)
  12. torch.save(state, save_path)
  13. if is_best:
  14. shutil.copyfile(save_path, os.path.join(file_path, 'model_best.pth.tar'))