diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index adb1350adc261..3c264eb8586b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -46,6 +46,13 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) + /** + * Defines the default rule batches in the Optimizer. + * + * Implementations of this class should override this method, and [[nonExcludableRules]] if + * necessary, instead of [[batches]]. The rule batches that eventually run in the Optimizer, + * i.e., returned by [[batches]], will be (defaultBatches - (excludedRules - nonExcludableRules)). + */ def defaultBatches: Seq[Batch] = { val operatorOptimizationRuleSet = Seq( @@ -160,6 +167,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) UpdateNullabilityInAttributeReferences) } + /** + * Defines rules that cannot be excluded from the Optimizer even if they are specified in + * SQL config "excludedRules". + * + * Implementations of this class can override this method if necessary. The rule batches + * that eventually run in the Optimizer, i.e., returned by [[batches]], will be + * (defaultBatches - (excludedRules - nonExcludableRules)). + */ def nonExcludableRules: Seq[String] = EliminateDistinct.ruleName :: EliminateSubqueryAliases.ruleName :: @@ -202,7 +217,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) */ def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil - override def batches: Seq[Batch] = { + /** + * Returns (defaultBatches - (excludedRules - nonExcludableRules)), the rule batches that + * eventually run in the Optimizer. + * + * Implementations of this class should override [[defaultBatches]], and [[nonExcludableRules]] + * if necessary, instead of this method. + */ + final override def batches: Seq[Batch] = { val excludedRulesConf = SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq) val excludedRules = excludedRulesConf.filter { ruleName => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala index 7112c033eabce..36b083a540c3c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -47,7 +47,7 @@ class OptimizerExtendableSuite extends SparkFunSuite { DummyRule) :: Nil } - override def batches: Seq[Batch] = super.batches ++ myBatches + override def defaultBatches: Seq[Batch] = super.defaultBatches ++ myBatches } test("Extending batches possible") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala index 5a5396e6f58b0..30c80d26b67a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -28,8 +28,10 @@ class OptimizerRuleExclusionSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - private def verifyExcludedRules(excludedRuleNames: Seq[String]) { - val optimizer = new SimpleTestOptimizer() + private def verifyExcludedRules(optimizer: Optimizer, rulesToExclude: Seq[String]) { + val nonExcludableRules = optimizer.nonExcludableRules + + val excludedRuleNames = rulesToExclude.filter(!nonExcludableRules.contains(_)) // Batches whose rules are all to be excluded should be removed as a whole. val excludedBatchNames = optimizer.batches .filter(batch => batch.rules.forall(rule => excludedRuleNames.contains(rule.ruleName))) @@ -38,21 +40,31 @@ class OptimizerRuleExclusionSuite extends PlanTest { withSQLConf( OPTIMIZER_EXCLUDED_RULES.key -> excludedRuleNames.foldLeft("")((l, r) => l + "," + r)) { val batches = optimizer.batches + // Verify removed batches. assert(batches.forall(batch => !excludedBatchNames.contains(batch.name))) + // Verify removed rules. assert( batches .forall(batch => batch.rules.forall(rule => !excludedRuleNames.contains(rule.ruleName)))) + // Verify non-excludable rules retained. + nonExcludableRules.foreach { nonExcludableRule => + assert( + optimizer.batches + .exists(batch => batch.rules.exists(rule => rule.ruleName == nonExcludableRule))) + } } } test("Exclude a single rule from multiple batches") { verifyExcludedRules( + new SimpleTestOptimizer(), Seq( PushPredicateThroughJoin.ruleName)) } test("Exclude multiple rules from single or multiple batches") { verifyExcludedRules( + new SimpleTestOptimizer(), Seq( CombineUnions.ruleName, RemoveLiteralFromGroupExpressions.ruleName, @@ -61,6 +73,7 @@ class OptimizerRuleExclusionSuite extends PlanTest { test("Exclude non-existent rule with other valid rules") { verifyExcludedRules( + new SimpleTestOptimizer(), Seq( LimitPushDown.ruleName, InferFiltersFromConstraints.ruleName, @@ -68,20 +81,34 @@ class OptimizerRuleExclusionSuite extends PlanTest { } test("Try to exclude a non-excludable rule") { - val excludedRules = Seq( - ReplaceIntersectWithSemiJoin.ruleName, - PullupCorrelatedPredicates.ruleName) + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + ReplaceIntersectWithSemiJoin.ruleName, + PullupCorrelatedPredicates.ruleName)) + } - val optimizer = new SimpleTestOptimizer() + test("Custom optimizer") { + val optimizer = new SimpleTestOptimizer() { + override def defaultBatches: Seq[Batch] = + Batch("push", Once, + PushDownPredicate, + PushPredicateThroughJoin, + PushProjectionThroughUnion) :: + Batch("pull", Once, + PullupCorrelatedPredicates) :: Nil - withSQLConf( - OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) { - excludedRules.foreach { excludedRule => - assert( - optimizer.batches - .exists(batch => batch.rules.exists(rule => rule.ruleName == excludedRule))) - } + override def nonExcludableRules: Seq[String] = + PushDownPredicate.ruleName :: + PullupCorrelatedPredicates.ruleName :: Nil } + + verifyExcludedRules( + optimizer, + Seq( + PushDownPredicate.ruleName, + PushProjectionThroughUnion.ruleName, + PullupCorrelatedPredicates.ruleName)) } test("Verify optimized plan after excluding CombineUnions rule") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala index 6e183d81b7265..a22a81e9844d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala @@ -44,7 +44,7 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { EmptyFunctionRegistry, new SQLConf())) { val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI) - override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches + override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches } test("check for invalid plan after execution of rule") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00ff4c8ac310b..64d3f2cdbfa82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -28,13 +28,16 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog) { - override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ + override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + override def nonExcludableRules: Seq[String] = + super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName + /** * Optimization batches that are executed before the regular optimization batches (also before * the finish analysis batch).