-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtreeutils.py
90 lines (71 loc) · 3.53 KB
/
treeutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from __future__ import print_function
from sklearn.metrics import make_scorer, f1_score
from sklearn.tree import _tree
import numpy as np, random as rnd
import sys
try:
from cStringIO import StringIO
except:
from io import StringIO
def simplify_tree(decision_tree, X, y, scorer=make_scorer(f1_score, greater_is_better=True), acceptable_score_drop=0.0, verbose=1):
current_score, original_score = 0, 1
while current_score != original_score:
current_score = scorer(decision_tree, X, y)
original_score = current_score
tree = decision_tree.tree_
removed_branches = []
nodes = np.random.permutation(np.arange(tree.node_count))
for i in nodes:
current_left, current_right = tree.children_left[i], tree.children_right[i]
if tree.children_left[i] >= 0 or tree.children_right[i] >= 0:
tree.children_left[i], tree.children_right[i] = -1, -1
auc = scorer(decision_tree, X, y)
if auc >= current_score - acceptable_score_drop:
current_score = auc
removed_branches.append(i)
else:
tree.children_left[i], tree.children_right[i] = current_left, current_right
if verbose:
print("Removed",len(removed_branches)," branches. current score: ", current_score)
return decision_tree
def tree_to_code(tree, feature_names, decimals=4, transform_to_probabilities=True):
tree_ = tree.tree_
tree_feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
rounding_multiplier = np.power(10, decimals)
round = lambda x: np.round(x*rounding_multiplier)/rounding_multiplier
def leaf_value(value, samples=1):
if transform_to_probabilities:
return round(value / samples)[0][1]
else:
return value[0]
stdout_ = sys.stdout
sys.stdout = StringIO()
#print("def probability_of_class_one({}):".format(", ".join(feature_names))+"")
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = tree_feature_name[node]
threshold = tree_.threshold[node]
if tree_.feature[tree_.children_left[node]] == _tree.TREE_UNDEFINED and \
tree_.feature[tree_.children_right[node]] == _tree.TREE_UNDEFINED and \
np.all(np.equal(tree_.value[tree_.children_left[node]], tree_.value[tree_.children_right[node]])):
print("{}return {}".format(indent, leaf_value(tree_.value[node], tree_.weighted_n_node_samples[node])))
else:
print("{}if {} <= {}:".format(indent, name, round(threshold)))
recurse(tree_.children_left[node], depth + 1)
print("{}else:".format(indent)) # # if {} > {}".format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
if transform_to_probabilities:
p = round(tree_.value[node] / tree_.weighted_n_node_samples[node])[0]
else:
p = tree_.value[node]
print("{}return {}".format(indent, leaf_value(tree_.value[node], tree_.weighted_n_node_samples[node])))
recurse(0, 1)
string = sys.stdout.getvalue()
sys.stdout = stdout_
string = "def probability_of_class_one({}):".format(", ".join([f for f in feature_names if f in string]))+"\n"+string
return string