From f0871c921285a05602cf566c9f2c23901224d73e Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 6 May 2016 15:39:43 +0200 Subject: [PATCH] Fix TPC-DS 41 - normalize predicates before pulling them out. --- .../sql/catalyst/analysis/Analyzer.scala | 4 ++- .../sql/catalyst/optimizer/Optimizer.scala | 28 +++++++++++++++++++ .../org/apache/spark/sql/SubquerySuite.scala | 12 ++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 527d5b635a7f9..9e9a856286533 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.planning.IntegerIndex import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} @@ -958,7 +959,8 @@ class Analyzer( localPredicateReferences -- p.outputSet } - val transformed = sub transformUp { + // Simplify the predicates before pulling them out. + val transformed = BooleanSimplification(sub) transformUp { case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a3ab89dc71145..3a4a428a78129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -101,6 +101,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, + RewriteScalarSubqueriesInFilter, RewriteCorrelatedScalarSubquery, EliminateSerialization) :: Batch("Decimal Optimizations", fixedPoint, @@ -1645,3 +1646,30 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { } } } + +/** + * Rewrite [[Filter]] plans that contain correlated [[ScalarSubquery]] expressions. When these + * correlated [[ScalarSubquery]] expressions are wrapped in a some Predicate expression, we rewrite + * them into [[PredicateSubquery]] expressions. + */ +object RewriteScalarSubqueriesInFilter extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(condition, child) => + val newCond = f.condition.transformUp { + case e if e.dataType == BooleanType => + val scalars = ArrayBuffer.empty[ScalarSubquery] + val newExpr = e.transform { + case s: ScalarSubquery if s.children.nonEmpty => + scalars += s + s.query.output.head + } + scalars match { + case Seq(ScalarSubquery(query, conditions, exprId)) => + PredicateSubquery(query, conditions :+ newExpr, nullAware = false, exprId) + case _ => + e + } + } + Filter(newCond, f.child) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 80bb4e05385f4..17ac0c8c6e496 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -281,4 +281,16 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(msg1.getMessage.contains( "The correlated scalar subquery can only contain equality predicates")) } + + test("disjunctive correlated scalar subquery") { + checkAnswer( + sql(""" + |select a + |from l + |where (select count(*) + | from r + | where (a = c and d = 2.0) or (a = c and d = 1.0)) > 0 + """.stripMargin), + Row(3) :: Nil) + } }