DeiT:注意力也能蒸餾

DeiT:注意力也能蒸餾

《Training data-ef?cient image transformers & distillation through attention》

ViT 在大數(shù)據(jù)集 ImageNet-21k(14million)或者 JFT-300M(300million) 上進(jìn)行訓(xùn)練,Batch Size 128 下 NVIDIA A100 32G GPU 的計算資源加持下預(yù)訓(xùn)練 ViT-Base/32 需要3天時間。

Facebook 與索邦大學(xué) Matthieu Cord 教授合作發(fā)表 Training data-efficient image transformers(DeiT) & distillation through attention朱盐,DeiT 模型(8600萬參數(shù))僅用一臺 GPU 服務(wù)器在 53 hours train台囱,20 hours finetune扑庞,僅使用 ImageNet 就達(dá)到了 84.2 top-1 準(zhǔn)確性,而無需使用任何外部數(shù)據(jù)進(jìn)行訓(xùn)練拒逮。性能與最先進(jìn)的卷積神經(jīng)網(wǎng)絡(luò)(CNN)可以抗衡罐氨。所以呢,很有必要講講這個 DeiT 網(wǎng)絡(luò)模型的相關(guān)內(nèi)容滩援。

下面來簡單總結(jié) DeiT:

DeiT 是一個全 Transformer 的架構(gòu)栅隐。其核心是提出了針對 ViT 的教師-學(xué)生蒸餾訓(xùn)練策略,并提出了 token-based distillation 方法玩徊,使得 Transformer 在視覺領(lǐng)域訓(xùn)練得又快又好租悄。

image.png

DeiT 相關(guān)背景

ViT 文中表示數(shù)據(jù)量不足會導(dǎo)致 ViT 效果變差。針對以上問題恩袱,DeiT 核心共享是使用了蒸餾策略恰矩,能夠僅使用 ImageNet-1K 數(shù)據(jù)集就就可以達(dá)到 83.1% 的 Top1。

那么文章主要貢獻(xiàn)可以總結(jié)為三點(diǎn):

  1. 僅使用 Transformer憎蛤,不引入 Conv 的情況下也能達(dá)到 SOTA 效果外傅。
  2. 提出了基于 token 蒸餾的策略,針對 Transformer 蒸餾方法超越傳統(tǒng)蒸餾方法俩檬。
  3. DeiT 發(fā)現(xiàn)使用 Convnet 作為教師網(wǎng)絡(luò)能夠比使用 Transformer 架構(gòu)效果更好萎胰。

正式了解 DeiT 算法之前呢,有幾個問題需要去了解的:ViT的缺點(diǎn)和局限性棚辽,為什么訓(xùn)練ViT要準(zhǔn)備這么多數(shù)據(jù)技竟,就不能簡單快速訓(xùn)練一個模型出來嗎?另外 Transformer 視覺模型又怎么玩蒸餾呢屈藐?

ViT 的缺點(diǎn)和局限性

Transformer的輸入是一個序列(Sequence)榔组,ViT 所采用的思路是把圖像分塊(patches),然后把每一塊視為一個向量(vector)联逻,所有的向量并在一起就成為了一個序列(Sequence)搓扯,ViT 使用的數(shù)據(jù)集包括了一個巨大的包含了 300 million images的 JFT-300,這個數(shù)據(jù)集是私有的包归,即外部研究者無法復(fù)現(xiàn)實(shí)驗(yàn)锨推。而且在ViT的實(shí)驗(yàn)中作者明確地提到:

"That transformers do not generalize well when trained on insufficient amounts of data."

image.png

意思是當(dāng)不使用 JFT-300 大數(shù)據(jù)集時,效果不如CNN模型公壤。也就反映出Transformer結(jié)構(gòu)若想取得理想的性能和泛化能力就需要這樣大的數(shù)據(jù)集换可。DeiT 作者通過所提出的蒸餾的訓(xùn)練方案,只在 Imagenet 上進(jìn)行訓(xùn)練厦幅,就產(chǎn)生了一個有競爭力的無卷積 Transformer沾鳄。

ViT 相關(guān)技術(shù)點(diǎn)

Multi-head Self Attention layers (MSA):

首先有一個 Query 矩陣 Q 和一個 Key 矩陣 K,把二者矩陣乘在一起并進(jìn)行歸一化以后得到 attention 矩陣确憨,它再與Value矩陣 V 相乘得到最終的輸出得到 Z译荞。最后經(jīng)過 linear transformation 得到 NxD 的輸出結(jié)果瓤的。

image.png

Feed-Forward Network (FFN):

Multi-head Self Attention layers 之后往往會跟上一個 Feed-Forward Network (FFN) ,它一般是由2個linear layer構(gòu)成磁椒,第1個linear layer把維度從 D 維變換到 ND 維堤瘤,第2個linear layer把維度從 ND 維再變換到 D 維。

此時 Transformer block 是不考慮位置信息的浆熔,基于此 ViT 加入了位置編碼 (Positional Encoding)本辐,這些編碼在第一個 block 之前被添加到 input token 中代表位置信息,作為額外可學(xué)習(xí)的embedding(Extra learnable class embedding)医增。

Class token:

Class token 與 input token 并在一起輸入 Transformer block 中慎皱,最后的輸出結(jié)果用來預(yù)測類別。這樣一來叶骨,Transformer 相當(dāng)于一共處理了 N+1 個維度為 D 的token茫多,并且只有第一個 token 的輸出用來預(yù)測類別。

知識蒸餾介紹

Knowledge Distillation(KD)最初被 Hinton 提出 “Distilling the Knowledge in a Neural Network”忽刽,與 Label smoothing 動機(jī)類似天揖,但是 KD 生成 soft label 的方式是通過教師網(wǎng)絡(luò)得到的。

KD 可以視為將教師網(wǎng)絡(luò)學(xué)到的信息壓縮到學(xué)生網(wǎng)絡(luò)中跪帝。還有一些工作 “Circumventing outlier of autoaugment with knowledge distillation” 則將 KD 視為數(shù)據(jù)增強(qiáng)方法的一種今膊。

提出背景

雖然在一般情況下,我們不會去區(qū)分訓(xùn)練和部署使用的模型伞剑,但是訓(xùn)練和部署之間存在著一定的不一致性斑唬。在訓(xùn)練過程中,我們需要使用復(fù)雜的模型黎泣,大量的計算資源恕刘,以便從非常大照雁、高度冗余的數(shù)據(jù)集中提取出信息矮湘。在實(shí)驗(yàn)中,效果最好的模型往往規(guī)模很大豆挽,甚至由多個模型集成得到衡便。而大模型不方便部署到服務(wù)中去献起,常見的瓶頸如下:

  • 推理速度和性能慢
  • 對部署資源要求高(內(nèi)存,顯存等)

在部署時镣陕,對延遲以及計算資源都有著嚴(yán)格的限制。因此姻政,模型壓縮(在保證性能的前提下減少模型的參數(shù)量)成為了一個重要的問題呆抑,而“模型蒸餾”屬于模型壓縮的一種方法。

理論原理

知識蒸餾使用的是 Teacher—Student 模型汁展,其中 Teacher 是“知識”的輸出者鹊碍,Student 是“知識”的接受者厌殉。知識蒸餾的過程分為2個階段:

  1. 原始模型訓(xùn)練: 訓(xùn)練 "Teacher模型", 簡稱為Net-T,它的特點(diǎn)是模型相對復(fù)雜侈咕,也可以由多個分別訓(xùn)練的模型集成而成公罕。我們對"Teacher模型"不作任何關(guān)于模型架構(gòu)、參數(shù)量耀销、是否集成方面的限制楼眷,唯一的要求就是,對于輸入X, 其都能輸出Y熊尉,其中Y經(jīng)過softmax的映射罐柳,輸出值對應(yīng)相應(yīng)類別的概率值。
  2. 精簡模型訓(xùn)練: 訓(xùn)練"Student模型", 簡稱為Net-S狰住,它是參數(shù)量較小张吉、模型結(jié)構(gòu)相對簡單的單模型。同樣的催植,對于輸入X肮蛹,其都能輸出Y,Y經(jīng)過softmax映射后同樣能輸出對應(yīng)相應(yīng)類別的概率值创南。

論文中伦忠,Hinton 將問題限定在分類問題下,或者其他本質(zhì)上屬于分類問題的問題扰藕,該類問題的共同點(diǎn)是模型最后會有一個softmax層缓苛,其輸出值對應(yīng)了相應(yīng)類別的概率值。知識蒸餾時邓深,由于已經(jīng)有了一個泛化能力較強(qiáng)的Net-T未桥,我們在利用Net-T來蒸餾訓(xùn)練Net-S時,可以直接讓Net-S去學(xué)習(xí)Net-T的泛化能力芥备。

其中KD的訓(xùn)練過程和傳統(tǒng)的訓(xùn)練過程的對比:

  1. 傳統(tǒng)training過程 Hard Targets: 對 ground truth 求極大似然 Softmax 值冬耿。
  2. KD的training過程 Soft Targets: 用 Teacher 模型的 class probabilities作為soft targets。
image.png

這就解釋了為什么通過蒸餾的方法訓(xùn)練出的 Net-S 相比使用完全相同的模型結(jié)構(gòu)和訓(xùn)練數(shù)據(jù)只使用Hard Targets的訓(xùn)練方法得到的模型萌壳,擁有更好的泛化能力亦镶。

具體方法

第一步是訓(xùn)練Net-T;第二步是在高溫 T 下袱瓮,蒸餾 Net-T 的知識到 Net-S缤骨。

image.png

訓(xùn)練 Net-T 的過程很簡單,而高溫蒸餾過程的目標(biāo)函數(shù)由distill loss(對應(yīng)soft target)和student loss(對應(yīng)hard target)加權(quán)得到:

image.png

Deit 中使用 Conv-Based 架構(gòu)作為教師網(wǎng)絡(luò)尺借,以 soft 的方式將歸納偏置傳遞給學(xué)生模型绊起,將局部性的假設(shè)通過蒸餾方式引入 Transformer 中,取得了不錯的效果燎斩。

DeiT 具體方法

為什么DeiT能在大幅減少 1. 訓(xùn)練所需的數(shù)據(jù)集2. 訓(xùn)練時長 的情況下依舊能夠取得很不錯的性能呢虱歪?我們可以把這個原因歸結(jié)為DeiT的訓(xùn)練策略蜂绎。ViT 在小數(shù)據(jù)集上的性能不如使用CNN網(wǎng)絡(luò) EfficientNet,但是跟ViT結(jié)構(gòu)相同笋鄙,僅僅是使用更好的訓(xùn)練策略的DeiT比ViT的性能已經(jīng)有了很大的提升师枣,在此基礎(chǔ)上,再加上蒸餾 (distillation) 操作萧落,性能超過了 EfficientNet践美。

假設(shè)有一個性能很好的分類器作為teacher model,通過引入了一個 Distillation Token铐尚,然后在 self-attention layers 中跟 class token拨脉,patch token 在 Transformer 結(jié)構(gòu)中不斷學(xué)習(xí)。

Class token的目標(biāo)是跟真實(shí)的label一致宣增,而Distillation Token是要跟teacher model預(yù)測的label一致玫膀。

image.png

對比 ViT 的輸出是一個 softmax,它代表著預(yù)測結(jié)果屬于各個類別的概率的分布爹脾。ViT的做法是直接將 softmax 與 GT label取 CE Loss帖旨。

image.png

而在 DeiT 中,除了 CE Loss 以外灵妨,還要 1)定義蒸餾損失解阅;2)加上 Distillation Token。

  1. 定義蒸餾損失

蒸餾分兩種泌霍,一種是軟蒸餾(soft distillation)货抄,另一種是硬蒸餾(hard distillation)。軟蒸餾如下式所示朱转,Z_s 和 Z_t 分別是 student model 和 teacher model 的輸出蟹地,KL 表示 KL 散度,psi 表示softmax函數(shù)藤为,lambda 和 tau 是超參數(shù):

image.png

硬蒸餾如下式所示怪与,其中 CE 表示交叉熵:

image.png

學(xué)生網(wǎng)絡(luò)的輸出 Z_s 與真實(shí)標(biāo)簽之間計算 CE Loss 。如果是硬蒸餾缅疟,就再與教師網(wǎng)絡(luò)的標(biāo)簽取 CE Loss分别。如果是軟蒸餾,就再與教師網(wǎng)絡(luò)的 softmax 輸出結(jié)果取 KL Loss 存淫。

值得注意的是耘斩,Hard Label 也可以通過標(biāo)簽平滑技術(shù) (Label smoothing) 轉(zhuǎn)換成Soft Labe,其中真值對應(yīng)的標(biāo)簽被認(rèn)為具有 1- esilon 的概率桅咆,剩余的 esilon 由剩余的類別共享煌往。

  1. 加入 Distillation Token

Distillation Token 和 ViT 中的 class token 一起加入 Transformer 中,和class token 一樣通過 self-attention 與其它的 embedding 一起計算轧邪,并且在最后一層之后由網(wǎng)絡(luò)輸出刽脖。

而 Distillation Token 對應(yīng)的這個輸出的目標(biāo)函數(shù)就是蒸餾損失。Distillation Token 允許模型從教師網(wǎng)絡(luò)的輸出中學(xué)習(xí)忌愚,就像在常規(guī)的蒸餾中一樣曲管,同時也作為一種對class token的補(bǔ)充。

image.png

DeiT 具體實(shí)驗(yàn)

實(shí)驗(yàn)參數(shù)的設(shè)置:圖中表示不同大小的 DeiT 結(jié)構(gòu)的超參數(shù)設(shè)置硕糊,最大的結(jié)構(gòu)是 DeiT-B院水,與 ViT-B 結(jié)構(gòu)是相同,唯一不同的是 embedding 的 hidden dimension 和 head 數(shù)量简十。作者保持了每個head的隱變量維度為64檬某,throughput是一個衡量DeiT模型處理圖片速度的變量,代表每秒能夠處理圖片的數(shù)目螟蝙。

image.png
  1. Teacher model對比

作者首先觀察到使用 CNN 作為 teacher 比 transformer 作為 teacher 的性能更優(yōu)恢恼。下圖中對比了 teacher 網(wǎng)絡(luò)使用 DeiT-B 和幾個 CNN 模型 RegNetY 時,得到的 student 網(wǎng)絡(luò)的預(yù)訓(xùn)練性能以及 finetune 之后的性能胰默。

其中场斑,DeiT-B 384 代表使用分辨率為 384×384 的圖像 finetune 得到的模型,最后的那個小蒸餾符號 alembic sign 代表蒸餾以后得到的模型牵署。

image.png
  1. 蒸餾方法對比

下圖是不同蒸餾策略的性能對比漏隐,label 代表有監(jiān)督學(xué)習(xí),前3行分別是不使用蒸餾奴迅,使用soft蒸餾和使用hard蒸餾的性能對比青责。前3行不使用 Distillation Token 進(jìn)行訓(xùn)練,只是相當(dāng)于在原來 ViT 的基礎(chǔ)上給損失函數(shù)加上了蒸餾部分取具。

對于Transformer來講脖隶,硬蒸餾的性能明顯優(yōu)于軟蒸餾,即使只使用 class token者填,不使用 distill token浩村,硬蒸餾達(dá)到 83.0%,而軟蒸餾的精度為 81.8%占哟。

image.png

從最后兩列 B224 和 B384 看出心墅,以更高的分辨率進(jìn)行微調(diào)有助于減少方法之間的差異。這可能是因?yàn)樵谖⒄{(diào)時榨乎,作者不使用教師信息怎燥。隨著微調(diào),class token 和 Distillation Token 之間的相關(guān)性略有增加蜜暑。

除此之外铐姚,蒸餾模型在 accuracy 和 throughput 之間的 trade-off 甚至優(yōu)于 teacher 模型,這也反映了蒸餾的有趣之處。

  1. 性能對比

下面是不同模型性能的數(shù)值比較隐绵≈冢可以發(fā)現(xiàn)在參數(shù)量相當(dāng)?shù)那闆r下,卷積網(wǎng)絡(luò)的速度更慢依许,這是因?yàn)榇蟮木仃嚦朔ū刃【矸e提供了更多的優(yōu)化機(jī)會棺禾。EffcientNet-B4和DeiT-B alembic sign的速度相似,在3個數(shù)據(jù)集的性能也比較接近峭跳。

image.png
  1. 對比實(shí)驗(yàn)

作者還做了一些關(guān)于數(shù)據(jù)增強(qiáng)方法和優(yōu)化器的對比實(shí)驗(yàn)膘婶。Transformer的訓(xùn)練需要大量的數(shù)據(jù),想要在不太大的數(shù)據(jù)集上取得好性能蛀醉,就需要大量的數(shù)據(jù)增強(qiáng)悬襟,以實(shí)現(xiàn)data-efficient training。幾乎所有評測過的數(shù)據(jù)增強(qiáng)的方法都能提升性能拯刁。對于優(yōu)化器來說脊岳,AdamW比SGD性能更好。

此外筛璧,發(fā)現(xiàn)Transformer對優(yōu)化器的超參數(shù)很敏感逸绎,試了多組 lr 和 weight+decay。stochastic depth有利于收斂夭谤。Mixup 和 CutMix 都能提高性能棺牧。Exp.+Moving+Avg. 表示參數(shù)平滑后的模型,對性能提升只是略有幫助朗儒。最后就是 Repeated augmentation 的數(shù)據(jù)增強(qiáng)方式對于性能提升幫助很大颊乘。

image.png

小結(jié)

DeiT 模型(8600萬參數(shù))僅用一臺 GPU 服務(wù)器在 53 hours train,20 hours finetune醉锄,僅使用 ImageNet 就達(dá)到了 84.2 top-1 準(zhǔn)確性乏悄,而無需使用任何外部數(shù)據(jù)進(jìn)行訓(xùn)練,性能與最先進(jìn)的卷積神經(jīng)網(wǎng)絡(luò)(CNN)可以抗衡恳不。其核心是提出了針對 ViT 的教師-學(xué)生蒸餾訓(xùn)練策略檩小,并提出了 token-based distillation 方法,使得 Transformer 在視覺領(lǐng)域訓(xùn)練得又快又好烟勋。

引用

[1] https://zhuanlan.zhihu.com/p/349315675

[2] DeiT:使用Attention蒸餾Transformer

[3] https://zhuanlan.zhihu.com/p/102038521

[4] Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. "Distilling the knowledge in a neural network." arXiv preprint arXiv:1503.02531 2.7 (2015).

[5] Touvron, Hugo, et al. "Training data-efficient image transformers & distillation through attention." International Conference on Machine Learning. PMLR, 2021.

[6] Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).

[7] Wei, Longhui, et al. "Circumventing outliers of autoaugment with knowledge distillation." European Conference on Computer Vision. Springer, Cham, 2020.

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末规求,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子卵惦,更是在濱河造成了極大的恐慌阻肿,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,496評論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件沮尿,死亡現(xiàn)場離奇詭異丛塌,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,407評論 3 392
  • 文/潘曉璐 我一進(jìn)店門赴邻,熙熙樓的掌柜王于貴愁眉苦臉地迎上來印衔,“玉大人,你說我怎么就攤上這事乍楚〉北啵” “怎么了?”我有些...
    開封第一講書人閱讀 162,632評論 0 353
  • 文/不壞的土叔 我叫張陵徒溪,是天一觀的道長。 經(jīng)常有香客問我金顿,道長臊泌,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,180評論 1 292
  • 正文 為了忘掉前任揍拆,我火速辦了婚禮渠概,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘嫂拴。我一直安慰自己播揪,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,198評論 6 388
  • 文/花漫 我一把揭開白布筒狠。 她就那樣靜靜地躺著猪狈,像睡著了一般。 火紅的嫁衣襯著肌膚如雪辩恼。 梳的紋絲不亂的頭發(fā)上雇庙,一...
    開封第一講書人閱讀 51,165評論 1 299
  • 那天,我揣著相機(jī)與錄音灶伊,去河邊找鬼疆前。 笑死,一個胖子當(dāng)著我的面吹牛聘萨,可吹牛的內(nèi)容都是我干的竹椒。 我是一名探鬼主播,決...
    沈念sama閱讀 40,052評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼米辐,長吁一口氣:“原來是場噩夢啊……” “哼胸完!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起儡循,我...
    開封第一講書人閱讀 38,910評論 0 274
  • 序言:老撾萬榮一對情侶失蹤舶吗,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后择膝,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體誓琼,經(jīng)...
    沈念sama閱讀 45,324評論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,542評論 2 332
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了腹侣。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片叔收。...
    茶點(diǎn)故事閱讀 39,711評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖傲隶,靈堂內(nèi)的尸體忽然破棺而出饺律,到底是詐尸還是另有隱情,我是刑警寧澤跺株,帶...
    沈念sama閱讀 35,424評論 5 343
  • 正文 年R本政府宣布复濒,位于F島的核電站,受9級特大地震影響乒省,放射性物質(zhì)發(fā)生泄漏巧颈。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,017評論 3 326
  • 文/蒙蒙 一袖扛、第九天 我趴在偏房一處隱蔽的房頂上張望砸泛。 院中可真熱鬧,春花似錦蛆封、人聲如沸唇礁。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,668評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽盏筐。三九已至,卻和暖如春妒蛇,著一層夾襖步出監(jiān)牢的瞬間机断,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,823評論 1 269
  • 我被黑心中介騙來泰國打工绣夺, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留吏奸,地道東北人。 一個月前我還...
    沈念sama閱讀 47,722評論 2 368
  • 正文 我出身青樓陶耍,卻偏偏與公主長得像奋蔚,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子烈钞,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,611評論 2 353

推薦閱讀更多精彩內(nèi)容