Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-7321][SQL] Add Column expression for conditional statements (when/otherwise) #6072

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pyspark/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
Aggregation methods, returned by :func:`DataFrame.groupBy`.
- L{DataFrameNaFunctions}
Methods for handling missing data (null values).
- L{DataFrameStatFunctions}
Methods for statistics functionality.
- L{functions}
List of built-in functions available for :class:`DataFrame`.
- L{types}
Expand Down
31 changes: 31 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,37 @@ def between(self, lowerBound, upperBound):
"""
return (self >= lowerBound) & (self <= upperBound)

@ignore_unicode_prefix
def when(self, condition, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.

See :func:`pyspark.sql.functions.when` for example usage.

:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.

"""
sc = SparkContext._active_spark_context
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
v = value._jc if isinstance(value, Column) else value
jc = sc._jvm.functions.when(condition._jc, v)
return Column(jc)

@ignore_unicode_prefix
def otherwise(self, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.

See :func:`pyspark.sql.functions.when` for example usage.

:param value: a literal value, or a :class:`Column` expression.
"""
v = value._jc if isinstance(value, Column) else value
jc = self._jc.otherwise(value)
return Column(jc)

def __repr__(self):
return 'Column<%s>' % self._jc.toString().encode('utf8')

Expand Down
26 changes: 24 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@

__all__ = [
'approxCountDistinct',
'coalesce',
'countDistinct',
'monotonicallyIncreasingId',
'rand',
'randn',
'sparkPartitionId',
'coalesce',
'udf']
'udf',
'when']


def _create_function(name, doc=""):
Expand Down Expand Up @@ -291,6 +292,27 @@ def struct(*cols):
return Column(jc)


def when(condition, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.

:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.

>>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
[Row(age=3), Row(age=4)]

>>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect()
[Row(age=3), Row(age=None)]
"""
sc = SparkContext._active_spark_context
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
v = value._jc if isinstance(value, Column) else value
jc = sc._jvm.functions.when(condition._jc, v)
return Column(jc)


class UserDefinedFunction(object):
"""
User defined function in Python
Expand Down
56 changes: 56 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,62 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*/
def eqNullSafe(other: Any): Column = this <=> other

/**
* Evaluates a list of conditions and returns one of multiple possible result expressions.
* If otherwise is not defined at the end, null is returned for unmatched conditions.
*
* {{{
* // Example: encoding gender string column into integer.
*
* // Scala:
* people.select(when(people("gender") === "male", 0)
* .when(people("gender") === "female", 1)
* .otherwise(2))
*
* // Java:
* people.select(when(col("gender").equalTo("male"), 0)
* .when(col("gender").equalTo("female"), 1)
* .otherwise(2))
* }}}
*
* @group expr_ops
*/
def when(condition: Column, value: Any):Column = this.expr match {
case CaseWhen(branches: Seq[Expression]) =>
CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr))
case _ =>
throw new IllegalArgumentException(
"when() can only be applied on a Column previously generated by when() function")
}

/**
* Evaluates a list of conditions and returns one of multiple possible result expressions.
* If otherwise is not defined at the end, null is returned for unmatched conditions.
*
* {{{
* // Example: encoding gender string column into integer.
*
* // Scala:
* people.select(when(people("gender") === "male", 0)
* .when(people("gender") === "female", 1)
* .otherwise(2))
*
* // Java:
* people.select(when(col("gender").equalTo("male"), 0)
* .when(col("gender").equalTo("female"), 1)
* .otherwise(2))
* }}}
*
* @group expr_ops
*/
def otherwise(value: Any):Column = this.expr match {
case CaseWhen(branches: Seq[Expression]) =>
CaseWhen(branches :+ lit(value).expr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if user mistakenly call otherwise twice? Then we will build a wrong CaseWhen expression here. Maybe we should create a helper class like what we did for NA functions and aggregate functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It is too heavy weight to have a whole class dedicated to this, but I will add code to throw exceptions if otherwise has been applied previously.

case _ =>
throw new IllegalArgumentException(
"otherwise() can only be applied on a Column previously generated by when() function")
}

/**
* True if the current column is between the lower bound and upper bound, inclusive.
*
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,30 @@ object functions {
*/
def not(e: Column): Column = !e

/**
* Evaluates a list of conditions and returns one of multiple possible result expressions.
* If otherwise is not defined at the end, null is returned for unmatched conditions.
*
* {{{
* // Example: encoding gender string column into integer.
*
* // Scala:
* people.select(when(people("gender") === "male", 0)
* .when(people("gender") === "female", 1)
* .otherwise(2))
*
* // Java:
* people.select(when(col("gender").equalTo("male"), 0)
* .when(col("gender").equalTo("female"), 1)
* .otherwise(2))
* }}}
*
* @group normal_funcs
*/
def when(condition: Column, value: Any): Column = {
CaseWhen(Seq(condition.expr, lit(value).expr))
}

/**
* Generate a random column with i.i.d. samples from U[0.0, 1.0].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,26 @@ class ColumnExpressionSuite extends QueryTest {
Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil)
}

test("SPARK-7321 when conditional statements") {
val testData = (1 to 3).map(i => (i, i.toString)).toDF("key", "value")

checkAnswer(
testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)),
Seq(Row(-1), Row(-2), Row(0))
)

// Without the ending otherwise, return null for unmatched conditions.
// Also test putting a non-literal value in the expression.
checkAnswer(
testData.select(when($"key" === 1, lit(0) - $"key").when($"key" === 2, -2)),
Seq(Row(-1), Row(-2), Row(null))
)

intercept[IllegalArgumentException] {
$"key".when($"key" === 1, -1)
}
}

test("sqrt") {
checkAnswer(
testData.select(sqrt('key)).orderBy('key.asc),
Expand Down