visualize.py 10 KB


  1. """
  2. Copyright (C) 2018 by Tudor Gheorghiu
  3. Permission is hereby granted, free of charge,
  4. to any person obtaining a copy of this software and associated
  5. documentation files (the "Software"),
  6. to deal in the Software without restriction,
  7. including without l> imitation the rights to
  8. use, copy, modify, merge, publish, distribute,
  9. sublicense, and/or sell copies of the Software,
  10. and to permit persons to whom the Software is furnished to do so,
  11. subject to the following conditions:
  12. The above copyright notice and this permission notice
  13. shall be included in all copies or substantial portions of the Software.
  14. """
  15. def ann_viz(model, view=True, filename="network.gv", title="My Neural Network"):
  16. """Vizualizez a Sequential model.
  17. # Arguments
  18. model: A Keras model instance.
  19. view: whether to display the model after generation.
  20. filename: where to save the vizualization. (a .gv file)
  21. title: A title for the graph
  22. """
  23. from graphviz import Digraph;
  24. import keras;
  25. from keras.models import Sequential;
  26. from keras.layers import Dense, Conv2D, MaxPooling2D, Dropout, Flatten;
  27. import json;
  28. input_layer = 0;
  29. hidden_layers_nr = 0;
  30. layer_types = [];
  31. hidden_layers = [];
  32. output_layer = 0;
  33. for layer in model.layers:
  34. if(layer == model.layers[0]):
  35. input_layer = int(str(layer.input_shape).split(",")[1][1:-1]);
  36. hidden_layers_nr += 1;
  37. if (type(layer) == keras.layers.core.Dense):
  38. hidden_layers.append(int(str(layer.output_shape).split(",")[1][1:-1]));
  39. layer_types.append("Dense");
  40. else:
  41. hidden_layers.append(1);
  42. if (type(layer) == keras.layers.convolutional.Conv2D):
  43. layer_types.append("Conv2D");
  44. elif (type(layer) == keras.layers.pooling.MaxPooling2D):
  45. layer_types.append("MaxPooling2D");
  46. elif (type(layer) == keras.layers.core.Dropout):
  47. layer_types.append("Dropout");
  48. elif (type(layer) == keras.layers.core.Flatten):
  49. layer_types.append("Flatten");
  50. elif (type(layer) == keras.layers.core.Activation):
  51. layer_types.append("Activation");
  52. else:
  53. if(layer == model.layers[-1]):
  54. output_layer = int(str(layer.output_shape).split(",")[1][1:-1]);
  55. else:
  56. hidden_layers_nr += 1;
  57. if (type(layer) == keras.layers.core.Dense):
  58. hidden_layers.append(int(str(layer.output_shape).split(",")[1][1:-1]));
  59. layer_types.append("Dense");
  60. else:
  61. hidden_layers.append(1);
  62. if (type(layer) == keras.layers.convolutional.Conv2D):
  63. layer_types.append("Conv2D");
  64. elif (type(layer) == keras.layers.pooling.MaxPooling2D):
  65. layer_types.append("MaxPooling2D");
  66. elif (type(layer) == keras.layers.core.Dropout):
  67. layer_types.append("Dropout");
  68. elif (type(layer) == keras.layers.core.Flatten):
  69. layer_types.append("Flatten");
  70. elif (type(layer) == keras.layers.core.Activation):
  71. layer_types.append("Activation");
  72. last_layer_nodes = input_layer;
  73. nodes_up = input_layer;
  74. if(type(model.layers[0]) != keras.layers.core.Dense):
  75. last_layer_nodes = 1;
  76. nodes_up = 1;
  77. input_layer = 1;
  78. g = Digraph('g', filename=filename);
  79. n = 0;
  80. g.graph_attr.update(splines="false", nodesep='1', ranksep='2');
  81. #Input Layer
  82. with g.subgraph(name='cluster_input') as c:
  83. if(type(model.layers[0]) == keras.layers.core.Dense):
  84. the_label = title+'\n\n\n\nInput Layer';
  85. if (int(str(model.layers[0].input_shape).split(",")[1][1:-1]) > 10):
  86. the_label += " (+"+str(int(str(model.layers[0].input_shape).split(",")[1][1:-1]) - 10)+")";
  87. input_layer = 10;
  88. c.attr(color='white')
  89. for i in range(0, input_layer):
  90. n += 1;
  91. c.node(str(n));
  92. c.attr(label=the_label)
  93. c.attr(rank='same');
  94. c.node_attr.update(color="#2ecc71", style="filled", fontcolor="#2ecc71", shape="circle");
  95. elif(type(model.layers[0]) == keras.layers.convolutional.Conv2D):
  96. #Conv2D Input visualizing
  97. the_label = title+'\n\n\n\nInput Layer';
  98. c.attr(color="white", label=the_label);
  99. c.node_attr.update(shape="square");
  100. pxls = str(model.layers[0].input_shape).split(',');
  101. clr = int(pxls[3][1:-1]);
  102. if (clr == 1):
  103. clrmap = "Grayscale";
  104. the_color = "black:white";
  105. elif (clr == 3):
  106. clrmap = "RGB";
  107. the_color = "#e74c3c:#3498db";
  108. else:
  109. clrmap = "";
  110. c.node_attr.update(fontcolor="white", fillcolor=the_color, style="filled");
  111. n += 1;
  112. c.node(str(n), label="Image\n"+pxls[1]+" x"+pxls[2]+" pixels\n"+clrmap, fontcolor="white");
  113. else:
  114. raise ValueError("ANN Visualizer: Layer not supported for visualizing");
  115. for i in range(0, hidden_layers_nr):
  116. with g.subgraph(name="cluster_"+str(i+1)) as c:
  117. if (layer_types[i] == "Dense"):
  118. c.attr(color='white');
  119. c.attr(rank='same');
  120. #If hidden_layers[i] > 10, dont include all
  121. the_label = "";
  122. if (int(str(model.layers[i].output_shape).split(",")[1][1:-1]) > 10):
  123. the_label += " (+"+str(int(str(model.layers[i].output_shape).split(",")[1][1:-1]) - 10)+")";
  124. hidden_layers[i] = 10;
  125. c.attr(labeljust="right", labelloc="b", label=the_label);
  126. for j in range(0, hidden_layers[i]):
  127. n += 1;
  128. c.node(str(n), shape="circle", style="filled", color="#3498db", fontcolor="#3498db");
  129. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  130. g.edge(str(h), str(n));
  131. last_layer_nodes = hidden_layers[i];
  132. nodes_up += hidden_layers[i];
  133. elif (layer_types[i] == "Conv2D"):
  134. c.attr(style='filled', color='#5faad0');
  135. n += 1;
  136. kernel_size = str(model.layers[i].get_config()['kernel_size']).split(',')[0][1] + "x" + str(model.layers[i].get_config()['kernel_size']).split(',')[1][1 : -1];
  137. filters = str(model.layers[i].get_config()['filters']);
  138. c.node("conv_"+str(n), label="Convolutional Layer\nKernel Size: "+kernel_size+"\nFilters: "+filters, shape="square");
  139. c.node(str(n), label=filters+"\nFeature Maps", shape="square");
  140. g.edge("conv_"+str(n), str(n));
  141. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  142. g.edge(str(h), "conv_"+str(n));
  143. last_layer_nodes = 1;
  144. nodes_up += 1;
  145. elif (layer_types[i] == "MaxPooling2D"):
  146. c.attr(color="white");
  147. n += 1;
  148. pool_size = str(model.layers[i].get_config()['pool_size']).split(',')[0][1] + "x" + str(model.layers[i].get_config()['pool_size']).split(',')[1][1 : -1];
  149. c.node(str(n), label="Max Pooling\nPool Size: "+pool_size, style="filled", fillcolor="#8e44ad", fontcolor="white");
  150. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  151. g.edge(str(h), str(n));
  152. last_layer_nodes = 1;
  153. nodes_up += 1;
  154. elif (layer_types[i] == "Flatten"):
  155. n += 1;
  156. c.attr(color="white");
  157. c.node(str(n), label="Flattening", shape="invtriangle", style="filled", fillcolor="#2c3e50", fontcolor="white");
  158. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  159. g.edge(str(h), str(n));
  160. last_layer_nodes = 1;
  161. nodes_up += 1;
  162. elif (layer_types[i] == "Dropout"):
  163. n += 1;
  164. c.attr(color="white");
  165. c.node(str(n), label="Dropout Layer", style="filled", fontcolor="white", fillcolor="#f39c12");
  166. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  167. g.edge(str(h), str(n));
  168. last_layer_nodes = 1;
  169. nodes_up += 1;
  170. elif (layer_types[i] == "Activation"):
  171. n += 1;
  172. c.attr(color="white");
  173. fnc = model.layers[i].get_config()['activation'];
  174. c.node(str(n), shape="octagon", label="Activation Layer\nFunction: "+fnc, style="filled", fontcolor="white", fillcolor="#00b894");
  175. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  176. g.edge(str(h), str(n));
  177. last_layer_nodes = 1;
  178. nodes_up += 1;
  179. with g.subgraph(name='cluster_output') as c:
  180. if (type(model.layers[-1]) == keras.layers.core.Dense):
  181. c.attr(color='white')
  182. c.attr(rank='same');
  183. c.attr(labeljust="1");
  184. for i in range(1, output_layer+1):
  185. n += 1;
  186. c.node(str(n), shape="circle", style="filled", color="#e74c3c", fontcolor="#e74c3c");
  187. for h in range(nodes_up - last_layer_nodes + 1 , nodes_up + 1):
  188. g.edge(str(h), str(n));
  189. c.attr(label='Output Layer', labelloc="bottom")
  190. c.node_attr.update(color="#2ecc71", style="filled", fontcolor="#2ecc71", shape="circle");
  191. g.attr(arrowShape="none");
  192. g.edge_attr.update(arrowhead="none", color="#707070");
  193. if view == True:
  194. g.view();