更多干貨就在我的個(gè)人博客 http://blackblog.tech 歡迎關(guān)注晶伦!
決策樹(shù):構(gòu)建一個(gè)基于屬性的樹(shù)形分類(lèi)器渠鸽。
1.每個(gè)非葉節(jié)點(diǎn)表示一個(gè)特征屬性上的測(cè)試(分割)窟坐,
2.每個(gè)分支代表這個(gè)特征屬性在某個(gè)值域上的輸出丁逝,
3.每個(gè)葉節(jié)點(diǎn)存放一個(gè)類(lèi)別椰于。
使用決策樹(shù)進(jìn)行決策的過(guò)程就是從根節(jié)點(diǎn)開(kāi)始惊搏,測(cè)試待分類(lèi)項(xiàng)中相應(yīng)的特征屬性贮乳,并按照其值選擇輸出分支,直到到達(dá)葉子節(jié)點(diǎn)恬惯,將葉子節(jié)點(diǎn)存放的類(lèi)別作為決策結(jié)果向拆。
采用遞歸的方法進(jìn)行建樹(shù)
遞歸的結(jié)束條件
1.當(dāng)前結(jié)點(diǎn)樣本均屬于同一類(lèi)別,無(wú)需劃分酪耳。
2.當(dāng)前屬性集為空浓恳。
3.所有樣本在當(dāng)前屬性集上取值相同,無(wú)法劃分碗暗。
4.當(dāng)前結(jié)點(diǎn)包含的樣本集合為空颈将,不能劃分。
決策樹(shù)的核心
經(jīng)過(guò)屬性劃分后言疗,不同類(lèi)樣本被更好的分離
理想情況:劃分后樣本被完美分類(lèi)晴圾。即每個(gè)分支的樣本都屬性同一類(lèi)。
實(shí)際情況:不可能完美劃分洲守!盡量使得每個(gè)分支某一類(lèi)樣本比例盡量高疑务!即盡量提高劃分后子集的純度。
劃分的目標(biāo):提升劃分后子集的純度梗醇,降低劃分后子集的不純度
決策樹(shù)算法分類(lèi)
決策樹(shù)算法的區(qū)別主要在于所采用的純度判別標(biāo)準(zhǔn)
ID3算法:
使用信息增益作為判別標(biāo)準(zhǔn)
信息熵計(jì)算公式:
假設(shè)屬性??有??可能取值{??1,??2,??,??^??}, ????對(duì)應(yīng)劃分后的數(shù)據(jù)子集為????.
信息增益:
信息增益越大知允,說(shuō)明當(dāng)前的劃分效果越好
C4.5算法
使用信息增益率作為判別準(zhǔn)則
????(??)稱(chēng)為屬性??的“固有值”(Intrinsic Value)
信息增益率越大,說(shuō)明當(dāng)前劃分效果越好
CART算法
使用基尼系數(shù)作為判別準(zhǔn)則
實(shí)驗(yàn)環(huán)境
python3.6
macOS 10.12
代碼思路
BuildTree函數(shù):在該函數(shù)中完成遞歸建樹(shù)叙谨,遞歸返回條件的判斷温鸽,建立存儲(chǔ)樹(shù)所用的字典,打印各類(lèi)信息
ChooseAttr函數(shù):在該函數(shù)中完成選出最佳特征的功能,根據(jù)Ent函數(shù)計(jì)算出的所有樣本的信息熵和加權(quán)的信息熵計(jì)算信息增益涤垫,信息增越大的意味著該屬性的純度越高姑尺,選取信息增益最大的屬性為最佳屬性。
Ent函數(shù):計(jì)算輸入樣本的信息熵蝠猬,通過(guò)輸入Sample的最后一列統(tǒng)計(jì)出該正例與反例出現(xiàn)的概率切蟋,根據(jù)信息熵公式計(jì)算信息熵
SpiltData函數(shù):該函數(shù)用于對(duì)數(shù)據(jù)進(jìn)行拆分,去掉已經(jīng)判斷過(guò)的屬性對(duì)應(yīng)的樣本
CreatePlot函數(shù):用于決策樹(shù)的可視化
數(shù)據(jù)集
使用西瓜書(shū)上的西瓜數(shù)據(jù)集2.0
為了方便計(jì)算榆芦,將西瓜數(shù)據(jù)集的內(nèi)容轉(zhuǎn)換為數(shù)字
色澤: 0:青綠 1:烏黑 2:淺白
根底: 0:蜷縮 1:少蜷 2:硬挺
敲聲: 0:濁響 1:沉悶 2:清脆
紋理: 0:清晰 1:稍糊 2:模糊
臍部: 0:凹陷 1:稍凹 2:平坦
觸感: 0:硬滑 1:軟黏
好瓜: 0:不是 1:是
上代碼
import math
import numpy
import DrawTree
數(shù)據(jù)集,屬性列表
#初始化一個(gè)屬性列表
AttrArr=["色澤","根蒂","敲聲","紋理","臍部","觸感","好瓜"]
#此處使用西瓜數(shù)據(jù)集2.0
data = numpy.array(
[[0,0,0,0,0,0,1],
[1,0,1,0,0,0,1],
[1,0,0,0,0,0,1],
[0,0,1,0,0,0,1],
[2,0,0,0,0,0,1],
[0,1,0,0,1,1,1],
[1,1,0,1,1,1,1],
[1,1,0,0,1,0,1],
[1,1,1,1,1,0,0],
[0,2,2,0,2,1,0],
[2,2,2,2,2,0,0],
[2,0,0,2,2,1,0],
[0,1,0,1,0,0,0],
[2,1,1,1,0,0,0],
[1,1,0,0,1,1,0],
[2,0,0,2,2,0,0],
[0,0,1,1,1,0,0]]
)
BuildTree函數(shù):在該函數(shù)中完成遞歸建樹(shù)柄粹,遞歸返回條件的判斷,建立存儲(chǔ)樹(shù)所用的字典匆绣,打印各類(lèi)信息
#建樹(shù)的函數(shù)
def BuildTree(Sample,Label):
#Sample 為輸入的數(shù)據(jù)
#Label 為對(duì)應(yīng)的標(biāo)簽
#獲取輸入數(shù)據(jù)的的大小
[Count, Attr] = Sample.shape;
n = Attr - 1;
m = Count - 1;
print("Sample:")
print(Sample)
#使用classlist存儲(chǔ)表示正例與反例所在的列
classList = Sample[:, n];
# 記錄第一個(gè)類(lèi)中的個(gè)數(shù)
classOne = 1;
for i in range(1, Count):
if (classList[i] == classList[0]):
classOne = classOne + 1;
#如果當(dāng)前結(jié)點(diǎn)包含的樣本全屬于同一個(gè)樣本驻右,則停止劃分
if (classOne == Count):
print("Final")
print(Sample)
if(classList[0]==0):return "no" #通過(guò) classlist 的 0 1 判斷 最終的結(jié)果
if (classList[0] == 1): return "yes" #通過(guò) classlist 的 0 1 判斷 最終的結(jié)果
#如果當(dāng)前屬性集為空,無(wú)法劃分
if (Attr == 0):
print("Final")
print(Sample)
return classList[0]
#使用ChooseAttr函數(shù)獲取最佳的特征對(duì)應(yīng)編號(hào)
bestAttr = ChooseAttr(Sample)
#通過(guò)最佳特征的編號(hào)獲得標(biāo)簽名
name=Label[bestAttr]
#新建一個(gè)字典用于存儲(chǔ)樹(shù)
Tree = {name: {}}
#打印出最佳特征
print("最佳特征:", name);
#取出對(duì)最佳屬性對(duì)應(yīng)的一列 并去掉重復(fù)值 用于得出一個(gè)屬性下所包含的取值
featValue = numpy.unique(Sample[:, bestAttr])
#計(jì)算出一個(gè)屬性下包含的取值
numOfFeatValue = len(featValue);
#最佳屬性的每一個(gè)評(píng)級(jí)都打印出來(lái)
for i in range(0, numOfFeatValue):
print(name, "評(píng)級(jí):", featValue[I])
subLabels = Label[:]
#對(duì)現(xiàn)有的樹(shù)執(zhí)行 SpiltData 去掉計(jì)算過(guò)的屬性所對(duì)應(yīng)的樣本 遞歸調(diào)用buildTree
Tree[name][i] = BuildTree(SpiltData(Sample, bestAttr, featValue[i]),subLabels)
print('-------------------------');
return Tree
ChooseAttr函數(shù):在該函數(shù)中完成選出最佳特征的功能崎淳,根據(jù)Ent函數(shù)計(jì)算出的所有樣本的信息熵和加權(quán)的信息熵計(jì)算信息增益堪夭,信息增越大的意味著該屬性的純度越高,選取信息增益最大的屬性為最佳屬性拣凹。
#Choose函數(shù)用于 選出最佳的屬性
def ChooseAttr(Sample):
#Sample 為輸入的數(shù)據(jù)
#獲取輸入數(shù)據(jù)的大小
[Count, Attr] = Sample.shape
numOfFeature = Attr - 1
#計(jì)算整個(gè)數(shù)據(jù)的信息熵
baseEnt = Ent(Sample)
#初始信息增益
bestInfoGain = 0.0
#初始的最佳屬性為 -1
bestFeature = -1
#遍歷當(dāng)前所有屬性
for j in range(0, numOfFeature):
#記錄出每一個(gè)屬性中的取值 并去掉重復(fù)值
featureTemp = numpy.unique(Sample[:, j])
#記錄下屬性取值的個(gè)數(shù)
numF = len(featureTemp)
newEnt = 0.0;
#遍歷所有的取值
for i in range(0, numF):
#去除掉當(dāng)前已經(jīng)判斷的樣本
subSet = SpiltData(Sample, j, featureTemp[I])
#得到每一個(gè)取值的個(gè)數(shù)
[newCount, newAttr] = subSet.shape
#計(jì)算每一個(gè)取值出現(xiàn)的概率
prob = newCount / Count
#計(jì)算新的信息熵
newEnt = newEnt + prob * Ent(subSet)
#計(jì)算信息增益
infoGain = baseEnt - newEnt
#找到信息增益最大的屬性 作為當(dāng)前最佳屬性
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = j
return bestFeature
Ent函數(shù):計(jì)算輸入樣本的信息熵森爽,通過(guò)輸入Sample的最后一列統(tǒng)計(jì)出該正例與反例出現(xiàn)的概率,根據(jù)信息熵公式計(jì)算信息熵
#Ent函數(shù)用于計(jì)算信息熵
def Ent(Sample):
#Sample為輸入的數(shù)據(jù)
#得到輸入數(shù)據(jù)的大小
[Count, Attr] = Sample.shape
n = Attr - 1
m = Count - 1
#獲取正例與反例所在的列
label = Sample[:, n]
#去掉重復(fù)的數(shù)據(jù)
deal = numpy.unique(label)
#得到最后判別情況的個(gè)數(shù)
numOfLabel = len(deal)
#新建一個(gè)概率list 用于存儲(chǔ)概率
prob = numpy.zeros([numOfLabel, 2])
for i in range(0, numOfLabel):
#獲取正例與反例
prob[i, 0] = deal[I]
for j in range(0, Count):
#對(duì)正例 與 反例 進(jìn)行計(jì)數(shù)
if (label[j] == deal[I]):
prob[i, 1] = prob[i, 1] + 1
#計(jì)算出概率
prob[:, 1] = prob[:, 1] / Count
ent = 0
#根據(jù)信息熵公示計(jì)算出信息熵
for i in range(0, numOfLabel):
ent = ent - prob[i, 1] * math.log2(prob[i, 1])
return ent
SpiltData函數(shù):該函數(shù)用于對(duì)數(shù)據(jù)進(jìn)行拆分咐鹤,去掉已經(jīng)判斷過(guò)的屬性對(duì)應(yīng)的樣本
#對(duì)數(shù)據(jù)進(jìn)行拆分 去掉已經(jīng)判斷過(guò)的屬性所對(duì)應(yīng)的樣本
def SpiltData(Sample, axis, value):
#Sample 代表輸入的數(shù)據(jù)
#axis 表示要?jiǎng)h除值所在的行
#value表示要?jiǎng)h除的值
[Count, Attr] = Sample.shape
subSet = Sample
k = 0
#對(duì)每一個(gè)樣本都做判斷 把已經(jīng)做過(guò)判斷的樣本刪掉
for i in range(0, Count):
if (Sample[i, axis]) != value:
subSet=numpy.delete(subSet,i-k,0)
k = k + 1
return subSet
主函數(shù)部分
TreeDict=BuildTree(data,AttrArr)
實(shí)驗(yàn)結(jié)果
存儲(chǔ)樹(shù)的字典:
{'紋理': {0: {'根蒂': {0: 'yes', 1: {'色澤': {0: 'yes', 1: {'觸感': {0: 'yes', 1: 'no'}}}}, 2: 'no'}}, 1: {'觸感': {0: 'no', 1: 'yes'}}, 2: 'no'}}
遞歸的過(guò)程:
Sample:
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]
[0 1 0 0 1 1 1]
[1 1 0 1 1 1 1]
[1 1 0 0 1 0 1]
[1 1 1 1 1 0 0]
[0 2 2 0 2 1 0]
[2 2 2 2 2 0 0]
[2 0 0 2 2 1 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[1 1 0 0 1 1 0]
[2 0 0 2 2 0 0]
[0 0 1 1 1 0 0]]
最佳特征: 紋理
紋理 評(píng)級(jí): 0
Sample:
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]
[0 1 0 0 1 1 1]
[1 1 0 0 1 0 1]
[0 2 2 0 2 1 0]
[1 1 0 0 1 1 0]]
最佳特征: 根蒂
根蒂 評(píng)級(jí): 0
Sample:
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]]
Final
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]]
根蒂 評(píng)級(jí): 1
Sample:
[[0 1 0 0 1 1 1]
[1 1 0 0 1 0 1]
[1 1 0 0 1 1 0]]
最佳特征: 色澤
色澤 評(píng)級(jí): 0
Sample:
[[0 1 0 0 1 1 1]]
Final
[[0 1 0 0 1 1 1]]
色澤 評(píng)級(jí): 1
Sample:
[[1 1 0 0 1 0 1]
[1 1 0 0 1 1 0]]
最佳特征: 觸感
觸感 評(píng)級(jí): 0
Sample:
[[1 1 0 0 1 0 1]]
Final
[[1 1 0 0 1 0 1]]
觸感 評(píng)級(jí): 1
Sample:
[[1 1 0 0 1 1 0]]
Final
[[1 1 0 0 1 1 0]]
根蒂 評(píng)級(jí): 2
Sample:
[[0 2 2 0 2 1 0]]
Final
[[0 2 2 0 2 1 0]]
紋理 評(píng)級(jí): 1
Sample:
[[1 1 0 1 1 1 1]
[1 1 1 1 1 0 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[0 0 1 1 1 0 0]]
最佳特征: 觸感
觸感 評(píng)級(jí): 0
Sample:
[[1 1 1 1 1 0 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[0 0 1 1 1 0 0]]
Final
[[1 1 1 1 1 0 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[0 0 1 1 1 0 0]]
觸感 評(píng)級(jí): 1
Sample:
[[1 1 0 1 1 1 1]]
Final
[[1 1 0 1 1 1 1]]
紋理 評(píng)級(jí): 2
Sample:
[[2 2 2 2 2 0 0]
[2 0 0 2 2 1 0]
[2 0 0 2 2 0 0]]
Final
[[2 2 2 2 2 0 0]
[2 0 0 2 2 1 0]
[2 0 0 2 2 0 0]]
樹(shù)的可視化
此處使用的是網(wǎng)上一個(gè)常見(jiàn)的可視化代碼
可視化函數(shù)
#繪制樹(shù)形圖
import matplotlib
# matplotlib.use('qt4agg')
from matplotlib.font_manager import *
import matplotlib.pyplot as plt
myfont = FontProperties(fname='/Users/zhangxuancheng/Library/Fonts/simhei.ttf')
decision_node = dict(boxstyle="sawtooth",fc="0.8")
leaf_node = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
plt.rcParams['font.sans-serif'] = ['SimHei']
#獲取樹(shù)的葉子結(jié)點(diǎn)個(gè)數(shù)(確定圖的寬度)
def get_leaf_num(tree):
leaf_num = 0
first_key = list(tree.keys())[0]
next_dict = tree[first_key]
for key in next_dict.keys():
if type(next_dict[key]).__name__=="dict":
leaf_num +=get_leaf_num(next_dict[key])
else:
leaf_num +=1
return leaf_num
#獲取數(shù)的深度(確定圖的高度)
def get_tree_depth(tree):
depth = 0
first_key = list(tree.keys())[0]
next_dict = tree[first_key]
for key in next_dict.keys():
if type(next_dict[key]).__name__ == "dict":
thisdepth = 1+ get_tree_depth(next_dict[key])
else:
thisdepth = 1
if thisdepth>depth: depth = thisdepth
return depth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
#在父子節(jié)點(diǎn)間填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = get_leaf_num(myTree)
depth = get_tree_depth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decision_node)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leaf_node)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(get_leaf_num(inTree))
plotTree.totalD = float(get_tree_depth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
調(diào)用該函數(shù)
DrawTree.createPlot(TreeDict)