Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve perf of sync device lists #17216

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ async def on_rdata(
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
prev_token = self.store.get_device_stream_token()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that familiar as always, but it sounds unintuitive to get the latest stream token and call it prev_token — what's going on here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this really needs a comment. Basically at this point we haven't updated the id generators, and so they don't know about the new stuff that has just come in (that happens in process_replication_position). What we're basically doing is getting an ID of "this is where we're currently at", so that we can get the deltas from the DB for what has changed since then upto the new token.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup OK I was hoping it was something like that, but a comment would be good to avoid wtfery

all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)

self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
Expand Down Expand Up @@ -146,12 +154,6 @@ async def on_rdata(
StreamKeyType.TO_DEVICE, token, users=entities
)
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)

# `all_room_ids` can be large, so let's wake up those streams in batches
for batched_room_ids in batch_iter(all_room_ids, 100):
self.notifier.on_new_event(
Expand Down
69 changes: 64 additions & 5 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ def __init__(
prefilled_cache=device_list_prefill,
)

device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_changes_in_room",
entity_column="room_id",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_room_stream_cache = StreamChangeCache(
"DeviceListRoomStreamChangeCache",
min_device_list_room_id,
prefilled_cache=device_list_room_prefill,
)

(
user_signature_stream_prefill,
user_signature_stream_list_id,
Expand Down Expand Up @@ -206,6 +220,13 @@ def _invalidate_caches_for_devices(
row.entity, token
)

def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
) -> None:
"Record that device lists have changed in rooms"
for room_id in room_ids:
self._device_list_room_stream_cache.entity_has_changed(room_id, token)

def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

Expand Down Expand Up @@ -1460,6 +1481,12 @@ async def get_device_list_changes_in_rooms(
if min_stream_id > from_id:
return None

changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
room_ids, from_id
)
if not changed_room_ids:
return set()

sql = """
SELECT DISTINCT user_id FROM device_lists_changes_in_room
WHERE {clause} AND stream_id > ? AND stream_id <= ?
Expand All @@ -1474,7 +1501,7 @@ def _get_device_list_changes_in_rooms_txn(
return {user_id for user_id, in txn}

changes = set()
for chunk in batch_iter(room_ids, 1000):
for chunk in batch_iter(changed_room_ids, 1000):
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", chunk
)
Expand All @@ -1490,6 +1517,34 @@ def _get_device_list_changes_in_rooms_txn(

return changes

async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
"""Return the set of rooms where devices have changed since the given
stream ID.

Will raise an exception if the given stream ID is too old.
"""

min_stream_id = await self._get_min_device_lists_changes_in_room()

if min_stream_id > from_id:
raise Exception("stream ID is too old")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really materially important, but feels like this should be a more specific exception type


sql = """
SELECT DISTINCT room_id FROM device_lists_changes_in_room
WHERE stream_id > ? AND stream_id <= ?
"""

def _get_all_device_list_changes_txn(
txn: LoggingTransaction,
) -> Set[str]:
txn.execute(sql, (from_id, to_id))
return {room_id for room_id, in txn}

return await self.db_pool.runInteraction(
"get_all_device_list_changes",
_get_all_device_list_changes_txn,
)

async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
Expand Down Expand Up @@ -1950,8 +2005,8 @@ def _update_remote_device_list_cache_txn(
async def add_device_change_to_streams(
self,
user_id: str,
device_ids: Collection[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
Expand Down Expand Up @@ -2110,8 +2165,8 @@ def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
stream_ids: List[int],
context: Dict[str, str],
) -> None:
Expand Down Expand Up @@ -2149,6 +2204,10 @@ def _add_device_outbound_room_poke_txn(
],
)

txn.call_after(
self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
)

async def get_uncoverted_outbound_room_pokes(
self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
Expand Down