visualize_caffe.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. def visualize_weights(net, layer_name, padding=4, filename=''):
  4. # The parameters are a list of [weights, biases]
  5. data = np.copy(net.params[layer_name][0].data)
  6. # N is the total number of convolutions
  7. N = data.shape[0]*data.shape[1]
  8. # Ensure the resulting image is square
  9. filters_per_row = int(np.ceil(np.sqrt(N)))
  10. # Assume the filters are square
  11. filter_size = data.shape[2]
  12. # Size of the result image including padding
  13. result_size = filters_per_row*(filter_size + padding) - padding
  14. # Initialize result image to all zeros
  15. result = np.zeros((result_size, result_size))
  16. # Tile the filters into the result image
  17. filter_x = 0
  18. filter_y = 0
  19. for n in range(data.shape[0]):
  20. for c in range(data.shape[1]):
  21. if filter_x == filters_per_row:
  22. filter_y += 1
  23. filter_x = 0
  24. for i in range(filter_size):
  25. for j in range(filter_size):
  26. result[filter_y*(filter_size + padding) + i, filter_x*(filter_size + padding) + j] = data[n, c, i, j]
  27. filter_x += 1
  28. # Normalize image to 0-1
  29. min = result.min()
  30. max = result.max()
  31. result = (result - min) / (max - min)
  32. # Plot figure
  33. plt.figure(figsize=(10, 10))
  34. plt.axis('off')
  35. plt.imshow(result, cmap='gray', interpolation='nearest')
  36. # Save plot if filename is set
  37. if filename != '':
  38. plt.savefig(filename, bbox_inches='tight', pad_inches=0)
  39. plt.show()