Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
kaka1992 committed May 6, 2015
1 parent 76d6346 commit 801009e
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 24 deletions.
14 changes: 2 additions & 12 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,21 +1414,11 @@ def between(self, lowerBound, upperBound):

@ignore_unicode_prefix
def when(self, whenExpr, thenExpr):
""" A case when otherwise expression..
>>> df.select(df.age.when(2, 3).otherwise(4).alias("age")).collect()
[Row(age=3), Row(age=4)]
>>> df.select(df.age.when(2, 3).alias("age")).collect()
[Row(age=3), Row(age=None)]
>>> df.select(df.age.otherwise(4).alias("age")).collect()
[Row(age=4), Row(age=4)]
"""
jc = self._jc.when(whenExpr, thenExpr)
return Column(jc)
return self._jc.when(whenExpr, thenExpr)

@ignore_unicode_prefix
def otherwise(self, elseExpr):
jc = self._jc.otherwise(elseExpr)
return Column(jc)
return self._jc.otherwise(elseExpr)

def __repr__(self):
return 'Column<%s>' % self._jc.toString().encode('utf8')
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ def monotonicallyIncreasingId():
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.monotonicallyIncreasingId())

def when(whenExpr, thenExpr):
""" A case when otherwise expression.
>>> df.select(when(df.age == 2, 3).otherwise(4).alias("age")).collect()
[Row(age=3), Row(age=4)]
>>> df.select(when(df.age == 2, 3).alias("age")).collect()
[Row(age=3), Row(age=None)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.when(whenExpr, thenExpr)
return Column(jc)

def rand(seed=None):
"""Generates a random column with i.i.d. samples from U[0.0, 1.0].
Expand Down
7 changes: 3 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,17 @@ class Column(protected[sql] val expr: Expression) extends Logging {
/**
* Case When Otherwise.
* {{{
* people.select( people("age").when(18, "SELECTED").other("IGNORED") )
* people.select( when(people("age") === 18, "SELECTED").other("IGNORED") )
* }}}
*
* @group expr_ops
*/
def when(whenExpr: Any, thenExpr: Any):Column = {
this.expr match {
case CaseWhen(branches: Seq[Expression]) =>
val caseExpr = branches.head.asInstanceOf[EqualNullSafe].left
CaseWhen(branches ++ Seq((caseExpr <=> whenExpr).expr, lit(thenExpr).expr))
CaseWhen(branches ++ Seq(lit(whenExpr).expr, lit(thenExpr).expr))
case _ =>
CaseWhen(Seq((this <=> whenExpr).expr, lit(thenExpr).expr))
CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr))
}
}

Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,18 @@ object functions {
*/
def not(e: Column): Column = !e

/**
* Case When Otherwise.
* {{{
* people.select( when(people("age") === 18, "SELECTED").other("IGNORED") )
* }}}
*
* @group normal_funcs
*/
def when(whenExpr: Any, thenExpr: Any): Column = {
CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr))
}

/**
* Generate a random column with i.i.d. samples from U[0.0, 1.0].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,22 +257,17 @@ class ColumnExpressionSuite extends QueryTest {
Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil)
}

test("SPARK-7321 case") {
test("SPARK-7321 case when otherwise") {
val testData = (1 to 3).map(i => TestData(i, i.toString)).toDF()
checkAnswer(
testData.select($"key".when(1, -1).when(2, -2).otherwise(0)),
testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)),
Seq(Row(-1), Row(-2), Row(0))
)

checkAnswer(
testData.select($"key".when(1, -1).when(2, -2)),
testData.select(when($"key" === 1, -1).when($"key" === 2, -2)),
Seq(Row(-1), Row(-2), Row(null))
)

checkAnswer(
testData.select($"key".otherwise(0)),
Seq(Row(0), Row(0), Row(0))
)
}

test("sqrt") {
Expand Down

0 comments on commit 801009e

Please sign in to comment.