Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move checksum calculation from afterMarshalling to modifyHttpRequest #4108

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@

package software.amazon.awssdk.core.internal.interceptor;

import static software.amazon.awssdk.core.HttpChecksumConstant.HTTP_CHECKSUM_VALUE;
import static software.amazon.awssdk.core.HttpChecksumConstant.SIGNING_METHOD;
import static software.amazon.awssdk.core.internal.util.HttpChecksumResolver.getResolvedChecksumSpecs;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Optional;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.core.checksums.ChecksumSpecs;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
Expand All @@ -47,49 +45,27 @@
@SdkInternalApi
public class HttpChecksumInHeaderInterceptor implements ExecutionInterceptor {

@Override
public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttributes executionAttributes) {
ChecksumSpecs headerChecksumSpecs = HttpChecksumUtils.checksumSpecWithRequestAlgorithm(executionAttributes).orElse(null);

if (shouldSkipHttpChecksumInHeader(context, executionAttributes, headerChecksumSpecs)) {
return;
}
Optional<RequestBody> syncContent = context.requestBody();
syncContent.ifPresent(
requestBody -> saveContentChecksum(requestBody, executionAttributes, headerChecksumSpecs.algorithm()));
}

@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
ChecksumSpecs checksumSpecs = getResolvedChecksumSpecs(executionAttributes);

if (shouldSkipHttpChecksumInHeader(context, executionAttributes, checksumSpecs)) {
return context.httpRequest();
}

String httpChecksumValue = executionAttributes.getAttribute(HTTP_CHECKSUM_VALUE);
if (httpChecksumValue != null) {
return context.httpRequest().copy(r -> r.putHeader(checksumSpecs.headerName(), httpChecksumValue));
}
return context.httpRequest();

}

/**
* Calculates the checksumSpecs of the provided request (and base64 encodes it), storing the result in
* executionAttribute "HttpChecksumValue".
* Calculates the checksum of the provided request (and base64 encodes it), and adds the header to the request.
*
* <p>Note: This assumes that the content stream provider can create multiple new streams. If it only supports one (e.g. with
* an input stream that doesn't support mark/reset), we could consider buffering the content in memory here and updating the
* request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do
* that yet.
*/
private static void saveContentChecksum(RequestBody requestBody, ExecutionAttributes executionAttributes,
Algorithm algorithm) {
@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
ChecksumSpecs checksumSpecs = getResolvedChecksumSpecs(executionAttributes);
Optional<RequestBody> syncContent = context.requestBody();

if (shouldSkipHttpChecksumInHeader(context, executionAttributes, checksumSpecs) || !syncContent.isPresent()) {
return context.httpRequest();
}

try {
String payloadChecksum = BinaryUtils.toBase64(HttpChecksumUtils.computeChecksum(
requestBody.contentStreamProvider().newStream(), algorithm));
executionAttributes.putAttribute(HTTP_CHECKSUM_VALUE, payloadChecksum);
syncContent.get().contentStreamProvider().newStream(), checksumSpecs.algorithm()));
return context.httpRequest().copy(r -> r.putHeader(checksumSpecs.headerName(), payloadChecksum));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
Expand All @@ -41,33 +40,39 @@
*/
@SdkInternalApi
public class HttpChecksumRequiredInterceptor implements ExecutionInterceptor {
private static final ExecutionAttribute<String> CONTENT_MD5_VALUE = new ExecutionAttribute<>("ContentMd5");

/**
* Calculates the MD5 checksum of the provided request (and base64 encodes it), and adds the header to the request.
*
* <p>Note: This assumes that the content stream provider can create multiple new streams. If it only supports one (e.g. with
* an input stream that doesn't support mark/reset), we could consider buffering the content in memory here and updating the
* request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do
* that yet.
*/
@Override
public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttributes executionAttributes) {
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
boolean isHttpChecksumRequired = isHttpChecksumRequired(executionAttributes);
boolean requestAlreadyHasMd5 = context.httpRequest().firstMatchingHeader(Header.CONTENT_MD5).isPresent();

Optional<RequestBody> syncContent = context.requestBody();
Optional<AsyncRequestBody> asyncContent = context.asyncRequestBody();

if (!isHttpChecksumRequired || requestAlreadyHasMd5) {
return;
return context.httpRequest();
}

if (asyncContent.isPresent()) {
throw new IllegalArgumentException("This operation requires a content-MD5 checksum, but one cannot be calculated "
+ "for non-blocking content.");
}

syncContent.ifPresent(requestBody -> saveContentMd5(requestBody, executionAttributes));
}

@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
String contentMd5 = executionAttributes.getAttribute(CONTENT_MD5_VALUE);
if (contentMd5 != null) {
return context.httpRequest().copy(r -> r.putHeader(Header.CONTENT_MD5, contentMd5));
if (syncContent.isPresent()) {
try {
String payloadMd5 = Md5Utils.md5AsBase64(syncContent.get().contentStreamProvider().newStream());
return context.httpRequest().copy(r -> r.putHeader(Header.CONTENT_MD5, payloadMd5));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
return context.httpRequest();
}
Expand All @@ -76,22 +81,4 @@ private boolean isHttpChecksumRequired(ExecutionAttributes executionAttributes)
return executionAttributes.getAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED) != null
|| HttpChecksumUtils.isMd5ChecksumRequired(executionAttributes);
}

/**
* Calculates the MD5 checksum of the provided request (and base64 encodes it), storing the result in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason we removed this function and its comment and we keep the function instead of inline ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to store the checksum as an ExecutionAttribute anymore, so it is just calling Md5Utils.md5AsBase64

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the comment atleast ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, adding back javadocs

* {@link #CONTENT_MD5_VALUE}.
*
* <p>Note: This assumes that the content stream provider can create multiple new streams. If it only supports one (e.g. with
* an input stream that doesn't support mark/reset), we could consider buffering the content in memory here and updating the
* request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do
* that yet.
*/
private void saveContentMd5(RequestBody requestBody, ExecutionAttributes executionAttributes) {
try {
String payloadMd5 = Md5Utils.md5AsBase64(requestBody.contentStreamProvider().newStream());
executionAttributes.putAttribute(CONTENT_MD5_VALUE, payloadMd5);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static software.amazon.awssdk.core.HttpChecksumConstant.HTTP_CHECKSUM_VALUE;

import io.reactivex.Flowable;
import java.io.IOException;
Expand All @@ -28,7 +27,6 @@
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
Expand All @@ -38,9 +36,6 @@
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.ExecutableHttpRequest;
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
Expand Down Expand Up @@ -103,11 +98,6 @@ public void setup() throws IOException {
});
}

@After
public void clear() {
CaptureChecksumValueInterceptor.reset();
}

@Test
public void sync_json_nonStreaming_unsignedPayload_with_Sha1_in_header() {
// jsonClient.flexibleCheckSumOperationWithShaChecksum(r -> r.stringMember("Hello world"));
Expand All @@ -118,9 +108,6 @@ public void sync_json_nonStreaming_unsignedPayload_with_Sha1_in_header() {
assertThat(getSyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("M68rRwFal7o7B3KEMt3m0w39TaA=");
// Assertion to make sure signer was not executed
assertThat(getSyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();

assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("M68rRwFal7o7B3KEMt3m0w39TaA=");

}

@Test
Expand All @@ -133,9 +120,6 @@ public void aync_json_nonStreaming_unsignedPayload_with_Sha1_in_header() {
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("M68rRwFal7o7B3KEMt3m0w39TaA=");
// Assertion to make sure signer was not executed
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();
assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("M68rRwFal7o7B3KEMt3m0w39TaA=");


}

@Test
Expand All @@ -148,9 +132,6 @@ public void sync_xml_nonStreaming_unsignedPayload_with_Sha1_in_header() {
assertThat(getSyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("FB/utBbwFLbIIt5ul3Ojuy5dKgU=");
// Assertion to make sure signer was not executed
assertThat(getSyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();

assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("FB/utBbwFLbIIt5ul3Ojuy5dKgU=");

}

@Test
Expand All @@ -169,9 +150,6 @@ public void sync_xml_nonStreaming_unsignedEmptyPayload_with_Sha1_in_header() {

// Assertion to make sure signer was not executed
assertThat(getSyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();

assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isNull();

}

@Test
Expand All @@ -185,8 +163,6 @@ public void aync_xml_nonStreaming_unsignedPayload_with_Sha1_in_header() {
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("FB/utBbwFLbIIt5ul3Ojuy5dKgU=");
// Assertion to make sure signer was not executed
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();
assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("FB/utBbwFLbIIt5ul3Ojuy5dKgU=");

}

@Test
Expand All @@ -206,8 +182,6 @@ public void aync_xml_nonStreaming_unsignedEmptyPayload_with_Sha1_in_header() {
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).isNotPresent();
// Assertion to make sure signer was not executed
assertThat(getAsyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent();
assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isNull();

}

private SdkHttpRequest getSyncRequest() {
Expand All @@ -224,32 +198,15 @@ private SdkHttpRequest getAsyncRequest() {


private <T extends AwsSyncClientBuilder<T, ?> & AwsClientBuilder<T, ?>> T initializeSync(T syncClientBuilder) {
return initialize(syncClientBuilder.httpClient(httpClient)
.overrideConfiguration(o -> o.addExecutionInterceptor(new CaptureChecksumValueInterceptor())));
return initialize(syncClientBuilder.httpClient(httpClient));
}

private <T extends AwsAsyncClientBuilder<T, ?> & AwsClientBuilder<T, ?>> T initializeAsync(T asyncClientBuilder) {
return initialize(asyncClientBuilder.httpClient(httpAsyncClient)
.overrideConfiguration(o -> o.addExecutionInterceptor(new CaptureChecksumValueInterceptor())));
return initialize(asyncClientBuilder.httpClient(httpAsyncClient));
}

private <T extends AwsClientBuilder<T, ?>> T initialize(T clientBuilder) {
return clientBuilder.credentialsProvider(AnonymousCredentialsProvider.create())
.region(Region.US_WEST_2);
}


private static class CaptureChecksumValueInterceptor implements ExecutionInterceptor {
private static String interceptorComputedChecksum;

private static void reset() {
interceptorComputedChecksum = null;
}

@Override
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
interceptorComputedChecksum = executionAttributes.getAttribute(HTTP_CHECKSUM_VALUE);

}
}
}