From e054899d92c86cbe45ac9ae6d1dd426d81e96b42 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Mon, 23 Jul 2018 22:29:37 -0400 Subject: [PATCH] Modify approle tidy to validate dangling accessors --- builtin/credential/approle/backend.go | 3 + .../credential/approle/path_tidy_user_id.go | 83 ++++++++++++++-- .../approle/path_tidy_user_id_test.go | 94 ++++++++++++++++++- 3 files changed, 170 insertions(+), 10 deletions(-) diff --git a/builtin/credential/approle/backend.go b/builtin/credential/approle/backend.go index a16794e20321..3705dbb1eb38 100644 --- a/builtin/credential/approle/backend.go +++ b/builtin/credential/approle/backend.go @@ -3,6 +3,7 @@ package approle import ( "context" "sync" + "time" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/locksutil" @@ -56,6 +57,8 @@ type backend struct { // secretIDListingLock is a dedicated lock for listing SecretIDAccessors // for all the SecretIDs issued against an approle secretIDListingLock sync.RWMutex + + testTidyDelay time.Duration } func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { diff --git a/builtin/credential/approle/path_tidy_user_id.go b/builtin/credential/approle/path_tidy_user_id.go index 7f5cec894f0c..5bdd12640fcd 100644 --- a/builtin/credential/approle/path_tidy_user_id.go +++ b/builtin/credential/approle/path_tidy_user_id.go @@ -38,17 +38,29 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi go func() { defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0) + logger := b.Logger().Named("tidy") + + checkCount := 0 + + defer func() { + if b.testTidyDelay > 0 { + logger.Trace("done checking entries", "num_entries", checkCount) + } + }() + // Don't cancel when the original client request goes away ctx = context.Background() - logger := b.Logger().Named("tidy") - tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error { + logger.Trace("listing role HMACs", "prefix", secretIDPrefixToUse) + roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse) if err != nil { return err } + logger.Trace("listing accessors", "prefix", accessorIDPrefixToUse) + // List all the accessors and add them all to a map accessorHashes, err := s.List(ctx, accessorIDPrefixToUse) if err != nil { @@ -59,7 +71,10 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi accessorMap[accessorHash] = true } + time.Sleep(b.testTidyDelay) + secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error { + checkCount++ lock := b.secretIDLock(secretIDHMAC) lock.Lock() defer lock.Unlock() @@ -91,6 +106,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err) } if accessorEntry == nil { + logger.Trace("found nil accessor") if err := s.Delete(ctx, entryIndex); err != nil { return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err) } @@ -99,6 +115,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi // ExpirationTime not being set indicates non-expiring SecretIDs if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) { + logger.Trace("found expired secret ID") // Clean up the accessor of the secret ID first err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse) if err != nil { @@ -126,6 +143,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi } for _, roleNameHMAC := range roleNameHMACs { + logger.Trace("listing secret ID HMACs", "role_hmac", roleNameHMAC) secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC)) if err != nil { return err @@ -140,13 +158,60 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi // Accessor indexes were not getting cleaned up until 0.9.3. This is a fix // to clean up the dangling accessor entries. - for accessorHash, _ := range accessorMap { - // Ideally, locking should be performed here. But for that, accessors - // are required in plaintext, which are not available. Hence performing - // a racy cleanup. - err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash) - if err != nil { - return err + if len(accessorMap) > 0 { + for _, lock := range b.secretIDLocks { + lock.Lock() + defer lock.Unlock() + } + for accessorHash, _ := range accessorMap { + logger.Trace("found dangling accessor, verifying") + // Ideally, locking on accessors should be performed here too + // but for that, accessors are required in plaintext, which are + // not available. The code above helps but it may still be + // racy. + // ... + // Look up the secret again now that we have all the locks. The + // lock is held when writing accessor/secret so if we have the + // lock we know we're not in a + // wrote-accessor-but-not-yet-secret case, which can be racy. + var entry secretIDAccessorStorageEntry + entryIndex := accessorIDPrefixToUse + accessorHash + se, err := s.Get(ctx, entryIndex) + if err != nil { + return err + } + if se != nil { + err = se.DecodeJSON(&entry) + if err != nil { + return err + } + + // The storage entry doesn't store the role ID, so we have + // to go about this the long way; fortunately we shouldn't + // actually hit this very often + var found bool + searchloop: + for _, roleNameHMAC := range roleNameHMACs { + secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC)) + if err != nil { + return err + } + for _, v := range secretIDHMACs { + if v == entry.SecretIDHMAC { + found = true + logger.Trace("accessor verified, not removing") + break searchloop + } + } + } + if !found { + logger.Trace("could not verify dangling accessor, removing") + err = s.Delete(ctx, entryIndex) + if err != nil { + return err + } + } + } } } diff --git a/builtin/credential/approle/path_tidy_user_id_test.go b/builtin/credential/approle/path_tidy_user_id_test.go index 0e1d5fd198b5..bc261efea55a 100644 --- a/builtin/credential/approle/path_tidy_user_id_test.go +++ b/builtin/credential/approle/path_tidy_user_id_test.go @@ -2,13 +2,15 @@ package approle import ( "context" + "fmt" + "sync" "testing" "time" "github.com/hashicorp/vault/logical" ) -func TestAppRole_TidyDanglingAccessors(t *testing.T) { +func TestAppRole_TidyDanglingAccessors_Normal(t *testing.T) { var resp *logical.Response var err error b, storage := createBackendWithStorage(t) @@ -83,3 +85,93 @@ func TestAppRole_TidyDanglingAccessors(t *testing.T) { t.Fatalf("bad: len(accessorHashes); expect 1, got %d", len(accessorHashes)) } } + +func TestAppRole_TidyDanglingAccessors_RaceTest(t *testing.T) { + var resp *logical.Response + var err error + b, storage := createBackendWithStorage(t) + + b.testTidyDelay = 300 * time.Millisecond + + // Create a role + createRole(t, b, storage, "role1", "a,b,c") + + // Create an initial entry + roleSecretIDReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "role/role1/secret-id", + Storage: storage, + } + resp, err = b.HandleRequest(context.Background(), roleSecretIDReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + count := 1 + + wg := sync.WaitGroup{} + now := time.Now() + started := false + for { + if time.Now().Sub(now) > 700*time.Millisecond { + break + } + if time.Now().Sub(now) > 100*time.Millisecond && !started { + started = true + _, err = b.tidySecretID(context.Background(), &logical.Request{ + Storage: storage, + }) + if err != nil { + t.Fatal(err) + } + } + go func() { + wg.Add(1) + defer wg.Done() + roleSecretIDReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "role/role1/secret-id", + Storage: storage, + } + resp, err = b.HandleRequest(context.Background(), roleSecretIDReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + }() + count++ + } + + t.Logf("wrote %d entries", count) + + wg.Wait() + // Let tidy finish + time.Sleep(1 * time.Second) + + // Run tidy again + _, err = b.tidySecretID(context.Background(), &logical.Request{ + Storage: storage, + }) + if err != nil { + t.Fatal(err) + } + time.Sleep(2 * time.Second) + + accessorHashes, err := storage.List(context.Background(), "accessor/") + if err != nil { + t.Fatal(err) + } + if len(accessorHashes) != count { + t.Fatalf("bad: len(accessorHashes); expect %d, got %d", count, len(accessorHashes)) + } + + roleHMACs, err := storage.List(context.Background(), secretIDPrefix) + if err != nil { + t.Fatal(err) + } + secretIDs, err := storage.List(context.Background(), fmt.Sprintf("%s%s", secretIDPrefix, roleHMACs[0])) + if err != nil { + t.Fatal(err) + } + if len(secretIDs) != count { + t.Fatalf("bad: len(secretIDs); expect %d, got %d", count, len(secretIDs)) + } +}