From 801009e798dc3c82f549241e2764e300ad1295da Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Wed, 6 May 2015 22:55:29 +0800 Subject: [PATCH] Update --- python/pyspark/sql/dataframe.py | 14 ++------------ python/pyspark/sql/functions.py | 10 ++++++++++ .../main/scala/org/apache/spark/sql/Column.scala | 7 +++---- .../scala/org/apache/spark/sql/functions.scala | 12 ++++++++++++ .../apache/spark/sql/ColumnExpressionSuite.scala | 11 +++-------- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 17eef7070eab2..bbfda34a045f3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1414,21 +1414,11 @@ def between(self, lowerBound, upperBound): @ignore_unicode_prefix def when(self, whenExpr, thenExpr): - """ A case when otherwise expression.. - >>> df.select(df.age.when(2, 3).otherwise(4).alias("age")).collect() - [Row(age=3), Row(age=4)] - >>> df.select(df.age.when(2, 3).alias("age")).collect() - [Row(age=3), Row(age=None)] - >>> df.select(df.age.otherwise(4).alias("age")).collect() - [Row(age=4), Row(age=4)] - """ - jc = self._jc.when(whenExpr, thenExpr) - return Column(jc) + return self._jc.when(whenExpr, thenExpr) @ignore_unicode_prefix def otherwise(self, elseExpr): - jc = self._jc.otherwise(elseExpr) - return Column(jc) + return self._jc.otherwise(elseExpr) def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 641220a264295..a2ba9375cf9be 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -146,6 +146,16 @@ def monotonicallyIncreasingId(): sc = SparkContext._active_spark_context return Column(sc._jvm.functions.monotonicallyIncreasingId()) +def when(whenExpr, thenExpr): + """ A case when otherwise expression. + >>> df.select(when(df.age == 2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + >>> df.select(when(df.age == 2, 3).alias("age")).collect() + [Row(age=3), Row(age=None)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.when(whenExpr, thenExpr) + return Column(jc) def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index afe0193a56f1e..402e346ae030f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -298,7 +298,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** * Case When Otherwise. * {{{ - * people.select( people("age").when(18, "SELECTED").other("IGNORED") ) + * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) * }}} * * @group expr_ops @@ -306,10 +306,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { def when(whenExpr: Any, thenExpr: Any):Column = { this.expr match { case CaseWhen(branches: Seq[Expression]) => - val caseExpr = branches.head.asInstanceOf[EqualNullSafe].left - CaseWhen(branches ++ Seq((caseExpr <=> whenExpr).expr, lit(thenExpr).expr)) + CaseWhen(branches ++ Seq(lit(whenExpr).expr, lit(thenExpr).expr)) case _ => - CaseWhen(Seq((this <=> whenExpr).expr, lit(thenExpr).expr)) + CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7e283393d0563..951a4c09b1e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -363,6 +363,18 @@ object functions { */ def not(e: Column): Column = !e + /** + * Case When Otherwise. + * {{{ + * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) + * }}} + * + * @group normal_funcs + */ + def when(whenExpr: Any, thenExpr: Any): Column = { + CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) + } + /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 26997c39224c8..69a6bc4aebb41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -257,22 +257,17 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } - test("SPARK-7321 case") { + test("SPARK-7321 case when otherwise") { val testData = (1 to 3).map(i => TestData(i, i.toString)).toDF() checkAnswer( - testData.select($"key".when(1, -1).when(2, -2).otherwise(0)), + testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)), Seq(Row(-1), Row(-2), Row(0)) ) checkAnswer( - testData.select($"key".when(1, -1).when(2, -2)), + testData.select(when($"key" === 1, -1).when($"key" === 2, -2)), Seq(Row(-1), Row(-2), Row(null)) ) - - checkAnswer( - testData.select($"key".otherwise(0)), - Seq(Row(0), Row(0), Row(0)) - ) } test("sqrt") {