Pay Attention to MLPs

Pay Attention to MLPs

Ref: https://arxiv.org/pdf/2105.08050.pdf
code:https://github.com/lucidrains/g-mlp-pytorch/blob/54209f0fb2a52557a1c64409f26df9ebd8d5c257/g_mlp_pytorch/g_mlp_pytorch.py

背景

Transformers 自橫空出世以來骗灶,在NLP領(lǐng)域惨恭,大規(guī)模了取代了LSTM-RNN模型,在CV上耙旦,ConvNets也不再是唯一選擇脱羡。它有2個重要特性:

  1. recurrent-free結(jié)構(gòu),可以并行化計(jì)算每個token的表達(dá)免都;
  2. multi-head self-attention blocks锉罐, 可以聚合token之間的空間信息。

其中的attention mechanism一直被認(rèn)為transformers取得優(yōu)秀成績的重要因素绕娘。和MLP相比脓规,attention可以根據(jù)模型輸入判哥,調(diào)整參數(shù)凉敲,而MLP的參數(shù)是固定的。那么問題來了栗恩,transformers效果那么好绢陌,是self-attention起的決定性作用嗎挨下,self-attention是必要的嗎

本文提出了gMLPs脐湾,一種attention-free, 以MLP為基礎(chǔ)的由channel projections臭笆, spatial projections 和gating組成的網(wǎng)絡(luò)結(jié)構(gòu)。

實(shí)驗(yàn)顯示:

  1. 在CV上秤掌,可以達(dá)到和vision transformers差不多的準(zhǔn)確率愁铺;和MLP-Mixer相比,參數(shù)減少66%机杜,準(zhǔn)確率還提升了3%帜讲;
  2. 在NLP上,將gMLPs應(yīng)用到BERT的MLM椒拗,和transformers一樣,在預(yù)訓(xùn)練實(shí)時能最小化perplexity获黔。同時蚀苛,實(shí)驗(yàn)也顯示,perplexity和模型規(guī)模有關(guān)玷氏,而對attention不敏感堵未;
    2.1 隨著模型的capacity上升,gMLPs的預(yù)訓(xùn)練和finetuning指標(biāo)會快速接近Transformers盏触,這意味著渗蟹,只要擴(kuò)大模型規(guī)模块饺,那么無需self-attention,gMLPs和Transformers的差距會不斷縮写蒲俊授艰;
    2.2 batch-size為256,進(jìn)過1Mstep世落,gMLPs相比Bert淮腾,在MNLI達(dá)到了86.4%的準(zhǔn)確率,在SQuAD達(dá)到了89.5%的F1屉佳;
    2.3 在finetuning階段谷朝,模型規(guī)模和perplexity接近的情況下, Transformers在cross-sentence alignment任務(wù)上比gMLPs效果好[MNLI任務(wù)高1.8%]武花。但是圆凰,當(dāng)gMLPs的參數(shù)量是transformers的3倍時,模型效果就很接近体箕;
    2.4 同時专钉,文中提出一個trick,在gMLPs后接一個single-head 128d 的attention干旁,在NLP的各項(xiàng)任務(wù)上驶沼,就超過了transformers。

因此争群,本文覺得回怜,提高數(shù)據(jù)量和算力,無需self-attention换薄,gMLPs玉雾,就可以和transformers媲美。

Model

輸入:序列長度為n轻要,embedding維度為d:
X\in R^{n\times d}

使用L個block复旬,每個block進(jìn)行如下操作:

Z = \sigma (XU) = GeLU(XU)
\tilde Z = s(Z)
Y = \tilde ZV

其中:
U,V為沿著channel[可理解為hidden維度]的線性投影冲泥,同Transformers的FFN驹碍;
s(\cdot)為空間上的交互,用于獲取tokens之間的關(guān)系凡恍。本文認(rèn)為s(\cdot)可以學(xué)習(xí)到位置信息志秃,因此,沒有使用positional embedding嚼酝。

gMLPs

Spatial Gating Unit

為了實(shí)現(xiàn)token之間的交互浮还,在s(\cdot)層,就要包含一個空間維度的交叉操作闽巩。

文中主要介紹了2種SGU:

  1. 比較直觀的钧舌,就是使用線性投影:
    f_{W,b} (Z) = WZ + b

其中:
W\in R^{n\times n}, n為序列長度担汤;b可以是一個矩陣,也可以是一個常量洼冻。
空間交互通過element-wise實(shí)現(xiàn):
s(Z) = Z \odot f_{W,b} (Z)

為確保訓(xùn)練的穩(wěn)定性崭歧,W初始化值接近于0, b為1碘赖。這相當(dāng)于初始化的FFN驾荣,開始每個token相互獨(dú)立,隨著訓(xùn)練逐漸考慮token之間的交互信息普泡。

  1. 除了線性投影的gatef_{W,b} (\cdot), 文中還將Z沿著channel分解成(Z_1,Z_2)播掷,借鑒GLUs的思路:
    s(Z) = Z_1 \odot f_{W,b} (Z_2)

源代碼分析

class SpatialGatingUnit(nn.Module):
    def __init__(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
        """dim: embedding size 
            dim_seq: sequence length """
        super().__init__()
        dim_out = dim // 2
        self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        self.proj = nn.Conv1d(dim_seq, dim_seq, 1) 
        # 常規(guī)卷積,卷積的是詞向量的維度撼班。本文是空間上的信息交互歧匈,因此輸入/輸出通道是序列長度,卷積核尺寸為1砰嘁。

        self.act = act

        init_eps /= dim_seq
        nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        nn.init.constant_(self.proj.bias, 1.)

    def forward(self, x, gate_res = None):
        device, n = x.device, x.shape[1]

        res, gate = x.chunk(2, dim = -1) #沿著詞向量維度件炉,分成2個矩陣。
        gate = self.norm(gate)

        weight, bias = self.proj.weight, self.proj.bias
        if self.causal:
            weight, bias = weight[:n, :n], bias[:n]
            mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
            weight = weight.masked_fill(mask[..., None], 0.)

        gate = F.conv1d(gate, weight, bias)

        if exists(gate_res):
            gate = gate + gate_res

        return self.act(gate) * res

GLUs(Gated linear units)補(bǔ)充:

由Language model with gated convolutional network提出矮湘,使用CNN學(xué)習(xí)長文本斟冕,為緩解梯度消散,并保留非線性能力缅阳,使用門控機(jī)制磕蛇。即:
沒有經(jīng)過非線性轉(zhuǎn)換的卷積層輸出*經(jīng)過非線性轉(zhuǎn)換的卷積層輸出
h(x) = (X*W+b)\odot \sigma(X*V + b)

其中:
\odot:element-wise product
X\in R^{N \times m}
W,V \in R^{k \times m \times n}

注意,GLUs是沿著channel維度[per token]的處理十办,而SGU是沿著空間維度[cross-token]的處理秀撇。

Image Classification

在圖片分類ImageNet數(shù)據(jù)集上,無需添加外部數(shù)據(jù)向族,訓(xùn)練gMLPs呵燕。
模型配置如下,輸入和輸出沿用的ViT(Vision Transformer)格式件相,模型的深度和寬度配置也和ViT/DeiT模型相似再扭。
結(jié)果:和Transformer一樣,gMLPs在訓(xùn)練集上過擬合夜矗,因此采用了DeiT的正則化處理(mixup, cutmix)霍衫;同時,對模型的維度做了調(diào)整侯养。


CV gMLPs
ImageNet模型結(jié)果
圖片分類準(zhǔn)確率和模型規(guī)模關(guān)系

Masked Language Modeling with BERT

DepthWise convolution補(bǔ)充

一個卷積核負(fù)責(zé)一個通道,卷積核數(shù)量要和圖片通道數(shù)相同澄干。
f_{W,b}( \cdot)好比一個寬的depthwise convolution逛揩,接收整個句子的信息柠傍。但是depthwise convolution面向的是通道的filter,而gMLPs只使用一個W共享交叉通道辩稽。

在NLP上惧笛,gMLPs進(jìn)行了多個ablation實(shí)驗(yàn)。

1. Ablation:the importance of gating in gMLP for BERT's Pretraining

  1. 使用Bert的absolute position embeddings;
  2. Bert框架 + T5-stype的relative position biases逞泄;
  3. 同1患整,2,但只保留relative positional biases喷众,去掉content-dependent terms inside the softmax各谚。

困惑度:交叉熵的指數(shù)形式。語言模型越好到千,句子概率越大昌渤,熵越小,困惑度越低憔四。

各種模型的perplexity比較

使用SGU可以讓gMLPs得到與Bert差不多的perplexity膀息。

2. Case Study: The Behavior of gMLPs as Model Size Increases

模型規(guī)模和finetuing結(jié)果比較

Transformer中的6+6:self-attention使用6層,F(xiàn)FN使用6層了赵。
finetuning任務(wù)用GLUE表示模型效果潜支。
結(jié)果顯示:

  1. gMLPs越深,pretraining perplexity越小柿汛,和transformer的模型效果越逼近冗酿;
  2. pretraining的perplexity越小,不意味著finetuning結(jié)果越好苛茂,比如gMLPs的perplexity比transformer小的時候已烤,在SST-2的模型結(jié)果更好,但是MNLI-m的模型結(jié)果更差妓羊;

3. Ablation: The Usefulness of Tiny Attention in BERT's Finetuning

文中還做了個測試胯究,在一些下游任務(wù)上,主要是設(shè)計(jì)到句子對的任務(wù)上躁绸,gMLPs表現(xiàn)比Transformers差裕循。 那就再加一個tiny attention,來加強(qiáng)模型對cross-sentence alignment的學(xué)習(xí)净刮。

Hybrid

這種混個gMLPs和attention的模型剥哑,稱為aMLPs。結(jié)果顯示淹父,aMLPs的效果比gMLPs和transformer都要好株婴。


模型比較

4.Main Results for MLM in the BERT Setup

模型效果總結(jié)
  1. 以SQuADv2.0任務(wù)為例,base模型,Bert模型的f1達(dá)到了78.6困介,gMLPs只有70.1大审, 差距8.5%;到了large模型座哩,差距只有81.0-78.3=2.7徒扶;
  2. aMLPs使用128d的attention size,在SQuADv2.0任務(wù)根穷,比Bert還要高4.4%的F1.

前面做的幾個實(shí)驗(yàn)的總結(jié):

  1. 在finetuning階段姜骡,gMLPs不如transformer,但是屿良,隨著模型變大圈澈,和transformer的差距會不斷縮小管引;
  2. aMLPs 不同的attention size(64士败,128),足夠使得模型效果優(yōu)于其他2個褥伴。
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末谅将,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子重慢,更是在濱河造成了極大的恐慌饥臂,老刑警劉巖,帶你破解...
    沈念sama閱讀 211,376評論 6 491
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件似踱,死亡現(xiàn)場離奇詭異隅熙,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)核芽,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,126評論 2 385
  • 文/潘曉璐 我一進(jìn)店門囚戚,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人轧简,你說我怎么就攤上這事驰坊。” “怎么了哮独?”我有些...
    開封第一講書人閱讀 156,966評論 0 347
  • 文/不壞的土叔 我叫張陵拳芙,是天一觀的道長。 經(jīng)常有香客問我皮璧,道長舟扎,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,432評論 1 283
  • 正文 為了忘掉前任悴务,我火速辦了婚禮睹限,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己邦泄,他們只是感情好删窒,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,519評論 6 385
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著顺囊,像睡著了一般。 火紅的嫁衣襯著肌膚如雪蕉拢。 梳的紋絲不亂的頭發(fā)上特碳,一...
    開封第一講書人閱讀 49,792評論 1 290
  • 那天,我揣著相機(jī)與錄音晕换,去河邊找鬼午乓。 笑死,一個胖子當(dāng)著我的面吹牛闸准,可吹牛的內(nèi)容都是我干的益愈。 我是一名探鬼主播,決...
    沈念sama閱讀 38,933評論 3 406
  • 文/蒼蘭香墨 我猛地睜開眼夷家,長吁一口氣:“原來是場噩夢啊……” “哼蒸其!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起库快,我...
    開封第一講書人閱讀 37,701評論 0 266
  • 序言:老撾萬榮一對情侶失蹤摸袁,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后义屏,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體靠汁,經(jīng)...
    沈念sama閱讀 44,143評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,488評論 2 327
  • 正文 我和宋清朗相戀三年闽铐,在試婚紗的時候發(fā)現(xiàn)自己被綠了蝶怔。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,626評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡兄墅,死狀恐怖踢星,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情察迟,我是刑警寧澤斩狱,帶...
    沈念sama閱讀 34,292評論 4 329
  • 正文 年R本政府宣布,位于F島的核電站扎瓶,受9級特大地震影響所踊,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜概荷,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,896評論 3 313
  • 文/蒙蒙 一秕岛、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦继薛、人聲如沸修壕。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,742評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽慈鸠。三九已至,卻和暖如春灌具,著一層夾襖步出監(jiān)牢的瞬間青团,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,977評論 1 265
  • 我被黑心中介騙來泰國打工咖楣, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留督笆,地道東北人。 一個月前我還...
    沈念sama閱讀 46,324評論 2 360
  • 正文 我出身青樓诱贿,卻偏偏與公主長得像娃肿,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子珠十,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,494評論 2 348

推薦閱讀更多精彩內(nèi)容