lstm.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import numpy as np
  2. from functools import reduce
  3. from model import element_wise_op, ReluActivator, IdentityActivator, SigmoidActivator, TanhActivator
  4. class LstmLayer(object):
  5. """ 实现LSTM层 """
  6. def __init__(self, input_width, state_width, learning_rate):
  7. """设置超参数,初始化LSTM层
  8. :param input_width: 输入维度
  9. :param state_width: 保存状态的向量维度
  10. :param learning_rate: 学习速率
  11. """
  12. self.input_width = input_width
  13. self.state_width = state_width
  14. self.learning_rate = learning_rate
  15. # 门的激活函数
  16. self.gate_activator = SigmoidActivator()
  17. # 输出的激活函数
  18. self.output_activator = TanhActivator()
  19. # 当前时刻初始化为t0
  20. self.times = 0
  21. # 各个时刻的单元状态向量c
  22. self.c_list = self.init_state_vec()
  23. # 各个时刻的输出向量h
  24. self.h_list = self.init_state_vec()
  25. # 各个时刻的遗忘门f
  26. self.f_list = self.init_state_vec()
  27. # 各个时刻的输入门i
  28. self.i_list = self.init_state_vec()
  29. # 各个时刻的输出门o
  30. self.o_list = self.init_state_vec()
  31. # 各个时刻的即时状态c
  32. self.ct_list = self.init_state_vec()
  33. # 遗忘门权重矩阵Wfh, Wfx, 偏置项bf
  34. self.Wfh, self.Wfx, self.bf = (self.init_weight_mat())
  35. # 输入门权重矩阵Wfh, Wfx, 偏置项bf
  36. self.Wih, self.Wix, self.bi = (self.init_weight_mat())
  37. # 输出门权重矩阵Wfh, Wfx, 偏置项bf
  38. self.Woh, self.Wox, self.bo = (self.init_weight_mat())
  39. # 单元状态权重矩阵Wfh, Wfx, 偏置项bf
  40. self.Wch, self.Wcx, self.bc = (self.init_weight_mat())
  41. def init_state_vec(self):
  42. """初始化保存状态的向量
  43. :return:
  44. """
  45. state_vec_list = [np.zeros((self.state_width, 1))]
  46. return state_vec_list
  47. def init_weight_mat(self):
  48. """初始化权重矩阵
  49. :return:
  50. """
  51. Wh = np.random.uniform(-1e-4, 1e-4, (self.state_width, self.state_width))
  52. Wx = np.random.uniform(-1e-4, 1e-4, (self.state_width, self.input_width))
  53. b = np.zeros((self.state_width, 1))
  54. return Wh, Wx, b
  55. def forward(self, x):
  56. """前向计算
  57. :param x: 输入
  58. :return:
  59. """
  60. self.times += 1
  61. # 遗忘门
  62. fg = self.calc_gate(x, self.Wfx, self.Wfh, self.bf, self.gate_activator)
  63. self.f_list.append(fg)
  64. # 输入门
  65. ig = self.calc_gate(x, self.Wix, self.Wih, self.bi, self.gate_activator)
  66. self.i_list.append(ig)
  67. # 输出门
  68. og = self.calc_gate(x, self.Wox, self.Woh, self.bo, self.gate_activator)
  69. self.o_list.append(og)
  70. # 即时状态
  71. ct = self.calc_gate(x, self.Wcx, self.Wch, self.bc, self.output_activator)
  72. self.ct_list.append(ct)
  73. # 单元状态
  74. c = fg * self.c_list[self.times - 1] + ig * ct
  75. self.c_list.append(c)
  76. # 输出
  77. h = og * self.output_activator.forward(c)
  78. self.h_list.append(h)
  79. def calc_gate(self, x, Wx, Wh, b, activator):
  80. """计算门
  81. :param x: 输入
  82. :param Wx: 权重矩阵Wx
  83. :param Wh: 权重矩阵Wh
  84. :param b: 偏置值
  85. :param activator: 激活函数
  86. :return:
  87. """
  88. h = self.h_list[self.times - 1] # 上次的LSTM输出
  89. net = np.dot(Wh, h) + np.dot(Wx, x) + b
  90. gate = activator.forward(net)
  91. return gate
  92. def backward(self, x, delta_h, activator):
  93. """实现LSTM训练算法
  94. :param x: 输入
  95. :param delta_h: 上层误差矩阵
  96. :param activator: 激活函数
  97. :return:
  98. """
  99. self.calc_delta(delta_h, activator)
  100. self.calc_gradient(x)
  101. def calc_delta(self, delta_h, activator):
  102. """计算误差
  103. :param delta_h: 上层误差矩阵
  104. :param activator: 激活函数
  105. :return:
  106. """
  107. # 初始化各个时刻的误差项
  108. self.delta_h_list = self.init_delta() # 输出误差项
  109. self.delta_o_list = self.init_delta() # 输出门误差项
  110. self.delta_i_list = self.init_delta() # 输入门误差项
  111. self.delta_f_list = self.init_delta() # 遗忘门误差项
  112. self.delta_ct_list = self.init_delta() # 即时输出误差项
  113. # 保存从上一层传递下来的当前时刻的误差项
  114. self.delta_h_list[-1] = delta_h
  115. # 迭代计算每个时刻的误差项
  116. for k in range(self.times, 0, -1):
  117. self.calc_delta_k(k)
  118. def init_delta(self):
  119. """ 初始化误差项 """
  120. delta_list = []
  121. for i in range(self.times + 1):
  122. delta_list.append(np.zeros((self.state_width, 1)))
  123. return delta_list
  124. def calc_delta_k(self, k):
  125. """根据k时刻的delta_h,计算k时刻的delta_f、
  126. delta_i、delta_o、delta_ct,以及k-1时刻的delta_h
  127. :param k: 时刻
  128. :return:
  129. """
  130. # 获得k时刻前向计算的值
  131. ig = self.i_list[k]
  132. og = self.o_list[k]
  133. fg = self.f_list[k]
  134. ct = self.ct_list[k]
  135. c = self.c_list[k]
  136. c_prev = self.c_list[k - 1]
  137. tanh_c = self.output_activator.forward(c)
  138. delta_k = self.delta_h_list[k]
  139. # 根据式9计算delta_o
  140. delta_o = (delta_k * tanh_c * self.gate_activator.backward(og))
  141. delta_f = (delta_k * og * (1 - tanh_c * tanh_c) * c_prev * self.gate_activator.backward(fg))
  142. delta_i = (delta_k * og * (1 - tanh_c * tanh_c) * ct * self.gate_activator.backward(ig))
  143. delta_ct = (delta_k * og * (1 - tanh_c * tanh_c) * ig * self.output_activator.backward(ct))
  144. delta_h_prev = (np.dot(delta_o.transpose(), self.Woh) + np.dot(delta_i.transpose(), self.Wih) + np.dot(
  145. delta_f.transpose(), self.Wfh) + np.dot(delta_ct.transpose(), self.Wch)).transpose()
  146. # 保存全部delta值
  147. self.delta_h_list[k - 1] = delta_h_prev
  148. self.delta_f_list[k] = delta_f
  149. self.delta_i_list[k] = delta_i
  150. self.delta_o_list[k] = delta_o
  151. self.delta_ct_list[k] = delta_ct
  152. def calc_gradient(self, x):
  153. """计算梯度
  154. :param x: 输入
  155. :return:
  156. """
  157. # 初始化遗忘门权重梯度矩阵和偏置项
  158. self.Wfh_grad, self.Wfx_grad, self.bf_grad = (self.init_weight_gradient_mat())
  159. # 初始化输入门权重梯度矩阵和偏置项
  160. self.Wih_grad, self.Wix_grad, self.bi_grad = (self.init_weight_gradient_mat())
  161. # 初始化输出门权重梯度矩阵和偏置项
  162. self.Woh_grad, self.Wox_grad, self.bo_grad = (self.init_weight_gradient_mat())
  163. # 初始化单元状态权重梯度矩阵和偏置项
  164. self.Wch_grad, self.Wcx_grad, self.bc_grad = (self.init_weight_gradient_mat())
  165. # 计算对上一次输出h的权重梯度
  166. for t in range(self.times, 0, -1):
  167. # 计算各个时刻的梯度
  168. (Wfh_grad, bf_grad,
  169. Wih_grad, bi_grad,
  170. Woh_grad, bo_grad,
  171. Wch_grad, bc_grad) = (self.calc_gradient_t(t))
  172. # 实际梯度是各时刻梯度之和
  173. self.Wfh_grad += Wfh_grad
  174. self.bf_grad += bf_grad
  175. self.Wih_grad += Wih_grad
  176. self.bi_grad += bi_grad
  177. self.Woh_grad += Woh_grad
  178. self.bo_grad += bo_grad
  179. self.Wch_grad += Wch_grad
  180. self.bc_grad += bc_grad
  181. # 计算对本次输入x的权重梯度
  182. xt = x.transpose()
  183. self.Wfx_grad = np.dot(self.delta_f_list[-1], xt)
  184. self.Wix_grad = np.dot(self.delta_i_list[-1], xt)
  185. self.Wox_grad = np.dot(self.delta_o_list[-1], xt)
  186. self.Wcx_grad = np.dot(self.delta_ct_list[-1], xt)
  187. def init_weight_gradient_mat(self):
  188. """ 初始化权重矩阵 """
  189. Wh_grad = np.zeros((self.state_width, self.state_width))
  190. Wx_grad = np.zeros((self.state_width, self.input_width))
  191. b_grad = np.zeros((self.state_width, 1))
  192. return Wh_grad, Wx_grad, b_grad
  193. def calc_gradient_t(self, t):
  194. """计算每个时刻t权重的梯度
  195. :param t: 时刻
  196. :return:
  197. """
  198. h_prev = self.h_list[t - 1].transpose()
  199. Wfh_grad = np.dot(self.delta_f_list[t], h_prev)
  200. bf_grad = self.delta_f_list[t]
  201. Wih_grad = np.dot(self.delta_i_list[t], h_prev)
  202. bi_grad = self.delta_f_list[t]
  203. Woh_grad = np.dot(self.delta_o_list[t], h_prev)
  204. bo_grad = self.delta_f_list[t]
  205. Wch_grad = np.dot(self.delta_ct_list[t], h_prev)
  206. bc_grad = self.delta_ct_list[t]
  207. return Wfh_grad, bf_grad, Wih_grad, bi_grad, Woh_grad, bo_grad, Wch_grad, bc_grad
  208. def update(self):
  209. """ 按照梯度下降,更新权重 """
  210. self.Wfh -= self.learning_rate * self.Whf_grad
  211. self.Wfx -= self.learning_rate * self.Whx_grad
  212. self.bf -= self.learning_rate * self.bf_grad
  213. self.Wih -= self.learning_rate * self.Whi_grad
  214. self.Wix -= self.learning_rate * self.Whi_grad
  215. self.bi -= self.learning_rate * self.bi_grad
  216. self.Woh -= self.learning_rate * self.Wof_grad
  217. self.Wox -= self.learning_rate * self.Wox_grad
  218. self.bo -= self.learning_rate * self.bo_grad
  219. self.Wch -= self.learning_rate * self.Wcf_grad
  220. self.Wcx -= self.learning_rate * self.Wcx_grad
  221. self.bc -= self.learning_rate * self.bc_grad
  222. def reset_state(self):
  223. """ 重置内部状态 """
  224. # 当前时刻初始化为t0
  225. self.times = 0
  226. # 各个时刻的单元状态向量c
  227. self.c_list = self.init_state_vec()
  228. # 各个时刻的输出向量h
  229. self.h_list = self.init_state_vec()
  230. # 各个时刻的遗忘门f
  231. self.f_list = self.init_state_vec()
  232. # 各个时刻的输入门i
  233. self.i_list = self.init_state_vec()
  234. # 各个时刻的输出门o
  235. self.o_list = self.init_state_vec()
  236. # 各个时刻的即时状态c~
  237. self.ct_list = self.init_state_vec()