dataset.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import os
  2. import numpy as np
  3. def load_from_directory(path):
  4. node_id_data_list = []
  5. node_type_data_list = []
  6. node_id = list()
  7. node_type = list()
  8. graph_type = list()
  9. none_file_path = ""
  10. for file_name in os.listdir(path):
  11. with open(path + "/" + file_name, 'r') as file:
  12. for line in file:
  13. if len(line.strip()) == 0:
  14. node_id_data_list.append([node_id, graph_type])
  15. node_type_data_list.append([node_type, graph_type])
  16. node_id = list()
  17. node_type = list()
  18. graph_type = list()
  19. elif len(line.split(' ')) == 3:
  20. graph_type.append([int(line.split(' ')[1])])
  21. else:
  22. data = line.split(' ')
  23. node_id.append([int(data[0]), int(data[2]), int(data[3])])
  24. node_type.append([int(data[1]), int(data[2]), int(data[4])])
  25. return node_id_data_list, node_type_data_list
  26. def find_max_edge_id(data_list):
  27. max_edge_id = 0
  28. for data in data_list:
  29. edges = data[0]
  30. for item in edges:
  31. if item[1] > max_edge_id:
  32. max_edge_id = item[1]
  33. return max_edge_id
  34. def find_max_node_id(data_list):
  35. max_node_id = 0
  36. for data in data_list:
  37. edges = data[0]
  38. for item in edges:
  39. if item[0] > max_node_id:
  40. max_node_id = item[0]
  41. if item[2] > max_node_id:
  42. max_node_id = item[2]
  43. return max_node_id
  44. def convert_program_data(data_list, n_annotation_dim, n_nodes):
  45. # n_nodes = find_max_node_id(data_list)
  46. class_data_list = []
  47. for item in data_list:
  48. edge_list = item[0]
  49. target_list = item[1]
  50. for target in target_list:
  51. task_type = target[0]
  52. task_output = target[-1]
  53. annotation = np.zeros([n_nodes, n_annotation_dim])
  54. for edge in edge_list:
  55. src_idx = edge[0]
  56. if src_idx < len(annotation):
  57. annotation[src_idx - 1][0] = 1
  58. class_data_list.append([edge_list, annotation, task_output])
  59. return class_data_list
  60. def create_adjacency_matrix(edges, n_nodes, n_edge_types):
  61. a = np.zeros([n_nodes, n_nodes * n_edge_types * 2])
  62. for edge in edges:
  63. src_idx = edge[0]
  64. e_type = edge[1]
  65. tgt_idx = edge[2]
  66. if tgt_idx < len(a):
  67. a[tgt_idx - 1][(e_type - 1) * n_nodes + src_idx - 1] = 1
  68. if src_idx < len(a):
  69. a[src_idx - 1][(e_type - 1 + n_edge_types) * n_nodes + tgt_idx - 1] = 1
  70. return a
  71. def create_embedding_matrix(node_id_edges, node_type_edges, n_nodes, n_types):
  72. anno = np.zeros([n_nodes, n_types])
  73. for i in range(len(node_id_edges)):
  74. node_type = node_type_edges[i][0]
  75. # print(node_type)
  76. src_idx = node_id_edges[i][0]
  77. anno[src_idx - 1][node_type - 1] = 1.0
  78. return anno
  79. class Dataset:
  80. """
  81. Load bAbI tasks for GGNN
  82. """
  83. def __init__(self, path, is_train):
  84. data_id = list()
  85. data_type = list()
  86. train_data_id, train_data_type = load_from_directory(path + "/train")
  87. test_data_id, test_data_type = load_from_directory(path + "/test")
  88. data_id.extend(train_data_id)
  89. data_id.extend(test_data_id)
  90. data_type.extend(train_data_type)
  91. data_type.extend(test_data_type)
  92. self.n_edge_types = find_max_edge_id(data_id)
  93. max_node_id = find_max_node_id(data_id)
  94. max_node_type = find_max_node_id(data_type)
  95. self.n_node_by_id = max_node_id
  96. self.n_node_by_type = max_node_type
  97. if is_train:
  98. self.node_by_id = convert_program_data(train_data_id, 1, self.n_node_by_id)
  99. self.node_by_type = convert_program_data(train_data_type, 1, self.n_node_by_type)
  100. else:
  101. self.node_by_id = convert_program_data(test_data_id, 1, self.n_node_by_id)
  102. self.node_by_type = convert_program_data(test_data_type, 1, self.n_node_by_type)
  103. def __getitem__(self, index):
  104. am = create_adjacency_matrix(self.node_by_id[index][0], self.n_node_by_id, self.n_edge_types)
  105. annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id, self.n_node_by_type)
  106. target = self.node_by_id[index][2] - 1
  107. return am, annotation, target
  108. def __len__(self):
  109. return len(self.node_by_id)
  110. if __name__ == '__main__':
  111. # data = load_graphs_from_file(
  112. # "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train/Directory.txt")
  113. # a = 5
  114. bi = Dataset(
  115. "I:\Program\Python\sap\GnnForPrivacyScan\data\\traindata", True)
  116. for d in bi:
  117. a = 5