PyTorch中定義模型時蚊惯,有時候會遇到self.register_buffer('name', Tensor)的操作魂仍,該方法的作用是定義一組參數(shù),該組參數(shù)的特別之處在于:模型訓(xùn)練時不會更新(即調(diào)用 optimizer.step() 后該組參數(shù)不會變化拣挪,只可人為地改變它們的值)擦酌,但是保存模型時,該組參數(shù)又作為模型參數(shù)不可或缺的一部分被保存菠劝。
為了更好地理解這句話赊舶,按照慣例,我們通過一個例子實(shí)驗(yàn)來解釋:
首先,定義一個模型并實(shí)例化:
import torch
import torch.nn as nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# (1)常見定義模型時的操作
self.param_nn = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(1, 1, 3, bias=False)),
('fc', nn.Linear(1, 2, bias=False))
]))
# (2)使用register_buffer()定義一組參數(shù)
self.register_buffer('param_buf', torch.randn(1, 2))
# (3)使用形式類似的register_parameter()定義一組參數(shù)
self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))
# (4)按照類的屬性形式定義一組變量
self.param_attr = torch.randn(1, 2)
def forward(self, x):
return x
net = Model()
上例中笼平,我們通過繼承nn.Module類定義了一個模型园骆,在模型參數(shù)的定義中,我們分別以(1)常見的nn.Module類形式寓调、(2)self.register_buffer()形式锌唾、(3)self.register_parameter()形式,以及(4)python類的屬性形式定義了4組參數(shù)夺英。
(1)哪些參數(shù)可以在模型訓(xùn)練時被更新晌涕?
這可以通過net.parameters()查看,因?yàn)槎x優(yōu)化器時是這樣的:optimizer = SGD(net.parameters(), lr=0.1)痛悯。為了方便查看余黎,我們使用 net.named_parameters():
In [8]: list(net.named_parameters())
Out[8]:
[('param_reg',
Parameter containing:
tensor([[-0.0617, -0.8984]], requires_grad=True)),
('param_nn.conv.weight',
Parameter containing:
tensor([[[[-0.3183, -0.0426, -0.2984],
[-0.1451, 0.2686, 0.0556],
[-0.3155, 0.0451, 0.0702]]]], requires_grad=True)),
('param_nn.fc.weight',
Parameter containing:
tensor([[-0.4647],
[ 0.7753]], requires_grad=True))]
可以看到,我們定義的4組參數(shù)中载萌,只有(1)和(3)定義的參數(shù)可以被更新惧财,而self.register_buffer()和以python類的屬性形式定義的參數(shù)都不能被更新。也就是說扭仁,modules和parameters可以被更新垮衷,而buffers和普通類屬性不行。
那既然這兩種形式定義的參數(shù)都不能被更新乖坠,二者可以互相替代嗎帘靡?答案是不可以,原因看下一節(jié):
(2)這其中哪些才算是模型的參數(shù)呢瓤帚?
模型的所有參數(shù)都裝在 state_dict 中描姚,因?yàn)楸4婺P蛥?shù)時直接保存 net.state_dict()。我們看一下其中究竟是哪些參數(shù):
In [9]: net.state_dict()
Out[9]:
OrderedDict([('param_reg', tensor([[-0.0617, -0.8984]])),
('param_buf', tensor([[-1.0517, 0.7663]])),
('param_nn.conv.weight',
tensor([[[[-0.3183, -0.0426, -0.2984],
[-0.1451, 0.2686, 0.0556],
[-0.3155, 0.0451, 0.0702]]]])),
('param_nn.fc.weight',
tensor([[-0.4647],
[ 0.7753]]))])
可以看到戈次,通過 nn.Module 類轩勘、self.register_buffer() 以及 self.register_parameter() 定義的參數(shù)都在 state-dict 中,只有用python類的屬性形式定義的參數(shù)不包含其中怯邪。也就是說绊寻,保存模型時,buffers悬秉,modules和parameters都可以被保存澄步,但普通屬性不行。
(3)self.register_buffer() 的使用方法
在用self.register_buffer('name', tensor) 定義模型參數(shù)時和泌,其有兩個形參需要傳入村缸。第一個是字符串,表示這組參數(shù)的名字武氓;第二個就是tensor 形式的參數(shù)梯皿。
在模型定義中調(diào)用這個參數(shù)時(比如改變這組參數(shù)的值)仇箱,可以使用self.name 獲取。本文例中东羹,就可用self.param_buf 引用剂桥。這和類屬性的引用方法是一樣的。
在實(shí)例化模型后属提,獲取這組參數(shù)的值時权逗,可以用 net.buffers() 方法獲取,該方法返回一個生成器(可迭代變量):
In [10]: net.buffers()
Out[10]: <generator object Module.buffers at 0x00000289CA0032E0>
In [11]: list(net.buffers())
Out[11]: [tensor([[-1.0517, 0.7663]])]
# 也可以用named_buffers() 方法同時獲取名字
In [12]: list(net.named_buffers())
Out[12]: [('param_buf', tensor([[-1.0517, 0.7663]]))]
(4)modules, parameters 和 buffers
實(shí)際上冤议,PyTorch 定義的模型用OrderedDict() 的方式記錄這三種類型斟薇,分別保存在self._modules, self._parameters 和 self._buffers 三個私有屬性中求类。調(diào)試模式時就可以看到每個模型都有這幾個私有屬性:
由于是私有屬性,我們無法在實(shí)例化的變量上調(diào)用這些屬性屹耐,可以在模型定義中調(diào)用它們:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# 常見定義模型時的操作
self.param_nn = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(1, 1, 3, bias=False)),
('fc', nn.Linear(1, 2, bias=False))
]))
# 使用register_buffer()定義一組參數(shù)
self.register_buffer('param_buf', torch.randn(1, 2))
# 使用形式類似的register_parameter()定義一組參數(shù)
self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))
# 按照類的屬性形式定義一組變量
self.param_attr = torch.randn(1, 2)
print('self._modules: ', self._modules)
print('self._parameters: ', self._modules)
print('self._buffers: ', self._modules)
def forward(self, x):
return x
模型實(shí)例化時尸疆,調(diào)用了 init() 方法,我們就可以看到調(diào)用輸出結(jié)果:
In [21]: net = Model()
self._modules: OrderedDict([('param_nn', Sequential(
(conv): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), bias=False)
(fc): Linear(in_features=1, out_features=2, bias=False)
))])
self._parameters: OrderedDict([('param_reg', Parameter containing:
tensor([[-0.5666, -0.2624]], requires_grad=True))])
self._buffers: OrderedDict([('param_buf', tensor([[-0.4005, -0.8199]]))])
在模型的實(shí)例化變量上調(diào)用時惶岭,三者有著相似的方法:
net.modules()
net.named_modules()
net.parameters()
net.named_parameters()
net.buffers()
net.named_buffers()
細(xì)心的讀著可能會發(fā)現(xiàn)寿弱,self._parameters 和 net.parameters() 的返回值并不相同。這里self._parameters 只記錄了使用 self.register_parameter() 定義的參數(shù)按灶,而net.parameters() 返回所有可學(xué)習(xí)參數(shù)症革,包括self._modules 中的參數(shù)和self._parameters 參數(shù)的并集。
實(shí)際上鸯旁,由nn.Module類定義的參數(shù)和self.register_parameter() 定義的參數(shù)性質(zhì)是一樣的噪矛,都是nn.Parameter 類型。