|
@@ -10,6 +10,7 @@ from tensorboardX import SummaryWriter
|
|
|
from dataloader.dataloader import PrivacyDataloader
|
|
|
from dataloader.dataset import Dataset
|
|
|
from train.model import GGNN
|
|
|
+from train.test import test
|
|
|
from train.train import train
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
@@ -18,9 +19,9 @@ parser.add_argument('--train_batch_size', type=int, default=5, help='input batch
|
|
|
parser.add_argument('--test_batch_size', type=int, default=5, help='input batch size')
|
|
|
parser.add_argument('--state_dim', type=int, default=106, help='GGNN hidden state size')
|
|
|
parser.add_argument('--n_steps', type=int, default=10, help='propogation steps number of GGNN')
|
|
|
-parser.add_argument('--niter', type=int, default=150, help='number of epochs to train for')
|
|
|
+parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
|
|
|
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
|
|
|
-parser.add_argument('--cuda', action='store_true', help='enables cuda')
|
|
|
+parser.add_argument('--cuda', type=bool, default=True, help='enables cuda')
|
|
|
parser.add_argument('--verbal', type=bool, default=True, help='print training info or not')
|
|
|
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
|
|
parser.add_argument('--n_classes', type=int, default=7, help='manual seed')
|
|
@@ -29,13 +30,13 @@ parser.add_argument('--model_path', default="model/model.ckpt", help='path to sa
|
|
|
parser.add_argument('--n_hidden', type=int, default=50, help='number of hidden layers')
|
|
|
parser.add_argument('--size_vocabulary', type=int, default=108, help='maximum number of node types')
|
|
|
parser.add_argument('--is_training_ggnn', type=bool, default=True, help='Training GGNN or BiGGNN')
|
|
|
-parser.add_argument('--training', action="store_true", help='is training')
|
|
|
-parser.add_argument('--testing', action="store_true", help='is testing')
|
|
|
+parser.add_argument('--training', type=bool, default=True, help='is training')
|
|
|
+parser.add_argument('--testing', type=bool, default=False, help='is testing')
|
|
|
parser.add_argument('--training_percentage', type=float, default=1.0, help='percentage of data use for training')
|
|
|
parser.add_argument('--log_path', default="logs/", help='log path for tensorboard')
|
|
|
parser.add_argument('--epoch', type=int, default=5, help='epoch to test')
|
|
|
-parser.add_argument('--n_edge_types', type=int, default=65, help='edge types')
|
|
|
-parser.add_argument('--n_node', type=int, default=100, help='node types')
|
|
|
+parser.add_argument('--n_edge_types', type=int, default=1, help='edge types')
|
|
|
+parser.add_argument('--n_node', type=int, help='node types')
|
|
|
|
|
|
opt = parser.parse_args()
|
|
|
|
|
@@ -55,16 +56,28 @@ opt Namespace(workers=2, train_batch_size=5, test_batch_size=5, state_dim=30, n_
|
|
|
def main(opt):
|
|
|
|
|
|
train_dataset = Dataset(
|
|
|
- "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train", True)
|
|
|
+ "data/traindata/train", True)
|
|
|
train_dataloader = PrivacyDataloader(train_dataset, batch_size=5, shuffle=True, num_workers=2)
|
|
|
|
|
|
+ test_dataset = Dataset(
|
|
|
+ "data/traindata/test", True)
|
|
|
+ test_dataloader = PrivacyDataloader(test_dataset, batch_size=5, shuffle=True, num_workers=2)
|
|
|
+
|
|
|
+
|
|
|
opt.annotation_dim = 1 # for bAbI
|
|
|
if opt.training:
|
|
|
opt.n_edge_types = train_dataset.n_edge_types
|
|
|
opt.n_node = train_dataset.n_node_by_id
|
|
|
+ else:
|
|
|
+ opt.n_edge_types = test_dataset.n_edge_types
|
|
|
+ opt.n_node = test_dataset.n_node_by_id
|
|
|
|
|
|
- filename = opt.model_path
|
|
|
- epoch = -1
|
|
|
+ if opt.testing:
|
|
|
+ filename = "{}.{}".format(opt.model_path, opt.epoch)
|
|
|
+ epoch = opt.epoch
|
|
|
+ else:
|
|
|
+ filename = opt.model_path
|
|
|
+ epoch = -1
|
|
|
|
|
|
if os.path.exists(filename):
|
|
|
if opt.testing:
|
|
@@ -121,6 +134,14 @@ def main(opt):
|
|
|
train(epoch, train_dataloader, net, criterion, optimizer, opt, writer)
|
|
|
writer.close()
|
|
|
|
|
|
+ if opt.testing:
|
|
|
+ filename = "{}.{}".format(opt.model_path, epoch)
|
|
|
+ if os.path.exists(filename):
|
|
|
+ net = torch.load(filename)
|
|
|
+ net.cuda()
|
|
|
+ optimizer = optim.Adam(net.parameters(), lr=opt.lr)
|
|
|
+ test(test_dataloader, net, criterion, optimizer, opt)
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main(opt)
|