diff --git a/aries_cloudagent/cache/base.py b/aries_cloudagent/cache/base.py index 3f4e4c8893..84abdf64d7 100644 --- a/aries_cloudagent/cache/base.py +++ b/aries_cloudagent/cache/base.py @@ -21,7 +21,7 @@ async def get(self, key: Text): """ @abstractmethod - async def set(self, key: Text, value: Any, ttl: int): + async def set(self, key: Text, value: Any, ttl: int = None): """ Add an item to the cache with an optional ttl. diff --git a/aries_cloudagent/conductor.py b/aries_cloudagent/conductor.py index 4fed1a5c7d..9ee132844b 100644 --- a/aries_cloudagent/conductor.py +++ b/aries_cloudagent/conductor.py @@ -185,10 +185,8 @@ async def start(self) -> None: _connection, invitation = await mgr.create_invitation( their_role=context.settings.get("debug.invite_role"), my_label=context.settings.get("debug.invite_label"), - multi_use=context.settings.get( - "debug.invite_multi_use", False - ), - public=context.settings.get("debug.invite_public", False) + multi_use=context.settings.get("debug.invite_multi_use", False), + public=context.settings.get("debug.invite_public", False), ) base_url = context.settings.get("invite_base_url") invite_url = invitation.to_url(base_url) @@ -288,47 +286,61 @@ async def inbound_message_router( complete.add_done_callback(lambda fut: socket.dispatch_complete()) return complete + async def get_connection_target( + self, connection_id: str, context: InjectionContext = None + ): + """Get a `ConnectionTarget` instance representing a connection. + + Args: + connection_id: The connection record identifier + context: An optional injection context + """ + + context = context or self.context + + try: + record = await ConnectionRecord.retrieve_by_id(context, connection_id) + except StorageNotFoundError as e: + raise MessagePrepareError( + "Could not locate connection record: {}".format(connection_id) + ) from e + mgr = ConnectionManager(context) + try: + target = await mgr.get_connection_target(record) + except ConnectionManagerError as e: + raise MessagePrepareError(str(e)) from e + if not target: + raise MessagePrepareError( + "No target found for connection: {}".format(connection_id) + ) + return target + async def prepare_outbound_message( - self, message: OutboundMessage, context: InjectionContext = None + self, + message: OutboundMessage, + context: InjectionContext = None, + direct_response: bool = False, ): """Prepare a response message for transmission. Args: message: An outbound message to be sent context: Optional request context + direct_response: Skip wrapping the response in forward messages """ context = context or self.context if message.connection_id and not message.target: - try: - record = await ConnectionRecord.retrieve_by_id( - context, message.connection_id - ) - except StorageNotFoundError as e: - raise MessagePrepareError( - "Could not locate connection record: {}".format( - message.connection_id - ) - ) from e - mgr = ConnectionManager(context) - try: - target = await mgr.get_connection_target(record) - except ConnectionManagerError as e: - raise MessagePrepareError(str(e)) from e - if not target: - raise MessagePrepareError( - "No connection target for message: {}".format(message.connection_id) - ) - message.target = target + message.target = await self.get_connection_target(message.connection_id) if not message.encoded and message.target: target = message.target message.payload = await self.message_serializer.encode_message( context, message.payload, - target.recipient_keys, - target.routing_keys, + target.recipient_keys or [], + (not direct_response) and target.routing_keys or [], target.sender_key, ) message.encoded = True @@ -343,11 +355,6 @@ async def outbound_message_router( message: An outbound message to be sent context: Optional request context """ - try: - await self.prepare_outbound_message(message, context) - except MessagePrepareError: - self.logger.exception("Error preparing outbound message for transmission") - return # try socket connections first, preferring the same socket ID socket_id = message.reply_socket_id @@ -364,12 +371,28 @@ async def outbound_message_router( sel_socket = socket break if sel_socket: + try: + await self.prepare_outbound_message(message, context, True) + except MessagePrepareError: + self.logger.exception( + "Error preparing outbound message for direct response" + ) + return + await sel_socket.send(message) self.logger.debug("Returned message to socket %s", sel_socket.socket_id) return # deliver directly to endpoint if message.endpoint: + try: + await self.prepare_outbound_message(message, context) + except MessagePrepareError: + self.logger.exception( + "Error preparing outbound message for transmission" + ) + return + await self.outbound_transport_manager.send_message(message) return diff --git a/aries_cloudagent/config/argparse.py b/aries_cloudagent/config/argparse.py index ac43f7af72..dd038a9b8c 100644 --- a/aries_cloudagent/config/argparse.py +++ b/aries_cloudagent/config/argparse.py @@ -18,11 +18,11 @@ class ArgumentGroup(abc.ABC): GROUP_NAME = None @abc.abstractmethod - def add_arguments(parser: ArgumentParser): + def add_arguments(self, parser: ArgumentParser): """Add arguments to the provided argument parser.""" @abc.abstractmethod - def get_settings(args: Namespace) -> dict: + def get_settings(self, args: Namespace) -> dict: """Extract settings from the parsed arguments.""" diff --git a/aries_cloudagent/messaging/actionmenu/messages/menu.py b/aries_cloudagent/messaging/actionmenu/messages/menu.py index 4252fea0e3..46543e7dcc 100644 --- a/aries_cloudagent/messaging/actionmenu/messages/menu.py +++ b/aries_cloudagent/messaging/actionmenu/messages/menu.py @@ -57,5 +57,5 @@ class Meta: title = fields.Str(required=False) description = fields.Str(required=False) - errormsg = description = fields.Str(required=False) + errormsg = fields.Str(required=False) options = fields.List(fields.Nested(MenuOptionSchema), required=True) diff --git a/aries_cloudagent/messaging/actionmenu/routes.py b/aries_cloudagent/messaging/actionmenu/routes.py index 018d591890..7dc7d32594 100644 --- a/aries_cloudagent/messaging/actionmenu/routes.py +++ b/aries_cloudagent/messaging/actionmenu/routes.py @@ -30,7 +30,7 @@ class MenuJsonSchema(Schema): title = fields.Str(required=False) description = fields.Str(required=False) - errormsg = description = fields.Str(required=False) + errormsg = fields.Str(required=False) options = fields.List(fields.Nested(MenuOptionSchema), required=True) diff --git a/aries_cloudagent/messaging/credentials/routes.py b/aries_cloudagent/messaging/credentials/routes.py index e5d1a1e7bf..5cfac7e294 100644 --- a/aries_cloudagent/messaging/credentials/routes.py +++ b/aries_cloudagent/messaging/credentials/routes.py @@ -519,7 +519,6 @@ async def credential_exchange_problem_report(request: web.BaseRequest): context = request.app["request_context"] outbound_handler = request.app["outbound_message_router"] - credential_exchange_id = request.match_info["id"] body = await request.json() try: diff --git a/aries_cloudagent/messaging/presentations/routes.py b/aries_cloudagent/messaging/presentations/routes.py index 6071fb23d4..747fb4d638 100644 --- a/aries_cloudagent/messaging/presentations/routes.py +++ b/aries_cloudagent/messaging/presentations/routes.py @@ -376,7 +376,6 @@ async def presentation_exchange_remove(request: web.BaseRequest): request: aiohttp request object """ context = request.app["request_context"] - presentation_exchange_id = request.match_info["id"] try: presentation_exchange_id = request.match_info["id"] presentation_exchange_record = await PresentationExchange.retrieve_by_id( diff --git a/aries_cloudagent/tests/test_conductor.py b/aries_cloudagent/tests/test_conductor.py index cdb10883c6..8fb27f6a3c 100644 --- a/aries_cloudagent/tests/test_conductor.py +++ b/aries_cloudagent/tests/test_conductor.py @@ -1,3 +1,5 @@ +import asyncio +from io import StringIO from unittest import mock, TestCase from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock @@ -6,11 +8,16 @@ from ..admin.base_server import BaseAdminServer from ..config.base_context import ContextBuilder from ..config.injection_context import InjectionContext +from ..messaging.connections.manager import ConnectionManager +from ..messaging.connections.models.connection_record import ConnectionRecord from ..messaging.connections.models.connection_target import ConnectionTarget from ..messaging.message_delivery import MessageDelivery from ..messaging.serializer import MessageSerializer from ..messaging.outbound_message import OutboundMessage from ..messaging.protocol_registry import ProtocolRegistry +from ..stats import Collector +from ..storage.base import BaseStorage +from ..storage.basic import BasicStorage from ..transport.inbound.base import InboundTransportConfiguration from ..transport.outbound.queue.base import BaseOutboundMessageQueue from ..transport.outbound.queue.basic import BasicOutboundMessageQueue @@ -37,12 +44,20 @@ async def build(self) -> InjectionContext: context.injector.bind_instance( BaseOutboundMessageQueue, BasicOutboundMessageQueue() ) + context.injector.bind_instance(BaseStorage, BasicStorage()) context.injector.bind_instance(BaseWallet, BasicWallet()) context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) context.injector.bind_instance(MessageSerializer, self.message_serializer) return context +class StubCollectorContextBuilder(StubContextBuilder): + async def build(self) -> InjectionContext: + context = await super().build() + context.injector.bind_instance(Collector, Collector()) + return context + + class TestConductor(AsyncTestCase, Config): async def test_startup(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) @@ -86,7 +101,12 @@ async def test_inbound_message_handler(self): await conductor.setup() - with async_mock.patch.object(conductor.dispatcher, "dispatch") as mock_dispatch: + with async_mock.patch.object( + conductor.dispatcher, "dispatch", autospec=True + ) as mock_dispatch: + dispatch_result = """{"@type": "..."}""" + mock_dispatch.return_value.return_value = asyncio.Future() + mock_dispatch.return_value.return_value.set_result(None) delivery = MessageDelivery() parsed_msg = {} @@ -96,7 +116,8 @@ async def test_inbound_message_handler(self): message_body = "{}" transport = "http" - await conductor.inbound_message_router(message_body, transport) + complete = await conductor.inbound_message_router(message_body, transport) + asyncio.wait_for(complete, 1.0) mock_serializer.parse_message.assert_awaited_once_with( conductor.context, message_body, transport @@ -106,6 +127,48 @@ async def test_inbound_message_handler(self): parsed_msg, delivery, None, conductor.outbound_message_router ) + async def test_direct_response(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) + conductor = test_module.Conductor(builder) + + await conductor.setup() + + single_response = asyncio.Future() + dispatch_result = """{"@type": "..."}""" + + async def mock_dispatch(parsed_msg, delivery, connection, outbound): + socket_id = delivery.socket_id + socket = conductor.sockets[socket_id] + socket.reply_mode = "all" + reply = OutboundMessage( + dispatch_result, + connection_id=None, + encoded=False, + endpoint=None, + reply_socket_id=socket_id, + ) + await outbound(reply) + result = asyncio.Future() + result.set_result(None) + return result + + with async_mock.patch.object(conductor.dispatcher, "dispatch", mock_dispatch): + + delivery = MessageDelivery() + parsed_msg = {} + mock_serializer = builder.message_serializer + mock_serializer.extract_message_type.return_value = "message_type" + mock_serializer.parse_message.return_value = (parsed_msg, delivery) + + message_body = "{}" + transport = "http" + complete = await conductor.inbound_message_router( + message_body, transport, None, single_response + ) + asyncio.wait_for(complete, 1.0) + + assert single_response.result() == dispatch_result + async def test_outbound_message_handler(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) conductor = test_module.Conductor(builder) @@ -137,6 +200,29 @@ async def test_outbound_message_handler(self): message ) + async def test_connection_target(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) + conductor = test_module.Conductor(builder) + + await conductor.setup() + + test_target = ConnectionTarget( + endpoint="endpoint", recipient_keys=(), routing_keys=(), sender_key="" + ) + test_conn_id = "1" + + with async_mock.patch.object( + ConnectionRecord, "retrieve_by_id", autospec=True + ) as retrieve_by_id, async_mock.patch.object( + ConnectionManager, "get_connection_target", autospec=True + ) as get_target: + + get_target.return_value = test_target + + target = await conductor.get_connection_target(test_conn_id) + + assert target is test_target + async def test_admin(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) builder.update_settings({"admin.enabled": "1"}) @@ -156,3 +242,35 @@ async def test_admin(self): await conductor.stop() admin_stop.assert_awaited_once_with() + + async def test_setup_collector(self): + builder: ContextBuilder = StubCollectorContextBuilder(self.test_settings) + builder.update_settings(self.good_inbound_transports) + builder.update_settings(self.good_outbound_transports) + conductor = test_module.Conductor(builder) + + with async_mock.patch.object( + test_module, "InboundTransportManager", autospec=True + ) as mock_inbound_mgr, async_mock.patch.object( + test_module, "OutboundTransportManager", autospec=True + ) as mock_outbound_mgr, async_mock.patch.object( + test_module, "LoggingConfigurator", autospec=True + ) as mock_logger: + + await conductor.setup() + + async def test_print_invite(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) + builder.update_settings( + {"debug.print_invitation": True, "invite_base_url": "http://localhost"} + ) + conductor = test_module.Conductor(builder) + + with mock.patch("sys.stdout", new=StringIO()) as captured: + await conductor.setup() + + await conductor.start() + + await conductor.stop() + + assert "http://localhost?c_i=" in captured.getvalue() diff --git a/aries_cloudagent/wallet/crypto.py b/aries_cloudagent/wallet/crypto.py index dd78135d7e..db9d32f3d9 100644 --- a/aries_cloudagent/wallet/crypto.py +++ b/aries_cloudagent/wallet/crypto.py @@ -487,7 +487,7 @@ def decode_pack_message( try: recips_outer = PackRecipientsSchema().loads(recips_json) except ValidationError: - ValueError("Invalid packed message recipients") + raise ValueError("Invalid packed message recipients") alg = recips_outer["alg"] is_authcrypt = alg == "Authcrypt"