Skip to content

Commit

Permalink
Spark 1162 Implemented takeOrdered in pyspark.
Browse files Browse the repository at this point in the history
Since python does not have a library for max heap and usual tricks like inverting values etc.. does not work for all cases.

We have our own implementation of max heap.

Author: Prashant Sharma <[email protected]>

Closes #97 from ScrapCodes/SPARK-1162/pyspark-top-takeOrdered2 and squashes the following commits:

35f86ba [Prashant Sharma] code review
2b1124d [Prashant Sharma] fixed tests
e8a08e2 [Prashant Sharma] Code review comments.
49e6ba7 [Prashant Sharma] SPARK-1162 added takeOrdered to pyspark
  • Loading branch information
ScrapCodes authored and mateiz committed Apr 3, 2014
1 parent 5d1feda commit c1ea3af
Showing 1 changed file with 102 additions and 5 deletions.
107 changes: 102 additions & 5 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
from heapq import heappush, heappop, heappushpop
import heapq

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
Expand All @@ -41,9 +41,9 @@

from py4j.java_collections import ListConverter, MapConverter


__all__ = ["RDD"]


def _extract_concise_traceback():
"""
This function returns the traceback info for a callsite, returns a dict
Expand Down Expand Up @@ -91,6 +91,73 @@ def __exit__(self, type, value, tb):
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)

class MaxHeapQ(object):
"""
An implementation of MaxHeap.
>>> import pyspark.rdd
>>> heap = pyspark.rdd.MaxHeapQ(5)
>>> [heap.insert(i) for i in range(10)]
[None, None, None, None, None, None, None, None, None, None]
>>> sorted(heap.getElements())
[0, 1, 2, 3, 4]
>>> heap = pyspark.rdd.MaxHeapQ(5)
>>> [heap.insert(i) for i in range(9, -1, -1)]
[None, None, None, None, None, None, None, None, None, None]
>>> sorted(heap.getElements())
[0, 1, 2, 3, 4]
>>> heap = pyspark.rdd.MaxHeapQ(1)
>>> [heap.insert(i) for i in range(9, -1, -1)]
[None, None, None, None, None, None, None, None, None, None]
>>> heap.getElements()
[0]
"""

def __init__(self, maxsize):
# we start from q[1], this makes calculating children as trivial as 2 * k
self.q = [0]
self.maxsize = maxsize

def _swim(self, k):
while (k > 1) and (self.q[k/2] < self.q[k]):
self._swap(k, k/2)
k = k/2

def _swap(self, i, j):
t = self.q[i]
self.q[i] = self.q[j]
self.q[j] = t

def _sink(self, k):
N = self.size()
while 2 * k <= N:
j = 2 * k
# Here we test if both children are greater than parent
# if not swap with larger one.
if j < N and self.q[j] < self.q[j + 1]:
j = j + 1
if(self.q[k] > self.q[j]):
break
self._swap(k, j)
k = j

def size(self):
return len(self.q) - 1

def insert(self, value):
if (self.size()) < self.maxsize:
self.q.append(value)
self._swim(self.size())
else:
self._replaceRoot(value)

def getElements(self):
return self.q[1:]

def _replaceRoot(self, value):
if(self.q[1] > value):
self.q[1] = value
self._sink(1)

class RDD(object):
"""
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
Expand Down Expand Up @@ -696,23 +763,53 @@ def top(self, num):
Note: It returns the list sorted in descending order.
>>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
[12]
>>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2)
>>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2)
[6, 5]
"""
def topIterator(iterator):
q = []
for k in iterator:
if len(q) < num:
heappush(q, k)
heapq.heappush(q, k)
else:
heappushpop(q, k)
heapq.heappushpop(q, k)
yield q

def merge(a, b):
return next(topIterator(a + b))

return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)

def takeOrdered(self, num, key=None):
"""
Get the N elements from a RDD ordered in ascending order or as specified
by the optional key function.
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
[1, 2, 3, 4, 5, 6]
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x)
[10, 9, 7, 6, 5, 4]
"""

def topNKeyedElems(iterator, key_=None):
q = MaxHeapQ(num)
for k in iterator:
if key_ != None:
k = (key_(k), k)
q.insert(k)
yield q.getElements()

def unKey(x, key_=None):
if key_ != None:
x = [i[1] for i in x]
return x

def merge(a, b):
return next(topNKeyedElems(a + b))
result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)
return sorted(unKey(result, key), key=key)


def take(self, num):
"""
Take the first num elements of the RDD.
Expand Down

0 comments on commit c1ea3af

Please sign in to comment.