1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- def main(examples, args):
- print('Total Number of Examples:', len(examples))
- weights_file = os.path.expanduser(args.weights)
- print('Weights loading from:', weights_file)
- viz = Visualizer(padding=args.padding,
- input_vocab=args.human_vocab,
- output_vocab=args.machine_vocab)
- print('Loading models')
- pred_model = simpleNMT(trainable=False,
- pad_length=args.padding,
- n_chars=viz.input_vocab.size(),
- n_labels=viz.output_vocab.size())
- pred_model.load_weights(weights_file, by_name=True)
- pred_model.compile(optimizer='adam', loss='categorical_crossentropy')
- proba_model = simpleNMT(trainable=False,
- pad_length=args.padding,
- n_chars=viz.input_vocab.size(),
- n_labels=viz.output_vocab.size(),
- return_probabilities=True)
- proba_model.load_weights(weights_file, by_name=True)
- proba_model.compile(optimizer='adam', loss='categorical_crossentropy')
- viz.set_models(pred_model, proba_model)
- print('Models loaded')
- for example in examples:
- viz.attention_map(example)
- print('Completed visualizations')
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- named_args = parser.add_argument_group('named arguments')
- named_args.add_argument('-e', '--examples', metavar='|',
- help="""Example string/file to visualize attention map for
- If file, it must end with '.txt'""",
- required=True)
- named_args.add_argument('-w', '--weights', metavar='|',
- help="""Location of weights""",
- required=False,
- default=SAMPLE_WEIGHTS)
- named_args.add_argument('-p', '--padding', metavar='|',
- help="""Length of padding""",
- required=False, default=50, type=int)
- named_args.add_argument('-hv', '--human-vocab', metavar='|',
- help="""Path to the human vocabulary""",
- required=False,
- default=SAMPLE_HUMAN_VOCAB,
- type=str)
- named_args.add_argument('-mv', '--machine-vocab', metavar='|',
- help="""Path to the machine vocabulary""",
- required=False,
- default=SAMPLE_MACHINE_VOCAB,
- type=str)
- args = parser.parse_args()
- if '.txt' in args.examples:
- examples = load_examples(args.examples)
- else:
- examples = [args.examples]
- main(examples, args)
|