在多元線性回歸中,并不是所用特征越多越好彤枢;選擇少量狰晚、合適的特征既可以避免過擬合,也可以增加模型解釋度缴啡。這里介紹3種方法來選擇特征:最優(yōu)子集選擇
壁晒、向前或向后逐步選擇
、交叉驗證法
业栅。
最優(yōu)子集選擇
這種方法的思想很簡單秒咐,就是把所有的特征組合都嘗試建模一遍谬晕,然后選擇最優(yōu)的模型⌒。基本如下:
- 對于p個特征攒钳,從k=1到k=p——
- 從p個特征中任意選擇k個,建立C(p,k)個模型歹茶,選擇最優(yōu)的一個(RSS最小或R2最大)夕玩;
- 從p個最優(yōu)模型中選擇一個最優(yōu)模型(交叉驗證誤差、Cp惊豺、BIC燎孟、Adjusted R2等指標(biāo))。
這種方法優(yōu)勢很明顯:所有各種可能的情況都嘗遍了尸昧,最后選擇的一定是最優(yōu)揩页;劣勢一樣很明顯:當(dāng)p越大時,計算量也會越發(fā)明顯地增大(2^p)烹俗。因此這種方法只適用于p較小的情況爆侣。
以下為R中ISLR
包的Hitters
數(shù)據(jù)集為例,構(gòu)建棒球運動員的多元線性模型幢妄。
> library(ISLR)
> Hitters <- na.omit(Hitters)
> dim(Hitters) # 除去Salary做為因變量兔仰,還剩下19個特征
[1] 263 20
> library(leaps)
> regfit.full = regsubsets(Salary~.,Hitters,nvmax = 19) #選擇最大19個特征的全子集選擇模型
> reg.summary = summary(regfit.full) # 可看到不同數(shù)量下的特征選擇
> plot(reg.summary$rss,xlab="Number of Variables",ylab="RSS",type = "l") # 特征越多,RSS越小
> plot(reg.summary$adjr2,xlab="Number of Variables",ylab="Adjusted RSq",type = "l")
> points(which.max(reg.summary$adjr2),reg.summary$adjr2[11],col="red",cex=2,pch=20) # 11個特征時蕉鸳,Adjusted R2最大
> plot(reg.summary$cp,xlab="Number of Variables",ylab="Cp",type = "l")
> points(which.min(reg.summary$cp),reg.summary$cp[10],col="red",cex=2,pch=20) # 10個特征時乎赴,Cp最小
> plot(reg.summary$bic,xlab="Number of Variables",ylab="BIC",type = "l")
> points(which.min(reg.summary$bic),reg.summary$bic[6],col="red",cex=2,pch=20) # 6個特征時,BIC最小
> plot(regfit.full,scale = "r2") #特征越多潮尝,R2越大榕吼,這不意外
> plot(regfit.full,scale = "adjr2")
> plot(regfit.full,scale = "Cp")
> plot(regfit.full,scale = "bic")
Adjust R2
、Cp
勉失、BIC
是三個用來評價模型的統(tǒng)計量(定義和公式就不寫了)羹蚣,Adjust R2越接近1說明模型擬合得越好;其他兩個指標(biāo)則是越小越好乱凿。
注意到在這3個指標(biāo)下顽素,特征選擇的結(jié)果也不同。這里以Adjust R2為例徒蟆,以它為指標(biāo)選出了11個特征:
從圖中可見戈抄,當(dāng)Adjusted R2最大(當(dāng)然也就比0.5多一點,也不怎么樣)時后专,選出的11個特征為:AtBat
、Hits
输莺、Walks
戚哎、CAtBat
裸诽、CRuns
、CRBI
型凳、CWalks
丈冬、LeagueN
、DivisionW
甘畅、PutOuts
埂蕊、Assists
。
可以直接查看模型的系數(shù):
> coef(regfit.full,11)
(Intercept) AtBat Hits Walks CAtBat
135.7512195 -2.1277482 6.9236994 5.6202755 -0.1389914
CRuns CRBI CWalks LeagueN DivisionW
1.4553310 0.7852528 -0.8228559 43.1116152 -111.1460252
PutOuts Assists
0.2894087 0.2688277
可見這11個特征與圖中一致疏唾,現(xiàn)在特征篩選出來了蓄氧,系數(shù)也算出來了,模型就已經(jīng)構(gòu)建出來了槐脏。
逐步回歸法
這種方法的思想可以概括為“一條路走到黑”喉童,每一次迭代都只能沿著上一次迭代的方向繼續(xù)進行,不能反悔顿天,不能丟鍋堂氯。以向前逐步回歸為例,基本過程如下:
- 對于p個特征牌废,從k=1到k=p——
- 從p個特征中任意選擇k個咽白,建立C(p,k)個模型,選擇最優(yōu)的一個(RSS最小或R2最大)鸟缕;
- 基于上一步的最優(yōu)模型的k個特征晶框,再選擇加入一個,這樣就可以構(gòu)建p-k個模型叁扫,從中最優(yōu)三妈;
- 重復(fù)以上過程,直到k=p迭代完成莫绣;
- 從p個模型中選擇最優(yōu)畴蒲。
向后逐步回歸法類似,只是一開始就用p個特征建模对室,之后每迭代一次就舍棄一個特征是模型更優(yōu)模燥。
這種方法與最優(yōu)子集選擇法的差別在于,最優(yōu)子集選擇法可以選擇任意(k+1)個特征進行建模掩宜,而逐步回歸法只能基于之前所選的k個特征進行(k+1)輪建模蔫骂。所以逐步回歸法不能保證最優(yōu),因為前面的特征選擇中很有可能選中一些不是很重要的特征在后面的迭代中也必須加上牺汤,從而就不可能產(chǎn)生最優(yōu)特征組合了辽旋。但優(yōu)勢就是計算量大大減小(p*(p+1)/2),因此實用性更強补胚。
> regfit.fwd = regsubsets(Salary~.,data=Hitters,nvmax = 19,method = "forward")
> summary(regfit.fwd) # 顯示向前選擇過程
Subset selection object
Call: regsubsets.formula(Salary ~ ., data = Hitters, nvmax = 19, method = "forward")
Selection Algorithm: forward
AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits
1 ( 1 ) " " " " " " " " " " " " " " " " " "
2 ( 1 ) " " "*" " " " " " " " " " " " " " "
3 ( 1 ) " " "*" " " " " " " " " " " " " " "
4 ( 1 ) " " "*" " " " " " " " " " " " " " "
5 ( 1 ) "*" "*" " " " " " " " " " " " " " "
6 ( 1 ) "*" "*" " " " " " " "*" " " " " " "
7 ( 1 ) "*" "*" " " " " " " "*" " " " " " "
8 ( 1 ) "*" "*" " " " " " " "*" " " " " " "
9 ( 1 ) "*" "*" " " " " " " "*" " " "*" " "
10 ( 1 ) "*" "*" " " " " " " "*" " " "*" " "
11 ( 1 ) "*" "*" " " " " " " "*" " " "*" " "
12 ( 1 ) "*" "*" " " "*" " " "*" " " "*" " "
13 ( 1 ) "*" "*" " " "*" " " "*" " " "*" " "
14 ( 1 ) "*" "*" "*" "*" " " "*" " " "*" " "
15 ( 1 ) "*" "*" "*" "*" " " "*" " " "*" "*"
16 ( 1 ) "*" "*" "*" "*" "*" "*" " " "*" "*"
17 ( 1 ) "*" "*" "*" "*" "*" "*" " " "*" "*"
18 ( 1 ) "*" "*" "*" "*" "*" "*" "*" "*" "*"
19 ( 1 ) "*" "*" "*" "*" "*" "*" "*" "*" "*"
CHmRun CRuns CRBI CWalks LeagueN DivisionW PutOuts
1 ( 1 ) " " " " "*" " " " " " " " "
2 ( 1 ) " " " " "*" " " " " " " " "
3 ( 1 ) " " " " "*" " " " " " " "*"
4 ( 1 ) " " " " "*" " " " " "*" "*"
5 ( 1 ) " " " " "*" " " " " "*" "*"
6 ( 1 ) " " " " "*" " " " " "*" "*"
7 ( 1 ) " " " " "*" "*" " " "*" "*"
8 ( 1 ) " " "*" "*" "*" " " "*" "*"
9 ( 1 ) " " "*" "*" "*" " " "*" "*"
10 ( 1 ) " " "*" "*" "*" " " "*" "*"
11 ( 1 ) " " "*" "*" "*" "*" "*" "*"
12 ( 1 ) " " "*" "*" "*" "*" "*" "*"
13 ( 1 ) " " "*" "*" "*" "*" "*" "*"
14 ( 1 ) " " "*" "*" "*" "*" "*" "*"
15 ( 1 ) " " "*" "*" "*" "*" "*" "*"
16 ( 1 ) " " "*" "*" "*" "*" "*" "*"
17 ( 1 ) " " "*" "*" "*" "*" "*" "*"
18 ( 1 ) " " "*" "*" "*" "*" "*" "*"
19 ( 1 ) "*" "*" "*" "*" "*" "*" "*"
Assists Errors NewLeagueN
1 ( 1 ) " " " " " "
2 ( 1 ) " " " " " "
3 ( 1 ) " " " " " "
4 ( 1 ) " " " " " "
5 ( 1 ) " " " " " "
6 ( 1 ) " " " " " "
7 ( 1 ) " " " " " "
8 ( 1 ) " " " " " "
9 ( 1 ) " " " " " "
10 ( 1 ) "*" " " " "
11 ( 1 ) "*" " " " "
12 ( 1 ) "*" " " " "
13 ( 1 ) "*" "*" " "
14 ( 1 ) "*" "*" " "
15 ( 1 ) "*" "*" " "
16 ( 1 ) "*" "*" " "
17 ( 1 ) "*" "*" "*"
18 ( 1 ) "*" "*" "*"
19 ( 1 ) "*" "*" "*"
> regfit.bwd = regsubsets(Salary~.,data=Hitters,nvmax = 19,method = "backward")
> summary(regfit.bwd) # 顯示向后選擇過程
Subset selection object
Call: regsubsets.formula(Salary ~ ., data = Hitters, nvmax = 19, method = "backward")
Selection Algorithm: backward
AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits
1 ( 1 ) " " " " " " " " " " " " " " " " " "
2 ( 1 ) " " "*" " " " " " " " " " " " " " "
3 ( 1 ) " " "*" " " " " " " " " " " " " " "
4 ( 1 ) "*" "*" " " " " " " " " " " " " " "
5 ( 1 ) "*" "*" " " " " " " "*" " " " " " "
6 ( 1 ) "*" "*" " " " " " " "*" " " " " " "
7 ( 1 ) "*" "*" " " " " " " "*" " " " " " "
8 ( 1 ) "*" "*" " " " " " " "*" " " " " " "
9 ( 1 ) "*" "*" " " " " " " "*" " " "*" " "
10 ( 1 ) "*" "*" " " " " " " "*" " " "*" " "
11 ( 1 ) "*" "*" " " " " " " "*" " " "*" " "
12 ( 1 ) "*" "*" " " "*" " " "*" " " "*" " "
13 ( 1 ) "*" "*" " " "*" " " "*" " " "*" " "
14 ( 1 ) "*" "*" "*" "*" " " "*" " " "*" " "
15 ( 1 ) "*" "*" "*" "*" " " "*" " " "*" "*"
16 ( 1 ) "*" "*" "*" "*" "*" "*" " " "*" "*"
17 ( 1 ) "*" "*" "*" "*" "*" "*" " " "*" "*"
18 ( 1 ) "*" "*" "*" "*" "*" "*" "*" "*" "*"
19 ( 1 ) "*" "*" "*" "*" "*" "*" "*" "*" "*"
CHmRun CRuns CRBI CWalks LeagueN DivisionW PutOuts
1 ( 1 ) " " "*" " " " " " " " " " "
2 ( 1 ) " " "*" " " " " " " " " " "
3 ( 1 ) " " "*" " " " " " " " " "*"
4 ( 1 ) " " "*" " " " " " " " " "*"
5 ( 1 ) " " "*" " " " " " " " " "*"
6 ( 1 ) " " "*" " " " " " " "*" "*"
7 ( 1 ) " " "*" " " "*" " " "*" "*"
8 ( 1 ) " " "*" "*" "*" " " "*" "*"
9 ( 1 ) " " "*" "*" "*" " " "*" "*"
10 ( 1 ) " " "*" "*" "*" " " "*" "*"
11 ( 1 ) " " "*" "*" "*" "*" "*" "*"
12 ( 1 ) " " "*" "*" "*" "*" "*" "*"
13 ( 1 ) " " "*" "*" "*" "*" "*" "*"
14 ( 1 ) " " "*" "*" "*" "*" "*" "*"
15 ( 1 ) " " "*" "*" "*" "*" "*" "*"
16 ( 1 ) " " "*" "*" "*" "*" "*" "*"
17 ( 1 ) " " "*" "*" "*" "*" "*" "*"
18 ( 1 ) " " "*" "*" "*" "*" "*" "*"
19 ( 1 ) "*" "*" "*" "*" "*" "*" "*"
Assists Errors NewLeagueN
1 ( 1 ) " " " " " "
2 ( 1 ) " " " " " "
3 ( 1 ) " " " " " "
4 ( 1 ) " " " " " "
5 ( 1 ) " " " " " "
6 ( 1 ) " " " " " "
7 ( 1 ) " " " " " "
8 ( 1 ) " " " " " "
9 ( 1 ) " " " " " "
10 ( 1 ) "*" " " " "
11 ( 1 ) "*" " " " "
12 ( 1 ) "*" " " " "
13 ( 1 ) "*" "*" " "
14 ( 1 ) "*" "*" " "
15 ( 1 ) "*" "*" " "
16 ( 1 ) "*" "*" " "
17 ( 1 ) "*" "*" "*"
18 ( 1 ) "*" "*" "*"
19 ( 1 ) "*" "*" "*"
需要注意的是码耐,全子集回歸、向前逐步回歸和向后逐步回歸的特征選擇結(jié)果可能不同:
> coef(regfit.full,7)
(Intercept) Hits Walks CAtBat CHits
79.4509472 1.2833513 3.2274264 -0.3752350 1.4957073
CHmRun DivisionW PutOuts
1.4420538 -129.9866432 0.2366813
> coef(regfit.fwd,7)
(Intercept) AtBat Hits Walks CRBI
109.7873062 -1.9588851 7.4498772 4.9131401 0.8537622
CWalks DivisionW PutOuts
-0.3053070 -127.1223928 0.2533404
> coef(regfit.bwd,7)
(Intercept) AtBat Hits Walks CRuns
105.6487488 -1.9762838 6.7574914 6.0558691 1.1293095
CWalks DivisionW PutOuts
-0.7163346 -116.1692169 0.3028847
交叉驗證法
交叉驗證法是機器學(xué)習(xí)中一個普適的檢驗?zāi)P推詈头讲畹姆椒ㄈ芷洌⒉痪窒抻诰唧w的模型本身骚腥。這里介紹一種折中的k折交叉驗證法
,過程如下:
- 將樣本隨機劃入k(一般取10)個大小接近的折(fold)
- 取第i(1<=i<=k)折的樣本作為驗證集瓶逃,其它作為訓(xùn)練集訓(xùn)練模型
- k個模型的驗證誤差的均值即作為模型的總體驗證誤差
k-fold CV比留一交叉驗證法
(LOOCV)的優(yōu)勢有兩點:1束铭、計算量小,LOOCV要計算n次厢绝,k-fold只需計算k次契沫;2、LOOCV每次只留一個樣本作為驗證集代芜,相當(dāng)于差不多還是把全部整體作為訓(xùn)練集埠褪,這樣每次擬合的模型都差不多,而且很容易造成過擬合挤庇,使驗證誤差方差過大钞速。k-fold沒有用那么多的樣本來訓(xùn)練,可以有效避免過擬合的問題嫡秕。
所以對于不同數(shù)量的特征渴语,都可以用k折交叉驗證法求一個驗證誤差,最后比較驗證誤差與特征數(shù)量的關(guān)系(同樣昆咽,這種思想方法也不僅局限于線性模型)驾凶。
> set.seed(1)
> # 隨機劃分訓(xùn)練集和測試集
> train = sample(c(T,F),nrow(Hitters),rep=T)
> test = !train
>
> # 訓(xùn)練集上進行全子集最優(yōu)回歸
> regfit.best = regsubsets(Salary~.,data = Hitters[train,],nvmax = 19)
> test.mat = model.matrix(Salary~.,data = Hitters[test,])
>
> val.errors = rep(NA,19)
>
> for(i in 1:19){
+ coefi = coef(regfit.best,id=i)
+ pred = test.mat[,names(coefi)]%*%coefi # 這一步用向量乘法來計算測試集的預(yù)測值
+ val.errors[i] = mean((Hitters$Salary[test]-pred)^2) # 計算MSE
+ }
>
> val.errors
[1] 220968.0 169157.1 178518.2 163426.1 168418.1 171270.6
[7] 162377.1 157909.3 154055.7 148162.1 151156.4 151742.5
[13] 152214.5 157358.7 158541.4 158743.3 159972.7 159859.8
[19] 160105.6
> which.min(val.errors)
[1] 10
> coef(regfit.best,10)
(Intercept) AtBat Hits Walks CAtBat
-80.2751499 -1.4683816 7.1625314 3.6430345 -0.1855698
CHits CHmRun CWalks LeagueN DivisionW
1.1053238 1.3844863 -0.7483170 84.5576103 -53.0289658
PutOuts
0.2381662
上例是將樣本隨機分為訓(xùn)練集和測試集,然后在訓(xùn)練集上按不同特征數(shù)通過全子集回歸構(gòu)建模型并計算不同特征數(shù)下的MSE掷酗,可見10個特征下MSE最小调违。
下面用k-折交叉驗證法來選擇特征:
> k = 10
> set.seed(1)
> folds = sample(1:k,nrow(Hitters),replace = T) # 將樣本可重復(fù)地劃入10折中
> table(folds) # 大致差不多
folds
1 2 3 4 5 6 7 8 9 10
13 25 31 32 33 27 26 30 22 24
> cv.errors = matrix(NA,k,19,dimnames = list(NULL,paste(1:19))) # 構(gòu)建一個k*19的矩陣來存放測試誤差。每一行代表一折泻轰,每一列代表特征數(shù)
>
> for(j in 1:k){
+ best.fit = regsubsets(Salary~.,data = Hitters[folds!=j,],nvmax = 19) # 以第j折以外的訓(xùn)練集作全子集最優(yōu)回歸
+ for(i in 1:19){ # 計算分別取1-19個特征下的MSE
+ pred = predict(best.fit,Hitters[folds==j,],id=i)
+ cv.errors[j,i] = mean((Hitters$Salary[folds==j]-pred)^2)
+ }
+ }
>
> cv.errors
1 2 3 4 5 6
[1,] 187479.08 141652.61 163000.36 169584.40 141745.39 151086.36
[2,] 96953.41 63783.33 85037.65 76643.17 64943.58 56414.96
[3,] 165455.17 167628.28 166950.43 152446.17 156473.24 135551.12
[4,] 124448.91 110672.67 107993.98 113989.64 108523.54 92925.54
[5,] 136168.29 79595.09 86881.88 94404.06 89153.27 83111.09
[6,] 171886.20 120892.96 120879.58 106957.31 100767.73 89494.38
[7,] 56375.90 74835.19 72726.96 59493.96 64024.85 59914.20
[8,] 93744.51 85579.47 98227.05 109847.35 100709.25 88934.97
[9,] 421669.62 454728.90 437024.28 419721.20 427986.39 401473.33
[10,] 146753.76 102599.22 192447.51 208506.12 214085.78 224120.38
7 8 9 10 11 12
[1,] 193584.17 144806.44 159388.10 138585.25 140047.07 158928.92
[2,] 63233.49 63054.88 60503.10 60213.51 58210.21 57939.91
[3,] 137609.30 146028.36 131999.41 122733.87 127967.69 129804.19
[4,] 104522.24 96227.18 93363.36 96084.53 99397.85 100151.19
[5,] 86412.18 77319.95 80439.75 75912.55 81680.13 83861.19
[6,] 94093.52 86104.48 84884.10 80575.26 80155.27 75768.73
[7,] 62942.94 60371.85 61436.77 62082.63 66155.09 65960.47
[8,] 90779.58 77151.69 75016.23 71782.40 76971.60 77696.55
[9,] 396247.58 381851.15 369574.22 376137.45 373544.77 382668.48
[10,] 214037.26 169160.95 177991.11 169239.17 147408.48 149955.85
13 14 15 16 17 18
[1,] 161322.76 155152.28 153394.07 153336.85 153069.00 152838.76
[2,] 59975.07 58629.57 58961.90 58757.55 58570.71 58890.03
[3,] 133746.86 135748.87 137937.17 140321.51 141302.29 140985.80
[4,] 103073.96 106622.46 106211.72 107797.54 106288.67 106913.18
[5,] 85111.01 84901.63 82829.44 84923.57 83994.95 84184.48
[6,] 76927.44 76529.74 78219.76 78256.23 77973.40 79151.81
[7,] 66310.58 70079.10 69553.50 68242.10 68114.27 67961.32
[8,] 78460.91 81107.16 82431.25 82213.66 81958.75 81893.97
[9,] 375284.60 376527.06 374706.25 372917.91 371622.53 373745.20
[10,] 194397.12 194448.21 174012.18 172060.78 184614.12 184397.75
19
[1,] 153197.11
[2,] 58949.25
[3,] 140392.48
[4,] 106919.66
[5,] 84284.62
[6,] 78988.92
[7,] 67943.62
[8,] 81848.89
[9,] 372365.67
[10,] 183156.97
>
> mean.cv.errors = apply(cv.errors,2,mean) # 計算各特征數(shù)下10折的平均MSE
> mean.cv.errors
1 2 3 4 5 6 7
160093.5 140196.8 153117.0 151159.3 146841.3 138302.6 144346.2
8 9 10 11 12 13 14
130207.7 129459.6 125334.7 125153.8 128273.5 133461.0 133974.6
15 16 17 18 19
131825.7 131882.8 132750.9 133096.2 132804.7
> plot(mean.cv.errors,type = "b")
可見交叉驗證的結(jié)果是選擇11個特征技肩。
那么就可以對整個數(shù)據(jù)集進行全子集回歸,選擇11變量結(jié)果了浮声。
> reg.best = regsubsets(Salary~.,data = Hitters,nvmax = 19)
> coef(reg.best,11)
(Intercept) AtBat Hits Walks CAtBat
135.7512195 -2.1277482 6.9236994 5.6202755 -0.1389914
CRuns CRBI CWalks LeagueN DivisionW
1.4553310 0.7852528 -0.8228559 43.1116152 -111.1460252
PutOuts Assists
0.2894087 0.2688277