Skip to content

Commit

Permalink
More tests, and ability to sort data if a total ordering is given
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Jul 30, 2014
1 parent e1f84be commit 4b7a5ce
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io._
import java.util.Comparator

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable

import com.google.common.io.ByteStreams

Expand Down Expand Up @@ -88,17 +89,36 @@ private[spark] class ExternalSorter[K, V, C](
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}

// For now, just compare them by partition; later we can compare by key as well
private val comparator = new Comparator[((Int, K), C)] {
override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = {
a._1._1 - b._1._1
// A comparator for ((Int, K), C) elements that orders them by partition and then possibly
// by key if we want to sort data within each partition
private val partitionKeyComparator: Comparator[((Int, K), C)] = {
if (ordering.isDefined) {
// We want to sort the data by key
val ord = ordering.get
new Comparator[((Int, K), C)] {
override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = {
val partitionDiff = a._1._1 - b._1._1
if (partitionDiff != 0) {
partitionDiff
} else {
ord.compare(a._1._2, b._1._2)
}
}
}
} else {
// Just sort it by partition
new Comparator[((Int, K), C)] {
override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = {
a._1._1 - b._1._1
}
}
}
}

// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
private case class SpilledFile(
private[this] case class SpilledFile(
file: File,
blockId: BlockId,
serializerBatchSizes: ArrayBuffer[Long],
Expand Down Expand Up @@ -171,7 +191,7 @@ private[spark] class ExternalSorter[K, V, C](
*
* @param usingMap whether we're using a map or buffer as our current in-memory collection
*/
def spill(memorySize: Long, usingMap: Boolean): Unit = {
private def spill(memorySize: Long, usingMap: Boolean): Unit = {
val collection: SizeTrackingCollection[((Int, K), C)] = if (usingMap) map else buffer
val memorySize = collection.estimateSize()

Expand Down Expand Up @@ -199,7 +219,7 @@ private[spark] class ExternalSorter[K, V, C](
}

try {
val it = collection.destructiveSortedIterator(comparator)
val it = collection.destructiveSortedIterator(partitionKeyComparator)
while (it.hasNext) {
val elem = it.next()
val partitionId = elem._1._1
Expand Down Expand Up @@ -240,31 +260,126 @@ private[spark] class ExternalSorter[K, V, C](
* Merge a sequence of sorted files, giving an iterator over partitions and then over elements
* inside each partition. This can be used to either write out a new file or return data to
* the user.
*
* Returns an iterator over all the data written to this object, grouped by partition. For each
* partition we then have an iterator over its contents, and these are expected to be accessed
* in order (you can't "skip ahead" to one partition without reading the previous one).
* Guaranteed to return a key-value pair for each partition, in order of partition ID.
*/
def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
// TODO: merge intermediate results if they are sorted by the comparator
val readers = spills.map(new SpillReader(_))
val inMemBuffered = inMemory.buffered
(0 until numPartitions).iterator.map { p =>
val inMemIterator = new Iterator[(K, C)] {
val inMemIterator = new Iterator[Product2[K, C]] {
override def hasNext: Boolean = {
inMemBuffered.hasNext && inMemBuffered.head._1._1 == p
}
override def next(): (K, C) = {
override def next(): Product2[K, C] = {
val elem = inMemBuffered.next()
(elem._1._2, elem._2)
}
}
(p, readers.iterator.flatMap(_.readNextPartition()) ++ inMemIterator)
val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
if (aggregator.isDefined && ordering.isDefined) {
(p, mergeWithAggregation(
iterators, aggregator.get.mergeCombiners, ordering.get, totalOrder = true))
} else if (ordering.isDefined) {
(p, mergeSort(iterators, ordering.get))
} else {
(p, iterators.iterator.flatten)
}
}
}

/**
* Merge-sort a bunch of (K, C) iterators using a given ordering on the keys. Assumed to
* be a total ordering.
*/
private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
: Iterator[Product2[K, C]] =
{
val bufferedIters = iterators.map(_.buffered)
type Iter = BufferedIterator[Product2[K, C]]
val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
})
heap.enqueue(bufferedIters: _*)
new Iterator[Product2[K, C]] {
override def hasNext: Boolean = !heap.isEmpty

override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val firstBuf = heap.dequeue()
val firstPair = firstBuf.next()
if (firstBuf.hasNext) {
heap.enqueue(firstBuf)
}
firstPair
}
}
}

/**
* Merge a bunch of (K, C) iterators by aggregating values for the same key, assuming that each
* iterator is sorted by key using our comparator. If the comparator is not a total ordering
* (e.g. when we sort objects by hash code and different keys may compare as equal although
* they're not), we still merge them by doing equality tests for all keys that compare as equal.
*/
private def mergeWithAggregation(
iterators: Seq[Iterator[Product2[K, C]]],
mergeCombiners: (C, C) => C,
comparator: Comparator[K],
totalOrder: Boolean)
: Iterator[Product2[K, C]] =
{
require(totalOrder, "non-total order not yet supported")
val bufferedIters = iterators.map(_.buffered)
type Iter = BufferedIterator[Product2[K, C]]
val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
})
heap.enqueue(bufferedIters: _*)
new Iterator[Product2[K, C]] {
override def hasNext: Boolean = !heap.isEmpty

override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val firstBuf = heap.dequeue()
val firstPair = firstBuf.next()
val k = firstPair._1
var c = firstPair._2
if (firstBuf.hasNext) {
heap.enqueue(firstBuf)
}
var shouldStop = false
while (!heap.isEmpty && !shouldStop) {
shouldStop = true
val newBuf = heap.dequeue()
while (newBuf.hasNext && newBuf.head._1 == k) {
val elem = newBuf.next()
c = mergeCombiners(c, elem._2)
shouldStop = false
}
if (newBuf.hasNext) {
heap.enqueue(newBuf)
}
}
(k, c)
}
}
}

/**
* An internal class for reading a spilled file partition by partition. Expects all the
* partitions to be requested in order.
*/
private class SpillReader(spill: SpilledFile) {
private[this] class SpillReader(spill: SpilledFile) {
val fileStream = new FileInputStream(spill.file)
val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)

Expand Down Expand Up @@ -371,7 +486,7 @@ private[spark] class ExternalSorter[K, V, C](
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
val collection: SizeTrackingCollection[((Int, K), C)] = if (usingMap) map else buffer
merge(spills, collection.destructiveSortedIterator(comparator))
merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val resultA = rddA.reduceByKey(math.max).collect()
assert(resultA.length == 50000)
resultA.foreach { case(k, v) =>
k match {
case 0 => assert(v == 1)
case 25000 => assert(v == 50001)
case 49999 => assert(v == 99999)
case _ =>
if (v != k * 2 + 1) {
fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
}
}

Expand All @@ -221,11 +218,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val resultB = rddB.groupByKey().collect()
assert(resultB.length == 25000)
resultB.foreach { case(i, seq) =>
i match {
case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003))
case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999))
case _ =>
val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
if (seq.toSet != expected) {
fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
}
}

Expand All @@ -239,6 +234,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
case 0 =>
assert(seq1.toSet == Set[Int](0))
assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
case 1 =>
assert(seq1.toSet == Set[Int](1))
assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
case 5000 =>
assert(seq1.toSet == Set[Int](5000))
assert(seq2.toSet == Set[Int]())
Expand Down
Loading

0 comments on commit 4b7a5ce

Please sign in to comment.