前幾天學(xué)了K近鄰的分類問(wèn)題杀赢,這幾天開(kāi)始學(xué)習(xí)K近鄰的回歸問(wèn)題。
先通過(guò)簡(jiǎn)單的三組圖對(duì)比楞件,K分別取1和敬,3盾饮,5時(shí)的預(yù)測(cè)結(jié)果
K=1
mglearn.plots.plot_knn_regression(n_neighbors=1)
輸出結(jié)果:
K=3
mglearn.plots.plot_knn_regression(n_neighbors=3)
輸出結(jié)果:
image.png
K=5
mglearn.plots.plot_knn_regression(n_neighbors=5)
輸出結(jié)果:
image.png
與分類問(wèn)題的K近鄰算法類似采桃,回歸問(wèn)題的K近鄰,也是通過(guò)選擇測(cè)試值附近的K個(gè)數(shù)據(jù)的平均值作為預(yù)測(cè)值丘损。
KNeighborsRegressor
接下來(lái)詳細(xì)看看scikit-learn中實(shí)現(xiàn)的回歸KNN算法吧普办,與之前看的KNeighborsClassifier類似的。
from sklearn.neighbors import KNeighborsRegressor
X,y= mglearn.datasets.make_wave(n_samples=40)
Xtrain,Xtest,ytrain,ytest=train_test_split(X,y,random_state=0)
reg=KNeighborsRegressor(n_neighbors=3)
reg.fit(Xtrain,ytrain)
print("test result is :\n{}".format(reg.predict(Xtest)))
print("test R^2 is {:.2f}".format(reg.score(Xtest,ytest)))
輸出結(jié)果:
test result is :
[-0.05396539 0.35686046 1.13671923 -1.89415682 -1.13881398 -1.63113382
0.35686046 0.91241374 -0.44680446 -1.13881398]
test R^2 is 0.83
- 在回歸問(wèn)題中号俐,使用R平方來(lái)度量準(zhǔn)確性泌豆,0.83則相對(duì)擬合效果較好定庵。