diff --git a/connexion/middleware/__init__.py b/connexion/middleware/__init__.py index 302bc67c7..342848f42 100644 --- a/connexion/middleware/__init__.py +++ b/connexion/middleware/__init__.py @@ -1,4 +1,4 @@ -from .abstract import AppMiddleware # NOQA +from .abstract import AppMiddleware, RoutedMiddleware # NOQA from .main import ConnexionMiddleware # NOQA from .routing import RoutingMiddleware # NOQA from .swagger_ui import SwaggerUIMiddleware # NOQA diff --git a/connexion/middleware/abstract.py b/connexion/middleware/abstract.py index 69c065cb4..75c263c31 100644 --- a/connexion/middleware/abstract.py +++ b/connexion/middleware/abstract.py @@ -1,7 +1,21 @@ import abc +import logging import pathlib import typing as t +import typing_extensions as te +from starlette.types import ASGIApp, Receive, Scope, Send + +from connexion.apis.abstract import AbstractSpecAPI +from connexion.exceptions import MissingMiddleware +from connexion.http_facts import METHODS +from connexion.operations import AbstractOperation +from connexion.resolver import ResolverError + +logger = logging.getLogger("connexion.middleware.abstract") + +ROUTING_CONTEXT = "connexion_routing" + class AppMiddleware(abc.ABC): """Middlewares that need the APIs to be registered on them should inherit from this base @@ -12,3 +26,117 @@ def add_api( self, specification: t.Union[pathlib.Path, str, dict], **kwargs ) -> None: pass + + +class RoutedOperation(te.Protocol): + def __init__(self, next_app: ASGIApp, **kwargs) -> None: + ... + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + ... + + +OP = t.TypeVar("OP", bound=RoutedOperation) + + +class RoutedAPI(AbstractSpecAPI, t.Generic[OP]): + + operation_cls: t.Type[OP] + """The operation this middleware uses, which should implement the RoutingOperation protocol.""" + + def __init__( + self, + specification: t.Union[pathlib.Path, str, dict], + *args, + next_app: ASGIApp, + **kwargs, + ) -> None: + super().__init__(specification, *args, **kwargs) + self.next_app = next_app + self.operations: t.MutableMapping[str, OP] = {} + + def add_paths(self) -> None: + paths = self.specification.get("paths", {}) + for path, methods in paths.items(): + for method in methods: + if method not in METHODS: + continue + try: + self.add_operation(path, method) + except ResolverError: + # ResolverErrors are either raised or handled in routing middleware. + pass + + def add_operation(self, path: str, method: str) -> None: + operation_spec_cls = self.specification.operation_cls + operation = operation_spec_cls.from_spec( + self.specification, self, path, method, self.resolver + ) + routed_operation = self.make_operation(operation) + self.operations[operation.operation_id] = routed_operation + + @abc.abstractmethod + def make_operation(self, operation: AbstractOperation) -> OP: + """Create an operation of the `operation_cls` type.""" + raise NotImplementedError + + +API = t.TypeVar("API", bound="RoutedAPI") + + +class RoutedMiddleware(AppMiddleware, t.Generic[API]): + """Baseclass for middleware that wants to leverage the RoutingMiddleware to route requests to + its operations. + + The RoutingMiddleware adds the operation_id to the ASGI scope. This middleware registers its + operations by operation_id at startup. At request time, the operation is fetched by an + operation_id lookup. + """ + + api_cls: t.Type[API] + """The subclass of RoutedAPI this middleware uses.""" + + def __init__(self, app: ASGIApp) -> None: + self.app = app + self.apis: t.Dict[str, API] = {} + + def add_api( + self, specification: t.Union[pathlib.Path, str, dict], **kwargs + ) -> None: + api = self.api_cls(specification, next_app=self.app, **kwargs) + self.apis[api.base_path] = api + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Fetches the operation related to the request and calls it.""" + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + try: + connexion_context = scope["extensions"][ROUTING_CONTEXT] + except KeyError: + raise MissingMiddleware( + "Could not find routing information in scope. Please make sure " + "you have a routing middleware registered upstream. " + ) + api_base_path = connexion_context.get("api_base_path") + if api_base_path: + api = self.apis[api_base_path] + operation_id = connexion_context.get("operation_id") + try: + operation = api.operations[operation_id] + except KeyError as e: + if operation_id is None: + logger.debug("Skipping validation check for operation without id.") + await self.app(scope, receive, send) + return + else: + raise MissingOperation("Encountered unknown operation_id.") from e + else: + return await operation(scope, receive, send) + + await self.app(scope, receive, send) + + +class MissingOperation(Exception): + """Missing operation""" diff --git a/connexion/middleware/exceptions.py b/connexion/middleware/exceptions.py index 81ea509ea..e23102299 100644 --- a/connexion/middleware/exceptions.py +++ b/connexion/middleware/exceptions.py @@ -39,7 +39,7 @@ def problem_handler(self, _, exception: ProblemException): def http_exception(self, request: Request, exc: HTTPException) -> Response: try: - headers = exc.headers + headers = exc.headers # type: ignore except AttributeError: # Starlette < 0.19 headers = {} diff --git a/connexion/middleware/routing.py b/connexion/middleware/routing.py index 54d409597..905ecb315 100644 --- a/connexion/middleware/routing.py +++ b/connexion/middleware/routing.py @@ -6,61 +6,36 @@ from starlette.types import ASGIApp, Receive, Scope, Send from connexion.apis import AbstractRoutingAPI -from connexion.middleware import AppMiddleware +from connexion.middleware.abstract import ROUTING_CONTEXT, AppMiddleware from connexion.operations import AbstractOperation from connexion.resolver import Resolver -ROUTING_CONTEXT = "connexion_routing" - - _scope: ContextVar[dict] = ContextVar("SCOPE") -class RoutingMiddleware(AppMiddleware): - def __init__(self, app: ASGIApp) -> None: - """Middleware that resolves the Operation for an incoming request and attaches it to the - scope. - - :param app: app to wrap in middleware. - """ - self.app = app - # Pass unknown routes to next app - self.router = Router(default=RoutingOperation(None, self.app)) - - def add_api( - self, - specification: t.Union[pathlib.Path, str, dict], - base_path: t.Optional[str] = None, - arguments: t.Optional[dict] = None, - **kwargs - ) -> None: - """Add an API to the router based on a OpenAPI spec. +class RoutingOperation: + def __init__(self, operation_id: t.Optional[str], next_app: ASGIApp) -> None: + self.operation_id = operation_id + self.next_app = next_app - :param specification: OpenAPI spec as dict or path to file. - :param base_path: Base path where to add this API. - :param arguments: Jinja arguments to replace in the spec. - """ - api = RoutingAPI( - specification, - base_path=base_path, - arguments=arguments, - next_app=self.app, - **kwargs - ) - self.router.mount(api.base_path, app=api.router) + @classmethod + def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp): + return cls(operation.operation_id, next_app) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Route request to matching operation, and attach it to the scope before calling the - next app.""" - if scope["type"] != "http": - await self.app(scope, receive, send) - return + """Attach operation to scope and pass it to the next app""" + original_scope = _scope.get() - _scope.set(scope.copy()) # type: ignore + api_base_path = scope.get("root_path", "")[ + len(original_scope.get("root_path", "")) : + ] - # Needs to be set so starlette router throws exceptions instead of returning error responses - scope["app"] = self - await self.router(scope, receive, send) + extensions = original_scope.setdefault("extensions", {}) + connexion_routing = extensions.setdefault(ROUTING_CONTEXT, {}) + connexion_routing.update( + {"api_base_path": api_base_path, "operation_id": self.operation_id} + ) + await self.next_app(original_scope, receive, send) class RoutingAPI(AbstractRoutingAPI): @@ -105,26 +80,48 @@ def _add_operation_internal( self.router.add_route(path, operation, methods=[method]) -class RoutingOperation: - def __init__(self, operation_id: t.Optional[str], next_app: ASGIApp) -> None: - self.operation_id = operation_id - self.next_app = next_app +class RoutingMiddleware(AppMiddleware): + def __init__(self, app: ASGIApp) -> None: + """Middleware that resolves the Operation for an incoming request and attaches it to the + scope. - @classmethod - def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp): - return cls(operation.operation_id, next_app) + :param app: app to wrap in middleware. + """ + self.app = app + # Pass unknown routes to next app + self.router = Router(default=RoutingOperation(None, self.app)) + + def add_api( + self, + specification: t.Union[pathlib.Path, str, dict], + base_path: t.Optional[str] = None, + arguments: t.Optional[dict] = None, + **kwargs + ) -> None: + """Add an API to the router based on a OpenAPI spec. + + :param specification: OpenAPI spec as dict or path to file. + :param base_path: Base path where to add this API. + :param arguments: Jinja arguments to replace in the spec. + """ + api = RoutingAPI( + specification, + base_path=base_path, + arguments=arguments, + next_app=self.app, + **kwargs + ) + self.router.mount(api.base_path, app=api.router) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Attach operation to scope and pass it to the next app""" - original_scope = _scope.get() + """Route request to matching operation, and attach it to the scope before calling the + next app.""" + if scope["type"] != "http": + await self.app(scope, receive, send) + return - api_base_path = scope.get("root_path", "")[ - len(original_scope.get("root_path", "")) : - ] + _scope.set(scope.copy()) # type: ignore - extensions = original_scope.setdefault("extensions", {}) - connexion_routing = extensions.setdefault(ROUTING_CONTEXT, {}) - connexion_routing.update( - {"api_base_path": api_base_path, "operation_id": self.operation_id} - ) - await self.next_app(original_scope, receive, send) + # Needs to be set so starlette router throws exceptions instead of returning error responses + scope["app"] = self + await self.router(scope, receive, send) diff --git a/connexion/middleware/security.py b/connexion/middleware/security.py index 485649f73..376f5286b 100644 --- a/connexion/middleware/security.py +++ b/connexion/middleware/security.py @@ -1,136 +1,28 @@ import logging -import pathlib import typing as t from collections import defaultdict from starlette.types import ASGIApp, Receive, Scope, Send -from connexion.apis.abstract import AbstractSpecAPI -from connexion.exceptions import MissingMiddleware, ProblemException -from connexion.http_facts import METHODS +from connexion.exceptions import ProblemException from connexion.lifecycle import MiddlewareRequest -from connexion.middleware import AppMiddleware -from connexion.middleware.routing import ROUTING_CONTEXT +from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation -from connexion.resolver import ResolverError from connexion.security import SecurityHandlerFactory -from connexion.spec import Specification logger = logging.getLogger("connexion.middleware.security") -class SecurityMiddleware(AppMiddleware): - """Middleware to check if operation is accessible on scope.""" - - def __init__(self, app: ASGIApp) -> None: - self.app = app - self.apis: t.Dict[str, SecurityAPI] = {} - - def add_api( - self, specification: t.Union[pathlib.Path, str, dict], **kwargs - ) -> None: - api = SecurityAPI(specification, **kwargs) - self.apis[api.base_path] = api - - async def __call__(self, scope: Scope, receive: Receive, send: Send): - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - try: - connexion_context = scope["extensions"][ROUTING_CONTEXT] - except KeyError: - raise MissingMiddleware( - "Could not find routing information in scope. Please make sure " - "you have a routing middleware registered upstream. " - ) - - api_base_path = connexion_context.get("api_base_path") - if api_base_path: - api = self.apis[api_base_path] - operation_id = connexion_context.get("operation_id") - try: - operation = api.operations[operation_id] - except KeyError as e: - if operation_id is None: - logger.debug( - "Skipping security check for operation without id. Enable " - "`auth_all_paths` to check security for unknown operations." - ) - else: - raise MissingSecurityOperation( - "Encountered unknown operation_id." - ) from e - - else: - request = MiddlewareRequest(scope) - await operation(request) - - await self.app(scope, receive, send) - - -class SecurityAPI(AbstractSpecAPI): - def __init__( - self, - specification: t.Union[pathlib.Path, str, dict], - auth_all_paths: bool = False, - *args, - **kwargs - ): - super().__init__(specification, *args, **kwargs) - self.security_handler_factory = SecurityHandlerFactory() - - if auth_all_paths: - self.add_auth_on_not_found() - else: - self.operations: t.Dict[str, SecurityOperation] = {} - - self.add_paths() - - def add_auth_on_not_found(self): - """Register a default SecurityOperation for routes that are not found.""" - default_operation = self.make_operation(self.specification) - self.operations = defaultdict(lambda: default_operation) - - def add_paths(self): - paths = self.specification.get("paths", {}) - for path, methods in paths.items(): - for method in methods: - if method not in METHODS: - continue - try: - self.add_operation(path, method) - except ResolverError: - # ResolverErrors are either raised or handled in routing middleware. - pass - - def add_operation(self, path: str, method: str) -> None: - operation_cls = self.specification.operation_cls - operation = operation_cls.from_spec( - self.specification, self, path, method, self.resolver - ) - security_operation = self.make_operation(operation) - self._add_operation_internal(operation.operation_id, security_operation) - - def make_operation(self, operation: t.Union[AbstractOperation, Specification]): - return SecurityOperation.from_operation( - operation, - security_handler_factory=self.security_handler_factory, - ) - - def _add_operation_internal( - self, operation_id: str, operation: "SecurityOperation" - ): - self.operations[operation_id] = operation - - class SecurityOperation: def __init__( self, + next_app: ASGIApp, + *, security_handler_factory: SecurityHandlerFactory, security: list, security_schemes: dict, ): + self.next_app = next_app self.security_handler_factory = security_handler_factory self.security = security self.security_schemes = security_schemes @@ -139,12 +31,14 @@ def __init__( @classmethod def from_operation( cls, - operation: t.Union[AbstractOperation, Specification], + operation: AbstractOperation, + *, + next_app: ASGIApp, security_handler_factory: SecurityHandlerFactory, - ): - # TODO: Turn Operation class into OperationSpec and use as init argument instead + ) -> "SecurityOperation": return cls( - security_handler_factory, + next_app=next_app, + security_handler_factory=security_handler_factory, security=operation.security, security_schemes=operation.security_schemes, ) @@ -304,8 +198,45 @@ def _get_verification_fn(self): return self.security_handler_factory.verify_security(auth_funcs) - async def __call__(self, request: MiddlewareRequest): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + request = MiddlewareRequest(scope) await self.verification_fn(request) + await self.next_app(scope, receive, send) + + +class SecurityAPI(RoutedAPI[SecurityOperation]): + + operation_cls = SecurityOperation + + def __init__(self, *args, auth_all_paths: bool = False, **kwargs): + super().__init__(*args, **kwargs) + + self.security_handler_factory = SecurityHandlerFactory() + + if auth_all_paths: + self.add_auth_on_not_found() + else: + self.operations: t.MutableMapping[str, SecurityOperation] = {} + + self.add_paths() + + def add_auth_on_not_found(self) -> None: + """Register a default SecurityOperation for routes that are not found.""" + default_operation = self.make_operation(self.specification) + self.operations = defaultdict(lambda: default_operation) + + def make_operation(self, operation: AbstractOperation) -> SecurityOperation: + return SecurityOperation.from_operation( + operation, + next_app=self.next_app, + security_handler_factory=self.security_handler_factory, + ) + + +class SecurityMiddleware(RoutedMiddleware[SecurityAPI]): + """Middleware to check if operation is accessible on scope.""" + + api_cls = SecurityAPI class MissingSecurityOperation(ProblemException): diff --git a/connexion/middleware/swagger_ui.py b/connexion/middleware/swagger_ui.py index 129959ce1..5ef7b765b 100644 --- a/connexion/middleware/swagger_ui.py +++ b/connexion/middleware/swagger_ui.py @@ -21,56 +21,6 @@ _original_scope: ContextVar[Scope] = ContextVar("SCOPE") -class SwaggerUIMiddleware(AppMiddleware): - def __init__(self, app: ASGIApp) -> None: - """Middleware that hosts a swagger UI. - - :param app: app to wrap in middleware. - """ - self.app = app - # Set default to pass unknown routes to next app - self.router = Router(default=self.default_fn) - - def add_api( - self, - specification: t.Union[pathlib.Path, str, dict], - base_path: t.Optional[str] = None, - arguments: t.Optional[dict] = None, - **kwargs - ) -> None: - """Add an API to the router based on a OpenAPI spec. - - :param specification: OpenAPI spec as dict or path to file. - :param base_path: Base path where to add this API. - :param arguments: Jinja arguments to replace in the spec. - """ - api = SwaggerUIAPI( - specification, - base_path=base_path, - arguments=arguments, - default=self.default_fn, - **kwargs - ) - self.router.mount(api.base_path, app=api.router) - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - _original_scope.set(scope.copy()) # type: ignore - await self.router(scope, receive, send) - - async def default_fn(self, _scope: Scope, receive: Receive, send: Send) -> None: - """ - Callback to call next app as default when no matching route is found. - - Unfortunately we cannot just pass the next app as default, since the router manipulates - the scope when descending into mounts, losing information about the base path. Therefore, - we use the original scope instead. - - This is caused by https://github.com/encode/starlette/issues/1336. - """ - original_scope = _original_scope.get() - await self.app(original_scope, receive, send) - - class SwaggerUIAPI(AbstractSpecAPI): def __init__(self, *args, default: ASGIApp, **kwargs): super().__init__(*args, **kwargs) @@ -213,3 +163,53 @@ async def _get_swagger_ui_config(self, request): media_type="application/json", content=self.jsonifier.dumps(self.options.openapi_console_ui_config), ) + + +class SwaggerUIMiddleware(AppMiddleware): + def __init__(self, app: ASGIApp) -> None: + """Middleware that hosts a swagger UI. + + :param app: app to wrap in middleware. + """ + self.app = app + # Set default to pass unknown routes to next app + self.router = Router(default=self.default_fn) + + def add_api( + self, + specification: t.Union[pathlib.Path, str, dict], + base_path: t.Optional[str] = None, + arguments: t.Optional[dict] = None, + **kwargs + ) -> None: + """Add an API to the router based on a OpenAPI spec. + + :param specification: OpenAPI spec as dict or path to file. + :param base_path: Base path where to add this API. + :param arguments: Jinja arguments to replace in the spec. + """ + api = SwaggerUIAPI( + specification, + base_path=base_path, + arguments=arguments, + default=self.default_fn, + **kwargs + ) + self.router.mount(api.base_path, app=api.router) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + _original_scope.set(scope.copy()) # type: ignore + await self.router(scope, receive, send) + + async def default_fn(self, _scope: Scope, receive: Receive, send: Send) -> None: + """ + Callback to call next app as default when no matching route is found. + + Unfortunately we cannot just pass the next app as default, since the router manipulates + the scope when descending into mounts, losing information about the base path. Therefore, + we use the original scope instead. + + This is caused by https://github.com/encode/starlette/issues/1336. + """ + original_scope = _original_scope.get() + await self.app(original_scope, receive, send) diff --git a/connexion/middleware/validation.py b/connexion/middleware/validation.py index 0b7a050eb..6a0ae86b8 100644 --- a/connexion/middleware/validation.py +++ b/connexion/middleware/validation.py @@ -2,19 +2,14 @@ Validation Middleware. """ import logging -import pathlib import typing as t from starlette.types import ASGIApp, Receive, Scope, Send -from connexion.apis.abstract import AbstractSpecAPI from connexion.decorators.uri_parsing import AbstractURIParser -from connexion.exceptions import MissingMiddleware, UnsupportedMediaTypeProblem -from connexion.http_facts import METHODS -from connexion.middleware import AppMiddleware -from connexion.middleware.routing import ROUTING_CONTEXT +from connexion.exceptions import UnsupportedMediaTypeProblem +from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation -from connexion.resolver import ResolverError from connexion.utils import is_nullable from connexion.validators import JSONBodyValidator @@ -30,130 +25,19 @@ } -class ValidationMiddleware(AppMiddleware): - """Middleware for validating requests according to the API contract.""" - - def __init__(self, app: ASGIApp) -> None: - self.app = app - self.apis: t.Dict[str, ValidationAPI] = {} - - def add_api( - self, specification: t.Union[pathlib.Path, str, dict], **kwargs - ) -> None: - api = ValidationAPI(specification, next_app=self.app, **kwargs) - self.apis[api.base_path] = api - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - try: - connexion_context = scope["extensions"][ROUTING_CONTEXT] - except KeyError: - raise MissingMiddleware( - "Could not find routing information in scope. Please make sure " - "you have a routing middleware registered upstream. " - ) - api_base_path = connexion_context.get("api_base_path") - if api_base_path: - api = self.apis[api_base_path] - operation_id = connexion_context.get("operation_id") - try: - operation = api.operations[operation_id] - except KeyError as e: - if operation_id is None: - logger.debug("Skipping validation check for operation without id.") - await self.app(scope, receive, send) - return - else: - raise MissingValidationOperation( - "Encountered unknown operation_id." - ) from e - else: - return await operation(scope, receive, send) - - await self.app(scope, receive, send) - - -class ValidationAPI(AbstractSpecAPI): - """Validation API.""" - - def __init__( - self, - specification: t.Union[pathlib.Path, str, dict], - *args, - next_app: ASGIApp, - validate_responses=False, - strict_validation=False, - validator_map=None, - uri_parser_class=None, - **kwargs, - ): - super().__init__(specification, *args, **kwargs) - self.next_app = next_app - - self.validator_map = validator_map - - logger.debug("Validate Responses: %s", str(validate_responses)) - self.validate_responses = validate_responses - - logger.debug("Strict Request Validation: %s", str(strict_validation)) - self.strict_validation = strict_validation - - self.uri_parser_class = uri_parser_class - - self.operations: t.Dict[str, ValidationOperation] = {} - self.add_paths() - - def add_paths(self): - paths = self.specification.get("paths", {}) - for path, methods in paths.items(): - for method in methods: - if method not in METHODS: - continue - try: - self.add_operation(path, method) - except ResolverError: - # ResolverErrors are either raised or handled in routing middleware. - pass - - def add_operation(self, path: str, method: str) -> None: - operation_cls = self.specification.operation_cls - operation = operation_cls.from_spec( - self.specification, self, path, method, self.resolver - ) - validation_operation = self.make_operation(operation) - self._add_operation_internal(operation.operation_id, validation_operation) - - def make_operation(self, operation: AbstractOperation): - return ValidationOperation( - operation, - self.next_app, - validate_responses=self.validate_responses, - strict_validation=self.strict_validation, - validator_map=self.validator_map, - uri_parser_class=self.uri_parser_class, - ) - - def _add_operation_internal( - self, operation_id: str, operation: "ValidationOperation" - ): - self.operations[operation_id] = operation - - class ValidationOperation: def __init__( self, - operation: AbstractOperation, next_app: ASGIApp, + *, + operation: AbstractOperation, validate_responses: bool = False, strict_validation: bool = False, validator_map: t.Optional[dict] = None, uri_parser_class: t.Optional[AbstractURIParser] = None, ) -> None: - self._operation = operation self.next_app = next_app + self._operation = operation self.validate_responses = validate_responses self.strict_validation = strict_validation self._validator_map = VALIDATOR_MAP @@ -228,5 +112,49 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): await self.next_app(scope, receive, send) +class ValidationAPI(RoutedAPI[ValidationOperation]): + """Validation API.""" + + operation_cls = ValidationOperation + + def __init__( + self, + *args, + validate_responses=False, + strict_validation=False, + validator_map=None, + uri_parser_class=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.validator_map = validator_map + + logger.debug("Validate Responses: %s", str(validate_responses)) + self.validate_responses = validate_responses + + logger.debug("Strict Request Validation: %s", str(strict_validation)) + self.strict_validation = strict_validation + + self.uri_parser_class = uri_parser_class + + self.add_paths() + + def make_operation(self, operation: AbstractOperation) -> ValidationOperation: + return ValidationOperation( + self.next_app, + operation=operation, + validate_responses=self.validate_responses, + strict_validation=self.strict_validation, + validator_map=self.validator_map, + uri_parser_class=self.uri_parser_class, + ) + + +class ValidationMiddleware(RoutedMiddleware[ValidationAPI]): + """Middleware for validating requests according to the API contract.""" + + api_cls = ValidationAPI + + class MissingValidationOperation(Exception): """Missing validation operation""" diff --git a/setup.py b/setup.py index 3636247cb..6d4407187 100755 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ def read_version(package): 'werkzeug>=2.2.1,<3', 'starlette>=0.15,<1', 'httpx>=0.15,<1', + 'typing-extensions>=4,<5', ] swagger_ui_require = 'swagger-ui-bundle>=0.0.2,<0.1' diff --git a/tests/test_operation2.py b/tests/test_operation2.py index bac5be61f..cd37d14ff 100644 --- a/tests/test_operation2.py +++ b/tests/test_operation2.py @@ -425,6 +425,7 @@ def test_operation_remote_token_info(security_handler_factory): ) SecurityOperation( + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_REMOTE, @@ -494,6 +495,7 @@ def test_operation_local_security_oauth2(security_handler_factory): security_handler_factory.verify_oauth = verify_oauth SecurityOperation( + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_LOCAL, @@ -509,7 +511,8 @@ def test_operation_local_security_duplicate_token_info(security_handler_factory) security_handler_factory.verify_oauth = verify_oauth SecurityOperation( - security_handler_factory, + next_app=mock.Mock, + security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_BOTH, ) @@ -545,6 +548,7 @@ def test_multi_body(api): def test_no_token_info(security_handler_factory): SecurityOperation( + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_WO_INFO, @@ -565,6 +569,7 @@ def return_api_key_name(func, in_, name): security = [{"key1": [], "key2": []}] SecurityOperation( + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=security, security_schemes=SECURITY_DEFINITIONS_2_KEYS, @@ -589,6 +594,7 @@ def test_multiple_oauth_in_and(security_handler_factory, caplog): security = [{"oauth_1": ["uid"], "oauth_2": ["uid"]}] SecurityOperation( + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=security, security_schemes=SECURITY_DEFINITIONS_2_OAUTH, @@ -685,6 +691,7 @@ def test_oauth_scopes_in_or(security_handler_factory): security = [{"oauth": ["myscope"]}, {"oauth": ["myscope2"]}] SecurityOperation( + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=security, security_schemes=SECURITY_DEFINITIONS_LOCAL,