Machine_Learning_2019_Task 9 繪制樹圖形
要求
利用 Python 結(jié)合 Matplotlib 繪制樹圖形
import matplotlib.pyplot as plt
# 定義文本框和箭頭格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 獲取葉節(jié)點(diǎn)的數(shù)目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
# 測試節(jié)點(diǎn)的數(shù)據(jù)類型是否為字典
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
# 如果不是芋哭,則為葉節(jié)點(diǎn)
else: numLeafs +=1
return numLeafs
# 樹的層數(shù)
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
# 測試節(jié)點(diǎn)的數(shù)據(jù)類型是否為字典
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 畫節(jié)點(diǎn)
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
# 在父子節(jié)點(diǎn)之間添加文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
# 畫樹
def plotTree(myTree, parentPt, nodeTxt):
# 計(jì)算樹在x軸的寬
numLeafs = getNumLeafs(myTree)
# 計(jì)算樹在y軸的高
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0]
# plotTree.xOff和plotTree.yOff追蹤已經(jīng)繪制的節(jié)點(diǎn)位置并表示下一個(gè)節(jié)點(diǎn)的恰當(dāng)位置
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
# 按比例減少全局變量
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
# 測試節(jié)點(diǎn)的數(shù)據(jù)類型是否為字典
if type(secondDict[key]).__name__=='dict':
# 遞歸調(diào)用
plotTree(secondDict[key],cntrPt,str(key))
# 如果不是奄喂,則為葉節(jié)點(diǎn)
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
# 創(chuàng)建繪圖
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()