Skip to content

Commit

Permalink
[SPARK-7152][SQL] Add a Column expression for partition ID.
Browse files Browse the repository at this point in the history
Author: Reynold Xin <[email protected]>

Closes apache#5705 from rxin/df-pid and squashes the following commits:

401018f [Reynold Xin] [SPARK-7152][SQL] Add a Column expression for partition ID.
  • Loading branch information
rxin committed Apr 26, 2015
1 parent 9a5bbe0 commit ca55dc9
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 19 deletions.
30 changes: 21 additions & 9 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ def _(col):
__all__.sort()


def approxCountDistinct(col, rsd=None):
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
>>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
if rsd is None:
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
else:
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)


def countDistinct(col, *cols):
"""Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
Expand All @@ -89,18 +103,16 @@ def countDistinct(col, *cols):
return Column(jc)


def approxCountDistinct(col, rsd=None):
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
def sparkPartitionId():
"""Returns a column for partition ID of the Spark task.
>>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
[Row(c=2)]
Note that this is indeterministic because it depends on data partitioning and task scheduling.
>>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect()
[Row(pid=0), Row(pid=0)]
"""
sc = SparkContext._active_spark_context
if rsd is None:
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
else:
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)
return Column(sc._jvm.functions.sparkPartitionId())


class UserDefinedFunction(object):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types.{IntegerType, DataType}


/**
* Expression that returns the current partition id of the Spark task.
*/
case object SparkPartitionID extends Expression with trees.LeafNode[Expression] {
self: Product =>

override type EvaluatedType = Int

override def nullable: Boolean = false

override def dataType: DataType = IntegerType

override def eval(input: Row): Int = TaskContext.get().partitionId()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

/**
* Package containing expressions that are specific to Spark runtime.
*/
package object expressions
29 changes: 19 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ object functions {
// Non-aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////

/**
* Computes the absolute value.
*
* @group normal_funcs
*/
def abs(e: Column): Column = Abs(e.expr)

/**
* Returns the first column that is not null.
* {{{
Expand All @@ -287,6 +294,13 @@ object functions {
@scala.annotation.varargs
def coalesce(e: Column*): Column = Coalesce(e.map(_.expr))

/**
* Converts a string exprsesion to lower case.
*
* @group normal_funcs
*/
def lower(e: Column): Column = Lower(e.expr)

/**
* Unary minus, i.e. negate the expression.
* {{{
Expand Down Expand Up @@ -317,18 +331,13 @@ object functions {
def not(e: Column): Column = !e

/**
* Converts a string expression to upper case.
* Partition ID of the Spark task.
*
* @group normal_funcs
*/
def upper(e: Column): Column = Upper(e.expr)

/**
* Converts a string exprsesion to lower case.
* Note that this is indeterministic because it depends on data partitioning and task scheduling.
*
* @group normal_funcs
*/
def lower(e: Column): Column = Lower(e.expr)
def sparkPartitionId(): Column = execution.expressions.SparkPartitionID

/**
* Computes the square root of the specified float value.
Expand All @@ -338,11 +347,11 @@ object functions {
def sqrt(e: Column): Column = Sqrt(e.expr)

/**
* Computes the absolutle value.
* Converts a string expression to upper case.
*
* @group normal_funcs
*/
def abs(e: Column): Column = Abs(e.expr)
def upper(e: Column): Column = Upper(e.expr)

//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ class ColumnExpressionSuite extends QueryTest {
)
}

test("sparkPartitionId") {
val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b")
checkAnswer(
df.select(sparkPartitionId()),
Row(0)
)
}

test("lift alias out of cast") {
compareExpressions(
col("1234").as("name").cast("int").expr,
Expand Down

0 comments on commit ca55dc9

Please sign in to comment.