diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 57d25069ef548..bfd24287c9645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -41,7 +41,6 @@ object DefaultOptimizer extends Optimizer { Batch("Operator Optimizations", FixedPoint(100), // Operator push down UnionPushDown, - LimitPushDown, PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -112,20 +111,6 @@ object UnionPushDown extends Rule[LogicalPlan] { } } -object LimitPushDown extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down limit when the child is project on limit. - case Limit(expr, Project(projectList, l: Limit)) => - Project(projectList, Limit(expr, l)) - - // Push down limit when the child is project on sort, - // and we cannot push down this project through sort. - case Limit(expr, p @ Project(projectList, s: Sort)) - if !s.references.subsetOf(p.outputSet) => - Project(projectList, Limit(expr, s)) - } -} - /** * Attempts to eliminate the reading of unneeded columns from the query plan using the following * transformations: @@ -175,7 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] { Join(left, prunedChild(right, allReferences), LeftSemi, condition) - // push down project if possible when the child is sort + // Push down project through limit, so that we may have chance to push it further. + case Project(projectList, Limit(exp, child)) => + Limit(exp, Project(projectList, child)) + + // Push down project if possible when the child is sort case p @ Project(projectList, s @ Sort(_, _, grandChild)) if s.references.subsetOf(p.outputSet) => s.copy(child = Project(projectList, grandChild)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index fc78366e0c33f..ffdc673cdc455 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -95,6 +95,22 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("column pruning for Project(ne, Limit)") { + val originalQuery = + testRelation + .select('a, 'b) + .limit(2) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + // After this line is unimplemented. test("simple push down") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala deleted file mode 100644 index c8168b32eef94..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ - -class LimitPushDownSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Limit PushDown", FixedPoint(10), - LimitPushDown, - CombineLimits, - ConstantFolding, - BooleanSimplification) :: Nil - } - - val testRelation = LocalRelation('a.int) - - test("push down limit when the child is project on limit") { - val originalQuery = - testRelation - .limit(10) - .select('a) - .limit(2) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .limit(2) - .select('a).analyze - - comparePlans(optimized, correctAnswer) - } - - test("push down limit when the child is project on sort") { - val originalQuery = - testRelation - .sortBy(SortOrder('a, Ascending)) - .select('a) - .limit(2) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .sortBy(SortOrder('a, Ascending)) - .limit(2) - .select('a).analyze - - comparePlans(optimized, correctAnswer) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 04fc798bf3738..5708df82de12f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext) experimental.extraStrategies ++ ( DataSourceStrategy :: DDLStrategy :: - TakeOrdered :: + TakeOrderedAndProject :: HashAggregation :: LeftSemiJoin :: HashJoin :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1ff1cc224de8c..21912cf24933e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1) - object TakeOrdered extends Strategy { + object TakeOrderedAndProject extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrdered(limit, order, planLater(child)) :: Nil + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fbb9b52d0f92c..c97c2031d33ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -147,12 +147,18 @@ case class Limit(limit: Int, child: SparkPlan) /** * :: DeveloperApi :: - * Take the first limit elements as defined by the sortOrder. This is logically equivalent to - * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but - * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, + * or having a [[Project]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { +case class TakeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -160,17 +166,22 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) - private def collectData(): Array[InternalRow] = - child.execute().map(_.copy()).takeOrdered(limit)(ord) + private val projection = projectList.map(newProjection(_, child.output)) + + private def collectData(): Iterator[InternalRow] = { + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord).toIterator + projection.map(data.map(_)).getOrElse(data) + } override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) - collectData().map(converter(_).asInstanceOf[Row]) + collectData().map(converter(_).asInstanceOf[Row]).toArray } // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) + protected override def doExecute(): RDD[InternalRow] = + sparkContext.makeRDD(collectData().toArray[InternalRow], 1) override def outputOrdering: Seq[SortOrder] = sortOrder } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5854ab48db552..3dd24130af81a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite { setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } + + test("efficient limit -> project -> sort") { + val query = testData.sort('key).select('value).limit(2).logicalPlan + val planned = planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cf05c6c989655..8021f915bb821 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -442,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrdered, + TakeOrderedAndProject, ParquetOperations, InMemoryScans, ParquetConversion, // Must be before HiveTableScans