diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 6388ef82cc5db..a7342cfef7dd7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -17,10 +17,11 @@ package org.apache.spark.rdd +import scala.language.existentials + import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.language.existentials import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index eec8678fc8b38..f7b8f340584e7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -501,8 +501,13 @@ private[spark] class ExternalSorter[K, V, C]( /** Construct a stream that only reads from the next batch */ def nextBatchStream(): InputStream = { - batchStreamsRead += 1 - ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1)) + if (batchStreamsRead < spill.serializerBatchSizes.length) { + batchStreamsRead += 1 + ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1)) + } else { + // No more batches left; give an empty stream + bufferedStream + } } /**