diff options
Diffstat (limited to 'src/cmd/gofix/main_test.go')
| -rw-r--r-- | src/cmd/gofix/main_test.go | 31 |
1 files changed, 17 insertions, 14 deletions
diff --git a/src/cmd/gofix/main_test.go b/src/cmd/gofix/main_test.go index 275778e5b..2151bf29e 100644 --- a/src/cmd/gofix/main_test.go +++ b/src/cmd/gofix/main_test.go @@ -5,10 +5,8 @@ package main import ( - "bytes" "go/ast" "go/parser" - "go/printer" "strings" "testing" ) @@ -22,27 +20,33 @@ type testCase struct { var testCases []testCase -func addTestCases(t []testCase) { +func addTestCases(t []testCase, fn func(*ast.File) bool) { + // Fill in fn to avoid repetition in definitions. + if fn != nil { + for i := range t { + if t[i].Fn == nil { + t[i].Fn = fn + } + } + } testCases = append(testCases, t...) } func fnop(*ast.File) bool { return false } -func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) { +func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) { file, err := parser.ParseFile(fset, desc, in, parserMode) if err != nil { t.Errorf("%s: parsing: %v", desc, err) return } - var buf bytes.Buffer - buf.Reset() - _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file) + outb, err := gofmtFile(file) if err != nil { t.Errorf("%s: printing: %v", desc, err) return } - if s := buf.String(); in != s && fn != fnop { + if s := string(outb); in != s && mustBeGofmt { t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s", desc, desc, in, desc, s) tdiff(t, in, s) @@ -59,26 +63,25 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out fixed = fn(file) } - buf.Reset() - _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file) + outb, err = gofmtFile(file) if err != nil { t.Errorf("%s: printing: %v", desc, err) return } - return buf.String(), fixed, true + return string(outb), fixed, true } func TestRewrite(t *testing.T) { for _, tt := range testCases { // Apply fix: should get tt.Out. - out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In) + out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true) if !ok { continue } // reformat to get printing right - out, _, ok = parseFixPrint(t, fnop, tt.Name, out) + out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false) if !ok { continue } @@ -98,7 +101,7 @@ func TestRewrite(t *testing.T) { } // Should not change if run again. - out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out) + out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true) if !ok { continue } |
