在XGBoost中通過(guò)Early Stop避免過(guò)擬合

本文翻譯自Avoid Overfitting By Early Stopping With XGBoost In Python氯材,講述如何在使用XGBoost建模時(shí)通過(guò)Early Stop手段來(lái)避免過(guò)擬合肌幽。全文系作者原創(chuàng),僅供學(xué)習(xí)參考使用期虾,轉(zhuǎn)載授權(quán)請(qǐng)私信聯(lián)系,否則將視為侵權(quán)行為。碼字不易碴犬,感謝支持题暖。以下為全文內(nèi)容:


過(guò)擬合問(wèn)題是在使用復(fù)雜的非線性學(xué)習(xí)算法時(shí)會(huì)經(jīng)常碰到的按傅,比如gradient boosting算法。

在這篇博客中你將發(fā)現(xiàn)如何通過(guò)Early Stop方法使得我們?cè)谑褂肞ython中的XGBoost模型時(shí)可以盡可能地避免過(guò)擬合問(wèn)題:

讀完這篇博客后胧卤,你將學(xué)到:

  • Early Stop可以減少訓(xùn)練集上的過(guò)擬合
  • 在使用XGBoost模型時(shí)如何監(jiān)控訓(xùn)練過(guò)程中模型的表現(xiàn)唯绍,如何繪制學(xué)習(xí)曲線
  • 如何使用Early Stop方法在模型表現(xiàn)最好的時(shí)候停止訓(xùn)練

讓我們開(kāi)始吧。

使用Early Stop避免過(guò)擬合

Early Stop是訓(xùn)練復(fù)雜機(jī)器學(xué)習(xí)模型以避免其過(guò)擬合的一種方法灌侣。

它通過(guò)監(jiān)控模型在一個(gè)額外的測(cè)試集上的表現(xiàn)來(lái)工作推捐,當(dāng)模型在測(cè)試集上的表現(xiàn)在連續(xù)的若干次(提前指定好的)迭代中都不再提升時(shí)它將終止訓(xùn)練過(guò)程。

它通過(guò)嘗試自動(dòng)選擇拐點(diǎn)來(lái)避免過(guò)擬合侧啼,在拐點(diǎn)處牛柒,測(cè)試數(shù)據(jù)集的性能開(kāi)始下降,而訓(xùn)練數(shù)據(jù)集的性能隨著模型開(kāi)始過(guò)擬合而繼續(xù)改善痊乾。

性能的度量可以是訓(xùn)練模型時(shí)正在使用的損失函數(shù)(例如對(duì)數(shù)損失)皮壁,或通常意義上用戶感興趣的外部度量(例如分類精度)。

在XGBoost中監(jiān)控模型的表現(xiàn)

XGBoost模型在訓(xùn)練時(shí)可以計(jì)算并輸入在某個(gè)指定的測(cè)試數(shù)據(jù)集的性能表現(xiàn)哪审。

在調(diào)用model.fit()函數(shù)時(shí)蛾魄,可以指定測(cè)試數(shù)據(jù)集和評(píng)價(jià)指標(biāo),同時(shí)設(shè)置verbose參數(shù)為True湿滓,這樣就可以在訓(xùn)練過(guò)程中輸出模型在測(cè)試集的表現(xiàn)滴须。

例如,我們可以通過(guò)下面的方法在使用XGBoost訓(xùn)練二分類任務(wù)時(shí)輸出分類錯(cuò)誤率(通過(guò)“error”指定):

eval_set = [(X_test, y_test)]
model.fit(X_train, y_train, eval_metric="error", eval_set=eval_set, verbose=True)

XGBoost提供了一系列的模型評(píng)價(jià)指標(biāo)叽奥,包括但不限于:

  • “rmse” 代表均方根誤差
  • “mae” 代表平均絕對(duì)誤差
  • “l(fā)ogloss” 代表二元對(duì)數(shù)損失
  • “mlogloss” 代表m-元對(duì)數(shù)損失
  • “error” 代表分類錯(cuò)誤率
  • “auc” 代表ROC曲線下面積

完整的列表見(jiàn)XGBoost文檔中的“Learning Task Parameters””章節(jié)扔水。

例如,我們可以演示如何監(jiān)控使用UCI機(jī)器學(xué)習(xí)存儲(chǔ)庫(kù)(更新:從這里下載)的關(guān)于Pima糖尿病發(fā)病數(shù)據(jù)集的XGBoost模型在訓(xùn)練過(guò)程中的性能指標(biāo)朝氓。

完整代碼清單如下:

# monitor training performance
from numpy import loadtxt
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# load data
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=",")
# split data into X and y
X = dataset[:,0:8]
Y = dataset[:,8]
# split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=7)
# fit model no training data
model = XGBClassifier()
eval_set = [(X_test, y_test)]
model.fit(X_train, y_train, eval_metric="error", eval_set=eval_set, verbose=True)
# make predictions for test data
y_pred = model.predict(X_test)
predictions = [round(value) for value in y_pred]
# evaluate predictions
accuracy = accuracy_score(y_test, predictions)
print("Accuracy: %.2f%%" % (accuracy * 100.0))

運(yùn)行這段代碼將會(huì)在67%的數(shù)據(jù)集上訓(xùn)練模型魔市,并且在每一輪迭代中使用剩下的33%數(shù)據(jù)來(lái)評(píng)估模型的性能主届。

每次迭代都會(huì)輸出分類錯(cuò)誤,最終將會(huì)輸出最后的分類準(zhǔn)確率待德。

...
[89]    validation_0-error:0.204724
[90]    validation_0-error:0.208661
[91]    validation_0-error:0.208661
[92]    validation_0-error:0.208661
[93]    validation_0-error:0.208661
[94]    validation_0-error:0.208661
[95]    validation_0-error:0.212598
[96]    validation_0-error:0.204724
[97]    validation_0-error:0.212598
[98]    validation_0-error:0.216535
[99]    validation_0-error:0.220472
Accuracy: 77.95%

觀察所有的輸出君丁,我們可以看到,在訓(xùn)練快要結(jié)束時(shí)測(cè)試集上的模型性能的變化是平緩的将宪,甚至變得更差绘闷。

使用學(xué)習(xí)曲線來(lái)評(píng)估XGBoost模型

我們可以提取出模型在測(cè)試數(shù)據(jù)集上的表現(xiàn)并繪制成圖案,從而更好地洞察到在整個(gè)訓(xùn)練過(guò)程中學(xué)習(xí)曲線是如何變化的涧偷。

在調(diào)用XGBoost模型時(shí)我們提供了一個(gè)數(shù)組簸喂,數(shù)組的每一項(xiàng)是一個(gè)X和y的配對(duì)。在測(cè)試集之外燎潮,我們同時(shí)將訓(xùn)練集也作為輸入喻鳄,從而觀察在訓(xùn)練過(guò)程中模型在訓(xùn)練集和測(cè)試集上各自的表現(xiàn)。

例如:

eval_set = [(X_train, y_train), (X_test, y_test)]
model.fit(X_train, y_train, eval_metric="error", eval_set=eval_set, verbose=True)

模型在各個(gè)數(shù)據(jù)集上的表現(xiàn)可以在訓(xùn)練結(jié)束后通過(guò)model.evals_result()函數(shù)獲取确封,這個(gè)函數(shù)返回一個(gè)dict包含了評(píng)估數(shù)據(jù)集的代碼和對(duì)應(yīng)的分?jǐn)?shù)列表除呵,例如:

results = model.evals_result()
print(results)

這將輸出如下的結(jié)果:

{
    'validation_0': {'error': [0.259843, 0.26378, 0.26378, ...]},
    'validation_1': {'error': [0.22179, 0.202335, 0.196498, ...]}
}

“validation_0”和“validation_1”代表了在調(diào)用fit()函數(shù)時(shí)傳給eval_set參數(shù)的數(shù)組中數(shù)據(jù)集的順序。

一個(gè)特定的結(jié)果爪喘,比如第一個(gè)數(shù)據(jù)集上的分類錯(cuò)誤率颜曾,可以通過(guò)如下方法獲取:

results['validation_0']['error']

另外我們可以指定更多的評(píng)價(jià)指標(biāo)秉剑,從而同時(shí)獲取多種評(píng)價(jià)指標(biāo)的變化情況泛豪。

接著我們可以使用收集到的數(shù)據(jù)繪制曲線,從而更直觀地了解在整個(gè)訓(xùn)練過(guò)程中模型在訓(xùn)練集和測(cè)試集上的表現(xiàn)究竟如何侦鹏。

下面是一段完整的代碼诡曙,展示了如何將收集到的數(shù)據(jù)繪制成學(xué)習(xí)曲線:

# plot learning curve
from numpy import loadtxt
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from matplotlib import pyplot
# load data
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=",")
# split data into X and y
X = dataset[:,0:8]
Y = dataset[:,8]
# split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=7)
# fit model no training data
model = XGBClassifier()
eval_set = [(X_train, y_train), (X_test, y_test)]
model.fit(X_train, y_train, eval_metric=["error", "logloss"], eval_set=eval_set, verbose=True)
# make predictions for test data
y_pred = model.predict(X_test)
predictions = [round(value) for value in y_pred]
# evaluate predictions
accuracy = accuracy_score(y_test, predictions)
print("Accuracy: %.2f%%" % (accuracy * 100.0))
# retrieve performance metrics
results = model.evals_result()
epochs = len(results['validation_0']['error'])
x_axis = range(0, epochs)
# plot log loss
fig, ax = pyplot.subplots()
ax.plot(x_axis, results['validation_0']['logloss'], label='Train')
ax.plot(x_axis, results['validation_1']['logloss'], label='Test')
ax.legend()
pyplot.ylabel('Log Loss')
pyplot.title('XGBoost Log Loss')
pyplot.show()
# plot classification error
fig, ax = pyplot.subplots()
ax.plot(x_axis, results['validation_0']['error'], label='Train')
ax.plot(x_axis, results['validation_1']['error'], label='Test')
ax.legend()
pyplot.ylabel('Classification Error')
pyplot.title('XGBoost Classification Error')
pyplot.show()

運(yùn)行這段代碼將會(huì)在每一次訓(xùn)練迭代中輸出模型在訓(xùn)練集和測(cè)試集上的分類錯(cuò)誤率。我們可以通過(guò)設(shè)置verbose=False來(lái)關(guān)閉輸出略水。

我們繪制了兩張圖价卤,第一張圖表示的是模型在每一輪迭代中在兩個(gè)數(shù)據(jù)集上的對(duì)數(shù)損失:

XGBoost Learning Curve Log Loss

第二張圖表示分類錯(cuò)誤率:

XGBoost Learning Curve Classification Error

從第一張圖來(lái)看,似乎有機(jī)會(huì)可以進(jìn)行Early Stop渊涝,大約在20到40輪迭代時(shí)比較合適慎璧。

從第二張圖可以得到相似的結(jié)果,大概在40輪迭代時(shí)效果比較理想跨释。

在XGBoost中進(jìn)行Early Stop

XGBoost提供了在指定輪數(shù)完成后提前停止訓(xùn)練的功能胸私。

除了提供用于評(píng)估每輪迭代中的評(píng)價(jià)指標(biāo)和數(shù)據(jù)集之外,還需要指定一個(gè)窗口大小鳖谈,意味著連續(xù)這么多輪迭代中模型的效果沒(méi)有提升岁疼。這是通過(guò)early_stopping_rounds參數(shù)來(lái)設(shè)置的。

例如蚯姆,我們可以像下面這樣設(shè)置連續(xù)10輪中對(duì)數(shù)損失都沒(méi)有提升:

eval_set = [(X_test, y_test)]
model.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="logloss", eval_set=eval_set, verbose=True)

如果同時(shí)指定了多個(gè)評(píng)估數(shù)據(jù)集和多個(gè)評(píng)價(jià)指標(biāo)五续,early_stopping_rounds將會(huì)使用數(shù)組中的最后一個(gè)作為依據(jù)。

下面提供了一個(gè)使用early_stopping_rounds的詳細(xì)例子:

# early stopping
from numpy import loadtxt
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# load data
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=",")
# split data into X and y
X = dataset[:,0:8]
Y = dataset[:,8]
# split data into train and test sets
seed = 7
test_size = 0.33
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=test_size, random_state=seed)
# fit model no training data
model = XGBClassifier()
eval_set = [(X_test, y_test)]
model.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="logloss", eval_set=eval_set, verbose=True)
# make predictions for test data
y_pred = model.predict(X_test)
predictions = [round(value) for value in y_pred]
# evaluate predictions
accuracy = accuracy_score(y_test, predictions)
print("Accuracy: %.2f%%" % (accuracy * 100.0))

運(yùn)行這段代碼將得到如下的輸出(部分):

...
[35]    validation_0-logloss:0.487962
[36]    validation_0-logloss:0.488218
[37]    validation_0-logloss:0.489582
[38]    validation_0-logloss:0.489334
[39]    validation_0-logloss:0.490969
[40]    validation_0-logloss:0.48978
[41]    validation_0-logloss:0.490704
[42]    validation_0-logloss:0.492369
Stopping. Best iteration:
[32]    validation_0-logloss:0.487297

我們可以看到模型在迭代到42輪時(shí)停止了訓(xùn)練龄恋,在32輪迭代后觀察到了最好的效果疙驾。

通常將early_stopping_rounds設(shè)置為一個(gè)與總訓(xùn)練輪數(shù)相關(guān)的函數(shù)(本例中是10%),或者通過(guò)觀察學(xué)習(xí)曲線來(lái)設(shè)置使得訓(xùn)練過(guò)程包含拐點(diǎn)郭毕,這兩種方法都是不錯(cuò)的選擇它碎。

總結(jié)

在這篇博客中你發(fā)現(xiàn)了如何監(jiān)控模型的表現(xiàn)以及怎么做Early Stop。

你學(xué)會(huì)了:

  • 使用Early Stop手段在模型過(guò)擬合之前停止訓(xùn)練
  • 在使用XGBoost模型時(shí)如何監(jiān)控模型的表現(xiàn)并繪制出模型的學(xué)習(xí)曲線
  • 在訓(xùn)練XGBoost模型時(shí)如何設(shè)置Early Stop參數(shù)

關(guān)于Early Stop或者這篇博客你還有什么想問(wèn)的問(wèn)題嗎显押?歡迎在下方的評(píng)論區(qū)留言扳肛,我將盡我最大的努力來(lái)解答。


以上就是本文的全部?jī)?nèi)容乘碑,如果您喜歡這篇文章挖息,歡迎將它分享給朋友們。

感謝您的閱讀兽肤,祝您生活愉快套腹!

作者:小美哥
2018-12-15

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市资铡,隨后出現(xiàn)的幾起案子电禀,更是在濱河造成了極大的恐慌,老刑警劉巖笤休,帶你破解...
    沈念sama閱讀 222,378評(píng)論 6 516
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件尖飞,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡店雅,警方通過(guò)查閱死者的電腦和手機(jī)政基,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,970評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)底洗,“玉大人腋么,你說(shuō)我怎么就攤上這事『ヒ荆” “怎么了珊擂?”我有些...
    開(kāi)封第一講書人閱讀 168,983評(píng)論 0 362
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)费变。 經(jīng)常有香客問(wèn)我摧扇,道長(zhǎng),這世上最難降的妖魔是什么挚歧? 我笑而不...
    開(kāi)封第一講書人閱讀 59,938評(píng)論 1 299
  • 正文 為了忘掉前任扛稽,我火速辦了婚禮,結(jié)果婚禮上滑负,老公的妹妹穿的比我還像新娘在张。我一直安慰自己用含,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,955評(píng)論 6 398
  • 文/花漫 我一把揭開(kāi)白布帮匾。 她就那樣靜靜地躺著啄骇,像睡著了一般。 火紅的嫁衣襯著肌膚如雪瘟斜。 梳的紋絲不亂的頭發(fā)上缸夹,一...
    開(kāi)封第一講書人閱讀 52,549評(píng)論 1 312
  • 那天,我揣著相機(jī)與錄音螺句,去河邊找鬼虽惭。 笑死,一個(gè)胖子當(dāng)著我的面吹牛蛇尚,可吹牛的內(nèi)容都是我干的芽唇。 我是一名探鬼主播,決...
    沈念sama閱讀 41,063評(píng)論 3 422
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼佣蓉,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼披摄!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起勇凭,我...
    開(kāi)封第一講書人閱讀 39,991評(píng)論 0 277
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤疚膊,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后虾标,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體寓盗,經(jīng)...
    沈念sama閱讀 46,522評(píng)論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,604評(píng)論 3 342
  • 正文 我和宋清朗相戀三年璧函,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了傀蚌。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,742評(píng)論 1 353
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡蘸吓,死狀恐怖善炫,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情库继,我是刑警寧澤箩艺,帶...
    沈念sama閱讀 36,413評(píng)論 5 351
  • 正文 年R本政府宣布,位于F島的核電站宪萄,受9級(jí)特大地震影響艺谆,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜拜英,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,094評(píng)論 3 335
  • 文/蒙蒙 一静汤、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦虫给、人聲如沸藤抡。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 32,572評(píng)論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)杰捂。三九已至,卻和暖如春棋蚌,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背挨队。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 33,671評(píng)論 1 274
  • 我被黑心中介騙來(lái)泰國(guó)打工谷暮, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人盛垦。 一個(gè)月前我還...
    沈念sama閱讀 49,159評(píng)論 3 378
  • 正文 我出身青樓湿弦,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親腾夯。 傳聞我的和親對(duì)象是個(gè)殘疾皇子颊埃,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,747評(píng)論 2 361

推薦閱讀更多精彩內(nèi)容