Skip to content

Commit

Permalink
[SPARK-23836][PYTHON] Add support for StructType return in Scalar Pan…
Browse files Browse the repository at this point in the history
…das UDF

This change adds support for returning StructType from a scalar Pandas UDF, where the return value of the function is a pandas.DataFrame. Nested structs are not supported and an error will be raised, child types can be any other type currently supported.

Added additional unit tests to `test_pandas_udf_scalar`

Closes apache#23900 from BryanCutler/pyspark-support-scalar_udf-StructType-SPARK-23836.

Authored-by: Bryan Cutler <[email protected]>
Signed-off-by: Bryan Cutler <[email protected]>
  • Loading branch information
BryanCutler authored and simon-slowik committed Jun 26, 2019
1 parent 02f03b7 commit e8193ed
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 12 deletions.
39 changes: 35 additions & 4 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
else:
import pickle
protocol = 3
basestring = unicode = str
xrange = range

from pyspark import cloudpickle
Expand Down Expand Up @@ -245,7 +246,7 @@ def __repr__(self):
return "ArrowStreamSerializer"


def _create_batch(series, timezone, safecheck):
def _create_batch(series, timezone, safecheck, assign_cols_by_name):
"""
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
Expand All @@ -255,6 +256,7 @@ def _create_batch(series, timezone, safecheck):
"""
import decimal
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
Expand Down Expand Up @@ -296,7 +298,34 @@ def create_array(s, t):
raise RuntimeError(error_msg % (s.dtype, t), e)
return array

arrs = [create_array(s, t) for s, t in series]
arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s)))

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns):
arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t]
# Assign result columns by position
else:
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
for i, field in enumerate(t)]

struct_arrs, struct_names = zip(*arrs_names)

# TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
else:
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])


Expand All @@ -305,10 +334,11 @@ class ArrowStreamPandasSerializer(Serializer):
Serializes Pandas.Series as Arrow data with Arrow streaming format.
"""

def __init__(self, timezone, safecheck):
def __init__(self, timezone, safecheck, assign_cols_by_name):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
Expand All @@ -327,7 +357,8 @@ def dump_stream(self, iterator, stream):
writer = None
try:
for series in iterator:
batch = _create_batch(series, self._timezone, self._safecheck)
batch = _create_batch(series, self._timezone, self._safecheck,
self._assign_cols_by_name)
if writer is None:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2872,8 +2872,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`.
:class:`MapType`, :class:`StructType` are currently not supported as output types.
:class:`MapType`, nested :class:`StructType` are currently not supported as output types.
Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.
Expand All @@ -2898,6 +2899,15 @@ def pandas_udf(f=None, returnType=None, functionType=None):
+----------+--------------+------------+
| 8| JOHN DOE| 22|
+----------+--------------+------------+
>>> @pandas_udf("first string, last string") # doctest: +SKIP
... def split_expand(n):
... return n.str.split(expand=True)
>>> df.select(split_expand("name")).show() # doctest: +SKIP
+------------------+
|split_expand(name)|
+------------------+
| [John, Doe]|
+------------------+
.. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
column, but is the length of an internal batch used for each call to the function.
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):

# Create Arrow record batches
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
timezone, safecheck)
timezone, safecheck, col_by_name)
for pdf_slice in pdf_slices]

# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def test_unsupported_types(self):
StructField('map', MapType(StringType(), IntegerType())),
StructField('arr_ts', ArrayType(TimestampType())),
StructField('null', NullType()),
StructField('struct', StructType([StructField('l', LongType())])),
]

# TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
Expand Down
81 changes: 80 additions & 1 deletion python/pyspark/sql/tests/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
import time
import unittest

if sys.version >= '3':
unicode = str

from datetime import date, datetime
from decimal import Decimal
from distutils.version import LooseVersion

from pyspark.rdd import PythonEvalType
from pyspark.sql import Column
from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf
from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf
from pyspark.sql.types import Row
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
Expand Down Expand Up @@ -265,6 +268,64 @@ def test_vectorized_udf_null_array(self):
result = df.select(array_f(col('array')))
self.assertEquals(df.collect(), result.collect())

def test_vectorized_udf_struct_type(self):
import pandas as pd

df = self.spark.range(10)
return_type = StructType([
StructField('id', LongType()),
StructField('str', StringType())])

def func(id):
return pd.DataFrame({'id': id, 'str': id.apply(unicode)})

f = pandas_udf(func, returnType=return_type)

expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
.alias('struct')).collect()

actual = df.select(f(col('id')).alias('struct')).collect()
self.assertEqual(expected, actual)

g = pandas_udf(func, 'id: long, str: string')
actual = df.select(g(col('id')).alias('struct')).collect()
self.assertEqual(expected, actual)

def test_vectorized_udf_struct_complex(self):
import pandas as pd

df = self.spark.range(10)
return_type = StructType([
StructField('ts', TimestampType()),
StructField('arr', ArrayType(LongType()))])

@pandas_udf(returnType=return_type)
def f(id):
return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)),
'arr': id.apply(lambda i: [i, i + 1])})

actual = df.withColumn('f', f(col('id'))).collect()
for i, row in enumerate(actual):
id, f = row
self.assertEqual(i, id)
self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0])
self.assertListEqual([i, i + 1], f[1])

def test_vectorized_udf_nested_struct(self):
nested_type = StructType([
StructField('id', IntegerType()),
StructField('nested', StructType([
StructField('foo', StringType()),
StructField('bar', FloatType())
]))
])

with QuietTest(self.sc):
with self.assertRaisesRegexp(
Exception,
'Invalid returnType with scalar Pandas UDFs'):
pandas_udf(lambda x: x, returnType=nested_type)

def test_vectorized_udf_complex(self):
df = self.spark.range(10).select(
col('id').cast('int').alias('a'),
Expand Down Expand Up @@ -331,6 +392,20 @@ def test_vectorized_udf_empty_partition(self):
res = df.select(f(col('id')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_struct_with_empty_partition(self):
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\
.withColumn('name', lit('John Doe'))

@pandas_udf("first string, last string")
def split_expand(n):
return n.str.split(expand=True)

result = df.select(split_expand('name')).collect()
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual('John', row[0]['first'])
self.assertEqual('Doe', row[0]['last'])

def test_vectorized_udf_varargs(self):
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
f = pandas_udf(lambda *v: v[0], LongType())
Expand All @@ -343,6 +418,10 @@ def test_vectorized_udf_unsupported_types(self):
NotImplementedError,
'Invalid returnType.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*scalar Pandas UDF.*ArrayType.StructType'):
pandas_udf(lambda x: x, ArrayType(StructType([StructField('a', IntegerType())])))

def test_vectorized_udf_dates(self):
schema = StructType().add("idx", LongType()).add("date", DateType())
Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,9 +1616,15 @@ def to_arrow_type(dt):
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
if type(dt.elementType) == TimestampType:
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
elif type(dt) == StructType:
if any(type(field.dataType) == StructType for field in dt):
raise TypeError("Nested StructType not supported in conversion to Arrow")
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in dt]
arrow_type = pa.struct(fields)
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def returnType(self):
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
if isinstance(self._returnType_placeholder, StructType):
try:
to_arrow_schema(self._returnType_placeholder)
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with grouped map Pandas UDFs: "
Expand All @@ -134,6 +134,9 @@ def returnType(self):
"UDFs: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
try:
# StructType is not yet allowed as a return type, explicitly check here to fail fast
if isinstance(self._returnType_placeholder, StructType):
raise TypeError
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
Expand Down
12 changes: 9 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
from pyspark.sql.types import to_arrow_type, StructType
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle

Expand Down Expand Up @@ -90,8 +90,9 @@ def wrap_scalar_pandas_udf(f, return_type):
def verify_result_length(*a):
result = f(*a)
if not hasattr(result, "__len__"):
pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series"
raise TypeError("Return type of the user-defined function should be "
"Pandas.Series, but is {}".format(type(result)))
"{}, but is {}".format(pd_type, type(result)))
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
"expected %d, got %d" % (len(a[0]), len(result)))
Expand Down Expand Up @@ -254,7 +255,12 @@ def read_udfs(pickleSer, infile, eval_type):
timezone = runner_conf.get("spark.sql.session.timeZone", None)
safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
"false").lower() == 'true'
ser = ArrowStreamPandasSerializer(timezone, safecheck)
# NOTE: this is duplicated from wrap_grouped_map_pandas_udf
assign_cols_by_name = runner_conf.get(
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
.lower() == "true"

ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
else:
ser = BatchedSerializer(PickleSerializer(), 100)

Expand Down

0 comments on commit e8193ed

Please sign in to comment.