123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- import argparse
- import os
- import random
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from tensorboardX import SummaryWriter
- from torchsummary import summary
- from dataloader.dataloader import PrivacyDataloader
- from dataloader.dataset import Dataset
- from train.model import GGNN
- from train.test import test
- from train.train import train
- parser = argparse.ArgumentParser()
- parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
- parser.add_argument('--train_batch_size', type=int, default=1, help='input batch size')
- parser.add_argument('--test_batch_size', type=int, default=1, help='input batch size')
- parser.add_argument('--state_dim', type=int, default=106, help='GGNN hidden state size')
- parser.add_argument('--n_steps', type=int, default=10, help='propogation steps number of GGNN')
- parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
- parser.add_argument('--lr', type=float, default=0.0005, help='learning rate')
- parser.add_argument('--cuda', type=bool, default=True, help='enables cuda')
- parser.add_argument('--verbal', type=bool, default=True, help='print training info or not')
- parser.add_argument('--manualSeed', type=int, help='manual seed')
- parser.add_argument('--n_classes', type=int, default=6, help='manual seed')
- parser.add_argument('--directory', default="data/traindata", help='program data')
- parser.add_argument('--model_path', default="model/model.ckpt", help='path to save the model')
- parser.add_argument('--n_hidden', type=int, default=50, help='number of hidden layers')
- parser.add_argument('--size_vocabulary', type=int, default=108, help='maximum number of node types')
- parser.add_argument('--is_training_ggnn', type=bool, default=True, help='Training GGNN or BiGGNN')
- parser.add_argument('--training', type=bool, default=True, help='is training')
- parser.add_argument('--testing', type=bool, default=False, help='is testing')
- parser.add_argument('--training_percentage', type=float, default=1.0, help='percentage of data use for training')
- parser.add_argument('--log_path', default="logs/", help='log path for tensorboard')
- parser.add_argument('--epoch', type=int, default=5, help='epoch to test')
- parser.add_argument('--n_edge_types', type=int, default=1, help='edge types')
- parser.add_argument('--n_node', type=int, help='node types')
- opt = parser.parse_args()
- if opt.manualSeed is None:
- opt.manualSeed = random.randint(1, 10000)
- print("Random Seed: ", opt.manualSeed)
- random.seed(opt.manualSeed)
- torch.manual_seed(opt.manualSeed)
- if opt.cuda:
- torch.cuda.manual_seed_all(opt.manualSeed)
- """
- opt Namespace(workers=2, train_batch_size=5, test_batch_size=5, state_dim=30, n_steps=10, niter=150, lr=0.01, cuda=False, verbal=True, manualSeed=None, n_classes=10, directory='program_data/github_java_sort_function_babi', model_path='model/model.ckpt', n_hidden=50, size_vocabulary=59, is_training_ggnn=True, training=False, testing=False, training_percentage=1.0, log_path='logs/', epoch=0, pretrained_embeddings='embedding/fast_pretrained_vectors.pkl')
- """
- def main(opt):
- train_dataset = Dataset(
- "data/traindata", True)
- train_dataloader = PrivacyDataloader(train_dataset, batch_size=5, shuffle=True, num_workers=2)
- test_dataset = Dataset(
- "data/traindata", False)
- test_dataloader = PrivacyDataloader(test_dataset, batch_size=5, shuffle=True, num_workers=2)
- opt.annotation_dim = 1
- if opt.training:
- opt.n_edge_types = train_dataset.n_edge_types
- opt.n_node = train_dataset.n_node_by_id
- else:
- opt.n_edge_types = test_dataset.n_edge_types
- opt.n_node = test_dataset.n_node_by_id
- if opt.testing:
- filename = "{}.{}".format(opt.model_path, opt.epoch)
- epoch = opt.epoch
- else:
- filename = opt.model_path
- epoch = -1
- if os.path.exists(filename):
- if opt.testing:
- print("Using No. {} saved model....".format(opt.epoch))
- dirname = os.path.dirname(filename)
- basename = os.path.basename(filename)
- epochs = os.listdir(dirname)
- if len(epochs) > 0:
- for s in epochs:
- if s.startswith(basename) and basename != s:
- x = s.split(os.extsep)
- e = x[len(x) - 1]
- epoch = max(epoch, int(e))
- if epoch != -1:
- print("Using No. {} of the saved models...".format(epoch))
- filename = "{}.{}".format(opt.model_path, epoch)
- if epoch != -1:
- print("Using No. {} saved model....".format(epoch))
- else:
- print("Using saved model....")
- net = torch.load(filename)
- else:
- net = GGNN(opt)
- net.double()
- net = GGNN(opt)
- net.double()
-
- criterion = nn.CrossEntropyLoss()
- if opt.cuda:
- net.cuda()
- criterion.cuda()
- optimizer = optim.Adam(net.parameters(), lr=opt.lr)
- if opt.training and opt.log_path != "":
- previous_runs = os.listdir(opt.log_path)
- if len(previous_runs) == 0:
- run_number = 1
- else:
- run_number = max([int(s.split("run-")[1]) for s in previous_runs]) + 1
- writer = SummaryWriter("%s/run-%03d" % (opt.log_path, run_number))
-
- else:
- writer = None
- print(net)
- if opt.training:
- for epoch in range(epoch + 1, epoch + opt.niter):
- train(epoch, train_dataloader, net, criterion, optimizer, opt, writer)
- writer.close()
- if opt.testing:
- filename = "{}.{}".format(opt.model_path, epoch)
- if os.path.exists(filename):
- net = torch.load(filename)
- net.cuda()
- optimizer = optim.Adam(net.parameters(), lr=opt.lr)
- print(opt)
- test(test_dataloader, net, criterion, optimizer, opt)
- def test_gnn():
- opt.directory = "data/traindata"
- opt.training = False
- opt.testing = True
- opt.model_path = 'model/model.ckpt'
- print(opt)
-
- def train_gnn():
- print(opt)
-
- def train_binary(directory="data/traindatabinary"):
- pass
- if __name__ == '__main__':
- train_gnn()
- """
- [-0.5253, -0.7534, 0.6765, -1.0767, 0.7319, -2.6533, 0.5728]
- """
|