數(shù)據(jù)集下載地址:http://archive.ics.uci.edu/ml/datasets/Adult
在寫該算法時(shí)遇到一個(gè)問題:
構(gòu)造決策樹時(shí)全陨,這兩段代碼雖然都可以成功運(yùn)行辱姨,但是構(gòu)造的結(jié)果卻有些不同戚嗅。
如果用第一種方式遍歷每個(gè)分支,會(huì)導(dǎo)致每次從右側(cè)分支開始遍歷替久,即使把branc_dict調(diào)整為{'right':right_split躏尉,'left':left_split}
而使用第二種方式,則可以正常遍歷(先遍歷左分支颅拦,再遍歷右分支),到目前為止還沒發(fā)現(xiàn)是什么原因?qū)е碌挠蚁牵魑挥兄赖臍g迎留言~
以下為代碼過程:
讀入數(shù)據(jù)
import pandas as pd
columns=['age', 'workclass', 'fnlwgt', 'education', 'education_num',
'marital_status', 'occupation', 'relationship', 'race', 'sex',
'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
'high_income']
data=pd.read_table('./data/income.data',delimiter=',',names=columns)
data.head()
在開始構(gòu)建決策樹之前碌秸,我們需要把數(shù)據(jù)集中的分類型數(shù)據(jù)轉(zhuǎn)換為數(shù)值型哮肚,pandas.Categorical方法可以把string型分類的column轉(zhuǎn)換為Categorical Type,轉(zhuǎn)換以后系統(tǒng)就會(huì)自動(dòng)將該column中的類別映射為一個(gè)數(shù)字允趟。
list=['workclass','education','marital_status', 'occupation',
'relationship', 'race', 'sex', 'native_country','high_income']
for name in list:
col=pd.Categorical.from_array(data[name])
data[name]=col.codes
data.head()
計(jì)算熵和信息增益:
熵
信息增益
def calc_entropy(target):
counts=np.bincount(target)
probabilities=counts/len(target)
entropys=probabilities*np.log2(probabilities)
return -sum(entropys)
def calc_information_gain(data,split_name,target):
entropy=calc_entropy(data[target])
median=np.median(data[split_name])
left_split=data[data[split_name]<=median]
right_split=data[data[split_name]>median]
to_subtract=0
for subset in [left_split,right_split]:
prob=subset.shape[0]/data.shape[0]
to_subtract+=prob*calc_entropy(subset[target])
return entropy-to_subtract
#通過計(jì)算每一個(gè)column的信息增益,獲得最佳分裂屬性(信息增益最大的)
def find_best_column(data,columns,target):
information_gains=[]
for name in columns:
information_gains.append(calc_information_gain(data,name,'high_income'))
information_index=information_gains.index(max(information_gains))
best_column=columns[information_index]
return best_column
帶有存儲(chǔ)功能的ID3算法:
為了實(shí)現(xiàn)存儲(chǔ)功能狮斗,可以使用一個(gè)含有以下關(guān)鍵字的dictionary存儲(chǔ)節(jié)點(diǎn):
- left/right 關(guān)鍵字表示左右結(jié)點(diǎn)
- column 最佳分裂屬性
- median 分裂屬性的中位數(shù)
- number 結(jié)點(diǎn)編號(hào)
- label
如果結(jié)點(diǎn)為葉節(jié)點(diǎn)弧蝇,則僅僅含有l(wèi)abel(值為0/1)和number關(guān)鍵字
偽代碼如下:
def id3(data, target, columns, tree)
1 Create a node for the tree
2 Number the node
3 If all of the values of the target attribute are 1, assign 1 to the label key in tree
4 If all of the values of the target attribute are 0, assign 0 to the label key in tree
5 Using information gain, find A, the column that splits the data best
6 Find the median value in column A
7 Assign the column and median keys in tree
8 Split A into values less than or equal to the median (0), and values above the median (1)
9 For each possible value (0 or 1), vi, of A,
10 Add a new tree branch below Root that corresponds to rows of data where A = vi
11 Let Examples(vi) be the subset of examples that have the value vi for A
12 Create a new key with the name corresponding to the side of the split (0=left, 1=right). The value of this key should be an empty dictionary.
13 Below this new branch, add the subtree id3(data[A==vi], target, columns, tree[split_side])
14 Return Root
實(shí)現(xiàn)代碼:
tree={}
nodes=[] #重點(diǎn)注意:因?yàn)樵谶f歸中使用int型不能自增看疗,所以采取使用數(shù)組的方法。
def id3(data,columns,target,tree):
nodes.append(len(nodes)+1)
tree['number']=nodes[-1]
unique_targets=pd.unique(data[target])
if len(unique_targets)==1:
tree['label']=unique_targets[0]
return #不要忘記返回
#如unique長度不為1摔寨,既包含0又含1怖辆,需要分裂:
best_column=find_best_column(data,columns,target)
median=np.median(data[best_column])
tree['column']=best_column #分裂key
tree['median']=median #median key
left_split=data[data[best_column]<=median]
right_split=data[data[best_column]>median]
branch_dict={'left':left_split,'right':right_split}
for branch in branch_dict:
tree[branch]={}
id3(branch_dict[branch],columns,target,tree[branch])
id3(data, ["age", "marital_status"],"high_income", tree)
print(tree)
結(jié)果為
為了方便觀察決策樹的構(gòu)造結(jié)果竖螃,我們可以寫一個(gè)結(jié)點(diǎn)輸出函數(shù),結(jié)構(gòu)化的輸出生成的決策樹:
def print_with_depth(string,depth):
prefix=" "*depth
print("{0}{1}".format(prefix,string))
def print_node(tree,depth):
if 'label' in tree:
print_with_depth('Leaf label{0}'.format(tree['label']),depth)
return
print_with_depth('{0}>{1}'.format(tree['column'],tree['median']),depth)
branches = [tree["left"], tree["right"]]
for branch in branches:
print_node(branch,depth+1)
print_node(tree, 0 )
輸出
實(shí)現(xiàn)預(yù)測功能:
#預(yù)測函數(shù)
def predict(tree,row):
if 'label' in tree:
return tree['label']
column=tree['column']
median=tree['median']
if row['columns']<=median:
return predict(tree['left'],row)
else:
return predict(tree['right'],row)
print(predict(tree, data.iloc[0]))
predictions=data.apply(lambda x:predict(tree,x),axis=1)
完蒋纬。