前言
梯度下降法(Gradient Descent)是機器學(xué)習(xí)中最常用的優(yōu)化方法之一,常用來求解目標(biāo)函數(shù)的極值膜眠。
其基本原理非常簡單:沿著目標(biāo)函數(shù)梯度下降的方向搜索極小值(也可以沿著梯度上升的方向搜索極大值)拗踢。
但是如何調(diào)整搜索的步長(也叫學(xué)習(xí)率脚牍,Learning Rate)、如何加快收斂速度以及如何防止搜索時發(fā)生震蕩卻是一門值得深究的學(xué)問巢墅。
在上篇博客《【梯度下降法】一:快速教程诸狭、Python簡易實現(xiàn)以及對學(xué)習(xí)率的探討》中我們簡單分析了學(xué)習(xí)率大小對搜索過程的影響券膀,發(fā)現(xiàn):
- 學(xué)習(xí)率較小時,收斂到極值的速度較慢驯遇。
- 學(xué)習(xí)率較大時芹彬,容易在搜索過程中發(fā)生震蕩。
因此本篇博客將簡單講解“沖量”的原理以及如何用“沖量”來解決上述兩個問題叉庐。
全部源代碼可在本人的GitHub:monitor1379中下載舒帮。
沖量:momentum
“沖量”這個概念源自于物理中的力學(xué),表示力對時間的積累效應(yīng)陡叠。
在普通的梯度下降法x += v
中玩郊,每次x
的更新量v
為v = - dx * lr
,其中dx
為目標(biāo)函數(shù)func(x)
對x
的一階導(dǎo)數(shù)匾竿,瓦宜。
當(dāng)使用沖量時,則把每次x
的更新量v
考慮為本次的梯度下降量- dx * lr
與上次x
的更新量v
乘上一個介于[0, 1]
的因子momentum
的和岭妖,即v = - dx * lr + v * momemtum
。
從公式上可看出:
- 當(dāng)本次梯度下降
- dx * lr
的方向與上次更新量v
的方向相同時反璃,上次的更新量能夠?qū)Ρ敬蔚乃阉髌鸬揭粋€正向加速的作用昵慌。 - 當(dāng)本次梯度下降
- dx * lr
的方向與上次更新量v
的方向相反時,上次的更新量能夠?qū)Ρ敬蔚乃阉髌鸬揭粋€減速的作用淮蜈。
使用沖量的梯度下降法的Python代碼如下:
import numpy as np
import matplotlib.pyplot as plt
# 目標(biāo)函數(shù):y=x^2
def func(x):
return np.square(x)
# 目標(biāo)函數(shù)一階導(dǎo)數(shù):dy/dx=2*x
def dfunc(x):
return 2 * x
def GD_momentum(x_start, df, epochs, lr, momentum):
"""
帶有沖量的梯度下降法斋攀。
:param x_start: x的起始點
:param df: 目標(biāo)函數(shù)的一階導(dǎo)函數(shù)
:param epochs: 迭代周期
:param lr: 學(xué)習(xí)率
:param momentum: 沖量
:return: x在每次迭代后的位置(包括起始點),長度為epochs+1
"""
xs = np.zeros(epochs+1)
x = x_start
xs[0] = x
v = 0
for i in range(epochs):
dx = df(x)
# v表示x要改變的幅度
v = - dx * lr + momentum * v
x += v
xs[i+1] = x
return xs
為了查看momentum大小對不同學(xué)習(xí)率的影響梧田,此處設(shè)置學(xué)習(xí)率為lr = [0.01, 0.1, 0.6, 0.9]
淳蔼,沖量依次為momentum = [0.0, 0.1, 0.5, 0.9]
,起始位置為x_start = -5
裁眯,迭代周期為6鹉梨。測試以及繪圖代碼如下:
def demo2_GD_momentum():
line_x = np.linspace(-5, 5, 100)
line_y = func(line_x)
plt.figure('Gradient Desent: Learning Rate, Momentum')
x_start = -5
epochs = 6
lr = [0.01, 0.1, 0.6, 0.9]
momentum = [0.0, 0.1, 0.5, 0.9]
color = ['k', 'r', 'g', 'y']
row = len(lr)
col = len(momentum)
size = np.ones(epochs+1) * 10
size[-1] = 70
for i in range(row):
for j in range(col):
x = GD_momentum(x_start, dfunc, epochs, lr=lr[i], momentum=momentum[j])
plt.subplot(row, col, i * col + j + 1)
plt.plot(line_x, line_y, c='b')
plt.plot(x, func(x), c=color[i], label='lr={}, mo={}'.format(lr[i], momentum[j]))
plt.scatter(x, func(x), c=color[i], s=size)
plt.legend(loc=0)
plt.show()
運行結(jié)果如下圖所示,每一行的圖的學(xué)習(xí)率lr一樣穿稳,每一列的momentum一樣存皂,最左列為不使用momentum時的收斂情況:
簡單分析一下運行結(jié)果:
- 從第一行可看出:在學(xué)習(xí)率較小的時候,適當(dāng)?shù)膍omentum能夠起到一個加速收斂速度的作用逢艘。
- 從第四行可看出:在學(xué)習(xí)率較大的時候旦袋,適當(dāng)?shù)膍omentum能夠起到一個減小收斂時震蕩幅度的作用。
從上述兩點來看它改,momentum確實能夠解決在篇頭提到的兩個問題疤孕。
然而在第二行與第三行的最后一列圖片中也發(fā)現(xiàn)了一個問題,當(dāng)momentum較大時央拖,原本能夠正確收斂的時候卻因為剎不住車跑過頭了祭阀。那么怎么繼續(xù)解決這個新出現(xiàn)的問題呢鹉戚?下一篇博客《【梯度下降法】三:學(xué)習(xí)率衰減因子(decay)的原理與Python實現(xiàn)》將介紹如何使用學(xué)習(xí)率衰減因子decay來讓學(xué)習(xí)率隨著迭代周期不斷變小,讓梯度下降法收斂時的“震蕩”與“跑偏”進(jìn)一步減少的方法柬讨。