model.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. class Apriori:
  2. """ Apriori算法 """
  3. def apriori(self, data, min_support=0.5):
  4. """apriori算法实现
  5. :param data: 数据集
  6. :param min_support: 最小支持度
  7. :return: 频繁项集,频繁项集的支持度
  8. """
  9. # 获取非频繁项集
  10. itemset_1 = self.create_itemset(data)
  11. # 转化事务集的形式,每个元素都转化为集合。
  12. data = list(map(set, data))
  13. # 获取频繁1项集和对应的支持度
  14. frequently_itemset_1, support_data = self.scan_data(data, itemset_1, min_support)
  15. # frequently_itemset用来存储所有的频繁项集
  16. frequently_itemset = [frequently_itemset_1]
  17. k = 2
  18. # 一直迭代到项集数目过大而在事务集中不存在这种n项集
  19. while len(frequently_itemset[k - 2]) > 0:
  20. # 根据频繁项集生成新的候选项集
  21. itemset_k = self.create_new_itemset(frequently_itemset[k - 2], k)
  22. frequently_itemset_k, support_k = self.scan_data(data, itemset_k, min_support)
  23. support_data.update(support_k)
  24. frequently_itemset.append(frequently_itemset_k)
  25. k += 1
  26. return frequently_itemset, support_data
  27. def create_itemset(self, data):
  28. """创建元素为1的项集
  29. :param data: 原始数据集
  30. :return:创建元素为1的项集
  31. """
  32. itemset = []
  33. # 元素个数为1的项集(非频繁项集,因为还没有同最小支持度比较)
  34. for items in data:
  35. for item in items:
  36. if [item] not in itemset:
  37. itemset.append([item])
  38. itemset.sort() # 这里排序是为了,生成新的候选集时可以直接认为两个n项候选集前面的部分相同
  39. # 因为除了候选1项集外其他的候选n项集都是以二维列表的形式存在,所以要将候选1项集的每一个元素都转化为一个单独的集合
  40. return list(map(frozenset, itemset)) # list(map(frozenset, itemset))的语义是将C1由列表转换为不变集合
  41. def scan_data(self, data, k, min_support):
  42. """找出候选集中的频繁项集
  43. :param data: 全部数据集
  44. :param k: 为大小为包含k个元素的候选项集
  45. :param min_support: 设定的最小支持度
  46. :return: frequently_itemset为在k中找出的频繁项集(支持度大于min_support的),support_data记录各频繁项集的支持度
  47. """
  48. scan_itemset = {}
  49. for i in data:
  50. for j in k:
  51. if j.issubset(i):
  52. scan_itemset[j] = scan_itemset.get(j, 0) + 1 # 计算每一个项集出现的频率
  53. items_num = float(len(list(data)))
  54. frequently_itemset = []
  55. support_data = {}
  56. for key in scan_itemset:
  57. support = scan_itemset[key] / items_num
  58. if support >= min_support:
  59. frequently_itemset.insert(0, key) # 将频繁项集插入返回列表的首部
  60. support_data[key] = support
  61. return frequently_itemset, support_data
  62. def create_new_itemset(self, frequently_itemset, k):
  63. """通过频繁项集列表frequently_itemset和项集个数k生成候选项集
  64. :param frequently_itemset: 频繁项集列表
  65. :param k: 项集个数
  66. :return: 候选项集
  67. """
  68. new_frequently_itemset = []
  69. frequently_itemset_len = len(frequently_itemset)
  70. for i in range(frequently_itemset_len):
  71. for j in range(i + 1, frequently_itemset_len):
  72. # 前k-1项相同时,才将两个集合合并,合并后才能生成k+1项
  73. l1 = list(frequently_itemset[i])[: k - 2]
  74. l2 = list(frequently_itemset[j])[: k - 2] # 取出两个集合的前k-1个元素
  75. l1.sort()
  76. l2.sort()
  77. if l1 == l2:
  78. new_frequently_itemset.append(frequently_itemset[i] | frequently_itemset[j])
  79. return new_frequently_itemset
  80. def calc_reliability(self, freq_set, h, support_data, brl, min_reliability=0.7):
  81. """对候选规则集进行评估
  82. :param freq_set: 频繁项集
  83. :param h: 元素列表
  84. :param support_data: 项集的支持度
  85. :param brl: 生成的关联规则
  86. :param min_reliability: 最小置信度
  87. :return: 规则列表的右部, candidate_rule_set()中用到
  88. """
  89. pruned = []
  90. for conseq in h:
  91. conf = support_data[freq_set] / support_data[freq_set - conseq]
  92. if conf >= min_reliability:
  93. brl.append((freq_set - conseq, conseq, conf))
  94. pruned.append(conseq)
  95. return pruned
  96. def candidate_rule_set(self, freq_set, h, support_data, brl, min_reliability=0.7):
  97. """生成候选规则集
  98. :param freq_set: 频繁项集
  99. :param h: 元素列表
  100. :param support_data: 项集的支持度
  101. :param brl: 生成的关联规则
  102. :param min_reliability: 最小置信度
  103. :return: 下一层候选规则集
  104. """
  105. m = len(h[0])
  106. if len(freq_set) > m + 1:
  107. hmp1 = self.create_new_itemset(h, m + 1)
  108. hmp1 = self.calc_reliability(freq_set, hmp1, support_data, brl, min_reliability)
  109. if len(hmp1) > 1:
  110. self.candidate_rule_set(freq_set, hmp1, support_data, brl, min_reliability)
  111. def generate_rules(self, frequently_itemset, support_data, min_reliability=0.7):
  112. """关联规则生成
  113. :param frequently_itemset: 频繁项集
  114. :param support_data: 频繁项集的支持度
  115. :param min_reliability: 最小置信度
  116. :return: 包含可信度的规则列表
  117. """
  118. big_rule_list = []
  119. for i in range(1, len(frequently_itemset)):
  120. for freq_set in frequently_itemset[i]:
  121. h1 = [frozenset([item]) for item in freq_set]
  122. if i > 1:
  123. # 三个及以上元素的集合
  124. self.candidate_rule_set(freq_set, h1, support_data, big_rule_list, min_reliability)
  125. else:
  126. # 两个元素的集合
  127. self.calc_reliability(freq_set, h1, support_data, big_rule_list, min_reliability)
  128. return big_rule_list