Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29116][PYTHON][ML] Refactor py classes related to DecisionTree #25929

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 53 additions & 52 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.param.shared import *
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeParams, \
DecisionTreeRegressionModel, GBTParams, HasVarianceImpurity, RandomForestParams, \
TreeEnsembleModel
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
_TreeEnsembleModel, _RandomForestParams, _GBTParams, \
_HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams
from pyspark.ml.regression import DecisionTreeRegressionModel
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
JavaPredictor, JavaPredictorParams, JavaPredictionModel, JavaWrapper
Expand Down Expand Up @@ -939,34 +940,17 @@ class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
pass


class TreeClassifierParams(object):
@inherit_doc
class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
"""
Private class to track supported impurity measures.

.. versionadded:: 1.4.0
Params for :py:class:`DecisionTreeClassifier` and :py:class:`DecisionTreeClassificationModel`.
"""
supportedImpurities = ["entropy", "gini"]

impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)

def __init__(self):
super(TreeClassifierParams, self).__init__()

@since("1.6.0")
def getImpurity(self):
"""
Gets the value of impurity or its default value.
"""
return self.getOrDefault(self.impurity)
pass


@inherit_doc
class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
DecisionTreeParams, TreeClassifierParams, HasCheckpointInterval,
HasSeed, JavaMLWritable, JavaMLReadable):
class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
JavaMLWritable, JavaMLReadable):
"""
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
learning algorithm for classification.
Expand Down Expand Up @@ -1045,20 +1029,20 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
seed=None, weightCol=None, leafCol=""):
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
seed=None, weightCol=None, leafCol="")
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
"""
super(DecisionTreeClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", leafCol="")
impurity="gini", leafCol="", minWeightFractionPerNode=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)

Expand All @@ -1068,13 +1052,14 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", seed=None, weightCol=None, leafCol=""):
impurity="gini", seed=None, weightCol=None, leafCol="",
minWeightFractionPerNode=0.0):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
seed=None, weightCol=None, leafCol="")
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
Sets params for the DecisionTreeClassifier.
"""
kwargs = self._input_kwargs
Expand All @@ -1101,6 +1086,13 @@ def setMinInstancesPerNode(self, value):
"""
return self._set(minInstancesPerNode=value)

@since("3.0.0")
def setMinWeightFractionPerNode(self, value):
"""
Sets the value of :py:attr:`minWeightFractionPerNode`.
"""
return self._set(minWeightFractionPerNode=value)

def setMinInfoGain(self, value):
"""
Sets the value of :py:attr:`minInfoGain`.
Expand Down Expand Up @@ -1128,8 +1120,9 @@ def setImpurity(self, value):


@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel, JavaProbabilisticClassificationModel,
JavaMLWritable, JavaMLReadable):
class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel,
_DecisionTreeClassifierParams, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.

Expand Down Expand Up @@ -1159,8 +1152,15 @@ def featureImportances(self):


@inherit_doc
class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestParams,
TreeClassifierParams, HasCheckpointInterval,
class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
"""
Params for :py:class:`RandomForestClassifier` and :py:class:`RandomForestClassificationModel`.
"""
pass


@inherit_doc
class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifierParams,
JavaMLWritable, JavaMLReadable):
"""
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
Expand Down Expand Up @@ -1230,22 +1230,22 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
leafCol=""):
leafCol="", minWeightFractionPerNode=0.0):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
leafCol="")
leafCol="", minWeightFractionPerNode=0.0)
"""
super(RandomForestClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", numTrees=20, featureSubsetStrategy="auto",
subsamplingRate=1.0, leafCol="")
subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)

Expand All @@ -1256,14 +1256,14 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
leafCol=""):
leafCol="", minWeightFractionPerNode=0.0):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
leafCol="")
leafCol="", minWeightFractionPerNode=0.0)
Sets params for linear classification.
"""
kwargs = self._input_kwargs
Expand Down Expand Up @@ -1337,8 +1337,9 @@ def setFeatureSubsetStrategy(self, value):
return self._set(featureSubsetStrategy=value)


class RandomForestClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel,
JavaMLWritable, JavaMLReadable):
class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
_RandomForestClassifierParams, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by RandomForestClassifier.

Expand Down Expand Up @@ -1367,7 +1368,7 @@ def trees(self):
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]


class GBTClassifierParams(GBTParams, HasVarianceImpurity):
class GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
"""
Private class to track supported GBTClassifier params.

Expand All @@ -1390,8 +1391,8 @@ def getLossType(self):


@inherit_doc
class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpointInterval,
HasSeed, JavaMLWritable, JavaMLReadable):
class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams,
JavaMLWritable, JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
learning algorithm for classification.
Expand Down Expand Up @@ -1485,14 +1486,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None,
leafCol=""):
leafCol="", minWeightFractionPerNode=0.0):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
validationIndicatorCol=None, leafCol="")
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
"""
super(GBTClassifier, self).__init__()
self._java_obj = self._new_java_obj(
Expand All @@ -1501,7 +1502,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
leafCol="")
leafCol="", minWeightFractionPerNode=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)

Expand All @@ -1512,14 +1513,14 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
validationIndicatorCol=None, leafCol=""):
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
validationIndicatorCol=None, leafCol="")
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
Sets params for Gradient Boosted Tree Classification.
"""
kwargs = self._input_kwargs
Expand Down Expand Up @@ -1600,8 +1601,8 @@ def setValidationIndicatorCol(self, value):
return self._set(validationIndicatorCol=value)


class GBTClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel,
JavaMLWritable, JavaMLReadable):
class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
GBTClassifierParams, JavaMLWritable, JavaMLReadable):
"""
Model fitted by GBTClassifier.

Expand Down
Loading