classifyer.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #!/usr/bin/python
  2. # coding=utf-8
  3. import numpy as np
  4. import xlrd
  5. import os
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from torch.autograd import Variable
  10. from character_processor import DataProcessor
  11. import re
  12. from bilstm_attention import BiLSTMModel
  13. import logging
  14. fileName = './flask.log'
  15. handler = [logging.FileHandler(filename=fileName, encoding="utf-8")]
  16. logging.basicConfig(level=logging.DEBUG, handlers=handler)
  17. parent_path = os.path.dirname(os.path.realpath(__file__))
  18. cws_model_path = parent_path + "/ltp_data_v3.4.0/cws.model"
  19. pos_model_path = parent_path + "/ltp_data_v3.4.0/pos.model"
  20. stop_word_path = parent_path + "/ltp_data_v3.4.0/stop_word.txt"
  21. synonym_word_path = parent_path + "/ltp_data_v3.4.0/HIT-IRLab-同义词词林.txt"
  22. grander_path = os.path.dirname(parent_path)
  23. classify_model_path_base = grander_path + "/classify_model/"
  24. word2index_path_base = grander_path + "/word2index/"
  25. max_len = 64
  26. vocab_size = 5000 # 词表大小
  27. embedding_size = 64 # 词向量维度
  28. batch_size = 16
  29. def get_testcase_label(json_data):
  30. logging.info("正在加载json数据")
  31. dataset_name = json_data['datasetName']
  32. # test_process = json_data['testProcess']
  33. # test_requirement = json_data['testRequirement']
  34. # product_version_module = json_data['productVersionModule']
  35. sentence = json_data['sentence']
  36. logging.info("正在进行ltp处理")
  37. text = text_after_ltp(sentence)
  38. logging.info("正在加载词表")
  39. word2index_path = word2index_path_base + str(dataset_name) + ".npy"
  40. word2index = np.load(word2index_path, allow_pickle=True).item()
  41. text_feature = text_to_feature(text, word2index, max_len)
  42. logging.info("正在加载分类模型")
  43. classify_model_path = classify_model_path_base + str(dataset_name) + ".pth"
  44. classify_model = torch.load(classify_model_path)
  45. text_dataset = []
  46. text_dataset.append(text_feature)
  47. label = [[0, 0, 0, 0, 0, 0]]
  48. embed = nn.Embedding(vocab_size + 2, embedding_size)
  49. text_dataset = Variable(embed(torch.LongTensor(text_dataset)), requires_grad=False)
  50. label = torch.FloatTensor(label)
  51. text_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(text_dataset, label), batch_size=16,
  52. shuffle=False)
  53. logging.info("正在获取分类结果")
  54. for data, label in text_loader:
  55. no_attention, preds = classify_model.forward(data)
  56. category_id = torch.max(preds, 1)[1].detach().numpy()[0]
  57. accuracy = torch.max(preds, 1)[0].detach().numpy()[0]
  58. break
  59. result = {
  60. "categoryId": int(category_id),
  61. "accuracy": float(accuracy)
  62. }
  63. logging.info(result)
  64. return result
  65. # 输入预处理:分词、去除停用此、替换同义词、词性筛选
  66. def text_after_ltp(test_process, test_requirement, product_version_module, name = None):
  67. input = name + test_requirement + test_process + product_version_module
  68. input = input.replace("\n"," ")
  69. processor = DataProcessor(cws_model_path, pos_model_path)
  70. stop_word = processor.stop_word_list(stop_word_path)
  71. synonym_dict = processor.synonym_word_dict(synonym_word_path)
  72. input_split_word = processor.segmentor(input)
  73. input_clean_word = processor.clean_word_list(input_split_word, stop_word)
  74. input_no_synonym = processor.synonym_replace_sentence(input_clean_word, synonym_dict)
  75. input_no_num = [i for i in input_no_synonym if not re.findall(r'^\d+\.',i)]
  76. return input_no_num
  77. # 根据词表将关键词向量转词向量
  78. def text_to_feature(text, word2index, max_len):
  79. feature = []
  80. for word in text:
  81. word = word.lower()
  82. if word in word2index:
  83. feature.append(word2index[word])
  84. else:
  85. feature.append(word2index["<unk>"])
  86. if (len(feature) == max_len):
  87. break
  88. feature = feature + [word2index["<pad>"]] * (max_len - len(feature))
  89. return feature
  90. if __name__ == "__main__":
  91. json_data = {
  92. 'datasetName': 'test',
  93. }
  94. get_testcase_label()