之前或多或少都有去关注以及该方面的paper阅读,但是并没有去好好的整理该技术的整体发展,今天闲来无事,想从代码以及paper的核心思想梳理一遍。供自己后续方便查看吧。
Pointer Networks
传统的seq2seq模型是无法解决输出序列的词汇表会随着输入序列长度的改变而改变的问题,所以一个很简单的想法就是为什么输出不能直接从输入端直接拿过来呢?那直接从输入端拿过来,具体要怎么操作呢?这就是该篇文章给出的一个思路,该篇文章给出的形象解释是凸包问题,如图所示
对上图的一个简单解释:给定p1到p4四个二维坐标,找到一个凸包。答案是p1->p4->p2->p1,图a就是传统的seq2seq做法,就是把四个点的坐标作为输入序列输入进去,然后提供一个词汇表:[start, 1, 2, 3, 4, end],最后依据词汇表预测出序列[start, 1, 4, 2, 1, end],缺点作者也提到过了,对于图a的传统seq2seq模型来说,它的输出词汇表已经限定,当输入序列的长度变化的时候(如变为10个点)它根本无法预测大于4的数字。因为你的词汇表限定了最大就是4。图b是作者提出的Pointer Networks,它预测的时候每一步都找当前输入序列中权重最大的那个元素,而由于输出序列完全来自输入序列,它可以适应输入序列的长度变化。那具体的是怎么处理的呢?下面就直接从代码实现层面来简单说一下。
还是以解凸包问题说起
每一个batch5个坐标点,那最开始的输入就是:(假设batch 256)
inputs.shape (256,5,2)
假设embedding是128,那inputs 经过embedding后的shape就是:
embedded_inputs(256,5,128)
然后进行encode,假设用了LSTM,(uints假设为512)那它会输出 encoder_outputs 和 encoder_hidden,shape分别是:
encoder_outputs(256, 5, 512)
encoder_hidden(256, 512)
接下来我们就要开始decode了,重点就是decode端去实现如何直接拿输入的信息了,其实对于这种seq2seq现在都会做一个attention的操作,那该paper其实就是在attention上做了简化,通过attention的操作得到一个alpha,通过alpha间接去拿输入端embedded_inputs 的具体某一个坐标的embedding。下面看一下decode端的一个操作吧:decode的输入主要是这四个值
embedded_inputs (256, 5, 128)就是encode端的embedding
decoder_input0 (256, 128)因为是t0时刻,所以这个值最开始是随机初始化的
decoder_hidden0 (256, 512)就是拿了encode端最后一个时刻的隐状态作为decode端的开始状态
encoder_outputs (256, 5, 512)
将 decoder_input0 和 decoder_hidden0 经过一个时刻的LSTM操作得到t1 的 h_t, 然后将 encoder_outputs 和 t_1 时刻的h_t 输入到attention,此时attention操作就是计算出一个alpha。具体如何计算呢,继续往下看:
我们知道上面操作得到的 h_t维度是(256,512), encoder_outputs 维度是(256, 5, 512) ,我们将h_t 进行repeat操作,维度变成 (256, 5, 512),h_t 和 encoder_outputs做一下维度变换,变成(256,512,5),然后这个encoder_outputs进行一次Conv1d操作,其实做不做这个操作我觉得影响也不是很大,该操作是不改变维度的,所以经过Conv1d操作后维度还是(256, 5, 512),在attention这里呢我们会一开始初始化一个变量,假设是V吧,,他的维度呢就是(256,512)的矩阵。(该矩阵呢其实是为了后面计算得到alpha的一个中间变量吧)为了矩阵操作方便,我们会将V进行维度扩展,变成(256,1,512),然后做一个这样的操作
1 | att = torch.bmm(V, self.tanh(h_t+ encoder_outputs)).squeeze(1) |
所以此时得到的 att 的维度就是 (256,5),此时呢,直接对这个att 进行 softmax操作,得到alpha,然后将 alpha 和 encoder_outputs做一个计算,得到一个 hidden_state
1 | hidden_state = torch.bmm(encoder_outputs, alpha.unsqueeze(2)).squeeze(2) |
至此,attention操作就结束了,最后返回的就是 alpha 和 hidden_state
那到此呢,我们只是拿到了alpha而已,那怎么通过这个alpha直接到embedded_inputs去拿对应索引的embeding呢?接着往下看:
拿到的alpha会做一个max操作,如下:
1 | max_probs, indices = alpha.max(1) |
到现在为止,已经解释了如何直接从输入端来取embedding来做为decode端的输入了,但是最终我们要拿到这个凸包的输出还没有解释,下面就简单来看一下吧,其实很简单了:
1 | outputs.append(alpha.unsqueeze(0)) |
最终,返回的outputs 会参与模型计算loss,至此,整个 Pointer Network 的代码实现就解释完了。或许有点懵。直接阅读代码吧,配合这个解释会特别清晰
阅读code
阅读code
阅读code
重要的事情说三遍,Down
Get To The Point: Summarization with Pointer-Generator Networks
在这篇论文中,作者认为,用于文本摘要的seq2seq模型往往存在两大缺陷:
- 模型容易不准确地再现事实细节,也就是说模型生成的摘要不准确;
- 往往会重复,也就是会重复生成一些词或者句子。而针对这两种缺陷,作者分别使用Pointer Networks和Coverage技术来解决
- 作者给了一张效果图如下:
在这张图中,基础的seq2seq模型的预测结果存在许多谬误的句子,同时如nigeria这样的单词反复出现(红色部分)。这也就印证了作者提出的基础seq2seq在文本摘要时存在的问题;Pointer-Generator模型,也就是在seq2seq基础上加上Pointer Networks的模型基本可以做到不出现事实性的错误,但是重复预测句子的问题仍然存在(绿色部分);最后,在Pointer-Generator模型上增加Coverage机制,可以看出,这次模型预测出的摘要不仅做到了事实正确,同时避免了重复某些句子的问题(摘要结果来自原文中的蓝色部分)
那么,Pointer-Generator模型以及变体Pointer-Generator+Coverage模型是怎么做的呢,我们具体从代码层面来分析一下
既然是Pointer Network的进一步改进,那首先想到的就是,如何像Pointer Network那样输出能跟着输入的改变而改变吧?为什么要有这样的操作,其实就是为了解决OOV的问题嘛,假设词表是10000,当你输入的某一个词不在词表中的时候,是不是就要用UNK来代替了,而且这个词也出现在decode端,那么decode端也是UNK了。所以为了解决这个问题,就有了,词表随着输入的扩大而扩大,具体代码体现如下:
1 | def article2ids(article_words, vocab): |
从上面代码可以很清晰的看到,但你输入的词超过词表时,问题也不大,词表跟着扩大就行了。
接着,现在是可以做到词表跟着输入的变化而变化了,但是接下来要怎么做呢?我们知道正常的seq2seq,encode端将词变成索引只要做这样一个操作就是了
1 | self.enc_input = [vocab.word2id(w) for w in article_words] |
这种操作当出现一个词不在词表中时,就会出现UNK对应的索引了。如果是这样来实现的话:
1 | if config.pointer_gen: |
因为词表跟着输入的扩充变化而变化,所以可以知道 self.enc_input_extend_vocab 列表里是不会出现UNK对应的索引的,至此输入端的情况就应该很清楚了。接着就是encode decode 的一些操作了
encode端其实是很简单的一些操作,核心代码如下:
1 | def forward(self, input, seq_lens): |
decode端稍微复杂一些,但是其实和Pointer Network 没什么的大的区别,也是在Attention操作的时候,直接拿到一个 (batch_size,seq_length)的概率分布矩阵,直接softmax操作,作为概率返回,看一下attention的核心代码吧
1 | def forward(self, s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage): |
可以看到attn_dist其实和Pointer Network的实现或者说操作是一样的,没有什么大的区别,coverage是为了惩罚重复出现问题而计算的一个矩阵
在attention计算完毕,decode端是如何来计算整体的概率分布呢?还是直接看代码吧
1 | p_gen = None |
至此,我觉得我自己应该大差不差的可以搞清楚了。如果你没有搞明白,还是那句话,代码面前没有秘密。看代码去吧(插一句,这份代码中有好几个地方我觉得是有待考虑的,在实现上,但是主体还是OK的)
CopyNet
该篇文章开篇作者提到要解决的问题就是赋予seq2seq复制的能力,如下所示:
从这个例子中我们可以看到,针对绿色的这部分词汇其实是不需要去理解语意的,直接从输入端copy到输出端就可以了,那我们该如何去实现这个功能呢?,下面我们直接从代码上进行解释:
encode
1 | class CopyEncoder(nn.Module): |
encod部分代码其实不需要进行什么解释了,很容易就理解了
decode
1 | # 1. input_idx 就是decode端第一次的输入,weighted第一次也是初始化的 |
其实代码后面针对weighted 和 state 还做了一部分操作,这都是次要的,只要理解了out的全部计算过程,可以说CopyNet 的核心思想你也就掌握了,其实这里和上一篇文章没有什么大的差别,这里是直接相加,上一篇文章弄了个软概率来合并,还有一个覆盖操作,大差不差吧。重要的事情说三遍。
阅读code
阅读code
阅读code
结论
Seq2seq 的copy机制可以暂时告一段落啦。。。