From af75c086a0a4137beacf735dcb233f31cb3d7234 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 28 Aug 2020 16:44:23 +0100 Subject: [PATCH 1/5] Move `get_devices_with_keys_by_user` to `EndToEndKeyWorkerStore` this seems a better fit for it. This commit simply moves the existing code: no other changes at all. --- synapse/storage/databases/main/devices.py | 45 ----------------- .../storage/databases/main/end_to_end_keys.py | 48 ++++++++++++++++++- 2 files changed, 47 insertions(+), 46 deletions(-) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index def96637a261..710bfdfa17f8 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -481,51 +481,6 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict] device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id: str): - """Get all devices (with any device keys) for a user - - Returns: - Deferred which resolves to (stream_id, devices) - """ - return self.db_pool.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn( - self, txn: LoggingTransaction, user_id: str - ) -> Tuple[int, List[JsonDict]]: - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in user_devices.items(): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - async def get_users_whose_devices_changed( self, from_key: str, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index af0b85e2c92c..c686a6b06187 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -22,7 +22,8 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -32,6 +33,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore): + def get_devices_with_keys_by_user(self, user_id: str): + """Get all devices (with any device keys) for a user + + Returns: + Deferred which resolves to (stream_id, devices) + """ + return self.db_pool.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, + user_id, + ) + + def _get_devices_with_keys_by_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.items(): + result = {"device_id": device_id} + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + @trace async def get_e2e_device_keys( self, query_list, include_all_devices=False, include_deleted_devices=False From 80db26f2af55b91ef055d625ea92c9690ccf9e85 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 28 Aug 2020 17:55:39 +0100 Subject: [PATCH 2/5] Rename `get_devices_with_keys_by_user` to better reflect what it does. --- synapse/handlers/device.py | 4 +++- synapse/storage/databases/main/end_to_end_keys.py | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index db417d60deb4..ee4666337a62 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -234,7 +234,9 @@ async def get_user_ids_changed(self, user_id, from_token): return result async def on_federation_query_user_devices(self, user_id): - stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id) + stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( + user_id + ) master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c686a6b06187..f81064ced7b3 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -33,19 +33,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore): - def get_devices_with_keys_by_user(self, user_id: str): + def get_e2e_device_keys_for_federation_query(self, user_id: str): """Get all devices (with any device keys) for a user Returns: Deferred which resolves to (stream_id, devices) """ return self.db_pool.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, + "get_e2e_device_keys_for_federation_query", + self._get_e2e_device_keys_for_federation_query_txn, user_id, ) - def _get_devices_with_keys_by_user_txn( + def _get_e2e_device_keys_for_federation_query( self, txn: LoggingTransaction, user_id: str ) -> Tuple[int, List[JsonDict]]: now_stream_id = self._device_list_id_gen.get_current_token() From a39fe73d485e9b88110595e3f5e4acbc7d09b732 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 28 Aug 2020 16:51:02 +0100 Subject: [PATCH 3/5] get_device_stream_token abstract method To avoid referencing fields which are declared in the derived classes, make `get_device_stream_token` abstract, and define that in the classes which define `_device_list_id_gen`. --- synapse/replication/slave/storage/devices.py | 3 +++ synapse/storage/databases/main/__init__.py | 3 +++ synapse/storage/databases/main/devices.py | 7 +++++-- synapse/storage/databases/main/end_to_end_keys.py | 8 +++++++- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 596c72eb92af..3b788c96250d 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,6 +48,9 @@ def __init__(self, database: DatabasePool, db_conn, hs): "DeviceListFederationStreamChangeCache", device_list_max ) + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 70cf15dd7f46..e6536c8456de 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -264,6 +264,9 @@ def __init__(self, database: DatabasePool, db_conn, hs): # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 710bfdfa17f8..e8379c73c460 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging from typing import Any, Dict, Iterable, List, Optional, Set, Tuple @@ -101,7 +102,7 @@ async def get_device_updates_by_remote( update included in the response), and the list of updates, where each update is a pair of EDU type and EDU contents. """ - now_stream_id = self._device_list_id_gen.get_current_token() + now_stream_id = self.get_device_stream_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -412,8 +413,10 @@ def _add_user_signature_change_txn( }, ) + @abc.abstractmethod def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() + """Get the current stream id from the _device_list_id_gen""" + ... @trace async def get_user_devices_from_cache( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f81064ced7b3..b7628baefd22 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -48,7 +49,7 @@ def get_e2e_device_keys_for_federation_query(self, user_id: str): def _get_e2e_device_keys_for_federation_query( self, txn: LoggingTransaction, user_id: str ) -> Tuple[int, List[JsonDict]]: - now_stream_id = self._device_list_id_gen.get_current_token() + now_stream_id = self.get_device_stream_token() devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) @@ -587,6 +588,11 @@ def _get_all_user_signature_changes_for_remotes_txn(txn): _get_all_user_signature_changes_for_remotes_txn, ) + @abc.abstractmethod + def get_device_stream_token(self) -> int: + """Get the current stream id from the _device_list_id_gen""" + ... + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): From 2969d1cbd970babfbb3108f3241e1ebd3e97f356 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 28 Aug 2020 17:57:26 +0100 Subject: [PATCH 4/5] changelog --- changelog.d/8204.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/8204.misc diff --git a/changelog.d/8204.misc b/changelog.d/8204.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8204.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. From b48168c4975c59c7451bc880122f482a5590706a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 28 Aug 2020 18:18:45 +0100 Subject: [PATCH 5/5] Update synapse/storage/databases/main/end_to_end_keys.py Co-authored-by: Patrick Cloke --- synapse/storage/databases/main/end_to_end_keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index b7628baefd22..706f17800d96 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -46,7 +46,7 @@ def get_e2e_device_keys_for_federation_query(self, user_id: str): user_id, ) - def _get_e2e_device_keys_for_federation_query( + def _get_e2e_device_keys_for_federation_query_txn( self, txn: LoggingTransaction, user_id: str ) -> Tuple[int, List[JsonDict]]: now_stream_id = self.get_device_stream_token()