Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error on cancelled tasks due to disconnect #8705

Merged
merged 11 commits into from
Jun 24, 2024
165 changes: 130 additions & 35 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,68 @@
TOPIC_PREFIX_FORWARDED_LOG_RECORD = "forwarded-log-record"


class FutureCancelledError(CancelledError):
key: str
reason: str
msg: str | None

def __init__(self, key: str, reason: str | None, msg: str | None = None):
self.key = key
self.reason = reason if reason else "unknown"
self.msg = msg

def __str__(self) -> str:
result = f"{self.key} cancelled for reason: {self.reason}."
if self.msg:
result = "\n".join([result, self.msg])
return result


class FuturesCancelledError(CancelledError):
error_groups: list[CancelledFuturesGroup]

def __init__(self, error_groups: list[CancelledFuturesGroup]):
self.error_groups = sorted(
error_groups, key=lambda group: len(group.errors), reverse=True
)

def __str__(self):
count = sum(map(lambda group: len(group.errors), self.error_groups))
result = f"{count} Future{'s' if count > 1 else ''} cancelled:"
return "\n".join(
[result, "Reasons:"] + [str(group) for group in self.error_groups]
)


class CancelledFuturesGroup:
#: Errors of the cancelled futures
errors: list[FutureCancelledError]

#: Reason for cancelling the futures
reason: str

__slots__ = tuple(__annotations__)

def __init__(self, errors: list[FutureCancelledError], reason: str):
self.errors = errors
self.reason = reason

def __str__(self):
keys = [error.key for error in self.errors]
example_message = None

for error in self.errors:
if error.msg:
example_message = error.msg
break

return (
f"{len(keys)} Future{'s' if len(keys) > 1 else ''} cancelled for reason: "
f"{self.reason}.\nMessage: {example_message}\n"
f"Future{'s' if len(keys) > 1 else ''}: {keys}"
)


class SourceCode(NamedTuple):
code: str
lineno_frame: int
Expand Down Expand Up @@ -245,7 +307,7 @@ def _bind_late(self):
if self.key in self._client.futures:
self._state = self._client.futures[self.key]
else:
self._state = self._client.futures[self.key] = FutureState()
self._state = self._client.futures[self.key] = FutureState(self.key)

if self._inform:
self._client._send_to_scheduler(
Expand Down Expand Up @@ -337,8 +399,10 @@ async def _result(self, raiseit=True):
raise exc.with_traceback(tb)
else:
return exc
elif self.status == "cancelled":
exception = CancelledError(self.key)
elif self.cancelled():
assert self._state
exception = self._state.exception
assert isinstance(exception, CancelledError)
if raiseit:
raise exception
else:
Expand Down Expand Up @@ -414,15 +478,15 @@ def execute_callback(fut):
done_callback, self, partial(cls._cb_executor.submit, execute_callback)
)

def cancel(self, **kwargs):
def cancel(self, reason=None, msg=None, **kwargs):
"""Cancel the request to run this future

See Also
--------
Client.cancel
"""
self._verify_initialized()
return self.client.cancel([self], **kwargs)
return self.client.cancel([self], reason=reason, msg=msg, **kwargs)

def retry(self, **kwargs):
"""Retry this future if it has failed
Expand Down Expand Up @@ -552,11 +616,14 @@ class FutureState:
This is shared between all Futures with the same key and client.
"""

__slots__ = ("_event", "status", "type", "exception", "traceback")
__slots__ = ("_event", "key", "status", "type", "exception", "traceback")

def __init__(self):
def __init__(self, key: str):
self._event = None
self.key = key
self.exception = None
self.status = "pending"
self.traceback = None
self.type = None

def _get_event(self):
Expand All @@ -568,10 +635,10 @@ def _get_event(self):
event = self._event = asyncio.Event()
return event

def cancel(self):
def cancel(self, reason=None, msg=None):
"""Cancels the operation"""
self.status = "cancelled"
self.exception = CancelledError()
self.exception = FutureCancelledError(key=self.key, reason=reason, msg=msg)
self._get_event().set()

def finish(self, type=None):
Expand Down Expand Up @@ -1321,7 +1388,13 @@ async def _reconnect(self):
self.scheduler_comm = None

for st in self.futures.values():
st.cancel()
st.cancel(
reason="scheduler-connection-lost",
msg=(
"Client lost the connection to the scheduler. "
"Please check your connection and re-run your work."
),
)
self.futures.clear()

timeout = self._timeout
Expand Down Expand Up @@ -1640,7 +1713,10 @@ def _handle_task_erred(self, key=None, exception=None, traceback=None):
def _handle_restart(self):
logger.info("Receive restart signal from scheduler")
for state in self.futures.values():
state.cancel()
state.cancel(
reason="scheduler-restart",
msg="Scheduler has restarted. Please re-run your work.",
)
self.futures.clear()
self.generation += 1
with self._refcount_lock:
Expand Down Expand Up @@ -2220,19 +2296,15 @@ async def wait(k):

exceptions = set()
bad_keys = set()
for key in keys:
if key not in self.futures or self.futures[key].status in failed:
for future in future_set:
key = future.key
if key not in self.futures or future.status in failed:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As opposed to futures_of and wait, we only raise a single exception here. We should probably align this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do this in a follow-up PR, this mixes cancellations with actual errors.

exceptions.add(key)
if errors == "raise":
try:
st = self.futures[key]
exception = st.exception
traceback = st.traceback
except (KeyError, AttributeError):
exc = CancelledError(key)
else:
raise exception.with_traceback(traceback)
raise exc
st = future._state
exception = st.exception
traceback = st.traceback
raise exception.with_traceback(traceback)
if errors == "skip":
bad_keys.add(key)
bad_data[key] = None
Expand Down Expand Up @@ -2602,16 +2674,16 @@ def scatter(
hash=hash,
)

async def _cancel(self, futures, force=False):
async def _cancel(self, futures, reason=None, msg=None, force=False):
# FIXME: This method is asynchronous since interacting with the FutureState below requires an event loop.
keys = list({f.key for f in futures_of(futures)})
self._send_to_scheduler({"op": "cancel-keys", "keys": keys, "force": force})
for k in keys:
st = self.futures.pop(k, None)
if st is not None:
st.cancel()
st.cancel(reason=reason, msg=msg)

def cancel(self, futures, asynchronous=None, force=False):
def cancel(self, futures, asynchronous=None, force=False, reason=None, msg=None):
"""
Cancel running futures
This stops future tasks from being scheduled if they have not yet run
Expand All @@ -2626,8 +2698,14 @@ def cancel(self, futures, asynchronous=None, force=False):
If True the client is in asynchronous mode
force : boolean (False)
Cancel this future even if other clients desire it
reason: str
Reason for cancelling the futures
msg : str
Message that will be attached to the cancelled future
"""
return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force)
return self.sync(
self._cancel, futures, asynchronous=asynchronous, force=force, msg=msg
)

async def _retry(self, futures):
keys = list({f.key for f in futures_of(futures)})
Expand Down Expand Up @@ -5445,9 +5523,19 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED):
{fu for fu in fs if fu.status != "pending"},
{fu for fu in fs if fu.status == "pending"},
)
cancelled = [f.key for f in done if f.status == "cancelled"]
if cancelled:
raise CancelledError(cancelled)
cancelled_errors = defaultdict(list)
for f in done:
if not f.cancelled():
continue
exception = f._state.exception
assert isinstance(exception, FutureCancelledError)
cancelled_errors[exception.reason].append(exception)
if cancelled_errors:
groups = [
CancelledFuturesGroup(errors=errors, reason=reason)
for reason, errors in cancelled_errors.items()
]
raise FuturesCancelledError(groups)

return DoneAndNotDoneFutures(done, not_done)

Expand Down Expand Up @@ -5678,8 +5766,6 @@ def _get_and_raise(self):
if self.raise_errors and future.status == "error":
typ, exc, tb = result
raise exc.with_traceback(tb)
elif future.status == "cancelled":
res = (res[0], CancelledError(future.key))
return res

def __next__(self):
Expand Down Expand Up @@ -5891,10 +5977,19 @@ def futures_of(o, client=None):
stack.extend(x.__dask_graph__().values())

if client is not None:
bad = {f for f in futures if f.cancelled()}
if bad:
raise CancelledError(bad)

cancelled_errors = defaultdict(list)
for f in futures:
if not f.cancelled():
continue
exception = f._state.exception
assert isinstance(exception, FutureCancelledError)
cancelled_errors[exception.reason].append(exception)
if cancelled_errors:
groups = [
CancelledFuturesGroup(errors=errors, reason=reason)
for reason, errors in cancelled_errors.items()
]
raise FuturesCancelledError(groups)
return futures[::-1]


Expand Down
41 changes: 41 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
from distributed.client import (
Client,
Future,
FutureCancelledError,
FuturesCancelledError,
_get_global_client,
_global_clients,
as_completed,
Expand Down Expand Up @@ -8566,3 +8568,42 @@ async def test_gather_race_vs_AMM(c, s, a, direct):
b.block_get_data.set()

assert await fut == 3 # It's from a; it would be 2 if it were from b


@gen_cluster(client=True)
async def test_client_disconnect_exception_on_cancelled_futures(c, s, a, b):
fut = c.submit(inc, 1)
await wait(fut)

await s.close()

with pytest.raises(FutureCancelledError, match="connection to the scheduler"):
await fut.result()

with pytest.raises(FuturesCancelledError, match="connection to the scheduler"):
await wait(fut)

with pytest.raises(FutureCancelledError, match="connection to the scheduler"):
await fut

with pytest.raises(FutureCancelledError, match="connection to the scheduler"):
await c.gather([fut])

with pytest.raises(FuturesCancelledError, match="connection to the scheduler"):
futures_of(fut, client=c)

async for fut, res in as_completed([fut], with_results=True):
assert isinstance(res, FutureCancelledError)
assert "connection to the scheduler" in res.msg


@pytest.mark.slow
@gen_cluster(client=True, Worker=Nanny, timeout=60)
async def test_scheduler_restart_exception_on_cancelled_futures(c, s, a, b):
fut = c.submit(inc, 1)
await wait(fut)

await s.restart(stimulus_id="test")

with pytest.raises(CancelledError, match="Scheduler has restarted"):
await fut.result()
Loading