Skip to content

Commit

Permalink
Move callsite of ReturnAnswer to fix caching / .rdd().
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Feb 5, 2016
1 parent 55e27af commit c4b0a53
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 12 deletions.
7 changes: 4 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1510,12 +1510,13 @@ class DataFrame private[sql](
}

private def collect(needCallback: Boolean): Array[Row] = {
def execute(): Array[Row] = withNewExecutionId {
queryExecution.executedPlan.executeCollectPublic()
val dfToExecute = withPlan(ReturnAnswer(logicalPlan))
def execute(): Array[Row] = dfToExecute.withNewExecutionId {
dfToExecute.queryExecution.executedPlan.executeCollectPublic()
}

if (needCallback) {
withCallback("collect", this)(_ => execute())
withCallback("collect", dfToExecute)(_ => execute())
} else {
execute()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {

lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next()
sqlContext.planner.plan(optimizedPlan).next()
}

// executedPlan should not be used to initialize any SparkPlan. It should be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ import org.apache.spark.sql.catalyst.plans.physical._
* Take the first `limit` elements and collect them to a single partition.
*
* This operator will be used when a logical `Limit` operation is the final operator in an
* logical plan, which happens when the user is collecting results back to the driver.
* logical plan and the user is collecting results back to the driver.
*/
case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(executeCollect(), 1)
protected override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException("doExecute() should not be called on CollectLimit()")
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,6 @@ class PlannerSuite extends SharedSQLContext {
}
}

test("terminal limits use CollectLimit") {
val query = testData.select('value).limit(2)
val planned = query.queryExecution.sparkPlan
assert(planned.isInstanceOf[CollectLimit])
}

test("PartitioningCollection") {
withTempTable("normal", "small", "tiny") {
testData.registerTempTable("normal")
Expand Down

0 comments on commit c4b0a53

Please sign in to comment.