diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index e50a9ae043a40..546e635700739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -31,8 +31,30 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] * Extracts the relevant constraints from a given set of constraints based on the attributes that * appear in the [[outputSet]]. */ - private def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { - constraints.filter(_.references.subsetOf(outputSet)) + protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints + .union(constructIsNotNullConstraints(constraints)) + .filter(constraint => + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + } + + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + // Currently we only propagate constraints if the condition consists of equality + // and ranges. For all other cases, we return an empty set of constraints + constraints.map { + case EqualTo(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case _ => + Set.empty[Expression] + }.foldLeft(Set.empty[Expression])(_ union _.toSet) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 65c135322c04f..83551325fd5ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -91,10 +91,7 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override protected def validConstraints: Set[Expression] = { - val newConstraint = splitConjunctivePredicates(condition) - .filter(_.references.subsetOf(outputSet)) - .toSet - newConstraint.union(child.constraints) + child.constraints.union(splitConjunctivePredicates(condition).toSet) } } @@ -221,35 +218,20 @@ case class Join( } } - private def constructIsNotNullConstraints(condition: Expression): Set[Expression] = { - // Currently we only propagate constraints if the condition consists of equality - // and ranges. For all other cases, we return an empty set of constraints - splitConjunctivePredicates(condition).map { - case EqualTo(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case _ => - Set.empty[Expression] - }.foldLeft(Set.empty[Expression])(_ union _.toSet) - } - override protected def validConstraints: Set[Expression] = { joinType match { case Inner if condition.isDefined => left.constraints .union(right.constraints) - .union(constructIsNotNullConstraints(condition.get)) + .union(splitConjunctivePredicates(condition.get).toSet) case LeftSemi if condition.isDefined => left.constraints .union(right.constraints) - .union(constructIsNotNullConstraints(condition.get)) + .union(splitConjunctivePredicates(condition.get).toSet) + case Inner => + left.constraints.union(right.constraints) + case LeftSemi => + left.constraints.union(right.constraints) case LeftOuter => left.constraints case RightOuter => @@ -259,8 +241,6 @@ case class Join( } } - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 31995c3c8ad08..b5cf91394d910 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -29,49 +29,78 @@ class ConstraintPropagationSuite extends SparkFunSuite { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get - private def verifyConstraints(a: Set[Expression], b: Set[Expression]): Unit = { - assert(a.forall(i => b.map(_.semanticEquals(i)).reduce(_ || _))) - assert(b.forall(i => a.map(_.semanticEquals(i)).reduce(_ || _))) + private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { + val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) + val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _)) + if (missing.nonEmpty || extra.nonEmpty) { + fail( + s""" + |== FAIL: Constraints do not match === + |Found: ${found.mkString(",")} + |Expected: ${expected.mkString(",")} + |== Result == + |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")} + |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")} + """.stripMargin) + } } test("propagating constraints in filters") { val tr = LocalRelation('a.int, 'b.string, 'c.int) + assert(tr.analyze.constraints.isEmpty) + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) - verifyConstraints(tr.where('a.attr > 10).analyze.constraints, Set(resolveColumn(tr, "a") > 10)) + + verifyConstraints(tr + .where('a.attr > 10) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + IsNotNull(resolveColumn(tr, "a")))) + verifyConstraints(tr .where('a.attr > 10) .select('c.attr, 'a.attr) .where('c.attr < 100) .analyze.constraints, - Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100)) + Set(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "c") < 100, + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c")))) } test("propagating constraints in union") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('d.int, 'e.int, 'f.int) val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + assert(tr1 .where('a.attr > 10) .unionAll(tr2.where('e.attr > 10) .unionAll(tr3.where('i.attr > 10))) .analyze.constraints.isEmpty) + verifyConstraints(tr1 .where('a.attr > 10) .unionAll(tr2.where('d.attr > 10) .unionAll(tr3.where('g.attr > 10))) .analyze.constraints, - Set(resolveColumn(tr1, "a") > 10)) + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) } test("propagating constraints in intersect") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + verifyConstraints(tr1 .where('a.attr > 10) .intersect(tr2.where('b.attr < 100)) .analyze.constraints, - Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100)) + Set(resolveColumn(tr1, "a") > 10, + resolveColumn(tr1, "b") < 100, + IsNotNull(resolveColumn(tr1, "a")), + IsNotNull(resolveColumn(tr1, "b")))) } test("propagating constraints in except") { @@ -81,7 +110,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .except(tr2.where('b.attr < 100)) .analyze.constraints, - Set(resolveColumn(tr1, "a") > 10)) + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) } test("propagating constraints in inner join") { @@ -93,8 +123,11 @@ class ConstraintPropagationSuite extends SparkFunSuite { .analyze.constraints, Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + tr1.resolveQuoted("a", caseInsensitiveResolution).get === + tr2.resolveQuoted("a", caseInsensitiveResolution).get, IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), - IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) } test("propagating constraints in left-semi join") { @@ -115,7 +148,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10)) + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) } test("propagating constraints in right-outer join") { @@ -125,7 +159,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100)) + Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) } test("propagating constraints in full-outer join") {