Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf improvement to getting auth chains #17169

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions changelog.d/17169.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add database performance improvement when fetching auth chains.
106 changes: 99 additions & 7 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import attr
from prometheus_client import Counter, Gauge
from sortedcontainers import SortedSet

from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
Expand Down Expand Up @@ -118,6 +119,11 @@ class BackfillQueueNavigationItem:
type: str


@attr.s(frozen=True, slots=True, auto_attribs=True)
class _ChainLinksCacheEntry:
links: List[Tuple[int, int, int, "_ChainLinksCacheEntry"]] = attr.Factory(list)


class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
Expand All @@ -138,6 +144,10 @@ def __init__(

self.hs = hs

self._chain_links_cache: LruCache[int, _ChainLinksCacheEntry] = LruCache(
max_size=10000, cache_name="chain_links_cache"
)

if hs.config.worker.run_background_tasks:
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
Expand Down Expand Up @@ -289,7 +299,9 @@ def _get_auth_chain_ids_using_cover_index_txn(

# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}
for links in self._get_chain_links(txn, set(event_chains.keys())):
for links in self._get_chain_links(
txn, event_chains.keys(), self._chain_links_cache
):
for chain_id in links:
if chain_id not in event_chains:
continue
Expand Down Expand Up @@ -341,7 +353,10 @@ def _get_auth_chain_ids_using_cover_index_txn(

@classmethod
def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
cls,
txn: LoggingTransaction,
chains_to_fetch: Collection[int],
cache: Optional[LruCache[int, _ChainLinksCacheEntry]] = None,
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively.
Expand All @@ -353,12 +368,55 @@ def _get_chain_links(
of origin sequence number, target chain ID and target sequence number.
"""

found_cached_chains = set()
if cache:
entries: Dict[int, _ChainLinksCacheEntry] = {}
for chain_id in chains_to_fetch:
entry = cache.get(chain_id)
if entry:
entries[chain_id] = entry

cached_links: Dict[int, List[Tuple[int, int, int]]] = {}
while entries:
origin_chain_id, entry = entries.popitem()

for (
origin_sequence_number,
target_chain_id,
target_sequence_number,
target_entry,
) in entry.links:
if target_chain_id in found_cached_chains:
continue

found_cached_chains.add(target_chain_id)

cache.get(chain_id)

entries[chain_id] = target_entry
cached_links.setdefault(origin_chain_id, []).append(
(
origin_sequence_number,
target_chain_id,
target_sequence_number,
)
)

yield cached_links

# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).

if isinstance(txn.database_engine, PostgresEngine):
# JIT and sequential scans sometimes get hit on this code path, which
# can make the queries much more expensive
txn.execute("SET LOCAL jit = off")
txn.execute("SET LOCAL enable_seqscan = off")

sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
Expand All @@ -377,16 +435,29 @@ def _get_chain_links(
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""

while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
# We fetch the links in batches. Separate batches will likely fetch the
# same set of links (e.g. they'll always pull in the links to create
# event). To try and minimize the amount of redundant links, we query
# the chain IDs in reverse order, as there will be a correlation between
# the order of chain IDs and links (i.e., higher chain IDs are more
# likely to depend on lower chain IDs than vice versa).
BATCH_SIZE = 5000
chains_to_fetch_sorted = SortedSet(chains_to_fetch)
chains_to_fetch_sorted.difference_update(found_cached_chains)

while chains_to_fetch_sorted:
batch2 = list(chains_to_fetch_sorted.islice(-BATCH_SIZE))
chains_to_fetch_sorted.difference_update(batch2)

clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

links: Dict[int, List[Tuple[int, int, int]]] = {}

cache_entries: Dict[int, _ChainLinksCacheEntry] = {}

for (
origin_chain_id,
origin_sequence_number,
Expand All @@ -397,7 +468,28 @@ def _get_chain_links(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

chains_to_fetch.difference_update(links)
if cache:
origin_entry = cache_entries.setdefault(
origin_chain_id, _ChainLinksCacheEntry()
)
target_entry = cache_entries.setdefault(
target_chain_id, _ChainLinksCacheEntry()
)
origin_entry.links.append(
(
origin_sequence_number,
target_chain_id,
target_sequence_number,
target_entry,
)
)

if cache:
for chain_id, entry in cache_entries.items():
if chain_id not in cache:
cache[chain_id] = entry

chains_to_fetch_sorted.difference_update(links)

yield links

Expand Down Expand Up @@ -589,7 +681,7 @@ def fetch_chain_info(events_to_fetch: Collection[str]) -> None:
# are reachable from any event.

# (We need to take a copy of `seen_chains` as the function mutates it)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment no longer applies.

for links in self._get_chain_links(txn, set(seen_chains)):
for links in self._get_chain_links(txn, seen_chains, self._chain_links_cache):
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
Expand Down
74 changes: 74 additions & 0 deletions tests/storage/test_purge.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from synapse.server import HomeServer
from synapse.util import Clock

from tests.test_utils.event_injection import inject_event
from tests.unittest import HomeserverTestCase


Expand Down Expand Up @@ -128,3 +129,76 @@ def test_purge_room(self) -> None:
self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)

def test_state_groups_state_decreases(self) -> None:
response = self.helper.send(self.room_id, body="first")
first_event_id = response["event_id"]

batches = []

previous_event_id = first_event_id
for i in range(50):
state_event1 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 1},
prev_event_ids=[previous_event_id],
origin_server_ts=1,
)
)

state_event2 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 2},
prev_event_ids=[previous_event_id],
origin_server_ts=2,
)
)

# print(state_event2.origin_server_ts - state_event1.origin_server_ts)

message_event = self.get_success(
inject_event(
self.hs,
type="dummy_event",
sender=self.user_id,
room_id=self.room_id,
content={},
prev_event_ids=[state_event1.event_id, state_event2.event_id],
)
)

token = self.get_success(
self.store.get_topological_token_for_event(state_event1.event_id)
)
batches.append(token)

previous_event_id = message_event.event_id

self.helper.send(self.room_id, body="last event")

def count_state_groups() -> int:
sql = "SELECT COUNT(*) FROM state_groups_state WHERE room_id = ?"
rows = self.get_success(
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
)
return rows[0][0]

print(count_state_groups())
for token in batches:
token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
self.get_success(
self._storage_controllers.purge_events.purge_history(
self.room_id, token_str, False
)
)
print(count_state_groups())
Loading