Skip to content

Commit

Permalink
[SPARK-1065] [PySpark] improve supporting for large broadcast
Browse files Browse the repository at this point in the history
Passing large object by py4j is very slow (cost much memory), so pass broadcast objects via files (similar to parallelize()).

Add an option to keep object in driver (it's False by default) to save memory in driver.

Author: Davies Liu <[email protected]>

Closes #1912 from davies/broadcast and squashes the following commits:

e06df4a [Davies Liu] load broadcast from disk in driver automatically
db3f232 [Davies Liu] fix serialization of accumulator
631a827 [Davies Liu] Merge branch 'master' into broadcast
c7baa8c [Davies Liu] compress serrialized broadcast and command
9a7161f [Davies Liu] fix doc tests
e93cf4b [Davies Liu] address comments: add test
6226189 [Davies Liu] improve large broadcast
  • Loading branch information
davies authored and JoshRosen committed Aug 16, 2014
1 parent 379e758 commit 2fc8aca
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,14 @@ private[spark] object PythonRDD extends Logging {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}

def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val length = file.readInt()
val obj = new Array[Byte](length)
file.readFully(obj)
sc.broadcast(obj)
}

def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
// The right way to implement this would be to use TypeTags to get the full
// type of T. Since I don't want to introduce breaking changes throughout the
Expand Down
37 changes: 28 additions & 9 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,16 @@
>>> b = sc.broadcast([1, 2, 3, 4, 5])
>>> b.value
[1, 2, 3, 4, 5]
>>> from pyspark.broadcast import _broadcastRegistry
>>> _broadcastRegistry[b.bid] = b
>>> from cPickle import dumps, loads
>>> loads(dumps(b)).value
[1, 2, 3, 4, 5]
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
>>> b.unpersist()
>>> large_broadcast = sc.broadcast(list(range(10000)))
"""
import os

from pyspark.serializers import CompressedSerializer, PickleSerializer

# Holds broadcasted data received from Java, keyed by its id.
_broadcastRegistry = {}

Expand All @@ -52,17 +50,38 @@ class Broadcast(object):
Access its value through C{.value}.
"""

def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
def __init__(self, bid, value, java_broadcast=None,
pickle_registry=None, path=None):
"""
Should not be called directly by users -- use
L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}
instead.
"""
self.value = value
self.bid = bid
if path is None:
self.value = value
self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry
self.path = path

def unpersist(self, blocking=False):
self._jbroadcast.unpersist(blocking)
os.unlink(self.path)

def __reduce__(self):
self._pickle_registry.add(self)
return (_from_id, (self.bid, ))

def __getattr__(self, item):
if item == 'value' and self.path is not None:
ser = CompressedSerializer(PickleSerializer())
value = ser.load_stream(open(self.path)).next()
self.value = value
return value

raise AttributeError(item)


if __name__ == "__main__":
import doctest
doctest.testmod()
20 changes: 13 additions & 7 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer
PairDeserializer, CompressedSerializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
Expand Down Expand Up @@ -566,13 +566,19 @@ def broadcast(self, value):
"""
Broadcast a read-only variable to the cluster, returning a
L{Broadcast<pyspark.broadcast.Broadcast>}
object for reading it in distributed functions. The variable will be
sent to each cluster only once.
object for reading it in distributed functions. The variable will
be sent to each cluster only once.
:keep: Keep the `value` in driver or not.
"""
pickleSer = PickleSerializer()
pickled = pickleSer.dumps(value)
jbroadcast = self._jsc.broadcast(bytearray(pickled))
return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars)
ser = CompressedSerializer(PickleSerializer())
# pass large object by py4j is very slow and need much memory
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
ser.dump_stream([value], tempFile)
tempFile.close()
jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name)
return Broadcast(jbroadcast.id(), None, jbroadcast,
self._pickled_broadcast_vars, tempFile.name)

def accumulator(self, value, accum_param=None):
"""
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long
PickleSerializer, pack_long, CompressedSerializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
Expand Down Expand Up @@ -1810,7 +1810,8 @@ def _jrdd(self):
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
pickled_command = CloudPickleSerializer().dumps(command)
ser = CompressedSerializer(CloudPickleSerializer())
pickled_command = ser.dumps(command)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import sys
import types
import collections
import zlib

from pyspark import cloudpickle

Expand Down Expand Up @@ -403,6 +404,22 @@ def loads(self, obj):
raise ValueError("invalid sevialization type: %s" % _type)


class CompressedSerializer(FramedSerializer):
"""
compress the serialized data
"""

def __init__(self, serializer):
FramedSerializer.__init__(self)
self.serializer = serializer

def dumps(self, obj):
return zlib.compress(self.serializer.dumps(obj), 1)

def loads(self, obj):
return self.serializer.loads(zlib.decompress(obj))


class UTF8Deserializer(Serializer):

"""
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,13 @@ def test_namedtuple_in_rdd(self):
theDoes = self.sc.parallelize([jon, jane])
self.assertEquals([jon, jane], theDoes.collect())

def test_large_broadcast(self):
N = 100000
data = [[float(i) for i in range(300)] for i in range(N)]
bdata = self.sc.broadcast(data) # 270MB
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)


class TestIO(PySparkTestCase):

Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
CompressedSerializer


pickleSer = PickleSerializer()
Expand Down Expand Up @@ -65,12 +66,13 @@ def main(infile, outfile):

# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
ser = CompressedSerializer(pickleSer)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
value = pickleSer._read_with_length(infile)
value = ser._read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, value)

command = pickleSer._read_with_length(infile)
command = ser._read_with_length(infile)
(func, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)
Expand Down

0 comments on commit 2fc8aca

Please sign in to comment.