From 6b9b9c44ea7cafa7e1fb607bcf5a2d19336f31f4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 17 Jan 2018 20:58:54 +0900 Subject: [PATCH] Address comments and clean up --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/catalog.py | 4 +- python/pyspark/sql/context.py | 27 ++--- python/pyspark/sql/session.py | 144 +------------------------ python/pyspark/sql/tests.py | 14 ++- python/pyspark/sql/udf.py | 183 +++++++++++++++++++++++++++++++- 6 files changed, 205 insertions(+), 168 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7164180a6a7b0..b900f0bd913c3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -400,6 +400,7 @@ def __hash__(self): "pyspark.sql.functions", "pyspark.sql.readwriter", "pyspark.sql.streaming", + "pyspark.sql.udf", "pyspark.sql.window", "pyspark.sql.tests", ] diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index b01bc0b8e268b..b64eb8acd3b91 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -231,12 +231,14 @@ def registerFunction(self, name, f, returnType=None): DeprecationWarning) return self._sparkSession.udf.register(name, f, returnType) # Reuse the docstring from UDFRegistration but with few notes. + _register_doc = UDFRegistration.register.__doc__.strip() registerFunction.__doc__ = """%s + .. note:: :func:`spark.catalog.registerFunction` is an alias for :func:`spark.udf.register`. .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. .. versionadded:: 2.0 - """ % UDFRegistration.register.__doc__ + """ % _register_doc[:_register_doc.rfind('versionadded::')] @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9f7c7c72e3c2e..aaf28889274d0 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -147,8 +147,7 @@ def udf(self): :return: :class:`UDFRegistration` """ - from pyspark.sql.session import UDFRegistration - return UDFRegistration(self.sparkSession) + return self.sparkSession.udf @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -179,36 +178,28 @@ def registerFunction(self, name, f, returnType=None): DeprecationWarning) return self.sparkSession.udf.register(name, f, returnType) # Reuse the docstring from UDFRegistration but with few notes. + _register_doc = UDFRegistration.register.__doc__.strip() registerFunction.__doc__ = """%s + .. note:: :func:`sqlContext.registerFunction` is an alias for :func:`spark.udf.register`. .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. .. versionadded:: 1.2 - """ % UDFRegistration.register.__doc__ + """ % _register_doc[:_register_doc.rfind('versionadded::')] def registerJavaFunction(self, name, javaClassName, returnType=None): warnings.warn( "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.", DeprecationWarning) return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType) + _registerJavaFunction_doc = UDFRegistration.registerJavaFunction.__doc__.strip() registerJavaFunction.__doc__ = """%s + .. note:: :func:`sqlContext.registerJavaFunction` is an alias for :func:`spark.udf.registerJavaFunction` .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. .. versionadded:: 2.1 - """ % UDFRegistration.registerJavaFunction.__doc__ - - def registerJavaUDAF(self, name, javaClassName): - warnings.warn( - "Deprecated in 2.3.0. Use spark.udf.registerJavaUDAF instead.", - DeprecationWarning) - return self.sparkSession.udf.registerJavaUDAF(name, javaClassName) - registerJavaUDAF.__doc__ = """%s - .. note:: :func:`sqlContext.registerJavaUDAF` is an alias for - :func:`spark.udf.registerJavaUDAF`. - .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaUDAF` instead. - .. versionadded:: 2.3 - """ % UDFRegistration.registerJavaUDAF.__doc__ + """ % _registerJavaFunction_doc[:_registerJavaFunction_doc.rfind('versionadded::')] # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -536,9 +527,9 @@ def _test(): globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - # 'spark' alias is a small hack for reusing doctests. Please see the reassignment + # 'spark' is used for reusing doctests. Please see the reassignment # of docstrings above. - globs['spark'] = globs['sqlContext'] + globs['spark'] = globs['sqlContext'].sparkSession globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e71552726aea2..0e126452c2879 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -36,10 +36,10 @@ from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \ _parse_datatype_string -from pyspark.sql.udf import UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction, UDFRegistration from pyspark.sql.utils import install_exception_handler -__all__ = ["SparkSession", "UDFRegistration"] +__all__ = ["SparkSession"] def _monkey_patch_RDD(sparkSession): @@ -778,146 +778,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.stop() -class UDFRegistration(object): - """Wrapper for user-defined function registration.""" - - def __init__(self, sparkSession): - self.sparkSession = sparkSession - - @ignore_unicode_prefix - def register(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a user-defined function - in SQL statements. - - :param name: name of the user-defined function in SQL statements. - :param f: a Python function, or a user-defined function. The user-defined function can - be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and - :meth:`pyspark.sql.functions.pandas_udf`. - :param returnType: the return type of the registered user-defined function. - :return: a user-defined function. - - `returnType` can be optionally specified when `f` is a Python function but not - when `f` is a user-defined function. See below: - - 1. When `f` is a Python function, `returnType` defaults to string type and can be - optionally specified. The produced object must match the specified type. In this case, - this API works as if `register(name, f, returnType=StringType())`. - - >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - 2. When `f` is a user-defined function, Spark uses the return type of the given a - user-defined function as the return type of the registered a user-defined function. - `returnType` should not be specified. In this case, this API works as if - `register(name, f)`. - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = spark.udf.register("slen", slen) - >>> spark.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = spark.udf.register("random_udf", random_udf) - >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP - >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - """ - - # This is to check whether the input function is from a user-defined function or - # Python function. - if hasattr(f, 'asNondeterministic'): - if returnType is not None: - raise TypeError( - "Invalid returnType: data type can not be specified when f is" - "a user-defined function, but got %s." % returnType) - if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: - raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") - register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, - evalType=f.evalType, - deterministic=f.deterministic) - return_udf = f - else: - if returnType is None: - returnType = StringType() - register_udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - return_udf = register_udf._wrapped() - self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) - return return_udf - - @ignore_unicode_prefix - def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a Java user-defined function so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not specified we would infer it via reflection. - - :param name: name of the user-defined function - :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object - - >>> from pyspark.sql.types import IntegerType - >>> spark.udf.registerJavaFunction("javaStringLength", - ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) - >>> spark.sql("SELECT javaStringLength('test')").collect() - [Row(UDF:javaStringLength(test)=4)] - >>> spark.udf.registerJavaFunction("javaStringLength2", - ... "test.org.apache.spark.sql.JavaStringLength") - >>> spark.sql("SELECT javaStringLength2('test')").collect() - [Row(UDF:javaStringLength2(test)=4)] - - """ - jdt = None - if returnType is not None: - jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) - self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) - - @ignore_unicode_prefix - def registerJavaUDAF(self, name, javaClassName): - """Register a Java user-defined aggregate function so it can be used in SQL statements. - - :param name: name of the user-defined aggregate function - :param javaClassName: fully qualified name of java class - - >>> spark.udf.registerJavaUDAF("javaUDAF", - ... "test.org.apache.spark.sql.MyDoubleAvg") - >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") - >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() - [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] - """ - self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) - - def _test(): import os import doctest diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d30e1888695c6..f84aa3d68b808 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -583,6 +583,15 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + # This is to check if a 'SQLContext.udf' can call its alias. + sqlContext = self.spark._wrapped + add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_four(id) AS plus_four").collect(), + df.select(add_four("id").alias("plus_four")).collect() + ) + def test_non_existed_udf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", @@ -598,11 +607,6 @@ def test_non_existed_udaf(self): self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - # This is to check if a deprecated 'SQLContext.registerJavaUDAF' can call its alias. - sqlContext = spark._wrapped - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", - lambda: sqlContext.registerJavaUDAF("udaf1", "non_existed_udaf")) - def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e80ab9165867..ef477001a398d 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -19,11 +19,13 @@ """ import functools -from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark import SparkContext, since +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +__all__ = ["UDFRegistration"] + def _wrap_function(sc, func, returnType): command = (func, returnType) @@ -181,3 +183,180 @@ def asNondeterministic(self): """ self.deterministic = False return self + + +class UDFRegistration(object): + """ + Wrapper for user-defined function registration. + + .. versionadded:: 1.3.1 + """ + + def __init__(self, sparkSession): + self.sparkSession = sparkSession + + @ignore_unicode_prefix + @since(1.3) + def register(self, name, f, returnType=None): + """Registers a Python function (including lambda function) or a user-defined function + in SQL statements. + + :param name: name of the user-defined function in SQL statements. + :param f: a Python function, or a user-defined function. The user-defined function can + be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and + :meth:`pyspark.sql.functions.pandas_udf`. + :param returnType: the return type of the registered user-defined function. + :return: a user-defined function. + + `returnType` can be optionally specified when `f` is a Python function but not + when `f` is a user-defined function. See below: + + 1. When `f` is a Python function: + + `returnType` defaults to string type and can be optionally specified. The produced + object must match the specified type. In this case, this API works as if + `register(name, f, returnType=StringType())`. + + >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + + 2. When `f` is a user-defined function: + + Spark uses the return type of the given user-defined function as the return type of + the registered user-defined function. `returnType` should not be specified. + In this case, this API works as if `register(name, f)`. + + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> new_random_udf = spark.udf.register("random_udf", random_udf) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=82)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + + .. note:: Registration for a user-defined function (case 2.) was added from + Spark 2.3.0. + """ + + # This is to check whether the input function is from a user-defined function or + # Python function. + if hasattr(f, 'asNondeterministic'): + if returnType is not None: + raise TypeError( + "Invalid returnType: data type can not be specified when f is" + "a user-defined function, but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f + else: + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf + + @ignore_unicode_prefix + @since(2.3) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a Java user-defined function so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + + :param name: name of the user-defined function + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.registerJavaFunction("javaStringLength", + ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> spark.sql("SELECT javaStringLength('test')").collect() + [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction("javaStringLength2", + ... "test.org.apache.spark.sql.JavaStringLength") + >>> spark.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF:javaStringLength2(test)=4)] + """ + + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a Java user-defined aggregate function so it can be used in SQL statements. + + :param name: name of the user-defined aggregate function + :param javaClassName: fully qualified name of java class + + >>> spark.udf.registerJavaUDAF("javaUDAF", + ... "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.udf + globs = pyspark.sql.udf.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.udf tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.udf, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test()