16-5-RNN-recognise-WebShell-LSTM.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from datasets import Datasets
  2. import tensorflow as tf
  3. from sklearn.model_selection import train_test_split
  4. import matplotlib.pyplot as plt
  5. def main():
  6. # 加载ADFA-LD 数据
  7. x1, y1 = Datasets.load_adfa_normal()
  8. x2, y2 = Datasets.load_adfa_attack(r"Web_Shell_\d+/UAD-W*")
  9. x = x1 + x2
  10. y = y1 + y2
  11. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3)
  12. # 数据预处理,词袋
  13. tokenizer = tf.keras.preprocessing.text.Tokenizer()
  14. tokenizer.fit_on_texts(x)
  15. x_train = tokenizer.texts_to_sequences(x_train)
  16. x_test = tokenizer.texts_to_sequences(x_test)
  17. num_words = len(tokenizer.word_index)
  18. # 序列编码one-hot
  19. x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=300)
  20. x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, maxlen=300)
  21. y_train = tf.keras.utils.to_categorical(y_train, num_classes=2)
  22. y_test = tf.keras.utils.to_categorical(y_test, num_classes=2)
  23. # 顺序模型(层直接写在里面,省写add)
  24. model = tf.keras.Sequential([
  25. tf.keras.layers.Embedding(
  26. input_dim=num_words + 1, # 字典长度 加1 不然会报错
  27. output_dim=128,
  28. input_length=300, # 当输入序列的长度固定时,该值为其长度
  29. ),
  30. tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, implementation=2)), # 双向LSTM
  31. tf.keras.layers.Dropout(0.5), # 丢弃50%,防止过拟合
  32. tf.keras.layers.Dense(2, activation="softmax"),
  33. ])
  34. # 编译模型
  35. model.compile(
  36. optimizer="adam", # 优化器
  37. loss="categorical_crossentropy", # 损失函数
  38. metrics=["acc"], # 观察值, acc正确率
  39. )
  40. # 训练
  41. history = model.fit(
  42. x_train, y_train,
  43. batch_size=32, # 一次放入多少样本
  44. epochs=10,
  45. validation_data=(x_test, y_test),
  46. )
  47. # loss: 0.0901 - acc: 0.9744 - val_loss: 0.0897 - val_acc: 0.9755
  48. # 画图 正确率(是否过拟合)
  49. plt.plot(history.epoch, history.history.get("acc"))
  50. plt.plot(history.epoch, history.history.get("val_acc"))
  51. plt.show()
  52. if __name__ == "__main__":
  53. main()