Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Return immutable objects for cachedList decorators (#16350)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Sep 19, 2023
1 parent 5a66ff2 commit d7c89c5
Show file tree
Hide file tree
Showing 24 changed files with 134 additions and 100 deletions.
1 change: 1 addition & 0 deletions changelog.d/16350.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
6 changes: 3 additions & 3 deletions synapse/appservice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import DeviceListUpdates, JsonDict, UserID
from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID
from synapse.util.caches.descriptors import _CacheContext, cached

if TYPE_CHECKING:
Expand Down Expand Up @@ -379,8 +379,8 @@ def __init__(
service: ApplicationService,
id: int,
events: Sequence[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
ephemeral: List[JsonMapping],
to_device_messages: List[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
Expand Down
6 changes: 3 additions & 3 deletions synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
from synapse.logging import opentracing
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache

if TYPE_CHECKING:
Expand Down Expand Up @@ -306,8 +306,8 @@ async def push_bulk(
self,
service: "ApplicationService",
events: Sequence[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
ephemeral: List[JsonMapping],
to_device_messages: List[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
Expand Down
18 changes: 9 additions & 9 deletions synapse/appservice/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore
from synapse.types import DeviceListUpdates, JsonDict
from synapse.types import DeviceListUpdates, JsonMapping
from synapse.util import Clock

if TYPE_CHECKING:
Expand Down Expand Up @@ -121,8 +121,8 @@ def enqueue_for_appservice(
self,
appservice: ApplicationService,
events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonDict]] = None,
to_device_messages: Optional[Collection[JsonDict]] = None,
ephemeral: Optional[Collection[JsonMapping]] = None,
to_device_messages: Optional[Collection[JsonMapping]] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
) -> None:
"""
Expand Down Expand Up @@ -180,9 +180,9 @@ def __init__(
# dict of {service_id: [events]}
self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]}
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
self.queued_ephemeral: Dict[str, List[JsonMapping]] = {}
# dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
self.queued_to_device_messages: Dict[str, List[JsonMapping]] = {}
# dict of {service_id: [device_list_summary]}
self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {}

Expand Down Expand Up @@ -293,8 +293,8 @@ async def _compute_msc3202_otk_counts_and_fallback_keys(
self,
service: ApplicationService,
events: Iterable[EventBase],
ephemerals: Iterable[JsonDict],
to_device_messages: Iterable[JsonDict],
ephemerals: Iterable[JsonMapping],
to_device_messages: Iterable[JsonMapping],
) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]:
"""
Given a list of the events, ephemeral messages and to-device messages,
Expand Down Expand Up @@ -364,8 +364,8 @@ async def send(
self,
service: ApplicationService,
events: Sequence[EventBase],
ephemeral: Optional[List[JsonDict]] = None,
to_device_messages: Optional[List[JsonDict]] = None,
ephemeral: Optional[List[JsonMapping]] = None,
to_device_messages: Optional[List[JsonMapping]] = None,
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
Expand Down
9 changes: 5 additions & 4 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
RoomAlias,
RoomStreamToken,
StreamKeyType,
Expand Down Expand Up @@ -397,7 +398,7 @@ async def _notify_interested_services_ephemeral(

async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
) -> List[JsonMapping]:
"""
Return the typing events since the given stream token that the given application
service should receive.
Expand Down Expand Up @@ -432,7 +433,7 @@ async def _handle_typing(

async def _handle_receipts(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
) -> List[JsonMapping]:
"""
Return the latest read receipts that the given application service should receive.
Expand Down Expand Up @@ -471,7 +472,7 @@ async def _handle_presence(
service: ApplicationService,
users: Collection[Union[str, UserID]],
new_token: Optional[int],
) -> List[JsonDict]:
) -> List[JsonMapping]:
"""
Return the latest presence updates that the given application service should receive.
Expand All @@ -491,7 +492,7 @@ async def _handle_presence(
A list of json dictionaries containing data derived from the presence events
that should be sent to the given application service.
"""
events: List[JsonDict] = []
events: List[JsonMapping] = []
presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
Expand Down
24 changes: 9 additions & 15 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple

import attr
from canonicaljson import encode_canonical_json
Expand All @@ -31,6 +31,7 @@
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.types import (
JsonDict,
JsonMapping,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
Expand Down Expand Up @@ -272,11 +273,7 @@ async def _query(
delay_cancellation=True,
)

ret = {"device_keys": results, "failures": failures}

ret.update(cross_signing_keys)

return ret
return {"device_keys": results, "failures": failures, **cross_signing_keys}

@trace
async def _query_devices_for_destination(
Expand Down Expand Up @@ -408,7 +405,7 @@ async def _query_devices_for_destination(
@cancellable
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
) -> Dict[str, Dict[str, JsonMapping]]:
"""Get cross-signing keys for users from the database
Args:
Expand Down Expand Up @@ -551,16 +548,13 @@ async def on_federation_query_client_keys(
self.config.federation.allow_device_name_lookup_over_federation
),
)
ret = {"device_keys": res}

# add in the cross-signing keys
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, None
)

ret.update(cross_signing_keys)

return ret
return {"device_keys": res, **cross_signing_keys}

async def claim_local_one_time_keys(
self,
Expand Down Expand Up @@ -1127,7 +1121,7 @@ def _check_master_key_signature(
user_id: str,
master_key_id: str,
signed_master_key: JsonDict,
stored_master_key: JsonDict,
stored_master_key: JsonMapping,
devices: Dict[str, Dict[str, JsonDict]],
) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices.
Expand Down Expand Up @@ -1278,7 +1272,7 @@ async def _process_other_signatures(

async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Tuple[JsonDict, str, VerifyKey]:
) -> Tuple[JsonMapping, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage.
Expand Down Expand Up @@ -1333,7 +1327,7 @@ async def _retrieve_cross_signing_keys_for_remote_user(
self,
user: UserID,
desired_key_type: str,
) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
) -> Optional[Tuple[JsonMapping, str, VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
Expand Down Expand Up @@ -1474,7 +1468,7 @@ def _check_device_signature(
user_id: str,
verify_key: VerifyKey,
signed_device: JsonDict,
stored_device: JsonDict,
stored_device: JsonMapping,
) -> None:
"""Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from synapse.streams.config import PaginationConfig
from synapse.types import (
JsonDict,
JsonMapping,
Requester,
RoomStreamToken,
StreamKeyType,
Expand Down Expand Up @@ -454,7 +455,7 @@ async def get_presence() -> List[JsonDict]:
for s in states
]

async def get_receipts() -> List[JsonDict]:
async def get_receipts() -> List[JsonMapping]:
receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)
Expand Down
13 changes: 7 additions & 6 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
JsonMapping,
ReadReceipt,
StreamKeyType,
UserID,
Expand Down Expand Up @@ -204,15 +205,15 @@ async def received_client_receipt(
await self.federation_sender.send_read_receipt(receipt)


class ReceiptEventSource(EventSource[int, JsonDict]):
class ReceiptEventSource(EventSource[int, JsonMapping]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.config = hs.config

@staticmethod
def filter_out_private_receipts(
rooms: Sequence[JsonDict], user_id: str
) -> List[JsonDict]:
rooms: Sequence[JsonMapping], user_id: str
) -> List[JsonMapping]:
"""
Filters a list of serialized receipts (as returned by /sync and /initialSync)
and removes private read receipts of other users.
Expand All @@ -229,7 +230,7 @@ def filter_out_private_receipts(
The same as rooms, but filtered.
"""

result = []
result: List[JsonMapping] = []

# Iterate through each room's receipt content.
for room in rooms:
Expand Down Expand Up @@ -282,7 +283,7 @@ async def get_new_events(
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[JsonMapping], int]:
from_key = int(from_key)
to_key = self.get_current_key()

Expand All @@ -301,7 +302,7 @@ async def get_new_events(

async def get_new_events_as(
self, from_key: int, to_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[JsonMapping], int]:
"""Returns a set of new read receipt events that an appservice
may be interested in.
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class SyncResult:
archived: List[ArchivedSyncResult]
to_device: List[JsonDict]
device_lists: DeviceListUpdates
device_one_time_keys_count: JsonDict
device_one_time_keys_count: JsonMapping
device_unused_fallback_key_types: List[str]

def __bool__(self) -> bool:
Expand Down Expand Up @@ -1558,7 +1558,7 @@ async def generate_sync_result(

logger.debug("Fetching OTK data")
device_id = sync_config.device_id
one_time_keys_count: JsonDict = {}
one_time_keys_count: JsonMapping = {}
unused_fallback_key_types: List[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
Expand Down
17 changes: 12 additions & 5 deletions synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@
)
from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource
from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID
from synapse.types import (
JsonDict,
JsonMapping,
Requester,
StrCollection,
StreamKeyType,
UserID,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.retryutils import filter_destinations_by_retry_limiter
Expand Down Expand Up @@ -487,7 +494,7 @@ def process_replication_rows(
raise Exception("Typing writer instance got typing info over replication")


class TypingNotificationEventSource(EventSource[int, JsonDict]):
class TypingNotificationEventSource(EventSource[int, JsonMapping]):
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
self.clock = hs.get_clock()
Expand All @@ -497,7 +504,7 @@ def __init__(self, hs: "HomeServer"):
#
self.get_typing_handler = hs.get_typing_handler

def _make_event_for(self, room_id: str) -> JsonDict:
def _make_event_for(self, room_id: str) -> JsonMapping:
typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": EduTypes.TYPING,
Expand All @@ -507,7 +514,7 @@ def _make_event_for(self, room_id: str) -> JsonDict:

async def get_new_events_as(
self, from_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[JsonMapping], int]:
"""Returns a set of new typing events that an appservice
may be interested in.
Expand Down Expand Up @@ -551,7 +558,7 @@ async def get_new_events(
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[JsonMapping], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
Expand Down
2 changes: 1 addition & 1 deletion synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(self, hs: "HomeServer"):
async def _get_rules_for_event(
self,
event: EventBase,
) -> Dict[str, FilteredPushRules]:
) -> Mapping[str, FilteredPushRules]:
"""Get the push rules for all users who may need to be notified about
the event.
Expand Down
Loading

0 comments on commit d7c89c5

Please sign in to comment.