-
-
Notifications
You must be signed in to change notification settings - Fork 719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP co-assign related root-ish tasks #4899
Changes from all commits
7f454fd
8e8f7f1
b12d490
064be2e
1484e65
5163e77
d9df8be
7b9728f
0fbb75e
f2da0bc
d3db281
c929c96
f25ed42
f50daf1
3a73508
cfe37f6
0b5486f
2c2bb68
5e58b5a
4132583
df2cf70
19107d8
8e45244
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -950,6 +950,9 @@ class TaskGroup: | |
_start: double | ||
_stop: double | ||
_all_durations: object | ||
_last_worker: WorkerState | ||
_last_worker_tasks_left: int # TODO Py_ssize_t? | ||
_last_worker_priority: tuple # TODO remove (debugging only) | ||
|
||
def __init__(self, name: str): | ||
self._name = name | ||
|
@@ -964,6 +967,9 @@ def __init__(self, name: str): | |
self._start = 0.0 | ||
self._stop = 0.0 | ||
self._all_durations = defaultdict(float) | ||
self._last_worker = None | ||
self._last_worker_tasks_left = 0 | ||
self._last_worker_priority = () | ||
|
||
@property | ||
def name(self): | ||
|
@@ -1009,6 +1015,26 @@ def start(self): | |
def stop(self): | ||
return self._stop | ||
|
||
@property | ||
def last_worker(self): | ||
return self._last_worker | ||
|
||
@property | ||
def last_worker_tasks_left(self): | ||
return self._last_worker_tasks_left | ||
|
||
@last_worker_tasks_left.setter | ||
def last_worker_tasks_left(self, n: int): | ||
self._last_worker_tasks_left = n | ||
|
||
@property | ||
def last_worker_priority(self): | ||
return self._last_worker_priority | ||
|
||
@last_worker_priority.setter | ||
def last_worker_priority(self, x: tuple): | ||
self._last_worker_priority = x | ||
|
||
@ccall | ||
def add(self, o): | ||
ts: TaskState = o | ||
|
@@ -2337,14 +2363,20 @@ def decide_worker(self, ts: TaskState) -> WorkerState: | |
ts.state = "no-worker" | ||
return ws | ||
|
||
if ts._dependencies or valid_workers is not None: | ||
if ( | ||
ts._dependencies | ||
or valid_workers is not None | ||
or ts._group._last_worker is not None | ||
): | ||
ws = decide_worker( | ||
ts, | ||
self._workers_dv.values(), | ||
valid_workers, | ||
partial(self.worker_objective, ts), | ||
self._total_nthreads, | ||
) | ||
else: | ||
# Fastpath when there are no related tasks or restrictions | ||
worker_pool = self._idle or self._workers | ||
worker_pool_dv = cast(dict, worker_pool) | ||
wp_vals = worker_pool.values() | ||
|
@@ -2366,6 +2398,15 @@ def decide_worker(self, ts: TaskState) -> WorkerState: | |
else: # dumb but fast in large case | ||
ws = wp_vals[self._n_tasks % n_workers] | ||
|
||
ts._group._last_worker = ws | ||
group_tasks_per_thread = ( | ||
len(ts._group) / self._total_nthreads if self._total_nthreads > 0 else 0 | ||
) | ||
ts._group._last_worker_tasks_left = ( | ||
math.floor(group_tasks_per_thread * ws._nthreads) - 1 | ||
) | ||
ts._group._last_worker_priority = ts._priority | ||
|
||
if self._validate: | ||
assert ws is None or isinstance(ws, WorkerState), ( | ||
type(ws), | ||
|
@@ -4671,6 +4712,9 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): | |
recommendations[ts._key] = "released" | ||
else: # pure data | ||
recommendations[ts._key] = "forgotten" | ||
if ts._group._last_worker is ws: | ||
ts._group._last_worker = None | ||
ts._group._last_worker_tasks_left = 0 | ||
ws._has_what.clear() | ||
|
||
self.transitions(recommendations) | ||
|
@@ -6244,8 +6288,9 @@ async def retire_workers( | |
logger.info("Retire workers %s", workers) | ||
|
||
# Keys orphaned by retiring those workers | ||
keys = {k for w in workers for k in w.has_what} | ||
keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} | ||
tasks = {ts for w in workers for ts in w.has_what} | ||
keys = {ts._key for ts in tasks if ts._who_has.issubset(workers)} | ||
groups = {ts._group for ts in tasks} | ||
|
||
if keys: | ||
other_workers = set(parent._workers_dv.values()) - workers | ||
|
@@ -6260,6 +6305,11 @@ async def retire_workers( | |
lock=False, | ||
) | ||
|
||
for group in groups: | ||
if group._last_worker in workers: | ||
group._last_worker = None | ||
group._last_worker_tasks_left = 0 | ||
|
||
worker_keys = {ws._address: ws.identity() for ws in workers} | ||
if close_workers: | ||
await asyncio.gather( | ||
|
@@ -7471,11 +7521,52 @@ def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState): | |
@cfunc | ||
@exceptval(check=False) | ||
def decide_worker( | ||
ts: TaskState, all_workers, valid_workers: set, objective | ||
ts: TaskState, | ||
all_workers, | ||
valid_workers: set, | ||
objective, | ||
total_nthreads: Py_ssize_t, | ||
) -> WorkerState: | ||
""" | ||
r""" | ||
Decide which worker should take task *ts*. | ||
|
||
There are two modes: root(ish) tasks, and normal tasks. | ||
|
||
Root(ish) tasks | ||
~~~~~~~~~~~~~~~ | ||
|
||
Root(ish) have no (or very very few) dependencies and fan out widely: | ||
they belong to TaskGroups that contain more tasks than there are workers. | ||
We want neighboring root tasks to run on the same worker, since there's a | ||
good chance those neighbors will be combined in a downstream operation: | ||
|
||
i j | ||
/ \ / \ | ||
e f g h | ||
| | | | | ||
a b c d | ||
\ \ / / | ||
X | ||
|
||
In the above case, we want ``a`` and ``b`` to run on the same worker, | ||
and ``c`` and ``d`` to run on the same worker, reducing future | ||
data transfer. We can also ignore the location of ``X``, because | ||
as a common dependency, it will eventually get transferred everywhere. | ||
|
||
Calculaing this directly from the graph would be expensive, so instead | ||
we use task priority as a proxy. We aim to send tasks close in priority | ||
within a `TaskGroup` to the same worker. To do this efficiently, we rely | ||
on the fact that `decide_worker` is generally called in priority order | ||
for root tasks (because `Scheduler.update_graph` creates recommendations | ||
in priority order), and track only the last worker used for a `TaskGroup`, | ||
and how many more tasks can be assigned to it before picking a new one. | ||
|
||
By colocating related root tasks, we ensure that placing thier downstream | ||
normal tasks is set up for success. | ||
|
||
Normal tasks | ||
~~~~~~~~~~~~ | ||
|
||
We choose the worker that has the data on which *ts* depends. | ||
|
||
If several workers have dependencies then we choose the less-busy worker. | ||
|
@@ -7488,36 +7579,83 @@ def decide_worker( | |
of bytes sent between workers. This is determined by calling the | ||
*objective* function. | ||
""" | ||
ws: WorkerState = None | ||
wws: WorkerState | ||
dts: TaskState | ||
|
||
group: TaskGroup = ts._group | ||
ws: WorkerState = group._last_worker | ||
|
||
if valid_workers is not None: | ||
total_nthreads = sum(wws._nthreads for wws in valid_workers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This walks through all workers for all tasks. We may not be able to do this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See below; I believe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, grand. |
||
|
||
group_tasks_per_thread = (len(group) / total_nthreads) if total_nthreads > 0 else 0 | ||
ignore_deps_while_picking: bool = False | ||
|
||
# Try to schedule sibling root-like tasks on the same workers. | ||
if ( | ||
ws is not None | ||
and group._last_worker_priority is not None | ||
# ^ `decide_worker` hasn't previously been called out of priority order | ||
and group_tasks_per_thread > 1 | ||
and sum(map(len, group._dependencies)) < 5 # TODO what number | ||
): | ||
if group._last_worker_tasks_left > 0: | ||
group._last_worker_tasks_left -= 1 | ||
if group._last_worker_priority < ts.priority and ( | ||
valid_workers is None or ws in valid_workers | ||
): | ||
group._last_worker_priority = ts.priority | ||
return ws | ||
|
||
# `decide_worker` called out of priority order, or the last used worker is not valid for this task. | ||
# This is probably not actually a root-ish task; disable root-ish mode in the future. | ||
group._last_worker = None | ||
group._last_worker_tasks_left = 0 | ||
group._last_worker_priority = None | ||
|
||
# Previous worker is fully assigned, so pick a new worker. | ||
ignore_deps_while_picking = True | ||
|
||
deps: set = ts._dependencies | ||
dts: TaskState | ||
candidates: set | ||
assert all([dts._who_has for dts in deps]) | ||
if ts._actor: | ||
candidates = set(all_workers) | ||
if ignore_deps_while_picking: | ||
candidates = valid_workers if valid_workers is not None else set(all_workers) | ||
else: | ||
candidates = {wws for dts in deps for wws in dts._who_has} | ||
if valid_workers is None: | ||
if not candidates: | ||
if ts._actor: | ||
candidates = set(all_workers) | ||
else: | ||
candidates &= valid_workers | ||
if not candidates: | ||
candidates = valid_workers | ||
else: | ||
candidates = {wws for dts in deps for wws in dts._who_has} | ||
if valid_workers is None: | ||
if not candidates: | ||
if ts._loose_restrictions: | ||
ws = decide_worker(ts, all_workers, None, objective) | ||
return ws | ||
candidates = set(all_workers) | ||
else: | ||
candidates &= valid_workers | ||
if not candidates: | ||
candidates = valid_workers | ||
if not candidates: | ||
if ts._loose_restrictions: | ||
ws = decide_worker( | ||
ts, all_workers, None, objective, total_nthreads | ||
) | ||
return ws | ||
|
||
ncandidates: Py_ssize_t = len(candidates) | ||
if ncandidates == 0: | ||
pass | ||
elif ncandidates == 1: | ||
# NOTE: this is the ideal case: all the deps are already on the same worker. | ||
for ws in candidates: | ||
break | ||
else: | ||
ws = min(candidates, key=objective) | ||
|
||
if group._last_worker_priority is not None: | ||
group._last_worker = ws | ||
group._last_worker_tasks_left = ( | ||
math.floor(group_tasks_per_thread * ws._nthreads) - 1 | ||
) | ||
group._last_worker_priority = ts.priority | ||
return ws | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<3 the ascii art
Comment/question: Do we want to explain all of this here? Historically I haven't put the logic behind heuristics in the code. This is a subjective opinion, and far from universal, but I find that heavily commented/documented logic makes it harder to understand the code at a glance. I really like that the current decide_worker implementation fits in a terminal window. I think that single-line comments are cool, but that long multi-line comments would better be written as documentation.
Thoughts? If you are not in disagreement then I would encourage us to write up a small docpage or maybe a blogpost and then link to that external resource from the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was also planning on updating https://distributed.dask.org/en/latest/scheduling-policies.html#choosing-workers, probably with this same ascii art. So just linking to that page in the docstring seems appropriate.