Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add StreamStore to mypy (#8232)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston authored Sep 2, 2020
1 parent 5a1dd29 commit 112266e
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 20 deletions.
1 change: 1 addition & 0 deletions changelog.d/8232.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `StreamStore`.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,
Expand Down
4 changes: 2 additions & 2 deletions synapse/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import abc
import os
from distutils.util import strtobool
from typing import Dict, Optional, Type
from typing import Dict, Optional, Tuple, Type

from unpaddedbase64 import encode_base64

Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(self, internal_metadata_dict: JsonDict):
# be here
before = DictProperty("before") # type: str
after = DictProperty("after") # type: str
order = DictProperty("order") # type: int
order = DictProperty("order") # type: Tuple[int, int]

def get_dict(self) -> JsonDict:
return dict(self._dict)
Expand Down
34 changes: 34 additions & 0 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,18 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
results = [dict(zip(col_headers, row)) for row in cursor]
return results

@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...

@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...

async def execute(
self,
desc: str,
Expand Down Expand Up @@ -1088,6 +1100,28 @@ async def simple_select_one(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)

@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one_onecol",
) -> Any:
...

@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one_onecol",
) -> Optional[Any]:
...

async def simple_select_one_onecol(
self,
table: str,
Expand Down
46 changes: 28 additions & 18 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import abc
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from twisted.internet import defer

Expand All @@ -54,9 +54,12 @@
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -206,7 +209,7 @@ def _make_generic_sql_bound(
)


def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
Expand Down Expand Up @@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):

__metaclass__ = abc.ABCMeta

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()
Expand Down Expand Up @@ -297,16 +300,16 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._stream_order_on_start = self.get_room_max_stream_ordering()

@abc.abstractmethod
def get_room_max_stream_ordering(self):
def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError()

@abc.abstractmethod
def get_room_min_stream_ordering(self):
def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()

async def get_room_events_stream_for_rooms(
self,
room_ids: Iterable[str],
room_ids: Collection[str],
from_key: str,
to_key: str,
limit: int = 0,
Expand Down Expand Up @@ -360,19 +363,21 @@ async def get_room_events_stream_for_rooms(

return results

def get_rooms_that_changed(self, room_ids, from_key):
def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids (list)
from_key (str): The room_key portion of a StreamToken
room_ids
from_key: The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
from_id = RoomStreamToken.parse_stream_token(from_key).stream
return {
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
if self._events_stream_cache.has_entity_changed(room_id, from_id)
}

async def get_room_events_stream_for_room(
Expand Down Expand Up @@ -444,7 +449,9 @@ def f(txn):

return ret, key

async def get_membership_changes_for_user(self, user_id, from_key, to_key):
async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str
) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream

Expand Down Expand Up @@ -661,7 +668,7 @@ async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
)
return row[0][0] if row else 0

def _get_max_topological_txn(self, txn, room_id):
def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
Expand Down Expand Up @@ -734,7 +741,7 @@ async def get_events_around(

def _get_events_around_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
event_id: str,
before_limit: int,
Expand Down Expand Up @@ -762,6 +769,9 @@ def _get_events_around_txn(
retcols=["stream_ordering", "topological_ordering"],
)

# This cannot happen as `allow_none=False`.
assert results is not None

# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
Expand Down Expand Up @@ -871,7 +881,7 @@ async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
desc="update_federation_out_pos",
)

def _reset_federation_positions_txn(self, txn) -> None:
def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
Expand Down Expand Up @@ -910,7 +920,7 @@ def _reset_federation_positions_txn(self, txn) -> None:
GROUP BY type
"""
txn.execute(sql)
min_positions = dict(txn) # Map from type -> min position
min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position

# Ensure we do actually have some values here
assert set(min_positions) == {"federation", "events"}
Expand All @@ -937,7 +947,7 @@ def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:

def _paginate_room_events_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
Expand Down

0 comments on commit 112266e

Please sign in to comment.