Skip to content

Commit

Permalink
[SPARK-9021] [PYSPARK] Change RDD.aggregate() to do reduce(mapPartiti…
Browse files Browse the repository at this point in the history
…ons()) instead of mapPartitions.fold()

I'm relatively new to Spark and functional programming, so forgive me if this pull request is just a result of my misunderstanding of how Spark should be used.

Currently, if one happens to use a mutable object as `zeroValue` for `RDD.aggregate()`, possibly unexpected behavior can occur.

This is because pyspark's current implementation of `RDD.aggregate()` does not serialize or make a copy of `zeroValue` before handing it off to `RDD.mapPartitions(...).fold(...)`. This results in a single reference to `zeroValue` being used for both `RDD.mapPartitions()` and `RDD.fold()` on each partition. This can result in strange accumulator values being fed into each partition's call to `RDD.fold()`, as the `zeroValue` may have been changed in-place during the `RDD.mapPartitions()` call.

As an illustrative example, submit the following to `spark-submit`:
```
from pyspark import SparkConf, SparkContext
import collections

def updateCounter(acc, val):
    print 'update acc:', acc
    print 'update val:', val
    acc[val] += 1
    return acc

def comboCounter(acc1, acc2):
    print 'combo acc1:', acc1
    print 'combo acc2:', acc2
    acc1.update(acc2)
    return acc1

def main():
    conf = SparkConf().setMaster("local").setAppName("Aggregate with Counter")
    sc = SparkContext(conf = conf)

    print '======= AGGREGATING with ONE PARTITION ======='
    print sc.parallelize(range(1,10), 1).aggregate(collections.Counter(), updateCounter, comboCounter)

    print '======= AGGREGATING with TWO PARTITIONS ======='
    print sc.parallelize(range(1,10), 2).aggregate(collections.Counter(), updateCounter, comboCounter)

if __name__ == "__main__":
    main()
```

One probably expects this to output the following:
```
Counter({1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1})
```

But it instead outputs this (regardless of the number of partitions):
```
Counter({1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2, 8: 2, 9: 2})
```

This is because (I believe) `zeroValue` gets passed correctly to each partition, but after `RDD.mapPartitions()` completes, the `zeroValue` object has been updated and is then passed to `RDD.fold()`, which results in all items being double-counted within each partition before being finally reduced at the calling node.

I realize that this type of calculation is typically done by `RDD.mapPartitions(...).reduceByKey(...)`, but hopefully this illustrates some potentially confusing behavior. I also noticed that other `RDD` methods use this `deepcopy` approach to creating unique copies of `zeroValue` (i.e., `RDD.aggregateByKey()` and `RDD.foldByKey()`), and that the Scala implementations do seem to serialize the `zeroValue` object appropriately to prevent this type of behavior.

Author: Nicholas Hwang <[email protected]>

Closes apache#7378 from njhwang/master and squashes the following commits:

659bb27 [Nicholas Hwang] Fixed RDD.aggregate() to perform a reduce operation on collected mapPartitions results, similar to how fold currently is implemented. This prevents an initial combOp being performed on each partition with zeroValue (which leads to unexpected behavior if zeroValue is a mutable object) before being combOp'ed with other partition results.
8d8d694 [Nicholas Hwang] Changed dict construction to be compatible with Python 2.6 (cannot use list comprehensions to make dicts)
56eb2ab [Nicholas Hwang] Fixed whitespace after colon to conform with PEP8
391de4a [Nicholas Hwang] Removed used of collections.Counter from RDD tests for Python 2.6 compatibility; used defaultdict(int) instead. Merged treeAggregate test with mutable zero value into aggregate test to reduce code duplication.
2fa4e4b [Nicholas Hwang] Merge branch 'master' of https://github.com/njhwang/spark
ba528bd [Nicholas Hwang] Updated comments regarding protection of zeroValue from mutation in RDD.aggregate(). Added regression tests for aggregate(), fold(), aggregateByKey(), foldByKey(), and treeAggregate(), all with both 1 and 2 partition RDDs. Confirmed that aggregate() is the only problematic implementation as of commit 257236c. Also replaced some parallelizations of ranges with xranges, per the documentation's recommendations of preferring xrange over range.
7820391 [Nicholas Hwang] Updated comments regarding protection of zeroValue from mutation in RDD.aggregate(). Added regression tests for aggregate(), fold(), aggregateByKey(), foldByKey(), and treeAggregate(), all with both 1 and 2 partition RDDs. Confirmed that aggregate() is the only problematic implementation as of commit 257236c.
90d1544 [Nicholas Hwang] Made sure RDD.aggregate() makes a deepcopy of zeroValue for all partitions; this ensures that the mapPartitions call works with unique copies of zeroValue in each partition, and prevents a single reference to zeroValue being used for both map and fold calls on each partition (resulting in possibly unexpected behavior).

(cherry picked from commit a803ac3)
Signed-off-by: Davies Liu <[email protected]>
  • Loading branch information
shanghaiclown authored and davies committed Jul 19, 2015
1 parent 6834d96 commit a3c853c
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 14 deletions.
10 changes: 8 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,9 @@ def func(iterator):
for obj in iterator:
acc = op(obj, acc)
yield acc
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
# to the final reduce call
vals = self.mapPartitions(func).collect()
return reduce(op, vals, zeroValue)

Expand Down Expand Up @@ -878,8 +881,11 @@ def func(iterator):
for obj in iterator:
acc = seqOp(acc, obj)
yield acc

return self.mapPartitions(func).fold(zeroValue, combOp)
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
# to the final reduce call
vals = self.mapPartitions(func).collect()
return reduce(combOp, vals, zeroValue)

def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
"""
Expand Down
141 changes: 129 additions & 12 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,127 @@ def test_deleting_input_files(self):

def test_sampling_default_seed(self):
# Test for SPARK-3995 (default seed setting)
data = self.sc.parallelize(range(1000), 1)
data = self.sc.parallelize(xrange(1000), 1)
subset = data.takeSample(False, 10)
self.assertEqual(len(subset), 10)

def test_aggregate_mutable_zero_value(self):
# Test for SPARK-9021; uses aggregate and treeAggregate to build dict
# representing a counter of ints
# NOTE: dict is used instead of collections.Counter for Python 2.6
# compatibility
from collections import defaultdict

# Show that single or multiple partitions work
data1 = self.sc.range(10, numSlices=1)
data2 = self.sc.range(10, numSlices=2)

def seqOp(x, y):
x[y] += 1
return x

def comboOp(x, y):
for key, val in y.items():
x[key] += val
return x

counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp)
counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp)
counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2)

ground_truth = defaultdict(int, dict((i, 1) for i in range(10)))
self.assertEqual(counts1, ground_truth)
self.assertEqual(counts2, ground_truth)
self.assertEqual(counts3, ground_truth)
self.assertEqual(counts4, ground_truth)

def test_aggregate_by_key_mutable_zero_value(self):
# Test for SPARK-9021; uses aggregateByKey to make a pair RDD that
# contains lists of all values for each key in the original RDD

# list(range(...)) for Python 3.x compatibility (can't use * operator
# on a range object)
# list(zip(...)) for Python 3.x compatibility (want to parallelize a
# collection, not a zip object)
tuples = list(zip(list(range(10))*2, [1]*20))
# Show that single or multiple partitions work
data1 = self.sc.parallelize(tuples, 1)
data2 = self.sc.parallelize(tuples, 2)

def seqOp(x, y):
x.append(y)
return x

def comboOp(x, y):
x.extend(y)
return x

values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
# Sort lists to ensure clean comparison with ground_truth
values1.sort()
values2.sort()

ground_truth = [(i, [1]*2) for i in range(10)]
self.assertEqual(values1, ground_truth)
self.assertEqual(values2, ground_truth)

def test_fold_mutable_zero_value(self):
# Test for SPARK-9021; uses fold to merge an RDD of dict counters into
# a single dict
# NOTE: dict is used instead of collections.Counter for Python 2.6
# compatibility
from collections import defaultdict

counts1 = defaultdict(int, dict((i, 1) for i in range(10)))
counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8)))
counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7)))
counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6)))
all_counts = [counts1, counts2, counts3, counts4]
# Show that single or multiple partitions work
data1 = self.sc.parallelize(all_counts, 1)
data2 = self.sc.parallelize(all_counts, 2)

def comboOp(x, y):
for key, val in y.items():
x[key] += val
return x

fold1 = data1.fold(defaultdict(int), comboOp)
fold2 = data2.fold(defaultdict(int), comboOp)

ground_truth = defaultdict(int)
for counts in all_counts:
for key, val in counts.items():
ground_truth[key] += val
self.assertEqual(fold1, ground_truth)
self.assertEqual(fold2, ground_truth)

def test_fold_by_key_mutable_zero_value(self):
# Test for SPARK-9021; uses foldByKey to make a pair RDD that contains
# lists of all values for each key in the original RDD

tuples = [(i, range(i)) for i in range(10)]*2
# Show that single or multiple partitions work
data1 = self.sc.parallelize(tuples, 1)
data2 = self.sc.parallelize(tuples, 2)

def comboOp(x, y):
x.extend(y)
return x

values1 = data1.foldByKey([], comboOp).collect()
values2 = data2.foldByKey([], comboOp).collect()
# Sort lists to ensure clean comparison with ground_truth
values1.sort()
values2.sort()

# list(range(...)) for Python 3.x compatibility
ground_truth = [(i, list(range(i))*2) for i in range(10)]
self.assertEqual(values1, ground_truth)
self.assertEqual(values2, ground_truth)

def test_aggregate_by_key(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)

Expand Down Expand Up @@ -624,8 +741,8 @@ def test_zip_with_different_serializers(self):

def test_zip_with_different_object_sizes(self):
# regress test for SPARK-5973
a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i)
b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i)
a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
self.assertEqual(10000, a.zip(b).count())

def test_zip_with_different_number_of_items(self):
Expand All @@ -647,7 +764,7 @@ def test_zip_with_different_number_of_items(self):
self.assertRaises(Exception, lambda: a.zip(b).count())

def test_count_approx_distinct(self):
rdd = self.sc.parallelize(range(1000))
rdd = self.sc.parallelize(xrange(1000))
self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
Expand Down Expand Up @@ -777,7 +894,7 @@ def test_distinct(self):
def test_external_group_by_key(self):
self.sc._conf.set("spark.python.worker.memory", "1m")
N = 200001
kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
gkv = kv.groupByKey().cache()
self.assertEqual(3, gkv.count())
filtered = gkv.filter(lambda kv: kv[0] == 1)
Expand Down Expand Up @@ -871,7 +988,7 @@ def test_narrow_dependency_in_join(self):

# Regression test for SPARK-6294
def test_take_on_jrdd(self):
rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x))
rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
rdd._jrdd.first()

def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
Expand Down Expand Up @@ -1503,13 +1620,13 @@ def run():
self.fail("daemon had been killed")

# run a normal job
rdd = self.sc.parallelize(range(100), 1)
rdd = self.sc.parallelize(xrange(100), 1)
self.assertEqual(100, rdd.map(str).count())

def test_after_exception(self):
def raise_exception(_):
raise Exception()
rdd = self.sc.parallelize(range(100), 1)
rdd = self.sc.parallelize(xrange(100), 1)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
self.assertEqual(100, rdd.map(str).count())
Expand All @@ -1525,22 +1642,22 @@ def test_after_jvm_exception(self):
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: filtered_data.count())

rdd = self.sc.parallelize(range(100), 1)
rdd = self.sc.parallelize(xrange(100), 1)
self.assertEqual(100, rdd.map(str).count())

def test_accumulator_when_reuse_worker(self):
from pyspark.accumulators import INT_ACCUMULATOR_PARAM
acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x))
self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
self.assertEqual(sum(range(100)), acc1.value)

acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x))
self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
self.assertEqual(sum(range(100)), acc2.value)
self.assertEqual(sum(range(100)), acc1.value)

def test_reuse_worker_after_take(self):
rdd = self.sc.parallelize(range(100000), 1)
rdd = self.sc.parallelize(xrange(100000), 1)
self.assertEqual(0, rdd.first())

def count():
Expand Down

0 comments on commit a3c853c

Please sign in to comment.