From c69226578573753eeba9468d9dff113a1c113c3d Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 20 Aug 2019 19:14:44 +0800 Subject: [PATCH] Optimization by using radix sort if possible --- .../exchange/ShuffleExchangeExec.scala | 75 ++++++++++++------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 2f4c5734469f8..34c41cdce51f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -29,7 +29,7 @@ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} @@ -237,36 +237,57 @@ object ShuffleExchangeExec { // that case all output rows go to the same partition. val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) { rdd.mapPartitionsInternal { iter => - val recordComparatorSupplier = new Supplier[RecordComparator] { - override def get: RecordComparator = new RecordBinaryComparator() - } - // The comparator for comparing row hashcode, which should always be Integer. - val prefixComparator = PrefixComparators.LONG + val schema = StructType.fromAttributes(outputAttributes) + val canUseRadixSort = SQLConf.get.enableRadixSort && schema.length == 1 && + SortPrefixUtils.canSortFullyWithPrefix(schema.head) + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - // The prefix computer generates row hashcode as the prefix, so we may decrease the - // probability that the prefixes are equal when input rows choose column values from a - // limited range. - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix - override def computePrefix(row: InternalRow): - UnsafeExternalRowSorter.PrefixComputer.Prefix = { - // The hashcode generated from the binary form of a [[UnsafeRow]] should not be null. - result.isNull = false - result.value = row.hashCode() - result + val sorter = if (canUseRadixSort) { + // For better performance, enable radix sort if possible. + val prefixComputer = SortPrefixUtils.createPrefixGenerator(schema) + val prefixComparator = SortPrefixUtils.getPrefixComparator(schema) + val ordering = GenerateOrdering.create(schema) + + UnsafeExternalRowSorter.create( + schema, + ordering, + prefixComparator, + prefixComputer, + pageSize, + true) + } else { + val recordComparatorSupplier = new Supplier[RecordComparator] { + override def get: RecordComparator = new RecordBinaryComparator() } + // The comparator for comparing row hashcode, which should always be Integer. + val prefixComparator = PrefixComparators.LONG + + // The prefix computer generates row hashcode as the prefix, so we may decrease the + // probability that the prefixes are equal when input rows choose column values from a + // limited range. + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + // The hashcode generated from the binary form of a [[UnsafeRow]] should not + // be null. + result.isNull = false + result.value = row.hashCode() + result + } + } + + UnsafeExternalRowSorter.createWithRecordComparator( + schema, + recordComparatorSupplier, + prefixComparator, + prefixComputer, + pageSize, + // We are comparing binary here, which does not support radix sort. + // See more details in SPARK-28699. + false) } - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = UnsafeExternalRowSorter.createWithRecordComparator( - StructType.fromAttributes(outputAttributes), - recordComparatorSupplier, - prefixComparator, - prefixComputer, - pageSize, - // We are comparing binary here, which does not support radix sort. - // See more details in SPARK-28699. - false) sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) } } else {