Skip to content

Commit

Permalink
refactory And/Or optimization to make it more readable and clean
Browse files Browse the repository at this point in the history
  • Loading branch information
scwf committed Dec 26, 2014
1 parent ac82785 commit e99a26c
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 1 deletion.
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,11 @@
<artifactId>jackson-mapper-asl</artifactId>
<version>1.8.8</version>
</dependency>
<dependency>
<groupId>org.spire-math</groupId>
<artifactId>spire_2.10</artifactId>
<version>0.9.0</version>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down
5 changes: 4 additions & 1 deletion sql/catalyst/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
</dependency>

<dependency>
<groupId>org.spire-math</groupId>
<artifactId>spire_2.10</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ trait PredicateHelper {
}
}

protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = {
condition match {
case Or(cond1, cond2) =>
splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2)
case other => other :: Nil
}
}

/**
* Returns true if `expr` can be evaluated using only the output of `plan`. This method
* can be used to determine when is is acceptable to move expression evaluation within a query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import spire.implicits._
import spire.math._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
Expand Down Expand Up @@ -293,6 +296,116 @@ object OptimizeIn extends Rule[LogicalPlan] {
}
}

/**
* Simplifies Conditions(And, Or) expressions when the conditions can by optimized.
*/
object ConditionSimplification extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
// a && a => a
case And(left, right) if left.fastEquals(right) =>
left

// a || a => a
case Or(left, right) if left.fastEquals(right) =>
left

// a < 2 && a > 2 => false, a > 3 && a > 5 => a > 5
case and @ And(
e1 @ NumericLiteralBinaryComparison(n1, i1),
e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 =>
if (!i1.intersects(i2)) Literal(false)
else if (i1.isSubsetOf(i2)) e1
else if (i1.isSupersetOf(i2)) e2
else and

// a < 2 || a >= 2 => true, a > 3 || a > 5 => a > 3
case or @ Or(
e1 @ NumericLiteralBinaryComparison(n1, i1),
e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 =>
if (i1.intersects(i2)) Literal(true)
else if (i1.isSubsetOf(i2)) e2
else if (i1.isSupersetOf(i2)) e1
else or

// (a < 3 && b > 5) || a > 2 => b > 5 || a > 2
case Or(left1 @ And(left2, right2), right1) =>
And(Or(left2, right1), Or(right2, right1))

// (a < 3 || b > 5) || a > 2 => true, (b > 5 || a < 3) || a > 2 => true
case Or( Or(
e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)),
right @ NumericLiteralBinaryComparison(n3, i3)) =>
if (n3 fastEquals n1) {
Or(Or(e1, right), e2)
} else {
Or(Or(e2, right), e1)
}

// (b > 5 && a < 2) && a > 3 => false, (a < 2 && b > 5) && a > 3 => false
case And(And(
e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)),
right @ NumericLiteralBinaryComparison(n3, i3)) =>
if (n3 fastEquals n1) {
And(And(e1, right), e2)
} else {
And(And(e2, right), e1)
}

// (a < 2 || b > 5) && a > 3 => b > 5 && a > 3
case And(left1@Or(left2, right2), right1) =>
Or(And(left2, right1), And(right2, right1))

// (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... =>
// a && b && ((c && ...) || (d && ...) || (e && ...) || ...)
case or @ Or(left, right) =>
val lhsSet = splitConjunctivePredicates(left).toSet
val rhsSet = splitConjunctivePredicates(right).toSet
val common = lhsSet.intersect(rhsSet)
(lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And))
.reduceOption(Or)
.map(_ :: common.toList)
.getOrElse(common.toList)
.reduce(And)

// (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... =>
// (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...)
case and @ And(left, right) =>
val lhsSet = splitDisjunctivePredicates(left).toSet
val rhsSet = splitDisjunctivePredicates(right).toSet
val common = lhsSet.intersect(rhsSet)
(lhsSet.diff(common).reduceOption(Or) ++ rhsSet.diff(common).reduceOption(Or))
.reduceOption(And)
.map(_ :: common.toList)
.getOrElse(common.toList)
.reduce(Or)
}
}

private implicit class NumericLiteral(e: Literal) {
def toDouble = Cast(e, DoubleType).eval().asInstanceOf[Double]
}

object NumericLiteralBinaryComparison {
def unapply(e: Expression): Option[(NamedExpression, Interval[Double])] = e match {
case LessThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.below(l.toDouble)))
case LessThan(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.atOrAbove(l.toDouble)))

case GreaterThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.above(l.toDouble)))
case GreaterThan(l @ Literal(_, dt: NumericType), n: NamedExpression) => Some((n, Interval.atOrBelow(l.toDouble)))

case LessThanOrEqual(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.atOrBelow(l.toDouble)))
case LessThanOrEqual(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.above(l.toDouble)))

case GreaterThanOrEqual(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.atOrAbove(l.toDouble)))
case GreaterThanOrEqual(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.below(l.toDouble)))

case EqualTo(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.point(l.toDouble)))
}
}
}

/**
* Simplifies boolean expressions where the answer can be determined without evaluating both sides.
* Note that this rule can eliminate expressions that might otherwise have been evaluated and thus
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.expressions.{Literal, Expression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class ConditionSimplificationSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateAnalysisOperators) ::
Batch("Constant Folding", FixedPoint(10),
NullPropagation,
ConstantFolding,
ConditionSimplification,
BooleanSimplification,
SimplifyFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)

def checkCondition(originCondition: Expression, optimizedCondition: Expression): Unit = {
val originQuery = testRelation.where(originCondition).analyze
val optimized = Optimize(originQuery)
val expected = testRelation.where(optimizedCondition).analyze
comparePlans(optimized, expected)
}

def checkCondition(originCondition: Expression): Unit = {
val originQuery = testRelation.where(originCondition).analyze
val optimized = Optimize(originQuery)
val expected = testRelation
comparePlans(optimized, expected)
}

test("literal in front of attribute") {
checkCondition(Literal(1) < 'a || Literal(2) < 'a, 'a > 1)
}

test("combine the same condition") {
checkCondition('a < 1 || 'a < 1, 'a < 1)
checkCondition('a < 1 || 'a < 1 || 'a < 1 || 'a < 1, 'a < 1)
checkCondition('a > 2 && 'a > 2, 'a > 2)
checkCondition('a > 2 && 'a > 2 && 'a > 2 && 'a > 2, 'a > 2)
checkCondition(('a < 1 && 'a < 2) || ('a < 1 && 'a < 2), 'a < 1)
}

test("combine literal binary comparison") {
checkCondition('a === 1 && 'a < 1)
checkCondition('a === 1 || 'a < 1, 'a <= 1)

checkCondition('a === 1 && 'a === 2)
checkCondition('a === 1 || 'a === 2, 'a === 1 || 'a === 2)

checkCondition('a <= 1 && 'a > 1)
checkCondition('a <= 1 || 'a > 1)

checkCondition('a < 1 && 'a >= 1)
checkCondition('a < 1 || 'a >= 1)

checkCondition('a > 3 && 'a > 2, 'a > 3)
checkCondition('a > 3 || 'a > 2, 'a > 2)

checkCondition('a >= 1 && 'a <= 1, 'a === 1)

}

test("different data type comparison") {
checkCondition('a > "abc")
checkCondition('a > "a" && 'a < "b")

checkCondition('a > "a" || 'a < "b")

checkCondition('a > "9" || 'a < "0", 'a > 9.0 || 'a < 0.0)
checkCondition('d > 9 && 'd < 1, 'd > 9.0 && 'd < 1.0 )

checkCondition('a > "9" || 'a < "0", 'a > 9.0 || 'a < 0.0)
}

test("combine predicate : 2 same combine") {
checkCondition('a < 1 || 'b > 2 || 'a >= 1)
checkCondition('a < 1 && 'b > 2 && 'a >= 1)

checkCondition('a < 2 || 'b > 3 || 'b > 2, 'a < 2 || 'b > 2)
checkCondition('a < 2 && 'b > 3 && 'b > 2, 'a < 2 && 'b > 3)

checkCondition('a < 2 || ('b > 3 || 'b > 2), 'b > 2 || 'a < 2)
checkCondition('a < 2 && ('b > 3 && 'b > 2), 'b > 3 && 'a < 2)

checkCondition('a < 2 || 'a === 3 || 'a > 5, 'a < 2 || 'a === 3 || 'a > 5)
}

test("combine predicate : 2 difference combine") {
checkCondition(('a < 2 || 'a > 3) && 'a > 4, 'a > 4)
checkCondition(('a < 2 || 'b > 3) && 'a < 2, 'a < 2)

checkCondition('a < 2 || ('a >= 2 && 'b > 1), 'b > 1 || 'a < 2)
checkCondition('a < 2 || ('a === 2 && 'b > 1), 'a < 2 || ('a === 2 && 'b > 1))

checkCondition('a > 3 || ('a > 2 && 'a < 4), 'a > 2)
}

test("multi left, single right") {
checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2)
}

test("multi left, multi right") {
checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5))

var input: Expression = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5)
var expected: Expression = 'a === 'b || ('b > 3 && 'a > 3 && 'a < 5)
checkCondition(input, expected)

input = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a > 1)
expected = 'a === 'b || ('b > 3 && 'a > 3)
checkCondition(input, expected)

input = ('a === 'b && 'b > 3 && 'c > 2) ||
('a === 'b && 'c < 1 && 'a === 5) ||
('a === 'b && 'b < 5 && 'a > 1)

expected = ('a === 'b) &&
(((('b > 3) && ('c > 2)) ||
(('c < 1) && ('a === 5))) ||
(('b < 5) && ('a > 1)))
checkCondition(input, expected)

input = ('a < 2 || 'b > 5 || 'a < 2 || 'b > 1) && ('a < 2 || 'b > 1)
expected = 'a < 2 || 'b > 1
checkCondition(input, expected)

input = ('a === 'b || 'b > 5) && ('a === 'b || 'c > 3) && ('a === 'b || 'b > 1)
expected = ('a === 'b) || ('c > 3 && 'b > 5)
checkCondition(input, expected)
}
}

0 comments on commit e99a26c

Please sign in to comment.