Skip to content

Commit

Permalink
Make STSCredentialsProvider prefetch and stale times configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
Helmsdown authored and bmaizels committed Sep 15, 2020
1 parent da9ebb9 commit ff3f0ae
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/feature-AmazonSTS-289d9e7.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"category": "Amazon STS",
"type": "feature",
"description": "Make the STSCredentialsProvider stale and prefetch times configurable so clients can control when session credentials are refreshed"
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import java.util.function.Function;
import software.amazon.awssdk.annotations.NotThreadSafe;
import software.amazon.awssdk.annotations.SdkInternalApi;
Expand All @@ -31,6 +32,7 @@
import software.amazon.awssdk.utils.cache.NonBlocking;
import software.amazon.awssdk.utils.cache.RefreshResult;


/**
* An implementation of {@link AwsCredentialsProvider} that is extended within this package to provide support for periodically-
* updating session credentials. When credentials get close to expiration, this class will attempt to update them asynchronously
Expand All @@ -40,6 +42,10 @@
@ThreadSafe
@SdkInternalApi
abstract class StsCredentialsProvider implements AwsCredentialsProvider, SdkAutoCloseable {

private static final Duration DEFAULT_STALE_TIME = Duration.ofMinutes(1);
private static final Duration DEFAULT_PREFETCH_TIME = Duration.ofMinutes(5);

/**
* The STS client that should be used for periodically updating the session credentials in the background.
*/
Expand All @@ -50,9 +56,15 @@ abstract class StsCredentialsProvider implements AwsCredentialsProvider, SdkAuto
*/
private final CachedSupplier<SessionCredentialsHolder> sessionCache;

private final Duration staleTime;
private final Duration prefetchTime;

protected StsCredentialsProvider(BaseBuilder<?, ?> builder, String asyncThreadName) {
this.stsClient = Validate.notNull(builder.stsClient, "STS client must not be null.");

this.staleTime = Optional.ofNullable(builder.staleTime).orElse(DEFAULT_STALE_TIME);
this.prefetchTime = Optional.ofNullable(builder.prefetchTime).orElse(DEFAULT_PREFETCH_TIME);

CachedSupplier.Builder<SessionCredentialsHolder> cacheBuilder = CachedSupplier.builder(this::updateSessionCredentials);
if (builder.asyncCredentialUpdateEnabled) {
cacheBuilder.prefetchStrategy(new NonBlocking(asyncThreadName));
Expand All @@ -67,9 +79,10 @@ protected StsCredentialsProvider(BaseBuilder<?, ?> builder, String asyncThreadNa
private RefreshResult<SessionCredentialsHolder> updateSessionCredentials() {
SessionCredentialsHolder credentials = new SessionCredentialsHolder(getUpdatedCredentials(stsClient));
Instant actualTokenExpiration = credentials.getSessionCredentialsExpiration().toInstant();

return RefreshResult.builder(credentials)
.staleTime(actualTokenExpiration.minus(Duration.ofMinutes(1)))
.prefetchTime(actualTokenExpiration.minus(Duration.ofMinutes(5)))
.staleTime(actualTokenExpiration.minus(staleTime))
.prefetchTime(actualTokenExpiration.minus(prefetchTime))
.build();
}

Expand All @@ -83,6 +96,21 @@ public void close() {
sessionCache.close();
}

/**
* The amount of time, relative to STS token expiration, that the cached credentials are considered stale and should no longer be used.
* All threads will block until the value is updated.
*/
public Duration staleTime() {
return staleTime;
}

/**
* The amount of time, relative to STS token expiration, that the cached credentials are considered close to stale and should be updated.
*/
public Duration prefetchTime() {
return prefetchTime;
}

/**
* Implemented by a child class to call STS and get a new set of credentials to be used by this provider.
*/
Expand All @@ -97,6 +125,8 @@ protected abstract static class BaseBuilder<B extends BaseBuilder<B, T>, T> {

private Boolean asyncCredentialUpdateEnabled = false;
private StsClient stsClient;
private Duration staleTime;
private Duration prefetchTime;

protected BaseBuilder(Function<B, T> providerConstructor) {
this.providerConstructor = providerConstructor;
Expand Down Expand Up @@ -127,6 +157,31 @@ public B asyncCredentialUpdateEnabled(Boolean asyncCredentialUpdateEnabled) {
return (B) this;
}

/**
* Configure the amount of time, relative to STS token expiration, that the cached credentials are considered stale and should no longer be used.
* All threads will block until the value is updated.
*
* <p>By default, this is 1 minute.</p>
*/
@SuppressWarnings("unchecked")
public B staleTime(Duration staleTime) {
this.staleTime = staleTime;
return (B) this;
}

/**
* Configure the amount of time, relative to STS token expiration, that the cached credentials are considered close to stale and should be updated.
* See {@link #asyncCredentialUpdateEnabled}.
*
* <p>By default, this is 5 minutes.</p>
*/
@SuppressWarnings("unchecked")
public B prefetchTime(Duration prefetchTime) {
this.prefetchTime = prefetchTime;
return (B) this;
}


/**
* Build the credentials provider using the configuration applied to this builder.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,43 @@ public abstract class StsCredentialsProviderTestBase<RequestT, ResponseT> {

@Test
public void cachingDoesNotApplyToExpiredSession() {
callClientWithCredentialsProvider(Instant.now().minus(Duration.ofSeconds(5)), 2);
callClientWithCredentialsProvider(Instant.now().minus(Duration.ofSeconds(5)), 2, false);
callClient(verify(stsClient, times(2)), Mockito.any());
}

@Test
public void cachingDoesNotApplyToExpiredSession_OverridePrefetchAndStaleTimes() {
callClientWithCredentialsProvider(Instant.now().minus(Duration.ofSeconds(5)), 2, true);
callClient(verify(stsClient, times(2)), Mockito.any());
}

@Test
public void cachingAppliesToNonExpiredSession() {
callClientWithCredentialsProvider(Instant.now().plus(Duration.ofHours(5)), 2);
callClientWithCredentialsProvider(Instant.now().plus(Duration.ofHours(5)), 2, false);
callClient(verify(stsClient, times(1)), Mockito.any());
}

@Test
public void cachingAppliesToNonExpiredSession_OverridePrefetchAndStaleTimes() {
callClientWithCredentialsProvider(Instant.now().plus(Duration.ofHours(5)), 2, true);
callClient(verify(stsClient, times(1)), Mockito.any());
}

@Test
public void distantExpiringCredentialsUpdatedInBackground() throws InterruptedException {
callClientWithCredentialsProvider(Instant.now().plusSeconds(90), 2);
callClientWithCredentialsProvider(Instant.now().plusSeconds(90), 2, false);

Instant endCheckTime = Instant.now().plus(Duration.ofSeconds(5));
while (Mockito.mockingDetails(stsClient).getInvocations().size() < 2 && endCheckTime.isAfter(Instant.now())) {
Thread.sleep(100);
}

callClient(verify(stsClient, times(2)), Mockito.any());
}

@Test
public void distantExpiringCredentialsUpdatedInBackground_OverridePrefetchAndStaleTimes() throws InterruptedException {
callClientWithCredentialsProvider(Instant.now().plusSeconds(90), 2, true);

Instant endCheckTime = Instant.now().plus(Duration.ofSeconds(5));
while (Mockito.mockingDetails(stsClient).getInvocations().size() < 2 && endCheckTime.isAfter(Instant.now())) {
Expand All @@ -72,14 +96,32 @@ public void distantExpiringCredentialsUpdatedInBackground() throws InterruptedEx

protected abstract ResponseT callClient(StsClient client, RequestT request);

public void callClientWithCredentialsProvider(Instant credentialsExpirationDate, int numTimesInvokeCredentialsProvider) {
public void callClientWithCredentialsProvider(Instant credentialsExpirationDate, int numTimesInvokeCredentialsProvider, boolean overrideStaleAndPrefetchTimes) {
Credentials credentials = Credentials.builder().accessKeyId("a").secretAccessKey("b").sessionToken("c").expiration(credentialsExpirationDate).build();
RequestT request = getRequest();
ResponseT response = getResponse(credentials);

when(callClient(stsClient, request)).thenReturn(response);

try (StsCredentialsProvider credentialsProvider = createCredentialsProviderBuilder(request).stsClient(stsClient).build()) {
StsCredentialsProvider.BaseBuilder<?, ? extends StsCredentialsProvider> credentialsProviderBuilder = createCredentialsProviderBuilder(request);

if(overrideStaleAndPrefetchTimes) {
//do the same values as we would do without overriding the stale and prefetch times
credentialsProviderBuilder.staleTime(Duration.ofMinutes(2));
credentialsProviderBuilder.prefetchTime(Duration.ofMinutes(4));
}

try (StsCredentialsProvider credentialsProvider = credentialsProviderBuilder.stsClient(stsClient).build()) {
if(overrideStaleAndPrefetchTimes) {
//validate that we actually stored the override values in the build provider
assertThat(credentialsProvider.staleTime()).as("stale time").isEqualTo(Duration.ofMinutes(2));
assertThat(credentialsProvider.prefetchTime()).as("prefetch time").isEqualTo(Duration.ofMinutes(4));
} else {
//validate that the default values are used
assertThat(credentialsProvider.staleTime()).as("stale time").isEqualTo(Duration.ofMinutes(1));
assertThat(credentialsProvider.prefetchTime()).as("prefetch time").isEqualTo(Duration.ofMinutes(5));
}

for (int i = 0; i < numTimesInvokeCredentialsProvider; ++i) {
AwsSessionCredentials providedCredentials = (AwsSessionCredentials) credentialsProvider.resolveCredentials();
assertThat(providedCredentials.accessKeyId()).isEqualTo("a");
Expand Down

0 comments on commit ff3f0ae

Please sign in to comment.