diff --git a/Makefile b/Makefile index 06b63909..ad69a17a 100644 --- a/Makefile +++ b/Makefile @@ -100,7 +100,7 @@ test-scripts: define test-common echo "Running tests for $(2)..." - @(cd $(1) && poetry run pytest -vvv -W error $(test-args) $(2)) + @(cd $(1) && poetry run pytest -vvv -W "error::Warning" -W "default::PendingDeprecationWarning" $(test-args) $(2)) endef define test-rust-common @@ -111,7 +111,7 @@ endef # Use `-n auto` to run tests in parallel define test-common-integrations echo "Running tests for $(2)..." - @(cd $(1) && poetry run pytest -s -m integration_tests -W error $(2)) + @(cd $(1) && poetry run pytest -s -m integration_tests -vvv -W "error::Warning" -W "default::PendingDeprecationWarning" $(2)) endef define lint-common diff --git a/mountaineer/__tests__/client_builder/file_generators/test_globals.py b/mountaineer/__tests__/client_builder/file_generators/test_globals.py new file mode 100644 index 00000000..5bb955e8 --- /dev/null +++ b/mountaineer/__tests__/client_builder/file_generators/test_globals.py @@ -0,0 +1,234 @@ +from enum import Enum +from pathlib import Path +from typing import List, Sequence + +import pytest + +from mountaineer.actions.passthrough_dec import passthrough +from mountaineer.actions.sideeffect_dec import sideeffect +from mountaineer.app import AppController +from mountaineer.client_builder.file_generators.base import ParsedController +from mountaineer.client_builder.file_generators.globals import ( + GlobalControllerGenerator, + GlobalLinkGenerator, +) +from mountaineer.client_builder.parser import ( + ControllerParser, + ControllerWrapper, + EnumWrapper, + ModelWrapper, +) +from mountaineer.controller import ControllerBase +from mountaineer.controller_layout import LayoutControllerBase +from mountaineer.paths import ManagedViewPath +from mountaineer.render import RenderBase + + +# Test Classes +class StatusEnum(Enum): + ACTIVE = "active" + PENDING = "pending" + INACTIVE = "inactive" + + +class MainModel(RenderBase): + name: str + status: StatusEnum + + +class ChildModel(MainModel): + child_field: str + count: int + + +class DependentModel(MainModel): + base: MainModel + child: ChildModel + current_status: StatusEnum + + +# Controllers +class BaseController(ControllerBase): + @passthrough + def base_action(self) -> MainModel: # type: ignore + pass + + +class ChildController(BaseController): + url: str = "/child" + view_path = "/child.tsx" + + async def render(self) -> DependentModel: # type: ignore + pass + + @sideeffect + def update(self, data: ChildModel) -> DependentModel: # type: ignore + pass + + +class LayoutController(LayoutControllerBase): + view_path = "/layout.tsx" + + async def render(self) -> MainModel: # type: ignore + pass + + +@pytest.fixture +def managed_path(tmp_path: Path) -> ManagedViewPath: + return ManagedViewPath(tmp_path) + + +@pytest.fixture +def controller_parser() -> ControllerParser: + return ControllerParser() + + +@pytest.fixture +def controller_wrappers(controller_parser: ControllerParser) -> list[ControllerWrapper]: + # Concrete instances should be mounted to an AppController to augment + # some of the runtime type information + app_controller = AppController(view_root=Path()) + app_controller.register(ChildController()) + app_controller.register(LayoutController()) + + return [ + controller_parser.parse_controller(ChildController), + controller_parser.parse_controller(LayoutController), + ] + + +# Tests +class TestGlobalControllerGenerator: + @pytest.fixture + def generator( + self, + managed_path: ManagedViewPath, + controller_wrappers: List[ControllerWrapper], + ) -> GlobalControllerGenerator: + return GlobalControllerGenerator( + managed_path=managed_path, controller_wrappers=controller_wrappers + ) + + def test_model_enum_graph_resolution( + self, generator: GlobalControllerGenerator + ) -> None: + """Test that models and enums are sorted correctly""" + # Get embedded types + controllers = ControllerWrapper.get_all_embedded_controllers( + generator.controller_wrappers + ) + embedded = ControllerWrapper.get_all_embedded_types( + controllers, include_superclasses=True + ) + + # Sort them + sorted_items = generator._build_model_enum_graph( + embedded.models, embedded.enums + ) + + # Hierarchy is: + # StatusEnum + # MainModel <- DependentModel + # <- ChildModel + # Verify StatusEnum comes before BaseModel + # Enums come before models and parents come before their subclasses + enum_idx = self.get_item_order(sorted_items, "StatusEnum") + main_model_idx = self.get_item_order(sorted_items, "MainModel") + assert enum_idx < main_model_idx + + child_model_idx = self.get_item_order(sorted_items, "ChildModel") + assert main_model_idx < child_model_idx + + dependent_model_idx = self.get_item_order(sorted_items, "DependentModel") + assert main_model_idx < dependent_model_idx + + def test_controller_graph_resolution( + self, generator: GlobalControllerGenerator + ) -> None: + """Test that controllers are sorted correctly""" + controllers = ControllerWrapper.get_all_embedded_controllers( + generator.controller_wrappers + ) + sorted_controllers = generator._build_controller_graph(controllers) + + base_idx = self.get_item_order(sorted_controllers, "BaseController") + child_idx = self.get_item_order(sorted_controllers, "ChildController") + + # Base should come before Child + assert base_idx < child_idx + + def test_script_generation(self, generator: GlobalControllerGenerator) -> None: + """Test the complete script generation""" + blocks = generator.script() + content = "\n".join(block.content for block in blocks) + + # Verify models are generated + assert "export interface MainModel" in content + assert "export interface ChildModel extends MainModel" in content + assert "export interface DependentModel extends MainModel" in content + + # Verify enum is generated + assert "export enum StatusEnum" in content + assert "ACTIVE = " in content + + # Verify controllers are generated + assert "export interface BaseController" in content + assert "export interface ChildController extends BaseController" in content + + # Verify layout controller is included + assert "export interface LayoutController" in content + + def get_item_order( + self, + sorted_items: Sequence[ModelWrapper | EnumWrapper | ControllerWrapper], + raw_name: str, + ): + return next( + i for i, item in enumerate(sorted_items) if item.name.raw_name == raw_name + ) + + +class TestGlobalLinkGenerator: + @pytest.fixture + def parsed_controllers( + self, controller_parser: ControllerParser, managed_path: ManagedViewPath + ) -> List[ParsedController]: + (managed_path / "child").mkdir() + (managed_path / "layout").mkdir() + + return [ + ParsedController( + wrapper=controller_parser.parse_controller(ChildController), + view_path=ManagedViewPath(managed_path / "child"), + is_layout=False, + ), + ParsedController( + wrapper=controller_parser.parse_controller(LayoutController), + view_path=ManagedViewPath(managed_path / "layout"), + is_layout=True, + ), + ] + + @pytest.fixture + def generator( + self, managed_path: ManagedViewPath, parsed_controllers: List[ParsedController] + ) -> GlobalLinkGenerator: + return GlobalLinkGenerator( + managed_path=managed_path, parsed_controllers=parsed_controllers + ) + + def test_script_generation(self, generator: GlobalLinkGenerator) -> None: + """Test link aggregator generation""" + blocks = generator.script() + content = "\n".join(block.content for block in blocks) + + # Verify imports + assert "import { getLink as ChildControllerGetLinks }" in content + + # Verify layout controller is excluded + assert "LayoutControllerGetLinks" not in content + + # Verify link generator object + assert "const linkGenerator = {" in content + assert "childController: ChildControllerGetLinks" in content + assert "export default linkGenerator" in content diff --git a/mountaineer/__tests__/client_builder/file_generators/test_locals.py b/mountaineer/__tests__/client_builder/file_generators/test_locals.py new file mode 100644 index 00000000..9a6d99b7 --- /dev/null +++ b/mountaineer/__tests__/client_builder/file_generators/test_locals.py @@ -0,0 +1,290 @@ +from enum import Enum +from pathlib import Path +from typing import Any, List + +import pytest +from fastapi import File +from pydantic import BaseModel + +from mountaineer.actions.passthrough_dec import passthrough +from mountaineer.actions.sideeffect_dec import sideeffect +from mountaineer.app import AppController +from mountaineer.client_builder.file_generators.base import CodeBlock +from mountaineer.client_builder.file_generators.locals import ( + LocalActionGenerator, + LocalGeneratorBase, + LocalIndexGenerator, + LocalLinkGenerator, + LocalModelGenerator, + LocalUseServerGenerator, +) +from mountaineer.client_builder.parser import ( + ControllerParser, + ControllerWrapper, +) +from mountaineer.controller import ControllerBase +from mountaineer.paths import ManagedViewPath +from mountaineer.render import RenderBase + + +# Test Models and Enums +class ExampleStatus(Enum): + ACTIVE = "active" + INACTIVE = "inactive" + + +class ExampleBaseModel(BaseModel): + name: str + status: ExampleStatus + + +class ExampleRequestModel(BaseModel): + query: str + limit: int = 10 + + +class ExampleResponseModel(BaseModel): + results: List[str] + total: int + + +class ExampleRenderModel(RenderBase): + title: str + items: List[ExampleBaseModel] + + +# Test Controllers +class ExampleBaseController(ControllerBase): + @passthrough + def base_action(self) -> ExampleResponseModel: # type: ignore + """Base action that returns a response model""" + pass + + +class ExampleController(ExampleBaseController): + url = "/test" + view_path = "/test.tsx" + + async def render( # type: ignore + self, + path_param: str, + query_param: int = 0, + enum_param: ExampleStatus = ExampleStatus.ACTIVE, + ) -> ExampleRenderModel: # type: ignore + """Main render method""" + pass + + @passthrough + def get_data(self) -> ExampleBaseModel: # type: ignore + """Get basic data""" + pass + + @sideeffect + def update_data(self, data: ExampleRequestModel) -> ExampleResponseModel: # type: ignore + """Update data with side effects""" + pass + + @sideeffect + async def upload_file(self, file: bytes = File(...)) -> ExampleResponseModel: # type: ignore + """File upload endpoint""" + pass + + +@pytest.fixture +def managed_path(tmp_path: Path) -> ManagedViewPath: + controller_path = tmp_path / "test_controller" + controller_path.mkdir() + return ManagedViewPath(controller_path) + + +@pytest.fixture +def global_root(tmp_path: Path) -> ManagedViewPath: + return ManagedViewPath(tmp_path) + + +@pytest.fixture +def controller_parser() -> ControllerParser: + return ControllerParser() + + +@pytest.fixture +def controller_wrapper(controller_parser: ControllerParser) -> ControllerWrapper: + app_controller = AppController(view_root=Path()) + app_controller.register(ExampleController()) + + return controller_parser.parse_controller(ExampleController) + + +class TestLocalGeneratorBase: + @pytest.fixture + def generator(self, managed_path: ManagedViewPath, global_root: ManagedViewPath): + class ConcreteGeneratorBase(LocalGeneratorBase): + def script(self): + yield CodeBlock() + + return ConcreteGeneratorBase(managed_path=managed_path, global_root=global_root) + + def test_get_global_import_path(self, generator: LocalGeneratorBase) -> None: + result: str = generator.get_global_import_path("test.ts") + assert isinstance(result, str) + assert "../" in result + assert result.endswith("test") + + +class TestLocalLinkGenerator: + @pytest.fixture + def generator( + self, + managed_path: ManagedViewPath, + global_root: ManagedViewPath, + controller_wrapper: ControllerWrapper, + ) -> Any: + return LocalLinkGenerator( + controller=controller_wrapper, + managed_path=managed_path, + global_root=global_root, + ) + + def test_script_generation(self, generator: LocalLinkGenerator) -> None: + result = list(generator.script()) + assert len(result) > 0 + content = "\n".join(block.content for block in result) + assert "import" in content + assert "getLink" in content + assert "path_param" in content + assert "query_param" in content + assert "enum_param" in content + + def test_get_link_implementation_with_parameters( + self, generator: LocalLinkGenerator + ) -> None: + impl = generator._get_link_implementation(generator.controller) + assert "path_param" in impl + assert "query_param?" in impl # Optional parameter + assert "enum_param?" in impl # Optional parameter + assert "/test" in impl + + def test_get_imports(self, generator: LocalLinkGenerator) -> None: + imports = list(generator._get_imports(generator.controller)) + assert any("../api" in block.content for block in imports) + assert any("../controllers" in block.content for block in imports) + + +class TestLocalActionGenerator: + @pytest.fixture + def generator( + self, + managed_path: ManagedViewPath, + global_root: ManagedViewPath, + controller_wrapper: ControllerWrapper, + ) -> Any: + return LocalActionGenerator( + controller=controller_wrapper, + managed_path=managed_path, + global_root=global_root, + ) + + def test_generate_controller_actions(self, generator: LocalActionGenerator) -> None: + actions = list(generator._generate_controller_actions(generator.controller)) + assert len(actions) == 4 # base_action, get_data, update_data, upload_file + action_names: set[str] = { + action + for action in " ".join(actions).split() + if action in ["base_action", "get_data", "update_data", "upload_file"] + } + assert len(action_names) == 4 + + def test_get_dependent_imports(self, generator: LocalActionGenerator) -> None: + deps = generator._get_dependent_imports(generator.controller) + + # Response wrapped models + assert deps == { + "BaseActionResponseWrapped", + "ExampleRequestModel", + "GetDataResponseWrapped", + "UpdateDataResponseWrapped", + "UploadFileForm", + "UploadFileResponseWrapped", + } + + +class TestLocalModelGenerator: + @pytest.fixture + def generator( + self, + managed_path: ManagedViewPath, + global_root: ManagedViewPath, + controller_wrapper: ControllerWrapper, + ) -> Any: + return LocalModelGenerator( + controller=controller_wrapper, + managed_path=managed_path, + global_root=global_root, + ) + + def test_script_generation(self, generator: LocalModelGenerator) -> None: + result: List[Any] = list(generator.script()) + assert len(result) > 0 + content: str = "\n".join(block.content for block in result) + + # Check for model exports + assert "export type { ExampleRequestModel as ExampleRequestModel }" in content + assert "export type { ExampleResponseModel as ExampleResponseModel }" in content + assert "export type { ExampleRenderModel as ExampleRenderModel }" in content + + # Check for enum exports + assert "export { ExampleStatus as ExampleStatus }" in content + + +class TestLocalUseServerGenerator: + @pytest.fixture + def generator( + self, + managed_path: ManagedViewPath, + global_root: ManagedViewPath, + controller_wrapper: ControllerWrapper, + ) -> Any: + return LocalUseServerGenerator( + controller=controller_wrapper, + managed_path=managed_path, + global_root=global_root, + ) + + def test_script_generation_with_render( + self, generator: LocalUseServerGenerator + ) -> None: + result: List[Any] = list(generator.script()) + content: str = "\n".join(block.content for block in result) + assert "useServer" in content + assert "ServerState" in content + assert "useState" in content + assert "applySideEffect" in content + + +class TestLocalIndexGenerator: + @pytest.fixture + def generator( + self, + managed_path: ManagedViewPath, + global_root: ManagedViewPath, + controller_wrapper: ControllerWrapper, + ) -> Any: + return LocalIndexGenerator( + controller=controller_wrapper, + managed_path=managed_path, + global_root=global_root, + ) + + def test_script_generation( + self, generator: LocalIndexGenerator, managed_path: ManagedViewPath + ) -> None: + (managed_path.parent / "actions.ts").write_text( + "export const action = () => {}" + ) + (managed_path.parent / "models.ts").write_text("export type Model = {}") + + result: List[Any] = list(generator.script()) + assert len(result) > 0 + content: str = "\n".join(block.content for block in result) + assert "export * from './actions'" in content + assert "export * from './models'" in content diff --git a/mountaineer/__tests__/client_builder/interface_builders/__init__.py b/mountaineer/__tests__/client_builder/interface_builders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mountaineer/__tests__/client_builder/interface_builders/common.py b/mountaineer/__tests__/client_builder/interface_builders/common.py new file mode 100644 index 00000000..5b270b12 --- /dev/null +++ b/mountaineer/__tests__/client_builder/interface_builders/common.py @@ -0,0 +1,114 @@ +from enum import Enum +from typing import Type + +from pydantic import BaseModel + +from mountaineer.actions.fields import FunctionActionType +from mountaineer.client_builder.parser import ( + ActionWrapper, + ControllerWrapper, + EnumWrapper, + ExceptionWrapper, + FieldWrapper, + ModelWrapper, + WrapperName, +) +from mountaineer.client_builder.types import TypeDefinition +from mountaineer.controller import ControllerBase +from mountaineer.exceptions import APIException + + +def create_model_wrapper( + model: Type[BaseModel], + name: str, + fields: list[FieldWrapper] | None = None, + superclasses: list[ModelWrapper] | None = None, +) -> ModelWrapper: + wrapper_name = WrapperName(name) + return ModelWrapper( + name=wrapper_name, + module_name="test_module", + model=model, + isolated_model=model, # Simplified for testing + superclasses=superclasses or [], + value_models=fields or [], + ) + + +# Helper function to create field wrappers +def create_field_wrapper( + name: str, + type_hint: type | ModelWrapper | EnumWrapper | TypeDefinition, + required: bool = True, +) -> FieldWrapper: + return FieldWrapper(name=name, value=type_hint, required=required) + + +# Helper function to create exception wrappers +def create_exception_wrapper( + exception: Type[APIException], + name: str, + status_code: int, + value_models: list[FieldWrapper] | None = None, +) -> ExceptionWrapper: + wrapper_name = WrapperName(name) + return ExceptionWrapper( + name=wrapper_name, + module_name="test_module", + status_code=status_code, + exception=exception, + value_models=value_models or [], + ) + + +def create_action_wrapper( + name: str, + params: list[FieldWrapper] | None = None, + response_model: Type[BaseModel] | None = None, + request_body: ModelWrapper | None = None, + action_type: FunctionActionType = FunctionActionType.PASSTHROUGH, +) -> ActionWrapper: + response_wrapper = ( + create_model_wrapper(response_model, response_model.__name__) + if response_model + else None + ) + return ActionWrapper( + name=name, + module_name="test_module", + action_type=action_type, + params=params or [], + headers=[], + request_body=request_body, + response_bodies={ControllerBase: response_wrapper} if response_wrapper else {}, + exceptions=[], + is_raw_response=False, + is_streaming_response=False, + controller_to_url={ControllerBase: f"/api/{name}"}, + ) + + +def create_controller_wrapper( + name: str, + actions: dict[str, ActionWrapper] | None = None, + superclasses: list[ControllerWrapper] | None = None, + entrypoint_url: str | None = None, +) -> ControllerWrapper: + wrapper_name = WrapperName(name) + return ControllerWrapper( + name=wrapper_name, + module_name="test_module", + entrypoint_url=entrypoint_url, + controller=type(name, (ControllerBase,), {}), + superclasses=superclasses or [], + queries=[], + paths=[], + render=None, + actions=actions or {}, + ) + + +def create_enum_wrapper(enum_class: Type[Enum]) -> EnumWrapper: + """Helper function to create enum wrappers""" + wrapper_name = WrapperName(enum_class.__name__) + return EnumWrapper(name=wrapper_name, module_name="test_module", enum=enum_class) diff --git a/mountaineer/__tests__/client_builder/interface_builders/test_action.py b/mountaineer/__tests__/client_builder/interface_builders/test_action.py new file mode 100644 index 00000000..5e682bad --- /dev/null +++ b/mountaineer/__tests__/client_builder/interface_builders/test_action.py @@ -0,0 +1,305 @@ +from datetime import datetime + +import pytest +from pydantic import BaseModel + +from mountaineer.__tests__.client_builder.interface_builders.common import ( + create_exception_wrapper, + create_field_wrapper, + create_model_wrapper, +) +from mountaineer.actions.fields import FunctionActionType +from mountaineer.client_builder.interface_builders.action import ActionInterface +from mountaineer.client_builder.parser import ( + ActionWrapper, + ExceptionWrapper, + FieldWrapper, + ModelWrapper, +) +from mountaineer.client_builder.types import ListOf, Or +from mountaineer.controller import ControllerBase +from mountaineer.exceptions import APIException + + +# Test Models +class StandardResponse(BaseModel): + value: str + timestamp: datetime + + +class ErrorResponse(APIException): + error_code: str + message: str + + +class FormData(BaseModel): + name: str + email: str + + +class AlternateResponse(BaseModel): + status: str + code: int + + +@pytest.fixture +def standard_response_wrapper(): + return create_model_wrapper(StandardResponse, "StandardResponse") + + +@pytest.fixture +def error_response_wrapper(): + return create_exception_wrapper(ErrorResponse, "ErrorResponse", 400) + + +@pytest.fixture +def form_data_wrapper(): + return create_model_wrapper(FormData, "FormData") + + +class TestBasicInterfaceGeneration: + def test_simple_action_interface(self, standard_response_wrapper: ModelWrapper): + action = ActionWrapper( + name="simple_action", + module_name="test_module", + action_type=FunctionActionType.PASSTHROUGH, + params=[], + headers=[], + request_body=None, + response_bodies={ControllerBase: standard_response_wrapper}, + exceptions=[], + is_raw_response=False, + is_streaming_response=False, + controller_to_url={ControllerBase: "/api/base/simple_action"}, + ) + + interface = ActionInterface.from_action( + action, "/api/base/simple_action", ControllerBase + ) + + assert interface.name == "simple_action" + assert "signal?: AbortSignal" in interface.typehints + assert "Promise" in interface.response_type + assert "StandardResponse" in interface.response_type + + @pytest.mark.parametrize( + "params, expected_strs, expected_default_initializer", + [ + # At least one required parameter, should require user input on functions + ( + [ + create_field_wrapper("required_param", str, True), + create_field_wrapper("optional_param", int, False), + ], + ["required_param: string", "optional_param?: number"], + False, + ), + # All optional parameters, should not require user input on functions + ( + [ + create_field_wrapper("optional_param", int, False), + ], + ["optional_param?: number"], + True, + ), + ], + ) + def test_action_url_query_parameters( + self, + params: list[FieldWrapper], + expected_strs: list[str], + expected_default_initializer: bool, + ): + action = ActionWrapper( + name="parametrized_action", + module_name="test_module", + action_type=FunctionActionType.PASSTHROUGH, + params=params, + headers=[], + request_body=None, + response_bodies={ + ControllerBase: create_model_wrapper( + StandardResponse, "StandardResponse" + ) + }, + exceptions=[], + is_raw_response=False, + is_streaming_response=False, + controller_to_url={ControllerBase: "/api/main/parametrized_action"}, + ) + + interface = ActionInterface.from_action( + action, "/api/main/parametrized_action", ControllerBase + ) + + ts_code = interface.to_js() + for expected_str in expected_strs: + assert expected_str in ts_code + assert interface.default_initializer == expected_default_initializer + + +class TestRequestBodyHandling: + def test_form_action_interface(self, form_data_wrapper): + action = ActionWrapper( + name="form_action", + module_name="test_module", + action_type=FunctionActionType.SIDEEFFECT, + params=[], + headers=[], + request_body=form_data_wrapper, + response_bodies={ + ControllerBase: create_model_wrapper( + StandardResponse, "StandardResponse" + ) + }, + exceptions=[], + is_raw_response=False, + is_streaming_response=False, + controller_to_url={ControllerBase: "/api/main/form"}, + ) + + interface = ActionInterface.from_action( + action, "/api/main/form", ControllerBase + ) + + ts_code = interface.to_js() + assert "requestBody: FormData" in ts_code + assert "mediaType" in "".join(interface.body) + + +class TestResponseTypeHandling: + def test_raw_response_handling(self, standard_response_wrapper: ModelWrapper): + action = ActionWrapper( + name="raw_action", + module_name="test_module", + action_type=FunctionActionType.PASSTHROUGH, + params=[], + headers=[], + request_body=None, + response_bodies={ControllerBase: standard_response_wrapper}, + exceptions=[], + is_raw_response=True, + is_streaming_response=False, + controller_to_url={ControllerBase: "/api/main/raw"}, + ) + + interface = ActionInterface.from_action(action, "/api/main/raw", ControllerBase) + + assert "Promise" in interface.response_type + + def test_streaming_response_handling(self, standard_response_wrapper: ModelWrapper): + action = ActionWrapper( + name="stream_action", + module_name="test_module", + action_type=FunctionActionType.PASSTHROUGH, + params=[], + headers=[], + request_body=None, + response_bodies={ControllerBase: standard_response_wrapper}, + exceptions=[], + is_raw_response=False, + is_streaming_response=True, + controller_to_url={ControllerBase: "/api/main/stream"}, + ) + + interface = ActionInterface.from_action( + action, "/api/main/stream", ControllerBase + ) + + assert "AsyncGenerator" in ts_code + assert "optional_param?: string" in ts_code diff --git a/mountaineer/__tests__/client_builder/interface_builders/test_base.py b/mountaineer/__tests__/client_builder/interface_builders/test_base.py new file mode 100644 index 00000000..551bf073 --- /dev/null +++ b/mountaineer/__tests__/client_builder/interface_builders/test_base.py @@ -0,0 +1,267 @@ +from datetime import date, datetime, time +from enum import Enum +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from uuid import UUID + +import pytest +from fastapi import UploadFile +from pydantic import BaseModel + +from mountaineer.client_builder.interface_builders.base import InterfaceBase +from mountaineer.client_builder.parser import ( + EnumWrapper, + ModelWrapper, + SelfReference, + WrapperName, +) +from mountaineer.client_builder.types import ( + DictOf, + ListOf, + LiteralOf, + Or, + SetOf, + TupleOf, + TypeDefinition, +) + + +# Test Models and Types +class Status(Enum): + ACTIVE = "active" + INACTIVE = "inactive" + PENDING = "pending" + + +class SimpleModel(BaseModel): + name: str + count: int + active: bool + + +class NestedModel(BaseModel): + status: Status + data: SimpleModel + optional: Optional[SimpleModel] + + +class ComplexTypes(BaseModel): + list_field: List[str] + dict_field: Dict[str, int] + set_field: Set[float] + tuple_field: Tuple[str, int] + union_field: Union[str, int] + literal_field: Literal["a", "b", "c"] + + +# Test Implementation +T = TypeVar("T") + + +class TypeConverter(InterfaceBase): + """Helper class for testing the abstract base model""" + + @classmethod + def convert(cls, value: Any) -> str: + return cast(str, cls._get_annotated_value(value)) + + +# Test Fixtures +@pytest.fixture +def model_wrapper() -> ModelWrapper: + return ModelWrapper( + name=WrapperName("SimpleModel"), + module_name="test_module", + model=SimpleModel, + isolated_model=SimpleModel, + superclasses=[], + value_models=[], + ) + + +@pytest.fixture +def enum_wrapper() -> EnumWrapper: + return EnumWrapper( + name=WrapperName("Status"), + module_name="test_module", + enum=Status, + ) + + +@pytest.fixture +def self_reference() -> SelfReference: + return SelfReference( + name="CircularModel", + model=SimpleModel, + ) + + +class TestPrimitiveTypeMapping: + @pytest.mark.parametrize( + "py_type,ts_type", + [ + (str, "string"), + (int, "number"), + (float, "number"), + (bool, "boolean"), + (datetime, "string"), + (date, "string"), + (time, "string"), + (UUID, "string"), + (UploadFile, "Blob"), + (None, "null"), + (Any, "any"), + ], + ) + def test_primitive_type_mapping(self, py_type: Type[Any], ts_type: str) -> None: + result: str = TypeConverter.convert(py_type) + assert result == ts_type + + def test_none_type_variations(self) -> None: + assert TypeConverter.convert(None) == "null" + assert TypeConverter.convert(type(None)) == "null" + + def test_any_type_fallback(self) -> None: + class CustomType: + pass + + result: str = TypeConverter.convert(CustomType) + assert result == "any" + + +class TestComplexTypeHandling: + def test_list_conversion(self) -> None: + type_def: ListOf = ListOf(str) + result: str = TypeConverter.convert(type_def) + assert result == "Array" + + def test_nested_list_conversion(self) -> None: + type_def: ListOf = ListOf(ListOf(str)) + result: str = TypeConverter.convert(type_def) + assert result == "Array>" + + def test_dict_conversion(self) -> None: + type_def: DictOf = DictOf(str, int) + result: str = TypeConverter.convert(type_def) + assert result == "Record" + + def test_set_conversion(self) -> None: + type_def: SetOf = SetOf(float) + result: str = TypeConverter.convert(type_def) + assert result == "Set" + + def test_tuple_conversion(self) -> None: + type_def = TupleOf(str, int, bool) + result: str = TypeConverter.convert(type_def) + assert result == "[string,number,boolean]" + + def test_union_conversion(self) -> None: + type_def = Or(str, int, bool) + result: str = TypeConverter.convert(type_def) + assert result == "string | number | boolean" + + def test_literal_conversion(self) -> None: + type_def = LiteralOf("active", "inactive") + result: str = TypeConverter.convert(type_def) + assert result == '"active" | "inactive"' + + @pytest.mark.parametrize( + "type_def,expected", + [ + (DictOf(str, ListOf(int)), "Record>"), + (ListOf(DictOf(str, bool)), "Array>"), + ( + Or(ListOf(str), DictOf(str, int)), + "Array | Record", + ), + ], + ) + def test_nested_complex_types( + self, type_def: TypeDefinition, expected: str + ) -> None: + result: str = TypeConverter.convert(type_def) + assert result == expected + + +class TestModelHandling: + def test_model_wrapper_conversion(self, model_wrapper: ModelWrapper) -> None: + result: str = TypeConverter.convert(model_wrapper) + assert result == "SimpleModel" + + def test_enum_wrapper_conversion(self, enum_wrapper: EnumWrapper) -> None: + result: str = TypeConverter.convert(enum_wrapper) + assert result == "Status" + + def test_self_reference_conversion(self, self_reference: SelfReference) -> None: + result: str = TypeConverter.convert(self_reference) + assert result == "CircularModel" + + def test_nested_model_references(self, model_wrapper: ModelWrapper) -> None: + type_def: ListOf = ListOf(model_wrapper) + result: str = TypeConverter.convert(type_def) + assert result == "Array" + + +class TestComplexScenarios: + def test_deeply_nested_structure(self) -> None: + # Create a deeply nested structure + deep_type = ListOf( + DictOf(str, Or(ListOf(TupleOf(str, int)), DictOf(str, SetOf(float)))) + ) + result = TypeConverter.convert(deep_type) + expected = "Array | Record>>>" + assert result == expected + + def test_mixed_model_and_primitive_types( + self, model_wrapper: ModelWrapper, enum_wrapper: EnumWrapper + ) -> None: + type_def = Or(model_wrapper, ListOf(enum_wrapper), DictOf(str, int)) + result: str = TypeConverter.convert(type_def) + assert "SimpleModel | Array | Record" == result + + def test_complex_union_types(self) -> None: + # Test union with various nested types + type_def = Or( + ListOf(str), + DictOf(str, bool), + SetOf(int), + TupleOf(str, int), + LiteralOf("a", "b"), + ) + result: str = TypeConverter.convert(type_def) + expected: str = 'Array | Record | Set | [string,number] | "a" | "b"' + assert result == expected + + @pytest.mark.parametrize( + "type_def,expected", + [ + ( + DictOf(LiteralOf("id", "name"), Or(str, int)), + 'Record<"id" | "name", string | number>', + ), + ( + ListOf(TupleOf(LiteralOf("GET", "POST"), str)), + 'Array<["GET" | "POST",string]>', + ), + ( + SetOf(Or(LiteralOf(1, 2), LiteralOf("a", "b"))), + 'Set<1 | 2 | "a" | "b">', + ), + ], + ) + def test_complex_literal_combinations( + self, type_def: TypeDefinition, expected: str + ) -> None: + result: str = TypeConverter.convert(type_def) + assert result == expected diff --git a/mountaineer/__tests__/client_builder/interface_builders/test_controller.py b/mountaineer/__tests__/client_builder/interface_builders/test_controller.py new file mode 100644 index 00000000..0bc26a27 --- /dev/null +++ b/mountaineer/__tests__/client_builder/interface_builders/test_controller.py @@ -0,0 +1,172 @@ +from datetime import datetime +from typing import Any, Optional + +from pydantic import BaseModel + +from mountaineer.__tests__.client_builder.interface_builders.common import ( + create_action_wrapper, + create_controller_wrapper, + create_model_wrapper, +) +from mountaineer.actions.fields import FunctionActionType +from mountaineer.client_builder.interface_builders.controller import ControllerInterface +from mountaineer.client_builder.parser import ( + FieldWrapper, +) +from mountaineer.client_builder.types import Or + + +# Test Models +class SimpleResponse(BaseModel): + message: str + timestamp: datetime + + +class ComplexResponse(BaseModel): + data: dict[str, Any] + status: bool + metadata: Optional[dict[str, str]] = None + + +class FormData(BaseModel): + name: str + email: str + preferences: dict[str, bool] = {} + + +class TestBasicInterfaceGeneration: + def test_simple_controller_interface(self): + health_check = create_action_wrapper( + "health_check", response_model=SimpleResponse + ) + + controller = create_controller_wrapper( + "ApiBaseController", actions={"health_check": health_check} + ) + + interface = ControllerInterface.from_controller(controller) + ts_code = interface.to_js() + + assert "export interface" in ts_code + assert "ApiBaseController" in ts_code + assert "health_check" in ts_code + assert "SimpleResponse" in ts_code + + def test_controller_with_multiple_actions(self): + get_resource = create_action_wrapper( + "get_resource", + params=[FieldWrapper("id", str, True)], + response_model=SimpleResponse, + ) + + update_resource = create_action_wrapper( + "update_resource", + request_body=create_model_wrapper(FormData, "FormData"), + response_model=ComplexResponse, + action_type=FunctionActionType.SIDEEFFECT, + ) + + controller = create_controller_wrapper( + "ResourceController", + actions={"get_resource": get_resource, "update_resource": update_resource}, + ) + + interface = ControllerInterface.from_controller(controller) + ts_code = interface.to_js() + + assert "get_resource" in ts_code + assert "update_resource" in ts_code + assert "params: {" in ts_code + assert "FormData" in ts_code + + +class TestInheritanceHandling: + def test_single_inheritance(self): + # Create base controller + base_controller = create_controller_wrapper( + "ApiBaseController", + actions={ + "health_check": create_action_wrapper( + "health_check", response_model=SimpleResponse + ) + }, + ) + + # Create resource controller inheriting from base + resource_controller = create_controller_wrapper( + "ResourceController", + actions={ + "get_resource": create_action_wrapper( + "get_resource", response_model=SimpleResponse + ) + }, + superclasses=[base_controller], + ) + + interface = ControllerInterface.from_controller(resource_controller) + + assert "ApiBaseController" in interface.include_superclasses + assert "extends ApiBaseController" in interface.to_js() + + def test_multi_level_inheritance(self): + # Create the inheritance chain + base_controller = create_controller_wrapper("ApiBaseController") + resource_controller = create_controller_wrapper( + "ResourceController", superclasses=[base_controller] + ) + extended_controller = create_controller_wrapper( + "ExtendedController", + actions={ + "specialized_action": create_action_wrapper( + "specialized_action", response_model=ComplexResponse + ) + }, + superclasses=[resource_controller], + ) + + interface = ControllerInterface.from_controller(extended_controller) + + # Only the immediate superclasses should be specified + assert "ResourceController" in interface.include_superclasses + assert "extends ResourceController" in interface.to_js() + assert "ApiBaseController" not in interface.to_js() + + +class TestParameterHandling: + def test_optional_parameters(self): + action = create_action_wrapper( + "optional_action", + params=[ + FieldWrapper("param1", Or(str, None), False), + FieldWrapper("param2", int, False), + ], + response_model=SimpleResponse, + ) + + controller = create_controller_wrapper( + "OptionalParamsController", actions={"optional_action": action} + ) + + interface = ControllerInterface.from_controller(controller) + ts_code = interface.to_js() + + assert "params?: {" in ts_code + assert "param1?: string" in ts_code + assert "param2?: number" in ts_code + + def test_required_parameters(self): + action = create_action_wrapper( + "required_action", + params=[FieldWrapper("required_param", str, True)], + response_model=SimpleResponse, + ) + + controller = create_controller_wrapper( + "RequiredParamsController", actions={"required_action": action} + ) + + interface = ControllerInterface.from_controller(controller) + ts_code = interface.to_js() + + assert "params: {" in ts_code + assert "required_param: string" in ts_code diff --git a/mountaineer/__tests__/client_builder/interface_builders/test_enum.py b/mountaineer/__tests__/client_builder/interface_builders/test_enum.py new file mode 100644 index 00000000..a24f704d --- /dev/null +++ b/mountaineer/__tests__/client_builder/interface_builders/test_enum.py @@ -0,0 +1,165 @@ +from enum import Enum, auto +from typing import Union + +import pytest + +from mountaineer.__tests__.client_builder.interface_builders.common import ( + create_enum_wrapper, +) +from mountaineer.client_builder.interface_builders.enum import EnumInterface + + +class TestBasicEnumGeneration: + def test_string_enum(self): + class StringEnum(Enum): + ALPHA = "alpha" + BETA = "beta" + GAMMA = "gamma" + + interface = EnumInterface.from_enum(create_enum_wrapper(StringEnum)) + ts_code = interface.to_js() + + assert "export enum StringEnum" in ts_code + assert "ALPHA = " in ts_code + assert "'alpha'" in ts_code + assert ts_code.count(",") == 2 # Two commas for three values + + def test_number_enum(self): + class NumberEnum(Enum): + ONE = 1 + TWO = 2 + THREE = 3 + + interface = EnumInterface.from_enum(create_enum_wrapper(NumberEnum)) + ts_code = interface.to_js() + + assert "ONE = 1" in ts_code + assert "TWO = 2" in ts_code + assert "THREE = 3" in ts_code + + def test_mixed_type_enum(self): + class MixedEnum(Enum): + STRING = "value" + NUMBER = 42 + BOOLEAN = True + NULL = None + + interface = EnumInterface.from_enum(create_enum_wrapper(MixedEnum)) + ts_code = interface.to_js() + + assert "'value'" in ts_code + assert "42" in ts_code + assert "true" in ts_code + assert "null" in ts_code + + +class TestEnumFormatting: + def test_export_statement(self): + class ExampleEnum(Enum): + A = "a" + B = "b" + + interface = EnumInterface.from_enum(create_enum_wrapper(ExampleEnum)) + + # Test with export + assert interface.to_js().startswith("export enum") + + # Test without export + interface.include_export = False + assert interface.to_js().startswith("enum") + assert "export" not in interface.to_js() + + def test_enum_structure(self): + class ExampleEnum(Enum): + A = "a" + B = "b" + C = "c" + + interface = EnumInterface.from_enum(create_enum_wrapper(ExampleEnum)) + ts_code = interface.to_js() + + assert ts_code.count("{") == 1 + assert ts_code.count("}") == 1 + assert ts_code.count(",") == len(ExampleEnum) - 1 + + @pytest.mark.parametrize( + "value, expected", + [ + ("string", "'string'"), + (42, "42"), + (True, "true"), + (None, "null"), + ], + ) + def test_value_formatting(self, value: str, expected: Union[str, int, bool, None]): + class ValueEnum(Enum): + TEST = value + + interface = EnumInterface.from_enum(create_enum_wrapper(ValueEnum)) + ts_code = interface.to_js() + + assert f"TEST = {expected}" in ts_code + + +class TestComplexCases: + def test_auto_enum(self): + class AutoEnum(Enum): + FIRST = auto() + SECOND = auto() + THIRD = auto() + + interface = EnumInterface.from_enum(create_enum_wrapper(AutoEnum)) + ts_code = interface.to_js() + + assert "FIRST = 1" in ts_code + assert "SECOND = 2" in ts_code + assert "THIRD = 3" in ts_code + + def test_duplicate_values(self): + class DuplicateValueEnum(Enum): + A = "value" + B = "value" + C = "value" + + interface = EnumInterface.from_enum(create_enum_wrapper(DuplicateValueEnum)) + ts_code = interface.to_js() + + assert "A = " in ts_code + assert "B = " in ts_code + assert "C = " in ts_code + assert ts_code.count("'value'") == 3 + + +class TestEdgeCases: + def test_single_member_enum(self): + class SingleEnum(Enum): + ONLY = "only" + + interface = EnumInterface.from_enum(create_enum_wrapper(SingleEnum)) + ts_code = interface.to_js() + + assert "ONLY = " in ts_code + assert ts_code.count(",") == 0 + + def test_special_characters(self): + class SpecialCharEnum(Enum): + DASH_VALUE = "dash-value" + UNDERSCORE_VALUE = "underscore_value" + SPACE_VALUE = "space value" + + interface = EnumInterface.from_enum(create_enum_wrapper(SpecialCharEnum)) + ts_code = interface.to_js() + + assert "DASH_VALUE = 'dash-value'" in ts_code + assert "UNDERSCORE_VALUE = 'underscore_value'" in ts_code + assert "SPACE_VALUE = 'space value'" in ts_code + + def test_empty_enum(self): + class EmptyEnum(Enum): + pass + + interface = EnumInterface.from_enum(create_enum_wrapper(EmptyEnum)) + ts_code = interface.to_js() + + assert "enum EmptyEnum {" in ts_code + assert ts_code.strip().endswith("}") diff --git a/mountaineer/__tests__/client_builder/interface_builders/test_exception.py b/mountaineer/__tests__/client_builder/interface_builders/test_exception.py new file mode 100644 index 00000000..b44c6859 --- /dev/null +++ b/mountaineer/__tests__/client_builder/interface_builders/test_exception.py @@ -0,0 +1,81 @@ +from datetime import datetime +from typing import Any + +from pydantic import BaseModel + +from mountaineer.__tests__.client_builder.interface_builders.common import ( + create_exception_wrapper, + create_field_wrapper, + create_model_wrapper, +) +from mountaineer.client_builder.interface_builders.exception import ExceptionInterface +from mountaineer.client_builder.types import DictOf, ListOf +from mountaineer.exceptions import APIException + + +class TestBasicGeneration: + def test_simple_exception(self): + wrapper = create_exception_wrapper( + APIException, + "SimpleException", + 400, + [create_field_wrapper("message", str), create_field_wrapper("code", int)], + ) + + interface = ExceptionInterface.from_exception(wrapper) + ts_code = interface.to_js() + + assert "export interface SimpleException" in ts_code + assert "message: string" in ts_code + assert "code: number" in ts_code + + def test_optional_fields(self): + wrapper = create_exception_wrapper( + APIException, + "OptionalFieldsException", + 400, + [ + create_field_wrapper("message", str), + create_field_wrapper("details", str, required=False), + create_field_wrapper("timestamp", datetime, required=False), + ], + ) + + interface = ExceptionInterface.from_exception(wrapper) + ts_code = interface.to_js() + + assert "message: string" in ts_code + assert "details?: string" in ts_code + assert "timestamp?: string" in ts_code # datetime converts to string + + def test_complex_fields(self): + class ValidationError(BaseModel): + field: str + message: str + + class ErrorDetail(BaseModel): + code: str + description: str + + wrapper = create_exception_wrapper( + APIException, + "ComplexException", + 500, + [ + create_field_wrapper("data", DictOf(str, Any)), + create_field_wrapper( + "errors", + ListOf(create_model_wrapper(ValidationError, "ValidationError")), + ), + create_field_wrapper( + "details", create_model_wrapper(ErrorDetail, "ErrorDetail") + ), + ], + ) + + interface = ExceptionInterface.from_exception(wrapper) + ts_code = interface.to_js() + + assert "data: Record" in ts_code + assert "errors: Array" in ts_code + assert "details: ErrorDetail" in ts_code diff --git a/mountaineer/__tests__/client_builder/interface_builders/test_model.py b/mountaineer/__tests__/client_builder/interface_builders/test_model.py new file mode 100644 index 00000000..0370ec2a --- /dev/null +++ b/mountaineer/__tests__/client_builder/interface_builders/test_model.py @@ -0,0 +1,151 @@ +from datetime import datetime +from enum import Enum +from typing import Any, Type + +import pytest +from pydantic import BaseModel + +from mountaineer.__tests__.client_builder.interface_builders.common import ( + create_field_wrapper, + create_model_wrapper, +) +from mountaineer.__tests__.client_builder.interface_builders.test_enum import ( + create_enum_wrapper, +) +from mountaineer.client_builder.interface_builders.model import ModelInterface +from mountaineer.client_builder.types import Or + + +class TestBasicInterfaceGeneration: + def test_simple_model_interface(self): + class FieldType(Enum): + STRING = "string" + NUMBER = "number" + + wrapper = create_model_wrapper( + BaseModel, + "SimpleModel", + [ + create_field_wrapper("string_field", str), + create_field_wrapper("int_field", int), + create_field_wrapper("optional_field", Or(str, None), required=False), + create_field_wrapper("enum_field", create_enum_wrapper(FieldType)), + ], + ) + + interface = ModelInterface.from_model(wrapper) + ts_code = interface.to_js() + + assert "export interface SimpleModel" in ts_code + assert "string_field: string" in ts_code + assert "int_field: number" in ts_code + assert "optional_field?: string" in ts_code + assert "enum_field: FieldType" in ts_code + + def test_no_export(self): + wrapper = create_model_wrapper( + BaseModel, "SimpleModel", [create_field_wrapper("field", str)] + ) + + interface = ModelInterface.from_model(wrapper) + interface.include_export = False + ts_code = interface.to_js() + + assert not ts_code.startswith("export") + assert ts_code.startswith("interface SimpleModel") + + @pytest.mark.parametrize( + "field_name,field_type,expected_ts", + [ + ("string_field", str, "string"), + ("int_field", int, "number"), + ("bool_field", bool, "boolean"), + ("float_field", float, "number"), + ("date_field", datetime, "string"), + ], + ) + def test_field_type_conversion( + self, field_name: str, field_type: Type[Any], expected_ts: str + ): + wrapper = create_model_wrapper( + BaseModel, "DynamicModel", [create_field_wrapper(field_name, field_type)] + ) + + interface = ModelInterface.from_model(wrapper) + assert f"{field_name}: {expected_ts}" in interface.to_js() + + +class TestInheritanceHandling: + def test_single_inheritance(self): + parent_wrapper = create_model_wrapper( + BaseModel, + "ParentModel", + [ + create_field_wrapper("parent_field", str), + create_field_wrapper("shared_field", int), + ], + ) + + child_wrapper = create_model_wrapper( + BaseModel, + "ChildModel", + [ + create_field_wrapper("child_field", bool), + create_field_wrapper("shared_field", float), # Override type + ], + superclasses=[parent_wrapper], + ) + + interface = ModelInterface.from_model(child_wrapper) + ts_code = interface.to_js() + + assert "interface ChildModel extends ParentModel" in ts_code + assert "child_field: boolean" in ts_code + assert "shared_field: number" in ts_code + + def test_multiple_inheritance(self): + base1_wrapper = create_model_wrapper( + BaseModel, "MultiInheritBase1", [create_field_wrapper("base1_field", str)] + ) + + base2_wrapper = create_model_wrapper( + BaseModel, "MultiInheritBase2", [create_field_wrapper("base2_field", int)] + ) + + child_wrapper = create_model_wrapper( + BaseModel, + "MultiInheritChild", + [create_field_wrapper("child_field", bool)], + superclasses=[base1_wrapper, base2_wrapper], + ) + + interface = ModelInterface.from_model(child_wrapper) + ts_code = interface.to_js() + + assert "extends MultiInheritBase1, MultiInheritBase2" in ts_code + assert "child_field: boolean" in ts_code + + +class TestEdgeCases: + def test_empty_model(self): + wrapper = create_model_wrapper(BaseModel, "EmptyModel", []) + interface = ModelInterface.from_model(wrapper) + ts_code = interface.to_js() + + assert "interface EmptyModel {\n\n}" in ts_code + + def test_all_optional_fields(self): + wrapper = create_model_wrapper( + BaseModel, + "OptionalModel", + [ + create_field_wrapper("field1", str, required=False), + create_field_wrapper("field2", int, required=False), + ], + ) + + interface = ModelInterface.from_model(wrapper) + ts_code = interface.to_js() + + assert "field1?: string" in ts_code + assert "field2?: number" in ts_code diff --git a/mountaineer/__tests__/client_builder/test_aliases.py b/mountaineer/__tests__/client_builder/test_aliases.py new file mode 100644 index 00000000..5dce8e75 --- /dev/null +++ b/mountaineer/__tests__/client_builder/test_aliases.py @@ -0,0 +1,291 @@ +from enum import Enum +from typing import Generic, Optional, Type, TypeVar, cast + +import pytest +from pydantic import BaseModel + +from mountaineer.client_builder.aliases import AliasManager +from mountaineer.client_builder.parser import ( + ControllerParser, + ControllerWrapper, + EnumWrapper, + ModelWrapper, + SelfReference, + WrapperName, +) +from mountaineer.controller import ControllerBase + +T = TypeVar("T") + + +class TestAliasManager: + @pytest.fixture + def parser(self) -> ControllerParser: + return ControllerParser() + + @pytest.fixture + def alias_manager(self) -> AliasManager: + return AliasManager() + + def test_basic_name_normalization( + self, parser: ControllerParser, alias_manager: AliasManager + ) -> None: + class TestModel(BaseModel): + field: str + + wrapper: ModelWrapper = ModelWrapper( + name=WrapperName(TestModel.__name__), + module_name="test.module", + model=TestModel, + isolated_model=TestModel, + superclasses=[], + value_models=[], + ) + parser.parsed_models[TestModel] = wrapper + + alias_manager.assign_global_names(parser) + assert wrapper.name.global_name == "TestModel" + assert wrapper.name.raw_name == "TestModel" + + def test_global_model_conflict_resolution( + self, parser: ControllerParser, alias_manager: AliasManager + ) -> None: + class User(BaseModel): + name: str + + class User2(BaseModel): + email: str + + User2.__name__ = "User" # Force name conflict + User2.__module__ = "auth.models" + + wrapper1: ModelWrapper = ModelWrapper( + name=WrapperName(User.__name__), + module_name="users.models", + model=User, + isolated_model=User, + superclasses=[], + value_models=[], + ) + + wrapper2: ModelWrapper = ModelWrapper( + name=WrapperName(User2.__name__), + module_name="auth.models", + model=User2, + isolated_model=User2, + superclasses=[], + value_models=[], + ) + + parser.parsed_models[User] = wrapper1 + parser.parsed_models[User2] = wrapper2 + + alias_manager.assign_global_names(parser) + + assert wrapper1.name.global_name == "UsersModels_User" + assert wrapper2.name.global_name == "AuthModels_User" + + def test_cross_type_conflict_resolution( + self, parser: ControllerParser, alias_manager: AliasManager + ) -> None: + class Status1(BaseModel): + code: int + + class Status2(Enum): + ACTIVE = "active" + INACTIVE = "inactive" + + Status1.__name__ = "Status" + Status2.__name__ = "Status" + + model_wrapper: ModelWrapper = ModelWrapper( + name=WrapperName("Status"), + module_name="models.status", + model=Status1, + isolated_model=Status1, + superclasses=[], + value_models=[], + ) + + enum_wrapper: EnumWrapper = EnumWrapper( + name=WrapperName("Status"), module_name="enums.status", enum=Status2 + ) + + parser.parsed_models[Status1] = model_wrapper + parser.parsed_enums[Status2] = enum_wrapper + + alias_manager.assign_global_names(parser) + + assert model_wrapper.name.global_name == "ModelsStatus_Status" + assert enum_wrapper.name.global_name == "EnumsStatus_Status" + + def test_self_reference_updating( + self, parser: ControllerParser, alias_manager: AliasManager + ) -> None: + class Node(BaseModel): + value: str + parent: Optional["Node"] = None + + class Node2(BaseModel): + value: str + parent: Optional[Node] = None + + wrapper: ModelWrapper = ModelWrapper( + name=WrapperName("Node"), + module_name="tree.models", + model=Node, + isolated_model=Node, + superclasses=[], + value_models=[], + ) + + # We just need to insert two of the same value so the duplicate detection + # is triggered, their keys don't matter + parser.parsed_models[Node] = wrapper + parser.parsed_models[Node2] = wrapper + parser.parsed_self_references.append(SelfReference(name="Node", model=Node)) + + alias_manager.assign_global_names(parser) + + assert wrapper.name.global_name == "TreeModels_Node" + assert parser.parsed_self_references[0].name == "TreeModels_Node" + + def test_local_name_resolution( + self, parser: ControllerParser, alias_manager: AliasManager + ) -> None: + class StatusEnum(Enum): + ACTIVE = "active" + INACTIVE = "inactive" + + class UserModel(BaseModel): + name: str + status: StatusEnum + + class TestController(ControllerBase): + url: str = "/test" + + controller_wrapper: ControllerWrapper = ControllerWrapper( + name=WrapperName("TestController"), + module_name="controllers", + entrypoint_url="/test", + controller=TestController, + superclasses=[], + queries=[], + paths=[], + render=None, + actions={}, + ) + + model_wrapper: ModelWrapper = ModelWrapper( + name=WrapperName("UserModel"), + module_name="models", + model=UserModel, + isolated_model=UserModel, + superclasses=[], + value_models=[], + ) + + enum_wrapper: EnumWrapper = EnumWrapper( + name=WrapperName("StatusEnum"), module_name="enums", enum=StatusEnum + ) + + parser.parsed_controllers[TestController] = controller_wrapper + parser.parsed_models[UserModel] = model_wrapper + parser.parsed_enums[StatusEnum] = enum_wrapper + + alias_manager.assign_local_names(parser) + + assert model_wrapper.name.local_name == "UserModel" + assert enum_wrapper.name.local_name == "StatusEnum" + + def test_generic_model_naming( + self, parser: ControllerParser, alias_manager: AliasManager + ) -> None: + class Container(BaseModel, Generic[T]): + value: T + + string_generic = cast(Type[BaseModel], Container[str]) + int_generic = cast(Type[BaseModel], Container[int]) + + wrappers = { + cls: ModelWrapper( + name=WrapperName(cls.__name__), + module_name=cls.__module__, + model=cls, + isolated_model=cls, + superclasses=[], + value_models=[], + ) + for cls in [string_generic, int_generic] + } + + for cls, wrapper in wrappers.items(): + parser.parsed_models[cls] = wrapper + + alias_manager.assign_global_names(parser) + + assert wrappers[string_generic].name.global_name == "ContainerStr" + assert wrappers[int_generic].name.global_name == "ContainerInt" + + @pytest.mark.parametrize( + "original_name, module, expected", + [ + ("A", "test.module", "TestModule_A"), + ("with_underscore", "test.module", "TestModule_WithUnderscore"), + ("MyClass", "a.b.c.d", "ABCD_MyClass"), + ("123Invalid", "test", "Test_123Invalid"), + ("With Space", "test", "Test_WithSpace"), + ], + ) + def test_edge_cases( + self, + parser: ControllerParser, + alias_manager: AliasManager, + original_name: str, + module: str, + expected: str, + ) -> None: + class DynamicModel(BaseModel): + field: str + + class OtherModel(BaseModel): + field: str + + wrapper: ModelWrapper = ModelWrapper( + name=WrapperName(original_name), + module_name=module, + model=DynamicModel, + isolated_model=DynamicModel, + superclasses=[], + value_models=[], + ) + + parser.parsed_models[DynamicModel] = wrapper + parser.parsed_models[OtherModel] = wrapper + alias_manager.assign_global_names(parser) + assert wrapper.name.global_name == expected + + @pytest.mark.parametrize( + "module_path,expected_prefix", + [ + ("users", "Users"), + ("auth.models", "AuthModels"), + ("api.v1.users.models", "ApiV1UsersModels"), + ("my_app.user_models", "MyAppUserModels"), + ("a.b.c", "ABC"), + ("", ""), + ("with.multiple.dots.at.end...", "WithMultipleDotsAtEnd"), + ("___internal.models", "InternalModels"), + ("api.v2", "ApiV2"), + ( + "complex_name.with_underscores.and_numbers2", + "ComplexNameWithUnderscoresAndNumbers2", + ), + ], + ) + def test_typescript_prefix_from_module( + self, alias_manager: AliasManager, module_path: str, expected_prefix: str + ) -> None: + """Test the module prefix formatting for various module path patterns""" + result: str = alias_manager._typescript_prefix_from_module(module_path) + assert result == expected_prefix diff --git a/mountaineer/__tests__/client_builder/test_parser.py b/mountaineer/__tests__/client_builder/test_parser.py index 04629d37..942ac4e3 100644 --- a/mountaineer/__tests__/client_builder/test_parser.py +++ b/mountaineer/__tests__/client_builder/test_parser.py @@ -1,12 +1,13 @@ from datetime import datetime from enum import Enum from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Generic, Optional, TypeVar, get_args, get_origin +from typing import AsyncIterator, Generic, Optional, TypeVar import pytest -from pydantic import BaseModel +from fastapi import File, Form, UploadFile +from pydantic import BaseModel, ConfigDict, field_validator +from mountaineer.actions.fields import FunctionActionType from mountaineer.actions.passthrough_dec import passthrough from mountaineer.actions.sideeffect_dec import sideeffect from mountaineer.app import AppController @@ -15,528 +16,499 @@ ControllerWrapper, EnumWrapper, ModelWrapper, - SelfReference, ) -from mountaineer.client_builder.types import Or +from mountaineer.client_builder.types import ListOf from mountaineer.controller import ControllerBase from mountaineer.render import RenderBase +# Type variables for generic tests +T = TypeVar("T") +S = TypeVar("S") -# Test Models -class UserRole(Enum): - ADMIN = "admin" - USER = "user" - GUEST = "guest" +# Core test enum +class ExampleEnum(Enum): + A = "a" + B = "b" + C = "c" -class BaseStats(BaseModel): - created_at: datetime - updated_at: datetime +# Core test models +class ExampleModelBase(BaseModel): + string_field: str + int_field: int + enum_field: ExampleEnum -class BaseMetadata(BaseModel): - version: int - is_active: bool = True + @field_validator("string_field") + def validate_string(cls, v): + if len(v) < 3: + raise ValueError("String too short") + return v.upper() -class Location(BaseModel): - city: str - country: str - postal_code: Optional[str] = None +# Generic test models +class GenericTestModel(BaseModel, Generic[T]): + value: T + metadata: str -class UserSettings(BaseModel): - theme: str = "light" - notifications_enabled: bool = True - preferred_language: str = "en" +class MultiGenericTestModel(GenericTestModel[T], Generic[T, S]): + second_value: S -class UserProfile(BaseStats, BaseMetadata): - id: int - username: str - role: UserRole - location: Location - settings: UserSettings - friends: list[int] = [] - last_login: Optional[datetime] = None +class NestedGenericTestModel(BaseModel, Generic[T]): + wrapper: GenericTestModel[T] + list_of: list[GenericTestModel[T]] -class SystemStatus(BaseModel): - status: bool - last_check: datetime +# Inheritance test models +class BaseInheritanceModel(ExampleModelBase): + base_field: str -class UserRoleResponse(BaseModel): - role: UserRole - permissions: list[str] +class LeftInheritanceModel(BaseInheritanceModel): + left_field: int -class DashboardData(RenderBase): - user: UserProfile - pending_notifications: int = 0 +class RightInheritanceModel(BaseInheritanceModel): + right_field: float -class ProfileUpdateRequest(BaseModel): - location: Optional[Location] = None - settings: Optional[UserSettings] = None +class DiamondInheritanceModel(LeftInheritanceModel, RightInheritanceModel): + final_field: bool -class ProfileUpdateResponse(BaseModel): - success: bool - updated_user: UserProfile +# Circular reference model +class CircularModel(BaseModel): + name: str + parent: Optional["CircularModel"] = None -# Test Controllers -class BaseController(ControllerBase): - @passthrough - def get_system_status(self) -> SystemStatus: - return SystemStatus(status=True, last_check=datetime.now()) +CircularModel.model_rebuild() + + +# Controller response models +class ControllerResponse(BaseModel): + message: str + timestamp: datetime + + +class FileUploadResponse(BaseModel): + filename: str + size: int + +class RenderResponse(RenderBase): + data: ExampleModelBase + count: int = 0 -class SharedController(BaseController): + +# Controller hierarchy +class BaseExampleController(ControllerBase): @passthrough - def shared_friends(self, limit: int = 10) -> UserProfile: - return UserProfile( - id=1, - username="test", - role=UserRole.USER, - location=Location(city="Test City", country="Test Country"), - settings=UserSettings(), - created_at=datetime.now(), - updated_at=datetime.now(), - version=1, - ) + def base_action(self) -> ControllerResponse: # type: ignore + pass + + +class ExampleController(BaseExampleController): + url = "/test/{path_param}" + view_path = "/test.tsx" + + async def render(self, path_param: str, query_param: int = 0) -> RenderResponse: # type: ignore + pass @passthrough - def get_user_role(self) -> UserRoleResponse: - return UserRoleResponse(role=UserRole.USER, permissions=[]) - - -class UserDashboardController(SharedController): - url = "/dashboard/" - view_path = "/dashboard/page.tsx" - - async def render(self) -> DashboardData: - return DashboardData( - user=UserProfile( - id=1, - username="test", - role=UserRole.USER, - location=Location(city="Test City", country="Test Country"), - settings=UserSettings(), - created_at=datetime.now(), - updated_at=datetime.now(), - version=1, - ), - pending_notifications=5, - ) + def get_data(self) -> ExampleModelBase: # type: ignore + pass @sideeffect - def update_profile(self, update: ProfileUpdateRequest) -> ProfileUpdateResponse: - return ProfileUpdateResponse( - success=True, - updated_user=UserProfile( - id=1, - username="test", - role=UserRole.USER, - location=Location(city="Test City", country="Test Country"), - settings=UserSettings(), - created_at=datetime.now(), - updated_at=datetime.now(), - version=1, - ), - ) + def update_form( # type: ignore + self, name: str = Form(...), size: int = Form(...) + ) -> FileUploadResponse: # type: ignore + pass + + @sideeffect + def upload_file(self, file: UploadFile = File(...)) -> FileUploadResponse: # type: ignore + pass + + +class SpecialTypesController(ControllerBase): + url = "/test2" + view_path = "/test2.tsx" + + @passthrough(raw_response=True) # type: ignore + async def raw_action(self) -> ExampleModelBase: # type: ignore + pass @passthrough - def get_friends(self, limit: int = 10) -> UserProfile: - return UserProfile( - id=1, - username="test", - role=UserRole.USER, - location=Location(city="Test City", country="Test Country"), - settings=UserSettings(), - created_at=datetime.now(), - updated_at=datetime.now(), - version=1, + async def stream_action(self) -> AsyncIterator[ExampleModelBase]: + yield ExampleModelBase( + string_field="test", int_field=1, enum_field=ExampleEnum.A ) -# Fixtures -@pytest.fixture -def controller_parser(): - return ControllerParser() +# Tests +class TestControllerParser: + @pytest.fixture + def parser(self): + parser = ControllerParser() + app_controller = AppController(view_root=Path()) + app_controller.register(ExampleController()) + return parser + + @pytest.fixture + def base_model_wrapper(self, parser: ControllerParser): + return parser._parse_model(ExampleModelBase) + + @pytest.fixture + def test_controller_wrapper(self, parser: ControllerParser): + return parser.parse_controller(ExampleController) + + def test_parse_enum(self, parser: ControllerParser): + wrapper = parser._parse_enum(ExampleEnum) + assert isinstance(wrapper, EnumWrapper) + assert wrapper.enum == ExampleEnum + assert wrapper.name.raw_name == "ExampleEnum" + + def test_parse_base_model(self, parser: ControllerParser): + wrapper = parser._parse_model(ExampleModelBase) + assert isinstance(wrapper, ModelWrapper) + assert wrapper.model == ExampleModelBase + assert len(wrapper.value_models) == 3 + assert wrapper.superclasses == [] + + def test_parse_generic_model(self, parser: ControllerParser): + wrapper = parser._parse_model(GenericTestModel[str]) + assert isinstance(wrapper, ModelWrapper) + assert wrapper.model == GenericTestModel[str] + assert len(wrapper.value_models) == 2 + + def test_parse_nested_generic_model(self, parser: ControllerParser): + wrapper = parser._parse_model(NestedGenericTestModel[int]) + assert isinstance(wrapper, ModelWrapper) + assert len(wrapper.value_models) == 2 + # Verify nested models were parsed + assert any(isinstance(f.value, ModelWrapper) for f in wrapper.value_models) + + def test_parse_circular_model(self, parser: ControllerParser): + wrapper = parser._parse_model(CircularModel) + assert isinstance(wrapper, ModelWrapper) + assert len(wrapper.value_models) == 2 + assert any(f.name == "parent" for f in wrapper.value_models) + assert len(parser.parsed_self_references) == 1 + + def test_parse_controller_inheritance(self, parser: ControllerParser): + wrapper = parser.parse_controller(ExampleController) + assert isinstance(wrapper, ControllerWrapper) + assert len(wrapper.superclasses) == 1 + assert wrapper.superclasses[0].controller == BaseExampleController + + def test_parse_controller_actions(self, parser: ControllerParser): + wrapper = parser.parse_controller(ExampleController) + assert len(wrapper.actions) == 3 + + # Actions explicitly tied to this class + known_passthrough = {"get_data"} + known_sideeffect = {"update_form", "upload_file"} + + # Verify different action types + assert { + a.name + for a in wrapper.actions.values() + if a.action_type == FunctionActionType.PASSTHROUGH + } == known_passthrough + assert { + a.name + for a in wrapper.actions.values() + if a.action_type == FunctionActionType.SIDEEFFECT + } == known_sideeffect + + def test_parse_render_method(self, parser: ControllerParser): + wrapper = parser.parse_controller(ExampleController) + assert wrapper.render is not None + assert len(wrapper.paths) == 1 + assert len(wrapper.queries) == 1 + + def test_parse_form_action(self, parser: ControllerParser): + wrapper = parser.parse_controller(ExampleController) + action = wrapper.actions.get("update_form") + assert action is not None + assert action.request_body is not None + assert action.request_body.body_type == "application/x-www-form-urlencoded" + + def test_parse_file_upload_action(self, parser: ControllerParser): + wrapper = parser.parse_controller(ExampleController) + action = wrapper.actions.get("upload_file") + assert action is not None + assert action.request_body is not None + assert action.request_body.body_type == "multipart/form-data" + + def test_parse_raw_response(self, parser: ControllerParser): + wrapper = parser.parse_controller(SpecialTypesController) + action = wrapper.actions.get("raw_action") + assert action is not None + assert action.is_raw_response + + def test_parse_server_side_renderer(self, parser: ControllerParser): + wrapper = parser.parse_controller(SpecialTypesController) + action = wrapper.actions.get("stream_action") + assert action is not None + assert action.is_streaming_response + + def test_parse_multiple_response_types(self, parser: ControllerParser): + """ + Test that one action that's shared by multiple children will have + a response type that includes all possible response models. + + """ + + class ResponseA(RenderBase): + pass + + class ResponseB(RenderBase): + pass + + class MultiResponseParent(ControllerBase): + @sideeffect + def multi_action(self) -> None: + pass + + class ResponseAController(MultiResponseParent): + url = "/response_a" + view_path = "/response_a.tsx" + + async def render(self) -> ResponseA: # type: ignore + pass + + class ResponseBController(MultiResponseParent): + url = "/response_b" + view_path = "/response_b.tsx" + + async def render(self) -> ResponseB: # type: ignore + pass + + app_controller = AppController(view_root=Path()) + app_controller.register(ResponseAController()) + app_controller.register(ResponseBController()) + + wrapper: ControllerWrapper = parser.parse_controller(MultiResponseParent) + action = wrapper.actions["multi_action"] + + a_response = action.response_bodies[ResponseAController] + b_response = action.response_bodies[ResponseBController] + + assert a_response + assert b_response + + response_a_sideeffect = next( + field for field in a_response.value_models if field.name == "sideeffect" + ) + response_b_sideeffect = next( + field for field in b_response.value_models if field.name == "sideeffect" + ) + assert isinstance(response_a_sideeffect.value, ModelWrapper) + assert isinstance(response_b_sideeffect.value, ModelWrapper) + + assert response_a_sideeffect.value.model == ResponseA + assert response_b_sideeffect.value.model == ResponseB + + +class TestInheritanceHandling: + @pytest.fixture + def parser(self): + return ControllerParser() + + def test_basic_inheritance(self, parser: ControllerParser): + wrapper = parser._parse_model(BaseInheritanceModel) + assert len(wrapper.superclasses) == 1 + assert wrapper.superclasses[0].model == ExampleModelBase + + def test_diamond_inheritance(self, parser: ControllerParser): + wrapper = parser._parse_model(DiamondInheritanceModel) + assert len([sc.name.global_name for sc in wrapper.superclasses]) == 2 + superclass_models = {s.model for s in wrapper.superclasses} + assert LeftInheritanceModel in superclass_models + assert RightInheritanceModel in superclass_models + + +class TestGenericHandling: + @pytest.fixture + def parser(self): + return ControllerParser() + + def test_basic_generic(self, parser: ControllerParser): + wrapper = parser._parse_model(GenericTestModel[str]) + assert len(wrapper.value_models) == 2 + assert wrapper.value_models[0].value == str + + def test_multi_generic(self, parser: ControllerParser): + wrapper = parser._parse_model(MultiGenericTestModel[str, int]) + assert len(wrapper.value_models) == 1 + assert any( + f.name == "second_value" and f.value == int for f in wrapper.value_models + ) -@pytest.fixture -def app_controller(): - with TemporaryDirectory() as temp_dir_name: - temp_view_path = Path(temp_dir_name) - app_controller = AppController(view_root=temp_view_path) - yield app_controller - - -@pytest.fixture -def parsed_controller( - controller_parser: ControllerParser, app_controller: AppController -): - controller = UserDashboardController() - app_controller.register(controller) - return controller_parser.parse_controller(controller.__class__) - - -# Test Controller Inheritance Structure -def test_controller_inheritance(parsed_controller: ControllerWrapper): - """Test the inheritance structure of the controller""" - assert len(parsed_controller.superclasses) == 2 - - superclass_names = [sc.name.global_name for sc in parsed_controller.superclasses] - assert "SharedController" in superclass_names - assert "BaseController" in superclass_names - - -# Test Base Controller Actions -def test_base_controller_actions(parsed_controller: ControllerWrapper): - """Test the actions defined in the base controller""" - base_controller = parsed_controller.superclasses[1] - assert set(base_controller.actions.keys()) == {"get_system_status"} - - action = base_controller.actions["get_system_status"] - assert len(action.params) == 0 - - # Test response structure - response_model = action.response_bodies[parsed_controller.controller] - assert isinstance(response_model, ModelWrapper) - fields = {f.name: f for f in response_model.value_models} - assert set(fields.keys()) == {"passthrough"} - - system_status = fields["passthrough"].value - assert isinstance(system_status, ModelWrapper) - status_fields = {f.name: f for f in system_status.value_models} - assert set(status_fields.keys()) == {"status", "last_check"} - assert status_fields["status"].value == bool - assert status_fields["last_check"].value == datetime - - -# Test Shared Controller Actions -def test_shared_controller_actions(parsed_controller: ControllerWrapper): - """Test the actions defined in the shared controller""" - shared_controller = parsed_controller.superclasses[0] - assert set(shared_controller.actions.keys()) == {"shared_friends", "get_user_role"} - - # Test shared_friends action - friends_action = shared_controller.actions["shared_friends"] - assert len(friends_action.params) == 1 - assert friends_action.params[0].name == "limit" - assert friends_action.params[0].value == int - assert not friends_action.params[0].required - - # Test get_user_role action - role_action = shared_controller.actions["get_user_role"] - assert len(role_action.params) == 0 - - response_model = role_action.response_bodies[parsed_controller.controller] - assert isinstance(response_model, ModelWrapper) - fields = {f.name: f for f in response_model.value_models} - role_wrapper = fields["passthrough"].value - assert isinstance(role_wrapper, ModelWrapper) - role_fields = {f.name: f for f in role_wrapper.value_models} - assert set(role_fields.keys()) == {"role", "permissions"} - assert isinstance(role_fields["role"].value, EnumWrapper) - assert role_fields["role"].value.enum == UserRole - - -# Test Dashboard Controller Actions -def test_dashboard_controller_actions(parsed_controller: ControllerWrapper): - """Test the actions defined in the dashboard controller""" - assert set(parsed_controller.actions.keys()) == {"update_profile", "get_friends"} - - # Test update_profile action - update_action = parsed_controller.actions["update_profile"] - assert update_action.request_body is not None - - update_fields = {f.name: f for f in update_action.request_body.value_models} - assert set(update_fields.keys()) == {"location", "settings"} - assert not update_fields["location"].required - assert not update_fields["settings"].required - - # Test get_friends action - friends_action = parsed_controller.actions["get_friends"] - assert len(friends_action.params) == 1 - assert friends_action.params[0].name == "limit" - assert not friends_action.params[0].required - - -# Test Render Model Structure -def test_render_model_structure(parsed_controller: ControllerWrapper): - """Test the structure of the render model""" - assert parsed_controller.render is not None - assert len(parsed_controller.render.superclasses) == 0 - - fields = {f.name: f for f in parsed_controller.render.value_models} - assert set(fields.keys()) == {"user", "pending_notifications"} - assert not fields["pending_notifications"].required - - user_profile = fields["user"].value - assert isinstance(user_profile, ModelWrapper) - profile_fields = {f.name: f for f in user_profile.value_models} - assert set(profile_fields.keys()) == { - "id", - "username", - "role", - "location", - "settings", - "friends", - "last_login", - } - - -# Test Complex Model Inheritance -def test_complex_model_inheritance(parsed_controller: ControllerWrapper): - """Test the inheritance structure of complex models""" - assert parsed_controller.render is not None - user_field = parsed_controller.render.value_models[0] - profile = user_field.value - assert isinstance(profile, ModelWrapper) - - assert len(profile.superclasses) == 2 - superclass_names = {sc.model.__name__ for sc in profile.superclasses} - assert superclass_names == {"BaseStats", "BaseMetadata"} - - # Test BaseStats fields - base_stats = next( - sc for sc in profile.superclasses if sc.model.__name__ == "BaseStats" - ) - stats_fields = {f.name: f for f in base_stats.value_models} - assert set(stats_fields.keys()) == {"created_at", "updated_at"} - - # Test BaseMetadata fields - base_metadata = next( - sc for sc in profile.superclasses if sc.model.__name__ == "BaseMetadata" + def test_nested_generic_resolution(self, parser: ControllerParser): + wrapper = parser._parse_model(NestedGenericTestModel[str]) + assert len(wrapper.value_models) == 2 + list_field = next(f for f in wrapper.value_models if f.name == "list_of") + assert isinstance(list_field.value, ListOf) + assert isinstance(list_field.value.children[0], ModelWrapper) + assert "GenericTestModel[str]" in str(list_field.value.children[0].model) + + +class TestControllerWrapperFeatures: + @pytest.fixture + def parser(self): + parser = ControllerParser() + app_controller = AppController(view_root=Path()) + app_controller.register(ExampleController()) + return parser + + def test_all_actions_collection(self, parser: ControllerParser): + wrapper = parser.parse_controller(ExampleController) + action_names = {action.name for action in wrapper.all_actions} + assert action_names == {"base_action", "get_data", "update_form", "upload_file"} + + @pytest.mark.parametrize( + "include_superclasses, expected_models", + [ + (False, {"ExampleModelBase", "FileUploadResponse", "RenderResponse"}), + ( + True, + { + "ExampleModelBase", + "ControllerResponse", + "FileUploadResponse", + "RenderResponse", + }, + ), + ], ) - metadata_fields = {f.name: f for f in base_metadata.value_models} - assert set(metadata_fields.keys()) == {"version", "is_active"} - - -# Test Nested Models -def test_nested_models(parsed_controller: ControllerWrapper): - """Test the structure of nested models""" - assert parsed_controller.render is not None - user_field = parsed_controller.render.value_models[0] - assert isinstance(user_field.value, ModelWrapper) - profile_fields = {f.name: f for f in user_field.value.value_models} - - # Test Location model - location = profile_fields["location"].value - assert isinstance(location, ModelWrapper) - location_fields = {f.name: f for f in location.value_models} - assert set(location_fields.keys()) == {"city", "country", "postal_code"} - assert not location_fields["postal_code"].required - - # Test UserSettings model - settings = profile_fields["settings"].value - assert isinstance(settings, ModelWrapper) - settings_fields = {f.name: f for f in settings.value_models} - assert set(settings_fields.keys()) == { - "theme", - "notifications_enabled", - "preferred_language", - } - assert not settings_fields["theme"].required - assert not settings_fields["notifications_enabled"].required - assert not settings_fields["preferred_language"].required - - -# Test Action Response Wrappers -def test_action_response_wrappers(parsed_controller: ControllerWrapper): - """Test the structure of action response wrappers""" - # Test sideeffect wrapper - update_action = parsed_controller.actions["update_profile"] - - response_model = update_action.response_bodies[parsed_controller.controller] - assert isinstance(response_model, ModelWrapper) - - response_fields = {f.name: f for f in response_model.value_models} - assert set(response_fields.keys()) == {"sideeffect", "passthrough"} - - # Test passthrough wrapper - friends_action = parsed_controller.actions["get_friends"] - response_model = friends_action.response_bodies[parsed_controller.controller] - assert isinstance(response_model, ModelWrapper) - friends_fields = {f.name: f for f in response_model.value_models} - assert set(friends_fields.keys()) == {"passthrough"} - - -# Define a self-referencing model -class CategoryNode(BaseModel): - id: int - name: str - parent: Optional["CategoryNode"] = None - children: list["CategoryNode"] = [] - created_at: datetime - - -# Update forward refs -CategoryNode.model_rebuild() - - -def test_parse_self_referencing_model(controller_parser: ControllerParser): - """Test parsing a Pydantic model with self-referencing fields""" - # Parse the model - parsed_model = controller_parser._parse_model(CategoryNode) - - # Test basic model properties - assert parsed_model.model == CategoryNode - assert len(parsed_model.superclasses) == 0 - - # Get all field names - field_names = {field.name for field in parsed_model.value_models} - assert field_names == {"id", "name", "parent", "children", "created_at"} - - # Get fields by name for detailed testing - fields = {f.name: f for f in parsed_model.value_models} - - # Test primitive fields - assert fields["id"].value == int - assert fields["id"].required - assert fields["name"].value == str - assert fields["name"].required - assert fields["created_at"].value == datetime - assert fields["created_at"].required - - # Test self-referencing fields - parent_field = fields["parent"] - assert not parent_field.required # Optional field - assert isinstance(parent_field.value, Or) - self_reference = parent_field.value.children[0] - assert isinstance(self_reference, SelfReference) - assert self_reference.model == CategoryNode - - children_field = fields["children"] - assert not children_field.required # Has default value - assert hasattr(children_field.value, "children") # Should have type info preserved - - # Verify isolated model contains all direct fields - assert set(parsed_model.isolated_model.model_fields.keys()) == { - "id", - "name", - "parent", - "children", - "created_at", - } - - -# Define a generic model -T = TypeVar("T") + def test_embedded_types_collection( + self, + parser: ControllerParser, + include_superclasses: bool, + expected_models: set[str], + ): + wrapper = parser.parse_controller(ExampleController) + embedded = ControllerWrapper.get_all_embedded_types( + [wrapper], include_superclasses=include_superclasses + ) + model_names = {m.model.__name__ for m in embedded.models} + assert model_names.intersection(expected_models) == expected_models -class GenericValueModel(BaseModel, Generic[T]): - value: T + def test_embedded_controllers(self, parser: ControllerParser): + wrapper = parser.parse_controller(ExampleController) + controllers = ControllerWrapper.get_all_embedded_controllers([wrapper]) + + controller_types = {c.controller for c in controllers} + assert controller_types == {ExampleController, BaseExampleController} + + +class TestIsolatedModelCreation: + @pytest.fixture + def parser(self): + return ControllerParser() + + def test_standard_model_isolation(self, parser: ControllerParser): + class ParentModel(BaseModel): + parent_field: str + shared_field: int = 0 + + class ChildModel(ParentModel): + child_field: str + shared_field: int = 1 # Override parent field + + isolated = parser._create_isolated_model(ChildModel) + + # Should only include fields directly defined in ChildModel + assert set(isolated.model_fields.keys()) == {"child_field", "shared_field"} + assert "parent_field" not in isolated.model_fields + + # Check that field types are preserved + assert isolated.model_fields["child_field"].annotation == str + assert isolated.model_fields["shared_field"].annotation == int + + def test_nested_inheritance_isolation(self, parser: ControllerParser): + class GrandparentModel(BaseModel): + grandparent_field: str + + class ParentModel(GrandparentModel): + parent_field: int + + class ChildModel(ParentModel): + child_field: bool + + isolated = parser._create_isolated_model(ChildModel) + + # Should only include fields from ChildModel + assert set(isolated.model_fields.keys()) == {"child_field"} + assert "parent_field" not in isolated.model_fields + assert "grandparent_field" not in isolated.model_fields + + def test_generic_model_isolation(self, parser: ControllerParser): + class GenericParent(BaseModel, Generic[T]): + parent_field: T + shared_field: str = "parent" + + class GenericChild(GenericParent[int]): + child_field: str + shared_field: str = "child" # Override parent field + + isolated = parser._create_isolated_model(GenericChild) + + # Should only include fields defined directly in GenericChild + assert set(isolated.model_fields.keys()) == {"child_field", "shared_field"} + assert "parent_field" not in isolated.model_fields + + # Verify field types are correctly resolved + assert isolated.model_fields["child_field"].annotation == str + assert isolated.model_fields["shared_field"].annotation == str + + def test_multi_generic_model_isolation(self, parser: ControllerParser): + class MultiGenericParent(BaseModel, Generic[T, S]): + field_t: T + field_s: S + + class MultiGenericChild(MultiGenericParent[str, int]): + child_field: bool + + isolated = parser._create_isolated_model(MultiGenericChild) + + # Should only include fields defined in child + assert set(isolated.model_fields.keys()) == {"child_field"} + assert "field_t" not in isolated.model_fields + assert "field_s" not in isolated.model_fields + + # Verify field type + assert isolated.model_fields["child_field"].annotation == bool + + def test_model_config_preservation(self, parser: ControllerParser): + class CustomModel(BaseModel): + field: str + + model_config = ConfigDict( + str_strip_whitespace=True, + frozen=True, + ) + + isolated = parser._create_isolated_model(CustomModel) + + # Check that model configuration is preserved + assert isolated.model_config["str_strip_whitespace"] is True # type: ignore + assert isolated.model_config["frozen"] is True # type: ignore + + def test_empty_model_isolation(self, parser: ControllerParser): + class EmptyParent(BaseModel): + parent_field: str + + class EmptyChild(EmptyParent): + pass # No direct fields + isolated = parser._create_isolated_model(EmptyChild) -def test_parse_generic_model(controller_parser: ControllerParser): - """Test parsing a Pydantic model with generic fields""" - # Test with different type parameters - StringValue = GenericValueModel[str] - IntValue = GenericValueModel[int] - DateValue = GenericValueModel[datetime] - - # Parse models with different type parameters - string_model = controller_parser._parse_model(StringValue) - int_model = controller_parser._parse_model(IntValue) - date_model = controller_parser._parse_model(DateValue) - - # Test basic model properties - assert string_model.model == StringValue - assert int_model.model == IntValue - assert date_model.model == DateValue - assert len(string_model.superclasses) == 1 - - # Test field structure for string variant - string_fields = {f.name: f for f in string_model.value_models} - assert len(string_fields) == 1 - assert "value" in string_fields - assert string_fields["value"].value == str - assert string_fields["value"].required - - # Test field structure for int variant - int_fields = {f.name: f for f in int_model.value_models} - assert len(int_fields) == 1 - assert "value" in int_fields - assert int_fields["value"].value == int - assert int_fields["value"].required - - # Test field structure for datetime variant - date_fields = {f.name: f for f in date_model.value_models} - assert len(date_fields) == 1 - assert "value" in date_fields - assert date_fields["value"].value == datetime - assert date_fields["value"].required - - # Test that the isolated model preserves the field structure - assert set(string_model.isolated_model.model_fields.keys()) == {"value"} - assert set(int_model.isolated_model.model_fields.keys()) == {"value"} - assert set(date_model.isolated_model.model_fields.keys()) == {"value"} - - -def test_parse_generic_nested(controller_parser: ControllerParser): - # Test with nested generic types - NestedValue = GenericValueModel[GenericValueModel[str]] - nested_model = controller_parser._parse_model(NestedValue) - nested_fields = {f.name: f for f in nested_model.value_models} - - assert len(nested_fields) == 1 - assert "value" in nested_fields - assert nested_fields["value"].required - - # Test that the isolated model preserves the field structure - assert set(nested_model.isolated_model.model_fields.keys()) == {"value"} - - -def test_parse_generic_optional_generic(controller_parser: ControllerParser): - # Test with optional generic types - OptionalValue = GenericValueModel[Optional[str]] - optional_model = controller_parser._parse_model(OptionalValue) - optional_fields = {f.name: f for f in optional_model.value_models} - - assert len(optional_fields) == 1 - assert "value" in optional_fields - assert isinstance( - optional_fields["value"].value, Or - ) # Should be wrapped in Or type - assert str in [child for child in optional_fields["value"].value.children] - - -def test_parse_generic_list_generic(controller_parser): - # Test with list of generic types - ListValue = GenericValueModel[list[str]] - list_model = controller_parser._parse_model(ListValue) - list_fields = {f.name: f for f in list_model.value_models} - - assert len(list_fields) == 1 - assert "value" in list_fields - assert hasattr( - list_fields["value"].value, "children" - ) # Should preserve list type info - assert list_fields["value"].required - - -def test_parse_generic_model_complex(controller_parser: ControllerParser): - # Test with multiple generic parameters - T = TypeVar("T") - S = TypeVar("S") - - class MultiGenericParent(BaseModel, Generic[T]): - value: T - - class MultiGenericModel(MultiGenericParent[T], Generic[T, S]): - first: T - second: S - - ComplexValue = MultiGenericModel[str, list[int]] - complex_model = controller_parser._parse_model(ComplexValue) - isolated_complex = complex_model.isolated_model - - assert set(isolated_complex.model_fields.keys()) == {"first", "second"} - assert isolated_complex.model_fields["first"].annotation == str - assert get_origin(isolated_complex.model_fields["second"].annotation) == list - assert get_args(isolated_complex.model_fields["second"].annotation) == (int,) + # Should have no fields since child defines none directly + assert not isolated.model_fields diff --git a/mountaineer/__tests__/client_builder/test_types.py b/mountaineer/__tests__/client_builder/test_types.py index 1bfb5d63..670d86a5 100644 --- a/mountaineer/__tests__/client_builder/test_types.py +++ b/mountaineer/__tests__/client_builder/test_types.py @@ -82,51 +82,31 @@ def are_type_definitions_equivalent( if type(def1) != type(def2): return False - # Get all relevant attributes (those that define the type) - attrs = [ - attr - for attr in dir(def1) - if not attr.startswith("_") and not callable(getattr(def1, attr)) + non_type_1 = [ + child for child in def1.children if not isinstance(child, TypeDefinition) + ] + non_type_2 = [ + child for child in def2.children if not isinstance(child, TypeDefinition) ] - # Compare each attribute recursively - return all( - TypeComparisonHelpers.are_types_equivalent( - getattr(def1, attr), getattr(def2, attr) - ) - for attr in attrs - ) - - @staticmethod - def describe_type_difference(type1: Any, type2: Any) -> str: - """ - Returns a detailed description of why two types are different. - Useful for debugging test failures. - """ - if type(type1) != type(type2): - return f"Type mismatch: {type(type1)} != {type(type2)}" + if non_type_1 != non_type_2: + return False - if isinstance(type1, Or) and isinstance(type2, Or): - if len(type1.types) != len(type2.types): - return f"Or types have different lengths: {len(type1.types)} != {len(type2.types)}" - return f"Or types contain different types: {type1.types} != {type2.types}" + child_types_1 = [ + child for child in def1.children if isinstance(child, TypeDefinition) + ] + child_types_2 = [ + child for child in def2.children if isinstance(child, TypeDefinition) + ] - if isinstance(type1, LiteralOf) and isinstance(type2, LiteralOf): - if set(type1.values) != set(type2.values): - return f"LiteralOf values differ: {set(type1.values)} != {set(type2.values)}" + if len(child_types_1) != len(child_types_2): + return False - if isinstance(type1, TypeDefinition) and isinstance(type2, TypeDefinition): - attrs = [ - attr - for attr in dir(type1) - if not attr.startswith("_") and not callable(getattr(type1, attr)) - ] - for attr in attrs: - v1, v2 = getattr(type1, attr), getattr(type2, attr) - if not TypeComparisonHelpers.are_types_equivalent(v1, v2): - return f"Attribute '{attr}' differs: {v1} != {v2}" + for a, b in zip(child_types_1, child_types_2): + if not TypeComparisonHelpers.are_types_equivalent(a, b): + return False - return f"Values differ: {type1} != {type2}" + return True @pytest.fixture @@ -181,61 +161,49 @@ def test_get_union_types_invalid(self): class TestModernTypeSyntax: @pytest.mark.parametrize( - "input_type,expected_type,expected_attributes", + "input_type,expected_type", [ - (str | int, Or, {"types": (str, int)}), - (str | None, Or, {"types": (str, type(None))}), - (int | str | float, Or, {"types": (int, str, float)}), - (list[int], ListOf, {"type": int}), - (dict[str, int], DictOf, {"key_type": str, "value_type": int}), - (tuple[str, int], TupleOf, {"types": (str, int)}), - (set[int], SetOf, {"type": int}), - (list[str | int], ListOf, {"type": Or(types=(str, int))}), + (str | int, Or(str, int)), + (str | None, Or(str, type(None))), + (int | str | float, Or(int, str, float)), + (list[int], ListOf(int)), + (dict[str, int], DictOf(str, int)), + (tuple[str, int], TupleOf(str, int)), + (set[int], SetOf(int)), + (list[str | int], ListOf(Or(str, int))), ( dict[str, list[int]], - DictOf, - {"key_type": str, "value_type": ListOf(type=int)}, + DictOf(str, ListOf(int)), ), ( tuple[int, str | None], - TupleOf, - {"types": (int, Or(types=(str, type(None))))}, + TupleOf(int, Or(str, type(None))), ), ( dict[str | int, list[tuple[int, str]]], - DictOf, - { - "key_type": Or(types=(str, int)), - "value_type": ListOf(type=TupleOf(types=(int, str))), - }, + DictOf(Or(str, int), ListOf(TupleOf(int, str))), ), ( list[dict[str, int | None]], - ListOf, - {"type": DictOf(key_type=str, value_type=Or(types=(int, type(None))))}, + ListOf(DictOf(str, Or(int, type(None)))), ), ], ) - def test_modern_type_syntax( - self, parser, type_compare, input_type, expected_type, expected_attributes - ): + def test_modern_type_syntax(self, parser, type_compare, input_type, expected_type): result = parser.parse_type(input_type) - assert isinstance(result, expected_type) - expected = expected_type(**expected_attributes) - assert type_compare.are_types_equivalent(result, expected) + assert isinstance(result, type(expected_type)) + assert type_compare.are_types_equivalent(result, expected_type) def test_complex_modern_nested_types(self, parser, type_compare): complex_type = dict[str, list[tuple[int | None, str] | set[bool]]] result = parser.parse_type(complex_type) expected = DictOf( - key_type=str, - value_type=ListOf( - type=Or( - types=( - TupleOf(types=(Or(types=(int, type(None))), str)), - SetOf(type=bool), - ) + key=str, + value=ListOf( + Or( + TupleOf(Or(int, type(None)), str), + SetOf(bool), ) ), ) @@ -258,43 +226,32 @@ def test_various_modern_combinations(self, parser, input_type): class TestLiteralTypes: @pytest.mark.parametrize( - "input_type,expected_type,expected_attributes", + "input_type,expected_type", [ - (Literal["a", "b"], LiteralOf, {"values": ["a", "b"]}), - (Literal[1, 2, 3], LiteralOf, {"values": [1, 2, 3]}), - (Literal[True, False], LiteralOf, {"values": [True, False]}), - (Literal[None], LiteralOf, {"values": [None]}), - (Literal["a", 1, True, None], LiteralOf, {"values": ["a", 1, True, None]}), - (list[Literal["a", "b"]], ListOf, {"type": LiteralOf(values=["a", "b"])}), + (Literal["a", "b"], LiteralOf("a", "b")), + (Literal[1, 2, 3], LiteralOf(1, 2, 3)), + (Literal[True, False], LiteralOf(True, False)), + (Literal[None], LiteralOf(None)), + (Literal["a", 1, True, None], LiteralOf("a", 1, True, None)), + (list[Literal["a", "b"]], ListOf(LiteralOf("a", "b"))), ( dict[str, Literal[1, 2, 3]], - DictOf, - {"key_type": str, "value_type": LiteralOf(values=[1, 2, 3])}, + DictOf(str, LiteralOf(1, 2, 3)), ), ( Literal["a", "b"] | int, - Or, - {"types": (LiteralOf(values=["a", "b"]), int)}, + Or(LiteralOf("a", "b"), int), ), ( dict[Literal["x", "y"], list[Literal[1, 2] | str]], - DictOf, - { - "key_type": LiteralOf(values=["x", "y"]), - "value_type": ListOf( - type=Or(types=(LiteralOf(values=[1, 2]), str)) - ), - }, + DictOf(LiteralOf("x", "y"), ListOf(type=Or(LiteralOf(1, 2), str))), ), ], ) - def test_literal_types( - self, parser, type_compare, input_type, expected_type, expected_attributes - ): + def test_literal_types(self, parser, type_compare, input_type, expected_type): result = parser.parse_type(input_type) - assert isinstance(result, expected_type) - expected = expected_type(**expected_attributes) - assert type_compare.are_types_equivalent(result, expected) + assert isinstance(result, type(expected_type)) + assert type_compare.are_types_equivalent(result, expected_type) def test_invalid_literal_values(self, parser): with pytest.raises(TypeError): diff --git a/mountaineer/app.py b/mountaineer/app.py index 472bfce5..91b048d7 100644 --- a/mountaineer/app.py +++ b/mountaineer/app.py @@ -727,7 +727,7 @@ def compile_html(
{ssr_html}
{client_import} diff --git a/mountaineer/client_builder/aliases.py b/mountaineer/client_builder/aliases.py index 38b750b8..219eb03a 100644 --- a/mountaineer/client_builder/aliases.py +++ b/mountaineer/client_builder/aliases.py @@ -131,4 +131,5 @@ def assign_local_names(self, parser: ControllerParser): def _typescript_prefix_from_module(self, module: str): module_parts = module.split(".") + module_parts = [module.strip("_") for module in module_parts] return "".join([camelize(component) for component in module_parts]) diff --git a/mountaineer/client_builder/file_generators/base.py b/mountaineer/client_builder/file_generators/base.py index 08b9a65e..94ded48d 100644 --- a/mountaineer/client_builder/file_generators/base.py +++ b/mountaineer/client_builder/file_generators/base.py @@ -32,9 +32,7 @@ def __init__(self, *, managed_path: Path): def build(self): blocks = list(self.script()) blocks = [self.standard_header] + blocks - self.managed_path.write_text( - "\n\n".join("\n".join(block.lines) for block in blocks) - ) + self.managed_path.write_text("\n\n".join(block.content for block in blocks)) @abstractmethod def script(self) -> Generator["CodeBlock", None, None]: @@ -96,3 +94,7 @@ def _get_indent_level(cls, line: str) -> tuple[str, int]: break indent_str += char return indent_str, len(indent_str) + + @property + def content(self): + return "\n".join(self.lines) diff --git a/mountaineer/client_builder/file_generators/globals.py b/mountaineer/client_builder/file_generators/globals.py index e234ac74..43e3a756 100644 --- a/mountaineer/client_builder/file_generators/globals.py +++ b/mountaineer/client_builder/file_generators/globals.py @@ -180,9 +180,11 @@ def script(self): ) link_setters[ # @pierce: 12-11-2024: Mirror the lowercase camelcase convention of previous versions - camelize( - parsed_controller.wrapper.controller.__name__, - uppercase_first_letter=False, + TSLiteral( + camelize( + parsed_controller.wrapper.controller.__name__, + uppercase_first_letter=False, + ) ) ] = TSLiteral(local_name) diff --git a/mountaineer/client_builder/interface_builders/action.py b/mountaineer/client_builder/interface_builders/action.py index 42f5488b..55acb933 100644 --- a/mountaineer/client_builder/interface_builders/action.py +++ b/mountaineer/client_builder/interface_builders/action.py @@ -21,7 +21,6 @@ class ActionInterface(InterfaceBase): default_initializer: bool response_type: str body: list[str] - required_models: list[str] def to_js(self) -> str: script = f"export const {self.name} = ({self.parameters} : {self.typehints}" @@ -43,7 +42,6 @@ def from_action( """ parameters_dict: dict[str, Any] = {} typehint_dict: dict[str, Any] = {} - required_models: list[str] = [] # System parameters (always optional) system_parameters = {"signal": TSLiteral("signal")} @@ -63,10 +61,11 @@ def from_action( model_name = action.request_body.name.global_name parameters_dict["requestBody"] = TSLiteral("requestBody") typehint_dict[TSLiteral("requestBody")] = TSLiteral(model_name) - required_models.append(model_name) # Merge system parameters - has_nonsystem_parameters = bool(parameters_dict) + has_required_nonsystem_parameters = ( + any([param.required for param in action.params]) if action.params else False + ) parameters_dict.update(system_parameters) typehint_dict.update(system_typehints) @@ -77,9 +76,6 @@ def from_action( if controllers: for controller in controllers: response_types.add(cls._get_response_type(action, controller)) - response_body = cls._get_response_body(action, controller) - if response_body: - required_models.append(response_body.name.global_name) else: # Fallback in the case that no concrete controllers are mounted with this action # In this case we just use a generic typehint for the return value @@ -89,10 +85,9 @@ def from_action( name=action.name, parameters=python_payload_to_typescript(parameters_dict), typehints=python_payload_to_typescript(typehint_dict), - default_initializer=not has_nonsystem_parameters, + default_initializer=not has_required_nonsystem_parameters, response_type=" | ".join(response_types), body=["return __request(", CodeBlock.indent(f" {request_payload}"), ");"], - required_models=required_models, ) @classmethod diff --git a/mountaineer/client_builder/interface_builders/enum.py b/mountaineer/client_builder/interface_builders/enum.py index 3163199a..d0c38668 100644 --- a/mountaineer/client_builder/interface_builders/enum.py +++ b/mountaineer/client_builder/interface_builders/enum.py @@ -34,7 +34,8 @@ def from_enum(cls, enum: EnumWrapper): return cls( name=enum.name.global_name, - body=python_payload_to_typescript(fields, dict_equality="="), + # Optional spacing, but make for better enum definitions (A = 'A') + body=python_payload_to_typescript(fields, dict_equality=" ="), ) def to_js(self) -> str: diff --git a/mountaineer/client_builder/parser.py b/mountaineer/client_builder/parser.py index c247e348..c39b67c8 100644 --- a/mountaineer/client_builder/parser.py +++ b/mountaineer/client_builder/parser.py @@ -64,7 +64,7 @@ class CoreWrapper: @dataclass class FieldWrapper: name: str - value: Union[type["ModelWrapper"], type["EnumWrapper"], type] + value: Union["ModelWrapper", "EnumWrapper", "TypeDefinition", type] required: bool @@ -304,8 +304,9 @@ def parse_controller(self, controller: type[ControllerBase]) -> ControllerWrappe return self.parsed_controllers[controller] # Get all valid superclasses in MRO order - base_exclude = (RenderBase, ControllerBase, LayoutControllerBase) - controller_classes = self._get_valid_mro_classes(controller, base_exclude) + controller_classes = self._get_valid_parent_classes( + controller, ControllerBase, (ControllerBase, LayoutControllerBase) + ) # Get render model from the concrete controller render, render_path, render_query, entrypoint_url = self._parse_render( @@ -315,7 +316,7 @@ def parse_controller(self, controller: type[ControllerBase]) -> ControllerWrappe # Parse superclasses superclass_controllers: list[ControllerWrapper] = [] - for superclass in controller_classes[1:]: + for superclass in controller_classes: superclass_controllers.append(self.parse_controller(superclass)) wrapper = ControllerWrapper( @@ -341,11 +342,13 @@ def _parse_model( return self.parsed_models[model] # Get all valid superclasses in MRO order, excluding BaseModel and above - model_classes = self._get_valid_mro_classes(model, (BaseModel, RenderBase)) + model_classes = self._get_valid_parent_classes( + model, BaseModel, (BaseModel, RenderBase) + ) # Parse direct superclasses (excluding the model itself) superclasses: list[ModelWrapper] = [] - for base in model_classes[1:]: # Skip the first class (model itself) + for base in model_classes: if base not in self.parsed_models: # Now parse it properly self.parsed_models[base] = self._parse_model(base) @@ -527,6 +530,10 @@ def _create_temp_route(self, func: Callable, name: str, url: str) -> APIRoute: # from the query params path=f"/{url}", endpoint=func, + # We don't use the FastAPI's sniffed response model parsing in our pipeline, and + # some definitions like server-side streaming AsyncIterables can't be handled + # natively by FastAPI + response_model=None, ) route = next( @@ -704,17 +711,19 @@ def _parse_exception(self, exception: Type[APIException]): self.parsed_exceptions[exception] = wrapper return wrapper - def _get_valid_mro_classes( - self, cls: type, base_exclude_classes: tuple[type, ...] + def _get_valid_parent_classes( + self, cls: type, base_require: type, base_exclude_classes: tuple[type, ...] ) -> list[type]: - """Helper to get valid MRO classes, excluding certain base classes and anything above them""" - mro = [] - for base in cls.__mro__: - # Stop when we hit any of the excluded base classes - if base in base_exclude_classes: - break - # Skip object() as well - if base is object: - continue - mro.append(base) - return mro + """ + Helper to get valid MRO parents, excluding certain base classes + + """ + return [ + base + for base in cls.__bases__ + if ( + base not in base_exclude_classes + and base is not object + and issubclass(base, base_require) + ) + ] diff --git a/mountaineer/client_builder/types.py b/mountaineer/client_builder/types.py index 7c0d959d..2b9aecbb 100644 --- a/mountaineer/client_builder/types.py +++ b/mountaineer/client_builder/types.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod, abstractproperty -from dataclasses import dataclass from types import UnionType from typing import ( Any, @@ -43,17 +42,17 @@ def update_children(self, children: list[Any]): """Return a new instance with updated children""" raise NotImplementedError + def __repr__(self): + return f"{self.__class__.__name__}({', '.join(repr(child) for child in self.children)})" + -@dataclass class Or(TypeDefinition): """Represents a Union type""" types: tuple[Any, ...] - def __class_getitem__(cls, types): - if not isinstance(types, tuple): - types = (types,) - return cls(types=types) + def __init__(self, *types): + self.types = types @property def children(self): @@ -63,14 +62,13 @@ def update_children(self, children): self.types = tuple(children) -@dataclass class ListOf(TypeDefinition): """Represents a List type""" type: Any - def __class_getitem__(cls, type): - return cls(type=type) + def __init__(self, type): + self.type = type @property def children(self): @@ -81,18 +79,15 @@ def update_children(self, children): self.type = children[0] -@dataclass class DictOf(TypeDefinition): """Represents a Dict type""" key_type: Any value_type: Any - def __class_getitem__(cls, types): - if not isinstance(types, tuple) or len(types) != 2: - raise ValueError("DictOf requires exactly two type parameters") - key_type, value_type = types - return cls(key_type=key_type, value_type=value_type) + def __init__(self, key, value): + self.key_type = key + self.value_type = value @property def children(self): @@ -104,16 +99,13 @@ def update_children(self, children): self.value_type = children[1] -@dataclass class TupleOf(TypeDefinition): """Represents a Tuple type""" types: tuple[Any, ...] - def __class_getitem__(cls, types): - if not isinstance(types, tuple): - types = (types,) - return cls(types=types) + def __init__(self, *types): + self.types = types @property def children(self): @@ -123,14 +115,13 @@ def update_children(self, children): self.types = tuple(children) -@dataclass class SetOf(TypeDefinition): """Represents a Set type""" type: Any - def __class_getitem__(cls, type_): - return cls(type=type_) + def __init__(self, type_): + self.type = type_ @property def children(self): @@ -140,21 +131,13 @@ def update_children(self, children): self.types = tuple(children) -@dataclass class LiteralOf(TypeDefinition): """Represents a Literal type""" values: list[Any] - def __class_getitem__(cls, values): - return cls(values=values) - - def __post_init__(self): - """Validate that all values are primitive types""" - # Convert None-like values to None - self.values = [ - value if not is_none_type(value) else None for value in self.values - ] + def __init__(self, *values): + self.values = [value if not is_none_type(value) else None for value in values] self._validate_primitive_values(self.values) @staticmethod @@ -162,6 +145,7 @@ def _validate_primitive_values(values: List[Any]) -> None: """ Ensures all values are primitive types (str, int, float, bool, None). Raises TypeError for non-primitive values. + """ for value in values: if not isinstance(value, (str, int, float, bool)) and value is not None: @@ -200,7 +184,7 @@ def parse_type(self, field_type: Any): # Handle unions if is_union_type(field_type): union_types = get_union_types(field_type) - return Or(types=tuple(self.parse_type(arg) for arg in union_types)) + return Or(*[self.parse_type(arg) for arg in union_types]) # Handle built-in collections origin_type = get_origin(field_type) @@ -219,16 +203,16 @@ def _parse_origin_type(self, field_type: Any, origin_type: Any) -> TypeDefinitio return ListOf(type=args[0]) if origin_type in (dict, Dict): - return DictOf(key_type=args[0], value_type=args[1]) + return DictOf(key=args[0], value=args[1]) if origin_type in (tuple, Tuple): - return TupleOf(types=args) + return TupleOf(*args) if origin_type in (set, Set): - return SetOf(type=args[0]) + return SetOf(args[0]) if origin_type is Literal: - return LiteralOf(values=list(args)) + return LiteralOf(*args) # For unknown origin types, wrap in Or raise ValueError(f"Unsupported origin type: {origin_type}") @@ -237,12 +221,12 @@ def _parse_basic_type(self, field_type: Any) -> Any: """Parse basic types without args""" if isinstance(field_type, type): if issubclass(field_type, (list, List)): - return ListOf(type=Any) + return ListOf(Any) elif issubclass(field_type, (dict, Dict)): - return DictOf(key_type=Any, value_type=Any) + return DictOf(key=Any, value=Any) elif issubclass(field_type, (tuple, Tuple)): # type: ignore - return TupleOf(types=(Any,)) + return TupleOf(Any) elif issubclass(field_type, (set, Set)): - return SetOf(type=Any) + return SetOf(Any) return field_type diff --git a/mountaineer/ssr.py b/mountaineer/ssr.py index 694402bf..65dd975d 100644 --- a/mountaineer/ssr.py +++ b/mountaineer/ssr.py @@ -71,7 +71,7 @@ def render_ssr( polyfill_script = get_static_path("ssr_polyfills.js").read_text() data_json = json_dumps(render_data) - injected_script = f"const SERVER_DATA = {data_json};\n{polyfill_script}\n" + injected_script = f"var SERVER_DATA = {data_json};\n{polyfill_script}\n" full_script = f"{injected_script}{script}" try: diff --git a/poetry.lock b/poetry.lock index 74364353..32acd223 100644 --- a/poetry.lock +++ b/poetry.lock @@ -153,13 +153,13 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.114.1" +version = "0.114.2" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.114.1-py3-none-any.whl", hash = "sha256:5d4746f6e4b7dff0b4f6b6c6d5445645285f662fe75886e99af7ee2d6b58bb3e"}, - {file = "fastapi-0.114.1.tar.gz", hash = "sha256:1d7bbbeabbaae0acb0c22f0ab0b040f642d3093ca3645f8c876b6f91391861d8"}, + {file = "fastapi-0.114.2-py3-none-any.whl", hash = "sha256:44474a22913057b1acb973ab90f4b671ba5200482e7622816d79105dcece1ac5"}, + {file = "fastapi-0.114.2.tar.gz", hash = "sha256:0adb148b62edb09e8c6eeefa3ea934e8f276dabc038c5a82989ea6346050c3da"}, ] [package.dependencies] @@ -835,6 +835,17 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-multipart" +version = "0.0.19" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python_multipart-0.0.19-py3-none-any.whl", hash = "sha256:f8d5b0b9c618575bf9df01c684ded1d94a338839bdd8223838afacfb4bb2082d"}, + {file = "python_multipart-0.0.19.tar.gz", hash = "sha256:905502ef39050557b7a6af411f454bc19526529ca46ae6831508438890ce12cc"}, +] + [[package]] name = "pyyaml" version = "6.0.1" @@ -1070,13 +1081,13 @@ SQLAlchemy = ">=2.0.14,<2.1.0" [[package]] name = "starlette" -version = "0.38.5" +version = "0.38.6" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.38.5-py3-none-any.whl", hash = "sha256:632f420a9d13e3ee2a6f18f437b0a9f1faecb0bc42e1942aa2ea0e379a4c4206"}, - {file = "starlette-0.38.5.tar.gz", hash = "sha256:04a92830a9b6eb1442c766199d62260c3d4dc9c4f9188360626b1e0273cb7077"}, + {file = "starlette-0.38.6-py3-none-any.whl", hash = "sha256:4517a1409e2e73ee4951214ba012052b9e16f60e90d73cfb06192c19203bbb05"}, + {file = "starlette-0.38.6.tar.gz", hash = "sha256:863a1588f5574e70a821dadefb41e4881ea451a47a3cd1b4df359d4ffefe5ead"}, ] [package.dependencies] @@ -1502,4 +1513,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "0598ebbf3c38b117643e2504e034e35c5e64599db7aadfd76957fe33141bc999" +content-hash = "09e9494dd1e2d4415a92216890a528ff98562ae6d879fc64dfff9ffc25b9cd5f" diff --git a/pyproject.toml b/pyproject.toml index 818b0e10..6e3c22a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ tqdm = "^4.66.2" toml = "^0.10.2" types-toml = "^0.10.8.20240310" types-pygments = "^2.18.0.20240506" +python-multipart = "^0.0.19" [build-system] # Pin for maturin until the latest pypi distribution action supports diff --git a/src/benches/fixtures/ssr_polyfill_archive.js b/src/benches/fixtures/ssr_polyfill_archive.js index 7f34d112..f15bde4c 100644 --- a/src/benches/fixtures/ssr_polyfill_archive.js +++ b/src/benches/fixtures/ssr_polyfill_archive.js @@ -10,4 +10,4 @@ class TextEncoder { } } -const SERVER_DATA = {}; +var SERVER_DATA = {};