diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 0029178ec4f2b..bb137d09211bf 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -25,6 +25,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.streaming.util import rddToFileName, RDDFunction +from pyspark.rdd import portable_hash, _parse_memory from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -40,6 +41,7 @@ def __init__(self, jdstream, ssc, jrdd_deserializer): self._jrdd_deserializer = jrdd_deserializer self.is_cached = False self.is_checkpointed = False + self._partitionFunc = None def context(self): """ @@ -161,32 +163,71 @@ def _mergeCombiners(iterator): return shuffled.mapPartitions(_mergeCombiners) - def partitionBy(self, numPartitions, partitionFunc=None): + def partitionBy(self, numPartitions, partitionFunc=portable_hash): """ Return a copy of the DStream partitioned using the specified partitioner. """ if numPartitions is None: numPartitions = self.ctx._defaultReducePartitions() - if partitionFunc is None: - partitionFunc = lambda x: 0 if x is None else hash(x) - # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. + outputSerializer = self.ctx._unbatched_serializer +# +# def add_shuffle_key(split, iterator): +# buckets = defaultdict(list) +# +# for (k, v) in iterator: +# buckets[partitionFunc(k) % numPartitions].append((k, v)) +# for (split, items) in buckets.iteritems(): +# yield pack_long(split) +# yield outputSerializer.dumps(items) +# keyed = PipelinedDStream(self, add_shuffle_key) + + limit = (_parse_memory(self.ctx._conf.get( + "spark.python.worker.memory", "512m")) / 2) def add_shuffle_key(split, iterator): + buckets = defaultdict(list) + c, batch = 0, min(10 * numPartitions, 1000) - for (k, v) in iterator: + for k, v in iterator: buckets[partitionFunc(k) % numPartitions].append((k, v)) - for (split, items) in buckets.iteritems(): + c += 1 + + # check used memory and avg size of chunk of objects + if (c % 1000 == 0 and get_used_memory() > limit + or c > batch): + n, size = len(buckets), 0 + for split in buckets.keys(): + yield pack_long(split) + d = outputSerializer.dumps(buckets[split]) + del buckets[split] + yield d + size += len(d) + + avg = (size / n) >> 20 + # let 1M < avg < 10M + if avg < 1: + batch *= 1.5 + elif avg > 10: + batch = max(batch / 1.5, 1) + c = 0 + + for split, items in buckets.iteritems(): yield pack_long(split) yield outputSerializer.dumps(items) - keyed = PipelinedDStream(self, add_shuffle_key) + + keyed = self._mapPartitionsWithIndex(add_shuffle_key) + + + + keyed._bypass_serializer = True - with SCCallSiteSync(self.context) as css: + with SCCallSiteSync(self.ctx) as css: partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jdstream = self.ctx._jvm.PythonPairwiseDStream(keyed._jdstream.dstream(), @@ -428,6 +469,10 @@ def get_output(rdd, time): class PipelinedDStream(DStream): + """ + Since PipelinedDStream is same to PipelindRDD, if PipliedRDD is changed, + this code should be changed in the same way. + """ def __init__(self, prev, func, preservesPartitioning=False): if not isinstance(prev, PipelinedDStream) or not prev._is_pipelinable(): # This transformation is the first in its stage: @@ -453,19 +498,22 @@ def pipeline_func(split, iterator): self._jdstream_val = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False + self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None @property def _jdstream(self): if self._jdstream_val: return self._jdstream_val if self._bypass_serializer: - serializer = NoOpSerializer() - else: - serializer = self.ctx.serializer - - command = (self.func, self._prev_jrdd_deserializer, serializer) - ser = CompressedSerializer(CloudPickleSerializer()) + self.jrdd_deserializer = NoOpSerializer() + command = (self.func, self._prev_jrdd_deserializer, + self._jrdd_deserializer) + # the serialized command will be compressed by broadcast + ser = CloudPickleSerializer() pickled_command = ser.dumps(command) + if pickled_command > (1 << 20): # 1M + broadcast = self.ctx.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client)