上一期介紹了Batch Normalization的前向傳播忠聚,然而想法美好,然而能否計算、如何計算這些新參數(shù)才是重點。
系列目錄
理解Batch Normalization系列1——原理
理解Batch Normalization系列2——訓(xùn)練及評估
理解Batch Normalization系列3——為什么有效及若干討論
理解Batch Normalization系列4——實踐
本文目錄
1 訓(xùn)練階段
1.1 反向傳播
1.2 參數(shù)的初始化及更新
2 評估階段
2.1 來自訓(xùn)練集的均值和方差
2.2 評估階段的計算
3 總結(jié)
參考文獻(xiàn)
先放出這張圖坠陈,幫助記住。
? 圖 1. BN的結(jié)構(gòu)
1 訓(xùn)練階段
引入BN喘帚,增加了畅姊、
咒钟、
吹由、
四個參數(shù)。
這四個參數(shù)的引入朱嘴,能否計算梯度倾鲫?它們分別是如何初始化與更新?
1.1 反向傳播
神經(jīng)網(wǎng)絡(luò)的訓(xùn)練萍嬉,離不開反向傳播乌昔,必須保證BN的標(biāo)準(zhǔn)化、縮放平移兩個操作必須可導(dǎo)壤追。
縮放平移就是一個線性公式磕道,求導(dǎo)很簡單。而對于標(biāo)準(zhǔn)化時的統(tǒng)計量行冰,看起來有點無從下手溺蕉。其實是憑借圖1的變量關(guān)系伶丐,可以繪制計算圖,如圖2所示疯特。Frederik Kratzert 在這篇博文中有詳細(xì)的計算哗魂,對每一個環(huán)節(jié)都進(jìn)行了詳細(xì)的描述。
? 圖 2. 求解BN反向傳播的計算圖 (來源: 這篇博文)
由圖2可見:
- 每個環(huán)節(jié)都可導(dǎo)
- 只要求出各個環(huán)節(jié)的導(dǎo)數(shù)
- 用鏈?zhǔn)椒▌t(串聯(lián)關(guān)系就相乘漓雅,并聯(lián)關(guān)系就相加)求出總梯度录别。
狗尾續(xù)貂,對這個反傳大致做了一個流程圖邻吞,如圖3所示组题,幫助理解。
? 圖 3. BN層反傳的流程圖 (來源: 這篇博文)
注意吃衅,均值的梯度往踢、方差的梯度的計算,只是為了保證梯度的反向傳播鏈路的通暢徘层,而不是為了更新自己(沒明白下文還會解釋)峻呕;縮放因子和j和平移因子
的梯度傳播則和權(quán)重W一樣,不影響反向傳播鏈路的通暢趣效,只是為了更新自己瘦癌。
最后的結(jié)果就是原論文中表述:
? 圖4. BN的反向傳播. (來源: Batch Normalization Paper)
? 如果是從事學(xué)術(shù),不妨練練手跷敬。
1.2 參數(shù)的初始化及更新
討論一下圖1中的6個參數(shù)的初始化及更新問題讯私。
-
W
初始化用標(biāo)準(zhǔn)正態(tài)分布,更新用梯度下降西傀。
與經(jīng)典網(wǎng)絡(luò)的初始化相同斤寇,初始化一個標(biāo)準(zhǔn)正態(tài)分布(即Xavier方法)。
-
b
省略掉該參數(shù)拥褂。
在經(jīng)典的神經(jīng)網(wǎng)絡(luò)里娘锁,b作為偏置,用于解決那些W無法通過與x相乘搞定的"損失減少要求"饺鹃,即對于本層所有神經(jīng)元的加權(quán)和進(jìn)行各自的平移莫秆。而加入BN后,
的作用正是進(jìn)行平移悔详。b的作用被
所完全替代了镊屎,因此省略掉b。
了解過ResNet結(jié)構(gòu)的朋友會發(fā)現(xiàn)該網(wǎng)絡(luò)中的卷積茄螃,都沒有偏置缝驳,為什么?下面截圖是Kaiming He在github上回答原話。(踩坑無數(shù)必須體會深刻)
? 圖5. BN的加入導(dǎo)致本層的偏置b失效
-
和
初始化取決于統(tǒng)計量用狱,僅更新梯度萎庭,但不更新值本身。
在訓(xùn)練階段齿拂,每個mini-batch上進(jìn)行前向傳播時驳规,通過對本batch上的m個樣本進(jìn)行統(tǒng)計得到;
在反向傳播時署海,計算出它們的梯度
對
的梯度吗购、
對
的梯度,用于進(jìn)行梯度傳播砸狞。
但是和
這兩個值本身不必進(jìn)行更新捻勉,因為在下一個mini-batch會計算自己的統(tǒng)計量,所以前一個mini-batch獲得的
和
沒意義刀森。
-
和
初始化為1踱启、0,更新用梯度下降研底。
根據(jù)我們在《理解Batch Normalization系列1——原理》的解讀埠偿,
作為“準(zhǔn)方差”,初始化為一個全1向量榜晦;而
作為"準(zhǔn)均值”冠蒋,初始化為一個全0向量,他倆的初始值對于剛剛完成標(biāo)準(zhǔn)正態(tài)化的
來說乾胶,沒起任何作用抖剿。
至于將要變成什么值,起多大作用识窿,那就交給后續(xù)的訓(xùn)練斩郎。即采用梯度下降進(jìn)行更新,方式同
喻频。
2 評估階段
缩宜、
是在整個訓(xùn)練集上訓(xùn)練出來的,與
一樣半抱,訓(xùn)練結(jié)束就可獲得脓恕。
然而膜宋,和
是靠每一個mini-batch的統(tǒng)計得到窿侈,因為評估時只有一條樣本,batch_size相當(dāng)于是1秋茫,在只有1個向量的數(shù)據(jù)組上進(jìn)行標(biāo)準(zhǔn)化后史简,成了一個全0向量,這可咋辦?
2.1 來自訓(xùn)練集的均值和方差
做法是用訓(xùn)練集來估計總體均值和總體標(biāo)準(zhǔn)差
圆兵。
-
簡單平均法
把每個mini-batch的均值和方差都保存下來跺讯,然后訓(xùn)練完了求均值的均值,方差的均值即可殉农。
-
移動指數(shù)平均(Exponential Moving Average)
這是對均值的近似刀脏。
僅以
舉例:
? 其中decay是衰減系數(shù)。即總均值是前一個mini-batch統(tǒng)計的總均值和本次mini-batch的
加權(quán)求和超凳。至于衰減率 decay在區(qū)間
之間愈污,decay越接近1,結(jié)果
越穩(wěn)定轮傍,越受較遠(yuǎn)的大范圍的樣本影響暂雹;decay越接近0,結(jié)果
越波動创夜,越受較近的小范圍的樣本影響杭跪。
事實上,簡單平均可能更好驰吓,簡單平均本質(zhì)上是平均權(quán)重,但是簡單平均需要保存所有BN層在所有mini-batch上的均值向量和方差向量檬贰,如果訓(xùn)練數(shù)據(jù)量很大现斋,會有較可觀的存儲代價。移動指數(shù)平均在實際的框架中更常見(例如tensorflow)偎蘸,可能的好處是EMA不需要存儲每一個mini-batch的值庄蹋,永遠(yuǎn)只保存著三個值:總統(tǒng)計值、本batch的統(tǒng)計值迷雪,decay系數(shù)限书。
在訓(xùn)練階段同步獲得了和
后,在評估時即可對樣本進(jìn)行BN操作章咧。
2.2 評估階段的計算
為避免分母不為0倦西,增加一個非常小的常數(shù),并為了計算優(yōu)化赁严,被轉(zhuǎn)換為:
這樣扰柠,只要訓(xùn)練結(jié)束,就已知了,1個BN層對一條測試樣本的前向傳播只是增加了一層線性計算而已程剥。
3 總結(jié)
用圖6做個總結(jié)劝枣。
? 圖6. BN層相關(guān)參數(shù)的學(xué)習(xí)方法
鬼斧神工的構(gòu)造,鬼斧神工的參數(shù)獲取方法,這么多鬼斧神工舔腾,需要好好消化消化溪胶。
請見下一期《理解Batch Normalization系列3——為什么有效及若干討論》
參考文獻(xiàn)
[1] https://arxiv.org/pdf/1502.03167v3.pdf
[2] https://r2rt.com/implementing-batch-normalization-in-tensorflow.html
[3] Adjusting for Dropout Variance in Batch Normalization and Weight Initialization
[4] http://www.reibang.com/p/05f3e7ddf1e1
[8] https://panxiaoxie.cn/2018/07/28/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0-Batch-Normalization/
[9] https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization