From 629d402569f33cb8ccf703ee0b2779bc975193b8 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 9 May 2015 10:28:10 -0700 Subject: [PATCH] fix LRSuite --- .../classification/LogisticRegression.scala | 2 +- .../LogisticRegressionSuite.scala | 35 +++++++++---------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index edd2252eda218..e607c24a7c61a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -216,7 +216,7 @@ class LogisticRegression(override val uid: String) (weightsWithIntercept, 0.0) } - new LogisticRegressionModel(this, weights.compressed, intercept) + new LogisticRegressionModel(uid, weights.compressed, intercept) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b03dd0991f021..9119745eb6f60 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.ml.classification import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionSuite +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} - class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ @@ -37,8 +36,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { super.beforeAll() sqlContext = new SQLContext(sc) - dataset = sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite - .generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 4)) + dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) /** * Here is the instruction describing how to export the test data into CSV format @@ -60,31 +58,30 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - val testData = LogisticRegressionSuite.generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42) + val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) - sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite - .generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42), 4)) + sqlContext.createDataFrame( + generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)) } } test("logistic regression: default params") { val lr = new LogisticRegression - assert(lr.getLabelCol == "label") - assert(lr.getFeaturesCol == "features") - assert(lr.getPredictionCol == "prediction") - assert(lr.getRawPredictionCol == "rawPrediction") - assert(lr.getProbabilityCol == "probability") - assert(lr.getFitIntercept == true) + assert(lr.getLabelCol === "label") + assert(lr.getFeaturesCol === "features") + assert(lr.getPredictionCol === "prediction") + assert(lr.getRawPredictionCol === "rawPrediction") + assert(lr.getProbabilityCol === "probability") + assert(lr.getFitIntercept) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") .collect() assert(model.getThreshold === 0.5) - assert(model.getFeaturesCol == "features") - assert(model.getPredictionCol == "prediction") - assert(model.getRawPredictionCol == "rawPrediction") - assert(model.getProbabilityCol == "probability") + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.getRawPredictionCol === "rawPrediction") + assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) } @@ -134,7 +131,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(parent2.getRegParam === 0.1) assert(parent2.getThreshold === 0.4) assert(model2.getThreshold === 0.4) - assert(model2.getProbabilityCol == "theProb") + assert(model2.getProbabilityCol === "theProb") } test("logistic regression: Predictor, Classifier methods") {