Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Return read-only collections from @cached methods (#13755)
Browse files Browse the repository at this point in the history
It's important that collections returned from `@cached` methods are not
modified, otherwise future retrievals from the cache will return the
modified collection.

This applies to the return values from `@cached` methods and the values
inside the dictionaries returned by `@cachedList` methods. It's not
necessary for the dictionaries returned by `@cachedList` methods
themselves to be read-only.

Signed-off-by: Sean Quah <[email protected]>
Co-authored-by: David Robertson <[email protected]>
  • Loading branch information
squahtx and David Robertson authored Feb 10, 2023
1 parent 14be78d commit d0c713c
Show file tree
Hide file tree
Showing 27 changed files with 98 additions and 77 deletions.
1 change: 1 addition & 0 deletions changelog.d/13755.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Re-type hint some collections as read-only.
4 changes: 2 additions & 2 deletions synapse/app/phone_stats_home.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import math
import resource
import sys
from typing import TYPE_CHECKING, List, Sized, Tuple
from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple

from prometheus_client import Gauge

Expand Down Expand Up @@ -194,7 +194,7 @@ def performance_stats_init() -> None:
@wrap_as_background_process("generate_monthly_active_users")
async def generate_monthly_active_users() -> None:
current_mau_count = 0
current_mau_count_by_service = {}
current_mau_count_by_service: Mapping[str, int] = {}
reserved_users: Sized = ()
store = hs.get_datastores().main
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
Expand Down
6 changes: 3 additions & 3 deletions synapse/config/room_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List
from typing import Any, Collection

from matrix_common.regex import glob_to_regex

Expand Down Expand Up @@ -70,7 +70,7 @@ def is_alias_creation_allowed(self, user_id: str, room_id: str, alias: str) -> b
return False

def is_publishing_room_allowed(
self, user_id: str, room_id: str, aliases: List[str]
self, user_id: str, room_id: str, aliases: Collection[str]
) -> bool:
"""Checks if the given user is allowed to publish the room
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(self, option_name: str, rule: JsonDict):
except Exception as e:
raise ConfigError("Failed to parse glob into regex") from e

def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool:
def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool:
"""Tests if this rule matches the given user_id, room_id and aliases.
Args:
Expand Down
6 changes: 3 additions & 3 deletions synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union

import attr
from signedjson.types import SigningKey
Expand Down Expand Up @@ -103,7 +103,7 @@ def is_state(self) -> bool:

async def build(
self,
prev_event_ids: List[str],
prev_event_ids: Collection[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
Expand Down Expand Up @@ -136,7 +136,7 @@ async def build(

format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)
Expand Down
3 changes: 2 additions & 1 deletion synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Collection,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
Expand Down Expand Up @@ -1512,7 +1513,7 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict:
def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
summary: Dict[str, MemberSummary],
summary: Mapping[str, MemberSummary],
) -> Collection[str]:
"""Calculate state to be returned in a partial_state send_join
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import string
from typing import TYPE_CHECKING, Iterable, List, Optional
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence

from typing_extensions import Literal

Expand Down Expand Up @@ -486,7 +486,7 @@ async def edit_published_room_list(
)
if canonical_alias:
# Ensure we do not mutate room_aliases.
room_aliases = room_aliases + [canonical_alias]
room_aliases = list(room_aliases) + [canonical_alias]

if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
Expand Down Expand Up @@ -529,7 +529,7 @@ async def edit_published_appservice_room_list(

async def get_aliases_for_room(
self, requester: Requester, room_id: str
) -> List[str]:
) -> Sequence[str]:
"""
Get a list of the aliases that currently point to this room on this server
"""
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple

from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.appservice import ApplicationService
Expand Down Expand Up @@ -189,7 +189,7 @@ def __init__(self, hs: "HomeServer"):

@staticmethod
def filter_out_private_receipts(
rooms: List[JsonDict], user_id: str
rooms: Sequence[JsonDict], user_id: str
) -> List[JsonDict]:
"""
Filters a list of serialized receipts (as returned by /sync and /initialSync)
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,6 @@ async def shutdown_room(
return {
"kicked_users": kicked_users,
"failed_to_kick_users": failed_to_kick_users,
"local_aliases": aliases_for_room,
"local_aliases": list(aliases_for_room),
"new_room_id": new_room_id,
}
4 changes: 2 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,7 +1519,7 @@ async def generate_sync_result(
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = (
unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)

Expand Down Expand Up @@ -2301,7 +2301,7 @@ async def _generate_room_entry(
sync_result_builder: "SyncResultBuilder",
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
tags: Optional[Mapping[str, Mapping[str, Any]]],
account_data: Mapping[str, JsonDict],
always_include: bool = False,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
Expand Down Expand Up @@ -149,7 +150,7 @@ async def _get_rules_for_event(
# little, we can skip fetching a huge number of push rules in large rooms.
# This helps make joins and leaves faster.
if event.type == EventTypes.Member:
local_users = []
local_users: Sequence[str] = []
# We never notify a user about their own actions. This is enforced in
# `_action_for_event_by_user` in the loop over `rules_by_user`, but we
# do the same check here to avoid unnecessary DB queries.
Expand Down Expand Up @@ -184,7 +185,6 @@ async def _get_rules_for_event(
if event.type == EventTypes.Member and event.membership == Membership.INVITE:
invited = event.state_key
if invited and self.hs.is_mine_id(invited) and invited not in local_users:
local_users = list(local_users)
local_users.append(invited)

if not local_users:
Expand Down
2 changes: 1 addition & 1 deletion synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ async def compute_state_after_events(
return await ret.get_state(self._state_storage_controller, state_filter)

async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str]
self, room_id: str, latest_event_ids: Collection[str]
) -> Set[str]:
"""
Get the users IDs who are currently in a room.
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Awaitable,
Callable,
Expand All @@ -23,7 +24,6 @@
List,
Mapping,
Optional,
Set,
Tuple,
)

Expand Down Expand Up @@ -527,7 +527,7 @@ async def get_current_state_event(
)
return state_map.get(key)

async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
Expand Down Expand Up @@ -584,7 +584,7 @@ async def get_current_hosts_in_room_or_partial_state_approximation(

async def get_users_in_room_with_profiles(
self, room_id: str
) -> Dict[str, ProfileInfo]:
) -> Mapping[str, ProfileInfo]:
"""
Get the current users in the room with their profiles.
If the room is currently partial-stated, this will block until the room has
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ async def get_global_account_data_by_type_for_user(
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
) -> Mapping[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def get_app_service_users_in_room(
room_id: str,
app_service: "ApplicationService",
cache_context: _CacheContext,
) -> List[str]:
) -> Sequence[str]:
"""
Get all users in a room that the appservice controls.
Expand Down
17 changes: 11 additions & 6 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -202,7 +203,9 @@ def _invalidate_caches_for_devices(
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
async def count_devices_by_users(
self, user_ids: Optional[Collection[str]] = None
) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
Expand All @@ -213,7 +216,7 @@ async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) ->
"""

def count_devices_by_users_txn(
txn: LoggingTransaction, user_ids: List[str]
txn: LoggingTransaction, user_ids: Collection[str]
) -> int:
sql = """
SELECT count(*)
Expand Down Expand Up @@ -747,7 +750,7 @@ def _add_user_signature_change_txn(
@cancellable
async def get_user_devices_from_cache(
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
Expand Down Expand Up @@ -775,16 +778,18 @@ async def get_user_devices_from_cache(
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache

# First fetch all the users which all devices are to be returned.
results: Dict[str, Dict[str, JsonDict]] = {}
results: Dict[str, Mapping[str, JsonDict]] = {}
for user_id in user_ids:
if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
# Then fetch all device-specific requests, but skip users we've already
# fetched all devices for.
device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in user_and_device_ids:
if user_id in user_ids_in_cache and user_id not in user_ids:
device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
device_specific_results.setdefault(user_id, {})[device_id] = device
results.update(device_specific_results)

set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
Expand All @@ -802,7 +807,7 @@ async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDic
return db_to_json(content)

@cached()
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Sequence, Tuple

import attr

Expand Down Expand Up @@ -74,7 +74,7 @@ async def get_room_alias_creator(self, room_alias: str) -> str:
)

@cached(max_entries=5000)
async def get_aliases_for_room(self, room_id: str) -> List[str]:
async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
Expand Down
Loading

0 comments on commit d0c713c

Please sign in to comment.