1匕荸、提出背景
1.1 模型訓(xùn)練中的困擾
在深度學(xué)習(xí)中汁雷,由于問題的復(fù)雜性介时,我們往往會(huì)使用較深層數(shù)的網(wǎng)絡(luò)進(jìn)行訓(xùn)練,相信很多煉丹的朋友都對(duì)調(diào)參的困難有所體會(huì)蒜埋,尤其是對(duì)深層神經(jīng)網(wǎng)絡(luò)的訓(xùn)練調(diào)參更是困難且復(fù)雜淫痰。
在這個(gè)過程中,我們需要去嘗試不同的學(xué)習(xí)率整份、初始化參數(shù)方法(例如Xavier初始化)等方式來幫助我們的模型加速收斂待错。深度神經(jīng)網(wǎng)絡(luò)之所以如此難訓(xùn)練,其中一個(gè)重要原因就是網(wǎng)絡(luò)中層與層之間存在高度的關(guān)聯(lián)性與耦合性烈评。下圖是一個(gè)多層的神經(jīng)網(wǎng)絡(luò)火俄,層與層之間采用全連接的方式進(jìn)行連接。
我們規(guī)定左側(cè)為神經(jīng)網(wǎng)絡(luò)的底層讲冠,右側(cè)為神經(jīng)網(wǎng)絡(luò)的上層瓜客。那么網(wǎng)絡(luò)中層與層之間的關(guān)聯(lián)性會(huì)導(dǎo)致如下的狀況:隨著訓(xùn)練的進(jìn)行,網(wǎng)絡(luò)中的參數(shù)也隨著梯度下降在不停更新竿开。
- 一方面谱仪,當(dāng)?shù)讓泳W(wǎng)絡(luò)中參數(shù)發(fā)生微弱變化時(shí),由于每一層中的線性變換與非線性激活映射否彩,這些微弱變化隨著網(wǎng)絡(luò)層數(shù)的加深而被放大(類似蝴蝶效應(yīng))疯攒;
- 另一方面,參數(shù)的變化導(dǎo)致每一層的輸入分布會(huì)發(fā)生改變列荔,進(jìn)而上層的網(wǎng)絡(luò)需要不停地去適應(yīng)這些分布變化敬尺,使得我們的模型訓(xùn)練變得困難枚尼。上述這一現(xiàn)象叫做Internal Covariate Shift。
1.2 什么是Internal Covariate Shift
Batch Normalization的原論文作者給了Internal Covariate Shift一個(gè)較規(guī)范的定義:在深層網(wǎng)絡(luò)訓(xùn)練的過程中砂吞,由于網(wǎng)絡(luò)中參數(shù)變化而引起內(nèi)部結(jié)點(diǎn)數(shù)據(jù)分布發(fā)生變化的這一過程被稱作Internal Covariate Shift署恍。
這句話怎么理解呢?我們定義每一層的線性變換為呜舒,其中代表層數(shù)锭汛;非線性變換為,其中袭蝗,為 第層的激活函數(shù)唤殴。
醉著梯度下降的進(jìn)行,每一層的參數(shù)與都會(huì)被更新到腥,那么的分布也就發(fā)生了改變朵逝,進(jìn)而也同樣出現(xiàn)分布的改變。而作為第層的輸入乡范,意味著層就需要去不停適應(yīng)這種數(shù)據(jù)分布的變化配名,這一過程就被叫做Internal Covariate Shift。
1.3 Internal Covariate Shift會(huì)帶來什么問題晋辆?
(1)上層網(wǎng)絡(luò)需要不停地調(diào)整來適應(yīng)輸入數(shù)據(jù)分布的變化渠脉,導(dǎo)致網(wǎng)絡(luò)學(xué)習(xí)速度的降低
我們?cè)谏厦嫣岬教荻认陆档倪^程會(huì)讓每一層的參數(shù)和發(fā)生變化,進(jìn)而使得每一層的線性與非線性計(jì)算結(jié)果分布產(chǎn)生變化瓶佳。后層網(wǎng)絡(luò)就要不停地適應(yīng)這種分布的變化芋膘,這個(gè)時(shí)候就會(huì)使得整個(gè)網(wǎng)絡(luò)的學(xué)習(xí)速率變慢。
(2)網(wǎng)絡(luò)的訓(xùn)練過程容易陷入梯度飽和區(qū)霸饲,減緩網(wǎng)絡(luò)收斂的速度
當(dāng)我們?cè)谏窠?jīng)網(wǎng)絡(luò)中采用飽和激活函數(shù)(saturated activation function)時(shí)为朋,例如sigmoid,tanh激活函數(shù)厚脉,很容易使得模型訓(xùn)練陷入梯度飽和區(qū)(saturated regime)习寸。隨著模型訓(xùn)練的進(jìn)行,我們的參數(shù) 會(huì)逐漸更新并變大傻工,此時(shí)就會(huì)隨著變大霞溪,并且還要收到更底層網(wǎng)絡(luò)參數(shù)的影響,隨著網(wǎng)絡(luò)層數(shù)的增加中捆,很容易陷入梯度飽和區(qū)威鹿,此時(shí)梯度會(huì)變得很小甚至接近與0,參數(shù)的更新速度就會(huì)變慢轨香,進(jìn)而就會(huì)放慢網(wǎng)絡(luò)的收斂速度。
對(duì)于激活函數(shù)梯度飽和的問題幼东,有兩種解決思路:第一種就是使用ReLU等非線性激活函數(shù)臂容,可以一定程度上解決訓(xùn)練陷入梯度飽和區(qū)的問題科雳。另一種就是,我們可以讓激活函數(shù)的分布保持在一個(gè)穩(wěn)定的狀態(tài)脓杉,來盡可能避免它們陷入梯度飽和區(qū)糟秘,也就是Normalization的思路。
1.4 我們?nèi)绾螠p緩Internal Covariate Shift球散?
要緩解ICS的問題尿赚,就要明白它產(chǎn)生的原因。ICS產(chǎn)生的原因是由于參數(shù)更新帶來的網(wǎng)絡(luò)中每一層輸入值分布的改變蕉堰,并且隨著網(wǎng)絡(luò)層數(shù)的加深而變得更加嚴(yán)重凌净,因此我們可以通過固定每一層網(wǎng)絡(luò)輸入值的分布來對(duì)減緩ICS問題。
(1)白化
白化(Whitening)是機(jī)器學(xué)習(xí)里面常用的一種規(guī)范化數(shù)據(jù)分布的方法屋讶,主要是PCA白化與ZCA白化冰寻。白化是對(duì)輸入數(shù)據(jù)分布進(jìn)行變換,進(jìn)而達(dá)到以下兩個(gè)目的:
- 使得輸入特征分布具有相同的均值與方差皿渗。其中PCA白化保證了所有特征分布均值為0斩芭,方差為1;而ZCA白化則保證了所有特征分布均值為0乐疆,方差相同划乖;
- 去除特征之間的相關(guān)性。
通過白化操作挤土,我們可以減緩ICS的問題琴庵,進(jìn)而固定了每一層網(wǎng)絡(luò)輸入分布,加速網(wǎng)絡(luò)訓(xùn)練過程的收斂(LeCun et al.,1998b耕挨;Wiesler&Ney,2011)细卧。
(2)Batch Normalization提出
既然白化可以解決這個(gè)問題,為什么我們還要提出別的解決辦法筒占?當(dāng)然是現(xiàn)有的方法具有一定的缺陷贪庙,白化主要有以下兩個(gè)問題:
- 白化過程計(jì)算成本太高,并且在每一輪訓(xùn)練中的每一層我們都需要做如此高成本計(jì)算的白化操作翰苫;
- 白化過程由于改變了網(wǎng)絡(luò)每一層的分布止邮,因而改變了網(wǎng)絡(luò)層中本身數(shù)據(jù)的表達(dá)能力。底層網(wǎng)絡(luò)學(xué)習(xí)到的參數(shù)信息會(huì)被白化操作丟失掉奏窑。
既然有了上面兩個(gè)問題导披,那我們的解決思路就很簡單,一方面埃唯,我們提出的normalization方法要能夠簡化計(jì)算過程撩匕;另一方面又需要經(jīng)過規(guī)范化處理后讓數(shù)據(jù)盡可能保留原始的表達(dá)能力。于是就有了簡化+改進(jìn)版的白化——Batch Normalization墨叛。
2止毕、Batch Normalization
2.1 思路
既然白化計(jì)算過程比較復(fù)雜模蜡,那我們就簡化一點(diǎn),比如我們可以嘗試單獨(dú)對(duì)每個(gè)特征進(jìn)行normalizaiton就可以了扁凛,讓每個(gè)特征都有均值為0忍疾,方差為1的分布就OK。
另一個(gè)問題谨朝,既然白化操作減弱了網(wǎng)絡(luò)中每一層輸入數(shù)據(jù)表達(dá)能力卤妒,那我就再加個(gè)線性變換操作,讓這些數(shù)據(jù)再能夠盡可能恢復(fù)本身的表達(dá)能力就好了字币。
因此则披,基于上面兩個(gè)解決問題的思路,作者提出了Batch Normalization纬朝,下一部分來具體講解這個(gè)算法步驟收叶。
2.2 公式
舉例計(jì)算:
上圖展示了一個(gè)batch size為2(兩張圖片)的Batch Normalization的計(jì)算過程。
假設(shè)feature1共苛、feature2分別是由image1判没、image2經(jīng)過一系列卷積池化后得到的特征矩陣,feature的channel數(shù)均為2隅茎,那么代表該batch的所有feature的channel1的數(shù)據(jù)澄峰,同理代表該batch的所有feature的channel2的數(shù)據(jù)。
然后分別計(jì)算和的均值和方差辟犀,得到我們的和兩個(gè)向量俏竞。
然后在根據(jù)標(biāo)準(zhǔn)差計(jì)算公式分別計(jì)算每個(gè)channel的值(公式中的是一個(gè)很小的常量,防止分母為零的情況)
在我們訓(xùn)練網(wǎng)絡(luò)的過程中堂竟,我們是通過一個(gè)batch一個(gè)batch的數(shù)據(jù)進(jìn)行訓(xùn)練的魂毁,但是我們?cè)陬A(yù)測過程中通常都是輸入一張圖片進(jìn)行預(yù)測,此時(shí)batch size為1出嘹,如果在通過上述方法計(jì)算均值和方差就沒有意義了席楚。所以我們?cè)谟?xùn)練過程中要去不斷的計(jì)算每個(gè)batch的均值和方差,并使用移動(dòng)平均(moving average)的方法記錄統(tǒng)計(jì)的均值和方差税稼,在我們訓(xùn)練完后我們可以近似認(rèn)為我們所統(tǒng)計(jì)的均值和方差就等于我們整個(gè)訓(xùn)練集的均值和方差烦秩。然后在我們驗(yàn)證以及預(yù)測過程中,就使用我們統(tǒng)計(jì)得到的均值和方差進(jìn)行標(biāo)準(zhǔn)化處理郎仆。
2.3 代碼
在訓(xùn)練過程中只祠,均值和方差通過計(jì)算當(dāng)前批次數(shù)據(jù)得到的記為和扰肌,而我們?cè)陬A(yù)測過程中所使用的均值和方差是一個(gè)訓(xùn)練過程中保存的統(tǒng)計(jì)量抛寝,記和,和的具體更新策略如下,momentum默認(rèn)取值為0.1:
需要注意的是:
- 在pytorch中對(duì)當(dāng)前批次feature進(jìn)行BN處理時(shí)使用的是總體標(biāo)準(zhǔn)差墩剖,計(jì)算公式是:
- 在更新統(tǒng)計(jì)量時(shí)采用的是樣本標(biāo)準(zhǔn)差猴凹,計(jì)算公式是:
下面是使用pytorch做的測試:
(1)bn_process函數(shù)是自定義的bn處理方法驗(yàn)證是否和使用官方bn處理方法結(jié)果一致。在bn_process中計(jì)算輸入batch數(shù)據(jù)的每個(gè)維度(這里的維度是channel維度)的均值和標(biāo)準(zhǔn)差(標(biāo)準(zhǔn)差等于方差開平方)岭皂,然后通過計(jì)算得到的均值和總體標(biāo)準(zhǔn)差對(duì)feature每個(gè)維度進(jìn)行標(biāo)準(zhǔn)化,然后使用均值和樣本標(biāo)準(zhǔn)差更新統(tǒng)計(jì)均值和標(biāo)準(zhǔn)差沼头。
(2)初始化統(tǒng)計(jì)均值是一個(gè)元素為0的向量爷绘,元素個(gè)數(shù)等于channel深度;初始化統(tǒng)計(jì)方差是一個(gè)元素為1的向量进倍,元素個(gè)數(shù)等于channel深度土至,初始化,猾昆。
import numpy as np
import torch.nn as nn
import torch
def bn_process(feature, mean, var):
feature_shape = feature.shape
for i in range(feature_shape[1]):
# [batch, channel, height, width]
feature_t = feature[:, i, :, :]
mean_t = feature_t.mean()
# 總體標(biāo)準(zhǔn)差
std_t1 = feature_t.std()
# 樣本標(biāo)準(zhǔn)差
std_t2 = feature_t.std(ddof=1)
# bn process
feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / std_t1
# update calculating mean and var
mean[i] = mean[i]*0.9 + mean_t*0.1
var[i] = var[i]*0.9 + (std_t2**2)*0.1
print(feature)
# 隨機(jī)生成一個(gè)batch為2陶因,channel為2,height=width=2的特征向量
# [batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
# 初始化統(tǒng)計(jì)均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())
# 注意要使用copy()深拷貝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)
bn = nn.BatchNorm2d(2)
output = bn(feature1)
print(output)
設(shè)置一個(gè)斷點(diǎn)進(jìn)行調(diào)試垂蜗,查看下官方bn對(duì)feature處理后得到的統(tǒng)計(jì)均值和方差楷扬。我們可以發(fā)現(xiàn)官方提供的bn的running_mean和running_var和我們自己計(jì)算的calculate_mean和calculate_var是一模一樣的(只是精度不同):
輸出結(jié)果如下:
從結(jié)果可以看出:通過自定義bn_process函數(shù)得到的輸出以及使用官方bn處理得到輸出,明顯結(jié)果是一樣的(只是精度不同)贴见。
2.4 優(yōu)勢
Batch Normalization在實(shí)際工程中被證明了能夠緩解神經(jīng)網(wǎng)絡(luò)難以訓(xùn)練的問題烘苹,BN具有的有事可以總結(jié)為以下四點(diǎn):
(1)BN使得網(wǎng)絡(luò)中每層輸入數(shù)據(jù)的分布相對(duì)穩(wěn)定,加速模型學(xué)習(xí)速度
BN通過規(guī)范化與線性變換使得每一層網(wǎng)絡(luò)的輸入數(shù)據(jù)的均值與方差都在一定范圍內(nèi)片部,使得后一層網(wǎng)絡(luò)不必不斷去適應(yīng)底層網(wǎng)絡(luò)中輸入的變化镣衡,從而實(shí)現(xiàn)了網(wǎng)絡(luò)中層與層之間的解耦,允許每一層進(jìn)行獨(dú)立學(xué)習(xí)档悠,有利于提高整個(gè)神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)速度廊鸥。
(2)BN使得模型對(duì)網(wǎng)絡(luò)中的參數(shù)不那么敏感,簡化調(diào)參過程辖所,使得網(wǎng)絡(luò)學(xué)習(xí)更加穩(wěn)定
在神經(jīng)網(wǎng)絡(luò)中惰说,我們經(jīng)常會(huì)謹(jǐn)慎地采用一些權(quán)重初始化方法(例如Xavier)或者合適的學(xué)習(xí)率來保證網(wǎng)絡(luò)穩(wěn)定訓(xùn)練。
當(dāng)學(xué)習(xí)率設(shè)置太高時(shí)奴烙,會(huì)使得參數(shù)更新步伐過大助被,容易出現(xiàn)震蕩和不收斂。但是使用BN的網(wǎng)絡(luò)將不會(huì)受到參數(shù)數(shù)值大小的影響切诀。
因此揩环,在使用Batch Normalization之后,抑制了參數(shù)微小變化隨著網(wǎng)絡(luò)層數(shù)加深被放大的問題幅虑,使得網(wǎng)絡(luò)對(duì)參數(shù)大小的適應(yīng)能力更強(qiáng)丰滑,此時(shí)我們可以設(shè)置較大的學(xué)習(xí)率而不用過于擔(dān)心模型divergence的風(fēng)險(xiǎn)。
3)BN允許網(wǎng)絡(luò)使用飽和性激活函數(shù)(例如sigmoid,tanh等)褒墨,緩解梯度消失問題
在不使用BN層的時(shí)候炫刷,由于網(wǎng)絡(luò)的深度與復(fù)雜性,很容易使得底層網(wǎng)絡(luò)變化累積到上層網(wǎng)絡(luò)中郁妈,導(dǎo)致模型的訓(xùn)練很容易進(jìn)入到激活函數(shù)的梯度飽和區(qū)浑玛;通過normalize操作可以讓激活函數(shù)的輸入數(shù)據(jù)落在梯度非飽和區(qū),緩解梯度消失的問題噩咪;另外通過自適應(yīng)學(xué)習(xí)與又讓數(shù)據(jù)保留更多的原始信息顾彰。
(4)BN具有一定的正則化效果
在Batch Normalization中,由于我們使用mini-batch的均值與方差作為對(duì)整體訓(xùn)練樣本均值與方差的估計(jì)胃碾,盡管每一個(gè)batch中的數(shù)據(jù)都是從總體樣本中抽樣得到涨享,但不同mini-batch的均值與方差會(huì)有所不同,這就為網(wǎng)絡(luò)的學(xué)習(xí)過程中增加了隨機(jī)噪音仆百,與Dropout通過關(guān)閉神經(jīng)元給網(wǎng)絡(luò)訓(xùn)練帶來噪音類似厕隧,在一定程度上對(duì)模型起到了正則化的效果。
2.5 使用BN時(shí)需要注意的問題
訓(xùn)練時(shí)要將traning參數(shù)設(shè)置為True俄周,在驗(yàn)證時(shí)將trainning參數(shù)設(shè)置為False吁讨。在pytorch中可通過創(chuàng)建模型的model.train()和model.eval()方法控制。
batch size盡可能設(shè)置大點(diǎn)栈源,設(shè)置小后表現(xiàn)可能很糟糕挡爵,設(shè)置的越大求的均值和方差越接近整個(gè)訓(xùn)練集的均值和方差。
建議將bn層放在卷積層(Conv)和激活層(例如Relu)之間甚垦,且卷積層不要使用偏置bias茶鹃,因?yàn)闆]有用,參考下圖推理艰亮,即使使用了偏置bias求出的結(jié)果也是一樣的
Batch Normalization原理與實(shí)戰(zhàn)
Batch Normalization詳解以及pytorch實(shí)驗(yàn)