diff --git a/changelog.d/17915.bugfix b/changelog.d/17915.bugfix new file mode 100644 index 00000000000..a5d82e486db --- /dev/null +++ b/changelog.d/17915.bugfix @@ -0,0 +1 @@ +Fix experimental support for [MSC4222](https://github.com/matrix-org/matrix-spec-proposals/pull/4222) where we would return the full state on incremental syncs when using lazy loaded members and there were no new events in the timeline. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 204965afeec..df3010ecf68 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -196,7 +196,9 @@ async def get_state_events( AuthError (403) if the user doesn't have permission to view members of this room. """ - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() + user_id = requester.user.to_string() if at_token: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index df9a088063f..350c3fa09a1 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1520,7 +1520,7 @@ async def _compute_state_delta_for_incremental_sync( if sync_config.use_state_after: delta_state_ids: MutableStateMap[str] = {} - if members_to_fetch is not None: + if members_to_fetch: # We're lazy-loading, so the client might need some more member # events to understand the events in this timeline. So we always # fish out all the member events corresponding to the timeline diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index b50eb8868ec..f28f5d7e039 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -234,8 +234,11 @@ async def get_state_for_events( RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ + if state_filter is None: + state_filter = StateFilter.all() + await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + if not state_filter.must_await_full_state(self._is_mine_id): await_full_state = False event_to_groups = await self.get_state_group_for_events( @@ -244,7 +247,7 @@ async def get_state_for_events( groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() + groups, state_filter ) state_event_map = await self.stores.main.get_events( @@ -292,10 +295,11 @@ async def get_state_ids_for_events( RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - if ( - await_full_state - and state_filter - and not state_filter.must_await_full_state(self._is_mine_id) + if state_filter is None: + state_filter = StateFilter.all() + + if await_full_state and not state_filter.must_await_full_state( + self._is_mine_id ): # Full state is not required if the state filter is restrictive enough. await_full_state = False @@ -306,7 +310,7 @@ async def get_state_ids_for_events( groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() + groups, state_filter ) event_to_state = { @@ -335,9 +339,10 @@ async def get_state_for_event( RuntimeError if we don't have a state group for the event (ie it is an outlier or is unknown) """ - state_map = await self.get_state_for_events( - [event_id], state_filter or StateFilter.all() - ) + if state_filter is None: + state_filter = StateFilter.all() + + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] @trace @@ -365,9 +370,12 @@ async def get_state_ids_for_event( RuntimeError if we don't have a state group for the event (ie it is an outlier or is unknown) """ + if state_filter is None: + state_filter = StateFilter.all() + state_map = await self.get_state_ids_for_events( [event_id], - state_filter or StateFilter.all(), + state_filter, await_full_state=await_full_state, ) return state_map[event_id] @@ -388,9 +396,12 @@ async def get_state_after_event( at the event and `state_filter` is not satisfied by partial state. Defaults to `True`. """ + if state_filter is None: + state_filter = StateFilter.all() + state_ids = await self.get_state_ids_for_event( event_id, - state_filter=state_filter or StateFilter.all(), + state_filter=state_filter, await_full_state=await_full_state, ) @@ -426,6 +437,9 @@ async def get_state_ids_at( at the last event in the room before `stream_position` and `state_filter` is not satisfied by partial state. Defaults to `True`. """ + if state_filter is None: + state_filter = StateFilter.all() + # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time # of the stream token if there were multiple forward extremities at the time. @@ -442,7 +456,7 @@ async def get_state_ids_at( if last_event_id: state = await self.get_state_after_event( last_event_id, - state_filter=state_filter or StateFilter.all(), + state_filter=state_filter, await_full_state=await_full_state, ) @@ -500,9 +514,10 @@ async def get_state_for_groups( Returns: Dict of state group to state map. """ - return await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) + if state_filter is None: + state_filter = StateFilter.all() + + return await self.stores.state._get_state_for_groups(groups, state_filter) @trace @tag_args @@ -583,12 +598,13 @@ async def get_current_state_ids( Returns: The current state of the room. """ - if await_full_state and ( - not state_filter or state_filter.must_await_full_state(self._is_mine_id) - ): + if state_filter is None: + state_filter = StateFilter.all() + + if await_full_state and state_filter.must_await_full_state(self._is_mine_id): await self._partial_state_room_tracker.await_full_state(room_id) - if state_filter and not state_filter.is_full(): + if state_filter is not None and not state_filter.is_full(): return await self.stores.main.get_partial_filtered_current_state_ids( room_id, state_filter ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 42b3638e1c8..788f7d1e325 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -572,10 +572,10 @@ async def get_partial_filtered_current_state_ids( Returns: Map from type/state_key to event ID. """ + if state_filter is None: + state_filter = StateFilter.all() - where_clause, where_args = ( - state_filter or StateFilter.all() - ).make_sql_filter_clause() + where_clause, where_args = (state_filter).make_sql_filter_clause() if not where_clause: # We delegate to the cached version @@ -584,7 +584,7 @@ async def get_partial_filtered_current_state_ids( def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, ) -> StateMap[str]: - results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) + results = StateMapWrapper(state_filter=state_filter) sql = """ SELECT type, state_key, event_id FROM current_state_events diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index ea7d8199a7d..f7824cba0f2 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -112,8 +112,8 @@ def _get_state_groups_from_groups_txn( Returns: Map from state_group to a StateMap at that point. """ - - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 875dba33496..f7a59c8992d 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -284,7 +284,8 @@ async def _get_state_for_groups( Returns: Dict of state group to state map. """ - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() member_filter, non_member_filter = state_filter.get_member_split() diff --git a/synapse/types/state.py b/synapse/types/state.py index 67d1c3fe972..e641215f184 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -68,15 +68,23 @@ class StateFilter: include_others: bool = False def __attrs_post_init__(self) -> None: - # If `include_others` is set we canonicalise the filter by removing - # wildcards from the types dictionary if self.include_others: + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + # this is needed to work around the fact that StateFilter is frozen object.__setattr__( self, "types", immutabledict({k: v for k, v in self.types.items() if v is not None}), ) + else: + # Otherwise we remove entries where the value is the empty set. + object.__setattr__( + self, + "types", + immutabledict({k: v for k, v in self.types.items() if v is None or v}), + ) @staticmethod def all() -> "StateFilter": diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 1960d2f0e10..9dd0e98971b 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -1262,3 +1262,35 @@ def test_incremental_sync_multiple_deltas(self) -> None: ) ) self.assertEqual(state[("m.test_event", "")], second_state["event_id"]) + + def test_incremental_sync_lazy_loaded_no_timeline(self) -> None: + """Test that lazy-loading with an empty timeline doesn't return the full + state. + + There was a bug where an empty state filter would cause the DB to return + the full state, rather than an empty set. + """ + user = self.register_user("user", "password") + tok = self.login("user", "password") + + # Create a room as the user and set some custom state. + joined_room = self.helper.create_room_as(user, tok=tok) + + since_token = self.hs.get_event_sources().get_current_token() + end_stream_token = self.hs.get_event_sources().get_current_token() + + state = self.get_success( + self.sync_handler._compute_state_delta_for_incremental_sync( + room_id=joined_room, + sync_config=generate_sync_config(user, use_state_after=True), + batch=TimelineBatch( + prev_batch=end_stream_token, events=[], limited=True + ), + since_token=since_token, + end_token=end_stream_token, + members_to_fetch=set(), + timeline_state={}, + ) + ) + + self.assertEqual(state, {})