diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 3966980a11ed0..1e5f3f7719977 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -141,14 +141,7 @@ final class ShuffleBlockFetcherIterator( /** * Whether the iterator is still active. If isZombie is true, the callback interface will no - * longer place fetched blocks into [[results]] and the iterator is marked as fully consumed. - * - * When the iterator is inactive, [[hasNext]] and [[next]] calls will honor that as there are - * cases the iterator is still being consumed. For example, ShuffledRDD + PipedRDD if the - * subprocess command is failed. The task will be marked as failed, then the iterator will be - * cleaned up at task completion, the [[next]] call (called in the stdin writer thread of - * PipedRDD if not exited yet) may hang at [[results.take]]. The defensive check in [[hasNext]] - * and [[next]] reduces the possibility of such race conditions. + * longer place fetched blocks into [[results]]. */ @GuardedBy("this") private[this] var isZombie = false @@ -387,7 +380,7 @@ final class ShuffleBlockFetcherIterator( logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime)) } - override def hasNext: Boolean = !isZombie && (numBlocksProcessed < numBlocksToFetch) + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch /** * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers @@ -399,7 +392,7 @@ final class ShuffleBlockFetcherIterator( */ override def next(): (BlockId, InputStream) = { if (!hasNext) { - throw new NoSuchElementException() + throw new NoSuchElementException } numBlocksProcessed += 1 @@ -410,7 +403,7 @@ final class ShuffleBlockFetcherIterator( // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch // is also corrupt, so the previous stage could be retried. // For local shuffle block, throw FailureFetchResult for the first IOException. - while (!isZombie && result == null) { + while (result == null) { val startFetchWait = System.currentTimeMillis() result = results.take() val stopFetchWait = System.currentTimeMillis() @@ -504,9 +497,6 @@ final class ShuffleBlockFetcherIterator( fetchUpToMaxBytes() } - if (result == null) { // the iterator is already closed/cleaned up. - throw new NoSuchElementException() - } currentResult = result.asInstanceOf[SuccessFetchResult] (currentResult.blockId, new BufferReleasingInputStream(input, this)) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 98fe9663b6211..6b83243fe496c 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -217,65 +217,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release() } - test("iterator is all consumed if task completes early") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - - // Make sure remote blocks would return - val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) - val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) - - // Semaphore to coordinate event sequence in two different threads. - val sem = new Semaphore(0) - - val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer(new Answer[Unit] { - override def answer(invocation: InvocationOnMock): Unit = { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first two blocks, and wait till task completion before returning the last - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) - sem.acquire() - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) - } - } - }) - - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator - - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) - - - assert(iterator.hasNext) - iterator.next() - - taskContext.markTaskCompleted(None) - sem.release() - assert(iterator.hasNext === false) - } - test("fail all blocks if any of the remote request fails") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1)