From 2f756088d8b13438c393bb5426076793a50c471b Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 26 Aug 2019 10:39:29 -0700 Subject: [PATCH] [SPARK-28607][CORE][SHUFFLE] Don't store partition lengths twice The shuffle writer API introduced in SPARK-28209 has a flaw that leads to a memory usage regression - we ended up tracking the partition lengths in two places. Here, we modify the API slightly to avoid redundant tracking. The implementation of the shuffle writer plugin is now responsible for tracking the lengths of partitions, and propagating this back up to the higher shuffle writer as part of the commitAllPartitions API. Existing unit tests. Closes #25341 from mccheah/dont-redundantly-store-part-lengths. Authored-by: mcheah Signed-off-by: Marcelo Vanzin --- .../api/MapOutputWriterCommitMessage.java | 35 +++++++++ .../shuffle/api/ShuffleMapOutputWriter.java | 13 +++- .../sort/BypassMergeSortShuffleWriter.java | 78 +++++++++---------- .../shuffle/sort/UnsafeShuffleWriter.java | 43 ++++------ .../io/LocalDiskShuffleMapOutputWriter.java | 3 +- .../shuffle/sort/SortShuffleWriter.scala | 9 ++- .../util/collection/ExternalSorter.scala | 12 +-- ...LocalDiskShuffleMapOutputWriterSuite.scala | 6 +- 8 files changed, 110 insertions(+), 89 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java diff --git a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java new file mode 100644 index 0000000000000..e07efd57cc07f --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java @@ -0,0 +1,35 @@ +package org.apache.spark.shuffle.api; + +import java.util.Optional; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.storage.BlockManagerId; + +@Experimental +public final class MapOutputWriterCommitMessage { + + private final long[] partitionLengths; + private final Optional location; + + private MapOutputWriterCommitMessage(long[] partitionLengths, Optional location) { + this.partitionLengths = partitionLengths; + this.location = location; + } + + public static MapOutputWriterCommitMessage of(long[] partitionLengths) { + return new MapOutputWriterCommitMessage(partitionLengths, Optional.empty()); + } + + public static MapOutputWriterCommitMessage of( + long[] partitionLengths, java.util.Optional location) { + return new MapOutputWriterCommitMessage(partitionLengths, location); + } + + public long[] getPartitionLengths() { + return partitionLengths; + } + + public Optional getLocation() { + return location; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 9135293636e90..8fcc73ba3c9b2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -51,15 +51,24 @@ public interface ShuffleMapOutputWriter { /** * Commits the writes done by all partition writers returned by all calls to this object's - * {@link #getPartitionWriter(int)}. + * {@link #getPartitionWriter(int)}, and returns a bundle of metadata associated with the + * behavior of the write. *

* This should ensure that the writes conducted by this module's partition writers are * available to downstream reduce tasks. If this method throws any exception, this module's * {@link #abort(Throwable)} method will be invoked before propagating the exception. *

* This can also close any resources and clean up temporary state if necessary. + *

+ * The returned array should contain two sets of metadata: + * + * 1. For each partition from (0) to (numPartitions - 1), the number of bytes written by + * the partition writer for that partition id. + * + * 2. If the partition data was stored on the local disk of this executor, also provide + * the block manager id where these bytes can be fetched from. */ - Optional commitAllPartitions() throws IOException; + MapOutputWriterCommitMessage commitAllPartitions() throws IOException; /** * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}. diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index d6cc1d500e3d1..94ad5fc66185b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,13 +21,10 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.OutputStream; -import java.nio.channels.Channels; import java.nio.channels.FileChannel; import java.util.Optional; import javax.annotation.Nullable; -import org.apache.spark.api.java.Optional; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import scala.None$; import scala.Option; import scala.Product2; @@ -42,6 +39,7 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; @@ -97,7 +95,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private DiskBlockObjectWriter[] partitionWriters; private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; - private long[] partitionLengths; + private MapOutputWriterCommitMessage commitMessage; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -122,7 +120,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.mapId = mapId; this.mapTaskAttemptId = mapTaskAttemptId; this.shuffleId = dep.shuffleId(); - this.mapTaskAttemptId = mapTaskAttemptId; this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; @@ -137,11 +134,11 @@ public void write(Iterator> records) throws IOException { .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); try { if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; - mapOutputWriter.commitAllPartitions(); + commitMessage = mapOutputWriter.commitAllPartitions(); mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), - partitionLengths); + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + mapTaskAttemptId); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -173,9 +170,11 @@ public void write(Iterator> records) throws IOException { } } - partitionLengths = writePartitionedData(mapOutputWriter); - mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + commitMessage = writePartitionedData(mapOutputWriter); + mapStatus = MapStatus$.MODULE$.apply( + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + mapTaskAttemptId); } catch (Exception e) { try { mapOutputWriter.abort(e); @@ -189,7 +188,7 @@ public void write(Iterator> records) throws IOException { @VisibleForTesting long[] getPartitionLengths() { - return partitionLengths; + return commitMessage.getPartitionLengths(); } /** @@ -197,42 +196,39 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException { + private MapOutputWriterCommitMessage writePartitionedData( + ShuffleMapOutputWriter mapOutputWriter) throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; - if (partitionWriters == null) { - // We were passed an empty iterator - return lengths; - } - final long writeStartTime = System.nanoTime(); - try { - for (int i = 0; i < numPartitions; i++) { - final File file = partitionWriterSegments[i].file(); - ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); - if (file.exists()) { - if (transferToEnabled) { - // Using WritableByteChannelWrapper to make resource closing consistent between - // this implementation and UnsafeShuffleWriter. - Optional maybeOutputChannel = writer.openChannelWrapper(); - if (maybeOutputChannel.isPresent()) { - writePartitionedDataWithChannel(file, maybeOutputChannel.get()); + if (partitionWriters != null) { + final long writeStartTime = System.nanoTime(); + try { + for (int i = 0; i < numPartitions; i++) { + final File file = partitionWriterSegments[i].file(); + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); + if (file.exists()) { + if (transferToEnabled) { + // Using WritableByteChannelWrapper to make resource closing consistent between + // this implementation and UnsafeShuffleWriter. + Optional maybeOutputChannel = writer.openChannelWrapper(); + if (maybeOutputChannel.isPresent()) { + writePartitionedDataWithChannel(file, maybeOutputChannel.get()); + } else { + writePartitionedDataWithStream(file, writer); + } } else { writePartitionedDataWithStream(file, writer); } - } else { - writePartitionedDataWithStream(file, writer); - } - if (!file.delete()) { - logger.error("Unable to delete file for partition {}", i); + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } } } - lengths[i] = writer.getNumBytesWritten(); + } finally { + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } - } finally { - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + partitionWriters = null; } - partitionWriters = null; - return lengths; + return mapOutputWriter.commitAllPartitions(); } private void writePartitionedDataWithChannel( diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 441718126bc92..745f4785ce01b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -23,9 +23,6 @@ import java.nio.channels.FileChannel; import java.util.Iterator; -import org.apache.spark.api.java.Optional; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; -import org.apache.spark.storage.BlockManagerId; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -40,10 +37,12 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.SupportsTransferTo; +import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; @@ -222,11 +221,10 @@ void closeAndWriteOutput() throws IOException { mapId, taskContext.taskAttemptId(), partitioner.numPartitions()); - final long[] partitionLengths; - Optional location; + MapOutputWriterCommitMessage commitMessage; try { try { - partitionLengths = mergeSpills(spills, mapWriter); + mergeSpills(spills, mapWriter); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { @@ -234,7 +232,7 @@ void closeAndWriteOutput() throws IOException { } } } - location = mapWriter.commitAllPartitions(); + commitMessage = mapWriter.commitAllPartitions(); } catch (Exception e) { try { mapWriter.abort(e); @@ -244,7 +242,9 @@ void closeAndWriteOutput() throws IOException { throw e; } mapStatus = MapStatus$.MODULE$.apply( - location.orNull(), partitionLengths, taskContext.attemptNumber()); + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + taskContext.attemptNumber()); } @VisibleForTesting @@ -276,7 +276,7 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, + private void mergeSpills(SpillInfo[] spills, ShuffleMapOutputWriter mapWriter) throws IOException { final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); @@ -285,12 +285,8 @@ private long[] mergeSpills(SpillInfo[] spills, final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); - final int numPartitions = partitioner.numPartitions(); - long[] partitionLengths = new long[numPartitions]; try { - if (spills.length == 0) { - return partitionLengths; - } else { + if (spills.length > 0) { // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill @@ -307,14 +303,14 @@ private long[] mergeSpills(SpillInfo[] spills, // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, mapWriter); + mergeSpillsWithTransferTo(spills, mapWriter); } else { logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, null); + mergeSpillsWithFileStream(spills, mapWriter, null); } } else { logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); + mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that @@ -322,7 +318,6 @@ private long[] mergeSpills(SpillInfo[] spills, // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - return partitionLengths; } } catch (IOException e) { throw e; @@ -345,12 +340,11 @@ private long[] mergeSpills(SpillInfo[] spills, * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithFileStream( + private void mergeSpillsWithFileStream( SpillInfo[] spills, ShuffleMapOutputWriter mapWriter, @Nullable CompressionCodec compressionCodec) throws IOException { final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; boolean threwException = true; @@ -395,7 +389,6 @@ private long[] mergeSpillsWithFileStream( Closeables.close(partitionOutput, copyThrewExecption); } long numBytesWritten = writer.getNumBytesWritten(); - partitionLengths[partition] = numBytesWritten; writeMetrics.incBytesWritten(numBytesWritten); } threwException = false; @@ -406,7 +399,6 @@ private long[] mergeSpillsWithFileStream( Closeables.close(stream, threwException); } } - return partitionLengths; } /** @@ -418,11 +410,10 @@ private long[] mergeSpillsWithFileStream( * @param mapWriter the map output writer to use for output. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithTransferTo( + private void mergeSpillsWithTransferTo( SpillInfo[] spills, ShuffleMapOutputWriter mapWriter) throws IOException { final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; @@ -455,7 +446,6 @@ private long[] mergeSpillsWithTransferTo( Closeables.close(partitionChannel, copyThrewExecption); } long numBytes = writer.getNumBytesWritten(); - partitionLengths[partition] = numBytes; writeMetrics.incBytesWritten(numBytes); } threwException = false; @@ -467,7 +457,6 @@ private long[] mergeSpillsWithTransferTo( Closeables.close(spillInputChannels[i], threwException); } } - return partitionLengths; } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index add4634a61fb5..7fc19b1270a46 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -96,10 +96,11 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I } @Override - public void commitAllPartitions() throws IOException { + public long[] commitAllPartitions() throws IOException { cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + return partitionLengths; } @Override diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index f0d3368d0a58d..626f5fd91c291 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -66,9 +66,12 @@ private[spark] class SortShuffleWriter[K, V, C]( // (see SPARK-3570). val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions) - val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - val location = mapOutputWriter.commitAllPartitions - mapStatus = MapStatus(location.orNull, partitionLengths, context.taskAttemptId()) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + val commitMessage = mapOutputWriter.commitAllPartitions + mapStatus = MapStatus( + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths, + context.taskAttemptId()) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 0c1af50e73fcf..2f967a3cdfae0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -729,9 +729,7 @@ private[spark] class ExternalSorter[K, V, C]( * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ def writePartitionedMapOutput( - shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { - // Track location of each range in the map output - val lengths = new Array[Long](numPartitions) + shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Unit = { if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer @@ -757,9 +755,6 @@ private[spark] class ExternalSorter[K, V, C]( partitionPairsWriter.close() } } - if (partitionWriter != null) { - lengths(partitionId) = partitionWriter.getNumBytesWritten - } } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. @@ -791,17 +786,12 @@ private[spark] class ExternalSorter[K, V, C]( partitionPairsWriter.close() } } - if (partitionWriter != null) { - lengths(id) = partitionWriter.getNumBytesWritten - } } } context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - - lengths } def stop(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index 5693b9824523a..5156cc2cc47a6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -102,7 +102,6 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA intercept[IllegalStateException] { stream.write(p) } - assert(writer.getNumBytesWritten === data(p).length) } verifyWrittenRecords() } @@ -122,8 +121,6 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA tempFileInput.getChannel, channelWrapper.channel(), 0L, data(p).length) } } - assert(writer.getNumBytesWritten === data(p).length, - s"Partition $p does not have the correct number of bytes.") } verifyWrittenRecords() } @@ -139,8 +136,9 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA } private def verifyWrittenRecords(): Unit = { - mapOutputWriter.commitAllPartitions() + val committedLengths = mapOutputWriter.commitAllPartitions() assert(partitionSizesInMergedFile === partitionLengths) + assert(committedLengths === partitionLengths) assert(mergedOutputFile.length() === partitionLengths.sum) assert(data === readRecordsFromFile()) }