diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b43b7ee71e7aa..05f5bdbfc0769 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -26,6 +26,56 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def output: Seq[Attribute] + /** + * Extracts the relevant constraints from a given set of constraints based on the attributes that + * appear in the [[outputSet]]. + */ + protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints + .union(constructIsNotNullConstraints(constraints)) + .filter(constraint => + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + } + + /** + * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions. + * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form + * `isNotNull(a)` + */ + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + // Currently we only propagate constraints if the condition consists of equality + // and ranges. For all other cases, we return an empty set of constraints + constraints.map { + case EqualTo(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case _ => + Set.empty[Expression] + }.foldLeft(Set.empty[Expression])(_ union _.toSet) + } + + /** + * A sequence of expressions that describes the data property of the output rows of this + * operator. For example, if the output of this operator is column `a`, an example `constraints` + * can be `Set(a > 10, a < 20)`. + */ + lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints) + + /** + * This method can be overridden by any child class of QueryPlan to specify a set of constraints + * based on the given operator's constraint propagation logic. These constraints are then + * canonicalized and filtered automatically to contain only those attributes that appear in the + * [[outputSet]] + */ + protected def validConstraints: Set[Expression] = Set.empty + /** * Returns the set of attributes that are output by this node. */ @@ -59,6 +109,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Runs [[transform]] with `rule` on all expressions present in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { @@ -67,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { @@ -99,6 +151,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. * @return */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6d859551f8c52..d8944a424156e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -305,6 +305,8 @@ abstract class UnaryNode extends LogicalPlan { def child: LogicalPlan override def children: Seq[LogicalPlan] = child :: Nil + + override protected def validConstraints: Set[Expression] = child.constraints } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 16f4b355b1b6c..8150ff8434762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -87,11 +87,27 @@ case class Generate( } } -case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { +case class Filter(condition: Expression, child: LogicalPlan) + extends UnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output + + override protected def validConstraints: Set[Expression] = { + child.constraints.union(splitConjunctivePredicates(condition).toSet) + } } -abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + + protected def leftConstraints: Set[Expression] = left.constraints + + protected def rightConstraints: Set[Expression] = { + require(left.output.size == right.output.size) + val attributeRewrites = AttributeMap(right.output.zip(left.output)) + right.constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } +} private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) @@ -106,6 +122,10 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + override protected def validConstraints: Set[Expression] = { + leftConstraints.union(rightConstraints) + } + // Intersect are only resolved if they don't introduce ambiguous expression ids, // since the Optimizer will convert Intersect to Join. override lazy val resolved: Boolean = @@ -119,6 +139,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + override protected def validConstraints: Set[Expression] = leftConstraints + override lazy val resolved: Boolean = childrenResolved && left.output.length == right.output.length && @@ -157,13 +179,36 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { val sizeInBytes = children.map(_.statistics.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } + + /** + * Maps the constraints containing a given (original) sequence of attributes to those with a + * given (reference) sequence of attributes. Given the nature of union, we expect that the + * mapping between the original and reference sequences are symmetric. + */ + private def rewriteConstraints( + reference: Seq[Attribute], + original: Seq[Attribute], + constraints: Set[Expression]): Set[Expression] = { + require(reference.size == original.size) + val attributeRewrites = AttributeMap(original.zip(reference)) + constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } + + override protected def validConstraints: Set[Expression] = { + children + .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) + .reduce(_ intersect _) + } } case class Join( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) + extends BinaryNode with PredicateHelper { override def output: Seq[Attribute] = { joinType match { @@ -180,6 +225,28 @@ case class Join( } } + override protected def validConstraints: Set[Expression] = { + joinType match { + case Inner if condition.isDefined => + left.constraints + .union(right.constraints) + .union(splitConjunctivePredicates(condition.get).toSet) + case LeftSemi if condition.isDefined => + left.constraints + .union(splitConjunctivePredicates(condition.get).toSet) + case Inner => + left.constraints.union(right.constraints) + case LeftSemi => + left.constraints + case LeftOuter => + left.constraints + case RightOuter => + right.constraints + case FullOuter => + Set.empty[Expression] + } + } + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala new file mode 100644 index 0000000000000..b5cf91394d910 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class ConstraintPropagationSuite extends SparkFunSuite { + + private def resolveColumn(tr: LocalRelation, columnName: String): Expression = + tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get + + private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { + val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) + val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _)) + if (missing.nonEmpty || extra.nonEmpty) { + fail( + s""" + |== FAIL: Constraints do not match === + |Found: ${found.mkString(",")} + |Expected: ${expected.mkString(",")} + |== Result == + |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")} + |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")} + """.stripMargin) + } + } + + test("propagating constraints in filters") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) + + verifyConstraints(tr + .where('a.attr > 10) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + IsNotNull(resolveColumn(tr, "a")))) + + verifyConstraints(tr + .where('a.attr > 10) + .select('c.attr, 'a.attr) + .where('c.attr < 100) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "c") < 100, + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c")))) + } + + test("propagating constraints in union") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('d.int, 'e.int, 'f.int) + val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + + assert(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('e.attr > 10) + .unionAll(tr3.where('i.attr > 10))) + .analyze.constraints.isEmpty) + + verifyConstraints(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('d.attr > 10) + .unionAll(tr3.where('g.attr > 10))) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) + } + + test("propagating constraints in intersect") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + + verifyConstraints(tr1 + .where('a.attr > 10) + .intersect(tr2.where('b.attr < 100)) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + resolveColumn(tr1, "b") < 100, + IsNotNull(resolveColumn(tr1, "a")), + IsNotNull(resolveColumn(tr1, "b")))) + } + + test("propagating constraints in except") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + verifyConstraints(tr1 + .where('a.attr > 10) + .except(tr2.where('b.attr < 100)) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) + } + + test("propagating constraints in inner join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + tr1.resolveQuoted("a", caseInsensitiveResolution).get === + tr2.resolveQuoted("a", caseInsensitiveResolution).get, + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-semi join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in right-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + } + + test("propagating constraints in full-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + assert(tr1.where('a.attr > 10) + .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints.isEmpty) + } +}