|
@@ -6,6 +6,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from tensorboardX import SummaryWriter
|
|
|
+from torchsummary import summary
|
|
|
|
|
|
from dataloader.dataloader import PrivacyDataloader
|
|
|
from dataloader.dataset import Dataset
|
|
@@ -15,17 +16,17 @@ from train.train import train
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
|
|
|
-parser.add_argument('--train_batch_size', type=int, default=5, help='input batch size')
|
|
|
-parser.add_argument('--test_batch_size', type=int, default=5, help='input batch size')
|
|
|
+parser.add_argument('--train_batch_size', type=int, default=1, help='input batch size')
|
|
|
+parser.add_argument('--test_batch_size', type=int, default=1, help='input batch size')
|
|
|
parser.add_argument('--state_dim', type=int, default=106, help='GGNN hidden state size')
|
|
|
parser.add_argument('--n_steps', type=int, default=10, help='propogation steps number of GGNN')
|
|
|
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
|
|
|
-parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
|
|
|
+parser.add_argument('--lr', type=float, default=0.0005, help='learning rate')
|
|
|
parser.add_argument('--cuda', type=bool, default=True, help='enables cuda')
|
|
|
parser.add_argument('--verbal', type=bool, default=True, help='print training info or not')
|
|
|
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
|
|
parser.add_argument('--n_classes', type=int, default=7, help='manual seed')
|
|
|
-parser.add_argument('--directory', default="data/traindata/train", help='program data')
|
|
|
+parser.add_argument('--directory', default="data/traindata", help='program data')
|
|
|
parser.add_argument('--model_path', default="model/model.ckpt", help='path to save the model')
|
|
|
parser.add_argument('--n_hidden', type=int, default=50, help='number of hidden layers')
|
|
|
parser.add_argument('--size_vocabulary', type=int, default=108, help='maximum number of node types')
|
|
@@ -56,14 +57,13 @@ opt Namespace(workers=2, train_batch_size=5, test_batch_size=5, state_dim=30, n_
|
|
|
def main(opt):
|
|
|
|
|
|
train_dataset = Dataset(
|
|
|
- "data/traindata/train", True)
|
|
|
+ "data/traindata", True)
|
|
|
train_dataloader = PrivacyDataloader(train_dataset, batch_size=5, shuffle=True, num_workers=2)
|
|
|
|
|
|
test_dataset = Dataset(
|
|
|
- "data/traindata/test", True)
|
|
|
+ "data/traindata", False)
|
|
|
test_dataloader = PrivacyDataloader(test_dataset, batch_size=5, shuffle=True, num_workers=2)
|
|
|
|
|
|
-
|
|
|
opt.annotation_dim = 1 # for bAbI
|
|
|
if opt.training:
|
|
|
opt.n_edge_types = train_dataset.n_edge_types
|
|
@@ -105,7 +105,7 @@ def main(opt):
|
|
|
|
|
|
net = GGNN(opt)
|
|
|
net.double()
|
|
|
- print(net)
|
|
|
+ # print(net)
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
@@ -125,10 +125,8 @@ def main(opt):
|
|
|
# writer = SummaryWriter(opt.log_path)
|
|
|
else:
|
|
|
writer = None
|
|
|
- opt.training = True
|
|
|
- print(opt)
|
|
|
+ print(net)
|
|
|
|
|
|
- # embedding_matrix = train_dataset.embedding_matrix
|
|
|
if opt.training:
|
|
|
for epoch in range(epoch + 1, epoch + opt.niter):
|
|
|
train(epoch, train_dataloader, net, criterion, optimizer, opt, writer)
|
|
@@ -140,8 +138,25 @@ def main(opt):
|
|
|
net = torch.load(filename)
|
|
|
net.cuda()
|
|
|
optimizer = optim.Adam(net.parameters(), lr=opt.lr)
|
|
|
+ print(opt)
|
|
|
test(test_dataloader, net, criterion, optimizer, opt)
|
|
|
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
+def test_gnn():
|
|
|
+ opt.directory = "data/traindata"
|
|
|
+ opt.training = False
|
|
|
+ opt.testing = True
|
|
|
+ opt.model_path = 'model/model_bk/model.ckpt'
|
|
|
+ main(opt)
|
|
|
+
|
|
|
+
|
|
|
+def train_gnn():
|
|
|
main(opt)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ train_gnn()
|
|
|
+
|
|
|
+"""
|
|
|
+[-0.5253, -0.7534, 0.6765, -1.0767, 0.7319, -2.6533, 0.5728]
|
|
|
+"""
|