主要介紹scikit-learn中的交叉驗證
sklearn.learning_curve 中的 learning curve 可以很直觀的看出我們的 model 學(xué)習(xí)的進(jìn)度, 對比發(fā)現(xiàn)有沒有 overfitting 的問題. 然后我們可以對我們的 model 進(jìn)行調(diào)整, 克服 overfitting 的問題.
Demo1.py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.svm import SVC
from sklearn.learning_curve import learning_curve
from sklearn.cross_validation import cross_val_score
# 加載數(shù)據(jù)集
digits = load_digits()
X = digits.data
y = digits.target
# 用SVM進(jìn)行學(xué)習(xí)并記錄loss
train_sizes, train_loss, test_loss = learning_curve(SVC(gamma = 0.001),
X, y, cv = 10, scoring = 'mean_squared_error',
train_sizes = [0.1, 0.25, 0.5, 0.75, 1])
# 訓(xùn)練誤差均值
train_loss_mean = -np.mean(train_loss, axis = 1)
# 測試誤差均值
test_loss_mean = -np.mean(test_loss, axis = 1)
# 繪制誤差曲線
plt.plot(train_sizes, train_loss_mean, 'o-', color = 'r', label = 'Training')
plt.plot(train_sizes, test_loss_mean, 'o-', color = 'g', label = 'Cross-Validation')
plt.xlabel('Training data size')
plt.ylabel('Loss')
plt.legend(loc = 'best')
plt.show()
結(jié)果:
Paste_Image.png
sklearn.learning_curve.learning_curve
-
調(diào)用格式是:
learning_curve(estimator, X, y, train_sizes=array([ 0.1 , 0.325, 0.55 , 0.775, 1. ]), cv=None, scoring=None, exploit_incremental_learning=False, n_jobs=1, pre_dispatch=‘a(chǎn)ll‘, verbose=0) -
函數(shù)的作用為:
對于不同大小的訓(xùn)練集禽拔,確定交叉驗證訓(xùn)練和測試的分?jǐn)?shù)兴溜。一個交叉驗證發(fā)生器將整個數(shù)據(jù)集分割k次,分割成訓(xùn)練集和測試集趋箩。不同大小的訓(xùn)練集的子集將會被用來訓(xùn)練評估器并且對于每一個大小的訓(xùn)練子集都會產(chǎn)生一個分?jǐn)?shù)泰偿,然后測試集的分?jǐn)?shù)也會計算俩垃。然后,對于每一個訓(xùn)練子集慷荔,運行k次之后的所有這些分?jǐn)?shù)將會被平均雕什。
estimator:所使用的分類器
X:array-like, shape (n_samples, n_features) 訓(xùn)練向量,n_samples是樣本的數(shù)量显晶,n_features是特征的數(shù)量
y:array-like, shape (n_samples) or (n_samples, n_features), optional目標(biāo)相對于X分類或者回歸
train_sizes:array-like, shape (n_ticks,), dtype float or int訓(xùn)練樣本的相對的或絕對的數(shù)字贷岸,這些量的樣本將會生成learning curve。如果dtype是float磷雇,他將會被視為最大數(shù)量訓(xùn)練集的一部分(這個由所選擇的驗證方法所決定)偿警。否則,他將會被視為訓(xùn)練集的絕對尺寸唯笙。要注意的是螟蒸,對于分類而言,樣本的大小必須要充分大崩掘,達(dá)到對于每一個分類都至少包含一個樣本的情況尿庐。
cv:int, cross-validation generator or an iterable, optional確定交叉驗證的分離策略
--None,使用默認(rèn)的3-fold cross-validation,
--integer,確定是幾折交叉驗證
--一個作為交叉驗證生成器的對象
--一個被應(yīng)用于訓(xùn)練/測試分離的迭代器
verbose : integer, optional控制冗余:越高呢堰,有越多的信息 -
返回值:
train_sizes_abs:array, shape = (n_unique_ticks,), dtype int用于生成learning curve的訓(xùn)練集的樣本數(shù)抄瑟。由于重復(fù)的輸入將會被刪除,所以ticks可能會少于n_ticks.
**train_scores **: array, shape (n_ticks, n_cv_folds)在訓(xùn)練集上的分?jǐn)?shù)
test_scores : array, shape (n_ticks, n_cv_folds)在測試集上的分?jǐn)?shù)