Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions args.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package cobra

import (
"fmt"
"strings"
)

Expand All @@ -33,15 +32,23 @@ func legacyArgs(cmd *Command, args []string) error {

// root command with subcommands, do subcommand checking.
if !cmd.HasParent() && len(args) > 0 {
return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0]))
return UnknownSubcommandError{
cmd: cmd,
subcmd: args[0],
suggestions: cmd.findSuggestions(args[0]),
}
}
return nil
}

// NoArgs returns an error if any args are included.
func NoArgs(cmd *Command, args []string) error {
if len(args) > 0 {
return fmt.Errorf("unknown command %q for %q", args[0], cmd.CommandPath())
return UnknownSubcommandError{
cmd: cmd,
subcmd: args[0],
suggestions: nil,
}
}
return nil
}
Expand All @@ -58,7 +65,11 @@ func OnlyValidArgs(cmd *Command, args []string) error {
}
for _, v := range args {
if !stringInSlice(v, validArgs) {
return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0]))
return InvalidArgValueError{
cmd: cmd,
arg: v,
suggestions: cmd.findSuggestions(args[0]),
}
}
}
}
Expand All @@ -74,7 +85,12 @@ func ArbitraryArgs(cmd *Command, args []string) error {
func MinimumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < n {
return fmt.Errorf("requires at least %d arg(s), only received %d", n, len(args))
return InvalidArgCountError{
cmd: cmd,
args: args,
atLeast: n,
atMost: -1,
}
}
return nil
}
Expand All @@ -84,7 +100,12 @@ func MinimumNArgs(n int) PositionalArgs {
func MaximumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) > n {
return fmt.Errorf("accepts at most %d arg(s), received %d", n, len(args))
return InvalidArgCountError{
cmd: cmd,
args: args,
atLeast: -1,
atMost: n,
}
}
return nil
}
Expand All @@ -94,7 +115,12 @@ func MaximumNArgs(n int) PositionalArgs {
func ExactArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) != n {
return fmt.Errorf("accepts %d arg(s), received %d", n, len(args))
return InvalidArgCountError{
cmd: cmd,
args: args,
atLeast: n,
atMost: n,
}
}
return nil
}
Expand All @@ -104,7 +130,12 @@ func ExactArgs(n int) PositionalArgs {
func RangeArgs(min int, max int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < min || len(args) > max {
return fmt.Errorf("accepts between %d and %d arg(s), received %d", min, max, len(args))
return InvalidArgCountError{
cmd: cmd,
args: args,
atLeast: min,
atMost: max,
}
}
return nil
}
Expand Down
77 changes: 77 additions & 0 deletions args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package cobra

import (
"errors"
"fmt"
"reflect"
"strings"
"testing"
)
Expand All @@ -32,6 +34,14 @@ func getCommand(args PositionalArgs, withValid bool) *Command {
return c
}

func getCommandName(c *Command) string {
if c == nil {
return "<nil>"
} else {
return c.Name()
}
}
Comment on lines +37 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func getCommandName(c *Command) string {
if c == nil {
return "<nil>"
} else {
return c.Name()
}
}
func getCommandName(c *Command) string {
if c == nil {
return "<nil>"
}
return c.Name()
}


func expectSuccess(output string, err error, t *testing.T) {
if output != "" {
t.Errorf("Unexpected output: %v", output)
Expand All @@ -41,6 +51,31 @@ func expectSuccess(output string, err error, t *testing.T) {
}
}

func expectErrorAs(err error, target error, t *testing.T) {
if err == nil {
t.Fatalf("Expected error, got nil")
}

targetType := reflect.TypeOf(target)
targetPtr := reflect.New(targetType).Interface() // *SomeError
if !errors.As(err, targetPtr) {
t.Fatalf("Expected error to be %T, got %T", target, err)
}
}
Comment on lines +54 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using simply using this?

Suggested change
func expectErrorAs(err error, target error, t *testing.T) {
if err == nil {
t.Fatalf("Expected error, got nil")
}
targetType := reflect.TypeOf(target)
targetPtr := reflect.New(targetType).Interface() // *SomeError
if !errors.As(err, targetPtr) {
t.Fatalf("Expected error to be %T, got %T", target, err)
}
}
func expectErrorAs(err error, target error, t *testing.T) {
if err == nil {
t.Fatalf("Expected error, got nil")
}
if !errors.As(err, &target) {
t.Fatalf("Expected error to be %T, got %T", target, err)
}
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, actually, yeah. Good point. I'll fix it in the next set of changes.


func expectErrorHasCommand(err error, cmd *Command, t *testing.T) {
getCommand, ok := err.(interface{ GetCommand() *Command })
if !ok {
t.Fatalf("Expected error to have GetCommand method, but did not")
}

got := getCommand.GetCommand()
if cmd != got {
t.Errorf("Expected err.GetCommand to return %v, got %v",
getCommandName(cmd), getCommandName(got))
}
}

func validOnlyWithInvalidArgs(err error, t *testing.T) {
if err == nil {
t.Fatal("Expected an error")
Expand Down Expand Up @@ -139,6 +174,13 @@ func TestNoArgs_WithValidOnly_WithInvalidArgs(t *testing.T) {
validOnlyWithInvalidArgs(err, t)
}

func TestNoArgs_ReturnsUnknownSubcommandError(t *testing.T) {
c := getCommand(NoArgs, false)
_, err := executeCommand(c, "a")
expectErrorAs(err, UnknownSubcommandError{}, t)
expectErrorHasCommand(err, c, t)
}

// OnlyValidArgs

func TestOnlyValidArgs(t *testing.T) {
Expand All @@ -153,6 +195,13 @@ func TestOnlyValidArgs_WithInvalidArgs(t *testing.T) {
validOnlyWithInvalidArgs(err, t)
}

func TestOnlyValidArgs_ReturnsInvalidArgValueError(t *testing.T) {
c := getCommand(OnlyValidArgs, true)
_, err := executeCommand(c, "a")
expectErrorAs(err, InvalidArgValueError{}, t)
expectErrorHasCommand(err, c, t)
}

// ArbitraryArgs

func TestArbitraryArgs(t *testing.T) {
Expand Down Expand Up @@ -229,6 +278,13 @@ func TestMinimumNArgs_WithLessArgs_WithValidOnly_WithInvalidArgs(t *testing.T) {
validOnlyWithInvalidArgs(err, t)
}

func TestMinimumNArgs_ReturnsInvalidArgCountError(t *testing.T) {
c := getCommand(MinimumNArgs(2), true)
_, err := executeCommand(c, "a")
expectErrorAs(err, InvalidArgCountError{}, t)
expectErrorHasCommand(err, c, t)
}

// MaximumNArgs

func TestMaximumNArgs(t *testing.T) {
Expand Down Expand Up @@ -279,6 +335,13 @@ func TestMaximumNArgs_WithMoreArgs_WithValidOnly_WithInvalidArgs(t *testing.T) {
validOnlyWithInvalidArgs(err, t)
}

func TestMaximumNArgs_ReturnsInvalidArgCountError(t *testing.T) {
c := getCommand(MaximumNArgs(2), true)
_, err := executeCommand(c, "a", "b", "c")
expectErrorAs(err, InvalidArgCountError{}, t)
expectErrorHasCommand(err, c, t)
}

// ExactArgs

func TestExactArgs(t *testing.T) {
Expand Down Expand Up @@ -329,6 +392,13 @@ func TestExactArgs_WithInvalidCount_WithValidOnly_WithInvalidArgs(t *testing.T)
validOnlyWithInvalidArgs(err, t)
}

func TestExactArgs_ReturnsInvalidArgCountError(t *testing.T) {
c := getCommand(ExactArgs(2), true)
_, err := executeCommand(c, "a")
expectErrorAs(err, InvalidArgCountError{}, t)
expectErrorHasCommand(err, c, t)
}

// RangeArgs

func TestRangeArgs(t *testing.T) {
Expand Down Expand Up @@ -379,6 +449,13 @@ func TestRangeArgs_WithInvalidCount_WithValidOnly_WithInvalidArgs(t *testing.T)
validOnlyWithInvalidArgs(err, t)
}

func TestRangeArgs_ReturnsInvalidArgCountError(t *testing.T) {
c := getCommand(RangeArgs(2, 4), true)
_, err := executeCommand(c, "a")
expectErrorAs(err, InvalidArgCountError{}, t)
expectErrorHasCommand(err, c, t)
}

// Takes(No)Args

func TestRootTakesNoArgs(t *testing.T) {
Expand Down
31 changes: 20 additions & 11 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,21 +778,14 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
return commandFound, a, nil
}

func (c *Command) findSuggestions(arg string) string {
func (c *Command) findSuggestions(arg string) []string {
if c.DisableSuggestions {
return ""
return nil
}
if c.SuggestionsMinimumDistance <= 0 {
c.SuggestionsMinimumDistance = 2
}
var sb strings.Builder
if suggestions := c.SuggestionsFor(arg); len(suggestions) > 0 {
sb.WriteString("\n\nDid you mean this?\n")
for _, s := range suggestions {
_, _ = fmt.Fprintf(&sb, "\t%v\n", s)
}
}
return sb.String()
return c.SuggestionsFor(arg)
}

func (c *Command) findNext(next string) *Command {
Expand Down Expand Up @@ -1195,7 +1188,10 @@ func (c *Command) ValidateRequiredFlags() error {
})

if len(missingFlagNames) > 0 {
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`))
return RequiredFlagError{
cmd: c,
missingFlagNames: missingFlagNames,
}
}
return nil
}
Expand Down Expand Up @@ -1933,6 +1929,19 @@ func commandNameMatches(s string, t string) bool {
return s == t
}

// helpTextForSuggestions joins a slice of command suggestions into a string.
// If the provided slice is empty or nil, an empty string is returned.
func helpTextForSuggestions(suggestions []string) string {
var sb strings.Builder
if len(suggestions) > 0 {
sb.WriteString("\n\nDid you mean this?\n")
for _, s := range suggestions {
_, _ = fmt.Fprintf(&sb, "\t%v\n", s)
}
}
return sb.String()
}

// tmplFunc holds a template and a function that will execute said template.
type tmplFunc struct {
tmpl string
Expand Down
16 changes: 16 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cobra
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -866,6 +867,21 @@ func TestRequiredFlags(t *testing.T) {
if got != expected {
t.Errorf("Expected error: %q, got: %q", expected, got)
}

// Test it returns valid RequiredFlagError.
var requiredFlagErr RequiredFlagError
if !errors.As(err, &requiredFlagErr) {
t.Fatalf("Expected error to be RequiredFlagError, got %T", err)
}

expectedMissingFlagNames := "foo1 foo2"
gotMissingFlagNames := strings.Join(requiredFlagErr.missingFlagNames, " ")
if expectedMissingFlagNames != gotMissingFlagNames {
t.Errorf("Expected error missingFlagNames to be %q, got %q",
expectedMissingFlagNames, gotMissingFlagNames)
}

expectErrorHasCommand(err, c, t)
}

func TestPersistentRequiredFlags(t *testing.T) {
Expand Down
Loading