Skip to content

Commit

Permalink
[branch-1.1][SPARK-4148][PySpark] fix seed distribution and add some …
Browse files Browse the repository at this point in the history
…tests for rdd.sample

Port #3010 to branch-1.1.

Author: Xiangrui Meng <[email protected]>

Closes #3104 from mengxr/SPARK-4148-1.1 and squashes the following commits:

684c002 [Xiangrui Meng] apply SPARK-4148 to branch-1.1
  • Loading branch information
mengxr committed Nov 5, 2014
1 parent 1b282cd commit 44751af
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
3 changes: 0 additions & 3 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,6 @@ def sample(self, withReplacement, fraction, seed=None):
"""
Return a sampled subset of this RDD (relies on numpy and falls back
on default random generator if numpy is unavailable).
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
Expand Down
11 changes: 5 additions & 6 deletions python/pyspark/rddsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@ def __init__(self, withReplacement, seed=None):
def initRandomGenerator(self, split):
if self._use_numpy:
import numpy
self._random = numpy.random.RandomState(self._seed)
self._random = numpy.random.RandomState(self._seed ^ split)
else:
self._random = random.Random(self._seed)
self._random = random.Random(self._seed ^ split)

for _ in range(0, split):
# discard the next few values in the sequence to have a
# different seed for the different splits
self._random.randint(0, sys.maxint)
# mixing because the initial seeds are close to each other
for _ in xrange(10):
self._random.randint(0, 1)

self._split = split
self._rand_initialized = True
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,21 @@ def test_histogram(self):
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))

def test_sample(self):
rdd = self.sc.parallelize(range(0, 100), 4)
wo = rdd.sample(False, 0.1, 2).collect()
wo_dup = rdd.sample(False, 0.1, 2).collect()
self.assertSetEqual(set(wo), set(wo_dup))
wr = rdd.sample(True, 0.2, 5).collect()
wr_dup = rdd.sample(True, 0.2, 5).collect()
self.assertSetEqual(set(wr), set(wr_dup))
wo_s10 = rdd.sample(False, 0.3, 10).collect()
wo_s20 = rdd.sample(False, 0.3, 20).collect()
self.assertNotEqual(set(wo_s10), set(wo_s20))
wr_s11 = rdd.sample(True, 0.4, 11).collect()
wr_s21 = rdd.sample(True, 0.4, 21).collect()
self.assertNotEqual(set(wr_s11), set(wr_s21))


class TestSQL(PySparkTestCase):

Expand Down

0 comments on commit 44751af

Please sign in to comment.