DeiT:注意力也能蒸餾
《Training data-ef?cient image transformers & distillation through attention》
ViT 在大數(shù)據(jù)集 ImageNet-21k(14million)或者 JFT-300M(300million) 上進(jìn)行訓(xùn)練,Batch Size 128 下 NVIDIA A100 32G GPU 的計算資源加持下預(yù)訓(xùn)練 ViT-Base/32 需要3天時間。
Facebook 與索邦大學(xué) Matthieu Cord 教授合作發(fā)表 Training data-efficient image transformers(DeiT) & distillation through attention朱盐,DeiT 模型(8600萬參數(shù))僅用一臺 GPU 服務(wù)器在 53 hours train台囱,20 hours finetune扑庞,僅使用 ImageNet 就達(dá)到了 84.2 top-1 準(zhǔn)確性,而無需使用任何外部數(shù)據(jù)進(jìn)行訓(xùn)練拒逮。性能與最先進(jìn)的卷積神經(jīng)網(wǎng)絡(luò)(CNN)可以抗衡罐氨。所以呢,很有必要講講這個 DeiT 網(wǎng)絡(luò)模型的相關(guān)內(nèi)容滩援。
下面來簡單總結(jié) DeiT:
DeiT 是一個全 Transformer 的架構(gòu)栅隐。其核心是提出了針對 ViT 的教師-學(xué)生蒸餾訓(xùn)練策略,并提出了 token-based distillation 方法玩徊,使得 Transformer 在視覺領(lǐng)域訓(xùn)練得又快又好租悄。
DeiT 相關(guān)背景
ViT 文中表示數(shù)據(jù)量不足會導(dǎo)致 ViT 效果變差。針對以上問題恩袱,DeiT 核心共享是使用了蒸餾策略恰矩,能夠僅使用 ImageNet-1K 數(shù)據(jù)集就就可以達(dá)到 83.1% 的 Top1。
那么文章主要貢獻(xiàn)可以總結(jié)為三點(diǎn):
- 僅使用 Transformer憎蛤,不引入 Conv 的情況下也能達(dá)到 SOTA 效果外傅。
- 提出了基于 token 蒸餾的策略,針對 Transformer 蒸餾方法超越傳統(tǒng)蒸餾方法俩檬。
- DeiT 發(fā)現(xiàn)使用 Convnet 作為教師網(wǎng)絡(luò)能夠比使用 Transformer 架構(gòu)效果更好萎胰。
正式了解 DeiT 算法之前呢,有幾個問題需要去了解的:ViT的缺點(diǎn)和局限性棚辽,為什么訓(xùn)練ViT要準(zhǔn)備這么多數(shù)據(jù)技竟,就不能簡單快速訓(xùn)練一個模型出來嗎?另外 Transformer 視覺模型又怎么玩蒸餾呢屈藐?
ViT 的缺點(diǎn)和局限性
Transformer的輸入是一個序列(Sequence)榔组,ViT 所采用的思路是把圖像分塊(patches),然后把每一塊視為一個向量(vector)联逻,所有的向量并在一起就成為了一個序列(Sequence)搓扯,ViT 使用的數(shù)據(jù)集包括了一個巨大的包含了 300 million images的 JFT-300,這個數(shù)據(jù)集是私有的包归,即外部研究者無法復(fù)現(xiàn)實(shí)驗(yàn)锨推。而且在ViT的實(shí)驗(yàn)中作者明確地提到:
"That transformers do not generalize well when trained on insufficient amounts of data."
意思是當(dāng)不使用 JFT-300 大數(shù)據(jù)集時,效果不如CNN模型公壤。也就反映出Transformer結(jié)構(gòu)若想取得理想的性能和泛化能力就需要這樣大的數(shù)據(jù)集换可。DeiT 作者通過所提出的蒸餾的訓(xùn)練方案,只在 Imagenet 上進(jìn)行訓(xùn)練厦幅,就產(chǎn)生了一個有競爭力的無卷積 Transformer沾鳄。
ViT 相關(guān)技術(shù)點(diǎn)
Multi-head Self Attention layers (MSA):
首先有一個 Query 矩陣 Q 和一個 Key 矩陣 K,把二者矩陣乘在一起并進(jìn)行歸一化以后得到 attention 矩陣确憨,它再與Value矩陣 V 相乘得到最終的輸出得到 Z译荞。最后經(jīng)過 linear transformation 得到 NxD 的輸出結(jié)果瓤的。
Feed-Forward Network (FFN):
Multi-head Self Attention layers 之后往往會跟上一個 Feed-Forward Network (FFN) ,它一般是由2個linear layer構(gòu)成磁椒,第1個linear layer把維度從 D 維變換到 ND 維堤瘤,第2個linear layer把維度從 ND 維再變換到 D 維。
此時 Transformer block 是不考慮位置信息的浆熔,基于此 ViT 加入了位置編碼 (Positional Encoding)本辐,這些編碼在第一個 block 之前被添加到 input token 中代表位置信息,作為額外可學(xué)習(xí)的embedding(Extra learnable class embedding)医增。
Class token:
Class token 與 input token 并在一起輸入 Transformer block 中慎皱,最后的輸出結(jié)果用來預(yù)測類別。這樣一來叶骨,Transformer 相當(dāng)于一共處理了 N+1 個維度為 D 的token茫多,并且只有第一個 token 的輸出用來預(yù)測類別。
知識蒸餾介紹
Knowledge Distillation(KD)最初被 Hinton 提出 “Distilling the Knowledge in a Neural Network”忽刽,與 Label smoothing 動機(jī)類似天揖,但是 KD 生成 soft label 的方式是通過教師網(wǎng)絡(luò)得到的。
KD 可以視為將教師網(wǎng)絡(luò)學(xué)到的信息壓縮到學(xué)生網(wǎng)絡(luò)中跪帝。還有一些工作 “Circumventing outlier of autoaugment with knowledge distillation” 則將 KD 視為數(shù)據(jù)增強(qiáng)方法的一種今膊。
提出背景
雖然在一般情況下,我們不會去區(qū)分訓(xùn)練和部署使用的模型伞剑,但是訓(xùn)練和部署之間存在著一定的不一致性斑唬。在訓(xùn)練過程中,我們需要使用復(fù)雜的模型黎泣,大量的計算資源恕刘,以便從非常大照雁、高度冗余的數(shù)據(jù)集中提取出信息矮湘。在實(shí)驗(yàn)中,效果最好的模型往往規(guī)模很大豆挽,甚至由多個模型集成得到衡便。而大模型不方便部署到服務(wù)中去献起,常見的瓶頸如下:
- 推理速度和性能慢
- 對部署資源要求高(內(nèi)存,顯存等)
在部署時镣陕,對延遲以及計算資源都有著嚴(yán)格的限制。因此姻政,模型壓縮(在保證性能的前提下減少模型的參數(shù)量)成為了一個重要的問題呆抑,而“模型蒸餾”屬于模型壓縮的一種方法。
理論原理
知識蒸餾使用的是 Teacher—Student 模型汁展,其中 Teacher 是“知識”的輸出者鹊碍,Student 是“知識”的接受者厌殉。知識蒸餾的過程分為2個階段:
- 原始模型訓(xùn)練: 訓(xùn)練 "Teacher模型", 簡稱為Net-T,它的特點(diǎn)是模型相對復(fù)雜侈咕,也可以由多個分別訓(xùn)練的模型集成而成公罕。我們對"Teacher模型"不作任何關(guān)于模型架構(gòu)、參數(shù)量耀销、是否集成方面的限制楼眷,唯一的要求就是,對于輸入X, 其都能輸出Y熊尉,其中Y經(jīng)過softmax的映射罐柳,輸出值對應(yīng)相應(yīng)類別的概率值。
- 精簡模型訓(xùn)練: 訓(xùn)練"Student模型", 簡稱為Net-S狰住,它是參數(shù)量較小张吉、模型結(jié)構(gòu)相對簡單的單模型。同樣的催植,對于輸入X肮蛹,其都能輸出Y,Y經(jīng)過softmax映射后同樣能輸出對應(yīng)相應(yīng)類別的概率值创南。
論文中伦忠,Hinton 將問題限定在分類問題下,或者其他本質(zhì)上屬于分類問題的問題扰藕,該類問題的共同點(diǎn)是模型最后會有一個softmax層缓苛,其輸出值對應(yīng)了相應(yīng)類別的概率值。知識蒸餾時邓深,由于已經(jīng)有了一個泛化能力較強(qiáng)的Net-T未桥,我們在利用Net-T來蒸餾訓(xùn)練Net-S時,可以直接讓Net-S去學(xué)習(xí)Net-T的泛化能力芥备。
其中KD的訓(xùn)練過程和傳統(tǒng)的訓(xùn)練過程的對比:
- 傳統(tǒng)training過程 Hard Targets: 對 ground truth 求極大似然 Softmax 值冬耿。
- KD的training過程 Soft Targets: 用 Teacher 模型的 class probabilities作為soft targets。
這就解釋了為什么通過蒸餾的方法訓(xùn)練出的 Net-S 相比使用完全相同的模型結(jié)構(gòu)和訓(xùn)練數(shù)據(jù)只使用Hard Targets的訓(xùn)練方法得到的模型萌壳,擁有更好的泛化能力亦镶。
具體方法
第一步是訓(xùn)練Net-T;第二步是在高溫 T 下袱瓮,蒸餾 Net-T 的知識到 Net-S缤骨。
訓(xùn)練 Net-T 的過程很簡單,而高溫蒸餾過程的目標(biāo)函數(shù)由distill loss(對應(yīng)soft target)和student loss(對應(yīng)hard target)加權(quán)得到:
Deit 中使用 Conv-Based 架構(gòu)作為教師網(wǎng)絡(luò)尺借,以 soft 的方式將歸納偏置傳遞給學(xué)生模型绊起,將局部性的假設(shè)通過蒸餾方式引入 Transformer 中,取得了不錯的效果燎斩。
DeiT 具體方法
為什么DeiT能在大幅減少 1. 訓(xùn)練所需的數(shù)據(jù)集 和 2. 訓(xùn)練時長 的情況下依舊能夠取得很不錯的性能呢虱歪?我們可以把這個原因歸結(jié)為DeiT的訓(xùn)練策略蜂绎。ViT 在小數(shù)據(jù)集上的性能不如使用CNN網(wǎng)絡(luò) EfficientNet,但是跟ViT結(jié)構(gòu)相同笋鄙,僅僅是使用更好的訓(xùn)練策略的DeiT比ViT的性能已經(jīng)有了很大的提升师枣,在此基礎(chǔ)上,再加上蒸餾 (distillation) 操作萧落,性能超過了 EfficientNet践美。
假設(shè)有一個性能很好的分類器作為teacher model,通過引入了一個 Distillation Token铐尚,然后在 self-attention layers 中跟 class token拨脉,patch token 在 Transformer 結(jié)構(gòu)中不斷學(xué)習(xí)。
Class token的目標(biāo)是跟真實(shí)的label一致宣增,而Distillation Token是要跟teacher model預(yù)測的label一致玫膀。
對比 ViT 的輸出是一個 softmax,它代表著預(yù)測結(jié)果屬于各個類別的概率的分布爹脾。ViT的做法是直接將 softmax 與 GT label取 CE Loss帖旨。
而在 DeiT 中,除了 CE Loss 以外灵妨,還要 1)定義蒸餾損失解阅;2)加上 Distillation Token。
- 定義蒸餾損失
蒸餾分兩種泌霍,一種是軟蒸餾(soft distillation)货抄,另一種是硬蒸餾(hard distillation)。軟蒸餾如下式所示朱转,Z_s 和 Z_t 分別是 student model 和 teacher model 的輸出蟹地,KL 表示 KL 散度,psi 表示softmax函數(shù)藤为,lambda 和 tau 是超參數(shù):
硬蒸餾如下式所示怪与,其中 CE 表示交叉熵:
學(xué)生網(wǎng)絡(luò)的輸出 Z_s 與真實(shí)標(biāo)簽之間計算 CE Loss 。如果是硬蒸餾缅疟,就再與教師網(wǎng)絡(luò)的標(biāo)簽取 CE Loss分别。如果是軟蒸餾,就再與教師網(wǎng)絡(luò)的 softmax 輸出結(jié)果取 KL Loss 存淫。
值得注意的是耘斩,Hard Label 也可以通過標(biāo)簽平滑技術(shù) (Label smoothing) 轉(zhuǎn)換成Soft Labe,其中真值對應(yīng)的標(biāo)簽被認(rèn)為具有 1- esilon 的概率桅咆,剩余的 esilon 由剩余的類別共享煌往。
- 加入 Distillation Token
Distillation Token 和 ViT 中的 class token 一起加入 Transformer 中,和class token 一樣通過 self-attention 與其它的 embedding 一起計算轧邪,并且在最后一層之后由網(wǎng)絡(luò)輸出刽脖。
而 Distillation Token 對應(yīng)的這個輸出的目標(biāo)函數(shù)就是蒸餾損失。Distillation Token 允許模型從教師網(wǎng)絡(luò)的輸出中學(xué)習(xí)忌愚,就像在常規(guī)的蒸餾中一樣曲管,同時也作為一種對class token的補(bǔ)充。
DeiT 具體實(shí)驗(yàn)
實(shí)驗(yàn)參數(shù)的設(shè)置:圖中表示不同大小的 DeiT 結(jié)構(gòu)的超參數(shù)設(shè)置硕糊,最大的結(jié)構(gòu)是 DeiT-B院水,與 ViT-B 結(jié)構(gòu)是相同,唯一不同的是 embedding 的 hidden dimension 和 head 數(shù)量简十。作者保持了每個head的隱變量維度為64檬某,throughput是一個衡量DeiT模型處理圖片速度的變量,代表每秒能夠處理圖片的數(shù)目螟蝙。
- Teacher model對比
作者首先觀察到使用 CNN 作為 teacher 比 transformer 作為 teacher 的性能更優(yōu)恢恼。下圖中對比了 teacher 網(wǎng)絡(luò)使用 DeiT-B 和幾個 CNN 模型 RegNetY 時,得到的 student 網(wǎng)絡(luò)的預(yù)訓(xùn)練性能以及 finetune 之后的性能胰默。
其中场斑,DeiT-B 384 代表使用分辨率為 384×384 的圖像 finetune 得到的模型,最后的那個小蒸餾符號 alembic sign 代表蒸餾以后得到的模型牵署。
- 蒸餾方法對比
下圖是不同蒸餾策略的性能對比漏隐,label 代表有監(jiān)督學(xué)習(xí),前3行分別是不使用蒸餾奴迅,使用soft蒸餾和使用hard蒸餾的性能對比青责。前3行不使用 Distillation Token 進(jìn)行訓(xùn)練,只是相當(dāng)于在原來 ViT 的基礎(chǔ)上給損失函數(shù)加上了蒸餾部分取具。
對于Transformer來講脖隶,硬蒸餾的性能明顯優(yōu)于軟蒸餾,即使只使用 class token者填,不使用 distill token浩村,硬蒸餾達(dá)到 83.0%,而軟蒸餾的精度為 81.8%占哟。
從最后兩列 B224 和 B384 看出心墅,以更高的分辨率進(jìn)行微調(diào)有助于減少方法之間的差異。這可能是因?yàn)樵谖⒄{(diào)時榨乎,作者不使用教師信息怎燥。隨著微調(diào),class token 和 Distillation Token 之間的相關(guān)性略有增加蜜暑。
除此之外铐姚,蒸餾模型在 accuracy 和 throughput 之間的 trade-off 甚至優(yōu)于 teacher 模型,這也反映了蒸餾的有趣之處。
- 性能對比
下面是不同模型性能的數(shù)值比較隐绵≈冢可以發(fā)現(xiàn)在參數(shù)量相當(dāng)?shù)那闆r下,卷積網(wǎng)絡(luò)的速度更慢依许,這是因?yàn)榇蟮木仃嚦朔ū刃【矸e提供了更多的優(yōu)化機(jī)會棺禾。EffcientNet-B4和DeiT-B alembic sign的速度相似,在3個數(shù)據(jù)集的性能也比較接近峭跳。
- 對比實(shí)驗(yàn)
作者還做了一些關(guān)于數(shù)據(jù)增強(qiáng)方法和優(yōu)化器的對比實(shí)驗(yàn)膘婶。Transformer的訓(xùn)練需要大量的數(shù)據(jù),想要在不太大的數(shù)據(jù)集上取得好性能蛀醉,就需要大量的數(shù)據(jù)增強(qiáng)悬襟,以實(shí)現(xiàn)data-efficient training。幾乎所有評測過的數(shù)據(jù)增強(qiáng)的方法都能提升性能拯刁。對于優(yōu)化器來說脊岳,AdamW比SGD性能更好。
此外筛璧,發(fā)現(xiàn)Transformer對優(yōu)化器的超參數(shù)很敏感逸绎,試了多組 lr 和 weight+decay。stochastic depth有利于收斂夭谤。Mixup 和 CutMix 都能提高性能棺牧。Exp.+Moving+Avg. 表示參數(shù)平滑后的模型,對性能提升只是略有幫助朗儒。最后就是 Repeated augmentation 的數(shù)據(jù)增強(qiáng)方式對于性能提升幫助很大颊乘。
小結(jié)
DeiT 模型(8600萬參數(shù))僅用一臺 GPU 服務(wù)器在 53 hours train,20 hours finetune醉锄,僅使用 ImageNet 就達(dá)到了 84.2 top-1 準(zhǔn)確性乏悄,而無需使用任何外部數(shù)據(jù)進(jìn)行訓(xùn)練,性能與最先進(jìn)的卷積神經(jīng)網(wǎng)絡(luò)(CNN)可以抗衡恳不。其核心是提出了針對 ViT 的教師-學(xué)生蒸餾訓(xùn)練策略檩小,并提出了 token-based distillation 方法,使得 Transformer 在視覺領(lǐng)域訓(xùn)練得又快又好烟勋。
引用
[1] https://zhuanlan.zhihu.com/p/349315675
[2] DeiT:使用Attention蒸餾Transformer
[3] https://zhuanlan.zhihu.com/p/102038521
[4] Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. "Distilling the knowledge in a neural network." arXiv preprint arXiv:1503.02531 2.7 (2015).
[5] Touvron, Hugo, et al. "Training data-efficient image transformers & distillation through attention." International Conference on Machine Learning. PMLR, 2021.
[6] Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).
[7] Wei, Longhui, et al. "Circumventing outliers of autoaugment with knowledge distillation." European Conference on Computer Vision. Springer, Cham, 2020.