一灰瞻、二分類focal loss
1、一句話概括:
focal loss婚被,這個損失函數(shù)是在標準交叉熵損失基礎(chǔ)上修改得到的。這個函數(shù)可以通過減少易分類樣本的權(quán)重梳虽,使得模型在訓(xùn)練時更專注于難分類的樣本址芯。
2、出發(fā)點:
希望one-stage detector可以達到two-stage detector的準確率窜觉,同時不影響原有的速度
在object detection領(lǐng)域谷炸,一張圖像可能生成成千上萬的candidate locations,但是其中只有很少一部分是包含object的禀挫,這就帶來了類別不均衡旬陡。
·類別不均衡帶來的問題:
負樣本數(shù)量太大,占總的loss的大部分语婴,而且多是容易分類的描孟,因此使得模型的優(yōu)化方向并不是我們所希望的那樣。
3砰左、目的:
通過減少易分類樣本的權(quán)重匿醒,從而使得模型在訓(xùn)練時更專注于難分類的樣本
4、推導(dǎo)過程:
(1)缠导、原始二分類交叉熵:
(2)廉羔、平衡交叉熵
CE在某種程度上不能處理正/負例子的重要性,這里引入了一個權(quán)重因子“α”僻造,其范圍為[0,1]憋他,正類為α,負類為“1 -α”嫡意,這兩個定義合并在一個名為“α”的名稱下举瑰,可以定義為
這個損失函數(shù)稍微解決了類不平衡的問題捣辆,但是仍然無法區(qū)分簡單和困難的例子蔬螟。為了解決這個問題,我們定義了焦損失汽畴。
(3)旧巾、focal loss
(1-pt)^γ為調(diào)變因子,這里γ≥0忍些,稱為聚焦參數(shù)鲁猩。
從上述定義中可以提取出Focal Loss的兩個性質(zhì):
1、當(dāng)樣本分類錯誤時罢坝,pt趨于0廓握,調(diào)變因子趨于1,使得損失函數(shù)幾乎不受影響。另一方面隙券,如果示例被正確分類男应,pt將趨于1,調(diào)變因子將趨向于0娱仔,使得損耗非常接近于0沐飘,從而降低了該特定示例的權(quán)重。
2牲迫、聚焦參數(shù)(γ)平滑地調(diào)整易于分類的示例向下加權(quán)的速率耐朴。
FL(Focal Loss)和CE(交叉熵損失)的比較
當(dāng)γ=2時,與概率為0.9的示例相比盹憎,概率為0.9的示例的損失比CE和0.968低100倍筛峭,損失將降低1000倍。
下面的描述了不同γ值下的FL脚乡。當(dāng)γ=0時蜒滩,F(xiàn)L等于CE損耗。這里我們可以看到奶稠,對于γ=0(CE損失)俯艰,即使是容易分類的例子也會產(chǎn)生非平凡的損失震級。這些求和的損失可以壓倒稀有類(很難分類的類)锌订。
二竹握、多分類focal loss
#來源:https://github.com/HeyLynne/FocalLoss_for_multiclass
class FocalLoss(nn.Module):
def __init__(self, gamma = 2, alpha = 1, size_average = True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.size_average = size_average
self.elipson = 0.000001
def forward(self, logits, labels):
"""
cal culates loss
logits: batch_size * labels_length * seq_length
labels: batch_size * seq_length
"""
if labels.dim() > 2:
labels = labels.contiguous().view(labels.size(0), labels.size(1), -1)
labels = labels.transpose(1, 2)
labels = labels.contiguous().view(-1, labels.size(2)).squeeze()
if logits.dim() > 3:
logits = logits.contiguous().view(logits.size(0), logits.size(1), logits.size(2), -1)
logits = logits.transpose(2, 3)
logits = logits.contiguous().view(-1, logits.size(1), logits.size(3)).squeeze()
assert(logits.size(0) == labels.size(0))
assert(logits.size(2) == labels.size(1))
batch_size = logits.size(0)
labels_length = logits.size(1)
seq_length = logits.size(2)
# transpose labels into labels onehot
new_label = labels.unsqueeze(1)
label_onehot = torch.zeros([batch_size, labels_length, seq_length]).scatter_(1, new_label, 1)
# calculate log
log_p = F.log_softmax(logits)
pt = label_onehot * log_p
sub_pt = 1 - pt
fl = -self.alpha * (sub_pt)**self.gamma * log_p
if self.size_average:
return fl.mean()
else:
return fl.sum()