diff --git a/grain/_src/python/dataset/transformations/map.py b/grain/_src/python/dataset/transformations/map.py index 18e59bc3..bd1ba722 100644 --- a/grain/_src/python/dataset/transformations/map.py +++ b/grain/_src/python/dataset/transformations/map.py @@ -20,6 +20,7 @@ from absl import logging from grain._src.core import transforms from grain._src.python.dataset import dataset +from grain._src.python.dataset.transformations import prefetch import numpy as np @@ -233,7 +234,14 @@ def __next__(self): with self._stats.record_self_time(): if element is not None: if self._seed is not None: - _reset_rng_state(self._rng, op_seed=0, index=self._index_for_rng) + # Shift index for the current worker process in case of multiprocess + # execution. The actual index value doesn't matter as long as it is + # unique for each process. + index_for_rng = ( + prefetch.worker_process_index + + self._index_for_rng * prefetch.worker_process_count + ) + _reset_rng_state(self._rng, op_seed=0, index=index_for_rng) element = self._map_fn(element, self._rng) else: element = self._map_fn(element) diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 5353451a..2bd5fc2e 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -41,6 +41,14 @@ T = TypeVar("T") +# Index of the current worker process and total number of processes. If used +# before multiprocess prefetch, must only be used during or after iterator +# initialization. +# TODO: Introduce context shared by all iterators and put these +# variables there. +worker_process_index = 0 +worker_process_count = 1 + @typing.runtime_checkable class SupportsInPlaceSlicing(Protocol): @@ -374,12 +382,16 @@ def __init__( def __call__( self, *, worker_index: int, worker_count: int ) -> Iterator[tuple[T, Optional[dict[str, Any]]]]: - # Recover from the last recorded state for the given worker. - worker_state = self._state[_WORKERS_STATE][str(worker_index)] + global worker_process_index, worker_process_count + worker_process_index = worker_index + worker_process_count = worker_count if worker_count > 1: _set_slice(self._ds, slice(worker_index, None, worker_count)) - it = iter(self._ds) - it.set_state(worker_state) # pytype: disable=attribute-error + it = self._ds.__iter__() + # Recover from the last recorded state for the given worker. + worker_state = self._state[_WORKERS_STATE][str(worker_index)] + if worker_state is not None: + it.set_state(worker_state) # Skip the required number of iterations after the last recorded state. for _ in range(self._state[_ITERATIONS_TO_SKIP][str(worker_index)]): _ = next(it) @@ -432,13 +444,12 @@ def __init__( # Create initial state. We record state of each worker periodically together # with the number of iterations without the recorded state and index of the # last worker. - workers_state: dict[str, Any] = {} - iterations_to_skip: dict[str, int] = {} - for i in range(multiprocessing_options.num_workers): - workers_state[str(i)] = iter( - self._iter_parent - ).get_state() # pytype: disable=attribute-error - iterations_to_skip[str(i)] = 0 + iterations_to_skip: dict[str, int] = { + str(i): 0 for i in range(multiprocessing_options.num_workers) + } + workers_state: dict[str, Any] = { + str(i): None for i in range(multiprocessing_options.num_workers) + } self._state: dict[str, dict[str, Any] | int] = { _WORKERS_STATE: workers_state, @@ -483,7 +494,16 @@ def set_state(self, state: dict[str, dict[str, Any] | int]) -> None: self._iterator = None def get_state(self) -> dict[str, Any]: - return copy.deepcopy(self._state) + result = copy.deepcopy(self._state) + workers_state: dict[str, Any] = result[_WORKERS_STATE] # pytype: disable=annotation-type-mismatch + parent_state = None + for worker_index, worker_state in workers_state.items(): + # Create initial state from the parent iterator. This is to make sure the + # spec of the produced iterator does not change. + if worker_state is None: + parent_state = parent_state or self._iter_parent.__iter__().get_state() + workers_state[worker_index] = copy.deepcopy(parent_state) + return result def _ensure_iterator_initialized(self) -> None: if self._iterator is None: diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 792e7b8b..db76c0a7 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -561,6 +561,19 @@ def map(self, features): # buffers. time.sleep(30) + def test_prefetch_with_random_map(self): + ds = dataset.MapDataset.source([0]).repeat(100).to_iter_dataset() + ds = ds.random_map(lambda x, rng: x + rng.integers(sys.maxsize), seed=42) + ds = prefetch.MultiprocessPrefetchIterDataset( + ds, + options.MultiprocessingOptions(num_workers=5), + ) + # Make sure that sliced datasets on workers are seeded differently and thus + # produce different random elements. + elements = list(ds) + distinct_elements = set(elements) + self.assertLen(distinct_elements, len(elements)) + class ThreadPrefetchIterDatasetTest(parameterized.TestCase):