def main(examples, args): print('Total Number of Examples:', len(examples)) weights_file = os.path.expanduser(args.weights) print('Weights loading from:', weights_file) viz = Visualizer(padding=args.padding, input_vocab=args.human_vocab, output_vocab=args.machine_vocab) print('Loading models') pred_model = simpleNMT(trainable=False, pad_length=args.padding, n_chars=viz.input_vocab.size(), n_labels=viz.output_vocab.size()) pred_model.load_weights(weights_file, by_name=True) pred_model.compile(optimizer='adam', loss='categorical_crossentropy') proba_model = simpleNMT(trainable=False, pad_length=args.padding, n_chars=viz.input_vocab.size(), n_labels=viz.output_vocab.size(), return_probabilities=True) proba_model.load_weights(weights_file, by_name=True) proba_model.compile(optimizer='adam', loss='categorical_crossentropy') viz.set_models(pred_model, proba_model) print('Models loaded') for example in examples: viz.attention_map(example) print('Completed visualizations') if __name__ == '__main__': parser = argparse.ArgumentParser() named_args = parser.add_argument_group('named arguments') named_args.add_argument('-e', '--examples', metavar='|', help="""Example string/file to visualize attention map for If file, it must end with '.txt'""", required=True) named_args.add_argument('-w', '--weights', metavar='|', help="""Location of weights""", required=False, default=SAMPLE_WEIGHTS) named_args.add_argument('-p', '--padding', metavar='|', help="""Length of padding""", required=False, default=50, type=int) named_args.add_argument('-hv', '--human-vocab', metavar='|', help="""Path to the human vocabulary""", required=False, default=SAMPLE_HUMAN_VOCAB, type=str) named_args.add_argument('-mv', '--machine-vocab', metavar='|', help="""Path to the machine vocabulary""", required=False, default=SAMPLE_MACHINE_VOCAB, type=str) args = parser.parse_args() if '.txt' in args.examples: examples = load_examples(args.examples) else: examples = [args.examples] main(examples, args)