From 4b19f49dd01168c006bc5d8a506a1ef3c36c721d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 21 Dec 2020 04:15:29 -0800 Subject: [PATCH] [SPARK-33845][SQL] Remove unnecessary if when trueValue and falseValue are foldable boolean types ### What changes were proposed in this pull request? Improve `SimplifyConditionals`. Simplify `If(cond, TrueLiteral, FalseLiteral)` to `cond`. Simplify `If(cond, FalseLiteral, TrueLiteral)` to `Not(cond)`. The use case is: ```sql create table t1 using parquet as select id from range(10); select if (id > 2, false, true) from t1; ``` Before this pr: ``` == Physical Plan == *(1) Project [if ((id#1L > 2)) false else true AS (IF((id > CAST(2 AS BIGINT)), false, true))#2] +- *(1) ColumnarToRow +- FileScan parquet default.t1[id#1L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/spark-warehouse/org.apache.spark.sql.DataF..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` After this pr: ``` == Physical Plan == *(1) Project [(id#1L <= 2) AS (IF((id > CAST(2 AS BIGINT)), false, true))#2] +- *(1) ColumnarToRow +- FileScan parquet default.t1[id#1L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/spark-warehouse/org.apache.spark.sql.DataF..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30849 from wangyum/SPARK-33798-2. Authored-by: Yuming Wang Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/optimizer/expressions.scala | 2 ++ .../PushFoldableIntoBranchesSuite.scala | 7 ++--- ...ReplaceNullWithFalseInPredicateSuite.scala | 31 +++++++++++-------- .../optimizer/SimplifyConditionalSuite.scala | 16 ++++++++++ 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e6730c9275a1e..ac2caaeb15357 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -475,6 +475,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue + case If(cond, TrueLiteral, FalseLiteral) => cond + case If(cond, FalseLiteral, TrueLiteral) => Not(cond) case If(cond, trueValue, falseValue) if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 43360af46ffb3..de4f4be8ec333 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -53,7 +53,7 @@ class PushFoldableIntoBranchesSuite test("Push down EqualTo through If") { assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) - assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) + assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a)) // Push down at most one not foldable expressions. assertEquivalent( @@ -67,7 +67,7 @@ class PushFoldableIntoBranchesSuite val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(2)) assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(2)), - If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, TrueLiteral)) + GreaterThanOrEqual(Rand(1), Literal(0.5))) assertEquivalent(EqualTo(nonDeterministic, Literal(3)), If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral)) @@ -102,8 +102,7 @@ class PushFoldableIntoBranchesSuite assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3))) assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)), If(a, Literal(2.0), Literal(3.0))) - assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), - If(a, FalseLiteral, TrueLiteral)) + assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a)) assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 00433a5490574..5da71c31e1990 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} @@ -236,12 +236,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(2) === nestedCaseWhen, TrueLiteral, FalseLiteral) - val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) - val condition = CaseWhen(branches) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - testDelete(originalCond = condition, expectedCond = condition) - testUpdate(originalCond = condition, expectedCond = condition) + val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)) + val expectedCond = + CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen))) + testFilter(originalCond = condition, expectedCond = expectedCond) + testJoin(originalCond = condition, expectedCond = expectedCond) + testDelete(originalCond = condition, expectedCond = expectedCond) + testUpdate(originalCond = condition, expectedCond = expectedCond) } test("inability to replace null in non-boolean branches of If inside another If") { @@ -252,10 +253,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(3)), TrueLiteral, FalseLiteral) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - testDelete(originalCond = condition, expectedCond = condition) - testUpdate(originalCond = condition, expectedCond = condition) + val expectedCond = Literal(5) > If( + UnresolvedAttribute("i") === Literal(15), + Literal(null, IntegerType), + Literal(3)) + testFilter(originalCond = condition, expectedCond = expectedCond) + testJoin(originalCond = condition, expectedCond = expectedCond) + testDelete(originalCond = condition, expectedCond = expectedCond) + testUpdate(originalCond = condition, expectedCond = expectedCond) } test("replace null in If used as a join condition") { @@ -405,9 +410,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val lambda1 = LambdaFunction( function = If(cond, Literal(null, BooleanType), TrueLiteral), arguments = lambdaArgs) - // the optimized lambda body is: if(arg > 0, false, true) + // the optimized lambda body is: if(arg > 0, false, true) => arg <= 0 val lambda2 = LambdaFunction( - function = If(cond, FalseLiteral, TrueLiteral), + function = LessThanOrEqual(condArg, Literal(0)), arguments = lambdaArgs) testProjection( originalExpr = createExpr(argument, lambda1) as 'x, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index bac962ced4618..328fc107e1c1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -199,4 +199,20 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow)) } } + + test("SPARK-33845: remove unnecessary if when the outputs are boolean type") { + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), + IsNotNull(UnresolvedAttribute("a"))) + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), + IsNull(UnresolvedAttribute("a"))) + + assertEquivalent( + If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), + GreaterThan(Rand(0), UnresolvedAttribute("a"))) + assertEquivalent( + If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), + LessThanOrEqual(Rand(0), UnresolvedAttribute("a"))) + } }