Skip to content

Commit

Permalink
[Data] Prevent from_pandas from combining input blocks (#46363)
Browse files Browse the repository at this point in the history
Originally, the number of blocks outputted by from_pandas equaled the number of input DataFrames (i.e., each input DataFrame became a block). For consistency with how we treat other inputs, #44937 changed the behavior so that each output block is the target block size. This meant that you could pass in many DataFrames as input but from_pandas would only output one block.

The change is problematic because many users do something like from_pandas(np.array_split(metadata, num_blocks)) to get better performance, and after #44937, the array_split is pointless. So, this PR reverts the change

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani authored Jul 9, 2024
1 parent c9b14d7 commit 5874960
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 62 deletions.
30 changes: 0 additions & 30 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,33 +630,3 @@ def gen():

def block_type(self) -> BlockType:
return BlockType.PANDAS


def _estimate_dataframe_size(df: "pandas.DataFrame") -> int:
"""Estimate the size of a pandas DataFrame.
This function is necessary because `DataFrame.memory_usage` doesn't count values in
columns with `dtype=object`.
The runtime complexity is linear in the number of values, so don't use this in
performance-critical code.
Args:
df: The DataFrame to estimate the size of.
Returns:
The estimated size of the DataFrame in bytes.
"""
size = 0
for column in df.columns:
if df[column].dtype == object:
for item in df[column]:
if isinstance(item, str):
size += len(item)
elif isinstance(item, np.ndarray):
size += item.nbytes
else:
size += 8 # pandas assumes object values are 8 bytes.
else:
size += df[column].nbytes
return size
26 changes: 10 additions & 16 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import collections
import logging
import math
import os
import warnings
from typing import (
Expand Down Expand Up @@ -48,7 +47,6 @@
)
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.logical.optimizers import LogicalPlan
from ray.data._internal.pandas_block import _estimate_dataframe_size
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.stats import DatasetStats
Expand Down Expand Up @@ -2388,7 +2386,7 @@ def from_pandas(
Create a Ray Dataset from a list of Pandas DataFrames.
>>> ray.data.from_pandas([df, df])
MaterializedDataset(num_blocks=1, num_rows=6, schema={a: int64, b: int64})
MaterializedDataset(num_blocks=2, num_rows=6, schema={a: int64, b: int64})
Args:
dfs: A pandas dataframe or a list of pandas dataframes.
Expand All @@ -2405,24 +2403,20 @@ def from_pandas(
if isinstance(dfs, pd.DataFrame):
dfs = [dfs]

context = DataContext.get_current()
num_blocks = override_num_blocks
if num_blocks is None:
total_size = sum(_estimate_dataframe_size(df) for df in dfs)
num_blocks = max(math.ceil(total_size / context.target_max_block_size), 1)

if len(dfs) > 1:
# I assume most users pass a single DataFrame as input. For simplicity, I'm
# concatenating DataFrames, even though it's not efficient.
ary = pd.concat(dfs, axis=0)
else:
ary = dfs[0]
dfs = np.array_split(ary, num_blocks)
if override_num_blocks is not None:
if len(dfs) > 1:
# I assume most users pass a single DataFrame as input. For simplicity, I'm
# concatenating DataFrames, even though it's not efficient.
ary = pd.concat(dfs, axis=0)
else:
ary = dfs[0]
dfs = np.array_split(ary, override_num_blocks)

from ray.air.util.data_batch_conversion import (
_cast_ndarray_columns_to_tensor_extension,
)

context = DataContext.get_current()
if context.enable_tensor_extension_casting:
dfs = [_cast_ndarray_columns_to_tensor_extension(df.copy()) for df in dfs]

Expand Down
16 changes: 0 additions & 16 deletions python/ray/data/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,6 @@ def test_from_pandas(ray_start_regular_shared, enable_pandas_block):
ctx.enable_pandas_block = old_enable_pandas_block


def test_from_pandas_default_num_blocks(ray_start_regular_shared, restore_data_context):
ray.data.DataContext.get_current().target_max_block_size = 8 * 1024 * 1024 # 8 MiB

record = {"number": 0, "string": "\0"}
record_size_bytes = 8 + 1 # 8 bytes for int64 and 1 byte for char
dataframe_size_bytes = 64 * 1024 * 1024 # 64 MiB
num_records = int(dataframe_size_bytes / record_size_bytes)
df = pd.DataFrame.from_records([record] * num_records)

ds = ray.data.from_pandas(df)

# If the target block size is 8 MiB, the DataFrame should be split into
# 64 MiB / (8 MiB / block) = 8 blocks.
assert ds.materialize().num_blocks() == 8


@pytest.mark.parametrize("num_inputs", [1, 2])
def test_from_pandas_override_num_blocks(num_inputs, ray_start_regular_shared):
df = pd.DataFrame({"number": [0]})
Expand Down

0 comments on commit 5874960

Please sign in to comment.