深度學(xué)習(xí)在這兩年的發(fā)展可謂是突飛猛進(jìn)辛孵,為了提升模型性能丛肮,模型的參數(shù)量變得越來(lái)越多,模型自身也變得越來(lái)越大魄缚。在圖像領(lǐng)域中基于Resnet的卷積神經(jīng)網(wǎng)絡(luò)模型宝与,不斷延伸著網(wǎng)絡(luò)深度。而在自然語(yǔ)言處理領(lǐng)域(NLP)領(lǐng)域冶匹,BERT习劫,GPT等超大模型的誕生也緊隨其后。這些巨型模型在準(zhǔn)確性上大部分時(shí)候都吊打其他一眾小參數(shù)量模型嚼隘,可是它們?cè)诓渴痣A段诽里,往往需要占用巨大內(nèi)存資源,同時(shí)運(yùn)行起來(lái)也極其耗時(shí)飞蛹,這與工業(yè)界對(duì)模型吃資源少谤狡,低延時(shí)的要求完全背道而馳。所以很多在學(xué)術(shù)界呼風(fēng)喚雨的強(qiáng)大模型在企業(yè)的運(yùn)用過(guò)程中卻沒(méi)有那么順風(fēng)順?biāo)?/p>
知識(shí)蒸餾
為解決上述問(wèn)題卧檐,我們需要將參數(shù)量巨大的模型墓懂,壓縮成小參數(shù)量模型,這樣就可以在不失精度的情況下霉囚,使得模型占用資源少捕仔,運(yùn)行快,所以如何將這些大模型壓縮佛嬉,同時(shí)保持住頂尖的準(zhǔn)確率逻澳,成了學(xué)術(shù)界一個(gè)專門的研究領(lǐng)域。2015年Geoffrey Hinton 發(fā)表的Distilling the Knowledge in a Neural Network的論文中提出了知識(shí)蒸餾技術(shù)暖呕,就是為了解決模型壓而生的斜做。至于文章的細(xì)節(jié)這里筆者不做過(guò)多介紹,想了解的同學(xué)們可以點(diǎn)擊上方鏈接好好研讀原文湾揽。不過(guò)這篇文章的主要思想就如下方圖片所示:用一個(gè)老師模型(大參數(shù)模型)去教一個(gè)學(xué)生模型(小參數(shù)模型)瓤逼,在實(shí)做上就是用讓學(xué)生模型去學(xué)習(xí)已經(jīng)在目標(biāo)數(shù)據(jù)集上訓(xùn)練過(guò)的老師模型。盡管學(xué)生模型最終依然達(dá)不到老師模型的準(zhǔn)確性库物,但是被老師教過(guò)的學(xué)生模型會(huì)比自己?jiǎn)为?dú)訓(xùn)練的學(xué)生模型更加強(qiáng)大霸旗。
這里大家可能會(huì)產(chǎn)生疑惑,為什么讓學(xué)生模型去學(xué)習(xí)目標(biāo)數(shù)據(jù)集會(huì)比被老師模型教出來(lái)的差戚揭。產(chǎn)生這種結(jié)果可能原因是因?yàn)?strong>老師模型的輸出提供了比目標(biāo)數(shù)據(jù)集更加豐富的信息诱告,如下圖所示,老師模型的輸出民晒,不僅提供了輸入圖片上的數(shù)字是數(shù)字1的信息精居,而且還附帶著數(shù)字1和數(shù)字7和9比較像等額外信息锄禽。
知識(shí)蒸餾具體流程
接下來(lái)筆者介紹一下知識(shí)蒸餾在實(shí)做上的具體流程。
- (1)定義一個(gè)參數(shù)量較大(強(qiáng)大的)的老師模型靴姿,和一個(gè)參數(shù)量較形值(弱小的)的學(xué)生模型,
- (2)讓老師模型在目標(biāo)數(shù)據(jù)集上訓(xùn)練到最佳佛吓,
- (3)將目標(biāo)數(shù)據(jù)的label替換成老師模型最后一個(gè)全連接層的輸出宵晚,讓學(xué)生模型學(xué)習(xí)老師模型的輸出,希望學(xué)生模型的輸出和老師模型輸出之間的交叉熵越小越好维雇。
了解到知識(shí)蒸餾的具體步驟之后淤刃,我們采用keras在mnist數(shù)據(jù)集上進(jìn)行一次簡(jiǎn)單的實(shí)驗(yàn)。
知識(shí)蒸餾實(shí)戰(zhàn)
導(dǎo)入一下必要的python 包吱型,同時(shí)載入數(shù)據(jù)钝凶。
from keras.datasets import mnist
from keras.layers import *
from keras import Model
from sklearn.metrics import accuracy_score
import numpy as np
(data_train,label_train),(data_test,label_test )= mnist.load_data()
data_train = np.expand_dims(data_train,axis=3)
data_test = np.expand_dims(data_test,axis=3)
定義老師模型和學(xué)生模型
在下方代碼中,筆者定義了一個(gè)包含3層卷積層的CNN模型作為老師模型(參數(shù)量6萬(wàn))唁影,定義了一個(gè)包含512個(gè)神經(jīng)元的全連接層作為學(xué)生模型(參數(shù)量4萬(wàn)耕陷,比老師模型少了2萬(wàn))。
#####定義老師模型——包含三層卷積層的CNN模型
def teacher_model():
input_ = Input(shape=(28,28,1))
x = Conv2D(32,(3,3),padding = "same")(input_)
x = Activation("relu")(x)
print(x)
x = MaxPool2D((2,2))(x)
x = Conv2D(64,(3,3),padding= "same")(x)
x = Activation("relu")(x)
x = MaxPool2D((2,2))(x)
x = Conv2D(64,(3,3),padding= "same")(x)
x = Activation("relu")(x)
x = MaxPool2D((2,2))(x)
x = Flatten()(x)
out = Dense(10,activation = "softmax")(x)
model = Model(inputs=input_,outputs=out)
model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
model.summary()
return model
###定義學(xué)生模型——— 一層含512個(gè)神經(jīng)元的全連接層
def student_model():
input_ = Input(shape=(28,28,1))
x = Flatten()(input_)
x = Dense(512,activation="sigmoid")(x)
out = Dense(10,activation = "softmax")(x)
model = Model(inputs=input_,outputs=out)
model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
model.summary()
return model
訓(xùn)練老師模型
接下來(lái)開(kāi)始訓(xùn)練老師模型据沈,由于mnist數(shù)據(jù)集較為簡(jiǎn)單哟沫,在三層的CNN模型上,我設(shè)定只訓(xùn)練2個(gè)epoch锌介。這里需要注意的是嗜诀,如下圖所示:三層卷積的CNN的有6萬(wàn)多個(gè)參數(shù)。
t_model = teacher_model()
t_model.fit(data_train,label_train,batch_size=64,epochs=2,validation_data=(data_test,label_test))
訓(xùn)練結(jié)果如下圖所示:兩個(gè)epoch孔祸,CNN模型就在測(cè)試集上做到了98%的準(zhǔn)確性隆敢。
訓(xùn)練學(xué)生模型
在512個(gè)神經(jīng)元的全連接層上訓(xùn)練mnist數(shù)據(jù)集,學(xué)生模型的參數(shù)量如下圖所示:參數(shù)量只有4萬(wàn)個(gè)崔慧,參數(shù)量比老師模型少了2萬(wàn)個(gè)
s_model = student_model()
s_model.fit(data_train,label_train,batch_size=64,epochs=10,validation_data=(data_test,label_test))
在學(xué)生模型上訓(xùn)練了10個(gè)epoch之后拂蝎,測(cè)試機(jī)準(zhǔn)確率最高也才達(dá)到0.9460,遠(yuǎn)低于CNN老師模型的0.98
老師模型教學(xué)生模型
最后我們用老師模型教學(xué)生模型惶室,進(jìn)行知識(shí)蒸餾温自。
首先我們采用下方代碼將目標(biāo)數(shù)據(jù)集的label替換成老師模型的輸出。
t_out = t_model.predict(data_train)
然后用學(xué)生模型去學(xué)習(xí)老師模型的輸出皇钞。
def teach_student(teacher_out, student_model,data_train,data_test,label_test):
t_out = teacher_out
s_model = student_model
for l in s_model.layers:
l.trainable = True
label_test = keras.utils.to_categorical(label_test)
model = Model(s_model.input,s_model.output)
model.compile(loss="categorical_crossentropy",
optimizer="adam")
model.fit(data_train,t_out,batch_size= 64,epochs = 5)
s_predict = np.argmax(model.predict(data_test),axis=1)
s_label = np.argmax(label_test,axis=1)
print(accuracy_score(s_predict,s_label))
最終得到的實(shí)驗(yàn)結(jié)果如下圖所示:學(xué)生模型的性能提升到了0.9511悼泌,相比于學(xué)生模型在目標(biāo)數(shù)據(jù)集上的最好成績(jī)0.9460提升了千分之6個(gè)點(diǎn)。這也證明我們知識(shí)蒸餾確實(shí)起作用了夹界。
結(jié)語(yǔ)
當(dāng)然我們也發(fā)現(xiàn)馆里,我們的實(shí)驗(yàn)提升的幅度并不大,離老師模型的準(zhǔn)確度還有巨大的差距,而要想優(yōu)化知識(shí)蒸餾的性能鸠踪,我們可以采取升溫技術(shù)以舒,升溫技術(shù)的原理圖如下圖所示:將老師模型的輸出在softmax激活函數(shù)之前初上一個(gè)數(shù)值大于1的數(shù)字T,這樣會(huì)使得老師模型輸出的個(gè)類別概率值變得較為接近慢哈。
確實(shí)升溫技術(shù)的主要目的就是將老師模型輸出的各類型的概率,變得較為接近永票,這樣老師模型的輸出信息將變得更加豐富卵贱,得學(xué)生模型學(xué)會(huì)分辨出個(gè)類別之間細(xì)微的區(qū)別。當(dāng)然知識(shí)蒸餾的優(yōu)化方法并不只上述的升溫技術(shù)這一種侣集,這里筆者只是拋磚引玉键俱,知識(shí)蒸餾還有更多的奧秘等著大家去探索,去學(xué)習(xí)世分。希望讀者能夠有所收獲的同時(shí)编振,心中的好奇心也能夠被激發(fā),主動(dòng)的學(xué)習(xí)知識(shí)蒸餾這門技術(shù)臭埋。
參考
https://arxiv.org/pdf/1503.02531.pdf
https://github.com/johnkorn/distillation
https://www.bilibili.com/video/av46561029/?p=54