From 8316d5ed8a9e810edbc5a202e5a7e8337cee9934 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 5 Feb 2015 13:29:56 -0800 Subject: [PATCH] fixes after rebasing on master --- .../examples/ml/DeveloperApiExample.scala | 6 +- .../examples/ml/SimpleParamsExample.scala | 2 +- .../spark/ml/classification/Classifier.scala | 4 +- .../classification/LogisticRegression.scala | 64 ------------------- 4 files changed, 6 insertions(+), 70 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 6f68020bf9ee2..7a19bcd3c8d77 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -40,10 +40,10 @@ object DeveloperApiExample { val conf = new SparkConf().setAppName("DeveloperApiExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training data. - val training = sparkContext.parallelize(Seq( + val training = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), @@ -61,7 +61,7 @@ object DeveloperApiExample { val model = lr.fit(training) // Prepare test data. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 79ce9fdf7294c..80c9f5ff5781e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -81,7 +81,7 @@ object SimpleParamsExample { println("Model 2 was fit using parameters: " + model2.fittingParamMap) // Prepare test data. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 40b49e37e076d..a4fbf04e03112 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -181,8 +181,8 @@ private[ml] object ClassificationModel { val raw2pred: Vector => Double = (rawPred) => { rawPred.toArray.zipWithIndex.maxBy(_._1)._2 } - tmpData = tmpData.select($"*", - callUDF(raw2pred, col(map(model.rawPredictionCol))).as(map(model.predictionCol))) + tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType, + col(map(model.rawPredictionCol))).as(map(model.predictionCol))) numColsOutput += 1 } } else if (map(model.predictionCol) != "") { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 4492c40aa2bfc..3246c9beae241 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -22,7 +22,6 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dsl._ import org.apache.spark.storage.StorageLevel @@ -103,69 +102,6 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - // Check schema - transformSchema(dataset.schema, paramMap, logging = true) - - val map = this.paramMap ++ paramMap - - // Output selected columns only. - // This is a bit complicated since it tries to avoid repeated computation. - // rawPrediction (-margin, margin) - // probability (1.0-score, score) - // prediction (max margin) - var tmpData = dataset - var numColsOutput = 0 - if (map(rawPredictionCol) != "") { - val features2raw: Vector => Vector = predictRaw - tmpData = tmpData.select($"*", - callUDF(features2raw, col(map(featuresCol))).as(map(rawPredictionCol))) - numColsOutput += 1 - } - if (map(probabilityCol) != "") { - if (map(rawPredictionCol) != "") { - val raw2prob: Vector => Vector = (rawPreds) => { - val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) - Vectors.dense(1.0 - prob1, prob1) - } - tmpData = tmpData.select($"*", - callUDF(raw2prob, col(map(rawPredictionCol))).as(map(probabilityCol))) - } else { - val features2prob: Vector => Vector = predictProbabilities - tmpData = tmpData.select($"*", - callUDF(features2prob, col(map(featuresCol))).as(map(probabilityCol))) - } - numColsOutput += 1 - } - if (map(predictionCol) != "") { - val t = map(threshold) - if (map(probabilityCol) != "") { - val predict: Vector => Double = (probs) => { - if (probs(1) > t) 1.0 else 0.0 - } - tmpData = tmpData.select($"*", - callUDF(predict, col(map(probabilityCol))).as(map(predictionCol))) - } else if (map(rawPredictionCol) != "") { - val predict: Vector => Double = (rawPreds) => { - val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) - if (prob1 > t) 1.0 else 0.0 - } - tmpData = tmpData.select($"*", - callUDF(predict, col(map(rawPredictionCol))).as(map(predictionCol))) - } else { - val predict: Vector => Double = this.predict - tmpData = tmpData.select($"*", - callUDF(predict, col(map(featuresCol))).as(map(predictionCol))) - } - numColsOutput += 1 - } - if (numColsOutput == 0) { - this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" + - " since no output columns were set.") - } - tmpData - } - override val numClasses: Int = 2 /**