diff --git a/changelog.d/17130.misc b/changelog.d/17130.misc new file mode 100644 index 00000000000..ac20c90bdea --- /dev/null +++ b/changelog.d/17130.misc @@ -0,0 +1 @@ +Add optimisation to `StreamChangeCache.get_entities_changed(..)`. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index d8253bd942b..7488ba56f5d 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -165,7 +165,7 @@ def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: return False def get_entities_changed( - self, entities: Collection[EntityType], stream_pos: int + self, entities: Collection[EntityType], stream_pos: int, _perf_factor: int = 1 ) -> Union[Set[EntityType], FrozenSet[EntityType]]: """ Returns the subset of the given entities that have had changes after the given position. @@ -177,6 +177,8 @@ def get_entities_changed( Args: entities: Entities to check for changes. stream_pos: The stream position to check for changes after. + _perf_factor: Used by unit tests to choose when to use each + optimisation. Return: A subset of entities which have changed after the given stream position. @@ -184,6 +186,22 @@ def get_entities_changed( This will be all entities if the given stream position is at or earlier than the earliest known stream position. """ + if not self._cache or stream_pos <= self._earliest_known_stream_pos: + self.metrics.inc_misses() + return set(entities) + + # If there have been tonnes of changes compared with the number of + # entities, it is faster to check each entities stream ordering + # one-by-one. + max_stream_pos, _ = self._cache.peekitem() + if max_stream_pos - stream_pos > _perf_factor * len(entities): + self.metrics.inc_hits() + return { + entity + for entity in entities + if self._entity_to_key.get(entity, -1) > stream_pos + } + cache_result = self.get_all_entities_changed(stream_pos) if cache_result.hit: # We now do an intersection, trying to do so in the most efficient diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 3df053493b3..5d38718a509 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,3 +1,5 @@ +from parameterized import parameterized + from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest @@ -161,7 +163,8 @@ def test_has_any_entity_changed(self) -> None: self.assertFalse(cache.has_any_entity_changed(2)) self.assertFalse(cache.has_any_entity_changed(3)) - def test_get_entities_changed(self) -> None: + @parameterized.expand([(0,), (1000000000,)]) + def test_get_entities_changed(self, perf_factor: int) -> None: """ StreamChangeCache.get_entities_changed will return the entities in the given list that have changed since the provided stream ID. If the @@ -178,7 +181,9 @@ def test_get_entities_changed(self) -> None: # get the ones after that point. self.assertEqual( cache.get_entities_changed( - ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], + stream_pos=2, + _perf_factor=perf_factor, ), {"bar@baz.net", "user@elsewhere.org"}, ) @@ -195,6 +200,7 @@ def test_get_entities_changed(self) -> None: "not@here.website", ], stream_pos=2, + _perf_factor=perf_factor, ), {"bar@baz.net", "user@elsewhere.org"}, ) @@ -210,6 +216,7 @@ def test_get_entities_changed(self) -> None: "not@here.website", ], stream_pos=0, + _perf_factor=perf_factor, ), {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"}, ) @@ -217,7 +224,11 @@ def test_get_entities_changed(self) -> None: # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed(["bar@baz.net"], stream_pos=2), + cache.get_entities_changed( + ["bar@baz.net"], + stream_pos=2, + _perf_factor=perf_factor, + ), {"bar@baz.net"}, )