ast_visualizer.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #!/usr/bin/python3
  2. import ast
  3. import graphviz as gv
  4. import subprocess
  5. import numbers
  6. import re
  7. from uuid import uuid4 as uuid
  8. import optparse
  9. import sys
  10. def main(args):
  11. parser = optparse.OptionParser(usage="astvisualizer.py [options] [string]")
  12. parser.add_option("-f", "--file", action="store",
  13. help="Read a code snippet from the specified file")
  14. parser.add_option("-l", "--label", action="store",
  15. help="The label for the visualization")
  16. options, args = parser.parse_args(args)
  17. if options.file:
  18. with open(options.file) as instream:
  19. code = instream.read()
  20. label = options.file
  21. elif len(args) == 2:
  22. code = args[1]
  23. label = "<code read from command line parameter>"
  24. else:
  25. print("Expecting Python code on stdin...")
  26. code = sys.stdin.read()
  27. label = "<code read from stdin>"
  28. if options.label:
  29. label = options.label
  30. code_ast = ast.parse(code)
  31. transformed_ast = transform_ast(code_ast)
  32. renderer = GraphRenderer()
  33. renderer.render(transformed_ast, label=label)
  34. def transform_ast(code_ast):
  35. if isinstance(code_ast, ast.AST):
  36. node = {to_camelcase(k): transform_ast(getattr(code_ast, k)) for k in code_ast._fields}
  37. node['node_type'] = to_camelcase(code_ast.__class__.__name__)
  38. return node
  39. elif isinstance(code_ast, list):
  40. return [transform_ast(el) for el in code_ast]
  41. else:
  42. return code_ast
  43. def to_camelcase(string):
  44. return re.sub('([a-z0-9])([A-Z])', r'\1_\2', string).lower()
  45. class GraphRenderer:
  46. """
  47. this class is capable of rendering data structures consisting of
  48. dicts and lists as a graph using graphviz
  49. """
  50. graphattrs = {
  51. 'labelloc': 't',
  52. 'fontcolor': 'white',
  53. 'bgcolor': '#333333',
  54. 'margin': '0',
  55. }
  56. nodeattrs = {
  57. 'color': 'white',
  58. 'fontcolor': 'white',
  59. 'style': 'filled',
  60. 'fillcolor': '#006699',
  61. }
  62. edgeattrs = {
  63. 'color': 'white',
  64. 'fontcolor': 'white',
  65. }
  66. _graph = None
  67. _rendered_nodes = None
  68. @staticmethod
  69. def _escape_dot_label(str):
  70. return str.replace("\\", "\\\\").replace("|", "\\|").replace("<", "\\<").replace(">", "\\>")
  71. def _render_node(self, node):
  72. if isinstance(node, (str, numbers.Number)) or node is None:
  73. node_id = uuid()
  74. else:
  75. node_id = id(node)
  76. node_id = str(node_id)
  77. if node_id not in self._rendered_nodes:
  78. self._rendered_nodes.add(node_id)
  79. if isinstance(node, dict):
  80. self._render_dict(node, node_id)
  81. elif isinstance(node, list):
  82. self._render_list(node, node_id)
  83. else:
  84. self._graph.node(node_id, label=self._escape_dot_label(str(node)))
  85. return node_id
  86. def _render_dict(self, node, node_id):
  87. self._graph.node(node_id, label=node.get("node_type", "[dict]"))
  88. for key, value in node.items():
  89. if key == "node_type":
  90. continue
  91. child_node_id = self._render_node(value)
  92. self._graph.edge(node_id, child_node_id, label=self._escape_dot_label(key))
  93. def _render_list(self, node, node_id):
  94. self._graph.node(node_id, label="[list]")
  95. for idx, value in enumerate(node):
  96. child_node_id = self._render_node(value)
  97. self._graph.edge(node_id, child_node_id, label=self._escape_dot_label(str(idx)))
  98. def render(self, data, *, label=None):
  99. # create the graph
  100. graphattrs = self.graphattrs.copy()
  101. if label is not None:
  102. graphattrs['label'] = self._escape_dot_label(label)
  103. graph = gv.Digraph(graph_attr = graphattrs, node_attr = self.nodeattrs, edge_attr = self.edgeattrs)
  104. # recursively draw all the nodes and edges
  105. self._graph = graph
  106. self._rendered_nodes = set()
  107. self._render_node(data)
  108. self._graph = None
  109. self._rendered_nodes = None
  110. # display the graph
  111. graph.format = "pdf"
  112. graph.view()
  113. subprocess.Popen(['xdg-open', "test.pdf"])
  114. if __name__ == '__main__':
  115. main(sys.argv)