Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LocalFlags().NFlags to count the number of local flags that have been set explicitly #1999

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
138 changes: 128 additions & 10 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ type Command struct {
pflags *flag.FlagSet
// lflags contains local flags.
lflags *flag.FlagSet
// lnpflags contains local non persistent flags
lnpflags *flag.FlagSet
// iflags contains inherited flags.
iflags *flag.FlagSet
// parentsPflags is all persistent flags of cmd's parents.
Expand Down Expand Up @@ -1027,7 +1029,6 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
c.checkCommandGroups()

args := c.args

// Workaround FAIL with "go test -v" or "cobra.test -test.v", see #155
if c.args == nil && filepath.Base(os.Args[0]) != "cobra.test" {
args = os.Args[1:]
Expand Down Expand Up @@ -1603,15 +1604,19 @@ func (c *Command) Flags() *flag.FlagSet {

// LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands.
func (c *Command) LocalNonPersistentFlags() *flag.FlagSet {
persistentFlags := c.PersistentFlags()
if c.lnpflags == nil {
persistentFlags := c.PersistentFlags()

c.lnpflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.LocalFlags().VisitAll(func(f *flag.Flag) {
if persistentFlags.Lookup(f.Name) == nil {
f.Changed = false
c.lnpflags.AddFlag(f)
}
})
}

out := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.LocalFlags().VisitAll(func(f *flag.Flag) {
if persistentFlags.Lookup(f.Name) == nil {
out.AddFlag(f)
}
})
return out
return c.lnpflags
}

// LocalFlags returns the local FlagSet specifically set in the current command.
Expand All @@ -1633,6 +1638,7 @@ func (c *Command) LocalFlags() *flag.FlagSet {
addToLocal := func(f *flag.Flag) {
// Add the flag if it is not a parent PFlag, or it shadows a parent PFlag
if c.lflags.Lookup(f.Name) == nil && f != c.parentsPflags.Lookup(f.Name) {
f.Changed = false
c.lflags.AddFlag(f)
}
}
Expand Down Expand Up @@ -1764,12 +1770,97 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) {
return
}

func (c *Command) parseLongArgs(s string, args []string, flags *flag.FlagSet) (passedArgs, restArgs []string) {
restArgs = args
name := s[2:]
if len(name) == 0 {
passedArgs = append(passedArgs, s)
return
}

split := strings.SplitN(s[2:], "=", 2)
name = split[0]
searchedFlag := flags.Lookup(name)
if searchedFlag == nil {
// ignore the flag that is not registered in passed flags but is registered in c.parentsPflags
c.parentsPflags.VisitAll(func(f *flag.Flag) {
if name == f.Name {
if len(split) == 1 && f.NoOptDefVal == "" && len(args) > 0 {
// '--flag arg'
restArgs = args[1:]
}
}
})
return
}

passedArgs = append(passedArgs, fmt.Sprintf("--%s", s[2:]))
if len(split) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 {
passedArgs = append(passedArgs, args[0])
restArgs = args[1:]
}

return
}

func (c *Command) parseShortArgs(s string, args []string, flags *flag.FlagSet) (passedArgs []string, restArgs []string) {
restArgs = args

shorthands := s[1:]
shorthand := string(s[1])

searchedFlag := flags.ShorthandLookup(shorthand)
if searchedFlag == nil {
// ignore the flag that is not registered in passed flags but is registered in c.parentsPflags
c.parentsPflags.VisitAll(func(f *flag.Flag) {
if shorthand == f.Shorthand {
if len(shorthands) == 1 && f.NoOptDefVal == "" && len(args) > 0 {
// '-f arg'
restArgs = args[1:]
}
}
})
return
}

passedArgs = append(passedArgs, s)
if len(shorthands) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 {
// '-f arg'
passedArgs = append(passedArgs, args[0])
restArgs = args[1:]
}

return
}

func (c *Command) removeParentPersistentArgs(args []string, flags *flag.FlagSet) (newArgs []string) {
for len(args) > 0 {
s := args[0]
args = args[1:]
if len(s) == 0 || s[0] != '-' {
newArgs = append(newArgs, s)
continue
}

var passedArgs, restArgs []string
if s[1] == '-' {
passedArgs, restArgs = c.parseLongArgs(s, args, flags)
} else {
passedArgs, restArgs = c.parseShortArgs(s, args, flags)
}
if len(passedArgs) > 0 {
newArgs = append(newArgs, passedArgs...)
}
args = restArgs
}
return
}

// ParseFlags parses persistent flag tree and local flags.
func (c *Command) ParseFlags(args []string) error {
if c.DisableFlagParsing {
return nil
}

if c.flagErrorBuf == nil {
c.flagErrorBuf = new(bytes.Buffer)
}
Expand All @@ -1779,11 +1870,38 @@ func (c *Command) ParseFlags(args []string) error {
// do it here after merging all flags and just before parse
c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)

// parse Flags
err := c.Flags().Parse(args)
// Print warnings if they occurred (e.g. deprecated flag messages).
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {
c.Print(c.flagErrorBuf.String())
}
if err != nil {
return err
}

// parse Local Flags
c.LocalFlags() // need to execute LocalFlags() to set the value in c.lflags before executing removeParentPersistentArgs
c.lflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
localArgs := c.removeParentPersistentArgs(args, c.lflags) // get only arguments related to c.lflags
err = c.lflags.Parse(localArgs)
// Print warnings if they occurred (e.g. deprecated flag messages).
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {
c.Print(c.flagErrorBuf.String())
}
if err != nil {
return err
}

// parse local non persistent flags
c.LocalNonPersistentFlags() // need to execute LocalNonPersistentFlags() to set the value in c.lnpflags before executing removeParentPersistentArgs
c.lnpflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
localNonPersistentArgs := c.removeParentPersistentArgs(args, c.lnpflags)
err = c.lnpflags.Parse(localNonPersistentArgs)
// Print warnings if they occurred (e.g. deprecated flag messages).
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {
c.Print(c.flagErrorBuf.String())
}

return err
}
Expand Down
140 changes: 140 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2735,3 +2735,143 @@ func TestUnknownFlagShouldReturnSameErrorRegardlessOfArgPosition(t *testing.T) {
})
}
}

func TestNFlagForFlags(t *testing.T) {
var rootNFlag, childNFlag int
rootCmd := &Command{
Use: "root",
Run: func(cmd *Command, _ []string) {
rootNFlag = cmd.Flags().NFlag()
},
}
childCmd := &Command{
Use: "child",
Run: func(cmd *Command, args []string) {
childNFlag = cmd.Flags().NFlag()
},
}
rootCmd.AddCommand(childCmd)

rootCmd.PersistentFlags().Bool("rp", false, "")
rpFlag := rootCmd.PersistentFlags().Lookup("rp")
childCmd.PersistentFlags().Bool("cp", false, "")
childCmd.Flags().Int("int", 0, "")

output, err := executeCommand(rootCmd, "--rp")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if rootNFlag != 1 {
t.Errorf("Expected NFlag: %v, got %v", 1, rootNFlag)
}
// set Changed false for the next test
rpFlag.Changed = false

output, err = executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if childNFlag != 3 {
t.Errorf("Expected NFlag: %v, got %v", 3, childNFlag)
}
}

func TestNFlagForLocalFlags(t *testing.T) {
var localNFlag int
rootCmd := &Command{
Use: "root",
Run: emptyRun,
}
childCmd := &Command{
Use: "child",
Run: func(cmd *Command, args []string) {
localNFlag = cmd.LocalFlags().NFlag()
},
}
rootCmd.AddCommand(childCmd)

rootCmd.PersistentFlags().Bool("rp", false, "")
childCmd.PersistentFlags().Bool("cp", false, "")
childCmd.Flags().Int("int", 0, "")

output, err := executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if localNFlag != 2 { // LocalFlags().NFlag() ignores '--rp'
t.Errorf("Expected NFlag: %v, got %v", 2, localNFlag)
}
}

func TestNFlagForLocalNonPersistentFlags(t *testing.T) {
var localNonPNFlag int
rootCmd := &Command{
Use: "root",
Run: emptyRun,
}
childCmd := &Command{
Use: "child",
Run: func(cmd *Command, args []string) {
localNonPNFlag = cmd.LocalNonPersistentFlags().NFlag()
},
}
rootCmd.AddCommand(childCmd)

rootCmd.PersistentFlags().Bool("rp", false, "")
childCmd.PersistentFlags().Bool("cp", false, "")
childCmd.Flags().Int("int", 0, "")

output, err := executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if localNonPNFlag != 1 { // LocalNonPersistentFlags().NFlag() ignores '--rp' and '--cp'
t.Errorf("Expected NFlag: %v, got %v", 1, localNonPNFlag)
}
}

func TestRemoveParentPersistentArgs(t *testing.T) {
rootCmd := &Command{Use: "root", Run: emptyRun}
childCmd := &Command{Use: "child", Run: emptyRun}
rootCmd.AddCommand(childCmd)

rootCmd.PersistentFlags().BoolP("rp", "r", false, "")
rootCmd.PersistentFlags().Int("ri", 0, "")
childCmd.PersistentFlags().Bool("cp", false, "")
childCmd.Flags().Int("int", 0, "")

output, err := executeCommand(rootCmd, "child", "-r", "--ri", "10", "--cp", "--int", "10")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

args := rootCmd.args
_, args, _ = rootCmd.Find(args)

gotLocalArgs := childCmd.removeParentPersistentArgs(args, childCmd.lflags)
expectedLocalArgs := []string{"--cp", "--int", "10"}
if !reflect.DeepEqual(gotLocalArgs, expectedLocalArgs) {
t.Errorf("Expected localArgs: %v, got %v", expectedLocalArgs, gotLocalArgs)
}

gotLocalNonPersistentArgs := childCmd.removeParentPersistentArgs(args, childCmd.lnpflags)
expectedLocalNonPersistentArgs := []string{"--int", "10"}
if !reflect.DeepEqual(gotLocalNonPersistentArgs, expectedLocalNonPersistentArgs) {
t.Errorf("Expected localArgs: %v, got %v", expectedLocalNonPersistentArgs, gotLocalNonPersistentArgs)
}
}