From 1ab5400c0b0abe0483181dbbe972d00fa1d6e077 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Mon, 3 Oct 2022 23:01:21 +0200 Subject: [PATCH] Move JSON response body validation to middleware (#1591) * Extract boilerplate code into Routed base classes * Use typing_extensions for Python 3.7 Protocol support * Use Mock instead of AsyncMock * Extract response validation to middleware * Refactor Request validation to match Response validation * Factor out shared functionality * Fix typo in TextResponseBodyValidator class name * Fix string formatting * Use correct schema to check nullability in response validation --- connexion/decorators/response.py | 135 --------------- connexion/decorators/validation.py | 25 +-- connexion/middleware/main.py | 6 +- .../{validation.py => request_validation.py} | 80 ++++----- connexion/middleware/response_validation.py | 158 ++++++++++++++++++ connexion/operations/abstract.py | 17 -- connexion/utils.py | 30 ++++ connexion/validators.py | 129 +++++++++++--- tests/api/test_headers.py | 2 +- tests/api/test_schema.py | 28 ++-- tests/fakeapi/example_method_view.py | 17 +- tests/test_json_validation.py | 5 +- tests/test_resolver_methodview.py | 8 +- 13 files changed, 358 insertions(+), 282 deletions(-) delete mode 100644 connexion/decorators/response.py rename connexion/middleware/{validation.py => request_validation.py} (60%) create mode 100644 connexion/middleware/response_validation.py diff --git a/connexion/decorators/response.py b/connexion/decorators/response.py deleted file mode 100644 index 98511259a..000000000 --- a/connexion/decorators/response.py +++ /dev/null @@ -1,135 +0,0 @@ -""" -This module defines a view function decorator to validate its responses. -""" - -import asyncio -import functools -import logging - -from jsonschema import ValidationError - -from ..exceptions import NonConformingResponseBody, NonConformingResponseHeaders -from ..utils import all_json, has_coroutine -from .decorator import BaseDecorator -from .validation import ResponseBodyValidator - -logger = logging.getLogger("connexion.decorators.response") - - -class ResponseValidator(BaseDecorator): - def __init__(self, operation, mimetype, validator=None): - """ - :type operation: Operation - :type mimetype: str - :param validator: Validator class that should be used to validate passed data - against API schema. - :type validator: jsonschema.IValidator - """ - self.operation = operation - self.mimetype = mimetype - self.validator = validator - - def validate_response(self, data, status_code, headers, url): - """ - Validates the Response object based on what has been declared in the specification. - Ensures the response body matches the declared schema. - :type data: dict - :type status_code: int - :type headers: dict - :rtype bool | None - """ - # check against returned header, fall back to expected mimetype - content_type = headers.get("Content-Type", self.mimetype) - content_type = content_type.rsplit(";", 1)[ - 0 - ] # remove things like utf8 metadata - - response_definition = self.operation.response_definition( - str(status_code), content_type - ) - response_schema = self.operation.response_schema(str(status_code), content_type) - - if self.is_json_schema_compatible(response_schema): - v = ResponseBodyValidator(response_schema, validator=self.validator) - try: - data = self.operation.json_loads(data) - v.validate_schema(data, url) - except ValidationError as e: - raise NonConformingResponseBody(message=str(e)) - - if response_definition and response_definition.get("headers"): - required_header_keys = { - k - for (k, v) in response_definition.get("headers").items() - if v.get("required", False) - } - header_keys = set(headers.keys()) - missing_keys = required_header_keys - header_keys - if missing_keys: - pretty_list = ", ".join(missing_keys) - msg = ( - "Keys in header don't match response specification. " - "Difference: {}" - ).format(pretty_list) - raise NonConformingResponseHeaders(message=msg) - return True - - def is_json_schema_compatible(self, response_schema: dict) -> bool: - """ - Verify if the specified operation responses are JSON schema - compatible. - - All operations that specify a JSON schema and have content - type "application/json" or "text/plain" can be validated using - json_schema package. - """ - if not response_schema: - return False - return all_json([self.mimetype]) or self.mimetype == "text/plain" - - def __call__(self, function): - """ - :type function: types.FunctionType - :rtype: types.FunctionType - """ - - def _wrapper(request, response): - connexion_response = self.operation.api.get_connexion_response( - response, self.mimetype - ) - if not connexion_response.is_streamed: - self.validate_response( - connexion_response.body, - connexion_response.status_code, - connexion_response.headers, - request.url, - ) - else: - logger.warning("Skipping response validation for streamed response.") - - return response - - if has_coroutine(function): - - @functools.wraps(function) - async def wrapper(request): - response = function(request) - while asyncio.iscoroutine(response): - response = await response - - return _wrapper(request, response) - - else: # pragma: no cover - - @functools.wraps(function) - def wrapper(request): - response = function(request) - return _wrapper(request, response) - - return wrapper - - def __repr__(self): - """ - :rtype: str - """ - return "" # pragma: no cover diff --git a/connexion/decorators/validation.py b/connexion/decorators/validation.py index bfa004ed2..83a258d67 100644 --- a/connexion/decorators/validation.py +++ b/connexion/decorators/validation.py @@ -14,7 +14,7 @@ from ..exceptions import BadRequestProblem, ExtraParameterProblem from ..http_facts import FORM_CONTENT_TYPES -from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator +from ..json_schema import Draft4RequestValidator from ..lifecycle import ConnexionResponse from ..utils import boolean, is_null, is_nullable @@ -196,29 +196,6 @@ def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse] return None -class ResponseBodyValidator: - def __init__(self, schema, validator=None): - """ - :param schema: The schema of the response body - :param validator: Validator class that should be used to validate passed data - against API schema. Default is Draft4ResponseValidator. - :type validator: jsonschema.IValidator - """ - ValidatorClass = validator or Draft4ResponseValidator - self.validator = ValidatorClass(schema, format_checker=draft4_format_checker) - - def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse]: - try: - self.validator.validate(data) - except ValidationError as exception: - logger.error( - f"{url} validation error: {exception}", extra={"validator": "response"} - ) - raise exception - - return None - - class ParameterValidator: def __init__(self, parameters, api, strict_validation=False): """ diff --git a/connexion/middleware/main.py b/connexion/middleware/main.py index 950cdb532..2e2acea1b 100644 --- a/connexion/middleware/main.py +++ b/connexion/middleware/main.py @@ -5,10 +5,11 @@ from connexion.middleware.abstract import AppMiddleware from connexion.middleware.exceptions import ExceptionMiddleware +from connexion.middleware.request_validation import RequestValidationMiddleware +from connexion.middleware.response_validation import ResponseValidationMiddleware from connexion.middleware.routing import RoutingMiddleware from connexion.middleware.security import SecurityMiddleware from connexion.middleware.swagger_ui import SwaggerUIMiddleware -from connexion.middleware.validation import ValidationMiddleware class ConnexionMiddleware: @@ -18,7 +19,8 @@ class ConnexionMiddleware: SwaggerUIMiddleware, RoutingMiddleware, SecurityMiddleware, - ValidationMiddleware, + RequestValidationMiddleware, + ResponseValidationMiddleware, ] def __init__( diff --git a/connexion/middleware/validation.py b/connexion/middleware/request_validation.py similarity index 60% rename from connexion/middleware/validation.py rename to connexion/middleware/request_validation.py index 6a0ae86b8..ff3918d5a 100644 --- a/connexion/middleware/validation.py +++ b/connexion/middleware/request_validation.py @@ -6,71 +6,51 @@ from starlette.types import ASGIApp, Receive, Scope, Send +from connexion import utils from connexion.decorators.uri_parsing import AbstractURIParser from connexion.exceptions import UnsupportedMediaTypeProblem from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation -from connexion.utils import is_nullable -from connexion.validators import JSONBodyValidator - -from ..decorators.response import ResponseValidator -from ..decorators.validation import ParameterValidator +from connexion.validators import VALIDATOR_MAP logger = logging.getLogger("connexion.middleware.validation") -VALIDATOR_MAP = { - "parameter": ParameterValidator, - "body": {"application/json": JSONBodyValidator}, - "response": ResponseValidator, -} - -class ValidationOperation: +class RequestValidationOperation: def __init__( self, next_app: ASGIApp, *, operation: AbstractOperation, - validate_responses: bool = False, strict_validation: bool = False, validator_map: t.Optional[dict] = None, uri_parser_class: t.Optional[AbstractURIParser] = None, ) -> None: self.next_app = next_app self._operation = operation - self.validate_responses = validate_responses self.strict_validation = strict_validation self._validator_map = VALIDATOR_MAP self._validator_map.update(validator_map or {}) self.uri_parser_class = uri_parser_class - def extract_content_type(self, headers: dict) -> t.Tuple[str, str]: + def extract_content_type( + self, headers: t.List[t.Tuple[bytes, bytes]] + ) -> t.Tuple[str, str]: """Extract the mime type and encoding from the content type headers. - :param headers: Header dict from ASGI scope + :param headers: Headers from ASGI scope :return: A tuple of mime type, encoding """ - encoding = "utf-8" - for key, value in headers: - # Headers can always be decoded using latin-1: - # https://stackoverflow.com/a/27357138/4098821 - key = key.decode("latin-1") - if key.lower() == "content-type": - content_type = value.decode("latin-1") - if ";" in content_type: - mime_type, parameters = content_type.split(";", maxsplit=1) - - prefix = "charset=" - for parameter in parameters.split(";"): - if parameter.startswith(prefix): - encoding = parameter[len(prefix) :] - else: - mime_type = content_type - break - else: + mime_type, encoding = utils.extract_content_type(headers) + if mime_type is None: # Content-type header is not required. Take a best guess. - mime_type = self._operation.consumes[0] + try: + mime_type = self._operation.consumes[0] + except IndexError: + mime_type = "application/octet-stream" + if encoding is None: + encoding = "utf-8" return mime_type, encoding @@ -86,6 +66,8 @@ def validate_mime_type(self, mime_type: str) -> None: ) async def __call__(self, scope: Scope, receive: Receive, send: Send): + receive_fn = receive + headers = scope["headers"] mime_type, encoding = self.extract_content_type(headers) self.validate_mime_type(mime_type) @@ -102,25 +84,25 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): ) else: validator = body_validator( - self.next_app, + scope, + receive, schema=self._operation.body_schema, - nullable=is_nullable(self._operation.body_definition), + nullable=utils.is_nullable(self._operation.body_definition), encoding=encoding, ) - return await validator(scope, receive, send) + receive_fn = validator.receive - await self.next_app(scope, receive, send) + await self.next_app(scope, receive_fn, send) -class ValidationAPI(RoutedAPI[ValidationOperation]): +class RequestValidationAPI(RoutedAPI[RequestValidationOperation]): """Validation API.""" - operation_cls = ValidationOperation + operation_cls = RequestValidationOperation def __init__( self, *args, - validate_responses=False, strict_validation=False, validator_map=None, uri_parser_class=None, @@ -129,9 +111,6 @@ def __init__( super().__init__(*args, **kwargs) self.validator_map = validator_map - logger.debug("Validate Responses: %s", str(validate_responses)) - self.validate_responses = validate_responses - logger.debug("Strict Request Validation: %s", str(strict_validation)) self.strict_validation = strict_validation @@ -139,21 +118,22 @@ def __init__( self.add_paths() - def make_operation(self, operation: AbstractOperation) -> ValidationOperation: - return ValidationOperation( + def make_operation( + self, operation: AbstractOperation + ) -> RequestValidationOperation: + return RequestValidationOperation( self.next_app, operation=operation, - validate_responses=self.validate_responses, strict_validation=self.strict_validation, validator_map=self.validator_map, uri_parser_class=self.uri_parser_class, ) -class ValidationMiddleware(RoutedMiddleware[ValidationAPI]): +class RequestValidationMiddleware(RoutedMiddleware[RequestValidationAPI]): """Middleware for validating requests according to the API contract.""" - api_cls = ValidationAPI + api_cls = RequestValidationAPI class MissingValidationOperation(Exception): diff --git a/connexion/middleware/response_validation.py b/connexion/middleware/response_validation.py new file mode 100644 index 000000000..bcd66b8c9 --- /dev/null +++ b/connexion/middleware/response_validation.py @@ -0,0 +1,158 @@ +""" +Validation Middleware. +""" +import logging +import typing as t + +from starlette.types import ASGIApp, Receive, Scope, Send + +from connexion import utils +from connexion.exceptions import NonConformingResponseHeaders +from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware +from connexion.operations import AbstractOperation +from connexion.validators import VALIDATOR_MAP + +logger = logging.getLogger("connexion.middleware.validation") + + +class ResponseValidationOperation: + def __init__( + self, + next_app: ASGIApp, + *, + operation: AbstractOperation, + validator_map: t.Optional[dict] = None, + ) -> None: + self.next_app = next_app + self._operation = operation + self._validator_map = VALIDATOR_MAP + self._validator_map.update(validator_map or {}) + + def extract_content_type( + self, headers: t.List[t.Tuple[bytes, bytes]] + ) -> t.Tuple[str, str]: + """Extract the mime type and encoding from the content type headers. + + :param headers: Headers from ASGI scope + + :return: A tuple of mime type, encoding + """ + mime_type, encoding = utils.extract_content_type(headers) + if mime_type is None: + # Content-type header is not required. Take a best guess. + try: + mime_type = self._operation.produces[0] + except IndexError: + mime_type = "application/octet-stream" + if encoding is None: + encoding = "utf-8" + + return mime_type, encoding + + def validate_mime_type(self, mime_type: str) -> None: + """Validate the mime type against the spec. + + :param mime_type: mime type from content type header + """ + if mime_type.lower() not in [c.lower() for c in self._operation.produces]: + raise NonConformingResponseHeaders( + reason="Invalid Response Content-type", + message=f"Invalid Response Content-type ({mime_type}), " + f"expected {self._operation.produces}", + ) + + @staticmethod + def validate_required_headers( + headers: t.List[tuple], response_definition: dict + ) -> None: + required_header_keys = { + k.lower() + for (k, v) in response_definition.get("headers", {}).items() + if v.get("required", False) + } + header_keys = set(header[0].decode("latin-1").lower() for header in headers) + missing_keys = required_header_keys - header_keys + if missing_keys: + pretty_list = ", ".join(missing_keys) + msg = ( + "Keys in header don't match response specification. Difference: {}" + ).format(pretty_list) + raise NonConformingResponseHeaders(message=msg) + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + + send_fn = send + + async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None: + nonlocal send_fn + + if message["type"] == "http.response.start": + status = str(message["status"]) + headers = message["headers"] + mime_type, encoding = self.extract_content_type(headers) + # TODO: Add produces to all tests and fix response content types + # self.validate_mime_type(mime_type) + response_definition = self._operation.response_definition( + status, mime_type + ) + self.validate_required_headers(headers, response_definition) + + # Validate body + try: + body_validator = self._validator_map["response"][mime_type] # type: ignore + except KeyError: + logging.info( + f"Skipping validation. No validator registered for content type: " + f"{mime_type}." + ) + else: + validator = body_validator( + scope, + send, + schema=self._operation.response_schema(status, mime_type), + nullable=utils.is_nullable( + self._operation.response_definition(status, mime_type) + ), + encoding=encoding, + ) + send_fn = validator.send + + return await send_fn(message) + + await self.next_app(scope, receive, wrapped_send) + + +class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]): + """Validation API.""" + + operation_cls = ResponseValidationOperation + + def __init__( + self, + *args, + validator_map=None, + validate_responses=False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.validator_map = validator_map + self.validate_responses = validate_responses + self.add_paths() + + def make_operation( + self, operation: AbstractOperation + ) -> ResponseValidationOperation: + if self.validate_responses: + return ResponseValidationOperation( + self.next_app, + operation=operation, + validator_map=self.validator_map, + ) + else: + return self.next_app # type: ignore + + +class ResponseValidationMiddleware(RoutedMiddleware[ResponseValidationAPI]): + """Middleware for validating requests according to the API contract.""" + + api_cls = ResponseValidationAPI diff --git a/connexion/operations/abstract.py b/connexion/operations/abstract.py index bb26cdc74..16e027b9e 100644 --- a/connexion/operations/abstract.py +++ b/connexion/operations/abstract.py @@ -9,7 +9,6 @@ from ..decorators.decorator import RequestResponseDecorator from ..decorators.parameter import parameter_to_arg from ..decorators.produces import BaseSerializer, Produces -from ..decorators.response import ResponseValidator from ..decorators.validation import ParameterValidator, RequestBodyValidator from ..utils import all_json, is_nullable @@ -20,7 +19,6 @@ VALIDATOR_MAP = { "parameter": ParameterValidator, "body": RequestBodyValidator, - "response": ResponseValidator, } @@ -389,12 +387,6 @@ def function(self): self.pythonic_params, ) - if self.validate_responses: - logger.debug("... Response validation enabled.") - response_decorator = self.__response_validation_decorator - logger.debug("... Adding response decorator (%r)", response_decorator) - function = response_decorator(function) - produces_decorator = self.__content_type_decorator logger.debug("... Adding produces decorator (%r)", produces_decorator) function = produces_decorator(function) @@ -473,15 +465,6 @@ def __validation_decorators(self): strict_validation=self.strict_validation, ) - @property - def __response_validation_decorator(self): - """ - Get a decorator for validating the generated Response. - :rtype: types.FunctionType - """ - ResponseValidator = self.validator_map["response"] - return ResponseValidator(self, self.get_mimetype()) - def json_loads(self, data): """ A wrapper for calling the API specific JSON loader. diff --git a/connexion/utils.py b/connexion/utils.py index e92d403ce..cc6cdd632 100644 --- a/connexion/utils.py +++ b/connexion/utils.py @@ -5,6 +5,7 @@ import asyncio import functools import importlib +import typing as t import yaml @@ -266,3 +267,32 @@ def _required_lib(exc, *args, **kwargs): raise exc return functools.partial(_required_lib, exc) + + +def extract_content_type( + headers: t.List[t.Tuple[bytes, bytes]] +) -> t.Tuple[t.Optional[str], t.Optional[str]]: + """Extract the mime type and encoding from the content type headers. + + :param headers: Headers from ASGI scope + + :return: A tuple of mime type, encoding + """ + mime_type, encoding = None, None + for key, value in headers: + # Headers can always be decoded using latin-1: + # https://stackoverflow.com/a/27357138/4098821 + decoded_key = key.decode("latin-1") + if decoded_key.lower() == "content-type": + content_type = value.decode("latin-1") + if ";" in content_type: + mime_type, parameters = content_type.split(";", maxsplit=1) + + prefix = "charset=" + for parameter in parameters.split(";"): + if parameter.startswith(prefix): + encoding = parameter[len(prefix) :] + else: + mime_type = content_type + break + return mime_type, encoding diff --git a/connexion/validators.py b/connexion/validators.py index 4d726a20d..dff5c41bc 100644 --- a/connexion/validators.py +++ b/connexion/validators.py @@ -6,36 +6,38 @@ import typing as t from jsonschema import Draft4Validator, ValidationError, draft4_format_checker -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import Receive, Scope, Send -from connexion.exceptions import BadRequestProblem -from connexion.json_schema import Draft4RequestValidator +from connexion.decorators.validation import ParameterValidator +from connexion.exceptions import BadRequestProblem, NonConformingResponseBody +from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator from connexion.utils import is_null logger = logging.getLogger("connexion.middleware.validators") -class JSONBodyValidator: +class JSONRequestBodyValidator: """Request body validator for json content types.""" def __init__( self, - next_app: ASGIApp, + scope: Scope, + receive: Receive, *, schema: dict, validator: t.Type[Draft4Validator] = None, nullable=False, encoding: str, ) -> None: - self.next_app = next_app + self._scope = scope + self._receive = receive self.schema = schema self.has_default = schema.get("default", False) self.nullable = nullable - self.validator_cls = validator or Draft4RequestValidator - self.validator = self.validator_cls( - schema, format_checker=draft4_format_checker - ) + validator_cls = validator or Draft4RequestValidator + self.validator = validator_cls(schema, format_checker=draft4_format_checker) self.encoding = encoding + self._messages: t.List[t.MutableMapping[str, t.Any]] = [] @classmethod def _error_path_message(cls, exception): @@ -44,7 +46,6 @@ def _error_path_message(cls, exception): return error_path_msg def validate(self, body: dict): - try: self.validator.validate(body) except ValidationError as exception: @@ -55,18 +56,15 @@ def validate(self, body: dict): ) raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - # Based on https://github.com/encode/starlette/pull/1519#issuecomment-1060633787 - # Ingest all body messages from the ASGI `receive` callable. - messages = [] + async def receive(self) -> t.Optional[t.MutableMapping[str, t.Any]]: more_body = True while more_body: - message = await receive() - messages.append(message) + message = await self._receive() + self._messages.append(message) more_body = message.get("more_body", False) # TODO: make json library pluggable - bytes_body = b"".join([message.get("body", b"") for message in messages]) + bytes_body = b"".join([message.get("body", b"") for message in self._messages]) decoded_body = bytes_body.decode(self.encoding) if decoded_body and not (self.nullable and is_null(decoded_body)): @@ -77,11 +75,92 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.validate(body) - async def wrapped_receive(): - # First up we want to return any messages we've stashed. - if messages: - return messages.pop(0) - # Once that's done we can just await any other messages. - return await receive() + while self._messages: + return self._messages.pop(0) + return None + + +class JSONResponseBodyValidator: + """Response body validator for json content types.""" + + def __init__( + self, + scope: Scope, + send: Send, + *, + schema: dict, + validator: t.Type[Draft4Validator] = None, + nullable=False, + encoding: str, + ) -> None: + self._scope = scope + self._send = send + self.schema = schema + self.has_default = schema.get("default", False) + self.nullable = nullable + validator_cls = validator or Draft4ResponseValidator + self.validator = validator_cls(schema, format_checker=draft4_format_checker) + self.encoding = encoding + self._messages: t.List[t.MutableMapping[str, t.Any]] = [] + + @classmethod + def _error_path_message(cls, exception): + error_path = ".".join(str(item) for item in exception.path) + error_path_msg = f" - '{error_path}'" if error_path else "" + return error_path_msg + + def validate(self, body: dict): + try: + self.validator.validate(body) + except ValidationError as exception: + error_path_msg = self._error_path_message(exception=exception) + logger.error( + f"Validation error: {exception.message}{error_path_msg}", + extra={"validator": "body"}, + ) + raise NonConformingResponseBody( + message=f"{exception.message}{error_path_msg}" + ) + + @staticmethod + def parse(body: str) -> dict: + try: + return json.loads(body) + except json.decoder.JSONDecodeError as e: + raise BadRequestProblem(str(e)) + + async def send(self, message: t.MutableMapping[str, t.Any]) -> None: + self._messages.append(message) + + if message["type"] == "http.response.start" or message.get("more_body", False): + return + + # TODO: make json library pluggable + bytes_body = b"".join([message.get("body", b"") for message in self._messages]) + decoded_body = bytes_body.decode(self.encoding) + + if decoded_body and not (self.nullable and is_null(decoded_body)): + body = self.parse(decoded_body) + self.validate(body) + + while self._messages: + await self._send(self._messages.pop(0)) + + +class TextResponseBodyValidator(JSONResponseBodyValidator): + @staticmethod + def parse(body: str) -> str: # type: ignore + try: + return json.loads(body) + except json.decoder.JSONDecodeError: + return body + - await self.next_app(scope, wrapped_receive, send) +VALIDATOR_MAP = { + "parameter": ParameterValidator, + "body": {"application/json": JSONRequestBodyValidator}, + "response": { + "application/json": JSONResponseBodyValidator, + "text/plain": TextResponseBodyValidator, + }, +} diff --git a/tests/api/test_headers.py b/tests/api/test_headers.py index 709b8821f..6b6ec4236 100644 --- a/tests/api/test_headers.py +++ b/tests/api/test_headers.py @@ -34,7 +34,7 @@ def test_header_not_returned(simple_openapi_app): assert data["title"] == "Response headers do not conform to specification" assert ( data["detail"] - == "Keys in header don't match response specification. Difference: Location" + == "Keys in header don't match response specification. Difference: location" ) assert data["status"] == 500 diff --git a/tests/api/test_schema.py b/tests/api/test_schema.py index ec20cfd9a..fa48c35d7 100644 --- a/tests/api/test_schema.py +++ b/tests/api/test_schema.py @@ -80,59 +80,59 @@ def test_schema_response(schema_app): request = app_client.get( "/v1.0/test_schema/response/object/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/object/invalid_type", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/object/invalid_requirements", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/string/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/string/invalid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/integer/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/integer/invalid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/number/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/number/invalid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/boolean/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/boolean/invalid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/array/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/array/invalid_dict", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/array/invalid_string", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text def test_schema_in_query(schema_app): diff --git a/tests/fakeapi/example_method_view.py b/tests/fakeapi/example_method_view.py index f66d2b549..d6ad0e93b 100644 --- a/tests/fakeapi/example_method_view.py +++ b/tests/fakeapi/example_method_view.py @@ -6,19 +6,22 @@ class PetsView(MethodView): mycontent = "demonstrate return from MethodView class" def get(self, **kwargs): - kwargs.update({"method": "get"}) - return kwargs + if kwargs: + kwargs.update({"name": "get"}) + return kwargs + else: + return [{"name": "get"}] def search(self): - return "search" + return [{"name": "search"}] def post(self, **kwargs): - kwargs.update({"method": "post"}) - return kwargs + kwargs.update({"name": "post"}) + return kwargs, 201 def put(self, *args, **kwargs): - kwargs.update({"method": "put"}) - return kwargs + kwargs.update({"name": "put"}) + return kwargs, 201 # Test that operation_id can still override resolver diff --git a/tests/test_json_validation.py b/tests/test_json_validation.py index 3857fc3d8..36916dd0b 100644 --- a/tests/test_json_validation.py +++ b/tests/test_json_validation.py @@ -3,10 +3,9 @@ import pytest from connexion import App -from connexion.decorators.validation import RequestBodyValidator from connexion.json_schema import Draft4RequestValidator from connexion.spec import Specification -from connexion.validators import JSONBodyValidator +from connexion.validators import JSONRequestBodyValidator from jsonschema.validators import _utils, extend from conftest import build_app_from_fixture @@ -31,7 +30,7 @@ def validate_type(validator, types, instance, schema): MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type}) - class MyJSONBodyValidator(JSONBodyValidator): + class MyJSONBodyValidator(JSONRequestBodyValidator): def __init__(self, *args, **kwargs): super().__init__(*args, validator=MinLengthRequestValidator, **kwargs) diff --git a/tests/test_resolver_methodview.py b/tests/test_resolver_methodview.py index 06436860a..5d315e1ac 100644 --- a/tests/test_resolver_methodview.py +++ b/tests/test_resolver_methodview.py @@ -192,13 +192,13 @@ def test_method_view_resolver_integration(method_view_app): client = method_view_app.app.test_client() r = client.get("/v1.0/pets") - assert r.json == {"method": "get"} + assert r.json == [{"name": "get"}] r = client.get("/v1.0/pets/1") - assert r.json == {"method": "get", "petId": 1} + assert r.json == {"name": "get", "petId": 1} r = client.post("/v1.0/pets", json={"name": "Musti"}) - assert r.json == {"method": "post", "body": {"name": "Musti"}} + assert r.json == {"name": "post", "body": {"name": "Musti"}} r = client.put("/v1.0/pets/1", json={"name": "Igor"}) - assert r.json == {"method": "put", "petId": 1, "body": {"name": "Igor"}} + assert r.json == {"name": "put", "petId": 1, "body": {"name": "Igor"}}