diff --git a/api/api_integration_test.go b/api/api_integration_test.go index 7d4eb0705f3e..21a988dcbdef 100644 --- a/api/api_integration_test.go +++ b/api/api_integration_test.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/vault/builtin/logical/database" "github.com/hashicorp/vault/builtin/logical/pki" "github.com/hashicorp/vault/builtin/logical/transit" + "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/vault" @@ -56,6 +57,7 @@ func testVaultServerUnseal(t testing.TB) (*api.Client, []string, func()) { "pki": pki.Factory, "transit": transit.Factory, }, + BuiltinRegistry: builtinplugins.Registry, }) } diff --git a/api/client.go b/api/client.go index 1fc2530eeb2a..3b46fef1b8bd 100644 --- a/api/client.go +++ b/api/client.go @@ -17,7 +17,7 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/go-cleanhttp" - retryablehttp "github.com/hashicorp/go-retryablehttp" + "github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/go-rootcerts" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/parseutil" diff --git a/api/sys_mounts.go b/api/sys_mounts.go index ec6a09a56c75..e42f22b74c57 100644 --- a/api/sys_mounts.go +++ b/api/sys_mounts.go @@ -133,7 +133,6 @@ type MountInput struct { Description string `json:"description"` Config MountConfigInput `json:"config"` Local bool `json:"local"` - PluginName string `json:"plugin_name,omitempty"` SealWrap bool `json:"seal_wrap" mapstructure:"seal_wrap"` Options map[string]string `json:"options"` } @@ -144,7 +143,6 @@ type MountConfigInput struct { Description *string `json:"description,omitempty" mapstructure:"description"` MaxLeaseTTL string `json:"max_lease_ttl" mapstructure:"max_lease_ttl"` ForceNoCache bool `json:"force_no_cache" mapstructure:"force_no_cache"` - PluginName string `json:"plugin_name,omitempty" mapstructure:"plugin_name"` AuditNonHMACRequestKeys []string `json:"audit_non_hmac_request_keys,omitempty" mapstructure:"audit_non_hmac_request_keys"` AuditNonHMACResponseKeys []string `json:"audit_non_hmac_response_keys,omitempty" mapstructure:"audit_non_hmac_response_keys"` ListingVisibility string `json:"listing_visibility,omitempty" mapstructure:"listing_visibility"` @@ -166,7 +164,6 @@ type MountConfigOutput struct { DefaultLeaseTTL int `json:"default_lease_ttl" mapstructure:"default_lease_ttl"` MaxLeaseTTL int `json:"max_lease_ttl" mapstructure:"max_lease_ttl"` ForceNoCache bool `json:"force_no_cache" mapstructure:"force_no_cache"` - PluginName string `json:"plugin_name,omitempty" mapstructure:"plugin_name"` AuditNonHMACRequestKeys []string `json:"audit_non_hmac_request_keys,omitempty" mapstructure:"audit_non_hmac_request_keys"` AuditNonHMACResponseKeys []string `json:"audit_non_hmac_response_keys,omitempty" mapstructure:"audit_non_hmac_response_keys"` ListingVisibility string `json:"listing_visibility,omitempty" mapstructure:"listing_visibility"` diff --git a/api/sys_plugins.go b/api/sys_plugins.go index b2f18d94d769..52be04f6678a 100644 --- a/api/sys_plugins.go +++ b/api/sys_plugins.go @@ -2,24 +2,43 @@ package api import ( "context" + "errors" "fmt" "net/http" + + "github.com/hashicorp/vault/helper/consts" + "github.com/mitchellh/mapstructure" ) // ListPluginsInput is used as input to the ListPlugins function. -type ListPluginsInput struct{} +type ListPluginsInput struct { + // Type of the plugin. Required. + Type consts.PluginType `json:"type"` +} // ListPluginsResponse is the response from the ListPlugins call. type ListPluginsResponse struct { - // Names is the list of names of the plugins. - Names []string `json:"names"` + // PluginsByType is the list of plugins by type. + PluginsByType map[consts.PluginType][]string `json:"types"` + + // NamesDeprecated is the list of names of the plugins. + NamesDeprecated []string `json:"names"` } // ListPlugins lists all plugins in the catalog and returns their names as a // list of strings. func (c *Sys) ListPlugins(i *ListPluginsInput) (*ListPluginsResponse, error) { - path := "/v1/sys/plugins/catalog" - req := c.c.NewRequest("LIST", path) + path := "" + method := "" + if i.Type == consts.PluginTypeUnknown { + path = "/v1/sys/plugins/catalog" + method = "GET" + } else { + path = fmt.Sprintf("/v1/sys/plugins/catalog/%s", i.Type) + method = "LIST" + } + + req := c.c.NewRequest(method, path) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() @@ -29,21 +48,76 @@ func (c *Sys) ListPlugins(i *ListPluginsInput) (*ListPluginsResponse, error) { } defer resp.Body.Close() - var result struct { - Data struct { - Keys []string `json:"keys"` - } `json:"data"` - } - if err := resp.DecodeJSON(&result); err != nil { + secret, err := ParseSecret(resp.Body) + if err != nil { return nil, err } + if secret == nil || secret.Data == nil { + return nil, errors.New("data from server response is empty") + } - return &ListPluginsResponse{Names: result.Data.Keys}, nil + if resp.StatusCode == 405 && req.Method == "GET" { + // We received an Unsupported Operation response from Vault, indicating + // Vault of an older version that doesn't support the READ method yet. + req.Method = "LIST" + resp, err := c.c.RawRequestWithContext(ctx, req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + var result struct { + Data struct { + Keys []string `json:"keys"` + } `json:"data"` + } + if err := resp.DecodeJSON(&result); err != nil { + return nil, err + } + return &ListPluginsResponse{NamesDeprecated: result.Data.Keys}, nil + } + + result := &ListPluginsResponse{ + PluginsByType: make(map[consts.PluginType][]string), + } + if i.Type == consts.PluginTypeUnknown { + for pluginTypeStr, pluginsRaw := range secret.Data { + pluginType, err := consts.ParsePluginType(pluginTypeStr) + if err != nil { + return nil, err + } + + pluginsIfc, ok := pluginsRaw.([]interface{}) + if !ok { + return nil, fmt.Errorf("unable to parse plugins for %q type", pluginTypeStr) + } + + plugins := make([]string, len(pluginsIfc)) + for i, nameIfc := range pluginsIfc { + name, ok := nameIfc.(string) + if !ok { + + } + plugins[i] = name + } + result.PluginsByType[pluginType] = plugins + } + } else { + var respKeys []string + if err := mapstructure.Decode(secret.Data["keys"], &respKeys); err != nil { + return nil, err + } + result.PluginsByType[i.Type] = respKeys + } + + return result, nil } // GetPluginInput is used as input to the GetPlugin function. type GetPluginInput struct { Name string `json:"-"` + + // Type of the plugin. Required. + Type consts.PluginType `json:"type"` } // GetPluginResponse is the response from the GetPlugin call. @@ -56,7 +130,7 @@ type GetPluginResponse struct { } func (c *Sys) GetPlugin(i *GetPluginInput) (*GetPluginResponse, error) { - path := fmt.Sprintf("/v1/sys/plugins/catalog/%s", i.Name) + path := fmt.Sprintf("/v1/sys/plugins/catalog/%s/%s", i.Type, i.Name) req := c.c.NewRequest(http.MethodGet, path) ctx, cancelFunc := context.WithCancel(context.Background()) @@ -82,6 +156,9 @@ type RegisterPluginInput struct { // Name is the name of the plugin. Required. Name string `json:"-"` + // Type of the plugin. Required. + Type consts.PluginType `json:"type"` + // Args is the list of args to spawn the process with. Args []string `json:"args,omitempty"` @@ -94,7 +171,7 @@ type RegisterPluginInput struct { // RegisterPlugin registers the plugin with the given information. func (c *Sys) RegisterPlugin(i *RegisterPluginInput) error { - path := fmt.Sprintf("/v1/sys/plugins/catalog/%s", i.Name) + path := fmt.Sprintf("/v1/sys/plugins/catalog/%s/%s", i.Type, i.Name) req := c.c.NewRequest(http.MethodPut, path) if err := req.SetJSONBody(i); err != nil { return err @@ -113,12 +190,15 @@ func (c *Sys) RegisterPlugin(i *RegisterPluginInput) error { type DeregisterPluginInput struct { // Name is the name of the plugin. Required. Name string `json:"-"` + + // Type of the plugin. Required. + Type consts.PluginType `json:"type"` } // DeregisterPlugin removes the plugin with the given name from the plugin // catalog. func (c *Sys) DeregisterPlugin(i *DeregisterPluginInput) error { - path := fmt.Sprintf("/v1/sys/plugins/catalog/%s", i.Name) + path := fmt.Sprintf("/v1/sys/plugins/catalog/%s/%s", i.Type, i.Name) req := c.c.NewRequest(http.MethodDelete, path) ctx, cancelFunc := context.WithCancel(context.Background()) diff --git a/builtin/credential/app-id/backend_test.go b/builtin/credential/app-id/backend_test.go index e25fa9cbb7aa..68a5e31d94bd 100644 --- a/builtin/credential/app-id/backend_test.go +++ b/builtin/credential/app-id/backend_test.go @@ -26,7 +26,7 @@ func TestBackend_basic(t *testing.T) { return b, nil } logicaltest.Test(t, logicaltest.TestCase{ - Factory: factory, + CredentialFactory: factory, Steps: []logicaltest.TestStep{ testAccStepMapAppId(t), testAccStepMapUserId(t), @@ -65,7 +65,7 @@ func TestBackend_basic(t *testing.T) { func TestBackend_cidr(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ - Factory: Factory, + CredentialFactory: Factory, Steps: []logicaltest.TestStep{ testAccStepMapAppIdDisplayName(t), testAccStepMapUserIdCidr(t, "192.168.1.0/16"), @@ -78,7 +78,7 @@ func TestBackend_cidr(t *testing.T) { func TestBackend_displayName(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ - Factory: Factory, + CredentialFactory: Factory, Steps: []logicaltest.TestStep{ testAccStepMapAppIdDisplayName(t), testAccStepMapUserId(t), diff --git a/builtin/credential/approle/cmd/main.go b/builtin/credential/approle/cmd/approle/main.go similarity index 100% rename from builtin/credential/approle/cmd/main.go rename to builtin/credential/approle/cmd/approle/main.go diff --git a/builtin/credential/aws/backend_test.go b/builtin/credential/aws/backend_test.go index 5f05594c03c3..8cdf0f5abe47 100644 --- a/builtin/credential/aws/backend_test.go +++ b/builtin/credential/aws/backend_test.go @@ -508,8 +508,8 @@ func TestBackend_ConfigClient(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - AcceptanceTest: false, - Backend: b, + AcceptanceTest: false, + CredentialBackend: b, Steps: []logicaltest.TestStep{ stepCreate, stepInvalidAccessKey, diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index 442ad54674e5..d1fc28e12b5f 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -220,7 +220,7 @@ func TestBackend_PermittedDNSDomainsIntermediateCA(t *testing.T) { // Sign the intermediate CSR using /pki secret, err = client.Logical().Write("pki/root/sign-intermediate", map[string]interface{}{ "permitted_dns_domains": ".myvault.com", - "csr": intermediateCSR, + "csr": intermediateCSR, }) if err != nil { t.Fatal(err) @@ -840,7 +840,7 @@ func TestBackend_CertWrites(t *testing.T) { } tc := logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "aaa", ca1, "foo", allowed{}, false), testAccStepCert(t, "bbb", ca2, "foo", allowed{}, false), @@ -863,7 +863,7 @@ func TestBackend_basic_CA(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{}, false), testAccStepLogin(t, connState), @@ -898,7 +898,7 @@ func TestBackend_Basic_CRLs(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCertNoLease(t, "web", ca, "foo"), testAccStepLoginDefaultLease(t, connState), @@ -923,7 +923,7 @@ func TestBackend_basic_singleCert(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{}, false), testAccStepLogin(t, connState), @@ -948,7 +948,7 @@ func TestBackend_common_name_singleCert(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{}, false), testAccStepLogin(t, connState), @@ -977,7 +977,7 @@ func TestBackend_ext_singleCert(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{ext: "2.1.1.1:A UTF8String Extension"}, false), testAccStepLogin(t, connState), @@ -1032,7 +1032,7 @@ func TestBackend_dns_singleCert(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{dns: "example.com"}, false), testAccStepLogin(t, connState), @@ -1063,7 +1063,7 @@ func TestBackend_email_singleCert(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{emails: "valid@example.com"}, false), testAccStepLogin(t, connState), @@ -1094,7 +1094,7 @@ func TestBackend_organizationalUnit_singleCert(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{organizational_units: "engineering"}, false), testAccStepLogin(t, connState), @@ -1123,7 +1123,7 @@ func TestBackend_uri_singleCert(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{uris: "spiffe://example.com/*"}, false), testAccStepLogin(t, connState), @@ -1151,7 +1151,7 @@ func TestBackend_mixed_constraints(t *testing.T) { t.Fatalf("err: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "1unconstrained", ca, "foo", allowed{}, false), testAccStepCert(t, "2matching", ca, "foo", allowed{names: "*.example.com,whatever"}, false), @@ -1172,7 +1172,7 @@ func TestBackend_untrusted(t *testing.T) { t.Fatalf("error testing connection state: %v", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: testFactory(t), + CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepLoginInvalid(t, connState), }, diff --git a/builtin/credential/github/backend_test.go b/builtin/credential/github/backend_test.go index d2511caac5d6..05ed630ce221 100644 --- a/builtin/credential/github/backend_test.go +++ b/builtin/credential/github/backend_test.go @@ -49,8 +49,8 @@ func TestBackend_Config(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - PreCheck: func() { testAccPreCheck(t) }, - Backend: b, + PreCheck: func() { testAccPreCheck(t) }, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testConfigWrite(t, config_data1), testLoginWrite(t, login_data, expectedTTL1.Nanoseconds(), false), @@ -105,8 +105,8 @@ func TestBackend_basic(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - PreCheck: func() { testAccPreCheck(t) }, - Backend: b, + PreCheck: func() { testAccPreCheck(t) }, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, false), testAccMap(t, "default", "fakepol"), diff --git a/builtin/credential/ldap/backend_test.go b/builtin/credential/ldap/backend_test.go index 1df92ec8b5ab..f4702a6c9cd8 100644 --- a/builtin/credential/ldap/backend_test.go +++ b/builtin/credential/ldap/backend_test.go @@ -402,7 +402,7 @@ func TestBackend_basic(t *testing.T) { b := factory(t) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfigUrl(t), // Map Scientists group (from LDAP server) with foo policy @@ -429,7 +429,7 @@ func TestBackend_basic(t *testing.T) { func TestBackend_basic_noPolicies(t *testing.T) { b := factory(t) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfigUrl(t), // Create LDAP user @@ -444,7 +444,7 @@ func TestBackend_basic_noPolicies(t *testing.T) { func TestBackend_basic_group_noPolicies(t *testing.T) { b := factory(t) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfigUrl(t), // Create engineers group with no policies @@ -463,7 +463,7 @@ func TestBackend_basic_authbind(t *testing.T) { b := factory(t) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfigUrlWithAuthBind(t), testAccStepGroup(t, "Scientists", "foo"), @@ -478,7 +478,7 @@ func TestBackend_basic_discover(t *testing.T) { b := factory(t) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfigUrlWithDiscover(t), testAccStepGroup(t, "Scientists", "foo"), @@ -493,7 +493,7 @@ func TestBackend_basic_nogroupdn(t *testing.T) { b := factory(t) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfigUrlNoGroupDN(t), testAccStepGroup(t, "Scientists", "foo"), @@ -508,7 +508,7 @@ func TestBackend_groupCrud(t *testing.T) { b := factory(t) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepGroup(t, "g1", "foo"), testAccStepReadGroup(t, "g1", "foo"), @@ -525,7 +525,7 @@ func TestBackend_configDefaultsAfterUpdate(t *testing.T) { b := factory(t) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ logicaltest.TestStep{ Operation: logical.UpdateOperation, @@ -687,7 +687,7 @@ func TestBackend_userCrud(t *testing.T) { b := Backend() logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepUser(t, "g1", "bar"), testAccStepReadUser(t, "g1", "bar"), diff --git a/builtin/credential/okta/backend_test.go b/builtin/credential/okta/backend_test.go index f30310eed84a..e19c56852a32 100644 --- a/builtin/credential/okta/backend_test.go +++ b/builtin/credential/okta/backend_test.go @@ -49,7 +49,7 @@ func TestBackend_Config(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, PreCheck: func() { testAccPreCheck(t) }, - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testConfigCreate(t, configData), testLoginWrite(t, username, "wrong", "E0000004", 0, nil), diff --git a/builtin/credential/radius/backend_test.go b/builtin/credential/radius/backend_test.go index a8eed280b736..4deddfbf61aa 100644 --- a/builtin/credential/radius/backend_test.go +++ b/builtin/credential/radius/backend_test.go @@ -112,7 +112,7 @@ func TestBackend_Config(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: false, // PreCheck: func() { testAccPreCheck(t) }, - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testConfigWrite(t, configDataBasic, false), testConfigWrite(t, configDataMissingRequired, true), @@ -135,7 +135,7 @@ func TestBackend_users(t *testing.T) { t.Fatalf("Unable to create backend: %s", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testStepUpdateUser(t, "web", "foo"), testStepUpdateUser(t, "web2", "foo"), @@ -210,9 +210,9 @@ func TestBackend_acceptance(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - PreCheck: testAccPreCheck(t, host, port), - AcceptanceTest: true, + CredentialBackend: b, + PreCheck: testAccPreCheck(t, host, port), + AcceptanceTest: true, Steps: []logicaltest.TestStep{ // Login with valid but unknown user will fail because unregistered_user_policies is emtpy testConfigWrite(t, configDataAcceptanceNoAllowUnreg, false), diff --git a/builtin/credential/userpass/backend_test.go b/builtin/credential/userpass/backend_test.go index 355f94e5a506..fcca9273b4f6 100644 --- a/builtin/credential/userpass/backend_test.go +++ b/builtin/credential/userpass/backend_test.go @@ -101,7 +101,7 @@ func TestBackend_basic(t *testing.T) { t.Fatalf("Unable to create backend: %s", err) } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepUser(t, "web", "password", "foo"), testAccStepUser(t, "web2", "password", "foo"), @@ -125,7 +125,7 @@ func TestBackend_userCrud(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepUser(t, "web", "password", "foo"), testAccStepReadUser(t, "web", "foo"), @@ -148,7 +148,7 @@ func TestBackend_userCreateOperation(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testUserCreateOperation(t, "web", "password", "foo"), testAccStepLogin(t, "web", "password", []string{"default", "foo"}), @@ -169,7 +169,7 @@ func TestBackend_passwordUpdate(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepUser(t, "web", "password", "foo"), testAccStepReadUser(t, "web", "foo"), @@ -194,7 +194,7 @@ func TestBackend_policiesUpdate(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepUser(t, "web", "password", "foo"), testAccStepReadUser(t, "web", "foo"), diff --git a/builtin/logical/aws/backend_test.go b/builtin/logical/aws/backend_test.go index 5eb76104d945..1002cc4d43d7 100644 --- a/builtin/logical/aws/backend_test.go +++ b/builtin/logical/aws/backend_test.go @@ -47,7 +47,7 @@ func TestBackend_basic(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, PreCheck: func() { testAccPreCheck(t) }, - Backend: getBackend(t), + LogicalBackend: getBackend(t), Steps: []logicaltest.TestStep{ testAccStepConfig(t), testAccStepWritePolicy(t, "test", testDynamoPolicy), @@ -77,7 +77,7 @@ func TestBackend_basicSTS(t *testing.T) { log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...") time.Sleep(10 * time.Second) }, - Backend: getBackend(t), + LogicalBackend: getBackend(t), Steps: []logicaltest.TestStep{ testAccStepConfigWithCreds(t, accessKey), testAccStepRotateRoot(accessKey), @@ -103,7 +103,7 @@ func TestBackend_policyCrud(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Backend: getBackend(t), + LogicalBackend: getBackend(t), Steps: []logicaltest.TestStep{ testAccStepConfig(t), testAccStepWritePolicy(t, "test", testDynamoPolicy), @@ -724,7 +724,7 @@ func TestBackend_basicPolicyArnRef(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, PreCheck: func() { testAccPreCheck(t) }, - Backend: getBackend(t), + LogicalBackend: getBackend(t), Steps: []logicaltest.TestStep{ testAccStepConfig(t), testAccStepWriteArnPolicyRef(t, "test", ec2PolicyArn), @@ -755,7 +755,7 @@ func TestBackend_iamUserManagedInlinePolicies(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, PreCheck: func() { testAccPreCheck(t) }, - Backend: getBackend(t), + LogicalBackend: getBackend(t), Steps: []logicaltest.TestStep{ testAccStepConfig(t), testAccStepWriteRole(t, "test", roleData), @@ -804,7 +804,7 @@ func TestBackend_AssumedRoleWithPolicyDoc(t *testing.T) { log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...") time.Sleep(10 * time.Second) }, - Backend: getBackend(t), + LogicalBackend: getBackend(t), Steps: []logicaltest.TestStep{ testAccStepConfig(t), testAccStepWriteRole(t, "test", roleData), @@ -840,7 +840,7 @@ func TestBackend_RoleDefaultSTSTTL(t *testing.T) { log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...") time.Sleep(10 * time.Second) }, - Backend: getBackend(t), + LogicalBackend: getBackend(t), Steps: []logicaltest.TestStep{ testAccStepConfig(t), testAccStepWriteRole(t, "test", roleData), @@ -856,7 +856,7 @@ func TestBackend_policyArnCrud(t *testing.T) { t.Parallel() logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Backend: getBackend(t), + LogicalBackend: getBackend(t), Steps: []logicaltest.TestStep{ testAccStepConfig(t), testAccStepWriteArnPolicyRef(t, "test", ec2PolicyArn), diff --git a/builtin/logical/cassandra/backend_test.go b/builtin/logical/cassandra/backend_test.go index 47060a73155b..d476f4aa0d71 100644 --- a/builtin/logical/cassandra/backend_test.go +++ b/builtin/logical/cassandra/backend_test.go @@ -92,7 +92,7 @@ func TestBackend_basic(t *testing.T) { defer cleanup() logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, hostname), testAccStepRole(t), @@ -116,7 +116,7 @@ func TestBackend_roleCrud(t *testing.T) { defer cleanup() logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, hostname), testAccStepRole(t), diff --git a/builtin/logical/consul/backend_test.go b/builtin/logical/consul/backend_test.go index dba4be62d2c1..a5a011abde7a 100644 --- a/builtin/logical/consul/backend_test.go +++ b/builtin/logical/consul/backend_test.go @@ -564,7 +564,7 @@ func testBackendManagement(t *testing.T, version string) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData), testAccStepWriteManagementPolicy(t, "test", ""), @@ -608,7 +608,7 @@ func testBackendBasic(t *testing.T, version string) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData), testAccStepWritePolicy(t, "test", testPolicy, ""), @@ -620,7 +620,7 @@ func testBackendBasic(t *testing.T, version string) { func TestBackend_crud(t *testing.T) { b, _ := Factory(context.Background(), logical.TestBackendConfig()) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test", testPolicy, ""), testAccStepWritePolicy(t, "test2", testPolicy, ""), @@ -635,7 +635,7 @@ func TestBackend_crud(t *testing.T) { func TestBackend_role_lease(t *testing.T) { b, _ := Factory(context.Background(), logical.TestBackendConfig()) logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test", testPolicy, "6h"), testAccStepReadPolicy(t, "test", testPolicy, 6*time.Hour), diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index de25a9693fc7..1af3ad961a77 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -14,6 +14,7 @@ import ( "github.com/go-test/deep" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/pluginutil" vaulthttp "github.com/hashicorp/vault/http" @@ -101,7 +102,7 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) sys := vault.TestDynamicSystemView(cores[0].Core) - vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", "TestBackend_PluginMain", []string{}, "") + vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain", []string{}, "") return cluster, sys } diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 82b43551ca35..ea71a6f7d15a 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -28,10 +28,11 @@ func (dc *DatabasePluginClient) Close() error { return err } -// newPluginClient returns a databaseRPCClient with a connection to a running +// NewPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger) (Database, error) { +func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (Database, error) { + // pluginSets is the map of plugins we can dispense. pluginSets := map[int]plugin.PluginSet{ // Version 3 supports both protocols @@ -46,7 +47,13 @@ func newPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne }, } - client, err := pluginRunner.Run(ctx, sys, pluginSets, handshakeConfig, []string{}, logger) + var client *plugin.Client + var err error + if isMetadataMode { + client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSets, handshakeConfig, []string{}, logger) + } else { + client, err = pluginRunner.Run(ctx, sys, pluginSets, handshakeConfig, []string{}, logger) + } if err != nil { return nil, err } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 1a39e5e03a3c..918b98b388c9 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/pluginutil" ) @@ -34,7 +35,7 @@ type Database interface { // object in a logging and metrics middleware. func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) { // Look for plugin in the plugin catalog - pluginRunner, err := sys.LookupPlugin(ctx, pluginName) + pluginRunner, err := sys.LookupPlugin(ctx, pluginName, consts.PluginTypeDatabase) if err != nil { return nil, err } @@ -61,7 +62,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu } else { // create a DatabasePluginClient instance - db, err = newPluginClient(ctx, sys, pluginRunner, namedLogger) + db, err = NewPluginClient(ctx, sys, pluginRunner, namedLogger, false) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 8b5ebee5bc19..c61b27321ae5 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -10,6 +10,7 @@ import ( log "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/pluginutil" vaulthttp "github.com/hashicorp/vault/http" @@ -94,8 +95,8 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { cores := cluster.Cores sys := vault.TestDynamicSystemView(cores[0].Core) - vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main", []string{}, "") - vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main", []string{}, "") + vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", consts.PluginTypeDatabase, "TestPlugin_GRPC_Main", []string{}, "") + vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", consts.PluginTypeDatabase, "TestPlugin_NetRPC_Main", []string{}, "") return cluster, sys } diff --git a/builtin/logical/mongodb/backend_test.go b/builtin/logical/mongodb/backend_test.go index f91104903437..8c4a9bb888a0 100644 --- a/builtin/logical/mongodb/backend_test.go +++ b/builtin/logical/mongodb/backend_test.go @@ -121,7 +121,7 @@ func TestBackend_basic(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(connData, false), testAccStepRole(), @@ -147,7 +147,7 @@ func TestBackend_roleCrud(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(connData, false), testAccStepRole(), @@ -175,7 +175,7 @@ func TestBackend_leaseWriteRead(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(connData, false), testAccStepWriteLease(), diff --git a/builtin/logical/mssql/backend_test.go b/builtin/logical/mssql/backend_test.go index 9f91517c36ad..37ccf2600de9 100644 --- a/builtin/logical/mssql/backend_test.go +++ b/builtin/logical/mssql/backend_test.go @@ -116,7 +116,7 @@ func TestBackend_basic(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, PreCheck: testAccPreCheckFunc(t, connURL), - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connURL), testAccStepRole(t), @@ -138,7 +138,7 @@ func TestBackend_roleCrud(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, PreCheck: testAccPreCheckFunc(t, connURL), - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connURL), testAccStepRole(t), @@ -162,7 +162,7 @@ func TestBackend_leaseWriteRead(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, PreCheck: testAccPreCheckFunc(t, connURL), - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connURL), testAccStepWriteLease(t), diff --git a/builtin/logical/mysql/backend_test.go b/builtin/logical/mysql/backend_test.go index 4cc4bd75220f..8d684772b885 100644 --- a/builtin/logical/mysql/backend_test.go +++ b/builtin/logical/mysql/backend_test.go @@ -115,7 +115,7 @@ func TestBackend_basic(t *testing.T) { // for wildcard based mysql user logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData, false), testAccStepRole(t, true), @@ -141,7 +141,7 @@ func TestBackend_basicHostRevoke(t *testing.T) { // for host based mysql user logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData, false), testAccStepRole(t, false), @@ -166,7 +166,7 @@ func TestBackend_roleCrud(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData, false), // test SQL with wildcard based user @@ -197,7 +197,7 @@ func TestBackend_leaseWriteRead(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData, false), testAccStepWriteLease(t), diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 40081d623a45..2ea878a3779c 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -151,8 +151,8 @@ func TestBackend_CSRValues(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{}, + LogicalBackend: b, + Steps: []logicaltest.TestStep{}, } intdata := map[string]interface{}{} @@ -178,8 +178,8 @@ func TestBackend_URLsCRUD(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{}, + LogicalBackend: b, + Steps: []logicaltest.TestStep{}, } intdata := map[string]interface{}{} @@ -208,7 +208,7 @@ func TestBackend_RSARoles(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ logicaltest.TestStep{ Operation: logical.UpdateOperation, @@ -249,7 +249,7 @@ func TestBackend_RSARoles_CSR(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ logicaltest.TestStep{ Operation: logical.UpdateOperation, @@ -290,7 +290,7 @@ func TestBackend_ECRoles(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ logicaltest.TestStep{ Operation: logical.UpdateOperation, @@ -331,7 +331,7 @@ func TestBackend_ECRoles_CSR(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ logicaltest.TestStep{ Operation: logical.UpdateOperation, @@ -1986,8 +1986,8 @@ func TestBackend_SignSelfIssued(t *testing.T) { Subject: pkix.Name{ CommonName: "foo.bar.com", }, - SerialNumber: big.NewInt(1234), - IsCA: false, + SerialNumber: big.NewInt(1234), + IsCA: false, BasicConstraintsValid: true, } @@ -2017,8 +2017,8 @@ func TestBackend_SignSelfIssued(t *testing.T) { Subject: pkix.Name{ CommonName: "bar.foo.com", }, - SerialNumber: big.NewInt(2345), - IsCA: true, + SerialNumber: big.NewInt(2345), + IsCA: true, BasicConstraintsValid: true, } ss, ssCert := getSelfSigned(template, issuer) @@ -2600,7 +2600,7 @@ func setCerts() { SerialNumber: big.NewInt(mathrand.Int63()), NotAfter: time.Now().Add(262980 * time.Hour), BasicConstraintsValid: true, - IsCA: true, + IsCA: true, } caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, cak.Public(), cak) if err != nil { diff --git a/builtin/logical/postgresql/backend_test.go b/builtin/logical/postgresql/backend_test.go index af20ee6797e9..192cbe3820dc 100644 --- a/builtin/logical/postgresql/backend_test.go +++ b/builtin/logical/postgresql/backend_test.go @@ -116,7 +116,7 @@ func TestBackend_basic(t *testing.T) { "connection_url": connURL, } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData, false), testAccStepCreateRole(t, "web", testRole, false), @@ -140,7 +140,7 @@ func TestBackend_roleCrud(t *testing.T) { "connection_url": connURL, } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData, false), testAccStepCreateRole(t, "web", testRole, false), @@ -171,7 +171,7 @@ func TestBackend_BlockStatements(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData, false), // This will also validate the query @@ -196,7 +196,7 @@ func TestBackend_roleReadOnly(t *testing.T) { "connection_url": connURL, } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData, false), testAccStepCreateRole(t, "web", testRole, false), @@ -227,7 +227,7 @@ func TestBackend_roleReadOnly_revocationSQL(t *testing.T) { "connection_url": connURL, } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, connData, false), testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false), diff --git a/builtin/logical/rabbitmq/backend_test.go b/builtin/logical/rabbitmq/backend_test.go index 1fcd04ffbc64..e37424987842 100644 --- a/builtin/logical/rabbitmq/backend_test.go +++ b/builtin/logical/rabbitmq/backend_test.go @@ -82,8 +82,8 @@ func TestBackend_basic(t *testing.T) { defer cleanup() logicaltest.Test(t, logicaltest.TestCase{ - PreCheck: testAccPreCheckFunc(t, uri), - Backend: b, + PreCheck: testAccPreCheckFunc(t, uri), + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, uri), testAccStepRole(t), @@ -104,8 +104,8 @@ func TestBackend_roleCrud(t *testing.T) { defer cleanup() logicaltest.Test(t, logicaltest.TestCase{ - PreCheck: testAccPreCheckFunc(t, uri), - Backend: b, + PreCheck: testAccPreCheckFunc(t, uri), + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepConfig(t, uri), testAccStepRole(t), diff --git a/builtin/logical/ssh/backend_test.go b/builtin/logical/ssh/backend_test.go index 140f0306f2f2..85888839bf70 100644 --- a/builtin/logical/ssh/backend_test.go +++ b/builtin/logical/ssh/backend_test.go @@ -253,7 +253,7 @@ func TestSSHBackend_Lookup(t *testing.T) { resp4 := []string{testDynamicRoleName} logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testLookupRead(t, data, resp1), testRoleWrite(t, testOTPRoleName, testOTPRoleData), @@ -285,7 +285,7 @@ func TestSSHBackend_RoleList(t *testing.T) { }, } logicaltest.Test(t, logicaltest.TestCase{ - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testRoleList(t, resp1), testRoleWrite(t, testOTPRoleName, testOTPRoleData), @@ -309,7 +309,7 @@ func TestSSHBackend_DynamicKeyCreate(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ PreCheck: testAccUserPrecheckFunc(t), AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testNamedKeysWrite(t, testKeyName, testSharedPrivateKey), testRoleWrite(t, testDynamicRoleName, testDynamicRoleData), @@ -332,7 +332,7 @@ func TestSSHBackend_OTPRoleCrud(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testRoleWrite(t, testOTPRoleName, testOTPRoleData), testRoleRead(t, testOTPRoleName, respOTPRoleData), @@ -362,7 +362,7 @@ func TestSSHBackend_DynamicRoleCrud(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testNamedKeysWrite(t, testKeyName, testSharedPrivateKey), testRoleWrite(t, testDynamicRoleName, testDynamicRoleData), @@ -376,7 +376,7 @@ func TestSSHBackend_DynamicRoleCrud(t *testing.T) { func TestSSHBackend_NamedKeysCrud(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testNamedKeysWrite(t, testKeyName, testSharedPrivateKey), testNamedKeysDelete(t), @@ -396,7 +396,7 @@ func TestSSHBackend_OTPCreate(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testRoleWrite(t, testOTPRoleName, testOTPRoleData), testCredsWrite(t, testOTPRoleName, data, false), @@ -413,7 +413,7 @@ func TestSSHBackend_VerifyEcho(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testVerifyWrite(t, verifyData, expectedData), }, @@ -451,7 +451,7 @@ func TestSSHBackend_ConfigZeroAddressCRUD(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testRoleWrite(t, testOTPRoleName, testOTPRoleData), testConfigZeroAddressWrite(t, req1), @@ -483,7 +483,7 @@ func TestSSHBackend_CredsForZeroAddressRoles_otp(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testRoleWrite(t, testOTPRoleName, otpRoleData), testCredsWrite(t, testOTPRoleName, data, true), @@ -512,7 +512,7 @@ func TestSSHBackend_CredsForZeroAddressRoles_dynamic(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ PreCheck: testAccUserPrecheckFunc(t), AcceptanceTest: true, - Factory: testingFactory, + LogicalFactory: testingFactory, Steps: []logicaltest.TestStep{ testNamedKeysWrite(t, testKeyName, testSharedPrivateKey), testRoleWrite(t, testDynamicRoleName, dynamicRoleData), @@ -535,7 +535,7 @@ func TestBackend_AbleToRetrievePublicKey(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ configCaStep(), @@ -571,7 +571,7 @@ func TestBackend_AbleToAutoGenerateSigningKeys(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ logicaltest.TestStep{ Operation: logical.UpdateOperation, @@ -609,7 +609,7 @@ func TestBackend_ValidPrincipalsValidatedForHostCertificates(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ configCaStep(), @@ -652,7 +652,7 @@ func TestBackend_OptionsOverrideDefaults(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ configCaStep(), @@ -700,7 +700,7 @@ func TestBackend_CustomKeyIDFormat(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ configCaStep(), @@ -749,7 +749,7 @@ func TestBackend_DisallowUserProvidedKeyIDs(t *testing.T) { } testCase := logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ configCaStep(), diff --git a/builtin/logical/totp/backend_test.go b/builtin/logical/totp/backend_test.go index a57d29e43fa3..8fb4cefa53b7 100644 --- a/builtin/logical/totp/backend_test.go +++ b/builtin/logical/totp/backend_test.go @@ -138,7 +138,7 @@ func TestBackend_readCredentialsDefaultValues(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -176,7 +176,7 @@ func TestBackend_readCredentialsEightDigitsThirtySecondPeriod(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -214,7 +214,7 @@ func TestBackend_readCredentialsSixDigitsNinetySecondPeriod(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -252,7 +252,7 @@ func TestBackend_readCredentialsSHA256(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -290,7 +290,7 @@ func TestBackend_readCredentialsSHA512(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -329,7 +329,7 @@ func TestBackend_keyCrudDefaultValues(t *testing.T) { invalidCode := "12345678" logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -358,7 +358,7 @@ func TestBackend_createKeyMissingKeyValue(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -382,7 +382,7 @@ func TestBackend_createKeyInvalidKeyValue(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -410,7 +410,7 @@ func TestBackend_createKeyInvalidAlgorithm(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -438,7 +438,7 @@ func TestBackend_createKeyInvalidPeriod(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -466,7 +466,7 @@ func TestBackend_createKeyInvalidDigits(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -500,7 +500,7 @@ func TestBackend_generatedKeyDefaultValues(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -526,7 +526,7 @@ func TestBackend_generatedKeyDefaultValuesNoQR(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), }, @@ -559,7 +559,7 @@ func TestBackend_generatedKeyNonDefaultKeySize(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -583,7 +583,7 @@ func TestBackend_urlPassedNonGeneratedKeyInvalidPeriod(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -607,7 +607,7 @@ func TestBackend_urlPassedNonGeneratedKeyInvalidDigits(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -640,7 +640,7 @@ func TestBackend_urlPassedNonGeneratedKeyIssuerInFirstPosition(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -674,7 +674,7 @@ func TestBackend_urlPassedNonGeneratedKeyIssuerInQueryString(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -708,7 +708,7 @@ func TestBackend_urlPassedNonGeneratedKeyMissingIssuer(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -742,7 +742,7 @@ func TestBackend_urlPassedNonGeneratedKeyMissingAccountName(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -776,7 +776,7 @@ func TestBackend_urlPassedNonGeneratedKeyMissingAccountNameandIssuer(t *testing. } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), @@ -801,7 +801,7 @@ func TestBackend_generatedKeyInvalidSkew(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -825,7 +825,7 @@ func TestBackend_generatedKeyInvalidQRSize(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -849,7 +849,7 @@ func TestBackend_generatedKeyInvalidKeySize(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -871,7 +871,7 @@ func TestBackend_generatedKeyMissingAccountName(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -893,7 +893,7 @@ func TestBackend_generatedKeyMissingIssuer(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -915,7 +915,7 @@ func TestBackend_invalidURLValue(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -937,7 +937,7 @@ func TestBackend_urlAndGenerateTrue(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -959,7 +959,7 @@ func TestBackend_keyAndGenerateTrue(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, true), testAccStepReadKey(t, "test", nil), @@ -991,7 +991,7 @@ func TestBackend_generatedKeyExportedFalse(t *testing.T) { } logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepCreateKey(t, "test", keyData, false), testAccStepReadKey(t, "test", expected), diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 95debdc5585c..6c4edd7d2a5e 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -251,7 +251,7 @@ func testTransit_RSA(t *testing.T, keyType string) { func TestBackend_basic(t *testing.T) { decryptData := make(map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ - Factory: Factory, + LogicalFactory: Factory, Steps: []logicaltest.TestStep{ testAccStepListPolicy(t, "test", true), testAccStepWritePolicy(t, "test", false), @@ -278,7 +278,7 @@ func TestBackend_basic(t *testing.T) { func TestBackend_upsert(t *testing.T) { decryptData := make(map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ - Factory: Factory, + LogicalFactory: Factory, Steps: []logicaltest.TestStep{ testAccStepReadPolicy(t, "test", true, false), testAccStepListPolicy(t, "test", true), @@ -293,7 +293,7 @@ func TestBackend_upsert(t *testing.T) { func TestBackend_datakey(t *testing.T) { dataKeyInfo := make(map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ - Factory: Factory, + LogicalFactory: Factory, Steps: []logicaltest.TestStep{ testAccStepListPolicy(t, "test", true), testAccStepWritePolicy(t, "test", false), @@ -317,7 +317,7 @@ func testBackendRotation(t *testing.T) { decryptData := make(map[string]interface{}) encryptHistory := make(map[int]map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ - Factory: Factory, + LogicalFactory: Factory, Steps: []logicaltest.TestStep{ testAccStepListPolicy(t, "test", true), testAccStepWritePolicy(t, "test", false), @@ -380,7 +380,7 @@ func testBackendRotation(t *testing.T) { func TestBackend_basic_derived(t *testing.T) { decryptData := make(map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ - Factory: Factory, + LogicalFactory: Factory, Steps: []logicaltest.TestStep{ testAccStepListPolicy(t, "test", true), testAccStepWritePolicy(t, "test", true), diff --git a/builtin/plugin/backend.go b/builtin/plugin/backend.go index c448bb89fa2d..34f0512e6f7e 100644 --- a/builtin/plugin/backend.go +++ b/builtin/plugin/backend.go @@ -8,6 +8,7 @@ import ( "sync" uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" bplugin "github.com/hashicorp/vault/logical/plugin" @@ -38,13 +39,18 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, // Backend returns an instance of the backend, either as a plugin if external // or as a concrete implementation if builtin, casted as logical.Backend. func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { - var b backend + var b PluginBackend name := conf.Config["plugin_name"] + pluginType, err := consts.ParsePluginType(conf.Config["plugin_type"]) + if err != nil { + return nil, err + } + sys := conf.System // NewBackend with isMetadataMode set to true - raw, err := bplugin.NewBackend(ctx, name, sys, conf.Logger, true) + raw, err := bplugin.NewBackend(ctx, name, pluginType, sys, conf, true) if err != nil { return nil, err } @@ -71,8 +77,8 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, return &b, nil } -// backend is a thin wrapper around plugin.BackendPluginClient -type backend struct { +// PluginBackend is a thin wrapper around plugin.BackendPluginClient +type PluginBackend struct { logical.Backend sync.RWMutex @@ -85,19 +91,23 @@ type backend struct { loaded bool } -func (b *backend) reloadBackend(ctx context.Context) error { +func (b *PluginBackend) reloadBackend(ctx context.Context) error { b.Logger().Debug("reloading plugin backend", "plugin", b.config.Config["plugin_name"]) return b.startBackend(ctx) } // startBackend starts a plugin backend -func (b *backend) startBackend(ctx context.Context) error { +func (b *PluginBackend) startBackend(ctx context.Context) error { pluginName := b.config.Config["plugin_name"] + pluginType, err := consts.ParsePluginType(b.config.Config["plugin_type"]) + if err != nil { + return err + } // Ensure proper cleanup of the backend (i.e. call client.Kill()) b.Backend.Cleanup(ctx) - nb, err := bplugin.NewBackend(ctx, pluginName, b.config.System, b.config.Logger, false) + nb, err := bplugin.NewBackend(ctx, pluginName, pluginType, b.config.System, b.config, false) if err != nil { return err } @@ -128,7 +138,7 @@ func (b *backend) startBackend(ctx context.Context) error { } // HandleRequest is a thin wrapper implementation of HandleRequest that includes automatic plugin reload. -func (b *backend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { +func (b *PluginBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { b.RLock() canary := b.canary @@ -179,7 +189,7 @@ func (b *backend) HandleRequest(ctx context.Context, req *logical.Request) (*log } // HandleExistenceCheck is a thin wrapper implementation of HandleRequest that includes automatic plugin reload. -func (b *backend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { +func (b *PluginBackend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { b.RLock() canary := b.canary diff --git a/builtin/plugin/backend_test.go b/builtin/plugin/backend_test.go index f90d9450b65e..0b116db9c130 100644 --- a/builtin/plugin/backend_test.go +++ b/builtin/plugin/backend_test.go @@ -1,4 +1,4 @@ -package plugin +package plugin_test import ( "context" @@ -7,24 +7,26 @@ import ( "testing" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/builtin/plugin" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/logging" "github.com/hashicorp/vault/helper/pluginutil" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/logical/plugin" + logicalPlugin "github.com/hashicorp/vault/logical/plugin" "github.com/hashicorp/vault/logical/plugin/mock" "github.com/hashicorp/vault/vault" ) func TestBackend_impl(t *testing.T) { - var _ logical.Backend = &backend{} + var _ logical.Backend = &plugin.PluginBackend{} } func TestBackend(t *testing.T) { config, cleanup := testConfig(t) defer cleanup() - _, err := Backend(context.Background(), config) + _, err := plugin.Backend(context.Background(), config) if err != nil { t.Fatal(err) } @@ -34,7 +36,7 @@ func TestBackend_Factory(t *testing.T) { config, cleanup := testConfig(t) defer cleanup() - _, err := Factory(context.Background(), config) + _, err := plugin.Factory(context.Background(), config) if err != nil { t.Fatal(err) } @@ -59,7 +61,7 @@ func TestBackend_PluginMain(t *testing.T) { tlsConfig := apiClientMeta.GetTLSConfig() tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig) - err := plugin.Serve(&plugin.ServeOpts{ + err := logicalPlugin.Serve(&logicalPlugin.ServeOpts{ BackendFactoryFunc: mock.Factory, TLSProviderFunc: tlsProviderFunc, }) @@ -84,12 +86,13 @@ func testConfig(t *testing.T) (*logical.BackendConfig, func()) { System: sys, Config: map[string]string{ "plugin_name": "mock-plugin", + "plugin_type": "database", }, } os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMain", []string{}, "") + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain", []string{}, "") return config, func() { cluster.Cleanup() diff --git a/command/auth_enable.go b/command/auth_enable.go index 70bb23d646b0..e8eae33ccb13 100644 --- a/command/auth_enable.go +++ b/command/auth_enable.go @@ -8,6 +8,7 @@ import ( "time" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -58,6 +59,10 @@ Usage: vault auth enable [options] TYPE $ vault auth enable -path=my-auth -plugin-name=my-auth-plugin plugin + OR (preferred way): + + $ vault auth enable -path=my-auth my-auth-plugin + ` + c.Flags().Help() return strings.TrimSpace(helpText) @@ -135,7 +140,7 @@ func (c *AuthEnableCommand) Flags() *FlagSets { f.StringVar(&StringVar{ Name: "plugin-name", Target: &c.flagPluginName, - Completion: c.PredictVaultPlugins(), + Completion: c.PredictVaultPlugins(consts.PluginTypeCredential), Usage: "Name of the auth method plugin. This plugin name must already " + "exist in the Vault server's plugin catalog.", }) @@ -212,6 +217,9 @@ func (c *AuthEnableCommand) Run(args []string) int { } authType := strings.TrimSpace(args[0]) + if authType == "plugin" { + authType = c.flagPluginName + } // If no path is specified, we default the path to the backend type // or use the plugin name if it's a plugin backend @@ -242,7 +250,6 @@ func (c *AuthEnableCommand) Run(args []string) int { Config: api.AuthConfigInput{ DefaultLeaseTTL: c.flagDefaultLeaseTTL.String(), MaxLeaseTTL: c.flagMaxLeaseTTL.String(), - PluginName: c.flagPluginName, }, Options: c.flagOptions, } @@ -279,7 +286,6 @@ func (c *AuthEnableCommand) Run(args []string) int { if authType == "plugin" { authThing = c.flagPluginName + " plugin" } - c.UI.Output(fmt.Sprintf("Success! Enabled %s at: %s", authThing, authPath)) return 0 } diff --git a/command/auth_enable_test.go b/command/auth_enable_test.go index d3d693d8126b..4c94cee3da53 100644 --- a/command/auth_enable_test.go +++ b/command/auth_enable_test.go @@ -5,6 +5,8 @@ import ( "strings" "testing" + "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" ) @@ -171,8 +173,11 @@ func TestAuthEnableCommand_Run(t *testing.T) { } } - if len(backends) != len(credentialBackends) { - t.Fatalf("expected %d credential backends, got %d", len(credentialBackends), len(backends)) + // Add 1 to account for the "token" backend, which is visible when you walk the filesystem but + // is treated as special and excluded from the registry. + expected := len(builtinplugins.Registry.Keys(consts.PluginTypeCredential)) + 1 + if len(backends) != expected { + t.Fatalf("expected %d credential backends, got %d", expected, len(backends)) } for _, b := range backends { diff --git a/command/auth_list.go b/command/auth_list.go index 61794fb40d84..ba1dcf8b6d3c 100644 --- a/command/auth_list.go +++ b/command/auth_list.go @@ -155,11 +155,10 @@ func (c *AuthListCommand) detailedMounts(auths map[string]*api.AuthMount) []stri replication = "local" } - out = append(out, fmt.Sprintf("%s | %s | %s | %s | %s | %s | %s | %s | %t | %v | %s", + out = append(out, fmt.Sprintf("%s | %s | %s | %s | %s | %s | %s | %t | %v | %s", path, mount.Type, mount.Accessor, - mount.Config.PluginName, defaultTTL, maxTTL, mount.Config.TokenType, diff --git a/command/base_predict.go b/command/base_predict.go index a280b86e0b2f..a434b2a22f73 100644 --- a/command/base_predict.go +++ b/command/base_predict.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/posener/complete" ) @@ -146,8 +147,8 @@ func (b *BaseCommand) PredictVaultAuths() complete.Predictor { } // PredictVaultPlugins returns a predictor for installed plugins. -func (b *BaseCommand) PredictVaultPlugins() complete.Predictor { - return NewPredict().VaultPlugins() +func (b *BaseCommand) PredictVaultPlugins(pluginTypes ...consts.PluginType) complete.Predictor { + return NewPredict().VaultPlugins(pluginTypes...) } // PredictVaultPolicies returns a predictor for "folders". See PredictVaultFiles @@ -191,8 +192,11 @@ func (p *Predict) VaultAuths() complete.Predictor { // VaultPlugins returns a predictor for Vault's plugin catalog. This is a public // API for consumers, but you probably want BaseCommand.PredictVaultPlugins // instead. -func (p *Predict) VaultPlugins() complete.Predictor { - return p.filterFunc(p.plugins) +func (p *Predict) VaultPlugins(pluginTypes ...consts.PluginType) complete.Predictor { + filterFunc := func() []string { + return p.plugins(pluginTypes...) + } + return p.filterFunc(filterFunc) } // VaultPolicies returns a predictor for Vault "folders". This is a public API for @@ -329,17 +333,35 @@ func (p *Predict) auths() []string { } // plugins returns a sorted list of the plugins in the catalog. -func (p *Predict) plugins() []string { +func (p *Predict) plugins(pluginTypes ...consts.PluginType) []string { + // This method's signature doesn't enforce that a pluginType must be passed in. + // If it's not, it's likely the caller's intent is go get a list of all of them, + // so let's help them out. + if len(pluginTypes) == 0 { + pluginTypes = append(pluginTypes, consts.PluginTypeUnknown) + } + client := p.Client() if client == nil { return nil } - result, err := client.Sys().ListPlugins(nil) - if err != nil { - return nil + var plugins []string + pluginsAdded := make(map[string]bool) + for _, pluginType := range pluginTypes { + result, err := client.Sys().ListPlugins(&api.ListPluginsInput{Type: pluginType}) + if err != nil { + return nil + } + for _, names := range result.PluginsByType { + for _, name := range names { + if _, ok := pluginsAdded[name]; !ok { + plugins = append(plugins, name) + pluginsAdded[name] = true + } + } + } } - plugins := result.Names sort.Strings(plugins) return plugins } diff --git a/command/base_predict_test.go b/command/base_predict_test.go index c9ad1891b2ff..9bfedec41164 100644 --- a/command/base_predict_test.go +++ b/command/base_predict_test.go @@ -322,15 +322,45 @@ func TestPredict_Plugins(t *testing.T) { "good_path", client, []string{ + "ad", + "alicloud", + "app-id", + "approle", + "aws", + "azure", + "cassandra", "cassandra-database-plugin", + "centrify", + "cert", + "consul", + "gcp", + "gcpkms", + "github", "hana-database-plugin", + "jwt", + "kubernetes", + "kv", + "ldap", + "mongodb", "mongodb-database-plugin", + "mssql", "mssql-database-plugin", + "mysql", "mysql-aurora-database-plugin", "mysql-database-plugin", "mysql-legacy-database-plugin", "mysql-rds-database-plugin", + "nomad", + "okta", + "pki", + "postgresql", "postgresql-database-plugin", + "rabbitmq", + "radius", + "ssh", + "totp", + "transit", + "userpass", }, }, } diff --git a/command/command_test.go b/command/command_test.go index ecfc7441d880..a2ae19ec9cf9 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -10,12 +10,13 @@ import ( "time" log "github.com/hashicorp/go-hclog" - kv "github.com/hashicorp/vault-plugin-secrets-kv" + "github.com/hashicorp/vault-plugin-secrets-kv" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/builtin/logical/pki" "github.com/hashicorp/vault/builtin/logical/ssh" "github.com/hashicorp/vault/builtin/logical/transit" + "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/physical/inmem" "github.com/hashicorp/vault/vault" @@ -74,6 +75,7 @@ func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) { CredentialBackends: credentialBackends, AuditBackends: auditBackends, LogicalBackends: logicalBackends, + BuiltinRegistry: builtinplugins.Registry, }) return client, closer } @@ -90,6 +92,7 @@ func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) { CredentialBackends: defaultVaultCredentialBackends, AuditBackends: defaultVaultAuditBackends, LogicalBackends: defaultVaultLogicalBackends, + BuiltinRegistry: builtinplugins.Registry, }) } @@ -107,6 +110,7 @@ func testVaultServerPluginDir(tb testing.TB, pluginDir string) (*api.Client, []s AuditBackends: defaultVaultAuditBackends, LogicalBackends: defaultVaultLogicalBackends, PluginDirectory: pluginDir, + BuiltinRegistry: builtinplugins.Registry, }) } @@ -156,6 +160,7 @@ func testVaultServerUninit(tb testing.TB) (*api.Client, func()) { CredentialBackends: defaultVaultCredentialBackends, AuditBackends: defaultVaultAuditBackends, LogicalBackends: defaultVaultLogicalBackends, + BuiltinRegistry: builtinplugins.Registry, }) if err != nil { tb.Fatal(err) diff --git a/command/commands.go b/command/commands.go index dc3daa0d186c..c46987045651 100644 --- a/command/commands.go +++ b/command/commands.go @@ -6,55 +6,38 @@ import ( "os/signal" "syscall" - ad "github.com/hashicorp/vault-plugin-secrets-ad/plugin" - alicloud "github.com/hashicorp/vault-plugin-secrets-alicloud" - azure "github.com/hashicorp/vault-plugin-secrets-azure" - gcp "github.com/hashicorp/vault-plugin-secrets-gcp/plugin" - gcpkms "github.com/hashicorp/vault-plugin-secrets-gcpkms" - kv "github.com/hashicorp/vault-plugin-secrets-kv" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/version" "github.com/mitchellh/cli" - "github.com/hashicorp/vault/builtin/logical/aws" - "github.com/hashicorp/vault/builtin/logical/cassandra" - "github.com/hashicorp/vault/builtin/logical/consul" - "github.com/hashicorp/vault/builtin/logical/database" - "github.com/hashicorp/vault/builtin/logical/mongodb" - "github.com/hashicorp/vault/builtin/logical/mssql" - "github.com/hashicorp/vault/builtin/logical/mysql" - "github.com/hashicorp/vault/builtin/logical/nomad" - "github.com/hashicorp/vault/builtin/logical/pki" - "github.com/hashicorp/vault/builtin/logical/postgresql" - "github.com/hashicorp/vault/builtin/logical/rabbitmq" - "github.com/hashicorp/vault/builtin/logical/ssh" - "github.com/hashicorp/vault/builtin/logical/totp" - "github.com/hashicorp/vault/builtin/logical/transit" - "github.com/hashicorp/vault/builtin/plugin" + /* + The builtinplugins package is initialized here because it, in turn, + initializes the database plugins. + They register multiple database drivers for the "database/sql" package. + */ + _ "github.com/hashicorp/vault/helper/builtinplugins" auditFile "github.com/hashicorp/vault/builtin/audit/file" auditSocket "github.com/hashicorp/vault/builtin/audit/socket" auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog" credAliCloud "github.com/hashicorp/vault-plugin-auth-alicloud" - credAzure "github.com/hashicorp/vault-plugin-auth-azure" credCentrify "github.com/hashicorp/vault-plugin-auth-centrify" credGcp "github.com/hashicorp/vault-plugin-auth-gcp/plugin" - credJWT "github.com/hashicorp/vault-plugin-auth-jwt" - credKube "github.com/hashicorp/vault-plugin-auth-kubernetes" - credAppId "github.com/hashicorp/vault/builtin/credential/app-id" - credAppRole "github.com/hashicorp/vault/builtin/credential/approle" credAws "github.com/hashicorp/vault/builtin/credential/aws" credCert "github.com/hashicorp/vault/builtin/credential/cert" credGitHub "github.com/hashicorp/vault/builtin/credential/github" credLdap "github.com/hashicorp/vault/builtin/credential/ldap" credOkta "github.com/hashicorp/vault/builtin/credential/okta" - credRadius "github.com/hashicorp/vault/builtin/credential/radius" credToken "github.com/hashicorp/vault/builtin/credential/token" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" + logicalKv "github.com/hashicorp/vault-plugin-secrets-kv" + logicalDb "github.com/hashicorp/vault/builtin/logical/database" + physAliCloudOSS "github.com/hashicorp/vault/physical/alicloudoss" physAzure "github.com/hashicorp/vault/physical/azure" physCassandra "github.com/hashicorp/vault/physical/cassandra" @@ -105,46 +88,15 @@ var ( } credentialBackends = map[string]logical.Factory{ - "alicloud": credAliCloud.Factory, - "app-id": credAppId.Factory, - "approle": credAppRole.Factory, - "aws": credAws.Factory, - "azure": credAzure.Factory, - "centrify": credCentrify.Factory, - "cert": credCert.Factory, - "gcp": credGcp.Factory, - "github": credGitHub.Factory, - "jwt": credJWT.Factory, - "kubernetes": credKube.Factory, - "ldap": credLdap.Factory, - "okta": credOkta.Factory, - "plugin": plugin.Factory, - "radius": credRadius.Factory, - "userpass": credUserpass.Factory, + "plugin": plugin.Factory, } logicalBackends = map[string]logical.Factory{ - "ad": ad.Factory, - "alicloud": alicloud.Factory, - "aws": aws.Factory, - "azure": azure.Factory, - "cassandra": cassandra.Factory, - "consul": consul.Factory, - "database": database.Factory, - "gcp": gcp.Factory, - "gcpkms": gcpkms.Factory, - "kv": kv.Factory, - "mongodb": mongodb.Factory, - "mssql": mssql.Factory, - "mysql": mysql.Factory, - "nomad": nomad.Factory, - "pki": pki.Factory, - "plugin": plugin.Factory, - "postgresql": postgresql.Factory, - "rabbitmq": rabbitmq.Factory, - "ssh": ssh.Factory, - "totp": totp.Factory, - "transit": transit.Factory, + "plugin": plugin.Factory, + "database": logicalDb.Factory, + // This is also available in the plugin catalog, but is here due to the need to + // automatically mount it. + "kv": logicalKv.Factory, } physicalBackends = map[string]physical.Factory{ diff --git a/command/plugin.go b/command/plugin.go index 4ed82850333b..cf0a5009f626 100644 --- a/command/plugin.go +++ b/command/plugin.go @@ -21,19 +21,21 @@ func (c *PluginCommand) Help() string { Usage: vault plugin [options] [args] This command groups subcommands for interacting with Vault's plugins and the - plugin catalog. Here are a few examples of the plugin commands: + plugin catalog. The plugin catalog is divided into three types: "auth", + "database", and "secret" plugins. A type must be specified on each call. Here + are a few examples of the plugin commands. - List all available plugins in the catalog: + List all available plugins in the catalog of a particular type: - $ vault plugin list + $ vault plugin list database - Register a new plugin to the catalog: + Register a new plugin to the catalog as a particular type: - $ vault plugin register -sha256=d3f0a8b... my-custom-plugin + $ vault plugin register -sha256=d3f0a8b... auth my-custom-plugin - Get information about a plugin in the catalog: + Get information about a plugin in the catalog listed under a particular type: - $ vault plugin info my-custom-plugin + $ vault plugin info auth my-custom-plugin Please see the individual subcommand help for detailed usage information. ` diff --git a/command/plugin_deregister.go b/command/plugin_deregister.go index ad6fd66b0245..2e3b8c3b21e0 100644 --- a/command/plugin_deregister.go +++ b/command/plugin_deregister.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -22,14 +23,15 @@ func (c *PluginDeregisterCommand) Synopsis() string { func (c *PluginDeregisterCommand) Help() string { helpText := ` -Usage: vault plugin deregister [options] NAME +Usage: vault plugin deregister [options] TYPE NAME Deregister an existing plugin in the catalog. If the plugin does not exist, - no action is taken (the command is idempotent). + no action is taken (the command is idempotent). The argument of type + takes "auth", "database", or "secret". Deregister the plugin named my-custom-plugin: - $ vault plugin deregister my-custom-plugin + $ vault plugin deregister auth my-custom-plugin ` + c.Flags().Help() @@ -41,7 +43,7 @@ func (c *PluginDeregisterCommand) Flags() *FlagSets { } func (c *PluginDeregisterCommand) AutocompleteArgs() complete.Predictor { - return c.PredictVaultPlugins() + return c.PredictVaultPlugins(consts.PluginTypeUnknown) } func (c *PluginDeregisterCommand) AutocompleteFlags() complete.Flags { @@ -58,11 +60,11 @@ func (c *PluginDeregisterCommand) Run(args []string) int { args = f.Args() switch { - case len(args) < 1: - c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + case len(args) < 2: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 2, got %d)", len(args))) return 1 - case len(args) > 1: - c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + case len(args) > 2: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 2, got %d)", len(args))) return 1 } @@ -72,10 +74,16 @@ func (c *PluginDeregisterCommand) Run(args []string) int { return 2 } - pluginName := strings.TrimSpace(args[0]) + pluginType, err := consts.ParsePluginType(strings.TrimSpace(args[0])) + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + pluginName := strings.TrimSpace(args[1]) if err := client.Sys().DeregisterPlugin(&api.DeregisterPluginInput{ Name: pluginName, + Type: pluginType, }); err != nil { c.UI.Error(fmt.Sprintf("Error deregistering plugin named %s: %s", pluginName, err)) return 2 diff --git a/command/plugin_deregister_test.go b/command/plugin_deregister_test.go index 4d326aafc905..11671aef7cdd 100644 --- a/command/plugin_deregister_test.go +++ b/command/plugin_deregister_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" ) @@ -36,13 +37,13 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { }, { "too_many_args", - []string{"foo", "bar"}, + []string{"foo", "bar", "fizz"}, "Too many arguments", 1, }, { "not_a_plugin", - []string{"nope_definitely_never_a_plugin_nope"}, + []string{consts.PluginTypeCredential.String(), "nope_definitely_never_a_plugin_nope"}, "", 0, }, @@ -82,13 +83,14 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { defer closer() pluginName := "my-plugin" - _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName) + _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential) ui, cmd := testPluginDeregisterCommand(t) cmd.client = client if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{ Name: pluginName, + Type: consts.PluginTypeCredential, Command: pluginName, SHA256: sha256Sum, }); err != nil { @@ -96,6 +98,7 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { } code := cmd.Run([]string{ + consts.PluginTypeCredential.String(), pluginName, }) if exp := 0; code != exp { @@ -108,19 +111,23 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { t.Errorf("expected %q to contain %q", combined, expected) } - resp, err := client.Sys().ListPlugins(&api.ListPluginsInput{}) + resp, err := client.Sys().ListPlugins(&api.ListPluginsInput{ + Type: consts.PluginTypeCredential, + }) if err != nil { t.Fatal(err) } found := false - for _, p := range resp.Names { - if p == pluginName { - found = true + for _, plugins := range resp.PluginsByType { + for _, p := range plugins { + if p == pluginName { + found = true + } } } if found { - t.Errorf("expected %q to not be in %q", pluginName, resp.Names) + t.Errorf("expected %q to not be in %q", pluginName, resp.PluginsByType) } }) @@ -134,6 +141,7 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { cmd.client = client code := cmd.Run([]string{ + consts.PluginTypeCredential.String(), "my-plugin", }) if exp := 2; code != exp { diff --git a/command/plugin_info.go b/command/plugin_info.go index c4232e9f5eca..98c121482bac 100644 --- a/command/plugin_info.go +++ b/command/plugin_info.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -22,14 +23,15 @@ func (c *PluginInfoCommand) Synopsis() string { func (c *PluginInfoCommand) Help() string { helpText := ` -Usage: vault plugin info [options] NAME +Usage: vault plugin info [options] TYPE NAME Displays information about a plugin in the catalog with the given name. If - the plugin does not exist, an error is returned. + the plugin does not exist, an error is returned. The argument of type + takes "auth", "database", or "secret". Get info about a plugin: - $ vault plugin info mysql-database-plugin + $ vault plugin info database mysql-database-plugin ` + c.Flags().Help() @@ -41,7 +43,7 @@ func (c *PluginInfoCommand) Flags() *FlagSets { } func (c *PluginInfoCommand) AutocompleteArgs() complete.Predictor { - return c.PredictVaultPlugins() + return c.PredictVaultPlugins(consts.PluginTypeUnknown) } func (c *PluginInfoCommand) AutocompleteFlags() complete.Flags { @@ -58,10 +60,10 @@ func (c *PluginInfoCommand) Run(args []string) int { args = f.Args() switch { - case len(args) < 1: + case len(args) < 2: c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) return 1 - case len(args) > 1: + case len(args) > 2: c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) return 1 } @@ -72,10 +74,16 @@ func (c *PluginInfoCommand) Run(args []string) int { return 2 } - pluginName := strings.TrimSpace(args[0]) + pluginType, err := consts.ParsePluginType(strings.TrimSpace(args[0])) + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + pluginName := strings.TrimSpace(args[1]) resp, err := client.Sys().GetPlugin(&api.GetPluginInput{ Name: pluginName, + Type: pluginType, }) if err != nil { c.UI.Error(fmt.Sprintf("Error reading plugin named %s: %s", pluginName, err)) diff --git a/command/plugin_info_test.go b/command/plugin_info_test.go index bc6e8bc3badb..a3818a1af06b 100644 --- a/command/plugin_info_test.go +++ b/command/plugin_info_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" ) @@ -29,13 +30,13 @@ func TestPluginInfoCommand_Run(t *testing.T) { }{ { "too_many_args", - []string{"foo", "bar"}, + []string{"foo", "bar", "fizz"}, "Too many arguments", 1, }, { "no_plugin_exist", - []string{"not-a-real-plugin-like-ever"}, + []string{consts.PluginTypeCredential.String(), "not-a-real-plugin-like-ever"}, "Error reading plugin", 2, }, @@ -79,13 +80,13 @@ func TestPluginInfoCommand_Run(t *testing.T) { defer closer() pluginName := "my-plugin" - _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName) + _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential) ui, cmd := testPluginInfoCommand(t) cmd.client = client code := cmd.Run([]string{ - pluginName, + consts.PluginTypeCredential.String(), pluginName, }) if exp := 0; code != exp { t.Errorf("expected %d to be %d", code, exp) @@ -110,14 +111,14 @@ func TestPluginInfoCommand_Run(t *testing.T) { defer closer() pluginName := "my-plugin" - testPluginCreateAndRegister(t, client, pluginDir, pluginName) + testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential) ui, cmd := testPluginInfoCommand(t) cmd.client = client code := cmd.Run([]string{ "-field", "builtin", - pluginName, + consts.PluginTypeCredential.String(), pluginName, }) if exp := 0; code != exp { t.Errorf("expected %d to be %d", code, exp) @@ -139,7 +140,7 @@ func TestPluginInfoCommand_Run(t *testing.T) { cmd.client = client code := cmd.Run([]string{ - "my-plugin", + consts.PluginTypeCredential.String(), "my-plugin", }) if exp := 2; code != exp { t.Errorf("expected %d to be %d", code, exp) diff --git a/command/plugin_list.go b/command/plugin_list.go index 4e8375c66266..e6f34b80bab6 100644 --- a/command/plugin_list.go +++ b/command/plugin_list.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -23,15 +24,20 @@ func (c *PluginListCommand) Synopsis() string { func (c *PluginListCommand) Help() string { helpText := ` -Usage: vault plugin list [options] +Usage: vault plugin list [options] [TYPE] Lists available plugins registered in the catalog. This does not list whether - plugins are in use, but rather just their availability. + plugins are in use, but rather just their availability. The last argument of + type takes "auth", "database", or "secret". List all available plugins in the catalog: $ vault plugin list + List all available database plugins in the catalog: + + $ vault plugin list database + ` + c.Flags().Help() return strings.TrimSpace(helpText) @@ -58,32 +64,60 @@ func (c *PluginListCommand) Run(args []string) int { } args = f.Args() - if len(args) > 0 { - c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + switch { + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0 or 1, got %d)", len(args))) return 1 } + pluginType := consts.PluginTypeUnknown + if len(args) > 0 { + pluginTypeStr := strings.TrimSpace(args[0]) + if pluginTypeStr != "" { + var err error + pluginType, err = consts.ParsePluginType(pluginTypeStr) + if err != nil { + c.UI.Error(fmt.Sprintf("Error parsing type: %s", err)) + return 2 + } + } + } + client, err := c.Client() if err != nil { c.UI.Error(err.Error()) return 2 } - resp, err := client.Sys().ListPlugins(&api.ListPluginsInput{}) + resp, err := client.Sys().ListPlugins(&api.ListPluginsInput{ + Type: pluginType, + }) if err != nil { c.UI.Error(fmt.Sprintf("Error listing available plugins: %s", err)) return 2 } - pluginNames := resp.Names - sort.Strings(pluginNames) - switch Format(c.UI) { case "table": - list := append([]string{"Plugins"}, pluginNames...) + var flattenedNames []string + namesAdded := make(map[string]bool) + for _, names := range resp.PluginsByType { + for _, name := range names { + if ok := namesAdded[name]; !ok { + flattenedNames = append(flattenedNames, name) + namesAdded[name] = true + } + } + sort.Strings(flattenedNames) + } + list := append([]string{"Plugins"}, flattenedNames...) c.UI.Output(tableOutput(list, nil)) return 0 default: - return OutputData(c.UI, pluginNames) + res := make(map[string]interface{}) + for k, v := range resp.PluginsByType { + res[k.String()] = v + } + return OutputData(c.UI, res) } } diff --git a/command/plugin_list_test.go b/command/plugin_list_test.go index 86a274d71fd3..ef0788c69c4c 100644 --- a/command/plugin_list_test.go +++ b/command/plugin_list_test.go @@ -29,7 +29,7 @@ func TestPluginListCommand_Run(t *testing.T) { }{ { "too_many_args", - []string{"foo"}, + []string{"foo", "fizz"}, "Too many arguments", 1, }, @@ -78,7 +78,7 @@ func TestPluginListCommand_Run(t *testing.T) { ui, cmd := testPluginListCommand(t) cmd.client = client - code := cmd.Run([]string{}) + code := cmd.Run([]string{"database"}) if exp := 2; code != exp { t.Errorf("expected %d to be %d", code, exp) } diff --git a/command/plugin_register.go b/command/plugin_register.go index 9bd5e76b33b1..e9a74a3424b7 100644 --- a/command/plugin_register.go +++ b/command/plugin_register.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -26,21 +27,22 @@ func (c *PluginRegisterCommand) Synopsis() string { func (c *PluginRegisterCommand) Help() string { helpText := ` -Usage: vault plugin register [options] NAME +Usage: vault plugin register [options] TYPE NAME Registers a new plugin in the catalog. The plugin binary must exist in Vault's - configured plugin directory. + configured plugin directory. The argument of type takes "auth", "database", + or "secret". Register the plugin named my-custom-plugin: - $ vault plugin register -sha256=d3f0a8b... my-custom-plugin + $ vault plugin register -sha256=d3f0a8b... auth my-custom-plugin Register a plugin with custom arguments: $ vault plugin register \ -sha256=d3f0a8b... \ -args=--with-glibc,--with-cgo \ - my-custom-plugin + auth my-custom-plugin ` + c.Flags().Help() @@ -79,7 +81,7 @@ func (c *PluginRegisterCommand) Flags() *FlagSets { } func (c *PluginRegisterCommand) AutocompleteArgs() complete.Predictor { - return c.PredictVaultPlugins() + return c.PredictVaultPlugins(consts.PluginTypeUnknown) } func (c *PluginRegisterCommand) AutocompleteFlags() complete.Flags { @@ -96,11 +98,11 @@ func (c *PluginRegisterCommand) Run(args []string) int { args = f.Args() switch { - case len(args) < 1: - c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + case len(args) < 2: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 2, got %d)", len(args))) return 1 - case len(args) > 1: - c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + case len(args) > 2: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 2, got %d)", len(args))) return 1 case c.flagSHA256 == "": c.UI.Error("SHA256 is required for all plugins, please provide -sha256") @@ -113,7 +115,12 @@ func (c *PluginRegisterCommand) Run(args []string) int { return 2 } - pluginName := strings.TrimSpace(args[0]) + pluginType, err := consts.ParsePluginType(strings.TrimSpace(args[0])) + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + pluginName := strings.TrimSpace(args[1]) command := c.flagCommand if command == "" { @@ -122,6 +129,7 @@ func (c *PluginRegisterCommand) Run(args []string) int { if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{ Name: pluginName, + Type: pluginType, Args: c.flagArgs, Command: command, SHA256: c.flagSHA256, diff --git a/command/plugin_register_test.go b/command/plugin_register_test.go index aae490298f46..e84108d9ab5f 100644 --- a/command/plugin_register_test.go +++ b/command/plugin_register_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" ) @@ -36,13 +37,13 @@ func TestPluginRegisterCommand_Run(t *testing.T) { }, { "too_many_args", - []string{"foo", "bar"}, + []string{"foo", "bar", "fizz"}, "Too many arguments", 1, }, { "not_a_plugin", - []string{"nope_definitely_never_a_plugin_nope"}, + []string{consts.PluginTypeCredential.String(), "nope_definitely_never_a_plugin_nope"}, "", 2, }, @@ -90,7 +91,7 @@ func TestPluginRegisterCommand_Run(t *testing.T) { code := cmd.Run([]string{ "-sha256", sha256Sum, - pluginName, + consts.PluginTypeCredential.String(), pluginName, }) if exp := 0; code != exp { t.Errorf("expected %d to be %d", code, exp) @@ -102,19 +103,23 @@ func TestPluginRegisterCommand_Run(t *testing.T) { t.Errorf("expected %q to contain %q", combined, expected) } - resp, err := client.Sys().ListPlugins(&api.ListPluginsInput{}) + resp, err := client.Sys().ListPlugins(&api.ListPluginsInput{ + Type: consts.PluginTypeCredential, + }) if err != nil { t.Fatal(err) } found := false - for _, p := range resp.Names { - if p == pluginName { - found = true + for _, plugins := range resp.PluginsByType { + for _, p := range plugins { + if p == pluginName { + found = true + } } } if !found { - t.Errorf("expected %q to be in %q", pluginName, resp.Names) + t.Errorf("expected %q to be in %q", pluginName, resp.PluginsByType) } }) @@ -129,7 +134,7 @@ func TestPluginRegisterCommand_Run(t *testing.T) { code := cmd.Run([]string{ "-sha256", "abcd1234", - "my-plugin", + consts.PluginTypeCredential.String(), "my-plugin", }) if exp := 2; code != exp { t.Errorf("expected %d to be %d", code, exp) diff --git a/command/plugin_test.go b/command/plugin_test.go index 7f64c14721d0..fb0d87ce10c3 100644 --- a/command/plugin_test.go +++ b/command/plugin_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" ) // testPluginDir creates a temporary directory suitable for holding plugins. @@ -61,13 +62,14 @@ func testPluginCreate(tb testing.TB, dir, name string) (string, string) { } // testPluginCreateAndRegister creates a plugin and registers it in the catalog. -func testPluginCreateAndRegister(tb testing.TB, client *api.Client, dir, name string) (string, string) { +func testPluginCreateAndRegister(tb testing.TB, client *api.Client, dir, name string, pluginType consts.PluginType) (string, string) { tb.Helper() pth, sha256Sum := testPluginCreate(tb, dir, name) if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{ Name: name, + Type: pluginType, Command: name, SHA256: sha256Sum, }); err != nil { diff --git a/command/secrets_enable.go b/command/secrets_enable.go index 1756f40bd475..2f2ae69c6ce9 100644 --- a/command/secrets_enable.go +++ b/command/secrets_enable.go @@ -8,6 +8,7 @@ import ( "time" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -65,6 +66,10 @@ Usage: vault secrets enable [options] TYPE $ vault secrets enable -path=my-secrets -plugin-name=my-plugin plugin + OR (preferred way): + + $ vault secrets enable -path=my-secrets my-plugin + For a full list of secrets engines and examples, please see the documentation. ` + c.Flags().Help() @@ -151,7 +156,7 @@ func (c *SecretsEnableCommand) Flags() *FlagSets { f.StringVar(&StringVar{ Name: "plugin-name", Target: &c.flagPluginName, - Completion: c.PredictVaultPlugins(), + Completion: c.PredictVaultPlugins(consts.PluginTypeSecrets, consts.PluginTypeDatabase), Usage: "Name of the secrets engine plugin. This plugin name must already " + "exist in Vault's plugin catalog.", }) @@ -223,6 +228,9 @@ func (c *SecretsEnableCommand) Run(args []string) int { // Get the engine type type (first arg) engineType := strings.TrimSpace(args[0]) + if engineType == "plugin" { + engineType = c.flagPluginName + } // If no path is specified, we default the path to the backend type // or use the plugin name if it's a plugin backend @@ -255,7 +263,6 @@ func (c *SecretsEnableCommand) Run(args []string) int { DefaultLeaseTTL: c.flagDefaultLeaseTTL.String(), MaxLeaseTTL: c.flagMaxLeaseTTL.String(), ForceNoCache: c.flagForceNoCache, - PluginName: c.flagPluginName, }, Options: c.flagOptions, } @@ -288,7 +295,6 @@ func (c *SecretsEnableCommand) Run(args []string) int { if engineType == "plugin" { thing = c.flagPluginName + " plugin" } - c.UI.Output(fmt.Sprintf("Success! Enabled the %s at: %s", thing, mountPath)) return 0 } diff --git a/command/secrets_enable_test.go b/command/secrets_enable_test.go index 3b300bb3241e..f3be8ffaf310 100644 --- a/command/secrets_enable_test.go +++ b/command/secrets_enable_test.go @@ -5,6 +5,8 @@ import ( "strings" "testing" + "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/consts" "github.com/mitchellh/cli" ) @@ -201,9 +203,10 @@ func TestSecretsEnableCommand_Run(t *testing.T) { } } - // Removing one from logical list since plugin is a virtual backend - if len(backends) != len(logicalBackends)-1 { - t.Fatalf("expected %d logical backends, got %d", len(logicalBackends)-1, len(backends)) + // backends are found by walking the directory, which includes the database backend, + // however, the plugins registry omits that one + if len(backends) != len(builtinplugins.Registry.Keys(consts.PluginTypeSecrets))+1 { + t.Fatalf("expected %d logical backends, got %d", len(builtinplugins.Registry.Keys(consts.PluginTypeSecrets))+1, len(backends)) } for _, b := range backends { diff --git a/command/secrets_list.go b/command/secrets_list.go index 3d52a323bca9..71d9e7ff3457 100644 --- a/command/secrets_list.go +++ b/command/secrets_list.go @@ -155,11 +155,10 @@ func (c *SecretsListCommand) detailedMounts(mounts map[string]*api.MountOutput) replication = "local" } - out = append(out, fmt.Sprintf("%s | %s | %s | %s | %s | %s | %t | %s | %t | %v | %s", + out = append(out, fmt.Sprintf("%s | %s | %s | %s | %s | %t | %s | %t | %v | %s", path, mount.Type, mount.Accessor, - mount.Config.PluginName, defaultTTL, maxTTL, mount.Config.ForceNoCache, diff --git a/command/server.go b/command/server.go index 2de6487af710..b1cc61d8eeb0 100644 --- a/command/server.go +++ b/command/server.go @@ -26,10 +26,11 @@ import ( "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-multierror" - sockaddr "github.com/hashicorp/go-sockaddr" + "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/command/server" serverseal "github.com/hashicorp/vault/command/server/seal" + "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/gated-writer" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/logging" @@ -550,6 +551,7 @@ func (c *ServerCommand) Run(args []string) int { DisablePerformanceStandby: config.DisablePerformanceStandby, DisableIndexing: config.DisableIndexing, AllLoggers: allLoggers, + BuiltinRegistry: builtinplugins.Registry, } if c.flagDev { coreConfig.DevToken = c.flagDevRootTokenID @@ -967,6 +969,7 @@ CLUSTER_SYNTHESIS_COMPLETE: var plugins []string if c.flagDevPluginDir != "" && c.flagDevPluginInit { + f, err := os.Open(c.flagDevPluginDir) if err != nil { c.UI.Error(fmt.Sprintf("Error reading plugin dir: %s", err)) @@ -1553,7 +1556,7 @@ func (c *ServerCommand) addPlugin(path, token string, core *vault.Core) error { req := &logical.Request{ Operation: logical.UpdateOperation, ClientToken: token, - Path: "sys/plugins/catalog/" + name, + Path: fmt.Sprintf("sys/plugins/catalog/%s", name), Data: map[string]interface{}{ "sha256": sha256sum, "command": name, diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go deleted file mode 100644 index df424cee676b..000000000000 --- a/helper/builtinplugins/builtin.go +++ /dev/null @@ -1,50 +0,0 @@ -package builtinplugins - -import ( - "github.com/hashicorp/vault/plugins/database/cassandra" - "github.com/hashicorp/vault/plugins/database/hana" - "github.com/hashicorp/vault/plugins/database/mongodb" - "github.com/hashicorp/vault/plugins/database/mssql" - "github.com/hashicorp/vault/plugins/database/mysql" - "github.com/hashicorp/vault/plugins/database/postgresql" - "github.com/hashicorp/vault/plugins/helper/database/credsutil" -) - -// BuiltinFactory is the func signature that should be returned by -// the plugin's New() func. -type BuiltinFactory func() (interface{}, error) - -var plugins = map[string]BuiltinFactory{ - // These four plugins all use the same mysql implementation but with - // different username settings passed by the constructor. - "mysql-database-plugin": mysql.New(mysql.MetadataLen, mysql.MetadataLen, mysql.UsernameLen), - "mysql-aurora-database-plugin": mysql.New(credsutil.NoneLength, mysql.LegacyMetadataLen, mysql.LegacyUsernameLen), - "mysql-rds-database-plugin": mysql.New(credsutil.NoneLength, mysql.LegacyMetadataLen, mysql.LegacyUsernameLen), - "mysql-legacy-database-plugin": mysql.New(credsutil.NoneLength, mysql.LegacyMetadataLen, mysql.LegacyUsernameLen), - - "postgresql-database-plugin": postgresql.New, - "mssql-database-plugin": mssql.New, - "cassandra-database-plugin": cassandra.New, - "mongodb-database-plugin": mongodb.New, - "hana-database-plugin": hana.New, -} - -// Get returns the BuiltinFactory func for a particular backend plugin -// from the plugins map. -func Get(name string) (BuiltinFactory, bool) { - f, ok := plugins[name] - return f, ok -} - -// Keys returns the list of plugin names that are considered builtin plugins. -func Keys() []string { - keys := make([]string, len(plugins)) - - i := 0 - for k := range plugins { - keys[i] = k - i++ - } - - return keys -} diff --git a/helper/builtinplugins/registry.go b/helper/builtinplugins/registry.go new file mode 100644 index 000000000000..021ef3644c12 --- /dev/null +++ b/helper/builtinplugins/registry.go @@ -0,0 +1,174 @@ +package builtinplugins + +import ( + "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + + credAliCloud "github.com/hashicorp/vault-plugin-auth-alicloud" + credAzure "github.com/hashicorp/vault-plugin-auth-azure" + credCentrify "github.com/hashicorp/vault-plugin-auth-centrify" + credGcp "github.com/hashicorp/vault-plugin-auth-gcp/plugin" + credJWT "github.com/hashicorp/vault-plugin-auth-jwt" + credKube "github.com/hashicorp/vault-plugin-auth-kubernetes" + credAppId "github.com/hashicorp/vault/builtin/credential/app-id" + credAppRole "github.com/hashicorp/vault/builtin/credential/approle" + credAws "github.com/hashicorp/vault/builtin/credential/aws" + credCert "github.com/hashicorp/vault/builtin/credential/cert" + credGitHub "github.com/hashicorp/vault/builtin/credential/github" + credLdap "github.com/hashicorp/vault/builtin/credential/ldap" + credOkta "github.com/hashicorp/vault/builtin/credential/okta" + credRadius "github.com/hashicorp/vault/builtin/credential/radius" + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" + + dbCass "github.com/hashicorp/vault/plugins/database/cassandra" + dbHana "github.com/hashicorp/vault/plugins/database/hana" + dbMongo "github.com/hashicorp/vault/plugins/database/mongodb" + dbMssql "github.com/hashicorp/vault/plugins/database/mssql" + dbMysql "github.com/hashicorp/vault/plugins/database/mysql" + dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql" + + logicalAd "github.com/hashicorp/vault-plugin-secrets-ad/plugin" + logicalAlicloud "github.com/hashicorp/vault-plugin-secrets-alicloud" + logicalAzure "github.com/hashicorp/vault-plugin-secrets-azure" + logicalGcp "github.com/hashicorp/vault-plugin-secrets-gcp/plugin" + logicalGcpKms "github.com/hashicorp/vault-plugin-secrets-gcpkms" + logicalKv "github.com/hashicorp/vault-plugin-secrets-kv" + logicalAws "github.com/hashicorp/vault/builtin/logical/aws" + logicalCass "github.com/hashicorp/vault/builtin/logical/cassandra" + logicalConsul "github.com/hashicorp/vault/builtin/logical/consul" + logicalMongo "github.com/hashicorp/vault/builtin/logical/mongodb" + logicalMssql "github.com/hashicorp/vault/builtin/logical/mssql" + logicalMysql "github.com/hashicorp/vault/builtin/logical/mysql" + logicalNomad "github.com/hashicorp/vault/builtin/logical/nomad" + logicalPki "github.com/hashicorp/vault/builtin/logical/pki" + logicalPostgres "github.com/hashicorp/vault/builtin/logical/postgresql" + logicalRabbit "github.com/hashicorp/vault/builtin/logical/rabbitmq" + logicalSsh "github.com/hashicorp/vault/builtin/logical/ssh" + logicalTotp "github.com/hashicorp/vault/builtin/logical/totp" + logicalTransit "github.com/hashicorp/vault/builtin/logical/transit" +) + +// Registry is inherently thread-safe because it's immutable. +// Thus, rather than creating multiple instances of it, we only need one. +var Registry = newRegistry() + +// BuiltinFactory is the func signature that should be returned by +// the plugin's New() func. +type BuiltinFactory func() (interface{}, error) + +func newRegistry() *registry { + return ®istry{ + credentialBackends: map[string]logical.Factory{ + "alicloud": credAliCloud.Factory, + "app-id": credAppId.Factory, + "approle": credAppRole.Factory, + "aws": credAws.Factory, + "azure": credAzure.Factory, + "centrify": credCentrify.Factory, + "cert": credCert.Factory, + "gcp": credGcp.Factory, + "github": credGitHub.Factory, + "jwt": credJWT.Factory, + "kubernetes": credKube.Factory, + "ldap": credLdap.Factory, + "okta": credOkta.Factory, + "radius": credRadius.Factory, + "userpass": credUserpass.Factory, + }, + databasePlugins: map[string]BuiltinFactory{ + // These four plugins all use the same mysql implementation but with + // different username settings passed by the constructor. + "mysql-database-plugin": dbMysql.New(dbMysql.MetadataLen, dbMysql.MetadataLen, dbMysql.UsernameLen), + "mysql-aurora-database-plugin": dbMysql.New(credsutil.NoneLength, dbMysql.LegacyMetadataLen, dbMysql.LegacyUsernameLen), + "mysql-rds-database-plugin": dbMysql.New(credsutil.NoneLength, dbMysql.LegacyMetadataLen, dbMysql.LegacyUsernameLen), + "mysql-legacy-database-plugin": dbMysql.New(credsutil.NoneLength, dbMysql.LegacyMetadataLen, dbMysql.LegacyUsernameLen), + + "postgresql-database-plugin": dbPostgres.New, + "mssql-database-plugin": dbMssql.New, + "cassandra-database-plugin": dbCass.New, + "mongodb-database-plugin": dbMongo.New, + "hana-database-plugin": dbHana.New, + }, + logicalBackends: map[string]logical.Factory{ + "ad": logicalAd.Factory, + "alicloud": logicalAlicloud.Factory, + "aws": logicalAws.Factory, + "azure": logicalAzure.Factory, + "cassandra": logicalCass.Factory, + "consul": logicalConsul.Factory, + "gcp": logicalGcp.Factory, + "gcpkms": logicalGcpKms.Factory, + "kv": logicalKv.Factory, + "mongodb": logicalMongo.Factory, + "mssql": logicalMssql.Factory, + "mysql": logicalMysql.Factory, + "nomad": logicalNomad.Factory, + "pki": logicalPki.Factory, + "postgresql": logicalPostgres.Factory, + "rabbitmq": logicalRabbit.Factory, + "ssh": logicalSsh.Factory, + "totp": logicalTotp.Factory, + "transit": logicalTransit.Factory, + }, + } +} + +type registry struct { + credentialBackends map[string]logical.Factory + databasePlugins map[string]BuiltinFactory + logicalBackends map[string]logical.Factory +} + +// Get returns the BuiltinFactory func for a particular backend plugin +// from the plugins map. +func (r *registry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) { + switch pluginType { + case consts.PluginTypeCredential: + f, ok := r.credentialBackends[name] + return toFunc(f), ok + case consts.PluginTypeSecrets: + f, ok := r.logicalBackends[name] + return toFunc(f), ok + case consts.PluginTypeDatabase: + f, ok := r.databasePlugins[name] + return f, ok + default: + return nil, false + } +} + +// Keys returns the list of plugin names that are considered builtin plugins. +func (r *registry) Keys(pluginType consts.PluginType) []string { + var keys []string + switch pluginType { + case consts.PluginTypeDatabase: + for key := range r.databasePlugins { + keys = append(keys, key) + } + case consts.PluginTypeCredential: + for key := range r.credentialBackends { + keys = append(keys, key) + } + case consts.PluginTypeSecrets: + for key := range r.logicalBackends { + keys = append(keys, key) + } + } + return keys +} + +func (r *registry) Contains(name string, pluginType consts.PluginType) bool { + for _, key := range r.Keys(pluginType) { + if key == name { + return true + } + } + return false +} + +func toFunc(ifc interface{}) func() (interface{}, error) { + return func() (interface{}, error) { + return ifc, nil + } +} diff --git a/helper/consts/plugin_types.go b/helper/consts/plugin_types.go new file mode 100644 index 000000000000..71915ffa2a92 --- /dev/null +++ b/helper/consts/plugin_types.go @@ -0,0 +1,59 @@ +package consts + +import "fmt" + +var PluginTypes = []PluginType{ + PluginTypeUnknown, + PluginTypeCredential, + PluginTypeDatabase, + PluginTypeSecrets, +} + +type PluginType uint32 + +// This is a list of PluginTypes used by Vault. +// If we need to add any in the future, it would +// be best to add them to the _end_ of the list below +// because they resolve to incrementing numbers, +// which may be saved in state somewhere. Thus if +// the name for one of those numbers changed because +// a value were added to the middle, that could cause +// the wrong plugin types to be read from storage +// for a given underlying number. Example of the problem +// here: https://play.golang.org/p/YAaPw5ww3er +const ( + PluginTypeUnknown PluginType = iota + PluginTypeCredential + PluginTypeDatabase + PluginTypeSecrets +) + +func (p PluginType) String() string { + switch p { + case PluginTypeUnknown: + return "unknown" + case PluginTypeCredential: + return "auth" + case PluginTypeDatabase: + return "database" + case PluginTypeSecrets: + return "secret" + default: + return "unsupported" + } +} + +func ParsePluginType(pluginType string) (PluginType, error) { + switch pluginType { + case "unknown": + return PluginTypeUnknown, nil + case "auth": + return PluginTypeCredential, nil + case "database": + return PluginTypeDatabase, nil + case "secret": + return PluginTypeSecrets, nil + default: + return PluginTypeUnknown, fmt.Errorf("%s is not a supported plugin type", pluginType) + } +} diff --git a/helper/mfa/mfa_test.go b/helper/mfa/mfa_test.go index 2706d4f2e3ab..29ec0ec4e95a 100644 --- a/helper/mfa/mfa_test.go +++ b/helper/mfa/mfa_test.go @@ -71,7 +71,7 @@ func TestMFALogin(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepEnableMFA(t), testAccStepLogin(t, "user"), @@ -84,7 +84,7 @@ func TestMFALoginDenied(t *testing.T) { logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, - Backend: b, + LogicalBackend: b, Steps: []logicaltest.TestStep{ testAccStepEnableMFA(t), testAccStepLoginDenied(t, "user"), diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 74fe95cc2441..2323684dff8e 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -12,6 +12,7 @@ import ( log "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/version" ) @@ -19,7 +20,7 @@ import ( // Looker defines the plugin Lookup function that looks into the plugin catalog // for available plugins and returns a PluginRunner type Looker interface { - LookupPlugin(context.Context, string) (*PluginRunner, error) + LookupPlugin(context.Context, string, consts.PluginType) (*PluginRunner, error) } // RunnerUtil interface defines the functions needed by the runner to wrap the @@ -41,6 +42,7 @@ type LookRunnerUtil interface { // go-plugin. type PluginRunner struct { Name string `json:"name" structs:"name"` + Type consts.PluginType `json:"type" structs:"type"` Command string `json:"command" structs:"command"` Args []string `json:"args" structs:"args"` Env []string `json:"env" structs:"env"` @@ -73,7 +75,7 @@ func (r *PluginRunner) runCommon(ctx context.Context, wrapper RunnerUtil, plugin cmd.Env = append(cmd.Env, env...) // Add the mlock setting to the ENV of the plugin - if wrapper.MlockEnabled() { + if wrapper != nil && wrapper.MlockEnabled() { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version)) diff --git a/http/handler_test.go b/http/handler_test.go index d2553e25093f..45b4bafca2f0 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -11,10 +11,9 @@ import ( "strings" "testing" - "github.com/hashicorp/vault/helper/namespace" - "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/vault" ) @@ -274,7 +273,6 @@ func TestSysMounts_headerAuth(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -287,7 +285,6 @@ func TestSysMounts_headerAuth(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -300,7 +297,6 @@ func TestSysMounts_headerAuth(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -313,7 +309,6 @@ func TestSysMounts_headerAuth(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -327,7 +322,6 @@ func TestSysMounts_headerAuth(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -340,7 +334,6 @@ func TestSysMounts_headerAuth(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -353,7 +346,6 @@ func TestSysMounts_headerAuth(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -366,7 +358,6 @@ func TestSysMounts_headerAuth(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, diff --git a/http/plugin_test.go b/http/plugin_test.go index bdedc0e699bb..71b051c69e71 100644 --- a/http/plugin_test.go +++ b/http/plugin_test.go @@ -11,6 +11,7 @@ import ( log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" bplugin "github.com/hashicorp/vault/builtin/plugin" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/plugin" @@ -50,12 +51,11 @@ func getPluginClusterAndCore(t testing.TB, logger log.Logger) (*vault.TestCluste os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) vault.TestWaitActive(t, core.Core) - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestPlugin_PluginMain", []string{}, "") + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestPlugin_PluginMain", []string{}, "") // Mount the mock plugin err = core.Client.Sys().Mount("mock", &api.MountInput{ - Type: "plugin", - PluginName: "mock-plugin", + Type: "mock-plugin", }) if err != nil { t.Fatal(err) diff --git a/http/sys_auth_test.go b/http/sys_auth_test.go index e47073927e25..33f631429814 100644 --- a/http/sys_auth_test.go +++ b/http/sys_auth_test.go @@ -32,7 +32,6 @@ func TestSysAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), - "plugin_name": "", "token_type": "default-service", "force_no_cache": false, }, @@ -47,7 +46,6 @@ func TestSysAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), - "plugin_name": "", "token_type": "default-service", "force_no_cache": false, }, @@ -102,7 +100,6 @@ func TestSysEnableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), - "plugin_name": "", "token_type": "default-service", "force_no_cache": false, }, @@ -116,7 +113,6 @@ func TestSysEnableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), - "plugin_name": "", "force_no_cache": false, "token_type": "default-service", }, @@ -131,7 +127,6 @@ func TestSysEnableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), - "plugin_name": "", "token_type": "default-service", "force_no_cache": false, }, @@ -145,7 +140,6 @@ func TestSysEnableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), - "plugin_name": "", "token_type": "default-service", "force_no_cache": false, }, @@ -201,7 +195,6 @@ func TestSysDisableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), - "plugin_name": "", "token_type": "default-service", "force_no_cache": false, }, @@ -216,7 +209,6 @@ func TestSysDisableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), - "plugin_name": "", "token_type": "default-service", "force_no_cache": false, }, diff --git a/http/sys_mount_test.go b/http/sys_mount_test.go index 5ad03b72ed3e..5193b101d0fd 100644 --- a/http/sys_mount_test.go +++ b/http/sys_mount_test.go @@ -33,7 +33,6 @@ func TestSysMounts(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -46,7 +45,6 @@ func TestSysMounts(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -59,7 +57,6 @@ func TestSysMounts(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -72,7 +69,6 @@ func TestSysMounts(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -86,7 +82,6 @@ func TestSysMounts(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -99,7 +94,6 @@ func TestSysMounts(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -112,7 +106,6 @@ func TestSysMounts(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -125,7 +118,6 @@ func TestSysMounts(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -181,7 +173,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -194,7 +185,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -207,7 +197,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -220,7 +209,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -233,7 +221,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -247,7 +234,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -260,7 +246,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -273,7 +258,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -286,7 +270,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -299,7 +282,6 @@ func TestSysMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -374,7 +356,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -387,7 +368,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -400,7 +380,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -413,7 +392,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -426,7 +404,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -440,7 +417,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -453,7 +429,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -466,7 +441,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -479,7 +453,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -492,7 +465,6 @@ func TestSysRemount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -548,7 +520,6 @@ func TestSysUnmount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -561,7 +532,6 @@ func TestSysUnmount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -574,7 +544,6 @@ func TestSysUnmount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -587,7 +556,6 @@ func TestSysUnmount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -601,7 +569,6 @@ func TestSysUnmount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -614,7 +581,6 @@ func TestSysUnmount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -627,7 +593,6 @@ func TestSysUnmount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -640,7 +605,6 @@ func TestSysUnmount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -778,7 +742,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -791,7 +754,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -804,7 +766,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -817,7 +778,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -830,7 +790,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -844,7 +803,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -857,7 +815,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -870,7 +827,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -883,7 +839,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -896,7 +851,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -978,7 +932,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("259196400"), "max_lease_ttl": json.Number("259200000"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -991,7 +944,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -1004,7 +956,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -1017,7 +968,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -1030,7 +980,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -1044,7 +993,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("259196400"), "max_lease_ttl": json.Number("259200000"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -1057,7 +1005,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -1070,7 +1017,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -1083,7 +1029,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": true, "seal_wrap": false, @@ -1096,7 +1041,6 @@ func TestSysTuneMount(t *testing.T) { "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, - "plugin_name": "", }, "local": false, "seal_wrap": false, diff --git a/logical/plugin/grpc_system.go b/logical/plugin/grpc_system.go index bcf5e70b6561..5b7a5824c446 100644 --- a/logical/plugin/grpc_system.go +++ b/logical/plugin/grpc_system.go @@ -4,11 +4,8 @@ import ( "context" "encoding/json" "errors" - "time" - - "google.golang.org/grpc" - "fmt" + "time" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/license" @@ -16,6 +13,7 @@ import ( "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/plugin/pb" + "google.golang.org/grpc" ) func newGRPCSystemView(conn *grpc.ClientConn) *gRPCSystemViewClient { @@ -111,7 +109,7 @@ func (s *gRPCSystemViewClient) ResponseWrapData(ctx context.Context, data map[st return info, nil } -func (s *gRPCSystemViewClient) LookupPlugin(ctx context.Context, name string) (*pluginutil.PluginRunner, error) { +func (s *gRPCSystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) { return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend") } diff --git a/logical/plugin/grpc_system_test.go b/logical/plugin/grpc_system_test.go index 42d087beb94b..98ec70f65ada 100644 --- a/logical/plugin/grpc_system_test.go +++ b/logical/plugin/grpc_system_test.go @@ -142,7 +142,7 @@ func TestSystem_GRPC_lookupPlugin(t *testing.T) { testSystemView := newGRPCSystemView(client) - if _, err := testSystemView.LookupPlugin(context.Background(), "foo"); err == nil { + if _, err := testSystemView.LookupPlugin(context.Background(), "foo", consts.PluginTypeDatabase); err == nil { t.Fatal("LookPlugin(): expected error on due to unsupported call from plugin") } } diff --git a/logical/plugin/plugin.go b/logical/plugin/plugin.go index 7ef3a5577094..250097c22a68 100644 --- a/logical/plugin/plugin.go +++ b/logical/plugin/plugin.go @@ -7,13 +7,13 @@ import ( "encoding/gob" "errors" "fmt" - "time" - "sync" + "time" "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" ) @@ -61,9 +61,9 @@ func (b *BackendPluginClient) Cleanup(ctx context.Context) { // external plugins, or a concrete implementation of the backend if it is a builtin backend. // The backend is returned as a logical.Backend interface. The isMetadataMode param determines whether // the plugin should run in metadata mode. -func NewBackend(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger, isMetadataMode bool) (logical.Backend, error) { +func NewBackend(ctx context.Context, pluginName string, pluginType consts.PluginType, sys pluginutil.LookRunnerUtil, conf *logical.BackendConfig, isMetadataMode bool) (logical.Backend, error) { // Look for plugin in the plugin catalog - pluginRunner, err := sys.LookupPlugin(ctx, pluginName) + pluginRunner, err := sys.LookupPlugin(ctx, pluginName, pluginType) if err != nil { return nil, err } @@ -71,21 +71,22 @@ func NewBackend(ctx context.Context, pluginName string, sys pluginutil.LookRunne var backend logical.Backend if pluginRunner.Builtin { // Plugin is builtin so we can retrieve an instance of the interface - // from the pluginRunner. Then cast it to logical.Backend. - backendRaw, err := pluginRunner.BuiltinFactory() + // from the pluginRunner. Then cast it to logical.Factory. + rawFactory, err := pluginRunner.BuiltinFactory() if err != nil { return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err) } - var ok bool - backend, ok = backendRaw.(logical.Backend) - if !ok { + if factory, ok := rawFactory.(logical.Factory); !ok { return nil, fmt.Errorf("unsupported backend type: %q", pluginName) + } else { + if backend, err = factory(ctx, conf); err != nil { + return nil, err + } } - } else { // create a backendPluginClient instance - backend, err = newPluginClient(ctx, sys, pluginRunner, logger, isMetadataMode) + backend, err = NewPluginClient(ctx, sys, pluginRunner, conf.Logger, isMetadataMode) if err != nil { return nil, err } @@ -94,7 +95,7 @@ func NewBackend(ctx context.Context, pluginName string, sys pluginutil.LookRunne return backend, nil } -func newPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) { +func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) { // pluginMap is the map of plugins we can dispense. pluginSet := map[int]plugin.PluginSet{ 3: plugin.PluginSet{ diff --git a/logical/plugin/system.go b/logical/plugin/system.go index 890f4ef5a2df..148f39a96d8d 100644 --- a/logical/plugin/system.go +++ b/logical/plugin/system.go @@ -106,7 +106,7 @@ func (s *SystemViewClient) ResponseWrapData(ctx context.Context, data map[string return reply.ResponseWrapInfo, nil } -func (s *SystemViewClient) LookupPlugin(ctx context.Context, name string) (*pluginutil.PluginRunner, error) { +func (s *SystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) { return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend") } diff --git a/logical/plugin/system_test.go b/logical/plugin/system_test.go index 32e13b4d430c..dd712631af00 100644 --- a/logical/plugin/system_test.go +++ b/logical/plugin/system_test.go @@ -150,7 +150,7 @@ func TestSystem_lookupPlugin(t *testing.T) { testSystemView := &SystemViewClient{client: client} - if _, err := testSystemView.LookupPlugin(context.Background(), "foo"); err == nil { + if _, err := testSystemView.LookupPlugin(context.Background(), "foo", consts.PluginTypeDatabase); err == nil { t.Fatal("LookPlugin(): expected error on due to unsupported call from plugin") } } diff --git a/logical/system_view.go b/logical/system_view.go index f970847485a2..dff258b11c09 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -56,7 +56,7 @@ type SystemView interface { // LookupPlugin looks into the plugin catalog for a plugin with the given // name. Returns a PluginRunner or an error if a plugin can not be found. - LookupPlugin(context.Context, string) (*pluginutil.PluginRunner, error) + LookupPlugin(context.Context, string, consts.PluginType) (*pluginutil.PluginRunner, error) // MlockEnabled returns the configuration setting for enabling mlock on // plugins. @@ -118,7 +118,7 @@ func (d StaticSystemView) ResponseWrapData(_ context.Context, data map[string]in return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView") } -func (d StaticSystemView) LookupPlugin(_ context.Context, name string) (*pluginutil.PluginRunner, error) { +func (d StaticSystemView) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) { return nil, errors.New("LookupPlugin is not implemented in StaticSystemView") } diff --git a/logical/testing/testing.go b/logical/testing/testing.go index e3ada93b02a4..0171330f1788 100644 --- a/logical/testing/testing.go +++ b/logical/testing/testing.go @@ -33,12 +33,19 @@ type TestCase struct { // test running. PreCheck func() - // Backend is the backend that will be mounted. - Backend logical.Backend + // LogicalBackend is the backend that will be mounted. + LogicalBackend logical.Backend - // Factory can be used instead of Backend if the + // LogicalFactory can be used instead of LogicalBackend if the // backend requires more construction - Factory logical.Factory + LogicalFactory logical.Factory + + // CredentialBackend is the backend that will be mounted. + CredentialBackend logical.Backend + + // CredentialFactory can be used instead of CredentialBackend if the + // backend requires more construction + CredentialFactory logical.Factory // Steps are the set of operations that are run for this test case. Steps []TestStep @@ -135,8 +142,15 @@ func Test(tt TestT, c TestCase) { } // Check that something is provided - if c.Backend == nil && c.Factory == nil { - tt.Fatal("Must provide either Backend or Factory") + if c.LogicalBackend == nil && c.LogicalFactory == nil { + if c.CredentialBackend == nil && c.CredentialFactory == nil { + tt.Fatal("Must provide either Backend or Factory") + return + } + } + // We currently only support doing one logical OR one credential test at a time. + if (c.LogicalFactory != nil || c.LogicalBackend != nil) && (c.CredentialFactory != nil || c.CredentialBackend != nil) { + tt.Fatal("Must provide only one backend or factory") return } @@ -149,18 +163,34 @@ func Test(tt TestT, c TestCase) { return } - core, err := vault.NewCore(&vault.CoreConfig{ - Physical: phys, - LogicalBackends: map[string]logical.Factory{ + config := &vault.CoreConfig{ + Physical: phys, + DisableMlock: true, + BuiltinRegistry: vault.NewMockBuiltinRegistry(), + } + + if c.LogicalBackend != nil || c.LogicalFactory != nil { + config.LogicalBackends = map[string]logical.Factory{ "test": func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { - if c.Backend != nil { - return c.Backend, nil + if c.LogicalBackend != nil { + return c.LogicalBackend, nil } - return c.Factory(ctx, conf) + return c.LogicalFactory(ctx, conf) }, - }, - DisableMlock: true, - }) + } + } + if c.CredentialBackend != nil || c.CredentialFactory != nil { + config.CredentialBackends = map[string]logical.Factory{ + "test": func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { + if c.CredentialBackend != nil { + return c.CredentialBackend, nil + } + return c.CredentialFactory(ctx, conf) + }, + } + } + + core, err := vault.NewCore(config) if err != nil { tt.Fatal("error initializing core: ", err) return @@ -202,15 +232,31 @@ func Test(tt TestT, c TestCase) { // Set the token so we're authenticated client.SetToken(init.RootToken) - // Mount the backend prefix := "mnt" - mountInfo := &api.MountInput{ - Type: "test", - Description: "acceptance test", + if c.LogicalBackend != nil || c.LogicalFactory != nil { + // Mount the backend + mountInfo := &api.MountInput{ + Type: "test", + Description: "acceptance test", + } + if err := client.Sys().Mount(prefix, mountInfo); err != nil { + tt.Fatal("error mounting backend: ", err) + return + } } - if err := client.Sys().Mount(prefix, mountInfo); err != nil { - tt.Fatal("error mounting backend: ", err) - return + + isAuthBackend := false + if c.CredentialBackend != nil || c.CredentialFactory != nil { + isAuthBackend = true + + // Enable the test auth method + opts := &api.EnableAuthOptions{ + Type: "test", + } + if err := client.Sys().EnableAuthWithOptions(prefix, opts); err != nil { + tt.Fatal("error enabling backend: ", err) + return + } } tokenInfo, err := client.Auth().Token().LookupSelf() @@ -269,6 +315,11 @@ func Test(tt TestT, c TestCase) { // Make sure to prefix the path with where we mounted the thing req.Path = fmt.Sprintf("%s/%s", prefix, req.Path) + if isAuthBackend { + // Prepend the path with "auth" + req.Path = "auth/" + req.Path + } + // Make the request resp, err := core.HandleRequest(namespace.RootContext(nil), req) if resp != nil && resp.Secret != nil { @@ -338,7 +389,11 @@ func Test(tt TestT, c TestCase) { // We set the "immediate" flag here that any backend can pick up on // to do all rollbacks immediately even if the WAL entries are new. logger.Warn("Requesting RollbackOperation") - req := logical.RollbackRequest(prefix + "/") + rollbackPath := prefix + "/" + if c.CredentialFactory != nil || c.CredentialBackend != nil { + rollbackPath = "auth/" + rollbackPath + } + req := logical.RollbackRequest(rollbackPath) req.Data["immediate"] = true req.ClientToken = client.Token() resp, err := core.HandleRequest(namespace.RootContext(nil), req) diff --git a/vault/auth.go b/vault/auth.go index bd67261c0dd9..49b28c2dec7e 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/namespace" @@ -141,11 +142,6 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry, var backend logical.Backend // Create the new backend sysView := c.mountEntrySysView(entry) - conf := make(map[string]string) - if entry.Config.PluginName != "" { - conf["plugin_name"] = entry.Config.PluginName - } - // Create the new backend backend, err = c.newCredentialBackend(ctx, entry, sysView, view) if err != nil { return err @@ -156,8 +152,8 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry, // Check for the correct backend type backendType := backend.Type() - if entry.Type == "plugin" && backendType != logical.TypeCredential { - return fmt.Errorf("cannot mount %q of type %q as an auth backend", entry.Config.PluginName, backendType) + if backendType != logical.TypeCredential { + return fmt.Errorf("cannot mount %q of type %q as an auth backend", entry.Type, backendType) } addPathCheckers(c, entry, backend, viewPath) @@ -600,15 +596,11 @@ func (c *Core) setupCredentials(ctx context.Context) error { // Initialize the backend sysView := c.mountEntrySysView(entry) - conf := make(map[string]string) - if entry.Config.PluginName != "" { - conf["plugin_name"] = entry.Config.PluginName - } backend, err = c.newCredentialBackend(ctx, entry, sysView, view) if err != nil { c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err) - if entry.Type == "plugin" { + if !c.builtinRegistry.Contains(entry.Type, consts.PluginTypeCredential) { // If we encounter an error instantiating the backend due to an error, // skip backend initialization but register the entry to the mount table // to preserve storage and path. @@ -624,8 +616,8 @@ func (c *Core) setupCredentials(ctx context.Context) error { { // Check for the correct backend type backendType := backend.Type() - if entry.Type == "plugin" && backendType != logical.TypeCredential { - return fmt.Errorf("cannot mount %q of type %q as an auth backend", entry.Config.PluginName, backendType) + if backendType != logical.TypeCredential { + return fmt.Errorf("cannot mount %q of type %q as an auth backend", entry.Type, backendType) } addPathCheckers(c, entry, backend, viewPath) @@ -717,7 +709,7 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV f, ok := c.credentialBackends[t] if !ok { - return nil, fmt.Errorf("unknown backend type: %q", t) + f = plugin.Factory } // Set up conf to pass in plugin_name @@ -725,10 +717,16 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV for k, v := range entry.Options { conf[k] = v } - if entry.Config.PluginName != "" { - conf["plugin_name"] = entry.Config.PluginName + + switch { + case entry.Type == "plugin": + conf["plugin_name"] = entry.Config.PluginNameDeprecated + default: + conf["plugin_name"] = t } + conf["plugin_type"] = consts.PluginTypeCredential.String() + authLogger := c.baseLogger.Named(fmt.Sprintf("auth.%s.%s", t, entry.Accessor)) c.AddLogger(authLogger) config := &logical.BackendConfig{ diff --git a/vault/auth_test.go b/vault/auth_test.go index 22fbacd36c39..db3409256e2a 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -21,7 +21,9 @@ func TestAuth_ReadOnlyViewDuringMount(t *testing.T) { if err == nil || !strings.Contains(err.Error(), logical.ErrSetupReadOnly.Error()) { t.Fatalf("expected a read-only error") } - return &NoopBackend{}, nil + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil } me := &MountEntry{ @@ -67,7 +69,9 @@ func TestCore_DefaultAuthTable(t *testing.T) { func TestCore_EnableCredential(t *testing.T) { c, keys, _ := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { - return &NoopBackend{}, nil + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil } me := &MountEntry{ @@ -94,7 +98,9 @@ func TestCore_EnableCredential(t *testing.T) { t.Fatalf("err: %v", err) } c2.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { - return &NoopBackend{}, nil + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil } for i, key := range keys { unseal, err := TestCoreUnseal(c2, key) @@ -118,7 +124,9 @@ func TestCore_EnableCredential(t *testing.T) { func TestCore_EnableCredential_Local(t *testing.T) { c, _, _ := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { - return &NoopBackend{}, nil + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil } c.auth = &MountTable{ @@ -205,7 +213,9 @@ func TestCore_EnableCredential_Local(t *testing.T) { func TestCore_EnableCredential_twice_409(t *testing.T) { c, _, _ := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { - return &NoopBackend{}, nil + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil } me := &MountEntry{ @@ -246,7 +256,9 @@ func TestCore_EnableCredential_Token(t *testing.T) { func TestCore_DisableCredential(t *testing.T) { c, keys, _ := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { - return &NoopBackend{}, nil + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil } err := c.disableCredential(namespace.RootContext(nil), "foo") @@ -308,7 +320,8 @@ func TestCore_DisableCredential_Protected(t *testing.T) { func TestCore_DisableCredential_Cleanup(t *testing.T) { noop := &NoopBackend{ - Login: []string{"login"}, + Login: []string{"login"}, + BackendType: logical.TypeCredential, } c, _, _ := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { diff --git a/vault/core.go b/vault/core.go index 034ef3a4add1..d559f14ca3c8 100644 --- a/vault/core.go +++ b/vault/core.go @@ -17,7 +17,7 @@ import ( "github.com/armon/go-metrics" log "github.com/hashicorp/go-hclog" - cache "github.com/patrickmn/go-cache" + "github.com/patrickmn/go-cache" "google.golang.org/grpc" @@ -137,6 +137,10 @@ type unlockInformation struct { type Core struct { entCore + // The registry of builtin plugins is passed in here as an interface because + // if it's used directly, it results in import cycles. + builtinRegistry BuiltinRegistry + // N.B.: This is used to populate a dev token down replication, as // otherwise, after replication is started, a dev would have to go through // the generate-root process simply to talk to the new follower cluster. @@ -403,6 +407,8 @@ type Core struct { type CoreConfig struct { DevToken string `json:"dev_token" structs:"dev_token" mapstructure:"dev_token"` + BuiltinRegistry BuiltinRegistry `json:"builtin_registry" structs:"builtin_registry" mapstructure:"builtin_registry"` + LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"` CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"` @@ -567,6 +573,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { disablePerfStandby: true, activeContextCancelFunc: new(atomic.Value), allLoggers: conf.AllLoggers, + builtinRegistry: conf.BuiltinRegistry, } atomic.StoreUint32(c.sealed, 1) @@ -619,7 +626,6 @@ func NewCore(conf *CoreConfig) (*Core, error) { } var err error - var ok bool if conf.PluginDirectory != "" { c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) @@ -648,15 +654,15 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.reloadFuncsLock.Unlock() conf.ReloadFuncs = &c.reloadFuncs - // Setup the backends logicalBackends := make(map[string]logical.Factory) for k, f := range conf.LogicalBackends { logicalBackends[k] = f } - _, ok = logicalBackends["kv"] + _, ok := logicalBackends["kv"] if !ok { logicalBackends["kv"] = PassthroughBackendFactory } + logicalBackends["cubbyhole"] = CubbyholeBackendFactory logicalBackends[systemMountType] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { sysBackendLogger := conf.Logger.Named("system") @@ -1382,7 +1388,7 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c return err } } - if err := c.setupPluginCatalog(); err != nil { + if err := c.setupPluginCatalog(ctx); err != nil { return err } if err := c.loadMounts(ctx); err != nil { @@ -1695,3 +1701,12 @@ func (c *Core) SetLogLevel(level log.Level) { logger.SetLevel(level) } } + +// BuiltinRegistry is an interface that allows the "vault" package to use +// the registry of builtin plugins without getting an import cycle. It +// also allows for mocking the registry easily. +type BuiltinRegistry interface { + Contains(name string, pluginType consts.PluginType) bool + Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) + Keys(pluginType consts.PluginType) []string +} diff --git a/vault/core_test.go b/vault/core_test.go index 39c0715f264b..3a4174e1748a 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -632,8 +632,9 @@ func TestCore_HandleRequest_NoClientToken(t *testing.T) { func TestCore_HandleRequest_ConnOnLogin(t *testing.T) { noop := &NoopBackend{ - Login: []string{"login"}, - Response: &logical.Response{}, + Login: []string{"login"}, + Response: &logical.Response{}, + BackendType: logical.TypeCredential, } c, _, root := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { @@ -675,6 +676,7 @@ func TestCore_HandleLogin_Token(t *testing.T) { DisplayName: "armon", }, }, + BackendType: logical.TypeCredential, } c, _, root := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { @@ -919,6 +921,7 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) { }, }, }, + BackendType: logical.TypeCredential, } c, _, root := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { @@ -1782,6 +1785,7 @@ func TestCore_HandleRequest_Login_InternalData(t *testing.T) { }, }, }, + BackendType: logical.TypeCredential, } c, _, root := TestCoreUnsealed(t) @@ -1871,6 +1875,7 @@ func TestCore_HandleLogin_ReturnSecret(t *testing.T) { Policies: []string{"foo", "bar"}, }, }, + BackendType: logical.TypeCredential, } c, _, root := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { @@ -2022,6 +2027,7 @@ func TestCore_EnableDisableCred_WithLease(t *testing.T) { Policies: []string{"root"}, }, }, + BackendType: logical.TypeCredential, } c, _, root := TestCoreUnsealed(t) diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index a83e60c7bda3..eef5e19c4561 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -164,14 +164,14 @@ func (d dynamicSystemView) ResponseWrapData(ctx context.Context, data map[string // LookupPlugin looks for a plugin with the given name in the plugin catalog. It // returns a PluginRunner or an error if no plugin was found. -func (d dynamicSystemView) LookupPlugin(ctx context.Context, name string) (*pluginutil.PluginRunner, error) { +func (d dynamicSystemView) LookupPlugin(ctx context.Context, name string, pluginType consts.PluginType) (*pluginutil.PluginRunner, error) { if d.core == nil { return nil, fmt.Errorf("system view core is nil") } if d.core.pluginCatalog == nil { return nil, fmt.Errorf("system view core plugin catalog is nil") } - r, err := d.core.pluginCatalog.Get(ctx, name) + r, err := d.core.pluginCatalog.Get(ctx, name, pluginType) if err != nil { return nil, err } diff --git a/vault/expiration_test.go b/vault/expiration_test.go index fc55c8c00a2d..f6256de33986 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -1943,6 +1943,7 @@ func badRenewFactory(ctx context.Context, conf *logical.BackendConfig) (logical. }, }, }, + BackendType: logical.TypeLogical, } err := be.Setup(namespace.RootContext(nil), conf) diff --git a/vault/logical_passthrough.go b/vault/logical_passthrough.go index c971c4479894..6c10cc7baf6a 100644 --- a/vault/logical_passthrough.go +++ b/vault/logical_passthrough.go @@ -58,6 +58,7 @@ func LeaseSwitchedPassthroughBackend(ctx context.Context, conf *logical.BackendC HelpDescription: strings.TrimSpace(passthroughHelpDescription), }, }, + BackendType: logical.TypeLogical, } b.Backend.Secrets = []*framework.Secret{ diff --git a/vault/logical_system.go b/vault/logical_system.go index 3b7b7b00af78..a0cf9d15cd12 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -12,6 +12,7 @@ import ( "hash" "net/http" "path/filepath" + "sort" "strconv" "strings" "sync" @@ -19,8 +20,8 @@ import ( "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" - memdb "github.com/hashicorp/go-memdb" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/compressutil" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/identity" @@ -131,8 +132,7 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend { b.Backend.Paths = append(b.Backend.Paths, b.configPaths()...) b.Backend.Paths = append(b.Backend.Paths, b.rekeyPaths()...) b.Backend.Paths = append(b.Backend.Paths, b.sealPaths()...) - b.Backend.Paths = append(b.Backend.Paths, b.pluginsCatalogPath()) - b.Backend.Paths = append(b.Backend.Paths, b.pluginsCatalogListPath()) + b.Backend.Paths = append(b.Backend.Paths, b.pluginsCatalogPaths()...) b.Backend.Paths = append(b.Backend.Paths, b.pluginsReloadPath()) b.Backend.Paths = append(b.Backend.Paths, b.auditPaths()...) b.Backend.Paths = append(b.Backend.Paths, b.mountPaths()...) @@ -255,21 +255,53 @@ func (b *SystemBackend) handleTidyLeases(ctx context.Context, req *logical.Reque return logical.RespondWithStatusCode(resp, req, http.StatusAccepted) } -func (b *SystemBackend) handlePluginCatalogList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - plugins, err := b.Core.pluginCatalog.List(ctx) +func (b *SystemBackend) handlePluginCatalogTypedList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginType, err := consts.ParsePluginType(d.Get("type").(string)) if err != nil { return nil, err } + plugins, err := b.Core.pluginCatalog.List(ctx, pluginType) + if err != nil { + return nil, err + } return logical.ListResponse(plugins), nil } +func (b *SystemBackend) handlePluginCatalogUntypedList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginsByType := make(map[string]interface{}) + for _, pluginType := range consts.PluginTypes { + plugins, err := b.Core.pluginCatalog.List(ctx, pluginType) + if err != nil { + return nil, err + } + if len(plugins) > 0 { + sort.Strings(plugins) + pluginsByType[pluginType.String()] = plugins + } + } + return &logical.Response{ + Data: pluginsByType, + }, nil +} + func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { pluginName := d.Get("name").(string) if pluginName == "" { return logical.ErrorResponse("missing plugin name"), nil } + pluginTypeStr := d.Get("type").(string) + if pluginTypeStr == "" { + // If the plugin type is not provided, list it as unknown so that we + // add it to the catalog and UpdatePlugins later will sort it. + pluginTypeStr = "unknown" + } + pluginType, err := consts.ParsePluginType(pluginTypeStr) + if err != nil { + return nil, err + } + sha256 := d.Get("sha256").(string) if sha256 == "" { sha256 = d.Get("sha_256").(string) @@ -302,7 +334,7 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, req *logi return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err } - err = b.Core.pluginCatalog.Set(ctx, pluginName, parts[0], args, env, sha256Bytes) + err = b.Core.pluginCatalog.Set(ctx, pluginName, pluginType, parts[0], args, env, sha256Bytes) if err != nil { return nil, err } @@ -315,7 +347,13 @@ func (b *SystemBackend) handlePluginCatalogRead(ctx context.Context, req *logica if pluginName == "" { return logical.ErrorResponse("missing plugin name"), nil } - plugin, err := b.Core.pluginCatalog.Get(ctx, pluginName) + + pluginType, err := consts.ParsePluginType(d.Get("type").(string)) + if err != nil { + return nil, err + } + + plugin, err := b.Core.pluginCatalog.Get(ctx, pluginName, pluginType) if err != nil { return nil, err } @@ -349,10 +387,13 @@ func (b *SystemBackend) handlePluginCatalogDelete(ctx context.Context, req *logi if pluginName == "" { return logical.ErrorResponse("missing plugin name"), nil } - err := b.Core.pluginCatalog.Delete(ctx, pluginName) + pluginType, err := consts.ParsePluginType(d.Get("type").(string)) if err != nil { return nil, err } + if err := b.Core.pluginCatalog.Delete(ctx, pluginName, pluginType); err != nil { + return nil, err + } return nil, nil } @@ -599,7 +640,6 @@ func mountInfo(entry *MountEntry) map[string]interface{} { "default_lease_ttl": int64(entry.Config.DefaultLeaseTTL.Seconds()), "max_lease_ttl": int64(entry.Config.MaxLeaseTTL.Seconds()), "force_no_cache": entry.Config.ForceNoCache, - "plugin_name": entry.Config.PluginName, } if rawVal, ok := entry.synthesizedConfigCache.Load("audit_non_hmac_request_keys"); ok { entryConfig["audit_non_hmac_request_keys"] = rawVal.([]string) @@ -735,15 +775,14 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d return logical.ErrorResponse( "backend type must be specified as a string"), logical.ErrInvalidRequest - case "plugin": - // Only set plugin-name if mount is of type plugin, with apiConfig.PluginName + // Only set plugin-name if mount is of type plugin, with apiConfig.PluginNameDeprecated // option taking precedence. switch { - case apiConfig.PluginName != "": - config.PluginName = apiConfig.PluginName + case apiConfig.PluginNameDeprecated != "": + logicalType = apiConfig.PluginNameDeprecated case pluginName != "": - config.PluginName = pluginName + logicalType = pluginName default: return logical.ErrorResponse( "plugin_name must be provided for plugin backend"), @@ -1620,15 +1659,14 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque return logical.ErrorResponse( "backend type must be specified as a string"), logical.ErrInvalidRequest - case "plugin": - // Only set plugin name if mount is of type plugin, with apiConfig.PluginName + // Only set plugin name if mount is of type plugin, with apiConfig.PluginNameDeprecated // option taking precedence. switch { - case apiConfig.PluginName != "": - config.PluginName = apiConfig.PluginName + case apiConfig.PluginNameDeprecated != "": + logicalType = apiConfig.PluginNameDeprecated case pluginName != "": - config.PluginName = pluginName + logicalType = pluginName default: return logical.ErrorResponse( "plugin_name must be provided for plugin backend"), @@ -3642,8 +3680,16 @@ This path responds to the following HTTP methods. "Lists the headers configured to be audited.", `Returns a list of headers that have been configured to be audited.`, }, + "plugin-catalog-list-all": { + "Lists all the plugins known to Vault", + ` +This path responds to the following HTTP methods. + LIST / + Returns a list of names of configured plugins. + `, + }, "plugin-catalog": { - "Configures the plugins known to vault", + "Configures the plugins known to Vault", ` This path responds to the following HTTP methods. LIST / @@ -3663,6 +3709,10 @@ This path responds to the following HTTP methods. "The name of the plugin", "", }, + "plugin-catalog_type": { + "The type of the plugin, may be auth, secret, or database", + "", + }, "plugin-catalog_sha-256": { `The SHA256 sum of the executable used in the command field. This should be HEX encoded.`, diff --git a/vault/logical_system_integ_test.go b/vault/logical_system_integ_test.go index e61518db803d..8c8c146cacea 100644 --- a/vault/logical_system_integ_test.go +++ b/vault/logical_system_integ_test.go @@ -11,6 +11,7 @@ import ( "github.com/go-test/deep" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/plugin" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/pluginutil" vaulthttp "github.com/hashicorp/vault/http" @@ -107,16 +108,16 @@ func TestSystemBackend_Plugin_MismatchType(t *testing.T) { core := cluster.Cores[0] - // Replace the plugin with a credential backend - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", []string{}, "") + // Add a credential backend with the same name + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, "") // Make a request to lazy load the now-credential plugin // and expect an error req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") req.ClientToken = core.Client.Token() _, err := core.HandleRequest(namespace.RootContext(nil), req) - if err == nil { - t.Fatalf("expected error due to mismatch on error type: %s", err) + if err != nil { + t.Fatalf("adding a same-named plugin of a different type should be no problem: %s", err) } // Sleep a bit before cleanup is called @@ -148,7 +149,7 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun core := cluster.Cores[0] // Remove the plugin from the catalog - req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/mock-plugin") + req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/database/mock-plugin") req.ClientToken = core.Client.Token() resp, err := core.HandleRequest(namespace.RootContext(nil), req) if err != nil || (resp != nil && resp.IsError()) { @@ -183,19 +184,15 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun switch btype { case logical.TypeLogical: // Add plugin back to the catalog - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", []string{}, "") + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, "") _, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{ - "type": "plugin", - "config": map[string]interface{}{ - "plugin_name": "mock-plugin", - }, + "type": "test", }) case logical.TypeCredential: // Add plugin back to the catalog - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", []string{}, "") + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, "") _, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{ - "type": "plugin", - "plugin_name": "mock-plugin", + "type": "test", }) } if err == nil { @@ -207,33 +204,33 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun func TestSystemBackend_Plugin_continueOnError(t *testing.T) { t.Run("secret", func(t *testing.T) { t.Run("sha256_mismatch", func(t *testing.T) { - testPlugin_continueOnError(t, logical.TypeLogical, true) + testPlugin_continueOnError(t, logical.TypeLogical, true, consts.PluginTypeSecrets) }) t.Run("missing_plugin", func(t *testing.T) { - testPlugin_continueOnError(t, logical.TypeLogical, false) + testPlugin_continueOnError(t, logical.TypeLogical, false, consts.PluginTypeSecrets) }) }) t.Run("auth", func(t *testing.T) { t.Run("sha256_mismatch", func(t *testing.T) { - testPlugin_continueOnError(t, logical.TypeCredential, true) + testPlugin_continueOnError(t, logical.TypeCredential, true, consts.PluginTypeCredential) }) t.Run("missing_plugin", func(t *testing.T) { - testPlugin_continueOnError(t, logical.TypeCredential, false) + testPlugin_continueOnError(t, logical.TypeCredential, false, consts.PluginTypeCredential) }) }) } -func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool) { +func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool, pluginType consts.PluginType) { cluster := testSystemBackendMock(t, 1, 1, btype) defer cluster.Cleanup() core := cluster.Cores[0] // Get the registered plugin - req := logical.TestRequest(t, logical.ReadOperation, "sys/plugins/catalog/mock-plugin") + req := logical.TestRequest(t, logical.ReadOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) req.ClientToken = core.Client.Token() resp, err := core.HandleRequest(namespace.RootContext(nil), req) if err != nil || resp == nil || (resp != nil && resp.IsError()) { @@ -247,7 +244,7 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc // Trigger a sha256 mismatch or missing plugin error if mismatch { - req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/catalog/mock-plugin") + req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/catalog/database/mock-plugin") req.Data = map[string]interface{}{ "sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216", "command": filepath.Base(command), @@ -288,9 +285,9 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc // Re-add the plugin to the catalog switch btype { case logical.TypeLogical: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", []string{}, cluster.TempDir) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, cluster.TempDir) case logical.TypeCredential: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", []string{}, cluster.TempDir) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, cluster.TempDir) } // Reload the plugin @@ -485,18 +482,11 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo switch backendType { case logical.TypeLogical: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", []string{}, tempDir) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, tempDir) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ - "type": "plugin", - } - if (i+1)%2 == 0 { - options["config"] = map[string]interface{}{ - "plugin_name": "mock-plugin", - } - } else { - options["plugin_name"] = "mock-plugin" + "type": "mock-plugin", } resp, err := client.Logical().Write(fmt.Sprintf("sys/mounts/mock-%d", i), options) if err != nil { @@ -507,18 +497,11 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo } } case logical.TypeCredential: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", []string{}, tempDir) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, tempDir) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ - "type": "plugin", - } - if (i+1)%2 == 0 { - options["config"] = map[string]interface{}{ - "plugin_name": "mock-plugin", - } - } else { - options["plugin_name"] = "mock-plugin" + "type": "mock-plugin", } resp, err := client.Logical().Write(fmt.Sprintf("sys/auth/mock-%d", i), options) if err != nil { @@ -546,7 +529,7 @@ func TestSystemBackend_Plugin_Env(t *testing.T) { func testSystemBackend_SingleCluster_Env(t *testing.T, env []string) *vault.TestCluster { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ - "plugin": plugin.Factory, + "test": plugin.Factory, }, } @@ -570,10 +553,9 @@ func testSystemBackend_SingleCluster_Env(t *testing.T, env []string) *vault.Test os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainEnv", env, tempDir) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainEnv", env, tempDir) options := map[string]interface{}{ - "type": "plugin", - "plugin_name": "mock-plugin", + "type": "mock-plugin", } resp, err := client.Logical().Write("sys/mounts/mock", options) diff --git a/vault/logical_system_paths.go b/vault/logical_system_paths.go index dea1f5ad2dda..ab97aa04ebbe 100644 --- a/vault/logical_system_paths.go +++ b/vault/logical_system_paths.go @@ -558,54 +558,90 @@ func (b *SystemBackend) sealPaths() []*framework.Path { } } -func (b *SystemBackend) pluginsCatalogPath() *framework.Path { - return &framework.Path{ - Pattern: "plugins/catalog/(?P.+)", +func (b *SystemBackend) pluginsCatalogPaths() []*framework.Path { + return []*framework.Path{ + { + Pattern: "plugins/catalog/(?Pauth|database|secret)/?$", - Fields: map[string]*framework.FieldSchema{ - "name": &framework.FieldSchema{ - Type: framework.TypeString, - Description: strings.TrimSpace(sysHelp["plugin-catalog_name"][0]), - }, - "sha256": &framework.FieldSchema{ - Type: framework.TypeString, - Description: strings.TrimSpace(sysHelp["plugin-catalog_sha-256"][0]), - }, - "sha_256": &framework.FieldSchema{ - Type: framework.TypeString, - Description: strings.TrimSpace(sysHelp["plugin-catalog_sha-256"][0]), - }, - "command": &framework.FieldSchema{ - Type: framework.TypeString, - Description: strings.TrimSpace(sysHelp["plugin-catalog_command"][0]), - }, - "args": &framework.FieldSchema{ - Type: framework.TypeStringSlice, - Description: strings.TrimSpace(sysHelp["plugin-catalog_args"][0]), + Fields: map[string]*framework.FieldSchema{ + "type": &framework.FieldSchema{ + Type: framework.TypeString, + Description: strings.TrimSpace(sysHelp["plugin-catalog_type"][0]), + }, }, - "env": &framework.FieldSchema{ - Type: framework.TypeStringSlice, - Description: strings.TrimSpace(sysHelp["plugin-catalog_env"][0]), + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ListOperation: &framework.PathOperation{ + Callback: b.handlePluginCatalogTypedList, + Summary: "List the plugins in the catalog.", + }, }, + + HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]), + HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), }, + { + Pattern: "plugins/catalog(/(?Pauth|database|secret))?/(?P.+)", - Operations: map[logical.Operation]framework.OperationHandler{ - logical.UpdateOperation: &framework.PathOperation{ - Callback: b.handlePluginCatalogUpdate, - Summary: "Register a new plugin, or updates an existing one with the supplied name.", - }, - logical.DeleteOperation: &framework.PathOperation{ - Callback: b.handlePluginCatalogDelete, - Summary: "Remove the plugin with the given name.", + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: strings.TrimSpace(sysHelp["plugin-catalog_name"][0]), + }, + "type": &framework.FieldSchema{ + Type: framework.TypeString, + Description: strings.TrimSpace(sysHelp["plugin-catalog_type"][0]), + }, + "sha256": &framework.FieldSchema{ + Type: framework.TypeString, + Description: strings.TrimSpace(sysHelp["plugin-catalog_sha-256"][0]), + }, + "sha_256": &framework.FieldSchema{ + Type: framework.TypeString, + Description: strings.TrimSpace(sysHelp["plugin-catalog_sha-256"][0]), + }, + "command": &framework.FieldSchema{ + Type: framework.TypeString, + Description: strings.TrimSpace(sysHelp["plugin-catalog_command"][0]), + }, + "args": &framework.FieldSchema{ + Type: framework.TypeStringSlice, + Description: strings.TrimSpace(sysHelp["plugin-catalog_args"][0]), + }, + "env": &framework.FieldSchema{ + Type: framework.TypeStringSlice, + Description: strings.TrimSpace(sysHelp["plugin-catalog_env"][0]), + }, }, - logical.ReadOperation: &framework.PathOperation{ - Callback: b.handlePluginCatalogRead, - Summary: "Return the configuration data for the plugin with the given name.", + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.handlePluginCatalogUpdate, + Summary: "Register a new plugin, or updates an existing one with the supplied name.", + }, + logical.DeleteOperation: &framework.PathOperation{ + Callback: b.handlePluginCatalogDelete, + Summary: "Remove the plugin with the given name.", + }, + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handlePluginCatalogRead, + Summary: "Return the configuration data for the plugin with the given name.", + }, }, + + HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]), + HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), }, + { + Pattern: "plugins/catalog/?$", + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.handlePluginCatalogUntypedList, + }, - HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]), - HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), + HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog-list-all"][0]), + HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog-list-all"][1]), + }, } } @@ -637,24 +673,6 @@ func (b *SystemBackend) pluginsReloadPath() *framework.Path { } } -func (b *SystemBackend) pluginsCatalogListPath() *framework.Path { - return &framework.Path{ - Pattern: "plugins/catalog/?$", - - Fields: map[string]*framework.FieldSchema{}, - - Operations: map[logical.Operation]framework.OperationHandler{ - logical.ListOperation: &framework.PathOperation{ - Callback: b.handlePluginCatalogList, - Summary: "List the plugins in the catalog.", - }, - }, - - HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]), - HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), - } -} - func (b *SystemBackend) toolsPaths() []*framework.Path { return []*framework.Path{ { diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index dff2175f09db..0e7b8a0ba797 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -19,6 +19,7 @@ import ( hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/salt" @@ -135,7 +136,6 @@ func TestSystemBackend_mounts(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["secret/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["secret/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": false, @@ -151,7 +151,6 @@ func TestSystemBackend_mounts(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["sys/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["sys/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": false, @@ -165,7 +164,6 @@ func TestSystemBackend_mounts(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": true, @@ -179,7 +177,6 @@ func TestSystemBackend_mounts(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": false, @@ -231,7 +228,6 @@ func TestSystemBackend_mount(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["secret/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["secret/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": false, @@ -247,7 +243,6 @@ func TestSystemBackend_mount(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["sys/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["sys/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": false, @@ -261,7 +256,6 @@ func TestSystemBackend_mount(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": true, @@ -275,7 +269,6 @@ func TestSystemBackend_mount(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": false, @@ -289,7 +282,6 @@ func TestSystemBackend_mount(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": int64(2100), "max_lease_ttl": int64(2700), - "plugin_name": "", "force_no_cache": false, }, "local": true, @@ -340,7 +332,7 @@ func TestSystemBackend_mount_invalid(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp.Data["error"] != `unknown backend type: "nope"` { + if resp.Data["error"] != `plugin not found in the catalog: nope` { t.Fatalf("bad: %v", resp) } } @@ -1430,7 +1422,6 @@ func TestSystemBackend_authTable(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": int64(0), "max_lease_ttl": int64(0), - "plugin_name": "", "force_no_cache": false, "token_type": "default-service", }, @@ -1447,7 +1438,7 @@ func TestSystemBackend_authTable(t *testing.T) { func TestSystemBackend_enableAuth(t *testing.T) { c, b, _ := testCoreSystemBackend(t) c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { - return &NoopBackend{}, nil + return &NoopBackend{BackendType: logical.TypeCredential}, nil } req := logical.TestRequest(t, logical.UpdateOperation, "auth/foo") @@ -1485,7 +1476,6 @@ func TestSystemBackend_enableAuth(t *testing.T) { "default_lease_ttl": int64(2100), "max_lease_ttl": int64(2700), "force_no_cache": false, - "plugin_name": "", "token_type": "default-service", }, "local": true, @@ -1499,7 +1489,6 @@ func TestSystemBackend_enableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": int64(0), "max_lease_ttl": int64(0), - "plugin_name": "", "force_no_cache": false, "token_type": "default-service", }, @@ -1521,7 +1510,7 @@ func TestSystemBackend_enableAuth_invalid(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp.Data["error"] != `unknown backend type: "nope"` { + if resp.Data["error"] != `plugin not found in the catalog: nope` { t.Fatalf("bad: %v", resp) } } @@ -1987,17 +1976,17 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { } c.pluginCatalog.directory = sym - req := logical.TestRequest(t, logical.ListOperation, "plugins/catalog/") + req := logical.TestRequest(t, logical.ListOperation, "plugins/catalog/database") resp, err := b.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } - if len(resp.Data["keys"].([]string)) != len(builtinplugins.Keys()) { - t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.Keys())) + if len(resp.Data["keys"].([]string)) != len(c.builtinRegistry.Keys(consts.PluginTypeDatabase)) { + t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.Registry.Keys(consts.PluginTypeDatabase))) } - req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/mysql-database-plugin") + req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/database/mysql-database-plugin") resp, err = b.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) @@ -2024,7 +2013,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { // Check we can only specify args in one of command or args. command := fmt.Sprintf("%s --test", filepath.Base(file.Name())) - req = logical.TestRequest(t, logical.UpdateOperation, "plugins/catalog/test-plugin") + req = logical.TestRequest(t, logical.UpdateOperation, "plugins/catalog/database/test-plugin") req.Data["args"] = []string{"--foo"} req.Data["sha_256"] = hex.EncodeToString([]byte{'1'}) req.Data["command"] = command @@ -2042,7 +2031,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { t.Fatalf("err: %v %v", err, resp.Error()) } - req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin") + req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/database/test-plugin") resp, err = b.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) @@ -2061,13 +2050,13 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { } // Delete plugin - req = logical.TestRequest(t, logical.DeleteOperation, "plugins/catalog/test-plugin") + req = logical.TestRequest(t, logical.DeleteOperation, "plugins/catalog/database/test-plugin") resp, err = b.HandleRequest(namespace.RootContext(nil), req) if err != nil { t.Fatalf("err: %v", err) } - req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin") + req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/database/test-plugin") resp, err = b.HandleRequest(namespace.RootContext(nil), req) if resp != nil || err != nil { t.Fatalf("expected nil response, plugin not deleted correctly got resp: %v, err: %v", resp, err) @@ -2262,7 +2251,6 @@ func TestSystemBackend_InternalUIMounts(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["secret"].(map[string]interface{})["secret/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["secret"].(map[string]interface{})["secret/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": false, @@ -2278,7 +2266,6 @@ func TestSystemBackend_InternalUIMounts(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["secret"].(map[string]interface{})["sys/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["secret"].(map[string]interface{})["sys/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": false, @@ -2292,7 +2279,6 @@ func TestSystemBackend_InternalUIMounts(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["secret"].(map[string]interface{})["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["secret"].(map[string]interface{})["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": true, @@ -2306,7 +2292,6 @@ func TestSystemBackend_InternalUIMounts(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": resp.Data["secret"].(map[string]interface{})["identity/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["secret"].(map[string]interface{})["identity/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), - "plugin_name": "", "force_no_cache": false, }, "local": false, @@ -2321,7 +2306,6 @@ func TestSystemBackend_InternalUIMounts(t *testing.T) { "default_lease_ttl": int64(0), "max_lease_ttl": int64(0), "force_no_cache": false, - "plugin_name": "", "token_type": "default-service", }, "type": "token", diff --git a/vault/mount.go b/vault/mount.go index fcf481a2d332..e3f5adc995be 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -11,6 +11,7 @@ import ( "time" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/namespace" @@ -60,6 +61,7 @@ const ( systemMountType = "system" identityMountType = "identity" cubbyholeMountType = "cubbyhole" + pluginMountType = "plugin" MountTableUpdateStorage = true MountTableNoUpdateStorage = false @@ -221,7 +223,7 @@ type MountConfig struct { DefaultLeaseTTL time.Duration `json:"default_lease_ttl" structs:"default_lease_ttl" mapstructure:"default_lease_ttl"` // Override for global default MaxLeaseTTL time.Duration `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` // Override for global default ForceNoCache bool `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"` // Override for global default - PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` + PluginNameDeprecated string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` AuditNonHMACRequestKeys []string `json:"audit_non_hmac_request_keys,omitempty" structs:"audit_non_hmac_request_keys" mapstructure:"audit_non_hmac_request_keys"` AuditNonHMACResponseKeys []string `json:"audit_non_hmac_response_keys,omitempty" structs:"audit_non_hmac_response_keys" mapstructure:"audit_non_hmac_response_keys"` ListingVisibility ListingVisibilityType `json:"listing_visibility,omitempty" structs:"listing_visibility" mapstructure:"listing_visibility"` @@ -234,7 +236,7 @@ type APIMountConfig struct { DefaultLeaseTTL string `json:"default_lease_ttl" structs:"default_lease_ttl" mapstructure:"default_lease_ttl"` MaxLeaseTTL string `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` ForceNoCache bool `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"` - PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` + PluginNameDeprecated string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` AuditNonHMACRequestKeys []string `json:"audit_non_hmac_request_keys,omitempty" structs:"audit_non_hmac_request_keys" mapstructure:"audit_non_hmac_request_keys"` AuditNonHMACResponseKeys []string `json:"audit_non_hmac_response_keys,omitempty" structs:"audit_non_hmac_response_keys" mapstructure:"audit_non_hmac_response_keys"` ListingVisibility ListingVisibilityType `json:"listing_visibility,omitempty" structs:"listing_visibility" mapstructure:"listing_visibility"` @@ -417,12 +419,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora var backend logical.Backend sysView := c.mountEntrySysView(entry) - conf := make(map[string]string) - if entry.Config.PluginName != "" { - conf["plugin_name"] = entry.Config.PluginName - } - // Consider having plugin name under entry.Options backend, err = c.newLogicalBackend(ctx, entry, sysView, view) if err != nil { return err @@ -433,8 +430,10 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora // Check for the correct backend type backendType := backend.Type() - if entry.Type == "plugin" && backendType != logical.TypeLogical { - return fmt.Errorf("cannot mount %q of type %q as a logical backend", entry.Config.PluginName, backendType) + if backendType != logical.TypeLogical { + if entry.Type != "kv" && entry.Type != "system" && entry.Type != "cubbyhole" { + return fmt.Errorf(`unknown backend type: "%s"`, entry.Type) + } } addPathCheckers(c, entry, backend, viewPath) @@ -1021,15 +1020,10 @@ func (c *Core) setupMounts(ctx context.Context) error { var backend logical.Backend // Create the new backend sysView := c.mountEntrySysView(entry) - // Set up conf to pass in plugin_name - conf := make(map[string]string) - if entry.Config.PluginName != "" { - conf["plugin_name"] = entry.Config.PluginName - } backend, err = c.newLogicalBackend(ctx, entry, sysView, view) if err != nil { c.logger.Error("failed to create mount entry", "path", entry.Path, "error", err) - if entry.Type == "plugin" { + if !c.builtinRegistry.Contains(entry.Type, consts.PluginTypeSecrets) { // If we encounter an error instantiating the backend due to an error, // skip backend initialization but register the entry to the mount table // to preserve storage and path. @@ -1045,8 +1039,11 @@ func (c *Core) setupMounts(ctx context.Context) error { { // Check for the correct backend type backendType := backend.Type() - if entry.Type == "plugin" && backendType != logical.TypeLogical { - return fmt.Errorf("cannot mount %q of type %q as a logical backend", entry.Config.PluginName, backendType) + + if backendType != logical.TypeLogical { + if entry.Type != "kv" && entry.Type != "system" && entry.Type != "cubbyhole" { + return fmt.Errorf(`unknown backend type: "%s"`, entry.Type) + } } addPathCheckers(c, entry, backend, barrierPath) @@ -1116,9 +1113,10 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView if alias, ok := mountAliases[t]; ok { t = alias } + f, ok := c.logicalBackends[t] if !ok { - return nil, fmt.Errorf("unknown backend type: %q", t) + f = plugin.Factory } // Set up conf to pass in plugin_name @@ -1126,10 +1124,16 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView for k, v := range entry.Options { conf[k] = v } - if entry.Config.PluginName != "" { - conf["plugin_name"] = entry.Config.PluginName + + switch { + case entry.Type == "plugin": + conf["plugin_name"] = entry.Config.PluginNameDeprecated + default: + conf["plugin_name"] = t } + conf["plugin_type"] = consts.PluginTypeSecrets.String() + backendLogger := c.baseLogger.Named(fmt.Sprintf("secrets.%s.%s", t, entry.Accessor)) c.AddLogger(backendLogger) config := &logical.BackendConfig{ diff --git a/vault/mount_test.go b/vault/mount_test.go index b3009f7bb38e..0a28e68f743c 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -487,7 +487,9 @@ func TestCore_MountTable_UpgradeToTyped(t *testing.T) { } c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { - return &NoopBackend{}, nil + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil } me = &MountEntry{ diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index f7a25619fe32..ad5fc2b4d999 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -10,12 +10,16 @@ import ( "strings" "sync" + log "github.com/hashicorp/go-hclog" + multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/errwrap" - "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" + backendplugin "github.com/hashicorp/vault/logical/plugin" ) var ( @@ -28,16 +32,24 @@ var ( // to be registered to the catalog before they can be used in backends. Builtin // plugins are automatically detected and included in the catalog. type PluginCatalog struct { - catalogView *BarrierView - directory string + builtinRegistry BuiltinRegistry + catalogView *BarrierView + directory string lock sync.RWMutex } -func (c *Core) setupPluginCatalog() error { +func (c *Core) setupPluginCatalog(ctx context.Context) error { c.pluginCatalog = &PluginCatalog{ - catalogView: NewBarrierView(c.barrier, pluginCatalogPath), - directory: c.pluginDirectory, + builtinRegistry: c.builtinRegistry, + catalogView: NewBarrierView(c.barrier, pluginCatalogPath), + directory: c.pluginDirectory, + } + + // Run upgrade if untyped plugins exist + err := c.pluginCatalog.UpgradePlugins(ctx, c.logger) + if err != nil { + c.logger.Error("error while upgrading plugin storage", "error", err) } if c.logger.IsInfo() { @@ -47,25 +59,145 @@ func (c *Core) setupPluginCatalog() error { return nil } +// getPluginTypeFromUnknown will attempt to run the plugin to determine the +// type. It will first attempt to run as a database plugin then a backend +// plugin. Both of these will be run in metadata mode. +func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, plugin *pluginutil.PluginRunner) (consts.PluginType, error) { + { + // Attempt to run as database plugin + client, err := dbplugin.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true) + if err == nil { + // Close the client and cleanup the plugin process + client.Close() + return consts.PluginTypeDatabase, nil + } + } + + { + // Attempt to run as backend plugin + client, err := backendplugin.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true) + if err == nil { + err := client.Setup(ctx, &logical.BackendConfig{}) + if err != nil { + return consts.PluginTypeUnknown, err + } + + backendType := client.Type() + client.Cleanup(ctx) + + switch backendType { + case logical.TypeCredential: + return consts.PluginTypeCredential, nil + case logical.TypeLogical: + return consts.PluginTypeSecrets, nil + } + } + } + + return consts.PluginTypeUnknown, nil +} + +// UpdatePlugins will loop over all the plugins of unknown type and attempt to +// upgrade them to typed plugins +func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) error { + c.lock.Lock() + defer c.lock.Unlock() + + // If the directory isn't set we can skip the upgrade attempt + if c.directory == "" { + return nil + } + + // List plugins from old location + pluginsRaw, err := c.catalogView.List(ctx, "") + if err != nil { + return err + } + plugins := make([]string, 0, len(pluginsRaw)) + for _, p := range pluginsRaw { + if !strings.HasSuffix(p, "/") { + plugins = append(plugins, p) + } + } + + logger.Info("upgrading plugin information", "plugins", plugins) + + var retErr error + for _, pluginName := range plugins { + pluginRaw, err := c.catalogView.Get(ctx, pluginName) + if err != nil { + retErr = multierror.Append(errwrap.Wrapf("failed to load plugin entry: {{err}}", err)) + continue + } + + plugin := new(pluginutil.PluginRunner) + if err := jsonutil.DecodeJSON(pluginRaw.Value, plugin); err != nil { + retErr = multierror.Append(errwrap.Wrapf("failed to decode plugin entry: {{err}}", err)) + continue + } + + // prepend the plugin directory to the command + cmdOld := plugin.Command + plugin.Command = filepath.Join(c.directory, plugin.Command) + + pluginType, err := c.getPluginTypeFromUnknown(ctx, plugin) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err)) + continue + } + if pluginType == consts.PluginTypeUnknown { + retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: plugin of unknown type", pluginName)) + continue + } + + // Upgrade the storage + err = c.setInternal(ctx, pluginName, pluginType, cmdOld, plugin.Args, plugin.Env, plugin.Sha256) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err)) + continue + } + + err = c.catalogView.Delete(ctx, pluginName) + if err != nil { + logger.Error("could not remove plugin", "plugin", pluginName, "error", err) + } + + logger.Info("upgraded plugin type", "plugin", pluginName, "type", pluginType.String()) + } + + return retErr +} + // Get retrieves a plugin with the specified name from the catalog. It first // looks for external plugins with this name and then looks for builtin plugins. // It returns a PluginRunner or an error if no plugin was found. -func (c *PluginCatalog) Get(ctx context.Context, name string) (*pluginutil.PluginRunner, error) { +func (c *PluginCatalog) Get(ctx context.Context, name string, pluginType consts.PluginType) (*pluginutil.PluginRunner, error) { c.lock.RLock() defer c.lock.RUnlock() // If the directory isn't set only look for builtin plugins. if c.directory != "" { // Look for external plugins in the barrier - out, err := c.catalogView.Get(ctx, name) + out, err := c.catalogView.Get(ctx, pluginType.String()+"/"+name) if err != nil { return nil, errwrap.Wrapf(fmt.Sprintf("failed to retrieve plugin %q: {{err}}", name), err) } + if out == nil { + // Also look for external plugins under what their name would have been if they + // were registered before plugin types existed. + out, err = c.catalogView.Get(ctx, name) + if err != nil { + return nil, errwrap.Wrapf(fmt.Sprintf("failed to retrieve plugin %q: {{err}}", name), err) + } + } if out != nil { entry := new(pluginutil.PluginRunner) if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { return nil, errwrap.Wrapf("failed to decode plugin entry: {{err}}", err) } + if entry.Type != pluginType && entry.Type != consts.PluginTypeUnknown { + return nil, nil + } // prepend the plugin directory to the command entry.Command = filepath.Join(c.directory, entry.Command) @@ -74,9 +206,10 @@ func (c *PluginCatalog) Get(ctx context.Context, name string) (*pluginutil.Plugi } } // Look for builtin plugins - if factory, ok := builtinplugins.Get(name); ok { + if factory, ok := c.builtinRegistry.Get(name, pluginType); ok { return &pluginutil.PluginRunner{ Name: name, + Type: pluginType, Builtin: true, BuiltinFactory: factory, }, nil @@ -87,7 +220,7 @@ func (c *PluginCatalog) Get(ctx context.Context, name string) (*pluginutil.Plugi // Set registers a new external plugin with the catalog, or updates an existing // external plugin. It takes the name, command and SHA256 of the plugin. -func (c *PluginCatalog) Set(ctx context.Context, name, command string, args []string, env []string, sha256 []byte) error { +func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts.PluginType, command string, args []string, env []string, sha256 []byte) error { if c.directory == "" { return ErrDirectoryNotConfigured } @@ -102,6 +235,10 @@ func (c *PluginCatalog) Set(ctx context.Context, name, command string, args []st c.lock.Lock() defer c.lock.Unlock() + return c.setInternal(ctx, name, pluginType, command, args, env, sha256) +} + +func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, command string, args []string, env []string, sha256 []byte) error { // Best effort check to make sure the command isn't breaking out of the // configured plugin directory. commandFull := filepath.Join(c.directory, command) @@ -118,8 +255,28 @@ func (c *PluginCatalog) Set(ctx context.Context, name, command string, args []st return errors.New("can not execute files outside of configured plugin directory") } + // If the plugin type is unknown, we want to attempt to determine the type + if pluginType == consts.PluginTypeUnknown { + // entryTmp should only be used for the below type check, it uses the + // full command instead of the relative command. + entryTmp := &pluginutil.PluginRunner{ + Name: name, + Command: commandFull, + Args: args, + Env: env, + Sha256: sha256, + Builtin: false, + } + + pluginType, err = c.getPluginTypeFromUnknown(ctx, entryTmp) + if err != nil || pluginType == consts.PluginTypeUnknown { + return errors.New("unable to determine plugin type") + } + } + entry := &pluginutil.PluginRunner{ Name: name, + Type: pluginType, Command: command, Args: args, Env: env, @@ -133,7 +290,7 @@ func (c *PluginCatalog) Set(ctx context.Context, name, command string, args []st } logicalEntry := logical.StorageEntry{ - Key: name, + Key: pluginType.String() + "/" + name, Value: buf, } if err := c.catalogView.Put(ctx, &logicalEntry); err != nil { @@ -144,16 +301,23 @@ func (c *PluginCatalog) Set(ctx context.Context, name, command string, args []st // Delete is used to remove an external plugin from the catalog. Builtin plugins // can not be deleted. -func (c *PluginCatalog) Delete(ctx context.Context, name string) error { +func (c *PluginCatalog) Delete(ctx context.Context, name string, pluginType consts.PluginType) error { c.lock.Lock() defer c.lock.Unlock() - return c.catalogView.Delete(ctx, name) + // Check the name under which the plugin exists, but if it's unfound, don't return any error. + pluginKey := pluginType.String() + "/" + name + out, err := c.catalogView.Get(ctx, pluginKey) + if err != nil || out == nil { + pluginKey = name + } + + return c.catalogView.Delete(ctx, pluginKey) } // List returns a list of all the known plugin names. If an external and builtin // plugin share the same name, only one instance of the name will be returned. -func (c *PluginCatalog) List(ctx context.Context) ([]string, error) { +func (c *PluginCatalog) List(ctx context.Context, pluginType consts.PluginType) ([]string, error) { c.lock.RLock() defer c.lock.RUnlock() @@ -163,14 +327,27 @@ func (c *PluginCatalog) List(ctx context.Context) ([]string, error) { return nil, err } - // Get the keys for builtin plugins - builtinKeys := builtinplugins.Keys() + // Get the builtin plugins. + builtinKeys := c.builtinRegistry.Keys(pluginType) - // Use a map to unique the two lists + // Use a map to unique the two lists. mapKeys := make(map[string]bool) + pluginTypePrefix := pluginType.String() + "/" + for _, plugin := range keys { - mapKeys[plugin] = true + + // Only list user-added plugins if they're of the given type. + if entry, err := c.Get(ctx, plugin, pluginType); err == nil && entry != nil { + + // Some keys will be prepended with the plugin type, but other ones won't. + // Users don't expect to see the plugin type, so we need to strip that here. + idx := strings.Index(plugin, pluginTypePrefix) + if idx == 0 { + plugin = plugin[len(pluginTypePrefix):] + } + mapKeys[plugin] = true + } } for _, plugin := range builtinKeys { diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index 7222959cc954..c9818d36475d 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/pluginutil" ) @@ -24,16 +25,17 @@ func TestPluginCatalog_CRUD(t *testing.T) { core.pluginCatalog.directory = sym // Get builtin plugin - p, err := core.pluginCatalog.Get(context.Background(), "mysql-database-plugin") + p, err := core.pluginCatalog.Get(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected error %v", err) } expectedBuiltin := &pluginutil.PluginRunner{ Name: "mysql-database-plugin", + Type: consts.PluginTypeDatabase, Builtin: true, } - expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin") + expectedBuiltin.BuiltinFactory, _ = builtinplugins.Registry.Get("mysql-database-plugin", consts.PluginTypeDatabase) if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { t.Fatal("expected BuiltinFactory did not match actual") @@ -52,19 +54,20 @@ func TestPluginCatalog_CRUD(t *testing.T) { defer file.Close() command := fmt.Sprintf("%s", filepath.Base(file.Name())) - err = core.pluginCatalog.Set(context.Background(), "mysql-database-plugin", command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'}) + err = core.pluginCatalog.Set(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'}) if err != nil { t.Fatal(err) } // Get the plugin - p, err = core.pluginCatalog.Get(context.Background(), "mysql-database-plugin") + p, err = core.pluginCatalog.Get(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected error %v", err) } expected := &pluginutil.PluginRunner{ Name: "mysql-database-plugin", + Type: consts.PluginTypeDatabase, Command: filepath.Join(sym, filepath.Base(file.Name())), Args: []string{"--test"}, Env: []string{"FOO=BAR"}, @@ -77,22 +80,23 @@ func TestPluginCatalog_CRUD(t *testing.T) { } // Delete the plugin - err = core.pluginCatalog.Delete(context.Background(), "mysql-database-plugin") + err = core.pluginCatalog.Delete(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected err: %v", err) } // Get builtin plugin - p, err = core.pluginCatalog.Get(context.Background(), "mysql-database-plugin") + p, err = core.pluginCatalog.Get(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected error %v", err) } expectedBuiltin = &pluginutil.PluginRunner{ Name: "mysql-database-plugin", + Type: consts.PluginTypeDatabase, Builtin: true, } - expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin") + expectedBuiltin.BuiltinFactory, _ = builtinplugins.Registry.Get("mysql-database-plugin", consts.PluginTypeDatabase) if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { t.Fatal("expected BuiltinFactory did not match actual") @@ -115,11 +119,11 @@ func TestPluginCatalog_List(t *testing.T) { core.pluginCatalog.directory = sym // Get builtin plugins and sort them - builtinKeys := builtinplugins.Keys() + builtinKeys := builtinplugins.Registry.Keys(consts.PluginTypeDatabase) sort.Strings(builtinKeys) // List only builtin plugins - plugins, err := core.pluginCatalog.List(context.Background()) + plugins, err := core.pluginCatalog.List(context.Background(), consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -142,23 +146,24 @@ func TestPluginCatalog_List(t *testing.T) { defer file.Close() command := filepath.Base(file.Name()) - err = core.pluginCatalog.Set(context.Background(), "mysql-database-plugin", command, []string{"--test"}, []string{}, []byte{'1'}) + err = core.pluginCatalog.Set(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, command, []string{"--test"}, []string{}, []byte{'1'}) if err != nil { t.Fatal(err) } // Set another plugin - err = core.pluginCatalog.Set(context.Background(), "aaaaaaa", command, []string{"--test"}, []string{}, []byte{'1'}) + err = core.pluginCatalog.Set(context.Background(), "aaaaaaa", consts.PluginTypeDatabase, command, []string{"--test"}, []string{}, []byte{'1'}) if err != nil { t.Fatal(err) } // List the plugins - plugins, err = core.pluginCatalog.List(context.Background()) + plugins, err = core.pluginCatalog.List(context.Background(), consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected error %v", err) } + // plugins has a test-added plugin called "aaaaaaa" that is not built in if len(plugins) != len(builtinKeys)+1 { t.Fatalf("unexpected length of plugin list, expected %d, got %d", len(builtinKeys)+1, len(plugins)) } diff --git a/vault/plugin_reload.go b/vault/plugin_reload.go index c0fb715c3ee8..1bb2cdc63698 100644 --- a/vault/plugin_reload.go +++ b/vault/plugin_reload.go @@ -45,14 +45,12 @@ func (c *Core) reloadMatchingPluginMounts(ctx context.Context, mounts []string) continue } - if entry.Type == "plugin" { - err := c.reloadBackendCommon(ctx, entry, isAuth) - if err != nil { - errors = multierror.Append(errors, errwrap.Wrapf(fmt.Sprintf("cannot reload plugin on %q: {{err}}", mount), err)) - continue - } - c.logger.Info("successfully reloaded plugin", "plugin", entry.Config.PluginName, "path", entry.Path) + err := c.reloadBackendCommon(ctx, entry, isAuth) + if err != nil { + errors = multierror.Append(errors, errwrap.Wrapf(fmt.Sprintf("cannot reload plugin on %q: {{err}}", mount), err)) + continue } + c.logger.Info("successfully reloaded plugin", "plugin", entry.Type, "path", entry.Path) } return errors } @@ -77,8 +75,7 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro if ns.ID != entry.Namespace().ID { continue } - - if entry.Config.PluginName == pluginName && entry.Type == "plugin" { + if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginNameDeprecated == pluginName) { err := c.reloadBackendCommon(ctx, entry, false) if err != nil { return err @@ -94,7 +91,7 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro continue } - if entry.Config.PluginName == pluginName && entry.Type == "plugin" { + if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginNameDeprecated == pluginName) { err := c.reloadBackendCommon(ctx, entry, true) if err != nil { return err diff --git a/vault/request_handling.go b/vault/request_handling.go index eb4326660783..d55dfcaca747 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -713,7 +713,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp case "plugin": // If we are a plugin type and the plugin name is "kv" check the // mount entry options. - if matchingMountEntry.Config.PluginName == "kv" && (matchingMountEntry.Options == nil || matchingMountEntry.Options["leased_passthrough"] != "true") { + if matchingMountEntry.Config.PluginNameDeprecated == "kv" && (matchingMountEntry.Options == nil || matchingMountEntry.Options["leased_passthrough"] != "true") { registerLease = false resp.Secret.Renewable = false } diff --git a/vault/router_test.go b/vault/router_test.go index c08e631a021f..287fca27ca5b 100644 --- a/vault/router_test.go +++ b/vault/router_test.go @@ -29,6 +29,7 @@ type NoopBackend struct { Invalidations []string DefaultLeaseTTL time.Duration MaxLeaseTTL time.Duration + BackendType logical.BackendType } func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { @@ -104,7 +105,10 @@ func (n *NoopBackend) Initialize(ctx context.Context) error { } func (n *NoopBackend) Type() logical.BackendType { - return logical.TypeLogical + if n.BackendType == logical.TypeUnknown { + return logical.TypeLogical + } + return n.BackendType } func TestRouter_Mount(t *testing.T) { diff --git a/vault/testing.go b/vault/testing.go index d5237fc63a8e..86475c4afce7 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -33,15 +33,18 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/net/http2" - cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/logging" "github.com/hashicorp/vault/helper/reload" "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/physical" + dbMysql "github.com/hashicorp/vault/plugins/database/mysql" + dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql" "github.com/mitchellh/go-testing-interface" physInmem "github.com/hashicorp/vault/physical/inmem" @@ -113,17 +116,19 @@ func TestCoreWithConfig(t testing.T, conf *CoreConfig) *Core { // specified seal for testing. func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core { conf := &CoreConfig{ - Seal: testSeal, - EnableUI: false, - EnableRaw: enableRaw, + Seal: testSeal, + EnableUI: false, + EnableRaw: enableRaw, + BuiltinRegistry: NewMockBuiltinRegistry(), } return TestCoreWithSealAndUI(t, conf) } func TestCoreUI(t testing.T, enableUI bool) *Core { conf := &CoreConfig{ - EnableUI: enableUI, - EnableRaw: true, + EnableUI: enableUI, + EnableRaw: true, + BuiltinRegistry: NewMockBuiltinRegistry(), } return TestCoreWithSealAndUI(t, conf) } @@ -176,6 +181,7 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo noopBackends["noop"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { b := new(framework.Backend) b.Setup(ctx, config) + b.BackendType = logical.TypeCredential return b, nil } noopBackends["http"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { @@ -207,6 +213,7 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo CredentialBackends: credentialBackends, DisableMlock: true, Logger: logger, + BuiltinRegistry: NewMockBuiltinRegistry(), } return conf @@ -348,7 +355,7 @@ func TestDynamicSystemView(c *Core) *dynamicSystemView { // TestAddTestPlugin registers the testFunc as part of the plugin command to the // plugin catalog. If provided, uses tmpDir as the plugin directory. -func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string, env []string, tempDir string) { +func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, testFunc string, env []string, tempDir string) { file, err := os.Open(os.Args[0]) if err != nil { t.Fatal(err) @@ -410,7 +417,7 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string, env []string c.pluginCatalog.directory = fullPath args := []string{fmt.Sprintf("--test.run=%s", testFunc)} - err = c.pluginCatalog.Set(context.Background(), name, fileName, args, env, sum) + err = c.pluginCatalog.Set(context.Background(), name, pluginType, fileName, args, env, sum) if err != nil { t.Fatal(err) } @@ -638,7 +645,7 @@ func (n *rawHTTP) Setup(ctx context.Context, config *logical.BackendConfig) erro } func (n *rawHTTP) Type() logical.BackendType { - return logical.TypeUnknown + return logical.TypeLogical } func GenerateRandBytes(length int) ([]byte, error) { @@ -1177,6 +1184,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te DisableMlock: true, EnableUI: true, EnableRaw: true, + BuiltinRegistry: NewMockBuiltinRegistry(), } if base != nil { @@ -1192,6 +1200,9 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te coreConfig.DisableSealWrap = base.DisableSealWrap coreConfig.DevLicenseDuration = base.DevLicenseDuration coreConfig.DisableCache = base.DisableCache + if base.BuiltinRegistry != nil { + coreConfig.BuiltinRegistry = base.BuiltinRegistry + } if !coreConfig.DisableMlock { base.DisableMlock = false @@ -1262,7 +1273,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te cores := []*Core{} coreConfigs := []*CoreConfig{} for i := 0; i < numCores; i++ { - localConfig := coreConfig.Clone() + localConfig := *coreConfig localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port) if localConfig.ClusterAddr != "" { localConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port+105) @@ -1279,12 +1290,12 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te localConfig.LicensingConfig = testGetLicensingConfig(pubKey) - c, err := NewCore(localConfig) + c, err := NewCore(&localConfig) if err != nil { t.Fatalf("err: %v", err) } cores = append(cores, c) - coreConfigs = append(coreConfigs, localConfig) + coreConfigs = append(coreConfigs, &localConfig) if opts != nil && opts.HandlerFunc != nil { handlers[i] = opts.HandlerFunc(&HandlerProperties{ Core: c, @@ -1483,3 +1494,56 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te return &testCluster } + +func NewMockBuiltinRegistry() *mockBuiltinRegistry { + return &mockBuiltinRegistry{ + forTesting: map[string]consts.PluginType{ + "mysql-database-plugin": consts.PluginTypeDatabase, + "postgresql-database-plugin": consts.PluginTypeDatabase, + }, + } +} + +type mockBuiltinRegistry struct { + forTesting map[string]consts.PluginType +} + +func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) { + testPluginType, ok := m.forTesting[name] + if !ok { + return nil, false + } + if pluginType != testPluginType { + return nil, false + } + if name == "postgresql-database-plugin" { + return dbPostgres.New, true + } + return dbMysql.New(dbMysql.MetadataLen, dbMysql.MetadataLen, dbMysql.UsernameLen), true +} + +// Keys only supports getting a realistic list of the keys for database plugins. +func (m *mockBuiltinRegistry) Keys(pluginType consts.PluginType) []string { + if pluginType != consts.PluginTypeDatabase { + return []string{} + } + /* + This is a hard-coded reproduction of the db plugin keys in helper/builtinplugins/registry.go. + The registry isn't directly used because it causes import cycles. + */ + return []string{ + "mysql-database-plugin", + "mysql-aurora-database-plugin", + "mysql-rds-database-plugin", + "mysql-legacy-database-plugin", + "postgresql-database-plugin", + "mssql-database-plugin", + "cassandra-database-plugin", + "mongodb-database-plugin", + "hana-database-plugin", + } +} + +func (m *mockBuiltinRegistry) Contains(name string, pluginType consts.PluginType) bool { + return false +} diff --git a/vault/token_store.go b/vault/token_store.go index da2d7efbb16d..2d2b1a1f3fa1 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -561,6 +561,7 @@ func NewTokenStore(ctx context.Context, logger log.Logger, core *Core, config *l salt.DefaultLocation, }, }, + BackendType: logical.TypeCredential, } t.Backend.Paths = append(t.Backend.Paths, t.paths()...)