一扮匠、Focal loss損失函數(shù)
Focal Loss的引入主要是為了解決**難易樣本數(shù)量不平衡****(注意,有區(qū)別于正負(fù)樣本數(shù)量不平衡)的問(wèn)題,實(shí)際可以使用的范圍非常廣泛。
本文的作者認(rèn)為,易分樣本(即贫贝,置信度高的樣本)對(duì)模型的提升效果非常小,模型應(yīng)該主要關(guān)注與那些難分樣本蛉谜。一個(gè)簡(jiǎn)單的思想:把高置信度(p)樣本的損失再降低一些不就好了嗎稚晚!
focal loss函數(shù)公式:
其中,(1)為類(lèi)別權(quán)重型诚,用來(lái)權(quán)衡正負(fù)樣本不均衡問(wèn)題客燕,倘若負(fù)樣本越多,給負(fù)樣本的
權(quán)重就越小狰贯,這樣就可以降低負(fù)樣本的影響也搓。加一個(gè)小于1的超參數(shù),相當(dāng)于把Loss曲線(xiàn)整體往下拉一些涵紊,使得當(dāng)樣本概率較大的時(shí)候影響減小傍妒。;
(2) 表示難分樣本權(quán)重摸柄,用來(lái)衡量難分樣本和易分樣本颤练,對(duì)于正類(lèi)樣本而言,預(yù)測(cè)結(jié)果為0.95肯定是簡(jiǎn)單樣本塘幅,所以(1-0.95)的gamma次方就會(huì)很小昔案,這時(shí)損失函數(shù)值就變得更小。而預(yù)測(cè)概率為0.3的樣本其損失相對(duì)很大电媳。即正樣本:概率越小踏揣,表示hard example,損失越大匾乓; 負(fù)樣本:概率越大捞稿,表示hard example,損失越大拼缝。γ 起到了平滑的作用娱局,作者的實(shí)驗(yàn)中,論文采用α=0.25咧七,γ=2效果最好衰齐。。針對(duì)hard example继阻,Pt比較小耻涛,則權(quán)重比較大,讓網(wǎng)絡(luò)傾向于利用這樣的樣本來(lái)進(jìn)行參數(shù)的更新
Focal loss缺點(diǎn)(騰訊面試):
(1) 對(duì)異常樣本敏感: 假如訓(xùn)練集中有個(gè)樣本label標(biāo)錯(cuò)了瘟檩,那么focal loss會(huì)一直放大這個(gè)樣本的loss(模型想矯正回來(lái))抹缕,導(dǎo)致網(wǎng)絡(luò)往錯(cuò)誤方向?qū)W習(xí)。
(2)對(duì)分類(lèi)邊界異常點(diǎn)處理不理想:由于邊界樣本表示相似性較高墨辛,對(duì)于不同異常值表示卓研,每次損失更新時(shí),都會(huì)有反復(fù)在分類(lèi)決策面(0.5)上反復(fù)橫跳的點(diǎn)睹簇,導(dǎo)致模型收斂速度下降奏赘,退化成交叉熵?fù)p失。
二带膀、Focal loss損失函數(shù)代碼
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
def __init__(self, class_num, alpha=0.20, gamma=1.5, use_alpha=False, size_average=True):
super(FocalLoss, self).__init__()
self.class_num = class_num
self.alpha = alpha
self.gamma = gamma
if use_alpha:
self.alpha = torch.tensor(alpha).cuda()
# self.alpha = torch.tensor(alpha)
self.softmax = nn.Softmax(dim=1)
self.use_alpha = use_alpha
self.size_average = size_average
def forward(self, pred, target):
prob = self.softmax(pred.view(-1,self.class_num))
prob = prob.clamp(min=0.0001,max=1.0)
target_ = torch.zeros(target.size(0),self.class_num).cuda()
# target_ = torch.zeros(target.size(0),self.class_num)
target_.scatter_(1, target.view(-1, 1).long(), 1.)
if self.use_alpha:
batch_loss = - self.alpha.double() * torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
else:
batch_loss = - torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
batch_loss = batch_loss.sum(dim=1)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
三志珍、Focal loss損失函數(shù)引用及使用
# 函數(shù)引用(focal_loss為模型文件名)
from focal_loss import FocalLoss
#...
# 損失函數(shù)初始化
criterion = FocalLoss(class_num=3)
#...
# 獲得損失函數(shù)
loss = criterion(outputs, targets)