介紹
歸一化層裆熙,主要有這幾種方法枫攀,BatchNorm(2015年)、LayerNorm(2016年)囊咏、InstanceNorm(2016年)恕洲、GroupNorm(2018年);
將輸入的圖像shape記為[N,C,H,W]梅割,這幾個方法主要區(qū)別是:
BatchNorm:batch方向做歸一化霜第,計算NHW的均值,對小batchsize效果不好户辞;
(BN主要缺點是對batchsize的大小比較敏感泌类,由于每次計算均值和方差是在一個batch上,所以如果batchsize太小底燎,則計算的均值刃榨、方差不足以代表整個數(shù)據(jù)分布)
LayerNorm:channel方向做歸一化,計算CHW的均值双仍;
(對RNN作用明顯)
InstanceNorm:一個batch枢希,一個channel內(nèi)做歸一化。計算HW的均值朱沃,用在風(fēng)格化遷移苞轿;
(因為在圖像風(fēng)格化中,生成結(jié)果主要依賴于某個圖像實例为流,所以對整個batch歸一化不適合圖像風(fēng)格化中呕屎,因而對HW做歸一化【床欤可以加速模型收斂,并且保持每個圖像實例之間的獨立尔当。)
GroupNorm:將channel方向分group莲祸,然后每個group內(nèi)做歸一化,算(C//G)HW的均值椭迎;這樣與batchsize無關(guān)锐帜,不受其約束。
1. BatchNorm詳解
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
參數(shù):
num_features:輸入的特征數(shù)畜号,該期望輸入的大小為’N x C [x L]’
eps: 為保證數(shù)值穩(wěn)定性(分母不能趨近或取0),給分母加上的值缴阎。默認(rèn)為1e-5。
momentum: 動態(tài)均值和動態(tài)方差所使用的動量简软。默認(rèn)為0.1蛮拔。
affine: 布爾值述暂,當(dāng)設(shè)為true,給該層添加可學(xué)習(xí)的仿射變換參數(shù)建炫。
track_running_stats:布爾值畦韭,當(dāng)設(shè)為true,記錄訓(xùn)練過程中的均值和方差肛跌;
實現(xiàn)公式:
# 示例代碼
import torch
import torch.nn as nn
"""
BatchNorm1d(時域)
Input: (N, C) or (N, C, L)
Output: (N, C) or (N, C, L)(same shape as input)
"""
# input = torch.randn(2, 10, 100)
input = torch.randn(2, 10)
# with learnable parameters
m1 = nn.BatchNorm1d(10)
# without learnable parameters
m2 = nn.BatchNorm1d(10, affine=False)
output1 = m1(input)
print(output1.shape)
output2 = m2(input)
print(output2.shape)
"""
BatchNorm2d(空域)
Input: (N, C, H, W)
Output: (N, C, H, W)(same shape as input)
"""
input = torch.randn(2, 10, 35, 45)
# with learnable parameters
m1 = nn.BatchNorm2d(10)
# without learnable parameters
m2 = nn.BatchNorm2d(10)
output1 = m1(input)
print(output1.shape)
output2 = m2(input)
print(output2.shape)
"""
BatchNorm3d(時空域)
Input: (N, C, D, H, W)
Output: (N, C, D, H, W)(same shape as input)
"""
input = torch.randn(2, 10, 20, 35, 45)
# with leanable parameters
m1 = nn.BatchNorm3d(10)
# without learnable parameters
m2 = nn.BatchNorm3d(10)
output1 = m1(input)
print(output1.shape)
output2 = m2(input)
print(output2.shape)
# 結(jié)果
torch.Size([2, 10])
torch.Size([2, 10])
torch.Size([2, 10, 35, 45])
torch.Size([2, 10, 35, 45])
torch.Size([2, 10, 20, 35, 45])
torch.Size([2, 10, 20, 35, 45])
2. GroupNorm詳解
torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True)
參數(shù):
num_groups:需要劃分的groups
num_features:輸入的特征數(shù)艺配,輸入的大小為’N x C x *’
eps: 為保證數(shù)值穩(wěn)定性(分母不能趨近或取0),給分母加上的值。默認(rèn)為1e-5
momentum: 動態(tài)均值和動態(tài)方差所使用的動量衍慎。默認(rèn)為0.1
affine: 布爾值转唉,當(dāng)設(shè)為true,給該層添加可學(xué)習(xí)的仿射變換參數(shù)
實現(xiàn)公式:
# 示例代碼
"""
GroupNorm
Input: (N, C, *)where C=num_channels
Output: (N, C, *)(same shape as input)
"""
input = torch.randn(2, 6, 10, 10)
# separate 6 channels into 3 groups
m1 = nn.GroupNorm(3, 6)
# Separate 6 channels into 6 groups (equivalent with InstanceNorm)
m2 = nn.GroupNorm(6, 6)
# Put all 6 channels into a single group (equivalent with LayerNorm)
m3 = nn.GroupNorm(1, 6)
output1 = m1(input)
print(output1.shape)
output2 = m2(input)
print(output2.shape)
output3 = m3(input)
print(output3.shape)
# 結(jié)果
torch.Size([2, 6, 10, 10])
torch.Size([2, 6, 10, 10])
torch.Size([2, 6, 10, 10])
3. InstanceNorm詳解
torch.nn.InstanceNorm1d(num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
torch.nn.InstanceNorm2d(num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
torch.nn.InstanceNorm3d(num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
參數(shù):
num_features:輸入的特征數(shù)稳捆,輸入的大小為’N x C [x L]’
eps:為保證數(shù)值穩(wěn)定性(分母不能趨近或取0),給分母加上的值赠法。默認(rèn)為1e-5
momentum: 動態(tài)均值和動態(tài)方差所使用的動量。默認(rèn)為0.1
affine: 布爾值眷柔,當(dāng)設(shè)為true期虾,給該層添加可學(xué)習(xí)的仿射變換參數(shù)
track_running_stats:布爾值,當(dāng)設(shè)為true驯嘱,記錄訓(xùn)練過程中的均值和方差镶苞;
實現(xiàn)公式:
# 示例代碼
"""
InstanceNorm1d
Input: (N, C, L)
Output: (N, C, L)(same shape as input)
"""
input = torch.randn(20, 100, 40)
# without learnable parameters
m1 = nn.InstanceNorm1d(100)
# with learnable parameters
m2 = nn.InstanceNorm1d(100, affine=True)
output1 = m1(input)
print(output1.shape)
output2 = m2(input)
print(output2.shape)
"""
InstanceNorm2d
Input: (N, C, H, W)
Output: (N, C, H, W)(same shape as input)
"""
input = torch.randn(20, 100, 35, 45)
# without learnable parameters
m1 = nn.InstanceNorm2d(100)
# with learnable parameters
m2 = nn.InstanceNorm2d(100, affine=True)
output1 = m1(input)
print(output1.shape)
output2 = m2(input)
print(output2.shape)
"""
InstanceNorm3d
Input: (N, C, D, H, W)
Output: (N, C, D, H, W)(same shape as input)
"""
input = torch.randn(20, 100, 35, 45)
# without learnable parameters
m1 = nn.InstanceNorm2d(100)
# with learnable parameters
m2 = nn.InstanceNorm2d(100, affine=True)
output1 = m1(input)
print(output1.shape)
output2 = m2(input)
print(output2.shape)
# 結(jié)果
torch.Size([20, 100, 40])
torch.Size([20, 100, 40])
torch.Size([20, 100, 35, 45])
torch.Size([20, 100, 35, 45])
torch.Size([20, 100, 35, 45])
torch.Size([20, 100, 35, 45])
4. LayerNorm詳解
torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True)
參數(shù):
normalized_shape:輸入尺寸
[?×normalized_shape[0]×normalized_shape[1]×…×normalized_shape[?1]]
eps:為保證數(shù)值穩(wěn)定性(分母不能趨近或取0),給分母加上的值。默認(rèn)為1e-5鞠评。
elementwise_affine:布爾值茂蚓,當(dāng)設(shè)為true,給該層添加可學(xué)習(xí)的仿射變換參數(shù)
實現(xiàn)公式:
# 示例代碼
"""
LayerNorm
Input: (N, *)
Output: (N, *)(same shape as input)
"""
input = torch.randn(20, 5, 10, 10)
# with learnable parameters
m1 = nn.LayerNorm(input.size()[1:])
# without learnable parameters
m2 = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
# normalize over last two dimensions
m3 = nn.LayerNorm([10, 10])
# normalize over last dimension of size 10
m4 = nn.LayerNorm(10)
output1 = m1(input)
print(output1.shape)
output2 = m2(input)
print(output2.shape)
output3 = m3(input)
print(output3.shape)
output4 = m4(input)
print(output4.shape)
# 結(jié)果
torch.Size([20, 5, 10, 10])
torch.Size([20, 5, 10, 10])
torch.Size([20, 5, 10, 10])
torch.Size([20, 5, 10, 10])
論文鏈接
- BatchNorm
https://arxiv.org/pdf/1502.03167.pdf- LayerNorm
https://arxiv.org/pdf/1607.06450v1.pdf- InstanceNorm
https://arxiv.org/pdf/1607.08022.pdf- GroupNorm
https://arxiv.org/pdf/1803.08494.pdf- SwitchableNorm
https://arxiv.org/pdf/1806.10779.pdf