visualize_2_3.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. def attention_map(self, text):
  2. """
  3. Text to visualze attention map for.
  4. """
  5. # encode the string
  6. d = self.input_vocab.string_to_int(text)
  7. # get the output sequence
  8. predicted_text = run_example(
  9. self.pred_model, self.input_vocab, self.output_vocab, text)
  10. text_ = list(text) + ['<eot>'] + ['<unk>'] * self.input_vocab.padding
  11. # get the lengths of the string
  12. input_length = len(text)+1
  13. output_length = predicted_text.index('<eot>')+1
  14. # get the activation map
  15. activation_map = np.squeeze(self.proba_model.predict(np.array([d])))[
  16. 0:output_length, 0:input_length]
  17. # import seaborn as sns
  18. plt.clf()
  19. f = plt.figure(figsize=(8, 8.5))
  20. ax = f.add_subplot(1, 1, 1)
  21. # add image
  22. i = ax.imshow(activation_map, interpolation='nearest', cmap='gray')
  23. # add colorbar
  24. cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
  25. cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
  26. cbar.ax.set_xlabel('Probability', labelpad=2)
  27. # add labels
  28. ax.set_yticks(range(output_length))
  29. ax.set_yticklabels(predicted_text[:output_length])
  30. ax.set_xticks(range(input_length))
  31. ax.set_xticklabels(text_[:input_length], rotation=45)
  32. ax.set_xlabel('Input Sequence')
  33. ax.set_ylabel('Output Sequence')
  34. # add grid and legend
  35. ax.grid()
  36. # ax.legend(loc='best')
  37. f.savefig(os.path.join(HERE, 'attention_maps', text.replace('/', '')+'.pdf'), bbox_inches='tight')
  38. f.show()