Skip to content

Commit

Permalink
Merge pull request #1940 from shaangill025/fix_oob_accept_minor_version
Browse files Browse the repository at this point in the history
Fix: OOB - Handling of minor versions
  • Loading branch information
swcurran authored Sep 29, 2022
2 parents 98f5537 + 84628f9 commit e62d5ba
Show file tree
Hide file tree
Showing 16 changed files with 620 additions and 44 deletions.
72 changes: 67 additions & 5 deletions aries_cloudagent/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import warnings

from typing import Callable, Coroutine, Union
from typing import Callable, Coroutine, Optional, Union, Tuple
import weakref

from aiohttp.web import HTTPException
Expand All @@ -36,6 +36,13 @@

from .error import ProtocolMinorVersionNotSupported
from .protocol_registry import ProtocolRegistry
from .util import (
get_version_from_message_type,
validate_get_response_version,
# WARNING_DEGRADED_FEATURES,
# WARNING_VERSION_MISMATCH,
# WARNING_VERSION_NOT_SUPPORTED,
)

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,16 +140,22 @@ async def handle_message(
inbound_message: The inbound message instance
send_outbound: Async function to send outbound messages
# Raises:
# MessageParseError: If the message type version is not supported
Returns:
The response from the handler
"""
r_time = get_timer()

error_result = None
version_warning = None
message = None
try:
message = await self.make_message(inbound_message.payload)
(message, warning) = await self.make_message(
profile, inbound_message.payload
)
except ProblemReportParseError:
pass # avoid problem report recursion
except MessageParseError as e:
Expand All @@ -155,6 +168,47 @@ async def handle_message(
)
if inbound_message.receipt.thread_id:
error_result.assign_thread_id(inbound_message.receipt.thread_id)
# if warning:
# warning_message_type = inbound_message.payload.get("@type")
# if warning == WARNING_DEGRADED_FEATURES:
# LOGGER.error(
# f"Sending {WARNING_DEGRADED_FEATURES} problem report, "
# "message type received with a minor version at or higher"
# " than protocol minimum supported and current minor version "
# f"for message_type {warning_message_type}"
# )
# version_warning = ProblemReport(
# description={
# "en": (
# "message type received with a minor version at or "
# "higher than protocol minimum supported and current"
# f" minor version for message_type {warning_message_type}"
# ),
# "code": WARNING_DEGRADED_FEATURES,
# }
# )
# elif warning == WARNING_VERSION_MISMATCH:
# LOGGER.error(
# f"Sending {WARNING_VERSION_MISMATCH} problem report, message "
# "type received with a minor version higher than current minor "
# f"version for message_type {warning_message_type}"
# )
# version_warning = ProblemReport(
# description={
# "en": (
# "message type received with a minor version higher"
# " than current minor version for message_type"
# f" {warning_message_type}"
# ),
# "code": WARNING_VERSION_MISMATCH,
# }
# )
# elif warning == WARNING_VERSION_NOT_SUPPORTED:
# raise MessageParseError(
# f"Message type version not supported for {warning_message_type}"
# )
# if version_warning and inbound_message.receipt.thread_id:
# version_warning.assign_thread_id(inbound_message.receipt.thread_id)

trace_event(
self.profile.settings,
Expand Down Expand Up @@ -199,6 +253,8 @@ async def handle_message(

if error_result:
await responder.send_reply(error_result)
elif version_warning:
await responder.send_reply(version_warning)
elif context.message:
context.injector.bind_instance(BaseResponder, responder)

Expand All @@ -215,7 +271,9 @@ async def handle_message(
perf_counter=r_time,
)

async def make_message(self, parsed_msg: dict) -> BaseMessage:
async def make_message(
self, profile: Profile, parsed_msg: dict
) -> Tuple[BaseMessage, Optional[str]]:
"""
Deserialize a message dict into the appropriate message instance.
Expand All @@ -224,6 +282,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage:
Args:
parsed_msg: The parsed message
profile: Profile
Returns:
An instance of the corresponding message class for this message
Expand All @@ -237,6 +296,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage:
if not isinstance(parsed_msg, dict):
raise MessageParseError("Expected a JSON object")
message_type = parsed_msg.get("@type")
message_type_rec_version = get_version_from_message_type(message_type)

if not message_type:
raise MessageParseError("Message does not contain '@type' parameter")
Expand All @@ -256,8 +316,10 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage:
if "/problem-report" in message_type:
raise ProblemReportParseError("Error parsing problem report message")
raise MessageParseError(f"Error deserializing message: {e}") from e

return instance
_, warning = await validate_get_response_version(
profile, message_type_rec_version, message_cls
)
return (instance, warning)

async def complete(self, timeout: float = 0.1):
"""Wait for pending tasks to complete."""
Expand Down
94 changes: 82 additions & 12 deletions aries_cloudagent/core/protocol_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging

from string import Template
from typing import Mapping, Sequence

from ..config.injection_context import InjectionContext
Expand Down Expand Up @@ -74,6 +75,73 @@ def parse_type_string(self, message_type):
"minor_version": int(version_string_tokens[1]),
}

def create_msg_types_for_minor_version(self, typesets, version_definition):
"""
Return mapping of message type to module path for minor versions.
Args:
typesets: Mappings of message types to register
version_definition: Optional version definition dict
Returns:
Typesets mapping
"""
updated_typeset = {}
curr_minor_version = version_definition["current_minor_version"]
min_minor_version = version_definition["minimum_minor_version"]
major_version = version_definition["major_version"]
if curr_minor_version >= min_minor_version and curr_minor_version >= 1:
for version_index in range(min_minor_version, curr_minor_version + 1):
to_check = f"{str(major_version)}.{str(version_index)}"
updated_typeset.update(
self._get_updated_tyoeset_dict(typesets, to_check, updated_typeset)
)
return (updated_typeset,)

def _get_updated_tyoeset_dict(self, typesets, to_check, updated_typeset) -> dict:
for typeset in typesets:
for msg_type_string, module_path in typeset.items():
updated_msg_type_string = Template(msg_type_string).substitute(
version=to_check
)
updated_typeset[updated_msg_type_string] = module_path
return updated_typeset

def _template_message_type_check(self, typeset) -> bool:
for msg_type_string, _ in typeset.items():
if "$version" in msg_type_string:
return True
return False

def _create_and_register_updated_typesets(self, typesets, version_definition):
updated_typesets = self.create_msg_types_for_minor_version(
typesets, version_definition
)
update_flag = False
for typeset in updated_typesets:
if typeset:
self._typemap.update(typeset)
update_flag = True
if update_flag:
return updated_typesets
else:
return None

def _update_version_map(self, message_type_string, module_path, version_definition):
parsed_type_string = self.parse_type_string(message_type_string)

if version_definition["major_version"] not in self._versionmap:
self._versionmap[version_definition["major_version"]] = []

self._versionmap[version_definition["major_version"]].append(
{
"parsed_type_string": parsed_type_string,
"version_definition": version_definition,
"message_module": module_path,
}
)

def register_message_types(self, *typesets, version_definition=None):
"""
Add new supported message types.
Expand All @@ -85,24 +153,26 @@ def register_message_types(self, *typesets, version_definition=None):
"""

# Maintain support for versionless protocol modules
template_msg_type_version = True
updated_typesets = None
for typeset in typesets:
self._typemap.update(typeset)
if not self._template_message_type_check(typeset):
self._typemap.update(typeset)
template_msg_type_version = False

# Track versioned modules for version routing
if version_definition:
# create updated typesets for minor versions and register them
if template_msg_type_version:
updated_typesets = self._create_and_register_updated_typesets(
typesets, version_definition
)
if updated_typesets:
typesets = updated_typesets
for typeset in typesets:
for message_type_string, module_path in typeset.items():
parsed_type_string = self.parse_type_string(message_type_string)

if version_definition["major_version"] not in self._versionmap:
self._versionmap[version_definition["major_version"]] = []

self._versionmap[version_definition["major_version"]].append(
{
"parsed_type_string": parsed_type_string,
"version_definition": version_definition,
"message_module": module_path,
}
self._update_version_map(
message_type_string, module_path, version_definition
)

def register_controllers(self, *controller_sets, version_definition=None):
Expand Down
Loading

0 comments on commit e62d5ba

Please sign in to comment.