visualize_2.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import argparse
  2. import os
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import matplotlib.patches as mpatches
  6. from models.NMT import simpleNMT
  7. from utils.examples import run_example
  8. from data.reader import Vocabulary
  9. HERE = os.path.realpath(os.path.join(os.path.realpath(__file__), '..'))
  10. def load_examples(file_name):
  11. with open(file_name) as f:
  12. return [s.replace('\n', '') for s in f.readlines()]
  13. # create a directory if it doesn't already exist
  14. if not os.path.exists(os.path.join(HERE, 'attention_maps')):
  15. os.makedirs(os.path.join(HERE, 'attention_maps'))
  16. SAMPLE_HUMAN_VOCAB = os.path.join(HERE, 'data', 'sample_human_vocab.json')
  17. SAMPLE_MACHINE_VOCAB = os.path.join(HERE, 'data', 'sample_machine_vocab.json')
  18. SAMPLE_WEIGHTS = os.path.join(HERE, 'weights', 'sample_NMT.49.0.01.hdf5')
  19. class Visualizer(object):
  20. def __init__(self,
  21. padding=None,
  22. input_vocab=SAMPLE_HUMAN_VOCAB,
  23. output_vocab=SAMPLE_MACHINE_VOCAB):
  24. """
  25. Visualizes attention maps
  26. :param padding: the padding to use for the sequences.
  27. :param input_vocab: the location of the input human
  28. vocabulary file
  29. :param output_vocab: the location of the output
  30. machine vocabulary file
  31. """
  32. self.padding = padding
  33. self.input_vocab = Vocabulary(
  34. input_vocab, padding=padding)
  35. self.output_vocab = Vocabulary(
  36. output_vocab, padding=padding)
  37. def set_models(self, pred_model, proba_model):
  38. """
  39. Sets the models to use
  40. :param pred_model: the prediction model
  41. :param proba_model: the model that outputs the activation maps
  42. """
  43. self.pred_model = pred_model
  44. self.proba_model = proba_model
  45. def attention_map(self, text):
  46. """
  47. Text to visualze attention map for.
  48. """
  49. # encode the string
  50. d = self.input_vocab.string_to_int(text)
  51. # get the output sequence
  52. predicted_text = run_example(
  53. self.pred_model, self.input_vocab, self.output_vocab, text)
  54. text_ = list(text) + ['<eot>'] + ['<unk>'] * self.input_vocab.padding
  55. # get the lengths of the string
  56. input_length = len(text)+1
  57. output_length = predicted_text.index('<eot>')+1
  58. # get the activation map
  59. activation_map = np.squeeze(self.proba_model.predict(np.array([d])))[
  60. 0:output_length, 0:input_length]
  61. # import seaborn as sns
  62. plt.clf()
  63. f = plt.figure(figsize=(8, 8.5))
  64. ax = f.add_subplot(1, 1, 1)
  65. # add image
  66. i = ax.imshow(activation_map, interpolation='nearest', cmap='gray')
  67. # add colorbar
  68. cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
  69. cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
  70. cbar.ax.set_xlabel('Probability', labelpad=2)
  71. # add labels
  72. ax.set_yticks(range(output_length))
  73. ax.set_yticklabels(predicted_text[:output_length])
  74. ax.set_xticks(range(input_length))
  75. ax.set_xticklabels(text_[:input_length], rotation=45)
  76. ax.set_xlabel('Input Sequence')
  77. ax.set_ylabel('Output Sequence')
  78. # add grid and legend
  79. ax.grid()
  80. # ax.legend(loc='best')
  81. f.savefig(os.path.join(HERE, 'attention_maps', text.replace('/', '')+'.pdf'), bbox_inches='tight')
  82. f.show()
  83. def main(examples, args):
  84. print('Total Number of Examples:', len(examples))
  85. weights_file = os.path.expanduser(args.weights)
  86. print('Weights loading from:', weights_file)
  87. viz = Visualizer(padding=args.padding,
  88. input_vocab=args.human_vocab,
  89. output_vocab=args.machine_vocab)
  90. print('Loading models')
  91. pred_model = simpleNMT(trainable=False,
  92. pad_length=args.padding,
  93. n_chars=viz.input_vocab.size(),
  94. n_labels=viz.output_vocab.size())
  95. pred_model.load_weights(weights_file, by_name=True)
  96. pred_model.compile(optimizer='adam', loss='categorical_crossentropy')
  97. proba_model = simpleNMT(trainable=False,
  98. pad_length=args.padding,
  99. n_chars=viz.input_vocab.size(),
  100. n_labels=viz.output_vocab.size(),
  101. return_probabilities=True)
  102. proba_model.load_weights(weights_file, by_name=True)
  103. proba_model.compile(optimizer='adam', loss='categorical_crossentropy')
  104. viz.set_models(pred_model, proba_model)
  105. print('Models loaded')
  106. for example in examples:
  107. viz.attention_map(example)
  108. print('Completed visualizations')
  109. if __name__ == '__main__':
  110. parser = argparse.ArgumentParser()
  111. named_args = parser.add_argument_group('named arguments')
  112. named_args.add_argument('-e', '--examples', metavar='|',
  113. help="""Example string/file to visualize attention map for
  114. If file, it must end with '.txt'""",
  115. required=True)
  116. named_args.add_argument('-w', '--weights', metavar='|',
  117. help="""Location of weights""",
  118. required=False,
  119. default=SAMPLE_WEIGHTS)
  120. named_args.add_argument('-p', '--padding', metavar='|',
  121. help="""Length of padding""",
  122. required=False, default=50, type=int)
  123. named_args.add_argument('-hv', '--human-vocab', metavar='|',
  124. help="""Path to the human vocabulary""",
  125. required=False,
  126. default=SAMPLE_HUMAN_VOCAB,
  127. type=str)
  128. named_args.add_argument('-mv', '--machine-vocab', metavar='|',
  129. help="""Path to the machine vocabulary""",
  130. required=False,
  131. default=SAMPLE_MACHINE_VOCAB,
  132. type=str)
  133. args = parser.parse_args()
  134. if '.txt' in args.examples:
  135. examples = load_examples(args.examples)
  136. else:
  137. examples = [args.examples]
  138. main(examples, args)