diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index 62dd255cc30cd..68251d0caaed3 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -497,7 +497,7 @@ def execute_to_iterator( """ ctx = DatasetContext.get_current() - if not ctx.use_streaming_executor: + if not ctx.use_streaming_executor or self.has_computed_output(): return ( self.execute(allow_clear_input_blocks, force_read).iter_blocks(), self._snapshot_stats, diff --git a/python/ray/data/tests/test_dataset_consumption.py b/python/ray/data/tests/test_dataset_consumption.py index f4882dc4e743b..affa3f2ceae07 100644 --- a/python/ray/data/tests/test_dataset_consumption.py +++ b/python/ray/data/tests/test_dataset_consumption.py @@ -180,6 +180,33 @@ def test_empty_dataset(ray_start_regular_shared): assert ds.count() == 0 +def test_cache_dataset(ray_start_regular_shared): + @ray.remote + class Counter: + def __init__(self): + self.i = 0 + + def inc(self): + print("INC") + self.i += 1 + return self.i + + c = Counter.remote() + + def inc(x): + ray.get(c.inc.remote()) + return x + + ds = ray.data.range(1) + ds = ds.map(inc) + ds = ds.cache() + + for _ in range(10): + ds.take_all() + + assert ray.get(c.inc.remote()) == 2 + + def test_schema(ray_start_regular_shared): ds = ray.data.range(10, parallelism=10) ds2 = ray.data.range_table(10, parallelism=10) diff --git a/python/ray/data/tests/test_dataset_parquet.py b/python/ray/data/tests/test_dataset_parquet.py index 15345b5942c32..0d98d792ca678 100644 --- a/python/ray/data/tests/test_dataset_parquet.py +++ b/python/ray/data/tests/test_dataset_parquet.py @@ -156,7 +156,7 @@ def test_parquet_read_basic(ray_start_regular_shared, fs, data_path): # Forces a data read. values = [[s["one"], s["two"]] for s in ds.take_all()] - check_num_computed(ds, 2, 1) + check_num_computed(ds, 2, 2) assert sorted(values) == [ [1, "a"], [2, "b"], @@ -466,7 +466,7 @@ def test_parquet_read_partitioned(ray_start_regular_shared, fs, data_path): # Forces a data read. values = [[s["one"], s["two"]] for s in ds.take()] - check_num_computed(ds, 2, 1) + check_num_computed(ds, 2, 2) assert sorted(values) == [ [1, "a"], [1, "b"], @@ -550,7 +550,7 @@ def test_parquet_read_partitioned_explicit(ray_start_regular_shared, tmp_path): # Forces a data read. values = [[s["one"], s["two"]] for s in ds.take()] - check_num_computed(ds, 2, 1) + check_num_computed(ds, 2, 2) assert sorted(values) == [ [1, "a"], [1, "b"], @@ -640,7 +640,7 @@ def test_parquet_read_parallel_meta_fetch(ray_start_regular_shared, fs, data_pat # Forces a data read. values = [s["one"] for s in ds.take(limit=3 * num_dfs)] - check_num_computed(ds, parallelism, 1) + check_num_computed(ds, parallelism, parallelism) assert sorted(values) == list(range(3 * num_dfs)) diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 0c4dc3c843679..2eeab5edb43ad 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -121,24 +121,9 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats): stats = canonicalize(ds.cache().stats()) if context.new_execution_backend: - if context.use_streaming_executor: - assert ( - stats - == """Stage N ReadRange->MapBatches(dummy_map_batches)->Map: N/N blocks executed in T -* Remote wall time: T min, T max, T mean, T total -* Remote cpu time: T min, T max, T mean, T total -* Peak heap memory usage (MiB): N min, N max, N mean -* Output num rows: N min, N max, N mean, N total -* Output size bytes: N min, N max, N mean, N total -* Tasks per node: N min, N max, N mean; N nodes used -* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \ -'obj_store_mem_peak': N} -""" - ) - else: - assert ( - stats - == """Stage N ReadRange->MapBatches(dummy_map_batches): N/N blocks executed in T + assert ( + stats + == """Stage N ReadRange->MapBatches(dummy_map_batches): N/N blocks executed in T * Remote wall time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total * Peak heap memory usage (MiB): N min, N max, N mean @@ -169,7 +154,7 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats): * In user code: T * Total time: T """ - ) + ) else: assert ( stats