main_30.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. def __next__(self):
  2. # Select next note (length >= 2000)
  3. while True:
  4. try:
  5. _, row = next(self.note_iterrows)
  6. except StopIteration:
  7. self._load_random_csv()
  8. _, row = next(self.note_iterrows)
  9. note_id = int(row.ROW_ID)
  10. note = row.TEXT.strip()
  11. # if len(note) >= 2000:
  12. # break
  13. if len(note) < 2000:
  14. continue
  15. try:
  16. correct, left, right = self._random_word_context(note)
  17. except:
  18. # import traceback; traceback.print_exc();
  19. continue
  20. break
  21. # Corrupt and pseudonymize
  22. correct = correct.lower()
  23. if random.uniform(0, 1) >= self.no_corruption_prob:
  24. typo = self.word_corrupter.corrupt_word(correct)
  25. else:
  26. typo = correct
  27. left = self.mimic_pseudo.pseudonymize(left)
  28. left = self._process_note(left)
  29. left = ' '.join(left.split(' ')[-128:])
  30. right = self.mimic_pseudo.pseudonymize(right)
  31. right = self._process_note(right)
  32. right = ' '.join(right.split(' ')[:128])
  33. # Parse
  34. temp_csv_row = [-1, note_id, typo, left, right, correct]
  35. # print(f'{self.csv_fname}({note_id}, {_}/{len(self.df_note)}): {correct} -> {typo}')
  36. example = self._parse_row(temp_csv_row)
  37. return example