機(jī)器學(xué)習(xí)造輪子之線性回歸
轉(zhuǎn)載請(qǐng)注明出去 http://www.reibang.com/p/8d5021339830
最近看了線性回歸侦讨,復(fù)習(xí)了一下微積分和線性代數(shù)观堂,想著學(xué)以致用洒缀,能不能自己動(dòng)手實(shí)現(xiàn)一把呢销钝。于是就動(dòng)手了醉锅。
線性回歸是比較基礎(chǔ)的算法侣签,是后面邏輯回歸的基礎(chǔ)芥玉。主要是通過(guò)一條直線來(lái)擬合樣本蛇摸。通常來(lái)說(shuō)只有教學(xué)意義。
來(lái)說(shuō)說(shuō)約定的符號(hào)灿巧,線性回歸參數(shù)主要由斜率和截距組成赶袄,這里用W表示斜率,b表示截距抠藕。大寫(xiě)的W表示這是一個(gè)向量饿肺。一般來(lái)說(shuō)是n_feauter_num數(shù)量,就是有多少個(gè)特征盾似,W的shape就是(n_feauter_num,1),截距b是一個(gè)常數(shù)敬辣,通過(guò)公式Y(jié)=W*X+b計(jì)算出目標(biāo)Y值,一般來(lái)說(shuō)零院,在機(jī)器學(xué)習(xí)中約定原始值為Y溉跃,預(yù)測(cè)值為Y_hat。下面來(lái)談?wù)劸唧w實(shí)現(xiàn)步驟
- 構(gòu)造數(shù)據(jù)
- 構(gòu)造loss function(coss function)
- 分別對(duì)W和b計(jì)算梯度(也是對(duì)cost function分別對(duì)W和b求導(dǎo))
- 計(jì)算Y_hat
- 多次迭代計(jì)算梯度告抄,直接收斂或者迭代結(jié)束
下面給出具體python代碼實(shí)現(xiàn)撰茎,本代碼是通用代碼,可以任意擴(kuò)展W打洼,代碼中計(jì)算loss和梯度的地方采用的向量實(shí)現(xiàn),因此增加W的維度不用修改代碼
import matplotlib.pyplot as plt
import numpy as np
def f(X):
w = np.array([1, 3, 2])
b = 10
return np.dot(X, w.T) + b
def cost(X, Y, w, b):
m = X.shape[0]
Z = np.dot(X, w) + b
Y_hat = Z.reshape(m, 1)
cost = np.sum(np.square(Y_hat - Y)) / (2 * m)
return cost
def gradient_descent(X, Y, W, b, learning_rate):
m = X.shape[0]
W = W - learning_rate * (1 / m) * X.T.dot((np.dot(X, W) + b - Y))
b = b - learning_rate * (1 / m) * np.sum(np.dot(X, W) + b - Y)
return W, b
def main():
# sample number
m = 5
# feature number
n = 3
total = m * n
# construct data
X = np.random.rand(total).reshape(m, n)
Y = f(X).reshape(m, 1)
# iris = datasets.load_iris()
# X, Y = iris.data, iris.target.reshape(150, 1)
# X = X[Y[:, 0] < 2]
# Y = Y[Y[:, 0] < 2]
# m = X.shape[0]
# n = X.shape[1]
# define parameter
W = np.ones((n, 1), dtype=float).reshape(n, 1)
b = 0.0
# def forward pass++
learning_rate = 0.1
iter_num = 10000
i = 0
J = []
while i < iter_num:
i = i + 1
W, b = gradient_descent(X, Y, W, b, learning_rate)
j = cost(X, Y, W, b)
J.append(j)
print(W, b)
print(j)
plt.plot(J)
plt.show()
if __name__ == '__main__':
main()
可以看到龄糊,結(jié)果輸出很接近預(yù)設(shè)參數(shù)[1,3,2]和10
是不是感覺(jué)so easy.
step: 4998 loss: 3.46349593719e-07
[[ 1.00286704]
[ 3.00463459]
[ 2.00173473]] 9.99528287088
step: 4999 loss: 3.45443124835e-07
[[ 1.00286329]
[ 3.00462853]
[ 2.00173246]] 9.99528904819
step: 5000 loss: 3.44539028368e-07