Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] [strict mode] Allow returning lists instead of arrays for numpy batches #34734

Merged
merged 5 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions python/ray/data/_internal/planner/map_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,29 @@ def validate_batch(batch: Block) -> None:
)

if isinstance(batch, collections.abc.Mapping):
for key, value in batch.items():
if not isinstance(value, np.ndarray):
for key, value in list(batch.items()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why you need the list( here -- shouldn't .items() always return an object that can be iterated on with the for key, value in batch.items() pattern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was to avoid mutating the dict we are iterating over. Though, maybe in this case it is ok, since it doesn't change the keys.

if not isinstance(value, (np.ndarray, list)):
raise ValueError(
f"Error validating {_truncated_repr(batch)}: "
"The `fn` you passed to `map_batches` returned a "
f"`dict`. `map_batches` expects all `dict` values "
f"to be of type `numpy.ndarray`, but the value "
f"to be `list` or `np.ndarray` type, but the value "
f"corresponding to key {key!r} is of type "
f"{type(value)}. To fix this issue, convert "
f"the {type(value)} to a `numpy.ndarray`."
f"the {type(value)} to a `np.ndarray`."
)
if isinstance(value, list):
# Try to convert list values into an numpy array via
# np.array(), so users don't need to manually cast.
# NOTE: we don't cast generic iterables, since types like
# `str` are also Iterable.
try:
batch[key] = np.array(value)
except Exception:
raise ValueError(
"Failed to convert column values to numpy array: "
f"({_truncated_repr(value)})."
)

def process_next_batch(batch: DataBatch) -> Iterator[Block]:
# Apply UDF.
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def batch_to_block(batch: DataBatch) -> Block:
return ArrowBlockAccessor.numpy_to_block(
batch, passthrough_arrow_not_implemented_errors=True
)
except pa.ArrowNotImplementedError:
except (pa.ArrowNotImplementedError, pa.ArrowInvalid):
import pandas as pd

# TODO(ekl) once we support Python objects within Arrow blocks, we
Expand Down
20 changes: 20 additions & 0 deletions python/ray/data/tests/test_strict_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ def test_strict_map_output(ray_start_regular_shared, enable_strict_mode):
ds.map(lambda x: UserDict({"x": object()})).materialize()


def test_strict_convert_map_output(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.range(1).map_batches(lambda x: {"id": [0, 1, 2, 3]}).materialize()
assert ds.take_batch()["id"].tolist() == [0, 1, 2, 3]

with pytest.raises(ValueError):
# Strings not converted into array.
ray.data.range(1).map_batches(lambda x: {"id": "string"}).materialize()

class UserObj:
def __eq__(self, other):
return isinstance(other, UserObj)

ds = (
ray.data.range(1)
.map_batches(lambda x: {"id": [0, 1, 2, UserObj()]})
.materialize()
)
assert ds.take_batch()["id"].tolist() == [0, 1, 2, UserObj()]


def test_strict_default_batch_format(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.range(1)

Expand Down