Skip to content

Commit

Permalink
[SPARK-28985][PYTHON][ML][FOLLOW-UP] Add _AFTSurvivalRegressionParams
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Adds

```python
_AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol,
                                   HasMaxIter, HasTol, HasFitIntercept,
                                   HasAggregationDepth): ...
```

with related Params and uses it to replace `HasFitIntercept`, `HasMaxIter`, `HasTol` and  `HasAggregationDepth` in `AFTSurvivalRegression` base classes and `JavaPredictionModel,` in `AFTSurvivalRegressionModel` base classes.

### Why are the changes needed?

Previous work (#25776) on [SPARK-28985](https://issues.apache.org/jira/browse/SPARK-28985) replaced `JavaEstimator`, `HasFeaturesCol`, `HasLabelCol`, `HasPredictionCol` in `AFTSurvivalRegression` and  `JavaModel` in `AFTSurvivalRegressionModel` with newly added `JavaPredictor`:

https://github.com/apache/spark/blob/e97b55d32285052a1f76cca35377c4b21eb2e7d7/python/pyspark/ml/wrapper.py#L377

and `JavaPredictionModel`

https://github.com/apache/spark/blob/e97b55d32285052a1f76cca35377c4b21eb2e7d7/python/pyspark/ml/wrapper.py#L405

respectively.

This however is inconsistent with Scala counterpart where both classes extend private `AFTSurvivalRegressionBase`

https://github.com/apache/spark/blob/eb037a8180be4ab7570eda1fa9cbf3c84b92c3f7/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala#L48-L50

This preserves some of the existing inconsistencies (variables as defined in [the official example](https://github.com/apache/spark/blob/master/examples/src/main/python/ml/aft_survival_regression.p))

```
from pyspark.ml.regression import AFTSurvivalRegression, AFTSurvivalRegressionModel
from pyspark.ml.param.shared import HasMaxIter, HasTol, HasFitIntercept, HasAggregationDepth
from pyspark.ml.param import Param

issubclass(AFTSurvivalRegressionModel, HasMaxIter)
# False
hasattr(model, "maxIter")  and isinstance(model.maxIter, Param)
# True

issubclass(AFTSurvivalRegressionModel, HasTol)
# False
hasattr(model, "tol")  and isinstance(model.tol, Param)
# True
```

and can cause problems in the future, if Predictor / PredictionModel API changes (unlike [`IsotonicRegression`](#26023), current implementation is technically speaking correct, though incomplete).

### Does this PR introduce any user-facing change?

Yes, it adds a number of base classes to `AFTSurvivalRegressionModel`. These change purely additive and have negligible potential for breaking existing code (and none, compared to changes already made in #25776). Additionally affected API hasn't been released in the current form yet.

### How was this patch tested?

- Existing unit tests.
- Manual testing.

CC huaxingao, zhengruifeng

Closes #26024 from zero323/SPARK-28985-FOLLOW-UP-aftsurival-regression.

Authored-by: zero323 <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
zero323 authored and srowen committed Oct 4, 2019
1 parent 228b1ea commit df22535
Showing 1 changed file with 51 additions and 38 deletions.
89 changes: 51 additions & 38 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,9 +1480,56 @@ def evaluateEachIteration(self, dataset, loss):
return self._call_java("evaluateEachIteration", dataset, loss)


class _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasMaxIter, HasTol, HasFitIntercept,
HasAggregationDepth):
"""
Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
.. versionadded:: 3.0.0
"""

censorCol = Param(
Params._dummy(), "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
quantileProbabilities = Param(
Params._dummy(), "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.",
typeConverter=TypeConverters.toListFloat)
quantilesCol = Param(
Params._dummy(), "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.",
typeConverter=TypeConverters.toString)

@since("1.6.0")
def getCensorCol(self):
"""
Gets the value of censorCol or its default value.
"""
return self.getOrDefault(self.censorCol)

@since("1.6.0")
def getQuantileProbabilities(self):
"""
Gets the value of quantileProbabilities or its default value.
"""
return self.getOrDefault(self.quantileProbabilities)

@since("1.6.0")
def getQuantilesCol(self):
"""
Gets the value of quantilesCol or its default value.
"""
return self.getOrDefault(self.quantilesCol)


@inherit_doc
class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
HasAggregationDepth, JavaMLWritable, JavaMLReadable):
class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression
Expand Down Expand Up @@ -1529,20 +1576,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
.. versionadded:: 1.6.0
"""

censorCol = Param(Params._dummy(), "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
quantileProbabilities = \
Param(Params._dummy(), "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.",
typeConverter=TypeConverters.toListFloat)
quantilesCol = Param(Params._dummy(), "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
Expand Down Expand Up @@ -1588,43 +1621,23 @@ def setCensorCol(self, value):
"""
return self._set(censorCol=value)

@since("1.6.0")
def getCensorCol(self):
"""
Gets the value of censorCol or its default value.
"""
return self.getOrDefault(self.censorCol)

@since("1.6.0")
def setQuantileProbabilities(self, value):
"""
Sets the value of :py:attr:`quantileProbabilities`.
"""
return self._set(quantileProbabilities=value)

@since("1.6.0")
def getQuantileProbabilities(self):
"""
Gets the value of quantileProbabilities or its default value.
"""
return self.getOrDefault(self.quantileProbabilities)

@since("1.6.0")
def setQuantilesCol(self, value):
"""
Sets the value of :py:attr:`quantilesCol`.
"""
return self._set(quantilesCol=value)

@since("1.6.0")
def getQuantilesCol(self):
"""
Gets the value of quantilesCol or its default value.
"""
return self.getOrDefault(self.quantilesCol)


class AFTSurvivalRegressionModel(JavaPredictionModel, JavaMLWritable, JavaMLReadable):
class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`AFTSurvivalRegression`.
Expand Down

0 comments on commit df22535

Please sign in to comment.