Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1022][Streaming] Add Kafka real unit test #1751

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions external/kafka/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>net.sf.jopt-simple</groupId>
<artifactId>jopt-simple</artifactId>
<version>3.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,118 @@

package org.apache.spark.streaming.kafka;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;

import scala.Predef;
import scala.Tuple2;
import scala.collection.JavaConverters;

import junit.framework.Assert;

import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream;
import org.junit.Test;
import com.google.common.collect.Maps;
import kafka.serializer.StringDecoder;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.LocalJavaStreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;

import org.junit.Test;
import org.junit.After;
import org.junit.Before;

public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable {
private transient KafkaStreamSuite testSuite = new KafkaStreamSuite();

@Before
@Override
public void setUp() {
testSuite.beforeFunction();
System.clearProperty("spark.driver.port");
//System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock");
ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
}

@After
@Override
public void tearDown() {
ssc.stop();
ssc = null;
System.clearProperty("spark.driver.port");
testSuite.afterFunction();
}

public class JavaKafkaStreamSuite extends LocalJavaStreamingContext {
@Test
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();

// tests the API, does not actually test data receiving
JavaPairReceiverInputDStream<String, String> test1 =
KafkaUtils.createStream(ssc, "localhost:12345", "group", topics);
JavaPairReceiverInputDStream<String, String> test2 = KafkaUtils.createStream(ssc, "localhost:12345", "group", topics,
StorageLevel.MEMORY_AND_DISK_SER_2());

HashMap<String, String> kafkaParams = Maps.newHashMap();
kafkaParams.put("zookeeper.connect", "localhost:12345");
kafkaParams.put("group.id","consumer-group");
JavaPairReceiverInputDStream<String, String> test3 = KafkaUtils.createStream(ssc,
String.class, String.class, StringDecoder.class, StringDecoder.class,
kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2());
public void testKafkaStream() throws InterruptedException {
String topic = "topic1";
HashMap<String, Integer> topics = new HashMap<String, Integer>();
topics.put(topic, 1);

HashMap<String, Integer> sent = new HashMap<String, Integer>();
sent.put("a", 5);
sent.put("b", 3);
sent.put("c", 10);

testSuite.createTopic(topic);
HashMap<String, Object> tmp = new HashMap<String, Object>(sent);
testSuite.produceAndSendMessage(topic,
JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap(
Predef.<Tuple2<String, Object>>conforms()));

HashMap<String, String> kafkaParams = new HashMap<String, String>();
kafkaParams.put("zookeeper.connect", testSuite.zkConnect());
kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000));
kafkaParams.put("auto.offset.reset", "smallest");

JavaPairDStream<String, String> stream = KafkaUtils.createStream(ssc,
String.class,
String.class,
StringDecoder.class,
StringDecoder.class,
kafkaParams,
topics,
StorageLevel.MEMORY_ONLY_SER());

final HashMap<String, Long> result = new HashMap<String, Long>();

JavaDStream<String> words = stream.map(
new Function<Tuple2<String, String>, String>() {
@Override
public String call(Tuple2<String, String> tuple2) throws Exception {
return tuple2._2();
}
}
);

words.countByValue().foreachRDD(
new Function<JavaPairRDD<String, Long>, Void>() {
@Override
public Void call(JavaPairRDD<String, Long> rdd) throws Exception {
List<Tuple2<String, Long>> ret = rdd.collect();
for (Tuple2<String, Long> r : ret) {
if (result.containsKey(r._1())) {
result.put(r._1(), result.get(r._1()) + r._2());
} else {
result.put(r._1(), r._2());
}
}

return null;
}
}
);

ssc.start();
ssc.awaitTermination(3000);

Assert.assertEquals(sent.size(), result.size());
for (String k : sent.keySet()) {
Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,193 @@

package org.apache.spark.streaming.kafka

import kafka.serializer.StringDecoder
import java.io.File
import java.net.InetSocketAddress
import java.util.{Properties, Random}

import scala.collection.mutable

import kafka.admin.CreateTopicCommand
import kafka.common.TopicAndPartition
import kafka.producer.{KeyedMessage, ProducerConfig, Producer}
import kafka.utils.ZKStringSerializer
import kafka.serializer.{StringDecoder, StringEncoder}
import kafka.server.{KafkaConfig, KafkaServer}

import org.I0Itec.zkclient.ZkClient

import org.apache.zookeeper.server.ZooKeeperServer
import org.apache.zookeeper.server.NIOServerCnxnFactory

import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.util.Utils

class KafkaStreamSuite extends TestSuiteBase {
import KafkaTestUtils._

val zkConnect = "localhost:2181"
val zkConnectionTimeout = 6000
val zkSessionTimeout = 6000

val brokerPort = 9092
val brokerProps = getBrokerConfig(brokerPort, zkConnect)
val brokerConf = new KafkaConfig(brokerProps)

protected var zookeeper: EmbeddedZookeeper = _
protected var zkClient: ZkClient = _
protected var server: KafkaServer = _
protected var producer: Producer[String, String] = _

override def useManualClock = false

override def beforeFunction() {
// Zookeeper server startup
zookeeper = new EmbeddedZookeeper(zkConnect)
logInfo("==================== 0 ====================")
zkClient = new ZkClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer)
logInfo("==================== 1 ====================")

test("kafka input stream") {
// Kafka broker startup
server = new KafkaServer(brokerConf)
logInfo("==================== 2 ====================")
server.startup()
logInfo("==================== 3 ====================")
Thread.sleep(2000)
logInfo("==================== 4 ====================")
super.beforeFunction()
}

override def afterFunction() {
producer.close()
server.shutdown()
brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) }

zkClient.close()
zookeeper.shutdown()

super.afterFunction()
}

test("Kafka input stream") {
val ssc = new StreamingContext(master, framework, batchDuration)
val topics = Map("my-topic" -> 1)

// tests the API, does not actually test data receiving
val test1: ReceiverInputDStream[(String, String)] =
KafkaUtils.createStream(ssc, "localhost:1234", "group", topics)
val test2: ReceiverInputDStream[(String, String)] =
KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK_SER_2)
val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group")
val test3: ReceiverInputDStream[(String, String)] =
KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
ssc, kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2)

// TODO: Actually test receiving data
val topic = "topic1"
val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
createTopic(topic)
produceAndSendMessage(topic, sent)

val kafkaParams = Map("zookeeper.connect" -> zkConnect,
"group.id" -> s"test-consumer-${random.nextInt(10000)}",
"auto.offset.reset" -> "smallest")

val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
ssc,
kafkaParams,
Map(topic -> 1),
StorageLevel.MEMORY_ONLY)
val result = new mutable.HashMap[String, Long]()
stream.map { case (k, v) => v }
.countByValue()
.foreachRDD { r =>
val ret = r.collect()
ret.toMap.foreach { kv =>
val count = result.getOrElseUpdate(kv._1, 0) + kv._2
result.put(kv._1, count)
}
}
ssc.start()
ssc.awaitTermination(3000)

assert(sent.size === result.size)
sent.keys.foreach { k => assert(sent(k) === result(k).toInt) }

ssc.stop()
}

private def createTestMessage(topic: String, sent: Map[String, Int])
: Seq[KeyedMessage[String, String]] = {
val messages = for ((s, freq) <- sent; i <- 0 until freq) yield {
new KeyedMessage[String, String](topic, s)
}
messages.toSeq
}

def createTopic(topic: String) {
CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0")
logInfo("==================== 5 ====================")
// wait until metadata is propagated
waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000)
}

def produceAndSendMessage(topic: String, sent: Map[String, Int]) {
val brokerAddr = brokerConf.hostName + ":" + brokerConf.port
producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr)))
producer.send(createTestMessage(topic, sent): _*)
logInfo("==================== 6 ====================")
}
}

object KafkaTestUtils {
val random = new Random()

def getBrokerConfig(port: Int, zkConnect: String): Properties = {
val props = new Properties()
props.put("broker.id", "0")
props.put("host.name", "localhost")
props.put("port", port.toString)
props.put("log.dir", Utils.createTempDir().getAbsolutePath)
props.put("zookeeper.connect", zkConnect)
props.put("log.flush.interval.messages", "1")
props.put("replica.socket.timeout.ms", "1500")
props
}

def getProducerConfig(brokerList: String): Properties = {
val props = new Properties()
props.put("metadata.broker.list", brokerList)
props.put("serializer.class", classOf[StringEncoder].getName)
props
}

def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = {
val startTime = System.currentTimeMillis()
while (true) {
if (condition())
return true
if (System.currentTimeMillis() > startTime + waitTime)
return false
Thread.sleep(waitTime.min(100L))
}
// Should never go to here
throw new RuntimeException("unexpected error")
}

def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int,
timeout: Long) {
assert(waitUntilTrue(() =>
servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains(
TopicAndPartition(topic, partition))), timeout),
s"Partition [$topic, $partition] metadata not propagated after timeout")
}

class EmbeddedZookeeper(val zkConnect: String) {
val random = new Random()
val snapshotDir = Utils.createTempDir()
val logDir = Utils.createTempDir()

val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500)
val (ip, port) = {
val splits = zkConnect.split(":")
(splits(0), splits(1).toInt)
}
val factory = new NIOServerCnxnFactory()
factory.configure(new InetSocketAddress(ip, port), 16)
factory.startup(zookeeper)

def shutdown() {
factory.shutdown()
Utils.deleteRecursively(snapshotDir)
Utils.deleteRecursively(logDir)
}
}
}