diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 7e04d641ad96e..163ef2a75f3d2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -38,9 +38,34 @@ import org.apache.spark.storage.BlockId * If combining is disabled, the type C must equal V -- we'll cast the objects at the end. * * @param aggregator optional Aggregator with combine functions to use for merging data - * @param partitioner optional partitioner; if given, sort by partition ID and then key - * @param ordering optional ordering to sort keys within each partition + * @param partitioner optional Partitioner; if given, sort by partition ID and then key + * @param ordering optional Ordering to sort keys within each partition; should be a total ordering * @param serializer serializer to use when spilling to disk + * + * Note that if an Ordering is given, we'll always sort using it, so only provide it if you really + * want the output keys to be sorted. In a map task without map-side combine for example, you + * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do + * want to do combining, having an Ordering is more efficient than not having it. + * + * At a high level, this class works as follows: + * + * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if + * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers, + * we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to + * avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner). + * + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first + * by partition ID and possibly second by key or by hash code of the key, if we want to do + * aggregation. For each file, we track how many objects were in each partition in memory, so we + * don't have to write out the partition ID for every element. + * + * - When the user requests an iterator, the spilled files are merged, along with any remaining + * in-memory data, using the same sort order defined above (unless both sorting and aggregation + * are disabled). If we need to aggregate by key, we either use a total ordering from the + * ordering parameter, or read the keys with the same hash code and compare them with each other + * for equality to merge values. + * + * - Users are expected to call stop() at the end to delete all the intermediate files. */ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, @@ -213,7 +238,7 @@ private[spark] class ExternalSorter[K, V, C]( .format(memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) - var objectsWritten = 0 + var objectsWritten = 0 // Objects written since the last flush // List of batch sizes (bytes) in the order they are written to disk val batchSizes = new ArrayBuffer[Long] @@ -227,7 +252,6 @@ private[spark] class ExternalSorter[K, V, C]( val bytesWritten = writer.bytesWritten batchSizes.append(bytesWritten) _diskBytesSpilled += bytesWritten - objectsWritten = 0 } try { @@ -244,6 +268,7 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() + objectsWritten = 0 writer.close() writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) } @@ -256,6 +281,7 @@ private[spark] class ExternalSorter[K, V, C]( case e: Exception => writer.close() file.delete() + throw e } if (usingMap) { @@ -280,7 +306,6 @@ private[spark] class ExternalSorter[K, V, C]( */ private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)]) : Iterator[(Int, Iterator[Product2[K, C]])] = { - // TODO: merge intermediate results if they are sorted by the comparator val readers = spills.map(new SpillReader(_)) val inMemBuffered = inMemory.buffered (0 until numPartitions).iterator.map { p => @@ -315,7 +340,7 @@ private[spark] class ExternalSorter[K, V, C]( * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys. */ private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) - : Iterator[Product2[K, C]] = + : Iterator[Product2[K, C]] = { val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) type Iter = BufferedIterator[Product2[K, C]] @@ -356,50 +381,46 @@ private[spark] class ExternalSorter[K, V, C]( if (!totalOrder) { // We only have a partial ordering, e.g. comparing the keys by hash code, which means that // multiple distinct keys might be treated as equal by the ordering. To deal with this, we - // need to buffer every set of keys considered equal by the comparator in memory, then do - // another pass through them to find the truly equal ones. - val sorted = mergeSort(iterators, comparator).buffered - // Buffers reused across keys to decrease memory allocation - val buf = new ArrayBuffer[(K, C)] - val toReturn = new ArrayBuffer[(K, C)] + // need to read all keys considered equal by the ordering at once and compare them. new Iterator[Iterator[Product2[K, C]]] { + val sorted = mergeSort(iterators, comparator).buffered + + // Buffers reused across elements to decrease memory allocation + val keys = new ArrayBuffer[K] + val combiners = new ArrayBuffer[C] + override def hasNext: Boolean = sorted.hasNext override def next(): Iterator[Product2[K, C]] = { if (!hasNext) { throw new NoSuchElementException } + keys.clear() + combiners.clear() val firstPair = sorted.next() - buf += ((firstPair._1, firstPair._2)) // Copy it in case the Product2 object is reused + keys += firstPair._1 + combiners += firstPair._2 val key = firstPair._1 while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) { - val n = sorted.next() - buf += ((n._1, n._2)) - } - // buf now contains all the elements with keys equal to our first one according to the - // partial ordering. Now we need to find which keys were "really" equal, which we do - // through linear scans through the buffer. - toReturn.clear() - while (!buf.isEmpty) { - val last = buf(buf.size - 1) - buf.reduceToSize(buf.size - 1) - val k = last._1 - var c = last._2 + val pair = sorted.next() var i = 0 - while (i < buf.size) { - while (i < buf.size && buf(i)._1 == k) { - c = mergeCombiners(c, buf(i)._2) - // Replace this element with the last one in the buffer - buf(i) = buf(buf.size - 1) - buf.reduceToSize(buf.size - 1) + var foundKey = false + while (i < keys.size && !foundKey) { + if (keys(i) == pair._1) { + combiners(i) = mergeCombiners(combiners(i), pair._2) + foundKey = true } i += 1 } - toReturn += ((k, c)) + if (!foundKey) { + keys += pair._1 + combiners += pair._2 + } } - // Note that we return a *sequence* of elements since we could've had many keys marked + + // Note that we return an iterator of elements since we could've had many keys marked // equal by the partial order; we flatten this below to get a flat iterator of (K, C). - toReturn.iterator + keys.iterator.zip(combiners.iterator) } }.flatMap(i => i) } else { @@ -482,39 +503,33 @@ private[spark] class ExternalSorter[K, V, C]( * If no more pairs are left, return null. */ private def readNextItem(): (K, C) = { - try { - if (finished) { - return null - } - val k = deserStream.readObject().asInstanceOf[K] - val c = deserStream.readObject().asInstanceOf[C] - // Start reading the next batch if we're done with this one - indexInBatch += 1 - if (indexInBatch == serializerBatchSize) { - batchStream = nextBatchStream() - compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) - deserStream = serInstance.deserializeStream(compressedStream) - indexInBatch = 0 - } - // Update the partition location of the element we're reading - indexInPartition += 1 - while (indexInPartition == spill.elementsPerPartition(partitionId)) { - partitionId += 1 - indexInPartition = 0 - } - if (partitionId == numPartitions - 1 && - indexInPartition == spill.elementsPerPartition(partitionId) - 1) { - // This is the last element, remember that we're done - finished = true - deserStream.close() - } - (k, c) - } catch { - case e: EOFException => - finished = true - deserStream.close() - null + if (finished) { + return null + } + val k = deserStream.readObject().asInstanceOf[K] + val c = deserStream.readObject().asInstanceOf[C] + // Start reading the next batch if we're done with this one + indexInBatch += 1 + if (indexInBatch == serializerBatchSize) { + batchStream = nextBatchStream() + compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) + deserStream = serInstance.deserializeStream(compressedStream) + indexInBatch = 0 } + // Update the partition location of the element we're reading, possibly skipping zero-length + // partitions until we get to the next non-empty one or to EOF. + indexInPartition += 1 + while (indexInPartition == spill.elementsPerPartition(partitionId)) { + partitionId += 1 + indexInPartition = 0 + } + if (partitionId == numPartitions - 1 && + indexInPartition == spill.elementsPerPartition(partitionId) - 1) { + // This is the last element, remember that we're done + finished = true + deserStream.close() + } + (k, c) } var nextPartitionToRead = 0 @@ -530,7 +545,9 @@ private[spark] class ExternalSorter[K, V, C]( return false } } - // Check that we're still in the right partition; will be numPartitions at EOF + assert(partitionId >= myPartition) + // Check that we're still in the right partition; note that readNextItem will have returned + // null at EOF above so we would've returned false there partitionId == myPartition } @@ -546,10 +563,11 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Return an iterator over all the data written to this object, grouped by partition. For each - * partition we then have an iterator over its contents, and these are expected to be accessed - * in order (you can't "skip ahead" to one partition without reading the previous one). - * Guaranteed to return a key-value pair for each partition, in order of partition ID. + * Return an iterator over all the data written to this object, grouped by partition and + * aggregated by the requested aggregator. For each partition we then have an iterator over its + * contents, and these are expected to be accessed in order (you can't "skip ahead" to one + * partition without reading the previous one). Guaranteed to return a key-value pair for each + * partition, in order of partition ID. * * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. @@ -561,7 +579,7 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Return an iterator over all the data written to this object. + * Return an iterator over all the data written to this object, aggregated by our aggregator. */ def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 18791e632e6ff..62c3dc782bbf3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -55,6 +55,44 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { assert(sorter4.iterator.toSeq === Seq()) } + test("few elements per partition") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val elements = Set((1, 1), (2, 2), (5, 5)) + val expected = Set( + (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()), + (5, Set((5, 5))), (6, Set())) + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(7)), Some(ord), None) + sorter.write(elements.iterator) + assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(7)), None, None) + sorter2.write(elements.iterator) + assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), Some(ord), None) + sorter3.write(elements.iterator) + assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), None, None) + sorter4.write(elements.iterator) + assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + } + test("spilling in local cluster") { val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001")