From 9b418ed5f9de7bd6fc1c7c594db8347142f3f70a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 24 Apr 2023 16:54:27 -0700 Subject: [PATCH 1/4] cast Signed-off-by: Eric Liang --- python/ray/data/_internal/planner/map_batches.py | 15 ++++++++++++--- python/ray/data/block.py | 2 +- python/ray/data/tests/test_strict_mode.py | 12 ++++++++++++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/planner/map_batches.py b/python/ray/data/_internal/planner/map_batches.py index 1a03d3b6e45e..718794d05df7 100644 --- a/python/ray/data/_internal/planner/map_batches.py +++ b/python/ray/data/_internal/planner/map_batches.py @@ -1,6 +1,6 @@ import collections from types import GeneratorType -from typing import Callable, Iterator, Optional +from typing import Callable, Iterable, Iterator, Optional from ray.data._internal.block_batching import batch_blocks from ray.data._internal.execution.interfaces import TaskContext @@ -51,8 +51,8 @@ def validate_batch(batch: Block) -> None: ) if isinstance(batch, collections.abc.Mapping): - for key, value in batch.items(): - if not isinstance(value, np.ndarray): + for key, value in list(batch.items()): + if not isinstance(value, (np.ndarray, Iterable)): raise ValueError( f"Error validating {_truncated_repr(batch)}: " "The `fn` you passed to `map_batches` returned a " @@ -62,6 +62,15 @@ def validate_batch(batch: Block) -> None: f"{type(value)}. To fix this issue, convert " f"the {type(value)} to a `numpy.ndarray`." ) + if not isinstance(value, np.ndarray): + # Try to convert iterable values into an numpy array via + # np.array(), so users don't need to manually cast. + try: + batch[key] = np.array(value) + except Exception: + raise ValueError( + "Failed to convert column values to numpy array: " + f"({_truncated_repr(value)}).") def process_next_batch(batch: DataBatch) -> Iterator[Block]: # Apply UDF. diff --git a/python/ray/data/block.py b/python/ray/data/block.py index e213ad33887b..08101a060c80 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -414,7 +414,7 @@ def batch_to_block(batch: DataBatch) -> Block: return ArrowBlockAccessor.numpy_to_block( batch, passthrough_arrow_not_implemented_errors=True ) - except pa.ArrowNotImplementedError: + except (pa.ArrowNotImplementedError, pa.ArrowInvalid): import pandas as pd # TODO(ekl) once we support Python objects within Arrow blocks, we diff --git a/python/ray/data/tests/test_strict_mode.py b/python/ray/data/tests/test_strict_mode.py index 100097c91f6b..baefd5193629 100644 --- a/python/ray/data/tests/test_strict_mode.py +++ b/python/ray/data/tests/test_strict_mode.py @@ -84,6 +84,18 @@ def test_strict_map_output(ray_start_regular_shared): ds.map(lambda x: UserDict({"x": object()})).materialize() +def test_strict_convert_map_output(ray_start_regular_shared): + ds = ray.data.range(1).map_batches(lambda x: {"id": [0, 1, 2, 3]}).materialize() + assert ds.take_batch()["id"].tolist() == [0, 1, 2, 3] + + class UserObj: + def __eq__(self, other): + return isinstance(other, UserObj) + + ds = ray.data.range(1).map_batches(lambda x: {"id": [0, 1, 2, UserObj()]}).materialize() + assert ds.take_batch()["id"].tolist() == [0, 1, 2, UserObj()] + + def test_strict_default_batch_format(ray_start_regular_shared): ds = ray.data.range(1) From eca85a577fc1b9b77dd8050018f149171562f13f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 24 Apr 2023 16:54:37 -0700 Subject: [PATCH 2/4] lint Signed-off-by: Eric Liang --- python/ray/data/_internal/planner/map_batches.py | 3 ++- python/ray/data/tests/test_strict_mode.py | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/planner/map_batches.py b/python/ray/data/_internal/planner/map_batches.py index 718794d05df7..4a2ebd9e1ad4 100644 --- a/python/ray/data/_internal/planner/map_batches.py +++ b/python/ray/data/_internal/planner/map_batches.py @@ -70,7 +70,8 @@ def validate_batch(batch: Block) -> None: except Exception: raise ValueError( "Failed to convert column values to numpy array: " - f"({_truncated_repr(value)}).") + f"({_truncated_repr(value)})." + ) def process_next_batch(batch: DataBatch) -> Iterator[Block]: # Apply UDF. diff --git a/python/ray/data/tests/test_strict_mode.py b/python/ray/data/tests/test_strict_mode.py index baefd5193629..cdb80907c9cd 100644 --- a/python/ray/data/tests/test_strict_mode.py +++ b/python/ray/data/tests/test_strict_mode.py @@ -92,7 +92,11 @@ class UserObj: def __eq__(self, other): return isinstance(other, UserObj) - ds = ray.data.range(1).map_batches(lambda x: {"id": [0, 1, 2, UserObj()]}).materialize() + ds = ( + ray.data.range(1) + .map_batches(lambda x: {"id": [0, 1, 2, UserObj()]}) + .materialize() + ) assert ds.take_batch()["id"].tolist() == [0, 1, 2, UserObj()] From 9ac18b304c896547c3fde1600cb18040ab8700c9 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 24 Apr 2023 16:59:50 -0700 Subject: [PATCH 3/4] fix strings Signed-off-by: Eric Liang --- python/ray/data/_internal/planner/map_batches.py | 12 +++++++----- python/ray/data/tests/test_strict_mode.py | 4 ++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/planner/map_batches.py b/python/ray/data/_internal/planner/map_batches.py index 4a2ebd9e1ad4..6d6282425228 100644 --- a/python/ray/data/_internal/planner/map_batches.py +++ b/python/ray/data/_internal/planner/map_batches.py @@ -1,6 +1,6 @@ import collections from types import GeneratorType -from typing import Callable, Iterable, Iterator, Optional +from typing import Callable, Iterator, Optional from ray.data._internal.block_batching import batch_blocks from ray.data._internal.execution.interfaces import TaskContext @@ -52,19 +52,21 @@ def validate_batch(batch: Block) -> None: if isinstance(batch, collections.abc.Mapping): for key, value in list(batch.items()): - if not isinstance(value, (np.ndarray, Iterable)): + if not isinstance(value, (np.ndarray, list)): raise ValueError( f"Error validating {_truncated_repr(batch)}: " "The `fn` you passed to `map_batches` returned a " f"`dict`. `map_batches` expects all `dict` values " - f"to be of type `numpy.ndarray`, but the value " + f"to be `list` or `np.ndarray` type, but the value " f"corresponding to key {key!r} is of type " f"{type(value)}. To fix this issue, convert " - f"the {type(value)} to a `numpy.ndarray`." + f"the {type(value)} to a `np.ndarray`." ) if not isinstance(value, np.ndarray): - # Try to convert iterable values into an numpy array via + # Try to convert list values into an numpy array via # np.array(), so users don't need to manually cast. + # NOTE: we don't cast generic iterables, since types like + # `str` are also Iterable. try: batch[key] = np.array(value) except Exception: diff --git a/python/ray/data/tests/test_strict_mode.py b/python/ray/data/tests/test_strict_mode.py index cdb80907c9cd..99dba6b1f5a6 100644 --- a/python/ray/data/tests/test_strict_mode.py +++ b/python/ray/data/tests/test_strict_mode.py @@ -88,6 +88,10 @@ def test_strict_convert_map_output(ray_start_regular_shared): ds = ray.data.range(1).map_batches(lambda x: {"id": [0, 1, 2, 3]}).materialize() assert ds.take_batch()["id"].tolist() == [0, 1, 2, 3] + with pytest.raises(ValueError): + # Strings not converted into array. + ray.data.range(1).map_batches(lambda x: {"id": "string"}).materialize() + class UserObj: def __eq__(self, other): return isinstance(other, UserObj) From b5a6b7f66671579c8767b816c6de8a405cfc3eb0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 24 Apr 2023 21:06:45 -0700 Subject: [PATCH 4/4] check list type Signed-off-by: Eric Liang --- python/ray/data/_internal/planner/map_batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/planner/map_batches.py b/python/ray/data/_internal/planner/map_batches.py index 6d6282425228..a0d528bd8519 100644 --- a/python/ray/data/_internal/planner/map_batches.py +++ b/python/ray/data/_internal/planner/map_batches.py @@ -62,7 +62,7 @@ def validate_batch(batch: Block) -> None: f"{type(value)}. To fix this issue, convert " f"the {type(value)} to a `np.ndarray`." ) - if not isinstance(value, np.ndarray): + if isinstance(value, list): # Try to convert list values into an numpy array via # np.array(), so users don't need to manually cast. # NOTE: we don't cast generic iterables, since types like