LearnFromPapers系列——用“模型想象出來的target”來訓(xùn)練可以提高分類的效果
<center>作者:郭必?fù)P</center>
<center>時(shí)間:2020年最后一天</center>
前言:今天是2020年最后一天樟插,這篇文章也是我的SimpleAI公眾號(hào)2020年的最后一篇推文吮播,感謝大家一直以來的陪伴和支持,希望SimpleAI曾帶給各位可愛的讀者們一點(diǎn)點(diǎn)的收獲吧~這么特殊的一天,我也來介紹一篇特殊的論文他挎,那就是今年我和組里幾位老師合作的一篇AAAI論文:“Label Confusion Learning to Enhance Text Classification Models”。這篇文章的主要思想是通過構(gòu)造一個(gè)“標(biāo)簽混淆模型”來實(shí)時(shí)地“想象”一個(gè)比one-hot更好的標(biāo)簽分布牺陶,從而使得各種深度學(xué)習(xí)模型(LSTM、CNN辣之、BERT)在分類問題上都能得到更好的效果掰伸。個(gè)人感覺,還是有召烂、意思的碱工。
- 論文標(biāo)題:Label Confusion Learning to Enhance Text Classification Models
- 會(huì)議/期刊:AAAI-21
- 團(tuán)隊(duì):上海財(cái)經(jīng)大學(xué) 信息管理與工程學(xué)院 AI Lab
一、主要貢獻(xiàn)
本文的主要貢獻(xiàn)有這么幾點(diǎn):
- 構(gòu)造了一個(gè)插件--"Label Confusion Model(LCM)"奏夫,可以在模型訓(xùn)練的時(shí)候?qū)崟r(shí)計(jì)算樣本和標(biāo)簽間的關(guān)系怕篷,從而生成一個(gè)標(biāo)簽分布,作為訓(xùn)練的target酗昼,實(shí)驗(yàn)證明廊谓,這個(gè)新的target比one-hot標(biāo)簽更好;
- 這個(gè)插件不需要任何外部的知識(shí)麻削,也僅僅在訓(xùn)練的時(shí)候才需要蒸痹,不會(huì)增加模型預(yù)測時(shí)的時(shí)間,不改變?cè)P偷慕Y(jié)構(gòu)呛哟。所以LCM的應(yīng)用范圍很廣叠荠;
- 實(shí)驗(yàn)發(fā)現(xiàn)LCM還具有出色的抗噪性和抗干擾能力,對(duì)于有錯(cuò)標(biāo)的數(shù)據(jù)集扫责,或者標(biāo)簽間相似度很高的數(shù)據(jù)集榛鼎,有更好的表現(xiàn)。
二鳖孤、問題背景者娱、相關(guān)工作
1. 用one-hot來訓(xùn)練不夠好
本文主要是從文本分類的角度出發(fā)的,但文本分類和圖像分類實(shí)際上在訓(xùn)練模式上是類似的苏揣,基本都遵循這樣的一個(gè)流程:
step 1. 一個(gè)深度網(wǎng)絡(luò)(DNN黄鳍,諸如LSTM、CNN平匈、BERT等)來得到向量表示
step 2. 一個(gè)softmax分類器來輸出預(yù)測的標(biāo)簽概率分布p
step 3. 使用Cross-entropy來計(jì)算真實(shí)標(biāo)簽(one-hot表示)與p之間的損失框沟,從而優(yōu)化
這里使用cross-entropy loss(簡稱CE-loss)基本上成了大家訓(xùn)練模型的默認(rèn)方法,但它實(shí)際上存在一些問題增炭。下面我舉個(gè)例子:
比如有一個(gè)六個(gè)類別的分類任務(wù)街望,CE-loss是如何計(jì)算當(dāng)前某個(gè)預(yù)測概率p相對(duì)于y的損失呢:
可以看出,根據(jù)CE-loss的公式弟跑,只有y中為1的那一維度參與了loss的計(jì)算灾前,其他的都忽略了。這樣就會(huì)造成一些后果:
- 真實(shí)標(biāo)簽跟其他標(biāo)簽之間的關(guān)系被忽略了孟辑,很多有用的知識(shí)無法學(xué)到哎甲;比如:“鳥”和“飛機(jī)”本來也比較像蔫敲,因此如果模型預(yù)測覺得二者更接近,那么應(yīng)該給予更小的loss
- 傾向于讓模型更加“武斷”炭玫,成為一個(gè)“非黑即白”的模型奈嘿,導(dǎo)致泛化性能差;
- 面對(duì)易混淆的分類任務(wù)吞加、有噪音(誤打標(biāo))的數(shù)據(jù)集時(shí)裙犹,更容易受影響
總之,這都是由one-hot的不合理表示造成的衔憨,因?yàn)閛ne-hot只是對(duì)真實(shí)情況的一種簡化叶圃。
2. 一些可能的解決辦法
LDL:
既然one-hot不合理,那我們就使用更合理的標(biāo)簽分布來訓(xùn)練嘛践图。比如下圖所示:
如果我們能獲取真實(shí)的標(biāo)簽分布來訓(xùn)練掺冠,那該多好啊。
這種使用標(biāo)簽的分布來學(xué)習(xí)模型的方法码党,稱為LDL(Label Distribution Learning)德崭,東南大學(xué)耿新團(tuán)隊(duì)專門研究這個(gè)方面,大家可以去了解一下揖盘。
但是眉厨,真實(shí)的標(biāo)簽分布,往往很難獲取兽狭,甚至不可獲取缺猛,只能模擬。比如找很多人來投票椭符,或者通過觀察進(jìn)行統(tǒng)計(jì)。比如在耿新他們最初的LDL論文中耻姥,提出了很多生物數(shù)據(jù)集销钝,是通過實(shí)驗(yàn)觀察來得到的標(biāo)簽分布。然而琐簇,大多數(shù)的現(xiàn)有的數(shù)據(jù)集蒸健,尤其是文本、圖像分類婉商,幾乎都是one-hot的似忧,所以LDL并無法直接使用。
Label Enhancement:
Label Enhancement丈秩,機(jī)標(biāo)簽增強(qiáng)技術(shù)盯捌,則是一類從通過樣本特征空間來生成標(biāo)簽分布的方法,我在前面的論文解讀中有介紹蘑秽,這些方法都很有趣饺著。
然而箫攀,使用這些方法來訓(xùn)練模型,都比較麻煩幼衰,因?yàn)槲覀冃枰ㄟ^“兩步走”來訓(xùn)練靴跛,第一步使用LE的方法來構(gòu)造標(biāo)簽分布,第二步再使用標(biāo)簽分布來訓(xùn)練渡嚣。
Loss Correction:
面對(duì)one-hot可能帶來的容易過擬合的問題梢睛,有研究提出了Label Smoothing方法:
label smoothing就是把原來的one-hot表示,在每一維上都添加了一個(gè)隨機(jī)噪音识椰。這是一種簡單粗暴绝葡,但又十分有效的方法,目前已經(jīng)使用在很多的圖像分類模型中了裤唠。
這種方法挤牛,一定程度上,可以緩解模型過于武斷的問題种蘸,也有一定的抗噪能力墓赴。但是單純地添加隨機(jī)噪音,也無法反映標(biāo)簽之間的關(guān)系航瞭,因此對(duì)模型的提升有限诫硕,甚至有欠擬合的風(fēng)險(xiǎn)。
當(dāng)然還有一些其他的Loss Correction方法刊侯,可以參考我前面的一個(gè)介紹章办。
三、我們的思想&模型設(shè)計(jì)
我們最終的目標(biāo)滨彻,是能夠使用更加合理的標(biāo)簽分布來代替one-hot分布訓(xùn)練模型藕届,最好這個(gè)過程能夠和模型的訓(xùn)練同步進(jìn)行。
首先我們思考亭饵,一個(gè)合理的標(biāo)簽分布休偶,應(yīng)該有什么樣的性質(zhì)。
① 很自然地辜羊,標(biāo)簽分布應(yīng)該可以反映標(biāo)簽之間的相似性踏兜。
比方下面這個(gè)例子:
② 標(biāo)簽間的相似性是相對(duì)的,要根據(jù)具體的樣本內(nèi)容來看八秃。
比方下面這個(gè)例子碱妆,同樣的標(biāo)簽,對(duì)于不同的句子昔驱,標(biāo)簽之間的相似度也是不一樣的:
③ 構(gòu)造得到的標(biāo)簽分布疹尾,在01化之后應(yīng)該跟原one-hot表示相同。
啥意思呢,就是我們不能構(gòu)造出了一個(gè)標(biāo)簽分布航棱,最大值對(duì)應(yīng)的標(biāo)簽跟原本的one-hot標(biāo)簽還不一致睡雇,我們最終的標(biāo)簽分布,還是要以one-hot為標(biāo)桿來構(gòu)造饮醇。
根據(jù)上面的思考它抱,我們這樣來設(shè)計(jì)模型:
使用一個(gè)Label Encoder來學(xué)習(xí)各個(gè)label的表示,與input sample的向量表示計(jì)算相似度朴艰,從而得到一個(gè)反映標(biāo)簽之間的混淆/相似程度的分布观蓄。最后,使用該混淆分布來調(diào)整原來的one-hot分布祠墅,從而得到一個(gè)更好的標(biāo)簽分布侮穿。
設(shè)計(jì)出來的模型結(jié)構(gòu)如圖:
這個(gè)結(jié)構(gòu)分兩部分,左邊是一個(gè)Basic Predictor毁嗦,就是各種我們常用的分類模型亲茅。右邊的則是LCM的模型。注意LCM是一個(gè)插件狗准,所以左側(cè)可以更換成任何深度學(xué)習(xí)模型克锣。
Basic Predictor的過程可以用如下公式表達(dá):
其中就是輸入的文本的通過Input Decoder得到的表示。
則是predicted label distribution(PLD)腔长。
LCM的過程可以表達(dá)為:
其中代表label通過Label Encoder得到的標(biāo)簽表示矩陣袭祟,
是標(biāo)簽和輸入文本的相似度得到的標(biāo)簽混淆分布,
是真實(shí)的one-hot表示捞附,二者通過一個(gè)超參數(shù)結(jié)合再歸一化巾乳,得到最終的
,即模擬標(biāo)簽分布鸟召,simulated label distribution(SLD)胆绊。
最后,我們使用KL散度來計(jì)算loss:
總體來說還是比較簡單的欧募,很好復(fù)現(xiàn)压状,其實(shí)也存在更優(yōu)的模型結(jié)構(gòu),我們還在探究槽片。
四、實(shí)驗(yàn)&結(jié)果分析
1. Benchmark數(shù)據(jù)集上的測試
我們使用了2個(gè)中文數(shù)據(jù)集和3個(gè)英文數(shù)據(jù)集肢础,在LSTM还栓、CNN、BERT三種模型架構(gòu)上進(jìn)行測試传轰,實(shí)驗(yàn)表明LCM可以在絕大多數(shù)情況下剩盒,提升主流模型的分類效果。
下面這個(gè)圖展示了不同水平的α超參數(shù)對(duì)模型的影響:
從圖中可以看出慨蛙,不管α水平如何辽聊,LCM加成的模型纪挎,都可以顯著提高收斂速度,最終的準(zhǔn)確率也更高跟匆。針對(duì)不同的數(shù)據(jù)集特征异袄,我們可以使用不同的α(比如數(shù)據(jù)集混淆程度大,可以使用較小的α)玛臂,另外烤蜕,論文中我們還介紹了在使用較小α的時(shí)候,可以使用early-stop策略來防止過擬合迹冤。
而下面這個(gè)圖則展示了LCM確實(shí)可以學(xué)習(xí)到label之間的一些相似性關(guān)系讽营,而且是從完全隨機(jī)的初始狀態(tài)開始學(xué)到的:
2. 難以區(qū)分的數(shù)據(jù)集(標(biāo)簽易混淆)
我們構(gòu)造了幾個(gè)“簡單的”和“困難的”數(shù)據(jù)集,通過實(shí)驗(yàn)標(biāo)簽泡徙,LCM更適合那些容易混淆的數(shù)據(jù)集:
3. 有噪音的數(shù)據(jù)集
我們還測試了在不同噪音水平下的數(shù)據(jù)集上的效果橱鹏,并跟Label Smoothing方法做了對(duì)比,發(fā)現(xiàn)是顯著好于LS方法的堪藐。
下面這個(gè)圖展示了另外一組更細(xì)致的實(shí)驗(yàn)結(jié)果:
4. 在圖像分類上也有效果
最后莉兰,我們?cè)趫D像任務(wù)上也簡單測試了一下,發(fā)現(xiàn)也有效果:
總結(jié):