理解:
針對類別不均衡問題貌笨,作者提出一種新的損失函數(shù):focal loss愈涩,這個損失函數(shù)是在標準交叉熵損失基礎(chǔ)上修改得到的笋颤。這個函數(shù)可以通過減少易分類樣本的權(quán)重北发,使得模型在訓(xùn)練時更專注于難分類的樣本。Focal loss主要是為了解決one-stage目標檢測中正負樣本比例嚴重失衡的問題积锅。該損失函數(shù)降低了大量簡單負樣本在訓(xùn)練中所占的權(quán)重江耀,也可理解為一種困難樣本挖掘。
-
Focal loss是在交叉熵損失函數(shù)基礎(chǔ)上進行的修改擒权,首先回顧二分類交叉上損失:
-
y^'是經(jīng)過激活函數(shù)的輸出袱巨,所以在0-1之間√汲可見普通的交叉熵對于正樣本而言愉老,輸出概率越大損失越小。對于負樣本而言剖效,輸出概率越小則損失越小嫉入。此時的損失函數(shù)在大量簡單樣本的迭代過程中比較緩慢且可能無法優(yōu)化至最優(yōu)。那么Focal loss是怎么改進的呢璧尸?
首先在原有的基礎(chǔ)上加了一個因子咒林,其中g(shù)amma>0使得減少易分類樣本的損失。使得更關(guān)注于困難的爷光、錯分的樣本垫竞。
例如gamma為2,對于正類樣本而言蛀序,預(yù)測結(jié)果為0.95肯定是簡單樣本欢瞪,所以(1-0.95)的gamma次方就會很小,這時損失函數(shù)值就變得更小徐裸。而預(yù)測概率為0.3的樣本其損失相對很大引有。對于負類樣本而言同樣,預(yù)測0.1的結(jié)果應(yīng)當遠比預(yù)測0.7的樣本損失值要小得多倦逐。對于預(yù)測概率為0.5時譬正,損失只減少了0.25倍宫补,所以更加關(guān)注于這種難以區(qū)分的樣本。這樣減少了簡單樣本的影響曾我,大量預(yù)測概率很小的樣本疊加起來后的效應(yīng)才可能比較有效粉怕。
-
此外,加入平衡因子alpha抒巢,用來平衡正負樣本本身的比例不均:文中alpha取0.25贫贝,即正樣本要比負樣本占比小,這是因為負例易分蛉谜。
只添加alpha雖然可以平衡正負樣本的重要性稚晚,但是無法解決簡單與困難樣本的問題。
gamma調(diào)節(jié)簡單樣本權(quán)重降低的速率型诚,當gamma為0時即為交叉熵損失函數(shù)客燕,當gamma增加時,調(diào)整因子的影響也在增加狰贯。實驗發(fā)現(xiàn)gamma為2是最優(yōu)也搓。
focal loss實現(xiàn)
- 簡單二分類:
import torch
class FocalLoss:
def __init__(self, alpha_t=None, gamma=0):
"""
:param alpha_t: A list of weights for each class
:param gamma:
"""
self.alpha_t = torch.tensor(alpha_t) if alpha_t else None
self.gamma = gamma
def __call__(self, outputs, targets):
if self.alpha_t is None and self.gamma == 0:
focal_loss = torch.nn.functional.cross_entropy(outputs, targets)
elif self.alpha_t is not None and self.gamma == 0:
if self.alpha_t.device != outputs.device:
self.alpha_t = self.alpha_t.to(outputs)
focal_loss = torch.nn.functional.cross_entropy(outputs, targets,weight=self.alpha_t)
elif self.alpha_t is None and self.gamma != 0:
ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none')
p_t = torch.exp(-ce_loss)
focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean()
elif self.alpha_t is not None and self.gamma != 0:
if self.alpha_t.device != outputs.device:
self.alpha_t = self.alpha_t.to(outputs)
ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none')
p_t = torch.exp(-ce_loss)
ce_loss = torch.nn.functional.cross_entropy(outputs, targets,weight=self.alpha_t, reduction='none')
focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean() # mean over the batch
return focal_loss
import torch.nn.functional as F
import torch.nn as nn
if __name__ == '__main__':
outputs = torch.tensor([[2, 1.],
[2.5, 1]], device='cuda')
targets = torch.tensor([0, 1], device='cuda')
print(torch.nn.functional.softmax(outputs, dim=1))
fl= FocalLoss([0.5, 0.5], 2)
loss = F.cross_entropy(outputs, targets)
print(loss)
print(fl(outputs, targets))
- 多分類類似
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""
Focal loss(https://arxiv.org/pdf/1708.02002.pdf)
Shape:
- input: (N, C)
- target: (N)
- Output: Scalar loss
Examples:
>>> loss = FocalLoss(gamma=2, alpha=[1.0]*7)
>>> input = torch.randn(3, 7, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(7)
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(self, gamma=0, alpha: List[float] = None, reduction="none"):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if alpha is not None:
self.alpha = torch.FloatTensor(alpha)
self.reduction = reduction
def forward(self, input, target):
# [N, 1]
m=nn.CrossEntropyLoss()
print(m(input,target))
target = target.unsqueeze(-1)
print("target1:",target)
# [N, C]
pt = F.softmax(input, dim=-1)
print('pt1:',pt)
logpt = F.log_softmax(input, dim=-1)
print('logpt1:',logpt)
# [N]
print('zhangyi:',pt.gather(1, target))
pt = pt.gather(1, target).squeeze(-1)
print('pt2:',pt)
logpt = logpt.gather(1, target).squeeze(-1)
print('logpt2:',logpt)
if self.alpha is not None:
# [N] at[i] = alpha[target[i]]
print("target.squeeze(-1)",target.squeeze(-1))
at = self.alpha.gather(0, target.squeeze(-1))
print('at1',at)
print('logpt3',logpt)
logpt = logpt * at
print('logpt3',logpt)
loss = -1 * (1 - pt) ** self.gamma * logpt
if self.reduction == "none":
return loss
if self.reduction == "mean":
return loss.mean()
return loss.sum()
@staticmethod
def convert_binary_pred_to_two_dimension(x, is_logits=True):
"""
Args:
x: (*): (log) prob of some instance has label 1
is_logits: if True, x represents log prob; otherwhise presents prob
Returns:
y: (*, 2), where y[*, 1] == log prob of some instance has label 0,
y[*, 0] = log prob of some instance has label 1
"""
probs = torch.sigmoid(x) if is_logits else x
probs = probs.unsqueeze(-1)
probs = torch.cat([1-probs, probs], dim=-1)
logprob = torch.log(probs+1e-4) # 1e-4 to prevent being rounded to 0 in fp16
return logprob
def __str__(self):
return f"Focal Loss gamma:{self.gamma}"
def __repr__(self):
return str(self)
loss = FocalLoss(gamma=2, alpha=[1.0]*7)
input = torch.randn(3, 7, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(7)
print(input)
print(target)
output = loss(input, target)
print(output)
# output.backward()
論文:
https://arxiv.org/pdf/1708.02002.pdf
一些博客:
Focal loss論文詳解 - 知乎 (zhihu.com)
Focal Loss理解 - 三年一夢 - 博客園 (cnblogs.com)