dataset.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import os
  2. import random
  3. import numpy as np
  4. def load_from_file(file_path, is_binary=0):
  5. """
  6. :param file_path:
  7. :param is_binary: 0:not binary, 1:binary type, 2:binary other
  8. :return:
  9. """
  10. node_id_data_list = []
  11. node_type_data_list = []
  12. node_id = list()
  13. node_type = list()
  14. graph_type = list()
  15. with open(file_path) as file:
  16. for line in file:
  17. if len(line.strip()) == 0:
  18. node_id_data_list.append([node_id, graph_type])
  19. node_type_data_list.append([node_type, graph_type])
  20. node_id = list()
  21. node_type = list()
  22. graph_type = list()
  23. elif len(line.split(' ')) == 3:
  24. if is_binary == 0:
  25. graph_type.append([int(line.split(' ')[1])])
  26. elif is_binary == 1:
  27. graph_type.append([1])
  28. else:
  29. graph_type.append([2])
  30. else:
  31. data = line.split(' ')
  32. node_id.append([int(data[0]), int(data[2]), int(data[3])])
  33. node_type.append([int(data[1]), int(data[2]), int(data[4])])
  34. return node_id_data_list, node_type_data_list
  35. def load_from_directory(path):
  36. node_id_data_list = []
  37. node_type_data_list = []
  38. for file_name in os.listdir(path):
  39. node_id, node_type = load_from_file(path + "/" + file_name)
  40. node_id_data_list.extend(node_id)
  41. node_type_data_list.extend(node_type)
  42. return node_id_data_list, node_type_data_list
  43. def find_max_edge_id(data_list):
  44. max_edge_id = 0
  45. for data in data_list:
  46. edges = data[0]
  47. for item in edges:
  48. if item[1] > max_edge_id:
  49. max_edge_id = item[1]
  50. return max_edge_id
  51. def find_max_node_id(data_list):
  52. max_node_id = 0
  53. for data in data_list:
  54. edges = data[0]
  55. for item in edges:
  56. if item[0] > max_node_id:
  57. max_node_id = item[0]
  58. if item[2] > max_node_id:
  59. max_node_id = item[2]
  60. return max_node_id
  61. def convert_program_data(data_list, n_annotation_dim, n_nodes):
  62. # n_nodes = find_max_node_id(data_list)
  63. class_data_list = []
  64. for item in data_list:
  65. edge_list = item[0]
  66. target_list = item[1]
  67. for target in target_list:
  68. task_type = target[0]
  69. task_output = target[-1]
  70. annotation = np.zeros([n_nodes, n_annotation_dim])
  71. for edge in edge_list:
  72. src_idx = edge[0]
  73. if src_idx < len(annotation):
  74. annotation[src_idx - 1][0] = 1
  75. class_data_list.append([edge_list, annotation, task_output])
  76. return class_data_list
  77. def create_adjacency_matrix(edges, n_nodes, n_edge_types):
  78. a = np.zeros([n_nodes, n_nodes * n_edge_types * 2])
  79. for edge in edges:
  80. src_idx = edge[0]
  81. e_type = edge[1]
  82. tgt_idx = edge[2]
  83. if tgt_idx < len(a):
  84. a[tgt_idx - 1][(e_type - 1) * n_nodes + src_idx - 1] = 1
  85. if src_idx < len(a):
  86. a[src_idx - 1][(e_type - 1 + n_edge_types) * n_nodes + tgt_idx - 1] = 1
  87. return a
  88. def create_embedding_matrix(node_id_edges, node_type_edges, n_nodes, n_types):
  89. anno = np.zeros([n_nodes, n_types])
  90. for i in range(len(node_id_edges)):
  91. node_type = node_type_edges[i][0]
  92. # print(node_type)
  93. src_idx = node_id_edges[i][0]
  94. anno[src_idx - 1][node_type - 1] = 1.0
  95. return anno
  96. class Dataset:
  97. """
  98. Load bAbI tasks for GGNN
  99. """
  100. def __init__(self, path, is_train):
  101. data_id = list()
  102. data_type = list()
  103. train_data_id, train_data_type = load_from_directory(path + "/train")
  104. test_data_id, test_data_type = load_from_directory(path + "/test")
  105. data_id.extend(train_data_id)
  106. data_id.extend(test_data_id)
  107. data_type.extend(train_data_type)
  108. data_type.extend(test_data_type)
  109. self.n_edge_types = find_max_edge_id(data_id)
  110. max_node_id = find_max_node_id(data_id)
  111. max_node_type = find_max_node_id(data_type)
  112. self.n_node_by_id = max_node_id
  113. self.n_node_by_type = max_node_type
  114. if is_train:
  115. self.node_by_id = convert_program_data(train_data_id, 1, self.n_node_by_id)
  116. self.node_by_type = convert_program_data(train_data_type, 1, self.n_node_by_type)
  117. else:
  118. self.node_by_id = convert_program_data(test_data_id, 1, self.n_node_by_id)
  119. self.node_by_type = convert_program_data(test_data_type, 1, self.n_node_by_type)
  120. def __getitem__(self, index):
  121. am = create_adjacency_matrix(self.node_by_id[index][0], self.n_node_by_id, self.n_edge_types)
  122. annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id,
  123. self.n_node_by_type)
  124. target = self.node_by_id[index][2] - 1
  125. return am, annotation, target
  126. def __len__(self):
  127. return len(self.node_by_id)
  128. def load_from_directory_binary(path, class_type):
  129. node_id_data_list = []
  130. node_type_data_list = []
  131. # binary true
  132. node_id_binary_true, node_type_binary_true = load_from_file(path + "/" + class_type + ".txt", 1)
  133. node_id_data_list.extend(node_id_binary_true)
  134. node_type_data_list.extend(node_type_binary_true)
  135. id_len = len(node_id_data_list)
  136. # binary false
  137. node_id_data_list_false = []
  138. node_type_data_list_false = []
  139. for file_name in os.listdir(path):
  140. if file_name != class_type + ".txt":
  141. node_id_binary_false, node_type_binary_false = load_from_file(path + "/" + file_name)
  142. node_id_data_list_false.extend(node_id_binary_false)
  143. node_type_data_list_false.extend(node_type_binary_false)
  144. random.shuffle(node_id_data_list_false)
  145. random.shuffle(node_type_data_list_false)
  146. node_id_data_list.extend(node_id_data_list_false[:id_len])
  147. node_type_data_list.extend(node_type_data_list_false[:id_len])
  148. return node_id_data_list, node_type_data_list
  149. class BinaryDataset:
  150. def __init__(self, path, class_type, is_train):
  151. data_id = list()
  152. data_type = list()
  153. train_data_id, train_data_type = load_from_directory_binary(path + "/train", class_type)
  154. test_data_id, test_data_type = load_from_directory_binary(path + "/test", class_type)
  155. data_id.extend(train_data_id)
  156. data_id.extend(test_data_id)
  157. data_type.extend(train_data_type)
  158. data_type.extend(test_data_type)
  159. self.n_edge_types = find_max_edge_id(data_id)
  160. max_node_id = find_max_node_id(data_id)
  161. max_node_type = find_max_node_id(data_type)
  162. self.n_node_by_id = max_node_id
  163. self.n_node_by_type = max_node_type
  164. if is_train:
  165. self.node_by_id = convert_program_data(train_data_id, 1, self.n_node_by_id)
  166. self.node_by_type = convert_program_data(train_data_type, 1, self.n_node_by_type)
  167. else:
  168. self.node_by_id = convert_program_data(test_data_id, 1, self.n_node_by_id)
  169. self.node_by_type = convert_program_data(test_data_type, 1, self.n_node_by_type)
  170. def __getitem__(self, index):
  171. am = create_adjacency_matrix(self.node_by_id[index][0], self.n_node_by_id, self.n_edge_types)
  172. annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id,
  173. self.n_node_by_type)
  174. target = self.node_by_id[index][2] - 1
  175. return am, annotation, target
  176. def __len__(self):
  177. return len(self.node_by_id)
  178. if __name__ == '__main__':
  179. # data = load_graphs_from_file(
  180. # "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train/Directory.txt")
  181. # a = 5
  182. # bi = Dataset(
  183. # "I:\Program\Python\sap\GnnForPrivacyScan\data\\traindata", True)
  184. binary_dataset = BinaryDataset("I:\Program\Python\sap\GnnForPrivacyScan\data\\traindatabinary", "Archive", True)
  185. for d in binary_dataset:
  186. a = 5