17-3-CNN-recognise-malicious-comments.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. from datasets import Datasets
  4. from sklearn.model_selection import train_test_split
  5. def main():
  6. # 导入数据
  7. x, y = Datasets.load_movie_review()
  8. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3)
  9. # 数据预处理,词袋
  10. tokenizer = tf.keras.preprocessing.text.Tokenizer()
  11. tokenizer.fit_on_texts(x)
  12. x_train = tokenizer.texts_to_sequences(x_train)
  13. x_test = tokenizer.texts_to_sequences(x_test)
  14. num_words = len(tokenizer.word_index)
  15. # 序列编码one-hot
  16. x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=200)
  17. x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, maxlen=200)
  18. y_train = tf.keras.utils.to_categorical(y_train, num_classes=2)
  19. y_test = tf.keras.utils.to_categorical(y_test, num_classes=2)
  20. # 顺序模型(层直接写在里面,省写add)
  21. model = tf.keras.Sequential([
  22. tf.keras.layers.Embedding(
  23. input_dim=num_words + 1, # 字典长度 加1 不然会报错
  24. output_dim=300, # 全连接嵌入的维度,常用256或300
  25. input_length=200, # 当输入序列的长度固定时,该值为其长度
  26. trainable=True, # 代表词向量作为参数进行更新
  27. ),
  28. # 卷积层
  29. tf.keras.layers.Conv1D(
  30. filters=64, # 64个卷积核
  31. kernel_size=3, # 大小3
  32. padding='valid', # 卷积模式
  33. activation="relu", # 激活函数
  34. ),
  35. tf.keras.layers.MaxPool1D(pool_size=2), # 池化
  36. tf.keras.layers.Dropout(.25), # 丢弃25% 防止过拟合
  37. tf.keras.layers.Flatten(),
  38. tf.keras.layers.Dense(2, activation="softmax"),
  39. ])
  40. # 编译模型
  41. model.compile(
  42. optimizer="adam", # 优化器
  43. loss="categorical_crossentropy", # 损失函数
  44. metrics=["acc"], # 观察值, acc正确率
  45. )
  46. # 训练
  47. history = model.fit(
  48. x_train, y_train,
  49. batch_size=32, # 一次放入多少样本
  50. epochs=10,
  51. validation_data=(x_test, y_test),
  52. )
  53. # loss: 0.0013 - acc: 1.0000 - val_loss: 0.5421 - val_acc: 0.7350
  54. # 画图 正确率(是否过拟合)
  55. plt.plot(history.epoch, history.history.get("acc"))
  56. plt.plot(history.epoch, history.history.get("val_acc"))
  57. plt.show()
  58. if __name__ == "__main__":
  59. main()