寫在前面:
為了對(duì)比二分類和多分類對(duì)于模型訓(xùn)練結(jié)果的影響虱疏,首先需要對(duì)這兩種損失函數(shù)有比較清晰的了解。
有空需要做一版整理分析哦。鉴分。。
機(jī)器學(xué)習(xí)分類問題中常用到Cross Entropy損失函數(shù)(交叉熵?fù)p失函數(shù))带膀,但為什么它會(huì)在分類問題中這么有效呢志珍?我們先從一個(gè)簡單的分類例子來入手。
1. 預(yù)測政治傾向例子
我們希望根據(jù)一個(gè)人的年齡垛叨、性別伦糯、年收入等相互獨(dú)立的特征,來預(yù)測一個(gè)人的政治傾向嗽元,有三種可預(yù)測結(jié)果:民主黨敛纲、共和黨、其他黨剂癌。假設(shè)我們當(dāng)前有兩個(gè)邏輯回歸模型(參數(shù)不同)淤翔,這兩個(gè)模型都是通過softmax的方式得到對(duì)于每個(gè)預(yù)測結(jié)果的概率值:
好了,有了模型之后佩谷,我們需要通過定義損失函數(shù)來判斷模型在樣本上的表現(xiàn)了旁壮,那么我們可以定義哪些損失函數(shù)呢?
有了上面的直觀分析谐檀,我們可以清楚的看到抡谐,對(duì)于分類問題的損失函數(shù)來說,分類錯(cuò)誤率和均方誤差損失都不是很好的損失函數(shù)稚补,下面我們來看一下交叉熵?fù)p失函數(shù)的表現(xiàn)情況童叠。
3. 學(xué)習(xí)過程
交叉熵?fù)p失函數(shù)經(jīng)常用于分類問題中框喳,特別是在神經(jīng)網(wǎng)絡(luò)做分類問題時(shí)课幕,也經(jīng)常使用交叉熵作為損失函數(shù)厦坛,此外,由于交叉熵涉及到計(jì)算每個(gè)類別的概率乍惊,所以交叉熵幾乎每次都和sigmoid(或softmax)函數(shù)一起出現(xiàn)杜秸。
我們用神經(jīng)網(wǎng)絡(luò)最后一層輸出的情況,來看一眼整個(gè)模型預(yù)測润绎、獲得損失和學(xué)習(xí)的流程:
1撬碟、神經(jīng)網(wǎng)絡(luò)最后一層得到每個(gè)類別的得分scores;
2莉撇、該得分經(jīng)過sigmoid(或softmax)函數(shù)獲得概率輸出呢蛤;
3、模型預(yù)測的類別概率輸出與真實(shí)類別的one hot形式進(jìn)行交叉熵?fù)p失函數(shù)的計(jì)算棍郎。
學(xué)習(xí)任務(wù)分為二分類和多分類情況其障,我們分別討論這兩種情況的學(xué)習(xí)過程。
3.2 多分類情況
4.1 優(yōu)點(diǎn)
在用梯度下降法做參數(shù)更新的時(shí)候涂佃,模型學(xué)習(xí)的速度取決于兩個(gè)值:一励翼、學(xué)習(xí)率;二辜荠、偏導(dǎo)值汽抚。其中,學(xué)習(xí)率是我們需要設(shè)置的超參數(shù)伯病,所以我們重點(diǎn)關(guān)注偏導(dǎo)值造烁。從上面的式子中,我們發(fā)現(xiàn)狱从,偏導(dǎo)值的大小取決于 x和 sigmoid(s)-y膨蛮,我們重點(diǎn)關(guān)注后者,后者的大小值反映了我們模型的錯(cuò)誤程度季研,該值越大敞葛,說明模型效果越差,但是該值越大同時(shí)也會(huì)使得偏導(dǎo)值越大与涡,從而模型學(xué)習(xí)速度更快惹谐。所以,使用邏輯函數(shù)得到概率驼卖,并結(jié)合交叉熵當(dāng)損失函數(shù)時(shí)氨肌,在模型效果差的時(shí)候?qū)W習(xí)速度比較快,在模型效果好的時(shí)候?qū)W習(xí)速度變慢酌畜。
4.2 缺點(diǎn)
sigmoid(softmax)+cross-entropy loss 擅長于學(xué)習(xí)類間的信息怎囚,因?yàn)樗捎昧祟愰g競爭機(jī)制,它只關(guān)心對(duì)于正確標(biāo)簽預(yù)測概率的準(zhǔn)確性,忽略了其他非正確標(biāo)簽的差異恳守,導(dǎo)致學(xué)習(xí)到的特征比較散考婴。基于這個(gè)問題的優(yōu)化有很多催烘,比如對(duì)softmax進(jìn)行改進(jìn)沥阱,如L-Softmax、SM-Softmax伊群、AM-Softmax等考杉。
參考資料:
1、損失函數(shù) - 交叉熵?fù)p失函數(shù)
https://zhuanlan.zhihu.com/p/35709485
2舰始、詳解softmax函數(shù)以及相關(guān)求導(dǎo)過程
https://zhuanlan.zhihu.com/p/25723112
3崇棠、損失函數(shù) - MSE