第一步
類的定義
import numpy as np
def as_array(x):
if np.isscalar(x):
return np.array(x)
return x
class Variable:
def __init__(self, data):
self.data = data
self.grad = None # 定義梯度
self.creator = None # 定義創(chuàng)建者
# 設(shè)置創(chuàng)建變量的函數(shù),函數(shù)是變量的創(chuàng)建者
def set_creator(self, func):
self.creator = func
# def backward(self):
# f = self.creator # 獲取到函數(shù)
# if f is not None: # 如果函數(shù)是None,說(shuō)明變量是用戶輸入變量
# x = f.input # 獲取到函數(shù)的輸入
# x.grad = f.backward(self.grad) # 通過(guò)函數(shù)和當(dāng)前變量計(jì)算前一個(gè)變量的梯度
# x.backward() # 遞歸調(diào)用前一個(gè)變量的梯度
# 遞歸修改為循環(huán)
def backward(self):
# 為了省去 y.grad = np.array(1.0)得糜, 引入如下的代碼
if self.grad is None:
self.grad = np.zeros_like(self.data)
funcs = [self.creator]
while funcs:
f = funcs.pop()
x, y = f.input, f.output
x.grad = f.backward(y.grad)
if x.creator is not None:
funcs.append(x.creator)
class Function:
def __call__(self, input2):
print(type(input2))
x = input2.data
y = self.forward(x)
output = Variable(as_array(y))
output.set_creator(self) # 讓輸出變量保存創(chuàng)建者的信息诫肠,動(dòng)態(tài)建立連接的核心
self.input = input2 # 保存輸入的變量
self.output = output # 也保存輸出變量
return output
def forward(self, x):
raise NotImplementedError
def backward(self, gy):
raise NotImplementedError
class SquareFunction(Function):
def forward(self, x):
y = x ** 2
return y
def backward(self, gy):
x = self.input.data
gx = 2 * x * gy
return gx
class ExpFunction(Function):
def forward(self, x):
return np.exp(x)
def backward(self, gy):
x = self.input.data
gx = np.exp(x) * gy
return gx
測(cè)試
import unittest
import numpy as np
from step01 import Variable, Function, ExpFunction, SquareFunction
# class MyTestCase(unittest.TestCase):
# def test_something(self):
# self.assertEqual(True, False) # add assertion here
#
#
# if __name__ == '__main__':
# unittest.main()
f1 = SquareFunction()
f2 = ExpFunction()
f3 = SquareFunction()
inputArr = np.array(0.5)
x = Variable(inputArr)
a = f1(x)
print(a.data)
b = f2(a)
print(b.data)
y = f3(b)
print(y.data)
y.grad = np.array(1.0)
b.grad = f3.backward(y.grad)
print(b.grad)
a.grad = f2.backward(b.grad)
print(b.grad)
x.grad = f1.backward(a.grad)
print(x.grad)
assert y.creator == f3
assert y.creator.input == b
assert y.creator.input.creator == f2
assert y.creator.input.creator.input == a
assert y.creator.input.creator.input.creator == f1
assert y.creator.input.creator.input.creator.input == x
# 反向傳播的計(jì)算流程
# 1. 獲取函數(shù)
# 2. 獲取函數(shù)的輸入
# 3. 調(diào)用函數(shù)的backward方法
# 從y到b的反向傳播
f3 = y.creator
b = f3.input
b.grad = f3.backward(y.grad)
# 從b到a的反向傳播
f2 = b.creator
a = f2.input
a.grad = f2.backward(b.grad)
# 從 a 到 x 的方向傳播
f1 = a.creator
x = f1.input
x.grad = f1.backward(a.grad)
print(x.grad)
# 通過(guò)遞歸來(lái)計(jì)算梯度
y.grad = np.array(1.0)
y.backward()
print(x.grad)
def square(x):
return SquareFunction()(x)
def exp(x):
return ExpFunction()(x)
x = Variable(np.array(0.5))
a = square(x)
b = exp(a)
y = square(b)
y.grad = np.array(1.0)
y.backward()
print("------------------")
print(x.grad)