6-5-RandomForest-detect-FTP-attack.py 1.1 KB

1234567891011121314151617181920212223242526272829303132
  1. from sklearn.ensemble import RandomForestClassifier
  2. from sklearn.model_selection import cross_val_score
  3. from sklearn.feature_extraction.text import CountVectorizer
  4. import pydotplus
  5. from datasets import Datasets
  6. def main():
  7. # 加载ADFA-LD 数据
  8. x1, y1 = Datasets.load_adfa_normal()
  9. x2, y2 = Datasets.load_adfa_attack(r"Hydra_FTP_\d+/UAD-Hydra-FTP*")
  10. x = x1 + x2
  11. y = y1 + y2
  12. # 词袋特征
  13. cv = CountVectorizer()
  14. x = cv.fit_transform(x).toarray()
  15. # 随机森林 交叉验证
  16. """
  17. n_estimators: 多少颗树
  18. max_depth: 树的最大深度 如果值为None,那么会扩展节点,直到所有的叶子是纯净的,或者直到所有叶子包含少于min_sample_split的样本
  19. min_samples_split:分割内部节点所需要的最小样本数量
  20. random_state:random_state是随机数生成器使用的种子
  21. """
  22. dec = RandomForestClassifier(n_estimators=10, max_depth=None, min_samples_split=2, random_state=0)
  23. scores = cross_val_score(dec, x, y, n_jobs=-1, cv=10)
  24. print(scores.mean()) # 0.9829090909090908
  25. if __name__ == "__main__":
  26. main()