From 13bd70adada9f49d7927fa4eaa8c958325a2ae06 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 12 May 2015 11:05:52 -0700 Subject: [PATCH] update ml/tests.py --- python/pyspark/ml/param/__init__.py | 15 ++++----- python/pyspark/ml/pipeline.py | 15 +++++++-- python/pyspark/ml/tests.py | 47 ++++++++++++----------------- 3 files changed, 37 insertions(+), 40 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index d30139b88767c..537965e835881 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -139,22 +139,19 @@ def hasParam(self, paramName): Tests whether this instance contains a param with a given (string) name. """ - return self.params.count(paramName) != 0 + param = self._resolveParam(paramName) + return param in self.params def getOrDefault(self, param): """ Gets the value of a param in the user-supplied param map or its default value. Raises an error if either is set. """ - if isinstance(param, Param): - if param in self._paramMap: - return self._paramMap[param] - else: - return self._defaultParamMap[param] - elif isinstance(param, str): - return self.getOrDefault(self.getParam(param)) + param = self._resolveParam(param) + if param in self._paramMap: + return self._paramMap[param] else: - raise KeyError("Cannot recognize %r as a param." % param) + return self._defaultParamMap[param] def extractParamMap(self, extra={}): """ diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 343db7aa29d73..114281354d393 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -60,7 +60,10 @@ def fit(self, dataset, params={}): if isinstance(params, (list, tuple)): return [self.fit(dataset, paramMap) for paramMap in params] elif isinstance(params, dict): - return self.copy(params)._fit(dataset) + if params: + return self.copy(params)._fit(dataset) + else: + return self._fit(dataset) else: raise ValueError("Params must be either a param map or a list/tuple of param maps, " "but got %s." % type(params)) @@ -97,7 +100,10 @@ def transform(self, dataset, params={}): :returns: transformed dataset """ if isinstance(params, dict): - return self.copy(params,)._transform(dataset) + if params: + return self.copy(params,)._transform(dataset) + else: + return self._transform(dataset) else: raise ValueError("Params must be either a param map but got %s." % type(params)) @@ -263,6 +269,9 @@ def evaluate(self, dataset, params={}): :return: metric """ if isinstance(params, dict): - return self.copy(params)._evaluate(dataset) + if params: + return self.copy(params)._evaluate(dataset) + else: + return self._evaluate(dataset) else: raise ValueError("Params must be a param map but got %s." % type(params)) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1edfee60689f0..26bdd407ccc2b 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -32,7 +32,7 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame -from pyspark.ml.param import Param +from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer @@ -43,44 +43,38 @@ def __init__(self): self.index = 0 -class MockTransformer(Transformer): +class HasFake(Params): + + def __init__(self): + super(HasFake, self).__init__() + self.fake = Param(self, "fake", "fake param") + + +class MockTransformer(Transformer, HasFake): def __init__(self): super(MockTransformer, self).__init__() - self.fake = Param(self, "fake", "fake") self.dataset_index = None - self.fake_param_value = None - def transform(self, dataset, params={}): + def _transform(self, dataset): self.dataset_index = dataset.index - if self.fake in params: - self.fake_param_value = params[self.fake] dataset.index += 1 return dataset -class MockEstimator(Estimator): +class MockEstimator(Estimator, HasFake): def __init__(self): super(MockEstimator, self).__init__() - self.fake = Param(self, "fake", "fake") self.dataset_index = None - self.fake_param_value = None - self.model = None - def fit(self, dataset, params={}): + def _fit(self, dataset): self.dataset_index = dataset.index - if self.fake in params: - self.fake_param_value = params[self.fake] model = MockModel() - self.model = model return model -class MockModel(MockTransformer, Model): - - def __init__(self): - super(MockModel, self).__init__() +class MockModel(MockTransformer, Model, HasFake): pass class PipelineTests(PySparkTestCase): @@ -94,16 +88,13 @@ def test_pipeline(self): pipeline = Pipeline() \ .setStages([estimator0, transformer1, estimator2, transformer3]) pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) - self.assertEqual(0, estimator0.dataset_index) - self.assertEqual(0, estimator0.fake_param_value) - model0 = estimator0.model + model0, transformer1, model2, transformer3 = pipeline_model.stages self.assertEqual(0, model0.dataset_index) self.assertEqual(1, transformer1.dataset_index) - self.assertEqual(1, transformer1.fake_param_value) - self.assertEqual(2, estimator2.dataset_index) - model2 = estimator2.model - self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should " - "not be called during fit.") + self.assertEqual(2, dataset.index) + self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") + self.assertIsNone(transformer3.dataset_index, + "The last transformer shouldn't be called in fit.") dataset = pipeline_model.transform(dataset) self.assertEqual(2, model0.dataset_index) self.assertEqual(3, transformer1.dataset_index) @@ -129,7 +120,7 @@ def test_param(self): maxIter = testParams.maxIter self.assertEqual(maxIter.name, "maxIter") self.assertEqual(maxIter.doc, "max number of iterations (>= 0)") - self.assertTrue(maxIter.parent is testParams) + self.assertTrue(maxIter.parent == testParams.uid) def test_params(self): testParams = TestParams()