Skip to content

Commit

Permalink
make ConnectionPool.remove cancel connection attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Feb 15, 2023
1 parent e203c63 commit 802b146
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 8 deletions.
41 changes: 33 additions & 8 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,10 +1321,13 @@ def __init__(
# _connecting contains futures actively trying to establish a connection
# while the _n_connecting also accounts for connection attempts which
# are waiting due to the connection limit
self._connecting: set[asyncio.Task[Comm]] = set()
self._connecting: defaultdict[str, set[asyncio.Task[Comm]]] = defaultdict(set)
self._pending_count = 0
self._connecting_count = 0
self.status = Status.init
self._reasons: weakref.WeakKeyDictionary[
asyncio.Task[Any], str
] = weakref.WeakKeyDictionary()

def _validate(self) -> None:
"""
Expand Down Expand Up @@ -1403,7 +1406,10 @@ async def _connect(self, addr: str, timeout: float | None = None) -> Comm:
finally:
self._connecting_count -= 1
except asyncio.CancelledError:
raise CommClosedError("ConnectionPool closing.")
current_task = asyncio.current_task()
assert current_task
reason = self._reasons.pop(current_task, "ConnectionPool closing.")
raise CommClosedError(reason)
finally:
self._pending_count -= 1

Expand Down Expand Up @@ -1433,9 +1439,21 @@ async def connect(self, addr: str, timeout: float | None = None) -> Comm:
# it to propagate
connect_attempt = asyncio.create_task(self._connect(addr, timeout))
done = asyncio.Event()
self._connecting.add(connect_attempt)
connect_attempt.add_done_callback(lambda _: done.set())
connect_attempt.add_done_callback(self._connecting.discard)
connecting = self._connecting[addr]
connecting.add(connect_attempt)

def callback(task: asyncio.Task[Comm]) -> None:
done.set()
connecting = self._connecting[addr]
connecting.discard(task)

if not connecting:
try:
del self._connecting[addr]
except KeyError:
pass

connect_attempt.add_done_callback(callback)

try:
await done.wait()
Expand Down Expand Up @@ -1484,7 +1502,7 @@ def collect(self) -> None:
self.semaphore.release()
comms.clear()

def remove(self, addr: str) -> None:
def remove(self, addr: str, *, reason: str = "Address removed.") -> None:
"""
Remove all Comms to a given address.
"""
Expand All @@ -1500,13 +1518,20 @@ def remove(self, addr: str) -> None:
IOLoop.current().add_callback(comm.close)
self.semaphore.release()

if addr in self._connecting:
tasks = self._connecting[addr]
for task in tasks:
self._reasons[task] = reason
task.cancel()

async def close(self) -> None:
"""
Close all communications
"""
self.status = Status.closed
for conn_fut in self._connecting:
conn_fut.cancel()
for tasks in self._connecting.values():
for task in tasks:
task.cancel()
for d in [self.available, self.occupied]:
comms = set()
while d:
Expand Down
29 changes: 29 additions & 0 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import threading
import time as timemod
import weakref
from unittest import mock

import pytest
from tornado.ioloop import IOLoop
Expand Down Expand Up @@ -839,6 +840,34 @@ async def connect_to_server():
assert all(t.cancelled() for t in tasks)


@gen_test()
async def test_remove_cancels_connect_attempts():
loop = asyncio.get_running_loop()
connect_started = asyncio.Event()
connect_finished = loop.create_future()

async def connect(*args, **kwargs):
connect_started.set()
await connect_finished

async def connect_to_server():
with pytest.raises(CommClosedError, match="Address removed."):
await rpc.connect("tcp://0.0.0.0")

async def remove_address():
await connect_started.wait()
rpc.remove("tcp://0.0.0.0")

rpc = await ConnectionPool(limit=1)

with mock.patch("distributed.core.connect", connect):
await asyncio.gather(
connect_to_server(),
remove_address(),
)
assert connect_finished.cancelled()


@gen_test()
async def test_connection_pool_respects_limit():
limit = 5
Expand Down

0 comments on commit 802b146

Please sign in to comment.