Skip to content

Commit

Permalink
Merge pull request #1249 from andrewwhitehead/fast-did-public
Browse files Browse the repository at this point in the history
Update get/set_did_public to use a storage record pointer
  • Loading branch information
andrewwhitehead authored Jun 22, 2021
2 parents 68347f8 + ce6454a commit 0fb2c4f
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 102 deletions.
2 changes: 1 addition & 1 deletion aries_cloudagent/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..core.error import BaseError

InjectType = TypeVar("Inject")
InjectType = TypeVar("InjectType")


class ConfigError(BaseError):
Expand Down
12 changes: 6 additions & 6 deletions aries_cloudagent/core/tests/test_event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def event():
yield event


class TestProcessor:
class MockProcessor:
def __init__(self):
self.context = None
self.event = None
Expand All @@ -39,7 +39,7 @@ async def __call__(self, context, event):

@pytest.fixture
def processor():
yield TestProcessor()
yield MockProcessor()


def test_event(event):
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_unsub_unsubbed_processor(event_bus: EventBus, processor):
"""Test unsubscribing an unsubscribed processor does not error."""
event_bus.unsubscribe(re.compile(".*"), processor)
event_bus.subscribe(re.compile(".*"), processor)
another_processor = TestProcessor()
another_processor = MockProcessor()
event_bus.unsubscribe(re.compile(".*"), another_processor)


Expand All @@ -101,7 +101,7 @@ async def test_sub_notify_error_logged_and_exec_continues(
def _raise_exception(context, event):
raise Exception()

processor = TestProcessor()
processor = MockProcessor()
bad_processor = _raise_exception
event_bus.subscribe(re.compile(".*"), bad_processor)
event_bus.subscribe(re.compile(".*"), processor)
Expand Down Expand Up @@ -147,7 +147,7 @@ async def test_sub_notify_no_match(event_bus: EventBus, context, event, processo
@pytest.mark.asyncio
async def test_sub_notify_only_one(event_bus: EventBus, context, event, processor):
"""Test only one subscriber is called when pattern matches only one."""
processor1 = TestProcessor()
processor1 = MockProcessor()
event_bus.subscribe(re.compile(".*"), processor)
event_bus.subscribe(re.compile("^$"), processor1)
await event_bus.notify(context, event)
Expand All @@ -160,7 +160,7 @@ async def test_sub_notify_only_one(event_bus: EventBus, context, event, processo
@pytest.mark.asyncio
async def test_sub_notify_both(event_bus: EventBus, context, event, processor):
"""Test both subscribers are called when pattern matches both."""
processor1 = TestProcessor()
processor1 = MockProcessor()
event_bus.subscribe(re.compile(".*"), processor)
event_bus.subscribe(re.compile("anything"), processor1)
await event_bus.notify(context, event)
Expand Down
71 changes: 13 additions & 58 deletions aries_cloudagent/wallet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ..ledger.endpoint_type import EndpointType
from .error import WalletError

from .did_posture import DIDPosture
from .did_info import DIDInfo, KeyInfo
from .key_type import KeyType
from .did_method import DIDMethod
Expand Down Expand Up @@ -122,8 +121,6 @@ async def create_public_did(
"""
Create and store a new public DID.
Implicitly flags all other dids as not public.
Args:
seed: Optional seed to use for DID
did: The DID to use
Expand All @@ -133,73 +130,33 @@ async def create_public_did(
The created `DIDInfo`
"""
if method != DIDMethod.SOV:
raise WalletError("Creating public did is only allowed for did:sov dids")

# validate key_type
if not method.supports_key_type(key_type):
raise WalletError(
f"Invalid key type {key_type.key_type} for method {method.method_name}"
)

metadata = DIDPosture.PUBLIC.metadata
dids = await self.get_local_dids()
for info in dids:
info_meta = info.metadata
info_meta["public"] = False
await self.replace_local_did_metadata(info.did, info_meta)
return await self.create_local_did(
metadata = metadata or {}
metadata.setdefault("posted", True)
did_info = await self.create_local_did(
method=method, key_type=key_type, seed=seed, did=did, metadata=metadata
)
return await self.set_public_did(did_info)

@abstractmethod
async def get_public_did(self) -> DIDInfo:
"""
Retrieve the public DID.
Returns:
The created `DIDInfo`
The currently public `DIDInfo`, if any
"""

dids = await self.get_local_dids()
for info in dids:
if info.metadata.get("public"):
return info

return None

async def set_public_did(self, did: str) -> DIDInfo:
@abstractmethod
async def set_public_did(self, did: Union[str, DIDInfo]) -> DIDInfo:
"""
Assign the public DID.
Returns:
The created `DIDInfo`
The updated `DIDInfo`
"""

did_info = await self.get_local_did(did)
if did_info.method != DIDMethod.SOV:
raise WalletError("Setting public did is only allowed for did:sov dids")

# will raise an exception if not found
info = None if did is None else await self.get_local_did(did)

public = await self.get_public_did()
if public and info and public.did == info.did:
info = public
else:
if public:
metadata = public.metadata.copy()
del metadata["public"]
await self.replace_local_did_metadata(public.did, metadata)

if info:
metadata = {**info.metadata, **DIDPosture.PUBLIC.metadata}
await self.replace_local_did_metadata(info.did, metadata)
info = await self.get_local_did(info.did)

return info

@abstractmethod
async def get_local_dids(self) -> Sequence[DIDInfo]:
"""
Expand Down Expand Up @@ -251,23 +208,21 @@ async def replace_local_did_metadata(self, did: str, metadata: dict):

async def get_posted_dids(self) -> Sequence[DIDInfo]:
"""
Get list of defined posted DIDs, excluding public DID.
Get list of defined posted DIDs.
Returns:
A list of `DIDInfo` instances
"""
return [
info
for info in await self.get_local_dids()
if info.metadata.get("posted") and not info.metadata.get("public")
info for info in await self.get_local_dids() if info.metadata.get("posted")
]

async def set_did_endpoint(
self,
did: str,
endpoint: str,
ledger: BaseLedger,
_ledger: BaseLedger,
endpoint_type: EndpointType = None,
):
"""
Expand All @@ -284,7 +239,7 @@ async def set_did_endpoint(
did_info = await self.get_local_did(did)

if did_info.method != DIDMethod.SOV:
raise WalletError("Setting did endpoint is only allowed for did:sov dids")
raise WalletError("Setting DID endpoint is only allowed for did:sov DIDs")
metadata = {**did_info.metadata}
if not endpoint_type:
endpoint_type = EndpointType.ENDPOINT
Expand Down
59 changes: 54 additions & 5 deletions aries_cloudagent/wallet/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import List, Sequence, Tuple, Union

from ..core.in_memory import InMemoryProfile
from ..did.did_key import DIDKey

from .base import BaseWallet
from .did_info import KeyInfo, DIDInfo
from .crypto import (
create_keypair,
validate_seed,
Expand All @@ -15,12 +15,12 @@
encode_pack_message,
decode_pack_message,
)
from .key_type import KeyType
from .did_info import KeyInfo, DIDInfo
from .did_posture import DIDPosture
from .did_method import DIDMethod
from .util import random_seed
from ..did.did_key import DIDKey
from .error import WalletError, WalletDuplicateError, WalletNotFoundError
from .util import b58_to_bytes, bytes_to_b58
from .key_type import KeyType
from .util import b58_to_bytes, bytes_to_b58, random_seed


class InMemoryWallet(BaseWallet):
Expand Down Expand Up @@ -363,6 +363,55 @@ def _get_private_key(self, verkey: str) -> bytes:

raise WalletError("Private key not found for verkey: {}".format(verkey))

async def get_public_did(self) -> DIDInfo:
"""
Retrieve the public DID.
Returns:
The currently public `DIDInfo`, if any
"""

dids = await self.get_local_dids()
for info in dids:
if info.metadata.get("public"):
return info

return None

async def set_public_did(self, did: Union[str, DIDInfo]) -> DIDInfo:
"""
Assign the public DID.
Returns:
The updated `DIDInfo`
"""

if isinstance(did, str):
# will raise an exception if not found
info = await self.get_local_did(did)
else:
info = did
did = info.did

if info.method != DIDMethod.SOV:
raise WalletError("Setting public DID is only allowed for did:sov DIDs")

public = await self.get_public_did()
if public and public.did == did:
info = public
else:
if public:
metadata = {**public.metadata, **DIDPosture.POSTED.metadata}
await self.replace_local_did_metadata(public.did, metadata)

metadata = {**info.metadata, **DIDPosture.PUBLIC.metadata}
await self.replace_local_did_metadata(did, metadata)
info = await self.get_local_did(did)

return info

async def sign_message(
self, message: Union[List[bytes], bytes], from_verkey: str
) -> bytes:
Expand Down
Loading

0 comments on commit 0fb2c4f

Please sign in to comment.