diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index d98afc3e5a294..d866f8c9687fb 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -552,14 +552,18 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None def reduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) - r = a.union(b).reduceByKey(func, numPartitions) if a else b + # use the average of number of partitions, or it will keep increasing + partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 + r = a.union(b).reduceByKey(func, partitions) if a else b if filterFunc: r = r.filter(filterFunc) return r def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) - joined = a.leftOuterJoin(b, numPartitions) + # use the average of number of partitions, or it will keep increasing + partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 + joined = a.leftOuterJoin(b, partitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) @@ -587,7 +591,9 @@ def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) else: - g = a.cogroup(b, numPartitions) + # use the average of number of partitions, or it will keep increasing + partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 + g = a.cogroup(b, partitions) g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None)) state = g.mapPartitions(lambda x: updateFunc(x)) return state.filter(lambda (k, v): v is not None) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 7ffdb145c104e..0dc6b3d675397 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -22,7 +22,7 @@ import unittest import tempfile -from pyspark.context import SparkContext +from pyspark.context import SparkContext, RDD from pyspark.streaming.context import StreamingContext @@ -46,8 +46,13 @@ def _test_func(self, input, func, expected, sort=False, input2=None): @param func: wrapped function. This function should return PythonDStream object. @param expected: expected output for this testcase. """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + # Apply test function to stream. if input2: stream = func(input_stream, input_stream2) @@ -63,6 +68,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None): current_time = time.time() # Check time out. if (current_time - start_time) > self.timeout: + print "timeout after", self.timeout break # StreamingContext.awaitTermination is not used to wait because # if py4j server is called every 50 milliseconds, it gets an error.