交叉熵損失計算示例
交叉熵損失公式
其中y為label,p^為預測的正類別概率,即在二分類中通過sigmoid函數(shù)得出的正類別概率大小。
舉例:
criterion = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(input, target)
上述代碼即為求交叉熵損失的示例碎赢,輸入為隨機生成的三個具有5個標簽的樣本,目標標簽為同樣隨機生成的3個具有唯一正確標簽的樣本速梗,loss即為函數(shù)torch.nn.CrossEntropyLoss()求得的交叉熵損失肮塞。
torch.nn.CrossEntropyLoss()內的核心代碼如下:
```
batch_loss = 0.
for i in range(input.shape[0]):? # 遍歷樣本數(shù)
? ? numerator = np.exp(input[i, target[i]]) ? # 對每個目標標簽求指數(shù)
? ? denominator = np.sum(np.exp(input[i, :])) # 每個目標標簽所在樣本總體指數(shù)之和 ?
? ? loss = -np.log(numerator / denominator)? ? # 損失函數(shù)對數(shù)公式
? ? batch_loss += loss? # 所有樣本損失之和
```
其中,公式中損失函數(shù)只有對數(shù)log(numerator / denominator)原因是镀琉,其他標簽默認系數(shù)為0峦嗤,因此代碼中省略蕊唐,因此代碼也可寫成如下:
```
for i in range(input.shape[0]):
Loss[i]=-input[i, target[i]+sum(math.exp(input[i][:]))
loss = sum(Loss)
```