utils_8.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. def save(self, model, iter_nb, train_metrics_values, test_metrics_values, tasks_weights=[], optimizer=None):
  2. self.logs_dict['train'][str(iter_nb)] = {}
  3. self.logs_dict['val'][str(iter_nb)] = {}
  4. for k in range(len(self.metrics)):
  5. self.logs_dict['train'][str(iter_nb)][self.metrics[k]] = float(train_metrics_values[k])
  6. self.logs_dict['val'][str(iter_nb)][self.metrics[k]] = float(test_metrics_values[k])
  7. if len(tasks_weights) > 0:
  8. for k in range(len(tasks_weights)):
  9. self.logs_dict['val'][str(iter_nb)]['weight_' + str(k)] = tasks_weights[k]
  10. with open(self.logs_file, 'w') as f:
  11. json.dump(self.logs_dict, f)
  12. ckpt = {
  13. 'model_state_dict': model.state_dict(),
  14. 'iter_nb': iter_nb,
  15. }
  16. if optimizer:
  17. ckpt['optimizer_state_dict'] = optimizer.state_dict()
  18. # Saves best miou score if reached
  19. if 'MEAN_IOU' in self.metrics:
  20. miou = float(test_metrics_values[self.metrics.index('MEAN_IOU')])
  21. if miou > self.best_miou and iter_nb > 0:
  22. print('Best miou. Saving it.')
  23. torch.save(ckpt, self.best_miou_weights_file)
  24. self.best_miou = miou
  25. self.config_dict['best_miou'] = self.best_miou
  26. # Saves best relative error if reached
  27. if 'REL_ERR' in self.metrics:
  28. rel_error = float(test_metrics_values[self.metrics.index('REL_ERR')])
  29. if rel_error < self.best_rel_error and iter_nb > 0:
  30. print('Best rel error. Saving it.')
  31. torch.save(ckpt, self.best_rel_error_weights_file)
  32. self.best_rel_error = rel_error
  33. self.config_dict['best_rel_error'] = self.best_rel_error
  34. # Saves last checkpoint
  35. torch.save(ckpt, self.last_checkpoint_weights_file)
  36. self.iter_nb = iter_nb
  37. self.config_dict['iter'] = self.iter_nb
  38. with open(self.config_file, 'w') as f:
  39. json.dump(self.config_dict, f)