Python機器學習基礎教程學習筆記(3)——KNN處理forge數(shù)據(jù)集(分類)
1 常規(guī)引入
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
# 不想看到warnings
import warnings
warnings.filterwarnings("ignore", category=Warning)
from collections import Counter
2 forge數(shù)據(jù)集
- 一個模擬的二分類數(shù)據(jù)集
# 生成數(shù)據(jù)集
X,y = mglearn.datasets.make_forge()
print("X.shape:{}".format(X.shape)) # 26個數(shù)據(jù)點宛徊,2個特征
print("y.shape:{}".format(y.shape)) # 26個目錄值
print("classes of y:\n{}".format(Counter(y)))# 兩個類別0和1花沉,各13個數(shù)據(jù)點
X.shape:(26, 2)
y.shape:(26,)
classes of y:
Counter({1: 13, 0: 13})
# 數(shù)據(jù)集繪圖
mglearn.discrete_scatter(X[:,0],X[:,1],y)
plt.legend(["Class 0","Class 1"],loc=4)
plt.xlabel("first feature")
plt.ylabel("second feature")
plt.show()
output_7_0
3 knn分類
- 適用于二分類和多分類
# knn只考慮1個最近鄰的示例,五角星是新增的3個數(shù)據(jù)點,找到他們最近的1個點的類幔妨,就是預測的類
mglearn.plots.plot_knn_classification(n_neighbors=1)
output_9_0
# knn考慮3個最近鄰的示例宠叼,五角星是新增的3個數(shù)據(jù)點,找到他們最近的3個點的類肉津,用“投票法”找到3個點的類出現(xiàn)次數(shù)更多的類別作儿,就是預測的類
mglearn.plots.plot_knn_classification(n_neighbors=3)
output_10_0
4 用knn算法處理forge數(shù)據(jù)集
# 拆分訓練集與測試集
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=0)
# 引入knn分類器,設置k為3(n_neighbors=3)
from sklearn.neighbors import KNeighborsClassifier
clf = KNeighborsClassifier(n_neighbors=3)
# 進行訓練馋劈,對于knn來說攻锰,就是保存訓練集數(shù)據(jù)晾嘶,以便在測試時計算與鄰居之間的距離
clf.fit(X_train,y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=None, n_neighbors=3, p=2,
weights='uniform')
# 調用predict方法來進行預測
print("Test set predictions:{}".format(clf.predict(X_test)))
Test set predictions:[1 0 1 0 1 0 0]
# 評估模型的泛化能力的好壞,調用score方法
print("Test set accuracy:{:.2f}".format(clf.score(X_test,y_test)))
Test set accuracy:0.86
5 分析knn分類器
查看決策邊界(decision bundary)
# 查看1個娶吞、3個垒迂、9個鄰居三種情況的決策邊界可視化
# plt.subplots()是一個函數(shù),返回一個包含figure和axes對象的元組妒蛇。
# 因此机断,使用fig,ax = plt.subplots()將元組分解為fig和ax兩個變量。
fig,axes = plt.subplots(
1, # nrows=1绣夺,行數(shù)
3, # ncols=3吏奸,列數(shù)
figsize=(10,3) # 設置圖像大小
)
for n_neighbors,ax in zip([1,3,9],axes):
clf = KNeighborsClassifier(n_neighbors=n_neighbors).fit(X,y)
mglearn.plots.plot_2d_separator(clf,X,fill=True,eps=0.5,ax=ax,alpha=.4)
mglearn.discrete_scatter(X[:,0],X[:,1],y,ax=ax)
ax.set_title("{} neighbor(s)".format(n_neighbors))
ax.set_xlabel('feature 0')
ax.set_ylabel('feature 1')
ax.legend(loc=3)
output_18_0
從圖中看出:
- k=1,決策邊界緊跟著訓練數(shù)據(jù)
- k越大陶耍,決策邊界越平滑奋蔚,對應的模型越簡單(模型復雜度越低),泛化能力越強
- 極端情況烈钞,k=訓練集中數(shù)據(jù)點的個數(shù)泊碑,每個測試點的鄰居都完全相同,所有訓練結果也完全相同