Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add helper to parse an enum from query args & use it. (#14956)
Browse files Browse the repository at this point in the history
The `parse_enum` helper pulls an enum value from the query string
(by delegating down to the parse_string helper with values generated
from the enum).

This is used to pull out "f" and "b" in most places and then we thread
the resulting Direction enum throughout more code.
  • Loading branch information
clokep authored Feb 1, 2023
1 parent 230a831 commit 1182ae5
Show file tree
Hide file tree
Showing 25 changed files with 176 additions and 96 deletions.
1 change: 1 addition & 0 deletions changelog.d/14956.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints.
15 changes: 10 additions & 5 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import attr
from prometheus_client import Counter

from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import (
CodeMessageException,
Codes,
Expand Down Expand Up @@ -1680,7 +1680,12 @@ async def send_request(
return result

async def timestamp_to_event(
self, *, destinations: List[str], room_id: str, timestamp: int, direction: str
self,
*,
destinations: List[str],
room_id: str,
timestamp: int,
direction: Direction,
) -> Optional["TimestampToEventResponse"]:
"""
Calls each remote federating server from `destinations` asking for their closest
Expand All @@ -1693,7 +1698,7 @@ async def timestamp_to_event(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
Expand Down Expand Up @@ -1738,7 +1743,7 @@ async def _timestamp_to_event_from_destination(
return None

async def _timestamp_to_event_from_destination(
self, destination: str, room_id: str, timestamp: int, direction: str
self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> "TimestampToEventResponse":
"""
Calls a remote federating server at `destination` asking for their
Expand All @@ -1751,7 +1756,7 @@ async def _timestamp_to_event_from_destination(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
Expand Down
12 changes: 9 additions & 3 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
from twisted.internet.abstract import isIPAddress
from twisted.python import failure

from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership
from synapse.api.constants import (
Direction,
EduTypes,
EventContentFields,
EventTypes,
Membership,
)
from synapse.api.errors import (
AuthError,
Codes,
Expand Down Expand Up @@ -218,7 +224,7 @@ async def on_backfill_request(
return 200, res

async def on_timestamp_to_event_request(
self, origin: str, room_id: str, timestamp: int, direction: str
self, origin: str, room_id: str, timestamp: int, direction: Direction
) -> Tuple[int, Dict[str, Any]]:
"""When we receive a federated `/timestamp_to_event` request,
handle all of the logic for validating and fetching the event.
Expand All @@ -228,7 +234,7 @@ async def on_timestamp_to_event_request(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
Expand Down
8 changes: 4 additions & 4 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import attr
import ijson

from synapse.api.constants import Membership
from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
Expand Down Expand Up @@ -169,7 +169,7 @@ async def backfill(
)

async def timestamp_to_event(
self, destination: str, room_id: str, timestamp: int, direction: str
self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> Union[JsonDict, List]:
"""
Calls a remote federating server at `destination` asking for their
Expand All @@ -180,7 +180,7 @@ async def timestamp_to_event(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
Expand All @@ -194,7 +194,7 @@ async def timestamp_to_event(
room_id,
)

args = {"ts": [str(timestamp)], "dir": [direction]}
args = {"ts": [str(timestamp)], "dir": [direction.value]}

remote_response = await self.client.get_json(
destination, path=path, args=args, try_trailing_slash_on_400=True
Expand Down
7 changes: 4 additions & 3 deletions synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from typing_extensions import Literal

from synapse.api.constants import EduTypes
from synapse.api.constants import Direction, EduTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX
Expand Down Expand Up @@ -234,9 +234,10 @@ async def on_GET(
room_id: str,
) -> Tuple[int, JsonDict]:
timestamp = parse_integer_from_args(query, "ts", required=True)
direction = parse_string_from_args(
query, "dir", default="f", allowed_values=["f", "b"], required=True
direction_str = parse_string_from_args(
query, "dir", allowed_values=["f", "b"], required=True
)
direction = Direction(direction_str)

return await self.handler.on_timestamp_to_event_request(
origin, room_id, timestamp, direction
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main

def get_current_key(self, direction: str = "f") -> int:
def get_current_key(self) -> int:
return self.store.get_max_account_data_stream_id()

async def get_new_events(
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,5 +315,5 @@ async def get_new_events_as(

return events, to_key

def get_current_key(self, direction: str = "f") -> int:
def get_current_key(self) -> int:
return self.store.get_max_receipt_stream_id()
9 changes: 5 additions & 4 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import synapse.events.snapshot
from synapse.api.constants import (
Direction,
EventContentFields,
EventTypes,
GuestAccess,
Expand Down Expand Up @@ -1487,7 +1488,7 @@ async def get_event_for_timestamp(
requester: Requester,
room_id: str,
timestamp: int,
direction: str,
direction: Direction,
) -> Tuple[str, int]:
"""Find the closest event to the given timestamp in the given direction.
If we can't find an event locally or the event we have locally is next to a gap,
Expand All @@ -1498,7 +1499,7 @@ async def get_event_for_timestamp(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
Expand Down Expand Up @@ -1533,13 +1534,13 @@ async def get_event_for_timestamp(
local_event_id, allow_none=False, allow_rejected=False
)

if direction == "f":
if direction == Direction.FORWARDS:
# We only need to check for a backward gap if we're looking forwards
# to ensure there is nothing in between.
is_event_next_to_backward_gap = (
await self.store.is_event_next_to_backward_gap(local_event)
)
elif direction == "b":
elif direction == Direction.BACKWARDS:
# We only need to check for a forward gap if we're looking backwards
# to ensure there is nothing in between
is_event_next_to_forward_gap = (
Expand Down
70 changes: 70 additions & 0 deletions synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

""" This module contains base REST classes for constructing REST servlets. """
import enum
import logging
from http import HTTPStatus
from typing import (
Expand Down Expand Up @@ -362,6 +363,7 @@ def parse_string(
request: Request,
name: str,
*,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
Expand Down Expand Up @@ -413,6 +415,74 @@ def parse_string(
)


EnumT = TypeVar("EnumT", bound=enum.Enum)


@overload
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
default: EnumT,
) -> EnumT:
...


@overload
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
*,
required: Literal[True],
) -> EnumT:
...


def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
default: Optional[EnumT] = None,
required: bool = False,
) -> Optional[EnumT]:
"""
Parse an enum parameter from the request query string.
Note that the enum *must only have string values*.
Args:
request: the twisted HTTP request.
name: the name of the query parameter.
E: the enum which represents valid values
default: enum value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
Returns:
An enum value.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
# Assert the enum values are strings.
assert all(
isinstance(e.value, str) for e in E
), "parse_enum only works with string values"
str_value = parse_string(
request,
name,
default=default.value if default is not None else None,
required=required,
allowed_values=[e.value for e in E],
)
if str_value is None:
return None
return E(str_value)


def _parse_string_value(
value: bytes,
allowed_values: Optional[Iterable[str]],
Expand Down
12 changes: 3 additions & 9 deletions synapse/rest/admin/event_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
Expand Down Expand Up @@ -60,7 +61,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
direction = parse_string(request, "dir", default="b")
direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS)
user_id = parse_string(request, "user_id")
room_id = parse_string(request, "room_id")

Expand All @@ -78,13 +79,6 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
errcode=Codes.INVALID_PARAM,
)

if direction not in ("f", "b"):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown direction: %s" % (direction,),
errcode=Codes.INVALID_PARAM,
)

event_reports, total = await self.store.get_event_reports_paginate(
start, limit, direction, user_id, room_id
)
Expand Down
7 changes: 4 additions & 3 deletions synapse/rest/admin/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.federation.transport.server import Authenticator
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.transactions import DestinationSortOrder
Expand Down Expand Up @@ -79,7 +80,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
allowed_values=[dest.value for dest in DestinationSortOrder],
)

direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)

destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
Expand Down Expand Up @@ -192,7 +193,7 @@ async def on_GET(
errcode=Codes.INVALID_PARAM,
)

direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)

rooms, total = await self._store.get_destination_rooms_paginate(
destination, start, limit, direction
Expand Down
Loading

0 comments on commit 1182ae5

Please sign in to comment.