Skip to content

Commit

Permalink
refactor fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 27, 2014
1 parent c28f520 commit 3f0fb4b
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 146 deletions.
3 changes: 3 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __ne__(self, other):
def __repr__(self):
return "<%s object>" % self.__class__.__name__

def __hash__(self):
return hash(str(self))


class FramedSerializer(Serializer):

Expand Down
129 changes: 111 additions & 18 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,51 @@
# limitations under the License.
#

from pyspark.serializers import UTF8Deserializer
from pyspark import RDD
from pyspark.serializers import UTF8Deserializer, BatchedSerializer
from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.dstream import DStream
from pyspark.streaming.duration import Duration, Seconds
from pyspark.streaming.duration import Seconds

from py4j.java_collections import ListConverter

__all__ = ["StreamingContext"]


def _daemonize_callback_server():
"""
Hack Py4J to daemonize callback server
"""
# TODO: create a patch for Py4J
import socket
import py4j.java_gateway
logger = py4j.java_gateway.logger
from py4j.java_gateway import Py4JNetworkError
from threading import Thread

def start(self):
"""Starts the CallbackServer. This method should be called by the
client instead of run()."""
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
1)
try:
self.server_socket.bind((self.address, self.port))
# self.port = self.server_socket.getsockname()[1]
except Exception:
msg = 'An error occurred while trying to start the callback server'
logger.exception(msg)
raise Py4JNetworkError(msg)

# Maybe thread needs to be cleanup up?
self.thread = Thread(target=self.run)
self.thread.daemon = True
self.thread.start()

py4j.java_gateway.CallbackServer.start = start


class StreamingContext(object):
"""
Main entry point for Spark Streaming functionality. A StreamingContext represents the
Expand Down Expand Up @@ -53,7 +88,9 @@ def _start_callback_server(self):
gw = self._sc._gateway
# getattr will fallback to JVM
if "_callback_server" not in gw.__dict__:
_daemonize_callback_server()
gw._start_callback_server(gw._python_proxy_port)
gw._python_proxy_port = gw._callback_server.port # update port with real port

def _initialize_context(self, sc, duration):
return self._jvm.JavaStreamingContext(sc._jsc, duration._jduration)
Expand Down Expand Up @@ -92,26 +129,44 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):

def remember(self, duration):
"""
Set each DStreams in this context to remember RDDs it generated in the last given duration.
DStreams remember RDDs only for a limited duration of time and releases them for garbage
collection. This method allows the developer to specify how to long to remember the RDDs (
if the developer wishes to query old data outside the DStream computation).
@param duration pyspark.streaming.duration.Duration object or seconds.
Minimum duration that each DStream should remember its RDDs
Set each DStreams in this context to remember RDDs it generated
in the last given duration. DStreams remember RDDs only for a
limited duration of time and releases them for garbage collection.
This method allows the developer to specify how to long to remember
the RDDs ( if the developer wishes to query old data outside the
DStream computation).
@param duration Minimum duration (in seconds) that each DStream
should remember its RDDs
"""
if isinstance(duration, (int, long, float)):
duration = Seconds(duration)

self._jssc.remember(duration._jduration)

# TODO: add storageLevel
def socketTextStream(self, hostname, port):
def checkpoint(self, directory):
"""
Sets the context to periodically checkpoint the DStream operations for master
fault-tolerance. The graph will be checkpointed every batch interval.
@param directory HDFS-compatible directory where the checkpoint data
will be reliably stored
"""
self._jssc.checkpoint(directory)

def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2):
"""
Create an input from TCP source hostname:port. Data is received using
a TCP socket and receive byte is interpreted as UTF8 encoded '\n' delimited
lines.
@param hostname Hostname to connect to for receiving data
@param port Port to connect to for receiving data
@param storageLevel Storage level to use for storing the received objects
"""
return DStream(self._jssc.socketTextStream(hostname, port), self, UTF8Deserializer())
jlevel = self._sc._getJavaStorageLevel(storageLevel)
return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self,
UTF8Deserializer())

def textFileStream(self, directory):
"""
Expand All @@ -122,14 +177,52 @@ def textFileStream(self, directory):
"""
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())

def _makeStream(self, inputs, numSlices=None):
def _check_serialzers(self, rdds):
# make sure they have same serializer
if len(set(rdd._jrdd_deserializer for rdd in rdds)):
for i in range(len(rdds)):
# reset them to sc.serializer
rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True)

def queueStream(self, queue, oneAtATime=False, default=None):
"""
This function is only for unittest.
It requires a list as input, and returns the i_th element at the i_th batch
under manual clock.
Create an input stream from an queue of RDDs or list. In each batch,
it will process either one or all of the RDDs returned by the queue.
NOTE: changes to the queue after the stream is created will not be recognized.
@param queue Queue of RDDs
@tparam T Type of objects in the RDD
"""
rdds = [self._sc.parallelize(input, numSlices) for input in inputs]
if queue and not isinstance(queue[0], RDD):
rdds = [self._sc.parallelize(input) for input in queue]
else:
rdds = queue
self._check_serialzers(rdds)
jrdds = ListConverter().convert([r._jrdd for r in rdds],
SparkContext._gateway._gateway_client)
jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds).asJavaDStream()
return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds, oneAtATime,
default and default._jrdd)
return DStream(jdstream.asJavaDStream(), self, rdds[0]._jrdd_deserializer)

def transform(self, dstreams, transformFunc):
"""
Create a new DStream in which each RDD is generated by applying a function on RDDs of
the DStreams. The order of the JavaRDDs in the transform function parameter will be the
same as the order of corresponding DStreams in the list.
"""
# TODO

def union(self, *dstreams):
"""
Create a unified DStream from multiple DStreams of the same
type and same slide duration.
"""
if not dstreams:
raise ValueError("should have at least one DStream to union")
if len(dstreams) == 1:
return dstreams[0]
self._check_serialzers(dstreams)
first = dstreams[0]
jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
SparkContext._gateway._gateway_client)
return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)
8 changes: 4 additions & 4 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,16 @@ def repartitions(self, numPartitions):
return self.transform(lambda rdd: rdd.repartition(numPartitions))

def union(self, other):
return self.transformWith(lambda a, b: a.union(b), other, True)
return self.transformWith(lambda a, b, t: a.union(b), other, True)

def cogroup(self, other):
return self.transformWith(lambda a, b: a.cogroup(b), other)
return self.transformWith(lambda a, b, t: a.cogroup(b), other)

def leftOuterJoin(self, other):
return self.transformWith(lambda a, b: a.leftOuterJion(b), other)
return self.transformWith(lambda a, b, t: a.leftOuterJion(b), other)

def rightOuterJoin(self, other):
return self.transformWith(lambda a, b: a.rightOuterJoin(b), other)
return self.transformWith(lambda a, b, t: a.rightOuterJoin(b), other)

def _jtime(self, milliseconds):
return self.ctx._jvm.Time(milliseconds)
Expand Down
62 changes: 38 additions & 24 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,25 @@ def setUp(self):
class_name = self.__class__.__name__
self.sc = SparkContext(appName=class_name)
self.sc.setCheckpointDir("/tmp")
# TODO: decrease duration to speed up tests
self.ssc = StreamingContext(self.sc, duration=Seconds(1))

def tearDown(self):
self.ssc.stop()
self.sc.stop()

@classmethod
def tearDownClass(cls):
# Make sure tp shutdown the callback server
SparkContext._gateway._shutdown_callback_server()

def _test_func(self, input, func, expected, numSlices=None, sort=False):
def _test_func(self, input, func, expected, sort=False):
"""
Start stream and return the result.
@param input: dataset for the test. This should be list of lists.
@param func: wrapped function. This function should return PythonDStream object.
@param expected: expected output for this testcase.
@param numSlices: the number of slices in the rdd in the dstream.
"""
# Generate input stream with user-defined input.
input_stream = self.ssc._makeStream(input, numSlices)
input_stream = self.ssc.queueStream(input)
# Apply test function to stream.
stream = func(input_stream)
result = stream.collect()
Expand Down Expand Up @@ -121,7 +119,7 @@ def func(dstream):

def test_count(self):
"""Basic operation test for DStream.count."""
input = [range(1, 5), range(1, 10), range(1, 20)]
input = [range(5), range(10), range(20)]

def func(dstream):
return dstream.count()
Expand Down Expand Up @@ -178,24 +176,24 @@ def func(dstream):
def test_glom(self):
"""Basic operation test for DStream.glom."""
input = [range(1, 5), range(5, 9), range(9, 13)]
numSlices = 2
rdds = [self.sc.parallelize(r, 2) for r in input]

def func(dstream):
return dstream.glom()
expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
self._test_func(input, func, expected, numSlices)
self._test_func(rdds, func, expected)

def test_mapPartitions(self):
"""Basic operation test for DStream.mapPartitions."""
input = [range(1, 5), range(5, 9), range(9, 13)]
numSlices = 2
rdds = [self.sc.parallelize(r, 2) for r in input]

def func(dstream):
def f(iterator):
yield sum(iterator)
return dstream.mapPartitions(f)
expected = [[3, 7], [11, 15], [19, 23]]
self._test_func(input, func, expected, numSlices)
self._test_func(rdds, func, expected)

def test_countByValue(self):
"""Basic operation test for DStream.countByValue."""
Expand Down Expand Up @@ -236,14 +234,14 @@ def add(a, b):
self._test_func(input, func, expected, sort=True)

def test_union(self):
input1 = [range(3), range(5), range(1)]
input1 = [range(3), range(5), range(1), range(6)]
input2 = [range(3, 6), range(5, 6), range(1, 6)]

d1 = self.ssc._makeStream(input1)
d2 = self.ssc._makeStream(input2)
d1 = self.ssc.queueStream(input1)
d2 = self.ssc.queueStream(input2)
d = d1.union(d2)
result = d.collect()
expected = [range(6), range(6), range(6)]
expected = [range(6), range(6), range(6), range(6)]

self.ssc.start()
start_time = time.time()
Expand Down Expand Up @@ -317,33 +315,49 @@ def func(dstream):
class TestStreamingContext(unittest.TestCase):
def setUp(self):
self.sc = SparkContext(master="local[2]", appName=self.__class__.__name__)
self.batachDuration = Seconds(1)
self.ssc = None
self.batachDuration = Seconds(0.1)
self.ssc = StreamingContext(self.sc, self.batachDuration)

def tearDown(self):
if self.ssc is not None:
self.ssc.stop()
self.ssc.stop()
self.sc.stop()

def test_stop_only_streaming_context(self):
self.ssc = StreamingContext(self.sc, self.batachDuration)
self._addInputStream(self.ssc)
self._addInputStream()
self.ssc.start()
self.ssc.stop(False)
self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)

def test_stop_multiple_times(self):
self.ssc = StreamingContext(self.sc, self.batachDuration)
self._addInputStream(self.ssc)
self._addInputStream()
self.ssc.start()
self.ssc.stop()
self.ssc.stop()

def _addInputStream(self, s):
def _addInputStream(self):
# Make sure each length of input is over 3
inputs = map(lambda x: range(1, x), range(5, 101))
stream = s._makeStream(inputs)
stream = self.ssc.queueStream(inputs)
stream.collect()

def test_queueStream(self):
input = [range(i) for i in range(3)]
dstream = self.ssc.queueStream(input)
result = dstream.collect()
self.ssc.start()
time.sleep(1)
self.assertEqual(input, result)

def test_union(self):
input = [range(i) for i in range(3)]
dstream = self.ssc.queueStream(input)
dstream2 = self.ssc.union(dstream, dstream)
result = dstream.collect()
self.ssc.start()
time.sleep(1)
expected = [i * 2 for i in input]
self.assertEqual(input, result)


if __name__ == "__main__":
unittest.main()
13 changes: 10 additions & 3 deletions python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def __init__(self, ctx, func, jrdd_deserializer):

def call(self, jrdd, milliseconds):
try:
rdd = RDD(jrdd, self.ctx, self.deserializer)
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
if emptyRDD is None:
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD
r = self.func(rdd, milliseconds)
if r:
return r._jrdd
Expand Down Expand Up @@ -58,8 +61,12 @@ def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None):

def call(self, jrdd, jrdd2, milliseconds):
try:
rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else None
other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else None
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
if emptyRDD is None:
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()

rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else emptyRDD
other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else emptyRDD
r = self.func(rdd, other, milliseconds)
if r:
return r._jrdd
Expand Down
Loading

0 comments on commit 3f0fb4b

Please sign in to comment.