前言
梯度下降法(Gradient Descent)是機(jī)器學(xué)習(xí)中最常用的優(yōu)化方法之一年局,常用來(lái)求解目標(biāo)函數(shù)的極值。
其基本原理非常簡(jiǎn)單:沿著目標(biāo)函數(shù)梯度下降的方向搜索極小值(也可以沿著梯度上升的方向搜索極大值)刷晋。
但是如何調(diào)整搜索的步長(zhǎng)(也叫學(xué)習(xí)率盖高,Learning Rate)、如何加快收斂速度以及如何防止搜索時(shí)發(fā)生震蕩卻是一門(mén)值得深究的學(xué)問(wèn)眼虱。接下來(lái)本文將分析第一個(gè)問(wèn)題:學(xué)習(xí)率的大小對(duì)搜索過(guò)程的影響喻奥。全部源代碼可在本人的GitHub:monitor1379中下載。
快速教程
前言啰嗦完了捏悬,接下來(lái)直接上干貨:如何編寫(xiě)梯度下降法撞蚕。代碼運(yùn)行環(huán)境為Python 2.7.11
+ NumPy 1.11.0
+ Matplotlib 1.5.1
。
首先先假設(shè)現(xiàn)在我們需要求解目標(biāo)函數(shù)func(x) = x * x
的極小值过牙,由于func
是一個(gè)凸函數(shù)诈豌,因此它唯一的極小值同時(shí)也是它的最小值仆救,其一階導(dǎo)函數(shù) 為dfunc(x) = 2 * x
。
import numpy as np
import matplotlib.pyplot as plt
# 目標(biāo)函數(shù):y=x^2
def func(x):
return np.square(x)
# 目標(biāo)函數(shù)一階導(dǎo)數(shù):dy/dx=2*x
def dfunc(x):
return 2 * x
接下來(lái)編寫(xiě)梯度下降法函數(shù):
# Gradient Descent
def GD(x_start, df, epochs, lr):
"""
梯度下降法矫渔。給定起始點(diǎn)與目標(biāo)函數(shù)的一階導(dǎo)函數(shù)彤蔽,求在epochs次迭代中x的更新值
:param x_start: x的起始點(diǎn)
:param df: 目標(biāo)函數(shù)的一階導(dǎo)函數(shù)
:param epochs: 迭代周期
:param lr: 學(xué)習(xí)率
:return: x在每次迭代后的位置(包括起始點(diǎn)),長(zhǎng)度為epochs+1
"""
xs = np.zeros(epochs+1)
x = x_start
xs[0] = x
for i in range(epochs):
dx = df(x)
# v表示x要改變的幅度
v = - dx * lr
x += v
xs[i+1] = x
return xs
需要注意的是參數(shù)df
是一個(gè)函數(shù)指針庙洼,即需要傳進(jìn)我們的目標(biāo)函數(shù)一階導(dǎo)函數(shù)顿痪。
測(cè)試代碼如下,假設(shè)起始搜索點(diǎn)為-5油够,迭代周期為5蚁袭,學(xué)習(xí)率為0.3:
def demo0_GD():
x_start = -5
epochs = 5
lr = 0.3
x = GD(x_start, dfunc, epochs, lr=lr)
print x
# 輸出:[-5. -2. -0.8 -0.32 -0.128 -0.0512]
繼續(xù)修改一下demo0_GD
函數(shù)以更加直觀地查看梯度下降法的搜索過(guò)程:
def demo0_GD():
"""演示如何使用梯度下降法GD()"""
line_x = np.linspace(-5, 5, 100)
line_y = func(line_x)
x_start = -5
epochs = 5
lr = 0.3
x = GD(x_start, dfunc, epochs, lr=lr)
color = 'r'
plt.plot(line_x, line_y, c='b')
plt.plot(x, func(x), c=color, label='lr={}'.format(lr))
plt.scatter(x, func(x), c=color, )
plt.legend()
plt.show()
從運(yùn)行結(jié)果來(lái)看,當(dāng)學(xué)習(xí)率為0.3的時(shí)候石咬,迭代5個(gè)周期似乎便能得到蠻不錯(cuò)的結(jié)果了揩悄。
梯度下降法確實(shí)是求解非線性方程極值的利器之一,但是如果學(xué)習(xí)率沒(méi)有調(diào)整好的話會(huì)發(fā)生什么樣的事情呢鬼悠?
學(xué)習(xí)率對(duì)梯度下降法的影響
在上節(jié)代碼的基礎(chǔ)上編寫(xiě)新的測(cè)試代碼demo1_GD_lr
删性,設(shè)置學(xué)習(xí)率分別為0.1、0.3與0.9:
def demo1_GD_lr():
# 函數(shù)圖像
line_x = np.linspace(-5, 5, 100)
line_y = func(line_x)
plt.figure('Gradient Desent: Learning Rate')
x_start = -5
epochs = 5
lr = [0.1, 0.3, 0.9]
color = ['r', 'g', 'y']
size = np.ones(epochs+1) * 10
size[-1] = 70
for i in range(len(lr)):
x = GD(x_start, dfunc, epochs, lr=lr[i])
plt.subplot(1, 3, i+1)
plt.plot(line_x, line_y, c='b')
plt.plot(x, func(x), c=color[i], label='lr={}'.format(lr[i]))
plt.scatter(x, func(x), c=color[i])
plt.legend()
plt.show()
從下圖輸出結(jié)果可以看出兩點(diǎn)焕窝,在迭代周期不變的情況下:
- 學(xué)習(xí)率較小時(shí)蹬挺,收斂到正確結(jié)果的速度較慢。
- 學(xué)習(xí)率較大時(shí)它掂,容易在搜索過(guò)程中發(fā)生震蕩巴帮。
綜上可以發(fā)現(xiàn),學(xué)習(xí)率大小對(duì)梯度下降法的搜索過(guò)程起著非常大的影響虐秋,為了解決上述的兩個(gè)問(wèn)題榕茧,接下來(lái)的博客《【梯度下降法】二:沖量(momentum)的原理與Python實(shí)現(xiàn)》將講解沖量(momentum)參數(shù)是如何在梯度下降法中起到加速收斂與減少震蕩的作用。