Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force Lambda failure #558

Merged
merged 8 commits into from
Dec 24, 2024
File renamed without changes.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ The readiness check port/path and traffic port can be configured using environme
| AWS_LWA_INVOKE_MODE | Lambda function invoke mode: "buffered" or "response_stream", default is "buffered" | "buffered" |
| AWS_LWA_PASS_THROUGH_PATH | the path for receiving event payloads that are passed through from non-http triggers | "/events" |
| AWS_LWA_AUTHORIZATION_SOURCE | a header name to be replaced to `Authorization` | None |
| AWS_LWA_ERROR_STATUS_CODES | comma-separated list of HTTP status codes that will cause Lambda invocations to fail (e.g. "500,502-504,422") | None |

> **Note:**
> We use "AWS_LWA_" prefix to namespacing all environment variables used by Lambda Web Adapter. The original ones will be supported until we reach version 1.0.
Expand Down Expand Up @@ -137,6 +138,8 @@ Please check out [FastAPI with Response Streaming](examples/fastapi-response-str

**AWS_LWA_AUTHORIZATION_SOURCE** - When set, Lambda Web Adapter replaces the specified header name to `Authorization` before proxying a request. This is useful when you use Lambda function URL with [IAM auth type](https://docs.aws.amazon.com/lambda/latest/dg/urls-auth.html), which reserves Authorization header for IAM authentication, but you want to still use Authorization header for your backend apps. This feature is disabled by default.

**AWS_LWA_ERROR_STATUS_CODES** - A comma-separated list of HTTP status codes that will cause Lambda invocations to fail. Supports individual codes and ranges (e.g. "500,502-504,422"). When the web application returns any of these status codes, the Lambda invocation will fail and trigger error handling behaviors like retries or DLQ processing. This is useful for treating certain HTTP errors as Lambda execution failures. This feature is disabled by default.

## Request Context

**Request Context** is metadata API Gateway sends to Lambda for a request. It usually contains requestId, requestTime, apiId, identity, and authorizer. Identity and authorizer are useful to get client identity for authorization. API Gateway Developer Guide contains more details [here](https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format).
Expand Down
60 changes: 60 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub struct AdapterOptions {
pub compression: bool,
pub invoke_mode: LambdaInvokeMode,
pub authorization_source: Option<String>,
pub error_status_codes: Option<Vec<u16>>,
}

impl Default for AdapterOptions {
Expand Down Expand Up @@ -116,10 +117,42 @@ impl Default for AdapterOptions {
.as_str()
.into(),
authorization_source: env::var("AWS_LWA_AUTHORIZATION_SOURCE").ok(),
error_status_codes: env::var("AWS_LWA_ERROR_STATUS_CODES")
.ok()
.map(|codes| parse_status_codes(&codes)),
}
}
}

fn parse_status_codes(input: &str) -> Vec<u16> {
input
.split(',')
.flat_map(|part| {
let part = part.trim();
if part.contains('-') {
let range: Vec<&str> = part.split('-').collect();
if range.len() == 2 {
if let (Ok(start), Ok(end)) = (range[0].parse::<u16>(), range[1].parse::<u16>()) {
return (start..=end).collect::<Vec<_>>();
}
}
tracing::warn!("Failed to parse status code range: {}", part);
vec![]
DiscreteTom marked this conversation as resolved.
Show resolved Hide resolved
} else {
part.parse::<u16>().map_or_else(
|_| {
if !part.is_empty() {
tracing::warn!("Failed to parse status code: {}", part);
}
vec![]
},
|code| vec![code],
)
}
})
.collect()
}

#[derive(Clone)]
pub struct Adapter<C, B> {
client: Arc<Client<C, B>>,
Expand All @@ -134,6 +167,7 @@ pub struct Adapter<C, B> {
compression: bool,
invoke_mode: LambdaInvokeMode,
authorization_source: Option<String>,
error_status_codes: Option<Vec<u16>>,
}

impl Adapter<HttpConnector, Body> {
Expand Down Expand Up @@ -171,6 +205,7 @@ impl Adapter<HttpConnector, Body> {
compression: options.compression,
invoke_mode: options.invoke_mode,
authorization_source: options.authorization_source.clone(),
error_status_codes: options.error_status_codes.clone(),
}
}
}
Expand Down Expand Up @@ -341,6 +376,17 @@ impl Adapter<HttpConnector, Body> {

let mut app_response = self.client.request(request).await?;

// Check if status code should trigger an error
if let Some(error_codes) = &self.error_status_codes {
let status = app_response.status().as_u16();
if error_codes.contains(&status) {
return Err(Error::from(format!(
"Request failed with configured error status code: {}",
status
)));
}
}

// remove "transfer-encoding" from the response to support "sam local start-api"
app_response.headers_mut().remove("transfer-encoding");

Expand Down Expand Up @@ -373,6 +419,20 @@ mod tests {
use super::*;
use httpmock::{Method::GET, MockServer};

#[test]
fn test_parse_status_codes() {
assert_eq!(parse_status_codes("500,502-504,422"), vec![500, 502, 503, 504, 422]);
assert_eq!(
parse_status_codes("500, 502-504, 422"), // with spaces
vec![500, 502, 503, 504, 422]
);
assert_eq!(parse_status_codes("500"), vec![500]);
assert_eq!(parse_status_codes("500-502"), vec![500, 501, 502]);
assert_eq!(parse_status_codes("invalid"), Vec::<u16>::new());
assert_eq!(parse_status_codes("500-invalid"), Vec::<u16>::new());
assert_eq!(parse_status_codes(""), Vec::<u16>::new());
}

#[tokio::test]
async fn test_status_200_is_ok() {
// Start app server
Expand Down
32 changes: 32 additions & 0 deletions tests/integ_tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,38 @@ async fn test_http_content_encoding_suffix() {
assert_eq!(json_data.to_owned(), body_to_string(response).await);
}

#[tokio::test]
async fn test_http_error_status_codes() {
// Start app server
let app_server = MockServer::start();
let error_endpoint = app_server.mock(|when, then| {
when.method(GET).path("/error");
then.status(502).body("Bad Gateway");
});

// Initialize adapter with error status codes
let mut adapter = Adapter::new(&AdapterOptions {
host: app_server.host(),
port: app_server.port().to_string(),
readiness_check_port: app_server.port().to_string(),
readiness_check_path: "/healthcheck".to_string(),
error_status_codes: Some(vec![500, 502, 503, 504]),
..Default::default()
});

// Call the adapter service with request that should trigger error
let req = LambdaEventBuilder::new().with_path("/error").build();
let mut request = Request::from(req);
add_lambda_context_to_request(&mut request);

let result = adapter.call(request).await;
assert!(result.is_err(), "Expected error response for status code 502");
assert!(result.unwrap_err().to_string().contains("502"));

// Assert endpoint was called
error_endpoint.assert();
}

#[tokio::test]
async fn test_http_authorization_source() {
// Start app server
Expand Down