diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a2c59fedfc8cd..0c3c68ec0bd95 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -64,6 +64,7 @@ from itertools import izip as zip, imap as map else: import pickle + basestring = unicode = str xrange = range pickle_protocol = pickle.HIGHEST_PROTOCOL @@ -244,7 +245,7 @@ def __repr__(self): return "ArrowStreamSerializer" -def _create_batch(series, timezone, safecheck): +def _create_batch(series, timezone, safecheck, assign_cols_by_name): """ Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. @@ -254,6 +255,7 @@ def _create_batch(series, timezone, safecheck): """ import decimal from distutils.version import LooseVersion + import pandas as pd import pyarrow as pa from pyspark.sql.types import _check_series_convert_timestamps_internal # Make input conform to [(series1, type1), (series2, type2), ...] @@ -295,7 +297,34 @@ def create_array(s, t): raise RuntimeError(error_msg % (s.dtype, t), e) return array - arrs = [create_array(s, t) for s, t in series] + arrs = [] + for s, t in series: + if t is not None and pa.types.is_struct(t): + if not isinstance(s, pd.DataFrame): + raise ValueError("A field of type StructType expects a pandas.DataFrame, " + "but got: %s" % str(type(s))) + + # Input partition and result pandas.DataFrame empty, make empty Arrays with struct + if len(s) == 0 and len(s.columns) == 0: + arrs_names = [(pa.array([], type=field.type), field.name) for field in t] + # Assign result columns by schema name if user labeled with strings + elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns): + arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t] + # Assign result columns by position + else: + arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) + for i, field in enumerate(t)] + + struct_arrs, struct_names = zip(*arrs_names) + + # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version + if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): + arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs)) + else: + arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) + else: + arrs.append(create_array(s, t)) + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) @@ -304,10 +333,11 @@ class ArrowStreamPandasSerializer(Serializer): Serializes Pandas.Series as Arrow data with Arrow streaming format. """ - def __init__(self, timezone, safecheck): + def __init__(self, timezone, safecheck, assign_cols_by_name): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import from_arrow_type, \ @@ -326,7 +356,8 @@ def dump_stream(self, iterator, stream): writer = None try: for series in iterator: - batch = _create_batch(series, self._timezone, self._safecheck) + batch = _create_batch(series, self._timezone, self._safecheck, + self._assign_cols_by_name) if writer is None: write_int(SpecialLengths.START_ARROW_STREAM, stream) writer = pa.RecordBatchStreamWriter(stream, batch.schema) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3c33e2bed92d9..a36423e67d750 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2842,8 +2842,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`. - :class:`MapType`, :class:`StructType` are currently not supported as output types. + :class:`MapType`, nested :class:`StructType` are currently not supported as output types. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and :meth:`pyspark.sql.DataFrame.select`. @@ -2868,6 +2869,15 @@ def pandas_udf(f=None, returnType=None, functionType=None): +----------+--------------+------------+ | 8| JOHN DOE| 22| +----------+--------------+------------+ + >>> @pandas_udf("first string, last string") # doctest: +SKIP + ... def split_expand(n): + ... return n.str.split(expand=True) + >>> df.select(split_expand("name")).show() # doctest: +SKIP + +------------------+ + |split_expand(name)| + +------------------+ + | [John, Doe]| + +------------------+ .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input column, but is the length of an internal batch used for each call to the function. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index bdf1701a58959..32a2c8a67252d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -557,8 +557,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): # Create Arrow record batches safecheck = self._wrapped._conf.arrowSafeTypeConversion() + col_by_name = True # col by name only applies to StructType columns, can't happen here batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], - timezone, safecheck) + timezone, safecheck, col_by_name) for pdf_slice in pdf_slices] # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index a0a25359d1e01..f7684d3fbcff0 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -273,6 +273,7 @@ def test_unsupported_types(self): StructField('map', MapType(StringType(), IntegerType())), StructField('arr_ts', ArrayType(TimestampType())), StructField('null', NullType()), + StructField('struct', StructType([StructField('l', LongType())])), ] # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0 diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 28ef98d7b3f1e..28b6db216d00a 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -23,13 +23,16 @@ import time import unittest +if sys.version >= '3': + unicode = str + from datetime import date, datetime from decimal import Decimal from distutils.version import LooseVersion from pyspark.rdd import PythonEvalType from pyspark.sql import Column -from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf +from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf from pyspark.sql.types import Row from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException @@ -265,6 +268,64 @@ def test_vectorized_udf_null_array(self): result = df.select(array_f(col('array'))) self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_struct_type(self): + import pandas as pd + + df = self.spark.range(10) + return_type = StructType([ + StructField('id', LongType()), + StructField('str', StringType())]) + + def func(id): + return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) + + f = pandas_udf(func, returnType=return_type) + + expected = df.select(struct(col('id'), col('id').cast('string').alias('str')) + .alias('struct')).collect() + + actual = df.select(f(col('id')).alias('struct')).collect() + self.assertEqual(expected, actual) + + g = pandas_udf(func, 'id: long, str: string') + actual = df.select(g(col('id')).alias('struct')).collect() + self.assertEqual(expected, actual) + + def test_vectorized_udf_struct_complex(self): + import pandas as pd + + df = self.spark.range(10) + return_type = StructType([ + StructField('ts', TimestampType()), + StructField('arr', ArrayType(LongType()))]) + + @pandas_udf(returnType=return_type) + def f(id): + return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)), + 'arr': id.apply(lambda i: [i, i + 1])}) + + actual = df.withColumn('f', f(col('id'))).collect() + for i, row in enumerate(actual): + id, f = row + self.assertEqual(i, id) + self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0]) + self.assertListEqual([i, i + 1], f[1]) + + def test_vectorized_udf_nested_struct(self): + nested_type = StructType([ + StructField('id', IntegerType()), + StructField('nested', StructType([ + StructField('foo', StringType()), + StructField('bar', FloatType()) + ])) + ]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Invalid returnType with scalar Pandas UDFs'): + pandas_udf(lambda x: x, returnType=nested_type) + def test_vectorized_udf_complex(self): df = self.spark.range(10).select( col('id').cast('int').alias('a'), @@ -331,6 +392,20 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_struct_with_empty_partition(self): + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\ + .withColumn('name', lit('John Doe')) + + @pandas_udf("first string, last string") + def split_expand(n): + return n.str.split(expand=True) + + result = df.select(split_expand('name')).collect() + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual('John', row[0]['first']) + self.assertEqual('Doe', row[0]['last']) + def test_vectorized_udf_varargs(self): df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) f = pandas_udf(lambda *v: v[0], LongType()) @@ -343,6 +418,10 @@ def test_vectorized_udf_unsupported_types(self): NotImplementedError, 'Invalid returnType.*scalar Pandas UDF.*MapType'): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*ArrayType.StructType'): + pandas_udf(lambda x: x, ArrayType(StructType([StructField('a', IntegerType())]))) def test_vectorized_udf_dates(self): schema = StructType().add("idx", LongType()).add("date", DateType()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 348cb5b118594..d87f0f91499ae 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1613,9 +1613,15 @@ def to_arrow_type(dt): # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == ArrayType: - if type(dt.elementType) == TimestampType: + if type(dt.elementType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) + elif type(dt) == StructType: + if any(type(field.dataType) == StructType for field in dt): + raise TypeError("Nested StructType not supported in conversion to Arrow") + fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + for field in dt] + arrow_type = pa.struct(fields) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 58f4e0dff5ee5..275abe9c85d1e 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -123,7 +123,7 @@ def returnType(self): elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: if isinstance(self._returnType_placeholder, StructType): try: - to_arrow_schema(self._returnType_placeholder) + to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid returnType with grouped map Pandas UDFs: " @@ -133,6 +133,9 @@ def returnType(self): "UDFs: returnType must be a StructType.") elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: try: + # StructType is not yet allowed as a return type, explicitly check here to fail fast + if isinstance(self._returnType_placeholder, StructType): + raise TypeError to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 01934a0e72758..0e9b6d665a36f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -39,7 +39,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type +from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -90,8 +90,9 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): + pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series" raise TypeError("Return type of the user-defined function should be " - "Pandas.Series, but is {}".format(type(result))) + "{}, but is {}".format(pd_type, type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) @@ -254,7 +255,12 @@ def read_udfs(pickleSer, infile, eval_type): timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", "false").lower() == 'true' - ser = ArrowStreamPandasSerializer(timezone, safecheck) + # NOTE: this is duplicated from wrap_grouped_map_pandas_udf + assign_cols_by_name = runner_conf.get( + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ + .lower() == "true" + + ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) else: ser = BatchedSerializer(PickleSerializer(), 100)