Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added plugin reload function to api #8777

Merged
merged 11 commits into from
May 4, 2020
28 changes: 28 additions & 0 deletions api/sys_plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,34 @@ func (c *Sys) DeregisterPlugin(i *DeregisterPluginInput) error {
return err
}

//ReloadPluginInput is used as input fo the ReloadPlugin function.
type ReloadPluginInput struct {
//Plugin is the name of the plugin to reload, as registered in the plugin catalog
Plugin string `json:"plugin"`

//Mounts is the array of string mount paths of the plugin backends to reload
Mounts []string `json:"mounts"`
}

// ReloadPlugin reloads mounted plugin backends
func (c *Sys) ReloadPlugin(i *ReloadPluginInput) error {
path := "/v1/sys/plugins/reload/backend"
req := c.c.NewRequest(http.MethodPut, path)

if err := req.SetJSONBody(i); err != nil {
return err
}

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()

resp, err := c.c.RawRequestWithContext(ctx, req)
if err == nil {
defer resp.Body.Close()
}
return err
}

// catalogPathByType is a helper to construct the proper API path by plugin type
func catalogPathByType(pluginType consts.PluginType, name string) string {
path := fmt.Sprintf("/v1/sys/plugins/catalog/%s/%s", pluginType, name)
Expand Down
104 changes: 104 additions & 0 deletions command/plugin_reload.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package command

import (
"fmt"
"strings"

"github.com/hashicorp/vault/api"
"github.com/mitchellh/cli"
"github.com/posener/complete"
)

var _ cli.Command = (*PluginReloadCommand)(nil)
var _ cli.CommandAutocomplete = (*PluginReloadCommand)(nil)

type PluginReloadCommand struct {
*BaseCommand
plugin string
mounts []string
}

func (c *PluginReloadCommand) Synopsis() string {
return "Reload mounted plugin backend"
}

func (c *PluginReloadCommand) Help() string {
helpText := `
Usage: vault plugin reload [options]

Reloads mounted plugin backends. Either the plugin name or the desired plugin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence is a bit confusing (specifically around "plugin backend mounts (mounts)"). Can this be paraphrased?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the new version, English is not my native language, so I'm open to suggestions :)

backend mounts (mounts) must be provided, but not both.

Reload the plugin named my-custom-plugin:

$ vault plugin reload --plugin=my-custom-plugin

` + c.Flags().Help()

return strings.TrimSpace(helpText)
}

func (c *PluginReloadCommand) Flags() *FlagSets {
set := c.flagSet(FlagSetHTTP)

f := set.NewFlagSet("Command Options")

f.StringVar(&StringVar{
Name: "plugin",
Target: &c.plugin,
Completion: complete.PredictAnything,
Usage: "The name of the plugin to reload, as registered in the plugin catalog.",
})

f.StringSliceVar(&StringSliceVar{
Name: "mounts",
Target: &c.mounts,
Completion: complete.PredictAnything,
Usage: "Array or comma-separated string mount paths of the plugin backends to reload.",
})

return set
}

func (c *PluginReloadCommand) AutocompleteArgs() complete.Predictor {
return nil
}

func (c *PluginReloadCommand) AutocompleteFlags() complete.Flags {
return c.Flags().Completions()
}

func (c *PluginReloadCommand) Run(args []string) int {
f := c.Flags()

if err := f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}

switch {
case c.plugin == "" && len(c.mounts) == 0:
c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args)))
return 1
case c.plugin != "" && len(c.mounts) > 0:
c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args)))
return 1
}

client, err := c.Client()
if err != nil {
c.UI.Error(err.Error())
return 2
}

if err := client.Sys().ReloadPlugin(&api.ReloadPluginInput{
Plugin: c.plugin,
Mounts: c.mounts,
}); err != nil {
c.UI.Error(fmt.Sprintf("Error reloading plugin/mounts: %s", err))
return 2
}

c.UI.Output(fmt.Sprintf("Success! Reloaded plugin/mounts: %s%s", c.plugin, c.mounts))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we split this into two print statements based on which one is set?

Copy link
Contributor Author

@olicuzo olicuzo May 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return 0
}
110 changes: 110 additions & 0 deletions command/plugin_reload_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package command

import (
"strings"
"testing"

"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/mitchellh/cli"
)

func testPluginReloadCommand(tb testing.TB) (*cli.MockUi, *PluginReloadCommand) {
tb.Helper()

ui := cli.NewMockUi()
return ui, &PluginReloadCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
}
}

func TestPluginReloadCommand_Run(t *testing.T) {
t.Parallel()

cases := []struct {
name string
args []string
out string
code int
}{
{
"not_enough_args",
nil,
"Not enough arguments",
1,
},
{
"too_many_args",
[]string{"-plugin", "foo", "-mounts", "bar"},
"Too many arguments",
1,
},
}

for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

client, closer := testVaultServer(t)
defer closer()

ui, cmd := testPluginReloadCommand(t)
cmd.client = client

args := append([]string{}, tc.args...)
code := cmd.Run(args)
if code != tc.code {
t.Errorf("expected %d to be %d", code, tc.code)
}

combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, tc.out) {
t.Errorf("expected %q to contain %q", combined, tc.out)
}
})
}

t.Run("integration", func(t *testing.T) {
t.Parallel()

pluginDir, cleanup := testPluginDir(t)
defer cleanup(t)

client, _, closer := testVaultServerPluginDir(t, pluginDir)
defer closer()

pluginName := "my-plugin"
_, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential)

ui, cmd := testPluginReloadCommand(t)
cmd.client = client

if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{
Name: pluginName,
Type: consts.PluginTypeCredential,
Command: pluginName,
SHA256: sha256Sum,
}); err != nil {
t.Fatal(err)
}

code := cmd.Run([]string{
"-plugin", pluginName,
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}

expected := "Success! Reloaded plugin/mounts: "
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, expected) {
t.Errorf("expected %q to contain %q", combined, expected)
}

})

}