引言
機器學習有一個重要假設(shè):IID屏歹,就是假設(shè)訓練數(shù)據(jù)和測試數(shù)據(jù)是滿足相同分布的气破,BatchNorm就是在深度神經(jīng)網(wǎng)絡(luò)訓練過程中使得每一層神經(jīng)網(wǎng)絡(luò)的輸入保持相同分布的埠偿。為什么對輸入數(shù)據(jù)做BN辆它,原因在于神經(jīng)網(wǎng)絡(luò)學習過程本質(zhì)上是為了學習數(shù)據(jù)的分布。
“Internal Covariate Shift”問題:
內(nèi)部協(xié)變量偏移堡赔,Internal指的是網(wǎng)絡(luò)深層的隱層,Covariate(協(xié)變量:不可控设联,但對結(jié)果有重要影響)指的是網(wǎng)絡(luò)的參數(shù)善已。在訓練過程中灼捂,因為各層參數(shù)不停在變化,導致隱層的輸入分布老是變來變?nèi)ァ?/p>
BN的基本思想:
每個隱層節(jié)點的激活輸入分布固定下來换团,避免了“Internal Covariate Shift”問題了悉稠,順帶解決反向傳播中梯度消失問題。BN思路來源于:如果對圖像做白化操作(0均值艘包,1方差的正態(tài)分布)的猛,神經(jīng)網(wǎng)絡(luò)收斂較快,深度神經(jīng)網(wǎng)絡(luò)的每一個隱層都是輸入層想虎,不過是相對下一層來說而已卦尊,BN可以理解為對深層神經(jīng)網(wǎng)絡(luò)每個隱層神經(jīng)元的激活值做簡化版本的白化操作。
一句話:對于每個隱層神經(jīng)元舌厨,把逐漸向非線性函數(shù)映射后向取值區(qū)間極限飽和區(qū)靠攏的輸入分布強制拉回到均值為0方差為1的比較標準的正態(tài)分布岂却,使得非線性變換函數(shù)的輸入值落入對輸入比較敏感的區(qū)域,以此避免梯度消失問題裙椭。經(jīng)過BN后躏哩,目前大部分Activation的值落入非線性函數(shù)的線性區(qū)內(nèi),其對應(yīng)的導數(shù)遠離導數(shù)飽和區(qū)骇陈,這樣來加速訓練收斂過程震庭。
疑點:BN操作之后,非線性激活函數(shù)變成了和線性函數(shù)一樣的效果你雌,顯然是不行的器联,為了保證非線性的獲得,對變換后的滿足均值為0方差為1的x又進行了scale加上shift操作(y=scale*x+shift)婿崭,每個神經(jīng)元增加了兩個參數(shù)scale和shift參數(shù)拨拓,這兩個參數(shù)是通過訓練學習到的,意思是通過scale和shift把這個值從標準正態(tài)分布左移或者右移一點并長胖一點或者變瘦一點氓栈,每個實例挪動的程度不一樣渣磷,這樣等價于非線性函數(shù)的值從正中心周圍的線性區(qū)往非線性區(qū)動了動。這樣找到一個線性和非線性的較好的平衡點授瘦,既能享受非線性的較強表達能力的好處醋界,又避免太靠非線性區(qū)兩頭使得網(wǎng)絡(luò)收斂速度太慢。這里理想狀態(tài)的scale和shift操作會不會又把x變換成未變換之前的狀態(tài)提完,又回到Internal Covariate Shift問題哪里形纺?應(yīng)該不會哈哈哈,否則BN完全沒用了啊徒欣,事實證明逐样。
Inference時的BN操作:
一個實例是沒法算實例集合求出的均值和方差,既然沒有從Mini-Batch數(shù)據(jù)里可以得到的統(tǒng)計量,那就想其它辦法來獲得這個統(tǒng)計量脂新,就是均值和方差挪捕。可以用從所有訓練實例中獲得的統(tǒng)計量來代替Mini-Batch里面m個訓練實例獲得的均值和方差統(tǒng)計量,因為本來就打算用全局的統(tǒng)計量争便,只是因為計算量等太大所以才會用Mini-Batch這種簡化方式的级零,那么在推理的時候直接用全局統(tǒng)計量即可。把每個Mini-Batch的均值和方差統(tǒng)計量記住始花,然后對這些均值和方差求其對應(yīng)的數(shù)學期望即可得出全局統(tǒng)計量妄讯。設(shè)置model.eval()的一個作用就是固定BN層,不像在訓練階段去求每個mini-batch的均值方差,而是直接取出之前記錄在網(wǎng)絡(luò)里面的每個mini-batch的方差,去求期望.
個人理解
為什么bs越大越好,因為bs越大酷宵,每個bs的分布就越趨近于同分布亥贸,這樣網(wǎng)絡(luò)比較容易學習數(shù)據(jù)的分布規(guī)律,梯度更新方向比較一致浇垦,收斂更快炕置。
BN中的參數(shù)
看一個例子
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(6)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
return x
model = Net()
for name, para in model.named_parameters():
print(name, para)
print('************************************************************')
for name, buffer in model.named_buffers():
print(name, buffer)
輸出為
OrderedDict([('conv1.weight', tensor([[[[ 0.0108, 0.1240, 0.0641],
[ 0.0838, 0.0657, 0.0785],
[ 0.0755, -0.1763, -0.0934]],
[[-0.1210, -0.1455, -0.1416],
[ 0.0903, 0.0632, 0.0489],
[-0.0614, -0.1614, 0.1625]],
[[ 0.1661, -0.0992, -0.1398],
[ 0.1170, 0.1084, 0.1536],
[ 0.0179, 0.1310, -0.0289]]],
[[[ 0.1363, 0.1840, 0.1140],
[ 0.0471, 0.0555, 0.1758],
[-0.0386, 0.1077, 0.1612]],
[[ 0.1177, 0.1799, -0.0495],
[-0.0314, -0.1714, 0.1125],
[-0.0723, -0.0770, 0.1663]],
[[-0.1474, 0.0866, -0.0111],
[ 0.1476, -0.0468, -0.0683],
[ 0.0535, 0.1440, 0.1900]]],
[[[-0.0954, 0.0743, -0.0975],
[ 0.0741, 0.1436, -0.1203],
[-0.0047, 0.1317, -0.1513]],
[[-0.1422, 0.1404, 0.1614],
[ 0.0025, -0.1499, 0.1647],
[ 0.0192, 0.0324, 0.0593]],
[[-0.0041, 0.1813, -0.1696],
[ 0.0822, 0.1765, -0.1627],
[ 0.0262, 0.1857, -0.0359]]],
[[[-0.1816, -0.1198, -0.1289],
[-0.0138, 0.1118, -0.0687],
[-0.0078, -0.0975, -0.0646]],
[[ 0.1763, -0.0490, -0.1117],
[ 0.0976, -0.0156, 0.1104],
[-0.0755, 0.0067, 0.0637]],
[[-0.0131, -0.1783, 0.0628],
[ 0.1020, 0.1713, -0.0764],
[-0.1752, 0.0589, -0.0661]]],
[[[-0.0292, 0.1491, 0.1690],
[-0.1483, 0.1089, -0.1463],
[-0.1159, 0.0097, 0.1525]],
[[-0.0439, -0.0683, -0.0691],
[-0.0465, -0.0289, 0.1653],
[ 0.1307, -0.0170, -0.1875]],
[[-0.0941, 0.1616, 0.0168],
[ 0.1385, 0.1919, 0.0238],
[-0.0705, 0.1550, 0.1585]]],
[[[ 0.1091, 0.0602, -0.1886],
[ 0.0663, 0.1151, -0.1629],
[ 0.0955, -0.1370, -0.1030]],
[[-0.1690, 0.1786, 0.0723],
[-0.0280, -0.0451, -0.0303],
[-0.0342, -0.0909, -0.1883]],
[[ 0.1072, 0.1869, 0.0249],
[ 0.1028, -0.1043, 0.0852],
[-0.0532, -0.1132, -0.0372]]]])), ('conv1.bias', tensor([-0.0907, 0.1700, -0.0342, 0.1511, 0.0931, 0.0797])), ('bn1.weight', tensor([1., 1., 1., 1., 1., 1.])), ('bn1.bias', tensor([0., 0., 0., 0., 0., 0.])), ('bn1.running_mean', tensor([0., 0., 0., 0., 0., 0.])), ('bn1.running_var', tensor([1., 1., 1., 1., 1., 1.])), ('bn1.num_batches_tracked', tensor(0))])
conv1.weight Parameter containing:
tensor([[[[-0.1410, 0.0936, -0.0152],
[-0.1397, -0.1212, -0.1048],
[-0.1421, -0.0171, 0.0640]],
[[ 0.1423, -0.1203, -0.0369],
[-0.0067, 0.0966, 0.1195],
[ 0.0143, 0.0839, -0.0283]],
[[-0.1537, -0.1123, -0.1345],
[ 0.0886, 0.1017, 0.0533],
[-0.0084, -0.1251, 0.1744]]],
[[[ 0.1859, -0.1693, -0.1616],
[ 0.0567, 0.1256, 0.0887],
[-0.0761, -0.1245, -0.0764]],
[[ 0.1298, -0.1307, -0.0978],
[ 0.0780, 0.0860, -0.0598],
[-0.0295, -0.1884, 0.0191]],
[[-0.1898, -0.0489, 0.1485],
[-0.1887, -0.0618, -0.1429],
[ 0.1066, -0.0593, 0.0559]]],
[[[ 0.0189, 0.0575, 0.1358],
[-0.1079, -0.0591, -0.1221],
[ 0.0100, -0.0392, 0.0423]],
[[ 0.1072, 0.1461, -0.1267],
[-0.1478, 0.1647, 0.1149],
[ 0.0258, -0.1862, -0.0070]],
[[ 0.1138, -0.0968, 0.0016],
[-0.0955, 0.1802, 0.0822],
[-0.1311, 0.0945, -0.0038]]],
[[[ 0.1647, -0.0404, 0.0610],
[-0.1558, 0.1357, 0.1779],
[-0.0070, 0.1030, -0.0585]],
[[ 0.1592, 0.0970, 0.0614],
[-0.0068, -0.0732, 0.1352],
[ 0.0447, 0.0769, -0.0384]],
[[-0.0589, -0.0711, -0.0543],
[ 0.0926, -0.0984, -0.0573],
[ 0.0687, 0.1849, 0.0993]]],
[[[ 0.0730, 0.0036, 0.0584],
[ 0.0568, 0.0311, -0.1742],
[ 0.1582, -0.0496, -0.0620]],
[[ 0.0348, -0.1415, 0.0212],
[-0.1688, 0.0436, -0.1485],
[ 0.0154, -0.1302, 0.1255]],
[[ 0.1393, 0.0575, -0.1821],
[ 0.0244, -0.1584, 0.0886],
[-0.0158, -0.1907, -0.1038]]],
[[[ 0.0019, -0.0077, -0.0073],
[ 0.0667, 0.1904, 0.1622],
[-0.1315, 0.1265, 0.0110]],
[[ 0.0979, 0.0211, -0.1126],
[ 0.1260, 0.1614, 0.0309],
[-0.0724, -0.1381, 0.1275]],
[[-0.0206, -0.0674, -0.0358],
[-0.0800, -0.0408, 0.1636],
[ 0.0082, -0.0014, -0.0292]]]], requires_grad=True)
conv1.bias Parameter containing:
tensor([-0.0660, 0.0184, 0.0102, 0.1804, -0.0702, 0.0977],
requires_grad=True)
bn1.weight Parameter containing:
tensor([1., 1., 1., 1., 1., 1.], requires_grad=True)
bn1.bias Parameter containing:
tensor([0., 0., 0., 0., 0., 0.], requires_grad=True)
************************************************************
bn1.running_mean tensor([0., 0., 0., 0., 0., 0.])
bn1.running_var tensor([1., 1., 1., 1., 1., 1.])
bn1.num_batches_tracked tensor(0)
可以看到,網(wǎng)絡(luò)中的參數(shù)除了parameters男韧,還有一些不用更新的參數(shù)朴摊,主要是BN中的'bn1.running_mean'
和bn1.running_var
,這些參數(shù)只在forward時進行統(tǒng)計計算此虑,backward時并不會被更新甚纲,這些參數(shù)也稱為buffer,可以用model.buffers()
獲取朦前。順便提一下介杆,在進行推理時設(shè)置model.val()
,會固定這些參數(shù)韭寸,不會計算春哨,而是采用記錄的全局統(tǒng)計量,如上所述恩伺。
創(chuàng)建于2020.11.26