Skip to content

Commit

Permalink
avoid Client._handle_report cancelling itself on Client._close (dask#…
Browse files Browse the repository at this point in the history
…5672)

in dask#4617 and dask#5666
a asyncio.gather call isn't correctly waited on and logs the following
asyncio unhandled error:

```
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
asyncio.exceptions.CancelledError
```

this exception is happening because on reconnect `_close` cancels itself
before calling gather:
https://github.com/dask/distributed/blob/feac52b49292781e78beff8226407f3a5f2e563e/distributed/client.py#L1335-L1343
`_handle_report()` calls `_reconnect()` calls `_close()` which then cancels itself (edited)

`self.coroutines` can only ever contain 1 task - `_handle_report` and so
can be removed in favour of explicitly tracking the `_handle_report`
task.
  • Loading branch information
graingert authored and gjoseph92 committed Feb 1, 2022
1 parent 39ed017 commit 5166a4b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
35 changes: 16 additions & 19 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def __init__(

self.futures = dict()
self.refcount = defaultdict(lambda: 0)
self.coroutines = []
self._handle_report_task = None
if name is None:
name = dask.config.get("client-name", None)
self.id = (
Expand Down Expand Up @@ -1164,8 +1164,7 @@ async def _start(self, timeout=no_default, **kwargs):
for topic, handler in Client._default_event_handlers.items():
self.subscribe_topic(topic, handler)

self._handle_scheduler_coroutine = asyncio.ensure_future(self._handle_report())
self.coroutines.append(self._handle_scheduler_coroutine)
self._handle_report_task = asyncio.create_task(self._handle_report())

return self

Expand Down Expand Up @@ -1466,12 +1465,16 @@ async def _close(self, fast=False):
self._send_to_scheduler({"op": "close-client"})
self._send_to_scheduler({"op": "close-stream"})

current_task = asyncio.current_task()
handle_report_task = self._handle_report_task
# Give the scheduler 'stream-closed' message 100ms to come through
# This makes the shutdown slightly smoother and quieter
with suppress(AttributeError, asyncio.CancelledError, TimeoutError):
await asyncio.wait_for(
asyncio.shield(self._handle_scheduler_coroutine), 0.1
)
if (
handle_report_task is not None
and handle_report_task is not current_task
):
with suppress(asyncio.CancelledError, TimeoutError):
await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1)

if (
self.scheduler_comm
Expand All @@ -1494,19 +1497,13 @@ async def _close(self, fast=False):
if _get_global_client() is self:
_set_global_client(None)

coroutines = set(self.coroutines)
for f in self.coroutines:
# cancel() works on asyncio futures (Tornado 5)
# but is a no-op on Tornado futures
with suppress(RuntimeError):
f.cancel()
if f.cancelled():
coroutines.remove(f)
del self.coroutines[:]

if not fast:
if (
not fast
and handle_report_task is not None
and handle_report_task is not current_task
):
with suppress(TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(asyncio.gather(*coroutines), 2)
await asyncio.wait_for(handle_report_task, 2)

with suppress(AttributeError):
await self.scheduler.close_rpc()
Expand Down
32 changes: 30 additions & 2 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import weakref
import zipfile
from collections import deque
from contextlib import suppress
from collections.abc import Generator
from contextlib import contextmanager, suppress
from functools import partial
from operator import add
from threading import Semaphore
from time import sleep
from typing import Any

import psutil
import pytest
Expand Down Expand Up @@ -3782,9 +3784,35 @@ def test_reconnect(loop):
c.close()


class UnhandledException(Exception):
pass


@contextmanager
def catch_unhandled_exceptions() -> Generator[None, None, None]:
loop = asyncio.get_running_loop()
ctx: dict[str, Any] | None = None

old_handler = loop.get_exception_handler()

@loop.set_exception_handler
def _(loop: object, context: dict[str, Any]) -> None:
nonlocal ctx
ctx = context

try:
yield
finally:
loop.set_exception_handler(old_handler)
if ctx:
raise UnhandledException(ctx["message"]) from ctx.get("exception")


@gen_cluster(client=True, nthreads=[], client_kwargs={"timeout": 0.5})
async def test_reconnect_timeout(c, s):
with captured_logger(logging.getLogger("distributed.client")) as logger:
with catch_unhandled_exceptions(), captured_logger(
logging.getLogger("distributed.client")
) as logger:
await s.close()
while c.status != "closed":
await c._update_scheduler_info()
Expand Down

0 comments on commit 5166a4b

Please sign in to comment.