Skip to content

Commit

Permalink
[SPARK-30938][ML][MLLIB] BinaryClassificationMetrics optimization
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
1, avoid `Iterator.grouped(size: Int)`, which need to maintain an arraybuffer of `size`
2, keep the number of partitions in curve computation

### Why are the changes needed?
1, `BinaryClassificationMetrics` tend to fail (OOM) when `grouping=count/numBins` is too large, due to `Iterator.grouped(size: Int)` need to maintain an arraybuffer with `size` entries, however, in `BinaryClassificationMetrics` we do not need to maintain such a big array;
2, make sizes of partitions more even;

This PR computes metrics more stable and a littler faster;

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
existing testsuites

Closes apache#27682 from zhengruifeng/grouped_opt.

Authored-by: zhengruifeng <[email protected]>
Signed-off-by: zhengruifeng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 28, 2020
1 parent 1383bd4 commit 14bb639
Showing 1 changed file with 49 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.evaluation.binary._
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

/**
Expand Down Expand Up @@ -101,10 +101,19 @@ class BinaryClassificationMetrics @Since("3.0.0") (
@Since("1.0.0")
def roc(): RDD[(Double, Double)] = {
val rocCurve = createCurve(FalsePositiveRate, Recall)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
val numParts = rocCurve.getNumPartitions
rocCurve.mapPartitionsWithIndex { case (pid, iter) =>
if (numParts == 1) {
require(pid == 0)
Iterator.single((0.0, 0.0)) ++ iter ++ Iterator.single((1.0, 1.0))
} else if (pid == 0) {
Iterator.single((0.0, 0.0)) ++ iter
} else if (pid == numParts - 1) {
iter ++ Iterator.single((1.0, 1.0))
} else {
iter
}
}
}

/**
Expand All @@ -124,7 +133,13 @@ class BinaryClassificationMetrics @Since("3.0.0") (
def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
val (_, firstPrecision) = prCurve.first()
confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve)
prCurve.mapPartitionsWithIndex { case (pid, iter) =>
if (pid == 0) {
Iterator.single((0.0, firstPrecision)) ++ iter
} else {
iter
}
}
}

/**
Expand Down Expand Up @@ -182,28 +197,40 @@ class BinaryClassificationMetrics @Since("3.0.0") (
val countsSize = counts.count()
// Group the iterator into chunks of about countsSize / numBins points,
// so that the resulting number of bins is about numBins
var grouping = countsSize / numBins
val grouping = countsSize / numBins
if (grouping < 2) {
// numBins was more than half of the size; no real point in down-sampling to bins
logInfo(s"Curve is too small ($countsSize) for $numBins bins to be useful")
counts
} else {
if (grouping >= Int.MaxValue) {
logWarning(
s"Curve too large ($countsSize) for $numBins bins; capping at ${Int.MaxValue}")
grouping = Int.MaxValue
counts.mapPartitions { iter =>
if (iter.hasNext) {
var score = Double.NaN
var agg = new BinaryLabelCounter()
var cnt = 0L
iter.flatMap { pair =>
score = pair._1
agg += pair._2
cnt += 1
if (cnt == grouping) {
// The score of the combined point will be just the last one's score,
// which is also the minimal in each chunk since all scores are already
// sorted in descending.
// The combined point will contain all counts in this chunk. Thus, calculated
// metrics (like precision, recall, etc.) on its score (or so-called threshold)
// are the same as those without sampling.
val ret = (score, agg)
agg = new BinaryLabelCounter()
cnt = 0
Some(ret)
} else None
} ++ {
if (cnt > 0) {
Iterator.single((score, agg))
} else Iterator.empty
}
} else Iterator.empty
}
counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
// The score of the combined point will be just the last one's score, which is also
// the minimal in each chunk since all scores are already sorted in descending.
val lastScore = pairs.last._1
// The combined point will contain all counts in this chunk. Thus, calculated
// metrics (like precision, recall, etc.) on its score (or so-called threshold) are
// the same as those without sampling.
val agg = new BinaryLabelCounter()
pairs.foreach(pair => agg += pair._2)
(lastScore, agg)
})
}
}

Expand Down

0 comments on commit 14bb639

Please sign in to comment.