From ef4e397fe494ac6bbfcbddb16ce5c5d8fffb75ce Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jul 2014 17:31:53 -0700 Subject: [PATCH] Support for partial aggregation even without an Ordering Groups together keys by hash code if there's no Ordering on them, similar to ExternalAppendOnlyMap --- .../shuffle/sort/SortShuffleWriter.scala | 5 +- .../util/collection/ExternalSorter.scala | 153 +++++++++++++----- .../util/collection/ExternalSorterSuite.scala | 20 ++- 3 files changed, 131 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index b418323814b84..bc051a52d0839 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -60,8 +60,11 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter.write(records) sorter.partitionedIterator } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we + // don't care whether the keys get sorted in each partition; that will be done on the + // reduce side if the operation being run is sortByKey. sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + None, Some(dep.partitioner), None, dep.serializer) sorter.write(records) sorter.partitionedIterator } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index fa2ab6cea8d32..78f71717984ae 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -89,24 +89,33 @@ private[spark] class ExternalSorter[K, V, C]( (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } - // A comparator for ((Int, K), C) elements that orders them by partition and then possibly - // by key if we want to sort data within each partition + // A comparator for keys K that orders them within a partition to allow partial aggregation. + // Can be a partial ordering by hash code if a total ordering is not provided through by the + // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some + // non-equal keys also have this, so we need to do a later pass to find truly equal keys). + // Note that we ignore this if no aggregator is given. + private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] { + override def compare(a: K, b: K): Int = a.hashCode() - b.hashCode() + }) + + private val sortWithinPartitions = ordering.isDefined || aggregator.isDefined + + // A comparator for ((Int, K), C) elements that orders them by partition and then possibly by key private val partitionKeyComparator: Comparator[((Int, K), C)] = { - if (ordering.isDefined) { - // We want to sort the data by key - val ord = ordering.get + if (sortWithinPartitions) { + // Sort by partition ID then key comparator new Comparator[((Int, K), C)] { override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = { val partitionDiff = a._1._1 - b._1._1 if (partitionDiff != 0) { partitionDiff } else { - ord.compare(a._1._2, b._1._2) + keyComparator.compare(a._1._2, b._1._2) } } } } else { - // Just sort it by partition + // Just sort it by partition ID new Comparator[((Int, K), C)] { override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = { a._1._1 - b._1._1 @@ -282,10 +291,13 @@ private[spark] class ExternalSorter[K, V, C]( } } val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator) - if (aggregator.isDefined && ordering.isDefined) { + if (aggregator.isDefined) { + // Perform partial aggregation across partitions (p, mergeWithAggregation( - iterators, aggregator.get.mergeCombiners, ordering.get, totalOrder = true)) + iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined)) } else if (ordering.isDefined) { + // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey); + // sort the elements without trying to merge them (p, mergeSort(iterators, ordering.get)) } else { (p, iterators.iterator.flatten) @@ -294,8 +306,7 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Merge-sort a bunch of (K, C) iterators using a given ordering on the keys. Assumed to - * be a total ordering. + * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys. */ private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) : Iterator[Product2[K, C]] = @@ -324,8 +335,8 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Merge a bunch of (K, C) iterators by aggregating values for the same key, assuming that each - * iterator is sorted by key using our comparator. If the comparator is not a total ordering + * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each + * iterator is sorted by key with a given comparator. If the comparator is not a total ordering * (e.g. when we sort objects by hash code and different keys may compare as equal although * they're not), we still merge them by doing equality tests for all keys that compare as equal. */ @@ -336,41 +347,95 @@ private[spark] class ExternalSorter[K, V, C]( totalOrder: Boolean) : Iterator[Product2[K, C]] = { - require(totalOrder, "non-total order not yet supported") - val bufferedIters = iterators.map(_.buffered) - type Iter = BufferedIterator[Product2[K, C]] - val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { - override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) - }) - heap.enqueue(bufferedIters: _*) - new Iterator[Product2[K, C]] { - override def hasNext: Boolean = !heap.isEmpty - - override def next(): Product2[K, C] = { - if (!hasNext) { - throw new NoSuchElementException - } - val firstBuf = heap.dequeue() - val firstPair = firstBuf.next() - val k = firstPair._1 - var c = firstPair._2 - if (firstBuf.hasNext) { - heap.enqueue(firstBuf) + if (!totalOrder) { + // We only have a partial ordering, e.g. comparing the keys by hash code, which means that + // multiple distinct keys might be treated as equal by the ordering. To deal with this, we + // need to buffer every set of keys considered equal by the comparator in memory, then do + // another pass through them to find the truly equal ones. + val sorted = mergeSort(iterators, comparator).buffered + // Buffers reused across keys to decrease memory allocation + val buf = new ArrayBuffer[(K, C)] + val toReturn = new ArrayBuffer[(K, C)] + new Iterator[Iterator[Product2[K, C]]] { + override def hasNext: Boolean = sorted.hasNext + + override def next(): Iterator[Product2[K, C]] = { + if (!hasNext) { + throw new NoSuchElementException + } + val firstPair = sorted.next() + buf += ((firstPair._1, firstPair._2)) // Copy it in case the Product2 object is reused + val key = firstPair._1 + while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) { + val n = sorted.next() + buf += ((n._1, n._2)) + } + // buf now contains all the elements with keys equal to our first one according to the + // partial ordering. Now we need to find which keys were "really" equal, which we do + // through linear scans through the buffer. + toReturn.clear() + while (!buf.isEmpty) { + val last = buf(buf.size - 1) + buf.reduceToSize(buf.size - 1) + val k = last._1 + var c = last._2 + var i = 0 + while (i < buf.size) { + while (i < buf.size && buf(i)._1 == k) { + c = mergeCombiners(c, buf(i)._2) + // Replace this element with the last one in the buffer + buf(i) = buf(buf.size - 1) + buf.reduceToSize(buf.size - 1) + } + i += 1 + } + toReturn += ((k, c)) + } + // Note that we return a *sequence* of elements since we could've had many keys marked + // equal by the partial order; we flatten this below to get a flat iterator of (K, C). + toReturn.iterator } - var shouldStop = false - while (!heap.isEmpty && !shouldStop) { - shouldStop = true - val newBuf = heap.dequeue() - while (newBuf.hasNext && newBuf.head._1 == k) { - val elem = newBuf.next() - c = mergeCombiners(c, elem._2) - shouldStop = false + }.flatMap(i => i) + } else { + // We have a total ordering. This means we can merge objects one by one as we read them + // from the iterators, without buffering all the ones that are "equal" to a given key. + // We do so with code similar to mergeSort, except our Iterator.next combines together all + // the elements with the given key. + val bufferedIters = iterators.map(_.buffered) + type Iter = BufferedIterator[Product2[K, C]] + val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { + override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + }) + heap.enqueue(bufferedIters: _*) + new Iterator[Product2[K, C]] { + override def hasNext: Boolean = !heap.isEmpty + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val firstBuf = heap.dequeue() + val firstPair = firstBuf.next() + val k = firstPair._1 + var c = firstPair._2 + if (firstBuf.hasNext) { + heap.enqueue(firstBuf) } - if (newBuf.hasNext) { - heap.enqueue(newBuf) + var shouldStop = false + while (!heap.isEmpty && !shouldStop) { + shouldStop = true // Stop unless we find another element with the same key + val newBuf = heap.dequeue() + while (newBuf.hasNext && newBuf.head._1 == k) { + val elem = newBuf.next() + c = mergeCombiners(c, elem._2) + shouldStop = false + } + if (newBuf.hasNext) { + heap.enqueue(newBuf) + } } + (k, c) } - (k, c) } } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index cd86b7b7263c9..7a11b75126d6e 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -235,7 +235,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { assert(diskBlockManager.getAllFiles().length === 2) } - test("no partial aggregation") { + test("no partial aggregation or sorting") { val conf = new SparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") @@ -266,7 +266,23 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { assert(results === expected) } - test("partial aggregation with spill") { + test("partial aggregation with spill, no ordering") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("partial aggregation with spill, with ordering") { val conf = new SparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")