Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Extract verification method ID generation to a separate class #2235

Merged
4 changes: 4 additions & 0 deletions aries_cloudagent/config/default_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..transport.wire_format import BaseWireFormat
from ..utils.dependencies import is_indy_sdk_module_installed
from ..utils.stats import Collector
from ..wallet.default_verification_key_strategy import DefaultVerificationKeyStrategy
from ..wallet.did_method import DIDMethods
from ..wallet.key_type import KeyTypes
from .base_context import ContextBuilder
Expand Down Expand Up @@ -53,6 +54,9 @@ async def build_context(self) -> InjectionContext:
context.injector.bind_instance(DIDResolver, DIDResolver([]))
context.injector.bind_instance(DIDMethods, DIDMethods())
context.injector.bind_instance(KeyTypes, KeyTypes())
context.injector.bind_instance(
DefaultVerificationKeyStrategy, DefaultVerificationKeyStrategy()
)

await self.bind_providers(context)
await self.load_plugins(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pyld import jsonld
from pyld.jsonld import JsonLdProcessor

from ......did.did_key import DIDKey
from ......messaging.decorators.attach_decorator import AttachDecorator
from ......storage.vc_holder.base import VCHolder
from ......storage.vc_holder.vc_record import VCRecord
Expand All @@ -35,6 +34,9 @@
)
from ......vc.ld_proofs.constants import SECURITY_CONTEXT_BBS_URL
from ......wallet.base import BaseWallet, DIDInfo
from ......wallet.default_verification_key_strategy import (
DefaultVerificationKeyStrategy,
)
from ......wallet.error import WalletNotFoundError
from ......wallet.key_type import BLS12381G2, ED25519

Expand Down Expand Up @@ -270,10 +272,17 @@ async def _get_suite_for_detail(
)

did_info = await self._did_info_for_did(issuer_id)
verification_method = verification_method or self._get_verification_method(
issuer_id
verkey_id_strategy = self.profile.context.inject(DefaultVerificationKeyStrategy)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be Base instead of Default, as you can override the default with a custom one? It would't make sense to make the injection token the default implementation IMO.

Copy link
Contributor Author

@yvgny yvgny Jun 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TimoGlastra So the injection token should be renamed to BaseVerificationKeyStrategy? This would mean the base (abstract) class would be named BaseVerificationKeyStrategy and the default implementation DefaultVerificationKeyStrategy, is this correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yvgny that is correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

verification_method = (
verification_method
or verkey_id_strategy.get_verification_method_id_for_did(issuer_id)
)

if verification_method is None:
raise V20CredFormatError(
f"Unable to get retrieve verification method for did {issuer_id}"
)

suite = await self._get_suite(
proof_type=proof_type,
verification_method=verification_method,
Expand Down Expand Up @@ -309,19 +318,6 @@ async def _get_suite(
),
)

def _get_verification_method(self, did: str):
"""Get the verification method for a did."""

if did.startswith("did:key:"):
return DIDKey.from_did(did).key_id
elif did.startswith("did:sov:"):
# key-1 is what the resolver uses for key id
return did + "#key-1"
else:
raise V20CredFormatError(
f"Unable to get retrieve verification method for did {did}"
)

def _get_proof_purpose(
self, *, proof_purpose: str = None, challenge: str = None, domain: str = None
) -> ProofPurpose:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
)
from .......vc.ld_proofs.constants import SECURITY_CONTEXT_BBS_URL
from .......vc.tests.document_loader import custom_document_loader
from .......wallet.default_verification_key_strategy import (
DefaultVerificationKeyStrategy,
)
from .......wallet.key_type import BLS12381G2, ED25519
from .......wallet.error import WalletNotFoundError
from .......wallet.did_method import SOV
Expand Down Expand Up @@ -124,6 +127,11 @@ async def setUp(self):
# Set custom document loader
self.context.injector.bind_instance(DocumentLoader, custom_document_loader)

# Set default verkey ID strategy
self.context.injector.bind_instance(
DefaultVerificationKeyStrategy, DefaultVerificationKeyStrategy()
)

self.handler = LDProofCredFormatHandler(self.profile)

self.cred_proposal = V20CredProposal(
Expand Down Expand Up @@ -318,24 +326,6 @@ async def test_get_suite(self):
assert suite.key_pair.key_type == ED25519
assert suite.key_pair.public_key_base58 == did_info.verkey

async def test_get_verification_method(self):
assert (
self.handler._get_verification_method(TEST_DID_KEY)
== DIDKey.from_did(TEST_DID_KEY).key_id
)

assert (
self.handler._get_verification_method(TEST_DID_SOV)
== TEST_DID_SOV + "#key-1"
)

with self.assertRaises(V20CredFormatError) as context:
self.handler._get_verification_method("did:random:not-supported")

assert "Unable to get retrieve verification method for did" in str(
context.exception
)

async def test_get_proof_purpose(self):
purpose = self.handler._get_proof_purpose()
assert type(purpose) == CredentialIssuancePurpose
Expand Down
24 changes: 10 additions & 14 deletions aries_cloudagent/protocols/present_proof/dif/pres_exch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from ....core.error import BaseError
from ....core.profile import Profile
from ....did.did_key import DIDKey
from ....storage.vc_holder.vc_record import VCRecord
from ....vc.ld_proofs import (
Ed25519Signature2018,
Expand All @@ -39,6 +38,7 @@
)
from ....vc.vc_ld.prove import sign_presentation, create_presentation, derive_credential
from ....wallet.base import BaseWallet, DIDInfo
from ....wallet.default_verification_key_strategy import DefaultVerificationKeyStrategy
from ....wallet.error import WalletError, WalletNotFoundError
from ....wallet.key_type import BLS12381G2, ED25519

Expand Down Expand Up @@ -117,7 +117,15 @@ async def _get_issue_suite(
):
"""Get signature suite for signing presentation."""
did_info = await self._did_info_for_did(issuer_id)
verification_method = self._get_verification_method(issuer_id)
verkey_id_strategy = self.profile.context.inject(DefaultVerificationKeyStrategy)
verification_method = verkey_id_strategy.get_verification_method_id_for_did(
issuer_id
)

if verification_method is None:
raise DIFPresExchError(
f"Unable to get retrieve verification method for did {issuer_id}"
)

# Get signature class based on proof type
SignatureClass = self.PROOF_TYPE_SIGNATURE_SUITE_MAPPING[self.proof_type]
Expand Down Expand Up @@ -151,18 +159,6 @@ async def _get_derive_suite(
),
)

def _get_verification_method(self, did: str):
"""Get the verification method for a did."""
if did.startswith("did:key:"):
return DIDKey.from_did(did).key_id
elif did.startswith("did:sov:"):
# key-1 is what uniresolver uses for key id
return did + "#key-1"
else:
raise DIFPresExchError(
f"Unable to get retrieve verification method for did {did}"
)

async def _did_info_for_did(self, did: str) -> DIDInfo:
"""Get the did info for specified did.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .....storage.vc_holder.vc_record import VCRecord
from .....wallet.base import BaseWallet, DIDInfo
from .....wallet.crypto import KeyType
from .....wallet.default_verification_key_strategy import DefaultVerificationKeyStrategy
from .....wallet.did_method import SOV, KEY, DIDMethods
from .....wallet.error import WalletNotFoundError
from .....vc.ld_proofs import (
Expand Down Expand Up @@ -73,6 +74,9 @@ def profile():
context = profile.context
context.injector.bind_instance(DIDResolver, DIDResolver([]))
context.injector.bind_instance(DocumentLoader, custom_document_loader)
context.injector.bind_instance(
DefaultVerificationKeyStrategy, DefaultVerificationKeyStrategy()
)
context.settings["debug.auto_respond_presentation_request"] = True
return profile

Expand Down Expand Up @@ -1867,19 +1871,6 @@ def test_cred_schema_match_b(self, profile, setup_tuple):
test_cred, "https://example.org/examples/degree.json"
)

def test_verification_method(self, profile):
dif_pres_exch_handler = DIFPresExchHandler(profile)
assert (
dif_pres_exch_handler._get_verification_method(
"did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL"
)
== DIDKey.from_did(
"did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL"
).key_id
)
with pytest.raises(DIFPresExchError):
dif_pres_exch_handler._get_verification_method("did:test:test")

@pytest.mark.asyncio
@pytest.mark.ursa_bbs_signatures
async def test_sign_pres_no_cred_subject_id(self, profile, setup_tuple):
Expand Down
43 changes: 43 additions & 0 deletions aries_cloudagent/wallet/default_verification_key_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Utilities for specifying which verification method is in use for a given DID."""
from abc import ABC, abstractmethod
from typing import Optional

from aries_cloudagent.did.did_key import DIDKey


class DefaultVerificationKeyStrategyBase(ABC):
"""Base class for defining which verification method is in use."""

@abstractmethod
def get_verification_method_id_for_did(self, did) -> Optional[str]:
"""Given a DID, returns the verification key ID in use.

Returns None if no strategy is specified for this DID.

:params str did: the did
:returns Optional[str]: the current verkey ID
"""
pass


class DefaultVerificationKeyStrategy(DefaultVerificationKeyStrategyBase):
"""A basic implementation for verkey strategy.

Supports did:key: and did:sov only.
"""

def get_verification_method_id_for_did(self, did) -> Optional[str]:
"""Given a did:key or did:sov, returns the verification key ID in use.

Returns None if no strategy is specified for this DID.

:params str did: the did
:returns Optional[str]: the current verkey ID
"""
if did.startswith("did:key:"):
return DIDKey.from_did(did).key_id
elif did.startswith("did:sov:"):
# key-1 is what uniresolver uses for key id
return did + "#key-1"

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from unittest import TestCase

from aries_cloudagent.did.did_key import DIDKey

from aries_cloudagent.wallet.default_verification_key_strategy import (
DefaultVerificationKeyStrategy,
)

TEST_DID_SOV = "did:sov:LjgpST2rjsoxYegQDRm7EL"
TEST_DID_KEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL"


class TestDefaultVerificationKeyStrategy(TestCase):
def test_with_did_sov(self):
strategy = DefaultVerificationKeyStrategy()
assert (
strategy.get_verification_method_id_for_did(TEST_DID_SOV)
== TEST_DID_SOV + "#key-1"
)

def test_with_did_key(self):
strategy = DefaultVerificationKeyStrategy()
assert (
strategy.get_verification_method_id_for_did(TEST_DID_KEY)
== DIDKey.from_did(TEST_DID_KEY).key_id
)

def test_unsupported_did_method(self):
strategy = DefaultVerificationKeyStrategy()
assert strategy.get_verification_method_id_for_did("did:test:test") is None