Skip to content

Commit

Permalink
Add a test of the upgrade and deprecated field functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
jefferai committed May 1, 2019
1 parent 91e2367 commit 5e9a402
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 2 deletions.
25 changes: 23 additions & 2 deletions vault/token_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -3070,6 +3070,8 @@ func (ts *TokenStore) tokenStoreRoleCreateUpdate(ctx context.Context, req *logic
return logical.ErrorResponse(errwrap.Wrapf("error parsing role fields: {{err}}", err).Error()), nil
}

var resp *logical.Response

// Now handle backwards compat. Prefer token_ fields over others if both
// are set. We set the original fields here so that on read of token role
// we can return the same values that were set. We clear out the Token*
Expand All @@ -3083,6 +3085,13 @@ func (ts *TokenStore) tokenStoreRoleCreateUpdate(ctx context.Context, req *logic
entry.TokenPeriod = 0
}
} else {
_, ok = data.GetOk("period")
if ok {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning("Both 'token_period' and deprecated 'period' value supplied, ignoring the deprecated value")
}
entry.Period = 0
}

Expand All @@ -3098,11 +3107,16 @@ func (ts *TokenStore) tokenStoreRoleCreateUpdate(ctx context.Context, req *logic
entry.TokenBoundCIDRs = nil
}
} else {
_, ok = data.GetOk("bound_cidrs")
if ok {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning("Both 'token_bound_cidrs' and deprecated 'bound_cidrs' value supplied, ignoring the deprecated value")
}
entry.BoundCIDRs = nil
}

var resp *logical.Response

finalExplicitMaxTTL := entry.TokenExplicitMaxTTL
explicitMaxTTLRaw, ok := data.GetOk("token_explicit_max_ttl")
if !ok {
Expand All @@ -3113,6 +3127,13 @@ func (ts *TokenStore) tokenStoreRoleCreateUpdate(ctx context.Context, req *logic
}
finalExplicitMaxTTL = entry.ExplicitMaxTTL
} else {
_, ok = data.GetOk("explicit_max_ttl")
if ok {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning("Both 'token_explicit_max_ttl' and deprecated 'explicit_max_ttl' value supplied, ignoring the deprecated value")
}
entry.ExplicitMaxTTL = 0
}
if finalExplicitMaxTTL != 0 {
Expand Down
244 changes: 244 additions & 0 deletions vault/token_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/locksutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/tokenhelper"
"github.com/hashicorp/vault/sdk/logical"
)

Expand Down Expand Up @@ -3607,6 +3609,248 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) {
}
}

func TestTokenStore_RoleTokenFields(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
//c, _, root := TestCoreUnsealed(t)
ts := c.tokenStore
rootContext := namespace.RootContext(context.Background())

boundCIDRs, err := parseutil.ParseAddrs([]string{"127.0.0.1/32"})
if err != nil {
t.Fatal(err)
}

// First test the upgrade case. Create a role with values and ensure they
// are reflected properly on read.
{
roleEntry := &tsRoleEntry{
Name: "test",
TokenParams: tokenhelper.TokenParams{
TokenType: logical.TokenTypeBatch,
},
Period: time.Second,
ExplicitMaxTTL: time.Hour,
}
roleEntry.BoundCIDRs = boundCIDRs
ns := namespace.RootNamespace
jsonEntry, err := logical.StorageEntryJSON("test", roleEntry)
if err != nil {
t.Fatal(err)
}
if err := ts.rolesView(ns).Put(rootContext, jsonEntry); err != nil {
t.Fatal(err)
}
// Read it back
roleEntry, err = ts.tokenStoreRole(rootContext, "test")
if err != nil {
t.Fatal(err)
}
expRoleEntry := &tsRoleEntry{
Name: "test",
TokenParams: tokenhelper.TokenParams{
TokenPeriod: time.Second,
TokenExplicitMaxTTL: time.Hour,
TokenBoundCIDRs: boundCIDRs,
TokenType: logical.TokenTypeBatch,
},
Period: time.Second,
ExplicitMaxTTL: time.Hour,
BoundCIDRs: boundCIDRs,
}
if diff := deep.Equal(expRoleEntry, roleEntry); diff != nil {
t.Fatal(diff)
}
}

// Now, read that back through the API and verify we see what we expect
{
req := logical.TestRequest(t, logical.ReadOperation, "roles/test")
resp, err := ts.HandleRequest(rootContext, req)
if err != nil {
t.Fatalf("err: %v", err)
}

expected := map[string]interface{}{
"name": "test",
"orphan": false,
"period": int64(1),
"token_period": int64(1),
"allowed_policies": []string(nil),
"disallowed_policies": []string(nil),
"path_suffix": "",
"token_explicit_max_ttl": int64(3600),
"explicit_max_ttl": int64(3600),
"renewable": false,
"token_type": "batch",
}

if resp.Data["bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String() != "127.0.0.1" {
t.Fatalf("unexpected bound cidrs: %s", resp.Data["bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String())
}
delete(resp.Data, "bound_cidrs")
if resp.Data["token_bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String() != "127.0.0.1" {
t.Fatalf("unexpected token bound cidrs: %s", resp.Data["token_bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String())
}
delete(resp.Data, "token_bound_cidrs")

if diff := deep.Equal(expected, resp.Data); diff != nil {
t.Fatal(diff)
}
}

// Put values in just the old locations, but through the API
{
req := logical.TestRequest(t, logical.UpdateOperation, "roles/test")
req.Data = map[string]interface{}{
"explicit_max_ttl": 7200,
"token_type": "default-batch",
"period": 5,
"bound_cidrs": boundCIDRs[0].String(),
}

resp, err := ts.HandleRequest(rootContext, req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err: %v\nresp: %#v", err, resp)
}
if resp != nil {
t.Fatalf("expected a nil response")
}

req = logical.TestRequest(t, logical.ReadOperation, "roles/test")
resp, err = ts.HandleRequest(rootContext, req)
if err != nil {
t.Fatalf("err: %v", err)
}

expected := map[string]interface{}{
"name": "test",
"orphan": false,
"period": int64(5),
"token_period": int64(5),
"allowed_policies": []string(nil),
"disallowed_policies": []string(nil),
"path_suffix": "",
"token_explicit_max_ttl": int64(7200),
"explicit_max_ttl": int64(7200),
"renewable": false,
"token_type": "default-batch",
}

if resp.Data["bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String() != "127.0.0.1" {
t.Fatalf("unexpected bound cidrs: %s", resp.Data["bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String())
}
delete(resp.Data, "bound_cidrs")
if resp.Data["token_bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String() != "127.0.0.1" {
t.Fatalf("unexpected token bound cidrs: %s", resp.Data["token_bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String())
}
delete(resp.Data, "token_bound_cidrs")

if diff := deep.Equal(expected, resp.Data); diff != nil {
t.Fatal(diff)
}
}
// Same thing for just the new locations
{
req := logical.TestRequest(t, logical.UpdateOperation, "roles/test")
req.Data = map[string]interface{}{
"token_explicit_max_ttl": 5200,
"token_type": "default-service",
"token_period": 7,
"token_bound_cidrs": boundCIDRs[0].String(),
}

resp, err := ts.HandleRequest(rootContext, req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err: %v\nresp: %#v", err, resp)
}
if resp != nil {
t.Fatalf("expected a nil response")
}

req = logical.TestRequest(t, logical.ReadOperation, "roles/test")
resp, err = ts.HandleRequest(rootContext, req)
if err != nil {
t.Fatalf("err: %v", err)
}

expected := map[string]interface{}{
"name": "test",
"orphan": false,
"period": int64(0),
"token_period": int64(7),
"allowed_policies": []string(nil),
"disallowed_policies": []string(nil),
"path_suffix": "",
"token_explicit_max_ttl": int64(5200),
"explicit_max_ttl": int64(0),
"renewable": false,
"token_type": "default-service",
}

if resp.Data["token_bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String() != "127.0.0.1" {
t.Fatalf("unexpected token bound cidrs: %s", resp.Data["token_bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String())
}
delete(resp.Data, "token_bound_cidrs")

if diff := deep.Equal(expected, resp.Data); diff != nil {
t.Fatal(diff)
}
}
// Put values in both locations
{
req := logical.TestRequest(t, logical.UpdateOperation, "roles/test")
req.Data = map[string]interface{}{
"token_explicit_max_ttl": 7200,
"explicit_max_ttl": 5200,
"token_type": "service",
"token_period": 5,
"period": 1,
"token_bound_cidrs": boundCIDRs[0].String(),
"bound_cidrs": boundCIDRs[0].String(),
}

resp, err := ts.HandleRequest(rootContext, req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err: %v\nresp: %#v", err, resp)
}
if resp == nil {
t.Fatalf("expected a non-nil response")
}
if len(resp.Warnings) != 3 {
t.Fatalf("expected 3 warnings, got %#v", resp.Warnings)
}

req = logical.TestRequest(t, logical.ReadOperation, "roles/test")
resp, err = ts.HandleRequest(rootContext, req)
if err != nil {
t.Fatalf("err: %v", err)
}

expected := map[string]interface{}{
"name": "test",
"orphan": false,
"period": int64(0),
"token_period": int64(5),
"allowed_policies": []string(nil),
"disallowed_policies": []string(nil),
"path_suffix": "",
"token_explicit_max_ttl": int64(7200),
"explicit_max_ttl": int64(0),
"renewable": false,
"token_type": "service",
}

if resp.Data["token_bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String() != "127.0.0.1" {
t.Fatalf("unexpected token bound cidrs: %s", resp.Data["token_bound_cidrs"].([]*sockaddr.SockAddrMarshaler)[0].String())
}
delete(resp.Data, "token_bound_cidrs")

if diff := deep.Equal(expected, resp.Data); diff != nil {
t.Fatal(diff)
}
}
}

func TestTokenStore_Periodic(t *testing.T) {
core, _, root := TestCoreUnsealed(t)

Expand Down

0 comments on commit 5e9a402

Please sign in to comment.