From 5d80d8c6a54b2113022eff31187e6d97521bd2cf Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Dec 2015 11:03:59 -0800 Subject: [PATCH] [SPARK-11932][STREAMING] Partition previous TrackStateRDD if partitioner not present The reason is that TrackStateRDDs generated by trackStateByKey expect the previous batch's TrackStateRDDs to have a partitioner. However, when recovery from DStream checkpoints, the RDDs recovered from RDD checkpoints do not have a partitioner attached to it. This is because RDD checkpoints do not preserve the partitioner (SPARK-12004). While #9983 solves SPARK-12004 by preserving the partitioner through RDD checkpoints, there may be a non-zero chance that the saving and recovery fails. To be resilient, this PR repartitions the previous state RDD if the partitioner is not detected. Author: Tathagata Das Closes #9988 from tdas/SPARK-11932. --- .../apache/spark/streaming/Checkpoint.scala | 2 +- .../streaming/dstream/TrackStateDStream.scala | 39 ++-- .../spark/streaming/rdd/TrackStateRDD.scala | 29 ++- .../spark/streaming/CheckpointSuite.scala | 189 +++++++++++++----- .../spark/streaming/TestSuiteBase.scala | 6 + .../streaming/TrackStateByKeySuite.scala | 77 +++++-- 6 files changed, 258 insertions(+), 84 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index fd0e8d5d690b6..d0046afdeb447 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -277,7 +277,7 @@ class CheckpointWriter( val bytes = Checkpoint.serialize(checkpoint, conf) executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) - logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") + logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 0ada1111ce30a..ea6213420e7ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD - val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { - TrackStateRDD.createFromPairRDD[K, V, S, E]( - spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), - partitioner, validTime - ) + val prevStateRDD = getOrCompute(validTime - slideDuration) match { + case Some(rdd) => + if (rdd.partitioner != Some(partitioner)) { + // If the RDD is not partitioned the right way, let us repartition it using the + // partition index as the key. This is to ensure that state RDD is always partitioned + // before creating another state RDD using it + TrackStateRDD.createFromRDD[K, V, S, E]( + rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) + } else { + rdd + } + case None => + TrackStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime + ) } + // Compute the new state RDD with previous state RDD and partitioned data RDD - parent.getOrCompute(validTime).map { dataRDD => - val partitionedDataRDD = dataRDD.partitionBy(partitioner) - val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => - (validTime - interval).milliseconds - } - new TrackStateRDD( - prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime) + // Even if there is no data RDD, use an empty one to create a new state RDD + val dataRDD = parent.getOrCompute(validTime).getOrElse { + context.sparkContext.emptyRDD[(K, V)] + } + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds } + Some(new TrackStateRDD( + prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index 7050378d0feb0..30aafcf1460e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: private[streaming] object TrackStateRDD { - def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( pairRDD: RDD[(K, S)], partitioner: Partitioner, - updateTime: Time): TrackStateRDD[K, V, S, T] = { + updateTime: Time): TrackStateRDD[K, V, S, E] = { val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } - Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None - new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + } + + def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + rdd: RDD[(K, S, Long)], + partitioner: Partitioner, + updateTime: Time): TrackStateRDD[K, V, S, E] = { + + val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) } + val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, (state, updateTime)) => + stateMap.put(key, state, updateTime) + } + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index b1cbc7163bee3..cd28d3cf408d5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -33,17 +33,149 @@ import org.mockito.Mockito.mock import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.TestUtils +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} +/** + * A trait of that can be mixed in to get methods for testing DStream operations under + * DStream checkpointing. Note that the implementations of this trait has to implement + * the `setupCheckpointOperation` + */ +trait DStreamCheckpointTester { self: SparkFunSuite => + + /** + * Tests a streaming operation under checkpointing, by restarting the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. + */ + protected def testCheckpointedOperation[U: ClassTag, V: ClassTag]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatchesBeforeRestart: Int, + batchDuration: Duration = Milliseconds(500), + stopSparkContextAfterTest: Boolean = true + ) { + require(numBatchesBeforeRestart < expectedOutput.size, + "Number of batches before context restart less than number of expected output " + + "(i.e. number of total batches to run)") + require(StreamingContext.getActive().isEmpty, + "Cannot run test with already active streaming context") + + // Current code assumes that number of batches to be run = number of inputs + val totalNumBatches = input.size + val batchDurationMillis = batchDuration.milliseconds + + // Setup the stream computation + val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + logDebug(s"Using checkpoint directory $checkpointDir") + val ssc = createContextForCheckpointOperation(batchDuration) + require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, + "Cannot run test without manual clock in the conf") + + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val operatedStream = operation(inputStream) + operatedStream.print() + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) + outputStream.register() + ssc.checkpoint(checkpointDir) + + // Do the computation for initial number of batches, create checkpoint file and quit + val beforeRestartOutput = generateOutput[V](ssc, + Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest) + assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true) + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + + val restartedSsc = new StreamingContext(checkpointDir) + val afterRestartOutput = generateOutput[V](restartedSsc, + Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest) + assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false) + } + + protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = { + val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) + } + + private def generateOutput[V: ClassTag]( + ssc: StreamingContext, + targetBatchTime: Time, + checkpointDir: String, + stopSparkContext: Boolean + ): Seq[Seq[V]] = { + try { + val batchDuration = ssc.graph.batchDuration + val batchCounter = new BatchCounter(ssc) + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val currentTime = clock.getTimeMillis() + + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) + clock.setTime(targetBatchTime.milliseconds) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) + + val outputStream = ssc.graph.getOutputStreams().filter { dstream => + dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] + }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + + eventually(timeout(10 seconds)) { + ssc.awaitTerminationOrTimeout(10) + assert(batchCounter.getLastCompletedBatchTime === targetBatchTime) + } + + eventually(timeout(10 seconds)) { + val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter { + _.toString.contains(clock.getTimeMillis.toString) + } + // Checkpoint files are written twice for every batch interval. So assert that both + // are written to make sure that both of them have been written. + assert(checkpointFilesOfLatestTime.size === 2) + } + outputStream.output.map(_.flatten) + + } finally { + ssc.stop(stopSparkContext = stopSparkContext) + } + } + + private def assertOutput[V: ClassTag]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + beforeRestart: Boolean): Unit = { + val expectedPartialOutput = if (beforeRestart) { + expectedOutput.take(output.size) + } else { + expectedOutput.takeRight(output.size) + } + val setComparison = output.zip(expectedPartialOutput).forall { + case (o, e) => o.toSet === e.toSet + } + assert(setComparison, s"set comparison failed\n" + + s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" + + s"Generated output items: ${output.mkString("\n")}" + ) + } +} + /** * This test suites tests the checkpointing functionality of DStreams - * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { var ssc: StreamingContext = null @@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase { override def afterFunction() { super.afterFunction() - if (ssc != null) ssc.stop() + if (ssc != null) { ssc.stop() } Utils.deleteRecursively(new File(checkpointDir)) } @@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase { Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), - Seq(("", 2)), Seq() ), + Seq(("", 2)), + Seq() + ), 3 ) } @@ -634,53 +768,6 @@ class CheckpointSuite extends TestSuiteBase { checkpointWriter.stop() } - /** - * Tests a streaming operation under checkpointing, by restarting the operation - * from checkpoint file and verifying whether the final output is correct. - * The output is assumed to have come from a reliable queue which an replay - * data as required. - * - * NOTE: This takes into consideration that the last batch processed before - * master failure will be re-processed after restart/recovery. - */ - def testCheckpointedOperation[U: ClassTag, V: ClassTag]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { - - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 - // because the last batch will be processed again - - // Do the computation for initial number of batches, create checkpoint file and quit - ssc = setupStreams[U, V](input, operation) - ssc.start() - val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) - ssc.stop() - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) - - // Restart and complete the computation from checkpoint file - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation " + - "\n-------------------------------------------\n" - ) - ssc = new StreamingContext(checkpointDir) - ssc.start() - val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) - // the first element will be re-processed data of the last batch before restart - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) - ssc.stop() - ssc = null - } - /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index a45c92d9c7bc8..be0f4636a6cb8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) { // All access to this state should be guarded by `BatchCounter.this.synchronized` private var numCompletedBatches = 0 private var numStartedBatches = 0 + private var lastCompletedBatchTime: Time = null private val listener = new StreamingListener { override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = @@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) { override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = BatchCounter.this.synchronized { numCompletedBatches += 1 + lastCompletedBatchTime = batchCompleted.batchInfo.batchTime BatchCounter.this.notifyAll() } } @@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) { numStartedBatches } + def getLastCompletedBatchTime: Time = this.synchronized { + lastCompletedBatchTime + } + /** * Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if * `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index 58aef74c0040f..1fc320d31b18b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -25,31 +25,27 @@ import scala.reflect.ClassTag import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { +class TrackStateByKeySuite extends SparkFunSuite + with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { private var sc: SparkContext = null - private var ssc: StreamingContext = null - private var checkpointDir: File = null - private val batchDuration = Seconds(1) + protected var checkpointDir: File = null + protected val batchDuration = Seconds(1) before { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) - } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } checkpointDir = Utils.createTempDir("checkpoint") - - ssc = new StreamingContext(sc, batchDuration) - ssc.checkpoint(checkpointDir.toString) } after { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) + if (checkpointDir != null) { + Utils.deleteRecursively(checkpointDir) } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } } override def beforeAll(): Unit = { @@ -242,7 +238,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(dstreamImpl.stateClass === classOf[Double]) assert(dstreamImpl.emittedClass === classOf[Long]) } - + val ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types @@ -451,8 +447,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef expectedCheckpointDuration: Duration, explicitCheckpointDuration: Option[Duration] = None ): Unit = { + val ssc = new StreamingContext(sc, batchDuration) + try { - ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) val dummyFunc = (value: Option[Int], state: State[Int]) => 0 val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) @@ -462,11 +459,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef trackStateStream.checkpoint(d) } trackStateStream.register() + ssc.checkpoint(checkpointDir.toString) ssc.start() // should initialize all the checkpoint durations assert(trackStateStream.checkpointDuration === null) assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) } finally { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + ssc.stop(stopSparkContext = false) } } @@ -479,6 +477,50 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) } + + test("trackStateByKey - driver failure recovery") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + def operation(dstream: DStream[String]): DStream[(String, Int)] = { + + val checkpointDuration = batchDuration * (stateData.size / 2) + + val runningCount = (value: Option[Int], state: State[Int]) => { + state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) + state.get() + } + + val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey( + StateSpec.function(runningCount)) + // Set internval make sure there is one RDD checkpointing + trackStateStream.checkpoint(checkpointDuration) + trackStateStream.stateSnapshots() + } + + testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, + batchDuration = batchDuration, stopSparkContextAfterTest = false) + } + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], trackStateSpec: StateSpec[K, Int, S, T], @@ -500,6 +542,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { // Setup the stream computation + val ssc = new StreamingContext(sc, Seconds(1)) val inputStream = new TestInputStream(ssc, input, numPartitions = 2) val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] @@ -511,12 +554,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef stateSnapshotStream.register() val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.toString) ssc.start() val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] clock.advance(batchDuration.milliseconds * numBatches) batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + ssc.stop(stopSparkContext = false) (collectedOutputs, collectedStateSnapshots) }