diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index fd49ed9aef9..80898c14f78 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1500,7 +1500,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, - response_validation_error_http_status: HTTPStatus | None = None, + response_validation_error_http_status=None, ): """ Parameters @@ -1520,6 +1520,8 @@ def __init__( Each prefix can be a static string or a compiled regex pattern enable_validation: bool | None Enables validation of the request body against the route schema, by default False. + response_validation_error_http_status + Enables response validation and sets returned status code if response is not validated. """ self._proxy_type = proxy_type self._dynamic_routes: list[Route] = [] @@ -1535,7 +1537,28 @@ def __init__( self.context: dict = {} # early init as customers might add context before event resolution self.processed_stack_frames = [] self._response_builder_class = ResponseBuilder[BaseProxyEvent] - self._response_validation_error_http_status = response_validation_error_http_status + self._has_response_validation_error = response_validation_error_http_status is not None + + if response_validation_error_http_status and not enable_validation: + msg = "'response_validation_error_http_status' cannot be set when enable_validation is False." + raise ValueError(msg) + + if ( + not isinstance(response_validation_error_http_status, HTTPStatus) + and response_validation_error_http_status is not None + ): + + try: + response_validation_error_http_status = HTTPStatus(response_validation_error_http_status) + except ValueError: + msg = f"'{response_validation_error_http_status}' must be an integer representing an HTTP status code." + raise ValueError(msg) from None + + self._response_validation_error_http_status = ( + response_validation_error_http_status + if response_validation_error_http_status + else HTTPStatus.UNPROCESSABLE_ENTITY + ) # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) @@ -1549,7 +1572,7 @@ def __init__( [ OpenAPIValidationMiddleware( validation_serializer=serializer, - has_response_validation_error=self._response_validation_error_http_status is not None, + has_response_validation_error=self._has_response_validation_error, ), ], ) @@ -2386,12 +2409,15 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild # OpenAPIValidationMiddleware will only raise ResponseValidationError when # 'self._response_validation_error_http_status' is not None if isinstance(exp, ResponseValidationError): - if self._response_validation_error_http_status is None: - raise TypeError + http_status = ( + self._response_validation_error_http_status + if self._response_validation_error_http_status + else HTTPStatus.UNPROCESSABLE_ENTITY + ) errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()] return self._response_builder_class( response=Response( - status_code=self._response_validation_error_http_status, + status_code=http_status.value, content_type=content_types.APPLICATION_JSON, body={"statusCode": self._response_validation_error_http_status, "detail": errors}, ), @@ -2611,7 +2637,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, - response_validation_error_http_status: HTTPStatus | None = None, + response_validation_error_http_status: HTTPStatus | int | None = None, ): """Amazon API Gateway REST and HTTP API v1 payload resolver""" super().__init__( @@ -2695,7 +2721,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, - response_validation_error_http_status: HTTPStatus | None = None, + response_validation_error_http_status: HTTPStatus | int | None = None, ): """Amazon API Gateway HTTP API v2 payload resolver""" super().__init__( @@ -2734,7 +2760,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, - response_validation_error_http_status: HTTPStatus | None = None, + response_validation_error_http_status: HTTPStatus | int | None = None, ): """Amazon Application Load Balancer (ALB) resolver""" super().__init__( diff --git a/aws_lambda_powertools/event_handler/lambda_function_url.py b/aws_lambda_powertools/event_handler/lambda_function_url.py index f45d7253cd7..2120bdeb28a 100644 --- a/aws_lambda_powertools/event_handler/lambda_function_url.py +++ b/aws_lambda_powertools/event_handler/lambda_function_url.py @@ -59,7 +59,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, - response_validation_error_http_status: HTTPStatus | None = None, + response_validation_error_http_status: HTTPStatus | int | None = None, ): super().__init__( ProxyEventType.LambdaFunctionUrlEvent, diff --git a/aws_lambda_powertools/event_handler/vpc_lattice.py b/aws_lambda_powertools/event_handler/vpc_lattice.py index 9e36b540ffd..fcb58545055 100644 --- a/aws_lambda_powertools/event_handler/vpc_lattice.py +++ b/aws_lambda_powertools/event_handler/vpc_lattice.py @@ -55,7 +55,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, - response_validation_error_http_status: HTTPStatus | None = None, + response_validation_error_http_status: HTTPStatus | int | None = None, ): """Amazon VPC Lattice resolver""" super().__init__( @@ -113,7 +113,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, - response_validation_error_http_status: HTTPStatus | None = None, + response_validation_error_http_status: HTTPStatus | int | None = None, ): """Amazon VPC Lattice resolver""" super().__init__(