線性回歸是利用數(shù)理統(tǒng)計(jì)中回歸分析裹纳,來(lái)確定兩種或兩種以上變量間相互依賴的定量關(guān)系的一種統(tǒng)計(jì)分析方法择葡。線性回歸在假設(shè)特證滿足線性關(guān)系,根據(jù)給定的訓(xùn)練數(shù)據(jù)訓(xùn)練一個(gè)模型剃氧,并用此模型進(jìn)行預(yù)測(cè)敏储。
有一組“工齡 - 工資”的數(shù)據(jù)表,我們假設(shè)它滿足線性關(guān)系y = a + bx朋鞍,其中x為工齡虹曙,y為工資。
工齡:0 1 2 3 4 5 6 7 8 9 10
工資:103100, 104900, 106800, 108700, 110400, 112300, 114200, 116100, 117800, 119700, 121600
定義損失函數(shù)J(a, b) 番舆,求其偏導(dǎo),得到梯度下降的公式矾踱。推導(dǎo)過(guò)程如下:
示例代碼如下:
import matplotlib.pyplot as plt
import numpy as np
y = (103100, 104900, 106800, 108700, 110400, 112300, 114200, 116100, 117800, 119700, 121600)
def calc_diff_a(a, b):
sum = 0
for x in range(0, 11):
sum = sum + 2 * a + 2 * b * x - 2 * y[x]
return sum
def calc_diff_b(a, b):
sum = 0
for x in range(0, 11):
sum = sum + x * (2 * a + 2 * b * x - 2 * y[x])
return sum
def cost(a, b):
sum = 0
for x in range(0, 11):
sum = sum + (a*a + b*b*x*x + 2*a*b*x - 2*a*y[x] - 2*b*x*y[x] + y[x]*y[x])
return sum;
if __name__ == "__main__":
num1 = 100000
num2 = 1
ratio = 0.0001
itercnt = 0
while itercnt < 50000:
tmp1 = calc_diff_a(num1, num2)
tmp2 = calc_diff_b(num1, num2)
num1 = num1 - ratio * tmp1
num2 = num2 - ratio * tmp2
itercnt = itercnt + 1
#print(tmp1, tmp2, cost(num1, num2))
print(num1, num2)
listx = np.linspace(0,10,11)
listy = num1 + num2 * listx
plt.figure()
plt.plot(listx, y, '*')
plt.plot(listx, listy)
plt.show()
運(yùn)行結(jié)果如下:
a = 103086.36363635205
b = 1848.181818183475