權(quán)重初始化對于訓(xùn)練神經(jīng)網(wǎng)絡(luò)至關(guān)重要,好的初始化權(quán)重可以有效的避免梯度消失等問題的發(fā)生。
在pytorch的使用過程中有幾種權(quán)重初始化的方法供大家參考。
注意:第一種方法不推薦。盡量使用后兩種方法苔埋。
# not recommend
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# recommend
def initialize_weights(m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
# recommend
def weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)
nn.init.xavier_normal_(m.bias.data)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias, 0)
編寫好weights_init
函數(shù)后,可以使用模型的apply
方法對模型進(jìn)行權(quán)重初始化蜒犯。
net = Residual() # generate an instance network from the Net class
net.apply(weights_init) # apply weight init