123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- #!/usr/bin/python
- # coding=utf-8
- import numpy as np
- import xlrd
- import os
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.autograd import Variable
- from character_processor import DataProcessor
- import re
- from bilstm_attention import BiLSTMModel
- import logging
- fileName = './flask.log'
- handler = [logging.FileHandler(filename=fileName, encoding="utf-8")]
- logging.basicConfig(level=logging.DEBUG, handlers=handler)
- parent_path = os.path.dirname(os.path.realpath(__file__))
- cws_model_path = parent_path + "/ltp_data_v3.4.0/cws.model"
- pos_model_path = parent_path + "/ltp_data_v3.4.0/pos.model"
- stop_word_path = parent_path + "/ltp_data_v3.4.0/stop_word.txt"
- synonym_word_path = parent_path + "/ltp_data_v3.4.0/HIT-IRLab-同义词词林.txt"
- grander_path = os.path.dirname(parent_path)
- classify_model_path_base = grander_path + "/classify_model/"
- word2index_path_base = grander_path + "/word2index/"
- max_len = 64
- vocab_size = 5000 # 词表大小
- embedding_size = 64 # 词向量维度
- batch_size = 16
- def get_testcase_label(json_data):
- logging.info("正在加载json数据")
- dataset_name = json_data['datasetName']
- # test_process = json_data['testProcess']
- # test_requirement = json_data['testRequirement']
- # product_version_module = json_data['productVersionModule']
- sentence = json_data['sentence']
- logging.info("正在进行ltp处理")
- text = text_after_ltp(sentence)
- logging.info("正在加载词表")
- word2index_path = word2index_path_base + str(dataset_name) + ".npy"
- word2index = np.load(word2index_path, allow_pickle=True).item()
- text_feature = text_to_feature(text, word2index, max_len)
- logging.info("正在加载分类模型")
- classify_model_path = classify_model_path_base + str(dataset_name) + ".pth"
- classify_model = torch.load(classify_model_path)
- text_dataset = []
- text_dataset.append(text_feature)
- label = [[0, 0, 0, 0, 0, 0]]
- embed = nn.Embedding(vocab_size + 2, embedding_size)
- text_dataset = Variable(embed(torch.LongTensor(text_dataset)), requires_grad=False)
- label = torch.FloatTensor(label)
- text_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(text_dataset, label), batch_size=16,
- shuffle=False)
- logging.info("正在获取分类结果")
- for data, label in text_loader:
- no_attention, preds = classify_model.forward(data)
- category_id = torch.max(preds, 1)[1].detach().numpy()[0]
- accuracy = torch.max(preds, 1)[0].detach().numpy()[0]
- break
- result = {
- "categoryId": int(category_id),
- "accuracy": float(accuracy)
- }
- logging.info(result)
- return result
- # 输入预处理:分词、去除停用此、替换同义词、词性筛选
- def text_after_ltp(test_process, test_requirement, product_version_module, name = None):
- input = name + test_requirement + test_process + product_version_module
- input = input.replace("\n"," ")
- processor = DataProcessor(cws_model_path, pos_model_path)
- stop_word = processor.stop_word_list(stop_word_path)
- synonym_dict = processor.synonym_word_dict(synonym_word_path)
- input_split_word = processor.segmentor(input)
- input_clean_word = processor.clean_word_list(input_split_word, stop_word)
- input_no_synonym = processor.synonym_replace_sentence(input_clean_word, synonym_dict)
- input_no_num = [i for i in input_no_synonym if not re.findall(r'^\d+\.',i)]
- return input_no_num
- # 根据词表将关键词向量转词向量
- def text_to_feature(text, word2index, max_len):
- feature = []
- for word in text:
- word = word.lower()
- if word in word2index:
- feature.append(word2index[word])
- else:
- feature.append(word2index["<unk>"])
- if (len(feature) == max_len):
- break
- feature = feature + [word2index["<pad>"]] * (max_len - len(feature))
- return feature
- if __name__ == "__main__":
- json_data = {
- 'datasetName': 'test',
- }
- get_testcase_label()
|