Skip to content

Commit

Permalink
[SPARK-29212][ML][PYSPARK] Add common classes without using JVM backend
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Implement common base ML classes (`Predictor`, `PredictionModel`, `Classifier`, `ClasssificationModel` `ProbabilisticClassifier`, `ProbabilisticClasssificationModel`, `Regressor`, `RegrssionModel`) for non-Java backends.

Note

- `Predictor` and `JavaClassifier` should be abstract as `_fit` method is not implemented.
- `PredictionModel` should be abstract as `_transform` is not implemented.

### Why are the changes needed?

To provide extensions points for non-JVM algorithms, as well as a public (as opposed to `Java*` variants, which are commonly described in docstrings as private) hierarchy which can be used to distinguish between different classes of predictors.

For longer discussion see [SPARK-29212](https://issues.apache.org/jira/browse/SPARK-29212) and / or #25776.

### Does this PR introduce any user-facing change?

It adds new base classes as listed above, but effective interfaces (method resolution order notwithstanding) stay the same.

Additionally "private" `Java*` classes in`ml.regression` and `ml.classification` have been renamed to follow PEP-8 conventions (added leading underscore).

It is for discussion if the same should be done to equivalent classes from `ml.wrapper`.

If we take `JavaClassifier` as an example, type hierarchy will change from

![old pyspark ml classification JavaClassifier](https://user-images.githubusercontent.com/1554276/72657093-5c0b0c80-39a0-11ea-9069-a897d75de483.png)

to

![new pyspark ml classification _JavaClassifier](https://user-images.githubusercontent.com/1554276/72657098-64fbde00-39a0-11ea-8f80-01187a5ea5a6.png)

Similarly the old model

![old pyspark ml classification JavaClassificationModel](https://user-images.githubusercontent.com/1554276/72657103-7513bd80-39a0-11ea-9ffc-59eb6ab61fde.png)

will become

![new pyspark ml classification _JavaClassificationModel](https://user-images.githubusercontent.com/1554276/72657110-80ff7f80-39a0-11ea-9f5c-fe408664e827.png)

### How was this patch tested?

Existing unit tests.

Closes #27245 from zero323/SPARK-29212.

Authored-by: zero323 <[email protected]>
Signed-off-by: zhengruifeng <[email protected]>
  • Loading branch information
zero323 authored and zhengruifeng committed Mar 4, 2020
1 parent 111e903 commit e1b3e9a
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 116 deletions.
6 changes: 4 additions & 2 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
DataFrame-based machine learning APIs to let users quickly assemble and configure practical
machine learning pipelines.
"""
from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer
from pyspark.ml.base import Estimator, Model, Predictor, PredictionModel, \
Transformer, UnaryTransformer
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml import classification, clustering, evaluation, feature, fpm, \
image, pipeline, recommendation, regression, stat, tuning, util, linalg, param

__all__ = [
"Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel",
"Transformer", "UnaryTransformer", "Estimator", "Model",
"Predictor", "PredictionModel", "Pipeline", "PipelineModel",
"classification", "clustering", "evaluation", "feature", "fpm", "image",
"recommendation", "regression", "stat", "tuning", "util", "linalg", "param",
]
81 changes: 80 additions & 1 deletion python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from abc import ABCMeta, abstractmethod
from abc import ABCMeta, abstractmethod, abstractproperty

import copy
import threading
Expand Down Expand Up @@ -246,3 +246,82 @@ def _transform(self, dataset):
transformedDataset = dataset.withColumn(self.getOutputCol(),
transformUDF(dataset[self.getInputCol()]))
return transformedDataset


@inherit_doc
class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
"""
Params for :py:class:`Predictor` and :py:class:`PredictorModel`.
.. versionadded:: 3.0.0
"""
pass


@inherit_doc
class Predictor(Estimator, _PredictorParams):
"""
Estimator for prediction tasks (regression and classification).
"""

__metaclass__ = ABCMeta

@since("3.0.0")
def setLabelCol(self, value):
"""
Sets the value of :py:attr:`labelCol`.
"""
return self._set(labelCol=value)

@since("3.0.0")
def setFeaturesCol(self, value):
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)

@since("3.0.0")
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)


@inherit_doc
class PredictionModel(Model, _PredictorParams):
"""
Model for prediction tasks (regression and classification).
"""

__metaclass__ = ABCMeta

@since("3.0.0")
def setFeaturesCol(self, value):
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)

@since("3.0.0")
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)

@abstractproperty
@since("2.1.0")
def numFeatures(self):
"""
Returns the number of features the model was trained on. If unknown, returns -1
"""
raise NotImplementedError()

@abstractmethod
@since("3.0.0")
def predict(self, value):
"""
Predict label for the given features.
"""
raise NotImplementedError()
Loading

0 comments on commit e1b3e9a

Please sign in to comment.