Skip to content

Commit

Permalink
Merge pull request alteryx#164 from tdas/kafka-fix
Browse files Browse the repository at this point in the history
Made block generator thread safe to fix Kafka bug.

This is a very important bug fix. Data can and was being lost in the kafka due to this.

(cherry picked from commit dfd1ebc)
Signed-off-by: Reynold Xin <[email protected]>
  • Loading branch information
mateiz authored and rxin committed Nov 12, 2013
1 parent 30786c6 commit c856651
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
logInfo("Data handler stopped")
}

def += (obj: T) {
def += (obj: T): Unit = synchronized {
currentBuffer += obj
}

private def updateCurrentBuffer(time: Long) {
private def updateCurrentBuffer(time: Long): Unit = synchronized {
try {
val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ import akka.actor.IOManager
import akka.actor.Props
import akka.util.ByteString

import dstream.SparkFlumeEvent
import org.apache.spark.streaming.dstream.{NetworkReceiver, SparkFlumeEvent}
import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket}
import java.io.{File, BufferedWriter, OutputStreamWriter}
import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import util.ManualClock
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.receivers.Receiver
import org.apache.spark.Logging
import org.apache.spark.{SparkContext, Logging}
import scala.util.Random
import org.apache.commons.io.FileUtils
import org.scalatest.BeforeAndAfter
Expand All @@ -44,6 +44,7 @@ import java.nio.ByteBuffer
import collection.JavaConversions._
import java.nio.charset.Charset
import com.google.common.io.Files
import java.util.concurrent.atomic.AtomicInteger

class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {

Expand All @@ -61,7 +62,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
System.clearProperty("spark.hostPort")
}


test("socket input stream") {
// Start the server
val testServer = new TestServer()
Expand Down Expand Up @@ -271,10 +271,49 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
}

test("multi-thread receiver") {
// set up the test receiver
val numThreads = 10
val numRecordsPerThread = 1000
val numTotalRecords = numThreads * numRecordsPerThread
val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread)
MultiThreadTestReceiver.haveAllThreadsFinished = false

// set up the network stream using the test receiver
val ssc = new StreamingContext(master, framework, batchDuration)
val networkStream = ssc.networkStream[Int](testReceiver)
val countStream = networkStream.count
val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]]
val outputStream = new TestOutputStream(countStream, outputBuffer)
def output = outputBuffer.flatMap(x => x)
ssc.registerOutputStream(outputStream)
ssc.start()

// Let the data from the receiver be received
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val startTime = System.currentTimeMillis()
while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) &&
System.currentTimeMillis() - startTime < 5000) {
Thread.sleep(100)
clock.addToTime(batchDuration.milliseconds)
}
Thread.sleep(1000)
logInfo("Stopping context")
ssc.stop()

// Verify whether data received was as expected
logInfo("--------------------------------")
logInfo("output.size = " + outputBuffer.size)
logInfo("output")
outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("--------------------------------")
assert(output.sum === numTotalRecords)
}
}


/** This is server to test the network input stream */
/** This is a server to test the network input stream */
class TestServer() extends Logging {

val queue = new ArrayBlockingQueue[String](100)
Expand Down Expand Up @@ -336,6 +375,7 @@ object TestServer {
}
}

/** This is an actor for testing actor input stream */
class TestActor(port: Int) extends Actor with Receiver {

def bytesToString(byteString: ByteString) = byteString.utf8String
Expand All @@ -347,3 +387,36 @@ class TestActor(port: Int) extends Actor with Receiver {
pushBlock(bytesToString(bytes))
}
}

/** This is a receiver to test multiple threads inserting data using block generator */
class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int)
extends NetworkReceiver[Int] {
lazy val executorPool = Executors.newFixedThreadPool(numThreads)
lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
lazy val finishCount = new AtomicInteger(0)

protected def onStart() {
blockGenerator.start()
(1 to numThreads).map(threadId => {
val runnable = new Runnable {
def run() {
(1 to numRecordsPerThread).foreach(i =>
blockGenerator += (threadId * numRecordsPerThread + i) )
if (finishCount.incrementAndGet == numThreads) {
MultiThreadTestReceiver.haveAllThreadsFinished = true
}
logInfo("Finished thread " + threadId)
}
}
executorPool.submit(runnable)
})
}

protected def onStop() {
executorPool.shutdown()
}
}

object MultiThreadTestReceiver {
var haveAllThreadsFinished = false
}

0 comments on commit c856651

Please sign in to comment.