Skip to content

Commit

Permalink
Support for partial aggregation even without an Ordering
Browse files Browse the repository at this point in the history
Groups together keys by hash code if there's no Ordering on them,
similar to ExternalAppendOnlyMap
  • Loading branch information
mateiz committed Jul 30, 2014
1 parent 4b7a5ce commit ef4e397
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ private[spark] class SortShuffleWriter[K, V, C](
sorter.write(records)
sorter.partitionedIterator
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we
// don't care whether the keys get sorted in each partition; that will be done on the
// reduce side if the operation being run is sortByKey.
sorter = new ExternalSorter[K, V, V](
None, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
None, Some(dep.partitioner), None, dep.serializer)
sorter.write(records)
sorter.partitionedIterator
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,24 +89,33 @@ private[spark] class ExternalSorter[K, V, C](
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}

// A comparator for ((Int, K), C) elements that orders them by partition and then possibly
// by key if we want to sort data within each partition
// A comparator for keys K that orders them within a partition to allow partial aggregation.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
// non-equal keys also have this, so we need to do a later pass to find truly equal keys).
// Note that we ignore this if no aggregator is given.
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
override def compare(a: K, b: K): Int = a.hashCode() - b.hashCode()
})

private val sortWithinPartitions = ordering.isDefined || aggregator.isDefined

// A comparator for ((Int, K), C) elements that orders them by partition and then possibly by key
private val partitionKeyComparator: Comparator[((Int, K), C)] = {
if (ordering.isDefined) {
// We want to sort the data by key
val ord = ordering.get
if (sortWithinPartitions) {
// Sort by partition ID then key comparator
new Comparator[((Int, K), C)] {
override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = {
val partitionDiff = a._1._1 - b._1._1
if (partitionDiff != 0) {
partitionDiff
} else {
ord.compare(a._1._2, b._1._2)
keyComparator.compare(a._1._2, b._1._2)
}
}
}
} else {
// Just sort it by partition
// Just sort it by partition ID
new Comparator[((Int, K), C)] {
override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = {
a._1._1 - b._1._1
Expand Down Expand Up @@ -282,10 +291,13 @@ private[spark] class ExternalSorter[K, V, C](
}
}
val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
if (aggregator.isDefined && ordering.isDefined) {
if (aggregator.isDefined) {
// Perform partial aggregation across partitions
(p, mergeWithAggregation(
iterators, aggregator.get.mergeCombiners, ordering.get, totalOrder = true))
iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
} else if (ordering.isDefined) {
// No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
// sort the elements without trying to merge them
(p, mergeSort(iterators, ordering.get))
} else {
(p, iterators.iterator.flatten)
Expand All @@ -294,8 +306,7 @@ private[spark] class ExternalSorter[K, V, C](
}

/**
* Merge-sort a bunch of (K, C) iterators using a given ordering on the keys. Assumed to
* be a total ordering.
* Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
*/
private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
: Iterator[Product2[K, C]] =
Expand Down Expand Up @@ -324,8 +335,8 @@ private[spark] class ExternalSorter[K, V, C](
}

/**
* Merge a bunch of (K, C) iterators by aggregating values for the same key, assuming that each
* iterator is sorted by key using our comparator. If the comparator is not a total ordering
* Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each
* iterator is sorted by key with a given comparator. If the comparator is not a total ordering
* (e.g. when we sort objects by hash code and different keys may compare as equal although
* they're not), we still merge them by doing equality tests for all keys that compare as equal.
*/
Expand All @@ -336,41 +347,95 @@ private[spark] class ExternalSorter[K, V, C](
totalOrder: Boolean)
: Iterator[Product2[K, C]] =
{
require(totalOrder, "non-total order not yet supported")
val bufferedIters = iterators.map(_.buffered)
type Iter = BufferedIterator[Product2[K, C]]
val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
})
heap.enqueue(bufferedIters: _*)
new Iterator[Product2[K, C]] {
override def hasNext: Boolean = !heap.isEmpty

override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val firstBuf = heap.dequeue()
val firstPair = firstBuf.next()
val k = firstPair._1
var c = firstPair._2
if (firstBuf.hasNext) {
heap.enqueue(firstBuf)
if (!totalOrder) {
// We only have a partial ordering, e.g. comparing the keys by hash code, which means that
// multiple distinct keys might be treated as equal by the ordering. To deal with this, we
// need to buffer every set of keys considered equal by the comparator in memory, then do
// another pass through them to find the truly equal ones.
val sorted = mergeSort(iterators, comparator).buffered
// Buffers reused across keys to decrease memory allocation
val buf = new ArrayBuffer[(K, C)]
val toReturn = new ArrayBuffer[(K, C)]
new Iterator[Iterator[Product2[K, C]]] {
override def hasNext: Boolean = sorted.hasNext

override def next(): Iterator[Product2[K, C]] = {
if (!hasNext) {
throw new NoSuchElementException
}
val firstPair = sorted.next()
buf += ((firstPair._1, firstPair._2)) // Copy it in case the Product2 object is reused
val key = firstPair._1
while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
val n = sorted.next()
buf += ((n._1, n._2))
}
// buf now contains all the elements with keys equal to our first one according to the
// partial ordering. Now we need to find which keys were "really" equal, which we do
// through linear scans through the buffer.
toReturn.clear()
while (!buf.isEmpty) {
val last = buf(buf.size - 1)
buf.reduceToSize(buf.size - 1)
val k = last._1
var c = last._2
var i = 0
while (i < buf.size) {
while (i < buf.size && buf(i)._1 == k) {
c = mergeCombiners(c, buf(i)._2)
// Replace this element with the last one in the buffer
buf(i) = buf(buf.size - 1)
buf.reduceToSize(buf.size - 1)
}
i += 1
}
toReturn += ((k, c))
}
// Note that we return a *sequence* of elements since we could've had many keys marked
// equal by the partial order; we flatten this below to get a flat iterator of (K, C).
toReturn.iterator
}
var shouldStop = false
while (!heap.isEmpty && !shouldStop) {
shouldStop = true
val newBuf = heap.dequeue()
while (newBuf.hasNext && newBuf.head._1 == k) {
val elem = newBuf.next()
c = mergeCombiners(c, elem._2)
shouldStop = false
}.flatMap(i => i)
} else {
// We have a total ordering. This means we can merge objects one by one as we read them
// from the iterators, without buffering all the ones that are "equal" to a given key.
// We do so with code similar to mergeSort, except our Iterator.next combines together all
// the elements with the given key.
val bufferedIters = iterators.map(_.buffered)
type Iter = BufferedIterator[Product2[K, C]]
val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
})
heap.enqueue(bufferedIters: _*)
new Iterator[Product2[K, C]] {
override def hasNext: Boolean = !heap.isEmpty

override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val firstBuf = heap.dequeue()
val firstPair = firstBuf.next()
val k = firstPair._1
var c = firstPair._2
if (firstBuf.hasNext) {
heap.enqueue(firstBuf)
}
if (newBuf.hasNext) {
heap.enqueue(newBuf)
var shouldStop = false
while (!heap.isEmpty && !shouldStop) {
shouldStop = true // Stop unless we find another element with the same key
val newBuf = heap.dequeue()
while (newBuf.hasNext && newBuf.head._1 == k) {
val elem = newBuf.next()
c = mergeCombiners(c, elem._2)
shouldStop = false
}
if (newBuf.hasNext) {
heap.enqueue(newBuf)
}
}
(k, c)
}
(k, c)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
assert(diskBlockManager.getAllFiles().length === 2)
}

test("no partial aggregation") {
test("no partial aggregation or sorting") {
val conf = new SparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
Expand Down Expand Up @@ -266,7 +266,23 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
assert(results === expected)
}

test("partial aggregation with spill") {
test("partial aggregation with spill, no ordering") {
val conf = new SparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)

val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
val expected = (0 until 3).map(p => {
(p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
}).toSet
assert(results === expected)
}

test("partial aggregation with spill, with ordering") {
val conf = new SparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
Expand Down

0 comments on commit ef4e397

Please sign in to comment.