計算多標(biāo)簽分類時候的損失函數(shù)一般選擇BCELoss和BCEWithLogitsLoss,這兩者的區(qū)別在于:
- BCELoss 是處理經(jīng)過Sigmoid之后輸出的概率值
- BCEWithLogitsLoss是把兩者合到一起Sigmoid-BCELoss
具體計算例子:
- 準(zhǔn)備輸入input:
import torch
import torch.nn as nn
input = torch.tensor([[-0.4089,-1.2471,0.5907],
[-0.4897,-0.8267,-0.7349],
[0.5241,-0.1246,-0.4751]])
print(input)
tensor([[-0.4089, -1.2471, 0.5907],
[-0.4897, -0.8267, -0.7349],
[ 0.5241, -0.1246, -0.4751]])
- sigmoid 將輸出值約束到0-1之間:
m=nn.Sigmoid()
S_input=m(input)
print(S_input)
tensor([[0.3992, 0.2232, 0.6435],
[0.3800, 0.3043, 0.3241],
[0.6281, 0.4689, 0.3834]])
- 準(zhǔn)備目標(biāo)值target:
target=torch.FloatTensor([[0,1,1],[0,0,1],[1,0,1]])
print(target)
tensor([[0., 1., 1.],
[0., 0., 1.],
[1., 0., 1.]])
- 接著使用BCELoss計算損失值:
BCELoss=nn.BCELoss()
loss=BCELoss(S_input,target)
print(loss)
tensor(0.7193)
-
如下圖看BCELoss如何計算多標(biāo)簽分類的損失简珠,驗證計算結(jié)果一致:
- 下面通過具體實(shí)現(xiàn)驗證圖示的計算過程:
loss = 0.0
for i in range(S_input.shape[0]):
for j in range(S_input.shape[1]):
loss += -(target[i][j] * torch.log(S_input[i][j]) + (1 - target[i][j]) * torch.log(1 - S_input[i][j]))
print(loss/(S_input.shape[0]*S_input.shape[1])) # 默認(rèn)取均值
tensor(0.7193)
- BCEWithLogitsLoss 就是把求Sigmoid 和上圖的取log等計算loss合到一起:
BCEWithLogitsLoss=nn.BCEWithLogitsLoss()
loss=BCEWithLogitsLoss(input,target)
print(loss)
tensor(0.7193)