Skip to content

Commit

Permalink
refactor locking and unlocking into methods on *backend
Browse files Browse the repository at this point in the history
  • Loading branch information
catsby committed Sep 11, 2018
1 parent a3e59aa commit 3c74986
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 42 deletions.
42 changes: 42 additions & 0 deletions builtin/logical/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ type backend struct {
// Mutex to protect access to iam/sts clients
clientMutex sync.RWMutex

// iamClient and stsClient hold configured iam and sts clients for reuse, and
// to enable mocking with AWS iface for tests
iamClient iamiface.IAMAPI
stsClient stsiface.STSAPI
}
Expand All @@ -76,3 +78,43 @@ After mounting this backend, credentials to generate IAM keys must
be configured with the "root" path and policies must be written using
the "roles/" endpoints before any access keys can be generated.
`

// clientIAM returns the configured IAM client. If nil, it constructs a new one
// and returns it, setting it the internal variable
func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IAMAPI, error) {
b.clientMutex.RLock()

This comment has been minimized.

Copy link
@vishalnayak

vishalnayak Sep 12, 2018

Contributor

Although I see why you might have used this approach of setting unlockFunc in the first place, now that this is in its own function, I guess we don't need to juggle this much. Probably the following is more readable.

{
	b.clientMutex.RLock()
	if b.iamClient != nil {
		b.clientMutex.RUnlock()
		return b.iamClient, nil
	}

	b.clientMutex.RUnlock()
	b.clientMutex.Lock()

	defer b.clientMutex.Unlock()

	iamClient, err := clientIAM(ctx, s)
	if err != nil {
		return nil, err
	}
	b.iamClient = iamClient

	return b.iamClient, nil
}
unlockFunc := b.clientMutex.RUnlock
defer func() { unlockFunc() }()
if b.iamClient == nil {
// Upgrade the lock for writing
b.clientMutex.RUnlock()
b.clientMutex.Lock()
unlockFunc = b.clientMutex.Unlock

iamClient, err := clientIAM(ctx, s)
if err != nil {
return nil, err
}
b.iamClient = iamClient

This comment has been minimized.

Copy link
@vishalnayak

vishalnayak Sep 12, 2018

Contributor

Since the calling function assumes that the client is always non-nil, do we want to error here if the iamClient is nil?

This comment has been minimized.

Copy link
@catsby

catsby Sep 12, 2018

Author Contributor

both clientIAM and clientSTS (

func clientIAM(ctx context.Context, s logical.Storage) (*iam.IAM, error) {
awsConfig, err := getRootConfig(ctx, s, "iam")
if err != nil {
return nil, err
}
client := iam.New(session.New(awsConfig))
if client == nil {
return nil, fmt.Errorf("could not obtain iam client")
}
return client, nil
}
func clientSTS(ctx context.Context, s logical.Storage) (*sts.STS, error) {
awsConfig, err := getRootConfig(ctx, s, "sts")
if err != nil {
return nil, err
}
client := sts.New(session.New(awsConfig))
if client == nil {
return nil, fmt.Errorf("could not obtain sts client")
}
return client, nil
}
) check for nil clients and error if nil, so we should be good here

}
return b.iamClient, nil
}

func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.STSAPI, error) {
b.clientMutex.RLock()
unlockFunc := b.clientMutex.RUnlock
defer func() { unlockFunc() }()
if b.stsClient == nil {
// Upgrade the lock for writing
b.clientMutex.RUnlock()
b.clientMutex.Lock()
unlockFunc = b.clientMutex.Unlock

stsClient, err := clientSTS(ctx, s)
if err != nil {
return nil, err
}
b.stsClient = stsClient
}
return b.stsClient, nil
}
58 changes: 16 additions & 42 deletions builtin/logical/aws/secret_access_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,14 @@ func (b *backend) secretTokenCreate(ctx context.Context, s logical.Storage,
displayName, policyName, policy string,
lifeTimeInSeconds int64) (*logical.Response, error) {

b.clientMutex.RLock()
unlockFunc := b.clientMutex.RUnlock
defer func() { unlockFunc() }()
if b.stsClient == nil {
// Upgrade the lock for writing
b.clientMutex.RUnlock()
b.clientMutex.Lock()
unlockFunc = b.clientMutex.Unlock

stsClient, err := clientSTS(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
b.stsClient = stsClient
stsClient, err := b.clientSTS(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

username, usernameWarning := genUsername(displayName, policyName, "sts")

tokenResp, err := b.stsClient.GetFederationToken(
tokenResp, err := stsClient.GetFederationToken(
&sts.GetFederationTokenInput{
Name: aws.String(username),
Policy: aws.String(policy),
Expand Down Expand Up @@ -125,14 +114,10 @@ func (b *backend) secretTokenCreate(ctx context.Context, s logical.Storage,
func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
displayName, roleName, roleArn, policy string,
lifeTimeInSeconds int64) (*logical.Response, error) {
b.clientMutex.Lock()
defer b.clientMutex.Unlock()
if b.stsClient == nil {
stsClient, err := clientSTS(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
b.stsClient = stsClient

stsClient, err := b.clientSTS(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

username, usernameWarning := genUsername(displayName, roleName, "iam_user")
Expand All @@ -145,7 +130,7 @@ func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
if policy != "" {
assumeRoleInput.SetPolicy(policy)
}
tokenResp, err := b.stsClient.AssumeRole(assumeRoleInput)
tokenResp, err := stsClient.AssumeRole(assumeRoleInput)

if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
Expand Down Expand Up @@ -180,20 +165,9 @@ func (b *backend) secretAccessKeysCreate(
s logical.Storage,
displayName, policyName string, role *awsRoleEntry) (*logical.Response, error) {

b.clientMutex.RLock()
unlockFunc := b.clientMutex.RUnlock
defer func() { unlockFunc() }()
if b.iamClient == nil {
// Upgrade the lock for writing
b.clientMutex.RUnlock()
b.clientMutex.Lock()
unlockFunc = b.clientMutex.Unlock

iamClient, err := clientIAM(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
b.iamClient = iamClient
iamClient, err := b.clientIAM(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

username, usernameWarning := genUsername(displayName, policyName, "iam_user")
Expand All @@ -210,7 +184,7 @@ func (b *backend) secretAccessKeysCreate(
}

// Create the user
_, err = b.iamClient.CreateUser(&iam.CreateUserInput{
_, err = iamClient.CreateUser(&iam.CreateUserInput{
UserName: aws.String(username),
})
if err != nil {
Expand All @@ -220,7 +194,7 @@ func (b *backend) secretAccessKeysCreate(

for _, arn := range role.PolicyArns {
// Attach existing policy against user
_, err = b.iamClient.AttachUserPolicy(&iam.AttachUserPolicyInput{
_, err = iamClient.AttachUserPolicy(&iam.AttachUserPolicyInput{
UserName: aws.String(username),
PolicyArn: aws.String(arn),
})
Expand All @@ -232,7 +206,7 @@ func (b *backend) secretAccessKeysCreate(
}
if role.PolicyDocument != "" {
// Add new inline user policy against user
_, err = b.iamClient.PutUserPolicy(&iam.PutUserPolicyInput{
_, err = iamClient.PutUserPolicy(&iam.PutUserPolicyInput{
UserName: aws.String(username),
PolicyName: aws.String(policyName),
PolicyDocument: aws.String(role.PolicyDocument),
Expand All @@ -244,7 +218,7 @@ func (b *backend) secretAccessKeysCreate(
}

// Create the keys
keyResp, err := b.iamClient.CreateAccessKey(&iam.CreateAccessKeyInput{
keyResp, err := iamClient.CreateAccessKey(&iam.CreateAccessKeyInput{
UserName: aws.String(username),
})
if err != nil {
Expand Down

0 comments on commit 3c74986

Please sign in to comment.