visualize_2_4.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. def main(examples, args):
  2. print('Total Number of Examples:', len(examples))
  3. weights_file = os.path.expanduser(args.weights)
  4. print('Weights loading from:', weights_file)
  5. viz = Visualizer(padding=args.padding,
  6. input_vocab=args.human_vocab,
  7. output_vocab=args.machine_vocab)
  8. print('Loading models')
  9. pred_model = simpleNMT(trainable=False,
  10. pad_length=args.padding,
  11. n_chars=viz.input_vocab.size(),
  12. n_labels=viz.output_vocab.size())
  13. pred_model.load_weights(weights_file, by_name=True)
  14. pred_model.compile(optimizer='adam', loss='categorical_crossentropy')
  15. proba_model = simpleNMT(trainable=False,
  16. pad_length=args.padding,
  17. n_chars=viz.input_vocab.size(),
  18. n_labels=viz.output_vocab.size(),
  19. return_probabilities=True)
  20. proba_model.load_weights(weights_file, by_name=True)
  21. proba_model.compile(optimizer='adam', loss='categorical_crossentropy')
  22. viz.set_models(pred_model, proba_model)
  23. print('Models loaded')
  24. for example in examples:
  25. viz.attention_map(example)
  26. print('Completed visualizations')
  27. if __name__ == '__main__':
  28. parser = argparse.ArgumentParser()
  29. named_args = parser.add_argument_group('named arguments')
  30. named_args.add_argument('-e', '--examples', metavar='|',
  31. help="""Example string/file to visualize attention map for
  32. If file, it must end with '.txt'""",
  33. required=True)
  34. named_args.add_argument('-w', '--weights', metavar='|',
  35. help="""Location of weights""",
  36. required=False,
  37. default=SAMPLE_WEIGHTS)
  38. named_args.add_argument('-p', '--padding', metavar='|',
  39. help="""Length of padding""",
  40. required=False, default=50, type=int)
  41. named_args.add_argument('-hv', '--human-vocab', metavar='|',
  42. help="""Path to the human vocabulary""",
  43. required=False,
  44. default=SAMPLE_HUMAN_VOCAB,
  45. type=str)
  46. named_args.add_argument('-mv', '--machine-vocab', metavar='|',
  47. help="""Path to the machine vocabulary""",
  48. required=False,
  49. default=SAMPLE_MACHINE_VOCAB,
  50. type=str)
  51. args = parser.parse_args()
  52. if '.txt' in args.examples:
  53. examples = load_examples(args.examples)
  54. else:
  55. examples = [args.examples]
  56. main(examples, args)