Skip to content

Commit

Permalink
[SPARK-47365][PYTHON] Add toArrow() DataFrame method to PySpark
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
- Add a PySpark DataFrame method `toArrow()` which returns the contents of the DataFrame as a [PyArrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html), for both local Spark and Spark Connect.
- Add a new entry to the **Apache Arrow in PySpark** user guide page describing usage of the `toArrow()` method.
- Add  a new option to the method `_collect_as_arrow()` to provide more useful output when there are zero records returned. (This keeps the implementation of `toArrow()` simpler.)

### Why are the changes needed?
In the Apache Arrow community, we hear from a lot of users who want to return the contents of a PySpark DataFrame as a PyArrow Table. Currently the only documented way to do this is to return the contents as a pandas DataFrame, then use PyArrow (`pa`) to convert that to a PyArrow Table.
```py
pa.Table.from_pandas(df.toPandas())
```
But going through pandas adds significant overhead which is easily avoided since internally `toPandas()` already converts the contents of Spark DataFrame to Arrow format as an intermediate step when `spark.sql.execution.arrow.pyspark.enabled` is `true`.

Currently it is also possible to use the experimental `_collect_as_arrow()` method to return the contents of a PySpark DataFrame as a list of PyArrow RecordBatches. This PR adds a new non-experimental method `toArrow()` which returns the more user-friendly PyArrow Table object.

This PR also adds a new argument `empty_list_if_zero_records` to the experimental method `_collect_as_arrow()` to control what the method returns in the case when the result data has zero rows. If set to `True` (the default), the existing behavior is preserved, and the method returns an empty Python list. If set to `False`, the method returns returns a length-one list containing an empty Arrow RecordBatch which includes the schema. This is used by `toArrow()` which requires the schema even if the data has zero rows.

For Spark Connect, there is already a `SparkSession.client.to_table()` method that returns a PyArrow table. This PR uses that to expose `toArrow()` for Spark Connect.

### Does this PR introduce _any_ user-facing change?

- It adds a DataFrame method `toArrow()` to the PySpark SQL DataFrame API.
- It adds a new argument `empty_list_if_zero_records` to the experimental DataFrame method `_collect_as_arrow()` with a default value which preserves the method's existing behavior.
- It exposes `toArrow()` for Spark Connect, via the existing `SparkSession.client.to_table()` method.
- It does not introduce any other user-facing changes.

### How was this patch tested?
This adds a new test and a new helper function for the test in `pyspark/sql/tests/test_arrow.py`.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#45481 from ianmcook/SPARK-47365.

Lead-authored-by: Ian Cook <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
2 people authored and JacobZheng0927 committed May 11, 2024
1 parent 98eadca commit 1388434
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 20 deletions.
18 changes: 18 additions & 0 deletions examples/src/main/python/sql/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@
require_minimum_pyarrow_version()


def dataframe_to_arrow_table_example(spark: SparkSession) -> None:
import pyarrow as pa # noqa: F401
from pyspark.sql.functions import rand

# Create a Spark DataFrame
df = spark.range(100).drop("id").withColumns({"0": rand(), "1": rand(), "2": rand()})

# Convert the Spark DataFrame to a PyArrow Table
table = df.select("*").toArrow()

print(table.schema)
# 0: double not null
# 1: double not null
# 2: double not null


def dataframe_with_arrow_example(spark: SparkSession) -> None:
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -302,6 +318,8 @@ def arrow_slen(s): # type: ignore[no-untyped-def]
.appName("Python Arrow-in-Spark example") \
.getOrCreate()

print("Running Arrow conversion example: DataFrame to Table")
dataframe_to_arrow_table_example(spark)
print("Running Pandas to/from conversion example")
dataframe_with_arrow_example(spark)
print("Running pandas_udf example: Series to Frame")
Expand Down
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ DataFrame
DataFrame.tail
DataFrame.take
DataFrame.to
DataFrame.toArrow
DataFrame.toDF
DataFrame.toJSON
DataFrame.toLocalIterator
Expand Down
49 changes: 33 additions & 16 deletions python/docs/source/user_guide/sql/arrow_pandas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ is installed and available on all cluster nodes.
You can install it using pip or conda from the conda-forge channel. See PyArrow
`installation <https://arrow.apache.org/docs/python/install.html>`_ for details.

Conversion to Arrow Table
-------------------------

You can call :meth:`DataFrame.toArrow` to convert a Spark DataFrame to a PyArrow Table.

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 37-49
:dedent: 4

Note that :meth:`DataFrame.toArrow` results in the collection of all records in the DataFrame to
the driver program and should be done on a small subset of the data. Not all Spark data types are
currently supported and an error can be raised if a column has an unsupported type.

Enabling for Conversion to/from Pandas
--------------------------------------

Expand All @@ -53,7 +67,7 @@ This can be controlled by ``spark.sql.execution.arrow.pyspark.fallback.enabled``

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 37-52
:lines: 53-68
:dedent: 4

Using the above optimizations with Arrow will produce the same results as when Arrow is not
Expand Down Expand Up @@ -90,7 +104,7 @@ specify the type hints of ``pandas.Series`` and ``pandas.DataFrame`` as below:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 56-80
:lines: 72-96
:dedent: 4

In the following sections, it describes the combinations of the supported type hints. For simplicity,
Expand All @@ -113,7 +127,7 @@ The following example shows how to create this Pandas UDF that computes the prod

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 84-114
:lines: 100-130
:dedent: 4

For detailed usage, please see :func:`pandas_udf`.
Expand Down Expand Up @@ -152,7 +166,7 @@ The following example shows how to create this Pandas UDF:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 118-140
:lines: 134-156
:dedent: 4

For detailed usage, please see :func:`pandas_udf`.
Expand All @@ -174,7 +188,7 @@ The following example shows how to create this Pandas UDF:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 144-167
:lines: 160-183
:dedent: 4

For detailed usage, please see :func:`pandas_udf`.
Expand Down Expand Up @@ -205,7 +219,7 @@ and window operations:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 171-212
:lines: 187-228
:dedent: 4

.. currentmodule:: pyspark.sql.functions
Expand Down Expand Up @@ -270,7 +284,7 @@ in the group.

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 216-234
:lines: 232-250
:dedent: 4

For detailed usage, please see please see :meth:`GroupedData.applyInPandas`
Expand All @@ -288,7 +302,7 @@ The following example shows how to use :meth:`DataFrame.mapInPandas`:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 238-249
:lines: 254-265
:dedent: 4

For detailed usage, please see :meth:`DataFrame.mapInPandas`.
Expand Down Expand Up @@ -327,7 +341,7 @@ The following example shows how to use ``DataFrame.groupby().cogroup().applyInPa

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 253-275
:lines: 269-291
:dedent: 4


Expand All @@ -349,7 +363,7 @@ Here's an example that demonstrates the usage of both a default, pickled Python

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 279-297
:lines: 295-313
:dedent: 4

Compared to the default, pickled Python UDFs, Arrow Python UDFs provide a more coherent type coercion mechanism. UDF
Expand Down Expand Up @@ -421,9 +435,12 @@ be verified by the user.
Setting Arrow ``self_destruct`` for memory savings
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled`` can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas DataFrame.
This option is experimental, and some operations may fail on the resulting Pandas DataFrame due to immutable backing arrays.
Typically, you would see the error ``ValueError: buffer source array is read-only``.
Newer versions of Pandas may fix these errors by improving support for such cases.
You can work around this error by copying the column(s) beforehand.
Additionally, this conversion may be slower because it is single-threaded.
Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled``
can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a
Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas
DataFrame. This option can also save memory when creating a PyArrow Table via ``toArrow``.
This option is experimental. When used with ``toPandas``, some operations may fail on the resulting
Pandas DataFrame due to immutable backing arrays. Typically, you would see the error
``ValueError: buffer source array is read-only``. Newer versions of Pandas may fix these errors by
improving support for such cases. You can work around this error by copying the column(s)
beforehand. Additionally, this conversion may be slower because it is single-threaded.
4 changes: 4 additions & 0 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
from pyspark.core.rdd import RDD
from pyspark.core.context import SparkContext
from pyspark._typing import PrimitiveType
Expand Down Expand Up @@ -1825,6 +1826,9 @@ def mapInArrow(
) -> ParentDataFrame:
return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier, profile)

def toArrow(self) -> "pa.Table":
return PandasConversionMixin.toArrow(self)

def toPandas(self) -> "PandasDataFrameLike":
return PandasConversionMixin.toPandas(self)

Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,10 @@ def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
assert table is not None
return (table, schema)

def toArrow(self) -> "pa.Table":
table, _ = self._to_table()
return table

def toPandas(self) -> "PandasDataFrameLike":
query = self._plan.to_proto(self._session.client)
return self._session.client.to_pandas(query, self._plan.observations)
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
from pyspark.core.context import SparkContext
from pyspark.core.rdd import RDD
from pyspark._typing import PrimitiveType
Expand Down Expand Up @@ -1200,6 +1201,7 @@ def collect(self) -> List[Row]:
DataFrame.take : Returns the first `n` rows.
DataFrame.head : Returns the first `n` rows.
DataFrame.toPandas : Returns the data as a pandas DataFrame.
DataFrame.toArrow : Returns the data as a PyArrow Table.
Notes
-----
Expand Down Expand Up @@ -6213,6 +6215,34 @@ def mapInArrow(
"""
...

@dispatch_df_method
def toArrow(self) -> "pa.Table":
"""
Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``.
This is only available if PyArrow is installed and available.
.. versionadded:: 4.0.0
Notes
-----
This method should only be used if the resulting PyArrow ``pyarrow.Table`` is
expected to be small, as all the data is loaded into the driver's memory.
This API is a developer API.
Examples
--------
>>> df.toArrow() # doctest: +SKIP
pyarrow.Table
age: int64
name: string
----
age: [[2,5]]
name: [["Alice","Bob"]]
"""
...

def toPandas(self) -> "PandasDataFrameLike":
"""
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
Expand Down
48 changes: 44 additions & 4 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,48 @@ def toPandas(self) -> "PandasDataFrameLike":
else:
return pdf

def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch"]:
def toArrow(self) -> "pa.Table":
from pyspark.sql.dataframe import DataFrame

assert isinstance(self, DataFrame)

jconf = self.sparkSession._jconf

from pyspark.sql.pandas.types import to_arrow_schema
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()
to_arrow_schema(self.schema)

import pyarrow as pa

self_destruct = jconf.arrowPySparkSelfDestructEnabled()
batches = self._collect_as_arrow(
split_batches=self_destruct, empty_list_if_zero_records=False
)
table = pa.Table.from_batches(batches)
# Ensure only the table has a reference to the batches, so that
# self_destruct (if enabled) is effective
del batches
return table

def _collect_as_arrow(
self,
split_batches: bool = False,
empty_list_if_zero_records: bool = True,
) -> List["pa.RecordBatch"]:
"""
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
Returns all records as a list of Arrow RecordBatches. PyArrow must be installed
and available on driver and worker Python environments.
This is an experimental feature.
:param split_batches: split batches such that each column is in its own allocation, so
that the selfDestruct optimization is effective; default False.
:param empty_list_if_zero_records: If True (the default), returns an empty list if the
result has 0 records. Otherwise, returns a list of length 1 containing an empty
Arrow RecordBatch which includes the schema.
.. note:: Experimental.
"""
from pyspark.sql.dataframe import DataFrame
Expand Down Expand Up @@ -282,8 +315,15 @@ def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch
batches = results[:-1]
batch_order = results[-1]

# Re-order the batch list using the correct order
return [batches[i] for i in batch_order]
if len(batches) or empty_list_if_zero_records:
# Re-order the batch list using the correct order
return [batches[i] for i in batch_order]
else:
from pyspark.sql.pandas.types import to_arrow_schema

schema = to_arrow_schema(self.schema)
empty_arrays = [pa.array([], type=field.type) for field in schema]
return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)]


class SparkConversionMixin:
Expand Down
35 changes: 35 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,35 @@ def create_pandas_data_frame(self):
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
return pd.DataFrame(data=data_dict)

def create_arrow_table(self):
import pyarrow as pa
import pyarrow.compute as pc

data_dict = {}
for j, name in enumerate(self.schema.names):
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
t = pa.Table.from_pydict(data_dict)
# convert these to Arrow types
new_schema = t.schema.set(
t.schema.get_field_index("2_int_t"), pa.field("2_int_t", pa.int32())
)
new_schema = new_schema.set(
new_schema.get_field_index("4_float_t"), pa.field("4_float_t", pa.float32())
)
new_schema = new_schema.set(
new_schema.get_field_index("6_decimal_t"),
pa.field("6_decimal_t", pa.decimal128(38, 18)),
)
t = t.cast(new_schema)
# convert timestamp to local timezone
timezone = self.spark.conf.get("spark.sql.session.timeZone")
t = t.set_column(
t.schema.get_field_index("8_timestamp_t"),
"8_timestamp_t",
pc.assume_timezone(t["8_timestamp_t"], timezone),
)
return t

@property
def create_np_arrs(self):
import numpy as np
Expand Down Expand Up @@ -339,6 +368,12 @@ def test_pandas_round_trip(self):
pdf_arrow = df.toPandas()
assert_frame_equal(pdf_arrow, pdf)

def test_arrow_round_trip(self):
t_in = self.create_arrow_table()
df = self.spark.createDataFrame(self.data, schema=self.schema)
t_out = df.toArrow()
self.assertTrue(t_out.equals(t_in))

def test_pandas_self_destruct(self):
import pyarrow as pa

Expand Down

0 comments on commit 1388434

Please sign in to comment.