From c6cbb06a0f13cce0a13bb34c0d07d50e7a4ffe77 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 29 Aug 2019 22:09:58 +0800 Subject: [PATCH] Still need mapId for the fetch fail scenario --- .../org/apache/spark/MapOutputTracker.scala | 24 +++--- .../storage/ShuffleBlockFetcherIterator.scala | 85 +++++++++++-------- .../apache/spark/MapOutputTrackerSuite.scala | 16 ++-- .../BlockStoreShuffleReaderSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 54 ++++++------ 5 files changed, 99 insertions(+), 82 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ce8d43f07dace..176733b43af4e 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -292,11 +292,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * endPartition is excluded from the range). * * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. + * and the second item is a sequence of (shuffle block id, shuffle block size, map id) + * tuples describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** * Deletes map output status information for the specified shuffle stage. @@ -646,7 +646,7 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -683,7 +683,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -864,17 +864,17 @@ private[spark] object MapOutputTracker extends Logging { * @param endPartition End of map output partition ID range (excluded from range) * @param statuses List of map statuses, indexed by map ID. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. + * and the second item is a sequence of (shuffle block id, shuffle block size, map id) + * tuples describing the shuffle blocks that are stored at that block manager. */ def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] - statuses.foreach { status => + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] + for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) @@ -884,7 +884,7 @@ private[spark] object MapOutputTracker extends Logging { val size = status.getSizeForBlock(part) if (size != 0) { splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, status.mapTaskAttemptId, part), size)) + ((ShuffleBlockId(shuffleId, status.mapTaskAttemptId, part), size, mapId)) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a5b7ee5762c49..9199f158c665d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -48,9 +48,10 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * @param shuffleClient [[BlockStoreClient]] for fetching remote blocks * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. - * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. Note that zero-sized blocks are - * already excluded, which happened in + * For each block we also require two info: 1. the size (in bytes as a long + * field) in order to throttle the memory usage; 2. the mapId for this + * block, which indicate the index in the map stage of the block. + * Note that zero-sized blocks are already excluded, which happened in * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. @@ -66,7 +67,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: BlockStoreClient, blockManager: BlockManager, - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -96,7 +97,7 @@ final class ShuffleBlockFetcherIterator( private[this] val startTimeNs = System.nanoTime() /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]() + private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -198,7 +199,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, address, _, buf, _) => + case SuccessFetchResult(_, _, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { @@ -223,9 +224,11 @@ final class ShuffleBlockFetcherIterator( bytesInFlight += req.size reqsInFlight += 1 - // so we can look up the size of each blockID - val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap - val remainingBlocks = new HashSet[String]() ++= sizeMap.keys + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case (blockId, size, mapId) => (blockId.toString, (size, mapId)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys val blockIds = req.blocks.map(_._1.toString) val address = req.address @@ -239,8 +242,8 @@ final class ShuffleBlockFetcherIterator( // This needs to be released after use. buf.retain() remainingBlocks -= blockId - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, - remainingBlocks.isEmpty)) + results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2, + address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty)) logDebug("remainingBlocks: " + remainingBlocks) } } @@ -249,7 +252,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), address, e)) + results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) } } @@ -282,28 +285,28 @@ final class ShuffleBlockFetcherIterator( for ((address, blockInfos) <- blocksByAddress) { if (address.executorId == blockManager.blockManagerId.executorId) { blockInfos.find(_._2 <= 0) match { - case Some((blockId, size)) if size < 0 => + case Some((blockId, size, _)) if size < 0 => throw new BlockException(blockId, "Negative block size " + size) - case Some((blockId, size)) if size == 0 => + case Some((blockId, size, _)) if size == 0 => throw new BlockException(blockId, "Zero-sized blocks should be excluded.") case None => // do nothing. } - localBlocks ++= blockInfos.map(_._1) + localBlocks ++= blockInfos.map(info => (info._1, info._3)) localBlockBytes += blockInfos.map(_._2).sum numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(BlockId, Long)] + var curBlocks = new ArrayBuffer[(BlockId, Long, Int)] while (iterator.hasNext) { - val (blockId, size) = iterator.next() + val (blockId, size, mapId) = iterator.next() remoteBlockBytes += size if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } else if (size == 0) { throw new BlockException(blockId, "Zero-sized blocks should be excluded.") } else { - curBlocks += ((blockId, size)) + curBlocks += ((blockId, size, mapId)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size @@ -314,7 +317,7 @@ final class ShuffleBlockFetcherIterator( remoteRequests += new FetchRequest(address, curBlocks) logDebug(s"Creating fetch request of $curRequestSize at $address " + s"with ${curBlocks.size} blocks") - curBlocks = new ArrayBuffer[(BlockId, Long)] + curBlocks = new ArrayBuffer[(BlockId, Long, Int)] curRequestSize = 0 } } @@ -340,19 +343,19 @@ final class ShuffleBlockFetcherIterator( logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { - val blockId = iter.next() + val (blockId, mapId) = iter.next() try { val buf = blockManager.getBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, + results.put(new SuccessFetchResult(blockId, mapId, blockManager.blockManagerId, buf.size(), buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) + results.put(new FailureFetchResult(blockId, mapId, blockManager.blockManagerId, e)) return } } @@ -412,7 +415,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(fetchWaitTime) result match { - case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => + case r @ SuccessFetchResult(blockId, mapId, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 shuffleMetrics.incRemoteBytesRead(buf.size) @@ -421,7 +424,7 @@ final class ShuffleBlockFetcherIterator( } shuffleMetrics.incRemoteBlocksFetched(1) } - if (!localBlocks.contains(blockId)) { + if (!localBlocks.contains((blockId, mapId))) { bytesInFlight -= size } if (isNetworkReqDone) { @@ -445,7 +448,7 @@ final class ShuffleBlockFetcherIterator( // since the last call. val msg = s"Received a zero-size buffer for block $blockId from $address " + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" - throwFetchFailedException(blockId, address, new IOException(msg)) + throwFetchFailedException(blockId, mapId, address, new IOException(msg)) } val in = try { @@ -456,7 +459,7 @@ final class ShuffleBlockFetcherIterator( assert(buf.isInstanceOf[FileSegmentManagedBuffer]) logError("Failed to create input stream from local block", e) buf.release() - throwFetchFailedException(blockId, address, e) + throwFetchFailedException(blockId, mapId, address, e) } try { input = streamWrapper(blockId, in) @@ -474,11 +477,11 @@ final class ShuffleBlockFetcherIterator( buf.release() if (buf.isInstanceOf[FileSegmentManagedBuffer] || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, address, e) + throwFetchFailedException(blockId, mapId, address, e) } else { logWarning(s"got an corrupted block $blockId from $address, fetch again", e) corruptedBlocks += blockId - fetchRequests += FetchRequest(address, Array((blockId, size))) + fetchRequests += FetchRequest(address, Array((blockId, size, mapId))) result = null } } finally { @@ -490,8 +493,8 @@ final class ShuffleBlockFetcherIterator( } } - case FailureFetchResult(blockId, address, e) => - throwFetchFailedException(blockId, address, e) + case FailureFetchResult(blockId, mapId, address, e) => + throwFetchFailedException(blockId, mapId, address, e) } // Send fetch requests up to maxBytesInFlight @@ -504,6 +507,7 @@ final class ShuffleBlockFetcherIterator( input, this, currentResult.blockId, + currentResult.mapId, currentResult.address, detectCorrupt && streamCompressedOrEncrypted)) } @@ -570,10 +574,11 @@ final class ShuffleBlockFetcherIterator( private[storage] def throwFetchFailedException( blockId: BlockId, + mapId: Int, address: BlockManagerId, e: Throwable) = { blockId match { - case ShuffleBlockId(shufId, mapId, reduceId) => + case ShuffleBlockId(shufId, _, reduceId) => throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( @@ -591,6 +596,7 @@ private class BufferReleasingInputStream( private[storage] val delegate: InputStream, private val iterator: ShuffleBlockFetcherIterator, private val blockId: BlockId, + private val mapId: Int, private val address: BlockManagerId, private val detectCorruption: Boolean) extends InputStream { @@ -602,7 +608,7 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, address, e) + iterator.throwFetchFailedException(blockId, mapId, address, e) } } @@ -624,7 +630,7 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, address, e) + iterator.throwFetchFailedException(blockId, mapId, address, e) } } @@ -636,7 +642,7 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, address, e) + iterator.throwFetchFailedException(blockId, mapId, address, e) } } @@ -646,7 +652,7 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, address, e) + iterator.throwFetchFailedException(blockId, mapId, address, e) } } @@ -681,9 +687,10 @@ object ShuffleBlockFetcherIterator { * A request to fetch blocks from a remote BlockManager. * @param address remote BlockManager to fetch from. * @param blocks Sequence of tuple, where the first element is the block id, - * and the second element is the estimated size, used to calculate bytesInFlight. + * and the second element is the estimated size, used to calculate bytesInFlight, + * the third element is the mapId. */ - case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { + case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long, Int)]) { val size = blocks.map(_._2).sum } @@ -698,6 +705,7 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block successfully. * @param blockId block id + * @param mapId mapId for this block * @param address BlockManager that the block was fetched from. * @param size estimated size of the block. Note that this is NOT the exact bytes. * Size of remote block is used to calculate bytesInFlight. @@ -706,6 +714,7 @@ object ShuffleBlockFetcherIterator { */ private[storage] case class SuccessFetchResult( blockId: BlockId, + mapId: Int, address: BlockManagerId, size: Long, buf: ManagedBuffer, @@ -717,11 +726,13 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block unsuccessfully. * @param blockId block id + * @param mapId mapId for this block * @param address BlockManager that the block was attempted to be fetched from * @param e the failure exception */ private[storage] case class FailureFetchResult( blockId: BlockId, + mapId: Int, address: BlockManagerId, e: Throwable) extends FetchResult diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 06602563693a1..d5ee19bde8edf 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -69,9 +69,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(10000L, 1000L), 6)) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000))), - (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 6, 0), size10000)))) - .toSet) + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))), + (BlockManagerId("b", "hostB", 1000), + ArrayBuffer((ShuffleBlockId(10, 6, 0), size10000, 1)))).toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() rpcEnv.shutdown() @@ -150,7 +151,8 @@ class MapOutputTrackerSuite extends SparkFunSuite { BlockManagerId("a", "hostA", 1000), Array(1000L), 5)) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000))))) + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch @@ -318,9 +320,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( (BlockManagerId("a", "hostA", 1000), - Seq((ShuffleBlockId(10, 5, 1), size1000), (ShuffleBlockId(10, 5, 3), size10000))), + Seq((ShuffleBlockId(10, 5, 1), size1000, 0), + (ShuffleBlockId(10, 5, 3), size10000, 0))), (BlockManagerId("b", "hostB", 1000), - Seq((ShuffleBlockId(10, 6, 0), size10000), (ShuffleBlockId(10, 6, 2), size1000))) + Seq((ShuffleBlockId(10, 6, 0), size10000, 1), + (ShuffleBlockId(10, 6, 2), size1000, 1))) ) ) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index bcd1dd105db9b..1a576d82a0a08 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -107,7 +107,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - (shuffleBlockId, byteOutputStream.size().toLong) + (shuffleBlockId, byteOutputStream.size().toLong, mapId) } Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index ed402440e74f1..05c21492b93e9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -98,9 +98,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val transfer = createMockTransfer(remoteBlocks) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long], 0)).toSeq), + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long], 1)).toSeq) ).toIterator val taskContext = TaskContext.empty() @@ -179,8 +179,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long], 0)).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -247,8 +247,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long], 0)).toSeq)) + .toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -336,8 +337,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long], 0)).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -389,8 +390,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptBuffer1 = mockCorruptBuffer(streamLength, 0) val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1", 1) val shuffleBlockId1 = ShuffleBlockId(0, 1, 0) - val blockLengths1 = Seq[Tuple2[BlockId, Long]]( - shuffleBlockId1 -> corruptBuffer1.size() + val blockLengths1 = Seq[Tuple3[BlockId, Long, Int]]( + (shuffleBlockId1, corruptBuffer1.size(), 1) ) val streamNotCorruptTill = 8 * 1024 @@ -398,13 +399,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill) val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2", 2) val shuffleBlockId2 = ShuffleBlockId(0, 2, 0) - val blockLengths2 = Seq[Tuple2[BlockId, Long]]( - shuffleBlockId2 -> corruptBuffer2.size() + val blockLengths2 = Seq[Tuple3[BlockId, Long, Int]]( + (shuffleBlockId2, corruptBuffer2.size(), 2) ) val transfer = createMockTransfer( Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( (blockManagerId1, blockLengths1), (blockManagerId2, blockLengths2) ).toIterator @@ -465,11 +466,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId doReturn(managedBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) - val localBlockLengths = Seq[Tuple2[BlockId, Long]]( - ShuffleBlockId(0, 0, 0) -> 10000 + val localBlockLengths = Seq[Tuple3[BlockId, Long, Int]]( + (ShuffleBlockId(0, 0, 0), 10000, 0) ) val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( (localBmId, localBlockLengths) ).toIterator @@ -531,8 +532,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long], 0)).toSeq)) + .toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -591,7 +593,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) def fetchShuffleBlock( - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -611,15 +613,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) } - val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L, 0)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) - val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator + val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L, 0)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. @@ -640,8 +642,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long], 0)).toSeq)) val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator(