Skip to content

Commit

Permalink
#233 Add close function to clean up resources (#236)
Browse files Browse the repository at this point in the history
Call `await cache.close()` to close a pool and its connections
  • Loading branch information
Quinny authored and argaen committed May 25, 2017
1 parent 9f2358d commit 333119e
Show file tree
Hide file tree
Showing 18 changed files with 80 additions and 3 deletions.
3 changes: 3 additions & 0 deletions aiocache/backends/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
return value.decode(encoding)
return value

async def _close(self, *args, _conn=None, **kwargs):
await self.client.close()


class MemcachedCache(MemcachedBackend, BaseCache):
"""
Expand Down
3 changes: 3 additions & 0 deletions aiocache/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ async def _clear(self, namespace=None, _conn=None):
async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
return getattr(SimpleMemoryBackend._cache, command)(*args, **kwargs)

async def _close(self, *args, _con=None, **kwargs):
pass

@classmethod
def __delete(cls, key):
if cls._cache.pop(key, None):
Expand Down
6 changes: 6 additions & 0 deletions aiocache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
kwargs["encoding"] = encoding
return await getattr(_conn, command)(*args, **kwargs)

async def _close(self, *args, **kwargs):
async with self._lock:
if self._pool is not None:
self._pool.close()
await self._pool.wait_closed()

async def _connect(self):
async with self._lock:
if self._pool is None:
Expand Down
16 changes: 16 additions & 0 deletions aiocache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,22 @@ async def raw(self, command, *args, _conn=None, **kwargs):
async def _raw(self, command, *args, **kwargs):
raise NotImplementedError()

@API.timeout
async def close(self, *args, _conn=None, **kwargs):
"""
Perform any resource clean up necessary when the cache is no longer
needed (generally when the controlling program exits).
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.time()
ret = await self._close(*args, _conn=_conn, **kwargs)
logger.debug("CLOSE (%.4f)s", time.time() - start)
return ret

async def _close(self, *args, **kwargs):
raise NotImplementedError()

def _build_key(self, key, namespace=None):
if namespace is not None:
return "{}{}".format(namespace, key)
Expand Down
7 changes: 6 additions & 1 deletion aiocache/decorators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import functools

Expand Down Expand Up @@ -55,7 +56,9 @@ async def wrapper(*args, **kwargs):

try:
if await cache_instance.exists(cache_key):
return await cache_instance.get(cache_key)
result = await cache_instance.get(cache_key)
asyncio.ensure_future(cache_instance.close())
return result

except Exception:
logger.exception("Unexpected error with %s", cache_instance)
Expand All @@ -67,6 +70,7 @@ async def wrapper(*args, **kwargs):
except Exception:
logger.exception("Unexpected error with %s", cache_instance)

asyncio.ensure_future(cache_instance.close())
return result

return wrapper
Expand Down Expand Up @@ -145,6 +149,7 @@ async def wrapper(*args, **kwargs):
except Exception:
logger.exception("Unexpected error with %s", cache_instance)

asyncio.ensure_future(cache_instance.close())
return partial_result

return wrapper
Expand Down
7 changes: 6 additions & 1 deletion examples/cached_alias_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,19 @@ async def alt_cache():
assert cache.endpoint == "127.0.0.1"
assert cache.timeout == 1
assert cache.port == 6379
await cache.close()


def test_alias():
loop = asyncio.get_event_loop()
loop.run_until_complete(default_cache())
loop.run_until_complete(alt_cache())

loop.run_until_complete(RedisCache().delete("key"))
cache = RedisCache()
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())

loop.run_until_complete(caches.get('default').close())


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion examples/cached_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from aiocache.serializers import PickleSerializer

Result = namedtuple('Result', "content, status")
cache = RedisCache(endpoint="127.0.0.1", port=6379, namespace="main")


@cached(
Expand All @@ -16,10 +15,12 @@ async def cached_call():


def test_cached():
cache = RedisCache(endpoint="127.0.0.1", port=6379, namespace="main")
loop = asyncio.get_event_loop()
loop.run_until_complete(cached_call())
assert loop.run_until_complete(cache.exists("key")) is True
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/lru_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_redis():
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.delete("key_1"))
loop.run_until_complete(cache.delete("key_2"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/multicached_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_multi_cached():
loop.run_until_complete(cache.delete("b"))
loop.run_until_complete(cache.delete("c"))
loop.run_until_complete(cache.delete("d"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/python_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_python_object():
loop = asyncio.get_event_loop()
loop.run_until_complete(complex_object())
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/serializer_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_serializer():
loop = asyncio.get_event_loop()
loop.run_until_complete(serializer())
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/serializer_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_serializer_function():
loop = asyncio.get_event_loop()
loop.run_until_complete(serializer_function())
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/simple_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_redis():
loop.run_until_complete(redis())
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.delete("expire_me"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions tests/acceptance/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def redis_cache(event_loop):
for _, pool in RedisBackend.pools.items():
pool.close()
event_loop.run_until_complete(pool.wait_closed())
event_loop.run_until_complete(cache.close())


@pytest.fixture
Expand All @@ -49,6 +50,7 @@ def memory_cache(event_loop):

event_loop.run_until_complete(cache.delete(pytest.KEY))
event_loop.run_until_complete(cache.delete(pytest.KEY_1))
event_loop.run_until_complete(cache.close())


@pytest.fixture
Expand All @@ -58,6 +60,7 @@ def memcached_cache(event_loop):

event_loop.run_until_complete(cache.delete(pytest.KEY))
event_loop.run_until_complete(cache.delete(pytest.KEY_1))
event_loop.run_until_complete(cache.close())


@pytest.fixture(params=[
Expand Down
6 changes: 6 additions & 0 deletions tests/ut/backends/test_memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ async def test_raw_bytes(self, memcached):
memcached.client.get.assert_called_with(pytest.KEY)
memcached.client.set.assert_called_with(pytest.KEY, "asd")

@pytest.mark.asyncio
async def test_close(self):
memcached = MemcachedBackend()
await memcached._close()
assert memcached.client._pool._pool.empty()


class TestMemcachedCache:

Expand Down
13 changes: 13 additions & 0 deletions tests/ut/backends/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,19 @@ async def test_raw(self, redis):
pool.conn.get.assert_called_with(pytest.KEY, encoding=ANY)
pool.conn.set.assert_called_with(pytest.KEY, 1)

@pytest.mark.asyncio
async def test_close_when_connected(self):
redis = RedisBackend()
await redis._raw("set", pytest.KEY, 1)
await redis._close()
assert redis._pool.closed

@pytest.mark.asyncio
async def test_close_when_not_connected(self):
redis = RedisBackend()
await redis._close()
assert redis._pool is None


class TestConn:

Expand Down
1 change: 1 addition & 0 deletions tests/ut/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self):
self._expire = asynctest.CoroutineMock()
self._clear = asynctest.CoroutineMock()
self._raw = asynctest.CoroutineMock()
self._close = asynctest.CoroutineMock()
self.acquire = asynctest.CoroutineMock()
self.release = asynctest.CoroutineMock()

Expand Down
9 changes: 9 additions & 0 deletions tests/ut/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ async def test_raw(self, base_cache):
with pytest.raises(NotImplementedError):
await base_cache._raw("get", pytest.KEY)

@pytest.mark.asyncio
async def test_close(self, base_cache):
with pytest.raises(NotImplementedError):
await base_cache._close()

@pytest.mark.asyncio
async def test_acquire(self, base_cache):
assert await base_cache.acquire() == base_cache
Expand Down Expand Up @@ -409,6 +414,10 @@ async def test_get_connection(self, mock_cache):
assert mock_cache.acquire.call_count == 1
assert mock_cache.release.call_count == 1

@pytest.mark.asyncio
async def test_close(self, mock_cache):
await mock_cache.close()


@pytest.fixture
def conn(mock_cache):
Expand Down

0 comments on commit 333119e

Please sign in to comment.