summaryrefslogtreecommitdiff
path: root/src/pkg/flag/flag_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/flag/flag_test.go')
-rw-r--r--src/pkg/flag/flag_test.go62
1 files changed, 51 insertions, 11 deletions
diff --git a/src/pkg/flag/flag_test.go b/src/pkg/flag/flag_test.go
index a9561f269..ddd54b277 100644
--- a/src/pkg/flag/flag_test.go
+++ b/src/pkg/flag/flag_test.go
@@ -15,17 +15,6 @@ import (
"time"
)
-var (
- test_bool = Bool("test_bool", false, "bool value")
- test_int = Int("test_int", 0, "int value")
- test_int64 = Int64("test_int64", 0, "int64 value")
- test_uint = Uint("test_uint", 0, "uint value")
- test_uint64 = Uint64("test_uint64", 0, "uint64 value")
- test_string = String("test_string", "0", "string value")
- test_float64 = Float64("test_float64", 0, "float64 value")
- test_duration = Duration("test_duration", 0, "time.Duration value")
-)
-
func boolString(s string) string {
if s == "0" {
return "false"
@@ -34,6 +23,16 @@ func boolString(s string) string {
}
func TestEverything(t *testing.T) {
+ ResetForTesting(nil)
+ Bool("test_bool", false, "bool value")
+ Int("test_int", 0, "int value")
+ Int64("test_int64", 0, "int64 value")
+ Uint("test_uint", 0, "uint value")
+ Uint64("test_uint64", 0, "uint64 value")
+ String("test_string", "0", "string value")
+ Float64("test_float64", 0, "float64 value")
+ Duration("test_duration", 0, "time.Duration value")
+
m := make(map[string]*Flag)
desired := "0"
visitor := func(f *Flag) {
@@ -208,6 +207,47 @@ func TestUserDefined(t *testing.T) {
}
}
+// Declare a user-defined boolean flag type.
+type boolFlagVar struct {
+ count int
+}
+
+func (b *boolFlagVar) String() string {
+ return fmt.Sprintf("%d", b.count)
+}
+
+func (b *boolFlagVar) Set(value string) error {
+ if value == "true" {
+ b.count++
+ }
+ return nil
+}
+
+func (b *boolFlagVar) IsBoolFlag() bool {
+ return b.count < 4
+}
+
+func TestUserDefinedBool(t *testing.T) {
+ var flags FlagSet
+ flags.Init("test", ContinueOnError)
+ var b boolFlagVar
+ var err error
+ flags.Var(&b, "b", "usage")
+ if err = flags.Parse([]string{"-b", "-b", "-b", "-b=true", "-b=false", "-b", "barg", "-b"}); err != nil {
+ if b.count < 4 {
+ t.Error(err)
+ }
+ }
+
+ if b.count != 4 {
+ t.Errorf("want: %d; got: %d", 4, b.count)
+ }
+
+ if err == nil {
+ t.Error("expected error; got none")
+ }
+}
+
func TestSetOutput(t *testing.T) {
var flags FlagSet
var buf bytes.Buffer