Skip to content

Commit

Permalink
Fix a bug when iter random_map before mp_prefetch is seeded same way …
Browse files Browse the repository at this point in the history
…on each worker.

Also, avoid initializing parent iterator in mp_prefetch unless the state is requested before any elements were processed.

PiperOrigin-RevId: 711788838
  • Loading branch information
iindyk authored and copybara-github committed Jan 3, 2025
1 parent 7e90b81 commit 20f1174
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 13 deletions.
10 changes: 9 additions & 1 deletion grain/_src/python/dataset/transformations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import logging
from grain._src.core import transforms
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import prefetch
import numpy as np


Expand Down Expand Up @@ -233,7 +234,14 @@ def __next__(self):
with self._stats.record_self_time():
if element is not None:
if self._seed is not None:
_reset_rng_state(self._rng, op_seed=0, index=self._index_for_rng)
# Shift index for the current worker process in case of multiprocess
# execution. The actual index value doesn't matter as long as it is
# unique for each process.
index_for_rng = (
prefetch.worker_process_index
+ self._index_for_rng * prefetch.worker_process_count
)
_reset_rng_state(self._rng, op_seed=0, index=index_for_rng)
element = self._map_fn(element, self._rng)
else:
element = self._map_fn(element)
Expand Down
44 changes: 32 additions & 12 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@

T = TypeVar("T")

# Index of the current worker process and total number of processes. If used
# before multiprocess prefetch, must only be used during or after iterator
# initialization.
# TODO: Introduce context shared by all iterators and put these
# variables there.
worker_process_index = 0
worker_process_count = 1


@typing.runtime_checkable
class SupportsInPlaceSlicing(Protocol):
Expand Down Expand Up @@ -374,12 +382,16 @@ def __init__(
def __call__(
self, *, worker_index: int, worker_count: int
) -> Iterator[tuple[T, Optional[dict[str, Any]]]]:
# Recover from the last recorded state for the given worker.
worker_state = self._state[_WORKERS_STATE][str(worker_index)]
global worker_process_index, worker_process_count
worker_process_index = worker_index
worker_process_count = worker_count
if worker_count > 1:
_set_slice(self._ds, slice(worker_index, None, worker_count))
it = iter(self._ds)
it.set_state(worker_state) # pytype: disable=attribute-error
it = self._ds.__iter__()
# Recover from the last recorded state for the given worker.
worker_state = self._state[_WORKERS_STATE][str(worker_index)]
if worker_state is not None:
it.set_state(worker_state)
# Skip the required number of iterations after the last recorded state.
for _ in range(self._state[_ITERATIONS_TO_SKIP][str(worker_index)]):
_ = next(it)
Expand Down Expand Up @@ -432,13 +444,12 @@ def __init__(
# Create initial state. We record state of each worker periodically together
# with the number of iterations without the recorded state and index of the
# last worker.
workers_state: dict[str, Any] = {}
iterations_to_skip: dict[str, int] = {}
for i in range(multiprocessing_options.num_workers):
workers_state[str(i)] = iter(
self._iter_parent
).get_state() # pytype: disable=attribute-error
iterations_to_skip[str(i)] = 0
iterations_to_skip: dict[str, int] = {
str(i): 0 for i in range(multiprocessing_options.num_workers)
}
workers_state: dict[str, Any] = {
str(i): None for i in range(multiprocessing_options.num_workers)
}

self._state: dict[str, dict[str, Any] | int] = {
_WORKERS_STATE: workers_state,
Expand Down Expand Up @@ -483,7 +494,16 @@ def set_state(self, state: dict[str, dict[str, Any] | int]) -> None:
self._iterator = None

def get_state(self) -> dict[str, Any]:
return copy.deepcopy(self._state)
result = copy.deepcopy(self._state)
workers_state: dict[str, Any] = result[_WORKERS_STATE] # pytype: disable=annotation-type-mismatch
parent_state = None
for worker_index, worker_state in workers_state.items():
# Create initial state from the parent iterator. This is to make sure the
# spec of the produced iterator does not change.
if worker_state is None:
parent_state = parent_state or self._iter_parent.__iter__().get_state()
workers_state[worker_index] = copy.deepcopy(parent_state)
return result

def _ensure_iterator_initialized(self) -> None:
if self._iterator is None:
Expand Down
13 changes: 13 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,19 @@ def map(self, features):
# buffers.
time.sleep(30)

def test_prefetch_with_random_map(self):
ds = dataset.MapDataset.source([0]).repeat(100).to_iter_dataset()
ds = ds.random_map(lambda x, rng: x + rng.integers(sys.maxsize), seed=42)
ds = prefetch.MultiprocessPrefetchIterDataset(
ds,
options.MultiprocessingOptions(num_workers=5),
)
# Make sure that sliced datasets on workers are seeded differently and thus
# produce different random elements.
elements = list(ds)
distinct_elements = set(elements)
self.assertLen(distinct_elements, len(elements))


class ThreadPrefetchIterDatasetTest(parameterized.TestCase):

Expand Down

0 comments on commit 20f1174

Please sign in to comment.