Skip to content

Commit

Permalink
Merge branch '1.1.X'
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Dec 19, 2022
2 parents e707d8b + bc8308d commit b397a49
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 29 deletions.
4 changes: 2 additions & 2 deletions docs/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ The predictions of a standard random forest can also be recovered from a quantil
>>> X, y = datasets.load_diabetes(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
>>> rf = RandomForestRegressor(random_state=0)
>>> qrf = RandomForestQuantileRegressor(random_state=0)
>>> qrf = RandomForestQuantileRegressor(max_samples_leaf=None, random_state=0)
>>> rf.fit(X_train, y_train), qrf.fit(X_train, y_train)
(RandomForestRegressor(random_state=0), RandomForestQuantileRegressor(random_state=0))
(RandomForestRegressor(random_state=0), RandomForestQuantileRegressor(max_samples_leaf=None, random_state=0))
>>> y_pred_rf = rf.predict(X_test)
>>> y_pred_qrf = qrf.predict(X_test, quantiles=None, aggregate_leaves_first=False)
>>> np.allclose(y_pred_rf, y_pred_qrf)
Expand Down
53 changes: 29 additions & 24 deletions quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a

import joblib
import numpy as np
import sklearn

from sklearn.ensemble._forest import ForestRegressor
from sklearn.ensemble._forest import _generate_sample_indices
Expand All @@ -42,6 +43,8 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a
from ._quantile_forest_fast import QuantileForest
from ._quantile_forest_fast import generate_unsampled_indices

sklearn_version = tuple(map(int, (sklearn.__version__.split('.'))))


def _generate_unsampled_indices(sample_indices, duplicates=None):
"""Private function used by forest._get_unsampled_indices function."""
Expand Down Expand Up @@ -980,10 +983,10 @@ def __init__(
ccp_alpha=0.0,
max_samples=None,
):
super(RandomForestQuantileRegressor, self).__init__(
base_estimator=DecisionTreeRegressor(),
n_estimators=n_estimators,
estimator_params=(
init_dict = {
'base_estimator' if sklearn_version < (1, 2) else 'estimator': DecisionTreeRegressor(),
'n_estimators': n_estimators,
'estimator_params': (
"criterion",
"max_depth",
"min_samples_split",
Expand All @@ -995,14 +998,15 @@ def __init__(
"random_state",
"ccp_alpha",
),
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
max_samples=max_samples,
)
'bootstrap': bootstrap,
'oob_score': oob_score,
'n_jobs': n_jobs,
'random_state': random_state,
'verbose': verbose,
'warm_start': warm_start,
'max_samples': max_samples,
}
super(RandomForestQuantileRegressor, self).__init__(**init_dict)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1253,10 +1257,10 @@ def __init__(
ccp_alpha=0.0,
max_samples=None,
):
super(ExtraTreesQuantileRegressor, self).__init__(
base_estimator=ExtraTreeRegressor(),
n_estimators=n_estimators,
estimator_params=(
init_dict = {
'base_estimator' if sklearn_version < (1, 2) else 'estimator': ExtraTreeRegressor(),
'n_estimators': n_estimators,
'estimator_params': (
"criterion",
"max_depth",
"min_samples_split",
Expand All @@ -1268,14 +1272,15 @@ def __init__(
"random_state",
"ccp_alpha",
),
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
max_samples=max_samples,
)
'bootstrap': bootstrap,
'oob_score': oob_score,
'n_jobs': n_jobs,
'random_state': random_state,
'verbose': verbose,
'warm_start': warm_start,
'max_samples': max_samples,
}
super(ExtraTreesQuantileRegressor, self).__init__(**init_dict)

self.criterion = criterion
self.max_depth = max_depth
Expand Down
2 changes: 1 addition & 1 deletion quantile_forest/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0
1.1.1
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cython >= 3.0a4
numpy
scipy
numpy >= 1.23
scipy >= 1.4
scikit-learn >= 1.0

0 comments on commit b397a49

Please sign in to comment.