diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index a8983df208318..6a32244bd03b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{StructType, NativeType} +import org.apache.spark.sql.types.{DataType, StructType, NativeType} /** @@ -232,3 +232,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return 0 } } + +object RowOrdering { + def getOrderingFromDataTypes(dataTypes: Seq[DataType]): RowOrdering = + new RowOrdering(dataTypes.zipWithIndex.map { + case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index a1f0805a0ab92..048251c4c1f91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -46,11 +46,8 @@ case class SortMergeJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map { - case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending) - } // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = new RowOrdering(orders) + private val keyOrdering: RowOrdering = RowOrdering.getOrderingFromDataTypes(leftKeys.map(_.dataType)) private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Ordering[Row] = newOrdering(keys.map(SortOrder(_, Ascending)), side.output)