Skip to content

Commit

Permalink
Merge pull request #2 from neurodata/master
Browse files Browse the repository at this point in the history
0918 Merge
  • Loading branch information
david-z-shi authored Sep 18, 2020
2 parents 1613caa + 04071c6 commit aa3ecf1
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 29 deletions.
58 changes: 29 additions & 29 deletions proglearn/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class UncertaintyForest:
A lifelong classification forest object
n_estimators : int
The number of estimaters used in the
LifelongClassificationForest
LifelongClassificationForest
finite_sample_correction : bool
Boolean indicating whether this learner
will have finite sample correction used
Expand All @@ -72,52 +72,52 @@ class UncertaintyForest:
Methods
---
fit(X, y)
fits forest to data X with labels y
fits forest to data X with labels y
predict(X)
predicts class labels given data, X
predicts class labels given data, X
predict_proba(X)
predicts posterior probabilities given data, X, of each class label
predicts posterior probabilities given data, X, of each class label
"""
def __init__(self, n_estimators=100, finite_sample_correction=False):
self.n_estimators = n_estimators
self.finite_sample_correction = finite_sample_correction

def fit(self, X, y):
"""
fits data X given class labels y
"""
fits data X given class labels y
Attributes
---
X : array of shape [n_samples, n_features]
The data that will be trained on
y : array of shape [n_samples]
The label for cluster membership of the given data
"""
Attributes
---
X : array of shape [n_samples, n_features]
The data that will be trained on
y : array of shape [n_samples]
The label for cluster membership of the given data
"""
self.lf = LifelongClassificationForest(
n_estimators=self.n_estimators,
finite_sample_correction=self.finite_sample_correction,
n_estimators = self.n_estimators,
finite_sample_correction = self.finite_sample_correction
)
self.lf.add_task(X, y, task_id=0)
return self

def predict(self, X):
"""
predicts the class labels given data X
"""
predicts the class labels given data X
Attributes
---
X : array of shape [n_samples, n_features]
The data on which we are performing inference.
"""
Attributes
---
X : array of shape [n_samples, n_features]
The data on which we are performing inference.
"""
return self.lf.predict(X, 0)

def predict_proba(self, X):
"""
returns the posterior probabilities of each class for data X
"""
returns the posterior probabilities of each class for data X
Attributes
---
X : array of shape [n_samples, n_features]
The data whose posteriors we are estimating.
"""
Attributes
---
X : array of shape [n_samples, n_features]
The data whose posteriors we are estimating.
"""
return self.lf.predict_proba(X, 0)
50 changes: 50 additions & 0 deletions proglearn/tests/test_lifelongclassificationforest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest
import pytest
import numpy as np
import random

from proglearn.forest import LifelongClassificationForest
from proglearn.transformers import TreeClassificationTransformer
from proglearn.voters import TreeClassificationVoter
from proglearn.deciders import SimpleArgmaxAverage

class test_LifelongClassificationForest(unittest.TestCase):

def setUp(self):
self.l2f = LifelongClassificationForest()

def test_initialize(self):
self.assertTrue(True)

def test_correct_default_transformer(self):
self.assertIs(self.l2f.pl.default_transformer_class, TreeClassificationTransformer)

def test_correct_default_voter(self):
self.assertIs(self.l2f.pl.default_voter_class, TreeClassificationVoter)

def test_correct_default_decider(self):
self.assertIs(self.l2f.pl.default_decider_class, SimpleArgmaxAverage)

def test_correct_default_kwargs_transformer_decider_empty(self):
self.assertFalse(self.l2f.pl.default_transformer_kwargs)
self.assertFalse(self.l2f.pl.default_decider_kwargs)

def test_correct_default_estimators(self):
self.assertIs(self.l2f.n_estimators, 100)

def test_correct_estimator(self):
rand = random.randint(0, 100)
l2f = LifelongClassificationForest(n_estimators=rand)
self.assertIs(l2f.n_estimators, rand)

def test_correct_default_finite_sample_correction(self):
tmp_dict = {"finite_sample_correction": False}
self.assertEqual(self.l2f.pl.default_voter_kwargs, tmp_dict)

def test_correct_true_initilization_finite_sample_correction(self):
tmp_dict = {"finite_sample_correction": True}
l2f = LifelongClassificationForest(finite_sample_correction=True)
self.assertEqual(l2f.pl.default_voter_kwargs, tmp_dict)

if __name__ == '__main__':
unittest.main()

0 comments on commit aa3ecf1

Please sign in to comment.