contrast_experiment.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. import os
  4. import numpy as np
  5. import logging
  6. import sklearn
  7. from sklearn.model_selection import train_test_split #导入切分训练集、测试集模块
  8. from sklearn.neighbors import KNeighborsClassifier
  9. from sklearn import svm
  10. from sklearn.naive_bayes import GaussianNB
  11. fileName = './constract.log'
  12. handler = [logging.FileHandler(filename=fileName,encoding="utf-8")]
  13. logging.basicConfig(level = logging.DEBUG, handlers = handler)
  14. parent_path = os.path.dirname(os.path.realpath(__file__))
  15. grander_path = os.path.dirname(parent_path)
  16. word_list_data_path_base = parent_path + "/word_list_data/"
  17. word2index_path_base = grander_path + "/word2index/"
  18. dataset_name = "航天中认自主可控众包测试练习赛"
  19. max_len = 64
  20. vocab_size = 5000
  21. embedding_size = 64
  22. batch_size = 16
  23. random_state = 15
  24. def contrast():
  25. logging.info("正在加载初始数据")
  26. txts = np.load(word_list_data_path_base + str(dataset_name) + ".npy", allow_pickle=True)
  27. labels = np.load(word_list_data_path_base + str(dataset_name) + "_label.npy", allow_pickle=True)
  28. labels_new = []
  29. for label in labels:
  30. label_new = 0
  31. for i in range(len(label)):
  32. label_new += i * label[i]
  33. labels_new.append(label_new)
  34. labels_new = np.array(labels_new)
  35. logging.info("正在加载词表")
  36. word2index_path = word2index_path_base + str(dataset_name) + ".npy"
  37. word2index = np.load(word2index_path, allow_pickle=True).item()
  38. features = []
  39. for txt in txts:
  40. text_feature = text_to_feature(txt, word2index, max_len)
  41. features.append(text_feature)
  42. #np.save(, features)
  43. score_knn_lowest = 100
  44. score_svm_lowest = 100
  45. score_nb_lowest = 100
  46. score_knn_all = 0
  47. score_svm_all = 0
  48. score_nb_all = 0
  49. for i in range(random_state):
  50. train_data, test_data, train_label, test_label = sklearn.model_selection.train_test_split(features, labels_new, random_state = i, train_size = 0.2,test_size = 0.8)
  51. logging.info("正在训练k最近邻分类器")
  52. knn_classifier = KNeighborsClassifier()
  53. knn_classifier.fit(train_data, train_label)
  54. score_knn = knn_classifier.score(test_data, test_label)
  55. if score_knn < score_knn_lowest:
  56. score_knn_lowest = score_knn
  57. score_knn_all = score_knn_all + score_knn
  58. logging.info("k最近邻分类器准确率为{}".format(score_knn))
  59. logging.info("正在训练SVM分类器")
  60. svm_classifier = svm.SVC(C=2,kernel='rbf',gamma=10,decision_function_shape='ovr')
  61. svm_classifier.fit(train_data, train_label)
  62. score_svm = svm_classifier.score(test_data, test_label)
  63. if score_svm < score_svm_lowest:
  64. score_svm_lowest = score_svm
  65. score_svm_all = score_svm_all + score_svm
  66. logging.info("SVM分类器准确率为{}".format(score_svm))
  67. logging.info("正在训练朴素贝叶斯分类器")
  68. muNB_classifier = GaussianNB()
  69. muNB_classifier.fit(train_data, train_label)
  70. score_nb = muNB_classifier.score(test_data, test_label)
  71. if score_nb < score_nb_lowest:
  72. score_nb_lowest = score_nb
  73. score_nb_all = score_nb_all + score_nb
  74. logging.info("朴素贝叶斯分类器准确率为{}".format(score_nb))
  75. logging.info("k最近邻分类器最低准确率为{}".format(score_knn_lowest))
  76. logging.info("SVM分类器最低准确率为{}".format(score_svm_lowest))
  77. logging.info("朴素贝叶斯分类器最低准确率为{}".format(score_nb_lowest))
  78. logging.info("k最近邻分类器平均准确率为{}".format(score_knn_all / random_state))
  79. logging.info("SVM分类器平均准确率为{}".format(score_svm_all / random_state))
  80. logging.info("朴素贝叶斯分类器平均准确率为{}".format(score_nb_all / random_state))
  81. def text_to_feature(text, word2index, max_len):
  82. feature = []
  83. for word in text:
  84. if word in word2index:
  85. feature.append(word2index[word])
  86. else:
  87. feature.append(word2index["<unk>"])
  88. if(len(feature) == max_len):
  89. break
  90. feature = feature + [word2index["<pad>"]] * (max_len - len(feature))
  91. return feature
  92. if __name__ == "__main__":
  93. contrast()