來源:https://www.researchgate.net/publication/364419868_The_Devil_in_Linear_Transformer
代碼:https://github.com/OpenNLPLab/Transnormer
這篇文章的目的是優(yōu)化線性transformer,線性transformer相對于標(biāo)準(zhǔn)transformer能夠?qū)⒂?jì)算復(fù)雜度從 降到
. 但線性transformer 相對于標(biāo)準(zhǔn)transformer 往往存在著較明顯的指標(biāo)gap程储。作者分析認(rèn)為原因有兩點(diǎn):
- unbounded gradients。無邊界梯度摊灭,會導(dǎo)致模型在訓(xùn)練時(shí)不穩(wěn)定败徊,收斂不好;
- attention dilution萝挤。注意力稀釋根欧,transformer在lower level時(shí)應(yīng)該更關(guān)注局部特征端蛆,而higher level更關(guān)注全局特征,但線性transformer中的attention往往weight 更均勻化今豆,不能聚焦在local區(qū)域上,因此稱為attention稀釋异逐。
針對于上述兩點(diǎn)插掂,作者提出了NormAttention和DiagAttention兩個(gè)模塊,形成NormFormer的結(jié)構(gòu)酝润。
1.The devil in linear attention
我們首先來看一下作者分析的線性transformer存在的兩點(diǎn)缺陷的結(jié)論是怎么來的璃弄。
1.1 Unbounded gradients
在標(biāo)準(zhǔn)的attention結(jié)構(gòu)中
正是這里的 帶來的
的計(jì)算復(fù)雜度夏块。而為了解決這個(gè)問題目前主要包含兩類: 基于pattern的方法和基于kernel的方法。
基于pattern的方式主要是通過一些先驗(yàn)篩選key或query浑塞,降低計(jì)算復(fù)雜度患民;而基于kernel的方法則是本文提到的線性transformer,通過核函數(shù)去取代softmax仅孩,從而能夠通過矩陣乘法結(jié)合律降低計(jì)算復(fù)雜度。
那么來看一下計(jì)算attention時(shí)京腥,vanilla和linear transformer的統(tǒng)一形式:
對于vanilla transformer而言溅蛉, , 對于linear transformer可以表示為
. 于是可以比較一下兩者的梯度:
vanilla attention: , 這里推理的時(shí)候注意湊
這里推理的時(shí)候只有 時(shí)邊界值成立欠气,所以最終
linear attention: 線性attention的關(guān)鍵在于, 因此
即预柒,
.
因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=s_%7Bik%7D%20%3D%20%5Cphi(q_i)%5Cphi(q_k)%5ET" alt="s_{ik} = \phi(q_i)\phi(q_k)^T" mathimg="1"> 大小是不確定的袁梗,所以相當(dāng)于linear attention的梯度是無邊界的。這就會導(dǎo)致收斂不穩(wěn)定淋袖,收斂難度大等問題锯梁。
1.2 Attention dilution
注意力稀釋方面,作者直接評估了不同level上拜姿,每一個(gè)query在鄰域內(nèi)的其他query上的attention的權(quán)重占比冯遂,這里需要注意的是蛤肌,query之間是有序的,即對于NLP或者featmap而言裸准,是有固定結(jié)構(gòu)的,才可以這么評估盐肃。表示第i個(gè)query在其
個(gè)鄰域query上的attention之和,可以看下圖推盛,a圖中transformer和linear transformer相比谦铃,顯然linear transformer的聚集度要小很多驹闰。這就是所謂的注意力稀釋。
2. architecture
針對于1中的兩個(gè)問題师妙,有針對性的設(shè)計(jì)了兩個(gè)模塊骡显。
2.1 NormAttention.
作者提出的解決方案
,
這里的XNorm 可以是Layernorm,也可以是 RMSNorm。注意這里的Q溜歪,和K是有激活函數(shù)的许蓖,公式?jīng)]寫,但圖中畫了自阱。
文章證明這個(gè)做法梯度是有上界的米酬。附錄的證明過程有點(diǎn)復(fù)雜。
2.2 DiagAttention
這個(gè)模塊其實(shí)就是一種基于pattern的attention加派,將query按距離劃分不重疊的window跳芳,每個(gè)window內(nèi)進(jìn)行 attention的計(jì)算。奇怪的是 這里的attention使用的都是vanilla attention娄琉。
下圖是文章方法TransNormer的結(jié)構(gòu):
3. 實(shí)驗(yàn)
實(shí)驗(yàn)都是在NLP上做的孽水,不大了解,因此不做分析丧慈,這里只看下消融實(shí)驗(yàn)的結(jié)論主卫。
table8. 表明早期的stage應(yīng)當(dāng)更關(guān)注局部特征,而后期的stage則應(yīng)該更關(guān)注全局信息完域。
table9. 早期適合使用blockattn瘩将,后期適合使用normattn
table10. FFN中作者對比了FFN和GLU的結(jié)果,發(fā)現(xiàn)GLU效果會更好一些肠仪。
table11.表明diagattn中的window的大小异旧,這個(gè)其實(shí)有有點(diǎn)說不通提佣,如果DiagAttn使用的linear attention, block size越大不是attention 稀釋的越嚴(yán)重嗎潮针? 這個(gè)地方DiagAttn使用的應(yīng)該都是vanilla attention倚喂,包括softmax attention和ReLA attention.
4. 結(jié)論
本文提出的norm attention其實(shí)在很多其他方法中都見過,而且所謂的diag attention使用的還是vanilla attention雳攘,并沒有把linear attention應(yīng)用到diag block里,感覺不是很充實(shí)吨灭。值得學(xué)習(xí)的是本文中提出的梯度分析的方法刑巧。