diff --git a/changelog.d/12519.misc b/changelog.d/12519.misc new file mode 100644 index 000000000000..9c023d8e3ec6 --- /dev/null +++ b/changelog.d/12519.misc @@ -0,0 +1 @@ +Refactor the relations code for clarity. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 2174b4a0948b..f8d3ba5456a5 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -479,9 +479,9 @@ def _inject_bundled_aggregations( Args: event: The event being serialized. time_now: The current time in milliseconds + config: Event serialization config aggregations: The bundled aggregation to serialize. serialized_event: The serialized event which may be modified. - config: Event serialization config apply_edits: Whether the content of the event should be modified to reflect any replacement in `aggregations.replace`. """ diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0be231957750..5efb5612736f 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -256,64 +256,6 @@ async def get_annotations_for_event( return filtered_results - async def _get_bundled_aggregation_for_event( - self, event: EventBase, ignored_users: FrozenSet[str] - ) -> Optional[BundledAggregations]: - """Generate bundled aggregations for an event. - - Note that this does not use a cache, but depends on cached methods. - - Args: - event: The event to calculate bundled aggregations for. - ignored_users: The users ignored by the requesting user. - - Returns: - The bundled aggregations for an event, if bundled aggregations are - enabled and the event can have bundled aggregations. - """ - - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return None - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include, a mapping of relation type to a - # type-specific value. Some types include the direct return type here - # while others need more processing during serialization. - aggregations = BundledAggregations() - - annotations = await self.get_annotations_for_event( - event_id, room_id, ignored_users=ignored_users - ) - if annotations: - aggregations.annotations = {"chunk": annotations} - - references, next_token = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.REFERENCE, - ignored_users=ignored_users, - ) - if references: - aggregations.references = { - "chunk": [{"event_id": event.event_id} for event in references] - } - - if next_token: - aggregations.references["next_batch"] = await next_token.to_string( - self._main_store - ) - - # Store the bundled aggregations in the event metadata for later use. - return aggregations - async def get_threads_for_events( self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str] ) -> Dict[str, _ThreadAggregation]: @@ -435,11 +377,39 @@ async def get_bundled_aggregations( # Fetch other relations per event. for event in events_by_id.values(): - event_result = await self._get_bundled_aggregation_for_event( - event, ignored_users + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + continue + + annotations = await self.get_annotations_for_event( + event.event_id, event.room_id, ignored_users=ignored_users ) - if event_result: - results[event.event_id] = event_result + if annotations: + results.setdefault( + event.event_id, BundledAggregations() + ).annotations = {"chunk": annotations} + + references, next_token = await self.get_relations_for_event( + event.event_id, + event, + event.room_id, + RelationTypes.REFERENCE, + ignored_users=ignored_users, + ) + if references: + aggregations = results.setdefault(event.event_id, BundledAggregations()) + aggregations.references = { + "chunk": [{"event_id": ev.event_id} for ev in references] + } + + if next_token: + aggregations.references["next_batch"] = await next_token.to_string( + self._main_store + ) # Fetch any edits (but not for redacted events). # diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 65743cdf6777..39667e3225e2 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -560,43 +560,6 @@ def test_edit_reply(self) -> None: {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_edit_thread(self) -> None: - """Test that editing a thread works.""" - - # Create a thread and edit the last event. - channel = self._send_relation( - RelationTypes.THREAD, - "m.room.message", - content={"msgtype": "m.text", "body": "A threaded reply!"}, - ) - threaded_event_id = channel.json_body["event_id"] - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - parent_id=threaded_event_id, - ) - - # Fetch the thread root, to get the bundled aggregation for the thread. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # We expect that the edit message appears in the thread summary in the - # unsigned relations section. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.THREAD, relations_dict) - - thread_summary = relations_dict[RelationTypes.THREAD] - self.assertIn("latest_event", thread_summary) - latest_event_in_thread = thread_summary["latest_event"] - self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") - def test_edit_edit(self) -> None: """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} @@ -1047,7 +1010,7 @@ def test_thread(self) -> None: channel = self._send_relation(RelationTypes.THREAD, "m.room.test") thread_2 = channel.json_body["event_id"] - def assert_annotations(bundled_aggregations: JsonDict) -> None: + def assert_thread(bundled_aggregations: JsonDict) -> None: self.assertEqual(2, bundled_aggregations.get("count")) self.assertTrue(bundled_aggregations.get("current_user_participated")) # The latest thread event has some fields that don't matter. @@ -1066,7 +1029,38 @@ def assert_annotations(bundled_aggregations: JsonDict) -> None: bundled_aggregations.get("latest_event"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + + def test_thread_edit_latest_event(self) -> None: + """Test that editing the latest event in a thread works.""" + + # Create a thread and edit the last event. + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "A threaded reply!"}, + ) + threaded_event_id = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=threaded_event_id, + ) + + # Fetch the thread root, to get the bundled aggregation for the thread. + relations_dict = self._get_bundled_aggregations() + + # We expect that the edit message appears in the thread summary in the + # unsigned relations section. + self.assertIn(RelationTypes.THREAD, relations_dict) + + thread_summary = relations_dict[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event_in_thread = thread_summary["latest_event"] + self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") def test_aggregation_get_event_for_annotation(self) -> None: """Test that annotations do not get bundled aggregations included @@ -1093,7 +1087,7 @@ def test_aggregation_get_event_for_thread(self) -> None: channel = self._send_relation(RelationTypes.THREAD, "m.room.test") thread_id = channel.json_body["event_id"] - # Annotate the annotation. + # Annotate the thread. self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id )