深度学习 | 您所在的位置:网站首页 › 摄影模型数学公式 › 深度学习 |
深度学习——LSTM原理与公式推导
1、 RNN回顾
1.1 RNN神经网络回顾
1.1.1 RNN概述
循环神经网络(RNN),主要用于出来序列式问题,通过隐藏节点之间的相互连接,赋予了整个神经网络的记忆能力。对于RNN中的每一隐藏状态而言,其输入主要包括两个部分,一部分是正常接受输入数据的输入,另外一个输是将前一个隐藏状态节点作为下一个节点的输入。 1.1.2 RNN的网络构成图
对于某一个隐藏层的状态节点 S t S_t St而言,其前一个隐藏层的状态节点为 S t − 1 S_{t-1} St−1,t时刻的输入为 X t X_t Xt,则计算出来的 S t S_t St的状态值为: S t = f w ( S t − 1 , X t ) S_t = f_w(S_{t-1},X_t) St=fw(St−1,Xt) 则 s t s_t st时刻的输出 o t o_t ot为: O t = F v ( s t ) O_t = F_v(s_t) Ot=Fv(st) 其中 f w f_w fw对应的是激活函数tanh, f v f_v fv对应的softmax函数。 1.1.4 RNN的局限性距离问题:根据上面的介绍可以知道,RNN中的记忆性是通过隐藏状态节点之间的连接实现的。前一个状态节点作为当前隐藏状态的输入,就将之前的信息输入到了当前的节点之中。但是,根据上面的传播公式,RNN采用的激活函数为tanh,而tanh的将节点的计算结果压缩到了(-1,1)之间,当节点不断的向前传播的时候,这使得从前面传来的信息越来越少。也就是说越远节点的信息对当前节点的贡献度越小。如果切换成其他的大于1的激活函数,通过节点的不断前向传播,也可能造成梯度爆炸的问题。 1.1.6 RNN的改进多层RNN网络:我们以两层的RNN网络为例,其基本的构成图如下: 根据上面的介绍我们可以知道,单层的RNN网络单元的记忆能力是有限的,即每一个神经单元,离它越近,对它的贡献度越大。为了解决这种短记忆力的局限性,我们上面提出来多层RNN的概念,下面,我们来介绍另外一个解决距离问题的神经网络。长短时记忆网络(LSTM)网络。 2.2 LSTM神经网络结构 2.2.1 LSTM的神经单元结构
其基本结构如下图: 其基本结构如图所示: 所谓的细胞状态,我们可以将其理解为一个存储信息的容器,通过输入门,遗忘门,输出门的过程控制,逐步对容器中的信息进行增变化和输出。其具体结构为:
输入: C t − 1 , S t − 1 , X t C_{t-1},S_{t-1},X_t Ct−1,St−1,Xt 遗忘门: n e t F ( t ) = W f T S t − 1 + U f T X t + B f net_F(t)=W_f^TS_{t-1}+U_f^TX_t+B_f netF(t)=WfTSt−1+UfTXt+Bf F ( t ) = s i g m o i d ( n e t F ( t ) ) F(t)=sigmoid(net_F(t)) F(t)=sigmoid(netF(t)) 细胞状态第一个改变: C t 1 = C t − 1 ∗ F ( t ) C_{t1}=C_{t-1}*F(t) Ct1=Ct−1∗F(t) 输入门: n e t I ( t ) = W i T S t − 1 + U i T X t + B i net_I(t)=W_i^TS_{t-1}+U_i^TX_t+B_i netI(t)=WiTSt−1+UiTXt+Bi I ( t ) = s i g m o i d ( n e t I ( t ) ) I(t)=sigmoid(net_I(t)) I(t)=sigmoid(netI(t)) n e t R ( t ) = W r T S t − 1 + U r T X t + B r net_R(t)=W_r^TS_{t-1}+U_r^TX_t+B_r netR(t)=WrTSt−1+UrTXt+Br R ( t ) = t a n h ( n e t R ( t ) ) R(t)=tanh(net_R(t)) R(t)=tanh(netR(t)) 细胞状态第二次改变: C t = C t 1 + I ( t ) ∗ R ( t ) C_t = C_{t1} + I(t)*R(t) Ct=Ct1+I(t)∗R(t) 输出门: n e t O ( t ) = W o T S t − 1 + U o T X t + B o net_O(t)=W_o^TS_{t-1}+U_o^TX_{t}+B_o netO(t)=WoTSt−1+UoTXt+Bo O ( t ) = s i g m o i d ( ) O(t)=sigmoid() O(t)=sigmoid() S t = t a n h ( C t ) ∗ O ( t ) S_t=tanh(C_t)*O(t) St=tanh(Ct)∗O(t) 2.4 LSTM的反向传播过程 2.4.1 误差计算现在,我们假设St时刻的总的误差为 δ S t δ_{S_t} δSt,我们来计算各个门的相关误差 首先计算输出门的误差 ∂ δ S t ∂ O ( t ) = t a n h ( C t ) \frac{∂δ_{S_t}}{∂O(t)}=tanh(C_t) ∂O(t)∂δSt=tanh(Ct) 且有: ∂ O ( t ) ∂ n e t O ( t ) = O ( t ) ∗ ( 1 − O ( t ) ) \frac{∂O(t)}{∂net_O(t)}=O(t)*(1-O(t)) ∂netO(t)∂O(t)=O(t)∗(1−O(t)) 则有: δ o ( t ) = ∂ δ S t ∂ n e t O ( t ) = t a n h ( C t ) ∗ O ( t ) ∗ ( 1 − O ( t ) ) δ_o(t)=\frac{∂δ_{S_t}}{∂net_O(t)}=tanh(C_t)*O(t)*(1-O(t)) δo(t)=∂netO(t)∂δSt=tanh(Ct)∗O(t)∗(1−O(t)) 然后计算输入门的误差: ∂ δ S t ∂ R ( t ) = ∂ δ S t ∂ t a n h ( C t ) ∗ ∂ t a n h ( C t ) ∂ C t ∗ ∂ C t ∂ R ( t ) = O ( t ) ∗ ( 1 − t a n h 2 ( C t ) ) ∗ I ( t ) \frac{∂δ_{S_t}}{∂R(t)}=\frac{∂δ_{S_t}}{∂tanh(C_t)}*\frac{∂tanh(C_t)}{∂C_t}*\frac{∂C_t}{∂R(t)}=\\ \frac{}{}\\ O(t)*(1-tanh^2(C_t) )*I(t) ∂R(t)∂δSt=∂tanh(Ct)∂δSt∗∂Ct∂tanh(Ct)∗∂R(t)∂Ct=O(t)∗(1−tanh2(Ct))∗I(t) 则有: δ R ( t ) = ∂ δ S t ∂ n e t R ( t ) = ∂ δ S t ∂ R ( t ) ∗ ∂ R ( t ) ∂ n e t R ( t ) = O ( t ) ∗ ( 1 − t a n h 2 ( C t ) ) ∗ I ( t ) ∗ ( 1 − R 2 ( t ) ) δ_R(t)=\frac{∂δ_{S_t}}{∂net_R(t)}=\frac{∂δ_{S_t}}{∂R(t)}*\frac{∂R(t)}{∂net_R(t)}=\\ \frac{}{}\\ O(t)*(1-tanh^2(C_t) )*I(t)*(1-R^2(t)) δR(t)=∂netR(t)∂δSt=∂R(t)∂δSt∗∂netR(t)∂R(t)=O(t)∗(1−tanh2(Ct))∗I(t)∗(1−R2(t)) 同理有: ∂ δ S t ∂ I ( t ) = ∂ δ S t ∂ t a n h ( C t ) ∗ ∂ t a n h ( C t ) ∂ C t ∗ ∂ C t ∂ I ( t ) = O ( t ) ∗ ( 1 − t a n h 2 ( C t ) ) ∗ R ( t ) \frac{∂δ_{S_t}}{∂I(t)}=\frac{∂δ_{S_t}}{∂tanh(C_t)}*\frac{∂tanh(C_t)}{∂C_t}*\frac{∂C_t}{∂I(t)}=\\ \frac{}{}\\ O(t)*(1-tanh^2(C_t) )*R(t) ∂I(t)∂δSt=∂tanh(Ct)∂δSt∗∂Ct∂tanh(Ct)∗∂I(t)∂Ct=O(t)∗(1−tanh2(Ct))∗R(t) 则有: δ I ( t ) = ∂ δ S t ∂ n e t I ( t ) = ∂ δ S t ∂ R ( t ) ∗ ∂ I ( t ) ∂ n e t I ( t ) = O ( t ) ∗ ( 1 − t a n h 2 ( C t ) ) ∗ R ( t ) ∗ I ( t ) ∗ ( 1 − I ( t ) ) δ_I(t)=\frac{∂δ_{S_t}}{∂net_I(t)}=\frac{∂δ_{S_t}}{∂R(t)}*\frac{∂I(t)}{∂net_I(t)}=\\ \frac{}{}\\ O(t)*(1-tanh^2(C_t) )*R(t)*I(t)*(1-I(t)) δI(t)=∂netI(t)∂δSt=∂R(t)∂δSt∗∂netI(t)∂I(t)=O(t)∗(1−tanh2(Ct))∗R(t)∗I(t)∗(1−I(t)) 然后是遗忘门的误差: ∂ δ S t ∂ F ( t ) = ∂ δ S t ∂ t a n h ( C t ) ∗ ∂ t a n h ( C t ) ∂ C t ∗ ∂ C t ∂ C ( t 1 ) ∗ ∂ C ( t 1 ) ∂ F ( t ) = O ( t ) ∗ ( 1 − t a n h 2 ( C t ) ) ∗ C t − 1 \frac{∂δ_{S_t}}{∂F(t)}=\frac{∂δ_{S_t}}{∂tanh(C_t)}*\frac{∂tanh(C_t)}{∂C_t}*\frac{∂C_t}{∂C_(t1)}*\frac{∂C_(t1)}{∂F(t)}=\\ \frac{}{}\\ O(t)*(1-tanh^2(C_t) )*C_{t-1} ∂F(t)∂δSt=∂tanh(Ct)∂δSt∗∂Ct∂tanh(Ct)∗∂C(t1)∂Ct∗∂F(t)∂C(t1)=O(t)∗(1−tanh2(Ct))∗Ct−1 则有: δ F ( t ) = ∂ δ S t ∂ n e t F ( t ) = ∂ δ S t ∂ F ( t ) ∗ ∂ F ( t ) ∂ n e t F ( t ) = O ( t ) ∗ ( 1 − t a n h 2 ( C t ) ) ∗ C t − 1 ∗ F ( t ) ∗ ( 1 − F ( t ) ) δ_F(t)=\frac{∂δ_{S_t}}{∂net_F(t)}=\frac{∂δ_{S_t}}{∂F(t)}*\frac{∂F(t)}{∂net_F(t)}=\\ \frac{}{}\\ O(t)*(1-tanh^2(C_t) )*C_{t-1}*F(t)*(1-F(t)) δF(t)=∂netF(t)∂δSt=∂F(t)∂δSt∗∂netF(t)∂F(t)=O(t)∗(1−tanh2(Ct))∗Ct−1∗F(t)∗(1−F(t)) 最后,我们要计算的是关于前一个时刻的误差: δ S t − 1 = ∂ δ S t ∂ S t − 1 = ∂ δ S t ∂ t a n h ( C t ) ∗ ∂ t a n h ( C t ) ∂ S t − 1 ∗ O ( t ) + ∂ δ S t ∂ O ( t ) ∗ ∂ O ( t ) ∂ S t − 1 ∗ t a n h ( C t ) = ∂ δ S t ∂ t a n h ( C t ) ∗ ∂ t a n h ( C t ) ∂ C t ∗ ∂ C t ∂ S t − 1 ∗ O ( t ) + t a n h ( C t ) ∗ ∂ δ S t ∂ O ( t ) ∗ ∂ O ( t ) ∂ n e t O ( t ) ∗ ∂ n e t O ( t ) ∂ S t − 1 = W f δ F ( t ) + W I δ I ( t ) + W r δ R ( t ) + W o δ O ( t ) δ_{S_{t-1}}=\frac{∂δ_{S_t}}{∂S_{t-1}}=\frac{∂δ_{S_t}}{∂tanh(C_t)}*\frac{∂tanh(C_t)}{∂S_{t-1}}*O(t)+\frac{∂δ_{S_t}}{∂O(t)}*\frac{{∂O(t)}}{∂S_{t-1}}*tanh(C_t)\\ \frac{}{}\\ =\frac{∂δ_{S_t}}{∂tanh(C_t)}*\frac{∂tanh(C_t)}{∂C_t}*\frac{∂C_t}{∂S_{t-1}}*O(t)+tanh(C_t)*\frac{∂δ_{S_t}}{∂O(t)}*\frac{{∂O(t)}}{∂net_O(t)}*\frac{∂net_O(t)}{∂S_{t-1}}\\ \frac{}{}\\=W_fδ_F(t)+W_Iδ_I(t)+W_rδ_R(t)+W_oδ_O(t) δSt−1=∂St−1∂δSt=∂tanh(Ct)∂δSt∗∂St−1∂tanh(Ct)∗O(t)+∂O(t)∂δSt∗∂St−1∂O(t)∗tanh(Ct)=∂tanh(Ct)∂δSt∗∂Ct∂tanh(Ct)∗∂St−1∂Ct∗O(t)+tanh(Ct)∗∂O(t)∂δSt∗∂netO(t)∂O(t)∗∂St−1∂netO(t)=WfδF(t)+WIδI(t)+WrδR(t)+WoδO(t) 2.4.2 梯度计算我们来分别计算该误差对于 W o , W i , W r , W f , U o , U i , U r , U f , B o , B i , B r , B f , S t − 1 W_o,W_i,W_r,W_f,U_o,U_i,U_r,U_f,B_o,B_i,B_r,B_f,S_{t-1} Wo,Wi,Wr,Wf,Uo,Ui,Ur,Uf,Bo,Bi,Br,Bf,St−1的相关梯度。 ∂ δ S t ∂ W o = S t − 1 ∂ δ S t ∂ n e t O ( t ) T = S t − 1 δ O T ( t ) \frac{∂δ_{S_t}}{∂W_o}=S_{t-1}\frac{∂δ_{S_t}}{∂net_O(t)}^T=S_{t-1}δ_O^T(t) ∂Wo∂δSt=St−1∂netO(t)∂δStT=St−1δOT(t) ∂ δ S t ∂ W r = S t − 1 ∂ δ S t ∂ n e t R ( t ) T = S t − 1 δ R T ( t ) \frac{∂δ_{S_t}}{∂W_r}=S_{t-1}\frac{∂δ_{S_t}}{∂net_R(t)}^T=S_{t-1}δ_R^T(t) ∂Wr∂δSt=St−1∂netR(t)∂δStT=St−1δRT(t) ∂ δ S t ∂ W i = S t − 1 ∂ δ S t ∂ n e t R ( t ) T = S t − 1 δ R T ( t ) \frac{∂δ_{S_t}}{∂W_i}=S_{t-1}\frac{∂δ_{S_t}}{∂net_R(t)}^T=S_{t-1}δ_R^T(t) ∂Wi∂δSt=St−1∂netR(t)∂δStT=St−1δRT(t) ∂ δ S t ∂ W f = S t − 1 ∂ δ S t ∂ n e t F ( t ) T = S t − 1 δ F T ( t ) \frac{∂δ_{S_t}}{∂W_f}=S_{t-1}\frac{∂δ_{S_t}}{∂net_F(t)}^T=S_{t-1}δ_F^T(t) ∂Wf∂δSt=St−1∂netF(t)∂δStT=St−1δFT(t) ∂ δ S t ∂ U o = X t ∂ δ S t ∂ n e t O ( t ) T = X t δ O T ( t ) \frac{∂δ_{S_t}}{∂U_o}=X_t\frac{∂δ_{S_t}}{∂net_O(t)}^T=X_tδ_O^T(t) ∂Uo∂δSt=Xt∂netO(t)∂δStT=XtδOT(t) ∂ δ S t ∂ U r = X t ∂ δ S t ∂ n e t R ( t ) T = X t δ R T ( t ) \frac{∂δ_{S_t}}{∂U_r}=X_t\frac{∂δ_{S_t}}{∂net_R(t)}^T=X_tδ_R^T(t) ∂Ur∂δSt=Xt∂netR(t)∂δStT=XtδRT(t) ∂ δ S t ∂ U i = X t ∂ δ S t ∂ n e t R ( t ) T = X t δ R T ( t ) \frac{∂δ_{S_t}}{∂U_i}=X_t\frac{∂δ_{S_t}}{∂net_R(t)}^T=X_tδ_R^T(t) ∂Ui∂δSt=Xt∂netR(t)∂δStT=XtδRT(t) ∂ δ S t ∂ U f = X t ∂ δ S t ∂ n e t F ( t ) T = X t δ F T ( t ) \frac{∂δ_{S_t}}{∂U_f}=X_t\frac{∂δ_{S_t}}{∂net_F(t)}^T=X_tδ_F^T(t) ∂Uf∂δSt=Xt∂netF(t)∂δStT=XtδFT(t) ∂ δ S t ∂ B o = δ O ( t ) \frac{∂δ_{S_t}}{∂B_o}=δ_O(t) ∂Bo∂δSt=δO(t) ∂ δ S t ∂ B I = δ I ( t ) \frac{∂δ_{S_t}}{∂B_I}=δ_I(t) ∂BI∂δSt=δI(t) ∂ δ S t ∂ B R = δ R ( t ) \frac{∂δ_{S_t}}{∂B_R}=δ_R(t) ∂BR∂δSt=δR(t) ∂ δ S t ∂ B F = δ F ( t ) \frac{∂δ_{S_t}}{∂B_F}=δ_F(t) ∂BF∂δSt=δF(t) |
CopyRight 2018-2019 实验室设备网 版权所有 |