traingnn.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import argparse
  2. import os
  3. import random
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from tensorboardX import SummaryWriter
  8. from torchsummary import summary
  9. from dataloader.dataloader import PrivacyDataloader
  10. from dataloader.dataset import Dataset
  11. from train.model import GGNN
  12. from train.test import test
  13. from train.train import train
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
  16. parser.add_argument('--train_batch_size', type=int, default=1, help='input batch size')
  17. parser.add_argument('--test_batch_size', type=int, default=1, help='input batch size')
  18. parser.add_argument('--state_dim', type=int, default=106, help='GGNN hidden state size')
  19. parser.add_argument('--n_steps', type=int, default=10, help='propogation steps number of GGNN')
  20. parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
  21. parser.add_argument('--lr', type=float, default=0.0005, help='learning rate')
  22. parser.add_argument('--cuda', type=bool, default=True, help='enables cuda')
  23. parser.add_argument('--verbal', type=bool, default=True, help='print training info or not')
  24. parser.add_argument('--manualSeed', type=int, help='manual seed')
  25. parser.add_argument('--n_classes', type=int, default=6, help='manual seed')
  26. parser.add_argument('--directory', default="data/traindata", help='program data')
  27. parser.add_argument('--model_path', default="model/model.ckpt", help='path to save the model')
  28. parser.add_argument('--n_hidden', type=int, default=50, help='number of hidden layers')
  29. parser.add_argument('--size_vocabulary', type=int, default=108, help='maximum number of node types')
  30. parser.add_argument('--is_training_ggnn', type=bool, default=True, help='Training GGNN or BiGGNN')
  31. parser.add_argument('--training', type=bool, default=True, help='is training')
  32. parser.add_argument('--testing', type=bool, default=False, help='is testing')
  33. parser.add_argument('--training_percentage', type=float, default=1.0, help='percentage of data use for training')
  34. parser.add_argument('--log_path', default="logs/", help='log path for tensorboard')
  35. parser.add_argument('--epoch', type=int, default=5, help='epoch to test')
  36. parser.add_argument('--n_edge_types', type=int, default=1, help='edge types')
  37. parser.add_argument('--n_node', type=int, help='node types')
  38. opt = parser.parse_args()
  39. if opt.manualSeed is None:
  40. opt.manualSeed = random.randint(1, 10000)
  41. print("Random Seed: ", opt.manualSeed)
  42. random.seed(opt.manualSeed)
  43. torch.manual_seed(opt.manualSeed)
  44. if opt.cuda:
  45. torch.cuda.manual_seed_all(opt.manualSeed)
  46. """
  47. opt Namespace(workers=2, train_batch_size=5, test_batch_size=5, state_dim=30, n_steps=10, niter=150, lr=0.01, cuda=False, verbal=True, manualSeed=None, n_classes=10, directory='program_data/github_java_sort_function_babi', model_path='model/model.ckpt', n_hidden=50, size_vocabulary=59, is_training_ggnn=True, training=False, testing=False, training_percentage=1.0, log_path='logs/', epoch=0, pretrained_embeddings='embedding/fast_pretrained_vectors.pkl')
  48. """
  49. def main(opt):
  50. train_dataset = Dataset(
  51. "data/traindata", True)
  52. train_dataloader = PrivacyDataloader(train_dataset, batch_size=5, shuffle=True, num_workers=2)
  53. test_dataset = Dataset(
  54. "data/traindata", False)
  55. test_dataloader = PrivacyDataloader(test_dataset, batch_size=5, shuffle=True, num_workers=2)
  56. opt.annotation_dim = 1 # for bAbI
  57. if opt.training:
  58. opt.n_edge_types = train_dataset.n_edge_types
  59. opt.n_node = train_dataset.n_node_by_id
  60. else:
  61. opt.n_edge_types = test_dataset.n_edge_types
  62. opt.n_node = test_dataset.n_node_by_id
  63. if opt.testing:
  64. filename = "{}.{}".format(opt.model_path, opt.epoch)
  65. epoch = opt.epoch
  66. else:
  67. filename = opt.model_path
  68. epoch = -1
  69. if os.path.exists(filename):
  70. if opt.testing:
  71. print("Using No. {} saved model....".format(opt.epoch))
  72. dirname = os.path.dirname(filename)
  73. basename = os.path.basename(filename)
  74. epochs = os.listdir(dirname)
  75. if len(epochs) > 0:
  76. for s in epochs:
  77. if s.startswith(basename) and basename != s:
  78. x = s.split(os.extsep)
  79. e = x[len(x) - 1]
  80. epoch = max(epoch, int(e))
  81. if epoch != -1:
  82. print("Using No. {} of the saved models...".format(epoch))
  83. filename = "{}.{}".format(opt.model_path, epoch)
  84. if epoch != -1:
  85. print("Using No. {} saved model....".format(epoch))
  86. else:
  87. print("Using saved model....")
  88. net = torch.load(filename)
  89. else:
  90. net = GGNN(opt)
  91. net.double()
  92. net = GGNN(opt)
  93. net.double()
  94. # print(net)
  95. criterion = nn.CrossEntropyLoss()
  96. if opt.cuda:
  97. net.cuda()
  98. criterion.cuda()
  99. optimizer = optim.Adam(net.parameters(), lr=opt.lr)
  100. if opt.training and opt.log_path != "":
  101. previous_runs = os.listdir(opt.log_path)
  102. if len(previous_runs) == 0:
  103. run_number = 1
  104. else:
  105. run_number = max([int(s.split("run-")[1]) for s in previous_runs]) + 1
  106. writer = SummaryWriter("%s/run-%03d" % (opt.log_path, run_number))
  107. # writer = SummaryWriter(opt.log_path)
  108. else:
  109. writer = None
  110. print(net)
  111. if opt.training:
  112. for epoch in range(epoch + 1, epoch + opt.niter):
  113. train(epoch, train_dataloader, net, criterion, optimizer, opt, writer)
  114. writer.close()
  115. if opt.testing:
  116. filename = "{}.{}".format(opt.model_path, epoch)
  117. if os.path.exists(filename):
  118. net = torch.load(filename)
  119. net.cuda()
  120. optimizer = optim.Adam(net.parameters(), lr=opt.lr)
  121. print(opt)
  122. test(test_dataloader, net, criterion, optimizer, opt)
  123. def test_gnn():
  124. opt.directory = "data/traindata"
  125. opt.training = False
  126. opt.testing = True
  127. opt.model_path = 'model/model.ckpt'
  128. print(opt)
  129. # main(opt)
  130. def train_gnn():
  131. print(opt)
  132. # main(opt)
  133. def train_binary(directory="data/traindatabinary"):
  134. pass
  135. if __name__ == '__main__':
  136. train_gnn()
  137. """
  138. [-0.5253, -0.7534, 0.6765, -1.0767, 0.7319, -2.6533, 0.5728]
  139. """