Microsoft在2020年提出了TwinBERT: Distilling Knowledge to Twin-Structured Compressed BERT Models for Large-Scale Retrieval這篇論文萎战。今天有幸看了看届吁,簡單的跟大家分享下擒悬。
解決問題
論文主要解決的問題是:性能~ 性能~ 性能~~~
Online Server需要快速處理,尤其是在召回階段巡球,面對上億級Doc,為此減少在線計(jì)算大勢所趨。
架構(gòu)原理
TwinBert就是在這種背景下應(yīng)運(yùn)而生的厅篓,如下圖結(jié)構(gòu):
主要講下上面這張圖:
- 整體:
- 兩個(gè)對稱的Bert, 左邊的Bert用于Query建模捶码,右邊的Bert用于Title keyword建模(或者Doc Context keyword建模)羽氮。
- 兩個(gè)Bert走完后,再各自經(jīng)過一個(gè)Pooling Layer惫恼,池化層档押,聽起來很高大上,其實(shí)很簡單祈纯,主要是將序列中每個(gè)token的向量搞在一起令宿,做成一個(gè)向量。 Query做成一個(gè)向量腕窥, keyword做成一個(gè)向量粒没,以方便進(jìn)行后面的Cross Layer的交互。 池化層有兩個(gè)操作二選一簇爆,【用CLS】 或者 【所有tokens向量平均加權(quán)起來】癞松,其中后者權(quán)重是學(xué)出來的倾贰。
- 輸入 : 均為Word Embeding + Position Embeding。 因?yàn)閮蛇叾际且痪湓捓雇铮跃蜎]有了Segment Embeding了匆浙。
值得提一下是,論文中是訓(xùn)練的英文的模型厕妖,所對輸入進(jìn)行了Word Hashing首尼,具體說是使用了Tri-letter, 至于什么是Word Hashing ,見本人的另外一文章Word Hashing。
*Transformer Encoder
這里不多說言秸,其中L用的是6層软能。
- 池化層
見整體部分,已說明举畸。
*Cross Layer
Query做成一個(gè)向量q, keyword做成一個(gè)向量k查排,二者進(jìn)行距離計(jì)算,有兩種方式抄沮,一種是余弦相似度跋核,如下圖:
另一種是Residual network, 這里不多講叛买,有興趣砂代,自身翻閱。
如何訓(xùn)練率挣?
蒸餾方法訓(xùn)練刻伊。
teacher model
所以要搞一個(gè)teacher model,文章用的12層的 query和title關(guān)鍵詞的訓(xùn)練的椒功。二分類捶箱,分為相關(guān)和不相關(guān)。最后輸出一個(gè)概率动漾。student model
有了teacher model丁屎, 現(xiàn)在就開始teach學(xué)生把,將上面講的Cross layer做的輸出通過LR壓縮到區(qū)間(0,1)谦炬, 因?yàn)橛嘞业闹涤蚴荹-1悦屏,1].
然后做一個(gè)做交叉熵 cross entropy节沦。如下面公式:
優(yōu)點(diǎn)
節(jié)省性能键思,Query在線用Bert預(yù)測, Doc提前離線算好刷到索引甫贯。在線只需要做一次Query Bert預(yù)測吼鳞,以及與Doc的向量計(jì)算。