MAML的核心思想是利用元學(xué)習(xí)來(lái)找到一個(gè)好的模型初始化,從而能夠在新任務(wù)上進(jìn)行快速適應(yīng)瓶竭。這種方法旨在處理“少樣本學(xué)習(xí)”的挑戰(zhàn)视乐,即當(dāng)新任務(wù)的數(shù)據(jù)量非常有限時(shí)如何有效地學(xué)習(xí)。傳統(tǒng)學(xué)習(xí)的數(shù)據(jù)點(diǎn)是一個(gè)樣本,而元學(xué)習(xí)的數(shù)據(jù)點(diǎn)是一個(gè)小數(shù)據(jù)集(任務(wù))译打,任務(wù)包含了很多樣本耗拓。元學(xué)習(xí)對(duì)每個(gè)任務(wù)中的每個(gè)樣本進(jìn)行訓(xùn)練得到每個(gè)任務(wù)的loss,并得到任務(wù)的損失和losses奏司。對(duì)losses進(jìn)行優(yōu)化來(lái)更新元學(xué)習(xí)模型的參數(shù)乔询。
MAML:
摘要:提出一個(gè)模型無(wú)關(guān)的元學(xué)習(xí)算法,它與任何由梯度下降訓(xùn)練的模型兼容并且可以應(yīng)用到各種不同的學(xué)習(xí)問(wèn)題韵洋,包括分類(lèi)竿刁,回歸,強(qiáng)化學(xué)習(xí)搪缨。元學(xué)習(xí)的目標(biāo)是在各種學(xué)習(xí)任務(wù)上訓(xùn)練一個(gè)模型食拜,它可以?xún)H僅使用小數(shù)量的訓(xùn)練樣本來(lái)解決新的學(xué)習(xí)任務(wù)。在我們的方法中副编,模型的參數(shù)被明確地訓(xùn)練负甸,這樣少量的梯度步長(zhǎng)和來(lái)自新任務(wù)的少量訓(xùn)練數(shù)據(jù)將在該任務(wù)上產(chǎn)生良好的泛化性能。該方法訓(xùn)練模型更容易去微調(diào)痹届。在兩個(gè)小樣本的圖像分類(lèi)上得到了sota的性能呻待,在小樣本回歸上也得好的結(jié)果,并且加速了使用神經(jīng)網(wǎng)絡(luò)策略的策略梯度強(qiáng)化的微調(diào)短纵。
引言:
問(wèn)題:
從小樣本得到認(rèn)知目標(biāo)或者快速的學(xué)習(xí)新技能屬于人類(lèi)擅長(zhǎng)的事带污。而智能機(jī)器學(xué)習(xí)這方面的能力存在挑戰(zhàn)。因?yàn)榇肀仨殞⑵湟郧暗慕?jīng)驗(yàn)與少量的新信息集成起來(lái)香到,同時(shí)避免對(duì)新數(shù)據(jù)進(jìn)行過(guò)擬合(只學(xué)會(huì)了這幾個(gè)樣本鱼冀,并沒(méi)有學(xué)習(xí)到能力)。此外悠就,先前的經(jīng)驗(yàn)和新數(shù)據(jù)的形式將取決于任務(wù)本身千绪。
重要性:
因此,提出的方法應(yīng)該對(duì)任務(wù)和實(shí)現(xiàn)任務(wù)的方法通用梗脾。
難點(diǎn):
創(chuàng)思:
在這項(xiàng)工作中荸型,提出了一個(gè)元學(xué)習(xí)算法MAML,與特定模型無(wú)關(guān)炸茧,即它可以直接應(yīng)用于任何可微的模型瑞妇。MAML聚焦在深度神經(jīng)網(wǎng)絡(luò),闡釋了如何用一個(gè)最小步數(shù)的微調(diào)梭冠,便可以更容易處理不同的網(wǎng)絡(luò)結(jié)構(gòu)和不同的問(wèn)題辕狰,包括分類(lèi),回歸控漠,策略梯度強(qiáng)化學(xué)習(xí)蔓倍。
提出的方法關(guān)注學(xué)習(xí)模型的初始化參數(shù)悬钳。以便新任務(wù)再模型上通過(guò)少量的樣本和迭代可以進(jìn)行快速適應(yīng)。與先驗(yàn)元學(xué)習(xí)方法和學(xué)習(xí)更新函數(shù)或者更新規(guī)則不同偶翅,算法沒(méi)有擴(kuò)展到學(xué)習(xí)參數(shù)或模型結(jié)構(gòu)的數(shù)量上(有論文已經(jīng)做了結(jié)構(gòu)和數(shù)量的了)默勾。MAML可以組合全連接,卷積聚谁,RNN母剥,不同的損失函數(shù),包括可微分的監(jiān)督損失和不可微分的強(qiáng)化學(xué)習(xí)目標(biāo)垦巴。
模型參數(shù)的訓(xùn)練過(guò)程媳搪,通過(guò)幾個(gè)或者一個(gè)梯度更新步驟,簡(jiǎn)單的微調(diào)參數(shù)可以得到好的結(jié)果骤宣。事實(shí)上,模型的優(yōu)化是容易且快速的序愚,允許在正確的空間快速學(xué)習(xí)憔披。學(xué)習(xí)的過(guò)程可以被看作最大化新任務(wù)損失函數(shù)對(duì)參數(shù)的敏感性。當(dāng)敏感性高的時(shí)候爸吮,對(duì)于參數(shù)的小的局部的改變可以導(dǎo)致在任務(wù)損失上的提升芬膝。
結(jié)果:
評(píng)估MAML相比流行的SOTA的專(zhuān)門(mén)為監(jiān)督分類(lèi)設(shè)計(jì)的one-shot 學(xué)習(xí)方法。方法使用小的參數(shù)形娇,但也可以容易的應(yīng)用到回歸以及強(qiáng)化學(xué)習(xí)锰霜,歸功于直接預(yù)訓(xùn)練初始參數(shù)使得性能提升。
假設(shè):
模型:
MAML:隨機(jī)初始化模型參數(shù)桐早,通過(guò)訓(xùn)練來(lái)學(xué)習(xí)最優(yōu)的初始化參數(shù)癣缅。初始化參數(shù)的訓(xùn)練主要分為兩步,第一步是任務(wù)內(nèi)的參數(shù)更新哄酝,第二步是任務(wù)間的參數(shù)更新
其中:
Require 給出所有任務(wù)的分布以及參數(shù)更新的學(xué)習(xí)率
1 友存、隨機(jī)初始化模型參數(shù);2陶衅、 循環(huán)訓(xùn)練更新參數(shù)屡立,直到訓(xùn)練截止;3 搀军、采樣一個(gè)batch膨俐,包含多個(gè)任務(wù),每個(gè)任務(wù)K個(gè)樣本罩句;4焚刺、遍歷所有任務(wù);5的止、計(jì)算第i個(gè)任務(wù)在lossL下的梯度檩坚;6 、任務(wù)內(nèi)的參數(shù)更新;7匾委、batch中的任務(wù)內(nèi)參數(shù)更新完成拖叙;8、任務(wù)間的參數(shù)更新赂乐。
不同的任務(wù)需要選擇不同的loss薯鳍,在回歸和分類(lèi)的算法上的應(yīng)用時(shí),loss的選擇為均方誤差和交叉熵挨措;在算法1中具體化任務(wù)和問(wèn)題得到算法2:
在強(qiáng)化學(xué)習(xí)上的MAML,loss為獎(jiǎng)勵(lì)函數(shù)浅役,模型輸出為決策斩松,
實(shí)驗(yàn):
實(shí)驗(yàn)回答論文2個(gè)問(wèn)題(這種先描述問(wèn)題的方法可以借鑒到寫(xiě)作上):
1)MAML可以在新任務(wù)上快速的學(xué)習(xí)嗎???
2)模型用MAML觉既,在額外的更新次數(shù)和樣本個(gè)數(shù)上可以連續(xù)的提升性能惧盹?
回歸任務(wù),用樣本做sin函數(shù)回歸
pretrained的方法只做一次參數(shù)更新瞪讼,而MAML做兩次參數(shù)更新钧椰,第一次更新為下一次更新確定方向。不同的梯度次數(shù)訓(xùn)練得到的預(yù)測(cè)結(jié)果不同符欠,從圖中可以看到K=5和K=10時(shí)10次更新結(jié)果最好嫡霞,1次梯度下降有不錯(cuò)的效果,能夠得到快速的適應(yīng)希柿,回答了任務(wù)1诊沪。隨著更新次數(shù)(grad step)和樣本個(gè)數(shù)K的提高,性能得到了提升狡汉,回答了問(wèn)題2娄徊。預(yù)訓(xùn)練的方法沒(méi)有元參數(shù)更新的步驟,效果都很差盾戴,很難擬合寄锐。
通過(guò)loss值可以看出MAML在步數(shù)增加的情況沒(méi)有過(guò)擬合,loss更低尖啡,性能持續(xù)提高橄仆,回答了問(wèn)題2。
分類(lèi)實(shí)驗(yàn):
Datasets:Omniglot衅斩,MiniImagenet
Omniglot:來(lái)自50個(gè)不同的字母(類(lèi))盆顾,1623個(gè)樣本,選擇20個(gè)類(lèi)畏梆。1200個(gè)作為訓(xùn)練集您宪,剩下的做測(cè)試集奈懒。
MiniImagenet:64個(gè)訓(xùn)練類(lèi),12個(gè)驗(yàn)證類(lèi)宪巨,24個(gè)測(cè)試類(lèi)
baseline:
MANN:Memory-Augmented Neural Networks 記憶增強(qiáng)的神經(jīng)網(wǎng)絡(luò)
Siamese nets 孿生網(wǎng)絡(luò)磷杏,共享encoder權(quán)重
matching nets 匹配網(wǎng)絡(luò),few-shot learning方法捏卓,用目標(biāo)樣本和支持集一起做嵌入极祸,后計(jì)算二者的相似度作為權(quán)重,為支持集賦予權(quán)重預(yù)測(cè)標(biāo)簽怠晴。
neural statistician 神經(jīng)統(tǒng)計(jì)師模型遥金,包括encoder,統(tǒng)計(jì)網(wǎng)絡(luò)(有很多不同的統(tǒng)計(jì)方式)蒜田,decoder稿械。統(tǒng)計(jì)網(wǎng)絡(luò)的任務(wù)是將所有樣本的特征整合,輸出一個(gè)集合表示物邑,即統(tǒng)計(jì)信息【加一些額外的設(shè)計(jì)和策略溜哮,神經(jīng)統(tǒng)計(jì)師是否可以被擴(kuò)展并應(yīng)用于演化聚類(lèi)任務(wù)?】
memory mod. 記憶增強(qiáng)的神經(jīng)網(wǎng)絡(luò)的一種色解,原文提到運(yùn)用到life-long中受限。
meta-learner LSTM 在元學(xué)習(xí)場(chǎng)景中使用的LSTM餐茵,LSTM接受梯度信息科阎,輸出應(yīng)該應(yīng)用于模型權(quán)重的更新。LSTM被看作一個(gè)優(yōu)化器忿族。
MAML first order approx 代表的是梯度之考慮一次微分锣笨,二次微分因?yàn)闀?huì)帶來(lái)計(jì)算開(kāi)銷(xiāo)被忽略。
分類(lèi)code:
maml pytorch代碼:https://github.com/dragen1860/MAML-Pytorch/blob/master/meta.py
代碼里的實(shí)現(xiàn)道批,對(duì)每個(gè)任務(wù)错英,先初始化參數(shù),對(duì)初始化的模型參數(shù)進(jìn)行訓(xùn)練得到第一次參數(shù)隆豹,在第一次參數(shù)的更新方向上更新了初始參數(shù)椭岩。也就是第一次參數(shù)的更新決定了更新方向,第二次更新更新了實(shí)際參數(shù)璃赡。
對(duì)batch判哥,batch中每個(gè)任務(wù)學(xué)習(xí)對(duì)應(yīng)的任務(wù)loss,將每個(gè)loss求和得到整體losses碉考,并對(duì)losses進(jìn)行優(yōu)化塌计。
微調(diào)過(guò)程:copy訓(xùn)練好的模型,在模型上進(jìn)行微調(diào)和驗(yàn)證侯谁。在測(cè)試集學(xué)習(xí)每個(gè)任務(wù)的loss锌仅,并得到losses和更新權(quán)重。分別對(duì)任務(wù)中的樣本在新權(quán)重下進(jìn)行測(cè)試热芹。
強(qiáng)化學(xué)習(xí)(實(shí)驗(yàn)部分很難看懂贱傀,以后補(bǔ)充)
? ? ? ?討論和未來(lái)工作:介紹了一種基于元學(xué)習(xí)的方法,該方法基于通過(guò)梯度下降學(xué)習(xí)易于適應(yīng)的模型參數(shù)剿吻。方法有很多好處窍箍,它很簡(jiǎn)單,并且沒(méi)有為元學(xué)習(xí)引入任何學(xué)習(xí)參數(shù)丽旅。它可以組合任何可以用基于梯度訓(xùn)練的模型椰棘,任何可以微分的目標(biāo),包括分類(lèi)榄笙,回歸邪狞,強(qiáng)化學(xué)習(xí)。模型僅僅產(chǎn)生權(quán)重的初始化茅撞,適應(yīng)任何數(shù)據(jù)數(shù)K和梯度步驟數(shù)setp grad帆卓,通過(guò)SOTA的分類(lèi)結(jié)果,也在RL上使用了策略梯度米丘。從過(guò)去的任務(wù)中重用知識(shí)可能是制作高容量可擴(kuò)展模型(例如深度神經(jīng)網(wǎng)絡(luò))的關(guān)鍵因素剑令,可以使用小數(shù)據(jù)集進(jìn)行快速訓(xùn)練。這項(xiàng)工作是邁向簡(jiǎn)單通用元學(xué)習(xí)技術(shù)的第一步拄查,可應(yīng)用于任何問(wèn)題和任何模型吁津。該領(lǐng)域的進(jìn)一步研究可以使多任務(wù)初始化成為深度學(xué)習(xí)和強(qiáng)化學(xué)習(xí)的標(biāo)準(zhǔn)成分。非常有用的工作堕扶!