1 關(guān)于Focal Loss
Focal Loss 是一個(gè)在交叉熵(CE)基礎(chǔ)上改進(jìn)的損失函數(shù),來自ICCV2017的Best student paper——Focal Loss for Dense Object Detection匣椰。論文下載鏈接為:https://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf裆熙。Focal Loss的提出源自圖像領(lǐng)域中目標(biāo)檢測任務(wù)中樣本數(shù)量不平衡性的問題,并且這里所謂的不平衡性跟平常理解的是有所區(qū)別的禽笑,它還強(qiáng)調(diào)了樣本的難易性入录。盡管Focal Loss 始于目標(biāo)檢測場景,其實(shí)它可以應(yīng)用到很多其他任務(wù)場景佳镜,只要符合它的問題背景僚稿,就可以試試,會(huì)有意想不到的效果蟀伸。
2 Focal Loss 原理
在引入Focal Loss公式前蚀同,我們以源paper中目標(biāo)檢測的任務(wù)來說:目標(biāo)檢測器通常會(huì)產(chǎn)生高達(dá)100k的候選目標(biāo),只有極少數(shù)是正樣本啊掏,正負(fù)樣本數(shù)量非常不平衡蠢络。
在計(jì)算分類的時(shí)候常用的損失——交叉熵(CE)的公式如下:
其中取值{1,-1}代表正負(fù)樣本迟蜜,為模型預(yù)測的label概率刹孔,通常>0.5就判斷為正樣本,否則為負(fù)樣本娜睛。論文中為了方便展示髓霞,重新定義了:
這樣CE函數(shù)就可以表達(dá)為:.
在CE基礎(chǔ)上,為了解決正負(fù)樣本不平衡性微姊,有人提出一種帶權(quán)重的CE函數(shù):
其中當(dāng):酸茴。 參數(shù) 為控制正負(fù)樣本的權(quán)重分预,取值范圍[0,1]兢交。 盡管這是一種很簡單的解決正負(fù)樣本不平衡的方案,但它還沒真正達(dá)到paper中作者想解決的問題:因?yàn)檎?fù)樣本中也有難易之分笼痹,認(rèn)為模型應(yīng)該更聚焦在難樣本的學(xué)習(xí)上配喳。如下圖酪穿,按正負(fù),難易可將樣本分為四個(gè)維度晴裹,其實(shí)上面帶權(quán)重的CE函數(shù)被济,只是解決了正負(fù)問題,并沒有解決難易問題涧团。
在這里可能有人疑問只磷,怎么來衡量一個(gè)樣本的難易程度,更何況真實(shí)數(shù)據(jù)也沒有這個(gè)標(biāo)記泌绣。其實(shí)钮追,這里的樣本難易是用模型來判斷的,就正樣本集合來說阿迈,如果一個(gè)樣本預(yù)測的元媚,一個(gè)樣本預(yù)測的,明顯前一個(gè)樣本更容易學(xué)習(xí)苗沧,或者說特征更明顯刊棕,是易樣本。這樣也就是說待逞,預(yù)測的概率越接近1或0的樣本甥角,就越是容易學(xué)習(xí)的樣本,相反识樱,越是集中0.5左右的樣本蜈膨,就是難樣本。在sigomid函數(shù)上牺荠,可以按下圖的方式展示樣本的難易之分翁巍。
既然問題已梳理清楚,怎么讓模型對(duì)難易樣本也有區(qū)分性的學(xué)習(xí)休雌,也是說聚焦程度不同灶壶。模型應(yīng)該花更多精力在難樣本的學(xué)習(xí)上,而減少精力在易樣本的學(xué)習(xí)杈曲,之前的CE函數(shù)驰凛,以及帶權(quán)重的CE函數(shù),都是將難樣本担扑、易樣本等同看待的恰响。這樣就引出Focal Loss 的表達(dá)形式:
其中為調(diào)節(jié)因子,取值為[0,5]涌献,當(dāng)胚宦,就等同于CE函數(shù);值越大,表示模型在難易樣本上聚焦的更厲害枢劝。下圖是不同參數(shù)下表現(xiàn)形式介评。
結(jié)合上圖與公式果港,可以看出最岗,當(dāng)趨近1時(shí)磷蜀,權(quán)重趨近0,對(duì)總損失貢獻(xiàn)幾乎沒有影響鹤盒,意味模型較少對(duì)這類樣本的學(xué)習(xí)蚕脏;比如, 在正樣本集合中侦锯,蝗锥,當(dāng)一樣本, 當(dāng)一樣本率触,二者相對(duì)來說终议,前者是難樣本,后者是易樣本葱蝗,反映在Focal Losss上穴张,前者的對(duì)總損失貢獻(xiàn)權(quán)重為0.16,后者0.09两曼,明顯難樣本貢獻(xiàn)權(quán)重更大皂甘,模型也就會(huì)更聚焦對(duì)其學(xué)習(xí)。同理悼凑,負(fù)樣本中一樣偿枕。
但是上面的Focal Loss公式只是體現(xiàn)了難易樣本的區(qū)分,沒有區(qū)分正負(fù)户辫。這樣就引出了完整版的Focal Loss表達(dá)形式:
這樣Focal Loss既能調(diào)整正負(fù)樣本的權(quán)重渐夸,又能控制難易分類樣本的權(quán)重。paper中通過實(shí)驗(yàn)驗(yàn)證渔欢,默認(rèn)墓塌,。在這里取值上可能會(huì)有疑問奥额,理論上正樣本權(quán)重更大些苫幢,取0.75,而paper實(shí)驗(yàn)結(jié)果給的是0.25垫挨。這里結(jié)合其他人的解釋韩肝,說下我的理解:主要原因是,而大部分負(fù)樣本的九榔,導(dǎo)致負(fù)樣本的貢獻(xiàn)權(quán)重還小于正樣本貢獻(xiàn)的權(quán)重哀峻,本意是想調(diào)高正樣本的貢獻(xiàn)權(quán)重涡相,但這樣就有點(diǎn)調(diào)的過大了,所以就有點(diǎn)反過來提高下負(fù)樣本的權(quán)重谜诫。所以在最終版中,不能理解就是完全來調(diào)節(jié)正負(fù)樣本的權(quán)重的攻旦,而是要結(jié)合一起來看喻旷。
3 Focal Loss 實(shí)踐
基于上面的介紹,我們對(duì)Focal Loss進(jìn)行一下實(shí)驗(yàn)驗(yàn)證牢屋。這里選擇MNIST數(shù)據(jù)集進(jìn)行實(shí)驗(yàn):只識(shí)別數(shù)字3且预,這樣將數(shù)據(jù)集的label轉(zhuǎn)變?yōu)閇0,1],1代表是數(shù)字3烙无,0為其他數(shù)字锋谐,這樣就構(gòu)建一個(gè)不平衡的樣本數(shù)據(jù)集。模型最后一層選擇sigmod作為激活函數(shù)進(jìn)行回歸預(yù)測截酷,然后選擇CE與FL兩種損失函數(shù)涮拗,看看訓(xùn)練情況如何。下面為對(duì)應(yīng)的代碼迂苛。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import tensorflow.keras.backend as K
# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') /255
x_test = x_test.reshape(10000, 784).astype('float32') /255
y_train=np.array([1 if d==2 else 0 for d in y_train])
y_test=np.array([1 if d==2 else 0 for d in y_test])
#定義focal loss
def focal_loss(gamma=2., alpha=.25):
def focal_loss_fixed(y_true, y_pred):
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon()+pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
return focal_loss_fixed
#build model
inputs = keras.Input(shape=(784,), name='mnist_input')
h1 = layers.Dense(64, activation='relu')(inputs)
outputs = layers.Dense(1, activation='sigmoid')(h1)
model = tf.keras.Model(inputs, outputs)
#以平方差損失函數(shù)來編譯模型進(jìn)行訓(xùn)練
model.compile(optimizer=keras.optimizers.RMSprop(),
loss=keras.losses.BinaryCrossentropy(),
metrics=['accuracy'])
#以Focal Loss損失函數(shù)來編譯模型進(jìn)行訓(xùn)練
model.compile(optimizer=keras.optimizers.RMSprop(),
loss=[focal_loss(alpha=.25, gamma=2)],
metrics=['accuracy'])
#training
history = model.fit(x_train, y_train, batch_size=64, epochs=5,
validation_data=(x_test, y_test))
訓(xùn)練結(jié)果如下:
從結(jié)果可以看出三热,雖然在該數(shù)據(jù)集上二者提升效果并不大,但Focal Loss在每輪上都優(yōu)于CE的訓(xùn)練效果三幻,所以還是能體現(xiàn)Focal Loss的優(yōu)勢就漾,如果在其他更不平衡的數(shù)據(jù)集上,應(yīng)該效果更好念搬。不管在CV抑堡,還是NLP領(lǐng)域,該損失函數(shù)值得大家去嘗試朗徊。在AAAI2019會(huì)議上提出一種基于Focal loss的改進(jìn)版GHM(Gradient Harmonized Single-stage Detector)首妖,有興趣的也可以去讀讀。
更多文章可關(guān)注筆者公眾號(hào):自然語言處理算法與實(shí)踐