一、簡(jiǎn)介
基于假設(shè):一個(gè)詞在句子中的意思瑰艘,與上下文(語(yǔ)境)有關(guān)泛源。與哪些詞有關(guān)呢?Transformer就是:利用點(diǎn)積將句子中所有詞的影響當(dāng)成權(quán)重都考慮了進(jìn)去漾月。
Transform模型是與RNN和CNN都完全不同的思路。相比Transformer侵蒙,RNN/CNN的問(wèn)題:
- RNN序列化處理效率提不上去赦肋。理論上,RNN效果上問(wèn)題不大谦疾。
- CNN感受野小南蹂。CNN只考慮卷積核大小區(qū)域,核內(nèi)參數(shù)共享念恍,并行/計(jì)算效率不是問(wèn)題六剥,但受限于核的大小,不能考慮整個(gè)上下文峰伙。
在并行方面疗疟,多頭attention和CNN一樣不依賴于前一時(shí)刻的計(jì)算,可以很好的并行瞳氓,優(yōu)于RNN策彤。在長(zhǎng)距離依賴上,由于self-attention是每個(gè)詞和所有詞都要計(jì)算attention匣摘,所以不管他們中間有多長(zhǎng)距離店诗,最大的路徑長(zhǎng)度也都只是1×滴郑可以捕獲長(zhǎng)距離依賴關(guān)系必搞。
二、注意力機(jī)制
注意力實(shí)際就是加權(quán)囊咏。
2.1 NLP中的注意力
以RNN做機(jī)器翻譯為例恕洲,下兩圖[1]分別是有沒(méi)有注意力:
沒(méi)有注意力機(jī)制的機(jī)器翻譯塔橡,翻譯下一詞時(shí),只考慮源語(yǔ)言經(jīng)過(guò)網(wǎng)絡(luò)后最終的表達(dá)(編碼/向量)霜第;而注意力機(jī)制是要考慮源語(yǔ)言中每(多)個(gè)詞的表達(dá)(編碼/向量)葛家。
NLP中有個(gè)非常常見(jiàn)的一個(gè)三元組概念:Query、Key泌类、Value,其中絕大部分情況Key=Value弹砚。在機(jī)器翻譯中,Query是已經(jīng)翻譯出來(lái)的部分桌吃,Key和Value是源語(yǔ)言中每個(gè)詞的表達(dá)(編碼/向量)苞轿,沒(méi)有注意力時(shí)直接拿Query就去預(yù)測(cè)下一個(gè)詞,注意力機(jī)制的計(jì)算就是用Query和Key計(jì)算出一組權(quán)重搬卒,賦權(quán)到Value上,拿Value去預(yù)測(cè)下一詞契邀。
翻譯編碼解碼模型[2]
計(jì)算權(quán)重[2]
加權(quán)[2]
2.2 自注意力
自注意力模型就是Query“=”Key“=”Value蹂安,挖掘一個(gè)句子內(nèi)部的聯(lián)系。計(jì)算句子中每個(gè)字之間的互相影響/權(quán)重田盈,再加權(quán)到句子中每個(gè)字的向量上。這個(gè)計(jì)算就是用了點(diǎn)積允瞧。
Query、Key痹升、Value都來(lái)自同一個(gè)輸入畦韭,但是經(jīng)過(guò)3個(gè)不同線性映射(全連接層)得到疼蛾,所以未必完全相等艺配。
公式中是Query向量和Key向量做點(diǎn)積衍慎,為了防止點(diǎn)積結(jié)果數(shù)值過(guò)大皮钠,做了一個(gè)放縮(
是Key向量的長(zhǎng)度),結(jié)果再經(jīng)過(guò)一個(gè)softmax歸一化成一個(gè)和為1的權(quán)重乔夯,乘到Value向量上。
attention可視化的效果(這里不同顏色代表attention不同頭的結(jié)果末荐,顏色越深attention值越大)【掀溃可以看到self-attention在這里可以學(xué)習(xí)到句子內(nèi)部長(zhǎng)距離依賴"making…….more difficult"這個(gè)短語(yǔ)壕鹉。
2.2.1 點(diǎn)積(Dot-Product)
- 兩向量點(diǎn)積表示兩個(gè)向量的相似度聋涨。
- 點(diǎn)積還有一個(gè)重要的特點(diǎn)是沒(méi)有參數(shù)。
點(diǎn)積也叫點(diǎn)乘牍白,一維點(diǎn)積用幾何表示是: 。與我們常用的余弦相識(shí)度/夾角作用一樣帕胆,與兩向量的相似程度成正比。
2.2.2 具體計(jì)算過(guò)程:
假設(shè)我們句子長(zhǎng)度設(shè)為512懒豹,每個(gè)單詞embedding成256維驯用。
-
Query與Key點(diǎn)積。
Pytorch代碼:
attn = torch.bmm(q, k.transpose(1, 2))
- scale放縮蝴乔、softmax歸一化、dropout隨機(jī)失活/置零
Pytorch代碼:
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
- 將權(quán)重矩陣加權(quán)到Value上薇正,維度未變化巩剖。
Pytorch代碼:
output = torch.bmm(attn, v)
2.3 多頭注意力
并不是將長(zhǎng)度是512的句子整個(gè)做點(diǎn)積自注意力钠怯,而是將其“拆”成h份,沒(méi)份長(zhǎng)度為512/h鞠鲜,然后每份單獨(dú)去加權(quán)注意力再拼接到一起,Q贤姆、K、V分別拆分霞捡。
“拆”的過(guò)程是一個(gè)獨(dú)立的(different)薄疚、可學(xué)習(xí)的(learned)線性映射。實(shí)際實(shí)現(xiàn)可以是h個(gè)全連接層街夭,每個(gè)全連接層輸入維度是512,輸出512/h板丽;也可以用一個(gè)全連接,輸入輸出均為512埃碱,輸出之后再切成h份。
多頭能夠從不同的表示子空間里學(xué)習(xí)相關(guān)信息啃憎。
在兩個(gè)頭和單頭的比較中瓮具,可以看到單頭"its"這個(gè)詞只能學(xué)習(xí)到"law"的依賴關(guān)系荧飞,而兩個(gè)頭"its"不僅學(xué)習(xí)到了"law"還學(xué)習(xí)到了"application"依賴關(guān)系名党。
Pytorch實(shí)現(xiàn):
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
...
def forward(self, q, k, v, mask=None):
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
三传睹、位置編碼(Positional Encoding)
因?yàn)閠ransformer沒(méi)有RNN和CNN耳幢,為了考慮位置信息,論文中直接將全局位置編號(hào)加到Embedding向量每個(gè)維度上启上。
Pytorch代碼:
# -- Forward
enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
另外,論文中位置編碼還利用了sin/cos正余弦函數(shù)考慮周期性和歸一化冈在。
四按摘、殘差和前饋(Feed Forward)
4.1 為什么殘差[3]
網(wǎng)絡(luò)的深度為什么重要包券?
因?yàn)镃NN能夠提取low/mid/high-level的特征炫贤,網(wǎng)絡(luò)的層數(shù)越多,意味著能夠提取到不同level的特征越豐富侍郭。并且,越深的網(wǎng)絡(luò)提取的特征越抽象亮元,越具有語(yǔ)義信息口柳。
為什么不能簡(jiǎn)單地增加網(wǎng)絡(luò)層數(shù)苹粟?
對(duì)于原來(lái)的網(wǎng)絡(luò)跃闹,如果簡(jiǎn)單地增加深度毛好,會(huì)導(dǎo)致梯度彌散或梯度爆炸。
對(duì)于該問(wèn)題的解決方法是正則化初始化和中間的正則化層(Batch Normalization)肌访,這樣的話可以訓(xùn)練幾十層的網(wǎng)絡(luò)。
雖然通過(guò)上述方法能夠訓(xùn)練了惩激,但是又會(huì)出現(xiàn)另一個(gè)問(wèn)題,就是退化問(wèn)題风钻,網(wǎng)絡(luò)層數(shù)增加酒请,但是在訓(xùn)練集上的準(zhǔn)確率卻飽和甚至下降了骡技。這個(gè)不能解釋為overfitting,因?yàn)閛verfit應(yīng)該表現(xiàn)為在訓(xùn)練集上表現(xiàn)更好才對(duì)囤萤。
退化問(wèn)題說(shuō)明了深度網(wǎng)絡(luò)不能很簡(jiǎn)單地被很好地優(yōu)化是趴。
作者通過(guò)實(shí)驗(yàn):通過(guò)淺層網(wǎng)絡(luò)+ y=x 等同映射構(gòu)造深層模型涛舍,結(jié)果深層模型并沒(méi)有比淺層網(wǎng)絡(luò)有等同或更低的錯(cuò)誤率唆途,推斷退化問(wèn)題可能是因?yàn)樯顚拥木W(wǎng)絡(luò)并不是那么好訓(xùn)練,也就是求解器很難去利用多層網(wǎng)絡(luò)擬合同等函數(shù)吹榴。
怎么解決退化問(wèn)題?
深度殘差網(wǎng)絡(luò)图筹。如果深層網(wǎng)絡(luò)的后面那些層是恒等映射让腹,那么模型就退化為一個(gè)淺層網(wǎng)絡(luò)远剩。那現(xiàn)在要解決的就是學(xué)習(xí)恒等映射函數(shù)了骇窍。 但是直接讓一些層去擬合一個(gè)潛在的恒等映射函數(shù)H(x) = x,比較困難痢掠,這可能就是深層網(wǎng)絡(luò)難以訓(xùn)練的原因。但是足画,如果把網(wǎng)絡(luò)設(shè)計(jì)為H(x) = F(x) + x,如下圖佃牛。我們可以轉(zhuǎn)換為學(xué)習(xí)一個(gè)殘差函數(shù)F(x) = H(x) - x. 只要F(x)=0,就構(gòu)成了一個(gè)恒等映射H(x) = x. 而且俘侠,擬合殘差肯定更加容易。
4.2 前饋
每個(gè)attention模塊后面會(huì)跟兩個(gè)全連接央星,中間加了一個(gè)Relu激活函數(shù),公式表示:
也可用兩個(gè)核為1的CNN層代替等曼。
兩個(gè)全連接是512->2048->512的操作。原因未詳細(xì)介紹禁谦。
五、訓(xùn)練-模型的參數(shù)在哪里
transformer的核心點(diǎn)積是沒(méi)有參數(shù)丧蘸,transform結(jié)構(gòu)的訓(xùn)練遥皂,會(huì)優(yōu)化的參數(shù)主要在:
- 嵌入層-Word Embedding
- 前饋(Feed Forward)層
- 多頭注意力中的“切片”操作(映射成多個(gè)/頭小向量)實(shí)際是一個(gè)全連接層(線性映射矩陣),以及多頭輸出拼接結(jié)果(Concat)后會(huì)經(jīng)過(guò)一個(gè)Linear全連接層演训。這兩個(gè)全連接層也是殘差塊有意義的地方,如果沒(méi)有這一層样悟,那這個(gè)注意力機(jī)制中就沒(méi)有參數(shù),殘差就沒(méi)有意義了窟她。
六、參考文獻(xiàn)
[1]. Neural Machine Translation by Jointly Learning to Align and Translate
[2]. 殘差的解讀