本文首發(fā)于簡書 西北小生_ 的博客:http://www.reibang.com/u/898c7641f6ea丧裁,未經允許掉蔬,禁止轉載勾习!
為了更深入地理解神經網絡模型,有時候我們需要觀察它訓練得到的卷積核懈玻、特征圖或者梯度等信息巧婶,這在CNN可視化研究中經常用到。其中涂乌,卷積核最易獲取艺栈,將模型參數保存即可得到;特征圖是中間變量湾盒,所對應的圖像處理完即會被系統(tǒng)清除湿右,否則將嚴重占用內存;梯度跟特征圖類似罚勾,除了葉子結點外毅人,其它中間變量的梯度都被會內存釋放,因而不能直接獲取尖殃。
最容易想到的獲取方法就是改變模型結構丈莺,在forward的最后不但返回模型的預測輸出,還返回所需要的特征圖等信息送丰。
如何在不改變模型結構的基礎上獲取特征圖缔俄、梯度等信息呢?
Pytorch的hook編程可以在不改變網絡結構的基礎上有效獲取器躏、改變模型中間變量以及梯度等信息俐载。
hook可以提取或改變Tensor的梯度,也可以獲取nn.Module的輸出和梯度(這里不能改變)登失。因此有3個hook函數用于實現以上功能:
Tensor.register_hook(hook_fn)遏佣,
nn.Module.register_forward_hook(hook_fn),
nn.Module.register_backward_hook(hook_fn).
下面對其用法進行一一介紹壁畸。
1.Tensor.register_hook(hook_fn)
功能:注冊一個反向傳播hook函數贼急,用于自動記錄Tensor的梯度茅茂。
PyTorch對中間變量和非葉子節(jié)點的梯度運行完后會自動釋放,以減緩內存占用太抓。什么是中間變量空闲?什么是非葉子節(jié)點?
上圖中走敌,a碴倾,b,d就是葉子節(jié)點掉丽,c跌榔,e,o是非葉子節(jié)點捶障,也是中間變量僧须。
In [18]: a = torch.Tensor([1,2]).requires_grad_()
...: b = torch.Tensor([3,4]).requires_grad_()
...: d = torch.Tensor([2]).requires_grad_()
...: c = a + b
...: e = c * d
...: o = e.sum()
In [19]: o.backward()
In [20]: print(a.grad)
tensor([2., 2.])
In [21]: print(b.grad)
tensor([2., 2.])
In [22]: print(c.grad)
None
In [23]: print(d.grad)
tensor([10.])
In [24]: print(e.grad)
None
In [25]: print(o.grad)
None
可以從程序的輸出中看到,a项炼,b担平,d作為葉子節(jié)點,經過反向傳播后梯度值仍然保留锭部,而其它非葉子節(jié)點的梯度已經被自動釋放了暂论,要想得到它們的梯度值,就需要使用hook了拌禾。
我們首先自定義一個hook_fn函數取胎,用于記錄對Tensor梯度的操作,然后用Tensor.register_hook(hook_fn)對要獲取梯度的非葉子結點的Tensor進行注冊湃窍,然后重新反向傳播一次:
In [44]: def hook_fn(grad):
...: print(grad)
...:
In [45]: e.register_hook(hook_fn)
Out[45]: <torch.utils.hooks.RemovableHandle at 0x1d139cf0a88>
In [46]: o.backward()
tensor([1., 1.])
這時就自動輸出了e的梯度闻蛀。
自定義的hook_fn函數的函數名可以是任取的,它的參數是grad坝咐,表示Tensor的梯度循榆。這個自定義函數主要是用于描述對Tensor梯度值的操作,上例中我們是對梯度直接進行輸出墨坚,所以是print(grad)秧饮。我們也可以把梯度裝在一個列表或字典里,甚至可以修改梯度泽篮,這樣如果梯度很小的時候將其變大一點就可以防止梯度消失的問題了:
In [28]: a = torch.Tensor([1,2]).requires_grad_()
...: b = torch.Tensor([3,4]).requires_grad_()
...: d = torch.Tensor([2]).requires_grad_()
...: c = a + b
...: e = c * d
...: o = e.sum()
In [29]: grad_list = []
In [30]: def hook(grad):
...: grad_list.append(grad) # 將梯度裝在列表里
...: return 2 * grad # 將梯度放大兩倍
...:
In [31]: c.register_hook(hook)
Out[31]: <torch.utils.hooks.RemovableHandle at 0x7f009b713208>
In [32]: o.backward()
In [33]: grad_list
Out[33]: [tensor([2., 2.])]
In [34]: a.grad
Out[34]: tensor([4., 4.])
In [35]: b.grad
Out[35]: tensor([4., 4.])
上例中盗尸,我們定義的hook函數執(zhí)行了兩個操作:一是將梯度裝進列表grad_list中,二是把梯度放大兩倍帽撑。從輸出中我們可以看到泼各,執(zhí)行反向傳播后,我們注冊的非葉子節(jié)點c的梯度保存在了列表grad_list中亏拉,并且a和b的梯度都變?yōu)樵瓉淼膬杀犊垓摺_@里需要注意的是逆巍,如果要將梯度值裝在一個列表或字典里,那么首先要定義一個同名的全局變量的列表或字典莽使,即使是局部變量锐极,也要在自定義的hook函數外面。另一個需要注意的點就是如果要改變梯度值芳肌,hook函數要有返回值灵再,返回改變后的梯度。
這里總結一下亿笤,如果要獲取非葉子節(jié)點Tensor的梯度值翎迁,我們需要在反向傳播前:
1)自定義一個hook函數,描述對梯度的操作净薛,函數名自擬汪榔,參數只有grad,表示Tensor的梯度肃拜;
2)對要獲取梯度的Tensor用方法Tensor.register_hook(hook)進行注冊揍异。
3)執(zhí)行反向傳播。
2.nn.Module.register_forward_hook(hook_fn)和nn.Module.register_backward_hook(hook_fn)
這兩個的操作對象都是nn.Module類爆班,如神經網絡中的卷積層(nn.Conv2d),全連接層(nn.Linear)辱姨,池化層(nn.MaxPool2d, nn.AvgPool2d)柿菩,激活層(nn.ReLU)或者nn.Sequential定義的小模塊等,所以放在一起講雨涛。
對于模型的中間模塊枢舶,也可以視作中間節(jié)點(非葉子節(jié)點),它的輸出為特征圖或激活值替久,反向傳播的梯度值都會被系統(tǒng)自動釋放凉泄,如果想要獲取它們,就要用到hook功能蚯根。
有名字即可看出后众,register_forward_hook是獲取前向傳播的輸出的,即特征圖或激活值颅拦;register_backward_hook是獲取反向傳播的輸出的蒂誉,即梯度值。它們的用法和上面介紹的register_hook類似距帅。我們先看一下hook_fn的定義:
對于register_forward_hook(hook_fn)右锨,其hook_fn函數定義如下:
def forward_hook(module, input, output):
operations
這里有3個參數,分別表示:模塊碌秸,模塊的輸入绍移,模塊的輸出悄窃。函數用于描述對這些參數的操作,一般我們都是為了獲取特征圖蹂窖,即只描述對output的操作即可轧抗。
對于register_backward_hook(hook_fn),其hook_fn函數定義如下:
def backward_hook(module, grad_in, grad_out):
operations
這里也有3個參數恼策,分別表示:模塊鸦致,模塊輸入端的梯度,模塊輸出端的梯度涣楷。這里需要特別注意的是分唾,此處的輸入端和輸出端,是前向傳播時的輸入端和輸出端狮斗,也就是說绽乔,上面的output的梯度對應這里的grad_out。例如線性模塊:o=W*x+b碳褒,其輸入端為 W折砸,x 和 b,輸出端為 o沙峻。
如果模塊有多個輸入或者輸出的話睦授,grad_in和grad_out可以是 tuple 類型。對于線性模塊:o=W*x+b 摔寨,它的輸入端包括了W去枷、x 和 b 三部分,因此 grad_input 就是一個包含三個元素的 tuple是复。
這里注意和 forward hook 的不同:
- 在 forward hook 中删顶,input 是 x,而不包括 W 和 b淑廊。
- 返回 Tensor 或者 None逗余,backward hook 函數不能直接改變它的輸入變量,但是可以返回新的 grad_in季惩,反向傳播到它上一個模塊录粱。
此處的自定義的函數hook_fn也可以自擬名稱,但如果兩個hook函數同時使用的時候注意名稱的區(qū)別画拾,一般在函數名里添加對應的forward和backward就不易搞混了关摇。
下面看一個具體用例:
#-*- utf-8 -*-
'''本程序用于驗證hook編程獲取卷積層的輸出特征圖和特征圖的梯度'''
__author__ = 'puxitong from UESTC'
import torch
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3,6,3,1,1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,9,3,1,1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2,2)
self.fc1 = nn.Linear(8*8*9, 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120,10)
def forward(self, x):
out = self.pool1(self.relu1(self.conv1(x)))
out = self.pool2(self.relu2(self.conv2(out)))
out = out.view(out.shape[0], -1)
out = self.relu3(self.fc1(out))
out = self.fc2(out)
return out
def backward_hook(module, grad_in, grad_out):
grad_block['grad_in'] = grad_in
grad_block['grad_out'] = grad_out
def farward_hook(module, inp, outp):
fmap_block['input'] = inp
fmap_block['output'] = outp
loss_func = nn.CrossEntropyLoss()
# 生成一個假標簽以便演示
label = torch.empty(1, dtype=torch.long).random_(3)
# 生成一副假圖像以便演示
input_img = torch.randn(1,3,32,32).requires_grad_()
fmap_block = dict() # 裝feature map
grad_block = dict() # 裝梯度
net = Net()
# 注冊hook
net.conv2.register_forward_hook(farward_hook)
net.conv2.register_backward_hook(backward_hook)
outs = net(input_img)
loss = loss_func(outs, label)
loss.backward()
print('End.')
上面的程序中,我們先定義了一個簡單的卷積神經網絡模型碾阁,我們對第二層卷積模塊進行hook注冊输虱,既獲取它的輸入輸出,又獲取輸入輸出的梯度脂凶,并將它們分別裝在字典里宪睹。為了達到驗證效果愁茁,我們隨機生成一個假圖像,它的尺寸和cifar-10數據集的圖像尺寸一致亭病,并且給這個假圖像定義一個類別標簽鹅很,用損失函數進行反向傳播,模擬神經網絡的訓練過程罪帖。
在IPython中運行程序后促煮,相應的特征圖和梯度就會出現在兩個列表fmap_block和grad_block中了。我們看一下它們的輸入和輸出的維度:
In [17]: len(fmap_block['input'])
Out[17]: 1
In [18]: len(fmap_block['output'])
Out[18]: 1
In [19]: len(grad_block['grad_in'])
Out[19]: 3
In [20]: len(grad_block['grad_out'])
Out[20]: 1
可以看出整袁,第二層卷積模塊的輸入和輸出都只有一個菠齿,即相應的特征圖。而輸入端的梯度值有3個坐昙,分別為權重的梯度绳匀,偏差的梯度,以及輸入特征圖的梯度炸客。輸出端的梯度值只有一個疾棵,即輸出特征圖的梯度。正如上面強調的痹仙,輸入端即使有W, X和b三個是尔,對于前項傳播來說只有X是其輸入,而對于反向傳播來說开仰,3個都是輸入嗜历。輸出端3項的梯度值排列的順序是什么呢,我們來看一下3項梯度的具體維度:
In [21]: grad_block['grad_in'][0].shape
Out[21]: torch.Size([1, 6, 16, 16])
In [22]: grad_block['grad_in'][1].shape
Out[22]: torch.Size([9, 6, 3, 3])
In [23]: grad_block['grad_in'][2].shape
Out[23]: torch.Size([9])
從輸出端梯度的維度可以判斷抖所,第一個顯然是特征圖的梯度,第二個則是權重(卷積核/濾波器)的梯度痕囱,第三個是偏置的梯度田轧。為了驗證梯度和這些參數具有同樣的維度,我們再來看看這三個值前向傳播時的維度:
In [24]: fmap_block['input'][0].shape
Out[24]: torch.Size([1, 6, 16, 16])
In [25]: net.conv2.weight.shape
Out[25]: torch.Size([9, 6, 3, 3])
In [26]: net.conv2.bias.shape
Out[26]: torch.Size([9])
可以看到鞍恢,我們的判斷是正確的傻粘。
最后需要注意的一點是,如果需要獲取輸入圖像的梯度帮掉,一定要將輸入Tensor的requires_grad屬性設為True弦悉。
原創(chuàng)不易,有用請點贊支持~