Batch Normalization 學(xué)習(xí)筆記
原文地址:http://blog.csdn.net/hjimce/article/details/50866313
一侯养、背景意義
本篇博文主要講解2015年深度學(xué)習(xí)領(lǐng)域剩檀,非常值得學(xué)習(xí)的一篇文獻(xiàn):《Batch Normalization: Accelerating Deep Network Training by ?Reducing Internal Covariate Shift》贴谎,這個(gè)算法目前已經(jīng)被大量的應(yīng)用汞扎,最新的文獻(xiàn)算法很多都會(huì)引用這個(gè)算法,進(jìn)行網(wǎng)絡(luò)訓(xùn)練擅这,可見(jiàn)其強(qiáng)大之處非同一般啊澈魄。
近年來(lái)深度學(xué)習(xí)捷報(bào)連連、聲名鵲起仲翎,隨機(jī)梯度下架成了訓(xùn)練深度網(wǎng)絡(luò)的主流方法痹扇。盡管隨機(jī)梯度下降法對(duì)于訓(xùn)練深度網(wǎng)絡(luò)簡(jiǎn)單高效,但是它有個(gè)毛病谭确,就是需要我們?nèi)藶榈娜ミx擇參數(shù)帘营,比如學(xué)習(xí)率、參數(shù)初始化逐哈、權(quán)重衰減系數(shù)芬迄、Drop out比例等。這些參數(shù)的選擇對(duì)訓(xùn)練結(jié)果至關(guān)重要昂秃,以至于我們很多時(shí)間都浪費(fèi)在這些的調(diào)參上禀梳。那么學(xué)完這篇文獻(xiàn)之后,你可以不需要那么刻意的慢慢調(diào)整參數(shù)肠骆。BN算法(Batch Normalization)其強(qiáng)大之處如下:
(1)你可以選擇比較大的初始學(xué)習(xí)率算途,讓你的訓(xùn)練速度飆漲。以前還需要慢慢調(diào)整學(xué)習(xí)率蚀腿,甚至在網(wǎng)絡(luò)訓(xùn)練到一半的時(shí)候嘴瓤,還需要想著學(xué)習(xí)率進(jìn)一步調(diào)小的比例選擇多少比較合適,現(xiàn)在我們可以采用初始很大的學(xué)習(xí)率莉钙,然后學(xué)習(xí)率的衰減速度也很大廓脆,因?yàn)檫@個(gè)算法收斂很快。當(dāng)然這個(gè)算法即使你選擇了較小的學(xué)習(xí)率磁玉,也比以前的收斂速度快停忿,因?yàn)樗哂锌焖儆?xùn)練收斂的特性;
(2)你再也不用去理會(huì)過(guò)擬合中drop out蚊伞、L2正則項(xiàng)參數(shù)的選擇問(wèn)題席赂,采用BN算法后,你可以移除這兩項(xiàng)了參數(shù)时迫,或者可以選擇更小的L2正則約束參數(shù)了颅停,因?yàn)锽N具有提高網(wǎng)絡(luò)泛化能力的特性;
(3)再也不需要使用使用局部響應(yīng)歸一化層了(局部響應(yīng)歸一化是Alexnet網(wǎng)絡(luò)用到的方法掠拳,搞視覺(jué)的估計(jì)比較熟悉)癞揉,因?yàn)锽N本身就是一個(gè)歸一化網(wǎng)絡(luò)層;
(4)可以把訓(xùn)練數(shù)據(jù)徹底打亂(防止每批訓(xùn)練的時(shí)候,某一個(gè)樣本都經(jīng)常被挑選到烧董,文獻(xiàn)說(shuō)這個(gè)可以提高1%的精度,這句話(huà)我也是百思不得其解半时肌)逊移。
開(kāi)始講解算法前,先來(lái)思考一個(gè)問(wèn)題:我們知道在神經(jīng)網(wǎng)絡(luò)訓(xùn)練開(kāi)始前龙填,都要對(duì)輸入數(shù)據(jù)做一個(gè)歸一化處理胳泉,那么具體為什么需要?dú)w一化呢?歸一化后有什么好處呢岩遗?原因在于神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)過(guò)程本質(zhì)就是為了學(xué)習(xí)數(shù)據(jù)分布扇商,一旦訓(xùn)練數(shù)據(jù)與測(cè)試數(shù)據(jù)的分布不同,那么網(wǎng)絡(luò)的泛化能力也大大降低宿礁;另外一方面案铺,一旦每批訓(xùn)練數(shù)據(jù)的分布各不相同(batch 梯度下降),那么網(wǎng)絡(luò)就要在每次迭代都去學(xué)習(xí)適應(yīng)不同的分布梆靖,這樣將會(huì)大大降低網(wǎng)絡(luò)的訓(xùn)練速度控汉,這也正是為什么我們需要對(duì)數(shù)據(jù)都要做一個(gè)歸一化預(yù)處理的原因。
對(duì)于深度網(wǎng)絡(luò)的訓(xùn)練是一個(gè)復(fù)雜的過(guò)程返吻,只要網(wǎng)絡(luò)的前面幾層發(fā)生微小的改變姑子,那么后面幾層就會(huì)被累積放大下去。一旦網(wǎng)絡(luò)某一層的輸入數(shù)據(jù)的分布發(fā)生改變测僵,那么這一層網(wǎng)絡(luò)就需要去適應(yīng)學(xué)習(xí)這個(gè)新的數(shù)據(jù)分布街佑,所以如果訓(xùn)練過(guò)程中,訓(xùn)練數(shù)據(jù)的分布一直在發(fā)生變化捍靠,那么將會(huì)影響網(wǎng)絡(luò)的訓(xùn)練速度沐旨。
我們知道網(wǎng)絡(luò)一旦train起來(lái),那么參數(shù)就要發(fā)生更新剂公,除了輸入層的數(shù)據(jù)外(因?yàn)檩斎雽訑?shù)據(jù)希俩,我們已經(jīng)人為的為每個(gè)樣本歸一化),后面網(wǎng)絡(luò)每一層的輸入數(shù)據(jù)分布是一直在發(fā)生變化的纲辽,因?yàn)樵谟?xùn)練的時(shí)候颜武,前面層訓(xùn)練參數(shù)的更新將導(dǎo)致后面層輸入數(shù)據(jù)分布的變化。以網(wǎng)絡(luò)第二層為例:網(wǎng)絡(luò)的第二層輸入拖吼,是由第一層的參數(shù)和input計(jì)算得到的鳞上,而第一層的參數(shù)在整個(gè)訓(xùn)練過(guò)程中一直在變化,因此必然會(huì)引起后面每一層輸入數(shù)據(jù)分布的改變吊档。我們把網(wǎng)絡(luò)中間層在訓(xùn)練過(guò)程中篙议,數(shù)據(jù)分布的改變稱(chēng)之為:“Internal ?Covariate?Shift”。Paper所提出的算法,就是要解決在訓(xùn)練過(guò)程中鬼贱,中間層數(shù)據(jù)分布發(fā)生改變的情況移怯,于是就有了Batch??Normalization,這個(gè)牛逼算法的誕生这难。
二舟误、初識(shí)BN(Batch??Normalization)
1、BN概述
就像激活函數(shù)層姻乓、卷積層嵌溢、全連接層、池化層一樣蹋岩,BN(Batch Normalization)也屬于網(wǎng)絡(luò)的一層赖草。在前面我們提到網(wǎng)絡(luò)除了輸出層外,其它層因?yàn)榈蛯泳W(wǎng)絡(luò)在訓(xùn)練的時(shí)候更新了參數(shù)剪个,而引起后面層輸入數(shù)據(jù)分布的變化秧骑。這個(gè)時(shí)候我們可能就會(huì)想,如果在每一層輸入的時(shí)候扣囊,再加個(gè)預(yù)處理操作那該有多好啊腿堤,比如網(wǎng)絡(luò)第三層輸入數(shù)據(jù)X3(X3表示網(wǎng)絡(luò)第三層的輸入數(shù)據(jù))把它歸一化至:均值0、方差為1如暖,然后再輸入第三層計(jì)算笆檀,這樣我們就可以解決前面所提到的“Internal?Covariate?Shift”的問(wèn)題了。
而事實(shí)上盒至,paper的算法本質(zhì)原理就是這樣:在網(wǎng)絡(luò)的每一層輸入的時(shí)候酗洒,又插入了一個(gè)歸一化層,也就是先做一個(gè)歸一化處理枷遂,然后再進(jìn)入網(wǎng)絡(luò)的下一層樱衷。不過(guò)文獻(xiàn)歸一化層,可不像我們想象的那么簡(jiǎn)單酒唉,它是一個(gè)可學(xué)習(xí)矩桂、有參數(shù)的網(wǎng)絡(luò)層。既然說(shuō)到數(shù)據(jù)預(yù)處理痪伦,下面就先來(lái)復(fù)習(xí)一下最強(qiáng)的預(yù)處理方法:白化侄榴。
2、預(yù)處理操作選擇
說(shuō)到神經(jīng)網(wǎng)絡(luò)輸入數(shù)據(jù)預(yù)處理网沾,最好的算法莫過(guò)于白化預(yù)處理癞蚕。然而白化計(jì)算量太大了,很不劃算辉哥,還有就是白化不是處處可微的桦山,所以在深度學(xué)習(xí)中攒射,其實(shí)很少用到白化。經(jīng)過(guò)白化預(yù)處理后恒水,數(shù)據(jù)滿(mǎn)足條件:a会放、特征之間的相關(guān)性降低,這個(gè)就相當(dāng)于pca钉凌;b鸦概、數(shù)據(jù)均值、標(biāo)準(zhǔn)差歸一化甩骏,也就是使得每一維特征均值為0,標(biāo)準(zhǔn)差為1先慷。如果數(shù)據(jù)特征維數(shù)比較大饮笛,要進(jìn)行PCA,也就是實(shí)現(xiàn)白化的第1個(gè)要求论熙,是需要計(jì)算特征向量福青,計(jì)算量非常大,于是為了簡(jiǎn)化計(jì)算脓诡,作者忽略了第1個(gè)要求无午,僅僅使用了下面的公式進(jìn)行預(yù)處理,也就是近似白化預(yù)處理:
公式簡(jiǎn)單粗糙祝谚,但是依舊很牛逼宪迟。因此后面我們也將用這個(gè)公式,對(duì)某一個(gè)層網(wǎng)絡(luò)的輸入數(shù)據(jù)做一個(gè)歸一化處理交惯。需要注意的是次泽,我們訓(xùn)練過(guò)程中采用batch 隨機(jī)梯度下降,上面的E(xk)指的是每一批訓(xùn)練數(shù)據(jù)神經(jīng)元xk的平均值席爽;然后分母就是每一批數(shù)據(jù)神經(jīng)元xk激活度的一個(gè)標(biāo)準(zhǔn)差了意荤。
三、BN算法實(shí)現(xiàn)
1只锻、BN算法概述
經(jīng)過(guò)前面簡(jiǎn)單介紹玖像,這個(gè)時(shí)候可能我們會(huì)想當(dāng)然的以為:好像很簡(jiǎn)單的樣子,不就是在網(wǎng)絡(luò)中間層數(shù)據(jù)做一個(gè)歸一化處理嘛齐饮,這么簡(jiǎn)單的想法捐寥,為什么之前沒(méi)人用呢?然而其實(shí)實(shí)現(xiàn)起來(lái)并不是那么簡(jiǎn)單的祖驱。其實(shí)如果是僅僅使用上面的歸一化公式上真,對(duì)網(wǎng)絡(luò)某一層A的輸出數(shù)據(jù)做歸一化,然后送入網(wǎng)絡(luò)下一層B羹膳,這樣是會(huì)影響到本層網(wǎng)絡(luò)A所學(xué)習(xí)到的特征的睡互。打個(gè)比方,比如我網(wǎng)絡(luò)中間某一層學(xué)習(xí)到特征數(shù)據(jù)本身就分布在S型激活函數(shù)的兩側(cè),你強(qiáng)制把它給我歸一化處理就珠、標(biāo)準(zhǔn)差也限制在了1寇壳,把數(shù)據(jù)變換成分布于s函數(shù)的中間部分,這樣就相當(dāng)于我這一層網(wǎng)絡(luò)所學(xué)習(xí)到的特征分布被你搞壞了妻怎,這可怎么辦壳炎?于是文獻(xiàn)使出了一招驚天地泣鬼神的招式:變換重構(gòu),引入了可學(xué)習(xí)參數(shù)γ逼侦、β匿辩,這就是算法關(guān)鍵之處:
每一個(gè)神經(jīng)元xk都會(huì)有一對(duì)這樣的參數(shù)γ、β榛丢。這樣其實(shí)當(dāng):
铲球、
是可以恢復(fù)出原始的某一層所學(xué)到的特征的。因此我們引入了這個(gè)可學(xué)習(xí)重構(gòu)參數(shù)γ晰赞、β稼病,讓我們的網(wǎng)絡(luò)可以學(xué)習(xí)恢復(fù)出原始網(wǎng)絡(luò)所要學(xué)習(xí)的特征分布。最后Batch?Normalization網(wǎng)絡(luò)層的前向傳導(dǎo)過(guò)程公式就是:
上面的公式中m指的是mini-batch?size掖鱼。
2然走、源碼實(shí)現(xiàn)
? ? ? ? ? ? m = K.mean(X, axis=-1, keepdims=True)#計(jì)算均值
? ? ? ? ? ? std = K.std(X, axis=-1, keepdims=True)#計(jì)算標(biāo)準(zhǔn)差
? ? ? ? ? ? X_normed = (X - m) / (std + self.epsilon)#歸一化
? ? ? ? ? ? out = self.gamma * X_normed + self.beta#重構(gòu)變換
上面的x是一個(gè)二維矩陣,對(duì)于源碼的實(shí)現(xiàn)就幾行代碼而已戏挡,輕輕松松芍瑞。
3、實(shí)戰(zhàn)使用
(1)可能學(xué)完了上面的算法褐墅,你只是知道它的一個(gè)訓(xùn)練過(guò)程啄巧,一個(gè)網(wǎng)絡(luò)一旦訓(xùn)練完了,就沒(méi)有了min-batch這個(gè)概念了掌栅。測(cè)試階段我們一般只輸入一個(gè)測(cè)試樣本秩仆,看看結(jié)果而已。因此測(cè)試樣本猾封,前向傳導(dǎo)的時(shí)候澄耍,上面的均值u、標(biāo)準(zhǔn)差σ?要哪里來(lái)晌缘?其實(shí)網(wǎng)絡(luò)一旦訓(xùn)練完畢齐莲,參數(shù)都是固定的,這個(gè)時(shí)候即使是每批訓(xùn)練樣本進(jìn)入網(wǎng)絡(luò)磷箕,那么BN層計(jì)算的均值u选酗、和標(biāo)準(zhǔn)差都是固定不變的。我們可以采用這些數(shù)值來(lái)作為測(cè)試樣本所需要的均值岳枷、標(biāo)準(zhǔn)差芒填,于是最后測(cè)試階段的u和σ 計(jì)算公式如下:
上面簡(jiǎn)單理解就是:對(duì)于均值來(lái)說(shuō)直接計(jì)算所有batch u值的平均值呜叫;然后對(duì)于標(biāo)準(zhǔn)偏差采用每個(gè)batch?σB的無(wú)偏估計(jì)。最后測(cè)試階段殿衰,BN的使用公式就是:
(2)根據(jù)文獻(xiàn)說(shuō)朱庆,BN可以應(yīng)用于一個(gè)神經(jīng)網(wǎng)絡(luò)的任何神經(jīng)元上。文獻(xiàn)主要是把BN變換闷祥,置于網(wǎng)絡(luò)激活函數(shù)層的前面娱颊。在沒(méi)有采用BN的時(shí)候,激活函數(shù)層是這樣的:
z=g(Wu+b)
也就是我們希望一個(gè)激活函數(shù)凯砍,比如s型函數(shù)s(x)的自變量x是經(jīng)過(guò)BN處理后的結(jié)果箱硕。因此前向傳導(dǎo)的計(jì)算公式就應(yīng)該是:
z=g(BN(Wu+b))
其實(shí)因?yàn)槠脜?shù)b經(jīng)過(guò)BN層后其實(shí)是沒(méi)有用的,最后也會(huì)被均值歸一化悟衩,當(dāng)然BN層后面還有個(gè)β參數(shù)作為偏置項(xiàng)剧罩,所以b這個(gè)參數(shù)就可以不用了。因此最后把BN層+激活函數(shù)層就變成了:
z=g(BN(Wu))
四局待、Batch Normalization在CNN中的使用
通過(guò)上面的學(xué)習(xí),我們知道BN層是對(duì)于每個(gè)神經(jīng)元做歸一化處理菱属,甚至只需要對(duì)某一個(gè)神經(jīng)元進(jìn)行歸一化钳榨,而不是對(duì)一整層網(wǎng)絡(luò)的神經(jīng)元進(jìn)行歸一化。既然BN是對(duì)單個(gè)神經(jīng)元的運(yùn)算纽门,那么在CNN中卷積層上要怎么搞薛耻?假如某一層卷積層有6個(gè)特征圖,每個(gè)特征圖的大小是100*100赏陵,這樣就相當(dāng)于這一層網(wǎng)絡(luò)有6*100*100個(gè)神經(jīng)元饼齿,如果采用BN,就會(huì)有6*100*100個(gè)參數(shù)γ蝙搔、β缕溉,這樣豈不是太恐怖了。因此卷積層上的BN使用吃型,其實(shí)也是使用了類(lèi)似權(quán)值共享的策略证鸥,把一整張?zhí)卣鲌D當(dāng)做一個(gè)神經(jīng)元進(jìn)行處理。
卷積神經(jīng)網(wǎng)絡(luò)經(jīng)過(guò)卷積后得到的是一系列的特征圖勤晚,如果min-batch?sizes為m枉层,那么網(wǎng)絡(luò)某一層輸入數(shù)據(jù)可以表示為四維矩陣(m,f,p,q),m為min-batch?sizes赐写,f為特征圖個(gè)數(shù)鸟蜡,p、q分別為特征圖的寬高挺邀。在cnn中我們可以把每個(gè)特征圖看成是一個(gè)特征處理(一個(gè)神經(jīng)元)揉忘,因此在使用Batch?Normalization跳座,mini-batch?size 的大小就是:m*p*q,于是對(duì)于每個(gè)特征圖都只有一對(duì)可學(xué)習(xí)參數(shù):γ癌淮、β躺坟。說(shuō)白了吧,這就是相當(dāng)于求取所有樣本所對(duì)應(yīng)的一個(gè)特征圖的所有神經(jīng)元的平均值乳蓄、方差咪橙,然后對(duì)這個(gè)特征圖神經(jīng)元做歸一化。下面是來(lái)自于keras卷積層的BN實(shí)現(xiàn)一小段主要源碼:
? ? ? ? ? input_shape = self.input_shape
? ? ? ? ? ? reduction_axes = list(range(len(input_shape)))
? ? ? ? ? ? del reduction_axes[self.axis]
? ? ? ? ? ? broadcast_shape = [1] * len(input_shape)
? ? ? ? ? ? broadcast_shape[self.axis] = input_shape[self.axis]
? ? ? ? ? ? if train:
? ? ? ? ? ? ? ? m = K.mean(X, axis=reduction_axes)
? ? ? ? ? ? ? ? brodcast_m = K.reshape(m, broadcast_shape)
? ? ? ? ? ? ? ? std = K.mean(K.square(X - brodcast_m) + self.epsilon, axis=reduction_axes)
? ? ? ? ? ? ? ? std = K.sqrt(std)
? ? ? ? ? ? ? ? brodcast_std = K.reshape(std, broadcast_shape)
? ? ? ? ? ? ? ? mean_update = self.momentum * self.running_mean + (1-self.momentum) * m
? ? ? ? ? ? ? ? std_update = self.momentum * self.running_std + (1-self.momentum) * std
? ? ? ? ? ? ? ? self.updates = [(self.running_mean, mean_update),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (self.running_std, std_update)]
? ? ? ? ? ? ? ? X_normed = (X - brodcast_m) / (brodcast_std + self.epsilon)
? ? ? ? ? ? else:
? ? ? ? ? ? ? ? brodcast_m = K.reshape(self.running_mean, broadcast_shape)
? ? ? ? ? ? ? ? brodcast_std = K.reshape(self.running_std, broadcast_shape)
? ? ? ? ? ? ? ? X_normed = ((X - brodcast_m) /
? ? ? ? ? ? ? ? ? ? ? ? ? ? (brodcast_std + self.epsilon))
? ? ? ? ? ? out = K.reshape(self.gamma, broadcast_shape) * X_normed + K.reshape(self.beta, broadcast_shape)
個(gè)人總結(jié):2015年個(gè)人最喜歡深度學(xué)習(xí)的一篇paper就是Batch Normalization這篇文獻(xiàn)虚倒,采用這個(gè)方法網(wǎng)絡(luò)的訓(xùn)練速度快到驚人啊美侦,感覺(jué)訓(xùn)練速度是以前的十倍以上,再也不用擔(dān)心自己這破電腦每次運(yùn)行一下魂奥,訓(xùn)練一下都要跑個(gè)兩三天的時(shí)間菠剩。另外這篇文獻(xiàn)跟空間變換網(wǎng)絡(luò)《Spatial Transformer Networks》的思想神似啊,都是一個(gè)變換網(wǎng)絡(luò)層耻煤。
參考文獻(xiàn):
1具壮、《Batch Normalization: Accelerating Deep Network Training by ?Reducing Internal Covariate Shift》
2、《Spatial Transformer Networks》
3哈蝇、https://github.com/fchollet/keras