梯度下降訓(xùn)練線性回歸(最優(yōu)化2)

實(shí)驗(yàn)?zāi)康?/h1>

梯度下降法是一個(gè)最優(yōu)化算法亡鼠,通常也稱為最速下降法飞蛹。最速下降法是求解無(wú)約束優(yōu)化問(wèn)題最簡(jiǎn)單和最古老的方法之一汛骂,雖然現(xiàn)已不具有實(shí)用性榆苞,但是許多有效算法都是以它為基礎(chǔ)進(jìn)行改進(jìn)和修正而得到的猛拴。最速下降法是用負(fù)梯度方向?yàn)樗阉鞣较虻模钏傧陆捣ㄔ浇咏繕?biāo)值蚀狰,步長(zhǎng)越小愉昆,前進(jìn)越慢。

線性回歸是利用數(shù)理統(tǒng)計(jì)中回歸分析麻蹋,來(lái)確定兩種或兩種以上變量間相互依賴的定量關(guān)系的一種統(tǒng)計(jì)分析方法跛溉,運(yùn)用十分廣泛。其表達(dá)形式為y = w'*x+e扮授,e為誤差服從均值為0的正態(tài)分布芳室。[1]

回歸分析中,只包括一個(gè)自變量和一個(gè)因變量刹勃,且二者的關(guān)系可用一條直線近似表示堪侯,這種回歸分析稱為一元線性回歸分析。如果回歸分析中包括兩個(gè)或兩個(gè)以上的自變量荔仁,且因變量和自變量之間是線性關(guān)系伍宦,則稱為多元線性回歸分析。

本次實(shí)驗(yàn)將驗(yàn)證梯度下降法可以求解線性回歸的參數(shù)乏梁。

實(shí)驗(yàn)環(huán)境

本次實(shí)驗(yàn)次洼,使用的程序語(yǔ)言為Python3.5,依賴的工具包為:

import matplotlib.pyplot as plt

import csv

import numpy as np

import pandas as pd

from mpl_toolkits.mplot3d import Axes3D

實(shí)驗(yàn)內(nèi)容及步驟

ü 畫出樣本分布圖遇骑;

ü 畫出線性回歸假設(shè)模型;

ü 畫出成本函數(shù)收斂曲線

ü 使用梯度下降法更新參數(shù)

在詳細(xì)了解梯度下降的算法之前卖毁,我們先看看相關(guān)的一些概念。

步長(zhǎng)(Learning rate):步長(zhǎng)決定了在梯度下降迭代的過(guò)程中落萎,每一步沿梯度負(fù)方向前進(jìn)的長(zhǎng)度亥啦。用上面下山的例子,步長(zhǎng)就是在當(dāng)前這一步所在位置沿著最陡峭最易下山的位置走的那一步的長(zhǎng)度练链。

特征(feature):指的是樣本中輸入部分翔脱,比如樣本(x0,y0),(x1,y1),則樣本特征為x,樣本輸出為y兑宇。

假設(shè)函數(shù)(hypothesis function):在監(jiān)督學(xué)習(xí)中碍侦,為了擬合輸入樣本粱坤,而使用的假設(shè)函數(shù)隶糕,記為hθ(x)。

損失函數(shù)(loss function):為了評(píng)估模型擬合的好壞站玄,通常用損失函數(shù)來(lái)度量擬合的程度枚驻。損失函數(shù)極小化,意味著擬合程度最好株旷,對(duì)應(yīng)的模型參數(shù)即為最優(yōu)參數(shù)再登。在線性回歸中尔邓,損失函數(shù)通常為樣本輸出和假設(shè)函數(shù)的差取平方。
預(yù)測(cè)值:
h_θ (x)=x^T θ
誤差:
J(θ)=1/2m ∑i=1~m〖[(h]θ (x(i) )-y(i) 〗 )^2
優(yōu)化目標(biāo):
min┬θ?〖J
θ 〗
參數(shù)更新:
θ=θ-α 1/m 〖((h-y)x)〗^T

實(shí)驗(yàn)結(jié)果及分析

對(duì)數(shù)據(jù)進(jìn)行了歸一化處理锉矢。

對(duì)1維數(shù)據(jù)梯嗽,迭代了300次0.02步長(zhǎng),500次0.01步長(zhǎng)沽损。最終結(jié)果如圖:


對(duì)2維數(shù)據(jù)灯节,迭代了700次0.01步長(zhǎng)。最終結(jié)果如圖:



如何調(diào)優(yōu)绵估?步長(zhǎng)炎疆,初始參數(shù),歸一化国裳。
方法變種:批量梯度下降形入,隨機(jī)梯度下降,小批量缝左。在每次選取樣本的數(shù)量上有區(qū)別亿遂。
橫向比較:梯度下降法和最小二乘法相比,梯度下降法需要選擇步長(zhǎng)盒使,而最小二乘法不需要崩掘。梯度下降法是迭代求解,最小二乘法是計(jì)算解析解少办。如果樣本量不算很大苞慢,且存在解析解,最小二乘法比起梯度下降法要有優(yōu)勢(shì)英妓,計(jì)算速度很快挽放。但是如果樣本量很大,用最小二乘法由于需要求一個(gè)超級(jí)大的逆矩陣蔓纠,這時(shí)就很難或者很慢才能求解解析解了辑畦,使用迭代的梯度下降法比較有優(yōu)勢(shì)。梯度下降法和牛頓法/擬牛頓法相比腿倚,兩者都是迭代求解纯出,不過(guò)梯度下降法是梯度求解,而牛頓法/擬牛頓法是用二階的海森矩陣的逆矩陣或偽逆矩陣求解敷燎。相對(duì)而言暂筝,使用牛頓法/擬牛頓法收斂更快。但是每次迭代的時(shí)間比梯度下降法長(zhǎng)硬贯。

程序

1.ipynb

%matplotlib inline
import matplotlib.pyplot as plt
import csv
import numpy as np
import pandas as pd

因?yàn)橐婚_始收斂效果不好焕襟,然后做了數(shù)據(jù)整理,方案是全部映射到10以內(nèi)一個(gè)數(shù)量級(jí)饭豹,歸一化鸵赖。

data=pd.read_csv('./1D.csv',header=None)
data=np.array(data)
div=np.array([max(data[:,0]),max(data[:,1])])
data=data/div*10
print(data.shape)
(97, 2)
num=10
val_x=np.ones((num,data.shape[1]))
sam_x=np.ones((data.shape[0]-num,data.shape[1]))

sam_x[:,1:]=data[:-num,:-1]
sam_y=data[:-num,-1:]
print("sx.shape",sam_x.shape,"sy.shape",sam_y.shape)
val_x[:,1:]=data[-num:,:-1]
val_y=data[-num:,-1:]
print("vx.shape",val_x.shape,"vy.shape",val_y.shape)
sx.shape (87, 2) sy.shape (87, 1)
vx.shape (10, 2) vy.shape (10, 1)
plt.scatter(data[:,0],data[:,-1],marker='x')
plt.grid()
# 預(yù)測(cè)值
def predict(theta,x):
    '''
    theta:(2,1)
    x:(n,2)
    '''
    return x.dot(theta).reshape((-1,1))
np.random.seed(1502520028)
# print(data.shape)
theta=np.random.normal(size=(2,1))
print(theta)

epoch=300
alpha=0.02
eps=0.1
# 定義成常量計(jì)算快
def const_error(h,y):
    return h-y
# 均方誤差
# h:pred_y
def cost(h,y,con):
    return (np.mean(con**2))/2
def grad(x,con):
    return np.mean(con*x,axis=0,keepdims=True).transpose()
[[-0.40838329]
 [ 1.31489742]]
sc=[]#訓(xùn)練集誤差
vc=[]#交叉集預(yù)測(cè)誤差
for i in range(epoch):
    h=predict(theta,sam_x)
    con=const_error(h,sam_y)
    g=grad(sam_x,con)
    theta=theta-alpha*g
    if i%(epoch//50)==0:
        sc.append(cost(h,sam_y,con))
        pre_y=predict(theta,val_x)
        vc.append(cost(pre_y,val_y,const_error(pre_y,val_y)))
plt.plot(np.arange(len(sc)),np.array(sc),label="training_cost")
plt.legend()
plt.plot(np.arange(len(vc)),np.array(vc),label="validating_cost")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("cost")
<matplotlib.text.Text at 0x85a3588>
plt.scatter(data[:,0],data[:,-1],marker='x')
plt.grid()
temp_x=np.linspace(-1,11,100)
plt.plot(temp_x,(theta[1]*temp_x+theta[0]))
plt.xlabel("x");plt.ylabel("y")
print("y=%-.2f%+.2f*x1" % (theta[0],theta[1]))
y=-1.20+1.02*x1
output_11_1.png

在上一階段务漩,兩個(gè)損失的收斂速度已經(jīng)變慢,那么我們又增加500次訓(xùn)練它褪。
為了避免在極值點(diǎn)處震蕩饵骨,同時(shí)減小步長(zhǎng)到0.01

epoch=500
alpha=0.01
print("我真的還想再活500年")
for i in range(epoch):
    h=predict(theta,sam_x)
    con=const_error(h,sam_y)
    g=grad(sam_x,con)
    theta=theta-alpha*g
    if i%(epoch//50)==0:
        sc.append(cost(h,sam_y,con))
        pre_y=predict(theta,val_x)
        vc.append(cost(pre_y,val_y,const_error(pre_y,val_y)))
plt.plot(np.arange(len(sc)),np.array(sc),label="training_cost")
plt.legend()
plt.plot(np.arange(len(vc)),np.array(vc),label="validating_cost")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("cost")


我真的還想再活500年





<matplotlib.text.Text at 0x882bac8>
plt.scatter(data[:,0],data[:,-1],marker='x')
plt.grid()
temp_x=np.linspace(-1,11,100)
plt.plot(temp_x,(theta[1]*temp_x+theta[0]))
plt.xlabel("x");plt.ylabel("y")
print("y=%-.2f%+.2f*x1" % (theta[0],theta[1]))
y=-1.41+1.07*x1

2.ipynb

%matplotlib inline
import matplotlib.pyplot as plt
import csv
import numpy as np
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D

數(shù)據(jù)整理方案是全部映射到10以內(nèi)一個(gè)數(shù)量級(jí),歸一化茫打。

data=pd.read_csv('./2D.csv',header=None)
data=np.array(data)
div=np.array([max(data[:,0]),max(data[:,1]),max(data[:,2])])
data=data/div*10
print(data.shape)
(47, 3)
num=8
val_x=np.ones((num,data.shape[1]))
sam_x=np.ones((data.shape[0]-num,data.shape[1]))
sam_x[:,1:]=data[:-num,:-1]
sam_y=data[:-num,-1:]
print("sx.shape",sam_x.shape,"sy.shape",sam_y.shape)
val_x[:,1:]=data[-num:,:-1]
val_y=data[-num:,-1:]
print("vx.shape",val_x.shape,"vy.shape",val_y.shape)
sx.shape (39, 3) sy.shape (39, 1)
vx.shape (8, 3) vy.shape (8, 1)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(data[:,0],data[:,1],data[:,-1],c='r',marker='^',depthshade=True)
<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x82e5cf8>
# 預(yù)測(cè)值
def predict(theta,x):
    '''
    theta:(3,1)
    x:(n,3)
    '''
    return x.dot(theta).reshape((-1,1))
np.random.seed(1502520028)
# print(data.shape)
theta=np.random.normal(size=(3,1))
print(theta)

epoch=700
alpha=0.01
eps=0.1
# 定義成常量計(jì)算快
def const_error(h,y):
    return h-y
# 均方誤差
# h:pred_y
def cost(h,y,con):
    return (np.mean(con**2))/2
def grad(x,con):
    return np.mean(con*x,axis=0,keepdims=True).transpose()
[[-0.40838329]
 [ 1.31489742]
 [-1.82310129]]
sc=[]#訓(xùn)練集誤差
vc=[]#交叉集預(yù)測(cè)誤差
for i in range(epoch):
    h=predict(theta,sam_x)
    con=const_error(h,sam_y)
    g=grad(sam_x,con)
    theta=theta-alpha*g
    if i%(epoch//100)==0:
        sc.append(cost(h,sam_y,con))
        pre_y=predict(theta,val_x)
        vc.append(cost(pre_y,val_y,const_error(pre_y,val_y)))
plt.plot(np.arange(len(sc)),np.array(sc),label="training_cost")
plt.legend()
plt.plot(np.arange(len(vc)),np.array(vc),label="validating_cost")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("cost")
<matplotlib.text.Text at 0x8583dd8>
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(data[:,0],data[:,1],data[:,-1],c='r',marker='^',depthshade=True)
temp_x=np.linspace(-1,12,100)
ax.plot(temp_x,temp_x,(theta[2]*temp_x+theta[1]*temp_x+theta[0]))
print("y=%-.2f%+.2f*x1%+.2f*x2" % (theta[0],theta[1],theta[2]))
y=0.17+0.90*x1+0.11*x2

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末宏悦,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子包吝,更是在濱河造成了極大的恐慌饼煞,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,110評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件诗越,死亡現(xiàn)場(chǎng)離奇詭異砖瞧,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)嚷狞,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,443評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門块促,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人床未,你說(shuō)我怎么就攤上這事竭翠。” “怎么了薇搁?”我有些...
    開封第一講書人閱讀 165,474評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵斋扰,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我啃洋,道長(zhǎng)传货,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,881評(píng)論 1 295
  • 正文 為了忘掉前任宏娄,我火速辦了婚禮问裕,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘孵坚。我一直安慰自己粮宛,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,902評(píng)論 6 392
  • 文/花漫 我一把揭開白布卖宠。 她就那樣靜靜地躺著巍杈,像睡著了一般。 火紅的嫁衣襯著肌膚如雪逗堵。 梳的紋絲不亂的頭發(fā)上眷昆,一...
    開封第一講書人閱讀 51,698評(píng)論 1 305
  • 那天蜒秤,我揣著相機(jī)與錄音汁咏,去河邊找鬼。 笑死纸泡,一個(gè)胖子當(dāng)著我的面吹牛蚤假,可吹牛的內(nèi)容都是我干的磷仰。 我是一名探鬼主播,決...
    沈念sama閱讀 40,418評(píng)論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼箍土,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼逢享!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起吴藻,我...
    開封第一講書人閱讀 39,332評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤瞒爬,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后沟堡,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體疮鲫,經(jīng)...
    沈念sama閱讀 45,796評(píng)論 1 316
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,968評(píng)論 3 337
  • 正文 我和宋清朗相戀三年弦叶,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了俊犯。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,110評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡伤哺,死狀恐怖燕侠,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情立莉,我是刑警寧澤绢彤,帶...
    沈念sama閱讀 35,792評(píng)論 5 346
  • 正文 年R本政府宣布,位于F島的核電站蜓耻,受9級(jí)特大地震影響茫舶,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜刹淌,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,455評(píng)論 3 331
  • 文/蒙蒙 一饶氏、第九天 我趴在偏房一處隱蔽的房頂上張望讥耗。 院中可真熱鬧,春花似錦疹启、人聲如沸古程。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,003評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)挣磨。三九已至,卻和暖如春荤懂,著一層夾襖步出監(jiān)牢的瞬間茁裙,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,130評(píng)論 1 272
  • 我被黑心中介騙來(lái)泰國(guó)打工节仿, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留呜达,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,348評(píng)論 3 373
  • 正文 我出身青樓粟耻,卻偏偏與公主長(zhǎng)得像查近,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子挤忙,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,047評(píng)論 2 355