From 254e03c96e1f2aaa5baa9c3d384adeb117e0b7ab Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 3 Jul 2014 13:49:46 -0700 Subject: [PATCH] minor fixes and Java API. punting on python for now. moved aggregateWithContext out of RDD --- .../apache/spark/api/java/JavaPairRDD.scala | 34 +++++++++++++++- .../apache/spark/rdd/PairRDDFunctions.scala | 30 +++++++------- .../main/scala/org/apache/spark/rdd/RDD.scala | 21 ---------- .../spark/util/random/StratifiedSampler.scala | 40 ++++++++++++++++--- .../spark/rdd/PairRDDFunctionsSuite.scala | 8 ++-- .../scala/org/apache/spark/rdd/RDDSuite.scala | 32 --------------- 6 files changed, 85 insertions(+), 80 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 14fa9d8135afe..e4aa46deb831b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList} +import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} import scala.collection.JavaConversions._ @@ -129,6 +129,38 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + exact: Boolean, + seed: Long): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed)) + + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + exact: Boolean): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + seed: Long): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, true, seed) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + */ + def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, true, Utils.random.nextLong) + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index fe56fe7ba6749..3c563ff032d06 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -19,12 +19,10 @@ package org.apache.spark.rdd import java.nio.ByteBuffer import java.text.SimpleDateFormat -import java.util.Date -import java.util.{HashMap => JHashMap} +import java.util.{Date, HashMap => JHashMap} +import scala.collection.{Map, mutable} import scala.collection.JavaConversions._ -import scala.collection.Map -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -34,16 +32,14 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob, +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import org.apache.spark._ -import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.SparkHadoopWriter import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils @@ -216,24 +212,26 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * need two additional passes. * * @param withReplacement whether to sample with or without replacement - * @param fractionByKey function mapping key to sampling rate + * @param fractions map of specific keys to sampling rates * @param seed seed for the random number generator * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per stratum * @return RDD containing the sampled subset */ def sampleByKey(withReplacement: Boolean, - fractionByKey: Map[K, Double], - seed: Long = Utils.random.nextLong, - exact: Boolean = true): RDD[(K, V)]= { - require(fractionByKey.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.") + fractions: Map[K, Double], + exact: Boolean = true, + seed: Long = Utils.random.nextLong): RDD[(K, V)]= { + + require(fractions.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.") + if (withReplacement) { val counts = if (exact) Some(this.countByKey()) else None val samplingFunc = - StratifiedSampler.getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed) + StratifiedSampler.getPoissonSamplingFunction(self, fractions, exact, counts, seed) self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } else { val samplingFunc = - StratifiedSampler.getBernoulliSamplingFunction(self, fractionByKey, exact, seed) + StratifiedSampler.getBernoulliSamplingFunction(self, fractions, exact, seed) self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9edc73d1c5774..7a44cce273622 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -875,27 +875,6 @@ abstract class RDD[T: ClassTag]( jobResult } - /** - * A version of {@link #aggregate()} that passes the TaskContext to the function that does - * aggregation for each partition. - */ - def aggregateWithContext[U: ClassTag](zeroValue: U)(seqOp: ((TaskContext, U), T) => U, - combOp: (U, U) => U): U = { - // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) - // pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce - val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item)) - val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) => - (arg1._1, combOp(arg1._2, arg1._2)) - val cleanSeqOp = sc.clean(paddedSeqOp) - val cleanCombOp = sc.clean(paddedcombOp) - val aggregatePartition = (tc: TaskContext, it: Iterator[T]) => - (it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2 - val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) - sc.runJob(this, aggregatePartition, mergeResult) - jobResult - } - /** * Return the number of elements in the RDD. */ diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala index 3535fa8387aff..1dd586d752efc 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSampler.scala @@ -17,15 +17,43 @@ package org.apache.spark.util.random +import scala.collection.{Map, mutable} import scala.collection.mutable.ArrayBuffer -import scala.collection.{mutable, Map} +import scala.reflect.ClassTag + import org.apache.commons.math3.random.RandomDataGenerator -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.util.random.{PoissonBounds => PB} -import scala.Some +import org.apache.spark.{Logging, SparkContext, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils +import org.apache.spark.util.random.{PoissonBounds => PB} private[spark] object StratifiedSampler extends Logging { + + /** + * A version of {@link #aggregate()} that passes the TaskContext to the function that does + * aggregation for each partition. This function avoids creating an extra depth in the RDD + * lineage, as opposed to using mapPartitionsWithId, which results in slightly improved run time. + */ + def aggregateWithContext[U: ClassTag, T: ClassTag](zeroValue: U) + (rdd: RDD[T], + seqOp: ((TaskContext, U), T) => U, + combOp: (U, U) => U): U = { + val sc: SparkContext = rdd.sparkContext + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + // pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce + val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item)) + val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) => + (arg1._1, combOp(arg1._2, arg1._2)) + val cleanSeqOp = sc.clean(paddedSeqOp) + val cleanCombOp = sc.clean(paddedcombOp) + val aggregatePartition = (tc: TaskContext, it: Iterator[T]) => + (it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2 + val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) + sc.runJob(rdd, aggregatePartition, mergeResult) + jobResult + } + /** * Returns the function used by aggregate to collect sampling statistics for each partition. */ @@ -153,7 +181,7 @@ private[spark] object StratifiedSampler extends Logging { val seqOp = StratifiedSampler.getSeqOp[K,V](false, fractionByKey, None) val combOp = StratifiedSampler.getCombOp[K]() val zeroU = new Result[K](Map[K, Stratum](), seed = seed) - val finalResult = rdd.aggregateWithContext(zeroU)(seqOp, combOp).resultMap + val finalResult = aggregateWithContext(zeroU)(rdd, seqOp, combOp).resultMap samplingRateByKey = StratifiedSampler.computeThresholdByKey(finalResult, fractionByKey) } (idx: Int, iter: Iterator[(K, V)]) => { @@ -183,7 +211,7 @@ private[spark] object StratifiedSampler extends Logging { val seqOp = StratifiedSampler.getSeqOp[K,V](true, fractionByKey, counts) val combOp = StratifiedSampler.getCombOp[K]() val zeroU = new Result[K](Map[K, Stratum](), seed = seed) - val finalResult = rdd.aggregateWithContext(zeroU)(seqOp, combOp).resultMap + val finalResult = aggregateWithContext(zeroU)(rdd, seqOp, combOp).resultMap val thresholdByKey = StratifiedSampler.computeThresholdByKey(finalResult, fractionByKey) (idx: Int, iter: Iterator[(K, V)]) => { val random = new RandomDataGenerator() diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 4ac99a9dc6824..06ea6cef67000 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -106,8 +106,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { n: Long) = { val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) - val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(false, fractionByKey, seed, exact) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = stratifiedData.sampleByKey(false, fractions, exact, seed) val sampleCounts = sample.countByKey() val takeSample = sample.collect() assert(sampleCounts.forall({case(k,v) => @@ -124,8 +124,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { n: Long) = { val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) - val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(true, fractionByKey, seed, exact) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = stratifiedData.sampleByKey(true, fractions, exact, seed) val sampleCounts = sample.countByKey() val takeSample = sample.collect() assert(sampleCounts.forall({case(k,v) => diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index a3094a2fd6262..0e5625b7645d5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -141,38 +141,6 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } - test("aggregateWithContext") { - val data = Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)) - val numPartitions = 2 - val pairs = sc.makeRDD(data, numPartitions) - //determine the partitionId for each pair - type StringMap = HashMap[String, Int] - val partitions = pairs.collectPartitions() - val offSets = new StringMap - for (i <- 0 to numPartitions - 1) { - partitions(i).foreach({ case (k, v) => offSets.put(k, offSets.getOrElse(k, 0) + i)}) - } - val emptyMap = new StringMap { - override def default(key: String): Int = 0 - } - val mergeElement: ((TaskContext, StringMap), (String, Int)) => StringMap = (arg1, pair) => { - val stringMap = arg1._2 - val tc = arg1._1 - stringMap(pair._1) += pair._2 + tc.partitionId - stringMap - } - val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { - for ((key, value) <- map2) { - map1(key) += value - } - map1 - } - val result = pairs.aggregateWithContext(emptyMap)(mergeElement, mergeMaps) - val expected = Set(("a", 6), ("b", 2), ("c", 5)) - .map({ case (k, v) => (k -> (offSets.getOrElse(k, 0) + v))}) - assert(result.toSet === expected) - } - test("basic caching") { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4))