對BN的理解

BN在網(wǎng)絡(luò)中的位置和操作流程

引言

機器學習有一個重要假設(shè):IID屏歹,就是假設(shè)訓練數(shù)據(jù)和測試數(shù)據(jù)是滿足相同分布的气破,BatchNorm就是在深度神經(jīng)網(wǎng)絡(luò)訓練過程中使得每一層神經(jīng)網(wǎng)絡(luò)的輸入保持相同分布的埠偿。為什么對輸入數(shù)據(jù)做BN辆它,原因在于神經(jīng)網(wǎng)絡(luò)學習過程本質(zhì)上是為了學習數(shù)據(jù)的分布。

“Internal Covariate Shift”問題:

內(nèi)部協(xié)變量偏移堡赔,Internal指的是網(wǎng)絡(luò)深層的隱層,Covariate(協(xié)變量:不可控设联,但對結(jié)果有重要影響)指的是網(wǎng)絡(luò)的參數(shù)善已。在訓練過程中灼捂,因為各層參數(shù)不停在變化,導致隱層的輸入分布老是變來變?nèi)ァ?/p>

BN的基本思想:

每個隱層節(jié)點的激活輸入分布固定下來换团,避免了“Internal Covariate Shift”問題了悉稠,順帶解決反向傳播中梯度消失問題。BN思路來源于:如果對圖像做白化操作(0均值艘包,1方差的正態(tài)分布)的猛,神經(jīng)網(wǎng)絡(luò)收斂較快,深度神經(jīng)網(wǎng)絡(luò)的每一個隱層都是輸入層想虎,不過是相對下一層來說而已卦尊,BN可以理解為對深層神經(jīng)網(wǎng)絡(luò)每個隱層神經(jīng)元的激活值做簡化版本的白化操作。

一句話:對于每個隱層神經(jīng)元舌厨,把逐漸向非線性函數(shù)映射后向取值區(qū)間極限飽和區(qū)靠攏的輸入分布強制拉回到均值為0方差為1的比較標準的正態(tài)分布岂却,使得非線性變換函數(shù)的輸入值落入對輸入比較敏感的區(qū)域,以此避免梯度消失問題裙椭。經(jīng)過BN后躏哩,目前大部分Activation的值落入非線性函數(shù)的線性區(qū)內(nèi),其對應(yīng)的導數(shù)遠離導數(shù)飽和區(qū)骇陈,這樣來加速訓練收斂過程震庭。

疑點:BN操作之后,非線性激活函數(shù)變成了和線性函數(shù)一樣的效果你雌,顯然是不行的器联,為了保證非線性的獲得,對變換后的滿足均值為0方差為1的x又進行了scale加上shift操作(y=scale*x+shift)婿崭,每個神經(jīng)元增加了兩個參數(shù)scale和shift參數(shù)拨拓,這兩個參數(shù)是通過訓練學習到的,意思是通過scale和shift把這個值從標準正態(tài)分布左移或者右移一點并長胖一點或者變瘦一點氓栈,每個實例挪動的程度不一樣渣磷,這樣等價于非線性函數(shù)的值從正中心周圍的線性區(qū)往非線性區(qū)動了動。這樣找到一個線性和非線性的較好的平衡點授瘦,既能享受非線性的較強表達能力的好處醋界,又避免太靠非線性區(qū)兩頭使得網(wǎng)絡(luò)收斂速度太慢。這里理想狀態(tài)的scale和shift操作會不會又把x變換成未變換之前的狀態(tài)提完,又回到Internal Covariate Shift問題哪里形纺?應(yīng)該不會哈哈哈,否則BN完全沒用了啊徒欣,事實證明逐样。

Inference時的BN操作:

一個實例是沒法算實例集合求出的均值和方差,既然沒有從Mini-Batch數(shù)據(jù)里可以得到的統(tǒng)計量,那就想其它辦法來獲得這個統(tǒng)計量脂新,就是均值和方差挪捕。可以用從所有訓練實例中獲得的統(tǒng)計量來代替Mini-Batch里面m個訓練實例獲得的均值和方差統(tǒng)計量,因為本來就打算用全局的統(tǒng)計量争便,只是因為計算量等太大所以才會用Mini-Batch這種簡化方式的级零,那么在推理的時候直接用全局統(tǒng)計量即可。把每個Mini-Batch的均值和方差統(tǒng)計量記住始花,然后對這些均值和方差求其對應(yīng)的數(shù)學期望即可得出全局統(tǒng)計量妄讯。設(shè)置model.eval()的一個作用就是固定BN層,不像在訓練階段去求每個mini-batch的均值方差,而是直接取出之前記錄在網(wǎng)絡(luò)里面的每個mini-batch的方差,去求期望.

個人理解

為什么bs越大越好,因為bs越大酷宵,每個bs的分布就越趨近于同分布亥贸,這樣網(wǎng)絡(luò)比較容易學習數(shù)據(jù)的分布規(guī)律,梯度更新方向比較一致浇垦,收斂更快炕置。

BN中的參數(shù)

看一個例子

import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(6)


    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)

        return x

model = Net()
for name, para in model.named_parameters():
    print(name, para)

print('************************************************************')
for name, buffer in model.named_buffers():
    print(name, buffer)

輸出為

OrderedDict([('conv1.weight', tensor([[[[ 0.0108,  0.1240,  0.0641],
          [ 0.0838,  0.0657,  0.0785],
          [ 0.0755, -0.1763, -0.0934]],

         [[-0.1210, -0.1455, -0.1416],
          [ 0.0903,  0.0632,  0.0489],
          [-0.0614, -0.1614,  0.1625]],

         [[ 0.1661, -0.0992, -0.1398],
          [ 0.1170,  0.1084,  0.1536],
          [ 0.0179,  0.1310, -0.0289]]],


        [[[ 0.1363,  0.1840,  0.1140],
          [ 0.0471,  0.0555,  0.1758],
          [-0.0386,  0.1077,  0.1612]],

         [[ 0.1177,  0.1799, -0.0495],
          [-0.0314, -0.1714,  0.1125],
          [-0.0723, -0.0770,  0.1663]],

         [[-0.1474,  0.0866, -0.0111],
          [ 0.1476, -0.0468, -0.0683],
          [ 0.0535,  0.1440,  0.1900]]],


        [[[-0.0954,  0.0743, -0.0975],
          [ 0.0741,  0.1436, -0.1203],
          [-0.0047,  0.1317, -0.1513]],

         [[-0.1422,  0.1404,  0.1614],
          [ 0.0025, -0.1499,  0.1647],
          [ 0.0192,  0.0324,  0.0593]],

         [[-0.0041,  0.1813, -0.1696],
          [ 0.0822,  0.1765, -0.1627],
          [ 0.0262,  0.1857, -0.0359]]],


        [[[-0.1816, -0.1198, -0.1289],
          [-0.0138,  0.1118, -0.0687],
          [-0.0078, -0.0975, -0.0646]],

         [[ 0.1763, -0.0490, -0.1117],
          [ 0.0976, -0.0156,  0.1104],
          [-0.0755,  0.0067,  0.0637]],

         [[-0.0131, -0.1783,  0.0628],
          [ 0.1020,  0.1713, -0.0764],
          [-0.1752,  0.0589, -0.0661]]],


        [[[-0.0292,  0.1491,  0.1690],
          [-0.1483,  0.1089, -0.1463],
          [-0.1159,  0.0097,  0.1525]],

         [[-0.0439, -0.0683, -0.0691],
          [-0.0465, -0.0289,  0.1653],
          [ 0.1307, -0.0170, -0.1875]],

         [[-0.0941,  0.1616,  0.0168],
          [ 0.1385,  0.1919,  0.0238],
          [-0.0705,  0.1550,  0.1585]]],


        [[[ 0.1091,  0.0602, -0.1886],
          [ 0.0663,  0.1151, -0.1629],
          [ 0.0955, -0.1370, -0.1030]],

         [[-0.1690,  0.1786,  0.0723],
          [-0.0280, -0.0451, -0.0303],
          [-0.0342, -0.0909, -0.1883]],

         [[ 0.1072,  0.1869,  0.0249],
          [ 0.1028, -0.1043,  0.0852],
          [-0.0532, -0.1132, -0.0372]]]])), ('conv1.bias', tensor([-0.0907,  0.1700, -0.0342,  0.1511,  0.0931,  0.0797])), ('bn1.weight', tensor([1., 1., 1., 1., 1., 1.])), ('bn1.bias', tensor([0., 0., 0., 0., 0., 0.])), ('bn1.running_mean', tensor([0., 0., 0., 0., 0., 0.])), ('bn1.running_var', tensor([1., 1., 1., 1., 1., 1.])), ('bn1.num_batches_tracked', tensor(0))])

conv1.weight Parameter containing:
tensor([[[[-0.1410,  0.0936, -0.0152],
          [-0.1397, -0.1212, -0.1048],
          [-0.1421, -0.0171,  0.0640]],

         [[ 0.1423, -0.1203, -0.0369],
          [-0.0067,  0.0966,  0.1195],
          [ 0.0143,  0.0839, -0.0283]],

         [[-0.1537, -0.1123, -0.1345],
          [ 0.0886,  0.1017,  0.0533],
          [-0.0084, -0.1251,  0.1744]]],


        [[[ 0.1859, -0.1693, -0.1616],
          [ 0.0567,  0.1256,  0.0887],
          [-0.0761, -0.1245, -0.0764]],

         [[ 0.1298, -0.1307, -0.0978],
          [ 0.0780,  0.0860, -0.0598],
          [-0.0295, -0.1884,  0.0191]],

         [[-0.1898, -0.0489,  0.1485],
          [-0.1887, -0.0618, -0.1429],
          [ 0.1066, -0.0593,  0.0559]]],


        [[[ 0.0189,  0.0575,  0.1358],
          [-0.1079, -0.0591, -0.1221],
          [ 0.0100, -0.0392,  0.0423]],

         [[ 0.1072,  0.1461, -0.1267],
          [-0.1478,  0.1647,  0.1149],
          [ 0.0258, -0.1862, -0.0070]],

         [[ 0.1138, -0.0968,  0.0016],
          [-0.0955,  0.1802,  0.0822],
          [-0.1311,  0.0945, -0.0038]]],


        [[[ 0.1647, -0.0404,  0.0610],
          [-0.1558,  0.1357,  0.1779],
          [-0.0070,  0.1030, -0.0585]],

         [[ 0.1592,  0.0970,  0.0614],
          [-0.0068, -0.0732,  0.1352],
          [ 0.0447,  0.0769, -0.0384]],

         [[-0.0589, -0.0711, -0.0543],
          [ 0.0926, -0.0984, -0.0573],
          [ 0.0687,  0.1849,  0.0993]]],


        [[[ 0.0730,  0.0036,  0.0584],
          [ 0.0568,  0.0311, -0.1742],
          [ 0.1582, -0.0496, -0.0620]],

         [[ 0.0348, -0.1415,  0.0212],
          [-0.1688,  0.0436, -0.1485],
          [ 0.0154, -0.1302,  0.1255]],

         [[ 0.1393,  0.0575, -0.1821],
          [ 0.0244, -0.1584,  0.0886],
          [-0.0158, -0.1907, -0.1038]]],


        [[[ 0.0019, -0.0077, -0.0073],
          [ 0.0667,  0.1904,  0.1622],
          [-0.1315,  0.1265,  0.0110]],

         [[ 0.0979,  0.0211, -0.1126],
          [ 0.1260,  0.1614,  0.0309],
          [-0.0724, -0.1381,  0.1275]],

         [[-0.0206, -0.0674, -0.0358],
          [-0.0800, -0.0408,  0.1636],
          [ 0.0082, -0.0014, -0.0292]]]], requires_grad=True)
conv1.bias Parameter containing:
tensor([-0.0660,  0.0184,  0.0102,  0.1804, -0.0702,  0.0977],
       requires_grad=True)
bn1.weight Parameter containing:
tensor([1., 1., 1., 1., 1., 1.], requires_grad=True)
bn1.bias Parameter containing:
tensor([0., 0., 0., 0., 0., 0.], requires_grad=True)
************************************************************
bn1.running_mean tensor([0., 0., 0., 0., 0., 0.])
bn1.running_var tensor([1., 1., 1., 1., 1., 1.])
bn1.num_batches_tracked tensor(0)

可以看到,網(wǎng)絡(luò)中的參數(shù)除了parameters男韧,還有一些不用更新的參數(shù)朴摊,主要是BN中的'bn1.running_mean'bn1.running_var,這些參數(shù)只在forward時進行統(tǒng)計計算此虑,backward時并不會被更新甚纲,這些參數(shù)也稱為buffer,可以用model.buffers()獲取朦前。順便提一下介杆,在進行推理時設(shè)置model.val(),會固定這些參數(shù)韭寸,不會計算春哨,而是采用記錄的全局統(tǒng)計量,如上所述恩伺。

創(chuàng)建于2020.11.26

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末赴背,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子晶渠,更是在濱河造成了極大的恐慌凰荚,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,270評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件褒脯,死亡現(xiàn)場離奇詭異浇揩,居然都是意外死亡,警方通過查閱死者的電腦和手機憨颠,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,489評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人爽彤,你說我怎么就攤上這事养盗。” “怎么了适篙?”我有些...
    開封第一講書人閱讀 165,630評論 0 356
  • 文/不壞的土叔 我叫張陵往核,是天一觀的道長。 經(jīng)常有香客問我嚷节,道長聂儒,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,906評論 1 295
  • 正文 為了忘掉前任硫痰,我火速辦了婚禮衩婚,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘效斑。我一直安慰自己非春,他們只是感情好,可當我...
    茶點故事閱讀 67,928評論 6 392
  • 文/花漫 我一把揭開白布缓屠。 她就那樣靜靜地躺著奇昙,像睡著了一般。 火紅的嫁衣襯著肌膚如雪敌完。 梳的紋絲不亂的頭發(fā)上储耐,一...
    開封第一講書人閱讀 51,718評論 1 305
  • 那天,我揣著相機與錄音滨溉,去河邊找鬼什湘。 笑死,一個胖子當著我的面吹牛业踏,可吹牛的內(nèi)容都是我干的禽炬。 我是一名探鬼主播,決...
    沈念sama閱讀 40,442評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼勤家,長吁一口氣:“原來是場噩夢啊……” “哼腹尖!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起伐脖,我...
    開封第一講書人閱讀 39,345評論 0 276
  • 序言:老撾萬榮一對情侶失蹤热幔,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后讼庇,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體绎巨,經(jīng)...
    沈念sama閱讀 45,802評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,984評論 3 337
  • 正文 我和宋清朗相戀三年蠕啄,在試婚紗的時候發(fā)現(xiàn)自己被綠了场勤。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片戈锻。...
    茶點故事閱讀 40,117評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖和媳,靈堂內(nèi)的尸體忽然破棺而出格遭,到底是詐尸還是另有隱情,我是刑警寧澤留瞳,帶...
    沈念sama閱讀 35,810評論 5 346
  • 正文 年R本政府宣布拒迅,位于F島的核電站,受9級特大地震影響她倘,放射性物質(zhì)發(fā)生泄漏璧微。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,462評論 3 331
  • 文/蒙蒙 一硬梁、第九天 我趴在偏房一處隱蔽的房頂上張望前硫。 院中可真熱鬧,春花似錦靶溜、人聲如沸开瞭。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,011評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽嗤详。三九已至,卻和暖如春瓷炮,著一層夾襖步出監(jiān)牢的瞬間葱色,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,139評論 1 272
  • 我被黑心中介騙來泰國打工娘香, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留苍狰,地道東北人。 一個月前我還...
    沈念sama閱讀 48,377評論 3 373
  • 正文 我出身青樓烘绽,卻偏偏與公主長得像淋昭,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子安接,可洞房花燭夜當晚...
    茶點故事閱讀 45,060評論 2 355

推薦閱讀更多精彩內(nèi)容