Skip to content

Commit

Permalink
[SPARK-28969][PYTHON][ML] OneVsRestParams parity between scala and py…
Browse files Browse the repository at this point in the history
…thon

### What changes were proposed in this pull request?
Follow the scala ```OneVsRestParams``` implementation, move ```setClassifier``` from ```OneVsRestParams``` to ```OneVsRest``` in Pyspark

### Why are the changes needed?
1. Maintain the parity between scala and python code.
2. ```Classifier``` can only be set in the estimator.

### Does this PR introduce any user-facing change?
Yes.
Previous behavior: ```OneVsRestModel``` has method ```setClassifier```
Current behavior:  ```setClassifier``` is removed from ```OneVsRestModel```. ```classifier``` can only be set in ```OneVsRest```.

### How was this patch tested?
Use existing tests

Closes #25715 from huaxingao/spark-28969.

Authored-by: Huaxin Gao <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
huaxingao authored and srowen committed Sep 13, 2019
1 parent fcf9b41 commit 77e9b58
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,15 +1872,6 @@ class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCo

classifier = Param(Params._dummy(), "classifier", "base binary classifier")

@since("2.0.0")
def setClassifier(self, value):
"""
Sets the value of :py:attr:`classifier`.
.. note:: Only LogisticRegression and NaiveBayes are supported now.
"""
return self._set(classifier=value)

@since("2.0.0")
def getClassifier(self):
"""
Expand Down Expand Up @@ -1959,6 +1950,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
kwargs = self._input_kwargs
return self._set(**kwargs)

@since("2.0.0")
def setClassifier(self, value):
"""
Sets the value of :py:attr:`classifier`.
"""
return self._set(classifier=value)

def _fit(self, dataset):
labelCol = self.getLabelCol()
featuresCol = self.getFeaturesCol()
Expand Down Expand Up @@ -2212,7 +2210,8 @@ def _from_java(cls, java_stage):
classifier = JavaParams._from_java(java_stage.getClassifier())
models = [JavaParams._from_java(model) for model in java_stage.models()]
py_stage = cls(models=models).setPredictionCol(predictionCol).setLabelCol(labelCol)\
.setFeaturesCol(featuresCol).setClassifier(classifier)
.setFeaturesCol(featuresCol)
py_stage._set(classifier=classifier)
py_stage._resetUid(java_stage.uid())
return py_stage

Expand Down

0 comments on commit 77e9b58

Please sign in to comment.