diff --git a/changelog.d/16923.bugfix b/changelog.d/16923.bugfix new file mode 100644 index 00000000000..bd6f24925ee --- /dev/null +++ b/changelog.d/16923.bugfix @@ -0,0 +1 @@ +Return `400 M_NOT_JSON` upon receiving invalid JSON in query parameters across various client and admin endpoints, rather than an internal server error. \ No newline at end of file diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 0ca08038f42..ab12951da8e 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -23,6 +23,7 @@ import enum import logging +import urllib.parse as urlparse from http import HTTPStatus from typing import ( TYPE_CHECKING, @@ -450,6 +451,87 @@ def parse_string( ) +def parse_json( + request: Request, + name: str, + default: Optional[dict] = None, + required: bool = False, + encoding: str = "ascii", +) -> Optional[JsonDict]: + """ + Parse a JSON parameter from the request query string. + + Args: + request: the twisted HTTP request. + name: the name of the query parameter. + default: value to use if the parameter is absent, + defaults to None. + required: whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + encoding: The encoding to decode the string content with. + + Returns: + A JSON value, or `default` if the named query parameter was not found + and `required` was False. + + Raises: + SynapseError if the parameter is absent and required, or if the + parameter is present and not a JSON object. + """ + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore + return parse_json_from_args( + args, + name, + default, + required=required, + encoding=encoding, + ) + + +def parse_json_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[dict] = None, + required: bool = False, + encoding: str = "ascii", +) -> Optional[JsonDict]: + """ + Parse a JSON parameter from the request query string. + + Args: + args: a mapping of request args as bytes to a list of bytes (e.g. request.args). + name: the name of the query parameter. + default: value to use if the parameter is absent, + defaults to None. + required: whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + encoding: the encoding to decode the string content with. + + A JSON value, or `default` if the named query parameter was not found + and `required` was False. + + Raises: + SynapseError if the parameter is absent and required, or if the + parameter is present and not a JSON object. + """ + name_bytes = name.encode("ascii") + + if name_bytes not in args: + if not required: + return default + + message = f"Missing required integer query parameter {name}" + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM) + + json_str = parse_string_from_args(args, name, required=True, encoding=encoding) + + try: + return json_decoder.decode(urlparse.unquote(json_str)) + except Exception: + message = f"Query parameter {name} must be a valid JSON object" + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.NOT_JSON) + + EnumT = TypeVar("EnumT", bound=enum.Enum) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 4252f98a6c3..0d86a4e15f1 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -21,7 +21,6 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple, cast -from urllib import parse as urlparse import attr @@ -38,6 +37,7 @@ assert_params_in_dict, parse_enum, parse_integer, + parse_json, parse_json_object_from_request, parse_string, ) @@ -51,7 +51,6 @@ from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester from synapse.types.state import StateFilter -from synapse.util import json_decoder if TYPE_CHECKING: from synapse.api.auth import Auth @@ -776,14 +775,8 @@ async def on_GET( limit = parse_integer(request, "limit", default=10) # picking the API shape for symmetry with /messages - filter_str = parse_string(request, "filter", encoding="utf-8") - if filter_str: - filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter( - self._hs, json_decoder.decode(filter_json) - ) - else: - event_filter = None + filter_json = parse_json(request, "filter", encoding="utf-8") + event_filter = Filter(self._hs, filter_json) if filter_json else None event_context = await self.room_context_handler.get_event_context( requester, @@ -914,21 +907,16 @@ async def on_GET( ) # Twisted will have processed the args by now. assert request.args is not None + + filter_json = parse_json(request, "filter", encoding="utf-8") + event_filter = Filter(self._hs, filter_json) if filter_json else None + as_client_event = b"raw" not in request.args - filter_str = parse_string(request, "filter", encoding="utf-8") - if filter_str: - filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter( - self._hs, json_decoder.decode(filter_json) - ) - if ( - event_filter - and event_filter.filter_json.get("event_format", "client") - == "federation" - ): - as_client_event = False - else: - event_filter = None + if ( + event_filter + and event_filter.filter_json.get("event_format", "client") == "federation" + ): + as_client_event = False msgs = await self._pagination_handler.get_messages( room_id=room_id, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 4eeadf8779a..e4c7dd1a583 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -52,6 +52,7 @@ parse_boolean, parse_enum, parse_integer, + parse_json, parse_json_object_from_request, parse_string, parse_strings_from_args, @@ -65,7 +66,6 @@ from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID from synapse.types.state import StateFilter -from synapse.util import json_decoder from synapse.util.cancellation import cancellable from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -703,21 +703,16 @@ async def on_GET( ) # Twisted will have processed the args by now. assert request.args is not None + + filter_json = parse_json(request, "filter", encoding="utf-8") + event_filter = Filter(self._hs, filter_json) if filter_json else None + as_client_event = b"raw" not in request.args - filter_str = parse_string(request, "filter", encoding="utf-8") - if filter_str: - filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter( - self._hs, json_decoder.decode(filter_json) - ) - if ( - event_filter - and event_filter.filter_json.get("event_format", "client") - == "federation" - ): - as_client_event = False - else: - event_filter = None + if ( + event_filter + and event_filter.filter_json.get("event_format", "client") == "federation" + ): + as_client_event = False msgs = await self.pagination_handler.get_messages( room_id=room_id, @@ -898,14 +893,8 @@ async def on_GET( limit = parse_integer(request, "limit", default=10) # picking the API shape for symmetry with /messages - filter_str = parse_string(request, "filter", encoding="utf-8") - if filter_str: - filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter( - self._hs, json_decoder.decode(filter_json) - ) - else: - event_filter = None + filter_json = parse_json(request, "filter", encoding="utf-8") + event_filter = Filter(self._hs, filter_json) if filter_json else None event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, limit, event_filter diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 0b669b6ee7c..75627472605 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -21,6 +21,7 @@ import json import time import urllib.parse +from http import HTTPStatus from typing import List, Optional from unittest.mock import AsyncMock, Mock @@ -2190,6 +2191,33 @@ def test_room_messages_purge(self) -> None: chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + def test_room_message_filter_query_validation(self) -> None: + # Test json validation in (filter) query parameter. + # Does not test the validity of the filter, only the json validation. + + # Check Get with valid json filter parameter, expect 200. + valid_filter_str = '{"types": ["m.room.message"]}' + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={valid_filter_str}", + access_token=self.admin_user_tok, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check Get with invalid json filter parameter, expect 400 NOT_JSON. + invalid_filter_str = "}}}{}" + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={invalid_filter_str}", + access_token=self.admin_user_tok, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) + self.assertEqual( + channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body + ) + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): servlets = [ @@ -2522,6 +2550,39 @@ def test_context_as_admin(self) -> None: else: self.fail("Event %s from events_after not found" % j) + def test_room_event_context_filter_query_validation(self) -> None: + # Test json validation in (filter) query parameter. + # Does not test the validity of the filter, only the json validation. + + # Create a user with room and event_id. + user_id = self.register_user("test", "test") + user_tok = self.login("test", "test") + room_id = self.helper.create_room_as(user_id, tok=user_tok) + event_id = self.helper.send(room_id, "message 1", tok=user_tok)["event_id"] + + # Check Get with valid json filter parameter, expect 200. + valid_filter_str = '{"types": ["m.room.message"]}' + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={valid_filter_str}", + access_token=self.admin_user_tok, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check Get with invalid json filter parameter, expect 400 NOT_JSON. + invalid_filter_str = "}}}{}" + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={invalid_filter_str}", + access_token=self.admin_user_tok, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) + self.assertEqual( + channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body + ) + class MakeRoomAdminTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 1364615085e..b796163dcbb 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2175,6 +2175,31 @@ def test_room_messages_purge(self) -> None: chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + def test_room_message_filter_query_validation(self) -> None: + # Test json validation in (filter) query parameter. + # Does not test the validity of the filter, only the json validation. + + # Check Get with valid json filter parameter, expect 200. + valid_filter_str = '{"types": ["m.room.message"]}' + channel = self.make_request( + "GET", + f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={valid_filter_str}", + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check Get with invalid json filter parameter, expect 400 NOT_JSON. + invalid_filter_str = "}}}{}" + channel = self.make_request( + "GET", + f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={invalid_filter_str}", + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) + self.assertEqual( + channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body + ) + class RoomMessageFilterTestCase(RoomBase): """Tests /rooms/$room_id/messages REST events.""" @@ -3213,6 +3238,33 @@ def test_erased_sender(self) -> None: self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) self.assertEqual(events_after[1].get("content"), {}, events_after[1]) + def test_room_event_context_filter_query_validation(self) -> None: + # Test json validation in (filter) query parameter. + # Does not test the validity of the filter, only the json validation. + event_id = self.helper.send(self.room_id, "message 7", tok=self.tok)["event_id"] + + # Check Get with valid json filter parameter, expect 200. + valid_filter_str = '{"types": ["m.room.message"]}' + channel = self.make_request( + "GET", + f"/rooms/{self.room_id}/context/{event_id}?filter={valid_filter_str}", + access_token=self.tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check Get with invalid json filter parameter, expect 400 NOT_JSON. + invalid_filter_str = "}}}{}" + channel = self.make_request( + "GET", + f"/rooms/{self.room_id}/context/{event_id}?filter={invalid_filter_str}", + access_token=self.tok, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) + self.assertEqual( + channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body + ) + class RoomAliasListTestCase(unittest.HomeserverTestCase): servlets = [