Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28985][PYTHON][ML] Add common classes (JavaPredictor/JavaClassificationModel/JavaProbabilisticClassifier) in PYTHON #25776

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 141 additions & 35 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
DecisionTreeRegressionModel, GBTParams, HasVarianceImpurity, RandomForestParams, \
TreeEnsembleModel
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
JavaPredictor, JavaPredictorParams, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc, _java2py, _py2java
from pyspark.ml.linalg import Vectors
from pyspark.sql import DataFrame
Expand All @@ -47,14 +47,43 @@
'OneVsRest', 'OneVsRestModel']


class JavaClassifierParams(HasRawPredictionCol, JavaPredictorParams):
"""
(Private) Java Classifier Params for classification tasks.
"""
pass


@inherit_doc
class JavaClassifier(JavaPredictor, JavaClassifierParams):
"""
Java Classifier for classification tasks.
Classes are indexed {0, 1, ..., numClasses - 1}.
"""

@since("3.0.0")
def setRawPredictionCol(self, value):
"""
Sets the value of :py:attr:`rawPredictionCol`.
"""
return self._set(rawPredictionCol=value)


@inherit_doc
class JavaClassificationModel(JavaPredictionModel):
class JavaClassificationModel(JavaPredictionModel, JavaClassifierParams):
"""
(Private) Java Model produced by a ``Classifier``.
Java Model produced by a ``Classifier``.
Classes are indexed {0, 1, ..., numClasses - 1}.
To be mixed in with class:`pyspark.ml.JavaModel`
"""

@since("3.0.0")
def setRawPredictionCol(self, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

"""
Sets the value of :py:attr:`rawPredictionCol`.
"""
return self._set(rawPredictionCol=value)

@property
@since("2.1.0")
def numClasses(self):
Expand All @@ -64,10 +93,60 @@ def numClasses(self):
return self._call_java("numClasses")


class JavaProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, JavaClassifierParams):
"""
(Private) Java Probabilistic Classifier Params for classification tasks.
"""
pass


@inherit_doc
class JavaProbabilisticClassifier(JavaClassifier, JavaProbabilisticClassifierParams):
"""
Java Probabilistic Classifier for classification tasks.
"""

@since("3.0.0")
def setProbabilityCol(self, value):
"""
Sets the value of :py:attr:`probabilityCol`.
"""
return self._set(probabilityCol=value)

@since("3.0.0")
def setThresholds(self, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

"""
Sets the value of :py:attr:`thresholds`.
"""
return self._set(thresholds=value)


@inherit_doc
class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasRawPredictionCol, HasFitIntercept, HasStandardization,
HasWeightCol, HasAggregationDepth, HasThreshold, JavaMLWritable, JavaMLReadable):
class JavaProbabilisticClassificationModel(JavaClassificationModel,
JavaProbabilisticClassifierParams):
"""
Java Model produced by a ``ProbabilisticClassifier``.
"""

@since("3.0.0")
def setProbabilityCol(self, value):
"""
Sets the value of :py:attr:`probabilityCol`.
"""
return self._set(probabilityCol=value)

@since("3.0.0")
def setThresholds(self, value):
"""
Sets the value of :py:attr:`thresholds`.
"""
return self._set(thresholds=value)


@inherit_doc
class LinearSVC(JavaClassifier, HasMaxIter, HasRegParam, HasTol,
HasFitIntercept, HasStandardization, HasWeightCol, HasAggregationDepth,
HasThreshold, JavaMLWritable, JavaMLReadable):
"""
`Linear SVM Classifier <https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM>`_

Expand All @@ -81,6 +160,8 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha
... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
>>> svm = LinearSVC(maxIter=5, regParam=0.01)
>>> model = svm.fit(df)
>>> model.setPredictionCol("prediction")
Copy link
Contributor

@zhengruifeng zhengruifeng Sep 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about changing the value to a non-default value like "newPrediction", and making sure that the ouput dataframe/row has changed column name?

LinearSVC...
>>> model.coefficients
DenseVector([0.0, -0.2792, -0.1833])
>>> model.intercept
Expand All @@ -90,6 +171,8 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha
>>> model.numFeatures
3
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()
>>> model.predict(test0.head().features)
1.0
>>> result = model.transform(test0).head()
>>> result.prediction
1.0
Expand Down Expand Up @@ -156,7 +239,7 @@ def _create_model(self, java_model):
return LinearSVCModel(java_model)


class LinearSVCModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
class LinearSVCModel(JavaClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LinearSVC.

Expand All @@ -181,8 +264,7 @@ def intercept(self):


@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
class LogisticRegression(JavaProbabilisticClassifier, HasMaxIter, HasRegParam, HasTol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
"""
Expand All @@ -198,6 +280,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()
>>> blor = LogisticRegression(regParam=0.01, weightCol="weight")
>>> blorModel = blor.fit(bdf)
>>> blorModel.setFeaturesCol("features")
LogisticRegressionModel...
>>> blorModel.coefficients
DenseVector([-1.080..., -0.646...])
>>> blorModel.intercept
Expand All @@ -211,6 +295,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> mlorModel.interceptVector
DenseVector([0.04..., -0.42..., 0.37...])
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()
>>> blorModel.predict(test0.head().features)
1.0
>>> result = blorModel.transform(test0).head()
>>> result.prediction
1.0
Expand Down Expand Up @@ -481,7 +567,7 @@ def getUpperBoundsOnIntercepts(self):
return self.getOrDefault(self.upperBoundsOnIntercepts)


class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable,
class LogisticRegressionModel(JavaProbabilisticClassificationModel, JavaMLWritable, JavaMLReadable,
HasTrainingSummary):
"""
Model fitted by LogisticRegression.
Expand Down Expand Up @@ -872,8 +958,7 @@ def getImpurity(self):


@inherit_doc
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasWeightCol,
HasPredictionCol, HasProbabilityCol, HasRawPredictionCol,
class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
DecisionTreeParams, TreeClassifierParams, HasCheckpointInterval,
HasSeed, JavaMLWritable, JavaMLReadable):
"""
Expand All @@ -892,6 +977,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasWeig
>>> td = si_model.transform(df)
>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed", leafCol="leafId")
>>> model = dt.fit(td)
>>> model.getLabelCol()
'indexed'
>>> model.setFeaturesCol("features")
DecisionTreeClassificationModel...
>>> model.numNodes
3
>>> model.depth
Expand All @@ -905,6 +994,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasWeig
>>> print(model.toDebugString)
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.predict(test0.head().features)
0.0
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
Expand Down Expand Up @@ -1031,8 +1122,8 @@ def setImpurity(self, value):


@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
class DecisionTreeClassificationModel(DecisionTreeModel, JavaProbabilisticClassificationModel,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.

Expand Down Expand Up @@ -1062,9 +1153,8 @@ def featureImportances(self):


@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
HasRawPredictionCol, HasProbabilityCol,
RandomForestParams, TreeClassifierParams, HasCheckpointInterval,
class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestParams,
TreeClassifierParams, HasCheckpointInterval,
JavaMLWritable, JavaMLReadable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding some simple tests for it?

"""
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
Expand All @@ -1085,11 +1175,17 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42,
... leafCol="leafId")
>>> model = rf.fit(td)
>>> model.getLabelCol()
'indexed'
>>> model.setFeaturesCol("features")
RandomForestClassificationModel...
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.predict(test0.head().features)
0.0
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
Expand Down Expand Up @@ -1231,8 +1327,8 @@ def setFeatureSubsetStrategy(self, value):
return self._set(featureSubsetStrategy=value)


class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
class RandomForestClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by RandomForestClassifier.

Expand Down Expand Up @@ -1284,9 +1380,8 @@ def getLossType(self):


@inherit_doc
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
GBTClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
JavaMLReadable):
class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpointInterval,
HasSeed, JavaMLWritable, JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
learning algorithm for classification.
Expand Down Expand Up @@ -1318,11 +1413,17 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> gbt.getFeatureSubsetStrategy()
'all'
>>> model = gbt.fit(td)
>>> model.getLabelCol()
'indexed'
>>> model.setFeaturesCol("features")
GBTClassificationModel...
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.predict(test0.head().features)
0.0
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
Expand Down Expand Up @@ -1485,8 +1586,8 @@ def setValidationIndicatorCol(self, value):
return self._set(validationIndicatorCol=value)


class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
class GBTClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by GBTClassifier.

Expand Down Expand Up @@ -1527,8 +1628,8 @@ def evaluateEachIteration(self, dataset):


@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
HasRawPredictionCol, HasThresholds, HasWeightCol, JavaMLWritable, JavaMLReadable):
class NaiveBayes(JavaProbabilisticClassifier, HasThresholds, HasWeightCol,
JavaMLWritable, JavaMLReadable):
"""
Naive Bayes Classifiers.
It supports both Multinomial and Bernoulli NB. `Multinomial NB
Expand All @@ -1547,11 +1648,15 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])
>>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
>>> model = nb.fit(df)
>>> model.setFeaturesCol("features")
NaiveBayes_...
>>> model.pi
DenseVector([-0.81..., -0.58...])
>>> model.theta
DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)
>>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
>>> model.predict(test0.head().features)
1.0
>>> result = model.transform(test0).head()
>>> result.prediction
1.0
Expand Down Expand Up @@ -1651,7 +1756,7 @@ def getModelType(self):
return self.getOrDefault(self.modelType)


class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
class NaiveBayesModel(JavaProbabilisticClassificationModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.

Expand All @@ -1676,10 +1781,8 @@ def theta(self):


@inherit_doc
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver,
JavaMLWritable, JavaMLReadable, HasProbabilityCol,
HasRawPredictionCol):
class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, HasMaxIter, HasTol, HasSeed,
HasStepSize, HasSolver, JavaMLWritable, JavaMLReadable):
"""
Classifier trainer based on the Multilayer Perceptron.
Each layer has sigmoid activation function, output layer has softmax.
Expand All @@ -1694,13 +1797,17 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
>>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 2, 2], blockSize=1, seed=123)
>>> model = mlp.fit(df)
>>> model.setFeaturesCol("features")
MultilayerPerceptronClassifier...
>>> model.layers
[2, 2, 2]
>>> model.weights.size
12
>>> testDF = spark.createDataFrame([
... (Vectors.dense([1.0, 0.0]),),
... (Vectors.dense([0.0, 0.0]),)], ["features"])
>>> model.predict(testDF.head().features)
1.0
>>> model.transform(testDF).select("features", "prediction").show()
+---------+----------+
| features|prediction|
Expand Down Expand Up @@ -1839,7 +1946,7 @@ def getInitialWeights(self):
return self.getOrDefault(self.initialWeights)


class MultilayerPerceptronClassificationModel(JavaModel, JavaClassificationModel, JavaMLWritable,
class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by MultilayerPerceptronClassifier.
Expand All @@ -1864,8 +1971,7 @@ def weights(self):
return self._call_java("weights")


class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol,
HasRawPredictionCol):
class OneVsRestParams(JavaClassifierParams, HasWeightCol):
"""
Parameters for OneVsRest and OneVsRestModel.
"""
Expand Down
Loading