From ece1fa4b98e284b6ff92fded4e484fc15084df25 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 6 Mar 2014 17:42:16 +0530 Subject: [PATCH] Added top in python. --- python/pyspark/rdd.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index be23f87f5ed2d..70945a1fdff70 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -28,6 +28,7 @@ from tempfile import NamedTemporaryFile from threading import Thread import warnings +from heapq import heappush, heappop, heappushpop from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, pack_long @@ -628,6 +629,30 @@ def mergeMaps(m1, m2): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) + + def top(self, num): + """ + Get the top N elements from a RDD. + + Note: It returns the list sorted in ascending order. + >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) + [12] + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2) + [5, 6] + """ + def topIterator(iterator): + q = [] + for k in iterator: + if len(q) < num: + heappush(q, k) + else: + heappushpop(q, k) + yield q + + def merge(a, b): + return next(topIterator(a + b)) + + return sorted(self.mapPartitions(topIterator).reduce(merge)) def take(self, num): """