diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index b02fca12a68c..96fae2a46689 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -360,8 +360,7 @@ async def batch_persist_unpersisted_contexts( """ Takes a list of events and their associated unpersisted contexts and persists the unpersisted contexts, returning a list of events and persisted contexts. - Note that all the events must be in a linear chain (ie a <- b <- c) - and must be state events. + Note that all the events must be in a linear chain (ie a <- b <- c). Args: events_and_context: A list of events and their unpersisted contexts @@ -375,15 +374,14 @@ async def batch_persist_unpersisted_contexts( events_and_persisted_context = [] for event, unpersisted_context in amended_events_and_context: - assert unpersisted_context.partial_state is not None context = EventContext( storage=unpersisted_context._storage, state_group=unpersisted_context.state_group_after_event, state_group_before_event=unpersisted_context.state_group_before_event, state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, partial_state=unpersisted_context.partial_state, - prev_group=unpersisted_context.prev_group_for_state_group_after_event, - delta_ids=unpersisted_context.delta_ids_to_state_group_after_event, + prev_group=unpersisted_context.state_group_before_event, + delta_ids=unpersisted_context.state_delta_due_to_event, ) events_and_persisted_context.append((event, context)) return events_and_persisted_context diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 176c8afb97a7..7873064b5e05 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -413,8 +413,7 @@ async def store_state_deltas_for_batched( prev_group: int, ) -> List[Tuple[EventBase, UnpersistedEventContext]]: """Generate and store state deltas for a group of events and contexts created to be - batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c) - and must be state events. + batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c). Args: events_and_context: the events to generate and store a state groups for @@ -449,31 +448,32 @@ def insert_deltas_group_txn( % (prev_group,) ) - num_state_groups = len(events_and_context) + num_state_groups = 0 + for event, _ in events_and_context: + if event.is_state(): + num_state_groups += 1 state_groups = self._state_group_seq_gen.get_next_mult_txn( txn, num_state_groups ) + sg_before = prev_group for index, (event, context) in enumerate(events_and_context): - context.state_group_after_event = state_groups[index] - # The first prev_group will be the last persisted state group, which is passed in - # else it will be the group most recently assigned - if index > 0: - context.prev_group_for_state_group_after_event = state_groups[ - index - 1 - ] - context.state_group_before_event = state_groups[index - 1] - else: - context.prev_group_for_state_group_after_event = prev_group - context.state_group_before_event = prev_group - context.delta_ids_to_state_group_after_event = { + if not event.is_state(): + context.state_group_after_event = sg_before + context.state_group_before_event = sg_before + pass + + sg_after = state_groups[index] + context.state_group_after_event = sg_after + context.state_group_before_event = sg_before + context.delta_ids_to_state_group_before_event = { (event.type, event.state_key): event.event_id } context.state_delta_due_to_event = { (event.type, event.state_key): event.event_id } - index += 1 + sg_before = sg_after self.db_pool.simple_insert_many_txn( txn, @@ -492,19 +492,20 @@ def insert_deltas_group_txn( values=[ ( context.state_group_after_event, - context.prev_group_for_state_group_after_event, + context.state_group_before_event, ) for _, context in events_and_context ], ) + values = [] for _, context in events_and_context: - assert context.delta_ids_to_state_group_after_event is not None - self.db_pool.simple_insert_many_txn( - txn, - table="state_groups_state", - keys=("state_group", "room_id", "type", "state_key", "event_id"), - values=[ + assert context.delta_ids_to_state_group_before_event is not None + for ( + key, + state_id, + ) in context.delta_ids_to_state_group_before_event.items(): + values.append( ( context.state_group_after_event, room_id, @@ -512,9 +513,14 @@ def insert_deltas_group_txn( key[1], state_id, ) - for key, state_id in context.delta_ids_to_state_group_after_event.items() - ], - ) + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=values, + ) return events_and_context return await self.db_pool.runInteraction( diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 46df79f730a5..ea61dd00ae09 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -379,7 +379,7 @@ def test_suppress_edits(self) -> None: bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Create & persist an event to use as the parent of the relation. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -393,6 +393,7 @@ def test_suppress_edits(self) -> None: }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self.event_creation_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)]