def save(self, model, iter_nb, train_metrics_values, test_metrics_values, tasks_weights=[], optimizer=None): self.logs_dict['train'][str(iter_nb)] = {} self.logs_dict['val'][str(iter_nb)] = {} for k in range(len(self.metrics)): self.logs_dict['train'][str(iter_nb)][self.metrics[k]] = float(train_metrics_values[k]) self.logs_dict['val'][str(iter_nb)][self.metrics[k]] = float(test_metrics_values[k]) if len(tasks_weights) > 0: for k in range(len(tasks_weights)): self.logs_dict['val'][str(iter_nb)]['weight_' + str(k)] = tasks_weights[k] with open(self.logs_file, 'w') as f: json.dump(self.logs_dict, f) ckpt = { 'model_state_dict': model.state_dict(), 'iter_nb': iter_nb, } if optimizer: ckpt['optimizer_state_dict'] = optimizer.state_dict() # Saves best miou score if reached if 'MEAN_IOU' in self.metrics: miou = float(test_metrics_values[self.metrics.index('MEAN_IOU')]) if miou > self.best_miou and iter_nb > 0: print('Best miou. Saving it.') torch.save(ckpt, self.best_miou_weights_file) self.best_miou = miou self.config_dict['best_miou'] = self.best_miou # Saves best relative error if reached if 'REL_ERR' in self.metrics: rel_error = float(test_metrics_values[self.metrics.index('REL_ERR')]) if rel_error < self.best_rel_error and iter_nb > 0: print('Best rel error. Saving it.') torch.save(ckpt, self.best_rel_error_weights_file) self.best_rel_error = rel_error self.config_dict['best_rel_error'] = self.best_rel_error # Saves last checkpoint torch.save(ckpt, self.last_checkpoint_weights_file) self.iter_nb = iter_nb self.config_dict['iter'] = self.iter_nb with open(self.config_file, 'w') as f: json.dump(self.config_dict, f)