Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add awsbase.ErrCodeEquals, AWS SDK for Go v2 variant of helper in v2/awsv1shim/tfawserr #524

Merged
merged 4 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,22 @@
package awsbase

import (
"errors"

"github.com/hashicorp/aws-sdk-go-base/v2/internal/config"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/errs"
)

// CannotAssumeRoleError occurs when AssumeRole cannot complete.
type CannotAssumeRoleError = config.CannotAssumeRoleError

// IsCannotAssumeRoleError returns true if the error contains the CannotAssumeRoleError type.
func IsCannotAssumeRoleError(err error) bool {
var e CannotAssumeRoleError
return errors.As(err, &e)
return errs.IsA[CannotAssumeRoleError](err)
}

// NoValidCredentialSourcesError occurs when all credential lookup methods have been exhausted without results.
type NoValidCredentialSourcesError = config.NoValidCredentialSourcesError

// IsNoValidCredentialSourcesError returns true if the error contains the NoValidCredentialSourcesError type.
func IsNoValidCredentialSourcesError(err error) bool {
var e NoValidCredentialSourcesError
return errors.As(err, &e)
return errs.IsA[NoValidCredentialSourcesError](err)
}
85 changes: 85 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package awsbase

import (
"fmt"
"testing"
)

func TestIsCannotAssumeRoleError(t *testing.T) {
testCases := []struct {
Name string
Err error
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level NoValidCredentialSourcesError",
Err: NoValidCredentialSourcesError{},
},
{
Name: "Top-level CannotAssumeRoleError",
Err: CannotAssumeRoleError{},
Expected: true,
},
{
Name: "Nested CannotAssumeRoleError",
Err: fmt.Errorf("test: %w", CannotAssumeRoleError{}),
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := IsCannotAssumeRoleError(testCase.Err)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}

func TestIsNoValidCredentialSourcesError(t *testing.T) {
testCases := []struct {
Name string
Err error
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level CannotAssumeRoleError",
Err: CannotAssumeRoleError{},
},
{
Name: "Top-level NoValidCredentialSourcesError",
Err: NoValidCredentialSourcesError{},
Expected: true,
},
{
Name: "Nested NoValidCredentialSourcesError",
Err: fmt.Errorf("test: %w", NoValidCredentialSourcesError{}),
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := IsNoValidCredentialSourcesError(testCase.Err)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}
21 changes: 21 additions & 0 deletions internal/errs/errs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package errs

import (
"errors"
)

// IsA indicates whether an error matches an error type.
func IsA[T error](err error) bool {
_, ok := As[T](err)
return ok
}

// As is equivalent to errors.As(), but returns the value in-line.
func As[T error](err error) (T, bool) {
var as T
ok := errors.As(err, &as)
return as, ok
}
23 changes: 23 additions & 0 deletions tfawserr/awserr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package tfawserr

import (
smithy "github.com/aws/smithy-go"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/errs"
)

// ErrCodeEquals returns true if the error matches all these conditions:
// - err is of type smithy.APIError
// - Error.Code() equals one of the passed codes
func ErrCodeEquals(err error, codes ...string) bool {
if apiErr, ok := errs.As[smithy.APIError](err); ok {
for _, code := range codes {
if apiErr.ErrorCode() == code {
return true
}
}
}
return false
}
105 changes: 105 additions & 0 deletions tfawserr/awserr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package tfawserr

import (
"fmt"
"testing"

"github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/aws/aws-sdk-go/aws"
Copy link
Contributor Author

@ewbankkit ewbankkit Jun 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be github.com/aws/aws-sdk-go-v2/aws to prevent the go.mod diff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy and paste :(

smithy "github.com/aws/smithy-go"
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
)

func TestErrCodeEquals(t *testing.T) {
testCases := []struct {
Name string
Err error
Codes []string
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level CannotAssumeRoleError",
Err: awsbase.CannotAssumeRoleError{},
},
{
Name: "Top-level smithy.GenericAPIError matching first code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Top-level smithy.GenericAPIError matching last code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Top-level smithy.GenericAPIError no code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
},
{
Name: "Top-level smithy.GenericAPIError non-matching codes",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
{
Name: "Wrapped smithy.GenericAPIError matching first code",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Wrapped smithy.GenericAPIError matching last code",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Wrapped smithy.GenericAPIError non-matching codes",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
{
Name: "Top-level sts ExpiredTokenException matching first code",
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Top-level sts ExpiredTokenException matching last code",
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Wrapped sts ExpiredTokenException matching first code",
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Wrapped sts ExpiredTokenException matching last code",
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := ErrCodeEquals(testCase.Err, testCase.Codes...)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}