Skip to content

Commit

Permalink
In sync wait for worker to catch up since token (element-hq#17215)
Browse files Browse the repository at this point in the history
Otherwise things will get confused.

An alternative would be to make sure that for lagging stream we don't
return anything (and make sure the returned next_batch token doesn't go
backwards). But that is a faff.
  • Loading branch information
erikjohnston authored and H-Shay committed May 31, 2024
1 parent 44bee74 commit b9efff0
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 7 deletions.
1 change: 1 addition & 0 deletions changelog.d/17215.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where duplicate events could be sent down sync when using workers that are overloaded.
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,8 @@ netaddr = ">=0.7.18"
# add a lower bound to the Jinja2 dependency.
Jinja2 = ">=3.0"
bleach = ">=1.4.3"
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
# Additionally we need https://github.com/python/typing/pull/817 to allow types to be
# generic over ParamSpecs.
typing-extensions = ">=3.10.0.1"
# We use `Self`, which were added in `typing-extensions` 4.0.
typing-extensions = ">=4.0"
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
cryptography = ">=3.4.7"
Expand Down
35 changes: 35 additions & 0 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,23 @@ def __bool__(self) -> bool:
or self.device_lists
)

@staticmethod
def empty(next_batch: StreamToken) -> "SyncResult":
"Return a new empty result"
return SyncResult(
next_batch=next_batch,
presence=[],
account_data=[],
joined=[],
invited=[],
knocked=[],
archived=[],
to_device=[],
device_lists=DeviceListUpdates(),
device_one_time_keys_count={},
device_unused_fallback_key_types=[],
)


@attr.s(slots=True, frozen=True, auto_attribs=True)
class E2eeSyncResult:
Expand Down Expand Up @@ -497,6 +514,24 @@ async def _wait_for_sync_for_user(
if context:
context.tag = sync_label

if since_token is not None:
# We need to make sure this worker has caught up with the token. If
# this returns false it means we timed out waiting, and we should
# just return an empty response.
start = self.clock.time_msec()
if not await self.notifier.wait_for_stream_token(since_token):
logger.warning(
"Timed out waiting for worker to catch up. Returning empty response"
)
return SyncResult.empty(since_token)

# If we've spent significant time waiting to catch up, take it off
# the timeout.
now = self.clock.time_msec()
if now - start > 1_000:
timeout -= now - start
timeout = max(timeout, 0)

# if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them)
if since_token is not None:
Expand Down
23 changes: 23 additions & 0 deletions synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,29 @@ async def check_for_updates(

return result

async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
"""Wait for this worker to catch up with the given stream token."""

start = self.clock.time_msec()
while True:
current_token = self.event_sources.get_current_token()
if stream_token.is_before_or_eq(current_token):
return True

now = self.clock.time_msec()

if now - start > 10_000:
return False

logger.info(
"Waiting for current token to reach %s; currently at %s",
stream_token,
current_token,
)

# TODO: be better
await self.clock.sleep(0.5)

async def _get_room_ids(
self, user: UserID, explicit_room_id: Optional[str]
) -> Tuple[StrCollection, bool]:
Expand Down
7 changes: 7 additions & 0 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class DeltaState:
to_insert: StateMap[str]
no_longer_in_room: bool = False

def is_noop(self) -> bool:
"""Whether this state delta is actually empty"""
return not self.to_delete and not self.to_insert and not self.no_longer_in_room


class PersistEventsStore:
"""Contains all the functions for writing events to the database.
Expand Down Expand Up @@ -1017,6 +1021,9 @@ async def update_current_state(
) -> None:
"""Update the current state stored in the datatabase for the given room"""

if state_delta.is_noop():
return

async with self._stream_id_gen.get_next() as stream_ordering:
await self.db_pool.runInteraction(
"update_current_state",
Expand Down
11 changes: 9 additions & 2 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,11 @@ def __init__(
notifier=hs.get_replication_notifier(),
stream_name="events",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
tables=[
("events", "instance_name", "stream_ordering"),
("current_state_delta_stream", "instance_name", "stream_id"),
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
],
sequence_name="events_stream_seq",
writers=hs.config.worker.writers.events,
)
Expand All @@ -210,7 +214,10 @@ def __init__(
notifier=hs.get_replication_notifier(),
stream_name="backfill",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
tables=[
("events", "instance_name", "stream_ordering"),
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
],
sequence_name="events_backfill_stream_seq",
positive=False,
writers=hs.config.worker.writers.events,
Expand Down
58 changes: 57 additions & 1 deletion synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from immutabledict import immutabledict
from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey
from typing_extensions import TypedDict
from typing_extensions import Self, TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface

Expand Down Expand Up @@ -515,6 +515,27 @@ def get_stream_pos_for_instance(self, instance_name: str) -> int:
# at `self.stream`.
return self.instance_map.get(instance_name, self.stream)

def is_before_or_eq(self, other_token: Self) -> bool:
"""Wether this token is before the other token, i.e. every constituent
part is before the other.
Essentially it is `self <= other`.
Note: if `self.is_before_or_eq(other_token) is False` then that does not
imply that the reverse is True.
"""
if self.stream > other_token.stream:
return False

instances = self.instance_map.keys() | other_token.instance_map.keys()
for instance in instances:
if self.instance_map.get(
instance, self.stream
) > other_token.instance_map.get(instance, other_token.stream):
return False

return True


@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken(AbstractMultiWriterStreamToken):
Expand Down Expand Up @@ -1008,6 +1029,41 @@ def get_field(
"""Returns the stream ID for the given key."""
return getattr(self, key.value)

def is_before_or_eq(self, other_token: "StreamToken") -> bool:
"""Wether this token is before the other token, i.e. every constituent
part is before the other.
Essentially it is `self <= other`.
Note: if `self.is_before_or_eq(other_token) is False` then that does not
imply that the reverse is True.
"""

for _, key in StreamKeyType.__members__.items():
if key == StreamKeyType.TYPING:
# Typing stream is allowed to "reset", and so comparisons don't
# really make sense as is.
# TODO: Figure out a better way of tracking resets.
continue

self_value = self.get_field(key)
other_value = other_token.get_field(key)

if isinstance(self_value, RoomStreamToken):
assert isinstance(other_value, RoomStreamToken)
if not self_value.is_before_or_eq(other_value):
return False
elif isinstance(self_value, MultiWriterStreamToken):
assert isinstance(other_value, MultiWriterStreamToken)
if not self_value.is_before_or_eq(other_value):
return False
else:
assert isinstance(other_value, int)
if self_value > other_value:
return False

return True


StreamToken.START = StreamToken(
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
Expand Down

0 comments on commit b9efff0

Please sign in to comment.