本文為三星發(fā)表在 ECCV 2020 的基于二值網(wǎng)絡(luò)搜索的 NAS 工作(BATS)高诺,論文題目:BATS: Binary ArchitecTure Search。通過結(jié)合神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索翻擒,大大縮小了二值模型與實(shí)值之間的精度差距,并在CIFAR 和 ImageNet 數(shù)據(jù)集上的實(shí)驗(yàn)和分析證明了所提出的方法的有效性。
摘要
本文提出了二進(jìn)制架構(gòu)搜索(BATS),這是一個通過神經(jīng)架構(gòu)搜索(NAS)大幅縮小二進(jìn)制神經(jīng)網(wǎng)絡(luò)與其實(shí)值對應(yīng)的精度差距的框架君仆。實(shí)驗(yàn)表明,直接將NAS 應(yīng)用于二進(jìn)制領(lǐng)域的結(jié)果非常糟糕牲距。為了緩解這種情況返咱,本文描述了將 NAS 成功應(yīng)用于二進(jìn)制領(lǐng)域的 3 個關(guān)鍵要素:
- (1) 引入并設(shè)計了一個新的面向二進(jìn)制的搜索空間。
- (2) 提出了一個新的控制和穩(wěn)定搜索拓?fù)浣Y(jié)構(gòu)的機(jī)制嗅虏。
- (3) 提出并驗(yàn)證了一系列新的二進(jìn)制網(wǎng)絡(luò)搜索策略,以實(shí)現(xiàn)更快的收斂和更低的搜索時間上沐。
實(shí)驗(yàn)重新結(jié)果證明了所提出的方法的有效性和直接在二進(jìn)制空間中搜索的必要性皮服。并且,在CIFAR10参咙、CIFAR100 和 ImageNet 數(shù)據(jù)集上設(shè)計了 SOTA 的二元神經(jīng)網(wǎng)絡(luò)架構(gòu)龄广。
方法
搜索空間重定義
標(biāo)準(zhǔn) DARTS 搜索空間的問題
標(biāo)準(zhǔn) DARTS 搜索空間下搜索得到的網(wǎng)絡(luò)結(jié)構(gòu)二值化訓(xùn)練是無法收斂的,原因如下:
- 深度可分離卷積(SepConv)二值化難蕴侧。首先择同,實(shí)數(shù)深度可分離卷積本身就是普通標(biāo)準(zhǔn)卷積的“壓縮”版本,其次净宵,經(jīng)過二值化后進(jìn)一步對深度可分離卷積進(jìn)行了近似操作敲才。因此,深度可分離卷積二值化存在“雙重近似問題”择葡。
- 1x1卷積 & bottlneck 塊 二值化難紧武。因?yàn)?strong>關(guān)鍵的FeatureMap信息,由于1x1卷積權(quán)重值少和bottlneck所處的重要位置被二值化后無法有效傳遞下去敏储。
- DilConv & SepConv 二值化難阻星。標(biāo)準(zhǔn)的 DARTS 搜索空間定義的
DilConv
和SepConv
操作包含的卷積序列個數(shù)不同。DilConv
包含兩個卷積序列已添,SepConv
包含四個卷積序列妥箕。導(dǎo)致訓(xùn)練過程中兩者的收斂速度不同,并且會因此放大二值化過程中的梯度衰減現(xiàn)象(論文是這樣描述的更舞,不過具體原因不清楚)
二值神經(jīng)網(wǎng)絡(luò)搜索空間
二值神經(jīng)網(wǎng)絡(luò)搜索空間與標(biāo)準(zhǔn) DARTS 搜索空間對比如下圖所示:
主要存在以下幾方面的修改:
- 刪除了 1x1 卷積畦幢。
-
重新分配了深度可分離卷積中 Group Size 與 Channel 的關(guān)系。標(biāo)準(zhǔn)的深度可分離卷積中
#groups = #in_channel
缆蝉。本文中呛讲,CIFAR數(shù)據(jù)集上預(yù)定義Group卷積為12 Groups x 3 Channels = 36 Channels
禾怠;ImageNet 數(shù)據(jù)集上預(yù)定義Group卷積為16 Groups x 5 Channels = 80 Channels
。 - 每個opetation只包含一個卷積序列贝搁,便于學(xué)習(xí)和實(shí)現(xiàn)低延時吗氏。
-
每個卷積操作都加上
Skip-Connect
操作,有利于保持FeatureMap信息的傳遞和保留雷逆。
搜索的正則化和穩(wěn)定性
DARTS 搜索的不穩(wěn)定分析
盡管 DARTS 取得了成功弦讽,但根據(jù)隨機(jī)種子的不同,DARTS 的精度在運(yùn)行之間可能會有很大的差異膀哲。事實(shí)上往产,在有些情況下,隨機(jī)搜索獲得的架構(gòu)甚至比搜索得到的架構(gòu)通過表現(xiàn)的更好某宪。此外仿村,特別是當(dāng)訓(xùn)練時間較長或在較大的數(shù)據(jù)集上進(jìn)行搜索時,DARTS可能會出現(xiàn) Skip-Connect
富集的問題兴喂。常用的解決方法包括:
- 在架構(gòu)搜索過程中對跳連應(yīng)用dropout
- 通過保留每個單元最多2個跳連作為后處理步驟蔼囊,簡單地促成概率第二高的操作
但是,這種機(jī)制仍然會導(dǎo)致大量的隨機(jī)性衣迷,而且并不總是有效的:例如畏鼓,它可能會用池化層(沒有學(xué)習(xí)能力)取代跳過連接,或者搜索的架構(gòu)跳連包含的太少壶谒。當(dāng)搜索是在二進(jìn)制域中直截了當(dāng)?shù)剡M(jìn)行時云矫,這樣的問題就更加明顯了。鑒于在搜索過程中汗菜,節(jié)點(diǎn) j 的輸入是通過對所有輸入邊的加權(quán)和來獲得的让禀,為了最大限度地提高信息流,架構(gòu)參數(shù) α 傾向于收斂到相同的值陨界,使得最終架構(gòu)的選擇存在問題堆缘,并且容易受到噪聲的影響,導(dǎo)致拓?fù)浣Y(jié)構(gòu)的性能可能比隨機(jī)選擇更差普碎。此外吼肥,搜索高度偏向于實(shí)值操作(池化和跳連),使得搜索在早期階段可以提供更大的收益麻车。
溫度正則(temperature regularization)
為了緩解上述問題缀皱,并使得搜索程序更具辨別力,迫使其做出 "harder" 的決策动猬,本文借鑒知識蒸餾的思路啤斗,建議使用溫度因子 T<1 的正則策略,定義從節(jié)點(diǎn)i到j(luò)的流程如下公式所示:
采用溫度正則方法可以使架構(gòu)參數(shù)的分布不那么均勻赁咙,更加尖銳(即更有辨別力)钮莲。在搜索過程中免钻,由于信息流是使用加權(quán)進(jìn)行聚合的,所以網(wǎng)絡(luò)不能從所有信息流中提取信息崔拥,來平等地(或接近平等地)依賴所有可能的操作极舔。相反,為了確保收斂到一個滿意的解決方案链瓦,它必須將最高的概率分配給一個非0操作的路徑拆魏,由一個次元溫度(T <1)強(qiáng)制執(zhí)行。這種行為也更接近評估過程慈俯,從而減少搜索(網(wǎng)絡(luò)從所有路徑中提取信息)和評估之間的性能差異渤刃。
上圖中,圖1描述了給定單元在不同溫度下的架構(gòu)參數(shù)分布贴膘。對于低溫(T=0.2)卖子,網(wǎng)絡(luò)被迫做出更多的判別性決策,這反過來又使它減少了對 Skip-Connect 的依賴刑峡。圖2 進(jìn)一步證實(shí)了這一點(diǎn)洋闽,它描述了在不同溫度下搜索過程結(jié)束時,在 Normal Cell 中各操作被選擇的概率氛琢。
二值搜索策略
盡管二值網(wǎng)絡(luò)具有加速和節(jié)省空間的特點(diǎn)喊递,但與實(shí)值網(wǎng)絡(luò)相比随闪,二值網(wǎng)絡(luò)的訓(xùn)練仍然比較困難阳似,其方法通常需要一個預(yù)訓(xùn)練階段或仔細(xì)調(diào)整超參數(shù)和優(yōu)化器。對于搜索二值網(wǎng)絡(luò)的情況铐伴,直接實(shí)現(xiàn)二值權(quán)重和激活的架構(gòu)搜索撮奏,在大多數(shù)嘗試中,要么導(dǎo)致退化的拓?fù)浣Y(jié)構(gòu)当宴,要么訓(xùn)練簡單地收斂到極低的精度值畜吊。此外,直接在實(shí)域中執(zhí)行搜索户矢,然后對網(wǎng)絡(luò)進(jìn)行二值化是次優(yōu)的玲献。
為了緩解這個問題,本文提出了一個兩階段的優(yōu)化過程梯浪,在搜索過程中捌年,激活是二值化的,而權(quán)重是實(shí)值化的挂洛,一旦發(fā)現(xiàn)了最佳架構(gòu)礼预,我們在評估階段也要對權(quán)重進(jìn)行二值化÷簿ⅲ【更具體地說托酸,在評估過程中褒颈,首先從頭開始訓(xùn)練一個具有二值激活和實(shí)值權(quán)重的新網(wǎng)絡(luò),然后對權(quán)重進(jìn)行二值化励堡。最后谷丸,在測試集上對完全二值化的網(wǎng)絡(luò)進(jìn)行評估∧钛恚】這是因?yàn)閷?shí)值網(wǎng)絡(luò)的權(quán)重通秤倬可以被二值化,而不會顯著降低精度摊趾,但激活的二值化就不一樣了币狠,由于可能的狀態(tài)數(shù)量有限,網(wǎng)絡(luò)內(nèi)部的信息流急劇下降砾层。因此漩绵,本文提出將問題有效地分成兩個子問題:權(quán)重和特征二值化,在搜索過程中肛炮,嘗試解決最難的一個問題止吐,即激活的二值化。一旦完成了這一點(diǎn)侨糟,權(quán)重的二值化以下總是會導(dǎo)致精度的小幅下降(~1%)碍扔。