首先,要明白Cart生成算法种冬。Cart生成算法的核心是以基尼系數(shù)(Gini Index)最小化為準(zhǔn)則生成分類樹付燥。理解下Gini Index,它用來衡量Pure程度枝哄,即一個(gè)節(jié)點(diǎn)中包含y因變量值的差異程度肄梨。Gini Index越小,說明y的值越一致挠锥,分類效果好众羡,選擇這樣的特征作為節(jié)點(diǎn),樹的效率才高蓖租。
Cart算法的基本思路(遞歸過程):
Step 1: 選定training data粱侣,遍歷每一個(gè)特征A,對(duì)每個(gè)特征A可取的值a蓖宦,根據(jù)A=a測(cè)試是否劃分為兩部分齐婴,并計(jì)算Gini Index。
Step 2: 在step 1中計(jì)算得到的Gini Index中稠茂,選擇最小的Gini Index對(duì)應(yīng)的A=a作為最有特征與最優(yōu)切分點(diǎn)柠偶,由此training data被分配到了兩個(gè)子節(jié)點(diǎn)中。
Step 3: 重復(fù)以上步驟睬关,直到滿足停止條件诱担。
R 中的rpart package能夠?qū)崿F(xiàn)Cart 算法。
R code:
# raw data has 4521 rows and 17 columns; the last column is y
bank <- read.csv("C:/working/summer/機(jī)器學(xué)習(xí)/決策樹/bank/bank.csv",header=TRUE,sep=';')
# seprate as training set & valication set
bank_train <- bank[1:4000,]
bank_test <- bank[4001:4521,1:16]
bank_test1 <- bank[4001:4521,]
# build tree
library(rpart)
fit <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact
+day+month+duration+campaign+pdays+previous+poutcome,method="class",
data=bank_train) # method=class represent build classification tree
plot(fit, uniform = TRUE,main="Classification Tree for Bank")
text(fit,use.n = TRUE,all=TRUE)
#######################################################################################################
#use validation data to test the accuracy
result <- predict(fit, bank_test,type = "class")
#use a function to calculate accuracy rate
source("C:/working/summer/機(jī)器學(xué)習(xí)/決策樹/accurate rate.r")
count_result(result,bank_test1)
#######################################################################################################
# deal with missing value
# na.action 默認(rèn)保留自變量缺失的觀測(cè)值电爹,刪除因變量缺失的觀測(cè)值
# 但是不明白怎么保留自變量缺失的觀測(cè)值蔫仙??這樣保留了怎么建的樹藐不?
summary(bank) #The 4th, 9th,16th column have unknown value
n <- nrow(bank)
for (i in 1:n){
if (bank[i,4]=="unknown"){
bank[i,4]=NA
}
if (bank[i,9]=="unknown"){
bank[i,9]=NA
}
if (bank[i,16]=="unknown"){
bank[i,16]=NA
}
}
fit2 <- rpart(y~.,method = "class", data=bank_train,na.action=na.rpart)
plot(fit,,use.n=TRUE,all=TRUE)
text(fit,use.n = TRUE,all=TRUE)
result2 <- predict(fit2,bank_test,type="class")
count_result(result2,bank_test1)
########################################################################################################
fit3 <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact+day+month+duration+campaign+
pdays+previous+poutcome,method="class",data=bank_train,na.action=na.rpart,
control=rpart.control(minsplit=40,cp=0.001)) # minsplit越大樹越簡(jiǎn)單匀哄,它表示當(dāng)分類小到這個(gè)值時(shí)就停止
result3 <- predict(fit3,bank_test,type="class")
count_result(result3,bank_test1)
plot(fit3,use.n=TRUE,all=TRUE)
count_result function 用來計(jì)算分類的正確率
count_result <- function(result,data_test){
n <- length(result)
count_right<-0
i <-1
for (i in 1:n){
if (result[i]==data_test[i,17]){
count_right=count_right+1
}
}
print(count_right/n)
剪枝:
library(rpart)
fit <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact
+day+month+duration+campaign+pdays+previous+poutcome,method="class",
data=bank_train,control=rpart.control(minsplit=140,cp=0.001)) # method=class represent build classification tree
plot(fit, uniform = TRUE,main="Classification Tree for Bank")
text(fit,use.n = TRUE,all=TRUE)
# more beautiful plot
library(rpart.plot)
rpart.plot(fit, branch=1, branch.type=2, type=1, extra=102,
shadow.col="gray", box.col="green",
border.col="blue", split.col="red",
split.cex=1.2, main="Kyphosis決策樹");
# prune
printcp(fit)
fit$cptable
fit2 <- prune(fit, cp= fit$cptable[which.min(fit$cptable[,"xerror"]),"CP"])
rpart.plot(fit2, branch=1, branch.type=2, type=1, extra=102,
shadow.col="gray", box.col="green",
border.col="blue", split.col="red",
split.cex=1.2, main="Kyphosis決策樹");
剪枝前:4層
剪枝后:3層