Li L, Song D, Ma R, et al. KNN-BERT: fine-tuning pre-trained models with KNN classifier[J]. arXiv preprint arXiv:2110.02523, 2021.
摘要導(dǎo)讀
預(yù)訓(xùn)練模型被廣泛應(yīng)用于利用交叉熵?fù)p失優(yōu)化的線性分類器來微調(diào)下游任務(wù)房匆,可能會面臨魯棒性和穩(wěn)定性問題习霹。這些問題可以通過學(xué)習(xí)表示來改進(jìn)霸饲,即在做出預(yù)測時(shí)去關(guān)注在同一個(gè)類簇中表示的相似性厂汗,不同類簇之間的差異性旨剥。在本文中簇搅,作者提出將KNN分類器運(yùn)用到預(yù)訓(xùn)練模型的微調(diào)中庄涡。對于該KNN分類器,作者引入了一個(gè)有監(jiān)督的動量對比學(xué)習(xí)框架來學(xué)習(xí)有監(jiān)督的下游任務(wù)的聚類表示题禀。在大規(guī)模數(shù)據(jù)集和小樣本數(shù)據(jù)機(jī)上的文本分類實(shí)驗(yàn)和魯棒性測試都顯示了結(jié)合將KNN結(jié)合到傳統(tǒng)的微調(diào)過程中會得到很大的提升鞋诗。
模型淺析
本文中提出了KNN-BERT膀捷,利用KNN分類器時(shí)迈嘹,使用以BERT為代表的預(yù)訓(xùn)練模型作為文本表示編碼器。下面將從KNN分類器的效用和如何為KNN分類器設(shè)計(jì)文本表示的訓(xùn)練過程兩個(gè)方面進(jìn)行介紹全庸。
- KNN分類器
作者將一般的線性分類器與KNN分類器相結(jié)合秀仲,并使用加權(quán)平均logits來作為最終的預(yù)測logits。假設(shè)編碼后的文本表示為壶笼,其對應(yīng)的標(biāo)簽為神僵,線性分類器;這里使用由標(biāo)記的樣本來代表由余弦相似度選出的個(gè)近鄰樣本覆劈。
KNN對應(yīng)的logits是一個(gè)投票結(jié)果保礼,記為KNN。給定權(quán)重比重责语,最終的得分可以由如下的形式計(jì)算: - 用于KNN的對比學(xué)習(xí)
為了在預(yù)訓(xùn)練模型的微調(diào)中學(xué)習(xí)適用于KNN的表示胁赢,作者引入了一個(gè)監(jiān)督型對比學(xué)習(xí)框架,該框架使用標(biāo)簽信息來構(gòu)建對比學(xué)習(xí)的正負(fù)例樣本白筹。類似于info-nce損失智末,帶有監(jiān)督信息的對比損失定義為如下的形式:
一般來說它呀,傳統(tǒng)的對比學(xué)習(xí)基本就考慮到這里就可以了。但本文的作者對正例集合的構(gòu)造給出了一種全新的方式。
考慮到正例樣本的多樣化纵穿,即:他們來自同一個(gè)類簇但通過預(yù)訓(xùn)練模型的編碼他們會擁有不同的語義信息下隧。因此,重要的是要確定哪些正例樣本應(yīng)該用于對比損失的計(jì)算谓媒,否則淆院,學(xué)習(xí)到的表示可能不會得到緊密的類簇。因此句惯,作者提出了兩個(gè)學(xué)習(xí)表示的目標(biāo):1)使得同一個(gè)類簇中的樣本盡可能緊湊土辩;2)將那些不在同一個(gè)類簇中的樣本盡可能推遠(yuǎn)。
根據(jù)該目標(biāo)抢野,下圖展示了在對比學(xué)習(xí)中需要重點(diǎn)關(guān)注的兩類正例樣本:
基于這個(gè)出發(fā)點(diǎn),從原始的正例集合中選取個(gè)最相似的正例和個(gè)最不相似的正例恃轩,并且只針對這些選好的正例樣本來進(jìn)行表示的更新结洼。作者給出的理由是:計(jì)算所有的正例樣本可能會破壞與分類表示無關(guān)的語義信息;并且可能會影響分類結(jié)果叉跛,因?yàn)轭惔丶墑e的正例樣本表示可能與錨點(diǎn)樣本有很大的不同松忍。根據(jù)選定的正例樣本,前面的可以被重寫為: - 動量對比優(yōu)化
顯然,在對比學(xué)習(xí)訓(xùn)練過程中酥艳,使用大量的負(fù)例樣本可以幫助更好地采樣編碼表示的底層連續(xù)高維空間摊溶。因此,動量對比框架MoCo被用來以基于隊(duì)列更新策略來考慮大規(guī)模的負(fù)例樣本玖雁。在動量對比框架中更扁,包含兩個(gè)獨(dú)立的編碼器:針對查詢(錨點(diǎn))query的編碼器,針對key的編碼器赫冬。query編碼器由來自查詢樣本的梯度下降來更新浓镜,而key編碼器則由一個(gè)動量的過程來進(jìn)行更新:
首先將負(fù)例表示壓入循環(huán)隊(duì)列,只有在隊(duì)列末尾的樣本會通過key編碼器進(jìn)行編碼來更新补鼻。(注:這種更新是在key編碼器經(jīng)過動量更新之后執(zhí)行哄啄。)通過動量更新過程雅任,對比學(xué)習(xí)過程可以考慮大量的正負(fù)例樣本,因?yàn)樵撨^程不需要計(jì)算所有正負(fù)例的梯度咨跌。 -
雙目標(biāo)訓(xùn)練
最終的訓(xùn)練損失如下:
在訓(xùn)練的過程中沪么,查詢樣本和其對應(yīng)的正例和負(fù)例的編碼都由BERT中[CLS]token的輸出為對應(yīng)的表示。在微調(diào)的過程中锌半,作者將原始的交叉熵?fù)p失和對比損失結(jié)合到一起進(jìn)行表示學(xué)習(xí)禽车。從這里可以看出,用于分類的交叉熵?fù)p失是對標(biāo)簽信息的直接利用刊殉,而在對比學(xué)習(xí)中殉摔,則是利用標(biāo)簽信息進(jìn)行正負(fù)例的構(gòu)造,使得學(xué)習(xí)到的表示更有利于類簇的劃分记焊。
部分實(shí)驗(yàn)
筆者這里主要關(guān)注了最相似正例和最不相似正例選取的數(shù)量以及其對應(yīng)的比例:可以看出的一點(diǎn)是逸月,不同數(shù)量的hard-positives對性能的影響是非常重要的。這表明遍膜,引入適當(dāng)數(shù)量的hard-positives有利于學(xué)習(xí)更好的表示碗硬。
總體來說,對于基于BERT微調(diào)的分類任務(wù)捌归,作者引入KNN分類器來提供更加魯棒的分類預(yù)測結(jié)果肛响;在該目標(biāo)的驅(qū)動下岭粤,為KNN的有效預(yù)測設(shè)計(jì)了對應(yīng)的對比學(xué)習(xí)過程惜索。在該過程中,提出了基于類別標(biāo)簽的正例選擇方式剃浇,并且定義了兩種值得關(guān)注的正例樣本:與查詢樣本最相似的正例和與查詢樣本最不相似的正例巾兆。接著,引入動量對比框架以構(gòu)造更多的標(biāo)簽級別的正負(fù)例樣本對虎囚。環(huán)環(huán)相扣角塑,最終得到了顯著的性能提高。
其實(shí)筆者對基于隊(duì)列的負(fù)例更新策略不太能get到淘讥。可能類似這樣圃伶,將所有的樣本都push進(jìn)循環(huán)隊(duì)列,然后根據(jù)樣本標(biāo)簽來判斷哪些是可用負(fù)例蒲列?反正窒朋,key編碼器也不進(jìn)行參數(shù)的更新,一次用多少也不會增加計(jì)算量蝗岖。(: