diff --git a/aries_cloudagent/core/dispatcher.py b/aries_cloudagent/core/dispatcher.py index e3f37b45ac..2193dd20b9 100644 --- a/aries_cloudagent/core/dispatcher.py +++ b/aries_cloudagent/core/dispatcher.py @@ -10,7 +10,7 @@ import os import warnings -from typing import Callable, Coroutine, Union +from typing import Callable, Coroutine, Optional, Union, Tuple import weakref from aiohttp.web import HTTPException @@ -36,6 +36,13 @@ from .error import ProtocolMinorVersionNotSupported from .protocol_registry import ProtocolRegistry +from .util import ( + get_version_from_message_type, + validate_get_response_version, + # WARNING_DEGRADED_FEATURES, + # WARNING_VERSION_MISMATCH, + # WARNING_VERSION_NOT_SUPPORTED, +) LOGGER = logging.getLogger(__name__) @@ -133,6 +140,9 @@ async def handle_message( inbound_message: The inbound message instance send_outbound: Async function to send outbound messages + # Raises: + # MessageParseError: If the message type version is not supported + Returns: The response from the handler @@ -140,9 +150,12 @@ async def handle_message( r_time = get_timer() error_result = None + version_warning = None message = None try: - message = await self.make_message(inbound_message.payload) + (message, warning) = await self.make_message( + profile, inbound_message.payload + ) except ProblemReportParseError: pass # avoid problem report recursion except MessageParseError as e: @@ -155,6 +168,47 @@ async def handle_message( ) if inbound_message.receipt.thread_id: error_result.assign_thread_id(inbound_message.receipt.thread_id) + # if warning: + # warning_message_type = inbound_message.payload.get("@type") + # if warning == WARNING_DEGRADED_FEATURES: + # LOGGER.error( + # f"Sending {WARNING_DEGRADED_FEATURES} problem report, " + # "message type received with a minor version at or higher" + # " than protocol minimum supported and current minor version " + # f"for message_type {warning_message_type}" + # ) + # version_warning = ProblemReport( + # description={ + # "en": ( + # "message type received with a minor version at or " + # "higher than protocol minimum supported and current" + # f" minor version for message_type {warning_message_type}" + # ), + # "code": WARNING_DEGRADED_FEATURES, + # } + # ) + # elif warning == WARNING_VERSION_MISMATCH: + # LOGGER.error( + # f"Sending {WARNING_VERSION_MISMATCH} problem report, message " + # "type received with a minor version higher than current minor " + # f"version for message_type {warning_message_type}" + # ) + # version_warning = ProblemReport( + # description={ + # "en": ( + # "message type received with a minor version higher" + # " than current minor version for message_type" + # f" {warning_message_type}" + # ), + # "code": WARNING_VERSION_MISMATCH, + # } + # ) + # elif warning == WARNING_VERSION_NOT_SUPPORTED: + # raise MessageParseError( + # f"Message type version not supported for {warning_message_type}" + # ) + # if version_warning and inbound_message.receipt.thread_id: + # version_warning.assign_thread_id(inbound_message.receipt.thread_id) trace_event( self.profile.settings, @@ -199,6 +253,8 @@ async def handle_message( if error_result: await responder.send_reply(error_result) + elif version_warning: + await responder.send_reply(version_warning) elif context.message: context.injector.bind_instance(BaseResponder, responder) @@ -215,7 +271,9 @@ async def handle_message( perf_counter=r_time, ) - async def make_message(self, parsed_msg: dict) -> BaseMessage: + async def make_message( + self, profile: Profile, parsed_msg: dict + ) -> Tuple[BaseMessage, Optional[str]]: """ Deserialize a message dict into the appropriate message instance. @@ -224,6 +282,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage: Args: parsed_msg: The parsed message + profile: Profile Returns: An instance of the corresponding message class for this message @@ -237,6 +296,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage: if not isinstance(parsed_msg, dict): raise MessageParseError("Expected a JSON object") message_type = parsed_msg.get("@type") + message_type_rec_version = get_version_from_message_type(message_type) if not message_type: raise MessageParseError("Message does not contain '@type' parameter") @@ -256,8 +316,10 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage: if "/problem-report" in message_type: raise ProblemReportParseError("Error parsing problem report message") raise MessageParseError(f"Error deserializing message: {e}") from e - - return instance + _, warning = await validate_get_response_version( + profile, message_type_rec_version, message_cls + ) + return (instance, warning) async def complete(self, timeout: float = 0.1): """Wait for pending tasks to complete.""" diff --git a/aries_cloudagent/core/protocol_registry.py b/aries_cloudagent/core/protocol_registry.py index 805c35efa7..cd18a814dd 100644 --- a/aries_cloudagent/core/protocol_registry.py +++ b/aries_cloudagent/core/protocol_registry.py @@ -2,6 +2,7 @@ import logging +from string import Template from typing import Mapping, Sequence from ..config.injection_context import InjectionContext @@ -74,6 +75,73 @@ def parse_type_string(self, message_type): "minor_version": int(version_string_tokens[1]), } + def create_msg_types_for_minor_version(self, typesets, version_definition): + """ + Return mapping of message type to module path for minor versions. + + Args: + typesets: Mappings of message types to register + version_definition: Optional version definition dict + + Returns: + Typesets mapping + + """ + updated_typeset = {} + curr_minor_version = version_definition["current_minor_version"] + min_minor_version = version_definition["minimum_minor_version"] + major_version = version_definition["major_version"] + if curr_minor_version >= min_minor_version and curr_minor_version >= 1: + for version_index in range(min_minor_version, curr_minor_version + 1): + to_check = f"{str(major_version)}.{str(version_index)}" + updated_typeset.update( + self._get_updated_tyoeset_dict(typesets, to_check, updated_typeset) + ) + return (updated_typeset,) + + def _get_updated_tyoeset_dict(self, typesets, to_check, updated_typeset) -> dict: + for typeset in typesets: + for msg_type_string, module_path in typeset.items(): + updated_msg_type_string = Template(msg_type_string).substitute( + version=to_check + ) + updated_typeset[updated_msg_type_string] = module_path + return updated_typeset + + def _template_message_type_check(self, typeset) -> bool: + for msg_type_string, _ in typeset.items(): + if "$version" in msg_type_string: + return True + return False + + def _create_and_register_updated_typesets(self, typesets, version_definition): + updated_typesets = self.create_msg_types_for_minor_version( + typesets, version_definition + ) + update_flag = False + for typeset in updated_typesets: + if typeset: + self._typemap.update(typeset) + update_flag = True + if update_flag: + return updated_typesets + else: + return None + + def _update_version_map(self, message_type_string, module_path, version_definition): + parsed_type_string = self.parse_type_string(message_type_string) + + if version_definition["major_version"] not in self._versionmap: + self._versionmap[version_definition["major_version"]] = [] + + self._versionmap[version_definition["major_version"]].append( + { + "parsed_type_string": parsed_type_string, + "version_definition": version_definition, + "message_module": module_path, + } + ) + def register_message_types(self, *typesets, version_definition=None): """ Add new supported message types. @@ -85,24 +153,26 @@ def register_message_types(self, *typesets, version_definition=None): """ # Maintain support for versionless protocol modules + template_msg_type_version = True + updated_typesets = None for typeset in typesets: - self._typemap.update(typeset) + if not self._template_message_type_check(typeset): + self._typemap.update(typeset) + template_msg_type_version = False # Track versioned modules for version routing if version_definition: + # create updated typesets for minor versions and register them + if template_msg_type_version: + updated_typesets = self._create_and_register_updated_typesets( + typesets, version_definition + ) + if updated_typesets: + typesets = updated_typesets for typeset in typesets: for message_type_string, module_path in typeset.items(): - parsed_type_string = self.parse_type_string(message_type_string) - - if version_definition["major_version"] not in self._versionmap: - self._versionmap[version_definition["major_version"]] = [] - - self._versionmap[version_definition["major_version"]].append( - { - "parsed_type_string": parsed_type_string, - "version_definition": version_definition, - "message_module": module_path, - } + self._update_version_map( + message_type_string, module_path, version_definition ) def register_controllers(self, *controller_sets, version_definition=None): diff --git a/aries_cloudagent/core/tests/test_dispatcher.py b/aries_cloudagent/core/tests/test_dispatcher.py index 4722ad9530..ffca07ce14 100644 --- a/aries_cloudagent/core/tests/test_dispatcher.py +++ b/aries_cloudagent/core/tests/test_dispatcher.py @@ -111,7 +111,15 @@ async def test_dispatch(self): StubAgentMessageHandler, "handle", autospec=True ) as handler_mock, async_mock.patch.object( test_module, "ConnectionManager", autospec=True - ) as conn_mgr_mock: + ) as conn_mgr_mock, async_mock.patch.object( + test_module, + "get_version_from_message_type", + async_mock.AsyncMock(return_value="1.1"), + ), async_mock.patch.object( + test_module, + "validate_get_response_version", + async_mock.AsyncMock(return_value=("1.1", None)), + ): conn_mgr_mock.return_value = async_mock.MagicMock( find_inbound_connection=async_mock.AsyncMock( return_value=async_mock.MagicMock(connection_id="dummy") @@ -152,7 +160,15 @@ async def test_dispatch_versioned_message(self): with async_mock.patch.object( StubAgentMessageHandler, "handle", autospec=True - ) as handler_mock: + ) as handler_mock, async_mock.patch.object( + test_module, + "get_version_from_message_type", + async_mock.AsyncMock(return_value="1.1"), + ), async_mock.patch.object( + test_module, + "validate_get_response_version", + async_mock.AsyncMock(return_value=("1.1", None)), + ): await dispatcher.queue_message( dispatcher.profile, make_inbound(message), rcv.send ) @@ -265,7 +281,15 @@ async def test_dispatch_versioned_message_handle_greater_succeeds(self): with async_mock.patch.object( StubAgentMessageHandler, "handle", autospec=True - ) as handler_mock: + ) as handler_mock, async_mock.patch.object( + test_module, + "get_version_from_message_type", + async_mock.AsyncMock(return_value="1.1"), + ), async_mock.patch.object( + test_module, + "validate_get_response_version", + async_mock.AsyncMock(return_value=("1.1", None)), + ): await dispatcher.queue_message( dispatcher.profile, make_inbound(message), rcv.send ) @@ -317,17 +341,22 @@ async def test_bad_message_dispatch_parse_x(self): await dispatcher.setup() rcv = Receiver() bad_messages = ["not even a dict", {"bad": "message"}] - for bad in bad_messages: - await dispatcher.queue_message( - dispatcher.profile, make_inbound(bad), rcv.send - ) - await dispatcher.task_queue - assert rcv.messages and isinstance(rcv.messages[0][1], OutboundMessage) - payload = json.loads(rcv.messages[0][1].payload) - assert payload["@type"] == DIDCommPrefix.qualify_current( - ProblemReport.Meta.message_type - ) - rcv.messages.clear() + with async_mock.patch.object( + test_module, "get_version_from_message_type", async_mock.AsyncMock() + ), async_mock.patch.object( + test_module, "validate_get_response_version", async_mock.AsyncMock() + ): + for bad in bad_messages: + await dispatcher.queue_message( + dispatcher.profile, make_inbound(bad), rcv.send + ) + await dispatcher.task_queue + assert rcv.messages and isinstance(rcv.messages[0][1], OutboundMessage) + payload = json.loads(rcv.messages[0][1].payload) + assert payload["@type"] == DIDCommPrefix.qualify_current( + ProblemReport.Meta.message_type + ) + rcv.messages.clear() async def test_bad_message_dispatch_problem_report_x(self): profile = make_profile() @@ -425,3 +454,91 @@ def _smaller_scope(): with self.assertRaises(RuntimeError): await responder.send_webhook("test", {}) + + # async def test_dispatch_version_with_degraded_features(self): + # profile = make_profile() + # registry = profile.inject(ProtocolRegistry) + # registry.register_message_types( + # { + # pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage + # for pfx in DIDCommPrefix + # } + # ) + # dispatcher = test_module.Dispatcher(profile) + # await dispatcher.setup() + # rcv = Receiver() + # message = { + # "@type": DIDCommPrefix.qualify_current(StubAgentMessage.Meta.message_type) + # } + + # with async_mock.patch.object( + # test_module, + # "get_version_from_message_type", + # async_mock.AsyncMock(return_value="1.1"), + # ), async_mock.patch.object( + # test_module, + # "validate_get_response_version", + # async_mock.AsyncMock(return_value=("1.1", "fields-ignored-due-to-version-mismatch")), + # ): + # await dispatcher.queue_message( + # dispatcher.profile, make_inbound(message), rcv.send + # ) + + # async def test_dispatch_fields_ignored_due_to_version_mismatch(self): + # profile = make_profile() + # registry = profile.inject(ProtocolRegistry) + # registry.register_message_types( + # { + # pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage + # for pfx in DIDCommPrefix + # } + # ) + # dispatcher = test_module.Dispatcher(profile) + # await dispatcher.setup() + # rcv = Receiver() + # message = { + # "@type": DIDCommPrefix.qualify_current(StubAgentMessage.Meta.message_type) + # } + + # with async_mock.patch.object( + # test_module, + # "get_version_from_message_type", + # async_mock.AsyncMock(return_value="1.1"), + # ), async_mock.patch.object( + # test_module, + # "validate_get_response_version", + # async_mock.AsyncMock(return_value=("1.1", "version-with-degraded-features")), + # ): + # await dispatcher.queue_message( + # dispatcher.profile, make_inbound(message), rcv.send + # ) + + # async def test_dispatch_version_not_supported(self): + # profile = make_profile() + # registry = profile.inject(ProtocolRegistry) + # registry.register_message_types( + # { + # pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage + # for pfx in DIDCommPrefix + # } + # ) + # dispatcher = test_module.Dispatcher(profile) + # await dispatcher.setup() + # rcv = Receiver() + # message = { + # "@type": DIDCommPrefix.qualify_current(StubAgentMessage.Meta.message_type) + # } + + # with async_mock.patch.object( + # test_module, + # "get_version_from_message_type", + # async_mock.AsyncMock(return_value="1.1"), + # ), async_mock.patch.object( + # test_module, + # "validate_get_response_version", + # async_mock.AsyncMock(return_value=("1.1", "version-not-supported")), + # ): + # with self.assertRaises(test_module.MessageParseError): + # await dispatcher.queue_message( + # dispatcher.profile, make_inbound(message), rcv.send + # ) diff --git a/aries_cloudagent/core/tests/test_protocol_registry.py b/aries_cloudagent/core/tests/test_protocol_registry.py index 5c43668d8b..15d99cbc36 100644 --- a/aries_cloudagent/core/tests/test_protocol_registry.py +++ b/aries_cloudagent/core/tests/test_protocol_registry.py @@ -44,6 +44,60 @@ def test_message_type_query(self): matches = self.registry.protocols_matching_query(q) assert matches == () + def test_create_msg_types_for_minor_version(self): + test_typesets = ( + { + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/$version/fake-forward-invitation": "aries_cloudagent.protocols.introduction.v0_1.messages.forward_invitation.ForwardInvitation", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/$version/fake-invitation": "aries_cloudagent.protocols.introduction.v0_1.messages.invitation.Invitation", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/$version/fake-invitation-request": "aries_cloudagent.protocols.introduction.v0_1.messages.invitation_request.InvitationRequest", + "https://didcom.org/introduction-service/$version/fake-forward-invitation": "aries_cloudagent.protocols.introduction.v0_1.messages.forward_invitation.ForwardInvitation", + "https://didcom.org/introduction-service/$version/fake-invitation": "aries_cloudagent.protocols.introduction.v0_1.messages.invitation.Invitation", + "https://didcom.org/introduction-service/$version/fake-invitation-request": "aries_cloudagent.protocols.introduction.v0_1.messages.invitation_request.InvitationRequest", + }, + ) + test_version_def = { + "current_minor_version": 1, + "major_version": 1, + "minimum_minor_version": 0, + "path": "v0_1", + } + updated_typesets = self.registry.create_msg_types_for_minor_version( + test_typesets, test_version_def + ) + updated_typeset = updated_typesets[0] + assert ( + "https://didcom.org/introduction-service/1.0/fake-forward-invitation" + in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/1.0/fake-invitation" + in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/1.0/fake-invitation-request" + in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/1.1/fake-forward-invitation" + in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/1.1/fake-invitation" + in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/1.1/fake-invitation-request" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-forward-invitation" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.1/fake-invitation-request" + in updated_typeset + ) + async def test_disclosed(self): self.registry.register_message_types( {self.test_message_type: self.test_message_handler} diff --git a/aries_cloudagent/core/tests/test_util.py b/aries_cloudagent/core/tests/test_util.py new file mode 100644 index 0000000000..1b20487e71 --- /dev/null +++ b/aries_cloudagent/core/tests/test_util.py @@ -0,0 +1,73 @@ +from async_case import IsolatedAsyncioTestCase + +from ...cache.base import BaseCache +from ...cache.in_memory import InMemoryCache +from ...core.in_memory import InMemoryProfile +from ...core.profile import Profile +from ...protocols.didcomm_prefix import DIDCommPrefix +from ...protocols.introduction.v0_1.messages.invitation import Invitation +from ...protocols.out_of_band.v1_0.messages.reuse import HandshakeReuse + +from .. import util as test_module + + +def make_profile() -> Profile: + profile = InMemoryProfile.test_profile() + profile.context.injector.bind_instance(BaseCache, InMemoryCache()) + return profile + + +class TestUtils(IsolatedAsyncioTestCase): + async def test_validate_get_response_version(self): + profile = make_profile() + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.1", HandshakeReuse + ) + assert resp_version == "1.1" + assert not warning + + # cached + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.1", HandshakeReuse + ) + assert resp_version == "1.1" + assert not warning + + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.0", HandshakeReuse + ) + assert resp_version == "1.0" + assert warning == test_module.WARNING_DEGRADED_FEATURES + + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.2", HandshakeReuse + ) + assert resp_version == "1.1" + assert warning == test_module.WARNING_VERSION_MISMATCH + + with self.assertRaises(test_module.ProtocolMinorVersionNotSupported): + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "0.0", Invitation + ) + + with self.assertRaises(Exception): + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.0", Invitation + ) + + def test_get_version_from_message_type(self): + assert ( + test_module.get_version_from_message_type( + DIDCommPrefix.qualify_current("out-of-band/1.1/handshake-reuse") + ) + == "1.1" + ) + + def test_get_version_from_message(self): + assert test_module.get_version_from_message(HandshakeReuse()) == "1.0" + + async def test_get_proto_default_version(self): + profile = make_profile() + assert ( + await test_module.get_proto_default_version(profile, HandshakeReuse) + ) == "1.1" diff --git a/aries_cloudagent/core/util.py b/aries_cloudagent/core/util.py index 791f80c95d..ebe03de929 100644 --- a/aries_cloudagent/core/util.py +++ b/aries_cloudagent/core/util.py @@ -1,10 +1,145 @@ """Core utilities and constants.""" +import inspect +import os import re +from typing import Optional, Tuple + +from ..cache.base import BaseCache +from ..core.profile import Profile +from ..messaging.agent_message import AgentMessage +from ..utils.classloader import ClassLoader + +from .error import ProtocolMinorVersionNotSupported, ProtocolDefinitionValidationError CORE_EVENT_PREFIX = "acapy::core::" STARTUP_EVENT_TOPIC = CORE_EVENT_PREFIX + "startup" STARTUP_EVENT_PATTERN = re.compile(f"^{STARTUP_EVENT_TOPIC}?$") SHUTDOWN_EVENT_TOPIC = CORE_EVENT_PREFIX + "shutdown" SHUTDOWN_EVENT_PATTERN = re.compile(f"^{SHUTDOWN_EVENT_TOPIC}?$") +WARNING_DEGRADED_FEATURES = "version-with-degraded-features" +WARNING_VERSION_MISMATCH = "fields-ignored-due-to-version-mismatch" +WARNING_VERSION_NOT_SUPPORTED = "version-not-supported" + + +async def validate_get_response_version( + profile: Profile, rec_version: str, msg_class: type +) -> Tuple[str, Optional[str]]: + """ + Return a tuple with version to respond with and warnings. + + Process received version and protocol version definition, + returns the tuple. + + Args: + profile: Profile + rec_version: received version from message + msg_class: type + + Returns: + Tuple with response version and any warnings + + """ + resp_version = rec_version + warning = None + version_string_tokens = rec_version.split(".") + rec_major_version = int(version_string_tokens[0]) + rec_minor_version = int(version_string_tokens[1]) + version_definition = await get_version_def_from_msg_class( + profile, msg_class, rec_major_version + ) + proto_major_version = int(version_definition["major_version"]) + proto_curr_minor_version = int(version_definition["current_minor_version"]) + proto_min_minor_version = int(version_definition["minimum_minor_version"]) + if rec_minor_version < proto_min_minor_version: + warning = WARNING_VERSION_NOT_SUPPORTED + elif ( + rec_minor_version >= proto_min_minor_version + and rec_minor_version < proto_curr_minor_version + ): + warning = WARNING_DEGRADED_FEATURES + elif rec_minor_version > proto_curr_minor_version: + warning = WARNING_VERSION_MISMATCH + if proto_major_version == rec_major_version: + if ( + proto_min_minor_version <= rec_minor_version + and proto_curr_minor_version >= rec_minor_version + ): + resp_version = f"{str(proto_major_version)}.{str(rec_minor_version)}" + elif rec_minor_version > proto_curr_minor_version: + resp_version = f"{str(proto_major_version)}.{str(proto_curr_minor_version)}" + elif rec_minor_version < proto_min_minor_version: + raise ProtocolMinorVersionNotSupported( + "Minimum supported minor version is " + + f"{proto_min_minor_version}." + + f" Received {rec_minor_version}." + ) + else: + raise ProtocolMinorVersionNotSupported( + f"Supported major version {proto_major_version}" + " is not same as received major version" + f" {rec_major_version}." + ) + return (resp_version, warning) + + +def get_version_from_message_type(msg_type: str) -> str: + """Return version from provided message_type.""" + return (re.search(r"(\d+\.)?(\*|\d+)", msg_type)).group() + + +def get_version_from_message(msg: AgentMessage) -> str: + """Return version from provided AgentMessage.""" + msg_type = msg._type + return get_version_from_message_type(msg_type) + + +async def get_proto_default_version( + profile: Profile, msg_class: type, major_version: int = 1 +) -> str: + """Return default protocol version from version_definition.""" + version_definition = await get_version_def_from_msg_class( + profile, msg_class, major_version + ) + default_major_version = version_definition["major_version"] + default_minor_version = version_definition["current_minor_version"] + return f"{default_major_version}.{default_minor_version}" + + +def _get_path_from_msg_class(msg_class: type) -> str: + path = os.path.normpath(inspect.getfile(msg_class)) + split_str = os.getenv("ACAPY_HOME") or "aries_cloudagent" + path = split_str + path.rsplit(split_str, 1)[1] + version = (re.search(r"v(\d+\_)?(\*|\d+)", path)).group() + path = path.split(version, 1)[0] + return (path.replace("/", ".")) + "definition" + + +async def get_version_def_from_msg_class( + profile: Profile, msg_class: type, major_version: int = 1 +): + """Return version_definition of a protocol.""" + cache = profile.inject_or(BaseCache) + version_definition = None + if cache: + version_definition = await cache.get( + f"version_definition::{str(msg_class).lower()}" + ) + if version_definition: + return version_definition + definition_path = _get_path_from_msg_class(msg_class) + definition = ClassLoader.load_module(definition_path) + for protocol_version in definition.versions: + if major_version == protocol_version["major_version"]: + version_definition = protocol_version + break + if not version_definition: + raise ProtocolDefinitionValidationError( + f"Unable to load protocol version_definition for {str(msg_class)}" + ) + if cache: + await cache.set( + f"version_definition::{str(msg_class).lower()}", version_definition + ) + return version_definition diff --git a/aries_cloudagent/protocols/out_of_band/definition.py b/aries_cloudagent/protocols/out_of_band/definition.py index 62bddef6f5..13c1f8a8ef 100644 --- a/aries_cloudagent/protocols/out_of_band/definition.py +++ b/aries_cloudagent/protocols/out_of_band/definition.py @@ -4,7 +4,7 @@ { "major_version": 1, "minimum_minor_version": 0, - "current_minor_version": 0, + "current_minor_version": 1, "path": "v1_0", } ] diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index 4bdfcc39a6..1679069e1a 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -9,6 +9,7 @@ from ....messaging.decorators.service_decorator import ServiceDecorator from ....core.event_bus import EventBus +from ....core.util import get_version_from_message from ....connections.base_manager import BaseConnectionManager from ....connections.models.conn_record import ConnRecord from ....core.error import BaseError @@ -906,7 +907,9 @@ async def receive_reuse_message( invi_msg_id = reuse_msg._thread.pthid reuse_msg_id = reuse_msg._thread_id - reuse_accept_msg = HandshakeReuseAccept() + reuse_accept_msg = HandshakeReuseAccept( + version=get_version_from_message(reuse_msg) + ) reuse_accept_msg.assign_thread_id(thid=reuse_msg_id, pthid=invi_msg_id) connection_targets = await self.fetch_connection_targets(connection=conn_rec) diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py b/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py index d8fb709e09..c130b92e4f 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py @@ -8,10 +8,10 @@ ) # Message types -INVITATION = "out-of-band/1.0/invitation" -MESSAGE_REUSE = "out-of-band/1.0/handshake-reuse" -MESSAGE_REUSE_ACCEPT = "out-of-band/1.0/handshake-reuse-accepted" -PROBLEM_REPORT = "out-of-band/1.0/problem_report" +INVITATION = "out-of-band/$version/invitation" +MESSAGE_REUSE = "out-of-band/$version/handshake-reuse" +MESSAGE_REUSE_ACCEPT = "out-of-band/$version/handshake-reuse-accepted" +PROBLEM_REPORT = "out-of-band/$version/problem_report" PROTOCOL_PACKAGE = "aries_cloudagent.protocols.out_of_band.v1_0" diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py index 04a2cfa3fa..14271ccd01 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py @@ -3,6 +3,7 @@ from collections import namedtuple from enum import Enum from re import sub +from string import Template from typing import Sequence, Text, Union from urllib.parse import parse_qs, urljoin, urlparse @@ -30,6 +31,7 @@ from .service import Service +BASE_PROTO_VERSION = "1.0" HSProtoSpec = namedtuple("HSProtoSpec", "rfc name aka") @@ -123,6 +125,7 @@ def __init__( handshake_protocols: Sequence[Text] = None, requests_attach: Sequence[AttachDecorator] = None, services: Sequence[Union[Service, Text]] = None, + version: str = BASE_PROTO_VERSION, **kwargs, ): """ @@ -140,12 +143,20 @@ def __init__( ) self.requests_attach = list(requests_attach) if requests_attach else [] self.services = services + self.assign_version_to_message_type(version=version) @classmethod def wrap_message(cls, message: dict) -> AttachDecorator: """Convert an aries message to an attachment decorator.""" return AttachDecorator.data_json(mapping=message, ident="request-0") + @classmethod + def assign_version_to_message_type(cls, version: str): + """Assign version to Meta.message_type.""" + cls.Meta.message_type = Template(cls.Meta.message_type).substitute( + version=version + ) + def to_url(self, base_url: str = None) -> str: """ Convert an invitation message to URL format for sharing. diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/problem_report.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/problem_report.py index f6ddb3bf86..679adacb5a 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/problem_report.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/problem_report.py @@ -10,6 +10,7 @@ validates_schema, ValidationError, ) +from string import Template from ....problem_report.v1_0.message import ProblemReport, ProblemReportSchema @@ -21,6 +22,7 @@ ) LOGGER = logging.getLogger(__name__) +BASE_PROTO_VERSION = "1.0" class ProblemReportReason(Enum): @@ -40,9 +42,17 @@ class Meta: message_type = PROBLEM_REPORT schema_class = "OOBProblemReportSchema" - def __init__(self, *args, **kwargs): + def __init__(self, version: str = BASE_PROTO_VERSION, *args, **kwargs): """Initialize a ProblemReport message instance.""" super().__init__(*args, **kwargs) + self.assign_version_to_message_type(version=version) + + @classmethod + def assign_version_to_message_type(cls, version: str): + """Assign version to Meta.message_type.""" + cls.Meta.message_type = Template(cls.Meta.message_type).substitute( + version=version + ) class OOBProblemReportSchema(ProblemReportSchema): diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse.py index df40511e80..f6b24d0e8b 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse.py @@ -1,6 +1,7 @@ """Represents a Handshake Reuse message under RFC 0434.""" from marshmallow import EXCLUDE, pre_dump, ValidationError +from string import Template from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -9,6 +10,7 @@ HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers.reuse_handler.HandshakeReuseMessageHandler" ) +BASE_PROTO_VERSION = "1.0" class HandshakeReuse(AgentMessage): @@ -23,10 +25,19 @@ class Meta: def __init__( self, + version: str = BASE_PROTO_VERSION, **kwargs, ): """Initialize Handshake Reuse message object.""" super().__init__(**kwargs) + self.assign_version_to_message_type(version=version) + + @classmethod + def assign_version_to_message_type(cls, version: str): + """Assign version to Meta.message_type.""" + cls.Meta.message_type = Template(cls.Meta.message_type).substitute( + version=version + ) class HandshakeReuseSchema(AgentMessageSchema): diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse_accept.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse_accept.py index d519ab0a2b..e920506205 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse_accept.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse_accept.py @@ -1,6 +1,7 @@ """Represents a Handshake Reuse Accept message under RFC 0434.""" from marshmallow import EXCLUDE, pre_dump, ValidationError +from string import Template from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -10,6 +11,7 @@ f"{PROTOCOL_PACKAGE}.handlers" ".reuse_accept_handler.HandshakeReuseAcceptMessageHandler" ) +BASE_PROTO_VERSION = "1.0" class HandshakeReuseAccept(AgentMessage): @@ -24,10 +26,19 @@ class Meta: def __init__( self, + version: str = BASE_PROTO_VERSION, **kwargs, ): """Initialize Handshake Reuse Accept object.""" super().__init__(**kwargs) + self.assign_version_to_message_type(version=version) + + @classmethod + def assign_version_to_message_type(cls, version: str): + """Assign version to Meta.message_type.""" + cls.Meta.message_type = Template(cls.Meta.message_type).substitute( + version=version + ) class HandshakeReuseAcceptSchema(AgentMessageSchema): diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py index 5340dd66dc..83b80773a6 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py @@ -1,5 +1,6 @@ import pytest +from string import Template from unittest import TestCase from ......messaging.models.base import BaseModelError @@ -51,7 +52,9 @@ def test_init(self): services=[TEST_DID], ) assert invi.services == [TEST_DID] - assert invi._type == DIDCommPrefix.qualify_current(INVITATION) + assert invi._type == DIDCommPrefix.qualify_current( + Template(INVITATION).substitute(version="1.0") + ) service = Service(_id="#inline", _type=DID_COMM, did=TEST_DID) invi_msg = InvitationMessage( @@ -61,7 +64,9 @@ def test_init(self): services=[service], ) assert invi_msg.services == [service] - assert invi_msg._type == DIDCommPrefix.qualify_current(INVITATION) + assert invi_msg._type == DIDCommPrefix.qualify_current( + Template(INVITATION).substitute(version="1.0") + ) def test_wrap_serde(self): """Test conversion of aries message to attachment decorator.""" diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py index 3f7384c164..59454a69d2 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py @@ -75,6 +75,15 @@ class AttachmentDefSchema(OpenAPISchema): ), required=False, ) + # accept = fields.List( + # fields.Str(), + # description=( + # "List of mime type in order of preference that should be" + # " use in responding to the message" + # ), + # example="['didcomm/aip1', 'didcomm/aip2;env=rfc19']", + # required=False, + # ) use_public_did = fields.Boolean( default=False, description="Whether to use public DID in invitation", diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py index 738e295a1f..df39e85088 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py @@ -3,6 +3,7 @@ import json from copy import deepcopy from datetime import datetime, timedelta, timezone +from string import Template from typing import List from unittest.mock import ANY @@ -388,7 +389,7 @@ async def test_create_invitation_handshake_succeeds(self): ) assert invi_rec.invitation._type == DIDCommPrefix.qualify_current( - INVITATION + Template(INVITATION).substitute(version="1.0") ) assert not invi_rec.invitation.requests_attach assert ( @@ -475,7 +476,7 @@ async def test_create_invitation_mediation_overwrites_routing_and_endpoint(self) ) assert isinstance(invite, InvitationRecord) assert invite.invitation._type == DIDCommPrefix.qualify_current( - INVITATION + Template(INVITATION).substitute(version="1.0") ) assert invite.invitation.label == "test123" assert ( @@ -793,7 +794,9 @@ async def test_create_invitation_peer_did(self): assert invi_rec._invitation.ser[ "@type" - ] == DIDCommPrefix.qualify_current(INVITATION) + ] == DIDCommPrefix.qualify_current( + Template(INVITATION).substitute(version="1.0") + ) assert not invi_rec._invitation.ser.get("requests~attach") assert invi_rec.invitation.label == "That guy" assert ( @@ -900,7 +903,9 @@ async def test_create_handshake_reuse_msg(self): assert oob_record.state == OobRecord.STATE_AWAIT_RESPONSE # Assert responder has been called with the reuse message - assert reuse_message._type == DIDCommPrefix.qualify_current(MESSAGE_REUSE) + assert reuse_message._type == DIDCommPrefix.qualify_current( + Template(MESSAGE_REUSE).substitute(version="1.0") + ) assert oob_record.reuse_msg_id == reuse_message._id async def test_create_handshake_reuse_msg_catch_exception(self):