Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(connection): Fix creation of FormMultiDict in Request.form to properly handle multi-keys #3639

Merged
merged 1 commit into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions litestar/connection/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send =
super().__init__(scope, receive, send)
self.is_connected: bool = True
self._body: bytes | EmptyType = Empty
self._form: dict[str, str | list[str]] | EmptyType = Empty
self._form: FormMultiDict | EmptyType = Empty
self._json: Any = Empty
self._msgpack: Any = Empty
self._content_type: tuple[str, dict[str, str]] | EmptyType = Empty
Expand Down Expand Up @@ -205,26 +205,36 @@ async def form(self) -> FormMultiDict:
A FormMultiDict instance
"""
if self._form is Empty:
if (form := self._connection_state.form) is not Empty:
self._form = form
else:
if (form_data := self._connection_state.form) is Empty:
content_type, options = self.content_type
if content_type == RequestEncodingType.MULTI_PART:
self._form = parse_multipart_form(
form_data = parse_multipart_form(
body=await self.body(),
boundary=options.get("boundary", "").encode(),
multipart_form_part_limit=self.app.multipart_form_part_limit,
)
elif content_type == RequestEncodingType.URL_ENCODED:
self._form = parse_url_encoded_form_data(
form_data = parse_url_encoded_form_data(
await self.body(),
)
else:
self._form = {}

self._connection_state.form = self._form
form_data = {}

self._connection_state.form = form_data

# form_data is a dict[str, list[str] | str | UploadFile]. Convert it to a
# list[tuple[str, str | UploadFile]] before passing it to FormMultiDict so
# multi-keys can be accessed properly
items = []
for k, v in form_data.items():
if isinstance(v, list):
for sv in v:
items.append((k, sv))
else:
items.append((k, v))
self._form = FormMultiDict(items)

return FormMultiDict(self._form)
return self._form

async def send_push_promise(self, path: str, raise_if_unavailable: bool = False) -> None:
"""Send a push promise.
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/test_connection/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pytest

from litestar import MediaType, Request, asgi, get
from litestar import MediaType, Request, asgi, get, post
from litestar.connection.base import empty_send
from litestar.datastructures import Address, Cookie
from litestar.exceptions import (
Expand Down Expand Up @@ -282,6 +282,24 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_form_urlencoded_multi_keys() -> None:
@post("/")
async def handler(request: Request) -> Any:
return (await request.form()).getall("foo")

with create_test_client(handler) as client:
assert client.post("/", data={"foo": ["1", "2"]}).json() == ["1", "2"]


def test_request_form_multipart_multi_keys() -> None:
@post("/")
async def handler(request: Request) -> int:
return len((await request.form()).getall("foo"))

with create_test_client(handler) as client:
assert client.post("/", data={"foo": "1"}, files={"foo": b"a"}).json() == 2


def test_request_body_then_stream() -> None:
async def app(scope: Any, receive: Receive, send: Send) -> None:
request = Request[Any, Any, Any](scope, receive)
Expand Down
23 changes: 11 additions & 12 deletions tests/unit/test_kwargs/test_multipart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,17 @@ async def form_multi_item_handler(request: Request) -> DefaultDict[str, list]:
data = await request.form()
output = defaultdict(list)
for key, value in data.multi_items():
for v in value:
if isinstance(v, UploadFile):
content = await v.read()
output[key].append(
{
"filename": v.filename,
"content": content.decode(),
"content_type": v.content_type,
}
)
else:
output[key].append(v)
if isinstance(value, UploadFile):
content = await value.read()
output[key].append(
{
"filename": value.filename,
"content": content.decode(),
"content_type": value.content_type,
}
)
else:
output[key].append(value)
return output


Expand Down
Loading