summaryrefslogtreecommitdiff
path: root/src/cmd/gofmt
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd/gofmt')
-rw-r--r--src/cmd/gofmt/Makefile4
-rw-r--r--src/cmd/gofmt/doc.go16
-rw-r--r--src/cmd/gofmt/gofmt.go103
-rw-r--r--src/cmd/gofmt/gofmt_test.go81
-rw-r--r--src/cmd/gofmt/rewrite.go118
-rw-r--r--src/cmd/gofmt/testdata/rewrite1.golden8
-rw-r--r--src/cmd/gofmt/testdata/rewrite1.input8
-rwxr-xr-xsrc/cmd/gofmt/testdata/test.sh65
8 files changed, 230 insertions, 173 deletions
diff --git a/src/cmd/gofmt/Makefile b/src/cmd/gofmt/Makefile
index 5f2f454e8..dc5b060e6 100644
--- a/src/cmd/gofmt/Makefile
+++ b/src/cmd/gofmt/Makefile
@@ -15,5 +15,5 @@ include ../../Make.cmd
test: $(TARG)
./test.sh
-smoketest: $(TARG)
- (cd testdata; ./test.sh)
+testshort:
+ gotest -test.short
diff --git a/src/cmd/gofmt/doc.go b/src/cmd/gofmt/doc.go
index 2d2c9ae61..e44030eee 100644
--- a/src/cmd/gofmt/doc.go
+++ b/src/cmd/gofmt/doc.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
/*
-
Gofmt formats Go programs.
Without an explicit path, it processes the standard input. Given a file,
@@ -16,14 +15,16 @@ Usage:
The flags are:
-l
- just list files whose formatting differs from gofmt's; generate no other output
- unless -w is also set.
+ just list files whose formatting differs from gofmt's;
+ generate no other output unless -w is also set.
-r rule
apply the rewrite rule to the source before reformatting.
-s
try to simplify code (after applying the rewrite rule, if any).
-w
if set, overwrite each input file with its output.
+ -comments=true
+ print comments; if false, all comments are elided from the output.
-spaces
align with spaces instead of tabs.
-tabindent
@@ -31,15 +32,6 @@ The flags are:
-tabwidth=8
tab width in spaces.
-Debugging flags:
-
- -trace
- print parse trace.
- -ast
- print AST (before rewrites).
- -comments=true
- print comments; if false, all comments are elided from the output.
-
The rewrite rule specified with the -r flag must be a string of the form:
pattern -> replacement
diff --git a/src/cmd/gofmt/gofmt.go b/src/cmd/gofmt/gofmt.go
index 224aee717..ce274aa21 100644
--- a/src/cmd/gofmt/gofmt.go
+++ b/src/cmd/gofmt/gofmt.go
@@ -13,9 +13,11 @@ import (
"go/printer"
"go/scanner"
"go/token"
+ "io"
"io/ioutil"
"os"
"path/filepath"
+ "runtime/pprof"
"strings"
)
@@ -27,15 +29,14 @@ var (
rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'α[β:len(α)] -> α[β:]')")
simplifyAST = flag.Bool("s", false, "simplify code")
- // debugging support
- comments = flag.Bool("comments", true, "print comments")
- trace = flag.Bool("trace", false, "print parse trace")
- printAST = flag.Bool("ast", false, "print AST (before rewrites)")
-
// layout control
+ comments = flag.Bool("comments", true, "print comments")
tabWidth = flag.Int("tabwidth", 8, "tab width")
tabIndent = flag.Bool("tabindent", true, "indent with tabs independent of -spaces")
useSpaces = flag.Bool("spaces", true, "align with spaces instead of tabs")
+
+ // debugging
+ cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file")
)
@@ -66,9 +67,6 @@ func initParserMode() {
if *comments {
parserMode |= parser.ParseComments
}
- if *trace {
- parserMode |= parser.Trace
- }
}
@@ -89,20 +87,25 @@ func isGoFile(f *os.FileInfo) bool {
}
-func processFile(f *os.File) os.Error {
- src, err := ioutil.ReadAll(f)
- if err != nil {
- return err
+// If in == nil, the source is the contents of the file with the given filename.
+func processFile(filename string, in io.Reader, out io.Writer) os.Error {
+ if in == nil {
+ f, err := os.Open(filename)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ in = f
}
- file, err := parser.ParseFile(fset, f.Name(), src, parserMode)
-
+ src, err := ioutil.ReadAll(in)
if err != nil {
return err
}
- if *printAST {
- ast.Print(file)
+ file, err := parser.ParseFile(fset, filename, src, parserMode)
+ if err != nil {
+ return err
}
if rewrite != nil {
@@ -123,10 +126,10 @@ func processFile(f *os.File) os.Error {
if !bytes.Equal(src, res) {
// formatting has changed
if *list {
- fmt.Fprintln(os.Stdout, f.Name())
+ fmt.Fprintln(out, filename)
}
if *write {
- err = ioutil.WriteFile(f.Name(), res, 0)
+ err = ioutil.WriteFile(filename, res, 0)
if err != nil {
return err
}
@@ -134,23 +137,13 @@ func processFile(f *os.File) os.Error {
}
if !*list && !*write {
- _, err = os.Stdout.Write(res)
+ _, err = out.Write(res)
}
return err
}
-func processFileByName(filename string) os.Error {
- file, err := os.Open(filename, os.O_RDONLY, 0)
- if err != nil {
- return err
- }
- defer file.Close()
- return processFile(file)
-}
-
-
type fileVisitor chan os.Error
func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool {
@@ -161,7 +154,7 @@ func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool {
func (v fileVisitor) VisitFile(path string, f *os.FileInfo) {
if isGoFile(f) {
v <- nil // synchronize error handler
- if err := processFileByName(path); err != nil {
+ if err := processFile(path, nil, os.Stdout); err != nil {
v <- err
}
}
@@ -169,30 +162,47 @@ func (v fileVisitor) VisitFile(path string, f *os.FileInfo) {
func walkDir(path string) {
- // start an error handler
- done := make(chan bool)
v := make(fileVisitor)
go func() {
- for err := range v {
- if err != nil {
- report(err)
- }
- }
- done <- true
+ filepath.Walk(path, v, v)
+ close(v)
}()
- // walk the tree
- filepath.Walk(path, v, v)
- close(v) // terminate error handler loop
- <-done // wait for all errors to be reported
+ for err := range v {
+ if err != nil {
+ report(err)
+ }
+ }
}
func main() {
+ // call gofmtMain in a separate function
+ // so that it can use defer and have them
+ // run before the exit.
+ gofmtMain()
+ os.Exit(exitCode)
+}
+
+
+func gofmtMain() {
flag.Usage = usage
flag.Parse()
if *tabWidth < 0 {
fmt.Fprintf(os.Stderr, "negative tabwidth %d\n", *tabWidth)
- os.Exit(2)
+ exitCode = 2
+ return
+ }
+
+ if *cpuprofile != "" {
+ f, err := os.Create(*cpuprofile)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "creating cpu profile: %s\n", err)
+ exitCode = 2
+ return
+ }
+ defer f.Close()
+ pprof.StartCPUProfile(f)
+ defer pprof.StopCPUProfile()
}
initParserMode()
@@ -200,9 +210,10 @@ func main() {
initRewrite()
if flag.NArg() == 0 {
- if err := processFile(os.Stdin); err != nil {
+ if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil {
report(err)
}
+ return
}
for i := 0; i < flag.NArg(); i++ {
@@ -211,13 +222,11 @@ func main() {
case err != nil:
report(err)
case dir.IsRegular():
- if err := processFileByName(path); err != nil {
+ if err := processFile(path, nil, os.Stdout); err != nil {
report(err)
}
case dir.IsDirectory():
walkDir(path)
}
}
-
- os.Exit(exitCode)
}
diff --git a/src/cmd/gofmt/gofmt_test.go b/src/cmd/gofmt/gofmt_test.go
new file mode 100644
index 000000000..4ec94e293
--- /dev/null
+++ b/src/cmd/gofmt/gofmt_test.go
@@ -0,0 +1,81 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bytes"
+ "io/ioutil"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+
+func runTest(t *testing.T, dirname, in, out, flags string) {
+ in = filepath.Join(dirname, in)
+ out = filepath.Join(dirname, out)
+
+ // process flags
+ *simplifyAST = false
+ *rewriteRule = ""
+ for _, flag := range strings.Split(flags, " ", -1) {
+ elts := strings.Split(flag, "=", 2)
+ name := elts[0]
+ value := ""
+ if len(elts) == 2 {
+ value = elts[1]
+ }
+ switch name {
+ case "":
+ // no flags
+ case "-r":
+ *rewriteRule = value
+ case "-s":
+ *simplifyAST = true
+ default:
+ t.Errorf("unrecognized flag name: %s", name)
+ }
+ }
+
+ initParserMode()
+ initPrinterMode()
+ initRewrite()
+
+ var buf bytes.Buffer
+ err := processFile(in, nil, &buf)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ expected, err := ioutil.ReadFile(out)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ if got := buf.Bytes(); bytes.Compare(got, expected) != 0 {
+ t.Errorf("(gofmt %s) != %s (see %s.gofmt)", in, out, in)
+ ioutil.WriteFile(in+".gofmt", got, 0666)
+ }
+}
+
+
+// TODO(gri) Add more test cases!
+var tests = []struct {
+ dirname, in, out, flags string
+}{
+ {".", "gofmt.go", "gofmt.go", ""},
+ {".", "gofmt_test.go", "gofmt_test.go", ""},
+ {"testdata", "composites.input", "composites.golden", "-s"},
+ {"testdata", "rewrite1.input", "rewrite1.golden", "-r=Foo->Bar"},
+}
+
+
+func TestRewrite(t *testing.T) {
+ for _, test := range tests {
+ runTest(t, test.dirname, test.in, test.out, test.flags)
+ }
+}
diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go
index fbcd46aa2..93643dced 100644
--- a/src/cmd/gofmt/rewrite.go
+++ b/src/cmd/gofmt/rewrite.go
@@ -46,6 +46,16 @@ func parseExpr(s string, what string) ast.Expr {
}
+// Keep this function for debugging.
+/*
+func dump(msg string, val reflect.Value) {
+ fmt.Printf("%s:\n", msg)
+ ast.Print(fset, val.Interface())
+ fmt.Println()
+}
+*/
+
+
// rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
m := make(map[string]reflect.Value)
@@ -54,7 +64,7 @@ func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
var f func(val reflect.Value) reflect.Value // f is recursive
f = func(val reflect.Value) reflect.Value {
for k := range m {
- m[k] = nil, false
+ m[k] = reflect.Value{}, false
}
val = apply(f, val)
if match(m, pat, val) {
@@ -78,28 +88,45 @@ func setValue(x, y reflect.Value) {
panic(x)
}
}()
- x.SetValue(y)
+ x.Set(y)
}
+// Values/types for special cases.
+var (
+ objectPtrNil = reflect.NewValue((*ast.Object)(nil))
+
+ identType = reflect.Typeof((*ast.Ident)(nil))
+ objectPtrType = reflect.Typeof((*ast.Object)(nil))
+ positionType = reflect.Typeof(token.NoPos)
+)
+
+
// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
- if val == nil {
- return nil
+ if !val.IsValid() {
+ return reflect.Value{}
+ }
+
+ // *ast.Objects introduce cycles and are likely incorrect after
+ // rewrite; don't follow them but replace with nil instead
+ if val.Type() == objectPtrType {
+ return objectPtrNil
}
- switch v := reflect.Indirect(val).(type) {
- case *reflect.SliceValue:
+
+ switch v := reflect.Indirect(val); v.Kind() {
+ case reflect.Slice:
for i := 0; i < v.Len(); i++ {
- e := v.Elem(i)
+ e := v.Index(i)
setValue(e, f(e))
}
- case *reflect.StructValue:
+ case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
e := v.Field(i)
setValue(e, f(e))
}
- case *reflect.InterfaceValue:
+ case reflect.Interface:
e := v.Elem()
setValue(v, f(e))
}
@@ -107,10 +134,6 @@ func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value
}
-var positionType = reflect.Typeof(token.NoPos)
-var identType = reflect.Typeof((*ast.Ident)(nil))
-
-
func isWildcard(s string) bool {
rune, size := utf8.DecodeRuneInString(s)
return size == len(s) && unicode.IsLower(rune)
@@ -124,9 +147,9 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
// Wildcard matches any expression. If it appears multiple
// times in the pattern, it must match the same expression
// each time.
- if m != nil && pattern != nil && pattern.Type() == identType {
+ if m != nil && pattern.IsValid() && pattern.Type() == identType {
name := pattern.Interface().(*ast.Ident).Name
- if isWildcard(name) && val != nil {
+ if isWildcard(name) && val.IsValid() {
// wildcards only match expressions
if _, ok := val.Interface().(ast.Expr); ok {
if old, ok := m[name]; ok {
@@ -139,8 +162,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
}
// Otherwise, pattern and val must match recursively.
- if pattern == nil || val == nil {
- return pattern == nil && val == nil
+ if !pattern.IsValid() || !val.IsValid() {
+ return !pattern.IsValid() && !val.IsValid()
}
if pattern.Type() != val.Type() {
return false
@@ -148,9 +171,6 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
// Special cases.
switch pattern.Type() {
- case positionType:
- // token positions don't need to match
- return true
case identType:
// For identifiers, only the names need to match
// (and none of the other *ast.Object information).
@@ -159,29 +179,30 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
p := pattern.Interface().(*ast.Ident)
v := val.Interface().(*ast.Ident)
return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
+ case objectPtrType, positionType:
+ // object pointers and token positions don't need to match
+ return true
}
p := reflect.Indirect(pattern)
v := reflect.Indirect(val)
- if p == nil || v == nil {
- return p == nil && v == nil
+ if !p.IsValid() || !v.IsValid() {
+ return !p.IsValid() && !v.IsValid()
}
- switch p := p.(type) {
- case *reflect.SliceValue:
- v := v.(*reflect.SliceValue)
+ switch p.Kind() {
+ case reflect.Slice:
if p.Len() != v.Len() {
return false
}
for i := 0; i < p.Len(); i++ {
- if !match(m, p.Elem(i), v.Elem(i)) {
+ if !match(m, p.Index(i), v.Index(i)) {
return false
}
}
return true
- case *reflect.StructValue:
- v := v.(*reflect.StructValue)
+ case reflect.Struct:
if p.NumField() != v.NumField() {
return false
}
@@ -192,8 +213,7 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
}
return true
- case *reflect.InterfaceValue:
- v := v.(*reflect.InterfaceValue)
+ case reflect.Interface:
return match(m, p.Elem(), v.Elem())
}
@@ -207,8 +227,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
// if m == nil, subst returns a copy of pattern and doesn't change the line
// number information.
func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
- if pattern == nil {
- return nil
+ if !pattern.IsValid() {
+ return reflect.Value{}
}
// Wildcard gets replaced with map value.
@@ -216,12 +236,12 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value)
name := pattern.Interface().(*ast.Ident).Name
if isWildcard(name) {
if old, ok := m[name]; ok {
- return subst(nil, old, nil)
+ return subst(nil, old, reflect.Value{})
}
}
}
- if pos != nil && pattern.Type() == positionType {
+ if pos.IsValid() && pattern.Type() == positionType {
// use new position only if old position was valid in the first place
if old := pattern.Interface().(token.Pos); !old.IsValid() {
return pattern
@@ -230,29 +250,33 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value)
}
// Otherwise copy.
- switch p := pattern.(type) {
- case *reflect.SliceValue:
- v := reflect.MakeSlice(p.Type().(*reflect.SliceType), p.Len(), p.Len())
+ switch p := pattern; p.Kind() {
+ case reflect.Slice:
+ v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
for i := 0; i < p.Len(); i++ {
- v.Elem(i).SetValue(subst(m, p.Elem(i), pos))
+ v.Index(i).Set(subst(m, p.Index(i), pos))
}
return v
- case *reflect.StructValue:
- v := reflect.MakeZero(p.Type()).(*reflect.StructValue)
+ case reflect.Struct:
+ v := reflect.Zero(p.Type())
for i := 0; i < p.NumField(); i++ {
- v.Field(i).SetValue(subst(m, p.Field(i), pos))
+ v.Field(i).Set(subst(m, p.Field(i), pos))
}
return v
- case *reflect.PtrValue:
- v := reflect.MakeZero(p.Type()).(*reflect.PtrValue)
- v.PointTo(subst(m, p.Elem(), pos))
+ case reflect.Ptr:
+ v := reflect.Zero(p.Type())
+ if elem := p.Elem(); elem.IsValid() {
+ v.Set(subst(m, elem, pos).Addr())
+ }
return v
- case *reflect.InterfaceValue:
- v := reflect.MakeZero(p.Type()).(*reflect.InterfaceValue)
- v.SetValue(subst(m, p.Elem(), pos))
+ case reflect.Interface:
+ v := reflect.Zero(p.Type())
+ if elem := p.Elem(); elem.IsValid() {
+ v.Set(subst(m, elem, pos))
+ }
return v
}
diff --git a/src/cmd/gofmt/testdata/rewrite1.golden b/src/cmd/gofmt/testdata/rewrite1.golden
new file mode 100644
index 000000000..3f909ff4a
--- /dev/null
+++ b/src/cmd/gofmt/testdata/rewrite1.golden
@@ -0,0 +1,8 @@
+package main
+
+type Bar int
+
+func main() {
+ var a Bar
+ println(a)
+}
diff --git a/src/cmd/gofmt/testdata/rewrite1.input b/src/cmd/gofmt/testdata/rewrite1.input
new file mode 100644
index 000000000..1f10e3601
--- /dev/null
+++ b/src/cmd/gofmt/testdata/rewrite1.input
@@ -0,0 +1,8 @@
+package main
+
+type Foo int
+
+func main() {
+ var a Foo
+ println(a)
+}
diff --git a/src/cmd/gofmt/testdata/test.sh b/src/cmd/gofmt/testdata/test.sh
deleted file mode 100755
index a1d5d823e..000000000
--- a/src/cmd/gofmt/testdata/test.sh
+++ /dev/null
@@ -1,65 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2010 The Go Authors. All rights reserved.
-# Use of this source code is governed by a BSD-style
-# license that can be found in the LICENSE file.
-
-CMD="../gofmt"
-TMP=test_tmp.go
-COUNT=0
-
-
-cleanup() {
- rm -f $TMP
-}
-
-
-error() {
- echo $1
- exit 1
-}
-
-
-count() {
- #echo $1
- let COUNT=$COUNT+1
- let M=$COUNT%10
- if [ $M == 0 ]; then
- echo -n "."
- fi
-}
-
-
-test() {
- count $1
-
- # compare against .golden file
- cleanup
- $CMD -s $1 > $TMP
- cmp -s $TMP $2
- if [ $? != 0 ]; then
- diff $TMP $2
- error "Error: simplified $1 does not match $2"
- fi
-
- # make sure .golden is idempotent
- cleanup
- $CMD -s $2 > $TMP
- cmp -s $TMP $2
- if [ $? != 0 ]; then
- diff $TMP $2
- error "Error: $2 is not idempotent"
- fi
-}
-
-
-runtests() {
- smoketest=../../../pkg/go/parser/parser.go
- test $smoketest $smoketest
- test composites.input composites.golden
- # add more test cases here
-}
-
-
-runtests
-cleanup
-echo "PASSED ($COUNT tests)"