AstGraph.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import ast
  2. import os
  3. import random
  4. from graph.common.nodetype import *
  5. from graph.common.graphtype import *
  6. from utils.fileio import *
  7. CUR_PATH = os.path.dirname(__file__)
  8. i = 0
  9. def ast_visit(node, graph=None, id_list=None):
  10. if graph is None:
  11. graph = list()
  12. if id_list is None:
  13. id_list = [node]
  14. iter_field = ast.iter_fields(node)
  15. for field, value in iter_field:
  16. # print(type(node).__name__, field, value)
  17. if isinstance(value, list):
  18. for item in value:
  19. if isinstance(item, ast.AST):
  20. id_list.append(item)
  21. graph.append(
  22. (id_list.index(node) + 1, NodeType.node_map[type(node).__name__], 1,
  23. id_list.index(item) + 1, NodeType.node_map[type(item).__name__])
  24. )
  25. ast_visit(item, graph, id_list)
  26. elif isinstance(value, ast.AST):
  27. id_list.append(value)
  28. graph.append(
  29. (id_list.index(node) + 1, NodeType.node_map[type(node).__name__], 1,
  30. id_list.index(value) + 1, NodeType.node_map[type(value).__name__])
  31. )
  32. ast_visit(value, graph, id_list)
  33. def gen_graph_from_file(file_path):
  34. with open(file_path, 'r') as file:
  35. content = file.read()
  36. func_graph = list()
  37. try:
  38. root = ast.parse(content)
  39. for node in root.body:
  40. if isinstance(node, ast.FunctionDef):
  41. ast_visit(node, func_graph)
  42. except IndentationError:
  43. print("IndentationError: ", file_path)
  44. except SyntaxError:
  45. print("SyntaxError: ", file_path)
  46. except:
  47. print("other: ", file_path)
  48. return func_graph
  49. def gen_graph_to_txt(input_path, train_path, test_path):
  50. kinds = os.listdir(input_path)
  51. kinds.remove('.DS_Store')
  52. for kind in kinds:
  53. graph_type = GraphType.type[kind]
  54. path_out_train = train_path + "/" + str(kind) + ".txt"
  55. path_out_test = test_path + "/" + str(kind) + ".txt"
  56. file_path_list = walk_files(input_path + "/" + kind)
  57. random.shuffle(file_path_list)
  58. lens = int(len(file_path_list) / 4 * 3)
  59. with open(path_out_train, 'w') as file:
  60. for file_path in file_path_list[:lens]:
  61. func_graph = gen_graph_from_file(file_path)
  62. for edge in func_graph:
  63. file.write(
  64. "" + str(edge[0]) + " " + str(edge[1]) + " " + str(edge[2]) + " " + str(edge[3]) + " " + str(
  65. edge[
  66. 4]) + "\n")
  67. file.write("? " + str(graph_type) + " " + file_path.replace(input_path + "/", "") + "\n\n")
  68. with open(path_out_test, 'w') as file:
  69. for file_path in file_path_list[lens: len(file_path_list)]:
  70. func_graph = gen_graph_from_file(file_path)
  71. for edge in func_graph:
  72. file.write(
  73. "" + str(edge[0]) + " " + str(edge[1]) + " " + str(edge[2]) + " " + str(edge[3]) + " " + str(
  74. edge[
  75. 4]) + "\n")
  76. file.write("? " + str(graph_type) + " " + file_path.replace(input_path + "/", "") + "\n\n")
  77. if __name__ == '__main__':
  78. gen_graph_to_txt("/Users/liufan/program/PYTHON/sap2nd/Data/target",
  79. "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindatabinary/train",
  80. "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindatabinary/test")
  81. # graph = gen_graph_from_file("/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/purposeSplit.bk/Directory/advance_touch_1.py", )
  82. # a = 5