Skip to content

Commit

Permalink
Merge pull request #95 from piercefreeman/feature/fix-ref-in-schemas
Browse files Browse the repository at this point in the history
Resolve enum refs in URL schemas
  • Loading branch information
piercefreeman authored Apr 16, 2024
2 parents 2d3b595 + 3df72a0 commit ecf8756
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 61 deletions.
60 changes: 60 additions & 0 deletions mountaineer/__tests__/client_builder/test_build_actions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import datetime
from re import sub as re_sub
from uuid import UUID

import pytest
from fastapi import APIRouter
Expand Down Expand Up @@ -125,6 +127,37 @@ def fn2():
"ExampleResponseModel",
],
),
# No request body parameter
(
"my_method_fn",
"/testing/url",
ActionDefinition(
action_type=ActionType.POST,
summary="",
operationId="",
requestBody=None,
responses={
"200": EXAMPLE_RESPONSE_200,
"422": EXAMPLE_RESPONSE_400,
},
),
(
"""
export const my_method_fn = (): Promise<ExampleResponseModel> => {
return __request({
'method': 'POST',
'url': '/testing/url',
'errors': {
422: HTTPValidationErrorException
}
});
}
"""
),
[
"ExampleResponseModel",
],
),
# Path and query parameters
(
"my_method_fn",
Expand Down Expand Up @@ -363,3 +396,30 @@ def test_build_raw_response_action(
built_function, build_imports = builder.build_action(url, definition, method_name)
assert re_sub(r"\s+", "", built_function) == re_sub(r"\s+", "", expected_function)
assert set(build_imports) == set(expected_imports)


AnyType = None | bool | str | int | datetime | UUID
DictParamItem = dict[str, AnyType]


def test_build_invalid_action_api():
"""
Ensure that we throw an error if the user has provided a schema payload
that is technically valid, but doesn't allow typehinting with pydantic. All non-raw
JSON requests should be Pydantic methods.
https://github.com/piercefreeman/mountaineer/issues/94
"""

def fn1(payload: DictParamItem, item_id: str, other_id: str) -> None:
pass

router = APIRouter()
router.post("/fn1")(fn1)

builder = OpenAPIToTypescriptActionConverter()
openapi_spec = get_openapi(title="", version="", routes=router.routes)

with pytest.raises(ValueError):
builder.convert(openapi_spec)
45 changes: 45 additions & 0 deletions mountaineer/__tests__/client_builder/test_build_links.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from enum import StrEnum
from re import sub as re_sub
from typing import Callable
from uuid import UUID

import pytest
from fastapi import APIRouter
Expand All @@ -20,6 +22,20 @@ def view_endpoint_query_params(query_a: str, query_b: int | None = None):
pass


class RouteType(StrEnum):
ROUTE_A = "route_a"
ROUTE_B = "route_b"


def enum_view_url(model_type: RouteType, model_id: UUID):
"""
Model view paths like /{model_type}/{model_id} where we want a flexible
string to be captured in the model_type path.
"""
pass


@pytest.mark.parametrize(
"url, endpoint, expected_link",
[
Expand Down Expand Up @@ -98,6 +114,35 @@ def view_endpoint_query_params(query_a: str, query_b: int | None = None):
"""
),
),
# Path with enum path variables - we should typehint explicitly
# as the enum based values
(
"/enum_view/{model_type}/{model_id}",
enum_view_url,
(
"""
export const getLink = ({
model_type,
model_id
}:{
model_type: 'route_a' | 'route_b',
model_id: string
}) => {
const url = `/enum_view/{model_type}/{model_id}`;
const queryParameters: Record<string,any> = {};
const pathParameters: Record<string,any> = {
model_type,
model_id
};
return __getLink({
rawUrl: url,
queryParameters,
pathParameters
});
};
"""
),
),
],
)
def test_convert(url: str, endpoint: Callable, expected_link: str):
Expand Down
19 changes: 15 additions & 4 deletions mountaineer/client_builder/build_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,16 @@ def build_action_payload(self, url: str, action: ActionDefinition):
if not action.is_raw_response:
response_types.append(
self.get_typescript_name_from_content_definition(
response_definition.content_schema
response_definition.content_schema,
url=url,
status_code=status_int,
)
)
else:
error_typehint = self.get_typescript_name_from_content_definition(
response_definition.content_schema
response_definition.content_schema,
url=url,
status_code=status_int,
)
common_params["errors"][status_int] = TSLiteral(
# Provide a mapping to the error class
Expand Down Expand Up @@ -253,11 +257,18 @@ def get_method_names(self, url: str, actions: list[ActionDefinition]):
return method_names

def get_typescript_name_from_content_definition(
self, definition: ContentDefinition
self,
definition: ContentDefinition,
# Url and status are provided for more context about where the error
# is being thrown. Can pass None if not available.
url: str | None = None,
status_code: int | None = None,
):
if not definition.schema_ref.ref:
raise ValueError(
f"Content definition {definition} does not have a schema reference"
f"Content definition {definition} does not have a schema reference.\n"
f"Double check your action definition for {url} with response code {status_code}.\n"
"Are you typehinting your response with a Pydantic BaseModel?"
)
return definition.schema_ref.ref.split("/")[-1]

Expand Down
4 changes: 3 additions & 1 deletion mountaineer/client_builder/build_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def convert(self, openapi: dict[str, Any]):
}:
continue

typehint_key, typehint_value = get_typehint_for_parameter(parameter)
typehint_key, typehint_value = get_typehint_for_parameter(
parameter, openapi_spec
)
input_parameters[TSLiteral(parameter.name)] = TSLiteral(parameter.name)
typehint_parameters[typehint_key] = typehint_value

Expand Down
29 changes: 4 additions & 25 deletions mountaineer/client_builder/build_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from inflection import camelize
from pydantic import BaseModel, create_model

from mountaineer.annotation_helpers import get_value_by_alias, yield_all_subtypes
from mountaineer.annotation_helpers import yield_all_subtypes
from mountaineer.client_builder.openapi import (
EmptyAPIProperty,
OpenAPIProperty,
OpenAPISchema,
OpenAPISchemaType,
resolve_ref,
)
from mountaineer.client_builder.typescript import (
TSLiteral,
Expand Down Expand Up @@ -103,7 +104,7 @@ def walk_models(
):
yield property
if property.ref is not None:
yield from walk_models(self.resolve_ref(property.ref, base))
yield from walk_models(resolve_ref(property.ref, base))
if property.items:
yield from walk_models(property.items)
if property.anyOf:
Expand All @@ -116,28 +117,6 @@ def walk_models(

return list(set(walk_models(base)))

def resolve_ref(self, ref: str, base: BaseModel) -> OpenAPIProperty:
"""
Resolve a $ref that points to a propery-compliant schema in the same document. If this
ref points somewhere else in the document (that is valid but not a data model) than we
raise a ValueError.
"""
current_obj = base
for part in ref.split("/"):
if part == "#":
current_obj = base
else:
try:
current_obj = get_value_by_alias(current_obj, part)
except AttributeError as e:
raise AttributeError(
f"Invalid $ref, couldn't resolve path: {ref}"
) from e
if not isinstance(current_obj, OpenAPIProperty):
raise ValueError(f"Resolved $ref is not a valid OpenAPIProperty: {ref}")
return current_obj

def convert_schema_to_interface(
self,
model: OpenAPIProperty,
Expand Down Expand Up @@ -188,7 +167,7 @@ def walk_array_types(prop: OpenAPIProperty | EmptyAPIProperty) -> Iterator[str]:
yield f"Array<{' | '.join(array_types)}>"
elif prop.ref:
yield self.get_typescript_interface_name(
self.resolve_ref(prop.ref, base=base)
resolve_ref(prop.ref, base=base)
)
elif prop.items:
yield from walk_array_types(prop.items)
Expand Down
47 changes: 20 additions & 27 deletions mountaineer/client_builder/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic import BaseModel, ConfigDict, Field, model_validator

from mountaineer.annotation_helpers import get_value_by_alias

#
# Enum definitions
#
Expand Down Expand Up @@ -324,33 +326,24 @@ class Components(BaseModel):
#


def get_types_from_parameters(schema: OpenAPIProperty | EmptyAPIProperty):
def resolve_ref(ref: str, base: BaseModel) -> OpenAPIProperty:
"""
Handle potentially complex types from the parameter schema, like the case
of optional fields.
Resolve a $ref that points to a propery-compliant schema in the same document. If this
ref points somewhere else in the document (that is valid but not a data model) than we
raise a ValueError.
"""
if isinstance(schema, EmptyAPIProperty):
return "any"

# Recursively gather all of the types that might be nested
if schema.variable_type:
yield schema.variable_type

for property in schema.properties.values():
yield from get_types_from_parameters(property)

if schema.additionalProperties:
yield from get_types_from_parameters(schema.additionalProperties)

if schema.items:
yield from get_types_from_parameters(schema.items)

if schema.anyOf:
for one_of in schema.anyOf:
yield from get_types_from_parameters(one_of)

# We don't expect $ref values in the URL schema, if we do then the parsing
# is likely incorrect
if schema.ref:
raise ValueError(f"Unexpected $ref in URL schema: {schema.ref}")
current_obj = base
for part in ref.split("/"):
if part == "#":
current_obj = base
else:
try:
current_obj = get_value_by_alias(current_obj, part)
except AttributeError as e:
raise AttributeError(
f"Invalid $ref, couldn't resolve path: {ref}"
) from e
if not isinstance(current_obj, OpenAPIProperty):
raise ValueError(f"Resolved $ref is not a valid OpenAPIProperty: {ref}")
return current_obj
Loading

0 comments on commit ecf8756

Please sign in to comment.