Tidymodels: tidy machine learning in R
在處理數(shù)據(jù)時(shí)蛤高,有簡(jiǎn)潔的工具包,tidyverse應(yīng)運(yùn)而生,極大地簡(jiǎn)化數(shù)據(jù)處理流程钓株,讓數(shù)據(jù)處理變得簡(jiǎn)潔,清晰陌僵。
但是在處理完數(shù)據(jù)后轴合,需要對(duì)數(shù)據(jù)進(jìn)行建模分析,預(yù)測(cè)與擬合碗短,這個(gè)過(guò)程隨著模型的不同而變的多元化受葛,尤其是機(jī)器學(xué)習(xí)應(yīng)用。加速了模型構(gòu)建的流程化與簡(jiǎn)潔化豪椿。
Caret的出現(xiàn)奔坟,讓此項(xiàng)工作變得簡(jiǎn)潔明了。但是還是有些缺點(diǎn)搭盾。
上圖基于Wickham和Grolemund撰寫(xiě)的《 R for Data Science》一書(shū)咳秉。
本文中的版本詳細(xì)解釋了tidymodels每個(gè)程序包涵蓋的步驟。在模型構(gòu)建及預(yù)測(cè)過(guò)程中鸯隅,tidymodels的流暢與簡(jiǎn)潔澜建,讓你體驗(yàn)縱享絲滑般的感受。
在模型構(gòu)建過(guò)程中蝌以,需要涉及的數(shù)據(jù)預(yù)處理及模型參數(shù)調(diào)整炕舵,這些步驟都含括在以下程序包中:
- rsample - 數(shù)據(jù)分離重采樣
- recipes - 數(shù)據(jù)轉(zhuǎn)換處理
- parnip - 模型構(gòu)建框架
- yardstick - 模型效果評(píng)估
下圖說(shuō)明了tidymodels建模步驟:
數(shù)據(jù)iris
下面我們將通過(guò)iris數(shù)據(jù)來(lái)舉例說(shuō)明。
首先跟畅,我們將iris數(shù)據(jù)分成訓(xùn)練和測(cè)試集咽筋,通過(guò)initial_split()函數(shù)實(shí)現(xiàn)數(shù)據(jù)拆分,可以根據(jù)prop參數(shù)徊件,指定分離比例奸攻。分離數(shù)據(jù)后,我們可以通過(guò)training() 與testing() 函數(shù)虱痕,獲取訓(xùn)練集和測(cè)試集的數(shù)據(jù)睹耐。
library(tidymodels)
# split
iris_split <- initial_split(iris, prop = 0.6)
iris_split
# get training data
iris_split %>%
training() %>%
glimpse()
## Observations: 90
## Variables: 5
## $ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.9, 5.4, 4…
## $ Sepal.Width <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 3.1, 3.7, 3…
## $ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.5, 1.5, 1…
## $ Petal.Width <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.1, 0.2, 0…
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, set…
數(shù)據(jù)預(yù)處理
recipes
包提供了多種函數(shù),可以對(duì)數(shù)據(jù)進(jìn)行預(yù)處理部翘。包括數(shù)據(jù)的標(biāo)準(zhǔn)化硝训,數(shù)據(jù)的相關(guān)性重復(fù),變成亞分類變量等。
- step_corr() - 消除相關(guān)性較高的影響
- step_center() - 以0為中心標(biāo)準(zhǔn)化
- step_scale() - 以1為中心標(biāo)準(zhǔn)化
recipe還有一個(gè)好處就是窖梁,在指定數(shù)據(jù)處理時(shí)赘风,可以用all_predictors()
來(lái)指定對(duì)所有協(xié)變量進(jìn)行歸一化。然后all_outcomes()
可以指定y窄绒。
可以打印recipe
的詳細(xì)信息贝次。里面記錄了驟刪除了Petal.Length變量。
在處理完train數(shù)據(jù)后彰导,test數(shù)據(jù)可以用bake函數(shù)進(jìn)行相似的處理蛔翅。然后輸出為dataframe。train數(shù)據(jù)從iris_recipe
輸出為dataframe位谋,可以用juice()
山析。
# train data
iris_recipe <- training(iris_split) %>%
recipe(Species ~.) %>%
step_corr(all_predictors()) %>%
step_center(all_predictors(), -all_outcomes()) %>%
step_scale(all_predictors(), -all_outcomes()) %>%
prep()
iris_recipe
## Data Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 4
##
## Training data contained 90 data points and no missing data.
##
## Operations:
##
## Correlation filter removed Petal.Length [trained]
## Centering for Sepal.Length, Sepal.Width, Petal.Width [trained]
## Scaling for Sepal.Length, Sepal.Width, Petal.Width [trained]
# test data
iris_testing <- iris_recipe %>%
bake(testing(iris_split))
glimpse(iris_testing)
## Observations: 60
## Variables: 4
## $ Sepal.Length <dbl> -1.597601746, -1.138960096, 0.007644027, -0.7949788…
## $ Sepal.Width <dbl> -0.41010139, 0.71517681, 2.06551064, 1.61539936, 0.…
## $ Petal.Width <dbl> -1.2085003, -1.2085003, -1.2085003, -1.0796318, -1.…
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, set…
數(shù)據(jù)建模
在R里面,有很多關(guān)于機(jī)器學(xué)習(xí)的包掏父,ranger
笋轨,randomForest
都有針對(duì)各自包的定義的參數(shù)及說(shuō)明,很不方便赊淑,沒(méi)有統(tǒng)一標(biāo)準(zhǔn)爵政。
tidymodels的出現(xiàn),將這些機(jī)器學(xué)習(xí)的包整合到一在接口陶缺,而不是重新開(kāi)發(fā)機(jī)器學(xué)習(xí)的包钾挟。更準(zhǔn)確的說(shuō),tidymodels提供了一組用于定義模型的函數(shù)和參數(shù)饱岸。然后根據(jù)請(qǐng)求的建模包對(duì)模型進(jìn)行擬合掺出。
現(xiàn)在我們準(zhǔn)備根據(jù)我們的數(shù)據(jù),建一個(gè)隨機(jī)森林模型苫费。rand_forest()
函數(shù)來(lái)定義汤锨,我們的模型然后mode參數(shù)定義分類還是回歸問(wèn)題。mode = "classification"
因?yàn)楸狙芯渴欠诸悊?wèn)題百框。trees可以設(shè)定節(jié)點(diǎn)的數(shù)闲礼。然后set_engine()
很重要,可以指定我們運(yùn)行的模型的引擎铐维,可以是glm柬泽、rf等。然后用fit()
函數(shù)方椎,加載我們要擬合的數(shù)據(jù)聂抢。
# ranger
iris_ranger <- rand_forest(trees = 100, mode = "classification") %>%
set_engine("ranger") %>%
fit(Species ~ ., data = iris_training)
# randomForest
iris_rf <- rand_forest(trees = 100, mode = "classification") %>%
set_engine("randomForest") %>%
fit(Species ~ ., data = iris_training)
總的來(lái)說(shuō)钧嘶,模型構(gòu)建的步驟分為三部棠众,選定模型, set_engine 然后 fit數(shù)據(jù)。流水線式操作闸拿。
預(yù)測(cè)
針對(duì)arsnip的predict()函數(shù)空盼,可以返回tibble數(shù)據(jù)格式。默認(rèn)情況下新荤,預(yù)測(cè)變量稱為.pred_class揽趾。在示例中,test的數(shù)據(jù)是bake以后的--數(shù)據(jù)預(yù)處理后的testing data苛骨。然后我們將其合并入test數(shù)據(jù)集中篱瞎。
predict(iris_ranger, iris_testing)
iris_ranger %>%
predict(iris_testing) %>%
bind_cols(iris_testing)
iris_ranger
## Observations: 60
## Variables: 5
## $ .pred_class <fct> setosa, setosa, setosa, setosa, setosa, setosa, set…
## $ Sepal.Length <dbl> -1.597601746, -1.138960096, 0.007644027, -0.7949788…
## $ Sepal.Width <dbl> -0.41010139, 0.71517681, 2.06551064, 1.61539936, 0.…
## $ Petal.Width <dbl> -1.2085003, -1.2085003, -1.2085003, -1.0796318, -1.…
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, set…
iris_ranger %>%
predict(iris_testing, type = "prob") %>%
glimpse()
## Observations: 60
## Variables: 3
## $ .pred_setosa <dbl> 0.677480159, 0.978293651, 0.783250000, 0.983972…
## $ .pred_versicolor <dbl> 0.295507937, 0.011706349, 0.150833333, 0.001111…
## $ .pred_virginica <dbl> 0.02701190, 0.01000000, 0.06591667, 0.01491667,…
該模型預(yù)測(cè)的結(jié)果為分類變量,當(dāng)然有時(shí)候會(huì)根據(jù)需要痒芝,預(yù)測(cè)每個(gè)類別的概率俐筋,所以可以通過(guò)predict函數(shù)中的 type參數(shù)來(lái)輸出為概率。
模型評(píng)估
使用metrics()函數(shù)來(lái)衡量模型的性能严衬。它將自動(dòng)選擇適合給定模型類型的指標(biāo)澄者。
該函數(shù)需要一個(gè)包含實(shí)際結(jié)果(真相)和模型預(yù)測(cè)值(估計(jì)值)的tibble數(shù)據(jù)。
iris_ranger %>%
predict(iris_testing) %>%
bind_cols(iris_testing) %>%
metrics(truth = Species, estimate = .pred_class)
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy multiclass 0.917
## 2 kap multiclass 0.874
iris_rf %>%
predict(iris_testing) %>%
bind_cols(iris_testing) %>%
metrics(truth = Species, estimate = .pred_class)
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy multiclass 0.883
## 2 kap multiclass 0.824
繪制分類結(jié)果的圖
iris_probs%>%
gain_curve(Species, .pred_setosa:.pred_virginica) %>%
autoplot()
iris_probs%>%
roc_curve(Species, .pred_setosa:.pred_virginica) %>%
autoplot()