Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Require type hints in the handlers module. (#10831)
Browse files Browse the repository at this point in the history
Adds missing type hints to methods in the synapse.handlers
module and requires all methods to have type hints there.

This also removes the unused construct_auth_difference method
from the FederationHandler.
  • Loading branch information
clokep authored Sep 20, 2021
1 parent 4379617 commit b359061
Show file tree
Hide file tree
Showing 35 changed files with 194 additions and 295 deletions.
1 change: 1 addition & 0 deletions changelog.d/10831.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to handlers.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

[mypy-synapse.handlers.*]
disallow_untyped_defs = True

[mypy-synapse.rest.*]
disallow_untyped_defs = True

Expand Down
4 changes: 2 additions & 2 deletions synapse/config/password_auth_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List
from typing import Any, List, Tuple, Type

from synapse.util.module_loader import load_module

Expand All @@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders"

def read_config(self, config, **kwargs):
self.password_providers: List[Any] = []
self.password_providers: List[Tuple[Type, Any]] = []
providers = []

# We want to be backwards compatible with the old `ldap_config`
Expand Down
14 changes: 10 additions & 4 deletions synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING, Optional

from synapse.api.ratelimiting import Ratelimiter
from synapse.types import Requester

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -63,16 +64,21 @@ def __init__(self, hs: "HomeServer"):

self.event_builder_factory = hs.get_event_builder_factory()

async def ratelimit(self, requester, update=True, is_admin_redaction=False):
async def ratelimit(
self,
requester: Requester,
update: bool = True,
is_admin_redaction: bool = False,
) -> None:
"""Ratelimits requests.
Args:
requester (Requester)
update (bool): Whether to record that a request is being processed.
requester
update: Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
is_admin_redaction (bool): Whether this is a room admin/moderator
is_admin_redaction: Whether this is a room admin/moderator
redacting an event. If so then we may apply different
ratelimits depending on config.
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, Any, List, Tuple

from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
Expand Down Expand Up @@ -171,7 +171,7 @@ def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_account_data_stream_id()

async def get_new_events(
self, user: UserID, from_key: int, **kwargs
self, user: UserID, from_key: int, **kwargs: Any
) -> Tuple[List[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def register_account_validity_callbacks(
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
):
) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired)
Expand Down Expand Up @@ -165,7 +165,7 @@ async def is_user_expired(self, user_id: str) -> bool:

return False

async def on_user_registration(self, user_id: str):
async def on_user_registration(self, user_id: str) -> None:
"""Tell third-party modules about a user's registration.
Args:
Expand Down
18 changes: 9 additions & 9 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union

from prometheus_client import Counter

Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, hs: "HomeServer"):
self.current_max = 0
self.is_processing = False

def notify_interested_services(self, max_token: RoomStreamToken):
def notify_interested_services(self, max_token: RoomStreamToken) -> None:
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
Expand All @@ -82,7 +82,7 @@ def notify_interested_services(self, max_token: RoomStreamToken):
self._notify_interested_services(max_token)

@wrap_as_background_process("notify_interested_services")
async def _notify_interested_services(self, max_token: RoomStreamToken):
async def _notify_interested_services(self, max_token: RoomStreamToken) -> None:
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
Expand All @@ -100,7 +100,7 @@ async def _notify_interested_services(self, max_token: RoomStreamToken):
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)

async def handle_event(event):
async def handle_event(event: EventBase) -> None:
# Gather interested services
services = await self._get_services_for_event(event)
if len(services) == 0:
Expand All @@ -116,9 +116,9 @@ async def handle_event(event):

if not self.started_scheduler:

async def start_scheduler():
async def start_scheduler() -> None:
try:
return await self.scheduler.start()
await self.scheduler.start()
except Exception:
logger.error("Application Services Failure")

Expand All @@ -137,7 +137,7 @@ async def start_scheduler():
"appservice_sender"
).observe((now - ts) / 1000)

async def handle_room_events(events):
async def handle_room_events(events: Iterable[EventBase]) -> None:
for event in events:
await handle_event(event)

Expand Down Expand Up @@ -184,7 +184,7 @@ def notify_interested_services_ephemeral(
stream_key: str,
new_token: Optional[int],
users: Optional[Collection[Union[str, UserID]]] = None,
):
) -> None:
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
Expand Down Expand Up @@ -226,7 +226,7 @@ async def _notify_interested_services_ephemeral(
stream_key: str,
new_token: Optional[int],
users: Collection[Union[str, UserID]],
):
) -> None:
logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
Expand Down
45 changes: 24 additions & 21 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
)
Expand Down Expand Up @@ -439,7 +440,7 @@ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:

return ui_auth_types

def get_enabled_auth_types(self):
def get_enabled_auth_types(self) -> Iterable[str]:
"""Return the enabled user-interactive authentication types
Returns the UI-Auth types which are supported by the homeserver's current
Expand Down Expand Up @@ -702,7 +703,7 @@ async def get_session_data(
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

async def _expire_old_sessions(self):
async def _expire_old_sessions(self) -> None:
"""
Invalidate any user interactive authentication sessions that have expired.
"""
Expand Down Expand Up @@ -1352,7 +1353,7 @@ async def validate_short_term_login_token(
await self.auth.check_auth_blocking(res.user_id)
return res

async def delete_access_token(self, access_token: str):
async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token
Args:
Expand Down Expand Up @@ -1381,7 +1382,7 @@ async def delete_access_tokens_for_user(
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
) -> None:
"""Invalidate access tokens belonging to a user
Args:
Expand Down Expand Up @@ -1409,7 +1410,7 @@ async def delete_access_tokens_for_user(

async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
):
) -> None:
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
Expand Down Expand Up @@ -1480,7 +1481,7 @@ async def hash(self, password: str) -> str:
Hashed password.
"""

def _do_hash():
def _do_hash() -> str:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)

Expand All @@ -1504,7 +1505,7 @@ async def validate_hash(
Whether self.hash(password) == stored_hash.
"""

def _do_validate_hash(checked_hash: bytes):
def _do_validate_hash(checked_hash: bytes) -> bool:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)

Expand Down Expand Up @@ -1581,7 +1582,7 @@ async def complete_sso_login(
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
):
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
Args:
Expand Down Expand Up @@ -1627,7 +1628,7 @@ def _complete_sso_login(
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
):
) -> None:
"""
The synchronous portion of complete_sso_login.
Expand Down Expand Up @@ -1726,17 +1727,17 @@ def _expire_sso_extra_attributes(self) -> None:
del self._extra_attributes[user_id]

@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
url_parts = list(urllib.parse.urlparse(url))
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class MacaroonGenerator:
hs = attr.ib()
hs: "HomeServer"

def generate_guest_access_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
Expand Down Expand Up @@ -1816,15 +1817,17 @@ class PasswordProvider:
"""

@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
def load(
cls, module: Type, config: JsonDict, module_api: ModuleApi
) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)

def __init__(self, pp, module_api: ModuleApi):
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api

Expand All @@ -1838,7 +1841,7 @@ def __init__(self, pp, module_api: ModuleApi):
if g:
self._supported_login_types.update(g())

def __str__(self):
def __str__(self) -> str:
return str(self._pp)

def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
Expand Down Expand Up @@ -1876,19 +1879,19 @@ async def check_auth(
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None)
if g:
check_password = getattr(self._pp, "check_password", None)
if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password(
is_valid = await check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None

g = getattr(self._pp, "check_auth", None)
if not g:
check_auth = getattr(self._pp, "check_auth", None)
if not check_auth:
return None
result = await g(username, login_type, login_dict)
result = await check_auth(username, login_type, login_dict)

# Check if the return value is a str or a tuple
if isinstance(result, str):
Expand Down
18 changes: 8 additions & 10 deletions synapse/handlers/cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket."""

def __init__(self, error, error_description=None):
def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description

def __str__(self):
def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error


@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
username: str
attributes: Dict[str, List[Optional[str]]]


class CasHandler:
Expand Down Expand Up @@ -133,11 +133,9 @@ async def _validate_ticket(
body = pde.response
except HttpResponseException as e:
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code),
)
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code)
raise CasError("server_error", description) from e

return self._parse_cas_response(body)
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __init__(self, hs: "HomeServer"):

hs.get_distributor().observe("user_left_room", self.user_left_room)

def _check_device_name_length(self, name: Optional[str]):
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
Expand Down
Loading

0 comments on commit b359061

Please sign in to comment.