contrast_experiment.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #!/usr/bin/python
  2. # coding=utf-8
  3. import os
  4. import numpy as np
  5. import logging
  6. import sklearn
  7. import torch
  8. from sklearn.model_selection import train_test_split # 导入切分训练集、测试集模块
  9. from sklearn.neighbors import KNeighborsClassifier
  10. from sklearn.neural_network import MLPClassifier
  11. from sklearn import svm
  12. from sklearn.naive_bayes import GaussianNB
  13. fileName = './constract.log'
  14. formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(module)s: %(message)s',
  15. datefmt='%m/%d/%Y %H:%M:%S')
  16. handler = logging.FileHandler(filename=fileName, encoding="utf-8")
  17. handler.setFormatter(formatter)
  18. logging.basicConfig(level=logging.DEBUG, handlers=[handler])
  19. parent_path = os.path.dirname(os.path.realpath(__file__))
  20. grander_path = os.path.dirname(parent_path)
  21. word_list_data_path_base = parent_path + "/word_list_data/"
  22. word2index_path_base = grander_path + "/word2index/"
  23. data_path = './word_list_data/'
  24. dataset_name = "决赛自主可控众测web自主可控运维管理系统"
  25. max_len = 64
  26. vocab_size = 5000
  27. embedding_size = 64
  28. batch_size = 16
  29. random_state = 15
  30. def contrast():
  31. logging.info("正在加载初始数据")
  32. direc = "./splited_data/"
  33. # txts = np.load(word_list_data_path_base + str(dataset_name) + ".npy", allow_pickle=True)
  34. # labels = np.load(word_list_data_path_base + str(dataset_name) + "_label.npy", allow_pickle=True)
  35. txts = np.load(direc + str(dataset_name) + "_train.npy", allow_pickle=True).tolist()
  36. labels = np.load(direc + str(dataset_name) + "_label_train.npy", allow_pickle=True).tolist()
  37. labels_new = []
  38. for label in labels:
  39. label_new = 0
  40. for i in range(len(label)):
  41. label_new += i * label[i]
  42. labels_new.append(label_new)
  43. labels_new = np.array(labels_new)
  44. logging.info("正在加载词表")
  45. word2index_path = word2index_path_base + str(dataset_name) + ".npy"
  46. word2index = np.load(word2index_path, allow_pickle=True).item()
  47. features = []
  48. for txt in txts:
  49. text_feature = text_to_feature(txt, word2index, max_len)
  50. features.append(text_feature)
  51. # np.save(, features)
  52. score_knn_lowest = 100
  53. score_svm_lowest = 100
  54. score_nb_lowest = 100
  55. score_bpnn_lowest = 100
  56. score_knn_all = 0
  57. recall_knn_all = 0
  58. f1_knn_all = 0
  59. pre_knn_all = 0
  60. score_svm_all = 0
  61. recall_svm_all = 0
  62. f1_svm_all = 0
  63. pre_svm_all = 0
  64. score_nb_all = 0
  65. recall_nb_all = 0
  66. f1_nb_all = 0
  67. pre_nb_all = 0
  68. score_bpnn_all = 0
  69. recall_bpnn_all = 0
  70. f1_bpnn_all = 0
  71. pre_bpnn_all = 0
  72. for i in range(random_state):
  73. train_data, test_data, train_label, test_label = sklearn.model_selection.train_test_split(features, labels_new,
  74. random_state=i,
  75. train_size=0.6,
  76. test_size=0.2)
  77. logging.info("正在训练k最近邻分类器")
  78. knn_classifier = KNeighborsClassifier()
  79. knn_classifier.fit(train_data, train_label)
  80. knn_predict = knn_classifier.predict(test_data)
  81. recall_knn = sklearn.metrics.recall_score(test_label, knn_predict, average="macro")
  82. f1_knn = sklearn.metrics.f1_score(test_label, knn_predict, average="macro")
  83. score_knn = knn_classifier.score(test_data, test_label)
  84. pre_knn = sklearn.metrics.precision_score(test_label, knn_predict, average="macro")
  85. if score_knn < score_knn_lowest:
  86. score_knn_lowest = score_knn
  87. score_knn_all = score_knn_all + score_knn
  88. recall_knn_all += recall_knn
  89. f1_knn_all += f1_knn
  90. pre_knn_all += pre_knn
  91. logging.info("k最近邻分类器Acc为{}".format(score_knn))
  92. logging.info("k最近邻分类器召回率为{}".format(recall_knn))
  93. logging.info("k最近邻分类器f1_score为{}".format(f1_knn))
  94. logging.info("正在训练SVM分类器")
  95. svm_classifier = svm.SVC(C=2, kernel='rbf', gamma=10, decision_function_shape='ovr')
  96. svm_classifier.fit(train_data, train_label)
  97. svm_predict = svm_classifier.predict(test_data)
  98. recall_svm = sklearn.metrics.recall_score(test_label, svm_predict, average="macro")
  99. f1_svm = sklearn.metrics.f1_score(test_label, svm_predict, average="macro")
  100. score_svm = svm_classifier.score(test_data, test_label)
  101. pre_svm = sklearn.metrics.precision_score(test_label, svm_predict, average="macro")
  102. if score_svm < score_svm_lowest:
  103. score_svm_lowest = score_svm
  104. score_svm_all = score_svm_all + score_svm
  105. recall_svm_all += recall_svm
  106. f1_svm_all += f1_svm
  107. pre_svm_all += pre_svm
  108. logging.info("SVM分类器Acc为{}".format(score_svm))
  109. logging.info("SVM分类器召回率为{}".format(recall_svm))
  110. logging.info("SVM分类器f1_score为{}".format(f1_svm))
  111. #
  112. logging.info("正在训练朴素贝叶斯分类器")
  113. muNB_classifier = GaussianNB()
  114. muNB_classifier.fit(train_data, train_label)
  115. muNB_predict = muNB_classifier.predict(test_data)
  116. recall_nb = sklearn.metrics.recall_score(test_label, muNB_predict, average="macro")
  117. f1_nb = sklearn.metrics.f1_score(test_label, muNB_predict, average="macro")
  118. score_nb = muNB_classifier.score(test_data, test_label)
  119. pre_nb = sklearn.metrics.precision_score(test_label, muNB_predict, average="macro")
  120. if score_nb < score_nb_lowest:
  121. score_nb_lowest = score_nb
  122. score_nb_all = score_nb_all + score_nb
  123. recall_nb_all += recall_nb
  124. f1_nb_all += f1_nb
  125. pre_nb_all += pre_nb
  126. logging.info("朴素贝叶斯分类器Acc为{}".format(score_nb))
  127. logging.info("朴素贝叶斯分类器召回率为{}".format(recall_nb))
  128. logging.info("朴素贝叶斯分类器f1_score为{}".format(f1_nb))
  129. logging.info("正在训练bpnn分类器")
  130. bpnn_classifier = MLPClassifier(solver='lbfgs', random_state=0, hidden_layer_sizes=[10, 10])
  131. bpnn_classifier.fit(train_data, train_label)
  132. bpnn_predict = bpnn_classifier.predict(test_data)
  133. recall_bpnn = sklearn.metrics.recall_score(test_label, bpnn_predict, average="macro")
  134. f1_bpnn = sklearn.metrics.f1_score(test_label, bpnn_predict, average="macro")
  135. score_bpnn = bpnn_classifier.score(test_data, test_label)
  136. pre_bpnn = sklearn.metrics.precision_score(test_label, bpnn_predict, average="macro")
  137. if score_bpnn < score_bpnn_lowest:
  138. score_bpnn_lowest = score_bpnn
  139. score_bpnn_all = score_bpnn_all + score_bpnn
  140. recall_bpnn_all += recall_bpnn
  141. f1_bpnn_all += f1_bpnn
  142. pre_bpnn_all += pre_bpnn
  143. logging.info("bpnn分类器Acc为{}".format(score_bpnn))
  144. logging.info("bpnn分类器召回率为{}".format(recall_bpnn))
  145. logging.info("bpnn分类器f1_score为{}".format(f1_bpnn))
  146. logging.info("数据集 " + dataset_name + " 结果:")
  147. logging.info("k最近邻分类器最低准确率为{}".format(score_knn_lowest))
  148. logging.info("SVM分类器最低准确率为{}".format(score_svm_lowest))
  149. logging.info("朴素贝叶斯分类器最低准确率为{}".format(score_nb_lowest))
  150. logging.info("k最近邻分类器平均Acc为{}".format(score_knn_all / random_state))
  151. logging.info("SVM分类器平均Acc为{}".format(score_svm_all / random_state))
  152. logging.info("朴素贝叶斯分类器平均Acc为{}".format(score_nb_all / random_state))
  153. logging.info("k最近邻分类器平均召回率为{}".format(recall_knn_all / random_state))
  154. logging.info("SVM分类器平均召回率为{}".format(recall_svm_all / random_state))
  155. logging.info("朴素贝叶斯分类器平均召回率为{}".format(recall_nb_all / random_state))
  156. logging.info("k最近邻分类器平均f1_score为{}".format(f1_knn_all / random_state))
  157. logging.info("SVM分类器平均f1_score为{}".format(f1_svm_all / random_state))
  158. logging.info("朴素贝叶斯分类器平均f1_score为{}".format(f1_nb_all / random_state))
  159. logging.info("k最近邻分类器平均precision为{}".format(pre_knn_all / random_state))
  160. logging.info("SVM分类器平均precision为{}".format(pre_svm_all / random_state))
  161. logging.info("朴素贝叶斯分类器平均precision为{}".format(pre_nb_all / random_state))
  162. logging.info("bpnn分类器平均Acc为{}".format(score_bpnn_all / random_state))
  163. logging.info("bpnn分类器平均召回率为{}".format(recall_bpnn_all / random_state))
  164. logging.info("bpnn分类器平均f1_score为{}".format(f1_bpnn_all / random_state))
  165. logging.info("bpnn分类器平均precision为{}".format(pre_bpnn_all / random_state))
  166. def text_to_feature(text, word2index, max_len):
  167. feature = []
  168. for word in text:
  169. if word in word2index:
  170. feature.append(word2index[word])
  171. else:
  172. feature.append(word2index["<unk>"])
  173. if len(feature) == max_len:
  174. break
  175. feature = feature + [word2index["<pad>"]] * (max_len - len(feature))
  176. return feature
  177. def calculate_bi_standards(name):
  178. model = torch.load(name)
  179. pass
  180. if __name__ == "__main__":
  181. contrast()