Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport 1.3: Fix identity token panic during invalidation #8043

Merged
merged 1 commit into from
Dec 17, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vault/identity_store.go
Original file line number Diff line number Diff line change
@@ -315,7 +315,9 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) {
return

case strings.HasPrefix(key, oidcTokensPrefix):
i.oidcCache.Flush(nil)
if err := i.oidcCache.Flush(noNamespace); err != nil {
i.logger.Error("error flushing oidc cache", "error", err)
}
}
}

135 changes: 108 additions & 27 deletions vault/identity_store_oidc.go
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ import (
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
@@ -90,6 +91,8 @@ type oidcCache struct {
c *cache.Cache
}

var errNilNamespace = errors.New("nil namespace in oidc cache request")

const (
issuerPath = "identity/oidc"
oidcTokensPrefix = "oidc_tokens/"
@@ -111,7 +114,7 @@ var supportedAlgs = []string{
}

// pseudo-namespace for cache items that don't belong to any real namespace.
var nilNamespace = &namespace.Namespace{ID: "__NIL_NAMESPACE"}
var noNamespace = &namespace.Namespace{ID: "__NO_NAMESPACE"}

func oidcPaths(i *IdentityStore) []*framework.Path {
return []*framework.Path{
@@ -370,7 +373,9 @@ func (i *IdentityStore) pathOIDCUpdateConfig(ctx context.Context, req *logical.R
return nil, err
}

i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

return resp, nil
}
@@ -381,7 +386,12 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (*
return nil, err
}

if v, ok := i.oidcCache.Get(ns, "config"); ok {
v, ok, err := i.oidcCache.Get(ns, "config")
if err != nil {
return nil, err
}

if ok {
return v.(*oidcConfig), nil
}

@@ -404,7 +414,9 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (*

c.effectiveIssuer += "/v1/" + ns.Path + issuerPath

i.oidcCache.SetDefault(ns, "config", &c)
if err := i.oidcCache.SetDefault(ns, "config", &c); err != nil {
return nil, err
}

return &c, nil
}
@@ -416,8 +428,6 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
return nil, err
}

defer i.oidcCache.Flush(ns)

name := d.Get("name").(string)

i.oidcLock.Lock()
@@ -494,6 +504,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
}
}

if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

// store named key
entry, err := logical.StorageEntryJSON(namedKeyConfigPath+name, key)
if err != nil {
@@ -590,7 +604,9 @@ func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Requ
return nil, err
}

i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

return nil, nil
}
@@ -645,7 +661,9 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ
return nil, err
}

i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

return nil, nil
}
@@ -683,7 +701,12 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.

var key *namedKey

if keyRaw, found := i.oidcCache.Get(ns, "namedKeys/"+role.Key); found {
keyRaw, found, err := i.oidcCache.Get(ns, "namedKeys/"+role.Key)
if err != nil {
return nil, err
}

if found {
key = keyRaw.(*namedKey)
} else {
entry, _ := req.Storage.Get(ctx, namedKeyConfigPath+role.Key)
@@ -695,7 +718,9 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.
return nil, err
}

i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key)
if err := i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key); err != nil {
return nil, err
}
}
// Validate that the role is allowed to sign with its key (the key could have been updated)
if !strutil.StrListContains(key.AllowedClientIDs, "*") && !strutil.StrListContains(key.AllowedClientIDs, role.ClientID) {
@@ -923,7 +948,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateRole(ctx context.Context, req *logic
return nil, err
}

i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}

return nil, nil
}

@@ -994,7 +1022,12 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
return nil, err
}

if v, ok := i.oidcCache.Get(ns, "discoveryResponse"); ok {
v, ok, err := i.oidcCache.Get(ns, "discoveryResponse")
if err != nil {
return nil, err
}

if ok {
data = v.([]byte)
} else {
c, err := i.getOIDCConfig(ctx, req.Storage)
@@ -1015,7 +1048,9 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
return nil, err
}

i.oidcCache.SetDefault(ns, "discoveryResponse", data)
if err := i.oidcCache.SetDefault(ns, "discoveryResponse", data); err != nil {
return nil, err
}
}

resp := &logical.Response{
@@ -1040,7 +1075,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}

if v, ok := i.oidcCache.Get(ns, "jwksResponse"); ok {
v, ok, err := i.oidcCache.Get(ns, "jwksResponse")
if err != nil {
return nil, err
}

if ok {
data = v.([]byte)
} else {
jwks, err := i.generatePublicJWKS(ctx, req.Storage)
@@ -1053,7 +1093,9 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}

i.oidcCache.SetDefault(ns, "jwksResponse", data)
if err := i.oidcCache.SetDefault(ns, "jwksResponse", data); err != nil {
return nil, err
}
}

resp := &logical.Response{
@@ -1072,7 +1114,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}
if len(keys) > 0 {
if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
return nil, err
}

if ok {
now := time.Now()
expireAt := v.(time.Time)
if expireAt.After(now) {
@@ -1311,7 +1358,12 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
return nil, err
}

if jwksRaw, ok := i.oidcCache.Get(ns, "jwks"); ok {
jwksRaw, ok, err := i.oidcCache.Get(ns, "jwks")
if err != nil {
return nil, err
}

if ok {
return jwksRaw.(*jose.JSONWebKeySet), nil
}

@@ -1336,7 +1388,9 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
jwks.Keys = append(jwks.Keys, *key)
}

i.oidcCache.SetDefault(ns, "jwks", jwks)
if err := i.oidcCache.SetDefault(ns, "jwks", jwks); err != nil {
return nil, err
}

return jwks, nil
}
@@ -1435,7 +1489,9 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
}

if didUpdate {
i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
i.Logger().Error("error flushing oidc cache", "error", err)
}
}

return nextExpiration, nil
@@ -1501,7 +1557,13 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {

nsPaths := i.listNamespacePaths()

if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
i.Logger().Error("error reading oidc cache", "err", err)
return
}

if ok {
nextRun = v.(time.Time)
}

@@ -1531,7 +1593,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
i.Logger().Warn("error expiring OIDC public keys", "err", err)
}

i.oidcCache.Flush(nilNamespace)
if err := i.oidcCache.Flush(noNamespace); err != nil {
i.Logger().Error("error flushing oidc cache", "err", err)
}

// re-run at the soonest expiration or rotation time
if nextRotation.Before(nextRun) {
@@ -1542,7 +1606,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
nextRun = nextExpiration
}
}
i.oidcCache.SetDefault(nilNamespace, "nextRun", nextRun)
if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
i.Logger().Error("error setting oidc cache", "err", err)
}
}
}

@@ -1556,20 +1622,35 @@ func (c *oidcCache) nskey(ns *namespace.Namespace, key string) string {
return fmt.Sprintf("v0:%s:%s", ns.ID, key)
}

func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool) {
return c.c.Get(c.nskey(ns, key))
func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool, error) {
if ns == nil {
return nil, false, errNilNamespace
}
v, found := c.c.Get(c.nskey(ns, key))
return v, found, nil
}

func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) {
func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) error {
if ns == nil {
return errNilNamespace
}
c.c.SetDefault(c.nskey(ns, key), obj)

return nil
}

func (c *oidcCache) Flush(ns *namespace.Namespace) {
func (c *oidcCache) Flush(ns *namespace.Namespace) error {
if ns == nil {
return errNilNamespace
}

for itemKey := range c.c.Items() {
if isTargetNamespacedKey(itemKey, []string{nilNamespace.ID, ns.ID}) {
if isTargetNamespacedKey(itemKey, []string{noNamespace.ID, ns.ID}) {
c.c.Delete(itemKey)
}
}

return nil
}

// isTargetNamespacedKey returns true for a properly constructed namespaced key (<version>:<nsID>:<key>)
32 changes: 27 additions & 5 deletions vault/identity_store_oidc_test.go
Original file line number Diff line number Diff line change
@@ -619,7 +619,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
currentCycle = currentCycle + 1

// sleep until we are in the next cycle - where a next run will happen
v, _ := c.identityStore.oidcCache.Get(nilNamespace, "nextRun")
v, _, _ := c.identityStore.oidcCache.Get(noNamespace, "nextRun")
nextRun := v.(time.Time)
now := time.Now()
diff := nextRun.Sub(now)
@@ -1012,7 +1012,7 @@ func TestOIDC_isTargetNamespacedKey(t *testing.T) {
func TestOIDC_Flush(t *testing.T) {
c := newOIDCCache()
ns := []*namespace.Namespace{
nilNamespace, //ns[0] is nilNamespace
noNamespace, //ns[0] is nilNamespace
&namespace.Namespace{ID: "ns1"},
&namespace.Namespace{ID: "ns2"},
}
@@ -1021,7 +1021,9 @@ func TestOIDC_Flush(t *testing.T) {
populateNs := func() {
for i := range ns {
for _, val := range []string{"keyA", "keyB", "keyC"} {
c.SetDefault(ns[i], val, struct{}{})
if err := c.SetDefault(ns[i], val, struct{}{}); err != nil {
t.Fatal(err)
}
}
}
}
@@ -1052,17 +1054,37 @@ func TestOIDC_Flush(t *testing.T) {

// flushing ns1 should flush ns1 and nilNamespace but not ns2
populateNs()
c.Flush(ns[1])
if err := c.Flush(ns[1]); err != nil {
t.Fatal(err)
}
items := c.c.Items()
verify(items, []*namespace.Namespace{ns[2]}, []*namespace.Namespace{ns[0], ns[1]})

// flushing nilNamespace should flush nilNamespace but not ns1 or ns2
populateNs()
c.Flush(ns[0])
if err := c.Flush(ns[0]); err != nil {
t.Fatal(err)
}
items = c.c.Items()
verify(items, []*namespace.Namespace{ns[1], ns[2]}, []*namespace.Namespace{ns[0]})
}

func TestOIDC_CacheNamespaceNilCheck(t *testing.T) {
cache := newOIDCCache()

if _, _, err := cache.Get(nil, "foo"); err == nil {
t.Fatal("expected error, got nil")
}

if err := cache.SetDefault(nil, "foo", 42); err == nil {
t.Fatal("expected error, got nil")
}

if err := cache.Flush(nil); err == nil {
t.Fatal("expected error, got nil")
}
}

// some helpers
func expectSuccess(t *testing.T, resp *logical.Response, err error) {
t.Helper()