From 13f4c30803a4c65dc0c48ec39046c1f1ad6bd70f Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 17 Nov 2021 19:01:42 -0700 Subject: [PATCH 01/10] First draft of p2p shuffle via scheduler plugin --- distributed/shuffle/common.py | 34 ++++ distributed/shuffle/shuffle_scheduler.py | 190 +++++++++++++++++++ distributed/shuffle/shuffle_worker.py | 228 +++++++++++++++++++++++ 3 files changed, 452 insertions(+) create mode 100644 distributed/shuffle/common.py create mode 100644 distributed/shuffle/shuffle_scheduler.py create mode 100644 distributed/shuffle/shuffle_worker.py diff --git a/distributed/shuffle/common.py b/distributed/shuffle/common.py new file mode 100644 index 00000000000..b748afcead8 --- /dev/null +++ b/distributed/shuffle/common.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import math +from typing import NewType + +ShuffleId = NewType("ShuffleId", str) + + +def worker_for(output_partition: int, npartitions: int, workers: list[str]) -> str: + "Get the address of the worker which should hold this output partition number" + if output_partition < 0: + raise IndexError(f"Negative output partition: {output_partition}") + if output_partition >= npartitions: + raise IndexError( + f"Output partition {output_partition} does not exist in a shuffle producing {npartitions} partitions" + ) + i = len(workers) * output_partition // npartitions + return workers[i] + + +def partition_range( + worker: str, npartitions: int, workers: list[str] +) -> tuple[int, int]: + "Get the output partition numbers (inclusive) that a worker will hold" + i = workers.index(worker) + first = math.ceil(npartitions * i / len(workers)) + last = math.ceil(npartitions * (i + 1) / len(workers)) - 1 + return first, last + + +def npartitions_for(worker: str, npartitions: int, workers: list[str]) -> int: + "Get the number of output partitions a worker will hold" + first, last = partition_range(worker, npartitions, workers) + return last - first + 1 diff --git a/distributed/shuffle/shuffle_scheduler.py b/distributed/shuffle/shuffle_scheduler.py new file mode 100644 index 00000000000..18d7f3f7b56 --- /dev/null +++ b/distributed/shuffle/shuffle_scheduler.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable + +from distributed.diagnostics import SchedulerPlugin +from distributed.utils import key_split_group + +from .common import ShuffleId, worker_for + +if TYPE_CHECKING: + from distributed import Scheduler + from distributed.scheduler import TaskState + + +TASK_PREFIX = "('shuffle_" + + +@dataclass +class ShuffleState: + workers: list[str] + out_tasks_left: int + barrier_reached: bool = False + + +class ShuffleSchedulerPlugin(SchedulerPlugin): + started_prefixes: dict[str, Callable[[ShuffleId, str], None]] + output_keys: dict[str, ShuffleId] + shuffles: dict[ShuffleId, ShuffleState] + scheduler: Scheduler + + def __init__(self) -> None: + super().__init__() + self.shuffles = {} + self.output_keys = {} + self.started_prefixes = { + f"{TASK_PREFIX}transfer-": self.transfer, + f"{TASK_PREFIX}barrier-": self.barrier, + } + + async def start(self, scheduler: Scheduler) -> None: + self.scheduler = scheduler + + def transfer(self, id: ShuffleId, key: str) -> None: + state = self.shuffles.get(id, None) + if state: + assert ( + not state.barrier_reached + ), f"Duplicate shuffle: {key} running after barrier already reached" + # TODO allow plugins to return recommendations, so we can error this task in some way + return + + addrs = list(self.scheduler.workers) + # TODO handle resource/worker restrictions + + # Check how many output tasks there actually are, purely for validation right now. + # This lets us catch the "error" of culling shuffle output tasks. + # In the future, we may use it to actually handle culled shuffles properly. + ts: TaskState = self.scheduler.tasks[key] + assert ( + len(ts.dependents) == 1 + ), f"{key} should have exactly one dependency (the barrier), not {ts.dependents}" + barrier = next(iter(ts.dependents)) + nout = len(barrier.dependents) + + self.shuffles[id] = ShuffleState(addrs, nout) + + # TODO allow plugins to return worker messages (though hopefully these will get batched anyway) + msg = [{"op": "shuffle_init", "id": id, "workers": addrs, "n_out_tasks": nout}] + self.scheduler.send_all( + {}, + {addr: msg for addr in addrs}, + ) + + def barrier(self, id: ShuffleId, key: str) -> None: + state = self.shuffles[id] + assert ( + not state.barrier_reached + ), f"Duplicate barrier: {key} running but barrier already reached" + state.barrier_reached = True + + # Identify output tasks + ts: TaskState = self.scheduler.tasks[key] + + # Set worker restrictions on output tasks, and register their keys for us to watch in transitions + for dts in ts.dependents: + assert ( + len(dts.dependencies) == 1 + ), f"Output task {dts} (of shuffle {id}) should have 1 dependency, not {dts.dependencies}" + + assert ( + not dts.worker_restrictions + ), f"Output task {dts.key} (of shuffle {id}) already has worker restrictions {dts.worker_restrictions}" + + try: + dts._worker_restrictions = { + self.worker_for_key(dts.key, state.out_tasks_left, state.workers) + } + except (RuntimeError, IndexError, ValueError) as e: + raise type(e)( + f"Could not pick worker to run dependent {dts.key} of {key}: {e}" + ) from None + + self.output_keys[dts.key] = id + + def unpack(self, id: ShuffleId, key: str) -> None: + # Check if all output keys are done + + # NOTE: we don't actually need this `unpack` step or tracking output keys; + # we could just delete the state in `barrier`. + # But we do it so we can detect duplicate shuffles, where a `transfer` task + # tries to reuse a shuffle ID that we're unpacking. + # (It does also allow us to clean up worker restrictions on error) + state = self.shuffles[id] + assert ( + state.barrier_reached + ), f"Output {key} complete but barrier for shuffle {id} not yet reached" + assert ( + state.out_tasks_left > 0 + ), f"Output {key} complete; nout_left = {state.out_tasks_left} for shuffle {id}" + + state.out_tasks_left -= 1 + + if not state.out_tasks_left: + # Shuffle is done. Yay! + del self.shuffles[id] + + def erred(self, id: ShuffleId, key: str) -> None: + try: + state = self.shuffles.pop(id) + except KeyError: + return + + if state.barrier_reached: + # Remove worker restrictions for output tasks, in case the shuffle is re-submitted + for k, id_ in list(self.output_keys.items()): + if id_ == id: + ts: TaskState = self.scheduler.tasks[k] + ts._worker_restrictions.clear() + del self.output_keys[k] + + def transition(self, key: str, start: str, finish: str, *args, **kwargs): + if key.startswith(TASK_PREFIX): + # transfer/barrier starting to run + if start == "waiting" and finish in ("processing", "memory"): + for prefix, handler in self.started_prefixes.items(): + if key.startswith(prefix): + # Is this too brittle, assuming IDs are 32 characters? + # Inferring from the key name is brittle in general... + id = key[len(prefix) : len(prefix) + 32] + return handler(ShuffleId(id), key) + + # transfer/barrier task erred + elif finish == "erred": + id = key_split_group(key).split("-")[-1] + return self.erred(ShuffleId(id), key) + + # Task completed + if start in ("waiting", "processing") and finish in ( + "memory", + "released", + "erred", + ): + try: + id = self.output_keys[key] + except KeyError: + return + # Known unpack task completed or erred + if finish == "erred": + return self.erred(id, key) + return self.unpack(id, key) + + def worker_for_key(self, key: str, npartitions: int, workers: list[str]) -> str: + "Worker address this task should be assigned to" + # Infer which output partition number this task is fetching by parsing its key + # FIXME this is so brittle. + # For example, after `df.set_index(...).to_delayed()`, you could create + # keys that don't have indices in them, and get fused (because they should!). + m = re.match(r"\(.+, (\d+)\)$", key) + if not m: + raise RuntimeError(f"{key} does not look like a DataFrame key") + + idx = int(m.group(0)) + addr = worker_for(idx, npartitions, workers) + if addr not in self.scheduler.workers: + raise RuntimeError( + f"Worker {addr} for output partition {idx} no longer known" + ) + return addr diff --git a/distributed/shuffle/shuffle_worker.py b/distributed/shuffle/shuffle_worker.py new file mode 100644 index 00000000000..6e60fd35167 --- /dev/null +++ b/distributed/shuffle/shuffle_worker.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import asyncio +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import pandas as pd + +from distributed.protocol import to_serialize + +from .common import ShuffleId, npartitions_for, worker_for + +if TYPE_CHECKING: + from distributed import Worker + + +@dataclass +class ShuffleState: + workers: list[str] + npartitions: int + out_parts_left: int + barrier_reached: bool = False + + +class ShuffleWorkerExtension: + "Extend the Worker with routes and state for peer-to-peer shuffles" + worker: Worker + shuffles: dict[ShuffleId, ShuffleState] + waiting_for_metadata: dict[ShuffleId, asyncio.Event] + output_data: defaultdict[ShuffleId, defaultdict[int, list[pd.DataFrame]]] + + def __init__(self, worker: Worker) -> None: + # Attach to worker + worker.extensions["shuffle"] = self + worker.stream_handlers["shuffle_init"] = self.shuffle_init + worker.handlers["shuffle_receive"] = self.shuffle_receive + worker.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done + + # Initialize + self.worker: Worker = worker + self.shuffles = {} + self.waiting_for_metadata = {} + self.output_data = defaultdict(lambda: defaultdict(list)) + + # Handlers + ########## + + def shuffle_init(self, id: ShuffleId, workers: list[str], n_out_tasks: int) -> None: + if id in self.shuffles: + raise ValueError( + f"Shuffle {id!r} is already registered on worker {self.worker.address}" + ) + self.shuffles[id] = ShuffleState( + workers, + n_out_tasks, + npartitions_for(self.worker.address, n_out_tasks, workers), + ) + try: + # Invariant: if `waiting_for_metadata` event is set, key is already in `shuffles` + self.waiting_for_metadata[id].set() + except KeyError: + pass + + def shuffle_receive( + self, + comm: object, + id: ShuffleId, + output_partition: int, + data: pd.DataFrame, + ) -> None: + try: + state = self.shuffles[id] + except KeyError: + # NOTE: `receive` could be called before `init`, if some other worker + # processed their `init` faster than us and then sent us data. + # That's why we keep `output_data` separate from `shuffles`. + pass + else: + assert not state.barrier_reached, f"`receive` called after barrier for {id}" + receiver = worker_for(output_partition, state.npartitions, state.workers) + assert receiver == self.worker.address, ( + f"{self.worker.address} received output partition {output_partition} " + f"for shuffle {id}, which was expected to go to {receiver}." + ) + + self.output_data[id][output_partition].append(data) + + async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: + state = await self.get_shuffle(id) + assert not state.barrier_reached, f"`inputs_done` called again for {id}" + state.barrier_reached = True + + if not state.out_parts_left: + # No output partitions, remove shuffle it now: + # `get_output_partition` will never be called. + # This happens when there are fewer output partitions than workers. + self.remove(id) + + # Tasks + ####### + + async def _add_partition( + self, id: ShuffleId, npartitions: int, column: str, data: pd.DataFrame + ) -> None: + # Block until scheduler has called init + state = await self.get_shuffle(id) + assert not state.barrier_reached, f"`add_partition` for {id} after barrier" + + if npartitions != state.npartitions: + raise NotImplementedError( + f"Expected shuffle {id} to produce {npartitions} output tasks, " + f"but it only has {state.npartitions}. Did you sub-select from the " + "shuffled DataFrame, like `df.set_index(...).loc['foo':'bar']`?\n" + "This is not yet supported for peer-to-peer shuffles. Either remove " + "the sub-selection or use `shuffle='tasks'` for now." + ) + # Group and send data + await self.send_partition(data, column, id, npartitions, state.workers) + + async def _barrier(self, id: ShuffleId) -> None: + # NOTE: requires workers list. This is guaranteed because it depends on `add_partition`, + # which got the workers list from the scheduler. So this task must run on a worker where + # `add_partition` has already run. + state = await self.get_shuffle(id) + assert not state.barrier_reached, f"`barrier` for {id} called multiple times" + + # Call `shuffle_inputs_done` on peers. + # Note that this will call `shuffle_inputs_done` on our own worker as well. + # Concurrently, the scheduler is setting worker restrictions on its own. + await asyncio.gather( + *( + self.worker.rpc(worker).shuffle_inputs_done(id=id) + for worker in state.workers + ), + ) + + async def get_output_partition( + self, id: ShuffleId, i: int, empty: pd.DataFrame + ) -> pd.DataFrame: + state = self.shuffles[id] + # ^ Don't need to `get_shuffle`; `shuffle_inputs_done` has run already and guarantees it's there + assert state.barrier_reached, f"`get_output_partition` for {id} before barrier" + assert ( + state.out_parts_left > 0 + ), f"No outputs remaining, but requested output partition {i} on {self.worker.address} for {id}." + # ^ Note: this is impossible with our cleanup-on-empty + + worker = worker_for(i, state.npartitions, state.workers) + assert worker == self.worker.address, ( + f"{self.worker.address} received output partition {i} " + f"for shuffle {id}, which was expected to go to {worker}." + ) + + try: + parts = self.output_data[id].pop(i) + except KeyError: + result = empty + else: + result = pd.concat(parts, copy=False) + + state.out_parts_left -= 1 + if not state.out_parts_left: + # Shuffle is done. Yay! + self.remove(id) + + return result + + # Helpers + ######### + + def remove(self, id: ShuffleId) -> None: + state = self.shuffles.pop(id) + assert state.barrier_reached, f"Removed {id} before barrier" + assert ( + not state.out_parts_left + ), f"Removed {id} with {state.out_parts_left} outputs left" + + event = self.waiting_for_metadata.pop(id, None) + if event: + assert event.is_set(), f"Removed {id} while still waiting for metadata" + + data = self.output_data.pop(id, None) + assert ( + not data + ), f"Removed {id}, which still has data for output partitions {list(data)}" + + async def get_shuffle(self, id: ShuffleId): + try: + return self.shuffles[id] + except KeyError: + event = self.waiting_for_metadata.setdefault(id, asyncio.Event()) + try: + await asyncio.wait_for(event.wait(), timeout=5) # TODO config + except TimeoutError: + raise TimeoutError( + f"Timed out waiting for scheduler to start shuffle {id}" + ) from None + # Invariant: once `waiting_for_metadata` event is set, key is already in `shuffles`. + # And once key is in `shuffles`, no `get_shuffle` will create a new event. + # So we can safely remove the event now. + self.waiting_for_metadata.pop(id, None) + return self.shuffles[id] + + async def send_partition( + self, + data: pd.DataFrame, + column: str, + id: ShuffleId, + npartitions: int, + workers: list[str], + ) -> None: + tasks = [] + # TODO grouping is blocking, should it be offloaded to a thread? + # It mostly doesn't release the GIL though, so may not make much difference. + for output_partition, data in data.groupby(column): + addr = worker_for(int(output_partition), npartitions, workers) + task = asyncio.create_task( + self.worker.rpc(addr).shuffle_receive( + id=id, + output_partition=output_partition, + data=to_serialize(data), + ) + ) + tasks.append(task) + + # TODO handle errors and cancellation here + await asyncio.gather(*tasks) From a331e114d9f5bbeb3b289e4fb75e6dfadd78259b Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 17 Nov 2021 19:27:50 -0700 Subject: [PATCH 02/10] Add graph and import on scheduler --- distributed/scheduler.py | 7 ++- distributed/shuffle/__init__.py | 20 ++++++ distributed/shuffle/graph.py | 91 +++++++++++++++++++++++++++ distributed/shuffle/shuffle_worker.py | 19 ++++-- distributed/worker.py | 4 +- 5 files changed, 134 insertions(+), 7 deletions(-) create mode 100644 distributed/shuffle/__init__.py create mode 100644 distributed/shuffle/graph.py diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d56cf8e0bfa..7af5149aaef 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -51,7 +51,7 @@ from distributed.utils import recursive_to_dict -from . import preloading, profile +from . import preloading, profile, shuffle from . import versions as version_module from .active_memory_manager import ActiveMemoryManagerExtension from .batched import BatchedSend @@ -188,6 +188,9 @@ def nogil(func): ActiveMemoryManagerExtension, MemorySamplerExtension, ] +DEFAULT_PLUGINS: tuple[SchedulerPlugin, ...] = ( + (shuffle.ShuffleSchedulerPlugin(),) if shuffle.SHUFFLE_AVAILABLE else () +) ALL_TASK_STATES = declare( set, {"released", "waiting", "no-worker", "processing", "erred", "memory"} @@ -3623,7 +3626,7 @@ def __init__( http_prefix="/", preload=None, preload_argv=(), - plugins=(), + plugins=DEFAULT_PLUGINS, **kwargs, ): self._setup_logging(logger) diff --git a/distributed/shuffle/__init__.py b/distributed/shuffle/__init__.py new file mode 100644 index 00000000000..265d6aa07a1 --- /dev/null +++ b/distributed/shuffle/__init__.py @@ -0,0 +1,20 @@ +try: + import pandas +except ImportError: + SHUFFLE_AVAILABLE = False +else: + del pandas + SHUFFLE_AVAILABLE = True + + from .common import ShuffleId + from .graph import rearrange_by_column_p2p + from .shuffle_scheduler import ShuffleSchedulerPlugin + from .shuffle_worker import ShuffleWorkerExtension + +__all__ = [ + "SHUFFLE_AVAILABLE", + "ShuffleId", + "rearrange_by_column_p2p", + "ShuffleWorkerExtension", + "ShuffleSchedulerPlugin", +] diff --git a/distributed/shuffle/graph.py b/distributed/shuffle/graph.py new file mode 100644 index 00000000000..c0747546512 --- /dev/null +++ b/distributed/shuffle/graph.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from dask.base import tokenize +from dask.blockwise import BlockwiseDepDict, blockwise +from dask.dataframe import DataFrame +from dask.delayed import Delayed +from dask.highlevelgraph import HighLevelGraph + +from .common import ShuffleId +from .shuffle_worker import ShuffleWorkerExtension + +if TYPE_CHECKING: + import pandas as pd + + +def get_shuffle_extension() -> ShuffleWorkerExtension: + from distributed import get_worker + + return get_worker().extensions["shuffle"] + + +def shuffle_transfer( + data: pd.DataFrame, id: ShuffleId, npartitions: int, column: str +) -> None: + ext = get_shuffle_extension() + ext.sync(ext.add_partition(data, id, npartitions, column)) + + +def shuffle_unpack( + id: ShuffleId, i: int, empty: pd.DataFrame, barrier=None +) -> pd.DataFrame: + ext = get_shuffle_extension() + return ext.sync(ext.get_output_partition(id, i, empty)) + + +def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None: + ext = get_shuffle_extension() + ext.sync(ext.barrier(id)) + + +def rearrange_by_column_p2p( + df: DataFrame, + column: str, + npartitions: int | None = None, +): + npartitions = npartitions or df.npartitions + token = tokenize(df, column, npartitions) + + transferred = df.map_partitions( + shuffle_transfer, + token, + npartitions, + column, + meta=df, + enforce_metadata=False, + transform_divisions=False, + ) + + barrier_key = "shuffle-barrier-" + token + barrier_dsk = {barrier_key: (shuffle_barrier, token, transferred.__dask_keys__())} + barrier = Delayed( + barrier_key, + HighLevelGraph.from_collections( + barrier_key, barrier_dsk, dependencies=[transferred] + ), + ) + + name = "shuffle-unpack-" + token + dsk = blockwise( + shuffle_unpack, + name, + "i", + token, + None, + BlockwiseDepDict({(i,): i for i in range(npartitions)}), + "i", + df._meta, + None, + barrier_key, + None, + numblocks={}, + ) + + return DataFrame( + HighLevelGraph.from_collections(name, dsk, [barrier]), + name, + df._meta, + [None] * (npartitions + 1), + ) diff --git a/distributed/shuffle/shuffle_worker.py b/distributed/shuffle/shuffle_worker.py index 6e60fd35167..20df2ed0419 100644 --- a/distributed/shuffle/shuffle_worker.py +++ b/distributed/shuffle/shuffle_worker.py @@ -3,7 +3,7 @@ import asyncio from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Coroutine, TypeVar import pandas as pd @@ -14,6 +14,8 @@ if TYPE_CHECKING: from distributed import Worker +T = TypeVar("T") + @dataclass class ShuffleState: @@ -100,8 +102,8 @@ async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: # Tasks ####### - async def _add_partition( - self, id: ShuffleId, npartitions: int, column: str, data: pd.DataFrame + async def add_partition( + self, data: pd.DataFrame, id: ShuffleId, npartitions: int, column: str ) -> None: # Block until scheduler has called init state = await self.get_shuffle(id) @@ -118,7 +120,7 @@ async def _add_partition( # Group and send data await self.send_partition(data, column, id, npartitions, state.workers) - async def _barrier(self, id: ShuffleId) -> None: + async def barrier(self, id: ShuffleId) -> None: # NOTE: requires workers list. This is guaranteed because it depends on `add_partition`, # which got the workers list from the scheduler. So this task must run on a worker where # `add_partition` has already run. @@ -226,3 +228,12 @@ async def send_partition( # TODO handle errors and cancellation here await asyncio.gather(*tasks) + + @property + def loop(self) -> asyncio.AbstractEventLoop: + return self.worker.loop.asyncio_loop # type: ignore + + def sync(self, coro: Coroutine[object, object, T]) -> T: + # Is it a bad idea not to use `distributed.utils.sync`? + # It's much nicer to use asyncio, because among other things it gives us typechecking. + return asyncio.run_coroutine_threadsafe(coro, self.loop).result() diff --git a/distributed/worker.py b/distributed/worker.py index bbd707ae7bb..b5f74a062bc 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -44,7 +44,7 @@ typename, ) -from . import comm, preloading, profile, system, utils +from . import comm, preloading, profile, shuffle, system, utils from .batched import BatchedSend from .comm import Comm, connect, get_address_host from .comm.addressing import address_from_user_args, parse_address @@ -114,6 +114,8 @@ RUNNING = {Status.running, Status.paused, Status.closing_gracefully} DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension] +if shuffle.SHUFFLE_AVAILABLE: + DEFAULT_EXTENSIONS.append(shuffle.ShuffleWorkerExtension) DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {} From e3fb4b3c78f6544ccaa623c4c019e86fe77c2e93 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 17 Nov 2021 23:20:38 -0700 Subject: [PATCH 03/10] Ensure graph keys are what we expect --- distributed/shuffle/graph.py | 62 +++++++++++++++--------- distributed/shuffle/shuffle_scheduler.py | 48 +++++++++--------- 2 files changed, 64 insertions(+), 46 deletions(-) diff --git a/distributed/shuffle/graph.py b/distributed/shuffle/graph.py index c0747546512..a338e75a6e6 100644 --- a/distributed/shuffle/graph.py +++ b/distributed/shuffle/graph.py @@ -5,7 +5,7 @@ from dask.base import tokenize from dask.blockwise import BlockwiseDepDict, blockwise from dask.dataframe import DataFrame -from dask.delayed import Delayed +from dask.dataframe.core import partitionwise_graph from dask.highlevelgraph import HighLevelGraph from .common import ShuffleId @@ -48,29 +48,28 @@ def rearrange_by_column_p2p( npartitions = npartitions or df.npartitions token = tokenize(df, column, npartitions) - transferred = df.map_partitions( - shuffle_transfer, - token, - npartitions, - column, - meta=df, - enforce_metadata=False, - transform_divisions=False, - ) - - barrier_key = "shuffle-barrier-" + token - barrier_dsk = {barrier_key: (shuffle_barrier, token, transferred.__dask_keys__())} - barrier = Delayed( - barrier_key, - HighLevelGraph.from_collections( - barrier_key, barrier_dsk, dependencies=[transferred] - ), + # We use `partitionwise_graph` instead of `map_partitions` so we can pass in our own key. + # The scheduler needs the task key to contain the shuffle ID; it's the only way it knows + # what shuffle a task belongs to. + # (Yes, this is rather brittle.) + transfer_name = "shuffle-transfer-" + token + transfer_dsk = partitionwise_graph( + shuffle_transfer, transfer_name, df, token, npartitions, column ) - name = "shuffle-unpack-" + token - dsk = blockwise( + barrier_name = "shuffle-barrier-" + token + barrier_dsk = { + barrier_name: ( + shuffle_barrier, + token, + [(transfer_name, i) for i in range(df.npartitions)], + ) + } + + unpack_name = "shuffle-unpack-" + token + unpack_dsk = blockwise( shuffle_unpack, - name, + unpack_name, "i", token, None, @@ -78,14 +77,29 @@ def rearrange_by_column_p2p( "i", df._meta, None, - barrier_key, + barrier_name, None, numblocks={}, ) + hlg = HighLevelGraph( + { + transfer_name: transfer_dsk, + barrier_name: barrier_dsk, + unpack_name: unpack_dsk, + **df.dask.layers, + }, + { + transfer_name: set(df.__dask_layers__()), + barrier_name: {transfer_name}, + unpack_name: {barrier_name}, + **df.dask.dependencies, + }, + ) + return DataFrame( - HighLevelGraph.from_collections(name, dsk, [barrier]), - name, + hlg, + unpack_name, df._meta, [None] * (npartitions + 1), ) diff --git a/distributed/shuffle/shuffle_scheduler.py b/distributed/shuffle/shuffle_scheduler.py index 18d7f3f7b56..c111aef8825 100644 --- a/distributed/shuffle/shuffle_scheduler.py +++ b/distributed/shuffle/shuffle_scheduler.py @@ -2,7 +2,7 @@ import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from distributed.diagnostics import SchedulerPlugin from distributed.utils import key_split_group @@ -14,7 +14,7 @@ from distributed.scheduler import TaskState -TASK_PREFIX = "('shuffle_" +TASK_PREFIX = "shuffle" @dataclass @@ -25,7 +25,6 @@ class ShuffleState: class ShuffleSchedulerPlugin(SchedulerPlugin): - started_prefixes: dict[str, Callable[[ShuffleId, str], None]] output_keys: dict[str, ShuffleId] shuffles: dict[ShuffleId, ShuffleState] scheduler: Scheduler @@ -34,10 +33,6 @@ def __init__(self) -> None: super().__init__() self.shuffles = {} self.output_keys = {} - self.started_prefixes = { - f"{TASK_PREFIX}transfer-": self.transfer, - f"{TASK_PREFIX}barrier-": self.barrier, - } async def start(self, scheduler: Scheduler) -> None: self.scheduler = scheduler @@ -141,20 +136,21 @@ def erred(self, id: ShuffleId, key: str) -> None: del self.output_keys[k] def transition(self, key: str, start: str, finish: str, *args, **kwargs): - if key.startswith(TASK_PREFIX): - # transfer/barrier starting to run - if start == "waiting" and finish in ("processing", "memory"): - for prefix, handler in self.started_prefixes.items(): - if key.startswith(prefix): - # Is this too brittle, assuming IDs are 32 characters? - # Inferring from the key name is brittle in general... - id = key[len(prefix) : len(prefix) + 32] - return handler(ShuffleId(id), key) - - # transfer/barrier task erred - elif finish == "erred": - id = key_split_group(key).split("-")[-1] - return self.erred(ShuffleId(id), key) + parts = parse_key(key) + if parts and len(parts) == 3: + prefix, group, id = parts + + if prefix == TASK_PREFIX: + if start == "waiting" and finish in ("processing", "memory"): + # transfer/barrier starting to run + if group == "transfer": + return self.transfer(ShuffleId(id), key) + if group == "barrier": + return self.barrier(ShuffleId(id), key) + + # transfer/barrier task erred + elif finish == "erred": + return self.erred(ShuffleId(id), key) # Task completed if start in ("waiting", "processing") and finish in ( @@ -181,10 +177,18 @@ def worker_for_key(self, key: str, npartitions: int, workers: list[str]) -> str: if not m: raise RuntimeError(f"{key} does not look like a DataFrame key") - idx = int(m.group(0)) + idx = int(m.group(1)) addr = worker_for(idx, npartitions, workers) if addr not in self.scheduler.workers: raise RuntimeError( f"Worker {addr} for output partition {idx} no longer known" ) return addr + + +def parse_key(key: str) -> list[str] | None: + if TASK_PREFIX in key[: len(TASK_PREFIX) + 2]: + if key[0] == "(": + key = key_split_group(key) + return key.split("-") + return None From 76a8219031dbd75b1cdff527507a177375eda9aa Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 17 Nov 2021 23:21:46 -0700 Subject: [PATCH 04/10] E2E shuffling tests All pass with https://github.com/dask/dask/pull/8392. Rather crude; needs unit testing. --- distributed/scheduler.py | 1 + distributed/shuffle/shuffle_scheduler.py | 4 +- distributed/shuffle/tests/__init__.py | 0 distributed/shuffle/tests/test_graph.py | 131 +++++++++++++++++++++++ 4 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 distributed/shuffle/tests/__init__.py create mode 100644 distributed/shuffle/tests/test_graph.py diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 7af5149aaef..134a486ccae 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -191,6 +191,7 @@ def nogil(func): DEFAULT_PLUGINS: tuple[SchedulerPlugin, ...] = ( (shuffle.ShuffleSchedulerPlugin(),) if shuffle.SHUFFLE_AVAILABLE else () ) +# ^ TODO this assumes one Scheduler per process; probably a bad idea. ALL_TASK_STATES = declare( set, {"released", "waiting", "no-worker", "processing", "erred", "memory"} diff --git a/distributed/shuffle/shuffle_scheduler.py b/distributed/shuffle/shuffle_scheduler.py index c111aef8825..b52d071f16b 100644 --- a/distributed/shuffle/shuffle_scheduler.py +++ b/distributed/shuffle/shuffle_scheduler.py @@ -2,7 +2,7 @@ import re from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from distributed.diagnostics import SchedulerPlugin from distributed.utils import key_split_group @@ -25,12 +25,12 @@ class ShuffleState: class ShuffleSchedulerPlugin(SchedulerPlugin): + name: ClassVar[str] = "ShuffleSchedulerPlugin" output_keys: dict[str, ShuffleId] shuffles: dict[ShuffleId, ShuffleState] scheduler: Scheduler def __init__(self) -> None: - super().__init__() self.shuffles = {} self.output_keys = {} diff --git a/distributed/shuffle/tests/__init__.py b/distributed/shuffle/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py new file mode 100644 index 00000000000..71948c9ab2f --- /dev/null +++ b/distributed/shuffle/tests/test_graph.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import dask +import dask.dataframe as dd +from dask.blockwise import Blockwise +from dask.dataframe.shuffle import partitioning_index, rearrange_by_column_tasks +from dask.utils_test import hlg_layer_topological + +from distributed.utils_test import gen_cluster + +from .. import ShuffleWorkerExtension, rearrange_by_column_p2p +from ..shuffle_scheduler import TASK_PREFIX, ShuffleSchedulerPlugin, parse_key + +if TYPE_CHECKING: + from distributed import Client, Scheduler, Worker + + +def shuffle( + df: dd.DataFrame, on: str, rearrange=rearrange_by_column_p2p +) -> dd.DataFrame: + "Simple version of `DataFrame.shuffle`, so we don't need dask to know about 'p2p'" + return ( + df.assign( + partition=lambda df: df[on].map_partitions( + partitioning_index, df.npartitions, transform_divisions=False + ) + ) + .pipe(rearrange, "partition") + .drop("partition", axis=1) + ) + + +def test_shuffle_helper(client: Client): + df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + shuffle_helper = shuffle(df, "id", rearrange=rearrange_by_column_tasks) + dask_shuffle = df.shuffle("id", shuffle="tasks") + dd.utils.assert_eq(shuffle_helper, dask_shuffle) + + +def test_graph(): + df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + shuffled = shuffle(df, "id") + shuffled.dask.validate() + + # Check graph optimizes correctly + (opt,) = dask.optimize(shuffled) + opt.dask.validate() + + assert len(opt.dask.layers) == 3 + # create+transfer -> barrier -> unpack+drop_by_shallow_copy + transfer_layer = hlg_layer_topological(opt.dask, 0) + assert isinstance(transfer_layer, Blockwise) + shuffle_id = transfer_layer.indices[0][0] + # ^ don't ask why it's in position 0; some oddity of blockwise fusion. + # Don't be surprised if this breaks unexpectedly. + assert isinstance(hlg_layer_topological(opt.dask, -1), Blockwise) + + # Check that task names contain the shuffle ID itself. + # This is how the scheduler plugin infers the shuffle ID. + for key in opt.dask.to_dict(): + key = str(key) + if "transfer" in key or "barrier" in key: + try: + parts = parse_key(key) + assert parts + prefix, group, id = parts + except Exception: + print(key) + raise + assert prefix == TASK_PREFIX + assert id == shuffle_id + + +def test_basic(client: Client): + df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + shuffled = shuffle(df, "id") + + dd.utils.assert_eq(shuffled, df.shuffle("id", shuffle="tasks")) + # ^ NOTE: this works because `assert_eq` sorts the rows before comparing + + +@gen_cluster([("", 2)] * 4, client=True) +async def test_basic_state(c: Client, s: Scheduler, *workers: Worker): + df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + shuffled = shuffle(df, "id") + + exts: list[ShuffleWorkerExtension] = [w.extensions["shuffle"] for w in workers] + for ext in exts: + assert not ext.shuffles + assert not ext.output_data + assert not ext.waiting_for_metadata + + plugin = s.plugins[ShuffleSchedulerPlugin.name] + assert isinstance(plugin, ShuffleSchedulerPlugin) + assert not plugin.shuffles + assert not plugin.output_keys + + f = c.compute(shuffled) + # TODO this is a bad/pointless test. the `f.done()` is necessary in case the shuffle is really fast. + # To test state more thoroughly, we'd need a way to 'stop the world' at various stages. Like have the + # scheduler pause everything when the barrier is reached. Not sure yet how to implement that. + while ( + not all(len(ext.shuffles) == 1 for ext in exts) + and len(plugin.shuffles) == 1 + and not f.done() + ): + await asyncio.sleep(0.1) + + await f + assert all(not ext.shuffles for ext in exts) + assert not plugin.shuffles + assert not plugin.output_keys + assert not any(ts.worker_restrictions for ts in s.tasks.values()) + + +def test_multiple_linear(client: Client): + df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + s1 = shuffle(df, "id") + s1["x"] = s1["x"] + 1 + s2 = shuffle(s1, "x") + + (opt,) = dask.optimize(s2) + assert len(opt.dask.layers) == 5 + # create+transfer -> barrier -> unpack+transfer -> barrier -> unpack + + dd.utils.assert_eq( + s2, df.assign(x=lambda df: df.x + 1).shuffle("x", shuffle="tasks") + ) From 9955b60bd5f7ade58d5ec8620be74a3a01d7c0ac Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Thu, 18 Nov 2021 01:36:23 -0700 Subject: [PATCH 05/10] Test merges Surprisingly, blockwise decides to merge the two output layers. This really throws things off. The test passes right now by disabling an aggressive assertion, but we need more robust validation here. --- distributed/shuffle/shuffle_scheduler.py | 31 +++++++++---- distributed/shuffle/tests/test_graph.py | 59 ++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 8 deletions(-) diff --git a/distributed/shuffle/shuffle_scheduler.py b/distributed/shuffle/shuffle_scheduler.py index b52d071f16b..774bf03d3c0 100644 --- a/distributed/shuffle/shuffle_scheduler.py +++ b/distributed/shuffle/shuffle_scheduler.py @@ -80,16 +80,14 @@ def barrier(self, id: ShuffleId, key: str) -> None: # Set worker restrictions on output tasks, and register their keys for us to watch in transitions for dts in ts.dependents: - assert ( - len(dts.dependencies) == 1 - ), f"Output task {dts} (of shuffle {id}) should have 1 dependency, not {dts.dependencies}" - - assert ( - not dts.worker_restrictions - ), f"Output task {dts.key} (of shuffle {id}) already has worker restrictions {dts.worker_restrictions}" + # TODO this is often not true thanks to blockwise fusion. + # Currently disabled so tests pass, but needs more careful logic. + # assert ( + # len(dts.dependencies) == 1 + # ), f"Output task {dts} (of shuffle {id}) should have 1 dependency, not {dts.dependencies}" try: - dts._worker_restrictions = { + restrictions = { self.worker_for_key(dts.key, state.out_tasks_left, state.workers) } except (RuntimeError, IndexError, ValueError) as e: @@ -97,6 +95,16 @@ def barrier(self, id: ShuffleId, key: str) -> None: f"Could not pick worker to run dependent {dts.key} of {key}: {e}" ) from None + assert ( + not dts.worker_restrictions or dts.worker_restrictions == restrictions + ), ( + f"Output task {dts.key} (of shuffle {id}) has unexpected worker restrictions " + f"{dts.worker_restrictions}, not {restrictions}" + ) + # TODO if these checks fail, we need to error the task! + # Otherwise it'll still run, and maybe even succeed, but just produce wrong data? + + dts._worker_restrictions = restrictions self.output_keys[dts.key] = id def unpack(self, id: ShuffleId, key: str) -> None: @@ -117,6 +125,13 @@ def unpack(self, id: ShuffleId, key: str) -> None: state.out_tasks_left -= 1 + ts: TaskState = self.scheduler.tasks[key] + assert ( + len(ts._worker_restrictions) == 1 + ), f"Output {key} missing worker restrictions" + ts._worker_restrictions.clear() + del self.output_keys[key] + if not state.out_tasks_left: # Shuffle is done. Yay! del self.shuffles[id] diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py index 71948c9ab2f..98e1c99f217 100644 --- a/distributed/shuffle/tests/test_graph.py +++ b/distributed/shuffle/tests/test_graph.py @@ -3,6 +3,8 @@ import asyncio from typing import TYPE_CHECKING +import pandas as pd + import dask import dask.dataframe as dd from dask.blockwise import Blockwise @@ -129,3 +131,60 @@ def test_multiple_linear(client: Client): dd.utils.assert_eq( s2, df.assign(x=lambda df: df.x + 1).shuffle("x", shuffle="tasks") ) + + +def test_multiple_concurrent(client: Client): + df1 = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + df2 = dd.demo.make_timeseries( + start="2001-01-01", end="2001-12-31", freq="15D", partition_freq="30D" + ) + s1 = shuffle(df1, "id") + s2 = shuffle(df2, "id") + assert s1._name != s2._name + + merged = dd.map_partitions( + lambda p1, p2: pd.merge(p1, p2, on="id"), s1, s2, align_dataframes=False + ) + + # TODO this fails because blockwise merges the two `unpack` layers together like + # X + # / \ --> X + # X X + + # So the HLG structure is + # + # Actual: Expected: + # + # merge + # / \ + # unpack+merge unpack unpack + # / \ | | + # barrier barrier barrier barrier + # | | | | + # xfer xfer xfer xfer + + # And in the scheduler plugin's barrier, we check that the dependents + # of a `barrier` depend only on that one barrier. + # But here, they depend on _both_ barriers. + # This check is probably overly restrictive, because with blockwise fusion + # after the unpack, it's in fact quite likely that other dependencies would + # appear. + # + # This is probably solveable, but tricky. + # We'd have to confirm that: + # 1. The other dependencies aren't barriers + # 2. If the other dependencies are barriers: + # - that shuffle has the same number of partitions _and_ the same set of workers + # + # Otherwise, these tasks just cannot be fused, because their data is going to + # different places. Yet another thing we might need to deal with at the optimization level. + + # _Or_, could the scheduler plugin look for this situation in advance before starting the + # shuffle, and if it sees that multiple shuffles feed into the output tasks of the one + # it's starting, ensure that their worker assignments all line up? (This would mean ensuring + # they all have the same list of workers; in order to be fused they must already have the + # same number output partitions.) + + dd.utils.assert_eq( + merged, dd.merge(df1, df2, on="id", shuffle="tasks"), check_index=False + ) From f79985a32e226c1948b28a0664ec2f959e39d663 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Thu, 18 Nov 2021 01:51:13 -0700 Subject: [PATCH 06/10] Light docs --- distributed/scheduler.py | 2 +- distributed/shuffle/shuffle_scheduler.py | 12 ++++- distributed/shuffle/shuffle_worker.py | 58 +++++++++++++++++++++--- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 134a486ccae..aa2d80775aa 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -188,7 +188,7 @@ def nogil(func): ActiveMemoryManagerExtension, MemorySamplerExtension, ] -DEFAULT_PLUGINS: tuple[SchedulerPlugin, ...] = ( +DEFAULT_PLUGINS: "tuple[SchedulerPlugin, ...]" = ( (shuffle.ShuffleSchedulerPlugin(),) if shuffle.SHUFFLE_AVAILABLE else () ) # ^ TODO this assumes one Scheduler per process; probably a bad idea. diff --git a/distributed/shuffle/shuffle_scheduler.py b/distributed/shuffle/shuffle_scheduler.py index 774bf03d3c0..2b7538d6dd3 100644 --- a/distributed/shuffle/shuffle_scheduler.py +++ b/distributed/shuffle/shuffle_scheduler.py @@ -38,6 +38,7 @@ async def start(self, scheduler: Scheduler) -> None: self.scheduler = scheduler def transfer(self, id: ShuffleId, key: str) -> None: + "Handle a `transfer` task for a shuffle being scheduled" state = self.shuffles.get(id, None) if state: assert ( @@ -69,6 +70,7 @@ def transfer(self, id: ShuffleId, key: str) -> None: ) def barrier(self, id: ShuffleId, key: str) -> None: + "Handle a `barrier` task for a shuffle being scheduled" state = self.shuffles[id] assert ( not state.barrier_reached @@ -108,6 +110,7 @@ def barrier(self, id: ShuffleId, key: str) -> None: self.output_keys[dts.key] = id def unpack(self, id: ShuffleId, key: str) -> None: + "Handle an `unpack` task for a shuffle completing" # Check if all output keys are done # NOTE: we don't actually need this `unpack` step or tracking output keys; @@ -137,6 +140,7 @@ def unpack(self, id: ShuffleId, key: str) -> None: del self.shuffles[id] def erred(self, id: ShuffleId, key: str) -> None: + "Handle any task for a shuffle erroring" try: state = self.shuffles.pop(id) except KeyError: @@ -151,6 +155,7 @@ def erred(self, id: ShuffleId, key: str) -> None: del self.output_keys[k] def transition(self, key: str, start: str, finish: str, *args, **kwargs): + "Watch transitions for keys we care about" parts = parse_key(key) if parts and len(parts) == 3: prefix, group, id = parts @@ -184,8 +189,10 @@ def transition(self, key: str, start: str, finish: str, *args, **kwargs): def worker_for_key(self, key: str, npartitions: int, workers: list[str]) -> str: "Worker address this task should be assigned to" - # Infer which output partition number this task is fetching by parsing its key - # FIXME this is so brittle. + # Infer which output partition number this task is fetching by parsing its key. + # We have to parse keys, instead of generating the list of expected keys, because + # blockwise fusion means they won't just be `shuffle-unpack-abcde`. + # FIXME this feels very hacky/brittle. # For example, after `df.set_index(...).to_delayed()`, you could create # keys that don't have indices in them, and get fused (because they should!). m = re.match(r"\(.+, (\d+)\)$", key) @@ -202,6 +209,7 @@ def worker_for_key(self, key: str, npartitions: int, workers: list[str]) -> str: def parse_key(key: str) -> list[str] | None: + "Split a shuffle key into its prefix, group, and shuffle ID, or None if not a shuffle key." if TASK_PREFIX in key[: len(TASK_PREFIX) + 2]: if key[0] == "(": key = key_split_group(key) diff --git a/distributed/shuffle/shuffle_worker.py b/distributed/shuffle/shuffle_worker.py index 20df2ed0419..ab35c4b7f5f 100644 --- a/distributed/shuffle/shuffle_worker.py +++ b/distributed/shuffle/shuffle_worker.py @@ -49,6 +49,11 @@ def __init__(self, worker: Worker) -> None: ########## def shuffle_init(self, id: ShuffleId, workers: list[str], n_out_tasks: int) -> None: + """ + Handler: initialize a shuffle. Called by scheduler on all workers. + + Must be called exactly once per ID. + """ if id in self.shuffles: raise ValueError( f"Shuffle {id!r} is already registered on worker {self.worker.address}" @@ -71,6 +76,12 @@ def shuffle_receive( output_partition: int, data: pd.DataFrame, ) -> None: + """ + Handler: receive data from a peer. + + The shuffle ID can be unknown. + Calling after the barrier task is an error. + """ try: state = self.shuffles[id] except KeyError: @@ -89,6 +100,13 @@ def shuffle_receive( self.output_data[id][output_partition].append(data) async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: + """ + Handler: note that the barrier task has been reached. Called by a peer. + + The shuffle will be removed if this worker holds no output partitions for it. + + Must be called exactly once per ID. Blocks until `shuffle_init` has been called. + """ state = await self.get_shuffle(id) assert not state.barrier_reached, f"`inputs_done` called again for {id}" state.barrier_reached = True @@ -105,6 +123,15 @@ async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: async def add_partition( self, data: pd.DataFrame, id: ShuffleId, npartitions: int, column: str ) -> None: + """ + Task: Hand off an input partition to the extension. + + This will block until the extension is ready to receive another input partition. + Also blocks until `shuffle_init` has been called. + + Using an unknown ``shuffle_id`` is an error. + Calling after the barrier task is an error. + """ # Block until scheduler has called init state = await self.get_shuffle(id) assert not state.barrier_reached, f"`add_partition` for {id} after barrier" @@ -121,9 +148,14 @@ async def add_partition( await self.send_partition(data, column, id, npartitions, state.workers) async def barrier(self, id: ShuffleId) -> None: - # NOTE: requires workers list. This is guaranteed because it depends on `add_partition`, - # which got the workers list from the scheduler. So this task must run on a worker where - # `add_partition` has already run. + """ + Task: Note that the barrier task has been reached (`add_partition` called for all input partitions) + + Using an unknown ``shuffle_id`` is an error. + Must be called exactly once per ID. + Blocks until `shuffle_init` has been called (on all workers). + Calling this before all partitions have been added will cause `add_partition` to fail. + """ state = await self.get_shuffle(id) assert not state.barrier_reached, f"`barrier` for {id} called multiple times" @@ -140,13 +172,20 @@ async def barrier(self, id: ShuffleId) -> None: async def get_output_partition( self, id: ShuffleId, i: int, empty: pd.DataFrame ) -> pd.DataFrame: - state = self.shuffles[id] - # ^ Don't need to `get_shuffle`; `shuffle_inputs_done` has run already and guarantees it's there + """ + Task: Retrieve a shuffled output partition from the extension. + + After calling on the final output partition remaining on this worker, the shuffle will be cleaned up. + + Using an unknown ``shuffle_id`` is an error. + Requesting a partition which doesn't belong on this worker, or has already been retrieved, is an error. + """ + state = await self.get_shuffle(id) # should never have to wait assert state.barrier_reached, f"`get_output_partition` for {id} before barrier" assert ( state.out_parts_left > 0 ), f"No outputs remaining, but requested output partition {i} on {self.worker.address} for {id}." - # ^ Note: this is impossible with our cleanup-on-empty + # ^ Note: impossible with our cleanup-on-empty worker = worker_for(i, state.npartitions, state.workers) assert worker == self.worker.address, ( @@ -172,6 +211,7 @@ async def get_output_partition( ######### def remove(self, id: ShuffleId) -> None: + "Remove state for this shuffle. The shuffle must be complete and in a valid state." state = self.shuffles.pop(id) assert state.barrier_reached, f"Removed {id} before barrier" assert ( @@ -187,7 +227,8 @@ def remove(self, id: ShuffleId) -> None: not data ), f"Removed {id}, which still has data for output partitions {list(data)}" - async def get_shuffle(self, id: ShuffleId): + async def get_shuffle(self, id: ShuffleId) -> ShuffleState: + "Get the `ShuffleState`, blocking until it's been received from the scheduler." try: return self.shuffles[id] except KeyError: @@ -212,6 +253,7 @@ async def send_partition( npartitions: int, workers: list[str], ) -> None: + "Split up an input partition and send its parts to peers." tasks = [] # TODO grouping is blocking, should it be offloaded to a thread? # It mostly doesn't release the GIL though, so may not make much difference. @@ -231,9 +273,11 @@ async def send_partition( @property def loop(self) -> asyncio.AbstractEventLoop: + "The asyncio event loop for the worker" return self.worker.loop.asyncio_loop # type: ignore def sync(self, coro: Coroutine[object, object, T]) -> T: + "Run an async function on the worker's event loop, synchronously from another thread." # Is it a bad idea not to use `distributed.utils.sync`? # It's much nicer to use asyncio, because among other things it gives us typechecking. return asyncio.run_coroutine_threadsafe(coro, self.loop).result() From 2afc22c214f81d46f6d3f15a8b96c90e0d51ec6b Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 19 Nov 2021 14:58:43 -0700 Subject: [PATCH 07/10] Better error when not running on worker Whenver I forget to switch to https://github.com/dask/distributed/pull/5520, the errors are confusing. --- distributed/shuffle/graph.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/graph.py b/distributed/shuffle/graph.py index a338e75a6e6..db7ab7efde7 100644 --- a/distributed/shuffle/graph.py +++ b/distributed/shuffle/graph.py @@ -18,7 +18,20 @@ def get_shuffle_extension() -> ShuffleWorkerExtension: from distributed import get_worker - return get_worker().extensions["shuffle"] + try: + worker = get_worker() + except ValueError as e: + raise RuntimeError( + "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; " + "please confirm that you've created a distributed Client and are submitting this computation through it." + ) from e + extension: ShuffleWorkerExtension | None = worker.extensions.get("shuffle") + if not extension: + raise RuntimeError( + f"The worker {worker.address} does not have a ShuffleExtension. " + "Is pandas installed on the worker?" + ) + return extension def shuffle_transfer( From e3170aef09428a3ca92c2b19baceee4fff03aac6 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 19 Nov 2021 15:02:13 -0700 Subject: [PATCH 08/10] Add responses to comments from 9b9a68b --- distributed/shuffle/shuffle_worker.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/distributed/shuffle/shuffle_worker.py b/distributed/shuffle/shuffle_worker.py index ab35c4b7f5f..9ded0a294ef 100644 --- a/distributed/shuffle/shuffle_worker.py +++ b/distributed/shuffle/shuffle_worker.py @@ -112,8 +112,7 @@ async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: state.barrier_reached = True if not state.out_parts_left: - # No output partitions, remove shuffle it now: - # `get_output_partition` will never be called. + # No output partitions, remove shuffle now: `get_output_partition` will never be called. # This happens when there are fewer output partitions than workers. self.remove(id) @@ -255,9 +254,11 @@ async def send_partition( ) -> None: "Split up an input partition and send its parts to peers." tasks = [] - # TODO grouping is blocking, should it be offloaded to a thread? - # It mostly doesn't release the GIL though, so may not make much difference. + # NOTE: `groupby` blocks the event loop, but it also holds the GIL, + # so we don't bother offloading to a thread. See bpo-7946. for output_partition, data in data.groupby(column): + # NOTE: `column` must refer to an integer column, which is the output partition number for the row. + # This is always `_partitions`, added by `dask/dataframe/shuffle.py::shuffle`. addr = worker_for(int(output_partition), npartitions, workers) task = asyncio.create_task( self.worker.rpc(addr).shuffle_receive( @@ -268,7 +269,9 @@ async def send_partition( ) tasks.append(task) - # TODO handle errors and cancellation here + # TODO Once RerunGroup logic exists (https://github.com/dask/distributed/issues/5403), + # handle errors and cancellation here in a way that lets other workers cancel & clean up their shuffles. + # Without it, letting errors kill the task is all we can do. await asyncio.gather(*tasks) @property From 9315e4c78a3722c936d5db0fc3918d63f12a6ce4 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 19 Nov 2021 15:34:29 -0700 Subject: [PATCH 09/10] Remove waiting for shuffle_init See https://github.com/dask/distributed/pull/5524#discussion_r752647896. Since messages from scheduler to workers remain ordered in `BatchedSend` (and TCP preserves ordering), we should be able to count on the `shuffle_init` always hitting the worker before the `add_partition` does, so long as we trust the transition logic of our plugin. --- distributed/shuffle/graph.py | 13 ++++---- distributed/shuffle/shuffle_worker.py | 42 +++++++------------------ distributed/shuffle/tests/test_graph.py | 1 - 3 files changed, 17 insertions(+), 39 deletions(-) diff --git a/distributed/shuffle/graph.py b/distributed/shuffle/graph.py index db7ab7efde7..69a797fe9d4 100644 --- a/distributed/shuffle/graph.py +++ b/distributed/shuffle/graph.py @@ -41,18 +41,17 @@ def shuffle_transfer( ext.sync(ext.add_partition(data, id, npartitions, column)) -def shuffle_unpack( - id: ShuffleId, i: int, empty: pd.DataFrame, barrier=None -) -> pd.DataFrame: - ext = get_shuffle_extension() - return ext.sync(ext.get_output_partition(id, i, empty)) - - def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None: ext = get_shuffle_extension() ext.sync(ext.barrier(id)) +def shuffle_unpack( + id: ShuffleId, i: int, empty: pd.DataFrame, barrier=None +) -> pd.DataFrame: + return get_shuffle_extension().get_output_partition(id, i, empty) + + def rearrange_by_column_p2p( df: DataFrame, column: str, diff --git a/distributed/shuffle/shuffle_worker.py b/distributed/shuffle/shuffle_worker.py index 9ded0a294ef..d93de519e3a 100644 --- a/distributed/shuffle/shuffle_worker.py +++ b/distributed/shuffle/shuffle_worker.py @@ -29,7 +29,6 @@ class ShuffleWorkerExtension: "Extend the Worker with routes and state for peer-to-peer shuffles" worker: Worker shuffles: dict[ShuffleId, ShuffleState] - waiting_for_metadata: dict[ShuffleId, asyncio.Event] output_data: defaultdict[ShuffleId, defaultdict[int, list[pd.DataFrame]]] def __init__(self, worker: Worker) -> None: @@ -42,7 +41,6 @@ def __init__(self, worker: Worker) -> None: # Initialize self.worker: Worker = worker self.shuffles = {} - self.waiting_for_metadata = {} self.output_data = defaultdict(lambda: defaultdict(list)) # Handlers @@ -63,11 +61,6 @@ def shuffle_init(self, id: ShuffleId, workers: list[str], n_out_tasks: int) -> N n_out_tasks, npartitions_for(self.worker.address, n_out_tasks, workers), ) - try: - # Invariant: if `waiting_for_metadata` event is set, key is already in `shuffles` - self.waiting_for_metadata[id].set() - except KeyError: - pass def shuffle_receive( self, @@ -99,7 +92,7 @@ def shuffle_receive( self.output_data[id][output_partition].append(data) - async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: + def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: """ Handler: note that the barrier task has been reached. Called by a peer. @@ -107,7 +100,7 @@ async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: Must be called exactly once per ID. Blocks until `shuffle_init` has been called. """ - state = await self.get_shuffle(id) + state = self.get_shuffle(id) assert not state.barrier_reached, f"`inputs_done` called again for {id}" state.barrier_reached = True @@ -132,7 +125,7 @@ async def add_partition( Calling after the barrier task is an error. """ # Block until scheduler has called init - state = await self.get_shuffle(id) + state = self.get_shuffle(id) assert not state.barrier_reached, f"`add_partition` for {id} after barrier" if npartitions != state.npartitions: @@ -155,7 +148,7 @@ async def barrier(self, id: ShuffleId) -> None: Blocks until `shuffle_init` has been called (on all workers). Calling this before all partitions have been added will cause `add_partition` to fail. """ - state = await self.get_shuffle(id) + state = self.get_shuffle(id) assert not state.barrier_reached, f"`barrier` for {id} called multiple times" # Call `shuffle_inputs_done` on peers. @@ -168,7 +161,7 @@ async def barrier(self, id: ShuffleId) -> None: ), ) - async def get_output_partition( + def get_output_partition( self, id: ShuffleId, i: int, empty: pd.DataFrame ) -> pd.DataFrame: """ @@ -179,7 +172,7 @@ async def get_output_partition( Using an unknown ``shuffle_id`` is an error. Requesting a partition which doesn't belong on this worker, or has already been retrieved, is an error. """ - state = await self.get_shuffle(id) # should never have to wait + state = self.get_shuffle(id) assert state.barrier_reached, f"`get_output_partition` for {id} before barrier" assert ( state.out_parts_left > 0 @@ -217,32 +210,19 @@ def remove(self, id: ShuffleId) -> None: not state.out_parts_left ), f"Removed {id} with {state.out_parts_left} outputs left" - event = self.waiting_for_metadata.pop(id, None) - if event: - assert event.is_set(), f"Removed {id} while still waiting for metadata" - data = self.output_data.pop(id, None) assert ( not data ), f"Removed {id}, which still has data for output partitions {list(data)}" - async def get_shuffle(self, id: ShuffleId) -> ShuffleState: - "Get the `ShuffleState`, blocking until it's been received from the scheduler." + def get_shuffle(self, id: ShuffleId) -> ShuffleState: + "Get the `ShuffleState` by ID, raise ValueError if it's not registered." try: return self.shuffles[id] except KeyError: - event = self.waiting_for_metadata.setdefault(id, asyncio.Event()) - try: - await asyncio.wait_for(event.wait(), timeout=5) # TODO config - except TimeoutError: - raise TimeoutError( - f"Timed out waiting for scheduler to start shuffle {id}" - ) from None - # Invariant: once `waiting_for_metadata` event is set, key is already in `shuffles`. - # And once key is in `shuffles`, no `get_shuffle` will create a new event. - # So we can safely remove the event now. - self.waiting_for_metadata.pop(id, None) - return self.shuffles[id] + raise ValueError( + f"Shuffle {id!r} is not registered on worker {self.worker.address}" + ) from None async def send_partition( self, diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py index 98e1c99f217..048ea949788 100644 --- a/distributed/shuffle/tests/test_graph.py +++ b/distributed/shuffle/tests/test_graph.py @@ -93,7 +93,6 @@ async def test_basic_state(c: Client, s: Scheduler, *workers: Worker): for ext in exts: assert not ext.shuffles assert not ext.output_data - assert not ext.waiting_for_metadata plugin = s.plugins[ShuffleSchedulerPlugin.name] assert isinstance(plugin, ShuffleSchedulerPlugin) From 04833a3899a7bf1de0fa8edca9a7ce9c97fc273c Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 19 Nov 2021 16:08:40 -0700 Subject: [PATCH 10/10] Reuse some tests from #5520 --- distributed/shuffle/tests/test_common.py | 51 +++++ .../shuffle/tests/test_shuffle_worker.py | 197 ++++++++++++++++++ 2 files changed, 248 insertions(+) create mode 100644 distributed/shuffle/tests/test_common.py create mode 100644 distributed/shuffle/tests/test_shuffle_worker.py diff --git a/distributed/shuffle/tests/test_common.py b/distributed/shuffle/tests/test_common.py new file mode 100644 index 00000000000..d99ef4e0e3b --- /dev/null +++ b/distributed/shuffle/tests/test_common.py @@ -0,0 +1,51 @@ +import string +from collections import Counter + +import pytest + +from ..common import npartitions_for, partition_range, worker_for + + +@pytest.mark.parametrize("npartitions", [1, 2, 3, 5]) +@pytest.mark.parametrize("n_workers", [1, 2, 3, 5]) +def test_worker_for_distribution(npartitions: int, n_workers: int): + "Test that `worker_for` distributes evenly" + workers = list(string.ascii_lowercase[:n_workers]) + + with pytest.raises(IndexError, match="Negative"): + worker_for(-1, npartitions, workers) + + assignments = [worker_for(i, npartitions, workers) for i in range(npartitions)] + + # Test `partition_range` + for w in workers: + first, last = partition_range(w, npartitions, workers) + assert all( + [ + first <= p_i <= last if a == w else p_i < first or p_i > last + for p_i, a in enumerate(assignments) + ] + ) + + counter = Counter(assignments) + assert len(counter) == min(npartitions, n_workers) + + # Test `npartitions_for` + calculated_counter = {w: npartitions_for(w, npartitions, workers) for w in workers} + assert counter == { + w: count for w, count in calculated_counter.items() if count != 0 + } + assert calculated_counter.keys() == set(workers) + # ^ this also checks that workers receiving 0 output partitions were calculated properly + + # Test the distribution of worker assignments. + # All workers should be assigned the same number of partitions, or if + # there's an odd number, some workers will be assigned only one extra partition. + counts = set(counter.values()) + assert len(counts) <= 2 + if len(counts) == 2: + lo, hi = sorted(counts) + assert lo == hi - 1 + + with pytest.raises(IndexError, match="does not exist"): + worker_for(npartitions, npartitions, workers) diff --git a/distributed/shuffle/tests/test_shuffle_worker.py b/distributed/shuffle/tests/test_shuffle_worker.py new file mode 100644 index 00000000000..ef87ee58d87 --- /dev/null +++ b/distributed/shuffle/tests/test_shuffle_worker.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + +from distributed.utils_test import gen_cluster + +from ..common import ShuffleId, npartitions_for, worker_for +from ..shuffle_worker import ShuffleState, ShuffleWorkerExtension + +if TYPE_CHECKING: + from distributed import Client, Scheduler, Worker + + +@gen_cluster([("", 1)]) +async def test_installation(s: Scheduler, worker: Worker): + ext = worker.extensions["shuffle"] + assert isinstance(ext, ShuffleWorkerExtension) + assert worker.stream_handlers["shuffle_init"] == ext.shuffle_init + assert worker.handlers["shuffle_receive"] == ext.shuffle_receive + assert worker.handlers["shuffle_inputs_done"] == ext.shuffle_inputs_done + + assert ext.worker is worker + assert not ext.shuffles + assert not ext.output_data + + +@gen_cluster([("", 1)]) +async def test_init(s: Scheduler, worker: Worker): + ext: ShuffleWorkerExtension = worker.extensions["shuffle"] + assert not ext.shuffles + + id = ShuffleId("foo") + workers = [worker.address, "tcp://foo"] + npartitions = 4 + + ext.shuffle_init(id, workers, npartitions) + assert ext.shuffles == { + id: ShuffleState(workers, npartitions, 2, barrier_reached=False) + } + + with pytest.raises(ValueError, match="already registered"): + ext.shuffle_init(id, [], 0) + + # Unchanged after trying to re-register + assert list(ext.shuffles) == [id] + + +@gen_cluster([("", 1)] * 4) +async def test_add_partition(s: Scheduler, *workers: Worker): + exts: dict[str, ShuffleWorkerExtension] = { + w.address: w.extensions["shuffle"] for w in workers + } + + id = ShuffleId("foo") + npartitions = 8 + addrs = list(exts) + column = "partition" + + for ext in exts.values(): + ext.shuffle_init(id, addrs, npartitions) + + partition = pd.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g", "h"], + column: [0, 1, 2, 3, 4, 5, 6, 7], + } + ) + + ext = exts[addrs[0]] + await ext.add_partition(partition, id, npartitions, column) + + for i, data in partition.groupby(column): + i = int(i) + addr = worker_for(i, npartitions, addrs) + ext = exts[addr] + received = ext.output_data[id][i] + assert len(received) == 1 + assert_frame_equal(data, received[0]) + + with pytest.raises(ValueError, match="not registered"): + await ext.add_partition(partition, ShuffleId("bar"), npartitions, column) + + # TODO (resilience stage) test failed sends + + +@gen_cluster([("", 1)] * 4, client=True) +async def test_barrier(c: Client, s: Scheduler, *workers: Worker): + exts: dict[str, ShuffleWorkerExtension] = { + w.address: w.extensions["shuffle"] for w in workers + } + + id = ShuffleId("foo") + npartitions = 3 + addrs = list(exts) + column = "partition" + + for ext in exts.values(): + ext.shuffle_init(id, addrs, npartitions) + + partition = pd.DataFrame( + { + "A": ["a", "b", "c"], + column: [0, 1, 2], + } + ) + first_ext = exts[addrs[0]] + await first_ext.add_partition(partition, id, npartitions, column) + + await first_ext.barrier(id) + + # Check all workers have been informed of the barrier + for addr, ext in exts.items(): + if npartitions_for(addr, npartitions, addrs): + assert ext.shuffles[id].barrier_reached + else: + # No output partitions on this worker; shuffle already cleaned up + assert not ext.shuffles + assert not ext.output_data + + # Test check on self + with pytest.raises(AssertionError, match="called multiple times"): + await first_ext.barrier(id) + + first_ext.shuffles[id].barrier_reached = False + + # RPC to other workers fails + with pytest.raises(AssertionError, match="`inputs_done` called again"): + await first_ext.barrier(id) + + +@gen_cluster([("", 1)] * 4, client=True) +async def test_get_partition(c: Client, s: Scheduler, *workers: Worker): + exts: dict[str, ShuffleWorkerExtension] = { + w.address: w.extensions["shuffle"] for w in workers + } + + id = ShuffleId("foo") + npartitions = 8 + addrs = list(exts) + column = "partition" + + for ext in exts.values(): + ext.shuffle_init(id, addrs, npartitions) + + p1 = pd.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g", "h"], + "partition": [0, 1, 2, 3, 4, 5, 6, 6], + } + ) + p2 = pd.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g", "h"], + "partition": [0, 1, 2, 3, 0, 0, 2, 3], + } + ) + + first_ext = exts[addrs[0]] + await asyncio.gather( + first_ext.add_partition(p1, id, npartitions, column), + first_ext.add_partition(p2, id, npartitions, column), + ) + await first_ext.barrier(id) + + empty = pd.DataFrame({"A": [], column: []}) + + with pytest.raises(AssertionError, match="was expected to go"): + first_ext.get_output_partition(id, 7, empty) + + full = pd.concat([p1, p2]) + expected_groups = full.groupby("partition") + for output_i in range(npartitions): + addr = worker_for(output_i, npartitions, addrs) + ext = exts[addr] + shuffle = ext.shuffles[id] + parts_left_before = shuffle.out_parts_left + + result = ext.get_output_partition(id, output_i, empty) + + try: + expected = expected_groups.get_group(output_i) + except KeyError: + expected = empty + assert_frame_equal(expected, result) + assert shuffle.out_parts_left == parts_left_before - 1 + + # Once all partitions are retrieved, shuffles are cleaned up + for ext in exts.values(): + assert not ext.shuffles + + with pytest.raises(ValueError, match="not registered"): + first_ext.get_output_partition(id, 0, empty)