From 308c20a827acf55369742f5af3c95c529930acd5 Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Tue, 16 Jan 2018 15:40:58 -0500 Subject: [PATCH] Converting OU and Organization role fields to CommaStringSlice --- builtin/logical/pki/backend_test.go | 12 +-- builtin/logical/pki/cert_util.go | 15 +--- builtin/logical/pki/path_roles.go | 30 +++++-- builtin/logical/pki/path_roles_test.go | 108 +++++++++++++++++++++++++ helper/strutil/strutil_test.go | 24 ++++++ 5 files changed, 162 insertions(+), 27 deletions(-) diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 0e4f73634d23..774145a5dd72 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -1492,7 +1492,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { } cert := parsedCertBundle.Certificate - expected := strutil.ParseDedupLowercaseAndSortStrings(role.OU, ",") + expected := strutil.RemoveDuplicates(role.OU, true) if !reflect.DeepEqual(cert.Subject.OrganizationalUnit, expected) { return fmt.Errorf("Error: returned certificate has OU of %s but %s was specified in the role.", cert.Subject.OrganizationalUnit, expected) } @@ -1513,7 +1513,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { } cert := parsedCertBundle.Certificate - expected := strutil.ParseDedupLowercaseAndSortStrings(role.Organization, ",") + expected := strutil.RemoveDuplicates(role.Organization, true) if !reflect.DeepEqual(cert.Subject.Organization, expected) { return fmt.Errorf("Error: returned certificate has Organization of %s but %s was specified in the role.", cert.Subject.Organization, expected) } @@ -1798,18 +1798,18 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { } // OU tests { - roleVals.OU = "foo" + roleVals.OU = []string{"foo"} addTests(getOuCheck(roleVals)) - roleVals.OU = "foo,bar" + roleVals.OU = []string{"foo", "bar"} addTests(getOuCheck(roleVals)) } // Organization tests { - roleVals.Organization = "system:masters" + roleVals.Organization = []string{"system:masters"} addTests(getOrganizationCheck(roleVals)) - roleVals.Organization = "foo,bar" + roleVals.Organization = []string{"foo", "bar"} addTests(getOrganizationCheck(roleVals)) } // IP SAN tests diff --git a/builtin/logical/pki/cert_util.go b/builtin/logical/pki/cert_util.go index 22161118b647..acd5d43cd438 100644 --- a/builtin/logical/pki/cert_util.go +++ b/builtin/logical/pki/cert_util.go @@ -715,20 +715,9 @@ func generateCreationBundle(b *backend, } // Set OU (organizationalUnit) values if specified in the role - ou := []string{} - { - if role.OU != "" { - ou = strutil.RemoveDuplicates(strutil.ParseStringSlice(role.OU, ","), false) - } - } - + ou := strutil.RemoveDuplicates(role.OU, false) // Set O (organization) values if specified in the role - organization := []string{} - { - if role.Organization != "" { - organization = strutil.RemoveDuplicates(strutil.ParseStringSlice(role.Organization, ","), false) - } - } + organization := strutil.RemoveDuplicates(role.Organization, false) // Get the TTL and verify it against the max allowed var ttl time.Duration diff --git a/builtin/logical/pki/path_roles.go b/builtin/logical/pki/path_roles.go index 0c3602590cc8..73cdd5777dc6 100644 --- a/builtin/logical/pki/path_roles.go +++ b/builtin/logical/pki/path_roles.go @@ -185,15 +185,13 @@ include the Common Name (cn). Defaults to true.`, }, "ou": &framework.FieldSchema{ - Type: framework.TypeString, - Default: "", + Type: framework.TypeCommaStringSlice, Description: `If set, the OU (OrganizationalUnit) will be set to this value in certificates issued by this role.`, }, "organization": &framework.FieldSchema{ - Type: framework.TypeString, - Default: "", + Type: framework.TypeCommaStringSlice, Description: `If set, the O (Organization) will be set to this value in certificates issued by this role.`, }, @@ -303,6 +301,20 @@ func (b *backend) getRole(s logical.Storage, n string) (*roleEntry, error) { modified = true } + // Upgrade OU + if result.OUOld != "" { + result.OU = strings.Split(result.OUOld, ",") + result.OUOld = "" + modified = true + } + + // Upgrade Organization + if result.OrganizationOld != "" { + result.Organization = strings.Split(result.OrganizationOld, ",") + result.OrganizationOld = "" + modified = true + } + if modified { jsonEntry, err := logical.StorageEntryJSON("role/"+n, &result) if err != nil { @@ -394,8 +406,8 @@ func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data UseCSRCommonName: data.Get("use_csr_common_name").(bool), UseCSRSANs: data.Get("use_csr_sans").(bool), KeyUsage: data.Get("key_usage").([]string), - OU: data.Get("ou").(string), - Organization: data.Get("organization").(string), + OU: data.Get("ou").([]string), + Organization: data.Get("organization").([]string), GenerateLease: new(bool), NoStore: data.Get("no_store").(bool), } @@ -522,8 +534,10 @@ type roleEntry struct { MaxPathLength *int `json:",omitempty" mapstructure:"max_path_length"` KeyUsageOld string `json:"key_usage,omitempty"` KeyUsage []string `json:"key_usage_list" mapstructure:"key_usage"` - OU string `json:"ou" mapstructure:"ou"` - Organization string `json:"organization" mapstructure:"organization"` + OUOld string `json:"ou,omitempty"` + OU []string `json:"ou_list" mapstructure:"ou"` + OrganizationOld string `json:"organization,omitempty"` + Organization []string `json:"organization_list" mapstructure:"organization"` GenerateLease *bool `json:"generate_lease,omitempty"` NoStore bool `json:"no_store" mapstructure:"no_store"` diff --git a/builtin/logical/pki/path_roles_test.go b/builtin/logical/pki/path_roles_test.go index 7a336e6c6bd4..801e13925764 100644 --- a/builtin/logical/pki/path_roles_test.go +++ b/builtin/logical/pki/path_roles_test.go @@ -209,6 +209,114 @@ func TestPki_RoleKeyUsage(t *testing.T) { } } +func TestPki_RoleOUOrganizationUpgrade(t *testing.T) { + var resp *logical.Response + var err error + b, storage := createBackendWithStorage(t) + + roleData := map[string]interface{}{ + "allowed_domains": "myvault.com", + "ttl": "5h", + "ou": []string{"abc", "123"}, + "organization": []string{"org1", "org2"}, + } + + roleReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/testrole", + Storage: storage, + Data: roleData, + } + + resp, err = b.HandleRequest(context.Background(), roleReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v resp: %#v", err, resp) + } + + roleReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(context.Background(), roleReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v resp: %#v", err, resp) + } + + ou := resp.Data["ou"].([]string) + if len(ou) != 2 { + t.Fatalf("ou should have 2 values") + } + organization := resp.Data["organization"].([]string) + if len(organization) != 2 { + t.Fatalf("organziation should have 2 values") + } + + // Check that old key usage value is nil + var role roleEntry + err = mapstructure.Decode(resp.Data, &role) + if err != nil { + t.Fatal(err) + } + if role.OUOld != "" { + t.Fatalf("old ou storage value should be blank") + } + if role.OrganizationOld != "" { + t.Fatalf("old organization storage value should be blank") + } + + // Make it explicit + role.OUOld = "abc,123" + role.OU = nil + role.OrganizationOld = "org1,org2" + role.Organization = nil + + entry, err := logical.StorageEntryJSON("role/testrole", role) + if err != nil { + t.Fatal(err) + } + if err := storage.Put(entry); err != nil { + t.Fatal(err) + } + + // Reading should upgrade key_usage + resp, err = b.HandleRequest(context.Background(), roleReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v resp: %#v", err, resp) + } + + ou = resp.Data["ou"].([]string) + if len(ou) != 2 { + t.Fatalf("ou should have 2 values") + } + organization = resp.Data["organization"].([]string) + if len(organization) != 2 { + t.Fatalf("organization should have 2 values") + } + + // Read back from storage to ensure upgrade + entry, err = storage.Get("role/testrole") + if err != nil { + t.Fatalf("err: %v", err) + } + if entry == nil { + t.Fatalf("role should not be nil") + } + var result roleEntry + if err := entry.DecodeJSON(&result); err != nil { + t.Fatalf("err: %v", err) + } + + if result.OUOld != "" { + t.Fatal("old ou value should be blank") + } + if len(result.OU) != 2 { + t.Fatal("ou should have 2 values") + } + if result.OrganizationOld != "" { + t.Fatal("old organization value should be blank") + } + if len(result.Organization) != 2 { + t.Fatal("organization should have 2 values") + } +} + func TestPki_RoleAllowedDomains(t *testing.T) { var resp *logical.Response var err error diff --git a/helper/strutil/strutil_test.go b/helper/strutil/strutil_test.go index 293926500662..87feb4a35215 100644 --- a/helper/strutil/strutil_test.go +++ b/helper/strutil/strutil_test.go @@ -399,3 +399,27 @@ func TestStrutil_AppendIfMissing(t *testing.T) { t.Fatalf("expected slice to still contain key 'bar': %v", keys) } } + +func TestStrUtil_RemoveDuplicates(t *testing.T) { + type tCase struct { + input []string + expect []string + lowercase bool + } + + tCases := []tCase{ + tCase{[]string{}, []string{}, false}, + tCase{[]string{}, []string{}, true}, + tCase{[]string{"a", "b", "a"}, []string{"a", "b"}, false}, + tCase{[]string{"A", "b", "a"}, []string{"A", "a", "b"}, false}, + tCase{[]string{"A", "b", "a"}, []string{"a", "b"}, true}, + } + + for _, tc := range tCases { + actual := RemoveDuplicates(tc.input, tc.lowercase) + + if !reflect.DeepEqual(actual, tc.expect) { + t.Fatalf("Bad testcase %#v, expected %v, got %v", tc, tc.expect, actual) + } + } +}