Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve perf of looking up device lists changes in /sync #17205

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17205.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve performance of calculating device lists changes in `/sync`.
13 changes: 7 additions & 6 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ async def on_rdata(
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)

self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
Expand Down Expand Up @@ -146,12 +153,6 @@ async def on_rdata(
StreamKeyType.TO_DEVICE, token, users=entities
)
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)

# `all_room_ids` can be large, so let's wake up those streams in batches
for batched_room_ids in batch_iter(all_room_ids, 100):
self.notifier.on_new_event(
Expand Down
69 changes: 64 additions & 5 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ def __init__(
prefilled_cache=device_list_prefill,
)

device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_changes_in_room",
entity_column="room_id",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_room_stream_cache = StreamChangeCache(
"DeviceListRoomStreamChangeCache",
min_device_list_room_id,
prefilled_cache=device_list_room_prefill,
)

(
user_signature_stream_prefill,
user_signature_stream_list_id,
Expand Down Expand Up @@ -206,6 +220,13 @@ def _invalidate_caches_for_devices(
row.entity, token
)

def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
) -> None:
"Record that device lists have changed in rooms"
for room_id in room_ids:
self._device_list_room_stream_cache.entity_has_changed(room_id, token)

def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

Expand Down Expand Up @@ -1460,6 +1481,12 @@ async def get_device_list_changes_in_rooms(
if min_stream_id > from_id:
return None

changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
room_ids, from_id
)
if not changed_room_ids:
return set()

sql = """
SELECT DISTINCT user_id FROM device_lists_changes_in_room
WHERE {clause} AND stream_id > ? AND stream_id <= ?
Expand All @@ -1474,7 +1501,7 @@ def _get_device_list_changes_in_rooms_txn(
return {user_id for user_id, in txn}

changes = set()
for chunk in batch_iter(room_ids, 1000):
for chunk in batch_iter(changed_room_ids, 1000):
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", chunk
)
Expand All @@ -1490,6 +1517,34 @@ def _get_device_list_changes_in_rooms_txn(

return changes

async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
"""Return the set of rooms where devices have changed since the given
stream ID.

Will raise an exception if the given stream ID is too old.
"""

min_stream_id = await self._get_min_device_lists_changes_in_room()

if min_stream_id > from_id:
raise Exception("stream ID is too old")

sql = """
SELECT DISTINCT room_id FROM device_lists_changes_in_room
WHERE stream_id > ? AND stream_id <= ?
"""

def _get_all_device_list_changes_txn(
txn: LoggingTransaction,
) -> Set[str]:
txn.execute(sql, (from_id, to_id))
return {room_id for room_id, in txn}

return await self.db_pool.runInteraction(
"get_all_device_list_changes",
_get_all_device_list_changes_txn,
)

async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
Expand Down Expand Up @@ -1950,8 +2005,8 @@ def _update_remote_device_list_cache_txn(
async def add_device_change_to_streams(
self,
user_id: str,
device_ids: Collection[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
Expand Down Expand Up @@ -2110,8 +2165,8 @@ def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
stream_ids: List[int],
context: Dict[str, str],
) -> None:
Expand Down Expand Up @@ -2149,6 +2204,10 @@ def _add_device_outbound_room_poke_txn(
],
)

txn.call_after(
self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
)

async def get_uncoverted_outbound_room_pokes(
self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
Expand Down
Loading