Skip to content

Commit

Permalink
use _fit/_transform/_evaluate to simplify the impl
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 12, 2015
1 parent 02abf13 commit 64a536c
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 122 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
from pyspark.ml.param import *
from pyspark.ml.pipeline import *

__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"]
__all__ = ["Param", "Params", "Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"]
2 changes: 1 addition & 1 deletion python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def setMetricName(self, value):
"""
Sets the value of :py:attr:`metricName`.
"""
self.paramMap[self.metricName] = value
self._paramMap[self.metricName] = value
return self

def getMetricName(self):
Expand Down
28 changes: 14 additions & 14 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def setThreshold(self, value):
"""
Sets the value of :py:attr:`threshold`.
"""
self.paramMap[self.threshold] = value
self._paramMap[self.threshold] = value
return self

def getThreshold(self):
Expand Down Expand Up @@ -171,7 +171,7 @@ def setMinDocFreq(self, value):
"""
Sets the value of :py:attr:`minDocFreq`.
"""
self.paramMap[self.minDocFreq] = value
self._paramMap[self.minDocFreq] = value
return self

def getMinDocFreq(self):
Expand Down Expand Up @@ -234,7 +234,7 @@ def setP(self, value):
"""
Sets the value of :py:attr:`p`.
"""
self.paramMap[self.p] = value
self._paramMap[self.p] = value
return self

def getP(self):
Expand Down Expand Up @@ -299,7 +299,7 @@ def setIncludeFirst(self, value):
"""
Sets the value of :py:attr:`includeFirst`.
"""
self.paramMap[self.includeFirst] = value
self._paramMap[self.includeFirst] = value
return self

def getIncludeFirst(self):
Expand Down Expand Up @@ -356,7 +356,7 @@ def setDegree(self, value):
"""
Sets the value of :py:attr:`degree`.
"""
self.paramMap[self.degree] = value
self._paramMap[self.degree] = value
return self

def getDegree(self):
Expand Down Expand Up @@ -430,7 +430,7 @@ def setMinTokenLength(self, value):
"""
Sets the value of :py:attr:`minTokenLength`.
"""
self.paramMap[self.minTokenLength] = value
self._paramMap[self.minTokenLength] = value
return self

def getMinTokenLength(self):
Expand All @@ -443,7 +443,7 @@ def setGaps(self, value):
"""
Sets the value of :py:attr:`gaps`.
"""
self.paramMap[self.gaps] = value
self._paramMap[self.gaps] = value
return self

def getGaps(self):
Expand All @@ -456,7 +456,7 @@ def setPattern(self, value):
"""
Sets the value of :py:attr:`pattern`.
"""
self.paramMap[self.pattern] = value
self._paramMap[self.pattern] = value
return self

def getPattern(self):
Expand Down Expand Up @@ -511,7 +511,7 @@ def setWithMean(self, value):
"""
Sets the value of :py:attr:`withMean`.
"""
self.paramMap[self.withMean] = value
self._paramMap[self.withMean] = value
return self

def getWithMean(self):
Expand All @@ -524,7 +524,7 @@ def setWithStd(self, value):
"""
Sets the value of :py:attr:`withStd`.
"""
self.paramMap[self.withStd] = value
self._paramMap[self.withStd] = value
return self

def getWithStd(self):
Expand Down Expand Up @@ -754,7 +754,7 @@ def setMaxCategories(self, value):
"""
Sets the value of :py:attr:`maxCategories`.
"""
self.paramMap[self.maxCategories] = value
self._paramMap[self.maxCategories] = value
return self

def getMaxCategories(self):
Expand Down Expand Up @@ -823,7 +823,7 @@ def setVectorSize(self, value):
"""
Sets the value of :py:attr:`vectorSize`.
"""
self.paramMap[self.vectorSize] = value
self._paramMap[self.vectorSize] = value
return self

def getVectorSize(self):
Expand All @@ -836,7 +836,7 @@ def setNumPartitions(self, value):
"""
Sets the value of :py:attr:`numPartitions`.
"""
self.paramMap[self.numPartitions] = value
self._paramMap[self.numPartitions] = value
return self

def getNumPartitions(self):
Expand All @@ -849,7 +849,7 @@ def setMinCount(self, value):
"""
Sets the value of :py:attr:`minCount`.
"""
self.paramMap[self.minCount] = value
self._paramMap[self.minCount] = value
return self

def getMinCount(self):
Expand Down
55 changes: 30 additions & 25 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

from abc import ABCMeta

import copy

from pyspark.ml.util import Identifiable
Expand All @@ -31,9 +30,9 @@ class Param(object):
"""

def __init__(self, parent, name, doc):
if not isinstance(parent, Params):
raise TypeError("Parent must be a Params but got type %s." % type(parent))
self.parent = parent
if not isinstance(parent, Identifiable):
raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
self.parent = parent.uid
self.name = str(name)
self.doc = str(doc)

Expand All @@ -43,6 +42,13 @@ def __str__(self):
def __repr__(self):
return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)

def __hash__(self):
return hash(str(self))

def __eq__(self, other):
if isinstance(other, Param):
return self.parent == other.parent and self.name == other.name


class Params(Identifiable):
"""
Expand All @@ -53,10 +59,10 @@ class Params(Identifiable):
__metaclass__ = ABCMeta

#: internal param map for user-supplied values param map
paramMap = {}
_paramMap = {}

#: internal param map for default values
defaultParamMap = {}
_defaultParamMap = {}

#: value returned by :py:func:`params`
_params = None
Expand All @@ -73,18 +79,18 @@ def params(self):
[getattr(self, x) for x in dir(self) if x != "params"]))
return self._params

def _explain(self, param):
def explainParam(self, param):
"""
Explains a single param and returns its name, doc, and optional
default value and user-supplied value in a string.
"""
param = self._resolveParam(param)
values = []
if self.isDefined(param):
if param in self.defaultParamMap:
values.append("default: %s" % self.defaultParamMap[param])
if param in self.paramMap:
values.append("current: %s" % self.paramMap[param])
if param in self._defaultParamMap:
values.append("default: %s" % self._defaultParamMap[param])
if param in self._paramMap:
values.append("current: %s" % self._paramMap[param])
else:
values.append("undefined")
valueStr = "(" + ", ".join(values) + ")"
Expand All @@ -95,7 +101,7 @@ def explainParams(self):
Returns the documentation of all params with their optionally
default values and user-supplied values.
"""
return "\n".join([self._explain(param) for param in self.params])
return "\n".join([self.explainParam(param) for param in self.params])

def getParam(self, paramName):
"""
Expand All @@ -112,14 +118,14 @@ def isSet(self, param):
Checks whether a param is explicitly set by user.
"""
param = self._resolveParam(param)
return param in self.paramMap
return param in self._paramMap

def hasDefault(self, param):
"""
Checks whether a param has a default value.
"""
param = self._resolveParam(param)
return param in self.defaultParamMap
return param in self._defaultParamMap

def isDefined(self, param):
"""
Expand All @@ -141,10 +147,10 @@ def getOrDefault(self, param):
default value. Raises an error if either is set.
"""
if isinstance(param, Param):
if param in self.paramMap:
return self.paramMap[param]
if param in self._paramMap:
return self._paramMap[param]
else:
return self.defaultParamMap[param]
return self._defaultParamMap[param]
elif isinstance(param, str):
return self.getOrDefault(self.getParam(param))
else:
Expand All @@ -160,8 +166,8 @@ def extractParamMap(self, extra={}):
:param extra: extra param values
:return: merged param map
"""
paramMap = self.defaultParamMap.copy()
paramMap.update(self.paramMap)
paramMap = self._defaultParamMap.copy()
paramMap.update(self._paramMap)
paramMap.update(extra)
return paramMap

Expand All @@ -178,15 +184,14 @@ def copy(self, extra={}):
:return: Copy of this instance
"""
that = copy.copy(self)
that.uid = that._generateUID()
that.paramMap = copy.deepcopy(self.paramMap)
return self._copyValues(that, extra)
that._paramMap = copy.deepcopy(self.extractParamMap(extra))
return that

def _shouldOwn(self, param):
"""
Validates that the input param belongs to this Params instance.
"""
if param.parent is not self:
if not (self.uid == param.parent and self.hasParam(param.name)):
raise ValueError("Param %r does not belong to %r." % (param, self))

def _resolveParam(self, param):
Expand Down Expand Up @@ -219,15 +224,15 @@ def _set(self, **kwargs):
Sets user-supplied params.
"""
for param, value in kwargs.items():
self.paramMap[getattr(self, param)] = value
self._paramMap[getattr(self, param)] = value
return self

def _setDefault(self, **kwargs):
"""
Sets default params.
"""
for param, value in kwargs.items():
self.defaultParamMap[getattr(self, param)] = value
self._defaultParamMap[getattr(self, param)] = value
return self

def _copyValues(self, to, extra={}):
Expand Down
Loading

0 comments on commit 64a536c

Please sign in to comment.