traingnn.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import argparse
  2. import os
  3. import random
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from tensorboardX import SummaryWriter
  8. from dataloader.dataloader import PrivacyDataloader
  9. from dataloader.dataset import Dataset
  10. from train.model import GGNN
  11. from train.test import test
  12. from train.train import train
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
  15. parser.add_argument('--train_batch_size', type=int, default=5, help='input batch size')
  16. parser.add_argument('--test_batch_size', type=int, default=5, help='input batch size')
  17. parser.add_argument('--state_dim', type=int, default=106, help='GGNN hidden state size')
  18. parser.add_argument('--n_steps', type=int, default=10, help='propogation steps number of GGNN')
  19. parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
  20. parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
  21. parser.add_argument('--cuda', type=bool, default=True, help='enables cuda')
  22. parser.add_argument('--verbal', type=bool, default=True, help='print training info or not')
  23. parser.add_argument('--manualSeed', type=int, help='manual seed')
  24. parser.add_argument('--n_classes', type=int, default=7, help='manual seed')
  25. parser.add_argument('--directory', default="data/traindata/train", help='program data')
  26. parser.add_argument('--model_path', default="model/model.ckpt", help='path to save the model')
  27. parser.add_argument('--n_hidden', type=int, default=50, help='number of hidden layers')
  28. parser.add_argument('--size_vocabulary', type=int, default=108, help='maximum number of node types')
  29. parser.add_argument('--is_training_ggnn', type=bool, default=True, help='Training GGNN or BiGGNN')
  30. parser.add_argument('--training', type=bool, default=True, help='is training')
  31. parser.add_argument('--testing', type=bool, default=False, help='is testing')
  32. parser.add_argument('--training_percentage', type=float, default=1.0, help='percentage of data use for training')
  33. parser.add_argument('--log_path', default="logs/", help='log path for tensorboard')
  34. parser.add_argument('--epoch', type=int, default=5, help='epoch to test')
  35. parser.add_argument('--n_edge_types', type=int, default=1, help='edge types')
  36. parser.add_argument('--n_node', type=int, help='node types')
  37. opt = parser.parse_args()
  38. if opt.manualSeed is None:
  39. opt.manualSeed = random.randint(1, 10000)
  40. print("Random Seed: ", opt.manualSeed)
  41. random.seed(opt.manualSeed)
  42. torch.manual_seed(opt.manualSeed)
  43. if opt.cuda:
  44. torch.cuda.manual_seed_all(opt.manualSeed)
  45. """
  46. opt Namespace(workers=2, train_batch_size=5, test_batch_size=5, state_dim=30, n_steps=10, niter=150, lr=0.01, cuda=False, verbal=True, manualSeed=None, n_classes=10, directory='program_data/github_java_sort_function_babi', model_path='model/model.ckpt', n_hidden=50, size_vocabulary=59, is_training_ggnn=True, training=False, testing=False, training_percentage=1.0, log_path='logs/', epoch=0, pretrained_embeddings='embedding/fast_pretrained_vectors.pkl')
  47. """
  48. def main(opt):
  49. train_dataset = Dataset(
  50. "data/traindata/train", True)
  51. train_dataloader = PrivacyDataloader(train_dataset, batch_size=5, shuffle=True, num_workers=2)
  52. test_dataset = Dataset(
  53. "data/traindata/test", True)
  54. test_dataloader = PrivacyDataloader(test_dataset, batch_size=5, shuffle=True, num_workers=2)
  55. opt.annotation_dim = 1 # for bAbI
  56. if opt.training:
  57. opt.n_edge_types = train_dataset.n_edge_types
  58. opt.n_node = train_dataset.n_node_by_id
  59. else:
  60. opt.n_edge_types = test_dataset.n_edge_types
  61. opt.n_node = test_dataset.n_node_by_id
  62. if opt.testing:
  63. filename = "{}.{}".format(opt.model_path, opt.epoch)
  64. epoch = opt.epoch
  65. else:
  66. filename = opt.model_path
  67. epoch = -1
  68. if os.path.exists(filename):
  69. if opt.testing:
  70. print("Using No. {} saved model....".format(opt.epoch))
  71. dirname = os.path.dirname(filename)
  72. basename = os.path.basename(filename)
  73. epochs = os.listdir(dirname)
  74. if len(epochs) > 0:
  75. for s in epochs:
  76. if s.startswith(basename) and basename != s:
  77. x = s.split(os.extsep)
  78. e = x[len(x) - 1]
  79. epoch = max(epoch, int(e))
  80. if epoch != -1:
  81. print("Using No. {} of the saved models...".format(epoch))
  82. filename = "{}.{}".format(opt.model_path, epoch)
  83. if epoch != -1:
  84. print("Using No. {} saved model....".format(epoch))
  85. else:
  86. print("Using saved model....")
  87. net = torch.load(filename)
  88. else:
  89. net = GGNN(opt)
  90. net.double()
  91. net = GGNN(opt)
  92. net.double()
  93. print(net)
  94. criterion = nn.CrossEntropyLoss()
  95. if opt.cuda:
  96. net.cuda()
  97. criterion.cuda()
  98. optimizer = optim.Adam(net.parameters(), lr=opt.lr)
  99. if opt.training and opt.log_path != "":
  100. previous_runs = os.listdir(opt.log_path)
  101. if len(previous_runs) == 0:
  102. run_number = 1
  103. else:
  104. run_number = max([int(s.split("run-")[1]) for s in previous_runs]) + 1
  105. writer = SummaryWriter("%s/run-%03d" % (opt.log_path, run_number))
  106. # writer = SummaryWriter(opt.log_path)
  107. else:
  108. writer = None
  109. opt.training = True
  110. print(opt)
  111. # embedding_matrix = train_dataset.embedding_matrix
  112. if opt.training:
  113. for epoch in range(epoch + 1, epoch + opt.niter):
  114. train(epoch, train_dataloader, net, criterion, optimizer, opt, writer)
  115. writer.close()
  116. if opt.testing:
  117. filename = "{}.{}".format(opt.model_path, epoch)
  118. if os.path.exists(filename):
  119. net = torch.load(filename)
  120. net.cuda()
  121. optimizer = optim.Adam(net.parameters(), lr=opt.lr)
  122. test(test_dataloader, net, criterion, optimizer, opt)
  123. if __name__ == '__main__':
  124. main(opt)