Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 15, 2015
1 parent 5294500 commit 4d6b07a
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame
from pyspark.ml.param import Param
from pyspark.ml.param.shared import HasMaxIter, HasInputCol
from pyspark.ml.pipeline import Transformer, Estimator, Pipeline


Expand All @@ -46,7 +47,7 @@ class MockTransformer(Transformer):

def __init__(self):
super(MockTransformer, self).__init__()
self.fake = Param(self, "fake", "fake", None)
self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None

Expand All @@ -62,7 +63,7 @@ class MockEstimator(Estimator):

def __init__(self):
super(MockEstimator, self).__init__()
self.fake = Param(self, "fake", "fake", None)
self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None
self.model = None
Expand Down Expand Up @@ -111,5 +112,47 @@ def test_pipeline(self):
self.assertEqual(6, dataset.index)


class TestParams(HasMaxIter, HasInputCol):
"""
A subclass of Params mixed with HasMaxIter and HasInputCol.
"""

def __init__(self):
super(TestParams, self).__init__()
self._setDefault(maxIter=10)


class ParamTests(PySparkTestCase):

def test_param(self):
testParams = TestParams()
maxIter = testParams.maxIter
self.assertEqual(maxIter.name, "maxIter")
self.assertEqual(maxIter.doc, "max number of iterations")
self.assertTrue(maxIter.parent is testParams)

def test_params(self):
testParams = TestParams()
maxIter = testParams.maxIter
inputCol = testParams.inputCol

params = testParams.params
self.assertEqual(params, [inputCol, maxIter])

self.assertTrue(testParams.hasDefault(maxIter))
self.assertFalse(testParams.isSet(maxIter))
self.assertTrue(testParams.isDefined(maxIter))
self.assertEqual(testParams.getMaxIter(), 10)
testParams.setMaxIter(100)
self.assertTrue(testParams.isSet(maxIter))
self.assertEquals(testParams.getMaxIter(), 100)

self.assertFalse(testParams.hasDefault(inputCol))
self.assertFalse(testParams.isSet(inputCol))
self.assertFalse(testParams.isDefined(inputCol))
with self.assertRaises(KeyError):
testParams.getInputCol()


if __name__ == "__main__":
unittest.main()

0 comments on commit 4d6b07a

Please sign in to comment.