Skip to content

Commit

Permalink
Updated comments regarding protection of zeroValue from mutation in R…
Browse files Browse the repository at this point in the history
…DD.aggregate(). Added regression tests for aggregate(), fold(), aggregateByKey(), foldByKey(), and treeAggregate(), all with both 1 and 2 partition RDDs. Confirmed that aggregate() is the only problematic implementation as of commit 257236c. Also replaced some parallelizations of ranges with xranges, per the documentation's recommendations of preferring xrange over range.
  • Loading branch information
shanghaiclown committed Jul 15, 2015
1 parent 90d1544 commit ba528bd
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 12 deletions.
2 changes: 2 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,8 @@ def func(iterator):
acc = seqOp(acc, obj)
yield acc

# fold() properly protects zeroValue from mutation, so it is
# unnecessary to make another copy in the fold() call
return self.mapPartitions(func).fold(zeroValue, combOp)

def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
Expand Down
149 changes: 137 additions & 12 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import random
import threading
import hashlib
import collections

from py4j.protocol import Py4JJavaError

Expand Down Expand Up @@ -529,10 +530,134 @@ def test_deleting_input_files(self):

def test_sampling_default_seed(self):
# Test for SPARK-3995 (default seed setting)
data = self.sc.parallelize(range(1000), 1)
data = self.sc.parallelize(xrange(1000), 1)
subset = data.takeSample(False, 10)
self.assertEqual(len(subset), 10)

def test_aggregate_mutable_zero_value(self):
# Test for SPARK-9021; uses aggregate to build Counter representing an
# RDD of ints

# Show that single or multiple partitions work
data1 = self.sc.range(10, numSlices=1)
data2 = self.sc.range(10, numSlices=2)

def seqOp(x, y):
x[y] += 1
return x

def comboOp(x, y):
x.update(y)
return x

counts1 = data1.aggregate(collections.Counter(), seqOp, comboOp)
counts2 = data2.aggregate(collections.Counter(), seqOp, comboOp)

self.assertEqual(counts1, collections.Counter(range(10)))
self.assertEqual(counts2, collections.Counter(range(10)))

def test_aggregate_by_key_mutable_zero_value(self):
# Test for SPARK-9021; uses aggregateByKey to make a pair RDD that
# contains lists of all values for each key in the original RDD

# list(range(...)) for Python 3.x compatibility (can't use * operator
# on a range object)
# list(zip(...)) for Python 3.x compatibility (want to parallelize a
# collection, not a zip object)
tuples = list(zip(list(range(10))*2, [1]*20))
# Show that single or multiple partitions work
data1 = self.sc.parallelize(tuples, 1)
data2 = self.sc.parallelize(tuples, 2)

def seqOp(x, y):
x.append(y)
return x

def comboOp(x, y):
x.extend(y)
return x

values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
# Sort lists to ensure clean comparison with ground_truth
values1.sort()
values2.sort()

ground_truth = [(i, [1]*2) for i in range(10)]
self.assertEqual(values1, ground_truth)
self.assertEqual(values2, ground_truth)

def test_fold_mutable_zero_value(self):
# Test for SPARK-9021; uses fold to merge an RDD of Counters into a
# single Counter
counts1 = collections.Counter(range(10))
counts2 = collections.Counter(range(3, 8))
counts3 = collections.Counter(range(4, 7))
counts4 = collections.Counter(range(5, 6))
all_counts = [counts1, counts2, counts3, counts4]
# Show that single or multiple partitions work
data1 = self.sc.parallelize(all_counts, 1)
data2 = self.sc.parallelize(all_counts, 2)

def comboOp(x, y):
x.update(y)
return x

fold1 = data1.fold(collections.Counter(), comboOp)
fold2 = data2.fold(collections.Counter(), comboOp)

ground_truth = collections.Counter()
for counts in all_counts:
ground_truth.update(counts)
self.assertEqual(fold1, ground_truth)
self.assertEqual(fold2, ground_truth)

def test_fold_by_key_mutable_zero_value(self):
# Test for SPARK-9021; uses foldByKey to make a pair RDD that contains
# lists of all values for each key in the original RDD

tuples = [(i, range(i)) for i in range(10)]*2
# Show that single or multiple partitions work
data1 = self.sc.parallelize(tuples, 1)
data2 = self.sc.parallelize(tuples, 2)

def comboOp(x, y):
x.extend(y)
return x

values1 = data1.foldByKey([], comboOp).collect()
values2 = data2.foldByKey([], comboOp).collect()
# Sort lists to ensure clean comparison with ground_truth
values1.sort()
values2.sort()

# list(range(...)) for Python 3.x compatibility
ground_truth = [(i, list(range(i))*2) for i in range(10)]
self.assertEqual(values1, ground_truth)
self.assertEqual(values2, ground_truth)

def test_tree_aggregate_mutable_zero_value(self):
# Test for SPARK-9021; uses aggregate to build Counter representing an
# RDD of ints

# Show that single or multiple partitions work
data1 = self.sc.range(10, numSlices=1)
data2 = self.sc.range(10, numSlices=2)

def seqOp(x, y):
x[y] += 1
return x

def comboOp(x, y):
x.update(y)
return x

counts1 = data1.treeAggregate(collections.Counter(), seqOp, comboOp, 2)
counts2 = data2.treeAggregate(collections.Counter(), seqOp, comboOp, 2)

self.assertEqual(counts1, collections.Counter(range(10)))
self.assertEqual(counts2, collections.Counter(range(10)))

def test_aggregate_by_key(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)

Expand Down Expand Up @@ -624,8 +749,8 @@ def test_zip_with_different_serializers(self):

def test_zip_with_different_object_sizes(self):
# regress test for SPARK-5973
a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i)
b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i)
a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
self.assertEqual(10000, a.zip(b).count())

def test_zip_with_different_number_of_items(self):
Expand All @@ -647,7 +772,7 @@ def test_zip_with_different_number_of_items(self):
self.assertRaises(Exception, lambda: a.zip(b).count())

def test_count_approx_distinct(self):
rdd = self.sc.parallelize(range(1000))
rdd = self.sc.parallelize(xrange(1000))
self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
Expand Down Expand Up @@ -777,7 +902,7 @@ def test_distinct(self):
def test_external_group_by_key(self):
self.sc._conf.set("spark.python.worker.memory", "1m")
N = 200001
kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
gkv = kv.groupByKey().cache()
self.assertEqual(3, gkv.count())
filtered = gkv.filter(lambda kv: kv[0] == 1)
Expand Down Expand Up @@ -871,7 +996,7 @@ def test_narrow_dependency_in_join(self):

# Regression test for SPARK-6294
def test_take_on_jrdd(self):
rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x))
rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
rdd._jrdd.first()

def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
Expand Down Expand Up @@ -1516,13 +1641,13 @@ def run():
self.fail("daemon had been killed")

# run a normal job
rdd = self.sc.parallelize(range(100), 1)
rdd = self.sc.parallelize(xrange(100), 1)
self.assertEqual(100, rdd.map(str).count())

def test_after_exception(self):
def raise_exception(_):
raise Exception()
rdd = self.sc.parallelize(range(100), 1)
rdd = self.sc.parallelize(xrange(100), 1)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
self.assertEqual(100, rdd.map(str).count())
Expand All @@ -1538,22 +1663,22 @@ def test_after_jvm_exception(self):
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: filtered_data.count())

rdd = self.sc.parallelize(range(100), 1)
rdd = self.sc.parallelize(xrange(100), 1)
self.assertEqual(100, rdd.map(str).count())

def test_accumulator_when_reuse_worker(self):
from pyspark.accumulators import INT_ACCUMULATOR_PARAM
acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x))
self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
self.assertEqual(sum(range(100)), acc1.value)

acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x))
self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
self.assertEqual(sum(range(100)), acc2.value)
self.assertEqual(sum(range(100)), acc1.value)

def test_reuse_worker_after_take(self):
rdd = self.sc.parallelize(range(100000), 1)
rdd = self.sc.parallelize(xrange(100000), 1)
self.assertEqual(0, rdd.first())

def count():
Expand Down

0 comments on commit ba528bd

Please sign in to comment.