visualize_tree.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. """Script that visualizes dependencies of Nix packages"""
  2. import argparse
  3. import configparser
  4. import itertools
  5. import os
  6. import random
  7. import shlex
  8. import subprocess
  9. import sys
  10. import tempfile
  11. import logging
  12. import networkx as nx
  13. import pygraphviz as pgv
  14. import matplotlib
  15. matplotlib.use('Agg')
  16. import matplotlib.pyplot as plt
  17. import warnings
  18. warnings.filterwarnings("ignore")
  19. from . import util
  20. from .graph_objects import Node, Edge
  21. logger = logging.getLogger(__name__)
  22. #: Default values for things we expect in the config file
  23. CONFIG_OPTIONS = {
  24. "aspect_ratio": (2, float),
  25. "dpi": (300, int),
  26. "font_scale": (1.0, float),
  27. "color_scatter": (1.0, float),
  28. "edge_color": ("#888888", str),
  29. "font_color": ("#888888", str),
  30. "color_map": ("rainbow", str),
  31. "img_y_height_inches": (24, float),
  32. "y_sublevels": (5, int),
  33. "y_sublevel_spacing": (0.2, float),
  34. "num_iterations": (100, int),
  35. "edge_alpha": (0.3, float),
  36. "edge_width_scale": (1.0, float),
  37. "max_displacement": (2.5, float),
  38. "top_level_spacing": (100, float),
  39. "repulsive_force_normalization": (2.0, float),
  40. "attractive_force_normalization": (1.0, float),
  41. "add_size_per_out_link": (200, int),
  42. "max_node_size_over_min_node_size": (5.0, float),
  43. "min_node_size": (100.0, float),
  44. "tmax": (30.0, float),
  45. "show_labels": (1, int)
  46. }
  47. class Graph(object):
  48. """Class representing a dependency tree"""
  49. def __init__(self, packages, config, output_file, do_write=True):
  50. """Initialize a graph from the result of a nix-store command"""
  51. self.config = self._parse_config(config)
  52. self.nodes = []
  53. self.edges = []
  54. self.root_package_names = [util.remove_nix_hash(os.path.basename(x)) for
  55. x in packages]
  56. for package in packages:
  57. # Run nix-store -q --graph <package>. This generates a graphviz
  58. # file with package dependencies
  59. cmd = ("nix-store -q --graph {}".format(package))
  60. res = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE,
  61. stderr=subprocess.PIPE)
  62. stdout, stderr = res.communicate()
  63. if res.returncode != 0:
  64. raise util.TreeCLIError("nix-store call failed, message "
  65. "{}".format(stderr))
  66. package_nodes, package_edges = self._get_edges_and_nodes(stdout)
  67. self.nodes.extend(package_nodes)
  68. self.edges.extend(package_edges)
  69. self.nodes = list(set(self.nodes))
  70. self._add_edges_to_nodes()
  71. # The package itself is level 0, its direct dependencies are
  72. # level 1, their direct dependencies are level 2, etc.
  73. for n in self.nodes:
  74. n.add_level()
  75. self.depth = max([x.level for x in self.nodes]) + 1
  76. logger.info("Graph has {} nodes, {} edges and a depth of {}".format(
  77. len(self.nodes), len(self.edges), self.depth))
  78. # Transform the Nodes and Edges into a networkx graph
  79. self.G = nx.DiGraph()
  80. for node in self.nodes:
  81. self.G.add_node(node)
  82. for parent in node.parents:
  83. self.G.add_edge(node, parent)
  84. self._add_pos_to_nodes()
  85. if do_write is True:
  86. self.write_frame_png(filename=output_file)
  87. def _parse_config(self, config, verbose=True):
  88. """Load visualization parameters from config file or take defaults
  89. if they are not in there
  90. """
  91. configfile = config[0]
  92. configsection = config[1]
  93. return_configs = {}
  94. if configfile is not None:
  95. configs = configparser.ConfigParser()
  96. configs.read(configfile)
  97. if len(configs.sections()) > 1:
  98. if configsection is None:
  99. raise util.TreeCLIError("Config file {} contains more than "
  100. "one section, so -s must be set".format(
  101. configfile))
  102. elif configsection not in configs.sections():
  103. raise util.TreeCLIError("Config file {} does not contain a "
  104. "section named {}".format(
  105. configfile, configsection))
  106. else:
  107. # There is only one section in the file, just read it
  108. configsection = configs.sections()[0]
  109. else:
  110. logger.info("--configfile not set, using all defaults")
  111. return {k: v[0] for k, v in CONFIG_OPTIONS.items()}
  112. logger.info("Reading section [{}] of file {}".format(configsection,
  113. configfile))
  114. # Loop through config options. If there is a corresponding key in the
  115. # config file, overwrite, else take the value from the defaults
  116. for param, (p_default, p_dtype) in CONFIG_OPTIONS.items():
  117. try:
  118. return_configs[param] = p_dtype(
  119. configs.get(configsection, param))
  120. logger.debug("Setting {} to {}".format(param,
  121. return_configs[param]))
  122. except (ConfigParser.NoOptionError, ValueError):
  123. return_configs[param] = p_dtype(p_default)
  124. logger.info( "Adding default of {} for {}".format(
  125. p_dtype(p_default), param))
  126. return return_configs
  127. def write_frame_png(self, filename="nix-tree.png"):
  128. """Dump the graph to a png file"""
  129. try:
  130. cmap = getattr(matplotlib.cm, self.config["color_map"])
  131. except AttributeError:
  132. raise util.TreeCLIError("Colormap {} does not exist".format(
  133. self.config["color_map"]))
  134. pos = {n: (n.x, n.y) for n in self.nodes}
  135. col_scale = 255.0/(self.depth+1.0)
  136. col = [(x.level+random.random()*self.config["color_scatter"])*col_scale
  137. for x in self.G.nodes()]
  138. col = [min([x,255]) for x in col]
  139. img_y_height=self.config["img_y_height_inches"]
  140. size_min = self.config["min_node_size"]
  141. size_max = self.config["max_node_size_over_min_node_size"] * size_min
  142. plt.figure(1, figsize=(img_y_height*self.config["aspect_ratio"],
  143. img_y_height))
  144. node_size = [min(size_min + (x.out_degree-1)*
  145. self.config["add_size_per_out_link"],
  146. size_max) if x.level > 0 else size_max for
  147. x in self.G.nodes()]
  148. # Draw edges
  149. nx.draw(self.G, pos, node_size=node_size, arrows=False,
  150. with_labels=self.config["show_labels"],
  151. edge_color=self.config["edge_color"],
  152. font_size=12*self.config["font_scale"],
  153. node_color=col, vmin=0, vmax=256,
  154. width=self.config["edge_width_scale"],
  155. alpha=self.config["edge_alpha"], nodelist=[])
  156. # Draw nodes
  157. nx.draw(self.G, pos, node_size=node_size, arrows=False,
  158. with_labels=self.config["show_labels"],
  159. font_size=12*self.config["font_scale"],
  160. node_color=col, vmin=0, vmax=255, edgelist=[],
  161. font_weight="light", cmap=cmap,
  162. font_color=self.config["font_color"])
  163. logger.info("Writing png file: {}".format(filename))
  164. plt.savefig(filename, dpi=self.config["dpi"])
  165. plt.close()
  166. def _add_pos_to_nodes(self):
  167. """Populates every node with an x an y position using the following
  168. iterative algorithm:
  169. * start at t=0
  170. * Apply an x force to each node that is proportional to the offset
  171. between its x position and the average position of its parents
  172. * Apply an x force to each node that pushes it away from its siblings
  173. with a force proportional to 1/d, where d is the distance between
  174. the node and its neighbor
  175. * advance time forward by dt=tmax/num_iterations, displace particles
  176. by F*dt
  177. * repeat until the number of iterations has been exhausted
  178. """
  179. logger.info("Adding positions to nodes")
  180. #: The distance between levels in arbitrary units. Used to set a
  181. #: scale on the diagram
  182. level_height = 10
  183. #: Maximum displacement of a point on a single iteration
  184. max_displacement = level_height * self.config["max_displacement"]
  185. #: The timestep to take on each iteration
  186. dt = self.config["tmax"]/self.config["num_iterations"]
  187. number_top_level = len([x for x in self.nodes if x.level == 0])
  188. count_top_level = 0
  189. # Initialize x with a random position unless you're the top level
  190. # package, then space nodes evenly
  191. for n in self.nodes:
  192. if n.level == 0:
  193. n.x = float(count_top_level)*self.config["top_level_spacing"]
  194. count_top_level += 1
  195. n.y = self.depth * level_height
  196. else:
  197. n.x = ((number_top_level + 1) *
  198. self.config["top_level_spacing"] * random.random())
  199. for iternum in range(self.config["num_iterations"]):
  200. if iternum in range(0,self.config["num_iterations"],
  201. int(self.config["num_iterations"]/10)):
  202. logger.debug("Completed iteration {} of {}".format(iternum,
  203. self.config["num_iterations"]))
  204. total_abs_displacement = 0.0
  205. for level in range(1, self.depth):
  206. # Get the y-offset by cycling with other nodes in the
  207. # same level
  208. xpos = [(x.name, x.x) for x in self.level(level)]
  209. xpos = sorted(xpos, key=lambda x:x[1])
  210. xpos = zip(xpos,
  211. itertools.cycle(range(self.config["y_sublevels"])))
  212. pos_sorter = {x[0][0]: x[1] for x in xpos}
  213. for n in self.level(level):
  214. n.y = ((self.depth - n.level) * level_height +
  215. pos_sorter[n.name] *
  216. self.config["y_sublevel_spacing"]*level_height)
  217. for lev_node in self.level(level):
  218. # We pull nodes toward their parents
  219. dis = [parent.x - lev_node.x for
  220. parent in lev_node.parents]
  221. # And push nodes away from their siblings with force 1/r
  222. sibs = self.level(level)
  223. sdis = [1.0/(sib.x - lev_node.x) for
  224. sib in sibs if abs(sib.x-lev_node.x) > 1e-3]
  225. total_sdis = (
  226. sum(sdis) *
  227. self.config["repulsive_force_normalization"])
  228. total_displacement = (
  229. self.config["attractive_force_normalization"] *
  230. float(sum(dis)) / len(dis))
  231. # Limit each of the displacements to the max displacement
  232. dx_parent = util.clamp(total_displacement, max_displacement)
  233. lev_node.dx_parent = dx_parent
  234. dx_sibling = util.clamp(total_sdis, max_displacement)
  235. lev_node.dx_sibling = -dx_sibling
  236. for lev_node in self.level(level):
  237. lev_node.x += lev_node.dx_parent * dt
  238. lev_node.x += lev_node.dx_sibling * dt
  239. total_abs_displacement += (abs(lev_node.dx_parent * dt) +
  240. abs(lev_node.dx_sibling * dt))
  241. def level(self, level):
  242. """Return a list of all nodes on a given level
  243. """
  244. return [x for x in self.nodes if x.level == level]
  245. def levels(self, min_level=0):
  246. """An iterator over levels, yields all the nodes in each level"""
  247. for i in range(min_level,self.depth):
  248. yield self.level(i)
  249. def nodes_by_prefix(self, name):
  250. """Return a list of all nodes whose names begin with a given prefix
  251. """
  252. return [x for x in self.nodes if x.name.startswith(name)]
  253. def _get_edges_and_nodes(self, raw_lines):
  254. """Transform a raw GraphViz file into Node and Edge objects. Note
  255. that at this point the nodes and edges are not linked into a graph
  256. they are simply two lists of items."""
  257. tempf = tempfile.NamedTemporaryFile(delete=False)
  258. tempf.write(raw_lines)
  259. tempf.close()
  260. G = pgv.AGraph(tempf.name)
  261. all_edges = []
  262. all_nodes = []
  263. for node in G.nodes():
  264. if (util.remove_nix_hash(node.name) not
  265. in [n.name for n in all_nodes]):
  266. all_nodes.append(Node(node.name))
  267. for edge in G.edges():
  268. all_edges.append(Edge(edge[0], edge[1]))
  269. return all_nodes, all_edges
  270. def _add_edges_to_nodes(self):
  271. """Given the lists of Edges and Nodes, add parents and children to
  272. nodes by following each edge
  273. """
  274. for edge in self.edges:
  275. nfrom = [n for n in self.nodes if n.name == edge.nfrom]
  276. nto = [n for n in self.nodes if n.name == edge.nto]
  277. nfrom = nfrom[0]
  278. nto = nto[0]
  279. if nfrom.name == nto.name:
  280. # Disallow self-references
  281. continue
  282. if nto not in nfrom.parents:
  283. nfrom.add_parent(nfrom, nto)
  284. if nfrom not in nto.children:
  285. nto.add_child(nfrom, nto)
  286. def __repr__(self):
  287. """Basic print of Graph, show the package name and the number of
  288. dependencies on each level
  289. """
  290. head = self.level(0)
  291. ret_str = "Graph of package: {}".format(head[0].name)
  292. for ilevel, level in enumerate(self.levels(min_level=1)):
  293. ret_str += "\n\tOn level {} there are {} packages".format(
  294. ilevel+1, len(level))
  295. return ret_str
  296. def init_logger(debug=False):
  297. """Sets up logging for this cli"""
  298. log_level = logging.DEBUG if debug else logging.INFO
  299. logging.basicConfig(format="%(levelname)s %(message)s\033[1;0m",
  300. stream=sys.stderr, level=log_level)
  301. logging.addLevelName(logging.CRITICAL,
  302. "\033[1;37m[\033[1;31mCRIT\033[1;37m]\033[0;31m")
  303. logging.addLevelName(logging.ERROR,
  304. "\033[1;37m[\033[1;33mERR \033[1;37m]\033[0;33m")
  305. logging.addLevelName(logging.WARNING,
  306. "\033[1;37m[\033[1;33mWARN\033[1;37m]\033[0;33m")
  307. logging.addLevelName(logging.INFO,
  308. "\033[1;37m[\033[1;32mINFO\033[1;37m]\033[0;37m")
  309. logging.addLevelName(logging.DEBUG,
  310. "\033[1;37m[\033[1;34mDBUG\033[1;37m]\033[0;34m")
  311. def main():
  312. """Parse command line arguments, instantiate graph and dump image"""
  313. parser = argparse.ArgumentParser()
  314. parser.add_argument("packages",
  315. help="Full path to a package in the Nix store. "
  316. "This package will be diagrammed", nargs='+')
  317. parser.add_argument("--configfile", "-c", help="ini file with layout and "
  318. "style configuration", required=False)
  319. parser.add_argument("--configsection", "-s", help="section from ini file "
  320. "to read")
  321. parser.add_argument("--output", "-o", help="output filename, will be "
  322. "a png", default="frame.png", required=False)
  323. parser.add_argument('--verbose', dest='verbose', action='store_true')
  324. parser.add_argument('--no-verbose', dest='verbose', action='store_false')
  325. parser.set_defaults(verbose=False)
  326. args = parser.parse_args()
  327. init_logger(debug=args.verbose)
  328. try:
  329. graph = Graph(args.packages, (args.configfile, args.configsection),
  330. args.output)
  331. except util.TreeCLIError as e:
  332. sys.stderr.write("ERROR: {}\n".format(e.message))
  333. sys.exit(1)
  334. if __name__ == "__main__":
  335. main()