Skip to content

Commit

Permalink
Move checksum calculation from afterMarshalling to modifyHttpRequest (#…
Browse files Browse the repository at this point in the history
…4108)

* Update HttpChecksumRequiredInterceptor

* Update HttpChecksumInHeaderInterceptor

* Update tests and remove constant

* Add back constant to resolve japicmp

* Add back javadocs
  • Loading branch information
davidh44 authored and L-Applin committed Jul 19, 2023
1 parent 014a284 commit bd84f34
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@

package software.amazon.awssdk.core.internal.interceptor;

import static software.amazon.awssdk.core.HttpChecksumConstant.HTTP_CHECKSUM_VALUE;
import static software.amazon.awssdk.core.HttpChecksumConstant.SIGNING_METHOD;
import static software.amazon.awssdk.core.internal.util.HttpChecksumResolver.getResolvedChecksumSpecs;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Optional;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.core.checksums.ChecksumSpecs;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
Expand All @@ -47,49 +45,27 @@
@SdkInternalApi
public class HttpChecksumInHeaderInterceptor implements ExecutionInterceptor {

@Override
public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttributes executionAttributes) {
ChecksumSpecs headerChecksumSpecs = HttpChecksumUtils.checksumSpecWithRequestAlgorithm(executionAttributes).orElse(null);

if (shouldSkipHttpChecksumInHeader(context, executionAttributes, headerChecksumSpecs)) {
return;
}
Optional<RequestBody> syncContent = context.requestBody();
syncContent.ifPresent(
requestBody -> saveContentChecksum(requestBody, executionAttributes, headerChecksumSpecs.algorithm()));
}

@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
ChecksumSpecs checksumSpecs = getResolvedChecksumSpecs(executionAttributes);

if (shouldSkipHttpChecksumInHeader(context, executionAttributes, checksumSpecs)) {
return context.httpRequest();
}

String httpChecksumValue = executionAttributes.getAttribute(HTTP_CHECKSUM_VALUE);
if (httpChecksumValue != null) {
return context.httpRequest().copy(r -> r.putHeader(checksumSpecs.headerName(), httpChecksumValue));
}
return context.httpRequest();

}

/**
* Calculates the checksumSpecs of the provided request (and base64 encodes it), storing the result in
* executionAttribute "HttpChecksumValue".
* Calculates the checksum of the provided request (and base64 encodes it), and adds the header to the request.
*
* <p>Note: This assumes that the content stream provider can create multiple new streams. If it only supports one (e.g. with
* an input stream that doesn't support mark/reset), we could consider buffering the content in memory here and updating the
* request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do
* that yet.
*/
private static void saveContentChecksum(RequestBody requestBody, ExecutionAttributes executionAttributes,
Algorithm algorithm) {
@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
ChecksumSpecs checksumSpecs = getResolvedChecksumSpecs(executionAttributes);
Optional<RequestBody> syncContent = context.requestBody();

if (shouldSkipHttpChecksumInHeader(context, executionAttributes, checksumSpecs) || !syncContent.isPresent()) {
return context.httpRequest();
}

try {
String payloadChecksum = BinaryUtils.toBase64(HttpChecksumUtils.computeChecksum(
requestBody.contentStreamProvider().newStream(), algorithm));
executionAttributes.putAttribute(HTTP_CHECKSUM_VALUE, payloadChecksum);
syncContent.get().contentStreamProvider().newStream(), checksumSpecs.algorithm()));
return context.httpRequest().copy(r -> r.putHeader(checksumSpecs.headerName(), payloadChecksum));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
Expand All @@ -41,33 +40,39 @@
*/
@SdkInternalApi
public class HttpChecksumRequiredInterceptor implements ExecutionInterceptor {
private static final ExecutionAttribute<String> CONTENT_MD5_VALUE = new ExecutionAttribute<>("ContentMd5");

/**
* Calculates the MD5 checksum of the provided request (and base64 encodes it), and adds the header to the request.
*
* <p>Note: This assumes that the content stream provider can create multiple new streams. If it only supports one (e.g. with
* an input stream that doesn't support mark/reset), we could consider buffering the content in memory here and updating the
* request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do
* that yet.
*/
@Override
public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttributes executionAttributes) {
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
boolean isHttpChecksumRequired = isHttpChecksumRequired(executionAttributes);
boolean requestAlreadyHasMd5 = context.httpRequest().firstMatchingHeader(Header.CONTENT_MD5).isPresent();

Optional<RequestBody> syncContent = context.requestBody();
Optional<AsyncRequestBody> asyncContent = context.asyncRequestBody();

if (!isHttpChecksumRequired || requestAlreadyHasMd5) {
return;
return context.httpRequest();
}

if (asyncContent.isPresent()) {
throw new IllegalArgumentException("This operation requires a content-MD5 checksum, but one cannot be calculated "
+ "for non-blocking content.");
}

syncContent.ifPresent(requestBody -> saveContentMd5(requestBody, executionAttributes));
}

@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
String contentMd5 = executionAttributes.getAttribute(CONTENT_MD5_VALUE);
if (contentMd5 != null) {
return context.httpRequest().copy(r -> r.putHeader(Header.CONTENT_MD5, contentMd5));
if (syncContent.isPresent()) {
try {
String payloadMd5 = Md5Utils.md5AsBase64(syncContent.get().contentStreamProvider().newStream());
return context.httpRequest().copy(r -> r.putHeader(Header.CONTENT_MD5, payloadMd5));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
return context.httpRequest();
}
Expand All @@ -76,22 +81,4 @@ private boolean isHttpChecksumRequired(ExecutionAttributes executionAttributes)
return executionAttributes.getAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED) != null
|| HttpChecksumUtils.isMd5ChecksumRequired(executionAttributes);
}

/**
* Calculates the MD5 checksum of the provided request (and base64 encodes it), storing the result in
* {@link #CONTENT_MD5_VALUE}.
*
* <p>Note: This assumes that the content stream provider can create multiple new streams. If it only supports one (e.g. with
* an input stream that doesn't support mark/reset), we could consider buffering the content in memory here and updating the
* request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do
* that yet.
*/
private void saveContentMd5(RequestBody requestBody, ExecutionAttributes executionAttributes) {
try {
String payloadMd5 = Md5Utils.md5AsBase64(requestBody.contentStreamProvider().newStream());
executionAttributes.putAttribute(CONTENT_MD5_VALUE, payloadMd5);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static software.amazon.awssdk.core.HttpChecksumConstant.HTTP_CHECKSUM_VALUE;

import io.reactivex.Flowable;
import java.io.IOException;
Expand All @@ -28,7 +27,6 @@
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
Expand All @@ -38,9 +36,6 @@
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.ExecutableHttpRequest;
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
Expand Down Expand Up @@ -103,11 +98,6 @@ public void setup() throws IOException {
});
}

@After
public void clear() {
CaptureChecksumValueInterceptor.reset();
}

@Test
public void sync_json_nonStreaming_unsignedPayload_with_Sha1_in_header() {
// jsonClient.flexibleCheckSumOperationWithShaChecksum(r -> r.stringMember("Hello world"));
Expand All @@ -118,9 +108,6 @@ public void sync_json_nonStreaming_unsignedPayload_with_Sha1_in_header() {
assertThat(getSyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("M68rRwFal7o7B3KEMt3m0w39TaA=");
// Assertion to make sure signer was not executed
assertThat(getSyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();

assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("M68rRwFal7o7B3KEMt3m0w39TaA=");

}

@Test
Expand All @@ -133,9 +120,6 @@ public void aync_json_nonStreaming_unsignedPayload_with_Sha1_in_header() {
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("M68rRwFal7o7B3KEMt3m0w39TaA=");
// Assertion to make sure signer was not executed
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();
assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("M68rRwFal7o7B3KEMt3m0w39TaA=");


}

@Test
Expand All @@ -148,9 +132,6 @@ public void sync_xml_nonStreaming_unsignedPayload_with_Sha1_in_header() {
assertThat(getSyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("FB/utBbwFLbIIt5ul3Ojuy5dKgU=");
// Assertion to make sure signer was not executed
assertThat(getSyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();

assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("FB/utBbwFLbIIt5ul3Ojuy5dKgU=");

}

@Test
Expand All @@ -169,9 +150,6 @@ public void sync_xml_nonStreaming_unsignedEmptyPayload_with_Sha1_in_header() {

// Assertion to make sure signer was not executed
assertThat(getSyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();

assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isNull();

}

@Test
Expand All @@ -185,8 +163,6 @@ public void aync_xml_nonStreaming_unsignedPayload_with_Sha1_in_header() {
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("FB/utBbwFLbIIt5ul3Ojuy5dKgU=");
// Assertion to make sure signer was not executed
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();
assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("FB/utBbwFLbIIt5ul3Ojuy5dKgU=");

}

@Test
Expand All @@ -206,8 +182,6 @@ public void aync_xml_nonStreaming_unsignedEmptyPayload_with_Sha1_in_header() {
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).isNotPresent();
// Assertion to make sure signer was not executed
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();
assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isNull();

}

private SdkHttpRequest getSyncRequest() {
Expand All @@ -224,32 +198,15 @@ private SdkHttpRequest getAsyncRequest() {


private <T extends AwsSyncClientBuilder<T, ?> & AwsClientBuilder<T, ?>> T initializeSync(T syncClientBuilder) {
return initialize(syncClientBuilder.httpClient(httpClient)
.overrideConfiguration(o -> o.addExecutionInterceptor(new CaptureChecksumValueInterceptor())));
return initialize(syncClientBuilder.httpClient(httpClient));
}

private <T extends AwsAsyncClientBuilder<T, ?> & AwsClientBuilder<T, ?>> T initializeAsync(T asyncClientBuilder) {
return initialize(asyncClientBuilder.httpClient(httpAsyncClient)
.overrideConfiguration(o -> o.addExecutionInterceptor(new CaptureChecksumValueInterceptor())));
return initialize(asyncClientBuilder.httpClient(httpAsyncClient));
}

private <T extends AwsClientBuilder<T, ?>> T initialize(T clientBuilder) {
return clientBuilder.credentialsProvider(AnonymousCredentialsProvider.create())
.region(Region.US_WEST_2);
}


private static class CaptureChecksumValueInterceptor implements ExecutionInterceptor {
private static String interceptorComputedChecksum;

private static void reset() {
interceptorComputedChecksum = null;
}

@Override
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
interceptorComputedChecksum = executionAttributes.getAttribute(HTTP_CHECKSUM_VALUE);

}
}
}

0 comments on commit bd84f34

Please sign in to comment.