Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

look up cross-signing keys from the DB in bulk #6486

Merged
merged 16 commits into from
Dec 12, 2019
Merged
1 change: 1 addition & 0 deletions changelog.d/6486.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve performance of looking up cross-signing keys.
29 changes: 21 additions & 8 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def do_remote_query(destination):

return ret

@defer.inlineCallbacks
def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database

Expand All @@ -283,14 +284,26 @@ def get_cross_signing_keys_from_cache(self, query, from_user_id):
self_signing_keys = {}
user_signing_keys = {}

# Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
return defer.succeed(
{
"master_keys": master_keys,
"self_signing_keys": self_signing_keys,
"user_signing_keys": user_signing_keys,
}
)
user_ids = list(query)

keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
uhoreg marked this conversation as resolved.
Show resolved Hide resolved

for user_id, user_info in keys.items():
if "master" in user_info:
master_keys[user_id] = user_info["master"]
if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"]

if from_user_id in keys and "user_signing" in keys[from_user_id]:
# users can see other users' master and self-signing keys, but can
# only see their own user-signing keys
user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]

return {
"master_keys": master_keys,
"self_signing_keys": self_signing_keys,
"user_signing_keys": user_signing_keys,
}

@trace
@defer.inlineCallbacks
Expand Down
211 changes: 207 additions & 4 deletions synapse/storage/data_stores/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from six import iteritems

from canonicaljson import encode_canonical_json, json

from twisted.enterprise.adbapi import Connection
from twisted.internet import defer

from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import cached, cachedList


class EndToEndKeyWorkerStore(SQLBaseStore):
Expand Down Expand Up @@ -271,7 +273,7 @@ def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=No
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_id (str): the user whose key is being requested
key_type (str): the type of key that is being set: either 'master'
key_type (str): the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
Expand Down Expand Up @@ -316,8 +318,10 @@ def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.

Args:
user_id (str): the user whose self-signing key is being requested
key_type (str): the type of cross-signing key to get
user_id (str): the user whose key is being requested
key_type (str): the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
the self-signing key will be included in the result

Expand All @@ -332,6 +336,201 @@ def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
from_user_id,
)

@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
pass
uhoreg marked this conversation as resolved.
Show resolved Hide resolved

@cachedList(
cached_method_name="_get_bare_e2e_cross_signing_keys",
list_name="user_ids",
num_args=1,
)
def _get_bare_e2e_cross_signing_keys_bulk(self, user_ids: list) -> dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_bare_e2e_cross_signing_keys_bulk(self, user_ids: list) -> dict:
def _get_bare_e2e_cross_signing_keys_bulk(self, user_ids: List[str]) -> Dict[str, Dict[str, dict]]:

(You might need to import from typing)

"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.

Args:
txn (twisted.enterprise.adbapi.Connection): db connection
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
txn (twisted.enterprise.adbapi.Connection): db connection

user_ids (list[str]): the users whose keys are being requested

Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, their user
ID will not be in the dict.

"""
return self.db.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)

def _get_bare_e2e_cross_signing_keys_bulk_txn(
self, txn: Connection, user_ids: list,
) -> dict:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.

Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_ids (list[str]): the users whose keys are being requested

Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, their user
ID will not be in the dict.

"""
result = {}

batch_size = 100
chunks = [
user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
]
for user_chunk in chunks:
sql = """
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
WHERE k.user_id IN (%s)
""" % (
",".join("?" for u in user_chunk),
)
query_params = []
query_params.extend(user_chunk)

txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn)

for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
key = json.loads(row["keydata"])
user_info = result.setdefault(user_id, {})
user_info[key_type] = key

return result

def _get_e2e_cross_signing_signatures_txn(
self, txn: Connection, keys: dict, from_user_id: str,
) -> dict:
"""Returns the cross-signing signatures made by a user on a set of keys.

Args:
txn (twisted.enterprise.adbapi.Connection): db connection
keys (dict[str, dict[str, dict]]): a map of user ID to key type to
key data. This dict will be modified to add signatures.
from_user_id (str): fetch the signatures made by this user

Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. The return value will be the same as the keys argument,
with the modifications included.
"""

# find out what cross-signing keys (a.k.a. devices) we need to get
# signatures for. This is a map of (user_id, device_id) to key type
# (device_id is the key's public part).
devices = {}

for user_id, user_info in keys.items():
for key_type, key in user_info.items():
device_id = None
for k in key["keys"].values():
device_id = k
devices[(user_id, device_id)] = key_type

device_list = list(devices)

# split into batches
batch_size = 100
chunks = [
device_list[i : i + batch_size]
for i in range(0, len(device_list), batch_size)
]
for user_chunk in chunks:
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
WHERE user_id = ?
AND (%s)
""" % (
" OR ".join(
"(target_user_id = ? AND target_device_id = ?)" for d in devices
)
)
query_params = [from_user_id]
for item in devices:
# item is a (user_id, device_id) tuple
query_params.extend(item)

txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn)

# and add the signatures to the appropriate keys
for row in rows:
key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
# need to recursively copy the dicts that will be modified.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
user_info = keys[target_user_id] = keys[target_user_id].copy()
target_user_key = user_info[key_type] = user_info[key_type].copy()
if "signatures" in target_user_key:
signatures = target_user_key["signatures"] = target_user_key[
"signatures"
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
user_sigs[key_id] = row["signature"]
else:
signatures[from_user_id] = {key_id: row["signature"]}
else:
target_user_key["signatures"] = {
from_user_id: {key_id: row["signature"]}
}

return keys

@defer.inlineCallbacks
def get_e2e_cross_signing_keys_bulk(
self, user_ids: list, from_user_id: str = None
) -> defer.Deferred:
"""Returns the cross-signing keys for a set of users.

Args:
user_ids (list[str]): the users whose keys are being requested
from_user_id (str): if specified, signatures made by this user on
the self-signing keys will be included in the result

Returns:
Deferred[dict[str, dict]]: map of user ID to key data. If a user's
cross-signing key was not found, their user ID will not be in
the dict.
"""

result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)

if from_user_id:
result = yield self.db.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
from_user_id,
)

return result

def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
Expand Down Expand Up @@ -520,6 +719,10 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
},
)

self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)

def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.

Expand Down
2 changes: 1 addition & 1 deletion synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(
else:
self.function_to_call = orig

arg_spec = inspect.getargspec(orig)
arg_spec = inspect.getfullargspec(orig)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why ooi?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I was going to add a comment explaining it. It's because _get_bare_e2e_cross_signing_keys_bulk has type annotations, so getargspec doesn't work on it, and it suggests to use getfullargspec instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aaaaaah, cool

all_args = arg_spec.args

if "cache_context" in all_args:
Expand Down
8 changes: 0 additions & 8 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,6 @@ def test_replace_master_key(self):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})

test_replace_master_key.skip = (
"Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
)

@defer.inlineCallbacks
def test_reupload_signatures(self):
"""re-uploading a signature should not fail"""
Expand Down Expand Up @@ -507,7 +503,3 @@ def test_upload_signatures(self):
],
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)

test_upload_signatures.skip = (
"Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
)