NLLLoss
負(fù)對數(shù)似然損失函數(shù),用于處理多分類問題莹捡,輸入是對數(shù)化的概率值鬼吵。
對于包含個樣本的batch數(shù)據(jù) , 是神經(jīng)網(wǎng)絡(luò)的輸出篮赢,并進(jìn)行歸一化和對數(shù)化處理齿椅。是樣本對應(yīng)的類別標(biāo)簽琉挖,每個樣本可能是種類別中的一個。
為第個樣本對應(yīng)的涣脚,
用于多個類別之間樣本不平衡問題:
weight ignore_index
class NLLLoss(_WeightedLoss):
__constants__ = ['ignore_index', 'reduction']
ignore_index: int
def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
reduce=None, reduction: str = 'mean') -> None:
super(NLLLoss, self).__init__(weight, size_average, reduce, reduction)
self.ignore_index = ignore_index
def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert self.weight is None or isinstance(self.weight, Tensor)
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
pytorch中通過torch.nn.NLLLoss
類實(shí)現(xiàn)示辈,也可以直接調(diào)用F.nll_loss
函數(shù),代碼中的weight
即是遣蚀。size_average
與reduce
已經(jīng)棄用矾麻。reduction有三種取值mean
, sum
, none
,對應(yīng)不同的返回. 默認(rèn)為mean
芭梯,對應(yīng)于一般情況下整體的計算险耀。
參數(shù)ignore_index
對應(yīng)于忽視的類別,即該類別的誤差不計入, 默認(rèn)為-100
玖喘,例如甩牺,將padding處的類別設(shè)置為ignore_index
LogSoftmax
pytorch中使用torch.nn.LogSoftmax
函數(shù)對神經(jīng)網(wǎng)絡(luò)的輸出進(jìn)行歸一化和對數(shù)化
CrossEntropyLoss
交叉熵?fù)p失函數(shù),用于處理多分類問題芒涡,輸入是未歸一化神經(jīng)網(wǎng)絡(luò)輸出柴灯。
對于包含個樣本的batch數(shù)據(jù) , 是神經(jīng)網(wǎng)絡(luò)未歸一化的輸出费尽。是樣本對應(yīng)的類別標(biāo)簽赠群,每個樣本可能是種類別中的一個。
為第個樣本對應(yīng)的旱幼,
class CrossEntropyLoss(_WeightedLoss):
__constants__ = ['ignore_index', 'reduction']
ignore_index: int
def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
reduce=None, reduction: str = 'mean') -> None:
super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
self.ignore_index = ignore_index
def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert self.weight is None or isinstance(self.weight, Tensor)
return F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction)
pytorch中通過torch.nn.CrossEntropyLoss
類實(shí)現(xiàn)查描,也可以直接調(diào)用F.cross_entropy
函數(shù),代碼中的weight
即是柏卤。size_average
與reduce
已經(jīng)棄用冬三。reduction有三種取值mean
, sum
, none
,對應(yīng)不同的返回. 默認(rèn)為mean
缘缚,對應(yīng)于一般情況下整體的計算勾笆。
驗(yàn)證:
import torch
import torch.nn as nn
# 多分類
m = torch.nn.LogSoftmax(dim=1)
loss_nll_fct = nn.NLLLoss(reduction="mean")
loss_ce_fct = nn.CrossEntropyLoss(reduction="mean")
input_src = torch.Tensor([[0.8, 0.9, 0.3], [0.8, 0.9, 0.3], [0.8, 0.9, 0.3], [0.8, 0.9, 0.3]])
target = torch.Tensor([1, 1, 0, 0]).long()
# 4個樣本,3分類
print(input_src.size())
print(target.size())
output = m(input_src)
loss_nll = loss_nll_fct(output, target)
print(loss_nll.item())
# 驗(yàn)證是否一致
loss_ce = loss_ce_fct(input_src, target)
print(loss_ce.item())
torch.Size([4, 3])
torch.Size([4])
0.9475762844085693
0.9475762844085693