Skip to content

Commit

Permalink
[SPARK-37228][SQL][PYTHON] Implement DataFrame.mapInArrow in Python
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to implement `DataFrame.mapInArrow` that allows users to apply a function with PyArrow record batches such as:

```python
def do_something(iterator):
    for arrow_batch in iterator:
        # do something with `pyarrow.RecordBatch` and create new `pyarrow.RecordBatch`.
        # ...
        yield arrow_batch

df.mapInArrow(do_something, df.schema).show()
```

The general idea is simple. It shares the same codebase of `DataFrame.mapInPandas` except the pandas conversion logic.

This PR also piggy-backs:
- Removes the check in `spark.udf.register` on `SQL_MAP_PANDAS_ITER_UDF`. This type is only used for `DataFrame.mapInPandas` internally, and it cannot be registered as a SQL UDF
- Removes the type hints for `pandas_udf` that is used for internal purposes such as `SQL_MAP_PANDAS_ITER_UDF` and `SQL_COGROUPED_MAP_PANDAS_UDF`. Both cannot be used for `pandas_udf` as a SQL expression and it should be hidden to end users.

Note that documentation will be done in another PR.

### Why are the changes needed?

For usability and technical problems. Both are elabourated in more details at SPARK-37227.
Please also see the discussions at #26783.

### Does this PR introduce _any_ user-facing change?

Yes, this PR adds a new API:

```python
import pyarrow as pa

df = spark.createDataFrame(
    [(1, "foo"), (2, None), (3, "bar"), (4, "bar")], "a int, b string")

def func(iterator):
    for batch in iterator:
        # `batch` is pyarrow.RecordBatch.
        yield batch

df.mapInArrow(func, df.schema).collect()
```

### How was this patch tested?

Manually tested, and unit tests were added.

Closes #34505 from HyukjinKwon/SPARK-37228.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed Nov 14, 2021
1 parent 950422f commit 775e05f
Show file tree
Hide file tree
Showing 20 changed files with 468 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ private[spark] object PythonEvalType {
val SQL_SCALAR_PANDAS_ITER_UDF = 204
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
val SQL_MAP_ARROW_ITER_UDF = 207

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
Expand All @@ -63,6 +64,7 @@ private[spark] object PythonEvalType {
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
}
}

Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def __hash__(self):
"pyspark.sql.tests.test_pandas_cogrouped_map",
"pyspark.sql.tests.test_pandas_grouped_map",
"pyspark.sql.tests.test_pandas_map",
"pyspark.sql.tests.test_arrow_map",
"pyspark.sql.tests.test_pandas_udf",
"pyspark.sql.tests.test_pandas_udf_grouped_agg",
"pyspark.sql.tests.test_pandas_udf_scalar",
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class PythonEvalType(object):
SQL_SCALAR_PANDAS_ITER_UDF = 204
SQL_MAP_PANDAS_ITER_UDF = 205
SQL_COGROUPED_MAP_PANDAS_UDF = 206
SQL_MAP_ARROW_ITER_UDF = 207


def portable_hash(x):
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/rdd.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ from pyspark.sql.pandas._typing import (
PandasCogroupedMapUDFType,
PandasGroupedAggUDFType,
PandasMapIterUDFType,
ArrowMapIterUDFType,
)
import pyspark.context
from pyspark.resultiterable import ResultIterable
Expand Down Expand Up @@ -83,6 +84,7 @@ class PythonEvalType:
SQL_SCALAR_PANDAS_ITER_UDF: PandasScalarIterUDFType
SQL_MAP_PANDAS_ITER_UDF: PandasMapIterUDFType
SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType
SQL_MAP_ARROW_ITER_UDF: ArrowMapIterUDFType

class BoundedFloat(float):
def __new__(cls, mean: float, confidence: float, low: float, high: float) -> BoundedFloat: ...
Expand Down
9 changes: 5 additions & 4 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ from pyspark.sql.pandas._typing.protocols.series import SeriesLike as SeriesLike
import pandas.core.frame # type: ignore[import]
import pandas.core.series # type: ignore[import]

import pyarrow # type: ignore[import]

# POC compatibility annotations
PandasDataFrame: Type[DataFrameLike] = pandas.core.frame.DataFrame
PandasSeries: Type[SeriesLike] = pandas.core.series.Series
Expand All @@ -48,6 +50,7 @@ PandasGroupedMapUDFType = Literal[201]
PandasCogroupedMapUDFType = Literal[206]
PandasGroupedAggUDFType = Literal[202]
PandasMapIterUDFType = Literal[205]
ArrowMapIterUDFType = Literal[207]

class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike) -> SeriesLike: ...
Expand Down Expand Up @@ -325,10 +328,8 @@ PandasGroupedAggFunction = Union[

PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]]

ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], Iterable[pyarrow.RecordBatch]]

PandasCogroupedMapFunction = Callable[[DataFrameLike, DataFrameLike], DataFrameLike]

MapIterPandasUserDefinedFunction = NewType("MapIterPandasUserDefinedFunction", FunctionType)
GroupedMapPandasUserDefinedFunction = NewType("GroupedMapPandasUserDefinedFunction", FunctionType)
CogroupedMapPandasUserDefinedFunction = NewType(
"CogroupedMapPandasUserDefinedFunction", FunctionType
)
8 changes: 5 additions & 3 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
None,
]: # None means it should infer the type from type hints.
Expand Down Expand Up @@ -400,12 +401,13 @@ def _create_pandas_udf(f, returnType, evalType):
elif evalType in [
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
]:
# In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered
# In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered
# at `apply` instead.
# In case of 'SQL_MAP_PANDAS_ITER_UDF' and 'SQL_COGROUPED_MAP_PANDAS_UDF', the
# evaluation type will always be set.
# In case of 'SQL_MAP_PANDAS_ITER_UDF', 'SQL_MAP_ARROW_ITER_UDF' and
# 'SQL_COGROUPED_MAP_PANDAS_UDF', the evaluation type will always be set.
pass
elif len(argspec.annotations) > 0:
evalType = infer_eval_type(signature(f))
Expand Down
42 changes: 0 additions & 42 deletions python/pyspark/sql/pandas/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,10 @@ from pyspark.sql._typing import (
)
from pyspark.sql.pandas._typing import (
GroupedMapPandasUserDefinedFunction,
MapIterPandasUserDefinedFunction,
CogroupedMapPandasUserDefinedFunction,
PandasCogroupedMapFunction,
PandasCogroupedMapUDFType,
PandasGroupedAggFunction,
PandasGroupedAggUDFType,
PandasGroupedMapFunction,
PandasGroupedMapUDFType,
PandasMapIterFunction,
PandasMapIterUDFType,
PandasScalarIterFunction,
PandasScalarIterUDFType,
PandasScalarToScalarFunction,
Expand Down Expand Up @@ -130,39 +124,3 @@ def pandas_udf(
def pandas_udf(
f: Union[AtomicDataTypeOrString, ArrayType], *, functionType: PandasGroupedAggUDFType
) -> Callable[[PandasGroupedAggFunction], UserDefinedFunctionLike]: ...
@overload
def pandas_udf(
f: PandasMapIterFunction,
returnType: Union[StructType, str],
functionType: PandasMapIterUDFType,
) -> MapIterPandasUserDefinedFunction: ...
@overload
def pandas_udf(
f: Union[StructType, str], returnType: PandasMapIterUDFType
) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
*, returnType: Union[StructType, str], functionType: PandasMapIterUDFType
) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
f: Union[StructType, str], *, functionType: PandasMapIterUDFType
) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
f: PandasCogroupedMapFunction,
returnType: Union[StructType, str],
functionType: PandasCogroupedMapUDFType,
) -> CogroupedMapPandasUserDefinedFunction: ...
@overload
def pandas_udf(
f: Union[StructType, str], returnType: PandasCogroupedMapUDFType
) -> Callable[[PandasCogroupedMapFunction], CogroupedMapPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
*, returnType: Union[StructType, str], functionType: PandasCogroupedMapUDFType
) -> Callable[[PandasCogroupedMapFunction], CogroupedMapPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
f: Union[StructType, str], *, functionType: PandasCogroupedMapUDFType
) -> Callable[[PandasCogroupedMapFunction], CogroupedMapPandasUserDefinedFunction]: ...
4 changes: 3 additions & 1 deletion python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,11 @@ def applyInPandas(
"""
from pyspark.sql.pandas.functions import pandas_udf

# The usage of the pandas_udf is internal so type checking is disabled.
udf = pandas_udf(
func, returnType=schema, functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
)
) # type: ignore[call-overload]

all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
jdf = self._gd1._jgd.flatMapCoGroupsInPandas( # type: ignore[attr-defined]
Expand Down
68 changes: 66 additions & 2 deletions python/pyspark/sql/pandas/map_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

if TYPE_CHECKING:
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas._typing import PandasMapIterFunction
from pyspark.sql.pandas._typing import PandasMapIterFunction, ArrowMapIterFunction


class PandasMapOpsMixin(object):
Expand Down Expand Up @@ -84,13 +84,77 @@ def mapInPandas(

assert isinstance(self, DataFrame)

# The usage of the pandas_udf is internal so type checking is disabled.
udf = pandas_udf(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
)
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInPandas(udf_column._jc.expr()) # type: ignore[operator]
return DataFrame(jdf, self.sql_ctx)

def mapInArrow(
self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
function that takes and outputs a PyArrow's `RecordBatch`, and returns the result as a
:class:`DataFrame`.
The function should take an iterator of `pyarrow.RecordBatch`\\s and return
another iterator of `pyarrow.RecordBatch`\\s. All columns are passed
together as an iterator of `pyarrow.RecordBatch`\\s to the function and the
returned iterator of `pyarrow.RecordBatch`\\s are combined as a :class:`DataFrame`.
Each `pyarrow.RecordBatch` size can be controlled by
`spark.sql.execution.arrow.maxRecordsPerBatch`.
.. versionadded:: 3.3.0
Parameters
----------
func : function
a Python native function that takes an iterator of `pyarrow.RecordBatch`\\s, and
outputs an iterator of `pyarrow.RecordBatch`\\s.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
Examples
--------
>>> import pyarrow # doctest: +SKIP
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
>>> def filter_func(iterator):
... for batch in iterator:
... pdf = batch.to_pandas()
... yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1])
>>> df.mapInArrow(filter_func, df.schema).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
Notes
-----
This API is unstable, and for developers.
See Also
--------
pyspark.sql.functions.pandas_udf
pyspark.sql.DataFrame.mapInPandas
"""
from pyspark.sql import DataFrame
from pyspark.sql.pandas.functions import pandas_udf

assert isinstance(self, DataFrame)

# The usage of the pandas_udf is internal so type checking is disabled.
udf = pandas_udf(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)


def _test() -> None:
import doctest
Expand Down
45 changes: 45 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,51 @@ def __repr__(self):
return "ArrowStreamSerializer"


class ArrowStreamUDFSerializer(ArrowStreamSerializer):
"""
Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch
for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`.
"""

def load_stream(self, stream):
"""
Flatten the struct into Arrow's record batches.
"""
import pyarrow as pa

batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
for batch in batches:
struct = batch.column(0)
yield [pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))]

def dump_stream(self, iterator, stream):
"""
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
This should be sent after creating the first record batch so in case of an error, it can
be sent back to the JVM before the Arrow stream starts.
"""
import pyarrow as pa

def wrap_and_init_stream():
should_write_start_length = True
for batch, _ in iterator:
assert isinstance(batch, pa.RecordBatch)

# Wrap the root struct
struct = pa.StructArray.from_arrays(
batch.columns, fields=pa.struct(list(batch.schema))
)
batch = pa.RecordBatch.from_arrays([struct], ["_0"])

# Write the first record batch with initialization.
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False
yield batch

return super(ArrowStreamUDFSerializer, self).dump_stream(wrap_and_init_stream(), stream)


class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""
Serializes Pandas.Series as Arrow data with Arrow streaming format.
Expand Down
Loading

0 comments on commit 775e05f

Please sign in to comment.