Skip to content

Commit

Permalink
move constructIsNotNullConstraints in QueryPlan
Browse files Browse the repository at this point in the history
  • Loading branch information
sameeragarwal committed Feb 2, 2016
1 parent 302444f commit b52742a
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,30 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]]
* Extracts the relevant constraints from a given set of constraints based on the attributes that
* appear in the [[outputSet]].
*/
private def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
constraints.filter(_.references.subsetOf(outputSet))
protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
constraints
.union(constructIsNotNullConstraints(constraints))
.filter(constraint =>
constraint.references.nonEmpty && constraint.references.subsetOf(outputSet))
}

private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
// Currently we only propagate constraints if the condition consists of equality
// and ranges. For all other cases, we return an empty set of constraints
constraints.map {
case EqualTo(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case GreaterThan(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case GreaterThanOrEqual(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case LessThan(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case LessThanOrEqual(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case _ =>
Set.empty[Expression]
}.foldLeft(Set.empty[Expression])(_ union _.toSet)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

override protected def validConstraints: Set[Expression] = {
val newConstraint = splitConjunctivePredicates(condition)
.filter(_.references.subsetOf(outputSet))
.toSet
newConstraint.union(child.constraints)
child.constraints.union(splitConjunctivePredicates(condition).toSet)
}
}

Expand Down Expand Up @@ -221,35 +218,20 @@ case class Join(
}
}

private def constructIsNotNullConstraints(condition: Expression): Set[Expression] = {
// Currently we only propagate constraints if the condition consists of equality
// and ranges. For all other cases, we return an empty set of constraints
splitConjunctivePredicates(condition).map {
case EqualTo(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case GreaterThan(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case GreaterThanOrEqual(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case LessThan(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case LessThanOrEqual(l, r) =>
Set(IsNotNull(l), IsNotNull(r))
case _ =>
Set.empty[Expression]
}.foldLeft(Set.empty[Expression])(_ union _.toSet)
}

override protected def validConstraints: Set[Expression] = {
joinType match {
case Inner if condition.isDefined =>
left.constraints
.union(right.constraints)
.union(constructIsNotNullConstraints(condition.get))
.union(splitConjunctivePredicates(condition.get).toSet)
case LeftSemi if condition.isDefined =>
left.constraints
.union(right.constraints)
.union(constructIsNotNullConstraints(condition.get))
.union(splitConjunctivePredicates(condition.get).toSet)
case Inner =>
left.constraints.union(right.constraints)
case LeftSemi =>
left.constraints.union(right.constraints)
case LeftOuter =>
left.constraints
case RightOuter =>
Expand All @@ -259,8 +241,6 @@ case class Join(
}
}

def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty

def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty

// Joins are only resolved if they don't introduce ambiguous expression ids.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,49 +29,78 @@ class ConstraintPropagationSuite extends SparkFunSuite {
private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get

private def verifyConstraints(a: Set[Expression], b: Set[Expression]): Unit = {
assert(a.forall(i => b.map(_.semanticEquals(i)).reduce(_ || _)))
assert(b.forall(i => a.map(_.semanticEquals(i)).reduce(_ || _)))
private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = {
val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _))
val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _))
if (missing.nonEmpty || extra.nonEmpty) {
fail(
s"""
|== FAIL: Constraints do not match ===
|Found: ${found.mkString(",")}
|Expected: ${expected.mkString(",")}
|== Result ==
|Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")}
|Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")}
""".stripMargin)
}
}

test("propagating constraints in filters") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)

assert(tr.analyze.constraints.isEmpty)

assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty)
verifyConstraints(tr.where('a.attr > 10).analyze.constraints, Set(resolveColumn(tr, "a") > 10))

verifyConstraints(tr
.where('a.attr > 10)
.analyze.constraints,
Set(resolveColumn(tr, "a") > 10,
IsNotNull(resolveColumn(tr, "a"))))

verifyConstraints(tr
.where('a.attr > 10)
.select('c.attr, 'a.attr)
.where('c.attr < 100)
.analyze.constraints,
Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100))
Set(resolveColumn(tr, "a") > 10,
resolveColumn(tr, "c") < 100,
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "c"))))
}

test("propagating constraints in union") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
val tr2 = LocalRelation('d.int, 'e.int, 'f.int)
val tr3 = LocalRelation('g.int, 'h.int, 'i.int)

assert(tr1
.where('a.attr > 10)
.unionAll(tr2.where('e.attr > 10)
.unionAll(tr3.where('i.attr > 10)))
.analyze.constraints.isEmpty)

verifyConstraints(tr1
.where('a.attr > 10)
.unionAll(tr2.where('d.attr > 10)
.unionAll(tr3.where('g.attr > 10)))
.analyze.constraints,
Set(resolveColumn(tr1, "a") > 10))
Set(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a"))))
}

test("propagating constraints in intersect") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
val tr2 = LocalRelation('a.int, 'b.int, 'c.int)

verifyConstraints(tr1
.where('a.attr > 10)
.intersect(tr2.where('b.attr < 100))
.analyze.constraints,
Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100))
Set(resolveColumn(tr1, "a") > 10,
resolveColumn(tr1, "b") < 100,
IsNotNull(resolveColumn(tr1, "a")),
IsNotNull(resolveColumn(tr1, "b"))))
}

test("propagating constraints in except") {
Expand All @@ -81,7 +110,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.except(tr2.where('b.attr < 100))
.analyze.constraints,
Set(resolveColumn(tr1, "a") > 10))
Set(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a"))))
}

test("propagating constraints in inner join") {
Expand All @@ -93,8 +123,11 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
tr1.resolveQuoted("a", caseInsensitiveResolution).get ===
tr2.resolveQuoted("a", caseInsensitiveResolution).get,
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
}

test("propagating constraints in left-semi join") {
Expand All @@ -115,7 +148,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10))
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
}

test("propagating constraints in right-outer join") {
Expand All @@ -125,7 +159,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100))
Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
}

test("propagating constraints in full-outer join") {
Expand Down

0 comments on commit b52742a

Please sign in to comment.