From f63105c7faddc79ccd624c9234b56916efec3569 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 17 Jan 2018 11:49:08 +0900 Subject: [PATCH 01/11] Deprecate register* for UDFs in SQLContext and Catalog in PySpark --- python/pyspark/sql/catalog.py | 97 +++---------------- python/pyspark/sql/context.py | 167 +++++++++----------------------- python/pyspark/sql/functions.py | 4 +- python/pyspark/sql/group.py | 3 +- python/pyspark/sql/session.py | 150 +++++++++++++++++++++++++++- python/pyspark/sql/tests.py | 16 +++ 6 files changed, 221 insertions(+), 216 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 35fbe9e669adb..b01bc0b8e268b 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -21,6 +21,7 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.dataframe import DataFrame +from pyspark.sql.session import UDFRegistration from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import IntegerType, StringType, StructType @@ -224,92 +225,18 @@ def dropGlobalTempView(self, viewName): """ self._jcatalog.dropGlobalTempView(viewName) - @ignore_unicode_prefix - @since(2.0) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = spark.catalog.registerFunction("stringLengthString", len) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = spark.udf.register("slen", slen) - >>> spark.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) - >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP - >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - """ - - # This is to check whether the input function is a wrapped/native UserDefinedFunction - if hasattr(f, 'asNondeterministic'): - if returnType is not None: - raise TypeError( - "Invalid returnType: None is expected when f is a UserDefinedFunction, " - "but got %s." % returnType) - if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: - raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") - register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, - evalType=f.evalType, - deterministic=f.deterministic) - return_udf = f - else: - if returnType is None: - returnType = StringType() - register_udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - return_udf = register_udf._wrapped() - self._jsparkSession.udf().registerPython(name, register_udf._judf) - return return_udf + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self._sparkSession.udf.register(name, f, returnType) + # Reuse the docstring from UDFRegistration but with few notes. + registerFunction.__doc__ = """%s + .. note:: :func:`spark.catalog.registerFunction` is an alias + for :func:`spark.udf.register`. + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. + .. versionadded:: 2.0 + """ % UDFRegistration.register.__doc__ @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 85479095af594..c4097b0dac578 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -24,14 +24,14 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql.session import _monkey_patch_RDD, SparkSession +from pyspark.sql.session import _monkey_patch_RDD, SparkSession, UDFRegistration from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import IntegerType, Row, StringType from pyspark.sql.utils import install_exception_handler -__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] +__all__ = ["SQLContext", "HiveContext"] class SQLContext(object): @@ -47,6 +47,8 @@ class SQLContext(object): :param sparkSession: The :class:`SparkSession` around which this SQLContext wraps. :param jsqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. + + .. note:: Deprecated in 2.3.0. Use SparkSession.builder.getOrCreate(). """ _instantiatedContext = None @@ -69,6 +71,10 @@ def __init__(self, sparkContext, sparkSession=None, jsqlContext=None): >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ + warnings.warn( + "Deprecated in 2.0.0. Use SparkSession.builder.getOrCreate() instead.", + DeprecationWarning) + self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm @@ -147,7 +153,8 @@ def udf(self): :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + from pyspark.sql.session import UDFRegistration + return UDFRegistration(self.sparkSession) @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -172,113 +179,42 @@ def range(self, start, end=None, step=1, numPartitions=None): """ return self.sparkSession.range(start, end, step, numPartitions) - @ignore_unicode_prefix - @since(1.2) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = sqlContext.udf.register("slen", slen) - >>> sqlContext.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) - >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP - >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - """ - return self.sparkSession.catalog.registerFunction(name, f, returnType) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self.sparkSession.udf.register(name, f, returnType) + # Reuse the docstring from UDFRegistration but with few notes. + registerFunction.__doc__ = """%s + .. note:: :func:`sqlContext.registerFunction` is an alias for + :func:`spark.udf.register`. + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. + .. versionadded:: 1.2 + """ % UDFRegistration.register.__doc__ - @ignore_unicode_prefix - @since(2.1) def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a java UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not specified we would infer it via reflection. - :param name: name of the UDF - :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object - - >>> sqlContext.registerJavaFunction("javaStringLength", - ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) - >>> sqlContext.sql("SELECT javaStringLength('test')").collect() - [Row(UDF:javaStringLength(test)=4)] - >>> sqlContext.registerJavaFunction("javaStringLength2", - ... "test.org.apache.spark.sql.JavaStringLength") - >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() - [Row(UDF:javaStringLength2(test)=4)] - - """ - jdt = None - if returnType is not None: - jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) - self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.", + DeprecationWarning) + return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType) + registerJavaFunction.__doc__ = """%s + .. note:: :func:`sqlContext.registerJavaFunction` is an alias for + :func:`spark.udf.registerJavaFunction` + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. + .. versionadded:: 2.1 + """ % UDFRegistration.registerJavaFunction.__doc__ - @ignore_unicode_prefix - @since(2.3) def registerJavaUDAF(self, name, javaClassName): - """Register a java UDAF so it can be used in SQL statements. - - :param name: name of the UDAF - :param javaClassName: fully qualified name of java class - - >>> sqlContext.registerJavaUDAF("javaUDAF", - ... "test.org.apache.spark.sql.MyDoubleAvg") - >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") - >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() - [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] - """ - self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.registerJavaUDAF instead.", + DeprecationWarning) + return self.sparkSession.udf.registerJavaUDAF(name, javaClassName) + registerJavaUDAF.__doc__ = """%s + .. note:: :func:`sqlContext.registerJavaUDAF` is an alias for + :func:`spark.udf.registerJavaUDAF`. + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaUDAF` instead. + .. versionadded:: 2.3 + """ % UDFRegistration.registerJavaUDAF.__doc__ # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -590,24 +526,6 @@ def refreshTable(self, tableName): self._ssql_ctx.refreshTable(tableName) -class UDFRegistration(object): - """Wrapper for user-defined function registration.""" - - def __init__(self, sqlContext): - self.sqlContext = sqlContext - - def register(self, name, f, returnType=None): - return self.sqlContext.registerFunction(name, f, returnType) - - def registerJavaFunction(self, name, javaClassName, returnType=None): - self.sqlContext.registerJavaFunction(name, javaClassName, returnType) - - def registerJavaUDAF(self, name, javaClassName): - self.sqlContext.registerJavaUDAF(name, javaClassName) - - register.__doc__ = SQLContext.registerFunction.__doc__ - - def _test(): import os import doctest @@ -624,6 +542,9 @@ def _test(): globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) + # 'spark' alias is a small hack for reusing doctests. Please see the reassignment + # of docstrings above. + globs['spark'] = globs['sqlContext'] globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f7b3f29764040..988c1d25259bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()): >>> import random >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. @@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): ... return pd.Series(np.random.randn(len(v)) >>> random = random.asNondeterministic() # doctest: +SKIP - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. """ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 09fae46adf014..22061b83eb78c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -212,7 +212,8 @@ def apply(self, udf): This function does not support partial aggregation, and requires shuffling all the data in the :class:`DataFrame`. - :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + :param udf: a group map user-defined function returned by + :meth:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 604021c1f45cc..e71552726aea2 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -28,8 +28,7 @@ from itertools import izip as zip, imap as map from pyspark import since -from pyspark.rdd import RDD, ignore_unicode_prefix -from pyspark.sql.catalog import Catalog +from pyspark.rdd import RDD, ignore_unicode_prefix, PythonEvalType from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -37,9 +36,10 @@ from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \ _parse_datatype_string +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.utils import install_exception_handler -__all__ = ["SparkSession"] +__all__ = ["SparkSession", "UDFRegistration"] def _monkey_patch_RDD(sparkSession): @@ -280,6 +280,7 @@ def catalog(self): :return: :class:`Catalog` """ + from pyspark.sql.catalog import Catalog if not hasattr(self, "_catalog"): self._catalog = Catalog(self) return self._catalog @@ -291,8 +292,7 @@ def udf(self): :return: :class:`UDFRegistration` """ - from pyspark.sql.context import UDFRegistration - return UDFRegistration(self._wrapped) + return UDFRegistration(self) @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): @@ -778,6 +778,146 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.stop() +class UDFRegistration(object): + """Wrapper for user-defined function registration.""" + + def __init__(self, sparkSession): + self.sparkSession = sparkSession + + @ignore_unicode_prefix + def register(self, name, f, returnType=None): + """Registers a Python function (including lambda function) or a user-defined function + in SQL statements. + + :param name: name of the user-defined function in SQL statements. + :param f: a Python function, or a user-defined function. The user-defined function can + be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and + :meth:`pyspark.sql.functions.pandas_udf`. + :param returnType: the return type of the registered user-defined function. + :return: a user-defined function. + + `returnType` can be optionally specified when `f` is a Python function but not + when `f` is a user-defined function. See below: + + 1. When `f` is a Python function, `returnType` defaults to string type and can be + optionally specified. The produced object must match the specified type. In this case, + this API works as if `register(name, f, returnType=StringType())`. + + >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + 2. When `f` is a user-defined function, Spark uses the return type of the given a + user-defined function as the return type of the registered a user-defined function. + `returnType` should not be specified. In this case, this API works as if + `register(name, f)`. + + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> new_random_udf = spark.udf.register("random_udf", random_udf) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=82)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + """ + + # This is to check whether the input function is from a user-defined function or + # Python function. + if hasattr(f, 'asNondeterministic'): + if returnType is not None: + raise TypeError( + "Invalid returnType: data type can not be specified when f is" + "a user-defined function, but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f + else: + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf + + @ignore_unicode_prefix + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a Java user-defined function so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + + :param name: name of the user-defined function + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.registerJavaFunction("javaStringLength", + ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> spark.sql("SELECT javaStringLength('test')").collect() + [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction("javaStringLength2", + ... "test.org.apache.spark.sql.JavaStringLength") + >>> spark.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF:javaStringLength2(test)=4)] + + """ + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + + @ignore_unicode_prefix + def registerJavaUDAF(self, name, javaClassName): + """Register a Java user-defined aggregate function so it can be used in SQL statements. + + :param name: name of the user-defined aggregate function + :param javaClassName: fully qualified name of java class + + >>> spark.udf.registerJavaUDAF("javaUDAF", + ... "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + + def _test(): import os import doctest diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8906618666b14..d30e1888695c6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -372,6 +372,12 @@ def test_udf(self): [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. + sqlContext = self.spark._wrapped + sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) + [row] = sqlContext.sql("SELECT oneArg('test')").collect() + self.assertEqual(row[0], 4) + def test_udf2(self): self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ @@ -582,11 +588,21 @@ def test_non_existed_udf(self): self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. + sqlContext = spark._wrapped + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) + def test_non_existed_udaf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) + # This is to check if a deprecated 'SQLContext.registerJavaUDAF' can call its alias. + sqlContext = spark._wrapped + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", + lambda: sqlContext.registerJavaUDAF("udaf1", "non_existed_udaf")) + def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", From 08438ee7d8c209a2dcb3eb4efeeef77451feb8d7 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 17 Jan 2018 14:46:38 +0900 Subject: [PATCH 02/11] Focus on the issue itself --- python/pyspark/sql/context.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c4097b0dac578..9f7c7c72e3c2e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -47,8 +47,6 @@ class SQLContext(object): :param sparkSession: The :class:`SparkSession` around which this SQLContext wraps. :param jsqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. - - .. note:: Deprecated in 2.3.0. Use SparkSession.builder.getOrCreate(). """ _instantiatedContext = None @@ -71,10 +69,6 @@ def __init__(self, sparkContext, sparkSession=None, jsqlContext=None): >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ - warnings.warn( - "Deprecated in 2.0.0. Use SparkSession.builder.getOrCreate() instead.", - DeprecationWarning) - self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm From 6b9b9c44ea7cafa7e1fb607bcf5a2d19336f31f4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 17 Jan 2018 20:58:54 +0900 Subject: [PATCH 03/11] Address comments and clean up --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/catalog.py | 4 +- python/pyspark/sql/context.py | 27 ++--- python/pyspark/sql/session.py | 144 +------------------------ python/pyspark/sql/tests.py | 14 ++- python/pyspark/sql/udf.py | 183 +++++++++++++++++++++++++++++++- 6 files changed, 205 insertions(+), 168 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7164180a6a7b0..b900f0bd913c3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -400,6 +400,7 @@ def __hash__(self): "pyspark.sql.functions", "pyspark.sql.readwriter", "pyspark.sql.streaming", + "pyspark.sql.udf", "pyspark.sql.window", "pyspark.sql.tests", ] diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index b01bc0b8e268b..b64eb8acd3b91 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -231,12 +231,14 @@ def registerFunction(self, name, f, returnType=None): DeprecationWarning) return self._sparkSession.udf.register(name, f, returnType) # Reuse the docstring from UDFRegistration but with few notes. + _register_doc = UDFRegistration.register.__doc__.strip() registerFunction.__doc__ = """%s + .. note:: :func:`spark.catalog.registerFunction` is an alias for :func:`spark.udf.register`. .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. .. versionadded:: 2.0 - """ % UDFRegistration.register.__doc__ + """ % _register_doc[:_register_doc.rfind('versionadded::')] @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9f7c7c72e3c2e..aaf28889274d0 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -147,8 +147,7 @@ def udf(self): :return: :class:`UDFRegistration` """ - from pyspark.sql.session import UDFRegistration - return UDFRegistration(self.sparkSession) + return self.sparkSession.udf @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -179,36 +178,28 @@ def registerFunction(self, name, f, returnType=None): DeprecationWarning) return self.sparkSession.udf.register(name, f, returnType) # Reuse the docstring from UDFRegistration but with few notes. + _register_doc = UDFRegistration.register.__doc__.strip() registerFunction.__doc__ = """%s + .. note:: :func:`sqlContext.registerFunction` is an alias for :func:`spark.udf.register`. .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. .. versionadded:: 1.2 - """ % UDFRegistration.register.__doc__ + """ % _register_doc[:_register_doc.rfind('versionadded::')] def registerJavaFunction(self, name, javaClassName, returnType=None): warnings.warn( "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.", DeprecationWarning) return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType) + _registerJavaFunction_doc = UDFRegistration.registerJavaFunction.__doc__.strip() registerJavaFunction.__doc__ = """%s + .. note:: :func:`sqlContext.registerJavaFunction` is an alias for :func:`spark.udf.registerJavaFunction` .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. .. versionadded:: 2.1 - """ % UDFRegistration.registerJavaFunction.__doc__ - - def registerJavaUDAF(self, name, javaClassName): - warnings.warn( - "Deprecated in 2.3.0. Use spark.udf.registerJavaUDAF instead.", - DeprecationWarning) - return self.sparkSession.udf.registerJavaUDAF(name, javaClassName) - registerJavaUDAF.__doc__ = """%s - .. note:: :func:`sqlContext.registerJavaUDAF` is an alias for - :func:`spark.udf.registerJavaUDAF`. - .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaUDAF` instead. - .. versionadded:: 2.3 - """ % UDFRegistration.registerJavaUDAF.__doc__ + """ % _registerJavaFunction_doc[:_registerJavaFunction_doc.rfind('versionadded::')] # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -536,9 +527,9 @@ def _test(): globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - # 'spark' alias is a small hack for reusing doctests. Please see the reassignment + # 'spark' is used for reusing doctests. Please see the reassignment # of docstrings above. - globs['spark'] = globs['sqlContext'] + globs['spark'] = globs['sqlContext'].sparkSession globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e71552726aea2..0e126452c2879 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -36,10 +36,10 @@ from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \ _parse_datatype_string -from pyspark.sql.udf import UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction, UDFRegistration from pyspark.sql.utils import install_exception_handler -__all__ = ["SparkSession", "UDFRegistration"] +__all__ = ["SparkSession"] def _monkey_patch_RDD(sparkSession): @@ -778,146 +778,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.stop() -class UDFRegistration(object): - """Wrapper for user-defined function registration.""" - - def __init__(self, sparkSession): - self.sparkSession = sparkSession - - @ignore_unicode_prefix - def register(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a user-defined function - in SQL statements. - - :param name: name of the user-defined function in SQL statements. - :param f: a Python function, or a user-defined function. The user-defined function can - be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and - :meth:`pyspark.sql.functions.pandas_udf`. - :param returnType: the return type of the registered user-defined function. - :return: a user-defined function. - - `returnType` can be optionally specified when `f` is a Python function but not - when `f` is a user-defined function. See below: - - 1. When `f` is a Python function, `returnType` defaults to string type and can be - optionally specified. The produced object must match the specified type. In this case, - this API works as if `register(name, f, returnType=StringType())`. - - >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - 2. When `f` is a user-defined function, Spark uses the return type of the given a - user-defined function as the return type of the registered a user-defined function. - `returnType` should not be specified. In this case, this API works as if - `register(name, f)`. - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = spark.udf.register("slen", slen) - >>> spark.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = spark.udf.register("random_udf", random_udf) - >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP - >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - """ - - # This is to check whether the input function is from a user-defined function or - # Python function. - if hasattr(f, 'asNondeterministic'): - if returnType is not None: - raise TypeError( - "Invalid returnType: data type can not be specified when f is" - "a user-defined function, but got %s." % returnType) - if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: - raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") - register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, - evalType=f.evalType, - deterministic=f.deterministic) - return_udf = f - else: - if returnType is None: - returnType = StringType() - register_udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - return_udf = register_udf._wrapped() - self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) - return return_udf - - @ignore_unicode_prefix - def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a Java user-defined function so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not specified we would infer it via reflection. - - :param name: name of the user-defined function - :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object - - >>> from pyspark.sql.types import IntegerType - >>> spark.udf.registerJavaFunction("javaStringLength", - ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) - >>> spark.sql("SELECT javaStringLength('test')").collect() - [Row(UDF:javaStringLength(test)=4)] - >>> spark.udf.registerJavaFunction("javaStringLength2", - ... "test.org.apache.spark.sql.JavaStringLength") - >>> spark.sql("SELECT javaStringLength2('test')").collect() - [Row(UDF:javaStringLength2(test)=4)] - - """ - jdt = None - if returnType is not None: - jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) - self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) - - @ignore_unicode_prefix - def registerJavaUDAF(self, name, javaClassName): - """Register a Java user-defined aggregate function so it can be used in SQL statements. - - :param name: name of the user-defined aggregate function - :param javaClassName: fully qualified name of java class - - >>> spark.udf.registerJavaUDAF("javaUDAF", - ... "test.org.apache.spark.sql.MyDoubleAvg") - >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") - >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() - [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] - """ - self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) - - def _test(): import os import doctest diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d30e1888695c6..f84aa3d68b808 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -583,6 +583,15 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + # This is to check if a 'SQLContext.udf' can call its alias. + sqlContext = self.spark._wrapped + add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_four(id) AS plus_four").collect(), + df.select(add_four("id").alias("plus_four")).collect() + ) + def test_non_existed_udf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", @@ -598,11 +607,6 @@ def test_non_existed_udaf(self): self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - # This is to check if a deprecated 'SQLContext.registerJavaUDAF' can call its alias. - sqlContext = spark._wrapped - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", - lambda: sqlContext.registerJavaUDAF("udaf1", "non_existed_udaf")) - def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e80ab9165867..ef477001a398d 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -19,11 +19,13 @@ """ import functools -from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark import SparkContext, since +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +__all__ = ["UDFRegistration"] + def _wrap_function(sc, func, returnType): command = (func, returnType) @@ -181,3 +183,180 @@ def asNondeterministic(self): """ self.deterministic = False return self + + +class UDFRegistration(object): + """ + Wrapper for user-defined function registration. + + .. versionadded:: 1.3.1 + """ + + def __init__(self, sparkSession): + self.sparkSession = sparkSession + + @ignore_unicode_prefix + @since(1.3) + def register(self, name, f, returnType=None): + """Registers a Python function (including lambda function) or a user-defined function + in SQL statements. + + :param name: name of the user-defined function in SQL statements. + :param f: a Python function, or a user-defined function. The user-defined function can + be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and + :meth:`pyspark.sql.functions.pandas_udf`. + :param returnType: the return type of the registered user-defined function. + :return: a user-defined function. + + `returnType` can be optionally specified when `f` is a Python function but not + when `f` is a user-defined function. See below: + + 1. When `f` is a Python function: + + `returnType` defaults to string type and can be optionally specified. The produced + object must match the specified type. In this case, this API works as if + `register(name, f, returnType=StringType())`. + + >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + + 2. When `f` is a user-defined function: + + Spark uses the return type of the given user-defined function as the return type of + the registered user-defined function. `returnType` should not be specified. + In this case, this API works as if `register(name, f)`. + + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> new_random_udf = spark.udf.register("random_udf", random_udf) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=82)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + + .. note:: Registration for a user-defined function (case 2.) was added from + Spark 2.3.0. + """ + + # This is to check whether the input function is from a user-defined function or + # Python function. + if hasattr(f, 'asNondeterministic'): + if returnType is not None: + raise TypeError( + "Invalid returnType: data type can not be specified when f is" + "a user-defined function, but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f + else: + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf + + @ignore_unicode_prefix + @since(2.3) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a Java user-defined function so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + + :param name: name of the user-defined function + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.registerJavaFunction("javaStringLength", + ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> spark.sql("SELECT javaStringLength('test')").collect() + [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction("javaStringLength2", + ... "test.org.apache.spark.sql.JavaStringLength") + >>> spark.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF:javaStringLength2(test)=4)] + """ + + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a Java user-defined aggregate function so it can be used in SQL statements. + + :param name: name of the user-defined aggregate function + :param javaClassName: fully qualified name of java class + + >>> spark.udf.registerJavaUDAF("javaUDAF", + ... "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.udf + globs = pyspark.sql.udf.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.udf tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.udf, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From f1fe40a5afe876cf3b81208af7bc1cd379bcb732 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 17 Jan 2018 21:07:38 +0900 Subject: [PATCH 04/11] Clean up imports --- python/pyspark/sql/catalog.py | 3 +-- python/pyspark/sql/context.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index b64eb8acd3b91..2b362760ecca7 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -21,8 +21,7 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.dataframe import DataFrame -from pyspark.sql.session import UDFRegistration -from pyspark.sql.udf import UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction, UDFRegistration from pyspark.sql.types import IntegerType, StringType, StructType diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index aaf28889274d0..af0031951ae7e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -24,11 +24,12 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql.session import _monkey_patch_RDD, SparkSession, UDFRegistration +from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import IntegerType, Row, StringType +from pyspark.sql.udf import UDFRegistration from pyspark.sql.utils import install_exception_handler __all__ = ["SQLContext", "HiveContext"] From c6ed44a7e125ff5e86b9734b753c07e7dc82f5a9 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 17 Jan 2018 21:23:43 +0900 Subject: [PATCH 05/11] Minor doc fix --- python/pyspark/sql/udf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index ef477001a398d..444b77689ef51 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -209,7 +209,7 @@ def register(self, name, f, returnType=None): :return: a user-defined function. `returnType` can be optionally specified when `f` is a Python function but not - when `f` is a user-defined function. See below: + when `f` is a user-defined function. Please see below. 1. When `f` is a Python function: @@ -234,7 +234,6 @@ def register(self, name, f, returnType=None): >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] - 2. When `f` is a user-defined function: Spark uses the return type of the given user-defined function as the return type of From 08ffa1ca2c332205eea370e4d3ce0489eb97424a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 17 Jan 2018 23:05:45 +0900 Subject: [PATCH 06/11] Fix minor nits found --- python/pyspark/sql/catalog.py | 2 +- python/pyspark/sql/context.py | 4 ++-- python/pyspark/sql/udf.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 2b362760ecca7..f91ebd7775926 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -237,7 +237,7 @@ def registerFunction(self, name, f, returnType=None): for :func:`spark.udf.register`. .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. .. versionadded:: 2.0 - """ % _register_doc[:_register_doc.rfind('versionadded::')] + """ % _register_doc[:_register_doc.rfind('.. versionadded::')] @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index af0031951ae7e..07048992f12af 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -186,7 +186,7 @@ def registerFunction(self, name, f, returnType=None): :func:`spark.udf.register`. .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. .. versionadded:: 1.2 - """ % _register_doc[:_register_doc.rfind('versionadded::')] + """ % _register_doc[:_register_doc.rfind('.. versionadded::')] def registerJavaFunction(self, name, javaClassName, returnType=None): warnings.warn( @@ -200,7 +200,7 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): :func:`spark.udf.registerJavaFunction` .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. .. versionadded:: 2.1 - """ % _registerJavaFunction_doc[:_registerJavaFunction_doc.rfind('versionadded::')] + """ % _registerJavaFunction_doc[:_registerJavaFunction_doc.rfind('.. versionadded::')] # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 444b77689ef51..332c821210afb 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -187,7 +187,8 @@ def asNondeterministic(self): class UDFRegistration(object): """ - Wrapper for user-defined function registration. + Wrapper for user-defined function registration. This instance can be accessed by + `spark.udf` or `sqlContext.udf`. .. versionadded:: 1.3.1 """ @@ -300,17 +301,17 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): In addition to a name and the function itself, the return type can be optionally specified. When the return type is not specified we would infer it via reflection. - :param name: name of the user-defined function + :param name: name of the user-defined function :param javaClassName: fully qualified name of java class :param returnType: a :class:`pyspark.sql.types.DataType` object >>> from pyspark.sql.types import IntegerType - >>> spark.udf.registerJavaFunction("javaStringLength", - ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> spark.udf.registerJavaFunction( + ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) >>> spark.sql("SELECT javaStringLength('test')").collect() [Row(UDF:javaStringLength(test)=4)] - >>> spark.udf.registerJavaFunction("javaStringLength2", - ... "test.org.apache.spark.sql.JavaStringLength") + >>> spark.udf.registerJavaFunction( + ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") >>> spark.sql("SELECT javaStringLength2('test')").collect() [Row(UDF:javaStringLength2(test)=4)] """ @@ -328,8 +329,7 @@ def registerJavaUDAF(self, name, javaClassName): :param name: name of the user-defined aggregate function :param javaClassName: fully qualified name of java class - >>> spark.udf.registerJavaUDAF("javaUDAF", - ... "test.org.apache.spark.sql.MyDoubleAvg") + >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) >>> df.registerTempTable("df") >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() From 4367beb7f165328d2b7357c27ba1e34ddf112825 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 17 Jan 2018 23:07:30 +0900 Subject: [PATCH 07/11] one more space --- python/pyspark/sql/udf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 332c821210afb..5b6cba9c732b7 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -311,7 +311,7 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): >>> spark.sql("SELECT javaStringLength('test')").collect() [Row(UDF:javaStringLength(test)=4)] >>> spark.udf.registerJavaFunction( - ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") + ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") >>> spark.sql("SELECT javaStringLength2('test')").collect() [Row(UDF:javaStringLength2(test)=4)] """ From 3e0147bd11b980d91a2b628b85c5d6a05391b28e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 10:40:42 +0900 Subject: [PATCH 08/11] Use link instead of doc copy --- python/pyspark/sql/catalog.py | 17 +++++++---------- python/pyspark/sql/context.py | 32 ++++++++++++-------------------- python/pyspark/sql/udf.py | 4 ++-- 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index f91ebd7775926..6aef0f22340be 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -21,7 +21,7 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.dataframe import DataFrame -from pyspark.sql.udf import UserDefinedFunction, UDFRegistration +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import IntegerType, StringType, StructType @@ -224,20 +224,17 @@ def dropGlobalTempView(self, viewName): """ self._jcatalog.dropGlobalTempView(viewName) + @since(2.0) def registerFunction(self, name, f, returnType=None): + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. + """ warnings.warn( "Deprecated in 2.3.0. Use spark.udf.register instead.", DeprecationWarning) return self._sparkSession.udf.register(name, f, returnType) - # Reuse the docstring from UDFRegistration but with few notes. - _register_doc = UDFRegistration.register.__doc__.strip() - registerFunction.__doc__ = """%s - - .. note:: :func:`spark.catalog.registerFunction` is an alias - for :func:`spark.udf.register`. - .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. - .. versionadded:: 2.0 - """ % _register_doc[:_register_doc.rfind('.. versionadded::')] @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 07048992f12af..cc1cd1a5842d9 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -173,34 +173,29 @@ def range(self, start, end=None, step=1, numPartitions=None): """ return self.sparkSession.range(start, end, step, numPartitions) + @since(1.2) def registerFunction(self, name, f, returnType=None): + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. + """ warnings.warn( "Deprecated in 2.3.0. Use spark.udf.register instead.", DeprecationWarning) return self.sparkSession.udf.register(name, f, returnType) - # Reuse the docstring from UDFRegistration but with few notes. - _register_doc = UDFRegistration.register.__doc__.strip() - registerFunction.__doc__ = """%s - - .. note:: :func:`sqlContext.registerFunction` is an alias for - :func:`spark.udf.register`. - .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. - .. versionadded:: 1.2 - """ % _register_doc[:_register_doc.rfind('.. versionadded::')] + @since(2.1) def registerJavaFunction(self, name, javaClassName, returnType=None): + """An alias for :func:`spark.udf.registerJavaFunction`. + See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`. + + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. + """ warnings.warn( "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.", DeprecationWarning) return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType) - _registerJavaFunction_doc = UDFRegistration.registerJavaFunction.__doc__.strip() - registerJavaFunction.__doc__ = """%s - - .. note:: :func:`sqlContext.registerJavaFunction` is an alias for - :func:`spark.udf.registerJavaFunction` - .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. - .. versionadded:: 2.1 - """ % _registerJavaFunction_doc[:_registerJavaFunction_doc.rfind('.. versionadded::')] # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -528,9 +523,6 @@ def _test(): globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - # 'spark' is used for reusing doctests. Please see the reassignment - # of docstrings above. - globs['spark'] = globs['sqlContext'].sparkSession globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5b6cba9c732b7..1943bb73f9ac2 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -188,7 +188,7 @@ def asNondeterministic(self): class UDFRegistration(object): """ Wrapper for user-defined function registration. This instance can be accessed by - `spark.udf` or `sqlContext.udf`. + :attr:`spark.udf` or :attr:`sqlContext.udf`. .. versionadded:: 1.3.1 """ @@ -197,7 +197,7 @@ def __init__(self, sparkSession): self.sparkSession = sparkSession @ignore_unicode_prefix - @since(1.3) + @since("1.3.1") def register(self, name, f, returnType=None): """Registers a Python function (including lambda function) or a user-defined function in SQL statements. From c9512a66800709417425c0d348c9327ed681420d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 10:45:29 +0900 Subject: [PATCH 09/11] Clean up imports --- python/pyspark/sql/session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 0e126452c2879..4c573534908cc 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -36,7 +36,7 @@ from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \ _parse_datatype_string -from pyspark.sql.udf import UserDefinedFunction, UDFRegistration +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.utils import install_exception_handler __all__ = ["SparkSession"] @@ -292,6 +292,7 @@ def udf(self): :return: :class:`UDFRegistration` """ + from pyspark.sql.udf import UDFRegistration return UDFRegistration(self) @since(2.0) From 00f5d1973dae4b7135b4de727e8a2c447ffb65cf Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 10:47:46 +0900 Subject: [PATCH 10/11] Clean up imports --- python/pyspark/sql/session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 4c573534908cc..22ffa6fa33513 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -36,7 +36,6 @@ from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \ _parse_datatype_string -from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.utils import install_exception_handler __all__ = ["SparkSession"] From e121273972d0ec0d94cc01e4426358b4e5fb7e2c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 10:48:36 +0900 Subject: [PATCH 11/11] Clean up imports --- python/pyspark/sql/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 22ffa6fa33513..6c84023c43fb6 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -28,7 +28,7 @@ from itertools import izip as zip, imap as map from pyspark import since -from pyspark.rdd import RDD, ignore_unicode_prefix, PythonEvalType +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader