Skip to content

Commit

Permalink
more docs
Browse files Browse the repository at this point in the history
optimize pipeline.fit impl
  • Loading branch information
mengxr committed Jan 26, 2015
1 parent 56de571 commit d3e8dbe
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 38 deletions.
117 changes: 85 additions & 32 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,15 @@


def _jvm():
return SparkContext._jvm


def _inherit_doc(cls):
for name, func in vars(cls).items():
# only inherit docstring for public functions
if name.startswith("_"):
continue
if not func.__doc__:
for parent in cls.__bases__:
parent_func = getattr(parent, name, None)
if parent_func and getattr(parent_func, "__doc__", None):
func.__doc__ = parent_func.__doc__
break
return cls
"""
Returns the JVM view associated with SparkContext. Must be called
after SparkContext is initialized.
"""
jvm = SparkContext._jvm
if jvm:
return jvm
else:
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")


@inherit_doc
Expand All @@ -50,6 +44,8 @@ class PipelineStage(Params):
:py:class:`Transformer`.
"""

__metaclass__ = ABCMeta

def __init__(self):
super(PipelineStage, self).__init__()

Expand Down Expand Up @@ -147,38 +143,54 @@ def getStages(self):
return self.paramMap[self.stages]

def fit(self, dataset, params={}):
map = self._merge_params(params)
transformers = []
for stage in self.getStages():
if isinstance(stage, Transformer):
transformers.append(stage)
dataset = stage.transform(dataset, map)
elif isinstance(stage, Estimator):
model = stage.fit(dataset, map)
transformers.append(model)
dataset = model.transform(dataset, map)
else:
paramMap = self._merge_params(params)
stages = paramMap(self.stages)
for stage in stages:
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
raise ValueError(
"Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
indexOfLastEstimator = -1
for i, stage in enumerate(stages):
if isinstance(stage, Estimator):
indexOfLastEstimator = i
transformers = []
for i, stage in enumerate(stages):
if i <= indexOfLastEstimator:
if isinstance(stage, Transformer):
transformers.append(stage)
dataset = stage.transform(dataset, paramMap)
else: # must be an Estimator
model = stage.fit(dataset, paramMap)
transformers.append(model)
dataset = model.transform(dataset, paramMap)
else:
transformers.append(stage)
return PipelineModel(transformers)


@inherit_doc
class PipelineModel(Transformer):
"""
Represents a compiled pipeline with transformers and fitted models.
"""

def __init__(self, transformers):
super(PipelineModel, self).__init__()
self.transformers = transformers

def transform(self, dataset, params={}):
map = self._merge_params(params)
paramMap = self._merge_params(params)
for t in self.transformers:
dataset = t.transform(dataset, map)
dataset = t.transform(dataset, paramMap)
return dataset


@inherit_doc
class JavaWrapper(object):
class JavaWrapper(Params):
"""
Utility class to help create wrapper classes from Java/Scala
implementations of pipeline components.
"""

__metaclass__ = ABCMeta

Expand All @@ -187,17 +199,45 @@ def __init__(self):

@abstractproperty
def _java_class(self):
"""
Fully-qualified class name of the wrapped Java component.
"""
raise NotImplementedError

def _create_java_obj(self):
"""
Creates a new Java object and returns its reference.
"""
java_obj = _jvm()
for name in self._java_class.split("."):
java_obj = getattr(java_obj, name)
return java_obj()

def _transfer_params_to_java(self, params, java_obj):
"""
Transforms the embedded params and additional params to the
input Java object.
:param params: additional params (overwriting embedded values)
:param java_obj: Java object to receive the params
"""
paramMap = self._merge_params(params)
for param in self.params():
if param in paramMap:
java_obj.set(param.name, paramMap[param])

def _empty_java_param_map(self):
"""
Returns an empty Java ParamMap reference.
"""
return _jvm().org.apache.spark.ml.param.ParamMap()


@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

Expand All @@ -206,12 +246,22 @@ def __init__(self):

@abstractmethod
def _create_model(self, java_model):
"""
Creates a model from the input Java model reference.
"""
raise NotImplementedError

def _fit_java(self, dataset, params={}):
"""
Fits a Java model to the input dataset.
:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.SchemaRDD`
:param params: additional params (overwriting embedded values)
:return: fitted Java model
"""
java_obj = self._create_java_obj()
self._transfer_params_to_java(params, java_obj)
return java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())
return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map())

def fit(self, dataset, params={}):
java_model = self._fit_java(dataset, params)
Expand All @@ -220,6 +270,10 @@ def fit(self, dataset, params={}):

@inherit_doc
class JavaTransformer(Transformer, JavaWrapper):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

Expand All @@ -229,6 +283,5 @@ def __init__(self):
def transform(self, dataset, params={}):
java_obj = self._create_java_obj()
self._transfer_params_to_java(params, java_obj)
return SchemaRDD(java_obj.transform(dataset._jschema_rdd,
_jvm().org.apache.spark.ml.param.ParamMap()),
return SchemaRDD(java_obj.transform(dataset._jschema_rdd, self._empty_java_param_map()),
dataset.sql_ctx)
1 change: 1 addition & 0 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class LogisticRegressionModel(Transformer):
"""

def __init__(self, java_model):
super(LogisticRegressionModel, self).__init__()
self._java_model = java_model

def transform(self, dataset, params={}):
Expand Down
6 changes: 0 additions & 6 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,3 @@ def _merge_params(self, params):
map = self.paramMap.copy()
map.update(params)
return map

def _transfer_params_to_java(self, params, java_obj):
map = self._merge_params(params)
for param in self.params():
if param in map:
java_obj.set(param.name, map[param])

0 comments on commit d3e8dbe

Please sign in to comment.