diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index e4d2e776bc..acc89941f2 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -546,6 +546,7 @@ def __del__( _grl().call_exception_handler(context) except RuntimeError: pass + self.connection._close() async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: """ diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 1ef9960ff3..7b0443454b 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -5,6 +5,7 @@ import socket import ssl import sys +import warnings import weakref from abc import abstractmethod from itertools import chain @@ -204,6 +205,24 @@ def __init__( raise ConnectionError("protocol must be either 2 or 3") self.protocol = protocol + def __del__(self, _warnings: Any = warnings): + # For some reason, the individual streams don't get properly garbage + # collected and therefore produce no resource warnings. We add one + # here, in the same style as those from the stdlib. + if getattr(self, "_writer", None): + _warnings.warn( + f"unclosed Connection {self!r}", ResourceWarning, source=self + ) + self._close() + + def _close(self): + """ + Internal method to silently close the connection without waiting + """ + if self._writer: + self._writer.close() + self._writer = self._reader = None + def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) return f"{self.__class__.__name__}<{repr_args}>" @@ -1017,7 +1036,7 @@ def __repr__(self): def reset(self): self._available_connections = [] - self._in_use_connections = set() + self._in_use_connections = weakref.WeakSet() def can_get_connection(self) -> bool: """Return True if a connection can be retrieved from the pool.""" diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 28e6b0d9c3..474a906091 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -316,7 +316,8 @@ async def mock_aclose(self): url: str = request.config.getoption("--redis-url") r1 = await Redis.from_url(url) with patch.object(r1, "aclose", mock_aclose): - await r1.close() + with pytest.deprecated_call(): + await r1.close() assert calls == 1 with pytest.deprecated_call(): @@ -436,3 +437,52 @@ async def mock_disconnect(_): assert called == 0 await pool.disconnect() + + +async def test_client_garbage_collection(request): + """ + Test that a Redis client will call _close() on any + connection that it holds at time of destruction + """ + + url: str = request.config.getoption("--redis-url") + pool = ConnectionPool.from_url(url) + + # create a client with a connection from the pool + client = Redis(connection_pool=pool, single_connection_client=True) + await client.initialize() + with mock.patch.object(client, "connection") as a: + # we cannot, in unittests, or from asyncio, reliably trigger garbage collection + # so we must just invoke the handler + with pytest.warns(ResourceWarning): + client.__del__() + assert a._close.called + + await client.aclose() + await pool.aclose() + + +async def test_connection_garbage_collection(request): + """ + Test that a Connection object will call close() on the + stream that it holds. + """ + + url: str = request.config.getoption("--redis-url") + pool = ConnectionPool.from_url(url) + + # create a client with a connection from the pool + client = Redis(connection_pool=pool, single_connection_client=True) + await client.initialize() + conn = client.connection + + with mock.patch.object(conn, "_reader"): + with mock.patch.object(conn, "_writer") as a: + # we cannot, in unittests, or from asyncio, reliably trigger + # garbage collection so we must just invoke the handler + with pytest.warns(ResourceWarning): + conn.__del__() + assert a.close.called + + await client.aclose() + await pool.aclose()