diff --git a/command/approle_concurrency_integ_test.go b/command/approle_concurrency_integ_test.go new file mode 100644 index 000000000000..d9f9fe9b92c3 --- /dev/null +++ b/command/approle_concurrency_integ_test.go @@ -0,0 +1,86 @@ +package command + +import ( + "sync" + "testing" + + "github.com/hashicorp/vault/api" + credAppRole "github.com/hashicorp/vault/builtin/credential/approle" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" + logxi "github.com/mgutz/logxi/v1" +) + +func TestAppRole_Integ_ConcurrentLogins(t *testing.T) { + var err error + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: logxi.NullLog, + CredentialBackends: map[string]logical.Factory{ + "approle": credAppRole.Factory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + + vault.TestWaitActive(t, cores[0].Core) + + client := cores[0].Client + + err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{ + Type: "approle", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/approle/role/role1", map[string]interface{}{ + "bind_secret_id": "true", + "period": "300", + }) + if err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Write("auth/approle/role/role1/secret-id", nil) + if err != nil { + t.Fatal(err) + } + secretID := secret.Data["secret_id"].(string) + + secret, err = client.Logical().Read("auth/approle/role/role1/role-id") + if err != nil { + t.Fatal(err) + } + roleID := secret.Data["role_id"].(string) + + wg := &sync.WaitGroup{} + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + secret, err = client.Logical().Write("auth/approle/login", map[string]interface{}{ + "role_id": roleID, + "secret_id": secretID, + }) + if err != nil { + t.Fatal(err) + } + if secret.Auth.ClientToken == "" { + t.Fatalf("expected a successful login") + } + }() + + } + wg.Wait() +} diff --git a/vault/identity_store.go b/vault/identity_store.go index 559c8148d25a..60cd0262fb68 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -249,7 +249,27 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl return nil, fmt.Errorf("missing alias name") } - alias, err := i.MemDBAliasByFactors(mountAccessor, aliasName, false, false) + txn := i.db.Txn(false) + + return i.entityByAliasFactorsInTxn(txn, mountAccessor, aliasName, clone) +} + +// entityByAlaisFactorsInTxn fetches the entity based on factors of alias, i.e +// mount accessor and the alias name. +func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool) (*identity.Entity, error) { + if txn == nil { + return nil, fmt.Errorf("nil txn") + } + + if mountAccessor == "" { + return nil, fmt.Errorf("missing mount accessor") + } + + if aliasName == "" { + return nil, fmt.Errorf("missing alias name") + } + + alias, err := i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, false, false) if err != nil { return nil, err } @@ -258,12 +278,12 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl return nil, nil } - return i.MemDBEntityByAliasID(alias.ID, clone) + return i.MemDBEntityByAliasIDInTxn(txn, alias.ID, clone) } -// CreateEntity creates a new entity. This is used by core to +// CreateOrFetchEntity creates a new entity. This is used by core to // associate each login attempt by an alias to a unified entity in Vault. -func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, error) { +func (i *IdentityStore) CreateOrFetchEntity(alias *logical.Alias) (*identity.Entity, error) { var entity *identity.Entity var err error @@ -290,9 +310,24 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er return nil, err } if entity != nil { - return nil, fmt.Errorf("alias already belongs to a different entity") + return entity, nil } + // Create a MemDB transaction to update both alias and entity + txn := i.db.Txn(true) + defer txn.Abort() + + // Check if an entity was created before acquiring the lock + entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, false) + if err != nil { + return nil, err + } + if entity != nil { + return entity, nil + } + + i.logger.Debug("identity: creating a new entity", "alias", alias) + entity = &identity.Entity{} err = i.sanitizeEntity(entity) @@ -320,10 +355,12 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er } // Update MemDB and persist entity object - err = i.upsertEntity(entity, nil, true) + err = i.upsertEntityInTxn(txn, entity, nil, true, false) if err != nil { return nil, err } + txn.Commit() + return entity, nil } diff --git a/vault/identity_store_test.go b/vault/identity_store_test.go index 5f5c884609d3..5c7f338191a0 100644 --- a/vault/identity_store_test.go +++ b/vault/identity_store_test.go @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/vault/logical" ) -func TestIdentityStore_CreateEntity(t *testing.T) { +func TestIdentityStore_CreateOrFetchEntity(t *testing.T) { is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t) alias := &logical.Alias{ MountType: "github", @@ -17,7 +17,7 @@ func TestIdentityStore_CreateEntity(t *testing.T) { Name: "githubuser", } - entity, err := is.CreateEntity(alias) + entity, err := is.CreateOrFetchEntity(alias) if err != nil { t.Fatal(err) } @@ -33,10 +33,20 @@ func TestIdentityStore_CreateEntity(t *testing.T) { t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name) } - // Try recreating an entity with the same alias details. It should fail. - entity, err = is.CreateEntity(alias) - if err == nil { - t.Fatalf("expected an error") + entity, err = is.CreateOrFetchEntity(alias) + if err != nil { + t.Fatal(err) + } + if entity == nil { + t.Fatalf("expected a non-nil entity") + } + + if len(entity.Aliases) != 1 { + t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(entity.Aliases)) + } + + if entity.Aliases[0].Name != alias.Name { + t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name) } } diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 64585171b0a1..2e0e7fea883f 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -666,12 +666,29 @@ func (i *IdentityStore) MemDBAliasByFactors(mountAccessor, aliasName string, clo return nil, fmt.Errorf("missing mount accessor") } + txn := i.db.Txn(false) + + return i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, clone, groupAlias) +} + +func (i *IdentityStore) MemDBAliasByFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) { + if txn == nil { + return nil, fmt.Errorf("nil txn") + } + + if aliasName == "" { + return nil, fmt.Errorf("missing alias name") + } + + if mountAccessor == "" { + return nil, fmt.Errorf("missing mount accessor") + } + tableName := entityAliasesTable if groupAlias { tableName = groupAliasesTable } - txn := i.db.Txn(false) aliasRaw, err := txn.First(tableName, "factors", mountAccessor, aliasName) if err != nil { return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %v", err) diff --git a/vault/request_handling.go b/vault/request_handling.go index 360d39e0b91a..79477db2259c 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -436,22 +436,15 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re var err error - // Check if an entity already exists for the given alias - entity, err = c.identityStore.entityByAliasFactors(auth.Alias.MountAccessor, auth.Alias.Name, false) + // Fetch the entity for the alias, or create an entity if one + // doesn't exist. + entity, err = c.identityStore.CreateOrFetchEntity(auth.Alias) if err != nil { return nil, nil, err } - // If not, create one. if entity == nil { - c.logger.Debug("core: creating a new entity", "alias", auth.Alias) - entity, err = c.identityStore.CreateEntity(auth.Alias) - if err != nil { - return nil, nil, err - } - if entity == nil { - return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias") - } + return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias") } auth.EntityID = entity.ID