前言
目前提高機(jī)器學(xué)習(xí)算法性能的方法幾乎都用多模型ensamble,在計(jì)算上非常昂貴且難以部署趣兄,尤其是大型神經(jīng)網(wǎng)絡(luò),比如bert。
知識(shí)蒸餾就是希望用小型模型得到跟大型復(fù)雜模型一樣的性能弦撩。人們普遍認(rèn)為,用于訓(xùn)練的目標(biāo)函數(shù)應(yīng)盡可能地反映任務(wù)的真實(shí)目標(biāo)论皆。在訓(xùn)練過(guò)程中益楼,往往以最優(yōu)化訓(xùn)練集的準(zhǔn)確率作為訓(xùn)練目標(biāo),但真實(shí)目標(biāo)其實(shí)應(yīng)該是最優(yōu)化模型的泛化能力点晴。顯然如果能直接以提升模型的泛化能力為目標(biāo)進(jìn)行訓(xùn)練是最好的偏形,但這需要正確的關(guān)于泛化能力的信息,而這些信息通常不可用觉鼻。
那么怎么能獲得可以利用的泛化能力的信息呢俊扭?
在一般的分類訓(xùn)練任務(wù)中,我們以softmax層輸出各個(gè)類別的概率坠陈,然后以跟one-hot lables的交叉熵作為loss function萨惑,這個(gè)loss function丟失了在其他類別上的概率,只把正確label上的概率值考慮進(jìn)來(lái)仇矾。但其實(shí)在錯(cuò)誤label上的概率分布也是具有價(jià)值的庸蔼。比如狗的樣本識(shí)別為老虎上的錯(cuò)誤概率會(huì)比識(shí)別為螞蟻的錯(cuò)誤概率大,因?yàn)??更像??贮匕,而不像??姐仅。這在一定程度上反映了該分類器的泛化機(jī)制,反映了它“腦袋里的知識(shí)”刻盐。
知識(shí)蒸餾
如果我們使用由大型模型產(chǎn)生的所有類概率作為訓(xùn)練小模型的目標(biāo)掏膏,是否就是直接以“泛化能力”作為目標(biāo)函數(shù)呢?研究證明敦锌,這樣的方法確實(shí)可以讓小模型得到不輸大模型的性能馒疹,而且有時(shí)甚至青出于藍(lán)勝于藍(lán)。這種把大模型的“知識(shí)”遷移到小模型的方式乙墙,我們稱之為“蒸餾”(濃縮就是精華)颖变。有人用單層BiLSTM對(duì)bert進(jìn)行蒸餾生均,效果不輸ELMo。(詳細(xì)可看論文 Distilling Task-Specific Knowledge from BERT into Simple Neural Networks )
這里先定義兩個(gè)概念腥刹÷黼剩“硬目標(biāo)(softmax)”:正確標(biāo)簽的交叉熵∠畏澹“軟目標(biāo)”(soft_softmax):大模型產(chǎn)生的類概率的交叉熵漓雅。
soft_softmax公式如下:
可以看到,它跟softmax比起來(lái)就是在指數(shù)項(xiàng)里多了一個(gè)“T”朽色,這個(gè)T稱為蒸餾溫度邻吞。為什么要加T呢?假如我們分三類葫男,然后網(wǎng)絡(luò)最后的輸出是[1.0 2.0 3.0]抱冷,我們可以很容易的計(jì)算出,傳統(tǒng)的softmax(即T=1)對(duì)此進(jìn)行處理后得到的概率為[0.09 0.24 0.67]梢褐,而當(dāng)T=4的時(shí)候旺遮,得到的概率則為[0.25 0.33 0.42]∮龋可以看出耿眉,當(dāng)T變大的時(shí)候輸出的概率分布變得平緩了,這就蒸餾溫度的作用鱼响。這時(shí)候得到的概率分布我們稱之為“soft target label”鸣剪。我們?cè)谟?xùn)練小模型的時(shí)候需要用到“soft target label”。
在訓(xùn)練小模型時(shí)丈积,目標(biāo)函數(shù)為:
其中
為soft target lable筐骇,這里T要跟蒸餾復(fù)雜模型時(shí)的T大小一致,也就是保持同樣的蒸餾溫度江滨,避免改變“知識(shí)”分布铛纬。注意:小模型在做預(yù)測(cè)時(shí)蒸餾溫度要還原為1,也就是用原始概率分布做預(yù)測(cè),因?yàn)樵兕A(yù)測(cè)時(shí)希望正確標(biāo)簽與錯(cuò)誤標(biāo)簽的概率差距盡量大,與蒸餾時(shí)的希望平緩區(qū)別開(kāi)來(lái)唬滑。
實(shí)際上可以這么理解,知識(shí)蒸餾是在本來(lái)的目標(biāo)函數(shù)上加上了正則項(xiàng)告唆,正則項(xiàng)可以提高模型的泛化能力,把軟目標(biāo)當(dāng)作正則項(xiàng)就是讓小模型的泛化能力盡量接近復(fù)雜模型的泛化能力晶密。軟目標(biāo)具有高熵值時(shí)擒悬,它們?yōu)槊總€(gè)訓(xùn)練案例提供比硬目標(biāo)更多的信息,并且在訓(xùn)練案例之間梯度的變化更小惹挟,因此小模型通城洋Γ可以在比原始繁瑣模型少得多的數(shù)據(jù)上訓(xùn)練并可以使用更高的學(xué)習(xí)率加快訓(xùn)練過(guò)程缝驳。
總結(jié)一下:
知識(shí)蒸餾就是
1.從復(fù)雜模型中得到“soft target label”连锯。
2.在訓(xùn)練小模型時(shí)同時(shí)訓(xùn)練硬目標(biāo)和軟目標(biāo)归苍。