From 45fbb851e76eeaa45c9926571059274efca2441a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 2 Mar 2018 17:27:18 +0100 Subject: [PATCH 1/4] [SPARK-23564][SQL] Add isNotNull check for left anti and outer joins --- .../sql/catalyst/expressions/predicates.scala | 26 ++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 21 +++++++++-- .../plans/logical/QueryPlanConstraints.scala | 27 +------------- .../InferFiltersFromConstraintsSuite.scala | 36 +++++++++++++++++++ 4 files changed, 82 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a6d41ea7d00d4..398d06122c591 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -45,6 +45,32 @@ trait Predicate extends Expression { override def dataType: DataType = BooleanType } +trait NotNullConstraintHelper { + /** + * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions + * of constraints. + */ + protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = + constraint match { + // When the root is IsNotNull, we can push IsNotNull through the child null intolerant + // expressions + case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) + // Constraints always return true for all the inputs. That means, null will never be returned. + // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child + // null intolerant expressions. + case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) + } + + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) + case _ => Seq.empty[Attribute] + } +} trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 91208479be03b..72eaffb25fd12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -638,7 +638,8 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper + with NotNullConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -663,7 +664,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe // right child val constraints = join.allConstraints.filter { c => c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } + } ++ extraJoinConstraints(join).toSet // Remove those constraints that are already enforced by either the left or the right child val additionalConstraints = constraints -- (left.constraints ++ right.constraints) val newConditionOpt = conditionOpt match { @@ -675,6 +676,22 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join } + + /** + * Returns additional constraints which are not enforced on the result of join operations, but + * which can be enforced either on the left or the right side + */ + def extraJoinConstraints(join: Join): Seq[Expression] = { + join match { + case Join(_, right, LeftAnti | LeftOuter, condition) if condition.isDefined => + splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( + _.references.subsetOf(right.outputSet)) + case Join(left, _, RightOuter, condition) if condition.isDefined => + splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( + _.references.subsetOf(left.outputSet)) + case _ => Seq.empty[Expression] + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 046848875548b..a82dc7be509fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints { self: LogicalPlan => +trait QueryPlanConstraints extends NotNullConstraintHelper { self: LogicalPlan => /** * An [[ExpressionSet]] that contains an additional set of constraints, such as equality @@ -72,31 +72,6 @@ trait QueryPlanConstraints { self: LogicalPlan => isNotNullConstraints -- constraints } - /** - * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions - * of constraints. - */ - private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = - constraint match { - // When the root is IsNotNull, we can push IsNotNull through the child null intolerant - // expressions - case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) - // Constraints always return true for all the inputs. That means, null will never be returned. - // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child - // null intolerant expressions. - case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) - } - - /** - * Recursively explores the expressions which are null intolerant and returns all attributes - * in these expressions. - */ - private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { - case a: Attribute => Seq(a) - case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) - case _ => Seq.empty[Attribute] - } - /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index f78c2356e35a5..4ad6ad02108c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -204,4 +204,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftAnti, condition).analyze + val left = x + val right = y.where(IsNotNull('a)) + val correctAnswer = left.join(right, LeftAnti, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftOuter, condition).analyze + val left = x + val right = y.where(IsNotNull('a)) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, RightOuter, condition).analyze + val left = x.where(IsNotNull('a)) + val right = y + val correctAnswer = left.join(right, RightOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From d8a11901bb2785739caa593b3048df420419d35b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 10 Mar 2018 11:20:06 +0100 Subject: [PATCH 2/4] use allConstraints --- .../sql/catalyst/expressions/predicates.scala | 27 ---------- .../sql/catalyst/optimizer/Optimizer.scala | 21 +------- .../plans/logical/QueryPlanConstraints.scala | 51 +++++++++++++++---- .../plans/logical/basicLogicalOperators.scala | 22 ++++++++ 4 files changed, 65 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 398d06122c591..a8f812b1371f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -45,33 +45,6 @@ trait Predicate extends Expression { override def dataType: DataType = BooleanType } -trait NotNullConstraintHelper { - /** - * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions - * of constraints. - */ - protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = - constraint match { - // When the root is IsNotNull, we can push IsNotNull through the child null intolerant - // expressions - case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) - // Constraints always return true for all the inputs. That means, null will never be returned. - // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child - // null intolerant expressions. - case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) - } - - /** - * Recursively explores the expressions which are null intolerant and returns all attributes - * in these expressions. - */ - protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { - case a: Attribute => Seq(a) - case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) - case _ => Seq.empty[Attribute] - } -} - trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 72eaffb25fd12..91208479be03b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -638,8 +638,7 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper - with NotNullConstraintHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -664,7 +663,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe // right child val constraints = join.allConstraints.filter { c => c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } ++ extraJoinConstraints(join).toSet + } // Remove those constraints that are already enforced by either the left or the right child val additionalConstraints = constraints -- (left.constraints ++ right.constraints) val newConditionOpt = conditionOpt match { @@ -676,22 +675,6 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join } - - /** - * Returns additional constraints which are not enforced on the result of join operations, but - * which can be enforced either on the left or the right side - */ - def extraJoinConstraints(join: Join): Seq[Expression] = { - join match { - case Join(_, right, LeftAnti | LeftOuter, condition) if condition.isDefined => - splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( - _.references.subsetOf(right.outputSet)) - case Join(left, _, RightOuter, condition) if condition.isDefined => - splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( - _.references.subsetOf(left.outputSet)) - case _ => Seq.empty[Expression] - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index a82dc7be509fb..19279db0f7bc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,21 +20,13 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints extends NotNullConstraintHelper { self: LogicalPlan => +trait QueryPlanConstraints { self: LogicalPlan => /** * An [[ExpressionSet]] that contains an additional set of constraints, such as equality * constraints and `isNotNull` constraints, etc. */ - lazy val allConstraints: ExpressionSet = { - if (conf.constraintPropagationEnabled) { - ExpressionSet(validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints))) - } else { - ExpressionSet(Set.empty) - } - } + lazy val allConstraints: ExpressionSet = ExpressionSet(constructAllConstraints) /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For @@ -55,6 +47,20 @@ trait QueryPlanConstraints extends NotNullConstraintHelper { self: LogicalPlan = */ protected def validConstraints: Set[Expression] = Set.empty + /** + * Returns the [[Expression]]s representing all the constraints which can be enforced on the + * current operator. + */ + protected def constructAllConstraints: Set[Expression] = { + if (conf.constraintPropagationEnabled) { + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints)) + } else { + Set.empty + } + } + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this @@ -72,6 +78,31 @@ trait QueryPlanConstraints extends NotNullConstraintHelper { self: LogicalPlan = isNotNullConstraints -- constraints } + /** + * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions + * of constraints. + */ + protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = + constraint match { + // When the root is IsNotNull, we can push IsNotNull thro0ugh the child null intolerant + // expressions + case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) + // Constraints always return true for all the inputs. That means, null will never be returned. + // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child + // null intolerant expressions. + case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) + } + + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) + case _ => Seq.empty[Attribute] + } + /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..a3ae09e473eab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -341,6 +341,28 @@ case class Join( case UsingJoin(_, _) => false case _ => resolvedExceptNatural } + + /** + * Returns additional constraints which are not enforced on the result of join operations, but + * which can be enforced either on the left or the right side + */ + override protected def constructAllConstraints: Set[Expression] = { + val additionalConstraints = joinType match { + case LeftAnti | LeftOuter if condition.isDefined => + splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( + _.references.subsetOf(right.outputSet)) + case RightOuter if condition.isDefined => + splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( + _.references.subsetOf(left.outputSet)) + case _ => Seq.empty[Expression] + } + super.constructAllConstraints ++ additionalConstraints + } + + override lazy val constraints: ExpressionSet = ExpressionSet( + super.constructAllConstraints.filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + }) } /** From 9e2d993d691ad37b230c9e14d16148b9dc9727e6 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 10 Mar 2018 11:41:03 +0100 Subject: [PATCH 3/4] fix typos --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 1 + .../sql/catalyst/plans/logical/QueryPlanConstraints.scala | 2 +- .../sql/catalyst/plans/logical/basicLogicalOperators.scala | 6 ++---- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a8f812b1371f9..a6d41ea7d00d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -45,6 +45,7 @@ trait Predicate extends Expression { override def dataType: DataType = BooleanType } + trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 19279db0f7bc7..857b4d4060618 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -84,7 +84,7 @@ trait QueryPlanConstraints { self: LogicalPlan => */ protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = constraint match { - // When the root is IsNotNull, we can push IsNotNull thro0ugh the child null intolerant + // When the root is IsNotNull, we can push IsNotNull through the child null intolerant // expressions case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) // Constraints always return true for all the inputs. That means, null will never be returned. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a3ae09e473eab..e135e8ee04089 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -342,11 +342,9 @@ case class Join( case _ => resolvedExceptNatural } - /** - * Returns additional constraints which are not enforced on the result of join operations, but - * which can be enforced either on the left or the right side - */ override protected def constructAllConstraints: Set[Expression] = { + // additional constraints which are not enforced on the result of join operations, but can be + // enforced either on the left or the right side val additionalConstraints = joinType match { case LeftAnti | LeftOuter if condition.isDefined => splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( From 5cadd86ec4fae40c8d2606f0c00aed99a96d0027 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 23 Mar 2018 15:23:33 +0100 Subject: [PATCH 4/4] added more tests, refactored existing ones, made back private method --- .../plans/logical/QueryPlanConstraints.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 45 +++++++------------ .../plans/ConstraintPropagationSuite.scala | 39 ++++++++++++---- 3 files changed, 49 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 857b4d4060618..219231dc792cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -97,7 +97,7 @@ trait QueryPlanConstraints { self: LogicalPlan => * Recursively explores the expressions which are null intolerant and returns all attributes * in these expressions. */ - protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { + private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { case a: Attribute => Seq(a) case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 4ad6ad02108c7..04d5af603ef1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -40,6 +40,19 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -196,48 +209,24 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftSemi, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftSemi, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) } test("SPARK-23564: left anti join should filter out null join keys on right side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftAnti, condition).analyze - val left = x - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftAnti, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) } test("SPARK-23564: left outer join should filter out null join keys on right side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftOuter, condition).analyze - val left = x - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) } test("SPARK-23564: right outer join should filter out null join keys on left side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, RightOuter, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y - val correctAnswer = left.join(right, RightOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a37e06d922642..b19f5a7fde4ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -237,23 +237,46 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { test("propagating constraints in left-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1 + val plan = tr1 .where('a.attr > 10) .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) - .analyze.constraints, - ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, - IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) + .analyze + val expectedConstraints = ExpressionSet(Seq( + tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + verifyConstraints(plan.constraints, expectedConstraints) + verifyConstraints(plan.allConstraints, expectedConstraints + + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get)) } test("propagating constraints in right-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1 + val plan = tr1 .where('a.attr > 10) .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) - .analyze.constraints, - ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, - IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) + .analyze + val expectedConstraints = ExpressionSet(Seq( + tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + verifyConstraints(plan.constraints, expectedConstraints) + verifyConstraints(plan.allConstraints, expectedConstraints + + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)) + } + + test("propagating constraints in left-anti join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val plan = tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftAnti, Some("tr1.a".attr === "tr2.a".attr)) + .analyze + val expectedConstraints = ExpressionSet(Seq( + tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + verifyConstraints(plan.constraints, expectedConstraints) + verifyConstraints(plan.allConstraints, expectedConstraints + + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get)) } test("propagating constraints in full-outer join") {