diff --git a/changelog.d/6526.bugfix b/changelog.d/6526.bugfix new file mode 100644 index 000000000000..53214b0748d5 --- /dev/null +++ b/changelog.d/6526.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fd3f5ced55f4..f4ac0bfbc89f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -64,8 +64,7 @@ from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.types import UserID, get_domain_from_id -from synapse.util import batch_iter, unwrapFirstError -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server @@ -240,7 +239,6 @@ def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): return None state = None - auth_chain = [] # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): @@ -346,7 +344,6 @@ def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. - auth_chains = set() event_map = {event_id: pdu} try: # Get the state of the events we know about @@ -370,24 +367,14 @@ def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): p, ) - room_version = yield self.store.get_room_version(room_id) - with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - ( - remote_state, - got_auth_chain, - ) = yield self._get_state_for_room( + (remote_state, _,) = yield self._get_state_for_room( origin, room_id, p, include_event_in_state=True ) - # XXX hrm I'm not convinced that duplicate events will compare - # for equality, so I'm not sure this does what the author - # hoped. - auth_chains.update(got_auth_chain) - remote_state_map = { (x.type, x.state_key): x.event_id for x in remote_state } @@ -396,6 +383,7 @@ def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): for x in remote_state: event_map[x.event_id] = x + room_version = yield self.store.get_room_version(room_id) state_map = yield resolve_events_with_store( room_id, room_version, @@ -417,7 +405,6 @@ def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): event_map.update(evs) state = [event_map[e] for e in six.itervalues(state_map)] - auth_chain = list(auth_chains) except Exception: logger.warning( "[%s %s] Error attempting to resolve state at missing " @@ -433,9 +420,7 @@ def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): affected=event_id, ) - yield self._process_received_pdu( - origin, pdu, state=state, auth_chain=auth_chain - ) + yield self._process_received_pdu(origin, pdu, state=state) @defer.inlineCallbacks def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): @@ -638,6 +623,8 @@ def _get_events_from_store_or_dest(self, destination, room_id, event_ids): room_id (str) event_ids (Iterable[str]) + Persists any events we don't already have as outliers. + If we fail to fetch any of the events, a warning will be logged, and the event will be omitted from the result. Likewise, any events which turn out not to be in the given room. @@ -657,27 +644,15 @@ def _get_events_from_store_or_dest(self, destination, room_id, event_ids): room_id, ) - room_version = yield self.store.get_room_version(room_id) - - # XXX 20 requests at once? really? - for batch in batch_iter(missing_events, 20): - deferreds = [ - run_in_background( - self.federation_client.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, - ) - for e_id in batch - ] - - res = yield make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) - ) + yield self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) - for success, result in res: - if success and result: - fetched_events[result.event_id] = result + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + (yield self.store.get_events(missing_events, allow_rejected=True)) + ) # check for events which were in the wrong room. # @@ -707,50 +682,24 @@ def _get_events_from_store_or_dest(self, destination, room_id, event_ids): return fetched_events @defer.inlineCallbacks - def _process_received_pdu(self, origin, event, state, auth_chain): + def _process_received_pdu(self, origin, event, state): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. + + Args: + origin: server sending the event + + event: event to be persisted + + state: Normally None, but if we are handling a gap in the graph + (ie, we are missing one or more prev_events), the resolved state at the + event """ room_id = event.room_id event_id = event.event_id logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) - event_ids = set() - if state: - event_ids |= {e.event_id for e in state} - if auth_chain: - event_ids |= {e.event_id for e in auth_chain} - - seen_ids = yield self.store.have_seen_events(event_ids) - - if state and auth_chain is not None: - # If we have any state or auth_chain given to us by the replication - # layer, then we should handle them (if we haven't before.) - - event_infos = [] - - for e in itertools.chain(auth_chain, state): - if e.event_id in seen_ids: - continue - e.internal_metadata.outlier = True - auth_ids = e.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in auth_chain - if e.event_id in auth_ids or e.type == EventTypes.Create - } - event_infos.append(_NewEventInfo(event=e, auth_events=auth)) - seen_ids.add(e.event_id) - - logger.info( - "[%s %s] persisting newly-received auth/state events %s", - room_id, - event_id, - [e.event.event_id for e in event_infos], - ) - yield self._handle_new_events(origin, event_infos) - try: context = yield self._handle_new_event(origin, event, state=state) except AuthError as e: @@ -806,8 +755,6 @@ def backfill(self, dest, room_id, limit, extremities): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - room_version = yield self.store.get_room_version(room_id) - events = yield self.federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -836,6 +783,9 @@ def backfill(self, dest, room_id, limit, extremities): event_ids = set(e.event_id for e in events) + # build a list of events whose prev_events weren't in the batch. + # (XXX: this will include events whose prev_events we already have; that doesn't + # sound right?) edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) @@ -864,95 +814,11 @@ def backfill(self, dest, room_id, limit, extremities): auth_events.update( {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} ) - missing_auth = required_auth - set(auth_events) - failed_to_fetch = set() - - # Try and fetch any missing auth events from both DB and remote servers. - # We repeatedly do this until we stop finding new auth events. - while missing_auth - failed_to_fetch: - logger.info("Missing auth for backfill: %r", missing_auth) - ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) - auth_events.update(ret_events) - - required_auth.update( - a_id for event in ret_events.values() for a_id in event.auth_event_ids() - ) - missing_auth = required_auth - set(auth_events) - if missing_auth - failed_to_fetch: - logger.info( - "Fetching missing auth for backfill: %r", - missing_auth - failed_to_fetch, - ) - - results = yield make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.federation_client.get_pdu, - [dest], - event_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - failed_to_fetch - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results if a}) - required_auth.update( - a_id - for event in results - if event - for a_id in event.auth_event_ids() - ) - missing_auth = required_auth - set(auth_events) - - failed_to_fetch = missing_auth - set(auth_events) - - seen_events = yield self.store.have_seen_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - # We now have a chunk of events plus associated state and auth chain to - # persist. We do the persistence in two steps: - # 1. Auth events and state get persisted as outliers, plus the - # backward extremities get persisted (as non-outliers). - # 2. The rest of the events in the chunk get persisted one by one, as - # each one depends on the previous event for its state. - # - # The important thing is that events in the chunk get persisted as - # non-outliers, including when those events are also in the state or - # auth chain. Caution must therefore be taken to ensure that they are - # not accidentally marked as outliers. - - # Step 1a: persist auth events that *don't* appear in the chunk ev_infos = [] - for a in auth_events.values(): - # We only want to persist auth events as outliers that we haven't - # seen and aren't about to persist as part of the backfilled chunk. - if a.event_id in seen_events or a.event_id in event_map: - continue - a.internal_metadata.outlier = True - ev_infos.append( - _NewEventInfo( - event=a, - auth_events={ - ( - auth_events[a_id].type, - auth_events[a_id].state_key, - ): auth_events[a_id] - for a_id in a.auth_event_ids() - if a_id in auth_events - }, - ) - ) - - # Step 1b: persist the events in the chunk we fetched state for (i.e. - # the backwards extremities) as non-outliers. + # Step 1: persist the events in the chunk we fetched state for (i.e. + # the backwards extremities), with custom auth events and state for e_id in events_to_state: # For paranoia we ensure that these events are marked as # non-outliers @@ -1194,6 +1060,57 @@ def try_backfill(domains): return False + @defer.inlineCallbacks + def _get_events_and_persist( + self, destination: str, room_id: str, events: Iterable[str] + ): + """Fetch the given events from a server, and persist them as outliers. + + Logs a warning if we can't find the given event. + """ + + room_version = yield self.store.get_room_version(room_id) + + event_infos = [] + + async def get_event(event_id: str): + with nested_logging_context(event_id): + try: + event = await self.federation_client.get_pdu( + [destination], event_id, room_version, outlier=True, + ) + if event is None: + logger.warning( + "Server %s didn't return event %s", destination, event_id, + ) + return + + # recursively fetch the auth events for this event + auth_events = await self._get_events_from_store_or_dest( + destination, room_id, event.auth_event_ids() + ) + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = auth_events.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + + event_infos.append(_NewEventInfo(event, None, auth)) + + except Exception as e: + logger.warning( + "Error fetching missing state/auth event %s: %s %s", + event_id, + type(e), + e, + ) + + yield concurrently_execute(get_event, events, 5) + + yield self._handle_new_events( + destination, event_infos, + ) + def _sanity_check_event(self, ev): """ Do some early sanity checks of a received event diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 5c4de2e69f23..04b6abdc243c 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit): Args: func (func): Function to execute, should return a deferred or coroutine. - args (list): List of arguments to pass to func, each invocation of func - gets a signle argument. + args (Iterable): List of arguments to pass to func, each invocation of func + gets a single argument. limit (int): Maximum number of conccurent executions. Returns: