diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 5284940ba809..695d52066796 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -23,6 +23,7 @@ """ import abc import contextlib +import itertools import logging from bisect import bisect from contextlib import contextmanager @@ -748,14 +749,18 @@ def __init__(self, hs: "HomeServer"): ] = {} # Keeps track of the number of *ongoing* syncs on other processes. + # # While any sync is ongoing on another process the user will never # go offline. + # # Each process has a unique identifier and an update frequency. If # no update is received from that process within the update period then # we assume that all the sync requests on that process have stopped. - # Stored as a dict from process_id to set of user_id, and a dict of - # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs: Dict[str, Set[str]] = {} + # Stored as a dict from process_id to set of (user_id, device_id), and + # a dict of process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs: Dict[ + str, Set[Tuple[str, Optional[str]]] + ] = {} self.external_process_last_updated_ms: Dict[str, int] = {} self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -960,7 +965,10 @@ async def _handle_timeouts(self) -> None: # that were syncing on that process to see if they need to be timed # out. users_to_check.update( - self.external_process_to_current_syncs.pop(process_id, ()) + user_id + for user_id, device_id in self.external_process_to_current_syncs.pop( + process_id, () + ) ) self.external_process_last_updated_ms.pop(process_id) @@ -973,14 +981,13 @@ async def _handle_timeouts(self) -> None: syncing_user_ids = { user_id - for ( - user_id, - device_id, - ), count in self._user_device_to_num_current_syncs.items() + for (user_id, _), count in self._user_device_to_num_current_syncs.items() if count } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) + for user_id, device_id in itertools.chain( + *self.external_process_to_current_syncs.values() + ): + syncing_user_ids.add(user_id) changes = handle_timeouts( states, @@ -1106,26 +1113,27 @@ async def update_external_syncs_row( process_id, set() ) - # USER_SYNC is sent when a user starts or stops syncing on a remote - # process. (But only for the initial and last device.) + # USER_SYNC is sent when a user's device starts or stops syncing on + # a remote # process. (But only for the initial and last sync for that + # device.) # - # When a user *starts* syncing it also calls set_state(...) which + # When a device *starts* syncing it also calls set_state(...) which # will update the state, last_active_ts, and last_user_sync_ts. - # Simply ensure the user is tracked as syncing in this case. + # Simply ensure the user & device is tracked as syncing in this case. # - # When a user *stops* syncing, update the last_user_sync_ts and mark + # When a device *stops* syncing, update the last_user_sync_ts and mark # them as no longer syncing. Note this doesn't quite match the # monolith behaviour, which updates last_user_sync_ts at the end of # every sync, not just the last in-flight sync. - if is_syncing and user_id not in process_presence: - process_presence.add(user_id) - elif not is_syncing and user_id in process_presence: + if is_syncing and (user_id, device_id) not in process_presence: + process_presence.add((user_id, device_id)) + elif not is_syncing and (user_id, device_id) in process_presence: new_state = prev_state.copy_and_replace( last_user_sync_ts=sync_time_msec ) await self._update_states([new_state]) - process_presence.discard(user_id) + process_presence.discard((user_id, device_id)) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() @@ -1139,7 +1147,9 @@ async def update_external_syncs_clear(self, process_id: str) -> None: process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = await self.current_state_for_users(process_presence) + prev_states = await self.current_state_for_users( + {user_id for user_id, device_id in process_presence} + ) time_now_ms = self.clock.time_msec() await self._update_states(