Skip to content

Commit

Permalink
Prevent race condition in P2P shuffle run manager (#8262)
Browse files Browse the repository at this point in the history
Co-authored-by: crusaderky <[email protected]>
  • Loading branch information
hendrikmakait and crusaderky authored Oct 13, 2023
1 parent 5cedc47 commit be23012
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 22 deletions.
56 changes: 40 additions & 16 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ class _ShuffleRunManager:
closed: bool
_active_runs: dict[ShuffleId, ShuffleRun]
_runs: set[ShuffleRun]
#: Mapping of shuffle IDs to the largest stale run ID.
#: This is used to prevent race conditions between fetching shuffle run data
#: from the scheduler and failing a shuffle run.
#: TODO: Remove once ordering between fetching and failing is guaranteed.
_stale_run_ids: dict[ShuffleId, int]
_runs_cleanup_condition: asyncio.Condition
_plugin: ShuffleWorkerPlugin

def __init__(self, plugin: ShuffleWorkerPlugin) -> None:
self.closed = False
self._active_runs = {}
self._runs = set()
self._stale_run_ids = {}
self._runs_cleanup_condition = asyncio.Condition()
self._plugin = plugin

Expand All @@ -52,6 +58,10 @@ def heartbeat(self) -> dict[ShuffleId, Any]:
}

def fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None:
stale_run_id = self._stale_run_ids.setdefault(shuffle_id, run_id)
if stale_run_id < run_id:
self._stale_run_ids[shuffle_id] = run_id

shuffle_run = self._active_runs.get(shuffle_id, None)
if shuffle_run is None or shuffle_run.run_id != run_id:
return
Expand Down Expand Up @@ -168,6 +178,29 @@ async def get_most_recent(
"""
return await self.get_with_run_id(shuffle_id=shuffle_id, run_id=max(run_ids))

async def _fetch(
self,
shuffle_id: ShuffleId,
spec: ShuffleSpec | None = None,
key: str | None = None,
) -> ShuffleRunSpec:
# FIXME: This should never be ToPickle[ShuffleRunSpec]
result: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
if spec is None:
result = await self._plugin.worker.scheduler.shuffle_get(
id=shuffle_id,
worker=self._plugin.worker.address,
)
else:
result = await self._plugin.worker.scheduler.shuffle_get_or_create(
spec=ToPickle(spec),
key=key,
worker=self._plugin.worker.address,
)
if isinstance(result, ToPickle):
result = result.data
return result

@overload
async def _refresh(
self,
Expand All @@ -190,21 +223,7 @@ async def _refresh(
spec: ShuffleSpec | None = None,
key: str | None = None,
) -> ShuffleRun:
# FIXME: This should never be ToPickle[ShuffleRunSpec]
result: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
if spec is None:
result = await self._plugin.worker.scheduler.shuffle_get(
id=shuffle_id,
worker=self._plugin.worker.address,
)
else:
result = await self._plugin.worker.scheduler.shuffle_get_or_create(
spec=ToPickle(spec),
key=key,
worker=self._plugin.worker.address,
)
if isinstance(result, ToPickle):
result = result.data
result = await self._fetch(shuffle_id=shuffle_id, spec=spec, key=key)
if self.closed:
raise ShuffleClosedError(f"{self} has already been closed")
if existing := self._active_runs.get(shuffle_id, None):
Expand All @@ -216,7 +235,12 @@ async def _refresh(
existing.run_id,
f"{existing!r} stale, expected run_id=={result.run_id}",
)

stale_run_id = self._stale_run_ids.get(shuffle_id, None)
if stale_run_id is not None and stale_run_id >= result.run_id:
raise RuntimeError(
f"Received stale shuffle run with run_id={result.run_id};"
f" expected run_id > {stale_run_id}"
)
shuffle_run = result.spec.create_run_on_worker(
result.run_id, result.worker_for, self._plugin
)
Expand Down
60 changes: 54 additions & 6 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,6 +1935,7 @@ async def test_shuffle_run_consistency(c, s, a):
but it is an implementation detail that users should not rely upon.
"""
await c.register_plugin(BlockedBarrierShuffleWorkerPlugin(), name="shuffle")
run_manager = get_shuffle_run_manager(a)
worker_plugin = a.plugins["shuffle"]
scheduler_ext = s.plugins["shuffle"]

Expand All @@ -1951,13 +1952,13 @@ async def test_shuffle_run_consistency(c, s, a):
shuffle_id = await wait_until_new_shuffle_is_initialized(s)
spec = scheduler_ext.get(shuffle_id, a.worker_address).data

# Worker plugin can fetch the current run
assert await worker_plugin._get_shuffle_run(shuffle_id, spec.run_id)
# Shuffle run manager can fetch the current run
assert await run_manager.get_with_run_id(shuffle_id, spec.run_id)

# This should never occur, but fetching an ID larger than the ID available on
# the scheduler should result in an error.
with pytest.raises(RuntimeError, match="invalid"):
await worker_plugin._get_shuffle_run(shuffle_id, spec.run_id + 1)
await run_manager.get_with_run_id(shuffle_id, spec.run_id + 1)

# Finish first execution
worker_plugin.block_barrier.set()
Expand All @@ -1979,12 +1980,12 @@ async def test_shuffle_run_consistency(c, s, a):
# Check invariant that the new run ID is larger than the previous
assert spec.run_id < new_spec.run_id

# Worker plugin can fetch the new shuffle run
assert await worker_plugin._get_shuffle_run(shuffle_id, new_spec.run_id)
# Shuffle run manager can fetch the new shuffle run
assert await run_manager.get_with_run_id(shuffle_id, new_spec.run_id)

# Fetching a stale run from a worker aware of the new run raises an error
with pytest.raises(RuntimeError, match="stale"):
await worker_plugin._get_shuffle_run(shuffle_id, spec.run_id)
await run_manager.get_with_run_id(shuffle_id, spec.run_id)

worker_plugin.block_barrier.set()
await out
Expand All @@ -1993,6 +1994,7 @@ async def test_shuffle_run_consistency(c, s, a):
await asyncio.sleep(0)
worker_plugin.block_barrier.clear()

# Create an unrelated shuffle on a different column
out = dd.shuffle.shuffle(df, "y", shuffle="p2p")
out = out.persist()
independent_shuffle_id = await wait_until_new_shuffle_is_initialized(s)
Expand All @@ -2012,6 +2014,52 @@ async def test_shuffle_run_consistency(c, s, a):
await check_scheduler_cleanup(s)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_fail_fetch_race(c, s, a):
"""This test manually triggers a race condition where a `shuffle_fail` arrives on
the worker before the result of `get` or `get_or_create`.
TODO: This assumes that there are no ordering guarantees between failing and fetching
This test checks the correct creation of shuffle run IDs through the scheduler
as well as the correct handling through the workers. It can be removed once ordering
is guaranteed.
"""
await c.register_plugin(BlockedBarrierShuffleWorkerPlugin(), name="shuffle")
run_manager = get_shuffle_run_manager(a)
worker_plugin = a.plugins["shuffle"]
scheduler_ext = s.plugins["shuffle"]

df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="100 s",
)
out = dd.shuffle.shuffle(df, "x", shuffle="p2p")
out = out.persist()

shuffle_id = await wait_until_new_shuffle_is_initialized(s)
spec = scheduler_ext.get(shuffle_id, a.worker_address).data
await worker_plugin.in_barrier.wait()
# Pretend that the fail from the scheduler arrives first
run_manager.fail(shuffle_id, spec.run_id, "error")
assert shuffle_id not in run_manager._active_runs

with pytest.raises(RuntimeError, match="Received stale shuffle run"):
await run_manager.get_with_run_id(shuffle_id, spec.run_id)
assert shuffle_id not in run_manager._active_runs

with pytest.raises(RuntimeError, match="Received stale shuffle run"):
await run_manager.get_or_create(spec.spec, "test-key")
assert shuffle_id not in run_manager._active_runs

worker_plugin.block_barrier.set()
del out

await check_worker_cleanup(a)
await check_scheduler_cleanup(s)


class BlockedShuffleAccessAndFailShuffleRunManager(_ShuffleRunManager):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
Expand Down

0 comments on commit be23012

Please sign in to comment.