-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-29377][PYTHON][ML] Parity between Scala ML tuning and Python ML tuning #26057
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,7 +122,7 @@ def to_key_value_pairs(keys, values): | |
return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)] | ||
|
||
|
||
class ValidatorParams(HasSeed): | ||
class _ValidatorParams(HasSeed): | ||
""" | ||
Common params for TrainValidationSplit and CrossValidator. | ||
""" | ||
|
@@ -133,36 +133,21 @@ class ValidatorParams(HasSeed): | |
Params._dummy(), "evaluator", | ||
"evaluator used to select hyper-parameters that maximize the validator metric") | ||
|
||
def setEstimator(self, value): | ||
""" | ||
Sets the value of :py:attr:`estimator`. | ||
""" | ||
return self._set(estimator=value) | ||
|
||
@since("2.0.0") | ||
def getEstimator(self): | ||
""" | ||
Gets the value of estimator or its default value. | ||
""" | ||
return self.getOrDefault(self.estimator) | ||
|
||
def setEstimatorParamMaps(self, value): | ||
""" | ||
Sets the value of :py:attr:`estimatorParamMaps`. | ||
""" | ||
return self._set(estimatorParamMaps=value) | ||
|
||
@since("2.0.0") | ||
def getEstimatorParamMaps(self): | ||
""" | ||
Gets the value of estimatorParamMaps or its default value. | ||
""" | ||
return self.getOrDefault(self.estimatorParamMaps) | ||
|
||
def setEvaluator(self, value): | ||
""" | ||
Sets the value of :py:attr:`evaluator`. | ||
""" | ||
return self._set(evaluator=value) | ||
|
||
@since("2.0.0") | ||
def getEvaluator(self): | ||
""" | ||
Gets the value of evaluator or its default value. | ||
|
@@ -199,7 +184,25 @@ def _to_java_impl(self): | |
return java_estimator, java_epms, java_evaluator | ||
|
||
|
||
class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, | ||
class _CrossValidatorParams(_ValidatorParams): | ||
""" | ||
Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`. | ||
|
||
.. versionadded:: 3.0.0 | ||
""" | ||
|
||
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation", | ||
typeConverter=TypeConverters.toInt) | ||
|
||
@since("1.4.0") | ||
def getNumFolds(self): | ||
""" | ||
Gets the value of numFolds or its default value. | ||
""" | ||
return self.getOrDefault(self.numFolds) | ||
|
||
|
||
class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollectSubModels, | ||
MLReadable, MLWritable): | ||
""" | ||
|
||
|
@@ -226,6 +229,8 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubMo | |
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, | ||
... parallelism=2) | ||
>>> cvModel = cv.fit(dataset) | ||
>>> cvModel.getNumFolds() | ||
3 | ||
>>> cvModel.avgMetrics[0] | ||
0.5 | ||
>>> evaluator.evaluate(cvModel.transform(dataset)) | ||
|
@@ -234,9 +239,6 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubMo | |
.. versionadded:: 1.4.0 | ||
""" | ||
|
||
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation", | ||
typeConverter=TypeConverters.toInt) | ||
|
||
@keyword_only | ||
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, | ||
seed=None, parallelism=1, collectSubModels=False): | ||
|
@@ -261,19 +263,33 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, num | |
kwargs = self._input_kwargs | ||
return self._set(**kwargs) | ||
|
||
@since("1.4.0") | ||
def setNumFolds(self, value): | ||
@since("2.0.0") | ||
def setEstimator(self, value): | ||
""" | ||
Sets the value of :py:attr:`numFolds`. | ||
Sets the value of :py:attr:`estimator`. | ||
""" | ||
return self._set(numFolds=value) | ||
return self._set(estimator=value) | ||
|
||
@since("2.0.0") | ||
def setEstimatorParamMaps(self, value): | ||
""" | ||
Sets the value of :py:attr:`estimatorParamMaps`. | ||
""" | ||
return self._set(estimatorParamMaps=value) | ||
|
||
@since("2.0.0") | ||
def setEvaluator(self, value): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to add since for setters? Some like setNumFolds have it, but some do not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked the history, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
Sets the value of :py:attr:`evaluator`. | ||
""" | ||
return self._set(evaluator=value) | ||
|
||
@since("1.4.0") | ||
def getNumFolds(self): | ||
def setNumFolds(self, value): | ||
""" | ||
Gets the value of numFolds or its default value. | ||
Sets the value of :py:attr:`numFolds`. | ||
""" | ||
return self.getOrDefault(self.numFolds) | ||
return self._set(numFolds=value) | ||
|
||
def _fit(self, dataset): | ||
est = self.getOrDefault(self.estimator) | ||
|
@@ -387,7 +403,7 @@ def _to_java(self): | |
return _java_obj | ||
|
||
|
||
class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): | ||
class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable): | ||
""" | ||
|
||
CrossValidatorModel contains the model with the highest average cross-validation | ||
|
@@ -407,6 +423,27 @@ def __init__(self, bestModel, avgMetrics=[], subModels=None): | |
#: sub model list from cross validation | ||
self.subModels = subModels | ||
|
||
@since("2.0.0") | ||
def setEstimator(self, value): | ||
""" | ||
Sets the value of :py:attr:`estimator`. | ||
""" | ||
return self._set(estimator=value) | ||
|
||
@since("2.0.0") | ||
def setEstimatorParamMaps(self, value): | ||
""" | ||
Sets the value of :py:attr:`estimatorParamMaps`. | ||
""" | ||
return self._set(estimatorParamMaps=value) | ||
|
||
@since("2.0.0") | ||
def setEvaluator(self, value): | ||
""" | ||
Sets the value of :py:attr:`evaluator`. | ||
""" | ||
return self._set(evaluator=value) | ||
|
||
def _transform(self, dataset): | ||
return self.bestModel.transform(dataset) | ||
|
||
|
@@ -486,8 +523,26 @@ def _to_java(self): | |
return _java_obj | ||
|
||
|
||
class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, | ||
MLReadable, MLWritable): | ||
class _TrainValidationSplitParams(_ValidatorParams): | ||
""" | ||
Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`. | ||
|
||
.. versionadded:: 3.0.0 | ||
""" | ||
|
||
trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\ | ||
validation data. Must be between 0 and 1.", typeConverter=TypeConverters.toFloat) | ||
|
||
@since("2.0.0") | ||
def getTrainRatio(self): | ||
""" | ||
Gets the value of trainRatio or its default value. | ||
""" | ||
return self.getOrDefault(self.trainRatio) | ||
|
||
|
||
class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelism, | ||
HasCollectSubModels, MLReadable, MLWritable): | ||
""" | ||
Validation for hyper-parameter tuning. Randomly splits the input dataset into train and | ||
validation sets, and uses evaluation metric on the validation set to select the best model. | ||
|
@@ -509,15 +564,14 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollec | |
>>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, | ||
... parallelism=1, seed=42) | ||
>>> tvsModel = tvs.fit(dataset) | ||
>>> tvsModel.getTrainRatio() | ||
0.75 | ||
>>> evaluator.evaluate(tvsModel.transform(dataset)) | ||
0.833... | ||
|
||
.. versionadded:: 2.0.0 | ||
""" | ||
|
||
trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\ | ||
validation data. Must be between 0 and 1.", typeConverter=TypeConverters.toFloat) | ||
|
||
@keyword_only | ||
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, | ||
parallelism=1, collectSubModels=False, seed=None): | ||
|
@@ -543,18 +597,32 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, tra | |
return self._set(**kwargs) | ||
|
||
@since("2.0.0") | ||
def setTrainRatio(self, value): | ||
def setEstimator(self, value): | ||
""" | ||
Sets the value of :py:attr:`trainRatio`. | ||
Sets the value of :py:attr:`estimator`. | ||
""" | ||
return self._set(trainRatio=value) | ||
return self._set(estimator=value) | ||
|
||
@since("2.0.0") | ||
def getTrainRatio(self): | ||
def setEstimatorParamMaps(self, value): | ||
""" | ||
Gets the value of trainRatio or its default value. | ||
Sets the value of :py:attr:`estimatorParamMaps`. | ||
""" | ||
return self.getOrDefault(self.trainRatio) | ||
return self._set(estimatorParamMaps=value) | ||
|
||
@since("2.0.0") | ||
def setEvaluator(self, value): | ||
""" | ||
Sets the value of :py:attr:`evaluator`. | ||
""" | ||
return self._set(evaluator=value) | ||
|
||
@since("2.0.0") | ||
def setTrainRatio(self, value): | ||
""" | ||
Sets the value of :py:attr:`trainRatio`. | ||
""" | ||
return self._set(trainRatio=value) | ||
|
||
def _fit(self, dataset): | ||
est = self.getOrDefault(self.estimator) | ||
|
@@ -662,7 +730,7 @@ def _to_java(self): | |
return _java_obj | ||
|
||
|
||
class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): | ||
class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable, MLWritable): | ||
""" | ||
Model from train validation split. | ||
|
||
|
@@ -678,6 +746,27 @@ def __init__(self, bestModel, validationMetrics=[], subModels=None): | |
#: sub models from train validation split | ||
self.subModels = subModels | ||
|
||
@since("2.0.0") | ||
def setEstimator(self, value): | ||
""" | ||
Sets the value of :py:attr:`estimator`. | ||
""" | ||
return self._set(estimator=value) | ||
|
||
@since("2.0.0") | ||
def setEstimatorParamMaps(self, value): | ||
""" | ||
Sets the value of :py:attr:`estimatorParamMaps`. | ||
""" | ||
return self._set(estimatorParamMaps=value) | ||
|
||
@since("2.0.0") | ||
def setEvaluator(self, value): | ||
""" | ||
Sets the value of :py:attr:`evaluator`. | ||
""" | ||
return self._set(evaluator=value) | ||
|
||
def _transform(self, dataset): | ||
return self.bestModel.transform(dataset) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will these changes affect users who try to extend these classes much or at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think adding leading underscore will affect users to extend these classes. The single leading underscore before a class name is only a weak indicator for internal usage. It doesn't enforce privacy.