知識(shí)蒸餾源自Hinton et al.于2014年發(fā)表在NIPS的一篇文章:Distilling the Knowledge in a Neural Network。
1. 背景
一般情況下,我們?cè)谟?xùn)練模型的時(shí)候使用了大量訓(xùn)練數(shù)據(jù)和計(jì)算資源來提取知識(shí),但這不方便在工業(yè)中部署美浦,原因有二:
(1)大模型推理速度慢
(2)對(duì)設(shè)備的資源要求高(大內(nèi)存)
因此我們希望對(duì)訓(xùn)練好的模型進(jìn)行壓縮,在保證推理效果的前提下減小模型的體量属拾,知識(shí)蒸餾(Knownledge Distillation)屬于模型壓縮的一種方法 [1]舞萄。
2. 知識(shí)蒸餾
名詞解釋:
cumbersome model:原始模型或者說大模型见转,但在后續(xù)的論文中一般稱它為teacher model;
distilled model:蒸餾后的小模型旗笔,在后續(xù)的論文中一般稱它為stududent model彪置;
hard targets:像[1, 0, 0]這樣的標(biāo)簽,也叫做ground-truth label换团;
soft targets:像[0.7, 0.2, 0.1]這樣的標(biāo)簽悉稠;
transfer set:訓(xùn)練student model的數(shù)據(jù)
好模型的目標(biāo)不是擬合訓(xùn)練數(shù)據(jù)宫蛆,而是學(xué)習(xí)如何泛化到新的數(shù)據(jù)艘包。所以蒸餾的目標(biāo)是讓student學(xué)習(xí)到teacher的泛化能力,理論上得到的結(jié)果會(huì)比單純擬合訓(xùn)練數(shù)據(jù)的student要好 [3]耀盗。顯然想虎,soft target可以提供更大的信息熵,所以studetn model可以學(xué)習(xí)到更多的信息叛拷。
通俗的來講舌厨,粗暴的使用one-hot編碼把原本有幫助的類內(nèi)variance和類間distance都忽略了,比如貓和狗的相似性要比貓與摩托車的相似性要多忿薇,狗的某些特征可能對(duì)識(shí)別貓也會(huì)有幫助(比如毛發(fā))裙椭,因此使用soft target可以恢復(fù)被one-hot編碼丟棄的信息 [2]。
在Hinton et al. 發(fā)表的這篇論文中署浩,作者提出了"softmax temperature"的概念揉燃,其公式為:
Python代碼:
import numpy as np
def softmax_t(x,t):
x_exp = np.exp(x / t)
return x_exp / np.sum(x_exp)
代表第
類的輸出概率,
和
為softmax的輸入筋栋,即上一層神經(jīng)元的輸出(logits)炊汤,T表示temperature參數(shù)。通常情況下弊攘,我們使用的softmax函數(shù)T為1抢腐,但
可以控制輸出soft的程度。比如對(duì)于
襟交,我們分別取
迈倍,然后畫出softmax函數(shù)的輸出可以看到,
越小捣域,輸出的預(yù)測(cè)結(jié)果越“硬”(曲線更加曲折)啼染,T越大輸出的結(jié)果越“軟”(曲線更加平和)。
插一句題外話竟宋,為什么這里的參數(shù)是叫溫度(temperature)呢提完?這和蒸餾(distillation)這一熱力學(xué)工藝有關(guān)。在蒸餾工藝中丘侠,溫度越高提取到的物質(zhì)越純?cè)綕饪s徒欣。而在知識(shí)蒸餾中,參數(shù)T越大(溫度越高)蜗字,teacher model產(chǎn)生的label越"soft"打肝,信息熵就越高脂新,提煉的知識(shí)更具有一般性(generalization)。所以說作者將這一參數(shù)取名temperature十分有趣粗梭。
知識(shí)蒸餾的實(shí)現(xiàn)過程可以概括為:
- 訓(xùn)練teacher model;
- 使用高溫T將teacher model中的知識(shí)蒸餾到student model(在測(cè)試時(shí)溫度T設(shè)為1)断医。
student modeld的目標(biāo)函數(shù)由一下兩項(xiàng)的加權(quán)平均組成:
- distillation loss:soft targets(由teacher model產(chǎn)生) 和student model的soft predictions的交叉熵滞乙,這里的T使用的是和訓(xùn)練teacher model相同的值。(保證student model和teacher model的結(jié)果盡可能一致)
- student loss:hard targets 和student model的輸出數(shù)據(jù)的交叉熵鉴嗤,但T設(shè)置為1斩启。(保證student model的結(jié)果和實(shí)際類別標(biāo)簽盡可能一致)
總體的損失函數(shù)可以寫作:
其中,表示輸入醉锅,
表示student model的參數(shù)兔簇,
是ground-truth label,
是交叉熵?fù)p失函數(shù)硬耍,
是剛剛提到的softmax temperature激活函數(shù)垄琐,
和
分別表示student和teacher model神經(jīng)元的輸出(logits),
和
表示兩個(gè)權(quán)重參數(shù) [4].
原論文指出,要比
相對(duì)小一些可以取得更好的結(jié)果经柴,因?yàn)樵谇筇荻葧r(shí)soft targets被縮放了
狸窘,所以第2項(xiàng)要乘以一個(gè)更小的權(quán)值來平衡二者在優(yōu)化時(shí)的比重 [1].
換一個(gè)角度來想,這里的知識(shí)蒸餾其實(shí)是相對(duì)于對(duì)于原始交叉熵添加了一個(gè)正則項(xiàng):
利用teacher model的先驗(yàn)知識(shí)對(duì)student model進(jìn)行正則化 [5]口锭。
References:
[1] Distilling the Knowledge in a Neural Network.
[2] # Distilling the Knowledge in a Neural Network 論文筆記
[3] 深度神經(jīng)網(wǎng)絡(luò)模型蒸餾Distillation
[4] Knowledge Distillation
[5] 神經(jīng)網(wǎng)絡(luò)知識(shí)蒸餾 Knowledge Distillation