Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Rewrite store_server_verify_key to store several keys at once #5234

Merged
merged 1 commit into from
May 23, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5234.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Rewrite store_server_verify_key to store several keys at once.
59 changes: 14 additions & 45 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
@@ -453,10 +453,11 @@ def get_server_verify_key_v2_indirect(
raise_from(KeyLookupError("Remote server returned an error"), e)

keys = {}
added_keys = []

responses = query_response["server_keys"]
time_now_ms = self.clock.time_msec()

for response in responses:
for response in query_response["server_keys"]:
if (
u"signatures" not in response
or perspective_name not in response[u"signatures"]
@@ -492,21 +493,13 @@ def get_server_verify_key_v2_indirect(
)
server_name = response["server_name"]

added_keys.extend(
(server_name, key_id, key) for key_id, key in processed_response.items()
)
keys.setdefault(server_name, {}).update(processed_response)

yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store_keys,
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
)
for server_name, response_keys in keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
yield self.store.store_server_verify_keys(
perspective_name, time_now_ms, added_keys
)

defer.returnValue(keys)
@@ -519,6 +512,7 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids):
if requested_key_id in keys:
continue

time_now_ms = self.clock.time_msec()
try:
response = yield self.client.get_json(
destination=server_name,
@@ -548,12 +542,13 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids):
requested_ids=[requested_key_id],
response_json=response,
)

yield self.store.store_server_verify_keys(
server_name,
time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()),
)
keys.update(response_keys)

yield self.store_keys(
server_name=server_name, from_server=server_name, verify_keys=keys
)
defer.returnValue({server_name: keys})

@defer.inlineCallbacks
@@ -650,32 +645,6 @@ def process_v2_response(self, from_server, response_json, requested_ids=[]):

defer.returnValue(response_keys)

def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
server_name(str): The name of the server the keys are for.
from_server(str): The server the keys were downloaded from.
verify_keys(dict): A mapping of key_id to VerifyKey.
Returns:
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
return logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_verify_key,
server_name,
server_name,
key.time_added,
key,
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)


@defer.inlineCallbacks
def _handle_key_deferred(verify_request):
65 changes: 39 additions & 26 deletions synapse/storage/keys.py
Original file line number Diff line number Diff line change
@@ -84,38 +84,51 @@ def _txn(txn):

return self.runInteraction("get_server_verify_keys", _txn)

def store_server_verify_key(
self, server_name, from_server, time_now_ms, verify_key
):
"""Stores a NACL verification key for the given server.
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
"""Stores NACL verification keys for remote servers.
Args:
server_name (str): The name of the server.
from_server (str): Where the verification key was looked up
time_now_ms (int): The time now in milliseconds
verify_key (nacl.signing.VerifyKey): The NACL verify key.
from_server (str): Where the verification keys were looked up
ts_added_ms (int): The time to record that the key was added
verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]):
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
key_id = "%s:%s" % (verify_key.alg, verify_key.version)

# XXX fix this to not need a lock (#3819)
def _txn(txn):
self._simple_upsert_txn(
txn,
table="server_signature_keys",
keyvalues={"server_name": server_name, "key_id": key_id},
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
"verify_key": db_binary_type(verify_key.encode()),
},
key_values = []
value_values = []
invalidations = []
for server_name, key_id, verify_key in verify_keys:
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
db_binary_type(verify_key.encode()),
)
)
# invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
txn.call_after(
self._get_server_verify_key.invalidate, ((server_name, key_id),)
)

return self.runInteraction("store_server_verify_key", _txn)
invalidations.append((server_name, key_id))

def _invalidate(res):
f = self._get_server_verify_key.invalidate
for i in invalidations:
f((i, ))
return res

return self.runInteraction(
"store_server_verify_keys",
self._simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
value_names=(
"from_server",
"ts_added_ms",
"verify_key",
),
value_values=value_values,
).addCallback(_invalidate)

def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
14 changes: 12 additions & 2 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
@@ -192,8 +192,18 @@ def test_verify_json_for_server(self):
kr = keyring.Keyring(self.hs)

key1 = signedjson.key.generate_signing_key(1)
r = self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
key1_id = "%s:%s" % (key1.alg, key1.version)

r = self.hs.datastore.store_server_verify_keys(
"server9",
time.time() * 1000,
[
(
"server9",
key1_id,
signedjson.key.get_verify_key(key1),
),
],
)
self.get_success(r)
json1 = {}
44 changes: 30 additions & 14 deletions tests/storage/test_keys.py
Original file line number Diff line number Diff line change
@@ -31,23 +31,32 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self):
store = self.hs.get_datastore()

d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
self.get_success(d)
d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2"
d = store.store_server_verify_keys(
"from_server",
10,
[
("server1", key_id_1, KEY_1),
("server1", key_id_2, KEY_2),
],
)
self.get_success(d)

d = store.get_server_verify_keys(
[
("server1", "ed25519:key1"),
("server1", "ed25519:key2"),
("server1", "ed25519:key3"),
]
[("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
)
res = self.get_success(d)

self.assertEqual(len(res.keys()), 3)
self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
res1 = res[("server1", key_id_1)]
self.assertEqual(res1, KEY_1)
self.assertEqual(res1.version, "key1")

res2 = res[("server1", key_id_2)]
self.assertEqual(res2, KEY_2)
# version comes from the ID it was stored with
self.assertEqual(res2.version, "KEY_ID_2")

# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -60,9 +69,14 @@ def test_cache(self):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"

d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
self.get_success(d)
d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
d = store.store_server_verify_keys(
"from_server",
0,
[
("srv1", key_id_1, KEY_1),
("srv1", key_id_2, KEY_2),
],
)
self.get_success(d)

d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
@@ -81,7 +95,9 @@ def test_cache(self):
new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
d = store.store_server_verify_keys(
"from_server", 10, [("srv1", key_id_2, new_key_2)]
)
self.get_success(d)

d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])