plotly.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import networkx as nx
  2. import plotly.graph_objs as go
  3. class NetworkPlotly:
  4. def __init__(self, name="Zettelkasten"):
  5. """
  6. Build network to visualize with Plotly
  7. Parameters
  8. ----------
  9. name : str
  10. The network name.
  11. """
  12. self.graph = nx.Graph()
  13. def add_node(self, node_id, title):
  14. """
  15. Add a node to the network.
  16. Parameters
  17. ----------
  18. node_id : str, or int
  19. A unique identifier for the node, typically the zettel ID.
  20. title : str
  21. The text label for each node, typically the zettel title.
  22. """
  23. self.graph.add_node(node_id, title=title)
  24. def add_edge(self, source, target):
  25. """
  26. Add a node (a zettel) to the network.
  27. Parameters
  28. ----------
  29. source : str or int
  30. The ID of the source zettel.
  31. target : str or int
  32. The ID of the target (cited) zettel.
  33. """
  34. self.graph.add_edge(source, target)
  35. def build_plotly_figure(self, pos=None):
  36. """
  37. Creates a Plot.ly Figure that can be view online or offline.
  38. Parameters
  39. ----------
  40. graph : nx.Graph
  41. The network of zettels to visualize
  42. pos : dict
  43. Dictionay of zettel_id : (x, y) coordinates where to draw nodes. If
  44. None, the Kamada Kawai layout will be used.
  45. Returns
  46. -------
  47. fig : plotly Figure
  48. """
  49. if pos is None:
  50. # The kamada kawai layout produces a really nice graph but it's
  51. # a O(N^2) algorithm. It seems only reasonable to draw the graph
  52. # with fewer than ~1000 nodes.
  53. if len(self.graph) < 1000:
  54. pos = nx.layout.kamada_kawai_layout(self.graph)
  55. else:
  56. pos = nx.layout.random_layout(self.graph)
  57. # Create scatter plot of the position of all notes
  58. node_trace = go.Scatter(
  59. x=[],
  60. y=[],
  61. text=[],
  62. mode="markers",
  63. hoverinfo="text",
  64. marker=dict(
  65. showscale=True,
  66. # colorscale options
  67. #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
  68. #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
  69. #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
  70. colorscale="YlGnBu",
  71. reversescale=True,
  72. color=[],
  73. size=10,
  74. colorbar=dict(
  75. thickness=15, title="Centrality", xanchor="left", titleside="right"
  76. ),
  77. line=dict(width=2),
  78. ),
  79. )
  80. for node in self.graph.nodes():
  81. x, y = pos[node]
  82. text = "<br>".join([node, self.graph.nodes[node].get("title", "")])
  83. node_trace["x"] += tuple([x])
  84. node_trace["y"] += tuple([y])
  85. node_trace["text"] += tuple([text])
  86. # Color nodes based on the centrality
  87. for node, centrality in nx.degree_centrality(self.graph).items():
  88. node_trace["marker"]["color"] += tuple([centrality])
  89. # Draw the edges as annotations because it's only sane way to draw arrows.
  90. edges = []
  91. for from_node, to_node in self.graph.edges():
  92. edges.append(
  93. dict(
  94. # Tail coordinates
  95. ax=pos[from_node][0],
  96. ay=pos[from_node][1],
  97. axref="x",
  98. ayref="y",
  99. # Head coordinates
  100. x=pos[to_node][0],
  101. y=pos[to_node][1],
  102. xref="x",
  103. yref="y",
  104. # Aesthetics
  105. arrowwidth=2,
  106. arrowcolor="#666",
  107. arrowhead=2,
  108. # Have the head stop short 5 px for the center point,
  109. # i.e., depends on the node marker size.
  110. standoff=5,
  111. )
  112. )
  113. fig = go.Figure(
  114. data=[node_trace],
  115. layout=go.Layout(
  116. showlegend=False,
  117. hovermode="closest",
  118. margin=dict(b=20, l=5, r=5, t=40),
  119. annotations=edges,
  120. xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
  121. yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
  122. ),
  123. )
  124. return fig
  125. def render(self, output, view=True):
  126. """
  127. Render the network to disk.
  128. Parameters
  129. ----------
  130. output : str
  131. Name of the output file.
  132. view : bool
  133. Open the rendered network using the default browser. Default is
  134. True.
  135. """
  136. fig = self.build_plotly_figure()
  137. if not output.endswith(".html"):
  138. output += ".html"
  139. fig.write_html(output, auto_open=view)