受干擾學(xué)生論文里提到6篇知識蒸餾及偽標(biāo)簽暗知識的論文腊凶,除去較早發(fā)表的3篇狼渊,包括以下3篇:
1.????Distilling the knowledge in a neural network(2015),神經(jīng)網(wǎng)絡(luò)知識蒸餾的開山之作馁害,也是最為知名的
2.????Bayesian dark knowledge(2015)
3.????Born again neural networks(2018)窄俏,訓(xùn)練一個和teacher參數(shù)一樣多的student網(wǎng)絡(luò),并且準(zhǔn)確率超過了teacher網(wǎng)絡(luò)
自糾正論文里提到2篇用于半監(jiān)督學(xué)習(xí)的偽標(biāo)簽論文碘菜,包括:
1.????Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks(2013)(有趣的是SCHP論文里不涉及到半監(jiān)督凹蜈,偽標(biāo)簽最初是用于蒸餾還是用于半監(jiān)督?)
2.????Training deep neural networks on noisy labels with bootstrapping(2014)
先從開山之作開始吧:https://arxiv.org/pdf/1503.02531.pdf
提高幾乎所有機(jī)器學(xué)習(xí)算法性能的一個非常簡單的方法是在相同的數(shù)據(jù)上訓(xùn)練多個不同的模型忍啸,然后對輸出取平均[3]仰坦。不幸的是,使用整個模型集合進(jìn)行預(yù)測是很麻煩的计雌,而且計(jì)算成本可能太高悄晃,以至于在很多情況下無法部署,尤其是在單個模型就是大網(wǎng)絡(luò)的情況下凿滤。[1]已經(jīng)證明妈橄,可以將集成模型的知識壓縮到一個更容易部署的單一模型中,而我們使用了不同的壓縮技術(shù)進(jìn)一步發(fā)展了這一方法翁脆。我們在MNIST上取得了一些令人驚訝的結(jié)果眷蚓,并且證明了通過將模型集合中的知識提取到單個模型中,我們可以顯著地改進(jìn)大量使用的商業(yè)系統(tǒng)的聲學(xué)模型反番。我們還介紹了一種由一個或多個完整模型和多個專家模型組成的新型集成方法沙热,這些模型學(xué)習(xí)區(qū)分完整模型混淆的細(xì)粒度類別。與混合專家不同罢缸,這些專家模型可以快速并行地訓(xùn)練篙贸。
1 引言
許多昆蟲的幼蟲形態(tài)和成蟲形態(tài)不同,幼蟲形態(tài)可以從環(huán)境中提取能量和營養(yǎng)枫疆,成蟲形態(tài)適合旅行和繁殖歉秫。早大規(guī)模機(jī)器學(xué)習(xí)中,我們常常在訓(xùn)練階段和部署階段使用非常相似的模型养铸,盡管訓(xùn)練和部署的要求非常不同:對于語音和物體識別這樣的任務(wù)雁芙,訓(xùn)練必須從非常大規(guī)模的、高度冗余的數(shù)據(jù)集中提取結(jié)構(gòu)钞螟,但訓(xùn)練是不需要實(shí)時的兔甘,并且可以使用大量的算力。然而在部署的時候鳞滨,對延遲和計(jì)算資源有更嚴(yán)格的要求洞焙。和昆蟲的類比提示我們,如果復(fù)雜模型能夠更容易地從數(shù)據(jù)中提取結(jié)構(gòu)的話,我們應(yīng)該很愿意取訓(xùn)練復(fù)雜的模型澡匪。復(fù)雜模型可以是多個單獨(dú)訓(xùn)練的模型的集成熔任,也可以是一個使用很強(qiáng)的正則化(如dropout)訓(xùn)練的單個非常大的模型[9]。復(fù)雜模型一旦訓(xùn)練完了唁情,我們就可以開始另一種訓(xùn)練疑苔,稱之為“蒸餾”,將知識從復(fù)雜模型轉(zhuǎn)移到更適合部署的小模型上甸鸟。[1]已經(jīng)開創(chuàng)了該方法的第一個版本惦费,在他們的重要論文中,他們令人信服地證明了一個大的集成模型的知識可以轉(zhuǎn)移到小模型上抢韭。
一個認(rèn)知上的障礙可能使得這種有前景的方法沒有更多的研究薪贫,就是我們傾向于把一個模型訓(xùn)練習(xí)得的參數(shù)取值看做是該模型學(xué)到的知識,這使得我們很難看到刻恭,如何改變網(wǎng)絡(luò)結(jié)構(gòu)卻又保持學(xué)到的知識瞧省。對知識的一個更抽象的視角是,它是從輸入向量到輸出向量的學(xué)習(xí)到的映射鳍贾。
一個可能阻礙了對這種非常有前途的方法進(jìn)行更多研究的一個概念性障礙是鞍匾,我們傾向于用所學(xué)的參數(shù)值來識別一個經(jīng)過訓(xùn)練的模型中的知識,這使得我們很難看到我們?nèi)绾胃淖兡P偷男问蕉3窒嗤闹R贾漏。知識的一個更抽象的觀點(diǎn)是候学,它是從輸入向量到輸出向量的學(xué)習(xí)映射藕筋。對于學(xué)習(xí)區(qū)分大量類的笨重模型纵散,正常的訓(xùn)練目標(biāo)是最大化正確答案的平均對數(shù)概率,但學(xué)習(xí)的副作用是訓(xùn)練模型為所有錯誤答案分配概率隐圾,即使這些概率很小伍掀,有些比其他的大得多。不正確答案的相對概率告訴我們很多關(guān)于這個繁瑣的模型是如何被推廣的暇藏。例如蜜笤,一個寶馬的形象,可能只有很小的機(jī)會被誤認(rèn)為垃圾車盐碱,但這種錯誤的可能性仍然比把它誤認(rèn)為胡蘿卜高出許多倍把兔。
一般認(rèn)為,用于培訓(xùn)的目標(biāo)函數(shù)應(yīng)盡可能地反映用戶的真實(shí)目標(biāo)瓮顽。盡管如此县好,當(dāng)真正的目標(biāo)是很好地推廣到新的數(shù)據(jù)時,通常對模型進(jìn)行訓(xùn)練以優(yōu)化訓(xùn)練數(shù)據(jù)的性能暖混。顯然缕贡,訓(xùn)練模型使其能夠很好地進(jìn)行概括會更好,但這需要關(guān)于正確的泛化方法的信息,而這些信息通常不可用晾咪。然而收擦,當(dāng)我們從一個大模型提取知識到一個小模型時,我們可以訓(xùn)練小模型像大模型一樣進(jìn)行泛化谍倦。如果繁瑣的模型能夠很好地進(jìn)行泛化塞赂,例如,它是不同模型的一個大集合的平均值剂跟,那么以同樣的方式訓(xùn)練的小模型在測試數(shù)據(jù)上的表現(xiàn)通常要比在訓(xùn)練集合時使用的同一訓(xùn)練集上以正常方式訓(xùn)練的小模型要好得多减途。
將笨重模型的泛化能力轉(zhuǎn)化為小模型的一個明顯的方法是利用笨重模型產(chǎn)生的類概率作為訓(xùn)練小模型的“軟目標(biāo)”。對于這個轉(zhuǎn)移階段曹洽,我們可以使用相同的訓(xùn)練集或單獨(dú)的“轉(zhuǎn)移”集鳍置。當(dāng)復(fù)雜模型是一個簡單模型的大集合時,我們可以使用它們各自的預(yù)測分布的算術(shù)或幾何平均值作為軟目標(biāo)送淆。當(dāng)軟目標(biāo)具有較高的熵時税产,每個訓(xùn)練實(shí)例提供的信息比硬目標(biāo)大得多,訓(xùn)練樣本之間的梯度方差也小得多偷崩,因此小模型通潮倏剑可以用比原來繁瑣的模型少得多的數(shù)據(jù)進(jìn)行訓(xùn)練,并使用更高的學(xué)習(xí)率阐斜。
對于像MNIST這樣的任務(wù)衫冻,笨重的模型幾乎總是以非常高的置信度產(chǎn)生正確的答案,關(guān)于學(xué)習(xí)函數(shù)的大部分信息都存在于軟目標(biāo)中非常小的概率比率中谒出。例如隅俘,一個版本的a 2可能被賦予10?6的概率是3,10?9是7的概率笤喳,而另一個版本的概率可能是相反的为居。這是一個有價值的信息,它定義了數(shù)據(jù)上豐富的相似性結(jié)構(gòu)(例如杀狡,它說哪個2看起來像3蒙畴,哪個看起來像7),但是在傳輸階段呜象,它對交叉熵成本函數(shù)的影響非常小膳凝,因?yàn)楦怕史浅=咏诹恪aruana和他的合作者通過使用logits(最終的softmax的輸入)而不是softmax產(chǎn)生的概率作為學(xué)習(xí)小模型的目標(biāo)恭陡,從而避免了這個問題蹬音,并使笨重模型產(chǎn)生的logit與小模型產(chǎn)生的logit之間的平方差最小化。我們更普遍的解決方案子姜,稱為“蒸餾”祟绊,是提高最終softmax的溫度楼入,直到繁瑣的模型產(chǎn)生一組適當(dāng)?shù)能浤繕?biāo)。然后我們在訓(xùn)練小模型時使用相同的高溫來匹配這些軟目標(biāo)牧抽。稍后我們將說明嘉熊,匹配笨重模型的logits實(shí)際上是蒸餾的一個特例。
用于訓(xùn)練小模型的傳輸集可以完全由未標(biāo)記的數(shù)據(jù)組成[1]扬舒,也可以使用原始訓(xùn)練集阐肤。我們發(fā)現(xiàn),使用原始的訓(xùn)練集效果很好讲坎,特別是如果我們在目標(biāo)函數(shù)中加入一個小項(xiàng)孕惜,鼓勵小模型預(yù)測真實(shí)目標(biāo),并匹配笨重模型提供的軟目標(biāo)晨炕。通常衫画,小模型不能精確匹配軟目標(biāo),在正確答案的方向上出錯是有幫助的瓮栗。
2蒸餾
神經(jīng)網(wǎng)絡(luò)通常通過使用一個“softmax”輸出層來產(chǎn)生類概率削罩,該層通過比較zi和其他logit,將為每個類計(jì)算的logit轉(zhuǎn)換為概率qi费奸。
式中弥激,T是通常設(shè)置為1的溫度。使用較高的T值會在類上產(chǎn)生較軟的概率分布愿阐。
在最簡單的蒸餾形式中微服,通過在傳遞集上訓(xùn)練知識并對傳遞集中的每種情況使用軟目標(biāo)分布將知識轉(zhuǎn)移到蒸餾模型中,該傳遞集是通過使用在其softmax中具有高溫的笨重模型生成的缨历。訓(xùn)練蒸餾模型時使用相同的高溫以蕴,但訓(xùn)練后使用的溫度為1。
當(dāng)所有或部分傳輸集的正確標(biāo)簽已知時戈二,通過訓(xùn)練蒸餾模型來生成正確的標(biāo)簽舒裤,這種方法可以得到顯著改進(jìn)喳资。一種方法是使用正確的標(biāo)簽來修改軟目標(biāo)觉吭,但是我們發(fā)現(xiàn)更好的方法是簡單地使用兩個不同目標(biāo)函數(shù)的加權(quán)平均值。第一個目標(biāo)函數(shù)是與軟目標(biāo)的交叉熵仆邓,該交叉熵的計(jì)算使用與從繁瑣模型生成軟目標(biāo)時使用的相同的高溫鲜滩。第二個目標(biāo)函數(shù)是帶有正確標(biāo)簽的交叉熵。這是使用與蒸餾模型的softmax中完全相同的logits計(jì)算的节值,但溫度為1徙硅。我們發(fā)現(xiàn),在第二個目標(biāo)函數(shù)上使用一個條件較低的權(quán)重通掣懔疲可以獲得最佳結(jié)果嗓蘑。由于軟目標(biāo)產(chǎn)生的梯度大小為1/t2,因此在使用硬目標(biāo)和軟目標(biāo)時,將其乘以t2是很重要的桩皿。這就保證了在實(shí)驗(yàn)meta參數(shù)時改變蒸餾溫度時豌汇,硬靶和軟靶的相對貢獻(xiàn)基本保持不變。
2.1匹配邏輯是蒸餾的特例
轉(zhuǎn)移集中的每一種情況都提供了一個交叉熵梯度dC/dzi泄隔,相對于蒸餾模型的每個logit zi拒贱。如果笨重模型的logits vi產(chǎn)生軟目標(biāo)概率pi,并且轉(zhuǎn)移訓(xùn)練是在溫度T下進(jìn)行的佛嬉,則該梯度由以下公式給出:
如果溫度與logits的大小相比較高逻澳,我們可以近似:
因此,在高溫極限下暖呕,蒸餾相當(dāng)于最小化1/2(zi?vi)2斜做,前提是每個分動器的logit分別為零。在較低的溫度下湾揽,蒸餾對匹配比平均值負(fù)得多的邏輯函數(shù)的關(guān)注要少得多陨享。這是潛在的優(yōu)勢,因?yàn)檫@些邏輯幾乎完全不受用于訓(xùn)練笨重模型的成本函數(shù)的約束钝腺,因此它們可能非常嘈雜抛姑。另一方面,非常消極的邏輯可能傳達(dá)有用的信息艳狐,關(guān)于知識獲得的繁瑣的模型定硝。這些影響中哪一個占主導(dǎo)地位是一個經(jīng)驗(yàn)問題。我們發(fā)現(xiàn)毫目,當(dāng)蒸餾模型太小而無法在繁瑣的模型中捕獲所有知識時蔬啡,中間溫度最有效,這強(qiáng)烈表明忽略大的負(fù)邏輯可能會有幫助镀虐。
3關(guān)于MNIST的初步實(shí)驗(yàn)
為了了解蒸餾的工作原理箱蟆,我們在所有60000個訓(xùn)練案例中訓(xùn)練了一個包含兩個隱藏層的大型神經(jīng)網(wǎng)絡(luò),共有1200個校正的線性隱藏單元刮便。如[5]中所述空猜,該網(wǎng)絡(luò)使用輟學(xué)和權(quán)重約束進(jìn)行了嚴(yán)格的正則化。輟學(xué)可以看作是訓(xùn)練一個指數(shù)級的共享權(quán)重的模型集合的一種方法恨旱。此外辈毯,輸入圖像在任何方向上都會受到最多兩個像素的抖動。該網(wǎng)絡(luò)實(shí)現(xiàn)了67個測試錯誤搜贤,而一個較小的網(wǎng)絡(luò)谆沃,有兩個隱藏層,800個校正線性隱藏單元仪芒,沒有正則化唁影,達(dá)到146個錯誤耕陷。但如果僅通過增加在溫度為20℃時對大網(wǎng)產(chǎn)生的軟目標(biāo)進(jìn)行匹配的附加任務(wù)對較小的網(wǎng)絡(luò)進(jìn)行正則化,則可以獲得74個測試誤差据沈。這說明軟目標(biāo)可以將大量的知識轉(zhuǎn)移到提取的模型中啃炸,包括從翻譯的訓(xùn)練數(shù)據(jù)中學(xué)習(xí)到的知識,即使轉(zhuǎn)移集不包含任何翻譯卓舵。
當(dāng)蒸餾網(wǎng)在其兩個隱藏層中各有300個或更多單元時南用,所有高于8的溫度都會產(chǎn)生相當(dāng)相似的結(jié)果。但當(dāng)這一數(shù)值從根本上減少到每層30個單位時掏湾,2.5到4之間的溫度比較高或較低的溫度效果要好得多裹虫。然后_我們_試_著_從_轉(zhuǎn)移_集中_省略_數(shù)字_3_的_所有_例子_ 。_因此融击,從蒸餾模型的角度來看筑公,3是一個從未見過的神話數(shù)字。盡管如此尊浪,蒸餾模型只會產(chǎn)生206個測試錯誤匣屡,其中133個在測試集中的1010個三分之一上拇涤。大多數(shù)的錯誤是由于三個班的學(xué)習(xí)偏差太低而造成的捣作。如果這個偏差增加3.5(這優(yōu)化了測試集的整體性能),則蒸餾模型會產(chǎn)生109個錯誤鹅士,其中14個在3s上券躁。因此,在正確的偏差下掉盅,盡管在訓(xùn)練過程中從未見過3也拜,但蒸餾模型在測試3s中的正確率為98.6%。如果傳遞集只包含來自訓(xùn)練集的7和8趾痘,則蒸餾模型會產(chǎn)生47.3%的測試誤差慢哈,但當(dāng)7和8的偏差減小7.6以優(yōu)化測試性能時,測試誤差將降至13.2%永票。
4語音識別實(shí)驗(yàn)
在這一節(jié)中卵贱,我們將研究用于自動語音識別(ASR)的深層神經(jīng)網(wǎng)絡(luò)(DNN)聲學(xué)模型的感知效果。我們證明瓦侮,本文提出的蒸餾策略能達(dá)到預(yù)期的效果艰赞,即將一組模型提取成一個單一的模型佣谐,其效果明顯優(yōu)于直接從相同訓(xùn)練數(shù)據(jù)中學(xué)習(xí)的相同大小的模型肚吏。
目前最先進(jìn)的ASR系統(tǒng)使用DNNs將波形特征的(短)時間上下文映射為隱馬爾可夫模型(HMM)離散狀態(tài)上的概率分布[4]。更具體地說狭魂,DNN在每一時刻在三個電話狀態(tài)的簇上產(chǎn)生一個概率分布罚攀,然后解碼器找到一條通過HMM狀態(tài)的路徑党觅,這是使用高概率狀態(tài)和生成語言模型下可能的轉(zhuǎn)錄之間的最佳折衷。
雖然有可能(也希望)以這樣一種方式訓(xùn)練DNN斋泄,即通過在所有可能的路徑上邊緣化來考慮解碼器(因此杯瞻,語言模型),通常炫掐,訓(xùn)練DNN進(jìn)行逐幀分類魁莉,方法是(局部地)最小化網(wǎng)絡(luò)預(yù)測和標(biāo)簽之間的交叉熵,強(qiáng)制對齊每個觀測的地面真實(shí)狀態(tài)序列:
其中θ是聲學(xué)模型P的參數(shù)募胃,該模型將時間t旗唁,st的聲學(xué)觀測值映射為“正確”HMM狀態(tài)ht的概率P(ht | st;θ′)痹束,該概率通過與正確的單詞序列進(jìn)行強(qiáng)制對齊來確定检疫。模型采用分布隨機(jī)梯度下降法訓(xùn)練。
我們使用的架構(gòu)有8個隱藏層祷嘶,每個層包含2560個校正的線性單元屎媳,最后一個softmax層有14000個標(biāo)簽(HMM目標(biāo)ht)。以40個Mel尺度濾波器組系數(shù)的26幀為輸入论巍,每幀提前10ms烛谊,預(yù)測第21幀的HMM狀態(tài)。參數(shù)總數(shù)約為85M嘉汰,這是Android語音搜索使用的聲學(xué)模型的一個稍微過時的版本晒来,應(yīng)該被認(rèn)為是一個非常強(qiáng)大的基線。為了訓(xùn)練DNN聲學(xué)模型郑现,我們使用了大約2000小時的英語口語數(shù)據(jù)湃崩,產(chǎn)生了大約700M的訓(xùn)練實(shí)例。在我們的開發(fā)平臺上接箫,該系統(tǒng)的幀頻準(zhǔn)確率為58.9%攒读,字錯誤率為10.9%。
4.1結(jié)果
我們訓(xùn)練了10個獨(dú)立的模型來預(yù)測P(ht | st辛友;θ)薄扁,使用與基線完全相同的體系結(jié)構(gòu)和訓(xùn)練過程。模型被隨機(jī)初始化為不同的初始參數(shù)值废累,我們發(fā)現(xiàn)這在訓(xùn)練的模型中創(chuàng)造了足夠的多樣性邓梅,使得集合的平均預(yù)測顯著優(yōu)于單個模型。我們已經(jīng)探索過通過改變每個模型所看到的數(shù)據(jù)集來增加模型的多樣性邑滨,但是我們發(fā)現(xiàn)這并不會顯著改變我們的結(jié)果日缨,所以我們選擇了更簡單的方法。對于蒸餾掖看,我們嘗試了[1,2,5,10]的溫度匣距,并在硬靶的交叉熵上使用了0.5的相對權(quán)重面哥,其中粗體字體表示表1中使用的最佳值。
表1顯示毅待,實(shí)際上尚卫,我們的蒸餾方法能夠從訓(xùn)練集中提取更多有用的信息,而不是簡單地使用硬標(biāo)簽來訓(xùn)練單個模型尸红。通過使用10個模型的集合實(shí)現(xiàn)的幀分類精度的80%以上的改進(jìn)被轉(zhuǎn)移到蒸餾模型上吱涉,這與我們在MNIST上的初步實(shí)驗(yàn)中觀察到的改進(jìn)相似。由于目標(biāo)函數(shù)的不匹配外里,集成對WER的最終目標(biāo)(在23K字的測試集上)的改善較小邑飒,但同樣,集成實(shí)現(xiàn)的WER改進(jìn)被轉(zhuǎn)移到蒸餾模型中级乐。
我們最近意識到了通過匹配已經(jīng)訓(xùn)練過的較大模型的類概率來學(xué)習(xí)小聲學(xué)模型的相關(guān)工作[8]疙咸。然而,他們使用一個大的未標(biāo)記數(shù)據(jù)集在1的溫度下進(jìn)行蒸餾风科,當(dāng)他們都用硬標(biāo)簽訓(xùn)練時撒轮,他們的最佳蒸餾模型只會將小模型的錯誤率降低28%,而大模型和小模型的錯誤率之間的差距是小模型的28%贼穆。
5大數(shù)據(jù)集專家培訓(xùn)團(tuán)
訓(xùn)練一個模型集合是利用并行計(jì)算的一個非常簡單的方法题山,而一個集成在測試時需要太多計(jì)算的常見問題可以用蒸餾來解決。然而故痊,對于集成還有另一個重要的反對意見:如果單個模型是大型的神經(jīng)網(wǎng)絡(luò)顶瞳,并且數(shù)據(jù)集非常大,那么即使很容易并行化愕秫,訓(xùn)練時所需的計(jì)算量也過大慨菱。
在這一節(jié)中,我們給出了這樣一個數(shù)據(jù)集的例子戴甩,并且我們展示了學(xué)習(xí)專家模型(每個模型都關(guān)注于類的不同可混淆子集)如何減少學(xué)習(xí)集成所需的總計(jì)算量符喝。專注于進(jìn)行細(xì)粒度區(qū)分的專家的主要問題是他們很容易過度擬合,我們描述了如何通過使用軟目標(biāo)來防止這種過度擬合甜孤。
5.1 JFT數(shù)據(jù)集
JFT是谷歌內(nèi)部的一個數(shù)據(jù)集协饲,有1億張標(biāo)簽圖片,有15000個標(biāo)簽缴川。當(dāng)我們做這項(xiàng)工作時茉稠,Google的JFT的基線模型是一個深卷積神經(jīng)網(wǎng)絡(luò)[7],它在大量核心上使用異步隨機(jī)梯度下降訓(xùn)練了大約6個月把夸。這個訓(xùn)練使用了兩種類型的并行[2]而线。首先,在不同的核集合上運(yùn)行神經(jīng)網(wǎng)絡(luò)的許多副本,并從訓(xùn)練集中處理不同的小批量吞获。每個副本計(jì)算其當(dāng)前小批量的平均漸變况凉,并將此漸變發(fā)送到分片參數(shù)服務(wù)器谚鄙,該服務(wù)器將返回參數(shù)的新值各拷。這些新值反映了參數(shù)服務(wù)器自上次向副本發(fā)送參數(shù)以來收到的所有漸變。第二闷营,每一個復(fù)制品通過在每個核心上放置不同的神經(jīng)元子集而分布在多個核心上烤黍。集成訓(xùn)練是第三種類型的并行,它可以圍繞其他兩種類型進(jìn)行傻盟,但前提是有更多的核心可用速蕊。等待數(shù)年來訓(xùn)練一個模型集合不是一個選擇,所以我們需要一個更快的方法來改進(jìn)基線模型娘赴。
5.2專業(yè)模型
當(dāng)類的數(shù)量非常大時规哲,笨重的模型應(yīng)該是一個集合,它包含一個針對所有數(shù)據(jù)訓(xùn)練的泛型模型和許多“專家”模型诽表,每個模型都是基于數(shù)據(jù)進(jìn)行訓(xùn)練的唉锌,這些數(shù)據(jù)在非常容易混淆的類子集(如不同類型的蘑菇)中高度豐富。這類專家的softmax可以通過將它不關(guān)心的所有類合并到一個垃圾箱類中而變得更小竿奏。為了減少過度擬合和分擔(dān)學(xué)習(xí)低層特征檢測器的工作袄简,每個專家模型都用廣義模型的權(quán)重進(jìn)行初始化。然后泛啸,通過訓(xùn)練專家绿语,其中一半的例子來自其特殊子集,一半隨機(jī)抽樣自訓(xùn)練集的其余部分候址,對這些權(quán)重稍作修改吕粹。訓(xùn)練結(jié)束后,我們可以通過增加垃圾桶類的logit值乘以專家類被過度抽樣的比例的對數(shù)來修正訓(xùn)練集的偏差岗仑。
5.3為專家分配課程
為了獲得專家的對象類別分組昂芜,我們決定將重點(diǎn)放在我們整個網(wǎng)絡(luò)經(jīng)常混淆的類別上赔蒲。盡管我們可以計(jì)算混淆矩陣并將其用作查找此類集群的一種方法泌神,但我們選擇了一種更簡單的方法,即不需要真正的標(biāo)簽來構(gòu)造集群舞虱。
特別是欢际,我們將聚類算法應(yīng)用于我們的廣義模型預(yù)測的協(xié)方差矩陣,這樣一組經(jīng)常一起預(yù)測的類sm將被用作我們的一個專家模型m的目標(biāo)矾兜。我們對協(xié)方差矩陣的列應(yīng)用了在線版本的K-means算法损趋,得到合理的聚類結(jié)果(見表2)。我們嘗試了幾種聚類算法椅寺,結(jié)果相似
5.4與專家一起進(jìn)行推理
在調(diào)查專家模型被提煉出來后會發(fā)生什么之前浑槽,我們想看看包含專家的組合表現(xiàn)如何蒋失。除了專家模型,我們總是有一個通才模型桐玻,這樣我們就可以處理沒有專家的類篙挽,這樣我們就可以決定使用哪些專家。給定一個輸入圖像x镊靴,我們分兩個步驟對一個分類進(jìn)行排序:
步驟1:對于每個測試用例铣卡,我們根據(jù)通才模型找到n個最有可能的類。將這組類稱為k偏竟。在我們的實(shí)驗(yàn)中煮落,我們使用n=1。
第二步:我們?nèi)∷械膶<夷P陀荒保琺蝉仇,其特殊的可混淆類子集sm與k有一個非空交集,并稱之為專家活動集合Ak(注意殖蚕,該集合可能是空的)轿衔。然后我們找到所有最小化類的全概率分布q:其中KL表示KL散度,pmpg表示專家模型或廣義全模型的概率分布嫌褪。分布pm是m的所有專業(yè)類加上一個垃圾箱類的分布呀枢,所以當(dāng)計(jì)算其與全q分布的KL散度時,我們求出全q分布分配給m垃圾箱中所有類的所有概率笼痛。
5.5結(jié)果
從訓(xùn)練有素的全網(wǎng)訓(xùn)練開始裙秋,專家們的訓(xùn)練速度極快(對于JFT來說,幾天而不是幾周)缨伊。而且摘刑,所有的專家都是完全獨(dú)立地接受培訓(xùn)的。表3顯示了基線系統(tǒng)和基線系統(tǒng)與專家模型相結(jié)合的絕對測試精度刻坊。有61個專業(yè)模型枷恕,總體測試準(zhǔn)確率相對提高了4.4%。我們還報告了條件測試的準(zhǔn)確性谭胚,即只考慮屬于專業(yè)類的示例徐块,并將我們的預(yù)測限制在該類的子集上。
在我們的JFT專家實(shí)驗(yàn)中灾而,我們培訓(xùn)了61名專家模型胡控,每個模型有300個班(加上垃圾箱班)。我們常常有一個特殊的專家組旁趟,因?yàn)槲覀儧]有一個特殊的類的專家組昼激。表4顯示了測試集示例的數(shù)量、使用專家時在位置1處正確的示例數(shù)量的變化,以及JFT數(shù)據(jù)集top1精確度的相對百分比改進(jìn)(按涵蓋該類的專家數(shù)量細(xì)分)橙困。當(dāng)我們有更多的專家覆蓋一個特定的班級時瞧掺,準(zhǔn)確性的提高會更大,這一趨勢令我們深受鼓舞凡傅,因?yàn)榕嘤?xùn)獨(dú)立的專家模型非常容易并行化辟狈。
6個軟目標(biāo)作為正則化器
關(guān)于使用軟目標(biāo)而不是硬目標(biāo),我們的一個主要主張是像捶,在軟目標(biāo)中可以攜帶許多有用的信息上陕,這些信息不可能用單個硬目標(biāo)進(jìn)行編碼桩砰。在本節(jié)中拓春,我們通過使用少得多的數(shù)據(jù)來擬合前面描述的基線語音模型的85M參數(shù)來證明這是一個非常大的影響。表5顯示亚隅,只有3%的數(shù)據(jù)(大約20M的例子)硼莽,用硬目標(biāo)訓(xùn)練基線模型會導(dǎo)致嚴(yán)重的過度擬合(我們提前停止,因?yàn)樵谶_(dá)到44.5%后煮纵,準(zhǔn)確度急劇下降)懂鸵,而用軟目標(biāo)訓(xùn)練的同一模型能夠恢復(fù)整個訓(xùn)練集中幾乎所有的信息(大約2%害羞)。更值得注意的是行疏,我們不必提前停止:具有軟目標(biāo)的系統(tǒng)簡單地“收斂”到57%匆光。這表明,軟目標(biāo)是一種非常有效的方法酿联,可以將一個模型根據(jù)所有數(shù)據(jù)訓(xùn)練出的規(guī)律性傳遞給另一個模型终息。
6.1使用軟靶防止專家過度擬合
我們在JFT數(shù)據(jù)集上的實(shí)驗(yàn)中使用的專家將他們所有的非專家類都分解成一個垃圾箱類。如果我們允許專家在所有課程中都有一個完整的softmax贞让,那么可能有一個比使用早期停止更好的方法來防止他們過度適應(yīng)周崭。一名專家接受的是在其特殊課程中高度豐富的數(shù)據(jù)方面的培訓(xùn)。這意味著它的訓(xùn)練集的有效規(guī)模要小得多喳张,而且它很容易過度適應(yīng)其特殊的課程续镇。這個問題不能通過使專家變得更小來解決,因?yàn)檫@樣我們就失去了從所有非專家類建模中獲得的非常有用的傳遞效果销部。
我們使用3%的語音數(shù)據(jù)進(jìn)行的實(shí)驗(yàn)強(qiáng)烈地表明摸航,如果一個專家以通才的權(quán)重初始化,那么除了用硬目標(biāo)訓(xùn)練專家外舅桩,我們還可以用非特殊類的軟目標(biāo)訓(xùn)練它酱虎,從而使它保留幾乎所有關(guān)于非特殊類的知識。軟目標(biāo)可以由通才提供江咳。我們目前正在探索這種方法逢净。
7與混合專家的關(guān)系
使用經(jīng)過數(shù)據(jù)子集培訓(xùn)的專家的使用與專家的混合[6]有些相似,后者使用門控網(wǎng)絡(luò)計(jì)算將每個示例分配給每個專家的概率。在專家學(xué)習(xí)處理分配給他們的示例的同時爹土,門控網(wǎng)絡(luò)也在學(xué)習(xí)根據(jù)專家對該示例的相對辨別性能來選擇將每個示例分配給哪些專家甥雕。利用專家的分辨能力來確定學(xué)習(xí)的任務(wù)比簡單地將輸入向量聚類并為每個簇分配一個專家要好得多,但這使得訓(xùn)練難以并行化:首先胀茵,每個專家的加權(quán)訓(xùn)練集不斷變化社露,依賴于所有其他專家;其次琼娘,選通網(wǎng)絡(luò)需要比較不同專家在同一個例子中的表現(xiàn)峭弟,從而知道如何修正其指派概率。這些困難意味著混合專家很少用于他們可能最有益的領(lǐng)域:包含大量數(shù)據(jù)集脱拼、包含明顯不同子集的任務(wù)
同時培養(yǎng)多個專家要容易得多瞒瘸。我們首先訓(xùn)練一個通才模型,然后使用混淆矩陣來定義訓(xùn)練專家的子集熄浓。一旦確定了這些子集情臭,專家就可以完全獨(dú)立地接受培訓(xùn)。在測試時赌蔑,我們可以使用通用模型的預(yù)測來決定哪些專家是相關(guān)的俯在,并且只需要運(yùn)行這些專家。
8討論
我們已經(jīng)證明娃惯,提取對于將知識從一個集合或從一個大的高度正則化的模型轉(zhuǎn)移到一個更小的跷乐、經(jīng)過提煉的模型非常有效。On-MNIST蒸餾非常有效趾浅,即使用于訓(xùn)練蒸餾模型的傳遞集缺少一個或多個類的任何示例愕提。對于Android語音搜索所使用的深度聲學(xué)模型,我們已經(jīng)表明潮孽,通過訓(xùn)練一組深層神經(jīng)網(wǎng)絡(luò)所實(shí)現(xiàn)的幾乎所有改進(jìn)都可以被提煉成一個大小相同的神經(jīng)網(wǎng)絡(luò)揪荣,這樣就更容易部署了。
對于真正大的神經(jīng)網(wǎng)絡(luò)往史,即使訓(xùn)練一個完整的集合也是不可行的仗颈,但是我們已經(jīng)證明,一個訓(xùn)練了很長時間的單一的非常大的網(wǎng)絡(luò)的性能可以通過學(xué)習(xí)大量的專家網(wǎng)絡(luò)得到顯著提高椎例,在一個高度易混淆的集群中挨决,每個類都學(xué)會了區(qū)分類。我們還沒有證明我們可以將專家的知識提煉回單一的大網(wǎng)絡(luò)中订歪。