很多人戲稱調(diào)參的過(guò)程就像是"煉丹"倦逐!確實(shí)差不多唬涧,而且很多時(shí)候你調(diào)整后的結(jié)果可能還不如默認(rèn)的結(jié)果好疫赎!這就好比打游戲,"一頓操作猛如虎碎节,一看戰(zhàn)績(jī)0比5"捧搞!
模型調(diào)優(yōu)一定要基于對(duì)算法和數(shù)據(jù)的理解進(jìn)行,不是隨便調(diào)的钓株。
我們使用著名的糖尿病數(shù)據(jù)集進(jìn)行演示实牡,首先創(chuàng)建任務(wù)
library(mlr3verse)
## 載入需要的程輯包:mlr3
task <- tsk("pima")
print(task)
## <TaskClassif:pima> (768 x 9)
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
## - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
## triceps
選擇算法,查看算法支持的超參數(shù)
learner <- lrn("classif.rpart")
learner$param_set
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
## 7: minsplit ParamInt 1 Inf Inf 20
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
在這里我們選擇調(diào)整復(fù)雜度參數(shù)cp
和最小分支參數(shù)minsplit
,并設(shè)定超參數(shù)的調(diào)整范圍:
search_space <- ps(
cp = p_dbl(lower = 0.001, upper = 0.1),
minsplit = p_int(lower = 1, upper = 10)
)
search_space
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>
然后選擇重抽樣方法和性能指標(biāo)
hout <- rsmp("holdout", ratio = 0.7)
measure <- msr("classif.ce")
接下來(lái)進(jìn)行調(diào)參有兩種方法轴合。
方法一:通過(guò)tuninginstancesinglecrite
和tuner
訓(xùn)練模型
library(mlr3tuning)
## 載入需要的程輯包:paradox
evals20 <- trm("evals", n_evals = 20) # 設(shè)定何時(shí)停止訓(xùn)練
# 統(tǒng)一放入instance中
instance <- TuningInstanceSingleCrit$new(
task = task,
learner = learner,
resampling = hout,
measure = measure,
terminator = evals20,
search_space = search_space
)
instance
## <TuningInstanceSingleCrit>
## * State: Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>
## * Terminator: <TerminatorEvals>
## * Terminated: FALSE
## * Archive:
## <ArchiveTuning>
## Null data.table (0 rows and 0 cols)
關(guān)于何時(shí)停止訓(xùn)練创坞,mlr3
給出了5種方法:
- Terminate after a given time:一定時(shí)間后停止
- Terninate after a given number of iterations:特定迭代次數(shù)后停止
- Terminate after a specific performance has been reached:達(dá)到特定性能指標(biāo)后停止
- Terminate when tuning dose find a better configuration for a given number of iterations:在給定迭代次數(shù)中確實(shí)找到表現(xiàn)很好的參數(shù)組合后停止
- A combination of above in ALL or ANY fashon:上面幾種方法組合
然后還需要設(shè)置超參數(shù)搜索的方法:
mlr3tuning
目前支持以下超參數(shù)搜索的方法:
- Grid search:網(wǎng)格搜索
- Random search:隨機(jī)搜索
- Generalized simulated annealing
- Non-Linear optimization
# 這里選擇網(wǎng)格搜索
tuner <- tnr("grid_search", resolution = 5) # 網(wǎng)格搜索
接下來(lái)就是進(jìn)行訓(xùn)練模型,上面我們?cè)O(shè)置了網(wǎng)格搜索的分辨率是5受葛,我們有2個(gè)超參數(shù)需要調(diào)整题涨,所以理論上一共有5 * 5 = 25個(gè)組合偎谁,但是在前面的停止搜索的方法中我們選擇了n_evals = 20
,所有實(shí)際上在評(píng)價(jià)完20個(gè)組合后就會(huì)停止了纲堵!
#lgr::get_logger("mlr3")$set_threshold("warn")
#lgr::get_logger("bbotk")$set_threshold("warn") # 減少屏幕打印內(nèi)容
tuner$optimize(instance)
## INFO [20:51:28.312] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]'
## INFO [20:51:28.331] [bbotk] Evaluating 1 configuration(s)
## INFO [20:51:29.309] [bbotk] Finished optimizing after 20 evaluation(s)
## INFO [20:51:29.310] [bbotk] Result:
## INFO [20:51:29.310] [bbotk] cp minsplit learner_param_vals x_domain classif.ce
## INFO [20:51:29.310] [bbotk] 0.02575 3 <list[3]> <list[2]> 0.2130435
## cp minsplit learner_param_vals x_domain classif.ce
## 1: 0.02575 3 <list[3]> <list[2]> 0.2130435
查看調(diào)整好的超參數(shù):
instance$result_learner_param_vals
## $xval
## [1] 0
##
## $cp
## [1] 0.02575
##
## $minsplit
## [1] 3
查看模型性能:
instance$result_y
## classif.ce
## 0.2130435
查看每一次迭代的結(jié)果巡雨,只有20個(gè):
instance$archive
## <ArchiveTuning>
## cp minsplit classif.ce runtime_learners timestamp batch_nr
## 1: 0.026 3 0.21 0.02 2022-02-27 20:51:28 1
## 2: 0.075 8 0.21 0.00 2022-02-27 20:51:28 2
## 3: 0.050 5 0.21 0.00 2022-02-27 20:51:28 3
## 4: 0.001 1 0.30 0.00 2022-02-27 20:51:28 4
## 5: 0.100 3 0.21 0.02 2022-02-27 20:51:28 5
## 6: 0.026 5 0.21 0.02 2022-02-27 20:51:28 6
## 7: 0.100 8 0.21 0.01 2022-02-27 20:51:28 7
## 8: 0.001 8 0.27 0.00 2022-02-27 20:51:28 8
## 9: 0.001 5 0.28 0.00 2022-02-27 20:51:28 9
## 10: 0.100 5 0.21 0.02 2022-02-27 20:51:28 10
## 11: 0.075 10 0.21 0.00 2022-02-27 20:51:28 11
## 12: 0.050 10 0.21 0.01 2022-02-27 20:51:28 12
## 13: 0.075 5 0.21 0.00 2022-02-27 20:51:28 13
## 14: 0.050 8 0.21 0.01 2022-02-27 20:51:29 14
## 15: 0.001 10 0.26 0.00 2022-02-27 20:51:29 15
## 16: 0.050 3 0.21 0.00 2022-02-27 20:51:29 16
## 17: 0.050 1 0.21 0.02 2022-02-27 20:51:29 17
## 18: 0.100 10 0.21 0.00 2022-02-27 20:51:29 18
## 19: 0.075 1 0.21 0.01 2022-02-27 20:51:29 19
## 20: 0.026 1 0.21 0.00 2022-02-27 20:51:29 20
## warnings errors resample_result
## 1: 0 0 <ResampleResult[22]>
## 2: 0 0 <ResampleResult[22]>
## 3: 0 0 <ResampleResult[22]>
## 4: 0 0 <ResampleResult[22]>
## 5: 0 0 <ResampleResult[22]>
## 6: 0 0 <ResampleResult[22]>
## 7: 0 0 <ResampleResult[22]>
## 8: 0 0 <ResampleResult[22]>
## 9: 0 0 <ResampleResult[22]>
## 10: 0 0 <ResampleResult[22]>
## 11: 0 0 <ResampleResult[22]>
## 12: 0 0 <ResampleResult[22]>
## 13: 0 0 <ResampleResult[22]>
## 14: 0 0 <ResampleResult[22]>
## 15: 0 0 <ResampleResult[22]>
## 16: 0 0 <ResampleResult[22]>
## 17: 0 0 <ResampleResult[22]>
## 18: 0 0 <ResampleResult[22]>
## 19: 0 0 <ResampleResult[22]>
## 20: 0 0 <ResampleResult[22]>
接下來(lái)就可以把訓(xùn)練好的超參數(shù)應(yīng)用于模型,重新應(yīng)用于數(shù)據(jù):
learner$param_set$values <- instance$result_learner_param_vals
learner$train(task)
這個(gè)訓(xùn)練好的模型就可以用于預(yù)測(cè)了席函,使用learner$predict()
即可铐望!
以上步驟寫(xiě)起來(lái)有些復(fù)雜,與tidymodels
相比不夠簡(jiǎn)潔好理解茂附,我剛開(kāi)始學(xué)習(xí)的時(shí)候經(jīng)常記不住正蛙,后來(lái)版本更新后終于有了簡(jiǎn)便寫(xiě)法:
instance <- tune(
task = task,
learner = learner,
resampling = hout,
measure = measure,
search_space = search_space,
method = "grid_search",
resolution = 5,
term_evals = 25
)
## INFO [20:51:29.402] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=25, k=0]'
## INFO [20:51:29.403] [bbotk] Evaluating 1 configuration(s)
## INFO [20:51:30.534] [bbotk] Finished optimizing after 25 evaluation(s)
## INFO [20:51:30.534] [bbotk] Result:
## INFO [20:51:30.535] [bbotk] cp minsplit learner_param_vals x_domain classif.ce
## INFO [20:51:30.535] [bbotk] 0.02575 10 <list[3]> <list[2]> 0.2347826
instance$result_learner_param_vals
## $xval
## [1] 0
##
## $cp
## [1] 0.02575
##
## $minsplit
## [1] 10
instance$result_y
## classif.ce
## 0.2347826
learner$param_set$values <- instance$result_learner_param_vals
learner$train(task)
mlr3
也支持同時(shí)設(shè)定多個(gè)性能指標(biāo):
measures <- msrs(c("classif.ce","time_train")) # 設(shè)定多個(gè)評(píng)價(jià)指標(biāo)
evals20 <- trm("evals", n_evals = 20)
instance <- TuningInstanceMultiCrit$new(
task = task,
learner = learner,
resampling = hout,
measures = measures,
search_space = search_space,
terminator = evals20
)
tuner$optimize(instance)
## INFO [20:51:30.595] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]'
## INFO [20:51:30.597] [bbotk] Evaluating 1 configuration(s)
## INFO [20:51:30.605] [mlr3] Running benchmark with 1 resampling iterations
## INFO [20:51:30.608] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [20:51:30.620] [mlr3] Finished benchmark
## INFO [20:51:30.642] [bbotk] Result of batch 1:
## INFO [20:51:30.643] [bbotk] cp minsplit classif.ce time_train warnings errors runtime_learners
## INFO [20:51:30.643] [bbotk] 0.0505 1 0.2347826 0 0 0 0.02
## cp minsplit learner_param_vals x_domain classif.ce time_train
## 1: 0.05050 1 <list[3]> <list[2]> 0.2347826 0
## 2: 0.07525 1 <list[3]> <list[2]> 0.2347826 0
## 3: 0.07525 10 <list[3]> <list[2]> 0.2347826 0
## 4: 0.10000 8 <list[3]> <list[2]> 0.2347826 0
## 5: 0.02575 3 <list[3]> <list[2]> 0.2347826 0
## 6: 0.07525 8 <list[3]> <list[2]> 0.2347826 0
## 7: 0.10000 3 <list[3]> <list[2]> 0.2347826 0
## 8: 0.10000 5 <list[3]> <list[2]> 0.2347826 0
## 9: 0.02575 5 <list[3]> <list[2]> 0.2347826 0
## 10: 0.07525 5 <list[3]> <list[2]> 0.2347826 0
## 11: 0.05050 8 <list[3]> <list[2]> 0.2347826 0
## 12: 0.05050 3 <list[3]> <list[2]> 0.2347826 0
## 13: 0.07525 3 <list[3]> <list[2]> 0.2347826 0
## 14: 0.05050 5 <list[3]> <list[2]> 0.2347826 0
## 15: 0.02575 1 <list[3]> <list[2]> 0.2347826 0
查看結(jié)果:
instance$result_learner_param_vals
## [[1]]
## [[1]]$xval
## [1] 0
##
## [[1]]$cp
## [1] 0.0505
##
## [[1]]$minsplit
## [1] 1
##
##
## [[2]]
## [[2]]$xval
## [1] 0
##
## [[2]]$cp
## [1] 0.07525
##
## [[2]]$minsplit
## [1] 1
##
##
## [[3]]
## [[3]]$xval
## [1] 0
##
## [[3]]$cp
## [1] 0.07525
##
## [[3]]$minsplit
## [1] 10
##
##
## [[4]]
## [[4]]$xval
## [1] 0
##
## [[4]]$cp
## [1] 0.1
##
## [[4]]$minsplit
## [1] 8
##
##
## [[5]]
## [[5]]$xval
## [1] 0
##
## [[5]]$cp
## [1] 0.02575
##
## [[5]]$minsplit
## [1] 3
##
##
## [[6]]
## [[6]]$xval
## [1] 0
##
## [[6]]$cp
## [1] 0.07525
##
## [[6]]$minsplit
## [1] 8
##
##
## [[7]]
## [[7]]$xval
## [1] 0
##
## [[7]]$cp
## [1] 0.1
##
## [[7]]$minsplit
## [1] 3
##
##
## [[8]]
## [[8]]$xval
## [1] 0
##
## [[8]]$cp
## [1] 0.1
##
## [[8]]$minsplit
## [1] 5
##
##
## [[9]]
## [[9]]$xval
## [1] 0
##
## [[9]]$cp
## [1] 0.02575
##
## [[9]]$minsplit
## [1] 5
##
##
## [[10]]
## [[10]]$xval
## [1] 0
##
## [[10]]$cp
## [1] 0.07525
##
## [[10]]$minsplit
## [1] 5
##
##
## [[11]]
## [[11]]$xval
## [1] 0
##
## [[11]]$cp
## [1] 0.0505
##
## [[11]]$minsplit
## [1] 8
##
##
## [[12]]
## [[12]]$xval
## [1] 0
##
## [[12]]$cp
## [1] 0.0505
##
## [[12]]$minsplit
## [1] 3
##
##
## [[13]]
## [[13]]$xval
## [1] 0
##
## [[13]]$cp
## [1] 0.07525
##
## [[13]]$minsplit
## [1] 3
##
##
## [[14]]
## [[14]]$xval
## [1] 0
##
## [[14]]$cp
## [1] 0.0505
##
## [[14]]$minsplit
## [1] 5
##
##
## [[15]]
## [[15]]$xval
## [1] 0
##
## [[15]]$cp
## [1] 0.02575
##
## [[15]]$minsplit
## [1] 1
instance$rusult_y
## NULL
以上就是第一種方法,接下來(lái)介紹第二種方法营曼。
方法二:通過(guò)autotuner
訓(xùn)練模型
這種方式方法把調(diào)整參數(shù)乒验、將調(diào)整好的參數(shù)應(yīng)用于模型放到一起了,但是也需要提前設(shè)定好各種需要的參數(shù)蒂阱。
task <- tsk("pima") # 創(chuàng)建任務(wù)
leanrer <- lrn("classif.rpart") # 選擇學(xué)習(xí)器
search_space <- ps(
cp = p_dbl(0.001, 0.1),
minsplit = p_int(1,10)
) # 設(shè)定搜索范圍
terminator <- trm("evals", n_evals = 10) # 設(shè)定停止標(biāo)志
tuner <- tnr("random_search") # 選擇搜索方法
resampling <- rsmp("holdout") # 選擇重抽樣方法
measure <- msr("classif.acc") # 選擇評(píng)價(jià)指標(biāo)
# 訓(xùn)練
at <- AutoTuner$new(
learner = learner,
resampling = resampling,
search_space = search_space,
measure = measure,
tuner = tuner,
terminator = terminator
)
自動(dòng)選擇最優(yōu)參數(shù)并作用于數(shù)據(jù):
at$train(task)
## INFO [20:51:31.873] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]'
## INFO [20:51:32.332] [bbotk] 0.02278977 3 <list[3]> <list[2]> 0.7695312
at$predict(task)
## <PredictionClassif> for 768 observations:
## row_ids truth response
## 1 pos pos
## 2 neg neg
## 3 pos neg
## ---
## 766 neg neg
## 767 pos neg
## 768 neg neg
這個(gè)方法也有個(gè)簡(jiǎn)便寫(xiě)法:
auto_learner <- auto_tuner(
learner = learner,
resampling = resampling,
measure = measure,
search_space = search_space,
method = "random_search",
term_evals = 10
)
auto_learner$train(task)
## INFO [20:51:32.407] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]'
## INFO [20:51:32.858] [bbotk] Finished optimizing after 10 evaluation(s)
## INFO [20:51:32.859] [bbotk] Result:
## INFO [20:51:32.859] [bbotk] cp minsplit learner_param_vals x_domain classif.acc
## INFO [20:51:32.859] [bbotk] 0.02922122 8 <list[3]> <list[2]> 0.7539062
auto_learner$predict(task)
## <PredictionClassif> for 768 observations:
## row_ids truth response
## 1 pos pos
## 2 neg neg
## 3 pos neg
## ---
## 766 neg neg
## 767 pos neg
## 768 neg neg
超參數(shù)設(shè)定的方法
每次單獨(dú)設(shè)置超參數(shù)的范圍等可能會(huì)顯得比較笨重?zé)o聊锻全,mlr3
也提供另外一種可以在選擇學(xué)習(xí)器時(shí)進(jìn)行設(shè)定超參數(shù)的方法。
# 在選擇學(xué)習(xí)器時(shí)設(shè)置超參數(shù)范圍
learner <- lrn("classif.svm")
learner$param_set$values$kernel <- "polynomial"
learner$param_set$values$degree <- to_tune(lower = 1, upper = 3)
print(learner$param_set$search_space())
## <ParamSet>
## id class lower upper nlevels default value
## 1: degree ParamInt 1 3 3 <NoDefault[3]>
但其實(shí)這樣也有問(wèn)題录煤,這個(gè)方法要求你對(duì)算法很熟悉鳄厌,能夠記住所有超參數(shù)記憶它們?cè)?code>mlr3中的拼寫(xiě)!但很顯然這有點(diǎn)困難辐赞,所有我還是推薦第一種部翘,每次單獨(dú)設(shè)置硝训,記不住還可以查看一下具體的超參數(shù)响委。
參數(shù)依賴
某些超參數(shù)只有在某些條件下才有效,比如支持向量機(jī)(SVM)窖梁,它的degree
參數(shù)只有在kernel
是polynomial
時(shí)才有效赘风,這種情況也可以在mlr3
中設(shè)置好。
library(data.table)
search_space = ps(
cost = p_dbl(-1, 1, trafo = function(x) 10^x), # 可進(jìn)行數(shù)據(jù)變換
kernel = p_fct(c("polynomial", "radial")),
degree = p_int(1, 3, depends = kernel == "polynomial") # 設(shè)置參數(shù)依賴
)
rbindlist(generate_design_grid(search_space, 3)$transpose(), fill = TRUE)
## cost kernel degree
## 1: 0.1 polynomial 1
## 2: 0.1 polynomial 2
## 3: 0.1 polynomial 3
## 4: 0.1 radial NA
## 5: 1.0 polynomial 1
## 6: 1.0 polynomial 2
## 7: 1.0 polynomial 3
## 8: 1.0 radial NA
## 9: 10.0 polynomial 1
## 10: 10.0 polynomial 2
## 11: 10.0 polynomial 3
## 12: 10.0 radial NA
進(jìn)行以上設(shè)置后在進(jìn)行后面的操作時(shí)不會(huì)出錯(cuò)纵刘,自動(dòng)處理邀窃。