diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c77e0f82dc253..e5849ce0b3fb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -46,7 +46,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) - def batches: Seq[Batch] = { + def defaultBatches: Seq[Batch] = { val operatorOptimizationRuleSet = Seq( // Operator push down @@ -158,6 +158,22 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RemoveRedundantProject) } + def nonExcludableRules: Seq[String] = + EliminateDistinct.ruleName :: + EliminateSubqueryAliases.ruleName :: + EliminateView.ruleName :: + ReplaceExpressions.ruleName :: + ComputeCurrentTime.ruleName :: + GetCurrentDatabase(sessionCatalog).ruleName :: + RewriteDistinctAggregates.ruleName :: + ReplaceDeduplicateWithAggregate.ruleName :: + ReplaceIntersectWithSemiJoin.ruleName :: + ReplaceExceptWithFilter.ruleName :: + ReplaceExceptWithAntiJoin.ruleName :: + ReplaceDistinctWithAggregate.ruleName :: + PullupCorrelatedPredicates.ruleName :: + RewritePredicateSubquery.ruleName :: Nil + /** * Optimize all the subqueries inside expression. */ @@ -173,6 +189,41 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) * Override to provide additional rules for the operator optimization batch. */ def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + override def batches: Seq[Batch] = { + val excludedRulesConf = + SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq) + val excludedRules = excludedRulesConf.filter { ruleName => + val nonExcludable = nonExcludableRules.contains(ruleName) + if (nonExcludable) { + logWarning(s"Optimization rule '${ruleName}' was not excluded from the optimizer " + + s"because this rule is a non-excludable rule.") + } + !nonExcludable + } + if (excludedRules.isEmpty) { + defaultBatches + } else { + defaultBatches.flatMap { batch => + val filteredRules = batch.rules.filter { rule => + val exclude = excludedRules.contains(rule.ruleName) + if (exclude) { + logInfo(s"Optimization rule '${rule.ruleName}' is excluded from the optimizer.") + } + !exclude + } + if (batch.rules == filteredRules) { + Some(batch) + } else if (filteredRules.nonEmpty) { + Some(Batch(batch.name, batch.strategy, filteredRules: _*)) + } else { + logInfo(s"Optimization batch '${batch.name}' is excluded from the optimizer " + + s"as all enclosed rules have been excluded.") + None + } + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bcf8ded6f6344..362d56d9c4701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -109,6 +109,14 @@ object SQLConf { */ def get: SQLConf = confGetter.get()() + val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") + .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + + "specified by their rule names and separated by comma. It is not guaranteed that all the " + + "rules in this configuration will eventually be excluded, as some rules are necessary " + + "for correctness. The optimizer will log the rules that have indeed been excluded.") + .stringConf + .createOptional + val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() .doc("The max number of iterations the optimizer and analyzer runs.") @@ -1226,6 +1234,8 @@ class SQLConf extends Serializable with Logging { /** ************************ Spark SQL Params/Hints ******************* */ + def optimizerExcludedRules: Option[String] = getConf(OPTIMIZER_EXCLUDED_RULES) + def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala new file mode 100644 index 0000000000000..5a5396e6f58b0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_EXCLUDED_RULES + + +class OptimizerRuleExclusionSuite extends PlanTest { + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + private def verifyExcludedRules(excludedRuleNames: Seq[String]) { + val optimizer = new SimpleTestOptimizer() + // Batches whose rules are all to be excluded should be removed as a whole. + val excludedBatchNames = optimizer.batches + .filter(batch => batch.rules.forall(rule => excludedRuleNames.contains(rule.ruleName))) + .map(_.name) + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRuleNames.foldLeft("")((l, r) => l + "," + r)) { + val batches = optimizer.batches + assert(batches.forall(batch => !excludedBatchNames.contains(batch.name))) + assert( + batches + .forall(batch => batch.rules.forall(rule => !excludedRuleNames.contains(rule.ruleName)))) + } + } + + test("Exclude a single rule from multiple batches") { + verifyExcludedRules( + Seq( + PushPredicateThroughJoin.ruleName)) + } + + test("Exclude multiple rules from single or multiple batches") { + verifyExcludedRules( + Seq( + CombineUnions.ruleName, + RemoveLiteralFromGroupExpressions.ruleName, + RemoveRepetitionFromGroupExpressions.ruleName)) + } + + test("Exclude non-existent rule with other valid rules") { + verifyExcludedRules( + Seq( + LimitPushDown.ruleName, + InferFiltersFromConstraints.ruleName, + "DummyRuleName")) + } + + test("Try to exclude a non-excludable rule") { + val excludedRules = Seq( + ReplaceIntersectWithSemiJoin.ruleName, + PullupCorrelatedPredicates.ruleName) + + val optimizer = new SimpleTestOptimizer() + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) { + excludedRules.foreach { excludedRule => + assert( + optimizer.batches + .exists(batch => batch.rules.exists(rule => rule.ruleName == excludedRule))) + } + } + } + + test("Verify optimized plan after excluding CombineUnions rule") { + val excludedRules = Seq( + ConvertToLocalRelation.ruleName, + PropagateEmptyRelation.ruleName, + CombineUnions.ruleName) + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) { + val optimizer = new SimpleTestOptimizer() + val originalQuery = testRelation.union(testRelation.union(testRelation)).analyze + val optimized = optimizer.execute(originalQuery) + comparePlans(originalQuery, optimized) + } + } +}