train.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import torch
  2. from torch.autograd import Variable
  3. from shutil import copyfile
  4. from dataloader.dataloader import PrivacyDataloader
  5. from dataloader.dataset import Dataset
  6. def train(epoch, dataloader, net, criterion, optimizer, opt, writer):
  7. for i, (adj_matrix, embedding_matrix, target) in enumerate(dataloader, 0):
  8. print("----------------")
  9. net.zero_grad()
  10. print(embedding_matrix)
  11. # padding = torch.zeros(len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim).double()
  12. # init_input = torch.cat((annotation, padding), 2)
  13. # init_input = torch.zeros(len(adj_matrix), opt.n_node, opt.state_dim).double()
  14. # init_input = torch.from_numpy(embedding_matrix).double()
  15. init_input = embedding_matrix
  16. print(init_input.shape)
  17. print(init_input)
  18. if opt.cuda:
  19. init_input = init_input.cuda()
  20. adj_matrix = adj_matrix.cuda()
  21. # annotation = annotation.cuda()
  22. target = target.cuda()
  23. init_input = Variable(init_input)
  24. adj_matrix = Variable(adj_matrix)
  25. # annotation = Variable(annotation)
  26. target = Variable(target)
  27. output = net(init_input, adj_matrix)
  28. print(output.shape)
  29. print(target.shape)
  30. # print(output)
  31. # print(target)
  32. loss = criterion(output, target)
  33. loss.backward()
  34. optimizer.step()
  35. print(loss)
  36. print(epoch)
  37. writer.add_scalar('loss', loss.data.item(), int(epoch))
  38. if i % int(len(dataloader) / 10 + 1) == 0 and opt.verbal:
  39. print('[%d/%d][%d/%d] Loss: %.4f' % (epoch, opt.niter, i, len(dataloader), loss.item()))
  40. torch.save(net, opt.model_path)
  41. copyfile(opt.model_path, "{}.{}".format(opt.model_path, epoch))
  42. if __name__ == '__main__':
  43. # traindata = Dataset(
  44. # "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train", True)
  45. # dataloader = PrivacyDataloader(traindata, batch_size=5, shuffle=True, num_workers=2)
  46. pass