Skip to content

Commit

Permalink
refactor: make response_validation_error_http_status accept more type…
Browse files Browse the repository at this point in the history
…s and add more detailed error messages.
  • Loading branch information
Amin Farjadi committed Feb 28, 2025
1 parent 1c33611 commit f8ead84
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
44 changes: 35 additions & 9 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | None = None,
response_validation_error_http_status=None,
):
"""
Parameters
Expand All @@ -1520,6 +1520,8 @@ def __init__(
Each prefix can be a static string or a compiled regex pattern
enable_validation: bool | None
Enables validation of the request body against the route schema, by default False.
response_validation_error_http_status
Enables response validation and sets returned status code if response is not validated.
"""
self._proxy_type = proxy_type
self._dynamic_routes: list[Route] = []
Expand All @@ -1535,7 +1537,28 @@ def __init__(
self.context: dict = {} # early init as customers might add context before event resolution
self.processed_stack_frames = []
self._response_builder_class = ResponseBuilder[BaseProxyEvent]
self._response_validation_error_http_status = response_validation_error_http_status
self._has_response_validation_error = response_validation_error_http_status is not None

if response_validation_error_http_status and not enable_validation:
msg = "'response_validation_error_http_status' cannot be set when enable_validation is False."
raise ValueError(msg)

if (
not isinstance(response_validation_error_http_status, HTTPStatus)
and response_validation_error_http_status is not None
):

try:
response_validation_error_http_status = HTTPStatus(response_validation_error_http_status)
except ValueError:
msg = f"'{response_validation_error_http_status}' must be an integer representing an HTTP status code."
raise ValueError(msg) from None

self._response_validation_error_http_status = (
response_validation_error_http_status
if response_validation_error_http_status
else HTTPStatus.UNPROCESSABLE_ENTITY
)

# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
Expand All @@ -1549,7 +1572,7 @@ def __init__(
[
OpenAPIValidationMiddleware(
validation_serializer=serializer,
has_response_validation_error=self._response_validation_error_http_status is not None,
has_response_validation_error=self._has_response_validation_error,
),
],
)
Expand Down Expand Up @@ -2386,12 +2409,15 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
# 'self._response_validation_error_http_status' is not None
if isinstance(exp, ResponseValidationError):
if self._response_validation_error_http_status is None:
raise TypeError
http_status = (
self._response_validation_error_http_status
if self._response_validation_error_http_status
else HTTPStatus.UNPROCESSABLE_ENTITY
)
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
return self._response_builder_class(
response=Response(
status_code=self._response_validation_error_http_status,
status_code=http_status.value,
content_type=content_types.APPLICATION_JSON,
body={"statusCode": self._response_validation_error_http_status, "detail": errors},
),
Expand Down Expand Up @@ -2611,7 +2637,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | None = None,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon API Gateway REST and HTTP API v1 payload resolver"""
super().__init__(
Expand Down Expand Up @@ -2695,7 +2721,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | None = None,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon API Gateway HTTP API v2 payload resolver"""
super().__init__(
Expand Down Expand Up @@ -2734,7 +2760,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | None = None,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon Application Load Balancer (ALB) resolver"""
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/event_handler/lambda_function_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | None = None,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
super().__init__(
ProxyEventType.LambdaFunctionUrlEvent,
Expand Down
4 changes: 2 additions & 2 deletions aws_lambda_powertools/event_handler/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | None = None,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon VPC Lattice resolver"""
super().__init__(
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | None = None,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon VPC Lattice resolver"""
super().__init__(
Expand Down

0 comments on commit f8ead84

Please sign in to comment.