dataset.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import os
  2. import numpy as np
  3. def load_from_directory(path):
  4. node_id_data_list = []
  5. node_type_data_list = []
  6. node_id = list()
  7. node_type = list()
  8. graph_type = list()
  9. none_file_path = ""
  10. for file_name in os.listdir(path):
  11. with open(path + "/" + file_name, 'r') as file:
  12. for line in file:
  13. print(line)
  14. if len(line.strip()) == 0:
  15. node_id_data_list.append([node_id, graph_type])
  16. node_type_data_list.append([node_type, graph_type])
  17. node_id = list()
  18. node_type = list()
  19. graph_type = list()
  20. elif len(line.split(' ')) == 3:
  21. graph_type.append([int(line.split(' ')[1])])
  22. else:
  23. data = line.split(' ')
  24. node_id.append([int(data[0]), int(data[2]), int(data[3])])
  25. node_type.append([int(data[1]), int(data[2]), int(data[4])])
  26. return node_id_data_list, node_type_data_list
  27. def find_max_edge_id(data_list):
  28. max_edge_id = 0
  29. for data in data_list:
  30. edges = data[0]
  31. for item in edges:
  32. if item[1] > max_edge_id:
  33. max_edge_id = item[1]
  34. return max_edge_id
  35. def find_max_node_id(data_list):
  36. max_node_id = 0
  37. for data in data_list:
  38. edges = data[0]
  39. for item in edges:
  40. if item[0] > max_node_id:
  41. max_node_id = item[0]
  42. if item[2] > max_node_id:
  43. max_node_id = item[2]
  44. return max_node_id
  45. def convert_program_data(data_list, n_annotation_dim, n_nodes):
  46. # n_nodes = find_max_node_id(data_list)
  47. class_data_list = []
  48. for item in data_list:
  49. edge_list = item[0]
  50. target_list = item[1]
  51. for target in target_list:
  52. task_type = target[0]
  53. task_output = target[-1]
  54. annotation = np.zeros([n_nodes, n_annotation_dim])
  55. for edge in edge_list:
  56. src_idx = edge[0]
  57. if src_idx < len(annotation):
  58. annotation[src_idx - 1][0] = 1
  59. class_data_list.append([edge_list, annotation, task_output])
  60. return class_data_list
  61. def create_adjacency_matrix(edges, n_nodes, n_edge_types):
  62. a = np.zeros([n_nodes, n_nodes * n_edge_types * 2])
  63. for edge in edges:
  64. src_idx = edge[0]
  65. e_type = edge[1]
  66. tgt_idx = edge[2]
  67. if tgt_idx < len(a):
  68. a[tgt_idx - 1][(e_type - 1) * n_nodes + src_idx - 1] = 1
  69. if src_idx < len(a):
  70. a[src_idx - 1][(e_type - 1 + n_edge_types) * n_nodes + tgt_idx - 1] = 1
  71. return a
  72. def create_embedding_matrix(node_id_edges, node_type_edges, n_nodes, n_types):
  73. anno = np.zeros([n_nodes, n_types])
  74. for i in range(len(node_id_edges)):
  75. node_type = node_type_edges[i][0]
  76. # print(node_type)
  77. src_idx = node_id_edges[i][0]
  78. anno[src_idx - 1][node_type - 1] = 1.0
  79. return anno
  80. class Dataset:
  81. """
  82. Load bAbI tasks for GGNN
  83. """
  84. def __init__(self, path, is_train):
  85. data_id, data_type = load_from_directory(path)
  86. self.n_edge_types = find_max_edge_id(data_id)
  87. max_node_id = find_max_node_id(data_id)
  88. max_node_type = find_max_node_id(data_type)
  89. self.n_node_by_id = max_node_id
  90. self.n_node_by_type = max_node_type
  91. self.node_by_id = convert_program_data(data_id, 1, self.n_node_by_id)
  92. self.node_by_type = convert_program_data(data_type, 1, self.n_node_by_type)
  93. def __getitem__(self, index):
  94. am = create_adjacency_matrix(self.node_by_id[index][0], self.n_node_by_id, self.n_edge_types)
  95. annotation = create_embedding_matrix(self.node_by_id[index][0], self.node_by_type[index][0], self.n_node_by_id, self.n_node_by_type)
  96. target = self.node_by_id[index][2] - 1
  97. return am, annotation, target
  98. def __len__(self):
  99. return len(self.node_by_id)
  100. if __name__ == '__main__':
  101. # data = load_graphs_from_file(
  102. # "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train/Directory.txt")
  103. # a = 5
  104. bi = Dataset(
  105. "I:\Program\Python\sap\GnnForPrivacyScan\data\\traindata\\train", True)
  106. for data in bi:
  107. a = 5