一 寫在前面
未經(jīng)允許法牲,不得轉(zhuǎn)載矢否,謝謝~~~
今天這篇paper是NeurIPS2019的一篇paper邦马,雖然時間有點久了腥刹,但是看完paper還是有覺得值得借鑒的地方马胧,還是簡單記錄一下??。
- 出處:NeurIPS2019
- title: Meta-Weight-Net: Learning an Explicit Mapping For SampleWeighting
- link: https://arxiv.org/pdf/1902.07379.pdf](https://arxiv.org/pdf/1902.07379.pdf
二 主要內(nèi)容
2.1 backgrounds
deep learning容易對biased data產(chǎn)生過擬合的現(xiàn)象衔峰。
這里作者重點歸納了兩種biased data情況:
- noisy data 標(biāo)簽有噪聲數(shù)據(jù)
- long-tail data 長尾分布數(shù)據(jù)
這種過擬合自然會導(dǎo)致模型的生成泛化能力受到影響佩脊,而為了解決這個問題的一個思路就是進(jìn)行sample reweighting,也就是對不同的樣本設(shè)置不同的權(quán)重垫卤。 那reweighting的方法本質(zhì)要學(xué)習(xí)的就是從不同樣本到權(quán)重之間的映射關(guān)系威彰,然后通過最小化加權(quán)之后的損失函數(shù)來優(yōu)化模型參數(shù)。
2.2 related work
目前主要的sample reweighting方法可以分為兩大類:
- 以focal loss為代表:
- 單樣本的loss越大 --> 認(rèn)為這個樣本更難分辨 --> 增加這個樣本的loss權(quán)重穴肘;
- 經(jīng)典方法包括focal loss歇盼,AdaBoost,hard negative mining梢褐;
- 這類方法主要適合用于解決long-tail數(shù)據(jù)旺遮,使得分布少的類別能擁有更高的權(quán)重;
- 以SPL為代表:
- 單樣本的loss越小 --> 認(rèn)為該樣本的標(biāo)簽可信度更高 --> 增加這個樣本的loss權(quán)重盈咳;
- 經(jīng)典方法包括SPL耿眉,iterative reweighting,以及其他變種方法鱼响;
- 這類方法適合用于解決noisy data問題鸣剪,使得標(biāo)簽正確的樣本擁有更高的權(quán)重;
下圖以focal loss和SPL為例,直觀給出了兩類方法的差別筐骇,focal loss遞增债鸡,SPL遞減。
2.3 motivation
作者首先總結(jié)了現(xiàn)有方法的兩大缺點:
1) 在現(xiàn)實無法預(yù)知data具體分布(long-tail還是noisy)的情況下铛纬,不知道要選遞增型還是遞減型厌均。更何況,現(xiàn)實中可能出現(xiàn)的是long-tail并且noisy的數(shù)據(jù)分布告唆;
2) 不管是哪一類方法棺弊,都需要超參數(shù)。
針對以上兩點擒悬,該文的motivation就是能否設(shè)計一個自適應(yīng)的且不需要超參數(shù)的reweighting方法模她,即找到一種從loss到weight的映射關(guān)系。
三 文章方法 Meta-Weighting-Net (MW-Net)
3.1 key idea
為了提出這樣一個自適應(yīng)的且不需要超參數(shù)的reweighting方法懂牧,文章的主要想法是用MLP來充當(dāng)weight fucntion的作用侈净,即讓MLP自動學(xué)習(xí)從loss到weight之間的映射關(guān)系。然后用unbiased meta data來引導(dǎo)MLP的參數(shù)學(xué)習(xí)僧凤。
如下圖所示畜侦,文章確實可以做到可以同時處理不同分布的數(shù)據(jù)(long-tail/noisy)。
3.2 具體方法
記整個分類網(wǎng)絡(luò)為, 用于預(yù)測樣本loss權(quán)重的MLP網(wǎng)絡(luò)為
拼弃, 網(wǎng)絡(luò)的整體訓(xùn)練過程如下圖:
可以重點關(guān)注箭頭的顏色夏伊,紅色的表示的是meta-weight-net的參數(shù)更新過程,而黑色的表示的整體分類網(wǎng)絡(luò)的參數(shù)更新過程吻氧。對于時間t而言,最重要的幾個步驟如下:
1) 對于分類網(wǎng)絡(luò)的參數(shù)
咏连, 用從訓(xùn)練集中采出的minibatch data進(jìn)行網(wǎng)絡(luò)參數(shù)的更新盯孙,得到
, 注意這里是暫時更新的
,并沒有替換原來
的參數(shù)
祟滴,可以理解為是一個臨時變量振惰。(圖中step5)
2) 對于MLP網(wǎng)絡(luò)為,用當(dāng)前的
預(yù)測得到的loss作為MLP網(wǎng)絡(luò)的輸入垄懂,得到輸出的loss weights骑晶,用meta-dataset構(gòu)建出來的minibatch data更新參數(shù)
, 得到t+1時刻的
草慧,替換原來
中的網(wǎng)絡(luò)參數(shù)桶蛔。(圖中step6)
3) 用t+1時刻的和t時刻的
, 再次用訓(xùn)練集中采出的minibatch data進(jìn)行網(wǎng)絡(luò)參數(shù)的更新漫谷,得到
仔雷,這次的
才真正作為t+1時刻的
, 替換原來
中的網(wǎng)絡(luò)參數(shù)。(圖中step7)
具體的公式可能看起來稍微有點復(fù)雜碟婆,但其實就是SGD在mini-batch上的優(yōu)化电抚。
最終的偽代碼如下所示:
四 寫在最后
整個思路還是比較巧妙的,而且之前的實驗結(jié)果圖也確實驗證了方法能對不同分布的數(shù)據(jù)都有效竖共。
目前還存在兩點問題:
1) meta-dataset具體什么怎么構(gòu)造的蝙叛,為什么在更新MLP的時候不能用正常的mini-batch,而要用meta-dataset公给;
2) 參數(shù)更新為什么一定要分3步借帘,直接a)更新;2)更新
是不可以的嗎
太細(xì)節(jié)的地方可能沒有g(shù)et到妓布,歡迎知道的小伙伴多多交流姻蚓。