train.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import torch
  2. from torch.autograd import Variable
  3. from shutil import copyfile
  4. from dataloader.dataloader import PrivacyDataloader
  5. from dataloader.dataset import Dataset
  6. def train(epoch, dataloader, net, criterion, optimizer, opt, writer):
  7. print("------------training_epoch: ", epoch, "----------------------------")
  8. for i, (adj_matrix, embedding_matrix, target) in enumerate(dataloader, 0):
  9. # print("---------", i, "-----------")
  10. net.zero_grad()
  11. # print(embedding_matrix)
  12. # padding = torch.zeros(len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim).double()
  13. # init_input = torch.cat((annotation, padding), 2)
  14. # init_input = torch.zeros(len(adj_matrix), opt.n_node, opt.state_dim).double()
  15. # init_input = torch.from_numpy(embedding_matrix).double()
  16. init_input = embedding_matrix
  17. # print("input_shape", init_input.shape)
  18. # print(init_input)
  19. if opt.cuda:
  20. init_input = init_input.cuda()
  21. adj_matrix = adj_matrix.cuda()
  22. # annotation = annotation.cuda()
  23. target = target.cuda()
  24. init_input = Variable(init_input)
  25. adj_matrix = Variable(adj_matrix)
  26. # annotation = Variable(annotation)
  27. target = Variable(target)
  28. output = net(init_input, adj_matrix)
  29. # print("ouput_shape", output.shape)
  30. # print("target_shape", target.shape)
  31. # print(output)
  32. # print(target)
  33. loss = criterion(output, target)
  34. loss.backward()
  35. optimizer.step()
  36. print("loss", loss)
  37. writer.add_scalar('loss', loss.data.item(), int(epoch))
  38. if i % int(len(dataloader) / 10 + 1) == 0 and opt.verbal:
  39. print('[%d/%d][%d/%d] Loss: %.4f' % (epoch, opt.niter, i, len(dataloader), loss.item()))
  40. torch.save(net, opt.model_path)
  41. copyfile(opt.model_path, "{}.{}".format(opt.model_path, epoch))
  42. if __name__ == '__main__':
  43. # traindata = Dataset(
  44. # "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train", True)
  45. # dataloader = PrivacyDataloader(traindata, batch_size=5, shuffle=True, num_workers=2)
  46. pass