Skip to content

Commit

Permalink
Fix Okta auth to allow group names containing slashes (#6665)
Browse files Browse the repository at this point in the history
This PR also adds CollectKeysPrefix which allows a more memory efficient
key scan for those cases where the result is immediately filtered by
prefix.
  • Loading branch information
Jim Kalafut authored May 1, 2019
1 parent 3acee26 commit c9ac721
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 25 deletions.
15 changes: 6 additions & 9 deletions builtin/credential/ldap/path_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ func pathGroups(b *backend) *framework.Path {
return &framework.Path{
Pattern: `groups/(?P<name>.+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the LDAP group.",
},

"policies": &framework.FieldSchema{
"policies": {
Type: framework.TypeCommaStringSlice,
Description: "Comma-separated list of policies associated to the group.",
},
Expand Down Expand Up @@ -132,17 +132,14 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
}

func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
keys, err := logical.CollectKeys(ctx, req.Storage)
keys, err := logical.CollectKeysWithPrefix(ctx, req.Storage, "group/")
if err != nil {
return nil, err
}
retKeys := make([]string, 0)
for _, key := range keys {
if strings.HasPrefix(key, "group/") && !strings.HasPrefix(key, "/") {
retKeys = append(retKeys, strings.TrimPrefix(key, "group/"))
}
for i := range keys {
keys[i] = strings.TrimPrefix(keys[i], "group/")
}
return logical.ListResponse(retKeys), nil
return logical.ListResponse(keys), nil
}

type GroupEntry struct {
Expand Down
12 changes: 4 additions & 8 deletions builtin/credential/ldap/path_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,14 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
}

func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
keys, err := logical.CollectKeys(ctx, req.Storage)
keys, err := logical.CollectKeysWithPrefix(ctx, req.Storage, "user/")
if err != nil {
return nil, err
}
retKeys := make([]string, 0)
for _, key := range keys {
if strings.HasPrefix(key, "user/") && !strings.HasPrefix(key, "/") {
retKeys = append(retKeys, strings.TrimPrefix(key, "user/"))
}
for i := range keys {
keys[i] = strings.TrimPrefix(keys[i], "user/")
}
return logical.ListResponse(retKeys), nil

return logical.ListResponse(keys), nil
}

type UserEntry struct {
Expand Down
24 changes: 20 additions & 4 deletions builtin/credential/okta/path_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ func pathGroups(b *backend) *framework.Path {
return &framework.Path{
Pattern: `groups/(?P<name>.+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the Okta group.",
},

"policies": &framework.FieldSchema{
"policies": {
Type: framework.TypeCommaStringSlice,
Description: "Comma-separated list of policies associated to the group.",
},
Expand All @@ -57,10 +57,12 @@ func (b *backend) Group(ctx context.Context, s logical.Storage, n string) (*Grou
return nil, "", err
}
if entry == nil {
entries, err := s.List(ctx, "group/")
entries, err := groupList(ctx, s)
if err != nil {
return nil, "", err

}

for _, groupName := range entries {
if strings.EqualFold(groupName, n) {
entry, err = s.Get(ctx, "group/"+groupName)
Expand Down Expand Up @@ -157,13 +159,27 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
}

func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groups, err := req.Storage.List(ctx, "group/")
groups, err := groupList(ctx, req.Storage)
if err != nil {
return nil, err
}

return logical.ListResponse(groups), nil
}

func groupList(ctx context.Context, s logical.Storage) ([]string, error) {
keys, err := logical.CollectKeysWithPrefix(ctx, s, "group/")
if err != nil {
return nil, err
}

for i := range keys {
keys[i] = strings.TrimPrefix(keys[i], "group/")
}

return keys, nil
}

type GroupEntry struct {
Policies []string
}
Expand Down
108 changes: 108 additions & 0 deletions builtin/credential/okta/path_groups_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package okta

import (
"context"
"strings"
"testing"
"time"

"github.com/go-test/deep"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
)

func TestGroupsList(t *testing.T) {
b, storage := getBackend(t)

groups := []string{
"%20\\",
"foo",
"zfoo",
"🙂",
"foo/nested",
"foo/even/more/nested",
}

for _, group := range groups {
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "groups/" + group,
Storage: storage,
Data: map[string]interface{}{
"policies": []string{group + "_a", group + "_b"},
},
}

resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

}

for _, group := range groups {
for _, upper := range []bool{false, true} {
groupPath := group
if upper {
groupPath = strings.ToUpper(group)
}
req := &logical.Request{
Operation: logical.ReadOperation,
Path: "groups/" + groupPath,
Storage: storage,
}

resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if resp == nil {
t.Fatal("unexpected nil response")
}

expected := []string{group + "_a", group + "_b"}

if diff := deep.Equal(resp.Data["policies"].([]string), expected); diff != nil {
t.Fatal(diff)
}
}
}

req := &logical.Request{
Operation: logical.ListOperation,
Path: "groups",
Storage: storage,
}

resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

if diff := deep.Equal(resp.Data["keys"].([]string), groups); diff != nil {
t.Fatal(diff)
}
}

func getBackend(t *testing.T) (logical.Backend, logical.Storage) {
defaultLeaseTTLVal := time.Hour * 12
maxLeaseTTLVal := time.Hour * 24

config := &logical.BackendConfig{
Logger: logging.NewVaultLogger(log.Trace),

System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
StorageView: &logical.InmemStorage{},
}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatalf("unable to create backend: %v", err)
}

return b, config.StorageView
}
15 changes: 11 additions & 4 deletions sdk/logical/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,24 @@ func ScanView(ctx context.Context, view ClearableView, cb func(path string)) err

// CollectKeys is used to collect all the keys in a view
func CollectKeys(ctx context.Context, view ClearableView) ([]string, error) {
// Accumulate the keys
var existing []string
return CollectKeysWithPrefix(ctx, view, "")
}

// CollectKeysPrefix is used to collect all the keys in a view with a given prefix string
func CollectKeysWithPrefix(ctx context.Context, view ClearableView, prefix string) ([]string, error) {
var keys []string

cb := func(path string) {
existing = append(existing, path)
if strings.HasPrefix(path, prefix) {
keys = append(keys, path)
}
}

// Scan for all the keys
if err := ScanView(ctx, view, cb); err != nil {
return nil, err
}
return existing, nil
return keys, nil
}

// ClearView is used to delete all the keys in a view
Expand Down
86 changes: 86 additions & 0 deletions sdk/logical/storage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package logical

import (
"context"
"testing"

"github.com/go-test/deep"
)

var keyList = []string{
"a",
"b",
"d",
"foo",
"foo42",
"foo/a/b/c",
"c/d/e/f/g",
}

func TestScanView(t *testing.T) {
s := prepKeyStorage(t)

keys := make([]string, 0)
err := ScanView(context.Background(), s, func(path string) {
keys = append(keys, path)
})

if err != nil {
t.Fatal(err)
}

if diff := deep.Equal(keys, keyList); diff != nil {
t.Fatal(diff)
}
}

func TestCollectKeys(t *testing.T) {
s := prepKeyStorage(t)

keys, err := CollectKeys(context.Background(), s)

if err != nil {
t.Fatal(err)
}

if diff := deep.Equal(keys, keyList); diff != nil {
t.Fatal(diff)
}
}

func TestCollectKeysPrefix(t *testing.T) {
s := prepKeyStorage(t)

keys, err := CollectKeysWithPrefix(context.Background(), s, "foo")

if err != nil {
t.Fatal(err)
}

exp := []string{
"foo",
"foo42",
"foo/a/b/c",
}

if diff := deep.Equal(keys, exp); diff != nil {
t.Fatal(diff)
}
}

func prepKeyStorage(t *testing.T) Storage {
t.Helper()
s := &InmemStorage{}

for _, key := range keyList {
if err := s.Put(context.Background(), &StorageEntry{
Key: key,
Value: nil,
SealWrap: false,
}); err != nil {
t.Fatal(err)
}
}

return s
}

0 comments on commit c9ac721

Please sign in to comment.