diff options
Diffstat (limited to 'src/pkg/flag')
-rw-r--r-- | src/pkg/flag/flag.go | 19 | ||||
-rw-r--r-- | src/pkg/flag/flag_test.go | 62 |
2 files changed, 67 insertions, 14 deletions
diff --git a/src/pkg/flag/flag.go b/src/pkg/flag/flag.go index 5444ad141..85dd8c3b3 100644 --- a/src/pkg/flag/flag.go +++ b/src/pkg/flag/flag.go @@ -33,7 +33,7 @@ After parsing, the arguments after the flag are available as the slice flag.Args() or individually as flag.Arg(i). - The arguments are indexed from 0 up to flag.NArg(). + The arguments are indexed from 0 through flag.NArg()-1. Command line flag syntax: -flag @@ -91,6 +91,15 @@ func (b *boolValue) Set(s string) error { func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) } +func (b *boolValue) IsBoolFlag() bool { return true } + +// optional interface to indicate boolean flags that can be +// supplied without "=value" text +type boolFlag interface { + Value + IsBoolFlag() bool +} + // -- int Value type intValue int @@ -204,6 +213,10 @@ func (d *durationValue) String() string { return (*time.Duration)(d).String() } // Value is the interface to the dynamic value stored in a flag. // (The default value is represented as a string.) +// +// If a Value has an IsBoolFlag() bool method returning true, +// the command-line parser makes -name equivalent to -name=true +// rather than using the next command-line argument. type Value interface { String() string Set(string) error @@ -704,10 +717,10 @@ func (f *FlagSet) parseOne() (bool, error) { } return false, f.failf("flag provided but not defined: -%s", name) } - if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg + if fv, ok := flag.Value.(boolFlag); ok && fv.IsBoolFlag() { // special case: doesn't need an arg if has_value { if err := fv.Set(value); err != nil { - f.failf("invalid boolean value %q for -%s: %v", value, name, err) + return false, f.failf("invalid boolean value %q for -%s: %v", value, name, err) } } else { fv.Set("true") diff --git a/src/pkg/flag/flag_test.go b/src/pkg/flag/flag_test.go index a9561f269..ddd54b277 100644 --- a/src/pkg/flag/flag_test.go +++ b/src/pkg/flag/flag_test.go @@ -15,17 +15,6 @@ import ( "time" ) -var ( - test_bool = Bool("test_bool", false, "bool value") - test_int = Int("test_int", 0, "int value") - test_int64 = Int64("test_int64", 0, "int64 value") - test_uint = Uint("test_uint", 0, "uint value") - test_uint64 = Uint64("test_uint64", 0, "uint64 value") - test_string = String("test_string", "0", "string value") - test_float64 = Float64("test_float64", 0, "float64 value") - test_duration = Duration("test_duration", 0, "time.Duration value") -) - func boolString(s string) string { if s == "0" { return "false" @@ -34,6 +23,16 @@ func boolString(s string) string { } func TestEverything(t *testing.T) { + ResetForTesting(nil) + Bool("test_bool", false, "bool value") + Int("test_int", 0, "int value") + Int64("test_int64", 0, "int64 value") + Uint("test_uint", 0, "uint value") + Uint64("test_uint64", 0, "uint64 value") + String("test_string", "0", "string value") + Float64("test_float64", 0, "float64 value") + Duration("test_duration", 0, "time.Duration value") + m := make(map[string]*Flag) desired := "0" visitor := func(f *Flag) { @@ -208,6 +207,47 @@ func TestUserDefined(t *testing.T) { } } +// Declare a user-defined boolean flag type. +type boolFlagVar struct { + count int +} + +func (b *boolFlagVar) String() string { + return fmt.Sprintf("%d", b.count) +} + +func (b *boolFlagVar) Set(value string) error { + if value == "true" { + b.count++ + } + return nil +} + +func (b *boolFlagVar) IsBoolFlag() bool { + return b.count < 4 +} + +func TestUserDefinedBool(t *testing.T) { + var flags FlagSet + flags.Init("test", ContinueOnError) + var b boolFlagVar + var err error + flags.Var(&b, "b", "usage") + if err = flags.Parse([]string{"-b", "-b", "-b", "-b=true", "-b=false", "-b", "barg", "-b"}); err != nil { + if b.count < 4 { + t.Error(err) + } + } + + if b.count != 4 { + t.Errorf("want: %d; got: %d", 4, b.count) + } + + if err == nil { + t.Error("expected error; got none") + } +} + func TestSetOutput(t *testing.T) { var flags FlagSet var buf bytes.Buffer |