From 9257dfa8a70083522ba1ce8bdfb4d93adebc287b Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Tue, 25 Oct 2022 10:04:45 -0700 Subject: [PATCH 01/20] S3 backend: Adds tests for defaults and validation --- .../backend/remote-state/s3/backend_test.go | 301 ++++++++++++++++-- 1 file changed, 271 insertions(+), 30 deletions(-) diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 92fe40b2c346..185ba3b44aa4 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "reflect" + "strings" "testing" "time" @@ -19,6 +20,8 @@ import ( "github.com/hashicorp/terraform/internal/configs/hcl2shim" "github.com/hashicorp/terraform/internal/states" "github.com/hashicorp/terraform/internal/states/remote" + "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/gocty" ) var ( @@ -315,37 +318,7 @@ func TestBackendConfig_AssumeRole(t *testing.T) { } } -func TestBackendConfig_invalidKey(t *testing.T) { - testACC(t) - cfg := hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "/leading-slash", - "encrypt": true, - "dynamodb_table": "dynamoTable", - }) - - _, diags := New().PrepareConfig(cfg) - if !diags.HasErrors() { - t.Fatal("expected config validation error") - } - - cfg = hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "trailing-slash/", - "encrypt": true, - "dynamodb_table": "dynamoTable", - }) - - _, diags = New().PrepareConfig(cfg) - if !diags.HasErrors() { - t.Fatal("expected config validation error") - } -} - func TestBackendConfig_invalidSSECustomerKeyLength(t *testing.T) { - testACC(t) cfg := hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ "region": "us-west-1", "bucket": "tf-test", @@ -396,6 +369,274 @@ func TestBackendConfig_conflictingEncryptionSchema(t *testing.T) { } } +func TestBackendConfig_PrepareConfigValidation(t *testing.T) { + cases := map[string]struct { + config cty.Value + expectedErr string + }{ + "null bucket": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.NullVal(cty.String), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `"bucket": required field is not set`, + }, + "empty bucket": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal(""), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `"bucket": required field is not set`, + }, + "null key": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.NullVal(cty.String), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `"key": required field is not set`, + }, + "empty key": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal(""), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `"key": required field is not set`, + }, + "key with leading slash": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("/leading-slash"), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `key must not start with '/'`, + }, + "key with trailing slash": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("trailing-slash/"), + "region": cty.StringVal("us-west-2"), + }), + expectedErr: `key must not end with '/'`, + }, + "null region": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.NullVal(cty.String), + }), + expectedErr: `"region": required field is not set`, + }, + "empty region": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal(""), + }), + expectedErr: `"region": required field is not set`, + }, + "workspace_key_prefix with leading slash": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.StringVal("/env"), + }), + expectedErr: `workspace_key_prefix must not start or end with '/'`, + }, + "workspace_key_prefix with trailing lash": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.StringVal("env/"), + }), + expectedErr: `workspace_key_prefix must not start or end with '/'`, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + b := New() + + // Validate + _, valDiags := b.PrepareConfig(tc.config) + if valDiags.Err() != nil && tc.expectedErr != "" { + actualErr := valDiags.Err().Error() + if !strings.Contains(actualErr, tc.expectedErr) { + t.Fatalf("unexpected validation result: %v", valDiags.Err()) + } + } + }) + } +} + +func TestBackendConfig_PrepareConfigWithEnvVars(t *testing.T) { + cases := map[string]struct { + config cty.Value + vars map[string]string + expectedErr string + }{ + "region env var AWS_REGION": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.NullVal(cty.String), + }), + vars: map[string]string{ + "AWS_REGION": "us-west-2", + }, + }, + "region env var AWS_DEFAULT_REGION": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.NullVal(cty.String), + }), + vars: map[string]string{ + "AWS_DEFAULT_REGION": "us-west-2", + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + b := New() + + for k, v := range tc.vars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range tc.vars { + os.Unsetenv(k) + } + }) + + _, valDiags := b.PrepareConfig(tc.config) + if valDiags.Err() != nil && tc.expectedErr != "" { + actualErr := valDiags.Err().Error() + if !strings.Contains(actualErr, tc.expectedErr) { + t.Fatalf("unexpected validation result: %v", valDiags.Err()) + } + } + }) + } +} + +func TestBackendConfig_PrepareConfigDefaults(t *testing.T) { + cases := map[string]struct { + config cty.Value + validate func(cty.Value, *testing.T) + }{ + "workspace_key_prefix": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.NullVal(cty.String), + }), + validate: func(c cty.Value, t *testing.T) { + v := c.GetAttr("workspace_key_prefix") + if v.IsNull() { + t.Fatal(`expected value for "workspace_key_prefix", got null`) + } else if a, e := v.AsString(), "env:"; a != e { + t.Fatalf(`expected "workspace_key_prefix" to be "%v", got "%v"`, e, a) + } + }, + }, + "max_retries": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "max_retries": cty.NullVal(cty.Number), + }), + validate: func(c cty.Value, t *testing.T) { + v := c.GetAttr("max_retries") + if v.IsNull() { + t.Fatal(`expected value for "max_retries", got null`) + } else { + var a int + err := gocty.FromCtyValue(v, &a) + if err != nil { + t.Fatalf("unexpected value error: %v", err) + } + if e := 5; a != e { + t.Fatalf(`expected "max_retries" to be "%v", got "%v"`, e, a) + } + } + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + b := New() + + // Validate + val, valDiags := b.PrepareConfig(tc.config) + if valDiags.Err() != nil { + t.Fatalf("unexpected validation result: %v", valDiags.Err()) + } + tc.validate(val, t) + }) + } +} + +func TestBackendConfig_prefixDefault(t *testing.T) { + config := cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.NullVal(cty.String), + }) + + b := New() + val, diags := b.PrepareConfig(config) + if err := diags.Err(); err != nil { + t.Fatalf("unexpected validation result: %v", err) + } + + v := val.GetAttr("workspace_key_prefix") + if v.IsNull() { + t.Fatal(`expected value for "workspace_key_prefix", got null`) + } else if v := v.AsString(); v != "env:" { + t.Fatalf(`expected "workspace_key_prefix" to be "env:", got %q`, v) + } +} + +func TestBackendConfig_maxRetriesDefault(t *testing.T) { + config := cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "max_retries": cty.NullVal(cty.Number), + }) + + b := New() + val, diags := b.PrepareConfig(config) + if err := diags.Err(); err != nil { + t.Fatalf("unexpected validation result: %v", err) + } + + v := val.GetAttr("max_retries") + if v.IsNull() { + t.Fatal(`expected value for "max_retries", got null`) + } else { + var foo int + err := gocty.FromCtyValue(v, &foo) + if err != nil { + t.Fatalf("unexpected value error: %v", err) + } + if foo != 5 { + t.Fatalf(`expected "max_retries" to be 5, got %v`, foo) + } + } +} + func TestBackend(t *testing.T) { testACC(t) From b5de54064350081b37607015059ce573d975d575 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Tue, 25 Oct 2022 14:19:28 -0700 Subject: [PATCH 02/20] Fully populates schema values --- .../backend/remote-state/s3/backend_test.go | 216 ++++++------------ 1 file changed, 76 insertions(+), 140 deletions(-) diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 185ba3b44aa4..dce3bbe8681a 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -17,11 +17,11 @@ import ( "github.com/aws/aws-sdk-go/service/s3" awsbase "github.com/hashicorp/aws-sdk-go-base" "github.com/hashicorp/terraform/internal/backend" + "github.com/hashicorp/terraform/internal/configs/configschema" "github.com/hashicorp/terraform/internal/configs/hcl2shim" "github.com/hashicorp/terraform/internal/states" "github.com/hashicorp/terraform/internal/states/remote" "github.com/zclconf/go-cty/cty" - "github.com/zclconf/go-cty/cty/gocty" ) var ( @@ -319,14 +319,14 @@ func TestBackendConfig_AssumeRole(t *testing.T) { } func TestBackendConfig_invalidSSECustomerKeyLength(t *testing.T) { - cfg := hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ + cfg := populateSchema(t, New().ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ "region": "us-west-1", "bucket": "tf-test", "encrypt": true, "key": "state", "dynamodb_table": "dynamoTable", "sse_customer_key": "key", - }) + })) _, diags := New().PrepareConfig(cfg) if !diags.HasErrors() { @@ -335,40 +335,21 @@ func TestBackendConfig_invalidSSECustomerKeyLength(t *testing.T) { } func TestBackendConfig_invalidSSECustomerKeyEncoding(t *testing.T) { - testACC(t) - cfg := hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ + cfg := populateSchema(t, New().ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ "region": "us-west-1", "bucket": "tf-test", "encrypt": true, "key": "state", "dynamodb_table": "dynamoTable", "sse_customer_key": "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", - }) + })) - diags := New().Configure(cfg) + _, diags := New().PrepareConfig(cfg) if !diags.HasErrors() { t.Fatal("expected error for failing to decode sse_customer_key") } } -func TestBackendConfig_conflictingEncryptionSchema(t *testing.T) { - testACC(t) - cfg := hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "state", - "encrypt": true, - "dynamodb_table": "dynamoTable", - "sse_customer_key": "1hwbcNPGWL+AwDiyGmRidTWAEVmCWMKbEHA+Es8w75o=", - "kms_key_id": "arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab", - }) - - diags := New().Configure(cfg) - if !diags.HasErrors() { - t.Fatal("expected error for simultaneous usage of kms_key_id and sse_customer_key") - } -} - func TestBackendConfig_PrepareConfigValidation(t *testing.T) { cases := map[string]struct { config cty.Value @@ -447,7 +428,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { }), expectedErr: `workspace_key_prefix must not start or end with '/'`, }, - "workspace_key_prefix with trailing lash": { + "workspace_key_prefix with trailing slash": { config: cty.ObjectVal(map[string]cty.Value{ "bucket": cty.StringVal("test"), "key": cty.StringVal("test"), @@ -456,14 +437,24 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { }), expectedErr: `workspace_key_prefix must not start or end with '/'`, }, + "encyrption key conflict": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.StringVal("env/"), + "sse_customer_key": cty.StringVal("1hwbcNPGWL+AwDiyGmRidTWAEVmCWMKbEHA+Es8w75o="), + "kms_key_id": cty.StringVal("arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab"), + }), + expectedErr: `Only one of "kms_key_id" and "sse_customer_key" can be set`, + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { b := New() - // Validate - _, valDiags := b.PrepareConfig(tc.config) + _, valDiags := b.PrepareConfig(populateSchema(t, b.ConfigSchema(), tc.config)) if valDiags.Err() != nil && tc.expectedErr != "" { actualErr := valDiags.Err().Error() if !strings.Contains(actualErr, tc.expectedErr) { @@ -515,7 +506,7 @@ func TestBackendConfig_PrepareConfigWithEnvVars(t *testing.T) { } }) - _, valDiags := b.PrepareConfig(tc.config) + _, valDiags := b.PrepareConfig(populateSchema(t, b.ConfigSchema(), tc.config)) if valDiags.Err() != nil && tc.expectedErr != "" { actualErr := valDiags.Err().Error() if !strings.Contains(actualErr, tc.expectedErr) { @@ -526,117 +517,6 @@ func TestBackendConfig_PrepareConfigWithEnvVars(t *testing.T) { } } -func TestBackendConfig_PrepareConfigDefaults(t *testing.T) { - cases := map[string]struct { - config cty.Value - validate func(cty.Value, *testing.T) - }{ - "workspace_key_prefix": { - config: cty.ObjectVal(map[string]cty.Value{ - "bucket": cty.StringVal("test"), - "key": cty.StringVal("test"), - "region": cty.StringVal("us-west-2"), - "workspace_key_prefix": cty.NullVal(cty.String), - }), - validate: func(c cty.Value, t *testing.T) { - v := c.GetAttr("workspace_key_prefix") - if v.IsNull() { - t.Fatal(`expected value for "workspace_key_prefix", got null`) - } else if a, e := v.AsString(), "env:"; a != e { - t.Fatalf(`expected "workspace_key_prefix" to be "%v", got "%v"`, e, a) - } - }, - }, - "max_retries": { - config: cty.ObjectVal(map[string]cty.Value{ - "bucket": cty.StringVal("test"), - "key": cty.StringVal("test"), - "region": cty.StringVal("us-west-2"), - "max_retries": cty.NullVal(cty.Number), - }), - validate: func(c cty.Value, t *testing.T) { - v := c.GetAttr("max_retries") - if v.IsNull() { - t.Fatal(`expected value for "max_retries", got null`) - } else { - var a int - err := gocty.FromCtyValue(v, &a) - if err != nil { - t.Fatalf("unexpected value error: %v", err) - } - if e := 5; a != e { - t.Fatalf(`expected "max_retries" to be "%v", got "%v"`, e, a) - } - } - }, - }, - } - - for name, tc := range cases { - t.Run(name, func(t *testing.T) { - b := New() - - // Validate - val, valDiags := b.PrepareConfig(tc.config) - if valDiags.Err() != nil { - t.Fatalf("unexpected validation result: %v", valDiags.Err()) - } - tc.validate(val, t) - }) - } -} - -func TestBackendConfig_prefixDefault(t *testing.T) { - config := cty.ObjectVal(map[string]cty.Value{ - "bucket": cty.StringVal("test"), - "key": cty.StringVal("test"), - "region": cty.StringVal("us-west-2"), - "workspace_key_prefix": cty.NullVal(cty.String), - }) - - b := New() - val, diags := b.PrepareConfig(config) - if err := diags.Err(); err != nil { - t.Fatalf("unexpected validation result: %v", err) - } - - v := val.GetAttr("workspace_key_prefix") - if v.IsNull() { - t.Fatal(`expected value for "workspace_key_prefix", got null`) - } else if v := v.AsString(); v != "env:" { - t.Fatalf(`expected "workspace_key_prefix" to be "env:", got %q`, v) - } -} - -func TestBackendConfig_maxRetriesDefault(t *testing.T) { - config := cty.ObjectVal(map[string]cty.Value{ - "bucket": cty.StringVal("test"), - "key": cty.StringVal("test"), - "region": cty.StringVal("us-west-2"), - "max_retries": cty.NullVal(cty.Number), - }) - - b := New() - val, diags := b.PrepareConfig(config) - if err := diags.Err(); err != nil { - t.Fatalf("unexpected validation result: %v", err) - } - - v := val.GetAttr("max_retries") - if v.IsNull() { - t.Fatal(`expected value for "max_retries", got null`) - } else { - var foo int - err := gocty.FromCtyValue(v, &foo) - if err != nil { - t.Fatalf("unexpected value error: %v", err) - } - if foo != 5 { - t.Fatalf(`expected "max_retries" to be 5, got %v`, foo) - } - } -} - func TestBackend(t *testing.T) { testACC(t) @@ -1037,3 +917,59 @@ func deleteDynamoDBTable(t *testing.T, dynClient *dynamodb.DynamoDB, tableName s t.Logf("WARNING: Failed to delete the test DynamoDB table %q. It has been left in your AWS account and may incur charges. (error was %s)", tableName, err) } } + +func populateSchema(t *testing.T, schema *configschema.Block, value cty.Value) cty.Value { + ty := schema.ImpliedType() + var path cty.Path + val, err := unmarshal(value, ty, path) + if err != nil { + t.Fatalf("populating schema: %s", err) + } + return val +} + +func unmarshal(value cty.Value, ty cty.Type, path cty.Path) (cty.Value, error) { + switch { + case ty.IsPrimitiveType(): + val, err := unmarshalPrimitive(value, ty, path) + if err != nil { + return cty.NilVal, err + } + return val, nil + case ty.IsObjectType(): + return unmarshalObject(value, ty.AttributeTypes(), path) + default: + return cty.NilVal, path.NewErrorf("unsupported type %s", ty.FriendlyName()) + } +} + +func unmarshalPrimitive(value cty.Value, ty cty.Type, path cty.Path) (cty.Value, error) { + return value, nil +} + +func unmarshalObject(dec cty.Value, atys map[string]cty.Type, path cty.Path) (cty.Value, error) { + if dec.IsNull() { + return dec, nil + } + valueTy := dec.Type() + + vals := make(map[string]cty.Value, len(atys)) + path = append(path, nil) + for key, aty := range atys { + path[len(path)-1] = cty.IndexStep{ + Key: cty.StringVal(key), + } + + if !valueTy.HasAttribute(key) { + vals[key] = cty.NullVal(aty) + } else { + val, err := unmarshal(dec.GetAttr(key), aty, path) + if err != nil { + return cty.DynamicVal, err + } + vals[key] = val + } + } + + return cty.ObjectVal(vals), nil +} From 381006742b8811608be7770d18467799fd2b1778 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Tue, 25 Oct 2022 14:23:22 -0700 Subject: [PATCH 03/20] Moves validation to `PrepareConfig` --- internal/backend/remote-state/s3/backend.go | 117 +++++++++++++----- .../backend/remote-state/s3/backend_test.go | 54 +++----- 2 files changed, 109 insertions(+), 62 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index aa64869f7e14..aa4aa4bbc58c 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -8,6 +8,7 @@ import ( "encoding/base64" "errors" "fmt" + "os" "strings" "github.com/aws/aws-sdk-go/aws" @@ -17,7 +18,9 @@ import ( "github.com/hashicorp/terraform/internal/backend" "github.com/hashicorp/terraform/internal/legacy/helper/schema" "github.com/hashicorp/terraform/internal/logging" + "github.com/hashicorp/terraform/internal/tfdiags" "github.com/hashicorp/terraform/version" + "github.com/zclconf/go-cty/cty" ) // New creates a new backend for S3 remote state. @@ -34,19 +37,6 @@ func New() backend.Backend { Type: schema.TypeString, Required: true, Description: "The path to the state file inside the bucket", - ValidateFunc: func(v interface{}, s string) ([]string, []error) { - // s3 will strip leading slashes from an object, so while this will - // technically be accepted by s3, it will break our workspace hierarchy. - if strings.HasPrefix(v.(string), "/") { - return nil, []error{errors.New("key must not start with '/'")} - } - // s3 will recognize objects with a trailing slash as a directory - // so they should not be valid keys - if strings.HasSuffix(v.(string), "/") { - return nil, []error{errors.New("key must not end with '/'")} - } - return nil, nil - }, }, "region": { @@ -177,13 +167,6 @@ func New() backend.Backend { Description: "The base64-encoded encryption key to use for server-side encryption with customer-provided keys (SSE-C).", DefaultFunc: schema.EnvDefaultFunc("AWS_SSE_CUSTOMER_KEY", ""), Sensitive: true, - ValidateFunc: func(v interface{}, s string) ([]string, []error) { - key := v.(string) - if key != "" && len(key) != 44 { - return nil, []error{errors.New("sse_customer_key must be 44 characters in length (256 bits, base64 encoded)")} - } - return nil, nil - }, }, "role_arn": { @@ -246,13 +229,6 @@ func New() backend.Backend { Optional: true, Description: "The prefix applied to the non-default state path inside the bucket.", Default: "env:", - ValidateFunc: func(v interface{}, s string) ([]string, []error) { - prefix := v.(string) - if strings.HasPrefix(prefix, "/") || strings.HasSuffix(prefix, "/") { - return nil, []error{errors.New("workspace_key_prefix must not start or end with '/'")} - } - return nil, nil - }, }, "force_path_style": { @@ -276,6 +252,91 @@ func New() backend.Backend { return result } +func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) { + var diags tfdiags.Diagnostics + if obj.IsNull() { + return obj, diags + } + + if val := obj.GetAttr("key"); val.IsNull() || val.AsString() == "" { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid key value", + `"key": required field is not set`, + cty.Path{cty.GetAttrStep{Name: "key"}}, + )) + } else if strings.HasPrefix(val.AsString(), "/") || strings.HasSuffix(val.AsString(), "/") { + // S3 will strip leading slashes from an object, so while this will + // technically be accepted by S3, it will break our workspace hierarchy. + // S3 will recognize objects with a trailing slash as a directory + // so they should not be valid keys + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid key value", + "key must not start or end with '/'", + cty.Path{cty.GetAttrStep{Name: "key"}}, + )) + } + + if val := obj.GetAttr("region"); val.IsNull() { + if os.Getenv("AWS_REGION") == "" && os.Getenv("AWS_DEFAULT_REGION") == "" { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Missing region value", + `"region": required field is not set`, + cty.Path{cty.GetAttrStep{Name: "region"}}, + )) + } + } + + if val := obj.GetAttr("sse_customer_key"); !val.IsNull() { + s := val.AsString() + if len(s) != 44 { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid sse_customer_key value", + "sse_customer_key must be 44 characters in length", + cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, + )) + } else { + var err error + _, err = base64.StdEncoding.DecodeString(s) + if err != nil { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid sse_customer_key value", + fmt.Sprintf("sse_customer_key must be base64 encoded: %s", err), + cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, + )) + } + } + } + + if val := obj.GetAttr("kms_key_id"); !val.IsNull() && val.AsString() != "" { + if val := obj.GetAttr("sse_customer_key"); !val.IsNull() && val.AsString() != "" { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid encryption configuration", + encryptionKeyConflictError, + cty.Path{}, + )) + } + } + + if val := obj.GetAttr("workspace_key_prefix"); !val.IsNull() { + if v := val.AsString(); strings.HasPrefix(v, "/") || strings.HasSuffix(v, "/") { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid workspace_key_prefix value", + "workspace_key_prefix must not start or end with '/'", + cty.Path{cty.GetAttrStep{Name: "workspace_key_prefix"}}, + )) + } + } + + return obj, diags +} + type Backend struct { *schema.Backend @@ -409,7 +470,7 @@ func (b *Backend) configure(ctx context.Context) error { return nil } -const encryptionKeyConflictError = `Cannot have both kms_key_id and sse_customer_key set. +const encryptionKeyConflictError = `Only one of "kms_key_id" and "sse_customer_key" can be set. The kms_key_id is used for encryption with KMS-Managed Keys (SSE-KMS) while sse_customer_key is used for encryption with customer-managed keys (SSE-C). diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index dce3bbe8681a..feae781a037c 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -318,38 +318,6 @@ func TestBackendConfig_AssumeRole(t *testing.T) { } } -func TestBackendConfig_invalidSSECustomerKeyLength(t *testing.T) { - cfg := populateSchema(t, New().ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "encrypt": true, - "key": "state", - "dynamodb_table": "dynamoTable", - "sse_customer_key": "key", - })) - - _, diags := New().PrepareConfig(cfg) - if !diags.HasErrors() { - t.Fatal("expected error for invalid sse_customer_key length") - } -} - -func TestBackendConfig_invalidSSECustomerKeyEncoding(t *testing.T) { - cfg := populateSchema(t, New().ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "encrypt": true, - "key": "state", - "dynamodb_table": "dynamoTable", - "sse_customer_key": "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", - })) - - _, diags := New().PrepareConfig(cfg) - if !diags.HasErrors() { - t.Fatal("expected error for failing to decode sse_customer_key") - } -} - func TestBackendConfig_PrepareConfigValidation(t *testing.T) { cases := map[string]struct { config cty.Value @@ -393,7 +361,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.StringVal("/leading-slash"), "region": cty.StringVal("us-west-2"), }), - expectedErr: `key must not start with '/'`, + expectedErr: `key must not start or end with '/'`, }, "key with trailing slash": { config: cty.ObjectVal(map[string]cty.Value{ @@ -401,7 +369,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.StringVal("trailing-slash/"), "region": cty.StringVal("us-west-2"), }), - expectedErr: `key must not end with '/'`, + expectedErr: `key must not start or end with '/'`, }, "null region": { config: cty.ObjectVal(map[string]cty.Value{ @@ -437,6 +405,24 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { }), expectedErr: `workspace_key_prefix must not start or end with '/'`, }, + "sse_customer_key invalid length": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "sse_customer_key": cty.StringVal("key"), + }), + expectedErr: `sse_customer_key must be 44 characters in length`, + }, + "sse_customer_key invalid encoding": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "sse_customer_key": cty.StringVal("====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka"), + }), + expectedErr: `sse_customer_key must be base64 encoded`, + }, "encyrption key conflict": { config: cty.ObjectVal(map[string]cty.Value{ "bucket": cty.StringVal("test"), From 467e6256da8387b4717b1771ee6deec29d1aac44 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 26 Oct 2022 15:05:50 -0700 Subject: [PATCH 04/20] Moves configuration to `Configure` --- internal/backend/remote-state/s3/backend.go | 262 +++++++++++------- .../backend/remote-state/s3/backend_test.go | 61 +++- 2 files changed, 216 insertions(+), 107 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index aa4aa4bbc58c..cf651c50abdd 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -4,9 +4,7 @@ package s3 import ( - "context" "encoding/base64" - "errors" "fmt" "os" "strings" @@ -21,6 +19,7 @@ import ( "github.com/hashicorp/terraform/internal/tfdiags" "github.com/hashicorp/terraform/version" "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/gocty" ) // New creates a new backend for S3 remote state. @@ -247,9 +246,24 @@ func New() backend.Backend { }, } - result := &Backend{Backend: s} - result.Backend.ConfigureFunc = result.configure - return result + return &Backend{Backend: s} +} + +type Backend struct { + *schema.Backend + + // The fields below are set from configure + s3Client *s3.S3 + dynClient *dynamodb.DynamoDB + + bucketName string + keyName string + serverSideEncryption bool + customerEncryptionKey []byte + acl string + kmsKeyID string + ddbTable string + workspaceKeyPrefix string } func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) { @@ -337,78 +351,62 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) return obj, diags } -type Backend struct { - *schema.Backend - - // The fields below are set from configure - s3Client *s3.S3 - dynClient *dynamodb.DynamoDB - - bucketName string - keyName string - serverSideEncryption bool - customerEncryptionKey []byte - acl string - kmsKeyID string - ddbTable string - workspaceKeyPrefix string -} - -func (b *Backend) configure(ctx context.Context) error { - if b.s3Client != nil { - return nil +func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { + var diags tfdiags.Diagnostics + if obj.IsNull() { + return diags } - // Grab the resource data - data := schema.FromContextBackendConfig(ctx) - - if !data.Get("skip_region_validation").(bool) { - if err := awsbase.ValidateRegion(data.Get("region").(string)); err != nil { - return err - } + var region string + if v, ok := stringAttrOk(obj, "region"); ok { + region = v } - b.bucketName = data.Get("bucket").(string) - b.keyName = data.Get("key").(string) - b.acl = data.Get("acl").(string) - b.workspaceKeyPrefix = data.Get("workspace_key_prefix").(string) - b.serverSideEncryption = data.Get("encrypt").(bool) - b.kmsKeyID = data.Get("kms_key_id").(string) - b.ddbTable = data.Get("dynamodb_table").(string) - - customerKeyString := data.Get("sse_customer_key").(string) - if customerKeyString != "" { - if b.kmsKeyID != "" { - return errors.New(encryptionKeyConflictError) + if boolAttr(obj, "skip_region_validation") { + if err := awsbase.ValidateRegion(region); err != nil { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid region value", + err.Error(), + cty.Path{cty.GetAttrStep{Name: "region"}}, + )) + return diags } + } - var err error - b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKeyString) - if err != nil { - return fmt.Errorf("Failed to decode sse_customer_key: %s", err.Error()) - } + b.bucketName = stringAttr(obj, "bucket") + b.keyName = stringAttr(obj, "key") + b.acl = stringAttr(obj, "acl") + b.workspaceKeyPrefix = stringAttrDefault(obj, "workspace_key_prefix", "env:") + b.serverSideEncryption = boolAttr(obj, "encrypt") + b.kmsKeyID = stringAttr(obj, "kms_key_id") + b.ddbTable = stringAttr(obj, "dynamodb_table") + + if customerKeyString, ok := stringAttrOk(obj, "sse_customer_key"); ok { + // Validation is handled in PrepareConfig, so ignore it here + b.customerEncryptionKey, _ = base64.StdEncoding.DecodeString(customerKeyString) } cfg := &awsbase.Config{ - AccessKey: data.Get("access_key").(string), - AssumeRoleARN: data.Get("role_arn").(string), - AssumeRoleDurationSeconds: data.Get("assume_role_duration_seconds").(int), - AssumeRoleExternalID: data.Get("external_id").(string), - AssumeRolePolicy: data.Get("assume_role_policy").(string), - AssumeRoleSessionName: data.Get("session_name").(string), + AccessKey: stringAttr(obj, "access_key"), + AssumeRoleARN: stringAttr(obj, "role_arn"), + AssumeRoleDurationSeconds: intAttr(obj, "assume_role_duration_seconds"), + AssumeRoleExternalID: stringAttr(obj, "external_id"), + AssumeRolePolicy: stringAttr(obj, "assume_role_policy"), + AssumeRoleSessionName: stringAttr(obj, "session_name"), CallerDocumentationURL: "https://www.terraform.io/docs/language/settings/backends/s3.html", CallerName: "S3 Backend", - CredsFilename: data.Get("shared_credentials_file").(string), + CredsFilename: stringAttr(obj, "shared_credentials_file"), DebugLogging: logging.IsDebugOrHigher(), - IamEndpoint: data.Get("iam_endpoint").(string), - MaxRetries: data.Get("max_retries").(int), - Profile: data.Get("profile").(string), - Region: data.Get("region").(string), - SecretKey: data.Get("secret_key").(string), - SkipCredsValidation: data.Get("skip_credentials_validation").(bool), - SkipMetadataApiCheck: data.Get("skip_metadata_api_check").(bool), - StsEndpoint: data.Get("sts_endpoint").(string), - Token: data.Get("token").(string), + IamEndpoint: stringAttr(obj, "iam_endpoint"), + MaxRetries: intAttrDefault(obj, "max_retries", 5), + Profile: stringAttr(obj, "profile"), + Region: stringAttr(obj, "region"), + SecretKey: stringAttr(obj, "secret_key"), + SkipCredsValidation: boolAttr(obj, "skip_credentials_validation"), + SkipMetadataApiCheck: boolAttr(obj, "skip_metadata_api_check"), + StsEndpoint: stringAttr(obj, "sts_endpoint"), + Token: stringAttr(obj, "token"), UserAgentProducts: []*awsbase.UserAgentProduct{ {Name: "APN", Version: "1.0"}, {Name: "HashiCorp", Version: "1.0"}, @@ -416,58 +414,124 @@ func (b *Backend) configure(ctx context.Context) error { }, } - if policyARNSet := data.Get("assume_role_policy_arns").(*schema.Set); policyARNSet.Len() > 0 { - for _, policyARNRaw := range policyARNSet.List() { - policyARN, ok := policyARNRaw.(string) - - if !ok { - continue + if policyARNSet := obj.GetAttr("assume_role_policy_arns"); !policyARNSet.IsNull() { + policyARNSet.ForEachElement(func(key, val cty.Value) (stop bool) { + v, ok := stringValueOk(val) + if ok { + cfg.AssumeRolePolicyARNs = append(cfg.AssumeRolePolicyARNs, v) } - - cfg.AssumeRolePolicyARNs = append(cfg.AssumeRolePolicyARNs, policyARN) - } + return + }) } - if tagMap := data.Get("assume_role_tags").(map[string]interface{}); len(tagMap) > 0 { - cfg.AssumeRoleTags = make(map[string]string) - - for k, vRaw := range tagMap { - v, ok := vRaw.(string) - - if !ok { - continue + if tagMap := obj.GetAttr("assume_role_tags"); !tagMap.IsNull() { + cfg.AssumeRoleTags = make(map[string]string, tagMap.LengthInt()) + tagMap.ForEachElement(func(key, val cty.Value) (stop bool) { + k := stringValue(key) + v, ok := stringValueOk(val) + if ok { + cfg.AssumeRoleTags[k] = v } - - cfg.AssumeRoleTags[k] = v - } + return + }) } - if transitiveTagKeySet := data.Get("assume_role_transitive_tag_keys").(*schema.Set); transitiveTagKeySet.Len() > 0 { - for _, transitiveTagKeyRaw := range transitiveTagKeySet.List() { - transitiveTagKey, ok := transitiveTagKeyRaw.(string) - - if !ok { - continue + if transitiveTagKeySet := obj.GetAttr("assume_role_transitive_tag_keys"); !transitiveTagKeySet.IsNull() { + transitiveTagKeySet.ForEachElement(func(key, val cty.Value) (stop bool) { + v, ok := stringValueOk(val) + if ok { + cfg.AssumeRoleTransitiveTagKeys = append(cfg.AssumeRoleTransitiveTagKeys, v) } - - cfg.AssumeRoleTransitiveTagKeys = append(cfg.AssumeRoleTransitiveTagKeys, transitiveTagKey) - } + return + }) } sess, err := awsbase.GetSession(cfg) if err != nil { - return fmt.Errorf("error configuring S3 Backend: %w", err) + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Failed to configure AWS client", + fmt.Sprintf(`The "S3" backend encountered an unexpected error while creating the AWS client: %s`, err), + )) + return diags } b.dynClient = dynamodb.New(sess.Copy(&aws.Config{ - Endpoint: aws.String(data.Get("dynamodb_endpoint").(string)), + Endpoint: aws.String(stringAttr(obj, "dynamodb_endpoint")), })) b.s3Client = s3.New(sess.Copy(&aws.Config{ - Endpoint: aws.String(data.Get("endpoint").(string)), - S3ForcePathStyle: aws.Bool(data.Get("force_path_style").(bool)), + Endpoint: aws.String(stringAttr(obj, "endpoint")), + S3ForcePathStyle: aws.Bool(boolAttr(obj, "force_path_style")), })) - return nil + return diags +} + +func stringValue(val cty.Value) string { + v, _ := stringValueOk(val) + return v +} + +func stringValueOk(val cty.Value) (string, bool) { + if val.IsNull() { + return "", false + } else { + return val.AsString(), true + } +} + +func stringAttr(obj cty.Value, name string) string { + return stringValue(obj.GetAttr(name)) +} + +func stringAttrOk(obj cty.Value, name string) (string, bool) { + return stringValueOk(obj.GetAttr(name)) +} + +func stringAttrDefault(obj cty.Value, name, def string) string { + if v, ok := stringAttrOk(obj, name); !ok { + return def + } else { + return v + } +} + +func boolAttr(obj cty.Value, name string) bool { + v, _ := boolAttrOk(obj, name) + return v +} + +func boolAttrOk(obj cty.Value, name string) (bool, bool) { + if val := obj.GetAttr(name); val.IsNull() { + return false, false + } else { + return val.True(), true + } +} + +func intAttr(obj cty.Value, name string) int { + v, _ := intAttrOk(obj, name) + return v +} + +func intAttrOk(obj cty.Value, name string) (int, bool) { + if val := obj.GetAttr(name); val.IsNull() { + return 0, false + } else { + var v int + if err := gocty.FromCtyValue(val, &v); err != nil { + return 0, false + } + return v, true + } +} + +func intAttrDefault(obj cty.Value, name string, def int) int { + if v, ok := intAttrOk(obj, name); !ok { + return def + } else { + return v + } } const encryptionKeyConflictError = `Only one of "kms_key_id" and "sse_customer_key" can be set. diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index feae781a037c..85a62add9ada 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -62,6 +62,9 @@ func TestBackendConfig(t *testing.T) { if *b.s3Client.Config.Region != "us-west-1" { t.Fatalf("Incorrect region was populated") } + if *b.s3Client.Config.MaxRetries != 5 { + t.Fatalf("Default max_retries was not set") + } if b.bucketName != "tf-test" { t.Fatalf("Incorrect bucketName was populated") } @@ -307,7 +310,8 @@ func TestBackendConfig_AssumeRole(t *testing.T) { testCase.Config["sts_endpoint"] = aws.StringValue(mockStsSession.Config.Endpoint) } - diags := New().Configure(hcl2shim.HCL2ValueFromConfigValue(testCase.Config)) + b := New() + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(testCase.Config))) if diags.HasErrors() { for _, diag := range diags { @@ -917,11 +921,15 @@ func populateSchema(t *testing.T, schema *configschema.Block, value cty.Value) c func unmarshal(value cty.Value, ty cty.Type, path cty.Path) (cty.Value, error) { switch { case ty.IsPrimitiveType(): - val, err := unmarshalPrimitive(value, ty, path) - if err != nil { - return cty.NilVal, err - } - return val, nil + return value, nil + // case ty.IsListType(): + // return unmarshalList(value, ty.ElementType(), path) + case ty.IsSetType(): + return unmarshalSet(value, ty.ElementType(), path) + case ty.IsMapType(): + return unmarshalMap(value, ty.ElementType(), path) + // case ty.IsTupleType(): + // return unmarshalTuple(value, ty.TupleElementTypes(), path) case ty.IsObjectType(): return unmarshalObject(value, ty.AttributeTypes(), path) default: @@ -929,8 +937,45 @@ func unmarshal(value cty.Value, ty cty.Type, path cty.Path) (cty.Value, error) { } } -func unmarshalPrimitive(value cty.Value, ty cty.Type, path cty.Path) (cty.Value, error) { - return value, nil +func unmarshalSet(dec cty.Value, ety cty.Type, path cty.Path) (cty.Value, error) { + if dec.IsNull() { + return dec, nil + } + + length := dec.LengthInt() + + if length == 0 { + return cty.SetValEmpty(ety), nil + } + + vals := make([]cty.Value, 0, length) + dec.ForEachElement(func(key, val cty.Value) (stop bool) { + vals = append(vals, val) + return + }) + + return cty.SetVal(vals), nil +} + +func unmarshalMap(dec cty.Value, ety cty.Type, path cty.Path) (cty.Value, error) { + if dec.IsNull() { + return dec, nil + } + + length := dec.LengthInt() + + if length == 0 { + return cty.MapValEmpty(ety), nil + } + + vals := make(map[string]cty.Value, length) + dec.ForEachElement(func(key, val cty.Value) (stop bool) { + k := stringValue(key) + vals[k] = val + return + }) + + return cty.MapVal(vals), nil } func unmarshalObject(dec cty.Value, atys map[string]cty.Type, path cty.Path) (cty.Value, error) { From 9bea21e8b28d7bcfdd4a464ce72b8a85e57ecdd5 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 26 Oct 2022 15:59:13 -0700 Subject: [PATCH 05/20] Moves schema to `ConfigSchema` and removes references to legacy schema --- internal/backend/remote-state/s3/backend.go | 161 +++++++------------- 1 file changed, 53 insertions(+), 108 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index cf651c50abdd..8c7b3fb585ae 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -14,7 +14,7 @@ import ( "github.com/aws/aws-sdk-go/service/s3" awsbase "github.com/hashicorp/aws-sdk-go-base" "github.com/hashicorp/terraform/internal/backend" - "github.com/hashicorp/terraform/internal/legacy/helper/schema" + "github.com/hashicorp/terraform/internal/configs/configschema" "github.com/hashicorp/terraform/internal/logging" "github.com/hashicorp/terraform/internal/tfdiags" "github.com/hashicorp/terraform/version" @@ -22,248 +22,193 @@ import ( "github.com/zclconf/go-cty/cty/gocty" ) -// New creates a new backend for S3 remote state. func New() backend.Backend { - s := &schema.Backend{ - Schema: map[string]*schema.Schema{ + return &Backend{} +} + +type Backend struct { + s3Client *s3.S3 + dynClient *dynamodb.DynamoDB + + bucketName string + keyName string + serverSideEncryption bool + customerEncryptionKey []byte + acl string + kmsKeyID string + ddbTable string + workspaceKeyPrefix string +} + +func (b *Backend) ConfigSchema() *configschema.Block { + return &configschema.Block{ + Attributes: map[string]*configschema.Attribute{ "bucket": { - Type: schema.TypeString, + Type: cty.String, Required: true, Description: "The name of the S3 bucket", }, - "key": { - Type: schema.TypeString, + Type: cty.String, Required: true, Description: "The path to the state file inside the bucket", }, - "region": { - Type: schema.TypeString, - Required: true, + Type: cty.String, + Optional: true, Description: "AWS region of the S3 Bucket and DynamoDB Table (if used).", - DefaultFunc: schema.MultiEnvDefaultFunc([]string{ - "AWS_REGION", - "AWS_DEFAULT_REGION", - }, nil), }, - "dynamodb_endpoint": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "A custom endpoint for the DynamoDB API", - DefaultFunc: schema.EnvDefaultFunc("AWS_DYNAMODB_ENDPOINT", ""), }, - "endpoint": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "A custom endpoint for the S3 API", - DefaultFunc: schema.EnvDefaultFunc("AWS_S3_ENDPOINT", ""), }, - "iam_endpoint": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "A custom endpoint for the IAM API", - DefaultFunc: schema.EnvDefaultFunc("AWS_IAM_ENDPOINT", ""), }, - "sts_endpoint": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "A custom endpoint for the STS API", - DefaultFunc: schema.EnvDefaultFunc("AWS_STS_ENDPOINT", ""), }, - "encrypt": { - Type: schema.TypeBool, + Type: cty.Bool, Optional: true, Description: "Whether to enable server side encryption of the state file", - Default: false, }, - "acl": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "Canned ACL to be applied to the state file", - Default: "", }, - "access_key": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "AWS access key", - Default: "", }, - "secret_key": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "AWS secret key", - Default: "", }, - "kms_key_id": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The ARN of a KMS Key to use for encrypting the state", - Default: "", }, - "dynamodb_table": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "DynamoDB table for state locking and consistency", - Default: "", }, - "profile": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "AWS profile name", - Default: "", }, - "shared_credentials_file": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "Path to a shared credentials file", - Default: "", }, - "token": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "MFA token", - Default: "", }, - "skip_credentials_validation": { - Type: schema.TypeBool, + Type: cty.String, Optional: true, Description: "Skip the credentials validation via STS API.", - Default: false, }, - "skip_region_validation": { - Type: schema.TypeBool, + Type: cty.Bool, Optional: true, Description: "Skip static validation of region name.", - Default: false, }, - "skip_metadata_api_check": { - Type: schema.TypeBool, + Type: cty.Bool, Optional: true, Description: "Skip the AWS Metadata API check.", - Default: false, }, - "sse_customer_key": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The base64-encoded encryption key to use for server-side encryption with customer-provided keys (SSE-C).", - DefaultFunc: schema.EnvDefaultFunc("AWS_SSE_CUSTOMER_KEY", ""), Sensitive: true, }, - "role_arn": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The role to be assumed", - Default: "", }, - "session_name": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The session name to use when assuming the role.", - Default: "", }, - "external_id": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The external ID to use when assuming the role", - Default: "", }, "assume_role_duration_seconds": { - Type: schema.TypeInt, + Type: cty.Number, Optional: true, Description: "Seconds to restrict the assume role session duration.", }, "assume_role_policy": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "IAM Policy JSON describing further restricting permissions for the IAM Role being assumed.", - Default: "", }, "assume_role_policy_arns": { - Type: schema.TypeSet, + Type: cty.Set(cty.String), Optional: true, Description: "Amazon Resource Names (ARNs) of IAM Policies describing further restricting permissions for the IAM Role being assumed.", - Elem: &schema.Schema{Type: schema.TypeString}, }, "assume_role_tags": { - Type: schema.TypeMap, + Type: cty.Map(cty.String), Optional: true, Description: "Assume role session tags.", - Elem: &schema.Schema{Type: schema.TypeString}, }, "assume_role_transitive_tag_keys": { - Type: schema.TypeSet, + Type: cty.Set(cty.String), Optional: true, Description: "Assume role session tag keys to pass to any subsequent sessions.", - Elem: &schema.Schema{Type: schema.TypeString}, }, "workspace_key_prefix": { - Type: schema.TypeString, + Type: cty.String, Optional: true, Description: "The prefix applied to the non-default state path inside the bucket.", - Default: "env:", }, "force_path_style": { - Type: schema.TypeBool, + Type: cty.Bool, Optional: true, Description: "Force s3 to use path style api.", - Default: false, }, "max_retries": { - Type: schema.TypeInt, + Type: cty.Number, Optional: true, Description: "The maximum number of times an AWS API request is retried on retryable failure.", - Default: 5, }, }, } - - return &Backend{Backend: s} -} - -type Backend struct { - *schema.Backend - - // The fields below are set from configure - s3Client *s3.S3 - dynClient *dynamodb.DynamoDB - - bucketName string - keyName string - serverSideEncryption bool - customerEncryptionKey []byte - acl string - kmsKeyID string - ddbTable string - workspaceKeyPrefix string } func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) { From 8d018cfef3795e8a8826a0701af0d7cd0304241f Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 26 Oct 2022 16:19:32 -0700 Subject: [PATCH 06/20] Adds test for setting region from envvars --- .../backend/remote-state/s3/backend_test.go | 47 ++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 85a62add9ada..d0a13f05aba4 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -84,6 +84,49 @@ func TestBackendConfig(t *testing.T) { } } +func TestBackendConfig_RegionEnvVar(t *testing.T) { + testACC(t) + config := map[string]interface{}{ + "bucket": "tf-test", + "key": "state", + } + + cases := map[string]struct { + vars map[string]string + }{ + "AWS_REGION": { + vars: map[string]string{ + "AWS_REGION": "us-west-1", + }, + }, + + "AWS_DEFAULT_REGION": { + vars: map[string]string{ + "AWS_DEFAULT_REGION": "us-west-1", + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + for k, v := range tc.vars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range tc.vars { + os.Unsetenv(k) + } + }) + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + + if *b.s3Client.Config.Region != "us-west-1" { + t.Fatalf("Incorrect region was populated") + } + }) + } +} + func TestBackendConfig_AssumeRole(t *testing.T) { testACC(t) @@ -468,7 +511,7 @@ func TestBackendConfig_PrepareConfigWithEnvVars(t *testing.T) { "region": cty.NullVal(cty.String), }), vars: map[string]string{ - "AWS_REGION": "us-west-2", + "AWS_REGION": "us-west-1", }, }, "region env var AWS_DEFAULT_REGION": { @@ -478,7 +521,7 @@ func TestBackendConfig_PrepareConfigWithEnvVars(t *testing.T) { "region": cty.NullVal(cty.String), }), vars: map[string]string{ - "AWS_DEFAULT_REGION": "us-west-2", + "AWS_DEFAULT_REGION": "us-west-1", }, }, } From 95eb523c0292830a67c52c379f72a65d2a09c879 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 26 Oct 2022 17:17:03 -0700 Subject: [PATCH 07/20] Sets service endpoints from envvar and adds tests --- internal/backend/remote-state/s3/backend.go | 46 +++++-- .../backend/remote-state/s3/backend_test.go | 126 ++++++++++++++++++ 2 files changed, 163 insertions(+), 9 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index 8c7b3fb585ae..959f40f8a7b3 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -343,14 +343,14 @@ func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { CallerName: "S3 Backend", CredsFilename: stringAttr(obj, "shared_credentials_file"), DebugLogging: logging.IsDebugOrHigher(), - IamEndpoint: stringAttr(obj, "iam_endpoint"), + IamEndpoint: stringAttrDefaultEnvVar(obj, "iam_endpoint", "AWS_IAM_ENDPOINT"), MaxRetries: intAttrDefault(obj, "max_retries", 5), Profile: stringAttr(obj, "profile"), Region: stringAttr(obj, "region"), SecretKey: stringAttr(obj, "secret_key"), SkipCredsValidation: boolAttr(obj, "skip_credentials_validation"), SkipMetadataApiCheck: boolAttr(obj, "skip_metadata_api_check"), - StsEndpoint: stringAttr(obj, "sts_endpoint"), + StsEndpoint: stringAttrDefaultEnvVar(obj, "sts_endpoint", "AWS_STS_ENDPOINT"), Token: stringAttr(obj, "token"), UserAgentProducts: []*awsbase.UserAgentProduct{ {Name: "APN", Version: "1.0"}, @@ -401,13 +401,20 @@ func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { return diags } - b.dynClient = dynamodb.New(sess.Copy(&aws.Config{ - Endpoint: aws.String(stringAttr(obj, "dynamodb_endpoint")), - })) - b.s3Client = s3.New(sess.Copy(&aws.Config{ - Endpoint: aws.String(stringAttr(obj, "endpoint")), - S3ForcePathStyle: aws.Bool(boolAttr(obj, "force_path_style")), - })) + var dynamoConfig aws.Config + if v, ok := stringAttrDefaultEnvVarOk(obj, "dynamodb_endpoint", "AWS_DYNAMODB_ENDPOINT"); ok { + dynamoConfig.Endpoint = aws.String(v) + } + b.dynClient = dynamodb.New(sess.Copy(&dynamoConfig)) + + var s3Config aws.Config + if v, ok := stringAttrDefaultEnvVarOk(obj, "endpoint", "AWS_S3_ENDPOINT"); ok { + s3Config.Endpoint = aws.String(v) + } + if v, ok := boolAttrOk(obj, "force_path_style"); ok { + s3Config.S3ForcePathStyle = aws.Bool(v) + } + b.s3Client = s3.New(sess.Copy(&s3Config)) return diags } @@ -441,6 +448,27 @@ func stringAttrDefault(obj cty.Value, name, def string) string { } } +func stringAttrDefaultEnvVar(obj cty.Value, name string, envvars ...string) string { + if v, ok := stringAttrDefaultEnvVarOk(obj, name, envvars...); !ok { + return "" + } else { + return v + } +} + +func stringAttrDefaultEnvVarOk(obj cty.Value, name string, envvars ...string) (string, bool) { + if v, ok := stringAttrOk(obj, name); !ok { + for _, envvar := range envvars { + if v := os.Getenv(envvar); v != "" { + return v, true + } + } + return "", false + } else { + return v, true + } +} + func boolAttr(obj cty.Value, name string) bool { v, _ := boolAttrOk(obj, name) return v diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index d0a13f05aba4..6662ad4dfee9 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -72,6 +72,10 @@ func TestBackendConfig(t *testing.T) { t.Fatalf("Incorrect keyName was populated") } + checkClientEndpoint(t, b.s3Client.Config, "") + + checkClientEndpoint(t, b.dynClient.Config, "") + credentials, err := b.s3Client.Config.Credentials.Get() if err != nil { t.Fatalf("Error when requesting credentials") @@ -84,6 +88,12 @@ func TestBackendConfig(t *testing.T) { } } +func checkClientEndpoint(t *testing.T, config aws.Config, expected string) { + if a := aws.StringValue(config.Endpoint); a != expected { + t.Errorf("expected endpoint %q, got %q", expected, a) + } +} + func TestBackendConfig_RegionEnvVar(t *testing.T) { testACC(t) config := map[string]interface{}{ @@ -127,6 +137,122 @@ func TestBackendConfig_RegionEnvVar(t *testing.T) { } } +func TestBackendConfig_DynamoDBEndpoint(t *testing.T) { + testACC(t) + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", + } + + vars := map[string]string{ + "AWS_DYNAMODB_ENDPOINT": "dynamo.test", + } + for k, v := range vars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range vars { + os.Unsetenv(k) + } + }) + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + + checkClientEndpoint(t, b.dynClient.Config, "dynamo.test") +} + +func TestBackendConfig_S3Endpoint(t *testing.T) { + testACC(t) + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", + } + + vars := map[string]string{ + "AWS_S3_ENDPOINT": "s3.test", + } + for k, v := range vars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range vars { + os.Unsetenv(k) + } + }) + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + + checkClientEndpoint(t, b.s3Client.Config, "s3.test") +} + +func TestBackendConfig_STSEndpoint(t *testing.T) { + testACC(t) + + testCases := []struct { + Config map[string]interface{} + Description string + MockStsEndpoints []*awsbase.MockEndpoint + }{ + { + Config: map[string]interface{}{ + "bucket": "tf-test", + "key": "state", + "region": "us-west-1", + "role_arn": awsbase.MockStsAssumeRoleArn, + "session_name": awsbase.MockStsAssumeRoleSessionName, + }, + Description: "role_arn", + MockStsEndpoints: []*awsbase.MockEndpoint{ + { + Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: url.Values{ + "Action": []string{"AssumeRole"}, + "DurationSeconds": []string{"900"}, + "RoleArn": []string{awsbase.MockStsAssumeRoleArn}, + "RoleSessionName": []string{awsbase.MockStsAssumeRoleSessionName}, + "Version": []string{"2011-06-15"}, + }.Encode()}, + Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsAssumeRoleValidResponseBody, ContentType: "text/xml"}, + }, + { + Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: mockStsGetCallerIdentityRequestBody}, + Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsGetCallerIdentityValidResponseBody, ContentType: "text/xml"}, + }, + }, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.Description, func(t *testing.T) { + closeSts, mockStsSession, err := awsbase.GetMockedAwsApiSession("STS", testCase.MockStsEndpoints) + defer closeSts() + + if err != nil { + t.Fatalf("unexpected error creating mock STS server: %s", err) + } + + if mockStsSession != nil && mockStsSession.Config != nil { + os.Setenv("AWS_STS_ENDPOINT", aws.StringValue(mockStsSession.Config.Endpoint)) + t.Cleanup(func() { + os.Unsetenv("AWS_STS_ENDPOINT") + }) + } + + b := New() + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(testCase.Config))) + + if diags.HasErrors() { + for _, diag := range diags { + t.Errorf("unexpected error: %s", diag.Description().Summary) + } + } + }) + } +} + func TestBackendConfig_AssumeRole(t *testing.T) { testACC(t) From 2d12f242c754c4c6ac5b825214a6be6503f604dd Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Thu, 27 Oct 2022 11:07:41 -0700 Subject: [PATCH 08/20] Adds checks for not getting validation errors when they are expected --- internal/backend/remote-state/s3/backend.go | 12 +++++++- .../backend/remote-state/s3/backend_test.go | 28 +++++++++++++------ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index 959f40f8a7b3..bc47c33af639 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -217,6 +217,16 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) return obj, diags } + if val := obj.GetAttr("bucket"); val.IsNull() || val.AsString() == "" { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid bucket value", + // `The "bucket" attribute value must not be empty.`, + `"bucket": required field is not set`, + cty.Path{cty.GetAttrStep{Name: "bucket"}}, + )) + } + if val := obj.GetAttr("key"); val.IsNull() || val.AsString() == "" { diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, @@ -237,7 +247,7 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) )) } - if val := obj.GetAttr("region"); val.IsNull() { + if val := obj.GetAttr("region"); val.IsNull() || val.AsString() == "" { if os.Getenv("AWS_REGION") == "" && os.Getenv("AWS_DEFAULT_REGION") == "" { diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 6662ad4dfee9..b9c8e51d60df 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -614,11 +614,17 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { b := New() _, valDiags := b.PrepareConfig(populateSchema(t, b.ConfigSchema(), tc.config)) - if valDiags.Err() != nil && tc.expectedErr != "" { - actualErr := valDiags.Err().Error() - if !strings.Contains(actualErr, tc.expectedErr) { - t.Fatalf("unexpected validation result: %v", valDiags.Err()) + if tc.expectedErr != "" { + if valDiags.Err() != nil { + actualErr := valDiags.Err().Error() + if !strings.Contains(actualErr, tc.expectedErr) { + t.Fatalf("unexpected validation result: %v", valDiags.Err()) + } + } else { + t.Fatal("expected an error, got none") } + } else if valDiags.Err() != nil { + t.Fatalf("expected no error, got %s", valDiags.Err()) } }) } @@ -666,11 +672,17 @@ func TestBackendConfig_PrepareConfigWithEnvVars(t *testing.T) { }) _, valDiags := b.PrepareConfig(populateSchema(t, b.ConfigSchema(), tc.config)) - if valDiags.Err() != nil && tc.expectedErr != "" { - actualErr := valDiags.Err().Error() - if !strings.Contains(actualErr, tc.expectedErr) { - t.Fatalf("unexpected validation result: %v", valDiags.Err()) + if tc.expectedErr != "" { + if valDiags.Err() != nil { + actualErr := valDiags.Err().Error() + if !strings.Contains(actualErr, tc.expectedErr) { + t.Fatalf("unexpected validation result: %v", valDiags.Err()) + } + } else { + t.Fatal("expected an error, got none") } + } else if valDiags.Err() != nil { + t.Fatalf("expected no error, got %s", valDiags.Err()) } }) } From 4eaa44c5a50c7dce18576fadd74138eb149996e4 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Thu, 27 Oct 2022 12:07:46 -0700 Subject: [PATCH 09/20] Adds functions for clearing all envvars --- .../backend/remote-state/s3/backend_test.go | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index b9c8e51d60df..d2d41d3ec610 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -611,6 +611,9 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { for name, tc := range cases { t.Run(name, func(t *testing.T) { + oldEnv := stashEnv() + defer popEnv(oldEnv) + b := New() _, valDiags := b.PrepareConfig(populateSchema(t, b.ConfigSchema(), tc.config)) @@ -660,16 +663,14 @@ func TestBackendConfig_PrepareConfigWithEnvVars(t *testing.T) { for name, tc := range cases { t.Run(name, func(t *testing.T) { + oldEnv := stashEnv() + defer popEnv(oldEnv) + b := New() for k, v := range tc.vars { os.Setenv(k, v) } - t.Cleanup(func() { - for k := range tc.vars { - os.Unsetenv(k) - } - }) _, valDiags := b.PrepareConfig(populateSchema(t, b.ConfigSchema(), tc.config)) if tc.expectedErr != "" { @@ -1185,3 +1186,22 @@ func unmarshalObject(dec cty.Value, atys map[string]cty.Type, path cty.Path) (ct return cty.ObjectVal(vals), nil } + +func stashEnv() []string { + env := os.Environ() + os.Clearenv() + return env +} + +func popEnv(env []string) { + os.Clearenv() + + for _, e := range env { + p := strings.SplitN(e, "=", 2) + k, v := p[0], "" + if len(p) > 1 { + v = p[1] + } + os.Setenv(k, v) + } +} From 827d7bd384a418b6b485fc88782379ccce3085e7 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Thu, 27 Oct 2022 14:39:48 -0700 Subject: [PATCH 10/20] Combines `sse_customer_key` and `AWS_SSE_CUSTOMER_KEY` validation --- internal/backend/remote-state/s3/backend.go | 65 ++++--- .../backend/remote-state/s3/backend_test.go | 158 ++++++++++++++---- 2 files changed, 168 insertions(+), 55 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index bc47c33af639..940553b50b21 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -258,29 +258,6 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) } } - if val := obj.GetAttr("sse_customer_key"); !val.IsNull() { - s := val.AsString() - if len(s) != 44 { - diags = diags.Append(tfdiags.AttributeValue( - tfdiags.Error, - "Invalid sse_customer_key value", - "sse_customer_key must be 44 characters in length", - cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, - )) - } else { - var err error - _, err = base64.StdEncoding.DecodeString(s) - if err != nil { - diags = diags.Append(tfdiags.AttributeValue( - tfdiags.Error, - "Invalid sse_customer_key value", - fmt.Sprintf("sse_customer_key must be base64 encoded: %s", err), - cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, - )) - } - } - } - if val := obj.GetAttr("kms_key_id"); !val.IsNull() && val.AsString() != "" { if val := obj.GetAttr("sse_customer_key"); !val.IsNull() && val.AsString() != "" { diags = diags.Append(tfdiags.AttributeValue( @@ -337,9 +314,45 @@ func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { b.kmsKeyID = stringAttr(obj, "kms_key_id") b.ddbTable = stringAttr(obj, "dynamodb_table") - if customerKeyString, ok := stringAttrOk(obj, "sse_customer_key"); ok { - // Validation is handled in PrepareConfig, so ignore it here - b.customerEncryptionKey, _ = base64.StdEncoding.DecodeString(customerKeyString) + // WarnOnEmptyString(), LenEquals(44), IsBase64Encoded() + if customerKey, ok := stringAttrOk(obj, "sse_customer_key"); ok { + if len(customerKey) != 44 { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid sse_customer_key value", + "sse_customer_key must be 44 characters in length", + cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, + )) + } else { + var err error + if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid sse_customer_key value", + fmt.Sprintf("sse_customer_key must be base64 encoded: %s", err), + cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, + )) + } + } + } else { + if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" { + if len(customerKey) != 44 { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid AWS_SSE_CUSTOMER_KEY value", + "AWS_SSE_CUSTOMER_KEY must be 44 characters in length", + )) + } else { + var err error + if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid AWS_SSE_CUSTOMER_KEY value", + fmt.Sprintf("AWS_SSE_CUSTOMER_KEY must be base64 encoded: %s", err), + )) + } + } + } } cfg := &awsbase.Config{ diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index d2d41d3ec610..ad57665ea8c5 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -4,6 +4,7 @@ package s3 import ( + "encoding/base64" "fmt" "net/url" "os" @@ -578,24 +579,6 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { }), expectedErr: `workspace_key_prefix must not start or end with '/'`, }, - "sse_customer_key invalid length": { - config: cty.ObjectVal(map[string]cty.Value{ - "bucket": cty.StringVal("test"), - "key": cty.StringVal("test"), - "region": cty.StringVal("us-west-2"), - "sse_customer_key": cty.StringVal("key"), - }), - expectedErr: `sse_customer_key must be 44 characters in length`, - }, - "sse_customer_key invalid encoding": { - config: cty.ObjectVal(map[string]cty.Value{ - "bucket": cty.StringVal("test"), - "key": cty.StringVal("test"), - "region": cty.StringVal("us-west-2"), - "sse_customer_key": cty.StringVal("====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka"), - }), - expectedErr: `sse_customer_key must be base64 encoded`, - }, "encyrption key conflict": { config: cty.ObjectVal(map[string]cty.Value{ "bucket": cty.StringVal("test"), @@ -736,21 +719,130 @@ func TestBackendLocked(t *testing.T) { backend.TestBackendStateForceUnlock(t, b1, b2) } -func TestBackendSSECustomerKey(t *testing.T) { +func TestBackendSSECustomerKeyConfig(t *testing.T) { testACC(t) - bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) - b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{ - "bucket": bucketName, - "encrypt": true, - "key": "test-SSE-C", - "sse_customer_key": "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", - })).(*Backend) + testCases := map[string]struct { + customerKey string + expectedErr string + }{ + "invalid length": { + customerKey: "test", + expectedErr: `sse_customer_key must be 44 characters in length`, + }, + "invalid encoding": { + customerKey: "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", + expectedErr: `sse_customer_key must be base64 encoded`, + }, + "valid": { + customerKey: "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", + }, + } - createS3Bucket(t, b.s3Client, bucketName) - defer deleteS3Bucket(t, b.s3Client, bucketName) + for name, testCase := range testCases { + testCase := testCase - backend.TestBackendStates(t, b) + t.Run(name, func(t *testing.T) { + bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) + config := map[string]interface{}{ + "bucket": bucketName, + "encrypt": true, + "key": "test-SSE-C", + "sse_customer_key": testCase.customerKey, + } + + b := New().(*Backend) + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config))) + + if testCase.expectedErr != "" { + if diags.Err() != nil { + actualErr := diags.Err().Error() + if !strings.Contains(actualErr, testCase.expectedErr) { + t.Fatalf("unexpected validation result: %v", diags.Err()) + } + } else { + t.Fatal("expected an error, got none") + } + } else { + if diags.Err() != nil { + t.Fatalf("expected no error, got %s", diags.Err()) + } + if string(b.customerEncryptionKey) != string(must(base64.StdEncoding.DecodeString(testCase.customerKey))) { + t.Fatal("unexpected value for customer encryption key") + } + + createS3Bucket(t, b.s3Client, bucketName) + defer deleteS3Bucket(t, b.s3Client, bucketName) + + backend.TestBackendStates(t, b) + } + }) + } +} + +func TestBackendSSECustomerKeyEnvVar(t *testing.T) { + testACC(t) + + testCases := map[string]struct { + customerKey string + expectedErr string + }{ + "invalid length": { + customerKey: "test", + expectedErr: `AWS_SSE_CUSTOMER_KEY must be 44 characters in length`, + }, + "invalid encoding": { + customerKey: "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", + expectedErr: `AWS_SSE_CUSTOMER_KEY must be base64 encoded`, + }, + "valid": { + customerKey: "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", + }, + } + + for name, testCase := range testCases { + testCase := testCase + + t.Run(name, func(t *testing.T) { + bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) + config := map[string]interface{}{ + "bucket": bucketName, + "encrypt": true, + "key": "test-SSE-C", + } + + os.Setenv("AWS_SSE_CUSTOMER_KEY", testCase.customerKey) + t.Cleanup(func() { + os.Unsetenv("AWS_SSE_CUSTOMER_KEY") + }) + + b := New().(*Backend) + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config))) + + if testCase.expectedErr != "" { + if diags.Err() != nil { + actualErr := diags.Err().Error() + if !strings.Contains(actualErr, testCase.expectedErr) { + t.Fatalf("unexpected validation result: %v", diags.Err()) + } + } else { + t.Fatal("expected an error, got none") + } + } else { + if diags.Err() != nil { + t.Fatalf("expected no error, got %s", diags.Err()) + } + if string(b.customerEncryptionKey) != string(must(base64.StdEncoding.DecodeString(testCase.customerKey))) { + t.Fatal("unexpected value for customer encryption key") + } + + createS3Bucket(t, b.s3Client, bucketName) + defer deleteS3Bucket(t, b.s3Client, bucketName) + + backend.TestBackendStates(t, b) + } + }) + } } // add some extra junk in S3 to try and confuse the env listing. @@ -1205,3 +1297,11 @@ func popEnv(env []string) { os.Setenv(k, v) } } + +func must[T any](v T, err error) T { + if err != nil { + panic(err) + } else { + return v + } +} From e8c7722d3ea64f818231bcf557be368de996122d Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Thu, 27 Oct 2022 16:25:16 -0700 Subject: [PATCH 11/20] Restores conflict between `kms_key_id` and envvar `AWS_SSE_CUSTOMER_KEY` --- internal/backend/remote-state/s3/backend.go | 40 ++++++++++++------- .../backend/remote-state/s3/backend_test.go | 19 +++++++-- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index 940553b50b21..777edc548a04 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -266,6 +266,12 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) encryptionKeyConflictError, cty.Path{}, )) + } else if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid encryption configuration", + encryptionKeyConflictEnvVarError, + )) } } @@ -334,23 +340,21 @@ func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { )) } } - } else { - if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" { - if len(customerKey) != 44 { + } else if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" { + if len(customerKey) != 44 { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid AWS_SSE_CUSTOMER_KEY value", + `The environment variable "AWS_SSE_CUSTOMER_KEY" must be 44 characters in length`, + )) + } else { + var err error + if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil { diags = diags.Append(tfdiags.Sourceless( tfdiags.Error, "Invalid AWS_SSE_CUSTOMER_KEY value", - "AWS_SSE_CUSTOMER_KEY must be 44 characters in length", + fmt.Sprintf(`The environment variable "AWS_SSE_CUSTOMER_KEY" must be base64 encoded: %s`, err), )) - } else { - var err error - if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil { - diags = diags.Append(tfdiags.Sourceless( - tfdiags.Error, - "Invalid AWS_SSE_CUSTOMER_KEY value", - fmt.Sprintf("AWS_SSE_CUSTOMER_KEY must be base64 encoded: %s", err), - )) - } } } } @@ -532,6 +536,12 @@ func intAttrDefault(obj cty.Value, name string, def int) int { const encryptionKeyConflictError = `Only one of "kms_key_id" and "sse_customer_key" can be set. -The kms_key_id is used for encryption with KMS-Managed Keys (SSE-KMS) -while sse_customer_key is used for encryption with customer-managed keys (SSE-C). +The "kms_key_id" is used for encryption with KMS-Managed Keys (SSE-KMS) +while "sse_customer_key" is used for encryption with customer-managed keys (SSE-C). +Please choose one or the other.` + +const encryptionKeyConflictEnvVarError = `Only one of "kms_key_id" and the environment variable "AWS_SSE_CUSTOMER_KEY" can be set. + +The "kms_key_id" is used for encryption with KMS-Managed Keys (SSE-KMS) +while "AWS_SSE_CUSTOMER_KEY" is used for encryption with customer-managed keys (SSE-C). Please choose one or the other.` diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index ad57665ea8c5..adc1038a8a7f 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -584,7 +584,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "bucket": cty.StringVal("test"), "key": cty.StringVal("test"), "region": cty.StringVal("us-west-2"), - "workspace_key_prefix": cty.StringVal("env/"), + "workspace_key_prefix": cty.StringVal("env"), "sse_customer_key": cty.StringVal("1hwbcNPGWL+AwDiyGmRidTWAEVmCWMKbEHA+Es8w75o="), "kms_key_id": cty.StringVal("arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab"), }), @@ -642,6 +642,19 @@ func TestBackendConfig_PrepareConfigWithEnvVars(t *testing.T) { "AWS_DEFAULT_REGION": "us-west-1", }, }, + "encyrption key conflict": { + config: cty.ObjectVal(map[string]cty.Value{ + "bucket": cty.StringVal("test"), + "key": cty.StringVal("test"), + "region": cty.StringVal("us-west-2"), + "workspace_key_prefix": cty.StringVal("env"), + "kms_key_id": cty.StringVal("arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab"), + }), + vars: map[string]string{ + "AWS_SSE_CUSTOMER_KEY": "1hwbcNPGWL+AwDiyGmRidTWAEVmCWMKbEHA+Es8w75o=", + }, + expectedErr: `Only one of "kms_key_id" and the environment variable "AWS_SSE_CUSTOMER_KEY" can be set`, + }, } for name, tc := range cases { @@ -789,11 +802,11 @@ func TestBackendSSECustomerKeyEnvVar(t *testing.T) { }{ "invalid length": { customerKey: "test", - expectedErr: `AWS_SSE_CUSTOMER_KEY must be 44 characters in length`, + expectedErr: `The environment variable "AWS_SSE_CUSTOMER_KEY" must be 44 characters in length`, }, "invalid encoding": { customerKey: "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", - expectedErr: `AWS_SSE_CUSTOMER_KEY must be base64 encoded`, + expectedErr: `The environment variable "AWS_SSE_CUSTOMER_KEY" must be base64 encoded`, }, "valid": { customerKey: "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", From 2fda09aab27c4a213ab8fa613db9db4fffc70fa2 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Mon, 26 Jun 2023 16:01:05 -0700 Subject: [PATCH 12/20] Updates attribute validation messages --- internal/backend/remote-state/s3/backend.go | 26 ++++++++++++++----- .../backend/remote-state/s3/backend_test.go | 20 +++++++------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index 777edc548a04..6b0707064781 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -40,6 +40,8 @@ type Backend struct { workspaceKeyPrefix string } +// ConfigSchema returns a description of the expected configuration +// structure for the receiving backend. func (b *Backend) ConfigSchema() *configschema.Block { return &configschema.Block{ Attributes: map[string]*configschema.Attribute{ @@ -211,6 +213,10 @@ func (b *Backend) ConfigSchema() *configschema.Block { } } +// PrepareConfig checks the validity of the values in the given +// configuration, and inserts any missing defaults, assuming that its +// structure has already been validated per the schema returned by +// ConfigSchema. func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) { var diags tfdiags.Diagnostics if obj.IsNull() { @@ -221,8 +227,7 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, "Invalid bucket value", - // `The "bucket" attribute value must not be empty.`, - `"bucket": required field is not set`, + `The "bucket" attribute value must not be empty.`, cty.Path{cty.GetAttrStep{Name: "bucket"}}, )) } @@ -231,7 +236,7 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, "Invalid key value", - `"key": required field is not set`, + `The "key" attribute value must not be empty.`, cty.Path{cty.GetAttrStep{Name: "key"}}, )) } else if strings.HasPrefix(val.AsString(), "/") || strings.HasSuffix(val.AsString(), "/") { @@ -242,7 +247,7 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, "Invalid key value", - "key must not start or end with '/'", + `The "key" attribute value must not start or end with with "/".`, cty.Path{cty.GetAttrStep{Name: "key"}}, )) } @@ -252,7 +257,7 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, "Missing region value", - `"region": required field is not set`, + `The "region" attribute or the "AWS_REGION" or "AWS_DEFAULT_REGION" environment variables must be set.`, cty.Path{cty.GetAttrStep{Name: "region"}}, )) } @@ -267,10 +272,11 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) cty.Path{}, )) } else if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" { - diags = diags.Append(tfdiags.Sourceless( + diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, "Invalid encryption configuration", encryptionKeyConflictEnvVarError, + cty.Path{}, )) } } @@ -280,7 +286,7 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, "Invalid workspace_key_prefix value", - "workspace_key_prefix must not start or end with '/'", + `The "workspace_key_prefix" attribute value must not start with "/".`, cty.Path{cty.GetAttrStep{Name: "workspace_key_prefix"}}, )) } @@ -289,6 +295,12 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) return obj, diags } +// Configure uses the provided configuration to set configuration fields +// within the backend. +// +// The given configuration is assumed to have already been validated +// against the schema returned by ConfigSchema and passed validation +// via PrepareConfig. func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { var diags tfdiags.Diagnostics if obj.IsNull() { diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index adc1038a8a7f..95f97bf21bae 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -503,7 +503,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.StringVal("test"), "region": cty.StringVal("us-west-2"), }), - expectedErr: `"bucket": required field is not set`, + expectedErr: `The "bucket" attribute value must not be empty.`, }, "empty bucket": { config: cty.ObjectVal(map[string]cty.Value{ @@ -511,7 +511,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.StringVal("test"), "region": cty.StringVal("us-west-2"), }), - expectedErr: `"bucket": required field is not set`, + expectedErr: `The "bucket" attribute value must not be empty.`, }, "null key": { config: cty.ObjectVal(map[string]cty.Value{ @@ -519,7 +519,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.NullVal(cty.String), "region": cty.StringVal("us-west-2"), }), - expectedErr: `"key": required field is not set`, + expectedErr: `The "key" attribute value must not be empty.`, }, "empty key": { config: cty.ObjectVal(map[string]cty.Value{ @@ -527,7 +527,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.StringVal(""), "region": cty.StringVal("us-west-2"), }), - expectedErr: `"key": required field is not set`, + expectedErr: `The "key" attribute value must not be empty.`, }, "key with leading slash": { config: cty.ObjectVal(map[string]cty.Value{ @@ -535,7 +535,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.StringVal("/leading-slash"), "region": cty.StringVal("us-west-2"), }), - expectedErr: `key must not start or end with '/'`, + expectedErr: `The "key" attribute value must not start or end with with "/".`, }, "key with trailing slash": { config: cty.ObjectVal(map[string]cty.Value{ @@ -543,7 +543,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.StringVal("trailing-slash/"), "region": cty.StringVal("us-west-2"), }), - expectedErr: `key must not start or end with '/'`, + expectedErr: `The "key" attribute value must not start or end with with "/".`, }, "null region": { config: cty.ObjectVal(map[string]cty.Value{ @@ -551,7 +551,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.StringVal("test"), "region": cty.NullVal(cty.String), }), - expectedErr: `"region": required field is not set`, + expectedErr: `The "region" attribute or the "AWS_REGION" or "AWS_DEFAULT_REGION" environment variables must be set.`, }, "empty region": { config: cty.ObjectVal(map[string]cty.Value{ @@ -559,7 +559,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "key": cty.StringVal("test"), "region": cty.StringVal(""), }), - expectedErr: `"region": required field is not set`, + expectedErr: `The "region" attribute or the "AWS_REGION" or "AWS_DEFAULT_REGION" environment variables must be set.`, }, "workspace_key_prefix with leading slash": { config: cty.ObjectVal(map[string]cty.Value{ @@ -568,7 +568,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "region": cty.StringVal("us-west-2"), "workspace_key_prefix": cty.StringVal("/env"), }), - expectedErr: `workspace_key_prefix must not start or end with '/'`, + expectedErr: `The "workspace_key_prefix" attribute value must not start with "/".`, }, "workspace_key_prefix with trailing slash": { config: cty.ObjectVal(map[string]cty.Value{ @@ -577,7 +577,7 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { "region": cty.StringVal("us-west-2"), "workspace_key_prefix": cty.StringVal("env/"), }), - expectedErr: `workspace_key_prefix must not start or end with '/'`, + expectedErr: `The "workspace_key_prefix" attribute value must not start with "/".`, }, "encyrption key conflict": { config: cty.ObjectVal(map[string]cty.Value{ From 454eed63e7791aeee929b82908fc5b5b0dcc77e6 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Mon, 26 Jun 2023 16:41:46 -0700 Subject: [PATCH 13/20] Adds KMS Key validation --- internal/backend/remote-state/s3/backend.go | 2 + internal/backend/remote-state/s3/testing.go | 26 +++ internal/backend/remote-state/s3/validate.go | 80 +++++++++ .../backend/remote-state/s3/validate_test.go | 154 ++++++++++++++++++ 4 files changed, 262 insertions(+) create mode 100644 internal/backend/remote-state/s3/testing.go create mode 100644 internal/backend/remote-state/s3/validate.go create mode 100644 internal/backend/remote-state/s3/validate_test.go diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index 6b0707064781..791576a723c3 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -279,6 +279,8 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) cty.Path{}, )) } + + diags = diags.Append(validateKMSKey(cty.Path{cty.GetAttrStep{Name: "kms_key_id"}}, val.AsString())) } if val := obj.GetAttr("workspace_key_prefix"); !val.IsNull() { diff --git a/internal/backend/remote-state/s3/testing.go b/internal/backend/remote-state/s3/testing.go new file mode 100644 index 000000000000..25c58625a2c0 --- /dev/null +++ b/internal/backend/remote-state/s3/testing.go @@ -0,0 +1,26 @@ +package s3 + +import ( + "github.com/hashicorp/terraform/internal/tfdiags" +) + +// diagnosticComparer is a Comparer function for use with cmp.Diff to compare two tfdiags.Diagnostic values +func diagnosticComparer(l, r tfdiags.Diagnostic) bool { + if l.Severity() != r.Severity() { + return false + } + if l.Description() != r.Description() { + return false + } + + lp := tfdiags.GetAttribute(l) + rp := tfdiags.GetAttribute(r) + if len(lp) != len(rp) { + return false + } + if !lp.Equals(rp) { + return false + } + + return true +} diff --git a/internal/backend/remote-state/s3/validate.go b/internal/backend/remote-state/s3/validate.go new file mode 100644 index 000000000000..6f8c0a1a2a21 --- /dev/null +++ b/internal/backend/remote-state/s3/validate.go @@ -0,0 +1,80 @@ +package s3 + +import ( + "fmt" + "regexp" + + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/hashicorp/terraform/internal/tfdiags" + "github.com/zclconf/go-cty/cty" +) + +const ( + multiRegionKeyIdPattern = `mrk-[a-f0-9]{32}` + uuidRegexPattern = `[a-f0-9]{8}-[a-f0-9]{4}-[1-5][a-f0-9]{3}-[ab89][a-f0-9]{3}-[a-f0-9]{12}` +) + +func validateKMSKey(path cty.Path, s string) (diags tfdiags.Diagnostics) { + if arn.IsARN(s) { + return validateKMSKeyARN(path, s) + } + return validateKMSKeyID(path, s) +} + +func validateKMSKeyID(path cty.Path, s string) (diags tfdiags.Diagnostics) { + keyIdRegex := regexp.MustCompile(`^` + uuidRegexPattern + `|` + multiRegionKeyIdPattern + `$`) + if !keyIdRegex.MatchString(s) { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ID", + fmt.Sprintf("Value must be a valid KMS Key ID, got %q", s), + path, + )) + return diags + } + + return diags +} + +func validateKMSKeyARN(path cty.Path, s string) (diags tfdiags.Diagnostics) { + if _, err := arn.Parse(s); err != nil { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + fmt.Sprintf("Value must be a valid KMS Key ARN, got %q", s), + path, + )) + return diags + } + + if !isKeyARN(s) { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + fmt.Sprintf("Value must be a valid KMS Key ARN, got %q", s), + path, + )) + return diags + } + + return diags +} + +func isKeyARN(s string) bool { + parsedARN, err := arn.Parse(s) + if err != nil { + return false + } + + return keyIdFromARNResource(parsedARN.Resource) != "" +} + +func keyIdFromARNResource(s string) string { + keyIdResourceRegex := regexp.MustCompile(`^key/(` + uuidRegexPattern + `|` + multiRegionKeyIdPattern + `)$`) + matches := keyIdResourceRegex.FindStringSubmatch(s) + if matches == nil || len(matches) != 2 { + return "" + } + + return matches[1] +} diff --git a/internal/backend/remote-state/s3/validate_test.go b/internal/backend/remote-state/s3/validate_test.go new file mode 100644 index 000000000000..a4d5e32ca390 --- /dev/null +++ b/internal/backend/remote-state/s3/validate_test.go @@ -0,0 +1,154 @@ +package s3 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/terraform/internal/tfdiags" + "github.com/zclconf/go-cty/cty" +) + +func TestValidateKMSKey(t *testing.T) { + t.Parallel() + + path := cty.Path{cty.GetAttrStep{Name: "field"}} + + testcases := map[string]struct { + in string + expected tfdiags.Diagnostics + }{ + "kms key id": { + in: "57ff7a43-341d-46b6-aee3-a450c9de6dc8", + }, + "kms key arn": { + in: "arn:aws:kms:us-west-2:111122223333:key/57ff7a43-341d-46b6-aee3-a450c9de6dc8", + }, + "kms multi-region key id": { + in: "mrk-f827515944fb43f9b902a09d2c8b554f", + }, + "kms multi-region key arn": { + in: "arn:aws:kms:us-west-2:111122223333:key/mrk-a835af0b39c94b86a21a8fc9535df681", + }, + "kms key alias": { + in: "alias/arbitrary-key", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ID", + `Value must be a valid KMS Key ID, got "alias/arbitrary-key"`, + path, + ), + }, + }, + "kms key alias arn": { + in: "arn:aws:kms:us-west-2:111122223333:alias/arbitrary-key", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "arn:aws:kms:us-west-2:111122223333:alias/arbitrary-key"`, + path, + ), + }, + }, + "invalid key": { + in: "$%wrongkey", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ID", + `Value must be a valid KMS Key ID, got "$%wrongkey"`, + path, + ), + }, + }, + "non-kms arn": { + in: "arn:aws:lamda:foo:bar:key/xyz", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "arn:aws:lamda:foo:bar:key/xyz"`, + path, + ), + }, + }, + } + + for name, testcase := range testcases { + testcase := testcase + t.Run(name, func(t *testing.T) { + t.Parallel() + + diags := validateKMSKey(path, testcase.in) + + if diff := cmp.Diff(diags, testcase.expected, cmp.Comparer(diagnosticComparer)); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + }) + } +} + +func TestValidateKeyARN(t *testing.T) { + t.Parallel() + + path := cty.Path{cty.GetAttrStep{Name: "field"}} + + testcases := map[string]struct { + in string + expected tfdiags.Diagnostics + }{ + "kms key id": { + in: "arn:aws:kms:us-west-2:123456789012:key/57ff7a43-341d-46b6-aee3-a450c9de6dc8", + }, + "kms mrk key id": { + in: "arn:aws:kms:us-west-2:111122223333:key/mrk-a835af0b39c94b86a21a8fc9535df681", + }, + "kms non-key id": { + in: "arn:aws:kms:us-west-2:123456789012:something/else", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "arn:aws:kms:us-west-2:123456789012:something/else"`, + path, + ), + }, + }, + "non-kms arn": { + in: "arn:aws:iam::123456789012:user/David", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "arn:aws:iam::123456789012:user/David"`, + path, + ), + }, + }, + "not an arn": { + in: "not an arn", + expected: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid KMS Key ARN", + `Value must be a valid KMS Key ARN, got "not an arn"`, + path, + ), + }, + }, + } + + for name, testcase := range testcases { + testcase := testcase + t.Run(name, func(t *testing.T) { + t.Parallel() + + diags := validateKMSKeyARN(path, testcase.in) + + if diff := cmp.Diff(diags, testcase.expected, cmp.Comparer(diagnosticComparer)); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + }) + } +} From c3f4f9cedde3aa550057fa08659402aa9ba768dc Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 5 Jul 2023 17:33:20 -0700 Subject: [PATCH 14/20] Fixes region validation --- internal/backend/remote-state/s3/backend.go | 2 +- .../backend/remote-state/s3/backend_test.go | 62 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index 791576a723c3..3a50e9705a69 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -314,7 +314,7 @@ func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { region = v } - if boolAttr(obj, "skip_region_validation") { + if region != "" && !boolAttr(obj, "skip_region_validation") { if err := awsbase.ValidateRegion(region); err != nil { diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 95f97bf21bae..fc7995139f4b 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -16,12 +16,14 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/s3" + "github.com/google/go-cmp/cmp" awsbase "github.com/hashicorp/aws-sdk-go-base" "github.com/hashicorp/terraform/internal/backend" "github.com/hashicorp/terraform/internal/configs/configschema" "github.com/hashicorp/terraform/internal/configs/hcl2shim" "github.com/hashicorp/terraform/internal/states" "github.com/hashicorp/terraform/internal/states/remote" + "github.com/hashicorp/terraform/internal/tfdiags" "github.com/zclconf/go-cty/cty" ) @@ -95,6 +97,61 @@ func checkClientEndpoint(t *testing.T, config aws.Config, expected string) { } } +func TestBackendConfig_InvalidRegion(t *testing.T) { + testACC(t) + + cases := map[string]struct { + config map[string]any + expectedDiags tfdiags.Diagnostics + }{ + "with region validation": { + config: map[string]interface{}{ + "region": "nonesuch", + "bucket": "tf-test", + "key": "state", + "skip_credentials_validation": true, + }, + expectedDiags: tfdiags.Diagnostics{ + tfdiags.AttributeValue( + tfdiags.Error, + "Invalid region value", + `Invalid AWS Region: nonesuch`, + cty.Path{cty.GetAttrStep{Name: "region"}}, + ), + }, + }, + "skip region validation": { + config: map[string]interface{}{ + "region": "nonesuch", + "bucket": "tf-test", + "key": "state", + "skip_region_validation": true, + "skip_credentials_validation": true, + }, + expectedDiags: nil, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + b := New() + configSchema := populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(tc.config)) + + configSchema, diags := b.PrepareConfig(configSchema) + if len(diags) > 0 { + t.Fatal(diags.ErrWithWarnings()) + } + + confDiags := b.Configure(configSchema) + diags = diags.Append(confDiags) + + if diff := cmp.Diff(diags, tc.expectedDiags, cmp.Comparer(diagnosticComparer)); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } + }) + } +} + func TestBackendConfig_RegionEnvVar(t *testing.T) { testACC(t) config := map[string]interface{}{ @@ -695,6 +752,7 @@ func TestBackend(t *testing.T) { "bucket": bucketName, "key": keyName, "encrypt": true, + "region": "us-west-1", })).(*Backend) createS3Bucket(t, b.s3Client, bucketName) @@ -714,6 +772,7 @@ func TestBackendLocked(t *testing.T) { "key": keyName, "encrypt": true, "dynamodb_table": bucketName, + "region": "us-west-1", })).(*Backend) b2 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{ @@ -721,6 +780,7 @@ func TestBackendLocked(t *testing.T) { "key": keyName, "encrypt": true, "dynamodb_table": bucketName, + "region": "us-west-1", })).(*Backend) createS3Bucket(t, b1.s3Client, bucketName) @@ -762,6 +822,7 @@ func TestBackendSSECustomerKeyConfig(t *testing.T) { "encrypt": true, "key": "test-SSE-C", "sse_customer_key": testCase.customerKey, + "region": "us-west-1", } b := New().(*Backend) @@ -822,6 +883,7 @@ func TestBackendSSECustomerKeyEnvVar(t *testing.T) { "bucket": bucketName, "encrypt": true, "key": "test-SSE-C", + "region": "us-west-1", } os.Setenv("AWS_SSE_CUSTOMER_KEY", testCase.customerKey) From 29e14d148be7ba3e94283e43f2e8a0e2c2edcfd6 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 5 Jul 2023 17:33:36 -0700 Subject: [PATCH 15/20] Reorders `skip_...` parameters --- internal/backend/remote-state/s3/backend.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index 3a50e9705a69..26ea72c9fa91 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -130,15 +130,15 @@ func (b *Backend) ConfigSchema() *configschema.Block { Optional: true, Description: "Skip the credentials validation via STS API.", }, - "skip_region_validation": { + "skip_metadata_api_check": { Type: cty.Bool, Optional: true, - Description: "Skip static validation of region name.", + Description: "Skip the AWS Metadata API check.", }, - "skip_metadata_api_check": { + "skip_region_validation": { Type: cty.Bool, Optional: true, - Description: "Skip the AWS Metadata API check.", + Description: "Skip static validation of region name.", }, "sse_customer_key": { Type: cty.String, From 90c10ebbe9b56bb3701dc17fa64d852f0ce29bb9 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Thu, 6 Jul 2023 14:58:11 -0700 Subject: [PATCH 16/20] Removes redundant ARN parsing --- internal/backend/remote-state/s3/validate.go | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/internal/backend/remote-state/s3/validate.go b/internal/backend/remote-state/s3/validate.go index 6f8c0a1a2a21..97a9e1087e7a 100644 --- a/internal/backend/remote-state/s3/validate.go +++ b/internal/backend/remote-state/s3/validate.go @@ -37,7 +37,8 @@ func validateKMSKeyID(path cty.Path, s string) (diags tfdiags.Diagnostics) { } func validateKMSKeyARN(path cty.Path, s string) (diags tfdiags.Diagnostics) { - if _, err := arn.Parse(s); err != nil { + parsedARN, err := arn.Parse(s) + if err != nil { diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, "Invalid KMS Key ARN", @@ -47,7 +48,7 @@ func validateKMSKeyARN(path cty.Path, s string) (diags tfdiags.Diagnostics) { return diags } - if !isKeyARN(s) { + if !isKeyARN(parsedARN) { diags = diags.Append(tfdiags.AttributeValue( tfdiags.Error, "Invalid KMS Key ARN", @@ -60,13 +61,8 @@ func validateKMSKeyARN(path cty.Path, s string) (diags tfdiags.Diagnostics) { return diags } -func isKeyARN(s string) bool { - parsedARN, err := arn.Parse(s) - if err != nil { - return false - } - - return keyIdFromARNResource(parsedARN.Resource) != "" +func isKeyARN(arn arn.ARN) bool { + return keyIdFromARNResource(arn.Resource) != "" } func keyIdFromARNResource(s string) string { From 2f00c86255ff85778941878d7bb4da0945748f4a Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Thu, 6 Jul 2023 15:00:02 -0700 Subject: [PATCH 17/20] Adds endpoint tests when configured in configuration --- internal/backend/remote-state/s3/backend.go | 1 - .../backend/remote-state/s3/backend_test.go | 97 ++++++++++++++++++- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index 26ea72c9fa91..f1d0345562a6 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -334,7 +334,6 @@ func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { b.kmsKeyID = stringAttr(obj, "kms_key_id") b.ddbTable = stringAttr(obj, "dynamodb_table") - // WarnOnEmptyString(), LenEquals(44), IsBase64Encoded() if customerKey, ok := stringAttrOk(obj, "sse_customer_key"); ok { if len(customerKey) != 44 { diags = diags.Append(tfdiags.AttributeValue( diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index fc7995139f4b..7674590937b4 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -195,7 +195,7 @@ func TestBackendConfig_RegionEnvVar(t *testing.T) { } } -func TestBackendConfig_DynamoDBEndpoint(t *testing.T) { +func TestBackendConfig_DynamoDBEndpointEnvVar(t *testing.T) { testACC(t) config := map[string]interface{}{ "region": "us-west-1", @@ -220,7 +220,21 @@ func TestBackendConfig_DynamoDBEndpoint(t *testing.T) { checkClientEndpoint(t, b.dynClient.Config, "dynamo.test") } -func TestBackendConfig_S3Endpoint(t *testing.T) { +func TestBackendConfig_DynamoDBEndpointConfig(t *testing.T) { + testACC(t) + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", + "dynamodb_endpoint": "dynamo.test", + } + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + + checkClientEndpoint(t, b.dynClient.Config, "dynamo.test") +} + +func TestBackendConfig_S3EndpointEnvVar(t *testing.T) { testACC(t) config := map[string]interface{}{ "region": "us-west-1", @@ -245,7 +259,21 @@ func TestBackendConfig_S3Endpoint(t *testing.T) { checkClientEndpoint(t, b.s3Client.Config, "s3.test") } -func TestBackendConfig_STSEndpoint(t *testing.T) { +func TestBackendConfig_S3EndpointConfig(t *testing.T) { + testACC(t) + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", + "endpoint": "s3.test", + } + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + + checkClientEndpoint(t, b.s3Client.Config, "s3.test") +} + +func TestBackendConfig_STSEndpointEnvVar(t *testing.T) { testACC(t) testCases := []struct { @@ -311,6 +339,69 @@ func TestBackendConfig_STSEndpoint(t *testing.T) { } } +func TestBackendConfig_STSEndpointConfig(t *testing.T) { + testACC(t) + + testCases := []struct { + Config map[string]interface{} + Description string + MockStsEndpoints []*awsbase.MockEndpoint + }{ + { + Config: map[string]interface{}{ + "bucket": "tf-test", + "key": "state", + "region": "us-west-1", + "role_arn": awsbase.MockStsAssumeRoleArn, + "session_name": awsbase.MockStsAssumeRoleSessionName, + }, + Description: "role_arn", + MockStsEndpoints: []*awsbase.MockEndpoint{ + { + Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: url.Values{ + "Action": []string{"AssumeRole"}, + "DurationSeconds": []string{"900"}, + "RoleArn": []string{awsbase.MockStsAssumeRoleArn}, + "RoleSessionName": []string{awsbase.MockStsAssumeRoleSessionName}, + "Version": []string{"2011-06-15"}, + }.Encode()}, + Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsAssumeRoleValidResponseBody, ContentType: "text/xml"}, + }, + { + Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: mockStsGetCallerIdentityRequestBody}, + Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsGetCallerIdentityValidResponseBody, ContentType: "text/xml"}, + }, + }, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.Description, func(t *testing.T) { + closeSts, mockStsSession, err := awsbase.GetMockedAwsApiSession("STS", testCase.MockStsEndpoints) + defer closeSts() + + if err != nil { + t.Fatalf("unexpected error creating mock STS server: %s", err) + } + + if mockStsSession != nil && mockStsSession.Config != nil { + testCase.Config["sts_endpoint"] = aws.StringValue(mockStsSession.Config.Endpoint) + } + + b := New() + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(testCase.Config))) + + if diags.HasErrors() { + for _, diag := range diags { + t.Errorf("unexpected error: %s", diag.Description().Summary) + } + } + }) + } +} + func TestBackendConfig_AssumeRole(t *testing.T) { testACC(t) From d179b686d926c85f442e6305731cf57f740038b0 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Thu, 6 Jul 2023 16:02:07 -0700 Subject: [PATCH 18/20] Consolidates endpoint tests --- .../backend/remote-state/s3/backend_test.go | 304 +++++++++--------- internal/backend/remote-state/s3/testing.go | 16 + 2 files changed, 165 insertions(+), 155 deletions(-) diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 7674590937b4..3b244455a0e3 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -195,208 +195,202 @@ func TestBackendConfig_RegionEnvVar(t *testing.T) { } } -func TestBackendConfig_DynamoDBEndpointEnvVar(t *testing.T) { +func TestBackendConfig_DynamoDBEndpoint(t *testing.T) { testACC(t) - config := map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "state", - } - vars := map[string]string{ - "AWS_DYNAMODB_ENDPOINT": "dynamo.test", - } - for k, v := range vars { - os.Setenv(k, v) - } - t.Cleanup(func() { - for k := range vars { - os.Unsetenv(k) - } - }) - - b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) - - checkClientEndpoint(t, b.dynClient.Config, "dynamo.test") -} - -func TestBackendConfig_DynamoDBEndpointConfig(t *testing.T) { - testACC(t) - config := map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "state", - "dynamodb_endpoint": "dynamo.test", + cases := map[string]struct { + config map[string]any + vars map[string]string + expected string + }{ + "none": { + expected: "", + }, + "config": { + config: map[string]any{ + "dynamodb_endpoint": "dynamo.test", + }, + expected: "dynamo.test", + }, + "envvar": { + vars: map[string]string{ + "AWS_DYNAMODB_ENDPOINT": "dynamo.test", + }, + expected: "dynamo.test", + }, } - b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) - - checkClientEndpoint(t, b.dynClient.Config, "dynamo.test") -} - -func TestBackendConfig_S3EndpointEnvVar(t *testing.T) { - testACC(t) - config := map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "state", - } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", + } - vars := map[string]string{ - "AWS_S3_ENDPOINT": "s3.test", - } - for k, v := range vars { - os.Setenv(k, v) - } - t.Cleanup(func() { - for k := range vars { - os.Unsetenv(k) - } - }) + if tc.vars != nil { + for k, v := range tc.vars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range tc.vars { + os.Unsetenv(k) + } + }) + } - b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + if tc.config != nil { + for k, v := range tc.config { + config[k] = v + } + } - checkClientEndpoint(t, b.s3Client.Config, "s3.test") -} + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) -func TestBackendConfig_S3EndpointConfig(t *testing.T) { - testACC(t) - config := map[string]interface{}{ - "region": "us-west-1", - "bucket": "tf-test", - "key": "state", - "endpoint": "s3.test", + checkClientEndpoint(t, b.dynClient.Config, tc.expected) + }) } - - b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) - - checkClientEndpoint(t, b.s3Client.Config, "s3.test") } -func TestBackendConfig_STSEndpointEnvVar(t *testing.T) { +func TestBackendConfig_S3Endpoint(t *testing.T) { testACC(t) - testCases := []struct { - Config map[string]interface{} - Description string - MockStsEndpoints []*awsbase.MockEndpoint + cases := map[string]struct { + config map[string]any + vars map[string]string + expected string }{ - { - Config: map[string]interface{}{ - "bucket": "tf-test", - "key": "state", - "region": "us-west-1", - "role_arn": awsbase.MockStsAssumeRoleArn, - "session_name": awsbase.MockStsAssumeRoleSessionName, + "none": { + expected: "", + }, + "config": { + config: map[string]any{ + "endpoint": "s3.test", }, - Description: "role_arn", - MockStsEndpoints: []*awsbase.MockEndpoint{ - { - Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: url.Values{ - "Action": []string{"AssumeRole"}, - "DurationSeconds": []string{"900"}, - "RoleArn": []string{awsbase.MockStsAssumeRoleArn}, - "RoleSessionName": []string{awsbase.MockStsAssumeRoleSessionName}, - "Version": []string{"2011-06-15"}, - }.Encode()}, - Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsAssumeRoleValidResponseBody, ContentType: "text/xml"}, - }, - { - Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: mockStsGetCallerIdentityRequestBody}, - Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsGetCallerIdentityValidResponseBody, ContentType: "text/xml"}, - }, + expected: "s3.test", + }, + "envvar": { + vars: map[string]string{ + "AWS_S3_ENDPOINT": "s3.test", }, + expected: "s3.test", }, } - for _, testCase := range testCases { - testCase := testCase - - t.Run(testCase.Description, func(t *testing.T) { - closeSts, mockStsSession, err := awsbase.GetMockedAwsApiSession("STS", testCase.MockStsEndpoints) - defer closeSts() - - if err != nil { - t.Fatalf("unexpected error creating mock STS server: %s", err) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", } - if mockStsSession != nil && mockStsSession.Config != nil { - os.Setenv("AWS_STS_ENDPOINT", aws.StringValue(mockStsSession.Config.Endpoint)) + if tc.vars != nil { + for k, v := range tc.vars { + os.Setenv(k, v) + } t.Cleanup(func() { - os.Unsetenv("AWS_STS_ENDPOINT") + for k := range tc.vars { + os.Unsetenv(k) + } }) } - b := New() - diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(testCase.Config))) - - if diags.HasErrors() { - for _, diag := range diags { - t.Errorf("unexpected error: %s", diag.Description().Summary) + if tc.config != nil { + for k, v := range tc.config { + config[k] = v } } + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend) + + checkClientEndpoint(t, b.s3Client.Config, tc.expected) }) } } -func TestBackendConfig_STSEndpointConfig(t *testing.T) { +func TestBackendConfig_STSEndpoint(t *testing.T) { testACC(t) - testCases := []struct { - Config map[string]interface{} - Description string - MockStsEndpoints []*awsbase.MockEndpoint - }{ + stsMocks := []*awsbase.MockEndpoint{ { - Config: map[string]interface{}{ - "bucket": "tf-test", - "key": "state", - "region": "us-west-1", - "role_arn": awsbase.MockStsAssumeRoleArn, - "session_name": awsbase.MockStsAssumeRoleSessionName, - }, - Description: "role_arn", - MockStsEndpoints: []*awsbase.MockEndpoint{ - { - Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: url.Values{ - "Action": []string{"AssumeRole"}, - "DurationSeconds": []string{"900"}, - "RoleArn": []string{awsbase.MockStsAssumeRoleArn}, - "RoleSessionName": []string{awsbase.MockStsAssumeRoleSessionName}, - "Version": []string{"2011-06-15"}, - }.Encode()}, - Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsAssumeRoleValidResponseBody, ContentType: "text/xml"}, - }, - { - Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: mockStsGetCallerIdentityRequestBody}, - Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsGetCallerIdentityValidResponseBody, ContentType: "text/xml"}, - }, - }, + Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: url.Values{ + "Action": []string{"AssumeRole"}, + "DurationSeconds": []string{"900"}, + "RoleArn": []string{awsbase.MockStsAssumeRoleArn}, + "RoleSessionName": []string{awsbase.MockStsAssumeRoleSessionName}, + "Version": []string{"2011-06-15"}, + }.Encode()}, + Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsAssumeRoleValidResponseBody, ContentType: "text/xml"}, + }, + { + Request: &awsbase.MockRequest{Method: "POST", Uri: "/", Body: mockStsGetCallerIdentityRequestBody}, + Response: &awsbase.MockResponse{StatusCode: 200, Body: awsbase.MockStsGetCallerIdentityValidResponseBody, ContentType: "text/xml"}, }, } - for _, testCase := range testCases { - testCase := testCase + cases := map[string]struct { + setConfig bool + setEnvVars bool + expectedDiags tfdiags.Diagnostics + }{ + "none": { + expectedDiags: tfdiags.Diagnostics{ + tfdiags.Sourceless( + tfdiags.Error, + "Failed to configure AWS client", + ``, + ), + }, + }, + "config": { + setConfig: true, + }, + "envvar": { + setEnvVars: true, + }, + } - t.Run(testCase.Description, func(t *testing.T) { - closeSts, mockStsSession, err := awsbase.GetMockedAwsApiSession("STS", testCase.MockStsEndpoints) - defer closeSts() + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + config := map[string]interface{}{ + "region": "us-west-1", + "bucket": "tf-test", + "key": "state", + "role_arn": awsbase.MockStsAssumeRoleArn, + "session_name": awsbase.MockStsAssumeRoleSessionName, + } + closeSts, mockStsSession, err := awsbase.GetMockedAwsApiSession("STS", stsMocks) if err != nil { t.Fatalf("unexpected error creating mock STS server: %s", err) } + defer closeSts() - if mockStsSession != nil && mockStsSession.Config != nil { - testCase.Config["sts_endpoint"] = aws.StringValue(mockStsSession.Config.Endpoint) + if tc.setEnvVars { + os.Setenv("AWS_STS_ENDPOINT", aws.StringValue(mockStsSession.Config.Endpoint)) + t.Cleanup(func() { + os.Unsetenv("AWS_STS_ENDPOINT") + }) + } + + if tc.setConfig { + config["sts_endpoint"] = aws.StringValue(mockStsSession.Config.Endpoint) } b := New() - diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(testCase.Config))) + configSchema := populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config)) - if diags.HasErrors() { - for _, diag := range diags { - t.Errorf("unexpected error: %s", diag.Description().Summary) - } + configSchema, diags := b.PrepareConfig(configSchema) + if len(diags) > 0 { + t.Fatal(diags.ErrWithWarnings()) + } + + confDiags := b.Configure(configSchema) + diags = diags.Append(confDiags) + + if diff := cmp.Diff(diags, tc.expectedDiags, cmp.Comparer(diagnosticSummaryComparer)); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) } }) } diff --git a/internal/backend/remote-state/s3/testing.go b/internal/backend/remote-state/s3/testing.go index 25c58625a2c0..cd519a26c29f 100644 --- a/internal/backend/remote-state/s3/testing.go +++ b/internal/backend/remote-state/s3/testing.go @@ -24,3 +24,19 @@ func diagnosticComparer(l, r tfdiags.Diagnostic) bool { return true } + +// diagnosticSummaryComparer is a Comparer function for use with cmp.Diff to compare +// the Severity and Summary fields two tfdiags.Diagnostic values +func diagnosticSummaryComparer(l, r tfdiags.Diagnostic) bool { + if l.Severity() != r.Severity() { + return false + } + + ld := l.Description() + rd := r.Description() + if ld.Summary != rd.Summary { + return false + } + + return true +} From 344e9de6b925a97f7c9d03b94c979a815a72135f Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Mon, 17 Jul 2023 11:34:33 -0700 Subject: [PATCH 19/20] Linting fixes --- .../backend/remote-state/s3/backend_state.go | 4 ---- .../backend/remote-state/s3/backend_test.go | 22 ++++++++++++++----- internal/backend/remote-state/s3/testing.go | 12 ++-------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/internal/backend/remote-state/s3/backend_state.go b/internal/backend/remote-state/s3/backend_state.go index add7d399d88c..89c2b08735ca 100644 --- a/internal/backend/remote-state/s3/backend_state.go +++ b/internal/backend/remote-state/s3/backend_state.go @@ -203,10 +203,6 @@ func (b *Backend) StateMgr(name string) (statemgr.Full, error) { return stateMgr, nil } -func (b *Backend) client() *RemoteClient { - return &RemoteClient{} -} - func (b *Backend) path(name string) string { if name == backend.DefaultStateName { return b.keyName diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 3b244455a0e3..4e39e258cfc2 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -1038,7 +1038,9 @@ func TestBackendExtraPaths(t *testing.T) { // Write the first state stateMgr := &remote.State{Client: client} - stateMgr.WriteState(s1) + if err := stateMgr.WriteState(s1); err != nil { + t.Fatal(err) + } if err := stateMgr.PersistState(nil); err != nil { t.Fatal(err) } @@ -1048,7 +1050,9 @@ func TestBackendExtraPaths(t *testing.T) { // states are equal, the state will not Put to the remote client.path = b.path("s2") stateMgr2 := &remote.State{Client: client} - stateMgr2.WriteState(s2) + if err := stateMgr2.WriteState(s2); err != nil { + t.Fatal(err) + } if err := stateMgr2.PersistState(nil); err != nil { t.Fatal(err) } @@ -1061,7 +1065,9 @@ func TestBackendExtraPaths(t *testing.T) { // put a state in an env directory name client.path = b.workspaceKeyPrefix + "/error" - stateMgr.WriteState(states.NewState()) + if err := stateMgr.WriteState(states.NewState()); err != nil { + t.Fatal(err) + } if err := stateMgr.PersistState(nil); err != nil { t.Fatal(err) } @@ -1071,7 +1077,9 @@ func TestBackendExtraPaths(t *testing.T) { // add state with the wrong key for an existing env client.path = b.workspaceKeyPrefix + "/s2/notTestState" - stateMgr.WriteState(states.NewState()) + if err := stateMgr.WriteState(states.NewState()); err != nil { + t.Fatal(err) + } if err := stateMgr.PersistState(nil); err != nil { t.Fatal(err) } @@ -1105,12 +1113,14 @@ func TestBackendExtraPaths(t *testing.T) { if s2Mgr.(*remote.State).StateSnapshotMeta().Lineage == s2Lineage { t.Fatal("state s2 was not deleted") } - s2 = s2Mgr.State() + _ = s2Mgr.State() // We need the side-effect s2Lineage = stateMgr.StateSnapshotMeta().Lineage // add a state with a key that matches an existing environment dir name client.path = b.workspaceKeyPrefix + "/s2/" - stateMgr.WriteState(states.NewState()) + if err := stateMgr.WriteState(states.NewState()); err != nil { + t.Fatal(err) + } if err := stateMgr.PersistState(nil); err != nil { t.Fatal(err) } diff --git a/internal/backend/remote-state/s3/testing.go b/internal/backend/remote-state/s3/testing.go index cd519a26c29f..7cade2bc5729 100644 --- a/internal/backend/remote-state/s3/testing.go +++ b/internal/backend/remote-state/s3/testing.go @@ -18,11 +18,7 @@ func diagnosticComparer(l, r tfdiags.Diagnostic) bool { if len(lp) != len(rp) { return false } - if !lp.Equals(rp) { - return false - } - - return true + return lp.Equals(rp) } // diagnosticSummaryComparer is a Comparer function for use with cmp.Diff to compare @@ -34,9 +30,5 @@ func diagnosticSummaryComparer(l, r tfdiags.Diagnostic) bool { ld := l.Description() rd := r.Description() - if ld.Summary != rd.Summary { - return false - } - - return true + return ld.Summary == rd.Summary } From 8564a5bf0ee5a3b23ded7609423384fc7c0f2be4 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Mon, 24 Jul 2023 17:36:51 -0700 Subject: [PATCH 20/20] Fixes type of parameter `skip_credentials_validation` --- internal/backend/remote-state/s3/backend.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index f1d0345562a6..ccb340c6f5d6 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -126,7 +126,7 @@ func (b *Backend) ConfigSchema() *configschema.Block { Description: "MFA token", }, "skip_credentials_validation": { - Type: cty.String, + Type: cty.Bool, Optional: true, Description: "Skip the credentials validation via STS API.", },