visualize_1.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. def ann_viz(model, view=True, filename="network.gv", title="My Neural Network"):
  2. """Vizualizez a Sequential model.
  3. # Arguments
  4. model: A Keras model instance.
  5. view: whether to display the model after generation.
  6. filename: where to save the vizualization. (a .gv file)
  7. title: A title for the graph
  8. """
  9. from graphviz import Digraph;
  10. import keras;
  11. from keras.models import Sequential;
  12. from keras.layers import Dense, Conv2D, MaxPooling2D, Dropout, Flatten;
  13. import json;
  14. input_layer = 0;
  15. hidden_layers_nr = 0;
  16. layer_types = [];
  17. hidden_layers = [];
  18. output_layer = 0;
  19. for layer in model.layers:
  20. if(layer == model.layers[0]):
  21. input_layer = int(str(layer.input_shape).split(",")[1][1:-1]);
  22. hidden_layers_nr += 1;
  23. if (type(layer) == keras.layers.core.Dense):
  24. hidden_layers.append(int(str(layer.output_shape).split(",")[1][1:-1]));
  25. layer_types.append("Dense");
  26. else:
  27. hidden_layers.append(1);
  28. if (type(layer) == keras.layers.convolutional.Conv2D):
  29. layer_types.append("Conv2D");
  30. elif (type(layer) == keras.layers.pooling.MaxPooling2D):
  31. layer_types.append("MaxPooling2D");
  32. elif (type(layer) == keras.layers.core.Dropout):
  33. layer_types.append("Dropout");
  34. elif (type(layer) == keras.layers.core.Flatten):
  35. layer_types.append("Flatten");
  36. elif (type(layer) == keras.layers.core.Activation):
  37. layer_types.append("Activation");
  38. else:
  39. if(layer == model.layers[-1]):
  40. output_layer = int(str(layer.output_shape).split(",")[1][1:-1]);
  41. else:
  42. hidden_layers_nr += 1;
  43. if (type(layer) == keras.layers.core.Dense):
  44. hidden_layers.append(int(str(layer.output_shape).split(",")[1][1:-1]));
  45. layer_types.append("Dense");
  46. else:
  47. hidden_layers.append(1);
  48. if (type(layer) == keras.layers.convolutional.Conv2D):
  49. layer_types.append("Conv2D");
  50. elif (type(layer) == keras.layers.pooling.MaxPooling2D):
  51. layer_types.append("MaxPooling2D");
  52. elif (type(layer) == keras.layers.core.Dropout):
  53. layer_types.append("Dropout");
  54. elif (type(layer) == keras.layers.core.Flatten):
  55. layer_types.append("Flatten");
  56. elif (type(layer) == keras.layers.core.Activation):
  57. layer_types.append("Activation");
  58. last_layer_nodes = input_layer;
  59. nodes_up = input_layer;
  60. if(type(model.layers[0]) != keras.layers.core.Dense):
  61. last_layer_nodes = 1;
  62. nodes_up = 1;
  63. input_layer = 1;
  64. g = Digraph('g', filename=filename);
  65. n = 0;
  66. g.graph_attr.update(splines="false", nodesep='1', ranksep='2');
  67. #Input Layer
  68. with g.subgraph(name='cluster_input') as c:
  69. if(type(model.layers[0]) == keras.layers.core.Dense):
  70. the_label = title+'\n\n\n\nInput Layer';
  71. if (int(str(model.layers[0].input_shape).split(",")[1][1:-1]) > 10):
  72. the_label += " (+"+str(int(str(model.layers[0].input_shape).split(",")[1][1:-1]) - 10)+")";
  73. input_layer = 10;
  74. c.attr(color='white')
  75. for i in range(0, input_layer):
  76. n += 1;
  77. c.node(str(n));
  78. c.attr(label=the_label)
  79. c.attr(rank='same');
  80. c.node_attr.update(color="#2ecc71", style="filled", fontcolor="#2ecc71", shape="circle");
  81. elif(type(model.layers[0]) == keras.layers.convolutional.Conv2D):
  82. #Conv2D Input visualizing
  83. the_label = title+'\n\n\n\nInput Layer';
  84. c.attr(color="white", label=the_label);
  85. c.node_attr.update(shape="square");
  86. pxls = str(model.layers[0].input_shape).split(',');
  87. clr = int(pxls[3][1:-1]);
  88. if (clr == 1):
  89. clrmap = "Grayscale";
  90. the_color = "black:white";
  91. elif (clr == 3):
  92. clrmap = "RGB";
  93. the_color = "#e74c3c:#3498db";
  94. else:
  95. clrmap = "";
  96. c.node_attr.update(fontcolor="white", fillcolor=the_color, style="filled");
  97. n += 1;
  98. c.node(str(n), label="Image\n"+pxls[1]+" x"+pxls[2]+" pixels\n"+clrmap, fontcolor="white");
  99. else:
  100. raise ValueError("ANN Visualizer: Layer not supported for visualizing");
  101. for i in range(0, hidden_layers_nr):
  102. with g.subgraph(name="cluster_"+str(i+1)) as c:
  103. if (layer_types[i] == "Dense"):
  104. c.attr(color='white');
  105. c.attr(rank='same');
  106. #If hidden_layers[i] > 10, dont include all
  107. the_label = "";
  108. if (int(str(model.layers[i].output_shape).split(",")[1][1:-1]) > 10):
  109. the_label += " (+"+str(int(str(model.layers[i].output_shape).split(",")[1][1:-1]) - 10)+")";
  110. hidden_layers[i] = 10;
  111. c.attr(labeljust="right", labelloc="b", label=the_label);
  112. for j in range(0, hidden_layers[i]):
  113. n += 1;
  114. c.node(str(n), shape="circle", style="filled", color="#3498db", fontcolor="#3498db");
  115. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  116. g.edge(str(h), str(n));
  117. last_layer_nodes = hidden_layers[i];
  118. nodes_up += hidden_layers[i];
  119. elif (layer_types[i] == "Conv2D"):
  120. c.attr(style='filled', color='#5faad0');
  121. n += 1;
  122. kernel_size = str(model.layers[i].get_config()['kernel_size']).split(',')[0][1] + "x" + str(model.layers[i].get_config()['kernel_size']).split(',')[1][1 : -1];
  123. filters = str(model.layers[i].get_config()['filters']);
  124. c.node("conv_"+str(n), label="Convolutional Layer\nKernel Size: "+kernel_size+"\nFilters: "+filters, shape="square");
  125. c.node(str(n), label=filters+"\nFeature Maps", shape="square");
  126. g.edge("conv_"+str(n), str(n));
  127. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  128. g.edge(str(h), "conv_"+str(n));
  129. last_layer_nodes = 1;
  130. nodes_up += 1;
  131. elif (layer_types[i] == "MaxPooling2D"):
  132. c.attr(color="white");
  133. n += 1;
  134. pool_size = str(model.layers[i].get_config()['pool_size']).split(',')[0][1] + "x" + str(model.layers[i].get_config()['pool_size']).split(',')[1][1 : -1];
  135. c.node(str(n), label="Max Pooling\nPool Size: "+pool_size, style="filled", fillcolor="#8e44ad", fontcolor="white");
  136. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  137. g.edge(str(h), str(n));
  138. last_layer_nodes = 1;
  139. nodes_up += 1;
  140. elif (layer_types[i] == "Flatten"):
  141. n += 1;
  142. c.attr(color="white");
  143. c.node(str(n), label="Flattening", shape="invtriangle", style="filled", fillcolor="#2c3e50", fontcolor="white");
  144. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  145. g.edge(str(h), str(n));
  146. last_layer_nodes = 1;
  147. nodes_up += 1;
  148. elif (layer_types[i] == "Dropout"):
  149. n += 1;
  150. c.attr(color="white");
  151. c.node(str(n), label="Dropout Layer", style="filled", fontcolor="white", fillcolor="#f39c12");
  152. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  153. g.edge(str(h), str(n));
  154. last_layer_nodes = 1;
  155. nodes_up += 1;
  156. elif (layer_types[i] == "Activation"):
  157. n += 1;
  158. c.attr(color="white");
  159. fnc = model.layers[i].get_config()['activation'];
  160. c.node(str(n), shape="octagon", label="Activation Layer\nFunction: "+fnc, style="filled", fontcolor="white", fillcolor="#00b894");
  161. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  162. g.edge(str(h), str(n));
  163. last_layer_nodes = 1;
  164. nodes_up += 1;
  165. with g.subgraph(name='cluster_output') as c:
  166. if (type(model.layers[-1]) == keras.layers.core.Dense):
  167. c.attr(color='white')
  168. c.attr(rank='same');
  169. c.attr(labeljust="1");
  170. for i in range(1, output_layer+1):
  171. n += 1;
  172. c.node(str(n), shape="circle", style="filled", color="#e74c3c", fontcolor="#e74c3c");
  173. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  174. g.edge(str(h), str(n));
  175. c.attr(label='Output Layer', labelloc="bottom")
  176. c.node_attr.update(color="#2ecc71", style="filled", fontcolor="#2ecc71", shape="circle");
  177. g.attr(arrowShape="none");
  178. g.edge_attr.update(arrowhead="none", color="#707070");
  179. if view == True:
  180. g.view();