diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt index d57d379f2c..3182dbcc28 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt @@ -6,8 +6,11 @@ package software.amazon.smithy.rustsdk import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rulesengine.language.syntax.parameters.Builtins +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.endpoint.RulesEngineBuiltInResolver import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator @@ -97,6 +100,22 @@ class RegionDecorator : RustCodegenDecorator { + return listOf( + object : RulesEngineBuiltInResolver { + override fun defaultFor( + parameter: Parameter, + configRef: String, + ): Writable? { + return when (parameter) { + Builtins.REGION -> writable { rust("$configRef.region.as_ref().map(|r|r.as_ref())") } + else -> null + } + } + }, + ) + } + override fun supportsCodegenContext(clazz: Class): Boolean = clazz.isAssignableFrom(ClientCodegenContext::class.java) } @@ -117,8 +136,10 @@ class RegionProviderConfig(codegenContext: CodegenContext) : ConfigCustomization """, *codegenScope, ) + is ServiceConfig.BuilderStruct -> rustTemplate("region: Option<#{Region}>,", *codegenScope) + ServiceConfig.BuilderImpl -> rustTemplate( """ @@ -162,6 +183,7 @@ class RegionConfigPlugin : OperationCustomization() { """, ) } + else -> emptySection } } @@ -176,6 +198,7 @@ class PubUseRegion(private val runtimeConfig: RuntimeConfig) : LibRsCustomizatio region(runtimeConfig), ) } + else -> emptySection } } diff --git a/codegen-client-test/build.gradle.kts b/codegen-client-test/build.gradle.kts index c4c085bfd9..765f09c761 100644 --- a/codegen-client-test/build.gradle.kts +++ b/codegen-client-test/build.gradle.kts @@ -88,6 +88,7 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> """.trimIndent(), imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"), ), + CodegenTest("aws.protocoltests.json#TestService", "endpoint-rules"), CodegenTest("com.aws.example.rust#PokemonService", "pokemon-service-client", imports = listOf("$commonModels/pokemon.smithy", "$commonModels/pokemon-common.smithy")), ) } diff --git a/codegen-client-test/model/endpoint-rules.smithy b/codegen-client-test/model/endpoint-rules.smithy new file mode 100644 index 0000000000..54a1b3593b --- /dev/null +++ b/codegen-client-test/model/endpoint-rules.smithy @@ -0,0 +1,33 @@ +$version: "1.0" + +namespace aws.protocoltests.json + +use smithy.rules#endpointRuleSet +use smithy.rules#endpointTests + +use smithy.rules#clientContextParams +use smithy.rules#staticContextParams +use smithy.rules#contextParam +use aws.protocols#awsJson1_1 + +@awsJson1_1 +@endpointRuleSet({ + "version": "1.0", + "rules": [], + "parameters": { + "Bucket": { "required": false, "type": "String" }, + "Region": { "required": false, "type": "String", "builtIn": "AWS::Region" }, + } +}) +service TestService { + operations: [TestOperation] +} + +operation TestOperation { + input: TestOperationInput +} + +structure TestOperationInput { + @contextParam(name: "Bucket") + bucket: String +} diff --git a/codegen-client/build.gradle.kts b/codegen-client/build.gradle.kts index 7130b5c301..3400510065 100644 --- a/codegen-client/build.gradle.kts +++ b/codegen-client/build.gradle.kts @@ -29,6 +29,7 @@ dependencies { implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") + implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion") runtimeOnly(project(":rust-runtime")) testImplementation("org.junit.jupiter:junit-jupiter:5.6.1") testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion") diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt index 03eeeb24a0..6f58a557d5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt @@ -8,6 +8,8 @@ package software.amazon.smithy.rust.codegen.client.smithy import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider @@ -24,6 +26,9 @@ data class ClientCodegenContext( override val serviceShape: ServiceShape, override val protocol: ShapeId, override val settings: ClientRustSettings, + // Expose the `rootDecorator`, enabling customizations to compose by referencing information from the root codegen + // decorator + val rootDecorator: RustCodegenDecorator, ) : CodegenContext( model, symbolProvider, serviceShape, protocol, settings, CodegenTarget.CLIENT, ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitor.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitor.kt index 80026b8659..95657353d9 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitor.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitor.kt @@ -80,7 +80,7 @@ class CodegenVisitor( model = codegenDecorator.transformModel(service, baseModel) symbolProvider = RustCodegenPlugin.baseSymbolProvider(model, service, symbolVisitorConfig) - codegenContext = ClientCodegenContext(model, symbolProvider, service, protocol, settings) + codegenContext = ClientCodegenContext(model, symbolProvider, service, protocol, settings, codegenDecorator) val clientPublicModules = setOf( RustModule.Error, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RustCodegenDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RustCodegenDecorator.kt index d5f4e4e9a3..1a09f82ff3 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RustCodegenDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RustCodegenDecorator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.client.smithy.endpoint.RulesEngineBuiltInResolver import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RustCrate @@ -69,6 +70,8 @@ interface RustCodegenDecorator { fun transformModel(service: ServiceShape, model: Model): Model = model + fun builtInResolvers(codegenContext: C): List = listOf() + fun supportsCodegenContext(clazz: Class): Boolean } @@ -141,6 +144,10 @@ open class CombinedCodegenDecorator(decorators: List { + return orderedDecorators.flatMap { it.builtInResolvers(codegenContext) } + } + override fun supportsCodegenContext(clazz: Class): Boolean = // `CombinedCodegenDecorator` can work with all types of codegen context. CodegenContext::class.java.isAssignableFrom(clazz) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextParamDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextParamDecorator.kt new file mode 100644 index 0000000000..2f47ad275d --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextParamDecorator.kt @@ -0,0 +1,97 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.endpoint + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ShapeType +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rulesengine.traits.ClientContextParamDefinition +import software.amazon.smithy.rulesengine.traits.ClientContextParamsTrait +import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.makeOptional +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.orNull +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase + +/** + * This decorator adds `ClientContextParams` to the service config. + * + * This handles injecting parameters like `s3::Accelerate` or `s3::ForcePathStyle` + */ +class ClientContextDecorator(ctx: CodegenContext) : ConfigCustomization() { + private val contextParams = ctx.serviceShape.getTrait()?.parameters.orEmpty().toList() + .map { (key, value) -> ContextParam.fromClientParam(key, value, ctx.symbolProvider) } + + data class ContextParam(val name: String, val type: Symbol, val docs: String?) { + companion object { + private fun toSymbol(shapeType: ShapeType, symbolProvider: RustSymbolProvider): Symbol = + symbolProvider.toSymbol( + when (shapeType) { + ShapeType.STRING -> StringShape.builder().id("smithy.api#String").build() + ShapeType.BOOLEAN -> BooleanShape.builder().id("smithy.api#Boolean").build() + else -> TODO("unsupported type") + }, + ) + + fun fromClientParam( + name: String, + definition: ClientContextParamDefinition, + symbolProvider: RustSymbolProvider, + ): ContextParam { + return ContextParam( + RustReservedWords.escapeIfNeeded(name.toSnakeCase()), + toSymbol(definition.type, symbolProvider), + definition.documentation.orNull(), + ) + } + } + } + + override fun section(section: ServiceConfig): Writable { + return when (section) { + is ServiceConfig.ConfigStruct -> writable { + contextParams.forEach { param -> + rust("pub (crate) ${param.name}: #T,", param.type.makeOptional()) + } + } + ServiceConfig.ConfigImpl -> emptySection + ServiceConfig.BuilderStruct -> writable { + contextParams.forEach { param -> + rust("${param.name}: #T,", param.type.makeOptional()) + } + } + ServiceConfig.BuilderImpl -> writable { + contextParams.forEach { param -> + param.docs?.also { docs(it) } + rust( + """ + pub fn ${param.name}(mut self, ${param.name}: impl Into<#T>) -> Self { + self.${param.name} = Some(${param.name}.into()); + self + } + """, + param.type, + ) + } + } + ServiceConfig.BuilderBuild -> writable { + contextParams.forEach { param -> + rust("${param.name}: self.${param.name},") + } + } + else -> emptySection + } + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointParamsGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointParamsGenerator.kt new file mode 100644 index 0000000000..9fa3f45d30 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointParamsGenerator.kt @@ -0,0 +1,293 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.endpoint + +import software.amazon.smithy.rulesengine.language.eval.Value +import software.amazon.smithy.rulesengine.language.syntax.Identifier +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.asDeref +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.isCopy +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.Clone +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.Debug +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.Default +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.PartialEq +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.makeOptional +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.orNull + +// TODO(https://github.com/awslabs/smithy-rs/issues/1927): When endpoint resolution is implemented, remove doc-hidden +/** + * The module containing all endpoint resolution machinery. Module layout: + * ``` + * crate::endpoints:: + * struct Params // Endpoint parameter struct + * struct ParamsBuilder // Builder for Params + * enum InvalidParams + * DefaultResolver // struct implementing the endpoint resolver based on the provided rules for the service + * internal // private module containing the endpoints library functions, the private version of the default resolver + * endpoints_lib::{endpoints_fn*, ...} + * fn default_resolver(params: &Params, partition_metadata: &PartitionMetadata, error_collector: &mut ErrorCollector) + * ``` + */ +val EndpointsModule = RustModule.public("endpoint", "Endpoint resolution functionality") + .copy(rustMetadata = RustMetadata(additionalAttributes = listOf(Attribute.DocHidden), visibility = Visibility.PUBLIC)) + +/** Endpoint Parameters generator. + * + * This class generates the `Params` struct for an [EndpointRuleset]. The struct has `pub(crate)` fields, a `Builder`, + * and an error type, `InvalidParams` that is created to handle when construction fails. + * + * The builder of this struct generates a fallible `build()` method because endpoint params MAY have required fields. + * However, the external parts of this struct (the public accessors) will _always_ be optional to ensure a public + * interface is maintained. + * + * The following snippet contains an example of what is generated (eliding the error): + * ```rust + * #[non_exhaustive] + * #[derive(std::clone::Clone, std::cmp::PartialEq, std::fmt::Debug)] + * /// Configuration parameters for resolving the correct endpoint + * pub struct Params { + * pub(crate) region: std::option::Option, + * } + * impl Params { + * /// Create a builder for [`Params`] + * pub fn builder() -> crate::endpoint::ParamsBuilder { + * crate::endpoint::Builder::default() + * } + * /// Gets the value for region + * pub fn region(&self) -> std::option::Option<&str> { + * self.region.as_deref() + * } + * } + * + * /// Builder for [`Params`] + * #[derive(std::default::Default, std::clone::Clone, std::cmp::PartialEq, std::fmt::Debug)] + * pub struct ParamsBuilder { + * region: std::option::Option, + * } + * impl ParamsBuilder { + * /// Consume this builder, creating [`Params`]. + * pub fn build( + * self, + * ) -> Result { + * Ok(crate::endpoint::Params { + * region: self.region, + * }) + * } + * + * /// Sets the value for region + * pub fn region(mut self, value: std::string::String) -> Self { + * self.region = Some(value); + * self + * } + * + * /// Sets the value for region + * pub fn set_region(mut self, param: Option>) -> Self { + * self.region = param.map(|t| t.into()); + * self + * } + * } + * ``` + */ + +class EndpointParamsGenerator(private val parameters: Parameters) { + + companion object { + fun memberName(parameterName: String) = Identifier.of(parameterName).rustName() + fun setterName(parameterName: String) = "set_${memberName(parameterName)}" + } + + fun paramsStruct(): RuntimeType = RuntimeType.forInlineFun("Params", EndpointsModule) { + generateEndpointsStruct(this) + } + + private fun endpointsBuilder(): RuntimeType = RuntimeType.forInlineFun("ParamsBuilder", EndpointsModule) { + generateEndpointParamsBuilder(this) + } + + private fun paramsError(): RuntimeType = RuntimeType.forInlineFun("InvalidParams", EndpointsModule) { + rust( + """ + /// An error that occurred during endpoint resolution + ##[derive(Debug)] + pub struct InvalidParams { + field: std::borrow::Cow<'static, str> + } + + impl InvalidParams { + ##[allow(dead_code)] + fn missing(field: &'static str) -> Self { + Self { field: field.into() } + } + } + + impl std::fmt::Display for InvalidParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "a required field was missing: `{}`", self.field) + } + } + + impl std::error::Error for InvalidParams { } + """, + ) + } + + /** + * Generates an endpoints struct based on the provided endpoint rules. The struct fields are `pub(crate)` + * with optionality as indicated by the required status of the parameter. + */ + private fun generateEndpointsStruct(writer: RustWriter) { + // Ensure that fields can be added in the future + Attribute.NonExhaustive.render(writer) + // Automatically implement standard Rust functionality + Attribute.Derives(setOf(Debug, PartialEq, Clone)).render(writer) + // Generate the struct block: + /* + pub struct Params { + ... members: pub(crate) field + } + */ + writer.docs("Configuration parameters for resolving the correct endpoint") + writer.rustBlock("pub struct Params") { + parameters.toList().forEach { parameter -> + // Render documentation for each parameter + parameter.documentation.orNull()?.also { docs(it) } + rust("pub(crate) ${parameter.memberName()}: #T,", parameter.symbol()) + } + } + + // Generate the impl block for the struct + writer.rustBlock("impl Params") { + rustTemplate( + """ + /// Create a builder for [`Params`] + pub fn builder() -> #{Builder} { + #{Builder}::default() + } + """, + "Builder" to endpointsBuilder(), + ) + parameters.toList().forEach { parameter -> + val name = parameter.memberName() + val type = parameter.symbol() + + (parameter.documentation.orNull() ?: "Gets the value for `$name`").also { docs(it) } + rustTemplate( + """ + pub fn ${parameter.memberName()}(&self) -> #{paramType} { + #{param:W} + } + + """, + "paramType" to type.makeOptional().mapRustType { t -> t.asDeref() }, + "param" to writable { + when { + type.isOptional() && type.rustType().isCopy() -> rust("self.$name") + type.isOptional() -> rust("self.$name.as_deref()") + type.rustType().isCopy() -> rust("Some(self.$name)") + else -> rust("Some(&self.$name)") + } + }, + ) + } + } + } + + private fun value(value: Value): String { + return when (value) { + is Value.String -> value.value().dq() + ".to_string()" + is Value.Bool -> value.expectBool().toString() + else -> TODO("unexpected type: $value") + } + } + + private fun generateEndpointParamsBuilder(rustWriter: RustWriter) { + rustWriter.docs("Builder for [`Params`]") + Attribute.Derives(setOf(Debug, Default, PartialEq, Clone)).render(rustWriter) + rustWriter.rustBlock("pub struct ParamsBuilder") { + parameters.toList().forEach { parameter -> + val name = parameter.memberName() + val type = parameter.symbol().makeOptional() + rust("$name: #T,", type) + } + } + + rustWriter.rustBlock("impl ParamsBuilder") { + docs("Consume this builder, creating [`Params`].") + rustBlockTemplate( + "pub fn build(self) -> Result<#{Params}, #{ParamsError}>", + "Params" to paramsStruct(), + "ParamsError" to paramsError(), + ) { + val params = writable { + rustBlockTemplate("#{Params}", "Params" to paramsStruct()) { + parameters.toList().forEach { parameter -> + rust("${parameter.memberName()}: self.${parameter.memberName()}") + parameter.default.orNull()?.also { default -> rust(".or(Some(${value(default)}))") } + if (parameter.isRequired) { + rustTemplate( + ".ok_or_else(||#{Error}::missing(${parameter.memberName().dq()}))?", + "Error" to paramsError(), + ) + } + rust(",") + } + } + } + rust("Ok(#W)", params) + } + parameters.toList().forEach { parameter -> + val name = parameter.memberName() + check(name == memberName(parameter.name.toString())) + check("set_$name" == setterName(parameter.name.toString())) + val type = parameter.symbol().mapRustType { t -> t.stripOuter() } + rustTemplate( + """ + /// Sets the value for $name #{extraDocs:W} + pub fn $name(mut self, value: impl Into<#{type}>) -> Self { + self.$name = Some(value.into()); + self + } + + /// Sets the value for $name #{extraDocs:W} + pub fn set_$name(mut self, param: Option<#{nonOptionalType}>) -> Self { + self.$name = param; + self + } + """, + "nonOptionalType" to parameter.symbol().mapRustType { it.stripOuter() }, + "type" to type, + "extraDocs" to writable { + if (parameter.default.isPresent || parameter.documentation.isPresent) { + docs("") + } + parameter.default.orNull()?.also { + docs("When unset, this parameter has a default value of `$it`.") + } + parameter.documentation.orNull()?.also { docs(it) } + }, + ) + } + } + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointRulesetIndex.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointRulesetIndex.kt new file mode 100644 index 0000000000..b0e0b09e2f --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointRulesetIndex.kt @@ -0,0 +1,29 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.endpoint + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.KnowledgeIndex +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.rulesengine.language.EndpointRuleSet +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait +import software.amazon.smithy.rust.codegen.core.util.getTrait +import java.util.concurrent.ConcurrentHashMap + +class EndpointRulesetIndex(model: Model) : KnowledgeIndex { + + private val rulesets: ConcurrentHashMap = ConcurrentHashMap() + + fun endpointRulesForService(serviceShape: ServiceShape) = rulesets.computeIfAbsent( + serviceShape, + ) { serviceShape.getTrait()?.ruleSet?.let { EndpointRuleSet.fromNode(it) } } + + companion object { + fun of(model: Model): EndpointRulesetIndex { + return model.getKnowledge(EndpointRulesetIndex::class.java) { EndpointRulesetIndex(it) } + } + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecorator.kt new file mode 100644 index 0000000000..e033775267 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecorator.kt @@ -0,0 +1,168 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.endpoint + +import software.amazon.smithy.model.node.BooleanNode +import software.amazon.smithy.model.node.StringNode +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeType +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters +import software.amazon.smithy.rulesengine.traits.ContextIndex +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection +import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.orNull + +/** + * BuiltInResolver enables potentially external codegen stages to provide sources for `builtIn` parameters. + * For example, this allows AWS to provide the value for the region builtIn in separate codegen. + * + * If this resolver does not recognize the value, it MUST return `null`. + */ +interface RulesEngineBuiltInResolver { + fun defaultFor(parameter: Parameter, configRef: String): Writable? +} + +class EndpointsDecorator : RustCodegenDecorator { + override val name: String = "Endpoints" + override val order: Byte = 0 + + override fun supportsCodegenContext(clazz: Class): Boolean = + clazz.isAssignableFrom(ClientCodegenContext::class.java) + + override fun operationCustomizations( + codegenContext: ClientCodegenContext, + operation: OperationShape, + baseCustomizations: List, + ): List { + return baseCustomizations + CreateEndpointParams( + codegenContext, + operation, + codegenContext.rootDecorator.builtInResolvers(codegenContext), + ) + } + + override fun configCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List { + return baseCustomizations + ClientContextDecorator(codegenContext) + } +} + +/** + * Creates an `::endpoint_resolver::Params` structure in make operation generator. This combines state from the + * client, the operation, and the model to create parameters. + * + * Example generated code: + * ```rust + * let _endpoint_params = crate::endpoint_resolver::Params::builder() + * .set_region(Some("test-region")) + * .set_disable_everything(Some(true)) + * .set_bucket(input.bucket.as_ref()) + * .build(); + * ``` + */ +class CreateEndpointParams( + private val ctx: ClientCodegenContext, + private val operationShape: OperationShape, + private val rulesEngineBuiltInResolvers: List, +) : + OperationCustomization() { + + private val runtimeConfig = ctx.runtimeConfig + private val params = + EndpointRulesetIndex.of(ctx.model).endpointRulesForService(ctx.serviceShape)?.parameters + private val idx = ContextIndex.of(ctx.model) + + override fun section(section: OperationSection): Writable { + // if we don't have any parameters, then we have no rules, don't bother + if (params == null) { + return emptySection + } + val codegenScope = arrayOf( + "Params" to EndpointParamsGenerator(params).paramsStruct(), + "BuildError" to runtimeConfig.operationBuildError(), + ) + return when (section) { + is OperationSection.MutateInput -> writable { + rustTemplate( + """ + let endpoint_params = #{Params}::builder()#{builderFields:W}.build(); + """, + "builderFields" to builderFields(params, section), + *codegenScope, + ) + } + + is OperationSection.MutateRequest -> writable { + // insert the endpoint resolution _result_ into the bag (note that this won't bail if endpoint + // resolution failed) + // this is temporary—in the long term, we will insert the endpoint into the bag directly, but this makes + // it testable + rustTemplate("${section.request}.properties_mut().insert(endpoint_params);") + } + + else -> emptySection + } + } + + private fun builderFields(params: Parameters, section: OperationSection.MutateInput) = writable { + val memberParams = idx.getContextParams(operationShape) + val builtInParams = params.toList().filter { it.isBuiltIn } + // first load builtins and their defaults + builtInParams.forEach { param -> + val defaultProviders = rulesEngineBuiltInResolvers.mapNotNull { it.defaultFor(param, section.config) } + if (defaultProviders.size > 1) { + error("Multiple providers provided a value for the builtin $param") + } + defaultProviders.firstOrNull()?.also { defaultValue -> + rust(".set_${param.name.rustName()}(#W)", defaultValue) + } + } + + idx.getClientContextParams(ctx.serviceShape).orNull()?.parameters?.forEach { (name, param) -> + val paramName = EndpointParamsGenerator.memberName(name) + val setterName = EndpointParamsGenerator.setterName(name) + if (param.type == ShapeType.BOOLEAN) { + rust(".$setterName(${section.config}.$paramName)") + } else { + rust(".$setterName(${section.config}.$paramName.clone())") + } + } + + idx.getStaticContextParams(operationShape).orNull()?.parameters?.forEach { (name, param) -> + val setterName = EndpointParamsGenerator.setterName(name) + val value = writable { + when (val v = param.value) { + is BooleanNode -> rust("Some(${v.value})") + is StringNode -> rust("Some(${v.value.dq()}.to_string())") + else -> TODO("Unexpected static value type: $v") + } + } + rust(".$setterName(#W)", value) + } + + // lastly, allow these to be overridden by members + memberParams.forEach { (memberShape, param) -> + val memberName = ctx.symbolProvider.toMemberName(memberShape) + rust( + ".${EndpointParamsGenerator.setterName(param.name)}(${section.input}.$memberName.clone())", + ) + } + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/Util.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/Util.kt new file mode 100644 index 0000000000..3a2a94ed3e --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/Util.kt @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.endpoint + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.rulesengine.language.syntax.Identifier +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType +import software.amazon.smithy.rulesengine.traits.ContextParamTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.makeOptional +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase + +/** + * Utility function to convert an [Identifier] into a valid Rust identifier (snake case) + */ +fun Identifier.rustName(): String { + return this.toString().stringToRustName() +} + +private fun String.stringToRustName(): String = RustReservedWords.escapeIfNeeded(this.toSnakeCase()) + +/** + * Returns the memberName() for a given [Parameter] + */ +fun Parameter.memberName(): String { + return name.rustName() +} + +fun ContextParamTrait.memberName(): String = this.name.stringToRustName() + +/** + * Returns the symbol for a given parameter. This enables [RustWriter] to generate the correct [RustType]. + */ +fun Parameter.symbol(): Symbol { + val rustType = when (this.type) { + ParameterType.STRING -> RustType.String + ParameterType.BOOLEAN -> RustType.Bool + else -> TODO("unexpected type: ${this.type}") + } + // Parameter return types are always optional + return Symbol.builder().rustType(rustType).build().letIf(!this.isRequired) { it.makeOptional() } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/HttpVersionListGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/HttpVersionListGeneratorTest.kt index 3a20dbd9e5..ae0663290e 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/HttpVersionListGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/HttpVersionListGeneratorTest.kt @@ -207,7 +207,6 @@ class FakeSigningDecorator : RustCodegenDecorator, ): List { - println(baseCustomizations) return baseCustomizations.filterNot { it is EventStreamSigningConfig } + FakeSigningConfig(codegenContext.runtimeConfig) } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/ClientContextParamsDecoratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/ClientContextParamsDecoratorTest.kt new file mode 100644 index 0000000000..bda2d80817 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/ClientContextParamsDecoratorTest.kt @@ -0,0 +1,57 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.endpoint + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.client.smithy.endpoint.ClientContextDecorator +import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext +import software.amazon.smithy.rust.codegen.client.testutil.validateConfigCustomizations +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.unitTest + +class ClientContextParamsDecoratorTest { + val model = """ + namespace test + use smithy.rules#clientContextParams + + @clientContextParams(aStringParam: { + documentation: "string docs", + type: "string" + }, + aBoolParam: { + documentation: "bool docs", + type: "boolean" + }) + service TestService { operations: [] } + """.asSmithyModel() + + @Test + fun `client params generate a valid customization`() { + val project = TestWorkspace.testProject() + project.unitTest { + rust( + """ + let conf = crate::Config::builder().a_string_param("hello!").a_bool_param(true).build(); + assert_eq!(conf.a_string_param.unwrap(), "hello!"); + assert_eq!(conf.a_bool_param, Some(true)); + """, + ) + } + // unset fields + project.unitTest { + rust( + """ + let conf = crate::Config::builder().a_string_param("hello!").build(); + assert_eq!(conf.a_string_param.unwrap(), "hello!"); + assert_eq!(conf.a_bool_param, None); + """, + ) + } + validateConfigCustomizations(ClientContextDecorator(testCodegenContext(model)), project) + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointParamsGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointParamsGeneratorTest.kt new file mode 100644 index 0000000000..e070e8a6f9 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointParamsGeneratorTest.kt @@ -0,0 +1,41 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.endpoint + +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import software.amazon.smithy.rulesengine.testutil.TestDiscovery +import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointParamsGenerator +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import java.util.stream.Stream + +internal class EndpointParamsGeneratorTest { + companion object { + @JvmStatic + fun testSuites(): Stream = TestDiscovery().testSuites() + } + + @ParameterizedTest() + @MethodSource("testSuites") + fun `generate endpoint params for provided test suites`(testSuite: TestDiscovery.RulesTestSuite) { + val project = TestWorkspace.testProject() + project.lib { + unitTest("params_work") { + rustTemplate( + """ + // this might fail if there are required fields + let _ = #{Params}::builder().build(); + """, + "Params" to EndpointParamsGenerator(testSuite.ruleSet().parameters).paramsStruct(), + ) + } + } + project.compileAndTest() + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointsDecoratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointsDecoratorTest.kt new file mode 100644 index 0000000000..0babf3354e --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointsDecoratorTest.kt @@ -0,0 +1,97 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.endpoint + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsDecorator +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.TokioTest +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.integrationTest + +class EndpointsDecoratorTest { + + val model = """ + namespace test + + use smithy.rules#endpointRuleSet + use smithy.rules#endpointTests + + use smithy.rules#clientContextParams + use smithy.rules#staticContextParams + use smithy.rules#contextParam + use aws.protocols#awsJson1_1 + + @awsJson1_1 + @endpointRuleSet({ + "version": "1.0", + "rules": [], + "parameters": { + "Bucket": { "required": false, "type": "String" }, + "Region": { "required": false, "type": "String", "builtIn": "AWS::Region" }, + "AStringParam": { "required": false, "type": "String" }, + "ABoolParam": { "required": false, "type": "Boolean" } + } + }) + @clientContextParams(AStringParam: { + documentation: "string docs", + type: "string" + }, + aBoolParam: { + documentation: "bool docs", + type: "boolean" + }) + service TestService { + operations: [TestOperation] + } + + @staticContextParams(Region: { value: "us-east-2" }) + operation TestOperation { + input: TestOperationInput + } + + structure TestOperationInput { + @contextParam(name: "Bucket") + bucket: String + } + """.asSmithyModel() + + // NOTE: this test will fail once the endpoint starts being added directly (unless we preserve endpoint params in the + // property bag. + @Test + fun `add endpoint params to the property bag`() { + clientIntegrationTest(model, addtionalDecorators = listOf(EndpointsDecorator())) { clientCodegenContext, rustCrate -> + rustCrate.integrationTest("endpoint_params_test") { + val moduleName = clientCodegenContext.moduleUseName() + TokioTest.render(this) + rust( + """ + async fn endpoint_params_are_set() { + let conf = $moduleName::Config::builder().a_string_param("hello").a_bool_param(false).build(); + let operation = $moduleName::operation::TestOperation::builder() + .bucket("bucket-name").build().expect("input is valid") + .make_operation(&conf).await.expect("valid operation"); + use $moduleName::endpoint::{Params, InvalidParams}; + let props = operation.properties(); + let endpoint_params = props.get::>().unwrap(); + assert_eq!( + endpoint_params.as_ref().expect("ok"), + &Params::builder() + .bucket("bucket-name".to_string()) + .a_bool_param(false) + .a_string_param("hello".to_string()) + .region("us-east-2".to_string()) + .build().unwrap() + ); + } + + """, + ) + } + } + } +}