Skip to content

Commit

Permalink
update get/set_public_did to use a storage record pointer
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Whitehead <[email protected]>
  • Loading branch information
andrewwhitehead committed Jun 17, 2021
1 parent 4c61a25 commit b911fe3
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 95 deletions.
71 changes: 13 additions & 58 deletions aries_cloudagent/wallet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ..ledger.endpoint_type import EndpointType
from .error import WalletError

from .did_posture import DIDPosture
from .did_info import DIDInfo, KeyInfo
from .key_type import KeyType
from .did_method import DIDMethod
Expand Down Expand Up @@ -122,8 +121,6 @@ async def create_public_did(
"""
Create and store a new public DID.
Implicitly flags all other dids as not public.
Args:
seed: Optional seed to use for DID
did: The DID to use
Expand All @@ -133,73 +130,33 @@ async def create_public_did(
The created `DIDInfo`
"""
if method != DIDMethod.SOV:
raise WalletError("Creating public did is only allowed for did:sov dids")

# validate key_type
if not method.supports_key_type(key_type):
raise WalletError(
f"Invalid key type {key_type.key_type} for method {method.method_name}"
)

metadata = DIDPosture.PUBLIC.metadata
dids = await self.get_local_dids()
for info in dids:
info_meta = info.metadata
info_meta["public"] = False
await self.replace_local_did_metadata(info.did, info_meta)
return await self.create_local_did(
metadata = metadata or {}
metadata.setdefault("posted", True)
did_info = await self.create_local_did(
method=method, key_type=key_type, seed=seed, did=did, metadata=metadata
)
return await self.set_public_did(did_info)

@abstractmethod
async def get_public_did(self) -> DIDInfo:
"""
Retrieve the public DID.
Returns:
The created `DIDInfo`
The currently public `DIDInfo`, if any
"""

dids = await self.get_local_dids()
for info in dids:
if info.metadata.get("public"):
return info

return None

async def set_public_did(self, did: str) -> DIDInfo:
@abstractmethod
async def set_public_did(self, did: Union[str, DIDInfo]) -> DIDInfo:
"""
Assign the public DID.
Returns:
The created `DIDInfo`
The updated `DIDInfo`
"""

did_info = await self.get_local_did(did)
if did_info.method != DIDMethod.SOV:
raise WalletError("Setting public did is only allowed for did:sov dids")

# will raise an exception if not found
info = None if did is None else await self.get_local_did(did)

public = await self.get_public_did()
if public and info and public.did == info.did:
info = public
else:
if public:
metadata = public.metadata.copy()
del metadata["public"]
await self.replace_local_did_metadata(public.did, metadata)

if info:
metadata = {**info.metadata, **DIDPosture.PUBLIC.metadata}
await self.replace_local_did_metadata(info.did, metadata)
info = await self.get_local_did(info.did)

return info

@abstractmethod
async def get_local_dids(self) -> Sequence[DIDInfo]:
"""
Expand Down Expand Up @@ -251,23 +208,21 @@ async def replace_local_did_metadata(self, did: str, metadata: dict):

async def get_posted_dids(self) -> Sequence[DIDInfo]:
"""
Get list of defined posted DIDs, excluding public DID.
Get list of defined posted DIDs.
Returns:
A list of `DIDInfo` instances
"""
return [
info
for info in await self.get_local_dids()
if info.metadata.get("posted") and not info.metadata.get("public")
info for info in await self.get_local_dids() if info.metadata.get("posted")
]

async def set_did_endpoint(
self,
did: str,
endpoint: str,
ledger: BaseLedger,
_ledger: BaseLedger,
endpoint_type: EndpointType = None,
):
"""
Expand All @@ -284,7 +239,7 @@ async def set_did_endpoint(
did_info = await self.get_local_did(did)

if did_info.method != DIDMethod.SOV:
raise WalletError("Setting did endpoint is only allowed for did:sov dids")
raise WalletError("Setting DID endpoint is only allowed for did:sov DIDs")
metadata = {**did_info.metadata}
if not endpoint_type:
endpoint_type = EndpointType.ENDPOINT
Expand Down
59 changes: 54 additions & 5 deletions aries_cloudagent/wallet/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import List, Sequence, Tuple, Union

from ..core.in_memory import InMemoryProfile
from ..did.did_key import DIDKey

from .base import BaseWallet
from .did_info import KeyInfo, DIDInfo
from .crypto import (
create_keypair,
validate_seed,
Expand All @@ -15,12 +15,12 @@
encode_pack_message,
decode_pack_message,
)
from .key_type import KeyType
from .did_info import KeyInfo, DIDInfo
from .did_posture import DIDPosture
from .did_method import DIDMethod
from .util import random_seed
from ..did.did_key import DIDKey
from .error import WalletError, WalletDuplicateError, WalletNotFoundError
from .util import b58_to_bytes, bytes_to_b58
from .key_type import KeyType
from .util import b58_to_bytes, bytes_to_b58, random_seed


class InMemoryWallet(BaseWallet):
Expand Down Expand Up @@ -363,6 +363,55 @@ def _get_private_key(self, verkey: str) -> bytes:

raise WalletError("Private key not found for verkey: {}".format(verkey))

async def get_public_did(self) -> DIDInfo:
"""
Retrieve the public DID.
Returns:
The currently public `DIDInfo`, if any
"""

dids = await self.get_local_dids()
for info in dids:
if info.metadata.get("public"):
return info

return None

async def set_public_did(self, did: Union[str, DIDInfo]) -> DIDInfo:
"""
Assign the public DID.
Returns:
The updated `DIDInfo`
"""

if isinstance(did, str):
# will raise an exception if not found
info = await self.get_local_did(did)
else:
info = did
did = info.did

if info.method != DIDMethod.SOV:
raise WalletError("Setting public DID is only allowed for did:sov DIDs")

public = await self.get_public_did()
if public and public.did == did:
info = public
else:
if public:
metadata = {**public.metadata, **DIDPosture.POSTED.metadata}
await self.replace_local_did_metadata(public.did, metadata)

metadata = {**info.metadata, **DIDPosture.PUBLIC.metadata}
await self.replace_local_did_metadata(did, metadata)
info = await self.get_local_did(did)

return info

async def sign_message(
self, message: Union[List[bytes], bytes], from_verkey: str
) -> bytes:
Expand Down
102 changes: 95 additions & 7 deletions aries_cloudagent/wallet/indy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Indy implementation of BaseWallet interface."""

from aries_cloudagent.storage.record import StorageRecord
import json

from typing import List, Sequence, Tuple, Union
Expand All @@ -11,30 +12,34 @@

from indy.error import IndyError, ErrorCode

from ..did.did_key import DIDKey
from ..indy.sdk.error import IndyErrorHandler
from ..indy.sdk.wallet_setup import IndyOpenWallet
from ..ledger.base import BaseLedger
from ..ledger.endpoint_type import EndpointType
from ..ledger.error import LedgerConfigError
from ..storage.indy import IndySdkStorage
from ..storage.error import StorageDuplicateError, StorageNotFoundError

from ..did.did_key import DIDKey
from .base import BaseWallet
from .crypto import (
create_keypair,
sign_message,
validate_seed,
verify_signed_message,
)
from .key_type import KeyType
from .did_method import DIDMethod
from .key_pair import KeyPairStorageManager
from ..storage.indy import IndySdkStorage
from ..storage.error import StorageDuplicateError, StorageNotFoundError
from .did_info import DIDInfo, KeyInfo
from .did_method import DIDMethod
from .error import WalletError, WalletDuplicateError, WalletNotFoundError
from .key_pair import KeyPairStorageManager
from .key_type import KeyType
from .util import b58_to_bytes, bytes_to_b58, bytes_to_b64


RECORD_TYPE_CONFIG = "config"
RECORD_NAME_PUBLIC_DID = "default_public_did"


class IndySdkWallet(BaseWallet):
"""Indy identity wallet implementation."""

Expand Down Expand Up @@ -604,6 +609,89 @@ async def replace_local_did_metadata(self, did: str, metadata: dict):
verkey=did_info.verkey, metadata=metadata
)

async def get_public_did(self) -> DIDInfo:
"""
Retrieve the public DID.
Returns:
The currently public `DIDInfo`, if any
"""

public_did = None
public_info = None
storage = IndySdkStorage(self.opened)
try:
public = await storage.get_record(
RECORD_TYPE_CONFIG, RECORD_NAME_PUBLIC_DID
)
except StorageNotFoundError:
# populate public DID record
# this should only happen once, for an upgraded wallet
# the 'public' metadata flag is no longer used
dids = await self.get_local_dids()
for info in dids:
if info.metadata.get("public"):
public_did = info.did
public_info = info
break
try:
# even if public is not set, store a record
# to avoid repeated queries
await storage.add_record(
StorageRecord(
type=RECORD_TYPE_CONFIG,
id=RECORD_NAME_PUBLIC_DID,
value=json.dumps({"did": public_did}),
)
)
except StorageDuplicateError:
# another process stored the record first
pass
else:
public_did = json.loads(public.value)["did"]
if public_did:
try:
public_info = await self.get_local_did(public_did)
except WalletNotFoundError:
pass

return public_info

async def set_public_did(self, did: Union[str, DIDInfo]) -> DIDInfo:
"""
Assign the public DID.
Returns:
The updated `DIDInfo`
"""

if isinstance(did, str):
# will raise an exception if not found
info = await self.get_local_did(did)
else:
info = did

if info.method != DIDMethod.SOV:
raise WalletError("Setting public DID is only allowed for did:sov DIDs")

public = await self.get_public_did()
if not public or public.did != info.did:
storage = IndySdkStorage(self.opened)
await storage.update_record(
StorageRecord(
type=RECORD_TYPE_CONFIG,
id=RECORD_NAME_PUBLIC_DID,
value="{}",
),
value=json.dumps({"did": info.did}),
tags=None,
)
public = info

return public

async def set_did_endpoint(
self,
did: str,
Expand All @@ -624,7 +712,7 @@ async def set_did_endpoint(
"""
did_info = await self.get_local_did(did)
if did_info.method != DIDMethod.SOV:
raise WalletError("Setting did endpoint is only allowed for did:sov dids")
raise WalletError("Setting DID endpoint is only allowed for did:sov DIDs")

metadata = {**did_info.metadata}
if not endpoint_type:
Expand Down
6 changes: 3 additions & 3 deletions aries_cloudagent/wallet/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,9 @@ async def wallet_did_list(request: web.BaseRequest):
filter_posture = DIDPosture.get(request.query.get("posture"))
filter_key_type = KeyType.from_key_type(request.query.get("key_type"))
results = []
public_did_info = await wallet.get_public_did()
posted_did_infos = await wallet.get_posted_dids()

if filter_posture is DIDPosture.PUBLIC:
public_did_info = await wallet.get_public_did()
if (
public_did_info
and (not filter_verkey or public_did_info.verkey == filter_verkey)
Expand All @@ -221,6 +220,7 @@ async def wallet_did_list(request: web.BaseRequest):
results.append(format_did_info(public_did_info))
elif filter_posture is DIDPosture.POSTED:
results = []
posted_did_infos = await wallet.get_posted_dids()
for info in posted_did_infos:
if (
(not filter_verkey or info.verkey == filter_verkey)
Expand Down Expand Up @@ -401,7 +401,7 @@ async def wallet_set_public_did(request: web.BaseRequest):
raise web.HTTPNotFound(reason=f"DID {did} is not posted to the ledger")

did_info = await wallet.get_local_did(did)
info = await wallet.set_public_did(did)
info = await wallet.set_public_did(did_info)
if info:
# Publish endpoint if necessary
endpoint = did_info.metadata.get("endpoint")
Expand Down
Loading

0 comments on commit b911fe3

Please sign in to comment.