diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 23a42b9074ac..af652a76596c 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -39,7 +39,7 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.logging.context import make_deferred_yieldable from synapse.logging.utils import log_function -from synapse.util import batch_iter, unwrapFirstError +from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -327,74 +327,7 @@ def get_room_state_ids(self, destination: str, room_id: str, event_id: str): ): raise Exception("invalid response from /state_ids") - desired_events = set(state_event_ids + auth_event_ids) - event_map = yield self.get_events_from_store_or_dest( - destination, room_id, desired_events - ) - - failed_to_fetch = desired_events - event_map.keys() - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, - failed_to_fetch, - ) - - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] - - auth_chain.sort(key=lambda e: e.depth) - - return pdus, auth_chain - - @defer.inlineCallbacks - def get_events_from_store_or_dest(self, destination, room_id, event_ids): - """Fetch events from a remote destination, checking if we already have them. - - Args: - destination (str) - room_id (str) - event_ids (Iterable[str]) - - Returns: - Deferred[dict[str, EventBase]]: A deferred resolving to a map - from event_id to event - """ - fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) - - missing_events = set(event_ids) - fetched_events.keys() - - if not missing_events: - return fetched_events - - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - event_ids, - ) - - room_version = yield self.store.get_room_version(room_id) - - # XXX 20 requests at once? really? - for batch in batch_iter(missing_events, 20): - deferreds = [ - run_in_background( - self.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, - ) - for e_id in batch - ] - - res = yield make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) - ) - for success, result in res: - if success and result: - fetched_events[result.event_id] = result - - return fetched_events + return state_event_ids, auth_event_ids @defer.inlineCallbacks @log_function diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4d152d598592..3992b4791b23 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -65,6 +65,7 @@ from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import UserID, get_domain_from_id +from synapse.util import batch_iter, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -237,6 +238,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: return None state = None + auth_chain = [] # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): @@ -338,14 +340,9 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: affected=pdu.event_id, ) - logger.info( - "Event %s is missing prev_events: calculating state for a " - "backwards extremity", - event_id, - ) - # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. + auth_chains = set() event_map = {event_id: pdu} try: # Get the state of the events we know about @@ -363,17 +360,42 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: # know about for p in prevs - seen: logger.info( - "Requesting state at missing prev_event %s", event_id, + "[%s %s] Requesting state at missing prev_event %s", + room_id, + event_id, + p, ) + room_version = await self.store.get_room_version(room_id) + with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - (remote_state, _,) = await self._get_state_for_room( - origin, room_id, p, include_event_in_state=True + ( + remote_state, + got_auth_chain, + ) = await self._get_state_for_room(origin, room_id, p) + + # we want the state *after* p; _get_state_for_room returns the + # state *before* p. + remote_event = await self.federation_client.get_pdu( + [origin], p, room_version, outlier=True ) + if remote_event is None: + raise Exception( + "Unable to get missing prev_event %s" % (p,) + ) + + if remote_event.is_state(): + remote_state.append(remote_event) + + # XXX hrm I'm not convinced that duplicate events will compare + # for equality, so I'm not sure this does what the author + # hoped. + auth_chains.update(got_auth_chain) + remote_state_map = { (x.type, x.state_key): x.event_id for x in remote_state } @@ -382,7 +404,6 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: for x in remote_state: event_map[x.event_id] = x - room_version = await self.store.get_room_version(room_id) state_map = await resolve_events_with_store( room_id, room_version, @@ -399,11 +420,12 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: evs = await self.store.get_events( list(state_map.values()), get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, + check_redacted=False, ) event_map.update(evs) state = [event_map[e] for e in six.itervalues(state_map)] + auth_chain = list(auth_chains) except Exception: logger.warning( "[%s %s] Error attempting to resolve state at missing " @@ -419,7 +441,9 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: affected=event_id, ) - await self._process_received_pdu(origin, pdu, state=state) + await self._process_received_pdu( + origin, pdu, state=state, auth_chain=auth_chain + ) async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): """ @@ -553,131 +577,99 @@ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): else: raise - async def _get_state_for_room( - self, - destination: str, - room_id: str, - event_id: str, - include_event_in_state: bool = False, - ) -> Tuple[List[EventBase], List[EventBase]]: + @defer.inlineCallbacks + @log_function + def _get_state_for_room(self, destination, room_id, event_id): """Requests all of the room state at a given event from a remote homeserver. Args: - destination: The remote homeserver to query for the state. - room_id: The id of the room we're interested in. - event_id: The id of the event we want the state at. - include_event_in_state: if true, the event itself will be included in the - returned state event list. + destination (str): The remote homeserver to query for the state. + room_id (str): The id of the room we're interested in. + event_id (str): The id of the event we want the state at. Returns: - A list of events in the state, possibly including the event itself, and - a list of events in the auth chain for the given event. + Deferred[Tuple[List[EventBase], List[EventBase]]]: + A list of events in the state, and a list of events in the auth chain + for the given event. """ ( state_event_ids, auth_event_ids, - ) = await self.federation_client.get_room_state_ids( + ) = yield self.federation_client.get_room_state_ids( destination, room_id, event_id=event_id ) desired_events = set(state_event_ids + auth_event_ids) - - if include_event_in_state: - desired_events.add(event_id) - - event_map = await self._get_events_from_store_or_dest( + event_map = yield self._get_events_from_store_or_dest( destination, room_id, desired_events ) failed_to_fetch = desired_events - event_map.keys() if failed_to_fetch: logger.warning( - "Failed to fetch missing state/auth events for %s %s", - event_id, + "Failed to fetch missing state/auth events for %s: %s", + room_id, failed_to_fetch, ) - remote_state = [ - event_map[e_id] for e_id in state_event_ids if e_id in event_map - ] - - if include_event_in_state: - remote_event = event_map.get(event_id) - if not remote_event: - raise Exception("Unable to get missing prev_event %s" % (event_id,)) - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) - + pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + auth_chain.sort(key=lambda e: e.depth) - return remote_state, auth_chain + return pdus, auth_chain - async def _get_events_from_store_or_dest( - self, destination: str, room_id: str, event_ids: Iterable[str] - ) -> Dict[str, EventBase]: + @defer.inlineCallbacks + def _get_events_from_store_or_dest(self, destination, room_id, event_ids): """Fetch events from a remote destination, checking if we already have them. - Persists any events we don't already have as outliers. - - If we fail to fetch any of the events, a warning will be logged, and the event - will be omitted from the result. Likewise, any events which turn out not to - be in the given room. + Args: + destination (str) + room_id (str) + event_ids (Iterable[str]) Returns: - map from event_id to event + Deferred[dict[str, EventBase]]: A deferred resolving to a map + from event_id to event """ - fetched_events = await self.store.get_events(event_ids, allow_rejected=True) + fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) missing_events = set(event_ids) - fetched_events.keys() - if missing_events: - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - room_id, - ) + if not missing_events: + return fetched_events - await self._get_events_and_persist( - destination=destination, room_id=room_id, events=missing_events - ) - - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - (await self.store.get_events(missing_events, allow_rejected=True)) - ) + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + event_ids, + ) - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B + room_version = yield self.store.get_room_version(room_id) - bad_events = list( - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ) + # XXX 20 requests at once? really? + for batch in batch_iter(missing_events, 20): + deferreds = [ + run_in_background( + self.federation_client.get_pdu, + destinations=[destination], + event_id=e_id, + room_version=room_version, + ) + for e_id in batch + ] - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned auth/state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, + res = yield make_deferred_yieldable( + defer.DeferredList(deferreds, consumeErrors=True) ) - del fetched_events[bad_event_id] + for success, result in res: + if success and result: + fetched_events[result.event_id] = result return fetched_events - async def _process_received_pdu( - self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - ): + @defer.inlineCallbacks + def _process_received_pdu(self, origin, event, state, auth_chain): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -696,15 +688,15 @@ async def _process_received_pdu( logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) try: - context = await self._handle_new_event(origin, event, state=state) + context = yield self._handle_new_event(origin, event, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - room = await self.store.get_room(room_id) + room = yield self.store.get_room(room_id) if not room: try: - await self.store.store_room( + yield self.store.store_room( room_id=room_id, room_creator_user_id="", is_public=False ) except StoreError: @@ -717,11 +709,11 @@ async def _process_received_pdu( # changing their profile info. newly_joined = True - prev_state_ids = await context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: - prev_state = await self.store.get_event( + prev_state = yield self.store.get_event( prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: @@ -729,7 +721,7 @@ async def _process_received_pdu( if newly_joined: user = UserID.from_string(event.state_key) - await self.user_joined_room(user, room_id) + yield self.user_joined_room(user, room_id) @log_function async def backfill(self, dest, room_id, limit, extremities):