diff options
-rw-r--r-- | src/lib/flag.go | 138 | ||||
-rw-r--r-- | src/lib/flag_test.go | 56 |
2 files changed, 139 insertions, 55 deletions
diff --git a/src/lib/flag.go b/src/lib/flag.go index 35e18f9e6..541966d87 100644 --- a/src/lib/flag.go +++ b/src/lib/flag.go @@ -115,11 +115,13 @@ func newBoolValue(val bool, p *bool) *boolValue { return &boolValue(p) } -func (b *boolValue) set(val bool) { - *b.p = val; +func (b *boolValue) set(s string) bool { + v, ok := atob(s); + *b.p = v; + return ok } -func (b *boolValue) str() string { +func (b *boolValue) String() string { return fmt.Sprintf("%v", *b.p) } @@ -133,11 +135,13 @@ func newIntValue(val int, p *int) *intValue { return &intValue(p) } -func (i *intValue) set(val int) { - *i.p = val; +func (i *intValue) set(s string) bool { + v, ok := atoi(s); + *i.p = int(v); + return ok } -func (i *intValue) str() string { +func (i *intValue) String() string { return fmt.Sprintf("%v", *i.p) } @@ -151,11 +155,13 @@ func newInt64Value(val int64, p *int64) *int64Value { return &int64Value(p) } -func (i *int64Value) set(val int64) { - *i.p = val; +func (i *int64Value) set(s string) bool { + v, ok := atoi(s); + *i.p = v; + return ok; } -func (i *int64Value) str() string { +func (i *int64Value) String() string { return fmt.Sprintf("%v", *i.p) } @@ -169,11 +175,13 @@ func newUintValue(val uint, p *uint) *uintValue { return &uintValue(p) } -func (i *uintValue) set(val uint) { - *i.p = val +func (i *uintValue) set(s string) bool { + v, ok := atoi(s); // TODO(r): want unsigned + *i.p = uint(v); + return ok; } -func (i *uintValue) str() string { +func (i *uintValue) String() string { return fmt.Sprintf("%v", *i.p) } @@ -187,11 +195,13 @@ func newUint64Value(val uint64, p *uint64) *uint64Value { return &uint64Value(p) } -func (i *uint64Value) set(val uint64) { - *i.p = val; +func (i *uint64Value) set(s string) bool { + v, ok := atoi(s); // TODO(r): want unsigned + *i.p = uint64(v); + return ok; } -func (i *uint64Value) str() string { +func (i *uint64Value) String() string { return fmt.Sprintf("%v", *i.p) } @@ -205,25 +215,27 @@ func newStringValue(val string, p *string) *stringValue { return &stringValue(p) } -func (s *stringValue) set(val string) { +func (s *stringValue) set(val string) bool { *s.p = val; + return true; } -func (s *stringValue) str() string { +func (s *stringValue) String() string { return fmt.Sprintf("%#q", *s.p) } -// -- Value interface -type _Value interface { - str() string; +// -- FlagValue interface +type FlagValue interface { + String() string; + set(string) bool; } -// -- Flag structure (internal) +// -- Flag structure type Flag struct { - name string; // name as it appears on command line - usage string; // help message - value _Value; // value as set - defvalue string; // default value (as text); for usage message + Name string; // name as it appears on command line + Usage string; // help message + Value FlagValue; // value as set + DefValue string; // default value (as text); for usage message } type allFlags struct { @@ -234,10 +246,45 @@ type allFlags struct { var flags *allFlags = &allFlags(make(map[string] *Flag), make(map[string] *Flag), 1) -func PrintDefaults() { +// Visit all flags, including those defined but not set. +func VisitAll(fn func(*Flag)) { for k, f := range flags.formal { - print(" -", f.name, "=", f.defvalue, ": ", f.usage, "\n"); + fn(f) + } +} + +// Visit only those flags that have been set +func Visit(fn func(*Flag)) { + for k, f := range flags.actual { + fn(f) + } +} + +func Lookup(name string) *Flag { + f, ok := flags.formal[name]; + if !ok { + return nil + } + return f +} + +func Set(name, value string) bool { + f, ok := flags.formal[name]; + if !ok { + return false + } + ok = f.Value.set(value); + if !ok { + return false } + flags.actual[name] = f; + return true; +} + +func PrintDefaults() { + VisitAll(func(f *Flag) { + print(" -", f.Name, "=", f.DefValue, ": ", f.Usage, "\n"); + }) } func Usage() { @@ -266,12 +313,9 @@ func NArg() int { return len(sys.Args) - flags.first_arg } -func add(name string, value _Value, usage string) { - f := new(Flag); - f.name = name; - f.usage = usage; - f.value = value; - f.defvalue = value.str(); // Remember the default value as a string; it won't change. +func add(name string, value FlagValue, usage string) { + // Remember the default value as a string; it won't change. + f := &Flag(name, usage, value, value.String()); dummy, alreadythere := flags.formal[name]; if alreadythere { print("flag redefined: ", name, "\n"); @@ -388,16 +432,14 @@ func (f *allFlags) parseOne(index int) (ok bool, next int) print("flag provided but not defined: -", name, "\n"); Usage(); } - if f, ok := flag.value.(*boolValue); ok { + if f, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg if has_value { - k, ok := atob(value); - if !ok { + if !f.set(value) { print("invalid boolean value ", value, " for flag: -", name, "\n"); Usage(); } - f.set(k) } else { - f.set(true) + f.set("true") } } else { // It must have a value, which might be the next argument. @@ -411,24 +453,10 @@ func (f *allFlags) parseOne(index int) (ok bool, next int) print("flag needs an argument: -", name, "\n"); Usage(); } - if f, ok := flag.value.(*stringValue); ok { - f.set(value) - } else { - // It's an integer flag. TODO(r): check for overflow? - k, ok := atoi(value); - if !ok { - print("invalid integer value ", value, " for flag: -", name, "\n"); + ok = flag.Value.set(value); + if !ok { + print("invalid value ", value, " for flag: -", name, "\n"); Usage(); - } - if f, ok := flag.value.(*intValue); ok { - f.set(int(k)); - } else if f, ok := flag.value.(*int64Value); ok { - f.set(k); - } else if f, ok := flag.value.(*uintValue); ok { - f.set(uint(k)); - } else if f, ok := flag.value.(*uint64Value); ok { - f.set(uint64(k)); - } } } flags.actual[name] = flag; diff --git a/src/lib/flag_test.go b/src/lib/flag_test.go new file mode 100644 index 000000000..1212cf89f --- /dev/null +++ b/src/lib/flag_test.go @@ -0,0 +1,56 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flag + +import ( + "flag"; + "fmt"; + "testing"; +) + +var ( + test_bool = flag.Bool("test_bool", true, "bool value"); + test_int = flag.Int("test_int", 1, "int value"); + test_int64 = flag.Int64("test_int64", 1, "int64 value"); + test_uint = flag.Uint("test_uint", 1, "uint value"); + test_uint64 = flag.Uint64("test_uint64", 1, "uint64 value"); + test_string = flag.String("test_string", "1", "string value"); +) + +// Because this calls flag.Parse, it needs to be the only Test* function +func TestEverything(t *testing.T) { + flag.Parse(); + m := make(map[string] *flag.Flag); + visitor := func(f *flag.Flag) { + if len(f.Name) > 5 && f.Name[0:5] == "test_" { + m[f.Name] = f + } + }; + flag.VisitAll(visitor); + if len(m) != 6 { + t.Error("flag.VisitAll misses some flags"); + for k, v := range m { + t.Log(k, *v) + } + } + m = make(map[string] *flag.Flag); + flag.Visit(visitor); + if len(m) != 0 { + t.Errorf("flag.Visit sees unset flags"); + for k, v := range m { + t.Log(k, *v) + } + } + // Now set some flags + flag.Set("test_bool", "false"); + flag.Set("test_uint", "1234"); + flag.Visit(visitor); + if len(m) != 2 { + t.Error("flag.Visit fails after set"); + for k, v := range m { + t.Log(k, *v) + } + } +} |