背景
神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索之前主流的方法主要包括:強(qiáng)化學(xué)習(xí),進(jìn)化學(xué)習(xí)抠蚣。他們的搜索空間都是不可微的连躏,Differentiable Architecture Search 這篇文章提出了一種可微的方法,可以用梯度下降來解決架構(gòu)搜索的問題坎缭,所以在搜索效率上比之前不可微的方法快幾個數(shù)量級白群∩邪可以這樣通俗的理解:之前不可微的方法,相當(dāng)于是你定義了一個搜索空間(比如3x3和5x5的卷積核)帜慢,然后神經(jīng)網(wǎng)絡(luò)的每一層你可以從搜索空間中選一種構(gòu)成一個神經(jīng)網(wǎng)絡(luò)笼裳,跑一下這個神經(jīng)網(wǎng)絡(luò)的訓(xùn)練結(jié)果,然后不斷測試其他的神經(jīng)網(wǎng)絡(luò)組合粱玲。這種方法躬柬,本質(zhì)上是從很多的組合當(dāng)中盡快的搜索到效果很好的一種,但是這個過程是黑盒密幔,需要有大量的驗證過程楔脯,所以會很耗時。而這篇文章把架構(gòu)搜索融合到模型當(dāng)中一起訓(xùn)練胯甩。
算法核心思想
由上圖可分析:
(a) 定義了一個cell單元昧廷,可看成有向無環(huán)圖,里面4個node偎箫,node之間的edge代表可能的操作(如:3x3 sep 卷積)木柬,初始化時unknown。
(b) 把搜索空間連續(xù)松弛化淹办,每個edge看成是所有子操作的混合(softmax權(quán)值疊加)眉枕。
(c) 聯(lián)合優(yōu)化,更新子操作混合概率上的edge超參(即架構(gòu)搜索任務(wù))和 架構(gòu)無關(guān)的網(wǎng)絡(luò)參數(shù)怜森。
(d) 優(yōu)化完畢后速挑,inference 直接取概率最大的子操作即可。
搜索空間
DARTS要做的事情副硅,是訓(xùn)練出來兩個Cell(Norm-Cell和Reduce-Cell)姥宝,然后把Cell相連構(gòu)成一個大網(wǎng)絡(luò),而超參數(shù)layers可以控制有多少個cell相連恐疲,例如layers = 20表示有20個cell前后相連腊满。
- Norm-Cell: [輸入與輸出的FeatureMap尺寸保持一致]
- Reduce-Cell: [輸出的FeatureMap尺寸減小一半]
Cell的組成
Cell由輸入節(jié)點(diǎn)套么,中間節(jié)點(diǎn),輸出節(jié)點(diǎn)碳蛋,邊四部分構(gòu)成胚泌,我們規(guī)定每一個cell有兩個輸入節(jié)點(diǎn)和一個輸出節(jié)點(diǎn),Norm-Cell和Reduce-Cell的結(jié)構(gòu)相同肃弟,不過操作不同玷室。
輸入節(jié)點(diǎn):對于卷積網(wǎng)絡(luò)來說,兩個輸入節(jié)點(diǎn)分別是前兩層(layers)cell的輸出愕乎,對于循環(huán)網(wǎng)絡(luò)(Recurrent)來說阵苇,輸入時當(dāng)前層的輸入和前一層的狀態(tài)。
中間節(jié)點(diǎn):每一個中間節(jié)點(diǎn)都由它的前繼通過邊再求和得來感论。
輸出節(jié)點(diǎn):由每一個中間節(jié)點(diǎn)concat起來。
邊:邊代表的是operation(比如33的卷積)紊册,在收斂得到結(jié)構(gòu)的過程中比肄,兩兩節(jié)點(diǎn)中間所有的邊(DARTS預(yù)定義了8中不同的操作*)都會存在并參與訓(xùn)練,最后加權(quán)平均囊陡,這個權(quán)就是我們要訓(xùn)練的東西芳绩,我們希望得到的結(jié)果是效果最好的邊它的權(quán)重最大。
DARTS實(shí)際預(yù)定義的Cell結(jié)構(gòu)與論文中示意圖的表示略有不同撞反,完整的Cell結(jié)構(gòu)包含兩個輸入節(jié)點(diǎn)妥色,四個中間節(jié)點(diǎn)和一個輸出節(jié)點(diǎn),如下圖所示:
全連接的情況下遏片,N0中間節(jié)點(diǎn)有兩個前繼節(jié)點(diǎn)嘹害;N1,N2,N3分別有3吮便,4笔呀,5個前繼節(jié)點(diǎn)。每個節(jié)點(diǎn)之間有對應(yīng)8個不同的預(yù)定義操作髓需,共同構(gòu)成一組邊许师。
首先我們定義如下公式用softmax歸一化alpha處理一組邊:
通過公式可知每個操作對應(yīng)一個權(quán)值(即alpha),這就是我們要訓(xùn)練的參數(shù)僚匆,我們把這些alpha稱作一個權(quán)值矩陣微渠,alpha值越大代表的操作在這組邊中越重要。
然后每組中間節(jié)點(diǎn)公式表示如下咧擂,即所有前繼節(jié)點(diǎn)累加作為當(dāng)前節(jié)點(diǎn)的輸出:
我們收斂到最后希望得到一個權(quán)值矩陣逞盆,這個矩陣當(dāng)中權(quán)值越大的邊,留下來之后效果越好屋确。
優(yōu)化策略
通過前面定義的搜索空間纳击,我們的目的是通過梯度下降優(yōu)化alpha矩陣续扔。我們把神經(jīng)網(wǎng)絡(luò)原有的權(quán)重稱為W矩陣。為了實(shí)現(xiàn)端到端的優(yōu)化焕数,我們希望同時優(yōu)化兩個矩陣使得結(jié)果變好纱昧。上述兩層優(yōu)化是有嚴(yán)格層次的,為了使兩者都能同時達(dá)到優(yōu)化的策略堡赔,一個樸素的想法是:在訓(xùn)練集上固定alpha矩陣的值识脆,然后梯度下降W矩陣的值,在驗證集上固定W矩陣的值善已,然后梯度下降alpha的值灼捂,循環(huán)往復(fù)直到這兩個值都比較理想。這個過程有點(diǎn)像k-means的過程换团,先定了中心悉稠,再求均值,再換中心艘包,再求均值的猛。需要注意的是驗證集和訓(xùn)練集的劃分比例是1:1的,因為對于alpha矩陣來說想虎,驗證集就是它的訓(xùn)練集卦尊。
但是這個方法雖然可以工作,但是效果不是很好舌厨,由于這種雙優(yōu)化的問題很難求得精確解(因為需要反復(fù)迭代求解兩個參數(shù))岂却,所以采用一種近似的迭代優(yōu)化步驟來交替更新兩個參數(shù),算法如下:
具體的公式推導(dǎo)流程可參考(DARTS公式推導(dǎo) https://zhuanlan.zhihu.com/p/73037439)
生成最終Cell結(jié)構(gòu)
根據(jù)前面所述裙椭,我們要訓(xùn)練出來一個alpha矩陣躏哩,使得權(quán)重大的邊保留下來,所以在這個結(jié)構(gòu)收斂了之后還需要做一個生成最終Cell的過程骇陈。那這個時候你可能會問震庭,為什么不把之前的結(jié)構(gòu)直接用上呢?因為邊太多你雌,結(jié)構(gòu)太復(fù)雜器联,參數(shù)太多不好訓(xùn)練,所以作者希望能生成一個更簡單的網(wǎng)絡(luò)結(jié)構(gòu)婿崭,接下來我們說生成的方法拨拓。
對于每一個中間節(jié)點(diǎn)來說,我們最多保留兩個最強(qiáng)壯的前繼氓栈;對于兩兩節(jié)點(diǎn)之間的邊渣磷,我們只保留權(quán)重最大的一條邊,我們定義一下什么是最強(qiáng)壯的前繼授瘦。假設(shè)一個節(jié)點(diǎn)有三個前繼醋界,那我們選哪兩個呢竟宋?把前繼和當(dāng)前節(jié)點(diǎn)之間權(quán)重最高的那條邊代表那個前繼的強(qiáng)壯程度,我們選最強(qiáng)壯的兩個前繼形纺。節(jié)點(diǎn)之間只保留權(quán)重最大的那條邊丘侠。
normal cell search
reduce cell search
網(wǎng)絡(luò)結(jié)構(gòu)堆疊
下圖,展示了Normal-Cell與Reduce-Cell的連接方式逐样,代碼描述是在1/3處和2/3處添加兩個Reduce-Cell蜗字。比如,在CIFAR-10數(shù)據(jù)集上的網(wǎng)絡(luò)結(jié)構(gòu)需要20個Cell串聯(lián)脂新。NetWork=6*Normal-Cell+Reduce-Cell+6*Normal-Cell+Reduce-Cell+6*Normal-Cell
由于挪捕,Cell結(jié)構(gòu)是兩個輸入的,因此詳細(xì)的Cell連接方式如下所示:
結(jié)果
CIFAR-10
ImageNet
參考
Liu, H., Simonyan, K., & Yang, Y. (2019). DARTS: Differentiable Architecture Search. ArXiv, abs/1806.09055.