Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bulk claim OTKs
Browse files Browse the repository at this point in the history
David Robertson committed Oct 28, 2023

Unverified

This user has not yet uploaded their public signing key.
1 parent e30ae68 commit da69538
Showing 1 changed file with 62 additions and 48 deletions.
110 changes: 62 additions & 48 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
@@ -1133,25 +1134,31 @@ async def claim_e2e_one_time_keys(
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
# allows us to use autocommit mode.
unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
for user_id, device_id, algorithm, count in query_list:
claim_rows = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
self._claim_e2e_one_time_key_returning,
user_id,
device_id,
algorithm,
count,
db_autocommit=True,
unfulfilled_claim_counts[user_id, device_id, algorithm] = count

bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
self._claim_e2e_one_time_keys_returning,
query_list,
db_autocommit=True,
)

for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
if claim_rows:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
for claim_row in claim_rows:
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
# Did we get enough OTKs?
count -= len(claim_rows)
if count:
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1

# Did we get enough OTKs?
for (
user_id,
device_id,
algorithm,
), count in unfulfilled_claim_counts.items():
if count > 0:
missing.append((user_id, device_id, algorithm, count))
else:
for user_id, device_id, algorithm, count in query_list:
@@ -1276,46 +1283,53 @@ def _claim_e2e_one_time_key_simple(
return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]

@trace
def _claim_e2e_one_time_key_returning(
def _claim_e2e_one_time_keys_returning(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
query_list: Iterable[Tuple[str, str, str, int]],
) -> List[Tuple[str, str, str, str, str]]:
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
Args:
query_list: Collection of tuples (user_id, device_id, algorithm, count)
as passed to claim_e2e_one_time_keys.
Returns:
A tuple of key name (algorithm + key ID) and key JSON, if an
OTK was found.
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
for each OTK claimed.
"""

# We can use RETURNING to do the fetch and DELETE in once step.
sql = """
DELETE FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
AND key_id IN (
SELECT key_id FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT ?
)
RETURNING key_id, key_json
"""

txn.execute(
sql,
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
WITH claims(user_id, device_id, algorithm, claim_count) AS (
VALUES ?
), ranked_keys AS (
SELECT
user_id, device_id, algorithm, key_id, claim_count,
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
FROM e2e_one_time_keys_json
JOIN claims USING (user_id, device_id, algorithm)
)
DELETE FROM e2e_one_time_keys_json k
WHERE (user_id, device_id, algorithm, key_id) IN (
SELECT user_id, device_id, algorithm, key_id
FROM ranked_keys
WHERE r <= claim_count
)
RETURNING user_id, device_id, algorithm, key_id, key_json;
"""
otk_rows = cast(
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)
otk_rows = list(txn)
if not otk_rows:
return []

self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, _, _, _ in otk_rows:
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)

return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
return otk_rows


class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):

0 comments on commit da69538

Please sign in to comment.