幸存者預(yù)測(cè)咽白??聽起來(lái)是不是很有意思鸟缕;沒錯(cuò)>Э颉!更有意思的還在后面懂从;本期給大家詳細(xì)介紹如果通過(guò)隨機(jī)森林算法預(yù)測(cè)泰坦尼克號(hào)幸存者的全過(guò)程授段;工具采用R語(yǔ)言,案例來(lái)自于Kaggle番甩。
案例背景
泰坦尼克號(hào)沉船事故是世界上最著名的沉船事故之一侵贵。1912年4月15日,在她的處女航期間缘薛,泰坦尼克號(hào)撞上冰山后沉沒模燥,造成2224名乘客和機(jī)組人員中超過(guò)1502人的死亡。這一轟動(dòng)的悲劇震驚了國(guó)際社會(huì)掩宜,并導(dǎo)致更好的船舶安全法規(guī)蔫骂。
事故中導(dǎo)致死亡的一個(gè)原因是許多船員和乘客沒有足夠的救生艇。然而在被獲救群體中也有一些比較幸運(yùn)的因素牺汤;一些人群在事故中被救的幾率高于其他人辽旋,比如婦女、兒童和上層階級(jí)檐迟。
這個(gè)Case里补胚,我們需要分析和判斷出什么樣的人更容易獲救。最重要的是追迟,要利用機(jī)器學(xué)習(xí)來(lái)預(yù)測(cè)出在這場(chǎng)災(zāi)難中哪些人會(huì)最終獲救溶其;
數(shù)據(jù)樣本
數(shù)據(jù)挖掘流程
-
1 加載和檢查數(shù)據(jù)
<pre>#加載包
library('ggplot2') #可視化
library('ggthemes') # 可視化
library('scales') # 可視化
library('dplyr') # 數(shù)據(jù)處理
library('mice') # 可視化
library('randomForest') # 分類算法
包安裝好了,繼續(xù)加載數(shù)據(jù)
setwd("A:\\...")
train <- read.table("train.csv",stringsAsFactors = F,header = T,sep=",",na.strings = "")
test <- read.table("test.csv",stringsAsFactors = F,header = T,sep=",",na.strings = "")
full <- bind_rows(train, test) # 將訓(xùn)練和測(cè)試集合并
查看下數(shù)據(jù)結(jié)構(gòu):
# check data
str(full)
Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 1309 obs. of 12 variables:
$ PassengerId: int 1 2 3 4 5 6 7 8 9 10 ...
$ Survived : int 0 1 1 1 0 0 0 0 1 1 ...
$ Pclass : int 3 1 3 1 3 3 1 3 3 2 ...
$ Name : chr "Braund, Mr. Owen Harris" "Cumings, Mrs. John Bradley (Florence Briggs Thayer)" "Heikkinen, Miss. Laina" "Futrelle, Mrs. Jacques Heath (Lily May Peel)" ...
$ Sex : chr "male" "female" "female" "female" ...
$ Age : num 22 38 26 35 35 NA 54 2 27 14 ...
$ SibSp : int 1 1 0 1 0 0 0 3 0 1 ...
$ Parch : int 0 0 0 0 0 0 0 1 2 0 ...
$ Ticket : chr "A/5 21171" "PC 17599" "STON/O2. 3101282" "113803" ...
$ Fare : num 7.25 71.28 7.92 53.1 8.05 ...
$ Cabin : chr NA "C85" NA "C123" ...
$ Embarked : chr "S" "C" "S" "S" ...
以上敦间,我們知道我們處理的樣本擁有1309個(gè)觀測(cè)值瓶逃,每個(gè)觀測(cè)值含有12個(gè)變量束铭。為了更方便的理解樣本,下面列出12個(gè)變量的釋義:
Survived : 取值0和1,0表示死亡厢绝,1表示獲救
Pclass :乘客的船倉(cāng)等級(jí)
Name :乘客姓名
Sex :乘客性別
Age :乘客年齡
SibSp :船上配偶兄妹的人數(shù)
Parch :船上父母孩子的人數(shù)
Ticket :票號(hào)
Fare :票價(jià)
Cabin :乘客船艙號(hào)
Embarked :出發(fā)港口
-
2 特征工程
乘客姓名變量包含了許多信息契沫,比如性別;另外還可以用姓氏來(lái)尋找一個(gè)家庭中的諸多成員昔汉。
#從乘客名稱中提取頭銜信息
full$Title <- gsub('(., )|(\..)', '', full$Name)
# 統(tǒng)計(jì)不同頭銜中對(duì)應(yīng)不同性別的人數(shù)
table(full$Sex, full$Title)
## Capt Col Don Dona Dr Jonkheer Lady MajorMasterMiss Mlle Mme
## female 0 0 0 1 1 0 1 0 0 260 2 1
## male 1 4 1 0 7 1 0 2 61 0 0 0
## Mr Mrs Ms Rev Sir the Countess
## female 0 197 2 0 0 1
## male 757 0 0 8 1 0
容易看到 Dona, Lady, the Countess,Capt, Col, Don, Dr, Major, Rev, Sir, Jonkheer 這些頭銜出現(xiàn)的次數(shù)很少懈万;而MIle,Mme,Ms都是女性,此處懷疑工作人員筆誤將Miss寫錯(cuò)MIle,Mme,將Mrs寫成Ms靶病;于是:
# 出現(xiàn)頻次很低的頭銜統(tǒng)一替換成'Rale Title'
rare_title <- c('Dona', 'Lady', 'the Countess','Capt', 'Col', 'Don', 'Dr', 'Major', 'Rev', 'Sir', 'Jonkheer')
# 將Mlle,Ms,Mme替換成Miss,Miss,Mrs
full$Title[full$Title == 'Mlle'] <- 'Miss'
full$Title[full$Title == 'Ms'] <- 'Miss'
full$Title[full$Title == 'Mme'] <- 'Mrs'
full$Title[full$Title %in% rare_title] <- 'Rare Title'
# 輸出頭銜和性別的交叉列聯(lián)表
table(full$Sex, full$Title)
輸出結(jié)果如下:
## Master Miss Mr Mrs Rare Title
## female 0 264 0 198 4
## male 61 0 757 0 25
最后会通,我們從用戶全民中提取用戶的姓氏
# 提取姓氏
namefull$Surname <- sapply(full$Name, function(x) strsplit(x, split = '[,.]')[[1]][1])
我們已經(jīng)處理了用戶的姓名;下面我們將創(chuàng)建一個(gè)家庭規(guī)模變量娄周,以反映出用戶有多少家庭成員在船上涕侈。
# 創(chuàng)建家庭規(guī)模變量
full$Fsize <- full$SibSp + full$Parch + 1# Create a family variable
full$Family <- paste(full$Surname, full$Fsize, sep='_')
用戶的家庭規(guī)模能反映出什么呢?為了幫助大家理解家庭規(guī)模和是否被救之間有什么影響,我們用一張圖來(lái)展示:
# 用ggplot2包畫出家庭規(guī)模與用戶被救之間的關(guān)系
ggplot(full[1:891,], aes(x = Fsize, fill = factor(Survived))) +
geom_bar(stat='count', position='dodge') +
scale_x_continuous(breaks=c(1:11)) +
labs(x = 'Family Size')
容易看到孤身一人和家庭規(guī)模大于4的用戶中被救的人數(shù)偏少:家庭規(guī)模在2~4之間的用戶被救的人數(shù)偏多昆咽。于是我們?cè)俳⒁粋€(gè)表征家庭規(guī)模大小的變量FsizeD:
full$FsizeD[full$Fsize == 1] <- 'singleton'
full$FsizeD[full$Fsize < 5 & full$Fsize > 1] <- 'small'
full$FsizeD[full$Fsize > 4] <- 'large'
我們?cè)偻ㄟ^(guò)馬賽克圖展現(xiàn)不同規(guī)模用戶的獲救概率:
-
3 缺失值處理
處理樣本缺失值的方法有很多,但對(duì)于只有1300多個(gè)觀測(cè)值的小樣本而言牙甫,我們不會(huì)通過(guò)刪除含有缺失值的觀測(cè)值來(lái)處理缺失數(shù)據(jù)掷酗;我們可以用特定值(比如均值)來(lái)填補(bǔ)缺失值,也可以通過(guò)預(yù)測(cè)來(lái)填補(bǔ)缺失值窟哺;
首先我們發(fā)現(xiàn)第62和830位乘客缺失了出發(fā)港口的指標(biāo)數(shù)據(jù)(Embarked)
full[c(62, 830), 'Embarked']
## Source: local data frame [2 x 1]
##
## Embarked
## (chr)
## 1
## 2
通過(guò)觀察泻轰,62和830號(hào)乘客的票價(jià)均為80美元,而且船艙等級(jí)為一等艙且轨;可以推測(cè)用戶從哪個(gè)港口出發(fā)可能會(huì)影響不同等級(jí)船艙的票價(jià)浮声;
# 先剔除62和830號(hào)乘客的信息
embark_fare <- full %>% filter(PassengerId != 62 & PassengerId != 830)
# 通過(guò)箱線圖展示出發(fā)港口、船艙等級(jí)旋奢、票價(jià)三者關(guān)系
ggplot(embark_fare, aes(x = Embarked, y = Fare, fill = factor(Pclass))) +
geom_boxplot() +
geom_hline(aes(yintercept=80), colour='red', linetype='dashed', lwd=2) +
scale_y_continuous(labels=dollar_format())
從圖上看出泳挥,票價(jià)在80美元而且船艙等級(jí)是一等艙的乘客只有可能從C地出發(fā);因此第62和830號(hào)乘客的出發(fā)地是C至朗。
full$Embarked[c(62, 830)] <- 'C'
第1044位乘客缺失了船票價(jià)格屉符;
full[1044, ]
## Source: local data frame [1 x 18]
##
## PassengerId Survived Pclass Name Sex Age SibSp Parch
## (int) (int) (int) (chr) (chr) (dbl) (int) (int)
## 1 1044 NA 3 Storey, Mr. Thomas male 60.5 0 0
## Variables not shown: Ticket (chr), Fare (dbl), Cabin (chr), Embarked
## (chr), Title (chr), Surname (chr), Fsize (dbl), Family (chr), FsizeD
## (chr), Deck (fctr)
由于這位乘客從S港口出而且船艙等級(jí)是3級(jí),所以我們看看從S港口出發(fā)且船艙等級(jí)是3級(jí)的所有乘客的票價(jià)是如何分布的:
#通過(guò)密度圖展示不同票價(jià)的的分布趨勢(shì)
ggplot(full[full$Pclass == '3' & full$Embarked == 'S', ],
aes(x = Fare)) +
geom_density(fill = '#99d6ff', alpha=0.4) +
geom_vline(aes(xintercept=median(Fare, na.rm=T)),
colour='red', linetype='dashed', lwd=1) +
scale_x_continuous(labels=dollar_format())
上圖看出锹引,我們用票價(jià)均值(紅虛線)來(lái)替代1044號(hào)乘客缺失的票價(jià)數(shù)據(jù)是相對(duì)合理的矗钟;于是:
full$Fare[1044] <- median(full[full$Pclass == '3' & full$Embarked == 'S', ]$Fare, na.rm = TRUE)
以上我們簡(jiǎn)單處理了一些缺失數(shù)據(jù),但是整個(gè)數(shù)據(jù)集中年齡字段仍有較多缺失值存在嫌变,因?yàn)槟挲g是數(shù)值型變量吨艇;所以我們可以結(jié)合其他的變量數(shù)據(jù)通過(guò)模型來(lái)預(yù)測(cè)出缺失數(shù)據(jù)。
# 統(tǒng)計(jì)缺失數(shù)據(jù)數(shù)量
sum(is.na(full$Age))
## [1] 263
這里我們采用常用的缺失值處理包mice腾啥。
# 將一些字符型輸入變量轉(zhuǎn)化成因子類型
factor_vars <- c('PassengerId','Pclass','Sex','Embarked', 'Title','Surname','Family','FsizeD')
# 設(shè)定隨機(jī)數(shù)種子
seedset.seed(520)
#調(diào)用mice包东涡,輸入變量中剔除一些價(jià)值很低的變量
mice_mod <- mice(full[, !names(full) %in% c('PassengerId','Name','Ticket','Cabin','Family','Surname','Survived')], method='rf')
##
## iter imp variable
## 1 1 Age Deck
## 1 2 Age Deck
## 1 3 Age Deck
## 1 4 Age Deck
## 1 5 Age Deck
## 2 1 Age Deck
## 2 2 Age Deck
## 2 3 Age Deck
## 2 4 Age Deck
## 2 5 Age Deck
## 3 1 Age Deck
## 3 2 Age Deck
## 3 3 Age Deck
## 3 4 Age Deck
## 3 5 Age Deck
## 4 1 Age Deck
## 4 2 Age Deck
## 4 3 Age Deck
## 4 4 Age Deck
## 4 5 Age Deck
## 5 1 Age Deck
## 5 2 Age Deck
## 5 3 Age Deck
## 5 4 Age Deck
## 5 5 Age Deck
# 保存輸出值
mice_output <- complete(mice_mod)
下面我們比較一下預(yù)測(cè)出的age分布和原始數(shù)據(jù)中的age分布有沒有較大差異:
# 畫出年齡密度分布圖
par(mfrow=c(1,2))
hist(full$Age, freq=F, main='Age: Original Data', col='darkgreen', ylim=c(0,0.04))
hist(mice_output$Age, freq=F, main='Age: MICE Output', col='lightgreen', ylim=c(0,0.04))
非常好冯吓;預(yù)測(cè)前后的年齡分布并沒有明顯差異;下面將預(yù)測(cè)后的年齡數(shù)據(jù)替換到原始數(shù)據(jù)當(dāng)中:
full$Age <- mice_output$Age
#統(tǒng)計(jì)新數(shù)據(jù)集中的缺失值數(shù)量
sum(is.na(full$Age))
## [1] 0
既然我們補(bǔ)全了age字段的缺失值數(shù)據(jù)软啼,那么下面我們繼續(xù)利用age字段做一些特征工程桑谍。例如我們可以通過(guò)age來(lái)大致確定哪些人是孩子、哪些人是母親祸挪;孩子的age一般都是小于18的锣披;而母親這可能滿足:1.age大于18;2.至少擁有一個(gè)孩子盎咛酢雹仿;3.全名中不帶有Miss字符;4.性別是女性整以。
# 首選我們觀察下不同性別當(dāng)中年齡與是否被救之間的關(guān)系
ggplot(full[1:891,], aes(Age, fill = factor(Survived))) +
geom_histogram() +
# 性別對(duì)于預(yù)測(cè)有明顯意義胧辽,因?yàn)槲覀冾A(yù)先知道女性獲救的幾率更大(這是一個(gè)先驗(yàn)概率)
facet_grid(.~Sex)
于是我們?cè)跇颖炯行略黾右涣衏hild:
full$Child[full$Age < 18] <- 'Child'
full$Child[full$Age >= 18] <- 'Adult'
# Show counts
table(full$Mother, full$Survived)
##
## 0 1
## Adult 484 274
## Child 65 68
數(shù)據(jù)上顯示如果你是一個(gè)孩子,那么在這場(chǎng)災(zāi)難中你被救的概率約有1/2公黑;下面我們繼續(xù)創(chuàng)建新變量mother邑商,我們期待能夠在數(shù)據(jù)用印證母親被救的可能性更大這一先驗(yàn)假設(shè)。
#增加mother變量
full$Mother[full$Sex == 'female' & full$Parch > 0 & full$Age > 18 & full$Title != 'Miss'] <- 'Mother'
table(full$Mother, full$Survived)
##
## 0 1
## Mother 16 39
## Not Mother 533 303
# 將child和mother變量轉(zhuǎn)化成因子類型
full$Child <- factor(full$Child)
full$Mother <- factor(full$Mother)
到這里凡蚜,我們完成了所有的數(shù)據(jù)處理工作人断。(** 是不是覺得超級(jí)枯燥、超級(jí)繁瑣朝蜘;沒錯(cuò)恶迈!這就是數(shù)據(jù)分析的現(xiàn)實(shí),80%的工作都集中在數(shù)據(jù)處理環(huán)節(jié)谱醇!**)
-
4 模型建立
預(yù)測(cè)環(huán)節(jié)中輸入變量除了最開始數(shù)據(jù)集中包含的之外暇仲,我們還陸續(xù)添加了一些新的變量;比如child副渴、mother奈附、Fsize、FsizeD等煮剧;這里我們選用隨機(jī)森林算法(RandomForest桅狠,關(guān)于隨機(jī)森林算法)進(jìn)行分類預(yù)測(cè);
第一步轿秧,我們首先從原始樣本full中剝離出訓(xùn)練集和測(cè)試集中跌;
train <- full[1:891,]
test <- full[892:1309,]
第二步、帶入訓(xùn)練集進(jìn)行樣本訓(xùn)練:
#設(shè)立隨機(jī)數(shù)種子
set.seed(9999)
# 模型建立
rf_model <- randomForest(factor(Survived) ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked + Title + FsizeD + Child + Mother, data = train)
#展示模型誤差(包含袋外誤差菇篡、正例誤差和負(fù)例誤差)
plot(rf_model, ylim=c(0,0.36))legend('topright', colnames(rf_model$err.rate), col=1:3, fill=1:3)
黑線代表總體誤差(oob袋外誤差)漩符,保持在20%左右;
藍(lán)線代表正例誤差率(對(duì)被救概率的預(yù)測(cè))驱还,保持在30%左右嗜暴;
紅線代表負(fù)例誤差率(對(duì)死亡概率的預(yù)測(cè))凸克,保持在10%左右;
因此闷沥,容易看出萎战,該模型對(duì)負(fù)例的預(yù)測(cè)精度明顯高于正例預(yù)測(cè)精度。
下面我們通過(guò)Gini系數(shù)來(lái)了解下模型的每個(gè)輸入變量對(duì)模型的重要性程度有什么不同:
importance <- importance(rf_model)
varImportance <- data.frame(Variables = row.names(importance), Importance = round(importance[ ,'MeanDecreaseGini'],2))
rankImportance <- varImportance %>%
mutate(Rank = paste0('#',dense_rank(desc(Importance))))
# 作圖
ggplot(rankImportance, aes(x = reorder(Variables, Importance), y = Importance, fill = Importance)) +
geom_bar(stat='identity') +
geom_text(aes(x = Variables, y = 0.5, label = Rank), hjust=0, vjust=0.55, size = 4, colour = 'red') + labs(x = 'Variables') +
coord_flip()
-
5 結(jié)果預(yù)測(cè)
# 將模型帶入測(cè)試集
prediction <- predict(rf_model, test)
# 保存結(jié)果
solution <- data.frame(PassengerID = test$PassengerId, Survived = prediction)
# 輸出結(jié)果到CSV文件格式
write.csv(solution, file = 'rf_mod_Solution.csv', row.names = F) -
6 結(jié)語(yǔ)
災(zāi)難預(yù)測(cè)是Kaggle上比較熱門和基礎(chǔ)的算法競(jìng)賽題目舆逃;這篇文章主要給大家展示一整套數(shù)據(jù)挖掘流程和機(jī)器學(xué)習(xí)算法建模實(shí)例以及如何將數(shù)據(jù)結(jié)果可視化展示蚂维;當(dāng)然,如果你在該賽題中應(yīng)用本文思路路狮,提交結(jié)果可直接排名500+左右虫啥;文中的數(shù)據(jù)處理思路來(lái)自于Megan Risdal 。