對(duì)于深度學(xué)習(xí)而言,數(shù)據(jù)集非常重要值戳,但在實(shí)際項(xiàng)目中西篓,或多或少會(huì)碰見(jiàn)數(shù)據(jù)不平衡問(wèn)題愈腾。什么是數(shù)據(jù)不平衡呢?舉例來(lái)說(shuō),現(xiàn)在有一個(gè)任務(wù)是判斷西瓜是否成熟岂津,這是一個(gè)二分類問(wèn)題——西瓜是生的還是熟的虱黄,該任務(wù)的數(shù)據(jù)集由兩部分?jǐn)?shù)據(jù)組成,成熟西瓜與生西瓜吮成,假設(shè)生西瓜的樣本數(shù)量遠(yuǎn)遠(yuǎn)大于成熟西瓜樣本的數(shù)量橱乱,針對(duì)這樣的數(shù)據(jù)集訓(xùn)練出來(lái)的算法“偏向”于識(shí)別新樣本為生西瓜,存心讓你買不到甜的西瓜以解夏天之苦粱甫,這就是一個(gè)數(shù)據(jù)不平衡問(wèn)題泳叠。
針對(duì)數(shù)據(jù)不平衡問(wèn)題有相應(yīng)的處理辦法,比如對(duì)多數(shù)樣本進(jìn)行采樣使得其樣本數(shù)量級(jí)與少樣本數(shù)相近茶宵,或者是對(duì)少數(shù)樣本重復(fù)使用等危纫。最近恰好在面試中遇到一個(gè)數(shù)據(jù)不平衡問(wèn)題,這也是面試中經(jīng)常會(huì)出現(xiàn)的問(wèn)題之一乌庶,現(xiàn)向讀者分享此次解決問(wèn)題的心得种蝶。
數(shù)據(jù)集
訓(xùn)練數(shù)據(jù)中有三個(gè)標(biāo)簽,分別標(biāo)記為[1瞒大、2螃征、3],這意味著該問(wèn)題是一個(gè)多分類問(wèn)題透敌。訓(xùn)練數(shù)據(jù)集有17個(gè)特征以及38829個(gè)獨(dú)立數(shù)據(jù)點(diǎn)会傲。而在測(cè)試數(shù)據(jù)中锅棕,有16個(gè)沒(méi)有標(biāo)簽的特征和16641個(gè)數(shù)據(jù)點(diǎn)。該訓(xùn)練數(shù)據(jù)集非常不平衡淌山,大部分?jǐn)?shù)據(jù)是1類(95%)裸燎,而2類和3類分別有3.0%和0.87%的數(shù)據(jù),如下圖所示泼疑。
算法
經(jīng)過(guò)初步觀察德绿,決定采用隨機(jī)森林(RF)算法,因?yàn)樗鼉?yōu)于支持向量機(jī)退渗、Xgboost以及LightGBM算法移稳。在這個(gè)項(xiàng)目中選擇RF還有幾個(gè)原因:
機(jī)森林對(duì)過(guò)擬合具有很強(qiáng)的魯棒性;
參數(shù)化仍然非常直觀;
在這個(gè)項(xiàng)目中,有許多成功的用例將隨機(jī)森林算法用于高度不平衡的數(shù)據(jù)集;
個(gè)人有先前的算法實(shí)施經(jīng)驗(yàn);
為了找到最佳參數(shù)会油,使用scikit-sklearn實(shí)現(xiàn)的GridSearchCV對(duì)指定的參數(shù)值執(zhí)行網(wǎng)格搜索个粱,更多細(xì)節(jié)可以在本人的Github上找到。
為了處理數(shù)據(jù)不平衡問(wèn)題翻翩,使用了以下三種技術(shù):
A.使用集成交叉驗(yàn)證(CV):
在這個(gè)項(xiàng)目中都许,使用交叉驗(yàn)證來(lái)驗(yàn)證模型的魯棒性。整個(gè)數(shù)據(jù)集被分成五個(gè)子集嫂冻。在每個(gè)交叉驗(yàn)證中胶征,使用其中的四個(gè)子集用于訓(xùn)練,剩余的子集用于驗(yàn)證模型桨仿,此外模型還對(duì)測(cè)試數(shù)據(jù)進(jìn)行了預(yù)測(cè)睛低。在交叉驗(yàn)證結(jié)束時(shí),會(huì)得到五個(gè)測(cè)試預(yù)測(cè)概率服傍。最后钱雷,對(duì)所有類別的概率取平均值。模型的訓(xùn)練表現(xiàn)穩(wěn)定吹零,每個(gè)交叉驗(yàn)證上具有穩(wěn)定的召回率和f1分?jǐn)?shù)罩抗。這項(xiàng)技術(shù)也幫助我在Kaggle比賽中取得了很好的成績(jī)(前1%)。以下部分代碼片段顯示了集成交叉驗(yàn)證的實(shí)現(xiàn):
B.設(shè)置類別權(quán)重/重要性:
代價(jià)敏感學(xué)習(xí)是使隨機(jī)森林更適合從非常不平衡的數(shù)據(jù)中學(xué)習(xí)的方法之一瘪校。隨機(jī)森林有傾向于偏向大多數(shù)類別。因此名段,對(duì)少數(shù)群體錯(cuò)誤分類施加昂貴的懲罰可能是有作用的阱扬。由于這種技術(shù)可以改善模型性能,所以我給少數(shù)群體分配了很高的權(quán)重(即更高的錯(cuò)誤分類成本)伸辟。然后將類別權(quán)重合并到隨機(jī)森林算法中麻惶。我根據(jù)類別1中數(shù)據(jù)集的數(shù)量與其它數(shù)據(jù)集的數(shù)量之間的比率來(lái)確定類別權(quán)重。例如信夫,類別1和類別3數(shù)據(jù)集的數(shù)目之間的比率約為110窃蹋,而類別1和類別2的比例約為26】▎現(xiàn)在我稍微對(duì)數(shù)量進(jìn)行修改以改善模型的性能,以下代碼片段顯示了不同類權(quán)重的實(shí)現(xiàn):
C.過(guò)大預(yù)測(cè)標(biāo)簽而不是過(guò)小預(yù)測(cè)(Over-Predict a Label than Under-Predict):
這項(xiàng)技術(shù)是可選的警没,通過(guò)實(shí)踐發(fā)現(xiàn)匈辱,這種方法對(duì)提高少數(shù)類別的表現(xiàn)非常有效。簡(jiǎn)而言之杀迹,如果將模型錯(cuò)誤分類為類別3亡脸,則該技術(shù)能最大限度地懲罰該模型,對(duì)于類別2和類別1懲罰力度稍差一些树酪。 為了實(shí)施該方法浅碾,我改變了每個(gè)類別的概率閾值,將類別3续语、類別2和類別1的概率設(shè)置為遞增順序(即垂谢,P3= 0.25,P2= 0.35疮茄,P1= 0.50)滥朱,以便模型被迫過(guò)度預(yù)測(cè)類別。該算法的詳細(xì)實(shí)現(xiàn)可以在Github上找到娃豹。
最終結(jié)果
以下結(jié)果表明焚虱,上述三種技術(shù)如何幫助改善模型性能:
1.使用集成交叉驗(yàn)證的結(jié)果:
2.使用集成交叉驗(yàn)證+類別權(quán)重的結(jié)果:
3.使用集成交叉驗(yàn)證+類別權(quán)重+過(guò)大預(yù)測(cè)標(biāo)簽的結(jié)果:
結(jié)論
由于在實(shí)施過(guò)大預(yù)測(cè)技術(shù)方面的經(jīng)驗(yàn)很少,因此最初的時(shí)候處理起來(lái)非常棘手懂版。但是鹃栽,研究該問(wèn)題有助于提升我解決問(wèn)題的能力。對(duì)于每個(gè)任務(wù)而言躯畴,起初可能確實(shí)是陌生的民鼓,這個(gè)時(shí)候不要害怕,一次次嘗試就好蓬抄。由于時(shí)間的限制(48小時(shí))丰嘉,無(wú)法將精力分散于模型的微調(diào)以及特征工程,存在改進(jìn)的地方還有很多嚷缭,比如刪除不必要的功能并添加一些額外功能饮亏。此外,也嘗試過(guò)LightGBM和XgBoost算法阅爽,但在實(shí)踐過(guò)程中發(fā)現(xiàn)路幸,隨機(jī)森林的效果優(yōu)于這兩個(gè)算法。在后面的研究中付翁,可以進(jìn)一步嘗試一些其他算法简肴,比如神經(jīng)網(wǎng)絡(luò)、稀疏編碼等百侧。