Skip to content

Commit

Permalink
[SPARK-28306][SQL] Make NormalizeFloatingNumbers rule idempotent
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
The optimizer rule `NormalizeFloatingNumbers` is not idempotent. It will generate multiple `NormalizeNaNAndZero` and `ArrayTransform` expression nodes for multiple runs. This patch fixed this non-idempotence by adding a marking tag above normalized expressions. It also adds missing UTs for `NormalizeFloatingNumbers`.

## How was this patch tested?
New UTs.

Closes #25080 from yeshengm/spark-28306.

Authored-by: Yesheng Ma <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
yeshengm authored and cloud-fan committed Jul 11, 2019
1 parent 0197628 commit 7021588
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,21 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.DataType

case class KnownNotNull(child: Expression) extends UnaryExpression {
override def nullable: Boolean = false
trait TaggingExpression extends UnaryExpression {
override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.genCode(ctx)

override def eval(input: InternalRow): Any = child.eval(input)
}

case class KnownNotNull(child: Expression) extends TaggingExpression {
override def nullable: Boolean = false

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.genCode(ctx).copy(isNull = FalseLiteral)
}

override def eval(input: InternalRow): Any = {
child.eval(input)
}
}

case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
Expand Down Expand Up @@ -61,7 +61,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case _: Subquery => plan

case _ => plan transform {
case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) =>
case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
// Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need
// to normalize the `windowExpressions`, as they are executed per input row and should take
// the input row as it is.
Expand All @@ -73,7 +73,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _)
// The analyzer guarantees left and right joins keys are of the same data type. Here we
// only need to check join keys of one side.
if leftKeys.exists(k => needNormalize(k.dataType)) =>
if leftKeys.exists(k => needNormalize(k)) =>
val newLeftJoinKeys = leftKeys.map(normalize)
val newRightJoinKeys = rightKeys.map(normalize)
val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
Expand All @@ -87,6 +87,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
}
}

/**
* Short circuit if the underlying expression is already normalized
*/
private def needNormalize(expr: Expression): Boolean = expr match {
case KnownFloatingPointNormalized(_) => false
case _ => needNormalize(expr.dataType)
}

private def needNormalize(dt: DataType): Boolean = dt match {
case FloatType | DoubleType => true
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
Expand All @@ -98,7 +106,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
}

private[sql] def normalize(expr: Expression): Expression = expr match {
case _ if !needNormalize(expr.dataType) => expr
case _ if !needNormalize(expr) => expr

case a: Alias =>
a.withNewChildren(Seq(normalize(a.child)))
Expand All @@ -116,7 +124,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
CreateMap(children.map(normalize))

case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
NormalizeNaNAndZero(expr)
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))

case _ if expr.dataType.isInstanceOf[StructType] =>
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
Expand All @@ -128,7 +136,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
val ArrayType(et, containsNull) = expr.dataType
val lv = NamedLambdaVariable("arg", et, containsNull)
val function = normalize(lv)
ArrayTransform(expr, LambdaFunction(function, Seq(lv)))
KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv))))

case _ => throw new IllegalStateException(s"fail to normalize $expr")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class NormalizeFloatingPointNumbersSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("NormalizeFloatingPointNumbers", Once, NormalizeFloatingNumbers) :: Nil
}

val testRelation1 = LocalRelation('a.double)
val a = testRelation1.output(0)
val testRelation2 = LocalRelation('a.double)
val b = testRelation2.output(0)

test("normalize floating points in window function expressions") {
val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))

val optimized = Optimize.execute(query)
val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))

comparePlans(optimized, correctAnswer)
}

test("normalize floating points in window function expressions - idempotence") {
val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))

val optimized = Optimize.execute(query)
val doubleOptimized = Optimize.execute(optimized)
val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))

comparePlans(doubleOptimized, correctAnswer)
}

test("normalize floating points in join keys") {
val query = testRelation1.join(testRelation2, condition = Some(a === b))

val optimized = Optimize.execute(query)
val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
=== KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)

comparePlans(optimized, correctAnswer)
}

test("normalize floating points in join keys - idempotence") {
val query = testRelation1.join(testRelation2, condition = Some(a === b))

val optimized = Optimize.execute(query)
val doubleOptimized = Optimize.execute(optimized)
val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
=== KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)

comparePlans(doubleOptimized, correctAnswer)
}
}

0 comments on commit 7021588

Please sign in to comment.