在Attention Is All You Need-谷歌的"自注意力"中提到了為什么不在NLP中的原因娘汞。今天再分析一下由facebook提出采用卷積完成的seq2seq模型变姨。這篇論文早于google的attention模型涛碑,其證明了CNN在NLP領(lǐng)域也能有較好的效果夺饲。
基本卷積結(jié)構(gòu)
由于是對自然語言進(jìn)行處理,因此模型的卷積與普通的圖像卷積模型有一點(diǎn)點(diǎn)不同峻呕。從形式上看范嘱,更像是一維卷積,即卷積核只在句子的方向移動(dòng)普碎。這種卷積形式最早應(yīng)該是[1]中提出吼肥。卷積的過程如圖:
這里假設(shè)詞向量為,則卷積的核大小為
即卷積核的寬度始終等于詞向量的維度麻车,
為卷積窗口一次覆蓋的詞語個(gè)數(shù)潜沦。設(shè)卷積核個(gè)數(shù)為
,輸入
绪氛。則卷積操作:
此外,論文中還使用了一種新的非線性單元GLU(gated linear units)
[2]涝影,其方法是
表示逐個(gè)元素相乘枣察。將輸入進(jìn)行兩次不共享參數(shù)的卷積,其中一個(gè)輸出通過
sigmoid
函數(shù)作為另一個(gè)輸入的gate
燃逻,這里類似于LSTM
的gate
機(jī)制序目,因此叫做gated linear units。在[2]中已經(jīng)證明了這種非線性機(jī)制在NLP任務(wù)中要優(yōu)于其他的激活函數(shù)伯襟。
總體結(jié)構(gòu)
總體結(jié)構(gòu)還是sequence to sequence猿涨。source 和 target都會(huì)先被編碼,然后做attention姆怪。Attention輸出再加上decoder輸出用來做預(yù)測叛赚,不過這里編碼的模型從RNN變?yōu)榱薈NN。但是從細(xì)節(jié)上來看稽揭,這里仍然有一些不同于傳統(tǒng)RNN語言模型的東西:
殘差連接
文中提到為了使得模型能夠堆疊的更深使用了殘差連接俺附。對于decoder來說也就是:
其中為decoder第
層的輸出。
Multi-step Attention
這也論文對attention機(jī)制的一個(gè)改變溪掀,RNN中只會(huì)對decoder與encoder做一次attention事镣。這里可以分解為3個(gè)步驟:
- 首先并不直接使用encoder state與decoder state計(jì)算attention,而是先對decoder state做一個(gè)變換再加上target的embeding:
這里的也就是target的embeding揪胃。前面做變換是為了使得卷積輸出維度變換與embedding一致璃哟。
- 然后進(jìn)行attention:
這里做了一個(gè)點(diǎn)乘attention氛琢,是encoder堆疊的層數(shù)。
也就是encoder最后一層的輸出随闪。
- 計(jì)算出attention權(quán)重后阳似,基于
計(jì)算attention輸出:
這里計(jì)算attention輸出的時(shí)候也并不是直接基于還加上了source的embedding
。這和傳統(tǒng)的attention有一點(diǎn)差別蕴掏。論文中給出的解釋是:
Encoder outputs
represent potentially large input contexts and
provides point information about a specific input element that is useful when making a prediction
當(dāng)計(jì)算出后障般,會(huì)直接與
相加,作為下一層的輸入盛杰。這也就是Multi-step Attention挽荡。
- 最終的預(yù)測方式為:
一點(diǎn)分析
論文中提到通過不斷堆疊CNN層的方式同樣可以capture long-range dependencies compared to the chain structure modeled by recurrent networks。而在Attention Is All You Need中的解決辦法更為直接即供,既然要capture long-range dependencies定拟,直接進(jìn)行attention就對了。這樣根本就不存在long-range dependencies逗嫡,每個(gè)單詞之間的依賴都變得shorter青自。
參考
[1] Convolutional Neural Networks for Sentence Classification
[2] Language modeling with gated linear units