Skip to content

Commit

Permalink
Fixes test failure, adds more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Dec 24, 2014
1 parent 5d54349 commit 4ab3a58
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.types.BooleanType
Expand Down Expand Up @@ -48,6 +47,14 @@ trait PredicateHelper {
}
}

protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = {
condition match {
case Or(cond1, cond2) =>
splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2)
case other => other :: Nil
}
}

/**
* Returns true if `expr` can be evaluated using only the output of `plan`. This method
* can be used to determine when is is acceptable to move expression evaluation within a query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ object CombineFilters extends Rule[LogicalPlan] {
}

/**
* Normalizes conjuctions and disjunctions to eliminate common factors.
* Normalizes conjunctions and disjunctions to eliminate common factors.
*/
object NormalizeFilters extends Rule[LogicalPlan] with PredicateHelper {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Expand All @@ -358,17 +358,23 @@ object NormalizeFilters extends Rule[LogicalPlan] with PredicateHelper {
}

def normalizedPredicate(predicate: Expression): Seq[Expression] = predicate match {
// a || a => a
case Or(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil
// a && a => a
case And(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil
// a && a && a ... => a
case p @ And(e, _) if splitConjunctivePredicates(p).distinct.size == 1 => e :: Nil

// a || a || a ... => a
case p @ Or(e, _) if splitDisjunctivePredicates(p).distinct.size == 1 => e :: Nil

// (a && b && c && ...) || (a && b && d && ...) => a && b && (c || d || ...)
case Or(lhs, rhs) =>
val lhsSet = splitConjunctivePredicates(lhs).toSet
val rhsSet = splitConjunctivePredicates(rhs).toSet
val commonPredicates = lhsSet & rhsSet
val otherPredicates = (lhsSet | rhsSet) &~ commonPredicates
otherPredicates.reduceOption(Or).getOrElse(Literal(true)) :: commonPredicates.toList
val common = lhsSet.intersect(rhsSet)

(lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And))
.reduceOption(Or)
.map(_ :: common.toList)
.getOrElse(common.toList)

case _ => predicate :: Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ class NormalizeFiltersSuite extends PlanTest {

test("a && a => a") {
checkExpression('a === 1 && 'a === 1, 'a === 1)
checkExpression('a === 1 && 'a === 1 && 'a === 1, 'a === 1)
}

test("a || a => a") {
checkExpression('a === 1 || 'a === 1, 'a === 1)
checkExpression('a === 1 || 'a === 1 || 'a === 1, 'a === 1)
}

test("(a && b) || (a && c)") {
test("(a && b) || (a && c) => a && (b || c)") {
checkExpression(
('a === 1 && 'a < 10) || ('a > 2 && 'a === 1),
('a === 1) && ('a < 10 || 'a > 2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,20 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be

test(query) {
val schemaRdd = sql(query)
assertResult(expectedQueryResult.toArray, "Wrong query result") {
val queryExecution = schemaRdd.queryExecution

assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
schemaRdd.collect().map(_.head).toArray
}

val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect {
case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
}.head

assert(readBatches === expectedReadBatches, "Wrong number of read batches")
assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions")
assert(readBatches === expectedReadBatches, s"Wrong number of read batches: $queryExecution")
assert(
readPartitions === expectedReadPartitions,
s"Wrong number of read partitions: $queryExecution")
}
}
}

0 comments on commit 4ab3a58

Please sign in to comment.