2017年陷舅,Google發(fā)表論文《Attention is All You Need》伐弹,提出經(jīng)典網(wǎng)絡(luò)結(jié)構(gòu)Transformer披摄,全部采用Attention結(jié)構(gòu)的方式馆铁,代替了傳統(tǒng)的Encoder-Decoder框架必須結(jié)合CNN或RNN的固有模式跑揉。并在兩項機器翻譯任務(wù)中取得了顯著效果。該論文一經(jīng)發(fā)出埠巨,便引起了業(yè)界的廣泛關(guān)注历谍,同時,Google于2018年發(fā)布的劃時代模型BERT也是在Transformer架構(gòu)上發(fā)展而來辣垒。所以望侈,為了之后學習的必要,本文將詳細介紹Transformer模型的網(wǎng)絡(luò)結(jié)構(gòu)勋桶。
1脱衙、整體架構(gòu)
Transformer作為seq2seq,也是由經(jīng)典的Encoder-Decoder模型組成例驹。在上圖中捐韩,整個Encoder層由6個左邊Nx部分的結(jié)構(gòu)組成。整個Decoder由6個右邊Nx部分的框架組成鹃锈,Decoder輸出的結(jié)果經(jīng)過一個線性層變換后荤胁,經(jīng)過softmax層計算,輸出最終的預(yù)測結(jié)果仪召。
(1)寨蹋、Encoder結(jié)構(gòu):
輸入序列X經(jīng)過word embedding和positional encoding做直接加和后,作為Encoder部分的輸入扔茅。輸入向量經(jīng)過一個multi-head self-attention層后已旧,做一次residual connection(殘差連接)和Layer Normalization(層歸一化,下文中簡稱LN)召娜,輸入到下一層position-wise feed-forward network中运褪。之后再進行一次殘差連接+LN,輸出到Decoder部分玖瘸,這里所涉及到的相關(guān)知識會在下文中詳細介紹秸讹。
(2)、Decoder結(jié)構(gòu):
輸出序列Y經(jīng)過word embedding和positional encoding做直接加和后雅倒,作為Decoder部分的輸入璃诀。很多對seq2seq不了解的朋友看到這里可能有些糊涂,簡單說明以下蔑匣。以翻譯任務(wù)為例劣欢,假設(shè)我們要進行一個中譯英任務(wù)棕诵。我們現(xiàn)在有一段中文序列X,對應(yīng)的英文序列Y凿将。我們在翻譯出某個單詞Yt時校套,并非只是用中文序列X翻譯,而是用中文序列X加已經(jīng)翻譯出來的英文序列(y1,y2,……yt-1)進行翻譯牧抵,所以也要將已經(jīng)翻譯出來的英文序列輸入其中笛匙。這也就解釋了為什么會將輸出序列Y作為Decoder的輸入。在論文中犀变,在訓練過程中為了處理方便同時不引入未來信息妹孙,采用了一種sequence masking機制,具體的實現(xiàn)下文再詳細介紹获枝。
Decoder部分的輸入向量首先經(jīng)過一層multi-head self-attention涕蜂,進行一次殘差連接+LN,再經(jīng)過一層multi-head context-attention映琳,進行一次殘差連接+LN,最后再經(jīng)過一層position-wise feed-forward network蜘拉,進行一次殘差連接+LN后萨西,輸出至線性層。
以上介紹了Encoder和Decoder的基本流程旭旭,相信大家對其中具體的實現(xiàn)還有不明白的細節(jié)谎脯,下面我就來為大家一一闡述。
2持寄、Attention機制
上文中提到了兩個Attenton結(jié)構(gòu)源梭,multi-head self-attention和multi-head context-attention可以說是本文中最重要的概念,這里來解釋下兩者的實現(xiàn)稍味,首先废麻,我們來回顧以下基礎(chǔ)的Attention機制。
(1)模庐、基礎(chǔ)Attention機制
之前曾經(jīng)寫過一篇詳細介紹Attention的文章烛愧,感興趣的朋友可以關(guān)注我的公眾號查找,這里主要使用論文中描述的方式來簡單介紹以下基礎(chǔ)Attention掂碱。
在自然語言處理中怜姿,Attention的本質(zhì)可以理解為一個查詢(query)到一些列(key - value)對的映射。以基礎(chǔ)的Attention計算公式為例:
計算attention時:第一步疼燥,將query和每個key進行相似度計算得到權(quán)重沧卢,即上圖中的第三個公式。第二步醉者,一般使用一個softmax函數(shù)將這些權(quán)重進行歸一化但狭,即上圖中的第二個公式披诗,最后將權(quán)重和相應(yīng)的鍵值value進行加權(quán)求和,得到最終的attention熟空,即第一個公式藤巢。通常key和value取值相同,例如上圖中息罗,key=value=hj, query=si-1掂咒。
其實,Google所用到的基本attention思路是與上面一致的迈喉,只是在計算Attention分數(shù)時绍刮,采用了另一種計算機制:Scaled dot-product attention
(2)、Scaled dot-product attention
Scaled dot-product attention的計算公式如下:
其實基本元素還是Q挨摸,K孩革,V三項,無非就是公式變了下得运。具體的計算圖結(jié)構(gòu)文章中也給了圖膝蜈,公式很清晰這里就不列了。
(3)Self-attention 和Context-attention
Self-attention:自己跟自己做Attention熔掺,輸入序列=輸出序列饱搏。Q=K=V。
Context-attention:Encoder輸出結(jié)果跟Decoder第一部分輸出結(jié)果之間做Attention置逻。
具體到網(wǎng)絡(luò)結(jié)構(gòu)中:
Encoder中的self-attention推沸,Q,K券坞,V均為Encoder的輸入鬓催。
Decoder中的self-attention,Q恨锚,K宇驾,V均為Decoder的輸入,也就是上一層Decoder的輸入猴伶,具體原因見Decoder的介紹飞苇。
Decoder中context-attention,Q為decoder中第一部分的輸出蜗顽,K布卡,V均為encoder的輸出。
(4)雇盖、Multi-head attention
論文中采用的Multi-head attention忿等,就是將Q, K, V先經(jīng)過一個線性映射,再在在輸入維度dk崔挖,dq贸街,dv上切分成h份庵寞,再對每一份進行Scaled dot-product attention,之后將每部分結(jié)果合并起來薛匪,經(jīng)過線性映射捐川,得到最終的輸出,結(jié)構(gòu)圖如下:
說的有些繞逸尖,舉個例子古沥,原文中d=512(即詞向量和位置向量的維度),h=8娇跟。那么假設(shè)原始輸入為[batch_size*seq_len*512]的三維表岩齿,處理后共分成8份[batch_size*seq_len* 64]的三維表,每份分別做Scaled dot-product苞俘,就是Multi-head attention了盹沈。這樣進行了h次運算,可以允許模型在不同的表示子空間中學習到相關(guān)信息吃谣。
以上就是Attention部分的全部講解乞封,說清楚這一部分,其他的都是一些零碎的細節(jié)岗憋。
3歌亲、Position-wise Feed-Forward network
一個全聯(lián)接神經(jīng)網(wǎng)絡(luò),先進行一次線性變換澜驮,再通過一次ReLU激活函數(shù),最后再進行一次線性變化惋鸥。公式如下:
4杂穷、Positional encoding
位置編碼,顧名思義卦绣,對序列中詞語的位置進行編碼耐量,公式如下:
即奇數(shù)位置用余弦編碼,偶數(shù)位置用正弦編碼滤港,最終得到一個512維的位置向量廊蜒。
5、Residual connection
殘差連接其實在很多網(wǎng)絡(luò)機構(gòu)中都有用到溅漾。原理很簡單山叮,假設(shè)一個輸入向量x,經(jīng)過一個網(wǎng)絡(luò)結(jié)構(gòu)添履,得到輸出向量f(x)屁倔,加上殘差連接,相當于在輸出向量中加入輸入向量暮胧,即輸出結(jié)構(gòu)變?yōu)閒(x)+x锐借,這樣做的好處是在對x求偏導時问麸,加入一項常數(shù)項1,避免了梯度消失的問題钞翔。
6严卖、Layer Normalization
歸一化的本質(zhì)都是將數(shù)據(jù)轉(zhuǎn)化為均值為0,方差為1的數(shù)據(jù)布轿。這樣可以減小數(shù)據(jù)的偏差哮笆,規(guī)避訓練過程中梯度消失或爆炸的情況。我們在訓練網(wǎng)絡(luò)中比較常見的歸一化方法是Batch Normalization驮捍,即在每一層輸出的每一批數(shù)據(jù)上進行歸一化疟呐。而Layer Normalization與BN稍有不同,即在每一層輸出的每一個樣本上進行歸一化东且。
7启具、Mask
mask的思想非常簡單:就是對輸入序列中沒某些值進行掩蓋,使其不起作用珊泳。在論文中鲁冯,做multi-head attention的地方用到了padding mask,在decode輸入數(shù)據(jù)中用到了sequence mask色查。
(1)薯演、padding mask
在我們輸入的數(shù)據(jù)中,因為每句話的長度不同秧了,所以要對較短的數(shù)據(jù)進行填充補齊長度跨扮。而這些填充值并沒有什么作用,為了減少填充數(shù)據(jù)對attention計算的影響验毡,采用padding mask的機制衡创,即在填充物的位置上加上一個趨緊于負無窮的負數(shù),這樣經(jīng)過softmax計算后這些位置的概率會趨近于0
(2)晶通、sequence mask
在上文中我們提到璃氢,預(yù)測t時刻的輸出值yt,應(yīng)該使用全部的輸入序列X狮辽,和t時刻之前的輸出序列(y1,y2,……,yt-1)進行預(yù)測一也。所以在訓練時,應(yīng)該將t-1時刻之后的信息全部隱藏掉喉脖。所以需要用到sequence mask椰苟。
實現(xiàn)也很簡單,就是用一個上三角矩陣树叽,上三角值均為1尊剔,下三角值均為0,對角線值為0,與輸入序列相乘须误,就達到了目的挨稿。
以上就是Transformer框架的全部知識點,BERT模型也是在此基礎(chǔ)上發(fā)展而來京痢。