Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TM: downloadDirectory refactor how SDK sends concurrent download file requests #3867

Merged
merged 6 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@

package software.amazon.awssdk.transfer.s3.internal;

import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;
import software.amazon.awssdk.utils.async.DemandIgnoringSubscription;
import software.amazon.awssdk.utils.async.StoringSubscriber;

/**
* An implementation of {@link Subscriber} that execute the provided function for every event and limits the number of concurrent
Expand All @@ -41,20 +37,16 @@ public class AsyncBufferingSubscriber<T> implements Subscriber<T> {
private final Function<T, CompletableFuture<?>> consumer;
private final int maxConcurrentExecutions;
private final AtomicInteger numRequestsInFlight;
private final AtomicBoolean isDelivering = new AtomicBoolean(false);
private volatile boolean isStreamingDone;
private volatile boolean upstreamDone;
private Subscription subscription;

private final StoringSubscriber<T> storingSubscriber;

public AsyncBufferingSubscriber(Function<T, CompletableFuture<?>> consumer,
CompletableFuture<Void> returnFuture,
int maxConcurrentExecutions) {
this.returnFuture = returnFuture;
this.consumer = consumer;
this.maxConcurrentExecutions = maxConcurrentExecutions;
this.numRequestsInFlight = new AtomicInteger(0);
this.storingSubscriber = new StoringSubscriber<>(Integer.MAX_VALUE);
}

@Override
Expand All @@ -65,89 +57,41 @@ public void onSubscribe(Subscription subscription) {
subscription.cancel();
return;
}
storingSubscriber.onSubscribe(new DemandIgnoringSubscription(subscription));
this.subscription = subscription;
subscription.request(maxConcurrentExecutions);
}

@Override
public void onNext(T item) {
storingSubscriber.onNext(item);
flushBufferIfNeeded();
}

private void flushBufferIfNeeded() {
if (isDelivering.compareAndSet(false, true)) {
try {
Optional<StoringSubscriber.Event<T>> next = storingSubscriber.peek();
while (numRequestsInFlight.get() < maxConcurrentExecutions) {
if (!next.isPresent()) {
subscription.request(1);
break;
}

switch (next.get().type()) {
case ON_COMPLETE:
handleCompleteEvent();
break;
case ON_ERROR:
handleError(next.get().runtimeError());
break;
case ON_NEXT:
handleOnNext(next.get().value());
break;
default:
handleError(new IllegalStateException("Unknown stored type: " + next.get().type()));
break;
}

next = storingSubscriber.peek();
}
} finally {
isDelivering.set(false);
}
}
}

private void handleOnNext(T item) {
storingSubscriber.poll();

int numberOfRequestInFlight = numRequestsInFlight.incrementAndGet();
log.debug(() -> "Delivering next item, numRequestInFlight=" + numberOfRequestInFlight);

numRequestsInFlight.incrementAndGet();
consumer.apply(item).whenComplete((r, t) -> {
numRequestsInFlight.decrementAndGet();
if (!isStreamingDone) {
checkForCompletion(numRequestsInFlight.decrementAndGet());
synchronized (this) {
subscription.request(1);
} else {
flushBufferIfNeeded();
}
});
}

private void handleCompleteEvent() {
if (numRequestsInFlight.get() == 0) {
returnFuture.complete(null);
storingSubscriber.poll();
}
}

@Override
public void onError(Throwable t) {
handleError(t);
storingSubscriber.onError(t);
}

private void handleError(Throwable t) {
// Need to complete future exceptionally first to prevent
// accidental successful completion by a concurrent checkForCompletion.
returnFuture.completeExceptionally(t);
storingSubscriber.poll();
upstreamDone = true;
}

@Override
public void onComplete() {
isStreamingDone = true;
storingSubscriber.onComplete();
flushBufferIfNeeded();
upstreamDone = true;
checkForCompletion(numRequestsInFlight.get());
}

private void checkForCompletion(int requestsInFlight) {
if (upstreamDone && requestsInFlight == 0) {
// This could get invoked multiple times, but it doesn't matter
// because future.complete is idempotent.
returnFuture.complete(null);
zoewangg marked this conversation as resolved.
Show resolved Hide resolved
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,52 @@

import static org.assertj.core.api.Assertions.assertThat;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.core.LogEvent;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.testutils.LogCaptor;
import software.amazon.awssdk.transfer.s3.S3TransferManager;

public class TransferManagerLoggingTest {
class TransferManagerLoggingTest {

@Test
public void transferManager_withCrtClient_shouldNotLogWarnMessages(){
LogCaptor logCaptor = LogCaptor.create(Level.WARN);
S3AsyncClient s3Crt = S3AsyncClient.crtCreate();
S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build();
void transferManager_withCrtClient_shouldNotLogWarnMessages() {

List<LogEvent> events = logCaptor.loggedEvents();
assertThat(events).isEmpty();
logCaptor.clear();
logCaptor.close();
try (S3AsyncClient s3Crt = S3AsyncClient.crtBuilder()
.region(Region.US_WEST_2)
.credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar"))
.build();
LogCaptor logCaptor = LogCaptor.create(Level.WARN);
S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) {
List<LogEvent> events = logCaptor.loggedEvents();
assertThat(events).isEmpty();
}
}

@Test
public void transferManager_withJavaClient_shouldLogWarnMessage(){
LogCaptor logCaptor = LogCaptor.create(Level.WARN);
S3AsyncClient s3Java = S3AsyncClient.create();
S3TransferManager tm = S3TransferManager.builder().s3Client(s3Java).build();
void transferManager_withJavaClient_shouldLogWarnMessage() {

List<LogEvent> events = logCaptor.loggedEvents();
assertLogged(events, Level.WARN, "The provided DefaultS3AsyncClient is not an instance of S3CrtAsyncClient, and "
+ "thus multipart upload/download feature is not enabled and resumable file upload is "
+ "not supported. To benefit from maximum throughput, consider using "
+ "S3AsyncClient.crtBuilder().build() instead.");
logCaptor.clear();
logCaptor.close();

try (S3AsyncClient s3Crt = S3AsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar"))
.build();
LogCaptor logCaptor = LogCaptor.create(Level.WARN);
S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) {
List<LogEvent> events = logCaptor.loggedEvents();
assertLogged(events, Level.WARN, "The provided DefaultS3AsyncClient is not an instance of S3CrtAsyncClient, and "
+ "thus multipart upload/download feature is not enabled and resumable file upload"
+ " is "
+ "not supported. To benefit from maximum throughput, consider using "
+ "S3AsyncClient.crtBuilder().build() instead.");
}
}

private static void assertLogged(List<LogEvent> events, org.apache.logging.log4j.Level level, String message) {
Expand Down