Skip to content

Commit

Permalink
[data] [streaming] Dataset.cache() doesn't work properly for streamin…
Browse files Browse the repository at this point in the history
…g executor (ray-project#33713)

It seems like we didn't have a test for the caching behavior, so when we enabled streaming mode, it broke caching. Previously, the cache assumption relied on the eager execution behavior of Dataset in general for all operations.

Signed-off-by: Jack He <[email protected]>
  • Loading branch information
ericl authored and ProjectsByJackHe committed May 4, 2023
1 parent 03d989a commit cce0f3d
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 24 deletions.
2 changes: 1 addition & 1 deletion python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def execute_to_iterator(
"""

ctx = DatasetContext.get_current()
if not ctx.use_streaming_executor:
if not ctx.use_streaming_executor or self.has_computed_output():
return (
self.execute(allow_clear_input_blocks, force_read).iter_blocks(),
self._snapshot_stats,
Expand Down
27 changes: 27 additions & 0 deletions python/ray/data/tests/test_dataset_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,33 @@ def test_empty_dataset(ray_start_regular_shared):
assert ds.count() == 0


def test_cache_dataset(ray_start_regular_shared):
@ray.remote
class Counter:
def __init__(self):
self.i = 0

def inc(self):
print("INC")
self.i += 1
return self.i

c = Counter.remote()

def inc(x):
ray.get(c.inc.remote())
return x

ds = ray.data.range(1)
ds = ds.map(inc)
ds = ds.cache()

for _ in range(10):
ds.take_all()

assert ray.get(c.inc.remote()) == 2


def test_schema(ray_start_regular_shared):
ds = ray.data.range(10, parallelism=10)
ds2 = ray.data.range_table(10, parallelism=10)
Expand Down
8 changes: 4 additions & 4 deletions python/ray/data/tests/test_dataset_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_parquet_read_basic(ray_start_regular_shared, fs, data_path):

# Forces a data read.
values = [[s["one"], s["two"]] for s in ds.take_all()]
check_num_computed(ds, 2, 1)
check_num_computed(ds, 2, 2)
assert sorted(values) == [
[1, "a"],
[2, "b"],
Expand Down Expand Up @@ -466,7 +466,7 @@ def test_parquet_read_partitioned(ray_start_regular_shared, fs, data_path):

# Forces a data read.
values = [[s["one"], s["two"]] for s in ds.take()]
check_num_computed(ds, 2, 1)
check_num_computed(ds, 2, 2)
assert sorted(values) == [
[1, "a"],
[1, "b"],
Expand Down Expand Up @@ -550,7 +550,7 @@ def test_parquet_read_partitioned_explicit(ray_start_regular_shared, tmp_path):

# Forces a data read.
values = [[s["one"], s["two"]] for s in ds.take()]
check_num_computed(ds, 2, 1)
check_num_computed(ds, 2, 2)
assert sorted(values) == [
[1, "a"],
[1, "b"],
Expand Down Expand Up @@ -640,7 +640,7 @@ def test_parquet_read_parallel_meta_fetch(ray_start_regular_shared, fs, data_pat

# Forces a data read.
values = [s["one"] for s in ds.take(limit=3 * num_dfs)]
check_num_computed(ds, parallelism, 1)
check_num_computed(ds, parallelism, parallelism)
assert sorted(values) == list(range(3 * num_dfs))


Expand Down
23 changes: 4 additions & 19 deletions python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,9 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
stats = canonicalize(ds.cache().stats())

if context.new_execution_backend:
if context.use_streaming_executor:
assert (
stats
== """Stage N ReadRange->MapBatches(dummy_map_batches)->Map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
else:
assert (
stats
== """Stage N ReadRange->MapBatches(dummy_map_batches): N/N blocks executed in T
assert (
stats
== """Stage N ReadRange->MapBatches(dummy_map_batches): N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
Expand Down Expand Up @@ -169,7 +154,7 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
* In user code: T
* Total time: T
"""
)
)
else:
assert (
stats
Expand Down

0 comments on commit cce0f3d

Please sign in to comment.