diff --git a/aries_cloudagent/admin/routes.py b/aries_cloudagent/admin/routes.py index b14014fc5d..02e1031f8d 100644 --- a/aries_cloudagent/admin/routes.py +++ b/aries_cloudagent/admin/routes.py @@ -14,6 +14,7 @@ from ..messaging.basicmessage.routes import register as register_basicmessages from ..messaging.discovery.routes import register as register_discovery from ..messaging.trustping.routes import register as register_trustping +from ..wallet.routes import register as register_wallet async def register_module_routes(app: web.Application): @@ -33,3 +34,4 @@ async def register_module_routes(app: web.Application): await register_basicmessages(app) await register_discovery(app) await register_trustping(app) + await register_wallet(app) diff --git a/aries_cloudagent/conductor.py b/aries_cloudagent/conductor.py index 0ee2df5e02..6116a60ce9 100644 --- a/aries_cloudagent/conductor.py +++ b/aries_cloudagent/conductor.py @@ -135,11 +135,7 @@ async def setup(self): # at the class level (!) should not be performed multiple times collector.wrap( ConnectionManager, - ( - "get_connection_target", - "fetch_did_document", - "find_connection", - ), + ("get_connection_target", "fetch_did_document", "find_connection"), ) async def start(self) -> None: diff --git a/aries_cloudagent/wallet/base.py b/aries_cloudagent/wallet/base.py index 98dbb10d97..cd1820be75 100644 --- a/aries_cloudagent/wallet/base.py +++ b/aries_cloudagent/wallet/base.py @@ -4,6 +4,7 @@ from collections import namedtuple from typing import Sequence + KeyInfo = namedtuple("KeyInfo", "verkey metadata") DIDInfo = namedtuple("DIDInfo", "did verkey metadata") @@ -184,6 +185,35 @@ async def get_public_did(self) -> DIDInfo: return None + async def set_public_did(self, did: str) -> DIDInfo: + """ + Assign the public did. + + Returns: + The created `DIDInfo` + + """ + + # will raise an exception if not found + info = None if did is None else await self.get_local_did(did) + + public = await self.get_public_did() + if public and info and public.did == info.did: + info = public + else: + if public: + metadata = public.metadata.copy() + del metadata["public"] + await self.replace_local_did_metadata(public.did, metadata) + + if info: + metadata = info.metadata.copy() + metadata["public"] = True + await self.replace_local_did_metadata(info.did, metadata) + info = await self.get_local_did(info.did) + + return info + @abstractmethod async def get_local_dids(self) -> Sequence[DIDInfo]: """ diff --git a/aries_cloudagent/wallet/routes.py b/aries_cloudagent/wallet/routes.py new file mode 100644 index 0000000000..89ef57f52e --- /dev/null +++ b/aries_cloudagent/wallet/routes.py @@ -0,0 +1,215 @@ +"""Wallet admin routes.""" + +from aiohttp import web +from aiohttp_apispec import docs, response_schema + +from marshmallow import fields, Schema + +from ..ledger.base import BaseLedger + +from .base import DIDInfo, BaseWallet +from .error import WalletError + + +class DIDSchema(Schema): + """Result schema for a DID.""" + + did = fields.Str() + verkey = fields.Str() + public = fields.Bool() + + +class DIDResultSchema(Schema): + """Result schema for a DID.""" + + result = fields.Nested(DIDSchema()) + + +class DIDListSchema(Schema): + """Result schema for connection list.""" + + results = fields.List(fields.Nested(DIDSchema())) + + +def format_did_info(info: DIDInfo): + """Serialize a DIDInfo object.""" + if info: + return { + "did": info.did, + "verkey": info.verkey, + "public": info.metadata + and info.metadata.get("public") + and "true" + or "false", + } + + +@docs( + tags=["wallet"], + summary="List wallet DIDs", + parameters=[ + {"name": "did", "in": "query", "schema": {"type": "string"}, "required": False}, + { + "name": "verkey", + "in": "query", + "schema": {"type": "string"}, + "required": False, + }, + { + "name": "public", + "in": "query", + "schema": {"type": "boolean"}, + "required": False, + }, + ], +) +@response_schema(DIDListSchema, 200) +async def wallet_did_list(request: web.BaseRequest): + """ + Request handler for searching wallet DIDs. + + Args: + request: aiohttp request object + + Returns: + The DID list response + + """ + context = request.app["request_context"] + wallet: BaseWallet = await context.inject(BaseWallet, required=False) + if not wallet: + raise web.HTTPForbidden() + filter_did = request.query.get("did") + filter_verkey = request.query.get("verkey") + filter_public = request.query.get("public") + results = [] + + if filter_public == "true": + info = await wallet.get_public_did() + if ( + info + and (not filter_verkey or info.verkey == filter_verkey) + and (not filter_did or info.did == filter_did) + ): + results.append(format_did_info(info)) + elif filter_did: + try: + info = await wallet.get_local_did(filter_did) + except WalletError: + # badly formatted DID or record not found + info = None + if info and (not filter_verkey or info.verkey == filter_verkey): + results.append(format_did_info(info)) + elif filter_verkey: + try: + info = await wallet.get_local_did_for_verkey(filter_verkey) + except WalletError: + info = None + if info: + results.append(format_did_info(info)) + else: + dids = await wallet.get_local_dids() + results = [] + for info in dids: + results.append(format_did_info(info)) + + results.sort(key=lambda info: info["did"]) + return web.json_response({"results": results}) + + +@docs(tags=["wallet"], summary="Create a local DID") +@response_schema(DIDResultSchema, 200) +async def wallet_create_did(request: web.BaseRequest): + """ + Request handler for creating a new wallet DID. + + Args: + request: aiohttp request object + + Returns: + The DID list response + + """ + context = request.app["request_context"] + wallet: BaseWallet = await context.inject(BaseWallet, required=False) + if not wallet: + raise web.HTTPForbidden() + info = await wallet.create_local_did() + return web.json_response({"result": format_did_info(info)}) + + +@docs(tags=["wallet"], summary="Fetch the current public DID") +@response_schema(DIDResultSchema, 200) +async def wallet_get_public_did(request: web.BaseRequest): + """ + Request handler for fetching the current public DID. + + Args: + request: aiohttp request object + + Returns: + The DID list response + + """ + context = request.app["request_context"] + wallet: BaseWallet = await context.inject(BaseWallet, required=False) + if not wallet: + raise web.HTTPForbidden() + info = await wallet.get_public_did() + return web.json_response({"result": format_did_info(info)}) + + +@docs( + tags=["wallet"], + summary="Assign the current public DID", + parameters=[ + {"name": "did", "in": "query", "schema": {"type": "string"}, "required": True} + ], +) +@response_schema(DIDResultSchema, 200) +async def wallet_set_public_did(request: web.BaseRequest): + """ + Request handler for setting the current public DID. + + Args: + request: aiohttp request object + + Returns: + The updated DID info + + """ + context = request.app["request_context"] + wallet: BaseWallet = await context.inject(BaseWallet, required=False) + if not wallet: + raise web.HTTPForbidden() + did = request.query.get("did") + if not did: + raise web.HTTPBadRequest() + try: + info = await wallet.get_local_did(did) + except WalletError: + # DID not found or not in valid format + raise web.HTTPBadRequest() + info = await wallet.set_public_did(did) + if info: + # Publish endpoint if necessary + endpoint = context.settings.get("default_endpoint") + ledger = await context.inject(BaseLedger, required=False) + if ledger: + async with ledger: + await ledger.update_endpoint_for_did(info.did, endpoint) + + return web.json_response({"result": format_did_info(info)}) + + +async def register(app: web.Application): + """Register routes.""" + + app.add_routes( + [ + web.get("/wallet/did", wallet_did_list), + web.post("/wallet/did/create", wallet_create_did), + web.get("/wallet/did/public", wallet_get_public_did), + web.post("/wallet/did/public", wallet_set_public_did), + ] + ) diff --git a/aries_cloudagent/wallet/tests/test_routes.py b/aries_cloudagent/wallet/tests/test_routes.py new file mode 100644 index 0000000000..76593c6ec7 --- /dev/null +++ b/aries_cloudagent/wallet/tests/test_routes.py @@ -0,0 +1,137 @@ +from asynctest import TestCase as AsyncTestCase +from asynctest import mock as async_mock +import pytest + +from aiohttp.web import HTTPForbidden + +from ...config.injection_context import InjectionContext +from ...wallet.base import BaseWallet, DIDInfo + +from .. import routes as test_module + + +class TestWalletRoutes(AsyncTestCase): + def setUp(self): + self.context = InjectionContext(enforce_typing=False) + self.wallet = async_mock.create_autospec(BaseWallet) + self.context.injector.bind_instance(BaseWallet, self.wallet) + self.app = { + "outbound_message_router": async_mock.CoroutineMock(), + "request_context": self.context, + } + self.test_did = "did" + self.test_verkey = "verkey" + + async def test_missing_wallet(self): + request = async_mock.MagicMock() + request.app = self.app + self.context.injector.clear_binding(BaseWallet) + + with self.assertRaises(HTTPForbidden): + await test_module.wallet_create_did(request) + + with self.assertRaises(HTTPForbidden): + await test_module.wallet_did_list(request) + + with self.assertRaises(HTTPForbidden): + await test_module.wallet_get_public_did(request) + + with self.assertRaises(HTTPForbidden): + await test_module.wallet_set_public_did(request) + + def test_format_did_info(self): + did_info = DIDInfo(self.test_did, self.test_verkey, {}) + result = test_module.format_did_info(did_info) + assert ( + result["did"] == self.test_did + and result["verkey"] == self.test_verkey + and result["public"] == "false" + ) + did_info = DIDInfo(self.test_did, self.test_verkey, {"public": True}) + result = test_module.format_did_info(did_info) + assert result["public"] == "true" + + async def test_create_did(self): + request = async_mock.MagicMock() + request.app = self.app + with async_mock.patch.object( + test_module.web, "json_response", async_mock.Mock() + ) as json_response, async_mock.patch.object( + test_module, "format_did_info", async_mock.Mock() + ) as format_did_info: + self.wallet.create_local_did.return_value = DIDInfo( + self.test_did, self.test_verkey, {} + ) + result = await test_module.wallet_create_did(request) + format_did_info.assert_called_once_with( + self.wallet.create_local_did.return_value + ) + json_response.assert_called_once_with( + {"result": format_did_info.return_value} + ) + assert result is json_response.return_value + + async def test_did_list(self): + request = async_mock.MagicMock() + request.app = self.app + request.query = {} + with async_mock.patch.object( + test_module.web, "json_response", async_mock.Mock() + ) as json_response, async_mock.patch.object( + test_module, "format_did_info", async_mock.Mock() + ) as format_did_info: + self.wallet.get_local_dids.return_value = [ + DIDInfo(self.test_did, self.test_verkey, {}) + ] + format_did_info.return_value = {"did": self.test_did} + result = await test_module.wallet_did_list(request) + format_did_info.assert_called_once_with( + self.wallet.get_local_dids.return_value[0] + ) + json_response.assert_called_once_with( + {"results": [format_did_info.return_value]} + ) + assert json_response.return_value is json_response() + assert result is json_response.return_value + + async def test_get_public_did(self): + request = async_mock.MagicMock() + request.app = self.app + with async_mock.patch.object( + test_module.web, "json_response", async_mock.Mock() + ) as json_response, async_mock.patch.object( + test_module, "format_did_info", async_mock.Mock() + ) as format_did_info: + self.wallet.get_public_did.return_value = DIDInfo( + self.test_did, self.test_verkey, {} + ) + result = await test_module.wallet_get_public_did(request) + format_did_info.assert_called_once_with( + self.wallet.get_public_did.return_value + ) + json_response.assert_called_once_with( + {"result": format_did_info.return_value} + ) + assert result is json_response.return_value + + async def test_set_public_did(self): + request = async_mock.MagicMock() + request.app = self.app + request.query = {"did": self.test_did} + with async_mock.patch.object( + test_module.web, "json_response", async_mock.Mock() + ) as json_response, async_mock.patch.object( + test_module, "format_did_info", async_mock.Mock() + ) as format_did_info: + self.wallet.get_public_did.return_value = DIDInfo( + self.test_did, self.test_verkey, {} + ) + result = await test_module.wallet_set_public_did(request) + self.wallet.set_public_did.assert_awaited_once_with(request.query["did"]) + format_did_info.assert_called_once_with( + self.wallet.set_public_did.return_value + ) + json_response.assert_called_once_with( + {"result": format_did_info.return_value} + ) + assert result is json_response.return_value