轉(zhuǎn)載:https://zhuanlan.zhihu.com/p/27783097
https://zhuanlan.zhihu.com/XavierLin
0. 本章內(nèi)容
在本次,我們將學習如何自定義一個torch.autograd.Function,下面是本次的主要內(nèi)容
1. 對Function的直觀理解晶姊;
2. Function與Module的差異與應用場景黄刚;
3. 寫一個簡單的ReLU Function;
1.對Function的直觀理解
? ? ?? 在之前的介紹中呆万,我們知道,Pytorch是利用Variable與Function來構建計算圖的⊙钒剩回顧下Variable,Variable就像是計算圖中的節(jié)點岂傲,保存計算結果(包括前向傳播的激活值难裆,反向傳播的梯度),而Function就像計算圖中的邊镊掖,實現(xiàn)Variable的計算乃戈,并輸出新的Variable。Function簡單說就是對Variable的運算亩进,如加減乘除症虑,relu,pool等归薛。但它不僅僅是簡單的運算谍憔。與普通Python或者numpy的運算不同,F(xiàn)unction是針對計算圖主籍,需要計算反向傳播的梯度习贫。因此他不僅需要進行該運算(forward過程),還需要保留前向傳播的輸入(為計算梯度)崇猫,并支持反向傳播計算梯度沈条。如果有做過公開課cs231的作業(yè),記得里面的每個運算都定義了forward诅炉,backward蜡歹,并通過保存cache來進行反向傳播屋厘。這兩者是類似的。在之前Variable的學習中月而,我們知道進行一次運算后汗洒,輸出的Variable對應的creator就是其運行的計算,如y = relu(x), y.creator父款,就是relu這個Function溢谤。我們可以對Function進行拓展,使其滿足我們自己的需要憨攒,而拓展就需要自定義Function的forward運算世杀,以及對應的backward運算,同時在forward中需要通過保存輸入值用于backward肝集≌鞍樱總結,F(xiàn)unction與Variable構成了pytorch的自動求導機制杏瞻,它定義的是各個Variable之間的計算關系所刀。
2. Function與Module的差異與應用場景
Function與Module都可以對pytorch進行自定義拓展,使其滿足網(wǎng)絡的需求捞挥,但這兩者還是有十分重要的不同:
1)Function一般只定義一個操作浮创,因為其無法保存參數(shù),因此適用于激活函數(shù)砌函、pooling等操作斩披;Module是保存了參數(shù),因此適合于定義一層胸嘴,如線性層雏掠,卷積層,也適用于定義一個網(wǎng)絡劣像。
2)Function需要定義三個方法:__init__, forward, backward(需要自己寫求導公式);Module:只需定義__init__和forward摧玫,而backward的計算由自動求導機制構成耳奕。
3)可以不嚴謹?shù)恼J為,Module是由一系列Function組成诬像,因此其在forward的過程中屋群,F(xiàn)unction和Variable組成了計算圖,在backward時坏挠,只需調(diào)用Function的backward就得到結果芍躏,因此Module不需要再定義backward。
4)Module不僅包括了Function降狠,還包括了對應的參數(shù)对竣,以及其他函數(shù)與變量庇楞,這是Function所不具備的
3. 一個ReLU Function
1)首先我們定義一個繼承Function的ReLU類;
2)然后我們來看Variable在進行運算時否纬,其creator是否是對應的Function吕晌;
3)最后我們?yōu)榉奖闶褂眠@個ReLU類,將其wrap成一個函數(shù)临燃,方便調(diào)用睛驳,不必每次顯式都創(chuàng)建一個新對象;
3.1 定義一個ReLU類
import torch
from torch.autograd import Variable
class MyReLU(torch.autograd.Function):?
? ? ? ?? def forward(self, input_):?
? ? ? ? ? ? ? ?? # 在forward中膜廊,需要定義MyReLU這個運算的forward計算過程乏沸;
? ? ? ? ? ? ? ?? # 同時可以保存任何在后向傳播中需要使用的變量值?
? ? ? ? ? ? ? ?? self.save_for_backward(input_) # 將輸入保存起來,在backward時使用?
? ? ? ? ? ? ? ?? output = input_.clamp(min=0) ? ?? # relu就是截斷負數(shù)爪瓜,讓所有負數(shù)等于0
? ? ? ? ? ? ? ?? return output?
? ? ? ?? def backward(self, grad_output):?
? ? ? ? ? ? ? ?? # 根據(jù)BP算法的推導(鏈式法則)蹬跃,dloss / dx = (dloss / doutput) * (doutput / dx)?
? ? ? ? ? ? ? ?? # dloss / doutput就是輸入的參數(shù)grad_output、?
? ? ? ? ? ? ? ? # 因此只需求relu的導數(shù)钥勋,在乘以grad_outpu?
? ? ? ? ? ? ? ? input_, = self.saved_tensors?
? ? ? ? ? ? ? ? grad_input = grad_output.clone()?
? ? ? ? ? ? ? ? grad_input[input < 0] = 0 # 上訴計算的結果就是左式炬转。即ReLU在反向
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?? #傳播中可以看做一個通道選擇函數(shù),所有未達
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?? #到閾值(激活值<0)的單元的梯度都為0?
? ? ? ? ? ? ? ? return grad_input
3.2 驗證Variable與Function的關系
from torch.autograd import Variable
input_=Variable(torch.randn(1))
relu=MyReLU()
output_=relu(input_)# 這個relu對象算灸,就是output_.creator扼劈,即這個relu對象將
? ? ? ? ? ? ? ? ? ? ? ? ? ? ?? # output與input連接起來,形成一個計算圖
print relu
print output_.creator
輸出:
<__main__.MyReLUobjectat0x7fd0b2d08b30>
<__main__.MyReLUobjectat0x7fd0b2d08b30>
可見菲驴,F(xiàn)unction連接了Variable與Variable荐吵,并實現(xiàn)不同計算。
3.3 Wrap一個ReLU函數(shù)
?? 可以直接把剛才自定義的ReLU類封裝成一個函數(shù)赊瞬,方便直接調(diào)用
def? relu(input_):
? ? ? ?? # MyReLU()是創(chuàng)建一個MyReLU對象先煎,
? ? ? ? # Function類利用了Python __call__操作,使得可以直接使用對象調(diào)用__call__制定的方法# ? ? ? ? ? #__call__指定的方法是forward巧涧,因此下面這句MyReLU()(input_)相當于 ? return ? ?
? ? ? ? # MyReLU().forward(input_)
? ? ? ? returnMyReLU()(input_)
input_=Variable(torch.linspace(-3,3,steps=5))
print ? input_
print? relu(input_)
輸出:
Variable containing:
-3.0000
-1.5000
0.0000
1.5000
3.0000
[torch.FloatTensor of size 5]
Variable containing:
0.0000
0.0000
0.0000
1.5000
3.0000
[torch.FloatTensor of size 5]