Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints for event streams. (#10856)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Sep 21, 2021
1 parent b25a494 commit 4054dfa
Show file tree
Hide file tree
Showing 18 changed files with 169 additions and 60 deletions.
1 change: 1 addition & 0 deletions changelog.d/10856.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to handlers.
13 changes: 10 additions & 3 deletions synapse/handlers/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import TYPE_CHECKING, Any, List, Tuple
from typing import TYPE_CHECKING, Collection, List, Optional, Tuple

from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
ReplicationRemoveTagRestServlet,
ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet,
)
from synapse.streams import EventSource
from synapse.types import JsonDict, UserID

if TYPE_CHECKING:
Expand Down Expand Up @@ -163,15 +164,21 @@ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> in
return response["max_stream_id"]


class AccountDataEventSource:
class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_account_data_stream_id()

async def get_new_events(
self, user: UserID, from_key: int, **kwargs: Any
self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ async def _notify_interested_services_ephemeral(
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
typing_source = self.event_sources.sources["typing"]
typing_source = self.event_sources.sources.typing
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
service=service,
Expand All @@ -269,7 +269,7 @@ async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
receipts_source = self.event_sources.sources["receipt"]
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key
)
Expand All @@ -279,7 +279,7 @@ async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"]
presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def _snapshot_all_rooms(

now_token = self.hs.get_event_sources().get_current_token()

presence_stream = self.hs.get_event_sources().sources["presence"]
presence_stream = self.hs.get_event_sources().sources.presence
presence, _ = await presence_stream.get_new_events(
user, from_key=None, include_offline=False
)
Expand Down
8 changes: 5 additions & 3 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached
Expand Down Expand Up @@ -1500,7 +1501,7 @@ def format_user_presence_state(
return content


class PresenceEventSource:
class PresenceEventSource(EventSource[int, UserPresenceState]):
def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle:
#
Expand All @@ -1519,10 +1520,11 @@ async def get_new_events(
self,
user: UserID,
from_key: Optional[int],
limit: Optional[int] = None,
room_ids: Optional[List[str]] = None,
include_offline: bool = True,
is_guest: bool = False,
explicit_room_id: Optional[str] = None,
**kwargs: Any,
include_offline: bool = True,
) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are:
# 1. Get the rooms the user is in.
Expand Down
13 changes: 10 additions & 3 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple

from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id

if TYPE_CHECKING:
Expand Down Expand Up @@ -162,7 +163,7 @@ async def received_client_receipt(
await self.federation_sender.send_read_receipt(receipt)


class ReceiptEventSource:
class ReceiptEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.config = hs.config
Expand Down Expand Up @@ -216,7 +217,13 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:
return visible_events

async def get_new_events(
self, from_key: int, room_ids: List[str], user: UserID, **kwargs: Any
self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()
Expand Down
18 changes: 14 additions & 4 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@
import random
import string
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Collection,
Dict,
List,
Optional,
Tuple,
)

from synapse.api.constants import (
EventContentFields,
Expand All @@ -47,6 +56,7 @@
from synapse.events.utils import copy_power_levels_contents
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
MutableStateMap,
Expand Down Expand Up @@ -1173,16 +1183,16 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
return results


class RoomEventSource:
class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

async def get_new_events(
self,
user: UserID,
from_key: RoomStreamToken,
limit: int,
room_ids: List[str],
limit: Optional[int],
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ async def ephemeral_by_room(

room_ids = sync_result_builder.joined_room_ids

typing_source = self.event_sources.sources["typing"]
typing_source = self.event_sources.sources.typing
typing, typing_key = await typing_source.get_new_events(
user=sync_config.user,
from_key=typing_key,
Expand All @@ -465,7 +465,7 @@ async def ephemeral_by_room(

receipt_key = since_token.receipt_key if since_token else 0

receipt_source = self.event_sources.sources["receipt"]
receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events(
user=sync_config.user,
from_key=receipt_key,
Expand Down Expand Up @@ -1415,7 +1415,7 @@ async def _generate_sync_entry_for_presence(
sync_config = sync_result_builder.sync_config
user = sync_result_builder.sync_config.user

presence_source = self.event_sources.sources["presence"]
presence_source = self.event_sources.sources.presence

since_token = sync_result_builder.since_token
presence_key = None
Expand Down
13 changes: 10 additions & 3 deletions synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import random
from collections import namedtuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple

from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
Expand All @@ -23,6 +23,7 @@
wrap_as_background_process,
)
from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
Expand Down Expand Up @@ -439,7 +440,7 @@ def process_replication_rows(
raise Exception("Typing writer instance got typing info over replication")


class TypingNotificationEventSource:
class TypingNotificationEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
Expand Down Expand Up @@ -485,7 +486,13 @@ async def get_new_events_as(
return (events, handler._latest_room_serial)

async def get_new_events(
self, from_key: int, room_ids: Iterable[str], **kwargs: Any
self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
Expand Down
2 changes: 1 addition & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, hs: "HomeServer", auth_handler):
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname
self._presence_stream = hs.get_event_sources().sources["presence"]
self._presence_stream = hs.get_event_sources().sources.presence
self._state = hs.get_state_handler()
self._clock: Clock = hs.get_clock()
self._send_email_handler = hs.get_send_email_handler()
Expand Down
2 changes: 1 addition & 1 deletion synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ async def check_for_updates(
events: List[EventBase] = []
end_token = from_token

for name, source in self.event_sources.sources.items():
for name, source in self.event_sources.sources.get_sources():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname)
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple

from twisted.internet import defer

Expand Down Expand Up @@ -153,12 +153,12 @@ def f(txn):
}

async def get_linearized_receipts_for_rooms(
self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
room_id: List of room_ids.
room_id: The room IDs to fetch receipts of.
to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
Expand Down
22 changes: 22 additions & 0 deletions synapse/streams/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Collection, Generic, List, Optional, Tuple, TypeVar

from synapse.types import UserID

# The key, this is either a stream token or int.
K = TypeVar("K")
# The return type.
R = TypeVar("R")


class EventSource(Generic[K, R]):
async def get_new_events(
self,
user: UserID,
from_key: K,
limit: Optional[int],
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[R], K]:
...
Loading

0 comments on commit 4054dfa

Please sign in to comment.