recursive.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import numpy as np
  2. from model import IdentityActivator
  3. class TreeNode(object):
  4. """ 树节点结构,用它保存卷积神经网络生成的整棵树 """
  5. def __init__(self, data, children=None, children_data=None):
  6. """初始化树节点结构
  7. :param data: 输入数据
  8. :param children: 子节点
  9. :param children_data: 父节点
  10. """
  11. if children_data is None:
  12. children_data = []
  13. if children is None:
  14. children = []
  15. self.parent = None
  16. self.children = children
  17. self.children_data = children_data
  18. self.data = data
  19. for child in children:
  20. child.parent = self
  21. class RecursiveLayer(object):
  22. """ 递归神经网络实现 """
  23. def __init__(self, node_width, child_count, activator, learning_rate):
  24. """递归神经网络构造函数
  25. :param node_width: 表示每个节点的向量的维度
  26. :param child_count: 每个父节点有几个子节点
  27. :param activator: 激活函数
  28. :param learning_rate: 学习速率
  29. """
  30. self.node_width = node_width
  31. self.child_count = child_count
  32. self.activator = activator
  33. self.learning_rate = learning_rate
  34. # 权重数组W
  35. self.W = np.random.uniform(-1e-4, 1e-4, (node_width, node_width * child_count))
  36. # 偏置项b
  37. self.b = np.zeros((node_width, 1))
  38. # 递归神经网络生成的树的根节点
  39. self.root = None
  40. def forward(self, *children):
  41. """前向计算
  42. 递归神经网络将这些树节点作为子节点,并计算它们的父节点。最后,将计算的父节点保存在self.root变量中
  43. :param children: 一系列的树节点对象
  44. """
  45. children_data = self.concatenate(children)
  46. parent_data = self.activator.forward(np.dot(self.W, children_data) + self.b)
  47. self.root = TreeNode(parent_data, children, children_data)
  48. def concatenate(self, tree_nodes):
  49. """将各个树节点中的数据拼接成一个长向量
  50. :param tree_nodes: 各个树节点
  51. :return:
  52. """
  53. concat = np.zeros((0, 1))
  54. for node in tree_nodes:
  55. concat = np.concatenate((concat, node.data))
  56. return concat
  57. def backward(self, parent_delta):
  58. """BPTS反向传播算法
  59. :param parent_delta: 父节点误差
  60. :return:
  61. """
  62. # 各个节点的误差项
  63. self.calc_delta(parent_delta, self.root)
  64. # 梯度
  65. self.W_grad, self.b_grad = self.calc_gradient(self.root)
  66. def calc_delta(self, parent_delta, parent):
  67. """计算每个节点的delta
  68. :param parent_delta: 父节点误差
  69. :param parent: 父节点
  70. :return:
  71. """
  72. parent.delta = parent_delta
  73. if parent.children:
  74. # 根据式2计算每个子节点的delta
  75. children_delta = np.dot(self.W.T, parent_delta) * (self.activator.backward(parent.children_data))
  76. # slices = [(子节点编号,子节点delta起始位置,子节点delta结束位置)]
  77. slices = [(i, i * self.node_width, (i + 1) * self.node_width) for i in range(self.child_count)]
  78. # 针对每个子节点,递归调用calc_delta函数
  79. for s in slices:
  80. self.calc_delta(children_delta[s[1]:s[2]], parent.children[s[0]])
  81. def calc_gradient(self, parent):
  82. """计算每个节点权重的梯度,并将它们求和,得到最终的梯度
  83. :param parent: 父节点
  84. :return:
  85. """
  86. # 初始化
  87. W_grad = np.zeros((self.node_width, self.node_width * self.child_count))
  88. b_grad = np.zeros((self.node_width, 1))
  89. if not parent.children:
  90. return W_grad, b_grad
  91. parent.W_grad = np.dot(parent.delta, parent.children_data.T)
  92. parent.b_grad = parent.delta
  93. W_grad += parent.W_grad
  94. b_grad += parent.b_grad
  95. # 将每个节点梯度求和
  96. for child in parent.children:
  97. W, b = self.calc_gradient(child)
  98. W_grad += W
  99. b_grad += b
  100. return W_grad, b_grad
  101. def update(self):
  102. """ 使用SGD算法更新权重 """
  103. self.W -= self.learning_rate * self.W_grad
  104. self.b -= self.learning_rate * self.b_grad
  105. def reset_state(self):
  106. """ 重置父节点 """
  107. self.root = None
  108. def dump(self, dump_grad=False):
  109. """ 打印递归网络 """
  110. print('root.data: %s' % self.root.data)
  111. print('root.children_data: %s' % self.root.children_data)
  112. if dump_grad:
  113. print('W_grad: %s' % self.W_grad)
  114. print('b_grad: %s' % self.b_grad)