從應(yīng)用落地的角度來(lái)說(shuō)儿礼,bert雖然效果好,但有一個(gè)短板就是預(yù)訓(xùn)練模型太大乞巧,預(yù)測(cè)時(shí)間在平均在300ms以上(一條數(shù)據(jù))口蝠,無(wú)法滿足業(yè)務(wù)需求器钟。知識(shí)蒸餾是在較低成本下有效提升預(yù)測(cè)速度的方法。最近在看知識(shí)蒸餾方面的內(nèi)容妙蔗,對(duì)《DistillBert》做個(gè)簡(jiǎn)單的介紹傲霸。
提綱
1. Bert后演化的趨勢(shì)
2.知識(shí)蒸餾基本原理
3.《DistillBert》詳解
4. 后話
一、Bert后演化的趨勢(shì)
Bert后眉反,語(yǔ)義表示的基本框架已確定昙啄,后續(xù)大多模型以提升精度、提升速度來(lái)做寸五∈崃荩基本以知識(shí)蒸餾、提升算力梳杏、多任務(wù)學(xué)習(xí)韧拒、網(wǎng)絡(luò)結(jié)構(gòu)優(yōu)化四個(gè)方向來(lái)做。
如何提升速度?
invida發(fā)布transformer op秘狞,底層算子做fuse叭莫。
知識(shí)蒸餾,以distillBert和tinyBert為代表蹈集。
神經(jīng)網(wǎng)絡(luò)優(yōu)化技巧烁试。prune來(lái)裁剪多余的網(wǎng)絡(luò)節(jié)點(diǎn),混合精度(fp32和fp16混合來(lái)降低計(jì)算精度從而實(shí)現(xiàn)速度的提升)
如何提升精度拢肆?
增強(qiáng)算力减响。roberta
改進(jìn)網(wǎng)絡(luò)。xlnet郭怪,利用transformer-xl支示。
多任務(wù)學(xué)習(xí)(ensemble)。微軟發(fā)布的mk-dnn
二鄙才、知識(shí)蒸餾的基本原理
? ? 知識(shí)蒸餾是從算法層面提速的有效方式颂鸿,是趨勢(shì)之一。知識(shí)蒸餾從hinton大神14年《Distilling the Knowledge in a Neural Network》這篇paper而來(lái)攒庵。
? ? 定義兩個(gè)網(wǎng)絡(luò)嘴纺,一個(gè)teacher model败晴,一個(gè)student model。teacher model是預(yù)訓(xùn)練出來(lái)的大模型栽渴,teacher model eval結(jié)果出來(lái)的softlabel作為student model學(xué)習(xí)的一部分尖坤。student model的學(xué)習(xí)目標(biāo)由soft label和hard label組成。
? ? 其中有個(gè)核心的問(wèn)題闲擦,為什么要用soft label呢慢味?因?yàn)樽髡哒J(rèn)為softlabel中包含有hard label中沒(méi)有信息,也就是樣本的概率信息墅冷,可以達(dá)到泛化的效果纯路。
? ? 細(xì)節(jié)參考這篇博文:https://blog.csdn.net/nature553863/article/details/80568658
三、DistillBert
DistillBert的網(wǎng)絡(luò)結(jié)構(gòu):
student model的網(wǎng)絡(luò)結(jié)果與teacher model也就是bert的網(wǎng)絡(luò)結(jié)構(gòu)基本一致寞忿。主要包含如下改動(dòng):
每2層中去掉一層感昼。。作者調(diào)研后結(jié)果是隱藏層維度的變化比層數(shù)的變化對(duì)計(jì)算性能的影響較小罐脊,所以只改變了層數(shù)定嗓,把計(jì)算層數(shù)減小到原來(lái)的一半。
去掉了token type embedding和pooler萍桌。
每一層加了初始化宵溅,每一層的初始化為teacher model的參數(shù)。
2. 三個(gè)損失函數(shù):
(1)Lce損失函數(shù)
? ? ? Lce損失函數(shù)為T(mén)eacher model的soft label的損失函數(shù)上炎,Teacher model的logits ti/T(T 為溫度),通過(guò)softmax計(jì)算輸出得到teacher的概率分布恃逻,與student model logits si/T(T為溫度),通過(guò)softmax計(jì)算輸出得到student的概率分布藕施,最后計(jì)算兩個(gè)概率分布的KL散度寇损。
(2)Lmlm損失函數(shù)
? ? ? Lmlm損失函數(shù)為hard label的損失函數(shù),是bert 的masked language model的損失函數(shù)裳食。
(3)Lcos損失函數(shù)
? ? ? 計(jì)算teacher hidden state和student hidden state的余弦相似度矛市。官方代碼用的是:nn.CosineEmbeddingLoss。
整體計(jì)算公式為:? ?
Loss= 5.0*Lce+2.0* Lmlm+1.0* Lcos
3. 參數(shù)配置
training階段:計(jì)算8個(gè)卡诲祸,16GB浊吏,V100的GPU機(jī)器,90個(gè)小時(shí)
性能: DistilBERT 比Bert快71%救氯,訓(xùn)練參數(shù)為207 MB 找田。
四、實(shí)驗(yàn)結(jié)果
DistillBert在GLUE數(shù)據(jù)集上的表現(xiàn)
下圖為Ablation test的結(jié)果着憨,可以看出Lce墩衙、Lcos、參數(shù)初始化為結(jié)果影響較大。
五漆改、后話
? ? 知識(shí)蒸餾本質(zhì)是什么植袍?? 個(gè)人理解,其實(shí)知識(shí)蒸餾實(shí)際相當(dāng)于引入先驗(yàn)概率(prior knowledge)籽懦, soft label即是網(wǎng)絡(luò)輸入的先驗(yàn)概率于个,soft label與真實(shí)世界的事物類(lèi)似,呈各種概率分布暮顺。