Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark with full data: Figure 1 notebook #328

Merged
merged 12 commits into from
Oct 21, 2020
187 changes: 187 additions & 0 deletions benchmarks/uf_posterior_visualization/uncertaintyforest_fig1.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/tutorials/functions/unc_forest_tutorials_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,5 @@ def plot_fig1(algos, num_plotted_trials, X_eval):
axes[3].set_ylabel(r"Var($\hat P(Y = 1|X = x)$)")

fig.tight_layout()
plt.savefig("fig1.pdf")
# plt.savefig("fig1.pdf")
plt.show()
14 changes: 13 additions & 1 deletion proglearn/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,28 @@ class UncertaintyForest:
max_depth : int, default=30
The maximum depth of a tree in the UncertaintyForest

tree_construction_proportion : float, default = 0.67
The proportions of the input data set aside to train each decision
tree. The remainder of the data is used to fill in voting posteriors.

Attributes
----------
lf_ : LifelongClassificationForest
Internal LifelongClassificationForest used to train and make
inference.
"""

def __init__(self, n_estimators=100, finite_sample_correction=False, max_depth=30):
def __init__(
self,
n_estimators=100,
finite_sample_correction=False,
max_depth=30,
tree_construction_proportion=0.67,
):
self.n_estimators = n_estimators
self.finite_sample_correction = finite_sample_correction
self.max_depth = max_depth
self.tree_construction_proportion = tree_construction_proportion

def fit(self, X, y):
"""
Expand All @@ -278,6 +289,7 @@ def fit(self, X, y):
default_n_estimators=self.n_estimators,
default_finite_sample_correction=self.finite_sample_correction,
default_max_depth=self.max_depth,
default_tree_construction_proportion=self.tree_construction_proportion,
)
self.lf_.add_task(X, y, task_id=0)
return self
Expand Down