在了解交叉熵之前我們需要關(guān)于熵的一些基本知識(shí)磺芭,可以參考我的上一篇博客[1]侧漓。
1.信息熵
信息熵的定義為離散隨機(jī)事件的出現(xiàn)概率[2]。當(dāng)一個(gè)事件出現(xiàn)的概率更高的時(shí)候驾霜,我們認(rèn)為該事件會(huì)傳播的更廣赶盔,因此可以使用信息熵來衡量信息的價(jià)值斤程。
當(dāng)一個(gè)信源具有多種不同的結(jié)果荸恕,記為:U1,U2,...,Un乖酬,每個(gè)事件相互獨(dú)立,對(duì)應(yīng)的概率記為:P1,P2,...,Pn融求。信息熵為各個(gè)事件方式概率的期望咬像,公式為:
對(duì)于二分類問題,當(dāng)一種事件發(fā)生的概率為p時(shí)生宛,另一種事件發(fā)生的概率就為(1-p)县昂,因此,對(duì)于二分類問題的信息熵計(jì)算公式為:
2.相對(duì)熵(KL散度)
相對(duì)熵(relative entropy)陷舅,又被稱為Kullback-Leibler散度(Kullback-leibler divergence)倒彰,是兩個(gè)概率分布間差異的一種度量[3]。在信息論中莱睁,相對(duì)熵等于兩個(gè)概率分布的信息熵的差值待讳。
相對(duì)熵的計(jì)算公式為:
其中代表事件的真實(shí)概率,
代表事件的預(yù)測(cè)概率缩赛。例如三分類問題的標(biāo)簽為
耙箍,預(yù)測(cè)標(biāo)簽為
。
因此該公式的字面上含義就是真實(shí)事件的信息熵與理論擬合的事件的香農(nóng)信息量與真實(shí)事件的概率的乘積的差的累加酥馍。[4]
當(dāng)p(x)和q(x)相等時(shí)相對(duì)熵為0辩昆,其它情況下大于0。證明如下:
KL散度在Pytorch中的使用方法為:
torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)
在使用過程中旨袒,reduction
一般設(shè)置為batchmean
這樣才符合數(shù)學(xué)公式而不是mean
汁针,在以后的版本中mean
會(huì)被替換掉。
此外砚尽,還要注意log_target
參數(shù)施无,因?yàn)樵谟?jì)算的過程中我們往往使用的是log softmax函數(shù)而不是softmax函數(shù)來避免underflow和overflow問題,因此我們要提前了解target是否經(jīng)過了log運(yùn)算必孤。
torch.nn.KLDivLoss()
會(huì)傳入兩個(gè)參數(shù)(input, target)
, input
是模型的預(yù)測(cè)輸出猾骡,target
是樣本的觀測(cè)標(biāo)簽。
kl_loss = nn.KLDivLoss(reduction="batchmean")
output = kl_loss(input, target)
下面我們用一個(gè)例子來看看torch.nn.KLDivLoss()
是如何使用的:
import torch
import torch.nn as nn
import torch.nn.functional as F
input = torch.randn(3, 5, requires_grad=True)
input = F.log_softmax(input, dim=1) # dim=1 每一行為一個(gè)樣本
target = torch.rand(3,5)
# target使用softmax
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=False)
output = kl_loss(input, F.softmax(target, dim=1))
print(output)
# target使用log_softmax
kl_loss_log = nn.KLDivLoss(reduction="batchmean", log_target=True)
output = kl_loss_log(input, F.log_softmax(target, dim=1))
print(output)
輸出結(jié)果如下:
tensor(0.3026, grad_fn=<DivBackward0>)
tensor(0.3026, grad_fn=<DivBackward0>)
3.交叉熵
相對(duì)熵可以寫成如下形式:
等式的前一項(xiàng)為真實(shí)事件的熵敷搪,后一部分為交叉熵[4]:
在機(jī)器學(xué)習(xí)中兴想,使用KL散度就可以評(píng)價(jià)真實(shí)標(biāo)簽與預(yù)測(cè)標(biāo)簽間的差異,但由于KL散度的第一項(xiàng)是個(gè)定值赡勘,故在優(yōu)化過程中只關(guān)注交叉熵就可以了嫂便。一般大多數(shù)機(jī)器學(xué)習(xí)算法會(huì)選擇交叉熵作為損失函數(shù)。
交叉熵在pytorch中可以調(diào)用如下函數(shù)實(shí)現(xiàn):
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
其計(jì)算方法如下所示[5]:
假設(shè)batch size為4闸与,待分類標(biāo)簽有3個(gè)毙替,隱藏層的輸出為:
input = torch.tensor([[ 0.8082, 1.3686, -0.6107],
[ 1.2787, 0.1579, 0.6178],
[-0.6033, -1.1306, 0.0672],
[-0.7814, 0.1185, -0.2945]])
經(jīng)過softmax
激活函數(shù)之后得到預(yù)測(cè)值:
output = nn.Softmax(dim=1)(input)
output:
tensor([[0.3341, 0.5851, 0.0808],
[0.5428, 0.1770, 0.2803],
[0.2821, 0.1665, 0.5515],
[0.1966, 0.4835, 0.3199]])
softmax函數(shù)的輸出結(jié)果每一行相加為1岸售。
假設(shè)這一個(gè)mini batch的標(biāo)簽為
[1,0,2,1]
根據(jù)交叉熵的公式:
代表真實(shí)標(biāo)簽,在真實(shí)標(biāo)簽中厂画,除了對(duì)應(yīng)類別其它類別的概率都為0凸丸,實(shí)際上,交叉熵可以簡(jiǎn)寫為:
所以該mini batch的loss的計(jì)算公式為(別忘了除以batch size木羹,我們最后求得的是mini batch的平均loss):
因此甲雅,我們還需要計(jì)算一次對(duì)數(shù):
output_log = torch.log(output)
output_log
計(jì)算結(jié)果為:
tensor([[-1.0964, -0.5360, -2.5153],
[-0.6111, -1.7319, -1.2720],
[-1.2657, -1.7930, -0.5952],
[-1.6266, -0.7267, -1.1397]])
根據(jù)交叉熵的計(jì)算公式解孙,loss的最終計(jì)算等式為:
運(yùn)算結(jié)果和pytorch內(nèi)置的交叉熵函數(shù)相同:
import torch
import torch.nn as nn
input = torch.tensor([[ 0.8082, 1.3686, -0.6107],
[ 1.2787, 0.1579, 0.6178],
[-0.6033, -1.1306, 0.0672],
[-0.7814, 0.1185, -0.2945]])
target = torch.tensor([1,0,2,1])
loss = nn.CrossEntropyLoss()
output = loss(input, target)
output.backward()
結(jié)果為:
tensor(0.6172)
除了torch.nn.CrosEntropyLoss()
函數(shù)外還有一個(gè)計(jì)算交叉熵的函數(shù)torch.nn.BCELoss()
坑填。與前者不同,該函數(shù)是用來計(jì)算二項(xiàng)分布(0-1分布)的交叉熵弛姜,因此輸出層只有一個(gè)神經(jīng)元(只能輸出0或者1)脐瑰。其公式為:
在pytorch中的函數(shù)為:
torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')
用一個(gè)實(shí)例來看看如何使用該函數(shù):
input = torch.tensor([-0.7001, -0.7231, -0.2049])
target = torch.tensor([0,0,1]).float()
m = nn.Sigmoid()
loss = nn.BCELoss()
output = loss(m(input), target)
output.backward()
輸出結(jié)果為:
tensor([0.5332])
它是如何計(jì)算的呢,我們接下來一步步分析:
首先輸入是:
input = [-0.7001, -0.7231, -0.2049]
需要經(jīng)過sigmoid
函數(shù)得到一個(gè)輸出
output_mid = m(input)
輸出結(jié)果為:
[0.3318, 0.3267, 0.4490]
然后我們根據(jù)二項(xiàng)分布交叉熵的公式:
得到loss
的如下計(jì)算公式:
和pytorch的內(nèi)置函數(shù)計(jì)算結(jié)果相同廷臼。
另外苍在,需要注意的是,當(dāng)使用交叉熵作為損失函數(shù)的時(shí)候荠商,標(biāo)簽不能為onehot形式寂恬,只能是一維的向量,例如莱没,當(dāng)batch size是5時(shí)初肉,這一個(gè)batch的標(biāo)簽只能時(shí)[0,1,4,2,6]這樣的形式。