diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 7544ebbfeaad5..e2a942a425e87 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -102,7 +102,7 @@ public UnsafeShuffleWriter( UnsafeShuffleHandle handle, int mapId, TaskContext taskContext, - SparkConf sparkConf) { + SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > PackedRecordPointer.MAXIMUM_PARTITION_ID) { throw new IllegalArgumentException( @@ -123,27 +123,29 @@ public UnsafeShuffleWriter( this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); + open(); } + /** + * This convenience method should only be called in test code. + */ + @VisibleForTesting public void write(Iterator> records) throws IOException { write(JavaConversions.asScalaIterator(records)); } @Override public void write(scala.collection.Iterator> records) throws IOException { + boolean success = false; try { while (records.hasNext()) { insertRecordIntoSorter(records.next()); } closeAndWriteOutput(); - } catch (Exception e) { - // Unfortunately, we have to catch Exception here in order to ensure proper cleanup after - // errors because Spark's Scala code, or users' custom Serializers, might throw arbitrary - // unchecked exceptions. - try { + success = true; + } finally { + if (!success) { sorter.cleanupAfterError(); - } finally { - throw new IOException("Error during shuffle write", e); } } } @@ -165,9 +167,6 @@ private void open() throws IOException { @VisibleForTesting void closeAndWriteOutput() throws IOException { - if (sorter == null) { - open(); - } serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); @@ -187,10 +186,7 @@ void closeAndWriteOutput() throws IOException { } @VisibleForTesting - void insertRecordIntoSorter(Product2 record) throws IOException{ - if (sorter == null) { - open(); - } + void insertRecordIntoSorter(Product2 record) throws IOException { final K key = record._1(); final int partitionId = partitioner.getPartition(key); serBuffer.reset(); @@ -275,15 +271,29 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { } } + /** + * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, + * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in + * cases where the IO compression codec does not support concatenation of compressed data, or in + * cases where users have explicitly disabled use of {@code transferTo} in order to work around + * kernel bugs. + * + * @param spills the spills to merge. + * @param outputFile the file to write the merged data to. + * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. + * @return the partition lengths in the merged file. + */ private long[] mergeSpillsWithFileStream( SpillInfo[] spills, File outputFile, @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; OutputStream mergedFileOutputStream = null; + boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new FileInputStream(spills[i].file); @@ -311,22 +321,34 @@ private long[] mergeSpillsWithFileStream( mergedFileOutputStream.close(); partitionLengths[partition] = (outputFile.length() - initialFileLength); } + threwException = false; } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. for (InputStream stream : spillInputStreams) { - Closeables.close(stream, false); + Closeables.close(stream, threwException); } - Closeables.close(mergedFileOutputStream, false); + Closeables.close(mergedFileOutputStream, threwException); } return partitionLengths; } + /** + * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. + * This is only safe when the IO compression codec and serializer support concatenation of + * serialized streams. + * + * @return the partition lengths in the merged file. + */ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; FileChannel mergedFileOutputChannel = null; + boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); @@ -368,12 +390,15 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th "to disable this NIO feature." ); } + threwException = false; } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. for (int i = 0; i < spills.length; i++) { assert(spillInputChannelPositions[i] == spills[i].file.length()); - Closeables.close(spillInputChannels[i], false); + Closeables.close(spillInputChannels[i], threwException); } - Closeables.close(mergedFileOutputChannel, false); + Closeables.close(mergedFileOutputChannel, threwException); } return partitionLengths; } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 61511de6a5219..730d265c87f88 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -194,7 +194,8 @@ public Tuple2 answer( when(shuffleDep.partitioner()).thenReturn(hashPartitioner); } - private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { + private UnsafeShuffleWriter createWriter( + boolean transferToEnabled) throws IOException { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter( blockManager, @@ -242,12 +243,12 @@ private List> readRecordsFromFile() throws IOException { } @Test(expected=IllegalStateException.class) - public void mustCallWriteBeforeSuccessfulStop() { + public void mustCallWriteBeforeSuccessfulStop() throws IOException { createWriter(false).stop(true); } @Test - public void doNotNeedToCallWriteBeforeUnsuccessfulStop() { + public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { createWriter(false).stop(false); }