Skip to content

Commit

Permalink
Merge pull request #199 from andrewwhitehead/unwrap-direct-response
Browse files Browse the repository at this point in the history
Don't wrap direct response messages with Forward wrapper(s)
  • Loading branch information
swcurran authored Sep 26, 2019
2 parents fd02a50 + 2a5be41 commit 5db1d89
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 42 deletions.
2 changes: 1 addition & 1 deletion aries_cloudagent/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def get(self, key: Text):
"""

@abstractmethod
async def set(self, key: Text, value: Any, ttl: int):
async def set(self, key: Text, value: Any, ttl: int = None):
"""
Add an item to the cache with an optional ttl.
Expand Down
87 changes: 55 additions & 32 deletions aries_cloudagent/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,8 @@ async def start(self) -> None:
_connection, invitation = await mgr.create_invitation(
their_role=context.settings.get("debug.invite_role"),
my_label=context.settings.get("debug.invite_label"),
multi_use=context.settings.get(
"debug.invite_multi_use", False
),
public=context.settings.get("debug.invite_public", False)
multi_use=context.settings.get("debug.invite_multi_use", False),
public=context.settings.get("debug.invite_public", False),
)
base_url = context.settings.get("invite_base_url")
invite_url = invitation.to_url(base_url)
Expand Down Expand Up @@ -288,47 +286,61 @@ async def inbound_message_router(
complete.add_done_callback(lambda fut: socket.dispatch_complete())
return complete

async def get_connection_target(
self, connection_id: str, context: InjectionContext = None
):
"""Get a `ConnectionTarget` instance representing a connection.
Args:
connection_id: The connection record identifier
context: An optional injection context
"""

context = context or self.context

try:
record = await ConnectionRecord.retrieve_by_id(context, connection_id)
except StorageNotFoundError as e:
raise MessagePrepareError(
"Could not locate connection record: {}".format(connection_id)
) from e
mgr = ConnectionManager(context)
try:
target = await mgr.get_connection_target(record)
except ConnectionManagerError as e:
raise MessagePrepareError(str(e)) from e
if not target:
raise MessagePrepareError(
"No target found for connection: {}".format(connection_id)
)
return target

async def prepare_outbound_message(
self, message: OutboundMessage, context: InjectionContext = None
self,
message: OutboundMessage,
context: InjectionContext = None,
direct_response: bool = False,
):
"""Prepare a response message for transmission.
Args:
message: An outbound message to be sent
context: Optional request context
direct_response: Skip wrapping the response in forward messages
"""

context = context or self.context

if message.connection_id and not message.target:
try:
record = await ConnectionRecord.retrieve_by_id(
context, message.connection_id
)
except StorageNotFoundError as e:
raise MessagePrepareError(
"Could not locate connection record: {}".format(
message.connection_id
)
) from e
mgr = ConnectionManager(context)
try:
target = await mgr.get_connection_target(record)
except ConnectionManagerError as e:
raise MessagePrepareError(str(e)) from e
if not target:
raise MessagePrepareError(
"No connection target for message: {}".format(message.connection_id)
)
message.target = target
message.target = await self.get_connection_target(message.connection_id)

if not message.encoded and message.target:
target = message.target
message.payload = await self.message_serializer.encode_message(
context,
message.payload,
target.recipient_keys,
target.routing_keys,
target.recipient_keys or [],
(not direct_response) and target.routing_keys or [],
target.sender_key,
)
message.encoded = True
Expand All @@ -343,11 +355,6 @@ async def outbound_message_router(
message: An outbound message to be sent
context: Optional request context
"""
try:
await self.prepare_outbound_message(message, context)
except MessagePrepareError:
self.logger.exception("Error preparing outbound message for transmission")
return

# try socket connections first, preferring the same socket ID
socket_id = message.reply_socket_id
Expand All @@ -364,12 +371,28 @@ async def outbound_message_router(
sel_socket = socket
break
if sel_socket:
try:
await self.prepare_outbound_message(message, context, True)
except MessagePrepareError:
self.logger.exception(
"Error preparing outbound message for direct response"
)
return

await sel_socket.send(message)
self.logger.debug("Returned message to socket %s", sel_socket.socket_id)
return

# deliver directly to endpoint
if message.endpoint:
try:
await self.prepare_outbound_message(message, context)
except MessagePrepareError:
self.logger.exception(
"Error preparing outbound message for transmission"
)
return

await self.outbound_transport_manager.send_message(message)
return

Expand Down
4 changes: 2 additions & 2 deletions aries_cloudagent/config/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class ArgumentGroup(abc.ABC):
GROUP_NAME = None

@abc.abstractmethod
def add_arguments(parser: ArgumentParser):
def add_arguments(self, parser: ArgumentParser):
"""Add arguments to the provided argument parser."""

@abc.abstractmethod
def get_settings(args: Namespace) -> dict:
def get_settings(self, args: Namespace) -> dict:
"""Extract settings from the parsed arguments."""


Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/messaging/actionmenu/messages/menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ class Meta:

title = fields.Str(required=False)
description = fields.Str(required=False)
errormsg = description = fields.Str(required=False)
errormsg = fields.Str(required=False)
options = fields.List(fields.Nested(MenuOptionSchema), required=True)
2 changes: 1 addition & 1 deletion aries_cloudagent/messaging/actionmenu/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MenuJsonSchema(Schema):

title = fields.Str(required=False)
description = fields.Str(required=False)
errormsg = description = fields.Str(required=False)
errormsg = fields.Str(required=False)
options = fields.List(fields.Nested(MenuOptionSchema), required=True)


Expand Down
1 change: 0 additions & 1 deletion aries_cloudagent/messaging/credentials/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,6 @@ async def credential_exchange_problem_report(request: web.BaseRequest):
context = request.app["request_context"]
outbound_handler = request.app["outbound_message_router"]

credential_exchange_id = request.match_info["id"]
body = await request.json()

try:
Expand Down
1 change: 0 additions & 1 deletion aries_cloudagent/messaging/presentations/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ async def presentation_exchange_remove(request: web.BaseRequest):
request: aiohttp request object
"""
context = request.app["request_context"]
presentation_exchange_id = request.match_info["id"]
try:
presentation_exchange_id = request.match_info["id"]
presentation_exchange_record = await PresentationExchange.retrieve_by_id(
Expand Down
122 changes: 120 additions & 2 deletions aries_cloudagent/tests/test_conductor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
from io import StringIO
from unittest import mock, TestCase
from asynctest import TestCase as AsyncTestCase
from asynctest import mock as async_mock
Expand All @@ -6,11 +8,16 @@
from ..admin.base_server import BaseAdminServer
from ..config.base_context import ContextBuilder
from ..config.injection_context import InjectionContext
from ..messaging.connections.manager import ConnectionManager
from ..messaging.connections.models.connection_record import ConnectionRecord
from ..messaging.connections.models.connection_target import ConnectionTarget
from ..messaging.message_delivery import MessageDelivery
from ..messaging.serializer import MessageSerializer
from ..messaging.outbound_message import OutboundMessage
from ..messaging.protocol_registry import ProtocolRegistry
from ..stats import Collector
from ..storage.base import BaseStorage
from ..storage.basic import BasicStorage
from ..transport.inbound.base import InboundTransportConfiguration
from ..transport.outbound.queue.base import BaseOutboundMessageQueue
from ..transport.outbound.queue.basic import BasicOutboundMessageQueue
Expand All @@ -37,12 +44,20 @@ async def build(self) -> InjectionContext:
context.injector.bind_instance(
BaseOutboundMessageQueue, BasicOutboundMessageQueue()
)
context.injector.bind_instance(BaseStorage, BasicStorage())
context.injector.bind_instance(BaseWallet, BasicWallet())
context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
context.injector.bind_instance(MessageSerializer, self.message_serializer)
return context


class StubCollectorContextBuilder(StubContextBuilder):
async def build(self) -> InjectionContext:
context = await super().build()
context.injector.bind_instance(Collector, Collector())
return context


class TestConductor(AsyncTestCase, Config):
async def test_startup(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
Expand Down Expand Up @@ -86,7 +101,12 @@ async def test_inbound_message_handler(self):

await conductor.setup()

with async_mock.patch.object(conductor.dispatcher, "dispatch") as mock_dispatch:
with async_mock.patch.object(
conductor.dispatcher, "dispatch", autospec=True
) as mock_dispatch:
dispatch_result = """{"@type": "..."}"""
mock_dispatch.return_value.return_value = asyncio.Future()
mock_dispatch.return_value.return_value.set_result(None)

delivery = MessageDelivery()
parsed_msg = {}
Expand All @@ -96,7 +116,8 @@ async def test_inbound_message_handler(self):

message_body = "{}"
transport = "http"
await conductor.inbound_message_router(message_body, transport)
complete = await conductor.inbound_message_router(message_body, transport)
asyncio.wait_for(complete, 1.0)

mock_serializer.parse_message.assert_awaited_once_with(
conductor.context, message_body, transport
Expand All @@ -106,6 +127,48 @@ async def test_inbound_message_handler(self):
parsed_msg, delivery, None, conductor.outbound_message_router
)

async def test_direct_response(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
conductor = test_module.Conductor(builder)

await conductor.setup()

single_response = asyncio.Future()
dispatch_result = """{"@type": "..."}"""

async def mock_dispatch(parsed_msg, delivery, connection, outbound):
socket_id = delivery.socket_id
socket = conductor.sockets[socket_id]
socket.reply_mode = "all"
reply = OutboundMessage(
dispatch_result,
connection_id=None,
encoded=False,
endpoint=None,
reply_socket_id=socket_id,
)
await outbound(reply)
result = asyncio.Future()
result.set_result(None)
return result

with async_mock.patch.object(conductor.dispatcher, "dispatch", mock_dispatch):

delivery = MessageDelivery()
parsed_msg = {}
mock_serializer = builder.message_serializer
mock_serializer.extract_message_type.return_value = "message_type"
mock_serializer.parse_message.return_value = (parsed_msg, delivery)

message_body = "{}"
transport = "http"
complete = await conductor.inbound_message_router(
message_body, transport, None, single_response
)
asyncio.wait_for(complete, 1.0)

assert single_response.result() == dispatch_result

async def test_outbound_message_handler(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
conductor = test_module.Conductor(builder)
Expand Down Expand Up @@ -137,6 +200,29 @@ async def test_outbound_message_handler(self):
message
)

async def test_connection_target(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
conductor = test_module.Conductor(builder)

await conductor.setup()

test_target = ConnectionTarget(
endpoint="endpoint", recipient_keys=(), routing_keys=(), sender_key=""
)
test_conn_id = "1"

with async_mock.patch.object(
ConnectionRecord, "retrieve_by_id", autospec=True
) as retrieve_by_id, async_mock.patch.object(
ConnectionManager, "get_connection_target", autospec=True
) as get_target:

get_target.return_value = test_target

target = await conductor.get_connection_target(test_conn_id)

assert target is test_target

async def test_admin(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
builder.update_settings({"admin.enabled": "1"})
Expand All @@ -156,3 +242,35 @@ async def test_admin(self):

await conductor.stop()
admin_stop.assert_awaited_once_with()

async def test_setup_collector(self):
builder: ContextBuilder = StubCollectorContextBuilder(self.test_settings)
builder.update_settings(self.good_inbound_transports)
builder.update_settings(self.good_outbound_transports)
conductor = test_module.Conductor(builder)

with async_mock.patch.object(
test_module, "InboundTransportManager", autospec=True
) as mock_inbound_mgr, async_mock.patch.object(
test_module, "OutboundTransportManager", autospec=True
) as mock_outbound_mgr, async_mock.patch.object(
test_module, "LoggingConfigurator", autospec=True
) as mock_logger:

await conductor.setup()

async def test_print_invite(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
builder.update_settings(
{"debug.print_invitation": True, "invite_base_url": "http://localhost"}
)
conductor = test_module.Conductor(builder)

with mock.patch("sys.stdout", new=StringIO()) as captured:
await conductor.setup()

await conductor.start()

await conductor.stop()

assert "http://localhost?c_i=" in captured.getvalue()
2 changes: 1 addition & 1 deletion aries_cloudagent/wallet/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def decode_pack_message(
try:
recips_outer = PackRecipientsSchema().loads(recips_json)
except ValidationError:
ValueError("Invalid packed message recipients")
raise ValueError("Invalid packed message recipients")

alg = recips_outer["alg"]
is_authcrypt = alg == "Authcrypt"
Expand Down

0 comments on commit 5db1d89

Please sign in to comment.