From 3c74986d2593373a26b7c919871bdb4d45d0d683 Mon Sep 17 00:00:00 2001 From: Clint Shryock Date: Tue, 11 Sep 2018 16:38:51 -0500 Subject: [PATCH] refactor locking and unlocking into methods on *backend --- builtin/logical/aws/backend.go | 42 ++++++++++++++++ builtin/logical/aws/secret_access_keys.go | 58 +++++++---------------- 2 files changed, 58 insertions(+), 42 deletions(-) diff --git a/builtin/logical/aws/backend.go b/builtin/logical/aws/backend.go index 72a24b8c7a6a..baea7599de13 100644 --- a/builtin/logical/aws/backend.go +++ b/builtin/logical/aws/backend.go @@ -63,6 +63,8 @@ type backend struct { // Mutex to protect access to iam/sts clients clientMutex sync.RWMutex + // iamClient and stsClient hold configured iam and sts clients for reuse, and + // to enable mocking with AWS iface for tests iamClient iamiface.IAMAPI stsClient stsiface.STSAPI } @@ -76,3 +78,43 @@ After mounting this backend, credentials to generate IAM keys must be configured with the "root" path and policies must be written using the "roles/" endpoints before any access keys can be generated. ` + +// clientIAM returns the configured IAM client. If nil, it constructs a new one +// and returns it, setting it the internal variable +func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IAMAPI, error) { + b.clientMutex.RLock() + unlockFunc := b.clientMutex.RUnlock + defer func() { unlockFunc() }() + if b.iamClient == nil { + // Upgrade the lock for writing + b.clientMutex.RUnlock() + b.clientMutex.Lock() + unlockFunc = b.clientMutex.Unlock + + iamClient, err := clientIAM(ctx, s) + if err != nil { + return nil, err + } + b.iamClient = iamClient + } + return b.iamClient, nil +} + +func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.STSAPI, error) { + b.clientMutex.RLock() + unlockFunc := b.clientMutex.RUnlock + defer func() { unlockFunc() }() + if b.stsClient == nil { + // Upgrade the lock for writing + b.clientMutex.RUnlock() + b.clientMutex.Lock() + unlockFunc = b.clientMutex.Unlock + + stsClient, err := clientSTS(ctx, s) + if err != nil { + return nil, err + } + b.stsClient = stsClient + } + return b.stsClient, nil +} diff --git a/builtin/logical/aws/secret_access_keys.go b/builtin/logical/aws/secret_access_keys.go index 71e35fe867fc..d7dfed34ef2f 100644 --- a/builtin/logical/aws/secret_access_keys.go +++ b/builtin/logical/aws/secret_access_keys.go @@ -69,25 +69,14 @@ func (b *backend) secretTokenCreate(ctx context.Context, s logical.Storage, displayName, policyName, policy string, lifeTimeInSeconds int64) (*logical.Response, error) { - b.clientMutex.RLock() - unlockFunc := b.clientMutex.RUnlock - defer func() { unlockFunc() }() - if b.stsClient == nil { - // Upgrade the lock for writing - b.clientMutex.RUnlock() - b.clientMutex.Lock() - unlockFunc = b.clientMutex.Unlock - - stsClient, err := clientSTS(ctx, s) - if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - b.stsClient = stsClient + stsClient, err := b.clientSTS(ctx, s) + if err != nil { + return logical.ErrorResponse(err.Error()), nil } username, usernameWarning := genUsername(displayName, policyName, "sts") - tokenResp, err := b.stsClient.GetFederationToken( + tokenResp, err := stsClient.GetFederationToken( &sts.GetFederationTokenInput{ Name: aws.String(username), Policy: aws.String(policy), @@ -125,14 +114,10 @@ func (b *backend) secretTokenCreate(ctx context.Context, s logical.Storage, func (b *backend) assumeRole(ctx context.Context, s logical.Storage, displayName, roleName, roleArn, policy string, lifeTimeInSeconds int64) (*logical.Response, error) { - b.clientMutex.Lock() - defer b.clientMutex.Unlock() - if b.stsClient == nil { - stsClient, err := clientSTS(ctx, s) - if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - b.stsClient = stsClient + + stsClient, err := b.clientSTS(ctx, s) + if err != nil { + return logical.ErrorResponse(err.Error()), nil } username, usernameWarning := genUsername(displayName, roleName, "iam_user") @@ -145,7 +130,7 @@ func (b *backend) assumeRole(ctx context.Context, s logical.Storage, if policy != "" { assumeRoleInput.SetPolicy(policy) } - tokenResp, err := b.stsClient.AssumeRole(assumeRoleInput) + tokenResp, err := stsClient.AssumeRole(assumeRoleInput) if err != nil { return logical.ErrorResponse(fmt.Sprintf( @@ -180,20 +165,9 @@ func (b *backend) secretAccessKeysCreate( s logical.Storage, displayName, policyName string, role *awsRoleEntry) (*logical.Response, error) { - b.clientMutex.RLock() - unlockFunc := b.clientMutex.RUnlock - defer func() { unlockFunc() }() - if b.iamClient == nil { - // Upgrade the lock for writing - b.clientMutex.RUnlock() - b.clientMutex.Lock() - unlockFunc = b.clientMutex.Unlock - - iamClient, err := clientIAM(ctx, s) - if err != nil { - return logical.ErrorResponse(err.Error()), nil - } - b.iamClient = iamClient + iamClient, err := b.clientIAM(ctx, s) + if err != nil { + return logical.ErrorResponse(err.Error()), nil } username, usernameWarning := genUsername(displayName, policyName, "iam_user") @@ -210,7 +184,7 @@ func (b *backend) secretAccessKeysCreate( } // Create the user - _, err = b.iamClient.CreateUser(&iam.CreateUserInput{ + _, err = iamClient.CreateUser(&iam.CreateUserInput{ UserName: aws.String(username), }) if err != nil { @@ -220,7 +194,7 @@ func (b *backend) secretAccessKeysCreate( for _, arn := range role.PolicyArns { // Attach existing policy against user - _, err = b.iamClient.AttachUserPolicy(&iam.AttachUserPolicyInput{ + _, err = iamClient.AttachUserPolicy(&iam.AttachUserPolicyInput{ UserName: aws.String(username), PolicyArn: aws.String(arn), }) @@ -232,7 +206,7 @@ func (b *backend) secretAccessKeysCreate( } if role.PolicyDocument != "" { // Add new inline user policy against user - _, err = b.iamClient.PutUserPolicy(&iam.PutUserPolicyInput{ + _, err = iamClient.PutUserPolicy(&iam.PutUserPolicyInput{ UserName: aws.String(username), PolicyName: aws.String(policyName), PolicyDocument: aws.String(role.PolicyDocument), @@ -244,7 +218,7 @@ func (b *backend) secretAccessKeysCreate( } // Create the keys - keyResp, err := b.iamClient.CreateAccessKey(&iam.CreateAccessKeyInput{ + keyResp, err := iamClient.CreateAccessKey(&iam.CreateAccessKeyInput{ UserName: aws.String(username), }) if err != nil {