決策樹圖像化顯示及剪枝
1.決策樹可視化
在上一節(jié)(親手實(shí)現(xiàn)決策樹(一))中婆誓,我們利用print_tree進(jìn)行了文字輸出決策樹:
3:21?
T->0:google?
T->{'Premium': 3}
F->{'Basic': 3}
F->2:yes?
T->0:slashdot?
T->{'None': 2}
F->{'Basic': 3}
F->{'None': 4}
下面將介紹如何圖形化顯示決策樹
def draw_tree(tree, jpeg='tree.jpeg'):
w = get_width(tree) * 100
h = get_depth(tree) * 100 + 120
img = Image.new('RGB', (w,h), color=(255, 255, 255))
draw = ImageDraw.Draw(img)
draw_node(draw, tree, w/2, 20)
img.save(jpeg, 'JPEG')
利用到的輔助函數(shù)get_width如下:
def get_width(tree):
if tree.tb is None and tree.fb is None:
return 1
return get_width(tree.tb) + get_width(tree.fb)
利用到的輔助函數(shù)get_depth如下:
def get_depth(tree):
if tree.tb is None and tree.fb is None:
return 0
return max(get_depth(tree.tb), get_width(tree.fb)) + 1
利用到的輔助函數(shù)draw_node如下:
def draw_node(draw, tree, x, y):
if tree.results is None:
# 得到每個(gè)分支的寬度
w1 = get_width(tree.fb) * 100
w2 = get_width(tree.tb) * 100
# 確定此節(jié)點(diǎn)所要占據(jù)的總空間
left = x - (w1 + w2) / 2
right = x + (w1 + w2) / 2
# 繪制判斷條件字符串
draw.text((x-20, y-10), str(tree.col) + ":" + str(tree.value), (0, 0, 0))
# 繪制到分支的連線
draw.line((x, y, left + w1/2, y + 100), fill=(255, 0, 0))
draw.line((x, y, right - w2/2, y + 100), fill=(255, 0, 0))
# 繪制分支的節(jié)點(diǎn)
draw_node(draw, tree.fb, left+w1/2, y+100)
draw_node(draw, tree.tb, right-w2/2, y+100)
else:
txt = ' \n'.join(['%s:%d' % v for v in tree.results.items()])
draw.text((x - 20, y), txt, (0, 0, 0))
畫出來(lái)的結(jié)果為:
2.決策樹的剪枝
為了避免過(guò)擬合,需要對(duì)決策樹進(jìn)行剪枝选浑,如果對(duì)某個(gè)節(jié)點(diǎn)分類后的子節(jié)點(diǎn)信息增益小于給定閾值,則不進(jìn)行細(xì)化抡谐。
def prune(tree, min_gain):
# 如果分支不是葉節(jié)點(diǎn)嗅辣,則對(duì)其進(jìn)行剪枝操作
if tree.tb.results is None:
prune(tree.tb, min_gain)
if tree.tb.results is None:
predict(tree.fb, min_gain)
# 如果兩個(gè)子分支都是葉子節(jié)點(diǎn),則判斷它們是否需要合并
if tree.tb.results is not None and tree.fb.results is not None:
# 構(gòu)造合并后的數(shù)據(jù)集
tb, fb = [], []
for v, c in tree.tb.results.items():
tb += [[v]] * c
for v, c in tree.tb.results.items():
fb += [[v]] * c
# 檢查熵的減少情況
delta = entropy(tb + fb) - (entropy(tb) + entropy(fb))/2
if delta < min_gain:
# 合并分支
tree.tb, tree.fb = None, None
tree.results = unique_counts(tb + fb)