Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Wait for streams to catch up when processing HTTP replication. #14820

Merged
merged 10 commits into from
Jan 18, 2023
1 change: 1 addition & 0 deletions changelog.d/14820.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix rare races when using workers.
4 changes: 4 additions & 0 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,6 +2259,10 @@ async def persist_events_and_notify(
event_and_contexts, backfilled=backfilled
)

# After persistence we always need to notify replication there may
# be new data.
self._notifier.notify_replication()

Comment on lines +2262 to +2265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this change, did we have to wait for something else to notify replication?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We poke the notifier below for all non-backfilled events, and since I don't think anything "waits" on the backfill stream that has broadly been OK.

But yeah, its not ideal. I kinda want to move the poke to replication more close to where we advance the stream tokens, but that proved a bit of a PITA due to circular dependencies.

if self._ephemeral_messages_enabled:
for event in events:
# If there's an expiry timestamp on the event, schedule its expiry.
Expand Down
97 changes: 88 additions & 9 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import urllib.parse
from inspect import signature
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple

from prometheus_client import Counter, Gauge

Expand All @@ -27,6 +27,7 @@
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace_with_opname
Expand All @@ -53,6 +54,9 @@
)


_STREAM_POSITION_KEY = "_INT_STREAM_POS"


class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""Helper base class for defining new replication HTTP endpoints.

Expand Down Expand Up @@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
a connection error is received.
RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
receiving connection errors, each will backoff exponentially longer.
WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
catch up before processing the request and/or response. Defaults to
True.
"""

NAME: str = abc.abstractproperty() # type: ignore
Expand All @@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
RETRY_ON_CONNECT_ERROR = True
RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1)

WAIT_FOR_STREAMS: ClassVar[bool] = True

def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache: ResponseCache[str] = ResponseCache(
Expand All @@ -126,6 +135,10 @@ def __init__(self, hs: "HomeServer"):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret

self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
self._replication = hs.get_replication_data_handler()
self._instance_name = hs.get_instance_name()

def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
Expand Down Expand Up @@ -160,7 +173,7 @@ async def _serialize_payload(**kwargs) -> JsonDict:

@abc.abstractmethod
async def _handle_request(
self, request: Request, **kwargs: Any
self, request: Request, content: JsonDict, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Handle incoming request.

Expand Down Expand Up @@ -201,6 +214,10 @@ def make_client(cls, hs: "HomeServer") -> Callable:

@trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
# We have to pull these out here to avoid circular dependencies...
streams = hs.get_replication_command_handler().get_streams_to_replicate()
replication = hs.get_replication_data_handler()
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
Expand All @@ -219,6 +236,24 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:

data = await cls._serialize_payload(**kwargs)

if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
# Include the current stream positions that we write to. We
# don't do this for GETs as they don't have a body, and we
# generally assume that a GET won't rely on data we have
# written.
if _STREAM_POSITION_KEY in data:
raise Exception(
"data to send contains %r key", _STREAM_POSITION_KEY
)

data[_STREAM_POSITION_KEY] = {
"streams": {
stream.NAME: stream.current_token(local_instance_name)
for stream in streams
},
"instance_name": local_instance_name,
}

url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
Expand Down Expand Up @@ -308,6 +343,18 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
) from e

_outgoing_request_counter.labels(cls.NAME, 200).inc()

# Wait on any streams that the remote may have written to.
for stream_name, position in result.get(
_STREAM_POSITION_KEY, {}
).items():
await replication.wait_for_stream_position(
instance_name=instance_name,
stream_name=stream_name,
position=position,
raise_on_timeout=False,
)

return result

return send_request
Expand Down Expand Up @@ -353,6 +400,23 @@ async def _check_auth_and_handle(
if self._replication_secret:
self._check_auth(request)

if self.METHOD == "GET":
# GET APIs always have an empty body.
content = {}
else:
content = parse_json_object_from_request(request)

# Wait on any streams that the remote may have written to.
for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
"streams"
].items():
await self._replication.wait_for_stream_position(
instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
stream_name=stream_name,
position=position,
raise_on_timeout=False,
)

if self.CACHE:
txn_id = kwargs.pop("txn_id")

Expand All @@ -361,13 +425,28 @@ async def _check_auth_and_handle(
# correctly yet. In particular, there may be issues to do with logging
# context lifetimes.

return await self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
code, response = await self.response_cache.wrap(
txn_id, self._handle_request, request, content, **kwargs
)
else:
# The `@cancellable` decorator may be applied to `_handle_request`. But we
# told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
# so we have to set up the cancellable flag ourselves.
request.is_render_cancellable = is_function_cancellable(
self._handle_request
)

code, response = await self._handle_request(request, content, **kwargs)

# Return streams we may have written to in the course of processing this
# request.
if _STREAM_POSITION_KEY in response:
raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)

# The `@cancellable` decorator may be applied to `_handle_request`. But we
# told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
# so we have to set up the cancellable flag ourselves.
request.is_render_cancellable = is_function_cancellable(self._handle_request)
if self.WAIT_FOR_STREAMS:
response[_STREAM_POSITION_KEY] = {
stream.NAME: stream.current_token(self._instance_name)
for stream in self._streams
}

return await self._handle_request(request, **kwargs)
return code, response
29 changes: 16 additions & 13 deletions synapse/replication/http/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from twisted.web.server import Request

from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict

Expand Down Expand Up @@ -61,10 +60,8 @@ async def _serialize_payload( # type: ignore[override]
return payload

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str
self, request: Request, content: JsonDict, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, content["content"]
)
Expand Down Expand Up @@ -101,7 +98,7 @@ async def _serialize_payload( # type: ignore[override]
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str
self, request: Request, content: JsonDict, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_user(
user_id, account_data_type
Expand Down Expand Up @@ -143,10 +140,13 @@ async def _serialize_payload( # type: ignore[override]
return payload

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str
self,
request: Request,
content: JsonDict,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, content["content"]
)
Expand Down Expand Up @@ -183,7 +183,12 @@ async def _serialize_payload( # type: ignore[override]
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str
self,
request: Request,
content: JsonDict,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_room(
user_id, room_id, account_data_type
Expand Down Expand Up @@ -225,10 +230,8 @@ async def _serialize_payload( # type: ignore[override]
return payload

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_tag_to_room(
user_id, room_id, tag, content["content"]
)
Expand Down Expand Up @@ -266,7 +269,7 @@ async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict:
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_tag_from_room(
user_id,
Expand Down
10 changes: 3 additions & 7 deletions synapse/replication/http/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from twisted.web.server import Request

from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.logging.opentracing import active_span
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
Expand Down Expand Up @@ -78,7 +77,7 @@ async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id)

Expand Down Expand Up @@ -138,9 +137,8 @@ async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[o
return {"user_ids": user_ids}

async def _handle_request( # type: ignore[override]
self, request: Request
self, request: Request, content: JsonDict
) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
content = parse_json_object_from_request(request)
user_ids: List[str] = content["user_ids"]

logger.info("Resync for %r", user_ids)
Expand Down Expand Up @@ -205,10 +203,8 @@ async def _serialize_payload( # type: ignore[override]
}

async def _handle_request( # type: ignore[override]
self, request: Request
self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

user_id = content["user_id"]
device_id = content["device_id"]
keys = content["keys"]
Expand Down
28 changes: 9 additions & 19 deletions synapse/replication/http/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
from synapse.util.metrics import Measure
Expand Down Expand Up @@ -114,10 +113,8 @@ async def _serialize_payload( # type: ignore[override]

return payload

async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override]
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)

room_id = content["room_id"]
backfilled = content["backfilled"]

Expand Down Expand Up @@ -181,13 +178,10 @@ async def _serialize_payload( # type: ignore[override]
return {"origin": origin, "content": content}

async def _handle_request( # type: ignore[override]
self, request: Request, edu_type: str
self, request: Request, content: JsonDict, edu_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)

origin = content["origin"]
edu_content = content["content"]
origin = content["origin"]
edu_content = content["content"]

logger.info("Got %r edu from %s", edu_type, origin)

Expand Down Expand Up @@ -231,13 +225,10 @@ async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # ty
return {"args": args}

async def _handle_request( # type: ignore[override]
self, request: Request, query_type: str
self, request: Request, content: JsonDict, query_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)

args = content["args"]
args["origin"] = content["origin"]
args = content["args"]
args["origin"] = content["origin"]

logger.info("Got %r query from %s", query_type, args["origin"])

Expand Down Expand Up @@ -274,7 +265,7 @@ async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
await self.store.clean_room_for_join(room_id)

Expand Down Expand Up @@ -307,9 +298,8 @@ async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDic
return {"room_version": room_version.identifier}

async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}
Expand Down
Loading