# -*- coding: utf-8 -*-
"""
Created on Fri Jul 13 16:00:57 2018
"""
#coding:utf-8
from math import log
class DecisonTree:
? ? trainData = []
? ? trainLabel = []
? ? featureValus = {} #每個(gè)特征所有可能的取值
? ? def __init__(self, trainData, trainLabel, threshold):
? ? ? ? self.loadData(trainData, trainLabel)
? ? ? ? self.threshold = threshold
? ? ? ? self.tree = self.createTree(range(0,len(trainLabel)), range(0,len(trainData[0])))
? ? #加載數(shù)據(jù)
? ? def loadData(self, trainData, trainLabel):
? ? ? ? if len(trainData) != len(trainLabel):
? ? ? ? ? ? raise ValueError('input error')
? ? ? ? self.trainData = trainData
? ? ? ? self.trainLabel = trainLabel
? ? ? ? #計(jì)算 featureValus
? ? ? ? for data in trainData:
? ? ? ? ? ? for index, value in enumerate(data):
? ? ? ? ? ? ? ? if not index in self.featureValus.keys():
? ? ? ? ? ? ? ? ? ? self.featureValus[index] = [value]
? ? ? ? ? ? ? ? if not value in self.featureValus[index]:
? ? ? ? ? ? ? ? ? ? self.featureValus[index].append(value)
? ? #計(jì)算信息熵
? ? def caculateEntropy(self, dataset):
? ? ? ? labelCount = self.labelCount(dataset)
? ? ? ? size = len(dataset)
? ? ? ? result = 0
? ? ? ? for i in labelCount.values():
? ? ? ? ? ? pi = i / float(size)
? ? ? ? ? ? result -= pi * (log(pi) /log(2))
? ? ? ? return result
? ? #計(jì)算信息增益
? ? def caculateGain(self, dataset, feature):
? ? ? ? values = self.featureValus[feature] #特征feature 所有可能的取值
? ? ? ? result = 0
? ? ? ? for v in values:
? ? ? ? ? ? subDataset = self.splitDataset(dataset=dataset, feature=feature, value=v)
? ? ? ? ? ? result += len(subDataset) / float(len(dataset)) * self.caculateEntropy(subDataset)
? ? ? ? return self.caculateEntropy(dataset=dataset) - result
? ? #計(jì)算數(shù)據(jù)集中缠犀,每個(gè)標(biāo)簽出現(xiàn)的次數(shù)
? ? def labelCount(self, dataset):
? ? ? ? labelCount = {}
? ? ? ? for i in dataset:
? ? ? ? ? ? if trainLabel[i] in labelCount.keys():
? ? ? ? ? ? ? ? labelCount[trainLabel[i]] += 1
? ? ? ? ? ? else:
? ? ? ? ? ? ? ? labelCount[trainLabel[i]] = 1
? ? ? ? return labelCount
? ? '''
? ? dataset:數(shù)據(jù)集
? ? features:特征集
? ? '''
? ? def createTree(self, dataset, features):
? ? ? ? labelCount = self.labelCount(dataset)
? ? ? ? #如果特征集為空敷待,則該樹為單節(jié)點(diǎn)樹
? ? ? ? #計(jì)算數(shù)據(jù)集中出現(xiàn)次數(shù)最多的標(biāo)簽
? ? ? ? if not features:
? ? ? ? ? ? return max(list(labelCount.items()),key = lambda x:x[1])[0]
? ? ? ? #如果數(shù)據(jù)集中,只包同一種標(biāo)簽壶硅,則該樹為單節(jié)點(diǎn)樹
? ? ? ? if len(labelCount) == 1:
? ? ? ? ? ? return labelCount.keys()[0]
? ? ? ? #計(jì)算特征集中每個(gè)特征的信息增益
? ? ? ? l = map(lambda x : [x, self.caculateGain(dataset=dataset, feature=x)], features)
? ? ? ? #選取信息增益最大的特征
? ? ? ? feature, gain = max(l, key = lambda x: x[1])
? ? ? ? #如果最大信息增益小于閾值威兜,則該樹為單節(jié)點(diǎn)樹
? ? ? ? #
? ? ? ? if self.threshold > gain:
? ? ? ? ? ? return max(list(labelCount.items()),key = lambda x:x[1])[0]
? ? ? ? tree = {}
? ? ? ? #選取特征子集
? ? ? ? subFeatures = filter(lambda x : x != feature, features)
? ? ? ? tree['feature'] = feature
? ? ? ? #構(gòu)建子樹
? ? ? ? for value in self.featureValus[feature]:
? ? ? ? ? ? subDataset = self.splitDataset(dataset=dataset, feature=feature, value=value)
? ? ? ? ? ? #保證子數(shù)據(jù)集非空
? ? ? ? ? ? if not subDataset:
? ? ? ? ? ? ? ? continue
? ? ? ? ? ? tree[value] = self.createTree(dataset=subDataset, features=subFeatures)
? ? ? ? return tree
? ? def splitDataset(self, dataset, feature, value):
? ? ? ? reslut = []
? ? ? ? for index in dataset:
? ? ? ? ? ? if self.trainData[index][feature] == value:
? ? ? ? ? ? ? ? reslut.append(index)
? ? ? ? return reslut
? ? def classify(self, data):
? ? ? ? def f(tree, data):
? ? ? ? ? ? if type(tree) != dict:
? ? ? ? ? ? ? ? return tree
? ? ? ? ? ? else:
? ? ? ? ? ? ? ? return f(tree[data[tree['feature']]], data)
? ? ? ? return f(self.tree, data)
? ? if __name__ == '__main__':
? ? trainData = [
? ? ? ? [0, 0, 0, 0],
? ? ? ? [0, 0, 0, 1],
? ? ? ? [0, 1, 0, 1],
? ? ? ? [0, 1, 1, 0],
? ? ? ? [0, 0, 0, 0],
? ? ? ? [1, 0, 0, 0],
? ? ? ? [1, 0, 0, 1],
? ? ? ? [1, 1, 1, 1],
? ? ? ? [1, 0, 1, 2],
? ? ? ? [1, 0, 1, 2],
? ? ? ? [2, 0, 1, 2],
? ? ? ? [2, 0, 1, 1],
? ? ? ? [2, 1, 0, 1],
? ? ? ? [2, 1, 0, 2],
? ? ? ? [2, 0, 0, 0],
? ? ]
? ? trainLabel = [0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0]
? ? tree = DecisonTree(trainData=trainData, trainLabel=trainLabel, threshold=0)
? ? print tree.tree