From 5153cff555d5188dc9a2844fe558e73af38fbd01 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 27 Jan 2015 14:12:21 -0800 Subject: [PATCH] simplify java models --- .../org/apache/spark/ml/param/params.scala | 1 - python/pyspark/ml/__init__.py | 50 ++++++++++++++++--- python/pyspark/ml/classification.py | 14 +++--- 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 33f7a3900a98e..5fb4379e23c2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -293,7 +293,6 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten new ParamMap(this.map ++ other.map) } - /** * Adds all parameters from the input param map into this param map. */ diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index e16214ae18431..a193442841f65 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -101,6 +101,18 @@ def transform(self, dataset, params={}): raise NotImplementedError() +@inherit_doc +class Model(Transformer): + """ + Abstract class for models fitted by :py:class:`Estimator`s. + """ + + ___metaclass__ = ABCMeta + + def __init__(self): + super(Model, self).__init__() + + @inherit_doc class Pipeline(Estimator): """ @@ -169,7 +181,7 @@ def fit(self, dataset, params={}): @inherit_doc -class PipelineModel(Transformer): +class PipelineModel(Model): """ Represents a compiled pipeline with transformers and fitted models. """ @@ -204,9 +216,9 @@ def _java_class(self): """ raise NotImplementedError - def _create_java_obj(self): + def _java_obj(self): """ - Creates a new Java object and returns its reference. + Returns or creates a Java object. """ java_obj = _jvm() for name in self._java_class.split("."): @@ -231,6 +243,13 @@ def _empty_java_param_map(self): """ return _jvm().org.apache.spark.ml.param.ParamMap() + def _create_java_param_map(self, params, java_obj): + paramMap = self._empty_java_param_map() + for param, value in params.items(): + if param.parent is self: + paramMap.put(java_obj.getParam(param.name), value) + return paramMap + @inherit_doc class JavaEstimator(Estimator, JavaWrapper): @@ -259,7 +278,7 @@ def _fit_java(self, dataset, params={}): :param params: additional params (overwriting embedded values) :return: fitted Java model """ - java_obj = self._create_java_obj() + java_obj = self._java_obj() self._transfer_params_to_java(params, java_obj) return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map()) @@ -281,7 +300,24 @@ def __init__(self): super(JavaTransformer, self).__init__() def transform(self, dataset, params={}): - java_obj = self._create_java_obj() - self._transfer_params_to_java(params, java_obj) - return SchemaRDD(java_obj.transform(dataset._jschema_rdd, self._empty_java_param_map()), + java_obj = self._java_obj() + self._transfer_params_to_java({}, java_obj) + java_param_map = self._create_java_param_map(params, java_obj) + return SchemaRDD(java_obj.transform(dataset._jschema_rdd, java_param_map), dataset.sql_ctx) + + +@inherit_doc +class JavaModel(JavaTransformer): + """ + Base class for :py:class:`Model`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def __init__(self): + super(JavaTransformer, self).__init__() + + def _java_obj(self): + return self._java_model diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ab29fb3220a63..4628cef6e255c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -15,8 +15,8 @@ # limitations under the License. # -from pyspark.sql import SchemaRDD, inherit_doc -from pyspark.ml import JavaEstimator, Transformer, _jvm +from pyspark.sql import inherit_doc +from pyspark.ml import JavaEstimator, JavaModel from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ HasRegParam @@ -40,7 +40,7 @@ def _create_model(self, java_model): @inherit_doc -class LogisticRegressionModel(Transformer): +class LogisticRegressionModel(JavaModel): """ Model fitted by LogisticRegression. """ @@ -49,8 +49,6 @@ def __init__(self, java_model): super(LogisticRegressionModel, self).__init__() self._java_model = java_model - def transform(self, dataset, params={}): - # TODO: handle params here. - return SchemaRDD(self._java_model.transform( - dataset._jschema_rdd, - _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx) + @property + def _java_class(self): + return "org.apache.spark.ml.classification.LogisticRegressionModel"