樣本不均衡-Focal loss典奉,GHM

Ref:

  1. https://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf
  2. https://zhuanlan.zhihu.com/p/80594704
  3. https://arxiv.org/pdf/1811.05181.pdf

背景

工作中處理二分類問題变勇,數(shù)據(jù)大多是長(zhǎng)尾分布龙优,即正樣本遠(yuǎn)小于負(fù)樣本。一般來說秦陋,通過調(diào)整閾值(置信度)惠呼,就可以滿足上線需求导俘。但總是有一些正樣本,得分較低剔蹋,希望找到一些辦法旅薄,提高這些得分很低的正例分?jǐn)?shù),且負(fù)樣本得分不被拉高太多泣崩。

模型通過梯度更新進(jìn)行訓(xùn)練少梁,實(shí)際應(yīng)用中,大部分的樣本是容易區(qū)分的矫付,而這些樣本貢獻(xiàn)了主要的loss凯沪,模型偏向于這些樣本,在部分難區(qū)分的樣本上效果不好买优。

所以妨马,為提高模型效果,要解決兩個(gè)問題:

  1. 如何處理樣本不均衡問題杀赢?
  2. 如何有效處理{正難烘跺,負(fù)難}的樣本?

Focal Loss

主要應(yīng)用在目標(biāo)檢測(cè)葵陵,實(shí)際應(yīng)用范圍很廣液荸。
分類問題中瞻佛,常見的loss是cross-entropy:
L_{CE} = \begin{cases} -log(p), & y = 1 \\ -log(1 - p), & y = otherwise \end{cases}

為了解決正負(fù)樣本不均衡,乘以權(quán)重\alpha
L_{FL} = \begin{cases}-\alpha log(p), & y = 1 \\ -(1-\alpha)log(1 - p), & y = 0 \end{cases}

一般根據(jù)各類別數(shù)據(jù)占比脱篙,對(duì)\alpha進(jìn)行取值娇钱,即當(dāng)class_1占比為30%時(shí),\alpha = 0.3绊困。

我們希望模型能更關(guān)注容易錯(cuò)分的數(shù)據(jù)文搂,反向思考,就是讓模型別那么關(guān)注容易分類的樣本秤朗。因此煤蹭,F(xiàn)ocal Loss的思路就是,把高置信度的樣本損失降低取视。
L_{FL} = \begin{cases} -\alpha(1-p)^{\gamma} log(p), & y = 1 \\ -(1-\alpha)p^{\gamma} log(1 - p), & y = 0\\ \end{cases}

多分類樣本:
L_{FL} = -\alpha(1-p)^{\gamma}log(p)

\gamma不同取值情況如下圖:

from paper

模型是如何通過(1-p)^{\gamma}控制損失的衰減的呢硝皂?

當(dāng)樣本被誤分類時(shí),p很小作谭,(1-p)^{\gamma}很大稽物,loss不怎么受影響。當(dāng)樣本被正確分類折欠,p很大贝或,(1-p)^{\gamma}變小,loss衰減锐秦。
比如:當(dāng)\alpha = 1咪奖,\gamma=2,p為0.9時(shí)酱床,L_{FL} = -(1-0.9)^2 * log(0.9) = 0.01*L_{CE}羊赵,這個(gè)容易分類的樣本,損失和cross-entropy相比斤葱,衰減了100倍慷垮。

代碼

# 二分類
class BCEFocalLoss(torch.nn.Module):
    """
    https://github.com/louis-she/focal-loss.pytorch/blob/master/focal_loss.py
    二分類的Focalloss alpha 固定
    """
    def __init__(self, gamma=2, alpha=0.25, reduction='sum'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
 
    def forward(self, preds, targets):
        "preds:[B,C],targets:[B]"
        pt = torch.sigmoid(preds)
        pt = pt.clamp(min=0.0001,max = 1.0) # 概率過低,logpt后揍堕,loss返回nan
        # 我在gpu上使用時(shí)料身,不加.to(targets.device),報(bào)錯(cuò)
        targets = torch.zeros(targets.size(0),2).to(targets.device).scatter_(1,targets.view(-1,1),1) 
        loss = - self.alpha * (1 - pt) ** self.gamma * targets * torch.log(pt) - \
               (1 - self.alpha) * pt ** self.gamma * (1 - targets) * torch.log(1 - pt)
        if self.reduction == 'elementwise_mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss

# 多分類
class FocalLoss(nn.Module):
    """ 
        Ref: https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py
        FL(pt) = -alpha_t(1-pt)^gamma log(pt)
        alpha: 類別權(quán)重,常數(shù)時(shí)衩茸,類別權(quán)重為:[alpha,1-alpha,1-alpha,...]芹血;列表時(shí),表示對(duì)應(yīng)類別權(quán)重
        gamma: 難易分類的樣本權(quán)重楞慈,使得模型更關(guān)注難分類的樣本
        優(yōu)點(diǎn):幫助區(qū)分難分類的不均衡樣本數(shù)據(jù)
    """
    def __init__(self, num_classes, alpha=0.25,gamma=2,reduce=True):

        super(FocalLoss,self).__init__()

        self.num_classes = num_classes
        self.gamma = gamma
        self.reduce = reduce 

        if alpha is None:
            self.alpha = torch.ones(self.num_classes,1)
        else:
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] = alpha 
            self.alpha[1:] += (1-alpha)
    
    def forward(self,preds,targets):
        "preds:[B,C],targets:[B]"
        preds = preds.view(-1,preds.size(-1)) #[B,C]
        self.alpha = self.alpha.to(preds.device)
        logpt = F.log_softmax(preds,dim=1) 
        pt = F.softmax(preds).clamp(min=0.0001,max=1.0) 

        logpt = logpt.gather(1,targets.view(-1,1)) # 對(duì)應(yīng)類別值
        pt = pt.gather(1,targets.view(-1,1)) 
        self.alpha = self.alpha.gather(0,targets.view(-1))

        loss = -(1-pt) **self.gamma *logpt
        loss = self.alpha*loss.t()

        if self.reduce:
            return loss.mean()
        else:
            return loss.sum()

GHM - gradient harmonizing mechanism

Focal Loss對(duì)容易分類的樣本進(jìn)行了損失衰減幔烛,讓模型更關(guān)注難分樣本,并通過\alpha\gamma進(jìn)行調(diào)參囊蓝。

GHM提到:

  1. 有一部分難分樣本就是離群點(diǎn)饿悬,不應(yīng)該給他太多關(guān)注;
  2. 梯度密度可以直接統(tǒng)計(jì)得到聚霜,不需要調(diào)參狡恬。

GHM認(rèn)為珠叔,類別不均衡可總結(jié)為難易分類樣本的不均衡,而這種難分樣本的不均衡又可視為梯度密度分布的不均衡弟劲。假設(shè)一個(gè)正樣本被正確分類祷安,它就是正易樣本,損失不大兔乞,模型不能從中獲益汇鞭。而一個(gè)錯(cuò)誤分類的樣本,更能促進(jìn)模型迭代庸追。實(shí)際應(yīng)用中霍骄,大量的樣本都是屬于容易分類的類型,這種樣本一個(gè)起不了太大作用淡溯,但量級(jí)過大腕巡,在模型進(jìn)行梯度更新時(shí),起主要作用血筑,使得模型朝這類數(shù)據(jù)更新绘沉。

from paper
  • 圖示左,樣本梯度分布豺总。
    梯度模長(zhǎng)(gradient norm)在很小和很大時(shí)车伞,密度較大。前者喻喳,表示了大量容易分類的樣本另玖,所以梯度很低。而后者表伦,文中認(rèn)為是離群點(diǎn)谦去,即便模型收斂,損失仍然很大蹦哼。
  • 圖示中渠退,經(jīng)過修正后的梯度分布沦零。
    和CE,FL相比闻坚,GHM-C根據(jù)梯度密度棚辽,大量容易分類的樣本和離群點(diǎn)的累計(jì)梯度被降級(jí),達(dá)到樣本均衡局劲,使得模型更加有效穩(wěn)定勺拣。
  • 圖示右,樣本集梯度貢獻(xiàn)鱼填。
    經(jīng)過GHM-C的梯度密度調(diào)整药有,各種難易分類的樣本分布更加平滑。

簡(jiǎn)而言之:Focal Loss是從置信度p來調(diào)整loss苹丸,GHM通過一定范圍置信度p的樣本數(shù)來調(diào)整loss愤惰。

梯度模長(zhǎng)

梯度模長(zhǎng):原文中用p^*表示真實(shí)標(biāo)簽竹祷,這里統(tǒng)一符號(hào),用y表示:
g = |p-y|= \begin{cases} 1-p, & y = 1 \\ p, & y = 0\\ \end{cases}

推理:
p = sigmoid(x)
\frac { \partial p}{ \partial x} = p(1-p)
\frac { \partial L_{CE}}{ \partial p} = \begin{cases} -\frac {\partial logp}{\partial p}= -\frac{1}{p} , & y = 1 \\ -\frac {\partial log(1-p)}{\partial p}= \frac{1}{1 - p} , &y = 0 \end{cases}
則:
\frac {\partial L_{CE}}{\partial x} = \frac {\partial L_{CE}}{\partial p} \frac {\partial p}{\partial x} = \begin{cases} p-1 , & y = 1 \\ p, & y = 0 \end{cases} = p-y

g = |p-y| = |\frac {\partial L_{CE}}{\partial x} |

梯度密度(Gradient Density)

梯度模長(zhǎng)分布不均羊苟,引入梯度密度:
GD(g)=\frac{1}{l_{ \epsilon} (g)} \sum_k^N \delta_{ \epsilon}(g_k,g)

在N個(gè)樣本中,梯度模長(zhǎng)分布在(g-\epsilon/2,g+\epsilon/2)范圍的個(gè)數(shù):
\delta_{ \epsilon}(x,y) = \begin{cases} 1, if&y-\frac{\epsilon} {2} \leq x <y + \frac{\epsilon} {2}\\ 0, &otherwise \end{cases}
區(qū)間長(zhǎng)度: l_{ \epsilon} (g) = min(g+\epsilon/2,1) - max(g-\epsilon/2,0)
梯度密度協(xié)調(diào)參數(shù):\beta_i = \frac {N}{GD(g_i)} = \frac {1}{GD(g_i)/N}
上式分母感憾,可視為對(duì)g_i附近樣本進(jìn)行歸一化蜡励。如果梯度分布均勻,則\beta_i = 1阻桅,如果密度過高凉倚,則意味著要降級(jí)處理。

GHM loss計(jì)算

L_{GHM-C} = \frac{1}{N}\sum_i^N \beta_i{L_{CE}(p_i,y_i)} = \sum_i^N \frac{L_{CE}(p_i,y_i)}{GD(g_i)}

代碼

def _expand_binary_labels(labels,label_weights,label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels),0)
    inds = torch.nonzero(labels>=1).squeeze()
    if inds.numel() >0:
        bin_labels[inds,labels[inds]] = 1
    bin_label_weights = label_weights.view(-1,1).expand(label_weights.size(0),label_channels)
    return bin_labels, bin_label_weights
class GHMC(nn.Module):
    """GHM Classification Loss.
    Ref:https://github.com/libuyu/mmdetection/blob/master/mmdet/models/losses/ghm_loss.py
    Details of the theorem can be viewed in the paper
    "Gradient Harmonized Single-stage Detector".
    https://arxiv.org/abs/1811.05181

    Args:
        bins (int): Number of the unit regions for distribution calculation.
        momentum (float): The parameter for moving average.
        use_sigmoid (bool): Can only be true for BCE based loss now.
        loss_weight (float): The weight of the total GHM-C loss.
    """

    def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0,alpha=None):
        super(GHMC, self).__init__()
        self.bins = bins
        self.momentum = momentum
        edges = torch.arange(bins + 1).float() / bins
        self.register_buffer('edges', edges)
        self.edges[-1] += 1e-6
        if momentum > 0:
            acc_sum = torch.zeros(bins)
            self.register_buffer('acc_sum', acc_sum)
        self.use_sigmoid = use_sigmoid
        if not self.use_sigmoid:
            raise NotImplementedError
        self.loss_weight = loss_weight

        self.label_weight = alpha

    def forward(self, pred, target, label_weight =None, *args, **kwargs):
        """Calculate the GHM-C loss.
          
        Args:
            pred (float tensor of size [batch_num, class_num]):
                The direct prediction of classification fc layer.
            target (float tensor of size [batch_num, class_num]):
                Binary class target for each sample.
            label_weight (float tensor of size [batch_num, class_num]):
                the value is 1 if the sample is valid and 0 if ignored.
        Returns:
            The gradient harmonized loss.
        """
        # the target should be binary class label

        # if pred.dim() != target.dim():
        #     target, label_weight = _expand_binary_labels(
        #     target, label_weight, pred.size(-1))

        # 我的pred輸入為[B,C]嫂沉,target輸入為[B]
        target = torch.zeros(target.size(0),2).to(target.device).scatter_(1,target.view(-1,1),1)
        
        # 暫時(shí)不清楚這個(gè)label_weight輸入形式稽寒,默認(rèn)都為1
        if label_weight is None:
            label_weight = torch.ones([pred.size(0),pred.size(-1)]).to(target.device)

        target, label_weight = target.float(), label_weight.float()
        edges = self.edges
        mmt = self.momentum
        weights = torch.zeros_like(pred)

        # gradient length
        # sigmoid梯度計(jì)算
        g = torch.abs(pred.sigmoid().detach() - target)
        # 有效的label的位置
        valid = label_weight > 0
        # 有效的label的數(shù)量
        tot = max(valid.float().sum().item(), 1.0)
        n = 0  # n valid bins
        for i in range(self.bins):
            # 將對(duì)應(yīng)的梯度值劃分到對(duì)應(yīng)的bin中, 0-1
            inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
            # 該bin中存在多少個(gè)樣本
            num_in_bin = inds.sum().item()
            if num_in_bin > 0:
                if mmt > 0:
                    # moment計(jì)算num bin
                    self.acc_sum[i] = mmt * self.acc_sum[i] \
                        + (1 - mmt) * num_in_bin
                    # 權(quán)重等于總數(shù)/num bin
                    weights[inds] = tot / self.acc_sum[i]
                else:
                    weights[inds] = tot / num_in_bin
                n += 1
        if n > 0:
            # scale系數(shù)
            weights = weights / n

        loss = F.binary_cross_entropy_with_logits(
            pred, target, weights, reduction='sum') / tot
        return loss * self.loss_weight
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末趟章,一起剝皮案震驚了整個(gè)濱河市杏糙,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌蚓土,老刑警劉巖宏侍,帶你破解...
    沈念sama閱讀 207,248評(píng)論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異蜀漆,居然都是意外死亡谅河,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,681評(píng)論 2 381
  • 文/潘曉璐 我一進(jìn)店門确丢,熙熙樓的掌柜王于貴愁眉苦臉地迎上來绷耍,“玉大人,你說我怎么就攤上這事鲜侥」邮迹” “怎么了?”我有些...
    開封第一講書人閱讀 153,443評(píng)論 0 344
  • 文/不壞的土叔 我叫張陵描函,是天一觀的道長(zhǎng)病袄。 經(jīng)常有香客問我,道長(zhǎng)赘阀,這世上最難降的妖魔是什么益缠? 我笑而不...
    開封第一講書人閱讀 55,475評(píng)論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮基公,結(jié)果婚禮上幅慌,老公的妹妹穿的比我還像新娘。我一直安慰自己轰豆,他們只是感情好胰伍,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,458評(píng)論 5 374
  • 文/花漫 我一把揭開白布齿诞。 她就那樣靜靜地躺著,像睡著了一般骂租。 火紅的嫁衣襯著肌膚如雪祷杈。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,185評(píng)論 1 284
  • 那天渗饮,我揣著相機(jī)與錄音但汞,去河邊找鬼。 笑死互站,一個(gè)胖子當(dāng)著我的面吹牛私蕾,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播胡桃,決...
    沈念sama閱讀 38,451評(píng)論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼踩叭,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了翠胰?” 一聲冷哼從身側(cè)響起容贝,我...
    開封第一講書人閱讀 37,112評(píng)論 0 261
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎之景,沒想到半個(gè)月后嗤疯,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,609評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡闺兢,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,083評(píng)論 2 325
  • 正文 我和宋清朗相戀三年茂缚,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片屋谭。...
    茶點(diǎn)故事閱讀 38,163評(píng)論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡脚囊,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出桐磁,到底是詐尸還是另有隱情悔耘,我是刑警寧澤,帶...
    沈念sama閱讀 33,803評(píng)論 4 323
  • 正文 年R本政府宣布我擂,位于F島的核電站衬以,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏校摩。R本人自食惡果不足惜看峻,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,357評(píng)論 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望衙吩。 院中可真熱鬧互妓,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,357評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至灼狰,卻和暖如春宛瞄,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背交胚。 一陣腳步聲響...
    開封第一講書人閱讀 31,590評(píng)論 1 261
  • 我被黑心中介騙來泰國(guó)打工份汗, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人承绸。 一個(gè)月前我還...
    沈念sama閱讀 45,636評(píng)論 2 355
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像挣轨,于是被迫代替她去往敵國(guó)和親军熏。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,925評(píng)論 2 344

推薦閱讀更多精彩內(nèi)容