Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate AWS Rate limiting errors to 502 errors #5270

Merged
merged 4 commits into from
Sep 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions builtin/credential/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/hashicorp/vault/helper/awsutil"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
Expand Down Expand Up @@ -233,14 +234,14 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag
}
iamClient, err := b.clientIAM(ctx, s, region.ID(), entity.AccountNumber)
if err != nil {
return "", err
return "", awsutil.AppendLogicalError(err)
}

switch entity.Type {
case "user":
userInfo, err := iamClient.GetUser(&iam.GetUserInput{UserName: &entity.FriendlyName})
if err != nil {
return "", err
return "", awsutil.AppendLogicalError(err)
}
if userInfo == nil {
return "", fmt.Errorf("got nil result from GetUser")
Expand All @@ -249,7 +250,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag
case "role":
roleInfo, err := iamClient.GetRole(&iam.GetRoleInput{RoleName: &entity.FriendlyName})
if err != nil {
return "", err
return "", awsutil.AppendLogicalError(err)
}
if roleInfo == nil {
return "", fmt.Errorf("got nil result from GetRole")
Expand All @@ -258,7 +259,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag
case "instance-profile":
profileInfo, err := iamClient.GetInstanceProfile(&iam.GetInstanceProfileInput{InstanceProfileName: &entity.FriendlyName})
if err != nil {
return "", err
return "", awsutil.AppendLogicalError(err)
}
if profileInfo == nil {
return "", fmt.Errorf("got nil result from GetInstanceProfile")
Expand Down
6 changes: 4 additions & 2 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/awsutil"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
Expand Down Expand Up @@ -132,7 +133,7 @@ func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName str
InstanceProfileName: aws.String(instanceProfileName),
})
if err != nil {
return "", err
return "", awsutil.AppendLogicalError(err)
}
if profile == nil {
return "", fmt.Errorf("nil output while getting instance profile details")
Expand Down Expand Up @@ -168,7 +169,8 @@ func (b *backend) validateInstance(ctx context.Context, s logical.Storage, insta
},
})
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error fetching description for instance ID %q: {{err}}", instanceID), err)
errW := errwrap.Wrapf(fmt.Sprintf("error fetching description for instance ID %q: {{err}}", instanceID), err)
return nil, errwrap.Wrap(errW, awsutil.CheckAWSError(err))
}
if status == nil {
return nil, fmt.Errorf("nil output from describe instances")
Expand Down
68 changes: 67 additions & 1 deletion builtin/logical/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"sync"
"time"

"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
Expand Down Expand Up @@ -33,7 +35,7 @@ func Backend() *backend {
},

Paths: []*framework.Path{
pathConfigRoot(),
pathConfigRoot(&b),
pathConfigLease(&b),
pathRoles(&b),
pathListRoles(&b),
Expand All @@ -57,6 +59,14 @@ type backend struct {

// Mutex to protect access to reading and writing policies
roleMutex sync.RWMutex

// Mutex to protect access to iam/sts clients
clientMutex sync.RWMutex

// iamClient and stsClient hold configured iam and sts clients for reuse, and
// to enable mocking with AWS iface for tests
iamClient iamiface.IAMAPI
stsClient stsiface.STSAPI
}

const backendHelp = `
Expand All @@ -68,3 +78,59 @@ After mounting this backend, credentials to generate IAM keys must
be configured with the "root" path and policies must be written using
the "roles/" endpoints before any access keys can be generated.
`

// clientIAM returns the configured IAM client. If nil, it constructs a new one
// and returns it, setting it the internal variable
func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IAMAPI, error) {
b.clientMutex.RLock()
if b.iamClient != nil {
b.clientMutex.RUnlock()
return b.iamClient, nil
}

// Upgrade the lock for writing
b.clientMutex.RUnlock()
b.clientMutex.Lock()
defer b.clientMutex.Unlock()

// check client again, in the event that a client was being created while we
// waited for Lock()
if b.iamClient != nil {
return b.iamClient, nil
}

iamClient, err := clientIAM(ctx, s)
if err != nil {
return nil, err
}
b.iamClient = iamClient

return b.iamClient, nil
}

func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.STSAPI, error) {
b.clientMutex.RLock()
if b.stsClient != nil {
b.clientMutex.RUnlock()
return b.stsClient, nil
}

// Upgrade the lock for writing
b.clientMutex.RUnlock()
b.clientMutex.Lock()
defer b.clientMutex.Unlock()

// check client again, in the event that a client was being created while we
// waited for Lock()
if b.stsClient != nil {
return b.stsClient, nil
}

stsClient, err := clientSTS(ctx, s)
if err != nil {
return nil, err
}
b.stsClient = stsClient

return b.stsClient, nil
}
81 changes: 73 additions & 8 deletions builtin/logical/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
"net/http"
"os"
"reflect"
"testing"
Expand All @@ -16,13 +17,22 @@ import (
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/mitchellh/mapstructure"
)

type mockIAMClient struct {
iamiface.IAMAPI
}

func (m *mockIAMClient) CreateUser(input *iam.CreateUserInput) (*iam.CreateUserOutput, error) {
return nil, awserr.New("Throttling", "", nil)
}

func getBackend(t *testing.T) logical.Backend {
be, _ := Factory(context.Background(), logical.TestBackendConfig())
return be
Expand Down Expand Up @@ -89,24 +99,79 @@ func TestBackend_policyCrud(t *testing.T) {
})
}

func TestBackend_throttled(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}

b := Backend()
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}

connData := map[string]interface{}{
"credential_type": "iam_user",
}

confReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/something",
Storage: config.StorageView,
Data: connData,
}

resp, err := b.HandleRequest(context.Background(), confReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err)
}

// Mock the IAM API call to return a throttled response to the CreateUser API
// call
b.iamClient = &mockIAMClient{}

credReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "creds/something",
Storage: config.StorageView,
}

credResp, err := b.HandleRequest(context.Background(), credReq)
if err == nil {
t.Fatalf("failed to trigger expected throttling error condition: resp:%#v", credResp)
}
rErr := credResp.Error()
expected := "Error creating IAM user: Throttling: "
if rErr.Error() != expected {
t.Fatalf("error message did not match, expected (%s), got (%s)", expected, rErr.Error())
}

// verify the error we got back is returned with a http.StatusBadGateway
code, err := logical.RespondErrorCommon(credReq, credResp, err)
if err == nil {
t.Fatal("expected error after running req/resp/err through RespondErrorCommon, got nil")
}
if code != http.StatusBadGateway {
t.Fatalf("expected HTTP status 'bad gateway', got: (%d)", code)
}
}

func testAccPreCheck(t *testing.T) {
if v := os.Getenv("AWS_DEFAULT_REGION"); v == "" {
log.Println("[INFO] Test: Using us-west-2 as test region")
os.Setenv("AWS_DEFAULT_REGION", "us-west-2")
}

if v := os.Getenv("AWS_ACCOUNT_ID"); v == "" {
accountId, err := getAccountId()
accountID, err := getAccountID()
if err != nil {
t.Logf("Unable to retrive user via iam:GetUser: %#v", err)
t.Skip("AWS_ACCOUNT_ID not explicitly set and could not be read from iam:GetUser for acceptance tests, skipping")
}
log.Printf("[INFO] Test: Used %s as AWS_ACCOUNT_ID", accountId)
os.Setenv("AWS_ACCOUNT_ID", accountId)
log.Printf("[INFO] Test: Used %s as AWS_ACCOUNT_ID", accountID)
os.Setenv("AWS_ACCOUNT_ID", accountID)
}
}

func getAccountId() (string, error) {
func getAccountID() (string, error) {
awsConfig := &aws.Config{
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
Expand Down Expand Up @@ -251,7 +316,7 @@ func createUser(t *testing.T, accessKey *awsAccessKey) {
}
genAccessKey := createAccessKeyOutput.AccessKey

accessKey.AccessKeyId = *genAccessKey.AccessKeyId
accessKey.AccessKeyID = *genAccessKey.AccessKeyId
accessKey.SecretAccessKey = *genAccessKey.SecretAccessKey
}

Expand Down Expand Up @@ -308,7 +373,7 @@ func teardown(accessKey *awsAccessKey) error {
}

deleteAccessKeyInput := &iam.DeleteAccessKeyInput{
AccessKeyId: aws.String(accessKey.AccessKeyId),
AccessKeyId: aws.String(accessKey.AccessKeyID),
UserName: aws.String(testUserName),
}
_, err = svc.DeleteAccessKey(deleteAccessKeyInput)
Expand Down Expand Up @@ -361,7 +426,7 @@ func testAccStepConfigWithCreds(t *testing.T, accessKey *awsAccessKey) logicalte
// In particular, they get evaluated before accessKey gets set by CreateUser
// and thus would fail. By moving to a closure in a PreFlight, we ensure that
// the creds get evaluated lazily after they've been properly set
req.Data["access_key"] = accessKey.AccessKeyId
req.Data["access_key"] = accessKey.AccessKeyID
req.Data["secret_key"] = accessKey.SecretAccessKey
return nil
},
Expand Down Expand Up @@ -731,7 +796,7 @@ func testAccStepWriteArnRoleRef(t *testing.T, roleName string) logicaltest.TestS
}

type awsAccessKey struct {
AccessKeyId string
AccessKeyID string
SecretAccessKey string
}

Expand Down
13 changes: 10 additions & 3 deletions builtin/logical/aws/path_config_root.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/hashicorp/vault/logical/framework"
)

func pathConfigRoot() *framework.Path {
func pathConfigRoot(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/root",
Fields: map[string]*framework.FieldSchema{
Expand Down Expand Up @@ -42,15 +42,15 @@ func pathConfigRoot() *framework.Path {
},

Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: pathConfigRootWrite,
logical.UpdateOperation: b.pathConfigRootWrite,
},

HelpSynopsis: pathConfigRootHelpSyn,
HelpDescription: pathConfigRootHelpDesc,
}
}

func pathConfigRootWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
region := data.Get("region").(string)
iamendpoint := data.Get("iam_endpoint").(string)
stsendpoint := data.Get("sts_endpoint").(string)
Expand All @@ -72,6 +72,13 @@ func pathConfigRootWrite(ctx context.Context, req *logical.Request, data *framew
return nil, err
}

// clear possible cached IAM / STS clients after successfully updating
// config/root
b.clientMutex.Lock()
defer b.clientMutex.Unlock()
b.iamClient = nil
b.stsClient = nil

return nil, nil
}

Expand Down
Loading