Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Persist auth/state events at backwards extremities when we fetch them (
Browse files Browse the repository at this point in the history
…#6526)

* commit 'bc7de8765':
  Persist auth/state events at backwards extremities when we fetch them (#6526)
  • Loading branch information
anoadragon453 committed Mar 19, 2020
2 parents 63f56c0 + bc7de87 commit 3fcb360
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 169 deletions.
1 change: 1 addition & 0 deletions changelog.d/6526.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs.
247 changes: 80 additions & 167 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
from synapse.util import batch_iter, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
Expand Down Expand Up @@ -238,7 +237,6 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
return None

state = None
auth_chain = []

# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
Expand Down Expand Up @@ -348,7 +346,6 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:

# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
auth_chains = set()
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
Expand All @@ -369,24 +366,14 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
"Requesting state at missing prev_event %s", event_id,
)

room_version = await self.store.get_room_version(room_id)

with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
(
remote_state,
got_auth_chain,
) = await self._get_state_for_room(
(remote_state, _,) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True
)

# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
auth_chains.update(got_auth_chain)

remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
Expand All @@ -395,6 +382,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
for x in remote_state:
event_map[x.event_id] = x

room_version = await self.store.get_room_version(room_id)
state_map = await resolve_events_with_store(
room_id,
room_version,
Expand All @@ -416,7 +404,6 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
event_map.update(evs)

state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains)
except Exception:
logger.warning(
"[%s %s] Error attempting to resolve state at missing "
Expand All @@ -432,9 +419,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
affected=event_id,
)

await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain
)
await self._process_received_pdu(origin, pdu, state=state)

async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Expand Down Expand Up @@ -633,10 +618,7 @@ async def _get_events_from_store_or_dest(
) -> Dict[str, EventBase]:
"""Fetch events from a remote destination, checking if we already have them.
Args:
destination
room_id
event_ids
Persists any events we don't already have as outliers.
If we fail to fetch any of the events, a warning will be logged, and the event
will be omitted from the result. Likewise, any events which turn out not to
Expand All @@ -656,27 +638,15 @@ async def _get_events_from_store_or_dest(
room_id,
)

room_version = await self.store.get_room_version(room_id)

# XXX 20 requests at once? really?
for batch in batch_iter(missing_events, 20):
deferreds = [
run_in_background(
self.federation_client.get_pdu,
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]

res = await make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
await self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
)

for success, result in res:
if success and result:
fetched_events[result.event_id] = result
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
(await self.store.get_events(missing_events, allow_rejected=True))
)

# check for events which were in the wrong room.
#
Expand Down Expand Up @@ -705,50 +675,26 @@ async def _get_events_from_store_or_dest(

return fetched_events

async def _process_received_pdu(self, origin, event, state, auth_chain):
async def _process_received_pdu(
self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
Args:
origin: server sending the event
event: event to be persisted
state: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
"""
room_id = event.room_id
event_id = event.event_id

logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)

event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}

seen_ids = await self.store.have_seen_events(event_ids)

if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)

event_infos = []

for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = e.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
event_infos.append(_NewEventInfo(event=e, auth_events=auth))
seen_ids.add(e.event_id)

logger.info(
"[%s %s] persisting newly-received auth/state events %s",
room_id,
event_id,
[e.event.event_id for e in event_infos],
)
await self._handle_new_events(origin, event_infos)

try:
context = await self._handle_new_event(origin, event, state=state)
except AuthError as e:
Expand Down Expand Up @@ -803,8 +749,6 @@ async def backfill(self, dest, room_id, limit, extremities):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")

room_version = await self.store.get_room_version(room_id)

events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
Expand Down Expand Up @@ -833,6 +777,9 @@ async def backfill(self, dest, room_id, limit, extremities):

event_ids = set(e.event_id for e in events)

# build a list of events whose prev_events weren't in the batch.
# (XXX: this will include events whose prev_events we already have; that doesn't
# sound right?)
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]

logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
Expand Down Expand Up @@ -861,95 +808,11 @@ async def backfill(self, dest, room_id, limit, extremities):
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()

# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)

required_auth.update(
a_id for event in ret_events.values() for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)

if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch,
)

results = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.federation_client.get_pdu,
[dest],
event_id,
room_version=room_version,
outlier=True,
timeout=10000,
)
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id
for event in results
if event
for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)

failed_to_fetch = missing_auth - set(auth_events)

seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)

# We now have a chunk of events plus associated state and auth chain to
# persist. We do the persistence in two steps:
# 1. Auth events and state get persisted as outliers, plus the
# backward extremities get persisted (as non-outliers).
# 2. The rest of the events in the chunk get persisted one by one, as
# each one depends on the previous event for its state.
#
# The important thing is that events in the chunk get persisted as
# non-outliers, including when those events are also in the state or
# auth chain. Caution must therefore be taken to ensure that they are
# not accidentally marked as outliers.

# Step 1a: persist auth events that *don't* appear in the chunk
ev_infos = []
for a in auth_events.values():
# We only want to persist auth events as outliers that we haven't
# seen and aren't about to persist as part of the backfilled chunk.
if a.event_id in seen_events or a.event_id in event_map:
continue

a.internal_metadata.outlier = True
ev_infos.append(
_NewEventInfo(
event=a,
auth_events={
(
auth_events[a_id].type,
auth_events[a_id].state_key,
): auth_events[a_id]
for a_id in a.auth_event_ids()
if a_id in auth_events
},
)
)

# Step 1b: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities) as non-outliers.
# Step 1: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities), with custom auth events and state
for e_id in events_to_state:
# For paranoia we ensure that these events are marked as
# non-outliers
Expand Down Expand Up @@ -1191,6 +1054,56 @@ async def try_backfill(domains):

return False

async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
):
"""Fetch the given events from a server, and persist them as outliers.
Logs a warning if we can't find the given event.
"""

room_version = await self.store.get_room_version(room_id)

event_infos = []

async def get_event(event_id: str):
with nested_logging_context(event_id):
try:
event = await self.federation_client.get_pdu(
[destination], event_id, room_version, outlier=True,
)
if event is None:
logger.warning(
"Server %s didn't return event %s", destination, event_id,
)
return

# recursively fetch the auth events for this event
auth_events = await self._get_events_from_store_or_dest(
destination, room_id, event.auth_event_ids()
)
auth = {}
for auth_event_id in event.auth_event_ids():
ae = auth_events.get(auth_event_id)
if ae:
auth[(ae.type, ae.state_key)] = ae

event_infos.append(_NewEventInfo(event, None, auth))

except Exception as e:
logger.warning(
"Error fetching missing state/auth event %s: %s %s",
event_id,
type(e),
e,
)

await concurrently_execute(get_event, events, 5)

await self._handle_new_events(
destination, event_infos,
)

def _sanity_check_event(self, ev):
"""
Do some early sanity checks of a received event
Expand Down
4 changes: 2 additions & 2 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit):
Args:
func (func): Function to execute, should return a deferred or coroutine.
args (list): List of arguments to pass to func, each invocation of func
gets a signle argument.
args (Iterable): List of arguments to pass to func, each invocation of func
gets a single argument.
limit (int): Maximum number of conccurent executions.
Returns:
Expand Down

0 comments on commit 3fcb360

Please sign in to comment.