diff --git a/example_test.go b/example_test.go index abd7806f..db7fb95c 100644 --- a/example_test.go +++ b/example_test.go @@ -34,3 +34,26 @@ func ExampleFlagSet_ShorthandLookup() { fmt.Println(flag.Name) } + +func ExampleFlagSet_StringToString() { + args := []string{ + "--arg", "a=1,b=2", + "--arg", "a=2", + "--arg=d=4", + } + + fs := pflag.NewFlagSet("Example", pflag.ContinueOnError) + fs.StringToString("arg", make(map[string]string), "string-to-string arg accepting key=value pairs") + + if err := fs.Parse(args); err != nil { + panic(err) + } + + value, err := fs.GetStringToString("arg") + if err != nil { + panic(err) + } + + fmt.Println(value) + // Output: map[a:2 b:2 d:4] +} diff --git a/flag.go b/flag.go index e9ca46e2..76e9b5fd 100644 --- a/flag.go +++ b/flag.go @@ -400,7 +400,12 @@ func (f *FlagSet) lookup(name NormalizedName) *Flag { return f.formal[name] } -// func to return a given type for a given flag name +// getFlagType performs a lookup of a flag with the given name and ftype. The flag is stringified and passed through +// convFunc before being returned to enforce flag immutablility. +// +// convFunc may be nil, in which case the raw flag value is returned directly and no immutability is enforced. This is +// particularly useful when users need to access the pointer of the underlying flag value for manipulation (e.g. +// resetting flag values in tests). func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval string) (interface{}, error)) (interface{}, error) { flag := f.Lookup(name) if flag == nil { @@ -413,6 +418,10 @@ func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval stri return nil, err } + if convFunc == nil { + return flag.Value, nil + } + sval := flag.Value.String() result, err := convFunc(sval) if err != nil { diff --git a/golangflag.go b/golangflag.go index e62eab53..20594384 100644 --- a/golangflag.go +++ b/golangflag.go @@ -158,4 +158,3 @@ func ParseSkippedFlags(osArgs []string, goFlagSet *goflag.FlagSet) error { } return goFlagSet.Parse(skippedFlags) } - diff --git a/string_slice.go b/string_slice.go index 3cb2e69d..d421887e 100644 --- a/string_slice.go +++ b/string_slice.go @@ -98,9 +98,12 @@ func (f *FlagSet) GetStringSlice(name string) ([]string, error) { // The argument p points to a []string variable in which to store the value of the flag. // Compared to StringArray flags, StringSlice flags take comma-separated value as arguments and split them accordingly. // For example: -// --ss="v1,v2" --ss="v3" +// +// --ss="v1,v2" --ss="v3" +// // will result in -// []string{"v1", "v2", "v3"} +// +// []string{"v1", "v2", "v3"} func (f *FlagSet) StringSliceVar(p *[]string, name string, value []string, usage string) { f.VarP(newStringSliceValue(value, p), name, "", usage) } @@ -114,9 +117,12 @@ func (f *FlagSet) StringSliceVarP(p *[]string, name, shorthand string, value []s // The argument p points to a []string variable in which to store the value of the flag. // Compared to StringArray flags, StringSlice flags take comma-separated value as arguments and split them accordingly. // For example: -// --ss="v1,v2" --ss="v3" +// +// --ss="v1,v2" --ss="v3" +// // will result in -// []string{"v1", "v2", "v3"} +// +// []string{"v1", "v2", "v3"} func StringSliceVar(p *[]string, name string, value []string, usage string) { CommandLine.VarP(newStringSliceValue(value, p), name, "", usage) } @@ -130,9 +136,12 @@ func StringSliceVarP(p *[]string, name, shorthand string, value []string, usage // The return value is the address of a []string variable that stores the value of the flag. // Compared to StringArray flags, StringSlice flags take comma-separated value as arguments and split them accordingly. // For example: -// --ss="v1,v2" --ss="v3" +// +// --ss="v1,v2" --ss="v3" +// // will result in -// []string{"v1", "v2", "v3"} +// +// []string{"v1", "v2", "v3"} func (f *FlagSet) StringSlice(name string, value []string, usage string) *[]string { p := []string{} f.StringSliceVarP(&p, name, "", value, usage) @@ -150,9 +159,12 @@ func (f *FlagSet) StringSliceP(name, shorthand string, value []string, usage str // The return value is the address of a []string variable that stores the value of the flag. // Compared to StringArray flags, StringSlice flags take comma-separated value as arguments and split them accordingly. // For example: -// --ss="v1,v2" --ss="v3" +// +// --ss="v1,v2" --ss="v3" +// // will result in -// []string{"v1", "v2", "v3"} +// +// []string{"v1", "v2", "v3"} func StringSlice(name string, value []string, usage string) *[]string { return CommandLine.StringSliceP(name, "", value, usage) } diff --git a/string_to_string.go b/string_to_string.go index 1d1e3bf9..f3327e33 100644 --- a/string_to_string.go +++ b/string_to_string.go @@ -21,7 +21,7 @@ func newStringToStringValue(val map[string]string, p *map[string]string) *string return ssv } -// Format: a=1,b=2 +// Set updates the flag value from the given string, adding additional mappings or updating existing ones. func (s *stringToStringValue) Set(val string) error { var ss []string n := strings.Count(val, "=") @@ -47,13 +47,17 @@ func (s *stringToStringValue) Set(val string) error { } out[kv[0]] = kv[1] } + + // clear out any default flag values if !s.changed { - *s.value = out - } else { - for k, v := range out { - (*s.value)[k] = v + for k := range *s.value { + delete(*s.value, k) } } + + for k, v := range out { + (*s.value)[k] = v + } s.changed = true return nil } @@ -84,85 +88,100 @@ func (s *stringToStringValue) String() string { return "[" + strings.TrimSpace(buf.String()) + "]" } -func stringToStringConv(val string) (interface{}, error) { - val = strings.Trim(val, "[]") - // An empty string would cause an empty map - if len(val) == 0 { - return map[string]string{}, nil - } - r := csv.NewReader(strings.NewReader(val)) - ss, err := r.Read() - if err != nil { - return nil, err - } - out := make(map[string]string, len(ss)) - for _, pair := range ss { - kv := strings.SplitN(pair, "=", 2) - if len(kv) != 2 { - return nil, fmt.Errorf("%s must be formatted as key=value", pair) - } - out[kv[0]] = kv[1] - } - return out, nil -} - -// GetStringToString return the map[string]string value of a flag with the given name +// GetStringToString return the map value of a flag with the given name from f. The returned map shares memory with the +// internal flag value [Flag.Value]. func (f *FlagSet) GetStringToString(name string) (map[string]string, error) { - val, err := f.getFlagType(name, "stringToString", stringToStringConv) + val, err := f.getFlagType(name, "stringToString", nil) if err != nil { return map[string]string{}, err } - return val.(map[string]string), nil -} - -// StringToStringVar defines a string flag with specified name, default value, and usage string. -// The argument p points to a map[string]string variable in which to store the values of the multiple flags. -// The value of each argument will not try to be separated by comma -func (f *FlagSet) StringToStringVar(p *map[string]string, name string, value map[string]string, usage string) { - f.VarP(newStringToStringValue(value, p), name, "", usage) -} - -// StringToStringVarP is like StringToStringVar, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) StringToStringVarP(p *map[string]string, name, shorthand string, value map[string]string, usage string) { - f.VarP(newStringToStringValue(value, p), name, shorthand, usage) -} -// StringToStringVar defines a string flag with specified name, default value, and usage string. -// The argument p points to a map[string]string variable in which to store the value of the flag. -// The value of each argument will not try to be separated by comma -func StringToStringVar(p *map[string]string, name string, value map[string]string, usage string) { - CommandLine.VarP(newStringToStringValue(value, p), name, "", usage) -} + fv, ok := val.(*stringToStringValue) + if !ok { + panic(fmt.Errorf("illegal state: unspected internal type for stringToString flag '%s'", name)) + } + if fv.value == nil { + return nil, nil + } -// StringToStringVarP is like StringToStringVar, but accepts a shorthand letter that can be used after a single dash. -func StringToStringVarP(p *map[string]string, name, shorthand string, value map[string]string, usage string) { - CommandLine.VarP(newStringToStringValue(value, p), name, shorthand, usage) + return *fv.value, nil } -// StringToString defines a string flag with specified name, default value, and usage string. -// The return value is the address of a map[string]string variable that stores the value of the flag. -// The value of each argument will not try to be separated by comma +// StringToString defines a map flag with specified name, default value, and usage string. +// +// StringToString flags are used to pass key=value pairs to applications. The same flag can be provided more than once +// with all key=value pairs being merged into a final map. Multiple key=value pairs may be provided in a single arg, +// separated by commas. A few simple examples include: +// +// --arg a=1 +// --arg a=1 --arg b=2 +// --arg a=1,b=2 +// --arg=a=1 +// +// As a special case, a single key-value pair with a value containing a comma will be interpreted as a single pair: +// +// --arg a=1,2 +// +// Returns a pointer to the map which will be updated upon invocation of [FlagSet.Parse], [Flag.Value.Set], and others. func (f *FlagSet) StringToString(name string, value map[string]string, usage string) *map[string]string { p := map[string]string{} f.StringToStringVarP(&p, name, "", value, usage) return &p } -// StringToStringP is like StringToString, but accepts a shorthand letter that can be used after a single dash. +// StringToStringP is like [FlagSet.StringToString], but also accepts a shorthand letter that can be used after a single +// dash. +// +// See [FlagSet.StringToString]. func (f *FlagSet) StringToStringP(name, shorthand string, value map[string]string, usage string) *map[string]string { p := map[string]string{} f.StringToStringVarP(&p, name, shorthand, value, usage) return &p } +// StringToStringVar is like [FlagSet.StringToString], but also accepts a map ppointer argument p which is updated with +// the parsed key-value pairs. +// +// See [FlagSet.StringToString]. +func (f *FlagSet) StringToStringVar(p *map[string]string, name string, value map[string]string, usage string) { + f.VarP(newStringToStringValue(value, p), name, "", usage) +} + +// StringToStringVarP is like [FlagSet.StringToString], but also accepts a map ppointer argument p which is updated with +// the parsed key-value pairs, and a shorthand letter that can be used after a single dash. +// +// See [FlagSet.StringToString]. +func (f *FlagSet) StringToStringVarP(p *map[string]string, name, shorthand string, value map[string]string, usage string) { + f.VarP(newStringToStringValue(value, p), name, shorthand, usage) +} + // StringToString defines a string flag with specified name, default value, and usage string. -// The return value is the address of a map[string]string variable that stores the value of the flag. -// The value of each argument will not try to be separated by comma +// +// See [FlagSet.StringToString]. func StringToString(name string, value map[string]string, usage string) *map[string]string { return CommandLine.StringToStringP(name, "", value, usage) } -// StringToStringP is like StringToString, but accepts a shorthand letter that can be used after a single dash. +// StringToStringP is like [FlagSet.StringToString], but also accepts a shorthand letter that can be used after a single +// dash. +// +// See [FlagSet.StringToString]. func StringToStringP(name, shorthand string, value map[string]string, usage string) *map[string]string { return CommandLine.StringToStringP(name, shorthand, value, usage) } + +// StringToStringVar is like [FlagSet.StringToString], but also accepts a map ppointer argument p which is updated with +// the parsed key-value pairs. +// +// See [FlagSet.StringToString]. +func StringToStringVar(p *map[string]string, name string, value map[string]string, usage string) { + CommandLine.VarP(newStringToStringValue(value, p), name, "", usage) +} + +// StringToStringVarP is like [FlagSet.StringToString], but also accepts a map ppointer argument p which is updated with +// the parsed key-value pairs, and a shorthand letter that can be used after a single dash. +// +// See [FlagSet.StringToString]. +func StringToStringVarP(p *map[string]string, name, shorthand string, value map[string]string, usage string) { + CommandLine.VarP(newStringToStringValue(value, p), name, shorthand, usage) +} diff --git a/string_to_string_test.go b/string_to_string_test.go index 0777f03f..707da48a 100644 --- a/string_to_string_test.go +++ b/string_to_string_test.go @@ -5,158 +5,243 @@ package pflag import ( - "bytes" - "encoding/csv" - "fmt" - "strings" + "reflect" "testing" ) -func setUpS2SFlagSet(s2sp *map[string]string) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringToStringVar(s2sp, "s2s", map[string]string{}, "Command separated ls2st!") - return f +func TestStringToString(t *testing.T) { + tt := []struct { + args []string + def map[string]string + expected map[string]string + }{ + { + // should permit no args and defaults + args: []string{}, + def: map[string]string{}, + expected: map[string]string{}, + }, + { + // should use defaults when no args given + args: []string{}, + def: map[string]string{"a": "1", "b": "2"}, + expected: map[string]string{"a": "1", "b": "2"}, + }, + { + // should parse single key-value pair + args: []string{"--arg", "a=1"}, + def: map[string]string{}, + expected: map[string]string{"a": "1"}, + }, + { + // should allow comma-separated key-value pairs + args: []string{"--arg", "a=1,b=2"}, + def: map[string]string{}, + expected: map[string]string{"a": "1", "b": "2"}, + }, + { + // should correctly parse values with commas + args: []string{"--arg", "a=1,2"}, + def: map[string]string{}, + expected: map[string]string{"a": "1,2"}, + }, + { + // should correctly parse values with equal symbols + args: []string{"--arg", "a=1="}, + def: map[string]string{}, + expected: map[string]string{"a": "1="}, + }, + { + // should allow multiple map args, merging into a single result + args: []string{"--arg", "a=1,b=2", "--arg", "c=3", "--arg", "a=2"}, + def: map[string]string{}, + expected: map[string]string{"a": "2", "b": "2", "c": "3"}, + }, + { + // should ensure command-line args take precedence over defaults + args: []string{"--arg", "a=4"}, + def: map[string]string{"a": "1", "b": "2"}, + expected: map[string]string{"a": "4"}, + }, + { + // should allow quoting of values to handle values with '=' and ',' + args: []string{"--arg", `"foo=bar,bar=qix",qix=foo`}, + def: map[string]string{}, + expected: map[string]string{"foo": "bar,bar=qix", "qix": "foo"}, + }, + { + // should allow quoting of values to handle values with '=' and ',' + args: []string{"--arg", `"foo=bar,bar=qix"`, "--arg", "qix=foo"}, + def: map[string]string{}, + expected: map[string]string{"foo": "bar,bar=qix", "qix": "foo"}, + }, + { + // should allow stuck values + args: []string{`--arg="e=5,6",a=1,b=2,d=4,c=3`}, + def: map[string]string{}, + expected: map[string]string{"a": "1", "b": "2", "d": "4", "c": "3", "e": "5,6"}, + }, + { + // should allow stuck values with defaults + args: []string{`--arg=a=1,b=2,"e=5,6"`}, + def: map[string]string{"da": "1", "db": "2", "de": "5,6"}, + expected: map[string]string{"a": "1", "b": "2", "e": "5,6"}, + }, + { + // should allow multiple stuck value args + args: []string{"--arg=a=1,b=2", "--arg=b=3", `--arg="e=5,6"`, `--arg=f=7,8`}, + def: map[string]string{}, + expected: map[string]string{"a": "1", "b": "3", "e": "5,6", "f": "7,8"}, + }, + { + // should parse arg with empty key and value + args: []string{"--arg", "="}, + def: map[string]string{}, + expected: map[string]string{"": ""}, + }, + { + // should parse comma delimited empty mappings + args: []string{"--arg", "=,=,="}, + def: map[string]string{}, + expected: map[string]string{"": ""}, + }, + { + // should peremit overlapping mappings + args: []string{"--arg", "a=1,a=2"}, + def: map[string]string{}, + expected: map[string]string{"a": "2"}, + }, + { + // should correctly parse short args + args: []string{"-a", "a=1,b=2", "-a=c=3"}, + def: map[string]string{}, + expected: map[string]string{"a": "1", "b": "2", "c": "3"}, + }, + } + + for num, test := range tt { + t.Logf("=== TEST %d ===", num) + t.Logf(" Args: %v", test.args) + t.Logf(" Default Value: %v", test.def) + t.Logf(" Expected: %v", test.expected) + + f := NewFlagSet("test", ContinueOnError) + f.StringToStringP("arg", "a", test.def, "test string-to-string arg") + + if err := f.Parse(test.args); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + result, err := f.GetStringToString("arg") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + t.Logf(" Actual: %v", result) + + for k, v := range test.expected { + actual, ok := result[k] + if !ok { + t.Fatalf("missing key in result: %s", k) + } + if actual != v { + t.Fatalf("unexpected value in result for key '%s': %s", k, actual) + } + } + + if len(test.expected) != len(result) { + t.Fatalf("unexpected extra key-value pairs in result: %v", result) + } + } } -func setUpS2SFlagSetWithDefault(s2sp *map[string]string) *FlagSet { +// This test ensures that [FlagSet.GetStringToString] always return the pointers which were given during flag +// initialization. +// +// This behaviour is important as it ensures consumers of the library can access the underlying map in a stable, +// consistent manner. +func TestS2SStablePointers(t *testing.T) { f := NewFlagSet("test", ContinueOnError) - f.StringToStringVar(s2sp, "s2s", map[string]string{"da": "1", "db": "2", "de": "5,6"}, "Command separated ls2st!") - return f -} -func createS2SFlag(vals map[string]string) string { - records := make([]string, 0, len(vals)>>1) - for k, v := range vals { - records = append(records, k+"="+v) - } + defval := map[string]string{"a": "1", "b": "2"} + + ptr := f.StringToString("map-flag", defval, "test for s2s arg") - var buf bytes.Buffer - w := csv.NewWriter(&buf) - if err := w.Write(records); err != nil { - panic(err) + if reflect.ValueOf(*ptr).Pointer() != reflect.ValueOf(defval).Pointer() { + t.Fatal("pointer mismatch") } - w.Flush() - return strings.TrimSpace(buf.String()) -} -func TestEmptyS2S(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSet(&s2s) - err := f.Parse([]string{}) + // initially, arg should have defaults + result0, err := f.GetStringToString("map-flag") if err != nil { - t.Fatal("expected no error; got", err) + t.Fatalf("unexpected error: %v", err) + } + if v, ok := result0["a"]; !ok || v != "1" { + t.Fatalf("value not present in map or unexpected value: %v", result0) + } + if v, ok := result0["b"]; !ok || v != "2" { + t.Fatalf("value not present in map or unexpected value: %v", result0) } - getS2S, err := f.GetStringToString("s2s") - if err != nil { - t.Fatal("got an error from GetStringToString():", err) + if reflect.ValueOf(result0).Pointer() != reflect.ValueOf(defval).Pointer() { + t.Fatal("pointer mismatch") } - if len(getS2S) != 0 { - t.Fatalf("got s2s %v with len=%d but expected length=0", getS2S, len(getS2S)) + if reflect.ValueOf(*ptr).Pointer() != reflect.ValueOf(result0).Pointer() { + t.Fatal("pointer mismatch") } -} -func TestS2S(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSet(&s2s) + // manipulate the map; the map should now have a single mapping and the pointers should be stable + if err := f.Set("map-flag", "c=3"); err != nil { + t.Fatalf("unexpected error: %v", err) + } - vals := map[string]string{"a": "1", "b": "2", "d": "4", "c": "3", "e": "5,6"} - arg := fmt.Sprintf("--s2s=%s", createS2SFlag(vals)) - err := f.Parse([]string{arg}) + result1, err := f.GetStringToString("map-flag") if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2s { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s", k, vals[k], v) - } + t.Fatalf("unexpected error: %v", err) } - getS2S, err := f.GetStringToString("s2s") - if err != nil { - t.Fatalf("got error: %v", err) + + if reflect.ValueOf(*ptr).Pointer() != reflect.ValueOf(result1).Pointer() { + t.Fatal("pointer mismatch") } - for k, v := range getS2S { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s from GetStringToString", k, vals[k], v) - } + if reflect.ValueOf(result0).Pointer() != reflect.ValueOf(result1).Pointer() { + t.Fatal("pointer mismatch") } -} - -func TestS2SDefault(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSetWithDefault(&s2s) - vals := map[string]string{"da": "1", "db": "2", "de": "5,6"} + // manipulate the map once more + if err := f.Set("map-flag", "d=4"); err != nil { + t.Fatalf("unexpected error: %v", err) + } - err := f.Parse([]string{}) + result2, err := f.GetStringToString("map-flag") if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2s { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s", k, vals[k], v) - } + t.Fatalf("unexpected error: %v", err) } - getS2S, err := f.GetStringToString("s2s") - if err != nil { - t.Fatal("got an error from GetStringToString():", err) + if reflect.ValueOf(*ptr).Pointer() != reflect.ValueOf(result2).Pointer() { + t.Fatal("pointer mismatch") } - for k, v := range getS2S { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s from GetStringToString but got: %s", k, vals[k], v) - } + if reflect.ValueOf(result1).Pointer() != reflect.ValueOf(result2).Pointer() { + t.Fatal("pointer mismatch") } -} - -func TestS2SWithDefault(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSetWithDefault(&s2s) - vals := map[string]string{"a": "1", "b": "2", "e": "5,6"} - arg := fmt.Sprintf("--s2s=%s", createS2SFlag(vals)) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) + // check that the newly added flag value was updated + if v, ok := result1["c"]; !ok || v != "3" { + t.Fatalf("value not present in map or unexpected value: %v", result1) } - for k, v := range s2s { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s", k, vals[k], v) - } + if v, ok := result1["d"]; !ok || v != "4" { + t.Fatalf("value not present in map or unexpected value: %v", result1) } - getS2S, err := f.GetStringToString("s2s") - if err != nil { - t.Fatal("got an error from GetStringToString():", err) + // finally, if we clear the map, it should reset flag + for k := range result1 { + delete(result1, k) } - for k, v := range getS2S { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s from GetStringToString but got: %s", k, vals[k], v) - } - } -} -func TestS2SCalledTwice(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSet(&s2s) - - in := []string{"a=1,b=2", "b=3", `"e=5,6"`, `f=7,8`} - expected := map[string]string{"a": "1", "b": "3", "e": "5,6", "f": "7,8"} - argfmt := "--s2s=%s" - arg0 := fmt.Sprintf(argfmt, in[0]) - arg1 := fmt.Sprintf(argfmt, in[1]) - arg2 := fmt.Sprintf(argfmt, in[2]) - arg3 := fmt.Sprintf(argfmt, in[3]) - err := f.Parse([]string{arg0, arg1, arg2, arg3}) + result3, err := f.GetStringToString("map-flag") if err != nil { - t.Fatal("expected no error; got", err) + t.Fatalf("unexpected error: %v", err) } - if len(s2s) != len(expected) { - t.Fatalf("expected %d flags; got %d flags", len(expected), len(s2s)) - } - for i, v := range s2s { - if expected[i] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s", i, expected[i], v) - } + if len(result3) != 0 { + t.Fatalf("unexpected map values: %v", result3) } }