diff --git a/streaming/pom.xml b/streaming/pom.xml
index 12f900c91eb98..002c35f251390 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -68,6 +68,11 @@
junit-interface
test
+
+ org.mockito
+ mockito-all
+ test
+
target/scala-${scala.binary.version}/classes
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
index f33c0ceafdf42..544cf57587e57 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
@@ -26,7 +26,7 @@ import org.apache.spark.Logging
import org.apache.spark.streaming.Time
private[streaming]
-class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
+class DStreamCheckpointData[T: ClassTag](dstream: DStream[T])
extends Serializable with Logging {
protected val data = new HashMap[Time, AnyRef]()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index 391e40924f38a..f1645066cc815 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -21,10 +21,11 @@ import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
import org.apache.spark.rdd.{BlockRDD, RDD}
-import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.{StorageLevel, BlockId}
import org.apache.spark.streaming._
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
+import org.apache.spark.streaming.storage.rdd.HDFSBackedBlockRDD
/**
* Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]]
@@ -39,9 +40,6 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext)
extends InputDStream[T](ssc_) {
- /** Keeps all received blocks information */
- private lazy val receivedBlockInfo = new HashMap[Time, Array[ReceivedBlockInfo]]
-
/** This is an unique identifier for the network input stream. */
val id = ssc.getNewReceiverStreamId()
@@ -62,19 +60,27 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
// If this is called for any time before the start time of the context,
// then this returns an empty RDD. This may happen when recovering from a
// master failure
- if (validTime >= graph.startTime) {
- val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id)
- receivedBlockInfo(validTime) = blockInfo
- val blockIds = blockInfo.map(_.blockId.asInstanceOf[BlockId])
- Some(new BlockRDD[T](ssc.sc, blockIds))
+ val blockRDD = if (validTime >= graph.startTime) {
+ val blockInfo = getReceivedBlockInfo(validTime)
+ val blockIds = blockInfo.map(_.blockId).map { _.asInstanceOf[BlockId] } toArray
+ val fileSegments = blockInfo.flatMap(_.fileSegmentOption).toArray
+ logInfo("Stream " + id + ": allocated " + blockInfo.map(_.blockId).mkString(", "))
+
+ if (fileSegments.nonEmpty) {
+ new HDFSBackedBlockRDD[T](ssc.sparkContext, ssc.sparkContext.hadoopConfiguration,
+ blockIds, fileSegments, storeInBlockManager = false, StorageLevel.MEMORY_ONLY_SER)
+ } else {
+ new BlockRDD[T](ssc.sc, blockIds)
+ }
} else {
- Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
+ new BlockRDD[T](ssc.sc, Array[BlockId]())
}
+ Some(blockRDD)
}
/** Get information on received blocks. */
- private[streaming] def getReceivedBlockInfo(time: Time) = {
- receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo])
+ private[streaming] def getReceivedBlockInfo(time: Time): Seq[ReceivedBlockInfo] = {
+ ssc.scheduler.receiverTracker.getReceivedBlocks(time, id)
}
/**
@@ -85,10 +91,6 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
*/
private[streaming] override def clearMetadata(time: Time) {
super.clearMetadata(time)
- val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - rememberDuration))
- receivedBlockInfo --= oldReceivedBlocks.keys
- logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than " +
- (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", "))
+ ssc.scheduler.receiverTracker.cleanupOldInfo(time - rememberDuration)
}
-}
-
+}
\ No newline at end of file
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDD.scala
similarity index 97%
rename from streaming/src/main/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDD.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDD.scala
index c672574ee2ed4..25f8363454394 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDD.scala
@@ -68,6 +68,7 @@ class HDFSBackedBlockRDD[T: ClassTag](
block.data.asInstanceOf[Iterator[T]]
// Data not found in Block Manager, grab it from HDFS
case None =>
+ logInfo("Reading partition data from write ahead log " + partition.segment.path)
val reader = new WriteAheadLogRandomReader(partition.segment.path, hadoopConf)
val dataRead = reader.read(partition.segment)
reader.close()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index 53a3e6200e340..97677491cd0b2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -25,16 +25,13 @@ import scala.concurrent.Await
import akka.actor.{Actor, Props}
import akka.pattern.ask
-
import com.google.common.base.Throwables
-
-import org.apache.spark.{Logging, SparkEnv}
-import org.apache.spark.streaming.scheduler._
-import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.hadoop.conf.Configuration
+import org.apache.spark.{SparkException, Logging, SparkEnv}
import org.apache.spark.storage.StreamBlockId
-import org.apache.spark.streaming.scheduler.DeregisterReceiver
-import org.apache.spark.streaming.scheduler.AddBlock
-import org.apache.spark.streaming.scheduler.RegisterReceiver
+import org.apache.spark.streaming.scheduler._
+import org.apache.spark.streaming.storage._
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]]
@@ -44,12 +41,26 @@ import org.apache.spark.streaming.scheduler.RegisterReceiver
*/
private[streaming] class ReceiverSupervisorImpl(
receiver: Receiver[_],
- env: SparkEnv
+ env: SparkEnv,
+ hadoopConf: Configuration,
+ checkpointDirOption: Option[String]
) extends ReceiverSupervisor(receiver, env.conf) with Logging {
- private val blockManager = env.blockManager
+ private val receivedBlockHandler: ReceivedBlockHandler = {
+ if (env.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) {
+ if (checkpointDirOption.isEmpty) {
+ throw new SparkException(
+ "Cannot enable receiver write-ahead log without checkpoint directory set. " +
+ "Please use streamingContext.checkpoint() to set the checkpoint directory. " +
+ "See documentation for more details.")
+ }
+ new WriteAheadLogBasedBlockHandler(env.blockManager, receiver.streamId,
+ receiver.storageLevel, env.conf, hadoopConf, checkpointDirOption.get)
+ } else {
+ new BlockManagerBasedBlockHandler(env.blockManager, receiver.streamId, receiver.storageLevel)
+ }
+ }
- private val storageLevel = receiver.storageLevel
/** Remote Akka actor for the ReceiverTracker */
private val trackerActor = {
@@ -108,11 +119,7 @@ private[streaming] class ReceiverSupervisorImpl(
optionalMetadata: Option[Any],
optionalBlockId: Option[StreamBlockId]
) {
- val blockId = optionalBlockId.getOrElse(nextBlockId)
- val time = System.currentTimeMillis
- blockManager.putArray(blockId, arrayBuffer.toArray[Any], storageLevel, tellMaster = true)
- logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms")
- reportPushedBlock(blockId, arrayBuffer.size, optionalMetadata)
+ pushAndReportBlock(ArrayBufferBlock(arrayBuffer), optionalMetadata, optionalBlockId)
}
/** Store a iterator of received data as a data block into Spark's memory. */
@@ -121,11 +128,7 @@ private[streaming] class ReceiverSupervisorImpl(
optionalMetadata: Option[Any],
optionalBlockId: Option[StreamBlockId]
) {
- val blockId = optionalBlockId.getOrElse(nextBlockId)
- val time = System.currentTimeMillis
- blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true)
- logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms")
- reportPushedBlock(blockId, -1, optionalMetadata)
+ pushAndReportBlock(IteratorBlock(iterator), optionalMetadata, optionalBlockId)
}
/** Store the bytes of received data as a data block into Spark's memory. */
@@ -134,17 +137,32 @@ private[streaming] class ReceiverSupervisorImpl(
optionalMetadata: Option[Any],
optionalBlockId: Option[StreamBlockId]
) {
+ pushAndReportBlock(ByteBufferBlock(bytes), optionalMetadata, optionalBlockId)
+ }
+
+ /** Store block and report it to driver */
+ def pushAndReportBlock(
+ receivedBlock: ReceivedBlock,
+ optionalMetadata: Option[Any],
+ optionalBlockId: Option[StreamBlockId]
+ ) {
val blockId = optionalBlockId.getOrElse(nextBlockId)
+ val numRecords = receivedBlock match {
+ case ArrayBufferBlock(arrayBuffer) => arrayBuffer.size
+ case _ => -1
+ }
+
val time = System.currentTimeMillis
- blockManager.putBytes(blockId, bytes, storageLevel, tellMaster = true)
- logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms")
- reportPushedBlock(blockId, -1, optionalMetadata)
- }
+ val fileSegmentOption = receivedBlockHandler.storeBlock(blockId, receivedBlock) match {
+ case Some(f: FileSegment) => Some(f)
+ case _ => None
+ }
+ logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms")
- /** Report pushed block */
- def reportPushedBlock(blockId: StreamBlockId, numRecords: Long, optionalMetadata: Option[Any]) {
- val blockInfo = ReceivedBlockInfo(streamId, blockId, numRecords, optionalMetadata.orNull)
- trackerActor ! AddBlock(blockInfo)
+ val blockInfo = ReceivedBlockInfo(streamId,
+ blockId, numRecords, optionalMetadata.orNull, fileSegmentOption)
+ val future = trackerActor.ask(AddBlock(blockInfo))(askTimeout)
+ Await.result(future, askTimeout)
logDebug("Reported block " + blockId)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 374848358e700..ee00abdf3f4f0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -222,10 +222,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
case Success(jobs) =>
val receivedBlockInfo = graph.getReceiverInputStreams.map { stream =>
val streamId = stream.id
- val receivedBlockInfo = stream.getReceivedBlockInfo(time)
+ val receivedBlockInfo = stream.getReceivedBlockInfo(time).toArray
(streamId, receivedBlockInfo)
}.toMap
- jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo))
+ jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo.toMap))
case Failure(e) =>
jobScheduler.reportError("Error generating jobs for time " + time, e)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 5307fe189d717..f2aa02d065222 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -17,23 +17,30 @@
package org.apache.spark.streaming.scheduler
+import java.nio.ByteBuffer
+
import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue}
import scala.language.existentials
import akka.actor._
-import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.SparkContext._
+import org.apache.spark.SerializableWritable
import org.apache.spark.storage.StreamBlockId
import org.apache.spark.streaming.{StreamingContext, Time}
import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver}
-import org.apache.spark.util.AkkaUtils
+import org.apache.spark.streaming.storage.{ReceivedBlockTracker, FileSegment, WriteAheadLogManager}
+import org.apache.spark.util.Utils
/** Information about blocks received by the receiver */
private[streaming] case class ReceivedBlockInfo(
streamId: Int,
blockId: StreamBlockId,
numRecords: Long,
- metadata: Any
+ metadata: Any,
+ fileSegmentOption: Option[FileSegment]
)
/**
@@ -57,23 +64,28 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err
* This class manages the execution of the receivers of NetworkInputDStreams. Instance of
* this class must be created after all input streams have been added and StreamingContext.start()
* has been called because it needs the final set of input streams at the time of instantiation.
+ *
+ * @param skipReceiverLaunch Do not launch the receiver. This is useful for testing.
*/
private[streaming]
-class ReceiverTracker(ssc: StreamingContext) extends Logging {
+class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false) extends Logging {
- val receiverInputStreams = ssc.graph.getReceiverInputStreams()
- val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*)
- val receiverExecutor = new ReceiverLauncher()
- val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo]
- val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]]
- with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]]
- val timeout = AkkaUtils.askTimeout(ssc.conf)
- val listenerBus = ssc.scheduler.listenerBus
+ private val receiverInputStreams = ssc.graph.getReceiverInputStreams()
+ private val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*)
+ private val receiverExecutor = new ReceiverLauncher()
+ private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo]
+ private val receivedBlockTracker = new ReceivedBlockTracker(
+ ssc.sparkContext.conf,
+ ssc.sparkContext.hadoopConfiguration,
+ receiverInputStreams.map { _.id },
+ ssc.scheduler.clock,
+ Option(ssc.checkpointDir)
+ )
+ private val listenerBus = ssc.scheduler.listenerBus
// actor is created when generator starts.
// This not being null means the tracker has been started and not stopped
- var actor: ActorRef = null
- var currentTime: Time = null
+ private var actor: ActorRef = null
/** Start the actor and receiver execution thread. */
def start() = synchronized {
@@ -84,7 +96,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
if (!receiverInputStreams.isEmpty) {
actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor),
"ReceiverTracker")
- receiverExecutor.start()
+ if (!skipReceiverLaunch) receiverExecutor.start()
logInfo("ReceiverTracker started")
}
}
@@ -93,28 +105,27 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
def stop() = synchronized {
if (!receiverInputStreams.isEmpty && actor != null) {
// First, stop the receivers
- receiverExecutor.stop()
+ if (!skipReceiverLaunch) receiverExecutor.stop()
// Finally, stop the actor
ssc.env.actorSystem.stop(actor)
actor = null
+ receivedBlockTracker.stop()
logInfo("ReceiverTracker stopped")
}
}
/** Return all the blocks received from a receiver. */
- def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = {
- val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true)
- logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks")
- receivedBlockInfo.toArray
+ def getReceivedBlocks(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = {
+ receivedBlockTracker.getOrAllocateBlocksToBatch(batchTime, streamId)
}
- private def getReceivedBlockInfoQueue(streamId: Int) = {
- receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo])
+ def cleanupOldInfo(cleanupThreshTime: Time) {
+ receivedBlockTracker.cleanupOldBatches(cleanupThreshTime)
}
/** Register a receiver */
- def registerReceiver(
+ private def registerReceiver(
streamId: Int,
typ: String,
host: String,
@@ -126,12 +137,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
}
receiverInfo(streamId) = ReceiverInfo(
streamId, s"${typ}-${streamId}", receiverActor, true, host)
- ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
+ listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address)
}
/** Deregister a receiver */
- def deregisterReceiver(streamId: Int, message: String, error: String) {
+ private def deregisterReceiver(streamId: Int, message: String, error: String) {
val newReceiverInfo = receiverInfo.get(streamId) match {
case Some(oldInfo) =>
oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error)
@@ -140,7 +151,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error)
}
receiverInfo(streamId) = newReceiverInfo
- ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId)))
+ listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId)))
val messageWithError = if (error != null && !error.isEmpty) {
s"$message - $error"
} else {
@@ -150,14 +161,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
}
/** Add new blocks for the given stream */
- def addBlocks(receivedBlockInfo: ReceivedBlockInfo) {
- getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo
- logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " +
- receivedBlockInfo.blockId)
+ private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = {
+ receivedBlockTracker.addBlock(receivedBlockInfo)
}
/** Report error sent by a receiver */
- def reportError(streamId: Int, message: String, error: String) {
+ private def reportError(streamId: Int, message: String, error: String) {
val newReceiverInfo = receiverInfo.get(streamId) match {
case Some(oldInfo) =>
oldInfo.copy(lastErrorMessage = message, lastError = error)
@@ -166,7 +175,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error)
}
receiverInfo(streamId) = newReceiverInfo
- ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId)))
+ listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId)))
val messageWithError = if (error != null && !error.isEmpty) {
s"$message - $error"
} else {
@@ -177,7 +186,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
/** Check if any blocks are left to be processed */
def hasMoreReceivedBlockIds: Boolean = {
- !receivedBlockInfo.values.forall(_.isEmpty)
+ receivedBlockTracker.hasUnallocatedReceivedBlocks()
}
/** Actor to receive messages from the receivers. */
@@ -187,7 +196,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
registerReceiver(streamId, typ, host, receiverActor, sender)
sender ! true
case AddBlock(receivedBlockInfo) =>
- addBlocks(receivedBlockInfo)
+ sender ! addBlock(receivedBlockInfo)
case ReportError(streamId, message, error) =>
reportError(streamId, message, error)
case DeregisterReceiver(streamId, message, error) =>
@@ -253,6 +262,9 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
ssc.sc.makeRDD(receivers, receivers.size)
}
+ val checkpointDirOption = Option(ssc.checkpointDir)
+ val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration)
+
// Function to start the receiver on the worker node
val startReceiver = (iterator: Iterator[Receiver[_]]) => {
if (!iterator.hasNext) {
@@ -260,9 +272,10 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
"Could not start receiver as object not found.")
}
val receiver = iterator.next()
- val executor = new ReceiverSupervisorImpl(receiver, SparkEnv.get)
- executor.start()
- executor.awaitTermination()
+ val supervisor = new ReceiverSupervisorImpl(
+ receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption)
+ supervisor.start()
+ supervisor.awaitTermination()
}
// Run the dummy Spark job to ensure that all slaves have registered.
// This avoids all the receivers to be scheduled on the same node.
@@ -272,7 +285,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
// Distribute the receivers and start them
logInfo("Starting " + receivers.length + " receivers")
- ssc.sparkContext.runJob(tempRDD, startReceiver)
+ ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
logInfo("All of the receivers have been terminated")
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockHandler.scala
new file mode 100644
index 0000000000000..d0d2370201a8e
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockHandler.scala
@@ -0,0 +1,120 @@
+package org.apache.spark.streaming.storage
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import WriteAheadLogBasedBlockHandler._
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.storage.{BlockManager, StorageLevel, StreamBlockId}
+import org.apache.spark.streaming.util.{Clock, SystemClock}
+import org.apache.spark.util.Utils
+
+private[streaming] sealed trait ReceivedBlock
+private[streaming] case class ArrayBufferBlock(arrayBuffer: ArrayBuffer[_]) extends ReceivedBlock
+private[streaming] case class IteratorBlock(iterator: Iterator[_]) extends ReceivedBlock
+private[streaming] case class ByteBufferBlock(byteBuffer: ByteBuffer) extends ReceivedBlock
+
+private[streaming] trait ReceivedBlockHandler {
+ def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): Option[AnyRef]
+ def cleanupOldBlock(threshTime: Long)
+}
+
+private[streaming] class BlockManagerBasedBlockHandler(
+ blockManager: BlockManager,
+ streamId: Int,
+ storageLevel: StorageLevel
+ ) extends ReceivedBlockHandler {
+
+ def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): Option[AnyRef] = {
+ receivedBlock match {
+ case ArrayBufferBlock(arrayBuffer) =>
+ blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true)
+ case IteratorBlock(iterator) =>
+ blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true)
+ case ByteBufferBlock(byteBuffer) =>
+ blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true)
+ case _ =>
+ throw new Exception(s"Could not push $blockId to block manager, unexpected block type")
+ }
+ None
+ }
+
+ def cleanupOldBlock(threshTime: Long) {
+ // this is not used as blocks inserted into the BlockManager are cleared by DStream's clearing
+ // of BlockRDDs.
+ }
+}
+
+private[streaming] class WriteAheadLogBasedBlockHandler(
+ blockManager: BlockManager,
+ streamId: Int,
+ storageLevel: StorageLevel,
+ conf: SparkConf,
+ hadoopConf: Configuration,
+ checkpointDir: String,
+ clock: Clock = new SystemClock
+ ) extends ReceivedBlockHandler with Logging {
+
+ private val blockStoreTimeout = conf.getInt(
+ "spark.streaming.receiver.blockStoreTimeout", 30).seconds
+ private val rollingInterval = conf.getInt(
+ "spark.streaming.receiver.writeAheadLog.rollingInterval", 60)
+ private val maxFailures = conf.getInt(
+ "spark.streaming.receiver.writeAheadLog.maxFailures", 3)
+
+ private val logManager = new WriteAheadLogManager(
+ checkpointDirToLogDir(checkpointDir, streamId),
+ hadoopConf, rollingInterval, maxFailures,
+ callerName = "WriteAheadLogBasedBlockHandler",
+ clock = clock
+ )
+
+ implicit private val executionContext = ExecutionContext.fromExecutorService(
+ Utils.newDaemonFixedThreadPool(1, "WriteAheadLogBasedBlockHandler"))
+
+ def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): Option[AnyRef] = {
+ val serializedBlock = receivedBlock match {
+ case ArrayBufferBlock(arrayBuffer) =>
+ blockManager.dataSerialize(blockId, arrayBuffer.iterator)
+ case IteratorBlock(iterator) =>
+ blockManager.dataSerialize(blockId, iterator)
+ case ByteBufferBlock(byteBuffer) =>
+ byteBuffer
+ case _ =>
+ throw new Exception(s"Could not push $blockId to block manager, unexpected block type")
+ }
+
+ val pushToBlockManagerFuture = Future {
+ blockManager.putBytes(blockId, serializedBlock, storageLevel, tellMaster = true)
+ }
+ val pushToLogFuture = Future {
+ logManager.writeToLog(serializedBlock)
+ }
+ val combinedFuture = for {
+ _ <- pushToBlockManagerFuture
+ fileSegment <- pushToLogFuture
+ } yield fileSegment
+
+ Some(Await.result(combinedFuture, blockStoreTimeout))
+ }
+
+ def cleanupOldBlock(threshTime: Long) {
+ logManager.cleanupOldLogs(threshTime)
+ }
+
+ def stop() {
+ logManager.stop()
+ }
+}
+
+private[streaming] object WriteAheadLogBasedBlockHandler {
+ def checkpointDirToLogDir(checkpointDir: String, streamId: Int): String = {
+ new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockTracker.scala
new file mode 100644
index 0000000000000..866d77dd25489
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/ReceivedBlockTracker.scala
@@ -0,0 +1,173 @@
+package org.apache.spark.streaming.storage
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable
+import scala.language.implicitConversions
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.streaming.Time
+import org.apache.spark.streaming.storage.WriteAheadLogManager
+import org.apache.spark.util.Utils
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.util.Clock
+import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
+
+private[streaming] sealed trait ReceivedBlockTrackerRecord
+
+private[streaming] case class BlockAddition(receivedBlockInfo: ReceivedBlockInfo)
+ extends ReceivedBlockTrackerRecord
+private[streaming] case class BatchAllocations(time: Time, allocatedBlocks: AllocatedBlocks)
+ extends ReceivedBlockTrackerRecord
+private[streaming] case class BatchCleanup(times: Seq[Time])
+ extends ReceivedBlockTrackerRecord
+
+case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) {
+ def apply(streamId: Int) = streamIdToAllocatedBlocks(streamId)
+}
+
+private[streaming]
+class ReceivedBlockTracker(
+ conf: SparkConf, hadoopConf: Configuration, streamIds: Seq[Int], clock: Clock,
+ checkpointDirOption: Option[String]) extends Logging {
+
+ private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo]
+
+ private val streamIdToUnallocatedBlockInfo = new mutable.HashMap[Int, ReceivedBlockQueue]
+ private val timeToAllocatedBlockInfo = new mutable.HashMap[Time, AllocatedBlocks]
+
+ private val logManagerRollingIntervalSecs = conf.getInt(
+ "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60)
+ private val logManagerOption = checkpointDirOption.map { checkpointDir =>
+ new WriteAheadLogManager(
+ ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir),
+ hadoopConf,
+ rollingIntervalSecs = logManagerRollingIntervalSecs,
+ callerName = "ReceivedBlockHandlerMaster",
+ clock = clock
+ )
+ }
+
+ // Recover block information from write ahead logs
+ recoverFromWriteAheadLogs()
+
+ /** Add received block */
+ def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized {
+ try {
+ writeToLog(BlockAddition(receivedBlockInfo))
+ getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
+ logDebug(s"Stream ${receivedBlockInfo.streamId} received block ${receivedBlockInfo.blockId}")
+ true
+ } catch {
+ case e: Exception =>
+ logError("Error adding block " + receivedBlockInfo, e)
+ false
+ }
+ }
+
+ /** Get blocks that have been added but not yet allocated to any batch */
+ def getUnallocatedBlocks(streamId: Int): Seq[ReceivedBlockInfo] = synchronized {
+ getReceivedBlockQueue(streamId).toSeq
+ }
+
+ /** Get the blocks allocated to a batch, or allocate blocks to the batch and then get them */
+ def getOrAllocateBlocksToBatch(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = {
+ synchronized {
+ if (!timeToAllocatedBlockInfo.contains(batchTime)) {
+ allocateAllUnallocatedBlocksToBatch(batchTime)
+ }
+ timeToAllocatedBlockInfo(batchTime)(streamId)
+ }
+ }
+
+ /** Check if any blocks are left to be allocated to batches */
+ def hasUnallocatedReceivedBlocks(): Boolean = synchronized {
+ !streamIdToUnallocatedBlockInfo.values.forall(_.isEmpty)
+ }
+
+ /** Clean up block information of old batches */
+ def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized {
+ assert(cleanupThreshTime.milliseconds < clock.currentTime())
+ val timesToCleanup = timeToAllocatedBlockInfo.keys.filter { _ < cleanupThreshTime }.toSeq
+ logInfo("Deleting batches " + timesToCleanup)
+ writeToLog(BatchCleanup(timesToCleanup))
+ timeToAllocatedBlockInfo --= timesToCleanup
+ logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds))
+ log
+ }
+
+ /** Stop the block tracker */
+ def stop() {
+ logManagerOption.foreach { _.stop() }
+ }
+
+ /** Allocate all unallocated blocks to the given batch */
+ private def allocateAllUnallocatedBlocksToBatch(batchTime: Time): AllocatedBlocks = synchronized {
+ val allocatedBlockInfos = AllocatedBlocks(streamIds.map { streamId =>
+ (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true))
+ }.toMap)
+ writeToLog(BatchAllocations(batchTime, allocatedBlockInfos))
+ timeToAllocatedBlockInfo(batchTime) = allocatedBlockInfos
+ allocatedBlockInfos
+ }
+
+ private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = {
+ streamIdToUnallocatedBlockInfo.getOrElseUpdate(streamId, new ReceivedBlockQueue)
+ }
+
+ /** Write an update to the tracker to the write ahead log */
+ private def writeToLog(record: ReceivedBlockTrackerRecord) {
+ logDebug("Writing to log " + record)
+ logManagerOption.foreach { logManager =>
+ logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record)))
+ }
+ }
+
+ private def recoverFromWriteAheadLogs(): Unit = synchronized {
+ logInfo("Recovering from checkpoint")
+ val blockIdToblockInfo = new mutable.HashMap[StreamBlockId, ReceivedBlockInfo]
+
+ def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) {
+ logTrace(s"Recovery: Inserting added block $receivedBlockInfo")
+ println(s"Recovery: Inserting added block $receivedBlockInfo")
+ blockIdToblockInfo.put(receivedBlockInfo.blockId, receivedBlockInfo)
+ getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
+ }
+
+ def insertAllocatedBatch(time: Time, allocatedBlocks: AllocatedBlocks) {
+ logTrace(s"Recovery: Inserting allocated batch for time $time to ${allocatedBlocks.streamIdToAllocatedBlocks}")
+ println(s"Recovery: Inserting allocated batch for time $time to ${allocatedBlocks.streamIdToAllocatedBlocks}")
+ streamIdToUnallocatedBlockInfo.values.foreach { _.clear() }
+ timeToAllocatedBlockInfo.put(time, allocatedBlocks)
+ }
+
+ def cleanupBatches(batchTimes: Seq[Time]) {
+ logTrace(s"Recovery: Cleaning up batches ${batchTimes}")
+ println(s"Recovery: Cleaning up batches ${batchTimes}")
+ timeToAllocatedBlockInfo --= batchTimes
+ }
+
+ logManagerOption.foreach { logManager =>
+ logManager.readFromLog().foreach { byteBuffer =>
+ logTrace("Recovering record " + byteBuffer)
+ Utils.deserialize[ReceivedBlockTrackerRecord](byteBuffer.array) match {
+ case BlockAddition(receivedBlockInfo) =>
+ insertAddedBlock(receivedBlockInfo)
+ case BatchAllocations(time, allocatedBlocks) =>
+ insertAllocatedBatch(time, allocatedBlocks)
+ case BatchCleanup(batchTimes) =>
+ cleanupBatches(batchTimes)
+ }
+ }
+ }
+ blockIdToblockInfo.clear()
+ }
+}
+
+private[streaming] object ReceivedBlockTracker {
+ def checkpointDirToLogDir(checkpointDir: String): String = {
+ new Path(checkpointDir, "receivedBlockMetadata").toString
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogManager.scala
new file mode 100644
index 0000000000000..c70ecb0da4e54
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogManager.scala
@@ -0,0 +1,176 @@
+package org.apache.spark.streaming.storage
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.{ExecutionContext, Future}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsPermission
+import org.apache.spark.Logging
+import org.apache.spark.streaming.storage.WriteAheadLogManager._
+import org.apache.spark.streaming.util.{Clock, SystemClock}
+import org.apache.spark.util.Utils
+
+private[streaming] class WriteAheadLogManager(
+ logDirectory: String,
+ hadoopConf: Configuration,
+ rollingIntervalSecs: Int = 60,
+ maxFailures: Int = 3,
+ callerName: String = "",
+ clock: Clock = new SystemClock
+ ) extends Logging {
+
+ private val pastLogs = new ArrayBuffer[LogInfo]
+ private val callerNameTag =
+ if (callerName != null && callerName.nonEmpty) s" for $callerName" else ""
+ private val threadpoolName = s"WriteAheadLogManager $callerNameTag"
+ implicit private val executionContext = ExecutionContext.fromExecutorService(
+ Utils.newDaemonFixedThreadPool(1, threadpoolName))
+ override protected val logName = s"WriteAheadLogManager $callerNameTag"
+
+ private var currentLogPath: String = null
+ private var currentLogWriter: WriteAheadLogWriter = null
+ private var currentLogWriterStartTime: Long = -1L
+ private var currentLogWriterStopTime: Long = -1L
+
+ initializeOrRecover()
+
+ def writeToLog(byteBuffer: ByteBuffer): FileSegment = synchronized {
+ var fileSegment: FileSegment = null
+ var failures = 0
+ var lastException: Exception = null
+ var succeeded = false
+ while (!succeeded && failures < maxFailures) {
+ try {
+ fileSegment = getLogWriter(clock.currentTime).write(byteBuffer)
+ succeeded = true
+ } catch {
+ case ex: Exception =>
+ lastException = ex
+ logWarning("Failed to ...")
+ resetWriter()
+ failures += 1
+ }
+ }
+ if (fileSegment == null) {
+ throw lastException
+ }
+ fileSegment
+ }
+
+ def readFromLog(): Iterator[ByteBuffer] = synchronized {
+ val logFilesToRead = pastLogs.map{ _.path} ++ Option(currentLogPath)
+ logInfo("Reading from the logs: " + logFilesToRead.mkString("\n"))
+ logFilesToRead.iterator.map { file =>
+ logDebug(s"Creating log reader with $file")
+ new WriteAheadLogReader(file, hadoopConf)
+ } flatMap { x => x }
+ }
+
+ /**
+ * Delete the log files that are older than the threshold time.
+ *
+ * Its important to note that the threshold time is based on the time stamps used in the log
+ * files, and is therefore based on the local system time. So if there is coordination necessary
+ * between the node calculating the threshTime (say, driver node), and the local system time
+ * (say, worker node), the caller has to take account of possible time skew.
+ */
+ def cleanupOldLogs(threshTime: Long): Unit = {
+ val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } }
+ logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " +
+ s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}")
+
+ def deleteFiles() {
+ oldLogFiles.foreach { logInfo =>
+ try {
+ val path = new Path(logInfo.path)
+ val fs = hadoopConf.synchronized { path.getFileSystem(hadoopConf) }
+ fs.delete(path, true)
+ synchronized { pastLogs -= logInfo }
+ logDebug(s"Cleared log file $logInfo")
+ } catch {
+ case ex: Exception =>
+ logWarning(s"Error clearing log file $logInfo", ex)
+ }
+ }
+ logInfo(s"Cleared log files in $logDirectory older than $threshTime")
+ }
+ if (!executionContext.isShutdown) {
+ Future { deleteFiles() }
+ }
+ }
+
+ def stop(): Unit = synchronized {
+ if (currentLogWriter != null) {
+ currentLogWriter.close()
+ }
+ executionContext.shutdown()
+ logInfo("Stopped log manager")
+ }
+
+ private def getLogWriter(currentTime: Long): WriteAheadLogWriter = synchronized {
+ if (currentLogWriter == null || currentTime > currentLogWriterStopTime) {
+ resetWriter()
+ if (currentLogPath != null) {
+ pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, currentLogPath)
+ }
+ currentLogWriterStartTime = currentTime
+ currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000)
+ val newLogPath = new Path(logDirectory,
+ timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime))
+ currentLogPath = newLogPath.toString
+ currentLogWriter = new WriteAheadLogWriter(currentLogPath, hadoopConf)
+ }
+ currentLogWriter
+ }
+
+ private def initializeOrRecover(): Unit = synchronized {
+ val logDirectoryPath = new Path(logDirectory)
+ val fileSystem = logDirectoryPath.getFileSystem(hadoopConf)
+
+ if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) {
+ val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath })
+ pastLogs.clear()
+ pastLogs ++= logFileInfo
+ logInfo(s"Recovered ${logFileInfo.size} log files from $logDirectory")
+ logDebug(s"Recovered files are:\n${logFileInfo.map(_.path).mkString("\n")}")
+ } else {
+ fileSystem.mkdirs(logDirectoryPath,
+ FsPermission.createImmutable(Integer.parseInt("770", 8).toShort))
+ logInfo(s"Created ${logDirectory} for log files")
+ }
+ }
+
+ private def resetWriter(): Unit = synchronized {
+ if (currentLogWriter != null) {
+ currentLogWriter.close()
+ currentLogWriter = null
+ }
+ }
+}
+
+private[storage] object WriteAheadLogManager {
+
+ case class LogInfo(startTime: Long, endTime: Long, path: String)
+
+ val logFileRegex = """log-(\d+)-(\d+)""".r
+
+ def timeToLogFile(startTime: Long, stopTime: Long): String = {
+ s"log-$startTime-$stopTime"
+ }
+
+ def logFilesTologInfo(files: Seq[Path]): Seq[LogInfo] = {
+ files.flatMap { file =>
+ logFileRegex.findFirstIn(file.getName()) match {
+ case Some(logFileRegex(startTimeStr, stopTimeStr)) =>
+ val startTime = startTimeStr.toLong
+ val stopTime = stopTimeStr.toLong
+ Some(LogInfo(startTime, stopTime, file.toString))
+ case None =>
+ None
+ }
+ }.sortBy { _.startTime }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala
index 724549e216e93..5e0dc1d49a89a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala
@@ -20,16 +20,20 @@ import java.io.{EOFException, Closeable}
import java.nio.ByteBuffer
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.Logging
private[streaming] class WriteAheadLogReader(path: String, conf: Configuration)
- extends Iterator[ByteBuffer] with Closeable {
+ extends Iterator[ByteBuffer] with Closeable with Logging {
private val instream = HdfsUtils.getInputStream(path, conf)
private var closed = false
private var nextItem: Option[ByteBuffer] = None
override def hasNext: Boolean = synchronized {
- assertOpen()
+ if (closed) {
+ return false
+ }
+
if (nextItem.isDefined) { // handle the case where hasNext is called without calling next
true
} else {
@@ -38,33 +42,35 @@ private[streaming] class WriteAheadLogReader(path: String, conf: Configuration)
val buffer = new Array[Byte](length)
instream.readFully(buffer)
nextItem = Some(ByteBuffer.wrap(buffer))
+ logTrace("Read next item " + nextItem.get)
true
} catch {
- case e: EOFException => false
- case e: Exception => throw e
+ case e: EOFException =>
+ logDebug("Error reading next item, EOF reached", e)
+ close()
+ false
+ case e: Exception =>
+ logDebug("Error reading next item, EOF reached", e)
+ close()
+ throw e
}
}
}
override def next(): ByteBuffer = synchronized {
- // TODO: Possible error case where there are not enough bytes in the stream
- // TODO: How to handle that?
val data = nextItem.getOrElse {
- throw new IllegalStateException("next called without calling hasNext or after hasNext " +
- "returned false")
+ close()
+ throw new IllegalStateException(
+ "next called without calling hasNext or after hasNext returned false")
}
nextItem = None // Ensure the next hasNext call loads new data.
data
}
override def close(): Unit = synchronized {
+ if (!closed) {
+ instream.close()
+ }
closed = true
- instream.close()
}
-
- private def assertOpen() {
- HdfsUtils.checkState(!closed, "Stream is closed. Create a new Reader to read from the " +
- "file.")
- }
-
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala
index 8a2db8305a7e2..68a1172d7d282 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala
@@ -16,20 +16,38 @@
*/
package org.apache.spark.streaming.storage
-import java.io.Closeable
-import java.lang.reflect.Method
+import java.io._
+import java.net.URI
import java.nio.ByteBuffer
import scala.util.Try
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.FSDataOutputStream
+import org.apache.hadoop.fs.{FSDataOutputStream, FileSystem}
+import org.apache.spark.streaming.storage.FileSegment
private[streaming] class WriteAheadLogWriter(path: String, conf: Configuration) extends Closeable {
- private val stream = HdfsUtils.getOutputStream(path, conf)
- private var nextOffset = stream.getPos
+ private val underlyingStream: Either[DataOutputStream, FSDataOutputStream] = {
+ val uri = new URI(path)
+ val defaultFs = FileSystem.getDefaultUri(conf).getScheme
+ val isDefaultLocal = defaultFs == null || defaultFs == "file"
+
+ if ((isDefaultLocal && uri.getScheme == null) || uri.getScheme == "file") {
+ assert(!new File(uri.getPath).exists)
+ Left(new DataOutputStream(new BufferedOutputStream(new FileOutputStream(uri.getPath))))
+ } else {
+ Right(HdfsUtils.getOutputStream(path, conf))
+ }
+ }
+
+ private lazy val hadoopFlushMethod = {
+ val cls = classOf[FSDataOutputStream]
+ Try(cls.getMethod("hflush")).orElse(Try(cls.getMethod("sync"))).toOption
+ }
+
+ private var nextOffset = getPosition()
private var closed = false
- private val hflushMethod = getHflushOrSync()
+
// Data is always written as:
// - Length - Long
@@ -48,8 +66,8 @@ private[streaming] class WriteAheadLogWriter(path: String, conf: Configuration)
stream.write(data.get())
}
}
- hflushOrSync()
- nextOffset = stream.getPos
+ flush()
+ nextOffset = getPosition()
segment
}
@@ -58,17 +76,22 @@ private[streaming] class WriteAheadLogWriter(path: String, conf: Configuration)
stream.close()
}
- private def hflushOrSync() {
- hflushMethod.foreach(_.invoke(stream))
+ private def stream(): DataOutputStream = {
+ underlyingStream.fold(x => x, x => x)
+ }
+
+ private def getPosition(): Long = {
+ underlyingStream match {
+ case Left(localStream) => localStream.size
+ case Right(dfsStream) => dfsStream.getPos()
+ }
}
- private def getHflushOrSync(): Option[Method] = {
- Try {
- Some(classOf[FSDataOutputStream].getMethod("hflush"))
- }.recover {
- case e: NoSuchMethodException =>
- Some(classOf[FSDataOutputStream].getMethod("sync"))
- }.getOrElse(None)
+ private def flush() {
+ underlyingStream match {
+ case Left(localStream) => localStream.flush
+ case Right(dfsStream) => hadoopFlushMethod.foreach { _.invoke(dfsStream) }
+ }
}
private def assertOpen() {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala
new file mode 100644
index 0000000000000..802b24e47e21c
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/DriverFailureSuite.scala
@@ -0,0 +1,410 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.io.File
+import java.util.UUID
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.spark.{HashPartitioner, Logging, SparkConf}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.streaming.receiver.Receiver
+import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerBatchCompleted}
+import org.apache.spark.util.Utils
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * This testsuite tests master failures at random times while the stream is running using
+ * the real clock.
+ */
+class DriverFailureSuite extends TestSuiteBase with Logging {
+
+ var directory = "FailureSuite"
+ val numBatches = 30
+
+ override def batchDuration = Milliseconds(1000)
+
+ override def useManualClock = false
+
+ override def beforeFunction() {
+ super.beforeFunction()
+ Utils.deleteRecursively(new File(directory))
+ }
+
+ override def afterFunction() {
+ super.afterFunction()
+ Utils.deleteRecursively(new File(directory))
+ }
+/*
+ test("multiple failures with map") {
+ MasterFailureTest.testMap(directory, numBatches, batchDuration)
+ }
+
+ test("multiple failures with updateStateByKey") {
+ MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration)
+ }
+*/
+ test("multiple failures with receiver and updateStateByKey") {
+
+
+ val operation = (st: DStream[String]) => {
+
+ val mapPartitionFunc = (iterator: Iterator[String]) => {
+ Iterator(iterator.flatMap(_.split(" ")).map(_ -> 1L).reduce((x, y) => (x._1, x._2 + y._2)))
+ }
+
+ val updateFunc = (iterator: Iterator[(String, Seq[Long], Option[Seq[Long]])]) => {
+ iterator.map { case (key, values, state) =>
+ val combined = (state.getOrElse(Seq.empty) ++ values).sorted
+ if (state.isEmpty || state.get.max != DriverFailureTestReceiver.maxRecordsPerBlock) {
+ val oldState = s"[${ state.map { _.max }.getOrElse(-1) }, ${state.map { _.distinct.sum }.getOrElse(0)}]"
+ val newState = s"[${combined.max}, ${combined.distinct.sum}]"
+ println(s"Updated state for $key: state = $oldState, new values = $values, new state = $newState")
+ }
+ (key, combined)
+ }
+ }
+
+ st.mapPartitions(mapPartitionFunc)
+ .updateStateByKey[Seq[Long]](updateFunc, new HashPartitioner(2), rememberPartitioner = false)
+ .checkpoint(batchDuration * 5)
+ }
+
+ val maxValue = DriverFailureTestReceiver.maxRecordsPerBlock
+ val expectedValues = (1L to maxValue).toSet
+
+ val verify = (time: Time, output: Seq[(String, Seq[Long])]) => {
+ val outputStr = output.map { x => (x._1, x._2.distinct.sum) }.sortBy(_._1).mkString(", ")
+ println(s"State at $time: $outputStr")
+
+ val incompletelyReceivedWords = output.filter { _._2.max < maxValue }
+ if (incompletelyReceivedWords.size > 1) {
+ val debugStr = incompletelyReceivedWords.map { x =>
+ s"""${x._1}: ${x._2.mkString(",")}, sum = ${x._2.distinct.sum}"""
+ }.mkString("\n")
+ throw new Exception(s"Incorrect processing of input, all input not processed:\n$debugStr\n")
+ }
+
+ output.foreach { case (key, values) =>
+ if (!values.forall(expectedValues.contains)) {
+ val sum = values.distinct.sum
+ val debugStr = values.zip(1L to values.size).map {
+ x => if (x._1 == x._2) x._1 else s"[${x._2}]"
+ }.mkString(",") + s", sum = $sum"
+ throw new Exception(s"Incorrect sequence of values in output:\n$debugStr\n")
+ }
+ }
+ }
+
+ val driverTest = new ReceiverBasedDriverFailureTest[(String, Seq[Long])](
+ "./driver-test/", 200, 50, operation, verify)
+ driverTest.testAndGetError().map { errorMessage =>
+ fail(errorMessage)
+ }
+ }
+}
+
+
+abstract class DriverFailureTest(
+ testDirectory: String,
+ batchDurationMillis: Int,
+ numBatchesToRun: Int
+ ) extends Logging {
+
+ @transient private val checkpointDir = createCheckpointDir()
+ @transient private val timeoutMillis = batchDurationMillis * numBatchesToRun * 4
+
+ @transient @volatile private var killed = false
+ @transient @volatile private var killCount = 0
+ @transient @volatile private var lastBatchCompleted = 0L
+ @transient @volatile private var batchesCompleted = 0
+ @transient @volatile private var ssc: StreamingContext = null
+
+ protected def setupContext(checkpointDirector: String): StreamingContext
+
+ //----------------------------------------
+
+ /**
+ * Run the test and return an option string containing error message.
+ * @return None is test succeeded, or Some(errorMessage) if test failed
+ */
+ def testAndGetError(): Option[String] = {
+ DriverFailureTest.reset()
+ ssc = setupContext(checkpointDir.toString)
+ run()
+ }
+
+ private def run(): Option[String] = {
+
+ val runStartTime = System.currentTimeMillis
+ var killingThread: Thread = null
+
+ def allBatchesCompleted = batchesCompleted >= numBatchesToRun
+ def timedOut = (System.currentTimeMillis - runStartTime) > timeoutMillis
+ def failed = DriverFailureTest.failed
+
+ while(!failed && !allBatchesCompleted && !timedOut) {
+ // Start the thread to kill the streaming after some time
+ killed = false
+ try {
+ ssc.addStreamingListener(new BatchCompletionListener)
+ ssc.start()
+
+ killingThread = new KillingThread(ssc, batchDurationMillis * 10)
+ killingThread.start()
+
+ while (!failed && !killed && !allBatchesCompleted && !timedOut) {
+ ssc.awaitTermination(1)
+ }
+ } catch {
+ case e: Exception =>
+ logError("Error running streaming context", e)
+ DriverFailureTest.fail("Error running streaming context: " + e)
+ }
+
+ logInfo(s"Failed = $failed")
+ logInfo(s"Killed = $killed")
+ logInfo(s"All batches completed = $allBatchesCompleted")
+ logInfo(s"Timed out = $timedOut")
+
+ if (killingThread.isAlive) {
+ killingThread.interrupt()
+ ssc.stop()
+ }
+
+ if (!timedOut) {
+ val sleepTime = Random.nextInt(batchDurationMillis * 10)
+ logInfo(
+ "\n-------------------------------------------\n" +
+ " Restarting stream computation in " + sleepTime + " ms " +
+ "\n-------------------------------------------\n"
+ )
+ Thread.sleep(sleepTime)
+
+ // Recreate the streaming context from checkpoint
+ System.clearProperty("spark.driver.port")
+ ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => {
+ throw new Exception("Trying to create new context when it " +
+ "should be reading from checkpoint file")
+ })
+ println("Restarted")
+ }
+ }
+
+ if (failed) {
+ Some(s"Failed with message: ${DriverFailureTest.firstFailureMessage}")
+ } else if (timedOut) {
+ Some(s"Timed out after $batchesCompleted/$numBatchesToRun batches, and " +
+ s"${System.currentTimeMillis} ms (time out = $timeoutMillis ms)")
+ } else if (allBatchesCompleted) {
+ None
+ } else {
+ throw new Exception("Unexpected end of test")
+ }
+ }
+
+ private def createCheckpointDir(): Path = {
+ // Create the directories for this test
+ val uuid = UUID.randomUUID().toString
+ val rootDir = new Path(testDirectory, uuid)
+ val fs = rootDir.getFileSystem(new Configuration())
+ val dir = new Path(rootDir, "checkpoint")
+ fs.mkdirs(dir)
+ dir
+ }
+
+ class BatchCompletionListener extends StreamingListener {
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
+ if (batchCompleted.batchInfo.batchTime.milliseconds > lastBatchCompleted) {
+ batchesCompleted += 1
+ lastBatchCompleted = batchCompleted.batchInfo.batchTime.milliseconds
+ }
+ }
+ }
+
+ class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging {
+ override def run() {
+ try {
+ // If it is the first killing, then allow the first checkpoint to be created
+ var minKillWaitTime = if (killCount == 0) 5000 else 2000
+ val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime)
+ logInfo("Kill wait time = " + killWaitTime)
+ Thread.sleep(killWaitTime)
+ logInfo(
+ "\n---------------------------------------\n" +
+ "Killing streaming context after " + killWaitTime + " ms" +
+ "\n---------------------------------------\n"
+ )
+ ssc.stop()
+ killed = true
+ killCount += 1
+ println("Killed")
+ logInfo("Killing thread finished normally")
+ } catch {
+ case ie: InterruptedException => logInfo("Killing thread interrupted")
+ case e: Exception => logWarning("Exception in killing thread", e)
+ }
+
+ }
+ }
+}
+
+object DriverFailureTest {
+ @transient @volatile var failed: Boolean = _
+ @transient @volatile var firstFailureMessage: String = _
+
+ def fail(message: String) {
+ if (!failed) {
+ failed = true
+ firstFailureMessage = message
+ }
+ }
+
+ def reset() {
+ failed = false
+ firstFailureMessage = "NOT SET"
+ }
+}
+
+class ReceiverBasedDriverFailureTest[T](
+ @transient testDirectory: String,
+ @transient batchDurationMillis: Int,
+ @transient numBatchesToRun: Int,
+ @transient operation: DStream[String] => DStream[T],
+ outputVerifyingFunction: (Time, Seq[T]) => Unit
+ ) extends DriverFailureTest(
+ testDirectory, batchDurationMillis, numBatchesToRun
+ ) {
+
+ @transient val conf = new SparkConf()
+ conf.setMaster("local[4]")
+ .setAppName("ReceiverBasedDriverFailureTest")
+ .set("spark.streaming.receiver.writeAheadLog.enable", "true") // enable write ahead log
+ .set("spark.streaming.receiver.writeAheadLog.rotationIntervalSecs", "10") // rotate logs to test cleanup
+
+ override def setupContext(checkpointDirector: String): StreamingContext = {
+
+ val context = StreamingContext.getOrCreate(checkpointDirector, () => {
+ val newSsc = new StreamingContext(conf, Milliseconds(batchDurationMillis))
+ val inputStream = newSsc.receiverStream[String](new DriverFailureTestReceiver)
+ /*inputStream.mapPartitions(iter => {
+ val sum = iter.map { _.split(" ").size }.fold(0)(_ + _)
+ Iterator(sum)
+ }).foreachRDD ((rdd: RDD[Int], time: Time) => {
+ try {
+ val collected = rdd.collect().sorted
+ println(s"# in partitions at $time = ${collected.mkString(", ")}")
+ } catch {
+ case ie: InterruptedException =>
+ // ignore
+ case e: Exception =>
+ DriverFailureTest.fail(e.toString)
+ }
+
+ })*/
+ val operatedStream = operation(inputStream)
+
+ val verify = outputVerifyingFunction
+ operatedStream.foreachRDD((rdd: RDD[T], time: Time) => {
+ try {
+ val collected = rdd.collect()
+ verify(time, collected)
+ } catch {
+ case ie: InterruptedException =>
+ // ignore
+ case e: Exception =>
+ DriverFailureTest.fail(e.toString)
+ }
+ })
+ newSsc.checkpoint(checkpointDirector)
+ newSsc
+ })
+ context
+ }
+}
+
+
+
+class DriverFailureTestReceiver extends Receiver[String](StorageLevel.MEMORY_ONLY_SER) with Logging {
+
+ import DriverFailureTestReceiver._
+ @volatile var thread: Thread = null
+
+ class ReceivingThread extends Thread() {
+ override def run() {
+ while (!isStopped() && !isInterrupted()) {
+ try {
+ val block = getNextBlock()
+ store(block)
+ commitBlock()
+ Thread.sleep(10)
+ } catch {
+ case ie: InterruptedException =>
+ case e: Exception =>
+ DriverFailureTestReceiver.this.stop("Error in receiving thread", e)
+ }
+ }
+ }
+ }
+
+ def onStart() {
+ if (thread == null) {
+ thread = new ReceivingThread()
+ thread.start()
+ } else {
+ logError("Error starting receiver, previous receiver thread not stopped yet.")
+ }
+ }
+
+ def onStop() {
+ if (thread != null) {
+ thread.interrupt()
+ thread = null
+ }
+ }
+}
+
+object DriverFailureTestReceiver {
+ val maxRecordsPerBlock = 1000L
+ private val currentKey = new AtomicInteger()
+ private val counter = new AtomicInteger()
+
+ counter.set(1)
+ currentKey.set(1)
+
+ def getNextBlock(): ArrayBuffer[String] = {
+ val count = counter.get()
+ new ArrayBuffer ++= (1 to count).map { _ => "word%03d".format(currentKey.get()) }
+ }
+
+ def commitBlock() {
+ println(s"Stored ${counter.get()} copies of word${currentKey.get}")
+ if (counter.incrementAndGet() > maxRecordsPerBlock) {
+ currentKey.incrementAndGet()
+ counter.set(1)
+ }
+ }
+}
+
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
deleted file mode 100644
index 40434b1f9b709..0000000000000
--- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming
-
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
-
-import java.io.File
-
-/**
- * This testsuite tests master failures at random times while the stream is running using
- * the real clock.
- */
-class FailureSuite extends TestSuiteBase with Logging {
-
- var directory = "FailureSuite"
- val numBatches = 30
-
- override def batchDuration = Milliseconds(1000)
-
- override def useManualClock = false
-
- override def beforeFunction() {
- super.beforeFunction()
- Utils.deleteRecursively(new File(directory))
- }
-
- override def afterFunction() {
- super.afterFunction()
- Utils.deleteRecursively(new File(directory))
- }
-
- test("multiple failures with map") {
- MasterFailureTest.testMap(directory, numBatches, batchDuration)
- }
-
- test("multiple failures with updateStateByKey") {
- MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration)
- }
-}
-
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
new file mode 100644
index 0000000000000..6b791df092f64
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -0,0 +1,214 @@
+package org.apache.spark.streaming
+
+import java.io.File
+
+import scala.Some
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.language.{implicitConversions, postfixOps}
+import scala.util.Random
+
+import com.google.common.io.Files
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.scheduler._
+import org.apache.spark.streaming.storage._
+import org.apache.spark.streaming.storage.WriteAheadLogSuite._
+import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock}
+import org.apache.spark.util.Utils
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.concurrent.Eventually._
+import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
+import org.apache.spark.storage.StreamBlockId
+import scala.Some
+import org.apache.spark.streaming.storage.BlockAddition
+
+class ReceivedBlockTrackerSuite
+ extends FunSuite with BeforeAndAfter with Matchers with Logging {
+
+ val conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite")
+ conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1")
+
+ val hadoopConf = new Configuration()
+ val akkaTimeout = 10 seconds
+ val streamId = 1
+
+ var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]()
+ var checkpointDirectory: File = null
+
+ before {
+ checkpointDirectory = Files.createTempDir()
+ }
+
+ after {
+ allReceivedBlockTrackers.foreach { _.stop() }
+ if (checkpointDirectory != null && checkpointDirectory.exists()) {
+ FileUtils.deleteDirectory(checkpointDirectory)
+ checkpointDirectory = null
+ }
+ }
+
+ test("block addition, and block to batch allocation") {
+ val receivedBlockTracker = createTracker(enableCheckpoint = false)
+ receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty
+
+ val blockInfos = generateBlockInfos()
+ blockInfos.map(receivedBlockTracker.addBlock)
+
+ receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos
+ receivedBlockTracker.getOrAllocateBlocksToBatch(1, streamId) shouldEqual blockInfos
+ receivedBlockTracker.getUnallocatedBlocks(streamId) should have size 0
+ receivedBlockTracker.getOrAllocateBlocksToBatch(1, streamId) shouldEqual blockInfos
+ receivedBlockTracker.getOrAllocateBlocksToBatch(2, streamId) should have size 0
+ }
+
+ test("block addition, block to batch allocation and cleanup with write ahead log") {
+ val manualClock = new ManualClock
+ conf.getInt(
+ "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) should be (1)
+
+ // Set the time increment level to twice the rotation interval so that every increment creates
+ // a new log file
+ val timeIncrementMillis = 2000L
+ def incrementTime() {
+ manualClock.addToTime(timeIncrementMillis)
+ }
+
+ // Generate and add blocks to the given tracker
+ def addBlockInfos(tracker: ReceivedBlockTracker): Seq[ReceivedBlockInfo] = {
+ val blockInfos = generateBlockInfos()
+ blockInfos.map(tracker.addBlock)
+ blockInfos
+ }
+
+ // Print the data present in the log ahead files in the log directory
+ def printLogFiles(message: String) {
+ val fileContents = getWriteAheadLogFiles().map { file =>
+ (s"\n>>>>> $file: <<<<<\n${getWrittenLogData(file).mkString("\n")}")
+ }.mkString("\n")
+ logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n")
+ }
+
+ // Start tracker and add blocks
+ val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock)
+ val blockInfos1 = addBlockInfos(tracker1)
+ tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1
+
+ // Verify whether write ahead log has correct contents
+ val expectedWrittenData1 = blockInfos1.map(BlockAddition)
+ getWrittenLogData() shouldEqual expectedWrittenData1
+ getWriteAheadLogFiles() should have size 1
+
+ // Restart tracker and verify recovered list of unallocated blocks
+ incrementTime()
+ val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock)
+ tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1
+
+ // Allocate blocks to batch and verify whether the unallocated blocks got allocated
+ val batchTime1 = manualClock.currentTime
+ tracker2.getOrAllocateBlocksToBatch(batchTime1, streamId) shouldEqual blockInfos1
+
+ // Add more blocks and allocate to another batch
+ incrementTime()
+ val batchTime2 = manualClock.currentTime
+ val blockInfos2 = addBlockInfos(tracker2)
+ tracker2.getOrAllocateBlocksToBatch(batchTime2, streamId) shouldEqual blockInfos2
+
+ // Verify whether log has correct contents
+ val expectedWrittenData2 = expectedWrittenData1 ++
+ Seq(createBatchAllocation(batchTime1, blockInfos1)) ++
+ blockInfos2.map(BlockAddition) ++
+ Seq(createBatchAllocation(batchTime2, blockInfos2))
+ getWrittenLogData() shouldEqual expectedWrittenData2
+
+ // Restart tracker and verify recovered state
+ incrementTime()
+ val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock)
+ tracker3.getOrAllocateBlocksToBatch(batchTime1, streamId) shouldEqual blockInfos1
+ tracker3.getOrAllocateBlocksToBatch(batchTime2, streamId) shouldEqual blockInfos2
+ tracker3.getUnallocatedBlocks(streamId) shouldBe empty
+
+ // Cleanup first batch but not second batch
+ val oldestLogFile = getWriteAheadLogFiles().head
+ incrementTime()
+ tracker3.cleanupOldBatches(batchTime2)
+
+ // Verify that the batch allocations have been cleaned, and the act has been written to log
+ tracker3.getOrAllocateBlocksToBatch(batchTime1, streamId) shouldEqual Seq.empty
+ getWrittenLogData(getWriteAheadLogFiles().last) should contain(createBatchCleanup(batchTime1))
+
+ // Verify that at least one log file gets deleted
+ eventually(timeout(10 seconds), interval(10 millisecond )) {
+ getWriteAheadLogFiles() should not contain oldestLogFile
+ }
+ printLogFiles("After cleanup")
+
+ // Restart tracker and verify recovered state, specifically whether info about the first
+ // batch has been removed, but not the second batch
+ incrementTime()
+ val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock)
+ tracker4.getUnallocatedBlocks(streamId) shouldBe empty
+ tracker4.getOrAllocateBlocksToBatch(batchTime1, streamId) shouldBe empty // should be cleaned
+ tracker4.getOrAllocateBlocksToBatch(batchTime2, streamId) shouldEqual blockInfos2
+ }
+
+ /**
+ * Create tracker object with the optional provided clock. Use fake clock if you
+ * want to control time by manually incrementing it to test log cleanup.
+ */
+ def createTracker(enableCheckpoint: Boolean, clock: Clock = new SystemClock): ReceivedBlockTracker = {
+ val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) else None
+ val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption)
+ allReceivedBlockTrackers += tracker
+ tracker
+ }
+
+ /** Generate blocks infos using random ids */
+ def generateBlockInfos(): Seq[ReceivedBlockInfo] = {
+ (1 to 5).map { id =>
+ new ReceivedBlockInfo(streamId,
+ StreamBlockId(streamId, math.abs(Random.nextInt())), 0, null, None)
+ }.toList
+ }
+
+ /** Get all the data written in the given write ahead log file. */
+ def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerRecord] = {
+ getWrittenLogData(Seq(logFile))
+ }
+
+ /**
+ * Get all the data written in the given write ahead log files. By default, it will read all
+ * files in the test log directory.
+ */
+ def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerRecord] = {
+ logFiles.flatMap {
+ file => new WriteAheadLogReader(file, hadoopConf).toSeq
+ }.map { byteBuffer =>
+ Utils.deserialize[ReceivedBlockTrackerRecord](byteBuffer.array)
+ }.toList
+ }
+
+ /** Get all the write ahead log files in the test directory */
+ def getWriteAheadLogFiles(): Seq[String] = {
+ import ReceivedBlockTracker._
+ val logDir = checkpointDirToLogDir(checkpointDirectory.toString)
+ getLogFilesInDirectory(new File(logDir)).map { _.toString }
+ }
+
+ /** Create batch allocation object from the given info */
+ def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocations = {
+ BatchAllocations(time, AllocatedBlocks(Map((streamId -> blockInfos))))
+ }
+
+ /** Create batch cleanup object from the given info */
+ def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanup = {
+ BatchCleanup((Seq(time) ++ moreTimes).map(Time.apply))
+ }
+
+ implicit def millisToTime(milliseconds: Long): Time = Time(milliseconds)
+
+ implicit def timeToMillis(time: Time): Long = time.milliseconds
+
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDDSuite.scala
similarity index 100%
rename from streaming/src/test/scala/org/apache/spark/streaming/storage/rdd/HDFSBackedBlockRDDSuite.scala
rename to streaming/src/test/scala/org/apache/spark/streaming/rdd/HDFSBackedBlockRDDSuite.scala
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/storage/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/storage/ReceivedBlockHandlerSuite.scala
new file mode 100644
index 0000000000000..654bc7bad92b6
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/storage/ReceivedBlockHandlerSuite.scala
@@ -0,0 +1,138 @@
+package org.apache.spark.streaming.storage
+
+import java.io.File
+
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import akka.actor.{ActorSystem, Props}
+import com.google.common.io.Files
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
+import org.apache.spark.network.nio.NioBlockTransferService
+import org.apache.spark.scheduler.LiveListenerBus
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.storage._
+import org.apache.spark.streaming.util.ManualClock
+import org.apache.spark.util.AkkaUtils
+import WriteAheadLogBasedBlockHandler._
+import WriteAheadLogSuite._
+
+class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers {
+
+ val conf = new SparkConf()
+ .set("spark.authenticate", "false")
+ .set("spark.kryoserializer.buffer.mb", "1")
+ .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1")
+ val hadoopConf = new Configuration()
+ val storageLevel = StorageLevel.MEMORY_ONLY_SER
+ val streamId = 1
+ val securityMgr = new SecurityManager(conf)
+ val mapOutputTracker = new MapOutputTrackerMaster(conf)
+ val shuffleManager = new HashShuffleManager(conf)
+ val serializer = new KryoSerializer(conf)
+ val manualClock = new ManualClock
+
+ var actorSystem: ActorSystem = null
+ var blockManagerMaster: BlockManagerMaster = null
+ var blockManager: BlockManager = null
+ var receivedBlockHandler: ReceivedBlockHandler = null
+ var tempDirectory: File = null
+
+ before {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
+ "test", "localhost", 0, conf = conf, securityManager = securityMgr)
+ this.actorSystem = actorSystem
+
+ conf.set("spark.driver.port", boundPort.toString)
+
+ blockManagerMaster = new BlockManagerMaster(
+ actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
+ conf, true)
+ blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, 100000,
+ conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr))
+ tempDirectory = Files.createTempDir()
+ manualClock.setTime(0)
+ }
+
+ after {
+ actorSystem.shutdown()
+ actorSystem.awaitTermination()
+ actorSystem = null
+ blockManagerMaster = null
+
+ if (blockManager != null) {
+ blockManager.stop()
+ blockManager = null
+ }
+
+ if (tempDirectory != null && tempDirectory.exists()) {
+ FileUtils.deleteDirectory(tempDirectory)
+ tempDirectory = null
+ }
+ }
+
+ test("WriteAheadLogBasedBlockHandler - store block") {
+ createWriteAheadLogBasedBlockHandler()
+ val (data, blockIds) = generateAndStoreData(receivedBlockHandler)
+ receivedBlockHandler.asInstanceOf[WriteAheadLogBasedBlockHandler].stop()
+
+ // Check whether blocks inserted in the block manager are correct
+ val blockManagerData = blockIds.flatMap { blockId =>
+ blockManager.getLocal(blockId).map { _.data }.getOrElse(Iterator.empty)
+ }
+ blockManagerData.toList shouldEqual data.toList
+
+ // Check whether the blocks written to the write ahead log are correct
+ val logFiles = getWriteAheadLogFiles()
+ logFiles.size should be > 1
+
+ val logData = logFiles.flatMap {
+ file => new WriteAheadLogReader(file.toString, hadoopConf).toSeq
+ }.flatMap { blockManager.dataDeserialize(StreamBlockId(streamId, 0), _ )}
+ logData.toList shouldEqual data.toList
+ }
+
+ test("WriteAheadLogBasedBlockHandler - clear old blocks") {
+ createWriteAheadLogBasedBlockHandler()
+ generateAndStoreData(receivedBlockHandler)
+ val preCleanupLogFiles = getWriteAheadLogFiles()
+ preCleanupLogFiles.size should be > 1
+
+ // this depends on the number of blocks inserted using generateAndStoreData()
+ manualClock.currentTime() shouldEqual 5000L
+
+ val cleanupThreshTime = 3000L
+ receivedBlockHandler.cleanupOldBlock(cleanupThreshTime)
+ eventually(timeout(10000 millis), interval(10 millis)) {
+ getWriteAheadLogFiles().size should be < preCleanupLogFiles.size
+ }
+ }
+
+ def createWriteAheadLogBasedBlockHandler() {
+ receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1,
+ storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock)
+ }
+
+ def generateAndStoreData(
+ receivedBlockHandler: ReceivedBlockHandler): (Seq[String], Seq[StreamBlockId]) = {
+ val data = (1 to 100).map { _.toString }
+ val blocks = data.grouped(10).map { _.toIterator }.toSeq
+ val blockIds = (0 until blocks.size).map { i => StreamBlockId(streamId, i) }
+ blocks.zip(blockIds).foreach { case (block, id) =>
+ manualClock.addToTime(500) // log rolling interval set to 1000 ms through SparkConf
+ receivedBlockHandler.storeBlock(id, IteratorBlock(block))
+ }
+ (data, blockIds)
+ }
+
+ def getWriteAheadLogFiles(): Seq[File] = {
+ getLogFilesInDirectory(
+ new File(checkpointDirToLogDir(tempDirectory.toString, streamId)))
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala
index ed21bdbb399fd..88b2b5095ceb6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala
@@ -16,145 +16,262 @@
*/
package org.apache.spark.streaming.storage
-import java.io.{RandomAccessFile, File}
+import java.io.{DataInputStream, FileInputStream, File, RandomAccessFile}
import java.nio.ByteBuffer
-import java.util.Random
+
+import scala.util.Random
import scala.collection.mutable.ArrayBuffer
+import scala.language.implicitConversions
import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.apache.commons.io.FileUtils
+import org.apache.spark.streaming.util.ManualClock
+import org.apache.spark.util.Utils
+import WriteAheadLogSuite._
-import org.apache.spark.streaming.TestSuiteBase
+class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
-class WriteAheadLogSuite extends TestSuiteBase {
-
val hadoopConf = new Configuration()
- val random = new Random()
+ var tempDirectory: File = null
- test("Test successful writes") {
- val dir = Files.createTempDir()
- val file = new File(dir, "TestWriter")
- try {
- val dataToWrite = for (i <- 1 to 50) yield generateRandomData()
- val writer = new WriteAheadLogWriter("file:///" + file.toString, hadoopConf)
- val segments = dataToWrite.map(writer.write)
- writer.close()
- val writtenData = readData(segments, file)
- assert(writtenData.toArray === dataToWrite.toArray)
- } finally {
- file.delete()
- dir.delete()
+ before {
+ tempDirectory = Files.createTempDir()
+ }
+
+ after {
+ if (tempDirectory != null && tempDirectory.exists()) {
+ FileUtils.deleteDirectory(tempDirectory)
+ tempDirectory = null
}
}
- test("Test successful reads using random reader") {
- val file = File.createTempFile("TestRandomReads", "")
- file.deleteOnExit()
- val writtenData = writeData(50, file)
- val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf)
- writtenData.foreach {
- x =>
- val length = x._1.remaining()
- assert(x._1 === reader.read(new FileSegment(file.toString, x._2, length)))
+ test("WriteAheadLogWriter - writing data") {
+ val file = new File(tempDirectory, Random.nextString(10))
+ val dataToWrite = generateRandomData()
+ val writer = new WriteAheadLogWriter("file:///" + file, hadoopConf)
+ val segments = dataToWrite.map(data => writer.write(data))
+ writer.close()
+ val writtenData = readDataManually(file, segments)
+ assert(writtenData.toArray === dataToWrite.toArray)
+ }
+
+ test("WriteAheadLogWriter - syncing of data by writing and reading immediately") {
+ val file = new File(tempDirectory, Random.nextString(10))
+ val dataToWrite = generateRandomData()
+ val writer = new WriteAheadLogWriter("file:///" + file, hadoopConf)
+ dataToWrite.foreach { data =>
+ val segment = writer.write(data)
+ assert(readDataManually(file, Seq(segment)).head === data)
}
- reader.close()
+ writer.close()
}
- test("Test reading data using random reader written with writer") {
- val dir = Files.createTempDir()
- val file = new File(dir, "TestRandomReads")
- try {
- val dataToWrite = for (i <- 1 to 50) yield generateRandomData()
- val segments = writeUsingWriter(file, dataToWrite)
- val iter = dataToWrite.iterator
- val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf)
- val writtenData = segments.map { x =>
- reader.read(x)
- }
- assert(dataToWrite.toArray === writtenData.toArray)
- } finally {
- file.delete()
- dir.delete()
+ test("WriteAheadLogReader - sequentially reading data") {
+ // Write data manually for testing the sequential reader
+ val file = File.createTempFile("TestSequentialReads", "", tempDirectory)
+ val writtenData = generateRandomData()
+ writeDataManually(writtenData, file)
+ val reader = new WriteAheadLogReader("file:///" + file.toString, hadoopConf)
+ val readData = reader.toSeq.map(byteBufferToString)
+ assert(readData.toList === writtenData.toList)
+ assert(reader.hasNext === false)
+ intercept[Exception] {
+ reader.next()
}
+ reader.close()
}
- test("Test successful reads using sequential reader") {
- val file = File.createTempFile("TestSequentialReads", "")
- file.deleteOnExit()
- val writtenData = writeData(50, file)
+ test("WriteAheadLogReader - sequentially reading data written with writer") {
+ // Write data manually for testing the sequential reader
+ val file = new File(tempDirectory, "TestWriter")
+ val dataToWrite = generateRandomData()
+ writeDataUsingWriter(file, dataToWrite)
+ val iter = dataToWrite.iterator
val reader = new WriteAheadLogReader("file:///" + file.toString, hadoopConf)
- val iter = writtenData.iterator
- iter.foreach { x =>
- assert(reader.hasNext === true)
- assert(reader.next() === x._1)
+ reader.foreach { byteBuffer =>
+ assert(byteBufferToString(byteBuffer) === iter.next())
}
reader.close()
}
+ test("WriteAheadLogRandomReader - reading data using random reader") {
+ // Write data manually for testing the random reader
+ val file = File.createTempFile("TestRandomReads", "", tempDirectory)
+ val writtenData = generateRandomData()
+ val segments = writeDataManually(writtenData, file)
- test("Test reading data using sequential reader written with writer") {
- val dir = Files.createTempDir()
- val file = new File(dir, "TestWriter")
- try {
- val dataToWrite = for (i <- 1 to 50) yield generateRandomData()
- val segments = writeUsingWriter(file, dataToWrite)
- val iter = dataToWrite.iterator
- val reader = new WriteAheadLogReader("file:///" + file.toString, hadoopConf)
- reader.foreach { x =>
- assert(x === iter.next())
- }
- } finally {
- file.delete()
- dir.delete()
+ // Get a random order of these segments and read them back
+ val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten
+ val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf)
+ writtenDataAndSegments.foreach { case (data, segment) =>
+ assert(data === byteBufferToString(reader.read(segment)))
}
+ reader.close()
+ }
+
+ test("WriteAheadLogRandomReader - reading data using random reader written with writer") {
+ // Write data using writer for testing the random reader
+ val file = new File(tempDirectory, "TestRandomReads")
+ val data = generateRandomData()
+ val segments = writeDataUsingWriter(file, data)
+
+ // Read a random sequence of segments and verify read data
+ val dataAndSegments = data.zip(segments).toSeq.permutations.take(10).flatten
+ val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf)
+ dataAndSegments.foreach { case(data, segment) =>
+ assert(data === byteBufferToString(reader.read(segment)))
+ }
+ reader.close()
+ }
+
+ test("WriteAheadLogManager - write rotating logs") {
+ // Write data using manager
+ val dataToWrite = generateRandomData(10)
+ writeDataUsingManager(tempDirectory, dataToWrite)
+
+ // Read data manually to verify the written data
+ val logFiles = getLogFilesInDirectory(tempDirectory)
+ assert(logFiles.size > 1)
+ val writtenData = logFiles.flatMap { file => readDataManually(file) }
+ assert(writtenData.toList === dataToWrite.toList)
+ }
+
+ test("WriteAheadLogManager - read rotating logs") {
+ // Write data manually for testing reading through manager
+ val writtenData = (1 to 10).map { i =>
+ val data = generateRandomData(10)
+ val file = new File(tempDirectory, s"log-$i-${i + 1}")
+ writeDataManually(data, file)
+ data
+ }.flatten
+
+ // Read data using manager and verify
+ val readData = readDataUsingManager(tempDirectory)
+ assert(readData.toList === writtenData.toList)
+ }
+
+ test("WriteAheadLogManager - recover past logs when creating new manager") {
+ // Write data with manager, recover with new manager and verify
+ val dataToWrite = generateRandomData(100)
+ writeDataUsingManager(tempDirectory, dataToWrite)
+ val logFiles = getLogFilesInDirectory(tempDirectory)
+ assert(logFiles.size > 1)
+ val readData = readDataUsingManager(tempDirectory)
+ assert(dataToWrite.toList === readData.toList)
}
+ // TODO (Hari, TD): Test different failure conditions of writers and readers.
+ // - Failure in the middle of write
+ // - Failure while reading incomplete/corrupt file
+}
+
+object WriteAheadLogSuite {
+
+ private val hadoopConf = new Configuration()
+
/**
- * Writes data to the file and returns the an array of the bytes written.
- * @param count
- * @return
+ * Write data to the file and returns the an array of the bytes written.
+ * This is used to test the WAL reader independently of the WAL writer.
*/
- // We don't want to be using the WAL writer to test the reader - it will be painful to figure
- // out where the bug is. Instead generate the file by hand and see if the WAL reader can
- // handle it.
- def writeData(count: Int, file: File): ArrayBuffer[(ByteBuffer, Long)] = {
- val writtenData = new ArrayBuffer[(ByteBuffer, Long)]()
+ def writeDataManually(data: Seq[String], file: File): Seq[FileSegment] = {
+ val segments = new ArrayBuffer[FileSegment]()
val writer = new RandomAccessFile(file, "rw")
- var i = 0
- while (i < count) {
- val data = generateRandomData()
- writtenData += ((data, writer.getFilePointer))
- data.rewind()
- writer.writeInt(data.remaining())
- writer.write(data.array())
- i += 1
+ data.foreach { item =>
+ val offset = writer.getFilePointer()
+ val bytes = Utils.serialize(item)
+ writer.writeInt(bytes.size)
+ writer.write(bytes)
+ segments += FileSegment(file.toString, offset, bytes.size)
}
writer.close()
- writtenData
+ segments
}
- def readData(segments: Seq[FileSegment], file: File): Seq[ByteBuffer] = {
+ def writeDataUsingWriter(file: File, data: Seq[String]): Seq[FileSegment] = {
+ val writer = new WriteAheadLogWriter(file.toString, hadoopConf)
+ val segments = data.map {
+ item => writer.write(item)
+ }
+ writer.close()
+ segments
+ }
+
+ def writeDataUsingManager(logDirectory: File, data: Seq[String]) {
+ val fakeClock = new ManualClock
+ val manager = new WriteAheadLogManager(logDirectory.toString, hadoopConf,
+ rollingIntervalSecs = 1, callerName = "WriteAheadLogSuite", clock = fakeClock)
+ data.foreach { item =>
+ fakeClock.addToTime(500)
+ manager.writeToLog(item)
+ }
+ manager.stop()
+ }
+
+ /**
+ * Read data from the given segments of log file and returns the read list of byte buffers.
+ * This is used to test the WAL writer independently of the WAL reader.
+ */
+ def readDataManually(file: File, segments: Seq[FileSegment]): Seq[String] = {
val reader = new RandomAccessFile(file, "r")
segments.map { x =>
reader.seek(x.offset)
val data = new Array[Byte](x.length)
reader.readInt()
reader.readFully(data)
- ByteBuffer.wrap(data)
+ Utils.deserialize[String](data)
+ }
+ }
+
+ def readDataManually(file: File): Seq[String] = {
+ val reader = new DataInputStream(new FileInputStream(file))
+ val buffer = new ArrayBuffer[String]
+ try {
+ while (reader.available > 0) {
+ val length = reader.readInt()
+ val bytes = new Array[Byte](length)
+ reader.read(bytes)
+ buffer += Utils.deserialize[String](bytes)
+ }
+ } finally {
+ reader.close()
}
+ buffer
}
- def generateRandomData(): ByteBuffer = {
- val data = new Array[Byte](random.nextInt(50))
- random.nextBytes(data)
- ByteBuffer.wrap(data)
+ def readDataUsingManager(logDirectory: File): Seq[String] = {
+ val manager = new WriteAheadLogManager(logDirectory.toString, hadoopConf,
+ callerName = "WriteAheadLogSuite")
+ val data = manager.readFromLog().map(byteBufferToString).toSeq
+ manager.stop()
+ data
}
- def writeUsingWriter(file: File, input: Seq[ByteBuffer]): Seq[FileSegment] = {
- val writer = new WriteAheadLogWriter(file.toString, hadoopConf)
- val segments = input.map(writer.write)
- writer.close()
- segments
+ def generateRandomData(numItems: Int = 50, itemSize: Int = 50): Seq[String] = {
+ (1 to numItems).map { _.toString }
+ }
+
+ def getLogFilesInDirectory(directory: File): Seq[File] = {
+ if (directory.exists) {
+ directory.listFiles().filter(_.getName().startsWith("log-"))
+ .sortBy(_.getName.split("-")(1).toLong)
+ } else {
+ Seq.empty
+ }
+ }
+
+ def printData(data: Seq[String]) {
+ println("# items in data = " + data.size)
+ println(data.mkString("\n"))
+ }
+
+ implicit def stringToByteBuffer(str: String): ByteBuffer = {
+ ByteBuffer.wrap(Utils.serialize(str))
+ }
+
+ implicit def byteBufferToString(byteBuffer: ByteBuffer): String = {
+ Utils.deserialize[String](byteBuffer.array)
}
}