From cdf3dcabeb9d80a46a239710b2d889657a932a54 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 10 Oct 2024 12:40:12 -0500 Subject: [PATCH] [PR #9454/b20908e backport][3.10] Simplify DNS throttle implementation (#9457) --- CHANGES/9454.misc.rst | 1 + aiohttp/connector.py | 96 +++++++----- aiohttp/locks.py | 41 ------ tests/test_connector.py | 317 ++++++++++++++++++++++++++++++++++++++-- tests/test_locks.py | 54 ------- 5 files changed, 368 insertions(+), 141 deletions(-) create mode 100644 CHANGES/9454.misc.rst delete mode 100644 aiohttp/locks.py delete mode 100644 tests/test_locks.py diff --git a/CHANGES/9454.misc.rst b/CHANGES/9454.misc.rst new file mode 100644 index 00000000000..5c842590512 --- /dev/null +++ b/CHANGES/9454.misc.rst @@ -0,0 +1 @@ +Simplified DNS resolution throttling code to reduce chance of race conditions -- by :user:`bdraco`. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 1c1283190d4..6e3c9e18db8 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -9,7 +9,7 @@ from contextlib import suppress from http import HTTPStatus from http.cookies import SimpleCookie -from itertools import cycle, islice +from itertools import chain, cycle, islice from time import monotonic from types import TracebackType from typing import ( @@ -50,8 +50,14 @@ ) from .client_proto import ResponseHandler from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params -from .helpers import ceil_timeout, is_ip_address, noop, sentinel -from .locks import EventResultOrError +from .helpers import ( + ceil_timeout, + is_ip_address, + noop, + sentinel, + set_exception, + set_result, +) from .resolver import DefaultResolver try: @@ -840,7 +846,9 @@ def __init__( self._use_dns_cache = use_dns_cache self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) - self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {} + self._throttle_dns_futures: Dict[ + Tuple[str, int], Set["asyncio.Future[None]"] + ] = {} self._family = family self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) self._happy_eyeballs_delay = happy_eyeballs_delay @@ -849,8 +857,8 @@ def __init__( def close(self) -> Awaitable[None]: """Close all ongoing DNS calls.""" - for ev in self._throttle_dns_events.values(): - ev.cancel() + for fut in chain.from_iterable(self._throttle_dns_futures.values()): + fut.cancel() for t in self._resolve_host_tasks: t.cancel() @@ -918,18 +926,35 @@ async def _resolve_host( await trace.send_dns_cache_hit(host) return result + futures: Set["asyncio.Future[None]"] # # If multiple connectors are resolving the same host, we wait # for the first one to resolve and then use the result for all of them. - # We use a throttle event to ensure that we only resolve the host once + # We use a throttle to ensure that we only resolve the host once # and then use the result for all the waiters. # + if key in self._throttle_dns_futures: + # get futures early, before any await (#4014) + futures = self._throttle_dns_futures[key] + future: asyncio.Future[None] = self._loop.create_future() + futures.add(future) + if traces: + for trace in traces: + await trace.send_dns_cache_hit(host) + try: + await future + finally: + futures.discard(future) + return self._cached_hosts.next_addrs(key) + + # update dict early, before any await (#4014) + self._throttle_dns_futures[key] = futures = set() # In this case we need to create a task to ensure that we can shield # the task from cancellation as cancelling this lookup should not cancel # the underlying lookup or else the cancel event will get broadcast to # all the waiters across all connections. # - coro = self._resolve_host_with_throttle(key, host, port, traces) + coro = self._resolve_host_with_throttle(key, host, port, futures, traces) loop = asyncio.get_running_loop() if sys.version_info >= (3, 12): # Optimization for Python 3.12, try to send immediately @@ -957,42 +982,39 @@ async def _resolve_host_with_throttle( key: Tuple[str, int], host: str, port: int, + futures: Set["asyncio.Future[None]"], traces: Optional[Sequence["Trace"]], ) -> List[ResolveResult]: - """Resolve host with a dns events throttle.""" - if key in self._throttle_dns_events: - # get event early, before any await (#4014) - event = self._throttle_dns_events[key] + """Resolve host and set result for all waiters. + + This method must be run in a task and shielded from cancellation + to avoid cancelling the underlying lookup. + """ + if traces: + for trace in traces: + await trace.send_dns_cache_miss(host) + try: if traces: for trace in traces: - await trace.send_dns_cache_hit(host) - await event.wait() - else: - # update dict early, before any await (#4014) - self._throttle_dns_events[key] = EventResultOrError(self._loop) + await trace.send_dns_resolvehost_start(host) + + addrs = await self._resolver.resolve(host, port, family=self._family) if traces: for trace in traces: - await trace.send_dns_cache_miss(host) - try: - - if traces: - for trace in traces: - await trace.send_dns_resolvehost_start(host) - - addrs = await self._resolver.resolve(host, port, family=self._family) - if traces: - for trace in traces: - await trace.send_dns_resolvehost_end(host) + await trace.send_dns_resolvehost_end(host) - self._cached_hosts.add(key, addrs) - self._throttle_dns_events[key].set() - except BaseException as e: - # any DNS exception, independently of the implementation - # is set for the waiters to raise the same exception. - self._throttle_dns_events[key].set(exc=e) - raise - finally: - self._throttle_dns_events.pop(key) + self._cached_hosts.add(key, addrs) + for fut in futures: + set_result(fut, None) + except BaseException as e: + # any DNS exception is set for the waiters to raise the same exception. + # This coro is always run in task that is shielded from cancellation so + # we should never be propagating cancellation here. + for fut in futures: + set_exception(fut, e) + raise + finally: + self._throttle_dns_futures.pop(key) return self._cached_hosts.next_addrs(key) diff --git a/aiohttp/locks.py b/aiohttp/locks.py deleted file mode 100644 index de2dc83d09d..00000000000 --- a/aiohttp/locks.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -import collections -from typing import Any, Deque, Optional - - -class EventResultOrError: - """Event asyncio lock helper class. - - Wraps the Event asyncio lock allowing either to awake the - locked Tasks without any error or raising an exception. - - thanks to @vorpalsmith for the simple design. - """ - - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: - self._loop = loop - self._exc: Optional[BaseException] = None - self._event = asyncio.Event() - self._waiters: Deque[asyncio.Future[Any]] = collections.deque() - - def set(self, exc: Optional[BaseException] = None) -> None: - self._exc = exc - self._event.set() - - async def wait(self) -> Any: - waiter = self._loop.create_task(self._event.wait()) - self._waiters.append(waiter) - try: - val = await waiter - finally: - self._waiters.remove(waiter) - - if self._exc is not None: - raise self._exc - - return val - - def cancel(self) -> None: - """Cancel all waiters""" - for waiter in self._waiters: - waiter.cancel() diff --git a/tests/test_connector.py b/tests/test_connector.py index a21dd872993..94eeb3ca85b 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -34,7 +34,6 @@ TCPConnector, _DNSCacheTable, ) -from aiohttp.locks import EventResultOrError from aiohttp.resolver import ResolveResult from aiohttp.test_utils import make_mocked_coro, unused_port from aiohttp.tracing import Trace @@ -1105,6 +1104,7 @@ def dns_response(loop): async def coro(): # simulates a network operation await asyncio.sleep(0) + await asyncio.sleep(0) return ["127.0.0.1"] return coro @@ -1766,8 +1766,8 @@ async def test_close_cancels_cleanup_handle(loop) -> None: async def test_close_cancels_resolve_host(loop: asyncio.AbstractEventLoop) -> None: cancelled = False - async def delay_resolve_host(*args: object) -> None: - """Delay _resolve_host() task in order to test cancellation.""" + async def delay_resolve(*args: object, **kwargs: object) -> None: + """Delay resolve() task in order to test cancellation.""" nonlocal cancelled try: await asyncio.sleep(10) @@ -1779,7 +1779,7 @@ async def delay_resolve_host(*args: object) -> None: req = ClientRequest( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) - with mock.patch.object(conn, "_resolve_host_with_throttle", delay_resolve_host): + with mock.patch.object(conn._resolver, "resolve", delay_resolve): t = asyncio.create_task(conn.connect(req, [], ClientTimeout())) # Let it create the internal task await asyncio.sleep(0) @@ -1797,6 +1797,301 @@ async def delay_resolve_host(*args: object) -> None: await t +async def test_multiple_dns_resolution_requests_success( + loop: asyncio.AbstractEventLoop, +) -> None: + """Verify that multiple DNS resolution requests are handled correctly.""" + + async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + """Delayed resolve() task.""" + for _ in range(3): + await asyncio.sleep(0) + return [ + { + "hostname": "localhost", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + }, + ] + + conn = aiohttp.TCPConnector(force_close=True) + req = ClientRequest( + "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() + ) + with mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + side_effect=OSError(1, "Forced connection to fail"), + ): + task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + # Let it create the internal task + await asyncio.sleep(0) + # Let that task start running + await asyncio.sleep(0) + + # Ensure the task is running + assert len(conn._resolve_host_tasks) == 1 + + task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + with pytest.raises( + aiohttp.ClientConnectorError, match="Forced connection to fail" + ): + await task1 + + # Verify the the task is finished + assert len(conn._resolve_host_tasks) == 0 + + with pytest.raises( + aiohttp.ClientConnectorError, match="Forced connection to fail" + ): + await task2 + with pytest.raises( + aiohttp.ClientConnectorError, match="Forced connection to fail" + ): + await task3 + + +async def test_multiple_dns_resolution_requests_failure( + loop: asyncio.AbstractEventLoop, +) -> None: + """Verify that DNS resolution failure for multiple requests is handled correctly.""" + + async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + """Delayed resolve() task.""" + for _ in range(3): + await asyncio.sleep(0) + raise OSError(None, "DNS Resolution mock failure") + + conn = aiohttp.TCPConnector(force_close=True) + req = ClientRequest( + "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() + ) + with mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + side_effect=OSError(1, "Forced connection to fail"), + ): + task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + # Let it create the internal task + await asyncio.sleep(0) + # Let that task start running + await asyncio.sleep(0) + + # Ensure the task is running + assert len(conn._resolve_host_tasks) == 1 + + task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + with pytest.raises( + aiohttp.ClientConnectorError, match="DNS Resolution mock failure" + ): + await task1 + + # Verify the the task is finished + assert len(conn._resolve_host_tasks) == 0 + + with pytest.raises( + aiohttp.ClientConnectorError, match="DNS Resolution mock failure" + ): + await task2 + with pytest.raises( + aiohttp.ClientConnectorError, match="DNS Resolution mock failure" + ): + await task3 + + +async def test_multiple_dns_resolution_requests_cancelled( + loop: asyncio.AbstractEventLoop, +) -> None: + """Verify that DNS resolution cancellation does not affect other tasks.""" + + async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + """Delayed resolve() task.""" + for _ in range(3): + await asyncio.sleep(0) + raise OSError(None, "DNS Resolution mock failure") + + conn = aiohttp.TCPConnector(force_close=True) + req = ClientRequest( + "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() + ) + with mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + side_effect=OSError(1, "Forced connection to fail"), + ): + task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + # Let it create the internal task + await asyncio.sleep(0) + # Let that task start running + await asyncio.sleep(0) + + # Ensure the task is running + assert len(conn._resolve_host_tasks) == 1 + + task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + task1.cancel() + with pytest.raises(asyncio.CancelledError): + await task1 + + with pytest.raises( + aiohttp.ClientConnectorError, match="DNS Resolution mock failure" + ): + await task2 + with pytest.raises( + aiohttp.ClientConnectorError, match="DNS Resolution mock failure" + ): + await task3 + + # Verify the the task is finished + assert len(conn._resolve_host_tasks) == 0 + + +async def test_multiple_dns_resolution_requests_first_cancelled( + loop: asyncio.AbstractEventLoop, +) -> None: + """Verify that first DNS resolution cancellation does not make other resolutions fail.""" + + async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + """Delayed resolve() task.""" + for _ in range(3): + await asyncio.sleep(0) + return [ + { + "hostname": "localhost", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + }, + ] + + conn = aiohttp.TCPConnector(force_close=True) + req = ClientRequest( + "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() + ) + with mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + side_effect=OSError(1, "Forced connection to fail"), + ): + task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + # Let it create the internal task + await asyncio.sleep(0) + # Let that task start running + await asyncio.sleep(0) + + # Ensure the task is running + assert len(conn._resolve_host_tasks) == 1 + + task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + task1.cancel() + with pytest.raises(asyncio.CancelledError): + await task1 + + # The second and third tasks should still make the connection + # even if the first one is cancelled + with pytest.raises( + aiohttp.ClientConnectorError, match="Forced connection to fail" + ): + await task2 + with pytest.raises( + aiohttp.ClientConnectorError, match="Forced connection to fail" + ): + await task3 + + # Verify the the task is finished + assert len(conn._resolve_host_tasks) == 0 + + +async def test_multiple_dns_resolution_requests_first_fails_second_successful( + loop: asyncio.AbstractEventLoop, +) -> None: + """Verify that first DNS resolution fails the first time and is successful the second time.""" + attempt = 0 + + async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + """Delayed resolve() task.""" + nonlocal attempt + for _ in range(3): + await asyncio.sleep(0) + attempt += 1 + if attempt == 1: + raise OSError(None, "DNS Resolution mock failure") + return [ + { + "hostname": "localhost", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + }, + ] + + conn = aiohttp.TCPConnector(force_close=True) + req = ClientRequest( + "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() + ) + with mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + side_effect=OSError(1, "Forced connection to fail"), + ): + task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + # Let it create the internal task + await asyncio.sleep(0) + # Let that task start running + await asyncio.sleep(0) + + # Ensure the task is running + assert len(conn._resolve_host_tasks) == 1 + + task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + + with pytest.raises( + aiohttp.ClientConnectorError, match="DNS Resolution mock failure" + ): + await task1 + + assert len(conn._resolve_host_tasks) == 0 + # The second task should also get the dns resolution failure + with pytest.raises( + aiohttp.ClientConnectorError, match="DNS Resolution mock failure" + ): + await task2 + + # The third task is created after the resolution finished so + # it should try again and succeed + task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) + # Let it create the internal task + await asyncio.sleep(0) + # Let that task start running + await asyncio.sleep(0) + + # Ensure the task is running + assert len(conn._resolve_host_tasks) == 1 + + with pytest.raises( + aiohttp.ClientConnectorError, match="Forced connection to fail" + ): + await task3 + + # Verify the the task is finished + assert len(conn._resolve_host_tasks) == 0 + + async def test_close_abort_closed_transports(loop: asyncio.AbstractEventLoop) -> None: tr = mock.Mock() @@ -2762,14 +3057,18 @@ async def test_connector_throttle_trace_race(loop): key = ("", 0) token = object() - class DummyTracer: - async def send_dns_cache_hit(self, *args, **kwargs): - event = connector._throttle_dns_events.pop(key) - event.set() + class DummyTracer(Trace): + def __init__(self) -> None: + """Dummy""" + + async def send_dns_cache_hit(self, *args: object, **kwargs: object) -> None: + futures = connector._throttle_dns_futures.pop(key) + for fut in futures: + fut.set_result(None) connector._cached_hosts.add(key, [token]) connector = TCPConnector() - connector._throttle_dns_events[key] = EventResultOrError(loop) + connector._throttle_dns_futures[key] = set() traces = [DummyTracer()] assert await connector._resolve_host("", 0, traces) == [token] diff --git a/tests/test_locks.py b/tests/test_locks.py deleted file mode 100644 index 5f434eace97..00000000000 --- a/tests/test_locks.py +++ /dev/null @@ -1,54 +0,0 @@ -# Tests of custom aiohttp locks implementations -import asyncio - -import pytest - -from aiohttp.locks import EventResultOrError - - -class TestEventResultOrError: - async def test_set_exception(self, loop) -> None: - ev = EventResultOrError(loop=loop) - - async def c(): - try: - await ev.wait() - except Exception as e: - return e - return 1 - - t = loop.create_task(c()) - await asyncio.sleep(0) - e = Exception() - ev.set(exc=e) - assert (await t) == e - - async def test_set(self, loop) -> None: - ev = EventResultOrError(loop=loop) - - async def c(): - await ev.wait() - return 1 - - t = loop.create_task(c()) - await asyncio.sleep(0) - ev.set() - assert (await t) == 1 - - async def test_cancel_waiters(self, loop) -> None: - ev = EventResultOrError(loop=loop) - - async def c(): - await ev.wait() - - t1 = loop.create_task(c()) - t2 = loop.create_task(c()) - await asyncio.sleep(0) - ev.cancel() - ev.set() - - with pytest.raises(asyncio.CancelledError): - await t1 - - with pytest.raises(asyncio.CancelledError): - await t2