Pytorch Hook 函數(shù)

Pytorch中帶了Hook函數(shù)恭金,Hook的中文意思是’鉤子‘蜓肆,剛開(kāi)始看到這個(gè)詞語(yǔ)就有點(diǎn)害怕,一是不認(rèn)識(shí)這個(gè)詞罢防,翻譯成中文也不了解這是什么意思艘虎;二是常規(guī)調(diào)庫(kù)搭積木時(shí)也沒(méi)有用到過(guò)這個(gè)函數(shù)唉侄;直到讀到下面文章,https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca 我對(duì)hook有了初步的理解

1. 為什么需要 hook 函數(shù)

  • 當(dāng)我們的神經(jīng)網(wǎng)絡(luò)出現(xiàn) bug 時(shí)野建,沒(méi)法產(chǎn)生我們所期望的輸出時(shí)属划,我們通常需要進(jìn)行debug,一般的做法是在 forward 函數(shù)中寫(xiě) print函數(shù)候生,輸出某些層的輸出同眯;或者通過(guò)添加斷點(diǎn)來(lái)進(jìn)行單步調(diào)試,以觀(guān)察中間層的輸出唯鸭。這在 pytorch 中就可以通過(guò) hook 函數(shù)來(lái)實(shí)現(xiàn)须蜗。
  • 由于pytorhc的自動(dòng)求導(dǎo)機(jī)制,即當(dāng)設(shè)置參數(shù)的 requires_grad=True時(shí),那么涉及這組參數(shù)的一系列操作將會(huì)被autograd記錄用以反向求導(dǎo)明肮。但是在自動(dòng)求導(dǎo)機(jī)制中只保存葉子節(jié)點(diǎn)菱农,也就是中間變量在計(jì)算完成梯度后會(huì)自動(dòng)釋放以節(jié)省空間
x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x * 2
z = torch.mean(y)
z.backward()
print("x.grad =", x.grad)
print("y.grad =", y.grad)
print("z.grad =", z.grad)

輸出

x.grad = tensor([1., 1.])
y.grad = None
z.grad = None

因此,如果我們想知道 y 和 z 的梯度柿估,就需要用到 hook 函數(shù)循未。
也就是說(shuō),hook 函數(shù)用以獲取我們不方便獲得的一些中間變量秫舌。

2. 什么是hook函數(shù)

  • hook 其實(shí)就是一個(gè)普通的函數(shù)或類(lèi)的妖,準(zhǔn)確的說(shuō)是一個(gè)可調(diào)用的對(duì)象,callable object. 需要什么樣的功能我們可根據(jù)自己的需求自己寫(xiě)足陨∩┧冢總之,hook 和我們常規(guī)寫(xiě)的函數(shù)和類(lèi)沒(méi)有區(qū)別墨缘。但是 pytorch 有一個(gè)機(jī)制赋元,我們可以把寫(xiě)好的函數(shù)或者類(lèi)注冊(cè)到某些 layer (nn.Module)上,這樣子當(dāng)這些 layer 在執(zhí)行 forward 或者 backward時(shí)其輸入或輸出就會(huì)自動(dòng)傳到我們寫(xiě)好的hook函數(shù)中執(zhí)行飒房。因此搁凸,這些函數(shù)就像一個(gè)鉤子一樣,可以?huà)斓侥承﹍ayer上或者從這些 layer 上解掛狠毯。這就是名字叫 hook 的原因护糖。

3. Pytorch 提供的 Hook

  • 一般來(lái)說(shuō),我們?cè)?debug 時(shí)想知道的內(nèi)容有三種
    • 某個(gè)模塊的輸入是什么嚼松,即 在跑 forward前模塊的輸入
    • 某個(gè)模塊的輸出是什么嫡良,即 在跑 forward后模塊的輸出
    • 某個(gè)模塊的梯度反傳后是什么,即 在跑 backward后模塊的狀態(tài)
  • 將這三個(gè)狀態(tài)的數(shù)據(jù)與我們所期望的數(shù)據(jù)進(jìn)行比較献酗,我們就可以知道哪里出現(xiàn)了問(wèn)題寝受;Pytorch 就提供了這三種鉤子,把這三種鉤子掛到指定的layer上罕偎,這些layer的輸入輸出就會(huì)對(duì)應(yīng)的作為參數(shù)傳到hook函數(shù)中運(yùn)行hook函數(shù)很澄。下圖引用自
    image.png
  • pytorch nn.Module源碼中就提供了這三個(gè)屬性
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
  • 同時(shí)提供了三個(gè)注冊(cè)方法,也就是往上面三個(gè)dict中填值的方法
    • forward prehook (executing before the forward pass),
    • forward hook (executing after the forward pass),
    • backward hook (executing after the backward pass).

register_forward_pre_hookforward前運(yùn)行颜及,獲取這一個(gè) module 的輸入

    def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
        r"""Registers a forward pre-hook on the module.

        The hook will be called every time before :func:`forward` is invoked.
        It should have the following signature::

            hook(module, input) -> None or modified input

        The input contains only the positional arguments given to the module.
        Keyword arguments won't be passed to the hooks and only to the ``forward``.
        The hook can modify the input. User can either return a tuple or a
        single modified value in the hook. We will wrap the value into a tuple
        if a single value is returned(unless that value is already a tuple).

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_pre_hooks)
        self._forward_pre_hooks[handle.id] = hook
        return handle

register_forward_hook在forward后運(yùn)行甩苛,獲取這個(gè)module的input和output信息

    def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
        r"""Registers a forward hook on the module.

        The hook will be called every time after :func:`forward` has computed an output.
        It should have the following signature::

            hook(module, input, output) -> None or modified output

        The input contains only the positional arguments given to the module.
        Keyword arguments won't be passed to the hooks and only to the ``forward``.
        The hook can modify the output. It can modify the input inplace but
        it will not have effect on forward since this is called after
        :func:`forward` is called.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_hooks)
        self._forward_hooks[handle.id] = hook
        return handle

register_backward_hook,獲取反向傳播中module的grad_in, grad_out信息

    def register_backward_hook(
        self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
    ) -> RemovableHandle:
        r"""Registers a backward hook on the module.

        This function is deprecated in favor of :meth:`nn.Module.register_full_backward_hook` and
        the behavior of this function will change in future versions.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``

        """
        if self._is_full_backward_hook is True:
            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "
                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = False

        handle = hooks.RemovableHandle(self._backward_hooks)
        self._backward_hooks[handle.id] = hook
        return handle

4.hook 實(shí)例

這里我們通過(guò)在ResNet34的每一層插入一個(gè)鉤子,來(lái)獲取ResNet34每一層的輸出俏站,即這里我們使用 register_forward_hook
使用下面圖片作為輸入

image.png

import torch
from torchvision.models import resnet34

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = resnet34(pretrained=True)
model = model.to(device)
class SaveOutput:
    def __init__(self):
        self.outputs = []
        self.inputs = []
        
    def __call__(self, module, module_in, module_out):
        print(module)
        self.inputs.append(module_in)
        self.outputs.append(module_out)
        
    def clear(self):
        self.outputs = []
        self.inputs = []
        

save_output = SaveOutput()

hook_handles = []

for layer in model.modules():
    if isinstance(layer, torch.nn.modules.conv.Conv2d):
        handle = layer.register_forward_hook(save_output)
        hook_handles.append(handle)
        
        
from PIL import Image
from torchvision import transforms as T

img = Image.open('./cat.jpeg')
transform = T.Compose([T.Resize((224,224)),
                       T.ToTensor(),
                       T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.485, 0.456, 0.406],)
                      ])
x = transform(img).unsqueeze(0).to(device)
out = model(x)

輸出

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

> save_output.outputs[0].size()
torch.Size([1, 64, 112, 112])
> save_output.inputs[0][0].size()
torch.Size([1, 3, 224, 224])

可以看到模塊讯蒲,模塊的輸入輸出會(huì)自動(dòng)作為參數(shù)傳入到我們寫(xiě)的SaveOutput實(shí)例中并調(diào)用該實(shí)例。
下面是每一層的輸出可視化

image.png


對(duì)于 Tensor的 hook

x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x * 2
y.register_hook(print)
z = torch.mean(y)
z.backward()

輸出:

tensor([0.5000, 0.5000])

hook 應(yīng)用于 模型剪枝 model pruning
https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末肄扎,一起剝皮案震驚了整個(gè)濱河市墨林,隨后出現(xiàn)的幾起案子赁酝,更是在濱河造成了極大的恐慌,老刑警劉巖旭等,帶你破解...
    沈念sama閱讀 219,490評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件赞哗,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡辆雾,警方通過(guò)查閱死者的電腦和手機(jī)肪笋,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,581評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門(mén),熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)度迂,“玉大人藤乙,你說(shuō)我怎么就攤上這事〔涯梗” “怎么了坛梁?”我有些...
    開(kāi)封第一講書(shū)人閱讀 165,830評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀(guān)的道長(zhǎng)腊凶。 經(jīng)常有香客問(wèn)我划咐,道長(zhǎng),這世上最難降的妖魔是什么钧萍? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,957評(píng)論 1 295
  • 正文 為了忘掉前任褐缠,我火速辦了婚禮,結(jié)果婚禮上风瘦,老公的妹妹穿的比我還像新娘队魏。我一直安慰自己,他們只是感情好万搔,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,974評(píng)論 6 393
  • 文/花漫 我一把揭開(kāi)白布胡桨。 她就那樣靜靜地躺著,像睡著了一般瞬雹。 火紅的嫁衣襯著肌膚如雪昧谊。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 51,754評(píng)論 1 307
  • 那天酗捌,我揣著相機(jī)與錄音呢诬,去河邊找鬼。 笑死意敛,一個(gè)胖子當(dāng)著我的面吹牛馅巷,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播草姻,決...
    沈念sama閱讀 40,464評(píng)論 3 420
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼稍刀!你這毒婦竟也來(lái)了撩独?” 一聲冷哼從身側(cè)響起敞曹,我...
    開(kāi)封第一講書(shū)人閱讀 39,357評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎综膀,沒(méi)想到半個(gè)月后澳迫,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,847評(píng)論 1 317
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡剧劝,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,995評(píng)論 3 338
  • 正文 我和宋清朗相戀三年橄登,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片讥此。...
    茶點(diǎn)故事閱讀 40,137評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡拢锹,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出萄喳,到底是詐尸還是另有隱情卒稳,我是刑警寧澤,帶...
    沈念sama閱讀 35,819評(píng)論 5 346
  • 正文 年R本政府宣布他巨,位于F島的核電站充坑,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏染突。R本人自食惡果不足惜捻爷,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,482評(píng)論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望份企。 院中可真熱鬧役衡,春花似錦、人聲如沸薪棒。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 32,023評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)俐芯。三九已至棵介,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間吧史,已是汗流浹背邮辽。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 33,149評(píng)論 1 272
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留贸营,地道東北人吨述。 一個(gè)月前我還...
    沈念sama閱讀 48,409評(píng)論 3 373
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像钞脂,于是被迫代替她去往敵國(guó)和親揣云。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,086評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容