diff --git a/command.go b/command.go index 78088db69..e2c8e6c42 100644 --- a/command.go +++ b/command.go @@ -190,6 +190,8 @@ type Command struct { // versionTemplate is the version template defined by user. versionTemplate *tmplFunc + // versionFunc is the version func defined by user. + versionFunc func(*Command) error // errPrefix is the error message prefix defined by user. errPrefix string @@ -363,6 +365,14 @@ func (c *Command) SetHelpTemplate(s string) { c.helpTemplate = tmpl(s) } +// SetVersionFunc sets version function. Can be defined by Application. +// +// Setting this means that [Command.SetVersionTemplate] will not have any +// effect any more. +func (c *Command) SetVersionFunc(f func(*Command) error) { + c.versionFunc = f +} + // SetVersionTemplate sets version template to be used. Application can use it to set custom template. func (c *Command) SetVersionTemplate(s string) { if s == "" { @@ -639,6 +649,18 @@ func (c *Command) getVersionTemplateFunc() func(w io.Writer, data interface{}) e return defaultVersionFunc } +// getVersionFunc returns the version function for the command going up the +// command tree if necessary. +func (c *Command) getVersionFunc() func(*Command) error { + if c.versionFunc != nil { + return c.versionFunc + } + if c.HasParent() { + return c.parent.getVersionFunc() + } + return nil +} + // ErrPrefix return error message prefix for the command func (c *Command) ErrPrefix() string { if c.errPrefix != "" { @@ -943,6 +965,9 @@ func (c *Command) execute(a []string) (err error) { return err } if versionVal { + if fn := c.getVersionFunc(); fn != nil { + return fn(c) + } fn := c.getVersionTemplateFunc() err := fn(c.OutOrStdout(), c) if err != nil { diff --git a/command_test.go b/command_test.go index a86e57f0a..365114a85 100644 --- a/command_test.go +++ b/command_test.go @@ -2952,3 +2952,18 @@ func TestHelpFuncExecuted(t *testing.T) { checkStringContains(t, output, helpText) } + +func TestVersionFuncExecuted(t *testing.T) { + rootCmd := &Command{Use: "root", Run: emptyRun, Version: "v2.3.4"} + rootCmd.SetVersionFunc(func(cmd *Command) error { + _, err := fmt.Fprint(cmd.OutOrStdout(), "custom version function: "+rootCmd.Version) + return err + }) + + output, err := executeCommandWithContext(context.Background(), rootCmd, "-v") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + checkStringContains(t, output, "custom version function: v2.3.4") +}