visualize_caffe_1.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. def visualize_weights(net, layer_name, padding=4, filename=''):
  2. # The parameters are a list of [weights, biases]
  3. data = np.copy(net.params[layer_name][0].data)
  4. # N is the total number of convolutions
  5. N = data.shape[0]*data.shape[1]
  6. # Ensure the resulting image is square
  7. filters_per_row = int(np.ceil(np.sqrt(N)))
  8. # Assume the filters are square
  9. filter_size = data.shape[2]
  10. # Size of the result image including padding
  11. result_size = filters_per_row*(filter_size + padding) - padding
  12. # Initialize result image to all zeros
  13. result = np.zeros((result_size, result_size))
  14. # Tile the filters into the result image
  15. filter_x = 0
  16. filter_y = 0
  17. for n in range(data.shape[0]):
  18. for c in range(data.shape[1]):
  19. if filter_x == filters_per_row:
  20. filter_y += 1
  21. filter_x = 0
  22. for i in range(filter_size):
  23. for j in range(filter_size):
  24. result[filter_y*(filter_size + padding) + i, filter_x*(filter_size + padding) + j] = data[n, c, i, j]
  25. filter_x += 1
  26. # Normalize image to 0-1
  27. min = result.min()
  28. max = result.max()
  29. result = (result - min) / (max - min)
  30. # Plot figure
  31. plt.figure(figsize=(10, 10))
  32. plt.axis('off')
  33. plt.imshow(result, cmap='gray', interpolation='nearest')
  34. # Save plot if filename is set
  35. if filename != '':
  36. plt.savefig(filename, bbox_inches='tight', pad_inches=0)
  37. plt.show()