From 6f32cdc194fb4dc29012a2709422f62f694deb22 Mon Sep 17 00:00:00 2001 From: Srikanth Govindarajan Date: Tue, 7 Jan 2025 19:23:18 -0800 Subject: [PATCH 1/6] lambda processor should retry for certain class of exceptions Signed-off-by: Srikanth Govindarajan --- .../lambda/common/LambdaCommonHandler.java | 3 + .../common/util/LambdaRetryStrategy.java | 190 ++++++++++++++++++ .../lambda/processor/LambdaProcessor.java | 40 +++- .../lambda/utils/LambdaRetryStrategyTest.java | 153 ++++++++++++++ 4 files changed, 377 insertions(+), 9 deletions(-) create mode 100644 data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java create mode 100644 data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java index 3518508b64..5ef5dac7ac 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java @@ -36,6 +36,9 @@ private LambdaCommonHandler() { } public static boolean isSuccess(InvokeResponse response) { + if(response == null) { + return false; + } int statusCode = response.statusCode(); return statusCode >= 200 && statusCode < 300; } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java new file mode 100644 index 0000000000..56a3c5af1f --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java @@ -0,0 +1,190 @@ +package org.opensearch.dataprepper.plugins.lambda.common.util; + +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; +import software.amazon.awssdk.services.lambda.model.ServiceException; + +import java.time.Duration; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import org.slf4j.Logger; + +import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; + +/** + * Similar to BulkRetryStrategy in the OpenSearch sink. + * Categorizes AWS Lambda exceptions and status codes into + * retryable and non-retryable scenarios. + */ +public final class LambdaRetryStrategy { + + private LambdaRetryStrategy() { + } + + public static boolean isRetryableException(final Throwable t) { + if (t instanceof TooManyRequestsException) { + // Throttling => often can retry with backoff + return true; + } + if (t instanceof ServiceException) { + // Usually indicates a 5xx => can retry + return true; + } + if (t instanceof SdkClientException) { + // Possibly network/connection error => can retry + return true; + } + return false; + } + + public static boolean isRetryableResponse(final InvokeResponse response) { + int statusCode = response.statusCode(); + // Throttling or internal error then retry + return (statusCode == 429) || (statusCode >= 500 && statusCode < 600); + } + + /** + * Set of status codes that should generally NOT be retried + * because they indicate client-side or permanent errors. + */ + private static final Set NON_RETRY_STATUS = new HashSet<>( + Arrays.asList( + 400, // ExpiredTokenException + 403, // IncompleteSignature, AccessDeniedException, AccessDeniedException + 404, // Not Found + 409 // Conflict + ) + ); + + /** + * Possibly a set of “bad request” style errors which might fall + * under the NON_RETRY_STATUS or be handled differently if you prefer. + */ + private static final Set BAD_REQUEST_ERRORS = new HashSet<>( + Arrays.asList( + 400, // Bad Request + 422, // Unprocessable Entity + 417, // Expectation Failed + 406 // Not Acceptable + ) + ); + + /** + * Status codes which may indicate a security or policy problem, so we don't retry. + */ + private static final Set NOT_ALLOWED_ERRORS = new HashSet<>( + Arrays.asList( + 401, // Unauthorized + 403, // Forbidden + 405 // Method Not Allowed + ) + ); + + /** + * Examples of input or payload errors that are likely not retryable + * unless the pipeline itself corrects them. + */ + private static final Set INVALID_INPUT_ERRORS = new HashSet<>( + Arrays.asList( + 413, // Payload Too Large + 414, // URI Too Long + 416 // Range Not Satisfiable + // ... + ) + ); + + /** + * Example of a “timeout” scenario. Lambda can return 429 for "Too Many Requests" or + * 408 (if applicable) for timeouts in some contexts. + * This can be considered retryable if you want to handle the throttling scenario. + */ + private static final Set TIMEOUT_ERRORS = new HashSet<>( + Arrays.asList( + 408, // Request Timeout + 429 // Too Many Requests (often used as "throttling" for Lambda) + ) + ); + + + public static boolean isRetryable(final InvokeResponse response) { + if(response == null) return false; + int statusCode = response.statusCode(); + // Example logic: 429 (Too Many Requests) or 5xx => retry + return statusCode == 429 || (statusCode >= 500 && statusCode < 600); + } + + /** + * Determines if this is definitely NOT retryable (client error or permanent failure). + */ + public static boolean isNonRetryable(final InvokeResponse response) { + if(response == null) return false; + + int statusCode = response.statusCode(); + return NON_RETRY_STATUS.contains(statusCode) + || BAD_REQUEST_ERRORS.contains(statusCode) + || NOT_ALLOWED_ERRORS.contains(statusCode) + || INVALID_INPUT_ERRORS.contains(statusCode); + } + + /** + * For convenience, you can create more fine-grained checks or + * direct set membership checks (e.g. isBadRequest(...), isTimeout(...)) if you want. + */ + public static boolean isTimeoutError(final InvokeResponse response) { + return TIMEOUT_ERRORS.contains(response.statusCode()); + } + + public static InvokeResponse retryOrFail( + final LambdaAsyncClient lambdaAsyncClient, + final Buffer buffer, + final LambdaCommonConfig config, + final InvokeResponse previousResponse, + final Logger LOG + ) { + int maxRetries = config.getClientOptions().getMaxConnectionRetries(); + Duration backoff = config.getClientOptions().getBaseDelay(); + + int attempt = 1; + InvokeResponse response = previousResponse; + + do{ + LOG.warn("Retrying Lambda invocation attempt {} of {} after {} ms backoff", + attempt, maxRetries, backoff); + try { + // Sleep for backoff + Thread.sleep(backoff.toMillis()); + + // Re-invoke Lambda with the same payload + InvokeRequest requestPayload = buffer.getRequestPayload( + config.getFunctionName(), + config.getInvocationType().getAwsLambdaValue() + ); + // Do a synchronous call. + response = lambdaAsyncClient.invoke(requestPayload).join(); + + if (isSuccess(response)) { + LOG.info("Retry attempt {} succeeded with status code {}", attempt, response.statusCode()); + return response; + } else{ + throw new RuntimeException(); + } + } catch (Exception e) { + LOG.error("Failed to invoke failed with exception {} in attempt {}", e.getMessage(), attempt); + if(!isRetryable(response)){ + throw new RuntimeException("Failed to invoke failed",e); + } + } + attempt++; + } while(attempt <= maxRetries && isRetryable(response)); + + return response; + } + +} + diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index 786939f5a1..8adbf9cc9b 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -31,6 +31,7 @@ import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; +import org.opensearch.dataprepper.plugins.lambda.common.util.LambdaRetryStrategy; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.SdkBytes; @@ -176,14 +177,21 @@ public Collection> doExecute(Collection> records) { Buffer inputBuffer = entry.getKey(); try { InvokeResponse response = future.join(); + + // If this response has a failure is retryable, do a direct retry + if (!isSuccess(response) && LambdaRetryStrategy.isRetryable(response)){ + response = LambdaRetryStrategy.retryOrFail( + lambdaAsyncClient, + inputBuffer, + lambdaProcessorConfig, + response, + LOG + ); + } + Duration latency = inputBuffer.stopLatencyWatch(); lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); requestPayloadMetric.record(inputBuffer.getPayloadRequestSize()); - if (!isSuccess(response)) { - String errorMessage = String.format("Lambda invoke failed with status code %s error %s ", - response.statusCode(), response.payload().asUtf8String()); - throw new RuntimeException(errorMessage); - } resultRecords.addAll(convertLambdaResponseToEvent(inputBuffer, response)); numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); @@ -194,10 +202,24 @@ public Collection> doExecute(Collection> records) { } catch (Exception e) { LOG.error(NOISY, e.getMessage(), e); - /* fall through */ - numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); - numberOfRequestsFailedCounter.increment(); - resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); + InvokeResponse response = null; + if (LambdaRetryStrategy.isRetryableException(e)){ + response = LambdaRetryStrategy.retryOrFail( + lambdaAsyncClient, + inputBuffer, + lambdaProcessorConfig, + null, + LOG + ); + String errorMessage = String.format("Lambda invoke failed with status code %s error %s. Will be Retrying the request ", + response.statusCode(), response.payload().asUtf8String()); + LOG.error(NOISY, e.getMessage(), e); + } + if(response == null || !isSuccess(response)) { + /* fall through */ + numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); + resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); + } } } return resultRecords; diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java new file mode 100644 index 0000000000..07e667f59a --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java @@ -0,0 +1,153 @@ +package org.opensearch.dataprepper.plugins.lambda.utils; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; +import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; +import org.opensearch.dataprepper.plugins.lambda.common.util.LambdaRetryStrategy; +import org.slf4j.Logger; +import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; + +import java.time.Duration; +import java.util.concurrent.CompletableFuture; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class LambdaRetryStrategyTest { + + @Mock + private LambdaAsyncClient lambdaAsyncClient; + + @Mock + private Buffer buffer; + + @Mock + private LambdaCommonConfig config; + + @Mock + private Logger logger; + + @BeforeEach + void setUp() { +// when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(CompletableFuture.completedFuture(InvokeResponse.builder().statusCode(200).build())); + when(config.getClientOptions()).thenReturn(mock(ClientOptions.class)); + when(config.getClientOptions().getMaxConnectionRetries()).thenReturn(3); + when(config.getClientOptions().getBaseDelay()).thenReturn(Duration.ofMillis(100)); + when(config.getFunctionName()).thenReturn("testFunction"); + when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); + } + +// @Test +// void testIsRetryableException() { +// assertTrue(LambdaRetryStrategy.isRetryableException(new TooManyRequestsException(null))); +// assertTrue(LambdaRetryStrategy.isRetryableException(new ServiceException(null))); +// assertTrue(LambdaRetryStrategy.isRetryableException(new SdkClientException(null))); +// assertFalse(LambdaRetryStrategy.isRetryableException(new RuntimeException())); +// } + + @Test + void testIsRetryableResponse() { + assertTrue(LambdaRetryStrategy.isRetryableResponse(InvokeResponse.builder().statusCode(429).build())); + assertTrue(LambdaRetryStrategy.isRetryableResponse(InvokeResponse.builder().statusCode(500).build())); + assertFalse(LambdaRetryStrategy.isRetryableResponse(InvokeResponse.builder().statusCode(200).build())); + } + + @Test + void testIsRetryable() { + assertTrue(LambdaRetryStrategy.isRetryable(InvokeResponse.builder().statusCode(429).build())); + assertTrue(LambdaRetryStrategy.isRetryable(InvokeResponse.builder().statusCode(500).build())); + assertFalse(LambdaRetryStrategy.isRetryable(InvokeResponse.builder().statusCode(200).build())); + assertFalse(LambdaRetryStrategy.isRetryable(null)); + } + + @Test + void testIsNonRetryable() { + assertTrue(LambdaRetryStrategy.isNonRetryable(InvokeResponse.builder().statusCode(400).build())); + assertTrue(LambdaRetryStrategy.isNonRetryable(InvokeResponse.builder().statusCode(403).build())); + assertFalse(LambdaRetryStrategy.isNonRetryable(InvokeResponse.builder().statusCode(500).build())); + assertFalse(LambdaRetryStrategy.isNonRetryable(null)); + } + + @Test + void testIsTimeoutError() { + assertTrue(LambdaRetryStrategy.isTimeoutError(InvokeResponse.builder().statusCode(408).build())); + assertTrue(LambdaRetryStrategy.isTimeoutError(InvokeResponse.builder().statusCode(429).build())); + assertFalse(LambdaRetryStrategy.isTimeoutError(InvokeResponse.builder().statusCode(200).build())); + } + + @Test + void testRetryOrFail_SuccessAfterRetry() throws Exception { + when(config.getClientOptions().getMaxConnectionRetries()).thenReturn(3); + when(config.getClientOptions().getBaseDelay()).thenReturn(Duration.ofMillis(100)); + when(config.getFunctionName()).thenReturn("testFunction"); + + InvokeRequest mockRequest = mock(InvokeRequest.class); + when(buffer.getRequestPayload(anyString(), anyString())).thenReturn(mockRequest); + + InvokeResponse failedResponse = InvokeResponse.builder().statusCode(500).build(); + InvokeResponse successResponse = InvokeResponse.builder().statusCode(200).build(); + + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.completedFuture(failedResponse)) + .thenReturn(CompletableFuture.completedFuture(successResponse)); + + InvokeResponse result = LambdaRetryStrategy.retryOrFail(lambdaAsyncClient, buffer, config, failedResponse, logger); + + assertEquals(200, result.statusCode()); + verify(lambdaAsyncClient, times(2)).invoke(any(InvokeRequest.class)); + } + + @Test + void testRetryOrFailExhaustedRetries() throws Exception { + when(config.getClientOptions().getMaxConnectionRetries()).thenReturn(3); + when(config.getClientOptions().getBaseDelay()).thenReturn(Duration.ofMillis(100)); + when(config.getFunctionName()).thenReturn("testFunction"); + + InvokeRequest mockRequest = mock(InvokeRequest.class); + when(buffer.getRequestPayload(anyString(), anyString())).thenReturn(mockRequest); + + InvokeResponse failedResponse = InvokeResponse.builder().statusCode(500).build(); + + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.completedFuture(failedResponse)); + + InvokeResponse result = LambdaRetryStrategy.retryOrFail(lambdaAsyncClient, buffer, config, failedResponse, logger); + + assertEquals(500, result.statusCode()); + verify(lambdaAsyncClient, times(3)).invoke(any(InvokeRequest.class)); + } + + @Test + void testRetryOrFail_NonRetryableResponse() { + InvokeResponse nonRetryableResponse = InvokeResponse.builder().statusCode(400).build(); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.completedFuture(nonRetryableResponse)); + when(buffer.getRequestPayload(anyString(), anyString())).thenReturn(mock(InvokeRequest.class)); + + assertThrows(RuntimeException.class, ()-> + LambdaRetryStrategy.retryOrFail(lambdaAsyncClient, buffer, config, nonRetryableResponse, logger)); + + verify(lambdaAsyncClient, times(1)).invoke(any(InvokeRequest.class)); + } + +} From b2d7d13aea0a2428941126eea18b1da57b7affbf Mon Sep 17 00:00:00 2001 From: Srikanth Govindarajan Date: Wed, 15 Jan 2025 18:33:20 -0800 Subject: [PATCH 2/6] Address Comment on complete codec Signed-off-by: Srikanth Govindarajan --- .../lambda/common/LambdaCommonHandler.java | 2 + .../lambda/common/accumlator/Buffer.java | 2 + .../common/accumlator/InMemoryBuffer.java | 16 +- .../common/util/LambdaRetryStrategy.java | 62 +++---- .../lambda/processor/LambdaProcessor.java | 5 + .../lambda/processor/LambdaProcessorTest.java | 162 ++++++++++++++++++ .../lambda/utils/LambdaRetryStrategyTest.java | 18 +- .../lambda-processor-with-retries.yaml | 14 ++ 8 files changed, 219 insertions(+), 62 deletions(-) create mode 100644 data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-with-retries.yaml diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java index 5ef5dac7ac..b4a5458d3c 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java @@ -68,12 +68,14 @@ private static List createBufferBatches(Collection> record if (ThresholdCheck.checkThresholdExceed(currentBufferPerBatch, maxEvents, maxBytes, maxCollectionDuration)) { batchedBuffers.add(currentBufferPerBatch); + currentBufferPerBatch.completeCodec(); currentBufferPerBatch = new InMemoryBuffer(keyName, outputCodecContext); } } if (currentBufferPerBatch.getEventCount() > 0) { batchedBuffers.add(currentBufferPerBatch); + currentBufferPerBatch.completeCodec(); } return batchedBuffers; } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java index b6249008cd..40b5314c3e 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java @@ -36,5 +36,7 @@ public interface Buffer { Long getPayloadRequestSize(); Duration stopLatencyWatch(); + + void completeCodec(); } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java index f3e2ea1f8f..256c263522 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java @@ -72,6 +72,16 @@ public void addRecord(Record record) { eventCount++; } + public void completeCodec() { + if (eventCount > 0) { + try { + requestCodec.complete(this.byteArrayOutputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + public List> getRecords() { return records; } @@ -98,12 +108,6 @@ public InvokeRequest getRequestPayload(String functionName, String invocationTyp return null; } - try { - requestCodec.complete(this.byteArrayOutputStream); - } catch (IOException e) { - throw new RuntimeException(e); - } - SdkBytes payload = getPayload(); payloadRequestSize = payload.asByteArray().length; diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java index 56a3c5af1f..64d4c93372 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java @@ -27,44 +27,8 @@ public final class LambdaRetryStrategy { private LambdaRetryStrategy() { } - public static boolean isRetryableException(final Throwable t) { - if (t instanceof TooManyRequestsException) { - // Throttling => often can retry with backoff - return true; - } - if (t instanceof ServiceException) { - // Usually indicates a 5xx => can retry - return true; - } - if (t instanceof SdkClientException) { - // Possibly network/connection error => can retry - return true; - } - return false; - } - - public static boolean isRetryableResponse(final InvokeResponse response) { - int statusCode = response.statusCode(); - // Throttling or internal error then retry - return (statusCode == 429) || (statusCode >= 500 && statusCode < 600); - } - - /** - * Set of status codes that should generally NOT be retried - * because they indicate client-side or permanent errors. - */ - private static final Set NON_RETRY_STATUS = new HashSet<>( - Arrays.asList( - 400, // ExpiredTokenException - 403, // IncompleteSignature, AccessDeniedException, AccessDeniedException - 404, // Not Found - 409 // Conflict - ) - ); - /** * Possibly a set of “bad request” style errors which might fall - * under the NON_RETRY_STATUS or be handled differently if you prefer. */ private static final Set BAD_REQUEST_ERRORS = new HashSet<>( Arrays.asList( @@ -95,7 +59,6 @@ public static boolean isRetryableResponse(final InvokeResponse response) { 413, // Payload Too Large 414, // URI Too Long 416 // Range Not Satisfiable - // ... ) ); @@ -115,8 +78,26 @@ public static boolean isRetryableResponse(final InvokeResponse response) { public static boolean isRetryable(final InvokeResponse response) { if(response == null) return false; int statusCode = response.statusCode(); - // Example logic: 429 (Too Many Requests) or 5xx => retry - return statusCode == 429 || (statusCode >= 500 && statusCode < 600); + return TIMEOUT_ERRORS.contains(statusCode) || (statusCode >= 500 && statusCode < 600); + } + + /* + * Note:isRetryable and isRetryableException should match + */ + public static boolean isRetryableException(final Throwable t) { + if (t instanceof TooManyRequestsException) { + // Throttling => often can retry with backoff + return true; + } + if (t instanceof ServiceException) { + // Usually indicates a 5xx => can retry + return true; + } + if (t instanceof SdkClientException) { + // Possibly network/connection error => can retry + return true; + } + return false; } /** @@ -126,8 +107,7 @@ public static boolean isNonRetryable(final InvokeResponse response) { if(response == null) return false; int statusCode = response.statusCode(); - return NON_RETRY_STATUS.contains(statusCode) - || BAD_REQUEST_ERRORS.contains(statusCode) + return BAD_REQUEST_ERRORS.contains(statusCode) || NOT_ALLOWED_ERRORS.contains(statusCode) || INVALID_INPUT_ERRORS.contains(statusCode); } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index 8adbf9cc9b..f1fb552c53 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -188,6 +188,11 @@ public Collection> doExecute(Collection> records) { LOG ); } + if(response == null || !isSuccess(response)) { + numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); + resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); + return resultRecords; + } Duration latency = inputBuffer.stopLatencyWatch(); lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java index 5c2ee0e8e6..18a624887e 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -611,4 +611,166 @@ public void testDoExecute_for_strict_and_aggregate_mode(String configFile, assertEquals("[lambda_failure]", record.getData().getMetadata().getTags().toString()); } } + + @Test + public void testDoExecute_retryScenario_successOnSecondAttempt() throws Exception { + // Arrange + final List> records = getSampleEventRecords(2); + + // First response -> 429 (Retryable) + final InvokeResponse firstResponse = InvokeResponse.builder() + .statusCode(429) + .payload(SdkBytes.fromUtf8String("First attempt throttled")) + .build(); + + // Second response -> 200 (Success) + final InvokeResponse secondResponse = InvokeResponse.builder() + .statusCode(200) + .payload(SdkBytes.fromUtf8String("[{\"successKey1\": \"successValue1\"}, {\"successKey2\": \"successValue2\"}]")) + .build(); + + // Setup stubbing with Mockito. + // The lambda client will return firstResponse, then secondResponse + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.completedFuture(firstResponse)) + .thenReturn(CompletableFuture.completedFuture(secondResponse)); + + // Create a config with at least 1 maxConnectionRetries + final LambdaProcessorConfig config = createLambdaConfigurationFromYaml("lambda-processor-success-config.yaml"); + + // Instantiate the processor and set fields + final LambdaProcessor processor = new LambdaProcessor(pluginFactory, pluginSetting, config, + awsCredentialsSupplier, expressionEvaluator); + populatePrivateFields(processor); + + // Act + final Collection> resultRecords = processor.doExecute(records); + + // Assert + // Because the second attempt is 200, we expect originalRecords to match count + // (and not have the "lambda_failure" tag). + assertEquals(records.size(), resultRecords.size()); + for (Record record : resultRecords) { + assertFalse(record.getData().getMetadata().getTags().contains("lambda_failure"), + "Record should NOT have a failure tag after a successful retry"); + } + + // The first attempt fails, but the second attempt is success => success counters increment + // Make sure the client was invoked 2 times + verify(lambdaAsyncClient, times(2)).invoke(any(InvokeRequest.class)); + // The second attempt is success + verify(numberOfRequestsSuccessCounter, times(1)).increment(); + } + + @Test + public void testDoExecute_retryScenario_failsAfterMaxRetries() throws Exception { + // Arrange + final List> records = getSampleEventRecords(3); + + // Simulate a 500 status code (Retryable) + final InvokeResponse failedResponse = InvokeResponse.builder() + .statusCode(500) + .payload(SdkBytes.fromUtf8String("Internal server error")) + .build(); + + // Stub the lambda client to always return failedResponse + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.completedFuture(failedResponse)) + .thenReturn(CompletableFuture.completedFuture(failedResponse)) + .thenReturn(CompletableFuture.completedFuture(failedResponse)); + + // Create a config with exactly 1 maxConnectionRetries (allowing 2 total attempts) + final LambdaProcessorConfig config = createLambdaConfigurationFromYaml("lambda-processor-success-config.yaml"); + + // Instantiate the processor + final LambdaProcessor processor = new LambdaProcessor(pluginFactory, pluginSetting, config, + awsCredentialsSupplier, expressionEvaluator); + populatePrivateFields(processor); + + // Act + final Collection> resultRecords = processor.doExecute(records); + + // Assert + // All records should have the "lambda_failure" tag + assertEquals(records.size(), resultRecords.size(), "Result records count should match input records count."); + for (Record record : resultRecords) { + assertTrue(record.getData().getMetadata().getTags().contains("lambda_failure"), + "Record should have 'lambda_failure' tag after all retries fail"); + } + + // Expect 3 invocations: initial attempt + 3 retry + verify(lambdaAsyncClient, times(4)).invoke(any(InvokeRequest.class)); + // No success counters + verify(numberOfRequestsSuccessCounter, never()).increment(); + // Records failed counter should increment once with the total number of records + verify(numberOfRecordsFailedCounter, times(1)).increment(records.size()); + } + + + @Test + public void testDoExecute_nonRetryableStatusCode_noRetryAttempted() throws Exception { + // Arrange + final List> records = getSampleEventRecords(2); + + // 400 is a client error => non-retryable + final InvokeResponse badRequestResponse = InvokeResponse.builder() + .statusCode(400) + .payload(SdkBytes.fromUtf8String("Bad request")) + .build(); + + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.completedFuture(badRequestResponse)); + + final LambdaProcessorConfig config = createLambdaConfigurationFromYaml("lambda-processor-with-retries.yaml"); + + final LambdaProcessor processor = new LambdaProcessor(pluginFactory, pluginSetting, config, + awsCredentialsSupplier, expressionEvaluator); + populatePrivateFields(processor); + + // Act + final Collection> resultRecords = processor.doExecute(records); + + // Assert + assertEquals(records.size(), resultRecords.size()); + for (Record record : resultRecords) { + assertTrue(record.getData().getMetadata().getTags().contains("lambda_failure"), + "Non-retryable failure should cause 'lambda_failure' tag"); + } + // Only 1 attempt => no second invoke + verify(lambdaAsyncClient, times(1)).invoke(any(InvokeRequest.class)); + // Fail counters + verify(numberOfRecordsFailedCounter).increment(2); + } + + @Test + public void testDoExecute_nonRetryableException_thrownImmediatelyFail() throws Exception { + // Arrange + final List> records = getSampleEventRecords(2); + + // Some random exception that is not in the list of retryable exceptions + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenThrow(new IllegalArgumentException("Non-retryable exception")); + + final LambdaProcessorConfig config = createLambdaConfigurationFromYaml("lambda-processor-with-retries.yaml"); + + final LambdaProcessor processor = new LambdaProcessor(pluginFactory, pluginSetting, config, + awsCredentialsSupplier, expressionEvaluator); + populatePrivateFields(processor); + + // Act + final Collection> resultRecords = processor.doExecute(records); + + // Assert + // We expect no success => all records come back tagged + assertEquals(records.size(), resultRecords.size()); + for (Record record : resultRecords) { + assertTrue(record.getData().getMetadata().getTags().contains("lambda_failure"), + "Record should have 'lambda_failure' after a non-retryable exception"); + } + + // Attempted only once + verify(lambdaAsyncClient, times(1)).invoke(any(InvokeRequest.class)); + verify(numberOfRequestsFailedCounter, times(1)).increment(); + } + } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java index 07e667f59a..a63c818017 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java @@ -13,9 +13,12 @@ import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; import org.opensearch.dataprepper.plugins.lambda.common.util.LambdaRetryStrategy; import org.slf4j.Logger; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import software.amazon.awssdk.services.lambda.model.ServiceException; +import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; import java.time.Duration; import java.util.concurrent.CompletableFuture; @@ -57,21 +60,6 @@ void setUp() { when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); } -// @Test -// void testIsRetryableException() { -// assertTrue(LambdaRetryStrategy.isRetryableException(new TooManyRequestsException(null))); -// assertTrue(LambdaRetryStrategy.isRetryableException(new ServiceException(null))); -// assertTrue(LambdaRetryStrategy.isRetryableException(new SdkClientException(null))); -// assertFalse(LambdaRetryStrategy.isRetryableException(new RuntimeException())); -// } - - @Test - void testIsRetryableResponse() { - assertTrue(LambdaRetryStrategy.isRetryableResponse(InvokeResponse.builder().statusCode(429).build())); - assertTrue(LambdaRetryStrategy.isRetryableResponse(InvokeResponse.builder().statusCode(500).build())); - assertFalse(LambdaRetryStrategy.isRetryableResponse(InvokeResponse.builder().statusCode(200).build())); - } - @Test void testIsRetryable() { assertTrue(LambdaRetryStrategy.isRetryable(InvokeResponse.builder().statusCode(429).build())); diff --git a/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-with-retries.yaml b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-with-retries.yaml new file mode 100644 index 0000000000..aa0f9eb1f8 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-with-retries.yaml @@ -0,0 +1,14 @@ +function_name: "lambdaProcessorTest" +response_events_match: true +tags_on_failure: [ "lambda_failure" ] +batch: + key_name: "osi_key" + threshold: + event_count: 100 + maximum_size: 1mb + event_collect_timeout: 335 +client: + max_retries: 2 +aws: + region: "us-east-1" + sts_role_arn: "arn:aws:iam::1234567890:role/sample-pipeine-role" \ No newline at end of file From 2f1d9577af8444f816cf811f1d214af50a02ebb9 Mon Sep 17 00:00:00 2001 From: Srikanth Govindarajan Date: Tue, 21 Jan 2025 23:48:40 -0800 Subject: [PATCH 3/6] Add retryCondidition to lambda Client Signed-off-by: Srikanth Govindarajan --- .../lambda/processor/LambdaProcessorIT.java | 97 +++++++++++++- .../common/client/LambdaClientFactory.java | 6 +- .../util/CustomLambdaRetryCondition.java | 17 +++ .../common/util/LambdaRetryStrategy.java | 51 +------ .../lambda/processor/LambdaProcessor.java | 47 ++----- .../client/LambdaClientFactoryTest.java | 57 ++++++++ .../lambda/processor/LambdaProcessorTest.java | 126 +++++++++++------- .../lambda/utils/CountingHttpClient.java | 33 +++++ .../lambda/utils/LambdaRetryStrategyTest.java | 62 +-------- .../lambda-processor-with-retries.yaml | 3 +- 10 files changed, 296 insertions(+), 203 deletions(-) create mode 100644 data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/CustomLambdaRetryCondition.java create mode 100644 data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/CountingHttpClient.java diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java index 8203819fcb..2e25c053ae 100644 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java @@ -46,17 +46,28 @@ import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; +import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition; +import org.opensearch.dataprepper.plugins.lambda.utils.CountingHttpClient; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -94,9 +105,12 @@ private LambdaProcessor createObjectUnderTest(LambdaProcessorConfig processorCon @BeforeEach public void setup() { - lambdaRegion = System.getProperty("tests.lambda.processor.region"); - functionName = System.getProperty("tests.lambda.processor.functionName"); - role = System.getProperty("tests.lambda.processor.sts_role_arn"); +// lambdaRegion = System.getProperty("tests.lambda.processor.region"); +// functionName = System.getProperty("tests.lambda.processor.functionName"); +// role = System.getProperty("tests.lambda.processor.sts_role_arn"); + lambdaRegion = "us-west-2"; + functionName = "lambdaNoReturn"; + role = "arn:aws:iam::176893235612:role/osis-s3-opensearch-role"; pluginMetrics = mock(PluginMetrics.class); pluginSetting = mock(PluginSetting.class); @@ -373,4 +387,81 @@ private List> createRecords(int numRecords) { } return records; } + + /* + * For this test, set concurrency limit to 1 + */ + @Test + void testTooManyRequestsExceptionWithCustomRetryCondition() { + //Note lambda function for this test looks like this: + /*def lambda_handler(event, context): + # Simulate a slow operation so that + # if concurrency = 1, multiple parallel invocations + # will result in TooManyRequestsException for the second+ invocation. + time.sleep(10) + # Return a simple success response + return { + "statusCode": 200, + "body": "Hello from concurrency-limited Lambda!" + } + */ + + // Wrap the default HTTP client to count requests + CountingHttpClient countingHttpClient = new CountingHttpClient( + NettyNioAsyncHttpClient.builder().build() + ); + + // Configure a custom retry policy with 3 retries and your custom condition + RetryPolicy retryPolicy = RetryPolicy.builder() + .numRetries(3) + .retryCondition(new CustomLambdaRetryCondition()) + .build(); + + // Build the real Lambda client + LambdaAsyncClient client = LambdaAsyncClient.builder() + .overrideConfiguration( + ClientOverrideConfiguration.builder() + .retryPolicy(retryPolicy) + .build() + ) + .region(Region.of(lambdaRegion)) + .httpClient(countingHttpClient) + .build(); + + // Parallel invocations to force concurrency=1 to throw TooManyRequestsException + int parallelInvocations = 10; + CompletableFuture[] futures = new CompletableFuture[parallelInvocations]; + for (int i = 0; i < parallelInvocations; i++) { + InvokeRequest request = InvokeRequest.builder() + .functionName(functionName) + .build(); + + futures[i] = client.invoke(request); + } + + // 5) Wait for all to complete + CompletableFuture.allOf(futures).join(); + + // 6) Check how many had TooManyRequestsException + long tooManyRequestsCount = Arrays.stream(futures) + .filter(f -> { + try { + f.join(); + return false; // no error => no TMR + } catch (CompletionException e) { + return e.getCause() instanceof TooManyRequestsException; + } + }) + .count(); + + // 7) Observe how many total network requests occurred (including SDK retries) + int totalRequests = countingHttpClient.getRequestCount(); + System.out.println("Total network requests (including retries): " + totalRequests); + + // Optionally: If you want to confirm the EXACT number, + // this might vary depending on how many parallel calls and how your TMR throttles them. + // For example, if all 5 calls are blocked, you might see 5*(numRetries + 1) in worst case. + assertTrue(totalRequests >= parallelInvocations, + "Should be at least one request per initial invocation, plus retries."); + } } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactory.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactory.java index 87b7a4271b..f019111b4f 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactory.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactory.java @@ -5,6 +5,7 @@ import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; +import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition; import org.opensearch.dataprepper.plugins.metricpublisher.MicrometerMetricPublisher; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; @@ -48,13 +49,14 @@ private static ClientOverrideConfiguration createOverrideConfiguration( .maxBackoffTime(clientOptions.getMaxBackoff()) .build(); - final RetryPolicy retryPolicy = RetryPolicy.builder() + final RetryPolicy customRetryPolicy = RetryPolicy.builder() + .retryCondition(new CustomLambdaRetryCondition()) .numRetries(clientOptions.getMaxConnectionRetries()) .backoffStrategy(backoffStrategy) .build(); return ClientOverrideConfiguration.builder() - .retryPolicy(retryPolicy) + .retryPolicy(customRetryPolicy) .addMetricPublisher(new MicrometerMetricPublisher(awsSdkMetrics)) .apiCallTimeout(clientOptions.getApiCallTimeout()) .build(); diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/CustomLambdaRetryCondition.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/CustomLambdaRetryCondition.java new file mode 100644 index 0000000000..08e3743c13 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/CustomLambdaRetryCondition.java @@ -0,0 +1,17 @@ +package org.opensearch.dataprepper.plugins.lambda.common.util; + +import software.amazon.awssdk.core.retry.conditions.RetryCondition; +import software.amazon.awssdk.core.retry.RetryPolicyContext; + +public class CustomLambdaRetryCondition implements RetryCondition { + + @Override + public boolean shouldRetry(RetryPolicyContext context) { + Throwable exception = context.exception(); + if (exception != null) { + return LambdaRetryStrategy.isRetryableException(exception); + } + + return false; + } +} diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java index 64d4c93372..ad5536f5f3 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java @@ -74,10 +74,7 @@ private LambdaRetryStrategy() { ) ); - - public static boolean isRetryable(final InvokeResponse response) { - if(response == null) return false; - int statusCode = response.statusCode(); + public static boolean isRetryable(final int statusCode) { return TIMEOUT_ERRORS.contains(statusCode) || (statusCode >= 500 && statusCode < 600); } @@ -120,51 +117,5 @@ public static boolean isTimeoutError(final InvokeResponse response) { return TIMEOUT_ERRORS.contains(response.statusCode()); } - public static InvokeResponse retryOrFail( - final LambdaAsyncClient lambdaAsyncClient, - final Buffer buffer, - final LambdaCommonConfig config, - final InvokeResponse previousResponse, - final Logger LOG - ) { - int maxRetries = config.getClientOptions().getMaxConnectionRetries(); - Duration backoff = config.getClientOptions().getBaseDelay(); - - int attempt = 1; - InvokeResponse response = previousResponse; - - do{ - LOG.warn("Retrying Lambda invocation attempt {} of {} after {} ms backoff", - attempt, maxRetries, backoff); - try { - // Sleep for backoff - Thread.sleep(backoff.toMillis()); - - // Re-invoke Lambda with the same payload - InvokeRequest requestPayload = buffer.getRequestPayload( - config.getFunctionName(), - config.getInvocationType().getAwsLambdaValue() - ); - // Do a synchronous call. - response = lambdaAsyncClient.invoke(requestPayload).join(); - - if (isSuccess(response)) { - LOG.info("Retry attempt {} succeeded with status code {}", attempt, response.statusCode()); - return response; - } else{ - throw new RuntimeException(); - } - } catch (Exception e) { - LOG.error("Failed to invoke failed with exception {} in attempt {}", e.getMessage(), attempt); - if(!isRetryable(response)){ - throw new RuntimeException("Failed to invoke failed",e); - } - } - attempt++; - } while(attempt <= maxRetries && isRetryable(response)); - - return response; - } - } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index f1fb552c53..0b9f8b6ffe 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -31,7 +31,6 @@ import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; -import org.opensearch.dataprepper.plugins.lambda.common.util.LambdaRetryStrategy; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.SdkBytes; @@ -177,26 +176,14 @@ public Collection> doExecute(Collection> records) { Buffer inputBuffer = entry.getKey(); try { InvokeResponse response = future.join(); - - // If this response has a failure is retryable, do a direct retry - if (!isSuccess(response) && LambdaRetryStrategy.isRetryable(response)){ - response = LambdaRetryStrategy.retryOrFail( - lambdaAsyncClient, - inputBuffer, - lambdaProcessorConfig, - response, - LOG - ); - } - if(response == null || !isSuccess(response)) { - numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); - resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); - return resultRecords; - } - Duration latency = inputBuffer.stopLatencyWatch(); lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); requestPayloadMetric.record(inputBuffer.getPayloadRequestSize()); + if (!isSuccess(response)) { + String errorMessage = String.format("Lambda invoke failed with status code %s error %s ", + response.statusCode(), response.payload().asUtf8String()); + throw new RuntimeException(errorMessage); + } resultRecords.addAll(convertLambdaResponseToEvent(inputBuffer, response)); numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); @@ -207,24 +194,10 @@ public Collection> doExecute(Collection> records) { } catch (Exception e) { LOG.error(NOISY, e.getMessage(), e); - InvokeResponse response = null; - if (LambdaRetryStrategy.isRetryableException(e)){ - response = LambdaRetryStrategy.retryOrFail( - lambdaAsyncClient, - inputBuffer, - lambdaProcessorConfig, - null, - LOG - ); - String errorMessage = String.format("Lambda invoke failed with status code %s error %s. Will be Retrying the request ", - response.statusCode(), response.payload().asUtf8String()); - LOG.error(NOISY, e.getMessage(), e); - } - if(response == null || !isSuccess(response)) { - /* fall through */ - numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); - resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); - } + /* fall through */ + numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsFailedCounter.increment(); + resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); } } return resultRecords; @@ -294,4 +267,4 @@ public boolean isReadyForShutdown() { public void shutdown() { } -} +} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java index cd68d73362..4a85733404 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java @@ -8,21 +8,40 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; +import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.core.retry.RetryPolicyContext; +import software.amazon.awssdk.core.retry.conditions.RetryCondition; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; import java.util.HashMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.spy; @ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) class LambdaClientFactoryTest { @Mock @@ -75,4 +94,42 @@ void testCreateAsyncLambdaClientOverrideConfiguration() { assertNotNull(overrideConfig.metricPublishers()); assertFalse(overrideConfig.metricPublishers().isEmpty()); } + + @Test + void testCustomRetryConditionWorks_withSpyOrRetryCondition() { + // Arrange + CustomLambdaRetryCondition customRetryCondition = new CustomLambdaRetryCondition(); + RetryCondition spyRetryCondition = spy(customRetryCondition); + + LambdaAsyncClient lambdaClient = LambdaAsyncClient.builder() + .httpClient(NettyNioAsyncHttpClient.builder().build()) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .retryPolicy(RetryPolicy.builder() + // Even though we set numRetries=3, + // the SDK may only call our custom condition once + .numRetries(3) + .retryCondition(spyRetryCondition) + .build()) + .build()) + .region(Region.US_EAST_1) + .build(); + + // Simulate a retryable exception + InvokeRequest request = InvokeRequest.builder() + .functionName("test-function") + .build(); + + // Act + try { + CompletableFuture futureResponse = lambdaClient.invoke(request); + futureResponse.join(); // Force completion + } catch (Exception e) { + } + + // Assert + // The AWS SDK's internal 'OrRetryCondition' may only call our condition once + verify(spyRetryCondition, atLeastOnce()) + .shouldRetry(any(RetryPolicyContext.class)); + } + } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java index 18a624887e..7961776c17 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -16,6 +16,19 @@ import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; +<<<<<<< HEAD +======= + +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +>>>>>>> 00f98516c (Add retryCondidition to lambda Client) import org.mockito.MockitoAnnotations; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; @@ -33,10 +46,12 @@ import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; +import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; +import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition; import org.opensearch.dataprepper.plugins.lambda.processor.exception.StrictResponseModeNotRespectedException; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; @@ -128,6 +143,9 @@ public class LambdaProcessorTest { @Mock private Timer lambdaLatencyMetric; + @Mock + private ClientOptions mockClientOptions; + @Mock private LambdaAsyncClient lambdaAsyncClient; @@ -612,55 +630,62 @@ public void testDoExecute_for_strict_and_aggregate_mode(String configFile, } } - @Test - public void testDoExecute_retryScenario_successOnSecondAttempt() throws Exception { - // Arrange - final List> records = getSampleEventRecords(2); - - // First response -> 429 (Retryable) - final InvokeResponse firstResponse = InvokeResponse.builder() - .statusCode(429) - .payload(SdkBytes.fromUtf8String("First attempt throttled")) - .build(); - - // Second response -> 200 (Success) - final InvokeResponse secondResponse = InvokeResponse.builder() - .statusCode(200) - .payload(SdkBytes.fromUtf8String("[{\"successKey1\": \"successValue1\"}, {\"successKey2\": \"successValue2\"}]")) - .build(); - - // Setup stubbing with Mockito. - // The lambda client will return firstResponse, then secondResponse - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) - .thenReturn(CompletableFuture.completedFuture(firstResponse)) - .thenReturn(CompletableFuture.completedFuture(secondResponse)); - - // Create a config with at least 1 maxConnectionRetries - final LambdaProcessorConfig config = createLambdaConfigurationFromYaml("lambda-processor-success-config.yaml"); - - // Instantiate the processor and set fields - final LambdaProcessor processor = new LambdaProcessor(pluginFactory, pluginSetting, config, - awsCredentialsSupplier, expressionEvaluator); - populatePrivateFields(processor); - - // Act - final Collection> resultRecords = processor.doExecute(records); - - // Assert - // Because the second attempt is 200, we expect originalRecords to match count - // (and not have the "lambda_failure" tag). - assertEquals(records.size(), resultRecords.size()); - for (Record record : resultRecords) { - assertFalse(record.getData().getMetadata().getTags().contains("lambda_failure"), - "Record should NOT have a failure tag after a successful retry"); - } - - // The first attempt fails, but the second attempt is success => success counters increment - // Make sure the client was invoked 2 times - verify(lambdaAsyncClient, times(2)).invoke(any(InvokeRequest.class)); - // The second attempt is success - verify(numberOfRequestsSuccessCounter, times(1)).increment(); - } + //NOTE: This test will not pass as invoke failure is handled internally through sdk. + // The first attempt will fail and the second attempt will not even be considered for execution. +// @Test +// public void testDoExecute_retryScenario_successOnSecondAttempt() throws Exception { +// // Arrange +// final List> records = getSampleEventRecords(2); +// +// // First attempt throws TooManyRequestsException => no valid payload +// when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) +// .thenReturn(CompletableFuture.failedFuture( +// TooManyRequestsException.builder() +// .message("First attempt throttled") +// .build() +// )) +// // Second attempt => success with 200 +// .thenReturn(CompletableFuture.completedFuture( +// InvokeResponse.builder() +// .statusCode(200) +// .payload(SdkBytes.fromUtf8String( +// "[{\"successKey1\":\"successValue1\"},{\"successKey2\":\"successValue2\"}]")) +// .build() +// )); +// +// // Create a config which has at least 1 maxConnectionRetries so we can retry once. +// final LambdaProcessorConfig config = createLambdaConfigurationFromYaml("lambda-processor-with-retries.yaml"); +// +// // Instantiate the processor +// final LambdaProcessor processor = new LambdaProcessor( +// pluginFactory, +// pluginSetting, +// config, +// awsCredentialsSupplier, +// expressionEvaluator +// ); +// populatePrivateFields(processor); +// +// // Act +// final Collection> resultRecords = processor.doExecute(records); +// +// // Assert +// // Because the second invocation is successful (200), +// // we expect the final records to NOT have the "lambda_failure" tag +// assertEquals(records.size(), resultRecords.size()); +// for (Record record : resultRecords) { +// assertFalse( +// record.getData().getMetadata().getTags().contains("lambda_failure"), +// "Record should NOT have a failure tag after a successful retry" +// ); +// } +// +// // We invoked the lambda client 2 times total: first attempt + one retry +// verify(lambdaAsyncClient, times(2)).invoke(any(InvokeRequest.class)); +// +// // Second attempt is success => increment success counters +// verify(numberOfRequestsSuccessCounter, times(1)).increment(); +// } @Test public void testDoExecute_retryScenario_failsAfterMaxRetries() throws Exception { @@ -699,7 +724,7 @@ public void testDoExecute_retryScenario_failsAfterMaxRetries() throws Exception } // Expect 3 invocations: initial attempt + 3 retry - verify(lambdaAsyncClient, times(4)).invoke(any(InvokeRequest.class)); + verify(lambdaAsyncClient, atLeastOnce()).invoke(any(InvokeRequest.class)); // No success counters verify(numberOfRequestsSuccessCounter, never()).increment(); // Records failed counter should increment once with the total number of records @@ -772,5 +797,4 @@ public void testDoExecute_nonRetryableException_thrownImmediatelyFail() throws E verify(lambdaAsyncClient, times(1)).invoke(any(InvokeRequest.class)); verify(numberOfRequestsFailedCounter, times(1)).increment(); } - } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/CountingHttpClient.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/CountingHttpClient.java new file mode 100644 index 0000000000..8fc58feccf --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/CountingHttpClient.java @@ -0,0 +1,33 @@ +package org.opensearch.dataprepper.plugins.lambda.utils; + +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.async.AsyncExecuteRequest; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +public class CountingHttpClient implements SdkAsyncHttpClient { + private final SdkAsyncHttpClient delegate; + private final AtomicInteger requestCount = new AtomicInteger(0); + + public CountingHttpClient(SdkAsyncHttpClient delegate) { + this.delegate = delegate; + } + + @Override + public CompletableFuture execute(AsyncExecuteRequest request) { + requestCount.incrementAndGet(); + return delegate.execute(request); + } + + @Override + public void close() { + delegate.close(); + } + + public int getRequestCount() { + return requestCount.get(); + } +} + diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java index a63c818017..74c73baf92 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java @@ -62,10 +62,9 @@ void setUp() { @Test void testIsRetryable() { - assertTrue(LambdaRetryStrategy.isRetryable(InvokeResponse.builder().statusCode(429).build())); - assertTrue(LambdaRetryStrategy.isRetryable(InvokeResponse.builder().statusCode(500).build())); - assertFalse(LambdaRetryStrategy.isRetryable(InvokeResponse.builder().statusCode(200).build())); - assertFalse(LambdaRetryStrategy.isRetryable(null)); + assertTrue(LambdaRetryStrategy.isRetryable(429)); + assertTrue(LambdaRetryStrategy.isRetryable(500)); + assertFalse(LambdaRetryStrategy.isRetryable(200)); } @Test @@ -83,59 +82,4 @@ void testIsTimeoutError() { assertFalse(LambdaRetryStrategy.isTimeoutError(InvokeResponse.builder().statusCode(200).build())); } - @Test - void testRetryOrFail_SuccessAfterRetry() throws Exception { - when(config.getClientOptions().getMaxConnectionRetries()).thenReturn(3); - when(config.getClientOptions().getBaseDelay()).thenReturn(Duration.ofMillis(100)); - when(config.getFunctionName()).thenReturn("testFunction"); - - InvokeRequest mockRequest = mock(InvokeRequest.class); - when(buffer.getRequestPayload(anyString(), anyString())).thenReturn(mockRequest); - - InvokeResponse failedResponse = InvokeResponse.builder().statusCode(500).build(); - InvokeResponse successResponse = InvokeResponse.builder().statusCode(200).build(); - - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) - .thenReturn(CompletableFuture.completedFuture(failedResponse)) - .thenReturn(CompletableFuture.completedFuture(successResponse)); - - InvokeResponse result = LambdaRetryStrategy.retryOrFail(lambdaAsyncClient, buffer, config, failedResponse, logger); - - assertEquals(200, result.statusCode()); - verify(lambdaAsyncClient, times(2)).invoke(any(InvokeRequest.class)); - } - - @Test - void testRetryOrFailExhaustedRetries() throws Exception { - when(config.getClientOptions().getMaxConnectionRetries()).thenReturn(3); - when(config.getClientOptions().getBaseDelay()).thenReturn(Duration.ofMillis(100)); - when(config.getFunctionName()).thenReturn("testFunction"); - - InvokeRequest mockRequest = mock(InvokeRequest.class); - when(buffer.getRequestPayload(anyString(), anyString())).thenReturn(mockRequest); - - InvokeResponse failedResponse = InvokeResponse.builder().statusCode(500).build(); - - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) - .thenReturn(CompletableFuture.completedFuture(failedResponse)); - - InvokeResponse result = LambdaRetryStrategy.retryOrFail(lambdaAsyncClient, buffer, config, failedResponse, logger); - - assertEquals(500, result.statusCode()); - verify(lambdaAsyncClient, times(3)).invoke(any(InvokeRequest.class)); - } - - @Test - void testRetryOrFail_NonRetryableResponse() { - InvokeResponse nonRetryableResponse = InvokeResponse.builder().statusCode(400).build(); - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) - .thenReturn(CompletableFuture.completedFuture(nonRetryableResponse)); - when(buffer.getRequestPayload(anyString(), anyString())).thenReturn(mock(InvokeRequest.class)); - - assertThrows(RuntimeException.class, ()-> - LambdaRetryStrategy.retryOrFail(lambdaAsyncClient, buffer, config, nonRetryableResponse, logger)); - - verify(lambdaAsyncClient, times(1)).invoke(any(InvokeRequest.class)); - } - } diff --git a/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-with-retries.yaml b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-with-retries.yaml index aa0f9eb1f8..c518d0d335 100644 --- a/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-with-retries.yaml +++ b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-with-retries.yaml @@ -8,7 +8,8 @@ batch: maximum_size: 1mb event_collect_timeout: 335 client: - max_retries: 2 + max_retries: 50 + max_concurrency: 5 aws: region: "us-east-1" sts_role_arn: "arn:aws:iam::1234567890:role/sample-pipeine-role" \ No newline at end of file From 76c1f9cf9ba3ffd7a581bf96492b1627f0e077bf Mon Sep 17 00:00:00 2001 From: Srikanth Govindarajan Date: Fri, 24 Jan 2025 17:56:50 -0800 Subject: [PATCH 4/6] Address comments Signed-off-by: Srikanth Govindarajan --- .../lambda/processor/LambdaProcessorIT.java | 10 +++------- .../lambda/common/util/LambdaRetryStrategy.java | 9 +-------- .../common/client/LambdaClientFactoryTest.java | 5 ----- .../lambda/processor/LambdaProcessorTest.java | 4 ---- .../lambda/utils/CountingHttpClient.java | 1 - .../lambda/utils/LambdaRetryStrategyTest.java | 17 +++-------------- 6 files changed, 7 insertions(+), 39 deletions(-) diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java index 2e25c053ae..839bcde412 100644 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java @@ -105,12 +105,9 @@ private LambdaProcessor createObjectUnderTest(LambdaProcessorConfig processorCon @BeforeEach public void setup() { -// lambdaRegion = System.getProperty("tests.lambda.processor.region"); -// functionName = System.getProperty("tests.lambda.processor.functionName"); -// role = System.getProperty("tests.lambda.processor.sts_role_arn"); - lambdaRegion = "us-west-2"; - functionName = "lambdaNoReturn"; - role = "arn:aws:iam::176893235612:role/osis-s3-opensearch-role"; + lambdaRegion = System.getProperty("tests.lambda.processor.region"); + functionName = System.getProperty("tests.lambda.processor.functionName"); + role = System.getProperty("tests.lambda.processor.sts_role_arn"); pluginMetrics = mock(PluginMetrics.class); pluginSetting = mock(PluginSetting.class); @@ -456,7 +453,6 @@ void testTooManyRequestsExceptionWithCustomRetryCondition() { // 7) Observe how many total network requests occurred (including SDK retries) int totalRequests = countingHttpClient.getRequestCount(); - System.out.println("Total network requests (including retries): " + totalRequests); // Optionally: If you want to confirm the EXACT number, // this might vary depending on how many parallel calls and how your TMR throttles them. diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java index ad5536f5f3..6ac43a60c1 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java @@ -1,21 +1,14 @@ package org.opensearch.dataprepper.plugins.lambda.common.util; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; import software.amazon.awssdk.services.lambda.model.ServiceException; -import java.time.Duration; import java.util.Arrays; import java.util.HashSet; import java.util.Set; -import org.slf4j.Logger; -import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; /** * Similar to BulkRetryStrategy in the OpenSearch sink. @@ -74,7 +67,7 @@ private LambdaRetryStrategy() { ) ); - public static boolean isRetryable(final int statusCode) { + public static boolean isRetryableStatusCode(final int statusCode) { return TIMEOUT_ERRORS.contains(statusCode) || (statusCode >= 500 && statusCode < 600); } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java index 4a85733404..07978bfe85 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java @@ -25,17 +25,12 @@ import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; import java.util.HashMap; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicInteger; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.mockito.Mockito.spy; diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java index 7961776c17..55a69a42d8 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -21,10 +21,8 @@ import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -46,12 +44,10 @@ import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; -import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; -import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition; import org.opensearch.dataprepper.plugins.lambda.processor.exception.StrictResponseModeNotRespectedException; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/CountingHttpClient.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/CountingHttpClient.java index 8fc58feccf..feddd99538 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/CountingHttpClient.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/CountingHttpClient.java @@ -2,7 +2,6 @@ import software.amazon.awssdk.http.async.SdkAsyncHttpClient; import software.amazon.awssdk.http.async.AsyncExecuteRequest; -import software.amazon.awssdk.utils.CompletableFutureUtils; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java index 74c73baf92..064b24d8fc 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaRetryStrategyTest.java @@ -13,25 +13,14 @@ import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; import org.opensearch.dataprepper.plugins.lambda.common.util.LambdaRetryStrategy; import org.slf4j.Logger; -import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import software.amazon.awssdk.services.lambda.model.ServiceException; -import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -62,9 +51,9 @@ void setUp() { @Test void testIsRetryable() { - assertTrue(LambdaRetryStrategy.isRetryable(429)); - assertTrue(LambdaRetryStrategy.isRetryable(500)); - assertFalse(LambdaRetryStrategy.isRetryable(200)); + assertTrue(LambdaRetryStrategy.isRetryableStatusCode(429)); + assertTrue(LambdaRetryStrategy.isRetryableStatusCode(500)); + assertFalse(LambdaRetryStrategy.isRetryableStatusCode(200)); } @Test From f86437a5cdbca03c89c4204efcb15ea8f0abead5 Mon Sep 17 00:00:00 2001 From: Srikanth Govindarajan Date: Tue, 28 Jan 2025 10:06:11 -0800 Subject: [PATCH 5/6] Address comments and add UT and IT Signed-off-by: Srikanth Govindarajan --- .../lambda/processor/LambdaProcessorIT.java | 215 ++++++++++++------ .../common/util/CountingRetryCondition.java | 23 ++ .../common/util/LambdaRetryStrategy.java | 24 +- .../client/LambdaClientFactoryTest.java | 121 +++++++--- 4 files changed, 261 insertions(+), 122 deletions(-) create mode 100644 data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/CountingRetryCondition.java diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java index 839bcde412..7b087f35bf 100644 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java @@ -22,8 +22,11 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; + +import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; @@ -42,12 +45,13 @@ import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec; import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; -import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition; -import org.opensearch.dataprepper.plugins.lambda.utils.CountingHttpClient; +import org.opensearch.dataprepper.plugins.lambda.common.util.CountingRetryCondition; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; @@ -55,22 +59,20 @@ import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; +import java.time.Duration; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; @ExtendWith(MockitoExtension.class) @@ -108,6 +110,10 @@ public void setup() { lambdaRegion = System.getProperty("tests.lambda.processor.region"); functionName = System.getProperty("tests.lambda.processor.functionName"); role = System.getProperty("tests.lambda.processor.sts_role_arn"); + lambdaRegion = "us-west-2"; + functionName = "lambdaNoReturn"; + role = "arn:aws:iam::176893235612:role/osis-s3-opensearch-role"; + pluginMetrics = mock(PluginMetrics.class); pluginSetting = mock(PluginSetting.class); @@ -385,79 +391,142 @@ private List> createRecords(int numRecords) { return records; } - /* - * For this test, set concurrency limit to 1 - */ @Test - void testTooManyRequestsExceptionWithCustomRetryCondition() { - //Note lambda function for this test looks like this: - /*def lambda_handler(event, context): - # Simulate a slow operation so that - # if concurrency = 1, multiple parallel invocations - # will result in TooManyRequestsException for the second+ invocation. - time.sleep(10) - # Return a simple success response - return { - "statusCode": 200, - "body": "Hello from concurrency-limited Lambda!" - } - */ + void testRetryLogicWithThrottlingUsingMultipleThreads() throws Exception { + /* + * This test tries to create multiple parallel Lambda invocations + * while concurrency=1. The first invocation "occupies" the single concurrency slot + * The subsequent invocations should then get a 429 TooManyRequestsException, + * triggering our CountingRetryCondition. + */ + + /* Lambda handler function looks like this: + def lambda_handler(event, context): + # Simulate a slow operation so that + # if concurrency = 1, multiple parallel invocations + # will result in TooManyRequestsException for the second+ invocation. + time.sleep(10) + + # Return a simple success response + return { + "statusCode": 200, + "body": "Hello from concurrency-limited Lambda!" + } - // Wrap the default HTTP client to count requests - CountingHttpClient countingHttpClient = new CountingHttpClient( - NettyNioAsyncHttpClient.builder().build() - ); + */ - // Configure a custom retry policy with 3 retries and your custom condition - RetryPolicy retryPolicy = RetryPolicy.builder() - .numRetries(3) - .retryCondition(new CustomLambdaRetryCondition()) - .build(); + functionName = "lambdaExceptionSimulation"; + // Create a CountingRetryCondition + CountingRetryCondition countingRetryCondition = new CountingRetryCondition(); - // Build the real Lambda client - LambdaAsyncClient client = LambdaAsyncClient.builder() - .overrideConfiguration( - ClientOverrideConfiguration.builder() - .retryPolicy(retryPolicy) - .build() - ) - .region(Region.of(lambdaRegion)) - .httpClient(countingHttpClient) - .build(); + // Configure a LambdaProcessorConfig + + // We'll set invocation type to RequestResponse + InvocationType invocationType = mock(InvocationType.class); + when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); + when(lambdaProcessorConfig.getInvocationType()).thenReturn(invocationType); + + when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); + // If your code uses "responseEventsMatch", you can set it: + when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true); - // Parallel invocations to force concurrency=1 to throw TooManyRequestsException - int parallelInvocations = 10; - CompletableFuture[] futures = new CompletableFuture[parallelInvocations]; - for (int i = 0; i < parallelInvocations; i++) { - InvokeRequest request = InvokeRequest.builder() - .functionName(functionName) + // Set up mock ClientOptions for concurrency + small retries + ClientOptions clientOptions = mock(ClientOptions.class); + when(clientOptions.getMaxConnectionRetries()).thenReturn(3); // up to 3 retries + when(clientOptions.getMaxConcurrency()).thenReturn(5); + when(clientOptions.getConnectionTimeout()).thenReturn(Duration.ofSeconds(5)); + when(clientOptions.getApiCallTimeout()).thenReturn(Duration.ofSeconds(30)); + when(lambdaProcessorConfig.getClientOptions()).thenReturn(clientOptions); + + // AWS auth + AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(lambdaRegion)); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(role); + when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(null); + when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(null); + when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + + // Setup the mock for getProvider + when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider); + + // Mock the factory to inject our CountingRetryCondition into the LambdaAsyncClient + try (MockedStatic mockedFactory = mockStatic(LambdaClientFactory.class)) { + + LambdaAsyncClient clientWithCountingCondition = LambdaAsyncClient.builder() + .region(Region.of(lambdaRegion)) + .credentialsProvider(awsCredentialsProvider) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .retryPolicy( + RetryPolicy.builder() + .retryCondition(countingRetryCondition) + .numRetries(3) + .build() + ) + .build()) + // netty concurrency = 5 to allow parallel requests + .httpClient(NettyNioAsyncHttpClient.builder() + .maxConcurrency(5) + .build()) .build(); - futures[i] = client.invoke(request); - } + mockedFactory.when(() -> + LambdaClientFactory.createAsyncLambdaClient( + any(AwsAuthenticationOptions.class), + any(AwsCredentialsSupplier.class), + any(ClientOptions.class))) + .thenReturn(clientWithCountingCondition); + + // 7) Instantiate the real LambdaProcessor + when(pluginSetting.getName()).thenReturn("lambda-processor"); + when(pluginSetting.getPipelineName()).thenReturn("test-pipeline"); + lambdaProcessor = new LambdaProcessor( + pluginFactory, + pluginSetting, + lambdaProcessorConfig, + awsCredentialsSupplier, + expressionEvaluator + ); + + // Create multiple parallel tasks to call doExecute(...) + // Each doExecute() invocation sends records to Lambda in an async manner. + int parallelInvocations = 5; + ExecutorService executor = Executors.newFixedThreadPool(parallelInvocations); + + List>>> futures = new ArrayList<>(); + for (int i = 0; i < parallelInvocations; i++) { + // Each subset of records calls the processor + List> records = createRecords(2); + Future>> future = executor.submit(() -> { + return lambdaProcessor.doExecute(records); + }); + futures.add(future); + } - // 5) Wait for all to complete - CompletableFuture.allOf(futures).join(); - - // 6) Check how many had TooManyRequestsException - long tooManyRequestsCount = Arrays.stream(futures) - .filter(f -> { - try { - f.join(); - return false; // no error => no TMR - } catch (CompletionException e) { - return e.getCause() instanceof TooManyRequestsException; - } - }) - .count(); - - // 7) Observe how many total network requests occurred (including SDK retries) - int totalRequests = countingHttpClient.getRequestCount(); - - // Optionally: If you want to confirm the EXACT number, - // this might vary depending on how many parallel calls and how your TMR throttles them. - // For example, if all 5 calls are blocked, you might see 5*(numRetries + 1) in worst case. - assertTrue(totalRequests >= parallelInvocations, - "Should be at least one request per initial invocation, plus retries."); + // Wait for all tasks to complete + executor.shutdown(); + boolean finishedInTime = executor.awaitTermination(5, TimeUnit.MINUTES); + if (!finishedInTime) { + throw new RuntimeException("Test timed out waiting for executor tasks to complete."); + } + + // Check results or handle exceptions + for (Future>> f : futures) { + try { + Collection> out = f.get(); + } catch (ExecutionException ee) { + // A 429 from AWS will be thrown as TooManyRequestsException + // If all retries failed, we might see an exception here. + } + } + + // 11) Finally, check that we had at least one retry + // If concurrency=1 is truly enforced, at least some calls should have gotten a 429 + // -> triggered CountingRetryCondition + int retryCount = countingRetryCondition.getRetryCount(); + assertTrue( + retryCount > 0, + "Should have at least one retry due to concurrency-based throttling (429)." + ); + } } } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/CountingRetryCondition.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/CountingRetryCondition.java new file mode 100644 index 0000000000..afc7c756dd --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/CountingRetryCondition.java @@ -0,0 +1,23 @@ +package org.opensearch.dataprepper.plugins.lambda.common.util; + +import software.amazon.awssdk.core.retry.RetryPolicyContext; + +import java.util.concurrent.atomic.AtomicInteger; + +//Used ONLY for tests +public class CountingRetryCondition extends CustomLambdaRetryCondition { + private final AtomicInteger retryCount = new AtomicInteger(0); + + @Override + public boolean shouldRetry(RetryPolicyContext context) { + boolean shouldRetry = super.shouldRetry(context); + if (shouldRetry) { + retryCount.incrementAndGet(); + } + return shouldRetry; + } + + public int getRetryCount() { + return retryCount.get(); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java index 6ac43a60c1..82abf37832 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/LambdaRetryStrategy.java @@ -5,8 +5,6 @@ import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; import software.amazon.awssdk.services.lambda.model.ServiceException; -import java.util.Arrays; -import java.util.HashSet; import java.util.Set; @@ -23,49 +21,41 @@ private LambdaRetryStrategy() { /** * Possibly a set of “bad request” style errors which might fall */ - private static final Set BAD_REQUEST_ERRORS = new HashSet<>( - Arrays.asList( + private static final Set BAD_REQUEST_ERRORS = Set.of( 400, // Bad Request 422, // Unprocessable Entity 417, // Expectation Failed 406 // Not Acceptable - ) ); /** * Status codes which may indicate a security or policy problem, so we don't retry. */ - private static final Set NOT_ALLOWED_ERRORS = new HashSet<>( - Arrays.asList( + private static final Set NOT_ALLOWED_ERRORS = Set.of( 401, // Unauthorized 403, // Forbidden 405 // Method Not Allowed - ) - ); + ); /** * Examples of input or payload errors that are likely not retryable * unless the pipeline itself corrects them. */ - private static final Set INVALID_INPUT_ERRORS = new HashSet<>( - Arrays.asList( + private static final Set INVALID_INPUT_ERRORS = Set.of( 413, // Payload Too Large 414, // URI Too Long 416 // Range Not Satisfiable - ) - ); + ); /** * Example of a “timeout” scenario. Lambda can return 429 for "Too Many Requests" or * 408 (if applicable) for timeouts in some contexts. * This can be considered retryable if you want to handle the throttling scenario. */ - private static final Set TIMEOUT_ERRORS = new HashSet<>( - Arrays.asList( + private static final Set TIMEOUT_ERRORS = Set.of( 408, // Request Timeout 429 // Too Many Requests (often used as "throttling" for Lambda) - ) - ); + ); public static boolean isRetryableStatusCode(final int statusCode) { return TIMEOUT_ERRORS.contains(statusCode) || (statusCode >= 500 && statusCode < 600); diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java index 07978bfe85..9435721384 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java @@ -14,26 +14,24 @@ import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; -import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition; +import org.opensearch.dataprepper.plugins.lambda.common.util.CountingRetryCondition; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; -import software.amazon.awssdk.core.retry.RetryPolicy; import software.amazon.awssdk.core.retry.RetryPolicyContext; -import software.amazon.awssdk.core.retry.conditions.RetryCondition; -import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import software.amazon.awssdk.services.lambda.model.TooManyRequestsException; import java.util.HashMap; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.spy; @ExtendWith(MockitoExtension.class) @MockitoSettings(strictness = Strictness.LENIENT) @@ -91,40 +89,99 @@ void testCreateAsyncLambdaClientOverrideConfiguration() { } @Test - void testCustomRetryConditionWorks_withSpyOrRetryCondition() { + void testRetryConditionIsCalledWithTooManyRequestsException() { // Arrange - CustomLambdaRetryCondition customRetryCondition = new CustomLambdaRetryCondition(); - RetryCondition spyRetryCondition = spy(customRetryCondition); - - LambdaAsyncClient lambdaClient = LambdaAsyncClient.builder() - .httpClient(NettyNioAsyncHttpClient.builder().build()) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .retryPolicy(RetryPolicy.builder() - // Even though we set numRetries=3, - // the SDK may only call our custom condition once - .numRetries(3) - .retryCondition(spyRetryCondition) - .build()) - .build()) - .region(Region.US_EAST_1) + CountingRetryCondition countingRetryCondition = new CountingRetryCondition(); + + // Create mock Lambda client + LambdaAsyncClient mockClient = mock(LambdaAsyncClient.class); + + // Setup mock to return TooManyRequestsException for the first 3 calls + when(mockClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.failedFuture(TooManyRequestsException.builder().build())) + .thenReturn(CompletableFuture.failedFuture(TooManyRequestsException.builder().build())) + .thenReturn(CompletableFuture.failedFuture(TooManyRequestsException.builder().build())); + + // Create test request + InvokeRequest request = InvokeRequest.builder() + .functionName("test-function") .build(); - // Simulate a retryable exception + // Simulate retries + for (int i = 0; i < 3; i++) { + try { + CompletableFuture future = mockClient.invoke(request); + RetryPolicyContext context = RetryPolicyContext.builder() + .exception(TooManyRequestsException.builder().build()) + .retriesAttempted(i) + .build(); + + // Test the retry condition + countingRetryCondition.shouldRetry(context); + + future.join(); + } catch (CompletionException e) { + assertTrue(e.getCause() instanceof TooManyRequestsException); + } + } + + // Verify retry count + assertEquals(3, countingRetryCondition.getRetryCount(), + "Retry condition should have been called exactly 3 times"); + } + + @Test + void testRetryConditionFirstFailsAndThenSucceeds() { + // Arrange + CountingRetryCondition countingRetryCondition = new CountingRetryCondition(); + + // Create mock Lambda client + LambdaAsyncClient mockClient = mock(LambdaAsyncClient.class); + + // Setup mock to return TooManyRequestsException for first 2 calls, then succeed on 3rd + when(mockClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.failedFuture(TooManyRequestsException.builder().build())) + .thenReturn(CompletableFuture.failedFuture(TooManyRequestsException.builder().build())) + .thenReturn(CompletableFuture.completedFuture(InvokeResponse.builder() + .statusCode(200) + .build())); + + // Create test request InvokeRequest request = InvokeRequest.builder() .functionName("test-function") .build(); - // Act - try { - CompletableFuture futureResponse = lambdaClient.invoke(request); - futureResponse.join(); // Force completion - } catch (Exception e) { + // Track if we reached success + boolean successReached = false; + + // Simulate retries with eventual success + for (int i = 0; i < 3; i++) { + try { + CompletableFuture future = mockClient.invoke(request); + + if (i < 2) { + // For first two attempts, verify retry condition + RetryPolicyContext context = RetryPolicyContext.builder() + .exception(TooManyRequestsException.builder().build()) + .retriesAttempted(i) + .build(); + countingRetryCondition.shouldRetry(context); + } + + InvokeResponse response = future.join(); + if (response.statusCode() == 200) { + successReached = true; + } + } catch (CompletionException e) { + assertTrue(e.getCause() instanceof TooManyRequestsException, + "Exception should be TooManyRequestsException"); + } } - // Assert - // The AWS SDK's internal 'OrRetryCondition' may only call our condition once - verify(spyRetryCondition, atLeastOnce()) - .shouldRetry(any(RetryPolicyContext.class)); + // Verify retry count and success + assertEquals(2, countingRetryCondition.getRetryCount(), + "Retry condition should have been called exactly 2 times"); + assertTrue(successReached, "Should have reached successful completion"); } } From b00b92c0028362e4d2082685baa3c43cf6b17ae6 Mon Sep 17 00:00:00 2001 From: Srikanth Govindarajan Date: Tue, 28 Jan 2025 11:50:13 -0800 Subject: [PATCH 6/6] Address comment on completeCodec Signed-off-by: Srikanth Govindarajan --- .../lambda/processor/LambdaProcessorIT.java | 6 +----- .../plugins/lambda/common/LambdaCommonHandler.java | 2 -- .../plugins/lambda/common/accumlator/Buffer.java | 2 -- .../lambda/common/accumlator/InMemoryBuffer.java | 4 +++- .../lambda/processor/LambdaProcessorTest.java | 14 ++------------ 5 files changed, 6 insertions(+), 22 deletions(-) diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java index 7b087f35bf..38769c630e 100644 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java @@ -110,10 +110,6 @@ public void setup() { lambdaRegion = System.getProperty("tests.lambda.processor.region"); functionName = System.getProperty("tests.lambda.processor.functionName"); role = System.getProperty("tests.lambda.processor.sts_role_arn"); - lambdaRegion = "us-west-2"; - functionName = "lambdaNoReturn"; - role = "arn:aws:iam::176893235612:role/osis-s3-opensearch-role"; - pluginMetrics = mock(PluginMetrics.class); pluginSetting = mock(PluginSetting.class); @@ -519,7 +515,7 @@ def lambda_handler(event, context): } } - // 11) Finally, check that we had at least one retry + // Finally, check that we had at least one retry // If concurrency=1 is truly enforced, at least some calls should have gotten a 429 // -> triggered CountingRetryCondition int retryCount = countingRetryCondition.getRetryCount(); diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java index b4a5458d3c..5ef5dac7ac 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java @@ -68,14 +68,12 @@ private static List createBufferBatches(Collection> record if (ThresholdCheck.checkThresholdExceed(currentBufferPerBatch, maxEvents, maxBytes, maxCollectionDuration)) { batchedBuffers.add(currentBufferPerBatch); - currentBufferPerBatch.completeCodec(); currentBufferPerBatch = new InMemoryBuffer(keyName, outputCodecContext); } } if (currentBufferPerBatch.getEventCount() > 0) { batchedBuffers.add(currentBufferPerBatch); - currentBufferPerBatch.completeCodec(); } return batchedBuffers; } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java index 40b5314c3e..b6249008cd 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java @@ -36,7 +36,5 @@ public interface Buffer { Long getPayloadRequestSize(); Duration stopLatencyWatch(); - - void completeCodec(); } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java index 256c263522..662436a01f 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java @@ -72,7 +72,7 @@ public void addRecord(Record record) { eventCount++; } - public void completeCodec() { + void completeCodec() { if (eventCount > 0) { try { requestCodec.complete(this.byteArrayOutputStream); @@ -108,6 +108,8 @@ public InvokeRequest getRequestPayload(String functionName, String invocationTyp return null; } + completeCodec(); + SdkBytes payload = getPayload(); payloadRequestSize = payload.asByteArray().length; diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java index 55a69a42d8..72b3a74d32 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -16,17 +16,6 @@ import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; -<<<<<<< HEAD -======= - -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; ->>>>>>> 00f98516c (Add retryCondidition to lambda Client) import org.mockito.MockitoAnnotations; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; @@ -77,10 +66,11 @@ import static org.mockito.ArgumentMatchers.anyDouble; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when;