bootstrap_utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import string
  2. import bcolz
  3. import numpy as np
  4. import os
  5. import pickle
  6. from .nlp_utils import compact_text
  7. import nltk
  8. import torch
  9. nltk.download('stopwords')
  10. from nltk.corpus import stopwords
  11. from nltk.tokenize import word_tokenize, sent_tokenize
  12. import gensim
  13. stop_words = set(stopwords.words('english'))
  14. eps = 10e-8
  15. def get_init_hidden(bsz, hidden_size, use_cuda):
  16. h_0 = torch.autograd.Variable(torch.FloatTensor(bsz, hidden_size).zero_())
  17. c_0 = torch.autograd.Variable(torch.FloatTensor(bsz, hidden_size).zero_())
  18. if use_cuda:
  19. h_0, c_0 = h_0.cuda(), c_0.cuda()
  20. return h_0, c_0
  21. def similarity(w1, w2, w2v):
  22. try:
  23. vec1 = w2v[w1]
  24. except KeyError:
  25. vec1 = np.zeros((300,))
  26. try:
  27. vec2 = w2v[w2]
  28. except KeyError:
  29. vec2 = np.zeros((300,))
  30. if np.sum(vec1)==0 and np.sum(vec2)!=0:
  31. vec1 = vec1[:len(vec2)]
  32. elif np.sum(vec1)!=0 and np.sum(vec2)==0:
  33. vec2 = vec2[:len(vec1)]
  34. unit_vec1 = vec1/(np.linalg.norm(vec1) + eps)
  35. unit_vec2 = vec2/(np.linalg.norm(vec2) + eps)
  36. return np.dot(unit_vec1, unit_vec2)
  37. def normalize(state, remove_articles=True):
  38. state = state.lower()
  39. out = state.translate(str.maketrans('', '' , string.punctuation))
  40. if remove_articles:
  41. out = word_tokenize(out)
  42. s_ws = [w for w in out if not w in stop_words]
  43. else:
  44. s_ws = word_tokenize(out)
  45. return s_ws
  46. def get_thresholded(sim_dict, t=0.2):
  47. final_words = []
  48. for k, v in sim_dict.items():
  49. if v >=t:
  50. final_words.append(k)
  51. return final_words
  52. def statistics_score(sim_dict, kind='avg'):
  53. scores = []
  54. for k, v in sim_dict.items():
  55. scores.append(v)
  56. return np.mean(scores) if kind == 'avg' else np.max(scores)
  57. def correlate_state(s_ws, object_list, w2v, mean=True):
  58. sim_dict = {}
  59. if isinstance(object_list, list):
  60. for w in s_ws:
  61. sim_list = []
  62. for w_obj in object_list:
  63. sim_list.append(similarity(w_obj, w, w2v))
  64. sim_dict[w] = np.mean(sim_list) if mean else np.max(sim_list)
  65. elif isinstance(object_list, dict):
  66. for w in s_ws:
  67. sim_list = []
  68. for w_obj, val in object_list.items():
  69. sim_list.append(val * similarity(w_obj, w, w2v))
  70. sim_dict[w] = np.mean(sim_list) if mean else np.max(sim_list)
  71. return sim_dict
  72. class BootstrapFilter:
  73. def __init__(self, threshold=0.3,
  74. filter_sentence=False):
  75. self.threshold = threshold
  76. self.load_cc_embeddings()
  77. self.load_bs_action_token()
  78. self.filter_sent = filter_sentence
  79. def load_cc_embeddings(self):
  80. embed_size = 300
  81. cc_path = os.path.expanduser('~/Data/nlp/conceptNet')
  82. filename = '{0}/numberbatch-en-19.08.txt'.format(cc_path)
  83. rootdir = '{0}/glove.dat'.format(cc_path)
  84. words_file = '{0}/CC_words.pkl'.format(cc_path, embed_size)
  85. idx_file = '{0}/CC_idx.pkl'.format(cc_path, embed_size)
  86. words = pickle.load(open(words_file, 'rb'))
  87. word2idx = pickle.load(open(idx_file, 'rb'))
  88. vectors = bcolz.open(rootdir)
  89. self.w2v = {w: vectors[word2idx[w]] for w in words}
  90. # Domain relevant episodic state pruning
  91. class CREST(BootstrapFilter):
  92. def __init__(self, threshold=0.3, embeddings='cnet'): # 'cnet' | 'glove' | 'word2vec' | 'bert'
  93. self.threshold = threshold
  94. print('##'*30)
  95. print('Using embedding : ', embeddings)
  96. print('##'*30)
  97. if embeddings=='cnet':
  98. self.load_cc_embeddings()
  99. elif embeddings=='glove':
  100. self.load_glove_embeddings()
  101. elif embeddings=='word2vec':
  102. self.load_w2v_embeddings()
  103. def load_glove_embeddings(self):
  104. embed_size=100
  105. glove_path=os.path.expanduser('~/Data/nlp/glove/glove.6B')
  106. rootdir = '{0}/6B.{1}.dat'.format(glove_path, embed_size)
  107. words_file = '{0}/6B.{1}_words.pkl'.format(glove_path, embed_size)
  108. idx_file = '{0}/6B.{1}_idx.pkl'.format(glove_path, embed_size)
  109. vectors = bcolz.open(rootdir)[:]
  110. words = pickle.load(open(words_file, 'rb'))
  111. word2idx = pickle.load(open(idx_file, 'rb'))
  112. self.w2v = {w: vectors[word2idx[w]] for w in words}
  113. def load_w2v_embeddings(self):
  114. self.w2v=gensim.models.KeyedVectors.load_word2vec_format('data/Googlemodel.bin',binary=True)
  115. def prune_state(self, noisy_string, expert_words, return_details=False, add_prefix=True):
  116. def get_scores(noisy_string_x, mean=True):
  117. s_wsx = normalize(noisy_string_x)
  118. sim_dictx = correlate_state(s_wsx, object_list=expert_words, w2v=self.w2v, mean=mean)
  119. return sim_dictx
  120. sentence_pruned_str_joined = noisy_string
  121. sim_dict = get_scores(sentence_pruned_str_joined, mean=False)
  122. final_str = get_thresholded(sim_dict, t=self.threshold)
  123. if add_prefix:
  124. final_str = '-= Unknown =- ' + ' '.join(final_str)
  125. else:
  126. final_str = ' '.join(final_str)
  127. if return_details:
  128. return final_str, sim_dict
  129. else:
  130. return final_str