Skip to content

Commit

Permalink
[SPARK-9054] [SQL] Rename RowOrdering to InterpretedOrdering; use new…
Browse files Browse the repository at this point in the history
…Ordering in SMJ

This patches renames `RowOrdering` to `InterpretedOrdering` and updates SortMergeJoin to use the `SparkPlan` methods for constructing its ordering so that it may benefit from codegen.

This is an updated version of #7408.

Author: Josh Rosen <[email protected]>

Closes #7973 from JoshRosen/SPARK-9054 and squashes the following commits:

e610655 [Josh Rosen] Add comment RE: Ascending ordering
34b8e0c [Josh Rosen] Import ordering
be19a0f [Josh Rosen] [SPARK-9054] [SQL] Rename RowOrdering to InterpretedOrdering; use newOrdering in more places.

(cherry picked from commit 9c87892)
Signed-off-by: Josh Rosen <[email protected]>
  • Loading branch information
JoshRosen committed Aug 5, 2015
1 parent 30e9fcf commit 618dc63
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {

override def nullable: Boolean = left.nullable && right.nullable

private lazy val ordering = TypeUtils.getOrdering(dataType)
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
Expand Down Expand Up @@ -374,7 +374,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {

override def nullable: Boolean = left.nullable && right.nullable

private lazy val ordering = TypeUtils.getOrdering(dataType)
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ case class Least(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

private lazy val ordering = TypeUtils.getOrdering(dataType)
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
Expand Down Expand Up @@ -374,7 +374,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

private lazy val ordering = TypeUtils.getOrdering(dataType)
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._
/**
* An interpreted row ordering comparator.
*/
class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
Expand All @@ -49,9 +49,9 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
case dt: AtomicType if order.direction == Descending =>
dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case s: StructType if order.direction == Ascending =>
s.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case s: StructType if order.direction == Descending =>
s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case other =>
throw new IllegalArgumentException(s"Type $other does not support ordered operations")
}
Expand All @@ -65,6 +65,18 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
}
}

object InterpretedOrdering {

/**
* Creates a [[InterpretedOrdering]] for the given schema, in natural ascending order.
*/
def forSchema(dataTypes: Seq[DataType]): InterpretedOrdering = {
new InterpretedOrdering(dataTypes.zipWithIndex.map {
case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
})
}
}

object RowOrdering {

/**
Expand All @@ -81,13 +93,4 @@ object RowOrdering {
* Returns true iff outputs from the expressions can be ordered.
*/
def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType))

/**
* Creates a [[RowOrdering]] for the given schema, in natural ascending order.
*/
def forSchema(dataTypes: Seq[DataType]): RowOrdering = {
new RowOrdering(dataTypes.zipWithIndex.map {
case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso

override def symbol: String = "<"

private lazy val ordering = TypeUtils.getOrdering(left.dataType)
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
}
Expand All @@ -388,7 +388,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo

override def symbol: String = "<="

private lazy val ordering = TypeUtils.getOrdering(left.dataType)
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
}
Expand All @@ -400,7 +400,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar

override def symbol: String = ">"

private lazy val ordering = TypeUtils.getOrdering(left.dataType)
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
}
Expand All @@ -412,7 +412,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar

override def symbol: String = ">="

private lazy val ordering = TypeUtils.getOrdering(left.dataType)
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ object TypeUtils {
def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]

def getOrdering(t: DataType): Ordering[Any] = {
def getInterpretedOrdering(t: DataType): Ordering[Any] = {
t match {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case s: StructType => s.ordering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._

import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, RowOrdering}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$}


/**
Expand Down Expand Up @@ -301,7 +301,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(newFields)
}

private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType))
private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType))
}

object StructType extends AbstractDataType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
// GenerateOrdering agrees with RowOrdering.
(DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
test(s"GenerateOrdering with $dataType") {
val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType))
val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType))
val genOrdering = GenerateOrdering.generate(
BoundReference(0, dataType, nullable = true).asc ::
BoundReference(1, dataType, nullable = true).asc :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
val mutablePair = new MutablePair[InternalRow, Null]()
iter.map(row => mutablePair.update(row.copy(), null))
}
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
// We need to use an interpreted ordering here because generated orderings cannot be
// serialized and this ordering needs to be created on the driver in order to be passed into
// Spark core code.
implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output)
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
case SinglePartition =>
new Partitioner {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.DataType

object SparkPlan {
protected[sql] val currentContext = new ThreadLocal[SQLContext]()
Expand Down Expand Up @@ -309,13 +310,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
throw e
} else {
log.error("Failed to generate ordering, fallback to interpreted", e)
new RowOrdering(order, inputSchema)
new InterpretedOrdering(order, inputSchema)
}
}
} else {
new RowOrdering(order, inputSchema)
new InterpretedOrdering(order, inputSchema)
}
}
/**
* Creates a row ordering for the given schema, in natural ascending order.
*/
protected def newNaturalAscendingOrdering(dataTypes: Seq[DataType]): Ordering[InternalRow] = {
val order: Seq[SortOrder] = dataTypes.zipWithIndex.map {
case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
}
newOrdering(order, Seq.empty)
}
}

private[sql] trait LeafNode extends SparkPlan {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ case class TakeOrderedAndProject(

override def outputPartitioning: Partitioning = SinglePartition

private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)
// We need to use an interpreted ordering here because generated orderings cannot be serialized
// and this ordering needs to be created on the driver in order to be passed into Spark core code.
private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)

// TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
@transient private val projection = projectList.map(new InterpretedProjection(_, child.output))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ case class SortMergeJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

// this is to manually construct an ordering that can be used to compare keys from both sides
private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))

override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Expand All @@ -59,15 +56,19 @@ case class SortMergeJoin(
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)

private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] =
private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
// This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
keys.map(SortOrder(_, Ascending))
}

protected override def doExecute(): RDD[InternalRow] = {
val leftResults = left.execute().map(_.copy())
val rightResults = right.execute().map(_.copy())

leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
new Iterator[InternalRow] {
// An ordering that can be used to compare keys from both sides.
private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
// Mutable per row objects.
private[this] val joinRow = new JoinedRow
private[this] var leftElement: InternalRow = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.util.Random
import org.apache.spark._
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
Expand Down Expand Up @@ -144,8 +144,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
}
sorter.cleanupResources()

val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType))
val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType))
val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType))
val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType))
val kvOrdering = new Ordering[(InternalRow, InternalRow)] {
override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = {
keyOrdering.compare(x._1, y._1) match {
Expand Down

0 comments on commit 618dc63

Please sign in to comment.