Skip to content

Commit

Permalink
[SPARK-12912][SQL] Add a test suite for EliminateSubQueries
Browse files Browse the repository at this point in the history
Also updated documentation to explain why ComputeCurrentTime and EliminateSubQueries are in the optimizer rather than analyzer.

Author: Reynold Xin <[email protected]>

Closes #10837 from rxin/optimizer-analyzer-comment.
  • Loading branch information
rxin committed Jan 20, 2016
1 parent 6844d36 commit 753b194
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ class Analyzer(
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
object ResolveRelations extends Rule[LogicalPlan] {
def getTable(u: UnresolvedRelation): LogicalPlan = {
private def getTable(u: UnresolvedRelation): LogicalPlan = {
try {
catalog.lookupRelation(u.tableIdentifier, u.alias)
} catch {
Expand Down Expand Up @@ -1165,7 +1165,7 @@ class Analyzer(
* scoping information for attributes and can be removed once analysis is complete.
*/
object EliminateSubQueries extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Subquery(_, child) => child
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,16 @@ import org.apache.spark.sql.types._
*/
abstract class Optimizer extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch] = {
// SubQueries are only needed for analysis and can be removed before execution.
Batch("Remove SubQueries", FixedPoint(100),
EliminateSubQueries) ::
Batch("Compute Current Time", Once,
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
// in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime).
// However, because we also use the analyzer to canonicalized queries (for view definition),
// we do not eliminate subqueries or compute current time in the analyzer.
Batch("Finish Analysis", Once,
EliminateSubQueries,
ComputeCurrentTime) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
Expand All @@ -57,7 +62,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ProjectCollapsing,
CombineFilters,
CombineLimits,
// Constant folding
// Constant folding and strength reduction
NullPropagation,
OptimizeIn,
ConstantFolding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* When `rule` does not apply to a given node it is left unchanged.
* Users should not expect a specific directionality. If a specific directionality is needed,
* transformDown or transformUp should be used.
*
* @param rule the function use to transform this nodes children
*/
def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = {
Expand All @@ -253,6 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
/**
* Returns a copy of this node where `rule` has been recursively applied to it and all of its
* children (pre-order). When `rule` does not apply to a given node it is left unchanged.
*
* @param rule the function used to transform this nodes children
*/
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
Expand All @@ -268,6 +270,26 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}

/**
* Returns a copy of this node where `rule` has been recursively applied first to all of its
* children and then itself (post-order). When `rule` does not apply to a given node, it is left
* unchanged.
*
* @param rule the function use to transform this nodes children
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
}
} else {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
}
}
}

/**
* Returns a copy of this node where `rule` has been recursively applied to all the children of
* this node. When `rule` does not apply to a given node it is left unchanged.
Expand Down Expand Up @@ -332,25 +354,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
if (changed) makeCopy(newArgs) else this
}

/**
* Returns a copy of this node where `rule` has been recursively applied first to all of its
* children and then itself (post-order). When `rule` does not apply to a given node, it is left
* unchanged.
* @param rule the function use to transform this nodes children
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
}
} else {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
}
}
}

/**
* Args to the constructor that should be copied, but not transformed.
* These are appended to the transformed args automatically by makeCopy
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class EliminateSubQueriesSuite extends PlanTest with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("EliminateSubQueries", Once, EliminateSubQueries) :: Nil
}

private def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
comparePlans(actual, correctAnswer)
}

private def afterOptimization(plan: LogicalPlan): LogicalPlan = {
Optimize.execute(analysis.SimpleAnalyzer.execute(plan))
}

test("eliminate top level subquery") {
val input = LocalRelation('a.int, 'b.int)
val query = Subquery("a", input)
comparePlans(afterOptimization(query), input)
}

test("eliminate mid-tree subquery") {
val input = LocalRelation('a.int, 'b.int)
val query = Filter(TrueLiteral, Subquery("a", input))
comparePlans(
afterOptimization(query),
Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
}

test("eliminate multiple subqueries") {
val input = LocalRelation('a.int, 'b.int)
val query = Filter(TrueLiteral, Subquery("c", Subquery("b", Subquery("a", input))))
comparePlans(
afterOptimization(query),
Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
}

}

0 comments on commit 753b194

Please sign in to comment.