diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 72180f6d05fbc..605b9e44e1d93 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1461,6 +1461,19 @@ def between(self, lowerBound, upperBound): """ return (self >= lowerBound) & (self <= upperBound) + @ignore_unicode_prefix + def when(self, whenExpr, thenExpr): + if isinstance(whenExpr, Column): + jc = self._jc.when(whenExpr._jc, thenExpr) + else: + raise TypeError("whenExpr should be Column") + return Column(jc) + + @ignore_unicode_prefix + def otherwise(self, elseExpr): + jc = self._jc.otherwise(elseExpr) + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 38a043a3c59d7..b603143062387 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -237,6 +237,21 @@ def monotonicallyIncreasingId(): sc = SparkContext._active_spark_context return Column(sc._jvm.functions.monotonicallyIncreasingId()) +def when(whenExpr, thenExpr): + """ A case when otherwise expression. + >>> df.select(when(df.age == 2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + >>> df.select(when(df.age == 2, 3).alias("age")).collect() + [Row(age=3), Row(age=None)] + >>> df.select(when(df.age == 2, 3==3).alias("age")).collect() + [Row(age=True), Row(age=None)] + """ + sc = SparkContext._active_spark_context + if isinstance(whenExpr, Column): + jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) + else: + raise TypeError("whenExpr should be Column") + return Column(jc) def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e6e475bb82f82..8fbd78b70b4a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -308,6 +308,32 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def eqNullSafe(other: Any): Column = this <=> other + /** + * Case When Otherwise. + * {{{ + * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) + * }}} + * + * @group expr_ops + */ + def when(whenExpr: Any, thenExpr: Any):Column = { + this.expr match { + case CaseWhen(branches: Seq[Expression]) => + CaseWhen(branches ++ Seq(lit(whenExpr).expr, lit(thenExpr).expr)) + case _ => + CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) + } + } + + def otherwise(elseExpr: Any):Column = { + this.expr match { + case CaseWhen(branches: Seq[Expression]) => + CaseWhen(branches :+ lit(elseExpr).expr) + case _ => + CaseWhen(Seq(lit(true).expr, lit(elseExpr).expr)) + } + } + /** * True if the current column is between the lower bound and upper bound, inclusive. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fae4bd0fd2994..5cccf62d755b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -381,6 +381,18 @@ object functions { */ def not(e: Column): Column = !e + /** + * Case When Otherwise. + * {{{ + * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) + * }}} + * + * @group normal_funcs + */ + def when(whenExpr: Any, thenExpr: Any): Column = { + CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) + } + /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index d96186c268720..8d79f46396247 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -255,6 +255,19 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } + test("SPARK-7321 case when otherwise") { + val testData = (1 to 3).map(i => TestData(i, i.toString)).toDF() + checkAnswer( + testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)), + Seq(Row(-1), Row(-2), Row(0)) + ) + + checkAnswer( + testData.select(when($"key" === 1, -1).when($"key" === 2, -2)), + Seq(Row(-1), Row(-2), Row(null)) + ) + } + test("sqrt") { checkAnswer( testData.select(sqrt('key)).orderBy('key.asc),