Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(OpenAPI): Correctly handle typing.NewType #3580

Merged
merged 3 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from litestar.utils.typing import (
get_origin_or_inner_type,
make_non_optional_union,
unwrap_new_type,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -325,7 +326,9 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re

result: Schema | Reference

if plugin_for_annotation := self.get_plugin_for(field_definition):
if field_definition.is_new_type:
result = self.for_new_type(field_definition)
elif plugin_for_annotation := self.get_plugin_for(field_definition):
result = self.for_plugin(field_definition, plugin_for_annotation)
elif _should_create_enum_schema(field_definition):
annotation = _type_or_first_not_none_inner_type(field_definition)
Expand Down Expand Up @@ -354,6 +357,15 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re

return self.process_schema_result(field_definition, result) if isinstance(result, Schema) else result

def for_new_type(self, field_definition: FieldDefinition) -> Schema | Reference:
return self.for_field_definition(
FieldDefinition.from_kwarg(
annotation=unwrap_new_type(field_definition.raw),
name=field_definition.name,
default=field_definition.default,
)
)

@staticmethod
def for_upload_file(field_definition: FieldDefinition) -> Schema:
"""Create schema for UploadFile.
Expand Down
27 changes: 12 additions & 15 deletions litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,10 @@
from copy import deepcopy
from dataclasses import dataclass, is_dataclass, replace
from inspect import Parameter, Signature
from typing import (
Any,
AnyStr,
Callable,
Collection,
ForwardRef,
Literal,
Mapping,
Protocol,
Sequence,
TypeVar,
cast,
)
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast

from msgspec import UnsetType
from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict
from typing_extensions import NewType, NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict

from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning
from litestar.openapi.spec import Example
Expand Down Expand Up @@ -314,7 +302,12 @@ def is_generic(self) -> bool:
def is_simple_type(self) -> bool:
"""Check if the field type is a singleton value (e.g. int, str etc.)."""
return not (
self.is_generic or self.is_optional or self.is_union or self.is_mapping or self.is_non_string_iterable
self.is_generic
or self.is_optional
or self.is_union
or self.is_mapping
or self.is_non_string_iterable
or self.is_new_type
)

@property
Expand Down Expand Up @@ -366,6 +359,10 @@ def is_tuple(self) -> bool:
"""Whether the annotation is a ``tuple`` or not."""
return self.is_subclass_of(tuple)

@property
def is_new_type(self) -> bool:
return isinstance(self.annotation, NewType)

@property
def is_type_var(self) -> bool:
"""Whether the annotation is a TypeVar or not."""
Expand Down
10 changes: 9 additions & 1 deletion litestar/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
cast,
)

from typing_extensions import Annotated, NotRequired, Required, get_args, get_origin, get_type_hints
from typing_extensions import Annotated, NewType, NotRequired, Required, get_args, get_origin, get_type_hints

from litestar.types.builtin_types import NoneType, UnionTypes

Expand Down Expand Up @@ -174,6 +174,14 @@ def unwrap_annotation(annotation: Any) -> tuple[Any, tuple[Any, ...], set[Any]]:
return annotation, tuple(metadata), wrappers


def unwrap_new_type(new_type: Any) -> Any:
"""Unwrap a (nested) ``typing.NewType``"""
inner = new_type
while isinstance(inner, NewType):
inner = inner.__supertype__
return inner


def get_origin_or_inner_type(annotation: Any) -> Any:
"""Get origin or unwrap it. Returns None for non-generic types.

Expand Down
40 changes: 39 additions & 1 deletion tests/unit/test_openapi/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

import pytest
from typing_extensions import Annotated
from typing_extensions import Annotated, NewType

from litestar import Controller, Litestar, Router, get
from litestar._openapi.datastructures import OpenAPIContext
Expand Down Expand Up @@ -380,3 +380,41 @@ async def uuid_path(id: Annotated[UUID, Parameter(description="UUID ID")]) -> UU
response = client.get("/schema/openapi.json")
assert response.json()["paths"]["/str/{id}"]["get"]["parameters"][0]["description"] == "String ID"
assert response.json()["paths"]["/uuid/{id}"]["get"]["parameters"][0]["description"] == "UUID ID"


def test_unwrap_new_type() -> None:
FancyString = NewType("FancyString", str)

@get("/{path_param:str}")
async def handler(
param: FancyString,
optional_param: Optional[FancyString],
path_param: FancyString,
) -> FancyString:
return FancyString("")

app = Litestar([handler])
assert app.openapi_schema.paths["/{path_param}"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[1].schema.one_of == [ # type: ignore[index, union-attr]
Schema(type=OpenAPIType.NULL),
Schema(type=OpenAPIType.STRING),
]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[2].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert (
app.openapi_schema.paths["/{path_param}"].get.responses["200"].content["application/json"].schema.type # type: ignore[index, union-attr]
== OpenAPIType.STRING
)


def test_unwrap_nested_new_type() -> None:
FancyString = NewType("FancyString", str)
FancierString = NewType("FancierString", FancyString) # pyright: ignore

@get("/")
async def handler(
param: FancierString,
) -> None:
return None

app = Litestar([handler])
assert app.openapi_schema.paths["/"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
Loading