summaryrefslogtreecommitdiff
path: root/src/cmd/gofix/main_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd/gofix/main_test.go')
-rw-r--r--src/cmd/gofix/main_test.go31
1 files changed, 17 insertions, 14 deletions
diff --git a/src/cmd/gofix/main_test.go b/src/cmd/gofix/main_test.go
index 275778e5b..2151bf29e 100644
--- a/src/cmd/gofix/main_test.go
+++ b/src/cmd/gofix/main_test.go
@@ -5,10 +5,8 @@
package main
import (
- "bytes"
"go/ast"
"go/parser"
- "go/printer"
"strings"
"testing"
)
@@ -22,27 +20,33 @@ type testCase struct {
var testCases []testCase
-func addTestCases(t []testCase) {
+func addTestCases(t []testCase, fn func(*ast.File) bool) {
+ // Fill in fn to avoid repetition in definitions.
+ if fn != nil {
+ for i := range t {
+ if t[i].Fn == nil {
+ t[i].Fn = fn
+ }
+ }
+ }
testCases = append(testCases, t...)
}
func fnop(*ast.File) bool { return false }
-func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) {
+func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
file, err := parser.ParseFile(fset, desc, in, parserMode)
if err != nil {
t.Errorf("%s: parsing: %v", desc, err)
return
}
- var buf bytes.Buffer
- buf.Reset()
- _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
+ outb, err := gofmtFile(file)
if err != nil {
t.Errorf("%s: printing: %v", desc, err)
return
}
- if s := buf.String(); in != s && fn != fnop {
+ if s := string(outb); in != s && mustBeGofmt {
t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
desc, desc, in, desc, s)
tdiff(t, in, s)
@@ -59,26 +63,25 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out
fixed = fn(file)
}
- buf.Reset()
- _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
+ outb, err = gofmtFile(file)
if err != nil {
t.Errorf("%s: printing: %v", desc, err)
return
}
- return buf.String(), fixed, true
+ return string(outb), fixed, true
}
func TestRewrite(t *testing.T) {
for _, tt := range testCases {
// Apply fix: should get tt.Out.
- out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In)
+ out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
if !ok {
continue
}
// reformat to get printing right
- out, _, ok = parseFixPrint(t, fnop, tt.Name, out)
+ out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
if !ok {
continue
}
@@ -98,7 +101,7 @@ func TestRewrite(t *testing.T) {
}
// Should not change if run again.
- out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out)
+ out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
if !ok {
continue
}