知識(shí)蒸餾-Distilling the knowledge in a neural network
作者:支廣達(dá)
1. 概念介紹
“很多昆蟲在幼蟲形態(tài)的時(shí)候是最擅長(zhǎng)從環(huán)境中吸取能量和養(yǎng)分的定庵,而當(dāng)他們成長(zhǎng)為成蟲的時(shí)候則需要擅長(zhǎng)完全不同能力比如遷移和繁殖侠坎∮畲校”在2014年Hinton發(fā)表的知識(shí)蒸餾的論文中用了這樣一個(gè)很形象的比喻來說明知識(shí)蒸餾的目的懈息。在大型的機(jī)器學(xué)習(xí)任務(wù)中俺陋,我們也用兩個(gè)不同的階段 training stage 和 deployment stage 來表達(dá)兩種不同的需求国拇。training stage(訓(xùn)練階段)可以利用大量的計(jì)算資源不需要實(shí)時(shí)響應(yīng),利用大量的數(shù)據(jù)進(jìn)行訓(xùn)練上忍。但是在deployment stage (部署階段)則會(huì)有很多限制舍悯,比如計(jì)算資源,計(jì)算速度要求等睡雇。知識(shí)蒸餾就是為了滿足這種需求而設(shè)計(jì)的一種模型壓縮的方法萌衬。
知識(shí)蒸餾的概念最早是在2006年由Bulica提出的,在2014年Hinton對(duì)知識(shí)蒸餾做了歸納和發(fā)展它抱。知識(shí)蒸餾的主要思想是訓(xùn)練一個(gè)小的網(wǎng)絡(luò)模型來模仿一個(gè)預(yù)先訓(xùn)練好的大型網(wǎng)絡(luò)或者集成的網(wǎng)絡(luò)秕豫。這種訓(xùn)練模式又被稱為 "teacher-student",大型的網(wǎng)絡(luò)是“老師”观蓄,小型的網(wǎng)絡(luò)是“學(xué)生”混移。
在知識(shí)蒸餾中,老師將知識(shí)傳授給學(xué)生的方法是:在訓(xùn)練學(xué)生的過程中最小化一個(gè)以老師預(yù)測(cè)結(jié)果的概率分布為目標(biāo)的損失函數(shù)侮穿。老師預(yù)測(cè)的概率分布就是老師模型的最后的softmax函數(shù)層的輸出歌径,然而,在很多情況下傳統(tǒng)的softmax層的輸出亲茅,正確的分類的概率值非常大回铛,而其他分類的概率值幾乎接近于0。因此克锣,這樣并不會(huì)比原始的數(shù)據(jù)集提供更多有用的信息茵肃,沒有利用到老師強(qiáng)大的泛化性能,比如袭祟,訓(xùn)練MNIST任務(wù)中數(shù)字‘3’相對(duì)于數(shù)字‘5’與數(shù)字‘8’的關(guān)系更加緊密验残。為了解決這個(gè)問題,Hinton在2015年發(fā)表的論文中提出了‘softmax temperature’的概念巾乳,對(duì)softmax函數(shù)做了改進(jìn):
這里的 就是指 temperature 參數(shù)您没。當(dāng)
等于1 時(shí)就是標(biāo)準(zhǔn)的softmax函數(shù)鸟召。當(dāng)
增大時(shí),softmax輸出的概率分布就會(huì)變得更加 soft(平滑)氨鹏,這樣就可以利用到老師模型的更多信息(老師覺得哪些類別更接近于要預(yù)測(cè)的類別)欧募。Hinton將這樣的蘊(yùn)含在老師模型中的信息稱之為 "dark knowledge",蒸餾的方法就是要將這些 "dark knowledge" 傳給學(xué)生模型喻犁。在訓(xùn)練學(xué)生的時(shí)候槽片,學(xué)生的softmax函數(shù)使用與老師的相同的
何缓,損失函數(shù)以老師輸出的軟標(biāo)簽為目標(biāo)肢础。這樣的損失函數(shù)我們稱為"distillation loss"。
在Hinton的論文中碌廓,還發(fā)現(xiàn)了在訓(xùn)練過程加上正確的數(shù)據(jù)標(biāo)簽(hard label)會(huì)使效果更好传轰。具體方法是,在計(jì)算distillation loss的同時(shí)谷婆,我利用hard label 把標(biāo)準(zhǔn)的損失()也計(jì)算出來慨蛙,這個(gè)損失我們稱之為 "student loss"。將兩種 loss 整合的公式如下:
這里的 是輸入纪挎,
是學(xué)生模型的參數(shù)期贫,
是交叉熵?fù)p失函數(shù),
是 hard label 异袄,
是參數(shù)有
的函數(shù)通砍,
是系數(shù),
分別是學(xué)生和老師的logits輸出烤蜕。模型的具體結(jié)構(gòu)如下圖所示:
2.超參數(shù)的調(diào)整
在上述公式中封孙, 是作為超參數(shù)人為設(shè)置的,Hinton的論文中使用的
的范圍為1到20讽营,他們通過實(shí)驗(yàn)發(fā)現(xiàn)虎忌,當(dāng)學(xué)生模型相對(duì)于老師模型非常小的時(shí)候,
的值相對(duì)小一點(diǎn)效果更好橱鹏。這樣的結(jié)果直觀的理解就是膜蠢,如果增加
的值,軟標(biāo)簽的分布蘊(yùn)含的信息越多導(dǎo)致一個(gè)小的模型無法"捕捉"所有信息但是這也只是一種假設(shè)莉兰,還沒有明確的方法來衡量一個(gè)網(wǎng)絡(luò)“捕捉”信息的能力狡蝶。關(guān)于
,Hinton的論文中對(duì)兩個(gè)loss用了加權(quán)平均:
贮勃。他們實(shí)驗(yàn)發(fā)現(xiàn)贪惹,在普通情況下
相對(duì)于
非常小的情況下能得到最好的效果。其他人也做了一些實(shí)驗(yàn)沒用加權(quán)平均寂嘉,將
設(shè)置為1奏瞬,而對(duì)
進(jìn)行調(diào)整枫绅。
3.實(shí)驗(yàn)
Hinton的論文中做了三個(gè)實(shí)驗(yàn),前兩個(gè)是MNIST和語音識(shí)別硼端,在這兩個(gè)實(shí)驗(yàn)中通過知識(shí)蒸餾得到的學(xué)生模型都達(dá)到了與老師模型相近的效果并淋,相對(duì)于直接在原始數(shù)據(jù)集上訓(xùn)練的相同的模型在準(zhǔn)確率上都有很大的提高。下面主要講述第三個(gè)比較創(chuàng)新的實(shí)驗(yàn):將知識(shí)蒸餾應(yīng)用在訓(xùn)練集成模型中珍昨。
3.1模型介紹
訓(xùn)練集成模型(訓(xùn)練多個(gè)同樣的模型然后集成得到更好的泛化效果)是利用并行計(jì)算的非常簡(jiǎn)單的方法县耽,但是當(dāng)數(shù)據(jù)集很大種類很多的時(shí)候就會(huì)產(chǎn)生巨大的計(jì)算量而且效果也不好。Hinton在論文中利用soft label的技巧設(shè)計(jì)了一種集成模型降低了計(jì)算量又取得了很好的效果镣典。這個(gè)模型包含兩種小模型:generalist model 和 specialist model(網(wǎng)絡(luò)模型相同兔毙,分工不同)整個(gè)模型由很多個(gè)specialist model 和一個(gè)generalist model 集成。顧名思義generalist model 是負(fù)責(zé)將數(shù)據(jù)進(jìn)行粗略的區(qū)分(將相似的圖片歸為一類)兄春,而specialist model(專家模型)則負(fù)責(zé)將相似的圖片進(jìn)行更細(xì)致的分類澎剥。這樣的操作也非常符合人類的大腦的思維方式先進(jìn)行大類的區(qū)分再進(jìn)行具體分類,下面我們看這個(gè)實(shí)驗(yàn)的具體細(xì)節(jié)赶舆。
實(shí)驗(yàn)所用的數(shù)據(jù)集是谷歌內(nèi)部的JFT數(shù)據(jù)集哑姚,JFT數(shù)據(jù)集非常大,有一億張圖片和15000個(gè)類別芜茵。實(shí)驗(yàn)中 generalist model 是用所有數(shù)據(jù)集進(jìn)行訓(xùn)練的叙量,有15000個(gè)輸出,也就是每個(gè)類別都有一個(gè)輸出概率九串。將數(shù)據(jù)集進(jìn)行分類則是用Online k-means聚類的方法對(duì)每張圖片輸入generalist model后得到的軟標(biāo)簽進(jìn)行聚類绞佩,最終將3%的數(shù)據(jù)為一組分發(fā)給各個(gè)specialist,每個(gè)小數(shù)據(jù)集包含一些聚集的圖片蒸辆,也就是generalist認(rèn)為相近的圖片征炼。
在specialist model的訓(xùn)練階段,模型的參數(shù)在初始化的時(shí)候是完全復(fù)制的generalist中的數(shù)值(specialist和generalist的結(jié)構(gòu)是一模一樣的)躬贡,這樣可以保留generalist模型的所有知識(shí)谆奥,然后specialist對(duì)分配的數(shù)據(jù)集進(jìn)行hard label訓(xùn)練。但是問題是拂玻,specialist如果只專注于分配的數(shù)據(jù)集(只對(duì)分配的數(shù)據(jù)集訓(xùn)練)整個(gè)網(wǎng)絡(luò)很快就會(huì)過擬合于分配的數(shù)據(jù)集上酸些,所以Hinton提出的方法是用一半的時(shí)間進(jìn)行hard label訓(xùn)練,另一半的時(shí)間用知識(shí)蒸餾的方法學(xué)習(xí)generalist生成的soft label檐蚜。這樣specialist就是花一半的時(shí)間在進(jìn)行小分類的學(xué)習(xí)魄懂,另一半的時(shí)間是在模仿generalist的行為。
整個(gè)模型的預(yù)測(cè)也與往常不同闯第。在做top-1分類的時(shí)候分為以下兩步:
第一步:將圖片輸入generalist model 得到輸出的概率分布市栗,取概率最大的類別k。
第二步:取出數(shù)據(jù)集包含類別k的所有specialists,為集合(各個(gè)數(shù)據(jù)集之間是有類別重合的)填帽。然后求解能使如下公式最小化的概率分布q作為預(yù)測(cè)分布蛛淋。
這里的KL是指KL散度(用于刻畫兩個(gè)概率分布之間的差距)和
分別是測(cè)試圖片輸入generalist 和specialists(m)之后輸出的概率分布,累加就是考慮所有屬于
集合的specialist的“意見”篡腌。
3.2實(shí)驗(yàn)結(jié)果
由于Specialist model的訓(xùn)練數(shù)據(jù)集很小褐荷,所以需要訓(xùn)練的時(shí)間很短,從傳統(tǒng)方法需要的幾周時(shí)間減少到幾天嘹悼。下圖是在訓(xùn)練好generalist模型之后逐個(gè)增加specialist進(jìn)行訓(xùn)練的測(cè)試結(jié)果:
從圖中可以看出叛甫,specialist個(gè)數(shù)的增加使top1準(zhǔn)確個(gè)數(shù)有明顯的提高。
4.總結(jié)
本文結(jié)合Hinton在2014年發(fā)表的論文對(duì)知識(shí)蒸餾和相關(guān)實(shí)驗(yàn)做了一個(gè)簡(jiǎn)單的介紹杨伙,如今很多模型都用到了知識(shí)蒸餾的方法其监,但知識(shí)蒸餾在深度學(xué)習(xí)中還是非常新的方向,還有非常多的應(yīng)用場(chǎng)景等待研究缀台。
項(xiàng)目地址:https://momodel.cn/explore/5dc3b1223752d662e35925a3?type=app
參考文獻(xiàn)
[1]Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015.
[2]https://nervanasystems.github.io/distiller/knowledge_distillation.html
[3]https://www.youtube.com/watch?v=EK61htlw8hY&t=3323s