9-3-SVM-recognise-XSS.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from sklearn.svm import SVC
  2. from datasets import Datasets
  3. from sklearn.preprocessing import StandardScaler
  4. from sklearn.model_selection import train_test_split, cross_val_score
  5. import re
  6. # 特征选取: url长度、url包含第三方域名个数、敏感字符个数、敏感关键字个数
  7. # url长度
  8. def url_len(url):
  9. return len(url)
  10. # url是否包含第三方域名
  11. def url_has_domain(url):
  12. return 1 if re.search('(http://)|(https://)', url, re.IGNORECASE) else 0
  13. # 敏感字符个数
  14. def evil_str_count(url):
  15. return len(re.findall("[<>,\'\"/]", url, re.IGNORECASE))
  16. # 敏感关键字个数
  17. def evil_keywords_count(url):
  18. blacklist = "(alert)|(script=)(%3c)|(%3e)|(%20)|(onerror)|(onload)|(eval)|(src=)|(prompt)"
  19. return len(re.findall(blacklist, url, re.IGNORECASE))
  20. # 特征提取
  21. def get_feature(url):
  22. return [url_len(url), url_has_domain(url), evil_str_count(url), evil_keywords_count(url)]
  23. def main():
  24. data, y = Datasets.load_xss()
  25. x = []
  26. for url in data:
  27. x.append(get_feature(url))
  28. # 标准化
  29. std = StandardScaler()
  30. x = std.fit_transform(x)
  31. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3)
  32. # 用SVM模型并训练
  33. clf = SVC(kernel='linear')
  34. clf.fit(x_train, y_train)
  35. print(clf.score(x_test, y_test))
  36. # 交叉验证 十组比较慢
  37. scores = cross_val_score(clf, x, y, cv=10, scoring='accuracy')
  38. print(scores.mean())
  39. if __name__ == "__main__":
  40. main()