Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't wrap direct response messages with Forward wrapper(s) #199

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aries_cloudagent/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def get(self, key: Text):
"""

@abstractmethod
async def set(self, key: Text, value: Any, ttl: int):
async def set(self, key: Text, value: Any, ttl: int = None):
"""
Add an item to the cache with an optional ttl.

Expand Down
87 changes: 55 additions & 32 deletions aries_cloudagent/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,8 @@ async def start(self) -> None:
_connection, invitation = await mgr.create_invitation(
their_role=context.settings.get("debug.invite_role"),
my_label=context.settings.get("debug.invite_label"),
multi_use=context.settings.get(
"debug.invite_multi_use", False
),
public=context.settings.get("debug.invite_public", False)
multi_use=context.settings.get("debug.invite_multi_use", False),
public=context.settings.get("debug.invite_public", False),
)
base_url = context.settings.get("invite_base_url")
invite_url = invitation.to_url(base_url)
Expand Down Expand Up @@ -288,47 +286,61 @@ async def inbound_message_router(
complete.add_done_callback(lambda fut: socket.dispatch_complete())
return complete

async def get_connection_target(
self, connection_id: str, context: InjectionContext = None
):
"""Get a `ConnectionTarget` instance representing a connection.

Args:
connection_id: The connection record identifier
context: An optional injection context
"""

context = context or self.context

try:
record = await ConnectionRecord.retrieve_by_id(context, connection_id)
except StorageNotFoundError as e:
raise MessagePrepareError(
"Could not locate connection record: {}".format(connection_id)
) from e
mgr = ConnectionManager(context)
try:
target = await mgr.get_connection_target(record)
except ConnectionManagerError as e:
raise MessagePrepareError(str(e)) from e
if not target:
raise MessagePrepareError(
"No target found for connection: {}".format(connection_id)
)
return target

async def prepare_outbound_message(
self, message: OutboundMessage, context: InjectionContext = None
self,
message: OutboundMessage,
context: InjectionContext = None,
direct_response: bool = False,
):
"""Prepare a response message for transmission.

Args:
message: An outbound message to be sent
context: Optional request context
direct_response: Skip wrapping the response in forward messages
"""

context = context or self.context

if message.connection_id and not message.target:
try:
record = await ConnectionRecord.retrieve_by_id(
context, message.connection_id
)
except StorageNotFoundError as e:
raise MessagePrepareError(
"Could not locate connection record: {}".format(
message.connection_id
)
) from e
mgr = ConnectionManager(context)
try:
target = await mgr.get_connection_target(record)
except ConnectionManagerError as e:
raise MessagePrepareError(str(e)) from e
if not target:
raise MessagePrepareError(
"No connection target for message: {}".format(message.connection_id)
)
message.target = target
message.target = await self.get_connection_target(message.connection_id)

if not message.encoded and message.target:
target = message.target
message.payload = await self.message_serializer.encode_message(
context,
message.payload,
target.recipient_keys,
target.routing_keys,
target.recipient_keys or [],
(not direct_response) and target.routing_keys or [],
target.sender_key,
)
message.encoded = True
Expand All @@ -343,11 +355,6 @@ async def outbound_message_router(
message: An outbound message to be sent
context: Optional request context
"""
try:
await self.prepare_outbound_message(message, context)
except MessagePrepareError:
self.logger.exception("Error preparing outbound message for transmission")
return

# try socket connections first, preferring the same socket ID
socket_id = message.reply_socket_id
Expand All @@ -364,12 +371,28 @@ async def outbound_message_router(
sel_socket = socket
break
if sel_socket:
try:
await self.prepare_outbound_message(message, context, True)
except MessagePrepareError:
self.logger.exception(
"Error preparing outbound message for direct response"
)
return

await sel_socket.send(message)
self.logger.debug("Returned message to socket %s", sel_socket.socket_id)
return

# deliver directly to endpoint
if message.endpoint:
try:
await self.prepare_outbound_message(message, context)
except MessagePrepareError:
self.logger.exception(
"Error preparing outbound message for transmission"
)
return

await self.outbound_transport_manager.send_message(message)
return

Expand Down
4 changes: 2 additions & 2 deletions aries_cloudagent/config/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class ArgumentGroup(abc.ABC):
GROUP_NAME = None

@abc.abstractmethod
def add_arguments(parser: ArgumentParser):
def add_arguments(self, parser: ArgumentParser):
"""Add arguments to the provided argument parser."""

@abc.abstractmethod
def get_settings(args: Namespace) -> dict:
def get_settings(self, args: Namespace) -> dict:
"""Extract settings from the parsed arguments."""


Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/messaging/actionmenu/messages/menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ class Meta:

title = fields.Str(required=False)
description = fields.Str(required=False)
errormsg = description = fields.Str(required=False)
errormsg = fields.Str(required=False)
options = fields.List(fields.Nested(MenuOptionSchema), required=True)
2 changes: 1 addition & 1 deletion aries_cloudagent/messaging/actionmenu/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MenuJsonSchema(Schema):

title = fields.Str(required=False)
description = fields.Str(required=False)
errormsg = description = fields.Str(required=False)
errormsg = fields.Str(required=False)
options = fields.List(fields.Nested(MenuOptionSchema), required=True)


Expand Down
1 change: 0 additions & 1 deletion aries_cloudagent/messaging/credentials/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,6 @@ async def credential_exchange_problem_report(request: web.BaseRequest):
context = request.app["request_context"]
outbound_handler = request.app["outbound_message_router"]

credential_exchange_id = request.match_info["id"]
body = await request.json()

try:
Expand Down
1 change: 0 additions & 1 deletion aries_cloudagent/messaging/presentations/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ async def presentation_exchange_remove(request: web.BaseRequest):
request: aiohttp request object
"""
context = request.app["request_context"]
presentation_exchange_id = request.match_info["id"]
try:
presentation_exchange_id = request.match_info["id"]
presentation_exchange_record = await PresentationExchange.retrieve_by_id(
Expand Down
122 changes: 120 additions & 2 deletions aries_cloudagent/tests/test_conductor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
from io import StringIO
from unittest import mock, TestCase
from asynctest import TestCase as AsyncTestCase
from asynctest import mock as async_mock
Expand All @@ -6,11 +8,16 @@
from ..admin.base_server import BaseAdminServer
from ..config.base_context import ContextBuilder
from ..config.injection_context import InjectionContext
from ..messaging.connections.manager import ConnectionManager
from ..messaging.connections.models.connection_record import ConnectionRecord
from ..messaging.connections.models.connection_target import ConnectionTarget
from ..messaging.message_delivery import MessageDelivery
from ..messaging.serializer import MessageSerializer
from ..messaging.outbound_message import OutboundMessage
from ..messaging.protocol_registry import ProtocolRegistry
from ..stats import Collector
from ..storage.base import BaseStorage
from ..storage.basic import BasicStorage
from ..transport.inbound.base import InboundTransportConfiguration
from ..transport.outbound.queue.base import BaseOutboundMessageQueue
from ..transport.outbound.queue.basic import BasicOutboundMessageQueue
Expand All @@ -37,12 +44,20 @@ async def build(self) -> InjectionContext:
context.injector.bind_instance(
BaseOutboundMessageQueue, BasicOutboundMessageQueue()
)
context.injector.bind_instance(BaseStorage, BasicStorage())
context.injector.bind_instance(BaseWallet, BasicWallet())
context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
context.injector.bind_instance(MessageSerializer, self.message_serializer)
return context


class StubCollectorContextBuilder(StubContextBuilder):
async def build(self) -> InjectionContext:
context = await super().build()
context.injector.bind_instance(Collector, Collector())
return context


class TestConductor(AsyncTestCase, Config):
async def test_startup(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
Expand Down Expand Up @@ -86,7 +101,12 @@ async def test_inbound_message_handler(self):

await conductor.setup()

with async_mock.patch.object(conductor.dispatcher, "dispatch") as mock_dispatch:
with async_mock.patch.object(
conductor.dispatcher, "dispatch", autospec=True
) as mock_dispatch:
dispatch_result = """{"@type": "..."}"""
mock_dispatch.return_value.return_value = asyncio.Future()
mock_dispatch.return_value.return_value.set_result(None)

delivery = MessageDelivery()
parsed_msg = {}
Expand All @@ -96,7 +116,8 @@ async def test_inbound_message_handler(self):

message_body = "{}"
transport = "http"
await conductor.inbound_message_router(message_body, transport)
complete = await conductor.inbound_message_router(message_body, transport)
asyncio.wait_for(complete, 1.0)

mock_serializer.parse_message.assert_awaited_once_with(
conductor.context, message_body, transport
Expand All @@ -106,6 +127,48 @@ async def test_inbound_message_handler(self):
parsed_msg, delivery, None, conductor.outbound_message_router
)

async def test_direct_response(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
conductor = test_module.Conductor(builder)

await conductor.setup()

single_response = asyncio.Future()
dispatch_result = """{"@type": "..."}"""

async def mock_dispatch(parsed_msg, delivery, connection, outbound):
socket_id = delivery.socket_id
socket = conductor.sockets[socket_id]
socket.reply_mode = "all"
reply = OutboundMessage(
dispatch_result,
connection_id=None,
encoded=False,
endpoint=None,
reply_socket_id=socket_id,
)
await outbound(reply)
result = asyncio.Future()
result.set_result(None)
return result

with async_mock.patch.object(conductor.dispatcher, "dispatch", mock_dispatch):

delivery = MessageDelivery()
parsed_msg = {}
mock_serializer = builder.message_serializer
mock_serializer.extract_message_type.return_value = "message_type"
mock_serializer.parse_message.return_value = (parsed_msg, delivery)

message_body = "{}"
transport = "http"
complete = await conductor.inbound_message_router(
message_body, transport, None, single_response
)
asyncio.wait_for(complete, 1.0)

assert single_response.result() == dispatch_result

async def test_outbound_message_handler(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
conductor = test_module.Conductor(builder)
Expand Down Expand Up @@ -137,6 +200,29 @@ async def test_outbound_message_handler(self):
message
)

async def test_connection_target(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
conductor = test_module.Conductor(builder)

await conductor.setup()

test_target = ConnectionTarget(
endpoint="endpoint", recipient_keys=(), routing_keys=(), sender_key=""
)
test_conn_id = "1"

with async_mock.patch.object(
ConnectionRecord, "retrieve_by_id", autospec=True
) as retrieve_by_id, async_mock.patch.object(
ConnectionManager, "get_connection_target", autospec=True
) as get_target:

get_target.return_value = test_target

target = await conductor.get_connection_target(test_conn_id)

assert target is test_target

async def test_admin(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
builder.update_settings({"admin.enabled": "1"})
Expand All @@ -156,3 +242,35 @@ async def test_admin(self):

await conductor.stop()
admin_stop.assert_awaited_once_with()

async def test_setup_collector(self):
builder: ContextBuilder = StubCollectorContextBuilder(self.test_settings)
builder.update_settings(self.good_inbound_transports)
builder.update_settings(self.good_outbound_transports)
conductor = test_module.Conductor(builder)

with async_mock.patch.object(
test_module, "InboundTransportManager", autospec=True
) as mock_inbound_mgr, async_mock.patch.object(
test_module, "OutboundTransportManager", autospec=True
) as mock_outbound_mgr, async_mock.patch.object(
test_module, "LoggingConfigurator", autospec=True
) as mock_logger:

await conductor.setup()

async def test_print_invite(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
builder.update_settings(
{"debug.print_invitation": True, "invite_base_url": "http://localhost"}
)
conductor = test_module.Conductor(builder)

with mock.patch("sys.stdout", new=StringIO()) as captured:
await conductor.setup()

await conductor.start()

await conductor.stop()

assert "http://localhost?c_i=" in captured.getvalue()
2 changes: 1 addition & 1 deletion aries_cloudagent/wallet/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def decode_pack_message(
try:
recips_outer = PackRecipientsSchema().loads(recips_json)
except ValidationError:
ValueError("Invalid packed message recipients")
raise ValueError("Invalid packed message recipients")

alg = recips_outer["alg"]
is_authcrypt = alg == "Authcrypt"
Expand Down