16-7-RNN-recognise-abnormal-operation-LSTM.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from datasets import Datasets
  2. import tensorflow as tf
  3. from sklearn.model_selection import train_test_split
  4. import matplotlib.pyplot as plt
  5. # 特征提取,使用词集将操作命令向量化,根据操作统计命令词集来判断
  6. def get_feature(cmd, fdist):
  7. feature = []
  8. for block in cmd:
  9. v = [0] * len(fdist)
  10. for i in range(0, len(fdist)):
  11. if fdist[i] in block:
  12. v[i] += 1
  13. feature.append(v)
  14. return feature
  15. def main():
  16. # 导入数据
  17. data, y, fdist = Datasets.load_Schonlau('User3')
  18. x = get_feature(data, fdist)
  19. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3)
  20. num_words = len(x)
  21. # 序列编码one-hot
  22. x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=100)
  23. x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, maxlen=100)
  24. y_train = tf.keras.utils.to_categorical(y_train, num_classes=2)
  25. y_test = tf.keras.utils.to_categorical(y_test, num_classes=2)
  26. # 顺序模型(层直接写在里面,省写add)
  27. model = tf.keras.Sequential([
  28. tf.keras.layers.Embedding(
  29. input_dim=num_words + 1, # 字典长度 加1 不然会报错
  30. output_dim=128,
  31. input_length=100, # 当输入序列的长度固定时,该值为其长度
  32. ),
  33. tf.keras.layers.LSTM(64),
  34. tf.keras.layers.Dense(2, activation="softmax"),
  35. ])
  36. # 编译模型
  37. model.compile(
  38. optimizer="adam", # 优化器
  39. loss="categorical_crossentropy", # 损失函数
  40. metrics=["acc"], # 观察值, acc正确率
  41. )
  42. # 训练
  43. history = model.fit(
  44. x_train, y_train,
  45. batch_size=32, # 一次放入多少样本
  46. epochs=10,
  47. validation_data=(x_test, y_test),
  48. )
  49. # loss: 0.2557 - acc: 0.9143 - val_loss: 0.1812 - val_acc: 0.9556
  50. # 画图 正确率(是否过拟合)
  51. plt.plot(history.epoch, history.history.get("acc"))
  52. plt.plot(history.epoch, history.history.get("val_acc"))
  53. plt.show()
  54. if __name__ == "__main__":
  55. main()