Skip to content

Commit

Permalink
improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Oct 2, 2014
1 parent c2b31cb commit 54bd92b
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 108 deletions.
1 change: 0 additions & 1 deletion python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def _ensure_initialized(cls):
# it happens before creating SparkContext when loading from checkpointing
cls._transformerSerializer = TransformFunctionSerializer(
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer)

@classmethod
def getOrCreate(cls, path, setupFunc):
Expand Down
29 changes: 21 additions & 8 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import time
from datetime import datetime

from py4j.protocol import Py4JJavaError

from pyspark import RDD
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.util import rddToFileName, TransformFunction
Expand Down Expand Up @@ -249,19 +251,31 @@ def saveAsTextFiles(self, prefix, suffix=None):
Save each RDD in this DStream as at text file, using string
representation of elements.
"""
def saveAsTextFile(time, rdd):
path = rddToFileName(prefix, suffix, time)
rdd.saveAsTextFile(path)
def saveAsTextFile(t, rdd):
path = rddToFileName(prefix, suffix, t)
try:
rdd.saveAsTextFile(path)
except Py4JJavaError as e:
# after recovered from checkpointing, the foreachRDD may
# be called twice
if 'FileAlreadyExistsException' not in str(e):
raise
return self.foreachRDD(saveAsTextFile)

def _saveAsPickleFiles(self, prefix, suffix=None):
"""
Save each RDD in this DStream as at binary file, the elements are
serialized by pickle.
"""
def saveAsPickleFile(time, rdd):
path = rddToFileName(prefix, suffix, time)
rdd.saveAsPickleFile(path)
def saveAsPickleFile(t, rdd):
path = rddToFileName(prefix, suffix, t)
try:
rdd.saveAsPickleFile(path)
except Py4JJavaError as e:
# after recovered from checkpointing, the foreachRDD may
# be called twice
if 'FileAlreadyExistsException' not in str(e):
raise
return self.foreachRDD(saveAsPickleFile)

def transform(self, func):
Expand Down Expand Up @@ -608,8 +622,7 @@ def _jdstream(self):
if self._jdstream_val is not None:
return self._jdstream_val

func = self.func
jfunc = TransformFunction(self.ctx, func, self.prev._jrdd_deserializer)
jfunc = TransformFunction(self.ctx, self.func, self.prev._jrdd_deserializer)
jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(),
jfunc, self.reuse).asJavaDStream()
self._jdstream_val = jdstream
Expand Down
150 changes: 85 additions & 65 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def setUp(self):
def tearDown(self):
self.ssc.stop()

def wait_for(self, result, n):
start_time = time.time()
while len(result) < n and time.time() - start_time < self.timeout:
time.sleep(0.01)
if len(result) < n:
print "timeout after", self.timeout

def _take(self, dstream, n):
"""
Return the first `n` elements in the stream (will start and stop).
Expand All @@ -55,12 +62,10 @@ def take(_, rdd):
dstream.foreachRDD(take)

self.ssc.start()
while len(results) < n:
time.sleep(0.01)
self.ssc.stop(False, True)
self.wait_for(results, n)
return results

def _collect(self, dstream):
def _collect(self, dstream, n, block=True):
"""
Collect each RDDs into the returned list.
Expand All @@ -69,10 +74,18 @@ def _collect(self, dstream):
result = []

def get_output(_, rdd):
r = rdd.collect()
if r:
result.append(r)
if rdd and len(result) < n:
r = rdd.collect()
if r:
result.append(r)

dstream.foreachRDD(get_output)

if not block:
return result

self.ssc.start()
self.wait_for(result, n)
return result

def _test_func(self, input, func, expected, sort=False, input2=None):
Expand All @@ -94,23 +107,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None):
else:
stream = func(input_stream)

result = self._collect(stream)
self.ssc.start()

start_time = time.time()
# Loop until get the expected the number of the result from the stream.
while True:
current_time = time.time()
# Check time out.
if (current_time - start_time) > self.timeout:
print "timeout after", self.timeout
break
# StreamingContext.awaitTermination is not used to wait because
# if py4j server is called every 50 milliseconds, it gets an error.
time.sleep(0.05)
# Check if the output is the same length of expected output.
if len(expected) == len(result):
break
result = self._collect(stream, len(expected))
if sort:
self._sort_result_based_on_key(result)
self._sort_result_based_on_key(expected)
Expand Down Expand Up @@ -424,55 +421,50 @@ class TestStreamingContext(PySparkStreamingTestCase):

duration = 0.1

def _add_input_stream(self):
inputs = map(lambda x: range(1, x), range(101))
stream = self.ssc.queueStream(inputs)
self._collect(stream, 1, block=False)

def test_stop_only_streaming_context(self):
self._addInputStream()
self._add_input_stream()
self.ssc.start()
self.ssc.stop(False)
self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)

def test_stop_multiple_times(self):
self._addInputStream()
self._add_input_stream()
self.ssc.start()
self.ssc.stop()
self.ssc.stop()

def _addInputStream(self):
# Make sure each length of input is over 3
inputs = map(lambda x: range(1, x), range(5, 101))
stream = self.ssc.queueStream(inputs)
self._collect(stream)

def test_queueStream(self):
input = [range(i) for i in range(3)]
def test_queue_stream(self):
input = [range(i + 1) for i in range(3)]
dstream = self.ssc.queueStream(input)
result = self._collect(dstream)
self.ssc.start()
time.sleep(1)
self.assertEqual(input, result[:3])
result = self._collect(dstream, 3)
self.assertEqual(input, result)

def test_textFileStream(self):
def test_text_file_stream(self):
d = tempfile.mkdtemp()
self.ssc = StreamingContext(self.sc, self.duration)
dstream2 = self.ssc.textFileStream(d).map(int)
result = self._collect(dstream2)
result = self._collect(dstream2, 2, block=False)
self.ssc.start()
time.sleep(1)
for name in ('a', 'b'):
time.sleep(1)
with open(os.path.join(d, name), "w") as f:
f.writelines(["%d\n" % i for i in range(10)])
time.sleep(2)
self.assertEqual([range(10) * 2], result[:3])
self.wait_for(result, 2)
self.assertEqual([range(10), range(10)], result)

def test_union(self):
input = [range(i) for i in range(3)]
input = [range(i + 1) for i in range(3)]
dstream = self.ssc.queueStream(input)
dstream2 = self.ssc.queueStream(input)
dstream3 = self.ssc.union(dstream, dstream2)
result = self._collect(dstream3)
self.ssc.start()
time.sleep(1)
result = self._collect(dstream3, 3)
expected = [i * 2 for i in input]
self.assertEqual(expected, result[:3])
self.assertEqual(expected, result)

def test_transform(self):
dstream1 = self.ssc.queueStream([[1]])
Expand All @@ -497,34 +489,62 @@ def tearDown(self):
pass

def test_get_or_create(self):
result = [0]
inputd = tempfile.mkdtemp()
outputd = tempfile.mkdtemp() + "/"

def updater(it):
for k, vs, s in it:
yield (k, sum(vs, s or 0))

def setup():
conf = SparkConf().set("spark.default.parallelism", 1)
sc = SparkContext(conf=conf)
ssc = StreamingContext(sc, .2)
dstream = ssc.textFileStream(inputd)
result[0] = self._collect(dstream.count())
ssc = StreamingContext(sc, 0.2)
dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
wc = dstream.updateStateByKey(updater)
wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
wc.checkpoint(.2)
return ssc

tmpd = tempfile.mkdtemp("test_streaming_cps")
ssc = StreamingContext.getOrCreate(tmpd, setup)
cpd = tempfile.mkdtemp("test_streaming_cps")
ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
time.sleep(1)
with open(os.path.join(inputd, "1"), 'w') as f:
f.writelines(["%d\n" % i for i in range(10)])
ssc.awaitTermination(4)

def check_output(n):
while not os.listdir(outputd):
time.sleep(0.1)
time.sleep(1) # make sure mtime is larger than the previous one
with open(os.path.join(inputd, str(n)), 'w') as f:
f.writelines(["%d\n" % i for i in range(10)])

while True:
p = os.path.join(outputd, max(os.listdir(outputd)))
if '_SUCCESS' not in os.listdir(p):
# not finished
time.sleep(0.01)
continue
ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
d = ordd.values().map(int).collect()
if not d:
time.sleep(0.01)
continue
self.assertEqual(10, len(d))
s = set(d)
self.assertEqual(1, len(s))
m = s.pop()
if n > m:
continue
self.assertEqual(n, m)
break

check_output(1)
check_output(2)
ssc.stop(True, True)
expected = [[i * 1 + 1] for i in range(5)] + [[5]] * 5
self.assertEqual([[10]], result[0][:1])

ssc = StreamingContext.getOrCreate(tmpd, setup)
ssc.start()
time.sleep(1)
with open(os.path.join(inputd, "1"), 'w') as f:
f.writelines(["%d\n" % i for i in range(10)])
ssc.awaitTermination(2)
ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
check_output(3)
ssc.stop(True, True)


Expand Down
23 changes: 12 additions & 11 deletions python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import time
from datetime import datetime
import traceback

Expand All @@ -32,23 +33,20 @@ def __init__(self, ctx, func, *deserializers):
self.func = func
self.deserializers = deserializers

@property
def emptyRDD(self):
if self._emptyRDD is None and self.ctx:
self._emptyRDD = self.ctx.parallelize([]).cache()
return self._emptyRDD

def call(self, milliseconds, jrdds):
try:
if self.ctx is None:
self.ctx = SparkContext._active_spark_context
if not self.ctx or not self.ctx._jsc:
# stopped
return

# extend deserializers with the first one
sers = self.deserializers
if len(sers) < len(jrdds):
sers += (sers[0],) * (len(jrdds) - len(sers))

rdds = [RDD(jrdd, self.ctx, ser) if jrdd else self.emptyRDD
rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
for jrdd, ser in zip(jrdds, sers)]
t = datetime.fromtimestamp(milliseconds / 1000.0)
r = self.func(t, *rdds)
Expand All @@ -69,6 +67,7 @@ def __init__(self, ctx, serializer, gateway=None):
self.ctx = ctx
self.serializer = serializer
self.gateway = gateway or self.ctx._gateway
self.gateway.jvm.PythonDStream.registerSerializer(self)

def dumps(self, id):
try:
Expand All @@ -91,20 +90,22 @@ class Java:
implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']


def rddToFileName(prefix, suffix, time):
def rddToFileName(prefix, suffix, timestamp):
"""
Return string prefix-time(.suffix)
>>> rddToFileName("spark", None, 12345678910)
'spark-12345678910'
>>> rddToFileName("spark", "tmp", 12345678910)
'spark-12345678910.tmp'
"""
if isinstance(timestamp, datetime):
seconds = time.mktime(timestamp.timetuple())
timestamp = long(seconds * 1000) + timestamp.microsecond / 1000
if suffix is None:
return prefix + "-" + str(time)
return prefix + "-" + str(timestamp)
else:
return prefix + "-" + str(time) + "." + suffix
return prefix + "-" + str(timestamp) + "." + suffix


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 54bd92b

Please sign in to comment.