diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 31919741e9d73..2d80fad796957 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -147,76 +147,6 @@ def __new__(cls, mean, confidence, low, high): return obj -class MaxHeapQ(object): - - """ - An implementation of MaxHeap. - - >>> import pyspark.rdd - >>> heap = pyspark.rdd.MaxHeapQ(5) - >>> [heap.insert(i) for i in range(10)] - [None, None, None, None, None, None, None, None, None, None] - >>> sorted(heap.getElements()) - [0, 1, 2, 3, 4] - >>> heap = pyspark.rdd.MaxHeapQ(5) - >>> [heap.insert(i) for i in range(9, -1, -1)] - [None, None, None, None, None, None, None, None, None, None] - >>> sorted(heap.getElements()) - [0, 1, 2, 3, 4] - >>> heap = pyspark.rdd.MaxHeapQ(1) - >>> [heap.insert(i) for i in range(9, -1, -1)] - [None, None, None, None, None, None, None, None, None, None] - >>> heap.getElements() - [0] - """ - - def __init__(self, maxsize): - # We start from q[1], so its children are always 2 * k - self.q = [0] - self.maxsize = maxsize - - def _swim(self, k): - while (k > 1) and (self.q[k / 2] < self.q[k]): - self._swap(k, k / 2) - k = k / 2 - - def _swap(self, i, j): - t = self.q[i] - self.q[i] = self.q[j] - self.q[j] = t - - def _sink(self, k): - N = self.size() - while 2 * k <= N: - j = 2 * k - # Here we test if both children are greater than parent - # if not swap with larger one. - if j < N and self.q[j] < self.q[j + 1]: - j = j + 1 - if(self.q[k] > self.q[j]): - break - self._swap(k, j) - k = j - - def size(self): - return len(self.q) - 1 - - def insert(self, value): - if (self.size()) < self.maxsize: - self.q.append(value) - self._swim(self.size()) - else: - self._replaceRoot(value) - - def getElements(self): - return self.q[1:] - - def _replaceRoot(self, value): - if(self.q[1] > value): - self.q[1] = value - self._sink(1) - - def _parse_memory(s): """ Parse a memory string in the format supported by Java (e.g. 1g, 200m) and @@ -248,6 +178,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() + self._partitionFunc = None def _toPickleSerialization(self): if (self._jrdd_deserializer == PickleSerializer() or @@ -325,8 +256,6 @@ def getCheckpointFile(self): checkpointFile = self._jrdd.rdd().getCheckpointFile() if checkpointFile.isDefined(): return checkpointFile.get() - else: - return None def map(self, f, preservesPartitioning=False): """ @@ -366,7 +295,7 @@ def mapPartitions(self, f, preservesPartitioning=False): """ def func(s, iterator): return f(iterator) - return self.mapPartitionsWithIndex(func) + return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitionsWithIndex(self, f, preservesPartitioning=False): """ @@ -416,7 +345,7 @@ def filter(self, f): """ def func(iterator): return ifilter(f, iterator) - return self.mapPartitions(func) + return self.mapPartitions(func, True) def distinct(self): """ @@ -561,7 +490,7 @@ def intersection(self, other): """ return self.map(lambda v: (v, None)) \ .cogroup(other.map(lambda v: (v, None))) \ - .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \ + .filter(lambda (k, vs): all(vs)) \ .keys() def _reserialize(self, serializer=None): @@ -616,7 +545,7 @@ def sortPartition(iterator): if numPartitions == 1: if self.getNumPartitions() > 1: self = self.coalesce(1) - return self.mapPartitions(sortPartition) + return self.mapPartitions(sortPartition, True) # first compute the boundary of each part via sampling: we want to partition # the key-space into bins such that the bins have roughly the same @@ -721,8 +650,8 @@ def foreach(self, f): def processPartition(iterator): for x in iterator: f(x) - yield None - self.mapPartitions(processPartition).collect() # Force evaluation + return iter([]) + self.mapPartitions(processPartition).count() # Force evaluation def foreachPartition(self, f): """ @@ -731,10 +660,15 @@ def foreachPartition(self, f): >>> def f(iterator): ... for x in iterator: ... print x - ... yield None >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ - self.mapPartitions(f).collect() # Force evaluation + def func(it): + r = f(it) + try: + return iter(r) + except TypeError: + return iter([]) + self.mapPartitions(func).count() # Force evaluation def collect(self): """ @@ -767,18 +701,23 @@ def reduce(self, f): 15 >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) 10 + >>> sc.parallelize([]).reduce(add) + Traceback (most recent call last): + ... + ValueError: Can not reduce() empty RDD """ def func(iterator): - acc = None - for obj in iterator: - if acc is None: - acc = obj - else: - acc = f(obj, acc) - if acc is not None: - yield acc + iterator = iter(iterator) + try: + initial = next(iterator) + except StopIteration: + return + yield reduce(f, iterator, initial) + vals = self.mapPartitions(func).collect() - return reduce(f, vals) + if vals: + return reduce(f, vals) + raise ValueError("Can not reduce() empty RDD") def fold(self, zeroValue, op): """ @@ -1081,7 +1020,7 @@ def countPartition(iterator): yield counts def mergeMaps(m1, m2): - for (k, v) in m2.iteritems(): + for k, v in m2.iteritems(): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) @@ -1117,24 +1056,10 @@ def takeOrdered(self, num, key=None): [10, 9, 7, 6, 5, 4] """ - def topNKeyedElems(iterator, key_=None): - q = MaxHeapQ(num) - for k in iterator: - if key_ is not None: - k = (key_(k), k) - q.insert(k) - yield q.getElements() - - def unKey(x, key_=None): - if key_ is not None: - x = [i[1] for i in x] - return x - def merge(a, b): - return next(topNKeyedElems(a + b)) - result = self.mapPartitions( - lambda i: topNKeyedElems(i, key)).reduce(merge) - return sorted(unKey(result, key), key=key) + return heapq.nsmallest(num, a + b, key) + + return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, key)]).reduce(merge) def take(self, num): """ @@ -1174,13 +1099,13 @@ def take(self, num): left = num - len(items) def takeUpToNumLeft(iterator): + iterator = iter(iterator) taken = 0 while taken < left: yield next(iterator) taken += 1 - p = range( - partsScanned, min(partsScanned + numPartsToTry, totalParts)) + p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) res = self.context.runJob(self, takeUpToNumLeft, p, True) items += res @@ -1194,8 +1119,15 @@ def first(self): >>> sc.parallelize([2, 3, 4]).first() 2 + >>> sc.parallelize([]).first() + Traceback (most recent call last): + ... + ValueError: RDD is empty """ - return self.take(1)[0] + rs = self.take(1) + if rs: + return rs[0] + raise ValueError("RDD is empty") def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -1420,13 +1352,13 @@ def reduceByKeyLocally(self, func): """ def reducePartition(iterator): m = {} - for (k, v) in iterator: - m[k] = v if k not in m else func(m[k], v) + for k, v in iterator: + m[k] = func(m[k], v) if k in m else v yield m def mergeMaps(m1, m2): - for (k, v) in m2.iteritems(): - m1[k] = v if k not in m1 else func(m1[k], v) + for k, v in m2.iteritems(): + m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1523,7 +1455,7 @@ def add_shuffle_key(split, iterator): buckets = defaultdict(list) c, batch = 0, min(10 * numPartitions, 1000) - for (k, v) in iterator: + for k, v in iterator: buckets[partitionFunc(k) % numPartitions].append((k, v)) c += 1 @@ -1546,7 +1478,7 @@ def add_shuffle_key(split, iterator): batch = max(batch / 1.5, 1) c = 0 - for (split, items) in buckets.iteritems(): + for split, items in buckets.iteritems(): yield pack_long(split) yield outputSerializer.dumps(items) @@ -1616,7 +1548,7 @@ def _mergeCombiners(iterator): merger.mergeCombiners(iterator) return merger.iteritems() - return shuffled.mapPartitions(_mergeCombiners) + return shuffled.mapPartitions(_mergeCombiners, True) def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ @@ -1680,7 +1612,6 @@ def mergeCombiners(a, b): return self.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions).mapValues(lambda x: ResultIterable(x)) - # TODO: add tests def flatMapValues(self, f): """ Pass each value in the key-value pair RDD through a flatMap function @@ -1770,9 +1701,8 @@ def subtractByKey(self, other, numPartitions=None): [('b', 4), ('b', 5)] """ def filter_func((key, vals)): - return len(vals[0]) > 0 and len(vals[1]) == 0 - map_func = lambda (key, vals): [(key, val) for val in vals[0]] - return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func) + return vals[0] and not vals[1] + return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0]) def subtract(self, other, numPartitions=None): """ @@ -1785,7 +1715,7 @@ def subtract(self, other, numPartitions=None): """ # note: here 'True' is just a placeholder rdd = other.map(lambda x: (x, True)) - return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) + return self.map(lambda x: (x, True)).subtractByKey(rdd, numPartitions).keys() def keyBy(self, f): """ @@ -1925,9 +1855,8 @@ def name(self): Return the name of this RDD. """ name_ = self._jrdd.name() - if not name_: - return None - return name_.encode('utf-8') + if name_: + return name_.encode('utf-8') def setName(self, name): """ @@ -1945,9 +1874,8 @@ def toDebugString(self): A description of this RDD and its recursive dependencies for debugging. """ debug_string = self._jrdd.toDebugString() - if not debug_string: - return None - return debug_string.encode('utf-8') + if debug_string: + return debug_string.encode('utf-8') def getStorageLevel(self): """ @@ -1982,10 +1910,28 @@ def _defaultReducePartitions(self): else: return self.getNumPartitions() - # TODO: `lookup` is disabled because we can't make direct comparisons based - # on the key; we need to compare the hash of the key to the hash of the - # keys in the pairs. This could be an expensive operation, since those - # hashes aren't retained. + def lookup(self, key): + """ + Return the list of values in the RDD for key `key`. This operation + is done efficiently if the RDD has a known partitioner by only + searching the partition that the key maps to. + + >>> l = range(1000) + >>> rdd = sc.parallelize(zip(l, l), 10) + >>> rdd.lookup(42) # slow + [42] + >>> sorted = rdd.sortByKey() + >>> sorted.lookup(42) # fast + [42] + >>> sorted.lookup(1024) + [] + """ + values = self.filter(lambda (k, v): k == key).values() + + if self._partitionFunc is not None: + return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False) + + return values.collect() def _is_pickled(self): """ Return this RDD is serialized by Pickle or not. """ @@ -2096,6 +2042,7 @@ def pipeline_func(split, iterator): self._jrdd_val = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False + self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None @property def _jrdd(self):