diff --git a/aries_cloudagent/cache/base.py b/aries_cloudagent/cache/base.py index b469958cb9..f489450c04 100644 --- a/aries_cloudagent/cache/base.py +++ b/aries_cloudagent/cache/base.py @@ -1,12 +1,23 @@ """Abstract base classes for cache.""" +import asyncio from abc import ABC, abstractmethod from typing import Any, Sequence, Text, Union +from ..error import BaseError + + +class CacheError(BaseError): + """Base class for cache-related errors.""" + class BaseCache(ABC): """Abstract cache interface.""" + def __init__(self): + """Initialize the cache instance.""" + self._key_locks = {} + @abstractmethod async def get(self, key: Text): """ @@ -46,6 +57,117 @@ async def clear(self, key: Text): async def flush(self): """Remove all items from the cache.""" + def acquire(self, key: Text): + """Acquire a lock on a given cache key.""" + result = CacheKeyLock(self, key) + first = self._key_locks.setdefault(key, result) + if first is not result: + result.parent = first + return result + + def release(self, key: Text): + """Release the lock on a given cache key.""" + if key in self._key_locks: + del self._key_locks[key] + def __repr__(self) -> str: """Human readable representation of `BaseStorageRecordSearch`.""" return "<{}>".format(self.__class__.__name__) + + +class CacheKeyLock: + """ + A lock on a particular cache key. + + Used to prevent multiple async threads from generating + or querying the same semi-expensive data. Not thread safe. + """ + + def __init__(self, cache: BaseCache, key: Text): + """Initialize the key lock.""" + self.cache = cache + self.exception: BaseException = None + self.key = key + self.released = False + self._future: asyncio.Future = asyncio.get_event_loop().create_future() + self._parent: "CacheKeyLock" = None + + @property + def done(self) -> bool: + """Accessor for the done state.""" + return self._future.done() + + @property + def future(self) -> asyncio.Future: + """Fetch the result in the form of an awaitable future.""" + return self._future + + @property + def result(self) -> Any: + """Fetch the current result, if any.""" + if self.done: + return self._future.result() + + @property + def parent(self) -> "CacheKeyLock": + """Accessor for the parent key lock, if any.""" + return self._parent + + @parent.setter + def parent(self, parent: "CacheKeyLock"): + """Setter for the parent lock.""" + self._parent = parent + parent._future.add_done_callback(self._handle_parent_done) + + def _handle_parent_done(self, fut: asyncio.Future): + """Handle completion of parent's future.""" + result = fut.result() + if result: + self._future.set_result(fut.result()) + + async def set_result(self, value: Any, ttl: int = None): + """Set the result, updating the cache and any waiters.""" + if self.done and value: + raise CacheError("Result already set") + self._future.set_result(value) + if not self._parent or self._parent.done: + await self.cache.set(self.key, value, ttl) + + def __await__(self): + """Wait for a result to be produced.""" + return (yield from self._future) + + async def __aenter__(self): + """Async context manager entry.""" + result = None + if self.parent: + result = await self.parent + if result: + await self # wait for parent's done handler to complete + if not result: + found = await self.cache.get(self.key) + if found: + self._future.set_result(found) + return self + + def release(self): + """Release the cache lock.""" + if not self.parent and not self.released: + self.cache.release(self.key) + self.released = True + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """ + Async context manager exit. + + `None` is returned to any waiters if no value is produced. + """ + if exc_val: + self.exception = exc_val + if not self.done: + self._future.set_result(None) + self.release() + + def __del__(self): + """Handle deletion.""" + self.release() diff --git a/aries_cloudagent/cache/basic.py b/aries_cloudagent/cache/basic.py index 3af3df53a0..a5d6f5b12e 100644 --- a/aries_cloudagent/cache/basic.py +++ b/aries_cloudagent/cache/basic.py @@ -1,7 +1,6 @@ """Basic in-memory cache implementation.""" -from datetime import datetime, timedelta - +import time from typing import Any, Sequence, Text, Union from .base import BaseCache @@ -12,17 +11,17 @@ class BasicCache(BaseCache): def __init__(self): """Initialize a `BasicCache` instance.""" - + super().__init__() # looks like { "key": { "expires": , "value": } } self._cache = {} def _remove_expired_cache_items(self): """Remove all expired items from cache.""" - for key in self._cache.copy(): # iterate copy, del from original - cache_item_expiry = self._cache[key]["expires"] + for key, val in self._cache.copy().items(): # iterate copy, del from original + cache_item_expiry = val["expires"] if cache_item_expiry is None: continue - now = datetime.now().timestamp() + now = time.perf_counter() if now >= cache_item_expiry: del self._cache[key] @@ -53,16 +52,9 @@ async def set(self, keys: Union[Text, Sequence[Text]], value: Any, ttl: int = No """ self._remove_expired_cache_items() - now = datetime.now() - expires_ts = None - if ttl: - expires = now + timedelta(seconds=ttl) - expires_ts = expires.timestamp() - for key in ([keys] if isinstance(keys, Text) else keys): - self._cache[key] = { - "expires": expires_ts, - "value": value - } + expires_ts = time.perf_counter() + ttl if ttl else None + for key in [keys] if isinstance(keys, Text) else keys: + self._cache[key] = {"expires": expires_ts, "value": value} async def clear(self, key: Text): """ diff --git a/aries_cloudagent/cache/tests/test_basic_cache.py b/aries_cloudagent/cache/tests/test_basic_cache.py index db34c775c0..8c06b0cf7b 100644 --- a/aries_cloudagent/cache/tests/test_basic_cache.py +++ b/aries_cloudagent/cache/tests/test_basic_cache.py @@ -1,6 +1,7 @@ -from asyncio import sleep +from asyncio import ensure_future, sleep, wait_for import pytest +from ..base import CacheError from ..basic import BasicCache @@ -69,3 +70,98 @@ async def test_set_expires_multi(self, cache): async def test_flush(self, cache): await cache.flush() assert cache._cache == {} + + @pytest.mark.asyncio + async def test_clear(self, cache): + await cache.set("key", "value") + await cache.clear("key") + item = await cache.get("key") + assert item is None + + @pytest.mark.asyncio + async def test_acquire_release(self, cache): + test_key = "test_key" + lock = cache.acquire(test_key) + await lock.__aenter__() + assert test_key in cache._key_locks + await lock.__aexit__(None, None, None) + assert test_key not in cache._key_locks + assert await cache.get(test_key) is None + + @pytest.mark.asyncio + async def test_acquire_with_future(self, cache): + test_key = "test_key" + test_result = "test_result" + lock = cache.acquire(test_key) + await lock.__aenter__() + await lock.set_result(test_result) + await lock.__aexit__(None, None, None) + assert await wait_for(lock, 1) == test_result + assert lock.done + assert lock.result == test_result + assert lock.future.result() == test_result + + @pytest.mark.asyncio + async def test_acquire_release_with_waiter(self, cache): + test_key = "test_key" + test_result = "test_result" + lock = cache.acquire(test_key) + await lock.__aenter__() + + lock2 = cache.acquire(test_key) + assert lock.parent is None + assert lock2.parent is lock + await lock.set_result(test_result) + await lock.__aexit__(None, None, None) + + assert await cache.get(test_key) == test_result + assert await wait_for(lock, 1) == test_result + assert await wait_for(lock2, 1) == test_result + + @pytest.mark.asyncio + async def test_duplicate_set(self, cache): + test_key = "test_key" + test_result = "test_result" + lock = cache.acquire(test_key) + async with lock: + assert not lock.done + await lock.set_result(test_result) + with pytest.raises(CacheError): + await lock.set_result(test_result) + assert lock.done + assert test_key not in cache._key_locks + + @pytest.mark.asyncio + async def test_populated(self, cache): + test_key = "test_key" + test_result = "test_result" + await cache.set(test_key, test_result) + lock = cache.acquire(test_key) + lock2 = cache.acquire(test_key) + + async def check(): + async with lock as entry: + async with lock2 as entry2: + assert entry2.done # parent value located + assert entry2.result == test_result + assert entry.done + assert entry.result == test_result + assert test_key not in cache._key_locks + + await wait_for(check(), 1) + + @pytest.mark.asyncio + async def test_acquire_exception(self, cache): + test_key = "test_key" + test_result = "test_result" + lock = cache.acquire(test_key) + with pytest.raises(ValueError): + async with lock: + raise ValueError + assert isinstance(lock.exception, ValueError) + assert lock.done + assert lock.result is None + + @pytest.mark.asyncio + async def test_repr(self, cache): + assert isinstance(repr(cache), str) diff --git a/aries_cloudagent/conductor.py b/aries_cloudagent/conductor.py index 0393e4892d..215a573485 100644 --- a/aries_cloudagent/conductor.py +++ b/aries_cloudagent/conductor.py @@ -144,7 +144,11 @@ async def setup(self): # at the class level (!) should not be performed multiple times collector.wrap( ConnectionManager, - ("get_connection_target", "fetch_did_document", "find_connection"), + ( + "get_connection_target", + "fetch_did_document", + "find_message_connection", + ), ) async def start(self) -> None: diff --git a/aries_cloudagent/protocols/connections/manager.py b/aries_cloudagent/protocols/connections/manager.py index 056334b235..ebe05bf9fa 100644 --- a/aries_cloudagent/protocols/connections/manager.py +++ b/aries_cloudagent/protocols/connections/manager.py @@ -704,6 +704,56 @@ async def find_message_connection( """ + cache_key = None + connection = None + resolved = False + + if delivery.sender_verkey and delivery.recipient_verkey: + cache_key = ( + f"connection_by_verkey::{delivery.sender_verkey}" + f"::{delivery.recipient_verkey}" + ) + cache: BaseCache = await self.context.inject(BaseCache, required=False) + if cache: + async with cache.acquire(cache_key) as entry: + if entry.result: + cached = entry.result + delivery.sender_did = cached["sender_did"] + delivery.recipient_did_public = cached["recipient_did_public"] + delivery.recipient_did = cached["recipient_did"] + connection = await ConnectionRecord.retrieve_by_id( + self.context, cached["id"] + ) + else: + connection = await self.resolve_message_connection(delivery) + if connection: + cache_val = { + "id": connection.connection_id, + "sender_did": delivery.sender_did, + "recipient_did": delivery.recipient_did, + "recipient_did_public": delivery.recipient_did_public, + } + await entry.set_result(cache_val, 3600) + resolved = True + + if not connection and not resolved: + connection = await self.resolve_message_connection(delivery) + return connection + + async def resolve_message_connection( + self, delivery: MessageDelivery + ) -> ConnectionRecord: + """ + Populate the delivery DID information and find the related `ConnectionRecord`. + + Args: + delivery: The message delivery details + + Returns: + The `ConnectionRecord` associated with the expanded message, if any + + """ + if delivery.sender_verkey: try: delivery.sender_did = await self.find_did_for_key( @@ -722,11 +772,15 @@ async def find_message_connection( delivery.recipient_verkey ) delivery.recipient_did = my_info.did - if "public" in my_info.metadata and my_info.metadata["public"] is True: + if ( + "public" in my_info.metadata + and my_info.metadata["public"] is True + ): delivery.recipient_did_public = True except InjectorError: self._logger.warning( - "Cannot resolve recipient verkey, no wallet defined by context: %s", + "Cannot resolve recipient verkey, no wallet defined by " + "context: %s", delivery.recipient_verkey, ) except WalletNotFoundError: @@ -735,13 +789,12 @@ async def find_message_connection( delivery.recipient_verkey, ) - connection = await self.find_connection( - delivery.sender_did, delivery.recipient_did, delivery.recipient_verkey, True + return await self.find_connection( + delivery.sender_did, + delivery.recipient_did, + delivery.recipient_verkey, + True, ) - # if connection: - # self._log_state("Found connection", {"connection": connection}) - - return connection async def create_did_document( self, @@ -902,9 +955,25 @@ async def get_connection_target( cache: BaseCache = await self.context.inject(BaseCache, required=False) cache_key = f"connection_target::{connection.connection_id}" if cache: - target_json = await cache.get(cache_key) - if target_json: - return ConnectionTarget.deserialize(target_json) + async with cache.acquire(cache_key) as entry: + if entry.result: + return ConnectionTarget.deserialize(entry.result) + else: + target = await self.fetch_connection_target(connection) + await entry.set_result(target.serialize(), 60) + else: + target = await self.fetch_connection_target(connection) + return target + + async def fetch_connection_target( + self, connection: ConnectionRecord + ) -> ConnectionTarget: + """Create a connection target from a `ConnectionRecord`. + + Args: + connection: The connection record (with associated `DIDDoc`) + used to generate the connection target + """ if not connection.my_did: self._logger.debug("No local DID associated with connection") @@ -954,9 +1023,6 @@ async def get_connection_target( doc, my_info.verkey, connection.their_label ) - if cache: - await cache.set(cache_key, result.serialize(), 60) - return result def diddoc_connection_target( diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/manager.py b/aries_cloudagent/protocols/issue_credential/v1_0/manager.py index 0f1f416935..6e0949d32e 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/manager.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/manager.py @@ -3,6 +3,7 @@ import logging from typing import Tuple +from ....cache.base import BaseCache from ....config.injection_context import InjectionContext from ....error import BaseError from ....holder.base import BaseHolder @@ -211,18 +212,24 @@ async def create_offer( else: cred_preview = None - cache_key = f"credential_offer::{credential_definition_id}" - cached = await V10CredentialExchange.get_cached_key(self.context, cache_key) - if cached: - credential_offer = cached["offer"] - else: + async def _create(): issuer: BaseIssuer = await self.context.inject(BaseIssuer) - credential_offer = await issuer.create_credential_offer( + return await issuer.create_credential_offer( credential_definition_id ) - await V10CredentialExchange.set_cached_key( - self.context, cache_key, {"offer": credential_offer}, 3600 - ) + + credential_offer = None + cache_key = f"credential_offer::{credential_definition_id}" + cache: BaseCache = await self.context.inject(BaseCache, required=False) + if cache: + async with cache.acquire(cache_key) as entry: + if entry.result: + credential_offer = entry.result + else: + credential_offer = await _create() + await entry.set_result(credential_offer, 3600) + if not credential_offer: + credential_offer = await _create() credential_offer_message = CredentialOffer( comment=comment, @@ -320,6 +327,19 @@ async def create_request( credential_definition_id = credential_exchange_record.credential_definition_id credential_offer = credential_exchange_record.credential_offer + async def _create(): + ledger: BaseLedger = await self.context.inject(BaseLedger) + async with ledger: + credential_definition = await ledger.get_credential_definition( + credential_definition_id + ) + + holder: BaseHolder = await self.context.inject(BaseHolder) + request, metadata = await holder.create_credential_request( + credential_offer, credential_definition, holder_did + ) + return {"request": request, "metadata": metadata} + if credential_exchange_record.credential_request: self._logger.warning( "create_request called multiple times for v1.0 credential exchange: %s", @@ -332,38 +352,22 @@ async def create_request( cache_key = ( f"credential_request::{credential_definition_id}::{holder_did}::{nonce}" ) - cached = await V10CredentialExchange.get_cached_key(self.context, cache_key) - if cached: - ( - credential_exchange_record.credential_request, - credential_exchange_record.credential_request_metadata, - ) = (cached["request"], cached["metadata"]) - else: - ledger: BaseLedger = await self.context.inject(BaseLedger) - async with ledger: - credential_definition = await ledger.get_credential_definition( - credential_definition_id - ) - - holder: BaseHolder = await self.context.inject(BaseHolder) - ( - credential_exchange_record.credential_request, - credential_exchange_record.credential_request_metadata, - ) = await holder.create_credential_request( - credential_offer, credential_definition, holder_did - ) + cred_req_result = None + cache: BaseCache = await self.context.inject(BaseCache, required=False) + if cache: + async with cache.acquire(cache_key) as entry: + if entry.result: + cred_req_result = entry.result + else: + cred_req_result = await _create() + await entry.set_result(cred_req_result, 3600) + if not cred_req_result: + cred_req_result = await _create() - await V10CredentialExchange.set_cached_key( - self.context, - cache_key, - { - "request": credential_exchange_record.credential_request, - "metadata": ( - credential_exchange_record.credential_request_metadata - ), - }, - 7200, - ) + ( + credential_exchange_record.credential_request, + credential_exchange_record.credential_request_metadata, + ) = (cred_req_result["request"], cred_req_result["metadata"]) credential_request_message = CredentialRequest( requests_attach=[