diff --git a/pom.xml b/pom.xml index e4db1393ba9cf..052ce02e125de 100644 --- a/pom.xml +++ b/pom.xml @@ -824,6 +824,11 @@ jackson-mapper-asl 1.8.8 + + org.spire-math + spire_2.10 + 0.9.0 + diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 1caa297e24e37..43f12a6b9364c 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -44,7 +44,10 @@ org.scala-lang scala-reflect - + + org.spire-math + spire_2.10 + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 94b6fb084d38a..c0fddf2af425b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -48,6 +48,14 @@ trait PredicateHelper { } } + protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = { + condition match { + case Or(cond1, cond2) => + splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2) + case other => other :: Nil + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when is is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 806c1394eb151..5fb6641686bdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet +import spire.implicits._ +import spire.math._ + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -293,6 +296,116 @@ object OptimizeIn extends Rule[LogicalPlan] { } } +/** + * Simplifies Conditions(And, Or) expressions when the conditions can by optimized. + */ +object ConditionSimplification extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + // a && a => a + case And(left, right) if left.fastEquals(right) => + left + + // a || a => a + case Or(left, right) if left.fastEquals(right) => + left + + // a < 2 && a > 2 => false, a > 3 && a > 5 => a > 5 + case and @ And( + e1 @ NumericLiteralBinaryComparison(n1, i1), + e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 => + if (!i1.intersects(i2)) Literal(false) + else if (i1.isSubsetOf(i2)) e1 + else if (i1.isSupersetOf(i2)) e2 + else and + + // a < 2 || a >= 2 => true, a > 3 || a > 5 => a > 3 + case or @ Or( + e1 @ NumericLiteralBinaryComparison(n1, i1), + e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 => + if (i1.intersects(i2)) Literal(true) + else if (i1.isSubsetOf(i2)) e2 + else if (i1.isSupersetOf(i2)) e1 + else or + + // (a < 3 && b > 5) || a > 2 => b > 5 || a > 2 + case Or(left1 @ And(left2, right2), right1) => + And(Or(left2, right1), Or(right2, right1)) + + // (a < 3 || b > 5) || a > 2 => true, (b > 5 || a < 3) || a > 2 => true + case Or( Or( + e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)), + right @ NumericLiteralBinaryComparison(n3, i3)) => + if (n3 fastEquals n1) { + Or(Or(e1, right), e2) + } else { + Or(Or(e2, right), e1) + } + + // (b > 5 && a < 2) && a > 3 => false, (a < 2 && b > 5) && a > 3 => false + case And(And( + e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)), + right @ NumericLiteralBinaryComparison(n3, i3)) => + if (n3 fastEquals n1) { + And(And(e1, right), e2) + } else { + And(And(e2, right), e1) + } + + // (a < 2 || b > 5) && a > 3 => b > 5 && a > 3 + case And(left1@Or(left2, right2), right1) => + Or(And(left2, right1), And(right2, right1)) + + // (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... => + // a && b && ((c && ...) || (d && ...) || (e && ...) || ...) + case or @ Or(left, right) => + val lhsSet = splitConjunctivePredicates(left).toSet + val rhsSet = splitConjunctivePredicates(right).toSet + val common = lhsSet.intersect(rhsSet) + (lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And)) + .reduceOption(Or) + .map(_ :: common.toList) + .getOrElse(common.toList) + .reduce(And) + + // (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... => + // (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...) + case and @ And(left, right) => + val lhsSet = splitDisjunctivePredicates(left).toSet + val rhsSet = splitDisjunctivePredicates(right).toSet + val common = lhsSet.intersect(rhsSet) + (lhsSet.diff(common).reduceOption(Or) ++ rhsSet.diff(common).reduceOption(Or)) + .reduceOption(And) + .map(_ :: common.toList) + .getOrElse(common.toList) + .reduce(Or) + } + } + + private implicit class NumericLiteral(e: Literal) { + def toDouble = Cast(e, DoubleType).eval().asInstanceOf[Double] + } + + object NumericLiteralBinaryComparison { + def unapply(e: Expression): Option[(NamedExpression, Interval[Double])] = e match { + case LessThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.below(l.toDouble))) + case LessThan(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.atOrAbove(l.toDouble))) + + case GreaterThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.above(l.toDouble))) + case GreaterThan(l @ Literal(_, dt: NumericType), n: NamedExpression) => Some((n, Interval.atOrBelow(l.toDouble))) + + case LessThanOrEqual(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.atOrBelow(l.toDouble))) + case LessThanOrEqual(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.above(l.toDouble))) + + case GreaterThanOrEqual(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.atOrAbove(l.toDouble))) + case GreaterThanOrEqual(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.below(l.toDouble))) + + case EqualTo(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.point(l.toDouble))) + } + } +} + /** * Simplifies boolean expressions where the answer can be determined without evaluating both sides. * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConditionSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConditionSimplificationSuite.scala new file mode 100644 index 0000000000000..743b0338fee88 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConditionSimplificationSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.expressions.{Literal, Expression} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class ConditionSimplificationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateAnalysisOperators) :: + Batch("Constant Folding", FixedPoint(10), + NullPropagation, + ConstantFolding, + ConditionSimplification, + BooleanSimplification, + SimplifyFilters) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) + + def checkCondition(originCondition: Expression, optimizedCondition: Expression): Unit = { + val originQuery = testRelation.where(originCondition).analyze + val optimized = Optimize(originQuery) + val expected = testRelation.where(optimizedCondition).analyze + comparePlans(optimized, expected) + } + + def checkCondition(originCondition: Expression): Unit = { + val originQuery = testRelation.where(originCondition).analyze + val optimized = Optimize(originQuery) + val expected = testRelation + comparePlans(optimized, expected) + } + + test("literal in front of attribute") { + checkCondition(Literal(1) < 'a || Literal(2) < 'a, 'a > 1) + } + + test("combine the same condition") { + checkCondition('a < 1 || 'a < 1, 'a < 1) + checkCondition('a < 1 || 'a < 1 || 'a < 1 || 'a < 1, 'a < 1) + checkCondition('a > 2 && 'a > 2, 'a > 2) + checkCondition('a > 2 && 'a > 2 && 'a > 2 && 'a > 2, 'a > 2) + checkCondition(('a < 1 && 'a < 2) || ('a < 1 && 'a < 2), 'a < 1) + } + + test("combine literal binary comparison") { + checkCondition('a === 1 && 'a < 1) + checkCondition('a === 1 || 'a < 1, 'a <= 1) + + checkCondition('a === 1 && 'a === 2) + checkCondition('a === 1 || 'a === 2, 'a === 1 || 'a === 2) + + checkCondition('a <= 1 && 'a > 1) + checkCondition('a <= 1 || 'a > 1) + + checkCondition('a < 1 && 'a >= 1) + checkCondition('a < 1 || 'a >= 1) + + checkCondition('a > 3 && 'a > 2, 'a > 3) + checkCondition('a > 3 || 'a > 2, 'a > 2) + + checkCondition('a >= 1 && 'a <= 1, 'a === 1) + + } + + test("different data type comparison") { + checkCondition('a > "abc") + checkCondition('a > "a" && 'a < "b") + + checkCondition('a > "a" || 'a < "b") + + checkCondition('a > "9" || 'a < "0", 'a > 9.0 || 'a < 0.0) + checkCondition('d > 9 && 'd < 1, 'd > 9.0 && 'd < 1.0 ) + + checkCondition('a > "9" || 'a < "0", 'a > 9.0 || 'a < 0.0) + } + + test("combine predicate : 2 same combine") { + checkCondition('a < 1 || 'b > 2 || 'a >= 1) + checkCondition('a < 1 && 'b > 2 && 'a >= 1) + + checkCondition('a < 2 || 'b > 3 || 'b > 2, 'a < 2 || 'b > 2) + checkCondition('a < 2 && 'b > 3 && 'b > 2, 'a < 2 && 'b > 3) + + checkCondition('a < 2 || ('b > 3 || 'b > 2), 'b > 2 || 'a < 2) + checkCondition('a < 2 && ('b > 3 && 'b > 2), 'b > 3 && 'a < 2) + + checkCondition('a < 2 || 'a === 3 || 'a > 5, 'a < 2 || 'a === 3 || 'a > 5) + } + + test("combine predicate : 2 difference combine") { + checkCondition(('a < 2 || 'a > 3) && 'a > 4, 'a > 4) + checkCondition(('a < 2 || 'b > 3) && 'a < 2, 'a < 2) + + checkCondition('a < 2 || ('a >= 2 && 'b > 1), 'b > 1 || 'a < 2) + checkCondition('a < 2 || ('a === 2 && 'b > 1), 'a < 2 || ('a === 2 && 'b > 1)) + + checkCondition('a > 3 || ('a > 2 && 'a < 4), 'a > 2) + } + + test("multi left, single right") { + checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2) + } + + test("multi left, multi right") { + checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) + + var input: Expression = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5) + var expected: Expression = 'a === 'b || ('b > 3 && 'a > 3 && 'a < 5) + checkCondition(input, expected) + + input = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a > 1) + expected = 'a === 'b || ('b > 3 && 'a > 3) + checkCondition(input, expected) + + input = ('a === 'b && 'b > 3 && 'c > 2) || + ('a === 'b && 'c < 1 && 'a === 5) || + ('a === 'b && 'b < 5 && 'a > 1) + + expected = ('a === 'b) && + (((('b > 3) && ('c > 2)) || + (('c < 1) && ('a === 5))) || + (('b < 5) && ('a > 1))) + checkCondition(input, expected) + + input = ('a < 2 || 'b > 5 || 'a < 2 || 'b > 1) && ('a < 2 || 'b > 1) + expected = 'a < 2 || 'b > 1 + checkCondition(input, expected) + + input = ('a === 'b || 'b > 5) && ('a === 'b || 'c > 3) && ('a === 'b || 'b > 1) + expected = ('a === 'b) || ('c > 3 && 'b > 5) + checkCondition(input, expected) + } +} \ No newline at end of file