Skip to content

Commit

Permalink
Fix Monotonic Constraint Performance (#75)
Browse files Browse the repository at this point in the history
* Update midpoint calculation
* Update unit test
  • Loading branch information
reidjohnson authored Aug 15, 2024
1 parent 0f8b7f4 commit 10847e7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
23 changes: 12 additions & 11 deletions quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,28 +390,28 @@ def _get_y_bound_leaves(self, y, y_train_leaves):
for i, estimator in enumerate(self.estimators_):
tree = estimator.tree_

min_values = np.full(tree.node_count, np.inf)
max_values = np.full(tree.node_count, -np.inf)
num_values = np.zeros(tree.node_count)
sum_values = np.zeros(tree.node_count)

# Iterate over all nodes in reverse order (leaves to root).
for node_idx in range(tree.node_count - 1, -1, -1):
if tree.children_left[node_idx] == tree.children_right[node_idx]: # leaf node
# Populate the leaf nodes with actual target values.
# Populate the leaf nodes.
leaf_indices = y_train_leaves[i, node_idx]
leaf_indices = leaf_indices[leaf_indices != 0]
if leaf_indices.size > 0:
leaf_targets = y[leaf_indices - 1]
min_values[node_idx] = min(leaf_targets)
max_values[node_idx] = max(leaf_targets)
num_values[node_idx] = leaf_indices.size
sum_values[node_idx] = leaf_targets.sum()
else: # non-leaf node
# The min and max of the parent is the min/max of its children.
# Populate the non-leaf nodes based on the children.
left_child = tree.children_left[node_idx]
right_child = tree.children_right[node_idx]
min_values[node_idx] = min(min_values[left_child], min_values[right_child])
max_values[node_idx] = max(max_values[left_child], max_values[right_child])
num_values[node_idx] = num_values[left_child] + num_values[right_child]
sum_values[node_idx] = sum_values[left_child] + sum_values[right_child]

# Traverse from root to leaves to enforce monotonicity.
stack = [(0, min_values[0], max_values[0])] # start with root node (node 0)
stack = [(0, -np.inf, np.inf)] # start with root node (node 0)

while stack:
node_idx, min_bound, max_bound = stack.pop()
Expand All @@ -429,8 +429,9 @@ def _get_y_bound_leaves(self, y, y_train_leaves):
stack.append((right_child, min_bound, max_bound))
else:
# Calculate midpoint that respects the current node's bounds.
mid = (max_values[left_child] + min_values[right_child]) / 2
mid = max(min(mid, max_bound), min_bound)
left_avg = sum_values[left_child] / (2 * num_values[left_child])
right_avg = sum_values[right_child] / (2 * num_values[right_child])
mid = max(min(left_avg + right_avg, max_bound), min_bound)

if self.monotonic_cst[feature_idx] == 1: # increasing monotonicity
stack.append((left_child, min_bound, mid))
Expand Down
3 changes: 3 additions & 0 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,9 @@ def check_monotonic_constraints(name, max_samples_leaf):

y = est.predict(X_test, oob_score=oob_score)

score = est.score(X_test, y_test, quantiles=0.5)
assert score > 0.75

# Check the monotonic increase constraint.
y_incr = est.predict(X_test_incr, oob_score=oob_score)
assert np.all(y_incr >= y)
Expand Down

0 comments on commit 10847e7

Please sign in to comment.