Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Modify StoreKeyFetcher to read from server_keys_json. #15417

Merged
merged 8 commits into from
Apr 20, 2023
Merged
1 change: 1 addition & 0 deletions changelog.d/15417.bugfix
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where cached key results which were directly fetched would not be properly re-used.
30 changes: 15 additions & 15 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,19 @@ class Keyring:
def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
):
self.clock = hs.get_clock()

if key_fetchers is None:
key_fetchers = (
# Fetch keys from the database.
StoreKeyFetcher(hs),
# Fetch keys from a configured Perspectives server.
PerspectivesKeyFetcher(hs),
# Fetch keys from the origin server directly.
ServerKeyFetcher(hs),
)
self._key_fetchers = key_fetchers
# Always fetch keys from the database.
mutable_key_fetchers: List[KeyFetcher] = [StoreKeyFetcher(hs)]
# Fetch keys from configured trusted key servers, if any exist.
key_servers = hs.config.key.key_servers
if key_servers:
mutable_key_fetchers.append(PerspectivesKeyFetcher(hs))
# Finally, fetch keys from the origin server directly.
mutable_key_fetchers.append(ServerKeyFetcher(hs))

self._key_fetchers: Iterable[KeyFetcher] = tuple(mutable_key_fetchers)
else:
self._key_fetchers = key_fetchers

self._fetch_keys_queue: BatchingQueue[
_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
Expand Down Expand Up @@ -510,7 +511,7 @@ async def _fetch_keys(
for key_id in queue_value.key_ids
)

res = await self.store.get_server_verify_keys(key_ids_to_fetch)
res = await self.store.get_server_keys_json(key_ids_to_fetch)
keys: Dict[str, Dict[str, FetchKeyResult]] = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
Expand All @@ -522,7 +523,6 @@ def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.store = hs.get_datastores().main
self.config = hs.config

async def process_v2_response(
self, from_server: str, response_json: JsonDict, time_added_ms: int
Expand Down Expand Up @@ -626,7 +626,7 @@ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
self.key_servers = self.config.key.key_servers
self.key_servers = hs.config.key.key_servers

async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
Expand Down Expand Up @@ -775,7 +775,7 @@ async def get_server_verify_key_v2_indirect(

keys.setdefault(server_name, {}).update(processed_response)

await self.store.store_server_verify_keys(
await self.store.store_server_signature_keys(
perspective_name, time_now_ms, added_keys
)

Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/key/v2/remote_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def query_keys(
for key_id in key_ids:
store_queries.append((server_name, key_id, None))

cached = await self.store.get_server_keys_json(store_queries)
cached = await self.store.get_server_keys_json_for_remote(store_queries)

json_results: Set[bytes] = set()

Expand Down
99 changes: 87 additions & 12 deletions synapse/storage/databases/main/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# limitations under the License.

import itertools
import json
import logging
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple

from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
Expand All @@ -36,15 +38,16 @@ class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys"""

@cached()
def _get_server_verify_key(
def _get_server_signature_key(
self, server_name_and_key_id: Tuple[str, str]
) -> FetchKeyResult:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
cached_method_name="_get_server_signature_key",
list_name="server_name_and_key_ids",
)
async def get_server_verify_keys(
async def get_server_signature_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Expand All @@ -62,10 +65,12 @@ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""

# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
"SELECT server_name, key_id, verify_key, ts_valid_until_ms "
"FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)
sql = """
SELECT server_name, key_id, verify_key, ts_valid_until_ms
FROM server_signature_keys WHERE 1=0
""" + " OR (server_name=? AND key_id=?)" * len(
batch
)

txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))

Expand All @@ -89,9 +94,9 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
_get_keys(txn, batch)
return keys

return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
return await self.db_pool.runInteraction("get_server_signature_keys", _txn)

async def store_server_verify_keys(
async def store_server_signature_keys(
self,
from_server: str,
ts_added_ms: int,
Expand Down Expand Up @@ -119,7 +124,7 @@ async def store_server_verify_keys(
)
)
# invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one
# _get_server_signature_key. _get_server_signature_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))

Expand All @@ -134,10 +139,10 @@ async def store_server_verify_keys(
"verify_key",
),
value_values=value_values,
desc="store_server_verify_keys",
desc="store_server_signature_keys",
)

invalidate = self._get_server_verify_key.invalidate
invalidate = self._get_server_signature_key.invalidate
for i in invalidations:
invalidate((i,))

Expand Down Expand Up @@ -180,16 +185,86 @@ async def store_server_keys_json(
desc="store_server_keys_json",
)

# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate((((server_name, key_id),)))

@cached()
def _get_server_keys_json(
self, server_name_and_key_id: Tuple[str, str]
) -> FetchKeyResult:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids"
)
async def get_server_keys_json(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a Collection rather than an Iterable? I thought we try to avoid passing iterables to DB queries because they might be exhausted when we come to retry them? (Or is this an Iterable versus Iterator thing?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? Don't we pass iterables in like everywhere?!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#11569 is what I had in mind.

I'm happy for this to land as-is (since it's no worse and should stop trusted key servers from spamming hosts). Though I would like to better understand if Iterables are still a problem that we should worry about.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #11569 and #11564.

I think what we have is probably fine for now then?

Returns:
A map from (server_name, key_id) -> FetchKeyResult, or None if the
key is unknown
"""
keys = {}

def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""

# batch_iter always returns tuples so it's safe to do len(batch)
sql = """
SELECT server_name, key_id, key_json, ts_valid_until_ms
FROM server_keys_json WHERE 1=0
""" + " OR (server_name=? AND key_id=?)" * len(
batch
)

txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))

for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn:
if ts_valid_until_ms is None:
# Old keys may be stored with a ts_valid_until_ms of null,
# in which case we treat this as if it was set to `0`, i.e.
# it won't match key requests that define a minimum
# `ts_valid_until_ms`.
ts_valid_until_ms = 0

# The entire signed JSON response is stored in server_keys_json,
# fetch out the bits needed.
key_json = json.loads(bytes(key_json_bytes))
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
key_base64 = key_json["verify_keys"][key_id]["key"]

keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(
key_id, decode_base64(key_base64)
),
valid_until_ts=ts_valid_until_ms,
)

def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys

return await self.db_pool.runInteraction("get_server_keys_json", _txn)

async def get_server_keys_json_for_remote(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
"""Retrieve the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
clokep marked this conversation as resolved.
Show resolved Hide resolved
used in an HTTP response.
Args:
server_keys: List of (server_name, key_id, source) triplets.
Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts
"""
Expand Down
62 changes: 31 additions & 31 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,23 @@ def test_verify_json_for_server(self) -> None:
kr = keyring.Keyring(self.hs)

key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
r = self.hs.get_datastores().main.store_server_keys_json(
"server9",
int(time.time() * 1000),
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), 1000)},
get_key_id(key1),
from_server="test",
ts_now_ms=int(time.time() * 1000),
ts_expires_ms=1000,
# The entire response gets signed & stored, just include the bits we
# care about.
key_json_bytes=canonicaljson.encode_canonical_json(
{
"verify_keys": {
get_key_id(key1): {
"key": encode_verify_key_base64(get_verify_key(key1))
}
}
}
),
)
self.get_success(r)

Expand Down Expand Up @@ -280,45 +293,26 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))

kr = keyring.Keyring(
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
)

key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
r = self.hs.get_datastores().main.store_server_signature_keys(
"server9",
int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_verify_keys.
# to 0 when fetched in KeyStore.get_server_signature_keys.
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
)
self.get_success(r)

json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server9", key1)

# should fail immediately on an unsigned object
d = kr.verify_json_for_server("server9", {}, 0)
self.get_failure(d, SynapseError)

# should fail on a signed object with a non-zero minimum_valid_until_ms,
# as it tries to refetch the keys and fails.
d = kr.verify_json_for_server("server9", json1, 500)
self.get_failure(d, SynapseError)

# We expect the keyring tried to refetch the key once.
mock_fetcher.get_keys.assert_called_once_with(
"server9", [get_key_id(key1)], 500
)

# should succeed on a signed object with a 0 minimum_valid_until_ms
d = kr.verify_json_for_server(
"server9",
json1,
0,
d = self.hs.get_datastores().main.get_server_signature_keys(
[("server9", get_key_id(key1))]
)
self.get_success(d)
result = self.get_success(d)
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)

def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
Expand Down Expand Up @@ -464,7 +458,9 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict:
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
Expand Down Expand Up @@ -582,7 +578,9 @@ def test_get_keys_from_perspectives(self) -> None:
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
Expand Down Expand Up @@ -703,7 +701,9 @@ def test_get_perspectives_own_key(self) -> None:
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
Expand Down
Loading