Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse json validation #16923

Merged
merged 16 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16923.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return `400 M_NOT_JSON` upon receiving invalid JSON in query parameters across various client and admin endpoints, rather than an internal server error.
82 changes: 82 additions & 0 deletions synapse/http/servlet.py
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import enum
import logging
import urllib.parse as urlparse
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -450,6 +451,87 @@ def parse_string(
)


def parse_json(
request: Request,
name: str,
default: Optional[dict] = None,
required: bool = False,
encoding: str = "ascii",
) -> Optional[JsonDict]:
"""
Parse a JSON parameter from the request query string.

Args:
request: the twisted HTTP request.
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
encoding: The encoding to decode the string content with.

Returns:
A JSON value, or `default` if the named query parameter was not found
and `required` was False.

Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present and not a JSON object.
"""
args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
return parse_json_from_args(
args,
name,
default,
required=required,
encoding=encoding,
)


def parse_json_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[dict] = None,
required: bool = False,
encoding: str = "ascii",
) -> Optional[JsonDict]:
"""
Parse a JSON parameter from the request query string.

Args:
args: a mapping of request args as bytes to a list of bytes (e.g. request.args).
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
encoding: the encoding to decode the string content with.

A JSON value, or `default` if the named query parameter was not found
and `required` was False.

Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present and not a JSON object.
"""
name_bytes = name.encode("ascii")

if name_bytes not in args:
if not required:
return default

message = f"Missing required integer query parameter {name}"
TrevisGordan marked this conversation as resolved.
Show resolved Hide resolved
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)

json_str = parse_string_from_args(args, name, required=True, encoding=encoding)

try:
return json_decoder.decode(urlparse.unquote(json_str))
except Exception:
message = f"Query parameter {name} must be a valid JSON object"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.NOT_JSON)


EnumT = TypeVar("EnumT", bound=enum.Enum)


Expand Down
36 changes: 12 additions & 24 deletions synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse

import attr

Expand All @@ -38,6 +37,7 @@
assert_params_in_dict,
parse_enum,
parse_integer,
parse_json,
parse_json_object_from_request,
parse_string,
)
Expand All @@ -51,7 +51,6 @@
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester
from synapse.types.state import StateFilter
from synapse.util import json_decoder

if TYPE_CHECKING:
from synapse.api.auth import Auth
Expand Down Expand Up @@ -776,14 +775,8 @@ async def on_GET(
limit = parse_integer(request, "limit", default=10)

# picking the API shape for symmetry with /messages
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

event_context = await self.room_context_handler.get_event_context(
requester,
Expand Down Expand Up @@ -914,21 +907,16 @@ async def on_GET(
)
# Twisted will have processed the args by now.
assert request.args is not None

filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
== "federation"
):
as_client_event = False
else:
event_filter = None
if (
event_filter
and event_filter.filter_json.get("event_format", "client") == "federation"
):
as_client_event = False

msgs = await self._pagination_handler.get_messages(
room_id=room_id,
Expand Down
35 changes: 12 additions & 23 deletions synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
parse_boolean,
parse_enum,
parse_integer,
parse_json,
parse_json_object_from_request,
parse_string,
parse_strings_from_args,
Expand All @@ -65,7 +66,6 @@
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types.state import StateFilter
from synapse.util import json_decoder
from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name, random_string

Expand Down Expand Up @@ -703,21 +703,16 @@ async def on_GET(
)
# Twisted will have processed the args by now.
assert request.args is not None

filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
== "federation"
):
as_client_event = False
else:
event_filter = None
if (
event_filter
and event_filter.filter_json.get("event_format", "client") == "federation"
):
as_client_event = False

msgs = await self.pagination_handler.get_messages(
room_id=room_id,
Expand Down Expand Up @@ -898,14 +893,8 @@ async def on_GET(
limit = parse_integer(request, "limit", default=10)

# picking the API shape for symmetry with /messages
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

event_context = await self.room_context_handler.get_event_context(
requester, room_id, event_id, limit, event_filter
Expand Down
61 changes: 61 additions & 0 deletions tests/rest/admin/test_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import time
import urllib.parse
from http import HTTPStatus
from typing import List, Optional
from unittest.mock import AsyncMock, Mock

Expand Down Expand Up @@ -2190,6 +2191,33 @@ def test_room_messages_purge(self) -> None:
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])

def test_room_message_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.

# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={valid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={invalid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)


class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
servlets = [
Expand Down Expand Up @@ -2522,6 +2550,39 @@ def test_context_as_admin(self) -> None:
else:
self.fail("Event %s from events_after not found" % j)

def test_room_event_context_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.

# Create a user with room and event_id.
user_id = self.register_user("test", "test")
user_tok = self.login("test", "test")
room_id = self.helper.create_room_as(user_id, tok=user_tok)
event_id = self.helper.send(room_id, "message 1", tok=user_tok)["event_id"]

# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={valid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={invalid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)


class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
servlets = [
Expand Down
52 changes: 52 additions & 0 deletions tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,6 +2175,31 @@ def test_room_messages_purge(self) -> None:
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])

def test_room_message_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.

# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={valid_filter_str}",
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={invalid_filter_str}",
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)


class RoomMessageFilterTestCase(RoomBase):
"""Tests /rooms/$room_id/messages REST events."""
Expand Down Expand Up @@ -3213,6 +3238,33 @@ def test_erased_sender(self) -> None:
self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
self.assertEqual(events_after[1].get("content"), {}, events_after[1])

def test_room_event_context_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.
event_id = self.helper.send(self.room_id, "message 7", tok=self.tok)["event_id"]

# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/context/{event_id}?filter={valid_filter_str}",
access_token=self.tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/context/{event_id}?filter={invalid_filter_str}",
access_token=self.tok,
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)


class RoomAliasListTestCase(unittest.HomeserverTestCase):
servlets = [
Expand Down
Loading