diff --git a/aries_cloudagent/messaging/credentials/manager.py b/aries_cloudagent/messaging/credentials/manager.py index 93c479a7c0..658441db54 100644 --- a/aries_cloudagent/messaging/credentials/manager.py +++ b/aries_cloudagent/messaging/credentials/manager.py @@ -490,7 +490,7 @@ async def store_credential( credential_definition, raw_credential, credential_exchange_record.credential_request_metadata, - credential_id=credential_id + credential_id=credential_id, ) credential = await holder.get_credential(credential_id) @@ -498,6 +498,13 @@ async def store_credential( credential_exchange_record.state = CredentialExchange.STATE_STORED credential_exchange_record.credential_id = credential_id credential_exchange_record.credential = credential + + # clear unnecessary data + credential_exchange_record.credential_offer = None + credential_exchange_record.credential_request = None + credential_exchange_record.raw_credential = None + # credential_request_metadata may be reused + await credential_exchange_record.save(self.context, reason="Store credential") credential_stored_message = CredentialStored() @@ -525,6 +532,12 @@ async def credential_stored(self, credential_stored_message: CredentialStored): }, ) + # clear unnecessary data + credential_exchange_record.credential_offer = None + credential_exchange_record.credential_request = None + credential_exchange_record.credential_request_metadata = None + credential_exchange_record.credential_values = None + credential_exchange_record.state = CredentialExchange.STATE_STORED await credential_exchange_record.save(self.context, reason="Credential stored") diff --git a/aries_cloudagent/messaging/models/base_record.py b/aries_cloudagent/messaging/models/base_record.py index ac8dde13ca..9eab004871 100644 --- a/aries_cloudagent/messaging/models/base_record.py +++ b/aries_cloudagent/messaging/models/base_record.py @@ -31,6 +31,7 @@ class Meta: LOG_STATE_FLAG = None CACHE_TTL = 60 CACHE_ENABLED = False + UNENCRYPTED_TAGS = () def __init__( self, @@ -83,7 +84,7 @@ def record_value(self) -> dict: @property def value(self) -> dict: """Accessor for the JSON record value generated for this record.""" - ret = self.tags + ret = self.strip_tag_prefix(self.tags) ret.update({"created_at": self.created_at, "updated_at": self.updated_at}) ret.update(self.record_value) return ret @@ -98,8 +99,9 @@ def tags(self) -> dict: """Accessor for the record tags generated for this record.""" tags = {"state": self.state} tags.update(self.record_tags) + unenc = self.UNENCRYPTED_TAGS or () # tag values must be non-empty - return {k: v for (k, v) in tags.items() if v} + return {(f"~{k}" if k in unenc else k): v for (k, v) in tags.items() if v} @classmethod def cache_key(cls, record_id: str, record_type: str = None): @@ -186,7 +188,7 @@ async def retrieve_by_id( result = await storage.get_record(cls.RECORD_TYPE, record_id) vals = json.loads(result.value) if result.tags: - vals.update(result.tags) + vals.update(cls.strip_tag_prefix(result.tags)) if cls.CACHE_ENABLED: await cls.set_cached_key(context, cache_key, vals) @@ -204,10 +206,10 @@ async def retrieve_by_tag_filter( """ storage: BaseStorage = await context.inject(BaseStorage) result = await storage.search_records( - cls.RECORD_TYPE, tag_filter + cls.RECORD_TYPE, cls.prefix_tag_filter(tag_filter) ).fetch_single() vals = json.loads(result.value) - vals.update(result.tags) + vals.update(cls.strip_tag_prefix(result.tags)) return cls.from_storage(result.id, vals) @classmethod @@ -221,11 +223,13 @@ async def query( tag_filter: An optional dictionary of tag filter clauses """ storage: BaseStorage = await context.inject(BaseStorage) - found = await storage.search_records(cls.RECORD_TYPE, tag_filter).fetch_all() + found = await storage.search_records( + cls.RECORD_TYPE, cls.prefix_tag_filter(tag_filter) + ).fetch_all() result = [] for record in found: vals = json.loads(record.value) - vals.update(record.tags) + vals.update(cls.strip_tag_prefix(record.tags)) result.append(cls.from_storage(record.id, vals)) return result @@ -358,6 +362,31 @@ def log_state( out += f" {k}: {v}\n" print(out, file=sys.stderr) + @classmethod + def strip_tag_prefix(cls, tags: dict): + """Strip tilde from unencrypted tag names.""" + return ( + {(k[1:] if "~" in k else k): v for (k, v) in tags.items()} if tags else {} + ) + + @classmethod + def prefix_tag_filter(cls, tag_filter: dict): + """Prefix unencrypted tags used in the tag filter.""" + ret = None + if tag_filter: + unenc = cls.UNENCRYPTED_TAGS or () + ret = {} + for k, v in tag_filter.items(): + if k in ("$or", "$and") and isinstance(v, list): + ret[k] = [cls.prefix_tag_filter(clause) for clause in v] + elif k == "$not" and isinstance(v, dict): + ret[k] = cls.prefix_tag_filter(v) + elif k in unenc: + ret[f"~{k}"] = v + else: + ret[k] = v + return ret + def __eq__(self, other: Any) -> bool: """Comparison between records.""" if type(other) is type(self): diff --git a/aries_cloudagent/messaging/models/tests/test_base_record.py b/aries_cloudagent/messaging/models/tests/test_base_record.py index b18c96dd4d..2012bfceb8 100644 --- a/aries_cloudagent/messaging/models/tests/test_base_record.py +++ b/aries_cloudagent/messaging/models/tests/test_base_record.py @@ -25,6 +25,10 @@ class Meta: model_class = BaseRecordImpl +class UnencTestImpl(BaseRecord): + UNENCRYPTED_TAGS = {"a", "b"} + + class TestBaseRecord(AsyncTestCase): def test_init_undef(self): with self.assertRaises(TypeError): @@ -180,3 +184,15 @@ async def test_webhook(self): topic = "topic" await record.send_webhook(context, payload, topic=topic) assert mock_responder.webhooks == [(topic, payload)] + + async def test_tag_prefix(self): + tags = {"~x": "a", "y": "b"} + assert UnencTestImpl.strip_tag_prefix(tags) == {"x": "a", "y": "b"} + + tags = {"a": "x", "b": "y", "c": "z"} + assert UnencTestImpl.prefix_tag_filter(tags) == {"~a": "x", "~b": "y", "c": "z"} + + tags = {"$or": [{"a": "x"}, {"c": "z"}]} + assert UnencTestImpl.prefix_tag_filter(tags) == { + "$or": [{"~a": "x"}, {"c": "z"}] + } diff --git a/demo/requirements.txt b/demo/requirements.txt index c9fa262f0f..f2cdf42d3e 100644 --- a/demo/requirements.txt +++ b/demo/requirements.txt @@ -1,2 +1,3 @@ +asyncpg~=0.18.0 prompt_toolkit~=2.0.9 git+https://github.com/webpy/webpy.git#egg=web.py diff --git a/demo/runners/performance.py b/demo/runners/performance.py index 62afc6ff8f..69754e7f4d 100644 --- a/demo/runners/performance.py +++ b/demo/runners/performance.py @@ -94,6 +94,10 @@ def check_received_creds(self) -> (int, int): async def update_creds(self): await self.credential_event.wait() + async def set_tag_policy(self, cred_def_id, taggables): + req_body = {"taggables": taggables} + await self.admin_POST(f"/wallet/tag-policy/{cred_def_id}", req_body) + class FaberAgent(BaseAgent): def __init__(self, port: int, **kwargs): @@ -205,6 +209,7 @@ async def main(start_port: int, show_timing: bool = False, routing: bool = False with log_timer("Publish duration:"): await faber.publish_defs() + # await alice.set_tag_policy(faber.credential_definition_id, ["name"]) with log_timer("Connect duration:"): if routing: @@ -292,6 +297,15 @@ async def check_received(agent, issue_count, pb): avg = recv_timer.duration / issue_count alice.log(f"Average time per credential: {avg:.2f}s ({1/avg:.2f}/s)") + if alice.postgres: + await alice.collect_postgres_stats(str(issue_count) + " creds") + for line in alice.format_postgres_stats(): + alice.log(line) + if faber.postgres: + await faber.collect_postgres_stats(str(issue_count) + " creds") + for line in faber.format_postgres_stats(): + faber.log(line) + if show_timing: timing = await alice.fetch_timing() if timing: @@ -358,6 +372,8 @@ async def check_received(agent, issue_count, pb): require_indy() try: - asyncio.get_event_loop().run_until_complete(main(args.port, True, args.routing)) + asyncio.get_event_loop().run_until_complete( + main(args.port, False, args.routing) + ) except KeyboardInterrupt: os._exit(1) diff --git a/demo/runners/support/agent.py b/demo/runners/support/agent.py index 61dd9b6f85..a82f959a89 100644 --- a/demo/runners/support/agent.py +++ b/demo/runners/support/agent.py @@ -1,4 +1,5 @@ import asyncio +import asyncpg import functools import json import logging @@ -117,6 +118,7 @@ def __init__( ) self.wallet_key = params.get("wallet_key") or self.ident + rand_name self.did = None + self.wallet_stats = [] async def register_schema_and_creddef(self, schema_name, version, schema_attrs): # Create a schema @@ -168,29 +170,8 @@ def get_agent_args(self): result.extend( [ ("--wallet-storage-type", "postgres_storage"), - ( - "--wallet-storage-config", - json.dumps( - { - "url": f"{self.internal_host}:5432", - "tls": "None", - "max_connections": 5, - "min_idle_time": 0, - "connection_timeout": 10, - } - ), - ), - ( - "--wallet-storage-creds", - json.dumps( - { - "account": "postgres", - "password": "mysecretpassword", - "admin_account": "postgres", - "admin_password": "mysecretpassword", - } - ), - ), + ("--wallet-storage-config", json.dumps(self.postgres_config)), + ("--wallet-storage-creds", json.dumps(self.postgres_creds)), ] ) if self.webhook_url: @@ -426,3 +407,80 @@ def format_timing(self, timing: dict) -> dict: async def reset_timing(self): await self.admin_POST("/status/reset", text=True) + + @property + def postgres_config(self): + return { + "url": f"{self.internal_host}:5432", + "tls": "None", + "max_connections": 5, + "min_idle_time": 0, + "connection_timeout": 10, + } + + @property + def postgres_creds(self): + return { + "account": "postgres", + "password": "mysecretpassword", + "admin_account": "postgres", + "admin_password": "mysecretpassword", + } + + async def collect_postgres_stats(self, ident: str, vacuum_full: bool = True): + creds = self.postgres_creds + + conn = await asyncpg.connect( + host=self.internal_host, + port="5432", + user=creds["admin_account"], + password=creds["admin_password"], + database=self.wallet_name, + ) + + tables = ("items", "tags_encrypted", "tags_plaintext") + for t in tables: + await conn.execute(f"VACUUM FULL {t}" if vacuum_full else f"VACUUM {t}") + + sizes = await conn.fetch( + """ + SELECT relname AS "relation", + pg_size_pretty(pg_total_relation_size(C.oid)) AS "total_size" + FROM pg_class C + LEFT JOIN pg_namespace N ON (N.oid = C.relnamespace) + WHERE nspname = 'public' + ORDER BY pg_total_relation_size(C.oid) DESC; + """ + ) + results = {k: [0, "0B"] for k in tables} + for row in sizes: + if row["relation"] in results: + results[row["relation"]][1] = row["total_size"].replace(" ", "") + for t in tables: + row = await conn.fetchrow(f"""SELECT COUNT(*) AS "count" FROM {t}""") + results[t][0] = row["count"] + self.wallet_stats.append((ident, results)) + + await conn.close() + + def format_postgres_stats(self): + if not self.wallet_stats: + return + yield "{:30} | {:>17} | {:>17} | {:>17}".format( + f"{self.wallet_name} DB", "items", "tags_encrypted", "tags_plaintext" + ) + yield "=" * 90 + for ident, stats in self.wallet_stats: + yield "{:30} | {:8d} {:>8} | {:8d} {:>8} | {:8d} {:>8}".format( + ident, + stats["items"][0], + stats["items"][1], + stats["tags_encrypted"][0], + stats["tags_encrypted"][1], + stats["tags_plaintext"][0], + stats["tags_plaintext"][1], + ) + yield "" + + def reset_postgres_stats(self): + self.wallet_stats.clear()