Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lambda processor should retry for certain class of exceptions #5320

Merged
merged 6 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
Expand All @@ -42,24 +45,34 @@
import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec;
import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType;
import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CountingRetryCondition;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -373,4 +386,143 @@ private List<Record<Event>> createRecords(int numRecords) {
}
return records;
}

@Test
void testRetryLogicWithThrottlingUsingMultipleThreads() throws Exception {
/*
* This test tries to create multiple parallel Lambda invocations
* while concurrency=1. The first invocation "occupies" the single concurrency slot
* The subsequent invocations should then get a 429 TooManyRequestsException,
* triggering our CountingRetryCondition.
*/

/* Lambda handler function looks like this:
def lambda_handler(event, context):
# Simulate a slow operation so that
# if concurrency = 1, multiple parallel invocations
# will result in TooManyRequestsException for the second+ invocation.
time.sleep(10)

# Return a simple success response
return {
"statusCode": 200,
"body": "Hello from concurrency-limited Lambda!"
}

*/

functionName = "lambdaExceptionSimulation";
// Create a CountingRetryCondition
CountingRetryCondition countingRetryCondition = new CountingRetryCondition();

// Configure a LambdaProcessorConfig

// We'll set invocation type to RequestResponse
InvocationType invocationType = mock(InvocationType.class);
when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue());
when(lambdaProcessorConfig.getInvocationType()).thenReturn(invocationType);

when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName);
// If your code uses "responseEventsMatch", you can set it:
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true);

// Set up mock ClientOptions for concurrency + small retries
ClientOptions clientOptions = mock(ClientOptions.class);
when(clientOptions.getMaxConnectionRetries()).thenReturn(3); // up to 3 retries
when(clientOptions.getMaxConcurrency()).thenReturn(5);
when(clientOptions.getConnectionTimeout()).thenReturn(Duration.ofSeconds(5));
when(clientOptions.getApiCallTimeout()).thenReturn(Duration.ofSeconds(30));
when(lambdaProcessorConfig.getClientOptions()).thenReturn(clientOptions);

// AWS auth
AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(lambdaRegion));
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(role);
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(null);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(null);
when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);

// Setup the mock for getProvider
when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider);

// Mock the factory to inject our CountingRetryCondition into the LambdaAsyncClient
try (MockedStatic<LambdaClientFactory> mockedFactory = mockStatic(LambdaClientFactory.class)) {

LambdaAsyncClient clientWithCountingCondition = LambdaAsyncClient.builder()
.region(Region.of(lambdaRegion))
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(ClientOverrideConfiguration.builder()
.retryPolicy(
RetryPolicy.builder()
.retryCondition(countingRetryCondition)
.numRetries(3)
.build()
)
.build())
// netty concurrency = 5 to allow parallel requests
.httpClient(NettyNioAsyncHttpClient.builder()
.maxConcurrency(5)
.build())
.build();

mockedFactory.when(() ->
LambdaClientFactory.createAsyncLambdaClient(
any(AwsAuthenticationOptions.class),
any(AwsCredentialsSupplier.class),
any(ClientOptions.class)))
.thenReturn(clientWithCountingCondition);

// 7) Instantiate the real LambdaProcessor
when(pluginSetting.getName()).thenReturn("lambda-processor");
when(pluginSetting.getPipelineName()).thenReturn("test-pipeline");
lambdaProcessor = new LambdaProcessor(
pluginFactory,
pluginSetting,
lambdaProcessorConfig,
awsCredentialsSupplier,
expressionEvaluator
);

// Create multiple parallel tasks to call doExecute(...)
// Each doExecute() invocation sends records to Lambda in an async manner.
int parallelInvocations = 5;
ExecutorService executor = Executors.newFixedThreadPool(parallelInvocations);

List<Future<Collection<Record<Event>>>> futures = new ArrayList<>();
for (int i = 0; i < parallelInvocations; i++) {
// Each subset of records calls the processor
List<Record<Event>> records = createRecords(2);
Future<Collection<Record<Event>>> future = executor.submit(() -> {
return lambdaProcessor.doExecute(records);
});
futures.add(future);
}

// Wait for all tasks to complete
executor.shutdown();
boolean finishedInTime = executor.awaitTermination(5, TimeUnit.MINUTES);
if (!finishedInTime) {
throw new RuntimeException("Test timed out waiting for executor tasks to complete.");
}

// Check results or handle exceptions
for (Future<Collection<Record<Event>>> f : futures) {
try {
Collection<Record<Event>> out = f.get();
} catch (ExecutionException ee) {
// A 429 from AWS will be thrown as TooManyRequestsException
// If all retries failed, we might see an exception here.
}
}

// Finally, check that we had at least one retry
// If concurrency=1 is truly enforced, at least some calls should have gotten a 429
// -> triggered CountingRetryCondition
int retryCount = countingRetryCondition.getRetryCount();
assertTrue(
retryCount > 0,
"Should have at least one retry due to concurrency-based throttling (429)."
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ private LambdaCommonHandler() {
}

public static boolean isSuccess(InvokeResponse response) {
if(response == null) {
return false;
}
int statusCode = response.statusCode();
return statusCode >= 200 && statusCode < 300;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ public void addRecord(Record<Event> record) {
eventCount++;
}

void completeCodec() {
if (eventCount > 0) {
try {
requestCodec.complete(this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

public List<Record<Event>> getRecords() {
return records;
}
Expand All @@ -98,11 +108,7 @@ public InvokeRequest getRequestPayload(String functionName, String invocationTyp
return null;
}

try {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any advantage in moving this to another public method? One disadvantage I see is that the caller needs to know and must follow that he should call completeCodec first before he call getRequestPayload method - right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think we are allowed to have private methods in interface(Buffer interface). That is the reason for using public here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want to have it as a separate method instead of keeping this code at where it is?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i felt that complete codec should not be associated with getRequestPayload . completeCodec should be independent of request. completeCodec should be done as soon as the buffer is done being added. This is just my thoughts. Please let me know what you think.

requestCodec.complete(this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}
completeCodec();

SdkBytes payload = getPayload();
payloadRequestSize = payload.asByteArray().length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition;
import org.opensearch.dataprepper.plugins.metricpublisher.MicrometerMetricPublisher;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
Expand Down Expand Up @@ -48,13 +49,14 @@ private static ClientOverrideConfiguration createOverrideConfiguration(
.maxBackoffTime(clientOptions.getMaxBackoff())
.build();

final RetryPolicy retryPolicy = RetryPolicy.builder()
final RetryPolicy customRetryPolicy = RetryPolicy.builder()
.retryCondition(new CustomLambdaRetryCondition())
.numRetries(clientOptions.getMaxConnectionRetries())
.backoffStrategy(backoffStrategy)
.build();

return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.retryPolicy(customRetryPolicy)
.addMetricPublisher(new MicrometerMetricPublisher(awsSdkMetrics))
.apiCallTimeout(clientOptions.getApiCallTimeout())
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.opensearch.dataprepper.plugins.lambda.common.util;

import software.amazon.awssdk.core.retry.RetryPolicyContext;

import java.util.concurrent.atomic.AtomicInteger;

//Used ONLY for tests
public class CountingRetryCondition extends CustomLambdaRetryCondition {
private final AtomicInteger retryCount = new AtomicInteger(0);

@Override
public boolean shouldRetry(RetryPolicyContext context) {
boolean shouldRetry = super.shouldRetry(context);
if (shouldRetry) {
retryCount.incrementAndGet();
}
return shouldRetry;
}

public int getRetryCount() {
return retryCount.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.opensearch.dataprepper.plugins.lambda.common.util;

import software.amazon.awssdk.core.retry.conditions.RetryCondition;
import software.amazon.awssdk.core.retry.RetryPolicyContext;

public class CustomLambdaRetryCondition implements RetryCondition {

@Override
public boolean shouldRetry(RetryPolicyContext context) {
Throwable exception = context.exception();
if (exception != null) {
return LambdaRetryStrategy.isRetryableException(exception);
}

return false;
}
}
Loading
Loading