diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala index 02e6625b2028c..20f13c280c12c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.adaptive +import java.util.Properties + import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import org.apache.spark.MapOutputStatistics -import org.apache.spark.broadcast +import org.apache.spark.{broadcast, MapOutputStatistics, SparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -48,12 +49,22 @@ abstract class QueryStage extends UnaryExecNode { override def outputOrdering: Seq[SortOrder] = child.outputOrdering + def withLocalProperties[T](sc: SparkContext, properties: Properties)(body: => T): T = { + val oldProperties = sc.getLocalProperties + try { + sc.setLocalProperties(properties) + body + } finally { + sc.setLocalProperties(oldProperties) + } + } + /** * Execute childStages and wait until all stages are completed. Use a thread pool to avoid * blocking on one child stage. */ def executeChildStages(): Unit = { - val executionId = sqlContext.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + val localProperties = sqlContext.sparkContext.getLocalProperties // Handle broadcast stages val broadcastQueryStages: Seq[BroadcastQueryStage] = child.collect { @@ -61,7 +72,7 @@ abstract class QueryStage extends UnaryExecNode { } val broadcastFutures = broadcastQueryStages.map { queryStage => Future { - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + withLocalProperties(sqlContext.sparkContext, localProperties) { queryStage.prepareBroadcast() } }(QueryStage.executionContext) @@ -73,7 +84,7 @@ abstract class QueryStage extends UnaryExecNode { } val shuffleStageFutures = shuffleQueryStages.map { queryStage => Future { - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + withLocalProperties(sqlContext.sparkContext, localProperties) { queryStage.execute() } }(QueryStage.executionContext)