Skip to content

Commit

Permalink
[SPARK-28607][CORE][SHUFFLE] Don't store partition lengths twice
Browse files Browse the repository at this point in the history
The shuffle writer API introduced in SPARK-28209 has a flaw that leads to a memory usage regression - we ended up tracking the partition lengths in two places. Here, we modify the API slightly to avoid redundant tracking. The implementation of the shuffle writer plugin is now responsible for tracking the lengths of partitions, and propagating this back up to the higher shuffle writer as part of the commitAllPartitions API.

Existing unit tests.

Closes apache#25341 from mccheah/dont-redundantly-store-part-lengths.

Authored-by: mcheah <[email protected]>
Signed-off-by: Marcelo Vanzin <[email protected]>
  • Loading branch information
mccheah committed Sep 10, 2019
1 parent 32d9b69 commit 2f75608
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 89 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.apache.spark.shuffle.api;

import java.util.Optional;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.storage.BlockManagerId;

@Experimental
public final class MapOutputWriterCommitMessage {

private final long[] partitionLengths;
private final Optional<BlockManagerId> location;

private MapOutputWriterCommitMessage(long[] partitionLengths, Optional<BlockManagerId> location) {
this.partitionLengths = partitionLengths;
this.location = location;
}

public static MapOutputWriterCommitMessage of(long[] partitionLengths) {
return new MapOutputWriterCommitMessage(partitionLengths, Optional.empty());
}

public static MapOutputWriterCommitMessage of(
long[] partitionLengths, java.util.Optional<BlockManagerId> location) {
return new MapOutputWriterCommitMessage(partitionLengths, location);
}

public long[] getPartitionLengths() {
return partitionLengths;
}

public Optional<BlockManagerId> getLocation() {
return location;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,24 @@ public interface ShuffleMapOutputWriter {

/**
* Commits the writes done by all partition writers returned by all calls to this object's
* {@link #getPartitionWriter(int)}.
* {@link #getPartitionWriter(int)}, and returns a bundle of metadata associated with the
* behavior of the write.
* <p>
* This should ensure that the writes conducted by this module's partition writers are
* available to downstream reduce tasks. If this method throws any exception, this module's
* {@link #abort(Throwable)} method will be invoked before propagating the exception.
* <p>
* This can also close any resources and clean up temporary state if necessary.
* <p>
* The returned array should contain two sets of metadata:
*
* 1. For each partition from (0) to (numPartitions - 1), the number of bytes written by
* the partition writer for that partition id.
*
* 2. If the partition data was stored on the local disk of this executor, also provide
* the block manager id where these bytes can be fetched from.
*/
Optional<BlockManagerId> commitAllPartitions() throws IOException;
MapOutputWriterCommitMessage commitAllPartitions() throws IOException;

/**
* Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@
import java.io.FileInputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
import java.util.Optional;
import javax.annotation.Nullable;

import org.apache.spark.api.java.Optional;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import scala.None$;
import scala.Option;
import scala.Product2;
Expand All @@ -42,6 +39,7 @@
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
Expand Down Expand Up @@ -97,7 +95,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private DiskBlockObjectWriter[] partitionWriters;
private FileSegment[] partitionWriterSegments;
@Nullable private MapStatus mapStatus;
private long[] partitionLengths;
private MapOutputWriterCommitMessage commitMessage;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
Expand All @@ -122,7 +120,6 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.mapId = mapId;
this.mapTaskAttemptId = mapTaskAttemptId;
this.shuffleId = dep.shuffleId();
this.mapTaskAttemptId = mapTaskAttemptId;
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = writeMetrics;
Expand All @@ -137,11 +134,11 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
.createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions);
try {
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
mapOutputWriter.commitAllPartitions();
commitMessage = mapOutputWriter.commitAllPartitions();
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(),
partitionLengths);
commitMessage.getLocation().orElse(null),
commitMessage.getPartitionLengths(),
mapTaskAttemptId);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
Expand Down Expand Up @@ -173,9 +170,11 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
}
}

partitionLengths = writePartitionedData(mapOutputWriter);
mapOutputWriter.commitAllPartitions();
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
commitMessage = writePartitionedData(mapOutputWriter);
mapStatus = MapStatus$.MODULE$.apply(
commitMessage.getLocation().orElse(null),
commitMessage.getPartitionLengths(),
mapTaskAttemptId);
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
Expand All @@ -189,50 +188,47 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {

@VisibleForTesting
long[] getPartitionLengths() {
return partitionLengths;
return commitMessage.getPartitionLengths();
}

/**
* Concatenate all of the per-partition files into a single combined file.
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException {
private MapOutputWriterCommitMessage writePartitionedData(
ShuffleMapOutputWriter mapOutputWriter) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
return lengths;
}
final long writeStartTime = System.nanoTime();
try {
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i);
if (file.exists()) {
if (transferToEnabled) {
// Using WritableByteChannelWrapper to make resource closing consistent between
// this implementation and UnsafeShuffleWriter.
Optional<WritableByteChannelWrapper> maybeOutputChannel = writer.openChannelWrapper();
if (maybeOutputChannel.isPresent()) {
writePartitionedDataWithChannel(file, maybeOutputChannel.get());
if (partitionWriters != null) {
final long writeStartTime = System.nanoTime();
try {
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i);
if (file.exists()) {
if (transferToEnabled) {
// Using WritableByteChannelWrapper to make resource closing consistent between
// this implementation and UnsafeShuffleWriter.
Optional<WritableByteChannelWrapper> maybeOutputChannel = writer.openChannelWrapper();
if (maybeOutputChannel.isPresent()) {
writePartitionedDataWithChannel(file, maybeOutputChannel.get());
} else {
writePartitionedDataWithStream(file, writer);
}
} else {
writePartitionedDataWithStream(file, writer);
}
} else {
writePartitionedDataWithStream(file, writer);
}
if (!file.delete()) {
logger.error("Unable to delete file for partition {}", i);
if (!file.delete()) {
logger.error("Unable to delete file for partition {}", i);
}
}
}
lengths[i] = writer.getNumBytesWritten();
} finally {
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
} finally {
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
partitionWriters = null;
}
partitionWriters = null;
return lengths;
return mapOutputWriter.commitAllPartitions();
}

private void writePartitionedDataWithChannel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
import java.nio.channels.FileChannel;
import java.util.Iterator;

import org.apache.spark.api.java.Optional;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.storage.BlockManagerId;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
Expand All @@ -40,10 +37,12 @@

import org.apache.spark.*;
import org.apache.spark.annotation.Private;
import org.apache.spark.shuffle.api.TransferrableWritableByteChannel;
import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SupportsTransferTo;
import org.apache.spark.shuffle.api.TransferrableWritableByteChannel;
import org.apache.spark.internal.config.package$;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
Expand Down Expand Up @@ -222,19 +221,18 @@ void closeAndWriteOutput() throws IOException {
mapId,
taskContext.taskAttemptId(),
partitioner.numPartitions());
final long[] partitionLengths;
Optional<BlockManagerId> location;
MapOutputWriterCommitMessage commitMessage;
try {
try {
partitionLengths = mergeSpills(spills, mapWriter);
mergeSpills(spills, mapWriter);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
logger.error("Error while deleting spill file {}", spill.file.getPath());
}
}
}
location = mapWriter.commitAllPartitions();
commitMessage = mapWriter.commitAllPartitions();
} catch (Exception e) {
try {
mapWriter.abort(e);
Expand All @@ -244,7 +242,9 @@ void closeAndWriteOutput() throws IOException {
throw e;
}
mapStatus = MapStatus$.MODULE$.apply(
location.orNull(), partitionLengths, taskContext.attemptNumber());
commitMessage.getLocation().orElse(null),
commitMessage.getPartitionLengths(),
taskContext.attemptNumber());
}

@VisibleForTesting
Expand Down Expand Up @@ -276,7 +276,7 @@ void forceSorterToSpill() throws IOException {
*
* @return the partition lengths in the merged file.
*/
private long[] mergeSpills(SpillInfo[] spills,
private void mergeSpills(SpillInfo[] spills,
ShuffleMapOutputWriter mapWriter) throws IOException {
final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS());
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
Expand All @@ -285,12 +285,8 @@ private long[] mergeSpills(SpillInfo[] spills,
final boolean fastMergeIsSupported = !compressionEnabled ||
CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
final int numPartitions = partitioner.numPartitions();
long[] partitionLengths = new long[numPartitions];
try {
if (spills.length == 0) {
return partitionLengths;
} else {
if (spills.length > 0) {
// There are multiple spills to merge, so none of these spill files' lengths were counted
// towards our shuffle write count or shuffle write time. If we use the slow merge path,
// then the final output file's size won't necessarily be equal to the sum of the spill
Expand All @@ -307,22 +303,21 @@ private long[] mergeSpills(SpillInfo[] spills,
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled && !encryptionEnabled) {
logger.debug("Using transferTo-based fast merge");
partitionLengths = mergeSpillsWithTransferTo(spills, mapWriter);
mergeSpillsWithTransferTo(spills, mapWriter);
} else {
logger.debug("Using fileStream-based fast merge");
partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, null);
mergeSpillsWithFileStream(spills, mapWriter, null);
}
} else {
logger.debug("Using slow merge");
partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, compressionCodec);
mergeSpillsWithFileStream(spills, mapWriter, compressionCodec);
}
// When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
// in-memory records, we write out the in-memory records to a file but do not count that
// final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
// to be counted as shuffle write, but this will lead to double-counting of the final
// SpillInfo's bytes.
writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
return partitionLengths;
}
} catch (IOException e) {
throw e;
Expand All @@ -345,12 +340,11 @@ private long[] mergeSpills(SpillInfo[] spills,
* @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
* @return the partition lengths in the merged file.
*/
private long[] mergeSpillsWithFileStream(
private void mergeSpillsWithFileStream(
SpillInfo[] spills,
ShuffleMapOutputWriter mapWriter,
@Nullable CompressionCodec compressionCodec) throws IOException {
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final InputStream[] spillInputStreams = new InputStream[spills.length];

boolean threwException = true;
Expand Down Expand Up @@ -395,7 +389,6 @@ private long[] mergeSpillsWithFileStream(
Closeables.close(partitionOutput, copyThrewExecption);
}
long numBytesWritten = writer.getNumBytesWritten();
partitionLengths[partition] = numBytesWritten;
writeMetrics.incBytesWritten(numBytesWritten);
}
threwException = false;
Expand All @@ -406,7 +399,6 @@ private long[] mergeSpillsWithFileStream(
Closeables.close(stream, threwException);
}
}
return partitionLengths;
}

/**
Expand All @@ -418,11 +410,10 @@ private long[] mergeSpillsWithFileStream(
* @param mapWriter the map output writer to use for output.
* @return the partition lengths in the merged file.
*/
private long[] mergeSpillsWithTransferTo(
private void mergeSpillsWithTransferTo(
SpillInfo[] spills,
ShuffleMapOutputWriter mapWriter) throws IOException {
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
final long[] spillInputChannelPositions = new long[spills.length];

Expand Down Expand Up @@ -455,7 +446,6 @@ private long[] mergeSpillsWithTransferTo(
Closeables.close(partitionChannel, copyThrewExecption);
}
long numBytes = writer.getNumBytesWritten();
partitionLengths[partition] = numBytes;
writeMetrics.incBytesWritten(numBytes);
}
threwException = false;
Expand All @@ -467,7 +457,6 @@ private long[] mergeSpillsWithTransferTo(
Closeables.close(spillInputChannels[i], threwException);
}
}
return partitionLengths;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I
}

@Override
public void commitAllPartitions() throws IOException {
public long[] commitAllPartitions() throws IOException {
cleanUp();
File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
return partitionLengths;
}

@Override
Expand Down
Loading

0 comments on commit 2f75608

Please sign in to comment.