Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor #2

Merged
merged 5 commits into from
Jan 28, 2015
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/docs/pyspark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Subpackages

pyspark.sql
pyspark.streaming
pyspark.ml
pyspark.mllib

Contents
Expand Down
309 changes: 3 additions & 306 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,310 +15,7 @@
# limitations under the License.
#

from abc import ABCMeta, abstractmethod, abstractproperty
from pyspark.ml.param import *
from pyspark.ml.pipeline import *

from pyspark import SparkContext
from pyspark.sql import SchemaRDD, inherit_doc # TODO: move inherit_doc to Spark Core
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable

__all__ = ["Pipeline", "Transformer", "Estimator", "param", "feature", "classification"]


def _jvm():
"""
Returns the JVM view associated with SparkContext. Must be called
after SparkContext is initialized.
"""
jvm = SparkContext._jvm
if jvm:
return jvm
else:
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")


@inherit_doc
class PipelineStage(Params):
"""
A stage in a pipeline, either an :py:class:`Estimator` or a
:py:class:`Transformer`.
"""

__metaclass__ = ABCMeta

def __init__(self):
super(PipelineStage, self).__init__()


@inherit_doc
class Estimator(PipelineStage):
"""
Abstract class for estimators that fit models to data.
"""

__metaclass__ = ABCMeta

def __init__(self):
super(Estimator, self).__init__()

@abstractmethod
def fit(self, dataset, params={}):
"""
Fits a model to the input dataset with optional parameters.

:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.SchemaRDD`
:param params: an optional param map that overwrites embedded
params
:returns: fitted model
"""
raise NotImplementedError()


@inherit_doc
class Transformer(PipelineStage):
"""
Abstract class for transformers that transform one dataset into
another.
"""

__metaclass__ = ABCMeta

def __init__(self):
super(Transformer, self).__init__()

@abstractmethod
def transform(self, dataset, params={}):
"""
Transforms the input dataset with optional parameters.

:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.SchemaRDD`
:param params: an optional param map that overwrites embedded
params
:returns: transformed dataset
"""
raise NotImplementedError()


@inherit_doc
class Model(Transformer):
"""
Abstract class for models fitted by :py:class:`Estimator`s.
"""

___metaclass__ = ABCMeta

def __init__(self):
super(Model, self).__init__()


@inherit_doc
class Pipeline(Estimator):
"""
A simple pipeline, which acts as an estimator. A Pipeline consists
of a sequence of stages, each of which is either an
:py:class:`Estimator` or a :py:class:`Transformer`. When
:py:meth:`Pipeline.fit` is called, the stages are executed in
order. If a stage is an :py:class:`Estimator`, its
:py:meth:`Estimator.fit` method will be called on the input
dataset to fit a model. Then the model, which is a transformer,
will be used to transform the dataset as the input to the next
stage. If a stage is a :py:class:`Transformer`, its
:py:meth:`Transformer.transform` method will be called to produce
the dataset for the next stage. The fitted model from a
:py:class:`Pipeline` is an :py:class:`PipelineModel`, which
consists of fitted models and transformers, corresponding to the
pipeline stages. If there are no stages, the pipeline acts as an
identity transformer.
"""

def __init__(self):
super(Pipeline, self).__init__()
#: Param for pipeline stages.
self.stages = Param(self, "stages", "pipeline stages")

def setStages(self, value):
"""
Set pipeline stages.
:param value: a list of transformers or estimators
:return: the pipeline instance
"""
self.paramMap[self.stages] = value
return self

def getStages(self):
"""
Get pipeline stages.
"""
if self.stages in self.paramMap:
return self.paramMap[self.stages]

def fit(self, dataset, params={}):
paramMap = self._merge_params(params)
stages = paramMap[self.stages]
for stage in stages:
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
raise ValueError(
"Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
indexOfLastEstimator = -1
for i, stage in enumerate(stages):
if isinstance(stage, Estimator):
indexOfLastEstimator = i
transformers = []
for i, stage in enumerate(stages):
if i <= indexOfLastEstimator:
if isinstance(stage, Transformer):
transformers.append(stage)
dataset = stage.transform(dataset, paramMap)
else: # must be an Estimator
model = stage.fit(dataset, paramMap)
transformers.append(model)
if i < indexOfLastEstimator:
dataset = model.transform(dataset, paramMap)
else:
transformers.append(stage)
return PipelineModel(transformers)


@inherit_doc
class PipelineModel(Model):
"""
Represents a compiled pipeline with transformers and fitted models.
"""

def __init__(self, transformers):
super(PipelineModel, self).__init__()
self.transformers = transformers

def transform(self, dataset, params={}):
paramMap = self._merge_params(params)
for t in self.transformers:
dataset = t.transform(dataset, paramMap)
return dataset


@inherit_doc
class JavaWrapper(Params):
"""
Utility class to help create wrapper classes from Java/Scala
implementations of pipeline components.
"""

__metaclass__ = ABCMeta

def __init__(self):
super(JavaWrapper, self).__init__()

@abstractproperty
def _java_class(self):
"""
Fully-qualified class name of the wrapped Java component.
"""
raise NotImplementedError

def _java_obj(self):
"""
Returns or creates a Java object.
"""
java_obj = _jvm()
for name in self._java_class.split("."):
java_obj = getattr(java_obj, name)
return java_obj()

def _transfer_params_to_java(self, params, java_obj):
"""
Transforms the embedded params and additional params to the
input Java object.
:param params: additional params (overwriting embedded values)
:param java_obj: Java object to receive the params
"""
paramMap = self._merge_params(params)
for param in self.params:
if param in paramMap:
java_obj.set(param.name, paramMap[param])

def _empty_java_param_map(self):
"""
Returns an empty Java ParamMap reference.
"""
return _jvm().org.apache.spark.ml.param.ParamMap()

def _create_java_param_map(self, params, java_obj):
paramMap = self._empty_java_param_map()
for param, value in params.items():
if param.parent is self:
paramMap.put(java_obj.getParam(param.name), value)
return paramMap


@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

def __init__(self):
super(JavaEstimator, self).__init__()

@abstractmethod
def _create_model(self, java_model):
"""
Creates a model from the input Java model reference.
"""
raise NotImplementedError

def _fit_java(self, dataset, params={}):
"""
Fits a Java model to the input dataset.
:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.SchemaRDD`
:param params: additional params (overwriting embedded values)
:return: fitted Java model
"""
java_obj = self._java_obj()
self._transfer_params_to_java(params, java_obj)
return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map())

def fit(self, dataset, params={}):
java_model = self._fit_java(dataset, params)
return self._create_model(java_model)


@inherit_doc
class JavaTransformer(Transformer, JavaWrapper):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

def __init__(self):
super(JavaTransformer, self).__init__()

def transform(self, dataset, params={}):
java_obj = self._java_obj()
self._transfer_params_to_java({}, java_obj)
java_param_map = self._create_java_param_map(params, java_obj)
return SchemaRDD(java_obj.transform(dataset._jschema_rdd, java_param_map),
dataset.sql_ctx)


@inherit_doc
class JavaModel(JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

def __init__(self):
super(JavaTransformer, self).__init__()

def _java_obj(self):
return self._java_model
__all__ = ["Pipeline", "Transformer", "Estimator"]
23 changes: 5 additions & 18 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
# limitations under the License.
#

from pyspark.sql import inherit_doc
from pyspark.ml import JavaEstimator, JavaModel
from pyspark.ml.util import JavaEstimator, JavaModel, inherit_doc
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
HasRegParam


__all__ = ['LogisticRegression', 'LogisticRegressionModel']


@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam):
Expand All @@ -43,32 +45,17 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> print model.transform(test1).first().prediction
1.0
"""

def __init__(self):
super(LogisticRegression, self).__init__()

@property
def _java_class(self):
return "org.apache.spark.ml.classification.LogisticRegression"
_java_class = "org.apache.spark.ml.classification.LogisticRegression"

def _create_model(self, java_model):
return LogisticRegressionModel(java_model)


@inherit_doc
class LogisticRegressionModel(JavaModel):
"""
Model fitted by LogisticRegression.
"""

def __init__(self, java_model):
super(LogisticRegressionModel, self).__init__()
self._java_model = java_model

@property
def _java_class(self):
return "org.apache.spark.ml.classification.LogisticRegressionModel"


if __name__ == "__main__":
import doctest
Expand Down
Loading