# -*- coding: utf-8 -*-
"""
Created on Thu May 21 19:19:01 2020
读取数据并对数据做预处理
统计出训练数据中出现频次最多的5k个单词,用这出现最多的5k个单词创建词表(词向量)
对于测试数据,直接用训练数据构建的词表
@author: 
"""
import os
import copy
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data  # 新添加代码
import sklearn
from sklearn import model_selection
import numpy as np
import pymysql
import classifyer
from nlpcda import Simbert
import logging

fileName = './model_train.log'
handler = [logging.FileHandler(filename=fileName, encoding="utf-8")]
logging.basicConfig(level=logging.DEBUG, handlers=handler)

simbert_config = {
    'model_path': './chinese_simbert_L-12_H-768_A-12',
    'CUDA_VISIBLE_DEVICES': '0,1',
    'max_len': 64,
    'seed': 1
}

bug_type = ["不正常退出", "功能不完整", "用户体验", "页面布局缺陷", "性能", "安全"]
num_classes = len(bug_type)
word2index_path_base = "../word2index/"
torch.manual_seed(123)

datas = []
labels = []


def read_file(file_path):
    f = open(file_path, "r", encoding="utf-8")
    msg = f.read()
    return msg


def process_file(root_path):
    # 获取该目录下所有的文件名称和目录名称
    dir_or_files = os.listdir(root_path)
    for dir_file in dir_or_files:
        # 获取目录或者文件的路径
        dir_file_path = os.path.join(root_path, dir_file)
        # 判断该路径为文件还是路径
        if os.path.isdir(dir_file_path):
            # 递归获取所有文件和目录的路径
            process_file(dir_file_path)
        else:
            if "after" in dir_file_path:
                description = read_file(dir_file_path)
                datas.append(description)
                label_china = dir_file_path.split("/")[8]
                label = []
                for i in bug_type:
                    if i == label_china:
                        label.append(1)
                    else:
                        label.append(0)
                labels.append(label)


class DataProcessor(object):
    def __init__(self, dataset_name=None, host=None, user=None, password=None):
        self.dataset_name = dataset_name
        self.datas_path = "./word_list_data/" + str(self.dataset_name) + ".npy"
        self.labels_path = "./word_list_data/" + str(self.dataset_name) + "_label.npy"
        self.datas_increase_path = "./word_list_data/" + str(self.dataset_name) + "_increase.npy"
        self.labels_increase_path = "./word_list_data/" + str(self.dataset_name) + "_label_increase.npy"
        self.host = host
        self.user = user
        self.password = password
        if user == None or password == None:
            self.host = "127.0.0.1"
            self.user = "root"
            self.password = "123456"

    def read_text_from_db(self):
        datas = []
        labels = []
        conn = pymysql.connect(host=self.host, user=self.user, password=self.password, database="mt_clerk_test",
                               charset="utf8")
        cursor = conn.cursor()
        try:
            sql = "select id from dataset where name = %s"
            cursor.execute(sql, str(self.dataset_name))
            dataset_id = cursor.fetchall()[0][0]
            sql = "select test_process,test_requirement,product_version_module,tccategory_id,name from test_case where dataset_id = %s and tccategory_id is not null"
            cursor.execute(sql, str(dataset_id))
            results = cursor.fetchall()
            for row in results:
                test_process = row[0]
                test_requirement = row[1]
                product_version_module = row[2]
                tccategory_id = int(row[3])
                name = row[4]
                text = classifyer.text_after_ltp(test_process, test_requirement, product_version_module, name)
                datas.append(text)
                label = []
                for i in range(num_classes):
                    if i == tccategory_id:
                        label.append(1)
                    else:
                        label.append(0)
                labels.append(label)
        except Exception as e:
            raise e
        finally:
            cursor.close()
            conn.close()
            np.save(self.datas_path, datas)
            np.save(self.labels_path, labels)
        return datas, labels

    def read_text_from_file_system(self):
        global datas, labels
        process_file("/Users/tanghaojie/Desktop/final/手动标记后的数据/决赛自主可控众测web自主可控运维管理系统")
        return datas, labels

    def increase_data(self):
        simbert = Simbert(config=simbert_config)
        datas_pre = np.load(self.datas_path, allow_pickle=True)
        labels_pre = np.load(self.labels_path, allow_pickle=True)

        datas = []
        labels = []
        num = 5
        if (len(datas_pre) == len(labels_pre)):
            for i in range(len(datas_pre)):
                datas.append(datas_pre[i])
                labels.append(labels_pre[i])
                synonyms = simbert.replace(sent=datas_pre[i], create_num=num)
                for j in range(num):
                    datas.append(synonyms[j][0])
                    labels.append(labels_pre[i])
        np.save(self.datas_increase_path, datas)
        np.save(self.labels_increase_path, labels)
        return datas, labels

    def word_count(self, datas):
        # 统计单词出现的频次,并将其降序排列,得出出现频次最多的单词
        dic = {}
        for data in datas:
            for word in data:
                word = word.lower()  # 所有单词转化为小写,中文没有小写 todo
                if (word in dic):
                    dic[word] += 1
                else:
                    dic[word] = 1
        word_count_sorted = sorted(dic.items(), key=lambda item: item[1], reverse=True)
        return word_count_sorted  # 键是词,值是出现的次数

    def word_index(self, datas, vocab_size):
        # 创建词表
        word_count_sorted = self.word_count(datas)
        word2index = {}
        # 词表中未出现的词,因为词表大小有限,所以有些句子中的词不在词表中
        word2index["<unk>"] = 0
        # 句子添加的padding,whats this
        word2index["<pad>"] = 1

        # 词表的实际大小由词的数量和限定大小决定
        vocab_size = min(len(word_count_sorted), vocab_size)
        for i in range(vocab_size):
            word = word_count_sorted[i][0]
            word2index[word] = i + 2  # 键是 词,值是在word2index列表中的位置

        word2index_path = word2index_path_base + self.dataset_name + ".npy"
        np.save(word2index_path, word2index)
        return word2index, vocab_size

    def get_datasets_origin(self, vocab_size, max_len):
        # 注,由于nn.Embedding每次生成的词嵌入不固定,因此此处同时获取训练数据的词嵌入和测试数据的词嵌入
        # 测试数据的词表也用训练数据创建

        # logging.info('正在从数据库读取原始数据')
        # txt_origin, label_origin = self.read_text_from_db()
        # txt_origin = np.load(self.datas_path, allow_pickle=True).tolist()
        # label_origin = np.load(self.labels_path, allow_pickle=True).tolist()
        logging.info('正在对原始数据进行数据扩增')
        # txt_origin, label_origin = self.increase_data()
        txt_origin = np.load(self.datas_increase_path, allow_pickle=True).tolist()
        label_origin = np.load(self.labels_increase_path, allow_pickle=True).tolist()

        label_count = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
        for i in label_origin:
            sum = 0
            for j in range(len(i)):
                if i[j] == 1:
                    sum = j
            label_count[sum] = label_count[sum] + 1

        logging.info('正在统计原始数据的标签类型', label_count)
        train_datas, test_datas, train_labels, test_labels = sklearn.model_selection.train_test_split(txt_origin,
                                                                                                      label_origin,
                                                                                                      random_state=2,
                                                                                                      train_size=0.2,
                                                                                                      test_size=0.8)
        test_datas, develop_datas, test_labels, develop_labels = sklearn.model_selection.train_test_split(test_datas,
                                                                                                          test_labels,
                                                                                                          random_state=2,
                                                                                                          train_size=0.25,
                                                                                                          test_size=0.75)

        logging.info('正在制作词表')
        word_datas = copy.deepcopy(train_datas)
        word_datas.extend(develop_datas)
        word_datas.extend(test_datas)
        word2index, vocab_size = self.word_index(word_datas, vocab_size)  # 获得word2index词表 和 词表的实际大小

        logging.info('正在获取词向量')
        train_features = []
        for data in train_datas:
            feature = []
            for word in data:
                word = word.lower()  # 词表中的单词均为小写
                if word in word2index:
                    feature.append(word2index[word])
                else:
                    feature.append(word2index["<unk>"])  # 词表中未出现的词用<unk>代替
                if (len(feature) == max_len):  # 限制句子的最大长度,超出部分直接截断
                    break
            # 对未达到最大长度的句子添加padding
            feature = feature + [word2index["<pad>"]] * (max_len - len(feature))
            train_features.append(feature)

        develop_features = []
        for data in develop_datas:
            feature = []
            for word in data:
                word = word.lower()  # 词表中的单词均为小写
                if word in word2index:
                    feature.append(word2index[word])
                else:
                    feature.append(word2index["<unk>"])  # 词表中未出现的词用<unk>代替
                if (len(feature) == max_len):  # 限制句子的最大长度,超出部分直接截断
                    break
            # 对未达到最大长度的句子添加padding
            feature = feature + [word2index["<pad>"]] * (max_len - len(feature))
            develop_features.append(feature)

        test_features = []
        for data in test_datas:
            feature = []
            for word in data:
                word = word.lower()  # 词表中的单词均为小写
                if word in word2index:
                    feature.append(word2index[word])
                else:
                    feature.append(word2index["<unk>"])  # 词表中未出现的词用<unk>代替
                if (len(feature) == max_len):  # 限制句子的最大长度,超出部分直接截断
                    break
            # 对未达到最大长度的句子添加padding
            feature = feature + [word2index["<pad>"]] * (max_len - len(feature))
            test_features.append(feature)
        return train_features, develop_features, test_features, train_labels, develop_labels, test_labels, word2index

    def get_datasets(self, train_features, develop_features, test_features, train_labels, develop_labels, test_labels,
                     vocab_size, embedding_size):
        # 将词的index转换成tensor,train_features中数据的维度需要一致,否则会报错
        train_features = torch.LongTensor(train_features)
        train_labels = torch.FloatTensor(train_labels)

        develop_features = torch.LongTensor(develop_features)
        develop_labels = torch.FloatTensor(develop_labels)

        test_features = torch.LongTensor(test_features)
        test_labels = torch.FloatTensor(test_labels)

        # 将词转化为embedding
        # 词表中有两个特殊的词<unk>和<pad>,所以词表实际大小为vocab_size + 2
        embed = nn.Embedding(vocab_size + 2, embedding_size)  # https://www.jianshu.com/p/63e7acc5e890
        train_features = embed(train_features)
        develop_features = embed(develop_features)
        test_features = embed(test_features)

        # 指定输入特征是否需要计算梯度
        train_features = Variable(train_features,
                                  requires_grad=False)  # https://www.cnblogs.com/henuliulei/p/11363121.html
        train_datasets = torch.utils.data.TensorDataset(train_features, train_labels)

        develop_features = Variable(develop_features,
                                    requires_grad=False)  # https://www.cnblogs.com/henuliulei/p/11363121.html
        develop_datasets = torch.utils.data.TensorDataset(develop_features, develop_labels)

        test_features = Variable(test_features, requires_grad=False)
        test_datasets = torch.utils.data.TensorDataset(test_features,
                                                       test_labels)  # https://www.cnblogs.com/hahaah/p/14914603.html
        return train_datasets, develop_datasets, test_datasets