diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 8b9abce3353..96a047bc0a3 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -13,6 +13,8 @@ from distributed import Scheduler from distributed._signals import wait_for_signals +from distributed.compatibility import asyncio_run +from distributed.config import get_loop_factory from distributed.preloading import validate_preload_argv from distributed.proctitle import ( enable_proctitle_on_children, @@ -246,7 +248,7 @@ async def wait_for_signals_and_close(): logger.info("Stopped scheduler at %r", scheduler.address) try: - asyncio.run(run()) + asyncio_run(run(), loop_factory=get_loop_factory()) finally: logger.info("End scheduler") diff --git a/distributed/cli/dask_spec.py b/distributed/cli/dask_spec.py index f53c7fc9578..a09fbabdf98 100644 --- a/distributed/cli/dask_spec.py +++ b/distributed/cli/dask_spec.py @@ -7,6 +7,8 @@ import click import yaml +from distributed.compatibility import asyncio_run +from distributed.config import get_loop_factory from distributed.deploy.spec import run_spec @@ -39,7 +41,7 @@ async def run(): except KeyboardInterrupt: await asyncio.gather(*(w.close() for w in servers.values())) - asyncio.run(run()) + asyncio_run(run(), loop_factory=get_loop_factory()) if __name__ == "__main__": diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 91a824f4edc..2c4989a8861 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -21,6 +21,8 @@ from distributed import Nanny from distributed._signals import wait_for_signals from distributed.comm import get_address_host_port +from distributed.compatibility import asyncio_run +from distributed.config import get_loop_factory from distributed.deploy.utils import nprocesses_nthreads from distributed.preloading import validate_preload_argv from distributed.proctitle import ( @@ -443,7 +445,7 @@ async def wait_for_signals_and_close(): [task.result() for task in done] try: - asyncio.run(run()) + asyncio_run(run(), loop_factory=get_loop_factory()) except (TimeoutError, asyncio.TimeoutError): # We already log the exception in nanny / worker. Don't do it again. if not signal_fired: diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 070953eeb86..7c6bb5476eb 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -25,6 +25,8 @@ unparse_host_port, ) from distributed.comm.registry import backends, get_backend +from distributed.compatibility import asyncio_run +from distributed.config import get_loop_factory from distributed.metrics import time from distributed.protocol import Serialized, deserialize, serialize, to_serialize from distributed.utils import get_ip, get_ipv6, get_mp_context, wait_for @@ -438,7 +440,9 @@ async def run_with_timeout(): t = asyncio.create_task(func(*args, **kwargs)) return await wait_for(t, timeout=10) - return await asyncio.to_thread(asyncio.run, run_with_timeout()) + return await asyncio.to_thread( + asyncio_run, run_with_timeout(), loop_factory=get_loop_factory() + ) @gen_test() diff --git a/distributed/compatibility.py b/distributed/compatibility.py index 84bfcc74db0..f7b92d12c11 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -5,6 +5,8 @@ import random import sys import warnings +from collections.abc import Callable, Coroutine +from typing import Any, TypeVar import tornado @@ -48,7 +50,7 @@ def randbytes(*args, **kwargs): # takes longer than the interval import datetime import math - from collections.abc import Awaitable, Callable + from collections.abc import Awaitable from inspect import isawaitable from tornado.ioloop import IOLoop @@ -182,3 +184,84 @@ def _update_next(self, current_time: float) -> None: # time.monotonic(). # https://github.com/tornadoweb/tornado/issues/2333 self._next_timeout += callback_time_sec + + +_T = TypeVar("_T") + +if sys.version_info >= (3, 12): + asyncio_run = asyncio.run +elif sys.version_info >= (3, 11): + + def asyncio_run( + main: Coroutine[Any, Any, _T], + *, + debug: bool = False, + loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None, + ) -> _T: + # asyncio.run from Python 3.12 + # https://docs.python.org/3/license.html#psf-license + with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner: + return runner.run(main) + +else: + # modified version of asyncio.run from Python 3.10 to add loop_factory kwarg + # https://docs.python.org/3/license.html#psf-license + def asyncio_run( + main: Coroutine[Any, Any, _T], + *, + debug: bool = False, + loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None, + ) -> _T: + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise RuntimeError( + "asyncio.run() cannot be called from a running event loop" + ) + + if not asyncio.iscoroutine(main): + raise ValueError(f"a coroutine was expected, got {main!r}") + + if loop_factory is None: + loop = asyncio.new_event_loop() + else: + loop = loop_factory() + try: + if loop_factory is None: + asyncio.set_event_loop(loop) + if debug is not None: + loop.set_debug(debug) + return loop.run_until_complete(main) + finally: + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete(loop.shutdown_default_executor()) + finally: + if loop_factory is None: + asyncio.set_event_loop(None) + loop.close() + + def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None: + to_cancel = asyncio.all_tasks(loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) diff --git a/distributed/config.py b/distributed/config.py index 89b176012b1..bf9c3ebe01e 100644 --- a/distributed/config.py +++ b/distributed/config.py @@ -4,6 +4,7 @@ import logging.config import os import sys +from collections.abc import Callable from typing import Any import yaml @@ -177,7 +178,7 @@ def initialize_logging(config: dict[Any, Any]) -> None: _initialize_logging_old_style(config) -def initialize_event_loop(config: dict[Any, Any]) -> None: +def get_loop_factory() -> Callable[[], asyncio.AbstractEventLoop] | None: event_loop = dask.config.get("distributed.admin.event-loop") if event_loop == "uvloop": uvloop = import_required( @@ -189,19 +190,18 @@ def initialize_event_loop(config: dict[Any, Any]) -> None: " conda install uvloop\n" " pip install uvloop", ) - uvloop.install() - elif event_loop in {"asyncio", "tornado"}: + return uvloop.new_event_loop + if event_loop in {"asyncio", "tornado"}: if sys.platform == "win32": - # WindowsProactorEventLoopPolicy is not compatible with tornado 6 + # ProactorEventLoop is not compatible with tornado 6 # fallback to the pre-3.8 default of Selector # https://github.com/tornadoweb/tornado/issues/2608 - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - else: - raise ValueError( - "Expected distributed.admin.event-loop to be in ('asyncio', 'tornado', 'uvloop'), got %s" - % dask.config.get("distributed.admin.event-loop") - ) + return asyncio.SelectorEventLoop + return None + raise ValueError( + "Expected distributed.admin.event-loop to be in ('asyncio', 'tornado', 'uvloop'), got %s" + % dask.config.get("distributed.admin.event-loop") + ) initialize_logging(dask.config.config) -initialize_event_loop(dask.config.config) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index f5fcef967d5..d31252a64bd 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -14,7 +14,8 @@ from dask.system import CPU_COUNT from distributed import Client, LocalCluster, Nanny, Worker, get_client -from distributed.compatibility import LINUX +from distributed.compatibility import LINUX, asyncio_run +from distributed.config import get_loop_factory from distributed.core import Status from distributed.metrics import time from distributed.system import MEMORY_LIMIT @@ -670,7 +671,7 @@ async def amain(): box = cluster._cached_widget assert isinstance(box, ipywidgets.Widget) - asyncio.run(amain()) + asyncio_run(amain(), loop_factory=get_loop_factory()) def test_no_ipywidgets(loop, monkeypatch): diff --git a/distributed/nanny.py b/distributed/nanny.py index 512a3fde170..1bd467b27a5 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -28,6 +28,8 @@ from distributed import preloading from distributed.comm import get_address_host from distributed.comm.addressing import address_from_user_args +from distributed.compatibility import asyncio_run +from distributed.config import get_loop_factory from distributed.core import ( AsyncTaskGroupClosedError, CommClosedError, @@ -996,7 +998,7 @@ def close_stop_q() -> None: if silence_logs: logger.setLevel(silence_logs) - asyncio.run(run()) + asyncio_run(run(), loop_factory=get_loop_factory()) def _get_env_variables(config_key: str) -> dict[str, str]: diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index 790bfbda4ed..d3f1b0bd911 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -13,7 +13,8 @@ import pytest from tornado.ioloop import IOLoop -from distributed.compatibility import LINUX, MACOS, WINDOWS +from distributed.compatibility import LINUX, MACOS, WINDOWS, asyncio_run +from distributed.config import get_loop_factory from distributed.metrics import time from distributed.process import AsyncProcess from distributed.utils import get_mp_context, wait_for @@ -389,7 +390,7 @@ async def run_with_timeout(): t = asyncio.create_task(parent_process_coroutine()) return await wait_for(t, timeout=10) - asyncio.run(run_with_timeout()) + asyncio_run(run_with_timeout(), loop_factory=get_loop_factory()) raise RuntimeError("this should be unreachable due to os._exit") diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 497cc1dd35f..4c0728f68d0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1365,15 +1365,13 @@ async def test_update_graph_culls(s, a, b): assert "z" not in s.tasks -def test_io_loop(loop): - async def main(): - with pytest.warns( - DeprecationWarning, match=r"the loop kwarg to Scheduler is deprecated" - ): - s = Scheduler(loop=loop, dashboard_address=":0", validate=True) - assert s.io_loop is IOLoop.current() - - asyncio.run(main()) +@gen_test() +async def test_io_loop(loop): + with pytest.warns( + DeprecationWarning, match=r"the loop kwarg to Scheduler is deprecated" + ): + s = Scheduler(loop=loop, dashboard_address=":0", validate=True) + assert s.io_loop is IOLoop.current() @gen_cluster(client=True) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 73329a81b33..313137135a2 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -23,7 +23,8 @@ import dask -from distributed.compatibility import MACOS, WINDOWS +from distributed.compatibility import MACOS, WINDOWS, asyncio_run +from distributed.config import get_loop_factory from distributed.metrics import time from distributed.utils import ( All, @@ -134,7 +135,7 @@ def test_sync_closed_loop(): async def get_loop(): return IOLoop.current() - loop = asyncio.run(get_loop()) + loop = asyncio_run(get_loop(), loop_factory=get_loop_factory()) loop.close() with pytest.raises(RuntimeError) as exc_info: @@ -399,7 +400,9 @@ def test_loop_runner(loop_in_thread): async def make_looprunner_in_async_context(): return IOLoop.current(), LoopRunner() - loop, runner = asyncio.run(make_looprunner_in_async_context()) + loop, runner = asyncio_run( + make_looprunner_in_async_context(), loop_factory=get_loop_factory() + ) with pytest.raises( RuntimeError, match=r"Accessing the loop property while the loop is not running is not supported", @@ -423,7 +426,7 @@ async def make_io_loop_in_async_context(): return IOLoop.current() # Explicit loop - loop = asyncio.run(make_io_loop_in_async_context()) + loop = asyncio_run(make_io_loop_in_async_context(), loop_factory=get_loop_factory()) with pytest.raises( RuntimeError, match=r"Constructing LoopRunner\(loop=loop\) without a running loop is not supported", @@ -449,7 +452,7 @@ async def make_io_loop_in_async_context(): LoopRunner(asynchronous=True) # Explicit loop - loop = asyncio.run(make_io_loop_in_async_context()) + loop = asyncio_run(make_io_loop_in_async_context(), loop_factory=get_loop_factory()) with pytest.raises( RuntimeError, match=r"Constructing LoopRunner\(loop=loop\) without a running loop is not supported", diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 239b2a0ac93..44b5e52e7ed 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,12 +1,13 @@ from __future__ import annotations -import asyncio from unittest import mock import pytest from dask.optimization import SubgraphCallable +from distributed.compatibility import asyncio_run +from distributed.config import get_loop_factory from distributed.core import ConnectionPool from distributed.utils_comm import ( WrappedKey, @@ -81,7 +82,7 @@ async def coro(): async def f(): return await retry(coro, count=0, delay_min=-1, delay_max=-1) - assert asyncio.run(f()) is retval + assert asyncio_run(f(), loop_factory=get_loop_factory()) is retval assert n_calls == 1 @@ -99,7 +100,7 @@ async def f(): return await retry(coro, count=0, delay_min=-1, delay_max=-1) with pytest.raises(RuntimeError, match="RT_ERROR 1"): - asyncio.run(f()) + asyncio_run(f(), loop_factory=get_loop_factory()) assert n_calls == 1 @@ -134,7 +135,7 @@ async def f(): with mock.patch("asyncio.sleep", my_sleep): with pytest.raises(MyEx, match="RT_ERROR 6"): - asyncio.run(f()) + asyncio_run(f(), loop_factory=get_loop_factory()) assert n_calls == 6 assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0] diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 0d3f88180bf..4438866477d 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -24,7 +24,8 @@ from distributed import Client, Event, Nanny, Scheduler, Worker, config, default_client from distributed.batched import BatchedSend from distributed.comm.core import connect -from distributed.compatibility import WINDOWS +from distributed.compatibility import WINDOWS, asyncio_run +from distributed.config import get_loop_factory from distributed.core import Server, Status, rpc from distributed.metrics import time from distributed.tests.test_batched import EchoServer @@ -73,7 +74,7 @@ async def identity(): return await scheduler_rpc.identity() with cluster() as (s, [a, b]): - ident = asyncio.run(identity()) + ident = asyncio_run(identity(), loop_factory=get_loop_factory()) assert ident["type"] == "Scheduler" assert len(ident["workers"]) == 2 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 20548ef1df0..cf902c0b41d 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -592,16 +592,13 @@ async def test_io_loop(s): @gen_cluster(nthreads=[]) async def test_io_loop_alternate_loop(s, loop): - async def main(): - with pytest.warns( - DeprecationWarning, - match=r"The `loop` argument to `Worker` is ignored, and will be " - r"removed in a future release. The Worker always binds to the current loop", - ): - async with Worker(s.address, loop=loop) as w: - assert w.io_loop is w.loop is IOLoop.current() - - await asyncio.to_thread(asyncio.run, main()) + with pytest.warns( + DeprecationWarning, + match=r"The `loop` argument to `Worker` is ignored, and will be " + r"removed in a future release. The Worker always binds to the current loop", + ): + async with Worker(s.address, loop=loop) as w: + assert w.io_loop is w.loop is IOLoop.current() @gen_cluster(client=True) diff --git a/distributed/utils.py b/distributed/utils.py index a7494234f0a..5ad983e7069 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -52,6 +52,9 @@ import psutil import tblib.pickling_support +from distributed.compatibility import asyncio_run +from distributed.config import get_loop_factory + try: import resource except ImportError: @@ -569,7 +572,7 @@ async def amain() -> None: def run_loop() -> None: nonlocal start_exc try: - asyncio.run(amain()) + asyncio_run(amain(), loop_factory=get_loop_factory()) except BaseException as e: if start_evt.is_set(): raise diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 20067b32637..fbceba4d356 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -43,8 +43,8 @@ from distributed.client import Client, _global_clients, default_client from distributed.comm import Comm from distributed.comm.tcp import TCP -from distributed.compatibility import MACOS, WINDOWS -from distributed.config import initialize_logging +from distributed.compatibility import MACOS, WINDOWS, asyncio_run +from distributed.config import get_loop_factory, initialize_logging from distributed.core import ( CommClosedError, ConnectionPool, @@ -375,7 +375,7 @@ async def inner_fn(): return await async_fn(*args, **kwargs) try: - return asyncio.run(inner_fn()) + return asyncio_run(inner_fn(), loop_factory=get_loop_factory()) finally: tornado_loop.close(all_fds=True)