本人學(xué)習(xí)pytorch主要參考官方文檔和 莫煩Python中的pytorch視頻教程。
后文主要是對pytorch官網(wǎng)的文檔的總結(jié)。
代碼來自pytorch官網(wǎng)
import torch
# 通過繼承torch.autograd.Function類到逊,并實現(xiàn)forward 和 backward函數(shù)
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
"""
在forward函數(shù)中,接收包含輸入的Tensor并返回包含輸出的Tensor禽拔。
ctx是環(huán)境變量恨诱,用于提供反向傳播是需要的信息≌秤牛可通過ctx.save_for_backward方法緩存數(shù)據(jù)仇味。
"""
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
"""
在backward函數(shù)中,接收包含了損失梯度的Tensor雹顺,
我們需要根據(jù)輸入計算損失的梯度丹墨。
"""
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
dtype = torch.float
device = torch.device("cpu")
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
learning_rate = 1e-6
for t in range(500):
relu = MyReLU.apply
y_pred = relu(x.mm(w1)).mm(w2)
loss = (y_pred - y).pow(2).sum()
print(t, loss.item())
loss.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()