Skip to content

Commit

Permalink
Address comments and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Jan 17, 2018
1 parent 08438ee commit 6b9b9c4
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 168 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def __hash__(self):
"pyspark.sql.functions",
"pyspark.sql.readwriter",
"pyspark.sql.streaming",
"pyspark.sql.udf",
"pyspark.sql.window",
"pyspark.sql.tests",
]
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,14 @@ def registerFunction(self, name, f, returnType=None):
DeprecationWarning)
return self._sparkSession.udf.register(name, f, returnType)
# Reuse the docstring from UDFRegistration but with few notes.
_register_doc = UDFRegistration.register.__doc__.strip()
registerFunction.__doc__ = """%s
.. note:: :func:`spark.catalog.registerFunction` is an alias
for :func:`spark.udf.register`.
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
.. versionadded:: 2.0
""" % UDFRegistration.register.__doc__
""" % _register_doc[:_register_doc.rfind('versionadded::')]

@since(2.0)
def isCached(self, tableName):
Expand Down
27 changes: 9 additions & 18 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ def udf(self):
:return: :class:`UDFRegistration`
"""
from pyspark.sql.session import UDFRegistration
return UDFRegistration(self.sparkSession)
return self.sparkSession.udf

@since(1.4)
def range(self, start, end=None, step=1, numPartitions=None):
Expand Down Expand Up @@ -179,36 +178,28 @@ def registerFunction(self, name, f, returnType=None):
DeprecationWarning)
return self.sparkSession.udf.register(name, f, returnType)
# Reuse the docstring from UDFRegistration but with few notes.
_register_doc = UDFRegistration.register.__doc__.strip()
registerFunction.__doc__ = """%s
.. note:: :func:`sqlContext.registerFunction` is an alias for
:func:`spark.udf.register`.
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
.. versionadded:: 1.2
""" % UDFRegistration.register.__doc__
""" % _register_doc[:_register_doc.rfind('versionadded::')]

def registerJavaFunction(self, name, javaClassName, returnType=None):
warnings.warn(
"Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.",
DeprecationWarning)
return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType)
_registerJavaFunction_doc = UDFRegistration.registerJavaFunction.__doc__.strip()
registerJavaFunction.__doc__ = """%s
.. note:: :func:`sqlContext.registerJavaFunction` is an alias for
:func:`spark.udf.registerJavaFunction`
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead.
.. versionadded:: 2.1
""" % UDFRegistration.registerJavaFunction.__doc__

def registerJavaUDAF(self, name, javaClassName):
warnings.warn(
"Deprecated in 2.3.0. Use spark.udf.registerJavaUDAF instead.",
DeprecationWarning)
return self.sparkSession.udf.registerJavaUDAF(name, javaClassName)
registerJavaUDAF.__doc__ = """%s
.. note:: :func:`sqlContext.registerJavaUDAF` is an alias for
:func:`spark.udf.registerJavaUDAF`.
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaUDAF` instead.
.. versionadded:: 2.3
""" % UDFRegistration.registerJavaUDAF.__doc__
""" % _registerJavaFunction_doc[:_registerJavaFunction_doc.rfind('versionadded::')]

# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
Expand Down Expand Up @@ -536,9 +527,9 @@ def _test():
globs['os'] = os
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
# 'spark' alias is a small hack for reusing doctests. Please see the reassignment
# 'spark' is used for reusing doctests. Please see the reassignment
# of docstrings above.
globs['spark'] = globs['sqlContext']
globs['spark'] = globs['sqlContext'].sparkSession
globs['rdd'] = rdd = sc.parallelize(
[Row(field1=1, field2="row1"),
Row(field1=2, field2="row2"),
Expand Down
144 changes: 2 additions & 142 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \
_make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \
_parse_datatype_string
from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.udf import UserDefinedFunction, UDFRegistration
from pyspark.sql.utils import install_exception_handler

__all__ = ["SparkSession", "UDFRegistration"]
__all__ = ["SparkSession"]


def _monkey_patch_RDD(sparkSession):
Expand Down Expand Up @@ -778,146 +778,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()


class UDFRegistration(object):
"""Wrapper for user-defined function registration."""

def __init__(self, sparkSession):
self.sparkSession = sparkSession

@ignore_unicode_prefix
def register(self, name, f, returnType=None):
"""Registers a Python function (including lambda function) or a user-defined function
in SQL statements.
:param name: name of the user-defined function in SQL statements.
:param f: a Python function, or a user-defined function. The user-defined function can
be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
:meth:`pyspark.sql.functions.pandas_udf`.
:param returnType: the return type of the registered user-defined function.
:return: a user-defined function.
`returnType` can be optionally specified when `f` is a Python function but not
when `f` is a user-defined function. See below:
1. When `f` is a Python function, `returnType` defaults to string type and can be
optionally specified. The produced object must match the specified type. In this case,
this API works as if `register(name, f, returnType=StringType())`.
>>> strlen = spark.udf.register("stringLengthString", lambda x: len(x))
>>> spark.sql("SELECT stringLengthString('test')").collect()
[Row(stringLengthString(test)=u'4')]
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
[Row(stringLengthString(text)=u'3')]
>>> from pyspark.sql.types import IntegerType
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
2. When `f` is a user-defined function, Spark uses the return type of the given a
user-defined function as the return type of the registered a user-defined function.
`returnType` should not be specified. In this case, this API works as if
`register(name, f)`.
>>> from pyspark.sql.types import IntegerType
>>> from pyspark.sql.functions import udf
>>> slen = udf(lambda s: len(s), IntegerType())
>>> _ = spark.udf.register("slen", slen)
>>> spark.sql("SELECT slen('test')").collect()
[Row(slen(test)=4)]
>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> new_random_udf = spark.udf.register("random_udf", random_udf)
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=82)]
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
"""

# This is to check whether the input function is from a user-defined function or
# Python function.
if hasattr(f, 'asNondeterministic'):
if returnType is not None:
raise TypeError(
"Invalid returnType: data type can not be specified when f is"
"a user-defined function, but got %s." % returnType)
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
raise ValueError(
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
evalType=f.evalType,
deterministic=f.deterministic)
return_udf = f
else:
if returnType is None:
returnType = StringType()
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)
return_udf = register_udf._wrapped()
self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
return return_udf

@ignore_unicode_prefix
def registerJavaFunction(self, name, javaClassName, returnType=None):
"""Register a Java user-defined function so it can be used in SQL statements.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not specified we would infer it via reflection.
:param name: name of the user-defined function
:param javaClassName: fully qualified name of java class
:param returnType: a :class:`pyspark.sql.types.DataType` object
>>> from pyspark.sql.types import IntegerType
>>> spark.udf.registerJavaFunction("javaStringLength",
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
>>> spark.sql("SELECT javaStringLength('test')").collect()
[Row(UDF:javaStringLength(test)=4)]
>>> spark.udf.registerJavaFunction("javaStringLength2",
... "test.org.apache.spark.sql.JavaStringLength")
>>> spark.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF:javaStringLength2(test)=4)]
"""
jdt = None
if returnType is not None:
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)

@ignore_unicode_prefix
def registerJavaUDAF(self, name, javaClassName):
"""Register a Java user-defined aggregate function so it can be used in SQL statements.
:param name: name of the user-defined aggregate function
:param javaClassName: fully qualified name of java class
>>> spark.udf.registerJavaUDAF("javaUDAF",
... "test.org.apache.spark.sql.MyDoubleAvg")
>>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
>>> df.registerTempTable("df")
>>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
"""
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)


def _test():
import os
import doctest
Expand Down
14 changes: 9 additions & 5 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,15 @@ def test_udf_registration_returns_udf(self):
df.select(add_three("id").alias("plus_three")).collect()
)

# This is to check if a 'SQLContext.udf' can call its alias.
sqlContext = self.spark._wrapped
add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())

self.assertListEqual(
df.selectExpr("add_four(id) AS plus_four").collect(),
df.select(add_four("id").alias("plus_four")).collect()
)

def test_non_existed_udf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
Expand All @@ -598,11 +607,6 @@ def test_non_existed_udaf(self):
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))

# This is to check if a deprecated 'SQLContext.registerJavaUDAF' can call its alias.
sqlContext = spark._wrapped
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
lambda: sqlContext.registerJavaUDAF("udaf1", "non_existed_udaf"))

def test_multiLine_json(self):
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
Expand Down
Loading

0 comments on commit 6b9b9c4

Please sign in to comment.