From 76d63461d2c512f5a6519d25dcaa14cfa8ec6468 Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Wed, 6 May 2015 11:20:01 +0800 Subject: [PATCH 1/4] [SPARK-7321][SQL] Add Column expression for conditional statements (if, case) --- python/pyspark/sql/dataframe.py | 18 +++++++++++++ .../scala/org/apache/spark/sql/Column.scala | 27 +++++++++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 18 +++++++++++++ 3 files changed, 63 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 24f370543def4..17eef7070eab2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1412,6 +1412,24 @@ def between(self, lowerBound, upperBound): """ return (self >= lowerBound) & (self <= upperBound) + @ignore_unicode_prefix + def when(self, whenExpr, thenExpr): + """ A case when otherwise expression.. + >>> df.select(df.age.when(2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + >>> df.select(df.age.when(2, 3).alias("age")).collect() + [Row(age=3), Row(age=None)] + >>> df.select(df.age.otherwise(4).alias("age")).collect() + [Row(age=4), Row(age=4)] + """ + jc = self._jc.when(whenExpr, thenExpr) + return Column(jc) + + @ignore_unicode_prefix + def otherwise(self, elseExpr): + jc = self._jc.otherwise(elseExpr) + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c0503bf047052..afe0193a56f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -295,6 +295,33 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def eqNullSafe(other: Any): Column = this <=> other + /** + * Case When Otherwise. + * {{{ + * people.select( people("age").when(18, "SELECTED").other("IGNORED") ) + * }}} + * + * @group expr_ops + */ + def when(whenExpr: Any, thenExpr: Any):Column = { + this.expr match { + case CaseWhen(branches: Seq[Expression]) => + val caseExpr = branches.head.asInstanceOf[EqualNullSafe].left + CaseWhen(branches ++ Seq((caseExpr <=> whenExpr).expr, lit(thenExpr).expr)) + case _ => + CaseWhen(Seq((this <=> whenExpr).expr, lit(thenExpr).expr)) + } + } + + def otherwise(elseExpr: Any):Column = { + this.expr match { + case CaseWhen(branches: Seq[Expression]) => + CaseWhen(branches :+ lit(elseExpr).expr) + case _ => + CaseWhen(Seq(lit(true).expr, lit(elseExpr).expr)) + } + } + /** * True if the current column is between the lower bound and upper bound, inclusive. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 3c1ad656fc855..26997c39224c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -257,6 +257,24 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } + test("SPARK-7321 case") { + val testData = (1 to 3).map(i => TestData(i, i.toString)).toDF() + checkAnswer( + testData.select($"key".when(1, -1).when(2, -2).otherwise(0)), + Seq(Row(-1), Row(-2), Row(0)) + ) + + checkAnswer( + testData.select($"key".when(1, -1).when(2, -2)), + Seq(Row(-1), Row(-2), Row(null)) + ) + + checkAnswer( + testData.select($"key".otherwise(0)), + Seq(Row(0), Row(0), Row(0)) + ) + } + test("sqrt") { checkAnswer( testData.select(sqrt('key)).orderBy('key.asc), From 801009e798dc3c82f549241e2764e300ad1295da Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Wed, 6 May 2015 22:55:29 +0800 Subject: [PATCH 2/4] Update --- python/pyspark/sql/dataframe.py | 14 ++------------ python/pyspark/sql/functions.py | 10 ++++++++++ .../main/scala/org/apache/spark/sql/Column.scala | 7 +++---- .../scala/org/apache/spark/sql/functions.scala | 12 ++++++++++++ .../apache/spark/sql/ColumnExpressionSuite.scala | 11 +++-------- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 17eef7070eab2..bbfda34a045f3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1414,21 +1414,11 @@ def between(self, lowerBound, upperBound): @ignore_unicode_prefix def when(self, whenExpr, thenExpr): - """ A case when otherwise expression.. - >>> df.select(df.age.when(2, 3).otherwise(4).alias("age")).collect() - [Row(age=3), Row(age=4)] - >>> df.select(df.age.when(2, 3).alias("age")).collect() - [Row(age=3), Row(age=None)] - >>> df.select(df.age.otherwise(4).alias("age")).collect() - [Row(age=4), Row(age=4)] - """ - jc = self._jc.when(whenExpr, thenExpr) - return Column(jc) + return self._jc.when(whenExpr, thenExpr) @ignore_unicode_prefix def otherwise(self, elseExpr): - jc = self._jc.otherwise(elseExpr) - return Column(jc) + return self._jc.otherwise(elseExpr) def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 641220a264295..a2ba9375cf9be 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -146,6 +146,16 @@ def monotonicallyIncreasingId(): sc = SparkContext._active_spark_context return Column(sc._jvm.functions.monotonicallyIncreasingId()) +def when(whenExpr, thenExpr): + """ A case when otherwise expression. + >>> df.select(when(df.age == 2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + >>> df.select(when(df.age == 2, 3).alias("age")).collect() + [Row(age=3), Row(age=None)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.when(whenExpr, thenExpr) + return Column(jc) def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index afe0193a56f1e..402e346ae030f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -298,7 +298,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** * Case When Otherwise. * {{{ - * people.select( people("age").when(18, "SELECTED").other("IGNORED") ) + * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) * }}} * * @group expr_ops @@ -306,10 +306,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { def when(whenExpr: Any, thenExpr: Any):Column = { this.expr match { case CaseWhen(branches: Seq[Expression]) => - val caseExpr = branches.head.asInstanceOf[EqualNullSafe].left - CaseWhen(branches ++ Seq((caseExpr <=> whenExpr).expr, lit(thenExpr).expr)) + CaseWhen(branches ++ Seq(lit(whenExpr).expr, lit(thenExpr).expr)) case _ => - CaseWhen(Seq((this <=> whenExpr).expr, lit(thenExpr).expr)) + CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7e283393d0563..951a4c09b1e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -363,6 +363,18 @@ object functions { */ def not(e: Column): Column = !e + /** + * Case When Otherwise. + * {{{ + * people.select( when(people("age") === 18, "SELECTED").other("IGNORED") ) + * }}} + * + * @group normal_funcs + */ + def when(whenExpr: Any, thenExpr: Any): Column = { + CaseWhen(Seq(lit(whenExpr).expr, lit(thenExpr).expr)) + } + /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 26997c39224c8..69a6bc4aebb41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -257,22 +257,17 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } - test("SPARK-7321 case") { + test("SPARK-7321 case when otherwise") { val testData = (1 to 3).map(i => TestData(i, i.toString)).toDF() checkAnswer( - testData.select($"key".when(1, -1).when(2, -2).otherwise(0)), + testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)), Seq(Row(-1), Row(-2), Row(0)) ) checkAnswer( - testData.select($"key".when(1, -1).when(2, -2)), + testData.select(when($"key" === 1, -1).when($"key" === 2, -2)), Seq(Row(-1), Row(-2), Row(null)) ) - - checkAnswer( - testData.select($"key".otherwise(0)), - Seq(Row(0), Row(0), Row(0)) - ) } test("sqrt") { From 8218d0acc287565a62259691803fd13c84f651ba Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 12 May 2015 10:28:49 +0800 Subject: [PATCH 3/4] Update --- python/pyspark/sql/dataframe.py | 7 +++++-- python/pyspark/sql/functions.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bbfda34a045f3..b64cff0fbd550 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1414,11 +1414,14 @@ def between(self, lowerBound, upperBound): @ignore_unicode_prefix def when(self, whenExpr, thenExpr): - return self._jc.when(whenExpr, thenExpr) + if isinstance(whenExpr, Column): + jc = self._jc.when(whenExpr._jc, thenExpr) + return Column(jc) @ignore_unicode_prefix def otherwise(self, elseExpr): - return self._jc.otherwise(elseExpr) + jc = self._jc.otherwise(elseExpr) + return Column(jc) def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a2ba9375cf9be..b70005d6ed4a2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq +from pyspark.sql.dataframe import Column, _to_java_column, _to_seq, _create_column_from_literal __all__ = [ @@ -154,7 +154,7 @@ def when(whenExpr, thenExpr): [Row(age=3), Row(age=None)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.when(whenExpr, thenExpr) + jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) return Column(jc) def rand(seed=None): From 95724c6375e3f0fda4bef4f2d8c6a62811a196cc Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 12 May 2015 10:38:21 +0800 Subject: [PATCH 4/4] Update --- python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/functions.py | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b64cff0fbd550..272d05a911b6f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1416,6 +1416,8 @@ def between(self, lowerBound, upperBound): def when(self, whenExpr, thenExpr): if isinstance(whenExpr, Column): jc = self._jc.when(whenExpr._jc, thenExpr) + else: + raise TypeError("whenExpr should be Column") return Column(jc) @ignore_unicode_prefix diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b70005d6ed4a2..e8dbbe1ddb30c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq, _create_column_from_literal +from pyspark.sql.dataframe import Column, _to_java_column, _to_seq __all__ = [ @@ -152,9 +152,14 @@ def when(whenExpr, thenExpr): [Row(age=3), Row(age=4)] >>> df.select(when(df.age == 2, 3).alias("age")).collect() [Row(age=3), Row(age=None)] + >>> df.select(when(df.age == 2, 3==3).alias("age")).collect() + [Row(age=True), Row(age=None)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) + if isinstance(whenExpr, Column): + jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) + else: + raise TypeError("whenExpr should be Column") return Column(jc) def rand(seed=None):