定義
標(biāo)簽平滑(Label smoothing),像L1扇调、L2和dropout一樣瘩例,是機(jī)器學(xué)習(xí)領(lǐng)域的一種正則化方法,通常用于分類(lèi)問(wèn)題赦拘,目的是防止模型在訓(xùn)練時(shí)過(guò)于自信地預(yù)測(cè)標(biāo)簽慌随,改善泛化能力差的問(wèn)題。
背景
對(duì)于分類(lèi)問(wèn)題,我們通常認(rèn)為訓(xùn)練數(shù)據(jù)中標(biāo)簽向量的目標(biāo)類(lèi)別概率應(yīng)為1阁猜,非目標(biāo)類(lèi)別概率應(yīng)為0丸逸。傳統(tǒng)的one-hot編碼的標(biāo)簽向量yi為,
yi={1,i=target0,i≠target
在訓(xùn)練網(wǎng)絡(luò)時(shí)剃袍,最小化損失函數(shù)H(y,p)=?K∑iyilogpi黄刚,其中pi由對(duì)模型倒數(shù)第二層輸出的logits向量z應(yīng)用Softmax函數(shù)計(jì)算得到,
pi=exp(zi)∑Kjexp(zj)
傳統(tǒng)one-hot編碼標(biāo)簽的網(wǎng)絡(luò)學(xué)習(xí)過(guò)程中民效,鼓勵(lì)模型預(yù)測(cè)為目標(biāo)類(lèi)別的概率趨近1憔维,非目標(biāo)類(lèi)別的概率趨近0,即最終預(yù)測(cè)的logits向量(logits向量經(jīng)過(guò)softmax后輸出的就是預(yù)測(cè)的所有類(lèi)別的概率分布)中目標(biāo)類(lèi)別zi的值會(huì)趨于無(wú)窮大畏邢,使得模型向預(yù)測(cè)正確與錯(cuò)誤標(biāo)簽的logit差值無(wú)限增大的方向?qū)W習(xí)业扒,而過(guò)大的logit差值會(huì)使模型缺乏適應(yīng)性,對(duì)它的預(yù)測(cè)過(guò)于自信舒萎。
在訓(xùn)練數(shù)據(jù)不足以覆蓋所有情況下程储,這就會(huì)導(dǎo)致網(wǎng)絡(luò)過(guò)擬合,泛化能力差臂寝,而且實(shí)際上有些標(biāo)注數(shù)據(jù)不一定準(zhǔn)確章鲤,這時(shí)候使用交叉熵?fù)p失函數(shù)作為目標(biāo)函數(shù)也不一定是最優(yōu)的了。
數(shù)學(xué)定義
label smoothing結(jié)合了均勻分布咆贬,用更新的標(biāo)簽向量^yi來(lái)替換傳統(tǒng)的ont-hot編碼的標(biāo)簽向量yhat
^yi=yhot(1?α)+α/K
其中K為多分類(lèi)的類(lèi)別總個(gè)數(shù)败徊,αα是一個(gè)較小的超參數(shù)(一般取0.1),即
^yi={1?α,i=targetα/K,i≠target
這樣掏缎,標(biāo)簽平滑后的分布就相當(dāng)于往真實(shí)分布中加入了噪聲皱蹦,避免模型對(duì)于正確標(biāo)簽過(guò)于自信,使得預(yù)測(cè)正負(fù)樣本的輸出值差別不那么大御毅,從而避免過(guò)擬合根欧,提高模型的泛化能力。
效果
NIPS 2019上的這篇論文<u style="box-sizing: border-box; list-style: inherit;">When Does Label Smoothing Help?</u>用實(shí)驗(yàn)說(shuō)明了為什么Label smoothing可以work端蛆,指出標(biāo)簽平滑可以讓分類(lèi)之間的cluster更加緊湊凤粗,增加類(lèi)間距離,減少類(lèi)內(nèi)距離今豆,提高泛化性嫌拣,同時(shí)還能提高M(jìn)odel Calibration(模型對(duì)于預(yù)測(cè)值的confidences和accuracies之間aligned的程度)。但是在模型蒸餾中使用Label smoothing會(huì)導(dǎo)致性能下降呆躲。
從標(biāo)簽平滑的定義我們可以看出异逐,它鼓勵(lì)神經(jīng)網(wǎng)絡(luò)選擇正確的類(lèi),并且正確類(lèi)和其余錯(cuò)誤的類(lèi)的差別是一致的插掂。與之不同的是灰瞻,如果我們使用硬目標(biāo)腥例,則會(huì)允許不同的錯(cuò)誤類(lèi)之間有很大不同≡腿螅基于此論文作者提出了一個(gè)結(jié)論:標(biāo)簽平滑鼓勵(lì)倒數(shù)第二層激活函數(shù)之后的結(jié)果靠近正確的類(lèi)的模板燎竖,并且同樣的遠(yuǎn)離錯(cuò)誤類(lèi)的模板。
作者設(shè)計(jì)了一個(gè)可視化的方案來(lái)證明這件事情要销,具體方案為:(1)挑選3個(gè)類(lèi)构回;(2)選取通過(guò)這三個(gè)類(lèi)的模板的標(biāo)準(zhǔn)正交基的平面;(3)將倒數(shù)第二層激活函數(shù)之后的結(jié)果映射到該平面疏咐。作者做了4組實(shí)驗(yàn)纤掸,第一組實(shí)驗(yàn)為在CIFAR-10/AlexNet(數(shù)據(jù)集/模型)上面“飛機(jī)”、“汽車(chē)”和“鳥(niǎo)”三類(lèi)的結(jié)果浑塞,可視化結(jié)果如下所示:
[圖片上傳失敗...(image-aef0e1-1644756672006)]
從中我們可以看出借跪,加了標(biāo)簽平滑之后(后兩張圖),每個(gè)類(lèi)聚的更緊了缩举,而且和其余類(lèi)的距離大致一致垦梆。第二組實(shí)驗(yàn)為在CIFAR-100/ResNet-56(數(shù)據(jù)集/模型)上的實(shí)驗(yàn)結(jié)果匹颤,三個(gè)類(lèi)分別為“河貍”仅孩、“海豚”與“水獺”,我們可以得到類(lèi)似的結(jié)果:
[圖片上傳失敗...(image-d61e63-1644756672006)]
在第三組實(shí)驗(yàn)中印蓖,作者測(cè)試了在ImageNet/Inception-v4(數(shù)據(jù)集/模型)上的表現(xiàn)辽慕,三個(gè)類(lèi)分別為“貓鼬”、“鯉魚(yú)”和“切刀肉”赦肃,結(jié)果如下:
[圖片上傳失敗...(image-c18b24-1644756672006)]
因?yàn)镮mageNet有很多細(xì)粒度的分類(lèi)溅蛉,可以用來(lái)測(cè)試比較相似的類(lèi)之間的關(guān)系。作者在第四組實(shí)驗(yàn)中選擇的三個(gè)類(lèi)分別為“玩具貴賓犬”他宛、“ 迷你貴賓犬”和“鯉魚(yú)”船侧,可以看出前兩個(gè)類(lèi)是很相似的,最后一個(gè)差別比較大的類(lèi)在圖中用藍(lán)色表示厅各,結(jié)果如下:
[圖片上傳失敗...(image-fb65ba-1644756672006)]
可以看出在使用硬目標(biāo)的情況下镜撩,兩個(gè)相似的類(lèi)彼此比較靠近。但是標(biāo)簽平滑強(qiáng)制要求每個(gè)示例與所有剩余類(lèi)的模板之間的距離相等队塘,這就導(dǎo)致了后兩張圖中兩個(gè)類(lèi)距離較遠(yuǎn)袁梗,這在一定程度上造成了信息的損失。
代碼實(shí)現(xiàn)
pytorch部分代碼
class LabelSmoothing(nn.Module):
def __init__(self, size, smoothing=0.0):
super(LabelSmoothing, self).__init__()
self.criterion = nn.KLDivLoss(size_average=False)
#self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing#if i=y的公式
self.smoothing = smoothing
self.size = size
self.true_dist = None
def forward(self, x, target):
"""
x表示輸入 (N憔古,M)N個(gè)樣本遮怜,M表示總類(lèi)數(shù),每一個(gè)類(lèi)的概率log P
target表示label(M鸿市,)
"""
assert x.size(1) == self.size
true_dist = x.data.clone()#先深復(fù)制過(guò)來(lái)
#print true_dist
true_dist.fill_(self.smoothing / (self.size - 1))#otherwise的公式
#print true_dist
#變成one-hot編碼锯梁,1表示按列填充即碗,
#target.data.unsqueeze(1)表示索引,confidence表示填充的數(shù)字
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
self.true_dist = true_dist
return self.criterion(x, Variable(true_dist, requires_grad=False))
loss_function = LabelSmoothing(num_labels, 0.1)
tensorflow代碼實(shí)現(xiàn)
def smoothing_cross_entropy(logits,labels,vocab_size,confidence):
with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]):
# Low confidence is given to all non-true labels, uniformly.
low_confidence = (1.0 - confidence) / to_float(vocab_size - 1)
# Normalizing constant is the best cross-entropy value with soft targets.
# We subtract it just for readability, makes no difference on learning.
normalizing = -(
confidence * tf.log(confidence) + to_float(vocab_size - 1) *
low_confidence * tf.log(low_confidence + 1e-20))
soft_targets = tf.one_hot(
tf.cast(labels, tf.int32),
depth=vocab_size,
on_value=confidence,
off_value=low_confidence)
xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits, labels=soft_targets)
return xentropy - normalizing