支持向量機(jī)Support vector machine,SVM是一種有監(jiān)督的機(jī)器學(xué)習(xí)算法竖般,可用于分類或者回歸硼补。本次筆記以分類任務(wù)為例主要學(xué)習(xí)。
1拇勃、簡(jiǎn)單理解
-
假設(shè)對(duì)于特征空間的樣本坐標(biāo)分四苇,以樣本標(biāo)簽為目的,存在一個(gè)超平面方咆,將樣本分為兩大類月腋,從而最大化區(qū)分兩類樣本。該超平面即為決策邊界,而間隔就是指訓(xùn)練數(shù)據(jù)中最接近決策邊界的樣本點(diǎn)與決策邊界之間的距離罗售。
SVM過程:以間隔最大化為目的辜窑,讓決策邊界盡可能遠(yuǎn)離樣本。然后根據(jù)這個(gè)決策邊界進(jìn)行樣本分類預(yù)測(cè)寨躁。
根據(jù)樣本數(shù)據(jù)的復(fù)雜程度穆碎,SVM建模的過程也有所不同,具體可分為如下三類
1.1 線性可分--硬間隔
hard margin classifier职恳,HML:此類樣本數(shù)據(jù)是最簡(jiǎn)單的情況所禀,即數(shù)據(jù)集樣本可以明顯地使用線性邊界區(qū)分開》徘眨可參考下圖色徘,分為3個(gè)步驟
- step1:為每個(gè)類別的分布情況繪制出外輪廓多邊形;
- step2:找出連接兩個(gè)輪廓最近的兩個(gè)數(shù)據(jù)點(diǎn)操禀,連接褂策;
- step3:繪制出該連線的垂直平分線,即為最優(yōu)的決策邊界颓屑。而連線長度的一半即為間隔(M)
HML的特點(diǎn)是不允許有樣本點(diǎn)位于間隔區(qū)內(nèi)斤寂,即必須干凈的劃分;即不能有樣本距決策邊界的距離比間隔的長度還短揪惦,甚至在決策邊界的另一端(誤分類)
1.2 線性不可完全分--軟間隔
-
當(dāng)兩個(gè)類別的邊界不是很明顯遍搞,或者存在離群點(diǎn)時(shí),如嚴(yán)格按照HMC器腋,會(huì)使模型過擬合溪猿,泛化能力降低。
- soft margin classifier纫塌,SML通過設(shè)置超參數(shù) C表示允許間隔內(nèi)存在樣本的數(shù)目诊县,優(yōu)化決策邊界,提高模型的泛化能力措左。如上所示翎冲,左圖為C=0的硬間隔方法;右圖為C取最大值的分類結(jié)果媳荒;
- 針對(duì)當(dāng)前數(shù)據(jù)集,可交叉驗(yàn)證不同C取值下的模型性能驹饺,從而確定最佳的指標(biāo)钳枕。
1.3 線性不可分-核技巧
上述例子數(shù)據(jù)集的決策平面都是線性的。當(dāng)決策邊界是一個(gè)曲線等復(fù)雜模式時(shí)赏壹,使用上述的方法無法得到滿意的分類效果鱼炒;
-
針對(duì)此類情況,支持向量機(jī)可通過核方法(kernel method)尋找到合適的非線性決策邊界蝌借∥羟疲可分為兩個(gè)步驟:
(1)將線性不可分的數(shù)據(jù)(n個(gè)特征向量指蚁,n維)增加一個(gè)維度(enlarged kernel-induced feature space),從而成為線性可分?jǐn)?shù)據(jù)(n+1)維自晰;
(2)在n+1 維的空間里確定合適的決策邊界(線性)凝化;然后投射到原來n維空間中,即得到我們真正需要的決策平面(非線性) 常見的核方法有:Radial basis function(徑向基函數(shù))酬荞、d-th degree polynomial搓劫、Hyperbolic tangent等,但通常建議使用徑向基函數(shù)(extremely flexible)試一試混巧。該方法的超參數(shù)枪向,除了
C
值外,還有一個(gè)sigma咧党,在下面代碼實(shí)操中會(huì)介紹到秘蛔。
2、代碼實(shí)操
(1)示例數(shù)據(jù):預(yù)測(cè)員工是否離職
library(modeldata)
data(attrition)
# initial dimension
dim(attrition)
## [1] 1470 31
library(dplyr)
df <- attrition %>%
mutate_if(is.ordered, factor, ordered = FALSE)
# Create training (70%) and test (30%) sets
set.seed(123) # for reproducibility
library(rsample)
churn_split <- initial_split(df, prop = 0.7, strata = "Attrition")
churn_train <- training(churn_split)
churn_test <- testing(churn_split)
(2)caret包建模
- 如上傍衡,我們使用徑向基函數(shù)的核方法尋找非線性決策邊界深员,從而建立支持向量機(jī)模型。
- 對(duì)于超參數(shù)
σ
,會(huì)自動(dòng)根據(jù)樣本數(shù)據(jù)尋找最合適的值聪舒;對(duì)超參數(shù)C
辨液,可以通過交叉驗(yàn)證選擇,一般備選方案為2的指數(shù)
系列值(2e-2, 2e-1,2e0,2e1,2e2...)
library(caret)
set.seed(1111) # for reproducibility
# Control params for SVM
ctrl <- trainControl(
method = "cv",
number = 10,
classProbs = TRUE, #表示返回分類概率箱残,而不是直接分類標(biāo)簽結(jié)果
summaryFunction = twoClassSummary # also needed for AUC/ROC
)
churn_svm <- train(
Attrition ~ .,
data = churn_train,
method = "svmRadial",
preProcess = c("center", "scale"),
trControl = ctrl,
metric = "ROC", # area under ROC curve (AUC)
tuneLength = 10) #遍歷C的10次取值滔迈,即從2的-2次方到2的7次方
#如下 C取4時(shí),模型最優(yōu)
churn_svm$results %>% arrange(desc(ROC)) %>% head(1)
# sigma C ROC Sens Spec ROCSD SensSD SpecSD
# 1 0.009522278 4 0.8234039 0.9791767 0.2738971 0.07462533 0.02019714 0.08679811
# Plot results
ggplot(churn_svm)
(3)測(cè)試集驗(yàn)證
pred = predict(churn_svm, churn_test)
table(pred)
# pred
# No Yes
# 415 27
table(churn_test$Attrition)
# No Yes
# 370 72
caret::confusionMatrix(pred, churn_test$Attrition, positive="Yes")
# Accuracy : 0.871
# 95% CI : (0.8362, 0.9008)
# No Information Rate : 0.8371
# P-Value [Acc > NIR] : 0.02819
#
# Kappa : 0.3681
#
# Mcnemar's Test P-Value : 5.611e-09
#
# Sensitivity : 0.29167
# Specificity : 0.98378
# Pos Pred Value : 0.77778
# Neg Pred Value : 0.87711
# Prevalence : 0.16290
# Detection Rate : 0.04751
# Detection Prevalence : 0.06109
# Balanced Accuracy : 0.63773
#
# 'Positive' Class : Yes
(4)衡量特征重要性
- SVM算法本身不提供有衡量特征重要性的計(jì)算方法被辑;
- 可使用
vip
包提供的permutation test置換檢驗(yàn)的方法燎悍,隨機(jī)調(diào)整某一列(特征)值的順序,觀察預(yù)測(cè)準(zhǔn)確率是否明顯下降盼理,從而判斷特征變量的重要性谈山。
library(vip)
prob_yes <- function(object, newdata) {
predict(object, newdata = newdata, type = "prob")[, "Yes"]
}
# Variable importance plot
set.seed(2827) # for reproducibility
vip(churn_svm, method = "permute", nsim = 5, train = churn_train,
target = "Attrition", metric = "auc", reference_class = "Yes",
pred_wrapper = prob_yes)
-
pdp
包觀察具體某一個(gè)特征變量對(duì)于預(yù)測(cè)結(jié)果的影響
library(pdp)
features <- c("OverTime", "JobRole")
pdps <- lapply(features, function(x) {
partial(churn_svm, pred.var = x, which.class = 2,
prob = TRUE, plot = TRUE, plot.engine = "ggplot2") +
coord_flip()
})
grid.arrange(grobs = pdps, nrow = 1)