Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11932][STREAMING] Partition previous TrackStateRDD if partitioner not present #9988

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class CheckpointWriter(
val bytes = Checkpoint.serialize(checkpoint, conf)
executor.execute(new CheckpointWriteHandler(
checkpoint.checkpointTime, bytes, clearCheckpointDataLater))
logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: revert this

} catch {
case rej: RejectedExecutionException =>
logError("Could not submit checkpoint task to the thread pool executor", rej)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
/** Method that generates a RDD for the given time */
override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
TrackStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner, validTime
)
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
case Some(rdd) =>
if (rdd.partitioner != Some(partitioner)) {
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
TrackStateRDD.createFromRDD[K, V, S, E](
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
} else {
rdd
}
case None =>
TrackStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime
)
}


// Compute the new state RDD with previous state RDD and partitioned data RDD
parent.getOrCompute(validTime).map { dataRDD =>
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
new TrackStateRDD(
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)
// Even if there is no data RDD, use an empty one to create a new state RDD
val dataRDD = parent.getOrCompute(validTime).getOrElse {
context.sparkContext.emptyRDD[(K, V)]
}
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
Some(new TrackStateRDD(
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:

private[streaming] object TrackStateRDD {

def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
pairRDD: RDD[(K, S)],
partitioner: Partitioner,
updateTime: Time): TrackStateRDD[K, V, S, T] = {
updateTime: Time): TrackStateRDD[K, V, S, E] = {

val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)

val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)

val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None

new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
}

def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
rdd: RDD[(K, S, Long)],
partitioner: Partitioner,
updateTime: Time): TrackStateRDD[K, V, S, E] = {

val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, (state, updateTime)) =>
stateMap.put(key, state, updateTime)
}
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)

val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)

val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None

new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
}
}

Expand Down
1 change: 1 addition & 0 deletions streaming/src/test/resources/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{

# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.spark-project.jetty=WARN
log4j.appender.org.apache.spark.streaming=DEBUG
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should revert this


Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,149 @@ import org.mockito.Mockito.mock
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.TestUtils
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}

/**
* A trait of that can be mixed in to get methods for testing DStream operations under
* DStream checkpointing. Note that the implementations of this trait has to implement
* the `setupCheckpointOperation`
*/
trait DStreamCheckpointTester { self: SparkFunSuite =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is refactoring where I extract out the testCheckpointedOperation so that it can be used in other unit tests.


/**
* Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
* The output is assumed to have come from a reliable queue which an replay
* data as required.
*
* NOTE: This takes into consideration that the last batch processed before
* master failure will be re-processed after restart/recovery.
*/
protected def testCheckpointedOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
numBatchesBeforeRestart: Int,
batchDuration: Duration = Milliseconds(500),
stopSparkContextAfterTest: Boolean = true
) {
require(numBatchesBeforeRestart < expectedOutput.size,
"Number of batches before context restart less than number of expected output " +
"(i.e. number of total batches to run)")
require(StreamingContext.getActive().isEmpty,
"Cannot run test with already active streaming context")

// Current code assumes that number of batches to be run = number of inputs
val totalNumBatches = input.size
val batchDurationMillis = batchDuration.milliseconds

// Setup the stream computation
val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString
logDebug(s"Using checkpoint directory $checkpointDir")
val ssc = createContextForCheckpointOperation(batchDuration)
require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName,
"Cannot run test without manual clock in the conf")

val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
val operatedStream = operation(inputStream)
operatedStream.print()
val outputStream = new TestOutputStreamWithPartitions(operatedStream,
new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
outputStream.register()
ssc.checkpoint(checkpointDir)

// Do the computation for initial number of batches, create checkpoint file and quit
val beforeRestartOutput = generateOutput[V](ssc,
Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest)
assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true)
// Restart and complete the computation from checkpoint file
logInfo(
"\n-------------------------------------------\n" +
" Restarting stream computation " +
"\n-------------------------------------------\n"
)

val restartedSsc = new StreamingContext(checkpointDir)
val afterRestartOutput = generateOutput[V](restartedSsc,
Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest)
assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false)
}

protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = {
val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
conf.set("spark.streaming.clock", classOf[ManualClock].getName())
new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
}

private def generateOutput[V: ClassTag](
ssc: StreamingContext,
targetBatchTime: Time,
checkpointDir: String,
stopSparkContext: Boolean
): Seq[Seq[V]] = {
try {
val batchDuration = ssc.graph.batchDuration
val batchCounter = new BatchCounter(ssc)
ssc.start()
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val currentTime = clock.getTimeMillis()

logInfo("Manual clock before advancing = " + clock.getTimeMillis())
clock.setTime(targetBatchTime.milliseconds)
logInfo("Manual clock after advancing = " + clock.getTimeMillis())

val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
}.head.asInstanceOf[TestOutputStreamWithPartitions[V]]

eventually(timeout(10 seconds)) {
ssc.awaitTerminationOrTimeout(10)
assert(batchCounter.getLastCompletedBatchTime === targetBatchTime)
}

eventually(timeout(10 seconds)) {
val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter {
_.toString.contains(clock.getTimeMillis.toString)
}
// Checkpoint files are written twice for every batch interval. So assert that both
// are written to make sure that both of them have been written.
assert(checkpointFilesOfLatestTime.size === 2)
}
outputStream.output.map(_.flatten)

} finally {
ssc.stop(stopSparkContext = stopSparkContext)
}
}

private def assertOutput[V: ClassTag](
output: Seq[Seq[V]],
expectedOutput: Seq[Seq[V]],
beforeRestart: Boolean): Unit = {
val expectedPartialOutput = if (beforeRestart) {
expectedOutput.take(output.size)
} else {
expectedOutput.takeRight(output.size)
}
val setComparison = output.zip(expectedPartialOutput).forall {
case (o, e) => o.toSet === e.toSet
}
assert(setComparison, s"set comparison failed\n" +
s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" +
s"Generated output items: ${output.mkString("\n")}"
)
}
}

/**
* This test suites tests the checkpointing functionality of DStreams -
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
class CheckpointSuite extends TestSuiteBase {
class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {

var ssc: StreamingContext = null

Expand All @@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase {

override def afterFunction() {
super.afterFunction()
if (ssc != null) ssc.stop()
if (ssc != null) { ssc.stop() }
Utils.deleteRecursively(new File(checkpointDir))
}

Expand Down Expand Up @@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase {
Seq(("", 2)),
Seq(),
Seq(("a", 2), ("b", 1)),
Seq(("", 2)), Seq() ),
Seq(("", 2)),
Seq()
),
3
)
}
Expand Down Expand Up @@ -634,53 +768,6 @@ class CheckpointSuite extends TestSuiteBase {
checkpointWriter.stop()
}

/**
* Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
* The output is assumed to have come from a reliable queue which an replay
* data as required.
*
* NOTE: This takes into consideration that the last batch processed before
* master failure will be re-processed after restart/recovery.
*/
def testCheckpointedOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
initialNumBatches: Int
) {

// Current code assumes that:
// number of inputs = number of outputs = number of batches to be run
val totalNumBatches = input.size
val nextNumBatches = totalNumBatches - initialNumBatches
val initialNumExpectedOutputs = initialNumBatches
val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
// because the last batch will be processed again

// Do the computation for initial number of batches, create checkpoint file and quit
ssc = setupStreams[U, V](input, operation)
ssc.start()
val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches)
ssc.stop()
verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
Thread.sleep(1000)

// Restart and complete the computation from checkpoint file
logInfo(
"\n-------------------------------------------\n" +
" Restarting stream computation " +
"\n-------------------------------------------\n"
)
ssc = new StreamingContext(checkpointDir)
ssc.start()
val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
// the first element will be re-processed data of the last batch before restart
verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
ssc.stop()
ssc = null
}

/**
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) {
// All access to this state should be guarded by `BatchCounter.this.synchronized`
private var numCompletedBatches = 0
private var numStartedBatches = 0
private var lastCompletedBatchTime: Time = null

private val listener = new StreamingListener {
override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit =
Expand All @@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) {
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit =
BatchCounter.this.synchronized {
numCompletedBatches += 1
lastCompletedBatchTime = batchCompleted.batchInfo.batchTime
BatchCounter.this.notifyAll()
}
}
Expand All @@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) {
numStartedBatches
}

def getLastCompletedBatchTime: Time = this.synchronized {
lastCompletedBatchTime
}

/**
* Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if
* `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's
Expand Down
Loading