Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse json validation #16923

Merged
merged 16 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog.d/16923.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Adds parse_json servlet function for standardized JSON parsing from query parameters, ensuring enhanced data validation and error handling.
Introduces INVALID_PARAM error response for invalid JSON objects, improving parameter validation feedback.
Adds validation check to prevent 500 internal server error on invalid Json Filter request.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally we try to keep changelog entries short, and acknowledge that the audience is system administrators. Such an audience won't care to know the details of the implementation, but rather than user-facing impact. My suggestion would be:

Suggested change
Adds parse_json servlet function for standardized JSON parsing from query parameters, ensuring enhanced data validation and error handling.
Introduces INVALID_PARAM error response for invalid JSON objects, improving parameter validation feedback.
Adds validation check to prevent 500 internal server error on invalid Json Filter request.
Return `400 M_INVALID_PARAM` upon receiving invalid JSON in query parameters across various client and admin endpoints, rather than an internal server error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the Insights! - I'll keep that in mind. 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you happy to accept my suggestion here? I prefer the suggested version of the changelog for the reasons in my initial comment.

82 changes: 82 additions & 0 deletions synapse/http/servlet.py
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
""" This module contains base REST classes for constructing REST servlets. """
import enum
import logging
import urllib.parse as urlparse
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -428,6 +429,87 @@ def parse_string(
)


def parse_json(
request: Request,
name: str,
default: Optional[dict] = None,
required: bool = False,
encoding: str = "ascii",
) -> Optional[JsonDict]:
"""
Parse a JSON parameter from the request query string.

Args:
request: the twisted HTTP request.
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
encoding: The encoding to decode the string content with.

Returns:
A JSON value, or `default` if the named query parameter was not found
and `required` was False.

Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present and not a JSON object.
"""
args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
return parse_json_from_args(
args,
name,
default,
required=required,
encoding=encoding,
)


def parse_json_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[dict] = None,
required: bool = False,
encoding: str = "ascii",
) -> Optional[JsonDict]:
"""
Parse a JSON parameter from the request query string.

Args:
args: a mapping of request args as bytes to a list of bytes (e.g. request.args).
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
encoding: the encoding to decode the string content with.

A JSON value, or `default` if the named query parameter was not found
and `required` was False.

Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present and not a JSON object.
"""
name_bytes = name.encode("ascii")

if name_bytes not in args:
if not required:
return default

message = f"Missing required integer query parameter {name}"
TrevisGordan marked this conversation as resolved.
Show resolved Hide resolved
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)

json_str = parse_string_from_args(args, name, required=True, encoding=encoding)

try:
return json_decoder.decode(urlparse.unquote(json_str))
except Exception:
message = f"Query parameter {name} must be a valid JSON object"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
TrevisGordan marked this conversation as resolved.
Show resolved Hide resolved


EnumT = TypeVar("EnumT", bound=enum.Enum)


Expand Down
34 changes: 12 additions & 22 deletions synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
assert_params_in_dict,
parse_enum,
parse_integer,
parse_json,
parse_json_object_from_request,
parse_string,
)
Expand Down Expand Up @@ -776,14 +777,8 @@ async def on_GET(
limit = parse_integer(request, "limit", default=10)

# picking the API shape for symmetry with /messages
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

event_context = await self.room_context_handler.get_event_context(
requester,
Expand Down Expand Up @@ -914,21 +909,16 @@ async def on_GET(
)
# Twisted will have processed the args by now.
assert request.args is not None

filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
== "federation"
):
as_client_event = False
else:
event_filter = None
if (
event_filter
and event_filter.filter_json.get("event_format", "client") == "federation"
):
as_client_event = False

msgs = await self._pagination_handler.get_messages(
room_id=room_id,
Expand Down
35 changes: 12 additions & 23 deletions synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
parse_boolean,
parse_enum,
parse_integer,
parse_json,
parse_json_object_from_request,
parse_string,
parse_strings_from_args,
Expand All @@ -65,7 +66,6 @@
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types.state import StateFilter
from synapse.util import json_decoder
from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name, random_string

Expand Down Expand Up @@ -703,21 +703,16 @@ async def on_GET(
)
# Twisted will have processed the args by now.
assert request.args is not None

filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
== "federation"
):
as_client_event = False
else:
event_filter = None
if (
event_filter
and event_filter.filter_json.get("event_format", "client") == "federation"
):
as_client_event = False

msgs = await self.pagination_handler.get_messages(
room_id=room_id,
Expand Down Expand Up @@ -898,14 +893,8 @@ async def on_GET(
limit = parse_integer(request, "limit", default=10)

# picking the API shape for symmetry with /messages
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

event_context = await self.room_context_handler.get_event_context(
requester, room_id, event_id, limit, event_filter
Expand Down
61 changes: 61 additions & 0 deletions tests/rest/admin/test_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import time
import urllib.parse
from http import HTTPStatus
from typing import List, Optional
from unittest.mock import AsyncMock, Mock

Expand Down Expand Up @@ -2190,6 +2191,33 @@ def test_room_messages_purge(self) -> None:
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])

def test_room_message_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.

# Check Get with valid json filter parameter, expect 200.
_valid_filter_str = '{"types": ["m.room.message"]}'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally you should only precede a variable name with an underscore in Python if you'd like to label the output of a function, but not actually use it. Here we are using _valid_filter_str, so it should not have a leading underscore.

Could you remove it, along with _invalid_filter_str and from other tests please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh took me a bit! But now I think I know what you mean.
You are referring to the use of a single underscore, _, as a throwaway variable, commonly used for temporary or insignificant values, as in:

for _ in range(32):
    print('Hello, World.')

this woulde be correct. But in this case, following PEP 8, _single_leading_underscore signals internal use, here serving as a minor internal string helper. It marks variables that are temporary or specific to this test's context, distinguishing between main test logic and setup details. However, I'm also more than happy to adjust this for you too! ;) Always wanted to cite a PEP tho. 😉

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for citing the PEP! It's interesting to read where this convention came from.

In Synapse, we certainly do use leading underscores for internal function/method names and private class variables (self._internal_var). This is to signal to code external to classes that they shouldn't try to access this field (it is internal).

However, we don't use this convention for local variable names - so at least for this codebase, I would remove the leading underscore.

channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={_valid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 INVALID_PARAM.
_invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={_invalid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.INVALID_PARAM, channel.json_body
)


class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
servlets = [
Expand Down Expand Up @@ -2522,6 +2550,39 @@ def test_context_as_admin(self) -> None:
else:
self.fail("Event %s from events_after not found" % j)

def test_room_event_context_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.

# Create a user with room and event_id.
user_id = self.register_user("test", "test")
user_tok = self.login("test", "test")
room_id = self.helper.create_room_as(user_id, tok=user_tok)
event_id = self.helper.send(room_id, "message 1", tok=user_tok)["event_id"]

# Check Get with valid json filter parameter, expect 200.
_valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={_valid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 INVALID_PARAM.
_invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={_invalid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.INVALID_PARAM, channel.json_body
)


class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
servlets = [
Expand Down
52 changes: 52 additions & 0 deletions tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,31 @@ def test_room_messages_purge(self) -> None:
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])

def test_room_message_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.

# Check Get with valid json filter parameter, expect 200.
_valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={_valid_filter_str}",
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 INVALID_PARAM.
_invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={_invalid_filter_str}",
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.INVALID_PARAM, channel.json_body
)


class RoomMessageFilterTestCase(RoomBase):
"""Tests /rooms/$room_id/messages REST events."""
Expand Down Expand Up @@ -3195,6 +3220,33 @@ def test_erased_sender(self) -> None:
self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
self.assertEqual(events_after[1].get("content"), {}, events_after[1])

def test_room_event_context_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.
event_id = self.helper.send(self.room_id, "message 7", tok=self.tok)["event_id"]

# Check Get with valid json filter parameter, expect 200.
_valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/context/{event_id}?filter={_valid_filter_str}",
access_token=self.tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 INVALID_PARAM.
_invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/context/{event_id}?filter={_invalid_filter_str}",
access_token=self.tok,
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.INVALID_PARAM, channel.json_body
)


class RoomAliasListTestCase(unittest.HomeserverTestCase):
servlets = [
Expand Down