Skip to content

Commit

Permalink
Cleanup in UnsafeShuffleWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 12, 2015
1 parent 4a2c785 commit e3b8855
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public UnsafeShuffleWriter(
UnsafeShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf sparkConf) {
SparkConf sparkConf) throws IOException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > PackedRecordPointer.MAXIMUM_PARTITION_ID) {
throw new IllegalArgumentException(
Expand All @@ -123,27 +123,29 @@ public UnsafeShuffleWriter(
this.taskContext = taskContext;
this.sparkConf = sparkConf;
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
open();
}

/**
* This convenience method should only be called in test code.
*/
@VisibleForTesting
public void write(Iterator<Product2<K, V>> records) throws IOException {
write(JavaConversions.asScalaIterator(records));
}

@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
boolean success = false;
try {
while (records.hasNext()) {
insertRecordIntoSorter(records.next());
}
closeAndWriteOutput();
} catch (Exception e) {
// Unfortunately, we have to catch Exception here in order to ensure proper cleanup after
// errors because Spark's Scala code, or users' custom Serializers, might throw arbitrary
// unchecked exceptions.
try {
success = true;
} finally {
if (!success) {
sorter.cleanupAfterError();
} finally {
throw new IOException("Error during shuffle write", e);
}
}
}
Expand All @@ -165,9 +167,6 @@ private void open() throws IOException {

@VisibleForTesting
void closeAndWriteOutput() throws IOException {
if (sorter == null) {
open();
}
serBuffer = null;
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
Expand All @@ -187,10 +186,7 @@ void closeAndWriteOutput() throws IOException {
}

@VisibleForTesting
void insertRecordIntoSorter(Product2<K, V> record) throws IOException{
if (sorter == null) {
open();
}
void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serBuffer.reset();
Expand Down Expand Up @@ -275,15 +271,29 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
}
}

/**
* Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
* {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
* cases where the IO compression codec does not support concatenation of compressed data, or in
* cases where users have explicitly disabled use of {@code transferTo} in order to work around
* kernel bugs.
*
* @param spills the spills to merge.
* @param outputFile the file to write the merged data to.
* @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
* @return the partition lengths in the merged file.
*/
private long[] mergeSpillsWithFileStream(
SpillInfo[] spills,
File outputFile,
@Nullable CompressionCodec compressionCodec) throws IOException {
assert (spills.length >= 2);
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final InputStream[] spillInputStreams = new FileInputStream[spills.length];
OutputStream mergedFileOutputStream = null;

boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputStreams[i] = new FileInputStream(spills[i].file);
Expand Down Expand Up @@ -311,22 +321,34 @@ private long[] mergeSpillsWithFileStream(
mergedFileOutputStream.close();
partitionLengths[partition] = (outputFile.length() - initialFileLength);
}
threwException = false;
} finally {
// To avoid masking exceptions that caused us to prematurely enter the finally block, only
// throw exceptions during cleanup if threwException == false.
for (InputStream stream : spillInputStreams) {
Closeables.close(stream, false);
Closeables.close(stream, threwException);
}
Closeables.close(mergedFileOutputStream, false);
Closeables.close(mergedFileOutputStream, threwException);
}
return partitionLengths;
}

/**
* Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
* This is only safe when the IO compression codec and serializer support concatenation of
* serialized streams.
*
* @return the partition lengths in the merged file.
*/
private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
assert (spills.length >= 2);
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
final long[] spillInputChannelPositions = new long[spills.length];
FileChannel mergedFileOutputChannel = null;

boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
Expand Down Expand Up @@ -368,12 +390,15 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
"to disable this NIO feature."
);
}
threwException = false;
} finally {
// To avoid masking exceptions that caused us to prematurely enter the finally block, only
// throw exceptions during cleanup if threwException == false.
for (int i = 0; i < spills.length; i++) {
assert(spillInputChannelPositions[i] == spills[i].file.length());
Closeables.close(spillInputChannels[i], false);
Closeables.close(spillInputChannels[i], threwException);
}
Closeables.close(mergedFileOutputChannel, false);
Closeables.close(mergedFileOutputChannel, threwException);
}
return partitionLengths;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ public Tuple2<TempShuffleBlockId, File> answer(
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
}

private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabled) {
private UnsafeShuffleWriter<Object, Object> createWriter(
boolean transferToEnabled) throws IOException {
conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
return new UnsafeShuffleWriter<Object, Object>(
blockManager,
Expand Down Expand Up @@ -242,12 +243,12 @@ private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
}

@Test(expected=IllegalStateException.class)
public void mustCallWriteBeforeSuccessfulStop() {
public void mustCallWriteBeforeSuccessfulStop() throws IOException {
createWriter(false).stop(true);
}

@Test
public void doNotNeedToCallWriteBeforeUnsuccessfulStop() {
public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
createWriter(false).stop(false);
}

Expand Down

0 comments on commit e3b8855

Please sign in to comment.