Skip to content

Commit

Permalink
[SPARK-14712][ML] LogisticRegressionModel.toString should summarize m…
Browse files Browse the repository at this point in the history
…odel

## What changes were proposed in this pull request?

[SPARK-14712](https://issues.apache.org/jira/browse/SPARK-14712)
spark.mllib LogisticRegressionModel overrides toString to print a little model info. We should do the same in spark.ml and override repr in pyspark.

## How was this patch tested?

LogisticRegressionSuite.scala
Python doctest in pyspark.ml.classification.py

Author: bravo-zhang <[email protected]>

Closes #18826 from bravo-zhang/spark-14712.
  • Loading branch information
jiayue-zhang authored and holdenk committed Jun 28, 2018
1 parent 5b05966 commit 524827f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] (
*/
@Since("1.6.0")
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)

override def toString: String = {
s"LogisticRegressionModel: " +
s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures"
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
assert(model.getFamily === family)
}
}

test("toString") {
val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0)
val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3"
assert(model.toString === expected)
}
}

object LogisticRegressionSuite {
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
True
>>> blorModel.intercept == model2.intercept
True
>>> model2
LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2
.. versionadded:: 1.3.0
"""
Expand Down Expand Up @@ -562,6 +564,9 @@ def evaluate(self, dataset):
java_blr_summary = self._call_java("evaluate", dataset)
return BinaryLogisticRegressionSummary(java_blr_summary)

def __repr__(self):
return self._call_java("toString")


class LogisticRegressionSummary(JavaWrapper):
"""
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/mllib/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ def load(cls, sc, path):
model.setThreshold(threshold)
return model

def __repr__(self):
return self._call_java("toString")


class LogisticRegressionWithSGD(object):
"""
Expand Down

0 comments on commit 524827f

Please sign in to comment.