diff --git a/cobra.go b/cobra.go index d01becc8f..45f679e17 100644 --- a/cobra.go +++ b/cobra.go @@ -39,6 +39,9 @@ var templateFuncs = template.FuncMap{ var initializers []func() +// EnablePersistentRunOverride ensures Persistent*Run* functions in childs override their parents +var EnablePersistentRunOverride = true + // EnablePrefixMatching allows to set automatic prefix matching. Automatic prefix matching can be a dangerous thing // to automatically enable in CLI tools. // Set this to true to enable it. diff --git a/command.go b/command.go index 5f1cacccb..0f3e1cc25 100644 --- a/command.go +++ b/command.go @@ -118,6 +118,17 @@ type Command struct { // PersistentPostRunE: PersistentPostRun but returns an error. PersistentPostRunE func(cmd *Command, args []string) error + // persistentPreRunHooks are executed before the command or one of its children are executed + persistentPreRunHooks []func(cmd *Command, args []string) error + // preRunHooks are executed before the command is executed + preRunHooks []func(cmd *Command, args []string) error + // runHooks are executed when the command is executed + runHooks []func(cmd *Command, args []string) error + // postRunHooks are executed after the command has executed + postRunHooks []func(cmd *Command, args []string) error + // persistentPostRunHooks are executed after the command or one of its children have executed + persistentPostRunHooks []func(cmd *Command, args []string) error + // SilenceErrors is an option to quiet errors down stream. SilenceErrors bool @@ -816,51 +827,82 @@ func (c *Command) execute(a []string) (err error) { return err } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPreRunE != nil { - if err := p.PersistentPreRunE(c, argWoFlags); err != nil { - return err - } - break - } else if p.PersistentPreRun != nil { - p.PersistentPreRun(c, argWoFlags) - break - } - } + // Allocate the hooks execution chain for the current command + var hooks []func(cmd *Command, args []string) error + + // First append the PreRun* hooks + hooks = append(hooks, c.preRunHooks...) if c.PreRunE != nil { - if err := c.PreRunE(c, argWoFlags); err != nil { - return err - } + hooks = append(hooks, c.PreRunE) } else if c.PreRun != nil { - c.PreRun(c, argWoFlags) + hooks = append(hooks, wrapVoidHook(c.PreRun)) } - if err := c.validateRequiredFlags(); err != nil { - return err - } - if c.RunE != nil { - if err := c.RunE(c, argWoFlags); err != nil { + // Include the validateRequiredFlags() logic as a hook + // to be executed before running the main Run hooks. + hooks = append(hooks, func(cmd *Command, args []string) error { + if err := cmd.validateRequiredFlags(); err != nil { return err } - } else { - c.Run(c, argWoFlags) + return nil + }) + + // Append the main Run* hooks + hooks = append(hooks, c.runHooks...) + if c.RunE != nil { + hooks = append(hooks, c.RunE) + } else if c.Run != nil { + hooks = append(hooks, wrapVoidHook(c.Run)) } + + // Append the PostRun* hooks + hooks = append(hooks, c.postRunHooks...) if c.PostRunE != nil { - if err := c.PostRunE(c, argWoFlags); err != nil { - return err - } + hooks = append(hooks, c.PostRunE) } else if c.PostRun != nil { - c.PostRun(c, argWoFlags) + hooks = append(hooks, wrapVoidHook(c.PostRun)) } + + // Lastly find and append/prepend the Persistent*Run hooks. + // Setting EnablePersistentRunOverride to true (default) preserves + // the previous behavior/concern where childs should override their parents. + // Any hooks registered through OnPersistent*Run will always + // be executed and cannot be overriden. + hasPersistentPreRunFromStruct := false + hasPersistentPostRunFromStruct := false for p := c; p != nil; p = p.Parent() { - if p.PersistentPostRunE != nil { - if err := p.PersistentPostRunE(c, argWoFlags); err != nil { - return err + // Find and prepend the PersistentPreRun* hooks as defined on the commands + if !hasPersistentPreRunFromStruct || !EnablePersistentRunOverride { + if p.PersistentPreRunE != nil { + hooks = prependHook(&hooks, p.PersistentPreRunE) + hasPersistentPreRunFromStruct = true + } else if p.PersistentPreRun != nil { + hooks = prependHook(&hooks, wrapVoidHook(p.PersistentPreRun)) + hasPersistentPreRunFromStruct = true + } + } + // Find and append the PersistentPostRun* hooks as defined on the commands + if !hasPersistentPostRunFromStruct || !EnablePersistentRunOverride { + if p.PersistentPostRunE != nil { + hooks = append(hooks, p.PersistentPostRunE) + hasPersistentPostRunFromStruct = true + } else if p.PersistentPostRun != nil { + hooks = append(hooks, wrapVoidHook(p.PersistentPostRun)) + hasPersistentPostRunFromStruct = true } - break - } else if p.PersistentPostRun != nil { - p.PersistentPostRun(c, argWoFlags) - break + } + + // Hooks registered through OnPersistent*Run should always be executed + // Prepend the PersistentPreRun* hooks + hooks = append(p.persistentPreRunHooks, hooks...) + // Append the PersistentPostRun* hooks + hooks = append(hooks, p.persistentPostRunHooks...) + } + + // Execute the hooks execution chain: + for _, x := range hooks { + if err := x(c, argWoFlags); err != nil { + return err } } @@ -873,6 +915,46 @@ func (c *Command) preRun() { } } +// prependHook prepends a hook onto the array of hooks +func prependHook(hooks *[]func(cmd *Command, args []string) error, hook ...func(cmd *Command, args []string) error) []func(cmd *Command, args []string) error { + return append(hook, *hooks...) +} + +// wrapVoidHook wraps a void hook into a function having the return error signature +func wrapVoidHook(hook func(cmd *Command, args []string)) func(cmd *Command, args []string) error { + return func(cmd *Command, args []string) error { + hook(cmd, args) + return nil + } +} + +// OnPersistentPreRun registers one or more hooks on the command to be executed +// before the command or one of its children are executed +func (c *Command) OnPersistentPreRun(f ...func(cmd *Command, args []string) error) { + c.persistentPreRunHooks = append(c.persistentPreRunHooks, f...) +} + +// OnPreRun registers one or more hooks on the command to be executed before the command is executed +func (c *Command) OnPreRun(f ...func(cmd *Command, args []string) error) { + c.preRunHooks = append(c.preRunHooks, f...) +} + +// OnRun registers one or more hooks on the command to be executed when the command is executed +func (c *Command) OnRun(f ...func(cmd *Command, args []string) error) { + c.runHooks = append(c.runHooks, f...) +} + +// OnPostRun registers one or more hooks on the command to be executed after the command has executed +func (c *Command) OnPostRun(f ...func(cmd *Command, args []string) error) { + c.postRunHooks = append(c.postRunHooks, f...) +} + +// OnPersistentPostRun register one or more hooks on the command to be executed +// after the command or one of its children have executed +func (c *Command) OnPersistentPostRun(f ...func(cmd *Command, args []string) error) { + c.persistentPostRunHooks = append(c.persistentPostRunHooks, f...) +} + // ExecuteContext is the same as Execute(), but sets the ctx on the command. // Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions. func (c *Command) ExecuteContext(ctx context.Context) error { diff --git a/command_test.go b/command_test.go index 16cc41b4c..2095377e1 100644 --- a/command_test.go +++ b/command_test.go @@ -1332,6 +1332,23 @@ func TestPersistentHooks(t *testing.T) { childPersPostArgs string ) + var ( + persParentPersPreArgs string + persParentPreArgs string + persParentRunArgs string + persParentPostArgs string + persParentPersPostArgs string + ) + + var ( + persChildPersPreArgs string + persChildPreArgs string + persChildPreArgs2 string + persChildRunArgs string + persChildPostArgs string + persChildPersPostArgs string + ) + parentCmd := &Command{ Use: "parent", PersistentPreRun: func(_ *Command, args []string) { @@ -1371,6 +1388,52 @@ func TestPersistentHooks(t *testing.T) { } parentCmd.AddCommand(childCmd) + parentCmd.OnPersistentPreRun(func(_ *Command, args []string) error { + persParentPersPreArgs = strings.Join(args, " ") + return nil + }) + parentCmd.OnPreRun(func(_ *Command, args []string) error { + persParentPreArgs = strings.Join(args, " ") + return nil + }) + parentCmd.OnRun(func(_ *Command, args []string) error { + persParentRunArgs = strings.Join(args, " ") + return nil + }) + parentCmd.OnPostRun(func(_ *Command, args []string) error { + persParentPostArgs = strings.Join(args, " ") + return nil + }) + parentCmd.OnPersistentPostRun(func(_ *Command, args []string) error { + persParentPersPostArgs = strings.Join(args, " ") + return nil + }) + + childCmd.OnPersistentPreRun(func(_ *Command, args []string) error { + persChildPersPreArgs = strings.Join(args, " ") + return nil + }) + childCmd.OnPreRun(func(_ *Command, args []string) error { + persChildPreArgs = strings.Join(args, " ") + return nil + }) + childCmd.OnPreRun(func(_ *Command, args []string) error { + persChildPreArgs2 = strings.Join(args, " ") + " three" + return nil + }) + childCmd.OnRun(func(_ *Command, args []string) error { + persChildRunArgs = strings.Join(args, " ") + return nil + }) + childCmd.OnPostRun(func(_ *Command, args []string) error { + persChildPostArgs = strings.Join(args, " ") + return nil + }) + childCmd.OnPersistentPostRun(func(_ *Command, args []string) error { + persChildPersPostArgs = strings.Join(args, " ") + return nil + }) + output, err := executeCommand(parentCmd, "child", "one", "two") if output != "" { t.Errorf("Unexpected output: %v", output) @@ -1378,14 +1441,12 @@ func TestPersistentHooks(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - - // TODO: currently PersistenPreRun* defined in parent does not - // run if the matchin child subcommand has PersistenPreRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - if parentPersPreArgs != "" { + if EnablePersistentRunOverride && parentPersPreArgs != "" { t.Errorf("Expected blank parentPersPreArgs, got %q", parentPersPreArgs) } + if !EnablePersistentRunOverride && parentPersPreArgs != "one two" { + t.Errorf("Expected parentPersPreArgs %q, got %q", "one two", parentPersPreArgs) + } if parentPreArgs != "" { t.Errorf("Expected blank parentPreArgs, got %q", parentPreArgs) } @@ -1395,14 +1456,12 @@ func TestPersistentHooks(t *testing.T) { if parentPostArgs != "" { t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs) } - // TODO: currently PersistenPostRun* defined in parent does not - // run if the matchin child subcommand has PersistenPostRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - if parentPersPostArgs != "" { + if EnablePersistentRunOverride && parentPersPostArgs != "" { t.Errorf("Expected blank parentPersPostArgs, got %q", parentPersPostArgs) } - + if !EnablePersistentRunOverride && parentPersPostArgs != "one two" { + t.Errorf("Expected parentPersPostArgs %q, got %q", "one two", parentPersPostArgs) + } if childPersPreArgs != "one two" { t.Errorf("Expected childPersPreArgs %q, got %q", "one two", childPersPreArgs) } @@ -1418,6 +1477,49 @@ func TestPersistentHooks(t *testing.T) { if childPersPostArgs != "one two" { t.Errorf("Expected childPersPostArgs %q, got %q", "one two", childPersPostArgs) } + + // Test On*Run hooks + + if persParentPersPreArgs != "one two" { + t.Errorf("Expected persParentPersPreArgs %q, got %q", "one two", persParentPersPreArgs) + } + if persParentPreArgs != "" { + t.Errorf("Expected blank persParentPreArgs, got %q", persParentPreArgs) + } + if persParentRunArgs != "" { + t.Errorf("Expected blank persParentRunArgs, got %q", persParentRunArgs) + } + if persParentPostArgs != "" { + t.Errorf("Expected blank persParentPostArg, got %q", persParentPostArgs) + } + if persParentPersPostArgs != "one two" { + t.Errorf("Expected persParentPersPostArgs %q, got %q", "one two", persParentPersPostArgs) + } + + if persChildPersPreArgs != "one two" { + t.Errorf("Expected persChildPersPreArgs %q, got %q", "one two", persChildPersPreArgs) + } + if persChildPreArgs != "one two" { + t.Errorf("Expected persChildPreArgs %q, got %q", "one two", persChildPreArgs) + } + if persChildPreArgs2 != "one two three" { + t.Errorf("Expected persChildPreArgs %q, got %q", "one two three", persChildPreArgs2) + } + if persChildRunArgs != "one two" { + t.Errorf("Expected persChildRunArgs %q, got %q", "one two", persChildRunArgs) + } + if persChildPostArgs != "one two" { + t.Errorf("Expected persChildPostArgs %q, got %q", "one two", persChildPostArgs) + } + if persChildPersPostArgs != "one two" { + t.Errorf("Expected persChildPersPostArgs %q, got %q", "one two", persChildPersPostArgs) + } +} + +func TestPersistentHooksWoOverride(t *testing.T) { + EnablePersistentRunOverride = false + TestPersistentHooks(t) + EnablePersistentRunOverride = true } // Related to https://github.com/spf13/cobra/issues/521.