Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sameeragarwal committed Jan 20, 2016
1 parent 67e138d commit 04ff99a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]]
/**
* Extracts the output property from a given child.
*/
def extractConstraintsFromChild(child: QueryPlan[PlanType]): Seq[Expression] = {
def extractConstraintsFromChild(child: QueryPlan[PlanType]): Set[Expression] = {
child.constraints.filter(_.references.subsetOf(outputSet))
}

/**
* An sequence of expressions that describes the data property of the output rows of this
* operator. For example, if the output of this operator is column `a`, an example `constraints`
* can be `Seq(a > 10, a < 20)`.
* can be `Set(a > 10, a < 20)`.
*/
def constraints: Seq[Expression] = Nil
def constraints: Set[Expression] = Set.empty

/**
* Returns the set of attributes that are output by this node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ abstract class UnaryNode extends LogicalPlan with PredicateHelper {

override def children: Seq[LogicalPlan] = child :: Nil

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
extractConstraintsFromChild(child)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ case class Generate(
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
val newConstraint = splitConjunctivePredicates(condition).filter(
_.references.subsetOf(outputSet))
_.references.subsetOf(outputSet)).toSet
newConstraint.union(extractConstraintsFromChild(child))
}
}
Expand All @@ -103,9 +103,9 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar
leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
}

protected def leftConstraints: Seq[Expression] = extractConstraintsFromChild(left)
protected def leftConstraints: Set[Expression] = extractConstraintsFromChild(left)

protected def rightConstraints: Seq[Expression] = {
protected def rightConstraints: Set[Expression] = {
require(left.output.size == right.output.size)
val attributeRewrites = AttributeMap(left.output.zip(right.output))
extractConstraintsFromChild(right).map(_ transform {
Expand Down Expand Up @@ -135,7 +135,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef
Statistics(sizeInBytes = sizeInBytes)
}

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
leftConstraints.intersect(rightConstraints)
}
}
Expand All @@ -147,7 +147,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
leftConstraints.union(rightConstraints)
}
}
Expand All @@ -156,7 +156,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output

override def constraints: Seq[Expression] = leftConstraints
override def constraints: Set[Expression] = leftConstraints
}

case class Join(
Expand All @@ -180,7 +180,7 @@ case class Join(
}
}

override def constraints: Seq[Expression] = {
override def constraints: Set[Expression] = {
joinType match {
case LeftSemi =>
extractConstraintsFromChild(left)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._

/**
* This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly
Expand Down Expand Up @@ -75,27 +74,14 @@ class LogicalPlanSuite extends SparkFunSuite {
}

test("propagating constraint in filter") {

def resolve(plan: LogicalPlan, constraints: Seq[String]): Seq[Expression] = {
Seq(plan.resolve(constraints.map(_.toString), caseInsensitiveResolution).get)
}

val tr = LocalRelation('a.int, 'b.string, 'c.int)
def resolveColumn(columnName: String): Expression =
tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get
assert(tr.analyze.constraints.isEmpty)
assert(tr.select('a.attr).analyze.constraints.isEmpty)
assert(tr.where('a.attr > 10).analyze.constraints.zip(Seq('a.attr > 10))
.forall(e => e._1.semanticEquals(e._2)))
/*
assert(tr.where('a.attr > 10).analyze.constraints == resolve(tr.where('a.attr > 10).analyze,
Seq("a > 10")))
*/
/*
assert(logicalPlan.constraints ==
Seq(logicalPlan.resolve(Seq('a > 10), caseInsensitiveResolution))
assert(tr.where('a.attr > 10).select('c.attr).analyze.constraints.get == ('a > 10))
assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
.analyze.constraints.get == And('a > 10, 'c < 100))
assert(tr.where('a.attr > 10).analyze.constraints == Set(resolveColumn("a") > 10))
assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty)
*/
assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
.analyze.constraints == Set(resolveColumn("a") > 10, resolveColumn("c") < 100))
}
}

0 comments on commit 04ff99a

Please sign in to comment.