apriori.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # author: Justin Cui
  2. # date: 2019/10/23
  3. # email: 321923502@qq.com
  4. from numpy import *
  5. def load_data():
  6. dataSet = [['bread', 'milk', 'vegetable', 'fruit', 'eggs'],
  7. ['noodle', 'beef', 'pork', 'water', 'socks', 'gloves', 'shoes', 'rice'],
  8. ['socks', 'gloves'],
  9. ['bread', 'milk', 'shoes', 'socks', 'eggs'],
  10. ['socks', 'shoes', 'sweater', 'cap', 'milk', 'vegetable', 'gloves'],
  11. ['eggs', 'bread', 'milk', 'fish', 'crab', 'shrimp', 'rice']]
  12. return dataSet
  13. # 扫描全部数据,产生c1
  14. def create_c1(data):
  15. c1 = []
  16. for transaction in data:
  17. for item in transaction:
  18. if [item] not in c1:
  19. c1.append([item])
  20. c1.sort()
  21. return list(map(frozenset, c1))
  22. # 由c(i)生成对应的l(i)
  23. def c2l(data, ck, min_support):
  24. dict_sup = {}
  25. for i in data:
  26. for j in ck:
  27. if j.issubset(i):
  28. if j not in dict_sup:
  29. dict_sup[j] = 1
  30. else:
  31. dict_sup[j] += 1
  32. support_data = {}
  33. result_list = []
  34. for i in dict_sup:
  35. temp_sup = dict_sup[i] / len(data)
  36. if temp_sup >= min_support:
  37. result_list.append(i)
  38. support_data[i] = temp_sup
  39. return result_list, support_data
  40. # 由l(k-1)生成c(k)
  41. def get_next_c(Lk, k):
  42. result_list = []
  43. len_lk = len(Lk)
  44. for i in range(len_lk):
  45. for j in range(i + 1, len_lk):
  46. l1 = list(Lk[i])[:k]
  47. l2 = list(Lk[j])[:k]
  48. if l1 == l2:
  49. a = Lk[i] | Lk[j]
  50. a1 = list(a)
  51. b = []
  52. for q in range(len(a1)):
  53. t = [a1[q]]
  54. tt = frozenset(set(a1) - set(t))
  55. b.append(tt)
  56. t = 0
  57. for w in b:
  58. if w in Lk:
  59. t += 1
  60. if t == len(b):
  61. result_list.append(b[0] | b[1])
  62. return result_list
  63. # 得到所有的l集
  64. def get_all_l(data_set, min_support):
  65. c1 = create_c1(data_set)
  66. data = list(map(set, data_set))
  67. L1, support_data = c2l(data, c1, min_support)
  68. L = [L1]
  69. k = 2
  70. while (len(L[k - 2]) > 0):
  71. Ck = get_next_c(L[k - 2], k - 2)
  72. Lk, sup = c2l(data, Ck, min_support)
  73. support_data.update(sup)
  74. L.append(Lk)
  75. k += 1
  76. del L[-1]
  77. return L, support_data
  78. # 得到所有L集的子集
  79. def get_subset(from_list, result_list):
  80. for i in range(len(from_list)):
  81. t = [from_list[i]]
  82. tt = frozenset(set(from_list) - set(t))
  83. if tt not in result_list:
  84. result_list.append(tt)
  85. tt = list(tt)
  86. if len(tt) > 1:
  87. get_subset(tt, result_list)
  88. # 计算置信度
  89. def calc_conf(freqSet, H, supportData, min_conf):
  90. for conseq in H:
  91. conf = supportData[freqSet] / supportData[freqSet - conseq]
  92. lift = supportData[freqSet] / (supportData[conseq] * supportData[freqSet - conseq])
  93. if conf >= min_conf and lift > 1:
  94. print(set(freqSet - conseq), '-->', set(conseq), '支持度', round(supportData[freqSet - conseq], 2), '置信度:',
  95. conf)
  96. # 生成规则
  97. def gen_rule(L, support_data, min_conf=0.7):
  98. for i in range(len(L)):
  99. print("\n", i + 1, "-频繁项集为:")
  100. for freqSet in L[i]:
  101. print(set(freqSet), end=" ")
  102. print("\n")
  103. for i in range(1, len(L)):
  104. for freqSet in L[i]:
  105. H1 = list(freqSet)
  106. all_subset = []
  107. get_subset(H1, all_subset)
  108. calc_conf(freqSet, all_subset, support_data, min_conf)
  109. if __name__ == '__main__':
  110. dataSet = load_data()
  111. L, supportData = get_all_l(dataSet, 0.5)
  112. gen_rule(L, supportData, 0.6)