123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- class Apriori:
- """ Apriori算法 """
- def apriori(self, data, min_support=0.5):
- """apriori算法实现
- :param data: 数据集
- :param min_support: 最小支持度
- :return: 频繁项集,频繁项集的支持度
- """
-
- itemset_1 = self.create_itemset(data)
-
- data = list(map(set, data))
-
- frequently_itemset_1, support_data = self.scan_data(data, itemset_1, min_support)
-
- frequently_itemset = [frequently_itemset_1]
- k = 2
-
- while len(frequently_itemset[k - 2]) > 0:
-
- itemset_k = self.create_new_itemset(frequently_itemset[k - 2], k)
- frequently_itemset_k, support_k = self.scan_data(data, itemset_k, min_support)
- support_data.update(support_k)
- frequently_itemset.append(frequently_itemset_k)
- k += 1
- return frequently_itemset, support_data
- def create_itemset(self, data):
- """创建元素为1的项集
- :param data: 原始数据集
- :return:创建元素为1的项集
- """
- itemset = []
-
- for items in data:
- for item in items:
- if [item] not in itemset:
- itemset.append([item])
- itemset.sort()
-
- return list(map(frozenset, itemset))
- def scan_data(self, data, k, min_support):
- """找出候选集中的频繁项集
- :param data: 全部数据集
- :param k: 为大小为包含k个元素的候选项集
- :param min_support: 设定的最小支持度
- :return: frequently_itemset为在k中找出的频繁项集(支持度大于min_support的),support_data记录各频繁项集的支持度
- """
- scan_itemset = {}
- for i in data:
- for j in k:
- if j.issubset(i):
- scan_itemset[j] = scan_itemset.get(j, 0) + 1
- items_num = float(len(list(data)))
- frequently_itemset = []
- support_data = {}
- for key in scan_itemset:
- support = scan_itemset[key] / items_num
- if support >= min_support:
- frequently_itemset.insert(0, key)
- support_data[key] = support
- return frequently_itemset, support_data
- def create_new_itemset(self, frequently_itemset, k):
- """通过频繁项集列表frequently_itemset和项集个数k生成候选项集
- :param frequently_itemset: 频繁项集列表
- :param k: 项集个数
- :return: 候选项集
- """
- new_frequently_itemset = []
- frequently_itemset_len = len(frequently_itemset)
- for i in range(frequently_itemset_len):
- for j in range(i + 1, frequently_itemset_len):
-
- l1 = list(frequently_itemset[i])[: k - 2]
- l2 = list(frequently_itemset[j])[: k - 2]
- l1.sort()
- l2.sort()
- if l1 == l2:
- new_frequently_itemset.append(frequently_itemset[i] | frequently_itemset[j])
- return new_frequently_itemset
- def calc_reliability(self, freq_set, h, support_data, brl, min_reliability=0.7):
- """对候选规则集进行评估
- :param freq_set: 频繁项集
- :param h: 元素列表
- :param support_data: 项集的支持度
- :param brl: 生成的关联规则
- :param min_reliability: 最小置信度
- :return: 规则列表的右部, candidate_rule_set()中用到
- """
- pruned = []
- for conseq in h:
- conf = support_data[freq_set] / support_data[freq_set - conseq]
- if conf >= min_reliability:
- brl.append((freq_set - conseq, conseq, conf))
- pruned.append(conseq)
- return pruned
- def candidate_rule_set(self, freq_set, h, support_data, brl, min_reliability=0.7):
- """生成候选规则集
- :param freq_set: 频繁项集
- :param h: 元素列表
- :param support_data: 项集的支持度
- :param brl: 生成的关联规则
- :param min_reliability: 最小置信度
- :return: 下一层候选规则集
- """
- m = len(h[0])
- if len(freq_set) > m + 1:
- hmp1 = self.create_new_itemset(h, m + 1)
- hmp1 = self.calc_reliability(freq_set, hmp1, support_data, brl, min_reliability)
- if len(hmp1) > 1:
- self.candidate_rule_set(freq_set, hmp1, support_data, brl, min_reliability)
- def generate_rules(self, frequently_itemset, support_data, min_reliability=0.7):
- """关联规则生成
- :param frequently_itemset: 频繁项集
- :param support_data: 频繁项集的支持度
- :param min_reliability: 最小置信度
- :return: 包含可信度的规则列表
- """
- big_rule_list = []
- for i in range(1, len(frequently_itemset)):
- for freq_set in frequently_itemset[i]:
- h1 = [frozenset([item]) for item in freq_set]
- if i > 1:
-
- self.candidate_rule_set(freq_set, h1, support_data, big_rule_list, min_reliability)
- else:
-
- self.calc_reliability(freq_set, h1, support_data, big_rule_list, min_reliability)
- return big_rule_list
|