5.2 Scikit-Learn 簡介
譯者:飛龍
協(xié)議:CC BY-NC-SA 4.0
譯文沒有得到原作者授權(quán)贰盗,不保證與原文的意思嚴格一致塑荒。
有幾個 Python 庫提供一系列機器學(xué)習(xí)算法的實現(xiàn)嗤形。最著名的是 Scikit-Learn捺宗,一個提供大量常見算法的高效版本的軟件包鹦付。 Scikit-Learn 的特點是簡潔被因,統(tǒng)一茬缩,流線型的 API赤惊,以及非常實用和完整的在線文檔。這種一致性的好處是凰锡,一旦了解了 Scikit-Learn 中一種類型的模型的基本用法和語法未舟,切換到新的模型或算法就非常簡單圈暗。
本節(jié)提供了 Scikit-Learn API 的概述;對這些API元素的了解裕膀,會成為理解以下章節(jié)中機器學(xué)習(xí)算法和方法的更深入的實際討論的基礎(chǔ)员串。
我們將首先介紹 Scikit-Learn 中的數(shù)據(jù)表示形式,然后設(shè)計 Estimator API昼扛,最后通過一個更有趣的例子寸齐,使用這些工具來探索一組手寫數(shù)字圖像。
Scikit-Learn 中的數(shù)據(jù)表示
機器學(xué)習(xí)是從數(shù)據(jù)創(chuàng)建模型:因此抄谐,我們將首先討論如何表示數(shù)據(jù)渺鹦,以便計算機理解。 在 Scikit-Learn 中考慮數(shù)據(jù)的最佳方式就是數(shù)據(jù)表蛹含。
數(shù)據(jù)作為表
一個基本表格是二維數(shù)據(jù)網(wǎng)格毅厚,其中行表示數(shù)據(jù)集的各個元素,列表示與這些元素中的每一個相關(guān)的數(shù)量浦箱。 例如吸耿,考慮 Iris 數(shù)據(jù)集,由 Ronald Fisher 于 1936 年進行了著名的分析酷窥。我們可以使用 Seaborn 庫以 Pandas DataFrame 的形式下載此數(shù)據(jù)集:
import seaborn as sns
iris = sns.load_dataset('iris')
iris.head()
這里的每一行數(shù)據(jù)指代一個觀察到的花珍语,行數(shù)是數(shù)據(jù)集中花的總數(shù)。 一般來說竖幔,我們將把矩陣行作為樣本板乙,將行數(shù)稱為n_samples
。
類似地拳氢,數(shù)據(jù)的每一列都是描述每個樣本的特定的定量信息募逞。 一般來說,我們將把矩陣的列稱為特征馋评,列數(shù)稱為n_features
放接。
特征矩陣
該表的布局清楚地表明,信息可以當做二維數(shù)組或矩陣留特,我們稱之為特征矩陣纠脾。 按照慣例,這個特征矩陣通常被存儲在一個名為X
的變量中蜕青。特征矩陣被假設(shè)為二維的苟蹈,形狀為[n_samples,n_features]
右核,并且最常使用NumPy
數(shù)組或Pandas DataFrame
來存放慧脱,盡管有些 Scikit-Learn 模型也接受 SciPy 稀疏矩陣。
樣本(即行)總是指代由數(shù)據(jù)集描述的各個對象贺喝。 例如菱鸥,樣本可能是一朵花宗兼,一個人,一個文檔氮采,一個圖像殷绍,一個聲音文件,一個視頻鹊漠,一個天文物體篡帕,或者你可以用一組定量測量來描述的任何東西。
特征(即列)總是指以定量方式描述每個樣本的不同觀察結(jié)果贸呢。 特征通常是實值镰烧,但在某些情況下可能是布爾值或離散值。
目標數(shù)組
除了特征矩陣X
之外楞陷,我們還通常使用標簽或目標數(shù)組怔鳖,按照慣例,我們通常稱為y
固蛾。目標數(shù)組通常是一維结执,長度為n_samples
,通常包含在 NumPy 數(shù)組或 Pandas Series 中艾凯。目標數(shù)組可以具有連續(xù)的數(shù)值或離散分類/標簽献幔。雖然一些 Scikit-Learn 估計器確實以二維[n_samples,n_targets]
目標數(shù)組的形式處理多個目標值趾诗,但我們將主要處理一維目標數(shù)組的常見情況蜡感。
常常有一點令人困惑的是,目標數(shù)組與其他特征列的不同之處恃泪。目標數(shù)組的特征在于郑兴,它通常是我們要從數(shù)據(jù)中預(yù)測的數(shù)量:在統(tǒng)計學(xué)上,它是因變量贝乎。例如情连,在上述數(shù)據(jù)中,我們可能希望構(gòu)建一個模型览效,可以基于其他度量來預(yù)測花的種類却舀;在這種情況下,物種列將被視為目標數(shù)組锤灿。
考慮到這個目標數(shù)組挽拔,我們可以使用 Seaborn(參見可視化與 Seaborn)來方便地顯示數(shù)據(jù):
%matplotlib inline
import seaborn as sns; sns.set()
sns.pairplot(iris, hue='species', size=1.5);
對于 Scikit-Learn 中的使用,我們會從DataFrame
提取特征矩陣和目標數(shù)組衡招。我們可以用一些第三章中的 PandasDataFrame
操作實現(xiàn)篱昔。
X_iris = iris.drop('species', axis=1)
X_iris.shape
# (150, 4)
y_iris = iris['species']
y_iris.shape
# (150,)
總之,特征和目標值的預(yù)期布局始腾,可以在下圖中顯示:
將這個數(shù)據(jù)合理格式化之后州刽,我們可以轉(zhuǎn)而思考 Scikit-Learn 的估計器 API 了。
Scikit-Learn 的估計器 API
Scikit-Learn API 的設(shè)計思想浪箭,是 Scikit-Learn API 的說明書所述的以下指導(dǎo)原則:
一致性:所有對象共享一個通用接口穗椅,從一組有限方法抽取,具有一致的文檔奶栖。
檢查:所有指定的參數(shù)值都公開為公共屬性匹表。
有限對象層次:只有算法由 Python 類表示;數(shù)據(jù)集以標準格式(NumPy 數(shù)組宣鄙,Pandas DataFrames袍镀,SciPy 稀疏矩陣)表示,參數(shù)名稱使用標準 Python 字符串冻晤。
組成:許多機器學(xué)習(xí)任務(wù)可以表達為更基礎(chǔ)的算法的序列苇羡,而 Scikit-Learn 可以盡可能地利用這一點。
敏感默認值:當模型需要用戶指定的參數(shù)時鼻弧,庫定義了一個適當?shù)哪J值设江。
在實踐中,一旦理解了基本原理攘轩,這些原則使 Scikit-Learn 非常容易使用叉存。 Scikit-Learn 中的每個機器學(xué)習(xí)算法都通過 Estimator API 實現(xiàn),該 API 為廣泛的機器學(xué)習(xí)應(yīng)用提供了一致的接口度帮。
API 基礎(chǔ)
通常歼捏,使用 Scikit-Learn 估計器 API 的步驟如下(我們將在以下部分中詳細介紹一些詳細示例)。
- 通過從 Scikit-Learn 導(dǎo)入適當?shù)墓烙嬵惐颗瘢瑏磉x擇一類模型甫菠。
- 通過使用所需的值實例化此類,來選擇模型超參數(shù)冕屯。
- 在上述討論之后寂诱,將數(shù)據(jù)排列成特征矩陣和目標向量。
- 通過調(diào)用模型實例的
fit
方法安聘,使用模型來擬合數(shù)據(jù)痰洒。 - 將模型應(yīng)用于新數(shù)據(jù):
- 對于監(jiān)督學(xué)習(xí),我們通常使用
predict()
方法預(yù)測未知數(shù)據(jù)的標簽浴韭。 - 對于無監(jiān)督學(xué)習(xí)丘喻,我們經(jīng)常使用
transform()
或predict()
方法來轉(zhuǎn)換或推斷數(shù)據(jù)的屬性。
- 對于監(jiān)督學(xué)習(xí),我們通常使用
我們現(xiàn)在將逐步介紹幾個簡單示例念颈,應(yīng)用監(jiān)督和無監(jiān)督學(xué)習(xí)方法泉粉。
監(jiān)督學(xué)習(xí)示例:簡單線性回歸
作為這個過程的一個例子,讓我們考慮一個簡單的線性回歸,也就是說嗡靡,一種常見情況,使用直線來擬合(x,y)
數(shù)據(jù)讨彼。 我們將以下簡單數(shù)據(jù)用于回歸示例:
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.RandomState(42)
x = 10 * rng.rand(50)
y = 2 * x - 1 + rng.randn(50)
plt.scatter(x, y);
有了這些數(shù)據(jù)歉井,我們可以使用前面提到的秘籍。 我們來看一下這個過程:
1. 選擇一個模型類
在 Scikit-Learn 中哈误,每個模型類都由 Python 類表示哩至。 所以,例如蜜自,如果我們想要計算一個簡單的線性回歸模型菩貌,我們可以導(dǎo)入線性回歸類:
from sklearn.linear_model import LinearRegression
要注意也存在更通用的線性回歸模型,你可以在 sklearn.linear_model 模型文檔中了解更多重荠。
2. 選擇模型超參數(shù)
一個重點是箭阶,模型類與模型實例不一樣。
一旦我們決定了我們的模型類晚缩,我們?nèi)匀挥幸恍┻x擇尾膊。根據(jù)我們正在使用的模型類,我們可能需要回答以下一個或多個問題:
- 我們希望擬合偏移(即縱截距)嗎荞彼?
- 我們是否希望將模型歸一化冈敛?
- 我們是否希望預(yù)處理我們的特征,來增加模型的靈活性鸣皂?
- 我們想在我們的模型中使用什么程度的正則化抓谴?
- 我們想要使用多少個模型組件?
這些是重要選擇的示例寞缝,在選擇模型類后必須做出癌压。這些選擇通常表示為超參數(shù),或在模型擬合數(shù)據(jù)之前必須設(shè)置的參數(shù)荆陆。在 Scikit-Learn 中滩届,通過在模型實例化下傳遞值來選擇超參數(shù)。我們將在超參數(shù)和模型驗證中被啼,探討如何定量地改進超參數(shù)的選擇帜消。
對于我們的線性回歸示例,我們可以實例化LinearRegression
類浓体,并指定我們想使用fit_intercept
超參數(shù)擬合截距:
model = LinearRegression(fit_intercept=True)
model
# LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
請記住泡挺,當模型被實例化時,唯一的操作是存儲這些超參數(shù)值命浴。 特別是娄猫,我們還沒有將模型應(yīng)用于任何數(shù)據(jù):Scikit-Learn API 非常清楚模型選擇和模型對數(shù)據(jù)應(yīng)用之間的區(qū)別贱除。
3. 將數(shù)據(jù)排列為特征矩陣和目標向量
以前,我們詳細介紹了 Scikit-Learn 數(shù)據(jù)表示媳溺,它需要二維特征矩陣和一維目標數(shù)組月幌。 這里我們的目標變量y
已經(jīng)是正確的形式(長度為n_samples
的數(shù)組),但是我們需要調(diào)整數(shù)據(jù)x
褂删,使其成為大小為[n_samples飞醉,n_features]
的矩陣冲茸。 在這種情況下屯阀,這相當于一維數(shù)組的簡單重塑:
X = x[:, np.newaxis]
X.shape
# (50, 1)
4. 使用模型來擬合數(shù)據(jù)
現(xiàn)在是時候?qū)⒛P蛻?yīng)用于數(shù)據(jù)了。這可以使用fit
方法來完成轴术。
model.fit(X, y)
# LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
這個fit
命令會導(dǎo)致一些模型相關(guān)的內(nèi)部計算难衰,這些計算的結(jié)果存儲在特定于模型的屬性中,用戶可以探索逗栽。 在 Scikit-Learn 中盖袭,按照慣例,在fit
過程中學(xué)習(xí)的所有模型參數(shù)彼宠,都有尾隨的下劃線鳄虱;例如在這個線性模型中,我們有以下這些東西:
model.coef_
# array([ 1.9776566])
model.intercept_
# -0.90331072553111635
這兩個參數(shù)表示對數(shù)據(jù)的簡單線性擬合的斜率和截距凭峡。 與數(shù)據(jù)定義相比拙已,我們看到它們非常接近輸入斜率 2 和截距 -1。
經(jīng)常出現(xiàn)的一個問題是摧冀,這些內(nèi)部模型參數(shù)的不確定性倍踪。 一般來說,Scikit-Learn 不提供從內(nèi)部模型參數(shù)本身得出結(jié)論的工具:模型參數(shù)的解釋更多是統(tǒng)計建模問題索昂,而不是機器學(xué)習(xí)問題建车。 機器學(xué)習(xí)的重點是模型預(yù)測。 如果你希望深入了解模型中參數(shù)的含義椒惨,則可以使用其他工具缤至,包括 Python Statsmodels 包。
5. 預(yù)測未知數(shù)據(jù)的標簽
一旦模型訓(xùn)練完成康谆,監(jiān)督機器學(xué)習(xí)的主要任務(wù)是领斥,根據(jù)對不是訓(xùn)練集的一部分的新數(shù)據(jù)做出評估。 在 Scikit-Learn 中秉宿,可以使用predict
方法來完成戒突。 對于這個例子,我們的“新數(shù)據(jù)”將是一個x
值的網(wǎng)格描睦,我們將詢問模型預(yù)測的y
值:
xfit = np.linspace(-1, 11)
像之前一樣膊存,我們需要將這些x
值調(diào)整為[n_samples, n_features]
的特征矩陣,之后我們可以將它扔給模型了。
Xfit = xfit[:, np.newaxis]
yfit = model.predict(Xfit)
最后隔崎,讓我們通過首先繪制原始數(shù)據(jù)今艺,之后是這個模型,來展示結(jié)果爵卒。
plt.scatter(x, y)
plt.plot(xfit, yfit);
通常虚缎,通過將其結(jié)果與某些已知基準進行比較,來評估模型的功效钓株,如下例所示实牡。
監(jiān)督學(xué)習(xí)示例,鳶尾花分類
我們來看看這個過程的另一個例子轴合,使用我們前面討論過的 Iris 數(shù)據(jù)集创坞。 我們的問題是這樣的:給出一個模型,使用 Iris 數(shù)據(jù)的一部分進行培訓(xùn)受葛,我們?nèi)绾文軌蝾A(yù)測剩余的標簽题涨?
對于這個任務(wù),我們將使用一個非常簡單的生成模型总滩,稱為高斯樸素貝葉斯纲堵,它們通過假設(shè)每個類別服從軸對齊的高斯分布(更多細節(jié)參見樸素貝葉斯分類)。 因為高斯樸素貝葉斯如此之快闰渔,沒有超參數(shù)可供選擇席函。在探索是否可以通過更復(fù)雜的模型做出改進之前,它通常是一個用作基準分類的良好模型澜建。
我們想對之前沒有看到的數(shù)據(jù)進行評估向挖,因此我們將數(shù)據(jù)分成訓(xùn)練集和測試集。 這可以手工完成炕舵,但是使用train_test_split
工具更方便:
from sklearn.cross_validation import train_test_split
Xtrain, Xtest, ytrain, ytest = train_test_split(X_iris, y_iris,
random_state=1)
排列好數(shù)據(jù)之后何之,我們可以遵循秘籍來預(yù)測標簽了。
from sklearn.naive_bayes import GaussianNB # 1. choose model class
model = GaussianNB() # 2. instantiate model
model.fit(Xtrain, ytrain) # 3. fit model to data
y_model = model.predict(Xtest) # 4. predict on new data
最后咽筋,我們可以使用accuracy_score
工具來查看匹配真實值的預(yù)測標簽的百分比溶推。
from sklearn.metrics import accuracy_score
accuracy_score(ytest, y_model)
# 0.97368421052631582
準確率超過了 97%,我們可以看到奸攻,即使是這個非常樸素的分類算法蒜危,對于特定數(shù)據(jù)集也是高效的。
無監(jiān)督學(xué)習(xí)示例:Iris 降維
作為無監(jiān)督學(xué)習(xí)問題的一個例子睹耐,我們來看一下 Iris 數(shù)據(jù)的降維辐赞,以便更容易地將其視覺化。 回想一下硝训,Iris 數(shù)據(jù)是四維的:每個樣本都記錄了四個特征响委。
降維的任務(wù)是詢問是否存在合適的低維表示新思,保留數(shù)據(jù)的基本特征。 降維通常用于來輔助數(shù)據(jù)可視化:畢竟赘风,繪制二維數(shù)據(jù)比四維或更高維度中更容易夹囚!
在這里,我們將使用主成分分析(PCA邀窃,參見主成分分析)荸哟,這是一種快速線性降維技術(shù)。 我們要求模型返回兩個組件 - 即數(shù)據(jù)的二維表示瞬捕。
按照先前列出的步驟鞍历,我們有:
from sklearn.decomposition import PCA # 1. Choose the model class
model = PCA(n_components=2) # 2. Instantiate the model with hyperparameters
model.fit(X_iris) # 3. Fit to data. Notice y is not specified!
X_2D = model.transform(X_iris) # 4. Transform the data to two dimensions
現(xiàn)在我們來繪制結(jié)果。 一個快速的方法是山析,將結(jié)果插入到原始的 Iris DataFrame 中堰燎,并使用 Seaborn 的lmplot
來顯示結(jié)果:
iris['PCA1'] = X_2D[:, 0]
iris['PCA2'] = X_2D[:, 1]
sns.lmplot("PCA1", "PCA2", hue='species', data=iris, fit_reg=False);
我們看到掏父,在二維表示中笋轨,物種的分隔相當良好,盡管 PCA 算法不知道物種標簽赊淑! 這對我們來說爵政,正如我們以前看到的那樣,相對簡單的分類可能對數(shù)據(jù)集有效陶缺。
無監(jiān)督學(xué)習(xí)示例:Iris 聚類
接下來我們來看一下 Iris 數(shù)據(jù)的聚類應(yīng)用钾挟。 聚類算法嘗試找到不同的數(shù)據(jù)分析,而不參考任何標簽饱岸。 在這里掺出,我們將使用一種稱為高斯混合模型(GMM)的強大的聚類方法,在高斯混合模型中有更詳細的討論苫费。 GMM 嘗試將數(shù)據(jù)建模為高斯數(shù)據(jù)塊的集合汤锨。
我們可以這樣擬合高斯混合模型:
from sklearn.mixture import GMM # 1. Choose the model class
model = GMM(n_components=3,
covariance_type='full') # 2. Instantiate the model with hyperparameters
model.fit(X_iris) # 3. Fit to data. Notice y is not specified!
y_gmm = model.predict(X_iris) # 4. Determine cluster labels
像之前一樣,我們向 IrisDataFrame
添加簇的標簽百框,并使用 Seaborn 繪制結(jié)果:
iris['cluster'] = y_gmm
sns.lmplot("PCA1", "PCA2", data=iris, hue='species',
col='cluster', fit_reg=False);
通過按照簇號分割數(shù)據(jù)闲礼,我們看到 GMM 算法已經(jīng)恢復(fù)了潛在的標簽:組 0 已經(jīng)完全分離了 setosa 物種,而在 versicolor 和 virginica 之間仍然存在少量的混合铐维。 這意味著即使沒有專家告訴我們個別花朵的物種標簽柬泽,這些花朵的度量是非常明顯的,我們可以用簡單的聚類算法嫁蛇,自動識別這些不同種類的物種的存在锨并! 這種算法可能會進一步向?qū)<姨峁┈F(xiàn)場線索,關(guān)于他們正在觀察的樣本之間關(guān)系睬棚。
應(yīng)用:手寫體數(shù)字探索
為了在一個更有趣的問題上演示這些原理第煮,我們來考慮一個光學(xué)字符識別問題:識別手寫數(shù)字有决。 粗略來說,這個問題涉及定位和識別圖像中的字符空盼。 在這里书幕,我們將使用捷徑,并使用 Scikit-Learn 的一組預(yù)格式化數(shù)字揽趾,這是內(nèi)置在庫中的台汇。
加載和展示數(shù)字的數(shù)據(jù)
我們使用 Scikit-Learn 的數(shù)據(jù)訪問接口,并看一看這個數(shù)據(jù):
from sklearn.datasets import load_digits
digits = load_digits()
digits.images.shape
# (1797, 8, 8)
這個圖像的數(shù)據(jù)是三維數(shù)組:1797 個樣本篱瞎,每個包含 8x8 的像素苟呐。讓我們先展示前一百個。
import matplotlib.pyplot as plt
fig, axes = plt.subplots(10, 10, figsize=(8, 8),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
ax.imshow(digits.images[i], cmap='binary', interpolation='nearest')
ax.text(0.05, 0.05, str(digits.target[i]),
transform=ax.transAxes, color='green')
為了在 Scikit-Learn 中處理這些數(shù)據(jù)俐筋,我們需要一個二維的[n_samples牵素,n_features]
表示。 我們可以將圖像中的每個像素視為一個特征:即通過展開像素陣列澄者,使得我們具有長度為 64 的數(shù)組笆呆,代表每個數(shù)字的像素值。 另外粱挡,我們需要目標數(shù)組赠幕,它為每個數(shù)字給出了先前確定的標簽。 這兩個數(shù)量分別內(nèi)置在數(shù)字數(shù)據(jù)集的data
和target
屬性中:
X = digits.data
X.shape
# (1797, 64)
y = digits.target
y.shape
# (1797,)
我們可以看到询筏,有 1797 個樣本和 64 個特征榕堰。
無監(jiān)督學(xué)習(xí):降維
我們希望在 64 維參數(shù)空間內(nèi)可視化我們的點,但很難有效地在這樣一個高維空間中可視化點嫌套。 相反逆屡,我們將使用無監(jiān)督的方法將維度減小到 2。 在這里踱讨,我們將利用一種稱為 Isomap 的流形學(xué)習(xí)算法(參見流形學(xué)習(xí))魏蔗,并將數(shù)據(jù)轉(zhuǎn)換為兩個維度:
from sklearn.manifold import Isomap
iso = Isomap(n_components=2)
iso.fit(digits.data)
data_projected = iso.transform(digits.data)
data_projected.shape
# (1797, 2)
我們可以看到,投影的數(shù)據(jù)現(xiàn)在是二維了勇蝙。讓我們繪制數(shù)據(jù)沫勿,來看看是否可以從結(jié)構(gòu)中學(xué)到什么東西。
plt.scatter(data_projected[:, 0], data_projected[:, 1], c=digits.target,
edgecolor='none', alpha=0.5,
cmap=plt.cm.get_cmap('spectral', 10))
plt.colorbar(label='digit label', ticks=range(10))
plt.clim(-0.5, 9.5);
這個繪圖給了我們很好的直覺味混,在更大的 64 維空間中各種數(shù)字的分離程度如何产雹。 例如,零(黑色)和一(紫色)在參數(shù)空間中幾乎沒有重疊翁锡。 直觀上來說蔓挖,這是有道理的:零的圖像中間是空的,而一的中間通常會有墨跡馆衔。 另一方面瘟判,一和四之間似乎有一個或多或少的連續(xù)頻譜:我們可以理解怨绣,有些人在一上畫了個“帽子”,從而使他們看起來像四拷获。
然而篮撑,總的來說,不同的組的似乎在參數(shù)空間中分離良好的:這告訴我們匆瓜,即使是一個非常簡單的監(jiān)督分類算法赢笨,應(yīng)該也適合于這些數(shù)據(jù)。 讓我們試試看吧驮吱。
對數(shù)字分類
讓我們對數(shù)字應(yīng)用分類算法茧妒。就像之前的 Iris 數(shù)據(jù)那樣,我們將數(shù)據(jù)分為訓(xùn)練和測試集左冬,之后擬合高斯樸素貝葉斯模型桐筏。
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, random_state=0)
from sklearn.naive_bayes import GaussianNB
model = GaussianNB()
model.fit(Xtrain, ytrain)
y_model = model.predict(Xtest)
既然我們預(yù)測了我們的模型,我們可以通過比較測試集和預(yù)測拇砰,來看看它的準確度梅忌。
from sklearn.metrics import accuracy_score
accuracy_score(ytest, y_model)
# 0.83333333333333337
即使是這個非常簡單的模型,我們發(fā)現(xiàn)數(shù)字分類的準確率約為 80%毕匀! 然而铸鹰,這個單一的數(shù)字并沒有告訴我們哪里不對 - 一個很好的方式是使用混淆矩陣,我們可以用 Scikit-Learn 和 Seaborn 進行計算:
from sklearn.metrics import confusion_matrix
mat = confusion_matrix(ytest, y_model)
sns.heatmap(mat, square=True, annot=True, cbar=False)
plt.xlabel('predicted value')
plt.ylabel('true value');
這顯示了錯誤標記的點往往是什么:例如皂岔,這里的大量二被錯誤分類為一或者八。 獲取模型特征的直覺的另一種方法展姐,是用預(yù)測的標簽再次繪制輸入躁垛。 我們將使用綠色標簽表示正確,紅色標簽表示不正確:
fig, axes = plt.subplots(10, 10, figsize=(8, 8),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
test_images = Xtest.reshape(-1, 8, 8)
for i, ax in enumerate(axes.flat):
ax.imshow(test_images[i], cmap='binary', interpolation='nearest')
ax.text(0.05, 0.05, str(y_model[i]),
transform=ax.transAxes,
color='green' if (ytest[i] == y_model[i]) else 'red')
檢查這個數(shù)據(jù)的這個子集圾笨,我們可以深入了解教馆,算法在哪里可能表現(xiàn)不是最好。 為了超過我們 80% 的分類準確率擂达,我們可能會采用更復(fù)雜的算法土铺,如支持向量機(參見支持向量機),隨機森林(參見決策樹和隨機森林)或其他分類方式板鬓。
總結(jié)
在本節(jié)中悲敷,我們已經(jīng)介紹了 Scikit-Learn 數(shù)據(jù)表示的基本特征和估計器 API。 不管估計類型如何俭令,都需要相同的導(dǎo)入/實例化/擬合/預(yù)測模式后德。 為了掌握有關(guān)估計 API 的信息,你可以瀏覽 Scikit-Learn 文檔抄腔,并開始在數(shù)據(jù)上嘗試各種模型瓢湃。
在下一節(jié)中理张,我們將探討機器學(xué)習(xí)中最重要的主題:如何選擇和驗證你的模型。