標(biāo)簽:代碼實(shí)戰(zhàn)楔脯,經(jīng)過(guò)驗(yàn)證,sklearn.tree可視化老玛,機(jī)器學(xué)習(xí)淤年,決策樹,cart蜡豹,開箱即用
利用sklearn.treeimport DecisionTreeClassifier創(chuàng)建數(shù)據(jù)的決策樹,并可視化結(jié)果
[TOC]
前提
python包:pydotplus溉苛、numpy镜廉、sklearn∮拚剑可通過(guò)pip install安裝娇唯。
Graphviz,安裝參見(jiàn)“可視化樹”一節(jié)
TODO:實(shí)例化
from sklearn.treeimport DecisionTreeClassifier
dt = DecisionTreeClassifier(criterion='gini',# 分類用基尼寂玲,回歸用'entropy'
splitter='best',# 結(jié)點(diǎn)分裂方式塔插,'best' 與'random'
min_samples_leaf=20,# 葉子節(jié)點(diǎn)包含的最少樣本數(shù)
random_state=2020,# 隨機(jī)策略
class_weight='balanced',# balanced根據(jù)樣本數(shù)量自動(dòng)調(diào)整權(quán)重
presort=True,# 數(shù)據(jù)集小用True,否則用false拓哟,可以提高速度
)
參數(shù)可以參考:http://www.reibang.com/p/f0f41ad72e5f
其他cart相關(guān)http://d0evi1.com/sklearn/cart/
TODO: 訓(xùn)練數(shù)據(jù)準(zhǔn)備與訓(xùn)練/生成決策樹
- 訓(xùn)練數(shù)據(jù)要求
--feature是numpy.array想许,數(shù)據(jù)類型為numpy.float64
,第i行是第i條數(shù)據(jù),第j列是對(duì)應(yīng)數(shù)據(jù)的第j個(gè)特征
--label是numpy.array流纹,數(shù)據(jù)類型為numpy.int16
糜烹,只有一列,第i行是第i條數(shù)據(jù)漱凝。若想使用float的label疮蹦,參見(jiàn)本節(jié)尾部的”真實(shí)輸出y是float應(yīng)該怎么改“
# # 加載文件示例
# # loadDataSet和相應(yīng)數(shù)據(jù)參見(jiàn)https://github.com/Jack-Cherish/Machine-Learning
# train_filename = 'cart_train.txt'
# train_Data = loadDataSet(train_filename) # 前幾列是參數(shù),最后一列是目標(biāo)
# train_Mat = np.mat(train_Data)
# feature = train_Mat[:, list(range(train_Mat.shape[1] - 1))]
# label = train_Mat[:, -1].astype(np.int16) # float轉(zhuǎn)int
# 使用隨機(jī)數(shù)做演示
feature = np.random.rand(100, 4) # 100行 4列茸炒,0-1的隨機(jī)浮點(diǎn)數(shù)
label = np.random.randint(0, 3, size=(100, 1))# [0,3)的隨機(jī)整數(shù)愕乎,100行 1列
# TODO:訓(xùn)練
dt = dt.fit(feature, label)
- 真實(shí)輸出y是float應(yīng)該怎么改
當(dāng)前代碼中真實(shí)輸出要求int類型,即訓(xùn)練時(shí)的y必須是int類型壁公,因?yàn)閷?shí)例化時(shí)criterion用的是'gini'感论,如果是float,修改兩處代碼贮尖,一處是示例化時(shí)的criterion笛粘,一處是訓(xùn)練數(shù)據(jù)中不調(diào)用astype(np.int16) - 異常: ValueError: Unknown label type: 'continuous'
大概率是因?yàn)槭褂昧嘶嵯禂?shù)但y卻是float類型的,如果數(shù)據(jù)的y本就是float湿硝,則調(diào)整實(shí)例化的criterion為'entropy'薪前;如果y本是int,注意要轉(zhuǎn)換為int
TODO:可視化樹
- 【前提】安裝Graphviz
- 安裝
-- 安裝參考 https://www.cnblogs.com/shuodehaoa/p/8667045.html
-- dot語(yǔ)法关斜、安裝參考示括、中文配置 https://blog.csdn.net/codingstandards/article/details/83778386
-- 官網(wǎng)下載地址 https://graphviz.org/download/
-- 我的下載:win10, 64位,stable - dot -c :可能需要命令行執(zhí)行一下
"Graphviz安裝目錄/bin/dot.exe -c"
- 安裝
- 可視化方法參考自博客痢畜,但做不到開箱即用
from sklearn.treeimport export_graphviz
from sklearn.externals.siximport StringIO
import pydotplus
import os
# 配置Graphviz的dot的地址挽牢。可能配置好環(huán)境變量之后重啟os酌毡,就不需要這部分了话原。
path_graphviz ="D:/software/graphviz/Graphviz2.44.1/bin" # replace by your Graphviz bin path
os.environ["PATH"] += os.pathsep + path_graphviz
# 轉(zhuǎn)換格式,從而保存本地
dot_data = StringIO()
dot_data = export_graphviz(dt,out_file=None)
graph = pydotplus.graph_from_dot_data(dot_data)
# graph = pydotplus.graph_from_dot_data(dot_data.getvalue())# 見(jiàn)到這個(gè)寫法线衫,但是我本地嘗試是失敗的
# graph.write_pdf("tmp/tree.pdf")
graph.write_jpg("tmp/tree.jpg")
print('Visible tree plot saved.')
之后可以查看"tmp/tree.jpg"
【exception】pydotplus.graphviz.InvocationException: GraphViz's executables not found
- 確定graphviz已安裝凿可。
注意不是pip安裝的python包,而是官網(wǎng)下載資源后按官網(wǎng)說(shuō)明來(lái)安裝的程序 - 檢查代碼中是否path_graphviz配置有誤授账。
注意"D:/software/graphviz"
是安裝時(shí)選擇的安裝路徑枯跑,"D:/software/graphviz/Graphviz2.44.1"
是graphviz的安裝目錄,它有"bin","include","lib","share"四個(gè)文件夾白热,"D:/software/graphviz/Graphviz2.44.1/bin"
才是我們要配的路徑地址敛助。
參考:https://blog.csdn.net/weixin_36407399/article/details/87890230
【exception】Format: "jpg" not recognized. Use one of:
Use one of:
后面是空的,說(shuō)明需要cmd執(zhí)行一下dot -c
屋确。
“dot -c“ means: Configure plugins (Writes $prefix/lib/graphviz/config with available plugin information. Needs write privilege.)
如果是linux系統(tǒng)纳击,可能需要root權(quán)限sudo dot -c
如果沒(méi)有配環(huán)境變量续扔,就是D:\software\graphviz\Graphviz2.44.1\bin\dot.exe -c
參考:https://blog.csdn.net/qq_43166422/article/details/105540575
TODO:預(yù)測(cè)
pre=dt.predict(feature)
#TODO:評(píng)估
print(estimate(label.T.tolist()[0],pre.T.tolist())) # 參數(shù)必須是一維數(shù)組
評(píng)估結(jié)果(由于訓(xùn)練數(shù)據(jù)是隨機(jī)生成的,所以這里只是格式參考)
混淆矩陣:
true\pre 0 1 2
0 14 3 19 36
1 4 14 14 32
2 5 4 23 32
23 21 56
report:
precision recall f1-score support
0 0.61 0.39 0.47 36
1 0.67 0.44 0.53 32
2 0.41 0.72 0.52 32
micro avg 0.51 0.51 0.51 100
macro avg 0.56 0.52 0.51 100
weighted avg 0.56 0.51 0.51 100
評(píng)估函數(shù)
import numpy
from sklearn.metrics import classification_report
def estimate(true_label:[], predict_label:[],confu_row_prefix='\t',confu_coltitle_sep='\t'):
'''輸入true_label和predict_label 輸出評(píng)估結(jié)果'''
# TODO: 函數(shù)進(jìn)入條件
assert type(true_label)==type([])
assert type(predict_label)==type([])
assert len(true_label) == len(predict_label)
# TODO:初始化混淆矩陣
class_k=set(true_label)
class_k=list(class_k)
class_num = len(class_k)
assert class_num>=1
confu = numpy.zeros([class_num, class_num]) # confu[實(shí)際][預(yù)測(cè)]
k_v={class_k[i]:i for i in range(class_num)}
# TODO:遍歷每一行統(tǒng)計(jì)準(zhǔn)確率
for ix in range(0, len(true_label)): # 遍歷訓(xùn)練集的每行
confu[k_v[true_label[ix]]][k_v[predict_label[ix]]] += 1 # confu[實(shí)際][預(yù)測(cè)]
# TODO:混淆矩陣打印字符串
string='混淆矩陣:\n%10s'%'true\pre'
sum_col=numpy.sum(confu,axis=0)#按列相加
sum_row=numpy.sum(confu,axis=1)#按行相加
# title
for i in range(class_num):
string+='%10s'%(class_k[i])
string+='\n'
# data
for i in range(class_num):
string+='%10s'%(class_k[i])
for j in range(class_num):
string+='%10s'%(int(confu[i][j]))
string+='%10s'%(int(sum_row[i]))# support
string+='\n'
# support
string+='%10s'%''
for i in range(class_num):
string+='%10s'%(int(sum_col[i]))
string+='\n'
# TODO:準(zhǔn)確率评疗、召回率等評(píng)估結(jié)果
cr = classification_report(true_label,
predict_label)
string+='report:\n'
string+=str(cr)
return string
決策樹的一些認(rèn)知
- 一些方法:
- ID3:
- 增熵测砂。
- 會(huì)越分越細(xì),容易過(guò)擬合百匆,所以有C4.5
- C4.5
- 信息增益率(增熵要除以屬性熵)
- 需要對(duì)數(shù)據(jù)集進(jìn)行多次掃描砌些,算法效率相對(duì)較低
- CART:
- GINI指數(shù)。
- 同樣容易過(guò)擬合加匈。需剪枝存璃,對(duì)特別長(zhǎng)的樹直接剪掉。
- ID3:
一些要點(diǎn):
- 節(jié)點(diǎn)的分裂:一般當(dāng)一個(gè)節(jié)點(diǎn)所代表的屬性無(wú)法給出判斷時(shí)雕拼,則選擇將這一節(jié)點(diǎn)分成2個(gè)子節(jié)點(diǎn)(如不是二叉樹的情況會(huì)分成n個(gè)子節(jié)點(diǎn))
- 閾值的確定:選擇適當(dāng)?shù)拈撝凳沟梅诸愬e(cuò)誤率最小 (Training Error)纵东。
- 剪枝:預(yù)剪枝、后剪枝
基尼系數(shù)的計(jì)算: