前言
梯度下降法(Gradient Descent)是機(jī)器學(xué)習(xí)中最常用的優(yōu)化方法之一窘疮,常用來(lái)求解目標(biāo)函數(shù)的極值袋哼。
其基本原理非常簡(jiǎn)單:沿著目標(biāo)函數(shù)梯度下降的方向搜索極小值(也可以沿著梯度上升的方向搜索極大值)。
但是如何調(diào)整搜索的步長(zhǎng)(也叫學(xué)習(xí)率闸衫,Learning Rate)涛贯、如何加快收斂速度以及如何防止搜索時(shí)發(fā)生震蕩卻是一門值得深究的學(xué)問。
上兩篇博客《【梯度下降法】一:快速教程蔚出、Python簡(jiǎn)易實(shí)現(xiàn)以及對(duì)學(xué)習(xí)率的探討》與《【梯度下降法】二:沖量(momentum)的原理與Python實(shí)現(xiàn)》分別介紹了學(xué)習(xí)率大小對(duì)搜索過(guò)程的影響以及“沖量”的原理以及如何用“沖量”來(lái)解決收斂速度慢與收斂時(shí)發(fā)生震蕩的問題弟翘。接下來(lái)本篇博客將介紹梯度下降法中的第三個(gè)超參數(shù):decay虫腋。
PS:本系列博客全部源代碼可在本人的GitHub:monitor1379中下載。
學(xué)習(xí)率衰減因子:decay
首先先回顧一下不同學(xué)習(xí)率下梯度下降法的收斂過(guò)程(示例代碼在GitHub上可下載):
從上圖可看出稀余,學(xué)習(xí)率較大時(shí)悦冀,容易在搜索過(guò)程中發(fā)生震蕩,而發(fā)生震蕩的根本原因無(wú)非就是搜索的步長(zhǎng)邁的太大了睛琳。
回顧一下問題本身盒蟆,在使用梯度下降法求解目標(biāo)函數(shù)func(x) = x * x
的極小值時(shí),更新公式為x += v
师骗,其中每次x
的更新量v
為v = - dx * lr
历等,dx
為目標(biāo)函數(shù)func(x)
對(duì)x
的一階導(dǎo)數(shù)”侔可以想到寒屯,如果能夠讓lr
隨著迭代周期不斷衰減變小,那么搜索時(shí)邁的步長(zhǎng)就能不斷減少以減緩震蕩黍少。學(xué)習(xí)率衰減因子由此誕生:
lr_i = lr_start * 1.0 / (1.0 + decay * i)
上面的公式即為學(xué)習(xí)率衰減公式寡夹,其中lr_i
為第i
次迭代時(shí)的學(xué)習(xí)率,lr_start
為原始學(xué)習(xí)率厂置,decay
為一個(gè)介于[0.0, 1.0]
的小數(shù)菩掏。
從公式上可看出:
-
decay
越小,學(xué)習(xí)率衰減地越慢农渊,當(dāng)decay = 0
時(shí)患蹂,學(xué)習(xí)率保持不變或颊。 -
decay
越大砸紊,學(xué)習(xí)率衰減地越快,當(dāng)decay = 1
時(shí)囱挑,學(xué)習(xí)率衰減最快醉顽。
使用decay的梯度下降法Python實(shí)現(xiàn)代碼如下:
import numpy as np
import matplotlib.pyplot as plt
# 目標(biāo)函數(shù):y=x^2
def func(x):
return np.square(x)
# 目標(biāo)函數(shù)一階導(dǎo)數(shù):dy/dx=2*x
def dfunc(x):
return 2 * x
def GD_decay(x_start, df, epochs, lr, decay):
"""
帶有學(xué)習(xí)率衰減因子的梯度下降法。
:param x_start: x的起始點(diǎn)
:param df: 目標(biāo)函數(shù)的一階導(dǎo)函數(shù)
:param epochs: 迭代周期
:param lr: 學(xué)習(xí)率
:param decay: 學(xué)習(xí)率衰減因子
:return: x在每次迭代后的位置(包括起始點(diǎn))平挑,長(zhǎng)度為epochs+1
"""
xs = np.zeros(epochs+1)
x = x_start
xs[0] = x
v = 0
for i in range(epochs):
dx = df(x)
# 學(xué)習(xí)率衰減
lr_i = lr * 1.0 / (1.0 + decay * i)
# v表示x要改變的幅度
v = - dx * lr_i
x += v
xs[i+1] = x
return xs
使用以下測(cè)試與繪圖代碼demo3_GD_decay
來(lái)看一下當(dāng)學(xué)習(xí)率依次為lr = [0.1, 0.3, 0.9, 0.99]
與decay = [0.0, 0.01, 0.5, 0.9]
時(shí)的效果如何:
def demo3_GD_decay():
line_x = np.linspace(-5, 5, 100)
line_y = func(line_x)
plt.figure('Gradient Desent: Decay')
x_start = -5
epochs = 10
lr = [0.1, 0.3, 0.9, 0.99]
decay = [0.0, 0.01, 0.5, 0.9]
color = ['k', 'r', 'g', 'y']
row = len(lr)
col = len(decay)
size = np.ones(epochs + 1) * 10
size[-1] = 70
for i in range(row):
for j in range(col):
x = GD_decay(x_start, dfunc, epochs, lr=lr[i], decay=decay[j])
plt.subplot(row, col, i * col + j + 1)
plt.plot(line_x, line_y, c='b')
plt.plot(x, func(x), c=color[i], label='lr={}, de={}'.format(lr[i], decay[j]))
plt.scatter(x, func(x), c=color[i], s=size)
plt.legend(loc=0)
plt.show()
運(yùn)行結(jié)果如下圖所示游添,其中每行圖片的學(xué)習(xí)率一樣、decay依次增加通熄,每列圖片decay一樣唆涝,學(xué)習(xí)率依次增加:
簡(jiǎn)單分析一下結(jié)果:
- 在所有行中均可以看出,decay越大唇辨,學(xué)習(xí)率衰減地越快廊酣。
- 在第三行與第四行可看到,decay確實(shí)能夠?qū)φ鹗幤鸬綔p緩的作用赏枚。
那么亡驰,不同decay下學(xué)習(xí)率的衰減速度到底有多大的區(qū)別呢晓猛?接下來(lái)設(shè)置起始學(xué)習(xí)率為1.0,decay依次為[0.0, 0.001, 0.1, 0.5, 0.9, 0.99]
凡辱,迭代周期為300時(shí)學(xué)習(xí)率衰減的情況戒职,測(cè)試與繪圖代碼如下:
def demo4_how_to_chose_decay():
lr = 1.0
iterations = np.arange(300)
decay = [0.0, 0.001, 0.1, 0.5, 0.9, 0.99]
for i in range(len(decay)):
decay_lr = lr * (1.0 / (1.0 + decay[i] * iterations))
plt.plot(iterations, decay_lr, label='decay={}'.format(decay[i]))
plt.ylim([0, 1.1])
plt.legend(loc='best')
plt.show()
運(yùn)行結(jié)果如下圖所示⊥盖可以看到洪燥,當(dāng)decay為0.1時(shí),50次迭代后學(xué)習(xí)率已從1.0急劇降低到了0.2乳乌。如果decay設(shè)置得太大蚓曼,則可能會(huì)收斂到一個(gè)不是極值的地方呢∏张ぃ看來(lái)調(diào)參真是任重而道遠(yuǎn):
后記
關(guān)于【梯度下降法】的三個(gè)超參數(shù)的原理纫版、實(shí)現(xiàn)以及優(yōu)缺點(diǎn)已經(jīng)介紹完畢。對(duì)機(jī)器學(xué)習(xí)客情、深度學(xué)習(xí)與計(jì)算機(jī)視覺感興趣的童鞋可以關(guān)注本博主的簡(jiǎn)書博客以及GitHub:monitor1379哦~后續(xù)將繼續(xù)上更多的硬干貨其弊,謝謝大家的支持:P