From 3cf6028a63beb798bcf4e23cb34bc8223b606478 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 30 Oct 2014 17:15:01 -0700 Subject: [PATCH 1/2] Moved FaiureSuite.scala to DriverFailureSuite.scala --- .../streaming/{FailureSuite.scala => DriverFailureSuite.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename streaming/src/test/scala/org/apache/spark/streaming/{FailureSuite.scala => DriverFailureSuite.scala} (100%) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala similarity index 100% rename from streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala rename to streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala From dd1436521e228d42db8d4a435b37284d02c35591 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 30 Oct 2014 17:18:18 -0700 Subject: [PATCH 2/2] Added DriverFailureTest and ReceiverBasedDriverFailureTest. --- .../spark/streaming/DriverFailureSuite.scala | 88 +++++++- .../streaming/util/DriverFailureTest.scala | 200 ++++++++++++++++++ .../util/ReceiverBasedDriverFailureTest.scala | 148 +++++++++++++ 3 files changed, 425 insertions(+), 11 deletions(-) create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/util/DriverFailureTest.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/util/ReceiverBasedDriverFailureTest.scala diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala index 40434b1f9b709..d807b9aaada71 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala @@ -17,18 +17,23 @@ package org.apache.spark.streaming -import org.apache.spark.Logging -import org.apache.spark.util.Utils - import java.io.File +import com.google.common.io.Files + +import org.apache.spark.{HashPartitioner, Logging} +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.util.{DriverFailureTestReceiver, ReceiverBasedDriverFailureTest} +import org.apache.spark.util.Utils + /** - * This testsuite tests master failures at random times while the stream is running using - * the real clock. + * This testsuite tests driver failures by explicitly stopping the streaming context at random + * times while the stream is running using the real clock. */ -class FailureSuite extends TestSuiteBase with Logging { +class DriverFailureSuite extends TestSuiteBase with Logging { - var directory = "FailureSuite" + var tempDir: String = null val numBatches = 30 override def batchDuration = Milliseconds(1000) @@ -37,20 +42,81 @@ class FailureSuite extends TestSuiteBase with Logging { override def beforeFunction() { super.beforeFunction() - Utils.deleteRecursively(new File(directory)) + tempDir = Files.createTempDir().toString } override def afterFunction() { super.afterFunction() - Utils.deleteRecursively(new File(directory)) + Utils.deleteRecursively(new File(tempDir)) } test("multiple failures with map") { - MasterFailureTest.testMap(directory, numBatches, batchDuration) + MasterFailureTest.testMap(tempDir, numBatches, batchDuration) } test("multiple failures with updateStateByKey") { - MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) + MasterFailureTest.testUpdateStateByKey(tempDir, numBatches, batchDuration) + } + + // TODO (TD): Explain how the test works + test("multiple failures with receiver and updateStateByKey") { + + // Define the DStream operation to test under driver failures + val operation = (st: DStream[String]) => { + + val mapPartitionFunc = (iterator: Iterator[String]) => { + Iterator(iterator.flatMap(_.split(" ")).map(_ -> 1L).reduce((x, y) => (x._1, x._2 + y._2))) + } + + val updateFunc = (iterator: Iterator[(String, Seq[Long], Option[Seq[Long]])]) => { + iterator.map { case (key, values, state) => + val combined = (state.getOrElse(Seq.empty) ++ values).sorted + if (state.isEmpty || state.get.max != DriverFailureTestReceiver.maxRecordsPerBlock) { + val oldState = s"[${ state.map { _.max }.getOrElse(-1) }, ${state.map { _.distinct.sum }.getOrElse(0)}]" + val newState = s"[${combined.max}, ${combined.distinct.sum}]" + println(s"Updated state for $key: state = $oldState, new values = $values, new state = $newState") + } + (key, combined) + } + } + + st.mapPartitions(mapPartitionFunc) + .updateStateByKey[Seq[Long]](updateFunc, new HashPartitioner(2), rememberPartitioner = false) + .checkpoint(batchDuration * 5) + } + + val maxValue = DriverFailureTestReceiver.maxRecordsPerBlock + val expectedValues = (1L to maxValue).toSet + + // Define the function to verify the output of the DStream operation + val verify = (time: Time, output: Seq[(String, Seq[Long])]) => { + val outputStr = output.map { x => (x._1, x._2.distinct.sum) }.sortBy(_._1).mkString(", ") + println(s"State at $time: $outputStr") + + val incompletelyReceivedWords = output.filter { _._2.max < maxValue } + if (incompletelyReceivedWords.size > 1) { + val debugStr = incompletelyReceivedWords.map { x => + s"""${x._1}: ${x._2.mkString(",")}, sum = ${x._2.distinct.sum}""" + }.mkString("\n") + throw new Exception(s"Incorrect processing of input, all input not processed:\n$debugStr\n") + } + + output.foreach { case (key, values) => + if (!values.forall(expectedValues.contains)) { + val sum = values.distinct.sum + val debugStr = values.zip(1L to values.size).map { + x => if (x._1 == x._2) x._1 else s"[${x._2}]" + }.mkString(",") + s", sum = $sum" + throw new Exception(s"Incorrect sequence of values in output:\n$debugStr\n") + } + } + } + + val driverTest = new ReceiverBasedDriverFailureTest(200, 50, operation, verify, tempDir) + driverTest.testAndGetError().map { errorMessage => + fail(errorMessage) + } } } + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/DriverFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/DriverFailureTest.scala new file mode 100644 index 0000000000000..72edc2a03119c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/DriverFailureTest.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.util.UUID + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.Logging +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerBatchCompleted} + +abstract class DriverFailureTest( + testDirectory: String, + batchDurationMillis: Int, + numBatchesToRun: Int + ) extends Logging { + + @transient private val checkpointDir = createCheckpointDir() + @transient private val timeoutMillis = batchDurationMillis * numBatchesToRun * 4 + + @transient @volatile private var killed = false + @transient @volatile private var killCount = 0 + @transient @volatile private var lastBatchCompleted = 0L + @transient @volatile private var batchesCompleted = 0 + @transient @volatile private var ssc: StreamingContext = null + + protected def setupContext(checkpointDirector: String): StreamingContext + + //---------------------------------------- + + /** + * Run the test and return an option string containing error message. + * @return None is test succeeded, or Some(errorMessage) if test failed + */ + def testAndGetError(): Option[String] = { + DriverFailureTest.reset() + ssc = setupContext(checkpointDir.toString) + doTest() + } + + /** + * Actually perform the test on the context that has been setup using `setupContext` + * and return any error message. + */ + private def doTest(): Option[String] = { + + val runStartTime = System.currentTimeMillis + var killingThread: Thread = null + + def allBatchesCompleted = batchesCompleted >= numBatchesToRun + def timedOut = (System.currentTimeMillis - runStartTime) > timeoutMillis + def failed = DriverFailureTest.failed + + while(!failed && !allBatchesCompleted && !timedOut) { + // Start the thread to kill the streaming after some time + killed = false + try { + ssc.addStreamingListener(new BatchCompletionListener) + ssc.start() + + killingThread = new KillingThread(ssc, batchDurationMillis * 10) + killingThread.start() + + while (!failed && !killed && !allBatchesCompleted && !timedOut) { + ssc.awaitTermination(1) + } + } catch { + case e: Exception => + logError("Error running streaming context", e) + DriverFailureTest.fail("Error running streaming context: " + e) + } + + logInfo(s"Failed = $failed") + logInfo(s"Killed = $killed") + logInfo(s"All batches completed = $allBatchesCompleted") + logInfo(s"Timed out = $timedOut") + + if (killingThread.isAlive) { + killingThread.interrupt() + ssc.stop() + } + + if (!timedOut) { + val sleepTime = Random.nextInt(batchDurationMillis * 10) + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation in " + sleepTime + " ms " + + "\n-------------------------------------------\n" + ) + Thread.sleep(sleepTime) + + // Recreate the streaming context from checkpoint + System.clearProperty("spark.driver.port") + ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => { + throw new Exception("Trying to create new context when it " + + "should be reading from checkpoint file") + }) + println("Restarted") + } + } + + if (failed) { + Some(s"Failed with message: ${DriverFailureTest.firstFailureMessage}") + } else if (timedOut) { + Some(s"Timed out after $batchesCompleted/$numBatchesToRun batches, and " + + s"${System.currentTimeMillis} ms (time out = $timeoutMillis ms)") + } else if (allBatchesCompleted) { + None + } else { + throw new Exception("Unexpected end of test") + } + } + + private def createCheckpointDir(): Path = { + // Create the directories for this test + val uuid = UUID.randomUUID().toString + val rootDir = new Path(testDirectory, uuid) + val fs = rootDir.getFileSystem(new Configuration()) + val dir = new Path(rootDir, "checkpoint") + fs.mkdirs(dir) + dir + } + + class BatchCompletionListener extends StreamingListener { + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + if (batchCompleted.batchInfo.batchTime.milliseconds > lastBatchCompleted) { + batchesCompleted += 1 + lastBatchCompleted = batchCompleted.batchInfo.batchTime.milliseconds + } + } + } + + class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging { + override def run() { + try { + // If it is the first killing, then allow the first checkpoint to be created + var minKillWaitTime = if (killCount == 0) 5000 else 2000 + val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime) + logInfo("Kill wait time = " + killWaitTime) + Thread.sleep(killWaitTime) + logInfo( + "\n---------------------------------------\n" + + "Killing streaming context after " + killWaitTime + " ms" + + "\n---------------------------------------\n" + ) + ssc.stop() + killed = true + killCount += 1 + println("Killed") + logInfo("Killing thread finished normally") + } catch { + case ie: InterruptedException => logInfo("Killing thread interrupted") + case e: Exception => logWarning("Exception in killing thread", e) + } + + } + } +} + +/** + * Companion object to [[org.apache.spark.streaming.util.DriverFailureTest]] containing + * global state used while running a driver failure test. + */ +object DriverFailureTest { + @transient @volatile var failed: Boolean = _ + @transient @volatile var firstFailureMessage: String = _ + + /** Mark the currently running test as failed with the given error message */ + def fail(message: String) { + if (!failed) { + failed = true + firstFailureMessage = message + } + } + + /** Reset the state */ + def reset() { + failed = false + firstFailureMessage = "NOT SET" + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/ReceiverBasedDriverFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/ReceiverBasedDriverFailureTest.scala new file mode 100644 index 0000000000000..172531818cb00 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/ReceiverBasedDriverFailureTest.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.storage.StorageLevel +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.ArrayBuffer + +/** + * Implementation of the [[org.apache.spark.streaming.util.DriverFailureTest]] that uses + * a receiver as the data source and test whether all the received data gets processed with + * the given operation despite recurring driver failures. + */ +class ReceiverBasedDriverFailureTest[T]( + @transient batchDurationMillis: Int, + @transient numBatchesToRun: Int, + @transient operation: DStream[String] => DStream[T], + @transient outputVerifyingFunction: (Time, Seq[T]) => Unit, + @transient testDirectory: String) + extends DriverFailureTest(testDirectory, batchDurationMillis, numBatchesToRun) { + + @transient val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("ReceiverBasedDriverFailureTest") + .set("spark.streaming.receiver.writeAheadLog.enable", "true") // enable write ahead log + .set("spark.streaming.receiver.writeAheadLog.rotationIntervalSecs", "10") + // rotate logs to test cleanup + + override def setupContext(checkpointDirector: String): StreamingContext = { + + val context = StreamingContext.getOrCreate(checkpointDirector, () => { + val newSsc = new StreamingContext(conf, Milliseconds(batchDurationMillis)) + val inputStream = newSsc.receiverStream[String](new DriverFailureTestReceiver) + val operatedStream = operation(inputStream) + val verify = outputVerifyingFunction + operatedStream.foreachRDD((rdd: RDD[T], time: Time) => { + try { + val collected = rdd.collect() + verify(time, collected) + } catch { + case ie: InterruptedException => + // ignore + case e: Exception => + DriverFailureTest.fail(e.toString) + } + }) + newSsc.checkpoint(checkpointDirector) + newSsc + }) + context + } +} + + +/** + * Implementation of [[org.apache.spark.streaming.receiver.Receiver]] that is used by + * [[org.apache.spark.streaming.util.ReceiverBasedDriverFailureTest]] for failure recovery under + * driver failures. + */ +class DriverFailureTestReceiver + extends Receiver[String](StorageLevel.MEMORY_ONLY_SER) with Logging { + + import DriverFailureTestReceiver._ + + @volatile var thread: Thread = null + + class ReceivingThread extends Thread() { + override def run() { + while (!isStopped() && !isInterrupted()) { + try { + val block = getNextBlock() + store(block) + commitBlock() + Thread.sleep(10) + } catch { + case ie: InterruptedException => + case e: Exception => + DriverFailureTestReceiver.this.stop("Error in receiving thread", e) + } + } + } + } + + def onStart() { + if (thread == null) { + thread.start() + } else { + logError("Error starting receiver, previous receiver thread not stopped yet.") + } + } + + def onStop() { + if (thread != null) { + thread.interrupt() + thread = null + } + } +} + + +/** + * Companion object of [[org.apache.spark.streaming.util.DriverFailureTestReceiver]] + * containing global state used while running a driver failure test. + */ +object DriverFailureTestReceiver { + val maxRecordsPerBlock = 1000L + private val currentKey = new AtomicInteger() + private val counter = new AtomicInteger() + + counter.set(1) + currentKey.set(1) + + def getNextBlock(): ArrayBuffer[String] = { + val count = counter.get() + new ArrayBuffer ++= (1 to count).map { + _ => "word%03d".format(currentKey.get()) + } + } + + def commitBlock() { + println(s"Stored ${counter.get()} copies of word${currentKey.get}") + if (counter.incrementAndGet() > maxRecordsPerBlock) { + currentKey.incrementAndGet() + counter.set(1) + } + } +} +