diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 9aacdc99e9be0..1f6f3d3e82108 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner} +import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.{SQLContext, Row} @@ -81,13 +81,17 @@ case class Exchange( * * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. * - * @param numPartitions the number of output partitions produced by the shuffle + * @param partitioner the partitioner for the shuffle * @param serializer the serializer that will be used to write rows * @return true if rows should be copied before being shuffled, false otherwise */ private def needToCopyObjectsBeforeShuffle( - numPartitions: Int, + partitioner: Partitioner, serializer: Serializer): Boolean = { + // Note: even though we only use the partitioner's `numPartitions` field, we require it to be + // passed instead of directly passing the number of partitions in order to guard against + // corner-cases where a partitioner constructed with `numPartitions` partitions may output + // fewer partitions (like RangeParittioner, for example). if (newOrdering.nonEmpty) { // If a new ordering is required, then records will be sorted with Spark's `ExternalSorter`, // which requires a defensive copy. @@ -95,7 +99,7 @@ case class Exchange( } else if (sortBasedShuffleOn) { // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory. // However, there are two special cases where we can avoid the copy, described below: - if (numPartitions <= bypassMergeThreshold) { + if (partitioner.numPartitions <= bypassMergeThreshold) { // If the number of output partitions is sufficiently small, then Spark will fall back to // the old hash-based shuffle write path which doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. @@ -177,8 +181,9 @@ case class Exchange( val keySchema = expressions.map(_.dataType).toArray val valueSchema = child.output.map(_.dataType).toArray val serializer = getSerializer(keySchema, valueSchema, numPartitions) + val part = new HashPartitioner(numPartitions) - val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) { + val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -190,13 +195,10 @@ case class Exchange( iter.map(r => mutablePair.update(hashExpressions(r), r)) } } - val part = new HashPartitioner(numPartitions) - val shuffled = - if (newOrdering.nonEmpty) { - new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering) - } else { - new ShuffledRDD[Row, Row, Row](rdd, part) - } + val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) + if (newOrdering.nonEmpty) { + shuffled.setKeyOrdering(keyOrdering) + } shuffled.setSerializer(serializer) shuffled.map(_._2) @@ -204,33 +206,41 @@ case class Exchange( val keySchema = child.output.map(_.dataType).toArray val serializer = getSerializer(keySchema, null, numPartitions) - val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) { - child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} + val childRdd = child.execute() + val part: Partitioner = { + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + val rddForSampling = childRdd.mapPartitions { iter => + val mutablePair = new MutablePair[Row, Null]() + iter.map(row => mutablePair.update(row.copy(), null)) + } + // TODO: RangePartitioner should take an Ordering. + implicit val ordering = new RowOrdering(sortingExpressions, child.output) + new RangePartitioner(numPartitions, rddForSampling, ascending = true) + } + + val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { + childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { - child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Row, Null](null, null) + childRdd.mapPartitions { iter => + val mutablePair = new MutablePair[Row, Null]() iter.map(row => mutablePair.update(row, null)) } } - // TODO: RangePartitioner should take an Ordering. - implicit val ordering = new RowOrdering(sortingExpressions, child.output) - - val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = - if (newOrdering.nonEmpty) { - new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering) - } else { - new ShuffledRDD[Row, Null, Null](rdd, part) - } + val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) + if (newOrdering.nonEmpty) { + shuffled.setKeyOrdering(keyOrdering) + } shuffled.setSerializer(serializer) shuffled.map(_._1) case SinglePartition => val valueSchema = child.output.map(_.dataType).toArray val serializer = getSerializer(null, valueSchema, 1) + val partitioner = new HashPartitioner(1) - val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions = 1, serializer)) { + val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) { child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) } } else { child.execute().mapPartitions { iter => @@ -238,7 +248,6 @@ case class Exchange( iter.map(r => mutablePair.update(null, r)) } } - val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) shuffled.setSerializer(serializer) shuffled.map(_._2)