1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- def save(self, model, iter_nb, train_metrics_values, test_metrics_values, tasks_weights=[], optimizer=None):
- self.logs_dict['train'][str(iter_nb)] = {}
- self.logs_dict['val'][str(iter_nb)] = {}
- for k in range(len(self.metrics)):
- self.logs_dict['train'][str(iter_nb)][self.metrics[k]] = float(train_metrics_values[k])
- self.logs_dict['val'][str(iter_nb)][self.metrics[k]] = float(test_metrics_values[k])
- if len(tasks_weights) > 0:
- for k in range(len(tasks_weights)):
- self.logs_dict['val'][str(iter_nb)]['weight_' + str(k)] = tasks_weights[k]
- with open(self.logs_file, 'w') as f:
- json.dump(self.logs_dict, f)
- ckpt = {
- 'model_state_dict': model.state_dict(),
- 'iter_nb': iter_nb,
- }
- if optimizer:
- ckpt['optimizer_state_dict'] = optimizer.state_dict()
- # Saves best miou score if reached
- if 'MEAN_IOU' in self.metrics:
- miou = float(test_metrics_values[self.metrics.index('MEAN_IOU')])
- if miou > self.best_miou and iter_nb > 0:
- print('Best miou. Saving it.')
- torch.save(ckpt, self.best_miou_weights_file)
- self.best_miou = miou
- self.config_dict['best_miou'] = self.best_miou
- # Saves best relative error if reached
- if 'REL_ERR' in self.metrics:
- rel_error = float(test_metrics_values[self.metrics.index('REL_ERR')])
- if rel_error < self.best_rel_error and iter_nb > 0:
- print('Best rel error. Saving it.')
- torch.save(ckpt, self.best_rel_error_weights_file)
- self.best_rel_error = rel_error
- self.config_dict['best_rel_error'] = self.best_rel_error
- # Saves last checkpoint
- torch.save(ckpt, self.last_checkpoint_weights_file)
- self.iter_nb = iter_nb
- self.config_dict['iter'] = self.iter_nb
- with open(self.config_file, 'w') as f:
- json.dump(self.config_dict, f)
|