Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29116][PYTHON][ML] Refactor py classes related to DecisionTree #25929

Closed
wants to merge 6 commits into from

Conversation

huaxingao
Copy link
Contributor

What changes were proposed in this pull request?

  • Move tree related classes to a separate file tree.py
  • add method predictLeaf in DecisionTreeModel& TreeEnsembleModel

Why are the changes needed?

  • keep parity between scala and python
  • easy code maintenance

Does this PR introduce any user-facing change?

Yes
add method predictLeaf in DecisionTreeModel& TreeEnsembleModel
add setMinWeightFractionPerNode in DecisionTreeClassifier and DecisionTreeRegressor

How was this patch tested?

add some doc tests

@SparkQA
Copy link

SparkQA commented Sep 25, 2019

Test build #111364 has finished for PR 25929 at commit b1c418a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class DecisionTreeClassifierParams(DecisionTreeParams, TreeClassifierParams):

@SparkQA
Copy link

SparkQA commented Sep 26, 2019

Test build #111440 has finished for PR 25929 at commit b722aae.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class _DecisionTreeClassifierParams(DecisionTreeParams, TreeClassifierParams):
  • class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
  • class _RandomForestClassifierParams(RandomForestParams, TreeClassifierParams):
  • class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifierParams,
  • class _DecisionTreeRegressorParams(DecisionTreeParams, TreeRegressorParams, HasVarianceCol):
  • class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLWritable,
  • class DecisionTreeRegressionModel(DecisionTreeModel, _DecisionTreeRegressorParams,
  • class _RandomForestRegressorParams(RandomForestParams, TreeRegressorParams):
  • class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLWritable,
  • class RandomForestRegressionModel(TreeEnsembleModel, _RandomForestRegressorParams,

@SparkQA
Copy link

SparkQA commented Oct 3, 2019

Test build #111756 has finished for PR 25929 at commit 61f318b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 3, 2019

Test build #111759 has finished for PR 25929 at commit 7034a5e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor

retest this please

@SparkQA
Copy link

SparkQA commented Oct 8, 2019

Test build #111863 has finished for PR 25929 at commit 7034a5e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -1252,9 +999,29 @@ def featureImportances(self):
return self._call_java("featureImportances")


class GBTRegressorParams(GBTParams, TreeRegressorParams):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about rename it _GBT...? like above _RandomForestClassifierParams

return self._call_java("toString")


class DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto



@inherit_doc
class DecisionTreeModel(JavaPredictionModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto



@inherit_doc
class TreeEnsembleModel(JavaPredictionModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

return self._call_java("toString")


class TreeEnsembleParams(DecisionTreeParams):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

return self.getOrDefault(self.featureSubsetStrategy)


class RandomForestParams(TreeEnsembleParams):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

return self.getOrDefault(self.impurity)


class TreeClassifierParams(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

return self.getOrDefault(self.impurity)


class TreeRegressorParams(HasVarianceImpurity):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@huaxingao
Copy link
Contributor Author

@zhengruifeng
Thanks for your comments.
I didn't add _single_leading_underscore for classes that are used for other packages.
I am a little fuzzy about this _single_leading_underscore usage:
In https://pep8.org/#descriptive-naming-styles, it has
_single_leading_underscore: weak “internal use” indicator. E.g. from M import * does not import objects whose name starts with an underscore.
It makes me feel that the class with _single_leading_underscore is for internal use only. It is not intended to be used in other packages. However, if I explicitly import the _single_leading_underscore class, it works OK.
For example, If I do
from pyspark.ml.tree import *, the _single_leading_underscore class is not imported.
If I do
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, these classes are imported OK.

@zhengruifeng
Copy link
Contributor

@huaxingao Yes, I can reproduce your case.
The 'private' classes can only be imported explicitly. I guess that is way it is a weak “internal use” indicator.
I think we can add _single_leading_underscore according to the scala side.

@huaxingao
Copy link
Contributor Author

OK. I will add _single_leading_underscore to the classes you mentioned in the comments. Thanks!

@zhengruifeng
Copy link
Contributor

retest this please

1 similar comment
@zhengruifeng
Copy link
Contributor

retest this please

@SparkQA
Copy link

SparkQA commented Oct 12, 2019

Test build #111947 has finished for PR 25929 at commit 302e98e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
  • class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel,
  • class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
  • class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
  • class GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
  • class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
  • class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, HasVarianceCol):
  • class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorParams,
  • class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
  • class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams,
  • class _GBTRegressorParams(_GBTParams, _TreeRegressorParams):
  • class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
  • class GBTRegressionModel(_TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
  • class _DecisionTreeModel(JavaPredictionModel):
  • class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
  • class _TreeEnsembleModel(JavaPredictionModel):
  • class _TreeEnsembleParams(_DecisionTreeParams):
  • class _RandomForestParams(_TreeEnsembleParams):
  • class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
  • class _HasVarianceImpurity(Params):
  • class _TreeClassifierParams(object):
  • class _TreeRegressorParams(_HasVarianceImpurity):

@zhengruifeng
Copy link
Contributor

Merged to master, thanks @huaxingao

@huaxingao
Copy link
Contributor Author

Thanks a lot for your help! @zhengruifeng

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants