Browse Source

训练更新

wendaojidian 2 years ago
parent
commit
72aee24e4a
4 changed files with 114 additions and 24 deletions
  1. 103 19
      dataloader/dataset.py
  2. 1 1
      test.py
  3. 10 4
      traingnn.py
  4. 0 0
      数据.md

+ 103 - 19
dataloader/dataset.py

@@ -1,30 +1,50 @@
 import os
+import random
 
 import numpy as np
 
 
-def load_from_directory(path):
+def load_from_file(file_path, is_binary=0):
+    """
+
+    :param file_path:
+    :param is_binary: 0:not binary, 1:binary type, 2:binary other
+    :return:
+    """
     node_id_data_list = []
     node_type_data_list = []
     node_id = list()
     node_type = list()
     graph_type = list()
-    none_file_path = ""
-    for file_name in os.listdir(path):
-        with open(path + "/" + file_name, 'r') as file:
-            for line in file:
-                if len(line.strip()) == 0:
-                    node_id_data_list.append([node_id, graph_type])
-                    node_type_data_list.append([node_type, graph_type])
-                    node_id = list()
-                    node_type = list()
-                    graph_type = list()
-                elif len(line.split(' ')) == 3:
+    with open(file_path) as file:
+        for line in file:
+            if len(line.strip()) == 0:
+                node_id_data_list.append([node_id, graph_type])
+                node_type_data_list.append([node_type, graph_type])
+                node_id = list()
+                node_type = list()
+                graph_type = list()
+            elif len(line.split(' ')) == 3:
+                if is_binary == 0:
                     graph_type.append([int(line.split(' ')[1])])
+                elif is_binary == 1:
+                    graph_type.append([1])
                 else:
-                    data = line.split(' ')
-                    node_id.append([int(data[0]), int(data[2]), int(data[3])])
-                    node_type.append([int(data[1]), int(data[2]), int(data[4])])
+                    graph_type.append([2])
+            else:
+                data = line.split(' ')
+                node_id.append([int(data[0]), int(data[2]), int(data[3])])
+                node_type.append([int(data[1]), int(data[2]), int(data[4])])
+    return node_id_data_list, node_type_data_list
+
+
+def load_from_directory(path):
+    node_id_data_list = []
+    node_type_data_list = []
+    for file_name in os.listdir(path):
+        node_id, node_type = load_from_file(path + "/" + file_name)
+        node_id_data_list.extend(node_id)
+        node_type_data_list.extend(node_type)
     return node_id_data_list, node_type_data_list
 
 
@@ -100,6 +120,7 @@ class Dataset:
     """
     Load bAbI tasks for GGNN
     """
+
     def __init__(self, path, is_train):
         data_id = list()
         data_type = list()
@@ -126,7 +147,69 @@ class Dataset:
 
     def __getitem__(self, index):
         am = create_adjacency_matrix(self.node_by_id[index][0], self.n_node_by_id, self.n_edge_types)
-        annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id, self.n_node_by_type)
+        annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id,
+                                             self.n_node_by_type)
+        target = self.node_by_id[index][2] - 1
+        return am, annotation, target
+
+    def __len__(self):
+        return len(self.node_by_id)
+
+
+def load_from_directory_binary(path, class_type):
+    node_id_data_list = []
+    node_type_data_list = []
+    # binary true
+    node_id_binary_true, node_type_binary_true = load_from_file(path + "/" + class_type + ".txt", 1)
+    node_id_data_list.extend(node_id_binary_true)
+    node_type_data_list.extend(node_type_binary_true)
+    id_len = len(node_id_data_list)
+
+    # binary false
+    node_id_data_list_false = []
+    node_type_data_list_false = []
+    for file_name in os.listdir(path):
+        if file_name != class_type + ".txt":
+            node_id_binary_false, node_type_binary_false = load_from_file(path + "/" + file_name)
+            node_id_data_list_false.extend(node_id_binary_false)
+            node_type_data_list_false.extend(node_type_binary_false)
+    random.shuffle(node_id_data_list_false)
+    random.shuffle(node_type_data_list_false)
+    node_id_data_list.extend(node_id_data_list_false[:id_len])
+    node_type_data_list.extend(node_type_data_list_false[:id_len])
+    return node_id_data_list, node_type_data_list
+
+
+class BinaryDataset:
+    def __init__(self, path, class_type, is_train):
+        data_id = list()
+        data_type = list()
+
+        train_data_id, train_data_type = load_from_directory_binary(path + "/train", class_type)
+        test_data_id, test_data_type = load_from_directory_binary(path + "/test", class_type)
+
+        data_id.extend(train_data_id)
+        data_id.extend(test_data_id)
+        data_type.extend(train_data_type)
+        data_type.extend(test_data_type)
+
+        self.n_edge_types = find_max_edge_id(data_id)
+        max_node_id = find_max_node_id(data_id)
+        max_node_type = find_max_node_id(data_type)
+
+        self.n_node_by_id = max_node_id
+        self.n_node_by_type = max_node_type
+        if is_train:
+            self.node_by_id = convert_program_data(train_data_id, 1, self.n_node_by_id)
+            self.node_by_type = convert_program_data(train_data_type, 1, self.n_node_by_type)
+        else:
+            self.node_by_id = convert_program_data(test_data_id, 1, self.n_node_by_id)
+            self.node_by_type = convert_program_data(test_data_type, 1, self.n_node_by_type)
+
+    def __getitem__(self, index):
+        am = create_adjacency_matrix(self.node_by_id[index][0], self.n_node_by_id, self.n_edge_types)
+        annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id,
+                                             self.n_node_by_type)
         target = self.node_by_id[index][2] - 1
         return am, annotation, target
 
@@ -138,7 +221,8 @@ if __name__ == '__main__':
     # data = load_graphs_from_file(
     #     "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train/Directory.txt")
     # a = 5
-    bi = Dataset(
-        "I:\Program\Python\sap\GnnForPrivacyScan\data\\traindata", True)
-    for d in bi:
+    # bi = Dataset(
+    #     "I:\Program\Python\sap\GnnForPrivacyScan\data\\traindata", True)
+    binary_dataset = BinaryDataset("I:\Program\Python\sap\GnnForPrivacyScan\data\\traindatabinary", "Archive", True)
+    for d in binary_dataset:
         a = 5

+ 1 - 1
test.py

@@ -20,4 +20,4 @@ CUR_PATH = os.path.dirname(__file__)
 file_list = "data/purposeSplit.bk"
 for directory in os.listdir(file_list):
     directory2 = file_list + "/" + directory
-    print(len(walk_files(directory2)))
+    print(directory, len(walk_files(directory2)))

+ 10 - 4
traingnn.py

@@ -25,7 +25,7 @@ parser.add_argument('--lr', type=float, default=0.0005, help='learning rate')
 parser.add_argument('--cuda', type=bool, default=True, help='enables cuda')
 parser.add_argument('--verbal', type=bool, default=True, help='print training info or not')
 parser.add_argument('--manualSeed', type=int, help='manual seed')
-parser.add_argument('--n_classes', type=int, default=7, help='manual seed')
+parser.add_argument('--n_classes', type=int, default=6, help='manual seed')
 parser.add_argument('--directory', default="data/traindata", help='program data')
 parser.add_argument('--model_path', default="model/model.ckpt", help='path to save the model')
 parser.add_argument('--n_hidden', type=int, default=50, help='number of hidden layers')
@@ -146,12 +146,18 @@ def test_gnn():
     opt.directory = "data/traindata"
     opt.training = False
     opt.testing = True
-    opt.model_path = 'model/model_bk/model.ckpt'
-    main(opt)
+    opt.model_path = 'model/model.ckpt'
+    print(opt)
+    # main(opt)
 
 
 def train_gnn():
-    main(opt)
+    print(opt)
+    # main(opt)
+
+
+def train_binary(directory="data/traindatabinary"):
+    pass
 
 
 if __name__ == '__main__':

+ 0 - 0
数据.md