Skip to content

Commit

Permalink
[SPARK-2479 (partial)][MLLIB] fix binary metrics unit tests
Browse files Browse the repository at this point in the history
Allow small errors in comparison.

@dbtsai , this unit test blocks apache#1562 . I may need to merge this one first. We can change it to use the tools in apache#1425 after that PR gets merged.

Author: Xiangrui Meng <[email protected]>

Closes apache#1576 from mengxr/fix-binary-metrics-unit-tests and squashes the following commits:

5076a7f [Xiangrui Meng] fix binary metrics unit tests
  • Loading branch information
mengxr authored and conviva-zz committed Sep 4, 2014
1 parent 6407d01 commit 6f92119
Showing 1 changed file with 27 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,26 @@ package org.apache.spark.mllib.evaluation
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals

class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {

// TODO: move utility functions to TestingUtils.

def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = {
actual.zip(expected).forall { case (x1, x2) =>
x1.almostEquals(x2)
}
}

def elementsAlmostEqual(
actual: Seq[(Double, Double)],
expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = {
actual.zip(expected).forall { case ((x1, y1), (x2, y2)) =>
x1.almostEquals(x2) && y1.almostEquals(y2)
}
}

test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
Expand All @@ -41,14 +59,14 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
assert(metrics.thresholds().collect().toSeq === threshold)
assert(metrics.roc().collect().toSeq === rocCurve)
assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve))
assert(metrics.pr().collect().toSeq === prCurve)
assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve))
assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1))
assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2))
assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision))
assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall))
assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold))
assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve))
assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve)))
assert(elementsAlmostEqual(metrics.pr().collect(), prCurve))
assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve)))
assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1)))
assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2)))
assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision)))
assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall)))
}
}

0 comments on commit 6f92119

Please sign in to comment.