Skip to content

Commit

Permalink
[SPARK-18617][BACKPORT] Follow up PR to Close "kryo auto pick" featur…
Browse files Browse the repository at this point in the history
…e for Spark Streaming

## What changes were proposed in this pull request?

This is a follow-up PR to backport #16052 to branch-2.0 with incremental update in #16091

## How was this patch tested?

new unit test

cc zsxwing rxin

Author: uncleGen <[email protected]>

Closes #16096 from uncleGen/branch-2.0-SPARK-18617.
  • Loading branch information
uncleGen authored and rxin committed Dec 1, 2016
1 parent 5ecd3c2 commit 6e3fd2b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
}

def getSerializer(ct: ClassTag[_]): Serializer = {
if (canUseKryo(ct)) {
// SPARK-18617: As feature in SPARK-13990 can not be applied to Spark Streaming now. The worst
// result is streaming job based on `Receiver` mode can not run on Spark 2.x properly. It may be
// a rational choice to close `kryo auto pick` feature for streaming in the first step.
def getSerializer(ct: ClassTag[_], autoPick: Boolean): Serializer = {
if (autoPick && canUseKryo(ct)) {
kryoSerializer
} else {
defaultSerializer
Expand Down Expand Up @@ -122,7 +125,8 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
outputStream: OutputStream,
values: Iterator[T]): Unit = {
val byteStream = new BufferedOutputStream(outputStream)
val ser = getSerializer(implicitly[ClassTag[T]]).newInstance()
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
}

Expand All @@ -138,7 +142,8 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
classTag: ClassTag[_]): ChunkedByteBuffer = {
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
val byteStream = new BufferedOutputStream(bbos)
val ser = getSerializer(classTag).newInstance()
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(classTag, autoPick).newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
bbos.toChunkedByteBuffer
}
Expand All @@ -152,7 +157,8 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
inputStream: InputStream)
(classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
getSerializer(classTag)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
getSerializer(classTag, autoPick)
.newInstance()
.deserializeStream(wrapForCompression(blockId, stream))
.asIterator.asInstanceOf[Iterator[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel, StreamBlockId}
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
Expand Down Expand Up @@ -332,7 +332,8 @@ private[spark] class MemoryStore(
val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, allocator)
redirectableStream.setOutputStream(bbos)
val serializationStream: SerializationStream = {
val ser = serializerManager.getSerializer(classTag).newInstance()
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = serializerManager.getSerializer(classTag, autoPick).newInstance()
ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class PartiallySerializedBlockSuite
spy
}

val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
val serializer = serializerManager
.getSerializer(implicitly[ClassTag[T]], autoPick = true).newInstance()
val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream)
redirectableOutputStream.setOutputStream(bbos)
val serializationStream = Mockito.spy(serializer.serializeStream(redirectableOutputStream))
Expand Down Expand Up @@ -182,7 +183,8 @@ class PartiallySerializedBlockSuite
Mockito.verifyNoMoreInteractions(memoryStore)
Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose()

val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
val serializer = serializerManager
.getSerializer(implicitly[ClassTag[T]], autoPick = true).newInstance()
val deserialized =
serializer.deserializeStream(new ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq
assert(deserialized === items)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.streaming

import java.io.{File, NotSerializableException}
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -806,6 +807,34 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
ssc.stop()
}

test("SPARK-18560 Receiver data should be deserialized properly.") {
// Start a two nodes cluster, so receiver will use one node, and Spark jobs will use the
// other one. Then Spark jobs need to fetch remote blocks and it will trigger SPARK-18560.
val conf = new SparkConf().setMaster("local-cluster[2,1,1024]").setAppName(appName)
ssc = new StreamingContext(conf, Milliseconds(100))
val input = ssc.receiverStream(new TestReceiver)
val latch = new CountDownLatch(1)
input.count().foreachRDD { rdd =>
// Make sure we can read from BlockRDD
if (rdd.collect().headOption.getOrElse(0L) > 0) {
// Stop StreamingContext to unblock "awaitTerminationOrTimeout"
new Thread() {
setDaemon(true)
override def run(): Unit = {
ssc.stop(stopSparkContext = true, stopGracefully = false)
latch.countDown()
}
}.start()
}
}
ssc.start()
ssc.awaitTerminationOrTimeout(60000)
// Wait until `ssc.top` returns. Otherwise, we may finish this test too fast and leak an active
// SparkContext. Note: the stop codes in `after` will just do nothing if `ssc.stop` in this test
// is running.
assert(latch.await(60, TimeUnit.SECONDS))
}

def addInputStream(s: StreamingContext): DStream[Int] = {
val input = (1 to 100).map(i => 1 to i)
val inputStream = new TestInputStream(s, input, 1)
Expand Down

0 comments on commit 6e3fd2b

Please sign in to comment.