Skip to content

Commit

Permalink
Merge pull request neurodata#327 from neurodata/add_n_estimators_to_f…
Browse files Browse the repository at this point in the history
…orest

Add n_estimators parameter to forest add_{task, transformer}
  • Loading branch information
jdey4 authored Oct 19, 2020
2 parents 5e89bd1 + 4ef53d6 commit 2f30a75
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
35 changes: 27 additions & 8 deletions proglearn/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ class LifelongClassificationForest(ClassificationProgressiveLearner):
Parameters
----------
n_estimators : int, default=100
The number of estimators used in the Lifelong Classification Forest
default_n_estimators : int, default=100
The number of trees used in the Lifelong Classification Forest
used if 'n_estimators' is not fed to add_{task, transformer}.
default_tree_construction_proportion : int, default=0.67
The proportions of the input data set aside to train each decision
Expand All @@ -40,12 +41,12 @@ class LifelongClassificationForest(ClassificationProgressiveLearner):

def __init__(
self,
n_estimators=100,
default_n_estimators=100,
default_tree_construction_proportion=0.67,
default_finite_sample_correction=False,
default_max_depth=30,
):
self.n_estimators = n_estimators
self.default_n_estimators = default_n_estimators
self.default_tree_construction_proportion = default_tree_construction_proportion
self.default_finite_sample_correction = default_finite_sample_correction
self.default_max_depth = default_max_depth
Expand All @@ -65,6 +66,7 @@ def add_task(
X,
y,
task_id=None,
n_estimators="default",
tree_construction_proportion="default",
finite_sample_correction="default",
max_depth="default",
Expand All @@ -87,6 +89,9 @@ def add_task(
task_id : obj, default=None
The id corresponding to the task being added.
n_estimators : int or str, default='default'
The number of trees used for the given task.
tree_construction_proportion : int or str, default='default'
The proportions of the input data set aside to train each decision
tree. The remainder of the data is used to fill in voting posteriors.
Expand All @@ -105,6 +110,8 @@ def add_task(
self : LifelongClassificationForest
The object itself.
"""
if n_estimators == "default":
n_estimators = self.default_n_estimators
if tree_construction_proportion == "default":
tree_construction_proportion = self.default_tree_construction_proportion
if finite_sample_correction == "default":
Expand All @@ -121,7 +128,7 @@ def add_task(
1 - tree_construction_proportion,
0,
],
num_transformers=self.n_estimators,
num_transformers=n_estimators,
transformer_kwargs={"kwargs": {"max_depth": max_depth}},
voter_kwargs={
"classes": np.unique(y),
Expand All @@ -131,7 +138,14 @@ def add_task(
)
return self

def add_transformer(self, X, y, transformer_id=None, max_depth="default"):
def add_transformer(
self,
X,
y,
transformer_id=None,
n_estimators="default",
max_depth="default",
):
"""
adds a transformer with id transformer_id and max tree depth max_depth, trained on
given input data matrix, X, and output data matrix, y, to the Lifelong Classification Forest.
Expand All @@ -149,6 +163,9 @@ def add_transformer(self, X, y, transformer_id=None, max_depth="default"):
transformer_id : obj, default=None
The id corresponding to the transformer being added.
n_estimators : int or str, default='default'
The number of trees used for the given task.
max_depth : int or str, default='default'
The maximum depth of a tree in the UncertaintyForest.
The default is used if 'default' is provided.
Expand All @@ -158,6 +175,8 @@ def add_transformer(self, X, y, transformer_id=None, max_depth="default"):
self : LifelongClassificationForest
The object itself.
"""
if n_estimators == "default":
n_estimators = self.default_n_estimators
if max_depth == "default":
max_depth = self.default_max_depth

Expand All @@ -166,7 +185,7 @@ def add_transformer(self, X, y, transformer_id=None, max_depth="default"):
y,
transformer_kwargs={"kwargs": {"max_depth": max_depth}},
transformer_id=transformer_id,
num_transformers=self.n_estimators,
num_transformers=n_estimators,
)

return self
Expand Down Expand Up @@ -256,7 +275,7 @@ def fit(self, X, y):
The object itself.
"""
self.lf_ = LifelongClassificationForest(
n_estimators=self.n_estimators,
default_n_estimators=self.n_estimators,
default_finite_sample_correction=self.finite_sample_correction,
default_max_depth=self.max_depth,
)
Expand Down
2 changes: 1 addition & 1 deletion proglearn/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_correct_default_kwargs(self):

def test_correct_default_n_estimators(self):
l2f = LifelongClassificationForest()
assert l2f.n_estimators == 100
assert l2f.default_n_estimators == 100

def test_correct_true_initilization_finite_sample_correction(self):
l2f = LifelongClassificationForest(default_finite_sample_correction=True)
Expand Down

0 comments on commit 2f30a75

Please sign in to comment.