參考鏈接: input()函數中的漏洞– Python2.x
前言?
最近在研究深度學習中圖像數據處理的細節(jié)宾巍,基于的平臺是PyTorch锡搜。心血來潮栅贴,總結一下汹买,好記性不如爛筆頭。?
Batch Normalization?
對于2015年出現的Batch Normalization1,2018年的文章Group Normalization2在Abstract中總結得言簡意賅,我直接copy過來端仰。?
?Batch Normalization (BN) is a milestone technique in the development of deep learning, enabling various networks to train. However, normalizing along the batch dimension introduces problems — BN’s error increases rapidly when the batch size becomes smaller, caused by inaccurate batch statistics estimation. This limits BN’s usage for training larger models and transferring features to computer vision tasks including detection, segmentation, and video, which require small batches constrained by memory consumption.?
機器學習中,進行模型訓練之前田藐,需對數據做歸一化處理荔烧,使其分布一致。在深度神經網絡訓練過程中汽久,通常一次訓練是一個batch鹤竭,而非全體數據。每個batch具有不同的分布產生了internal covarivate shift問題——在訓練過程中景醇,數據分布會發(fā)生變化臀稚,對下一層網絡的學習帶來困難。Batch Normalization強行將數據拉回到均值為0三痰,方差為1的正太分布上吧寺,一方面使得數據分布一致,另一方面避免梯度消失酒觅。?
結合圖1撮执,說明Batch Normalization的原理。假設在網絡中間經過某些卷積操作之后的輸出的feature maps的尺寸為N×C×W×H舷丹,5為batch size(N)抒钱,3為channel(C),W×H為feature map的寬高颜凯,則Batch Normalization的計算過程如下谋币。??
?圖 1
1.每個batch計算同一通道的均值
? ? ? ? ?μ
? ? ? ? \mu
? ? ?μ,如圖取channel 0症概,即
? ? ? ? ?c
? ? ? ? ?=
? ? ? ? ?0
? ? ? ? c=0
? ? ?c=0(紅色表示)?
? ? ? ? ? μ
? ? ? ? ? =
? ? ? ? ? ? ?∑
? ? ? ? ? ? ? n
? ? ? ? ? ? ? =
? ? ? ? ? ? ? 0
? ? ? ? ? ? ? N
? ? ? ? ? ? ? ?
? ? ? ? ? ? ? 1
? ? ? ? ? ? ?∑
? ? ? ? ? ? ? w
? ? ? ? ? ? ? =
? ? ? ? ? ? ? 0
? ? ? ? ? ? ? W
? ? ? ? ? ? ? ?
? ? ? ? ? ? ? 1
? ? ? ? ? ? ?∑
? ? ? ? ? ? ? h
? ? ? ? ? ? ? =
? ? ? ? ? ? ? 0
? ? ? ? ? ? ? H
? ? ? ? ? ? ? ?
? ? ? ? ? ? ? 1
? ? ? ? ? ? X
? ? ? ? ? ? [
? ? ? ? ? ? n
? ? ? ? ? ? ,
? ? ? ? ? ? c
? ? ? ? ? ? ,
? ? ? ? ? ? w
? ? ? ? ? ? ,
? ? ? ? ? ? h
? ? ? ? ? ? ]
? ? ? ? ? ? N
? ? ? ? ? ? ×
? ? ? ? ? ? W
? ? ? ? ? ? ×
? ? ? ? ? ? H
? ? ? ? ?\mu = \frac{\sum\limits_{n=0}^{N-1}\sum\limits_{w=0}^{W-1} \sum\limits_{h=0}^{H-1} X[n, c, w, h]}{N×W×H}
? ? ? μ=N×W×Hn=0∑N?1?w=0∑W?1?h=0∑H?1?X[n,c,w,h]?2.每個batch計算同一通道的方差
? ? ? ? ? σ
? ? ? ? ? 2
? ? ? ? σ^2
? ? ?σ2?
? ? ? ? ? ?σ
? ? ? ? ? ?2
? ? ? ? ? =
? ? ? ? ? ? ?∑
? ? ? ? ? ? ? n
? ? ? ? ? ? ? =
? ? ? ? ? ? ? 0
? ? ? ? ? ? ? N
? ? ? ? ? ? ? ?
? ? ? ? ? ? ? 1
? ? ? ? ? ? ?∑
? ? ? ? ? ? ? w
? ? ? ? ? ? ? =
? ? ? ? ? ? ? 0
? ? ? ? ? ? ? W
? ? ? ? ? ? ? ?
? ? ? ? ? ? ? 1
? ? ? ? ? ? ?∑
? ? ? ? ? ? ? h
? ? ? ? ? ? ? =
? ? ? ? ? ? ? 0
? ? ? ? ? ? ? H
? ? ? ? ? ? ? ?
? ? ? ? ? ? ? 1
? ? ? ? ? ? (
? ? ? ? ? ? X
? ? ? ? ? ? [
? ? ? ? ? ? n
? ? ? ? ? ? ,
? ? ? ? ? ? c
? ? ? ? ? ? ,
? ? ? ? ? ? w
? ? ? ? ? ? ,
? ? ? ? ? ? h
? ? ? ? ? ? ]
? ? ? ? ? ? ?
? ? ? ? ? ? μ
? ? ? ? ? ? ?)
? ? ? ? ? ? ?2
? ? ? ? ? ? N
? ? ? ? ? ? ×
? ? ? ? ? ? W
? ? ? ? ? ? ×
? ? ? ? ? ? H
? ? ? ? ?σ^2 = \frac{\sum\limits_{n=0}^{N-1}\sum\limits_{w=0}^{W-1} \sum\limits_{h=0}^{H-1} (X[n, c, w, h]-\mu)^2}{N×W×H}
? ? ? σ2=N×W×Hn=0∑N?1?w=0∑W?1?h=0∑H?1?(X[n,c,w,h]?μ)2?3.對當前channel下feature map中每個點
? ? ? ? ?x
? ? ? ? x
? ? ?x蕾额,索引形式
? ? ? ? ?X
? ? ? ? ?[
? ? ? ? ?n
? ? ? ? ?,
? ? ? ? ?c
? ? ? ? ?,
? ? ? ? ?w
? ? ? ? ?,
? ? ? ? ?h
? ? ? ? ?]
? ? ? ? X[n, c, w, h]
? ? ?X[n,c,w,h],做歸一化?
? ? ? ? ? ?x
? ? ? ? ? ? ′
? ? ? ? ? =
? ? ? ? ? ? (
? ? ? ? ? ? x
? ? ? ? ? ? ?
? ? ? ? ? ? μ
? ? ? ? ? ? )
? ? ? ? ? ? ? σ
? ? ? ? ? ? ? 2
? ? ? ? ? ? ?+
? ? ? ? ? ? ??
? ? ? ? ?x^{'}=\frac{(x-\mu)}{\sqrt{σ^2+\epsilon}}
? ? ? x′=σ2+?
? ? ? ? ? ? ? ? ? ? ?(x?μ)?4.增加縮放和平移變量
? ? ? ? ?γ
? ? ? ? \gamma
? ? ?γ和
? ? ? ? ?β
? ? ? ? \beta
? ? ?β(可學習的仿射變換參數)彼城,歸一化后的值?
? ? ? ? ? y
? ? ? ? ? =
? ? ? ? ? γ
? ? ? ? ? ?x
? ? ? ? ? ? ′
? ? ? ? ? +
? ? ? ? ? β
? ? ? ? ?y=\gamma x^{'}+\beta
? ? ? y=γx′+β 簡化公式:?
? ? ? ? ? y
? ? ? ? ? =
? ? ? ? ? ? x
? ? ? ? ? ? ?
? ? ? ? ? ? μ
? ? ? ? ? ? ? σ
? ? ? ? ? ? ? 2
? ? ? ? ? ? ?+
? ? ? ? ? ? ??
? ? ? ? ? γ
? ? ? ? ? +
? ? ? ? ? β
? ? ? ? ?y=\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}\gamma +\beta
? ? ? y=σ2+?
? ? ? ? ? ? ? ? ? ? ?x?μ?γ+β 原文中的算法描述如下诅蝶,? 注:上圖1所示
? ? ? ? ? m
? ? ? ? ?m
? ? ? m就是
? ? ? ? ? N
? ? ? ? ? ?
? ? ? ? ? W
? ? ? ? ? ?
? ? ? ? ? H
? ? ? ? ?N*W*H
? ? ? N?W?H?
PyTorch的nn.BatchNorm2d()函數?
理解了Batch Normalization的過程退个,PyTorch里面的函數就參考其文檔3用就好。 BatchNorm2d()內部的參數如下:?
num_features:一般情況下輸入的數據格式為batch_size * num_features * height * width调炬,即為特征數语盈,channel數eps:分母中添加的一個值,目的是為了計算的穩(wěn)定性缰泡,默認:1e-5momentum:一個用于運行過程中均值和方差的一個估計參數刀荒,默認值為
? ? ? ? ?0.1
? ? ? ? 0.1
? ? ?0.1;
? ? ? ? ? ?x
? ? ? ? ? ?^
? ? ? ? ? ?n
? ? ? ? ? ?e
? ? ? ? ? ?w
? ? ? ? ?=
? ? ? ? ?(
? ? ? ? ?1
? ? ? ? ??
? ? ? ? ?m
? ? ? ? ?o
? ? ? ? ?m
? ? ? ? ?e
? ? ? ? ?n
? ? ? ? ?t
? ? ? ? ?u
? ? ? ? ?m
? ? ? ? ?)
? ? ? ? ?×
? ? ? ? ? x
? ? ? ? ? ^
? ? ? ? ?+
? ? ? ? ?m
? ? ? ? ?o
? ? ? ? ?m
? ? ? ? ?e
? ? ? ? ?n
? ? ? ? ?t
? ? ? ? ?u
? ? ? ? ?m
? ? ? ? ?×
? ? ? ? ? x
? ? ? ? ? t
? ? ? ? \hat{x}_{new} =(1?momentum) × \hat{x} +momentum×x_t
? ? ?x^new?=(1?momentum)×x^+momentum×xt?棘钞,其中
? ? ? ? ? x
? ? ? ? ? ^
? ? ? ? \hat{x}
? ? ?x^是估計值缠借,
? ? ? ? ? x
? ? ? ? ? t
? ? ? ? x_t
? ? ?xt?是新的觀測值affine:當設為true時,給定可以學習的系數矩陣
? ? ? ? ?γ
? ? ? ? \gamma
? ? ?γ和
? ? ? ? ?β
? ? ? ? \beta
? ? ?β?
Show me the codes?
import torch
import torch.nn as nn
def checkBN(debug = False):
? ? # parameters
? ? N = 5 # batch size
? ? C = 3 # channel
? ? W = 2 # width of feature map
? ? H = 2 # height of feature map
? ? # batch normalization layer
? ? BN = nn.BatchNorm2d(C,affine=True) #gamma和beta, 其維度與channel數相同
? ? # input and output
? ? featuremaps = torch.randn(N,C,W,H)
? ? output = BN(featuremaps)
? ? # checkout
? ? ###########################################
? ? if debug:
? ? ? ? print("input feature maps:\n",featuremaps)
? ? ? ? print("normalized feature maps: \n",output)
? ? ###########################################
? ? # manually operation, the first channel
? ? X = featuremaps[:,0,:,:]
? ? firstDimenMean = torch.Tensor.mean(X)
? ? firstDimenVar = torch.Tensor.var(X,False) #Bessel's Correction貝塞爾校正不被使用
? ? BN_one = ((input[0,0,0,0] - firstDimenMean)/(torch.pow(firstDimenVar+BN.eps,0.5) )) * BN.weight[0] + BN.bias[0]
? ? print('+++'*15,'\n','manually operation: ', BN_one)
? ? print('==='*15,'\n','pytorch result: ', output[0,0,0,0])
if __name__=="__main__":
? ? checkBN()
可以看出手算的結果和PyTorch的nn.BatchNorm2d的計算結果一致宜猜。?
+++++++++++++++++++++++++++++++++++++++++++++
?manually operation:? tensor(-0.0327, grad_fn=<AddBackward0>)
=============================================
?pytorch result:? tensor(-0.0327, grad_fn=<SelectBackward>)
貝塞爾校正?
代碼中出現泼返,求方差時是否需要貝塞爾校正,即從樣本方差到總體方差的校正姨拥。 方差公式從符隙,?
? ? ? ? ? σ
? ? ? ? ? 2
? ? ? ? ?=
? ? ? ? ? ? ∑
? ? ? ? ? ? ?i
? ? ? ? ? ? ?=
? ? ? ? ? ? ?0
? ? ? ? ? ? ?N
? ? ? ? ? ? ??
? ? ? ? ? ? ?1
? ? ? ? ? ?(
? ? ? ? ? ? x
? ? ? ? ? ? i
? ? ? ? ? ??
? ? ? ? ? ?m
? ? ? ? ? ?e
? ? ? ? ? ?a
? ? ? ? ? ?n
? ? ? ? ? ?(
? ? ? ? ? ?x
? ? ? ? ? ?)
? ? ? ? ? ? )
? ? ? ? ? ? 2
? ? ? ? ? N
? ? ? ? \sigma^2 = \frac{\sum\limits_{i=0}^{N-1} (x_i-mean(x))^2}{N}
? ? ?σ2=Ni=0∑N?1?(xi??mean(x))2? 變成(基于樣本的總體方差的有偏估計),?
? ? ? ? ? σ
? ? ? ? ? 2
? ? ? ? ?=
? ? ? ? ? ? ∑
? ? ? ? ? ? ?i
? ? ? ? ? ? ?=
? ? ? ? ? ? ?0
? ? ? ? ? ? ?N
? ? ? ? ? ? ??
? ? ? ? ? ? ?1
? ? ? ? ? ?(
? ? ? ? ? ? x
? ? ? ? ? ? i
? ? ? ? ? ??
? ? ? ? ? ?m
? ? ? ? ? ?e
? ? ? ? ? ?a
? ? ? ? ? ?n
? ? ? ? ? ?(
? ? ? ? ? ?x
? ? ? ? ? ?)
? ? ? ? ? ? )
? ? ? ? ? ? 2
? ? ? ? ? ?N
? ? ? ? ? ??
? ? ? ? ? ?1
? ? ? ? \sigma^2 = \frac{\sum\limits_{i=0}^{N-1} (x_i-mean(x))^2}{N-1}
? ? ?σ2=N?1i=0∑N?1?(xi??mean(x))2??
Reference?
?Ioffe, Sergey, and Christian Szegedy. “Batch normalization: Accelerating deep network training by reducing internal covariate shift.” arXiv preprint arXiv:1502.03167 (2015). ?? ?? Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European Conference on Computer Vision (ECCV). 2018. ?? BatchNorm2d ??