diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index aa64869f7e14..ccb340c6f5d6 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -4,10 +4,9 @@ package s3 import ( - "context" "encoding/base64" - "errors" "fmt" + "os" "strings" "github.com/aws/aws-sdk-go/aws" @@ -15,339 +14,384 @@ import ( "github.com/aws/aws-sdk-go/service/s3" awsbase "github.com/hashicorp/aws-sdk-go-base" "github.com/hashicorp/terraform/internal/backend" - "github.com/hashicorp/terraform/internal/legacy/helper/schema" + "github.com/hashicorp/terraform/internal/configs/configschema" "github.com/hashicorp/terraform/internal/logging" + "github.com/hashicorp/terraform/internal/tfdiags" "github.com/hashicorp/terraform/version" + "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/gocty" ) -// New creates a new backend for S3 remote state. func New() backend.Backend { - s := &schema.Backend{ - Schema: map[string]*schema.Schema{ + return &Backend{} +} + +type Backend struct { + s3Client *s3.S3 + dynClient *dynamodb.DynamoDB + + bucketName string + keyName string + serverSideEncryption bool + customerEncryptionKey []byte + acl string + kmsKeyID string + ddbTable string + workspaceKeyPrefix string +} + +// ConfigSchema returns a description of the expected configuration +// structure for the receiving backend. +func (b *Backend) ConfigSchema() *configschema.Block { + return &configschema.Block{ + Attributes: map[string]*configschema.Attribute{ "bucket": { - Type: schema.TypeString, + Type: cty.String, Required: true, Description: "The name of the S3 bucket", }, - "key": { - Type: schema.TypeString, + Type: cty.String, Required: true, Description: "The path to the state file inside the bucket", - ValidateFunc: func(v interface{}, s string) ([]string, []error) { - // s3 will strip leading slashes from an object, so while this will - // technically be accepted by s3, it will break our workspace hierarchy. - if strings.HasPrefix(v.(string), "/") { - return nil, []error{errors.New("key must not start with '/'")} - } - // s3 will recognize objects with a trailing slash as a directory - // so they should not be valid keys - if strings.HasSuffix(v.(string), "/") { - return nil, []error{errors.New("key must not end with '/'")} - } - return nil, nil - }, }, - "region": { - Type: schema.TypeString, - Required: true, + Type: cty.String, + Optional: true, Description: "AWS region of the S3 Bucket and DynamoDB Table (if used).", - DefaultFunc: schema.MultiEnvDefaultFunc([]string{ - "AWS_REGION", - "AWS_DEFAULT_REGION", - }, nil), }, - "dynamodb_endpoint": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "A custom endpoint for the DynamoDB API", - DefaultFunc: schema.EnvDefaultFunc("AWS_DYNAMODB_ENDPOINT", ""), }, - "endpoint": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "A custom endpoint for the S3 API", - DefaultFunc: schema.EnvDefaultFunc("AWS_S3_ENDPOINT", ""), }, - "iam_endpoint": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "A custom endpoint for the IAM API", - DefaultFunc: schema.EnvDefaultFunc("AWS_IAM_ENDPOINT", ""), }, - "sts_endpoint": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "A custom endpoint for the STS API", - DefaultFunc: schema.EnvDefaultFunc("AWS_STS_ENDPOINT", ""), }, - "encrypt": { - Type: schema.TypeBool, + Type: cty.Bool, Optional: true, Description: "Whether to enable server side encryption of the state file", - Default: false, }, - "acl": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "Canned ACL to be applied to the state file", - Default: "", }, - "access_key": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "AWS access key", - Default: "", }, - "secret_key": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "AWS secret key", - Default: "", }, - "kms_key_id": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The ARN of a KMS Key to use for encrypting the state", - Default: "", }, - "dynamodb_table": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "DynamoDB table for state locking and consistency", - Default: "", }, - "profile": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "AWS profile name", - Default: "", }, - "shared_credentials_file": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "Path to a shared credentials file", - Default: "", }, - "token": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "MFA token", - Default: "", }, - "skip_credentials_validation": { - Type: schema.TypeBool, + Type: cty.Bool, Optional: true, Description: "Skip the credentials validation via STS API.", - Default: false, - }, - - "skip_region_validation": { - Type: schema.TypeBool, - Optional: true, - Description: "Skip static validation of region name.", - Default: false, }, - "skip_metadata_api_check": { - Type: schema.TypeBool, + Type: cty.Bool, Optional: true, Description: "Skip the AWS Metadata API check.", - Default: false, }, - + "skip_region_validation": { + Type: cty.Bool, + Optional: true, + Description: "Skip static validation of region name.", + }, "sse_customer_key": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The base64-encoded encryption key to use for server-side encryption with customer-provided keys (SSE-C).", - DefaultFunc: schema.EnvDefaultFunc("AWS_SSE_CUSTOMER_KEY", ""), Sensitive: true, - ValidateFunc: func(v interface{}, s string) ([]string, []error) { - key := v.(string) - if key != "" && len(key) != 44 { - return nil, []error{errors.New("sse_customer_key must be 44 characters in length (256 bits, base64 encoded)")} - } - return nil, nil - }, }, - "role_arn": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The role to be assumed", - Default: "", }, - "session_name": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The session name to use when assuming the role.", - Default: "", }, - "external_id": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The external ID to use when assuming the role", - Default: "", }, "assume_role_duration_seconds": { - Type: schema.TypeInt, + Type: cty.Number, Optional: true, Description: "Seconds to restrict the assume role session duration.", }, "assume_role_policy": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "IAM Policy JSON describing further restricting permissions for the IAM Role being assumed.", - Default: "", }, "assume_role_policy_arns": { - Type: schema.TypeSet, + Type: cty.Set(cty.String), Optional: true, Description: "Amazon Resource Names (ARNs) of IAM Policies describing further restricting permissions for the IAM Role being assumed.", - Elem: &schema.Schema{Type: schema.TypeString}, }, "assume_role_tags": { - Type: schema.TypeMap, + Type: cty.Map(cty.String), Optional: true, Description: "Assume role session tags.", - Elem: &schema.Schema{Type: schema.TypeString}, }, "assume_role_transitive_tag_keys": { - Type: schema.TypeSet, + Type: cty.Set(cty.String), Optional: true, Description: "Assume role session tag keys to pass to any subsequent sessions.", - Elem: &schema.Schema{Type: schema.TypeString}, }, "workspace_key_prefix": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The prefix applied to the non-default state path inside the bucket.", - Default: "env:", - ValidateFunc: func(v interface{}, s string) ([]string, []error) { - prefix := v.(string) - if strings.HasPrefix(prefix, "/") || strings.HasSuffix(prefix, "/") { - return nil, []error{errors.New("workspace_key_prefix must not start or end with '/'")} - } - return nil, nil - }, }, "force_path_style": { - Type: schema.TypeBool, + Type: cty.Bool, Optional: true, Description: "Force s3 to use path style api.", - Default: false, }, "max_retries": { - Type: schema.TypeInt, + Type: cty.Number, Optional: true, Description: "The maximum number of times an AWS API request is retried on retryable failure.", - Default: 5, }, }, } - - result := &Backend{Backend: s} - result.Backend.ConfigureFunc = result.configure - return result } -type Backend struct { - *schema.Backend +// PrepareConfig checks the validity of the values in the given +// configuration, and inserts any missing defaults, assuming that its +// structure has already been validated per the schema returned by +// ConfigSchema. +func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) { + var diags tfdiags.Diagnostics + if obj.IsNull() { + return obj, diags + } - // The fields below are set from configure - s3Client *s3.S3 - dynClient *dynamodb.DynamoDB + if val := obj.GetAttr("bucket"); val.IsNull() || val.AsString() == "" { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid bucket value", + `The "bucket" attribute value must not be empty.`, + cty.Path{cty.GetAttrStep{Name: "bucket"}}, + )) + } - bucketName string - keyName string - serverSideEncryption bool - customerEncryptionKey []byte - acl string - kmsKeyID string - ddbTable string - workspaceKeyPrefix string -} + if val := obj.GetAttr("key"); val.IsNull() || val.AsString() == "" { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid key value", + `The "key" attribute value must not be empty.`, + cty.Path{cty.GetAttrStep{Name: "key"}}, + )) + } else if strings.HasPrefix(val.AsString(), "/") || strings.HasSuffix(val.AsString(), "/") { + // S3 will strip leading slashes from an object, so while this will + // technically be accepted by S3, it will break our workspace hierarchy. + // S3 will recognize objects with a trailing slash as a directory + // so they should not be valid keys + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid key value", + `The "key" attribute value must not start or end with with "/".`, + cty.Path{cty.GetAttrStep{Name: "key"}}, + )) + } -func (b *Backend) configure(ctx context.Context) error { - if b.s3Client != nil { - return nil + if val := obj.GetAttr("region"); val.IsNull() || val.AsString() == "" { + if os.Getenv("AWS_REGION") == "" && os.Getenv("AWS_DEFAULT_REGION") == "" { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Missing region value", + `The "region" attribute or the "AWS_REGION" or "AWS_DEFAULT_REGION" environment variables must be set.`, + cty.Path{cty.GetAttrStep{Name: "region"}}, + )) + } } - // Grab the resource data - data := schema.FromContextBackendConfig(ctx) + if val := obj.GetAttr("kms_key_id"); !val.IsNull() && val.AsString() != "" { + if val := obj.GetAttr("sse_customer_key"); !val.IsNull() && val.AsString() != "" { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid encryption configuration", + encryptionKeyConflictError, + cty.Path{}, + )) + } else if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid encryption configuration", + encryptionKeyConflictEnvVarError, + cty.Path{}, + )) + } + + diags = diags.Append(validateKMSKey(cty.Path{cty.GetAttrStep{Name: "kms_key_id"}}, val.AsString())) + } - if !data.Get("skip_region_validation").(bool) { - if err := awsbase.ValidateRegion(data.Get("region").(string)); err != nil { - return err + if val := obj.GetAttr("workspace_key_prefix"); !val.IsNull() { + if v := val.AsString(); strings.HasPrefix(v, "/") || strings.HasSuffix(v, "/") { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid workspace_key_prefix value", + `The "workspace_key_prefix" attribute value must not start with "/".`, + cty.Path{cty.GetAttrStep{Name: "workspace_key_prefix"}}, + )) } } - b.bucketName = data.Get("bucket").(string) - b.keyName = data.Get("key").(string) - b.acl = data.Get("acl").(string) - b.workspaceKeyPrefix = data.Get("workspace_key_prefix").(string) - b.serverSideEncryption = data.Get("encrypt").(bool) - b.kmsKeyID = data.Get("kms_key_id").(string) - b.ddbTable = data.Get("dynamodb_table").(string) - - customerKeyString := data.Get("sse_customer_key").(string) - if customerKeyString != "" { - if b.kmsKeyID != "" { - return errors.New(encryptionKeyConflictError) + return obj, diags +} + +// Configure uses the provided configuration to set configuration fields +// within the backend. +// +// The given configuration is assumed to have already been validated +// against the schema returned by ConfigSchema and passed validation +// via PrepareConfig. +func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { + var diags tfdiags.Diagnostics + if obj.IsNull() { + return diags + } + + var region string + if v, ok := stringAttrOk(obj, "region"); ok { + region = v + } + + if region != "" && !boolAttr(obj, "skip_region_validation") { + if err := awsbase.ValidateRegion(region); err != nil { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid region value", + err.Error(), + cty.Path{cty.GetAttrStep{Name: "region"}}, + )) + return diags } + } - var err error - b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKeyString) - if err != nil { - return fmt.Errorf("Failed to decode sse_customer_key: %s", err.Error()) + b.bucketName = stringAttr(obj, "bucket") + b.keyName = stringAttr(obj, "key") + b.acl = stringAttr(obj, "acl") + b.workspaceKeyPrefix = stringAttrDefault(obj, "workspace_key_prefix", "env:") + b.serverSideEncryption = boolAttr(obj, "encrypt") + b.kmsKeyID = stringAttr(obj, "kms_key_id") + b.ddbTable = stringAttr(obj, "dynamodb_table") + + if customerKey, ok := stringAttrOk(obj, "sse_customer_key"); ok { + if len(customerKey) != 44 { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid sse_customer_key value", + "sse_customer_key must be 44 characters in length", + cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, + )) + } else { + var err error + if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid sse_customer_key value", + fmt.Sprintf("sse_customer_key must be base64 encoded: %s", err), + cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, + )) + } + } + } else if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" { + if len(customerKey) != 44 { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid AWS_SSE_CUSTOMER_KEY value", + `The environment variable "AWS_SSE_CUSTOMER_KEY" must be 44 characters in length`, + )) + } else { + var err error + if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid AWS_SSE_CUSTOMER_KEY value", + fmt.Sprintf(`The environment variable "AWS_SSE_CUSTOMER_KEY" must be base64 encoded: %s`, err), + )) + } } } cfg := &awsbase.Config{ - AccessKey: data.Get("access_key").(string), - AssumeRoleARN: data.Get("role_arn").(string), - AssumeRoleDurationSeconds: data.Get("assume_role_duration_seconds").(int), - AssumeRoleExternalID: data.Get("external_id").(string), - AssumeRolePolicy: data.Get("assume_role_policy").(string), - AssumeRoleSessionName: data.Get("session_name").(string), + AccessKey: stringAttr(obj, "access_key"), + AssumeRoleARN: stringAttr(obj, "role_arn"), + AssumeRoleDurationSeconds: intAttr(obj, "assume_role_duration_seconds"), + AssumeRoleExternalID: stringAttr(obj, "external_id"), + AssumeRolePolicy: stringAttr(obj, "assume_role_policy"), + AssumeRoleSessionName: stringAttr(obj, "session_name"), CallerDocumentationURL: "https://www.terraform.io/docs/language/settings/backends/s3.html", CallerName: "S3 Backend", - CredsFilename: data.Get("shared_credentials_file").(string), + CredsFilename: stringAttr(obj, "shared_credentials_file"), DebugLogging: logging.IsDebugOrHigher(), - IamEndpoint: data.Get("iam_endpoint").(string), - MaxRetries: data.Get("max_retries").(int), - Profile: data.Get("profile").(string), - Region: data.Get("region").(string), - SecretKey: data.Get("secret_key").(string), - SkipCredsValidation: data.Get("skip_credentials_validation").(bool), - SkipMetadataApiCheck: data.Get("skip_metadata_api_check").(bool), - StsEndpoint: data.Get("sts_endpoint").(string), - Token: data.Get("token").(string), + IamEndpoint: stringAttrDefaultEnvVar(obj, "iam_endpoint", "AWS_IAM_ENDPOINT"), + MaxRetries: intAttrDefault(obj, "max_retries", 5), + Profile: stringAttr(obj, "profile"), + Region: stringAttr(obj, "region"), + SecretKey: stringAttr(obj, "secret_key"), + SkipCredsValidation: boolAttr(obj, "skip_credentials_validation"), + SkipMetadataApiCheck: boolAttr(obj, "skip_metadata_api_check"), + StsEndpoint: stringAttrDefaultEnvVar(obj, "sts_endpoint", "AWS_STS_ENDPOINT"), + Token: stringAttr(obj, "token"), UserAgentProducts: []*awsbase.UserAgentProduct{ {Name: "APN", Version: "1.0"}, {Name: "HashiCorp", Version: "1.0"}, @@ -355,62 +399,162 @@ func (b *Backend) configure(ctx context.Context) error { }, } - if policyARNSet := data.Get("assume_role_policy_arns").(*schema.Set); policyARNSet.Len() > 0 { - for _, policyARNRaw := range policyARNSet.List() { - policyARN, ok := policyARNRaw.(string) + if policyARNSet := obj.GetAttr("assume_role_policy_arns"); !policyARNSet.IsNull() { + policyARNSet.ForEachElement(func(key, val cty.Value) (stop bool) { + v, ok := stringValueOk(val) + if ok { + cfg.AssumeRolePolicyARNs = append(cfg.AssumeRolePolicyARNs, v) + } + return + }) + } - if !ok { - continue + if tagMap := obj.GetAttr("assume_role_tags"); !tagMap.IsNull() { + cfg.AssumeRoleTags = make(map[string]string, tagMap.LengthInt()) + tagMap.ForEachElement(func(key, val cty.Value) (stop bool) { + k := stringValue(key) + v, ok := stringValueOk(val) + if ok { + cfg.AssumeRoleTags[k] = v } + return + }) + } - cfg.AssumeRolePolicyARNs = append(cfg.AssumeRolePolicyARNs, policyARN) - } + if transitiveTagKeySet := obj.GetAttr("assume_role_transitive_tag_keys"); !transitiveTagKeySet.IsNull() { + transitiveTagKeySet.ForEachElement(func(key, val cty.Value) (stop bool) { + v, ok := stringValueOk(val) + if ok { + cfg.AssumeRoleTransitiveTagKeys = append(cfg.AssumeRoleTransitiveTagKeys, v) + } + return + }) } - if tagMap := data.Get("assume_role_tags").(map[string]interface{}); len(tagMap) > 0 { - cfg.AssumeRoleTags = make(map[string]string) + sess, err := awsbase.GetSession(cfg) + if err != nil { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Failed to configure AWS client", + fmt.Sprintf(`The "S3" backend encountered an unexpected error while creating the AWS client: %s`, err), + )) + return diags + } - for k, vRaw := range tagMap { - v, ok := vRaw.(string) + var dynamoConfig aws.Config + if v, ok := stringAttrDefaultEnvVarOk(obj, "dynamodb_endpoint", "AWS_DYNAMODB_ENDPOINT"); ok { + dynamoConfig.Endpoint = aws.String(v) + } + b.dynClient = dynamodb.New(sess.Copy(&dynamoConfig)) - if !ok { - continue - } + var s3Config aws.Config + if v, ok := stringAttrDefaultEnvVarOk(obj, "endpoint", "AWS_S3_ENDPOINT"); ok { + s3Config.Endpoint = aws.String(v) + } + if v, ok := boolAttrOk(obj, "force_path_style"); ok { + s3Config.S3ForcePathStyle = aws.Bool(v) + } + b.s3Client = s3.New(sess.Copy(&s3Config)) - cfg.AssumeRoleTags[k] = v - } + return diags +} + +func stringValue(val cty.Value) string { + v, _ := stringValueOk(val) + return v +} + +func stringValueOk(val cty.Value) (string, bool) { + if val.IsNull() { + return "", false + } else { + return val.AsString(), true } +} - if transitiveTagKeySet := data.Get("assume_role_transitive_tag_keys").(*schema.Set); transitiveTagKeySet.Len() > 0 { - for _, transitiveTagKeyRaw := range transitiveTagKeySet.List() { - transitiveTagKey, ok := transitiveTagKeyRaw.(string) +func stringAttr(obj cty.Value, name string) string { + return stringValue(obj.GetAttr(name)) +} - if !ok { - continue - } +func stringAttrOk(obj cty.Value, name string) (string, bool) { + return stringValueOk(obj.GetAttr(name)) +} - cfg.AssumeRoleTransitiveTagKeys = append(cfg.AssumeRoleTransitiveTagKeys, transitiveTagKey) +func stringAttrDefault(obj cty.Value, name, def string) string { + if v, ok := stringAttrOk(obj, name); !ok { + return def + } else { + return v + } +} + +func stringAttrDefaultEnvVar(obj cty.Value, name string, envvars ...string) string { + if v, ok := stringAttrDefaultEnvVarOk(obj, name, envvars...); !ok { + return "" + } else { + return v + } +} + +func stringAttrDefaultEnvVarOk(obj cty.Value, name string, envvars ...string) (string, bool) { + if v, ok := stringAttrOk(obj, name); !ok { + for _, envvar := range envvars { + if v := os.Getenv(envvar); v != "" { + return v, true + } } + return "", false + } else { + return v, true } +} - sess, err := awsbase.GetSession(cfg) - if err != nil { - return fmt.Errorf("error configuring S3 Backend: %w", err) +func boolAttr(obj cty.Value, name string) bool { + v, _ := boolAttrOk(obj, name) + return v +} + +func boolAttrOk(obj cty.Value, name string) (bool, bool) { + if val := obj.GetAttr(name); val.IsNull() { + return false, false + } else { + return val.True(), true } +} - b.dynClient = dynamodb.New(sess.Copy(&aws.Config{ - Endpoint: aws.String(data.Get("dynamodb_endpoint").(string)), - })) - b.s3Client = s3.New(sess.Copy(&aws.Config{ - Endpoint: aws.String(data.Get("endpoint").(string)), - S3ForcePathStyle: aws.Bool(data.Get("force_path_style").(bool)), - })) +func intAttr(obj cty.Value, name string) int { + v, _ := intAttrOk(obj, name) + return v +} - return nil +func intAttrOk(obj cty.Value, name string) (int, bool) { + if val := obj.GetAttr(name); val.IsNull() { + return 0, false + } else { + var v int + if err := gocty.FromCtyValue(val, &v); err != nil { + return 0, false + } + return v, true + } +} + +func intAttrDefault(obj cty.Value, name string, def int) int { + if v, ok := intAttrOk(obj, name); !ok { + return def + } else { + return v + } } -const encryptionKeyConflictError = `Cannot have both kms_key_id and sse_customer_key set. +const encryptionKeyConflictError = `Only one of "kms_key_id" and "sse_customer_key" can be set. + +The "kms_key_id" is used for encryption with KMS-Managed Keys (SSE-KMS) +while "sse_customer_key" is used for encryption with customer-managed keys (SSE-C). +Please choose one or the other.` + +const encryptionKeyConflictEnvVarError = `Only one of "kms_key_id" and the environment variable "AWS_SSE_CUSTOMER_KEY" can be set. -The kms_key_id is used for encryption with KMS-Managed Keys (SSE-KMS) -while sse_customer_key is used for encryption with customer-managed keys (SSE-C). +The "kms_key_id" is used for encryption with KMS-Managed Keys (SSE-KMS) +while "AWS_SSE_CUSTOMER_KEY" is used for encryption with customer-managed keys (SSE-C). Please choose one or the other.` diff --git a/internal/backend/remote-state/s3/backend_state.go b/internal/backend/remote-state/s3/backend_state.go index add7d399d88c..89c2b08735ca 100644 --- a/internal/backend/remote-state/s3/backend_state.go +++ b/internal/backend/remote-state/s3/backend_state.go @@ -203,10 +203,6 @@ func (b *Backend) StateMgr(name string) (statemgr.Full, error) { return stateMgr, nil } -func (b *Backend) client() *RemoteClient { - return &RemoteClient{} -} - func (b *Backend) path(name string) string { if name == backend.DefaultStateName { return b.keyName diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 92fe40b2c346..4e39e258cfc2 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -4,21 +4,27 @@ package s3 import ( + "encoding/base64" "fmt" "net/url" "os" "reflect" + "strings" "testing" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/s3" + "github.com/google/go-cmp/cmp" awsbase "github.com/hashicorp/aws-sdk-go-base" "github.com/hashicorp/terraform/internal/backend" + "github.com/hashicorp/terraform/internal/configs/configschema" "github.com/hashicorp/terraform/internal/configs/hcl2shim" "github.com/hashicorp/terraform/internal/states" "github.com/hashicorp/terraform/internal/states/remote" + "github.com/hashicorp/terraform/internal/tfdiags" + "github.com/zclconf/go-cty/cty" ) var ( @@ -59,6 +65,9 @@ func TestBackendConfig(t *testing.T) { if *b.s3Client.Config.Region != "us-west-1" { t.Fatalf("Incorrect region was populated") } + if *b.s3Client.Config.MaxRetries != 5 { + t.Fatalf("Default max_retries was not set") + } if b.bucketName != "tf-test" { t.Fatalf("Incorrect bucketName was populated") } @@ -66,6 +75,10 @@ func TestBackendConfig(t *testing.T) { t.Fatalf("Incorrect keyName was populated") } + checkClientEndpoint(t, b.s3Client.Config, "") + + checkClientEndpoint(t, b.dynClient.Config, "") + credentials, err := b.s3Client.Config.Credentials.Get() if err != nil { t.Fatalf("Error when requesting credentials") @@ -78,6 +91,311 @@ func TestBackendConfig(t *testing.T) { } } +func checkClientEndpoint(t *testing.T, config aws.Config, expected string) { + if a := aws.StringValue(config.Endpoint); a != expected { + t.Errorf("expected endpoint %q, got %q", expected, a) + } +} + +func TestBackendConfig_InvalidRegion(t *testing.T) { + testACC(t) + + cases := map[string]struct { + config map[string]any + expectedDiags tfdiags.Diagnostics + }{ + "with region validation": { + config: map[string]interface{}{ + "region": "nonesuch", + "bucket": "tf-test", + "key": "state", + "skip_credentials_validation": true, + }, + expectedDiags: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid region value", + `Invalid AWS Region: nonesuch`, + cty.Path{cty.GetAttrStep{Name: "region"}}, + ), + }, + }, + "skip region validation": { + config: map[string]interface{}{ + "region": "nonesuch", + "bucket": "tf-test", + "key": "state", + "skip_region_validation": true, + "skip_credentials_validation": true, + }, + expectedDiags: nil, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + b := New() + configSchema := populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(tc.config)) + + configSchema, diags := b.PrepareConfig(configSchema) + if len(diags) > 0 { + t.Fatal(diags.ErrWithWarnings()) + } + + confDiags := b.Configure(configSchema) + diags = diags.Append(confDiags) + + if diff := cmp.Diff(diags, tc.expectedDiags, cmp.Comparer(diagnosticComparer)); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + }) + } +} + +func TestBackendConfig_RegionEnvVar(t *testing.T) { + testACC(t) + config := map[string]interface{}{ + "bucket": "tf-test", + "key": "state", + } + + cases := map[string]struct { + vars map[string]string + }{ + "AWS_REGION": { + vars: map[string]string{ + "AWS_REGION": "us-west-1", + }, + }, + + "AWS_DEFAULT_REGION": { + vars: map[string]string{ + "AWS_DEFAULT_REGION": "us-west-1", + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + for k, v := range tc.vars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range tc.vars { + os.Unsetenv(k) + } + }) + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + + if *b.s3Client.Config.Region != "us-west-1" { + t.Fatalf("Incorrect region was populated") + } + }) + } +} + +func TestBackendConfig_DynamoDBEndpoint(t *testing.T) { + testACC(t) + + cases := map[string]struct { + config map[string]any + vars map[string]string + expected string + }{ + "none": { + expected: "", + }, + "config": { + config: map[string]any{ + "dynamodb_endpoint": "dynamo.test", + }, + expected: "dynamo.test", + }, + "envvar": { + vars: map[string]string{ + "AWS_DYNAMODB_ENDPOINT": "dynamo.test", + }, + expected: "dynamo.test", + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", + } + + if tc.vars != nil { + for k, v := range tc.vars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range tc.vars { + os.Unsetenv(k) + } + }) + } + + if tc.config != nil { + for k, v := range tc.config { + config[k] = v + } + } + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + + checkClientEndpoint(t, b.dynClient.Config, tc.expected) + }) + } +} + +func TestBackendConfig_S3Endpoint(t *testing.T) { + testACC(t) + + cases := map[string]struct { + config map[string]any + vars map[string]string + expected string + }{ + "none": { + expected: "", + }, + "config": { + config: map[string]any{ + "endpoint": "s3.test", + }, + expected: "s3.test", + }, + "envvar": { + vars: map[string]string{ + "AWS_S3_ENDPOINT": "s3.test", + }, + expected: "s3.test", + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", + } + + if tc.vars != nil { + for k, v := range tc.vars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range tc.vars { + os.Unsetenv(k) + } + }) + } + + if tc.config != nil { + for k, v := range tc.config { + config[k] = v + } + } + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + + checkClientEndpoint(t, b.s3Client.Config, tc.expected) + }) + } +} + +func TestBackendConfig_STSEndpoint(t *testing.T) { + testACC(t) + + stsMocks := []*awsbase.MockEndpoint{ + { + Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: url.Values{ + "Action": []string{"AssumeRole"}, + "DurationSeconds": []string{"900"}, + "RoleArn": []string{awsbase.MockStsAssumeRoleArn}, + "RoleSessionName": []string{awsbase.MockStsAssumeRoleSessionName}, + "Version": []string{"2011-06-15"}, + }.Encode()}, + Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsAssumeRoleValidResponseBody, ContentType: "text/xml"}, + }, + { + Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: mockStsGetCallerIdentityRequestBody}, + Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsGetCallerIdentityValidResponseBody, ContentType: "text/xml"}, + }, + } + + cases := map[string]struct { + setConfig bool + setEnvVars bool + expectedDiags tfdiags.Diagnostics + }{ + "none": { + expectedDiags: tfdiags.Diagnostics{ + tfdiags.Sourceless( + tfdiags.Error, + "Failed to configure AWS client", + ``, + ), + }, + }, + "config": { + setConfig: true, + }, + "envvar": { + setEnvVars: true, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", + "role_arn": awsbase.MockStsAssumeRoleArn, + "session_name": awsbase.MockStsAssumeRoleSessionName, + } + + closeSts, mockStsSession, err := awsbase.GetMockedAwsApiSession("STS", stsMocks) + if err != nil { + t.Fatalf("unexpected error creating mock STS server: %s", err) + } + defer closeSts() + + if tc.setEnvVars { + os.Setenv("AWS_STS_ENDPOINT", aws.StringValue(mockStsSession.Config.Endpoint)) + t.Cleanup(func() { + os.Unsetenv("AWS_STS_ENDPOINT") + }) + } + + if tc.setConfig { + config["sts_endpoint"] = aws.StringValue(mockStsSession.Config.Endpoint) + } + + b := New() + configSchema := populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config)) + + configSchema, diags := b.PrepareConfig(configSchema) + if len(diags) > 0 { + t.Fatal(diags.ErrWithWarnings()) + } + + confDiags := b.Configure(configSchema) + diags = diags.Append(confDiags) + + if diff := cmp.Diff(diags, tc.expectedDiags, cmp.Comparer(diagnosticSummaryComparer)); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + }) + } +} + func TestBackendConfig_AssumeRole(t *testing.T) { testACC(t) @@ -304,7 +622,8 @@ func TestBackendConfig_AssumeRole(t *testing.T) { testCase.Config["sts_endpoint"] = aws.StringValue(mockStsSession.Config.Endpoint) } - diags := New().Configure(hcl2shim.HCL2ValueFromConfigValue(testCase.Config)) + b := New() + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(testCase.Config))) if diags.HasErrors() { for _, diag := range diags { @@ -315,84 +634,196 @@ func TestBackendConfig_AssumeRole(t *testing.T) { } } -func TestBackendConfig_invalidKey(t *testing.T) { - testACC(t) - cfg := hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "/leading-slash", - "encrypt": true, - "dynamodb_table": "dynamoTable", - }) - - _, diags := New().PrepareConfig(cfg) - if !diags.HasErrors() { - t.Fatal("expected config validation error") +func TestBackendConfig_PrepareConfigValidation(t *testing.T) { + cases := map[string]struct { + config cty.Value + expectedErr string + }{ + "null bucket": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.NullVal(cty.String), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `The "bucket" attribute value must not be empty.`, + }, + "empty bucket": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal(""), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `The "bucket" attribute value must not be empty.`, + }, + "null key": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.NullVal(cty.String), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `The "key" attribute value must not be empty.`, + }, + "empty key": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal(""), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `The "key" attribute value must not be empty.`, + }, + "key with leading slash": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("/leading-slash"), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `The "key" attribute value must not start or end with with "/".`, + }, + "key with trailing slash": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("trailing-slash/"), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `The "key" attribute value must not start or end with with "/".`, + }, + "null region": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.NullVal(cty.String), + }), + expectedErr: `The "region" attribute or the "AWS_REGION" or "AWS_DEFAULT_REGION" environment variables must be set.`, + }, + "empty region": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal(""), + }), + expectedErr: `The "region" attribute or the "AWS_REGION" or "AWS_DEFAULT_REGION" environment variables must be set.`, + }, + "workspace_key_prefix with leading slash": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.StringVal("/env"), + }), + expectedErr: `The "workspace_key_prefix" attribute value must not start with "/".`, + }, + "workspace_key_prefix with trailing slash": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.StringVal("env/"), + }), + expectedErr: `The "workspace_key_prefix" attribute value must not start with "/".`, + }, + "encyrption key conflict": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.StringVal("env"), + "sse_customer_key": cty.StringVal("1hwbcNPGWL+AwDiyGmRidTWAEVmCWMKbEHA+Es8w75o="), + "kms_key_id": cty.StringVal("arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab"), + }), + expectedErr: `Only one of "kms_key_id" and "sse_customer_key" can be set`, + }, } - cfg = hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "trailing-slash/", - "encrypt": true, - "dynamodb_table": "dynamoTable", - }) - - _, diags = New().PrepareConfig(cfg) - if !diags.HasErrors() { - t.Fatal("expected config validation error") + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + oldEnv := stashEnv() + defer popEnv(oldEnv) + + b := New() + + _, valDiags := b.PrepareConfig(populateSchema(t, b.ConfigSchema(), tc.config)) + if tc.expectedErr != "" { + if valDiags.Err() != nil { + actualErr := valDiags.Err().Error() + if !strings.Contains(actualErr, tc.expectedErr) { + t.Fatalf("unexpected validation result: %v", valDiags.Err()) + } + } else { + t.Fatal("expected an error, got none") + } + } else if valDiags.Err() != nil { + t.Fatalf("expected no error, got %s", valDiags.Err()) + } + }) } } -func TestBackendConfig_invalidSSECustomerKeyLength(t *testing.T) { - testACC(t) - cfg := hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "encrypt": true, - "key": "state", - "dynamodb_table": "dynamoTable", - "sse_customer_key": "key", - }) - - _, diags := New().PrepareConfig(cfg) - if !diags.HasErrors() { - t.Fatal("expected error for invalid sse_customer_key length") +func TestBackendConfig_PrepareConfigWithEnvVars(t *testing.T) { + cases := map[string]struct { + config cty.Value + vars map[string]string + expectedErr string + }{ + "region env var AWS_REGION": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.NullVal(cty.String), + }), + vars: map[string]string{ + "AWS_REGION": "us-west-1", + }, + }, + "region env var AWS_DEFAULT_REGION": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.NullVal(cty.String), + }), + vars: map[string]string{ + "AWS_DEFAULT_REGION": "us-west-1", + }, + }, + "encyrption key conflict": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.StringVal("env"), + "kms_key_id": cty.StringVal("arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab"), + }), + vars: map[string]string{ + "AWS_SSE_CUSTOMER_KEY": "1hwbcNPGWL+AwDiyGmRidTWAEVmCWMKbEHA+Es8w75o=", + }, + expectedErr: `Only one of "kms_key_id" and the environment variable "AWS_SSE_CUSTOMER_KEY" can be set`, + }, } -} -func TestBackendConfig_invalidSSECustomerKeyEncoding(t *testing.T) { - testACC(t) - cfg := hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "encrypt": true, - "key": "state", - "dynamodb_table": "dynamoTable", - "sse_customer_key": "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", - }) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + oldEnv := stashEnv() + defer popEnv(oldEnv) - diags := New().Configure(cfg) - if !diags.HasErrors() { - t.Fatal("expected error for failing to decode sse_customer_key") - } -} + b := New() -func TestBackendConfig_conflictingEncryptionSchema(t *testing.T) { - testACC(t) - cfg := hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "state", - "encrypt": true, - "dynamodb_table": "dynamoTable", - "sse_customer_key": "1hwbcNPGWL+AwDiyGmRidTWAEVmCWMKbEHA+Es8w75o=", - "kms_key_id": "arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab", - }) + for k, v := range tc.vars { + os.Setenv(k, v) + } - diags := New().Configure(cfg) - if !diags.HasErrors() { - t.Fatal("expected error for simultaneous usage of kms_key_id and sse_customer_key") + _, valDiags := b.PrepareConfig(populateSchema(t, b.ConfigSchema(), tc.config)) + if tc.expectedErr != "" { + if valDiags.Err() != nil { + actualErr := valDiags.Err().Error() + if !strings.Contains(actualErr, tc.expectedErr) { + t.Fatalf("unexpected validation result: %v", valDiags.Err()) + } + } else { + t.Fatal("expected an error, got none") + } + } else if valDiags.Err() != nil { + t.Fatalf("expected no error, got %s", valDiags.Err()) + } + }) } } @@ -406,6 +837,7 @@ func TestBackend(t *testing.T) { "bucket": bucketName, "key": keyName, "encrypt": true, + "region": "us-west-1", })).(*Backend) createS3Bucket(t, b.s3Client, bucketName) @@ -425,6 +857,7 @@ func TestBackendLocked(t *testing.T) { "key": keyName, "encrypt": true, "dynamodb_table": bucketName, + "region": "us-west-1", })).(*Backend) b2 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{ @@ -432,6 +865,7 @@ func TestBackendLocked(t *testing.T) { "key": keyName, "encrypt": true, "dynamodb_table": bucketName, + "region": "us-west-1", })).(*Backend) createS3Bucket(t, b1.s3Client, bucketName) @@ -443,21 +877,132 @@ func TestBackendLocked(t *testing.T) { backend.TestBackendStateForceUnlock(t, b1, b2) } -func TestBackendSSECustomerKey(t *testing.T) { +func TestBackendSSECustomerKeyConfig(t *testing.T) { testACC(t) - bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) - b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{ - "bucket": bucketName, - "encrypt": true, - "key": "test-SSE-C", - "sse_customer_key": "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", - })).(*Backend) + testCases := map[string]struct { + customerKey string + expectedErr string + }{ + "invalid length": { + customerKey: "test", + expectedErr: `sse_customer_key must be 44 characters in length`, + }, + "invalid encoding": { + customerKey: "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", + expectedErr: `sse_customer_key must be base64 encoded`, + }, + "valid": { + customerKey: "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", + }, + } - createS3Bucket(t, b.s3Client, bucketName) - defer deleteS3Bucket(t, b.s3Client, bucketName) + for name, testCase := range testCases { + testCase := testCase - backend.TestBackendStates(t, b) + t.Run(name, func(t *testing.T) { + bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) + config := map[string]interface{}{ + "bucket": bucketName, + "encrypt": true, + "key": "test-SSE-C", + "sse_customer_key": testCase.customerKey, + "region": "us-west-1", + } + + b := New().(*Backend) + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config))) + + if testCase.expectedErr != "" { + if diags.Err() != nil { + actualErr := diags.Err().Error() + if !strings.Contains(actualErr, testCase.expectedErr) { + t.Fatalf("unexpected validation result: %v", diags.Err()) + } + } else { + t.Fatal("expected an error, got none") + } + } else { + if diags.Err() != nil { + t.Fatalf("expected no error, got %s", diags.Err()) + } + if string(b.customerEncryptionKey) != string(must(base64.StdEncoding.DecodeString(testCase.customerKey))) { + t.Fatal("unexpected value for customer encryption key") + } + + createS3Bucket(t, b.s3Client, bucketName) + defer deleteS3Bucket(t, b.s3Client, bucketName) + + backend.TestBackendStates(t, b) + } + }) + } +} + +func TestBackendSSECustomerKeyEnvVar(t *testing.T) { + testACC(t) + + testCases := map[string]struct { + customerKey string + expectedErr string + }{ + "invalid length": { + customerKey: "test", + expectedErr: `The environment variable "AWS_SSE_CUSTOMER_KEY" must be 44 characters in length`, + }, + "invalid encoding": { + customerKey: "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", + expectedErr: `The environment variable "AWS_SSE_CUSTOMER_KEY" must be base64 encoded`, + }, + "valid": { + customerKey: "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", + }, + } + + for name, testCase := range testCases { + testCase := testCase + + t.Run(name, func(t *testing.T) { + bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) + config := map[string]interface{}{ + "bucket": bucketName, + "encrypt": true, + "key": "test-SSE-C", + "region": "us-west-1", + } + + os.Setenv("AWS_SSE_CUSTOMER_KEY", testCase.customerKey) + t.Cleanup(func() { + os.Unsetenv("AWS_SSE_CUSTOMER_KEY") + }) + + b := New().(*Backend) + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config))) + + if testCase.expectedErr != "" { + if diags.Err() != nil { + actualErr := diags.Err().Error() + if !strings.Contains(actualErr, testCase.expectedErr) { + t.Fatalf("unexpected validation result: %v", diags.Err()) + } + } else { + t.Fatal("expected an error, got none") + } + } else { + if diags.Err() != nil { + t.Fatalf("expected no error, got %s", diags.Err()) + } + if string(b.customerEncryptionKey) != string(must(base64.StdEncoding.DecodeString(testCase.customerKey))) { + t.Fatal("unexpected value for customer encryption key") + } + + createS3Bucket(t, b.s3Client, bucketName) + defer deleteS3Bucket(t, b.s3Client, bucketName) + + backend.TestBackendStates(t, b) + } + }) + } } // add some extra junk in S3 to try and confuse the env listing. @@ -493,7 +1038,9 @@ func TestBackendExtraPaths(t *testing.T) { // Write the first state stateMgr := &remote.State{Client: client} - stateMgr.WriteState(s1) + if err := stateMgr.WriteState(s1); err != nil { + t.Fatal(err) + } if err := stateMgr.PersistState(nil); err != nil { t.Fatal(err) } @@ -503,7 +1050,9 @@ func TestBackendExtraPaths(t *testing.T) { // states are equal, the state will not Put to the remote client.path = b.path("s2") stateMgr2 := &remote.State{Client: client} - stateMgr2.WriteState(s2) + if err := stateMgr2.WriteState(s2); err != nil { + t.Fatal(err) + } if err := stateMgr2.PersistState(nil); err != nil { t.Fatal(err) } @@ -516,7 +1065,9 @@ func TestBackendExtraPaths(t *testing.T) { // put a state in an env directory name client.path = b.workspaceKeyPrefix + "/error" - stateMgr.WriteState(states.NewState()) + if err := stateMgr.WriteState(states.NewState()); err != nil { + t.Fatal(err) + } if err := stateMgr.PersistState(nil); err != nil { t.Fatal(err) } @@ -526,7 +1077,9 @@ func TestBackendExtraPaths(t *testing.T) { // add state with the wrong key for an existing env client.path = b.workspaceKeyPrefix + "/s2/notTestState" - stateMgr.WriteState(states.NewState()) + if err := stateMgr.WriteState(states.NewState()); err != nil { + t.Fatal(err) + } if err := stateMgr.PersistState(nil); err != nil { t.Fatal(err) } @@ -560,12 +1113,14 @@ func TestBackendExtraPaths(t *testing.T) { if s2Mgr.(*remote.State).StateSnapshotMeta().Lineage == s2Lineage { t.Fatal("state s2 was not deleted") } - s2 = s2Mgr.State() + _ = s2Mgr.State() // We need the side-effect s2Lineage = stateMgr.StateSnapshotMeta().Lineage // add a state with a key that matches an existing environment dir name client.path = b.workspaceKeyPrefix + "/s2/" - stateMgr.WriteState(states.NewState()) + if err := stateMgr.WriteState(states.NewState()); err != nil { + t.Fatal(err) + } if err := stateMgr.PersistState(nil); err != nil { t.Fatal(err) } @@ -796,3 +1351,127 @@ func deleteDynamoDBTable(t *testing.T, dynClient *dynamodb.DynamoDB, tableName s t.Logf("WARNING: Failed to delete the test DynamoDB table %q. It has been left in your AWS account and may incur charges. (error was %s)", tableName, err) } } + +func populateSchema(t *testing.T, schema *configschema.Block, value cty.Value) cty.Value { + ty := schema.ImpliedType() + var path cty.Path + val, err := unmarshal(value, ty, path) + if err != nil { + t.Fatalf("populating schema: %s", err) + } + return val +} + +func unmarshal(value cty.Value, ty cty.Type, path cty.Path) (cty.Value, error) { + switch { + case ty.IsPrimitiveType(): + return value, nil + // case ty.IsListType(): + // return unmarshalList(value, ty.ElementType(), path) + case ty.IsSetType(): + return unmarshalSet(value, ty.ElementType(), path) + case ty.IsMapType(): + return unmarshalMap(value, ty.ElementType(), path) + // case ty.IsTupleType(): + // return unmarshalTuple(value, ty.TupleElementTypes(), path) + case ty.IsObjectType(): + return unmarshalObject(value, ty.AttributeTypes(), path) + default: + return cty.NilVal, path.NewErrorf("unsupported type %s", ty.FriendlyName()) + } +} + +func unmarshalSet(dec cty.Value, ety cty.Type, path cty.Path) (cty.Value, error) { + if dec.IsNull() { + return dec, nil + } + + length := dec.LengthInt() + + if length == 0 { + return cty.SetValEmpty(ety), nil + } + + vals := make([]cty.Value, 0, length) + dec.ForEachElement(func(key, val cty.Value) (stop bool) { + vals = append(vals, val) + return + }) + + return cty.SetVal(vals), nil +} + +func unmarshalMap(dec cty.Value, ety cty.Type, path cty.Path) (cty.Value, error) { + if dec.IsNull() { + return dec, nil + } + + length := dec.LengthInt() + + if length == 0 { + return cty.MapValEmpty(ety), nil + } + + vals := make(map[string]cty.Value, length) + dec.ForEachElement(func(key, val cty.Value) (stop bool) { + k := stringValue(key) + vals[k] = val + return + }) + + return cty.MapVal(vals), nil +} + +func unmarshalObject(dec cty.Value, atys map[string]cty.Type, path cty.Path) (cty.Value, error) { + if dec.IsNull() { + return dec, nil + } + valueTy := dec.Type() + + vals := make(map[string]cty.Value, len(atys)) + path = append(path, nil) + for key, aty := range atys { + path[len(path)-1] = cty.IndexStep{ + Key: cty.StringVal(key), + } + + if !valueTy.HasAttribute(key) { + vals[key] = cty.NullVal(aty) + } else { + val, err := unmarshal(dec.GetAttr(key), aty, path) + if err != nil { + return cty.DynamicVal, err + } + vals[key] = val + } + } + + return cty.ObjectVal(vals), nil +} + +func stashEnv() []string { + env := os.Environ() + os.Clearenv() + return env +} + +func popEnv(env []string) { + os.Clearenv() + + for _, e := range env { + p := strings.SplitN(e, "=", 2) + k, v := p[0], "" + if len(p) > 1 { + v = p[1] + } + os.Setenv(k, v) + } +} + +func must[T any](v T, err error) T { + if err != nil { + panic(err) + } else { + return v + } +} diff --git a/internal/backend/remote-state/s3/testing.go b/internal/backend/remote-state/s3/testing.go new file mode 100644 index 000000000000..7cade2bc5729 --- /dev/null +++ b/internal/backend/remote-state/s3/testing.go @@ -0,0 +1,34 @@ +package s3 + +import ( + "github.com/hashicorp/terraform/internal/tfdiags" +) + +// diagnosticComparer is a Comparer function for use with cmp.Diff to compare two tfdiags.Diagnostic values +func diagnosticComparer(l, r tfdiags.Diagnostic) bool { + if l.Severity() != r.Severity() { + return false + } + if l.Description() != r.Description() { + return false + } + + lp := tfdiags.GetAttribute(l) + rp := tfdiags.GetAttribute(r) + if len(lp) != len(rp) { + return false + } + return lp.Equals(rp) +} + +// diagnosticSummaryComparer is a Comparer function for use with cmp.Diff to compare +// the Severity and Summary fields two tfdiags.Diagnostic values +func diagnosticSummaryComparer(l, r tfdiags.Diagnostic) bool { + if l.Severity() != r.Severity() { + return false + } + + ld := l.Description() + rd := r.Description() + return ld.Summary == rd.Summary +} diff --git a/internal/backend/remote-state/s3/validate.go b/internal/backend/remote-state/s3/validate.go new file mode 100644 index 000000000000..97a9e1087e7a --- /dev/null +++ b/internal/backend/remote-state/s3/validate.go @@ -0,0 +1,76 @@ +package s3 + +import ( + "fmt" + "regexp" + + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/hashicorp/terraform/internal/tfdiags" + "github.com/zclconf/go-cty/cty" +) + +const ( + multiRegionKeyIdPattern = `mrk-[a-f0-9]{32}` + uuidRegexPattern = `[a-f0-9]{8}-[a-f0-9]{4}-[1-5][a-f0-9]{3}-[ab89][a-f0-9]{3}-[a-f0-9]{12}` +) + +func validateKMSKey(path cty.Path, s string) (diags tfdiags.Diagnostics) { + if arn.IsARN(s) { + return validateKMSKeyARN(path, s) + } + return validateKMSKeyID(path, s) +} + +func validateKMSKeyID(path cty.Path, s string) (diags tfdiags.Diagnostics) { + keyIdRegex := regexp.MustCompile(`^` + uuidRegexPattern + `|` + multiRegionKeyIdPattern + `$`) + if !keyIdRegex.MatchString(s) { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ID", + fmt.Sprintf("Value must be a valid KMS Key ID, got %q", s), + path, + )) + return diags + } + + return diags +} + +func validateKMSKeyARN(path cty.Path, s string) (diags tfdiags.Diagnostics) { + parsedARN, err := arn.Parse(s) + if err != nil { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + fmt.Sprintf("Value must be a valid KMS Key ARN, got %q", s), + path, + )) + return diags + } + + if !isKeyARN(parsedARN) { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + fmt.Sprintf("Value must be a valid KMS Key ARN, got %q", s), + path, + )) + return diags + } + + return diags +} + +func isKeyARN(arn arn.ARN) bool { + return keyIdFromARNResource(arn.Resource) != "" +} + +func keyIdFromARNResource(s string) string { + keyIdResourceRegex := regexp.MustCompile(`^key/(` + uuidRegexPattern + `|` + multiRegionKeyIdPattern + `)$`) + matches := keyIdResourceRegex.FindStringSubmatch(s) + if matches == nil || len(matches) != 2 { + return "" + } + + return matches[1] +} diff --git a/internal/backend/remote-state/s3/validate_test.go b/internal/backend/remote-state/s3/validate_test.go new file mode 100644 index 000000000000..a4d5e32ca390 --- /dev/null +++ b/internal/backend/remote-state/s3/validate_test.go @@ -0,0 +1,154 @@ +package s3 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/terraform/internal/tfdiags" + "github.com/zclconf/go-cty/cty" +) + +func TestValidateKMSKey(t *testing.T) { + t.Parallel() + + path := cty.Path{cty.GetAttrStep{Name: "field"}} + + testcases := map[string]struct { + in string + expected tfdiags.Diagnostics + }{ + "kms key id": { + in: "57ff7a43-341d-46b6-aee3-a450c9de6dc8", + }, + "kms key arn": { + in: "arn:aws:kms:us-west-2:111122223333:key/57ff7a43-341d-46b6-aee3-a450c9de6dc8", + }, + "kms multi-region key id": { + in: "mrk-f827515944fb43f9b902a09d2c8b554f", + }, + "kms multi-region key arn": { + in: "arn:aws:kms:us-west-2:111122223333:key/mrk-a835af0b39c94b86a21a8fc9535df681", + }, + "kms key alias": { + in: "alias/arbitrary-key", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ID", + `Value must be a valid KMS Key ID, got "alias/arbitrary-key"`, + path, + ), + }, + }, + "kms key alias arn": { + in: "arn:aws:kms:us-west-2:111122223333:alias/arbitrary-key", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "arn:aws:kms:us-west-2:111122223333:alias/arbitrary-key"`, + path, + ), + }, + }, + "invalid key": { + in: "$%wrongkey", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ID", + `Value must be a valid KMS Key ID, got "$%wrongkey"`, + path, + ), + }, + }, + "non-kms arn": { + in: "arn:aws:lamda:foo:bar:key/xyz", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "arn:aws:lamda:foo:bar:key/xyz"`, + path, + ), + }, + }, + } + + for name, testcase := range testcases { + testcase := testcase + t.Run(name, func(t *testing.T) { + t.Parallel() + + diags := validateKMSKey(path, testcase.in) + + if diff := cmp.Diff(diags, testcase.expected, cmp.Comparer(diagnosticComparer)); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + }) + } +} + +func TestValidateKeyARN(t *testing.T) { + t.Parallel() + + path := cty.Path{cty.GetAttrStep{Name: "field"}} + + testcases := map[string]struct { + in string + expected tfdiags.Diagnostics + }{ + "kms key id": { + in: "arn:aws:kms:us-west-2:123456789012:key/57ff7a43-341d-46b6-aee3-a450c9de6dc8", + }, + "kms mrk key id": { + in: "arn:aws:kms:us-west-2:111122223333:key/mrk-a835af0b39c94b86a21a8fc9535df681", + }, + "kms non-key id": { + in: "arn:aws:kms:us-west-2:123456789012:something/else", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "arn:aws:kms:us-west-2:123456789012:something/else"`, + path, + ), + }, + }, + "non-kms arn": { + in: "arn:aws:iam::123456789012:user/David", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "arn:aws:iam::123456789012:user/David"`, + path, + ), + }, + }, + "not an arn": { + in: "not an arn", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "not an arn"`, + path, + ), + }, + }, + } + + for name, testcase := range testcases { + testcase := testcase + t.Run(name, func(t *testing.T) { + t.Parallel() + + diags := validateKMSKeyARN(path, testcase.in) + + if diff := cmp.Diff(diags, testcase.expected, cmp.Comparer(diagnosticComparer)); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + }) + } +}