Decision tree(決策樹)
(注:本文并非原創(chuàng),但修改了原文中幾處代碼錯(cuò)誤以及部分概念描述的模糊之處荠卷,新加了一些算式證明等)
決策樹是廣泛用于分類和回歸任務(wù)的模型豪筝。本質(zhì)上愚墓,它從一層層if/else問題中進(jìn)行學(xué)習(xí)输虱,并得出結(jié)論
import mglearn
mglearn.plots.plot_animal_tree()
上圖就是一顆決策樹些楣,樹的每個(gè)結(jié)點(diǎn)代表一個(gè)問題或包含答案的終結(jié)點(diǎn)(也叫葉結(jié)點(diǎn))
下面較為詳細(xì)的解釋下這個(gè)算法:首先,決策樹是一種基本的分類與回歸方法
在分類中,定義為:
分類決策樹模型是一種描述對(duì)實(shí)例進(jìn)行分類的樹形結(jié)構(gòu)戈毒。決策樹由結(jié)點(diǎn)和有向邊組成。結(jié)點(diǎn)有兩種類型:內(nèi)部結(jié)點(diǎn)和葉結(jié)點(diǎn)横堡,內(nèi)部結(jié)點(diǎn)表示一個(gè)特征或?qū)傩月袷校~結(jié)點(diǎn)表示一個(gè)類。
分類的時(shí)候命贴,從根結(jié)點(diǎn)開始道宅,對(duì)實(shí)例的某一個(gè)特征進(jìn)行測(cè)試,根據(jù)測(cè)試結(jié)果胸蛛,將實(shí)例分配到其子結(jié)點(diǎn)污茵;此時(shí),每一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)著該特征的一個(gè)取值葬项。如此遞歸向下移動(dòng)泞当,直至達(dá)到葉結(jié)點(diǎn),最后將實(shí)例分配到葉結(jié)點(diǎn)的類中民珍。
決策樹的學(xué)習(xí)
決策樹學(xué)習(xí)算法包含特征選擇襟士、決策樹的生成與剪枝過程。決策樹的學(xué)習(xí)算法通常是遞歸地選擇最優(yōu)特征嚷量,并用最優(yōu)特征對(duì)數(shù)據(jù)集進(jìn)行分割陋桂。開始時(shí),構(gòu)建根結(jié)點(diǎn)蝶溶,選擇最優(yōu)特征嗜历,該特征有幾種值就分割為幾個(gè)子集,每個(gè)子集分別遞歸調(diào)用此方法抖所,返回結(jié)點(diǎn)梨州,返回的結(jié)點(diǎn)就是上一層的子結(jié)點(diǎn)。直到所有特征都已經(jīng)用完部蛇,或者數(shù)據(jù)集只有一維特征為止摊唇。
特征選擇
特征選擇問題希望選取對(duì)訓(xùn)練數(shù)據(jù)具有良好分類能力的特征,這樣可以提高決策樹學(xué)習(xí)的效率涯鲁。如果利用一個(gè)特征進(jìn)行分類的結(jié)果與隨機(jī)分類的結(jié)果沒有很大差別巷查,則稱這個(gè)特征是沒有分類能力的。為了更好的選擇特征抹腿,使用了一些熵的概念(在另外的文章中已經(jīng)詳細(xì)推導(dǎo)過)岛请,這里用代碼實(shí)現(xiàn)一下之前的結(jié)論
import numpy as np
def calcuInfoEnt(dataSet, i=-1):
'''
計(jì)算信息熵
dataSet:數(shù)據(jù)集
return:數(shù)據(jù)集的信息熵
'''
numElements = len(dataSet)
labelCounts = {}
infoEnt = 0.0
for elementVec in dataSet: #遍歷數(shù)據(jù)集,統(tǒng)計(jì)元素向量中具有相同標(biāo)簽的頻率
currLabel = elementVec[i]
if currLabel not in labelCounts.keys():
labelCounts[currLabel] = 0
labelCounts[currLabel] += 1
for key in labelCounts:
prob = float(labelCounts[key]) / numElements
infoEnt -= prob * np.log2(prob)
return infoEnt
def splitDataSet(dataSet, axis, featVal):
'''
按照給定特征值劃分?jǐn)?shù)據(jù)集
dataSet:待劃分?jǐn)?shù)據(jù)集
axis:劃分?jǐn)?shù)據(jù)集特征的維度
featVal:特征的值
return:劃分的子數(shù)據(jù)集
'''
subDataSet = []
for elementVec in dataSet:
if elementVec[axis] == featVal:
reduceElemVec = elementVec[:axis] #提取特征前的vec
reduceElemVec.extend(elementVec[axis+1:]) #提取特征后的vec
subDataSet.append(reduceElemVec)
return subDataSet
def calcuConditionEnt(dataSet, i, featList, featSet):
'''
計(jì)算在指定特征i的條件下警绩,Y的條件熵
dataSet:數(shù)據(jù)集
i:維度i
featList:數(shù)據(jù)集特征值列表
featSet:數(shù)據(jù)集特征值集合
'''
conditionEnt = 0.0
for featVal in featSet:
subDataSet = splitDataSet(dataSet, i, featVal)
prob = float(len(subDataSet))/len(dataSet) #指定特征的概率
conditionEnt += prob * calcuInfoEnt(subDataSet) #條件熵的定義計(jì)算
return conditionEnt
最一開始我們使用信息增益(Information gain)來構(gòu)建決策樹崇败,被稱為ID3,這種算法本身缺陷很大
Information gain(信息增益)
信息增益表示得知特征X的信息而使得類Y的信息的不確定性減少的程度。特征A對(duì)訓(xùn)練數(shù)據(jù)集D的信息增益
后室,定義為集合D的經(jīng)驗(yàn)熵
與特征A給定條件下D的經(jīng)驗(yàn)條件熵H(D|A)之差缩膝,即
不難發(fā)現(xiàn),信息增益大的特征具有更強(qiáng)的分類能力岸霹。那么疾层,根據(jù)信息增益準(zhǔn)則的特征選擇方法就是:對(duì)訓(xùn)練數(shù)據(jù)集計(jì)算其每個(gè)特征的信息增益,選擇信息增益最大的特征贡避。
假設(shè)樣本有k個(gè)類別痛黎,表示類別k的樣本個(gè)數(shù),
表示樣本總數(shù)刮吧,那么每個(gè)類別的概率就是
![]()
那么
特征A對(duì)數(shù)據(jù)集D的經(jīng)驗(yàn)條件熵H(D|A):
根據(jù)特征A將D劃分為n個(gè)子集湖饱,
為
的樣本個(gè)數(shù),
之和為
杀捻,記
中屬于
的樣本集合為
井厌,即交集,
為
的樣本個(gè)數(shù)
def calcuInfoGain(dataSet, baseEnt, i):
'''
計(jì)算信息增益
dataSet:數(shù)據(jù)集
baseEnt:數(shù)據(jù)集的信息熵
i:特征維度
return:特征i對(duì)數(shù)據(jù)集的信息增益g(D|A)
'''
featList = [example[i] for example in dataSet] #第i維特征列表
featSet = set(featList) #轉(zhuǎn)換為特征集合
conditionEnt = calcuConditionEnt(dataSet, i, featList, featSet)
infoGain = baseEnt - conditionEnt
return infoGain
后面改進(jìn)為使用信息增益比(Information gain ratio)生成決策樹致讥,它對(duì)ID3算法進(jìn)行了以下改進(jìn):
1)使用信息增益比選擇特征旗笔,克服了用信息增益選擇特征時(shí)偏向選擇取值多的特征的不足
2)在樹構(gòu)造的過程中進(jìn)行剪枝
3)能夠完成對(duì)連續(xù)屬性的離散化處理
4)能夠?qū)Σ煌暾麛?shù)據(jù)進(jìn)行處理
針對(duì)上面四點(diǎn)下面的介紹對(duì)后兩點(diǎn)并未進(jìn)行優(yōu)化
Information gain ratio (信息增益比)
特征A對(duì)訓(xùn)練數(shù)據(jù)集D的信息增益比
公式為:
特別地,其中為對(duì)于數(shù)據(jù)集D拄踪,將當(dāng)前特征A作為隨機(jī)變量(取值為特征A的各個(gè)特征值)蝇恶,求得的經(jīng)驗(yàn)熵
懲罰參數(shù)(penalty parameter):數(shù)據(jù)集D以特征A作為隨機(jī)變量的熵的倒數(shù),即:將特征A取值相同的樣本劃分到同一個(gè)子集中
def calcuInfoGainRatio(dataSet, baseEnt, i):
'''
計(jì)算信息增益比
dataSet:數(shù)據(jù)集
baseEnt:數(shù)據(jù)集的信息熵
i:特征維度
return:特征i對(duì)數(shù)據(jù)集的信息增益比gR(D,A)
'''
return calcuInfoGain(dataSet, baseEnt, i) / calcuInfoEnt(dataSet, i)
決策樹的生成
ID3
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(
cancer.data, cancer.target, stratify=cancer.target, random_state=42)
tree = DecisionTreeClassifier(max_depth=3, random_state=0)
tree.fit(X_train, y_train)
export_graphviz(tree, out_file="tree.dot", class_names=["malignant", "benign"],
feature_names=cancer.feature_names, impurity=False, filled=True)
import graphviz
with open("tree.dot") as f:
dot_graph = f.read()
graphviz.Source(dot_graph)
上面了生成了一棵樹惶桐,對(duì)照著說一下算法
輸入:訓(xùn)練數(shù)據(jù)集D, 特征A撮弧,閾值
輸出:決策樹T
(1)若D中所有實(shí)例屬于同一類,則T單結(jié)點(diǎn)樹姚糊,并將類
作為該結(jié)點(diǎn)的類標(biāo)記贿衍,返回T;
(2)若A=,則T為單結(jié)點(diǎn)樹救恨,并將D中實(shí)例數(shù)最大的類
作為該結(jié)點(diǎn)的類標(biāo)記贸辈,返回T;
(3)否則肠槽,計(jì)算A中各特征對(duì)D的信息增益擎淤,選擇信息增益最大的特征;
(4)如果的信息增益小于閾值
秸仙,則置T為單結(jié)點(diǎn)樹嘴拢,并將D中實(shí)例樹最大的類
作為該結(jié)點(diǎn)的類標(biāo)記,返回T寂纪;
(5)否則席吴,對(duì)的每一可能值
赌结,依
將D分割為若干非空子集
,將
中實(shí)例數(shù)最大的類作為標(biāo)記孝冒,構(gòu)建子結(jié)點(diǎn)柬姚,由結(jié)點(diǎn)及其子結(jié)點(diǎn)構(gòu)成樹T,返回T庄涡;
(6)對(duì)第i個(gè)子結(jié)點(diǎn)伤靠,以為訓(xùn)練集,以
為特征集啼染,遞歸地調(diào)用(1)~(5),得到子樹
焕梅,返回
迹鹅。
import operator
def chooseBestFeatSplitID3(dataSet):
'''
選擇最好的數(shù)據(jù)集劃分方式
dataSet:數(shù)據(jù)集
return:劃分結(jié)果
'''
numFeatures = len(dataSet[0]) - 1
baseEnt = calcuInfoEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
infoGain = calcuInfoGain(dataSet, baseEnt, i) #計(jì)算信息增益
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature #返回最優(yōu)特征維度
def majorityClassify(classList):
'''
采用多數(shù)表決的方法決定結(jié)點(diǎn)的分類
classList:所有的類標(biāo)簽列表
return:出現(xiàn)次數(shù)最多的類
'''
classCount = {}
for cla in classList:
if cla not in classCount.keys():
classCount[cla] = 0
classCount[cla] += 1
sortClassCount = sorted(classCount.items(), key=operator.itemgetter(1),
reverse=True)
return sortClassCount[0][0]
def crtDecisionTree(dataSet, featLabels):
'''
創(chuàng)建決策樹
dataSet:訓(xùn)練數(shù)據(jù)集
featLabels:所有特征標(biāo)簽
return:返回決策樹字典
'''
classList = [element[-1] for element in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0] #所有的類標(biāo)簽都相同
if len(dataSet[0]) == 1:
return majorityClassify(classList) #用完所有特征
bestFeat = chooseBestFeatSplitID3(dataSet)
bestFeatLabel = featLabels[bestFeat]
deTree = {bestFeatLabel:{}}
subFeatLabels = featLabels[:] #復(fù)制所有類標(biāo)簽,保證每次遞歸調(diào)用時(shí)不改變?cè)瓉淼? del(subFeatLabels[bestFeat])
featValues = [element[bestFeat] for element in dataSet]
featValSet = set(featValues)
#####
for value in featValSet:
#subFeatLabels = featLabels[:]
deTree[bestFeatLabel][value] = \
crtDecisionTree(splitDataSet(dataSet, bestFeat, value),subFeatLabels)
return deTree
# 導(dǎo)入數(shù)據(jù)
def createDataSet():
dataSet = [['youth', 'no', 'no', 1, 'refuse'],
['youth', 'no', 'no', '2', 'refuse'],
['youth', 'yes', 'no', '2', 'agree'],
['youth', 'yes', 'yes', 1, 'agree'],
['youth', 'no', 'no', 1, 'refuse'],
['mid', 'no', 'no', 1, 'refuse'],
['mid', 'no', 'no', '2', 'refuse'],
['mid', 'yes', 'yes', '2', 'agree'],
['mid', 'no', 'yes', '3', 'agree'],
['mid', 'no', 'yes', '3', 'agree'],
['elder', 'no', 'yes', '3', 'agree'],
['elder', 'no', 'yes', '2', 'agree'],
['elder', 'yes', 'no', '2', 'agree'],
['elder', 'yes', 'no', '3', 'agree'],
['elder', 'no', 'no', 1, 'refuse'],
]
#print(type(dataSet))
labels = ['age', 'working', 'house', 'credit_situation']
return dataSet, labels
下面我們來更直觀的展示一下分類的結(jié)果
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="round", color='#3366FF') # 定義判斷結(jié)點(diǎn)形態(tài)
leafNode = dict(boxstyle="circle", color='#FF6633') # 定義葉結(jié)點(diǎn)形態(tài)
arrow_args = dict(arrowstyle="<-", color='g') # 定義箭頭
#計(jì)算葉子結(jié)點(diǎn)個(gè)數(shù)
def getNumLeafs(deTree):
numLeafs = 0
firstCondition = list(deTree.keys())[0]
secondDict = deTree[firstCondition]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':#測(cè)試結(jié)點(diǎn)的數(shù)據(jù)類型是否為字典
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1
return numLeafs
#計(jì)算樹的深度
def getTreeDepth(deTree):
maxDepth = 0
firstCondition = list(deTree.keys())[0]
secondDict = deTree[firstCondition]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth : maxDepth = thisDepth
return maxDepth
# 繪制帶箭頭的注釋
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,
xytext=centerPt, xycoords='axes fraction',
textcoords='axes fraction',va="center",
ha="center", bbox=nodeType, arrowprops=arrow_args )
# 在父子結(jié)點(diǎn)間填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center",
ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) # 計(jì)算寬與高
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))
/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 標(biāo)記子結(jié)點(diǎn)屬性值
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 減少y偏移
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],
(plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree,
#and the first element will be another dict
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
# 測(cè)試代碼
if __name__ == "__main__":
dataSet, featLabels = createDataSet()
deTree = crtDecisionTree(dataSet, featLabels)
createPlot(deTree)
C4.5
C4.5算法使用信息增益比來選擇屬性贞言,繼承了ID3算法的優(yōu)點(diǎn)斜棚,并在一下幾個(gè)方面對(duì)ID3的算法進(jìn)行改進(jìn):
??克服了用信息增益選擇屬性時(shí)偏向選擇取值多的屬性的不足性;
??在樹構(gòu)造過程中進(jìn)行剪枝该窗;
??能夠完成對(duì)連續(xù)屬性的離散化處理弟蚀;
??能夠?qū)Σ煌暾麛?shù)據(jù)進(jìn)行處理。
在算法描述上酗失,僅對(duì)上面ID3里的第三步中改為信息增益比即可
def chooseBestFeatSplitC45(dataSet):
'''
選擇最好的數(shù)據(jù)集劃分方式
dataSet:數(shù)據(jù)集
return:劃分結(jié)果
'''
numFeatures = len(dataSet[0]) - 1
baseEnt = calcuInfoEnt(dataSet)
bestInfoGainRate = 0.0
bestFeature = -1
for i in range(numFeatures):
infoGainRate = calcuInfoGainRatio(dataSet, baseEnt, i) #計(jì)算信息增益比
if(infoGainRate > bestInfoGainRate):
bestInfoGainRate = infoGainRate
bestFeature = i
return bestFeature
然后我們考慮用分類好的決策樹模型進(jìn)行預(yù)測(cè)分類
def classify(inputTree, featLabels, testData):
'''
利用決策樹進(jìn)行分類
inputTree:構(gòu)造好的決策樹模型
featLabels:所有的特征標(biāo)簽
testData:測(cè)試數(shù)據(jù)
return:返回分類的決策結(jié)果
'''
firstCondition = list(inputTree.keys())[0]
secondDict = inputTree[firstCondition]
#拿到第一個(gè)分類條件在labels里面的索引
featIndex = featLabels.index(firstCondition)
featVal = testData[featIndex]
result = secondDict[featVal]
if isinstance(result, dict):
classLabel = classify(result, featLabels, testData)
else: classLabel = result
return classLabel
dataSet, featLabels = createDataSet()
deTree = crtDecisionTree(dataSet, featLabels)
print('預(yù)測(cè)結(jié)果是:' +
classify(deTree, featLabels, ['youth','no','yes',1]))
預(yù)測(cè)結(jié)果是:agree
CART(分類與回歸樹)
CART:分類與回歸樹义钉,也是一種應(yīng)用廣泛的決策樹學(xué)習(xí)方法。但是CART算法比較強(qiáng)大规肴,既可以用來作分類樹捶闸,也可以用來作回歸樹。在作為分類樹的時(shí)候拖刃,與ID3,C4.5差別不是很大删壮,只是選擇特征的根據(jù)不同。在通常情況下兑牡,決策樹是二叉樹央碟,也就是說它的特征值都是二分類的。當(dāng)用CART作回歸樹時(shí)均函,以最小平方誤差作為劃分樣本的依據(jù)亿虽。
在分類樹中采用基尼指數(shù)用來選擇最優(yōu)特征。假設(shè)有
個(gè)類苞也,樣本點(diǎn)屬于第
類的概率為
经柴,則概率的基尼指數(shù)定義為
對(duì)于給定樣本集合,
為樣本個(gè)數(shù),
是
中屬于第
類的樣本子集墩朦,則此時(shí)的基尼指數(shù)為
def calcuGini(dataSet):
'''
計(jì)算基尼指數(shù)
dataSet:數(shù)據(jù)集
return:基尼指數(shù)的計(jì)算結(jié)果
'''
numElements = len(dataSet)
Gini = 1.0
labelCounts = {}
for eleVec in dataSet: #遍歷每個(gè)實(shí)例坯认,統(tǒng)計(jì)標(biāo)簽的頻數(shù)
curLabel = eleVec[-1]
if curLabel not in labelCounts.keys():
labelCounts[curLabel] = 0
labelCounts[curLabel] += 1
for key in labelCounts:
prob = float(labelCounts[key]) / numElements
Gini -= prob * prob
return Gini
那么在給定特征A的條件下,集合D的基尼指數(shù)定義為
因?yàn)樘卣鞯姆诸悅€(gè)數(shù)會(huì)決策樹的分支個(gè)數(shù),CART是二叉樹牛哺,那么在給定特征A的時(shí)候陋气,集合D就會(huì)被分為兩類和
基尼指數(shù)表示集合D的不確定性,和熵類似引润,當(dāng)經(jīng)過
的分類后巩趁,
的數(shù)值越大,樣本集合的不確定性也就越大淳附。
def calcuGiniBaseFeat(dataSet, featI, featVal):
'''
計(jì)算給定特征下的基尼指數(shù)
dataSet:數(shù)據(jù)集
featI:特征維度
featVal:特征維度下的特征值
return:計(jì)算結(jié)果
'''
D0 = []
D1 = []
for eleVec in dataSet:
if eleVec[featI] == featVal:
D0.append(eleVec)
else:
D1.append(eleVec)
Gini = float(len(D0)) / len(dataSet) * calcuGini(D0) + \
float(len(D1)) / len(dataSet) * calcuGini(D1)
return Gini
下面描述下CART分類樹的算法步驟:
輸入:訓(xùn)練數(shù)據(jù)集D议慰,停止計(jì)算的條件
輸出:CART決策樹
根據(jù)訓(xùn)練數(shù)據(jù)集,從根結(jié)點(diǎn)開始奴曙,遞歸地對(duì)每個(gè)結(jié)點(diǎn)進(jìn)行以下操作别凹,構(gòu)建二叉決策樹
(1)設(shè)結(jié)點(diǎn)的訓(xùn)練數(shù)據(jù)集為, 計(jì)算現(xiàn)有特征對(duì)該數(shù)據(jù)集的基尼指數(shù)。此時(shí)洽糟,對(duì)每一個(gè)特征
炉菲,對(duì)其可能取的每個(gè)值
,根據(jù)樣本點(diǎn)對(duì)
的測(cè)試為“是”或“否”將
分割成
和
兩部分坤溃,計(jì)算
時(shí)的基尼指數(shù)拍霜;
(2)在所喲可能的特征以及它們所有可能的切分點(diǎn)
中,選擇基尼指數(shù)最小的特征及其對(duì)應(yīng)的切分點(diǎn)作為最優(yōu)特征與最優(yōu)切分點(diǎn)薪介。依據(jù)最優(yōu)特征和最優(yōu)切分點(diǎn)祠饺,從現(xiàn)結(jié)點(diǎn)生成兩個(gè)子結(jié)點(diǎn),將訓(xùn)練集依特征分配到兩個(gè)子結(jié)點(diǎn)中汁政;
(3)對(duì)兩個(gè)子結(jié)點(diǎn)遞歸地調(diào)用(1),(2),直至滿足停止條件為止吠裆;
(4)生成CART決策樹
算法停止計(jì)算的條件是結(jié)點(diǎn)中的樣本個(gè)數(shù)小于預(yù)定閾值,或者樣本集的基尼指數(shù)小于預(yù)定閾值(代表樣本基本屬于同一類)烂完,或者沒有更多特征试疙。
def chooseBestFeatSplitGini(dataSet):
bestGini = float("inf")
bestFeatI = 0
condiGini = 0.0
numFeatures = len(dataSet[0]) - 1
for i in range(numFeatures):
featList = [element[i] for element in dataSet]
featSet = set(featList)
for splitVal in featSet:
condiGini = calcuGiniBaseFeat(dataSet, i, splitVal)
if condiGini < bestGini:
bestFeatI = i
bestGini = condiGini
return bestFeatI
def crtDecisionTreeCART(dataSet, featLabels):
'''
創(chuàng)建決策樹
dataSet:訓(xùn)練數(shù)據(jù)集
featLabels:所有特征標(biāo)簽
return:返回決策樹字典
'''
classList = [element[-1] for element in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0] #所有的類標(biāo)簽都相同
if len(dataSet[0]) == 1:
return majorityClassify(classList) #用完所有特征
bestFeat = chooseBestFeatSplitGini(dataSet)
bestFeatLabel = featLabels[bestFeat]
deTree = {bestFeatLabel:{}}
subFeatLabels = featLabels[:] #復(fù)制所有類標(biāo)簽,保證每次遞歸調(diào)用時(shí)不改變?cè)瓉淼? del(subFeatLabels[bestFeat])
featValues = [element[bestFeat] for element in dataSet]
featValSet = set(featValues)
#####
for value in featValSet:
#subFeatLabels = featLabels[:]
deTree[bestFeatLabel][value] = \
crtDecisionTreeCART(splitDataSet(dataSet, bestFeat, value),
subFeatLabels)
return deTree
def createDataSetCART():
import numpy as np
dataSet = np.loadtxt("C:\\Users\\MAIBENBEN\\Desktop\\lenses.txt", dtype=str)
#print(type(dataSet))
labels = ['age', 'prescript', 'astigmatic', 'tearRate']
return dataSet.tolist(), labels
dataSet, featLabels = createDataSetCART()
deTree = crtDecisionTreeCART(dataSet, featLabels)
createPlot(deTree)
這棵樹看起來還是比較復(fù)雜的抠蚣,我們可以測(cè)試一下它的泛化能力祝旷。
def calcuError(tree, testData, labels):
errCount = 0.0
for i in range(len(testData)):
if classify(tree, labels, testData[i]) != testData[i][-1]:
errCount += 1
return float(errCount)
testData = np.loadtxt("C:\\Users\\MAIBENBEN\\Desktop\\testData.txt", dtype=str)
dataSet, featLabels = createDataSetCART()
deTree = crtDecisionTreeCART(dataSet, featLabels)
testErr = calcuError(deTree, testData.tolist(), featLabels)
print(testErr)
0.0
通過上面的簡(jiǎn)單預(yù)測(cè),我們的模型沒有預(yù)測(cè)誤差(謝天謝地嘶窄,這是件好事)怀跛。不過當(dāng)在更大的數(shù)據(jù)集訓(xùn)練得出的決策樹中,樹將會(huì)變得非常復(fù)雜柄冲,這大概率會(huì)造成過擬合的現(xiàn)象吻谋,即泛化能力就會(huì)差些。
所以我們有必要了解下決策樹的剪枝现横。
pruning(剪枝)
在決策樹學(xué)習(xí)中將已經(jīng)生成的樹進(jìn)行簡(jiǎn)化的過程稱為剪枝漓拾。決策樹的剪枝往往通過極小化決策樹的損失函數(shù)或代價(jià)函數(shù)來實(shí)現(xiàn)阁最。實(shí)際上剪枝的過程就是一個(gè)動(dòng)態(tài)規(guī)劃的過程:從葉結(jié)點(diǎn)開始,自底向上的對(duì)內(nèi)部結(jié)點(diǎn)計(jì)算預(yù)測(cè)誤差以及剪枝后的預(yù)測(cè)誤差骇两,如果兩者的預(yù)測(cè)誤差是相等或剪枝后預(yù)測(cè)誤差更小速种,那么就是剪掉的好。但如果剪枝后的預(yù)測(cè)誤差更大低千,就不要剪了配阵。剪枝后,原內(nèi)部結(jié)點(diǎn)會(huì)變成新的葉結(jié)點(diǎn)示血,其決策類別由多數(shù)表決決定棋傍。不斷重復(fù)上述的過程,直到預(yù)測(cè)誤差最小為止难审。
實(shí)現(xiàn)代碼如下:
import copy
def isTree(obj):
return (type(obj).__name__=='dict')
#計(jì)算剪枝后的預(yù)測(cè)誤差
def calcuPruErr(major, testData):
errCount = 0.0
for i in range(len(testData)):
if major != testData[i][-1]:
errCount += 1
return float(errCount)
#對(duì)決策樹進(jìn)行剪枝
def pruningTree(inputTree, dataSet, testData, featLabels):
labels = featLabels[:]
firstFeat = list(inputTree.keys())[0]
secondDict = inputTree[firstFeat]
classList = [element[-1] for element in dataSet]
featIndex = labels.index(firstFeat)
subLabels = copy.deepcopy(labels)
del(labels[featIndex])
for key in list(secondDict.keys()):
if isTree(secondDict[key]):
#深度優(yōu)先搜索瘫拣,遞歸剪枝
#key是特征值
subDataSet = splitDataSet(dataSet, featIndex, key)
subTestSet = splitDataSet(testData, featIndex, key)
if len(subDataSet) > 0 and len(subTestSet) > 0:
inputTree[firstFeat][key] = \
pruningTree(secondDict[key],subDataSet,
subTestSet,copy.deepcopy(labels))
if calcuError(inputTree, testData, subLabels) < \
calcuPruErr(majorityClassify(classList), testData):
#剪枝后的誤差反而變大,不做處理剔宪,直接返回
return inputTree
else:
#剪枝,原父結(jié)點(diǎn)變成子結(jié)點(diǎn)壹无,其類別由多數(shù)表決決定
print(majorityClassify(classList))
return majorityClassify(classList)
newTree = pruningTree(deTree, dataSet, testData.tolist(), featLabels)
createPlot(newTree)
讓我們來看下剪枝好的樹的泛化能力
tErr = calcuError(newTree, testData.tolist(), featLabels)
print(tErr)
0.0
當(dāng)然CART決策樹還可以用來做回歸任務(wù)葱绒,這里就不進(jìn)行詳細(xì)說明了。斗锭。地淀。