summaryrefslogtreecommitdiff
path: root/src/pkg/flag
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/flag')
-rw-r--r--src/pkg/flag/flag.go19
-rw-r--r--src/pkg/flag/flag_test.go62
2 files changed, 67 insertions, 14 deletions
diff --git a/src/pkg/flag/flag.go b/src/pkg/flag/flag.go
index 5444ad141..85dd8c3b3 100644
--- a/src/pkg/flag/flag.go
+++ b/src/pkg/flag/flag.go
@@ -33,7 +33,7 @@
After parsing, the arguments after the flag are available as the
slice flag.Args() or individually as flag.Arg(i).
- The arguments are indexed from 0 up to flag.NArg().
+ The arguments are indexed from 0 through flag.NArg()-1.
Command line flag syntax:
-flag
@@ -91,6 +91,15 @@ func (b *boolValue) Set(s string) error {
func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) }
+func (b *boolValue) IsBoolFlag() bool { return true }
+
+// optional interface to indicate boolean flags that can be
+// supplied without "=value" text
+type boolFlag interface {
+ Value
+ IsBoolFlag() bool
+}
+
// -- int Value
type intValue int
@@ -204,6 +213,10 @@ func (d *durationValue) String() string { return (*time.Duration)(d).String() }
// Value is the interface to the dynamic value stored in a flag.
// (The default value is represented as a string.)
+//
+// If a Value has an IsBoolFlag() bool method returning true,
+// the command-line parser makes -name equivalent to -name=true
+// rather than using the next command-line argument.
type Value interface {
String() string
Set(string) error
@@ -704,10 +717,10 @@ func (f *FlagSet) parseOne() (bool, error) {
}
return false, f.failf("flag provided but not defined: -%s", name)
}
- if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg
+ if fv, ok := flag.Value.(boolFlag); ok && fv.IsBoolFlag() { // special case: doesn't need an arg
if has_value {
if err := fv.Set(value); err != nil {
- f.failf("invalid boolean value %q for -%s: %v", value, name, err)
+ return false, f.failf("invalid boolean value %q for -%s: %v", value, name, err)
}
} else {
fv.Set("true")
diff --git a/src/pkg/flag/flag_test.go b/src/pkg/flag/flag_test.go
index a9561f269..ddd54b277 100644
--- a/src/pkg/flag/flag_test.go
+++ b/src/pkg/flag/flag_test.go
@@ -15,17 +15,6 @@ import (
"time"
)
-var (
- test_bool = Bool("test_bool", false, "bool value")
- test_int = Int("test_int", 0, "int value")
- test_int64 = Int64("test_int64", 0, "int64 value")
- test_uint = Uint("test_uint", 0, "uint value")
- test_uint64 = Uint64("test_uint64", 0, "uint64 value")
- test_string = String("test_string", "0", "string value")
- test_float64 = Float64("test_float64", 0, "float64 value")
- test_duration = Duration("test_duration", 0, "time.Duration value")
-)
-
func boolString(s string) string {
if s == "0" {
return "false"
@@ -34,6 +23,16 @@ func boolString(s string) string {
}
func TestEverything(t *testing.T) {
+ ResetForTesting(nil)
+ Bool("test_bool", false, "bool value")
+ Int("test_int", 0, "int value")
+ Int64("test_int64", 0, "int64 value")
+ Uint("test_uint", 0, "uint value")
+ Uint64("test_uint64", 0, "uint64 value")
+ String("test_string", "0", "string value")
+ Float64("test_float64", 0, "float64 value")
+ Duration("test_duration", 0, "time.Duration value")
+
m := make(map[string]*Flag)
desired := "0"
visitor := func(f *Flag) {
@@ -208,6 +207,47 @@ func TestUserDefined(t *testing.T) {
}
}
+// Declare a user-defined boolean flag type.
+type boolFlagVar struct {
+ count int
+}
+
+func (b *boolFlagVar) String() string {
+ return fmt.Sprintf("%d", b.count)
+}
+
+func (b *boolFlagVar) Set(value string) error {
+ if value == "true" {
+ b.count++
+ }
+ return nil
+}
+
+func (b *boolFlagVar) IsBoolFlag() bool {
+ return b.count < 4
+}
+
+func TestUserDefinedBool(t *testing.T) {
+ var flags FlagSet
+ flags.Init("test", ContinueOnError)
+ var b boolFlagVar
+ var err error
+ flags.Var(&b, "b", "usage")
+ if err = flags.Parse([]string{"-b", "-b", "-b", "-b=true", "-b=false", "-b", "barg", "-b"}); err != nil {
+ if b.count < 4 {
+ t.Error(err)
+ }
+ }
+
+ if b.count != 4 {
+ t.Errorf("want: %d; got: %d", 4, b.count)
+ }
+
+ if err == nil {
+ t.Error("expected error; got none")
+ }
+}
+
func TestSetOutput(t *testing.T) {
var flags FlagSet
var buf bytes.Buffer