LSTM长短期记忆网络

LSTM长短期记忆网络

什么是LSTM

LSTM网络如下图所示,虽然看起来太复杂了,第一次看到的人都是一脸懵逼,这是啥?有用吗?但是原理不难,我们一点点分析。

LSTM的结构

前面讲的SimpleRNN只是RNN中最简单的版本,里面的memory是最简单的,可以随时把值写入,也可随时把值读出。但现在常用的memory称为Long Short-term Memory(LSTM)。

这个Long Short-term Memory是比较复杂的,它有3个Gate。

  • Input Gate

    当某个网络(外界)的输出值想要写入到LSTM里面的时候,必须先通过一个闸门(Input Gate),并且只有当Input Gate打开的时候,才能把值写入到memory中去,当Input Gate关起来的时候,就无法写入值。至于什么时候把Input Gate打开还是关闭,这是网络自己学到的

  • Output Gate

    输出的地方有一个Output Gate,可以决定外界是否可以把值从memory里面读出。只有当Output Gate打开的时候,外界才能读取。同样的,Output Gate什么时候打开还是关闭,也是Output Gate自己学到的。

  • Forget Gate

    Forget Gate决定什么时候把过去记得的东西(存在memory中的值)忘掉,什么时候记住过去学的东西。同样的,Output Gate什么时候把存在memory中的值清除掉或者保留下来,也是Forget Gate自己学到的。

整个LSTM单元可以看成有四个Input,一个Output。

一个有意思的知识点:LSTM的名字应该叫下面哪一个呢?

Long Short-term Memory
Long-Short term Memory

应该是第一个,因为LSTM和前面讲的SimpleRNN一样,还是个Short-term的Memory,只不过SimpleRNN的Memory每一步都会被清除并更新,它的Short-term是非常Short的。相对的,LSTM的Memory不是每一步都会被更新,可以记得比较长一些,只要Forget Gate不Forget的话,它的Memory里的值就会被存起来,所以是比较长的Short-term Memory,即称之为Long Short-term Memory。

LSTM的流程

LSTM这个Cell的具体结构和数据流程如下所示。

LSTM流程举例

假设该神经网络只有一个LSTM,如下图所示。再假设该网络的权值已知(权值是学习出来的),然后输入x序列。三个门的输入值,都是序列x乘上一个权值矩阵得到的结果

我们来依次输入x序列,看整个流程是怎么运行的。为了简单起见,假设输入门和输出门的激活函数都是线性的,memory的初始值为0,具体如下图所示。

现在输入x的第一个向量[3, 1, 0],计算流程如下图所示,具体就不细讲了,图里面的流程非常清楚,memory中的值会从0变为3+1x0=3,最后的输出是0。

接下来,输入x的第二个向量[4, 1, 0],计算流程如下图所示,memory中的值会从0变为4+1x3=7,由于输出门关闭,最后的输出是0。

接下来,输入x的第三个向量[2, 0, 0],计算流程如下图所示,memory中的值会从7变为0+1x7=7,由于输出门关闭,最后的输出是0。

接下来,输入x的第四个向量[1, 0, 1],计算流程如下图所示,memory中的值会从7变为0+1x7=7,最后的输出是7。

接下来,输入x的最后一个向量[3, -1, 0],计算流程如下图所示,memory中的值会从7变为0+0x7=0,最后的输出是0。

将LSTM的cell作为神经元

看到这里,你可能会疑惑,LSTM一个cell的结构我理解了,但是和神经网络有什么关系呢?其实,直接把神经网络中的神经元替换成LSTM的cell就可以了。

下图是一般的神经网络和里面的神经元。

然后输入x乘上不同的权值作为LSTM单元的四个输入(输入和三个门)。就好像一般的机器插一个电源线就能运行,但LSTM这个机器需要插四个电压不同的电源线才能跑。所以同样的neuron数目下,LSTM的参数是一般的四倍。

为什么LSTM属于RNN

只看上图的话,我们就会很疑惑,这个和RNN的关系是什么呢?怎么看起来不太像RNN。所以要画另外一个图来表示LSTM。

这四个vector合起来就会操控这些LSTM单元的运作。

接下来,我们向量化的并行看一下LSTM单元的运作,不再去分开单独看。注意里面的乘号是逐元素运算。

虽然上图已经很复杂了,但这并不是LSTM的最终形态,真正的LSTM怎么做呢?

上图只是LSTM只有一层的情况,但通常LSTM不会只有一层,再叠加个五六层才是我们要的样子。

每一次看到这个东西的人,它的反应都是这个样子:

每一个人第一次都看到这个图,都在想这应该是不work的吧,但其实它确实是work的。现在说自己在用RNN的时候,其实都是在用LSTM了,用Keras的时候,这些都帮你是写好的,你只要输入LSTM四个字就好了。GRU是LSTM的稍微简化版本,只有两个gate,但据说少了一个gate,但表现和LSTM差不多,所以少了三分之一的参数,比较不容易过拟合。所以,我们之前讲的那种最简单的RNN,要称其为SimpleRNN才行,即RNN包含了LSTM、GRU和SimpleRNN。

参考资料

本文参考了此视频。

===

完全图解RNN、RNN变体、Seq2Seq、Attention机制

Understanding LSTM Networks

Understanding LSTM Networks翻译:如何简单的理解LSTM——其实没有那么复杂

谁能用比较通俗有趣的语言解释RNN和LSTM?

Last updated