這篇文章適合對Keras和深度學(xué)習(xí)有一定基礎(chǔ)的讀者
BatchNormalization 是我們在訓(xùn)練深度神經(jīng)網(wǎng)絡(luò)的時候常用方法搬葬,由Google在2015年提出:https://arxiv.org/pdf/1502.03167.pdf.
總結(jié)來說使用BatchNormalization有以下有點:
- 可以減少過擬合棠赛,一定成都上減少Dropout的使用
- 加速訓(xùn)練
- 使用更好的學(xué)習(xí)率
BatchNormalization原理
我們都知道在訓(xùn)練深度學(xué)習(xí)模型的時候是使用一個一個batch來進(jìn)行隨機(jī)梯度更新的熊赖,這樣不用每次更新都需要計算所有數(shù)據(jù)的參數(shù)浪听,同樣對于batchnormalization:
假設(shè)輸入的batch中有m個數(shù)據(jù)钉蒲,對輸入的m個數(shù)據(jù)計算均值和均方差浓利,使用統(tǒng)計數(shù)據(jù)對輸入進(jìn)行normalization宫仗,然后再使用 和
對歸一化的輸入
進(jìn)行 scale 和 shift够挂,其中scale和shift是可以學(xué)習(xí)的參數(shù),也就是經(jīng)過batchnormalization處理的batch數(shù)據(jù)不僅僅受到整個batch的mean和variance參數(shù)影響藕夫,也受到前面訓(xùn)練的數(shù)據(jù)集的影響(前面的數(shù)據(jù)訓(xùn)練影響
和
)
原文里有這樣一句話孽糖,也是相同的意思:
The BN transform can be added to a network to manip- ulate any activation. In the notation y = BNγ ,β (x), we
indicate that the parameters γ and β are to be learned,
but it should be noted that the BN transform does not
independently process the activation in each training ex-
ample. Rather, BNγ,β(x) depends both on the training
example and the other examples in the mini-batch
Keras中BatchNormalization的參數(shù):
keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', moving_mean_initializer='zeros', moving_variance_initializer='ones', beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None)
Arguments
axis: Integer, the axis that should be normalized (typically the features axis). For instance, after a Conv2D layer with data_format="channels_first", set axis=1 in BatchNormalization.
momentum: Momentum for the moving mean and the moving variance.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of beta to normalized tensor. If False, beta is ignored.
scale: If True, multiply by gamma. If False, gamma is not used. When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
moving_mean_initializer: Initializer for the moving mean.
moving_variance_initializer: Initializer for the moving variance.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
通過使用BatchNormalization對網(wǎng)絡(luò)內(nèi)部的輸入輸出進(jìn)行歸一化,可以避免梯度消失或者爆炸的問題毅贮,而且可以增加網(wǎng)絡(luò)的魯棒性办悟,可以參考對網(wǎng)絡(luò)的輸入進(jìn)行歸一化。
Keras 的BatchNormalization實現(xiàn)
從這里回歸題目滩褥,Keras里面的BatchNormalization有什么不一樣?
我們知道在訓(xùn)練的時候使用batch normalization來對輸入進(jìn)行歸一化病蛉,在測試的時候使用的是一個樣本如何獲取mean和variance呢?
在測試的時候使用的是前面的訓(xùn)練的所有的min-batches的指數(shù)平均瑰煎,具體這里不展開铺然,可以參考這里:Ng的課程
,可看作前面所有的數(shù)據(jù)的mean和variance對當(dāng)前測試樣本的一個估計丢间。
在Keras里面inference或者predict mode里面采用的也是這種方法探熔。這個PR提出的問題是在進(jìn)行遷移學(xué)習(xí)的時候Keras提供的這個接口有很大的問題,很多人在訓(xùn)練集和測試集上的準(zhǔn)確度差異太大烘挫。
遷移學(xué)習(xí)一般在我們自己的樣本數(shù)據(jù)過少诀艰,在別人訓(xùn)練好的模型基礎(chǔ)上,使用我們自己的模型進(jìn)行參數(shù)微調(diào)整饮六。別人的模型解決的問題不是完全一樣二是類似的問題其垄,因為訓(xùn)練好的模型前面幾層可能都會識別邊緣和角點等信息。
在遷移學(xué)習(xí)的時候卤橄,通過frozen前面已經(jīng)訓(xùn)練好的layer绿满,然后在新加的layer上進(jìn)行參數(shù)更新。Keras里面一般我們通過如下代碼來fronzen一些層:
for layer in base_model.layers:
layer.trainable=False
問題就出在這里窟扑,在進(jìn)行finetune的時候喇颁,trainable=False的層計算mean和variance參數(shù)的時候使用的是新數(shù)據(jù)的min-batch計算得到的mean和variance進(jìn)行參數(shù)更新,而在模型finetune好之后嚎货,在inference的時候使用的是原始的數(shù)據(jù)加權(quán)平均的mean和variance橘霎。總而言之殖属,在finetune的時候trainable=False的batch normalization 統(tǒng)計參數(shù)來自于新數(shù)據(jù)(你現(xiàn)有的樣本)姐叁,而finetune完成之后進(jìn)行inference的時候統(tǒng)計參數(shù)來自于別人訓(xùn)練模型用的樣本特性。在這之間就有一個gap, 導(dǎo)致在finetune的訓(xùn)練準(zhǔn)確度和測試準(zhǔn)確度差異較大外潜,Github上也有人提過issue原环。
那么正確的解決方案是怎樣的?
在finetune的時候使用原始數(shù)據(jù)計算的統(tǒng)計參數(shù)對trainable=False的BatchNormalization參數(shù)進(jìn)行更新处窥,這樣就可以保證訓(xùn)練和測試的時候行為一致嘱吗。現(xiàn)在的Keras應(yīng)該是不支持這一行為的。 這也導(dǎo)致了在pull request page的論戰(zhàn)碧库。
提出PR的人柜与,在他的博客里面也做了對比實驗,想仔細(xì)了解的人可以去參考的博客嵌灰。
這個問題我之前也沒有注意過弄匕,通過這個問題即更加深入的了解了BatchNormalization也對Keras的使用方法有所注意,也是Keras封裝太多帶來的問題沽瞭,未來可能考慮轉(zhuǎn)戰(zhàn)Tensorflow或者Pytorch迁匠。希望對讀者有所裨益。