Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Persistent*Run behavior to allow multiple hooks throughout the execution chain #1142

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cobra.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ var templateFuncs = template.FuncMap{

var initializers []func()

// EnablePersistentRunOverride ensures Persistent*Run* functions in childs override their parents
var EnablePersistentRunOverride = true

// EnablePrefixMatching allows to set automatic prefix matching. Automatic prefix matching can be a dangerous thing
// to automatically enable in CLI tools.
// Set this to true to enable it.
Expand Down
148 changes: 115 additions & 33 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ type Command struct {
// PersistentPostRunE: PersistentPostRun but returns an error.
PersistentPostRunE func(cmd *Command, args []string) error

// persistentPreRunHooks are executed before the command or one of its children are executed
persistentPreRunHooks []func(cmd *Command, args []string) error
// preRunHooks are executed before the command is executed
preRunHooks []func(cmd *Command, args []string) error
// runHooks are executed when the command is executed
runHooks []func(cmd *Command, args []string) error
// postRunHooks are executed after the command has executed
postRunHooks []func(cmd *Command, args []string) error
// persistentPostRunHooks are executed after the command or one of its children have executed
persistentPostRunHooks []func(cmd *Command, args []string) error

// SilenceErrors is an option to quiet errors down stream.
SilenceErrors bool

Expand Down Expand Up @@ -816,51 +827,82 @@ func (c *Command) execute(a []string) (err error) {
return err
}

for p := c; p != nil; p = p.Parent() {
if p.PersistentPreRunE != nil {
if err := p.PersistentPreRunE(c, argWoFlags); err != nil {
return err
}
break
} else if p.PersistentPreRun != nil {
p.PersistentPreRun(c, argWoFlags)
break
}
}
// Allocate the hooks execution chain for the current command
var hooks []func(cmd *Command, args []string) error

// First append the PreRun* hooks
hooks = append(hooks, c.preRunHooks...)
if c.PreRunE != nil {
if err := c.PreRunE(c, argWoFlags); err != nil {
return err
}
hooks = append(hooks, c.PreRunE)
} else if c.PreRun != nil {
c.PreRun(c, argWoFlags)
hooks = append(hooks, wrapVoidHook(c.PreRun))
}

if err := c.validateRequiredFlags(); err != nil {
return err
}
if c.RunE != nil {
if err := c.RunE(c, argWoFlags); err != nil {
// Include the validateRequiredFlags() logic as a hook
// to be executed before running the main Run hooks.
hooks = append(hooks, func(cmd *Command, args []string) error {
if err := cmd.validateRequiredFlags(); err != nil {
return err
}
} else {
c.Run(c, argWoFlags)
return nil
})

// Append the main Run* hooks
hooks = append(hooks, c.runHooks...)
if c.RunE != nil {
hooks = append(hooks, c.RunE)
} else if c.Run != nil {
hooks = append(hooks, wrapVoidHook(c.Run))
}

// Append the PostRun* hooks
hooks = append(hooks, c.postRunHooks...)
if c.PostRunE != nil {
if err := c.PostRunE(c, argWoFlags); err != nil {
return err
}
hooks = append(hooks, c.PostRunE)
} else if c.PostRun != nil {
c.PostRun(c, argWoFlags)
hooks = append(hooks, wrapVoidHook(c.PostRun))
}

// Lastly find and append/prepend the Persistent*Run hooks.
// Setting EnablePersistentRunOverride to true (default) preserves
// the previous behavior/concern where childs should override their parents.
// Any hooks registered through OnPersistent*Run will always
// be executed and cannot be overriden.
hasPersistentPreRunFromStruct := false
hasPersistentPostRunFromStruct := false
for p := c; p != nil; p = p.Parent() {
if p.PersistentPostRunE != nil {
if err := p.PersistentPostRunE(c, argWoFlags); err != nil {
return err
// Find and prepend the PersistentPreRun* hooks as defined on the commands
if !hasPersistentPreRunFromStruct || !EnablePersistentRunOverride {
if p.PersistentPreRunE != nil {
hooks = prependHook(&hooks, p.PersistentPreRunE)
hasPersistentPreRunFromStruct = true
} else if p.PersistentPreRun != nil {
hooks = prependHook(&hooks, wrapVoidHook(p.PersistentPreRun))
hasPersistentPreRunFromStruct = true
}
}
// Find and append the PersistentPostRun* hooks as defined on the commands
if !hasPersistentPostRunFromStruct || !EnablePersistentRunOverride {
if p.PersistentPostRunE != nil {
hooks = append(hooks, p.PersistentPostRunE)
hasPersistentPostRunFromStruct = true
} else if p.PersistentPostRun != nil {
hooks = append(hooks, wrapVoidHook(p.PersistentPostRun))
hasPersistentPostRunFromStruct = true
}
break
} else if p.PersistentPostRun != nil {
p.PersistentPostRun(c, argWoFlags)
break
}

// Hooks registered through OnPersistent*Run should always be executed
// Prepend the PersistentPreRun* hooks
hooks = append(p.persistentPreRunHooks, hooks...)
// Append the PersistentPostRun* hooks
hooks = append(hooks, p.persistentPostRunHooks...)
}

// Execute the hooks execution chain:
for _, x := range hooks {
if err := x(c, argWoFlags); err != nil {
return err
}
}

Expand All @@ -873,6 +915,46 @@ func (c *Command) preRun() {
}
}

// prependHook prepends a hook onto the array of hooks
func prependHook(hooks *[]func(cmd *Command, args []string) error, hook ...func(cmd *Command, args []string) error) []func(cmd *Command, args []string) error {
return append(hook, *hooks...)
}

// wrapVoidHook wraps a void hook into a function having the return error signature
func wrapVoidHook(hook func(cmd *Command, args []string)) func(cmd *Command, args []string) error {
return func(cmd *Command, args []string) error {
hook(cmd, args)
return nil
}
}

// OnPersistentPreRun registers one or more hooks on the command to be executed
// before the command or one of its children are executed
func (c *Command) OnPersistentPreRun(f ...func(cmd *Command, args []string) error) {
c.persistentPreRunHooks = append(c.persistentPreRunHooks, f...)
}

// OnPreRun registers one or more hooks on the command to be executed before the command is executed
func (c *Command) OnPreRun(f ...func(cmd *Command, args []string) error) {
c.preRunHooks = append(c.preRunHooks, f...)
}

// OnRun registers one or more hooks on the command to be executed when the command is executed
func (c *Command) OnRun(f ...func(cmd *Command, args []string) error) {
c.runHooks = append(c.runHooks, f...)
}

// OnPostRun registers one or more hooks on the command to be executed after the command has executed
func (c *Command) OnPostRun(f ...func(cmd *Command, args []string) error) {
c.postRunHooks = append(c.postRunHooks, f...)
}

// OnPersistentPostRun register one or more hooks on the command to be executed
// after the command or one of its children have executed
func (c *Command) OnPersistentPostRun(f ...func(cmd *Command, args []string) error) {
c.persistentPostRunHooks = append(c.persistentPostRunHooks, f...)
}

// ExecuteContext is the same as Execute(), but sets the ctx on the command.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions.
func (c *Command) ExecuteContext(ctx context.Context) error {
Expand Down
126 changes: 114 additions & 12 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,23 @@ func TestPersistentHooks(t *testing.T) {
childPersPostArgs string
)

var (
persParentPersPreArgs string
persParentPreArgs string
persParentRunArgs string
persParentPostArgs string
persParentPersPostArgs string
)

var (
persChildPersPreArgs string
persChildPreArgs string
persChildPreArgs2 string
persChildRunArgs string
persChildPostArgs string
persChildPersPostArgs string
)

parentCmd := &Command{
Use: "parent",
PersistentPreRun: func(_ *Command, args []string) {
Expand Down Expand Up @@ -1371,21 +1388,65 @@ func TestPersistentHooks(t *testing.T) {
}
parentCmd.AddCommand(childCmd)

parentCmd.OnPersistentPreRun(func(_ *Command, args []string) error {
persParentPersPreArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnPreRun(func(_ *Command, args []string) error {
persParentPreArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnRun(func(_ *Command, args []string) error {
persParentRunArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnPostRun(func(_ *Command, args []string) error {
persParentPostArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnPersistentPostRun(func(_ *Command, args []string) error {
persParentPersPostArgs = strings.Join(args, " ")
return nil
})

childCmd.OnPersistentPreRun(func(_ *Command, args []string) error {
persChildPersPreArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPreRun(func(_ *Command, args []string) error {
persChildPreArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPreRun(func(_ *Command, args []string) error {
persChildPreArgs2 = strings.Join(args, " ") + " three"
return nil
})
childCmd.OnRun(func(_ *Command, args []string) error {
persChildRunArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPostRun(func(_ *Command, args []string) error {
persChildPostArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPersistentPostRun(func(_ *Command, args []string) error {
persChildPersPostArgs = strings.Join(args, " ")
return nil
})

output, err := executeCommand(parentCmd, "child", "one", "two")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

// TODO: currently PersistenPreRun* defined in parent does not
// run if the matchin child subcommand has PersistenPreRun.
// If the behavior changes (https://github.com/spf13/cobra/issues/252)
// this test must be fixed.
if parentPersPreArgs != "" {
if EnablePersistentRunOverride && parentPersPreArgs != "" {
t.Errorf("Expected blank parentPersPreArgs, got %q", parentPersPreArgs)
}
if !EnablePersistentRunOverride && parentPersPreArgs != "one two" {
t.Errorf("Expected parentPersPreArgs %q, got %q", "one two", parentPersPreArgs)
}
if parentPreArgs != "" {
t.Errorf("Expected blank parentPreArgs, got %q", parentPreArgs)
}
Expand All @@ -1395,14 +1456,12 @@ func TestPersistentHooks(t *testing.T) {
if parentPostArgs != "" {
t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs)
}
// TODO: currently PersistenPostRun* defined in parent does not
// run if the matchin child subcommand has PersistenPostRun.
// If the behavior changes (https://github.com/spf13/cobra/issues/252)
// this test must be fixed.
if parentPersPostArgs != "" {
if EnablePersistentRunOverride && parentPersPostArgs != "" {
t.Errorf("Expected blank parentPersPostArgs, got %q", parentPersPostArgs)
}

if !EnablePersistentRunOverride && parentPersPostArgs != "one two" {
t.Errorf("Expected parentPersPostArgs %q, got %q", "one two", parentPersPostArgs)
}
if childPersPreArgs != "one two" {
t.Errorf("Expected childPersPreArgs %q, got %q", "one two", childPersPreArgs)
}
Expand All @@ -1418,6 +1477,49 @@ func TestPersistentHooks(t *testing.T) {
if childPersPostArgs != "one two" {
t.Errorf("Expected childPersPostArgs %q, got %q", "one two", childPersPostArgs)
}

// Test On*Run hooks

if persParentPersPreArgs != "one two" {
t.Errorf("Expected persParentPersPreArgs %q, got %q", "one two", persParentPersPreArgs)
}
if persParentPreArgs != "" {
t.Errorf("Expected blank persParentPreArgs, got %q", persParentPreArgs)
}
if persParentRunArgs != "" {
t.Errorf("Expected blank persParentRunArgs, got %q", persParentRunArgs)
}
if persParentPostArgs != "" {
t.Errorf("Expected blank persParentPostArg, got %q", persParentPostArgs)
}
if persParentPersPostArgs != "one two" {
t.Errorf("Expected persParentPersPostArgs %q, got %q", "one two", persParentPersPostArgs)
}

if persChildPersPreArgs != "one two" {
t.Errorf("Expected persChildPersPreArgs %q, got %q", "one two", persChildPersPreArgs)
}
if persChildPreArgs != "one two" {
t.Errorf("Expected persChildPreArgs %q, got %q", "one two", persChildPreArgs)
}
if persChildPreArgs2 != "one two three" {
t.Errorf("Expected persChildPreArgs %q, got %q", "one two three", persChildPreArgs2)
}
if persChildRunArgs != "one two" {
t.Errorf("Expected persChildRunArgs %q, got %q", "one two", persChildRunArgs)
}
if persChildPostArgs != "one two" {
t.Errorf("Expected persChildPostArgs %q, got %q", "one two", persChildPostArgs)
}
if persChildPersPostArgs != "one two" {
t.Errorf("Expected persChildPersPostArgs %q, got %q", "one two", persChildPersPostArgs)
}
}

func TestPersistentHooksWoOverride(t *testing.T) {
EnablePersistentRunOverride = false
TestPersistentHooks(t)
EnablePersistentRunOverride = true
}

// Related to https://github.com/spf13/cobra/issues/521.
Expand Down