Skip to content

Commit

Permalink
backport of commit 10bd15f (#28827)
Browse files Browse the repository at this point in the history
Co-authored-by: miagilepner <[email protected]>
  • Loading branch information
1 parent 50ef8a5 commit 74339e7
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 30 deletions.
68 changes: 68 additions & 0 deletions builtin/logical/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package aws

import (
"context"
"fmt"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -33,6 +34,7 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend,

func Backend(_ *logical.BackendConfig) *backend {
var b backend
b.minAllowableRotationPeriod = minAllowableRotationPeriod
b.credRotationQueue = queue.New()
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
Expand Down Expand Up @@ -62,6 +64,7 @@ func Backend(_ *logical.BackendConfig) *backend {
secretAccessKeys(&b),
},

InitializeFunc: b.initialize,
Invalidate: b.invalidate,
WALRollback: b.walRollback,
WALRollbackMinAge: minAwsUserRollbackAge,
Expand Down Expand Up @@ -94,6 +97,8 @@ type backend struct {
// the age of a static role's credential is tracked by a priority queue and handled
// by the PeriodicFunc
credRotationQueue *queue.PriorityQueue

minAllowableRotationPeriod time.Duration
}

const backendHelp = `
Expand Down Expand Up @@ -176,3 +181,66 @@ func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.ST

return b.stsClient, nil
}

func (b *backend) initialize(ctx context.Context, request *logical.InitializationRequest) error {
if !b.WriteSafeReplicationState() {
b.Logger().Info("skipping populating rotation queue")
return nil
}
b.Logger().Info("populating rotation queue")

creds, err := request.Storage.List(ctx, pathStaticCreds+"/")
if err != nil {
return err
}
b.Logger().Debug(fmt.Sprintf("Adding %d items to the rotation queue", len(creds)))
for _, roleName := range creds {
if roleName == "" {
continue
}
credPath := formatCredsStoragePath(roleName)
credsEntry, err := request.Storage.Get(ctx, credPath)
if err != nil {
return fmt.Errorf("could not read credentials: %w", err)
}
if credsEntry == nil {
continue
}
credentials := awsCredentials{}
if err := credsEntry.DecodeJSON(&credentials); err != nil {
return fmt.Errorf("failed to decode credentials: %w", err)
}

configEntry, err := request.Storage.Get(ctx, formatRoleStoragePath(roleName))
if err != nil {
return fmt.Errorf("could not read role: %w", err)
}
if configEntry == nil {
continue
}
config := staticRoleEntry{}
if err := configEntry.DecodeJSON(&config); err != nil {
return fmt.Errorf("failed to decode role config: %w", err)
}

if credentials.Expiration == nil {
expiration := time.Now().UTC().Add(config.RotationPeriod)
credentials.Expiration = &expiration
_, err := logical.StorageEntryJSON(credPath, creds)
if err != nil {
return fmt.Errorf("failed to marshal object to JSON: %w", err)
}
b.Logger().Debug("no known expiration time for credentials so resetting the expiration", "role", roleName, "new expiration", expiration)
}

err = b.credRotationQueue.Push(&queue.Item{
Key: config.Name,
Value: config,
Priority: credentials.priority(config),
})
if err != nil {
return fmt.Errorf("failed to add creds for role %s to queue: %w", roleName, err)
}
}
return nil
}
13 changes: 11 additions & 2 deletions builtin/logical/aws/path_static_creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"fmt"
"net/http"
"time"

"github.com/fatih/structs"
"github.com/hashicorp/vault/sdk/framework"
Expand All @@ -21,8 +22,9 @@ const (
)

type awsCredentials struct {
AccessKeyID string `json:"access_key" structs:"access_key" mapstructure:"access_key"`
SecretAccessKey string `json:"secret_key" structs:"secret_key" mapstructure:"secret_key"`
AccessKeyID string `json:"access_key" structs:"access_key" mapstructure:"access_key"`
Expiration *time.Time `json:"expiration,omitempty" structs:"expiration" mapstructure:"expiration"`
SecretAccessKey string `json:"secret_key" structs:"secret_key" mapstructure:"secret_key"`
}

func pathStaticCredentials(b *backend) *framework.Path {
Expand Down Expand Up @@ -89,6 +91,13 @@ func formatCredsStoragePath(roleName string) string {
return fmt.Sprintf("%s/%s", pathStaticCreds, roleName)
}

func (a *awsCredentials) priority(role staticRoleEntry) int64 {
if a.Expiration != nil {
return a.Expiration.Unix()
}
return time.Now().Add(role.RotationPeriod).Unix()
}

const pathStaticCredsHelpSyn = `Retrieve static credentials from the named role.`

const pathStaticCredsHelpDesc = `
Expand Down
21 changes: 21 additions & 0 deletions builtin/logical/aws/path_static_creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"context"
"reflect"
"testing"
"time"

"github.com/fatih/structs"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)

// TestStaticCredsRead verifies that we can correctly read a cred that exists, and correctly _not read_
Expand Down Expand Up @@ -91,3 +93,22 @@ func staticCredsFieldData(data map[string]interface{}) *framework.FieldData {
Schema: schema,
}
}

// Test_awsCredentials_priority verifies that the expiration in the credentials
// is returned as the priority value when it is present, but otherwise the
// priority is now + the rotation period
func Test_awsCredentials_priority(t *testing.T) {
expiration := time.Date(2023, 10, 24, 15, 21, 0o0, 0o0, time.UTC)
roleConfig := staticRoleEntry{RotationPeriod: time.Hour}
t.Run("use credential value", func(t *testing.T) {
creds := &awsCredentials{
Expiration: &expiration,
}
require.Equal(t, expiration.Unix(), creds.priority(roleConfig))
})
t.Run("use role value", func(t *testing.T) {
hourUnix := time.Now().Add(time.Hour).Unix()
creds := &awsCredentials{}
require.InDelta(t, hourUnix, creds.priority(roleConfig), float64(time.Minute/time.Second))
})
}
27 changes: 20 additions & 7 deletions builtin/logical/aws/path_static_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,25 +194,31 @@ func (b *backend) pathStaticRolesWrite(ctx context.Context, req *logical.Request

// Bootstrap initial set of keys if they did not exist before. AWS Secret Access Keys can only be obtained on creation,
// so we need to boostrap new roles with a new initial set of keys to be able to serve valid credentials to Vault clients.
existingCreds, err := req.Storage.Get(ctx, formatCredsStoragePath(config.Name))
credsPath := formatCredsStoragePath(config.Name)
existingCredsEntry, err := req.Storage.Get(ctx, credsPath)
if err != nil {
return nil, fmt.Errorf("unable to verify if credentials already exist for role %q: %w", config.Name, err)
}
if existingCreds == nil {
err := b.createCredential(ctx, req.Storage, config, false)
if existingCredsEntry == nil {
creds, err := b.createCredential(ctx, req.Storage, config, false)
if err != nil {
return nil, fmt.Errorf("failed to create new credentials for role %q: %w", config.Name, err)
}

err = b.credRotationQueue.Push(&queue.Item{
Key: config.Name,
Value: config,
Priority: time.Now().Add(config.RotationPeriod).Unix(),
Priority: creds.priority(config),
})
if err != nil {
return nil, fmt.Errorf("failed to add item into the rotation queue for role %q: %w", config.Name, err)
}
} else {
var existingCreds awsCredentials
err := existingCredsEntry.DecodeJSON(&existingCreds)
if err != nil {
return nil, fmt.Errorf("unable to decode existing credentials for role %s: %w", config.Name, err)
}
// creds already exist, so all we need to do is update the rotation
// what here stays the same and what changes? Can we change the name?
i, err := b.credRotationQueue.PopByKey(config.Name)
Expand All @@ -221,7 +227,14 @@ func (b *backend) pathStaticRolesWrite(ctx context.Context, req *logical.Request
}
i.Value = config
// update the next rotation to occur at now + the new rotation period
i.Priority = time.Now().Add(config.RotationPeriod).Unix()
newExpiration := time.Now().Add(config.RotationPeriod)
existingCreds.Expiration = &newExpiration
_, err = logical.StorageEntryJSON(credsPath, &existingCreds)
if err != nil {
return nil, fmt.Errorf("error updating credentials for role %s: %w", config.Name, err)
}
i.Priority = existingCreds.priority(config)

err = b.credRotationQueue.Push(i)
if err != nil {
return nil, fmt.Errorf("failed to add updated item into the rotation queue for role %q: %w", config.Name, err)
Expand Down Expand Up @@ -312,8 +325,8 @@ const (
)

func (b *backend) validateRotationPeriod(period time.Duration) error {
if period < minAllowableRotationPeriod {
return fmt.Errorf("role rotation period out of range: must be greater than %.2f seconds", minAllowableRotationPeriod.Seconds())
if period < b.minAllowableRotationPeriod {
return fmt.Errorf("role rotation period out of range: must be greater than %.2f seconds", b.minAllowableRotationPeriod.Seconds())
}
return nil
}
Expand Down
32 changes: 18 additions & 14 deletions builtin/logical/aws/rotation.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)

cfg := item.Value.(staticRoleEntry)

err = b.createCredential(ctx, storage, cfg, true)
creds, err := b.createCredential(ctx, storage, cfg, true)
if err != nil {
b.Logger().Error("failed to create credential, re-queueing", "error", err)
// put it back in the queue with a backoff
item.Priority = time.Now().Add(10 * time.Second).Unix()
innerErr := b.credRotationQueue.Push(item)
Expand All @@ -74,7 +75,7 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)
}

// set new priority and re-queue
item.Priority = time.Now().Add(cfg.RotationPeriod).Unix()
item.Priority = creds.priority(cfg)
err = b.credRotationQueue.Push(item)
if err != nil {
return true, fmt.Errorf("failed to add item into the rotation queue for role %q: %w", cfg.Name, err)
Expand All @@ -84,10 +85,10 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)
}

// createCredential will create a new iam credential, deleting the oldest one if necessary.
func (b *backend) createCredential(ctx context.Context, storage logical.Storage, cfg staticRoleEntry, shouldLockStorage bool) error {
func (b *backend) createCredential(ctx context.Context, storage logical.Storage, cfg staticRoleEntry, shouldLockStorage bool) (*awsCredentials, error) {
iamClient, err := b.clientIAM(ctx, storage)
if err != nil {
return fmt.Errorf("unable to get the AWS IAM client: %w", err)
return nil, fmt.Errorf("unable to get the AWS IAM client: %w", err)
}

// IAM users can have a most 2 sets of keys at a time.
Expand All @@ -97,14 +98,14 @@ func (b *backend) createCredential(ctx context.Context, storage logical.Storage,

err = b.validateIAMUserExists(ctx, storage, &cfg, false)
if err != nil {
return fmt.Errorf("iam user didn't exist, or username/userid didn't match: %w", err)
return nil, fmt.Errorf("iam user didn't exist, or username/userid didn't match: %w", err)
}

accessKeys, err := iamClient.ListAccessKeys(&iam.ListAccessKeysInput{
UserName: aws.String(cfg.Username),
})
if err != nil {
return fmt.Errorf("unable to list existing access keys for IAM user %q: %w", cfg.Username, err)
return nil, fmt.Errorf("unable to list existing access keys for IAM user %q: %w", cfg.Username, err)
}

// If we have the maximum number of keys, we have to delete one to make another (so we can get the credentials).
Expand All @@ -127,7 +128,7 @@ func (b *backend) createCredential(ctx context.Context, storage logical.Storage,
UserName: oldestKey.UserName,
})
if err != nil {
return fmt.Errorf("unable to delete oldest access keys for user %q: %w", cfg.Username, err)
return nil, fmt.Errorf("unable to delete oldest access keys for user %q: %w", cfg.Username, err)
}
}

Expand All @@ -136,27 +137,30 @@ func (b *backend) createCredential(ctx context.Context, storage logical.Storage,
UserName: aws.String(cfg.Username),
})
if err != nil {
return fmt.Errorf("unable to create new access keys for user %q: %w", cfg.Username, err)
return nil, fmt.Errorf("unable to create new access keys for user %q: %w", cfg.Username, err)
}
expiration := time.Now().UTC().Add(cfg.RotationPeriod)

// Persist new keys
entry, err := logical.StorageEntryJSON(formatCredsStoragePath(cfg.Name), &awsCredentials{
creds := &awsCredentials{
AccessKeyID: *out.AccessKey.AccessKeyId,
SecretAccessKey: *out.AccessKey.SecretAccessKey,
})
Expiration: &expiration,
}
// Persist new keys
entry, err := logical.StorageEntryJSON(formatCredsStoragePath(cfg.Name), creds)
if err != nil {
return fmt.Errorf("failed to marshal object to JSON: %w", err)
return nil, fmt.Errorf("failed to marshal object to JSON: %w", err)
}
if shouldLockStorage {
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
}
err = storage.Put(ctx, entry)
if err != nil {
return fmt.Errorf("failed to save object in storage: %w", err)
return nil, fmt.Errorf("failed to save object in storage: %w", err)
}

return nil
return creds, nil
}

// delete credential will remove the credential associated with the role from storage.
Expand Down
Loading

0 comments on commit 74339e7

Please sign in to comment.