[損失函數(shù)]——交叉熵

在了解交叉熵之前我們需要關(guān)于熵的一些基本知識(shí)磺芭,可以參考我的上一篇博客[1]侧漓。

1.信息熵

信息熵的定義為離散隨機(jī)事件的出現(xiàn)概率[2]。當(dāng)一個(gè)事件出現(xiàn)的概率更高的時(shí)候驾霜,我們認(rèn)為該事件會(huì)傳播的更廣赶盔,因此可以使用信息熵來衡量信息的價(jià)值斤程。
當(dāng)一個(gè)信源具有多種不同的結(jié)果荸恕,記為:U1,U2,...,Un乖酬,每個(gè)事件相互獨(dú)立,對(duì)應(yīng)的概率記為:P1,P2,...,Pn融求。信息熵為各個(gè)事件方式概率的期望咬像,公式為:
H(U)=E[-\log p_{i}]=-\sum_{i=1}^{n}p_{i}\log p_{i}

對(duì)于二分類問題,當(dāng)一種事件發(fā)生的概率為p時(shí)生宛,另一種事件發(fā)生的概率就為(1-p)县昂,因此,對(duì)于二分類問題的信息熵計(jì)算公式為:


2.相對(duì)熵(KL散度)

相對(duì)熵(relative entropy)陷舅,又被稱為Kullback-Leibler散度(Kullback-leibler divergence)倒彰,是兩個(gè)概率分布間差異的一種度量[3]。在信息論中莱睁,相對(duì)熵等于兩個(gè)概率分布的信息熵的差值待讳。
相對(duì)熵的計(jì)算公式為:
\begin{align} \text{KL}(P||Q) & = \sum_{i = 1}^{n} [p(x_{i})\log p(x_{i})-p(x_{i})\log q(x_{i})] \\ & = \sum_{i = 1}^{n}p(x_{i})\log \frac{p(x_{i})}{q(x_{i})} \end{align}
其中p(x)代表事件的真實(shí)概率,q(x)代表事件的預(yù)測(cè)概率缩赛。例如三分類問題的標(biāo)簽為(1,0,0)耙箍,預(yù)測(cè)標(biāo)簽為(0.7,0.1,0.2)

因此該公式的字面上含義就是真實(shí)事件的信息熵與理論擬合的事件的香農(nóng)信息量與真實(shí)事件的概率的乘積的差的累加酥馍。[4]

當(dāng)p(x)和q(x)相等時(shí)相對(duì)熵為0辩昆,其它情況下大于0。證明如下:


KL散度在Pytorch中的使用方法為:

torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)

在使用過程中旨袒,reduction一般設(shè)置為batchmean這樣才符合數(shù)學(xué)公式而不是mean汁针,在以后的版本中mean會(huì)被替換掉。
此外砚尽,還要注意log_target參數(shù)施无,因?yàn)樵谟?jì)算的過程中我們往往使用的是log softmax函數(shù)而不是softmax函數(shù)來避免underflow和overflow問題,因此我們要提前了解target是否經(jīng)過了log運(yùn)算必孤。

torch.nn.KLDivLoss()會(huì)傳入兩個(gè)參數(shù)(input, target), input是模型的預(yù)測(cè)輸出猾骡,target是樣本的觀測(cè)標(biāo)簽。

kl_loss = nn.KLDivLoss(reduction="batchmean")
output = kl_loss(input, target)

下面我們用一個(gè)例子來看看torch.nn.KLDivLoss()是如何使用的:

import torch
import torch.nn as nn
import torch.nn.functional as F

input = torch.randn(3, 5, requires_grad=True)
input = F.log_softmax(input, dim=1)  # dim=1 每一行為一個(gè)樣本

target = torch.rand(3,5)

# target使用softmax
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=False)
output = kl_loss(input, F.softmax(target, dim=1))
print(output)

# target使用log_softmax
kl_loss_log = nn.KLDivLoss(reduction="batchmean", log_target=True)
output = kl_loss_log(input, F.log_softmax(target, dim=1))
print(output)

輸出結(jié)果如下:

tensor(0.3026, grad_fn=<DivBackward0>)
tensor(0.3026, grad_fn=<DivBackward0>)

3.交叉熵

相對(duì)熵可以寫成如下形式:
D_{KL}(p||q)=\sum_{i=1}^{n}p(x_{i})\log p(x_{i})-\sum_{i=1}^{n}p(x_{i})\log q(x_{i})=-H(p(x)) +[-\sum_{i=1}^{n}p(x_{i})\log q(x_{i})]
等式的前一項(xiàng)為真實(shí)事件的熵敷搪,后一部分為交叉熵[4]
H(p,q)=-\sum_{i=1}^{n}p(x_{i})\log q(x_{i})
在機(jī)器學(xué)習(xí)中兴想,使用KL散度就可以評(píng)價(jià)真實(shí)標(biāo)簽與預(yù)測(cè)標(biāo)簽間的差異,但由于KL散度的第一項(xiàng)是個(gè)定值赡勘,故在優(yōu)化過程中只關(guān)注交叉熵就可以了嫂便。一般大多數(shù)機(jī)器學(xué)習(xí)算法會(huì)選擇交叉熵作為損失函數(shù)。

交叉熵在pytorch中可以調(diào)用如下函數(shù)實(shí)現(xiàn):

torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

其計(jì)算方法如下所示[5]
假設(shè)batch size為4闸与,待分類標(biāo)簽有3個(gè)毙替,隱藏層的輸出為:

input = torch.tensor([[ 0.8082,  1.3686, -0.6107],
                      [ 1.2787,  0.1579,  0.6178],
                      [-0.6033, -1.1306,  0.0672],
                      [-0.7814,  0.1185, -0.2945]])

經(jīng)過softmax激活函數(shù)之后得到預(yù)測(cè)值:

output = nn.Softmax(dim=1)(input)
output:
tensor([[0.3341, 0.5851, 0.0808],
        [0.5428, 0.1770, 0.2803],
        [0.2821, 0.1665, 0.5515],
        [0.1966, 0.4835, 0.3199]])

softmax函數(shù)的輸出結(jié)果每一行相加為1岸售。
假設(shè)這一個(gè)mini batch的標(biāo)簽為

[1,0,2,1]

根據(jù)交叉熵的公式:
H(p,q)=-\sum_{i=1}^{n}p(x_{i})\log q(x_{i})
p(x_{i})代表真實(shí)標(biāo)簽,在真實(shí)標(biāo)簽中厂画,除了對(duì)應(yīng)類別其它類別的概率都為0凸丸,實(shí)際上,交叉熵可以簡(jiǎn)寫為:
H(p,q)=-\log q(x_{class})
所以該mini batch的loss的計(jì)算公式為(別忘了除以batch size木羹,我們最后求得的是mini batch的平均loss):
Loss = - [log(0.5851) + log(0.5428) + log(0.5515) + log(0.4835)] / 4
因此甲雅,我們還需要計(jì)算一次對(duì)數(shù):

output_log = torch.log(output)
output_log

計(jì)算結(jié)果為:

tensor([[-1.0964, -0.5360, -2.5153],
        [-0.6111, -1.7319, -1.2720],
        [-1.2657, -1.7930, -0.5952],
        [-1.6266, -0.7267, -1.1397]])

根據(jù)交叉熵的計(jì)算公式解孙,loss的最終計(jì)算等式為:
loss = - (-0.5360 - 0.6111 - 0.5952 - 0.7267) / 4 = 0.61725
運(yùn)算結(jié)果和pytorch內(nèi)置的交叉熵函數(shù)相同:


import torch
import torch.nn as nn

input = torch.tensor([[ 0.8082,  1.3686, -0.6107],
        [ 1.2787,  0.1579,  0.6178],
        [-0.6033, -1.1306,  0.0672],
        [-0.7814,  0.1185, -0.2945]])
target = torch.tensor([1,0,2,1])

loss = nn.CrossEntropyLoss()
output = loss(input, target)
output.backward()

結(jié)果為:

tensor(0.6172)

除了torch.nn.CrosEntropyLoss()函數(shù)外還有一個(gè)計(jì)算交叉熵的函數(shù)torch.nn.BCELoss()坑填。與前者不同,該函數(shù)是用來計(jì)算二項(xiàng)分布(0-1分布)的交叉熵弛姜,因此輸出層只有一個(gè)神經(jīng)元(只能輸出0或者1)脐瑰。其公式為:
loss = -[y·logx+(1-y)·log(1-x)]
在pytorch中的函數(shù)為:

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')

用一個(gè)實(shí)例來看看如何使用該函數(shù):

input = torch.tensor([-0.7001, -0.7231, -0.2049])
target = torch.tensor([0,0,1]).float()
m = nn.Sigmoid()
loss = nn.BCELoss()
output = loss(m(input), target)
output.backward()

輸出結(jié)果為:

tensor([0.5332])

它是如何計(jì)算的呢,我們接下來一步步分析:

首先輸入是:

input = [-0.7001, -0.7231, -0.2049]

需要經(jīng)過sigmoid函數(shù)得到一個(gè)輸出

output_mid = m(input)

輸出結(jié)果為:

[0.3318, 0.3267, 0.4490]

然后我們根據(jù)二項(xiàng)分布交叉熵的公式:
loss = -[y·logx+(1-y)·log(1-x)]
得到loss的如下計(jì)算公式:

loss = - [1*\log (1-0.3318) + 1*\log (1-0.3267) + 1*\log (0.4490)]/3=0.5312

和pytorch的內(nèi)置函數(shù)計(jì)算結(jié)果相同廷臼。
另外苍在,需要注意的是,當(dāng)使用交叉熵作為損失函數(shù)的時(shí)候荠商,標(biāo)簽不能為onehot形式寂恬,只能是一維的向量,例如莱没,當(dāng)batch size是5時(shí)初肉,這一個(gè)batch的標(biāo)簽只能時(shí)[0,1,4,2,6]這樣的形式。


  1. 什么是熵饰躲,如何計(jì)算牙咏? ?

  2. 百度百科-信息熵 ?

  3. 百度百科-相對(duì)熵 ?

  4. 一文搞懂交叉熵在機(jī)器學(xué)習(xí)中的使用,透徹理解交叉熵背后的直覺 ?

  5. NLL_Loss & CrossEntropyLoss(交叉熵) ?

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末嘹裂,一起剝皮案震驚了整個(gè)濱河市妄壶,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌寄狼,老刑警劉巖丁寄,帶你破解...
    沈念sama閱讀 219,366評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異泊愧,居然都是意外死亡伊磺,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,521評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門拼卵,熙熙樓的掌柜王于貴愁眉苦臉地迎上來奢浑,“玉大人,你說我怎么就攤上這事腋腮∪副耍” “怎么了壤蚜?”我有些...
    開封第一講書人閱讀 165,689評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)徊哑。 經(jīng)常有香客問我袜刷,道長(zhǎng),這世上最難降的妖魔是什么莺丑? 我笑而不...
    開封第一講書人閱讀 58,925評(píng)論 1 295
  • 正文 為了忘掉前任著蟹,我火速辦了婚禮,結(jié)果婚禮上梢莽,老公的妹妹穿的比我還像新娘萧豆。我一直安慰自己,他們只是感情好昏名,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,942評(píng)論 6 392
  • 文/花漫 我一把揭開白布涮雷。 她就那樣靜靜地躺著,像睡著了一般轻局。 火紅的嫁衣襯著肌膚如雪洪鸭。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,727評(píng)論 1 305
  • 那天仑扑,我揣著相機(jī)與錄音览爵,去河邊找鬼。 笑死镇饮,一個(gè)胖子當(dāng)著我的面吹牛蜓竹,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播盒让,決...
    沈念sama閱讀 40,447評(píng)論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼梅肤,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了邑茄?” 一聲冷哼從身側(cè)響起姨蝴,我...
    開封第一講書人閱讀 39,349評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎肺缕,沒想到半個(gè)月后左医,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,820評(píng)論 1 317
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡同木,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,990評(píng)論 3 337
  • 正文 我和宋清朗相戀三年浮梢,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片彤路。...
    茶點(diǎn)故事閱讀 40,127評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡秕硝,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出洲尊,到底是詐尸還是另有隱情远豺,我是刑警寧澤奈偏,帶...
    沈念sama閱讀 35,812評(píng)論 5 346
  • 正文 年R本政府宣布,位于F島的核電站躯护,受9級(jí)特大地震影響惊来,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜棺滞,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,471評(píng)論 3 331
  • 文/蒙蒙 一裁蚁、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧继准,春花似錦枉证、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,017評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)昂灵。三九已至避凝,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間眨补,已是汗流浹背管削。 一陣腳步聲響...
    開封第一講書人閱讀 33,142評(píng)論 1 272
  • 我被黑心中介騙來泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留撑螺,地道東北人含思。 一個(gè)月前我還...
    沈念sama閱讀 48,388評(píng)論 3 373
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像甘晤,于是被迫代替她去往敵國(guó)和親含潘。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,066評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容