diff --git a/pom.xml b/pom.xml
index e4db1393ba9cf..052ce02e125de 100644
--- a/pom.xml
+++ b/pom.xml
@@ -824,6 +824,11 @@
jackson-mapper-asl
1.8.8
+
+ org.spire-math
+ spire_2.10
+ 0.9.0
+
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 1caa297e24e37..43f12a6b9364c 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -44,7 +44,10 @@
org.scala-lang
scala-reflect
-
+
+ org.spire-math
+ spire_2.10
+
org.apache.spark
spark-core_${scala.binary.version}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 94b6fb084d38a..c0fddf2af425b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -48,6 +48,14 @@ trait PredicateHelper {
}
}
+ protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = {
+ condition match {
+ case Or(cond1, cond2) =>
+ splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2)
+ case other => other :: Nil
+ }
+ }
+
/**
* Returns true if `expr` can be evaluated using only the output of `plan`. This method
* can be used to determine when is is acceptable to move expression evaluation within a query
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 806c1394eb151..5fb6641686bdf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -18,6 +18,9 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
+import spire.implicits._
+import spire.math._
+
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -293,6 +296,116 @@ object OptimizeIn extends Rule[LogicalPlan] {
}
}
+/**
+ * Simplifies Conditions(And, Or) expressions when the conditions can by optimized.
+ */
+object ConditionSimplification extends Rule[LogicalPlan] with PredicateHelper {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsDown {
+ // a && a => a
+ case And(left, right) if left.fastEquals(right) =>
+ left
+
+ // a || a => a
+ case Or(left, right) if left.fastEquals(right) =>
+ left
+
+ // a < 2 && a > 2 => false, a > 3 && a > 5 => a > 5
+ case and @ And(
+ e1 @ NumericLiteralBinaryComparison(n1, i1),
+ e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 =>
+ if (!i1.intersects(i2)) Literal(false)
+ else if (i1.isSubsetOf(i2)) e1
+ else if (i1.isSupersetOf(i2)) e2
+ else and
+
+ // a < 2 || a >= 2 => true, a > 3 || a > 5 => a > 3
+ case or @ Or(
+ e1 @ NumericLiteralBinaryComparison(n1, i1),
+ e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 =>
+ if (i1.intersects(i2)) Literal(true)
+ else if (i1.isSubsetOf(i2)) e2
+ else if (i1.isSupersetOf(i2)) e1
+ else or
+
+ // (a < 3 && b > 5) || a > 2 => b > 5 || a > 2
+ case Or(left1 @ And(left2, right2), right1) =>
+ And(Or(left2, right1), Or(right2, right1))
+
+ // (a < 3 || b > 5) || a > 2 => true, (b > 5 || a < 3) || a > 2 => true
+ case Or( Or(
+ e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)),
+ right @ NumericLiteralBinaryComparison(n3, i3)) =>
+ if (n3 fastEquals n1) {
+ Or(Or(e1, right), e2)
+ } else {
+ Or(Or(e2, right), e1)
+ }
+
+ // (b > 5 && a < 2) && a > 3 => false, (a < 2 && b > 5) && a > 3 => false
+ case And(And(
+ e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)),
+ right @ NumericLiteralBinaryComparison(n3, i3)) =>
+ if (n3 fastEquals n1) {
+ And(And(e1, right), e2)
+ } else {
+ And(And(e2, right), e1)
+ }
+
+ // (a < 2 || b > 5) && a > 3 => b > 5 && a > 3
+ case And(left1@Or(left2, right2), right1) =>
+ Or(And(left2, right1), And(right2, right1))
+
+ // (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... =>
+ // a && b && ((c && ...) || (d && ...) || (e && ...) || ...)
+ case or @ Or(left, right) =>
+ val lhsSet = splitConjunctivePredicates(left).toSet
+ val rhsSet = splitConjunctivePredicates(right).toSet
+ val common = lhsSet.intersect(rhsSet)
+ (lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And))
+ .reduceOption(Or)
+ .map(_ :: common.toList)
+ .getOrElse(common.toList)
+ .reduce(And)
+
+ // (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... =>
+ // (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...)
+ case and @ And(left, right) =>
+ val lhsSet = splitDisjunctivePredicates(left).toSet
+ val rhsSet = splitDisjunctivePredicates(right).toSet
+ val common = lhsSet.intersect(rhsSet)
+ (lhsSet.diff(common).reduceOption(Or) ++ rhsSet.diff(common).reduceOption(Or))
+ .reduceOption(And)
+ .map(_ :: common.toList)
+ .getOrElse(common.toList)
+ .reduce(Or)
+ }
+ }
+
+ private implicit class NumericLiteral(e: Literal) {
+ def toDouble = Cast(e, DoubleType).eval().asInstanceOf[Double]
+ }
+
+ object NumericLiteralBinaryComparison {
+ def unapply(e: Expression): Option[(NamedExpression, Interval[Double])] = e match {
+ case LessThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.below(l.toDouble)))
+ case LessThan(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.atOrAbove(l.toDouble)))
+
+ case GreaterThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.above(l.toDouble)))
+ case GreaterThan(l @ Literal(_, dt: NumericType), n: NamedExpression) => Some((n, Interval.atOrBelow(l.toDouble)))
+
+ case LessThanOrEqual(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.atOrBelow(l.toDouble)))
+ case LessThanOrEqual(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.above(l.toDouble)))
+
+ case GreaterThanOrEqual(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.atOrAbove(l.toDouble)))
+ case GreaterThanOrEqual(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.below(l.toDouble)))
+
+ case EqualTo(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.point(l.toDouble)))
+ }
+ }
+}
+
/**
* Simplifies boolean expressions where the answer can be determined without evaluating both sides.
* Note that this rule can eliminate expressions that might otherwise have been evaluated and thus
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConditionSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConditionSimplificationSuite.scala
new file mode 100644
index 0000000000000..743b0338fee88
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConditionSimplificationSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.expressions.{Literal, Expression}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class ConditionSimplificationSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("AnalysisNodes", Once,
+ EliminateAnalysisOperators) ::
+ Batch("Constant Folding", FixedPoint(10),
+ NullPropagation,
+ ConstantFolding,
+ ConditionSimplification,
+ BooleanSimplification,
+ SimplifyFilters) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)
+
+ def checkCondition(originCondition: Expression, optimizedCondition: Expression): Unit = {
+ val originQuery = testRelation.where(originCondition).analyze
+ val optimized = Optimize(originQuery)
+ val expected = testRelation.where(optimizedCondition).analyze
+ comparePlans(optimized, expected)
+ }
+
+ def checkCondition(originCondition: Expression): Unit = {
+ val originQuery = testRelation.where(originCondition).analyze
+ val optimized = Optimize(originQuery)
+ val expected = testRelation
+ comparePlans(optimized, expected)
+ }
+
+ test("literal in front of attribute") {
+ checkCondition(Literal(1) < 'a || Literal(2) < 'a, 'a > 1)
+ }
+
+ test("combine the same condition") {
+ checkCondition('a < 1 || 'a < 1, 'a < 1)
+ checkCondition('a < 1 || 'a < 1 || 'a < 1 || 'a < 1, 'a < 1)
+ checkCondition('a > 2 && 'a > 2, 'a > 2)
+ checkCondition('a > 2 && 'a > 2 && 'a > 2 && 'a > 2, 'a > 2)
+ checkCondition(('a < 1 && 'a < 2) || ('a < 1 && 'a < 2), 'a < 1)
+ }
+
+ test("combine literal binary comparison") {
+ checkCondition('a === 1 && 'a < 1)
+ checkCondition('a === 1 || 'a < 1, 'a <= 1)
+
+ checkCondition('a === 1 && 'a === 2)
+ checkCondition('a === 1 || 'a === 2, 'a === 1 || 'a === 2)
+
+ checkCondition('a <= 1 && 'a > 1)
+ checkCondition('a <= 1 || 'a > 1)
+
+ checkCondition('a < 1 && 'a >= 1)
+ checkCondition('a < 1 || 'a >= 1)
+
+ checkCondition('a > 3 && 'a > 2, 'a > 3)
+ checkCondition('a > 3 || 'a > 2, 'a > 2)
+
+ checkCondition('a >= 1 && 'a <= 1, 'a === 1)
+
+ }
+
+ test("different data type comparison") {
+ checkCondition('a > "abc")
+ checkCondition('a > "a" && 'a < "b")
+
+ checkCondition('a > "a" || 'a < "b")
+
+ checkCondition('a > "9" || 'a < "0", 'a > 9.0 || 'a < 0.0)
+ checkCondition('d > 9 && 'd < 1, 'd > 9.0 && 'd < 1.0 )
+
+ checkCondition('a > "9" || 'a < "0", 'a > 9.0 || 'a < 0.0)
+ }
+
+ test("combine predicate : 2 same combine") {
+ checkCondition('a < 1 || 'b > 2 || 'a >= 1)
+ checkCondition('a < 1 && 'b > 2 && 'a >= 1)
+
+ checkCondition('a < 2 || 'b > 3 || 'b > 2, 'a < 2 || 'b > 2)
+ checkCondition('a < 2 && 'b > 3 && 'b > 2, 'a < 2 && 'b > 3)
+
+ checkCondition('a < 2 || ('b > 3 || 'b > 2), 'b > 2 || 'a < 2)
+ checkCondition('a < 2 && ('b > 3 && 'b > 2), 'b > 3 && 'a < 2)
+
+ checkCondition('a < 2 || 'a === 3 || 'a > 5, 'a < 2 || 'a === 3 || 'a > 5)
+ }
+
+ test("combine predicate : 2 difference combine") {
+ checkCondition(('a < 2 || 'a > 3) && 'a > 4, 'a > 4)
+ checkCondition(('a < 2 || 'b > 3) && 'a < 2, 'a < 2)
+
+ checkCondition('a < 2 || ('a >= 2 && 'b > 1), 'b > 1 || 'a < 2)
+ checkCondition('a < 2 || ('a === 2 && 'b > 1), 'a < 2 || ('a === 2 && 'b > 1))
+
+ checkCondition('a > 3 || ('a > 2 && 'a < 4), 'a > 2)
+ }
+
+ test("multi left, single right") {
+ checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2)
+ }
+
+ test("multi left, multi right") {
+ checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5))
+
+ var input: Expression = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5)
+ var expected: Expression = 'a === 'b || ('b > 3 && 'a > 3 && 'a < 5)
+ checkCondition(input, expected)
+
+ input = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a > 1)
+ expected = 'a === 'b || ('b > 3 && 'a > 3)
+ checkCondition(input, expected)
+
+ input = ('a === 'b && 'b > 3 && 'c > 2) ||
+ ('a === 'b && 'c < 1 && 'a === 5) ||
+ ('a === 'b && 'b < 5 && 'a > 1)
+
+ expected = ('a === 'b) &&
+ (((('b > 3) && ('c > 2)) ||
+ (('c < 1) && ('a === 5))) ||
+ (('b < 5) && ('a > 1)))
+ checkCondition(input, expected)
+
+ input = ('a < 2 || 'b > 5 || 'a < 2 || 'b > 1) && ('a < 2 || 'b > 1)
+ expected = 'a < 2 || 'b > 1
+ checkCondition(input, expected)
+
+ input = ('a === 'b || 'b > 5) && ('a === 'b || 'c > 3) && ('a === 'b || 'b > 1)
+ expected = ('a === 'b) || ('c > 3 && 'b > 5)
+ checkCondition(input, expected)
+ }
+}
\ No newline at end of file