train.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. from torch.autograd import Variable
  3. from shutil import copyfile
  4. from torchsummary import summary
  5. from dataloader.dataloader import PrivacyDataloader
  6. from dataloader.dataset import Dataset
  7. def train(epoch, dataloader, net, criterion, optimizer, opt, writer):
  8. print("------------training_epoch: ", epoch, "----------------------------")
  9. for i, (adj_matrix, embedding_matrix, target) in enumerate(dataloader, 0):
  10. # print("---------", i, "-----------")
  11. net.zero_grad()
  12. # print(embedding_matrix)
  13. # padding = torch.zeros(len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim).double()
  14. # init_input = torch.cat((annotation, padding), 2)
  15. # init_input = torch.zeros(len(adj_matrix), opt.n_node, opt.state_dim).double()
  16. # init_input = torch.from_numpy(embedding_matrix).double()
  17. init_input = embedding_matrix
  18. print("input_shape", init_input.shape)
  19. # print(init_input)
  20. if opt.cuda:
  21. init_input = init_input.cuda()
  22. adj_matrix = adj_matrix.cuda()
  23. # annotation = annotation.cuda()
  24. target = target.cuda()
  25. init_input = Variable(init_input)
  26. adj_matrix = Variable(adj_matrix)
  27. # annotation = Variable(annotation)
  28. target = Variable(target)
  29. # print(adj_matrix.shape)
  30. output = net(init_input, adj_matrix)
  31. # print("ouput_shape", output.shape)
  32. # print("target_shape", target.shape)
  33. print("output", output)
  34. print(target)
  35. loss = criterion(output, target)
  36. loss.backward()
  37. optimizer.step()
  38. print("loss", loss)
  39. writer.add_scalar('loss', loss.data.item(), int(epoch))
  40. if i % int(len(dataloader) / 10 + 1) == 0 and opt.verbal:
  41. print('[%d/%d][%d/%d] Loss: %.4f' % (epoch, opt.niter, i, len(dataloader), loss.item()))
  42. torch.save(net, opt.model_path)
  43. copyfile(opt.model_path, "{}.{}".format(opt.model_path, epoch))
  44. if __name__ == '__main__':
  45. # traindata = Dataset(
  46. # "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train", True)
  47. # dataloader = PrivacyDataloader(traindata, batch_size=5, shuffle=True, num_workers=2)
  48. pass