Ref: https://arxiv.org/pdf/2105.08050.pdf
code:https://github.com/lucidrains/g-mlp-pytorch/blob/54209f0fb2a52557a1c64409f26df9ebd8d5c257/g_mlp_pytorch/g_mlp_pytorch.py
背景
Transformers 自橫空出世以來骗灶,在NLP領(lǐng)域惨恭,大規(guī)模了取代了LSTM-RNN模型,在CV上耙旦,ConvNets也不再是唯一選擇脱羡。它有2個重要特性:
- recurrent-free結(jié)構(gòu),可以并行化計(jì)算每個token的表達(dá)免都;
- multi-head self-attention blocks锉罐, 可以聚合token之間的空間信息。
其中的attention mechanism一直被認(rèn)為transformers取得優(yōu)秀成績的重要因素绕娘。和MLP相比脓规,attention可以根據(jù)模型輸入判哥,調(diào)整參數(shù)凉敲,而MLP的參數(shù)是固定的。那么問題來了栗恩,transformers效果那么好绢陌,是self-attention起的決定性作用嗎挨下,self-attention是必要的嗎?
本文提出了gMLPs脐湾,一種attention-free, 以MLP為基礎(chǔ)的由channel projections臭笆, spatial projections 和gating組成的網(wǎng)絡(luò)結(jié)構(gòu)。
實(shí)驗(yàn)顯示:
- 在CV上秤掌,可以達(dá)到和vision transformers差不多的準(zhǔn)確率愁铺;和MLP-Mixer相比,參數(shù)減少66%机杜,準(zhǔn)確率還提升了3%帜讲;
- 在NLP上,將gMLPs應(yīng)用到BERT的MLM椒拗,和transformers一樣,在預(yù)訓(xùn)練實(shí)時能最小化perplexity获黔。同時蚀苛,實(shí)驗(yàn)也顯示,perplexity和模型規(guī)模有關(guān)玷氏,而對attention不敏感堵未;
2.1 隨著模型的capacity上升,gMLPs的預(yù)訓(xùn)練和finetuning指標(biāo)會快速接近Transformers盏触,這意味著渗蟹,只要擴(kuò)大模型規(guī)模块饺,那么無需self-attention,gMLPs和Transformers的差距會不斷縮写蒲俊授艰;
2.2 batch-size為256,進(jìn)過1Mstep世落,gMLPs相比Bert淮腾,在MNLI達(dá)到了86.4%的準(zhǔn)確率,在SQuAD達(dá)到了89.5%的F1屉佳;
2.3 在finetuning階段谷朝,模型規(guī)模和perplexity接近的情況下, Transformers在cross-sentence alignment任務(wù)上比gMLPs效果好[MNLI任務(wù)高1.8%]武花。但是圆凰,當(dāng)gMLPs的參數(shù)量是transformers的3倍時,模型效果就很接近体箕;
2.4 同時专钉,文中提出一個trick,在gMLPs后接一個single-head 128d 的attention干旁,在NLP的各項(xiàng)任務(wù)上驶沼,就超過了transformers。
因此争群,本文覺得回怜,提高數(shù)據(jù)量和算力,無需self-attention换薄,gMLPs玉雾,就可以和transformers媲美。
Model
輸入:序列長度為n轻要,embedding維度為d:
使用L個block复旬,每個block進(jìn)行如下操作:
其中:
U,V為沿著channel[可理解為hidden維度]的線性投影冲泥,同Transformers的FFN驹碍;
為空間上的交互,用于獲取tokens之間的關(guān)系凡恍。本文認(rèn)為可以學(xué)習(xí)到位置信息志秃,因此,沒有使用positional embedding嚼酝。
Spatial Gating Unit
為了實(shí)現(xiàn)token之間的交互浮还,在層,就要包含一個空間維度的交叉操作闽巩。
文中主要介紹了2種SGU:
- 比較直觀的钧舌,就是使用線性投影:
其中:
, n為序列長度担汤;b可以是一個矩陣,也可以是一個常量洼冻。
空間交互通過element-wise實(shí)現(xiàn):
為確保訓(xùn)練的穩(wěn)定性崭歧,W初始化值接近于0, b為1碘赖。這相當(dāng)于初始化的FFN驾荣,開始每個token相互獨(dú)立,隨著訓(xùn)練逐漸考慮token之間的交互信息普泡。
- 除了線性投影的gate, 文中還將Z沿著channel分解成播掷,借鑒GLUs的思路:
源代碼分析
class SpatialGatingUnit(nn.Module):
def __init__(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
"""dim: embedding size
dim_seq: sequence length """
super().__init__()
dim_out = dim // 2
self.causal = causal
self.norm = nn.LayerNorm(dim_out)
self.proj = nn.Conv1d(dim_seq, dim_seq, 1)
# 常規(guī)卷積,卷積的是詞向量的維度撼班。本文是空間上的信息交互歧匈,因此輸入/輸出通道是序列長度,卷積核尺寸為1砰嘁。
self.act = act
init_eps /= dim_seq
nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
nn.init.constant_(self.proj.bias, 1.)
def forward(self, x, gate_res = None):
device, n = x.device, x.shape[1]
res, gate = x.chunk(2, dim = -1) #沿著詞向量維度件炉,分成2個矩陣。
gate = self.norm(gate)
weight, bias = self.proj.weight, self.proj.bias
if self.causal:
weight, bias = weight[:n, :n], bias[:n]
mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
weight = weight.masked_fill(mask[..., None], 0.)
gate = F.conv1d(gate, weight, bias)
if exists(gate_res):
gate = gate + gate_res
return self.act(gate) * res
GLUs(Gated linear units)補(bǔ)充:
由Language model with gated convolutional network提出矮湘,使用CNN學(xué)習(xí)長文本斟冕,為緩解梯度消散,并保留非線性能力缅阳,使用門控機(jī)制磕蛇。即:
沒有經(jīng)過非線性轉(zhuǎn)換的卷積層輸出*經(jīng)過非線性轉(zhuǎn)換的卷積層輸出
其中:
:element-wise product
注意,GLUs是沿著channel維度[per token]的處理十办,而SGU是沿著空間維度[cross-token]的處理秀撇。
Image Classification
在圖片分類ImageNet數(shù)據(jù)集上,無需添加外部數(shù)據(jù)向族,訓(xùn)練gMLPs呵燕。
模型配置如下,輸入和輸出沿用的ViT(Vision Transformer)格式件相,模型的深度和寬度配置也和ViT/DeiT模型相似再扭。
結(jié)果:和Transformer一樣,gMLPs在訓(xùn)練集上過擬合夜矗,因此采用了DeiT的正則化處理(mixup, cutmix)霍衫;同時,對模型的維度做了調(diào)整侯养。
Masked Language Modeling with BERT
DepthWise convolution補(bǔ)充
一個卷積核負(fù)責(zé)一個通道,卷積核數(shù)量要和圖片通道數(shù)相同澄干。
好比一個寬的depthwise convolution逛揩,接收整個句子的信息柠傍。但是depthwise convolution面向的是通道的filter,而gMLPs只使用一個W共享交叉通道辩稽。
在NLP上惧笛,gMLPs進(jìn)行了多個ablation實(shí)驗(yàn)。
1. Ablation:the importance of gating in gMLP for BERT's Pretraining
- 使用Bert的absolute position embeddings;
- Bert框架 + T5-stype的relative position biases逞泄;
- 同1患整,2,但只保留relative positional biases喷众,去掉content-dependent terms inside the softmax各谚。
困惑度:交叉熵的指數(shù)形式。語言模型越好到千,句子概率越大昌渤,熵越小,困惑度越低憔四。
使用SGU可以讓gMLPs得到與Bert差不多的perplexity膀息。
2. Case Study: The Behavior of gMLPs as Model Size Increases
Transformer中的6+6:self-attention使用6層,F(xiàn)FN使用6層了赵。
finetuning任務(wù)用GLUE表示模型效果潜支。
結(jié)果顯示:
- gMLPs越深,pretraining perplexity越小柿汛,和transformer的模型效果越逼近冗酿;
- pretraining的perplexity越小,不意味著finetuning結(jié)果越好苛茂,比如gMLPs的perplexity比transformer小的時候已烤,在SST-2的模型結(jié)果更好,但是MNLI-m的模型結(jié)果更差妓羊;
3. Ablation: The Usefulness of Tiny Attention in BERT's Finetuning
文中還做了個測試胯究,在一些下游任務(wù)上,主要是設(shè)計(jì)到句子對的任務(wù)上躁绸,gMLPs表現(xiàn)比Transformers差裕循。 那就再加一個tiny attention,來加強(qiáng)模型對cross-sentence alignment的學(xué)習(xí)净刮。
這種混個gMLPs和attention的模型剥哑,稱為aMLPs。結(jié)果顯示淹父,aMLPs的效果比gMLPs和transformer都要好株婴。
4.Main Results for MLM in the BERT Setup
- 以SQuADv2.0任務(wù)為例,base模型,Bert模型的f1達(dá)到了78.6困介,gMLPs只有70.1大审, 差距8.5%;到了large模型座哩,差距只有81.0-78.3=2.7徒扶;
- aMLPs使用128d的attention size,在SQuADv2.0任務(wù)根穷,比Bert還要高4.4%的F1.
前面做的幾個實(shí)驗(yàn)的總結(jié):
- 在finetuning階段姜骡,gMLPs不如transformer,但是屿良,隨著模型變大圈澈,和transformer的差距會不斷縮小管引;
- aMLPs 不同的attention size(64士败,128),足夠使得模型效果優(yōu)于其他2個褥伴。