Skip to content

Commit

Permalink
call gc.collect() after data.clear() to release memory as much as
Browse files Browse the repository at this point in the history
possible.
  • Loading branch information
davies committed Jul 25, 2014
1 parent 37d71f7 commit cad91bf
Showing 1 changed file with 36 additions and 35 deletions.
71 changes: 36 additions & 35 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import platform
import shutil
import warnings
import gc

from pyspark.serializers import BatchedSerializer, PickleSerializer

Expand Down Expand Up @@ -242,7 +243,7 @@ def mergeValues(self, iterator):

c += 1
if c % batch == 0 and get_used_memory() > self.memory_limit:
self._first_spill()
self._spill()
self._partitioned_mergeValues(iterator, self._next_limit())
break

Expand Down Expand Up @@ -280,7 +281,7 @@ def mergeCombiners(self, iterator, check=True):

c += 1
if c % batch == 0 and get_used_memory() > self.memory_limit:
self._first_spill()
self._spill()
self._partitioned_mergeCombiners(iterator, self._next_limit())
break

Expand All @@ -299,33 +300,6 @@ def _partitioned_mergeCombiners(self, iterator, limit=0):
self._spill()
limit = self._next_limit()

def _first_spill(self):
"""
Dump all the data into disks partition by partition.
The data has not been partitioned, it will iterator the
dataset once, write them into different files, has no
additional memory. It only called when the memory goes
above limit at the first time.
"""
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)
# open all the files for writing
streams = [open(os.path.join(path, str(i)), 'w')
for i in range(self.partitions)]

for k, v in self.data.iteritems():
h = self._partition(k)
# put one item in batch, make it compatitable with load_stream
# it will increase the memory if dump them in batch
self.serializer.dump_stream([(k, v)], streams[h])
for s in streams:
s.close()
self.data.clear()
self.pdata = [{} for i in range(self.partitions)]
self.spills += 1

def _spill(self):
"""
dump already partitioned data into disks.
Expand All @@ -336,13 +310,38 @@ def _spill(self):
if not os.path.exists(path):
os.makedirs(path)

for i in range(self.partitions):
p = os.path.join(path, str(i))
with open(p, "w") as f:
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
# additional memory. It only called when the memory goes
# above limit at the first time.

# open all the files for writing
streams = [open(os.path.join(path, str(i)), 'w')
for i in range(self.partitions)]

for k, v in self.data.iteritems():
h = self._partition(k)
# put one item in batch, make it compatitable with load_stream
# it will increase the memory if dump them in batch
self.serializer.dump_stream([(k, v)], streams[h])

for s in streams:
s.close()

self.data.clear()
self.pdata = [{} for i in range(self.partitions)]

else:
for i in range(self.partitions):
p = os.path.join(path, str(i))
with open(p, "w") as f:
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()

self.spills += 1
gc.collect() # release the memory as much as possible

def iteritems(self):
""" Return all merged items as iterator """
Expand Down Expand Up @@ -372,13 +371,15 @@ def _external_items(self):
and j < self.spills - 1
and get_used_memory() > hard_limit):
self.data.clear() # will read from disk again
gc.collect() # release the memory as much as possible
for v in self._recursive_merged_items(i):
yield v
return

for v in self.data.iteritems():
yield v
self.data.clear()
gc.collect()

# remove the merged partition
for j in range(self.spills):
Expand Down

0 comments on commit cad91bf

Please sign in to comment.