From 0214a7659c62e4ff0f68f6e09cd7846547cd3bcb Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Tue, 17 Jun 2014 19:22:32 -0700 Subject: [PATCH] cleanUp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addressed reviewer comments and added better documentation of code. Added commons-math3 as a dependency of spark (okay’ed by Matei). “mvm clean install” compiled. Recovered files that were reverted by accident in the merge. TODOs: figure out API for sampleByKeyExact and update Java, Python, and the markdown file accordingly. --- core/pom.xml | 2 - .../main/scala/org/apache/spark/rdd/RDD.scala | 21 +++++++ .../spark/util/random/SamplingUtils.scala | 36 ++++++++++-- .../spark/util/random/StratifiedSampler.scala | 58 ++++++++++++++----- .../scala/org/apache/spark/rdd/RDDSuite.scala | 32 ++++++++++ 5 files changed, 129 insertions(+), 20 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index bd6767e03bb9d..70b6674c83bdc 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -70,8 +70,6 @@ org.apache.commons commons-math3 - 3.3 - test com.google.code.findbugs diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1633b185861b9..2b49241c579a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -874,6 +874,27 @@ abstract class RDD[T: ClassTag]( jobResult } + /** + * A version of {@link #aggregate()} that passes the TaskContext to the function that does + * aggregation for each partition. + */ + def aggregateWithContext[U: ClassTag](zeroValue: U)(seqOp: ((TaskContext, U), T) => U, + combOp: (U, U) => U): U = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + // pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce + val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item)) + val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) => + (arg1._1, combOp(arg1._2, arg1._2)) + val cleanSeqOp = sc.clean(paddedSeqOp) + val cleanCombOp = sc.clean(paddedcombOp) + val aggregatePartition = (tc: TaskContext, it: Iterator[T]) => + (it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2 + val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) + sc.runJob(this, aggregatePartition, mergeResult) + jobResult + } + /** * Return the number of elements in the RDD. */ diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index fcfb6c97e8932..24168fe2c6cf1 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.random -import org.apache.commons.math3.distribution.{PoissonDistribution, NormalDistribution} +import org.apache.commons.math3.distribution.PoissonDistribution private[spark] object SamplingUtils { @@ -43,7 +43,7 @@ private[spark] object SamplingUtils { * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate */ def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, - withReplacement: Boolean): Double = { + withReplacement: Boolean): Double = { val fraction = sampleSizeLowerBound.toDouble / total if (withReplacement) { val numStDev = if (sampleSizeLowerBound < 12) 9 else 5 @@ -56,12 +56,29 @@ private[spark] object SamplingUtils { } } +/** + * Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact + * sample sizes with high confidence when sampling with replacement. + * + * The algorithm for guaranteeing sample size instantly accepts items whose associated value drawn + * from Pois(s) is less than the lower bound and puts items whose value is between the lower and + * upper bound in a waitlist. The final sample is consisted of all items accepted on the fly and a + * portion of the waitlist needed to make the exact sample size. + */ private[spark] object PoissonBounds { val delta = 1e-4 / 3.0 - val phi = new NormalDistribution().cumulativeProbability(1.0 - delta) - def getLambda1(s: Double): Double = { + /** + * Compute the threshold for accepting items on the fly. The threshold value is a fairly small + * number, which means if the item has an associated value < threshold, it is highly likely to + * be in the final sample. Hence we accept items with values less than the returned value of this + * function instantly. + * + * @param s sample size + * @return threshold for accepting items on the fly + */ + def getLowerBound(s: Double): Double = { var lb = math.max(0.0, s - math.sqrt(s / delta)) // Chebyshev's inequality var ub = s while (lb < ub - 1.0) { @@ -79,7 +96,16 @@ private[spark] object PoissonBounds { poisson.inverseCumulativeProbability(delta) } - def getLambda2(s: Double): Double = { + /** + * Compute the threshold for waitlisting items. An item is waitlisted if its associated value is + * greater than the lower bound determined above but below the upper bound computed here. + * The value is computed such that we only need to keep log(s) items in the waitlist and still be + * able to guarantee sample size with high confidence. + * + * @param s sample size + * @return threshold for waitlisting the item + */ + def getUpperBound(s: Double): Double = { var lb = s var ub = s + math.sqrt(s / delta) // Chebyshev's inequality while (lb < ub - 1.0) { diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala index 5ddd78d715174..3535fa8387aff 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala @@ -26,6 +26,9 @@ import scala.Some import org.apache.spark.rdd.RDD private[spark] object StratifiedSampler extends Logging { + /** + * Returns the function used by aggregate to collect sampling statistics for each partition. + */ def getSeqOp[K, V](withReplacement: Boolean, fractionByKey: (K => Double), counts: Option[Map[K, Long]]): ((TaskContext, Result[K]),(K, V)) => Result[K] = { @@ -43,9 +46,9 @@ private[spark] object StratifiedSampler extends Logging { if (stratum.q1.isEmpty || stratum.q2.isEmpty) { val n = counts.get(item._1) val s = math.ceil(n * fraction).toLong - val lmbd1 = PB.getLambda1(s) + val lmbd1 = PB.getLowerBound(s) val minCount = PB.getMinCount(lmbd1) - val lmbd2 = if (lmbd1 == 0) PB.getLambda2(s) else PB.getLambda2(s - minCount) + val lmbd2 = if (lmbd1 == 0) PB.getUpperBound(s) else PB.getUpperBound(s - minCount) val q1 = lmbd1 / n val q2 = lmbd2 / n stratum.q1 = Some(q1) @@ -60,6 +63,8 @@ private[spark] object StratifiedSampler extends Logging { stratum.addToWaitList(ArrayBuffer.fill(x2)(rng.nextUniform(0.0, 1.0))) } } else { + // We use the streaming version of the algorithm for sampling without replacement. + // Hence, q1 and q2 change on every iteration. val g1 = - math.log(delta) / stratum.numItems val g2 = (2.0 / 3.0) * g1 val q1 = math.max(0, fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction))) @@ -79,7 +84,11 @@ private[spark] object StratifiedSampler extends Logging { } } - def getCombOp[K](): (Result[K], Result[K]) => Result[K] = { + /** + * Returns the function used by aggregate to combine results from different partitions, as + * returned by seqOp. + */ + def getCombOp[K](): (Result[K], Result[K]) => Result[K] = { (r1: Result[K], r2: Result[K]) => { // take union of both key sets in case one partition doesn't contain all keys val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet) @@ -100,6 +109,10 @@ private[spark] object StratifiedSampler extends Logging { } } + /** + * Given the result returned by the aggregate function, we need to determine the threshold used + * to accept items to generate the exact sample size. + */ def computeThresholdByKey[K](finalResult: Map[K, Stratum], fractionByKey: (K => Double)): (K => Double) = { val thresholdByKey = new mutable.HashMap[K, Double]() @@ -122,11 +135,15 @@ private[spark] object StratifiedSampler extends Logging { thresholdByKey } - def computeThresholdByKey[K](finalResult: Map[K, String]): (K => String) = { - finalResult - } - - def getBernoulliSamplingFunction[K, V](rdd:RDD[(K, V)], + /** + * Return the per partition sampling function used for sampling without replacement. + * + * When exact sample size is required, we make an additional pass over the RDD to determine the + * exact sampling rate that guarantees sample size with high confidence. + * + * The sampling function has a unique seed per partition. + */ + def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)], fractionByKey: K => Double, exact: Boolean, seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { @@ -146,6 +163,16 @@ private[spark] object StratifiedSampler extends Logging { } } + /** + * Return the per partition sampling function used for sampling with replacement. + * + * When exact sample size is required, we make two additional passed over the RDD to determine + * the exact sampling rate that guarantees sample size with high confidence. The first pass + * counts the number of items in each stratum (group of items with the same key) in the RDD, and + * the second pass uses the counts to determine exact sampling rates. + * + * The sampling function has a unique seed per partition. + */ def getPoissonSamplingFunction[K, V](rdd:RDD[(K, V)], fractionByKey: K => Double, exact: Boolean, @@ -191,6 +218,10 @@ private[spark] object StratifiedSampler extends Logging { } } +/** + * Object used by seqOp to keep track of the number of items accepted and items waitlisted per + * stratum, as well as the bounds for accepting and waitlisting items. + */ private[random] class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0L) extends Serializable { @@ -205,13 +236,14 @@ private[random] class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0 def addToWaitList(elem: Double) = waitList += elem def addToWaitList(elems: ArrayBuffer[Double]) = waitList ++= elems - - override def toString() = { - "numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 + - " waitListSize:" + waitList.size - } } +/** + * Object used by seqOp and combOp to keep track of the sampling statistics for all strata. + * + * When used by seqOp for each partition, we also keep track of the partition ID in this object + * to make sure a single random number generator with a unique seed is used for each partition. + */ private[random] class Result[K](var resultMap: Map[K, Stratum], var cachedPartitionId: Option[Int] = None, val seed: Long) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 0e5625b7645d5..a3094a2fd6262 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -141,6 +141,38 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } + test("aggregateWithContext") { + val data = Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)) + val numPartitions = 2 + val pairs = sc.makeRDD(data, numPartitions) + //determine the partitionId for each pair + type StringMap = HashMap[String, Int] + val partitions = pairs.collectPartitions() + val offSets = new StringMap + for (i <- 0 to numPartitions - 1) { + partitions(i).foreach({ case (k, v) => offSets.put(k, offSets.getOrElse(k, 0) + i)}) + } + val emptyMap = new StringMap { + override def default(key: String): Int = 0 + } + val mergeElement: ((TaskContext, StringMap), (String, Int)) => StringMap = (arg1, pair) => { + val stringMap = arg1._2 + val tc = arg1._1 + stringMap(pair._1) += pair._2 + tc.partitionId + stringMap + } + val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { + for ((key, value) <- map2) { + map1(key) += value + } + map1 + } + val result = pairs.aggregateWithContext(emptyMap)(mergeElement, mergeMaps) + val expected = Set(("a", 6), ("b", 2), ("c", 5)) + .map({ case (k, v) => (k -> (offSets.getOrElse(k, 0) + v))}) + assert(result.toSet === expected) + } + test("basic caching") { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4))