Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9021][PySpark] Change RDD.aggregate() to do reduce(mapPartitions()) instead of mapPartitions.fold() #7378

Closed
wants to merge 8 commits into from
10 changes: 8 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,9 @@ def func(iterator):
for obj in iterator:
acc = op(obj, acc)
yield acc
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
# to the final reduce call
vals = self.mapPartitions(func).collect()
return reduce(op, vals, zeroValue)

Expand Down Expand Up @@ -891,8 +894,11 @@ def func(iterator):
for obj in iterator:
acc = seqOp(acc, obj)
yield acc

return self.mapPartitions(func).fold(zeroValue, combOp)
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
# to the final reduce call
vals = self.mapPartitions(func).collect()
return reduce(combOp, vals, zeroValue)

def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
"""
Expand Down
141 changes: 129 additions & 12 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,127 @@ def test_deleting_input_files(self):

def test_sampling_default_seed(self):
# Test for SPARK-3995 (default seed setting)
data = self.sc.parallelize(range(1000), 1)
data = self.sc.parallelize(xrange(1000), 1)
subset = data.takeSample(False, 10)
self.assertEqual(len(subset), 10)

def test_aggregate_mutable_zero_value(self):
# Test for SPARK-9021; uses aggregate and treeAggregate to build dict
# representing a counter of ints
# NOTE: dict is used instead of collections.Counter for Python 2.6
# compatibility
from collections import defaultdict

# Show that single or multiple partitions work
data1 = self.sc.range(10, numSlices=1)
data2 = self.sc.range(10, numSlices=2)

def seqOp(x, y):
x[y] += 1
return x

def comboOp(x, y):
for key, val in y.items():
x[key] += val
return x

counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp)
counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp)
counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2)

ground_truth = defaultdict(int, dict((i, 1) for i in range(10)))
self.assertEqual(counts1, ground_truth)
self.assertEqual(counts2, ground_truth)
self.assertEqual(counts3, ground_truth)
self.assertEqual(counts4, ground_truth)

def test_aggregate_by_key_mutable_zero_value(self):
# Test for SPARK-9021; uses aggregateByKey to make a pair RDD that
# contains lists of all values for each key in the original RDD

# list(range(...)) for Python 3.x compatibility (can't use * operator
# on a range object)
# list(zip(...)) for Python 3.x compatibility (want to parallelize a
# collection, not a zip object)
tuples = list(zip(list(range(10))*2, [1]*20))
# Show that single or multiple partitions work
data1 = self.sc.parallelize(tuples, 1)
data2 = self.sc.parallelize(tuples, 2)

def seqOp(x, y):
x.append(y)
return x

def comboOp(x, y):
x.extend(y)
return x

values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
# Sort lists to ensure clean comparison with ground_truth
values1.sort()
values2.sort()

ground_truth = [(i, [1]*2) for i in range(10)]
self.assertEqual(values1, ground_truth)
self.assertEqual(values2, ground_truth)

def test_fold_mutable_zero_value(self):
# Test for SPARK-9021; uses fold to merge an RDD of dict counters into
# a single dict
# NOTE: dict is used instead of collections.Counter for Python 2.6
# compatibility
from collections import defaultdict

counts1 = defaultdict(int, dict((i, 1) for i in range(10)))
counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8)))
counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7)))
counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6)))
all_counts = [counts1, counts2, counts3, counts4]
# Show that single or multiple partitions work
data1 = self.sc.parallelize(all_counts, 1)
data2 = self.sc.parallelize(all_counts, 2)

def comboOp(x, y):
for key, val in y.items():
x[key] += val
return x

fold1 = data1.fold(defaultdict(int), comboOp)
fold2 = data2.fold(defaultdict(int), comboOp)

ground_truth = defaultdict(int)
for counts in all_counts:
for key, val in counts.items():
ground_truth[key] += val
self.assertEqual(fold1, ground_truth)
self.assertEqual(fold2, ground_truth)

def test_fold_by_key_mutable_zero_value(self):
# Test for SPARK-9021; uses foldByKey to make a pair RDD that contains
# lists of all values for each key in the original RDD

tuples = [(i, range(i)) for i in range(10)]*2
# Show that single or multiple partitions work
data1 = self.sc.parallelize(tuples, 1)
data2 = self.sc.parallelize(tuples, 2)

def comboOp(x, y):
x.extend(y)
return x

values1 = data1.foldByKey([], comboOp).collect()
values2 = data2.foldByKey([], comboOp).collect()
# Sort lists to ensure clean comparison with ground_truth
values1.sort()
values2.sort()

# list(range(...)) for Python 3.x compatibility
ground_truth = [(i, list(range(i))*2) for i in range(10)]
self.assertEqual(values1, ground_truth)
self.assertEqual(values2, ground_truth)

def test_aggregate_by_key(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)

Expand Down Expand Up @@ -624,8 +741,8 @@ def test_zip_with_different_serializers(self):

def test_zip_with_different_object_sizes(self):
# regress test for SPARK-5973
a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i)
b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i)
a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
self.assertEqual(10000, a.zip(b).count())

def test_zip_with_different_number_of_items(self):
Expand All @@ -647,7 +764,7 @@ def test_zip_with_different_number_of_items(self):
self.assertRaises(Exception, lambda: a.zip(b).count())

def test_count_approx_distinct(self):
rdd = self.sc.parallelize(range(1000))
rdd = self.sc.parallelize(xrange(1000))
self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
Expand Down Expand Up @@ -777,7 +894,7 @@ def test_distinct(self):
def test_external_group_by_key(self):
self.sc._conf.set("spark.python.worker.memory", "1m")
N = 200001
kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
gkv = kv.groupByKey().cache()
self.assertEqual(3, gkv.count())
filtered = gkv.filter(lambda kv: kv[0] == 1)
Expand Down Expand Up @@ -871,7 +988,7 @@ def test_narrow_dependency_in_join(self):

# Regression test for SPARK-6294
def test_take_on_jrdd(self):
rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x))
rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
rdd._jrdd.first()

def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
Expand Down Expand Up @@ -1516,13 +1633,13 @@ def run():
self.fail("daemon had been killed")

# run a normal job
rdd = self.sc.parallelize(range(100), 1)
rdd = self.sc.parallelize(xrange(100), 1)
self.assertEqual(100, rdd.map(str).count())

def test_after_exception(self):
def raise_exception(_):
raise Exception()
rdd = self.sc.parallelize(range(100), 1)
rdd = self.sc.parallelize(xrange(100), 1)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
self.assertEqual(100, rdd.map(str).count())
Expand All @@ -1538,22 +1655,22 @@ def test_after_jvm_exception(self):
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: filtered_data.count())

rdd = self.sc.parallelize(range(100), 1)
rdd = self.sc.parallelize(xrange(100), 1)
self.assertEqual(100, rdd.map(str).count())

def test_accumulator_when_reuse_worker(self):
from pyspark.accumulators import INT_ACCUMULATOR_PARAM
acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x))
self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
self.assertEqual(sum(range(100)), acc1.value)

acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x))
self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
self.assertEqual(sum(range(100)), acc2.value)
self.assertEqual(sum(range(100)), acc1.value)

def test_reuse_worker_after_take(self):
rdd = self.sc.parallelize(range(100000), 1)
rdd = self.sc.parallelize(xrange(100000), 1)
self.assertEqual(0, rdd.first())

def count():
Expand Down