diff --git a/changelog.d/17169.misc b/changelog.d/17169.misc new file mode 100644 index 00000000000..6b06b002ead --- /dev/null +++ b/changelog.d/17169.misc @@ -0,0 +1 @@ +Add database performance improvement when fetching auth chains. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 24abab4a235..11217096569 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -39,6 +39,7 @@ import attr from prometheus_client import Counter, Gauge +from sortedcontainers import SortedSet from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError @@ -118,6 +119,11 @@ class BackfillQueueNavigationItem: type: str +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _ChainLinksCacheEntry: + links: List[Tuple[int, int, int, "_ChainLinksCacheEntry"]] = attr.Factory(list) + + class _NoChainCoverIndex(Exception): def __init__(self, room_id: str): super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) @@ -138,6 +144,10 @@ def __init__( self.hs = hs + self._chain_links_cache: LruCache[int, _ChainLinksCacheEntry] = LruCache( + max_size=10000, cache_name="chain_links_cache" + ) + if hs.config.worker.run_background_tasks: hs.get_clock().looping_call( self._delete_old_forward_extrem_cache, 60 * 60 * 1000 @@ -289,7 +299,9 @@ def _get_auth_chain_ids_using_cover_index_txn( # A map from chain ID to max sequence number *reachable* from any event ID. chains: Dict[int, int] = {} - for links in self._get_chain_links(txn, set(event_chains.keys())): + for links in self._get_chain_links( + txn, event_chains.keys(), self._chain_links_cache + ): for chain_id in links: if chain_id not in event_chains: continue @@ -341,7 +353,10 @@ def _get_auth_chain_ids_using_cover_index_txn( @classmethod def _get_chain_links( - cls, txn: LoggingTransaction, chains_to_fetch: Set[int] + cls, + txn: LoggingTransaction, + chains_to_fetch: Collection[int], + cache: Optional[LruCache[int, _ChainLinksCacheEntry]] = None, ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: """Fetch all auth chain links from the given set of chains, and all links from those chains, recursively. @@ -353,12 +368,55 @@ def _get_chain_links( of origin sequence number, target chain ID and target sequence number. """ + found_cached_chains = set() + if cache: + entries: Dict[int, _ChainLinksCacheEntry] = {} + for chain_id in chains_to_fetch: + entry = cache.get(chain_id) + if entry: + entries[chain_id] = entry + + cached_links: Dict[int, List[Tuple[int, int, int]]] = {} + while entries: + origin_chain_id, entry = entries.popitem() + + for ( + origin_sequence_number, + target_chain_id, + target_sequence_number, + target_entry, + ) in entry.links: + if target_chain_id in found_cached_chains: + continue + + found_cached_chains.add(target_chain_id) + + cache.get(chain_id) + + entries[chain_id] = target_entry + cached_links.setdefault(origin_chain_id, []).append( + ( + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) + ) + + yield cached_links + # This query is structured to first get all chain IDs reachable, and # then pull out all links from those chains. This does pull out more # rows than is strictly necessary, however there isn't a way of # structuring the recursive part of query to pull out the links without # also returning large quantities of redundant data (which can make it a # lot slower). + + if isinstance(txn.database_engine, PostgresEngine): + # JIT and sequential scans sometimes get hit on this code path, which + # can make the queries much more expensive + txn.execute("SET LOCAL jit = off") + txn.execute("SET LOCAL enable_seqscan = off") + sql = """ WITH RECURSIVE links(chain_id) AS ( SELECT @@ -377,9 +435,20 @@ def _get_chain_links( INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) """ - while chains_to_fetch: - batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) - chains_to_fetch.difference_update(batch2) + # We fetch the links in batches. Separate batches will likely fetch the + # same set of links (e.g. they'll always pull in the links to create + # event). To try and minimize the amount of redundant links, we query + # the chain IDs in reverse order, as there will be a correlation between + # the order of chain IDs and links (i.e., higher chain IDs are more + # likely to depend on lower chain IDs than vice versa). + BATCH_SIZE = 5000 + chains_to_fetch_sorted = SortedSet(chains_to_fetch) + chains_to_fetch_sorted.difference_update(found_cached_chains) + + while chains_to_fetch_sorted: + batch2 = list(chains_to_fetch_sorted.islice(-BATCH_SIZE)) + chains_to_fetch_sorted.difference_update(batch2) + clause, args = make_in_list_sql_clause( txn.database_engine, "origin_chain_id", batch2 ) @@ -387,6 +456,8 @@ def _get_chain_links( links: Dict[int, List[Tuple[int, int, int]]] = {} + cache_entries: Dict[int, _ChainLinksCacheEntry] = {} + for ( origin_chain_id, origin_sequence_number, @@ -397,7 +468,28 @@ def _get_chain_links( (origin_sequence_number, target_chain_id, target_sequence_number) ) - chains_to_fetch.difference_update(links) + if cache: + origin_entry = cache_entries.setdefault( + origin_chain_id, _ChainLinksCacheEntry() + ) + target_entry = cache_entries.setdefault( + target_chain_id, _ChainLinksCacheEntry() + ) + origin_entry.links.append( + ( + origin_sequence_number, + target_chain_id, + target_sequence_number, + target_entry, + ) + ) + + if cache: + for chain_id, entry in cache_entries.items(): + if chain_id not in cache: + cache[chain_id] = entry + + chains_to_fetch_sorted.difference_update(links) yield links @@ -589,7 +681,7 @@ def fetch_chain_info(events_to_fetch: Collection[str]) -> None: # are reachable from any event. # (We need to take a copy of `seen_chains` as the function mutates it) - for links in self._get_chain_links(txn, set(seen_chains)): + for links in self._get_chain_links(txn, seen_chains, self._chain_links_cache): for chains in set_to_chain: for chain_id in links: if chain_id not in chains: diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 080d5640a5b..9fa69f6581e 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -25,6 +25,7 @@ from synapse.server import HomeServer from synapse.util import Clock +from tests.test_utils.event_injection import inject_event from tests.unittest import HomeserverTestCase @@ -128,3 +129,76 @@ def test_purge_room(self) -> None: self.store._invalidate_local_get_event_cache(create_event.event_id) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) + + def test_state_groups_state_decreases(self) -> None: + response = self.helper.send(self.room_id, body="first") + first_event_id = response["event_id"] + + batches = [] + + previous_event_id = first_event_id + for i in range(50): + state_event1 = self.get_success( + inject_event( + self.hs, + type="test.state", + sender=self.user_id, + state_key="", + room_id=self.room_id, + content={"key": i, "e": 1}, + prev_event_ids=[previous_event_id], + origin_server_ts=1, + ) + ) + + state_event2 = self.get_success( + inject_event( + self.hs, + type="test.state", + sender=self.user_id, + state_key="", + room_id=self.room_id, + content={"key": i, "e": 2}, + prev_event_ids=[previous_event_id], + origin_server_ts=2, + ) + ) + + # print(state_event2.origin_server_ts - state_event1.origin_server_ts) + + message_event = self.get_success( + inject_event( + self.hs, + type="dummy_event", + sender=self.user_id, + room_id=self.room_id, + content={}, + prev_event_ids=[state_event1.event_id, state_event2.event_id], + ) + ) + + token = self.get_success( + self.store.get_topological_token_for_event(state_event1.event_id) + ) + batches.append(token) + + previous_event_id = message_event.event_id + + self.helper.send(self.room_id, body="last event") + + def count_state_groups() -> int: + sql = "SELECT COUNT(*) FROM state_groups_state WHERE room_id = ?" + rows = self.get_success( + self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id) + ) + return rows[0][0] + + print(count_state_groups()) + for token in batches: + token_str = self.get_success(token.to_string(self.hs.get_datastores().main)) + self.get_success( + self._storage_controllers.purge_events.purge_history( + self.room_id, token_str, False + ) + ) + print(count_state_groups())