-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23836][PYTHON] Add support for StructType return in Scalar Pandas UDF #23900
Changes from 7 commits
50ae9d4
5888a79
0c4a1c6
8567ce6
bfabb7d
174ad99
94fd921
89796b2
91decf0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,6 +64,7 @@ | |
from itertools import izip as zip, imap as map | ||
else: | ||
import pickle | ||
basestring = unicode = str | ||
xrange = range | ||
pickle_protocol = pickle.HIGHEST_PROTOCOL | ||
|
||
|
@@ -244,7 +245,7 @@ def __repr__(self): | |
return "ArrowStreamSerializer" | ||
|
||
|
||
def _create_batch(series, timezone, safecheck): | ||
def _create_batch(series, timezone, safecheck, assign_cols_by_name): | ||
""" | ||
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. | ||
|
||
|
@@ -254,6 +255,7 @@ def _create_batch(series, timezone, safecheck): | |
""" | ||
import decimal | ||
from distutils.version import LooseVersion | ||
import pandas as pd | ||
import pyarrow as pa | ||
from pyspark.sql.types import _check_series_convert_timestamps_internal | ||
# Make input conform to [(series1, type1), (series2, type2), ...] | ||
|
@@ -295,7 +297,33 @@ def create_array(s, t): | |
raise RuntimeError(error_msg % (s.dtype, t), e) | ||
return array | ||
|
||
arrs = [create_array(s, t) for s, t in series] | ||
arrs = [] | ||
for s, t in series: | ||
if t is not None and pa.types.is_struct(t): | ||
if not isinstance(s, pd.DataFrame): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good @BryanCutler, and cc @ueshin FYI. Just out of curiosity, WDYT about putting those PySpark specific conversion logics into somewhere together, of course, in a separate PR and JIRA? Looks it's getting difficult to read (to me .. ) |
||
raise ValueError("A field of type StructType expects a pandas.DataFrame, " | ||
"but got: %s" % str(type(s))) | ||
|
||
# Assign result columns by schema name if user labeled with strings, else use position | ||
struct_arrs = [] | ||
struct_names = [] | ||
if assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not all columns are labeled with string, but any one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this is copied from grouped map wrap. It didn't seem necessary to check all columns to be string. The only case that ends up weird is if the columns have a mix of strings and other types. I think that would be a little strange and I'm not sure that assigning by position is the right thing to do. So this would probably end up with raising a |
||
for field in t: | ||
struct_arrs.append(create_array(s[field.name], field.type)) | ||
struct_names.append(field.name) | ||
else: | ||
for i, field in enumerate(t): | ||
struct_arrs.append(create_array(s[s.columns[i]], field.type)) | ||
struct_names.append(field.name) | ||
|
||
# TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version | ||
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): | ||
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs)) | ||
else: | ||
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) | ||
else: | ||
arrs.append(create_array(s, t)) | ||
|
||
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) | ||
|
||
|
||
|
@@ -304,10 +332,11 @@ class ArrowStreamPandasSerializer(Serializer): | |
Serializes Pandas.Series as Arrow data with Arrow streaming format. | ||
""" | ||
|
||
def __init__(self, timezone, safecheck): | ||
def __init__(self, timezone, safecheck, assign_cols_by_name): | ||
super(ArrowStreamPandasSerializer, self).__init__() | ||
self._timezone = timezone | ||
self._safecheck = safecheck | ||
self._assign_cols_by_name = assign_cols_by_name | ||
|
||
def arrow_to_pandas(self, arrow_column): | ||
from pyspark.sql.types import from_arrow_type, \ | ||
|
@@ -326,7 +355,8 @@ def dump_stream(self, iterator, stream): | |
writer = None | ||
try: | ||
for series in iterator: | ||
batch = _create_batch(series, self._timezone, self._safecheck) | ||
batch = _create_batch(series, self._timezone, self._safecheck, | ||
self._assign_cols_by_name) | ||
if writer is None: | ||
write_int(SpecialLengths.START_ARROW_STREAM, stream) | ||
writer = pa.RecordBatchStreamWriter(stream, batch.schema) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -557,8 +557,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): | |
|
||
# Create Arrow record batches | ||
safecheck = self._wrapped._conf.arrowSafeTypeConversion() | ||
col_by_name = True # col by name only applies to StructType columns, can't happen here | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Meaning the value of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's right. This will be removed when we take out that conf. |
||
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], | ||
timezone, safecheck) | ||
timezone, safecheck, col_by_name) | ||
for pdf_slice in pdf_slices] | ||
|
||
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,13 +23,16 @@ | |
import time | ||
import unittest | ||
|
||
if sys.version >= '3': | ||
unicode = str | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
|
||
from datetime import date, datetime | ||
from decimal import Decimal | ||
from distutils.version import LooseVersion | ||
|
||
from pyspark.rdd import PythonEvalType | ||
from pyspark.sql import Column | ||
from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf | ||
from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf | ||
from pyspark.sql.types import Row | ||
from pyspark.sql.types import * | ||
from pyspark.sql.utils import AnalysisException | ||
|
@@ -265,6 +268,64 @@ def test_vectorized_udf_null_array(self): | |
result = df.select(array_f(col('array'))) | ||
self.assertEquals(df.collect(), result.collect()) | ||
|
||
def test_vectorized_udf_struct_type(self): | ||
import pandas as pd | ||
|
||
df = self.spark.range(10) | ||
return_type = StructType([ | ||
StructField('id', LongType()), | ||
StructField('str', StringType())]) | ||
|
||
def func(id): | ||
return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) | ||
|
||
f = pandas_udf(func, returnType=return_type) | ||
|
||
expected = df.select(struct(col('id'), col('id').cast('string').alias('str')) | ||
.alias('struct')).collect() | ||
|
||
actual = df.select(f(col('id')).alias('struct')).collect() | ||
self.assertEqual(expected, actual) | ||
|
||
g = pandas_udf(func, 'id: long, str: string') | ||
actual = df.select(g(col('id')).alias('struct')).collect() | ||
self.assertEqual(expected, actual) | ||
|
||
def test_vectorized_udf_struct_complex(self): | ||
import pandas as pd | ||
|
||
df = self.spark.range(10) | ||
return_type = StructType([ | ||
StructField('ts', TimestampType()), | ||
StructField('arr', ArrayType(LongType()))]) | ||
|
||
@pandas_udf(returnType=return_type) | ||
def f(id): | ||
return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)), | ||
'arr': id.apply(lambda i: [i, i + 1])}) | ||
|
||
actual = df.withColumn('f', f(col('id'))).collect() | ||
for i, row in enumerate(actual): | ||
id, f = row | ||
self.assertEqual(i, id) | ||
self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0]) | ||
self.assertListEqual([i, i + 1], f[1]) | ||
|
||
def test_vectorized_udf_nested_struct(self): | ||
nested_type = StructType([ | ||
StructField('id', IntegerType()), | ||
StructField('nested', StructType([ | ||
StructField('foo', StringType()), | ||
StructField('bar', FloatType()) | ||
])) | ||
]) | ||
|
||
with QuietTest(self.sc): | ||
with self.assertRaisesRegexp( | ||
Exception, | ||
'Invalid returnType with scalar Pandas UDFs'): | ||
pandas_udf(lambda x: x, returnType=nested_type) | ||
|
||
def test_vectorized_udf_complex(self): | ||
df = self.spark.range(10).select( | ||
col('id').cast('int').alias('a'), | ||
|
@@ -343,6 +404,10 @@ def test_vectorized_udf_unsupported_types(self): | |
NotImplementedError, | ||
'Invalid returnType.*scalar Pandas UDF.*MapType'): | ||
pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) | ||
with self.assertRaisesRegexp( | ||
NotImplementedError, | ||
'Invalid returnType.*scalar Pandas UDF.*ArrayType.StructType'): | ||
pandas_udf(lambda x: x, ArrayType(StructType([StructField('a', IntegerType())]))) | ||
|
||
def test_vectorized_udf_dates(self): | ||
schema = StructType().add("idx", LongType()).add("date", DateType()) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1613,9 +1613,15 @@ def to_arrow_type(dt): | |
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read | ||
arrow_type = pa.timestamp('us', tz='UTC') | ||
elif type(dt) == ArrayType: | ||
if type(dt.elementType) == TimestampType: | ||
if type(dt.elementType) in [StructType, TimestampType]: | ||
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) | ||
arrow_type = pa.list_(to_arrow_type(dt.elementType)) | ||
elif type(dt) == StructType: | ||
if any(type(field.dataType) == StructType for field in dt): | ||
raise TypeError("Nested StructType not supported in conversion to Arrow") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is ArrayType(elementType = StructType) supported? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, that should not be supported right now. I added a check and put that type in a test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am curious why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cfmcgrady support wasn't removed, it was never allowed to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your reply. |
||
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) | ||
for field in dt] | ||
arrow_type = pa.struct(fields) | ||
else: | ||
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) | ||
return arrow_type | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,6 +133,8 @@ def returnType(self): | |
"UDFs: returnType must be a StructType.") | ||
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: | ||
try: | ||
if isinstance(self._returnType_placeholder, StructType): | ||
raise TypeError | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, @BryanCutler, sorry if I missed something but why do we throw a type error here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Grouped Agg UDFs don't allow a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Can you add some message while we're here? If this is going to be fixed soon, I am okay as is as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I will try it as a followup, but a message for now will be good. I just noticed that grouped map wasn't catching a nested struct type, so I need to fix that anyway. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, either way sounds good to me. I'll leave it to you. |
||
to_arrow_type(self._returnType_placeholder) | ||
except TypeError: | ||
raise NotImplementedError( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -254,7 +254,12 @@ def read_udfs(pickleSer, infile, eval_type): | |
timezone = runner_conf.get("spark.sql.session.timeZone", None) | ||
safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", | ||
"false").lower() == 'true' | ||
ser = ArrowStreamPandasSerializer(timezone, safecheck) | ||
# NOTE: this is duplicated from wrap_grouped_map_pandas_udf | ||
assign_cols_by_name = runner_conf.get( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BryanCutler, BTW, would you be willing to work on removing Also, would you mind working on upgrading minimum Arrow to 0.12.0 as well, as we discussed? (Probably it better be asked to dev mailing list first to be 100% sure). If you're currently busy, I will take one or both. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course those should be separate JIRAs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, definitely! I could take those 2 tasks. I was thinking on holding off a little while to bump up the minimum Arrow version just to see if anything major came up in the meantime releases. 0.12.1 will be out in a couple days, but I don't think major bug fixes for us. Maybe wait just a little bit longer? |
||
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on above comment https://github.com/apache/spark/pull/23900/files#r260874304, if we are going to remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left it in to be consistent. I'd rather remove both of them in a separate PR in case there is some discussion about it. |
||
.lower() == "true" | ||
|
||
ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) | ||
else: | ||
ser = BatchedSerializer(PickleSerializer(), 100) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HyukjinKwon I thought we have other places for this kind of thing (or is it your new PR for cloudpickle)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes .. there are some places that use this here and there. IIRC, we discussed about Python 2 drop in dev mailing list. I could get rid of it soon anyway ..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this and below are just for Python 2 support. Are we dropping that for Spark 3.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we'll drop it in Spark 3.0. I will cc you in the related PRs later in the future.