Skip to content

Commit

Permalink
[SPARK-22624][PYSPARK] Expose range partitioning shuffle introduced b…
Browse files Browse the repository at this point in the history
…y spark-22614

## What changes were proposed in this pull request?

 Expose range partitioning shuffle introduced by spark-22614

## How was this patch tested?

Unit test in dataframe.py

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: xubo245 <[email protected]>

Closes apache#20456 from xubo245/SPARK22624_PysparkRangePartition.
  • Loading branch information
xubo245 authored and Robert Kruszewski committed Feb 12, 2018
1 parent 3683374 commit 29c970a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
45 changes: 45 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,51 @@ def repartition(self, numPartitions, *cols):
else:
raise TypeError("numPartitions should be an int or Column")

@since("2.4.0")
def repartitionByRange(self, numPartitions, *cols):
"""
Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The
resulting DataFrame is range partitioned.
``numPartitions`` can be an int to specify the target number of partitions or a Column.
If it is a Column, it will be used as the first partitioning column. If not specified,
the default number of partitions is used.
At least one partition-by expression must be specified.
When no explicit sort order is specified, "ascending nulls first" is assumed.
>>> df.repartitionByRange(2, "age").rdd.getNumPartitions()
2
>>> df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 5| Bob|
+---+-----+
>>> df.repartitionByRange(1, "age").rdd.getNumPartitions()
1
>>> data = df.repartitionByRange("age")
>>> df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 5| Bob|
+---+-----+
"""
if isinstance(numPartitions, int):
if len(cols) == 0:
return ValueError("At least one partition-by expression must be specified.")
else:
return DataFrame(
self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx)
elif isinstance(numPartitions, (basestring, Column)):
cols = (numPartitions,) + cols
return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx)
else:
raise TypeError("numPartitions should be an int, string or Column")

@since(1.3)
def distinct(self):
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,34 @@ def test_expr(self):
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
self.assertEqual(13, result["length(a)"])

def test_repartitionByRange_dataframe(self):
schema = StructType([
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("height", DoubleType(), True)])

df1 = self.spark.createDataFrame(
[(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema)
df2 = self.spark.createDataFrame(
[(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema)

# test repartitionByRange(numPartitions, *cols)
df3 = df1.repartitionByRange(2, "name", "age")
self.assertEqual(df3.rdd.getNumPartitions(), 2)
self.assertEqual(df3.rdd.first(), df2.rdd.first())
self.assertEqual(df3.rdd.take(3), df2.rdd.take(3))

# test repartitionByRange(numPartitions, *cols)
df4 = df1.repartitionByRange(3, "name", "age")
self.assertEqual(df4.rdd.getNumPartitions(), 3)
self.assertEqual(df4.rdd.first(), df2.rdd.first())
self.assertEqual(df4.rdd.take(3), df2.rdd.take(3))

# test repartitionByRange(*cols)
df5 = df1.repartitionByRange("name", "age")
self.assertEqual(df5.rdd.first(), df2.rdd.first())
self.assertEqual(df5.rdd.take(3), df2.rdd.take(3))

def test_replace(self):
schema = StructType([
StructField("name", StringType(), True),
Expand Down

0 comments on commit 29c970a

Please sign in to comment.