diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 8ade4350087bf..21f92382043cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -252,6 +252,7 @@ class StreamExecution( */ private def runBatches(): Unit = { try { + sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString) if (sparkSession.sessionState.conf.streamingMetricsEnabled) { sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) } @@ -308,6 +309,7 @@ class StreamExecution( logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) } else { currentStatus = currentStatus.copy(isDataAvailable = false) updateStatusMessage("Waiting for data to arrive") @@ -418,6 +420,7 @@ class StreamExecution( /* First assume that we are re-executing the latest known batch * in the offset log */ currentBatchId = latestBatchId + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) availableOffsets = nextOffsets.toStreamProgress(sources) /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. */ @@ -463,6 +466,7 @@ class StreamExecution( } } currentBatchId = latestCommittedBatchId + 1 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets constructNextBatch() @@ -478,6 +482,7 @@ class StreamExecution( case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) constructNextBatch() } } @@ -590,8 +595,6 @@ class StreamExecution( * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { - sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString) - // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -686,8 +689,11 @@ class StreamExecution( // intentionally state.set(TERMINATED) if (microBatchThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) microBatchThread.interrupt() microBatchThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) } logInfo(s"Query $prettyIdString was stopped") } @@ -828,7 +834,9 @@ class StreamExecution( } private def getBatchDescriptionString: String = { - Option(name).map(_ + " ").getOrElse("") + s"[batch = $currentBatchId, id = $id, runId = $runId]" + val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString + Option(name).map(_ + " ").getOrElse("") + + s"[batch = $batchDescription, id = $id, runId = $runId]" } }