diff --git a/streaming/pom.xml b/streaming/pom.xml index 12f900c91eb98..002c35f251390 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -68,6 +68,11 @@ junit-interface test + + org.mockito + mockito-all + test + target/scala-${scala.binary.version}/classes diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index f33c0ceafdf42..544cf57587e57 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -26,7 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.streaming.Time private[streaming] -class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) +class DStreamCheckpointData[T: ClassTag](dstream: DStream[T]) extends Serializable with Logging { protected val data = new HashMap[Time, AnyRef]() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 391e40924f38a..f1645066cc815 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,10 +21,11 @@ import scala.collection.mutable.HashMap import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{StorageLevel, BlockId} import org.apache.spark.streaming._ import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.storage.rdd.HDFSBackedBlockRDD /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -39,9 +40,6 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** Keeps all received blocks information */ - private lazy val receivedBlockInfo = new HashMap[Time, Array[ReceivedBlockInfo]] - /** This is an unique identifier for the network input stream. */ val id = ssc.getNewReceiverStreamId() @@ -62,19 +60,27 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont // If this is called for any time before the start time of the context, // then this returns an empty RDD. This may happen when recovering from a // master failure - if (validTime >= graph.startTime) { - val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id) - receivedBlockInfo(validTime) = blockInfo - val blockIds = blockInfo.map(_.blockId.asInstanceOf[BlockId]) - Some(new BlockRDD[T](ssc.sc, blockIds)) + val blockRDD = if (validTime >= graph.startTime) { + val blockInfo = getReceivedBlockInfo(validTime) + val blockIds = blockInfo.map(_.blockId).map { _.asInstanceOf[BlockId] } toArray + val fileSegments = blockInfo.flatMap(_.fileSegmentOption).toArray + logInfo("Stream " + id + ": allocated " + blockInfo.map(_.blockId).mkString(", ")) + + if (fileSegments.nonEmpty) { + new HDFSBackedBlockRDD[T](ssc.sparkContext, ssc.sparkContext.hadoopConfiguration, + blockIds, fileSegments, storeInBlockManager = false, StorageLevel.MEMORY_ONLY_SER) + } else { + new BlockRDD[T](ssc.sc, blockIds) + } } else { - Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) + new BlockRDD[T](ssc.sc, Array[BlockId]()) } + Some(blockRDD) } /** Get information on received blocks. */ - private[streaming] def getReceivedBlockInfo(time: Time) = { - receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo]) + private[streaming] def getReceivedBlockInfo(time: Time): Seq[ReceivedBlockInfo] = { + ssc.scheduler.receiverTracker.getReceivedBlocks(time, id) } /** @@ -85,10 +91,6 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont */ private[streaming] override def clearMetadata(time: Time) { super.clearMetadata(time) - val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - rememberDuration)) - receivedBlockInfo --= oldReceivedBlocks.keys - logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", ")) + ssc.scheduler.receiverTracker.cleanupOldInfo(time - rememberDuration) } -} - +} \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDD.scala similarity index 97% rename from streaming/src/main/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDD.scala rename to streaming/src/main/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDD.scala index c672574ee2ed4..25f8363454394 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDD.scala @@ -68,6 +68,7 @@ class HDFSBackedBlockRDD[T: ClassTag]( block.data.asInstanceOf[Iterator[T]] // Data not found in Block Manager, grab it from HDFS case None => + logInfo("Reading partition data from write ahead log " + partition.segment.path) val reader = new WriteAheadLogRandomReader(partition.segment.path, hadoopConf) val dataRead = reader.read(partition.segment) reader.close() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 53a3e6200e340..97677491cd0b2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -25,16 +25,13 @@ import scala.concurrent.Await import akka.actor.{Actor, Props} import akka.pattern.ask - import com.google.common.base.Throwables - -import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.hadoop.conf.Configuration +import org.apache.spark.{SparkException, Logging, SparkEnv} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.streaming.scheduler.DeregisterReceiver -import org.apache.spark.streaming.scheduler.AddBlock -import org.apache.spark.streaming.scheduler.RegisterReceiver +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.storage._ +import org.apache.spark.util.{AkkaUtils, Utils} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -44,12 +41,26 @@ import org.apache.spark.streaming.scheduler.RegisterReceiver */ private[streaming] class ReceiverSupervisorImpl( receiver: Receiver[_], - env: SparkEnv + env: SparkEnv, + hadoopConf: Configuration, + checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { - private val blockManager = env.blockManager + private val receivedBlockHandler: ReceivedBlockHandler = { + if (env.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (checkpointDirOption.isEmpty) { + throw new SparkException( + "Cannot enable receiver write-ahead log without checkpoint directory set. " + + "Please use streamingContext.checkpoint() to set the checkpoint directory. " + + "See documentation for more details.") + } + new WriteAheadLogBasedBlockHandler(env.blockManager, receiver.streamId, + receiver.storageLevel, env.conf, hadoopConf, checkpointDirOption.get) + } else { + new BlockManagerBasedBlockHandler(env.blockManager, receiver.streamId, receiver.storageLevel) + } + } - private val storageLevel = receiver.storageLevel /** Remote Akka actor for the ReceiverTracker */ private val trackerActor = { @@ -108,11 +119,7 @@ private[streaming] class ReceiverSupervisorImpl( optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] ) { - val blockId = optionalBlockId.getOrElse(nextBlockId) - val time = System.currentTimeMillis - blockManager.putArray(blockId, arrayBuffer.toArray[Any], storageLevel, tellMaster = true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, arrayBuffer.size, optionalMetadata) + pushAndReportBlock(ArrayBufferBlock(arrayBuffer), optionalMetadata, optionalBlockId) } /** Store a iterator of received data as a data block into Spark's memory. */ @@ -121,11 +128,7 @@ private[streaming] class ReceiverSupervisorImpl( optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] ) { - val blockId = optionalBlockId.getOrElse(nextBlockId) - val time = System.currentTimeMillis - blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, -1, optionalMetadata) + pushAndReportBlock(IteratorBlock(iterator), optionalMetadata, optionalBlockId) } /** Store the bytes of received data as a data block into Spark's memory. */ @@ -134,17 +137,32 @@ private[streaming] class ReceiverSupervisorImpl( optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] ) { + pushAndReportBlock(ByteBufferBlock(bytes), optionalMetadata, optionalBlockId) + } + + /** Store block and report it to driver */ + def pushAndReportBlock( + receivedBlock: ReceivedBlock, + optionalMetadata: Option[Any], + optionalBlockId: Option[StreamBlockId] + ) { val blockId = optionalBlockId.getOrElse(nextBlockId) + val numRecords = receivedBlock match { + case ArrayBufferBlock(arrayBuffer) => arrayBuffer.size + case _ => -1 + } + val time = System.currentTimeMillis - blockManager.putBytes(blockId, bytes, storageLevel, tellMaster = true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, -1, optionalMetadata) - } + val fileSegmentOption = receivedBlockHandler.storeBlock(blockId, receivedBlock) match { + case Some(f: FileSegment) => Some(f) + case _ => None + } + logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - /** Report pushed block */ - def reportPushedBlock(blockId: StreamBlockId, numRecords: Long, optionalMetadata: Option[Any]) { - val blockInfo = ReceivedBlockInfo(streamId, blockId, numRecords, optionalMetadata.orNull) - trackerActor ! AddBlock(blockInfo) + val blockInfo = ReceivedBlockInfo(streamId, + blockId, numRecords, optionalMetadata.orNull, fileSegmentOption) + val future = trackerActor.ask(AddBlock(blockInfo))(askTimeout) + Await.result(future, askTimeout) logDebug("Reported block " + blockId) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 374848358e700..ee00abdf3f4f0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -222,10 +222,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { case Success(jobs) => val receivedBlockInfo = graph.getReceiverInputStreams.map { stream => val streamId = stream.id - val receivedBlockInfo = stream.getReceivedBlockInfo(time) + val receivedBlockInfo = stream.getReceivedBlockInfo(time).toArray (streamId, receivedBlockInfo) }.toMap - jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo)) + jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo.toMap)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 5307fe189d717..f2aa02d065222 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,23 +17,30 @@ package org.apache.spark.streaming.scheduler +import java.nio.ByteBuffer + import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue} import scala.language.existentials import akka.actor._ -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.SparkContext._ +import org.apache.spark.SerializableWritable import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} -import org.apache.spark.util.AkkaUtils +import org.apache.spark.streaming.storage.{ReceivedBlockTracker, FileSegment, WriteAheadLogManager} +import org.apache.spark.util.Utils /** Information about blocks received by the receiver */ private[streaming] case class ReceivedBlockInfo( streamId: Int, blockId: StreamBlockId, numRecords: Long, - metadata: Any + metadata: Any, + fileSegmentOption: Option[FileSegment] ) /** @@ -57,23 +64,28 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err * This class manages the execution of the receivers of NetworkInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() * has been called because it needs the final set of input streams at the time of instantiation. + * + * @param skipReceiverLaunch Do not launch the receiver. This is useful for testing. */ private[streaming] -class ReceiverTracker(ssc: StreamingContext) extends Logging { +class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false) extends Logging { - val receiverInputStreams = ssc.graph.getReceiverInputStreams() - val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*) - val receiverExecutor = new ReceiverLauncher() - val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] - val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - val timeout = AkkaUtils.askTimeout(ssc.conf) - val listenerBus = ssc.scheduler.listenerBus + private val receiverInputStreams = ssc.graph.getReceiverInputStreams() + private val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*) + private val receiverExecutor = new ReceiverLauncher() + private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] + private val receivedBlockTracker = new ReceivedBlockTracker( + ssc.sparkContext.conf, + ssc.sparkContext.hadoopConfiguration, + receiverInputStreams.map { _.id }, + ssc.scheduler.clock, + Option(ssc.checkpointDir) + ) + private val listenerBus = ssc.scheduler.listenerBus // actor is created when generator starts. // This not being null means the tracker has been started and not stopped - var actor: ActorRef = null - var currentTime: Time = null + private var actor: ActorRef = null /** Start the actor and receiver execution thread. */ def start() = synchronized { @@ -84,7 +96,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { if (!receiverInputStreams.isEmpty) { actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), "ReceiverTracker") - receiverExecutor.start() + if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") } } @@ -93,28 +105,27 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { def stop() = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers - receiverExecutor.stop() + if (!skipReceiverLaunch) receiverExecutor.stop() // Finally, stop the actor ssc.env.actorSystem.stop(actor) actor = null + receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") } } /** Return all the blocks received from a receiver. */ - def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = { - val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true) - logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks") - receivedBlockInfo.toArray + def getReceivedBlocks(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + receivedBlockTracker.getOrAllocateBlocksToBatch(batchTime, streamId) } - private def getReceivedBlockInfoQueue(streamId: Int) = { - receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo]) + def cleanupOldInfo(cleanupThreshTime: Time) { + receivedBlockTracker.cleanupOldBatches(cleanupThreshTime) } /** Register a receiver */ - def registerReceiver( + private def registerReceiver( streamId: Int, typ: String, host: String, @@ -126,12 +137,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host) - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) } /** Deregister a receiver */ - def deregisterReceiver(streamId: Int, message: String, error: String) { + private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) @@ -140,7 +151,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -150,14 +161,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Add new blocks for the given stream */ - def addBlocks(receivedBlockInfo: ReceivedBlockInfo) { - getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " + - receivedBlockInfo.blockId) + private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { + receivedBlockTracker.addBlock(receivedBlockInfo) } /** Report error sent by a receiver */ - def reportError(streamId: Int, message: String, error: String) { + private def reportError(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(lastErrorMessage = message, lastError = error) @@ -166,7 +175,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -177,7 +186,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { /** Check if any blocks are left to be processed */ def hasMoreReceivedBlockIds: Boolean = { - !receivedBlockInfo.values.forall(_.isEmpty) + receivedBlockTracker.hasUnallocatedReceivedBlocks() } /** Actor to receive messages from the receivers. */ @@ -187,7 +196,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { registerReceiver(streamId, typ, host, receiverActor, sender) sender ! true case AddBlock(receivedBlockInfo) => - addBlocks(receivedBlockInfo) + sender ! addBlock(receivedBlockInfo) case ReportError(streamId, message, error) => reportError(streamId, message, error) case DeregisterReceiver(streamId, message, error) => @@ -253,6 +262,9 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ssc.sc.makeRDD(receivers, receivers.size) } + val checkpointDirOption = Option(ssc.checkpointDir) + val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration) + // Function to start the receiver on the worker node val startReceiver = (iterator: Iterator[Receiver[_]]) => { if (!iterator.hasNext) { @@ -260,9 +272,10 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { "Could not start receiver as object not found.") } val receiver = iterator.next() - val executor = new ReceiverSupervisorImpl(receiver, SparkEnv.get) - executor.start() - executor.awaitTermination() + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() } // Run the dummy Spark job to ensure that all slaves have registered. // This avoids all the receivers to be scheduled on the same node. @@ -272,7 +285,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") - ssc.sparkContext.runJob(tempRDD, startReceiver) + ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) logInfo("All of the receivers have been terminated") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockHandler.scala new file mode 100644 index 0000000000000..d0d2370201a8e --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockHandler.scala @@ -0,0 +1,120 @@ +package org.apache.spark.streaming.storage + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps + +import WriteAheadLogBasedBlockHandler._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.storage.{BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.streaming.util.{Clock, SystemClock} +import org.apache.spark.util.Utils + +private[streaming] sealed trait ReceivedBlock +private[streaming] case class ArrayBufferBlock(arrayBuffer: ArrayBuffer[_]) extends ReceivedBlock +private[streaming] case class IteratorBlock(iterator: Iterator[_]) extends ReceivedBlock +private[streaming] case class ByteBufferBlock(byteBuffer: ByteBuffer) extends ReceivedBlock + +private[streaming] trait ReceivedBlockHandler { + def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): Option[AnyRef] + def cleanupOldBlock(threshTime: Long) +} + +private[streaming] class BlockManagerBasedBlockHandler( + blockManager: BlockManager, + streamId: Int, + storageLevel: StorageLevel + ) extends ReceivedBlockHandler { + + def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): Option[AnyRef] = { + receivedBlock match { + case ArrayBufferBlock(arrayBuffer) => + blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true) + case IteratorBlock(iterator) => + blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) + case ByteBufferBlock(byteBuffer) => + blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) + case _ => + throw new Exception(s"Could not push $blockId to block manager, unexpected block type") + } + None + } + + def cleanupOldBlock(threshTime: Long) { + // this is not used as blocks inserted into the BlockManager are cleared by DStream's clearing + // of BlockRDDs. + } +} + +private[streaming] class WriteAheadLogBasedBlockHandler( + blockManager: BlockManager, + streamId: Int, + storageLevel: StorageLevel, + conf: SparkConf, + hadoopConf: Configuration, + checkpointDir: String, + clock: Clock = new SystemClock + ) extends ReceivedBlockHandler with Logging { + + private val blockStoreTimeout = conf.getInt( + "spark.streaming.receiver.blockStoreTimeout", 30).seconds + private val rollingInterval = conf.getInt( + "spark.streaming.receiver.writeAheadLog.rollingInterval", 60) + private val maxFailures = conf.getInt( + "spark.streaming.receiver.writeAheadLog.maxFailures", 3) + + private val logManager = new WriteAheadLogManager( + checkpointDirToLogDir(checkpointDir, streamId), + hadoopConf, rollingInterval, maxFailures, + callerName = "WriteAheadLogBasedBlockHandler", + clock = clock + ) + + implicit private val executionContext = ExecutionContext.fromExecutorService( + Utils.newDaemonFixedThreadPool(1, "WriteAheadLogBasedBlockHandler")) + + def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): Option[AnyRef] = { + val serializedBlock = receivedBlock match { + case ArrayBufferBlock(arrayBuffer) => + blockManager.dataSerialize(blockId, arrayBuffer.iterator) + case IteratorBlock(iterator) => + blockManager.dataSerialize(blockId, iterator) + case ByteBufferBlock(byteBuffer) => + byteBuffer + case _ => + throw new Exception(s"Could not push $blockId to block manager, unexpected block type") + } + + val pushToBlockManagerFuture = Future { + blockManager.putBytes(blockId, serializedBlock, storageLevel, tellMaster = true) + } + val pushToLogFuture = Future { + logManager.writeToLog(serializedBlock) + } + val combinedFuture = for { + _ <- pushToBlockManagerFuture + fileSegment <- pushToLogFuture + } yield fileSegment + + Some(Await.result(combinedFuture, blockStoreTimeout)) + } + + def cleanupOldBlock(threshTime: Long) { + logManager.cleanupOldLogs(threshTime) + } + + def stop() { + logManager.stop() + } +} + +private[streaming] object WriteAheadLogBasedBlockHandler { + def checkpointDirToLogDir(checkpointDir: String, streamId: Int): String = { + new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockTracker.scala new file mode 100644 index 0000000000000..866d77dd25489 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockTracker.scala @@ -0,0 +1,173 @@ +package org.apache.spark.streaming.storage + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.storage.WriteAheadLogManager +import org.apache.spark.util.Utils +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.util.Clock +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo + +private[streaming] sealed trait ReceivedBlockTrackerRecord + +private[streaming] case class BlockAddition(receivedBlockInfo: ReceivedBlockInfo) + extends ReceivedBlockTrackerRecord +private[streaming] case class BatchAllocations(time: Time, allocatedBlocks: AllocatedBlocks) + extends ReceivedBlockTrackerRecord +private[streaming] case class BatchCleanup(times: Seq[Time]) + extends ReceivedBlockTrackerRecord + +case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { + def apply(streamId: Int) = streamIdToAllocatedBlocks(streamId) +} + +private[streaming] +class ReceivedBlockTracker( + conf: SparkConf, hadoopConf: Configuration, streamIds: Seq[Int], clock: Clock, + checkpointDirOption: Option[String]) extends Logging { + + private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] + + private val streamIdToUnallocatedBlockInfo = new mutable.HashMap[Int, ReceivedBlockQueue] + private val timeToAllocatedBlockInfo = new mutable.HashMap[Time, AllocatedBlocks] + + private val logManagerRollingIntervalSecs = conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) + private val logManagerOption = checkpointDirOption.map { checkpointDir => + new WriteAheadLogManager( + ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir), + hadoopConf, + rollingIntervalSecs = logManagerRollingIntervalSecs, + callerName = "ReceivedBlockHandlerMaster", + clock = clock + ) + } + + // Recover block information from write ahead logs + recoverFromWriteAheadLogs() + + /** Add received block */ + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + try { + writeToLog(BlockAddition(receivedBlockInfo)) + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + logDebug(s"Stream ${receivedBlockInfo.streamId} received block ${receivedBlockInfo.blockId}") + true + } catch { + case e: Exception => + logError("Error adding block " + receivedBlockInfo, e) + false + } + } + + /** Get blocks that have been added but not yet allocated to any batch */ + def getUnallocatedBlocks(streamId: Int): Seq[ReceivedBlockInfo] = synchronized { + getReceivedBlockQueue(streamId).toSeq + } + + /** Get the blocks allocated to a batch, or allocate blocks to the batch and then get them */ + def getOrAllocateBlocksToBatch(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + if (!timeToAllocatedBlockInfo.contains(batchTime)) { + allocateAllUnallocatedBlocksToBatch(batchTime) + } + timeToAllocatedBlockInfo(batchTime)(streamId) + } + } + + /** Check if any blocks are left to be allocated to batches */ + def hasUnallocatedReceivedBlocks(): Boolean = synchronized { + !streamIdToUnallocatedBlockInfo.values.forall(_.isEmpty) + } + + /** Clean up block information of old batches */ + def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized { + assert(cleanupThreshTime.milliseconds < clock.currentTime()) + val timesToCleanup = timeToAllocatedBlockInfo.keys.filter { _ < cleanupThreshTime }.toSeq + logInfo("Deleting batches " + timesToCleanup) + writeToLog(BatchCleanup(timesToCleanup)) + timeToAllocatedBlockInfo --= timesToCleanup + logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds)) + log + } + + /** Stop the block tracker */ + def stop() { + logManagerOption.foreach { _.stop() } + } + + /** Allocate all unallocated blocks to the given batch */ + private def allocateAllUnallocatedBlocksToBatch(batchTime: Time): AllocatedBlocks = synchronized { + val allocatedBlockInfos = AllocatedBlocks(streamIds.map { streamId => + (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + }.toMap) + writeToLog(BatchAllocations(batchTime, allocatedBlockInfos)) + timeToAllocatedBlockInfo(batchTime) = allocatedBlockInfos + allocatedBlockInfos + } + + private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = { + streamIdToUnallocatedBlockInfo.getOrElseUpdate(streamId, new ReceivedBlockQueue) + } + + /** Write an update to the tracker to the write ahead log */ + private def writeToLog(record: ReceivedBlockTrackerRecord) { + logDebug("Writing to log " + record) + logManagerOption.foreach { logManager => + logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + } + } + + private def recoverFromWriteAheadLogs(): Unit = synchronized { + logInfo("Recovering from checkpoint") + val blockIdToblockInfo = new mutable.HashMap[StreamBlockId, ReceivedBlockInfo] + + def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) { + logTrace(s"Recovery: Inserting added block $receivedBlockInfo") + println(s"Recovery: Inserting added block $receivedBlockInfo") + blockIdToblockInfo.put(receivedBlockInfo.blockId, receivedBlockInfo) + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + + def insertAllocatedBatch(time: Time, allocatedBlocks: AllocatedBlocks) { + logTrace(s"Recovery: Inserting allocated batch for time $time to ${allocatedBlocks.streamIdToAllocatedBlocks}") + println(s"Recovery: Inserting allocated batch for time $time to ${allocatedBlocks.streamIdToAllocatedBlocks}") + streamIdToUnallocatedBlockInfo.values.foreach { _.clear() } + timeToAllocatedBlockInfo.put(time, allocatedBlocks) + } + + def cleanupBatches(batchTimes: Seq[Time]) { + logTrace(s"Recovery: Cleaning up batches ${batchTimes}") + println(s"Recovery: Cleaning up batches ${batchTimes}") + timeToAllocatedBlockInfo --= batchTimes + } + + logManagerOption.foreach { logManager => + logManager.readFromLog().foreach { byteBuffer => + logTrace("Recovering record " + byteBuffer) + Utils.deserialize[ReceivedBlockTrackerRecord](byteBuffer.array) match { + case BlockAddition(receivedBlockInfo) => + insertAddedBlock(receivedBlockInfo) + case BatchAllocations(time, allocatedBlocks) => + insertAllocatedBatch(time, allocatedBlocks) + case BatchCleanup(batchTimes) => + cleanupBatches(batchTimes) + } + } + } + blockIdToblockInfo.clear() + } +} + +private[streaming] object ReceivedBlockTracker { + def checkpointDirToLogDir(checkpointDir: String): String = { + new Path(checkpointDir, "receivedBlockMetadata").toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogManager.scala new file mode 100644 index 0000000000000..c70ecb0da4e54 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogManager.scala @@ -0,0 +1,176 @@ +package org.apache.spark.streaming.storage + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.permission.FsPermission +import org.apache.spark.Logging +import org.apache.spark.streaming.storage.WriteAheadLogManager._ +import org.apache.spark.streaming.util.{Clock, SystemClock} +import org.apache.spark.util.Utils + +private[streaming] class WriteAheadLogManager( + logDirectory: String, + hadoopConf: Configuration, + rollingIntervalSecs: Int = 60, + maxFailures: Int = 3, + callerName: String = "", + clock: Clock = new SystemClock + ) extends Logging { + + private val pastLogs = new ArrayBuffer[LogInfo] + private val callerNameTag = + if (callerName != null && callerName.nonEmpty) s" for $callerName" else "" + private val threadpoolName = s"WriteAheadLogManager $callerNameTag" + implicit private val executionContext = ExecutionContext.fromExecutorService( + Utils.newDaemonFixedThreadPool(1, threadpoolName)) + override protected val logName = s"WriteAheadLogManager $callerNameTag" + + private var currentLogPath: String = null + private var currentLogWriter: WriteAheadLogWriter = null + private var currentLogWriterStartTime: Long = -1L + private var currentLogWriterStopTime: Long = -1L + + initializeOrRecover() + + def writeToLog(byteBuffer: ByteBuffer): FileSegment = synchronized { + var fileSegment: FileSegment = null + var failures = 0 + var lastException: Exception = null + var succeeded = false + while (!succeeded && failures < maxFailures) { + try { + fileSegment = getLogWriter(clock.currentTime).write(byteBuffer) + succeeded = true + } catch { + case ex: Exception => + lastException = ex + logWarning("Failed to ...") + resetWriter() + failures += 1 + } + } + if (fileSegment == null) { + throw lastException + } + fileSegment + } + + def readFromLog(): Iterator[ByteBuffer] = synchronized { + val logFilesToRead = pastLogs.map{ _.path} ++ Option(currentLogPath) + logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) + logFilesToRead.iterator.map { file => + logDebug(s"Creating log reader with $file") + new WriteAheadLogReader(file, hadoopConf) + } flatMap { x => x } + } + + /** + * Delete the log files that are older than the threshold time. + * + * Its important to note that the threshold time is based on the time stamps used in the log + * files, and is therefore based on the local system time. So if there is coordination necessary + * between the node calculating the threshTime (say, driver node), and the local system time + * (say, worker node), the caller has to take account of possible time skew. + */ + def cleanupOldLogs(threshTime: Long): Unit = { + val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } } + logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " + + s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}") + + def deleteFiles() { + oldLogFiles.foreach { logInfo => + try { + val path = new Path(logInfo.path) + val fs = hadoopConf.synchronized { path.getFileSystem(hadoopConf) } + fs.delete(path, true) + synchronized { pastLogs -= logInfo } + logDebug(s"Cleared log file $logInfo") + } catch { + case ex: Exception => + logWarning(s"Error clearing log file $logInfo", ex) + } + } + logInfo(s"Cleared log files in $logDirectory older than $threshTime") + } + if (!executionContext.isShutdown) { + Future { deleteFiles() } + } + } + + def stop(): Unit = synchronized { + if (currentLogWriter != null) { + currentLogWriter.close() + } + executionContext.shutdown() + logInfo("Stopped log manager") + } + + private def getLogWriter(currentTime: Long): WriteAheadLogWriter = synchronized { + if (currentLogWriter == null || currentTime > currentLogWriterStopTime) { + resetWriter() + if (currentLogPath != null) { + pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, currentLogPath) + } + currentLogWriterStartTime = currentTime + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + val newLogPath = new Path(logDirectory, + timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) + currentLogPath = newLogPath.toString + currentLogWriter = new WriteAheadLogWriter(currentLogPath, hadoopConf) + } + currentLogWriter + } + + private def initializeOrRecover(): Unit = synchronized { + val logDirectoryPath = new Path(logDirectory) + val fileSystem = logDirectoryPath.getFileSystem(hadoopConf) + + if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) + pastLogs.clear() + pastLogs ++= logFileInfo + logInfo(s"Recovered ${logFileInfo.size} log files from $logDirectory") + logDebug(s"Recovered files are:\n${logFileInfo.map(_.path).mkString("\n")}") + } else { + fileSystem.mkdirs(logDirectoryPath, + FsPermission.createImmutable(Integer.parseInt("770", 8).toShort)) + logInfo(s"Created ${logDirectory} for log files") + } + } + + private def resetWriter(): Unit = synchronized { + if (currentLogWriter != null) { + currentLogWriter.close() + currentLogWriter = null + } + } +} + +private[storage] object WriteAheadLogManager { + + case class LogInfo(startTime: Long, endTime: Long, path: String) + + val logFileRegex = """log-(\d+)-(\d+)""".r + + def timeToLogFile(startTime: Long, stopTime: Long): String = { + s"log-$startTime-$stopTime" + } + + def logFilesTologInfo(files: Seq[Path]): Seq[LogInfo] = { + files.flatMap { file => + logFileRegex.findFirstIn(file.getName()) match { + case Some(logFileRegex(startTimeStr, stopTimeStr)) => + val startTime = startTimeStr.toLong + val stopTime = stopTimeStr.toLong + Some(LogInfo(startTime, stopTime, file.toString)) + case None => + None + } + }.sortBy { _.startTime } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala index 724549e216e93..5e0dc1d49a89a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala @@ -20,16 +20,20 @@ import java.io.{EOFException, Closeable} import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration +import org.apache.spark.Logging private[streaming] class WriteAheadLogReader(path: String, conf: Configuration) - extends Iterator[ByteBuffer] with Closeable { + extends Iterator[ByteBuffer] with Closeable with Logging { private val instream = HdfsUtils.getInputStream(path, conf) private var closed = false private var nextItem: Option[ByteBuffer] = None override def hasNext: Boolean = synchronized { - assertOpen() + if (closed) { + return false + } + if (nextItem.isDefined) { // handle the case where hasNext is called without calling next true } else { @@ -38,33 +42,35 @@ private[streaming] class WriteAheadLogReader(path: String, conf: Configuration) val buffer = new Array[Byte](length) instream.readFully(buffer) nextItem = Some(ByteBuffer.wrap(buffer)) + logTrace("Read next item " + nextItem.get) true } catch { - case e: EOFException => false - case e: Exception => throw e + case e: EOFException => + logDebug("Error reading next item, EOF reached", e) + close() + false + case e: Exception => + logDebug("Error reading next item, EOF reached", e) + close() + throw e } } } override def next(): ByteBuffer = synchronized { - // TODO: Possible error case where there are not enough bytes in the stream - // TODO: How to handle that? val data = nextItem.getOrElse { - throw new IllegalStateException("next called without calling hasNext or after hasNext " + - "returned false") + close() + throw new IllegalStateException( + "next called without calling hasNext or after hasNext returned false") } nextItem = None // Ensure the next hasNext call loads new data. data } override def close(): Unit = synchronized { + if (!closed) { + instream.close() + } closed = true - instream.close() } - - private def assertOpen() { - HdfsUtils.checkState(!closed, "Stream is closed. Create a new Reader to read from the " + - "file.") - } - } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala index 8a2db8305a7e2..68a1172d7d282 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala @@ -16,20 +16,38 @@ */ package org.apache.spark.streaming.storage -import java.io.Closeable -import java.lang.reflect.Method +import java.io._ +import java.net.URI import java.nio.ByteBuffer import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FSDataOutputStream +import org.apache.hadoop.fs.{FSDataOutputStream, FileSystem} +import org.apache.spark.streaming.storage.FileSegment private[streaming] class WriteAheadLogWriter(path: String, conf: Configuration) extends Closeable { - private val stream = HdfsUtils.getOutputStream(path, conf) - private var nextOffset = stream.getPos + private val underlyingStream: Either[DataOutputStream, FSDataOutputStream] = { + val uri = new URI(path) + val defaultFs = FileSystem.getDefaultUri(conf).getScheme + val isDefaultLocal = defaultFs == null || defaultFs == "file" + + if ((isDefaultLocal && uri.getScheme == null) || uri.getScheme == "file") { + assert(!new File(uri.getPath).exists) + Left(new DataOutputStream(new BufferedOutputStream(new FileOutputStream(uri.getPath)))) + } else { + Right(HdfsUtils.getOutputStream(path, conf)) + } + } + + private lazy val hadoopFlushMethod = { + val cls = classOf[FSDataOutputStream] + Try(cls.getMethod("hflush")).orElse(Try(cls.getMethod("sync"))).toOption + } + + private var nextOffset = getPosition() private var closed = false - private val hflushMethod = getHflushOrSync() + // Data is always written as: // - Length - Long @@ -48,8 +66,8 @@ private[streaming] class WriteAheadLogWriter(path: String, conf: Configuration) stream.write(data.get()) } } - hflushOrSync() - nextOffset = stream.getPos + flush() + nextOffset = getPosition() segment } @@ -58,17 +76,22 @@ private[streaming] class WriteAheadLogWriter(path: String, conf: Configuration) stream.close() } - private def hflushOrSync() { - hflushMethod.foreach(_.invoke(stream)) + private def stream(): DataOutputStream = { + underlyingStream.fold(x => x, x => x) + } + + private def getPosition(): Long = { + underlyingStream match { + case Left(localStream) => localStream.size + case Right(dfsStream) => dfsStream.getPos() + } } - private def getHflushOrSync(): Option[Method] = { - Try { - Some(classOf[FSDataOutputStream].getMethod("hflush")) - }.recover { - case e: NoSuchMethodException => - Some(classOf[FSDataOutputStream].getMethod("sync")) - }.getOrElse(None) + private def flush() { + underlyingStream match { + case Left(localStream) => localStream.flush + case Right(dfsStream) => hadoopFlushMethod.foreach { _.invoke(dfsStream) } + } } private def assertOpen() { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala new file mode 100644 index 0000000000000..802b24e47e21c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File +import java.util.UUID +import java.util.concurrent.atomic.AtomicInteger + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.{HashPartitioner, Logging, SparkConf} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerBatchCompleted} +import org.apache.spark.util.Utils +import scala.collection.mutable.ArrayBuffer + +/** + * This testsuite tests master failures at random times while the stream is running using + * the real clock. + */ +class DriverFailureSuite extends TestSuiteBase with Logging { + + var directory = "FailureSuite" + val numBatches = 30 + + override def batchDuration = Milliseconds(1000) + + override def useManualClock = false + + override def beforeFunction() { + super.beforeFunction() + Utils.deleteRecursively(new File(directory)) + } + + override def afterFunction() { + super.afterFunction() + Utils.deleteRecursively(new File(directory)) + } +/* + test("multiple failures with map") { + MasterFailureTest.testMap(directory, numBatches, batchDuration) + } + + test("multiple failures with updateStateByKey") { + MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) + } +*/ + test("multiple failures with receiver and updateStateByKey") { + + + val operation = (st: DStream[String]) => { + + val mapPartitionFunc = (iterator: Iterator[String]) => { + Iterator(iterator.flatMap(_.split(" ")).map(_ -> 1L).reduce((x, y) => (x._1, x._2 + y._2))) + } + + val updateFunc = (iterator: Iterator[(String, Seq[Long], Option[Seq[Long]])]) => { + iterator.map { case (key, values, state) => + val combined = (state.getOrElse(Seq.empty) ++ values).sorted + if (state.isEmpty || state.get.max != DriverFailureTestReceiver.maxRecordsPerBlock) { + val oldState = s"[${ state.map { _.max }.getOrElse(-1) }, ${state.map { _.distinct.sum }.getOrElse(0)}]" + val newState = s"[${combined.max}, ${combined.distinct.sum}]" + println(s"Updated state for $key: state = $oldState, new values = $values, new state = $newState") + } + (key, combined) + } + } + + st.mapPartitions(mapPartitionFunc) + .updateStateByKey[Seq[Long]](updateFunc, new HashPartitioner(2), rememberPartitioner = false) + .checkpoint(batchDuration * 5) + } + + val maxValue = DriverFailureTestReceiver.maxRecordsPerBlock + val expectedValues = (1L to maxValue).toSet + + val verify = (time: Time, output: Seq[(String, Seq[Long])]) => { + val outputStr = output.map { x => (x._1, x._2.distinct.sum) }.sortBy(_._1).mkString(", ") + println(s"State at $time: $outputStr") + + val incompletelyReceivedWords = output.filter { _._2.max < maxValue } + if (incompletelyReceivedWords.size > 1) { + val debugStr = incompletelyReceivedWords.map { x => + s"""${x._1}: ${x._2.mkString(",")}, sum = ${x._2.distinct.sum}""" + }.mkString("\n") + throw new Exception(s"Incorrect processing of input, all input not processed:\n$debugStr\n") + } + + output.foreach { case (key, values) => + if (!values.forall(expectedValues.contains)) { + val sum = values.distinct.sum + val debugStr = values.zip(1L to values.size).map { + x => if (x._1 == x._2) x._1 else s"[${x._2}]" + }.mkString(",") + s", sum = $sum" + throw new Exception(s"Incorrect sequence of values in output:\n$debugStr\n") + } + } + } + + val driverTest = new ReceiverBasedDriverFailureTest[(String, Seq[Long])]( + "./driver-test/", 200, 50, operation, verify) + driverTest.testAndGetError().map { errorMessage => + fail(errorMessage) + } + } +} + + +abstract class DriverFailureTest( + testDirectory: String, + batchDurationMillis: Int, + numBatchesToRun: Int + ) extends Logging { + + @transient private val checkpointDir = createCheckpointDir() + @transient private val timeoutMillis = batchDurationMillis * numBatchesToRun * 4 + + @transient @volatile private var killed = false + @transient @volatile private var killCount = 0 + @transient @volatile private var lastBatchCompleted = 0L + @transient @volatile private var batchesCompleted = 0 + @transient @volatile private var ssc: StreamingContext = null + + protected def setupContext(checkpointDirector: String): StreamingContext + + //---------------------------------------- + + /** + * Run the test and return an option string containing error message. + * @return None is test succeeded, or Some(errorMessage) if test failed + */ + def testAndGetError(): Option[String] = { + DriverFailureTest.reset() + ssc = setupContext(checkpointDir.toString) + run() + } + + private def run(): Option[String] = { + + val runStartTime = System.currentTimeMillis + var killingThread: Thread = null + + def allBatchesCompleted = batchesCompleted >= numBatchesToRun + def timedOut = (System.currentTimeMillis - runStartTime) > timeoutMillis + def failed = DriverFailureTest.failed + + while(!failed && !allBatchesCompleted && !timedOut) { + // Start the thread to kill the streaming after some time + killed = false + try { + ssc.addStreamingListener(new BatchCompletionListener) + ssc.start() + + killingThread = new KillingThread(ssc, batchDurationMillis * 10) + killingThread.start() + + while (!failed && !killed && !allBatchesCompleted && !timedOut) { + ssc.awaitTermination(1) + } + } catch { + case e: Exception => + logError("Error running streaming context", e) + DriverFailureTest.fail("Error running streaming context: " + e) + } + + logInfo(s"Failed = $failed") + logInfo(s"Killed = $killed") + logInfo(s"All batches completed = $allBatchesCompleted") + logInfo(s"Timed out = $timedOut") + + if (killingThread.isAlive) { + killingThread.interrupt() + ssc.stop() + } + + if (!timedOut) { + val sleepTime = Random.nextInt(batchDurationMillis * 10) + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation in " + sleepTime + " ms " + + "\n-------------------------------------------\n" + ) + Thread.sleep(sleepTime) + + // Recreate the streaming context from checkpoint + System.clearProperty("spark.driver.port") + ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => { + throw new Exception("Trying to create new context when it " + + "should be reading from checkpoint file") + }) + println("Restarted") + } + } + + if (failed) { + Some(s"Failed with message: ${DriverFailureTest.firstFailureMessage}") + } else if (timedOut) { + Some(s"Timed out after $batchesCompleted/$numBatchesToRun batches, and " + + s"${System.currentTimeMillis} ms (time out = $timeoutMillis ms)") + } else if (allBatchesCompleted) { + None + } else { + throw new Exception("Unexpected end of test") + } + } + + private def createCheckpointDir(): Path = { + // Create the directories for this test + val uuid = UUID.randomUUID().toString + val rootDir = new Path(testDirectory, uuid) + val fs = rootDir.getFileSystem(new Configuration()) + val dir = new Path(rootDir, "checkpoint") + fs.mkdirs(dir) + dir + } + + class BatchCompletionListener extends StreamingListener { + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + if (batchCompleted.batchInfo.batchTime.milliseconds > lastBatchCompleted) { + batchesCompleted += 1 + lastBatchCompleted = batchCompleted.batchInfo.batchTime.milliseconds + } + } + } + + class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging { + override def run() { + try { + // If it is the first killing, then allow the first checkpoint to be created + var minKillWaitTime = if (killCount == 0) 5000 else 2000 + val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime) + logInfo("Kill wait time = " + killWaitTime) + Thread.sleep(killWaitTime) + logInfo( + "\n---------------------------------------\n" + + "Killing streaming context after " + killWaitTime + " ms" + + "\n---------------------------------------\n" + ) + ssc.stop() + killed = true + killCount += 1 + println("Killed") + logInfo("Killing thread finished normally") + } catch { + case ie: InterruptedException => logInfo("Killing thread interrupted") + case e: Exception => logWarning("Exception in killing thread", e) + } + + } + } +} + +object DriverFailureTest { + @transient @volatile var failed: Boolean = _ + @transient @volatile var firstFailureMessage: String = _ + + def fail(message: String) { + if (!failed) { + failed = true + firstFailureMessage = message + } + } + + def reset() { + failed = false + firstFailureMessage = "NOT SET" + } +} + +class ReceiverBasedDriverFailureTest[T]( + @transient testDirectory: String, + @transient batchDurationMillis: Int, + @transient numBatchesToRun: Int, + @transient operation: DStream[String] => DStream[T], + outputVerifyingFunction: (Time, Seq[T]) => Unit + ) extends DriverFailureTest( + testDirectory, batchDurationMillis, numBatchesToRun + ) { + + @transient val conf = new SparkConf() + conf.setMaster("local[4]") + .setAppName("ReceiverBasedDriverFailureTest") + .set("spark.streaming.receiver.writeAheadLog.enable", "true") // enable write ahead log + .set("spark.streaming.receiver.writeAheadLog.rotationIntervalSecs", "10") // rotate logs to test cleanup + + override def setupContext(checkpointDirector: String): StreamingContext = { + + val context = StreamingContext.getOrCreate(checkpointDirector, () => { + val newSsc = new StreamingContext(conf, Milliseconds(batchDurationMillis)) + val inputStream = newSsc.receiverStream[String](new DriverFailureTestReceiver) + /*inputStream.mapPartitions(iter => { + val sum = iter.map { _.split(" ").size }.fold(0)(_ + _) + Iterator(sum) + }).foreachRDD ((rdd: RDD[Int], time: Time) => { + try { + val collected = rdd.collect().sorted + println(s"# in partitions at $time = ${collected.mkString(", ")}") + } catch { + case ie: InterruptedException => + // ignore + case e: Exception => + DriverFailureTest.fail(e.toString) + } + + })*/ + val operatedStream = operation(inputStream) + + val verify = outputVerifyingFunction + operatedStream.foreachRDD((rdd: RDD[T], time: Time) => { + try { + val collected = rdd.collect() + verify(time, collected) + } catch { + case ie: InterruptedException => + // ignore + case e: Exception => + DriverFailureTest.fail(e.toString) + } + }) + newSsc.checkpoint(checkpointDirector) + newSsc + }) + context + } +} + + + +class DriverFailureTestReceiver extends Receiver[String](StorageLevel.MEMORY_ONLY_SER) with Logging { + + import DriverFailureTestReceiver._ + @volatile var thread: Thread = null + + class ReceivingThread extends Thread() { + override def run() { + while (!isStopped() && !isInterrupted()) { + try { + val block = getNextBlock() + store(block) + commitBlock() + Thread.sleep(10) + } catch { + case ie: InterruptedException => + case e: Exception => + DriverFailureTestReceiver.this.stop("Error in receiving thread", e) + } + } + } + } + + def onStart() { + if (thread == null) { + thread = new ReceivingThread() + thread.start() + } else { + logError("Error starting receiver, previous receiver thread not stopped yet.") + } + } + + def onStop() { + if (thread != null) { + thread.interrupt() + thread = null + } + } +} + +object DriverFailureTestReceiver { + val maxRecordsPerBlock = 1000L + private val currentKey = new AtomicInteger() + private val counter = new AtomicInteger() + + counter.set(1) + currentKey.set(1) + + def getNextBlock(): ArrayBuffer[String] = { + val count = counter.get() + new ArrayBuffer ++= (1 to count).map { _ => "word%03d".format(currentKey.get()) } + } + + def commitBlock() { + println(s"Stored ${counter.get()} copies of word${currentKey.get}") + if (counter.incrementAndGet() > maxRecordsPerBlock) { + currentKey.incrementAndGet() + counter.set(1) + } + } +} + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala deleted file mode 100644 index 40434b1f9b709..0000000000000 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import org.apache.spark.Logging -import org.apache.spark.util.Utils - -import java.io.File - -/** - * This testsuite tests master failures at random times while the stream is running using - * the real clock. - */ -class FailureSuite extends TestSuiteBase with Logging { - - var directory = "FailureSuite" - val numBatches = 30 - - override def batchDuration = Milliseconds(1000) - - override def useManualClock = false - - override def beforeFunction() { - super.beforeFunction() - Utils.deleteRecursively(new File(directory)) - } - - override def afterFunction() { - super.afterFunction() - Utils.deleteRecursively(new File(directory)) - } - - test("multiple failures with map") { - MasterFailureTest.testMap(directory, numBatches, batchDuration) - } - - test("multiple failures with updateStateByKey") { - MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) - } -} - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala new file mode 100644 index 0000000000000..6b791df092f64 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -0,0 +1,214 @@ +package org.apache.spark.streaming + +import java.io.File + +import scala.Some +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.{implicitConversions, postfixOps} +import scala.util.Random + +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.storage._ +import org.apache.spark.streaming.storage.WriteAheadLogSuite._ +import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock} +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.storage.StreamBlockId +import scala.Some +import org.apache.spark.streaming.storage.BlockAddition + +class ReceivedBlockTrackerSuite + extends FunSuite with BeforeAndAfter with Matchers with Logging { + + val conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") + conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") + + val hadoopConf = new Configuration() + val akkaTimeout = 10 seconds + val streamId = 1 + + var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]() + var checkpointDirectory: File = null + + before { + checkpointDirectory = Files.createTempDir() + } + + after { + allReceivedBlockTrackers.foreach { _.stop() } + if (checkpointDirectory != null && checkpointDirectory.exists()) { + FileUtils.deleteDirectory(checkpointDirectory) + checkpointDirectory = null + } + } + + test("block addition, and block to batch allocation") { + val receivedBlockTracker = createTracker(enableCheckpoint = false) + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + val blockInfos = generateBlockInfos() + blockInfos.map(receivedBlockTracker.addBlock) + + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos + receivedBlockTracker.getOrAllocateBlocksToBatch(1, streamId) shouldEqual blockInfos + receivedBlockTracker.getUnallocatedBlocks(streamId) should have size 0 + receivedBlockTracker.getOrAllocateBlocksToBatch(1, streamId) shouldEqual blockInfos + receivedBlockTracker.getOrAllocateBlocksToBatch(2, streamId) should have size 0 + } + + test("block addition, block to batch allocation and cleanup with write ahead log") { + val manualClock = new ManualClock + conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) should be (1) + + // Set the time increment level to twice the rotation interval so that every increment creates + // a new log file + val timeIncrementMillis = 2000L + def incrementTime() { + manualClock.addToTime(timeIncrementMillis) + } + + // Generate and add blocks to the given tracker + def addBlockInfos(tracker: ReceivedBlockTracker): Seq[ReceivedBlockInfo] = { + val blockInfos = generateBlockInfos() + blockInfos.map(tracker.addBlock) + blockInfos + } + + // Print the data present in the log ahead files in the log directory + def printLogFiles(message: String) { + val fileContents = getWriteAheadLogFiles().map { file => + (s"\n>>>>> $file: <<<<<\n${getWrittenLogData(file).mkString("\n")}") + }.mkString("\n") + logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n") + } + + // Start tracker and add blocks + val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock) + val blockInfos1 = addBlockInfos(tracker1) + tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Verify whether write ahead log has correct contents + val expectedWrittenData1 = blockInfos1.map(BlockAddition) + getWrittenLogData() shouldEqual expectedWrittenData1 + getWriteAheadLogFiles() should have size 1 + + // Restart tracker and verify recovered list of unallocated blocks + incrementTime() + val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Allocate blocks to batch and verify whether the unallocated blocks got allocated + val batchTime1 = manualClock.currentTime + tracker2.getOrAllocateBlocksToBatch(batchTime1, streamId) shouldEqual blockInfos1 + + // Add more blocks and allocate to another batch + incrementTime() + val batchTime2 = manualClock.currentTime + val blockInfos2 = addBlockInfos(tracker2) + tracker2.getOrAllocateBlocksToBatch(batchTime2, streamId) shouldEqual blockInfos2 + + // Verify whether log has correct contents + val expectedWrittenData2 = expectedWrittenData1 ++ + Seq(createBatchAllocation(batchTime1, blockInfos1)) ++ + blockInfos2.map(BlockAddition) ++ + Seq(createBatchAllocation(batchTime2, blockInfos2)) + getWrittenLogData() shouldEqual expectedWrittenData2 + + // Restart tracker and verify recovered state + incrementTime() + val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker3.getOrAllocateBlocksToBatch(batchTime1, streamId) shouldEqual blockInfos1 + tracker3.getOrAllocateBlocksToBatch(batchTime2, streamId) shouldEqual blockInfos2 + tracker3.getUnallocatedBlocks(streamId) shouldBe empty + + // Cleanup first batch but not second batch + val oldestLogFile = getWriteAheadLogFiles().head + incrementTime() + tracker3.cleanupOldBatches(batchTime2) + + // Verify that the batch allocations have been cleaned, and the act has been written to log + tracker3.getOrAllocateBlocksToBatch(batchTime1, streamId) shouldEqual Seq.empty + getWrittenLogData(getWriteAheadLogFiles().last) should contain(createBatchCleanup(batchTime1)) + + // Verify that at least one log file gets deleted + eventually(timeout(10 seconds), interval(10 millisecond )) { + getWriteAheadLogFiles() should not contain oldestLogFile + } + printLogFiles("After cleanup") + + // Restart tracker and verify recovered state, specifically whether info about the first + // batch has been removed, but not the second batch + incrementTime() + val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker4.getUnallocatedBlocks(streamId) shouldBe empty + tracker4.getOrAllocateBlocksToBatch(batchTime1, streamId) shouldBe empty // should be cleaned + tracker4.getOrAllocateBlocksToBatch(batchTime2, streamId) shouldEqual blockInfos2 + } + + /** + * Create tracker object with the optional provided clock. Use fake clock if you + * want to control time by manually incrementing it to test log cleanup. + */ + def createTracker(enableCheckpoint: Boolean, clock: Clock = new SystemClock): ReceivedBlockTracker = { + val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) else None + val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption) + allReceivedBlockTrackers += tracker + tracker + } + + /** Generate blocks infos using random ids */ + def generateBlockInfos(): Seq[ReceivedBlockInfo] = { + (1 to 5).map { id => + new ReceivedBlockInfo(streamId, + StreamBlockId(streamId, math.abs(Random.nextInt())), 0, null, None) + }.toList + } + + /** Get all the data written in the given write ahead log file. */ + def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerRecord] = { + getWrittenLogData(Seq(logFile)) + } + + /** + * Get all the data written in the given write ahead log files. By default, it will read all + * files in the test log directory. + */ + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerRecord] = { + logFiles.flatMap { + file => new WriteAheadLogReader(file, hadoopConf).toSeq + }.map { byteBuffer => + Utils.deserialize[ReceivedBlockTrackerRecord](byteBuffer.array) + }.toList + } + + /** Get all the write ahead log files in the test directory */ + def getWriteAheadLogFiles(): Seq[String] = { + import ReceivedBlockTracker._ + val logDir = checkpointDirToLogDir(checkpointDirectory.toString) + getLogFilesInDirectory(new File(logDir)).map { _.toString } + } + + /** Create batch allocation object from the given info */ + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocations = { + BatchAllocations(time, AllocatedBlocks(Map((streamId -> blockInfos)))) + } + + /** Create batch cleanup object from the given info */ + def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanup = { + BatchCleanup((Seq(time) ++ moreTimes).map(Time.apply)) + } + + implicit def millisToTime(milliseconds: Long): Time = Time(milliseconds) + + implicit def timeToMillis(time: Time): Long = time.milliseconds + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDDSuite.scala similarity index 100% rename from streaming/src/test/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDDSuite.scala rename to streaming/src/test/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDDSuite.scala diff --git a/streaming/src/test/scala/org/apache/spark/streaming/storage/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/storage/ReceivedBlockHandlerSuite.scala new file mode 100644 index 0000000000000..654bc7bad92b6 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/storage/ReceivedBlockHandlerSuite.scala @@ -0,0 +1,138 @@ +package org.apache.spark.streaming.storage + +import java.io.File + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import akka.actor.{ActorSystem, Props} +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} +import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.storage._ +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.util.AkkaUtils +import WriteAheadLogBasedBlockHandler._ +import WriteAheadLogSuite._ + +class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers { + + val conf = new SparkConf() + .set("spark.authenticate", "false") + .set("spark.kryoserializer.buffer.mb", "1") + .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val hadoopConf = new Configuration() + val storageLevel = StorageLevel.MEMORY_ONLY_SER + val streamId = 1 + val securityMgr = new SecurityManager(conf) + val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) + val serializer = new KryoSerializer(conf) + val manualClock = new ManualClock + + var actorSystem: ActorSystem = null + var blockManagerMaster: BlockManagerMaster = null + var blockManager: BlockManager = null + var receivedBlockHandler: ReceivedBlockHandler = null + var tempDirectory: File = null + + before { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + "test", "localhost", 0, conf = conf, securityManager = securityMgr) + this.actorSystem = actorSystem + + conf.set("spark.driver.port", boundPort.toString) + + blockManagerMaster = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + conf, true) + blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, 100000, + conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr)) + tempDirectory = Files.createTempDir() + manualClock.setTime(0) + } + + after { + actorSystem.shutdown() + actorSystem.awaitTermination() + actorSystem = null + blockManagerMaster = null + + if (blockManager != null) { + blockManager.stop() + blockManager = null + } + + if (tempDirectory != null && tempDirectory.exists()) { + FileUtils.deleteDirectory(tempDirectory) + tempDirectory = null + } + } + + test("WriteAheadLogBasedBlockHandler - store block") { + createWriteAheadLogBasedBlockHandler() + val (data, blockIds) = generateAndStoreData(receivedBlockHandler) + receivedBlockHandler.asInstanceOf[WriteAheadLogBasedBlockHandler].stop() + + // Check whether blocks inserted in the block manager are correct + val blockManagerData = blockIds.flatMap { blockId => + blockManager.getLocal(blockId).map { _.data }.getOrElse(Iterator.empty) + } + blockManagerData.toList shouldEqual data.toList + + // Check whether the blocks written to the write ahead log are correct + val logFiles = getWriteAheadLogFiles() + logFiles.size should be > 1 + + val logData = logFiles.flatMap { + file => new WriteAheadLogReader(file.toString, hadoopConf).toSeq + }.flatMap { blockManager.dataDeserialize(StreamBlockId(streamId, 0), _ )} + logData.toList shouldEqual data.toList + } + + test("WriteAheadLogBasedBlockHandler - clear old blocks") { + createWriteAheadLogBasedBlockHandler() + generateAndStoreData(receivedBlockHandler) + val preCleanupLogFiles = getWriteAheadLogFiles() + preCleanupLogFiles.size should be > 1 + + // this depends on the number of blocks inserted using generateAndStoreData() + manualClock.currentTime() shouldEqual 5000L + + val cleanupThreshTime = 3000L + receivedBlockHandler.cleanupOldBlock(cleanupThreshTime) + eventually(timeout(10000 millis), interval(10 millis)) { + getWriteAheadLogFiles().size should be < preCleanupLogFiles.size + } + } + + def createWriteAheadLogBasedBlockHandler() { + receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1, + storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) + } + + def generateAndStoreData( + receivedBlockHandler: ReceivedBlockHandler): (Seq[String], Seq[StreamBlockId]) = { + val data = (1 to 100).map { _.toString } + val blocks = data.grouped(10).map { _.toIterator }.toSeq + val blockIds = (0 until blocks.size).map { i => StreamBlockId(streamId, i) } + blocks.zip(blockIds).foreach { case (block, id) => + manualClock.addToTime(500) // log rolling interval set to 1000 ms through SparkConf + receivedBlockHandler.storeBlock(id, IteratorBlock(block)) + } + (data, blockIds) + } + + def getWriteAheadLogFiles(): Seq[File] = { + getLogFilesInDirectory( + new File(checkpointDirToLogDir(tempDirectory.toString, streamId))) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala index ed21bdbb399fd..88b2b5095ceb6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala @@ -16,145 +16,262 @@ */ package org.apache.spark.streaming.storage -import java.io.{RandomAccessFile, File} +import java.io.{DataInputStream, FileInputStream, File, RandomAccessFile} import java.nio.ByteBuffer -import java.util.Random + +import scala.util.Random import scala.collection.mutable.ArrayBuffer +import scala.language.implicitConversions import com.google.common.io.Files import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.commons.io.FileUtils +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.util.Utils +import WriteAheadLogSuite._ -import org.apache.spark.streaming.TestSuiteBase +class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { -class WriteAheadLogSuite extends TestSuiteBase { - val hadoopConf = new Configuration() - val random = new Random() + var tempDirectory: File = null - test("Test successful writes") { - val dir = Files.createTempDir() - val file = new File(dir, "TestWriter") - try { - val dataToWrite = for (i <- 1 to 50) yield generateRandomData() - val writer = new WriteAheadLogWriter("file:///" + file.toString, hadoopConf) - val segments = dataToWrite.map(writer.write) - writer.close() - val writtenData = readData(segments, file) - assert(writtenData.toArray === dataToWrite.toArray) - } finally { - file.delete() - dir.delete() + before { + tempDirectory = Files.createTempDir() + } + + after { + if (tempDirectory != null && tempDirectory.exists()) { + FileUtils.deleteDirectory(tempDirectory) + tempDirectory = null } } - test("Test successful reads using random reader") { - val file = File.createTempFile("TestRandomReads", "") - file.deleteOnExit() - val writtenData = writeData(50, file) - val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf) - writtenData.foreach { - x => - val length = x._1.remaining() - assert(x._1 === reader.read(new FileSegment(file.toString, x._2, length))) + test("WriteAheadLogWriter - writing data") { + val file = new File(tempDirectory, Random.nextString(10)) + val dataToWrite = generateRandomData() + val writer = new WriteAheadLogWriter("file:///" + file, hadoopConf) + val segments = dataToWrite.map(data => writer.write(data)) + writer.close() + val writtenData = readDataManually(file, segments) + assert(writtenData.toArray === dataToWrite.toArray) + } + + test("WriteAheadLogWriter - syncing of data by writing and reading immediately") { + val file = new File(tempDirectory, Random.nextString(10)) + val dataToWrite = generateRandomData() + val writer = new WriteAheadLogWriter("file:///" + file, hadoopConf) + dataToWrite.foreach { data => + val segment = writer.write(data) + assert(readDataManually(file, Seq(segment)).head === data) } - reader.close() + writer.close() } - test("Test reading data using random reader written with writer") { - val dir = Files.createTempDir() - val file = new File(dir, "TestRandomReads") - try { - val dataToWrite = for (i <- 1 to 50) yield generateRandomData() - val segments = writeUsingWriter(file, dataToWrite) - val iter = dataToWrite.iterator - val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf) - val writtenData = segments.map { x => - reader.read(x) - } - assert(dataToWrite.toArray === writtenData.toArray) - } finally { - file.delete() - dir.delete() + test("WriteAheadLogReader - sequentially reading data") { + // Write data manually for testing the sequential reader + val file = File.createTempFile("TestSequentialReads", "", tempDirectory) + val writtenData = generateRandomData() + writeDataManually(writtenData, file) + val reader = new WriteAheadLogReader("file:///" + file.toString, hadoopConf) + val readData = reader.toSeq.map(byteBufferToString) + assert(readData.toList === writtenData.toList) + assert(reader.hasNext === false) + intercept[Exception] { + reader.next() } + reader.close() } - test("Test successful reads using sequential reader") { - val file = File.createTempFile("TestSequentialReads", "") - file.deleteOnExit() - val writtenData = writeData(50, file) + test("WriteAheadLogReader - sequentially reading data written with writer") { + // Write data manually for testing the sequential reader + val file = new File(tempDirectory, "TestWriter") + val dataToWrite = generateRandomData() + writeDataUsingWriter(file, dataToWrite) + val iter = dataToWrite.iterator val reader = new WriteAheadLogReader("file:///" + file.toString, hadoopConf) - val iter = writtenData.iterator - iter.foreach { x => - assert(reader.hasNext === true) - assert(reader.next() === x._1) + reader.foreach { byteBuffer => + assert(byteBufferToString(byteBuffer) === iter.next()) } reader.close() } + test("WriteAheadLogRandomReader - reading data using random reader") { + // Write data manually for testing the random reader + val file = File.createTempFile("TestRandomReads", "", tempDirectory) + val writtenData = generateRandomData() + val segments = writeDataManually(writtenData, file) - test("Test reading data using sequential reader written with writer") { - val dir = Files.createTempDir() - val file = new File(dir, "TestWriter") - try { - val dataToWrite = for (i <- 1 to 50) yield generateRandomData() - val segments = writeUsingWriter(file, dataToWrite) - val iter = dataToWrite.iterator - val reader = new WriteAheadLogReader("file:///" + file.toString, hadoopConf) - reader.foreach { x => - assert(x === iter.next()) - } - } finally { - file.delete() - dir.delete() + // Get a random order of these segments and read them back + val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten + val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf) + writtenDataAndSegments.foreach { case (data, segment) => + assert(data === byteBufferToString(reader.read(segment))) } + reader.close() + } + + test("WriteAheadLogRandomReader - reading data using random reader written with writer") { + // Write data using writer for testing the random reader + val file = new File(tempDirectory, "TestRandomReads") + val data = generateRandomData() + val segments = writeDataUsingWriter(file, data) + + // Read a random sequence of segments and verify read data + val dataAndSegments = data.zip(segments).toSeq.permutations.take(10).flatten + val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf) + dataAndSegments.foreach { case(data, segment) => + assert(data === byteBufferToString(reader.read(segment))) + } + reader.close() + } + + test("WriteAheadLogManager - write rotating logs") { + // Write data using manager + val dataToWrite = generateRandomData(10) + writeDataUsingManager(tempDirectory, dataToWrite) + + // Read data manually to verify the written data + val logFiles = getLogFilesInDirectory(tempDirectory) + assert(logFiles.size > 1) + val writtenData = logFiles.flatMap { file => readDataManually(file) } + assert(writtenData.toList === dataToWrite.toList) + } + + test("WriteAheadLogManager - read rotating logs") { + // Write data manually for testing reading through manager + val writtenData = (1 to 10).map { i => + val data = generateRandomData(10) + val file = new File(tempDirectory, s"log-$i-${i + 1}") + writeDataManually(data, file) + data + }.flatten + + // Read data using manager and verify + val readData = readDataUsingManager(tempDirectory) + assert(readData.toList === writtenData.toList) + } + + test("WriteAheadLogManager - recover past logs when creating new manager") { + // Write data with manager, recover with new manager and verify + val dataToWrite = generateRandomData(100) + writeDataUsingManager(tempDirectory, dataToWrite) + val logFiles = getLogFilesInDirectory(tempDirectory) + assert(logFiles.size > 1) + val readData = readDataUsingManager(tempDirectory) + assert(dataToWrite.toList === readData.toList) } + // TODO (Hari, TD): Test different failure conditions of writers and readers. + // - Failure in the middle of write + // - Failure while reading incomplete/corrupt file +} + +object WriteAheadLogSuite { + + private val hadoopConf = new Configuration() + /** - * Writes data to the file and returns the an array of the bytes written. - * @param count - * @return + * Write data to the file and returns the an array of the bytes written. + * This is used to test the WAL reader independently of the WAL writer. */ - // We don't want to be using the WAL writer to test the reader - it will be painful to figure - // out where the bug is. Instead generate the file by hand and see if the WAL reader can - // handle it. - def writeData(count: Int, file: File): ArrayBuffer[(ByteBuffer, Long)] = { - val writtenData = new ArrayBuffer[(ByteBuffer, Long)]() + def writeDataManually(data: Seq[String], file: File): Seq[FileSegment] = { + val segments = new ArrayBuffer[FileSegment]() val writer = new RandomAccessFile(file, "rw") - var i = 0 - while (i < count) { - val data = generateRandomData() - writtenData += ((data, writer.getFilePointer)) - data.rewind() - writer.writeInt(data.remaining()) - writer.write(data.array()) - i += 1 + data.foreach { item => + val offset = writer.getFilePointer() + val bytes = Utils.serialize(item) + writer.writeInt(bytes.size) + writer.write(bytes) + segments += FileSegment(file.toString, offset, bytes.size) } writer.close() - writtenData + segments } - def readData(segments: Seq[FileSegment], file: File): Seq[ByteBuffer] = { + def writeDataUsingWriter(file: File, data: Seq[String]): Seq[FileSegment] = { + val writer = new WriteAheadLogWriter(file.toString, hadoopConf) + val segments = data.map { + item => writer.write(item) + } + writer.close() + segments + } + + def writeDataUsingManager(logDirectory: File, data: Seq[String]) { + val fakeClock = new ManualClock + val manager = new WriteAheadLogManager(logDirectory.toString, hadoopConf, + rollingIntervalSecs = 1, callerName = "WriteAheadLogSuite", clock = fakeClock) + data.foreach { item => + fakeClock.addToTime(500) + manager.writeToLog(item) + } + manager.stop() + } + + /** + * Read data from the given segments of log file and returns the read list of byte buffers. + * This is used to test the WAL writer independently of the WAL reader. + */ + def readDataManually(file: File, segments: Seq[FileSegment]): Seq[String] = { val reader = new RandomAccessFile(file, "r") segments.map { x => reader.seek(x.offset) val data = new Array[Byte](x.length) reader.readInt() reader.readFully(data) - ByteBuffer.wrap(data) + Utils.deserialize[String](data) + } + } + + def readDataManually(file: File): Seq[String] = { + val reader = new DataInputStream(new FileInputStream(file)) + val buffer = new ArrayBuffer[String] + try { + while (reader.available > 0) { + val length = reader.readInt() + val bytes = new Array[Byte](length) + reader.read(bytes) + buffer += Utils.deserialize[String](bytes) + } + } finally { + reader.close() } + buffer } - def generateRandomData(): ByteBuffer = { - val data = new Array[Byte](random.nextInt(50)) - random.nextBytes(data) - ByteBuffer.wrap(data) + def readDataUsingManager(logDirectory: File): Seq[String] = { + val manager = new WriteAheadLogManager(logDirectory.toString, hadoopConf, + callerName = "WriteAheadLogSuite") + val data = manager.readFromLog().map(byteBufferToString).toSeq + manager.stop() + data } - def writeUsingWriter(file: File, input: Seq[ByteBuffer]): Seq[FileSegment] = { - val writer = new WriteAheadLogWriter(file.toString, hadoopConf) - val segments = input.map(writer.write) - writer.close() - segments + def generateRandomData(numItems: Int = 50, itemSize: Int = 50): Seq[String] = { + (1 to numItems).map { _.toString } + } + + def getLogFilesInDirectory(directory: File): Seq[File] = { + if (directory.exists) { + directory.listFiles().filter(_.getName().startsWith("log-")) + .sortBy(_.getName.split("-")(1).toLong) + } else { + Seq.empty + } + } + + def printData(data: Seq[String]) { + println("# items in data = " + data.size) + println(data.mkString("\n")) + } + + implicit def stringToByteBuffer(str: String): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize(str)) + } + + implicit def byteBufferToString(byteBuffer: ByteBuffer): String = { + Utils.deserialize[String](byteBuffer.array) } }