Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#233 Added interface for performing resource clean up and updated tests #236

Merged
merged 16 commits into from
May 25, 2017
Merged
3 changes: 3 additions & 0 deletions aiocache/backends/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
return value.decode(encoding)
return value

async def _close(self, *args, _conn=None, **kwargs):
await self.client.close()


class MemcachedCache(MemcachedBackend, BaseCache):
"""
Expand Down
3 changes: 3 additions & 0 deletions aiocache/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ async def _clear(self, namespace=None, _conn=None):
async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
return getattr(SimpleMemoryBackend._cache, command)(*args, **kwargs)

async def _close(self, *args, _con=None, **kwargs):
pass

@classmethod
def __delete(cls, key):
if cls._cache.pop(key, None):
Expand Down
6 changes: 6 additions & 0 deletions aiocache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
kwargs["encoding"] = encoding
return await getattr(_conn, command)(*args, **kwargs)

async def _close(self, *args, **kwargs):
async with self._lock:
if self._pool is not None:
self._pool.close()
await self._pool.wait_closed()

async def _connect(self):
async with self._lock:
if self._pool is None:
Expand Down
16 changes: 16 additions & 0 deletions aiocache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,22 @@ async def raw(self, command, *args, _conn=None, **kwargs):
async def _raw(self, command, *args, **kwargs):
raise NotImplementedError()

@API.timeout
async def close(self, *args, _conn=None, **kwargs):
"""
Perform any resource clean up necessary when the cache is no longer
needed (generally when the controlling program exits).

:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.time()
ret = await self._close(*args, _conn=_conn, **kwargs)
logger.debug("CLOSE (%.4f)s", time.time() - start)
return ret

async def _close(self, *args, **kwargs):
raise NotImplementedError()

def _build_key(self, key, namespace=None):
if namespace is not None:
return "{}{}".format(namespace, key)
Expand Down
7 changes: 6 additions & 1 deletion aiocache/decorators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import functools

Expand Down Expand Up @@ -55,7 +56,9 @@ async def wrapper(*args, **kwargs):

try:
if await cache_instance.exists(cache_key):
return await cache_instance.get(cache_key)
result = await cache_instance.get(cache_key)
asyncio.ensure_future(cache_instance.close())
return result

except Exception:
logger.exception("Unexpected error with %s", cache_instance)
Expand All @@ -67,6 +70,7 @@ async def wrapper(*args, **kwargs):
except Exception:
logger.exception("Unexpected error with %s", cache_instance)

asyncio.ensure_future(cache_instance.close())
return result

return wrapper
Expand Down Expand Up @@ -145,6 +149,7 @@ async def wrapper(*args, **kwargs):
except Exception:
logger.exception("Unexpected error with %s", cache_instance)

asyncio.ensure_future(cache_instance.close())
return partial_result

return wrapper
Expand Down
7 changes: 6 additions & 1 deletion examples/cached_alias_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,19 @@ async def alt_cache():
assert cache.endpoint == "127.0.0.1"
assert cache.timeout == 1
assert cache.port == 6379
await cache.close()


def test_alias():
loop = asyncio.get_event_loop()
loop.run_until_complete(default_cache())
loop.run_until_complete(alt_cache())

loop.run_until_complete(RedisCache().delete("key"))
cache = RedisCache()
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())

loop.run_until_complete(caches.get('default').close())


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion examples/cached_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from aiocache.serializers import PickleSerializer

Result = namedtuple('Result', "content, status")
cache = RedisCache(endpoint="127.0.0.1", port=6379, namespace="main")


@cached(
Expand All @@ -16,10 +15,12 @@ async def cached_call():


def test_cached():
cache = RedisCache(endpoint="127.0.0.1", port=6379, namespace="main")
loop = asyncio.get_event_loop()
loop.run_until_complete(cached_call())
assert loop.run_until_complete(cache.exists("key")) is True
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/lru_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_redis():
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.delete("key_1"))
loop.run_until_complete(cache.delete("key_2"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/multicached_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_multi_cached():
loop.run_until_complete(cache.delete("b"))
loop.run_until_complete(cache.delete("c"))
loop.run_until_complete(cache.delete("d"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/python_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_python_object():
loop = asyncio.get_event_loop()
loop.run_until_complete(complex_object())
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/serializer_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_serializer():
loop = asyncio.get_event_loop()
loop.run_until_complete(serializer())
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/serializer_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_serializer_function():
loop = asyncio.get_event_loop()
loop.run_until_complete(serializer_function())
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/simple_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_redis():
loop.run_until_complete(redis())
loop.run_until_complete(cache.delete("key"))
loop.run_until_complete(cache.delete("expire_me"))
loop.run_until_complete(cache.close())


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions tests/acceptance/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def redis_cache(event_loop):
for _, pool in RedisBackend.pools.items():
pool.close()
event_loop.run_until_complete(pool.wait_closed())
event_loop.run_until_complete(cache.close())


@pytest.fixture
Expand All @@ -49,6 +50,7 @@ def memory_cache(event_loop):

event_loop.run_until_complete(cache.delete(pytest.KEY))
event_loop.run_until_complete(cache.delete(pytest.KEY_1))
event_loop.run_until_complete(cache.close())


@pytest.fixture
Expand All @@ -58,6 +60,7 @@ def memcached_cache(event_loop):

event_loop.run_until_complete(cache.delete(pytest.KEY))
event_loop.run_until_complete(cache.delete(pytest.KEY_1))
event_loop.run_until_complete(cache.close())


@pytest.fixture(params=[
Expand Down
6 changes: 6 additions & 0 deletions tests/ut/backends/test_memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ async def test_raw_bytes(self, memcached):
memcached.client.get.assert_called_with(pytest.KEY)
memcached.client.set.assert_called_with(pytest.KEY, "asd")

@pytest.mark.asyncio
async def test_close(self):
memcached = MemcachedBackend()
await memcached._close()
assert memcached.client._pool._pool.empty()


class TestMemcachedCache:

Expand Down
13 changes: 13 additions & 0 deletions tests/ut/backends/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,19 @@ async def test_raw(self, redis):
pool.conn.get.assert_called_with(pytest.KEY, encoding=ANY)
pool.conn.set.assert_called_with(pytest.KEY, 1)

@pytest.mark.asyncio
async def test_close_when_connected(self):
redis = RedisBackend()
await redis._raw("set", pytest.KEY, 1)
await redis._close()
assert redis._pool.closed

@pytest.mark.asyncio
async def test_close_when_not_connected(self):
redis = RedisBackend()
await redis._close()
assert redis._pool is None


class TestConn:

Expand Down
1 change: 1 addition & 0 deletions tests/ut/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self):
self._expire = asynctest.CoroutineMock()
self._clear = asynctest.CoroutineMock()
self._raw = asynctest.CoroutineMock()
self._close = asynctest.CoroutineMock()
self.acquire = asynctest.CoroutineMock()
self.release = asynctest.CoroutineMock()

Expand Down
9 changes: 9 additions & 0 deletions tests/ut/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ async def test_raw(self, base_cache):
with pytest.raises(NotImplementedError):
await base_cache._raw("get", pytest.KEY)

@pytest.mark.asyncio
async def test_close(self, base_cache):
with pytest.raises(NotImplementedError):
await base_cache._close()

@pytest.mark.asyncio
async def test_acquire(self, base_cache):
assert await base_cache.acquire() == base_cache
Expand Down Expand Up @@ -409,6 +414,10 @@ async def test_get_connection(self, mock_cache):
assert mock_cache.acquire.call_count == 1
assert mock_cache.release.call_count == 1

@pytest.mark.asyncio
async def test_close(self, mock_cache):
await mock_cache.close()


@pytest.fixture
def conn(mock_cache):
Expand Down