梯度下降法
梯度定義
梯度的本意是一個向量(矢量)龄坪,表示某一函數(shù)在該點處的方向導數(shù)沿著該方向取得最大值,即函數(shù)在該點處沿著該方向(此梯度的方向)變化最快泪姨,變化率最大(為該梯度的模)肠槽。
<p align="right">--------百度百科</p>
對于來說,其梯度為:
對于來說阔加,其梯度為:
梯度下降法思路
因為梯度是函數(shù)上升最快的方向饵史,所以如果我們要尋找函數(shù)的最小值,只需沿著梯度的反方向尋找即可胜榔。這里以為例胳喷,簡述梯度下降法實現(xiàn)的大體步驟:
- 確定變量的初始點,從初始點開始一步步向函數(shù)最小值逼近苗分。
- 求函數(shù)梯度厌蔽,然后求梯度的反向牵辣,將變量的初始點代入摔癣,確定變量變化的方向:;用求得的梯度向量(變量變化的方向)乘以學習率 (變量變化的步長)得到一個新的向量纬向;變量的初始點加上求得的新向量择浊,到達下一個點。
- 判斷此時函數(shù)值的變化量是否滿足精度要求逾条。定義一個我們認為滿足要求的精度琢岩;用上一個點的函數(shù)值減去當前點的函數(shù)值,得到此時函數(shù)值變化量的精度值(可以近似認為p為損失函數(shù))师脂;判斷是否成立担孔。不成立則反復執(zhí)行步驟2、3吃警。
但是梯度下降法對初始點的選取要求比較高糕篇,選取不當容易陷入極小值(局部最優(yōu)解)。
梯度下降法的簡單應用
梯度下降法求二維曲線的最小值
下圖為梯度下降法求曲線最小值的結果圖酌心,左圖紅色的點為求解過程中的過程點拌消,右圖為求解過程中精度的變化(損失函數(shù)值的變化),代碼見附錄安券。
梯度下降法求三維曲面的最小值
下圖為梯度下降法求曲面最小值的結果圖墩崩,圖中紅色的點為求解過程中的過程點,代碼見附錄侯勉。
代碼附錄
# -*- encoding=utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as aplt
from mpl_toolkits.mplot3d.axes3d import Axes3D
import sympy
class gradientDescent(object):
def init2D(self,vector:float,precision:float,startPoint:float):
"""
vector:學習率
precision:精度
startPoint:起始點
"""
self.vector = vector
self.precision = precision
self.startPoint = startPoint
self.startPrecision = precision + 1
def init3D(self,vector:float,precision:float,startVar1Point:float,startVar2Point:float):
"""
vector:學習率
precision:精度
startVar1Point:變量1的起始位置
startVar2Point:變量2的起始位置
"""
self.vector = vector
self.precision = precision
self.startVar1Point = startVar1Point
self.startVar2Point = startVar2Point
self.startPrecision = precision + 1
def singleVar2D(self, func:str, var:str):
grad = sympy.diff(func, var)
grad = str(grad)
xpoint = []
ypoint = []
errors = []
x = self.startPoint
while self.startPrecision > self.precision:
y = eval(func)
xpoint.append(x)
ypoint.append(y)
x1 = x - self.vector*eval(grad)
x = x1
y1 = eval(func)
self.startPrecision = y - y1
errors.append(self.startPrecision)
xpoint.append(x)
ypoint.append(y)
xlen = len(xpoint)
return [xpoint,ypoint,errors,xlen]
def doubleVar3D(self, func:str, var1:str, var2:str):
var1Grad = sympy.diff(func, var1)
var1Grad = str(var1Grad)
var1Grad = var1Grad.replace("sqrt","np.sqrt")
var2Grad = sympy.diff(func, var2)
var2Grad = str(var2Grad)
var2Grad = var2Grad.replace("sqrt","np.sqrt")
func = func.replace("sqrt","np.sqrt")
xpoint = []
ypoint = []
zpoint = []
errors = []
x = self.startVar1Point
y = self.startVar2Point
while self.startPrecision > self.precision:
z = eval(func)
xpoint.append(x)
ypoint.append(y)
zpoint.append(z)
x1 = x - self.vector*eval(var1Grad)
y1 = y - self.vector*eval(var2Grad)
x = x1
y = y1
z1 = eval(func)
self.startPrecision = z - z1
errors.append(self.startPrecision)
xpoint.append(x)
ypoint.append(y)
zpoint.append(z)
xlen = len(xpoint)
return [xpoint,ypoint,zpoint,errors,xlen]
if __name__ == '__main__':
xData = np.arange(-100,100,0.1)
yData = xData**2 + 2*xData + 5
vector=0.2
precision=10e-6
startPoint=-100
x = sympy.symbols("x")
func = "x**2+2*x+5"
gradient_descent = gradientDescent()
gradient_descent.init2D(vector,precision,startPoint)
[xpoint,ypoint,errors,xlen] = gradient_descent.singleVar2D(func,x)
fig,ax = plt.subplots(figsize=(12,8),ncols=2,nrows=1)
for i in range(xlen):
ax[0].cla()
ax[0].plot(xData,yData,color="green",label="$y=x^2+2x+5$")
ax[0].scatter(xpoint[i],ypoint[i],color="red",label="process point")
plt.pause(0.1)
ax[0].legend(loc = "best")
ax[1].plot(errors,label="Loss curve")
ax[1].legend(loc = "best")
plt.pause(0.1)
plt.show()
# =======================================================================
xData = np.arange(-100,100,0.1)
yData = np.arange(-100,100,0.1)
X,Y = np.meshgrid(xData,yData)
# z = sqrt(x^2+y^2)
Z = np.sqrt(X**2+Y**2)
x = sympy.symbols("x")
y = sympy.symbols("y")
func = "sqrt(x**2+y**2)"
vector=0.2
precision=10e-6
startVar1Point=100
startVar2Point=-100
gradient_descent = gradientDescent()
gradient_descent.init3D(vector, precision, startVar1Point, startVar2Point)
[xpoint,ypoint,zpoint,errors,xlen] = gradient_descent.doubleVar3D(func,x,y)
fig = plt.figure()
ax = Axes3D(fig)
surf = ax.plot_surface(X,Y,Z,label="$z=\sqrt{x^2+y^2}$")
ax.scatter(xpoint,ypoint,zpoint,color="red",label="process point")
# 解決標簽報錯鹦筹,不顯示問題
surf._facecolors2d=surf._facecolors3d
surf._edgecolors2d=surf._edgecolors3d
ax.legend()
plt.show()