diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6de17e5924d04..a5102b1018a8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1510,12 +1510,13 @@ class DataFrame private[sql]( } private def collect(needCallback: Boolean): Array[Row] = { - def execute(): Array[Row] = withNewExecutionId { - queryExecution.executedPlan.executeCollectPublic() + val dfToExecute = withPlan(ReturnAnswer(logicalPlan)) + def execute(): Array[Row] = dfToExecute.withNewExecutionId { + dfToExecute.queryExecution.executedPlan.executeCollectPublic() } if (needCallback) { - withCallback("collect", this)(_ => execute()) + withCallback("collect", dfToExecute)(_ => execute()) } else { execute() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 8616fe317034f..8e9f1d6aaaff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -44,7 +44,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next() + sqlContext.planner.plan(optimizedPlan).next() } // executedPlan should not be used to initialize any SparkPlan. It should be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 5acde5996b543..615a5a94445ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -27,13 +27,15 @@ import org.apache.spark.sql.catalyst.plans.physical._ * Take the first `limit` elements and collect them to a single partition. * * This operator will be used when a logical `Limit` operation is the final operator in an - * logical plan, which happens when the user is collecting results back to the driver. + * logical plan and the user is collecting results back to the driver. */ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition override def executeCollect(): Array[InternalRow] = child.executeTake(limit) - protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(executeCollect(), 1) + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException("doExecute() should not be called on CollectLimit()") + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2ff1bfaf95153..f62f8a2df81ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -182,12 +182,6 @@ class PlannerSuite extends SharedSQLContext { } } - test("terminal limits use CollectLimit") { - val query = testData.select('value).limit(2) - val planned = query.queryExecution.sparkPlan - assert(planned.isInstanceOf[CollectLimit]) - } - test("PartitioningCollection") { withTempTable("normal", "small", "tiny") { testData.registerTempTable("normal")