diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index f20453e624296..f4c1a6406e020 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -34,4 +34,5 @@ pyspark.ml.classification module .. automodule:: pyspark.ml.classification :members: :undoc-members: + :inherited-members: :show-inheritance: diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 83351dc523ce1..4666ce7bc2499 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -29,6 +29,20 @@ def _jvm(): return SparkContext._jvm +def _inherit_doc(cls): + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls + + @inherit_doc class PipelineStage(Params): """ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f3cea0958c897..fd1fb906ca5c1 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -17,11 +17,13 @@ from pyspark.sql import SchemaRDD, inherit_doc from pyspark.ml import Estimator, Transformer, _jvm -from pyspark.ml.param import Param +from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ + HasRegParam @inherit_doc -class LogisticRegression(Estimator): +class LogisticRegression(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + HasRegParam): """ Logistic regression. """ @@ -29,31 +31,8 @@ class LogisticRegression(Estimator): # _java_class = "org.apache.spark.ml.classification.LogisticRegression" def __init__(self): + super(LogisticRegression, self).__init__() self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression() - self.maxIter = Param(self, "maxIter", "max number of iterations", 100) - self.regParam = Param(self, "regParam", "regularization constant", 0.1) - self.featuresCol = Param(self, "featuresCol", "features column name", "features") - - def setMaxIter(self, value): - self._java_obj.setMaxIter(value) - return self - - def getMaxIter(self): - return self._java_obj.getMaxIter() - - def setRegParam(self, value): - self._java_obj.setRegParam(value) - return self - - def getRegParam(self): - return self._java_obj.getRegParam() - - def setFeaturesCol(self, value): - self._java_obj.setFeaturesCol(value) - return self - - def getFeaturesCol(self): - return self._java_obj.getFeaturesCol() def fit(self, dataset, params=None): """ diff --git a/python/pyspark/ml/param.py b/python/pyspark/ml/param/__init__.py similarity index 78% rename from python/pyspark/ml/param.py rename to python/pyspark/ml/param/__init__.py index ffe58a6ee69d7..89e5d732f7586 100644 --- a/python/pyspark/ml/param.py +++ b/python/pyspark/ml/param/__init__.py @@ -19,7 +19,6 @@ from pyspark.ml.util import Identifiable - __all__ = ["Param"] @@ -29,16 +28,18 @@ class Param(object): """ def __init__(self, parent, name, doc, defaultValue=None): + if not isinstance(parent, Identifiable): + raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) self.parent = parent - self.name = name - self.doc = doc + self.name = str(name) + self.doc = str(doc) self.defaultValue = defaultValue def __str__(self): - return self.parent + "_" + self.name + return str(self.parent) + "_" + self.name def __repr__(self): - return self.parent + "_" + self.name + return str(self.parent) + "_" + self.name class Params(Identifiable): @@ -49,10 +50,11 @@ class Params(Identifiable): __metaclass__ = ABCMeta + #: Internal param map. + paramMap = {} + def __init__(self): super(Params, self).__init__() - #: Internal param map. - self.paramMap = {} def params(self): """ @@ -60,4 +62,4 @@ def params(self): :py:func:`dir` to get all attributes of type :py:class:`Param`. """ - return [attr for attr in dir(self) if isinstance(attr, Param)] + return filter(lambda x: isinstance(x, Param), map(lambda x: getattr(self, x), dir(self))) diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py new file mode 100644 index 0000000000000..8c3aa7eba9483 --- /dev/null +++ b/python/pyspark/ml/param/_gen_shared_params.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +header = """# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#""" + + +def _gen_param_code(name, doc, defaultValue): + """ + Generates Python code for a shared param class. + + :param name: param name + :param doc: param doc + :param defaultValue: string representation of the param + :return: code string + """ + upperCamelName = name[0].upper() + name[1:] + return """class Has%s(Params): + + def __init__(self): + super(Has%s, self).__init__() + #: %s + self.%s = Param(self, "%s", "%s", %s) + + def set%s(self, value): + self.paramMap[self.%s] = value + return self + + def get%s(self, value): + if self.%s in self.paramMap: + return self.paramMap[self.%s] + else: + return self.defaultValue""" % ( + upperCamelName, upperCamelName, doc, name, name, doc, defaultValue, upperCamelName, name, + upperCamelName, name, name) + +if __name__ == "__main__": + print header + print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" + print "from pyspark.ml.param import Param, Params\n\n" + shared = [ + ("maxIter", "max number of iterations", "100"), + ("regParam", "regularization constant", "0.1"), + ("featuresCol", "features column name", "'features'"), + ("labelCol", "label column name", "'label'"), + ("predictionCol", "prediction column name", "'prediction'"), + ("inputCol", "input column name", "'input'"), + ("outputCol", "output column name", "'output'")] + code = [] + for name, doc, defaultValue in shared: + code.append(_gen_param_code(name, doc, defaultValue)) + print "\n\n\n".join(code) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py new file mode 100644 index 0000000000000..88afb5481f7b8 --- /dev/null +++ b/python/pyspark/ml/param/shared.py @@ -0,0 +1,146 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DO NOT MODIFY. The code is generated by _gen_shared_params.py. + +from pyspark.ml.param import Param, Params + + +class HasMaxIter(Params): + + def __init__(self): + super(HasMaxIter, self).__init__() + #: max number of iterations + self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + + def setMaxIter(self, value): + self.paramMap[self.maxIter] = value + return self + + def getMaxIter(self, value): + if self.maxIter in self.paramMap: + return self.paramMap[self.maxIter] + else: + return self.defaultValue + + +class HasRegParam(Params): + + def __init__(self): + super(HasRegParam, self).__init__() + #: regularization constant + self.regParam = Param(self, "regParam", "regularization constant", 0.1) + + def setRegParam(self, value): + self.paramMap[self.regParam] = value + return self + + def getRegParam(self, value): + if self.regParam in self.paramMap: + return self.paramMap[self.regParam] + else: + return self.defaultValue + + +class HasFeaturesCol(Params): + + def __init__(self): + super(HasFeaturesCol, self).__init__() + #: features column name + self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + + def setFeaturesCol(self, value): + self.paramMap[self.featuresCol] = value + return self + + def getFeaturesCol(self, value): + if self.featuresCol in self.paramMap: + return self.paramMap[self.featuresCol] + else: + return self.defaultValue + + +class HasLabelCol(Params): + + def __init__(self): + super(HasLabelCol, self).__init__() + #: label column name + self.labelCol = Param(self, "labelCol", "label column name", 'label') + + def setLabelCol(self, value): + self.paramMap[self.labelCol] = value + return self + + def getLabelCol(self, value): + if self.labelCol in self.paramMap: + return self.paramMap[self.labelCol] + else: + return self.defaultValue + + +class HasPredictionCol(Params): + + def __init__(self): + super(HasPredictionCol, self).__init__() + #: prediction column name + self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + + def setPredictionCol(self, value): + self.paramMap[self.predictionCol] = value + return self + + def getPredictionCol(self, value): + if self.predictionCol in self.paramMap: + return self.paramMap[self.predictionCol] + else: + return self.defaultValue + + +class HasInputCol(Params): + + def __init__(self): + super(HasInputCol, self).__init__() + #: input column name + self.inputCol = Param(self, "inputCol", "input column name", 'input') + + def setInputCol(self, value): + self.paramMap[self.inputCol] = value + return self + + def getInputCol(self, value): + if self.inputCol in self.paramMap: + return self.paramMap[self.inputCol] + else: + return self.defaultValue + + +class HasOutputCol(Params): + + def __init__(self): + super(HasOutputCol, self).__init__() + #: output column name + self.outputCol = Param(self, "outputCol", "output column name", 'output') + + def setOutputCol(self, value): + self.paramMap[self.outputCol] = value + return self + + def getOutputCol(self, value): + if self.outputCol in self.paramMap: + return self.paramMap[self.outputCol] + else: + return self.defaultValue diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 801a5eeaa3249..5d74088b0b13e 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -27,3 +27,9 @@ def __init__(self): #: A unique id for the object. The default implementation #: concatenates the class name, "-", and 8 random hex chars. self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + + def __str__(self): + return self.uid + + def __repr__(self): + return str(self)