Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update get/set_did_public to use a storage record pointer #1249

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aries_cloudagent/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..core.error import BaseError

InjectType = TypeVar("Inject")
InjectType = TypeVar("InjectType")


class ConfigError(BaseError):
Expand Down
12 changes: 6 additions & 6 deletions aries_cloudagent/core/tests/test_event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def event():
yield event


class TestProcessor:
class MockProcessor:
def __init__(self):
self.context = None
self.event = None
Expand All @@ -39,7 +39,7 @@ async def __call__(self, context, event):

@pytest.fixture
def processor():
yield TestProcessor()
yield MockProcessor()


def test_event(event):
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_unsub_unsubbed_processor(event_bus: EventBus, processor):
"""Test unsubscribing an unsubscribed processor does not error."""
event_bus.unsubscribe(re.compile(".*"), processor)
event_bus.subscribe(re.compile(".*"), processor)
another_processor = TestProcessor()
another_processor = MockProcessor()
event_bus.unsubscribe(re.compile(".*"), another_processor)


Expand All @@ -101,7 +101,7 @@ async def test_sub_notify_error_logged_and_exec_continues(
def _raise_exception(context, event):
raise Exception()

processor = TestProcessor()
processor = MockProcessor()
bad_processor = _raise_exception
event_bus.subscribe(re.compile(".*"), bad_processor)
event_bus.subscribe(re.compile(".*"), processor)
Expand Down Expand Up @@ -147,7 +147,7 @@ async def test_sub_notify_no_match(event_bus: EventBus, context, event, processo
@pytest.mark.asyncio
async def test_sub_notify_only_one(event_bus: EventBus, context, event, processor):
"""Test only one subscriber is called when pattern matches only one."""
processor1 = TestProcessor()
processor1 = MockProcessor()
event_bus.subscribe(re.compile(".*"), processor)
event_bus.subscribe(re.compile("^$"), processor1)
await event_bus.notify(context, event)
Expand All @@ -160,7 +160,7 @@ async def test_sub_notify_only_one(event_bus: EventBus, context, event, processo
@pytest.mark.asyncio
async def test_sub_notify_both(event_bus: EventBus, context, event, processor):
"""Test both subscribers are called when pattern matches both."""
processor1 = TestProcessor()
processor1 = MockProcessor()
event_bus.subscribe(re.compile(".*"), processor)
event_bus.subscribe(re.compile("anything"), processor1)
await event_bus.notify(context, event)
Expand Down
71 changes: 13 additions & 58 deletions aries_cloudagent/wallet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ..ledger.endpoint_type import EndpointType
from .error import WalletError

from .did_posture import DIDPosture
from .did_info import DIDInfo, KeyInfo
from .key_type import KeyType
from .did_method import DIDMethod
Expand Down Expand Up @@ -122,8 +121,6 @@ async def create_public_did(
"""
Create and store a new public DID.

Implicitly flags all other dids as not public.

Args:
seed: Optional seed to use for DID
did: The DID to use
Expand All @@ -133,73 +130,33 @@ async def create_public_did(
The created `DIDInfo`

"""
if method != DIDMethod.SOV:
raise WalletError("Creating public did is only allowed for did:sov dids")

# validate key_type
if not method.supports_key_type(key_type):
raise WalletError(
f"Invalid key type {key_type.key_type} for method {method.method_name}"
)

metadata = DIDPosture.PUBLIC.metadata
dids = await self.get_local_dids()
for info in dids:
info_meta = info.metadata
info_meta["public"] = False
await self.replace_local_did_metadata(info.did, info_meta)
return await self.create_local_did(
metadata = metadata or {}
metadata.setdefault("posted", True)
did_info = await self.create_local_did(
method=method, key_type=key_type, seed=seed, did=did, metadata=metadata
)
return await self.set_public_did(did_info)

@abstractmethod
async def get_public_did(self) -> DIDInfo:
"""
Retrieve the public DID.

Returns:
The created `DIDInfo`
The currently public `DIDInfo`, if any

"""

dids = await self.get_local_dids()
for info in dids:
if info.metadata.get("public"):
return info

return None

async def set_public_did(self, did: str) -> DIDInfo:
@abstractmethod
async def set_public_did(self, did: Union[str, DIDInfo]) -> DIDInfo:
"""
Assign the public DID.

Returns:
The created `DIDInfo`
The updated `DIDInfo`

"""

did_info = await self.get_local_did(did)
if did_info.method != DIDMethod.SOV:
raise WalletError("Setting public did is only allowed for did:sov dids")

# will raise an exception if not found
info = None if did is None else await self.get_local_did(did)

public = await self.get_public_did()
if public and info and public.did == info.did:
info = public
else:
if public:
metadata = public.metadata.copy()
del metadata["public"]
await self.replace_local_did_metadata(public.did, metadata)

if info:
metadata = {**info.metadata, **DIDPosture.PUBLIC.metadata}
await self.replace_local_did_metadata(info.did, metadata)
info = await self.get_local_did(info.did)

return info

@abstractmethod
async def get_local_dids(self) -> Sequence[DIDInfo]:
"""
Expand Down Expand Up @@ -251,23 +208,21 @@ async def replace_local_did_metadata(self, did: str, metadata: dict):

async def get_posted_dids(self) -> Sequence[DIDInfo]:
"""
Get list of defined posted DIDs, excluding public DID.
Get list of defined posted DIDs.

Returns:
A list of `DIDInfo` instances

"""
return [
info
for info in await self.get_local_dids()
if info.metadata.get("posted") and not info.metadata.get("public")
info for info in await self.get_local_dids() if info.metadata.get("posted")
]

async def set_did_endpoint(
self,
did: str,
endpoint: str,
ledger: BaseLedger,
_ledger: BaseLedger,
endpoint_type: EndpointType = None,
):
"""
Expand All @@ -284,7 +239,7 @@ async def set_did_endpoint(
did_info = await self.get_local_did(did)

if did_info.method != DIDMethod.SOV:
raise WalletError("Setting did endpoint is only allowed for did:sov dids")
raise WalletError("Setting DID endpoint is only allowed for did:sov DIDs")
metadata = {**did_info.metadata}
if not endpoint_type:
endpoint_type = EndpointType.ENDPOINT
Expand Down
59 changes: 54 additions & 5 deletions aries_cloudagent/wallet/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import List, Sequence, Tuple, Union

from ..core.in_memory import InMemoryProfile
from ..did.did_key import DIDKey

from .base import BaseWallet
from .did_info import KeyInfo, DIDInfo
from .crypto import (
create_keypair,
validate_seed,
Expand All @@ -15,12 +15,12 @@
encode_pack_message,
decode_pack_message,
)
from .key_type import KeyType
from .did_info import KeyInfo, DIDInfo
from .did_posture import DIDPosture
from .did_method import DIDMethod
from .util import random_seed
from ..did.did_key import DIDKey
from .error import WalletError, WalletDuplicateError, WalletNotFoundError
from .util import b58_to_bytes, bytes_to_b58
from .key_type import KeyType
from .util import b58_to_bytes, bytes_to_b58, random_seed


class InMemoryWallet(BaseWallet):
Expand Down Expand Up @@ -363,6 +363,55 @@ def _get_private_key(self, verkey: str) -> bytes:

raise WalletError("Private key not found for verkey: {}".format(verkey))

async def get_public_did(self) -> DIDInfo:
"""
Retrieve the public DID.

Returns:
The currently public `DIDInfo`, if any

"""

dids = await self.get_local_dids()
for info in dids:
if info.metadata.get("public"):
return info

return None

async def set_public_did(self, did: Union[str, DIDInfo]) -> DIDInfo:
"""
Assign the public DID.

Returns:
The updated `DIDInfo`

"""

if isinstance(did, str):
# will raise an exception if not found
info = await self.get_local_did(did)
else:
info = did
did = info.did

if info.method != DIDMethod.SOV:
raise WalletError("Setting public DID is only allowed for did:sov DIDs")

public = await self.get_public_did()
if public and public.did == did:
info = public
else:
if public:
metadata = {**public.metadata, **DIDPosture.POSTED.metadata}
await self.replace_local_did_metadata(public.did, metadata)

metadata = {**info.metadata, **DIDPosture.PUBLIC.metadata}
await self.replace_local_did_metadata(did, metadata)
info = await self.get_local_did(did)

return info

async def sign_message(
self, message: Union[List[bytes], bytes], from_verkey: str
) -> bytes:
Expand Down
Loading