diff --git a/python/ray/data/_internal/planner/map_batches.py b/python/ray/data/_internal/planner/map_batches.py index 1a03d3b6e45e2..a0d528bd85190 100644 --- a/python/ray/data/_internal/planner/map_batches.py +++ b/python/ray/data/_internal/planner/map_batches.py @@ -51,17 +51,29 @@ def validate_batch(batch: Block) -> None: ) if isinstance(batch, collections.abc.Mapping): - for key, value in batch.items(): - if not isinstance(value, np.ndarray): + for key, value in list(batch.items()): + if not isinstance(value, (np.ndarray, list)): raise ValueError( f"Error validating {_truncated_repr(batch)}: " "The `fn` you passed to `map_batches` returned a " f"`dict`. `map_batches` expects all `dict` values " - f"to be of type `numpy.ndarray`, but the value " + f"to be `list` or `np.ndarray` type, but the value " f"corresponding to key {key!r} is of type " f"{type(value)}. To fix this issue, convert " - f"the {type(value)} to a `numpy.ndarray`." + f"the {type(value)} to a `np.ndarray`." ) + if isinstance(value, list): + # Try to convert list values into an numpy array via + # np.array(), so users don't need to manually cast. + # NOTE: we don't cast generic iterables, since types like + # `str` are also Iterable. + try: + batch[key] = np.array(value) + except Exception: + raise ValueError( + "Failed to convert column values to numpy array: " + f"({_truncated_repr(value)})." + ) def process_next_batch(batch: DataBatch) -> Iterator[Block]: # Apply UDF. diff --git a/python/ray/data/block.py b/python/ray/data/block.py index e47c4a8f241db..d525e0f07185c 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -439,7 +439,7 @@ def batch_to_block(batch: DataBatch) -> Block: return ArrowBlockAccessor.numpy_to_block( batch, passthrough_arrow_not_implemented_errors=True ) - except pa.ArrowNotImplementedError: + except (pa.ArrowNotImplementedError, pa.ArrowInvalid): import pandas as pd # TODO(ekl) once we support Python objects within Arrow blocks, we diff --git a/python/ray/data/tests/test_strict_mode.py b/python/ray/data/tests/test_strict_mode.py index 19eb1853f5eb4..bbbf3c6b1b238 100644 --- a/python/ray/data/tests/test_strict_mode.py +++ b/python/ray/data/tests/test_strict_mode.py @@ -80,6 +80,26 @@ def test_strict_map_output(ray_start_regular_shared, enable_strict_mode): ds.map(lambda x: UserDict({"x": object()})).materialize() +def test_strict_convert_map_output(ray_start_regular_shared, enable_strict_mode): + ds = ray.data.range(1).map_batches(lambda x: {"id": [0, 1, 2, 3]}).materialize() + assert ds.take_batch()["id"].tolist() == [0, 1, 2, 3] + + with pytest.raises(ValueError): + # Strings not converted into array. + ray.data.range(1).map_batches(lambda x: {"id": "string"}).materialize() + + class UserObj: + def __eq__(self, other): + return isinstance(other, UserObj) + + ds = ( + ray.data.range(1) + .map_batches(lambda x: {"id": [0, 1, 2, UserObj()]}) + .materialize() + ) + assert ds.take_batch()["id"].tolist() == [0, 1, 2, UserObj()] + + def test_strict_default_batch_format(ray_start_regular_shared, enable_strict_mode): ds = ray.data.range(1)