From bbf359df3a9ad37ff399314ddde19eae24536d30 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 18 Jul 2014 21:18:17 -0700 Subject: [PATCH] More work --- .../shuffle/sort/SortShuffleManager.scala | 20 ++++ .../shuffle/sort/SortShuffleWriter.scala | 91 +++++++++++++++++-- .../spark/storage/DiskBlockManager.scala | 14 ++- .../util/collection/ExternalSorter.scala | 44 ++++++++- .../ExternalAppendOnlyMapSuite.scala | 7 -- .../util/collection/ExternalSorterSuite.scala | 80 ++++++++++++++++ .../util/collection/FixedHashObject.scala | 25 +++++ 7 files changed, 258 insertions(+), 23 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d869b870920ce..87d081ec51ee1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -17,9 +17,12 @@ package org.apache.spark.shuffle.sort +import java.io.{DataInputStream, FileInputStream} + import org.apache.spark.shuffle._ import org.apache.spark.{TaskContext, ShuffleDependency} import org.apache.spark.shuffle.hash.HashShuffleReader +import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId} private[spark] class SortShuffleManager extends ShuffleManager { /** @@ -57,4 +60,21 @@ private[spark] class SortShuffleManager extends ShuffleManager { /** Shut down this ShuffleManager. */ override def stop(): Unit = {} + + /** Get the location of a block in a map output file. Uses the index file we create for it. */ + def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = { + // The block is actually going to be a range of a single map output file for this map, + // so + val realId = ShuffleBlockId(blockId.shuffleId, blockId.mapId, 0) + val indexFile = diskManager.getFile(realId.name + ".index") + val in = new DataInputStream(new FileInputStream(indexFile)) + try { + in.skip(blockId.reduceId * 8) + val offset = in.readLong() + val nextOffset = in.readLong() + new FileSegment(diskManager.getFile(realId), offset, nextOffset - offset) + } finally { + in.close() + } + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 9a8835ee210b7..afd12e65e54a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -18,10 +18,14 @@ package org.apache.spark.shuffle.sort import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle} -import org.apache.spark.{SparkEnv, Logging, TaskContext} +import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.storage.ShuffleBlockId +import java.util.concurrent.atomic.AtomicInteger +import org.apache.spark.executor.ShuffleWriteMetrics +import java.io.{BufferedOutputStream, FileOutputStream, DataOutputStream} private[spark] class SortShuffleWriter[K, V, C]( handle: BaseShuffleHandle[K, V, C], @@ -30,17 +34,24 @@ private[spark] class SortShuffleWriter[K, V, C]( extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency - private val numOutputPartitions = dep.partitioner.numPartitions + private val numPartitions = dep.partitioner.numPartitions private val metrics = context.taskMetrics private val blockManager = SparkEnv.get.blockManager + private val shuffleBlockManager = blockManager.shuffleBlockManager private val diskBlockManager = blockManager.diskBlockManager private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) + private val conf = SparkEnv.get.conf + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + + private var sorter: ExternalSorter[K, V, _] = null + + private var stopping = false + private var mapStatus: MapStatus = null + /** Write a bunch of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - var sorter: ExternalSorter[K, V, _] = null - val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = { if (dep.mapSideCombine) { if (!dep.aggregator.isDefined) { @@ -58,13 +69,81 @@ private[spark] class SortShuffleWriter[K, V, C]( } } + // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later + // serve different ranges of this file using an index file that we create at the end. + val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0) + val shuffleFile = blockManager.diskBlockManager.getFile(blockId) + + // Track location of each range in the output file + val offsets = new Array[Long](numPartitions + 1) + val lengths = new Array[Long](numPartitions) + + // Statistics + var totalBytes = 0L + var totalTime = 0L + for ((id, elements) <- partitions) { + if (elements.hasNext) { + val writer = blockManager.getDiskWriter(blockId, shuffleFile, ser, fileBufferSize) + for (elem <- elements) { + writer.write(elem) + } + writer.commit() + writer.close() + val segment = writer.fileSegment() + offsets(id + 1) = segment.offset + segment.length + lengths(id) = segment.length + totalTime += writer.timeWriting() + totalBytes += segment.length + } else { + // Don't create a new writer to avoid writing any headers and things like that + offsets(id + 1) = offsets(id) + } + } + + val shuffleMetrics = new ShuffleWriteMetrics + shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime + context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics) + + // Write an index file with the offsets of each block, plus a final offset at the end for the + // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure + // out where each block begins and ends. + val diskBlockManager = blockManager.diskBlockManager + val indexFile = diskBlockManager.getFile(blockId.name + ".index") + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + try { + var i = 0 + while (i < numPartitions + 1) { + out.writeLong(offsets(i)) + i += 1 + } + } finally { + out.close() } - ??? + mapStatus = new MapStatus(blockManager.blockManagerId, + lengths.map(MapOutputTracker.compressSize)) + + // TODO: keep track of our file in a way that can be cleaned up later } /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = ??? + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + return Option(mapStatus) + } else { + // TODO: clean up our file + return None + } + } finally { + // TODO: sorter.stop() + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2e7ed7538e6e5..6f82805cd8f3f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -21,10 +21,11 @@ import java.io.File import java.text.SimpleDateFormat import java.util.{Date, Random, UUID} -import org.apache.spark.Logging +import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.network.netty.{PathResolver, ShuffleSender} import org.apache.spark.util.Utils +import org.apache.spark.shuffle.sort.SortShuffleManager /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -54,12 +55,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD addShutdownHook() /** - * Returns the physical file segment in which the given BlockId is located. - * If the BlockId has been mapped to a specific FileSegment, that will be returned. - * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly. + * Returns the physical file segment in which the given BlockId is located. If the BlockId has + * been mapped to a specific FileSegment by the shuffle layer, that will be returned. + * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId. */ def getBlockLocation(blockId: BlockId): FileSegment = { - if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) { + if (blockId.isShuffle && SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]) { + val sortShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[SortShuffleManager] + sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this) + } else if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) { shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) } else { val file = getFile(blockId.name) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 6d96ec1fb3aba..97a82cbd745b3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -18,6 +18,7 @@ package org.apache.spark.util.collection import java.io._ +import java.util.Comparator import scala.collection.mutable.ArrayBuffer @@ -88,6 +89,13 @@ private[spark] class ExternalSorter[K, V, C]( (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } + // For now, just compare them by partition; later we can compare by key as well + private val comparator = new Comparator[((Int, K), C)] { + override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = { + a._1._1 - b._1._1 + } + } + // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. @@ -192,7 +200,7 @@ private[spark] class ExternalSorter[K, V, C]( } try { - val it = collection.iterator // TODO: destructiveSortedIterator(comparator) + val it = collection.destructiveSortedIterator(comparator) while (it.hasNext) { val elem = it.next() val partitionId = elem._1._1 @@ -232,11 +240,22 @@ private[spark] class ExternalSorter[K, V, C]( * inside each partition. This can be used to either write out a new file or return data to * the user. */ - def merge(spills: Seq[SpilledFile]): Iterator[(Int, Iterator[Product2[K, C]])] = { + def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)]) + : Iterator[(Int, Iterator[Product2[K, C]])] = { // TODO: merge intermediate results if they are sorted by the comparator val readers = spills.map(new SpillReader(_)) + val inMemBuffered = inMemory.buffered (0 until numPartitions).iterator.map { p => - (p, readers.iterator.flatMap(_.readNextPartition())) + val inMemIterator = new Iterator[(K, C)] { + override def hasNext: Boolean = { + inMemBuffered.hasNext && inMemBuffered.head._1._1 == p + } + override def next(): (K, C) = { + val elem = inMemBuffered.next() + (elem._1._2, elem._2) + } + } + (p, readers.iterator.flatMap(_.readNextPartition()) ++ inMemIterator) } } @@ -301,6 +320,11 @@ private[spark] class ExternalSorter[K, V, C]( } val k = deserStream.readObject().asInstanceOf[K] val c = deserStream.readObject().asInstanceOf[C] + if (partitionId == numPartitions - 1 && + indexInPartition == spill.elementsPerPartition(partitionId) - 1) { + finished = true + deserStream.close() + } (k, c) } catch { case e: EOFException => @@ -319,6 +343,9 @@ private[spark] class ExternalSorter[K, V, C]( override def hasNext: Boolean = { if (nextItem == null) { nextItem = readNextItem() + if (nextItem == null) { + return false + } } // Check that we're still in the right partition; will be numPartitions at EOF partitionId == myPartition @@ -328,7 +355,9 @@ private[spark] class ExternalSorter[K, V, C]( if (!hasNext) { throw new NoSuchElementException } - nextItem + val item = nextItem + nextItem = null + item } } } @@ -337,11 +366,16 @@ private[spark] class ExternalSorter[K, V, C]( * Return an iterator over all the data written to this object, grouped by partition. For each * partition we then have an iterator over its contents, and these are expected to be accessed * in order (you can't "skip ahead" to one partition without reading the previous one). + * Guaranteed to return a key-value pair for each partition, in order of partition ID. * * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. */ - def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = merge(spills) + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + val usingMap = aggregator.isDefined + val collection: SizeTrackingCollection[((Int, K), C)] = if (usingMap) map else buffer + merge(spills, collection.destructiveSortedIterator(comparator)) + } /** * Return an iterator over all the data written to this object. diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 0b7ad184a46d2..e2ee62b2b54a8 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -369,10 +369,3 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } } - -/** - * A dummy class that always returns the same hash code, to easily test hash collisions - */ -case class FixedHashObject(v: Int, h: Int) extends Serializable { - override def hashCode(): Int = h -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala new file mode 100644 index 0000000000000..1253963087bd7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.scalatest.FunSuite + +import org.apache.spark.{SparkContext, SparkConf, LocalSparkContext} +import org.apache.spark.SparkContext._ +import scala.collection.mutable.ArrayBuffer + +class ExternalSorterSuite extends FunSuite with LocalSparkContext { + + test("spilling in local cluster") { + val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + // reduceByKey - should spill ~8 times + val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) + val resultA = rddA.reduceByKey(math.max).collect() + assert(resultA.length == 50000) + resultA.foreach { case(k, v) => + k match { + case 0 => assert(v == 1) + case 25000 => assert(v == 50001) + case 49999 => assert(v == 99999) + case _ => + } + } + + // groupByKey - should spill ~17 times + val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultB = rddB.groupByKey().collect() + assert(resultB.length == 25000) + resultB.foreach { case(i, seq) => + i match { + case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3)) + case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003)) + case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999)) + case _ => + } + } + + // cogroup - should spill ~7 times + val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) + val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) + val resultC = rddC1.cogroup(rddC2).collect() + assert(resultC.length == 10000) + resultC.foreach { case(i, (seq1, seq2)) => + i match { + case 0 => + assert(seq1.toSet == Set[Int](0)) + assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 5000 => + assert(seq1.toSet == Set[Int](5000)) + assert(seq2.toSet == Set[Int]()) + case 9999 => + assert(seq1.toSet == Set[Int](9999)) + assert(seq2.toSet == Set[Int]()) + case _ => + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala new file mode 100644 index 0000000000000..c787b5f066e00 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +/** + * A dummy class that always returns the same hash code, to easily test hash collisions + */ +case class FixedHashObject(v: Int, h: Int) extends Serializable { + override def hashCode(): Int = h +}