Skip to content

Commit

Permalink
[SPARK-11932][STREAMING] Partition previous TrackStateRDD if partitio…
Browse files Browse the repository at this point in the history
…ner not present

The reason is that TrackStateRDDs generated by trackStateByKey expect the previous batch's TrackStateRDDs to have a partitioner. However, when recovery from DStream checkpoints, the RDDs recovered from RDD checkpoints do not have a partitioner attached to it. This is because RDD checkpoints do not preserve the partitioner (SPARK-12004).

While #9983 solves SPARK-12004 by preserving the partitioner through RDD checkpoints, there may be a non-zero chance that the saving and recovery fails. To be resilient, this PR repartitions the previous state RDD if the partitioner is not detected.

Author: Tathagata Das <[email protected]>

Closes #9988 from tdas/SPARK-11932.

(cherry picked from commit 5d80d8c)
Signed-off-by: Tathagata Das <[email protected]>
  • Loading branch information
tdas committed Dec 7, 2015
1 parent fed4538 commit 539914f
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class CheckpointWriter(
val bytes = Checkpoint.serialize(checkpoint, conf)
executor.execute(new CheckpointWriteHandler(
checkpoint.checkpointTime, bytes, clearCheckpointDataLater))
logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
} catch {
case rej: RejectedExecutionException =>
logError("Could not submit checkpoint task to the thread pool executor", rej)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
/** Method that generates a RDD for the given time */
override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
TrackStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner, validTime
)
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
case Some(rdd) =>
if (rdd.partitioner != Some(partitioner)) {
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
TrackStateRDD.createFromRDD[K, V, S, E](
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
} else {
rdd
}
case None =>
TrackStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime
)
}


// Compute the new state RDD with previous state RDD and partitioned data RDD
parent.getOrCompute(validTime).map { dataRDD =>
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
new TrackStateRDD(
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)
// Even if there is no data RDD, use an empty one to create a new state RDD
val dataRDD = parent.getOrCompute(validTime).getOrElse {
context.sparkContext.emptyRDD[(K, V)]
}
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
Some(new TrackStateRDD(
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:

private[streaming] object TrackStateRDD {

def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
pairRDD: RDD[(K, S)],
partitioner: Partitioner,
updateTime: Time): TrackStateRDD[K, V, S, T] = {
updateTime: Time): TrackStateRDD[K, V, S, E] = {

val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)

val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)

val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None

new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
}

def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
rdd: RDD[(K, S, Long)],
partitioner: Partitioner,
updateTime: Time): TrackStateRDD[K, V, S, E] = {

val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, (state, updateTime)) =>
stateMap.put(key, state, updateTime)
}
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)

val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)

val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None

new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,149 @@ import org.mockito.Mockito.mock
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.TestUtils
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}

/**
* A trait of that can be mixed in to get methods for testing DStream operations under
* DStream checkpointing. Note that the implementations of this trait has to implement
* the `setupCheckpointOperation`
*/
trait DStreamCheckpointTester { self: SparkFunSuite =>

/**
* Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
* The output is assumed to have come from a reliable queue which an replay
* data as required.
*
* NOTE: This takes into consideration that the last batch processed before
* master failure will be re-processed after restart/recovery.
*/
protected def testCheckpointedOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
numBatchesBeforeRestart: Int,
batchDuration: Duration = Milliseconds(500),
stopSparkContextAfterTest: Boolean = true
) {
require(numBatchesBeforeRestart < expectedOutput.size,
"Number of batches before context restart less than number of expected output " +
"(i.e. number of total batches to run)")
require(StreamingContext.getActive().isEmpty,
"Cannot run test with already active streaming context")

// Current code assumes that number of batches to be run = number of inputs
val totalNumBatches = input.size
val batchDurationMillis = batchDuration.milliseconds

// Setup the stream computation
val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString
logDebug(s"Using checkpoint directory $checkpointDir")
val ssc = createContextForCheckpointOperation(batchDuration)
require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName,
"Cannot run test without manual clock in the conf")

val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
val operatedStream = operation(inputStream)
operatedStream.print()
val outputStream = new TestOutputStreamWithPartitions(operatedStream,
new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
outputStream.register()
ssc.checkpoint(checkpointDir)

// Do the computation for initial number of batches, create checkpoint file and quit
val beforeRestartOutput = generateOutput[V](ssc,
Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest)
assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true)
// Restart and complete the computation from checkpoint file
logInfo(
"\n-------------------------------------------\n" +
" Restarting stream computation " +
"\n-------------------------------------------\n"
)

val restartedSsc = new StreamingContext(checkpointDir)
val afterRestartOutput = generateOutput[V](restartedSsc,
Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest)
assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false)
}

protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = {
val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
conf.set("spark.streaming.clock", classOf[ManualClock].getName())
new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
}

private def generateOutput[V: ClassTag](
ssc: StreamingContext,
targetBatchTime: Time,
checkpointDir: String,
stopSparkContext: Boolean
): Seq[Seq[V]] = {
try {
val batchDuration = ssc.graph.batchDuration
val batchCounter = new BatchCounter(ssc)
ssc.start()
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val currentTime = clock.getTimeMillis()

logInfo("Manual clock before advancing = " + clock.getTimeMillis())
clock.setTime(targetBatchTime.milliseconds)
logInfo("Manual clock after advancing = " + clock.getTimeMillis())

val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
}.head.asInstanceOf[TestOutputStreamWithPartitions[V]]

eventually(timeout(10 seconds)) {
ssc.awaitTerminationOrTimeout(10)
assert(batchCounter.getLastCompletedBatchTime === targetBatchTime)
}

eventually(timeout(10 seconds)) {
val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter {
_.toString.contains(clock.getTimeMillis.toString)
}
// Checkpoint files are written twice for every batch interval. So assert that both
// are written to make sure that both of them have been written.
assert(checkpointFilesOfLatestTime.size === 2)
}
outputStream.output.map(_.flatten)

} finally {
ssc.stop(stopSparkContext = stopSparkContext)
}
}

private def assertOutput[V: ClassTag](
output: Seq[Seq[V]],
expectedOutput: Seq[Seq[V]],
beforeRestart: Boolean): Unit = {
val expectedPartialOutput = if (beforeRestart) {
expectedOutput.take(output.size)
} else {
expectedOutput.takeRight(output.size)
}
val setComparison = output.zip(expectedPartialOutput).forall {
case (o, e) => o.toSet === e.toSet
}
assert(setComparison, s"set comparison failed\n" +
s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" +
s"Generated output items: ${output.mkString("\n")}"
)
}
}

/**
* This test suites tests the checkpointing functionality of DStreams -
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
class CheckpointSuite extends TestSuiteBase {
class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {

var ssc: StreamingContext = null

Expand All @@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase {

override def afterFunction() {
super.afterFunction()
if (ssc != null) ssc.stop()
if (ssc != null) { ssc.stop() }
Utils.deleteRecursively(new File(checkpointDir))
}

Expand Down Expand Up @@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase {
Seq(("", 2)),
Seq(),
Seq(("a", 2), ("b", 1)),
Seq(("", 2)), Seq() ),
Seq(("", 2)),
Seq()
),
3
)
}
Expand Down Expand Up @@ -634,53 +768,6 @@ class CheckpointSuite extends TestSuiteBase {
checkpointWriter.stop()
}

/**
* Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
* The output is assumed to have come from a reliable queue which an replay
* data as required.
*
* NOTE: This takes into consideration that the last batch processed before
* master failure will be re-processed after restart/recovery.
*/
def testCheckpointedOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
initialNumBatches: Int
) {

// Current code assumes that:
// number of inputs = number of outputs = number of batches to be run
val totalNumBatches = input.size
val nextNumBatches = totalNumBatches - initialNumBatches
val initialNumExpectedOutputs = initialNumBatches
val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
// because the last batch will be processed again

// Do the computation for initial number of batches, create checkpoint file and quit
ssc = setupStreams[U, V](input, operation)
ssc.start()
val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches)
ssc.stop()
verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
Thread.sleep(1000)

// Restart and complete the computation from checkpoint file
logInfo(
"\n-------------------------------------------\n" +
" Restarting stream computation " +
"\n-------------------------------------------\n"
)
ssc = new StreamingContext(checkpointDir)
ssc.start()
val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
// the first element will be re-processed data of the last batch before restart
verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
ssc.stop()
ssc = null
}

/**
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) {
// All access to this state should be guarded by `BatchCounter.this.synchronized`
private var numCompletedBatches = 0
private var numStartedBatches = 0
private var lastCompletedBatchTime: Time = null

private val listener = new StreamingListener {
override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit =
Expand All @@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) {
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit =
BatchCounter.this.synchronized {
numCompletedBatches += 1
lastCompletedBatchTime = batchCompleted.batchInfo.batchTime
BatchCounter.this.notifyAll()
}
}
Expand All @@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) {
numStartedBatches
}

def getLastCompletedBatchTime: Time = this.synchronized {
lastCompletedBatchTime
}

/**
* Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if
* `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's
Expand Down
Loading

0 comments on commit 539914f

Please sign in to comment.