From f17156a4fd82e8049b1a37c55b966fa7e8c2841b Mon Sep 17 00:00:00 2001 From: mccheah Date: Fri, 13 Sep 2019 12:53:51 -0700 Subject: [PATCH] [SPARK-25299] SPARK-25299 upstream updates (#605) * Bring implementation into closer alignment with upstream. Step to ease merge conflict resolution and build failure problems when we pull in changes from upstream. * Cherry-pick BypassMergeSortShuffleWriter changes and shuffle writer API changes * [SPARK-28607][CORE][SHUFFLE] Don't store partition lengths twice The shuffle writer API introduced in SPARK-28209 has a flaw that leads to a memory usage regression - we ended up tracking the partition lengths in two places. Here, we modify the API slightly to avoid redundant tracking. The implementation of the shuffle writer plugin is now responsible for tracking the lengths of partitions, and propagating this back up to the higher shuffle writer as part of the commitAllPartitions API. Existing unit tests. Closes #25341 from mccheah/dont-redundantly-store-part-lengths. Authored-by: mcheah Signed-off-by: Marcelo Vanzin * [SPARK-28571][CORE][SHUFFLE] Use the shuffle writer plugin for the SortShuffleWriter Use the shuffle writer APIs introduced in SPARK-28209 in the sort shuffle writer. Existing unit tests were changed to use the plugin instead, and they used the local disk version to ensure that there were no regressions. Closes #25342 from mccheah/shuffle-writer-refactor-sort-shuffle-writer. Lead-authored-by: mcheah Co-authored-by: mccheah Signed-off-by: Marcelo Vanzin * [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. * Resolve build issues and remaining semantic conflicts * More build fixes * More build fixes * Attempt to fix build * More build fixes * [SPARK-29072] Put back usage of TimeTrackingOutputStream for UnsafeShuffleWriter and ShufflePartitionPairsWriter. * Address comments * Import ordering * Fix stream reference --- .../spark/api/shuffle/ShuffleDataIO.java | 34 --- .../shuffle/ShuffleExecutorComponents.java | 37 --- .../api/shuffle/ShuffleMapOutputWriter.java | 39 --- .../spark/api/shuffle/ShuffleReadSupport.java | 42 --- .../spark/api/shuffle/SupportsTransferTo.java | 53 ---- .../TransferrableWritableByteChannel.java | 54 ---- .../api/MapOutputWriterCommitMessage.java | 53 ++++ .../api}/ShuffleBlockInfo.java | 6 +- .../spark/shuffle/api/ShuffleDataIO.java | 53 ++++ .../api}/ShuffleDriverComponents.java | 2 +- .../api/ShuffleExecutorComponents.java | 91 +++++++ .../shuffle/api/ShuffleMapOutputWriter.java | 80 ++++++ .../shuffle/api/ShufflePartitionWriter.java | 98 +++++++ .../SingleSpillShuffleMapOutputWriter.java} | 26 +- .../api/WritableByteChannelWrapper.java} | 30 +-- .../sort/BypassMergeSortShuffleWriter.java | 147 ++++++----- ...faultTransferrableWritableByteChannel.java | 51 ---- .../shuffle/sort/UnsafeShuffleWriter.java | 246 +++++++++++------- .../io/DefaultShuffleExecutorComponents.java | 74 ------ ...ataIO.java => LocalDiskShuffleDataIO.java} | 25 +- .../LocalDiskShuffleExecutorComponents.java | 126 +++++++++ ...a => LocalDiskShuffleMapOutputWriter.java} | 100 +++---- ... LocalDiskSingleSpillMapOutputWriter.java} | 43 +-- ... => LocalDiskShuffleDriverComponents.java} | 4 +- .../org/apache/spark/ContextCleaner.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/internal/config/package.scala | 8 +- .../shuffle/BlockStoreShuffleReader.scala | 19 +- .../shuffle/ShufflePartitionPairsWriter.scala | 135 ++++++++++ ...cala => LocalDiskShuffleReadSupport.scala} | 12 +- .../shuffle/sort/SortShuffleManager.scala | 10 +- .../shuffle/sort/SortShuffleWriter.scala | 15 +- .../spark/storage/DiskBlockObjectWriter.scala | 2 +- .../util/collection/ExternalSorter.scala | 40 +-- .../spark/util/collection/PairsWriter.scala | 5 + .../ShufflePartitionPairsWriter.scala | 91 ------- .../sort/UnsafeShuffleWriterSuite.java | 41 +-- .../scala/org/apache/spark/ShuffleSuite.scala | 18 +- .../DAGSchedulerShufflePluginSuite.scala | 10 +- .../BlockStoreShuffleReaderSuite.scala | 21 +- .../ShuffleDriverComponentsSuite.scala | 43 +-- .../BlockStoreShuffleReaderBenchmark.scala | 14 +- ...ypassMergeSortShuffleWriterBenchmark.scala | 15 +- .../BypassMergeSortShuffleWriterSuite.scala | 187 ++++++------- .../sort/ShuffleWriterBenchmarkBase.scala | 6 +- .../sort/SortShuffleWriterBenchmark.scala | 15 +- .../shuffle/sort/SortShuffleWriterSuite.scala | 117 +++++++++ .../sort/UnsafeShuffleWriterBenchmark.scala | 15 +- .../DefaultShuffleMapOutputWriterSuite.scala | 230 ---------------- ...LocalDiskShuffleMapOutputWriterSuite.scala | 161 ++++++++++++ 50 files changed, 1527 insertions(+), 1221 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java rename core/src/main/java/org/apache/spark/{api/shuffle => shuffle/api}/ShuffleBlockInfo.java (98%) create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java rename core/src/main/java/org/apache/spark/{api/shuffle => shuffle/api}/ShuffleDriverComponents.java (97%) create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java rename core/src/main/java/org/apache/spark/{api/shuffle/ShuffleWriteSupport.java => shuffle/api/SingleSpillShuffleMapOutputWriter.java} (62%) rename core/src/main/java/org/apache/spark/{api/shuffle/ShufflePartitionWriter.java => shuffle/api/WritableByteChannelWrapper.java} (59%) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java rename core/src/main/java/org/apache/spark/shuffle/sort/io/{DefaultShuffleDataIO.java => LocalDiskShuffleDataIO.java} (61%) create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java rename core/src/main/java/org/apache/spark/shuffle/sort/io/{DefaultShuffleMapOutputWriter.java => LocalDiskShuffleMapOutputWriter.java} (68%) rename core/src/main/java/org/apache/spark/shuffle/sort/io/{DefaultShuffleWriteSupport.java => LocalDiskSingleSpillMapOutputWriter.java} (52%) rename core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/{DefaultShuffleDriverComponents.java => LocalDiskShuffleDriverComponents.java} (93%) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala rename core/src/main/scala/org/apache/spark/shuffle/io/{DefaultShuffleReadSupport.scala => LocalDiskShuffleReadSupport.scala} (90%) delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java deleted file mode 100644 index dd7c0ac7320cb..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; - -/** - * :: Experimental :: - * An interface for launching Shuffle related components - * - * @since 3.0.0 - */ -@Experimental -public interface ShuffleDataIO { - String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."; - - ShuffleDriverComponents driver(); - ShuffleExecutorComponents executor(); -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java deleted file mode 100644 index a5fa032bf651d..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; - -import java.util.Map; - -/** - * :: Experimental :: - * An interface for building shuffle support for Executors - * - * @since 3.0.0 - */ -@Experimental -public interface ShuffleExecutorComponents { - void initializeExecutor(String appId, String execId, Map extraConfigs); - - ShuffleWriteSupport writes(); - - ShuffleReadSupport reads(); -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java deleted file mode 100644 index 025fc096faaad..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import java.io.IOException; - -import org.apache.spark.annotation.Experimental; -import org.apache.spark.api.java.Optional; -import org.apache.spark.storage.BlockManagerId; - -/** - * :: Experimental :: - * An interface for creating and managing shuffle partition writers - * - * @since 3.0.0 - */ -@Experimental -public interface ShuffleMapOutputWriter { - ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException; - - Optional commitAllPartitions() throws IOException; - - void abort(Throwable error) throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java deleted file mode 100644 index 83947bd4d6fa4..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; - -import java.io.IOException; -import java.io.InputStream; - -/** - * :: Experimental :: - * An interface for reading shuffle records. - * @since 3.0.0 - */ -@Experimental -public interface ShuffleReadSupport { - /** - * Returns an underlying {@link Iterable} that will iterate - * through shuffle data, given an iterable for the shuffle blocks to fetch. - */ - Iterable getPartitionReaders(Iterable blockMetadata) - throws IOException; - - default boolean shouldWrapStream() { - return true; - } -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java deleted file mode 100644 index 866b61d0bafd9..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import java.io.IOException; - -import org.apache.spark.annotation.Experimental; - -/** - * :: Experimental :: - * Indicates that partition writers can transfer bytes directly from input byte channels to - * output channels that stream data to the underlying shuffle partition storage medium. - *

- * This API is separated out for advanced users because it only needs to be used for - * specific low-level optimizations. The idea is that the returned channel can transfer bytes - * from the input file channel out to the backing storage system without copying data into - * memory. - *

- * Most shuffle plugin implementations should use {@link ShufflePartitionWriter} instead. - * - * @since 3.0.0 - */ -@Experimental -public interface SupportsTransferTo extends ShufflePartitionWriter { - - /** - * Opens and returns a {@link TransferrableWritableByteChannel} for transferring bytes from - * input byte channels to the underlying shuffle data store. - */ - TransferrableWritableByteChannel openTransferrableChannel() throws IOException; - - /** - * Returns the number of bytes written either by this writer's output stream opened by - * {@link #openStream()} or the byte channel opened by {@link #openTransferrableChannel()}. - */ - @Override - long getNumBytesWritten(); -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java deleted file mode 100644 index 18234d7c4c944..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import java.io.Closeable; -import java.io.IOException; - -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; -import org.apache.spark.annotation.Experimental; - -/** - * :: Experimental :: - * Represents an output byte channel that can copy bytes from input file channels to some - * arbitrary storage system. - *

- * This API is provided for advanced users who can transfer bytes from a file channel to - * some output sink without copying data into memory. Most users should not need to use - * this functionality; this is primarily provided for the built-in shuffle storage backends - * that persist shuffle files on local disk. - *

- * For a simpler alternative, see {@link ShufflePartitionWriter}. - * - * @since 3.0.0 - */ -@Experimental -public interface TransferrableWritableByteChannel extends Closeable { - - /** - * Copy all bytes from the source readable byte channel into this byte channel. - * - * @param source File to transfer bytes from. Do not call anything on this channel other than - * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. - * @param transferStartPosition Start position of the input file to transfer from. - * @param numBytesToTransfer Number of bytes to transfer from the given source. - */ - void transferFrom(FileChannel source, long transferStartPosition, long numBytesToTransfer) - throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java new file mode 100644 index 0000000000000..5a1c82499b715 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.util.Optional; + +import org.apache.spark.annotation.Private; +import org.apache.spark.storage.BlockManagerId; + +@Private +public final class MapOutputWriterCommitMessage { + + private final long[] partitionLengths; + private final Optional location; + + private MapOutputWriterCommitMessage( + long[] partitionLengths, Optional location) { + this.partitionLengths = partitionLengths; + this.location = location; + } + + public static MapOutputWriterCommitMessage of(long[] partitionLengths) { + return new MapOutputWriterCommitMessage(partitionLengths, Optional.empty()); + } + + public static MapOutputWriterCommitMessage of( + long[] partitionLengths, BlockManagerId location) { + return new MapOutputWriterCommitMessage(partitionLengths, Optional.of(location)); + } + + public long[] getPartitionLengths() { + return partitionLengths; + } + + public Optional getLocation() { + return location; + } +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java similarity index 98% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java index 34daf2c137a12..72a67c76f28b5 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; + +import java.util.Objects; import org.apache.spark.api.java.Optional; import org.apache.spark.storage.BlockManagerId; -import java.util.Objects; - /** * :: Experimental :: * An object defining the shuffle block and length metadata associated with the block. diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java new file mode 100644 index 0000000000000..5126f0c3577f8 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * An interface for plugging in modules for storing and reading temporary shuffle data. + *

+ * This is the root of a plugin system for storing shuffle bytes to arbitrary storage + * backends in the sort-based shuffle algorithm implemented by the + * {@link org.apache.spark.shuffle.sort.SortShuffleManager}. If another shuffle algorithm is + * needed instead of sort-based shuffle, one should implement + * {@link org.apache.spark.shuffle.ShuffleManager} instead. + *

+ * A single instance of this module is loaded per process in the Spark application. + * The default implementation reads and writes shuffle data from the local disks of + * the executor, and is the implementation of shuffle file storage that has remained + * consistent throughout most of Spark's history. + *

+ * Alternative implementations of shuffle data storage can be loaded via setting + * spark.shuffle.sort.io.plugin.class. + * @since 3.0.0 + */ +@Private +public interface ShuffleDataIO { + + String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."; + + ShuffleDriverComponents driver(); + + /** + * Called once on executor processes to bootstrap the shuffle data storage modules that + * are only invoked on the executors. + */ + ShuffleExecutorComponents executor(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java similarity index 97% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java index 8b54968f9b134..cbc59bc7b6a05 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; import java.io.IOException; import java.util.Map; diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java new file mode 100644 index 0000000000000..94c07009f3180 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; +import java.util.Optional; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * An interface for building shuffle support for Executors. + * + * @since 3.0.0 + */ +@Private +public interface ShuffleExecutorComponents { + + /** + * Called once per executor to bootstrap this module with state that is specific to + * that executor, specifically the application ID and executor ID. + */ + void initializeExecutor(String appId, String execId, Map extraConfigs); + + /** + * Called once per map task to create a writer that will be responsible for persisting all the + * partitioned bytes written by that map task. + * + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param mapId Within the shuffle, the identifier of the map task + * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task + * with the same (shuffleId, mapId) pair can be distinguished by the + * different values of mapTaskAttemptId. + * @param numPartitions The number of partitions that will be written by the map task. Some of + * these partitions may be empty. + */ + ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId, + int numPartitions) throws IOException; + + /** + * Returns an underlying {@link Iterable} that will iterate + * through shuffle data, given an iterable for the shuffle blocks to fetch. + */ + Iterable getPartitionReaders(Iterable blockMetadata) + throws IOException; + + default boolean shouldWrapPartitionReaderStream() { + return true; + } + + /** + * An optional extension for creating a map output writer that can optimize the transfer of a + * single partition file, as the entire result of a map task, to the backing store. + *

+ * Most implementations should return the default {@link Optional#empty()} to indicate that + * they do not support this optimization. This primarily is for backwards-compatibility in + * preserving an optimization in the local disk shuffle storage implementation. + * + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param mapId Within the shuffle, the identifier of the map task + * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task + * with the same (shuffleId, mapId) pair can be distinguished by the + * different values of mapTaskAttemptId. + */ + default Optional createSingleFileMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId) throws IOException { + return Optional.empty(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java new file mode 100644 index 0000000000000..8fcc73ba3c9b2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.IOException; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * A top-level writer that returns child writers for persisting the output of a map task, + * and then commits all of the writes as one atomic operation. + * + * @since 3.0.0 + */ +@Private +public interface ShuffleMapOutputWriter { + + /** + * Creates a writer that can open an output stream to persist bytes targeted for a given reduce + * partition id. + *

+ * The chunk corresponds to bytes in the given reduce partition. This will not be called twice + * for the same partition within any given map task. The partition identifier will be in the + * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was + * provided upon the creation of this map output writer via + * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}. + *

+ * Calls to this method will be invoked with monotonically increasing reducePartitionIds; each + * call to this method will be called with a reducePartitionId that is strictly greater than + * the reducePartitionIds given to any previous call to this method. This method is not + * guaranteed to be called for every partition id in the above described range. In particular, + * no guarantees are made as to whether or not this method will be called for empty partitions. + */ + ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException; + + /** + * Commits the writes done by all partition writers returned by all calls to this object's + * {@link #getPartitionWriter(int)}, and returns a bundle of metadata associated with the + * behavior of the write. + *

+ * This should ensure that the writes conducted by this module's partition writers are + * available to downstream reduce tasks. If this method throws any exception, this module's + * {@link #abort(Throwable)} method will be invoked before propagating the exception. + *

+ * This can also close any resources and clean up temporary state if necessary. + *

+ * The returned array should contain two sets of metadata: + * + * 1. For each partition from (0) to (numPartitions - 1), the number of bytes written by + * the partition writer for that partition id. + * + * 2. If the partition data was stored on the local disk of this executor, also provide + * the block manager id where these bytes can be fetched from. + */ + MapOutputWriterCommitMessage commitAllPartitions() throws IOException; + + /** + * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}. + *

+ * This should invalidate the results of writing bytes. This can also close any resources and + * clean up temporary state if necessary. + */ + void abort(Throwable error) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java new file mode 100644 index 0000000000000..928875156a70f --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.IOException; +import java.util.Optional; +import java.io.OutputStream; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * An interface for opening streams to persist partition bytes to a backing data store. + *

+ * This writer stores bytes for one (mapper, reducer) pair, corresponding to one shuffle + * block. + * + * @since 3.0.0 + */ +@Private +public interface ShufflePartitionWriter { + + /** + * Open and return an {@link OutputStream} that can write bytes to the underlying + * data store. + *

+ * This method will only be called once on this partition writer in the map task, to write the + * bytes to the partition. The output stream will only be used to write the bytes for this + * partition. The map task closes this output stream upon writing all the bytes for this + * block, or if the write fails for any reason. + *

+ * Implementations that intend on combining the bytes for all the partitions written by this + * map task should reuse the same OutputStream instance across all the partition writers provided + * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that + * {@link OutputStream#close()} does not close the resource, since it will be reused across + * partition writes. The underlying resources should be cleaned up in + * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + */ + OutputStream openStream() throws IOException; + + /** + * Opens and returns a {@link WritableByteChannelWrapper} for transferring bytes from + * input byte channels to the underlying shuffle data store. + *

+ * This method will only be called once on this partition writer in the map task, to write the + * bytes to the partition. The channel will only be used to write the bytes for this + * partition. The map task closes this channel upon writing all the bytes for this + * block, or if the write fails for any reason. + *

+ * Implementations that intend on combining the bytes for all the partitions written by this + * map task should reuse the same channel instance across all the partition writers provided + * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that + * {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel + * will be reused across partition writes. The underlying resources should be cleaned up in + * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + *

+ * This method is primarily for advanced optimizations where bytes can be copied from the input + * spill files to the output channel without copying data into memory. If such optimizations are + * not supported, the implementation should return {@link Optional#empty()}. By default, the + * implementation returns {@link Optional#empty()}. + *

+ * Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the + * underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure + * that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()}, + * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + */ + default Optional openChannelWrapper() throws IOException { + return Optional.empty(); + } + + /** + * Returns the number of bytes written either by this writer's output stream opened by + * {@link #openStream()} or the byte channel opened by {@link #openChannelWrapper()}. + *

+ * This can be different from the number of bytes given by the caller. For example, the + * stream might compress or encrypt the bytes before persisting the data to the backing + * data store. + */ + long getNumBytesWritten(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java similarity index 62% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java rename to core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java index 7ee1d8a554073..bddb97bdf0d7e 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java @@ -15,23 +15,23 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; +import java.io.File; import java.io.IOException; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Private; /** - * :: Experimental :: - * An interface for deploying a shuffle map output writer - * - * @since 3.0.0 + * Optional extension for partition writing that is optimized for transferring a single + * file to the backing store. */ -@Experimental -public interface ShuffleWriteSupport { - ShuffleMapOutputWriter createMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId, - int numPartitions) throws IOException; +@Private +public interface SingleSpillShuffleMapOutputWriter { + + /** + * Transfer a file that contains the bytes of all the partitions written by this map task. + */ + MapOutputWriterCommitMessage transferMapSpillFile( + File mapOutputFile, long[] partitionLengths) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java similarity index 59% rename from core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java rename to core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java index 74c928b0b9c8f..a204903008a51 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java @@ -15,30 +15,28 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; -import java.io.IOException; -import java.io.OutputStream; +import java.io.Closeable; +import java.nio.channels.WritableByteChannel; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Private; /** - * :: Experimental :: - * An interface for giving streams / channels for shuffle writes. + * :: Private :: + * + * A thin wrapper around a {@link WritableByteChannel}. + *

+ * This is primarily provided for the local disk shuffle implementation to provide a + * {@link java.nio.channels.FileChannel} that keeps the channel open across partition writes. * * @since 3.0.0 */ -@Experimental -public interface ShufflePartitionWriter { - - /** - * Opens and returns an underlying {@link OutputStream} that can write bytes to the underlying - * data store. - */ - OutputStream openStream() throws IOException; +@Private +public interface WritableByteChannelWrapper extends Closeable { /** - * Get the number of bytes written by this writer's stream returned by {@link #openStream()}. + * The underlying channel to write bytes into. */ - long getNumBytesWritten(); + WritableByteChannel channel(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 63aee8ad50da3..94ad5fc66185b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,11 +21,10 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.OutputStream; -import java.nio.channels.Channels; import java.nio.channels.FileChannel; +import java.util.Optional; import javax.annotation.Nullable; -import org.apache.spark.api.java.Optional; import scala.None$; import scala.Option; import scala.Product2; @@ -40,11 +39,11 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; -import org.apache.spark.api.shuffle.SupportsTransferTo; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.internal.config.package$; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -90,13 +89,13 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int mapId; private final long mapTaskAttemptId; private final Serializer serializer; - private final ShuffleWriteSupport shuffleWriteSupport; + private final ShuffleExecutorComponents shuffleExecutorComponents; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; - private long[] partitionLengths; + private MapOutputWriterCommitMessage commitMessage; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -112,34 +111,33 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { long mapTaskAttemptId, SparkConf conf, ShuffleWriteMetricsReporter writeMetrics, - ShuffleWriteSupport shuffleWriteSupport) { + ShuffleExecutorComponents shuffleExecutorComponents) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); this.blockManager = blockManager; final ShuffleDependency dep = handle.dependency(); this.mapId = mapId; - this.shuffleId = dep.shuffleId(); this.mapTaskAttemptId = mapTaskAttemptId; + this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); - this.shuffleWriteSupport = shuffleWriteSupport; + this.shuffleExecutorComponents = shuffleExecutorComponents; } @Override public void write(Iterator> records) throws IOException { assert (partitionWriters == null); - ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport + ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); try { if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; - Optional location = mapOutputWriter.commitAllPartitions(); + commitMessage = mapOutputWriter.commitAllPartitions(); mapStatus = MapStatus$.MODULE$.apply( - location.orNull(), - partitionLengths, + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), mapTaskAttemptId); return; } @@ -172,14 +170,17 @@ public void write(Iterator> records) throws IOException { } } - partitionLengths = writePartitionedData(mapOutputWriter); - Optional location = mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(location.orNull(), partitionLengths, mapTaskAttemptId); + commitMessage = writePartitionedData(mapOutputWriter); + mapStatus = MapStatus$.MODULE$.apply( + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + mapTaskAttemptId); } catch (Exception e) { try { mapOutputWriter.abort(e); } catch (Exception e2) { logger.error("Failed to abort the writer after failing to write map output.", e2); + e.addSuppressed(e2); } throw e; } @@ -187,7 +188,7 @@ public void write(Iterator> records) throws IOException { @VisibleForTesting long[] getPartitionLengths() { - return partitionLengths; + return commitMessage.getPartitionLengths(); } /** @@ -195,61 +196,75 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException { + private MapOutputWriterCommitMessage writePartitionedData( + ShuffleMapOutputWriter mapOutputWriter) throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; - if (partitionWriters == null) { - // We were passed an empty iterator - return lengths; - } - final long writeStartTime = System.nanoTime(); - try { - for (int i = 0; i < numPartitions; i++) { - final File file = partitionWriterSegments[i].file(); - ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); - if (file.exists()) { - boolean copyThrewException = true; - if (transferToEnabled) { - FileInputStream in = new FileInputStream(file); - TransferrableWritableByteChannel outputChannel = null; - try (FileChannel inputChannel = in.getChannel()) { - if (writer instanceof SupportsTransferTo) { - outputChannel = ((SupportsTransferTo) writer).openTransferrableChannel(); + if (partitionWriters != null) { + final long writeStartTime = System.nanoTime(); + try { + for (int i = 0; i < numPartitions; i++) { + final File file = partitionWriterSegments[i].file(); + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); + if (file.exists()) { + if (transferToEnabled) { + // Using WritableByteChannelWrapper to make resource closing consistent between + // this implementation and UnsafeShuffleWriter. + Optional maybeOutputChannel = writer.openChannelWrapper(); + if (maybeOutputChannel.isPresent()) { + writePartitionedDataWithChannel(file, maybeOutputChannel.get()); } else { - // Use default transferrable writable channel anyways in order to have parity with - // UnsafeShuffleWriter. - outputChannel = new DefaultTransferrableWritableByteChannel( - Channels.newChannel(writer.openStream())); + writePartitionedDataWithStream(file, writer); } - outputChannel.transferFrom(inputChannel, 0L, inputChannel.size()); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - Closeables.close(outputChannel, copyThrewException); + } else { + writePartitionedDataWithStream(file, writer); } - } else { - FileInputStream in = new FileInputStream(file); - OutputStream outputStream = null; - try { - outputStream = writer.openStream(); - Utils.copyStream(in, outputStream, false, false); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - Closeables.close(outputStream, copyThrewException); + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); } } - if (!file.delete()) { - logger.error("Unable to delete file for partition {}", i); - } } - lengths[i] = writer.getNumBytesWritten(); + } finally { + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + } + return mapOutputWriter.commitAllPartitions(); + } + + private void writePartitionedDataWithChannel( + File file, + WritableByteChannelWrapper outputChannel) throws IOException { + boolean copyThrewException = true; + try { + FileInputStream in = new FileInputStream(file); + try (FileChannel inputChannel = in.getChannel()) { + Utils.copyFileStreamNIO( + inputChannel, outputChannel.channel(), 0L, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + } finally { + Closeables.close(outputChannel, copyThrewException); + } + } + + private void writePartitionedDataWithStream(File file, ShufflePartitionWriter writer) + throws IOException { + boolean copyThrewException = true; + FileInputStream in = new FileInputStream(file); + OutputStream outputStream; + try { + outputStream = writer.openStream(); + try { + Utils.copyStream(in, outputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(outputStream, copyThrewException); } } finally { - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + Closeables.close(in, copyThrewException); } - partitionWriters = null; - return lengths; } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java deleted file mode 100644 index 64ce851e392d2..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort; - -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; -import org.apache.spark.util.Utils; - -/** - * This is used when transferTo is enabled but the shuffle plugin hasn't implemented - * {@link org.apache.spark.api.shuffle.SupportsTransferTo}. - *

- * This default implementation exists as a convenience to the unsafe shuffle writer and - * the bypass merge sort shuffle writers. - */ -public class DefaultTransferrableWritableByteChannel implements TransferrableWritableByteChannel { - - private final WritableByteChannel delegate; - - public DefaultTransferrableWritableByteChannel(WritableByteChannel delegate) { - this.delegate = delegate; - } - - @Override - public void transferFrom( - FileChannel source, long transferStartPosition, long numBytesToTransfer) { - Utils.copyFileStreamNIO(source, delegate, transferStartPosition, numBytesToTransfer); - } - - @Override - public void close() throws IOException { - delegate.close(); - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9627f1151f837..acb86616066a8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -18,13 +18,13 @@ package org.apache.spark.shuffle.sort; import java.nio.channels.Channels; +import java.util.Optional; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; import java.util.Iterator; -import org.apache.spark.api.java.Optional; -import org.apache.spark.storage.BlockManagerId; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -39,11 +39,6 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.api.shuffle.SupportsTransferTo; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; @@ -56,8 +51,16 @@ import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -74,7 +77,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; - private final ShuffleWriteSupport shuffleWriteSupport; + private final ShuffleExecutorComponents shuffleExecutorComponents; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -111,7 +114,7 @@ public UnsafeShuffleWriter( TaskContext taskContext, SparkConf sparkConf, ShuffleWriteMetricsReporter writeMetrics, - ShuffleWriteSupport shuffleWriteSupport) throws IOException { + ShuffleExecutorComponents shuffleExecutorComponents) { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -127,7 +130,7 @@ public UnsafeShuffleWriter( this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = writeMetrics; - this.shuffleWriteSupport = shuffleWriteSupport; + this.shuffleExecutorComponents = shuffleExecutorComponents; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); @@ -216,34 +219,20 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport - .createMapOutputWriter(shuffleId, - mapId, - taskContext.taskAttemptId(), - partitioner.numPartitions()); - final long[] partitionLengths; - Optional location; + final MapOutputWriterCommitMessage commitMessage; try { - try { - partitionLengths = mergeSpills(spills, mapWriter); - } finally { - for (SpillInfo spill : spills) { - if (spill.file.exists() && !spill.file.delete()) { - logger.error("Error while deleting spill file {}", spill.file.getPath()); - } + commitMessage = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); } } - location = mapWriter.commitAllPartitions(); - } catch (Exception e) { - try { - mapWriter.abort(e); - } catch (Exception innerE) { - logger.error("Failed to abort the Map Output Writer", innerE); - } - throw e; } mapStatus = MapStatus$.MODULE$.apply( - location.orNull(), partitionLengths, taskContext.attemptNumber()); + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + taskContext.attemptNumber()); } @VisibleForTesting @@ -275,57 +264,94 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, - ShuffleMapOutputWriter mapWriter) throws IOException { + private MapOutputWriterCommitMessage mergeSpills(SpillInfo[] spills) throws IOException { + MapOutputWriterCommitMessage commitMessage; + if (spills.length == 0) { + final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents + .createMapOutputWriter( + shuffleId, + mapId, + taskContext.taskAttemptId(), + partitioner.numPartitions()); + return mapWriter.commitAllPartitions(); + } else if (spills.length == 1) { + Optional maybeSingleFileWriter = + shuffleExecutorComponents.createSingleFileMapOutputWriter( + shuffleId, mapId, taskContext.taskAttemptId()); + if (maybeSingleFileWriter.isPresent()) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + long[] partitionLengths = spills[0].partitionLengths; + return maybeSingleFileWriter.get().transferMapSpillFile( + spills[0].file, partitionLengths); + } else { + commitMessage = mergeSpillsUsingStandardWriter(spills); + } + } else { + commitMessage = mergeSpillsUsingStandardWriter(spills); + } + return commitMessage; + } + + private MapOutputWriterCommitMessage mergeSpillsUsingStandardWriter( + SpillInfo[] spills) throws IOException { + MapOutputWriterCommitMessage commitMessage; final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = - (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE()); + (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE()); final boolean fastMergeIsSupported = !compressionEnabled || - CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); - final int numPartitions = partitioner.numPartitions(); - long[] partitionLengths = new long[numPartitions]; + final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents + .createMapOutputWriter( + shuffleId, + mapId, + taskContext.taskAttemptId(), + partitioner.numPartitions()); try { - if (spills.length == 0) { - return partitionLengths; - } else { - // There are multiple spills to merge, so none of these spill files' lengths were counted - // towards our shuffle write count or shuffle write time. If we use the slow merge path, - // then the final output file's size won't necessarily be equal to the sum of the spill - // files' sizes. To guard against this case, we look at the output file's actual size when - // computing shuffle bytes written. - // - // We allow the individual merge methods to report their own IO times since different merge - // strategies use different IO techniques. We count IO during merge towards the shuffle - // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" - // branch in ExternalSorter. - if (fastMergeEnabled && fastMergeIsSupported) { - // Compression is disabled or we are using an IO compression codec that supports - // decompression of concatenated compressed streams, so we can perform a fast spill merge - // that doesn't need to interpret the spilled bytes. - if (transferToEnabled && !encryptionEnabled) { - logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, mapWriter); - } else { - logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, null); - } + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // write time, which appears to be consistent with the "not bypassing merge-sort" branch in + // ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled && !encryptionEnabled) { + logger.debug("Using transferTo-based fast merge"); + mergeSpillsWithTransferTo(spills, mapWriter); } else { - logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); + logger.debug("Using fileStream-based fast merge"); + mergeSpillsWithFileStream(spills, mapWriter, null); } - // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has - // in-memory records, we write out the in-memory records to a file but do not count that - // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs - // to be counted as shuffle write, but this will lead to double-counting of the final - // SpillInfo's bytes. - writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - return partitionLengths; + } else { + logger.debug("Using slow merge"); + mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); + commitMessage = mapWriter.commitAllPartitions(); + } catch (Exception e) { + try { + mapWriter.abort(e); + } catch (Exception e2) { + logger.warn("Failed to abort writing the map output.", e2); + e.addSuppressed(e2); } - } catch (IOException e) { throw e; } + return commitMessage; } /** @@ -344,12 +370,11 @@ private long[] mergeSpills(SpillInfo[] spills, * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithFileStream( + private void mergeSpillsWithFileStream( SpillInfo[] spills, ShuffleMapOutputWriter mapWriter, @Nullable CompressionCodec compressionCodec) throws IOException { final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; boolean threwException = true; @@ -360,11 +385,11 @@ private long[] mergeSpillsWithFileStream( inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { - boolean copyThrewExecption = true; + boolean copyThrewException = true; ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); - OutputStream partitionOutput = null; + OutputStream partitionOutput = writer.openStream(); try { - partitionOutput = writer.openStream(); + partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); @@ -374,6 +399,7 @@ private long[] mergeSpillsWithFileStream( if (partitionLengthInSpill > 0) { InputStream partitionInputStream = null; + boolean copySpillThrewException = true; try { partitionInputStream = new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); @@ -384,17 +410,18 @@ private long[] mergeSpillsWithFileStream( partitionInputStream); } ByteStreams.copy(partitionInputStream, partitionOutput); + copySpillThrewException = false; } finally { - partitionInputStream.close(); + Closeables.close(partitionInputStream, copySpillThrewException); } } - copyThrewExecption = false; + copyThrewException = false; } + copyThrewException = false; } finally { - Closeables.close(partitionOutput, copyThrewExecption); + Closeables.close(partitionOutput, copyThrewException); } long numBytesWritten = writer.getNumBytesWritten(); - partitionLengths[partition] = numBytesWritten; writeMetrics.incBytesWritten(numBytesWritten); } threwException = false; @@ -405,7 +432,6 @@ private long[] mergeSpillsWithFileStream( Closeables.close(stream, threwException); } } - return partitionLengths; } /** @@ -417,11 +443,10 @@ private long[] mergeSpillsWithFileStream( * @param mapWriter the map output writer to use for output. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithTransferTo( + private void mergeSpillsWithTransferTo( SpillInfo[] spills, ShuffleMapOutputWriter mapWriter) throws IOException { final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; @@ -431,30 +456,28 @@ private long[] mergeSpillsWithTransferTo( spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); } for (int partition = 0; partition < numPartitions; partition++) { - boolean copyThrewExecption = true; + boolean copyThrewException = true; ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); - TransferrableWritableByteChannel partitionChannel = null; + WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper() + .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer))); try { - partitionChannel = writer instanceof SupportsTransferTo ? - ((SupportsTransferTo) writer).openTransferrableChannel() - : new DefaultTransferrableWritableByteChannel( - Channels.newChannel(writer.openStream())); for (int i = 0; i < spills.length; i++) { - long partitionLengthInSpill = 0L; - partitionLengthInSpill += spills[i].partitionLengths[partition]; + long partitionLengthInSpill = spills[i].partitionLengths[partition]; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); - partitionChannel.transferFrom( - spillInputChannel, spillInputChannelPositions[i], partitionLengthInSpill); + Utils.copyFileStreamNIO( + spillInputChannel, + resolvedChannel.channel(), + spillInputChannelPositions[i], + partitionLengthInSpill); + copyThrewException = false; spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } - copyThrewExecption = false; } finally { - Closeables.close(partitionChannel, copyThrewExecption); + Closeables.close(resolvedChannel, copyThrewException); } long numBytes = writer.getNumBytesWritten(); - partitionLengths[partition] = numBytes; writeMetrics.incBytesWritten(numBytes); } threwException = false; @@ -466,7 +489,6 @@ private long[] mergeSpillsWithTransferTo( Closeables.close(spillInputChannels[i], threwException); } } - return partitionLengths; } @Override @@ -495,4 +517,30 @@ public Option stop(boolean success) { } } } + + private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) { + try { + return writer.openStream(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper { + private final WritableByteChannel channel; + + StreamFallbackChannelWrapper(OutputStream fallbackStream) { + this.channel = Channels.newChannel(fallbackStream); + } + + @Override + public WritableByteChannel channel() { + return channel; + } + + @Override + public void close() throws IOException { + channel.close(); + } + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java deleted file mode 100644 index 3b5f9670d64d2..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import org.apache.spark.MapOutputTracker; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkEnv; -import org.apache.spark.api.shuffle.ShuffleExecutorComponents; -import org.apache.spark.api.shuffle.ShuffleReadSupport; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.serializer.SerializerManager; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; -import org.apache.spark.storage.BlockManager; - -import java.util.Map; - -public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { - - private final SparkConf sparkConf; - private BlockManager blockManager; - private IndexShuffleBlockResolver blockResolver; - private MapOutputTracker mapOutputTracker; - private SerializerManager serializerManager; - - public DefaultShuffleExecutorComponents(SparkConf sparkConf) { - this.sparkConf = sparkConf; - } - - @Override - public void initializeExecutor(String appId, String execId, Map extraConfigs) { - blockManager = SparkEnv.get().blockManager(); - mapOutputTracker = SparkEnv.get().mapOutputTracker(); - serializerManager = SparkEnv.get().serializerManager(); - blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); - } - - @Override - public ShuffleWriteSupport writes() { - checkInitialized(); - return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); - } - - @Override - public ShuffleReadSupport reads() { - checkInitialized(); - return new DefaultShuffleReadSupport(blockManager, - mapOutputTracker, - serializerManager, - sparkConf); - } - - private void checkInitialized() { - if (blockResolver == null) { - throw new IllegalStateException( - "Executor components must be initialized before getting writers."); - } - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java similarity index 61% rename from core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java rename to core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java index 7c124c1fe68bc..77fcd34f962bf 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java @@ -18,26 +18,31 @@ package org.apache.spark.shuffle.sort.io; import org.apache.spark.SparkConf; -import org.apache.spark.api.shuffle.ShuffleDriverComponents; -import org.apache.spark.api.shuffle.ShuffleExecutorComponents; -import org.apache.spark.api.shuffle.ShuffleDataIO; -import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents; +import org.apache.spark.shuffle.api.ShuffleDriverComponents; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleDataIO; +import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents; -public class DefaultShuffleDataIO implements ShuffleDataIO { +/** + * Implementation of the {@link ShuffleDataIO} plugin system that replicates the local shuffle + * storage and index file functionality that has historically been used from Spark 2.4 and earlier. + */ +public class LocalDiskShuffleDataIO implements ShuffleDataIO { private final SparkConf sparkConf; - public DefaultShuffleDataIO(SparkConf sparkConf) { + public LocalDiskShuffleDataIO(SparkConf sparkConf) { this.sparkConf = sparkConf; } @Override - public ShuffleExecutorComponents executor() { - return new DefaultShuffleExecutorComponents(sparkConf); + public ShuffleDriverComponents driver() { + return new LocalDiskShuffleDriverComponents(); } @Override - public ShuffleDriverComponents driver() { - return new DefaultShuffleDriverComponents(); + public ShuffleExecutorComponents executor() { + return new LocalDiskShuffleExecutorComponents(sparkConf); } + } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java new file mode 100644 index 0000000000000..c8d70d72eb02e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.io.InputStream; +import java.util.Map; +import java.util.Optional; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.MapOutputTracker; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.shuffle.api.ShuffleBlockInfo; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockManagerId; + +public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents { + + private final SparkConf sparkConf; + private LocalDiskShuffleReadSupport shuffleReadSupport; + private BlockManagerId shuffleServerId; + private BlockManager blockManager; + private IndexShuffleBlockResolver blockResolver; + + public LocalDiskShuffleExecutorComponents(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @VisibleForTesting + public LocalDiskShuffleExecutorComponents( + SparkConf sparkConf, + BlockManager blockManager, + MapOutputTracker mapOutputTracker, + SerializerManager serializerManager, + IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId) { + this.sparkConf = sparkConf; + this.blockManager = blockManager; + this.blockResolver = blockResolver; + this.shuffleServerId = shuffleServerId; + this.shuffleReadSupport = new LocalDiskShuffleReadSupport( + blockManager, mapOutputTracker, serializerManager, sparkConf); + } + + @Override + public void initializeExecutor(String appId, String execId, Map extraConfigs) { + blockManager = SparkEnv.get().blockManager(); + if (blockManager == null) { + throw new IllegalStateException("No blockManager available from the SparkEnv."); + } + shuffleServerId = blockManager.shuffleServerId(); + blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker(); + SerializerManager serializerManager = SparkEnv.get().serializerManager(); + shuffleReadSupport = new LocalDiskShuffleReadSupport( + blockManager, mapOutputTracker, serializerManager, sparkConf); + } + + @Override + public ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId, + int numPartitions) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return new LocalDiskShuffleMapOutputWriter( + shuffleId, + mapId, + numPartitions, + blockResolver, + shuffleServerId, + sparkConf); + } + + @Override + public Optional createSingleFileMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return Optional.of(new LocalDiskSingleSpillMapOutputWriter( + shuffleId, mapId, blockResolver, shuffleServerId)); + } + + @Override + public Iterable getPartitionReaders(Iterable blockMetadata) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting readers."); + } + return shuffleReadSupport.getPartitionReaders(blockMetadata); + } + + @Override + public boolean shouldWrapPartitionReaderStream() { + return false; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java similarity index 68% rename from core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index ad55b3db377f6..064875420c473 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -23,73 +23,73 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Optional; -import org.apache.spark.api.java.Optional; -import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.SparkConf; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.api.shuffle.SupportsTransferTo; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; -import org.apache.spark.internal.config.package$; -import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel; -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; -public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { +/** + * Implementation of {@link ShuffleMapOutputWriter} that replicates the functionality of shuffle + * persisting shuffle data to local disk alongside index files, identical to Spark's historic + * canonical shuffle storage mechanism. + */ +public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { private static final Logger log = - LoggerFactory.getLogger(DefaultShuffleMapOutputWriter.class); + LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class); private final int shuffleId; private final int mapId; - private final ShuffleWriteMetricsReporter metrics; private final IndexShuffleBlockResolver blockResolver; private final long[] partitionLengths; private final int bufferSize; + private final BlockManagerId shuffleServerId; private int lastPartitionId = -1; private long currChannelPosition; - private final BlockManagerId shuffleServerId; + private long bytesWrittenToMergedFile = 0L; private final File outputFile; private File outputTempFile; private FileOutputStream outputFileStream; private FileChannel outputFileChannel; - private TimeTrackingOutputStream ts; private BufferedOutputStream outputBufferedFileStream; - public DefaultShuffleMapOutputWriter( + public LocalDiskShuffleMapOutputWriter( int shuffleId, int mapId, int numPartitions, - BlockManagerId shuffleServerId, - ShuffleWriteMetricsReporter metrics, IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId, SparkConf sparkConf) { this.shuffleId = shuffleId; this.mapId = mapId; - this.shuffleServerId = shuffleServerId; - this.metrics = metrics; this.blockResolver = blockResolver; this.bufferSize = (int) (long) sparkConf.get( package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + this.shuffleServerId = shuffleServerId; this.partitionLengths = new long[numPartitions]; this.outputFile = blockResolver.getDataFile(shuffleId, mapId); this.outputTempFile = null; } @Override - public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException { - if (partitionId <= lastPartitionId) { + public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException { + if (reducePartitionId <= lastPartitionId) { throw new IllegalArgumentException("Partitions should be requested in increasing order."); } - lastPartitionId = partitionId; + lastPartitionId = reducePartitionId; if (outputTempFile == null) { outputTempFile = Utils.tempFileWith(outputFile); } @@ -98,24 +98,32 @@ public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOExcep } else { currChannelPosition = 0L; } - return new DefaultShufflePartitionWriter(partitionId); + return new LocalDiskShufflePartitionWriter(reducePartitionId); } @Override - public Optional commitAllPartitions() throws IOException { + public MapOutputWriterCommitMessage commitAllPartitions() throws IOException { + // Check the position after transferTo loop to see if it is in the right position and raise a + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + outputFileChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your " + + " kernel version to see if it is 2.6.32, as there is a kernel bug which will lead " + + "to unexpected behavior when using transferTo. You can set " + + "spark.file.transferTo=false to disable this NIO feature."); + } cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); - return Optional.of(shuffleServerId); + return MapOutputWriterCommitMessage.of(partitionLengths, shuffleServerId); } @Override - public void abort(Throwable error) { - try { - cleanUp(); - } catch (Exception e) { - log.error("Unable to close appropriate underlying file stream", e); - } + public void abort(Throwable error) throws IOException { + cleanUp(); if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath()); } @@ -136,29 +144,27 @@ private void cleanUp() throws IOException { private void initStream() throws IOException { if (outputFileStream == null) { outputFileStream = new FileOutputStream(outputTempFile, true); - ts = new TimeTrackingOutputStream(metrics, outputFileStream); } if (outputBufferedFileStream == null) { - outputBufferedFileStream = new BufferedOutputStream(ts, bufferSize); + outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize); } } private void initChannel() throws IOException { - if (outputFileStream == null) { - outputFileStream = new FileOutputStream(outputTempFile, true); - } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. if (outputFileChannel == null) { - outputFileChannel = outputFileStream.getChannel(); + outputFileChannel = new FileOutputStream(outputTempFile, true).getChannel(); } } - private class DefaultShufflePartitionWriter implements SupportsTransferTo { + private class LocalDiskShufflePartitionWriter implements ShufflePartitionWriter { private final int partitionId; private PartitionWriterStream partStream = null; private PartitionWriterChannel partChannel = null; - private DefaultShufflePartitionWriter(int partitionId) { + private LocalDiskShufflePartitionWriter(int partitionId) { this.partitionId = partitionId; } @@ -177,7 +183,7 @@ public OutputStream openStream() throws IOException { } @Override - public TransferrableWritableByteChannel openTransferrableChannel() throws IOException { + public Optional openChannelWrapper() throws IOException { if (partChannel == null) { if (partStream != null) { throw new IllegalStateException("Requested an output stream for a previous write but" + @@ -187,7 +193,7 @@ public TransferrableWritableByteChannel openTransferrableChannel() throws IOExce initChannel(); partChannel = new PartitionWriterChannel(partitionId); } - return partChannel; + return Optional.of(partChannel); } @Override @@ -238,6 +244,7 @@ public void write(byte[] buf, int pos, int length) throws IOException { public void close() { isClosed = true; partitionLengths[partitionId] = count; + bytesWrittenToMergedFile += count; } private void verifyNotClosed() { @@ -247,12 +254,11 @@ private void verifyNotClosed() { } } - private class PartitionWriterChannel extends DefaultTransferrableWritableByteChannel { + private class PartitionWriterChannel implements WritableByteChannelWrapper { private final int partitionId; PartitionWriterChannel(int partitionId) { - super(outputFileChannel); this.partitionId = partitionId; } @@ -261,9 +267,15 @@ public long getCount() throws IOException { return writtenPosition - currChannelPosition; } + @Override + public WritableByteChannel channel() { + return outputFileChannel; + } + @Override public void close() throws IOException { partitionLengths[partitionId] = getCount(); + bytesWrittenToMergedFile += partitionLengths[partitionId]; } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java similarity index 52% rename from core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java rename to core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java index d6210f045840b..219f9ee1296dd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -17,36 +17,45 @@ package org.apache.spark.shuffle.sort.io; -import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; + import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.util.Utils; -public class DefaultShuffleWriteSupport implements ShuffleWriteSupport { +public class LocalDiskSingleSpillMapOutputWriter + implements SingleSpillShuffleMapOutputWriter { - private final SparkConf sparkConf; + private final int shuffleId; + private final int mapId; private final IndexShuffleBlockResolver blockResolver; private final BlockManagerId shuffleServerId; - public DefaultShuffleWriteSupport( - SparkConf sparkConf, + public LocalDiskSingleSpillMapOutputWriter( + int shuffleId, + int mapId, IndexShuffleBlockResolver blockResolver, BlockManagerId shuffleServerId) { - this.sparkConf = sparkConf; + this.shuffleId = shuffleId; + this.mapId = mapId; this.blockResolver = blockResolver; this.shuffleServerId = shuffleServerId; } @Override - public ShuffleMapOutputWriter createMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId, - int numPartitions) { - return new DefaultShuffleMapOutputWriter( - shuffleId, mapId, numPartitions, shuffleServerId, - TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); + public MapOutputWriterCommitMessage transferMapSpillFile( + File mapSpillFile, + long[] partitionLengths) throws IOException { + // The map spill file already has the proper format, and it contains all of the partition data. + // So just transfer it directly to the destination without any merging. + File outputFile = blockResolver.getDataFile(shuffleId, mapId); + File tempFile = Utils.tempFileWith(outputFile); + Files.move(mapSpillFile.toPath(), tempFile.toPath()); + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); + return MapOutputWriterCommitMessage.of(partitionLengths, shuffleServerId); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java similarity index 93% rename from core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java rename to core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java index c6f43b91f90a0..183769274841c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java @@ -22,11 +22,11 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.SparkEnv; -import org.apache.spark.api.shuffle.ShuffleDriverComponents; +import org.apache.spark.shuffle.api.ShuffleDriverComponents; import org.apache.spark.internal.config.package$; import org.apache.spark.storage.BlockManagerMaster; -public class DefaultShuffleDriverComponents implements ShuffleDriverComponents { +public class LocalDiskShuffleDriverComponents implements ShuffleDriverComponents { private BlockManagerMaster blockManagerMaster; private boolean shouldUnregisterOutputOnHostOnFetchFailure; diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bcd47ba0c29c1..98232380cc266 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -23,11 +23,11 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Scheduled import scala.collection.JavaConverters._ -import org.apache.spark.api.shuffle.ShuffleDriverComponents import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} +import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, ThreadUtils, Utils} /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f359022716571..c84bc82b9a29f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -43,7 +43,6 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.conda.CondaEnvironment import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} @@ -58,6 +57,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents} import org.apache.spark.status.{AppStatusSource, AppStatusStore} import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage._ diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a852a06be9125..833db06420d4d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode} -import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -770,10 +770,10 @@ package object config { .createWithDefault(false) private[spark] val SHUFFLE_IO_PLUGIN_CLASS = - ConfigBuilder("spark.shuffle.io.plugin.class") + ConfigBuilder("spark.shuffle.sort.io.plugin.class") .doc("Name of the class to use for shuffle IO.") .stringConf - .createWithDefault(classOf[DefaultShuffleDataIO].getName) + .createWithDefault(classOf[LocalDiskShuffleDataIO].getName) private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") @@ -951,7 +951,7 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val SHUFFLE_UNDAFE_FAST_MERGE_ENABLE = + private[spark] val SHUFFLE_UNSAFE_FAST_MERGE_ENABLE = ConfigBuilder("spark.shuffle.unsafe.fastMergeEnabled") .doc("Whether to perform a fast spill merge.") .booleanConf diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index a20a849cc6421..e614dbc8c9542 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,17 +17,14 @@ package org.apache.spark.shuffle -import java.io.InputStream - import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.api.java.Optional -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.api.{ShuffleBlockInfo, ShuffleExecutorComponents} import org.apache.spark.storage.ShuffleBlockAttemptId import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -42,7 +39,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, - shuffleReadSupport: ShuffleReadSupport, + shuffleExecutorComponents: ShuffleExecutorComponents, serializerManager: SerializerManager = SparkEnv.get.serializerManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, sparkConf: SparkConf = SparkEnv.get.conf) @@ -57,7 +54,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val streamsIterator = - shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] { + shuffleExecutorComponents.getPartitionReaders(new Iterable[ShuffleBlockInfo] { override def iterator: Iterator[ShuffleBlockInfo] = { mapOutputTracker .getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) @@ -76,18 +73,18 @@ private[spark] class BlockStoreShuffleReader[K, C]( } }.asJava).iterator() - val retryingWrappedStreams = streamsIterator.asScala.map(readSupportStream => { - if (shuffleReadSupport.shouldWrapStream()) { + val retryingWrappedStreams = streamsIterator.asScala.map(rawReaderStream => { + if (shuffleExecutorComponents.shouldWrapPartitionReaderStream()) { if (compressShuffle) { compressionCodec.compressedInputStream( - serializerManager.wrapForEncryption(readSupportStream)) + serializerManager.wrapForEncryption(rawReaderStream)) } else { - serializerManager.wrapForEncryption(readSupportStream) + serializerManager.wrapForEncryption(rawReaderStream) } } else { // The default implementation checks for corrupt streams, so it will already have // decompressed/decrypted the bytes - readSupportStream + rawReaderStream } }) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala new file mode 100644 index 0000000000000..e0affb858c359 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.{Closeable, IOException, OutputStream} + +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.api.ShufflePartitionWriter +import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream} +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.PairsWriter + +/** + * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an + * arbitrary partition writer instead of writing to local disk through the block manager. + */ +private[spark] class ShufflePartitionPairsWriter( + partitionWriter: ShufflePartitionWriter, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + blockId: BlockId, + writeMetrics: ShuffleWriteMetricsReporter) + extends PairsWriter with Closeable { + + private var isClosed = false + private var partitionStream: OutputStream = _ + private var timeTrackingStream: OutputStream = _ + private var wrappedStream: OutputStream = _ + private var objOut: SerializationStream = _ + private var numRecordsWritten = 0 + private var curNumBytesWritten = 0L + + override def write(key: Any, value: Any): Unit = { + if (isClosed) { + throw new IOException("Partition pairs writer is already closed.") + } + if (objOut == null) { + open() + } + objOut.writeKey(key) + objOut.writeValue(value) + recordWritten() + } + + private def open(): Unit = { + try { + partitionStream = partitionWriter.openStream + timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream) + wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream) + objOut = serializerInstance.serializeStream(wrappedStream) + } catch { + case e: Exception => + Utils.tryLogNonFatalError { + close() + } + throw e + } + } + + override def close(): Unit = { + if (!isClosed) { + Utils.tryWithSafeFinally { + Utils.tryWithSafeFinally { + objOut = closeIfNonNull(objOut) + // Setting these to null will prevent the underlying streams from being closed twice + // just in case any stream's close() implementation is not idempotent. + wrappedStream = null + timeTrackingStream = null + partitionStream = null + } { + // Normally closing objOut would close the inner streams as well, but just in case there + // was an error in initialization etc. we make sure we clean the other streams up too. + Utils.tryWithSafeFinally { + wrappedStream = closeIfNonNull(wrappedStream) + // Same as above - if wrappedStream closes then assume it closes underlying + // partitionStream and don't close again in the finally + timeTrackingStream = null + partitionStream = null + } { + Utils.tryWithSafeFinally { + timeTrackingStream = closeIfNonNull(timeTrackingStream) + partitionStream = null + } { + partitionStream = closeIfNonNull(partitionStream) + } + } + } + updateBytesWritten() + } { + isClosed = true + } + } + } + + private def closeIfNonNull[T <: Closeable](closeable: T): T = { + if (closeable != null) { + closeable.close() + } + null.asInstanceOf[T] + } + + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + private def recordWritten(): Unit = { + numRecordsWritten += 1 + writeMetrics.incRecordsWritten(1) + + if (numRecordsWritten % 16384 == 0) { + updateBytesWritten() + } + } + + private def updateBytesWritten(): Unit = { + val numBytesWritten = partitionWriter.getNumBytesWritten + val bytesWrittenDiff = numBytesWritten - curNumBytesWritten + writeMetrics.incBytesWritten(bytesWrittenDiff) + curNumBytesWritten = numBytesWritten + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala similarity index 90% rename from core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala rename to core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala index e18097c2c590a..9e1c1816d306c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala @@ -22,17 +22,17 @@ import java.io.InputStream import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockAttemptId, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.shuffle.api.ShuffleBlockInfo +import org.apache.spark.storage.{BlockManager, ShuffleBlockAttemptId, ShuffleBlockFetcherIterator, ShuffleBlockId} -class DefaultShuffleReadSupport( +class LocalDiskShuffleReadSupport( blockManager: BlockManager, mapOutputTracker: MapOutputTracker, serializerManager: SerializerManager, - conf: SparkConf) extends ShuffleReadSupport { + conf: SparkConf) { private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 private val maxReqsInFlight = conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) @@ -41,7 +41,7 @@ class DefaultShuffleReadSupport( private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) - override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): + def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { val iterableToReturn = if (blockMetadata.asScala.isEmpty) { @@ -70,8 +70,6 @@ class DefaultShuffleReadSupport( } iterableToReturn.asJava } - - override def shouldWrapStream(): Boolean = false } private class ShuffleBlockFetcherIterable( diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index c364c8d08db20..610c04ace3b6f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -22,9 +22,9 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import org.apache.spark._ -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents} import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents} import org.apache.spark.util.Utils /** @@ -130,7 +130,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition, context, metrics, - shuffleExecutorComponents.reads()) + shuffleExecutorComponents) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -152,7 +152,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context, env.conf, metrics, - shuffleExecutorComponents.writes()) + shuffleExecutorComponents) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, @@ -161,10 +161,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context.taskAttemptId(), env.conf, metrics, - shuffleExecutorComponents.writes()) + shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents.writes()) + shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 26f3f2267d44d..0082b4c9c6b24 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -18,10 +18,10 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ -import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( @@ -29,7 +29,7 @@ private[spark] class SortShuffleWriter[K, V, C]( handle: BaseShuffleHandle[K, V, C], mapId: Int, context: TaskContext, - writeSupport: ShuffleWriteSupport) + shuffleExecutorComponents: ShuffleExecutorComponents) extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency @@ -64,11 +64,14 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val mapOutputWriter = writeSupport.createMapOutputWriter( + val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions) - val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - val location = mapOutputWriter.commitAllPartitions - mapStatus = MapStatus(location.orNull, partitionLengths, context.taskAttemptId()) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + val commitMessage = mapOutputWriter.commitAllPartitions() + mapStatus = MapStatus( + commitMessage.getLocation.orElse(null), + commitMessage.getPartitionLengths, + context.taskAttemptId()) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index f9f4e3594e4f9..758621c52495b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -234,7 +234,7 @@ private[spark] class DiskBlockObjectWriter( /** * Writes a key-value pair. */ - def write(key: Any, value: Any) { + override def write(key: Any, value: Any) { if (!streamOpen) { open() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 14d34e1c47c8e..1216a45415a74 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,11 +26,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams import org.apache.spark._ -import org.apache.spark.api.shuffle.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ +import org.apache.spark.shuffle.ShufflePartitionPairsWriter +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} +import org.apache.spark.util.{Utils => TryUtils} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -675,9 +677,9 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project. - * We should figure out an alternative way to test that so that we can remove this otherwise - * unused code path. + * TODO(SPARK-28764): remove this, as this is only used by UnsafeRowSerializerSuite in the SQL + * project. We should figure out an alternative way to test that so that we can remove this + * otherwise unused code path. */ def writePartitionedFile( blockId: BlockId, @@ -728,9 +730,9 @@ private[spark] class ExternalSorter[K, V, C]( * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ def writePartitionedMapOutput( - shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { - // Track location of each range in the map output - val lengths = new Array[Long](numPartitions) + shuffleId: Int, + mapId: Int, + mapOutputWriter: ShuffleMapOutputWriter): Unit = { if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer @@ -739,7 +741,7 @@ private[spark] class ExternalSorter[K, V, C]( val partitionId = it.nextPartition() var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null - try { + TryUtils.tryWithSafeFinally { partitionWriter = mapOutputWriter.getPartitionWriter(partitionId) val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) partitionPairsWriter = new ShufflePartitionPairsWriter( @@ -751,28 +753,19 @@ private[spark] class ExternalSorter[K, V, C]( while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(partitionPairsWriter) } - } finally { + } { if (partitionPairsWriter != null) { partitionPairsWriter.close() } } - if (partitionWriter != null) { - lengths(partitionId) = partitionWriter.getNumBytesWritten - } } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. for ((id, elements) <- this.partitionedIterator) { - // The contract for the plugin is that we will ask for a writer for every partition - // even if it's empty. However, the external sorter will return non-contiguous - // partition ids. So this loop "backfills" the empty partitions that form the gaps. - - // The algorithm as a whole is correct because the partition ids are returned by the - // iterator in ascending order. val blockId = ShuffleBlockId(shuffleId, mapId, id) var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null - try { + TryUtils.tryWithSafeFinally { partitionWriter = mapOutputWriter.getPartitionWriter(id) partitionPairsWriter = new ShufflePartitionPairsWriter( partitionWriter, @@ -785,22 +778,17 @@ private[spark] class ExternalSorter[K, V, C]( partitionPairsWriter.write(elem._1, elem._2) } } - } finally { - if (partitionPairsWriter!= null) { + } { + if (partitionPairsWriter != null) { partitionPairsWriter.close() } } - if (partitionWriter != null) { - lengths(id) = partitionWriter.getNumBytesWritten - } } } context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - - lengths } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala index 9d7c209f242e1..05ed72c3e3778 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala @@ -17,6 +17,11 @@ package org.apache.spark.util.collection +/** + * An abstraction of a consumer of key-value pairs, primarily used when + * persisting partitioned data, either through the shuffle writer plugins + * or via DiskBlockObjectWriter. + */ private[spark] trait PairsWriter { def write(key: Any, value: Any): Unit diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala deleted file mode 100644 index 8538a78b377c8..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.{Closeable, FilterOutputStream, OutputStream} - -import org.apache.spark.api.shuffle.ShufflePartitionWriter -import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter -import org.apache.spark.storage.BlockId - -/** - * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an - * arbitrary partition writer instead of writing to local disk through the block manager. - */ -private[spark] class ShufflePartitionPairsWriter( - partitionWriter: ShufflePartitionWriter, - serializerManager: SerializerManager, - serializerInstance: SerializerInstance, - blockId: BlockId, - writeMetrics: ShuffleWriteMetricsReporter) - extends PairsWriter with Closeable { - - private var isOpen = false - private var partitionStream: OutputStream = _ - private var wrappedStream: OutputStream = _ - private var objOut: SerializationStream = _ - private var numRecordsWritten = 0 - private var curNumBytesWritten = 0L - - override def write(key: Any, value: Any): Unit = { - if (!isOpen) { - open() - isOpen = true - } - objOut.writeKey(key) - objOut.writeValue(value) - writeMetrics.incRecordsWritten(1) - } - - private def open(): Unit = { - partitionStream = partitionWriter.openStream - wrappedStream = serializerManager.wrapStream(blockId, partitionStream) - objOut = serializerInstance.serializeStream(wrappedStream) - } - - override def close(): Unit = { - if (isOpen) { - objOut.close() - objOut = null - wrappedStream = null - partitionStream = null - isOpen = false - updateBytesWritten() - } - } - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - private def recordWritten(): Unit = { - numRecordsWritten += 1 - writeMetrics.incRecordsWritten(1) - - if (numRecordsWritten % 16384 == 0) { - updateBytesWritten() - } - } - - private def updateBytesWritten(): Unit = { - val numBytesWritten = partitionWriter.getNumBytesWritten - val bytesWrittenDiff = numBytesWritten - curNumBytesWritten - writeMetrics.incBytesWritten(bytesWrittenDiff) - curNumBytesWritten = numBytesWritten - } -} diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 4c2e6ac6474da..18f3a339e246c 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -38,6 +38,7 @@ import org.mockito.MockitoAnnotations; import org.apache.spark.HashPartitioner; +import org.apache.spark.MapOutputTracker; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; @@ -56,7 +57,7 @@ import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport; +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -87,6 +88,8 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; + @Mock(answer = RETURNS_SMART_NULLS) MapOutputTracker mapOutputTracker; + @Mock(answer = RETURNS_SMART_NULLS) SerializerManager serializerManager; @After public void tearDown() { @@ -138,8 +141,7 @@ public void setUp() throws IOException { }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - - Answer renameTempAnswer = invocationOnMock -> { + Answer renameTempAnswer = invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; File tmp = (File) invocationOnMock.getArguments()[3]; if (!mergedOutputFile.delete()) { @@ -172,23 +174,25 @@ public void setUp() throws IOException { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); - - TaskContext$.MODULE$.setTaskContext(taskContext); } - private UnsafeShuffleWriter createWriter( - boolean transferToEnabled) throws IOException { + private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); - return new UnsafeShuffleWriter<>( + return new UnsafeShuffleWriter( blockManager, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId()) - ); + new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + shuffleBlockResolver, + BlockManagerId.apply("localhost", 7077))); } private void assertSpillFilesWereCleanedUp() { @@ -414,7 +418,7 @@ public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Except @Test public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { - conf.set(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE(), false); + conf.set(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE(), false); testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); } @@ -539,16 +543,21 @@ public void testPeakMemoryUsed() throws Exception { final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; taskMemoryManager = spy(taskMemoryManager); when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); - final UnsafeShuffleWriter writer = - new UnsafeShuffleWriter<>( + final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( blockManager, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId())); + new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + shuffleBlockResolver, + BlockManagerId.apply("localhost", 7077))); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 1cd7296e9de53..6eb8251ec4002 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -383,14 +383,15 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // simultaneously, and everything is still OK def writeAndClose( - writer: ShuffleWriter[Int, Int], - taskContext: TaskContext)( - iter: Iterator[(Int, Int)]): Option[MapStatus] = { - TaskContext.setTaskContext(taskContext) - val files = writer.write(iter) - val status = writer.stop(true) - TaskContext.unset - status + writer: ShuffleWriter[Int, Int], + taskContext: TaskContext)( + iter: Iterator[(Int, Int)]): Option[MapStatus] = { + try { + val files = writer.write(iter) + writer.stop(true) + } finally { + TaskContext.unset() + } } val interleaver = new InterleaveIterators( data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2)) @@ -413,6 +414,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq + TaskContext.unset() assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) manager.unregisterShuffle(0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala index 68bc5c2961e2d..9d3a52a237cbe 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala @@ -19,18 +19,18 @@ package org.apache.spark.scheduler import java.util import org.apache.spark.{FetchFailed, HashPartitioner, ShuffleDependency, SparkConf, Success} -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} import org.apache.spark.internal.config import org.apache.spark.rdd.RDD -import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO import org.apache.spark.storage.BlockManagerId class PluginShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { - val defaultShuffleDataIO = new DefaultShuffleDataIO(sparkConf) + val localDiskShuffleDataIO = new LocalDiskShuffleDataIO(sparkConf) override def driver(): ShuffleDriverComponents = - new PluginShuffleDriverComponents(defaultShuffleDataIO.driver()) + new PluginShuffleDriverComponents(localDiskShuffleDataIO.driver()) - override def executor(): ShuffleExecutorComponents = defaultShuffleDataIO.executor() + override def executor(): ShuffleExecutorComponents = localDiskShuffleDataIO.executor() } class PluginShuffleDriverComponents(delegate: ShuffleDriverComponents) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 64f8cbc970d54..966a6fa9d005f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import org.mockito.Mockito.{mock, when} +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -29,7 +31,7 @@ import org.apache.spark.internal.config import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId} /** @@ -59,11 +61,14 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying * ManagedBuffers that contain the data are eventually released. */ test("read() releases resources on completion") { + MockitoAnnotations.initMocks(this) val testConf = new SparkConf(false) // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the // shuffle code calls SparkEnv.get()). @@ -142,15 +147,21 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val shuffleReadSupport = - new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf) + val shuffleExecutorComponents = + new LocalDiskShuffleExecutorComponents( + testConf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + localBlockManagerId) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, taskContext, metrics, - shuffleReadSupport, + shuffleExecutorComponents, serializerManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala index 0abfa4d8d8413..b571565cf4336 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.shuffle -import java.util +import java.io.InputStream +import java.lang.{Iterable => JIterable} +import java.util.{Map => JMap} import com.google.common.collect.ImmutableMap -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleReadSupport, ShuffleWriteSupport} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.shuffle.api.{ShuffleBlockInfo, ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleMapOutputWriter} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { test(s"test serialization of shuffle initialization conf to executors") { @@ -43,7 +44,7 @@ class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext } class TestShuffleDriverComponents extends ShuffleDriverComponents { - override def initializeApplication(): util.Map[String, String] = + override def initializeApplication(): JMap[String, String] = ImmutableMap.of("test-key", "test-value") } @@ -55,21 +56,29 @@ class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { } class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents { - override def initializeExecutor(appId: String, execId: String, - extraConfigs: util.Map[String, String]): Unit = { + + private var delegate = new LocalDiskShuffleExecutorComponents(sparkConf) + + override def initializeExecutor( + appId: String, execId: String, extraConfigs: JMap[String, String]): Unit = { assert(extraConfigs.get("test-key") == "test-value") + delegate.initializeExecutor(appId, execId, extraConfigs) + } + + override def createMapOutputWriter( + shuffleId: Int, + mapId: Int, + mapTaskAttemptId: Long, + numPartitions: Int): ShuffleMapOutputWriter = { + delegate.createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions) } - override def writes(): ShuffleWriteSupport = { - val blockManager = SparkEnv.get.blockManager - val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) - new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId) + override def getPartitionReaders( + blockMetadata: JIterable[ShuffleBlockInfo]): JIterable[InputStream] = { + delegate.getPartitionReaders(blockMetadata) } - override def reads(): ShuffleReadSupport = { - val blockManager = SparkEnv.get.blockManager - val mapOutputTracker = SparkEnv.get.mapOutputTracker - val serializerManager = SparkEnv.get.serializerManager - new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf) + override def shouldWrapPartitionReaderStream(): Boolean = { + delegate.shouldWrapPartitionReaderStream() } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 0a77b9f0686ac..a8246aca20baa 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -39,8 +39,9 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer.{KryoSerializer, SerializerManager} -import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException, IndexShuffleBlockResolver} +import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockAttemptId, ShuffleBlockId} import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils} @@ -67,6 +68,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { // this is only used when initiating the BlockManager, for comms between master and executor @Mock(answer = RETURNS_SMART_NULLS) private var rpcEnv: RpcEnv = _ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ private var tempDir: File = _ @@ -212,11 +214,13 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(sorter) - val readSupport = new DefaultShuffleReadSupport( + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + defaultConf, blockManager, mapOutputTracker, serializerManager, - defaultConf) + blockResolver, + blockManager.shuffleServerId) new BlockStoreShuffleReader[String, String]( shuffleHandle, @@ -224,7 +228,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { 1, taskContext, taskContext.taskMetrics().createTempShuffleReadMetrics(), - readSupport, + shuffleExecutorComponents, serializerManager, mapOutputTracker ) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index dbcf09400c97e..46888259206a9 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -19,8 +19,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -49,9 +48,13 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") - val shuffleWriteSupport = - new DefaultShuffleWriteSupport( - conf, blockResolver, BlockManagerId("0", "localhost", 7077, None)) + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + blockManager.shuffleServerId) val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, @@ -60,7 +63,7 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase taskContext.taskAttemptId(), conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleWriteSupport + shuffleExecutorComponents ) shuffleWriter diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index bd241b5ebfaef..da1630e67a485 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.shuffle.sort import java.io.File -import java.util.{Properties, UUID} +import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -30,15 +31,14 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach -import scala.util.Random import org.apache.spark._ -import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -49,11 +49,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ + @Mock(answer = RETURNS_SMART_NULLS) private var serializerManager: SerializerManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var mapOutputTracker: MapOutputTracker = _ private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ - private var writeSupport: ShuffleWriteSupport = _ + private var shuffleExecutorComponents: ShuffleExecutorComponents = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) .set("spark.app.id", "sampleApp") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() @@ -62,39 +64,42 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte override def beforeEach(): Unit = { super.beforeEach() + MockitoAnnotations.initMocks(this) tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) taskMetrics = new TaskMetrics - MockitoAnnotations.initMocks(this) shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( shuffleId = 0, numMaps = 2, dependency = dependency ) + val memoryManager = new TestMemoryManager(conf) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) when(dependency.partitioner).thenReturn(new HashPartitioner(7)) when(dependency.serializer).thenReturn(new JavaSerializer(conf)) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - doAnswer(new Answer[Void] { - def answer(invocationOnMock: InvocationOnMock): Void = { - val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + + when(blockResolver.writeIndexFileAndCommit( + anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + .thenAnswer { (invocationOnMock: InvocationOnMock) => + val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) } null } - }).when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) - when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(blockManager.getDiskWriter( any[BlockId], any[File], any[SerializerInstance], anyInt(), - any[ShuffleWriteMetrics] - )).thenAnswer(new Answer[DiskBlockObjectWriter] { - override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { + any[ShuffleWriteMetrics])) + .thenAnswer { (invocation: InvocationOnMock) => val args = invocation.getArguments val manager = new SerializerManager(new JavaSerializer(conf), conf) new DiskBlockObjectWriter( @@ -104,44 +109,29 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte args(3).asInstanceOf[Int], syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], - blockId = args(0).asInstanceOf[BlockId] - ) + blockId = args(0).asInstanceOf[BlockId]) } - }) - when(diskBlockManager.createTempShuffleBlock()).thenAnswer( - new Answer[(TempShuffleBlockId, File)] { - override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = { - val blockId = new TempShuffleBlockId(UUID.randomUUID) - val file = new File(tempDir, blockId.name) - blockIdToFileMap.put(blockId, file) - temporaryFilesCreated += file - (blockId, file) - } - }) - when(diskBlockManager.getFile(any[BlockId])).thenAnswer( - new Answer[File] { - override def answer(invocation: InvocationOnMock): File = { - blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get - } - }) - val memoryManager = new TestMemoryManager(conf) - val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + when(diskBlockManager.createTempShuffleBlock()) + .thenAnswer { (invocationOnMock: InvocationOnMock) => + val blockId = new TempShuffleBlockId(UUID.randomUUID) + val file = new File(tempDir, blockId.name) + blockIdToFileMap.put(blockId, file) + temporaryFilesCreated += file + (blockId, file) + } - TaskContext.setTaskContext(new TaskContextImpl( - stageId = 0, - stageAttemptNumber = 0, - partitionId = 0, - taskAttemptId = Random.nextInt(10000), - attemptNumber = 0, - taskMemoryManager = taskMemoryManager, - localProperties = new Properties, - metricsSystem = null, - taskMetrics = taskMetrics)) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) => + blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId]) + } - writeSupport = - new DefaultShuffleWriteSupport(conf, blockResolver, blockManager.shuffleServerId) + shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + BlockManagerId("localhost", 7077)) } override def afterEach(): Unit = { @@ -160,11 +150,11 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - taskContext.taskAttemptId(), + 0L, // MapTaskAttemptId conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) + shuffleExecutorComponents) + writer.write(Iterator.empty) writer.stop( /* success = */ true) assert(writer.getPartitionLengths.sum === 0) @@ -178,55 +168,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(taskMetrics.memoryBytesSpilled === 0) } - test("write with some empty partitions") { - val transferConf = conf.clone.set("spark.file.transferTo", "false") - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - val writer = new BypassMergeSortShuffleWriter[Int, Int]( - blockManager, - shuffleHandle, - 0, // MapId - taskContext.taskAttemptId(), - transferConf, - taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) - writer.write(records) - writer.stop( /* success = */ true) - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - } - - test("write with some empty partitions with transferTo") { - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - val writer = new BypassMergeSortShuffleWriter[Int, Int]( - blockManager, - shuffleHandle, - 0, // MapId - taskContext.taskAttemptId(), - conf, - taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) - writer.write(records) - writer.stop( /* success = */ true) - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) + Seq(true, false).foreach { transferTo => + test(s"write with some empty partitions - transferTo $transferTo") { + val transferConf = conf.clone.set("spark.file.transferTo", transferTo.toString) + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + shuffleHandle, + 0, // MapId + 0L, + transferConf, + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleExecutorComponents) + writer.write(records) + writer.stop( /* success = */ true) + assert(temporaryFilesCreated.nonEmpty) + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } } test("only generate temp shuffle file for non-empty partition") { @@ -247,11 +213,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - taskContext.taskAttemptId(), + 0L, conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) + shuffleExecutorComponents) intercept[SparkException] { writer.write(records) @@ -270,11 +235,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - taskContext.taskAttemptId(), + 0L, conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) + shuffleExecutorComponents) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { if (i == 99990) { @@ -287,4 +251,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } + + /** + * This won't be necessary with Scala 2.12 + */ + private implicit def functionToAnswer[T](func: InvocationOnMock => T): Answer[T] = { + new Answer[T] { + override def answer(invocationOnMock: InvocationOnMock): T = func(invocationOnMock) + } + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala index 26b92e5203b50..6decc9d4e2c84 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala @@ -29,7 +29,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random -import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkConf, TaskContext} +import org.apache.spark.{HashPartitioner, MapOutputTracker, ShuffleDependency, SparkConf, TaskContext} import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} @@ -50,9 +50,11 @@ abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase { @Mock(answer = RETURNS_SMART_NULLS) protected var taskContext: TaskContext = _ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEnv: RpcEnv = _ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + // only used to retrieve info about the maps at the beginning, doesn't affect perf + @Mock(answer = RETURNS_SMART_NULLS) protected var mapOutputTracker: MapOutputTracker = _ protected val defaultConf: SparkConf = new SparkConf(loadDefaults = false) - protected val serializer: Serializer = new KryoSerializer(defaultConf) + protected val serializer: Serializer = new KryoSerializer(defaultConf) protected val partitioner: HashPartitioner = new HashPartitioner(10) protected val serializerManager: SerializerManager = new SerializerManager(serializer, defaultConf) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index 7e7a86b3e6b2a..2cb53e4bac224 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -22,8 +22,7 @@ import org.mockito.Mockito.when import org.apache.spark.{Aggregator, SparkEnv, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -78,16 +77,20 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(taskContext) - val writeSupport = - new DefaultShuffleWriteSupport( - defaultConf, blockResolver, BlockManagerId("0", "localhost", 7077, None)) + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + defaultConf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + blockManager.shuffleServerId) val shuffleWriter = new SortShuffleWriter[String, String, String]( blockResolver, shuffleHandle, 0, taskContext, - writeSupport) + shuffleExecutorComponents) shuffleWriter } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala new file mode 100644 index 0000000000000..326831749ce09 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Mockito._ +import org.scalatest.Matchers + +import org.apache.spark.{MapOutputTracker, Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents +import org.apache.spark.storage.{BlockManager, BlockManagerId} +import org.apache.spark.util.Utils + + +class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers { + + @Mock(answer = RETURNS_SMART_NULLS) + private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) + private var mapOutputTracker: MapOutputTracker = _ + @Mock(answer = RETURNS_SMART_NULLS) + private var serializerManager: SerializerManager = _ + + private val shuffleId = 0 + private val numMaps = 5 + private var shuffleHandle: BaseShuffleHandle[Int, Int, Int] = _ + private val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + private val serializer = new JavaSerializer(conf) + private var shuffleExecutorComponents: ShuffleExecutorComponents = _ + + override def beforeEach(): Unit = { + super.beforeEach() + MockitoAnnotations.initMocks(this) + val partitioner = new Partitioner() { + def numPartitions = numMaps + def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions) + } + shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.partitioner).thenReturn(partitioner) + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps = numMaps, dependency) + } + shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + shuffleBlockResolver, + BlockManagerId("localhost", 7077)) + } + + override def afterAll(): Unit = { + try { + shuffleBlockResolver.stop() + } finally { + super.afterAll() + } + } + + test("write empty iterator") { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val writer = new SortShuffleWriter[Int, Int, Int]( + shuffleBlockResolver, + shuffleHandle, + mapId = 1, + context, + shuffleExecutorComponents) + writer.write(Iterator.empty) + writer.stop(success = true) + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 1) + val writeMetrics = context.taskMetrics().shuffleWriteMetrics + assert(!dataFile.exists()) + assert(writeMetrics.bytesWritten === 0) + assert(writeMetrics.recordsWritten === 0) + } + + test("write with some records") { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val records = List[(Int, Int)]((1, 2), (2, 3), (4, 4), (6, 5)) + val writer = new SortShuffleWriter[Int, Int, Int]( + shuffleBlockResolver, + shuffleHandle, + mapId = 2, + context, + shuffleExecutorComponents) + writer.write(records.toIterator) + writer.stop(success = true) + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 2) + val writeMetrics = context.taskMetrics().shuffleWriteMetrics + assert(dataFile.exists()) + assert(dataFile.length() === writeMetrics.bytesWritten) + assert(records.size === writeMetrics.recordsWritten) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index b09ccb334e4f1..d012bda0ffede 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -18,8 +18,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -44,9 +43,13 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) - val shuffleWriteSupport = - new DefaultShuffleWriteSupport( - conf, blockResolver, BlockManagerId("0", "localhost", 7077, None)) + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + blockManager.shuffleServerId) TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( @@ -57,7 +60,7 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleWriteSupport) + shuffleExecutorComponents) } def writeBenchmarkWithSmallDataset(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala deleted file mode 100644 index 3ccb549912782..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io - -import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} -import java.math.BigInteger -import java.nio.ByteBuffer -import java.nio.channels.{Channels, WritableByteChannel} - -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} -import org.mockito.Mock -import org.mockito.Mockito.{doAnswer, doNothing, when} -import org.mockito.MockitoAnnotations -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.api.shuffle.SupportsTransferTo -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.shuffle.IndexShuffleBlockResolver -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{ByteBufferInputStream, Utils} - -class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { - - @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ - @Mock(answer = RETURNS_SMART_NULLS) private var shuffleWriteMetrics: ShuffleWriteMetrics = _ - - private val NUM_PARTITIONS = 4 - private val D_LEN = 10 - private val data: Array[Array[Int]] = (0 until NUM_PARTITIONS).map { - p => (1 to D_LEN).map(_ + p).toArray }.toArray - - private var tempFile: File = _ - private var mergedOutputFile: File = _ - private var tempDir: File = _ - private var partitionSizesInMergedFile: Array[Long] = _ - private var conf: SparkConf = _ - private var mapOutputWriter: DefaultShuffleMapOutputWriter = _ - - override def afterEach(): Unit = { - try { - Utils.deleteRecursively(tempDir) - } finally { - super.afterEach() - } - } - - override def beforeEach(): Unit = { - MockitoAnnotations.initMocks(this) - tempDir = Utils.createTempDir(null, "test") - mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) - tempFile = File.createTempFile("tempfile", "", tempDir) - partitionSizesInMergedFile = null - conf = new SparkConf() - .set("spark.app.id", "example.spark.app") - .set("spark.shuffle.unsafe.file.output.buffer", "16k") - when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) - - doNothing().when(shuffleWriteMetrics).incWriteTime(anyLong) - - doAnswer(new Answer[Void] { - def answer(invocationOnMock: InvocationOnMock): Void = { - partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] - val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - mergedOutputFile.delete - tmp.renameTo(mergedOutputFile) - } - null - } - }).when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) - mapOutputWriter = new DefaultShuffleMapOutputWriter( - 0, - 0, - NUM_PARTITIONS, - BlockManagerId("0", "localhost", 9099), - shuffleWriteMetrics, - blockResolver, - conf) - } - - private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = { - var startOffset = 0L - val result = new Array[Array[Int]](NUM_PARTITIONS) - (0 until NUM_PARTITIONS).foreach { p => - val partitionSize = partitionSizesInMergedFile(p).toInt - lazy val inner = new Array[Int](partitionSize) - lazy val innerBytebuffer = ByteBuffer.allocate(partitionSize) - if (partitionSize > 0) { - val in = new FileInputStream(mergedOutputFile) - in.getChannel.position(startOffset) - val lin = new LimitedInputStream(in, partitionSize) - var nonEmpty = true - var count = 0 - while (nonEmpty) { - try { - val readBit = lin.read() - if (fromByte) { - innerBytebuffer.put(readBit.toByte) - } else { - inner(count) = readBit - } - count += 1 - } catch { - case _: Exception => - nonEmpty = false - } - } - in.close() - } - if (fromByte) { - result(p) = innerBytebuffer.array().sliding(4, 4).map { b => - new BigInteger(b).intValue() - }.toArray - } else { - result(p) = inner - } - startOffset += partitionSize - } - result - } - - test("writing to an outputstream") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.openStream() - data(p).foreach { i => stream.write(i)} - stream.close() - intercept[IllegalStateException] { - stream.write(p) - } - assert(writer.getNumBytesWritten() == D_LEN) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(false)) - } - - test("writing to a channel") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() - val byteBuffer = ByteBuffer.allocate(D_LEN * 4) - val intBuffer = byteBuffer.asIntBuffer() - intBuffer.put(data(p)) - val numBytes = byteBuffer.remaining() - val outputTempFile = File.createTempFile("channelTemp", "", tempDir) - val outputTempFileStream = new FileOutputStream(outputTempFile) - Utils.copyStream( - new ByteBufferInputStream(byteBuffer), - outputTempFileStream, - closeStreams = true) - val tempFileInput = new FileInputStream(outputTempFile) - channel.transferFrom(tempFileInput.getChannel, 0L, numBytes) - // Bytes require * 4 - channel.close() - tempFileInput.close() - assert(writer.getNumBytesWritten == D_LEN * 4) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(true)) - } - - test("copyStreams with an outputstream") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.openStream() - val byteBuffer = ByteBuffer.allocate(D_LEN * 4) - val intBuffer = byteBuffer.asIntBuffer() - intBuffer.put(data(p)) - val in = new ByteArrayInputStream(byteBuffer.array()) - Utils.copyStream(in, stream, false, false) - in.close() - stream.close() - assert(writer.getNumBytesWritten == D_LEN * 4) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(true)) - } - - test("copyStreamsWithNIO with a channel") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() - val byteBuffer = ByteBuffer.allocate(D_LEN * 4) - val intBuffer = byteBuffer.asIntBuffer() - intBuffer.put(data(p)) - val out = new FileOutputStream(tempFile) - out.write(byteBuffer.array()) - out.close() - val in = new FileInputStream(tempFile) - channel.transferFrom(in.getChannel, 0L, byteBuffer.remaining()) - channel.close() - assert(writer.getNumBytesWritten == D_LEN * 4) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(true)) - } -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala new file mode 100644 index 0000000000000..8aa9f51e09494 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import java.io.{File, FileInputStream} +import java.nio.channels.FileChannel +import java.nio.file.Files +import java.util.Arrays + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.Mock +import org.mockito.Mockito.when +import org.mockito.MockitoAnnotations +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils + +class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) + private var blockResolver: IndexShuffleBlockResolver = _ + + private val NUM_PARTITIONS = 4 + private val BLOCK_MANAGER_ID = BlockManagerId("localhost", 7077) + private val data: Array[Array[Byte]] = (0 until NUM_PARTITIONS).map { p => + if (p == 3) { + Array.emptyByteArray + } else { + (0 to p * 10).map(_ + p).map(_.toByte).toArray + } + }.toArray + + private val partitionLengths = data.map(_.length) + + private var tempFile: File = _ + private var mergedOutputFile: File = _ + private var tempDir: File = _ + private var partitionSizesInMergedFile: Array[Long] = _ + private var conf: SparkConf = _ + private var mapOutputWriter: LocalDiskShuffleMapOutputWriter = _ + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def beforeEach(): Unit = { + MockitoAnnotations.initMocks(this) + tempDir = Utils.createTempDir() + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) + tempFile = File.createTempFile("tempfile", "", tempDir) + partitionSizesInMergedFile = null + conf = new SparkConf() + .set("spark.app.id", "example.spark.app") + .set("spark.shuffle.unsafe.file.output.buffer", "16k") + when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) + when(blockResolver.writeIndexFileAndCommit( + anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + .thenAnswer { (invocationOnMock: InvocationOnMock) => + partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + mergedOutputFile.delete() + tmp.renameTo(mergedOutputFile) + } + null + } + mapOutputWriter = new LocalDiskShuffleMapOutputWriter( + 0, + 0, + NUM_PARTITIONS, + blockResolver, + BLOCK_MANAGER_ID, + conf) + } + + test("writing to an outputstream") { + (0 until NUM_PARTITIONS).foreach { p => + val writer = mapOutputWriter.getPartitionWriter(p) + val stream = writer.openStream() + data(p).foreach { i => stream.write(i) } + stream.close() + intercept[IllegalStateException] { + stream.write(p) + } + } + verifyWrittenRecords() + } + + test("writing to a channel") { + (0 until NUM_PARTITIONS).foreach { p => + val writer = mapOutputWriter.getPartitionWriter(p) + val outputTempFile = File.createTempFile("channelTemp", "", tempDir) + Files.write(outputTempFile.toPath, data(p)) + val tempFileInput = new FileInputStream(outputTempFile) + val channel = writer.openChannelWrapper() + Utils.tryWithResource(new FileInputStream(outputTempFile)) { tempFileInput => + Utils.tryWithResource(writer.openChannelWrapper().get) { channelWrapper => + assert(channelWrapper.channel().isInstanceOf[FileChannel], + "Underlying channel should be a file channel") + Utils.copyFileStreamNIO( + tempFileInput.getChannel, channelWrapper.channel(), 0L, data(p).length) + } + } + } + verifyWrittenRecords() + } + + private def readRecordsFromFile() = { + val mergedOutputBytes = Files.readAllBytes(mergedOutputFile.toPath) + val result = (0 until NUM_PARTITIONS).map { part => + val startOffset = data.slice(0, part).map(_.length).sum + val partitionSize = data(part).length + Arrays.copyOfRange(mergedOutputBytes, startOffset, startOffset + partitionSize) + }.toArray + result + } + + private def verifyWrittenRecords(): Unit = { + val committedLengths = mapOutputWriter.commitAllPartitions() + assert(partitionSizesInMergedFile === partitionLengths) + assert(committedLengths.getPartitionLengths === partitionLengths) + assert(committedLengths.getLocation.isPresent) + assert(committedLengths.getLocation.get === BLOCK_MANAGER_ID) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile()) + } + + /** + * This won't be necessary with Scala 2.12 + */ + private implicit def functionToAnswer[T](func: InvocationOnMock => T): Answer[T] = { + new Answer[T] { + override def answer(invocationOnMock: InvocationOnMock): T = func(invocationOnMock) + } + } +}