From 9dc35182725c8dca5293cee7ab7dccca9a258c06 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 8 Apr 2014 19:16:52 -0700 Subject: [PATCH] add tests for BinaryClassificationEvaluator --- .../mllib/evaluation/AreaUnderCurve.scala | 2 +- .../BinaryClassificationEvaluator.scala | 44 +++++++---------- .../binary/BinaryClassificationMetrics.scala | 4 +- .../evaluation/AreaUnderCurveSuite.scala | 1 - .../BinaryClassificationEvaluationSuite.scala | 13 ----- .../BinaryClassificationEvaluatorSuite.scala | 49 +++++++++++++++++++ 6 files changed, 71 insertions(+), 42 deletions(-) delete mode 100644 mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationEvaluationSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 5fdd8d8cb2480..7858ec602483f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Computes the area under the curve (AUC) using the trapezoidal rule. */ -private[mllib] object AreaUnderCurve { +private[evaluation] object AreaUnderCurve { /** * Uses the trapezoidal rule to compute the area under the line connecting the two input points. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluator.scala index 4f25b524716cb..290d8fe127ec7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluator.scala @@ -29,8 +29,8 @@ import org.apache.spark.Logging * @param totalCount label counter for all labels */ private case class BinaryConfusionMatrixImpl( - private val count: LabelCounter, - private val totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable { + count: LabelCounter, + totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable { /** number of true positives */ override def tp: Long = count.numPositives @@ -54,16 +54,16 @@ private case class BinaryConfusionMatrixImpl( /** * Evaluator for binary classification. * - * @param scoreAndlabels an RDD of (score, label) pairs. + * @param scoreAndLabels an RDD of (score, label) pairs. */ -class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) extends Serializable with Logging { +class BinaryClassificationEvaluator(scoreAndLabels: RDD[(Double, Double)]) extends Serializable with Logging { private lazy val ( cumCounts: RDD[(Double, LabelCounter)], - confusionByThreshold: RDD[(Double, BinaryConfusionMatrix)]) = { + confusions: RDD[(Double, BinaryConfusionMatrix)]) = { // Create a bin for each distinct score value, count positives and negatives within each bin, // and then sort by score values in descending order. - val counts = scoreAndlabels.combineByKey( + val counts = scoreAndLabels.combineByKey( createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label, mergeValue = (c: LabelCounter, label: Double) => c += label, mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2 @@ -73,21 +73,21 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten iter.foreach(agg += _) Iterator(agg) }, preservesPartitioning = true).collect() - val cum = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c) - val totalCount = cum.last - logInfo(s"Total counts: totalCount") + val partitionwiseCumCounts = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c) + val totalCount = partitionwiseCumCounts.last + logInfo(s"Total counts: $totalCount") val cumCounts = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, LabelCounter)]) => { - val cumCount = cum(index) + val cumCount = partitionwiseCumCounts(index) iter.map { case (score, c) => cumCount += c (score, cumCount.clone()) } }, preservesPartitioning = true) cumCounts.persist() - val scoreAndConfusion = cumCounts.map { case (score, cumCount) => - (score, BinaryConfusionMatrixImpl(cumCount, totalCount)) + val confusions = cumCounts.map { case (score, cumCount) => + (score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix]) } - (cumCounts, totalCount, scoreAndConfusion) + (cumCounts, confusions) } /** Unpersist intermediate RDDs used in the computation. */ @@ -126,18 +126,18 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ - def fMeasureByThreshold() = fMeasureByThreshold(1.0) + def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) /** Creates a curve of (threshold, metric). */ private def createCurve(y: BinaryClassificationMetric): RDD[(Double, Double)] = { - confusionByThreshold.map { case (s, c) => + confusions.map { case (s, c) => (s, y(c)) } } /** Creates a curve of (metricX, metricY). */ private def createCurve(x: BinaryClassificationMetric, y: BinaryClassificationMetric): RDD[(Double, Double)] = { - confusionByThreshold.map { case (_, c) => + confusions.map { case (_, c) => (x(c), y(c)) } } @@ -151,7 +151,7 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten */ private class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable { - /** Process a label. */ + /** Processes a label. */ def +=(label: Double): LabelCounter = { // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle // -1.0 for negative as well. @@ -159,27 +159,21 @@ private class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = this } - /** Merge another counter. */ + /** Merges another counter. */ def +=(other: LabelCounter): LabelCounter = { numPositives += other.numPositives numNegatives += other.numNegatives this } - def +(label: Double): LabelCounter = { - this.clone() += label - } - + /** Sums this counter and another counter and returns the result in a new counter. */ def +(other: LabelCounter): LabelCounter = { this.clone() += other } - def sum: Long = numPositives + numNegatives - override def clone: LabelCounter = { new LabelCounter(numPositives, numNegatives) } override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}" } - diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala index 09581bcc75c2c..11581586de817 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation.binary /** * Trait for a binary classification evaluation metric. */ -private[evaluation] trait BinaryClassificationMetric { +private[evaluation] trait BinaryClassificationMetric extends Serializable { def apply(c: BinaryConfusionMatrix): Double } @@ -37,7 +37,7 @@ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetric } /** Recall. */ -private[evalution] object Recall extends BinaryClassificationMetric { +private[evaluation] object Recall extends BinaryClassificationMetric { override def apply(c: BinaryConfusionMatrix): Double = c.tp.toDouble / c.p } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 78dd65c1721b6..1c9844f289fe0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -22,7 +22,6 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { - test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationEvaluationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationEvaluationSuite.scala deleted file mode 100644 index db5cffe280f60..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationEvaluationSuite.scala +++ /dev/null @@ -1,13 +0,0 @@ -package org.apache.spark.mllib.evaluation - -import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext - -class BinaryClassificationEvaluationSuite extends FunSuite with LocalSparkContext { - test("test") { - val data = Seq((0.0, 0.0), (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0), (0.9, 1.0)) - BinaryClassificationEvaluator.get(data) - val rdd = sc.parallelize(data, 3) - BinaryClassificationEvaluator.get(rdd) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluatorSuite.scala new file mode 100644 index 0000000000000..8a3ef5a8713a3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluatorSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.evaluation.AreaUnderCurve + +class BinaryClassificationEvaluatorSuite extends FunSuite with LocalSparkContext { + test("binary evaluation metrics") { + val scoreAndLabels = sc.parallelize( + Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) + val evaluator = new BinaryClassificationEvaluator(scoreAndLabels) + val score = Seq(0.8, 0.6, 0.4, 0.1) + val tp = Seq(1, 3, 3, 4) + val fp = Seq(0, 1, 2, 3) + val p = 4 + val n = 3 + val precision = tp.zip(fp).map { case (t, f) => t.toDouble / (t + f) } + val recall = tp.map(t => t.toDouble / p) + val fpr = fp.map(f => f.toDouble / n) + val roc = fpr.zip(recall) + val pr = recall.zip(precision) + val f1 = pr.map { case (re, prec) => 2.0 * (prec * re) / (prec + re) } + val f2 = pr.map { case (re, prec) => 5.0 * (prec * re) / (4.0 * prec + re)} + assert(evaluator.rocCurve().collect().toSeq === roc) + assert(evaluator.rocAUC() === AreaUnderCurve.of(roc)) + assert(evaluator.prCurve().collect().toSeq === pr) + assert(evaluator.prAUC() === AreaUnderCurve.of(pr)) + assert(evaluator.fMeasureByThreshold().collect().toSeq === score.zip(f1)) + assert(evaluator.fMeasureByThreshold(2.0).collect().toSeq === score.zip(f2)) + } +}