Skip to content

Commit

Permalink
refactor and improve docs
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Jul 23, 2014
1 parent fdd0a49 commit 6178844
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 58 deletions.
2 changes: 1 addition & 1 deletion python/epydoc.conf
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests
pyspark.rddsampler pyspark.daemon pyspark.mllib._common
pyspark.mllib.tests
pyspark.mllib.tests pyspark.shuffle
4 changes: 2 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,7 +1317,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
def combineLocally(iterator):
merger = ExternalMerger(agg, memory, serializer) \
if spill else InMemoryMerger(agg)
merger.combine(iterator)
merger.mergeValues(iterator)
return merger.iteritems()

locally_combined = self.mapPartitions(combineLocally)
Expand All @@ -1326,7 +1326,7 @@ def combineLocally(iterator):
def _mergeCombiners(iterator):
merger = ExternalMerger(agg, memory, serializer) \
if spill else InMemoryMerger(agg)
merger.merge(iterator)
merger.mergeCombiners(iterator)
return merger.iteritems()

return shuffled.mapPartitions(_mergeCombiners)
Expand Down
130 changes: 84 additions & 46 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,54 @@
import psutil

def get_used_memory():
""" return the used memory in MB """
self = psutil.Process(os.getpid())
return self.memory_info().rss >> 20

except ImportError:

def get_used_memory():
""" return the used memory in MB """
if platform.system() == 'Linux':
for line in open('/proc/self/status'):
if line.startswith('VmRSS:'):
return int(line.split()[1]) >> 10
else:
warnings.warn("please install psutil to get accurate memory usage")
warnings.warn("please install psutil to have better "
"support with spilling")
if platform.system() == "Darwin":
import resource
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss >> 20
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
return rss >> 20
# TODO: support windows
return 0


class Aggregator(object):

def __init__(self, creator, combiner, mergeCombiner=None):
self.creator = creator
self.combiner = combiner
self.mergeCombiner = mergeCombiner or combiner
"""
Aggregator has tree functions to merge values into combiner.
createCombiner: (value) -> combiner
mergeValue: (combine, value) -> combiner
mergeCombiners: (combiner, combiner) -> combiner
"""

def __init__(self, createCombiner, mergeValue, mergeCombiners):
self.createCombiner = createCombiner
self.mergeValue = mergeValue
self.mergeCombiners = mergeCombiners


class SimpleAggregator(Aggregator):

"""
SimpleAggregator is useful for the cases that combiners have
same type with values
"""

def __init__(self, combiner):
Aggregator.__init__(self, lambda x: x, combiner, combiner)


class Merger(object):
Expand All @@ -63,11 +86,11 @@ class Merger(object):
def __init__(self, aggregator):
self.agg = aggregator

def combine(self, iterator):
def mergeValues(self, iterator):
""" combine the items by creator and combiner """
raise NotImplementedError

def merge(self, iterator):
def mergeCombiners(self, iterator):
""" merge the combined items by mergeCombiner """
raise NotImplementedError

Expand All @@ -86,17 +109,18 @@ def __init__(self, aggregator):
Merger.__init__(self, aggregator)
self.data = {}

def combine(self, iterator):
def mergeValues(self, iterator):
""" combine the items by creator and combiner """
# speed up attributes lookup
d, creator, comb = self.data, self.agg.creator, self.agg.combiner
d, creator = self.data, self.agg.createCombiner
comb = self.agg.mergeValue
for k, v in iterator:
d[k] = comb(d[k], v) if k in d else creator(v)

def merge(self, iterator):
def mergeCombiners(self, iterator):
""" merge the combined items by mergeCombiner """
# speed up attributes lookup
d, comb = self.data, self.agg.mergeCombiner
d, comb = self.data, self.agg.mergeCombiners
for k, v in iterator:
d[k] = comb(d[k], v) if k in d else v

Expand Down Expand Up @@ -133,32 +157,43 @@ class ExternalMerger(Merger):
it will partition the loaded data and dump them into disks
and load them partition by partition again.
>>> agg = Aggregator(lambda x: x, lambda x, y: x + y)
`data` and `pdata` are used to hold the merged items in memory.
At first, all the data are merged into `data`. Once the used
memory goes over limit, the items in `data` are dumped indo
disks, `data` will be cleared, all rest of items will be merged
into `pdata` and then dumped into disks. Before returning, all
the items in `pdata` will be dumped into disks.
Finally, if any items were spilled into disks, each partition
will be merged into `data` and be yielded, then cleared.
>>> agg = SimpleAggregator(lambda x, y: x + y)
>>> merger = ExternalMerger(agg, 10)
>>> N = 10000
>>> merger.combine(zip(xrange(N), xrange(N)) * 10)
>>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10)
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
499950000
>>> merger = ExternalMerger(agg, 10)
>>> merger.merge(zip(xrange(N), xrange(N)) * 10)
>>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10)
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
499950000
"""

PARTITIONS = 64 # number of partitions when spill data into disks
BATCH = 10000 # check the memory after # of items merged

def __init__(self, aggregator, memory_limit=512, serializer=None,
localdirs=None, scale=1):
localdirs=None, scale=1, partitions=64, batch=10000):
Merger.__init__(self, aggregator)
self.memory_limit = memory_limit
# default serializer is only used for tests
self.serializer = serializer or \
BatchedSerializer(PickleSerializer(), 1024)
self.localdirs = localdirs or self._get_dirs()
# number of partitions when spill data into disks
self.partitions = partitions
# check the memory after # of items merged
self.batch = batch
# scale is used to scale down the hash of key for recursive hash map,
self.scale = scale
# unpartitioned merged data
Expand Down Expand Up @@ -187,31 +222,31 @@ def next_limit(self):
"""
return max(self.memory_limit, get_used_memory() * 1.05)

def combine(self, iterator):
def mergeValues(self, iterator):
""" combine the items by creator and combiner """
iterator = iter(iterator)
# speedup attribute lookup
d, creator, comb = self.data, self.agg.creator, self.agg.combiner
c, batch = 0, self.BATCH
creator, comb = self.agg.createCombiner, self.agg.mergeValue
d, c, batch = self.data, 0, self.batch

for k, v in iterator:
d[k] = comb(d[k], v) if k in d else creator(v)

c += 1
if c % batch == 0 and get_used_memory() > self.memory_limit:
self._first_spill()
self._partitioned_combine(iterator, self.next_limit())
self._partitioned_mergeValues(iterator, self.next_limit())
break

def _partition(self, key):
""" return the partition for key """
return (hash(key) / self.scale) % self.PARTITIONS
return (hash(key) / self.scale) % self.partitions

def _partitioned_combine(self, iterator, limit=0):
def _partitioned_mergeValues(self, iterator, limit=0):
""" partition the items by key, then combine them """
# speedup attribute lookup
creator, comb, pdata = self.agg.creator, self.agg.combiner, self.pdata
c, hfun, batch = 0, self._partition, self.BATCH
creator, comb = self.agg.createCombiner, self.agg.mergeValue
c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch

for k, v in iterator:
d = pdata[hfun(k)]
Expand All @@ -224,11 +259,11 @@ def _partitioned_combine(self, iterator, limit=0):
self._spill()
limit = self.next_limit()

def merge(self, iterator, check=True):
def mergeCombiners(self, iterator, check=True):
""" merge (K,V) pair by mergeCombiner """
iterator = iter(iterator)
# speedup attribute lookup
d, comb, batch = self.data, self.agg.mergeCombiner, self.BATCH
d, comb, batch = self.data, self.agg.mergeCombiners, self.batch
c = 0
for k, v in iterator:
d[k] = comb(d[k], v) if k in d else v
Expand All @@ -238,38 +273,39 @@ def merge(self, iterator, check=True):
c += 1
if c % batch == 0 and get_used_memory() > self.memory_limit:
self._first_spill()
self._partitioned_merge(iterator, self.next_limit())
self._partitioned_mergeCombiners(iterator, self.next_limit())
break

def _partitioned_merge(self, iterator, limit=0):
def _partitioned_mergeCombiners(self, iterator, limit=0):
""" partition the items by key, then merge them """
comb, pdata, hfun = self.agg.mergeCombiner, self.pdata, self._partition
c = 0
comb, pdata = self.agg.mergeCombiners, self.pdata
c, hfun = 0, self._partition
for k, v in iterator:
d = pdata[hfun(k)]
d[k] = comb(d[k], v) if k in d else v
if not limit:
continue

c += 1
if c % self.BATCH == 0 and get_used_memory() > limit:
if c % self.batch == 0 and get_used_memory() > limit:
self._spill()
limit = self.next_limit()

def _first_spill(self):
"""
dump all the data into disks partition by partition.
The data has not been partitioned, it will iterator the dataset once,
write them into different files, has no additional memory. It only
called when the memory goes above limit at the first time.
The data has not been partitioned, it will iterator the
dataset once, write them into different files, has no
additional memory. It only called when the memory goes
above limit at the first time.
"""
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)
# open all the files for writing
streams = [open(os.path.join(path, str(i)), 'w')
for i in range(self.PARTITIONS)]
for i in range(self.partitions)]

for k, v in self.data.iteritems():
h = self._partition(k)
Expand All @@ -279,7 +315,7 @@ def _first_spill(self):
for s in streams:
s.close()
self.data.clear()
self.pdata = [{} for i in range(self.PARTITIONS)]
self.pdata = [{} for i in range(self.partitions)]
self.spills += 1

def _spill(self):
Expand All @@ -292,7 +328,7 @@ def _spill(self):
if not os.path.exists(path):
os.makedirs(path)

for i in range(self.PARTITIONS):
for i in range(self.partitions):
p = os.path.join(path, str(i))
with open(p, "w") as f:
# dump items in batch
Expand All @@ -314,13 +350,14 @@ def _external_items(self):
hard_limit = self.next_limit()

try:
for i in range(self.PARTITIONS):
for i in range(self.partitions):
self.data = {}
for j in range(self.spills):
path = self._get_spill_dir(j)
p = os.path.join(path, str(i))
# do not check memory during merging
self.merge(self.serializer.load_stream(open(p)), False)
self.mergeCombiners(self.serializer.load_stream(open(p)),
False)

if get_used_memory() > hard_limit and j < self.spills - 1:
self.data.clear() # will read from disk again
Expand Down Expand Up @@ -352,18 +389,19 @@ def _recursive_merged_items(self, start):
self._spill()
assert self.spills > 0

for i in range(start, self.PARTITIONS):
for i in range(start, self.partitions):
subdirs = [os.path.join(d, "parts", str(i))
for d in self.localdirs]
m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
subdirs, self.scale * self.PARTITIONS)
m.pdata = [{} for _ in range(self.PARTITIONS)]
subdirs, self.scale * self.partitions)
m.pdata = [{} for _ in range(self.partitions)]
limit = self.next_limit()

for j in range(self.spills):
path = self._get_spill_dir(j)
p = os.path.join(path, str(i))
m._partitioned_merge(self.serializer.load_stream(open(p)))
m._partitioned_mergeCombiners(
self.serializer.load_stream(open(p)))

if get_used_memory() > limit:
m._spill()
Expand Down
16 changes: 7 additions & 9 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,49 +57,47 @@ def setUp(self):
self.agg = Aggregator(lambda x: [x],
lambda x, y: x.append(y) or x,
lambda x, y: x.extend(y) or x)
ExternalMerger.PARTITIONS = 8
ExternalMerger.BATCH = 1 << 14

def test_in_memory(self):
m = InMemoryMerger(self.agg)
m.combine(self.data)
m.mergeValues(self.data)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
sum(xrange(self.N)))

m = InMemoryMerger(self.agg)
m.merge(map(lambda (x, y): (x, [y]), self.data))
m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
sum(xrange(self.N)))

def test_small_dataset(self):
m = ExternalMerger(self.agg, 1000)
m.combine(self.data)
m.mergeValues(self.data)
self.assertEqual(m.spills, 0)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
sum(xrange(self.N)))

m = ExternalMerger(self.agg, 1000)
m.merge(map(lambda (x, y): (x, [y]), self.data))
m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
self.assertEqual(m.spills, 0)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
sum(xrange(self.N)))

def test_medium_dataset(self):
m = ExternalMerger(self.agg, 10)
m.combine(self.data)
m.mergeValues(self.data)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
sum(xrange(self.N)))

m = ExternalMerger(self.agg, 10)
m.merge(map(lambda (x, y): (x, [y]), self.data * 3))
m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
sum(xrange(self.N)) * 3)

def test_huge_dataset(self):
m = ExternalMerger(self.agg, 10)
m.merge(map(lambda (k, v): (k, [str(v)]), self.data * 10))
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
self.N * 10)
Expand Down

0 comments on commit 6178844

Please sign in to comment.