#!/usr/bin/python # coding=utf-8 import numpy as np import xlrd import os import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from character_processor import DataProcessor import re # from bilstm_attention import BiLSTMModel import logging fileName = './flask.log' handler = [logging.FileHandler(filename=fileName, encoding="utf-8")] logging.basicConfig(level=logging.DEBUG, handlers=handler) parent_path = os.path.dirname(os.path.realpath(__file__)) cws_model_path = parent_path + "/ltp_data_v3.4.0/cws.model" pos_model_path = parent_path + "/ltp_data_v3.4.0/pos.model" stop_word_path = parent_path + "/ltp_data_v3.4.0/stop_word.txt" synonym_word_path = parent_path + "/ltp_data_v3.4.0/HIT-IRLab-同义词词林.txt" grander_path = os.path.dirname(parent_path) classify_model_path_base = grander_path + "/classify_model/" word2index_path_base = grander_path + "/word2index/" max_len = 64 vocab_size = 5000 # 词表大小 embedding_size = 64 # 词向量维度 batch_size = 16 def get_testcase_label(json_data): logging.info("正在加载json数据") dataset_name = json_data['datasetName'] # test_process = json_data['testProcess'] # test_requirement = json_data['testRequirement'] # product_version_module = json_data['productVersionModule'] sentence = json_data['sentence'] logging.info("正在进行ltp处理") text = text_after_ltp(sentence) logging.info("正在加载词表") word2index_path = word2index_path_base + str(dataset_name) + ".npy" word2index = np.load(word2index_path, allow_pickle=True).item() text_feature = text_to_feature(text, word2index, max_len) logging.info("正在加载分类模型") classify_model_path = classify_model_path_base + str(dataset_name) + ".pth" classify_model = torch.load(classify_model_path) text_dataset = [] text_dataset.append(text_feature) label = [[0, 0, 0, 0, 0, 0]] embed = nn.Embedding(vocab_size + 2, embedding_size) text_dataset = Variable(embed(torch.LongTensor(text_dataset)), requires_grad=False) label = torch.FloatTensor(label) text_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(text_dataset, label), batch_size=16, shuffle=False) logging.info("正在获取分类结果") for data, label in text_loader: no_attention, preds = classify_model.forward(data) category_id = torch.max(preds, 1)[1].detach().numpy()[0] accuracy = torch.max(preds, 1)[0].detach().numpy()[0] break result = { "categoryId": int(category_id), "accuracy": float(accuracy) } logging.info(result) return result # 输入预处理:分词、去除停用此、替换同义词、词性筛选 def text_after_ltp(test_process, test_requirement, product_version_module, name = None): input = name + test_requirement + test_process + product_version_module input = input.replace("\n"," ") processor = DataProcessor(cws_model_path, pos_model_path) stop_word = processor.stop_word_list(stop_word_path) synonym_dict = processor.synonym_word_dict(synonym_word_path) input_split_word = processor.segmentor(input) input_clean_word = processor.clean_word_list(input_split_word, stop_word) input_no_synonym = processor.synonym_replace_sentence(input_clean_word, synonym_dict) input_no_num = [i for i in input_no_synonym if not re.findall(r'^\d+\.',i)] return input_no_num # 根据词表将关键词向量转词向量 def text_to_feature(text, word2index, max_len): feature = [] for word in text: word = word.lower() if word in word2index: feature.append(word2index[word]) else: feature.append(word2index[""]) if (len(feature) == max_len): break feature = feature + [word2index[""]] * (max_len - len(feature)) return feature if __name__ == "__main__": json_data = { 'datasetName': 'test', } get_testcase_label()