Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26713][CORE][followup] revert the partial fix in ShuffleBlockFetcherIterator #741

Merged
merged 2 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,7 @@ final class ShuffleBlockFetcherIterator(

/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
* longer place fetched blocks into [[results]] and the iterator is marked as fully consumed.
*
* When the iterator is inactive, [[hasNext]] and [[next]] calls will honor that as there are
* cases the iterator is still being consumed. For example, ShuffledRDD + PipedRDD if the
* subprocess command is failed. The task will be marked as failed, then the iterator will be
* cleaned up at task completion, the [[next]] call (called in the stdin writer thread of
* PipedRDD if not exited yet) may hang at [[results.take]]. The defensive check in [[hasNext]]
* and [[next]] reduces the possibility of such race conditions.
* longer place fetched blocks into [[results]].
*/
@GuardedBy("this")
private[this] var isZombie = false
Expand Down Expand Up @@ -387,7 +380,7 @@ final class ShuffleBlockFetcherIterator(
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}

override def hasNext: Boolean = !isZombie && (numBlocksProcessed < numBlocksToFetch)
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch

/**
* Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
Expand All @@ -399,7 +392,7 @@ final class ShuffleBlockFetcherIterator(
*/
override def next(): (BlockId, InputStream) = {
if (!hasNext) {
throw new NoSuchElementException()
throw new NoSuchElementException
}

numBlocksProcessed += 1
Expand All @@ -410,7 +403,7 @@ final class ShuffleBlockFetcherIterator(
// then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
// is also corrupt, so the previous stage could be retried.
// For local shuffle block, throw FailureFetchResult for the first IOException.
while (!isZombie && result == null) {
while (result == null) {
Comment on lines -413 to +406
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this thread correctly, we can get race conditions on the isZombie flag.

val startFetchWait = System.currentTimeMillis()
result = results.take()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we hang.

val stopFetchWait = System.currentTimeMillis()
Expand Down Expand Up @@ -504,9 +497,6 @@ final class ShuffleBlockFetcherIterator(
fetchUpToMaxBytes()
}

if (result == null) { // the iterator is already closed/cleaned up.
throw new NoSuchElementException()
}
currentResult = result.asInstanceOf[SuccessFetchResult]
(currentResult.blockId, new BufferReleasingInputStream(input, this))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,65 +217,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}

test("iterator is all consumed if task completes early") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId

// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())

// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)

val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first two blocks, and wait till task completion before returning the last
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
sem.acquire()
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
}
}
})

val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator

val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true,
taskContext.taskMetrics.createTempShuffleReadMetrics())


assert(iterator.hasNext)
iterator.next()

taskContext.markTaskCompleted(None)
sem.release()
assert(iterator.hasNext === false)
}

test("fail all blocks if any of the remote request fails") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
Expand Down