Skip to content

Commit

Permalink
Still need mapId for the fetch fail scenario
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanyuanking committed Aug 29, 2019
1 parent 8c7460d commit c6cbb06
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 82 deletions.
24 changes: 12 additions & 12 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging

// For testing
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
}

Expand All @@ -292,11 +292,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* endPartition is excluded from the range).
*
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block id, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
* and the second item is a sequence of (shuffle block id, shuffle block size, map id)
* tuples describing the shuffle blocks that are stored at that block manager.
*/
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

/**
* Deletes map output status information for the specified shuffle stage.
Expand Down Expand Up @@ -646,7 +646,7 @@ private[spark] class MapOutputTrackerMaster(
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
Expand Down Expand Up @@ -683,7 +683,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr

// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
Expand Down Expand Up @@ -864,17 +864,17 @@ private[spark] object MapOutputTracker extends Logging {
* @param endPartition End of map output partition ID range (excluded from range)
* @param statuses List of map statuses, indexed by map ID.
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
* and the second item is a sequence of (shuffle block id, shuffle block size, map id)
* tuples describing the shuffle blocks that are stored at that block manager.
*/
def convertMapStatuses(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
assert (statuses != null)
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]]
statuses.foreach { status =>
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
Expand All @@ -884,7 +884,7 @@ private[spark] object MapOutputTracker extends Logging {
val size = status.getSizeForBlock(part)
if (size != 0) {
splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
((ShuffleBlockId(shuffleId, status.mapTaskAttemptId, part), size))
((ShuffleBlockId(shuffleId, status.mapTaskAttemptId, part), size, mapId))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
* @param shuffleClient [[BlockStoreClient]] for fetching remote blocks
* @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage. Note that zero-sized blocks are
* already excluded, which happened in
* For each block we also require two info: 1. the size (in bytes as a long
* field) in order to throttle the memory usage; 2. the mapId for this
* block, which indicate the index in the map stage of the block.
* Note that zero-sized blocks are already excluded, which happened in
* [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
Expand All @@ -66,7 +67,7 @@ final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: BlockStoreClient,
blockManager: BlockManager,
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])],
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
Expand Down Expand Up @@ -96,7 +97,7 @@ final class ShuffleBlockFetcherIterator(
private[this] val startTimeNs = System.nanoTime()

/** Local blocks to fetch, excluding zero-sized blocks. */
private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]()
private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()

/** Remote blocks to fetch, excluding zero-sized blocks. */
private[this] val remoteBlocks = new HashSet[BlockId]()
Expand Down Expand Up @@ -198,7 +199,7 @@ final class ShuffleBlockFetcherIterator(
while (iter.hasNext) {
val result = iter.next()
result match {
case SuccessFetchResult(_, address, _, buf, _) =>
case SuccessFetchResult(_, _, address, _, buf, _) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
Expand All @@ -223,9 +224,11 @@ final class ShuffleBlockFetcherIterator(
bytesInFlight += req.size
reqsInFlight += 1

// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
// so we can look up the block info of each blockID
val infoMap = req.blocks.map {
case (blockId, size, mapId) => (blockId.toString, (size, mapId))
}.toMap
val remainingBlocks = new HashSet[String]() ++= infoMap.keys
val blockIds = req.blocks.map(_._1.toString)
val address = req.address

Expand All @@ -239,8 +242,8 @@ final class ShuffleBlockFetcherIterator(
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
Expand All @@ -249,7 +252,7 @@ final class ShuffleBlockFetcherIterator(

override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
}
}

Expand Down Expand Up @@ -282,28 +285,28 @@ final class ShuffleBlockFetcherIterator(
for ((address, blockInfos) <- blocksByAddress) {
if (address.executorId == blockManager.blockManagerId.executorId) {
blockInfos.find(_._2 <= 0) match {
case Some((blockId, size)) if size < 0 =>
case Some((blockId, size, _)) if size < 0 =>
throw new BlockException(blockId, "Negative block size " + size)
case Some((blockId, size)) if size == 0 =>
case Some((blockId, size, _)) if size == 0 =>
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
case None => // do nothing.
}
localBlocks ++= blockInfos.map(_._1)
localBlocks ++= blockInfos.map(info => (info._1, info._3))
localBlockBytes += blockInfos.map(_._2).sum
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
var curBlocks = new ArrayBuffer[(BlockId, Long, Int)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
val (blockId, size, mapId) = iterator.next()
remoteBlockBytes += size
if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
} else if (size == 0) {
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
} else {
curBlocks += ((blockId, size))
curBlocks += ((blockId, size, mapId))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
Expand All @@ -314,7 +317,7 @@ final class ShuffleBlockFetcherIterator(
remoteRequests += new FetchRequest(address, curBlocks)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${curBlocks.size} blocks")
curBlocks = new ArrayBuffer[(BlockId, Long)]
curBlocks = new ArrayBuffer[(BlockId, Long, Int)]
curRequestSize = 0
}
}
Expand All @@ -340,19 +343,19 @@ final class ShuffleBlockFetcherIterator(
logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
val iter = localBlocks.iterator
while (iter.hasNext) {
val blockId = iter.next()
val (blockId, mapId) = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId,
results.put(new SuccessFetchResult(blockId, mapId, blockManager.blockManagerId,
buf.size(), buf, false))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
results.put(new FailureFetchResult(blockId, mapId, blockManager.blockManagerId, e))
return
}
}
Expand Down Expand Up @@ -412,7 +415,7 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incFetchWaitTime(fetchWaitTime)

result match {
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
case r @ SuccessFetchResult(blockId, mapId, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
shuffleMetrics.incRemoteBytesRead(buf.size)
Expand All @@ -421,7 +424,7 @@ final class ShuffleBlockFetcherIterator(
}
shuffleMetrics.incRemoteBlocksFetched(1)
}
if (!localBlocks.contains(blockId)) {
if (!localBlocks.contains((blockId, mapId))) {
bytesInFlight -= size
}
if (isNetworkReqDone) {
Expand All @@ -445,7 +448,7 @@ final class ShuffleBlockFetcherIterator(
// since the last call.
val msg = s"Received a zero-size buffer for block $blockId from $address " +
s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
throwFetchFailedException(blockId, address, new IOException(msg))
throwFetchFailedException(blockId, mapId, address, new IOException(msg))
}

val in = try {
Expand All @@ -456,7 +459,7 @@ final class ShuffleBlockFetcherIterator(
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
logError("Failed to create input stream from local block", e)
buf.release()
throwFetchFailedException(blockId, address, e)
throwFetchFailedException(blockId, mapId, address, e)
}
try {
input = streamWrapper(blockId, in)
Expand All @@ -474,11 +477,11 @@ final class ShuffleBlockFetcherIterator(
buf.release()
if (buf.isInstanceOf[FileSegmentManagedBuffer]
|| corruptedBlocks.contains(blockId)) {
throwFetchFailedException(blockId, address, e)
throwFetchFailedException(blockId, mapId, address, e)
} else {
logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
corruptedBlocks += blockId
fetchRequests += FetchRequest(address, Array((blockId, size)))
fetchRequests += FetchRequest(address, Array((blockId, size, mapId)))
result = null
}
} finally {
Expand All @@ -490,8 +493,8 @@ final class ShuffleBlockFetcherIterator(
}
}

case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
case FailureFetchResult(blockId, mapId, address, e) =>
throwFetchFailedException(blockId, mapId, address, e)
}

// Send fetch requests up to maxBytesInFlight
Expand All @@ -504,6 +507,7 @@ final class ShuffleBlockFetcherIterator(
input,
this,
currentResult.blockId,
currentResult.mapId,
currentResult.address,
detectCorrupt && streamCompressedOrEncrypted))
}
Expand Down Expand Up @@ -570,10 +574,11 @@ final class ShuffleBlockFetcherIterator(

private[storage] def throwFetchFailedException(
blockId: BlockId,
mapId: Int,
address: BlockManagerId,
e: Throwable) = {
blockId match {
case ShuffleBlockId(shufId, mapId, reduceId) =>
case ShuffleBlockId(shufId, _, reduceId) =>
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
Expand All @@ -591,6 +596,7 @@ private class BufferReleasingInputStream(
private[storage] val delegate: InputStream,
private val iterator: ShuffleBlockFetcherIterator,
private val blockId: BlockId,
private val mapId: Int,
private val address: BlockManagerId,
private val detectCorruption: Boolean)
extends InputStream {
Expand All @@ -602,7 +608,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
iterator.throwFetchFailedException(blockId, mapId, address, e)
}
}

Expand All @@ -624,7 +630,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
iterator.throwFetchFailedException(blockId, mapId, address, e)
}
}

Expand All @@ -636,7 +642,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
iterator.throwFetchFailedException(blockId, mapId, address, e)
}
}

Expand All @@ -646,7 +652,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
iterator.throwFetchFailedException(blockId, mapId, address, e)
}
}

Expand Down Expand Up @@ -681,9 +687,10 @@ object ShuffleBlockFetcherIterator {
* A request to fetch blocks from a remote BlockManager.
* @param address remote BlockManager to fetch from.
* @param blocks Sequence of tuple, where the first element is the block id,
* and the second element is the estimated size, used to calculate bytesInFlight.
* and the second element is the estimated size, used to calculate bytesInFlight,
* the third element is the mapId.
*/
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long, Int)]) {
val size = blocks.map(_._2).sum
}

Expand All @@ -698,6 +705,7 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block successfully.
* @param blockId block id
* @param mapId mapId for this block
* @param address BlockManager that the block was fetched from.
* @param size estimated size of the block. Note that this is NOT the exact bytes.
* Size of remote block is used to calculate bytesInFlight.
Expand All @@ -706,6 +714,7 @@ object ShuffleBlockFetcherIterator {
*/
private[storage] case class SuccessFetchResult(
blockId: BlockId,
mapId: Int,
address: BlockManagerId,
size: Long,
buf: ManagedBuffer,
Expand All @@ -717,11 +726,13 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block unsuccessfully.
* @param blockId block id
* @param mapId mapId for this block
* @param address BlockManager that the block was attempted to be fetched from
* @param e the failure exception
*/
private[storage] case class FailureFetchResult(
blockId: BlockId,
mapId: Int,
address: BlockManagerId,
e: Throwable)
extends FetchResult
Expand Down
Loading

0 comments on commit c6cbb06

Please sign in to comment.