Batch Normalization(以下簡(jiǎn)稱BN)是在GoogleInceptionNet V2的論文中被提出的皿淋,該方法減輕了如何合理初始化神經(jīng)網(wǎng)絡(luò)這個(gè)棘手問題帶來的頭痛。
另一片博客主要從為什么要進(jìn)行Batch Normalization探孝,怎么進(jìn)行Batch Normalization,Batch Normalization究竟做了什么等方面去闡述誉裆,可以兩者結(jié)合在一起理解Batch Normalization顿颅。
一、原理介紹
BN是一個(gè)非常有效的正則化方法足丢,可以讓大型卷積網(wǎng)絡(luò)的訓(xùn)練速度加快很多倍粱腻,同時(shí)收斂后的分類準(zhǔn)確率也可以得到大幅提升。BN在用于神經(jīng)網(wǎng)絡(luò)某層時(shí)斩跌,會(huì)對(duì)每一個(gè)mini-batch數(shù)據(jù)的內(nèi)部進(jìn)行標(biāo)準(zhǔn)化處理绍些,使輸出規(guī)范化到N(0,1)的正太分布,減少了內(nèi)部神經(jīng)元分布的改變(Internal Covariate Shift)耀鸦。BN論文指出遇革,傳統(tǒng)的深度神經(jīng)網(wǎng)絡(luò)在訓(xùn)練時(shí),每一層的輸入分布都在變化揭糕,導(dǎo)致訓(xùn)練變得困難萝快,我們只能使用一個(gè)很小的學(xué)習(xí)率來解決這個(gè)問題。而對(duì)每一層使用BN之后著角,我們就可以有效的解決這個(gè)問題揪漩。
二、實(shí)踐細(xì)節(jié)
在實(shí)現(xiàn)層面吏口,應(yīng)用這個(gè)技巧通常意味著全連接層(或者是卷積層)與激活函數(shù)之間添加一個(gè)BN層奄容,對(duì)數(shù)據(jù)進(jìn)行處理使其服從標(biāo)準(zhǔn)高斯分布。因?yàn)闅w一化是一個(gè)簡(jiǎn)單可求導(dǎo)的操作产徊,所以上述思路是可行的昂勒。
全連接層fc/卷積層conv--->批量歸一化Batch Normalization--->激活函數(shù)activation function
單純使用BN獲得增益并不明顯,還需要一些對(duì)應(yīng)的調(diào)整:
- 增大學(xué)習(xí)速率并加快學(xué)習(xí)衰減速度以適用BN規(guī)范化后的數(shù)據(jù)舟铜;
- 去除Dropout并減輕L2正則化(因?yàn)锽N已經(jīng)可以起到正則化的作用)戈盈;
- 更徹底的對(duì)訓(xùn)練樣本進(jìn)行shuffle,減少數(shù)據(jù)增強(qiáng)過程中對(duì)數(shù)據(jù)的光學(xué)畸變(因?yàn)?strong>BN訓(xùn)練更快谆刨,每個(gè)樣本被訓(xùn)練的次數(shù)更少,因此更真實(shí)的樣本對(duì)訓(xùn)練更有幫助)塘娶。
三、公式推導(dǎo)
前向傳播過程
反向傳播過程
四痊夭、代碼實(shí)現(xiàn)
前向傳播過程
def batchnorm_forward(x, gamma, beta, bn_param):
"""
Forward pass for batch normalization.
During training the sample mean and (uncorrected) sample variance are
computed from minibatch statistics and used to normalize the incoming data.
During training we also keep an exponentially decaying running mean of the mean
and variance of each feature, and these averages are used to normalize data
at test-time.
At each timestep we update the running averages for mean and variance using
an exponential decay based on the momentum parameter:
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
Note that the batch normalization paper suggests a different test-time
behavior: they compute sample mean and variance for each feature using a
large number of training images rather than using a running average. For
this implementation we have chosen to use running averages instead since
they do not require an additional estimation step; the torch7 implementation
of batch normalization also uses running averages.
Input:
- x: Data of shape (N, D)
- gamma: Scale parameter of shape (D,)
- beta: Shift paremeter of shape (D,)
- bn_param: Dictionary with the following keys:
- mode: 'train' or 'test'; required
- eps: Constant for numeric stability
- momentum: Constant for running mean / variance.
- running_mean: Array of shape (D,) giving running mean of features
- running_var Array of shape (D,) giving running variance of features
Returns a tuple of:
- out: of shape (N, D)
- cache: A tuple of values needed in the backward pass
"""
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = x.shape
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
out, cache = None, None
if mode == 'train':
# Compute output
mu = x.mean(axis=0)
xc = x - mu
var = np.mean(xc ** 2, axis=0)
std = np.sqrt(var + eps)
xn = xc / std
out = gamma * xn + beta
cache = (mode, x, gamma, xc, std, xn, out)
# Update running average of mean
running_mean *= momentum
running_mean += (1 - momentum) * mu
# Update running average of variance
running_var *= momentum
running_var += (1 - momentum) * var
elif mode == 'test':
# Using running mean and variance to normalize
std = np.sqrt(running_var + eps)
xn = (x - running_mean) / std
out = gamma * xn + beta
cache = (mode, x, xn, gamma, beta, std)
else:
raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
# Store the updated running means back into bn_param
bn_param['running_mean'] = running_mean
bn_param['running_var'] = running_var
return out, cache
反向傳播過程
def batchnorm_backward(dout, cache):
"""
Backward pass for batch normalization.
For this implementation, you should write out a computation graph for
batch normalization on paper and propagate gradients backward through
intermediate nodes.
Inputs:
- dout: Upstream derivatives, of shape (N, D)
- cache: Variable of intermediates from batchnorm_forward.
Returns a tuple of:
- dx: Gradient with respect to inputs x, of shape (N, D)
- dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
- dbeta: Gradient with respect to shift parameter beta, of shape (D,)
"""
mode = cache[0]
if mode == 'train':
mode, x, gamma, xc, std, xn, out = cache
N = x.shape[0]
dbeta = dout.sum(axis=0)
dgamma = np.sum(xn * dout, axis=0)
dxn = gamma * dout
dxc = dxn / std
dstd = -np.sum((dxn * xc) / (std * std), axis=0)
dvar = 0.5 * dstd / std
dxc += (2.0 / N) * xc * dvar
dmu = np.sum(dxc, axis=0)
dx = dxc - dmu / N
elif mode == 'test':
mode, x, xn, gamma, beta, std = cache
dbeta = dout.sum(axis=0)
dgamma = np.sum(xn * dout, axis=0)
dxn = gamma * dout
dx = dxn / std
else:
raise ValueError(mode)
return dx, dgamma, dbeta
在實(shí)踐中刁岸,使用了批量歸一化的網(wǎng)絡(luò)對(duì)于不好的初始值有更強(qiáng)的魯棒性∷遥總結(jié)起來說就是批量歸一化可以理解為在網(wǎng)絡(luò)的每一層之前都做預(yù)處理虹曙,只是這種操作以另一種方式與網(wǎng)絡(luò)集成在了一起迫横。