Skip to content

Commit

Permalink
Merge pull request openwallet-foundation#2816 from petridishdev/featu…
Browse files Browse the repository at this point in the history
…re/did-rotate

feat: did-rotate
  • Loading branch information
dbluhm authored Mar 15, 2024
2 parents 63250d8 + bdb2600 commit 3eb0bf7
Show file tree
Hide file tree
Showing 41 changed files with 1,919 additions and 145 deletions.
22 changes: 19 additions & 3 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ async def store_did_document(self, value: Union[DIDDoc, dict]):
await storage.update_record(record, doc, {"did": did})

await self.remove_keys_for_did(did)
await self.record_did(did)
await self.record_keys_for_resolvable_did(did)

async def add_key_for_did(self, did: str, key: str):
"""Store a verkey for lookup against a DID.
Expand Down Expand Up @@ -441,8 +441,8 @@ async def resolve_invitation(
[self._extract_key_material_in_base58_format(key) for key in routing_keys],
)

async def record_did(self, did: str):
"""Record DID for later use.
async def record_keys_for_resolvable_did(self, did: str):
"""Record the keys for a public DID.
This is required to correlate sender verkeys back to a connection.
"""
Expand Down Expand Up @@ -739,6 +739,21 @@ async def get_connection_targets(
targets = await self.fetch_connection_targets(connection)
return targets

async def clear_connection_targets_cache(self, connection_id: str):
"""Clear the connection targets cache for a given connection ID.
Historically, connections have not been updatable after the protocol
completes. However, with DID Rotation, we need to be able to update
the connection targets and clear the cache of targets.
"""
# TODO it would be better to include the DIDs of the connection in the
# target cache key This solution only works when using whole cluster
# caching or have only a single instance with local caching
cache = self._profile.inject_or(BaseCache)
if cache:
cache_key = f"connection_target::{connection_id}"
await cache.clear(cache_key)

def diddoc_connection_targets(
self,
doc: Optional[Union[DIDDoc, dict]],
Expand Down Expand Up @@ -959,6 +974,7 @@ async def get_endpoints(self, conn_id: str) -> Tuple[Optional[str], Optional[str
connection = await ConnRecord.retrieve_by_id(session, conn_id)
wallet = session.inject(BaseWallet)
my_did_info = await wallet.get_local_did(connection.my_did)

my_endpoint = my_did_info.metadata.get(
"endpoint",
self._profile.settings.get("default_endpoint"),
Expand Down
8 changes: 4 additions & 4 deletions aries_cloudagent/connections/tests/test_base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ async def test_resolve_connection_targets_x_unsupported_key(self):
await self.manager.resolve_connection_targets(did)
assert "not supported" in str(cm.exception)

async def test_record_did_empty(self):
async def test_record_keys_for_resolvable_did_empty(self):
did = "did:sov:" + self.test_did
service_builder = ServiceBuilder(DID(did))
service_builder.add_didcomm(
Expand All @@ -1083,9 +1083,9 @@ async def test_record_did_empty(self):
self.manager.resolve_didcomm_services = mock.CoroutineMock(
return_value=(DIDDocument(id=DID(did)), service_builder.services)
)
await self.manager.record_did(did)
await self.manager.record_keys_for_resolvable_did(did)

async def test_record_did(self):
async def test_record_keys_for_resolvable_did(self):
did = "did:sov:" + self.test_did
doc_builder = DIDDocumentBuilder(did)
vm = doc_builder.verification_method.add(
Expand All @@ -1099,7 +1099,7 @@ async def test_record_did(self):
self.manager.resolve_didcomm_services = mock.CoroutineMock(
return_value=(doc, doc.service)
)
await self.manager.record_did(did)
await self.manager.record_keys_for_resolvable_did(did)

async def test_diddoc_connection_targets_diddoc_underspecified(self):
with self.assertRaises(BaseConnectionManagerError):
Expand Down
3 changes: 3 additions & 0 deletions aries_cloudagent/messaging/valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,9 @@ def __init__(
DID_POSTURE_VALIDATE = DIDPosture()
DID_POSTURE_EXAMPLE = DIDPosture.EXAMPLE

DID_WEB_VALIDATE = DIDWeb()
DID_WEB_EXAMPLE = DIDWeb.EXAMPLE

ROUTING_KEY_VALIDATE = RoutingKey()
ROUTING_KEY_EXAMPLE = RoutingKey.EXAMPLE

Expand Down
Empty file.
10 changes: 10 additions & 0 deletions aries_cloudagent/protocols/did_rotate/definition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Version definitions for this protocol."""

versions = [
{
"major_version": 1,
"minimum_minor_version": 0,
"current_minor_version": 0,
"path": "v1_0",
}
]
Empty file.
Empty file.
29 changes: 29 additions & 0 deletions aries_cloudagent/protocols/did_rotate/v1_0/handlers/ack_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Rotate ack handler."""

from .....messaging.base_handler import BaseHandler
from .....messaging.request_context import RequestContext
from .....messaging.responder import BaseResponder
from ..manager import DIDRotateManager
from ..messages.ack import RotateAck


class RotateAckHandler(BaseHandler):
"""Message handler class for rotate ack message."""

async def handle(self, context: RequestContext, responder: BaseResponder):
"""Handle rotate ack message.
Args:
context: request context
responder: responder callback
"""
self._logger.debug("RotateAckHandler called with context %s", context)
assert isinstance(context.message, RotateAck)

connection_record = context.connection_record
ack = context.message

profile = context.profile
did_rotate_mgr = DIDRotateManager(profile)

await did_rotate_mgr.receive_ack(connection_record, ack)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Rotate hangup handler."""

from .....messaging.base_handler import BaseHandler
from .....messaging.request_context import RequestContext
from .....messaging.responder import BaseResponder
from ..manager import DIDRotateManager
from ..messages.hangup import Hangup


class HangupHandler(BaseHandler):
"""Message handler class for rotate message."""

async def handle(self, context: RequestContext, responder: BaseResponder):
"""Handle rotate hangup message.
Args:
context: request context
responder: responder callback
"""
self._logger.debug("HangupHandler called with context %s", context)
assert isinstance(context.message, Hangup)

connection_record = context.connection_record

profile = context.profile
did_rotate_mgr = DIDRotateManager(profile)

await did_rotate_mgr.receive_hangup(connection_record)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Rotate problem report handler."""

from .....messaging.base_handler import BaseHandler
from .....messaging.request_context import RequestContext
from .....messaging.responder import BaseResponder
from ..manager import DIDRotateManager
from ..messages.problem_report import RotateProblemReport


class ProblemReportHandler(BaseHandler):
"""Message handler class for rotate message."""

async def handle(self, context: RequestContext, responder: BaseResponder):
"""Handle rotate problem report message.
Args:
context: request context
responder: responder callback
"""
self._logger.debug("ProblemReportHandler called with context %s", context)
assert isinstance(context.message, RotateProblemReport)

problem_report = context.message

profile = context.profile
did_rotate_mgr = DIDRotateManager(profile)

await did_rotate_mgr.receive_problem_report(problem_report)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Rotate handler."""

from .....messaging.base_handler import BaseHandler
from .....messaging.request_context import RequestContext
from .....messaging.responder import BaseResponder
from ..manager import DIDRotateManager
from ..messages.rotate import Rotate


class RotateHandler(BaseHandler):
"""Message handler class for rotate message."""

async def handle(self, context: RequestContext, responder: BaseResponder):
"""Handle rotate message.
Args:
context: request context
responder: responder callback
"""
self._logger.debug("RotateHandler called with context %s", context)
assert isinstance(context.message, Rotate)

connection_record = context.connection_record
rotate = context.message

profile = context.profile
did_rotate_mgr = DIDRotateManager(profile)

if record := await did_rotate_mgr.receive_rotate(connection_record, rotate):
await did_rotate_mgr.commit_rotate(connection_record, record)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder
from ......tests import mock
from ...messages.ack import RotateAck
from .. import ack_handler as test_module


@pytest.fixture()
def request_context():
ctx = RequestContext.test_context()
yield ctx


class TestAckHandler:
"""Unit tests for AckHandler."""

@pytest.mark.asyncio
@mock.patch.object(test_module, "DIDRotateManager")
async def test_handle(self, MockDIDRotateManager, request_context):
MockDIDRotateManager.return_value.receive_ack = mock.CoroutineMock()

request_context.message = RotateAck()
request_context.connection_record = mock.MagicMock()

handler = test_module.RotateAckHandler()
responder = MockResponder()
await handler.handle(request_context, responder)

MockDIDRotateManager.return_value.receive_ack.assert_called_once_with(
request_context.connection_record, request_context.message
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder
from ......tests import mock
from ...messages.hangup import Hangup
from .. import hangup_handler as test_module


@pytest.fixture()
def request_context():
ctx = RequestContext.test_context()
yield ctx


class TestHangupHandler:
"""Unit tests for HangupHandler."""

@pytest.mark.asyncio
@mock.patch.object(test_module, "DIDRotateManager")
async def test_handle(self, MockDIDRotateManager, request_context):
MockDIDRotateManager.return_value.receive_hangup = mock.CoroutineMock()

request_context.message = Hangup()
request_context.connection_record = mock.MagicMock()

handler = test_module.HangupHandler()
responder = MockResponder()
await handler.handle(request_context, responder)

MockDIDRotateManager.return_value.receive_hangup.assert_called_once_with(
request_context.connection_record
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder
from ......tests import mock
from ...messages.problem_report import RotateProblemReport
from .. import problem_report_handler as test_module

test_valid_rotate_request = {
"to_did": "did:example:newdid",
}


@pytest.fixture()
def request_context():
ctx = RequestContext.test_context()
yield ctx


class TestProblemReportHandler:
"""Unit tests for ProblemReportHandler."""

@pytest.mark.asyncio
@mock.patch.object(test_module, "DIDRotateManager")
async def test_handle(self, MockDIDRotateManager, request_context):
MockDIDRotateManager.return_value.receive_problem_report = mock.CoroutineMock()

request_context.message = RotateProblemReport()
request_context.connection_record = mock.MagicMock()

handler = test_module.ProblemReportHandler()
responder = MockResponder()
await handler.handle(request_context, responder)

MockDIDRotateManager.return_value.receive_problem_report.assert_called_once_with(
request_context.message
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder
from ......tests import mock
from ...messages.rotate import Rotate
from .. import rotate_handler as test_module

test_valid_rotate_request = {
"to_did": "did:example:newdid",
}


@pytest.fixture()
def request_context():
ctx = RequestContext.test_context()
yield ctx


class TestRotateHandler:
"""Unit tests for RotateHandler."""

@pytest.mark.asyncio
@mock.patch.object(test_module, "DIDRotateManager")
async def test_handle(self, MockDIDRotateManager, request_context):
MockDIDRotateManager.return_value.receive_rotate = mock.CoroutineMock()
MockDIDRotateManager.return_value.commit_rotate = mock.CoroutineMock()

request_context.message = Rotate(**test_valid_rotate_request)
request_context.connection_record = mock.MagicMock()

handler = test_module.RotateHandler()
responder = MockResponder()
await handler.handle(request_context, responder)

MockDIDRotateManager.return_value.receive_rotate.assert_called_once_with(
request_context.connection_record, request_context.message
)
MockDIDRotateManager.return_value.commit_rotate.assert_called_once_with(
request_context.connection_record,
MockDIDRotateManager.return_value.receive_rotate.return_value,
)
Loading

0 comments on commit 3eb0bf7

Please sign in to comment.