Skip to content

Commit

Permalink
use broadcast automatically for large closure
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 16, 2014
1 parent febafef commit aefd508
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
4 changes: 4 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,8 +2061,12 @@ def _jrdd(self):
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if pickled_command > (1 << 20): # 1M
broadcast = self.ctx.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from array import array
from operator import itemgetter

from pyspark.rdd import RDD, PipelinedRDD
from pyspark.rdd import RDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
from pyspark.storagelevel import StorageLevel

Expand Down Expand Up @@ -974,7 +974,11 @@ def registerFunction(self, name, f, returnType=StringType()):
command = (func,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
pickled_command = CloudPickleSerializer().dumps(command)
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if pickled_command > (1 << 20): # 1M
broadcast = self._sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self._sc._pickled_broadcast_vars],
self._sc._gateway._gateway_client)
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,12 @@ def test_large_broadcast(self):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)

def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
self.assertEquals(N, m)

def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))
b = self.sc.parallelize(range(100, 105))
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def main(infile, outfile):

_accumulatorRegistry.clear()
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
(func, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)
Expand Down

0 comments on commit aefd508

Please sign in to comment.