Skip to content

Commit

Permalink
code gen for shared params
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 21, 2015
1 parent d9ea77c commit 17ecfb9
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 34 deletions.
1 change: 1 addition & 0 deletions python/docs/pyspark.ml.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ pyspark.ml.classification module
.. automodule:: pyspark.ml.classification
:members:
:undoc-members:
:inherited-members:
:show-inheritance:
14 changes: 14 additions & 0 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ def _jvm():
return SparkContext._jvm


def _inherit_doc(cls):
for name, func in vars(cls).items():
# only inherit docstring for public functions
if name.startswith("_"):
continue
if not func.__doc__:
for parent in cls.__bases__:
parent_func = getattr(parent, name, None)
if parent_func and getattr(parent_func, "__doc__", None):
func.__doc__ = parent_func.__doc__
break
return cls


@inherit_doc
class PipelineStage(Params):
"""
Expand Down
31 changes: 5 additions & 26 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,22 @@

from pyspark.sql import SchemaRDD, inherit_doc
from pyspark.ml import Estimator, Transformer, _jvm
from pyspark.ml.param import Param
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
HasRegParam


@inherit_doc
class LogisticRegression(Estimator):
class LogisticRegression(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam):
"""
Logistic regression.
"""

# _java_class = "org.apache.spark.ml.classification.LogisticRegression"

def __init__(self):
super(LogisticRegression, self).__init__()
self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression()
self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
self.regParam = Param(self, "regParam", "regularization constant", 0.1)
self.featuresCol = Param(self, "featuresCol", "features column name", "features")

def setMaxIter(self, value):
self._java_obj.setMaxIter(value)
return self

def getMaxIter(self):
return self._java_obj.getMaxIter()

def setRegParam(self, value):
self._java_obj.setRegParam(value)
return self

def getRegParam(self):
return self._java_obj.getRegParam()

def setFeaturesCol(self, value):
self._java_obj.setFeaturesCol(value)
return self

def getFeaturesCol(self):
return self._java_obj.getFeaturesCol()

def fit(self, dataset, params=None):
"""
Expand Down
18 changes: 10 additions & 8 deletions python/pyspark/ml/param.py → python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from pyspark.ml.util import Identifiable


__all__ = ["Param"]


Expand All @@ -29,16 +28,18 @@ class Param(object):
"""

def __init__(self, parent, name, doc, defaultValue=None):
if not isinstance(parent, Identifiable):
raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
self.parent = parent
self.name = name
self.doc = doc
self.name = str(name)
self.doc = str(doc)
self.defaultValue = defaultValue

def __str__(self):
return self.parent + "_" + self.name
return str(self.parent) + "_" + self.name

def __repr__(self):
return self.parent + "_" + self.name
return str(self.parent) + "_" + self.name


class Params(Identifiable):
Expand All @@ -49,15 +50,16 @@ class Params(Identifiable):

__metaclass__ = ABCMeta

#: Internal param map.
paramMap = {}

def __init__(self):
super(Params, self).__init__()
#: Internal param map.
self.paramMap = {}

def params(self):
"""
Returns all params. The default implementation uses
:py:func:`dir` to get all attributes of type
:py:class:`Param`.
"""
return [attr for attr in dir(self) if isinstance(attr, Param)]
return filter(lambda x: isinstance(x, Param), map(lambda x: getattr(self, x), dir(self)))
80 changes: 80 additions & 0 deletions python/pyspark/ml/param/_gen_shared_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

header = """#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#"""


def _gen_param_code(name, doc, defaultValue):
"""
Generates Python code for a shared param class.
:param name: param name
:param doc: param doc
:param defaultValue: string representation of the param
:return: code string
"""
upperCamelName = name[0].upper() + name[1:]
return """class Has%s(Params):
def __init__(self):
super(Has%s, self).__init__()
#: %s
self.%s = Param(self, "%s", "%s", %s)
def set%s(self, value):
self.paramMap[self.%s] = value
return self
def get%s(self, value):
if self.%s in self.paramMap:
return self.paramMap[self.%s]
else:
return self.defaultValue""" % (
upperCamelName, upperCamelName, doc, name, name, doc, defaultValue, upperCamelName, name,
upperCamelName, name, name)

if __name__ == "__main__":
print header
print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
print "from pyspark.ml.param import Param, Params\n\n"
shared = [
("maxIter", "max number of iterations", "100"),
("regParam", "regularization constant", "0.1"),
("featuresCol", "features column name", "'features'"),
("labelCol", "label column name", "'label'"),
("predictionCol", "prediction column name", "'prediction'"),
("inputCol", "input column name", "'input'"),
("outputCol", "output column name", "'output'")]
code = []
for name, doc, defaultValue in shared:
code.append(_gen_param_code(name, doc, defaultValue))
print "\n\n\n".join(code)
146 changes: 146 additions & 0 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# DO NOT MODIFY. The code is generated by _gen_shared_params.py.

from pyspark.ml.param import Param, Params


class HasMaxIter(Params):

def __init__(self):
super(HasMaxIter, self).__init__()
#: max number of iterations
self.maxIter = Param(self, "maxIter", "max number of iterations", 100)

def setMaxIter(self, value):
self.paramMap[self.maxIter] = value
return self

def getMaxIter(self, value):
if self.maxIter in self.paramMap:
return self.paramMap[self.maxIter]
else:
return self.defaultValue


class HasRegParam(Params):

def __init__(self):
super(HasRegParam, self).__init__()
#: regularization constant
self.regParam = Param(self, "regParam", "regularization constant", 0.1)

def setRegParam(self, value):
self.paramMap[self.regParam] = value
return self

def getRegParam(self, value):
if self.regParam in self.paramMap:
return self.paramMap[self.regParam]
else:
return self.defaultValue


class HasFeaturesCol(Params):

def __init__(self):
super(HasFeaturesCol, self).__init__()
#: features column name
self.featuresCol = Param(self, "featuresCol", "features column name", 'features')

def setFeaturesCol(self, value):
self.paramMap[self.featuresCol] = value
return self

def getFeaturesCol(self, value):
if self.featuresCol in self.paramMap:
return self.paramMap[self.featuresCol]
else:
return self.defaultValue


class HasLabelCol(Params):

def __init__(self):
super(HasLabelCol, self).__init__()
#: label column name
self.labelCol = Param(self, "labelCol", "label column name", 'label')

def setLabelCol(self, value):
self.paramMap[self.labelCol] = value
return self

def getLabelCol(self, value):
if self.labelCol in self.paramMap:
return self.paramMap[self.labelCol]
else:
return self.defaultValue


class HasPredictionCol(Params):

def __init__(self):
super(HasPredictionCol, self).__init__()
#: prediction column name
self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction')

def setPredictionCol(self, value):
self.paramMap[self.predictionCol] = value
return self

def getPredictionCol(self, value):
if self.predictionCol in self.paramMap:
return self.paramMap[self.predictionCol]
else:
return self.defaultValue


class HasInputCol(Params):

def __init__(self):
super(HasInputCol, self).__init__()
#: input column name
self.inputCol = Param(self, "inputCol", "input column name", 'input')

def setInputCol(self, value):
self.paramMap[self.inputCol] = value
return self

def getInputCol(self, value):
if self.inputCol in self.paramMap:
return self.paramMap[self.inputCol]
else:
return self.defaultValue


class HasOutputCol(Params):

def __init__(self):
super(HasOutputCol, self).__init__()
#: output column name
self.outputCol = Param(self, "outputCol", "output column name", 'output')

def setOutputCol(self, value):
self.paramMap[self.outputCol] = value
return self

def getOutputCol(self, value):
if self.outputCol in self.paramMap:
return self.paramMap[self.outputCol]
else:
return self.defaultValue
6 changes: 6 additions & 0 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ def __init__(self):
#: A unique id for the object. The default implementation
#: concatenates the class name, "-", and 8 random hex chars.
self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]

def __str__(self):
return self.uid

def __repr__(self):
return str(self)

0 comments on commit 17ecfb9

Please sign in to comment.