希望以后可以一直注意到這個問題,就是PyTorch的圖在進(jìn)行backward的時候是不保存中間變量的grad的,因此在之后用.grad
去查看梯度來檢查梯度傳播是無效的。這個問題,也可參見why-cant-i-see-grad-of-an-intermediate-variable中提出的詳細(xì)例子陷虎。
那如何能夠查看中間變量的梯度呢帕膜?Adam Paszke在這個問題底下又給出了一個簡短而有用的例子,具體如下:
grads = {}
def save_grad(name):
def hook(grad):
grads[name] = grad
return hook
x = Variable(torch.randn(1,1), requires_grad=True)
y = 3*x
z = y**2
# In here, save_grad('y') returns a hook (a function) that keeps 'y' as name
y.register_hook(save_grad('y'))
z.register_hook(save_grad('z'))
z.backward()
print(grads['y'])
print(grads['z'])
主要是通過hook機(jī)制,使得PyTorch圖在進(jìn)行backward的時候觸發(fā)保存下中間變量的grad养盗。