diff --git a/command.go b/command.go index 01f7c6f1c..90f98c3f0 100644 --- a/command.go +++ b/command.go @@ -150,6 +150,8 @@ type Command struct { pflags *flag.FlagSet // lflags contains local flags. lflags *flag.FlagSet + // lnpflags contains local non persistent flags + lnpflags *flag.FlagSet // iflags contains inherited flags. iflags *flag.FlagSet // parentsPflags is all persistent flags of cmd's parents. @@ -1027,7 +1029,6 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { c.checkCommandGroups() args := c.args - // Workaround FAIL with "go test -v" or "cobra.test -test.v", see #155 if c.args == nil && filepath.Base(os.Args[0]) != "cobra.test" { args = os.Args[1:] @@ -1603,15 +1604,19 @@ func (c *Command) Flags() *flag.FlagSet { // LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands. func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { - persistentFlags := c.PersistentFlags() + if c.lnpflags == nil { + persistentFlags := c.PersistentFlags() + + c.lnpflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.LocalFlags().VisitAll(func(f *flag.Flag) { + if persistentFlags.Lookup(f.Name) == nil { + f.Changed = false + c.lnpflags.AddFlag(f) + } + }) + } - out := flag.NewFlagSet(c.Name(), flag.ContinueOnError) - c.LocalFlags().VisitAll(func(f *flag.Flag) { - if persistentFlags.Lookup(f.Name) == nil { - out.AddFlag(f) - } - }) - return out + return c.lnpflags } // LocalFlags returns the local FlagSet specifically set in the current command. @@ -1633,6 +1638,7 @@ func (c *Command) LocalFlags() *flag.FlagSet { addToLocal := func(f *flag.Flag) { // Add the flag if it is not a parent PFlag, or it shadows a parent PFlag if c.lflags.Lookup(f.Name) == nil && f != c.parentsPflags.Lookup(f.Name) { + f.Changed = false c.lflags.AddFlag(f) } } @@ -1764,12 +1770,97 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) { return } +func (c *Command) parseLongArgs(s string, args []string, flags *flag.FlagSet) (passedArgs, restArgs []string) { + restArgs = args + name := s[2:] + if len(name) == 0 { + passedArgs = append(passedArgs, s) + return + } + + split := strings.SplitN(s[2:], "=", 2) + name = split[0] + searchedFlag := flags.Lookup(name) + if searchedFlag == nil { + // ignore the flag that is not registered in passed flags but is registered in c.parentsPflags + c.parentsPflags.VisitAll(func(f *flag.Flag) { + if name == f.Name { + if len(split) == 1 && f.NoOptDefVal == "" && len(args) > 0 { + // '--flag arg' + restArgs = args[1:] + } + } + }) + return + } + + passedArgs = append(passedArgs, fmt.Sprintf("--%s", s[2:])) + if len(split) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 { + passedArgs = append(passedArgs, args[0]) + restArgs = args[1:] + } + + return +} + +func (c *Command) parseShortArgs(s string, args []string, flags *flag.FlagSet) (passedArgs []string, restArgs []string) { + restArgs = args + + shorthands := s[1:] + shorthand := string(s[1]) + + searchedFlag := flags.ShorthandLookup(shorthand) + if searchedFlag == nil { + // ignore the flag that is not registered in passed flags but is registered in c.parentsPflags + c.parentsPflags.VisitAll(func(f *flag.Flag) { + if shorthand == f.Shorthand { + if len(shorthands) == 1 && f.NoOptDefVal == "" && len(args) > 0 { + // '-f arg' + restArgs = args[1:] + } + } + }) + return + } + + passedArgs = append(passedArgs, s) + if len(shorthands) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 { + // '-f arg' + passedArgs = append(passedArgs, args[0]) + restArgs = args[1:] + } + + return +} + +func (c *Command) removeParentPersistentArgs(args []string, flags *flag.FlagSet) (newArgs []string) { + for len(args) > 0 { + s := args[0] + args = args[1:] + if len(s) == 0 || s[0] != '-' { + newArgs = append(newArgs, s) + continue + } + + var passedArgs, restArgs []string + if s[1] == '-' { + passedArgs, restArgs = c.parseLongArgs(s, args, flags) + } else { + passedArgs, restArgs = c.parseShortArgs(s, args, flags) + } + if len(passedArgs) > 0 { + newArgs = append(newArgs, passedArgs...) + } + args = restArgs + } + return +} + // ParseFlags parses persistent flag tree and local flags. func (c *Command) ParseFlags(args []string) error { if c.DisableFlagParsing { return nil } - if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } @@ -1779,11 +1870,38 @@ func (c *Command) ParseFlags(args []string) error { // do it here after merging all flags and just before parse c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) + // parse Flags err := c.Flags().Parse(args) // Print warnings if they occurred (e.g. deprecated flag messages). if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil { c.Print(c.flagErrorBuf.String()) } + if err != nil { + return err + } + + // parse Local Flags + c.LocalFlags() // need to execute LocalFlags() to set the value in c.lflags before executing removeParentPersistentArgs + c.lflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) + localArgs := c.removeParentPersistentArgs(args, c.lflags) // get only arguments related to c.lflags + err = c.lflags.Parse(localArgs) + // Print warnings if they occurred (e.g. deprecated flag messages). + if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil { + c.Print(c.flagErrorBuf.String()) + } + if err != nil { + return err + } + + // parse local non persistent flags + c.LocalNonPersistentFlags() // need to execute LocalNonPersistentFlags() to set the value in c.lnpflags before executing removeParentPersistentArgs + c.lnpflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) + localNonPersistentArgs := c.removeParentPersistentArgs(args, c.lnpflags) + err = c.lnpflags.Parse(localNonPersistentArgs) + // Print warnings if they occurred (e.g. deprecated flag messages). + if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil { + c.Print(c.flagErrorBuf.String()) + } return err } diff --git a/command_test.go b/command_test.go index 0212f5ae9..3b6d59d4c 100644 --- a/command_test.go +++ b/command_test.go @@ -2735,3 +2735,143 @@ func TestUnknownFlagShouldReturnSameErrorRegardlessOfArgPosition(t *testing.T) { }) } } + +func TestNFlagForFlags(t *testing.T) { + var rootNFlag, childNFlag int + rootCmd := &Command{ + Use: "root", + Run: func(cmd *Command, _ []string) { + rootNFlag = cmd.Flags().NFlag() + }, + } + childCmd := &Command{ + Use: "child", + Run: func(cmd *Command, args []string) { + childNFlag = cmd.Flags().NFlag() + }, + } + rootCmd.AddCommand(childCmd) + + rootCmd.PersistentFlags().Bool("rp", false, "") + rpFlag := rootCmd.PersistentFlags().Lookup("rp") + childCmd.PersistentFlags().Bool("cp", false, "") + childCmd.Flags().Int("int", 0, "") + + output, err := executeCommand(rootCmd, "--rp") + if output != "" { + t.Errorf("Unexpected output: %v", output) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if rootNFlag != 1 { + t.Errorf("Expected NFlag: %v, got %v", 1, rootNFlag) + } + // set Changed false for the next test + rpFlag.Changed = false + + output, err = executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10") + if output != "" { + t.Errorf("Unexpected output: %v", output) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if childNFlag != 3 { + t.Errorf("Expected NFlag: %v, got %v", 3, childNFlag) + } +} + +func TestNFlagForLocalFlags(t *testing.T) { + var localNFlag int + rootCmd := &Command{ + Use: "root", + Run: emptyRun, + } + childCmd := &Command{ + Use: "child", + Run: func(cmd *Command, args []string) { + localNFlag = cmd.LocalFlags().NFlag() + }, + } + rootCmd.AddCommand(childCmd) + + rootCmd.PersistentFlags().Bool("rp", false, "") + childCmd.PersistentFlags().Bool("cp", false, "") + childCmd.Flags().Int("int", 0, "") + + output, err := executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10") + if output != "" { + t.Errorf("Unexpected output: %v", output) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if localNFlag != 2 { // LocalFlags().NFlag() ignores '--rp' + t.Errorf("Expected NFlag: %v, got %v", 2, localNFlag) + } +} + +func TestNFlagForLocalNonPersistentFlags(t *testing.T) { + var localNonPNFlag int + rootCmd := &Command{ + Use: "root", + Run: emptyRun, + } + childCmd := &Command{ + Use: "child", + Run: func(cmd *Command, args []string) { + localNonPNFlag = cmd.LocalNonPersistentFlags().NFlag() + }, + } + rootCmd.AddCommand(childCmd) + + rootCmd.PersistentFlags().Bool("rp", false, "") + childCmd.PersistentFlags().Bool("cp", false, "") + childCmd.Flags().Int("int", 0, "") + + output, err := executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10") + if output != "" { + t.Errorf("Unexpected output: %v", output) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if localNonPNFlag != 1 { // LocalNonPersistentFlags().NFlag() ignores '--rp' and '--cp' + t.Errorf("Expected NFlag: %v, got %v", 1, localNonPNFlag) + } +} + +func TestRemoveParentPersistentArgs(t *testing.T) { + rootCmd := &Command{Use: "root", Run: emptyRun} + childCmd := &Command{Use: "child", Run: emptyRun} + rootCmd.AddCommand(childCmd) + + rootCmd.PersistentFlags().BoolP("rp", "r", false, "") + rootCmd.PersistentFlags().Int("ri", 0, "") + childCmd.PersistentFlags().Bool("cp", false, "") + childCmd.Flags().Int("int", 0, "") + + output, err := executeCommand(rootCmd, "child", "-r", "--ri", "10", "--cp", "--int", "10") + if output != "" { + t.Errorf("Unexpected output: %v", output) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + args := rootCmd.args + _, args, _ = rootCmd.Find(args) + + gotLocalArgs := childCmd.removeParentPersistentArgs(args, childCmd.lflags) + expectedLocalArgs := []string{"--cp", "--int", "10"} + if !reflect.DeepEqual(gotLocalArgs, expectedLocalArgs) { + t.Errorf("Expected localArgs: %v, got %v", expectedLocalArgs, gotLocalArgs) + } + + gotLocalNonPersistentArgs := childCmd.removeParentPersistentArgs(args, childCmd.lnpflags) + expectedLocalNonPersistentArgs := []string{"--int", "10"} + if !reflect.DeepEqual(gotLocalNonPersistentArgs, expectedLocalNonPersistentArgs) { + t.Errorf("Expected localArgs: %v, got %v", expectedLocalNonPersistentArgs, gotLocalNonPersistentArgs) + } +}