From 45821e1ce9613deeced1858ee7f6b939d14fa781 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 7 Sep 2022 21:59:07 +0100 Subject: [PATCH 1/7] Return read-only collections from `@cached` methods Signed-off-by: Sean Quah --- changelog.d/13755.misc | 1 + synapse/app/phone_stats_home.py | 4 +-- synapse/config/room_directory.py | 6 ++--- synapse/handlers/directory.py | 5 ++-- synapse/handlers/receipts.py | 4 +-- synapse/handlers/room.py | 2 +- synapse/handlers/sync.py | 12 +++++---- synapse/push/bulk_push_rule_evaluator.py | 4 +-- synapse/push/push_rule_evaluator.py | 4 +-- synapse/state/__init__.py | 2 +- synapse/storage/controllers/state.py | 5 ++-- .../storage/databases/main/account_data.py | 7 ++--- synapse/storage/databases/main/appservice.py | 14 ++++++++-- synapse/storage/databases/main/devices.py | 11 +++++--- synapse/storage/databases/main/directory.py | 4 +-- .../storage/databases/main/end_to_end_keys.py | 25 ++++++++++------- .../databases/main/event_federation.py | 7 ++--- .../databases/main/monthly_active_users.py | 4 +-- synapse/storage/databases/main/receipts.py | 10 ++++--- .../storage/databases/main/registration.py | 4 +-- synapse/storage/databases/main/relations.py | 26 +++++++++++++----- synapse/storage/databases/main/roommember.py | 27 ++++++++++--------- synapse/storage/databases/main/signatures.py | 6 ++--- synapse/storage/databases/main/tags.py | 8 +++--- .../storage/databases/main/user_directory.py | 4 +-- tests/rest/admin/test_server_notice.py | 4 +-- 26 files changed, 127 insertions(+), 83 deletions(-) create mode 100644 changelog.d/13755.misc diff --git a/changelog.d/13755.misc b/changelog.d/13755.misc new file mode 100644 index 000000000000..662ee00e99d5 --- /dev/null +++ b/changelog.d/13755.misc @@ -0,0 +1 @@ +Re-type hint some collections as read-only. diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 53db1e85b3d8..897dd3edac3e 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,7 +15,7 @@ import math import resource import sys -from typing import TYPE_CHECKING, List, Sized, Tuple +from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple from prometheus_client import Gauge @@ -194,7 +194,7 @@ def performance_stats_init() -> None: @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users() -> None: current_mau_count = 0 - current_mau_count_by_service = {} + current_mau_count_by_service: Mapping[str, int] = {} reserved_users: Sized = () store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 3ed236217fd4..8666c22f010d 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, Collection from matrix_common.regex import glob_to_regex @@ -70,7 +70,7 @@ def is_alias_creation_allowed(self, user_id: str, room_id: str, alias: str) -> b return False def is_publishing_room_allowed( - self, user_id: str, room_id: str, aliases: List[str] + self, user_id: str, room_id: str, aliases: Collection[str] ) -> bool: """Checks if the given user is allowed to publish the room @@ -122,7 +122,7 @@ def __init__(self, option_name: str, rule: JsonDict): except Exception as e: raise ConfigError("Failed to parse glob into regex") from e - def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool: + def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool: """Tests if this rule matches the given user_id, room_id and aliases. Args: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 7127d5aefcb7..1d1b62bfffff 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -14,7 +14,7 @@ import logging import string -from typing import TYPE_CHECKING, Iterable, List, Optional +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( @@ -483,6 +483,7 @@ async def edit_published_room_list( ) ) if canonical_alias: + room_aliases = list(room_aliases) room_aliases.append(canonical_alias) if not self.config.roomdirectory.is_publishing_room_allowed( @@ -525,7 +526,7 @@ async def edit_published_appservice_room_list( async def get_aliases_for_room( self, requester: Requester, room_id: str - ) -> List[str]: + ) -> Sequence[str]: """ Get a list of the aliases that currently point to this room on this server """ diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index d2bdb9c8be79..c6492e689d89 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple from synapse.api.constants import EduTypes, ReceiptTypes from synapse.appservice import ApplicationService @@ -174,7 +174,7 @@ def __init__(self, hs: "HomeServer"): @staticmethod def filter_out_private_receipts( - rooms: List[JsonDict], user_id: str + rooms: Sequence[JsonDict], user_id: str ) -> List[JsonDict]: """ Filters a list of serialized receipts (as returned by /sync and /initialSync) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 33e9a870022a..8c5ce303378a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1872,6 +1872,6 @@ async def shutdown_room( return { "kicked_users": kicked_users, "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, + "local_aliases": list(aliases_for_room), "new_room_id": new_room_id, } diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2d95b1fa24e6..bf6a6d88084b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1377,7 +1377,7 @@ async def generate_sync_result( one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - unused_fallback_key_types = ( + unused_fallback_key_types = list( await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) @@ -1573,7 +1573,7 @@ async def _generate_sync_entry_for_to_device( async def _generate_sync_entry_for_account_data( self, sync_result_builder: "SyncResultBuilder" - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonDict]]: """Generates the account data portion of the sync response. Account data (called "Client Config" in the spec) can be set either globally @@ -1608,6 +1608,7 @@ async def _generate_sync_entry_for_account_data( ) if push_rules_changed: + global_account_data = dict(global_account_data) global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) @@ -1617,6 +1618,7 @@ async def _generate_sync_entry_for_account_data( account_data_by_room, ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) + global_account_data = dict(global_account_data) global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) @@ -1693,7 +1695,7 @@ async def _generate_sync_entry_for_presence( async def _generate_sync_entry_for_rooms( self, sync_result_builder: "SyncResultBuilder", - account_data_by_room: Dict[str, Dict[str, JsonDict]], + account_data_by_room: Mapping[str, Mapping[str, JsonDict]], ) -> Tuple[Set[str], Set[str], Set[str], Set[str]]: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -2170,8 +2172,8 @@ async def _generate_room_entry( sync_result_builder: "SyncResultBuilder", room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], - tags: Optional[Dict[str, Dict[str, Any]]], - account_data: Dict[str, JsonDict], + tags: Optional[Mapping[str, Mapping[str, Any]]], + account_data: Mapping[str, JsonDict], always_include: bool = False, ) -> None: """Populates the `joined` and `archived` section of `sync_result_builder` diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d1caf8a0f7a0..03c8ca4bbb13 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -17,13 +17,13 @@ import logging from typing import ( TYPE_CHECKING, + AbstractSet, Collection, Dict, Iterable, List, Mapping, Optional, - Set, Tuple, Union, ) @@ -199,7 +199,7 @@ async def _get_power_levels_and_sender_level( async def _get_mutual_relations( self, event: EventBase, rules: Iterable[Tuple[PushRule, bool]] - ) -> Dict[str, Set[Tuple[str, str]]]: + ) -> Dict[str, AbstractSet[Tuple[str, str]]]: """ Fetch event metadata for events which related to the same event as the given event. diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 3c5632cd9153..23206792f921 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,6 +16,7 @@ import logging import re from typing import ( + AbstractSet, Any, Dict, List, @@ -23,7 +24,6 @@ Optional, Pattern, Sequence, - Set, Tuple, Union, ) @@ -131,7 +131,7 @@ def __init__( room_member_count: int, sender_power_level: int, power_levels: Dict[str, Union[int, Dict[str, int]]], - relations: Dict[str, Set[Tuple[str, str]]], + relations: Dict[str, AbstractSet[Tuple[str, str]]], relations_match_enabled: bool, ): self._event = event diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3787d35b244f..47169983a6fa 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -210,7 +210,7 @@ async def compute_state_after_events( return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( - self, room_id: str, latest_event_ids: List[str] + self, room_id: str, latest_event_ids: Collection[str] ) -> Set[str]: """ Get the users IDs who are currently in a room. diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index bbe568bf053e..5f479b5a4a90 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -23,6 +23,7 @@ List, Mapping, Optional, + Sequence, Tuple, ) @@ -523,7 +524,7 @@ async def get_current_state_event( ) return state_map.get(key) - async def get_current_hosts_in_room(self, room_id: str) -> List[str]: + async def get_current_hosts_in_room(self, room_id: str) -> Sequence[str]: """Get current hosts in room based on current state.""" await self._partial_state_room_tracker.await_full_state(room_id) @@ -532,7 +533,7 @@ async def get_current_hosts_in_room(self, room_id: str) -> List[str]: async def get_users_in_room_with_profiles( self, room_id: str - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """ Get the current users in the room with their profiles. If the room is currently partial-stated, this will block until the room has diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c38b8a9e5a7e..50b301bce93a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -21,6 +21,7 @@ FrozenSet, Iterable, List, + Mapping, Optional, Tuple, cast, @@ -132,7 +133,7 @@ def get_max_account_data_stream_id(self) -> int: @cached() async def get_account_data_for_user( self, user_id: str - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: + ) -> Tuple[Mapping[str, JsonDict], Mapping[str, Mapping[str, JsonDict]]]: """Get all the client account_data for a user. Args: @@ -198,7 +199,7 @@ async def get_global_account_data_by_type_for_user( @cached(num_args=2, tree=True) async def get_account_data_for_room( self, user_id: str, room_id: str - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonDict]: """Get all the client account_data for a user for a room. Args: @@ -327,7 +328,7 @@ def get_updated_room_account_data_txn( async def get_updated_account_data_for_user( self, user_id: str, stream_id: int - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: + ) -> Tuple[Mapping[str, JsonDict], Mapping[str, Mapping[str, JsonDict]]]: """Get all the client account_data for a that's changed for a user Args: diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 64b70a7b28ee..e4fa8aa72ee4 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -14,7 +14,17 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Pattern, + Sequence, + Tuple, + cast, +) from synapse.appservice import ( ApplicationService, @@ -156,7 +166,7 @@ async def get_app_service_users_in_room( room_id: str, app_service: "ApplicationService", cache_context: _CacheContext, - ) -> List[str]: + ) -> Sequence[str]: users_in_room = await self.get_users_in_room( room_id, on_invalidate=cache_context.invalidate ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 5d700ca6c307..d190530a49cb 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -22,6 +22,7 @@ Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -135,7 +136,9 @@ def __init__( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) - async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: + async def count_devices_by_users( + self, user_ids: Optional[Collection[str]] = None + ) -> int: """Retrieve number of all devices of given users. Only returns number of devices that are not marked as hidden. @@ -146,7 +149,7 @@ async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> """ def count_devices_by_users_txn( - txn: LoggingTransaction, user_ids: List[str] + txn: LoggingTransaction, user_ids: Collection[str] ) -> int: sql = """ SELECT count(*) @@ -706,7 +709,7 @@ async def get_user_devices_from_cache( device = await self._get_cached_user_device(user_id, device_id) results.setdefault(user_id, {})[device_id] = device else: - results[user_id] = await self.get_cached_devices_for_user(user_id) + results[user_id] = dict(await self.get_cached_devices_for_user(user_id)) set_tag("in_cache", str(results)) set_tag("not_in_cache", str(user_ids_not_in_cache)) @@ -724,7 +727,7 @@ async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDic return db_to_json(content) @cached() - async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: + async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 5903fdaf007a..44aa181174ac 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Sequence, Tuple import attr @@ -74,7 +74,7 @@ async def get_room_alias_creator(self, room_alias: str) -> str: ) @cached(max_entries=5000) - async def get_aliases_for_room(self, room_id: str) -> List[str]: + async def get_aliases_for_room(self, room_id: str) -> Sequence[str]: return await self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 8e9e1b0b4b41..375c5e7a03a7 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -20,7 +20,9 @@ Dict, Iterable, List, + Mapping, Optional, + Sequence, Tuple, Union, cast, @@ -660,7 +662,7 @@ def _set_e2e_fallback_keys_txn( @cached(max_entries=10000) async def get_e2e_unused_fallback_key_types( self, user_id: str, device_id: str - ) -> List[str]: + ) -> Sequence[str]: """Returns the fallback key types that have an unused key. Args: @@ -700,7 +702,7 @@ async def get_e2e_cross_signing_key( return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: + def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -713,7 +715,7 @@ def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[Dict[str, JsonDict]]]: + ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -734,7 +736,7 @@ async def _get_bare_e2e_cross_signing_keys_bulk( ) # The `Optional` comes from the `@cachedList` decorator. - return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) + return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result) def _get_bare_e2e_cross_signing_keys_bulk_txn( self, @@ -893,7 +895,7 @@ def _get_e2e_cross_signing_signatures_txn( @cancellable async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Dict[str, JsonDict]]]: + ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -909,11 +911,14 @@ async def get_e2e_cross_signing_keys_bulk( result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: - result = await self.db_pool.runInteraction( - "get_e2e_cross_signing_signatures", - self._get_e2e_cross_signing_signatures_txn, - result, - from_user_id, + result = cast( + Dict[str, Optional[Mapping[str, JsonDict]]], + await self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ), ) return result diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ca47a22bf179..1691ac990b38 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -21,6 +21,7 @@ Iterable, List, Optional, + Sequence, Set, Tuple, cast, @@ -950,7 +951,7 @@ def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]: ) @cached(max_entries=5000, iterable=True) - async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: + async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, @@ -980,7 +981,7 @@ def _get_min_depth_interaction( @cancellable async def get_forward_extremities_for_room_at_stream_ordering( self, room_id: str, stream_ordering: int - ) -> List[str]: + ) -> Sequence[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -1013,7 +1014,7 @@ async def get_forward_extremities_for_room_at_stream_ordering( @cached(max_entries=5000, num_args=2) async def _get_forward_extremeties_for_room( self, room_id: str, stream_ordering: int - ) -> List[str]: + ) -> Sequence[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index efd136a86474..d5567ad3a950 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -95,7 +95,7 @@ def _count_users(txn: LoggingTransaction) -> int: return await self.db_pool.runInteraction("count_users", _count_users) @cached(num_args=0) - async def get_monthly_active_count_by_service(self) -> Dict[str, int]: + async def get_monthly_active_count_by_service(self) -> Mapping[str, int]: """Generates current count of monthly active users broken down by service. A service is typically an appservice but also includes native matrix users. Since the `monthly_active_users` table is populated from the `user_ips` table diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 124c70ad37b6..6e26cc2126e5 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -21,7 +21,9 @@ Dict, Iterable, List, + Mapping, Optional, + Sequence, Tuple, cast, ) @@ -312,7 +314,7 @@ async def get_linearized_receipts_for_rooms( async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> Sequence[dict]: """Get receipts for a single room for sending to clients. Args: @@ -335,7 +337,7 @@ async def get_linearized_receipts_for_room( @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[JsonDict]: + ) -> Sequence[JsonDict]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: @@ -378,7 +380,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: ) async def _get_linearized_receipts_for_rooms( self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None - ) -> Dict[str, List[JsonDict]]: + ) -> Dict[str, Sequence[JsonDict]]: if not room_ids: return {} @@ -438,7 +440,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonDict]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7fb9c801dac8..46d44e99f8f0 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast import attr @@ -164,7 +164,7 @@ def __init__( ) @cached() - async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7bd27790ebfe..c101bdf0a2b6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,12 +14,14 @@ import logging from typing import ( + AbstractSet, Collection, Dict, FrozenSet, Iterable, List, Optional, + Sequence, Set, Tuple, Union, @@ -66,7 +68,7 @@ async def get_relations_for_event( direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: + ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -243,7 +245,7 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool: @cached(tree=True) async def get_aggregation_groups_for_event( self, event_id: str, room_id: str, limit: int = 5 - ) -> List[JsonDict]: + ) -> Sequence[JsonDict]: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -764,7 +766,7 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool: @cached(iterable=True) async def get_mutual_event_relations_for_rel_type( self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: + ) -> AbstractSet[Tuple[str, str]]: raise NotImplementedError() @cachedList( @@ -773,7 +775,7 @@ async def get_mutual_event_relations_for_rel_type( ) async def get_mutual_event_relations( self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: + ) -> Dict[str, AbstractSet[Tuple[str, str]]]: """ Fetch event metadata for events which related to the same event as the given event. @@ -810,8 +812,20 @@ def _get_event_relations( result[rel_type].add((sender, type)) return result - return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations + # Cast the values from `Set`s to `AbstractSet`s. + return cast( + Dict[ + str, + AbstractSet[ + Tuple[ + str, + str, + ] + ], + ], + await self.db_pool.runInteraction( + "get_event_relations", _get_event_relations + ), ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 6e1ff5626bcf..eff58ff27644 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,6 +15,7 @@ import logging from typing import ( TYPE_CHECKING, + AbstractSet, Callable, Collection, Dict, @@ -23,6 +24,7 @@ List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -186,7 +188,7 @@ def _check_safe_current_state_events_membership_updated_txn( ) @cached(max_entries=100000, iterable=True) - async def get_users_in_room(self, room_id: str) -> List[str]: + async def get_users_in_room(self, room_id: str) -> Sequence[str]: """ Returns a list of users in the room sorted by longest in the room first (aka. with the lowest depth). This is done to match the sort in @@ -242,9 +244,7 @@ def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[s return [r[0] for r in txn] @cached() - def get_user_in_room_with_profile( - self, room_id: str, user_id: str - ) -> Dict[str, ProfileInfo]: + def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo: raise NotImplementedError() @cachedList( @@ -298,7 +298,7 @@ def _get_subset_users_in_room_with_profiles( @cached(max_entries=100000, iterable=True) async def get_users_in_room_with_profiles( self, room_id: str - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """Get a mapping from user ID to profile information for all users in a given room. The profile information comes directly from this room's `m.room.member` @@ -337,7 +337,7 @@ def _get_users_in_room_with_profiles( ) @cached(max_entries=100000) - async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: + async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]: """Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: @@ -435,7 +435,7 @@ async def get_number_joined_users_in_room(self, room_id: str) -> int: @cached() async def get_invited_rooms_for_local_user( self, user_id: str - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Get all the rooms the *local* user is invited to. Args: @@ -498,10 +498,11 @@ async def get_rooms_for_local_user_where_membership_is( ) # Now we filter out forgotten and excluded rooms - rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id) + rooms_to_exclude = await self.get_forgotten_rooms_for_user(user_id) if excluded_rooms is not None: - rooms_to_exclude.update(set(excluded_rooms)) + rooms_to_exclude = set(rooms_to_exclude) + rooms_to_exclude.update(excluded_rooms) return [room for room in rooms if room.room_id not in rooms_to_exclude] @@ -551,7 +552,7 @@ def _get_rooms_for_local_user_where_membership_is_txn( return results @cached(iterable=True) - async def get_local_users_in_room(self, room_id: str) -> List[str]: + async def get_local_users_in_room(self, room_id: str) -> Sequence[str]: """ Retrieves a list of the current roommembers who are local to the server. """ @@ -859,7 +860,7 @@ async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]: """Returns the set of users who share a room with `user_id`""" room_ids = await self.get_rooms_for_user(user_id) - user_who_share_room = set() + user_who_share_room: Set[str] = set() for room_id in room_ids: user_ids = await self.get_users_in_room(room_id) user_who_share_room.update(user_ids) @@ -1021,7 +1022,7 @@ async def _check_host_room_membership( return True @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room(self, room_id: str) -> List[str]: + async def get_current_hosts_in_room(self, room_id: str) -> Sequence[str]: """ Get current hosts in room based on current state. @@ -1225,7 +1226,7 @@ def f(txn: LoggingTransaction) -> int: return count == 0 @cached() - async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: + async def get_forgotten_rooms_for_user(self, user_id: str) -> AbstractSet[str]: """Gets all rooms the user has forgotten. Args: diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index 05da15074a73..5dcb1fc0b5f4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, List, Tuple +from typing import Collection, Dict, List, Mapping, Tuple from unpaddedbase64 import encode_base64 @@ -26,7 +26,7 @@ class SignatureWorkerStore(EventsWorkerStore): @cached() - def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]: + def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]: # This is a dummy function to allow get_event_reference_hashes # to use its cache raise NotImplementedError() @@ -36,7 +36,7 @@ def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]] ) async def get_event_reference_hashes( self, event_ids: Collection[str] - ) -> Dict[str, Dict[str, bytes]]: + ) -> Mapping[str, Mapping[str, bytes]]: """Get all hashes for given events. Args: diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index b0f5de67a30d..082a09880d25 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Iterable, List, Tuple, cast +from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast from synapse.replication.tcp.streams import TagAccountDataStream from synapse.storage._base import db_to_json @@ -31,7 +31,9 @@ class TagsWorkerStore(AccountDataWorkerStore): @cached() - async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: + async def get_tags_for_user( + self, user_id: str + ) -> Mapping[str, Mapping[str, JsonDict]]: """Get all the tags for a user. @@ -131,7 +133,7 @@ def get_tag_content( async def get_updated_tags( self, user_id: str, stream_id: int - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonDict]]: """Get all the tags for the rooms where the tags have changed since the given version diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ddb25b5cea7f..d5829dd20556 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -16,9 +16,9 @@ import re from typing import ( TYPE_CHECKING, - Dict, Iterable, List, + Mapping, Optional, Sequence, Set, @@ -581,7 +581,7 @@ def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None: ) @cached() - async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]: + async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index a2f347f666e8..f71ff46d8777 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Sequence from twisted.test.proto_helpers import MemoryReactor @@ -558,7 +558,7 @@ def test_update_notice_user_avatar_when_changed(self) -> None: def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Check invite and room membership status of a user. Args From ed43ae3c0865174684f51e9883c4e631e4a34456 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 31 Jan 2023 17:14:14 +0000 Subject: [PATCH 2/7] WIP cleanup merge --- synapse/events/builder.py | 6 +++--- synapse/federation/federation_server.py | 3 ++- synapse/handlers/sync.py | 1 - synapse/push/bulk_push_rule_evaluator.py | 1 - synapse/storage/controllers/state.py | 3 +-- synapse/storage/databases/main/devices.py | 4 +++- synapse/storage/databases/main/event_federation.py | 4 +++- synapse/storage/databases/main/relations.py | 6 +++--- synapse/storage/databases/main/roommember.py | 2 -- 9 files changed, 15 insertions(+), 15 deletions(-) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 94dd1298e177..c82745275f94 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union import attr from signedjson.types import SigningKey @@ -103,7 +103,7 @@ def is_state(self) -> bool: async def build( self, - prev_event_ids: List[str], + prev_event_ids: Collection[str], auth_event_ids: Optional[List[str]], depth: Optional[int] = None, ) -> EventBase: @@ -136,7 +136,7 @@ async def build( format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. - prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] if format_version == EventFormatVersions.ROOM_V1_V2: auth_events = await self._store.add_event_hashes(auth_event_ids) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index c9a6dfd1a4bf..28571597cf25 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -23,6 +23,7 @@ Collection, Dict, List, + Mapping, Optional, Tuple, Union, @@ -1506,7 +1507,7 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict: def _get_event_ids_for_partial_state_join( join_event: EventBase, prev_state_ids: StateMap[str], - summary: Dict[str, MemberSummary], + summary: Mapping[str, MemberSummary], ) -> Collection[str]: """Calculate state to be returned in a partial_state send_join diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6e60e9501f9a..8c903a7d80c4 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1818,7 +1818,6 @@ async def _generate_sync_entry_for_rooms( self, sync_result_builder: "SyncResultBuilder", account_data_by_room: Mapping[str, Mapping[str, JsonDict]], - account_data_by_room: Dict[str, Dict[str, JsonDict]], ) -> Tuple[AbstractSet[str], AbstractSet[str], AbstractSet[str], AbstractSet[str]]: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9e1206d86055..88cfc05d0552 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -16,7 +16,6 @@ import logging from typing import ( TYPE_CHECKING, - AbstractSet, Any, Collection, Dict, diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 58de894b4a88..3ef1d7249926 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -14,6 +14,7 @@ import logging from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Callable, @@ -23,7 +24,6 @@ List, Mapping, Optional, - Set, Tuple, ) @@ -539,7 +539,6 @@ async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: return await self.stores.main.get_current_hosts_in_room(room_id) async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: ->>>>>>> 3dfc4a08dc2e77178f2c2af68dc14b32da2d8b8f """Get current hosts in room based on current state. Blocks until we have full state for the given room. This only happens for rooms diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index beb9c9596ad4..f44205dcdab2 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -202,7 +202,9 @@ def _invalidate_caches_for_devices( def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - async def count_devices_by_users(self, user_ids: Optional[Collection[str]] = None) -> int: + async def count_devices_by_users( + self, user_ids: Optional[Collection[str]] = None + ) -> int: """Retrieve number of all devices of given users. Only returns number of devices that are not marked as hidden. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 07e46b2268b3..ca780cca36ec 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1005,7 +1005,9 @@ def get_insertion_event_backward_extremities_in_room_txn( room_id, ) - async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: + async def get_max_depth_of( + self, event_ids: Collection[str] + ) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs Args: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 1ae142023a5d..fa3266c081b5 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - AbstractSet, Collection, Dict, FrozenSet, @@ -399,7 +398,9 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool: return result is not None @cached() - async def get_aggregation_groups_for_event(self, event_id: str) -> Sequence[JsonDict]: + async def get_aggregation_groups_for_event( + self, event_id: str + ) -> Sequence[JsonDict]: raise NotImplementedError() @cachedList( @@ -409,7 +410,6 @@ async def get_aggregation_groups_for_events( self, event_ids: Collection[str] ) -> Mapping[str, Optional[List[JsonDict]]]: """Get a list of annotations on the given events, grouped by event type and ->>>>>>> 1799a54a545618782840a60950ef4b64da9ee24d aggregation key, sorted by count. This is used e.g. to get the what and how many reactions have happend diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 58d2444331da..482e3eedc805 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -16,7 +16,6 @@ from typing import ( TYPE_CHECKING, AbstractSet, - Callable, Collection, Dict, FrozenSet, @@ -997,7 +996,6 @@ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]: @cached(iterable=True, max_entries=10000) async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: ->>>>>>> 3dfc4a08dc2e77178f2c2af68dc14b32da2d8b8f """ Get current hosts in room based on current state. From ed6f1fe7df9e566215b2a711d9cba474310aea74 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 10 Feb 2023 03:18:26 +0000 Subject: [PATCH 3/7] fixup errors from merge --- synapse/push/push_rule_evaluator.py | 361 ------------------- synapse/storage/databases/main/roommember.py | 9 +- 2 files changed, 1 insertion(+), 369 deletions(-) delete mode 100644 synapse/push/push_rule_evaluator.py diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py deleted file mode 100644 index a6f40b6d2db8..000000000000 --- a/synapse/push/push_rule_evaluator.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from typing import ( - AbstractSet, - Any, - Dict, - List, - Mapping, - Optional, - Pattern, - Sequence, - Tuple, - Union, -) - -from matrix_common.regex import glob_to_regex, to_word_pattern - -from synapse.events import EventBase -from synapse.types import UserID -from synapse.util.caches.lrucache import LruCache - -logger = logging.getLogger(__name__) - - -GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]") -IS_GLOB = re.compile(r"[\?\*\[\]]") -INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") - - -def _room_member_count(condition: Mapping[str, Any], room_member_count: int) -> bool: - return _test_ineq_condition(condition, room_member_count) - - -def _sender_notification_permission( - condition: Mapping[str, Any], - sender_power_level: Optional[int], - power_levels: Dict[str, Union[int, Dict[str, int]]], -) -> bool: - if sender_power_level is None: - return False - - notif_level_key = condition.get("key") - if notif_level_key is None: - return False - - notif_levels = power_levels.get("notifications", {}) - assert isinstance(notif_levels, dict) - room_notif_level = notif_levels.get(notif_level_key, 50) - - return sender_power_level >= room_notif_level - - -def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool: - if "is" not in condition: - return False - m = INEQUALITY_EXPR.match(condition["is"]) - if not m: - return False - ineq = m.group(1) - rhs = m.group(2) - if not rhs.isdigit(): - return False - rhs_int = int(rhs) - - if ineq == "" or ineq == "==": - return number == rhs_int - elif ineq == "<": - return number < rhs_int - elif ineq == ">": - return number > rhs_int - elif ineq == ">=": - return number >= rhs_int - elif ineq == "<=": - return number <= rhs_int - else: - return False - - -def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: - """ - Converts a list of actions into a `tweaks` dict (which can then be passed to - the push gateway). - - This function ignores all actions other than `set_tweak` actions, and treats - absent `value`s as `True`, which agrees with the only spec-defined treatment - of absent `value`s (namely, for `highlight` tweaks). - - Args: - actions: list of actions - e.g. [ - {"set_tweak": "a", "value": "AAA"}, - {"set_tweak": "b", "value": "BBB"}, - {"set_tweak": "highlight"}, - "notify" - ] - - Returns: - dictionary of tweaks for those actions - e.g. {"a": "AAA", "b": "BBB", "highlight": True} - """ - tweaks = {} - for a in actions: - if not isinstance(a, dict): - continue - if "set_tweak" in a: - # value is allowed to be absent in which case the value assumed - # should be True. - tweaks[a["set_tweak"]] = a.get("value", True) - return tweaks - - -class PushRuleEvaluatorForEvent: - def __init__( - self, - event: EventBase, - room_member_count: int, - sender_power_level: Optional[int], - power_levels: Dict[str, Union[int, Dict[str, int]]], - relations: Dict[str, AbstractSet[Tuple[str, str]]], - relations_match_enabled: bool, - ): - self._event = event - self._room_member_count = room_member_count - self._sender_power_level = sender_power_level - self._power_levels = power_levels - self._relations = relations - self._relations_match_enabled = relations_match_enabled - - # Maps strings of e.g. 'content.body' -> event["content"]["body"] - self._value_cache = _flatten_dict(event) - - # Maps cache keys to final values. - self._condition_cache: Dict[str, bool] = {} - - def check_conditions( - self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str] - ) -> bool: - """ - Returns true if a user's conditions/user ID/display name match the event. - - Args: - conditions: The user's conditions to match. - uid: The user's MXID. - display_name: The display name. - - Returns: - True if all conditions match the event, False otherwise. - """ - for cond in conditions: - _cache_key = cond.get("_cache_key", None) - if _cache_key: - res = self._condition_cache.get(_cache_key, None) - if res is False: - return False - elif res is True: - continue - - res = self.matches(cond, uid, display_name) - if _cache_key: - self._condition_cache[_cache_key] = bool(res) - - if not res: - return False - - return True - - def matches( - self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str] - ) -> bool: - """ - Returns true if a user's condition/user ID/display name match the event. - - Args: - condition: The user's condition to match. - uid: The user's MXID. - display_name: The display name, or None if there is not one. - - Returns: - True if the condition matches the event, False otherwise. - """ - if condition["kind"] == "event_match": - return self._event_match(condition, user_id) - elif condition["kind"] == "contains_display_name": - return self._contains_display_name(display_name) - elif condition["kind"] == "room_member_count": - return _room_member_count(condition, self._room_member_count) - elif condition["kind"] == "sender_notification_permission": - return _sender_notification_permission( - condition, self._sender_power_level, self._power_levels - ) - elif ( - condition["kind"] == "org.matrix.msc3772.relation_match" - and self._relations_match_enabled - ): - return self._relation_match(condition, user_id) - else: - # XXX This looks incorrect -- we have reached an unknown condition - # kind and are unconditionally returning that it matches. Note - # that it seems possible to provide a condition to the /pushrules - # endpoint with an unknown kind, see _rule_tuple_from_request_object. - return True - - def _event_match(self, condition: Mapping, user_id: str) -> bool: - """ - Check an "event_match" push rule condition. - - Args: - condition: The "event_match" push rule condition to match. - user_id: The user's MXID. - - Returns: - True if the condition matches the event, False otherwise. - """ - pattern = condition.get("pattern", None) - - if not pattern: - pattern_type = condition.get("pattern_type", None) - if pattern_type == "user_id": - pattern = user_id - elif pattern_type == "user_localpart": - pattern = UserID.from_string(user_id).localpart - - if not pattern: - logger.warning("event_match condition with no pattern") - return False - - # XXX: optimisation: cache our pattern regexps - if condition["key"] == "content.body": - body = self._event.content.get("body", None) - if not body or not isinstance(body, str): - return False - - return _glob_matches(pattern, body, word_boundary=True) - else: - haystack = self._value_cache.get(condition["key"], None) - if haystack is None: - return False - - return _glob_matches(pattern, haystack) - - def _contains_display_name(self, display_name: Optional[str]) -> bool: - """ - Check an "event_match" push rule condition. - - Args: - display_name: The display name, or None if there is not one. - - Returns: - True if the display name is found in the event body, False otherwise. - """ - if not display_name: - return False - - body = self._event.content.get("body", None) - if not body or not isinstance(body, str): - return False - - # Similar to _glob_matches, but do not treat display_name as a glob. - r = regex_cache.get((display_name, False, True), None) - if not r: - r1 = re.escape(display_name) - r1 = to_word_pattern(r1) - r = re.compile(r1, flags=re.IGNORECASE) - regex_cache[(display_name, False, True)] = r - - return bool(r.search(body)) - - def _relation_match(self, condition: Mapping, user_id: str) -> bool: - """ - Check an "relation_match" push rule condition. - - Args: - condition: The "event_match" push rule condition to match. - user_id: The user's MXID. - - Returns: - True if the condition matches the event, False otherwise. - """ - rel_type = condition.get("rel_type") - if not rel_type: - logger.warning("relation_match condition missing rel_type") - return False - - sender_pattern = condition.get("sender") - if sender_pattern is None: - sender_type = condition.get("sender_type") - if sender_type == "user_id": - sender_pattern = user_id - type_pattern = condition.get("type") - - # If any other relations matches, return True. - for sender, event_type in self._relations.get(rel_type, ()): - if sender_pattern and not _glob_matches(sender_pattern, sender): - continue - if type_pattern and not _glob_matches(type_pattern, event_type): - continue - # All values must have matched. - return True - - # No relations matched. - return False - - -# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( - 50000, "regex_push_cache" -) - - -def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: - """Tests if value matches glob. - - Args: - glob - value: String to test against glob. - word_boundary: Whether to match against word boundaries or entire - string. Defaults to False. - """ - - try: - r = regex_cache.get((glob, True, word_boundary), None) - if not r: - r = glob_to_regex(glob, word_boundary=word_boundary) - regex_cache[(glob, True, word_boundary)] = r - return bool(r.search(value)) - except re.error: - logger.warning("Failed to parse glob to regex: %r", glob) - return False - - -def _flatten_dict( - d: Union[EventBase, Mapping[str, Any]], - prefix: Optional[List[str]] = None, - result: Optional[Dict[str, str]] = None, -) -> Dict[str, str]: - if prefix is None: - prefix = [] - if result is None: - result = {} - for key, value in d.items(): - if isinstance(value, str): - result[".".join(prefix + [key])] = value.lower() - elif isinstance(value, Mapping): - _flatten_dict(value, prefix=(prefix + [key]), result=result) - - return result diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 3f32864b1c0b..694a5b802c7c 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -155,14 +155,7 @@ def _transact(txn: LoggingTransaction) -> int: @cached(max_entries=100000, iterable=True) async def get_users_in_room(self, room_id: str) -> Sequence[str]: - """ - Returns a list of users in the room sorted by longest in the room first - (aka. with the lowest depth). This is done to match the sort in - `get_current_hosts_in_room()` and so we can re-use the cache but it's - not horrible to have here either. - - Uses `m.room.member`s in the room state at the current forward extremities to - determine which users are in the room. + """Returns a list of users in the room. Will return inaccurate results for rooms with partial state, since the state for the forward extremities of those rooms will exclude most members. We may also From 1fe7d82947ae92959398ef3fd5858e4948916c54 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 10 Feb 2023 03:19:56 +0000 Subject: [PATCH 4/7] fixup: change return type of get_linearized_receipts_for_room --- synapse/storage/databases/main/receipts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 2709d28cbff6..dddf49c2d575 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -290,7 +290,7 @@ async def get_linearized_receipts_for_rooms( async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> Sequence[dict]: + ) -> Sequence[JsonDict]: """Get receipts for a single room for sending to clients. Args: From 7f0f531d0fdebd700c08d0d90a264fe204d40a51 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 10 Feb 2023 03:41:50 +0000 Subject: [PATCH 5/7] fix mypy complaint --- synapse/push/bulk_push_rule_evaluator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 39d2f88f0349..00a0041fdc20 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -22,6 +22,7 @@ List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -148,7 +149,7 @@ async def _get_rules_for_event( # little, we can skip fetching a huge number of push rules in large rooms. # This helps make joins and leaves faster. if event.type == EventTypes.Member: - local_users = [] + local_users: Sequence[str] = [] # We never notify a user about their own actions. This is enforced in # `_action_for_event_by_user` in the loop over `rules_by_user`, but we # do the same check here to avoid unnecessary DB queries. @@ -183,7 +184,6 @@ async def _get_rules_for_event( if event.type == EventTypes.Member and event.membership == Membership.INVITE: invited = event.state_key if invited and self.hs.is_mine_id(invited) and invited not in local_users: - local_users = list(local_users) local_users.append(invited) if not local_users: From d7f84d7bdcc6bee1fb976e65eb190b97f4a68d4e Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 10 Feb 2023 04:07:47 +0000 Subject: [PATCH 6/7] make mypy happy again after merge --- synapse/handlers/directory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 0a67b8db063b..a5798e9483ce 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -486,7 +486,7 @@ async def edit_published_room_list( ) if canonical_alias: # Ensure we do not mutate room_aliases. - room_aliases = room_aliases + [canonical_alias] + room_aliases = list(room_aliases) + [canonical_alias] if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases From 08e9135e5daa402dabd65ca73b75ba51f3b78bc5 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 10 Feb 2023 15:46:56 +0000 Subject: [PATCH 7/7] fix mypy complaints after merge --- synapse/storage/databases/main/devices.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 933333e82cb6..1ca66d57d40c 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -750,7 +750,7 @@ def _add_user_signature_change_txn( @cancellable async def get_user_devices_from_cache( self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] - ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: + ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: @@ -778,16 +778,18 @@ async def get_user_devices_from_cache( user_ids_not_in_cache = unique_user_ids - user_ids_in_cache # First fetch all the users which all devices are to be returned. - results: Dict[str, Dict[str, JsonDict]] = {} + results: Dict[str, Mapping[str, JsonDict]] = {} for user_id in user_ids: if user_id in user_ids_in_cache: results[user_id] = await self.get_cached_devices_for_user(user_id) # Then fetch all device-specific requests, but skip users we've already # fetched all devices for. + device_specific_results: Dict[str, Dict[str, JsonDict]] = {} for user_id, device_id in user_and_device_ids: if user_id in user_ids_in_cache and user_id not in user_ids: device = await self._get_cached_user_device(user_id, device_id) - results.setdefault(user_id, {})[device_id] = device + device_specific_results.setdefault(user_id, {})[device_id] = device + results.update(device_specific_results) set_tag("in_cache", str(results)) set_tag("not_in_cache", str(user_ids_not_in_cache))