CatBoost vs. LightGBM vs. XGBoost

258.png

提升算法是一類機(jī)器學(xué)習(xí)算法募闲,通過迭代地訓(xùn)練一系列弱分類器(通常是決策樹)來構(gòu)建一個(gè)強(qiáng)分類器股毫。在每一輪迭代中述寡,新的分類器被設(shè)計(jì)為修正前一輪分類器的錯(cuò)誤嫉晶,從而逐步提高整體的分類性能。

盡管神經(jīng)網(wǎng)絡(luò)興起并流行起來敛滋,但提升算法仍然相當(dāng)實(shí)用许布。因?yàn)樗鼈冊谟?xùn)練數(shù)據(jù)有限、訓(xùn)練時(shí)間短绎晃、缺乏參數(shù)調(diào)優(yōu)專業(yè)知識等的情況下蜜唾,仍然有良好的表現(xiàn)。

提升算法有AdaBoost庶艾、CatBoost袁余、LightGBM、XGBoost等咱揍。

本文颖榜,將重點(diǎn)關(guān)注CatBoost、LightGBM煤裙、XGBoost掩完。將包括:

  • 結(jié)構(gòu)上的區(qū)別;
  • 每個(gè)算法對分類變量的處理方式硼砰;
  • 理解參數(shù)且蓬;
  • 在數(shù)據(jù)集上的實(shí)踐;
  • 每個(gè)算法的性能题翰。

由于 XGBoost(通常被稱為 GBM Killer)在機(jī)器學(xué)習(xí)領(lǐng)域已經(jīng)存在了很長時(shí)間恶阴,并且有很多文章專門介紹它,因此本文將更多地關(guān)注 CatBoost 和 LGBM遍愿。

1. LightGBM和XGBoost的結(jié)構(gòu)差異

LightGBM使用一種新穎的梯度單邊采樣(Gradient-based One-Side Sampling,GOSS)技術(shù)存淫,在查找分裂值時(shí)過濾數(shù)據(jù)實(shí)例,而XGBoost使用預(yù)排序算法(pre-sorted algorithm)和基于直方圖的算法(Histogram-based algorithm)來計(jì)算最佳分裂沼填。

上面的實(shí)例指的是觀測/樣本桅咆。

首先,讓我們了解一下XGBoost的預(yù)排序分裂是如何工作的:

  • 對于每個(gè)節(jié)點(diǎn)坞笙,枚舉所有特征岩饼;
  • 對于每個(gè)特征,按特征值對實(shí)例進(jìn)行排序薛夜;
  • 使用線性掃描來根據(jù)信息增益(information gain)決定該特征上的最佳分裂籍茧;
  • 選擇所有特征中的最佳分裂解決方案。

簡單來說梯澜,基于直方圖的算法將特征的所有數(shù)據(jù)點(diǎn)分成離散的箱子寞冯,并使用這些箱子來找到直方圖的分裂值。雖然在訓(xùn)練速度上比預(yù)排序算法高效,后者需要枚舉預(yù)排序的特征值上的所有可能分裂點(diǎn)吮龄,但在速度方面仍然落后于GOSS俭茧。

那么,是什么使得GOSS方法高效呢漓帚?

在AdaBoost中母债,樣本權(quán)重可以作為樣本重要性的良好指標(biāo)。然而尝抖,在梯度提升決策樹(GBDT)中毡们,沒有原生的樣本權(quán)重,因此無法直接應(yīng)用于AdaBoost提出的采樣方法昧辽。這就引入了基于梯度的采樣方法衙熔。

梯度代表損失函數(shù)切線的斜率,因此在某種意義上奴迅,如果數(shù)據(jù)點(diǎn)的梯度較大青责,這些點(diǎn)對于找到最佳分裂點(diǎn)是重要的,因?yàn)樗鼈兙哂懈叩恼`差取具。

GOSS保留所有具有較大梯度的實(shí)例脖隶,并對具有較小梯度的實(shí)例進(jìn)行隨機(jī)采樣。例如暇检,假設(shè)我有50萬行的數(shù)據(jù)产阱,其中1萬行具有較大的梯度。因此块仆,我的算法將選擇(10k行具有較大梯度 + 剩余的490k行的x%隨機(jī)選擇)构蹬。假設(shè)x為10%,則選擇的總行數(shù)是59k悔据,基于這些行找到了分裂值庄敛。

這里的基本假設(shè)是,具有較小梯度的訓(xùn)練實(shí)例具有較小的訓(xùn)練誤差科汗,并且已經(jīng)訓(xùn)練得很好藻烤。為了保持相同的數(shù)據(jù)分布,在計(jì)算信息增益時(shí)头滔,GOSS引入了一個(gè)常數(shù)乘數(shù)怖亭,用于具有較小梯度的數(shù)據(jù)實(shí)例。因此坤检,GOSS在減少數(shù)據(jù)實(shí)例數(shù)量和保持學(xué)習(xí)決策樹的準(zhǔn)確性之間取得了良好的平衡兴猩。

LGBM在梯度/誤差較大的葉子上進(jìn)一步生長

2. 每個(gè)模型如何處理分類變量?

2.1 CatBoost

CatBoost具有靈活性早歇,可以提供分類列的索引倾芝,以便可以使用one-hot編碼進(jìn)行編碼讨勤,使用one_hot_max_size參數(shù)(對于具有不同值數(shù)量小于或等于給定參數(shù)值的所有特征使用one-hot編碼)。

如果在cat_features參數(shù)中未傳遞任何內(nèi)容蛀醉,則CatBoost將將所有列視為數(shù)值變量悬襟。

注意:如果一個(gè)包含字符串值的列沒有在cat_features中提供衅码,CatBoost會拋出錯(cuò)誤拯刁。另外,默認(rèn)為int類型的列將默認(rèn)視為數(shù)值型逝段,如果要將其視為分類變量垛玻,必須在cat_features中指定。

對于剩余的分類列奶躯,其中唯一類別數(shù)大于one_hot_max_size的列帚桩,CatBoost使用一種類似于均值編碼但減少過擬合的高效編碼方法。該過程如下:

  • 隨機(jī)以隨機(jī)順序?qū)斎胗^測集進(jìn)行排列嘹黔,生成多個(gè)隨機(jī)排列账嚎;
  • 將標(biāo)簽值從浮點(diǎn)數(shù)或類別轉(zhuǎn)換為整數(shù);
  • 使用以下公式將所有分類特征值轉(zhuǎn)換為數(shù)值:

其中儡蔓,countInClass表示標(biāo)簽值等于“1”的對象中當(dāng)前分類特征值的出現(xiàn)次數(shù)郭蕉,prior是分子的初步值,由起始參數(shù)確定喂江,totalCount是具有與當(dāng)前分類特征值匹配的當(dāng)前對象之前的總對象數(shù)召锈。

數(shù)學(xué)上,可以用以下方程表示:

2.2 LightGBM

與CatBoost類似获询,LightGBM也可以通過輸入特征名稱來處理分類特征涨岁。它不會轉(zhuǎn)換為獨(dú)熱編碼,而且比獨(dú)熱編碼快得多吉嚣。LGBM使用一種特殊的算法來找到分類特征的分裂值梢薪。


注意:在構(gòu)建LGBM數(shù)據(jù)集之前,您應(yīng)該將分類特征轉(zhuǎn)換為整數(shù)類型尝哆。即使通過categorical_feature參數(shù)傳遞了字符串值秉撇,它也不接受字符串值。

2.3 XGBoost

與CatBoost或LGBM不同较解,XGBoost本身不能處理分類特征畜疾,它只接受類似于隨機(jī)森林的數(shù)值型數(shù)據(jù)。因此印衔,在將分類數(shù)據(jù)提供給XGBoost之前肛循,需要執(zhí)行各種編碼,如標(biāo)簽編碼砂沛、均值編碼或獨(dú)熱編碼。

3. 理解參數(shù)

所有這些模型都有很多要調(diào)整的參數(shù)彤敛,但我們只討論其中重要的參數(shù)。下面是這些參數(shù)的列表了赌,根據(jù)它們的功能以及在不同模型中的對應(yīng)參數(shù)墨榄。


4. 在數(shù)據(jù)集上的實(shí)現(xiàn)

我使用了2015年航班延誤的Kaggle數(shù)據(jù)集,因?yàn)樗劝诸愄卣饔职瑪?shù)值特征勿她。由于大約有500萬行數(shù)據(jù)袄秩,這個(gè)數(shù)據(jù)集對于評估每種類型的提升模型在速度和準(zhǔn)確性方面的性能是很好的。我將使用這個(gè)數(shù)據(jù)的10%子集逢并,約50萬行之剧。

以下是用于建模的特征:

  • MONTH,DAY砍聊,DAY_OF_WEEK:數(shù)據(jù)類型int
  • AIRLINE和FLIGHT_NUMBER:數(shù)據(jù)類型int
  • ORIGIN_AIRPORT和DESTINATION_AIRPORT:數(shù)據(jù)類型字符串
  • DEPARTURE_TIME:數(shù)據(jù)類型float
  • ARRIVAL_DELAY:這將是目標(biāo)變量背稼,并轉(zhuǎn)換為表示超過10分鐘延誤的布爾變量
  • DISTANCE和AIR_TIME:數(shù)據(jù)類型float
import pandas as pd, numpy as np, time
from sklearn.model_selection import train_test_split

data = pd.read_csv("./data/flights.csv")
data = data.sample(frac = 0.1, random_state=10)

data = data[["MONTH","DAY","DAY_OF_WEEK","AIRLINE","FLIGHT_NUMBER","DESTINATION_AIRPORT",
                 "ORIGIN_AIRPORT","AIR_TIME", "DEPARTURE_TIME","DISTANCE","ARRIVAL_DELAY"]]
data.dropna(inplace=True)

data["ARRIVAL_DELAY"] = (data["ARRIVAL_DELAY"]>10)*1

cols = ["AIRLINE","FLIGHT_NUMBER","DESTINATION_AIRPORT","ORIGIN_AIRPORT"]
for item in cols:
    data[item] = data[item].astype("category").cat.codes + 1

train, test, y_train, y_test = train_test_split(data.drop(["ARRIVAL_DELAY"], axis=1 ), data["ARRIVAL_DELAY"],random_state=10, test_size=0.25)

4.1 XGBoost

import xgboost as xgb
from sklearn import metrics
from sklearn.model_selection import GridSearchCV

def auc(m, train, test): 
    return (metrics.roc_auc_score(y_train,m.predict_proba(train)[:,1]),
                            metrics.roc_auc_score(y_test,m.predict_proba(test)[:,1]))

# Parameter Tuning
model = xgb.XGBClassifier()
param_dist = {"max_depth": [10,30,50],
              "min_child_weight" : [1,3,6],
              "n_estimators": [200],
              "learning_rate": [0.05, 0.1,0.16],}
grid_search = GridSearchCV(model, param_grid=param_dist, cv = 3, 
                                   verbose=10, n_jobs=-1)
grid_search.fit(train, y_train)

grid_search.best_estimator_

model = xgb.XGBClassifier(max_depth=50, min_child_weight=1,  n_estimators=200,\
                          n_jobs=-1 , verbose=1,learning_rate=0.16)
model.fit(train,y_train)
auc(model, train, test)

4.2 LightGBM

import lightgbm as lgb
from sklearn import metrics

def auc2(m, train, test): 
    return (metrics.roc_auc_score(y_train,m.predict(train)),
                            metrics.roc_auc_score(y_test,m.predict(test)))

lg = lgb.LGBMClassifier(verbose=0)
param_dist = {"max_depth": [25,50, 75],
              "learning_rate" : [0.01,0.05,0.1],
              "num_leaves": [300,900,1200],
              "n_estimators": [200]
             }
grid_search = GridSearchCV(lg, n_jobs=-1, param_grid=param_dist, cv = 3, scoring="roc_auc", verbose=5)
grid_search.fit(train,y_train)
grid_search.best_estimator_

d_train = lgb.Dataset(train, label=y_train)
params = {"max_depth": 50, "learning_rate" : 0.1, "num_leaves": 900,  "n_estimators": 300}

# Without Categorical Features
model2 = lgb.train(params, d_train)
auc2(model2, train, test)

# With Catgeorical Features
cate_features_name = ["MONTH","DAY","DAY_OF_WEEK","AIRLINE","DESTINATION_AIRPORT",
                 "ORIGIN_AIRPORT"]
model2 = lgb.train(params, d_train, categorical_feature = cate_features_name)
auc2(model2, train, test)

4.3 CatBoost

在調(diào)整CatBoost的參數(shù)時(shí),很難傳遞分類特征的索引玻蝌。因此蟹肘,我在沒有傳遞分類特征的情況下調(diào)整了參數(shù),并評估了兩個(gè)模型——一個(gè)使用分類特征俯树,另一個(gè)不使用分類特征帘腹。我單獨(dú)調(diào)整了one_hot_max_size,因?yàn)樗粫绊懫渌麉?shù)聘萨。

import catboost as cb
cat_features_index = [0,1,2,3,4,5,6]

def auc(m, train, test): 
    return (metrics.roc_auc_score(y_train,m.predict_proba(train)[:,1]),
                            metrics.roc_auc_score(y_test,m.predict_proba(test)[:,1]))

params = {'depth': [4, 7, 10],
          'learning_rate' : [0.03, 0.1, 0.15],
         'l2_leaf_reg': [1,4,9],
         'iterations': [300]}
cb = cb.CatBoostClassifier()
cb_model = GridSearchCV(cb, params, scoring="roc_auc", cv = 3)
cb_model.fit(train, y_train)

With Categorical features
clf = cb.CatBoostClassifier(eval_metric="AUC", depth=10, iterations= 500, l2_leaf_reg= 9, learning_rate= 0.15)
clf.fit(train,y_train)
auc(clf, train, test)

With Categorical features
clf = cb.CatBoostClassifier(eval_metric="AUC",one_hot_max_size=31, \
                            depth=10, iterations= 500, l2_leaf_reg= 9, learning_rate= 0.15)
clf.fit(train,y_train, cat_features= cat_features_index)
auc(clf, train, test)

5. 結(jié)論

在評估模型時(shí)竹椒,我們應(yīng)該從速度和準(zhǔn)確性兩個(gè)方面考慮模型的性能。

考慮到這一點(diǎn)米辐,CatBoost是贏家胸完,測試集上的準(zhǔn)確率最高(0.816),過擬合最星讨(訓(xùn)練集和測試集的準(zhǔn)確率接近)且預(yù)測時(shí)間和調(diào)優(yōu)時(shí)間最短赊窥。但這僅僅是因?yàn)槲覀兛紤]了分類變量并調(diào)整了one_hot_max_size。如果我們不利用CatBoost的這些特性狸页,它的準(zhǔn)確率只有0.752锨能,表現(xiàn)最差。因此芍耘,我們得出結(jié)論址遇,CatBoost僅在數(shù)據(jù)中存在分類變量且我們正確調(diào)整它們時(shí)表現(xiàn)良好。

我們的下一個(gè)表現(xiàn)良好的模型是XGBoost斋竞。即使忽略了我們在數(shù)據(jù)中有分類變量并將其轉(zhuǎn)換為數(shù)值變量供XGBoost使用的事實(shí)倔约,它的準(zhǔn)確率仍與CatBoost相當(dāng)接近。然而坝初,XGBoost唯一的問題是速度太慢浸剩。調(diào)整其參數(shù)真的很令人沮喪钾军,特別是使用GridSearchCV(運(yùn)行GridSearchCV花費(fèi)了我6個(gè)小時(shí),非常糟糕的主意>钜)吏恭。更好的方法是單獨(dú)調(diào)整參數(shù),而不是使用GridSearchCV重罪。閱讀這篇博文樱哼,了解如何巧妙地調(diào)整參數(shù)。

最后蛆封,LightGBM排名最后唇礁。這里需要注意的一點(diǎn)是,當(dāng)使用cat_features時(shí)惨篱,它在速度和準(zhǔn)確性方面表現(xiàn)不佳。我認(rèn)為它表現(xiàn)糟糕的原因是它對分類數(shù)據(jù)使用了某種修改過的均值編碼围俘,導(dǎo)致過擬合(訓(xùn)練準(zhǔn)確率非常高——0.999砸讳,相比之下測試準(zhǔn)確率較低)。然而界牡,如果像XGBoost那樣正常使用它簿寂,它可以以比XGBoost快得多的速度實(shí)現(xiàn)類似(甚至更高)的準(zhǔn)確性(LGBM——0.785,XGBoost——0.789)宿亡。

最后常遂,我必須說這些觀察結(jié)果適用于這個(gè)特定的數(shù)據(jù)集,對于其他數(shù)據(jù)集可能有效也可能無效挽荠。然而克胳,一般來說,一個(gè)真實(shí)的情況是XGBoost比其他兩種算法更慢圈匆。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末漠另,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子跃赚,更是在濱河造成了極大的恐慌笆搓,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,539評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件纬傲,死亡現(xiàn)場離奇詭異满败,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)叹括,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,594評論 3 396
  • 文/潘曉璐 我一進(jìn)店門算墨,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人领猾,你說我怎么就攤上這事米同『龋” “怎么了?”我有些...
    開封第一講書人閱讀 165,871評論 0 356
  • 文/不壞的土叔 我叫張陵面粮,是天一觀的道長少孝。 經(jīng)常有香客問我,道長熬苍,這世上最難降的妖魔是什么稍走? 我笑而不...
    開封第一講書人閱讀 58,963評論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮柴底,結(jié)果婚禮上婿脸,老公的妹妹穿的比我還像新娘。我一直安慰自己柄驻,他們只是感情好狐树,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,984評論 6 393
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著鸿脓,像睡著了一般抑钟。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上野哭,一...
    開封第一講書人閱讀 51,763評論 1 307
  • 那天在塔,我揣著相機(jī)與錄音,去河邊找鬼拨黔。 笑死蛔溃,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的篱蝇。 我是一名探鬼主播贺待,決...
    沈念sama閱讀 40,468評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼态兴!你這毒婦竟也來了狠持?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,357評論 0 276
  • 序言:老撾萬榮一對情侶失蹤瞻润,失蹤者是張志新(化名)和其女友劉穎喘垂,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體绍撞,經(jīng)...
    沈念sama閱讀 45,850評論 1 317
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡正勒,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,002評論 3 338
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了傻铣。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片章贞。...
    茶點(diǎn)故事閱讀 40,144評論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖非洲,靈堂內(nèi)的尸體忽然破棺而出鸭限,到底是詐尸還是另有隱情蜕径,我是刑警寧澤,帶...
    沈念sama閱讀 35,823評論 5 346
  • 正文 年R本政府宣布败京,位于F島的核電站兜喻,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏赡麦。R本人自食惡果不足惜朴皆,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,483評論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望泛粹。 院中可真熱鬧遂铡,春花似錦、人聲如沸晶姊。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,026評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽帽借。三九已至珠增,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間砍艾,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,150評論 1 272
  • 我被黑心中介騙來泰國打工巍举, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留脆荷,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,415評論 3 373
  • 正文 我出身青樓懊悯,卻偏偏與公主長得像蜓谋,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個(gè)殘疾皇子炭分,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,092評論 2 355

推薦閱讀更多精彩內(nèi)容