diff options
Diffstat (limited to 'src/cmd/gofmt')
-rw-r--r-- | src/cmd/gofmt/Makefile | 4 | ||||
-rw-r--r-- | src/cmd/gofmt/doc.go | 16 | ||||
-rw-r--r-- | src/cmd/gofmt/gofmt.go | 103 | ||||
-rw-r--r-- | src/cmd/gofmt/gofmt_test.go | 81 | ||||
-rw-r--r-- | src/cmd/gofmt/rewrite.go | 118 | ||||
-rw-r--r-- | src/cmd/gofmt/testdata/rewrite1.golden | 8 | ||||
-rw-r--r-- | src/cmd/gofmt/testdata/rewrite1.input | 8 | ||||
-rwxr-xr-x | src/cmd/gofmt/testdata/test.sh | 65 |
8 files changed, 230 insertions, 173 deletions
diff --git a/src/cmd/gofmt/Makefile b/src/cmd/gofmt/Makefile index 5f2f454e8..dc5b060e6 100644 --- a/src/cmd/gofmt/Makefile +++ b/src/cmd/gofmt/Makefile @@ -15,5 +15,5 @@ include ../../Make.cmd test: $(TARG) ./test.sh -smoketest: $(TARG) - (cd testdata; ./test.sh) +testshort: + gotest -test.short diff --git a/src/cmd/gofmt/doc.go b/src/cmd/gofmt/doc.go index 2d2c9ae61..e44030eee 100644 --- a/src/cmd/gofmt/doc.go +++ b/src/cmd/gofmt/doc.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. /* - Gofmt formats Go programs. Without an explicit path, it processes the standard input. Given a file, @@ -16,14 +15,16 @@ Usage: The flags are: -l - just list files whose formatting differs from gofmt's; generate no other output - unless -w is also set. + just list files whose formatting differs from gofmt's; + generate no other output unless -w is also set. -r rule apply the rewrite rule to the source before reformatting. -s try to simplify code (after applying the rewrite rule, if any). -w if set, overwrite each input file with its output. + -comments=true + print comments; if false, all comments are elided from the output. -spaces align with spaces instead of tabs. -tabindent @@ -31,15 +32,6 @@ The flags are: -tabwidth=8 tab width in spaces. -Debugging flags: - - -trace - print parse trace. - -ast - print AST (before rewrites). - -comments=true - print comments; if false, all comments are elided from the output. - The rewrite rule specified with the -r flag must be a string of the form: pattern -> replacement diff --git a/src/cmd/gofmt/gofmt.go b/src/cmd/gofmt/gofmt.go index 224aee717..ce274aa21 100644 --- a/src/cmd/gofmt/gofmt.go +++ b/src/cmd/gofmt/gofmt.go @@ -13,9 +13,11 @@ import ( "go/printer" "go/scanner" "go/token" + "io" "io/ioutil" "os" "path/filepath" + "runtime/pprof" "strings" ) @@ -27,15 +29,14 @@ var ( rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'α[β:len(α)] -> α[β:]')") simplifyAST = flag.Bool("s", false, "simplify code") - // debugging support - comments = flag.Bool("comments", true, "print comments") - trace = flag.Bool("trace", false, "print parse trace") - printAST = flag.Bool("ast", false, "print AST (before rewrites)") - // layout control + comments = flag.Bool("comments", true, "print comments") tabWidth = flag.Int("tabwidth", 8, "tab width") tabIndent = flag.Bool("tabindent", true, "indent with tabs independent of -spaces") useSpaces = flag.Bool("spaces", true, "align with spaces instead of tabs") + + // debugging + cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file") ) @@ -66,9 +67,6 @@ func initParserMode() { if *comments { parserMode |= parser.ParseComments } - if *trace { - parserMode |= parser.Trace - } } @@ -89,20 +87,25 @@ func isGoFile(f *os.FileInfo) bool { } -func processFile(f *os.File) os.Error { - src, err := ioutil.ReadAll(f) - if err != nil { - return err +// If in == nil, the source is the contents of the file with the given filename. +func processFile(filename string, in io.Reader, out io.Writer) os.Error { + if in == nil { + f, err := os.Open(filename) + if err != nil { + return err + } + defer f.Close() + in = f } - file, err := parser.ParseFile(fset, f.Name(), src, parserMode) - + src, err := ioutil.ReadAll(in) if err != nil { return err } - if *printAST { - ast.Print(file) + file, err := parser.ParseFile(fset, filename, src, parserMode) + if err != nil { + return err } if rewrite != nil { @@ -123,10 +126,10 @@ func processFile(f *os.File) os.Error { if !bytes.Equal(src, res) { // formatting has changed if *list { - fmt.Fprintln(os.Stdout, f.Name()) + fmt.Fprintln(out, filename) } if *write { - err = ioutil.WriteFile(f.Name(), res, 0) + err = ioutil.WriteFile(filename, res, 0) if err != nil { return err } @@ -134,23 +137,13 @@ func processFile(f *os.File) os.Error { } if !*list && !*write { - _, err = os.Stdout.Write(res) + _, err = out.Write(res) } return err } -func processFileByName(filename string) os.Error { - file, err := os.Open(filename, os.O_RDONLY, 0) - if err != nil { - return err - } - defer file.Close() - return processFile(file) -} - - type fileVisitor chan os.Error func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { @@ -161,7 +154,7 @@ func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { func (v fileVisitor) VisitFile(path string, f *os.FileInfo) { if isGoFile(f) { v <- nil // synchronize error handler - if err := processFileByName(path); err != nil { + if err := processFile(path, nil, os.Stdout); err != nil { v <- err } } @@ -169,30 +162,47 @@ func (v fileVisitor) VisitFile(path string, f *os.FileInfo) { func walkDir(path string) { - // start an error handler - done := make(chan bool) v := make(fileVisitor) go func() { - for err := range v { - if err != nil { - report(err) - } - } - done <- true + filepath.Walk(path, v, v) + close(v) }() - // walk the tree - filepath.Walk(path, v, v) - close(v) // terminate error handler loop - <-done // wait for all errors to be reported + for err := range v { + if err != nil { + report(err) + } + } } func main() { + // call gofmtMain in a separate function + // so that it can use defer and have them + // run before the exit. + gofmtMain() + os.Exit(exitCode) +} + + +func gofmtMain() { flag.Usage = usage flag.Parse() if *tabWidth < 0 { fmt.Fprintf(os.Stderr, "negative tabwidth %d\n", *tabWidth) - os.Exit(2) + exitCode = 2 + return + } + + if *cpuprofile != "" { + f, err := os.Create(*cpuprofile) + if err != nil { + fmt.Fprintf(os.Stderr, "creating cpu profile: %s\n", err) + exitCode = 2 + return + } + defer f.Close() + pprof.StartCPUProfile(f) + defer pprof.StopCPUProfile() } initParserMode() @@ -200,9 +210,10 @@ func main() { initRewrite() if flag.NArg() == 0 { - if err := processFile(os.Stdin); err != nil { + if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil { report(err) } + return } for i := 0; i < flag.NArg(); i++ { @@ -211,13 +222,11 @@ func main() { case err != nil: report(err) case dir.IsRegular(): - if err := processFileByName(path); err != nil { + if err := processFile(path, nil, os.Stdout); err != nil { report(err) } case dir.IsDirectory(): walkDir(path) } } - - os.Exit(exitCode) } diff --git a/src/cmd/gofmt/gofmt_test.go b/src/cmd/gofmt/gofmt_test.go new file mode 100644 index 000000000..4ec94e293 --- /dev/null +++ b/src/cmd/gofmt/gofmt_test.go @@ -0,0 +1,81 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "io/ioutil" + "path/filepath" + "strings" + "testing" +) + + +func runTest(t *testing.T, dirname, in, out, flags string) { + in = filepath.Join(dirname, in) + out = filepath.Join(dirname, out) + + // process flags + *simplifyAST = false + *rewriteRule = "" + for _, flag := range strings.Split(flags, " ", -1) { + elts := strings.Split(flag, "=", 2) + name := elts[0] + value := "" + if len(elts) == 2 { + value = elts[1] + } + switch name { + case "": + // no flags + case "-r": + *rewriteRule = value + case "-s": + *simplifyAST = true + default: + t.Errorf("unrecognized flag name: %s", name) + } + } + + initParserMode() + initPrinterMode() + initRewrite() + + var buf bytes.Buffer + err := processFile(in, nil, &buf) + if err != nil { + t.Error(err) + return + } + + expected, err := ioutil.ReadFile(out) + if err != nil { + t.Error(err) + return + } + + if got := buf.Bytes(); bytes.Compare(got, expected) != 0 { + t.Errorf("(gofmt %s) != %s (see %s.gofmt)", in, out, in) + ioutil.WriteFile(in+".gofmt", got, 0666) + } +} + + +// TODO(gri) Add more test cases! +var tests = []struct { + dirname, in, out, flags string +}{ + {".", "gofmt.go", "gofmt.go", ""}, + {".", "gofmt_test.go", "gofmt_test.go", ""}, + {"testdata", "composites.input", "composites.golden", "-s"}, + {"testdata", "rewrite1.input", "rewrite1.golden", "-r=Foo->Bar"}, +} + + +func TestRewrite(t *testing.T) { + for _, test := range tests { + runTest(t, test.dirname, test.in, test.out, test.flags) + } +} diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go index fbcd46aa2..93643dced 100644 --- a/src/cmd/gofmt/rewrite.go +++ b/src/cmd/gofmt/rewrite.go @@ -46,6 +46,16 @@ func parseExpr(s string, what string) ast.Expr { } +// Keep this function for debugging. +/* +func dump(msg string, val reflect.Value) { + fmt.Printf("%s:\n", msg) + ast.Print(fset, val.Interface()) + fmt.Println() +} +*/ + + // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file. func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { m := make(map[string]reflect.Value) @@ -54,7 +64,7 @@ func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { var f func(val reflect.Value) reflect.Value // f is recursive f = func(val reflect.Value) reflect.Value { for k := range m { - m[k] = nil, false + m[k] = reflect.Value{}, false } val = apply(f, val) if match(m, pat, val) { @@ -78,28 +88,45 @@ func setValue(x, y reflect.Value) { panic(x) } }() - x.SetValue(y) + x.Set(y) } +// Values/types for special cases. +var ( + objectPtrNil = reflect.NewValue((*ast.Object)(nil)) + + identType = reflect.Typeof((*ast.Ident)(nil)) + objectPtrType = reflect.Typeof((*ast.Object)(nil)) + positionType = reflect.Typeof(token.NoPos) +) + + // apply replaces each AST field x in val with f(x), returning val. // To avoid extra conversions, f operates on the reflect.Value form. func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value { - if val == nil { - return nil + if !val.IsValid() { + return reflect.Value{} + } + + // *ast.Objects introduce cycles and are likely incorrect after + // rewrite; don't follow them but replace with nil instead + if val.Type() == objectPtrType { + return objectPtrNil } - switch v := reflect.Indirect(val).(type) { - case *reflect.SliceValue: + + switch v := reflect.Indirect(val); v.Kind() { + case reflect.Slice: for i := 0; i < v.Len(); i++ { - e := v.Elem(i) + e := v.Index(i) setValue(e, f(e)) } - case *reflect.StructValue: + case reflect.Struct: for i := 0; i < v.NumField(); i++ { e := v.Field(i) setValue(e, f(e)) } - case *reflect.InterfaceValue: + case reflect.Interface: e := v.Elem() setValue(v, f(e)) } @@ -107,10 +134,6 @@ func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value } -var positionType = reflect.Typeof(token.NoPos) -var identType = reflect.Typeof((*ast.Ident)(nil)) - - func isWildcard(s string) bool { rune, size := utf8.DecodeRuneInString(s) return size == len(s) && unicode.IsLower(rune) @@ -124,9 +147,9 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // Wildcard matches any expression. If it appears multiple // times in the pattern, it must match the same expression // each time. - if m != nil && pattern != nil && pattern.Type() == identType { + if m != nil && pattern.IsValid() && pattern.Type() == identType { name := pattern.Interface().(*ast.Ident).Name - if isWildcard(name) && val != nil { + if isWildcard(name) && val.IsValid() { // wildcards only match expressions if _, ok := val.Interface().(ast.Expr); ok { if old, ok := m[name]; ok { @@ -139,8 +162,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { } // Otherwise, pattern and val must match recursively. - if pattern == nil || val == nil { - return pattern == nil && val == nil + if !pattern.IsValid() || !val.IsValid() { + return !pattern.IsValid() && !val.IsValid() } if pattern.Type() != val.Type() { return false @@ -148,9 +171,6 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // Special cases. switch pattern.Type() { - case positionType: - // token positions don't need to match - return true case identType: // For identifiers, only the names need to match // (and none of the other *ast.Object information). @@ -159,29 +179,30 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { p := pattern.Interface().(*ast.Ident) v := val.Interface().(*ast.Ident) return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name + case objectPtrType, positionType: + // object pointers and token positions don't need to match + return true } p := reflect.Indirect(pattern) v := reflect.Indirect(val) - if p == nil || v == nil { - return p == nil && v == nil + if !p.IsValid() || !v.IsValid() { + return !p.IsValid() && !v.IsValid() } - switch p := p.(type) { - case *reflect.SliceValue: - v := v.(*reflect.SliceValue) + switch p.Kind() { + case reflect.Slice: if p.Len() != v.Len() { return false } for i := 0; i < p.Len(); i++ { - if !match(m, p.Elem(i), v.Elem(i)) { + if !match(m, p.Index(i), v.Index(i)) { return false } } return true - case *reflect.StructValue: - v := v.(*reflect.StructValue) + case reflect.Struct: if p.NumField() != v.NumField() { return false } @@ -192,8 +213,7 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { } return true - case *reflect.InterfaceValue: - v := v.(*reflect.InterfaceValue) + case reflect.Interface: return match(m, p.Elem(), v.Elem()) } @@ -207,8 +227,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // if m == nil, subst returns a copy of pattern and doesn't change the line // number information. func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value { - if pattern == nil { - return nil + if !pattern.IsValid() { + return reflect.Value{} } // Wildcard gets replaced with map value. @@ -216,12 +236,12 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) name := pattern.Interface().(*ast.Ident).Name if isWildcard(name) { if old, ok := m[name]; ok { - return subst(nil, old, nil) + return subst(nil, old, reflect.Value{}) } } } - if pos != nil && pattern.Type() == positionType { + if pos.IsValid() && pattern.Type() == positionType { // use new position only if old position was valid in the first place if old := pattern.Interface().(token.Pos); !old.IsValid() { return pattern @@ -230,29 +250,33 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) } // Otherwise copy. - switch p := pattern.(type) { - case *reflect.SliceValue: - v := reflect.MakeSlice(p.Type().(*reflect.SliceType), p.Len(), p.Len()) + switch p := pattern; p.Kind() { + case reflect.Slice: + v := reflect.MakeSlice(p.Type(), p.Len(), p.Len()) for i := 0; i < p.Len(); i++ { - v.Elem(i).SetValue(subst(m, p.Elem(i), pos)) + v.Index(i).Set(subst(m, p.Index(i), pos)) } return v - case *reflect.StructValue: - v := reflect.MakeZero(p.Type()).(*reflect.StructValue) + case reflect.Struct: + v := reflect.Zero(p.Type()) for i := 0; i < p.NumField(); i++ { - v.Field(i).SetValue(subst(m, p.Field(i), pos)) + v.Field(i).Set(subst(m, p.Field(i), pos)) } return v - case *reflect.PtrValue: - v := reflect.MakeZero(p.Type()).(*reflect.PtrValue) - v.PointTo(subst(m, p.Elem(), pos)) + case reflect.Ptr: + v := reflect.Zero(p.Type()) + if elem := p.Elem(); elem.IsValid() { + v.Set(subst(m, elem, pos).Addr()) + } return v - case *reflect.InterfaceValue: - v := reflect.MakeZero(p.Type()).(*reflect.InterfaceValue) - v.SetValue(subst(m, p.Elem(), pos)) + case reflect.Interface: + v := reflect.Zero(p.Type()) + if elem := p.Elem(); elem.IsValid() { + v.Set(subst(m, elem, pos)) + } return v } diff --git a/src/cmd/gofmt/testdata/rewrite1.golden b/src/cmd/gofmt/testdata/rewrite1.golden new file mode 100644 index 000000000..3f909ff4a --- /dev/null +++ b/src/cmd/gofmt/testdata/rewrite1.golden @@ -0,0 +1,8 @@ +package main + +type Bar int + +func main() { + var a Bar + println(a) +} diff --git a/src/cmd/gofmt/testdata/rewrite1.input b/src/cmd/gofmt/testdata/rewrite1.input new file mode 100644 index 000000000..1f10e3601 --- /dev/null +++ b/src/cmd/gofmt/testdata/rewrite1.input @@ -0,0 +1,8 @@ +package main + +type Foo int + +func main() { + var a Foo + println(a) +} diff --git a/src/cmd/gofmt/testdata/test.sh b/src/cmd/gofmt/testdata/test.sh deleted file mode 100755 index a1d5d823e..000000000 --- a/src/cmd/gofmt/testdata/test.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2010 The Go Authors. All rights reserved. -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file. - -CMD="../gofmt" -TMP=test_tmp.go -COUNT=0 - - -cleanup() { - rm -f $TMP -} - - -error() { - echo $1 - exit 1 -} - - -count() { - #echo $1 - let COUNT=$COUNT+1 - let M=$COUNT%10 - if [ $M == 0 ]; then - echo -n "." - fi -} - - -test() { - count $1 - - # compare against .golden file - cleanup - $CMD -s $1 > $TMP - cmp -s $TMP $2 - if [ $? != 0 ]; then - diff $TMP $2 - error "Error: simplified $1 does not match $2" - fi - - # make sure .golden is idempotent - cleanup - $CMD -s $2 > $TMP - cmp -s $TMP $2 - if [ $? != 0 ]; then - diff $TMP $2 - error "Error: $2 is not idempotent" - fi -} - - -runtests() { - smoketest=../../../pkg/go/parser/parser.go - test $smoketest $smoketest - test composites.input composites.golden - # add more test cases here -} - - -runtests -cleanup -echo "PASSED ($COUNT tests)" |