Skip to content

Commit

Permalink
Fix Reproduction of scikit-learn Trees with MAE Criterion (#77)
Browse files Browse the repository at this point in the history
* Update prediction aggregation and tests
  • Loading branch information
reidjohnson authored Aug 23, 2024
1 parent c93668e commit 3ee9736
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
24 changes: 5 additions & 19 deletions quantile_forest/_quantile_forest_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,6 @@ cdef class QuantileForest:
preds : array-like of shape (n_samples, n_outputs, n_quantiles)
Quantiles or means for samples as floats.
"""
cdef vector[double] median = [0.5]

cdef intp_t n_quantiles, n_samples, n_trees, n_outputs, n_train
cdef intp_t i, j, k, l
cdef bint use_mean
Expand Down Expand Up @@ -854,23 +852,11 @@ cdef class QuantileForest:
leaf_preds[0].push_back(pred[0])

# Average the quantile predictions across accumulations.
if not use_mean:
for k in range(<intp_t>(leaf_preds.size())):
if leaf_preds[k].size() == 1:
preds_view[i, j, k] = leaf_preds[k][0]
elif leaf_preds[k].size() > 1:
pred = calc_quantile(
leaf_preds[k],
median,
interpolation,
issorted=False,
)
preds_view[i, j, k] = pred[0]
else:
if leaf_preds[0].size() == 1:
preds_view[i, j, 0] = leaf_preds[0][0]
elif leaf_preds[0].size() > 1:
preds_view[i, j, 0] = calc_mean(leaf_preds[0])
for k in range(<intp_t>(leaf_preds.size())):
if leaf_preds[k].size() == 1:
preds_view[i, j, k] = leaf_preds[k][0]
elif leaf_preds[k].size() > 1:
preds_view[i, j, k] = calc_mean(leaf_preds[k])

return np.asarray(preds_view)

Expand Down
40 changes: 28 additions & 12 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,23 +458,17 @@ def check_predict_quantiles(
n_estimators=1 if max_samples_leaf == 1 else 10,
random_state=0,
)
est2 = ExtraTreesQuantileRegressor(
n_estimators=1 if max_samples_leaf == 1 else 10,
default_quantiles=None,
max_samples_leaf=max_samples_leaf,
random_state=0,
)
else:
est1 = RandomForestRegressor(
n_estimators=1 if max_samples_leaf == 1 else 10,
random_state=0,
)
est2 = RandomForestQuantileRegressor(
n_estimators=1 if max_samples_leaf == 1 else 10,
default_quantiles=None,
max_samples_leaf=max_samples_leaf,
random_state=0,
)
est2 = ForestRegressor(
n_estimators=1 if max_samples_leaf == 1 else 10,
default_quantiles=None,
max_samples_leaf=max_samples_leaf,
random_state=0,
)
y_pred_1 = est1.fit(X_train, y_train).predict(X_test)
y_pred_2 = est2.fit(X_train, y_train).predict(
X_test,
Expand Down Expand Up @@ -516,6 +510,28 @@ def check_predict_quantiles(
assert np.any(y_pred[:, 0, ...] != y_pred[:, 1, ...])
assert score > 0.95

# Check unaggregated predictions with absolute error criterion.
if quantiles == 0.5:
X_train_mae = np.array([[1], [1], [3], [1]])
y_train_mae = np.arange(4)

X_test_mae = np.array([[5], [4], [2], [0]])

params = {"criterion": "absolute_error", "max_depth": 1, "random_state": 0}

if name == "ExtraTreesQuantileRegressor":
est1 = ExtraTreesRegressor(**params)
else:
est1 = RandomForestRegressor(**params)
est2 = ForestRegressor(max_samples_leaf=None, **params)

est1.fit(X_train_mae, y_train_mae)
est2.fit(X_train_mae, y_train_mae)

y_pred1 = est1.predict(X_test_mae)
y_pred2 = est2.predict(X_test_mae, quantiles=0.5, aggregate_leaves_first=False)
assert_allclose(y_pred1, y_pred2)

# Check that specifying `quantiles` overwrites `default_quantiles`.
est1 = ForestRegressor(
n_estimators=1,
Expand Down

0 comments on commit 3ee9736

Please sign in to comment.