diff --git a/command.go b/command.go index ee5365bcb..6b696a970 100644 --- a/command.go +++ b/command.go @@ -1021,11 +1021,7 @@ func (c *Command) validateRequiredFlags() error { flags := c.Flags() missingFlagNames := []string{} flags.VisitAll(func(pflag *flag.Flag) { - requiredAnnotation, found := pflag.Annotations[BashCompOneRequiredFlag] - if !found { - return - } - if (requiredAnnotation[0] == "true") && !pflag.Changed { + if c.IsFlagRequired(pflag.Name) && !pflag.Changed { missingFlagNames = append(missingFlagNames, pflag.Name) } }) @@ -1608,6 +1604,36 @@ func (c *Command) HasAvailableInheritedFlags() bool { return c.InheritedFlags().HasAvailableFlags() } +// IsFlagRequired returns true if the flag identified by 'name' is a local flag +// is marked as required. +func (c *Command) IsFlagRequired(name string) bool { + f := c.Flags().Lookup(name) + if f == nil { + return false + } + + requiredAnnotation, found := f.Annotations[BashCompOneRequiredFlag] + if !found { + return false + } + return len(requiredAnnotation) > 0 && requiredAnnotation[0] == "true" +} + +// IsPersistentFlagRequired returns true if the persistent flag identified by +// 'name' is marked as required. +func (c *Command) IsPersistentFlagRequired(name string) bool { + f := c.PersistentFlags().Lookup(name) + if f == nil { + return false + } + + requiredAnnotation, found := f.Annotations[BashCompOneRequiredFlag] + if !found { + return false + } + return len(requiredAnnotation) > 0 && requiredAnnotation[0] == "true" +} + // Flag climbs up the command tree looking for matching flag. func (c *Command) Flag(name string) (flag *flag.Flag) { flag = c.Flags().Lookup(name) diff --git a/command_test.go b/command_test.go index d48fef1a0..e472f09d7 100644 --- a/command_test.go +++ b/command_test.go @@ -792,6 +792,20 @@ func TestRequiredFlags(t *testing.T) { } } +func TestIsFlagRequired(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun} + c.Flags().String("foo", "", "") + assertNoErr(t, c.MarkFlagRequired("foo")) + c.Flags().String("bar", "", "") + + expected := true + got := c.IsFlagRequired("foo") + + if got != expected { + t.Errorf("Expected %v: got: %v", expected, got) + } +} + func TestPersistentRequiredFlags(t *testing.T) { parent := &Command{Use: "parent", Run: emptyRun} parent.PersistentFlags().String("foo1", "", "") @@ -817,6 +831,20 @@ func TestPersistentRequiredFlags(t *testing.T) { } } +func TestIsPersistentFlagRequired(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun} + c.PersistentFlags().String("foo", "", "") + assertNoErr(t, c.MarkPersistentFlagRequired("foo")) + c.PersistentFlags().String("bar", "", "") + + expected := true + got := c.IsPersistentFlagRequired("foo") + + if got != expected { + t.Errorf("Expected %v: got: %v", expected, got) + } +} + func TestPersistentRequiredFlagsWithDisableFlagParsing(t *testing.T) { // Make sure a required persistent flag does not break // commands that disable flag parsing