|
@@ -1,30 +1,50 @@
|
|
|
import os
|
|
|
+import random
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
-def load_from_directory(path):
|
|
|
+def load_from_file(file_path, is_binary=0):
|
|
|
+ """
|
|
|
+
|
|
|
+ :param file_path:
|
|
|
+ :param is_binary: 0:not binary, 1:binary type, 2:binary other
|
|
|
+ :return:
|
|
|
+ """
|
|
|
node_id_data_list = []
|
|
|
node_type_data_list = []
|
|
|
node_id = list()
|
|
|
node_type = list()
|
|
|
graph_type = list()
|
|
|
- none_file_path = ""
|
|
|
- for file_name in os.listdir(path):
|
|
|
- with open(path + "/" + file_name, 'r') as file:
|
|
|
- for line in file:
|
|
|
- if len(line.strip()) == 0:
|
|
|
- node_id_data_list.append([node_id, graph_type])
|
|
|
- node_type_data_list.append([node_type, graph_type])
|
|
|
- node_id = list()
|
|
|
- node_type = list()
|
|
|
- graph_type = list()
|
|
|
- elif len(line.split(' ')) == 3:
|
|
|
+ with open(file_path) as file:
|
|
|
+ for line in file:
|
|
|
+ if len(line.strip()) == 0:
|
|
|
+ node_id_data_list.append([node_id, graph_type])
|
|
|
+ node_type_data_list.append([node_type, graph_type])
|
|
|
+ node_id = list()
|
|
|
+ node_type = list()
|
|
|
+ graph_type = list()
|
|
|
+ elif len(line.split(' ')) == 3:
|
|
|
+ if is_binary == 0:
|
|
|
graph_type.append([int(line.split(' ')[1])])
|
|
|
+ elif is_binary == 1:
|
|
|
+ graph_type.append([1])
|
|
|
else:
|
|
|
- data = line.split(' ')
|
|
|
- node_id.append([int(data[0]), int(data[2]), int(data[3])])
|
|
|
- node_type.append([int(data[1]), int(data[2]), int(data[4])])
|
|
|
+ graph_type.append([2])
|
|
|
+ else:
|
|
|
+ data = line.split(' ')
|
|
|
+ node_id.append([int(data[0]), int(data[2]), int(data[3])])
|
|
|
+ node_type.append([int(data[1]), int(data[2]), int(data[4])])
|
|
|
+ return node_id_data_list, node_type_data_list
|
|
|
+
|
|
|
+
|
|
|
+def load_from_directory(path):
|
|
|
+ node_id_data_list = []
|
|
|
+ node_type_data_list = []
|
|
|
+ for file_name in os.listdir(path):
|
|
|
+ node_id, node_type = load_from_file(path + "/" + file_name)
|
|
|
+ node_id_data_list.extend(node_id)
|
|
|
+ node_type_data_list.extend(node_type)
|
|
|
return node_id_data_list, node_type_data_list
|
|
|
|
|
|
|
|
@@ -100,6 +120,7 @@ class Dataset:
|
|
|
"""
|
|
|
Load bAbI tasks for GGNN
|
|
|
"""
|
|
|
+
|
|
|
def __init__(self, path, is_train):
|
|
|
data_id = list()
|
|
|
data_type = list()
|
|
@@ -126,7 +147,69 @@ class Dataset:
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
am = create_adjacency_matrix(self.node_by_id[index][0], self.n_node_by_id, self.n_edge_types)
|
|
|
- annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id, self.n_node_by_type)
|
|
|
+ annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id,
|
|
|
+ self.n_node_by_type)
|
|
|
+ target = self.node_by_id[index][2] - 1
|
|
|
+ return am, annotation, target
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.node_by_id)
|
|
|
+
|
|
|
+
|
|
|
+def load_from_directory_binary(path, class_type):
|
|
|
+ node_id_data_list = []
|
|
|
+ node_type_data_list = []
|
|
|
+
|
|
|
+ node_id_binary_true, node_type_binary_true = load_from_file(path + "/" + class_type + ".txt", 1)
|
|
|
+ node_id_data_list.extend(node_id_binary_true)
|
|
|
+ node_type_data_list.extend(node_type_binary_true)
|
|
|
+ id_len = len(node_id_data_list)
|
|
|
+
|
|
|
+
|
|
|
+ node_id_data_list_false = []
|
|
|
+ node_type_data_list_false = []
|
|
|
+ for file_name in os.listdir(path):
|
|
|
+ if file_name != class_type + ".txt":
|
|
|
+ node_id_binary_false, node_type_binary_false = load_from_file(path + "/" + file_name)
|
|
|
+ node_id_data_list_false.extend(node_id_binary_false)
|
|
|
+ node_type_data_list_false.extend(node_type_binary_false)
|
|
|
+ random.shuffle(node_id_data_list_false)
|
|
|
+ random.shuffle(node_type_data_list_false)
|
|
|
+ node_id_data_list.extend(node_id_data_list_false[:id_len])
|
|
|
+ node_type_data_list.extend(node_type_data_list_false[:id_len])
|
|
|
+ return node_id_data_list, node_type_data_list
|
|
|
+
|
|
|
+
|
|
|
+class BinaryDataset:
|
|
|
+ def __init__(self, path, class_type, is_train):
|
|
|
+ data_id = list()
|
|
|
+ data_type = list()
|
|
|
+
|
|
|
+ train_data_id, train_data_type = load_from_directory_binary(path + "/train", class_type)
|
|
|
+ test_data_id, test_data_type = load_from_directory_binary(path + "/test", class_type)
|
|
|
+
|
|
|
+ data_id.extend(train_data_id)
|
|
|
+ data_id.extend(test_data_id)
|
|
|
+ data_type.extend(train_data_type)
|
|
|
+ data_type.extend(test_data_type)
|
|
|
+
|
|
|
+ self.n_edge_types = find_max_edge_id(data_id)
|
|
|
+ max_node_id = find_max_node_id(data_id)
|
|
|
+ max_node_type = find_max_node_id(data_type)
|
|
|
+
|
|
|
+ self.n_node_by_id = max_node_id
|
|
|
+ self.n_node_by_type = max_node_type
|
|
|
+ if is_train:
|
|
|
+ self.node_by_id = convert_program_data(train_data_id, 1, self.n_node_by_id)
|
|
|
+ self.node_by_type = convert_program_data(train_data_type, 1, self.n_node_by_type)
|
|
|
+ else:
|
|
|
+ self.node_by_id = convert_program_data(test_data_id, 1, self.n_node_by_id)
|
|
|
+ self.node_by_type = convert_program_data(test_data_type, 1, self.n_node_by_type)
|
|
|
+
|
|
|
+ def __getitem__(self, index):
|
|
|
+ am = create_adjacency_matrix(self.node_by_id[index][0], self.n_node_by_id, self.n_edge_types)
|
|
|
+ annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id,
|
|
|
+ self.n_node_by_type)
|
|
|
target = self.node_by_id[index][2] - 1
|
|
|
return am, annotation, target
|
|
|
|
|
@@ -138,7 +221,8 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
- bi = Dataset(
|
|
|
- "I:\Program\Python\sap\GnnForPrivacyScan\data\\traindata", True)
|
|
|
- for d in bi:
|
|
|
+
|
|
|
+
|
|
|
+ binary_dataset = BinaryDataset("I:\Program\Python\sap\GnnForPrivacyScan\data\\traindatabinary", "Archive", True)
|
|
|
+ for d in binary_dataset:
|
|
|
a = 5
|