決策樹
優(yōu)點(diǎn): 計(jì)算復(fù)雜度不高而叼,輸出結(jié)果易于理解塔粒,對(duì)中間值的缺失不敏感,可以處理不相關(guān)特征數(shù)據(jù)腻脏。
缺點(diǎn): 可能會(huì)產(chǎn)生過度匹配問題。
適用數(shù)據(jù)類型: 數(shù)值型和標(biāo)稱型
1.計(jì)算香農(nóng)熵
from math import log
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
??這段代碼主要是用于計(jì)算數(shù)據(jù)的香農(nóng)熵银锻。首先創(chuàng)建以數(shù)據(jù)各個(gè)標(biāo)簽為鍵的哈希表迹卢,值初始化為0。然后根據(jù)出現(xiàn)次數(shù)進(jìn)行計(jì)數(shù)徒仓。
??各個(gè)標(biāo)簽的出現(xiàn)次數(shù)除以數(shù)據(jù)總數(shù)就是該標(biāo)簽的出現(xiàn)概率腐碱。有了概率我們就可以計(jì)算香農(nóng)熵。
??計(jì)算每個(gè)標(biāo)簽的熵并累加掉弛,就可得到數(shù)據(jù)整體的香農(nóng)熵症见。
2.建立測(cè)試用數(shù)據(jù)
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
??建立一個(gè)簡(jiǎn)單的數(shù)據(jù)用于測(cè)試算法的各個(gè)部分是否正確。這是一組關(guān)于海洋生物的數(shù)據(jù)殃饿,no surfacing對(duì)應(yīng)dataSet的第1列表示不浮出水面是否可以生存谋作,flippers對(duì)應(yīng)第2列表示是否有腳蹼,第3列的yes和no為標(biāo)簽乎芳,表示是否為魚類遵蚜。接下來我們會(huì)對(duì)這組數(shù)據(jù)進(jìn)行分類。
3.劃分?jǐn)?shù)據(jù)
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
??劃分?jǐn)?shù)據(jù)的函數(shù)需要3個(gè)參數(shù)奈惑。其中axis表示劃分軸吭净,value表示某個(gè)特定的值。
??這里我們以數(shù)據(jù)矩陣的某1列(某一種特征)為軸進(jìn)行劃分肴甸。遍歷矩陣每行寂殉,如果軸上的元素等于value,那么我們就用軸左邊的所有元素加上軸右邊的所有元素原在,來創(chuàng)建一個(gè)新向量(也就是去除了軸上元素的該行所有元素)友扰,把新向量添加到新的矩陣中。
4.計(jì)算最佳的特征劃分
def chooseBestFeatureToSplit(dataSet):
numFeatrues = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatrues):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = I
return bestFeature
??在第3節(jié)中我們實(shí)現(xiàn)了劃分?jǐn)?shù)據(jù)的函數(shù)庶柿,但是我們需要知道如何劃分?jǐn)?shù)據(jù)才是最好的村怪。我們可以用香農(nóng)熵進(jìn)行選擇,如果某種劃分下浮庐,數(shù)據(jù)整體的香農(nóng)熵最小(也就是原始未劃分狀態(tài)下的香農(nóng)熵減去劃分后的香農(nóng)熵的值最大)甚负,那么該劃分就是最佳劃分。
??在代碼中,首先獲得特征數(shù)量(最后1列是標(biāo)簽腊敲,所以要減1)似踱,并計(jì)算數(shù)據(jù)的原始未劃分狀態(tài)下的香農(nóng)熵清蚀。
??初始化最佳信息增益bestInfoGain和最佳劃分特征bestFeature。按列遍歷特征千所,利用集合無重復(fù)特性介时,將當(dāng)前列所有特征保存在集合中没宾。
??初始化新的香農(nóng)熵為newEntropy。遍歷特征集合沸柔,累加以當(dāng)前列為軸循衰,各個(gè)特征(無重復(fù))為指定value劃分下的香農(nóng)熵,值賦給newEntropy褐澎。遍歷完后的newEntropy就是當(dāng)前軸劃分下的數(shù)據(jù)整體香農(nóng)熵会钝。原始熵減去新熵就是得到了信息增益。
??這里的if語句就是傳統(tǒng)的求最大值的方法工三。在遍歷完所有軸迁酸,計(jì)算完所有劃分后,bestInfoGain保存的就是最大的信息增益俭正,bestFeature就是最大信息增益劃分的對(duì)應(yīng)軸奸鬓。最后把該軸返回。
5.多數(shù)表決制決定分類
import operator
def majorityCnt(classList):
classCount={} # 初始化哈希表
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0 # 為所有可能的分類新建鍵
classCount[vote] += 1 # 對(duì)分類的出現(xiàn)次數(shù)進(jìn)行計(jì)數(shù)
sortedClassCount = sorted(classCount.items, \
key=operator.itemgetter(1), reverse=True) # 根據(jù)哈希表第一個(gè)域(值)進(jìn)行逆(從大到小)排序
return sortedClassCount[0][0] # 返回出現(xiàn)次數(shù)最多的分類
??在實(shí)際建樹之前掸读,我們需要處理一種可能發(fā)生的情況串远。就是在處理完所有特征劃分后,剩下的向量中的元素仍然不屬于同一個(gè)標(biāo)簽儿惫。顯然只有1列的向量無法劃分澡罚,所以我們采用多數(shù)表決制。這段代碼的作用就是讓我們就對(duì)每個(gè)標(biāo)簽進(jìn)行計(jì)數(shù)肾请,然后返回次數(shù)最多的標(biāo)簽始苇。
6.建樹
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet\
(dataSet, bestFeat, value), subLabels)
return myTree
??建立樹這種數(shù)據(jù)類型的時(shí)候,通常會(huì)采用遞歸的方法筐喳。采用遞歸方法時(shí)催式,我們需要一個(gè)基本條件終止遞歸。這段代碼中避归,首先建立最后1列標(biāo)簽向量組成的列表荣月。然后判斷兩種情況,第1種情況是如果列表里只有一種標(biāo)簽梳毙,也就是剩下的數(shù)據(jù)全都是同一個(gè)分類時(shí)哺窄,直接返回該標(biāo)簽。第2種情況是剩余數(shù)據(jù)為列向量,也就是處理完了所有劃分的時(shí)候萌业,采用多數(shù)表決制坷襟,返回次數(shù)最多的標(biāo)簽。
??接下來是函數(shù)的主要部分生年,首先使用第4節(jié)的chooseBestFeatureToSplit函數(shù)獲得最佳劃分特征婴程,保存該特征的標(biāo)簽。創(chuàng)建以該標(biāo)簽為鍵的哈希表myTree抱婉,其值為1個(gè)空的哈希表档叔。然后將最佳標(biāo)簽從labels列表中刪除。
??通過列表推導(dǎo)和集合蒸绩,保存最佳特征軸的所有不同值衙四。這些值作為myTree的子哈希表中的鍵,鍵的值通過遞歸建樹獲得患亿。注意传蹈,遞歸函數(shù)中的第1個(gè)參數(shù)是通過最佳特征軸劃分后剩余的數(shù)據(jù)(去除了該軸),第2個(gè)參數(shù)是去除了最佳特征標(biāo)簽的labels的拷貝subLabels步藕。
7.繪制樹結(jié)點(diǎn)
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, \
xycoords='axes fraction', \
xytext=centerPt, textcoords='axes fraction', \
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
??決策樹的測(cè)試是通過繪制圖形來實(shí)現(xiàn)的卡睦。這里我們要利用matplotlib模組進(jìn)行圖形繪制。
??首先建立treePlotter.py漱抓。定義決策結(jié)點(diǎn)表锻、葉結(jié)點(diǎn),樹枝的樣式乞娄,然后定義plotNode函數(shù)繪制樹結(jié)點(diǎn)瞬逊。
8.獲取葉節(jié)點(diǎn)的數(shù)目和樹的層數(shù)
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key]) # 遞歸求得子字典層度
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
??遍歷樹時(shí),和建樹一樣仪或,使用遞歸方法會(huì)更加高效确镊。在getNumLeafs函數(shù)中,我們先獲得根節(jié)點(diǎn)范删,然后獲得根節(jié)點(diǎn)的值(哈希表)蕾域。遍歷這個(gè)哈希表的所有鍵,如果鍵是字典屬性(哈希表)到旦,那么遞歸求得這個(gè)子哈希表的葉旨巷。如果不是字典屬性,那么就讓numLeafs變量自增添忘。
??getTreeDepth函數(shù)也是一樣的方法采呐,先獲得根節(jié)點(diǎn),然后獲得根節(jié)點(diǎn)的值(哈希表)搁骑,遍歷其所有鍵斧吐,如果鍵是字典屬性又固,那么深度為子哈希表遞歸函數(shù)的返回值+1。如果不是字典屬性的話煤率,深度就等于1仰冠。這里比getNumLeafs函數(shù)多了一個(gè)求最大值的步驟,保證最后返回的是最大深度蝶糯。
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': \
{0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': \
{0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[I]
??這里定義一個(gè)測(cè)試用的函數(shù)洋只,返回的列表中第1個(gè)元素為之前我們建的樹。這個(gè)函數(shù)用來測(cè)試getNumLeafs函數(shù)和getTreeDepth函數(shù)是否正常工作裳涛。
9.繪制樹
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
??plotMidText函數(shù)用來在兩個(gè)節(jié)點(diǎn)坐標(biāo)的中點(diǎn)繪制文本信息木张。rotation=30可以讓文本信息繪制時(shí)逆時(shí)針旋轉(zhuǎn)30度众辨,顯示文字信息時(shí)更加美觀端三。
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, \
plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \
cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
??plotTree函數(shù)是繪制樹的主要函數(shù)。和之前的getNumLeafs函數(shù)和getTreeDepth函數(shù)一樣鹃彻,plotTree函數(shù)采用遞歸方法遍歷樹進(jìn)行繪制郊闯。
??cntrPt變量保證每次繪制子結(jié)點(diǎn)時(shí),動(dòng)態(tài)平分坐標(biāo)蛛株,對(duì)稱繪制团赁。計(jì)算式中的plotTree.xOff,plotTree.yOff為全局變量谨履,plotTree.xOff在遍歷到葉時(shí)更新欢摄,向右增加一個(gè)葉結(jié)點(diǎn)的平均寬度,plotTree.yOff在每深入一層后更新笋粟,向下遞減1個(gè)深度怀挠,并且在函數(shù)最后還原,這樣可以保證在遞歸到某個(gè)不是最深的子樹時(shí)害捕,遍歷完該子樹所有葉后y坐標(biāo)可以返回子樹的根繼續(xù)繪制另一方向绿淋。
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
??createPlot是運(yùn)行繪制程序的函數(shù),在函數(shù)中獲得樹的總寬度和總深度尝盼,然后根據(jù)這兩個(gè)值計(jì)算xOff和yOff吞滞。
10.實(shí)現(xiàn)決策樹分類器
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key]
return classLabel
??我們需要實(shí)現(xiàn)分類器才能對(duì)外部數(shù)據(jù)進(jìn)行分類預(yù)測(cè)。這里的代碼還是老樣子盾沫,遞歸進(jìn)行匹配裁赠。
??在使用海洋生物數(shù)據(jù)建的樹進(jìn)行分類時(shí),輸入數(shù)據(jù)testVec應(yīng)為一個(gè)包含2個(gè)元素列表赴精,分別為不浮出水面是否可以生存{0,1}和是否有腳蹼{0,1}组贺。程序會(huì)在根結(jié)點(diǎn)進(jìn)行第1列特征的匹配,如果第1列特征值為1的話祖娘,進(jìn)入右側(cè)子樹根結(jié)點(diǎn)失尖,然后進(jìn)行第2列特征的匹配啊奄,最后返回分類結(jié)果。
11.決策樹的存儲(chǔ)
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb') # 以二進(jìn)制格式寫入
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb') # 以二進(jìn)制格式讀取
return pickle.load(fr)
??和k-近鄰算法不同掀潮,決策樹建成后的模型可以作為數(shù)據(jù)存儲(chǔ)菇夸,不用每次重新計(jì)算。這里利用pickle模組仪吧,可以以二進(jìn)制格式存儲(chǔ)決策樹庄新。
12.使用決策樹預(yù)測(cè)隱形眼鏡類型
def createLensesTree:
import treePlotter
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabel = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabel)
print(lensesTree)
storeTree(lensesTree, 'lensesTree_classifierStorage.txt')
treePlotter.createPlot(lensesTree)
??首先導(dǎo)入繪制樹的程序,因?yàn)閿?shù)據(jù)中每個(gè)特征用制表符分割薯鼠,所以用split方法提取每個(gè)特征择诈,并用strip方法去除行首尾的空格,推導(dǎo)出二維列表出皇。
??手動(dòng)定義標(biāo)簽列表羞芍,然后使用createTree函數(shù)建樹,輸出樹后可以看到郊艘,哈希表層數(shù)過多已經(jīng)很難分清邏輯關(guān)系荷科。
??用treePlotter.createPlot函數(shù)繪制樹后如下圖。
??可以看到繪制圖形后纱注,邏輯關(guān)系可讀性大大提升畏浆。這里我們還是用了storeTree函數(shù)將樹保存為2進(jìn)制文件。
lensesTree = grabTree('lensesTree_classifierStorage.txt')
??使用grabTree函數(shù)就可以從文件中讀取樹模型狞贱。
參考