Skip to content

Commit

Permalink
Move QueryParametersToBodyInterceptor to front of interceptor chain (#…
Browse files Browse the repository at this point in the history
…4109)

* Move QueryParametersToBodyInterceptor to front of interceptor chain

* Move customization.config interceptors to front of interceptor chain - for query protocols

* Refactoring

* Add codegen tests

* Refactoring

* Refactoring
  • Loading branch information
davidh44 authored and L-Applin committed Jul 19, 2023
1 parent c3083a4 commit ee2cd81
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.lang.model.element.Modifier;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.auth.token.credentials.SdkTokenProvider;
Expand All @@ -32,6 +36,9 @@
import software.amazon.awssdk.codegen.utils.AuthUtils;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.protocols.query.interceptor.QueryParametersToBodyInterceptor;
import software.amazon.awssdk.utils.CollectionUtils;

public class AsyncClientBuilderClass implements ClassSpec {
private final IntermediateModel model;
Expand Down Expand Up @@ -119,26 +126,53 @@ private MethodSpec endpointProviderMethod() {
}

private MethodSpec buildClientMethod() {
return MethodSpec.methodBuilder("buildClient")
.addAnnotation(Override.class)
.addModifiers(Modifier.PROTECTED, Modifier.FINAL)
.returns(clientInterfaceName)
.addStatement("$T clientConfiguration = super.asyncClientConfiguration()", SdkClientConfiguration.class)
.addStatement("this.validateClientOptions(clientConfiguration)")
.addStatement("$T endpointOverride = null", URI.class)
.addCode("if (clientConfiguration.option($T.ENDPOINT_OVERRIDDEN) != null"
+ "&& $T.TRUE.equals(clientConfiguration.option($T.ENDPOINT_OVERRIDDEN))) {"
+ "endpointOverride = clientConfiguration.option($T.ENDPOINT);"
+ "}",
SdkClientOption.class, Boolean.class, SdkClientOption.class, SdkClientOption.class)
.addStatement("$T serviceClientConfiguration = $T.builder()"
+ ".overrideConfiguration(overrideConfiguration())"
+ ".region(clientConfiguration.option($T.AWS_REGION))"
+ ".endpointOverride(endpointOverride)"
+ ".build()",
serviceConfigClassName, serviceConfigClassName, AwsClientOption.class)
.addStatement("return new $T(serviceClientConfiguration, clientConfiguration)", clientClassName)
.build();
MethodSpec.Builder b = MethodSpec.methodBuilder("buildClient")
.addAnnotation(Override.class)
.addModifiers(Modifier.PROTECTED, Modifier.FINAL)
.returns(clientInterfaceName)
.addStatement("$T clientConfiguration = super.asyncClientConfiguration()",
SdkClientConfiguration.class);

addQueryProtocolInterceptors(b);

return b.addStatement("this.validateClientOptions(clientConfiguration)")
.addStatement("$T endpointOverride = null", URI.class)
.addCode("if (clientConfiguration.option($T.ENDPOINT_OVERRIDDEN) != null"
+ "&& $T.TRUE.equals(clientConfiguration.option($T.ENDPOINT_OVERRIDDEN))) {"
+ "endpointOverride = clientConfiguration.option($T.ENDPOINT);"
+ "}",
SdkClientOption.class, Boolean.class, SdkClientOption.class, SdkClientOption.class)
.addStatement("$T serviceClientConfiguration = $T.builder()"
+ ".overrideConfiguration(overrideConfiguration())"
+ ".region(clientConfiguration.option($T.AWS_REGION))"
+ ".endpointOverride(endpointOverride)"
+ ".build()",
serviceConfigClassName, serviceConfigClassName, AwsClientOption.class)
.addStatement("return new $T(serviceClientConfiguration, clientConfiguration)", clientClassName)
.build();
}

private MethodSpec.Builder addQueryProtocolInterceptors(MethodSpec.Builder b) {
if (!model.getMetadata().isQueryProtocol()) {
return b;
}

TypeName listType = ParameterizedTypeName.get(List.class, ExecutionInterceptor.class);

b.addStatement("$T interceptors = clientConfiguration.option($T.EXECUTION_INTERCEPTORS)",
listType, SdkClientOption.class)
.addStatement("$T queryParamsToBodyInterceptor = $T.singletonList(new $T())",
listType, Collections.class, QueryParametersToBodyInterceptor.class)
.addStatement("$T customizationInterceptors = new $T<>()", listType, ArrayList.class);

List<String> customInterceptors = model.getCustomizationConfig().getInterceptors();
customInterceptors.forEach(i -> b.addStatement("customizationInterceptors.add(new $T())", ClassName.bestGuess(i)));

b.addStatement("interceptors = $T.mergeLists(queryParamsToBodyInterceptor, interceptors)", CollectionUtils.class)
.addStatement("interceptors = $T.mergeLists(customizationInterceptors, interceptors)", CollectionUtils.class);

return b.addStatement("clientConfiguration = clientConfiguration.toBuilder().option($T.EXECUTION_INTERCEPTORS, "
+ "interceptors).build()", SdkClientOption.class);
}

private MethodSpec bearerTokenProviderMethod() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -59,7 +58,6 @@
import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.http.Protocol;
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
import software.amazon.awssdk.protocols.query.interceptor.QueryParametersToBodyInterceptor;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.StringUtils;
Expand Down Expand Up @@ -262,8 +260,10 @@ private MethodSpec finalizeServiceConfigurationMethod() {
builtInInterceptors.add(endpointRulesSpecUtils.authSchemesInterceptorName());
builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName());

for (String interceptor : model.getCustomizationConfig().getInterceptors()) {
builtInInterceptors.add(ClassName.bestGuess(interceptor));
if (!model.getMetadata().isQueryProtocol()) {
for (String interceptor : model.getCustomizationConfig().getInterceptors()) {
builtInInterceptors.add(ClassName.bestGuess(interceptor));
}
}

for (ClassName interceptor : builtInInterceptors) {
Expand All @@ -288,16 +288,6 @@ private MethodSpec finalizeServiceConfigurationMethod() {
builder.addCode("interceptors = $T.mergeLists(interceptors, config.option($T.EXECUTION_INTERCEPTORS));\n",
CollectionUtils.class, SdkClientOption.class);

if (model.getMetadata().isQueryProtocol()) {
TypeName listType = ParameterizedTypeName.get(List.class, ExecutionInterceptor.class);
builder.addStatement("$T protocolInterceptors = $T.singletonList(new $T())",
listType,
Collections.class,
QueryParametersToBodyInterceptor.class);
builder.addStatement("interceptors = $T.mergeLists(interceptors, protocolInterceptors)",
CollectionUtils.class);
}

if (model.getEndpointOperation().isPresent()) {
builder.beginControlFlow("if (!endpointDiscoveryEnabled)")
.addStatement("$1T chain = new $1T(config)", DefaultEndpointDiscoveryProviderChain.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.lang.model.element.Modifier;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.auth.token.credentials.SdkTokenProvider;
Expand All @@ -32,6 +36,9 @@
import software.amazon.awssdk.codegen.utils.AuthUtils;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.protocols.query.interceptor.QueryParametersToBodyInterceptor;
import software.amazon.awssdk.utils.CollectionUtils;

public class SyncClientBuilderClass implements ClassSpec {
private final IntermediateModel model;
Expand Down Expand Up @@ -119,26 +126,53 @@ private MethodSpec endpointProviderMethod() {


private MethodSpec buildClientMethod() {
return MethodSpec.methodBuilder("buildClient")
.addAnnotation(Override.class)
.addModifiers(Modifier.PROTECTED, Modifier.FINAL)
.returns(clientInterfaceName)
.addStatement("$T clientConfiguration = super.syncClientConfiguration()", SdkClientConfiguration.class)
.addStatement("this.validateClientOptions(clientConfiguration)")
.addStatement("$T endpointOverride = null", URI.class)
.addCode("if (clientConfiguration.option($T.ENDPOINT_OVERRIDDEN) != null"
+ "&& $T.TRUE.equals(clientConfiguration.option($T.ENDPOINT_OVERRIDDEN))) {"
+ "endpointOverride = clientConfiguration.option($T.ENDPOINT);"
+ "}",
SdkClientOption.class, Boolean.class, SdkClientOption.class, SdkClientOption.class)
.addStatement("$T serviceClientConfiguration = $T.builder()"
+ ".overrideConfiguration(overrideConfiguration())"
+ ".region(clientConfiguration.option($T.AWS_REGION))"
+ ".endpointOverride(endpointOverride)"
+ ".build()",
serviceConfigClassName, serviceConfigClassName, AwsClientOption.class)
.addStatement("return new $T(serviceClientConfiguration, clientConfiguration)", clientClassName)
.build();
MethodSpec.Builder b = MethodSpec.methodBuilder("buildClient")
.addAnnotation(Override.class)
.addModifiers(Modifier.PROTECTED, Modifier.FINAL)
.returns(clientInterfaceName)
.addStatement("$T clientConfiguration = super.syncClientConfiguration()",
SdkClientConfiguration.class);

addQueryProtocolInterceptors(b);

return b.addStatement("this.validateClientOptions(clientConfiguration)")
.addStatement("$T endpointOverride = null", URI.class)
.addCode("if (clientConfiguration.option($T.ENDPOINT_OVERRIDDEN) != null"
+ "&& $T.TRUE.equals(clientConfiguration.option($T.ENDPOINT_OVERRIDDEN))) {"
+ "endpointOverride = clientConfiguration.option($T.ENDPOINT);"
+ "}",
SdkClientOption.class, Boolean.class, SdkClientOption.class, SdkClientOption.class)
.addStatement("$T serviceClientConfiguration = $T.builder()"
+ ".overrideConfiguration(overrideConfiguration())"
+ ".region(clientConfiguration.option($T.AWS_REGION))"
+ ".endpointOverride(endpointOverride)"
+ ".build()",
serviceConfigClassName, serviceConfigClassName, AwsClientOption.class)
.addStatement("return new $T(serviceClientConfiguration, clientConfiguration)", clientClassName)
.build();
}

private MethodSpec.Builder addQueryProtocolInterceptors(MethodSpec.Builder b) {
if (!model.getMetadata().isQueryProtocol()) {
return b;
}

TypeName listType = ParameterizedTypeName.get(List.class, ExecutionInterceptor.class);

b.addStatement("$T interceptors = clientConfiguration.option($T.EXECUTION_INTERCEPTORS)",
listType, SdkClientOption.class)
.addStatement("$T queryParamsToBodyInterceptor = $T.singletonList(new $T())",
listType, Collections.class, QueryParametersToBodyInterceptor.class)
.addStatement("$T customizationInterceptors = new $T<>()", listType, ArrayList.class);

List<String> customInterceptors = model.getCustomizationConfig().getInterceptors();
customInterceptors.forEach(i -> b.addStatement("customizationInterceptors.add(new $T())", ClassName.bestGuess(i)));

b.addStatement("interceptors = $T.mergeLists(queryParamsToBodyInterceptor, interceptors)", CollectionUtils.class)
.addStatement("interceptors = $T.mergeLists(customizationInterceptors, interceptors)", CollectionUtils.class);

return b.addStatement("clientConfiguration = clientConfiguration.toBuilder().option($T.EXECUTION_INTERCEPTORS, "
+ "interceptors).build()", SdkClientOption.class);
}

private MethodSpec tokenProviderMethodImpl() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package software.amazon.awssdk.codegen.internal;

import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.codegen.poet.builder.BuilderClassTest;

/**
* Empty no-op test interceptor for query protocols to view generated code in test-query-sync-client-builder-class.java and
* test-query-async-client-builder-class.java and validate in {@link BuilderClassTest}.
*/
@SdkInternalApi
public class QueryProtocolCustomTestInterceptor {
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ public void baseQueryClientBuilderClass() throws Exception {
validateQueryGeneration(BaseClientBuilderClass::new, "test-query-client-builder-class.java");
}

@Test
public void syncQueryClientBuilderClass() throws Exception {
validateQueryGeneration(SyncClientBuilderClass::new, "test-query-sync-client-builder-class.java");
}

@Test
public void asyncQueryClientBuilderClass() throws Exception {
validateQueryGeneration(AsyncClientBuilderClass::new, "test-query-async-client-builder-class.java");
}

@Test
public void syncClientBuilderInterface() throws Exception {
validateGeneration(SyncClientBuilderInterface::new, "test-sync-client-builder-interface.java");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package software.amazon.awssdk.services.query;

import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import software.amazon.awssdk.annotations.Generated;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.auth.token.credentials.SdkTokenProvider;
import software.amazon.awssdk.awscore.client.config.AwsClientOption;
import software.amazon.awssdk.codegen.internal.QueryProtocolCustomTestInterceptor;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.protocols.query.interceptor.QueryParametersToBodyInterceptor;
import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider;
import software.amazon.awssdk.utils.CollectionUtils;

/**
* Internal implementation of {@link QueryAsyncClientBuilder}.
*/
@Generated("software.amazon.awssdk:codegen")
@SdkInternalApi
final class DefaultQueryAsyncClientBuilder extends DefaultQueryBaseClientBuilder<QueryAsyncClientBuilder, QueryAsyncClient>
implements QueryAsyncClientBuilder {
@Override
public DefaultQueryAsyncClientBuilder endpointProvider(QueryEndpointProvider endpointProvider) {
clientConfiguration.option(SdkClientOption.ENDPOINT_PROVIDER, endpointProvider);
return this;
}

@Override
public DefaultQueryAsyncClientBuilder tokenProvider(SdkTokenProvider tokenProvider) {
clientConfiguration.option(AwsClientOption.TOKEN_PROVIDER, tokenProvider);
return this;
}

@Override
protected final QueryAsyncClient buildClient() {
SdkClientConfiguration clientConfiguration = super.asyncClientConfiguration();
List<ExecutionInterceptor> interceptors = clientConfiguration.option(SdkClientOption.EXECUTION_INTERCEPTORS);
List<ExecutionInterceptor> queryParamsToBodyInterceptor = Collections
.singletonList(new QueryParametersToBodyInterceptor());
List<ExecutionInterceptor> customizationInterceptors = new ArrayList<>();
customizationInterceptors.add(new QueryProtocolCustomTestInterceptor());
interceptors = CollectionUtils.mergeLists(queryParamsToBodyInterceptor, interceptors);
interceptors = CollectionUtils.mergeLists(customizationInterceptors, interceptors);
clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.EXECUTION_INTERCEPTORS, interceptors)
.build();
this.validateClientOptions(clientConfiguration);
URI endpointOverride = null;
if (clientConfiguration.option(SdkClientOption.ENDPOINT_OVERRIDDEN) != null
&& Boolean.TRUE.equals(clientConfiguration.option(SdkClientOption.ENDPOINT_OVERRIDDEN))) {
endpointOverride = clientConfiguration.option(SdkClientOption.ENDPOINT);
}
QueryServiceClientConfiguration serviceClientConfiguration = QueryServiceClientConfiguration.builder()
.overrideConfiguration(overrideConfiguration()).region(clientConfiguration.option(AwsClientOption.AWS_REGION))
.endpointOverride(endpointOverride).build();
return new DefaultQueryAsyncClient(serviceClientConfiguration, clientConfiguration);
}
}
Loading

0 comments on commit ee2cd81

Please sign in to comment.