1. 多標(biāo)簽分類損失函數(shù)
pytorch中能計(jì)算多標(biāo)簽分類任務(wù)loss的方法有好幾個(gè)观挎。
binary_cross_entropy和binary_cross_entropy_with_logits都是來自torch.nn.functional的函數(shù)挂滓,BCELoss和BCEWithLogitsLoss都來自torch.nn言秸,它們的區(qū)別:
函數(shù)名 | 解釋 |
---|---|
binary_cross_entropy | Function that measures the Binary Cross Entropy between the target and the output |
binary_cross_entropy_with_logits | Function that measures Binary Cross Entropy between target and output logits |
BCELoss | Function that measures the Binary Cross Entropy between the target and the output |
BCEWithLogitsLoss | Function that measures Binary Cross Entropy between target and output logits |
區(qū)別只在于這個(gè)logits,損失函數(shù)(類)名字中帶了with_logits沼沈,這里的logits指的是該損失函數(shù)已經(jīng)內(nèi)部自帶了計(jì)算logit的操作单绑,無需在傳入給這個(gè)loss函數(shù)之前手動使用sigmoid/softmax將之前網(wǎng)絡(luò)的輸入映射到[0,1]之間。
nn.functional.xxx是函數(shù)接口,而nn.Xxx是nn.functional.xxx的類封裝城榛,并且nn.Xxx都繼承于一個(gè)共同祖先nn.Module揪利。
In [257]: import torch
In [258]: import torch.nn as nn
In [259]: import torch.nn.functional as F
In [260]: true = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
In [261]: pred = torch.rand((2,3))
In [262]: true
Out[262]:
tensor([[1., 0., 1.],
[1., 0., 0.]])
In [263]: pred
Out[263]:
tensor([[0.0391, 0.7691, 0.1190],
[0.8846, 0.1628, 0.2641]])
In [264]: F.binary_cross_entropy(torch.sigmoid(pred), true)
Out[264]: tensor(0.7361)
In [265]: F.binary_cross_entropy_with_logits(pred, true)
Out[265]: tensor(0.7361)
In [267]: lf2 = nn.BCELoss()
In [268]: lf2(torch.sigmoid(pred), true)
Out[268]: tensor(0.7361)
In [269]: lf = nn.BCEWithLogitsLoss()
In [270]: lf(pred, true)
Out[270]: tensor(0.7361)
# -(ylog(p)+(1-y)log(1-p))
In [268]: torch.sum(-(true*torch.log(torch.sigmoid(pred))+(1-true)*torch.log(1-torch.sigmoid(pred))))/6
Out[268]: tensor(0.7361)