diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index a3d2f587b9052..7e04d641ad96e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -289,6 +289,9 @@ private[spark] class ExternalSorter[K, V, C]( inMemBuffered.hasNext && inMemBuffered.head._1._1 == p } override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } val elem = inMemBuffered.next() (elem._1._2, elem._2) } @@ -314,7 +317,7 @@ private[spark] class ExternalSorter[K, V, C]( private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) : Iterator[Product2[K, C]] = { - val bufferedIters = iterators.map(_.buffered) + val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) type Iter = BufferedIterator[Product2[K, C]] val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) @@ -404,7 +407,7 @@ private[spark] class ExternalSorter[K, V, C]( // from the iterators, without buffering all the ones that are "equal" to a given key. // We do so with code similar to mergeSort, except our Iterator.next combines together all // the elements with the given key. - val bufferedIters = iterators.map(_.buffered) + val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) type Iter = BufferedIterator[Product2[K, C]] val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index eae72c86c5d72..18791e632e6ff 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -25,6 +25,36 @@ import org.apache.spark._ import org.apache.spark.SparkContext._ class ExternalSorterSuite extends FunSuite with LocalSparkContext { + test("empty data stream") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter.iterator.toSeq === Seq()) + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(3)), None, None) + assert(sorter2.iterator.toSeq === Seq()) + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter3.iterator.toSeq === Seq()) + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), None, None) + assert(sorter4.iterator.toSeq === Seq()) + } + test("spilling in local cluster") { val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001")