diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 25337f8cb663b..fa2ab6cea8d32 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -21,6 +21,7 @@ import java.io._ import java.util.Comparator import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable import com.google.common.io.ByteStreams @@ -88,17 +89,36 @@ private[spark] class ExternalSorter[K, V, C]( (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } - // For now, just compare them by partition; later we can compare by key as well - private val comparator = new Comparator[((Int, K), C)] { - override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = { - a._1._1 - b._1._1 + // A comparator for ((Int, K), C) elements that orders them by partition and then possibly + // by key if we want to sort data within each partition + private val partitionKeyComparator: Comparator[((Int, K), C)] = { + if (ordering.isDefined) { + // We want to sort the data by key + val ord = ordering.get + new Comparator[((Int, K), C)] { + override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = { + val partitionDiff = a._1._1 - b._1._1 + if (partitionDiff != 0) { + partitionDiff + } else { + ord.compare(a._1._2, b._1._2) + } + } + } + } else { + // Just sort it by partition + new Comparator[((Int, K), C)] { + override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = { + a._1._1 - b._1._1 + } + } } } // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. - private case class SpilledFile( + private[this] case class SpilledFile( file: File, blockId: BlockId, serializerBatchSizes: ArrayBuffer[Long], @@ -171,7 +191,7 @@ private[spark] class ExternalSorter[K, V, C]( * * @param usingMap whether we're using a map or buffer as our current in-memory collection */ - def spill(memorySize: Long, usingMap: Boolean): Unit = { + private def spill(memorySize: Long, usingMap: Boolean): Unit = { val collection: SizeTrackingCollection[((Int, K), C)] = if (usingMap) map else buffer val memorySize = collection.estimateSize() @@ -199,7 +219,7 @@ private[spark] class ExternalSorter[K, V, C]( } try { - val it = collection.destructiveSortedIterator(comparator) + val it = collection.destructiveSortedIterator(partitionKeyComparator) while (it.hasNext) { val elem = it.next() val partitionId = elem._1._1 @@ -240,23 +260,118 @@ private[spark] class ExternalSorter[K, V, C]( * Merge a sequence of sorted files, giving an iterator over partitions and then over elements * inside each partition. This can be used to either write out a new file or return data to * the user. + * + * Returns an iterator over all the data written to this object, grouped by partition. For each + * partition we then have an iterator over its contents, and these are expected to be accessed + * in order (you can't "skip ahead" to one partition without reading the previous one). + * Guaranteed to return a key-value pair for each partition, in order of partition ID. */ - def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)]) + private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)]) : Iterator[(Int, Iterator[Product2[K, C]])] = { // TODO: merge intermediate results if they are sorted by the comparator val readers = spills.map(new SpillReader(_)) val inMemBuffered = inMemory.buffered (0 until numPartitions).iterator.map { p => - val inMemIterator = new Iterator[(K, C)] { + val inMemIterator = new Iterator[Product2[K, C]] { override def hasNext: Boolean = { inMemBuffered.hasNext && inMemBuffered.head._1._1 == p } - override def next(): (K, C) = { + override def next(): Product2[K, C] = { val elem = inMemBuffered.next() (elem._1._2, elem._2) } } - (p, readers.iterator.flatMap(_.readNextPartition()) ++ inMemIterator) + val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator) + if (aggregator.isDefined && ordering.isDefined) { + (p, mergeWithAggregation( + iterators, aggregator.get.mergeCombiners, ordering.get, totalOrder = true)) + } else if (ordering.isDefined) { + (p, mergeSort(iterators, ordering.get)) + } else { + (p, iterators.iterator.flatten) + } + } + } + + /** + * Merge-sort a bunch of (K, C) iterators using a given ordering on the keys. Assumed to + * be a total ordering. + */ + private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) + : Iterator[Product2[K, C]] = + { + val bufferedIters = iterators.map(_.buffered) + type Iter = BufferedIterator[Product2[K, C]] + val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { + override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + }) + heap.enqueue(bufferedIters: _*) + new Iterator[Product2[K, C]] { + override def hasNext: Boolean = !heap.isEmpty + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val firstBuf = heap.dequeue() + val firstPair = firstBuf.next() + if (firstBuf.hasNext) { + heap.enqueue(firstBuf) + } + firstPair + } + } + } + + /** + * Merge a bunch of (K, C) iterators by aggregating values for the same key, assuming that each + * iterator is sorted by key using our comparator. If the comparator is not a total ordering + * (e.g. when we sort objects by hash code and different keys may compare as equal although + * they're not), we still merge them by doing equality tests for all keys that compare as equal. + */ + private def mergeWithAggregation( + iterators: Seq[Iterator[Product2[K, C]]], + mergeCombiners: (C, C) => C, + comparator: Comparator[K], + totalOrder: Boolean) + : Iterator[Product2[K, C]] = + { + require(totalOrder, "non-total order not yet supported") + val bufferedIters = iterators.map(_.buffered) + type Iter = BufferedIterator[Product2[K, C]] + val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { + override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + }) + heap.enqueue(bufferedIters: _*) + new Iterator[Product2[K, C]] { + override def hasNext: Boolean = !heap.isEmpty + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val firstBuf = heap.dequeue() + val firstPair = firstBuf.next() + val k = firstPair._1 + var c = firstPair._2 + if (firstBuf.hasNext) { + heap.enqueue(firstBuf) + } + var shouldStop = false + while (!heap.isEmpty && !shouldStop) { + shouldStop = true + val newBuf = heap.dequeue() + while (newBuf.hasNext && newBuf.head._1 == k) { + val elem = newBuf.next() + c = mergeCombiners(c, elem._2) + shouldStop = false + } + if (newBuf.hasNext) { + heap.enqueue(newBuf) + } + } + (k, c) + } } } @@ -264,7 +379,7 @@ private[spark] class ExternalSorter[K, V, C]( * An internal class for reading a spilled file partition by partition. Expects all the * partitions to be requested in order. */ - private class SpillReader(spill: SpilledFile) { + private[this] class SpillReader(spill: SpilledFile) { val fileStream = new FileInputStream(spill.file) val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize) @@ -371,7 +486,7 @@ private[spark] class ExternalSorter[K, V, C]( def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: SizeTrackingCollection[((Int, K), C)] = if (usingMap) map else buffer - merge(spills, collection.destructiveSortedIterator(comparator)) + merge(spills, collection.destructiveSortedIterator(partitionKeyComparator)) } /** diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index e2ee62b2b54a8..7de5df6e1c8bd 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -208,11 +208,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val resultA = rddA.reduceByKey(math.max).collect() assert(resultA.length == 50000) resultA.foreach { case(k, v) => - k match { - case 0 => assert(v == 1) - case 25000 => assert(v == 50001) - case 49999 => assert(v == 99999) - case _ => + if (v != k * 2 + 1) { + fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") } } @@ -221,11 +218,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val resultB = rddB.groupByKey().collect() assert(resultB.length == 25000) resultB.foreach { case(i, seq) => - i match { - case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3)) - case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003)) - case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999)) - case _ => + val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) + if (seq.toSet != expected) { + fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") } } @@ -239,6 +234,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { case 0 => assert(seq1.toSet == Set[Int](0)) assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 1 => + assert(seq1.toSet == Set[Int](1)) + assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) case 5000 => assert(seq1.toSet == Set[Int](5000)) assert(seq2.toSet == Set[Int]()) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index c95aa9c125825..cd86b7b7263c9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -35,11 +35,8 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val resultA = rddA.reduceByKey(math.max).collect() assert(resultA.length == 50000) resultA.foreach { case(k, v) => - k match { - case 0 => assert(v == 1) - case 25000 => assert(v == 50001) - case 49999 => assert(v == 99999) - case _ => + if (v != k * 2 + 1) { + fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") } } @@ -48,11 +45,9 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val resultB = rddB.groupByKey().collect() assert(resultB.length == 25000) resultB.foreach { case(i, seq) => - i match { - case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3)) - case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003)) - case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999)) - case _ => + val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) + if (seq.toSet != expected) { + fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") } } @@ -66,6 +61,9 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { case 0 => assert(seq1.toSet == Set[Int](0)) assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 1 => + assert(seq1.toSet == Set[Int](1)) + assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) case 5000 => assert(seq1.toSet == Set[Int](5000)) assert(seq2.toSet == Set[Int]()) @@ -75,41 +73,51 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { case _ => } } + + // larger cogroup - should spill ~7 times + val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val resultD = rddD1.cogroup(rddD2).collect() + assert(resultD.length == 5000) + resultD.foreach { case(i, (seq1, seq2)) => + val expected = Set(i * 2, i * 2 + 1) + if (seq1.toSet != expected) { + fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") + } + if (seq2.toSet != expected) { + fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") + } + } } test("spilling in local cluster with many reduce tasks") { val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - // reduceByKey - should spill ~8 times + // reduceByKey - should spill ~4 times per executor val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) val resultA = rddA.reduceByKey(math.max _, 100).collect() assert(resultA.length == 50000) resultA.foreach { case(k, v) => - k match { - case 0 => assert(v == 1) - case 25000 => assert(v == 50001) - case 49999 => assert(v == 99999) - case _ => + if (v != k * 2 + 1) { + fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") } } - // groupByKey - should spill ~17 times + // groupByKey - should spill ~8 times per executor val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) val resultB = rddB.groupByKey(100).collect() assert(resultB.length == 25000) resultB.foreach { case(i, seq) => - i match { - case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3)) - case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003)) - case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999)) - case _ => + val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) + if (seq.toSet != expected) { + fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") } } - // cogroup - should spill ~7 times + // cogroup - should spill ~4 times per executor val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) val resultC = rddC1.cogroup(rddC2, 100).collect() @@ -119,6 +127,9 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { case 0 => assert(seq1.toSet == Set[Int](0)) assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 1 => + assert(seq1.toSet == Set[Int](1)) + assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) case 5000 => assert(seq1.toSet == Set[Int](5000)) assert(seq2.toSet == Set[Int]()) @@ -128,6 +139,21 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { case _ => } } + + // larger cogroup - should spill ~4 times per executor + val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val resultD = rddD1.cogroup(rddD2).collect() + assert(resultD.length == 5000) + resultD.foreach { case(i, (seq1, seq2)) => + val expected = Set(i * 2, i * 2 + 1) + if (seq1.toSet != expected) { + fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") + } + if (seq2.toSet != expected) { + fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") + } + } } test("cleanup of intermediate files in sorter") { @@ -173,7 +199,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in shuffle") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = new SparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -188,7 +214,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in shuffle with errors") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = new SparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -208,4 +234,86 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { // All other files (map 2's output and intermediate merge files) should've been deleted. assert(diskBlockManager.getAllFiles().length === 2) } + + test("no partial aggregation") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100000).iterator.map(i => (i / 4, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("partial aggregation without spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100).iterator.map(i => (i / 2, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("partial aggregation with spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("sorting without aggregation, no spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val ord = implicitly[Ordering[Int]] + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + sorter.write((0 until 100).iterator.map(i => (i, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq + val expected = (0 until 3).map(p => { + (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) + }).toSeq + assert(results === expected) + } + + test("sorting without aggregation, with spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val ord = implicitly[Ordering[Int]] + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + sorter.write((0 until 100000).iterator.map(i => (i, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq + val expected = (0 until 3).map(p => { + (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) + }).toSeq + assert(results === expected) + } }