diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 046848875548b..219231dc792cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -26,15 +26,7 @@ trait QueryPlanConstraints { self: LogicalPlan => * An [[ExpressionSet]] that contains an additional set of constraints, such as equality * constraints and `isNotNull` constraints, etc. */ - lazy val allConstraints: ExpressionSet = { - if (conf.constraintPropagationEnabled) { - ExpressionSet(validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints))) - } else { - ExpressionSet(Set.empty) - } - } + lazy val allConstraints: ExpressionSet = ExpressionSet(constructAllConstraints) /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For @@ -55,6 +47,20 @@ trait QueryPlanConstraints { self: LogicalPlan => */ protected def validConstraints: Set[Expression] = Set.empty + /** + * Returns the [[Expression]]s representing all the constraints which can be enforced on the + * current operator. + */ + protected def constructAllConstraints: Set[Expression] = { + if (conf.constraintPropagationEnabled) { + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints)) + } else { + Set.empty + } + } + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this @@ -76,7 +82,7 @@ trait QueryPlanConstraints { self: LogicalPlan => * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions * of constraints. */ - private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = + protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = constraint match { // When the root is IsNotNull, we can push IsNotNull through the child null intolerant // expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..e135e8ee04089 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -341,6 +341,26 @@ case class Join( case UsingJoin(_, _) => false case _ => resolvedExceptNatural } + + override protected def constructAllConstraints: Set[Expression] = { + // additional constraints which are not enforced on the result of join operations, but can be + // enforced either on the left or the right side + val additionalConstraints = joinType match { + case LeftAnti | LeftOuter if condition.isDefined => + splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( + _.references.subsetOf(right.outputSet)) + case RightOuter if condition.isDefined => + splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( + _.references.subsetOf(left.outputSet)) + case _ => Seq.empty[Expression] + } + super.constructAllConstraints ++ additionalConstraints + } + + override lazy val constraints: ExpressionSet = ExpressionSet( + super.constructAllConstraints.filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + }) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index f78c2356e35a5..04d5af603ef1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -40,6 +40,19 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -196,12 +209,24 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftSemi, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftSemi, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) + } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a37e06d922642..b19f5a7fde4ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -237,23 +237,46 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { test("propagating constraints in left-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1 + val plan = tr1 .where('a.attr > 10) .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) - .analyze.constraints, - ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, - IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) + .analyze + val expectedConstraints = ExpressionSet(Seq( + tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + verifyConstraints(plan.constraints, expectedConstraints) + verifyConstraints(plan.allConstraints, expectedConstraints + + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get)) } test("propagating constraints in right-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1 + val plan = tr1 .where('a.attr > 10) .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) - .analyze.constraints, - ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, - IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) + .analyze + val expectedConstraints = ExpressionSet(Seq( + tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + verifyConstraints(plan.constraints, expectedConstraints) + verifyConstraints(plan.allConstraints, expectedConstraints + + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)) + } + + test("propagating constraints in left-anti join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val plan = tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftAnti, Some("tr1.a".attr === "tr2.a".attr)) + .analyze + val expectedConstraints = ExpressionSet(Seq( + tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + verifyConstraints(plan.constraints, expectedConstraints) + verifyConstraints(plan.allConstraints, expectedConstraints + + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get)) } test("propagating constraints in full-outer join") {