Skip to content

Commit

Permalink
Review comments and more tests (e.g. tests with 1 element per partition)
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Jul 30, 2014
1 parent e9ad356 commit 5461cbb
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,34 @@ import org.apache.spark.storage.BlockId
* If combining is disabled, the type C must equal V -- we'll cast the objects at the end.
*
* @param aggregator optional Aggregator with combine functions to use for merging data
* @param partitioner optional partitioner; if given, sort by partition ID and then key
* @param ordering optional ordering to sort keys within each partition
* @param partitioner optional Partitioner; if given, sort by partition ID and then key
* @param ordering optional Ordering to sort keys within each partition; should be a total ordering
* @param serializer serializer to use when spilling to disk
*
* Note that if an Ordering is given, we'll always sort using it, so only provide it if you really
* want the output keys to be sorted. In a map task without map-side combine for example, you
* probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do
* want to do combining, having an Ordering is more efficient than not having it.
*
* At a high level, this class works as follows:
*
* - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if
* we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers,
* we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to
* avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner).
*
* - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first
* by partition ID and possibly second by key or by hash code of the key, if we want to do
* aggregation. For each file, we track how many objects were in each partition in memory, so we
* don't have to write out the partition ID for every element.
*
* - When the user requests an iterator, the spilled files are merged, along with any remaining
* in-memory data, using the same sort order defined above (unless both sorting and aggregation
* are disabled). If we need to aggregate by key, we either use a total ordering from the
* ordering parameter, or read the keys with the same hash code and compare them with each other
* for equality to merge values.
*
* - Users are expected to call stop() at the end to delete all the intermediate files.
*/
private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
Expand Down Expand Up @@ -213,7 +238,7 @@ private[spark] class ExternalSorter[K, V, C](
.format(memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
var objectsWritten = 0
var objectsWritten = 0 // Objects written since the last flush

// List of batch sizes (bytes) in the order they are written to disk
val batchSizes = new ArrayBuffer[Long]
Expand All @@ -227,7 +252,6 @@ private[spark] class ExternalSorter[K, V, C](
val bytesWritten = writer.bytesWritten
batchSizes.append(bytesWritten)
_diskBytesSpilled += bytesWritten
objectsWritten = 0
}

try {
Expand All @@ -244,6 +268,7 @@ private[spark] class ExternalSorter[K, V, C](

if (objectsWritten == serializerBatchSize) {
flush()
objectsWritten = 0
writer.close()
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
}
Expand All @@ -256,6 +281,7 @@ private[spark] class ExternalSorter[K, V, C](
case e: Exception =>
writer.close()
file.delete()
throw e
}

if (usingMap) {
Expand All @@ -280,7 +306,6 @@ private[spark] class ExternalSorter[K, V, C](
*/
private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
// TODO: merge intermediate results if they are sorted by the comparator
val readers = spills.map(new SpillReader(_))
val inMemBuffered = inMemory.buffered
(0 until numPartitions).iterator.map { p =>
Expand Down Expand Up @@ -315,7 +340,7 @@ private[spark] class ExternalSorter[K, V, C](
* Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
*/
private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
: Iterator[Product2[K, C]] =
: Iterator[Product2[K, C]] =
{
val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
type Iter = BufferedIterator[Product2[K, C]]
Expand Down Expand Up @@ -356,50 +381,46 @@ private[spark] class ExternalSorter[K, V, C](
if (!totalOrder) {
// We only have a partial ordering, e.g. comparing the keys by hash code, which means that
// multiple distinct keys might be treated as equal by the ordering. To deal with this, we
// need to buffer every set of keys considered equal by the comparator in memory, then do
// another pass through them to find the truly equal ones.
val sorted = mergeSort(iterators, comparator).buffered
// Buffers reused across keys to decrease memory allocation
val buf = new ArrayBuffer[(K, C)]
val toReturn = new ArrayBuffer[(K, C)]
// need to read all keys considered equal by the ordering at once and compare them.
new Iterator[Iterator[Product2[K, C]]] {
val sorted = mergeSort(iterators, comparator).buffered

// Buffers reused across elements to decrease memory allocation
val keys = new ArrayBuffer[K]
val combiners = new ArrayBuffer[C]

override def hasNext: Boolean = sorted.hasNext

override def next(): Iterator[Product2[K, C]] = {
if (!hasNext) {
throw new NoSuchElementException
}
keys.clear()
combiners.clear()
val firstPair = sorted.next()
buf += ((firstPair._1, firstPair._2)) // Copy it in case the Product2 object is reused
keys += firstPair._1
combiners += firstPair._2
val key = firstPair._1
while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
val n = sorted.next()
buf += ((n._1, n._2))
}
// buf now contains all the elements with keys equal to our first one according to the
// partial ordering. Now we need to find which keys were "really" equal, which we do
// through linear scans through the buffer.
toReturn.clear()
while (!buf.isEmpty) {
val last = buf(buf.size - 1)
buf.reduceToSize(buf.size - 1)
val k = last._1
var c = last._2
val pair = sorted.next()
var i = 0
while (i < buf.size) {
while (i < buf.size && buf(i)._1 == k) {
c = mergeCombiners(c, buf(i)._2)
// Replace this element with the last one in the buffer
buf(i) = buf(buf.size - 1)
buf.reduceToSize(buf.size - 1)
var foundKey = false
while (i < keys.size && !foundKey) {
if (keys(i) == pair._1) {
combiners(i) = mergeCombiners(combiners(i), pair._2)
foundKey = true
}
i += 1
}
toReturn += ((k, c))
if (!foundKey) {
keys += pair._1
combiners += pair._2
}
}
// Note that we return a *sequence* of elements since we could've had many keys marked

// Note that we return an iterator of elements since we could've had many keys marked
// equal by the partial order; we flatten this below to get a flat iterator of (K, C).
toReturn.iterator
keys.iterator.zip(combiners.iterator)
}
}.flatMap(i => i)
} else {
Expand Down Expand Up @@ -482,39 +503,33 @@ private[spark] class ExternalSorter[K, V, C](
* If no more pairs are left, return null.
*/
private def readNextItem(): (K, C) = {
try {
if (finished) {
return null
}
val k = deserStream.readObject().asInstanceOf[K]
val c = deserStream.readObject().asInstanceOf[C]
// Start reading the next batch if we're done with this one
indexInBatch += 1
if (indexInBatch == serializerBatchSize) {
batchStream = nextBatchStream()
compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
deserStream = serInstance.deserializeStream(compressedStream)
indexInBatch = 0
}
// Update the partition location of the element we're reading
indexInPartition += 1
while (indexInPartition == spill.elementsPerPartition(partitionId)) {
partitionId += 1
indexInPartition = 0
}
if (partitionId == numPartitions - 1 &&
indexInPartition == spill.elementsPerPartition(partitionId) - 1) {
// This is the last element, remember that we're done
finished = true
deserStream.close()
}
(k, c)
} catch {
case e: EOFException =>
finished = true
deserStream.close()
null
if (finished) {
return null
}
val k = deserStream.readObject().asInstanceOf[K]
val c = deserStream.readObject().asInstanceOf[C]
// Start reading the next batch if we're done with this one
indexInBatch += 1
if (indexInBatch == serializerBatchSize) {
batchStream = nextBatchStream()
compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
deserStream = serInstance.deserializeStream(compressedStream)
indexInBatch = 0
}
// Update the partition location of the element we're reading, possibly skipping zero-length
// partitions until we get to the next non-empty one or to EOF.
indexInPartition += 1
while (indexInPartition == spill.elementsPerPartition(partitionId)) {
partitionId += 1
indexInPartition = 0
}
if (partitionId == numPartitions - 1 &&
indexInPartition == spill.elementsPerPartition(partitionId) - 1) {
// This is the last element, remember that we're done
finished = true
deserStream.close()
}
(k, c)
}

var nextPartitionToRead = 0
Expand All @@ -530,7 +545,9 @@ private[spark] class ExternalSorter[K, V, C](
return false
}
}
// Check that we're still in the right partition; will be numPartitions at EOF
assert(partitionId >= myPartition)
// Check that we're still in the right partition; note that readNextItem will have returned
// null at EOF above so we would've returned false there
partitionId == myPartition
}

Expand All @@ -546,10 +563,11 @@ private[spark] class ExternalSorter[K, V, C](
}

/**
* Return an iterator over all the data written to this object, grouped by partition. For each
* partition we then have an iterator over its contents, and these are expected to be accessed
* in order (you can't "skip ahead" to one partition without reading the previous one).
* Guaranteed to return a key-value pair for each partition, in order of partition ID.
* Return an iterator over all the data written to this object, grouped by partition and
* aggregated by the requested aggregator. For each partition we then have an iterator over its
* contents, and these are expected to be accessed in order (you can't "skip ahead" to one
* partition without reading the previous one). Guaranteed to return a key-value pair for each
* partition, in order of partition ID.
*
* For now, we just merge all the spilled files in once pass, but this can be modified to
* support hierarchical merging.
Expand All @@ -561,7 +579,7 @@ private[spark] class ExternalSorter[K, V, C](
}

/**
* Return an iterator over all the data written to this object.
* Return an iterator over all the data written to this object, aggregated by our aggregator.
*/
def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,44 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
assert(sorter4.iterator.toSeq === Seq())
}

test("few elements per partition") {
val conf = new SparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)

val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
val ord = implicitly[Ordering[Int]]
val elements = Set((1, 1), (2, 2), (5, 5))
val expected = Set(
(0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()),
(5, Set((5, 5))), (6, Set()))

// Both aggregator and ordering
val sorter = new ExternalSorter[Int, Int, Int](
Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
sorter.write(elements.iterator)
assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)

// Only aggregator
val sorter2 = new ExternalSorter[Int, Int, Int](
Some(agg), Some(new HashPartitioner(7)), None, None)
sorter2.write(elements.iterator)
assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)

// Only ordering
val sorter3 = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(7)), Some(ord), None)
sorter3.write(elements.iterator)
assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)

// Neither aggregator nor ordering
val sorter4 = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(7)), None, None)
sorter4.write(elements.iterator)
assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
}

test("spilling in local cluster") {
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
Expand Down

0 comments on commit 5461cbb

Please sign in to comment.