plotly_3.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. def build_plotly_figure(self, pos=None):
  2. if pos is None:
  3. # The kamada kawai layout produces a really nice graph but it's
  4. # a O(N^2) algorithm. It seems only reasonable to draw the graph
  5. # with fewer than ~1000 nodes.
  6. if len(self.graph) < 1000:
  7. pos = nx.layout.kamada_kawai_layout(self.graph)
  8. else:
  9. pos = nx.layout.random_layout(self.graph)
  10. # Create scatter plot of the position of all notes
  11. node_trace = go.Scatter(
  12. x=[],
  13. y=[],
  14. text=[],
  15. mode="markers",
  16. hoverinfo="text",
  17. marker=dict(
  18. showscale=True,
  19. # colorscale options
  20. colorscale="YlGnBu",
  21. reversescale=True,
  22. color=[],
  23. size=10,
  24. colorbar=dict(
  25. thickness=15, title="Centrality", xanchor="left", titleside="right"
  26. ),
  27. line=dict(width=2),
  28. ),
  29. )
  30. for node in self.graph.nodes():
  31. x, y = pos[node]
  32. text = "<br>".join([node, self.graph.nodes[node].get("title", "")])
  33. node_trace["x"] += tuple([x])
  34. node_trace["y"] += tuple([y])
  35. node_trace["text"] += tuple([text])
  36. # Color nodes based on the centrality
  37. for node, centrality in nx.degree_centrality(self.graph).items():
  38. node_trace["marker"]["color"] += tuple([centrality])
  39. # Draw the edges as annotations because it's only sane way to draw arrows.
  40. edges = []
  41. for from_node, to_node in self.graph.edges():
  42. edges.append(
  43. dict(
  44. # Tail coordinates
  45. ax=pos[from_node][0],
  46. ay=pos[from_node][1],
  47. axref="x",
  48. ayref="y",
  49. # Head coordinates
  50. x=pos[to_node][0],
  51. y=pos[to_node][1],
  52. xref="x",
  53. yref="y",
  54. # Aesthetics
  55. arrowwidth=2,
  56. arrowcolor="#666",
  57. arrowhead=2,
  58. # Have the head stop short 5 px for the center point,
  59. # i.e., depends on the node marker size.
  60. standoff=5,
  61. )
  62. )
  63. fig = go.Figure(
  64. data=[node_trace],
  65. layout=go.Layout(
  66. showlegend=False,
  67. hovermode="closest",
  68. margin=dict(b=20, l=5, r=5, t=40),
  69. annotations=edges,
  70. xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
  71. yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
  72. ),
  73. )
  74. return fig