PyTorch之HOOK——獲取神經網絡特征和梯度的有效工具

本文首發(fā)于簡書 西北小生_ 的博客:http://www.reibang.com/u/898c7641f6ea丧裁,未經允許掉蔬,禁止轉載勾习!

為了更深入地理解神經網絡模型,有時候我們需要觀察它訓練得到的卷積核懈玻、特征圖或者梯度等信息巧婶,這在CNN可視化研究中經常用到。其中涂乌,卷積核最易獲取艺栈,將模型參數保存即可得到;特征圖是中間變量湾盒,所對應的圖像處理完即會被系統(tǒng)清除湿右,否則將嚴重占用內存;梯度跟特征圖類似罚勾,除了葉子結點外毅人,其它中間變量的梯度都被會內存釋放,因而不能直接獲取尖殃。
最容易想到的獲取方法就是改變模型結構丈莺,在forward的最后不但返回模型的預測輸出,還返回所需要的特征圖等信息送丰。

如何在不改變模型結構的基礎上獲取特征圖缔俄、梯度等信息呢?

Pytorch的hook編程可以在不改變網絡結構的基礎上有效獲取器躏、改變模型中間變量以及梯度等信息俐载。
hook可以提取或改變Tensor的梯度,也可以獲取nn.Module的輸出和梯度(這里不能改變)登失。因此有3個hook函數用于實現以上功能:

Tensor.register_hook(hook_fn)遏佣,
nn.Module.register_forward_hook(hook_fn),
nn.Module.register_backward_hook(hook_fn).

下面對其用法進行一一介紹壁畸。

1.Tensor.register_hook(hook_fn)

功能:注冊一個反向傳播hook函數贼急,用于自動記錄Tensor的梯度茅茂。
PyTorch對中間變量和非葉子節(jié)點的梯度運行完后會自動釋放,以減緩內存占用太抓。什么是中間變量空闲?什么是非葉子節(jié)點?


Tensor計算

上圖中走敌,a碴倾,b,d就是葉子節(jié)點掉丽,c跌榔,e,o是非葉子節(jié)點捶障,也是中間變量僧须。

In [18]: a = torch.Tensor([1,2]).requires_grad_() 
    ...: b = torch.Tensor([3,4]).requires_grad_() 
    ...: d = torch.Tensor([2]).requires_grad_() 
    ...: c = a + b 
    ...: e = c * d 
    ...: o = e.sum()     

In [19]: o.backward()

In [20]: print(a.grad)
tensor([2., 2.])

In [21]: print(b.grad)
tensor([2., 2.])

In [22]: print(c.grad)
None

In [23]: print(d.grad)
tensor([10.])

In [24]: print(e.grad)
None

In [25]: print(o.grad)
None

可以從程序的輸出中看到,a项炼,b担平,d作為葉子節(jié)點,經過反向傳播后梯度值仍然保留锭部,而其它非葉子節(jié)點的梯度已經被自動釋放了暂论,要想得到它們的梯度值,就需要使用hook了拌禾。

我們首先自定義一個hook_fn函數取胎,用于記錄對Tensor梯度的操作,然后用Tensor.register_hook(hook_fn)對要獲取梯度的非葉子結點的Tensor進行注冊湃窍,然后重新反向傳播一次:

In [44]: def hook_fn(grad):
    ...:     print(grad)
    ...:

In [45]: e.register_hook(hook_fn)
Out[45]: <torch.utils.hooks.RemovableHandle at 0x1d139cf0a88>

In [46]: o.backward()
tensor([1., 1.])

這時就自動輸出了e的梯度闻蛀。

自定義的hook_fn函數的函數名可以是任取的,它的參數是grad坝咐,表示Tensor的梯度循榆。這個自定義函數主要是用于描述對Tensor梯度值的操作,上例中我們是對梯度直接進行輸出墨坚,所以是print(grad)秧饮。我們也可以把梯度裝在一個列表或字典里,甚至可以修改梯度泽篮,這樣如果梯度很小的時候將其變大一點就可以防止梯度消失的問題了:

In [28]: a = torch.Tensor([1,2]).requires_grad_() 
    ...: b = torch.Tensor([3,4]).requires_grad_() 
    ...: d = torch.Tensor([2]).requires_grad_() 
    ...: c = a + b 
    ...: e = c * d 
    ...: o = e.sum()                                                            

In [29]: grad_list = []                                                         

In [30]: def hook(grad): 
    ...:     grad_list.append(grad)    # 將梯度裝在列表里
    ...:     return 2 * grad    # 將梯度放大兩倍
    ...:                                                                        

In [31]: c.register_hook(hook)                                                  
Out[31]: <torch.utils.hooks.RemovableHandle at 0x7f009b713208>

In [32]: o.backward()                                                           

In [33]: grad_list                                                              
Out[33]: [tensor([2., 2.])]

In [34]: a.grad                                                                 
Out[34]: tensor([4., 4.])

In [35]: b.grad                                                                 
Out[35]: tensor([4., 4.])

上例中盗尸,我們定義的hook函數執(zhí)行了兩個操作:一是將梯度裝進列表grad_list中,二是把梯度放大兩倍帽撑。從輸出中我們可以看到泼各,執(zhí)行反向傳播后,我們注冊的非葉子節(jié)點c的梯度保存在了列表grad_list中亏拉,并且a和b的梯度都變?yōu)樵瓉淼膬杀犊垓摺_@里需要注意的是逆巍,如果要將梯度值裝在一個列表或字典里,那么首先要定義一個同名的全局變量的列表或字典莽使,即使是局部變量锐极,也要在自定義的hook函數外面。另一個需要注意的點就是如果要改變梯度值芳肌,hook函數要有返回值灵再,返回改變后的梯度。

這里總結一下亿笤,如果要獲取非葉子節(jié)點Tensor的梯度值翎迁,我們需要在反向傳播前
1)自定義一個hook函數,描述對梯度的操作净薛,函數名自擬汪榔,參數只有grad,表示Tensor的梯度肃拜;
2)對要獲取梯度的Tensor用方法Tensor.register_hook(hook)進行注冊揍异。
3)執(zhí)行反向傳播。

2.nn.Module.register_forward_hook(hook_fn)和nn.Module.register_backward_hook(hook_fn)

這兩個的操作對象都是nn.Module類爆班,如神經網絡中的卷積層(nn.Conv2d),全連接層(nn.Linear)辱姨,池化層(nn.MaxPool2d, nn.AvgPool2d)柿菩,激活層(nn.ReLU)或者nn.Sequential定義的小模塊等,所以放在一起講雨涛。

對于模型的中間模塊枢舶,也可以視作中間節(jié)點(非葉子節(jié)點),它的輸出為特征圖或激活值替久,反向傳播的梯度值都會被系統(tǒng)自動釋放凉泄,如果想要獲取它們,就要用到hook功能蚯根。

有名字即可看出后众,register_forward_hook是獲取前向傳播的輸出的,即特征圖或激活值颅拦;register_backward_hook是獲取反向傳播的輸出的蒂誉,即梯度值。它們的用法和上面介紹的register_hook類似距帅。我們先看一下hook_fn的定義:

對于register_forward_hook(hook_fn)右锨,其hook_fn函數定義如下:

def forward_hook(module, input, output):
    operations

這里有3個參數,分別表示:模塊碌秸,模塊的輸入绍移,模塊的輸出悄窃。函數用于描述對這些參數的操作,一般我們都是為了獲取特征圖蹂窖,即只描述對output的操作即可轧抗。

對于register_backward_hook(hook_fn),其hook_fn函數定義如下:

def backward_hook(module, grad_in, grad_out):
    operations

這里也有3個參數恼策,分別表示:模塊鸦致,模塊輸入端的梯度,模塊輸出端的梯度涣楷。這里需要特別注意的是分唾,此處的輸入端和輸出端,是前向傳播時的輸入端和輸出端狮斗,也就是說绽乔,上面的output的梯度對應這里的grad_out。例如線性模塊:o=W*x+b碳褒,其輸入端為 W折砸,x 和 b,輸出端為 o沙峻。

如果模塊有多個輸入或者輸出的話睦授,grad_in和grad_out可以是 tuple 類型。對于線性模塊:o=W*x+b 摔寨,它的輸入端包括了W去枷、x 和 b 三部分,因此 grad_input 就是一個包含三個元素的 tuple是复。

這里注意和 forward hook 的不同:

  1. 在 forward hook 中删顶,input 是 x,而不包括 W 和 b淑廊。
  2. 返回 Tensor 或者 None逗余,backward hook 函數不能直接改變它的輸入變量,但是可以返回新的 grad_in季惩,反向傳播到它上一個模塊录粱。

此處的自定義的函數hook_fn也可以自擬名稱,但如果兩個hook函數同時使用的時候注意名稱的區(qū)別画拾,一般在函數名里添加對應的forward和backward就不易搞混了关摇。

下面看一個具體用例:

#-*- utf-8 -*-

'''本程序用于驗證hook編程獲取卷積層的輸出特征圖和特征圖的梯度'''

__author__ = 'puxitong from UESTC'

import torch
import torch.nn as nn
import numpy as np 
import torchvision.transforms as transforms


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,6,3,1,1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,9,3,1,1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(8*8*9, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120,10)

    def forward(self, x):
        out = self.pool1(self.relu1(self.conv1(x)))
        out = self.pool2(self.relu2(self.conv2(out)))
        out = out.view(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.fc2(out)

        return out


def backward_hook(module, grad_in, grad_out):
    grad_block['grad_in'] = grad_in
    grad_block['grad_out'] = grad_out


def farward_hook(module, inp, outp):
    fmap_block['input'] = inp
    fmap_block['output'] = outp


loss_func = nn.CrossEntropyLoss()

# 生成一個假標簽以便演示
label = torch.empty(1, dtype=torch.long).random_(3)

# 生成一副假圖像以便演示
input_img = torch.randn(1,3,32,32).requires_grad_()  

fmap_block = dict()  # 裝feature map
grad_block = dict()  # 裝梯度

net = Net()

# 注冊hook
net.conv2.register_forward_hook(farward_hook)
net.conv2.register_backward_hook(backward_hook)

outs = net(input_img)
loss = loss_func(outs, label)
loss.backward()

print('End.')

上面的程序中,我們先定義了一個簡單的卷積神經網絡模型碾阁,我們對第二層卷積模塊進行hook注冊输虱,既獲取它的輸入輸出,又獲取輸入輸出的梯度脂凶,并將它們分別裝在字典里宪睹。為了達到驗證效果愁茁,我們隨機生成一個假圖像,它的尺寸和cifar-10數據集的圖像尺寸一致亭病,并且給這個假圖像定義一個類別標簽鹅很,用損失函數進行反向傳播,模擬神經網絡的訓練過程罪帖。

在IPython中運行程序后促煮,相應的特征圖和梯度就會出現在兩個列表fmap_block和grad_block中了。我們看一下它們的輸入和輸出的維度:

In [17]: len(fmap_block['input'])                                               
Out[17]: 1

In [18]: len(fmap_block['output'])                                              
Out[18]: 1

In [19]: len(grad_block['grad_in'])                                             
Out[19]: 3

In [20]: len(grad_block['grad_out'])                                            
Out[20]: 1

可以看出整袁,第二層卷積模塊的輸入和輸出都只有一個菠齿,即相應的特征圖。而輸入端的梯度值有3個坐昙,分別為權重的梯度绳匀,偏差的梯度,以及輸入特征圖的梯度炸客。輸出端的梯度值只有一個疾棵,即輸出特征圖的梯度。正如上面強調的痹仙,輸入端即使有W, X和b三個是尔,對于前項傳播來說只有X是其輸入,而對于反向傳播來說开仰,3個都是輸入嗜历。輸出端3項的梯度值排列的順序是什么呢,我們來看一下3項梯度的具體維度:

In [21]: grad_block['grad_in'][0].shape                                         
Out[21]: torch.Size([1, 6, 16, 16])

In [22]: grad_block['grad_in'][1].shape                                         
Out[22]: torch.Size([9, 6, 3, 3])

In [23]: grad_block['grad_in'][2].shape                                         
Out[23]: torch.Size([9])

從輸出端梯度的維度可以判斷抖所,第一個顯然是特征圖的梯度,第二個則是權重(卷積核/濾波器)的梯度痕囱,第三個是偏置的梯度田轧。為了驗證梯度和這些參數具有同樣的維度,我們再來看看這三個值前向傳播時的維度:

In [24]: fmap_block['input'][0].shape                                           
Out[24]: torch.Size([1, 6, 16, 16])

In [25]: net.conv2.weight.shape         
Out[25]: torch.Size([9, 6, 3, 3])

In [26]: net.conv2.bias.shape                                                   
Out[26]: torch.Size([9])

可以看到鞍恢,我們的判斷是正確的傻粘。

最后需要注意的一點是,如果需要獲取輸入圖像的梯度帮掉,一定要將輸入Tensor的requires_grad屬性設為True弦悉。

原創(chuàng)不易,有用請點贊支持~

?著作權歸作者所有,轉載或內容合作請聯系作者
  • 序言:七十年代末蟆炊,一起剝皮案震驚了整個濱河市稽莉,隨后出現的幾起案子,更是在濱河造成了極大的恐慌涩搓,老刑警劉巖污秆,帶你破解...
    沈念sama閱讀 219,188評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件劈猪,死亡現場離奇詭異,居然都是意外死亡良拼,警方通過查閱死者的電腦和手機战得,發(fā)現死者居然都...
    沈念sama閱讀 93,464評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來庸推,“玉大人常侦,你說我怎么就攤上這事”崦剑” “怎么了聋亡?”我有些...
    開封第一講書人閱讀 165,562評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長掖蛤。 經常有香客問我杀捻,道長,這世上最難降的妖魔是什么蚓庭? 我笑而不...
    開封第一講書人閱讀 58,893評論 1 295
  • 正文 為了忘掉前任致讥,我火速辦了婚禮,結果婚禮上器赞,老公的妹妹穿的比我還像新娘垢袱。我一直安慰自己,他們只是感情好港柜,可當我...
    茶點故事閱讀 67,917評論 6 392
  • 文/花漫 我一把揭開白布请契。 她就那樣靜靜地躺著,像睡著了一般夏醉。 火紅的嫁衣襯著肌膚如雪爽锥。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,708評論 1 305
  • 那天畔柔,我揣著相機與錄音氯夷,去河邊找鬼。 笑死靶擦,一個胖子當著我的面吹牛腮考,可吹牛的內容都是我干的。 我是一名探鬼主播玄捕,決...
    沈念sama閱讀 40,430評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼踩蔚,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了枚粘?” 一聲冷哼從身側響起馅闽,我...
    開封第一講書人閱讀 39,342評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后捞蛋,有當地人在樹林里發(fā)現了一具尸體孝冒,經...
    沈念sama閱讀 45,801評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 37,976評論 3 337
  • 正文 我和宋清朗相戀三年拟杉,在試婚紗的時候發(fā)現自己被綠了庄涡。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,115評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡搬设,死狀恐怖穴店,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情拿穴,我是刑警寧澤泣洞,帶...
    沈念sama閱讀 35,804評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站默色,受9級特大地震影響球凰,放射性物質發(fā)生泄漏。R本人自食惡果不足惜腿宰,卻給世界環(huán)境...
    茶點故事閱讀 41,458評論 3 331
  • 文/蒙蒙 一呕诉、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧吃度,春花似錦甩挫、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,008評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至间护,卻和暖如春亦渗,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背汁尺。 一陣腳步聲響...
    開封第一講書人閱讀 33,135評論 1 272
  • 我被黑心中介騙來泰國打工法精, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人均函。 一個月前我還...
    沈念sama閱讀 48,365評論 3 373
  • 正文 我出身青樓,卻偏偏與公主長得像菱涤,于是被迫代替她去往敵國和親苞也。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 45,055評論 2 355

推薦閱讀更多精彩內容