Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error on cancelled tasks due to disconnect #8705

Merged
merged 11 commits into from
Jun 24, 2024
63 changes: 35 additions & 28 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,12 @@ async def _result(self, raiseit=True):
raise exc.with_traceback(tb)
else:
return exc
elif self.status == "cancelled":
exception = CancelledError(self.key)
elif self.cancelled():
if self._state:
exception = self._state.exception
assert isinstance(exception, CancelledError)
else:
exception = CancelledError(self.key)
if raiseit:
raise exception
else:
Expand Down Expand Up @@ -414,15 +418,15 @@ def execute_callback(fut):
done_callback, self, partial(cls._cb_executor.submit, execute_callback)
)

def cancel(self, **kwargs):
def cancel(self, msg=None, **kwargs):
"""Cancel the request to run this future

See Also
--------
Client.cancel
"""
self._verify_initialized()
return self.client.cancel([self], **kwargs)
return self.client.cancel([self], msg=msg, **kwargs)

def retry(self, **kwargs):
"""Retry this future if it has failed
Expand Down Expand Up @@ -556,7 +560,9 @@ class FutureState:

def __init__(self):
self._event = None
self.exception = None
self.status = "pending"
self.traceback = None
self.type = None

def _get_event(self):
Expand All @@ -568,10 +574,10 @@ def _get_event(self):
event = self._event = asyncio.Event()
return event

def cancel(self):
def cancel(self, msg=None):
"""Cancels the operation"""
self.status = "cancelled"
self.exception = CancelledError()
self.exception = CancelledError(msg) if msg else CancelledError()
self._get_event().set()

def finish(self, type=None):
Expand Down Expand Up @@ -1321,7 +1327,10 @@ async def _reconnect(self):
self.scheduler_comm = None

for st in self.futures.values():
st.cancel()
st.cancel(
"Cancelled because the client lost the connection to the scheduler. "
"Please check your connection and re-run your work."
)
self.futures.clear()

timeout = self._timeout
Expand Down Expand Up @@ -2220,19 +2229,15 @@ async def wait(k):

exceptions = set()
bad_keys = set()
for key in keys:
if key not in self.futures or self.futures[key].status in failed:
for future in future_set:
key = future.key
if key not in self.futures or future.status in failed:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As opposed to futures_of and wait, we only raise a single exception here. We should probably align this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do this in a follow-up PR, this mixes cancellations with actual errors.

exceptions.add(key)
if errors == "raise":
try:
st = self.futures[key]
exception = st.exception
traceback = st.traceback
except (KeyError, AttributeError):
exc = CancelledError(key)
else:
raise exception.with_traceback(traceback)
raise exc
st = future._state
exception = st.exception
traceback = st.traceback
raise exception.with_traceback(traceback)
if errors == "skip":
bad_keys.add(key)
bad_data[key] = None
Expand Down Expand Up @@ -2602,16 +2607,16 @@ def scatter(
hash=hash,
)

async def _cancel(self, futures, force=False):
async def _cancel(self, futures, msg=None, force=False):
# FIXME: This method is asynchronous since interacting with the FutureState below requires an event loop.
keys = list({f.key for f in futures_of(futures)})
self._send_to_scheduler({"op": "cancel-keys", "keys": keys, "force": force})
for k in keys:
st = self.futures.pop(k, None)
if st is not None:
st.cancel()
st.cancel(msg=msg)

def cancel(self, futures, asynchronous=None, force=False):
def cancel(self, futures, asynchronous=None, force=False, msg=None):
"""
Cancel running futures
This stops future tasks from being scheduled if they have not yet run
Expand All @@ -2626,8 +2631,12 @@ def cancel(self, futures, asynchronous=None, force=False):
If True the client is in asynchronous mode
force : boolean (False)
Cancel this future even if other clients desire it
msg : str
Message that will be attached to the cancelled future
"""
return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force)
return self.sync(
self._cancel, futures, asynchronous=asynchronous, force=force, msg=msg
)

async def _retry(self, futures):
keys = list({f.key for f in futures_of(futures)})
Expand Down Expand Up @@ -5445,7 +5454,7 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED):
{fu for fu in fs if fu.status != "pending"},
{fu for fu in fs if fu.status == "pending"},
)
cancelled = [f.key for f in done if f.status == "cancelled"]
cancelled = {f.key: f._state.exception for f in done if f.cancelled()}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might become very large. Ideally, we'd group these by the reason for cancellation.

if cancelled:
raise CancelledError(cancelled)

Expand Down Expand Up @@ -5678,8 +5687,6 @@ def _get_and_raise(self):
if self.raise_errors and future.status == "error":
typ, exc, tb = result
raise exc.with_traceback(tb)
elif future.status == "cancelled":
res = (res[0], CancelledError(future.key))
return res

def __next__(self):
Expand Down Expand Up @@ -5891,9 +5898,9 @@ def futures_of(o, client=None):
stack.extend(x.__dask_graph__().values())

if client is not None:
bad = {f for f in futures if f.cancelled()}
if bad:
raise CancelledError(bad)
cancelled = {f.key: f._state.exception for f in futures if f.cancelled()}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might become very large. Ideally, we'd group these by the reason for cancellation.

if cancelled:
raise CancelledError(cancelled)

return futures[::-1]

Expand Down
24 changes: 24 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8566,3 +8566,27 @@ async def test_gather_race_vs_AMM(c, s, a, direct):
b.block_get_data.set()

assert await fut == 3 # It's from a; it would be 2 if it were from b


@gen_cluster(client=True)
async def test_client_disconnect_exception_on_cancelled_futures(c, s, a, b):
fut = c.submit(inc, 1)
await wait(fut)

await s.close()

with pytest.raises(CancelledError, match="connection to the scheduler"):
await wait(fut)

with pytest.raises(CancelledError, match="connection to the scheduler"):
await fut

with pytest.raises(CancelledError, match="connection to the scheduler"):
await c.gather([fut])

with pytest.raises(CancelledError, match="connection to the scheduler"):
futures_of(fut, client=c)

async for fut, res in as_completed([fut], with_results=True):
assert isinstance(res, CancelledError)
assert "connection to the scheduler" in res.args[0]
Loading