From 95724c6375e3f0fda4bef4f2d8c6a62811a196cc Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 12 May 2015 10:38:21 +0800 Subject: [PATCH] Update --- python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/functions.py | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b64cff0fbd550..272d05a911b6f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1416,6 +1416,8 @@ def between(self, lowerBound, upperBound): def when(self, whenExpr, thenExpr): if isinstance(whenExpr, Column): jc = self._jc.when(whenExpr._jc, thenExpr) + else: + raise TypeError("whenExpr should be Column") return Column(jc) @ignore_unicode_prefix diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b70005d6ed4a2..e8dbbe1ddb30c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq, _create_column_from_literal +from pyspark.sql.dataframe import Column, _to_java_column, _to_seq __all__ = [ @@ -152,9 +152,14 @@ def when(whenExpr, thenExpr): [Row(age=3), Row(age=4)] >>> df.select(when(df.age == 2, 3).alias("age")).collect() [Row(age=3), Row(age=None)] + >>> df.select(when(df.age == 2, 3==3).alias("age")).collect() + [Row(age=True), Row(age=None)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) + if isinstance(whenExpr, Column): + jc = sc._jvm.functions.when(whenExpr._jc, thenExpr) + else: + raise TypeError("whenExpr should be Column") return Column(jc) def rand(seed=None):