perceptron_and_linear.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from model.perceptron import Perceptron
  2. from model.linearunit import LinearUnit
  3. from exts import *
  4. import matplotlib.pyplot as plt
  5. def perceptron_test():
  6. """感知器"""
  7. # 实现and函数,基于and真值表构建训练数据
  8. x = [[0, 0], [1, 0], [0, 1], [1, 1]]
  9. y = [0, 0, 0, 1]
  10. # 创建感知器,输入参数为2(因为and是二元函数),激活函数为relu
  11. p = Perceptron(2, relu)
  12. # 训练5次,学习速率为0.1
  13. p.fix(x, y, 5, 0.1)
  14. print(p)
  15. # 验证模型
  16. print("0 and 0 = %d" % p.predict([0, 0]))
  17. print("0 and 1 = %d" % p.predict([0, 1]))
  18. print("1 and 0 = %d" % p.predict([1, 0]))
  19. print("1 and 1 = %d" % p.predict([1, 1]))
  20. def linear_test():
  21. """线性单元"""
  22. # 生成5个人的收入数据
  23. x = [[5], [3], [8], [1.4], [10.1], [2]]
  24. y = [5500, 2300, 7600, 1800, 11400, 2000]
  25. # 创建感知器,线性单元
  26. lu = LinearUnit(1)
  27. lu.fix(x, y, 10, 0.01)
  28. print(lu)
  29. print(lu.predict([3.4]))
  30. # 画图
  31. fig = plt.figure()
  32. ax = fig.add_subplot(111)
  33. ax.scatter(list(map(lambda x1: x1[0], x)), y)
  34. weights = lu.w
  35. bias = lu.b
  36. x = range(0, 12, 1)
  37. y = list(map(lambda x: weights[0] * x + bias, x))
  38. ax.plot(x, y)
  39. plt.show()
  40. def main():
  41. # 感知器
  42. perceptron_test()
  43. # 线性单元
  44. linear_test()
  45. if __name__ == '__main__':
  46. main()