Skip to content

Commit

Permalink
fix serialization of accumulator
Browse files Browse the repository at this point in the history
add better message when try to access Broadcast.value in driver.
  • Loading branch information
davies committed Aug 13, 2014
1 parent 631a827 commit db3f232
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
13 changes: 11 additions & 2 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,30 @@ class Broadcast(object):
Access its value through C{.value}.
"""

def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
def __init__(self, bid, value, java_broadcast=None, pickle_registry=None, keep=True):
"""
Should not be called directly by users -- use
L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}
instead.
"""
self.value = value
self.bid = bid
if keep:
self.value = value
self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry
self.keep = keep

def __reduce__(self):
self._pickle_registry.add(self)
return (_from_id, (self.bid, ))

def __getattr__(self, item):
if item == 'value' and not self.keep:
raise Exception("please create broadcast with keep=True to make"
" it accessable in driver")

raise AttributeError(item)


if __name__ == "__main__":
import doctest
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,8 @@ def broadcast(self, value, keep=False):
tempFile.close()
jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name)
os.unlink(tempFile.name)
return Broadcast(jbroadcast.id(), value if keep else None,
jbroadcast, self._pickled_broadcast_vars)
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars, keep)

def accumulator(self, value, accum_param=None):
"""
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
CompressedSerializer


pickleSer = CompressedSerializer(PickleSerializer())
pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()


Expand Down Expand Up @@ -66,12 +66,13 @@ def main(infile, outfile):

# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
ser = CompressedSerializer(pickleSer)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
value = pickleSer._read_with_length(infile)
value = ser._read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, value)

command = pickleSer._read_with_length(infile)
command = ser._read_with_length(infile)
(func, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)
Expand Down

0 comments on commit db3f232

Please sign in to comment.