自動求導(dǎo)機(jī)制
PyTorch 中所有神經(jīng)網(wǎng)絡(luò)的核心是 autograd 包柒凉。 我們先簡單介紹一下這個包动雹,然后訓(xùn)練第一個簡單的神經(jīng)網(wǎng)絡(luò)塞琼。
autograd包為張量上的所有操作提供了自動求導(dǎo)叠洗。 它是一個在運(yùn)行時定義的框架凫海,這意味著反向傳播是根據(jù)你的代碼來確定如何運(yùn)行呛凶,并且每次迭代可以是不同的。
張量(Tensor)
torch.Tensor是這個包的核心類行贪,如果設(shè)置.requires_grad為True, 那么就會追蹤所有對于該張量的操作漾稀。當(dāng)完成計算后通過調(diào)用.backward(),自動計算所有梯度建瘫,這個張量的所有梯度都會自動累計到.grad屬性崭捍。
要阻止張量跟蹤歷史記錄,可以調(diào)用.detach()方法將其與計算歷史記錄分離啰脚,并禁止跟蹤它將來的計算記錄殷蛇。
推斷時 為了防止跟蹤歷史記錄(和使用內(nèi)存),可以將代碼塊包裝在with torch.no_grad():中。 在評估模型時特別有用粒梦,因?yàn)槟P涂赡芫哂衦equires_grad = True的可訓(xùn)練參數(shù)亮航,但是我們不需要梯度計算。
在自動梯度計算中還有另外一個重要的類Function.
Tensor和Function互相連接并生成一個非循環(huán)圖匀们,它表示和存儲了完整的計算歷史缴淋,每個張量都由.grad_fn屬性,這個屬性引用了一個創(chuàng)建了Tensor的Function(除非這個Tensor是由用戶手動創(chuàng)建昼蛀,即宴猾,該張量的.grad_fn是None)
如果需要計算導(dǎo)數(shù),你可以再Tensor上調(diào)用.backward()叼旋。如果Tensor是一個標(biāo)量(即它包含一個元素數(shù)據(jù))則不需要為backward()指定任何參數(shù),但是如果有更多元素沦辙,你需要指定一個gradient參數(shù)來匹配張量形狀夫植。
import torch
x = torch.ones(2,2,requires_grad=True)
print(x)
# tensor([[1., 1.],
# [1., 1.]], requires_grad=True)
y = x+2
print(y)
#tensor([[3., 3.],
# [3., 3.]], grad_fn=<AddBackward0>)
x.requires_grad_(True)
# 可以用于改變requires_grad屬性
梯度
反向傳播因?yàn)閛ut是一個純量(scalar), out.backward()等于out.backward(torch.tensor(1))
out.backward()
# 打印 d(out)/dx 梯度
print(x.grad)
如果.requires_grad = True但是又不希望進(jìn)行autograd計算,可以將變量包裹在with torch.no_grad()中
總結(jié)流程:
- 當(dāng)我們執(zhí)行z.backward()的時候油讯,這個操作將調(diào)用z里面的grad_fn屬性详民,執(zhí)行求導(dǎo)的操作
- 這個操作將遍歷grad_fn的next_functions, 然后分別取出里面的Function(AccumulateGrad)陌兑, 執(zhí)行求導(dǎo)操作沈跨,這個部分是一個遞歸的過程直到最后類型為葉子節(jié)點(diǎn)。
- 計算出結(jié)果以后兔综,將結(jié)果保存在他們對應(yīng)的variable這個變量所引用的對象(x和y)的grad這個屬性里面
- 求導(dǎo)結(jié)束饿凛。所有葉節(jié)點(diǎn)的grad變量都得到相應(yīng)的更新
最終當(dāng)我們執(zhí)行完c.backward()之后,a和b里面的grad值就得到了更新软驰。
擴(kuò)展Autograd
如果需要自定義autograd擴(kuò)展新的功能涧窒,需要擴(kuò)展Function類,因?yàn)镕unction使用autograd來計算結(jié)果和梯度锭亏,并對操作歷史進(jìn)行編碼纠吴。在Function類中最主要的方法是forward()和backward()他們分別代表前向傳播和后向傳播。
一個自定義的Function需要以下三個方法:
init(optional):如果該操作需要額外的參數(shù)慧瘤,則需要定義該function的構(gòu)造函數(shù)戴已,不需要可以省略
forward(): 執(zhí)行前向傳播的計算代碼
backward():執(zhí)行后向傳播的計算代碼
# 引入Function便于擴(kuò)展
from torch.autograd.function import Function
# 定義一個乘以常數(shù)的操作(輸入?yún)?shù)是張量)
# 方法必須是靜態(tài)方法,所以要加上@staticmethod
class MulConstant(Function):
@staticmethod
def forward(ctx, tensor, constant):
# ctx 用來保存信息這里類似self锅减,并且ctx的屬性可以在backward中調(diào)用
ctx.constant=constant
return tensor *constant
@staticmethod
def backward(ctx, grad_output):
# 返回的參數(shù)要與輸入的參數(shù)一樣.
# 第一個輸入為3x3的張量糖儡,第二個為一個常數(shù)
# 常數(shù)的梯度必須是 None.
return grad_output, None
# 測試定義的Function
a=torch.rand(3,3,requires_grad=True)
b=MulConstant.apply(a,5)
print("a:"+str(a))
print("b:"+str(b)) # b為a的元素乘以5
#a:tensor([[0.0118, 0.1434, 0.8669],
# [0.1817, 0.8904, 0.5852],
# [0.7364, 0.5234, 0.9677]], #requires_grad=True)
#b:tensor([[0.0588, 0.7169, 4.3347],
# [0.9084, 4.4520, 2.9259],
# [3.6820, 2.6171, 4.8386]], grad_fn=<MulConstantBackward>)
# 反向傳播,返回值不是標(biāo)量上煤,所以backward要參數(shù)
b.backward(torch.ones_like(a))
a.grad
#tensor([[1., 1., 1.],
# [1., 1., 1.],
# [1., 1., 1.]])