Skip to content

Commit

Permalink
Support reloading database plugins across multiple mounts (#24512)
Browse files Browse the repository at this point in the history
* Support reloading database plugins across multiple mounts
* Add clarifying comment to MountEntry.Path field
* Tests: Replace non-parallelisable t.Setenv with plugin env settings
  • Loading branch information
tomhjp authored Jan 8, 2024
1 parent ee0ccea commit 6e537bb
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 168 deletions.
271 changes: 133 additions & 138 deletions builtin/logical/database/backend_test.go

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions builtin/logical/database/versioning_large_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ func TestPlugin_lifecycle(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", []string{})
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", []string{})
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", []string{})
env := []string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)}
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", env)
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", env)
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", env)

config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
Expand Down
5 changes: 2 additions & 3 deletions builtin/plugin/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func())
},
}

os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)

vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd, []string{})
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd,
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})

return config, func() {
cluster.Cleanup()
Expand Down
6 changes: 6 additions & 0 deletions changelog/24512.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
```release-note:change
plugins: Add a warning to the response from sys/plugins/reload/backend if no plugins were reloaded.
```
```release-note:improvement
secrets/database: Support reloading named database plugins using the sys/plugins/reload/backend API endpoint.
```
6 changes: 3 additions & 3 deletions http/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package http

import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"reflect"
Expand Down Expand Up @@ -55,10 +56,9 @@ func getPluginClusterAndCore(t *testing.T, logger log.Logger) (*vault.TestCluste
cores := cluster.Cores
core := cores[0]

os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)

vault.TestWaitActive(benchhelpers.TBtoT(t), core.Core)
vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain", []string{})
vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain",
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})

// Mount the mock plugin
err = core.Client.Sys().Mount("mock", &api.MountInput{
Expand Down
32 changes: 32 additions & 0 deletions vault/external_tests/plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
if resp.Data["reload_id"] == nil {
t.Fatal("no reload_id in response")
}
if len(resp.Warnings) != 0 {
t.Fatal(resp.Warnings)
}

for i := 0; i < 2; i++ {
// Ensure internal backed value is reset
Expand All @@ -578,6 +581,35 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
}
}

func TestSystemBackend_PluginReload_WarningIfNoneReloaded(t *testing.T) {
cluster := testSystemBackendMock(t, 1, 2, logical.TypeLogical, "v5")
defer cluster.Cleanup()

core := cluster.Cores[0]
client := core.Client

for _, backendType := range []logical.BackendType{logical.TypeLogical, logical.TypeCredential} {
t.Run(backendType.String(), func(t *testing.T) {
// Perform plugin reload
resp, err := client.Logical().Write("sys/plugins/reload/backend", map[string]any{
"plugin": "does-not-exist",
})
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.Data["reload_id"] == nil {
t.Fatal("no reload_id in response")
}
if len(resp.Warnings) == 0 {
t.Fatal("expected warning")
}
})
}
}

// testSystemBackendMock returns a systemBackend with the desired number
// of mounted mock plugin backends. numMounts alternates between different
// ways of providing the plugin_name.
Expand Down
25 changes: 16 additions & 9 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,32 +738,39 @@ func (b *SystemBackend) handlePluginReloadUpdate(ctx context.Context, req *logic
return logical.ErrorResponse("plugin or mounts must be provided"), nil
}

resp := logical.Response{
Data: map[string]interface{}{
"reload_id": req.ID,
},
}

if pluginName != "" {
err := b.Core.reloadMatchingPlugin(ctx, pluginName)
reloaded, err := b.Core.reloadMatchingPlugin(ctx, pluginName)
if err != nil {
return nil, err
}
if reloaded == 0 {
if scope == globalScope {
resp.AddWarning("no plugins were reloaded locally (but they may be reloaded on other nodes)")
} else {
resp.AddWarning("no plugins were reloaded")
}
}
} else if len(pluginMounts) > 0 {
err := b.Core.reloadMatchingPluginMounts(ctx, pluginMounts)
if err != nil {
return nil, err
}
}

r := logical.Response{
Data: map[string]interface{}{
"reload_id": req.ID,
},
}

if scope == globalScope {
err := handleGlobalPluginReload(ctx, b.Core, req.ID, pluginName, pluginMounts)
if err != nil {
return nil, err
}
return logical.RespondWithStatusCode(&r, req, http.StatusAccepted)
return logical.RespondWithStatusCode(&resp, req, http.StatusAccepted)
}
return &r, nil
return &resp, nil
}

func (b *SystemBackend) handlePluginRuntimeCatalogUpdate(ctx context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) {
Expand Down
2 changes: 1 addition & 1 deletion vault/mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ const mountStateUnmounting = "unmounting"
// MountEntry is used to represent a mount table entry
type MountEntry struct {
Table string `json:"table"` // The table it belongs to
Path string `json:"path"` // Mount Path
Path string `json:"path"` // Mount Path, as provided in the mount API call but with a trailing slash, i.e. no auth/ or namespace prefix.
Type string `json:"type"` // Logical backend Type. NB: This is the plugin name, e.g. my-vault-plugin, NOT plugin type (e.g. auth).
Description string `json:"description"` // User-provided description
UUID string `json:"uuid"` // Barrier view UUID
Expand Down
47 changes: 36 additions & 11 deletions vault/plugin_reload.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,60 @@ func (c *Core) reloadMatchingPluginMounts(ctx context.Context, mounts []string)
return errors
}

// reloadPlugin reloads all mounted backends that are of
// plugin pluginName (name of the plugin as registered in
// the plugin catalog).
func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) error {
// reloadMatchingPlugin reloads all mounted backends that are named pluginName
// (name of the plugin as registered in the plugin catalog). It returns the
// number of plugins that were reloaded and an error if any.
func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) (reloaded int, err error) {
c.mountsLock.RLock()
defer c.mountsLock.RUnlock()
c.authLock.RLock()
defer c.authLock.RUnlock()

ns, err := namespace.FromContext(ctx)
if err != nil {
return err
return reloaded, err
}

// Filter mount entries that only matches the plugin name
for _, entry := range c.mounts.Entries {
// We dont reload mounts that are not in the same namespace
if ns.ID != entry.Namespace().ID {
continue
}

if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
err := c.reloadBackendCommon(ctx, entry, false)
if err != nil {
return err
return reloaded, err
}
reloaded++
c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "version", entry.Version)
} else if entry.Type == "database" {
// The combined database plugin is itself a secrets engine, but
// knowledge of whether a database plugin is in use within a particular
// mount is internal to the combined database plugin's storage, so
// we delegate the reload request with an internally routed request.
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: entry.Path + "reload/" + pluginName,
}
resp, err := c.router.Route(ctx, req)
if err != nil {
return reloaded, err
}
if resp == nil {
return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s", pluginName, entry.Path)
}
if resp.IsError() {
return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s: %s", pluginName, entry.Path, resp.Error())
}

if count, ok := resp.Data["count"].(int); ok && count > 0 {
c.logger.Info("successfully reloaded database plugin(s)", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "connections", resp.Data["connections"])
reloaded += count
}
c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "path", entry.Path, "version", entry.Version)
}
}

// Filter auth mount entries that ony matches the plugin name
for _, entry := range c.auth.Entries {
// We dont reload mounts that are not in the same namespace
if ns.ID != entry.Namespace().ID {
Expand All @@ -109,13 +133,14 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro
if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
err := c.reloadBackendCommon(ctx, entry, true)
if err != nil {
return err
return reloaded, err
}
reloaded++
c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version)
}
}

return nil
return reloaded, nil
}

// reloadBackendCommon is a generic method to reload a backend provided a
Expand Down

0 comments on commit 6e537bb

Please sign in to comment.