Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(signature): Fix #3593: Ensure signature model internal function signatures don't clash with model signature #3605

Merged
merged 3 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion litestar/_kwargs/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def resolve_dependency(
"""
signature_model = dependency.provide.signature_model
dependency_kwargs = (
signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs)
signature_model.parse_values_from_connection_kwargs(connection=connection, kwargs=kwargs)
if signature_model._fields
else {}
)
Expand Down
10 changes: 6 additions & 4 deletions litestar/_signature/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASG
return message

@classmethod
def _collect_errors(cls, deserializer: Callable[[Any, Any], Any], **kwargs: Any) -> list[tuple[str, Exception]]:
def _collect_errors(
cls, deserializer: Callable[[Any, Any], Any], kwargs: dict[str, Any]
) -> list[tuple[str, Exception]]:
exceptions: list[tuple[str, Exception]] = []
for field_name in cls._fields:
try:
Expand All @@ -181,12 +183,12 @@ def _collect_errors(cls, deserializer: Callable[[Any, Any], Any], **kwargs: Any)
return exceptions

@classmethod
def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]:
def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, kwargs: dict[str, Any]) -> dict[str, Any]:
"""Extract values from the connection instance and return a dict of parsed values.

Args:
connection: The ASGI connection instance.
**kwargs: A dictionary of kwargs.
kwargs: A dictionary of kwargs.

Raises:
ValidationException: If validation failed.
Expand All @@ -206,7 +208,7 @@ def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwarg
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e
except ValidationError as e:
for field_name, exc in cls._collect_errors(deserializer=deserializer, **kwargs): # type: ignore[assignment]
for field_name, exc in cls._collect_errors(deserializer=deserializer, kwargs=kwargs): # type: ignore[assignment]
match = ERR_RE.search(str(exc))
keys = [field_name, str(match.group(1))] if match else [field_name]
message = cls._build_error_message(keys=keys, exc_msg=str(exc), connection=connection)
Expand Down
2 changes: 1 addition & 1 deletion litestar/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ async def _get_response_data(
cleanup_group = await parameter_model.resolve_dependencies(request, kwargs)

parsed_kwargs = route_handler.signature_model.parse_values_from_connection_kwargs(
connection=request, **kwargs
connection=request, kwargs=kwargs
)

if cleanup_group:
Expand Down
2 changes: 1 addition & 1 deletion litestar/routes/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N
cleanup_group = await self.handler_parameter_model.resolve_dependencies(websocket, parsed_kwargs)

parsed_kwargs = self.route_handler.signature_model.parse_values_from_connection_kwargs(
connection=websocket, **parsed_kwargs
connection=websocket, kwargs=parsed_kwargs
)

if cleanup_group:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ exclude-classes = """
"""

[tool.ruff]
include = [
"{litestar,tests,docs,test_apps,tools}/**/*.{py,pyi}",
"pyproject.toml"
]

lint.select = [
"A", # flake8-builtins
"B", # flake8-bugbear
Expand Down
30 changes: 27 additions & 3 deletions tests/unit/test_signature/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def fn(a: int) -> None:
type_decoders=[],
)
with pytest.raises(ValidationException):
model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), a="not an int")
model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), kwargs={"a": "not an int"})


def test_create_signature_validation() -> None:
Expand Down Expand Up @@ -128,7 +128,7 @@ def handler(data: Parent) -> None:

with pytest.raises(ValidationException) as exc_info:
model.parse_values_from_connection_kwargs(
connection=RequestFactory().get(route_handler=handler), data={"child": {}, "other_child": {}}
connection=RequestFactory().get(route_handler=handler), kwargs={"data": {"child": {}, "other_child": {}}}
)

assert isinstance(exc_info.value.extra, list)
Expand Down Expand Up @@ -283,7 +283,7 @@ def fn(a: Annotated[int, Parameter(gt=5)], b: Annotated[int, Parameter(lt=5)]) -
type_decoders=[],
)
with pytest.raises(ValidationException) as exc:
model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), a=0, b=9)
model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), kwargs={"a": 0, "b": 9})

assert exc.value.extra == [
{"message": "Expected `int` >= 6", "key": "a", "source": ParamType.QUERY},
Expand All @@ -303,3 +303,27 @@ async def something(foo: Foo[str] = Foo()) -> None:

with create_test_client([something]) as client:
assert client.get("/").status_code == 200


def test_separate_model_namespace() -> None:
# https://github.com/litestar-org/litestar/issues/3593

async def provide_connection() -> str:
return "connection"

@get("/connection", dependencies={"connection": provide_connection})
async def get_connection(connection: str) -> str:
return connection

async def provide_deserializer() -> str:
return "deserializer"

@get("/deserializer", dependencies={"deserializer": provide_deserializer})
async def get_deserializer(deserializer: int) -> str:
return deserializer # type: ignore[return-value]

with create_test_client([get_connection, get_deserializer], raise_server_exceptions=True, debug=True) as client:
assert client.get("/connection").text == "connection"
res = client.get("/deserializer")
assert res.status_code == 500
assert "Expected `int`, got `str` - at `$.deserializer`" in res.text
Loading