CART構(gòu)建與剪枝

上周實(shí)現(xiàn)了離散變量的決策樹(shù)的構(gòu)建(ID3算法)比规,它的做法是每次選取當(dāng)前最佳的特征來(lái)分割數(shù)據(jù)调衰,并按照該特征所有的可能值來(lái)切分养葵。也就是說(shuō)心剥,如果一個(gè)特征有4種取值邦尊,那么數(shù)據(jù)被切分成4份,一旦按某特征切分后优烧,便固定死了蝉揍,該特征在之后的算法執(zhí)行過(guò)程中將不會(huì)再起作用,顯然畦娄,這種切分方式過(guò)于迅速又沾。而此外弊仪,ID3算法不能直接處理連續(xù)型特征。
再補(bǔ)充一下用ID3算法生成決策樹(shù)的圖例杖刷。
我們的例子是李航的《統(tǒng)計(jì)學(xué)習(xí)方法》第五章的表5.1撼短,根據(jù)該表生成決策樹(shù),在已知年齡挺勿、有工作曲横、有自己房子、信貸情況的情況下判斷是否給貸款.


圖1 貸款申請(qǐng)樣本數(shù)據(jù)表

用ID3算法生成的決策樹(shù)如下(畫(huà)圖的程序?qū)崿F(xiàn)在最后不瓶,參照的是Peter Harrington的《機(jī)器學(xué)習(xí)實(shí)戰(zhàn)》):


圖2 ID3算法生成的貸款決策樹(shù)

效果很明顯禾嫉,從雜亂無(wú)章的15條記錄中提取出這么精辟的決策樹(shù),有了這棵決策樹(shù)便很輕易的可以判斷該不該給某人貸款蚊丐,如果他有房子熙参,就給貸,如果沒(méi)有麦备,但他有工作孽椰,也給貸,如果都沒(méi)有凛篙,就不給貸黍匾。比表5.1精簡(jiǎn)有效多了昼激。
再來(lái)看一個(gè)例子转锈,周志華的《機(jī)器學(xué)習(xí)》的判斷是否為好瓜的數(shù)據(jù):
圖3 判斷是否為好瓜

判斷一個(gè)西瓜可以從色澤,根蒂,敲聲,紋理,臍部,觸感6個(gè)特征去判斷,每個(gè)特征都有2-3個(gè)值嗦篱,用ID3算法生成的決策樹(shù)如下:
圖4 ID3算法生成是否為好瓜的決策樹(shù)

這里一個(gè)節(jié)點(diǎn)可以有2個(gè)以上的分支填物,取決于每個(gè)特征的所有可能值纹腌。這樣也使一團(tuán)雜亂無(wú)章的數(shù)據(jù)有了個(gè)很清晰的決策樹(shù)。

**總結(jié):
ID3算法可以使離散的問(wèn)題清晰簡(jiǎn)單化滞磺,但也有兩點(diǎn)局限:

  1. 切分過(guò)于迅速
  2. 不能直接處理連續(xù)型特征**
    如遇到連續(xù)變化的特征或者特征可能值很多的情況下升薯,算法得出的效果并不理想而且沒(méi)有多大用處。大多數(shù)情況下击困,生成決策樹(shù)的目的是用來(lái)分類(lèi)的涎劈。

這周,生成決策樹(shù)的算法是CART算法沛励,不像ID3算法责语,它是一種二元切分法,具體處理方法:如果特征值大于給定值就走左子樹(shù)目派,否則就走右子樹(shù)坤候。解決了ID3算法的局限,但同時(shí)企蹭,如果用來(lái)分類(lèi)白筹,生成的決策樹(shù)容易太貪心智末,滿(mǎn)足了大部分訓(xùn)練數(shù)據(jù),出現(xiàn)過(guò)擬合徒河。為提高泛化能力系馆,需對(duì)其
進(jìn)行剪枝,把某些節(jié)點(diǎn)塌陷成一類(lèi)顽照。
在本文由蘑,構(gòu)建CART的實(shí)現(xiàn)算法有兩種(程序在最后)
一種是Peter Harrington的《機(jī)器學(xué)習(xí)實(shí)戰(zhàn)》的對(duì)連續(xù)數(shù)據(jù)的構(gòu)建算法,核心方法(選取最優(yōu)特征)的偽代碼如下:
遍歷每個(gè)特征:
遍歷每個(gè)特征值:
將數(shù)據(jù)切分成兩份
計(jì)算切分的誤差
如果當(dāng)前誤差小于當(dāng)前最小誤差:
更新當(dāng)前最小誤差
更新當(dāng)前最優(yōu)特征和最優(yōu)切分點(diǎn)
返回最優(yōu)切分特征和最優(yōu)切分點(diǎn)

一種是李航的《統(tǒng)計(jì)學(xué)習(xí)方法》的用基尼指數(shù)構(gòu)建的算法代兵,程序是自己實(shí)現(xiàn)的尼酿,目前只能針對(duì)離散性數(shù)據(jù),核心方法的偽代碼如下:
遍歷每個(gè)特征:
遍歷每個(gè)特征值:
將數(shù)據(jù)切分成兩份
計(jì)算切分的基尼指數(shù)
如果基尼指數(shù)小于當(dāng)前基尼指數(shù):
更新當(dāng)前基尼指數(shù)
更新當(dāng)前最優(yōu)特征和最優(yōu)切分點(diǎn)
返回最優(yōu)切分特征和最優(yōu)切分點(diǎn)

只是把誤差計(jì)算方式變成了基尼指數(shù)植影,其他基本類(lèi)似裳擎。

對(duì)前面兩例用CART算法生成的決策樹(shù)如下:


圖5 CART算法生成的貸款決策樹(shù)

圖6 CART算法生成的是否好瓜決策樹(shù)

圖5和圖2是一樣的,因?yàn)橛脕?lái)切分的特征都只有兩類(lèi)
但圖6和圖4便不一樣思币。

再來(lái)對(duì)連續(xù)的數(shù)據(jù)構(gòu)建決策樹(shù)鹿响,數(shù)據(jù)來(lái)自于Peter Harrington的《機(jī)器學(xué)習(xí)實(shí)戰(zhàn)》的第九章ex0.txt


圖7 ex0.txt

肉眼可以分辨,整段數(shù)據(jù)可分為5段惶我,用CART算法生成的結(jié)果如下:

{'spInd': 0, 'spVal': 0.39434999999999998, 'left': {'spInd': 0, 'spVal': 0.58200200000000002, 'left': {'spInd': 0, 'spVal': 0.79758300000000004, 'left': 3.9871631999999999, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 0, 'spVal': 0.19783400000000001, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}

(實(shí)在不想畫(huà)圖了,就用dict表示吧,spInd表示當(dāng)前分割特征,spVal表示當(dāng)前分割值,left表示坐子節(jié)點(diǎn)爬骤,right表示右子節(jié)點(diǎn))
從dict中也明顯可以看到,它將數(shù)據(jù)分成5段,但這個(gè)前提是ops=(1,4)選的好,對(duì)樹(shù)進(jìn)行預(yù)剪枝了。

如果ops=(0.1,0.4)會(huì)發(fā)生什么呢?

{'spInd': 0, 'spVal': 0.39434999999999998, 'left': {'spInd': 0, 'spVal': 0.58200200000000002, 'left': {'spInd': 0, 'spVal': 0.79758300000000004, 'left': {'spInd': 0, 'spVal': 0.81900600000000001, 'left': {'spInd': 0, 'spVal': 0.83269300000000002, 'left': 3.9814298333333347, 'right': {'spInd': 0, 'spVal': 0.81913599999999998, 'left': 4.5692899999999996, 'right': 4.048082}}, 'right': 3.7688410000000001}, 'right': {'spInd': 0, 'spVal': 0.62039299999999997, 'left': {'spInd': 0, 'spVal': 0.62261599999999995, 'left': 2.9787170277777779, 'right': 2.6702779999999997}, 'right': {'spInd': 0, 'spVal': 0.61605100000000002, 'left': 3.5225040000000001, 'right': 3.0497069999999997}}}, 'right': {'spInd': 0, 'spVal': 0.48669800000000002, 'left': {'spInd': 0, 'spVal': 0.53324099999999997, 'left': {'spInd': 0, 'spVal': 0.55900899999999998, 'left': 2.0720909999999999, 'right': 1.8145387500000001}, 'right': 2.0843065555555551}, 'right': 1.8810897500000001}}, 'right': {'spInd': 0, 'spVal': 0.19783400000000001, 'left': {'spInd': 0, 'spVal': 0.21054200000000001, 'left': {'spInd': 0, 'spVal': 0.37526999999999999, 'left': 1.2040690000000001, 'right': {'spInd': 0, 'spVal': 0.316465, 'left': 0.86561450000000006, 'right': {'spInd': 0, 'spVal': 0.23417499999999999, 'left': 1.1113766363636364, 'right': 0.90613224999999997}}}, 'right': 1.3753635000000002}, 'right': {'spInd': 0, 'spVal': 0.14865400000000001, 'left': 0.071894545454545447, 'right': {'spInd': 0, 'spVal': 0.14314299999999999, 'left': -0.27792149999999999, 'right': -0.040866062499999994}}}}

顯然州胳,過(guò)擬合了瓤湘。生成了很多不必要的節(jié)點(diǎn)。在實(shí)際應(yīng)用中,根本不能控制數(shù)據(jù)值得大小,所以ops也很難選好,而ops的選擇對(duì)結(jié)果的影響很大窒朋。所以?xún)H僅預(yù)剪枝是遠(yuǎn)遠(yuǎn)不夠的抵赢。

于是需要后剪枝。簡(jiǎn)單來(lái)說(shuō),就是選擇ops,使得構(gòu)建出的樹(shù)足夠大,接下來(lái)從上而下找到葉節(jié)點(diǎn)赠尾,用測(cè)試集的數(shù)據(jù)來(lái)判斷這些葉節(jié)點(diǎn)是否能降低測(cè)試誤差杉编,如果能光酣,就合并救军,偽代碼如下:
基于已有的樹(shù)切分測(cè)試數(shù)據(jù):
如果存在任一子集是一棵樹(shù)财异,則在該子集遞歸剪枝過(guò)程
計(jì)算當(dāng)前兩個(gè)葉節(jié)點(diǎn)合并后的誤差
計(jì)算合并前的誤差
如果合并后的誤差小于合并前的誤差:
將兩個(gè)葉節(jié)點(diǎn)合并

對(duì)上述決策樹(shù)進(jìn)行剪枝,由于沒(méi)有測(cè)試數(shù)據(jù)唱遭,便拿前150當(dāng)作訓(xùn)練數(shù)據(jù)戳寸,后50當(dāng)作測(cè)試數(shù)據(jù),圖如下:


圖8 ex0.txt訓(xùn)練數(shù)據(jù)和測(cè)試數(shù)據(jù)

同樣拷泽,ops=(0.1,0.4),剪枝后的樹(shù)為:
{'spInd': 0, 'spVal': 0.39434999999999998, 'left': {'spInd': 0, 'spVal': 0.58028299999999999, 'left': {'spInd': 0, 'spVal': 0.79758300000000004, 'left': 3.9739993000000005, 'right': 3.0065657575757574}, 'right': 1.9667640539772728}, 'right': {'spInd': 0, 'spVal': 0.19783400000000001, 'left': 1.0753531944444445, 'right': -0.028014558823529413}}

由那么復(fù)雜的樹(shù)剪枝剪成只有五個(gè)類(lèi)別疫鹊。效果不錯(cuò)

實(shí)現(xiàn)代碼如下:

treePlotter.py

'''
Created on 2017年7月30日

@author: fujianfei
'''

import matplotlib.pyplot as plt


plt.rcParams['font.sans-serif']=['SimHei']#解約matplotlib畫(huà)圖,中文亂碼問(wèn)題

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):
    numLeafs = 0
    firstSides = list(myTree.keys()) 
    firstStr = firstSides[0]#找到輸入的第一個(gè)元素
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstSides = list(myTree.keys()) 
    firstStr = firstSides[0]#找到輸入的第一個(gè)元素
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    #depth = getTreeDepth(myTree)
    firstSides = list(myTree.keys()) 
    firstStr = firstSides[0]#找到輸入的第一個(gè)元素
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

#def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

# def retrieveTree(i):
#     listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
#                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
#                   ]
#     return listOfTrees[i]

#createPlot(thisTree)

CARTTree.py

'''
Created on 2017年8月2日

@author: Administrator
'''
import operator


class TreeNode(object):
    '''
    .樹(shù)節(jié)點(diǎn)的定義:
    '''


    def __init__(self, feat=None, val=None, left=None, right=None):
        '''
        featureToSpliton:該節(jié)點(diǎn)對(duì)應(yīng)的特征司致,比如'年齡'
        ValToSplit:由特征分類(lèi)后的值拆吆,比如'青年','中年','老年'
        leftBranch:左分支
        rightBranch:右分支
        '''
        self.feat = feat
        self.val = val
        self.left = left
        self.right = right
        
      
def calcGini(dataSet):
    '''
    .計(jì)算訓(xùn)練數(shù)據(jù)的預(yù)測(cè)誤差,在這里用基尼指數(shù)
    '''  
    num = len(dataSet)#數(shù)據(jù)集的行數(shù)脂矫,即有幾個(gè)樣本點(diǎn)
    labelCounts = {}#提取總共有多少標(biāo)簽并計(jì)數(shù)
    for featVec in dataSet:#遍歷數(shù)據(jù)集
        label = featVec[-1]#提取標(biāo)簽
        if label not in labelCounts.keys():#如果標(biāo)簽不再labelCount里
            labelCounts[label] = 0#那么在字典labelCount里建一對(duì)字典 key=label,value=0
        labelCounts[label] += 1#對(duì)key=label的字典 的 value加1枣耀,計(jì)數(shù)
    
    gini = 0.0 #定義1-基尼指數(shù)
    
    for key in labelCounts.keys():
        prop = float(labelCounts[key])/num #計(jì)算每個(gè)類(lèi)別的概率
        gini += prop ** 2 #每個(gè)類(lèi)別概率的平方相加,賦值給gini
    return 1-gini#1-概率平方之和庭再,即為基尼指數(shù)

def splitDataSet(dataSet, featAndVal):
    '''
    .分割數(shù)據(jù)集捞奕,根據(jù)特征feat(比如年齡)和特征對(duì)應(yīng)的某個(gè)值val(比如青年)
    .將數(shù)據(jù)dataSet分割為兩部分:青年的數(shù)據(jù)集sub_dateSet1,非青年的數(shù)據(jù)集sub_dateSet2,并返回兩個(gè)子數(shù)據(jù)集
    .返回的子數(shù)據(jù)集可用來(lái)計(jì)算條件基尼指數(shù)佩微,Gini(D,A)
    '''
    sub_dateSet1 = []
    sub_dateSet2 = []
    for featVec in dataSet:
        if featVec[featAndVal[0]] == featAndVal[1]:
            reduceDataSet = featVec[:featAndVal[0]]
            reduceDataSet.extend(featVec[featAndVal[0]+1:])
            sub_dateSet1.append(reduceDataSet)
        else:
            reduceDataSet = featVec[:featAndVal[0]]
            reduceDataSet.extend(featVec[featAndVal[0]+1:])
            sub_dateSet2.append(reduceDataSet)    
    return sub_dateSet1,sub_dateSet2   

def chooseBestFeatAndCuttingpoint(dataSet):
    '''
    .遍歷數(shù)據(jù)集找到最小的基尼指數(shù)缝彬,選擇最優(yōu)特征與最優(yōu)切分點(diǎn)
    '''
    bestFeatAndCuttingpoint = [-1,-1]#定義優(yōu)特征和最優(yōu)切分點(diǎn)
    min_gini = float("inf")#定義最小基尼指數(shù)
    numFeat = len(dataSet[0]) - 1#特征數(shù)
    numData = len(dataSet)#樣本數(shù)
    for i in range(numFeat):#遍歷所有特征
        featList = [example[i] for example in dataSet]
        uniqueFeat = set(featList)
        for value in uniqueFeat:#遍歷所有可能的切分點(diǎn)
            #把樣本集合D根據(jù)特征A是否取某一可能值a被分割成D1和D2兩部分
            subdata1,subdata2 =  splitDataSet(dataSet, [i,value])
            #計(jì)算在特征A,切分點(diǎn)a的條件下,集合D的基尼指數(shù)
            tmp_gini = (float(len(subdata1))/numData) * calcGini(subdata1) + (float(len(subdata2))/numData) * calcGini(subdata2)
            if tmp_gini < min_gini:
                min_gini = tmp_gini
                bestFeatAndCuttingpoint = [i,value]

    return bestFeatAndCuttingpoint

def majorityCnt(classList):
    '''
    .多數(shù)投票表決哺眯,有時(shí)候會(huì)遇到數(shù)據(jù)集已經(jīng)處理了所有的屬性
    .但是類(lèi)標(biāo)簽還不是唯一的谷浅,這個(gè)時(shí)候便用該方法確定該葉子節(jié)點(diǎn)的分類(lèi)
    '''
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reserve=True)
    return sortedClassCount[0][0]

def creatCART(dataSet, labels):
    '''
    .用數(shù)據(jù)字典結(jié)構(gòu)存儲(chǔ)樹(shù)
    .后續(xù)的CART樹(shù)剪枝就用這種結(jié)構(gòu)
    '''
    classList = [example[-1] for example in dataSet]#dataSet的最后一列,類(lèi)別列
    #結(jié)束遞歸的條件:
    #1.類(lèi)別完全相同
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    #2.分類(lèi)到了最后一個(gè)節(jié)點(diǎn)奶卓,用多數(shù)投票決定類(lèi)別
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    #選擇最優(yōu)特征和最優(yōu)切分點(diǎn)
    bestFeatAndCuttingpoint = chooseBestFeatAndCuttingpoint(dataSet)
    bestFeatLabel = labels[bestFeatAndCuttingpoint[0]]#特征對(duì)應(yīng)的標(biāo)簽
    mytree = {bestFeatLabel:{}}#定義樹(shù)一疯,用字典類(lèi)型的結(jié)構(gòu)就足以表示決策樹(shù)全部的信息
    del(labels[bestFeatAndCuttingpoint[0]])#將用過(guò)的標(biāo)簽刪除
    sub_dataSet1,sub_dataSet2 = splitDataSet(dataSet, bestFeatAndCuttingpoint)#分割成D1和D2
    subLabels = labels[:]#去掉用過(guò)后的標(biāo)簽
    mytree[bestFeatLabel]['是'] = creatCART(sub_dataSet1, subLabels)#符合val的Branch,即D1
    mytree[bestFeatLabel]['否'] = creatCART(sub_dataSet2, subLabels)#不符合val的Branch夺姑,即D2
    return mytree 

class CART(object):
    '''
    .用特殊類(lèi)型結(jié)構(gòu)存儲(chǔ)樹(shù)墩邀,自己建的TreeNode,樹(shù)節(jié)點(diǎn)形式的結(jié)構(gòu)
    .這種結(jié)構(gòu)還不完善盏浙,沒(méi)有去調(diào)式
    '''
    def __init__(self,data=None):
        def creatNode(dataSet=None, bestFeatAndCuttingpoint=None):
            gini = calcGini(dataSet)
            #遞歸停止條件:樣本個(gè)數(shù)小于預(yù)定閾值眉睹,或樣本集的基尼指數(shù)小于預(yù)定閾值,或這沒(méi)有更多特征
            if len(dataSet) <=0 or gini <=0.0001 or len(dataSet[0]) <=0:
                return None
            #選擇最優(yōu)特征和最優(yōu)切分點(diǎn)
            sub_dataSet1,sub_dataSet2 = splitDataSet(dataSet, bestFeatAndCuttingpoint)#分割成D1和D2
            return TreeNode(bestFeatAndCuttingpoint[0], bestFeatAndCuttingpoint[1], creatNode(sub_dataSet1,chooseBestFeatAndCuttingpoint(sub_dataSet1)), creatNode(sub_dataSet2,chooseBestFeatAndCuttingpoint(sub_dataSet2)))
        self.root = creatNode(data, chooseBestFeatAndCuttingpoint(data))  
        
             
def preOrder(root):
    '''
    .樹(shù)的前序遍歷
    '''
    print(root.feat)
    if root.left:
        preOrder(root.left)
    if root.right:
        preOrder(root.right)

regTrees.py

'''
Created on 2017年8月5日

@author: fujianfei
'''
import numpy as np
from os.path import os 
import matplotlib.pyplot as plt


def loadDataSet(fileName):
    '''
    .導(dǎo)入數(shù)據(jù)
    '''
    data_path = os.getcwd()+'\\data\\'
    dataMat = np.loadtxt(data_path+fileName,delimiter='\t')
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    '''
    .將數(shù)據(jù)根據(jù)特征和值分成兩部分废膘,一部分為大于value的數(shù)據(jù)集mat0竹海,一部分為小于等于Value的數(shù)據(jù)集mat1
    '''
#     print(np.nonzero((dataSet[:,feature] > value)))
    mat0 = dataSet[np.nonzero((dataSet[:,feature] > value)),:][0]
    mat1 = dataSet[np.nonzero((dataSet[:,feature] <= value)),:][0]
    return mat0,mat1

def regLeaf(dataSet):
    return np.mean(dataSet[:,-1])

def regErr(dataSet):
    return np.var(dataSet[:,-1]) * len(dataSet)


def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0];tolN = ops[1]
    if len(set(dataSet[:,-1].tolist())) == 1:
        return None, leafType(dataSet)
    n = len(dataSet[0])
    S = errType(dataSet)
    bestS = float('inf'); bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):#遍歷所有特征
        for splitVal in set(dataSet[:,featIndex]):#遍歷所有確定特征的值
            mat0,mat1 = binSplitDataSet(dataSet, featIndex, splitVal)#將數(shù)據(jù)分成兩部分
            if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):continue
            newS = errType(mat0) + errType(mat1)#計(jì)算分成兩部分后的數(shù)據(jù)的方差之和
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if(S-bestS) < tolS:
        return None, leafType(dataSet)
    mat0,mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex, bestValue
        

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

def istree(obj):
    return (type(obj).__name__ == 'dict');

def getMean(tree):
    '''
    .計(jì)算樹(shù)的平均值
    '''
    if istree(tree['left']) : return getMean(tree['left'])
    if istree(tree['right']) : return getMean(tree['right'])
    return (tree['left']+tree['right'])/2.0
    
def prune(tree, testDate):
    '''
    .剪枝
    '''
    if len(testDate) == 0 : return getMean(tree)
    if istree(tree['left']) or istree(tree['right']):
        lSet, rSet = binSplitDataSet(testDate, tree['spInd'], tree['spVal'])
    if istree(tree['left']) : tree['left'] = prune(tree['left'], lSet)
    if istree(tree['right']) : tree['right'] = prune(tree['right'], rSet)
    if (not istree(tree['left'])) and (not istree(tree['right'])):
        lSet, rSet = binSplitDataSet(testDate, tree['spInd'], tree['spVal'])
        #剪枝前的誤差
        erroNoMerge = np.sum(np.power(lSet[:,-1]-tree['left'],2)) + np.sum(np.power(rSet[:,-1]-tree['right'],2))
        #剪枝后的誤差
        treeMean = (tree['left'] + tree['right'])/2.0
        erroMerge = np.sum(np.power(testDate[:,-1]-treeMean,2))
        #如果剪枝后的誤差小于剪枝前的誤差,則進(jìn)行剪枝
        if erroMerge < erroNoMerge:
            print('merging')
            return treeMean
        else : return tree
    else : return tree
    
        
dataSet = loadDataSet('ex0.txt')
dataSet1 = loadDataSet('ex0test.txt')
plt.subplot(121)
plt.scatter(dataSet[:,0], dataSet[:,1])
plt.subplot(122)
plt.scatter(dataSet1[:,0], dataSet1[:,1])

plt.show()
tree_ = createTree(dataSet,ops=(0.1,0.4))
tree_ = prune(tree_,dataSet1)
print(tree_)


init.py

from DecisionTree import trees,CARTTree,treePlotter,regTrees
from os.path import os 

# def createDataSet():
#     dataSet = [[1,1,'yes'],
#                [1,1,'yes'],
#                [1,0,'no'],
#                [0,1,'no'],
#                [0,1,'no']]
#     labels = ['no surfacing','flippers']
#     return dataSet,labels
 
# def createDataSet():
#     dataSet = [[1,2,2,3,'no'],
#                [1,2,2,2,'no'],
#                [1,1,2,2,'yes'],
#                [1,1,1,3,'yes'],
#                [1,2,2,3,'no'],
#                [2,2,2,3,'no'],
#                [2,2,2,2,'no'],
#                [2,1,1,2,'yes'],
#                [2,2,1,1,'yes'],
#                [2,2,1,1,'yes'],
#                [3,2,1,1,'yes'],
#                [3,2,1,2,'yes'],
#                [3,1,2,2,'yes'],
#                [3,1,2,1,'yes'],
#                [3,2,2,3,'no']]
#     labels = ['年齡','有工作','有自己房子','信貸情況']
#     return dataSet,labels


def loadData(fileName):
    dataSet = []
    data_path = os.getcwd()+'\\data\\'
    fr = open(data_path+fileName)
    for line in fr.readlines():
        curLine = line.strip().split(',')
        dataSet.append(curLine)
    return dataSet


dataSet = loadData('watermelon1.txt')
labels = ['色澤', '根蒂', '敲聲', '紋理', '臍部', '觸感']

# dataSet,labels = createDataSet()



mytree = trees.createTree(dataSet, labels)
# mytree = regTrees.createTree(dataSet)
# mytree = CARTTree.creatCART(dataSet, labels)
print(mytree)

# treePlotter.createPlot(mytree)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末丐黄,一起剝皮案震驚了整個(gè)濱河市斋配,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖艰争,帶你破解...
    沈念sama閱讀 216,744評(píng)論 6 502
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件坏瞄,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡甩卓,警方通過(guò)查閱死者的電腦和手機(jī)鸠匀,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,505評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門(mén),熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)猛频,“玉大人狮崩,你說(shuō)我怎么就攤上這事÷寡埃” “怎么了睦柴?”我有些...
    開(kāi)封第一講書(shū)人閱讀 163,105評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)毡熏。 經(jīng)常有香客問(wèn)我坦敌,道長(zhǎng),這世上最難降的妖魔是什么痢法? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,242評(píng)論 1 292
  • 正文 為了忘掉前任狱窘,我火速辦了婚禮,結(jié)果婚禮上财搁,老公的妹妹穿的比我還像新娘蘸炸。我一直安慰自己,他們只是感情好尖奔,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,269評(píng)論 6 389
  • 文/花漫 我一把揭開(kāi)白布搭儒。 她就那樣靜靜地躺著,像睡著了一般提茁。 火紅的嫁衣襯著肌膚如雪淹禾。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 51,215評(píng)論 1 299
  • 那天茴扁,我揣著相機(jī)與錄音铃岔,去河邊找鬼。 笑死峭火,一個(gè)胖子當(dāng)著我的面吹牛毁习,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播卖丸,決...
    沈念sama閱讀 40,096評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼蜓洪,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了坯苹?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書(shū)人閱讀 38,939評(píng)論 0 274
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤摇天,失蹤者是張志新(化名)和其女友劉穎粹湃,沒(méi)想到半個(gè)月后恐仑,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,354評(píng)論 1 311
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡为鳄,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,573評(píng)論 2 333
  • 正文 我和宋清朗相戀三年裳仆,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片孤钦。...
    茶點(diǎn)故事閱讀 39,745評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡歧斟,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出偏形,到底是詐尸還是另有隱情静袖,我是刑警寧澤,帶...
    沈念sama閱讀 35,448評(píng)論 5 344
  • 正文 年R本政府宣布俊扭,位于F島的核電站队橙,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏萨惑。R本人自食惡果不足惜捐康,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,048評(píng)論 3 327
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望庸蔼。 院中可真熱鬧解总,春花似錦、人聲如沸姐仅。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 31,683評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)萍嬉。三九已至乌昔,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間壤追,已是汗流浹背磕道。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,838評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留行冰,地道東北人溺蕉。 一個(gè)月前我還...
    沈念sama閱讀 47,776評(píng)論 2 369
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像悼做,于是被迫代替她去往敵國(guó)和親疯特。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,652評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容