Ref:
背景
工作中處理二分類問題变勇,數(shù)據(jù)大多是長(zhǎng)尾分布龙优,即正樣本遠(yuǎn)小于負(fù)樣本。一般來說秦陋,通過調(diào)整閾值(置信度)惠呼,就可以滿足上線需求导俘。但總是有一些正樣本,得分較低剔蹋,希望找到一些辦法旅薄,提高這些得分很低的正例分?jǐn)?shù),且負(fù)樣本得分不被拉高太多泣崩。
模型通過梯度更新進(jìn)行訓(xùn)練少梁,實(shí)際應(yīng)用中,大部分的樣本是容易區(qū)分的矫付,而這些樣本貢獻(xiàn)了主要的loss凯沪,模型偏向于這些樣本,在部分難區(qū)分的樣本上效果不好买优。
所以妨马,為提高模型效果,要解決兩個(gè)問題:
- 如何處理樣本不均衡問題杀赢?
- 如何有效處理{正難烘跺,負(fù)難}的樣本?
Focal Loss
主要應(yīng)用在目標(biāo)檢測(cè)葵陵,實(shí)際應(yīng)用范圍很廣液荸。
分類問題中瞻佛,常見的loss是cross-entropy:
為了解決正負(fù)樣本不均衡,乘以權(quán)重:
一般根據(jù)各類別數(shù)據(jù)占比脱篙,對(duì)進(jìn)行取值娇钱,即當(dāng)class_1占比為30%時(shí),绊困。
我們希望模型能更關(guān)注容易錯(cuò)分的數(shù)據(jù)文搂,反向思考,就是讓模型別那么關(guān)注容易分類的樣本秤朗。因此煤蹭,F(xiàn)ocal Loss的思路就是,把高置信度的樣本損失降低取视。
多分類樣本:
不同取值情況如下圖:
模型是如何通過控制損失的衰減的呢硝皂?
當(dāng)樣本被誤分類時(shí),p很小作谭,很大稽物,loss不怎么受影響。當(dāng)樣本被正確分類折欠,p很大贝或,變小,loss衰減锐秦。
比如:當(dāng)咪奖,,p為0.9時(shí)酱床,羊赵,這個(gè)容易分類的樣本,損失和cross-entropy相比斤葱,衰減了100倍慷垮。
代碼
# 二分類
class BCEFocalLoss(torch.nn.Module):
"""
https://github.com/louis-she/focal-loss.pytorch/blob/master/focal_loss.py
二分類的Focalloss alpha 固定
"""
def __init__(self, gamma=2, alpha=0.25, reduction='sum'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, preds, targets):
"preds:[B,C],targets:[B]"
pt = torch.sigmoid(preds)
pt = pt.clamp(min=0.0001,max = 1.0) # 概率過低,logpt后揍堕,loss返回nan
# 我在gpu上使用時(shí)料身,不加.to(targets.device),報(bào)錯(cuò)
targets = torch.zeros(targets.size(0),2).to(targets.device).scatter_(1,targets.view(-1,1),1)
loss = - self.alpha * (1 - pt) ** self.gamma * targets * torch.log(pt) - \
(1 - self.alpha) * pt ** self.gamma * (1 - targets) * torch.log(1 - pt)
if self.reduction == 'elementwise_mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
return loss
# 多分類
class FocalLoss(nn.Module):
"""
Ref: https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py
FL(pt) = -alpha_t(1-pt)^gamma log(pt)
alpha: 類別權(quán)重,常數(shù)時(shí)衩茸,類別權(quán)重為:[alpha,1-alpha,1-alpha,...]芹血;列表時(shí),表示對(duì)應(yīng)類別權(quán)重
gamma: 難易分類的樣本權(quán)重楞慈,使得模型更關(guān)注難分類的樣本
優(yōu)點(diǎn):幫助區(qū)分難分類的不均衡樣本數(shù)據(jù)
"""
def __init__(self, num_classes, alpha=0.25,gamma=2,reduce=True):
super(FocalLoss,self).__init__()
self.num_classes = num_classes
self.gamma = gamma
self.reduce = reduce
if alpha is None:
self.alpha = torch.ones(self.num_classes,1)
else:
self.alpha = torch.zeros(num_classes)
self.alpha[0] = alpha
self.alpha[1:] += (1-alpha)
def forward(self,preds,targets):
"preds:[B,C],targets:[B]"
preds = preds.view(-1,preds.size(-1)) #[B,C]
self.alpha = self.alpha.to(preds.device)
logpt = F.log_softmax(preds,dim=1)
pt = F.softmax(preds).clamp(min=0.0001,max=1.0)
logpt = logpt.gather(1,targets.view(-1,1)) # 對(duì)應(yīng)類別值
pt = pt.gather(1,targets.view(-1,1))
self.alpha = self.alpha.gather(0,targets.view(-1))
loss = -(1-pt) **self.gamma *logpt
loss = self.alpha*loss.t()
if self.reduce:
return loss.mean()
else:
return loss.sum()
GHM - gradient harmonizing mechanism
Focal Loss對(duì)容易分類的樣本進(jìn)行了損失衰減幔烛,讓模型更關(guān)注難分樣本,并通過和進(jìn)行調(diào)參囊蓝。
GHM提到:
- 有一部分難分樣本就是離群點(diǎn)饿悬,不應(yīng)該給他太多關(guān)注;
- 梯度密度可以直接統(tǒng)計(jì)得到聚霜,不需要調(diào)參狡恬。
GHM認(rèn)為珠叔,類別不均衡可總結(jié)為難易分類樣本的不均衡,而這種難分樣本的不均衡又可視為梯度密度分布的不均衡弟劲。假設(shè)一個(gè)正樣本被正確分類祷安,它就是正易樣本,損失不大兔乞,模型不能從中獲益汇鞭。而一個(gè)錯(cuò)誤分類的樣本,更能促進(jìn)模型迭代庸追。實(shí)際應(yīng)用中霍骄,大量的樣本都是屬于容易分類的類型,這種樣本一個(gè)起不了太大作用淡溯,但量級(jí)過大腕巡,在模型進(jìn)行梯度更新時(shí),起主要作用血筑,使得模型朝這類數(shù)據(jù)更新绘沉。
- 圖示左,樣本梯度分布豺总。
梯度模長(zhǎng)(gradient norm)在很小和很大時(shí)车伞,密度較大。前者喻喳,表示了大量容易分類的樣本另玖,所以梯度很低。而后者表伦,文中認(rèn)為是離群點(diǎn)谦去,即便模型收斂,損失仍然很大蹦哼。 - 圖示中渠退,經(jīng)過修正后的梯度分布沦零。
和CE,FL相比闻坚,GHM-C根據(jù)梯度密度棚辽,大量容易分類的樣本和離群點(diǎn)的累計(jì)梯度被降級(jí),達(dá)到樣本均衡局劲,使得模型更加有效穩(wěn)定勺拣。 - 圖示右,樣本集梯度貢獻(xiàn)鱼填。
經(jīng)過GHM-C的梯度密度調(diào)整药有,各種難易分類的樣本分布更加平滑。
簡(jiǎn)而言之:Focal Loss是從置信度p來調(diào)整loss苹丸,GHM通過一定范圍置信度p的樣本數(shù)來調(diào)整loss愤惰。
梯度模長(zhǎng)
梯度模長(zhǎng):原文中用表示真實(shí)標(biāo)簽竹祷,這里統(tǒng)一符號(hào),用y表示:
推理:
則:
梯度密度(Gradient Density)
梯度模長(zhǎng)分布不均羊苟,引入梯度密度:
在N個(gè)樣本中,梯度模長(zhǎng)分布在范圍的個(gè)數(shù):
區(qū)間長(zhǎng)度:
梯度密度協(xié)調(diào)參數(shù):
上式分母感憾,可視為對(duì)附近樣本進(jìn)行歸一化蜡励。如果梯度分布均勻,則阻桅,如果密度過高凉倚,則意味著要降級(jí)處理。
GHM loss計(jì)算
代碼
def _expand_binary_labels(labels,label_weights,label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels),0)
inds = torch.nonzero(labels>=1).squeeze()
if inds.numel() >0:
bin_labels[inds,labels[inds]] = 1
bin_label_weights = label_weights.view(-1,1).expand(label_weights.size(0),label_channels)
return bin_labels, bin_label_weights
class GHMC(nn.Module):
"""GHM Classification Loss.
Ref:https://github.com/libuyu/mmdetection/blob/master/mmdet/models/losses/ghm_loss.py
Details of the theorem can be viewed in the paper
"Gradient Harmonized Single-stage Detector".
https://arxiv.org/abs/1811.05181
Args:
bins (int): Number of the unit regions for distribution calculation.
momentum (float): The parameter for moving average.
use_sigmoid (bool): Can only be true for BCE based loss now.
loss_weight (float): The weight of the total GHM-C loss.
"""
def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0,alpha=None):
super(GHMC, self).__init__()
self.bins = bins
self.momentum = momentum
edges = torch.arange(bins + 1).float() / bins
self.register_buffer('edges', edges)
self.edges[-1] += 1e-6
if momentum > 0:
acc_sum = torch.zeros(bins)
self.register_buffer('acc_sum', acc_sum)
self.use_sigmoid = use_sigmoid
if not self.use_sigmoid:
raise NotImplementedError
self.loss_weight = loss_weight
self.label_weight = alpha
def forward(self, pred, target, label_weight =None, *args, **kwargs):
"""Calculate the GHM-C loss.
Args:
pred (float tensor of size [batch_num, class_num]):
The direct prediction of classification fc layer.
target (float tensor of size [batch_num, class_num]):
Binary class target for each sample.
label_weight (float tensor of size [batch_num, class_num]):
the value is 1 if the sample is valid and 0 if ignored.
Returns:
The gradient harmonized loss.
"""
# the target should be binary class label
# if pred.dim() != target.dim():
# target, label_weight = _expand_binary_labels(
# target, label_weight, pred.size(-1))
# 我的pred輸入為[B,C]嫂沉,target輸入為[B]
target = torch.zeros(target.size(0),2).to(target.device).scatter_(1,target.view(-1,1),1)
# 暫時(shí)不清楚這個(gè)label_weight輸入形式稽寒,默認(rèn)都為1
if label_weight is None:
label_weight = torch.ones([pred.size(0),pred.size(-1)]).to(target.device)
target, label_weight = target.float(), label_weight.float()
edges = self.edges
mmt = self.momentum
weights = torch.zeros_like(pred)
# gradient length
# sigmoid梯度計(jì)算
g = torch.abs(pred.sigmoid().detach() - target)
# 有效的label的位置
valid = label_weight > 0
# 有效的label的數(shù)量
tot = max(valid.float().sum().item(), 1.0)
n = 0 # n valid bins
for i in range(self.bins):
# 將對(duì)應(yīng)的梯度值劃分到對(duì)應(yīng)的bin中, 0-1
inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
# 該bin中存在多少個(gè)樣本
num_in_bin = inds.sum().item()
if num_in_bin > 0:
if mmt > 0:
# moment計(jì)算num bin
self.acc_sum[i] = mmt * self.acc_sum[i] \
+ (1 - mmt) * num_in_bin
# 權(quán)重等于總數(shù)/num bin
weights[inds] = tot / self.acc_sum[i]
else:
weights[inds] = tot / num_in_bin
n += 1
if n > 0:
# scale系數(shù)
weights = weights / n
loss = F.binary_cross_entropy_with_logits(
pred, target, weights, reduction='sum') / tot
return loss * self.loss_weight