Skip to content

Commit

Permalink
Strip empty strings from database revocation stmts (#5955)
Browse files Browse the repository at this point in the history
* Strip empty strings from database revocation stmts

It's technically valid to give empty strings as statements to run on
most databases. However, in the case of revocation statements, it's not
only generally inadvisable but can lead to lack of revocations when you
expect them. This strips empty strings from the array of revocation
statements.

It also makes two other changes:

* Return statements on read as empty but valid arrays rather than nulls,
so that typing information is inferred (this is more in line with the
rest of Vault these days)

* Changes field data for TypeStringSlice and TypeCommaStringSlice such
that a client-supplied value of `""` doesn't turn into `[]string{""}`
but rather `[]string{}`.

The latter and the explicit revocation statement changes are related,
and defense in depth.
  • Loading branch information
jefferai committed Dec 14, 2018
1 parent 9f404f2 commit 21e7462
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 15 deletions.
3 changes: 3 additions & 0 deletions builtin/logical/database/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/hashicorp/errwrap"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
Expand Down Expand Up @@ -159,6 +160,8 @@ func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName
result.Statements = stmts
}

result.Statements.Revocation = strutil.RemoveEmpty(result.Statements.Revocation)

// For backwards compatibility, copy the values back into the string form
// of the fields
result.Statements = dbutil.StatementCompatibilityHelper(result.Statements)
Expand Down
12 changes: 8 additions & 4 deletions builtin/logical/database/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,8 @@ func TestBackend_roleCrud(t *testing.T) {
expected := dbplugin.Statements{
Creation: []string{strings.TrimSpace(testRole)},
Revocation: []string{strings.TrimSpace(defaultRevocationSQL)},
Rollback: []string{},
Renewal: []string{},
}

actual := dbplugin.Statements{
Expand All @@ -886,8 +888,8 @@ func TestBackend_roleCrud(t *testing.T) {
Renewal: resp.Data["renew_statements"].([]string),
}

if !reflect.DeepEqual(expected, actual) {
t.Fatalf("Statements did not match, expected %#v, got %#v", expected, actual)
if diff := deep.Equal(expected, actual); diff != nil {
t.Fatal(diff)
}

if diff := deep.Equal(resp.Data["db_name"], "plugin-test"); diff != nil {
Expand Down Expand Up @@ -945,6 +947,8 @@ func TestBackend_roleCrud(t *testing.T) {
expected := dbplugin.Statements{
Creation: []string{strings.TrimSpace(testRole)},
Revocation: []string{strings.TrimSpace(defaultRevocationSQL)},
Rollback: []string{},
Renewal: []string{},
}

actual := dbplugin.Statements{
Expand Down Expand Up @@ -1028,8 +1032,8 @@ func TestBackend_roleCrud(t *testing.T) {
Renewal: resp.Data["renew_statements"].([]string),
}

if !reflect.DeepEqual(expected, actual) {
t.Fatalf("Statements did not match, expected %#v, got %#v", expected, actual)
if diff := deep.Equal(expected, actual); diff != nil {
t.Fatal(diff)
}

if diff := deep.Equal(resp.Data["db_name"], "plugin-test"); diff != nil {
Expand Down
39 changes: 28 additions & 11 deletions builtin/logical/database/path_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
Expand Down Expand Up @@ -110,25 +111,39 @@ func (b *databaseBackend) pathRoleDelete() framework.OperationFunc {
}

func (b *databaseBackend) pathRoleRead() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
role, err := b.Role(ctx, req.Storage, d.Get("name").(string))
if err != nil {
return nil, err
}
if role == nil {
return nil, nil
}

data := map[string]interface{}{
"db_name": role.DBName,
"creation_statements": role.Statements.Creation,
"revocation_statements": role.Statements.Revocation,
"rollback_statements": role.Statements.Rollback,
"renew_statements": role.Statements.Renewal,
"default_ttl": role.DefaultTTL.Seconds(),
"max_ttl": role.MaxTTL.Seconds(),
}
if len(role.Statements.Creation) == 0 {
data["creation_statements"] = []string{}
}
if len(role.Statements.Revocation) == 0 {
data["revocation_statements"] = []string{}
}
if len(role.Statements.Rollback) == 0 {
data["rollback_statements"] = []string{}
}
if len(role.Statements.Renewal) == 0 {
data["renew_statements"] = []string{}
}

return &logical.Response{
Data: map[string]interface{}{
"db_name": role.DBName,
"creation_statements": role.Statements.Creation,
"revocation_statements": role.Statements.Revocation,
"rollback_statements": role.Statements.Rollback,
"renew_statements": role.Statements.Renewal,
"default_ttl": role.DefaultTTL.Seconds(),
"max_ttl": role.MaxTTL.Seconds(),
},
Data: data,
}, nil
}
}
Expand Down Expand Up @@ -218,6 +233,8 @@ func (b *databaseBackend) pathRoleCreateUpdate() framework.OperationFunc {
role.Statements.RollbackStatements = ""
}

role.Statements.Revocation = strutil.RemoveEmpty(role.Statements.Revocation)

// Store it
entry, err := logical.StorageEntryJSON("role/"+name, role)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions helper/parseutil/parseutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ func ParseBool(in interface{}) (bool, error) {
}

func ParseCommaStringSlice(in interface{}) ([]string, error) {
rawString, ok := in.(string)
if ok && rawString == "" {
return []string{}, nil
}
var result []string
config := &mapstructure.DecoderConfig{
Result: &result,
Expand Down
16 changes: 16 additions & 0 deletions helper/strutil/strutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,22 @@ func RemoveDuplicates(items []string, lowercase bool) []string {
return items
}

// RemoveEmpty removes empty elements from a slice of
// strings
func RemoveEmpty(items []string) []string {
if len(items) == 0 {
return items
}
itemsSlice := make([]string, 0, len(items))
for _, item := range items {
if item == "" {
continue
}
itemsSlice = append(itemsSlice, item)
}
return itemsSlice
}

// EquivalentSlices checks whether the given string sets are equivalent, as in,
// they contain the same values.
func EquivalentSlices(a, b []string) bool {
Expand Down
16 changes: 16 additions & 0 deletions helper/strutil/strutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,22 @@ func TestTrimStrings(t *testing.T) {
}
}

func TestRemoveEmpty(t *testing.T) {
input := []string{"abc", "", "abc", ""}
expected := []string{"abc", "abc"}
actual := RemoveEmpty(input)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("Bad TrimStrings: expected:%#v, got:%#v", expected, actual)
}

input = []string{""}
expected = []string{}
actual = RemoveEmpty(input)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("Bad TrimStrings: expected:%#v, got:%#v", expected, actual)
}
}

func TestStrutil_AppendIfMissing(t *testing.T) {
keys := []string{}

Expand Down
5 changes: 5 additions & 0 deletions logical/framework/field_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ func (d *FieldData) getPrimitive(k string, schema *FieldSchema) (interface{}, bo
return result, true, nil

case TypeStringSlice:
rawString, ok := raw.(string)
if ok && rawString == "" {
return []string{}, true, nil
}

var result []string
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
Expand Down
22 changes: 22 additions & 0 deletions logical/framework/field_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,28 @@ func TestFieldDataGet(t *testing.T) {
[]string{"abc"},
},

"string slice type, empty string": {
map[string]*FieldSchema{
"foo": &FieldSchema{Type: TypeStringSlice},
},
map[string]interface{}{
"foo": "",
},
"foo",
[]string{},
},

"comma string slice type, empty string": {
map[string]*FieldSchema{
"foo": &FieldSchema{Type: TypeCommaStringSlice},
},
map[string]interface{}{
"foo": "",
},
"foo",
[]string{},
},

"comma string slice type, comma string with one value": {
map[string]*FieldSchema{
"foo": &FieldSchema{Type: TypeCommaStringSlice},
Expand Down

0 comments on commit 21e7462

Please sign in to comment.