From 57311659ca9d87709c04dc10a428ecdcc286f951 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Jan 2023 15:53:28 +0000 Subject: [PATCH] STUFF for current token --- synapse/handlers/presence.py | 2 +- synapse/replication/http/_base.py | 4 +-- synapse/replication/tcp/client.py | 8 +++++- synapse/replication/tcp/handler.py | 8 +++++- synapse/replication/tcp/resource.py | 3 ++- synapse/replication/tcp/streams/_base.py | 24 +++++++++++++----- synapse/replication/tcp/streams/federation.py | 2 +- synapse/storage/databases/main/cache.py | 9 +++++-- synapse/storage/util/id_generators.py | 25 +++++++++++++------ 9 files changed, 63 insertions(+), 22 deletions(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 43e4e7b1b4c2..595ec36fa33d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -2178,7 +2178,7 @@ def send_presence_to_destinations( self._notifier.notify_replication() - def get_current_token(self, instance_name: str) -> int: + def get_current_token(self, instance_name: str, minimum: bool = False) -> int: """Get the current position of the stream. On workers this returns the last stream ID received from replication. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index b95b4777975c..908933083f6e 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -248,7 +248,7 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: data[_STREAM_POSITION_KEY] = { "streams": { - stream.NAME: stream.current_token(local_instance_name) + stream.NAME: stream.current_token(local_instance_name, True) for stream in streams }, "instance_name": local_instance_name, @@ -443,7 +443,7 @@ async def _check_auth_and_handle( if self.WAIT_FOR_STREAMS: response[_STREAM_POSITION_KEY] = { - stream.NAME: stream.current_token(self._instance_name) + stream.NAME: stream.current_token(self._instance_name, True) for stream in self._streams } diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 322d695bc7f0..fe2ae58c2d31 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -344,7 +344,13 @@ async def wait_for_stream_position( # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): - logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + logger.info( + "Waiting for repl stream %r to reach %s (%s) (current: %s)", + stream_name, + position, + instance_name, + current_position, + ) await make_deferred_yieldable(deferred) logger.info( "Finished waiting for repl stream %r to reach %s", stream_name, position diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d03a53d76429..598530c593e7 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -540,7 +540,13 @@ async def on_rdata( rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token) + logger.debug( + "%s: Received rdata %s (%s) -> %s", + self._instance_name, + stream_name, + instance_name, + token, + ) await self._replication_data_handler.on_rdata( stream_name, instance_name, token, rows ) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 9d17eff71451..d03c352acec8 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -160,7 +160,8 @@ async def _run_notifier_loop(self) -> None: for stream in all_streams: if stream.last_token == stream.current_token( - self._instance_name + self._instance_name, + minimum=stream.NAME == EventsStream.NAME, ): continue diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index cb782ee01363..6c25548d5b40 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -27,6 +27,7 @@ ) import attr +from typing_extensions import Protocol from synapse.api.constants import AccountDataTypes from synapse.replication.http.streams import ReplicationGetStreamUpdates @@ -78,6 +79,11 @@ UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] +class CurrentTokenFunction(Protocol): + def __call__(self, instance_name: str, minimum: bool = False) -> Token: + ... + + class Stream: """Base class for the streams. @@ -107,7 +113,7 @@ def parse_row(cls, row: StreamRow) -> Any: def __init__( self, local_instance_name: str, - current_token_function: Callable[[str], Token], + current_token_function: CurrentTokenFunction, update_function: UpdateFunction, ): """Instantiate a Stream @@ -192,12 +198,16 @@ async def get_updates_since( def current_token_without_instance( current_token: Callable[[], int] -) -> Callable[[str], int]: +) -> CurrentTokenFunction: """Takes a current token callback function for a single writer stream that doesn't take an instance name parameter and wraps it in a function that does accept an instance name parameter but ignores it. """ - return lambda instance_name: current_token() + + def expanded_current_token(instance_name: str, minimum: bool = False) -> int: + return current_token() + + return expanded_current_token def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction: @@ -246,10 +256,12 @@ def __init__(self, hs: "HomeServer"): self.store.get_all_new_backfill_event_rows, ) - def _current_token(self, instance_name: str) -> int: + def _current_token(self, instance_name: str, minimum: bool = False) -> int: # The backfill stream over replication operates on *positive* numbers, # which means we need to negate it. - return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name) + return -self.store._backfill_id_gen.get_current_token_for_writer( + instance_name, minimum + ) class PresenceStream(Stream): @@ -395,7 +407,7 @@ def __init__(self, hs: "HomeServer"): self.store.get_all_push_rule_updates, ) - def _current_token(self, instance_name: str) -> int: + def _current_token(self, instance_name: str, minimum: bool = False) -> int: push_rules_token = self.store.get_max_push_rules_stream_id() return push_rules_token diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 4046bdec6931..0fdfa618ca04 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -68,7 +68,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - def _stub_current_token(instance_name: str) -> int: + def _stub_current_token(instance_name: str, minimum: bool = False) -> int: # dummy current-token method for use on workers return 0 diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2179a8bf5922..9bf4a7135c46 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -184,6 +184,7 @@ def process_replication_position( ) -> None: if stream_name == CachesStream.NAME: if self._cache_id_gen: + logger.info("Advancing cache for %s to %s", instance_name, token) self._cache_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) @@ -402,8 +403,12 @@ def _send_invalidation_to_replication( }, ) - def get_cache_stream_token_for_writer(self, instance_name: str) -> int: + def get_cache_stream_token_for_writer( + self, instance_name: str, minimum: bool = False + ) -> int: if self._cache_id_gen: - return self._cache_id_gen.get_current_token_for_writer(instance_name) + return self._cache_id_gen.get_current_token_for_writer( + instance_name, minimum + ) else: return 0 diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index e66164c15fdd..765a8c76c939 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -119,7 +119,9 @@ def get_current_token(self) -> int: raise NotImplementedError() @abc.abstractmethod - def get_current_token_for_writer(self, instance_name: str) -> int: + def get_current_token_for_writer( + self, instance_name: str, minimum: bool = False + ) -> int: """Returns the position of the given writer. For streams with single writers this is equivalent to `get_current_token`. @@ -262,7 +264,9 @@ def get_current_token(self) -> int: return self._current - def get_current_token_for_writer(self, instance_name: str) -> int: + def get_current_token_for_writer( + self, instance_name: str, minimum: bool = False + ) -> int: return self.get_current_token() @@ -378,6 +382,8 @@ def __init__( self._current_positions.values(), default=1 ) + self._last_persisted_position = self._persisted_upto_position + def _load_current_ids( self, db_conn: LoggingDatabaseConnection, @@ -627,13 +633,16 @@ def _mark_id_as_finished(self, next_id: int) -> None: if new_cur: curr = self._current_positions.get(self._instance_name, 0) self._current_positions[self._instance_name] = max(curr, new_cur) + self._last_persisted_position = max(curr, new_cur) self._add_persisted_position(next_id) def get_current_token(self) -> int: return self.get_persisted_upto_position() - def get_current_token_for_writer(self, instance_name: str) -> int: + def get_current_token_for_writer( + self, instance_name: str, minimum: bool = False + ) -> int: # If we don't have an entry for the given instance name, we assume it's a # new writer. # @@ -641,10 +650,12 @@ def get_current_token_for_writer(self, instance_name: str) -> int: # persisted up to position. This stops Synapse from doing a full table # scan when a new writer announces itself over replication. with self._lock: - return self._return_factor * max( - self._current_positions.get(instance_name, 0), - self._persisted_upto_position, - ) + if minimum and instance_name == self._instance_name: + return self._last_persisted_position + else: + return self._return_factor * self._current_positions.get( + instance_name, self._persisted_upto_position + ) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map.