-
Notifications
You must be signed in to change notification settings - Fork 0
/
treeprint.py
104 lines (94 loc) · 3.76 KB
/
treeprint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 10:02:37 2017
@author: ltostrams
"""
from __future__ import print_function
"""
NOTE: You don't have to change this code! (but you can..)
A simple(?) tree visualizer for sklearn DecisionTreeClassifiers.
Based on suggestions from this thread: https://github.com/scikit-learn/scikit-learn/issues/6261
Adaptations by Lisa Tostrams, september 2017
"""
import operator
def tree_print(clf, attributeNames, classNames):
"""
Print the tree of a sklearn DecisionTreeClassifier
Parameters
----------
clf : DecisionTreeClassifier - A tree that has already been fit.
attributeNames: names for the variables
classNames: names for the leafs
"""
tlevel = _tree_rprint('', clf, attributeNames, classNames)
print('<',end='')
for i in range(5*tlevel - 2):
print('-',end='')
print('>')
print('Tree Depth: ',tlevel)
def _tree_rprint(kword, clf, features, labels, node_index=0, tlevel_index=0, parent = 0, left=True):
# Note: The DecisionTreeClassifier uses the Tree structure defined in:
# github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx
# it is an array based tree implementation:
# indent the nodes according to their tree level
# LT changes:
# draw tree from left to right
# added numbering
# added arrows
#
# TODO:
# clean up ugly string formatting code
# the following should use the TREE_LEAF constant defined in _tree.pyx
# instead of -1, not quite sure how to get at it from the tree user level
if clf.tree_.children_left[node_index] == -1: # indicates leaf
print(kword[:-4], end=' ' if kword else '')
# get the majority label
count_list = clf.tree_.value[node_index, 0]
max_index, max_value = max(enumerate(count_list), key=operator.itemgetter(1))
max_label = labels[max_index]
print(max_label)
return tlevel_index
else:
# compute and print node label
feature = features[clf.tree_.feature[node_index]]
threshold = clf.tree_.threshold[node_index]
# recurse down the children
left_index = clf.tree_.children_left[node_index]
right_index = clf.tree_.children_right[node_index]
#some formatting stuff
string = kword[:-9]
if(left_index<11):
string = kword[:-8]
if(left and node_index is not 0):
leftstr = string[:-1]
leftstr = leftstr+' '
else:
leftstr = string
tmp = leftstr
for i in range(tlevel_index+1 - len(string)):
leftstr = leftstr+' '
#print left rule
ltlevel_index = _tree_rprint(leftstr+' |->{} then'.format(left_index), clf, features, labels, left_index, tlevel_index+1, parent=node_index)
if(node_index is 0):
print(' ', end='')
print(tmp+' |')
print(kword, end=' ' if kword else '')
#print current rule
print('if {} =< {}: go to {}, else go to {}'.format(feature, threshold, left_index, right_index))
#more formatting hell
if(not left):
rightstr = string[:-1]
rightstr = rightstr+' '
else:
rightstr = string
tmp = rightstr
for i in range(tlevel_index+1 - len(string)):
rightstr = rightstr+' '
if(node_index is 0):
print(' ', end='')
print(tmp+' |')
#print right rule
rtlevel_index = _tree_rprint(rightstr+' |->{} else'.format(right_index), clf, features, labels, right_index, tlevel_index+1, parent=node_index, left=False)
# return the maximum depth of either one of the children
return max(ltlevel_index,rtlevel_index)