diff --git a/aries_cloudagent/admin/server.py b/aries_cloudagent/admin/server.py index 0bbada1da6..7330699f7c 100644 --- a/aries_cloudagent/admin/server.py +++ b/aries_cloudagent/admin/server.py @@ -7,6 +7,7 @@ from typing import Callable, Coroutine import uuid import warnings +import weakref from aiohttp import web from aiohttp_apispec import ( @@ -115,7 +116,11 @@ def __init__( """ super().__init__(**kwargs) - self._profile = profile + # Weakly hold the profile so this reference doesn't prevent profiles + # from being cleaned up when appropriate. + # Binding this AdminResponder to the profile's context creates a circular + # reference. + self._profile = weakref.ref(profile) self._send = send async def send_outbound(self, message: OutboundMessage) -> OutboundSendStatus: @@ -125,7 +130,10 @@ async def send_outbound(self, message: OutboundMessage) -> OutboundSendStatus: Args: message: The `OutboundMessage` to be sent """ - return await self._send(self._profile, message) + profile = self._profile() + if not profile: + raise RuntimeError("weakref to profile has expired") + return await self._send(profile, message) async def send_webhook(self, topic: str, payload: dict): """ @@ -139,7 +147,10 @@ async def send_webhook(self, topic: str, payload: dict): "responder.send_webhook is deprecated; please use the event bus instead.", DeprecationWarning, ) - await self._profile.notify("acapy::webhook::" + topic, payload) + profile = self._profile() + if not profile: + raise RuntimeError("weakref to profile has expired") + await profile.notify("acapy::webhook::" + topic, payload) @property def send_fn(self) -> Coroutine: diff --git a/aries_cloudagent/admin/tests/test_admin_server.py b/aries_cloudagent/admin/tests/test_admin_server.py index 7707315e12..6482371c9e 100644 --- a/aries_cloudagent/admin/tests/test_admin_server.py +++ b/aries_cloudagent/admin/tests/test_admin_server.py @@ -198,10 +198,9 @@ async def test_import_routes_multitenant_middleware(self): context = InjectionContext() context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) context.injector.bind_instance(GoalCodeRegistry, GoalCodeRegistry()) - profile = InMemoryProfile.test_profile() context.injector.bind_instance( test_module.BaseMultitenantManager, - test_module.BaseMultitenantManager(profile), + async_mock.MagicMock(spec=test_module.BaseMultitenantManager), ) await DefaultContextBuilder().load_plugins(context) server = self.get_admin_server( @@ -486,3 +485,17 @@ async def test_on_record_event(server, event_topic, webhook_topic): ) as mock_send_webhook: await server._on_record_event(profile, Event(event_topic, None)) mock_send_webhook.assert_called_once_with(profile, webhook_topic, None) + + +@pytest.mark.asyncio +async def test_admin_responder_profile_expired_x(): + def _smaller_scope(): + profile = InMemoryProfile.test_profile() + return test_module.AdminResponder(profile, None) + + responder = _smaller_scope() + with pytest.raises(RuntimeError): + await responder.send_outbound(None) + + with pytest.raises(RuntimeError): + await responder.send_webhook("test", {}) diff --git a/aries_cloudagent/askar/profile.py b/aries_cloudagent/askar/profile.py index c20d436f01..a8cb10df71 100644 --- a/aries_cloudagent/askar/profile.py +++ b/aries_cloudagent/askar/profile.py @@ -36,11 +36,18 @@ class AskarProfile(Profile): BACKEND_NAME = "askar" - def __init__(self, opened: AskarOpenStore, context: InjectionContext = None): + def __init__( + self, + opened: AskarOpenStore, + context: InjectionContext = None, + *, + profile_id: str = None + ): """Create a new AskarProfile instance.""" super().__init__(context=context, name=opened.name, created=opened.created) self.opened = opened self.ledger_pool: IndyVdrLedgerPool = None + self.profile_id = profile_id self.init_ledger_pool() self.bind_providers() @@ -56,8 +63,8 @@ def store(self) -> Store: async def remove(self): """Remove the profile.""" - if self.settings.get("multitenant.wallet_type") == "askar-profile": - await self.store.remove_profile(self.settings.get("wallet.askar_profile")) + if self.profile_id: + await self.store.remove_profile(self.profile_id) def init_ledger_pool(self): """Initialize the ledger pool.""" @@ -160,11 +167,10 @@ def __init__( ): """Create a new IndySdkProfileSession instance.""" super().__init__(profile=profile, context=context, settings=settings) - profile_id = profile.context.settings.get("wallet.askar_profile") if is_txn: - self._opener = self.profile.store.transaction(profile_id) + self._opener = self.profile.store.transaction(profile.profile_id) else: - self._opener = self.profile.store.session(profile_id) + self._opener = self.profile.store.session(profile.profile_id) self._handle: Session = None self._acquire_start: float = None self._acquire_end: float = None diff --git a/aries_cloudagent/askar/tests/test_profile.py b/aries_cloudagent/askar/tests/test_profile.py index f01da0d2fe..aef94fa481 100644 --- a/aries_cloudagent/askar/tests/test_profile.py +++ b/aries_cloudagent/askar/tests/test_profile.py @@ -1,7 +1,8 @@ import asyncio +import logging import pytest -from asynctest import TestCase as AsyncTestCase, mock +from asynctest import mock from ...askar.profile import AskarProfile from ...config.injection_context import InjectionContext @@ -9,79 +10,79 @@ from .. import profile as test_module -class TestProfile(AsyncTestCase): - @mock.patch("aries_cloudagent.askar.store.AskarOpenStore") - async def test_init_success(self, AskarOpenStore): - askar_profile = AskarProfile( - AskarOpenStore, - ) - - assert askar_profile.opened == AskarOpenStore - - @mock.patch("aries_cloudagent.askar.store.AskarOpenStore") - async def test_remove_success(self, AskarOpenStore): - openStore = AskarOpenStore - context = InjectionContext() - profile_id = "profile_id" - context.settings = { - "multitenant.wallet_type": "askar-profile", - "wallet.askar_profile": profile_id, - "ledger.genesis_transactions": mock.MagicMock(), - } - askar_profile = AskarProfile(openStore, context) - remove_profile_stub = asyncio.Future() - remove_profile_stub.set_result(True) - openStore.store.remove_profile.return_value = remove_profile_stub - - await askar_profile.remove() - - openStore.store.remove_profile.assert_called_once_with(profile_id) - - @mock.patch("aries_cloudagent.askar.store.AskarOpenStore") - async def test_remove_profile_not_removed_if_wallet_type_not_askar_profile( - self, AskarOpenStore - ): - openStore = AskarOpenStore - context = InjectionContext() - context.settings = {"multitenant.wallet_type": "basic"} - askar_profile = AskarProfile(openStore, context) - - await askar_profile.remove() - - openStore.store.remove_profile.assert_not_called() - - @pytest.mark.asyncio - async def test_profile_manager_transaction(self): - profile = "profileId" - - with mock.patch("aries_cloudagent.askar.profile.AskarProfile") as AskarProfile: - askar_profile = AskarProfile(None, True) - askar_profile_transaction = mock.MagicMock() - askar_profile.store.transaction.return_value = askar_profile_transaction - askar_profile.context.settings.get.return_value = profile - - transactionProfile = test_module.AskarProfileSession(askar_profile, True) - - assert transactionProfile._opener == askar_profile_transaction - askar_profile.context.settings.get.assert_called_once_with( - "wallet.askar_profile" - ) - askar_profile.store.transaction.assert_called_once_with(profile) - - @pytest.mark.asyncio - async def test_profile_manager_store(self): - profile = "profileId" - - with mock.patch("aries_cloudagent.askar.profile.AskarProfile") as AskarProfile: - askar_profile = AskarProfile(None, False) - askar_profile_session = mock.MagicMock() - askar_profile.store.session.return_value = askar_profile_session - askar_profile.context.settings.get.return_value = profile - - sessionProfile = test_module.AskarProfileSession(askar_profile, False) - - assert sessionProfile._opener == askar_profile_session - askar_profile.context.settings.get.assert_called_once_with( - "wallet.askar_profile" - ) - askar_profile.store.session.assert_called_once_with(profile) +@pytest.fixture +def open_store(): + yield mock.MagicMock() + + +@pytest.mark.asyncio +async def test_init_success(open_store): + askar_profile = AskarProfile( + open_store, + ) + + assert askar_profile.opened == open_store + + +@pytest.mark.asyncio +async def test_remove_success(open_store): + openStore = open_store + context = InjectionContext() + profile_id = "profile_id" + context.settings = { + "multitenant.wallet_type": "askar-profile", + "wallet.askar_profile": profile_id, + "ledger.genesis_transactions": mock.MagicMock(), + } + askar_profile = AskarProfile(openStore, context, profile_id=profile_id) + remove_profile_stub = asyncio.Future() + remove_profile_stub.set_result(True) + openStore.store.remove_profile.return_value = remove_profile_stub + + await askar_profile.remove() + + openStore.store.remove_profile.assert_called_once_with(profile_id) + + +@pytest.mark.asyncio +async def test_remove_profile_not_removed_if_wallet_type_not_askar_profile(open_store): + openStore = open_store + context = InjectionContext() + context.settings = {"multitenant.wallet_type": "basic"} + askar_profile = AskarProfile(openStore, context) + + await askar_profile.remove() + + openStore.store.remove_profile.assert_not_called() + + +@pytest.mark.asyncio +async def test_profile_manager_transaction(): + profile = "profileId" + + with mock.patch("aries_cloudagent.askar.profile.AskarProfile") as AskarProfile: + askar_profile = AskarProfile(None, True, profile_id=profile) + askar_profile.profile_id = profile + askar_profile_transaction = mock.MagicMock() + askar_profile.store.transaction.return_value = askar_profile_transaction + + transactionProfile = test_module.AskarProfileSession(askar_profile, True) + + assert transactionProfile._opener == askar_profile_transaction + askar_profile.store.transaction.assert_called_once_with(profile) + + +@pytest.mark.asyncio +async def test_profile_manager_store(): + profile = "profileId" + + with mock.patch("aries_cloudagent.askar.profile.AskarProfile") as AskarProfile: + askar_profile = AskarProfile(None, False, profile_id=profile) + askar_profile.profile_id = profile + askar_profile_session = mock.MagicMock() + askar_profile.store.session.return_value = askar_profile_session + + sessionProfile = test_module.AskarProfileSession(askar_profile, False) + + assert sessionProfile._opener == askar_profile_session + askar_profile.store.session.assert_called_once_with(profile) diff --git a/aries_cloudagent/config/argparse.py b/aries_cloudagent/config/argparse.py index 2c2ac13dd4..510d3b54c7 100644 --- a/aries_cloudagent/config/argparse.py +++ b/aries_cloudagent/config/argparse.py @@ -1620,13 +1620,15 @@ def add_arguments(self, parser: ArgumentParser): parser.add_argument( "--multitenancy-config", type=str, - metavar="", + nargs="+", + metavar="key=value", env_var="ACAPY_MULTITENANCY_CONFIGURATION", help=( - 'Specify multitenancy configuration ("wallet_type" and "wallet_name"). ' - 'For example: "{"wallet_type":"askar-profile","wallet_name":' - '"askar-profile-name", "key_derivation_method":"RAW"}"' - '"wallet_name" is only used when "wallet_type" is "askar-profile"' + "Specify multitenancy configuration in key=value pairs. " + 'For example: "wallet_type=askar-profile wallet_name=askar-profile-name" ' + "Possible values: wallet_name, wallet_key, cache_size, " + 'key_derivation_method. "wallet_name" is only used when ' + '"wallet_type" is "askar-profile"' ), ) @@ -1647,22 +1649,37 @@ def get_settings(self, args: Namespace): settings["multitenant.admin_enabled"] = True if args.multitenancy_config: - multitenancyConfig = json.loads(args.multitenancy_config) - - if multitenancyConfig.get("wallet_type"): - settings["multitenant.wallet_type"] = multitenancyConfig.get( - "wallet_type" - ) - - if multitenancyConfig.get("wallet_name"): - settings["multitenant.wallet_name"] = multitenancyConfig.get( - "wallet_name" - ) - - if multitenancyConfig.get("key_derivation_method"): - settings[ - "multitenant.key_derivation_method" - ] = multitenancyConfig.get("key_derivation_method") + # Legacy support + if ( + len(args.multitenancy_config) == 1 + and args.multitenancy_config[0][0] == "{" + ): + multitenancy_config = json.loads(args.multitenancy_config[0]) + if multitenancy_config.get("wallet_type"): + settings["multitenant.wallet_type"] = multitenancy_config.get( + "wallet_type" + ) + + if multitenancy_config.get("wallet_name"): + settings["multitenant.wallet_name"] = multitenancy_config.get( + "wallet_name" + ) + + if multitenancy_config.get("cache_size"): + settings["multitenant.cache_size"] = multitenancy_config.get( + "cache_size" + ) + + if multitenancy_config.get("key_derivation_method"): + settings[ + "multitenant.key_derivation_method" + ] = multitenancy_config.get("key_derivation_method") + + else: + for value_str in args.multitenancy_config: + key, value = value_str.split("=", maxsplit=1) + value = yaml.safe_load(value) + settings[f"multitenant.{key}"] = value return settings diff --git a/aries_cloudagent/config/tests/test_argparse.py b/aries_cloudagent/config/tests/test_argparse.py index 08c47256dc..364a584987 100644 --- a/aries_cloudagent/config/tests/test_argparse.py +++ b/aries_cloudagent/config/tests/test_argparse.py @@ -231,7 +231,26 @@ async def test_multitenancy_settings(self): "--jwt-secret", "secret", "--multitenancy-config", - '{"wallet_type":"askar","wallet_name":"test"}', + '{"wallet_type":"askar","wallet_name":"test", "cache_size": 10}', + ] + ) + + settings = group.get_settings(result) + + assert settings.get("multitenant.enabled") == True + assert settings.get("multitenant.jwt_secret") == "secret" + assert settings.get("multitenant.wallet_type") == "askar" + assert settings.get("multitenant.wallet_name") == "test" + + result = parser.parse_args( + [ + "--multitenant", + "--jwt-secret", + "secret", + "--multitenancy-config", + "wallet_type=askar", + "wallet_name=test", + "cache_size=10", ] ) diff --git a/aries_cloudagent/core/conductor.py b/aries_cloudagent/core/conductor.py index 5a1a048842..377e04b6f0 100644 --- a/aries_cloudagent/core/conductor.py +++ b/aries_cloudagent/core/conductor.py @@ -488,7 +488,7 @@ async def stop(self, timeout=1.0): # close multitenant profiles multitenant_mgr = self.context.inject_or(BaseMultitenantManager) if multitenant_mgr: - for profile in multitenant_mgr._instances.values(): + for profile in multitenant_mgr.open_profiles: shutdown.run(profile.close()) if self.root_profile: diff --git a/aries_cloudagent/core/dispatcher.py b/aries_cloudagent/core/dispatcher.py index 57ef012021..e3f37b45ac 100644 --- a/aries_cloudagent/core/dispatcher.py +++ b/aries_cloudagent/core/dispatcher.py @@ -11,6 +11,7 @@ import warnings from typing import Callable, Coroutine, Union +import weakref from aiohttp.web import HTTPException @@ -283,7 +284,10 @@ def __init__( """ super().__init__(**kwargs) - self._context = context + # Weakly hold the context so it can be properly garbage collected. + # Binding this DispatcherResponder into the context creates a circular + # reference. + self._context = weakref.ref(context) self._inbound_message = inbound_message self._send = send_outbound @@ -296,13 +300,13 @@ async def create_outbound( Args: message: The message payload """ - if isinstance(message, AgentMessage) and self._context.settings.get( - "timing.enabled" - ): + context = self._context() + if not context: + raise RuntimeError("weakref to context has expired") + + if isinstance(message, AgentMessage) and context.settings.get("timing.enabled"): # Inject the timing decorator - in_time = ( - self._context.message_receipt and self._context.message_receipt.in_time - ) + in_time = context.message_receipt and context.message_receipt.in_time if not message._decorators.get("timing"): message._decorators["timing"] = { "in_time": in_time, @@ -318,7 +322,11 @@ async def send_outbound(self, message: OutboundMessage) -> OutboundSendStatus: Args: message: The `OutboundMessage` to be sent """ - return await self._send(self._context.profile, message, self._inbound_message) + context = self._context() + if not context: + raise RuntimeError("weakref to context has expired") + + return await self._send(context.profile, message, self._inbound_message) async def send_webhook(self, topic: str, payload: dict): """ @@ -332,4 +340,8 @@ async def send_webhook(self, topic: str, payload: dict): "responder.send_webhook is deprecated; please use the event bus instead.", DeprecationWarning, ) - await self._context.profile.notify("acapy::webhook::" + topic, payload) + context = self._context() + if not context: + raise RuntimeError("weakref to context has expired") + + await context.profile.notify("acapy::webhook::" + topic, payload) diff --git a/aries_cloudagent/core/tests/test_conductor.py b/aries_cloudagent/core/tests/test_conductor.py index 73f6e00528..b65b0494d8 100644 --- a/aries_cloudagent/core/tests/test_conductor.py +++ b/aries_cloudagent/core/tests/test_conductor.py @@ -28,6 +28,7 @@ ) from ...resolver.did_resolver import DIDResolver, DIDResolverRegistry from ...multitenant.base import BaseMultitenantManager +from ...multitenant.manager import MultitenantManager from ...storage.base import BaseStorage from ...storage.error import StorageNotFoundError from ...transport.inbound.message import InboundMessage @@ -1059,16 +1060,21 @@ async def test_shutdown_multitenant_profiles(self): } await conductor.setup() multitenant_mgr = conductor.context.inject(BaseMultitenantManager) + assert isinstance(multitenant_mgr, MultitenantManager) - multitenant_mgr._instances = { - "test1": async_mock.MagicMock(close=async_mock.CoroutineMock()), - "test2": async_mock.MagicMock(close=async_mock.CoroutineMock()), - } + multitenant_mgr._profiles.put( + "test1", + async_mock.MagicMock(close=async_mock.CoroutineMock()), + ) + multitenant_mgr._profiles.put( + "test2", + async_mock.MagicMock(close=async_mock.CoroutineMock()), + ) await conductor.stop() - multitenant_mgr._instances["test1"].close.assert_called_once_with() - multitenant_mgr._instances["test2"].close.assert_called_once_with() + multitenant_mgr._profiles.profiles["test1"].close.assert_called_once_with() + multitenant_mgr._profiles.profiles["test2"].close.assert_called_once_with() def get_invite_store_mock( diff --git a/aries_cloudagent/core/tests/test_dispatcher.py b/aries_cloudagent/core/tests/test_dispatcher.py index 3b5e245367..b9f00564b1 100644 --- a/aries_cloudagent/core/tests/test_dispatcher.py +++ b/aries_cloudagent/core/tests/test_dispatcher.py @@ -404,3 +404,20 @@ async def test_create_enc_outbound(self): ) as mock_send_outbound: await responder.send(message) assert mock_send_outbound.called_once() + + async def test_expired_context_x(self): + def _smaller_scope(): + profile = make_profile() + context = RequestContext(profile) + message = b"abc123xyz7890000" + return test_module.DispatcherResponder(context, message, None) + + responder = _smaller_scope() + with self.assertRaises(RuntimeError): + await responder.create_outbound(b"test") + + with self.assertRaises(RuntimeError): + await responder.send_outbound(None) + + with self.assertRaises(RuntimeError): + await responder.send_webhook("test", {}) diff --git a/aries_cloudagent/indy/sdk/profile.py b/aries_cloudagent/indy/sdk/profile.py index 24badc748a..d8151a5050 100644 --- a/aries_cloudagent/indy/sdk/profile.py +++ b/aries_cloudagent/indy/sdk/profile.py @@ -1,9 +1,10 @@ """Manage Indy-SDK profile interaction.""" +import asyncio import logging from typing import Any, Mapping -from weakref import ref +from weakref import finalize, ref from ...config.injection_context import InjectionContext from ...config.provider import ClassProvider @@ -30,13 +31,18 @@ class IndySdkProfile(Profile): BACKEND_NAME = "indy" - def __init__(self, opened: IndyOpenWallet, context: InjectionContext = None): + def __init__( + self, + opened: IndyOpenWallet, + context: InjectionContext = None, + ): """Create a new IndyProfile instance.""" super().__init__(context=context, name=opened.name, created=opened.created) self.opened = opened self.ledger_pool: IndySdkLedgerPool = None self.init_ledger_pool() self.bind_providers() + self._finalizer = self._make_finalizer(opened) @property def name(self) -> str: @@ -116,6 +122,18 @@ async def close(self): await self.opened.close() self.opened = None + def _make_finalizer(self, opened: IndyOpenWallet) -> finalize: + """Return a finalizer for this profile. + + See docs for weakref.finalize for more details on behavior of finalizers. + """ + + def _finalize(opened: IndyOpenWallet): + LOGGER.debug("Profile finalizer called; closing wallet") + asyncio.get_event_loop().create_task(opened.close()) + + return finalize(self, _finalize, opened) + async def remove(self): """Remove the profile associated with this instance.""" if not self.opened: diff --git a/aries_cloudagent/indy/sdk/tests/test_profile.py b/aries_cloudagent/indy/sdk/tests/test_profile.py index 6db4425574..6bd97474e6 100644 --- a/aries_cloudagent/indy/sdk/tests/test_profile.py +++ b/aries_cloudagent/indy/sdk/tests/test_profile.py @@ -1,3 +1,4 @@ +import logging import pytest from asynctest import mock as async_mock @@ -10,80 +11,69 @@ from ..wallet_setup import IndyWalletConfig, IndyOpenWallet +@pytest.fixture +async def open_wallet(): + yield IndyOpenWallet( + config=IndyWalletConfig({"name": "test-profile"}), + created=True, + handle=1, + master_secret_id="master-secret", + ) + + @pytest.fixture() -async def profile(): +async def profile(open_wallet): context = InjectionContext() context.injector.bind_instance(IndySdkLedgerPool, IndySdkLedgerPool("name")) - yield IndySdkProfile( - IndyOpenWallet( - config=IndyWalletConfig({"name": "test-profile"}), - created=True, - handle=1, - master_secret_id="master-secret", - ), - context, + yield IndySdkProfile(open_wallet, context) + + +@pytest.mark.asyncio +async def test_properties(profile): + assert profile.name == "test-profile" + assert profile.backend == "indy" + assert profile.wallet and profile.wallet.handle == 1 + + assert "IndySdkProfile" in str(profile) + assert profile.created + assert profile.wallet.created + assert profile.wallet.master_secret_id == "master-secret" + + with async_mock.patch.object(profile, "opened", False): + with pytest.raises(ProfileError): + await profile.remove() + + with async_mock.patch.object(profile.opened, "close", async_mock.CoroutineMock()): + await profile.remove() + assert profile.opened is None + + +def test_settings_genesis_transactions(open_wallet): + context = InjectionContext( + settings={"ledger.genesis_transactions": async_mock.MagicMock()} ) + context.injector.bind_instance(IndySdkLedgerPool, IndySdkLedgerPool("name")) + profile = IndySdkProfile(open_wallet, context) -class TestIndySdkProfile: - @pytest.mark.asyncio - async def test_properties(self, profile): - assert profile.name == "test-profile" - assert profile.backend == "indy" - assert profile.wallet and profile.wallet.handle == 1 +def test_settings_ledger_config(open_wallet): + context = InjectionContext(settings={"ledger.ledger_config_list": True}) + context.injector.bind_instance(IndySdkLedgerPool, IndySdkLedgerPool("name")) + profile = IndySdkProfile(open_wallet, context) - assert "IndySdkProfile" in str(profile) - assert profile.created - assert profile.wallet.created - assert profile.wallet.master_secret_id == "master-secret" - with async_mock.patch.object(profile, "opened", False): - with pytest.raises(ProfileError): - await profile.remove() +def test_read_only(open_wallet): + context = InjectionContext(settings={"ledger.read_only": True}) + context.injector.bind_instance(IndySdkLedgerPool, IndySdkLedgerPool("name")) + ro_profile = IndySdkProfile(open_wallet, context) - with async_mock.patch.object( - profile.opened, "close", async_mock.CoroutineMock() - ): - await profile.remove() - assert profile.opened is None - - def test_settings_genesis_transactions(self): - context = InjectionContext( - settings={"ledger.genesis_transactions": async_mock.MagicMock()} - ) - context.injector.bind_instance(IndySdkLedgerPool, IndySdkLedgerPool("name")) - profile = IndySdkProfile( - IndyOpenWallet( - config=IndyWalletConfig({"name": "test-profile"}), - created=True, - handle=1, - master_secret_id="master-secret", - ), - context, - ) - - def test_settings_ledger_config(self): - context = InjectionContext(settings={"ledger.ledger_config_list": True}) - context.injector.bind_instance(IndySdkLedgerPool, IndySdkLedgerPool("name")) - profile = IndySdkProfile( - IndyOpenWallet( - config=IndyWalletConfig({"name": "test-profile"}), - created=True, - handle=1, - master_secret_id="master-secret", - ), - context, - ) - - def test_read_only(self): - context = InjectionContext(settings={"ledger.read_only": True}) - context.injector.bind_instance(IndySdkLedgerPool, IndySdkLedgerPool("name")) - ro_profile = IndySdkProfile( - IndyOpenWallet( - config=IndyWalletConfig({"name": "test-profile"}), - created=True, - handle=1, - master_secret_id="master-secret", - ), - context, - ) + +def test_finalizer(open_wallet, caplog): + def _smaller_scope(): + profile = IndySdkProfile(open_wallet) + assert profile + + with caplog.at_level(logging.DEBUG): + _smaller_scope() + + assert "finalizer called" in caplog.text diff --git a/aries_cloudagent/multitenant/askar_profile_manager.py b/aries_cloudagent/multitenant/askar_profile_manager.py index c692b04ea8..83135cfe8e 100644 --- a/aries_cloudagent/multitenant/askar_profile_manager.py +++ b/aries_cloudagent/multitenant/askar_profile_manager.py @@ -1,5 +1,6 @@ """Manager for askar profile multitenancy mode.""" +from typing import Iterable, Optional, cast from ..core.profile import ( Profile, ) @@ -13,15 +14,25 @@ class AskarProfileMultitenantManager(BaseMultitenantManager): """Class for handling askar profile multitenancy.""" - DEFAULT_MULTIENANT_WALLET_NAME = "multitenant_sub_wallet" + DEFAULT_MULTITENANT_WALLET_NAME = "multitenant_sub_wallet" - def __init__(self, profile: Profile): + def __init__(self, profile: Profile, multitenant_profile: AskarProfile = None): """Initialize askar profile multitenant Manager. Args: profile: The base profile for this manager """ super().__init__(profile) + self._multitenant_profile: Optional[AskarProfile] = multitenant_profile + + @property + def open_profiles(self) -> Iterable[Profile]: + """Return iterator over open profiles. + + Only the core multitenant profile is considered open. + """ + if self._multitenant_profile: + yield self._multitenant_profile async def get_wallet_profile( self, @@ -33,6 +44,13 @@ async def get_wallet_profile( ) -> Profile: """Get Askar profile for a wallet record. + An object of type AskarProfile is returned but this should not be + confused with the underlying profile mechanism provided by Askar that + enables multiple "profiles" to share a wallet. Usage of this mechanism + is what causes this implementation of BaseMultitenantManager.get_wallet_profile + to look different from others, especially since no explicit clean up is + required for profiles that are no longer in use. + Args: base_context: Base context to extend from wallet_record: Wallet record to get the context for @@ -42,12 +60,10 @@ async def get_wallet_profile( Profile: Profile for the wallet record """ - multitenant_wallet_name = ( - base_context.settings.get("multitenant.wallet_name") - or self.DEFAULT_MULTIENANT_WALLET_NAME - ) - - if multitenant_wallet_name not in self._instances: + if not self._multitenant_profile: + multitenant_wallet_name = base_context.settings.get( + "multitenant.wallet_name", self.DEFAULT_MULTITENANT_WALLET_NAME + ) context = base_context.copy() sub_wallet_settings = { "wallet.recreate": False, @@ -65,13 +81,14 @@ async def get_wallet_profile( context.settings = context.settings.extend(sub_wallet_settings) profile, _ = await wallet_config(context, provision=False) - self._instances[multitenant_wallet_name] = profile + self._multitenant_profile = cast(AskarProfile, profile) - multitenant_wallet = self._instances[multitenant_wallet_name] - profile_context = multitenant_wallet.context.copy() + profile_context = self._multitenant_profile.context.copy() if provision: - await multitenant_wallet.store.create_profile(wallet_record.wallet_id) + await self._multitenant_profile.store.create_profile( + wallet_record.wallet_id + ) extra_settings = { "admin.webhook_urls": self.get_webhook_urls(base_context, wallet_record), @@ -82,7 +99,13 @@ async def get_wallet_profile( wallet_record.settings ).extend(extra_settings) - return AskarProfile(multitenant_wallet.opened, profile_context) + assert self._multitenant_profile.opened + + return AskarProfile( + self._multitenant_profile.opened, + profile_context, + profile_id=wallet_record.wallet_id, + ) async def remove_wallet_profile(self, profile: Profile): """Remove the wallet profile instance. diff --git a/aries_cloudagent/multitenant/base.py b/aries_cloudagent/multitenant/base.py index 77decdbb80..3f01900e2c 100644 --- a/aries_cloudagent/multitenant/base.py +++ b/aries_cloudagent/multitenant/base.py @@ -2,10 +2,10 @@ from datetime import datetime import logging -from abc import abstractmethod +from abc import abstractmethod, ABC import jwt -from typing import List, Optional, cast +from typing import Iterable, List, Optional, cast from ..core.profile import ( Profile, @@ -35,7 +35,7 @@ class MultitenantManagerError(BaseError): """Generic multitenant error.""" -class BaseMultitenantManager: +class BaseMultitenantManager(ABC): """Base class for handling multitenancy.""" def __init__(self, profile: Profile): @@ -48,7 +48,10 @@ def __init__(self, profile: Profile): if not profile: raise MultitenantManagerError("Missing profile") - self._instances: dict[str, Profile] = {} + @property + @abstractmethod + def open_profiles(self) -> Iterable[Profile]: + """Return iterator over open profiles.""" async def get_default_mediator(self) -> Optional[MediationRecord]: """Retrieve the default mediator used for subwallet routing. @@ -215,7 +218,7 @@ async def update_wallet( wallet_id: str, new_settings: dict, ) -> WalletRecord: - """Update a existing wallet and wallet record. + """Update an existing wallet record. Args: wallet_id: The wallet id of the wallet record @@ -231,18 +234,6 @@ async def update_wallet( wallet_record.update_settings(new_settings) await wallet_record.save(session) - # update profile only if loaded - if wallet_id in self._instances: - profile = self._instances[wallet_id] - profile.settings.update(wallet_record.settings) - - extra_settings = { - "admin.webhook_urls": self.get_webhook_urls( - self._profile.context, wallet_record - ), - } - profile.settings.update(extra_settings) - return wallet_record async def remove_wallet(self, wallet_id: str, wallet_key: str = None): diff --git a/aries_cloudagent/multitenant/cache.py b/aries_cloudagent/multitenant/cache.py new file mode 100644 index 0000000000..1fb3f37e3c --- /dev/null +++ b/aries_cloudagent/multitenant/cache.py @@ -0,0 +1,113 @@ +"""Cache for multitenancy profiles.""" + +import logging +from collections import OrderedDict +from typing import Optional +from weakref import WeakValueDictionary + +from ..core.profile import Profile + +LOGGER = logging.getLogger(__name__) + + +class ProfileCache: + """Profile cache that caches based on LRU strategy.""" + + def __init__(self, capacity: int): + """Initialize ProfileCache. + + Args: + capacity: The capacity of the cache. If capacity is exceeded + profiles are closed. + """ + + LOGGER.debug(f"Profile cache initialized with capacity {capacity}") + + self._cache: OrderedDict[str, Profile] = OrderedDict() + self.profiles: WeakValueDictionary[str, Profile] = WeakValueDictionary() + self.capacity = capacity + + def _cleanup(self): + """Prune cache until size matches defined capacity.""" + if len(self._cache) > self.capacity: + LOGGER.debug( + f"Profile limit of {self.capacity} reached." + " Evicting least recently used profiles..." + ) + while len(self._cache) > self.capacity: + key, _ = self._cache.popitem(last=False) + LOGGER.debug(f"Evicted profile with key {key}") + + def get(self, key: str) -> Optional[Profile]: + """Get profile with associated key from cache. + + If a profile is open but has been evicted from the cache, this will + reinsert the profile back into the cache. This prevents attempting to + open a profile that is already open. Triggers clean up. + + Args: + key (str): the key to get the profile for. + + Returns: + Optional[Profile]: Profile if found in cache. + + """ + value = self.profiles.get(key) + if value: + if key not in self._cache: + LOGGER.debug( + f"Rescuing profile {key} from eviction from cache; profile " + "will be reinserted into cache" + ) + self._cache[key] = value + self._cache.move_to_end(key) + self._cleanup() + + return value + + def has(self, key: str) -> bool: + """Check whether there is a profile with associated key in the cache. + + Args: + key (str): the key to check for a profile + + Returns: + bool: Whether the key exists in the cache + + """ + return key in self.profiles + + def put(self, key: str, value: Profile) -> None: + """Add profile with associated key to the cache. + + If new profile exceeds the cache capacity least recently used profiles + that are not used will be removed from the cache. + + Args: + key (str): the key to set + value (Profile): the profile to set + """ + + # Profiles are responsible for cleaning up after themselves when they + # fall out of scope. Previously the cache needed to create a finalizer. + # value.finalzer() + + # Keep track of currently opened profiles using weak references + self.profiles[key] = value + + # Strong reference to profile to hold open until evicted + LOGGER.debug(f"Setting profile with id {key} in profile cache") + self._cache[key] = value + + # Refresh profile livliness + self._cache.move_to_end(key) + self._cleanup() + + def remove(self, key: str): + """Remove profile with associated key from the cache. + + Args: + key (str): The key to remove from the cache. + """ + del self.profiles[key] + del self._cache[key] diff --git a/aries_cloudagent/multitenant/manager.py b/aries_cloudagent/multitenant/manager.py index 5bbbcc6632..e7bf2d9447 100644 --- a/aries_cloudagent/multitenant/manager.py +++ b/aries_cloudagent/multitenant/manager.py @@ -1,12 +1,16 @@ """Manager for multitenancy.""" -from ..core.profile import ( - Profile, -) -from ..config.wallet import wallet_config +import logging +from typing import Iterable + from ..config.injection_context import InjectionContext -from ..wallet.models.wallet_record import WalletRecord +from ..config.wallet import wallet_config +from ..core.profile import Profile from ..multitenant.base import BaseMultitenantManager +from ..wallet.models.wallet_record import WalletRecord +from .cache import ProfileCache + +LOGGER = logging.getLogger(__name__) class MultitenantManager(BaseMultitenantManager): @@ -19,6 +23,14 @@ def __init__(self, profile: Profile): profile: The profile for this manager """ super().__init__(profile) + self._profiles = ProfileCache( + profile.settings.get_int("multitenant.cache_size") or 100 + ) + + @property + def open_profiles(self) -> Iterable[Profile]: + """Return iterator over open profiles.""" + yield from self._profiles.profiles.values() async def get_wallet_profile( self, @@ -40,7 +52,8 @@ async def get_wallet_profile( """ wallet_id = wallet_record.wallet_id - if wallet_id not in self._instances: + profile = self._profiles.get(wallet_id) + if not profile: # Extend base context context = base_context.copy() @@ -68,9 +81,37 @@ async def get_wallet_profile( # MTODO: add ledger config profile, _ = await wallet_config(context, provision=provision) - self._instances[wallet_id] = profile + self._profiles.put(wallet_id, profile) + + return profile + + async def update_wallet(self, wallet_id: str, new_settings: dict) -> WalletRecord: + """Update an existing wallet and wallet record. + + Args: + wallet_id: The wallet id of the wallet record + new_settings: The context settings to be updated for this wallet + + Returns: + WalletRecord: The updated wallet record + + """ + wallet_record = await super().update_wallet(wallet_id, new_settings) + + # Wallet record has been updated but profile settings in memory must + # also be refreshed; update profile only if loaded + profile = self._profiles.get(wallet_id) + if profile: + profile.settings.update(wallet_record.settings) + + extra_settings = { + "admin.webhook_urls": self.get_webhook_urls( + self._profile.context, wallet_record + ), + } + profile.settings.update(extra_settings) - return self._instances[wallet_id] + return wallet_record async def remove_wallet_profile(self, profile: Profile): """Remove the wallet profile instance. @@ -79,6 +120,6 @@ async def remove_wallet_profile(self, profile: Profile): profile: The wallet profile instance """ - wallet_id = profile.settings.get("wallet.id") - del self._instances[wallet_id] + wallet_id = profile.settings.get_str("wallet.id") + self._profiles.remove(wallet_id) await profile.remove() diff --git a/aries_cloudagent/multitenant/tests/test_askar_profile_manager.py b/aries_cloudagent/multitenant/tests/test_askar_profile_manager.py index 6bad949cb4..5bbc1d926c 100644 --- a/aries_cloudagent/multitenant/tests/test_askar_profile_manager.py +++ b/aries_cloudagent/multitenant/tests/test_askar_profile_manager.py @@ -43,79 +43,61 @@ async def test_get_wallet_profile_should_open_store_and_return_profile_with_wall with async_mock.patch( "aries_cloudagent.multitenant.askar_profile_manager.wallet_config" - ) as wallet_config: - with async_mock.patch( - "aries_cloudagent.multitenant.askar_profile_manager.AskarProfile" - ) as AskarProfile: - sub_wallet_profile_context = InjectionContext() - sub_wallet_profile = AskarProfile(None, None) - sub_wallet_profile.context.copy.return_value = ( - sub_wallet_profile_context - ) + ) as wallet_config, async_mock.patch( + "aries_cloudagent.multitenant.askar_profile_manager.AskarProfile", + ) as AskarProfile: + sub_wallet_profile_context = InjectionContext() + sub_wallet_profile = AskarProfile(None, None) + sub_wallet_profile.context.copy.return_value = sub_wallet_profile_context - def side_effect(context, provision): - sub_wallet_profile.name = askar_profile_mock_name - return sub_wallet_profile, None + def side_effect(context, provision): + sub_wallet_profile.name = askar_profile_mock_name + return sub_wallet_profile, None - wallet_config.side_effect = side_effect + wallet_config.side_effect = side_effect - profile = await self.manager.get_wallet_profile( - self.profile.context, wallet_record - ) + profile = await self.manager.get_wallet_profile( + self.profile.context, wallet_record + ) - assert profile.name == askar_profile_mock_name - wallet_config.assert_called_once() - wallet_config_settings_argument = wallet_config.call_args[0][0].settings - assert ( - wallet_config_settings_argument.get("wallet.name") - == self.DEFAULT_MULTIENANT_WALLET_NAME - ) - assert wallet_config_settings_argument.get("wallet.id") == None - assert wallet_config_settings_argument.get("auto_provision") == True - assert wallet_config_settings_argument.get("wallet.type") == "askar" - AskarProfile.assert_called_with( - sub_wallet_profile.opened, sub_wallet_profile_context - ) - assert ( - sub_wallet_profile_context.settings.get("wallet.seed") - == "test_seed" - ) - assert ( - sub_wallet_profile_context.settings.get("wallet.rekey") - == "test_rekey" - ) - assert ( - sub_wallet_profile_context.settings.get("wallet.name") - == "test_name" - ) - assert ( - sub_wallet_profile_context.settings.get("wallet.type") - == "test_type" - ) - assert sub_wallet_profile_context.settings.get("mediation.open") == True - assert ( - sub_wallet_profile_context.settings.get("mediation.invite") - == "http://invite.com" - ) - assert ( - sub_wallet_profile_context.settings.get("mediation.default_id") - == "24a96ef5" - ) - assert ( - sub_wallet_profile_context.settings.get("mediation.clear") == True - ) - assert ( - sub_wallet_profile_context.settings.get("wallet.id") - == wallet_record.wallet_id - ) - assert ( - sub_wallet_profile_context.settings.get("wallet.name") - == "test_name" - ) - assert ( - sub_wallet_profile_context.settings.get("wallet.askar_profile") - == wallet_record.wallet_id - ) + assert profile.name == askar_profile_mock_name + wallet_config.assert_called_once() + wallet_config_settings_argument = wallet_config.call_args[0][0].settings + assert ( + wallet_config_settings_argument.get("wallet.name") + == self.DEFAULT_MULTIENANT_WALLET_NAME + ) + assert wallet_config_settings_argument.get("wallet.id") == None + assert wallet_config_settings_argument.get("auto_provision") == True + assert wallet_config_settings_argument.get("wallet.type") == "askar" + AskarProfile.assert_called_with( + sub_wallet_profile.opened, sub_wallet_profile_context, profile_id="test" + ) + assert sub_wallet_profile_context.settings.get("wallet.seed") == "test_seed" + assert ( + sub_wallet_profile_context.settings.get("wallet.rekey") == "test_rekey" + ) + assert sub_wallet_profile_context.settings.get("wallet.name") == "test_name" + assert sub_wallet_profile_context.settings.get("wallet.type") == "test_type" + assert sub_wallet_profile_context.settings.get("mediation.open") == True + assert ( + sub_wallet_profile_context.settings.get("mediation.invite") + == "http://invite.com" + ) + assert ( + sub_wallet_profile_context.settings.get("mediation.default_id") + == "24a96ef5" + ) + assert sub_wallet_profile_context.settings.get("mediation.clear") == True + assert ( + sub_wallet_profile_context.settings.get("wallet.id") + == wallet_record.wallet_id + ) + assert sub_wallet_profile_context.settings.get("wallet.name") == "test_name" + assert ( + sub_wallet_profile_context.settings.get("wallet.askar_profile") + == wallet_record.wallet_id + ) async def test_get_wallet_profile_should_create_profile(self): wallet_record = WalletRecord(wallet_id="test", settings={}) @@ -128,9 +110,7 @@ async def test_get_wallet_profile_should_create_profile(self): sub_wallet_profile = AskarProfile(None, None) sub_wallet_profile.context.copy.return_value = InjectionContext() sub_wallet_profile.store.create_profile.return_value = create_profile_stub - self.manager._instances[ - self.DEFAULT_MULTIENANT_WALLET_NAME - ] = sub_wallet_profile + self.manager._multitenant_profile = sub_wallet_profile await self.manager.get_wallet_profile( self.profile.context, wallet_record, provision=True @@ -172,8 +152,23 @@ def side_effect(context, provision): ) async def test_remove_wallet_profile(self): - test_profile = InMemoryProfile.test_profile() + test_profile = InMemoryProfile.test_profile({"wallet.id": "test"}) with async_mock.patch.object(InMemoryProfile, "remove") as profile_remove: await self.manager.remove_wallet_profile(test_profile) profile_remove.assert_called_once_with() + + async def test_open_profiles(self): + assert len(list(self.manager.open_profiles)) == 0 + + create_profile_stub = asyncio.Future() + create_profile_stub.set_result("") + with async_mock.patch( + "aries_cloudagent.multitenant.askar_profile_manager.AskarProfile" + ) as AskarProfile: + sub_wallet_profile = AskarProfile(None, None) + sub_wallet_profile.context.copy.return_value = InjectionContext() + sub_wallet_profile.store.create_profile.return_value = create_profile_stub + self.manager._multitenant_profile = sub_wallet_profile + + assert len(list(self.manager.open_profiles)) == 1 diff --git a/aries_cloudagent/multitenant/tests/test_base.py b/aries_cloudagent/multitenant/tests/test_base.py index f20605e9cb..5d37297b9d 100644 --- a/aries_cloudagent/multitenant/tests/test_base.py +++ b/aries_cloudagent/multitenant/tests/test_base.py @@ -26,6 +26,25 @@ from .. import base as test_module +class MockMultitenantManager(BaseMultitenantManager): + async def get_wallet_profile( + self, + base_context, + wallet_record: WalletRecord, + extra_settings: dict = ..., + *, + provision=False + ): + """Do nothing.""" + + async def remove_wallet_profile(self, profile): + """Do nothing.""" + + @property + def open_profiles(self): + """Do nothing.""" + + class TestBaseMultitenantManager(AsyncTestCase): async def setUp(self): self.profile = InMemoryProfile.test_profile() @@ -34,11 +53,11 @@ async def setUp(self): self.responder = async_mock.CoroutineMock(send=async_mock.CoroutineMock()) self.context.injector.bind_instance(BaseResponder, self.responder) - self.manager = BaseMultitenantManager(self.profile) + self.manager = MockMultitenantManager(self.profile) async def test_init_throws_no_profile(self): with self.assertRaises(MultitenantManagerError): - BaseMultitenantManager(None) + MockMultitenantManager(None) async def test_get_default_mediator(self): with async_mock.patch.object( @@ -161,7 +180,7 @@ async def test_get_wallet_by_key(self): async def test_create_wallet_removes_key_only_unmanaged_mode(self): with async_mock.patch.object( - BaseMultitenantManager, "get_wallet_profile" + self.manager, "get_wallet_profile" ) as get_wallet_profile: get_wallet_profile.return_value = InMemoryProfile.test_profile() @@ -177,7 +196,7 @@ async def test_create_wallet_removes_key_only_unmanaged_mode(self): async def test_create_wallet_fails_if_wallet_name_exists(self): with async_mock.patch.object( - BaseMultitenantManager, "_wallet_name_exists" + self.manager, "_wallet_name_exists" ) as _wallet_name_exists: _wallet_name_exists.return_value = True @@ -194,9 +213,9 @@ async def test_create_wallet_saves_wallet_record_creates_profile(self): with async_mock.patch.object( WalletRecord, "save" ) as wallet_record_save, async_mock.patch.object( - BaseMultitenantManager, "get_wallet_profile" + self.manager, "get_wallet_profile" ) as get_wallet_profile, async_mock.patch.object( - BaseMultitenantManager, "add_key" + self.manager, "add_key" ) as add_key: get_wallet_profile.return_value = InMemoryProfile.test_profile() @@ -230,9 +249,9 @@ async def test_create_wallet_adds_wallet_route(self): with async_mock.patch.object( WalletRecord, "save" ) as wallet_record_save, async_mock.patch.object( - BaseMultitenantManager, "get_wallet_profile" + self.manager, "get_wallet_profile" ) as get_wallet_profile, async_mock.patch.object( - BaseMultitenantManager, "add_key" + self.manager, "add_key" ) as add_key, async_mock.patch.object( InMemoryWallet, "get_public_did" ) as get_public_did: @@ -260,15 +279,13 @@ async def test_create_wallet_adds_wallet_route(self): assert wallet_record.key_management_mode == WalletRecord.MODE_MANAGED assert wallet_record.wallet_key == "test_key" - async def test_update_wallet_update_wallet_profile(self): + async def test_update_wallet(self): with async_mock.patch.object( WalletRecord, "retrieve_by_id" ) as retrieve_by_id, async_mock.patch.object( WalletRecord, "save" ) as wallet_record_save: wallet_id = "test-wallet-id" - wallet_profile = InMemoryProfile.test_profile() - self.manager._instances["test-wallet-id"] = wallet_profile retrieve_by_id.return_value = WalletRecord( wallet_id=wallet_id, settings={ @@ -288,10 +305,6 @@ async def test_update_wallet_update_wallet_profile(self): assert isinstance(wallet_record, WalletRecord) assert wallet_record.wallet_webhook_urls == ["new-webhook-url"] assert wallet_record.wallet_dispatch_type == "default" - assert wallet_profile.settings.get("wallet.webhook_urls") == [ - "new-webhook-url" - ] - assert wallet_profile.settings.get("wallet.dispatch_type") == "default" async def test_remove_wallet_fails_no_wallet_key_but_required(self): with async_mock.patch.object(WalletRecord, "retrieve_by_id") as retrieve_by_id: @@ -308,9 +321,9 @@ async def test_remove_wallet_removes_profile_wallet_storage_records(self): with async_mock.patch.object( WalletRecord, "retrieve_by_id" ) as retrieve_by_id, async_mock.patch.object( - BaseMultitenantManager, "get_wallet_profile" + self.manager, "get_wallet_profile" ) as get_wallet_profile, async_mock.patch.object( - BaseMultitenantManager, "remove_wallet_profile" + self.manager, "remove_wallet_profile" ) as remove_wallet_profile, async_mock.patch.object( WalletRecord, "delete_record" ) as wallet_delete_record, async_mock.patch.object( @@ -506,7 +519,7 @@ async def test_get_profile_for_token_managed_wallet_no_iat(self): ) with async_mock.patch.object( - BaseMultitenantManager, "get_wallet_profile" + self.manager, "get_wallet_profile" ) as get_wallet_profile: mock_profile = InMemoryProfile.test_profile() get_wallet_profile.return_value = mock_profile @@ -543,7 +556,7 @@ async def test_get_profile_for_token_managed_wallet_iat(self): ) with async_mock.patch.object( - BaseMultitenantManager, "get_wallet_profile" + self.manager, "get_wallet_profile" ) as get_wallet_profile: mock_profile = InMemoryProfile.test_profile() get_wallet_profile.return_value = mock_profile @@ -581,7 +594,7 @@ async def test_get_profile_for_token_managed_wallet_x_iat_no_match(self): ) with async_mock.patch.object( - BaseMultitenantManager, "get_wallet_profile" + self.manager, "get_wallet_profile" ) as get_wallet_profile, self.assertRaises( MultitenantManagerError, msg="Token not valid" ): @@ -617,7 +630,7 @@ async def test_get_profile_for_token_unmanaged_wallet(self): ) with async_mock.patch.object( - BaseMultitenantManager, "get_wallet_profile" + self.manager, "get_wallet_profile" ) as get_wallet_profile: mock_profile = InMemoryProfile.test_profile() get_wallet_profile.return_value = mock_profile @@ -657,7 +670,7 @@ async def test_get_wallets_by_message(self): ] with async_mock.patch.object( - BaseMultitenantManager, "_get_wallet_by_key" + self.manager, "_get_wallet_by_key" ) as get_wallet_by_key: get_wallet_by_key.side_effect = return_wallets diff --git a/aries_cloudagent/multitenant/tests/test_cache.py b/aries_cloudagent/multitenant/tests/test_cache.py new file mode 100644 index 0000000000..ae7dbcc303 --- /dev/null +++ b/aries_cloudagent/multitenant/tests/test_cache.py @@ -0,0 +1,107 @@ +from ...core.profile import Profile + +from ..cache import ProfileCache + + +class MockProfile(Profile): + def session(self, context=None): + ... + + def transaction(self, context=None): + ... + + +def test_get_not_in_cache(): + cache = ProfileCache(1) + + assert cache.get("1") is None + + +def test_put_get_in_cache(): + cache = ProfileCache(1) + + profile = MockProfile() + cache.put("1", profile) + + assert cache.get("1") is profile + + +def test_remove(): + cache = ProfileCache(1) + + profile = MockProfile() + cache.put("1", profile) + + assert cache.get("1") is profile + + cache.remove("1") + + assert cache.get("1") is None + + +def test_has_true(): + cache = ProfileCache(1) + + profile = MockProfile() + + assert cache.has("1") is False + cache.put("1", profile) + assert cache.has("1") is True + + +def test_cleanup(): + cache = ProfileCache(1) + + cache.put("1", MockProfile()) + + assert len(cache.profiles) == 1 + + cache.put("2", MockProfile()) + + assert len(cache.profiles) == 1 + assert cache.get("1") == None + + +def test_cleanup_lru(): + cache = ProfileCache(3) + + cache.put("1", MockProfile()) + cache.put("2", MockProfile()) + cache.put("3", MockProfile()) + + assert len(cache.profiles) == 3 + + cache.get("1") + + cache.put("4", MockProfile()) + + assert len(cache._cache) == 3 + assert cache.get("1") + assert cache.get("2") is None + assert cache.get("3") + assert cache.get("4") + + +def test_rescue_open_profile(): + cache = ProfileCache(3) + + cache.put("1", MockProfile()) + cache.put("2", MockProfile()) + cache.put("3", MockProfile()) + + assert len(cache.profiles) == 3 + + held = cache.profiles["1"] + cache.put("4", MockProfile()) + + assert len(cache.profiles) == 4 + assert len(cache._cache) == 3 + + cache.get("1") + + assert len(cache.profiles) == 3 + assert len(cache._cache) == 3 + assert cache.get("1") + assert cache.get("2") is None + assert cache.get("3") + assert cache.get("4") diff --git a/aries_cloudagent/multitenant/tests/test_manager.py b/aries_cloudagent/multitenant/tests/test_manager.py index 7e01959f95..c851369c11 100644 --- a/aries_cloudagent/multitenant/tests/test_manager.py +++ b/aries_cloudagent/multitenant/tests/test_manager.py @@ -19,7 +19,7 @@ async def setUp(self): async def test_get_wallet_profile_returns_from_cache(self): wallet_record = WalletRecord(wallet_id="test") - self.manager._instances["test"] = InMemoryProfile.test_profile() + self.manager._profiles.put("test", InMemoryProfile.test_profile()) with async_mock.patch( "aries_cloudagent.config.wallet.wallet_config" @@ -27,12 +27,12 @@ async def test_get_wallet_profile_returns_from_cache(self): profile = await self.manager.get_wallet_profile( self.profile.context, wallet_record ) - assert profile is self.manager._instances["test"] + assert profile is self.manager._profiles.get("test") wallet_config.assert_not_called() async def test_get_wallet_profile_not_in_cache(self): wallet_record = WalletRecord(wallet_id="test", settings={}) - self.manager._instances["test"] = InMemoryProfile.test_profile() + self.manager._profiles.put("test", InMemoryProfile.test_profile()) self.profile.context.update_settings( {"admin.webhook_urls": ["http://localhost:8020"]} ) @@ -43,7 +43,7 @@ async def test_get_wallet_profile_not_in_cache(self): profile = await self.manager.get_wallet_profile( self.profile.context, wallet_record ) - assert profile is self.manager._instances["test"] + assert profile is self.manager._profiles.get("test") wallet_config.assert_not_called() async def test_get_wallet_profile_settings(self): @@ -174,13 +174,46 @@ def side_effect(context, provision): assert profile.settings.get("mediation.default_id") == "24a96ef5" assert profile.settings.get("mediation.clear") == True + async def test_update_wallet_update_wallet_profile(self): + with async_mock.patch.object( + WalletRecord, "retrieve_by_id" + ) as retrieve_by_id, async_mock.patch.object( + WalletRecord, "save" + ) as wallet_record_save: + wallet_id = "test-wallet-id" + wallet_profile = InMemoryProfile.test_profile() + self.manager._profiles.put("test-wallet-id", wallet_profile) + retrieve_by_id.return_value = WalletRecord( + wallet_id=wallet_id, + settings={ + "wallet.webhook_urls": ["test-webhook-url"], + "wallet.dispatch_type": "both", + }, + ) + + new_settings = { + "wallet.webhook_urls": ["new-webhook-url"], + "wallet.dispatch_type": "default", + } + wallet_record = await self.manager.update_wallet(wallet_id, new_settings) + + wallet_record_save.assert_called_once() + + assert isinstance(wallet_record, WalletRecord) + assert wallet_record.wallet_webhook_urls == ["new-webhook-url"] + assert wallet_record.wallet_dispatch_type == "default" + assert wallet_profile.settings.get("wallet.webhook_urls") == [ + "new-webhook-url" + ] + assert wallet_profile.settings.get("wallet.dispatch_type") == "default" + async def test_remove_wallet_profile(self): test_profile = InMemoryProfile.test_profile( settings={"wallet.id": "test"}, ) - self.manager._instances["test"] = test_profile + self.manager._profiles.put("test", test_profile) with async_mock.patch.object(InMemoryProfile, "remove") as profile_remove: await self.manager.remove_wallet_profile(test_profile) - assert "test" not in self.manager._instances + assert not self.manager._profiles.has("test") profile_remove.assert_called_once_with()