From 0b7f62cb9fb3585fb76a1441b13916bf3e500503 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 09:20:07 +0100 Subject: [PATCH 01/12] Rename storage classes --- synapse/events/snapshot.py | 10 +++++----- synapse/push/push_tools.py | 4 ++-- synapse/storage/__init__.py | 12 ++++++------ synapse/storage/persist_events.py | 2 +- synapse/storage/purge_events.py | 2 +- synapse/storage/state.py | 2 +- synapse/visibility.py | 10 +++++----- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 7a91544119f7..18c03a46160b 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -22,7 +22,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: - from synapse.storage import Storage + from synapse.storage import StorageControllers from synapse.storage.databases.main import DataStore from synapse.storage.state import StateFilter @@ -84,7 +84,7 @@ class EventContext: incomplete state. """ - _storage: "Storage" + _storage: "StorageControllers" rejected: Union[Literal[False], str] = False _state_group: Optional[int] = None state_group_before_event: Optional[int] = None @@ -97,7 +97,7 @@ class EventContext: @staticmethod def with_state( - storage: "Storage", + storage: "StorageControllers", state_group: Optional[int], state_group_before_event: Optional[int], state_delta_due_to_event: Optional[StateMap[str]], @@ -117,7 +117,7 @@ def with_state( @staticmethod def for_outlier( - storage: "Storage", + storage: "StorageControllers", ) -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" return EventContext(storage=storage) @@ -147,7 +147,7 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: } @staticmethod - def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": + def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext": """Converts a dict that was produced by `serialize` back into a EventContext. diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a1bf5b20dd42..83af3b7fdfe8 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,7 +16,7 @@ from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event -from synapse.storage import Storage +from synapse.storage import StorageControllers from synapse.storage.databases.main import DataStore @@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - async def get_context_for_event( - storage: Storage, ev: EventBase, user_id: str + storage: StorageControllers, ev: EventBase, user_id: str ) -> Dict[str, str]: ctx = {} diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 105e4e1fec1b..d82365222a2f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,9 +30,9 @@ from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore -from synapse.storage.persist_events import EventsPersistenceStorage -from synapse.storage.purge_events import PurgeEventsStorage -from synapse.storage.state import StateGroupStorage +from synapse.storage.persist_events import EventsPersistenceStorageController +from synapse.storage.purge_events import PurgeEventsStorageController +from synapse.storage.state import StateGroupStorageController if TYPE_CHECKING: from synapse.server import HomeServer @@ -50,9 +50,9 @@ def __init__(self, hs: "HomeServer", stores: Databases): # interfaces. self.main = stores.main - self.purge_events = PurgeEventsStorage(hs, stores) - self.state = StateGroupStorage(hs, stores) + self.purge_events = PurgeEventsStorageController(hs, stores) + self.state = StateGroupStorageController(hs, stores) self.persistence = None if stores.persist_events: - self.persistence = EventsPersistenceStorage(hs, stores) + self.persistence = EventsPersistenceStorageController(hs, stores) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index a21dea91c852..ef8c135b1253 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -272,7 +272,7 @@ def _get_drainining_queue( pass -class EventsPersistenceStorage: +class EventsPersistenceStorageController: """High level interface for handling persisting newly received events. Takes care of batching up events by room, and calculating the necessary diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index 30669beb7c6a..9ca50d6a0982 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -class PurgeEventsStorage: +class PurgeEventsStorageController: """High level interface for purging rooms and event history.""" def __init__(self, hs: "HomeServer", stores: Databases): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ab630953ac93..96896b4fb43b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -580,7 +580,7 @@ def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) -class StateGroupStorage: +class StateGroupStorageController: """High level interface to fetching state for event.""" def __init__(self, hs: "HomeServer", stores: "Databases"): diff --git a/synapse/visibility.py b/synapse/visibility.py index da4af02796c3..13a5ff63490a 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase from synapse.events.utils import prune_event -from synapse.storage import Storage +from synapse.storage import StorageControllers from synapse.storage.state import StateFilter from synapse.types import RetentionPolicy, StateMap, get_domain_from_id @@ -47,7 +47,7 @@ async def filter_events_for_client( - storage: Storage, + storage: StorageControllers, user_id: str, events: List[EventBase], is_peeking: bool = False, @@ -268,7 +268,7 @@ def allowed(event: EventBase) -> Optional[EventBase]: async def filter_events_for_server( - storage: Storage, + storage: StorageControllers, server_name: str, events: List[EventBase], redact: bool = True, @@ -360,7 +360,7 @@ def check_event_is_visible( async def _event_to_history_vis( - storage: Storage, events: Collection[EventBase] + storage: StorageControllers, events: Collection[EventBase] ) -> Dict[str, str]: """Get the history visibility at each of the given events @@ -407,7 +407,7 @@ async def _event_to_history_vis( async def _event_to_memberships( - storage: Storage, events: Collection[EventBase], server_name: str + storage: StorageControllers, events: Collection[EventBase], server_name: str ) -> Dict[str, StateMap[EventBase]]: """Get the remote membership list at each of the given events From 78211cbaa4c7d86d588376a26499b3960a9b241f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 09:23:07 +0100 Subject: [PATCH 02/12] Rename vars --- synapse/handlers/admin.py | 6 ++-- synapse/handlers/federation.py | 4 +-- synapse/handlers/federation_event.py | 14 +++++--- synapse/handlers/initial_sync.py | 9 +++-- synapse/handlers/message.py | 18 ++++++---- synapse/handlers/pagination.py | 4 +-- synapse/handlers/room.py | 4 +-- synapse/handlers/room_batch.py | 4 +-- synapse/handlers/search.py | 4 +-- synapse/handlers/sync.py | 20 ++++++----- synapse/push/mailer.py | 10 +++--- synapse/state/__init__.py | 33 ++++++++++++------- tests/handlers/test_federation.py | 6 ++-- tests/handlers/test_federation_event.py | 4 +-- tests/handlers/test_message.py | 12 ++++--- .../slave/storage/test_receipts.py | 10 ++++-- 16 files changed, 98 insertions(+), 64 deletions(-) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 50e34743b73d..dd134747780b 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -31,7 +31,7 @@ class AdminHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_storage = self.storage.state + self.state_storage_controller = self.storage.state async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -233,7 +233,9 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = await self.state_storage.get_state_for_event(event_id) + state = await self.state_storage_controller.get_state_for_event( + event_id + ) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c8233270d72c..6efa5fd678b3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,7 +126,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_storage = self.storage.state + self.state_storage_controller = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -1027,7 +1027,7 @@ async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: if event.internal_metadata.outlier: raise NotFoundError("State not known at event %s" % (event_id,)) - state_groups = await self.state_storage.get_state_groups_ids( + state_groups = await self.state_storage_controller.get_state_groups_ids( room_id, [event_id] ) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index a1361af2727d..c65e7386340a 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -99,7 +99,7 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main self._storage = hs.get_storage() - self._state_storage = self._storage.state + self._state_storage_controller = self._storage.state self._state_handler = hs.get_state_handler() self._event_creation_handler = hs.get_event_creation_handler() @@ -535,7 +535,9 @@ async def update_state_for_partial_state_event( ) return await self._store.update_state_for_partial_state_event(event, context) - self._state_storage.notify_event_un_partial_stated(event.event_id) + self._state_storage_controller.notify_event_un_partial_stated( + event.event_id + ) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] @@ -835,7 +837,9 @@ async def _resolve_state_at_missing_prevs( try: # Get the state of the events we know about - ours = await self._state_storage.get_state_groups_ids(room_id, seen) + ours = await self._state_storage_controller.get_state_groups_ids( + room_id, seen + ) # state_maps is a list of mappings from (type, state_key) to event_id state_maps: List[StateMap[str]] = list(ours.values()) @@ -1613,7 +1617,7 @@ async def _check_for_soft_fail( # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_storage.get_state_groups_ids( + state_sets_d = await self._state_storage_controller.get_state_groups_ids( event.room_id, extrem_ids ) state_sets: List[StateMap[str]] = list(state_sets_d.values()) @@ -1885,7 +1889,7 @@ async def _update_context_for_auth_events( # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = await self._state_storage.store_state_group( + state_group = await self._state_storage_controller.store_state_group( event.event_id, event.room_id, prev_group=prev_group, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index c06932a41acf..6b3d20e5593e 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -68,7 +68,7 @@ def __init__(self, hs: "HomeServer"): ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() - self.state_storage = self.storage.state + self.state_storage_controller = self.storage.state async def snapshot_all_rooms( self, @@ -198,7 +198,8 @@ async def handle_room(event: RoomsForUser) -> None: event.stream_ordering, ) deferred_room_state = run_in_background( - self.state_storage.get_state_for_events, [event.event_id] + self.state_storage_controller.get_state_for_events, + [event.event_id], ).addCallback( lambda states: cast(StateMap[EventBase], states[event.event_id]) ) @@ -355,7 +356,9 @@ async def _room_initial_sync_parted( member_event_id: str, is_peeking: bool, ) -> JsonDict: - room_state = await self.state_storage.get_state_for_event(member_event_id) + room_state = await self.state_storage_controller.get_state_for_event( + member_event_id + ) limit = pagin_config.limit if pagin_config else None if limit is None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7ca126dbd171..6b92a4a161a4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -85,7 +85,7 @@ def __init__(self, hs: "HomeServer"): self.state = hs.get_state_handler() self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_storage = self.storage.state + self.state_storage_controller = self.storage.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages @@ -132,7 +132,7 @@ async def get_room_data( assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state = await self.state_storage.get_state_for_events( + room_state = await self.state_storage_controller.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -193,7 +193,7 @@ async def get_state_events( # check whether the user is in the room at that time to determine # whether they should be treated as peeking. - state_map = await self.state_storage.get_state_for_event( + state_map = await self.state_storage_controller.get_state_for_event( last_event.event_id, StateFilter.from_types([(EventTypes.Member, user_id)]), ) @@ -214,8 +214,10 @@ async def get_state_events( ) if visible_events: - room_state_events = await self.state_storage.get_state_for_events( - [last_event.event_id], state_filter=state_filter + room_state_events = ( + await self.state_storage_controller.get_state_for_events( + [last_event.event_id], state_filter=state_filter + ) ) room_state: Mapping[Any, EventBase] = room_state_events[ last_event.event_id @@ -244,8 +246,10 @@ async def get_state_events( assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state_events = await self.state_storage.get_state_for_events( - [membership_event_id], state_filter=state_filter + room_state_events = ( + await self.state_storage_controller.get_state_for_events( + [membership_event_id], state_filter=state_filter + ) ) room_state = room_state_events[membership_event_id] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 6f4820c240cc..4f1541ace6d4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -130,7 +130,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_storage = self.storage.state + self.state_storage_controller = self.storage_controllers.state self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() @@ -539,7 +539,7 @@ async def get_messages( (EventTypes.Member, event.sender) for event in events ) - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self.state_storage_controller.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e2775b34f10b..b9502cd7143c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1193,7 +1193,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_storage = self.storage.state + self.state_storage_controller = self.storage_controllers.state self._relations_handler = hs.get_relations_handler() async def get_event_context( @@ -1293,7 +1293,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = await self.state_storage.get_state_for_events( + state = await self.state_storage_controller.get_state_for_events( [last_event_id], state_filter=state_filter ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 7ce32f2e9ce6..edd35a66bb24 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -17,7 +17,7 @@ class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self.state_storage_controller = hs.get_storage().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -141,7 +141,7 @@ async def get_most_recent_full_state_ids_from_event_id_list( ) = await self.store.get_max_depth_of(event_ids) # mapping from (type, state_key) -> state_event_id assert most_recent_event_id is not None - prev_state_map = await self.state_storage.get_state_ids_for_event( + prev_state_map = await self.state_storage_controller.get_state_ids_for_event( most_recent_event_id ) # List of state event ID's diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index e02c915248c1..314a02fe6545 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -56,7 +56,7 @@ def __init__(self, hs: "HomeServer"): self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() self.storage = hs.get_storage() - self.state_storage = self.storage.state + self.state_storage_controller = self.storage.state self.auth = hs.get_auth() async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: @@ -677,7 +677,7 @@ async def _calculate_event_contexts( [(EventTypes.Member, sender) for sender in senders] ) - state = await self.state_storage.get_state_for_event( + state = await self.state_storage_controller.get_state_for_event( last_event_id, state_filter ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c5c538e0c35e..7bed951db80d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -239,7 +239,7 @@ def __init__(self, hs: "HomeServer"): self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() - self.state_storage = self.storage.state + self.state_storage_controller = self.storage.state # TODO: flush cache entries on subsequent sync request. # Once we get the next /sync request (ie, one with the same access token @@ -630,7 +630,7 @@ async def get_state_after_event( event: event of interest state_filter: The state filter used to fetch state from the database. """ - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self.state_storage_controller.get_state_ids_for_event( event.event_id, state_filter=state_filter or StateFilter.all() ) if event.is_state(): @@ -710,7 +710,7 @@ async def compute_summary( return None last_event = last_events[-1] - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self.state_storage_controller.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -889,13 +889,15 @@ async def compute_state_delta( if full_state: if batch: current_state_ids = ( - await self.state_storage.get_state_ids_for_event( + await self.state_storage_controller.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) ) - state_ids = await self.state_storage.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + state_ids = ( + await self.state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) ) else: @@ -915,7 +917,7 @@ async def compute_state_delta( elif batch.limited: if batch: state_at_timeline_start = ( - await self.state_storage.get_state_ids_for_event( + await self.state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) ) @@ -950,7 +952,7 @@ async def compute_state_delta( if batch: current_state_ids = ( - await self.state_storage.get_state_ids_for_event( + await self.state_storage_controller.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) ) @@ -982,7 +984,7 @@ async def compute_state_delta( # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self.state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 84124af96527..cc5ac5bcdf7b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -114,7 +114,7 @@ def __init__( self.send_email_handler = hs.get_send_email_handler() self.store = self.hs.get_datastores().main - self.state_storage = self.hs.get_storage().state + self.state_storage_controller = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() @@ -494,7 +494,7 @@ async def _get_message_vars( ) else: # Attempt to check the historical state for the room. - historical_state = await self.state_storage.get_state_for_event( + historical_state = await self.state_storage_controller.get_state_for_event( event.event_id, StateFilter.from_types((type_state_key,)) ) sender_state_event = historical_state.get(type_state_key) @@ -767,8 +767,10 @@ async def _make_summary_text_from_member_events( member_event_ids.append(sender_state_event_id) else: # Attempt to check the historical state for the room. - historical_state = await self.state_storage.get_state_for_event( - event_id, StateFilter.from_types((type_state_key,)) + historical_state = ( + await self.state_storage_controller.get_state_for_event( + event_id, StateFilter.from_types((type_state_key,)) + ) ) sender_state_event = historical_state.get(type_state_key) if sender_state_event: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 9c9d946f38c0..417d110e1fb3 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -127,7 +127,7 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self.state_storage_controller = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() self._storage = hs.get_storage() @@ -337,12 +337,14 @@ async def compute_event_context( # if not state_group_before_event: - state_group_before_event = await self.state_storage.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - current_state_ids=state_ids_before_event, + state_group_before_event = ( + await self.state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) ) # Assign the new state group to the cached state entry. @@ -382,7 +384,7 @@ async def compute_event_context( state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = await self.state_storage.store_state_group( + state_group_after_event = await self.state_storage_controller.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event, @@ -416,7 +418,9 @@ async def resolve_state_groups_for_events( """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_storage.get_state_group_for_events(event_ids) + state_groups = await self.state_storage_controller.get_state_group_for_events( + event_ids + ) state_group_ids = state_groups.values() @@ -424,8 +428,13 @@ async def resolve_state_groups_for_events( state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self.state_storage.get_state_for_groups(state_group_ids_set) - prev_group, delta_ids = await self.state_storage.get_state_group_delta( + state = await self.state_storage_controller.get_state_for_groups( + state_group_ids_set + ) + ( + prev_group, + delta_ids, + ) = await self.state_storage_controller.get_state_group_delta( state_group_id ) return _StateCacheEntry( @@ -439,7 +448,7 @@ async def resolve_state_groups_for_events( room_version = await self.store.get_room_version_id(room_id) - state_to_resolve = await self.state_storage.get_state_for_groups( + state_to_resolve = await self.state_storage_controller.get_state_for_groups( state_group_ids_set ) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index ec0090062166..1e1cf42240b5 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -50,7 +50,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self.state_storage_controller = hs.get_storage().state self._event_auth_handler = hs.get_event_auth_handler() return hs @@ -338,7 +338,9 @@ def test_backfill_floating_outlier_membership_auth(self) -> None: # mapping from (type, state_key) -> state_event_id assert most_recent_prev_event_id is not None prev_state_map = self.get_success( - self.state_storage.get_state_ids_for_event(most_recent_prev_event_id) + self.state_storage_controller.get_state_ids_for_event( + most_recent_prev_event_id + ) ) # List of state event ID's prev_state_ids = list(prev_state_map.values()) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index e64b28f28b86..75b2ef8a3092 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -70,7 +70,7 @@ def _test_process_pulled_event_with_missing_state( ) -> None: OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" main_store = self.hs.get_datastores().main - state_storage = self.hs.get_storage().state + state_storage_controller = self.hs.get_storage().state # create the room user_id = self.register_user("kermit", "test") @@ -216,7 +216,7 @@ async def get_event(destination: str, event_id: str, timeout=None): # check that the state at that event is as expected state = self.get_success( - state_storage.get_state_ids_for_event(pulled_event.event_id) + state_storage_controller.get_state_ids_for_event(pulled_event.event_id) ) expected_state = { (e.type, e.state_key): e.event_id for e in state_at_prev_event diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index f4f7ab48458e..e0bc1e71d11e 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -37,7 +37,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_event_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self.persist_event_storage_controller = self.hs.get_storage().persistence self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") @@ -65,7 +65,9 @@ def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) return memberEvent, memberEventContext @@ -129,7 +131,7 @@ def test_duplicated_txn_id(self): self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( - self.persist_event_storage.persist_event(event3, context) + self.persist_event_storage_controller.persist_event(event3, context) ) # Assert that the returned values match those from the initial event @@ -143,7 +145,7 @@ def test_duplicated_txn_id(self): self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events([(event3, context)]) + self.persist_event_storage_controller.persist_events([(event3, context)]) ) ret_event4 = events[0] @@ -166,7 +168,7 @@ def test_duplicated_txn_id_one_call(self): self.assertNotEqual(event1.event_id, event2.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events( + self.persist_event_storage_controller.persist_events( [(event1, context1), (event2, context2)] ) ) diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index 5bbbd5fbcbab..91f500678e26 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -31,7 +31,7 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): def prepare(self, reactor, clock, homeserver): super().prepare(reactor, clock, homeserver) self.room_creator = homeserver.get_room_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self.persist_event_storage_controller = self.hs.get_storage().persistence # Create a test user self.ourUser = UserID.from_string(OUR_USER_ID) @@ -61,7 +61,9 @@ def prepare(self, reactor, clock, homeserver): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) # Join the second user to the second room @@ -76,7 +78,9 @@ def prepare(self, reactor, clock, homeserver): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) def test_return_empty_with_no_data(self): From c188ed64cc4c4cc04746733d9dcbf7388787dd64 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 09:27:02 +0100 Subject: [PATCH 03/12] Rename Storage and vars --- synapse/federation/federation_server.py | 2 +- synapse/handlers/admin.py | 8 ++-- synapse/handlers/device.py | 2 +- synapse/handlers/events.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 15 +++--- synapse/handlers/initial_sync.py | 10 ++-- synapse/handlers/message.py | 14 +++--- synapse/handlers/pagination.py | 13 ++++-- synapse/handlers/relations.py | 7 ++- synapse/handlers/room.py | 7 ++- synapse/handlers/room_batch.py | 2 +- synapse/handlers/search.py | 12 ++--- synapse/handlers/sync.py | 8 ++-- synapse/notifier.py | 4 +- synapse/push/httppusher.py | 6 ++- synapse/push/mailer.py | 6 +-- synapse/replication/http/federation.py | 2 +- synapse/replication/http/send_event.py | 6 ++- synapse/server.py | 6 +-- synapse/state/__init__.py | 8 ++-- synapse/storage/__init__.py | 4 +- tests/events/test_snapshot.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_federation_event.py | 7 +-- tests/handlers/test_message.py | 4 +- tests/handlers/test_user_directory.py | 2 +- tests/replication/slave/storage/_base.py | 2 +- .../slave/storage/test_receipts.py | 4 +- tests/rest/admin/test_user.py | 4 +- tests/rest/client/test_retention.py | 4 +- tests/rest/client/test_room_batch.py | 6 ++- tests/storage/test_event_chain.py | 3 +- tests/storage/test_events.py | 4 +- tests/storage/test_purge.py | 12 +++-- tests/storage/test_redaction.py | 2 +- tests/storage/test_room.py | 2 +- tests/storage/test_room_search.py | 4 +- tests/storage/test_state.py | 2 +- tests/test_utils/event_injection.py | 2 +- tests/test_visibility.py | 46 +++++++++++++------ tests/utils.py | 2 +- 42 files changed, 157 insertions(+), 105 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index b8232e5257d2..58def6bdf526 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -109,7 +109,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_federation_handler() - self.storage = hs.get_storage() + self.storage_controllers = hs.get_storage_controllers() self._spam_checker = hs.get_spam_checker() self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index dd134747780b..31f2e60c320e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,8 +30,8 @@ class AdminHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage_controller = self.storage.state + self.storage_controllers = hs.get_storage_controllers() + self.state_storage_controller = self.storage_controllers.state async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -197,7 +197,9 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> from_key = events[-1].internal_metadata.after - events = await filter_events_for_client(self.storage, user_id, events) + events = await filter_events_for_client( + self.storage_controllers, user_id, events + ) writer.write_events(room_id, events) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b21e46986543..fe73978bebd8 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -70,7 +70,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() - self.state_storage = hs.get_storage().state + self.state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() self.server_name = hs.hostname diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 82a5aac3dda6..586bca06ec08 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -139,7 +139,7 @@ async def get_stream( class EventHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() async def get_event( self, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6efa5fd678b3..68d40828f5cc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -125,7 +125,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.state_storage_controller = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index c65e7386340a..b9086745298a 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -98,8 +98,8 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main - self._storage = hs.get_storage() - self._state_storage_controller = self._storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._state_handler = hs.get_state_handler() self._event_creation_handler = hs.get_event_creation_handler() @@ -1440,7 +1440,7 @@ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: # we're not bothering about room state, so flag the event as an outlier. event.internal_metadata.outlier = True - context = EventContext.for_outlier(self._storage) + context = EventContext.for_outlier(self._storage_controllers) try: validate_event_for_room_version(room_version_obj, event) check_auth_rules_for_event(room_version_obj, event, auth) @@ -1898,7 +1898,7 @@ async def _update_context_for_auth_events( ) return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group=state_group, state_group_before_event=context.state_group_before_event, state_delta_due_to_event=state_updates, @@ -1988,11 +1988,14 @@ async def persist_events_and_notify( ) return result["max_stream_id"] else: - assert self._storage.persistence + assert self._storage_controllers.persistence # Note that this returns the events that were persisted, which may not be # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self._storage.persistence.persist_events( + ( + events, + max_stream_token, + ) = await self._storage_controllers.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 6b3d20e5593e..876bdfc7ae39 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -67,8 +67,8 @@ def __init__(self, hs: "HomeServer"): ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() - self.storage = hs.get_storage() - self.state_storage_controller = self.storage.state + self.storage_controllers = hs.get_storage_controllers() + self.state_storage_controller = self.storage_controllers.state async def snapshot_all_rooms( self, @@ -219,7 +219,7 @@ async def handle_room(event: RoomsForUser) -> None: ).addErrback(unwrapFirstError) messages = await filter_events_for_client( - self.storage, user_id, messages + self.storage_controllers, user_id, messages ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) @@ -372,7 +372,7 @@ async def _room_initial_sync_parted( ) messages = await filter_events_for_client( - self.storage, user_id, messages, is_peeking=is_peeking + self.storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) @@ -477,7 +477,7 @@ async def get_receipts() -> List[JsonDict]: ) messages = await filter_events_for_client( - self.storage, user_id, messages, is_peeking=is_peeking + self.storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6b92a4a161a4..a78f6cd3d995 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -84,8 +84,8 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage_controller = self.storage.state + self.storage_controllers = hs.get_storage_controllers() + self.state_storage_controller = self.storage_controllers.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages @@ -206,7 +206,7 @@ async def get_state_events( is_peeking = not joined visible_events = await filter_events_for_client( - self.storage, + self.storage_controllers, user_id, [last_event], filter_send_to_client=False, @@ -406,7 +406,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -1021,7 +1021,7 @@ async def create_new_client_event( # after it is created if builder.internal_metadata.outlier: event.internal_metadata.outlier = True - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self.storage_controllers) elif ( event.type == EventTypes.MSC2716_INSERTION and state_event_ids @@ -1434,7 +1434,7 @@ async def persist_and_notify_client_event( """ extra_users = extra_users or [] - assert self.storage.persistence is not None + assert self.storage_controllers.persistence is not None assert self._events_shard_config.should_handle( self._instance_name, event.room_id ) @@ -1668,7 +1668,7 @@ async def persist_and_notify_client_event( event, event_pos, max_stream_token, - ) = await self.storage.persistence.persist_event( + ) = await self.storage_controllers.persistence.persist_event( event, context=context, backfilled=backfilled ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 4f1541ace6d4..c4d0b2d3e237 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -129,7 +129,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage_controllers = hs.get_storage_controllers() self.state_storage_controller = self.storage_controllers.state self.clock = hs.get_clock() self._server_name = hs.hostname @@ -352,7 +352,7 @@ async def _purge_history( self._purges_in_progress_by_room.add(room_id) try: async with self.pagination_lock.write(room_id): - await self.storage.purge_events.purge_history( + await self.storage_controllers.purge_events.purge_history( room_id, token, delete_local_events ) logger.info("[purge] complete") @@ -414,7 +414,7 @@ async def purge_room(self, room_id: str, force: bool = False) -> None: if joined: raise SynapseError(400, "Users are still joined to this room") - await self.storage.purge_events.purge_room(room_id) + await self.storage_controllers.purge_events.purge_room(room_id) async def get_messages( self, @@ -520,7 +520,10 @@ async def get_messages( events = await event_filter.filter(events) events = await filter_events_for_client( - self.storage, user_id, events, is_peeking=(member_event_id is None) + self.storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), ) if not events: @@ -653,7 +656,7 @@ async def _shutdown_and_purge_room( 400, "Users are still joined to this room" ) - await self.storage.purge_events.purge_room(room_id) + await self.storage_controllers.purge_events.purge_room(room_id) logger.info("complete") self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ab7e54857d56..9a1cc11bb3eb 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -69,7 +69,7 @@ def __bool__(self) -> bool: class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main - self._storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() @@ -143,7 +143,10 @@ async def get_relations( ) events = await filter_events_for_client( - self._storage, user_id, events, is_peeking=(member_event_id is None) + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), ) now = self._clock.time_msec() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b9502cd7143c..6fd4af932a0b 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1192,7 +1192,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage_controllers = hs.get_storage_controllers() self.state_storage_controller = self.storage_controllers.state self._relations_handler = hs.get_relations_handler() @@ -1236,7 +1236,10 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: if use_admin_priviledge: return events return await filter_events_for_client( - self.storage, user.to_string(), events, is_peeking=is_peeking + self.storage_controllers, + user.to_string(), + events, + is_peeking=is_peeking, ) event = await self.store.get_event( diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index edd35a66bb24..d65126248e0c 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -17,7 +17,7 @@ class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.state_storage_controller = hs.get_storage().state + self.state_storage_controller = hs.get_storage_controllers().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 314a02fe6545..40a7e18586fe 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -55,8 +55,8 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() - self.storage = hs.get_storage() - self.state_storage_controller = self.storage.state + self.storage_controllers = hs.get_storage_controllers() + self.state_storage_controller = self.storage_controllers.state self.auth = hs.get_auth() async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: @@ -460,7 +460,7 @@ async def _search_by_rank( filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + self.storage_controllers, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -559,7 +559,7 @@ async def _search_by_recent( filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + self.storage_controllers, user.to_string(), filtered_events ) room_events.extend(events) @@ -644,11 +644,11 @@ async def _calculate_event_contexts( ) events_before = await filter_events_for_client( - self.storage, user.to_string(), res.events_before + self.storage_controllers, user.to_string(), res.events_before ) events_after = await filter_events_for_client( - self.storage, user.to_string(), res.events_after + self.storage_controllers, user.to_string(), res.events_after ) context: JsonDict = { diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7bed951db80d..048b3d06bf36 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -238,8 +238,8 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.auth = hs.get_auth() - self.storage = hs.get_storage() - self.state_storage_controller = self.storage.state + self.storage_controllers = hs.get_storage_controllers() + self.state_storage_controller = self.storage_controllers.state # TODO: flush cache entries on subsequent sync request. # Once we get the next /sync request (ie, one with the same access token @@ -512,7 +512,7 @@ async def _load_filtered_recents( current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( - self.storage, + self.storage_controllers, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -580,7 +580,7 @@ async def _load_filtered_recents( current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( - self.storage, + self.storage_controllers, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, diff --git a/synapse/notifier.py b/synapse/notifier.py index ba23257f5498..ed3a4d58fa86 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -221,7 +221,7 @@ def __init__(self, hs: "HomeServer"): self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.hs = hs - self.storage = hs.get_storage() + self.storage_controllers = hs.get_storage_controllers() self.event_sources = hs.get_event_sources() self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] @@ -623,7 +623,7 @@ async def check_for_updates( if name == "room": new_events = await filter_events_for_client( - self.storage, + self.storage_controllers, user.to_string(), new_events, is_peeking=is_peeking, diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index d5603596c004..c437bd808fa9 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -65,7 +65,7 @@ class HttpPusher(Pusher): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): super().__init__(hs, pusher_config) - self.storage = self.hs.get_storage() + self.storage_controllers = self.hs.get_storage_controllers() self.app_display_name = pusher_config.app_display_name self.device_display_name = pusher_config.device_display_name self.pushkey_ts = pusher_config.ts @@ -343,7 +343,9 @@ async def _build_notification_dict( } return d - ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id) + ctx = await push_tools.get_context_for_event( + self.storage_controllers, event, self.user_id + ) d = { "notification": { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index cc5ac5bcdf7b..cb9a8f2c1337 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -114,10 +114,10 @@ def __init__( self.send_email_handler = hs.get_send_email_handler() self.store = self.hs.get_datastores().main - self.state_storage_controller = self.hs.get_storage().state + self.state_storage_controller = self.hs.get_storage_controllers().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() - self.storage = hs.get_storage() + self.storage_controllers = hs.get_storage_controllers() self.app_name = app_name self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects @@ -456,7 +456,7 @@ async def _get_notif_vars( } the_events = await filter_events_for_client( - self.storage, user_id, results.events_before + self.storage_controllers, user_id, results.events_before ) the_events.append(notif_event) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 3e7300b4a148..b1d33e30db80 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -69,7 +69,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index ce781768364b..aacabfac4a7e 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -70,7 +70,7 @@ def __init__(self, hs: "HomeServer"): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() @staticmethod @@ -127,7 +127,9 @@ async def _handle_request( # type: ignore[override] event.internal_metadata.outlier = content["outlier"] requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize(self.storage, content["context"]) + context = EventContext.deserialize( + self.storage_controllers, content["context"] + ) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/server.py b/synapse/server.py index 3fd23aaf52cd..d32f1652232e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -123,7 +123,7 @@ WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, Storage +from synapse.storage import Databases, StorageControllers from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock @@ -729,8 +729,8 @@ def get_password_policy_handler(self) -> PasswordPolicyHandler: return PasswordPolicyHandler(self) @cache_in_self - def get_storage(self) -> Storage: - return Storage(self, self.get_datastores()) + def get_storage_controllers(self) -> StorageControllers: + return StorageControllers(self, self.get_datastores()) @cache_in_self def get_replication_streamer(self) -> ReplicationStreamer: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 417d110e1fb3..1854e6ec7024 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -127,10 +127,10 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state_storage_controller = hs.get_storage().state + self.state_storage_controller = hs.get_storage_controllers().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - self._storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() @overload async def get_current_state( @@ -361,7 +361,7 @@ async def compute_event_context( if not event.is_state(): return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group_before_event=state_group_before_event, state_group=state_group_before_event, state_delta_due_to_event={}, @@ -393,7 +393,7 @@ async def compute_event_context( ) return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group=state_group_after_event, state_group_before_event=state_group_before_event, state_delta_due_to_event=delta_ids, diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index d82365222a2f..9964091b8ffe 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -41,8 +41,8 @@ __all__ = ["Databases", "DataStore"] -class Storage: - """The high level interfaces for talking to various storage layers.""" +class StorageControllers: + """The high level interfaces for talking to various storage controller layers.""" def __init__(self, hs: "HomeServer", stores: Databases): # We include the main data store here mainly so that we don't have to diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index defbc68c18cd..2cf3f1a4c9c3 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.user_id = self.register_user("u1", "pass") self.user_tok = self.login("u1", "pass") diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 1e1cf42240b5..500c9ccfbc96 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -50,7 +50,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_storage_controller = hs.get_storage().state + self.state_storage_controller = hs.get_storage_controllers().state self._event_auth_handler = hs.get_event_auth_handler() return hs diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 75b2ef8a3092..1d5b2492c00e 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -70,7 +70,7 @@ def _test_process_pulled_event_with_missing_state( ) -> None: OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" main_store = self.hs.get_datastores().main - state_storage_controller = self.hs.get_storage().state + state_storage_controller = self.hs.get_storage_controllers().state # create the room user_id = self.register_user("kermit", "test") @@ -146,10 +146,11 @@ def _test_process_pulled_event_with_missing_state( ) if prev_exists_as_outlier: prev_event.internal_metadata.outlier = True - persistence = self.hs.get_storage().persistence + persistence = self.hs.get_storage_controllers().persistence self.get_success( persistence.persist_event( - prev_event, EventContext.for_outlier(self.hs.get_storage()) + prev_event, + EventContext.for_outlier(self.hs.get_storage_controllers()), ) ) else: diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index e0bc1e71d11e..b5779c485f83 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_event_creation_handler() - self.persist_event_storage_controller = self.hs.get_storage().persistence + self.persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 4d658d29cab5..a68c2ffd4530 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -954,7 +954,7 @@ def _add_user_to_room( ) self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) + self.hs.get_storage_controllers().persistence.persist_event(event, context) ) def test_local_user_leaving_room_remains_in_user_directory(self) -> None: diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 85be79d19d48..9b41d9309122 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -32,7 +32,7 @@ def prepare(self, reactor, clock, hs): self.master_store = hs.get_datastores().main self.slaved_store = self.worker_hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() def replicate(self): """Tell the master side of replication that something has happened, and then diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index 91f500678e26..19f57115a1c4 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): def prepare(self, reactor, clock, homeserver): super().prepare(reactor, clock, homeserver) self.room_creator = homeserver.get_room_creation_handler() - self.persist_event_storage_controller = self.hs.get_storage().persistence + self.persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) # Create a test user self.ourUser = UserID.from_string(OUR_USER_ID) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 0cdf1dec4042..0d44102237fe 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2579,7 +2579,7 @@ def test_get_rooms_with_nonlocal_user(self) -> None: other_user_tok = self.login("user", "pass") event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() # Create two rooms, one with a local user only and one with both a local # and remote user. @@ -2604,7 +2604,7 @@ def test_get_rooms_with_nonlocal_user(self) -> None: event_creation_handler.create_new_client_event(builder) ) - self.get_success(storage.persistence.persist_event(event, context)) + self.get_success(storage_controllers.persistence.persist_event(event, context)) # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 2cd7a9e6c5f8..ac9c11335460 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -130,7 +130,7 @@ def test_visibility(self) -> None: We do this by setting a very long time between purge jobs. """ store = self.hs.get_datastores().main - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Send a first event, which should be filtered out at the end of the test. @@ -155,7 +155,7 @@ def test_visibility(self) -> None: ) self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( - filter_events_for_client(storage, self.user_id, events) + filter_events_for_client(storage_controllers, self.user_id, events) ) # We should only get one event back. diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 41a1bf6d890e..20010e028350 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -88,7 +88,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.clock = clock - self.storage = hs.get_storage() + self.storage_controllers = hs.get_storage_controllers() self.virtual_user_id, _ = self.register_appservice_user( "as_user_potato", self.appservice.token @@ -168,7 +168,9 @@ def test_same_state_groups_for_whole_historical_batch(self) -> None: # Fetch the state_groups state_group_map = self.get_success( - self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + self.storage_controllers.state.get_state_groups_ids( + room_id, historical_event_ids + ) ) # We expect all of the historical events to be using the same state_group diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index c7661e71868f..a0ce077a9957 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -393,7 +393,8 @@ def _persist(txn): # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( - txn, [(e, EventContext(self.hs.get_storage())) for e in events] + txn, + [(e, EventContext(self.hs.get_storage_controllers())) for e in events], ) # Actually call the function that calculates the auth chain stuff. diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index aaa3189b16ef..27b20b6b048e 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self.persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main self.register_user("user", "pass") @@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self.persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self): diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 08cc60237ec1..78c8744e3b0f 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -31,7 +31,7 @@ def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) self.store = hs.get_datastores().main - self.storage = self.hs.get_storage() + self.storage_controllers = self.hs.get_storage_controllers() def test_purge_history(self): """ @@ -51,7 +51,9 @@ def test_purge_history(self): # Purge everything before this topological token self.get_success( - self.storage.purge_events.purge_history(self.room_id, token_str, True) + self.storage_controllers.purge_events.purge_history( + self.room_id, token_str, True + ) ) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted @@ -79,7 +81,9 @@ def test_purge_history_wont_delete_extrems(self): # Purge everything before this topological token f = self.get_failure( - self.storage.purge_events.purge_history(self.room_id, event, True), + self.storage_controllers.purge_events.purge_history( + self.room_id, event, True + ), SynapseError, ) self.assertIn("greater than forward", f.value.args[0]) @@ -105,7 +109,7 @@ def test_purge_room(self): self.assertIsNotNone(create_event) # Purge everything before this topological token - self.get_success(self.storage.purge_events.purge_room(self.room_id)) + self.get_success(self.storage_controllers.purge_events.purge_room(self.room_id)) # The events aren't found. self.store._invalidate_get_event_cache(create_event.event_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index d8d17ef37925..8e737332fc06 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -31,7 +31,7 @@ def default_config(self): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 5b011e18cd69..fec81a9f712b 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -72,7 +72,7 @@ def prepare(self, reactor, clock, hs): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8dfc1e1db903..e747c6b50eb9 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -99,7 +99,9 @@ def test_non_string(self): prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) prev_event = self.get_success(store.get_event(prev_event_ids[0])) prev_state_map = self.get_success( - self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) + self.hs.get_storage_controllers().state.get_state_ids_for_event( + prev_event_ids[0] + ) ) event_dict = { diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index f88f1c55fc6f..8043bdbde2c7 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -29,7 +29,7 @@ class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index c654e36ee4f4..8027c7a856e2 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -70,7 +70,7 @@ async def inject_event( """ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - persistence = hs.get_storage().persistence + persistence = hs.get_storage_controllers().persistence assert persistence is not None await persistence.persist_event(event, context) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 7a9b01ef9d44..380a1839a11b 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -34,7 +34,7 @@ def setUp(self) -> None: super(FilterEventsForServerTestCase, self).setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() - self.storage = self.hs.get_storage() + self.storage_controllers = self.hs.get_storage_controllers() self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @@ -60,7 +60,9 @@ def test_filtering(self) -> None: events_to_filter.append(evt) filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self.storage_controllers, "test_server", events_to_filter + ) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -80,7 +82,9 @@ def test_filter_outlier(self) -> None: outlier = self._inject_outlier() self.assertEqual( self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier]) + filter_events_for_server( + self.storage_controllers, "remote_hs", [outlier] + ) ), [outlier], ) @@ -89,7 +93,9 @@ def test_filter_outlier(self) -> None: evt = self._inject_message("@unerased:local_hs") filtered = self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) + filter_events_for_server( + self.storage_controllers, "remote_hs", [outlier, evt] + ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") self.assertEqual(filtered[0], outlier) @@ -99,7 +105,9 @@ def test_filter_outlier(self) -> None: # ... but other servers should only be able to see the outlier (the other should # be redacted) filtered = self.get_success( - filter_events_for_server(self.storage, "other_server", [outlier, evt]) + filter_events_for_server( + self.storage_controllers, "other_server", [outlier, evt] + ) ) self.assertEqual(filtered[0], outlier) self.assertEqual(filtered[1].event_id, evt.event_id) @@ -132,7 +140,9 @@ def test_erased_user(self) -> None: # ... and the filtering happens. filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self.storage_controllers, "test_server", events_to_filter + ) ) for i in range(0, len(events_to_filter)): @@ -168,7 +178,9 @@ def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self.storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_room_member( @@ -194,7 +206,9 @@ def _inject_room_member( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self.storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_message( @@ -216,7 +230,9 @@ def _inject_message( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self.storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_outlier(self) -> EventBase: @@ -234,8 +250,8 @@ def _inject_outlier(self) -> EventBase: event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self.storage.persistence.persist_event( - event, EventContext.for_outlier(self.storage) + self.storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self.storage_controllers) ) ) return event @@ -293,7 +309,9 @@ def test_out_of_band_invite_rejection(self): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@user:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@user:test", + [invite_event, reject_event], ) ), [invite_event, reject_event], @@ -303,7 +321,9 @@ def test_out_of_band_invite_rejection(self): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@other:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@other:test", + [invite_event, reject_event], ) ), [], diff --git a/tests/utils.py b/tests/utils.py index d4ba3a9b99cf..3059c453d595 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -264,7 +264,7 @@ def time_bound_deferred(self, d, *args, **kwargs): async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" - persistence_store = hs.get_storage().persistence + persistence_store = hs.get_storage_controllers().persistence store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() From cc945716b2660a0efcc9f659c949352eb9dde4ab Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 09:31:39 +0100 Subject: [PATCH 04/12] Newsfile --- changelog.d/12913.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12913.misc diff --git a/changelog.d/12913.misc b/changelog.d/12913.misc new file mode 100644 index 000000000000..a2bc940557f2 --- /dev/null +++ b/changelog.d/12913.misc @@ -0,0 +1 @@ +Rename storage classes. From 77aa3ae1069e84425410b64c50811badecbd7402 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 10:37:12 +0100 Subject: [PATCH 05/12] Fix tests --- tests/test_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index 84694d368d8b..95f81bebae19 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -179,12 +179,12 @@ def get_leaves(self): class StateTestCase(unittest.TestCase): def setUp(self): self.dummy_store = _DummyStore() - storage = Mock(main=self.dummy_store, state=self.dummy_store) + storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", "get_datastores", - "get_storage", + "get_storage_controllers", "get_auth", "get_state_handler", "get_clock", @@ -199,7 +199,7 @@ def setUp(self): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) - hs.get_storage.return_value = storage + hs.get_storage_controllers.return_value = storage_controllers self.state = StateHandler(hs) self.event_id = 0 From 25b5c86b30c8f3b3d4d247345c1982b9d53806bd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 10:39:17 +0100 Subject: [PATCH 06/12] Commit unsaved files... --- synapse/handlers/events.py | 4 ++-- synapse/handlers/federation.py | 26 +++++++++++++++----------- synapse/replication/http/federation.py | 4 ++-- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 586bca06ec08..fa729bf4e617 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -139,7 +139,7 @@ async def get_stream( class EventHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage = hs.get_storage_controllers() + self.storage_controllers = hs.get_storage_controllers() async def get_event( self, @@ -177,7 +177,7 @@ async def get_event( is_peeking = user.to_string() not in users filtered = await filter_events_for_client( - self.storage, user.to_string(), [event], is_peeking=is_peeking + self.storage_controllers, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 68d40828f5cc..42bec7e12740 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -125,8 +125,8 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.storage = hs.get_storage_controllers() - self.state_storage_controller = self.storage.state + self.storage_controllers = hs.get_storage_controllers() + self.state_storage_controller = self.storage_controllers.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -324,7 +324,7 @@ async def _maybe_backfill_inner( # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = await filter_events_for_server( - self.storage, + self.storage_controllers, self.server_name, events_to_check, redact=False, @@ -660,7 +660,7 @@ async def do_knock( # in the invitee's sync stream. It is stripped out for all other local users. event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self.storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -849,7 +849,7 @@ async def on_invite_request( ) ) - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self.storage_controllers) await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -878,7 +878,7 @@ async def do_remotely_reject_invite( await self.federation_client.send_leave(host_list, event) - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self.storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -1078,7 +1078,9 @@ async def on_backfill_request( ], ) - events = await filter_events_for_server(self.storage, origin, events) + events = await filter_events_for_server( + self.storage_controllers, origin, events + ) return events @@ -1109,7 +1111,9 @@ async def get_persisted_pdu( if not in_room: raise AuthError(403, "Host not in room.") - events = await filter_events_for_server(self.storage, origin, [event]) + events = await filter_events_for_server( + self.storage_controllers, origin, [event] + ) event = events[0] return event else: @@ -1138,7 +1142,7 @@ async def on_get_missing_events( ) missing_events = await filter_events_for_server( - self.storage, origin, missing_events + self.storage_controllers, origin, missing_events ) return missing_events @@ -1480,9 +1484,9 @@ async def _sync_partial_state_room( # clear the lazy-loading flag. logger.info("Updating current state for %s", room_id) assert ( - self.storage.persistence is not None + self.storage_controllers.persistence is not None ), "TODO(faster_joins): support for workers" - await self.storage.persistence.update_current_state(room_id) + await self.storage_controllers.persistence.update_current_state(room_id) logger.info("Clearing partial-state flag for %s", room_id) success = await self.store.clear_partial_state_room(room_id) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index b1d33e30db80..2d2d06bcf7d6 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -69,7 +69,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastores().main - self.storage = hs.get_storage_controllers() + self.storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -133,7 +133,7 @@ async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # ty event.internal_metadata.outlier = event_payload["outlier"] context = EventContext.deserialize( - self.storage, event_payload["context"] + self.storage_controllers, event_payload["context"] ) event_and_contexts.append((event, context)) From 179e1c37e542748dabdbda3e3dc909ead20ace8f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 13:37:51 +0100 Subject: [PATCH 07/12] Mvoe stuff to storage.controllers --- synapse/events/snapshot.py | 2 +- synapse/push/push_tools.py | 2 +- synapse/server.py | 3 +- synapse/storage/__init__.py | 25 -- synapse/storage/controllers/__init__.py | 46 +++ .../{ => controllers}/persist_events.py | 0 .../storage/{ => controllers}/purge_events.py | 0 synapse/storage/controllers/state.py | 352 ++++++++++++++++++ synapse/storage/state.py | 320 ---------------- synapse/visibility.py | 2 +- 10 files changed, 403 insertions(+), 349 deletions(-) create mode 100644 synapse/storage/controllers/__init__.py rename synapse/storage/{ => controllers}/persist_events.py (100%) rename synapse/storage/{ => controllers}/purge_events.py (100%) create mode 100644 synapse/storage/controllers/state.py diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 18c03a46160b..b700cbbfa197 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -22,7 +22,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: - from synapse.storage import StorageControllers + from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore from synapse.storage.state import StateFilter diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 83af3b7fdfe8..8397229ccb72 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,7 +16,7 @@ from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event -from synapse.storage import StorageControllers +from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore diff --git a/synapse/server.py b/synapse/server.py index d32f1652232e..a66ec228dbab 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -123,7 +123,8 @@ WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, StorageControllers +from synapse.storage import Databases +from synapse.storage.controllers import StorageControllers from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 9964091b8ffe..2817acf2c008 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -26,33 +26,8 @@ data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from typing import TYPE_CHECKING from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore -from synapse.storage.persist_events import EventsPersistenceStorageController -from synapse.storage.purge_events import PurgeEventsStorageController -from synapse.storage.state import StateGroupStorageController - -if TYPE_CHECKING: - from synapse.server import HomeServer - __all__ = ["Databases", "DataStore"] - - -class StorageControllers: - """The high level interfaces for talking to various storage controller layers.""" - - def __init__(self, hs: "HomeServer", stores: Databases): - # We include the main data store here mainly so that we don't have to - # rewrite all the existing code to split it into high vs low level - # interfaces. - self.main = stores.main - - self.purge_events = PurgeEventsStorageController(hs, stores) - self.state = StateGroupStorageController(hs, stores) - - self.persistence = None - if stores.persist_events: - self.persistence = EventsPersistenceStorageController(hs, stores) diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py new file mode 100644 index 000000000000..992261d07be5 --- /dev/null +++ b/synapse/storage/controllers/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from synapse.storage.controllers.persist_events import ( + EventsPersistenceStorageController, +) +from synapse.storage.controllers.purge_events import PurgeEventsStorageController +from synapse.storage.controllers.state import StateGroupStorageController +from synapse.storage.databases import Databases +from synapse.storage.databases.main import DataStore + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +__all__ = ["Databases", "DataStore"] + + +class StorageControllers: + """The high level interfaces for talking to various storage controller layers.""" + + def __init__(self, hs: "HomeServer", stores: Databases): + # We include the main data store here mainly so that we don't have to + # rewrite all the existing code to split it into high vs low level + # interfaces. + self.main = stores.main + + self.purge_events = PurgeEventsStorageController(hs, stores) + self.state = StateGroupStorageController(hs, stores) + + self.persistence = None + if stores.persist_events: + self.persistence = EventsPersistenceStorageController(hs, stores) diff --git a/synapse/storage/persist_events.py b/synapse/storage/controllers/persist_events.py similarity index 100% rename from synapse/storage/persist_events.py rename to synapse/storage/controllers/persist_events.py diff --git a/synapse/storage/purge_events.py b/synapse/storage/controllers/purge_events.py similarity index 100% rename from synapse/storage/purge_events.py rename to synapse/storage/controllers/purge_events.py diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py new file mode 100644 index 000000000000..c6fac27888d4 --- /dev/null +++ b/synapse/storage/controllers/state.py @@ -0,0 +1,352 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import ( + TYPE_CHECKING, + Awaitable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, +) + +from synapse.events import EventBase +from synapse.storage.state import StateFilter +from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.types import MutableStateMap, StateMap + +if TYPE_CHECKING: + + from synapse.server import HomeServer + from synapse.storage.databases import Databases + +logger = logging.getLogger(__name__) + + +class StateGroupStorageController: + """High level interface to fetching state for event.""" + + def __init__(self, hs: "HomeServer", stores: "Databases"): + self._is_mine_id = hs.is_mine_id + self.stores = stores + self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + + def notify_event_un_partial_stated(self, event_id: str) -> None: + self._partial_state_events_tracker.notify_un_partial_stated(event_id) + + async def get_state_group_delta( + self, state_group: int + ) -> Tuple[Optional[int], Optional[StateMap[str]]]: + """Given a state group try to return a previous group and a delta between + the old and the new. + + Args: + state_group: The state group used to retrieve state deltas. + + Returns: + A tuple of the previous group and a state map of the event IDs which + make up the delta between the old and new state groups. + """ + + state_group_delta = await self.stores.state.get_state_group_delta(state_group) + return state_group_delta.prev_group, state_group_delta.delta_ids + + async def get_state_groups_ids( + self, _room_id: str, event_ids: Collection[str] + ) -> Dict[int, MutableStateMap[str]]: + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id: id of the room for these events + event_ids: ids of the events + + Returns: + dict of state_group_id -> (dict of (type, state_key) -> event id) + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + if not event_ids: + return {} + + event_to_groups = await self.get_state_group_for_events(event_ids) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups(groups) + + return group_to_state + + async def get_state_ids_for_group( + self, state_group: int, state_filter: Optional[StateFilter] = None + ) -> StateMap[str]: + """Get the event IDs of all the state in the given state group + + Args: + state_group: A state group for which we want to get the state IDs. + state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules + + Returns: + Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = await self.get_state_for_groups((state_group,), state_filter) + + return group_to_state[state_group] + + async def get_state_groups( + self, room_id: str, event_ids: Collection[str] + ) -> Dict[int, List[EventBase]]: + """Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + + Returns: + dict of state_group_id -> list of state events. + """ + if not event_ids: + return {} + + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) + + state_event_map = await self.stores.main.get_events( + [ + ev_id + for group_ids in group_to_ids.values() + for ev_id in group_ids.values() + ], + get_prev_content=False, + ) + + return { + group: [ + state_event_map[v] + for v in event_id_map.values() + if v in state_event_map + ] + for group, event_id_map in group_to_ids.items() + } + + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ) -> Awaitable[Dict[int, StateMap[str]]]: + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups: list of state group IDs to query + state_filter: The state filter used to fetch state + from the database. + + Returns: + Dict of state group to state map. + """ + + return self.stores.state._get_state_groups_from_groups(groups, state_filter) + + async def get_state_for_events( + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None + ) -> Dict[str, StateMap[EventBase]]: + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. + + Args: + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + + Returns: + A dict of (event_id) -> (type, state_key) -> [state_events] + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + state_event_map = await self.stores.main.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + + async def get_state_ids_for_events( + self, + event_ids: Collection[str], + state_filter: Optional[StateFilter] = None, + ) -> Dict[str, StateMap[str]]: + """ + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) + + Args: + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from event_id -> (type, state_key) -> event_id + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + + async def get_state_for_event( + self, event_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[EventBase]: + """ + Get the state dict corresponding to a particular event + + Args: + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from (type, state_key) -> state_event + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) + """ + state_map = await self.get_state_for_events( + [event_id], state_filter or StateFilter.all() + ) + return state_map[event_id] + + async def get_state_ids_for_event( + self, event_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[str]: + """ + Get the state dict corresponding to a particular event + + Args: + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from (type, state_key) -> state_event_id + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) + """ + state_map = await self.get_state_ids_for_events( + [event_id], state_filter or StateFilter.all() + ) + return state_map[event_id] + + def get_state_for_groups( + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None + ) -> Awaitable[Dict[int, MutableStateMap[str]]]: + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. + from the database. + + Returns: + Dict of state group to state map. + """ + return self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + async def get_state_group_for_events( + self, + event_ids: Collection[str], + await_full_state: bool = True, + ) -> Mapping[str, int]: + """Returns mapping event_id -> state_group + + Args: + event_ids: events to get state groups for + await_full_state: if true, will block if we do not yet have complete + state at these events. + """ + if await_full_state: + await self._partial_state_events_tracker.await_full_state(event_ids) + + return await self.stores.main._get_state_group_for_events(event_ids) + + async def store_state_group( + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[StateMap[str]], + current_state_ids: StateMap[str], + ) -> int: + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids: The state to store. Map of (type, state_key) + to event_id. + + Returns: + The state group ID + """ + return await self.stores.state.store_state_group( + event_id, room_id, prev_group, delta_ids, current_state_ids + ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 96896b4fb43b..96aaffb53c04 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - Awaitable, Callable, Collection, Dict, @@ -32,15 +31,11 @@ from frozendict import frozendict from synapse.api.constants import EventTypes -from synapse.events import EventBase -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker from synapse.types import MutableStateMap, StateKey, StateMap if TYPE_CHECKING: from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad - from synapse.server import HomeServer - from synapse.storage.databases import Databases logger = logging.getLogger(__name__) @@ -578,318 +573,3 @@ def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: types=frozendict({EventTypes.Member: frozenset()}), include_others=True ) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) - - -class StateGroupStorageController: - """High level interface to fetching state for event.""" - - def __init__(self, hs: "HomeServer", stores: "Databases"): - self._is_mine_id = hs.is_mine_id - self.stores = stores - self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) - - def notify_event_un_partial_stated(self, event_id: str) -> None: - self._partial_state_events_tracker.notify_un_partial_stated(event_id) - - async def get_state_group_delta( - self, state_group: int - ) -> Tuple[Optional[int], Optional[StateMap[str]]]: - """Given a state group try to return a previous group and a delta between - the old and the new. - - Args: - state_group: The state group used to retrieve state deltas. - - Returns: - A tuple of the previous group and a state map of the event IDs which - make up the delta between the old and new state groups. - """ - - state_group_delta = await self.stores.state.get_state_group_delta(state_group) - return state_group_delta.prev_group, state_group_delta.delta_ids - - async def get_state_groups_ids( - self, _room_id: str, event_ids: Collection[str] - ) -> Dict[int, MutableStateMap[str]]: - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id: id of the room for these events - event_ids: ids of the events - - Returns: - dict of state_group_id -> (dict of (type, state_key) -> event id) - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - if not event_ids: - return {} - - event_to_groups = await self.get_state_group_for_events(event_ids) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups(groups) - - return group_to_state - - async def get_state_ids_for_group( - self, state_group: int, state_filter: Optional[StateFilter] = None - ) -> StateMap[str]: - """Get the event IDs of all the state in the given state group - - Args: - state_group: A state group for which we want to get the state IDs. - state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules - - Returns: - Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = await self.get_state_for_groups((state_group,), state_filter) - - return group_to_state[state_group] - - async def get_state_groups( - self, room_id: str, event_ids: Collection[str] - ) -> Dict[int, List[EventBase]]: - """Get the state groups for the given list of event_ids - - Args: - room_id: ID of the room for these events. - event_ids: The event IDs to retrieve state for. - - Returns: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - - state_event_map = await self.stores.main.get_events( - [ - ev_id - for group_ids in group_to_ids.values() - for ev_id in group_ids.values() - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in event_id_map.values() - if v in state_event_map - ] - for group, event_id_map in group_to_ids.items() - } - - def _get_state_groups_from_groups( - self, groups: List[int], state_filter: StateFilter - ) -> Awaitable[Dict[int, StateMap[str]]]: - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups: list of state group IDs to query - state_filter: The state filter used to fetch state - from the database. - - Returns: - Dict of state group to state map. - """ - - return self.stores.state._get_state_groups_from_groups(groups, state_filter) - - async def get_state_for_events( - self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None - ) -> Dict[str, StateMap[EventBase]]: - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids: The events to fetch the state of. - state_filter: The state filter used to fetch state. - - Returns: - A dict of (event_id) -> (type, state_key) -> [state_events] - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): - await_full_state = False - - event_to_groups = await self.get_state_group_for_events( - event_ids, await_full_state=await_full_state - ) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - state_event_map = await self.stores.main.get_events( - [ev_id for sd in group_to_state.values() for ev_id in sd.values()], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in group_to_state[group].items() - if v in state_event_map - } - for event_id, group in event_to_groups.items() - } - - return {event: event_to_state[event] for event in event_ids} - - async def get_state_ids_for_events( - self, - event_ids: Collection[str], - state_filter: Optional[StateFilter] = None, - ) -> Dict[str, StateMap[str]]: - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids: events whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from event_id -> (type, state_key) -> event_id - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): - await_full_state = False - - event_to_groups = await self.get_state_group_for_events( - event_ids, await_full_state=await_full_state - ) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in event_to_groups.items() - } - - return {event: event_to_state[event] for event in event_ids} - - async def get_state_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None - ) -> StateMap[EventBase]: - """ - Get the state dict corresponding to a particular event - - Args: - event_id: event whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from (type, state_key) -> state_event - - Raises: - RuntimeError if we don't have a state group for the event (ie it is an - outlier or is unknown) - """ - state_map = await self.get_state_for_events( - [event_id], state_filter or StateFilter.all() - ) - return state_map[event_id] - - async def get_state_ids_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None - ) -> StateMap[str]: - """ - Get the state dict corresponding to a particular event - - Args: - event_id: event whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from (type, state_key) -> state_event_id - - Raises: - RuntimeError if we don't have a state group for the event (ie it is an - outlier or is unknown) - """ - state_map = await self.get_state_ids_for_events( - [event_id], state_filter or StateFilter.all() - ) - return state_map[event_id] - - def get_state_for_groups( - self, groups: Iterable[int], state_filter: Optional[StateFilter] = None - ) -> Awaitable[Dict[int, MutableStateMap[str]]]: - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups: list of state groups for which we want to get the state. - state_filter: The state filter used to fetch state. - from the database. - - Returns: - Dict of state group to state map. - """ - return self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - async def get_state_group_for_events( - self, - event_ids: Collection[str], - await_full_state: bool = True, - ) -> Mapping[str, int]: - """Returns mapping event_id -> state_group - - Args: - event_ids: events to get state groups for - await_full_state: if true, will block if we do not yet have complete - state at these events. - """ - if await_full_state: - await self._partial_state_events_tracker.await_full_state(event_ids) - - return await self.stores.main._get_state_group_for_events(event_ids) - - async def store_state_group( - self, - event_id: str, - room_id: str, - prev_group: Optional[int], - delta_ids: Optional[StateMap[str]], - current_state_ids: StateMap[str], - ) -> int: - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id: The event ID for which the state was calculated. - room_id: ID of the room for which the state was calculated. - prev_group: A previous state group for the room, optional. - delta_ids: The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids: The state to store. Map of (type, state_key) - to event_id. - - Returns: - The state group ID - """ - return await self.stores.state.store_state_group( - event_id, room_id, prev_group, delta_ids, current_state_ids - ) diff --git a/synapse/visibility.py b/synapse/visibility.py index 13a5ff63490a..97548c14e34c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase from synapse.events.utils import prune_event -from synapse.storage import StorageControllers +from synapse.storage.controllers import StorageControllers from synapse.storage.state import StateFilter from synapse.types import RetentionPolicy, StateMap, get_domain_from_id From 4302238a744ac5dd5492f36eab1efa3a83ec9d8e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 14:30:22 +0100 Subject: [PATCH 08/12] Rename vars --- synapse/federation/federation_server.py | 1 - synapse/handlers/admin.py | 8 +++--- synapse/handlers/device.py | 4 +-- synapse/handlers/events.py | 4 +-- synapse/handlers/federation.py | 26 ++++++++++--------- synapse/handlers/initial_sync.py | 14 +++++----- synapse/handlers/message.py | 22 ++++++++-------- synapse/handlers/pagination.py | 14 +++++----- synapse/handlers/room.py | 8 +++--- synapse/handlers/room_batch.py | 4 +-- synapse/handlers/search.py | 14 +++++----- synapse/handlers/sync.py | 22 ++++++++-------- synapse/notifier.py | 4 +-- synapse/push/httppusher.py | 4 +-- synapse/push/mailer.py | 10 +++---- synapse/replication/http/federation.py | 4 +-- synapse/replication/http/send_event.py | 4 +-- tests/events/test_snapshot.py | 4 +-- tests/handlers/test_message.py | 10 +++---- tests/replication/slave/storage/_base.py | 2 +- .../replication/slave/storage/test_events.py | 10 ++++--- tests/rest/client/test_room_batch.py | 4 +-- tests/storage/test_events.py | 12 ++++----- tests/storage/test_purge.py | 10 ++++--- tests/storage/test_redaction.py | 14 +++++----- tests/storage/test_room.py | 4 +-- tests/test_visibility.py | 22 ++++++++-------- 27 files changed, 133 insertions(+), 126 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 58def6bdf526..4a8f996cb87b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -109,7 +109,6 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_federation_handler() - self.storage_controllers = hs.get_storage_controllers() self._spam_checker = hs.get_spam_checker() self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 31f2e60c320e..d4fe7df533a1 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,8 +30,8 @@ class AdminHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - self.state_storage_controller = self.storage_controllers.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -198,7 +198,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> from_key = events[-1].internal_metadata.after events = await filter_events_for_client( - self.storage_controllers, user_id, events + self._storage_controllers, user_id, events ) writer.write_events(room_id, events) @@ -235,7 +235,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = await self.state_storage_controller.get_state_for_event( + state = await self._state_storage_controller.get_state_for_event( event_id ) writer.write_state(room_id, event_id, state) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index fe73978bebd8..3b002c6072eb 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -70,7 +70,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() - self.state_storage = hs.get_storage_controllers().state + self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() self.server_name = hs.hostname @@ -203,7 +203,7 @@ async def get_user_ids_changed( continue # mapping from event_id -> state_dict - prev_state_ids = await self.state_storage.get_state_ids_for_events( + prev_state_ids = await self._state_storage.get_state_ids_for_events( event_ids ) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index fa729bf4e617..e5410caf992e 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -139,7 +139,7 @@ async def get_stream( class EventHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() async def get_event( self, @@ -177,7 +177,7 @@ async def get_event( is_peeking = user.to_string() not in users filtered = await filter_events_for_client( - self.storage_controllers, user.to_string(), [event], is_peeking=is_peeking + self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 42bec7e12740..80ee7e7b4e7c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -125,8 +125,8 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - self.state_storage_controller = self.storage_controllers.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -324,7 +324,7 @@ async def _maybe_backfill_inner( # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = await filter_events_for_server( - self.storage_controllers, + self._storage_controllers, self.server_name, events_to_check, redact=False, @@ -660,7 +660,7 @@ async def do_knock( # in the invitee's sync stream. It is stripped out for all other local users. event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] - context = EventContext.for_outlier(self.storage_controllers) + context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -849,7 +849,7 @@ async def on_invite_request( ) ) - context = EventContext.for_outlier(self.storage_controllers) + context = EventContext.for_outlier(self._storage_controllers) await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -878,7 +878,7 @@ async def do_remotely_reject_invite( await self.federation_client.send_leave(host_list, event) - context = EventContext.for_outlier(self.storage_controllers) + context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -1027,7 +1027,7 @@ async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: if event.internal_metadata.outlier: raise NotFoundError("State not known at event %s" % (event_id,)) - state_groups = await self.state_storage_controller.get_state_groups_ids( + state_groups = await self._state_storage_controller.get_state_groups_ids( room_id, [event_id] ) @@ -1079,7 +1079,7 @@ async def on_backfill_request( ) events = await filter_events_for_server( - self.storage_controllers, origin, events + self._storage_controllers, origin, events ) return events @@ -1112,7 +1112,7 @@ async def get_persisted_pdu( raise AuthError(403, "Host not in room.") events = await filter_events_for_server( - self.storage_controllers, origin, [event] + self._storage_controllers, origin, [event] ) event = events[0] return event @@ -1142,7 +1142,7 @@ async def on_get_missing_events( ) missing_events = await filter_events_for_server( - self.storage_controllers, origin, missing_events + self._storage_controllers, origin, missing_events ) return missing_events @@ -1484,9 +1484,11 @@ async def _sync_partial_state_room( # clear the lazy-loading flag. logger.info("Updating current state for %s", room_id) assert ( - self.storage_controllers.persistence is not None + self._storage_controllers.persistence is not None ), "TODO(faster_joins): support for workers" - await self.storage_controllers.persistence.update_current_state(room_id) + await self._storage_controllers.persistence.update_current_state( + room_id + ) logger.info("Clearing partial-state flag for %s", room_id) success = await self.store.clear_partial_state_room(room_id) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 876bdfc7ae39..d78802b34d69 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -67,8 +67,8 @@ def __init__(self, hs: "HomeServer"): ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() - self.storage_controllers = hs.get_storage_controllers() - self.state_storage_controller = self.storage_controllers.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state async def snapshot_all_rooms( self, @@ -198,7 +198,7 @@ async def handle_room(event: RoomsForUser) -> None: event.stream_ordering, ) deferred_room_state = run_in_background( - self.state_storage_controller.get_state_for_events, + self._state_storage_controller.get_state_for_events, [event.event_id], ).addCallback( lambda states: cast(StateMap[EventBase], states[event.event_id]) @@ -219,7 +219,7 @@ async def handle_room(event: RoomsForUser) -> None: ).addErrback(unwrapFirstError) messages = await filter_events_for_client( - self.storage_controllers, user_id, messages + self._storage_controllers, user_id, messages ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) @@ -356,7 +356,7 @@ async def _room_initial_sync_parted( member_event_id: str, is_peeking: bool, ) -> JsonDict: - room_state = await self.state_storage_controller.get_state_for_event( + room_state = await self._state_storage_controller.get_state_for_event( member_event_id ) @@ -372,7 +372,7 @@ async def _room_initial_sync_parted( ) messages = await filter_events_for_client( - self.storage_controllers, user_id, messages, is_peeking=is_peeking + self._storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) @@ -477,7 +477,7 @@ async def get_receipts() -> List[JsonDict]: ) messages = await filter_events_for_client( - self.storage_controllers, user_id, messages, is_peeking=is_peeking + self._storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a78f6cd3d995..bdcd1818774c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -84,8 +84,8 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - self.state_storage_controller = self.storage_controllers.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages @@ -132,7 +132,7 @@ async def get_room_data( assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state = await self.state_storage_controller.get_state_for_events( + room_state = await self._state_storage_controller.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -193,7 +193,7 @@ async def get_state_events( # check whether the user is in the room at that time to determine # whether they should be treated as peeking. - state_map = await self.state_storage_controller.get_state_for_event( + state_map = await self._state_storage_controller.get_state_for_event( last_event.event_id, StateFilter.from_types([(EventTypes.Member, user_id)]), ) @@ -206,7 +206,7 @@ async def get_state_events( is_peeking = not joined visible_events = await filter_events_for_client( - self.storage_controllers, + self._storage_controllers, user_id, [last_event], filter_send_to_client=False, @@ -215,7 +215,7 @@ async def get_state_events( if visible_events: room_state_events = ( - await self.state_storage_controller.get_state_for_events( + await self._state_storage_controller.get_state_for_events( [last_event.event_id], state_filter=state_filter ) ) @@ -247,7 +247,7 @@ async def get_state_events( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" room_state_events = ( - await self.state_storage_controller.get_state_for_events( + await self._state_storage_controller.get_state_for_events( [membership_event_id], state_filter=state_filter ) ) @@ -406,7 +406,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -1021,7 +1021,7 @@ async def create_new_client_event( # after it is created if builder.internal_metadata.outlier: event.internal_metadata.outlier = True - context = EventContext.for_outlier(self.storage_controllers) + context = EventContext.for_outlier(self._storage_controllers) elif ( event.type == EventTypes.MSC2716_INSERTION and state_event_ids @@ -1434,7 +1434,7 @@ async def persist_and_notify_client_event( """ extra_users = extra_users or [] - assert self.storage_controllers.persistence is not None + assert self._storage_controllers.persistence is not None assert self._events_shard_config.should_handle( self._instance_name, event.room_id ) @@ -1668,7 +1668,7 @@ async def persist_and_notify_client_event( event, event_pos, max_stream_token, - ) = await self.storage_controllers.persistence.persist_event( + ) = await self._storage_controllers.persistence.persist_event( event, context=context, backfilled=backfilled ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index c4d0b2d3e237..acabee62ec38 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -129,8 +129,8 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - self.state_storage_controller = self.storage_controllers.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() @@ -352,7 +352,7 @@ async def _purge_history( self._purges_in_progress_by_room.add(room_id) try: async with self.pagination_lock.write(room_id): - await self.storage_controllers.purge_events.purge_history( + await self._storage_controllers.purge_events.purge_history( room_id, token, delete_local_events ) logger.info("[purge] complete") @@ -414,7 +414,7 @@ async def purge_room(self, room_id: str, force: bool = False) -> None: if joined: raise SynapseError(400, "Users are still joined to this room") - await self.storage_controllers.purge_events.purge_room(room_id) + await self._storage_controllers.purge_events.purge_room(room_id) async def get_messages( self, @@ -520,7 +520,7 @@ async def get_messages( events = await event_filter.filter(events) events = await filter_events_for_client( - self.storage_controllers, + self._storage_controllers, user_id, events, is_peeking=(member_event_id is None), @@ -542,7 +542,7 @@ async def get_messages( (EventTypes.Member, event.sender) for event in events ) - state_ids = await self.state_storage_controller.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) @@ -656,7 +656,7 @@ async def _shutdown_and_purge_room( 400, "Users are still joined to this room" ) - await self.storage_controllers.purge_events.purge_room(room_id) + await self._storage_controllers.purge_events.purge_room(room_id) logger.info("complete") self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6fd4af932a0b..5c91d33f583f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1192,8 +1192,8 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - self.state_storage_controller = self.storage_controllers.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._relations_handler = hs.get_relations_handler() async def get_event_context( @@ -1236,7 +1236,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: if use_admin_priviledge: return events return await filter_events_for_client( - self.storage_controllers, + self._storage_controllers, user.to_string(), events, is_peeking=is_peeking, @@ -1296,7 +1296,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = await self.state_storage_controller.get_state_for_events( + state = await self._state_storage_controller.get_state_for_events( [last_event_id], state_filter=state_filter ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index d65126248e0c..1414e575d6fc 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -17,7 +17,7 @@ class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.state_storage_controller = hs.get_storage_controllers().state + self._state_storage_controller = hs.get_storage_controllers().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -141,7 +141,7 @@ async def get_most_recent_full_state_ids_from_event_id_list( ) = await self.store.get_max_depth_of(event_ids) # mapping from (type, state_key) -> state_event_id assert most_recent_event_id is not None - prev_state_map = await self.state_storage_controller.get_state_ids_for_event( + prev_state_map = await self._state_storage_controller.get_state_ids_for_event( most_recent_event_id ) # List of state event ID's diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 40a7e18586fe..659f99f7e2a2 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -55,8 +55,8 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() - self.storage_controllers = hs.get_storage_controllers() - self.state_storage_controller = self.storage_controllers.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.auth = hs.get_auth() async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: @@ -460,7 +460,7 @@ async def _search_by_rank( filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage_controllers, user.to_string(), filtered_events + self._storage_controllers, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -559,7 +559,7 @@ async def _search_by_recent( filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage_controllers, user.to_string(), filtered_events + self._storage_controllers, user.to_string(), filtered_events ) room_events.extend(events) @@ -644,11 +644,11 @@ async def _calculate_event_contexts( ) events_before = await filter_events_for_client( - self.storage_controllers, user.to_string(), res.events_before + self._storage_controllers, user.to_string(), res.events_before ) events_after = await filter_events_for_client( - self.storage_controllers, user.to_string(), res.events_after + self._storage_controllers, user.to_string(), res.events_after ) context: JsonDict = { @@ -677,7 +677,7 @@ async def _calculate_event_contexts( [(EventTypes.Member, sender) for sender in senders] ) - state = await self.state_storage_controller.get_state_for_event( + state = await self._state_storage_controller.get_state_for_event( last_event_id, state_filter ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 048b3d06bf36..b5859dcb28ca 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -238,8 +238,8 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.auth = hs.get_auth() - self.storage_controllers = hs.get_storage_controllers() - self.state_storage_controller = self.storage_controllers.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state # TODO: flush cache entries on subsequent sync request. # Once we get the next /sync request (ie, one with the same access token @@ -512,7 +512,7 @@ async def _load_filtered_recents( current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( - self.storage_controllers, + self._storage_controllers, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -580,7 +580,7 @@ async def _load_filtered_recents( current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( - self.storage_controllers, + self._storage_controllers, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -630,7 +630,7 @@ async def get_state_after_event( event: event of interest state_filter: The state filter used to fetch state from the database. """ - state_ids = await self.state_storage_controller.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( event.event_id, state_filter=state_filter or StateFilter.all() ) if event.is_state(): @@ -710,7 +710,7 @@ async def compute_summary( return None last_event = last_events[-1] - state_ids = await self.state_storage_controller.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -889,13 +889,13 @@ async def compute_state_delta( if full_state: if batch: current_state_ids = ( - await self.state_storage_controller.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) ) state_ids = ( - await self.state_storage_controller.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) ) @@ -917,7 +917,7 @@ async def compute_state_delta( elif batch.limited: if batch: state_at_timeline_start = ( - await self.state_storage_controller.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) ) @@ -952,7 +952,7 @@ async def compute_state_delta( if batch: current_state_ids = ( - await self.state_storage_controller.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) ) @@ -984,7 +984,7 @@ async def compute_state_delta( # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = await self.state_storage_controller.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( diff --git a/synapse/notifier.py b/synapse/notifier.py index ed3a4d58fa86..25c70802e17f 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -221,7 +221,7 @@ def __init__(self, hs: "HomeServer"): self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.hs = hs - self.storage_controllers = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.event_sources = hs.get_event_sources() self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] @@ -623,7 +623,7 @@ async def check_for_updates( if name == "room": new_events = await filter_events_for_client( - self.storage_controllers, + self._storage_controllers, user.to_string(), new_events, is_peeking=is_peeking, diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index c437bd808fa9..e96fb45e9f55 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -65,7 +65,7 @@ class HttpPusher(Pusher): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): super().__init__(hs, pusher_config) - self.storage_controllers = self.hs.get_storage_controllers() + self._storage_controllers = self.hs.get_storage_controllers() self.app_display_name = pusher_config.app_display_name self.device_display_name = pusher_config.device_display_name self.pushkey_ts = pusher_config.ts @@ -344,7 +344,7 @@ async def _build_notification_dict( return d ctx = await push_tools.get_context_for_event( - self.storage_controllers, event, self.user_id + self._storage_controllers, event, self.user_id ) d = { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index cb9a8f2c1337..63aefd07f55c 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -114,10 +114,10 @@ def __init__( self.send_email_handler = hs.get_send_email_handler() self.store = self.hs.get_datastores().main - self.state_storage_controller = self.hs.get_storage_controllers().state + self._state_storage_controller = self.hs.get_storage_controllers().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() - self.storage_controllers = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.app_name = app_name self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects @@ -456,7 +456,7 @@ async def _get_notif_vars( } the_events = await filter_events_for_client( - self.storage_controllers, user_id, results.events_before + self._storage_controllers, user_id, results.events_before ) the_events.append(notif_event) @@ -494,7 +494,7 @@ async def _get_message_vars( ) else: # Attempt to check the historical state for the room. - historical_state = await self.state_storage_controller.get_state_for_event( + historical_state = await self._state_storage_controller.get_state_for_event( event.event_id, StateFilter.from_types((type_state_key,)) ) sender_state_event = historical_state.get(type_state_key) @@ -768,7 +768,7 @@ async def _make_summary_text_from_member_events( else: # Attempt to check the historical state for the room. historical_state = ( - await self.state_storage_controller.get_state_for_event( + await self._state_storage_controller.get_state_for_event( event_id, StateFilter.from_types((type_state_key,)) ) ) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 2d2d06bcf7d6..eed29cd59739 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -69,7 +69,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -133,7 +133,7 @@ async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # ty event.internal_metadata.outlier = event_payload["outlier"] context = EventContext.deserialize( - self.storage_controllers, event_payload["context"] + self._storage_controllers, event_payload["context"] ) event_and_contexts.append((event, context)) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index aacabfac4a7e..c2b2588ea548 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -70,7 +70,7 @@ def __init__(self, hs: "HomeServer"): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() @staticmethod @@ -128,7 +128,7 @@ async def _handle_request( # type: ignore[override] requester = Requester.deserialize(self.store, content["requester"]) context = EventContext.deserialize( - self.storage_controllers, content["context"] + self._storage_controllers, content["context"] ) ratelimit = content["ratelimit"] diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 2cf3f1a4c9c3..8ddce83b830d 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.user_id = self.register_user("u1", "pass") self.user_tok = self.login("u1", "pass") @@ -87,7 +87,7 @@ def test_serialize_deserialize_state_prev(self): def _check_serialize_deserialize(self, event, context): serialized = self.get_success(context.serialize(event, self.store)) - d_context = EventContext.deserialize(self.storage, serialized) + d_context = EventContext.deserialize(self._storage_controllers, serialized) self.assertEqual(context.state_group, d_context.state_group) self.assertEqual(context.rejected, d_context.rejected) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index b5779c485f83..44da96c792fe 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -37,7 +37,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_event_creation_handler() - self.persist_event_storage_controller = ( + self._persist_event_storage_controller = ( self.hs.get_storage_controllers().persistence ) @@ -67,7 +67,7 @@ def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: ) ) self.get_success( - self.persist_event_storage_controller.persist_event( + self._persist_event_storage_controller.persist_event( memberEvent, memberEventContext ) ) @@ -133,7 +133,7 @@ def test_duplicated_txn_id(self): self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( - self.persist_event_storage_controller.persist_event(event3, context) + self._persist_event_storage_controller.persist_event(event3, context) ) # Assert that the returned values match those from the initial event @@ -147,7 +147,7 @@ def test_duplicated_txn_id(self): self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( - self.persist_event_storage_controller.persist_events([(event3, context)]) + self._persist_event_storage_controller.persist_events([(event3, context)]) ) ret_event4 = events[0] @@ -170,7 +170,7 @@ def test_duplicated_txn_id_one_call(self): self.assertNotEqual(event1.event_id, event2.event_id) events, _ = self.get_success( - self.persist_event_storage_controller.persist_events( + self._persist_event_storage_controller.persist_events( [(event1, context1), (event2, context2)] ) ) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 9b41d9309122..c5705256e6fa 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -32,7 +32,7 @@ def prepare(self, reactor, clock, hs): self.master_store = hs.get_datastores().main self.slaved_store = self.worker_hs.get_datastores().main - self.storage = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() def replicate(self): """Tell the master side of replication that something has happened, and then diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 297a9e77f8c3..6d3d4afe52c7 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -262,7 +262,9 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): ) msg, msgctx = self.build_event() self.get_success( - self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + self._storage_controllers.persistence.persist_events( + [(j2, j2ctx), (msg, msgctx)] + ) ) self.replicate() @@ -323,12 +325,14 @@ def persist(self, backfill=False, **kwargs): if backfill: self.get_success( - self.storage.persistence.persist_events( + self._storage_controllers.persistence.persist_events( [(event, context)], backfilled=True ) ) else: - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 20010e028350..1b7ee08ab2c0 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -88,7 +88,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.clock = clock - self.storage_controllers = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.virtual_user_id, _ = self.register_appservice_user( "as_user_potato", self.appservice.token @@ -168,7 +168,7 @@ def test_same_state_groups_for_whole_historical_batch(self) -> None: # Fetch the state_groups state_group_map = self.get_success( - self.storage_controllers.state.get_state_groups_ids( + self._storage_controllers.state.get_state_groups_ids( room_id, historical_event_ids ) ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 27b20b6b048e..a76718e8f995 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage_controllers().persistence + self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main self.register_user("user", "pass") @@ -71,7 +71,7 @@ def persist_event(self, event, state=None): context = self.get_success( self.state.compute_event_context(event, state_ids_before_event=state) ) - self.get_success(self.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) def assert_extremities(self, expected_extremities): """Assert the current extremities for the room""" @@ -148,7 +148,7 @@ def test_do_not_prune_gap_if_state_different(self): ) ) - self.get_success(self.persistence.persist_event(remote_event_2, context)) + self.get_success(self._persistence.persist_event(remote_event_2, context)) # Check that we haven't dropped the old extremity. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) @@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage_controllers().persistence + self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self): @@ -390,7 +390,7 @@ def test_remote_user_rooms_cache_invalidated(self): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_rooms_for_user` to add the remote user to the cache rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) @@ -437,7 +437,7 @@ def test_room_remote_user_cache_invalidated(self): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_users_in_room` to add the remote user to the cache users = self.get_success(self.store.get_users_in_room(room_id)) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 78c8744e3b0f..92cd0dfc0557 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -31,7 +31,7 @@ def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) self.store = hs.get_datastores().main - self.storage_controllers = self.hs.get_storage_controllers() + self._storage_controllers = self.hs.get_storage_controllers() def test_purge_history(self): """ @@ -51,7 +51,7 @@ def test_purge_history(self): # Purge everything before this topological token self.get_success( - self.storage_controllers.purge_events.purge_history( + self._storage_controllers.purge_events.purge_history( self.room_id, token_str, True ) ) @@ -81,7 +81,7 @@ def test_purge_history_wont_delete_extrems(self): # Purge everything before this topological token f = self.get_failure( - self.storage_controllers.purge_events.purge_history( + self._storage_controllers.purge_events.purge_history( self.room_id, event, True ), SynapseError, @@ -109,7 +109,9 @@ def test_purge_room(self): self.assertIsNotNone(create_event) # Purge everything before this topological token - self.get_success(self.storage_controllers.purge_events.purge_room(self.room_id)) + self.get_success( + self._storage_controllers.purge_events.purge_room(self.room_id) + ) # The events aren't found. self.store._invalidate_get_event_cache(create_event.event_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 8e737332fc06..6c4e63b77cac 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -31,7 +31,7 @@ def default_config(self): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage_controllers() + self._storage = hs.get_storage_controllers() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -71,7 +71,7 @@ def inject_room_member( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -93,7 +93,7 @@ def inject_message(self, room, user, body): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -114,7 +114,7 @@ def inject_redaction(self, room, event_id, user, reason): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -268,7 +268,7 @@ def internal_metadata(self): ) ) - self.get_success(self.storage.persistence.persist_event(event_1, context_1)) + self.get_success(self._storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -287,7 +287,7 @@ def internal_metadata(self): ) ) ) - self.get_success(self.storage.persistence.persist_event(event_2, context_2)) + self.get_success(self._storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) @@ -411,7 +411,7 @@ def test_store_redacted_redaction(self): ) self.get_success( - self.storage.persistence.persist_event(redaction_event, context) + self._storage.persistence.persist_event(redaction_event, context) ) # Now lets jump to the future where we have censored the redaction event diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index fec81a9f712b..d497a19f6336 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -72,7 +72,7 @@ def prepare(self, reactor, clock, hs): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastores().main - self.storage = hs.get_storage_controllers() + self._storage = hs.get_storage_controllers() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -88,7 +88,7 @@ def prepare(self, reactor, clock, hs): def inject_room_event(self, **kwargs): self.get_success( - self.storage.persistence.persist_event( + self._storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 380a1839a11b..f338af6c36d0 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -34,7 +34,7 @@ def setUp(self) -> None: super(FilterEventsForServerTestCase, self).setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() - self.storage_controllers = self.hs.get_storage_controllers() + self._storage_controllers = self.hs.get_storage_controllers() self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @@ -61,7 +61,7 @@ def test_filtering(self) -> None: filtered = self.get_success( filter_events_for_server( - self.storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", events_to_filter ) ) @@ -83,7 +83,7 @@ def test_filter_outlier(self) -> None: self.assertEqual( self.get_success( filter_events_for_server( - self.storage_controllers, "remote_hs", [outlier] + self._storage_controllers, "remote_hs", [outlier] ) ), [outlier], @@ -94,7 +94,7 @@ def test_filter_outlier(self) -> None: filtered = self.get_success( filter_events_for_server( - self.storage_controllers, "remote_hs", [outlier, evt] + self._storage_controllers, "remote_hs", [outlier, evt] ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") @@ -106,7 +106,7 @@ def test_filter_outlier(self) -> None: # be redacted) filtered = self.get_success( filter_events_for_server( - self.storage_controllers, "other_server", [outlier, evt] + self._storage_controllers, "other_server", [outlier, evt] ) ) self.assertEqual(filtered[0], outlier) @@ -141,7 +141,7 @@ def test_erased_user(self) -> None: # ... and the filtering happens. filtered = self.get_success( filter_events_for_server( - self.storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", events_to_filter ) ) @@ -179,7 +179,7 @@ def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: self.event_creation_handler.create_new_client_event(builder) ) self.get_success( - self.storage_controllers.persistence.persist_event(event, context) + self._storage_controllers.persistence.persist_event(event, context) ) return event @@ -207,7 +207,7 @@ def _inject_room_member( ) self.get_success( - self.storage_controllers.persistence.persist_event(event, context) + self._storage_controllers.persistence.persist_event(event, context) ) return event @@ -231,7 +231,7 @@ def _inject_message( ) self.get_success( - self.storage_controllers.persistence.persist_event(event, context) + self._storage_controllers.persistence.persist_event(event, context) ) return event @@ -250,8 +250,8 @@ def _inject_outlier(self) -> EventBase: event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self.storage_controllers.persistence.persist_event( - event, EventContext.for_outlier(self.storage_controllers) + self._storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) ) ) return event From 2c7a4f7283323c7d1c2d115cac24bc8dc4f8a80e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 17:03:51 +0100 Subject: [PATCH 09/12] Fix missing rename --- synapse/state/__init__.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 1854e6ec7024..bf09f5128aac 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -127,7 +127,7 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state_storage_controller = hs.get_storage_controllers().state + self._state_storage_controller = hs.get_storage_controllers().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() self._storage_controllers = hs.get_storage_controllers() @@ -338,7 +338,7 @@ async def compute_event_context( if not state_group_before_event: state_group_before_event = ( - await self.state_storage_controller.store_state_group( + await self._state_storage_controller.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event_prev_group, @@ -384,12 +384,14 @@ async def compute_event_context( state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = await self.state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=state_ids_after_event, + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=state_ids_after_event, + ) ) return EventContext.with_state( @@ -418,7 +420,7 @@ async def resolve_state_groups_for_events( """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_storage_controller.get_state_group_for_events( + state_groups = await self._state_storage_controller.get_state_group_for_events( event_ids ) @@ -428,13 +430,13 @@ async def resolve_state_groups_for_events( state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self.state_storage_controller.get_state_for_groups( + state = await self._state_storage_controller.get_state_for_groups( state_group_ids_set ) ( prev_group, delta_ids, - ) = await self.state_storage_controller.get_state_group_delta( + ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) return _StateCacheEntry( @@ -448,7 +450,7 @@ async def resolve_state_groups_for_events( room_version = await self.store.get_room_version_id(room_id) - state_to_resolve = await self.state_storage_controller.get_state_for_groups( + state_to_resolve = await self._state_storage_controller.get_state_for_groups( state_group_ids_set ) From 8dbf7fd582d2b8b5b7168ad427124e756add93cc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 17:06:00 +0100 Subject: [PATCH 10/12] Spurious newline --- synapse/storage/controllers/state.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index c6fac27888d4..0f099530863b 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -30,7 +30,6 @@ from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.storage.databases import Databases From 4dd4d30d91f93fde2432f2631b29d428ac45c6ae Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 May 2022 17:07:41 +0100 Subject: [PATCH 11/12] Mention storage controllers --- synapse/storage/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 2817acf2c008..b51ad5d72636 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -20,7 +20,9 @@ database. The `databases` are classes that talk directly to a `DatabasePool` instance and have associated schemas, background updates, etc. On top of those there are classes that provide high level interfaces that combine calls to -multiple `databases`. +multiple `databases`, called storage controllers and are located in the +`controller` module. These are bundled into a single `StorageControllers` class +for ease of use and exposed via `HomeServer.get_storage_controllers()`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are From caf2dba3fd4f1f1de10566ad24ce38f94f848e30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 May 2022 12:46:08 +0100 Subject: [PATCH 12/12] Fix wording --- synapse/storage/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index b51ad5d72636..bac21ecf9cff 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -18,11 +18,13 @@ against different configurations of databases (e.g. single or multiple databases). The `DatabasePool` class represents connections to a single physical database. The `databases` are classes that talk directly to a `DatabasePool` -instance and have associated schemas, background updates, etc. On top of those -there are classes that provide high level interfaces that combine calls to -multiple `databases`, called storage controllers and are located in the -`controller` module. These are bundled into a single `StorageControllers` class -for ease of use and exposed via `HomeServer.get_storage_controllers()`. +instance and have associated schemas, background updates, etc. + +On top of the databases are the StorageControllers, located in the +`synapse.storage.controllers` module. These classes provide high level +interfaces that combine calls to multiple `databases`. They are bundled into the +`StorageControllers` singleton for ease of use, and exposed via +`HomeServer.get_storage_controllers()`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are