From a2d8ed5439ed2e6e53abdaf6df9e50b65cc146c5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Jan 2023 15:32:09 +0000 Subject: [PATCH 1/8] Always pass response body --- synapse/replication/http/_base.py | 13 ++++++++--- synapse/replication/http/account_data.py | 29 +++++++++++++----------- synapse/replication/http/devices.py | 10 +++----- synapse/replication/http/federation.py | 28 ++++++++--------------- synapse/replication/http/login.py | 5 +--- synapse/replication/http/membership.py | 22 ++++++++---------- synapse/replication/http/presence.py | 7 ++---- synapse/replication/http/push.py | 5 +--- synapse/replication/http/register.py | 9 ++------ synapse/replication/http/send_event.py | 5 +--- synapse/replication/http/send_events.py | 4 +--- synapse/replication/http/state.py | 2 +- synapse/replication/http/streams.py | 2 +- tests/replication/http/test__base.py | 4 ++-- 14 files changed, 60 insertions(+), 85 deletions(-) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 3f4d3fc51ae3..3fccc380c078 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -27,6 +27,7 @@ from synapse.api.errors import HttpResponseException, SynapseError from synapse.http import RequestTimedOutError from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging import opentracing from synapse.logging.opentracing import trace_with_opname @@ -160,7 +161,7 @@ async def _serialize_payload(**kwargs) -> JsonDict: @abc.abstractmethod async def _handle_request( - self, request: Request, **kwargs: Any + self, request: Request, content: JsonDict, **kwargs: Any ) -> Tuple[int, JsonDict]: """Handle incoming request. @@ -353,6 +354,12 @@ async def _check_auth_and_handle( if self._replication_secret: self._check_auth(request) + if self.METHOD == "GET": + # GET APIs always have an empty body. + content = {} + else: + content = parse_json_object_from_request(request) + if self.CACHE: txn_id = kwargs.pop("txn_id") @@ -362,7 +369,7 @@ async def _check_auth_and_handle( # context lifetimes. return await self.response_cache.wrap( - txn_id, self._handle_request, request, **kwargs + txn_id, self._handle_request, request, content, **kwargs ) # The `@cancellable` decorator may be applied to `_handle_request`. But we @@ -370,4 +377,4 @@ async def _check_auth_and_handle( # so we have to set up the cancellable flag ourselves. request.is_render_cancellable = is_function_cancellable(self._handle_request) - return await self._handle_request(request, **kwargs) + return await self._handle_request(request, content, **kwargs) diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 0edc95977b3a..2374f810c94f 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -61,10 +60,8 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, account_data_type: str + self, request: Request, content: JsonDict, user_id: str, account_data_type: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - max_stream_id = await self.handler.add_account_data_for_user( user_id, account_data_type, content["content"] ) @@ -101,7 +98,7 @@ async def _serialize_payload( # type: ignore[override] return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, account_data_type: str + self, request: Request, content: JsonDict, user_id: str, account_data_type: str ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_account_data_for_user( user_id, account_data_type @@ -143,10 +140,13 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, account_data_type: str + self, + request: Request, + content: JsonDict, + user_id: str, + room_id: str, + account_data_type: str, ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - max_stream_id = await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, content["content"] ) @@ -183,7 +183,12 @@ async def _serialize_payload( # type: ignore[override] return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, account_data_type: str + self, + request: Request, + content: JsonDict, + user_id: str, + room_id: str, + account_data_type: str, ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_account_data_for_room( user_id, room_id, account_data_type @@ -225,10 +230,8 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, tag: str + self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - max_stream_id = await self.handler.add_tag_to_room( user_id, room_id, tag, content["content"] ) @@ -266,7 +269,7 @@ async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, tag: str + self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_tag_from_room( user_id, diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index ea5c08e6cfdf..ecea6fc915c7 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.logging.opentracing import active_span from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -78,7 +77,7 @@ async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) @@ -138,9 +137,8 @@ async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[o return {"user_ids": user_ids} async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, Dict[str, Optional[JsonDict]]]: - content = parse_json_object_from_request(request) user_ids: List[str] = content["user_ids"] logger.info("Resync for %r", user_ids) @@ -205,10 +203,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - user_id = content["user_id"] device_id = content["device_id"] keys = content["keys"] diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index d3abafed2871..53ad32703029 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -21,7 +21,6 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict from synapse.util.metrics import Measure @@ -114,10 +113,8 @@ async def _serialize_payload( # type: ignore[override] return payload - async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override] + async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override] with Measure(self.clock, "repl_fed_send_events_parse"): - content = parse_json_object_from_request(request) - room_id = content["room_id"] backfilled = content["backfilled"] @@ -181,13 +178,10 @@ async def _serialize_payload( # type: ignore[override] return {"origin": origin, "content": content} async def _handle_request( # type: ignore[override] - self, request: Request, edu_type: str + self, request: Request, content: JsonDict, edu_type: str ) -> Tuple[int, JsonDict]: - with Measure(self.clock, "repl_fed_send_edu_parse"): - content = parse_json_object_from_request(request) - - origin = content["origin"] - edu_content = content["content"] + origin = content["origin"] + edu_content = content["content"] logger.info("Got %r edu from %s", edu_type, origin) @@ -231,13 +225,10 @@ async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # ty return {"args": args} async def _handle_request( # type: ignore[override] - self, request: Request, query_type: str + self, request: Request, content: JsonDict, query_type: str ) -> Tuple[int, JsonDict]: - with Measure(self.clock, "repl_fed_query_parse"): - content = parse_json_object_from_request(request) - - args = content["args"] - args["origin"] = content["origin"] + args = content["args"] + args["origin"] = content["origin"] logger.info("Got %r query from %s", query_type, args["origin"]) @@ -274,7 +265,7 @@ async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override return {} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str + self, request: Request, content: JsonDict, room_id: str ) -> Tuple[int, JsonDict]: await self.store.clean_room_for_join(room_id) @@ -307,9 +298,8 @@ async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDic return {"room_version": room_version.identifier} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str + self, request: Request, content: JsonDict, room_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) return 200, {} diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index c68e18da129b..6ad6cb1bfe4e 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -73,10 +72,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 663bff573848..9fa1060d48f6 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -17,7 +17,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID @@ -79,10 +78,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: SynapseRequest, room_id: str, user_id: str + self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - remote_room_hosts = content["remote_room_hosts"] event_content = content["content"] @@ -147,11 +144,10 @@ async def _serialize_payload( # type: ignore[override] async def _handle_request( # type: ignore[override] self, request: SynapseRequest, + content: JsonDict, room_id: str, user_id: str, ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - remote_room_hosts = content["remote_room_hosts"] event_content = content["content"] @@ -217,10 +213,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: SynapseRequest, invite_event_id: str + self, request: SynapseRequest, content: JsonDict, invite_event_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - txn_id = content["txn_id"] event_content = content["content"] @@ -285,10 +279,9 @@ async def _serialize_payload( # type: ignore[override] async def _handle_request( # type: ignore[override] self, request: SynapseRequest, + content: JsonDict, knock_event_id: str, ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - txn_id = content["txn_id"] event_content = content["content"] @@ -347,7 +340,12 @@ async def _serialize_payload( # type: ignore[override] return {} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str, user_id: str, change: str + self, + request: Request, + content: JsonDict, + room_id: str, + user_id: str, + change: str, ) -> Tuple[int, JsonDict]: logger.info("user membership change: %s in %s", user_id, room_id) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index 4a5b08f56f73..db16aac9c206 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, UserID @@ -56,7 +55,7 @@ async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: await self._presence_handler.bump_presence_active_time( UserID.from_string(user_id) @@ -107,10 +106,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - await self._presence_handler.set_state( UserID.from_string(user_id), content["state"], diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py index af5c2f66a735..297e8ad564bd 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -61,10 +60,8 @@ async def _serialize_payload(app_id: str, pushkey: str, user_id: str) -> JsonDic return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - app_id = content["app_id"] pushkey = content["pushkey"] diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 976c2833603d..265e601b96a9 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -96,10 +95,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - await self.registration_handler.check_registration_ratelimit(content["address"]) # Always default admin users to approved (since it means they were created by @@ -150,10 +147,8 @@ async def _serialize_payload( # type: ignore[override] return {"auth_result": auth_result, "access_token": access_token} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - auth_result = content["auth_result"] access_token = content["access_token"] diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 4215a1c1bc41..27ad91407502 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -21,7 +21,6 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID from synapse.util.metrics import Measure @@ -114,11 +113,9 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request, event_id: str + self, request: Request, content: JsonDict, event_id: str ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_send_event_parse"): - content = parse_json_object_from_request(request) - event_dict = content["event"] room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]] internal_metadata = content["internal_metadata"] diff --git a/synapse/replication/http/send_events.py b/synapse/replication/http/send_events.py index 8889bbb644e1..4f82c9f96daa 100644 --- a/synapse/replication/http/send_events.py +++ b/synapse/replication/http/send_events.py @@ -21,7 +21,6 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID from synapse.util.metrics import Measure @@ -114,10 +113,9 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, payload: JsonDict ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_send_events_parse"): - payload = parse_json_object_from_request(request) events_and_context = [] events = payload["events"] diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py index 838b7584e56f..0c524e7de3fd 100644 --- a/synapse/replication/http/state.py +++ b/synapse/replication/http/state.py @@ -57,7 +57,7 @@ async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override return {} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str + self, request: Request, content: JsonDict, room_id: str ) -> Tuple[int, JsonDict]: writer_instance = self._events_shard_config.get_instance(room_id) if writer_instance != self._instance_name: diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index c06522536254..3a7b8a1a4ceb 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -67,7 +67,7 @@ async def _serialize_payload( # type: ignore[override] return {"from_token": from_token, "upto_token": upto_token} async def _handle_request( # type: ignore[override] - self, request: Request, stream_name: str + self, request: Request, content: JsonDict, stream_name: str ) -> Tuple[int, JsonDict]: stream = self.streams.get(stream_name) if stream is None: diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py index 936ab4504a79..349584941e2d 100644 --- a/tests/replication/http/test__base.py +++ b/tests/replication/http/test__base.py @@ -44,7 +44,7 @@ async def _serialize_payload() -> JsonDict: @cancellable async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, JsonDict]: await self.clock.sleep(1.0) return HTTPStatus.OK, {"result": True} @@ -64,7 +64,7 @@ async def _serialize_payload() -> JsonDict: return {} async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, JsonDict]: await self.clock.sleep(1.0) return HTTPStatus.OK, {"result": True} From e36ff7bd278da57a849a28ef2bc3522bb3c2e560 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Jan 2023 15:51:48 +0000 Subject: [PATCH 2/8] Wait for streams to catch up when processing HTTP replication. This should hopefully mitigate a class of races where data gets out of sync due a HTTP replication request racing with the replication streams. --- synapse/replication/http/_base.py | 76 +++++++++++++++++++++++++--- tests/replication/http/test__base.py | 4 +- 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 3fccc380c078..225b2bc46437 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -54,6 +54,9 @@ ) +_STREAM_POSITION_KEY = "_INT_STREAM_POS" + + class ReplicationEndpoint(metaclass=abc.ABCMeta): """Helper base class for defining new replication HTTP endpoints. @@ -127,6 +130,10 @@ def __init__(self, hs: "HomeServer"): if hs.config.worker.worker_replication_secret: self._replication_secret = hs.config.worker.worker_replication_secret + self._streams = hs.get_replication_command_handler().get_streams_to_replicate() + self._replication = hs.get_replication_data_handler() + self._instance_name = hs.get_instance_name() + def _check_auth(self, request: Request) -> None: # Get the authorization header. auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") @@ -202,6 +209,10 @@ def make_client(cls, hs: "HomeServer") -> Callable: @trace_with_opname("outgoing_replication_request") async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: + # We have to pull these out here to avoid circular dependencies... + streams = hs.get_replication_command_handler().get_streams_to_replicate() + replication = hs.get_replication_data_handler() + with outgoing_gauge.track_inprogress(): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") @@ -220,6 +231,24 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: data = await cls._serialize_payload(**kwargs) + if cls.METHOD != "GET": + # Include the current stream positions that we write to. We + # don't do this for GETs as they don't have a body, and we + # generally assume that a GET won't rely on data we have + # written. + if _STREAM_POSITION_KEY in data: + raise Exception( + "data to send contains %r key", _STREAM_POSITION_KEY + ) + + data[_STREAM_POSITION_KEY] = { + "streams": { + stream.NAME: stream.current_token(local_instance_name) + for stream in streams + }, + "instance_name": local_instance_name, + } + url_args = [ urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS ] @@ -309,6 +338,17 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: ) from e _outgoing_request_counter.labels(cls.NAME, 200).inc() + + # Wait on any streams that the remote may have written to. + for stream_name, position in result.get( + _STREAM_POSITION_KEY, {} + ).items(): + await replication.wait_for_stream_position( + instance_name=instance_name, + stream_name=stream_name, + position=position, + ) + return result return send_request @@ -360,6 +400,16 @@ async def _check_auth_and_handle( else: content = parse_json_object_from_request(request) + # Wait on any streams that the remote may have written to. + for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[ + "streams" + ].items(): + await self._replication.wait_for_stream_position( + instance_name=content[_STREAM_POSITION_KEY]["instance_name"], + stream_name=stream_name, + position=position, + ) + if self.CACHE: txn_id = kwargs.pop("txn_id") @@ -368,13 +418,27 @@ async def _check_auth_and_handle( # correctly yet. In particular, there may be issues to do with logging # context lifetimes. - return await self.response_cache.wrap( + code, response = await self.response_cache.wrap( txn_id, self._handle_request, request, content, **kwargs ) + else: + # The `@cancellable` decorator may be applied to `_handle_request`. But we + # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`, + # so we have to set up the cancellable flag ourselves. + request.is_render_cancellable = is_function_cancellable( + self._handle_request + ) + + code, response = await self._handle_request(request, content, **kwargs) + + # Return streams we may have written to in the course of processing this + # request. + if _STREAM_POSITION_KEY in response: + raise Exception("data to send contains %r key", _STREAM_POSITION_KEY) - # The `@cancellable` decorator may be applied to `_handle_request`. But we - # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`, - # so we have to set up the cancellable flag ourselves. - request.is_render_cancellable = is_function_cancellable(self._handle_request) + response[_STREAM_POSITION_KEY] = { + stream.NAME: stream.current_token(self._instance_name) + for stream in self._streams + } - return await self._handle_request(request, content, **kwargs) + return code, response diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py index 349584941e2d..75a2b1e0dec6 100644 --- a/tests/replication/http/test__base.py +++ b/tests/replication/http/test__base.py @@ -85,7 +85,7 @@ def create_test_resource(self): def test_cancellable_disconnect(self) -> None: """Test that handlers with the `@cancellable` flag can be cancelled.""" path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/" - channel = self.make_request("POST", path, await_result=False) + channel = self.make_request("POST", path, await_result=False, content={}) test_disconnect( self.reactor, channel, @@ -96,7 +96,7 @@ def test_cancellable_disconnect(self) -> None: def test_uncancellable_disconnect(self) -> None: """Test that handlers without the `@cancellable` flag cannot be cancelled.""" path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/" - channel = self.make_request("POST", path, await_result=False) + channel = self.make_request("POST", path, await_result=False, content={}) test_disconnect( self.reactor, channel, From 7f2700bd6a2e602ea5f21a7efdac37ba20bdb0d6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Jan 2023 17:51:48 +0000 Subject: [PATCH 3/8] Don't wait for streams when asking for stream updates Otherwise we can deadlock as we wait for the positions we are asking for. --- synapse/replication/http/_base.py | 16 +++++++++++----- synapse/replication/http/streams.py | 4 ++++ tests/replication/http/test__base.py | 1 + 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 225b2bc46437..b95b4777975c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -98,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): a connection error is received. RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when receiving connection errors, each will backoff exponentially longer. + WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to + catch up before processing the request and/or response. Defaults to + True. """ NAME: str = abc.abstractproperty() # type: ignore @@ -108,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): RETRY_ON_CONNECT_ERROR = True RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1) + WAIT_FOR_STREAMS = True + def __init__(self, hs: "HomeServer"): if self.CACHE: self.response_cache: ResponseCache[str] = ResponseCache( @@ -231,7 +236,7 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: data = await cls._serialize_payload(**kwargs) - if cls.METHOD != "GET": + if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS: # Include the current stream positions that we write to. We # don't do this for GETs as they don't have a body, and we # generally assume that a GET won't rely on data we have @@ -436,9 +441,10 @@ async def _check_auth_and_handle( if _STREAM_POSITION_KEY in response: raise Exception("data to send contains %r key", _STREAM_POSITION_KEY) - response[_STREAM_POSITION_KEY] = { - stream.NAME: stream.current_token(self._instance_name) - for stream in self._streams - } + if self.WAIT_FOR_STREAMS: + response[_STREAM_POSITION_KEY] = { + stream.NAME: stream.current_token(self._instance_name) + for stream in self._streams + } return code, response diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index 3a7b8a1a4ceb..3c7b5b18eab8 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -54,6 +54,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): PATH_ARGS = ("stream_name",) METHOD = "GET" + # We don't want to wait for replication streams to catch up, as this gets + # called in the process of catching replication streams up. + WAIT_FOR_STREAMS = False + def __init__(self, hs: "HomeServer"): super().__init__(hs) diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py index 75a2b1e0dec6..e03d9b4cc098 100644 --- a/tests/replication/http/test__base.py +++ b/tests/replication/http/test__base.py @@ -54,6 +54,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint): NAME = "uncancellable_sleep" PATH_ARGS = () CACHE = False + WAIT_FOR_STREAMS = False def __init__(self, hs: HomeServer): super().__init__(hs) From 473cc101ee1438d80523add0c8006bf0de9a9fba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Jan 2023 12:02:07 +0000 Subject: [PATCH 4/8] Send out `POSITION` commands for all streams This is so that if a stream advances their position *without* writing a row to the stream, other instances will get told about the updated position quickly anyway. --- synapse/replication/tcp/resource.py | 43 +++++++++++++---------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 99f09669f00b..9d17eff71451 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -199,33 +199,28 @@ async def _run_notifier_loop(self) -> None: # The token has advanced but there is no data to # send, so we send a `POSITION` to inform other # workers of the updated position. - if stream.NAME == EventsStream.NAME: - # XXX: We only do this for the EventStream as it - # turns out that e.g. account data streams share - # their "current token" with each other, meaning - # that it is *not* safe to send a POSITION. - - # Note: `last_token` may not *actually* be the - # last token we sent out in a RDATA or POSITION. - # This can happen if we sent out an RDATA for - # position X when our current token was say X+1. - # Other workers will see RDATA for X and then a - # POSITION with last token of X+1, which will - # cause them to check if there were any missing - # updates between X and X+1. - logger.info( - "Sending position: %s -> %s", + + # Note: `last_token` may not *actually* be the + # last token we sent out in a RDATA or POSITION. + # This can happen if we sent out an RDATA for + # position X when our current token was say X+1. + # Other workers will see RDATA for X and then a + # POSITION with last token of X+1, which will + # cause them to check if there were any missing + # updates between X and X+1. + logger.info( + "Sending position: %s -> %s", + stream.NAME, + current_token, + ) + self.command_handler.send_command( + PositionCommand( stream.NAME, + self._instance_name, + last_token, current_token, ) - self.command_handler.send_command( - PositionCommand( - stream.NAME, - self._instance_name, - last_token, - current_token, - ) - ) + ) continue # Some streams return multiple rows with the same stream IDs, From a03ee6ec7d200dda62c0ef96511c73f410c97360 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2023 09:29:38 +0000 Subject: [PATCH 5/8] Change ID generator to return position of last write This is already true when asking for stream positions of other instances, but for our own instance we have fudged it. Changing this should be fine (it was just an optimisation), and I don't think it should have much impact in practice at all. The reason to do this is so that when tell remotes what our current position is we only include *our* writes, rather than writes of other instances. This reduces delays when the remote instance is waiting for stream positions to update. In practice, this is probably only a problem for tests, though we may as well do it for all of them. --- synapse/handlers/federation_event.py | 4 ++++ synapse/storage/util/id_generators.py | 34 +++++++++++++++------------ synapse/types/__init__.py | 6 +++++ tests/storage/test_id_generators.py | 20 +++++++--------- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 6df000faafed..904a721483c9 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2259,6 +2259,10 @@ async def persist_events_and_notify( event_and_contexts, backfilled=backfilled ) + # After persistence we always need to notify replication there may + # be new data. + self._notifier.notify_replication() + if self._ephemeral_messages_enabled: for event in events: # If there's an expiry timestamp on the event, schedule its expiry. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 0d7108f01b41..8670ffbfa374 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -378,6 +378,12 @@ def __init__( self._current_positions.values(), default=1 ) + if not writers: + # If there have been no explicit writers given then any instance can + # write to the stream. In which case, let's pre-seed our own + # position with the current minimum. + self._current_positions[self._instance_name] = self._persisted_upto_position + def _load_current_ids( self, db_conn: LoggingDatabaseConnection, @@ -695,24 +701,22 @@ def _add_persisted_position(self, new_id: int) -> None: heapq.heappush(self._known_persisted_positions, new_id) - # If we're a writer and we don't have any active writes we update our - # current position to the latest position seen. This allows the instance - # to report a recent position when asked, rather than a potentially old - # one (if this instance hasn't written anything for a while). - our_current_position = self._current_positions.get(self._instance_name) - if ( - our_current_position - and not self._unfinished_ids - and not self._in_flight_fetches - ): - self._current_positions[self._instance_name] = max( - our_current_position, new_id - ) - # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). - min_curr = min(self._current_positions.values(), default=0) + our_current_position = self._current_positions.get(self._instance_name, 0) + min_curr = min( + ( + token + for name, token in self._current_positions.items() + if name != self._instance_name + ), + default=our_current_position, + ) + + if our_current_position and (self._unfinished_ids or self._in_flight_fetches): + min_curr = min(min_curr, our_current_position) + self._persisted_upto_position = max(min_curr, self._persisted_upto_position) # We now iterate through the seen positions, discarding those that are diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 0c725eb9677d..c59eca24301e 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -604,6 +604,12 @@ async def to_string(self, store: "DataStore") -> str: elif self.instance_map: entries = [] for name, pos in self.instance_map.items(): + if pos <= self.stream: + # Ignore instances who are below the minimum stream position + # (we might know they've advanced without seeing a recent + # write from them). + continue + instance_id = await store.get_id_for_instance(name) entries.append(f"{instance_id}.{pos}") diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index d6a2b8d2743e..ff9691c518bc 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -349,8 +349,8 @@ def test_multi_instance(self) -> None: # The first ID gen will notice that it can advance its token to 7 as it # has no in progress writes... - self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) # ... but the second ID gen doesn't know that. @@ -366,8 +366,9 @@ async def _get_next_async() -> None: self.assertEqual(stream_id, 8) self.assertEqual( - first_id_gen.get_positions(), {"first": 7, "second": 7} + first_id_gen.get_positions(), {"first": 3, "second": 7} ) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.get_success(_get_next_async()) @@ -473,7 +474,7 @@ def test_get_persisted_upto_position_get_next(self) -> None: id_gen = self._create_id_generator("first", writers=["first", "second"]) - self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5}) + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_persisted_upto_position(), 5) @@ -720,7 +721,7 @@ async def _get_next_async2() -> None: self.get_success(_get_next_async2()) - self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2}) + self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) @@ -816,15 +817,12 @@ def test_load_existing_stream(self) -> None: first_id_gen = self._create_id_generator("first", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - # The first ID gen will notice that it can advance its token to 7 as it - # has no in progress writes... - self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) - # ... but the second ID gen doesn't know that. self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) - self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) + self.assertEqual(second_id_gen.get_persisted_upto_position(), 7) From 6edfd625d1b59189c360e752048672a7999f04c3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Jan 2023 15:53:45 +0000 Subject: [PATCH 6/8] Newsfile --- changelog.d/14820.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/14820.bugfix diff --git a/changelog.d/14820.bugfix b/changelog.d/14820.bugfix new file mode 100644 index 000000000000..36e94f2b9b96 --- /dev/null +++ b/changelog.d/14820.bugfix @@ -0,0 +1 @@ +Fix rare races when using workers. From 01ae502c0c8dfa5b758f2ba023ef1a2b55d50645 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Jan 2023 13:43:51 +0000 Subject: [PATCH 7/8] Make ClassVar --- synapse/replication/http/_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index b95b4777975c..908f3f1db7da 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -17,7 +17,7 @@ import re import urllib.parse from inspect import signature -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple from prometheus_client import Counter, Gauge @@ -111,7 +111,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): RETRY_ON_CONNECT_ERROR = True RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1) - WAIT_FOR_STREAMS = True + WAIT_FOR_STREAMS: ClassVar[bool] = True def __init__(self, hs: "HomeServer"): if self.CACHE: From ace2b8c5763fe889450d23afd313fb212e96fd4a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Jan 2023 13:49:29 +0000 Subject: [PATCH 8/8] Don't fail if waiting for stream update times out --- synapse/replication/http/_base.py | 2 ++ synapse/replication/tcp/client.py | 25 +++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 908f3f1db7da..709327b97fb4 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -352,6 +352,7 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: instance_name=instance_name, stream_name=stream_name, position=position, + raise_on_timeout=False, ) return result @@ -413,6 +414,7 @@ async def _check_auth_and_handle( instance_name=content[_STREAM_POSITION_KEY]["instance_name"], stream_name=stream_name, position=position, + raise_on_timeout=False, ) if self.CACHE: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 322d695bc7f0..5c2482e40cb6 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.protocol import ReconnectingClientFactory @@ -314,10 +315,21 @@ def on_remote_server_up(self, server: str) -> None: self.send_handler.wake_destination(server) async def wait_for_stream_position( - self, instance_name: str, stream_name: str, position: int + self, + instance_name: str, + stream_name: str, + position: int, + raise_on_timeout: bool = True, ) -> None: """Wait until this instance has received updates up to and including the given stream position. + + Args: + instance_name + stream_name + position + raise_on_timeout: Whether to raise an exception if we time out + waiting for the updates, or if we log an error and return. """ if instance_name == self._instance_name: @@ -345,7 +357,16 @@ async def wait_for_stream_position( # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): logger.info("Waiting for repl stream %r to reach %s", stream_name, position) - await make_deferred_yieldable(deferred) + try: + await make_deferred_yieldable(deferred) + except defer.TimeoutError: + logger.error("Timed out waiting for stream %s", stream_name) + + if raise_on_timeout: + raise + + return + logger.info( "Finished waiting for repl stream %r to reach %s", stream_name, position )