Skip to content

Commit

Permalink
AL-9228 Avoid reordering decimal Add for canonicalization if data typ…
Browse files Browse the repository at this point in the history
…e is changed (apache#728)

* [SPARK-40362][SQL][3.3] Fix BinaryComparison canonicalization

* [SPARK-40903][SQL] Avoid reordering decimal Add for canonicalization if data type is changed

* [SPARK-42162] Introduce MultiCommutativeOp expression as a memory optimization for canonicalizing large trees of commutative expressions

* [MINOR][DOCS] Remove Canonicalize in docs

---------

Signed-off-by: Dongjoon Hyun <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
Co-authored-by: Peter Toth <[email protected]>
Co-authored-by: Peter Toth <[email protected]>
Co-authored-by: Gengliang Wang <[email protected]>
Co-authored-by: Supun Nakandala <[email protected]>
  • Loading branch information
5 people authored Feb 1, 2024
1 parent c43b45f commit c7ed0dc
Show file tree
Hide file tree
Showing 22 changed files with 411 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ case object UnresolvedSeed extends LeafExpression with Unevaluable {
*/
case class TempResolvedColumn(child: Expression, nameParts: Seq[String]) extends UnaryExpression
with Unevaluable {
override lazy val preCanonicalized = child.preCanonicalized
override lazy val canonicalized = child.canonicalized
override def dataType: DataType = child.dataType
override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ abstract class CastBase extends UnaryExpression
override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)

override lazy val preCanonicalized: Expression = {
val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[CastBase]
override lazy val canonicalized: Expression = {
val basic = withNewChildren(Seq(child.canonicalized)).asInstanceOf[CastBase]
if (timeZoneId.isDefined && !needsTimeZone) {
basic.withTimeZone(null)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ case class DynamicPruningSubquery(

override def toString: String = s"dynamicpruning#${exprId.id} $conditionString"

override lazy val preCanonicalized: DynamicPruning = {
override lazy val canonicalized: DynamicPruning = {
copy(
pruningKey = pruningKey.preCanonicalized,
pruningKey = pruningKey.canonicalized,
buildQuery = buildQuery.canonicalized,
buildKeys = buildKeys.map(_.preCanonicalized),
buildKeys = buildKeys.map(_.canonicalized),
exprId = ExprId(0))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, Tre
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD
import org.apache.spark.sql.types._

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -222,49 +223,40 @@ abstract class Expression extends TreeNode[Expression] {
*/
def childrenResolved: Boolean = children.forall(_.resolved)

// Expression canonicalization is done in 2 phases:
// 1. Recursively canonicalize each node in the expression tree. This does not change the tree
// structure and is more like "node-local" canonicalization.
// 2. Find adjacent commutative operators in the expression tree, reorder them to get a
// static order and remove cosmetic variations. This may change the tree structure
// dramatically and is more like a "global" canonicalization.
//
// The first phase is done by `preCanonicalized`. It's a `lazy val` which recursively calls
// `preCanonicalized` on the children. This means that almost every node in the expression tree
// will instantiate the `preCanonicalized` variable, which is good for performance as you can
// reuse the canonicalization result of the children when you construct a new expression node.
//
// The second phase is done by `canonicalized`, which simply calls `Canonicalize` and is kind of
// the actual "user-facing API" of expression canonicalization. Only the root node of the
// expression tree will instantiate the `canonicalized` variable. This is different from
// `preCanonicalized`, because `canonicalized` does "global" canonicalization and most of the time
// you cannot reuse the canonicalization result of the children.

/**
* An internal lazy val to implement expression canonicalization. It should only be called in
* `canonicalized`, or in subclass's `preCanonicalized` when the subclass overrides this lazy val
* to provide custom canonicalization logic.
*/
lazy val preCanonicalized: Expression = {
val canonicalizedChildren = children.map(_.preCanonicalized)
withNewChildren(canonicalizedChildren)
}

/**
* Returns an expression where a best effort attempt has been made to transform `this` in a way
* that preserves the result but removes cosmetic variations (case sensitivity, ordering for
* commutative operations, etc.) See [[Canonicalize]] for more details.
* commutative operations, etc.).
*
* `deterministic` expressions where `this.canonicalized == other.canonicalized` will always
* evaluate to the same result.
*
* The process of canonicalization is a one pass, bottum-up expression tree computation based on
* canonicalizing children before canonicalizing the current node. There is one exception though,
* as adjacent, same class [[CommutativeExpression]]s canonicalazion happens in a way that calling
* `canonicalized` on the root:
* 1. Gathers and canonicalizes the non-commutative (or commutative but not same class) child
* expressions of the adjacent expressions.
* 2. Reorder the canonicalized child expressions by their hashcode.
* This means that the lazy `cannonicalized` is called and computed only on the root of the
* adjacent expressions.
*/
lazy val canonicalized: Expression = Canonicalize.reorderCommutativeOperators(preCanonicalized)
lazy val canonicalized: Expression = withCanonicalizedChildren

/**
* The default process of canonicalization. It is a one pass, bottum-up expression tree
* computation based oncanonicalizing children before canonicalizing the current node.
*/
final protected def withCanonicalizedChildren: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
withNewChildren(canonicalizedChildren)
}

/**
* Returns true when two expressions will always compute the same result, even if they differ
* cosmetically (i.e. capitalization of names in attributes may be different).
*
* See [[Canonicalize]] for more details.
* See [[Expression#canonicalized]] for more details.
*/
final def semanticEquals(other: Expression): Boolean =
deterministic && other.deterministic && canonicalized == other.canonicalized
Expand All @@ -273,7 +265,7 @@ abstract class Expression extends TreeNode[Expression] {
* Returns a `hashCode` for the calculation performed by this expression. Unlike the standard
* `hashCode`, an attempt has been made to eliminate cosmetic differences.
*
* See [[Canonicalize]] for more details.
* See [[Expression#canonicalized]] for more details.
*/
def semanticHash(): Int = canonicalized.hashCode()

Expand Down Expand Up @@ -362,7 +354,7 @@ trait RuntimeReplaceable extends Expression {
// As this expression gets replaced at optimization with its `child" expression,
// two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions
// are semantically equal.
override lazy val preCanonicalized: Expression = replacement.preCanonicalized
override lazy val canonicalized: Expression = replacement.canonicalized

final override def eval(input: InternalRow = null): Any =
throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
Expand Down Expand Up @@ -1156,3 +1148,94 @@ trait ComplexTypeMergingExpression extends Expression {
trait UserDefinedExpression {
def name: String
}

trait CommutativeExpression extends Expression {
/** Collects adjacent commutative operations. */
private def gatherCommutative(
e: Expression,
f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = e match {
case c: CommutativeExpression if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
case other => other.canonicalized :: Nil
}

/**
* Reorders adjacent commutative operators such as [[And]] in the expression tree, according to
* the `hashCode` of non-commutative nodes, to remove cosmetic variations.
*/
protected def orderCommutative(
f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] =
gatherCommutative(this, f).sortBy(_.hashCode())

/**
* Helper method to generated a canonicalized plan. If the number of operands are
* greater than the MULTI_COMMUTATIVE_OP_OPT_THRESHOLD, this method creates a
* [[MultiCommutativeOp]] as the canonicalized plan.
*/
protected def buildCanonicalizedPlan(
collectOperands: PartialFunction[Expression, Seq[Expression]],
buildBinaryOp: (Expression, Expression) => Expression,
failOnError: Option[Boolean] = None): Expression = {
val operands = orderCommutative(collectOperands)
val reorderResult =
if (operands.length < SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)) {
operands.reduce(buildBinaryOp)
} else {
MultiCommutativeOp(operands, this.getClass, failOnError)(this)
}
reorderResult
}
}

/**
* A helper class used by the Commutative expressions during canonicalization. During
* canonicalization, when we have a long tree of commutative operations, we use the MultiCommutative
* expression to represent that tree instead of creating new commutative objects.
* This class is added as a memory optimization for processing large commutative operation trees
* without creating a large number of new intermediate objects.
* The MultiCommutativeOp memory optimization is applied to the following commutative
* expressions:
* Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor.
* @param operands A sequence of operands that produces a commutative expression tree.
* @param opCls The class of the root operator of the expression tree.
* @param failOnError The optional expression evaluation mode.
* @param originalRoot Root operator of the commutative expression tree before canonicalization.
* This object reference is used to deduce the return dataType of Add and
* Multiply operations when the input datatype is decimal.
*/
case class MultiCommutativeOp(
operands: Seq[Expression],
opCls: Class[_],
failOnError: Option[Boolean])(originalRoot: Expression) extends Unevaluable {
// Helper method to deduce the data type of a single operation.
private def singleOpDataType(lType: DataType, rType: DataType): DataType = {
originalRoot match {
case add: Add =>
(lType, rType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
add.resultDecimalType(p1, s1, p2, s2)
case _ => lType
}
case multiply: Multiply =>
(lType, rType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
multiply.resultDecimalType(p1, s1, p2, s2)
case _ => lType
}
}
}

override def dataType: DataType = {
originalRoot match {
case _: Add | _: Multiply =>
operands.map(_.dataType).reduce((l, r) => singleOpDataType(l, r))
case other => other.dataType
}
}

override def nullable: Boolean = operands.exists(_.nullable)

override def children: Seq[Expression] = operands

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
this.copy(operands = newChildren)(originalRoot)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import scala.collection.{mutable, GenTraversableOnce}
import org.apache.spark.sql.catalyst.util.BigArrayBuffer

object ExpressionSet {
/** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */
/**
* Constructs a new [[ExpressionSet]] by applying [[Expression#canonicalized]] to `expressions`.
*/
def apply(expressions: TraversableOnce[Expression]): ExpressionSet = {
val set = new ExpressionSet()
expressions.foreach(set.add)
Expand All @@ -37,7 +39,7 @@ object ExpressionSet {
/**
* A [[Set]] where membership is determined based on determinacy and a canonical representation of
* an [[Expression]] (i.e. one that attempts to ignore cosmetic differences).
* See [[Canonicalize]] for more details.
* See [[Expression#canonicalized]] for more details.
*
* Internally this set uses the canonical representation, but keeps also track of the original
* expressions to ease debugging. Since different expressions can share the same canonical
Expand Down Expand Up @@ -156,8 +158,8 @@ class ExpressionSet protected(
override def clone(): ExpressionSet = new ExpressionSet(baseSet.clone(), originals.clone())

/**
* Returns a string containing both the post [[Canonicalize]] expressions and the original
* expressions in this set.
* Returns a string containing both the post [[Expression#canonicalized]] expressions
* and the original expressions in this set.
*/
def toDebugString: String =
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ case class PythonUDF(

override def nullable: Boolean = true

override lazy val preCanonicalized: Expression = {
val canonicalizedChildren = children.map(_.preCanonicalized)
override lazy val canonicalized: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
// `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result.
this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ case class ScalaUDF(

override def name: String = udfName.getOrElse("UDF")

override lazy val preCanonicalized: Expression = {
override lazy val canonicalized: Expression = {
// SPARK-32307: `ExpressionEncoder` can't be canonicalized, and technically we don't
// need it to identify a `ScalaUDF`.
copy(children = children.map(_.preCanonicalized), inputEncoders = Nil, outputEncoder = None)
copy(children = children.map(_.canonicalized), inputEncoders = Nil, outputEncoder = None)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ case class AggregateExpression(
def filterAttributes: AttributeSet = filter.map(_.references).getOrElse(AttributeSet.empty)

// We compute the same thing regardless of our final result.
override lazy val preCanonicalized: Expression = {
override lazy val canonicalized: Expression = {
val normalizedAggFunc = mode match {
// For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers,
// and the actual children of `aggregateFunction` is not used, here we normalize the expr id.
Expand All @@ -134,10 +134,10 @@ case class AggregateExpression(
}

AggregateExpression(
normalizedAggFunc.preCanonicalized.asInstanceOf[AggregateFunction],
normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction],
mode,
isDistinct,
filter.map(_.preCanonicalized),
filter.map(_.canonicalized),
ExprId(0))
}

Expand Down
Loading

0 comments on commit c7ed0dc

Please sign in to comment.