PyTorch nn.Module中的self.register_buffer()解析

PyTorch中定義模型時蚊惯,有時候會遇到self.register_buffer('name', Tensor)的操作魂仍,該方法的作用是定義一組參數(shù),該組參數(shù)的特別之處在于:模型訓(xùn)練時不會更新(即調(diào)用 optimizer.step() 后該組參數(shù)不會變化拣挪,只可人為地改變它們的值)擦酌,但是保存模型時,該組參數(shù)又作為模型參數(shù)不可或缺的一部分被保存菠劝。

為了更好地理解這句話赊舶,按照慣例,我們通過一個例子實(shí)驗(yàn)來解釋:

首先,定義一個模型并實(shí)例化:

import torch 
import torch.nn as nn
from collections import OrderedDict

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # (1)常見定義模型時的操作
        self.param_nn = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(1, 1, 3, bias=False)),
            ('fc', nn.Linear(1, 2, bias=False))
        ]))

        # (2)使用register_buffer()定義一組參數(shù)
        self.register_buffer('param_buf', torch.randn(1, 2))

        # (3)使用形式類似的register_parameter()定義一組參數(shù)
        self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))

        # (4)按照類的屬性形式定義一組變量
        self.param_attr = torch.randn(1, 2) 

    def forward(self, x):
        return x

net = Model()

上例中笼平,我們通過繼承nn.Module類定義了一個模型园骆,在模型參數(shù)的定義中,我們分別以(1)常見的nn.Module類形式寓调、(2)self.register_buffer()形式锌唾、(3)self.register_parameter()形式,以及(4)python類的屬性形式定義了4組參數(shù)夺英。

(1)哪些參數(shù)可以在模型訓(xùn)練時被更新晌涕?

這可以通過net.parameters()查看,因?yàn)槎x優(yōu)化器時是這樣的:optimizer = SGD(net.parameters(), lr=0.1)痛悯。為了方便查看余黎,我們使用 net.named_parameters():

In [8]: list(net.named_parameters())
Out[8]:
[('param_reg',
  Parameter containing:
  tensor([[-0.0617, -0.8984]], requires_grad=True)),
 ('param_nn.conv.weight',
  Parameter containing:
  tensor([[[[-0.3183, -0.0426, -0.2984],
            [-0.1451,  0.2686,  0.0556],
            [-0.3155,  0.0451,  0.0702]]]], requires_grad=True)),
 ('param_nn.fc.weight',
  Parameter containing:
  tensor([[-0.4647],
          [ 0.7753]], requires_grad=True))]

可以看到,我們定義的4組參數(shù)中载萌,只有(1)和(3)定義的參數(shù)可以被更新惧财,而self.register_buffer()和以python類的屬性形式定義的參數(shù)都不能被更新。也就是說扭仁,modules和parameters可以被更新垮衷,而buffers和普通類屬性不行。

那既然這兩種形式定義的參數(shù)都不能被更新乖坠,二者可以互相替代嗎帘靡?答案是不可以,原因看下一節(jié):

(2)這其中哪些才算是模型的參數(shù)呢瓤帚?

模型的所有參數(shù)都裝在 state_dict 中描姚,因?yàn)楸4婺P蛥?shù)時直接保存 net.state_dict()。我們看一下其中究竟是哪些參數(shù):

In [9]: net.state_dict()
Out[9]:
OrderedDict([('param_reg', tensor([[-0.0617, -0.8984]])),
             ('param_buf', tensor([[-1.0517,  0.7663]])),
             ('param_nn.conv.weight',
              tensor([[[[-0.3183, -0.0426, -0.2984],
                        [-0.1451,  0.2686,  0.0556],
                        [-0.3155,  0.0451,  0.0702]]]])),
             ('param_nn.fc.weight',
              tensor([[-0.4647],
                      [ 0.7753]]))])

可以看到戈次,通過 nn.Module 類轩勘、self.register_buffer() 以及 self.register_parameter() 定義的參數(shù)都在 state-dict 中,只有用python類的屬性形式定義的參數(shù)不包含其中怯邪。也就是說绊寻,保存模型時,buffers悬秉,modules和parameters都可以被保存澄步,但普通屬性不行。

(3)self.register_buffer() 的使用方法

在用self.register_buffer('name', tensor) 定義模型參數(shù)時和泌,其有兩個形參需要傳入村缸。第一個是字符串,表示這組參數(shù)的名字武氓;第二個就是tensor 形式的參數(shù)梯皿。

在模型定義中調(diào)用這個參數(shù)時(比如改變這組參數(shù)的值)仇箱,可以使用self.name 獲取。本文例中东羹,就可用self.param_buf 引用剂桥。這和類屬性的引用方法是一樣的。

在實(shí)例化模型后属提,獲取這組參數(shù)的值時权逗,可以用 net.buffers() 方法獲取,該方法返回一個生成器(可迭代變量):

In [10]: net.buffers()
Out[10]: <generator object Module.buffers at 0x00000289CA0032E0>

In [11]: list(net.buffers())
Out[11]: [tensor([[-1.0517,  0.7663]])]

# 也可以用named_buffers() 方法同時獲取名字
In [12]: list(net.named_buffers())
Out[12]: [('param_buf', tensor([[-1.0517,  0.7663]]))]

(4)modules, parameters 和 buffers

實(shí)際上冤议,PyTorch 定義的模型用OrderedDict() 的方式記錄這三種類型斟薇,分別保存在self._modules, self._parameters 和 self._buffers 三個私有屬性中求类。調(diào)試模式時就可以看到每個模型都有這幾個私有屬性:


調(diào)試模式 變量窗口

由于是私有屬性,我們無法在實(shí)例化的變量上調(diào)用這些屬性屹耐,可以在模型定義中調(diào)用它們:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # 常見定義模型時的操作
        self.param_nn = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(1, 1, 3, bias=False)),
            ('fc', nn.Linear(1, 2, bias=False))
        ]))

        # 使用register_buffer()定義一組參數(shù)
        self.register_buffer('param_buf', torch.randn(1, 2))

        # 使用形式類似的register_parameter()定義一組參數(shù)
        self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))

        # 按照類的屬性形式定義一組變量
        self.param_attr = torch.randn(1, 2) 

        print('self._modules: ', self._modules)
        print('self._parameters: ', self._modules)
        print('self._buffers: ', self._modules)

    def forward(self, x):
        return x

模型實(shí)例化時尸疆,調(diào)用了 init() 方法,我們就可以看到調(diào)用輸出結(jié)果:

In [21]: net = Model()
self._modules:  OrderedDict([('param_nn', Sequential(
  (conv): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (fc): Linear(in_features=1, out_features=2, bias=False)
))])

self._parameters:  OrderedDict([('param_reg', Parameter containing:
tensor([[-0.5666, -0.2624]], requires_grad=True))])

self._buffers:  OrderedDict([('param_buf', tensor([[-0.4005, -0.8199]]))])

在模型的實(shí)例化變量上調(diào)用時惶岭,三者有著相似的方法:

net.modules()
net.named_modules()

net.parameters()
net.named_parameters()

net.buffers()
net.named_buffers()

細(xì)心的讀著可能會發(fā)現(xiàn)寿弱,self._parameters 和 net.parameters() 的返回值并不相同。這里self._parameters 只記錄了使用 self.register_parameter() 定義的參數(shù)按灶,而net.parameters() 返回所有可學(xué)習(xí)參數(shù)症革,包括self._modules 中的參數(shù)和self._parameters 參數(shù)的并集。

實(shí)際上鸯旁,由nn.Module類定義的參數(shù)和self.register_parameter() 定義的參數(shù)性質(zhì)是一樣的噪矛,都是nn.Parameter 類型。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末铺罢,一起剝皮案震驚了整個濱河市艇挨,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌韭赘,老刑警劉巖缩滨,帶你破解...
    沈念sama閱讀 211,265評論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異泉瞻,居然都是意外死亡脉漏,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,078評論 2 385
  • 文/潘曉璐 我一進(jìn)店門袖牙,熙熙樓的掌柜王于貴愁眉苦臉地迎上來侧巨,“玉大人,你說我怎么就攤上這事鞭达∪信荩” “怎么了巧娱?”我有些...
    開封第一講書人閱讀 156,852評論 0 347
  • 文/不壞的土叔 我叫張陵,是天一觀的道長烘贴。 經(jīng)常有香客問我禁添,道長讥电,這世上最難降的妖魔是什么巡验? 我笑而不...
    開封第一講書人閱讀 56,408評論 1 283
  • 正文 為了忘掉前任如迟,我火速辦了婚禮菜谣,結(jié)果婚禮上酪耕,老公的妹妹穿的比我還像新娘怀大。我一直安慰自己肤粱,他們只是感情好捉蚤,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,445評論 5 384
  • 文/花漫 我一把揭開白布汽纠。 她就那樣靜靜地躺著卫键,像睡著了一般。 火紅的嫁衣襯著肌膚如雪虱朵。 梳的紋絲不亂的頭發(fā)上莉炉,一...
    開封第一講書人閱讀 49,772評論 1 290
  • 那天,我揣著相機(jī)與錄音碴犬,去河邊找鬼絮宁。 笑死,一個胖子當(dāng)著我的面吹牛服协,可吹牛的內(nèi)容都是我干的绍昂。 我是一名探鬼主播,決...
    沈念sama閱讀 38,921評論 3 406
  • 文/蒼蘭香墨 我猛地睜開眼偿荷,長吁一口氣:“原來是場噩夢啊……” “哼窘游!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起跳纳,我...
    開封第一講書人閱讀 37,688評論 0 266
  • 序言:老撾萬榮一對情侶失蹤张峰,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后棒旗,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體喘批,經(jīng)...
    沈念sama閱讀 44,130評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,467評論 2 325
  • 正文 我和宋清朗相戀三年铣揉,在試婚紗的時候發(fā)現(xiàn)自己被綠了饶深。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,617評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡逛拱,死狀恐怖敌厘,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情朽合,我是刑警寧澤俱两,帶...
    沈念sama閱讀 34,276評論 4 329
  • 正文 年R本政府宣布饱狂,位于F島的核電站,受9級特大地震影響宪彩,放射性物質(zhì)發(fā)生泄漏休讳。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,882評論 3 312
  • 文/蒙蒙 一尿孔、第九天 我趴在偏房一處隱蔽的房頂上張望俊柔。 院中可真熱鬧,春花似錦活合、人聲如沸雏婶。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,740評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽留晚。三九已至,卻和暖如春告嘲,著一層夾襖步出監(jiān)牢的瞬間错维,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,967評論 1 265
  • 我被黑心中介騙來泰國打工状蜗, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留需五,地道東北人鹉动。 一個月前我還...
    沈念sama閱讀 46,315評論 2 360
  • 正文 我出身青樓轧坎,卻偏偏與公主長得像,于是被迫代替她去往敵國和親泽示。 傳聞我的和親對象是個殘疾皇子缸血,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,486評論 2 348

推薦閱讀更多精彩內(nèi)容