Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make ConnectionPool.remove cancel connection attempts #7547

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 77 additions & 41 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import Container, Coroutine
from collections.abc import Callable, Container, Coroutine, Generator
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypedDict, TypeVar, final
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, TypeVar, final

import tblib
from tlz import merge
Expand Down Expand Up @@ -46,7 +46,7 @@
)

if TYPE_CHECKING:
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Self

P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -1014,13 +1014,19 @@ async def send_recv( # type: ignore[no-untyped-def]
return response


def addr_from_args(addr=None, ip=None, port=None):
def addr_from_args(
addr: str | tuple[str, int | None] | None = None,
ip: str | None = None,
port: int | None = None,
) -> str:
if addr is None:
addr = (ip, port)
else:
assert ip is None and port is None
assert ip is not None
return normalize_address(unparse_host_port(ip, port))

assert ip is None and port is None
if isinstance(addr, tuple):
addr = unparse_host_port(*addr)
return normalize_address(unparse_host_port(*addr))

return normalize_address(addr)


Expand Down Expand Up @@ -1288,39 +1294,42 @@ class ConnectionPool:

def __init__(
self,
limit=512,
deserialize=True,
serializers=None,
allow_offload=True,
deserializers=None,
connection_args=None,
timeout=None,
server=None,
):
limit: int = 512,
deserialize: bool = True,
serializers: list[str] | None = None,
allow_offload: bool = True,
deserializers: list[str] | None = None,
connection_args: dict[str, object] | None = None,
timeout: float | None = None,
server: object = None,
) -> None:
self.limit = limit # Max number of open comms
# Invariant: len(available) == open - active
self.available = defaultdict(set)
self.available: defaultdict[str, set[Comm]] = defaultdict(set)
# Invariant: len(occupied) == active
self.occupied = defaultdict(set)
self.occupied: defaultdict[str, set[Comm]] = defaultdict(set)
self.allow_offload = allow_offload
self.deserialize = deserialize
self.serializers = serializers
self.deserializers = deserializers if deserializers is not None else serializers
self.connection_args = connection_args or {}
self.timeout = timeout
self.server = weakref.ref(server) if server else None
self._created = weakref.WeakSet()
self._created: weakref.WeakSet[Comm] = weakref.WeakSet()
self._instances.add(self)
# _n_connecting and _connecting have subtle different semantics. The set
# _connecting contains futures actively trying to establish a connection
# while the _n_connecting also accounts for connection attempts which
# are waiting due to the connection limit
self._connecting = set()
self._connecting: defaultdict[str, set[asyncio.Task[Comm]]] = defaultdict(set)
self._pending_count = 0
self._connecting_count = 0
self.status = Status.init
self._reasons: weakref.WeakKeyDictionary[
asyncio.Task[Any], str
] = weakref.WeakKeyDictionary()

def _validate(self):
def _validate(self) -> None:
"""
Validate important invariants of this class

Expand All @@ -1329,35 +1338,40 @@ def _validate(self):
assert self.semaphore._value == self.limit - self.open - self._n_connecting

@property
def active(self):
def active(self) -> int:
return sum(map(len, self.occupied.values()))

@property
def open(self):
def open(self) -> int:
return self.active + sum(map(len, self.available.values()))

def __repr__(self):
def __repr__(self) -> str:
return "<ConnectionPool: open=%d, active=%d, connecting=%d>" % (
self.open,
self.active,
len(self._connecting),
)

def __call__(self, addr=None, ip=None, port=None):
def __call__(
self,
addr: str | tuple[str, int | None] | None = None,
ip: str | None = None,
port: int | None = None,
) -> PooledRPCCall:
"""Cached rpc objects"""
addr = addr_from_args(addr=addr, ip=ip, port=port)
return PooledRPCCall(
addr, self, serializers=self.serializers, deserializers=self.deserializers
)

def __await__(self):
async def _():
def __await__(self) -> Generator[Any, Any, Self]:
async def _() -> Self:
await self.start()
return self

return _().__await__()

async def start(self):
async def start(self) -> None:
# Invariant: semaphore._value == limit - open - _n_connecting
self.semaphore = asyncio.Semaphore(self.limit)
self.status = Status.running
Expand All @@ -1366,7 +1380,7 @@ async def start(self):
def _n_connecting(self) -> int:
return self._connecting_count

async def _connect(self, addr, timeout=None):
async def _connect(self, addr: str, timeout: float | None = None) -> Comm:
self._pending_count += 1
try:
await self.semaphore.acquire()
Expand All @@ -1392,11 +1406,14 @@ async def _connect(self, addr, timeout=None):
finally:
self._connecting_count -= 1
except asyncio.CancelledError:
raise CommClosedError("ConnectionPool closing.")
current_task = asyncio.current_task()
assert current_task
reason = self._reasons.pop(current_task, "ConnectionPool closing.")
raise CommClosedError(reason)
finally:
self._pending_count -= 1

async def connect(self, addr, timeout=None):
async def connect(self, addr: str, timeout: float | None = None) -> Comm:
"""
Get a Comm to the given address. For internal use.
"""
Expand All @@ -1422,9 +1439,21 @@ async def connect(self, addr, timeout=None):
# it to propagate
connect_attempt = asyncio.create_task(self._connect(addr, timeout))
done = asyncio.Event()
self._connecting.add(connect_attempt)
connect_attempt.add_done_callback(lambda _: done.set())
connect_attempt.add_done_callback(self._connecting.discard)
connecting = self._connecting[addr]
connecting.add(connect_attempt)

def callback(task: asyncio.Task[Comm]) -> None:
done.set()
connecting = self._connecting[addr]
connecting.discard(task)

if not connecting:
try:
del self._connecting[addr]
except KeyError: # pragma: no cover
pass

connect_attempt.add_done_callback(callback)

try:
await done.wait()
Expand All @@ -1438,7 +1467,7 @@ async def connect(self, addr, timeout=None):
raise
return await connect_attempt

def reuse(self, addr, comm):
def reuse(self, addr: str, comm: Comm) -> None:
"""
Reuse an open communication to the given address. For internal use.
"""
Expand All @@ -1457,7 +1486,7 @@ def reuse(self, addr, comm):
if self.semaphore.locked() and self._pending_count:
self.collect()

def collect(self):
def collect(self) -> None:
"""
Collect open but unused communications, to allow opening other ones.
"""
Expand All @@ -1473,7 +1502,7 @@ def collect(self):
self.semaphore.release()
comms.clear()

def remove(self, addr):
def remove(self, addr: str, *, reason: str = "Address removed.") -> None:
"""
Remove all Comms to a given address.
"""
Expand All @@ -1489,13 +1518,20 @@ def remove(self, addr):
IOLoop.current().add_callback(comm.close)
self.semaphore.release()

async def close(self):
if addr in self._connecting:
tasks = self._connecting[addr]
for task in tasks:
self._reasons[task] = reason
task.cancel()

async def close(self) -> None:
"""
Close all communications
"""
self.status = Status.closed
for conn_fut in self._connecting:
conn_fut.cancel()
for tasks in self._connecting.values():
for task in tasks:
task.cancel()
for d in [self.available, self.occupied]:
comms = set()
while d:
Expand Down
33 changes: 32 additions & 1 deletion distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import threading
import time as timemod
import weakref
from unittest import mock

import pytest
from tornado.ioloop import IOLoop
Expand Down Expand Up @@ -724,10 +725,12 @@ async def ping(comm, delay=0.1):
*(rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[:5])
)
await asyncio.gather(*(rpc(s.address).ping() for s in servers[:5]))
await asyncio.gather(*(rpc("127.0.0.1:%d" % s.port).ping() for s in servers[:5]))
await asyncio.gather(*(rpc(f"127.0.0.1:{s.port}").ping() for s in servers[:5]))
await asyncio.gather(
*(rpc(ip="127.0.0.1", port=s.port).ping() for s in servers[:5])
)
await asyncio.gather(*(rpc(("127.0.0.1", s.port)).ping() for s in servers[:5]))

assert sum(map(len, rpc.available.values())) == 5
assert sum(map(len, rpc.occupied.values())) == 0
assert rpc.active == 0
Expand Down Expand Up @@ -839,6 +842,34 @@ async def connect_to_server():
assert all(t.cancelled() for t in tasks)


@gen_test()
async def test_remove_cancels_connect_attempts():
loop = asyncio.get_running_loop()
connect_started = asyncio.Event()
connect_finished = loop.create_future()

async def connect(*args, **kwargs):
connect_started.set()
await connect_finished

async def connect_to_server():
with pytest.raises(CommClosedError, match="Address removed."):
await rpc.connect("tcp://0.0.0.0")

async def remove_address():
await connect_started.wait()
rpc.remove("tcp://0.0.0.0")

rpc = await ConnectionPool(limit=1)

with mock.patch("distributed.core.connect", connect):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of mocking, you could also just connect to a server with either a very slow or blocked listener or one that is not replying on the handshake

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought mocking was cleaner here, I'll look into using a slow server

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also the other tests monkeypatch with monkeypatch.setitem(backends, "tcp", SlowBackend()), so I still think using mock.patch on the connect function is cleanest here

Copy link
Member Author

@graingert graingert Feb 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when testing this:

@gen_test()
async def test_remove_cancels_connect_attempts():
    loop = asyncio.get_running_loop()
    connect_started = asyncio.Event()
    connect_finished = asyncio.Event()

    class BrokenHandshakeListener(TCPListener):
        async def on_connection(self, comm):
            try:
                connect_started.set()
                await comm.read(1)
            finally:
                connect_finished.set()

    async with BrokenHandshakeListener(
        address="tcp://",
        comm_handler=lambda: None,
    ) as listener:
        rpc = await ConnectionPool(limit=1)

        async def connect_to_server():
            with pytest.raises(CommClosedError, match="Address removed."):
                await rpc.connect(listener.contact_address)

        async def remove_address():
            await connect_started.wait()
            rpc.remove(listener.contact_address)

        await asyncio.gather(
            connect_to_server(),
            remove_address(),
        )

        await connect_finished.wait()

I hit a race condition in this asyncio.wait_for where the cancellation is ignored:

comm = await asyncio.wait_for(
connector.connect(loc, deserialize=deserialize, **connection_args),
timeout=min(intermediate_cap, time_left()),
)
break

https://github.com/python/cpython/blob/924a3bfa28578802eb9ca77a66fb5d4762a62f14/Lib/asyncio/tasks.py#L472

this is because the connector.connect and on_connection tasks resume in the same even loop cycle and so the cancellation arrives just as the connect() -> asyncio.wait_for() coroutine is about to be resumed, this is very unlikely in production because the code will be waiting in comm.read() or asyncio.sleep(backoff) and can be resolved in the test by adding an asyncio.sleep(0.5) before calling rpc.remove():

        async def remove_address():
            await connect_started.wait()
            await asyncio.sleep(0.5)  # avoid issuing a .cancel() after connect but before comm.read()
            rpc.remove(listener.contact_address)

or do you think it's best to leave this with a mocked connect function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should come back to this once #7571 is done. Triggering this edge case is interesting and I believe we've encountered this a couple of times in the past (in CI).

No need for any action on this PR.

await asyncio.gather(
connect_to_server(),
remove_address(),
)
assert connect_finished.cancelled()


@gen_test()
async def test_connection_pool_respects_limit():
limit = 5
Expand Down