From d96d076c4f1720f844b0f0335b643360c3d4d122 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 22 Sep 2021 11:43:34 +0100 Subject: [PATCH 1/5] Improve typing in user_directory files This makes the user_directory.py in storage pass most of mypy's checks (including `no-untyped-defs`). Unfortunately that file is in the tangled web of Store class inheritance so doesn't pass mypy at the moment. The handlers directory has already been mypyed. --- mypy.ini | 2 + .../storage/databases/main/user_directory.py | 109 ++++++++++++------ tests/handlers/test_user_directory.py | 5 +- 3 files changed, 79 insertions(+), 37 deletions(-) diff --git a/mypy.ini b/mypy.ini index 3cb6cecd7e8c..437d0a46a593 100644 --- a/mypy.ini +++ b/mypy.ini @@ -85,9 +85,11 @@ files = tests/handlers/test_room_summary.py, tests/handlers/test_send_email.py, tests/handlers/test_sync.py, + tests/handlers/test_user_directory.py, tests/rest/client/test_login.py, tests/rest/client/test_auth.py, tests/storage/test_state.py, + tests/storage/test_user_directory.py, tests/util/test_itertools.py, tests/util/test_stream_change_cache.py diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 718f3e99768e..ed837d666c3b 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -14,14 +14,16 @@ import logging import re -from typing import Any, Dict, Iterable, Optional, Set, Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, cast from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules -from synapse.storage.database import DatabasePool +from synapse.server import HomeServer +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import get_domain_from_id, get_localpart_from_id +from synapse.storage.types import Connection +from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -36,7 +38,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: Connection, + hs: HomeServer, + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -57,10 +64,12 @@ def __init__(self, database: DatabasePool, db_conn, hs): "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) - async def _populate_user_directory_createtables(self, progress, batch_size): + async def _populate_user_directory_createtables( + self, progress: JsonDict, batch_size: int + ) -> int: # Get all the rooms that we want to process. - def _make_staging_area(txn): + def _make_staging_area(txn: LoggingTransaction) -> None: sql = ( "CREATE TABLE IF NOT EXISTS " + TEMP_TABLE @@ -110,16 +119,20 @@ def _make_staging_area(txn): ) return 1 - async def _populate_user_directory_cleanup(self, progress, batch_size): + async def _populate_user_directory_cleanup( + self, + progress: JsonDict, + batch_size: int, + ) -> int: """ Update the user directory stream position, then clean up the old tables. """ position = await self.db_pool.simple_select_one_onecol( - TEMP_TABLE + "_position", None, "position" + TEMP_TABLE + "_position", {}, "position" ) await self.update_user_directory_stream_pos(position) - def _delete_staging_area(txn): + def _delete_staging_area(txn: LoggingTransaction) -> None: txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") @@ -133,18 +146,32 @@ def _delete_staging_area(txn): ) return 1 - async def _populate_user_directory_process_rooms(self, progress, batch_size): + async def _populate_user_directory_process_rooms( + self, progress: JsonDict, batch_size: int + ) -> int: """ + Rescan the state of all rooms so we can track + + - who's in a public room; + - which local users share a private room with other users (local + and remote); and + - who should be in the user_directory. + Args: progress (dict) batch_size (int): Maximum number of state events to process per cycle. + + Returns: + number of events processed. """ # If we don't have progress filed, delete everything. if not progress: await self.delete_all_from_user_dir() - def _get_next_batch(txn): + def _get_next_batch( + txn: LoggingTransaction, + ) -> Optional[Sequence[Tuple[str, int]]]: # Only fetch 250 rooms, so we don't fetch too many at once, even # if those 250 rooms have less than batch_size state events. sql = """ @@ -155,7 +182,7 @@ def _get_next_batch(txn): TEMP_TABLE + "_rooms", ) txn.execute(sql) - rooms_to_work_on = txn.fetchall() + rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall()) if not rooms_to_work_on: return None @@ -163,7 +190,9 @@ def _get_next_batch(txn): # Get how many are left to process, so we can give status on how # far we are in processing txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") - progress["remaining"] = txn.fetchone()[0] + result = txn.fetchone() + assert result is not None + progress["remaining"] = result[0] return rooms_to_work_on @@ -261,29 +290,33 @@ def _get_next_batch(txn): return processed_event_count - async def _populate_user_directory_process_users(self, progress, batch_size): + async def _populate_user_directory_process_users( + self, progress: JsonDict, batch_size: int + ) -> int: """ Add all local users to the user directory. """ - def _get_next_batch(txn): + def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]: sql = "SELECT user_id FROM %s LIMIT %s" % ( TEMP_TABLE + "_users", str(batch_size), ) txn.execute(sql) - users_to_work_on = txn.fetchall() + user_result = cast(List[Tuple[str]], txn.fetchall()) - if not users_to_work_on: + if not user_result: return None - users_to_work_on = [x[0] for x in users_to_work_on] + users_to_work_on = [x[0] for x in user_result] # Get how many are left to process, so we can give status on how # far we are in processing sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users" txn.execute(sql) - progress["remaining"] = txn.fetchone()[0] + count_result = txn.fetchone() + assert count_result is not None + progress["remaining"] = count_result[0] return users_to_work_on @@ -324,7 +357,7 @@ def _get_next_batch(txn): return len(users_to_work_on) - async def is_room_world_readable_or_publicly_joinable(self, room_id): + async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: """Check if the room is either world_readable or publically joinable""" # Create a state filter that only queries join and history state event @@ -368,7 +401,7 @@ async def update_profile_in_user_dir( if not isinstance(avatar_url, str): avatar_url = None - def _update_profile_in_user_dir_txn(txn): + def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_upsert_txn( txn, table="user_directory", @@ -435,7 +468,7 @@ async def add_users_who_share_private_room( for user_id, other_user_id in user_id_tuples ], value_names=(), - value_values=None, + value_values=(), desc="add_users_who_share_room", ) @@ -454,14 +487,14 @@ async def add_users_in_public_rooms( key_names=["user_id", "room_id"], key_values=[(user_id, room_id) for user_id in user_ids], value_names=(), - value_values=None, + value_values=(), desc="add_users_in_public_rooms", ) async def delete_all_from_user_dir(self) -> None: """Delete the entire user directory""" - def _delete_all_from_user_dir_txn(txn): + def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None: txn.execute("DELETE FROM user_directory") txn.execute("DELETE FROM user_directory_search") txn.execute("DELETE FROM users_in_public_rooms") @@ -473,7 +506,7 @@ def _delete_all_from_user_dir_txn(txn): ) @cached() - async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: + async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]: return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, @@ -497,7 +530,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, database: DatabasePool, db_conn: Connection, hs: HomeServer + ) -> None: super().__init__(database, db_conn, hs) self._prefer_local_users_in_search = ( @@ -506,7 +541,7 @@ def __init__(self, database: DatabasePool, db_conn, hs): self._server_name = hs.config.server.server_name async def remove_from_user_dir(self, user_id: str) -> None: - def _remove_from_user_dir_txn(txn): + def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) @@ -532,7 +567,7 @@ def _remove_from_user_dir_txn(txn): "remove_from_user_dir", _remove_from_user_dir_txn ) - async def get_users_in_dir_due_to_room(self, room_id): + async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]: """Get all user_ids that are in the room directory because they're in the given room_id """ @@ -565,7 +600,7 @@ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None: room_id """ - def _remove_user_who_share_room_txn(txn): + def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", @@ -586,7 +621,7 @@ def _remove_user_who_share_room_txn(txn): "remove_user_who_share_room", _remove_user_who_share_room_txn ) - async def get_user_dir_rooms_user_is_in(self, user_id): + async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]: """ Returns the rooms that a user is in. @@ -628,7 +663,9 @@ async def get_shared_rooms_for_users( A set of room ID's that the users share. """ - def _get_shared_rooms_for_users_txn(txn): + def _get_shared_rooms_for_users_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, str]]: txn.execute( """ SELECT p1.room_id @@ -669,7 +706,9 @@ async def get_user_directory_stream_pos(self) -> Optional[int]: desc="get_user_directory_stream_pos", ) - async def search_user_dir(self, user_id, search_term, limit): + async def search_user_dir( + self, user_id: str, search_term: str, limit: int + ) -> JsonDict: """Searches for users in directory Returns: @@ -705,7 +744,7 @@ async def search_user_dir(self, user_id, search_term, limit): # We allow manipulating the ranking algorithm by injecting statements # based on config options. additional_ordering_statements = [] - ordering_arguments = () + ordering_arguments: Tuple[str, ...] = () if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) @@ -811,7 +850,7 @@ async def search_user_dir(self, user_id, search_term, limit): return {"limited": limited, "results": results} -def _parse_query_sqlite(search_term): +def _parse_query_sqlite(search_term: str) -> str: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. We use this so that we can add prefix matching, which isn't something @@ -826,7 +865,7 @@ def _parse_query_sqlite(search_term): return " & ".join("(%s* OR %s)" % (result, result) for result in results) -def _parse_query_postgres(search_term): +def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. We use this so that we can add prefix matching, which isn't something diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index ae88ed89aad8..2156d9baaf3a 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple from unittest.mock import Mock from twisted.internet import defer @@ -285,7 +286,7 @@ def _compress_shared(self, shared): r.add((i["user_id"], i["other_user_id"], i["room_id"])) return r - def get_users_in_public_rooms(self): + def get_users_in_public_rooms(self) -> List[Tuple[str, str]]: r = self.get_success( self.store.db_pool.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") @@ -296,7 +297,7 @@ def get_users_in_public_rooms(self): retval.append((i["user_id"], i["room_id"])) return retval - def get_users_who_share_private_rooms(self): + def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]: return self.get_success( self.store.db_pool.simple_select_list( "users_who_share_private_rooms", From 61ccaf36232abd5df34617516698a0979badb1e2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 23 Sep 2021 13:40:22 +0100 Subject: [PATCH 2/5] Changelog --- changelog.d/10891.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/10891.misc diff --git a/changelog.d/10891.misc b/changelog.d/10891.misc new file mode 100644 index 000000000000..621ee0bc7470 --- /dev/null +++ b/changelog.d/10891.misc @@ -0,0 +1 @@ +Improve type hinting in the user_directory code. \ No newline at end of file From 0ff821641abe53531cb25c16b99f3f483947fc49 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 23 Sep 2021 13:47:23 +0100 Subject: [PATCH 3/5] Break typecheck-only circular import --- .../storage/databases/main/user_directory.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ed837d666c3b..ca9799aa2579 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -14,10 +14,22 @@ import logging import re -from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + cast, +) + +if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules -from synapse.server import HomeServer from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore @@ -531,7 +543,10 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SHARE_PRIVATE_WORKING_SET = 500 def __init__( - self, database: DatabasePool, db_conn: Connection, hs: HomeServer + self, + database: DatabasePool, + db_conn: Connection, + hs: HomeServer, ) -> None: super().__init__(database, db_conn, hs) From f15e9ad2f8cd2cc2b055bf7d1ca4b4fd0786050e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 23 Sep 2021 13:58:29 +0100 Subject: [PATCH 4/5] Don't have PEP 563 postponed annotations yet--3.7+ --- synapse/storage/databases/main/user_directory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ca9799aa2579..7ca04237a508 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -54,7 +54,7 @@ def __init__( self, database: DatabasePool, db_conn: Connection, - hs: HomeServer, + hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -546,7 +546,7 @@ def __init__( self, database: DatabasePool, db_conn: Connection, - hs: HomeServer, + hs: "HomeServer", ) -> None: super().__init__(database, db_conn, hs) From 60bf3f41ebdfeb2e38b01887a84d626f120a6ea6 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 23 Sep 2021 15:54:47 +0100 Subject: [PATCH 5/5] Update changelog.d/10891.misc Co-authored-by: reivilibre --- changelog.d/10891.misc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/10891.misc b/changelog.d/10891.misc index 621ee0bc7470..6eecea4065e9 100644 --- a/changelog.d/10891.misc +++ b/changelog.d/10891.misc @@ -1 +1 @@ -Improve type hinting in the user_directory code. \ No newline at end of file +Improve type hinting in the user directory code. \ No newline at end of file