summaryrefslogtreecommitdiff
path: root/src/cmd/fix
diff options
context:
space:
mode:
authorOndřej Surý <ondrej@sury.org>2012-04-06 15:14:11 +0200
committerOndřej Surý <ondrej@sury.org>2012-04-06 15:14:11 +0200
commit505c19580e0f43fe5224431459cacb7c21edd93d (patch)
tree79e2634c253d60afc0cc0b2f510dc7dcbb48497b /src/cmd/fix
parent1336a7c91e596c423a49d1194ea42d98bca0d958 (diff)
downloadgolang-505c19580e0f43fe5224431459cacb7c21edd93d.tar.gz
Imported Upstream version 1upstream/1
Diffstat (limited to 'src/cmd/fix')
-rw-r--r--src/cmd/fix/doc.go36
-rw-r--r--src/cmd/fix/error.go353
-rw-r--r--src/cmd/fix/error_test.go240
-rw-r--r--src/cmd/fix/filepath.go56
-rw-r--r--src/cmd/fix/filepath_test.go33
-rw-r--r--src/cmd/fix/fix.go848
-rw-r--r--src/cmd/fix/go1pkgrename.go146
-rw-r--r--src/cmd/fix/go1pkgrename_test.go139
-rw-r--r--src/cmd/fix/go1rename.go167
-rw-r--r--src/cmd/fix/go1rename_test.go195
-rw-r--r--src/cmd/fix/googlecode.go41
-rw-r--r--src/cmd/fix/googlecode_test.go31
-rw-r--r--src/cmd/fix/hashsum.go94
-rw-r--r--src/cmd/fix/hashsum_test.go99
-rw-r--r--src/cmd/fix/hmacnew.go61
-rw-r--r--src/cmd/fix/hmacnew_test.go107
-rw-r--r--src/cmd/fix/htmlerr.go47
-rw-r--r--src/cmd/fix/htmlerr_test.go39
-rw-r--r--src/cmd/fix/httpfinalurl.go57
-rw-r--r--src/cmd/fix/httpfinalurl_test.go37
-rw-r--r--src/cmd/fix/httpfs.go70
-rw-r--r--src/cmd/fix/httpfs_test.go47
-rw-r--r--src/cmd/fix/httpheaders.go67
-rw-r--r--src/cmd/fix/httpheaders_test.go73
-rw-r--r--src/cmd/fix/httpserver.go141
-rw-r--r--src/cmd/fix/httpserver_test.go53
-rw-r--r--src/cmd/fix/imagecolor.go85
-rw-r--r--src/cmd/fix/imagecolor_test.go126
-rw-r--r--src/cmd/fix/imagenew.go83
-rw-r--r--src/cmd/fix/imagenew_test.go51
-rw-r--r--src/cmd/fix/imageycbcr.go64
-rw-r--r--src/cmd/fix/imageycbcr_test.go54
-rw-r--r--src/cmd/fix/import_test.go458
-rw-r--r--src/cmd/fix/iocopyn.go41
-rw-r--r--src/cmd/fix/iocopyn_test.go37
-rw-r--r--src/cmd/fix/main.go271
-rw-r--r--src/cmd/fix/main_test.go129
-rw-r--r--src/cmd/fix/mapdelete.go89
-rw-r--r--src/cmd/fix/mapdelete_test.go43
-rw-r--r--src/cmd/fix/math.go51
-rw-r--r--src/cmd/fix/math_test.go47
-rw-r--r--src/cmd/fix/netdial.go117
-rw-r--r--src/cmd/fix/netdial_test.go57
-rw-r--r--src/cmd/fix/netudpgroup.go58
-rw-r--r--src/cmd/fix/netudpgroup_test.go53
-rw-r--r--src/cmd/fix/newwriter.go90
-rw-r--r--src/cmd/fix/newwriter_test.go83
-rw-r--r--src/cmd/fix/oserrorstring.go75
-rw-r--r--src/cmd/fix/oserrorstring_test.go57
-rw-r--r--src/cmd/fix/osopen.go124
-rw-r--r--src/cmd/fix/osopen_test.go82
-rw-r--r--src/cmd/fix/procattr.go62
-rw-r--r--src/cmd/fix/procattr_test.go74
-rw-r--r--src/cmd/fix/reflect.go862
-rw-r--r--src/cmd/fix/reflect_test.go35
-rw-r--r--src/cmd/fix/signal.go50
-rw-r--r--src/cmd/fix/signal_test.go94
-rw-r--r--src/cmd/fix/sorthelpers.go49
-rw-r--r--src/cmd/fix/sorthelpers_test.go45
-rw-r--r--src/cmd/fix/sortslice.go52
-rw-r--r--src/cmd/fix/sortslice_test.go35
-rw-r--r--src/cmd/fix/strconv.go127
-rw-r--r--src/cmd/fix/strconv_test.go93
-rw-r--r--src/cmd/fix/stringssplit.go72
-rw-r--r--src/cmd/fix/stringssplit_test.go51
-rw-r--r--src/cmd/fix/template.go111
-rw-r--r--src/cmd/fix/template_test.go55
-rw-r--r--src/cmd/fix/testdata/reflect.asn1.go.in814
-rw-r--r--src/cmd/fix/testdata/reflect.asn1.go.out814
-rw-r--r--src/cmd/fix/testdata/reflect.datafmt.go.in710
-rw-r--r--src/cmd/fix/testdata/reflect.datafmt.go.out710
-rw-r--r--src/cmd/fix/testdata/reflect.decode.go.in905
-rw-r--r--src/cmd/fix/testdata/reflect.decode.go.out908
-rw-r--r--src/cmd/fix/testdata/reflect.decoder.go.in196
-rw-r--r--src/cmd/fix/testdata/reflect.decoder.go.out196
-rw-r--r--src/cmd/fix/testdata/reflect.dnsmsg.go.in777
-rw-r--r--src/cmd/fix/testdata/reflect.dnsmsg.go.out777
-rw-r--r--src/cmd/fix/testdata/reflect.encode.go.in367
-rw-r--r--src/cmd/fix/testdata/reflect.encode.go.out367
-rw-r--r--src/cmd/fix/testdata/reflect.encoder.go.in240
-rw-r--r--src/cmd/fix/testdata/reflect.encoder.go.out240
-rw-r--r--src/cmd/fix/testdata/reflect.export.go.in400
-rw-r--r--src/cmd/fix/testdata/reflect.export.go.out400
-rw-r--r--src/cmd/fix/testdata/reflect.print.go.in944
-rw-r--r--src/cmd/fix/testdata/reflect.print.go.out944
-rw-r--r--src/cmd/fix/testdata/reflect.quick.go.in364
-rw-r--r--src/cmd/fix/testdata/reflect.quick.go.out365
-rw-r--r--src/cmd/fix/testdata/reflect.read.go.in620
-rw-r--r--src/cmd/fix/testdata/reflect.read.go.out620
-rw-r--r--src/cmd/fix/testdata/reflect.scan.go.in1082
-rw-r--r--src/cmd/fix/testdata/reflect.scan.go.out1082
-rw-r--r--src/cmd/fix/testdata/reflect.script.go.in359
-rw-r--r--src/cmd/fix/testdata/reflect.script.go.out359
-rw-r--r--src/cmd/fix/testdata/reflect.template.go.in1043
-rw-r--r--src/cmd/fix/testdata/reflect.template.go.out1044
-rw-r--r--src/cmd/fix/testdata/reflect.type.go.in790
-rw-r--r--src/cmd/fix/testdata/reflect.type.go.out790
-rw-r--r--src/cmd/fix/timefileinfo.go298
-rw-r--r--src/cmd/fix/timefileinfo_test.go187
-rw-r--r--src/cmd/fix/typecheck.go673
-rw-r--r--src/cmd/fix/url.go101
-rw-r--r--src/cmd/fix/url2.go46
-rw-r--r--src/cmd/fix/url2_test.go31
-rw-r--r--src/cmd/fix/url_test.go159
-rw-r--r--src/cmd/fix/xmlapi.go111
-rw-r--r--src/cmd/fix/xmlapi_test.go85
106 files changed, 28487 insertions, 0 deletions
diff --git a/src/cmd/fix/doc.go b/src/cmd/fix/doc.go
new file mode 100644
index 000000000..a92e0fc06
--- /dev/null
+++ b/src/cmd/fix/doc.go
@@ -0,0 +1,36 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Fix finds Go programs that use old APIs and rewrites them to use
+newer ones. After you update to a new Go release, fix helps make
+the necessary changes to your programs.
+
+Usage:
+ go tool fix [-r name,...] [path ...]
+
+Without an explicit path, fix reads standard input and writes the
+result to standard output.
+
+If the named path is a file, fix rewrites the named files in place.
+If the named path is a directory, fix rewrites all .go files in that
+directory tree. When fix rewrites a file, it prints a line to standard
+error giving the name of the file and the rewrite applied.
+
+If the -diff flag is set, no files are rewritten. Instead fix prints
+the differences a rewrite would introduce.
+
+The -r flag restricts the set of rewrites considered to those in the
+named list. By default fix considers all known rewrites. Fix's
+rewrites are idempotent, so that it is safe to apply fix to updated
+or partially updated code even without using the -r flag.
+
+Fix prints the full list of fixes it can apply in its help output;
+to see them, run go tool fix -?.
+
+Fix does not make backup copies of the files that it edits.
+Instead, use a version control system's ``diff'' functionality to inspect
+the changes that fix makes before committing them.
+*/
+package documentation
diff --git a/src/cmd/fix/error.go b/src/cmd/fix/error.go
new file mode 100644
index 000000000..55613210a
--- /dev/null
+++ b/src/cmd/fix/error.go
@@ -0,0 +1,353 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "regexp"
+ "strings"
+)
+
+func init() {
+ register(errorFix)
+}
+
+var errorFix = fix{
+ "error",
+ "2011-11-02",
+ errorFn,
+ `Use error instead of os.Error.
+
+This fix rewrites code using os.Error to use error:
+
+ os.Error -> error
+ os.NewError -> errors.New
+ os.EOF -> io.EOF
+
+Seeing the old names above (os.Error and so on) triggers the following
+heuristic rewrites. The heuristics can be forced using the -force=error flag.
+
+A top-level function, variable, or constant named error is renamed error_.
+
+Error implementations—those types used as os.Error or named
+XxxError—have their String methods renamed to Error. Any existing
+Error field or method is renamed to Err.
+
+Error values—those with type os.Error or named e, err, error, err1,
+and so on—have method calls and field references rewritten just
+as the types do (String to Error, Error to Err). Also, a type assertion
+of the form err.(*os.Waitmsg) becomes err.(*exec.ExitError).
+
+http://codereview.appspot.com/5305066
+`,
+}
+
+// At minimum, this fix applies the following rewrites:
+//
+// os.Error -> error
+// os.NewError -> errors.New
+// os.EOF -> io.EOF
+//
+// However, if can apply any of those rewrites, it assumes that the
+// file predates the error type and tries to update the code to use
+// the new definition for error - an Error method, not a String method.
+// This more heuristic procedure may not be 100% accurate, so it is
+// only run when the file needs updating anyway. The heuristic can
+// be forced to run using -force=error.
+//
+// First, we must identify the implementations of os.Error.
+// These include the type of any value returned as or assigned to an os.Error.
+// To that set we add any type whose name contains "Error" or "error".
+// The heuristic helps for implementations that are not used as os.Error
+// in the file in which they are defined.
+//
+// In any implementation of os.Error, we rename an existing struct field
+// or method named Error to Err and rename the String method to Error.
+//
+// Second, we must identify the values of type os.Error.
+// These include any value that obviously has type os.Error.
+// To that set we add any variable whose name is e or err or error
+// possibly followed by _ or a numeric or capitalized suffix.
+// The heuristic helps for variables that are initialized using calls
+// to functions in other packages. The type checker does not have
+// information about those packages available, and in general cannot
+// (because the packages may themselves not compile).
+//
+// For any value of type os.Error, we replace a call to String with a call to Error.
+// We also replace type assertion err.(*os.Waitmsg) with err.(*exec.ExitError).
+
+// Variables matching this regexp are assumed to have type os.Error.
+var errVar = regexp.MustCompile(`^(e|err|error)_?([A-Z0-9].*)?$`)
+
+// Types matching this regexp are assumed to be implementations of os.Error.
+var errType = regexp.MustCompile(`^\*?([Ee]rror|.*Error)$`)
+
+// Type-checking configuration: tell the type-checker this basic
+// information about types, functions, and variables in external packages.
+var errorTypeConfig = &TypeConfig{
+ Type: map[string]*Type{
+ "os.Error": {},
+ },
+ Func: map[string]string{
+ "fmt.Errorf": "os.Error",
+ "os.NewError": "os.Error",
+ },
+ Var: map[string]string{
+ "os.EPERM": "os.Error",
+ "os.ENOENT": "os.Error",
+ "os.ESRCH": "os.Error",
+ "os.EINTR": "os.Error",
+ "os.EIO": "os.Error",
+ "os.ENXIO": "os.Error",
+ "os.E2BIG": "os.Error",
+ "os.ENOEXEC": "os.Error",
+ "os.EBADF": "os.Error",
+ "os.ECHILD": "os.Error",
+ "os.EDEADLK": "os.Error",
+ "os.ENOMEM": "os.Error",
+ "os.EACCES": "os.Error",
+ "os.EFAULT": "os.Error",
+ "os.EBUSY": "os.Error",
+ "os.EEXIST": "os.Error",
+ "os.EXDEV": "os.Error",
+ "os.ENODEV": "os.Error",
+ "os.ENOTDIR": "os.Error",
+ "os.EISDIR": "os.Error",
+ "os.EINVAL": "os.Error",
+ "os.ENFILE": "os.Error",
+ "os.EMFILE": "os.Error",
+ "os.ENOTTY": "os.Error",
+ "os.EFBIG": "os.Error",
+ "os.ENOSPC": "os.Error",
+ "os.ESPIPE": "os.Error",
+ "os.EROFS": "os.Error",
+ "os.EMLINK": "os.Error",
+ "os.EPIPE": "os.Error",
+ "os.EAGAIN": "os.Error",
+ "os.EDOM": "os.Error",
+ "os.ERANGE": "os.Error",
+ "os.EADDRINUSE": "os.Error",
+ "os.ECONNREFUSED": "os.Error",
+ "os.ENAMETOOLONG": "os.Error",
+ "os.EAFNOSUPPORT": "os.Error",
+ "os.ETIMEDOUT": "os.Error",
+ "os.ENOTCONN": "os.Error",
+ },
+}
+
+func errorFn(f *ast.File) bool {
+ if !imports(f, "os") && !force["error"] {
+ return false
+ }
+
+ // Fix gets called once to run the heuristics described above
+ // when we notice that this file definitely needs fixing
+ // (it mentions os.Error or something similar).
+ var fixed bool
+ var didHeuristic bool
+ heuristic := func() {
+ if didHeuristic {
+ return
+ }
+ didHeuristic = true
+
+ // We have identified a necessary fix (like os.Error -> error)
+ // but have not applied it or any others yet. Prepare the file
+ // for fixing and apply heuristic fixes.
+
+ // Rename error to error_ to make room for error.
+ fixed = renameTop(f, "error", "error_") || fixed
+
+ // Use type checker to build list of error implementations.
+ typeof, assign := typecheck(errorTypeConfig, f)
+
+ isError := map[string]bool{}
+ for _, val := range assign["os.Error"] {
+ t := typeof[val]
+ if strings.HasPrefix(t, "*") {
+ t = t[1:]
+ }
+ if t != "" && !strings.HasPrefix(t, "func(") {
+ isError[t] = true
+ }
+ }
+
+ // We use both the type check results and the "Error" name heuristic
+ // to identify implementations of os.Error.
+ isErrorImpl := func(typ string) bool {
+ return isError[typ] || errType.MatchString(typ)
+ }
+
+ isErrorVar := func(x ast.Expr) bool {
+ if typ := typeof[x]; typ != "" {
+ return isErrorImpl(typ) || typ == "os.Error"
+ }
+ if sel, ok := x.(*ast.SelectorExpr); ok {
+ return sel.Sel.Name == "Error" || sel.Sel.Name == "Err"
+ }
+ if id, ok := x.(*ast.Ident); ok {
+ return errVar.MatchString(id.Name)
+ }
+ return false
+ }
+
+ walk(f, func(n interface{}) {
+ // In method declaration on error implementation type,
+ // rename String() to Error() and Error() to Err().
+ fn, ok := n.(*ast.FuncDecl)
+ if ok &&
+ fn.Recv != nil &&
+ len(fn.Recv.List) == 1 &&
+ isErrorImpl(typeName(fn.Recv.List[0].Type)) {
+ // Rename.
+ switch fn.Name.Name {
+ case "String":
+ fn.Name.Name = "Error"
+ fixed = true
+ case "Error":
+ fn.Name.Name = "Err"
+ fixed = true
+ }
+ return
+ }
+
+ // In type definition of an error implementation type,
+ // rename Error field to Err to make room for method.
+ // Given type XxxError struct { ... Error T } rename field to Err.
+ d, ok := n.(*ast.GenDecl)
+ if ok {
+ for _, s := range d.Specs {
+ switch s := s.(type) {
+ case *ast.TypeSpec:
+ if isErrorImpl(typeName(s.Name)) {
+ st, ok := s.Type.(*ast.StructType)
+ if ok {
+ for _, f := range st.Fields.List {
+ for _, n := range f.Names {
+ if n.Name == "Error" {
+ n.Name = "Err"
+ fixed = true
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // For values that are an error implementation type,
+ // rename .Error to .Err and .String to .Error
+ sel, selok := n.(*ast.SelectorExpr)
+ if selok && isErrorImpl(typeof[sel.X]) {
+ switch sel.Sel.Name {
+ case "Error":
+ sel.Sel.Name = "Err"
+ fixed = true
+ case "String":
+ sel.Sel.Name = "Error"
+ fixed = true
+ }
+ }
+
+ // Assume x.Err is an error value and rename .String to .Error
+ // Children have been processed so the rewrite from Error to Err
+ // has already happened there.
+ if selok {
+ if subsel, ok := sel.X.(*ast.SelectorExpr); ok && subsel.Sel.Name == "Err" && sel.Sel.Name == "String" {
+ sel.Sel.Name = "Error"
+ fixed = true
+ }
+ }
+
+ // For values that are an error variable, rename .String to .Error.
+ if selok && isErrorVar(sel.X) && sel.Sel.Name == "String" {
+ sel.Sel.Name = "Error"
+ fixed = true
+ }
+
+ // Rewrite composite literal of error type to turn Error: into Err:.
+ lit, ok := n.(*ast.CompositeLit)
+ if ok && isErrorImpl(typeof[lit]) {
+ for _, e := range lit.Elts {
+ if kv, ok := e.(*ast.KeyValueExpr); ok && isName(kv.Key, "Error") {
+ kv.Key.(*ast.Ident).Name = "Err"
+ fixed = true
+ }
+ }
+ }
+
+ // Rename os.Waitmsg to exec.ExitError
+ // when used in a type assertion on an error.
+ ta, ok := n.(*ast.TypeAssertExpr)
+ if ok && isErrorVar(ta.X) && isPtrPkgDot(ta.Type, "os", "Waitmsg") {
+ addImport(f, "exec")
+ sel := ta.Type.(*ast.StarExpr).X.(*ast.SelectorExpr)
+ sel.X.(*ast.Ident).Name = "exec"
+ sel.Sel.Name = "ExitError"
+ fixed = true
+ }
+
+ })
+ }
+
+ fix := func() {
+ if fixed {
+ return
+ }
+ fixed = true
+ heuristic()
+ }
+
+ if force["error"] {
+ heuristic()
+ }
+
+ walk(f, func(n interface{}) {
+ p, ok := n.(*ast.Expr)
+ if !ok {
+ return
+ }
+ sel, ok := (*p).(*ast.SelectorExpr)
+ if !ok {
+ return
+ }
+ switch {
+ case isPkgDot(sel, "os", "Error"):
+ fix()
+ *p = &ast.Ident{NamePos: sel.Pos(), Name: "error"}
+ case isPkgDot(sel, "os", "NewError"):
+ fix()
+ addImport(f, "errors")
+ sel.X.(*ast.Ident).Name = "errors"
+ sel.Sel.Name = "New"
+ case isPkgDot(sel, "os", "EOF"):
+ fix()
+ addImport(f, "io")
+ sel.X.(*ast.Ident).Name = "io"
+ }
+ })
+
+ if fixed && !usesImport(f, "os") {
+ deleteImport(f, "os")
+ }
+
+ return fixed
+}
+
+func typeName(typ ast.Expr) string {
+ if p, ok := typ.(*ast.StarExpr); ok {
+ typ = p.X
+ }
+ id, ok := typ.(*ast.Ident)
+ if ok {
+ return id.Name
+ }
+ sel, ok := typ.(*ast.SelectorExpr)
+ if ok {
+ return typeName(sel.X) + "." + sel.Sel.Name
+ }
+ return ""
+}
diff --git a/src/cmd/fix/error_test.go b/src/cmd/fix/error_test.go
new file mode 100644
index 000000000..027eed24f
--- /dev/null
+++ b/src/cmd/fix/error_test.go
@@ -0,0 +1,240 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(errorTests, errorFn)
+}
+
+var errorTests = []testCase{
+ {
+ Name: "error.0",
+ In: `package main
+
+func error() {}
+
+var error int
+`,
+ Out: `package main
+
+func error() {}
+
+var error int
+`,
+ },
+ {
+ Name: "error.1",
+ In: `package main
+
+import "os"
+
+func f() os.Error {
+ return os.EOF
+}
+
+func error() {}
+
+var error int
+
+func g() {
+ error := 1
+ _ = error
+}
+
+func h(os.Error) {}
+
+func i(...os.Error) {}
+`,
+ Out: `package main
+
+import "io"
+
+func f() error {
+ return io.EOF
+}
+
+func error_() {}
+
+var error_ int
+
+func g() {
+ error := 1
+ _ = error
+}
+
+func h(error) {}
+
+func i(...error) {}
+`,
+ },
+ {
+ Name: "error.2",
+ In: `package main
+
+import "os"
+
+func f() os.Error {
+ return os.EOF
+}
+
+func g() string {
+ // these all convert because f is known
+ if err := f(); err != nil {
+ return err.String()
+ }
+ if err1 := f(); err1 != nil {
+ return err1.String()
+ }
+ if e := f(); e != nil {
+ return e.String()
+ }
+ if x := f(); x != nil {
+ return x.String()
+ }
+
+ // only the error names (err, err1, e) convert; u is not known
+ if err := u(); err != nil {
+ return err.String()
+ }
+ if err1 := u(); err1 != nil {
+ return err1.String()
+ }
+ if e := u(); e != nil {
+ return e.String()
+ }
+ if x := u(); x != nil {
+ return x.String()
+ }
+ return ""
+}
+
+type T int
+
+func (t T) String() string { return "t" }
+
+type PT int
+
+func (p *PT) String() string { return "pt" }
+
+type MyError int
+
+func (t MyError) String() string { return "myerror" }
+
+type PMyError int
+
+func (p *PMyError) String() string { return "pmyerror" }
+
+func error() {}
+
+var error int
+`,
+ Out: `package main
+
+import "io"
+
+func f() error {
+ return io.EOF
+}
+
+func g() string {
+ // these all convert because f is known
+ if err := f(); err != nil {
+ return err.Error()
+ }
+ if err1 := f(); err1 != nil {
+ return err1.Error()
+ }
+ if e := f(); e != nil {
+ return e.Error()
+ }
+ if x := f(); x != nil {
+ return x.Error()
+ }
+
+ // only the error names (err, err1, e) convert; u is not known
+ if err := u(); err != nil {
+ return err.Error()
+ }
+ if err1 := u(); err1 != nil {
+ return err1.Error()
+ }
+ if e := u(); e != nil {
+ return e.Error()
+ }
+ if x := u(); x != nil {
+ return x.String()
+ }
+ return ""
+}
+
+type T int
+
+func (t T) String() string { return "t" }
+
+type PT int
+
+func (p *PT) String() string { return "pt" }
+
+type MyError int
+
+func (t MyError) Error() string { return "myerror" }
+
+type PMyError int
+
+func (p *PMyError) Error() string { return "pmyerror" }
+
+func error_() {}
+
+var error_ int
+`,
+ },
+ {
+ Name: "error.3",
+ In: `package main
+
+import "os"
+
+func f() os.Error {
+ return os.EOF
+}
+
+type PathError struct {
+ Name string
+ Error os.Error
+}
+
+func (p *PathError) String() string {
+ return p.Name + ": " + p.Error.String()
+}
+
+func (p *PathError) Error1() string {
+ p = &PathError{Error: nil}
+ return fmt.Sprint(p.Name, ": ", p.Error)
+}
+`,
+ Out: `package main
+
+import "io"
+
+func f() error {
+ return io.EOF
+}
+
+type PathError struct {
+ Name string
+ Err error
+}
+
+func (p *PathError) Error() string {
+ return p.Name + ": " + p.Err.Error()
+}
+
+func (p *PathError) Error1() string {
+ p = &PathError{Err: nil}
+ return fmt.Sprint(p.Name, ": ", p.Err)
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/filepath.go b/src/cmd/fix/filepath.go
new file mode 100644
index 000000000..f31226018
--- /dev/null
+++ b/src/cmd/fix/filepath.go
@@ -0,0 +1,56 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(filepathFix)
+}
+
+var filepathFix = fix{
+ "filepath",
+ "2011-06-26",
+ filepathFunc,
+ `Adapt code from filepath.[List]SeparatorString to string(filepath.[List]Separator).
+
+http://codereview.appspot.com/4527090
+`,
+}
+
+func filepathFunc(f *ast.File) (fixed bool) {
+ if !imports(f, "path/filepath") {
+ return
+ }
+
+ walk(f, func(n interface{}) {
+ e, ok := n.(*ast.Expr)
+ if !ok {
+ return
+ }
+
+ var ident string
+ switch {
+ case isPkgDot(*e, "filepath", "SeparatorString"):
+ ident = "filepath.Separator"
+ case isPkgDot(*e, "filepath", "ListSeparatorString"):
+ ident = "filepath.ListSeparator"
+ default:
+ return
+ }
+
+ // string(filepath.[List]Separator)
+ *e = &ast.CallExpr{
+ Fun: ast.NewIdent("string"),
+ Args: []ast.Expr{ast.NewIdent(ident)},
+ }
+
+ fixed = true
+ })
+
+ return
+}
diff --git a/src/cmd/fix/filepath_test.go b/src/cmd/fix/filepath_test.go
new file mode 100644
index 000000000..37a2f5d9f
--- /dev/null
+++ b/src/cmd/fix/filepath_test.go
@@ -0,0 +1,33 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(filepathTests, filepathFunc)
+}
+
+var filepathTests = []testCase{
+ {
+ Name: "filepath.0",
+ In: `package main
+
+import (
+ "path/filepath"
+)
+
+var _ = filepath.SeparatorString
+var _ = filepath.ListSeparatorString
+`,
+ Out: `package main
+
+import (
+ "path/filepath"
+)
+
+var _ = string(filepath.Separator)
+var _ = string(filepath.ListSeparator)
+`,
+ },
+}
diff --git a/src/cmd/fix/fix.go b/src/cmd/fix/fix.go
new file mode 100644
index 000000000..a100be794
--- /dev/null
+++ b/src/cmd/fix/fix.go
@@ -0,0 +1,848 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "os"
+ "path"
+ "reflect"
+ "strconv"
+ "strings"
+)
+
+type fix struct {
+ name string
+ date string // date that fix was introduced, in YYYY-MM-DD format
+ f func(*ast.File) bool
+ desc string
+}
+
+// main runs sort.Sort(byName(fixes)) before printing list of fixes.
+type byName []fix
+
+func (f byName) Len() int { return len(f) }
+func (f byName) Swap(i, j int) { f[i], f[j] = f[j], f[i] }
+func (f byName) Less(i, j int) bool { return f[i].name < f[j].name }
+
+// main runs sort.Sort(byDate(fixes)) before applying fixes.
+type byDate []fix
+
+func (f byDate) Len() int { return len(f) }
+func (f byDate) Swap(i, j int) { f[i], f[j] = f[j], f[i] }
+func (f byDate) Less(i, j int) bool { return f[i].date < f[j].date }
+
+var fixes []fix
+
+func register(f fix) {
+ fixes = append(fixes, f)
+}
+
+// walk traverses the AST x, calling visit(y) for each node y in the tree but
+// also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt,
+// in a bottom-up traversal.
+func walk(x interface{}, visit func(interface{})) {
+ walkBeforeAfter(x, nop, visit)
+}
+
+func nop(interface{}) {}
+
+// walkBeforeAfter is like walk but calls before(x) before traversing
+// x's children and after(x) afterward.
+func walkBeforeAfter(x interface{}, before, after func(interface{})) {
+ before(x)
+
+ switch n := x.(type) {
+ default:
+ panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x))
+
+ case nil:
+
+ // pointers to interfaces
+ case *ast.Decl:
+ walkBeforeAfter(*n, before, after)
+ case *ast.Expr:
+ walkBeforeAfter(*n, before, after)
+ case *ast.Spec:
+ walkBeforeAfter(*n, before, after)
+ case *ast.Stmt:
+ walkBeforeAfter(*n, before, after)
+
+ // pointers to struct pointers
+ case **ast.BlockStmt:
+ walkBeforeAfter(*n, before, after)
+ case **ast.CallExpr:
+ walkBeforeAfter(*n, before, after)
+ case **ast.FieldList:
+ walkBeforeAfter(*n, before, after)
+ case **ast.FuncType:
+ walkBeforeAfter(*n, before, after)
+ case **ast.Ident:
+ walkBeforeAfter(*n, before, after)
+ case **ast.BasicLit:
+ walkBeforeAfter(*n, before, after)
+
+ // pointers to slices
+ case *[]ast.Decl:
+ walkBeforeAfter(*n, before, after)
+ case *[]ast.Expr:
+ walkBeforeAfter(*n, before, after)
+ case *[]*ast.File:
+ walkBeforeAfter(*n, before, after)
+ case *[]*ast.Ident:
+ walkBeforeAfter(*n, before, after)
+ case *[]ast.Spec:
+ walkBeforeAfter(*n, before, after)
+ case *[]ast.Stmt:
+ walkBeforeAfter(*n, before, after)
+
+ // These are ordered and grouped to match ../../pkg/go/ast/ast.go
+ case *ast.Field:
+ walkBeforeAfter(&n.Names, before, after)
+ walkBeforeAfter(&n.Type, before, after)
+ walkBeforeAfter(&n.Tag, before, after)
+ case *ast.FieldList:
+ for _, field := range n.List {
+ walkBeforeAfter(field, before, after)
+ }
+ case *ast.BadExpr:
+ case *ast.Ident:
+ case *ast.Ellipsis:
+ walkBeforeAfter(&n.Elt, before, after)
+ case *ast.BasicLit:
+ case *ast.FuncLit:
+ walkBeforeAfter(&n.Type, before, after)
+ walkBeforeAfter(&n.Body, before, after)
+ case *ast.CompositeLit:
+ walkBeforeAfter(&n.Type, before, after)
+ walkBeforeAfter(&n.Elts, before, after)
+ case *ast.ParenExpr:
+ walkBeforeAfter(&n.X, before, after)
+ case *ast.SelectorExpr:
+ walkBeforeAfter(&n.X, before, after)
+ case *ast.IndexExpr:
+ walkBeforeAfter(&n.X, before, after)
+ walkBeforeAfter(&n.Index, before, after)
+ case *ast.SliceExpr:
+ walkBeforeAfter(&n.X, before, after)
+ if n.Low != nil {
+ walkBeforeAfter(&n.Low, before, after)
+ }
+ if n.High != nil {
+ walkBeforeAfter(&n.High, before, after)
+ }
+ case *ast.TypeAssertExpr:
+ walkBeforeAfter(&n.X, before, after)
+ walkBeforeAfter(&n.Type, before, after)
+ case *ast.CallExpr:
+ walkBeforeAfter(&n.Fun, before, after)
+ walkBeforeAfter(&n.Args, before, after)
+ case *ast.StarExpr:
+ walkBeforeAfter(&n.X, before, after)
+ case *ast.UnaryExpr:
+ walkBeforeAfter(&n.X, before, after)
+ case *ast.BinaryExpr:
+ walkBeforeAfter(&n.X, before, after)
+ walkBeforeAfter(&n.Y, before, after)
+ case *ast.KeyValueExpr:
+ walkBeforeAfter(&n.Key, before, after)
+ walkBeforeAfter(&n.Value, before, after)
+
+ case *ast.ArrayType:
+ walkBeforeAfter(&n.Len, before, after)
+ walkBeforeAfter(&n.Elt, before, after)
+ case *ast.StructType:
+ walkBeforeAfter(&n.Fields, before, after)
+ case *ast.FuncType:
+ walkBeforeAfter(&n.Params, before, after)
+ if n.Results != nil {
+ walkBeforeAfter(&n.Results, before, after)
+ }
+ case *ast.InterfaceType:
+ walkBeforeAfter(&n.Methods, before, after)
+ case *ast.MapType:
+ walkBeforeAfter(&n.Key, before, after)
+ walkBeforeAfter(&n.Value, before, after)
+ case *ast.ChanType:
+ walkBeforeAfter(&n.Value, before, after)
+
+ case *ast.BadStmt:
+ case *ast.DeclStmt:
+ walkBeforeAfter(&n.Decl, before, after)
+ case *ast.EmptyStmt:
+ case *ast.LabeledStmt:
+ walkBeforeAfter(&n.Stmt, before, after)
+ case *ast.ExprStmt:
+ walkBeforeAfter(&n.X, before, after)
+ case *ast.SendStmt:
+ walkBeforeAfter(&n.Chan, before, after)
+ walkBeforeAfter(&n.Value, before, after)
+ case *ast.IncDecStmt:
+ walkBeforeAfter(&n.X, before, after)
+ case *ast.AssignStmt:
+ walkBeforeAfter(&n.Lhs, before, after)
+ walkBeforeAfter(&n.Rhs, before, after)
+ case *ast.GoStmt:
+ walkBeforeAfter(&n.Call, before, after)
+ case *ast.DeferStmt:
+ walkBeforeAfter(&n.Call, before, after)
+ case *ast.ReturnStmt:
+ walkBeforeAfter(&n.Results, before, after)
+ case *ast.BranchStmt:
+ case *ast.BlockStmt:
+ walkBeforeAfter(&n.List, before, after)
+ case *ast.IfStmt:
+ walkBeforeAfter(&n.Init, before, after)
+ walkBeforeAfter(&n.Cond, before, after)
+ walkBeforeAfter(&n.Body, before, after)
+ walkBeforeAfter(&n.Else, before, after)
+ case *ast.CaseClause:
+ walkBeforeAfter(&n.List, before, after)
+ walkBeforeAfter(&n.Body, before, after)
+ case *ast.SwitchStmt:
+ walkBeforeAfter(&n.Init, before, after)
+ walkBeforeAfter(&n.Tag, before, after)
+ walkBeforeAfter(&n.Body, before, after)
+ case *ast.TypeSwitchStmt:
+ walkBeforeAfter(&n.Init, before, after)
+ walkBeforeAfter(&n.Assign, before, after)
+ walkBeforeAfter(&n.Body, before, after)
+ case *ast.CommClause:
+ walkBeforeAfter(&n.Comm, before, after)
+ walkBeforeAfter(&n.Body, before, after)
+ case *ast.SelectStmt:
+ walkBeforeAfter(&n.Body, before, after)
+ case *ast.ForStmt:
+ walkBeforeAfter(&n.Init, before, after)
+ walkBeforeAfter(&n.Cond, before, after)
+ walkBeforeAfter(&n.Post, before, after)
+ walkBeforeAfter(&n.Body, before, after)
+ case *ast.RangeStmt:
+ walkBeforeAfter(&n.Key, before, after)
+ walkBeforeAfter(&n.Value, before, after)
+ walkBeforeAfter(&n.X, before, after)
+ walkBeforeAfter(&n.Body, before, after)
+
+ case *ast.ImportSpec:
+ case *ast.ValueSpec:
+ walkBeforeAfter(&n.Type, before, after)
+ walkBeforeAfter(&n.Values, before, after)
+ walkBeforeAfter(&n.Names, before, after)
+ case *ast.TypeSpec:
+ walkBeforeAfter(&n.Type, before, after)
+
+ case *ast.BadDecl:
+ case *ast.GenDecl:
+ walkBeforeAfter(&n.Specs, before, after)
+ case *ast.FuncDecl:
+ if n.Recv != nil {
+ walkBeforeAfter(&n.Recv, before, after)
+ }
+ walkBeforeAfter(&n.Type, before, after)
+ if n.Body != nil {
+ walkBeforeAfter(&n.Body, before, after)
+ }
+
+ case *ast.File:
+ walkBeforeAfter(&n.Decls, before, after)
+
+ case *ast.Package:
+ walkBeforeAfter(&n.Files, before, after)
+
+ case []*ast.File:
+ for i := range n {
+ walkBeforeAfter(&n[i], before, after)
+ }
+ case []ast.Decl:
+ for i := range n {
+ walkBeforeAfter(&n[i], before, after)
+ }
+ case []ast.Expr:
+ for i := range n {
+ walkBeforeAfter(&n[i], before, after)
+ }
+ case []*ast.Ident:
+ for i := range n {
+ walkBeforeAfter(&n[i], before, after)
+ }
+ case []ast.Stmt:
+ for i := range n {
+ walkBeforeAfter(&n[i], before, after)
+ }
+ case []ast.Spec:
+ for i := range n {
+ walkBeforeAfter(&n[i], before, after)
+ }
+ }
+ after(x)
+}
+
+// imports returns true if f imports path.
+func imports(f *ast.File, path string) bool {
+ return importSpec(f, path) != nil
+}
+
+// importSpec returns the import spec if f imports path,
+// or nil otherwise.
+func importSpec(f *ast.File, path string) *ast.ImportSpec {
+ for _, s := range f.Imports {
+ if importPath(s) == path {
+ return s
+ }
+ }
+ return nil
+}
+
+// importPath returns the unquoted import path of s,
+// or "" if the path is not properly quoted.
+func importPath(s *ast.ImportSpec) string {
+ t, err := strconv.Unquote(s.Path.Value)
+ if err == nil {
+ return t
+ }
+ return ""
+}
+
+// declImports reports whether gen contains an import of path.
+func declImports(gen *ast.GenDecl, path string) bool {
+ if gen.Tok != token.IMPORT {
+ return false
+ }
+ for _, spec := range gen.Specs {
+ impspec := spec.(*ast.ImportSpec)
+ if importPath(impspec) == path {
+ return true
+ }
+ }
+ return false
+}
+
+// isPkgDot returns true if t is the expression "pkg.name"
+// where pkg is an imported identifier.
+func isPkgDot(t ast.Expr, pkg, name string) bool {
+ sel, ok := t.(*ast.SelectorExpr)
+ return ok && isTopName(sel.X, pkg) && sel.Sel.String() == name
+}
+
+// isPtrPkgDot returns true if f is the expression "*pkg.name"
+// where pkg is an imported identifier.
+func isPtrPkgDot(t ast.Expr, pkg, name string) bool {
+ ptr, ok := t.(*ast.StarExpr)
+ return ok && isPkgDot(ptr.X, pkg, name)
+}
+
+// isTopName returns true if n is a top-level unresolved identifier with the given name.
+func isTopName(n ast.Expr, name string) bool {
+ id, ok := n.(*ast.Ident)
+ return ok && id.Name == name && id.Obj == nil
+}
+
+// isName returns true if n is an identifier with the given name.
+func isName(n ast.Expr, name string) bool {
+ id, ok := n.(*ast.Ident)
+ return ok && id.String() == name
+}
+
+// isCall returns true if t is a call to pkg.name.
+func isCall(t ast.Expr, pkg, name string) bool {
+ call, ok := t.(*ast.CallExpr)
+ return ok && isPkgDot(call.Fun, pkg, name)
+}
+
+// If n is an *ast.Ident, isIdent returns it; otherwise isIdent returns nil.
+func isIdent(n interface{}) *ast.Ident {
+ id, _ := n.(*ast.Ident)
+ return id
+}
+
+// refersTo returns true if n is a reference to the same object as x.
+func refersTo(n ast.Node, x *ast.Ident) bool {
+ id, ok := n.(*ast.Ident)
+ // The test of id.Name == x.Name handles top-level unresolved
+ // identifiers, which all have Obj == nil.
+ return ok && id.Obj == x.Obj && id.Name == x.Name
+}
+
+// isBlank returns true if n is the blank identifier.
+func isBlank(n ast.Expr) bool {
+ return isName(n, "_")
+}
+
+// isEmptyString returns true if n is an empty string literal.
+func isEmptyString(n ast.Expr) bool {
+ lit, ok := n.(*ast.BasicLit)
+ return ok && lit.Kind == token.STRING && len(lit.Value) == 2
+}
+
+func warn(pos token.Pos, msg string, args ...interface{}) {
+ if pos.IsValid() {
+ msg = "%s: " + msg
+ arg1 := []interface{}{fset.Position(pos).String()}
+ args = append(arg1, args...)
+ }
+ fmt.Fprintf(os.Stderr, msg+"\n", args...)
+}
+
+// countUses returns the number of uses of the identifier x in scope.
+func countUses(x *ast.Ident, scope []ast.Stmt) int {
+ count := 0
+ ff := func(n interface{}) {
+ if n, ok := n.(ast.Node); ok && refersTo(n, x) {
+ count++
+ }
+ }
+ for _, n := range scope {
+ walk(n, ff)
+ }
+ return count
+}
+
+// rewriteUses replaces all uses of the identifier x and !x in scope
+// with f(x.Pos()) and fnot(x.Pos()).
+func rewriteUses(x *ast.Ident, f, fnot func(token.Pos) ast.Expr, scope []ast.Stmt) {
+ var lastF ast.Expr
+ ff := func(n interface{}) {
+ ptr, ok := n.(*ast.Expr)
+ if !ok {
+ return
+ }
+ nn := *ptr
+
+ // The child node was just walked and possibly replaced.
+ // If it was replaced and this is a negation, replace with fnot(p).
+ not, ok := nn.(*ast.UnaryExpr)
+ if ok && not.Op == token.NOT && not.X == lastF {
+ *ptr = fnot(nn.Pos())
+ return
+ }
+ if refersTo(nn, x) {
+ lastF = f(nn.Pos())
+ *ptr = lastF
+ }
+ }
+ for _, n := range scope {
+ walk(n, ff)
+ }
+}
+
+// assignsTo returns true if any of the code in scope assigns to or takes the address of x.
+func assignsTo(x *ast.Ident, scope []ast.Stmt) bool {
+ assigned := false
+ ff := func(n interface{}) {
+ if assigned {
+ return
+ }
+ switch n := n.(type) {
+ case *ast.UnaryExpr:
+ // use of &x
+ if n.Op == token.AND && refersTo(n.X, x) {
+ assigned = true
+ return
+ }
+ case *ast.AssignStmt:
+ for _, l := range n.Lhs {
+ if refersTo(l, x) {
+ assigned = true
+ return
+ }
+ }
+ }
+ }
+ for _, n := range scope {
+ if assigned {
+ break
+ }
+ walk(n, ff)
+ }
+ return assigned
+}
+
+// newPkgDot returns an ast.Expr referring to "pkg.name" at position pos.
+func newPkgDot(pos token.Pos, pkg, name string) ast.Expr {
+ return &ast.SelectorExpr{
+ X: &ast.Ident{
+ NamePos: pos,
+ Name: pkg,
+ },
+ Sel: &ast.Ident{
+ NamePos: pos,
+ Name: name,
+ },
+ }
+}
+
+// renameTop renames all references to the top-level name old.
+// It returns true if it makes any changes.
+func renameTop(f *ast.File, old, new string) bool {
+ var fixed bool
+
+ // Rename any conflicting imports
+ // (assuming package name is last element of path).
+ for _, s := range f.Imports {
+ if s.Name != nil {
+ if s.Name.Name == old {
+ s.Name.Name = new
+ fixed = true
+ }
+ } else {
+ _, thisName := path.Split(importPath(s))
+ if thisName == old {
+ s.Name = ast.NewIdent(new)
+ fixed = true
+ }
+ }
+ }
+
+ // Rename any top-level declarations.
+ for _, d := range f.Decls {
+ switch d := d.(type) {
+ case *ast.FuncDecl:
+ if d.Recv == nil && d.Name.Name == old {
+ d.Name.Name = new
+ d.Name.Obj.Name = new
+ fixed = true
+ }
+ case *ast.GenDecl:
+ for _, s := range d.Specs {
+ switch s := s.(type) {
+ case *ast.TypeSpec:
+ if s.Name.Name == old {
+ s.Name.Name = new
+ s.Name.Obj.Name = new
+ fixed = true
+ }
+ case *ast.ValueSpec:
+ for _, n := range s.Names {
+ if n.Name == old {
+ n.Name = new
+ n.Obj.Name = new
+ fixed = true
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Rename top-level old to new, both unresolved names
+ // (probably defined in another file) and names that resolve
+ // to a declaration we renamed.
+ walk(f, func(n interface{}) {
+ id, ok := n.(*ast.Ident)
+ if ok && isTopName(id, old) {
+ id.Name = new
+ fixed = true
+ }
+ if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new {
+ id.Name = id.Obj.Name
+ fixed = true
+ }
+ })
+
+ return fixed
+}
+
+// matchLen returns the length of the longest prefix shared by x and y.
+func matchLen(x, y string) int {
+ i := 0
+ for i < len(x) && i < len(y) && x[i] == y[i] {
+ i++
+ }
+ return i
+}
+
+// addImport adds the import path to the file f, if absent.
+func addImport(f *ast.File, ipath string) (added bool) {
+ if imports(f, ipath) {
+ return false
+ }
+
+ // Determine name of import.
+ // Assume added imports follow convention of using last element.
+ _, name := path.Split(ipath)
+
+ // Rename any conflicting top-level references from name to name_.
+ renameTop(f, name, name+"_")
+
+ newImport := &ast.ImportSpec{
+ Path: &ast.BasicLit{
+ Kind: token.STRING,
+ Value: strconv.Quote(ipath),
+ },
+ }
+
+ // Find an import decl to add to.
+ var (
+ bestMatch = -1
+ lastImport = -1
+ impDecl *ast.GenDecl
+ impIndex = -1
+ )
+ for i, decl := range f.Decls {
+ gen, ok := decl.(*ast.GenDecl)
+ if ok && gen.Tok == token.IMPORT {
+ lastImport = i
+ // Do not add to import "C", to avoid disrupting the
+ // association with its doc comment, breaking cgo.
+ if declImports(gen, "C") {
+ continue
+ }
+
+ // Compute longest shared prefix with imports in this block.
+ for j, spec := range gen.Specs {
+ impspec := spec.(*ast.ImportSpec)
+ n := matchLen(importPath(impspec), ipath)
+ if n > bestMatch {
+ bestMatch = n
+ impDecl = gen
+ impIndex = j
+ }
+ }
+ }
+ }
+
+ // If no import decl found, add one after the last import.
+ if impDecl == nil {
+ impDecl = &ast.GenDecl{
+ Tok: token.IMPORT,
+ }
+ f.Decls = append(f.Decls, nil)
+ copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
+ f.Decls[lastImport+1] = impDecl
+ }
+
+ // Ensure the import decl has parentheses, if needed.
+ if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() {
+ impDecl.Lparen = impDecl.Pos()
+ }
+
+ insertAt := impIndex + 1
+ if insertAt == 0 {
+ insertAt = len(impDecl.Specs)
+ }
+ impDecl.Specs = append(impDecl.Specs, nil)
+ copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:])
+ impDecl.Specs[insertAt] = newImport
+ if insertAt > 0 {
+ // Assign same position as the previous import,
+ // so that the sorter sees it as being in the same block.
+ prev := impDecl.Specs[insertAt-1]
+ newImport.Path.ValuePos = prev.Pos()
+ newImport.EndPos = prev.Pos()
+ }
+
+ f.Imports = append(f.Imports, newImport)
+ return true
+}
+
+// deleteImport deletes the import path from the file f, if present.
+func deleteImport(f *ast.File, path string) (deleted bool) {
+ oldImport := importSpec(f, path)
+
+ // Find the import node that imports path, if any.
+ for i, decl := range f.Decls {
+ gen, ok := decl.(*ast.GenDecl)
+ if !ok || gen.Tok != token.IMPORT {
+ continue
+ }
+ for j, spec := range gen.Specs {
+ impspec := spec.(*ast.ImportSpec)
+ if oldImport != impspec {
+ continue
+ }
+
+ // We found an import spec that imports path.
+ // Delete it.
+ deleted = true
+ copy(gen.Specs[j:], gen.Specs[j+1:])
+ gen.Specs = gen.Specs[:len(gen.Specs)-1]
+
+ // If this was the last import spec in this decl,
+ // delete the decl, too.
+ if len(gen.Specs) == 0 {
+ copy(f.Decls[i:], f.Decls[i+1:])
+ f.Decls = f.Decls[:len(f.Decls)-1]
+ } else if len(gen.Specs) == 1 {
+ gen.Lparen = token.NoPos // drop parens
+ }
+ if j > 0 {
+ // We deleted an entry but now there will be
+ // a blank line-sized hole where the import was.
+ // Close the hole by making the previous
+ // import appear to "end" where this one did.
+ gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
+ }
+ break
+ }
+ }
+
+ // Delete it from f.Imports.
+ for i, imp := range f.Imports {
+ if imp == oldImport {
+ copy(f.Imports[i:], f.Imports[i+1:])
+ f.Imports = f.Imports[:len(f.Imports)-1]
+ break
+ }
+ }
+
+ return
+}
+
+// rewriteImport rewrites any import of path oldPath to path newPath.
+func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
+ for _, imp := range f.Imports {
+ if importPath(imp) == oldPath {
+ rewrote = true
+ // record old End, because the default is to compute
+ // it using the length of imp.Path.Value.
+ imp.EndPos = imp.End()
+ imp.Path.Value = strconv.Quote(newPath)
+ }
+ }
+ return
+}
+
+func usesImport(f *ast.File, path string) (used bool) {
+ spec := importSpec(f, path)
+ if spec == nil {
+ return
+ }
+
+ name := spec.Name.String()
+ switch name {
+ case "<nil>":
+ // If the package name is not explicitly specified,
+ // make an educated guess. This is not guaranteed to be correct.
+ lastSlash := strings.LastIndex(path, "/")
+ if lastSlash == -1 {
+ name = path
+ } else {
+ name = path[lastSlash+1:]
+ }
+ case "_", ".":
+ // Not sure if this import is used - err on the side of caution.
+ return true
+ }
+
+ walk(f, func(n interface{}) {
+ sel, ok := n.(*ast.SelectorExpr)
+ if ok && isTopName(sel.X, name) {
+ used = true
+ }
+ })
+
+ return
+}
+
+func expr(s string) ast.Expr {
+ x, err := parser.ParseExpr(s)
+ if err != nil {
+ panic("parsing " + s + ": " + err.Error())
+ }
+ // Remove position information to avoid spurious newlines.
+ killPos(reflect.ValueOf(x))
+ return x
+}
+
+var posType = reflect.TypeOf(token.Pos(0))
+
+func killPos(v reflect.Value) {
+ switch v.Kind() {
+ case reflect.Ptr, reflect.Interface:
+ if !v.IsNil() {
+ killPos(v.Elem())
+ }
+ case reflect.Slice:
+ n := v.Len()
+ for i := 0; i < n; i++ {
+ killPos(v.Index(i))
+ }
+ case reflect.Struct:
+ n := v.NumField()
+ for i := 0; i < n; i++ {
+ f := v.Field(i)
+ if f.Type() == posType {
+ f.SetInt(0)
+ continue
+ }
+ killPos(f)
+ }
+ }
+}
+
+// A Rename describes a single renaming.
+type rename struct {
+ OldImport string // only apply rename if this import is present
+ NewImport string // add this import during rewrite
+ Old string // old name: p.T or *p.T
+ New string // new name: p.T or *p.T
+}
+
+func renameFix(tab []rename) func(*ast.File) bool {
+ return func(f *ast.File) bool {
+ return renameFixTab(f, tab)
+ }
+}
+
+func parseName(s string) (ptr bool, pkg, nam string) {
+ i := strings.Index(s, ".")
+ if i < 0 {
+ panic("parseName: invalid name " + s)
+ }
+ if strings.HasPrefix(s, "*") {
+ ptr = true
+ s = s[1:]
+ i--
+ }
+ pkg = s[:i]
+ nam = s[i+1:]
+ return
+}
+
+func renameFixTab(f *ast.File, tab []rename) bool {
+ fixed := false
+ added := map[string]bool{}
+ check := map[string]bool{}
+ for _, t := range tab {
+ if !imports(f, t.OldImport) {
+ continue
+ }
+ optr, opkg, onam := parseName(t.Old)
+ walk(f, func(n interface{}) {
+ np, ok := n.(*ast.Expr)
+ if !ok {
+ return
+ }
+ x := *np
+ if optr {
+ p, ok := x.(*ast.StarExpr)
+ if !ok {
+ return
+ }
+ x = p.X
+ }
+ if !isPkgDot(x, opkg, onam) {
+ return
+ }
+ if t.NewImport != "" && !added[t.NewImport] {
+ addImport(f, t.NewImport)
+ added[t.NewImport] = true
+ }
+ *np = expr(t.New)
+ check[t.OldImport] = true
+ fixed = true
+ })
+ }
+
+ for ipath := range check {
+ if !usesImport(f, ipath) {
+ deleteImport(f, ipath)
+ }
+ }
+ return fixed
+}
diff --git a/src/cmd/fix/go1pkgrename.go b/src/cmd/fix/go1pkgrename.go
new file mode 100644
index 000000000..f701f62f0
--- /dev/null
+++ b/src/cmd/fix/go1pkgrename.go
@@ -0,0 +1,146 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "strings"
+)
+
+func init() {
+ register(go1pkgrenameFix)
+}
+
+var go1pkgrenameFix = fix{
+ "go1rename",
+ "2011-11-08",
+ go1pkgrename,
+ `Rewrite imports for packages moved during transition to Go 1.
+
+http://codereview.appspot.com/5316078
+`,
+}
+
+var go1PackageRenames = []struct{ old, new string }{
+ {"asn1", "encoding/asn1"},
+ {"big", "math/big"},
+ {"cmath", "math/cmplx"},
+ {"csv", "encoding/csv"},
+ {"exec", "os/exec"},
+ {"exp/template/html", "html/template"},
+ {"gob", "encoding/gob"},
+ {"http", "net/http"},
+ {"http/cgi", "net/http/cgi"},
+ {"http/fcgi", "net/http/fcgi"},
+ {"http/httptest", "net/http/httptest"},
+ {"http/pprof", "net/http/pprof"},
+ {"json", "encoding/json"},
+ {"mail", "net/mail"},
+ {"rpc", "net/rpc"},
+ {"rpc/jsonrpc", "net/rpc/jsonrpc"},
+ {"scanner", "text/scanner"},
+ {"smtp", "net/smtp"},
+ {"syslog", "log/syslog"},
+ {"tabwriter", "text/tabwriter"},
+ {"template", "text/template"},
+ {"template/parse", "text/template/parse"},
+ {"rand", "math/rand"},
+ {"url", "net/url"},
+ {"utf16", "unicode/utf16"},
+ {"utf8", "unicode/utf8"},
+ {"xml", "encoding/xml"},
+
+ // go.crypto sub-repository
+ {"crypto/bcrypt", "code.google.com/p/go.crypto/bcrypt"},
+ {"crypto/blowfish", "code.google.com/p/go.crypto/blowfish"},
+ {"crypto/cast5", "code.google.com/p/go.crypto/cast5"},
+ {"crypto/md4", "code.google.com/p/go.crypto/md4"},
+ {"crypto/ocsp", "code.google.com/p/go.crypto/ocsp"},
+ {"crypto/openpgp", "code.google.com/p/go.crypto/openpgp"},
+ {"crypto/openpgp/armor", "code.google.com/p/go.crypto/openpgp/armor"},
+ {"crypto/openpgp/elgamal", "code.google.com/p/go.crypto/openpgp/elgamal"},
+ {"crypto/openpgp/errors", "code.google.com/p/go.crypto/openpgp/errors"},
+ {"crypto/openpgp/packet", "code.google.com/p/go.crypto/openpgp/packet"},
+ {"crypto/openpgp/s2k", "code.google.com/p/go.crypto/openpgp/s2k"},
+ {"crypto/ripemd160", "code.google.com/p/go.crypto/ripemd160"},
+ {"crypto/twofish", "code.google.com/p/go.crypto/twofish"},
+ {"crypto/xtea", "code.google.com/p/go.crypto/xtea"},
+ {"exp/ssh", "code.google.com/p/go.crypto/ssh"},
+
+ // go.image sub-repository
+ {"image/bmp", "code.google.com/p/go.image/bmp"},
+ {"image/tiff", "code.google.com/p/go.image/tiff"},
+
+ // go.net sub-repository
+ {"net/dict", "code.google.com/p/go.net/dict"},
+ {"net/websocket", "code.google.com/p/go.net/websocket"},
+ {"exp/spdy", "code.google.com/p/go.net/spdy"},
+ {"http/spdy", "code.google.com/p/go.net/spdy"},
+
+ // go.codereview sub-repository
+ {"encoding/git85", "code.google.com/p/go.codereview/git85"},
+ {"patch", "code.google.com/p/go.codereview/patch"},
+
+ // exp
+ {"ebnf", "exp/ebnf"},
+ {"go/types", "exp/types"},
+
+ // deleted
+ {"container/vector", ""},
+ {"exp/datafmt", ""},
+ {"go/typechecker", ""},
+ {"old/netchan", ""},
+ {"old/regexp", ""},
+ {"old/template", ""},
+ {"try", ""},
+}
+
+var go1PackageNameRenames = []struct{ newPath, old, new string }{
+ {"html/template", "html", "template"},
+ {"math/cmplx", "cmath", "cmplx"},
+}
+
+func go1pkgrename(f *ast.File) bool {
+ fixed := false
+
+ // First update the imports.
+ for _, rename := range go1PackageRenames {
+ spec := importSpec(f, rename.old)
+ if spec == nil {
+ continue
+ }
+ if rename.new == "" {
+ warn(spec.Pos(), "package %q has been deleted in Go 1", rename.old)
+ continue
+ }
+ if rewriteImport(f, rename.old, rename.new) {
+ fixed = true
+ }
+ if strings.HasPrefix(rename.new, "exp/") {
+ warn(spec.Pos(), "package %q is not part of Go 1", rename.new)
+ }
+ }
+ if !fixed {
+ return false
+ }
+
+ // Now update the package names used by importers.
+ for _, rename := range go1PackageNameRenames {
+ // These are rare packages, so do the import test before walking.
+ if imports(f, rename.newPath) {
+ walk(f, func(n interface{}) {
+ if sel, ok := n.(*ast.SelectorExpr); ok {
+ if isTopName(sel.X, rename.old) {
+ // We know Sel.X is an Ident.
+ sel.X.(*ast.Ident).Name = rename.new
+ return
+ }
+ }
+ })
+ }
+ }
+
+ return fixed
+}
diff --git a/src/cmd/fix/go1pkgrename_test.go b/src/cmd/fix/go1pkgrename_test.go
new file mode 100644
index 000000000..840e443b0
--- /dev/null
+++ b/src/cmd/fix/go1pkgrename_test.go
@@ -0,0 +1,139 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(go1pkgrenameTests, go1pkgrename)
+}
+
+var go1pkgrenameTests = []testCase{
+ {
+ Name: "go1rename.0",
+ In: `package main
+
+import (
+ "asn1"
+ "big"
+ "cmath"
+ "csv"
+ "exec"
+ "exp/template/html"
+ "gob"
+ "http"
+ "http/cgi"
+ "http/fcgi"
+ "http/httptest"
+ "http/pprof"
+ "json"
+ "mail"
+ "rand"
+ "rpc"
+ "rpc/jsonrpc"
+ "scanner"
+ "smtp"
+ "syslog"
+ "tabwriter"
+ "template"
+ "template/parse"
+ "url"
+ "utf16"
+ "utf8"
+ "xml"
+
+ "crypto/bcrypt"
+)
+`,
+ Out: `package main
+
+import (
+ "encoding/asn1"
+ "encoding/csv"
+ "encoding/gob"
+ "encoding/json"
+ "encoding/xml"
+ "html/template"
+ "log/syslog"
+ "math/big"
+ "math/cmplx"
+ "math/rand"
+ "net/http"
+ "net/http/cgi"
+ "net/http/fcgi"
+ "net/http/httptest"
+ "net/http/pprof"
+ "net/mail"
+ "net/rpc"
+ "net/rpc/jsonrpc"
+ "net/smtp"
+ "net/url"
+ "os/exec"
+ "text/scanner"
+ "text/tabwriter"
+ "text/template"
+ "text/template/parse"
+ "unicode/utf16"
+ "unicode/utf8"
+
+ "code.google.com/p/go.crypto/bcrypt"
+)
+`,
+ },
+ {
+ Name: "go1rename.1",
+ In: `package main
+
+import "cmath"
+import poot "exp/template/html"
+
+import (
+ "ebnf"
+ "old/regexp"
+)
+
+var _ = cmath.Sin
+var _ = poot.Poot
+`,
+ Out: `package main
+
+import "math/cmplx"
+import poot "html/template"
+
+import (
+ "exp/ebnf"
+ "old/regexp"
+)
+
+var _ = cmplx.Sin
+var _ = poot.Poot
+`,
+ },
+ {
+ Name: "go1rename.2",
+ In: `package foo
+
+import (
+ "fmt"
+ "http"
+ "url"
+
+ "google/secret/project/go"
+)
+
+func main() {}
+`,
+ Out: `package foo
+
+import (
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "google/secret/project/go"
+)
+
+func main() {}
+`,
+ },
+}
diff --git a/src/cmd/fix/go1rename.go b/src/cmd/fix/go1rename.go
new file mode 100644
index 000000000..9266c749c
--- /dev/null
+++ b/src/cmd/fix/go1rename.go
@@ -0,0 +1,167 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ register(go1renameFix)
+}
+
+var go1renameFix = fix{
+ "go1rename",
+ "2012-02-12",
+ renameFix(go1renameReplace),
+ `Rewrite package-level names that have been renamed in Go 1.
+
+http://codereview.appspot.com/5625045/
+http://codereview.appspot.com/5672072/
+`,
+}
+
+var go1renameReplace = []rename{
+ {
+ OldImport: "crypto/aes",
+ NewImport: "crypto/cipher",
+ Old: "*aes.Cipher",
+ New: "cipher.Block",
+ },
+ {
+ OldImport: "crypto/des",
+ NewImport: "crypto/cipher",
+ Old: "*des.Cipher",
+ New: "cipher.Block",
+ },
+ {
+ OldImport: "crypto/des",
+ NewImport: "crypto/cipher",
+ Old: "*des.TripleDESCipher",
+ New: "cipher.Block",
+ },
+ {
+ OldImport: "encoding/json",
+ NewImport: "",
+ Old: "json.MarshalForHTML",
+ New: "json.Marshal",
+ },
+ {
+ OldImport: "net/url",
+ NewImport: "",
+ Old: "url.ParseWithReference",
+ New: "url.Parse",
+ },
+ {
+ OldImport: "net/url",
+ NewImport: "",
+ Old: "url.ParseRequest",
+ New: "url.ParseRequestURI",
+ },
+ {
+ OldImport: "os",
+ NewImport: "syscall",
+ Old: "os.Exec",
+ New: "syscall.Exec",
+ },
+ {
+ OldImport: "runtime",
+ NewImport: "",
+ Old: "runtime.Cgocalls",
+ New: "runtime.NumCgoCall",
+ },
+ {
+ OldImport: "runtime",
+ NewImport: "",
+ Old: "runtime.Goroutines",
+ New: "runtime.NumGoroutine",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.ErrPersistEOF",
+ New: "httputil.ErrPersistEOF",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.ErrPipeline",
+ New: "httputil.ErrPipeline",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.ErrClosed",
+ New: "httputil.ErrClosed",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.ServerConn",
+ New: "httputil.ServerConn",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.ClientConn",
+ New: "httputil.ClientConn",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.NewChunkedReader",
+ New: "httputil.NewChunkedReader",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.NewChunkedWriter",
+ New: "httputil.NewChunkedWriter",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.ReverseProxy",
+ New: "httputil.ReverseProxy",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.NewSingleHostReverseProxy",
+ New: "httputil.NewSingleHostReverseProxy",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.DumpRequest",
+ New: "httputil.DumpRequest",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.DumpRequestOut",
+ New: "httputil.DumpRequestOut",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.DumpResponse",
+ New: "httputil.DumpResponse",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.NewClientConn",
+ New: "httputil.NewClientConn",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.NewServerConn",
+ New: "httputil.NewServerConn",
+ },
+ {
+ OldImport: "net/http",
+ NewImport: "net/http/httputil",
+ Old: "http.NewProxyClientConn",
+ New: "httputil.NewProxyClientConn",
+ },
+}
diff --git a/src/cmd/fix/go1rename_test.go b/src/cmd/fix/go1rename_test.go
new file mode 100644
index 000000000..90219ba71
--- /dev/null
+++ b/src/cmd/fix/go1rename_test.go
@@ -0,0 +1,195 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(go1renameTests, go1renameFix.f)
+}
+
+var go1renameTests = []testCase{
+ {
+ Name: "go1rename.0",
+ In: `package main
+
+import (
+ "crypto/aes"
+ "crypto/des"
+ "encoding/json"
+ "net/http"
+ "net/url"
+ "os"
+ "runtime"
+)
+
+var (
+ _ *aes.Cipher
+ _ *des.Cipher
+ _ *des.TripleDESCipher
+ _ = json.MarshalForHTML
+ _ = aes.New()
+ _ = url.Parse
+ _ = url.ParseWithReference
+ _ = url.ParseRequest
+ _ = os.Exec
+ _ = runtime.Cgocalls
+ _ = runtime.Goroutines
+ _ = http.ErrPersistEOF
+ _ = http.ErrPipeline
+ _ = http.ErrClosed
+ _ = http.NewSingleHostReverseProxy
+ _ = http.NewChunkedReader
+ _ = http.NewChunkedWriter
+ _ *http.ReverseProxy
+ _ *http.ClientConn
+ _ *http.ServerConn
+)
+`,
+ Out: `package main
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "encoding/json"
+ "net/http/httputil"
+ "net/url"
+ "runtime"
+ "syscall"
+)
+
+var (
+ _ cipher.Block
+ _ cipher.Block
+ _ cipher.Block
+ _ = json.Marshal
+ _ = aes.New()
+ _ = url.Parse
+ _ = url.Parse
+ _ = url.ParseRequestURI
+ _ = syscall.Exec
+ _ = runtime.NumCgoCall
+ _ = runtime.NumGoroutine
+ _ = httputil.ErrPersistEOF
+ _ = httputil.ErrPipeline
+ _ = httputil.ErrClosed
+ _ = httputil.NewSingleHostReverseProxy
+ _ = httputil.NewChunkedReader
+ _ = httputil.NewChunkedWriter
+ _ *httputil.ReverseProxy
+ _ *httputil.ClientConn
+ _ *httputil.ServerConn
+)
+`,
+ },
+ {
+ Name: "httputil.0",
+ In: `package main
+
+import "net/http"
+
+func f() {
+ http.DumpRequest(nil, false)
+ http.DumpRequestOut(nil, false)
+ http.DumpResponse(nil, false)
+ http.NewChunkedReader(nil)
+ http.NewChunkedWriter(nil)
+ http.NewClientConn(nil, nil)
+ http.NewProxyClientConn(nil, nil)
+ http.NewServerConn(nil, nil)
+ http.NewSingleHostReverseProxy(nil)
+}
+`,
+ Out: `package main
+
+import "net/http/httputil"
+
+func f() {
+ httputil.DumpRequest(nil, false)
+ httputil.DumpRequestOut(nil, false)
+ httputil.DumpResponse(nil, false)
+ httputil.NewChunkedReader(nil)
+ httputil.NewChunkedWriter(nil)
+ httputil.NewClientConn(nil, nil)
+ httputil.NewProxyClientConn(nil, nil)
+ httputil.NewServerConn(nil, nil)
+ httputil.NewSingleHostReverseProxy(nil)
+}
+`,
+ },
+ {
+ Name: "httputil.1",
+ In: `package main
+
+import "net/http"
+
+func f() {
+ http.DumpRequest(nil, false)
+ http.DumpRequestOut(nil, false)
+ http.DumpResponse(nil, false)
+ http.NewChunkedReader(nil)
+ http.NewChunkedWriter(nil)
+ http.NewClientConn(nil, nil)
+ http.NewProxyClientConn(nil, nil)
+ http.NewServerConn(nil, nil)
+ http.NewSingleHostReverseProxy(nil)
+}
+`,
+ Out: `package main
+
+import "net/http/httputil"
+
+func f() {
+ httputil.DumpRequest(nil, false)
+ httputil.DumpRequestOut(nil, false)
+ httputil.DumpResponse(nil, false)
+ httputil.NewChunkedReader(nil)
+ httputil.NewChunkedWriter(nil)
+ httputil.NewClientConn(nil, nil)
+ httputil.NewProxyClientConn(nil, nil)
+ httputil.NewServerConn(nil, nil)
+ httputil.NewSingleHostReverseProxy(nil)
+}
+`,
+ },
+ {
+ Name: "httputil.2",
+ In: `package main
+
+import "net/http"
+
+func f() {
+ http.DumpRequest(nil, false)
+ http.DumpRequestOut(nil, false)
+ http.DumpResponse(nil, false)
+ http.NewChunkedReader(nil)
+ http.NewChunkedWriter(nil)
+ http.NewClientConn(nil, nil)
+ http.NewProxyClientConn(nil, nil)
+ http.NewServerConn(nil, nil)
+ http.NewSingleHostReverseProxy(nil)
+ http.Get("")
+}
+`,
+ Out: `package main
+
+import (
+ "net/http"
+ "net/http/httputil"
+)
+
+func f() {
+ httputil.DumpRequest(nil, false)
+ httputil.DumpRequestOut(nil, false)
+ httputil.DumpResponse(nil, false)
+ httputil.NewChunkedReader(nil)
+ httputil.NewChunkedWriter(nil)
+ httputil.NewClientConn(nil, nil)
+ httputil.NewProxyClientConn(nil, nil)
+ httputil.NewServerConn(nil, nil)
+ httputil.NewSingleHostReverseProxy(nil)
+ http.Get("")
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/googlecode.go b/src/cmd/fix/googlecode.go
new file mode 100644
index 000000000..143781a74
--- /dev/null
+++ b/src/cmd/fix/googlecode.go
@@ -0,0 +1,41 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "regexp"
+)
+
+func init() {
+ register(googlecodeFix)
+}
+
+var googlecodeFix = fix{
+ "googlecode",
+ "2011-11-21",
+ googlecode,
+ `Rewrite Google Code imports from the deprecated form
+"foo.googlecode.com/vcs/path" to "code.google.com/p/foo/path".
+`,
+}
+
+var googlecodeRe = regexp.MustCompile(`^([a-z0-9\-]+)\.googlecode\.com/(svn|git|hg)(/[a-z0-9A-Z_.\-/]+)?$`)
+
+func googlecode(f *ast.File) bool {
+ fixed := false
+
+ for _, s := range f.Imports {
+ old := importPath(s)
+ if m := googlecodeRe.FindStringSubmatch(old); m != nil {
+ new := "code.google.com/p/" + m[1] + m[3]
+ if rewriteImport(f, old, new) {
+ fixed = true
+ }
+ }
+ }
+
+ return fixed
+}
diff --git a/src/cmd/fix/googlecode_test.go b/src/cmd/fix/googlecode_test.go
new file mode 100644
index 000000000..c62ee4f32
--- /dev/null
+++ b/src/cmd/fix/googlecode_test.go
@@ -0,0 +1,31 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(googlecodeTests, googlecode)
+}
+
+var googlecodeTests = []testCase{
+ {
+ Name: "googlecode.0",
+ In: `package main
+
+import (
+ "foo.googlecode.com/hg/bar"
+ "go-qux-23.googlecode.com/svn"
+ "zap.googlecode.com/git/some/path"
+)
+`,
+ Out: `package main
+
+import (
+ "code.google.com/p/foo/bar"
+ "code.google.com/p/go-qux-23"
+ "code.google.com/p/zap/some/path"
+)
+`,
+ },
+}
diff --git a/src/cmd/fix/hashsum.go b/src/cmd/fix/hashsum.go
new file mode 100644
index 000000000..0df6ad749
--- /dev/null
+++ b/src/cmd/fix/hashsum.go
@@ -0,0 +1,94 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(hashSumFix)
+}
+
+var hashSumFix = fix{
+ "hashsum",
+ "2011-11-30",
+ hashSumFn,
+ `Pass a nil argument to calls to hash.Sum
+
+This fix rewrites code so that it passes a nil argument to hash.Sum.
+The additional argument will allow callers to avoid an
+allocation in the future.
+
+http://codereview.appspot.com/5448065
+`,
+}
+
+// Type-checking configuration: tell the type-checker this basic
+// information about types, functions, and variables in external packages.
+var hashSumTypeConfig = &TypeConfig{
+ Var: map[string]string{
+ "crypto.MD4": "crypto.Hash",
+ "crypto.MD5": "crypto.Hash",
+ "crypto.SHA1": "crypto.Hash",
+ "crypto.SHA224": "crypto.Hash",
+ "crypto.SHA256": "crypto.Hash",
+ "crypto.SHA384": "crypto.Hash",
+ "crypto.SHA512": "crypto.Hash",
+ "crypto.MD5SHA1": "crypto.Hash",
+ "crypto.RIPEMD160": "crypto.Hash",
+ },
+
+ Func: map[string]string{
+ "adler32.New": "hash.Hash",
+ "crc32.New": "hash.Hash",
+ "crc32.NewIEEE": "hash.Hash",
+ "crc64.New": "hash.Hash",
+ "fnv.New32a": "hash.Hash",
+ "fnv.New32": "hash.Hash",
+ "fnv.New64a": "hash.Hash",
+ "fnv.New64": "hash.Hash",
+ "hmac.New": "hash.Hash",
+ "hmac.NewMD5": "hash.Hash",
+ "hmac.NewSHA1": "hash.Hash",
+ "hmac.NewSHA256": "hash.Hash",
+ "md4.New": "hash.Hash",
+ "md5.New": "hash.Hash",
+ "ripemd160.New": "hash.Hash",
+ "sha1.New224": "hash.Hash",
+ "sha1.New": "hash.Hash",
+ "sha256.New224": "hash.Hash",
+ "sha256.New": "hash.Hash",
+ "sha512.New384": "hash.Hash",
+ "sha512.New": "hash.Hash",
+ },
+
+ Type: map[string]*Type{
+ "crypto.Hash": {
+ Method: map[string]string{
+ "New": "func() hash.Hash",
+ },
+ },
+ },
+}
+
+func hashSumFn(f *ast.File) bool {
+ typeof, _ := typecheck(hashSumTypeConfig, f)
+
+ fixed := false
+
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if ok && len(call.Args) == 0 {
+ sel, ok := call.Fun.(*ast.SelectorExpr)
+ if ok && sel.Sel.Name == "Sum" && typeof[sel.X] == "hash.Hash" {
+ call.Args = append(call.Args, ast.NewIdent("nil"))
+ fixed = true
+ }
+ }
+ })
+
+ return fixed
+}
diff --git a/src/cmd/fix/hashsum_test.go b/src/cmd/fix/hashsum_test.go
new file mode 100644
index 000000000..241af2020
--- /dev/null
+++ b/src/cmd/fix/hashsum_test.go
@@ -0,0 +1,99 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(hashSumTests, hashSumFn)
+}
+
+var hashSumTests = []testCase{
+ {
+ Name: "hashsum.0",
+ In: `package main
+
+import "crypto/sha256"
+
+func f() []byte {
+ h := sha256.New()
+ return h.Sum()
+}
+`,
+ Out: `package main
+
+import "crypto/sha256"
+
+func f() []byte {
+ h := sha256.New()
+ return h.Sum(nil)
+}
+`,
+ },
+
+ {
+ Name: "hashsum.1",
+ In: `package main
+
+func f(h hash.Hash) []byte {
+ return h.Sum()
+}
+`,
+ Out: `package main
+
+func f(h hash.Hash) []byte {
+ return h.Sum(nil)
+}
+`,
+ },
+
+ {
+ Name: "hashsum.0",
+ In: `package main
+
+import "crypto/sha256"
+
+func f() []byte {
+ h := sha256.New()
+ h.Write([]byte("foo"))
+ digest := h.Sum()
+}
+`,
+ Out: `package main
+
+import "crypto/sha256"
+
+func f() []byte {
+ h := sha256.New()
+ h.Write([]byte("foo"))
+ digest := h.Sum(nil)
+}
+`,
+ },
+
+ {
+ Name: "hashsum.0",
+ In: `package main
+
+import _ "crypto/sha256"
+import "crypto"
+
+func f() []byte {
+ hashType := crypto.SHA256
+ h := hashType.New()
+ digest := h.Sum()
+}
+`,
+ Out: `package main
+
+import _ "crypto/sha256"
+import "crypto"
+
+func f() []byte {
+ hashType := crypto.SHA256
+ h := hashType.New()
+ digest := h.Sum(nil)
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/hmacnew.go b/src/cmd/fix/hmacnew.go
new file mode 100644
index 000000000..c0c44ef3e
--- /dev/null
+++ b/src/cmd/fix/hmacnew.go
@@ -0,0 +1,61 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "go/ast"
+
+func init() {
+ register(hmacNewFix)
+}
+
+var hmacNewFix = fix{
+ "hmacnew",
+ "2012-01-19",
+ hmacnew,
+ `Deprecate hmac.NewMD5, hmac.NewSHA1 and hmac.NewSHA256.
+
+This fix rewrites code using hmac.NewMD5, hmac.NewSHA1 and hmac.NewSHA256 to
+use hmac.New:
+
+ hmac.NewMD5(key) -> hmac.New(md5.New, key)
+ hmac.NewSHA1(key) -> hmac.New(sha1.New, key)
+ hmac.NewSHA256(key) -> hmac.New(sha256.New, key)
+
+`,
+}
+
+func hmacnew(f *ast.File) (fixed bool) {
+ if !imports(f, "crypto/hmac") {
+ return
+ }
+
+ walk(f, func(n interface{}) {
+ ce, ok := n.(*ast.CallExpr)
+ if !ok {
+ return
+ }
+
+ var pkg string
+ switch {
+ case isPkgDot(ce.Fun, "hmac", "NewMD5"):
+ pkg = "md5"
+ case isPkgDot(ce.Fun, "hmac", "NewSHA1"):
+ pkg = "sha1"
+ case isPkgDot(ce.Fun, "hmac", "NewSHA256"):
+ pkg = "sha256"
+ default:
+ return
+ }
+
+ addImport(f, "crypto/"+pkg)
+
+ ce.Fun = ast.NewIdent("hmac.New")
+ ce.Args = append([]ast.Expr{ast.NewIdent(pkg + ".New")}, ce.Args...)
+
+ fixed = true
+ })
+
+ return
+}
diff --git a/src/cmd/fix/hmacnew_test.go b/src/cmd/fix/hmacnew_test.go
new file mode 100644
index 000000000..5aeee8573
--- /dev/null
+++ b/src/cmd/fix/hmacnew_test.go
@@ -0,0 +1,107 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(hmacNewTests, hmacnew)
+}
+
+var hmacNewTests = []testCase{
+ {
+ Name: "hmacnew.0",
+ In: `package main
+
+import "crypto/hmac"
+
+var f = hmac.NewSHA1([]byte("some key"))
+`,
+ Out: `package main
+
+import (
+ "crypto/hmac"
+ "crypto/sha1"
+)
+
+var f = hmac.New(sha1.New, []byte("some key"))
+`,
+ },
+ {
+ Name: "hmacnew.1",
+ In: `package main
+
+import "crypto/hmac"
+
+var key = make([]byte, 8)
+var f = hmac.NewSHA1(key)
+`,
+ Out: `package main
+
+import (
+ "crypto/hmac"
+ "crypto/sha1"
+)
+
+var key = make([]byte, 8)
+var f = hmac.New(sha1.New, key)
+`,
+ },
+ {
+ Name: "hmacnew.2",
+ In: `package main
+
+import "crypto/hmac"
+
+var f = hmac.NewMD5([]byte("some key"))
+`,
+ Out: `package main
+
+import (
+ "crypto/hmac"
+ "crypto/md5"
+)
+
+var f = hmac.New(md5.New, []byte("some key"))
+`,
+ },
+ {
+ Name: "hmacnew.3",
+ In: `package main
+
+import "crypto/hmac"
+
+var f = hmac.NewSHA256([]byte("some key"))
+`,
+ Out: `package main
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+)
+
+var f = hmac.New(sha256.New, []byte("some key"))
+`,
+ },
+ {
+ Name: "hmacnew.4",
+ In: `package main
+
+import (
+ "crypto/hmac"
+ "crypto/sha1"
+)
+
+var f = hmac.New(sha1.New, []byte("some key"))
+`,
+ Out: `package main
+
+import (
+ "crypto/hmac"
+ "crypto/sha1"
+)
+
+var f = hmac.New(sha1.New, []byte("some key"))
+`,
+ },
+}
diff --git a/src/cmd/fix/htmlerr.go b/src/cmd/fix/htmlerr.go
new file mode 100644
index 000000000..b5105c822
--- /dev/null
+++ b/src/cmd/fix/htmlerr.go
@@ -0,0 +1,47 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(htmlerrFix)
+}
+
+var htmlerrFix = fix{
+ "htmlerr",
+ "2011-11-04",
+ htmlerr,
+ `Rename html's Tokenizer.Error method to Err.
+
+http://codereview.appspot.com/5327064/
+`,
+}
+
+var htmlerrTypeConfig = &TypeConfig{
+ Func: map[string]string{
+ "html.NewTokenizer": "html.Tokenizer",
+ },
+}
+
+func htmlerr(f *ast.File) bool {
+ if !imports(f, "html") {
+ return false
+ }
+
+ typeof, _ := typecheck(htmlerrTypeConfig, f)
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ s, ok := n.(*ast.SelectorExpr)
+ if ok && typeof[s.X] == "html.Tokenizer" && s.Sel.Name == "Error" {
+ s.Sel.Name = "Err"
+ fixed = true
+ }
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/htmlerr_test.go b/src/cmd/fix/htmlerr_test.go
new file mode 100644
index 000000000..043abc42a
--- /dev/null
+++ b/src/cmd/fix/htmlerr_test.go
@@ -0,0 +1,39 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(htmlerrTests, htmlerr)
+}
+
+var htmlerrTests = []testCase{
+ {
+ Name: "htmlerr.0",
+ In: `package main
+
+import (
+ "html"
+)
+
+func f() {
+ e := errors.New("")
+ t := html.NewTokenizer(r)
+ _, _ = e.Error(), t.Error()
+}
+`,
+ Out: `package main
+
+import (
+ "html"
+)
+
+func f() {
+ e := errors.New("")
+ t := html.NewTokenizer(r)
+ _, _ = e.Error(), t.Err()
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/httpfinalurl.go b/src/cmd/fix/httpfinalurl.go
new file mode 100644
index 000000000..49b9f1c51
--- /dev/null
+++ b/src/cmd/fix/httpfinalurl.go
@@ -0,0 +1,57 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(httpFinalURLFix)
+}
+
+var httpFinalURLFix = fix{
+ "httpfinalurl",
+ "2011-05-13",
+ httpfinalurl,
+ `Adapt http Get calls to not have a finalURL result parameter.
+
+http://codereview.appspot.com/4535056/
+`,
+}
+
+func httpfinalurl(f *ast.File) bool {
+ if !imports(f, "http") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ // Fix up calls to http.Get.
+ //
+ // If they have blank identifiers, remove them:
+ // resp, _, err := http.Get(url)
+ // -> resp, err := http.Get(url)
+ //
+ // But if they're using the finalURL parameter, warn:
+ // resp, finalURL, err := http.Get(url)
+ as, ok := n.(*ast.AssignStmt)
+ if !ok || len(as.Lhs) != 3 || len(as.Rhs) != 1 {
+ return
+ }
+
+ if !isCall(as.Rhs[0], "http", "Get") {
+ return
+ }
+
+ if isBlank(as.Lhs[1]) {
+ as.Lhs = []ast.Expr{as.Lhs[0], as.Lhs[2]}
+ fixed = true
+ } else {
+ warn(as.Pos(), "call to http.Get records final URL")
+ }
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/httpfinalurl_test.go b/src/cmd/fix/httpfinalurl_test.go
new file mode 100644
index 000000000..9249f7e18
--- /dev/null
+++ b/src/cmd/fix/httpfinalurl_test.go
@@ -0,0 +1,37 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(httpfinalurlTests, httpfinalurl)
+}
+
+var httpfinalurlTests = []testCase{
+ {
+ Name: "finalurl.0",
+ In: `package main
+
+import (
+ "http"
+)
+
+func f() {
+ resp, _, err := http.Get("http://www.google.com/")
+ _, _ = resp, err
+}
+`,
+ Out: `package main
+
+import (
+ "http"
+)
+
+func f() {
+ resp, err := http.Get("http://www.google.com/")
+ _, _ = resp, err
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/httpfs.go b/src/cmd/fix/httpfs.go
new file mode 100644
index 000000000..d87b30f9d
--- /dev/null
+++ b/src/cmd/fix/httpfs.go
@@ -0,0 +1,70 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "go/token"
+)
+
+func init() {
+ register(httpFileSystemFix)
+}
+
+var httpFileSystemFix = fix{
+ "httpfs",
+ "2011-06-27",
+ httpfs,
+ `Adapt http FileServer to take a FileSystem.
+
+http://codereview.appspot.com/4629047 http FileSystem interface
+`,
+}
+
+func httpfs(f *ast.File) bool {
+ if !imports(f, "http") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok || !isPkgDot(call.Fun, "http", "FileServer") {
+ return
+ }
+ if len(call.Args) != 2 {
+ return
+ }
+ dir, prefix := call.Args[0], call.Args[1]
+ call.Args = []ast.Expr{&ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: ast.NewIdent("http"),
+ Sel: ast.NewIdent("Dir"),
+ },
+ Args: []ast.Expr{dir},
+ }}
+ wrapInStripHandler := true
+ if prefixLit, ok := prefix.(*ast.BasicLit); ok {
+ if prefixLit.Kind == token.STRING && (prefixLit.Value == `"/"` || prefixLit.Value == `""`) {
+ wrapInStripHandler = false
+ }
+ }
+ if wrapInStripHandler {
+ call.Fun.(*ast.SelectorExpr).Sel = ast.NewIdent("StripPrefix")
+ call.Args = []ast.Expr{
+ prefix,
+ &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: ast.NewIdent("http"),
+ Sel: ast.NewIdent("FileServer"),
+ },
+ Args: call.Args,
+ },
+ }
+ }
+ fixed = true
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/httpfs_test.go b/src/cmd/fix/httpfs_test.go
new file mode 100644
index 000000000..dd8ef2cfd
--- /dev/null
+++ b/src/cmd/fix/httpfs_test.go
@@ -0,0 +1,47 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(httpFileSystemTests, httpfs)
+}
+
+var httpFileSystemTests = []testCase{
+ {
+ Name: "httpfs.0",
+ In: `package httpfs
+
+import (
+ "http"
+)
+
+func f() {
+ _ = http.FileServer("/var/www/foo", "/")
+ _ = http.FileServer("/var/www/foo", "")
+ _ = http.FileServer("/var/www/foo/bar", "/bar")
+ s := "/foo"
+ _ = http.FileServer(s, "/")
+ prefix := "/p"
+ _ = http.FileServer(s, prefix)
+}
+`,
+ Out: `package httpfs
+
+import (
+ "http"
+)
+
+func f() {
+ _ = http.FileServer(http.Dir("/var/www/foo"))
+ _ = http.FileServer(http.Dir("/var/www/foo"))
+ _ = http.StripPrefix("/bar", http.FileServer(http.Dir("/var/www/foo/bar")))
+ s := "/foo"
+ _ = http.FileServer(http.Dir(s))
+ prefix := "/p"
+ _ = http.StripPrefix(prefix, http.FileServer(http.Dir(s)))
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/httpheaders.go b/src/cmd/fix/httpheaders.go
new file mode 100644
index 000000000..15c21ac86
--- /dev/null
+++ b/src/cmd/fix/httpheaders.go
@@ -0,0 +1,67 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(httpHeadersFix)
+}
+
+var httpHeadersFix = fix{
+ "httpheaders",
+ "2011-06-16",
+ httpheaders,
+ `Rename http Referer, UserAgent, Cookie, SetCookie, which are now methods.
+
+http://codereview.appspot.com/4620049/
+`,
+}
+
+func httpheaders(f *ast.File) bool {
+ if !imports(f, "http") {
+ return false
+ }
+
+ called := make(map[ast.Node]bool)
+ walk(f, func(ni interface{}) {
+ switch n := ni.(type) {
+ case *ast.CallExpr:
+ called[n.Fun] = true
+ }
+ })
+
+ fixed := false
+ typeof, _ := typecheck(headerTypeConfig, f)
+ walk(f, func(ni interface{}) {
+ switch n := ni.(type) {
+ case *ast.SelectorExpr:
+ if called[n] {
+ break
+ }
+ if t := typeof[n.X]; t != "*http.Request" && t != "*http.Response" {
+ break
+ }
+ switch n.Sel.Name {
+ case "Referer", "UserAgent":
+ n.Sel.Name += "()"
+ fixed = true
+ case "Cookie":
+ n.Sel.Name = "Cookies()"
+ fixed = true
+ }
+ }
+ })
+ return fixed
+}
+
+var headerTypeConfig = &TypeConfig{
+ Type: map[string]*Type{
+ "*http.Request": {},
+ "*http.Response": {},
+ },
+}
diff --git a/src/cmd/fix/httpheaders_test.go b/src/cmd/fix/httpheaders_test.go
new file mode 100644
index 000000000..37506b82d
--- /dev/null
+++ b/src/cmd/fix/httpheaders_test.go
@@ -0,0 +1,73 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(httpHeadersTests, httpheaders)
+}
+
+var httpHeadersTests = []testCase{
+ {
+ Name: "httpheaders.0",
+ In: `package headertest
+
+import (
+ "http"
+)
+
+type Other struct {
+ Referer string
+ UserAgent string
+ Cookie []*http.Cookie
+}
+
+func f(req *http.Request, res *http.Response, other *Other) {
+ _ = req.Referer
+ _ = req.UserAgent
+ _ = req.Cookie
+
+ _ = res.Cookie
+
+ _ = other.Referer
+ _ = other.UserAgent
+ _ = other.Cookie
+
+ _ = req.Referer()
+ _ = req.UserAgent()
+ _ = req.Cookies()
+ _ = res.Cookies()
+}
+`,
+ Out: `package headertest
+
+import (
+ "http"
+)
+
+type Other struct {
+ Referer string
+ UserAgent string
+ Cookie []*http.Cookie
+}
+
+func f(req *http.Request, res *http.Response, other *Other) {
+ _ = req.Referer()
+ _ = req.UserAgent()
+ _ = req.Cookies()
+
+ _ = res.Cookies()
+
+ _ = other.Referer
+ _ = other.UserAgent
+ _ = other.Cookie
+
+ _ = req.Referer()
+ _ = req.UserAgent()
+ _ = req.Cookies()
+ _ = res.Cookies()
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/httpserver.go b/src/cmd/fix/httpserver.go
new file mode 100644
index 000000000..7aa651786
--- /dev/null
+++ b/src/cmd/fix/httpserver.go
@@ -0,0 +1,141 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "go/token"
+)
+
+func init() {
+ register(httpserverFix)
+}
+
+var httpserverFix = fix{
+ "httpserver",
+ "2011-03-15",
+ httpserver,
+ `Adapt http server methods and functions to changes
+made to the http ResponseWriter interface.
+
+http://codereview.appspot.com/4245064 Hijacker
+http://codereview.appspot.com/4239076 Header
+http://codereview.appspot.com/4239077 Flusher
+http://codereview.appspot.com/4248075 RemoteAddr, UsingTLS
+`,
+}
+
+func httpserver(f *ast.File) bool {
+ if !imports(f, "http") {
+ return false
+ }
+
+ fixed := false
+ for _, decl := range f.Decls {
+ fn, ok := decl.(*ast.FuncDecl)
+ if !ok {
+ continue
+ }
+ w, req, ok := isServeHTTP(fn)
+ if !ok {
+ continue
+ }
+ walk(fn.Body, func(n interface{}) {
+ // Want to replace expression sometimes,
+ // so record pointer to it for updating below.
+ ptr, ok := n.(*ast.Expr)
+ if ok {
+ n = *ptr
+ }
+
+ // Look for w.UsingTLS() and w.Remoteaddr().
+ call, ok := n.(*ast.CallExpr)
+ if !ok || (len(call.Args) != 0 && len(call.Args) != 2) {
+ return
+ }
+ sel, ok := call.Fun.(*ast.SelectorExpr)
+ if !ok {
+ return
+ }
+ if !refersTo(sel.X, w) {
+ return
+ }
+ switch sel.Sel.String() {
+ case "Hijack":
+ // replace w with w.(http.Hijacker)
+ sel.X = &ast.TypeAssertExpr{
+ X: sel.X,
+ Type: ast.NewIdent("http.Hijacker"),
+ }
+ fixed = true
+ case "Flush":
+ // replace w with w.(http.Flusher)
+ sel.X = &ast.TypeAssertExpr{
+ X: sel.X,
+ Type: ast.NewIdent("http.Flusher"),
+ }
+ fixed = true
+ case "UsingTLS":
+ if ptr == nil {
+ // can only replace expression if we have pointer to it
+ break
+ }
+ // replace with req.TLS != nil
+ *ptr = &ast.BinaryExpr{
+ X: &ast.SelectorExpr{
+ X: ast.NewIdent(req.String()),
+ Sel: ast.NewIdent("TLS"),
+ },
+ Op: token.NEQ,
+ Y: ast.NewIdent("nil"),
+ }
+ fixed = true
+ case "RemoteAddr":
+ if ptr == nil {
+ // can only replace expression if we have pointer to it
+ break
+ }
+ // replace with req.RemoteAddr
+ *ptr = &ast.SelectorExpr{
+ X: ast.NewIdent(req.String()),
+ Sel: ast.NewIdent("RemoteAddr"),
+ }
+ fixed = true
+ case "SetHeader":
+ // replace w.SetHeader with w.Header().Set
+ // or w.Header().Del if second argument is ""
+ sel.X = &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: ast.NewIdent(w.String()),
+ Sel: ast.NewIdent("Header"),
+ },
+ }
+ sel.Sel = ast.NewIdent("Set")
+ if len(call.Args) == 2 && isEmptyString(call.Args[1]) {
+ sel.Sel = ast.NewIdent("Del")
+ call.Args = call.Args[:1]
+ }
+ fixed = true
+ }
+ })
+ }
+ return fixed
+}
+
+func isServeHTTP(fn *ast.FuncDecl) (w, req *ast.Ident, ok bool) {
+ for _, field := range fn.Type.Params.List {
+ if isPkgDot(field.Type, "http", "ResponseWriter") {
+ w = field.Names[0]
+ continue
+ }
+ if isPtrPkgDot(field.Type, "http", "Request") {
+ req = field.Names[0]
+ continue
+ }
+ }
+
+ ok = w != nil && req != nil
+ return
+}
diff --git a/src/cmd/fix/httpserver_test.go b/src/cmd/fix/httpserver_test.go
new file mode 100644
index 000000000..b6ddff27e
--- /dev/null
+++ b/src/cmd/fix/httpserver_test.go
@@ -0,0 +1,53 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(httpserverTests, httpserver)
+}
+
+var httpserverTests = []testCase{
+ {
+ Name: "httpserver.0",
+ In: `package main
+
+import "http"
+
+func f(xyz http.ResponseWriter, abc *http.Request, b string) {
+ xyz.SetHeader("foo", "bar")
+ xyz.SetHeader("baz", "")
+ xyz.Hijack()
+ xyz.Flush()
+ go xyz.Hijack()
+ defer xyz.Flush()
+ _ = xyz.UsingTLS()
+ _ = true == xyz.UsingTLS()
+ _ = xyz.RemoteAddr()
+ _ = xyz.RemoteAddr() == "hello"
+ if xyz.UsingTLS() {
+ }
+}
+`,
+ Out: `package main
+
+import "http"
+
+func f(xyz http.ResponseWriter, abc *http.Request, b string) {
+ xyz.Header().Set("foo", "bar")
+ xyz.Header().Del("baz")
+ xyz.(http.Hijacker).Hijack()
+ xyz.(http.Flusher).Flush()
+ go xyz.(http.Hijacker).Hijack()
+ defer xyz.(http.Flusher).Flush()
+ _ = abc.TLS != nil
+ _ = true == (abc.TLS != nil)
+ _ = abc.RemoteAddr
+ _ = abc.RemoteAddr == "hello"
+ if abc.TLS != nil {
+ }
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/imagecolor.go b/src/cmd/fix/imagecolor.go
new file mode 100644
index 000000000..1aac40a6f
--- /dev/null
+++ b/src/cmd/fix/imagecolor.go
@@ -0,0 +1,85 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(imagecolorFix)
+}
+
+var imagecolorFix = fix{
+ "imagecolor",
+ "2011-10-04",
+ imagecolor,
+ `Adapt code to types moved from image to color.
+
+http://codereview.appspot.com/5132048
+`,
+}
+
+var colorRenames = []struct{ in, out string }{
+ {"Color", "Color"},
+ {"ColorModel", "Model"},
+ {"ColorModelFunc", "ModelFunc"},
+ {"PalettedColorModel", "Palette"},
+
+ {"RGBAColor", "RGBA"},
+ {"RGBA64Color", "RGBA64"},
+ {"NRGBAColor", "NRGBA"},
+ {"NRGBA64Color", "NRGBA64"},
+ {"AlphaColor", "Alpha"},
+ {"Alpha16Color", "Alpha16"},
+ {"GrayColor", "Gray"},
+ {"Gray16Color", "Gray16"},
+
+ {"RGBAColorModel", "RGBAModel"},
+ {"RGBA64ColorModel", "RGBA64Model"},
+ {"NRGBAColorModel", "NRGBAModel"},
+ {"NRGBA64ColorModel", "NRGBA64Model"},
+ {"AlphaColorModel", "AlphaModel"},
+ {"Alpha16ColorModel", "Alpha16Model"},
+ {"GrayColorModel", "GrayModel"},
+ {"Gray16ColorModel", "Gray16Model"},
+}
+
+func imagecolor(f *ast.File) (fixed bool) {
+ if !imports(f, "image") {
+ return
+ }
+
+ walk(f, func(n interface{}) {
+ s, ok := n.(*ast.SelectorExpr)
+
+ if !ok || !isTopName(s.X, "image") {
+ return
+ }
+
+ switch sel := s.Sel.String(); {
+ case sel == "ColorImage":
+ s.Sel = &ast.Ident{Name: "Uniform"}
+ fixed = true
+ case sel == "NewColorImage":
+ s.Sel = &ast.Ident{Name: "NewUniform"}
+ fixed = true
+ default:
+ for _, rename := range colorRenames {
+ if sel == rename.in {
+ addImport(f, "image/color")
+ s.X.(*ast.Ident).Name = "color"
+ s.Sel.Name = rename.out
+ fixed = true
+ }
+ }
+ }
+ })
+
+ if fixed && !usesImport(f, "image") {
+ deleteImport(f, "image")
+ }
+ return
+}
diff --git a/src/cmd/fix/imagecolor_test.go b/src/cmd/fix/imagecolor_test.go
new file mode 100644
index 000000000..c62365481
--- /dev/null
+++ b/src/cmd/fix/imagecolor_test.go
@@ -0,0 +1,126 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(colorTests, imagecolor)
+}
+
+var colorTests = []testCase{
+ {
+ Name: "color.0",
+ In: `package main
+
+import (
+ "image"
+)
+
+var (
+ _ image.Image
+ _ image.RGBA
+ _ image.Black
+ _ image.Color
+ _ image.ColorModel
+ _ image.ColorModelFunc
+ _ image.PalettedColorModel
+ _ image.RGBAColor
+ _ image.RGBA64Color
+ _ image.NRGBAColor
+ _ image.NRGBA64Color
+ _ image.AlphaColor
+ _ image.Alpha16Color
+ _ image.GrayColor
+ _ image.Gray16Color
+)
+
+func f() {
+ _ = image.RGBAColorModel
+ _ = image.RGBA64ColorModel
+ _ = image.NRGBAColorModel
+ _ = image.NRGBA64ColorModel
+ _ = image.AlphaColorModel
+ _ = image.Alpha16ColorModel
+ _ = image.GrayColorModel
+ _ = image.Gray16ColorModel
+}
+`,
+ Out: `package main
+
+import (
+ "image"
+ "image/color"
+)
+
+var (
+ _ image.Image
+ _ image.RGBA
+ _ image.Black
+ _ color.Color
+ _ color.Model
+ _ color.ModelFunc
+ _ color.Palette
+ _ color.RGBA
+ _ color.RGBA64
+ _ color.NRGBA
+ _ color.NRGBA64
+ _ color.Alpha
+ _ color.Alpha16
+ _ color.Gray
+ _ color.Gray16
+)
+
+func f() {
+ _ = color.RGBAModel
+ _ = color.RGBA64Model
+ _ = color.NRGBAModel
+ _ = color.NRGBA64Model
+ _ = color.AlphaModel
+ _ = color.Alpha16Model
+ _ = color.GrayModel
+ _ = color.Gray16Model
+}
+`,
+ },
+ {
+ Name: "color.1",
+ In: `package main
+
+import (
+ "fmt"
+ "image"
+)
+
+func f() {
+ fmt.Println(image.RGBAColor{1, 2, 3, 4}.RGBA())
+}
+`,
+ Out: `package main
+
+import (
+ "fmt"
+ "image/color"
+)
+
+func f() {
+ fmt.Println(color.RGBA{1, 2, 3, 4}.RGBA())
+}
+`,
+ },
+ {
+ Name: "color.2",
+ In: `package main
+
+import "image"
+
+var c *image.ColorImage = image.NewColorImage(nil)
+`,
+ Out: `package main
+
+import "image"
+
+var c *image.Uniform = image.NewUniform(nil)
+`,
+ },
+}
diff --git a/src/cmd/fix/imagenew.go b/src/cmd/fix/imagenew.go
new file mode 100644
index 000000000..b4e36d4f0
--- /dev/null
+++ b/src/cmd/fix/imagenew.go
@@ -0,0 +1,83 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(imagenewFix)
+}
+
+var imagenewFix = fix{
+ "imagenew",
+ "2011-09-14",
+ imagenew,
+ `Adapt image.NewXxx calls to pass an image.Rectangle instead of (w, h int).
+
+http://codereview.appspot.com/4964073
+`,
+}
+
+var imagenewFuncs = map[string]bool{
+ "NewRGBA": true,
+ "NewRGBA64": true,
+ "NewNRGBA": true,
+ "NewNRGBA64": true,
+ "NewAlpha": true,
+ "NewAlpha16": true,
+ "NewGray": true,
+ "NewGray16": true,
+}
+
+func imagenew(f *ast.File) bool {
+ if !imports(f, "image") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok {
+ return
+ }
+ isNewFunc := false
+ for newFunc := range imagenewFuncs {
+ if len(call.Args) == 2 && isPkgDot(call.Fun, "image", newFunc) {
+ isNewFunc = true
+ break
+ }
+ }
+ if len(call.Args) == 3 && isPkgDot(call.Fun, "image", "NewPaletted") {
+ isNewFunc = true
+ }
+ if !isNewFunc {
+ return
+ }
+ // Replace image.NewXxx(w, h) with image.NewXxx(image.Rect(0, 0, w, h)).
+ rectArgs := []ast.Expr{
+ &ast.BasicLit{Value: "0"},
+ &ast.BasicLit{Value: "0"},
+ }
+ rectArgs = append(rectArgs, call.Args[:2]...)
+ rect := []ast.Expr{
+ &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: &ast.Ident{
+ Name: "image",
+ },
+ Sel: &ast.Ident{
+ Name: "Rect",
+ },
+ },
+ Args: rectArgs,
+ },
+ }
+ call.Args = append(rect, call.Args[2:]...)
+ fixed = true
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/imagenew_test.go b/src/cmd/fix/imagenew_test.go
new file mode 100644
index 000000000..30abed23c
--- /dev/null
+++ b/src/cmd/fix/imagenew_test.go
@@ -0,0 +1,51 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(imagenewTests, imagenew)
+}
+
+var imagenewTests = []testCase{
+ {
+ Name: "imagenew.0",
+ In: `package main
+
+import (
+ "image"
+)
+
+func f() {
+ image.NewRGBA(1, 2)
+ image.NewRGBA64(1, 2)
+ image.NewNRGBA(1, 2)
+ image.NewNRGBA64(1, 2)
+ image.NewAlpha(1, 2)
+ image.NewAlpha16(1, 2)
+ image.NewGray(1, 2)
+ image.NewGray16(1, 2)
+ image.NewPaletted(1, 2, nil)
+}
+`,
+ Out: `package main
+
+import (
+ "image"
+)
+
+func f() {
+ image.NewRGBA(image.Rect(0, 0, 1, 2))
+ image.NewRGBA64(image.Rect(0, 0, 1, 2))
+ image.NewNRGBA(image.Rect(0, 0, 1, 2))
+ image.NewNRGBA64(image.Rect(0, 0, 1, 2))
+ image.NewAlpha(image.Rect(0, 0, 1, 2))
+ image.NewAlpha16(image.Rect(0, 0, 1, 2))
+ image.NewGray(image.Rect(0, 0, 1, 2))
+ image.NewGray16(image.Rect(0, 0, 1, 2))
+ image.NewPaletted(image.Rect(0, 0, 1, 2), nil)
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/imageycbcr.go b/src/cmd/fix/imageycbcr.go
new file mode 100644
index 000000000..41b96d18d
--- /dev/null
+++ b/src/cmd/fix/imageycbcr.go
@@ -0,0 +1,64 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(imageycbcrFix)
+}
+
+var imageycbcrFix = fix{
+ "imageycbcr",
+ "2011-12-20",
+ imageycbcr,
+ `Adapt code to types moved from image/ycbcr to image and image/color.
+
+http://codereview.appspot.com/5493084
+`,
+}
+
+func imageycbcr(f *ast.File) (fixed bool) {
+ if !imports(f, "image/ycbcr") {
+ return
+ }
+
+ walk(f, func(n interface{}) {
+ s, ok := n.(*ast.SelectorExpr)
+
+ if !ok || !isTopName(s.X, "ycbcr") {
+ return
+ }
+
+ switch s.Sel.String() {
+ case "RGBToYCbCr", "YCbCrToRGB":
+ addImport(f, "image/color")
+ s.X.(*ast.Ident).Name = "color"
+ case "YCbCrColor":
+ addImport(f, "image/color")
+ s.X.(*ast.Ident).Name = "color"
+ s.Sel.Name = "YCbCr"
+ case "YCbCrColorModel":
+ addImport(f, "image/color")
+ s.X.(*ast.Ident).Name = "color"
+ s.Sel.Name = "YCbCrModel"
+ case "SubsampleRatio", "SubsampleRatio444", "SubsampleRatio422", "SubsampleRatio420":
+ addImport(f, "image")
+ s.X.(*ast.Ident).Name = "image"
+ s.Sel.Name = "YCbCr" + s.Sel.Name
+ case "YCbCr":
+ addImport(f, "image")
+ s.X.(*ast.Ident).Name = "image"
+ default:
+ return
+ }
+ fixed = true
+ })
+
+ deleteImport(f, "image/ycbcr")
+ return
+}
diff --git a/src/cmd/fix/imageycbcr_test.go b/src/cmd/fix/imageycbcr_test.go
new file mode 100644
index 000000000..23b599dcd
--- /dev/null
+++ b/src/cmd/fix/imageycbcr_test.go
@@ -0,0 +1,54 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(ycbcrTests, imageycbcr)
+}
+
+var ycbcrTests = []testCase{
+ {
+ Name: "ycbcr.0",
+ In: `package main
+
+import (
+ "image/ycbcr"
+)
+
+func f() {
+ _ = ycbcr.RGBToYCbCr
+ _ = ycbcr.YCbCrToRGB
+ _ = ycbcr.YCbCrColorModel
+ var _ ycbcr.YCbCrColor
+ var _ ycbcr.YCbCr
+ var (
+ _ ycbcr.SubsampleRatio = ycbcr.SubsampleRatio444
+ _ ycbcr.SubsampleRatio = ycbcr.SubsampleRatio422
+ _ ycbcr.SubsampleRatio = ycbcr.SubsampleRatio420
+ )
+}
+`,
+ Out: `package main
+
+import (
+ "image"
+ "image/color"
+)
+
+func f() {
+ _ = color.RGBToYCbCr
+ _ = color.YCbCrToRGB
+ _ = color.YCbCrModel
+ var _ color.YCbCr
+ var _ image.YCbCr
+ var (
+ _ image.YCbCrSubsampleRatio = image.YCbCrSubsampleRatio444
+ _ image.YCbCrSubsampleRatio = image.YCbCrSubsampleRatio422
+ _ image.YCbCrSubsampleRatio = image.YCbCrSubsampleRatio420
+ )
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/import_test.go b/src/cmd/fix/import_test.go
new file mode 100644
index 000000000..730119205
--- /dev/null
+++ b/src/cmd/fix/import_test.go
@@ -0,0 +1,458 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "go/ast"
+
+func init() {
+ addTestCases(importTests, nil)
+}
+
+var importTests = []testCase{
+ {
+ Name: "import.0",
+ Fn: addImportFn("os"),
+ In: `package main
+
+import (
+ "os"
+)
+`,
+ Out: `package main
+
+import (
+ "os"
+)
+`,
+ },
+ {
+ Name: "import.1",
+ Fn: addImportFn("os"),
+ In: `package main
+`,
+ Out: `package main
+
+import "os"
+`,
+ },
+ {
+ Name: "import.2",
+ Fn: addImportFn("os"),
+ In: `package main
+
+// Comment
+import "C"
+`,
+ Out: `package main
+
+// Comment
+import "C"
+import "os"
+`,
+ },
+ {
+ Name: "import.3",
+ Fn: addImportFn("os"),
+ In: `package main
+
+// Comment
+import "C"
+
+import (
+ "io"
+ "utf8"
+)
+`,
+ Out: `package main
+
+// Comment
+import "C"
+
+import (
+ "io"
+ "os"
+ "utf8"
+)
+`,
+ },
+ {
+ Name: "import.4",
+ Fn: deleteImportFn("os"),
+ In: `package main
+
+import (
+ "os"
+)
+`,
+ Out: `package main
+`,
+ },
+ {
+ Name: "import.5",
+ Fn: deleteImportFn("os"),
+ In: `package main
+
+// Comment
+import "C"
+import "os"
+`,
+ Out: `package main
+
+// Comment
+import "C"
+`,
+ },
+ {
+ Name: "import.6",
+ Fn: deleteImportFn("os"),
+ In: `package main
+
+// Comment
+import "C"
+
+import (
+ "io"
+ "os"
+ "utf8"
+)
+`,
+ Out: `package main
+
+// Comment
+import "C"
+
+import (
+ "io"
+ "utf8"
+)
+`,
+ },
+ {
+ Name: "import.7",
+ Fn: deleteImportFn("io"),
+ In: `package main
+
+import (
+ "io" // a
+ "os" // b
+ "utf8" // c
+)
+`,
+ Out: `package main
+
+import (
+ // a
+ "os" // b
+ "utf8" // c
+)
+`,
+ },
+ {
+ Name: "import.8",
+ Fn: deleteImportFn("os"),
+ In: `package main
+
+import (
+ "io" // a
+ "os" // b
+ "utf8" // c
+)
+`,
+ Out: `package main
+
+import (
+ "io" // a
+ // b
+ "utf8" // c
+)
+`,
+ },
+ {
+ Name: "import.9",
+ Fn: deleteImportFn("utf8"),
+ In: `package main
+
+import (
+ "io" // a
+ "os" // b
+ "utf8" // c
+)
+`,
+ Out: `package main
+
+import (
+ "io" // a
+ "os" // b
+ // c
+)
+`,
+ },
+ {
+ Name: "import.10",
+ Fn: deleteImportFn("io"),
+ In: `package main
+
+import (
+ "io"
+ "os"
+ "utf8"
+)
+`,
+ Out: `package main
+
+import (
+ "os"
+ "utf8"
+)
+`,
+ },
+ {
+ Name: "import.11",
+ Fn: deleteImportFn("os"),
+ In: `package main
+
+import (
+ "io"
+ "os"
+ "utf8"
+)
+`,
+ Out: `package main
+
+import (
+ "io"
+ "utf8"
+)
+`,
+ },
+ {
+ Name: "import.12",
+ Fn: deleteImportFn("utf8"),
+ In: `package main
+
+import (
+ "io"
+ "os"
+ "utf8"
+)
+`,
+ Out: `package main
+
+import (
+ "io"
+ "os"
+)
+`,
+ },
+ {
+ Name: "import.13",
+ Fn: rewriteImportFn("utf8", "encoding/utf8"),
+ In: `package main
+
+import (
+ "io"
+ "os"
+ "utf8" // thanks ken
+)
+`,
+ Out: `package main
+
+import (
+ "encoding/utf8" // thanks ken
+ "io"
+ "os"
+)
+`,
+ },
+ {
+ Name: "import.14",
+ Fn: rewriteImportFn("asn1", "encoding/asn1"),
+ In: `package main
+
+import (
+ "asn1"
+ "crypto"
+ "crypto/rsa"
+ _ "crypto/sha1"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "time"
+)
+
+var x = 1
+`,
+ Out: `package main
+
+import (
+ "crypto"
+ "crypto/rsa"
+ _ "crypto/sha1"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/asn1"
+ "time"
+)
+
+var x = 1
+`,
+ },
+ {
+ Name: "import.15",
+ Fn: rewriteImportFn("url", "net/url"),
+ In: `package main
+
+import (
+ "bufio"
+ "net"
+ "path"
+ "url"
+)
+
+var x = 1 // comment on x, not on url
+`,
+ Out: `package main
+
+import (
+ "bufio"
+ "net"
+ "net/url"
+ "path"
+)
+
+var x = 1 // comment on x, not on url
+`,
+ },
+ {
+ Name: "import.16",
+ Fn: rewriteImportFn("http", "net/http", "template", "text/template"),
+ In: `package main
+
+import (
+ "flag"
+ "http"
+ "log"
+ "template"
+)
+
+var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
+`,
+ Out: `package main
+
+import (
+ "flag"
+ "log"
+ "net/http"
+ "text/template"
+)
+
+var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
+`,
+ },
+ {
+ Name: "import.17",
+ Fn: addImportFn("x/y/z", "x/a/c"),
+ In: `package main
+
+// Comment
+import "C"
+
+import (
+ "a"
+ "b"
+
+ "x/w"
+
+ "d/f"
+)
+`,
+ Out: `package main
+
+// Comment
+import "C"
+
+import (
+ "a"
+ "b"
+
+ "x/a/c"
+ "x/w"
+ "x/y/z"
+
+ "d/f"
+)
+`,
+ },
+ {
+ Name: "import.18",
+ Fn: addDelImportFn("e", "o"),
+ In: `package main
+
+import (
+ "f"
+ "o"
+ "z"
+)
+`,
+ Out: `package main
+
+import (
+ "e"
+ "f"
+ "z"
+)
+`,
+ },
+}
+
+func addImportFn(path ...string) func(*ast.File) bool {
+ return func(f *ast.File) bool {
+ fixed := false
+ for _, p := range path {
+ if !imports(f, p) {
+ addImport(f, p)
+ fixed = true
+ }
+ }
+ return fixed
+ }
+}
+
+func deleteImportFn(path string) func(*ast.File) bool {
+ return func(f *ast.File) bool {
+ if imports(f, path) {
+ deleteImport(f, path)
+ return true
+ }
+ return false
+ }
+}
+
+func addDelImportFn(p1 string, p2 string) func(*ast.File) bool {
+ return func(f *ast.File) bool {
+ fixed := false
+ if !imports(f, p1) {
+ addImport(f, p1)
+ fixed = true
+ }
+ if imports(f, p2) {
+ deleteImport(f, p2)
+ fixed = true
+ }
+ return fixed
+ }
+}
+
+func rewriteImportFn(oldnew ...string) func(*ast.File) bool {
+ return func(f *ast.File) bool {
+ fixed := false
+ for i := 0; i < len(oldnew); i += 2 {
+ if imports(f, oldnew[i]) {
+ rewriteImport(f, oldnew[i], oldnew[i+1])
+ fixed = true
+ }
+ }
+ return fixed
+ }
+}
diff --git a/src/cmd/fix/iocopyn.go b/src/cmd/fix/iocopyn.go
new file mode 100644
index 000000000..720f3c689
--- /dev/null
+++ b/src/cmd/fix/iocopyn.go
@@ -0,0 +1,41 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(ioCopyNFix)
+}
+
+var ioCopyNFix = fix{
+ "iocopyn",
+ "2011-09-30",
+ ioCopyN,
+ `Rename io.Copyn to io.CopyN.
+
+http://codereview.appspot.com/5157045
+`,
+}
+
+func ioCopyN(f *ast.File) bool {
+ if !imports(f, "io") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ if expr, ok := n.(ast.Expr); ok {
+ if isPkgDot(expr, "io", "Copyn") {
+ expr.(*ast.SelectorExpr).Sel.Name = "CopyN"
+ fixed = true
+ return
+ }
+ }
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/iocopyn_test.go b/src/cmd/fix/iocopyn_test.go
new file mode 100644
index 000000000..f86fad763
--- /dev/null
+++ b/src/cmd/fix/iocopyn_test.go
@@ -0,0 +1,37 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(ioCopyNTests, ioCopyN)
+}
+
+var ioCopyNTests = []testCase{
+ {
+ Name: "io.CopyN.0",
+ In: `package main
+
+import (
+ "io"
+)
+
+func f() {
+ io.Copyn(dst, src)
+ foo.Copyn(dst, src)
+}
+`,
+ Out: `package main
+
+import (
+ "io"
+)
+
+func f() {
+ io.CopyN(dst, src)
+ foo.Copyn(dst, src)
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/main.go b/src/cmd/fix/main.go
new file mode 100644
index 000000000..b151408d7
--- /dev/null
+++ b/src/cmd/fix/main.go
@@ -0,0 +1,271 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/printer"
+ "go/scanner"
+ "go/token"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "sort"
+ "strings"
+)
+
+var (
+ fset = token.NewFileSet()
+ exitCode = 0
+)
+
+var allowedRewrites = flag.String("r", "",
+ "restrict the rewrites to this comma-separated list")
+
+var forceRewrites = flag.String("force", "",
+ "force these fixes to run even if the code looks updated")
+
+var allowed, force map[string]bool
+
+var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
+
+// enable for debugging fix failures
+const debug = false // display incorrectly reformatted source and exit
+
+func usage() {
+ fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
+ flag.PrintDefaults()
+ fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
+ sort.Sort(byName(fixes))
+ for _, f := range fixes {
+ fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
+ desc := strings.TrimSpace(f.desc)
+ desc = strings.Replace(desc, "\n", "\n\t", -1)
+ fmt.Fprintf(os.Stderr, "\t%s\n", desc)
+ }
+ os.Exit(2)
+}
+
+func main() {
+ flag.Usage = usage
+ flag.Parse()
+
+ sort.Sort(byDate(fixes))
+
+ if *allowedRewrites != "" {
+ allowed = make(map[string]bool)
+ for _, f := range strings.Split(*allowedRewrites, ",") {
+ allowed[f] = true
+ }
+ }
+
+ if *forceRewrites != "" {
+ force = make(map[string]bool)
+ for _, f := range strings.Split(*forceRewrites, ",") {
+ force[f] = true
+ }
+ }
+
+ if flag.NArg() == 0 {
+ if err := processFile("standard input", true); err != nil {
+ report(err)
+ }
+ os.Exit(exitCode)
+ }
+
+ for i := 0; i < flag.NArg(); i++ {
+ path := flag.Arg(i)
+ switch dir, err := os.Stat(path); {
+ case err != nil:
+ report(err)
+ case dir.IsDir():
+ walkDir(path)
+ default:
+ if err := processFile(path, false); err != nil {
+ report(err)
+ }
+ }
+ }
+
+ os.Exit(exitCode)
+}
+
+const (
+ tabWidth = 8
+ parserMode = parser.ParseComments
+ printerMode = printer.TabIndent | printer.UseSpaces
+)
+
+var printConfig = &printer.Config{
+ Mode: printerMode,
+ Tabwidth: tabWidth,
+}
+
+func gofmtFile(f *ast.File) ([]byte, error) {
+ var buf bytes.Buffer
+
+ ast.SortImports(fset, f)
+ err := printConfig.Fprint(&buf, fset, f)
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+func processFile(filename string, useStdin bool) error {
+ var f *os.File
+ var err error
+ var fixlog bytes.Buffer
+
+ if useStdin {
+ f = os.Stdin
+ } else {
+ f, err = os.Open(filename)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ }
+
+ src, err := ioutil.ReadAll(f)
+ if err != nil {
+ return err
+ }
+
+ file, err := parser.ParseFile(fset, filename, src, parserMode)
+ if err != nil {
+ return err
+ }
+
+ // Apply all fixes to file.
+ newFile := file
+ fixed := false
+ for _, fix := range fixes {
+ if allowed != nil && !allowed[fix.name] {
+ continue
+ }
+ if fix.f(newFile) {
+ fixed = true
+ fmt.Fprintf(&fixlog, " %s", fix.name)
+
+ // AST changed.
+ // Print and parse, to update any missing scoping
+ // or position information for subsequent fixers.
+ newSrc, err := gofmtFile(newFile)
+ if err != nil {
+ return err
+ }
+ newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
+ if err != nil {
+ if debug {
+ fmt.Printf("%s", newSrc)
+ report(err)
+ os.Exit(exitCode)
+ }
+ return err
+ }
+ }
+ }
+ if !fixed {
+ return nil
+ }
+ fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
+
+ // Print AST. We did that after each fix, so this appears
+ // redundant, but it is necessary to generate gofmt-compatible
+ // source code in a few cases. The official gofmt style is the
+ // output of the printer run on a standard AST generated by the parser,
+ // but the source we generated inside the loop above is the
+ // output of the printer run on a mangled AST generated by a fixer.
+ newSrc, err := gofmtFile(newFile)
+ if err != nil {
+ return err
+ }
+
+ if *doDiff {
+ data, err := diff(src, newSrc)
+ if err != nil {
+ return fmt.Errorf("computing diff: %s", err)
+ }
+ fmt.Printf("diff %s fixed/%s\n", filename, filename)
+ os.Stdout.Write(data)
+ return nil
+ }
+
+ if useStdin {
+ os.Stdout.Write(newSrc)
+ return nil
+ }
+
+ return ioutil.WriteFile(f.Name(), newSrc, 0)
+}
+
+var gofmtBuf bytes.Buffer
+
+func gofmt(n interface{}) string {
+ gofmtBuf.Reset()
+ err := printConfig.Fprint(&gofmtBuf, fset, n)
+ if err != nil {
+ return "<" + err.Error() + ">"
+ }
+ return gofmtBuf.String()
+}
+
+func report(err error) {
+ scanner.PrintError(os.Stderr, err)
+ exitCode = 2
+}
+
+func walkDir(path string) {
+ filepath.Walk(path, visitFile)
+}
+
+func visitFile(path string, f os.FileInfo, err error) error {
+ if err == nil && isGoFile(f) {
+ err = processFile(path, false)
+ }
+ if err != nil {
+ report(err)
+ }
+ return nil
+}
+
+func isGoFile(f os.FileInfo) bool {
+ // ignore non-Go files
+ name := f.Name()
+ return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
+}
+
+func diff(b1, b2 []byte) (data []byte, err error) {
+ f1, err := ioutil.TempFile("", "go-fix")
+ if err != nil {
+ return nil, err
+ }
+ defer os.Remove(f1.Name())
+ defer f1.Close()
+
+ f2, err := ioutil.TempFile("", "go-fix")
+ if err != nil {
+ return nil, err
+ }
+ defer os.Remove(f2.Name())
+ defer f2.Close()
+
+ f1.Write(b1)
+ f2.Write(b2)
+
+ data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
+ if len(data) > 0 {
+ // diff exits with a non-zero status when the files don't match.
+ // Ignore that failure as long as we get output.
+ err = nil
+ }
+ return
+}
diff --git a/src/cmd/fix/main_test.go b/src/cmd/fix/main_test.go
new file mode 100644
index 000000000..2151bf29e
--- /dev/null
+++ b/src/cmd/fix/main_test.go
@@ -0,0 +1,129 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "go/parser"
+ "strings"
+ "testing"
+)
+
+type testCase struct {
+ Name string
+ Fn func(*ast.File) bool
+ In string
+ Out string
+}
+
+var testCases []testCase
+
+func addTestCases(t []testCase, fn func(*ast.File) bool) {
+ // Fill in fn to avoid repetition in definitions.
+ if fn != nil {
+ for i := range t {
+ if t[i].Fn == nil {
+ t[i].Fn = fn
+ }
+ }
+ }
+ testCases = append(testCases, t...)
+}
+
+func fnop(*ast.File) bool { return false }
+
+func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
+ file, err := parser.ParseFile(fset, desc, in, parserMode)
+ if err != nil {
+ t.Errorf("%s: parsing: %v", desc, err)
+ return
+ }
+
+ outb, err := gofmtFile(file)
+ if err != nil {
+ t.Errorf("%s: printing: %v", desc, err)
+ return
+ }
+ if s := string(outb); in != s && mustBeGofmt {
+ t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
+ desc, desc, in, desc, s)
+ tdiff(t, in, s)
+ return
+ }
+
+ if fn == nil {
+ for _, fix := range fixes {
+ if fix.f(file) {
+ fixed = true
+ }
+ }
+ } else {
+ fixed = fn(file)
+ }
+
+ outb, err = gofmtFile(file)
+ if err != nil {
+ t.Errorf("%s: printing: %v", desc, err)
+ return
+ }
+
+ return string(outb), fixed, true
+}
+
+func TestRewrite(t *testing.T) {
+ for _, tt := range testCases {
+ // Apply fix: should get tt.Out.
+ out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
+ if !ok {
+ continue
+ }
+
+ // reformat to get printing right
+ out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
+ if !ok {
+ continue
+ }
+
+ if out != tt.Out {
+ t.Errorf("%s: incorrect output.\n", tt.Name)
+ if !strings.HasPrefix(tt.Name, "testdata/") {
+ t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
+ }
+ tdiff(t, out, tt.Out)
+ continue
+ }
+
+ if changed := out != tt.In; changed != fixed {
+ t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed)
+ continue
+ }
+
+ // Should not change if run again.
+ out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
+ if !ok {
+ continue
+ }
+
+ if fixed2 {
+ t.Errorf("%s: applied fixes during second round", tt.Name)
+ continue
+ }
+
+ if out2 != out {
+ t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
+ tt.Name, out, out2)
+ tdiff(t, out, out2)
+ }
+ }
+}
+
+func tdiff(t *testing.T, a, b string) {
+ data, err := diff([]byte(a), []byte(b))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ t.Error(string(data))
+}
diff --git a/src/cmd/fix/mapdelete.go b/src/cmd/fix/mapdelete.go
new file mode 100644
index 000000000..db89c7bf4
--- /dev/null
+++ b/src/cmd/fix/mapdelete.go
@@ -0,0 +1,89 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "go/ast"
+
+func init() {
+ register(mapdeleteFix)
+}
+
+var mapdeleteFix = fix{
+ "mapdelete",
+ "2011-10-18",
+ mapdelete,
+ `Use delete(m, k) instead of m[k] = 0, false.
+
+http://codereview.appspot.com/5272045
+`,
+}
+
+func mapdelete(f *ast.File) bool {
+ fixed := false
+ walk(f, func(n interface{}) {
+ stmt, ok := n.(*ast.Stmt)
+ if !ok {
+ return
+ }
+ as, ok := (*stmt).(*ast.AssignStmt)
+ if !ok || len(as.Lhs) != 1 || len(as.Rhs) != 2 {
+ return
+ }
+ ix, ok := as.Lhs[0].(*ast.IndexExpr)
+ if !ok {
+ return
+ }
+ if !isTopName(as.Rhs[1], "false") {
+ warn(as.Pos(), "two-element map assignment with non-false second value")
+ return
+ }
+ if !canDrop(as.Rhs[0]) {
+ warn(as.Pos(), "two-element map assignment with non-trivial first value")
+ return
+ }
+ *stmt = &ast.ExprStmt{
+ X: &ast.CallExpr{
+ Fun: &ast.Ident{
+ NamePos: as.Pos(),
+ Name: "delete",
+ },
+ Args: []ast.Expr{ix.X, ix.Index},
+ },
+ }
+ fixed = true
+ })
+ return fixed
+}
+
+// canDrop reports whether it is safe to drop the
+// evaluation of n from the program.
+// It is very conservative.
+func canDrop(n ast.Expr) bool {
+ switch n := n.(type) {
+ case *ast.Ident, *ast.BasicLit:
+ return true
+ case *ast.ParenExpr:
+ return canDrop(n.X)
+ case *ast.SelectorExpr:
+ return canDrop(n.X)
+ case *ast.CompositeLit:
+ if !canDrop(n.Type) {
+ return false
+ }
+ for _, e := range n.Elts {
+ if !canDrop(e) {
+ return false
+ }
+ }
+ return true
+ case *ast.StarExpr:
+ // Dropping *x is questionable,
+ // but we have to be able to drop (*T)(nil).
+ return canDrop(n.X)
+ case *ast.ArrayType, *ast.ChanType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.StructType:
+ return true
+ }
+ return false
+}
diff --git a/src/cmd/fix/mapdelete_test.go b/src/cmd/fix/mapdelete_test.go
new file mode 100644
index 000000000..8ed50328e
--- /dev/null
+++ b/src/cmd/fix/mapdelete_test.go
@@ -0,0 +1,43 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(mapdeleteTests, mapdelete)
+}
+
+var mapdeleteTests = []testCase{
+ {
+ Name: "mapdelete.0",
+ In: `package main
+
+func f() {
+ m[x] = 0, false
+ m[x] = g(), false
+ m[x] = 1
+ delete(m, x)
+ m[x] = 0, b
+}
+
+func g(false bool) {
+ m[x] = 0, false
+}
+`,
+ Out: `package main
+
+func f() {
+ delete(m, x)
+ m[x] = g(), false
+ m[x] = 1
+ delete(m, x)
+ m[x] = 0, b
+}
+
+func g(false bool) {
+ m[x] = 0, false
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/math.go b/src/cmd/fix/math.go
new file mode 100644
index 000000000..2ec837eb0
--- /dev/null
+++ b/src/cmd/fix/math.go
@@ -0,0 +1,51 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "go/ast"
+
+func init() {
+ register(mathFix)
+}
+
+var mathFix = fix{
+ "math",
+ "2011-09-29",
+ math,
+ `Remove the leading F from math functions such as Fabs.
+
+http://codereview.appspot.com/5158043
+`,
+}
+
+var mathRenames = []struct{ in, out string }{
+ {"Fabs", "Abs"},
+ {"Fdim", "Dim"},
+ {"Fmax", "Max"},
+ {"Fmin", "Min"},
+ {"Fmod", "Mod"},
+}
+
+func math(f *ast.File) bool {
+ if !imports(f, "math") {
+ return false
+ }
+
+ fixed := false
+
+ walk(f, func(n interface{}) {
+ // Rename functions.
+ if expr, ok := n.(ast.Expr); ok {
+ for _, s := range mathRenames {
+ if isPkgDot(expr, "math", s.in) {
+ expr.(*ast.SelectorExpr).Sel.Name = s.out
+ fixed = true
+ return
+ }
+ }
+ }
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/math_test.go b/src/cmd/fix/math_test.go
new file mode 100644
index 000000000..b8d69d2f2
--- /dev/null
+++ b/src/cmd/fix/math_test.go
@@ -0,0 +1,47 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(mathTests, math)
+}
+
+var mathTests = []testCase{
+ {
+ Name: "math.0",
+ In: `package main
+
+import (
+ "math"
+)
+
+func f() {
+ math.Fabs(1)
+ math.Fdim(1)
+ math.Fmax(1)
+ math.Fmin(1)
+ math.Fmod(1)
+ math.Abs(1)
+ foo.Fabs(1)
+}
+`,
+ Out: `package main
+
+import (
+ "math"
+)
+
+func f() {
+ math.Abs(1)
+ math.Dim(1)
+ math.Max(1)
+ math.Min(1)
+ math.Mod(1)
+ math.Abs(1)
+ foo.Fabs(1)
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/netdial.go b/src/cmd/fix/netdial.go
new file mode 100644
index 000000000..2de994cff
--- /dev/null
+++ b/src/cmd/fix/netdial.go
@@ -0,0 +1,117 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(netdialFix)
+ register(tlsdialFix)
+ register(netlookupFix)
+}
+
+var netdialFix = fix{
+ "netdial",
+ "2011-03-28",
+ netdial,
+ `Adapt 3-argument calls of net.Dial to use 2-argument form.
+
+http://codereview.appspot.com/4244055
+`,
+}
+
+var tlsdialFix = fix{
+ "tlsdial",
+ "2011-03-28",
+ tlsdial,
+ `Adapt 4-argument calls of tls.Dial to use 3-argument form.
+
+http://codereview.appspot.com/4244055
+`,
+}
+
+var netlookupFix = fix{
+ "netlookup",
+ "2011-03-28",
+ netlookup,
+ `Adapt 3-result calls to net.LookupHost to use 2-result form.
+
+http://codereview.appspot.com/4244055
+`,
+}
+
+func netdial(f *ast.File) bool {
+ if !imports(f, "net") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok || !isPkgDot(call.Fun, "net", "Dial") || len(call.Args) != 3 {
+ return
+ }
+ // net.Dial(a, "", b) -> net.Dial(a, b)
+ if !isEmptyString(call.Args[1]) {
+ warn(call.Pos(), "call to net.Dial with non-empty second argument")
+ return
+ }
+ call.Args[1] = call.Args[2]
+ call.Args = call.Args[:2]
+ fixed = true
+ })
+ return fixed
+}
+
+func tlsdial(f *ast.File) bool {
+ if !imports(f, "crypto/tls") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok || !isPkgDot(call.Fun, "tls", "Dial") || len(call.Args) != 4 {
+ return
+ }
+ // tls.Dial(a, "", b, c) -> tls.Dial(a, b, c)
+ if !isEmptyString(call.Args[1]) {
+ warn(call.Pos(), "call to tls.Dial with non-empty second argument")
+ return
+ }
+ call.Args[1] = call.Args[2]
+ call.Args[2] = call.Args[3]
+ call.Args = call.Args[:3]
+ fixed = true
+ })
+ return fixed
+}
+
+func netlookup(f *ast.File) bool {
+ if !imports(f, "net") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ as, ok := n.(*ast.AssignStmt)
+ if !ok || len(as.Lhs) != 3 || len(as.Rhs) != 1 {
+ return
+ }
+ call, ok := as.Rhs[0].(*ast.CallExpr)
+ if !ok || !isPkgDot(call.Fun, "net", "LookupHost") {
+ return
+ }
+ if !isBlank(as.Lhs[2]) {
+ warn(as.Pos(), "call to net.LookupHost expecting cname; use net.LookupCNAME")
+ return
+ }
+ as.Lhs = as.Lhs[:2]
+ fixed = true
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/netdial_test.go b/src/cmd/fix/netdial_test.go
new file mode 100644
index 000000000..fff00b4ad
--- /dev/null
+++ b/src/cmd/fix/netdial_test.go
@@ -0,0 +1,57 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(netdialTests, nil)
+}
+
+var netdialTests = []testCase{
+ {
+ Name: "netdial.0",
+ Fn: netdial,
+ In: `package main
+
+import "net"
+
+func f() {
+ c, err := net.Dial(net, "", addr)
+ c, err = net.Dial(net, "", addr)
+}
+`,
+ Out: `package main
+
+import "net"
+
+func f() {
+ c, err := net.Dial(net, addr)
+ c, err = net.Dial(net, addr)
+}
+`,
+ },
+
+ {
+ Name: "netlookup.0",
+ Fn: netlookup,
+ In: `package main
+
+import "net"
+
+func f() {
+ foo, bar, _ := net.LookupHost(host)
+ foo, bar, _ = net.LookupHost(host)
+}
+`,
+ Out: `package main
+
+import "net"
+
+func f() {
+ foo, bar := net.LookupHost(host)
+ foo, bar = net.LookupHost(host)
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/netudpgroup.go b/src/cmd/fix/netudpgroup.go
new file mode 100644
index 000000000..b54beb0de
--- /dev/null
+++ b/src/cmd/fix/netudpgroup.go
@@ -0,0 +1,58 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(netudpgroupFix)
+}
+
+var netudpgroupFix = fix{
+ "netudpgroup",
+ "2011-08-18",
+ netudpgroup,
+ `Adapt 1-argument calls of net.(*UDPConn).JoinGroup, LeaveGroup to use 2-argument form.
+
+http://codereview.appspot.com/4815074
+`,
+}
+
+func netudpgroup(f *ast.File) bool {
+ if !imports(f, "net") {
+ return false
+ }
+
+ fixed := false
+ for _, d := range f.Decls {
+ fd, ok := d.(*ast.FuncDecl)
+ if !ok || fd.Body == nil {
+ continue
+ }
+ walk(fd.Body, func(n interface{}) {
+ ce, ok := n.(*ast.CallExpr)
+ if !ok {
+ return
+ }
+ se, ok := ce.Fun.(*ast.SelectorExpr)
+ if !ok || len(ce.Args) != 1 {
+ return
+ }
+ switch se.Sel.String() {
+ case "JoinGroup", "LeaveGroup":
+ // c.JoinGroup(a) -> c.JoinGroup(nil, a)
+ // c.LeaveGroup(a) -> c.LeaveGroup(nil, a)
+ arg := ce.Args[0]
+ ce.Args = make([]ast.Expr, 2)
+ ce.Args[0] = ast.NewIdent("nil")
+ ce.Args[1] = arg
+ fixed = true
+ }
+ })
+ }
+ return fixed
+}
diff --git a/src/cmd/fix/netudpgroup_test.go b/src/cmd/fix/netudpgroup_test.go
new file mode 100644
index 000000000..88c0e093f
--- /dev/null
+++ b/src/cmd/fix/netudpgroup_test.go
@@ -0,0 +1,53 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(netudpgroupTests, netudpgroup)
+}
+
+var netudpgroupTests = []testCase{
+ {
+ Name: "netudpgroup.0",
+ In: `package main
+
+import "net"
+
+func f() {
+ err := x.JoinGroup(gaddr)
+ err = y.LeaveGroup(gaddr)
+}
+`,
+ Out: `package main
+
+import "net"
+
+func f() {
+ err := x.JoinGroup(nil, gaddr)
+ err = y.LeaveGroup(nil, gaddr)
+}
+`,
+ },
+ // Innocent function with no body.
+ {
+ Name: "netudpgroup.1",
+ In: `package main
+
+import "net"
+
+func f()
+
+var _ net.IP
+`,
+ Out: `package main
+
+import "net"
+
+func f()
+
+var _ net.IP
+`,
+ },
+}
diff --git a/src/cmd/fix/newwriter.go b/src/cmd/fix/newwriter.go
new file mode 100644
index 000000000..4befe24fb
--- /dev/null
+++ b/src/cmd/fix/newwriter.go
@@ -0,0 +1,90 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(newWriterFix)
+}
+
+var newWriterFix = fix{
+ "newWriter",
+ "2012-02-14",
+ newWriter,
+ `Adapt bufio, gzip and zlib NewWriterXxx calls for whether they return errors.
+
+Also rename gzip.Compressor and gzip.Decompressor to gzip.Writer and gzip.Reader.
+
+http://codereview.appspot.com/5639057 and
+http://codereview.appspot.com/5642054
+`,
+}
+
+func newWriter(f *ast.File) bool {
+ if !imports(f, "bufio") && !imports(f, "compress/gzip") && !imports(f, "compress/zlib") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ switch n := n.(type) {
+ case *ast.SelectorExpr:
+ if isTopName(n.X, "gzip") {
+ switch n.Sel.String() {
+ case "Compressor":
+ n.Sel = &ast.Ident{Name: "Writer"}
+ fixed = true
+ case "Decompressor":
+ n.Sel = &ast.Ident{Name: "Reader"}
+ fixed = true
+ }
+ } else if isTopName(n.X, "zlib") {
+ if n.Sel.String() == "NewWriterDict" {
+ n.Sel = &ast.Ident{Name: "NewWriterLevelDict"}
+ fixed = true
+ }
+ }
+
+ case *ast.AssignStmt:
+ // Drop the ", _" in assignments of the form:
+ // w0, _ = gzip.NewWriter(w1)
+ if len(n.Lhs) != 2 || len(n.Rhs) != 1 {
+ return
+ }
+ i, ok := n.Lhs[1].(*ast.Ident)
+ if !ok {
+ return
+ }
+ if i.String() != "_" {
+ return
+ }
+ c, ok := n.Rhs[0].(*ast.CallExpr)
+ if !ok {
+ return
+ }
+ s, ok := c.Fun.(*ast.SelectorExpr)
+ if !ok {
+ return
+ }
+ sel := s.Sel.String()
+ switch {
+ case isTopName(s.X, "bufio") && (sel == "NewReaderSize" || sel == "NewWriterSize"):
+ // No-op.
+ case isTopName(s.X, "gzip") && sel == "NewWriter":
+ // No-op.
+ case isTopName(s.X, "zlib") && sel == "NewWriter":
+ // No-op.
+ default:
+ return
+ }
+ n.Lhs = n.Lhs[:1]
+ fixed = true
+ }
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/newwriter_test.go b/src/cmd/fix/newwriter_test.go
new file mode 100644
index 000000000..1f59628a0
--- /dev/null
+++ b/src/cmd/fix/newwriter_test.go
@@ -0,0 +1,83 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(newWriterTests, newWriter)
+}
+
+var newWriterTests = []testCase{
+ {
+ Name: "newWriter.0",
+ In: `package main
+
+import (
+ "bufio"
+ "compress/gzip"
+ "compress/zlib"
+ "io"
+
+ "foo"
+)
+
+func f() *gzip.Compressor {
+ var (
+ _ gzip.Compressor
+ _ *gzip.Decompressor
+ _ struct {
+ W *gzip.Compressor
+ R gzip.Decompressor
+ }
+ )
+
+ var w io.Writer
+ br := bufio.NewReader(nil)
+ br, _ = bufio.NewReaderSize(nil, 256)
+ bw, err := bufio.NewWriterSize(w, 256) // Unfixable, as it declares an err variable.
+ bw, _ = bufio.NewWriterSize(w, 256)
+ fw, _ := foo.NewWriter(w)
+ gw, _ := gzip.NewWriter(w)
+ gw, _ = gzip.NewWriter(w)
+ zw, _ := zlib.NewWriter(w)
+ _ = zlib.NewWriterDict(zw, 0, nil)
+ return gw
+}
+`,
+ Out: `package main
+
+import (
+ "bufio"
+ "compress/gzip"
+ "compress/zlib"
+ "io"
+
+ "foo"
+)
+
+func f() *gzip.Writer {
+ var (
+ _ gzip.Writer
+ _ *gzip.Reader
+ _ struct {
+ W *gzip.Writer
+ R gzip.Reader
+ }
+ )
+
+ var w io.Writer
+ br := bufio.NewReader(nil)
+ br = bufio.NewReaderSize(nil, 256)
+ bw, err := bufio.NewWriterSize(w, 256) // Unfixable, as it declares an err variable.
+ bw = bufio.NewWriterSize(w, 256)
+ fw, _ := foo.NewWriter(w)
+ gw := gzip.NewWriter(w)
+ gw = gzip.NewWriter(w)
+ zw := zlib.NewWriter(w)
+ _ = zlib.NewWriterLevelDict(zw, 0, nil)
+ return gw
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/oserrorstring.go b/src/cmd/fix/oserrorstring.go
new file mode 100644
index 000000000..a75a2c12d
--- /dev/null
+++ b/src/cmd/fix/oserrorstring.go
@@ -0,0 +1,75 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(oserrorstringFix)
+}
+
+var oserrorstringFix = fix{
+ "oserrorstring",
+ "2011-06-22",
+ oserrorstring,
+ `Replace os.ErrorString() conversions with calls to os.NewError().
+
+http://codereview.appspot.com/4607052
+`,
+}
+
+func oserrorstring(f *ast.File) bool {
+ if !imports(f, "os") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ // The conversion os.ErrorString(x) looks like a call
+ // of os.ErrorString with one argument.
+ if call := callExpr(n, "os", "ErrorString"); call != nil {
+ // os.ErrorString(args) -> os.NewError(args)
+ call.Fun.(*ast.SelectorExpr).Sel.Name = "NewError"
+ // os.ErrorString(args) -> os.NewError(args)
+ call.Fun.(*ast.SelectorExpr).Sel.Name = "NewError"
+ fixed = true
+ return
+ }
+
+ // Remove os.Error type from variable declarations initialized
+ // with an os.NewError.
+ // (An *ast.ValueSpec may also be used in a const declaration
+ // but those won't be initialized with a call to os.NewError.)
+ if spec, ok := n.(*ast.ValueSpec); ok &&
+ len(spec.Names) == 1 &&
+ isPkgDot(spec.Type, "os", "Error") &&
+ len(spec.Values) == 1 &&
+ callExpr(spec.Values[0], "os", "NewError") != nil {
+ // var name os.Error = os.NewError(x) ->
+ // var name = os.NewError(x)
+ spec.Type = nil
+ fixed = true
+ return
+ }
+
+ // Other occurrences of os.ErrorString are not fixed
+ // but they are rare.
+
+ })
+ return fixed
+}
+
+// callExpr returns the call expression if x is a call to pkg.name with one argument;
+// otherwise it returns nil.
+func callExpr(x interface{}, pkg, name string) *ast.CallExpr {
+ if call, ok := x.(*ast.CallExpr); ok &&
+ len(call.Args) == 1 &&
+ isPkgDot(call.Fun, pkg, name) {
+ return call
+ }
+ return nil
+}
diff --git a/src/cmd/fix/oserrorstring_test.go b/src/cmd/fix/oserrorstring_test.go
new file mode 100644
index 000000000..75551480c
--- /dev/null
+++ b/src/cmd/fix/oserrorstring_test.go
@@ -0,0 +1,57 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(oserrorstringTests, oserrorstring)
+}
+
+var oserrorstringTests = []testCase{
+ {
+ Name: "oserrorstring.0",
+ In: `package main
+
+import "os"
+
+var _ = os.ErrorString("foo")
+var _ os.Error = os.ErrorString("bar1")
+var _ os.Error = os.NewError("bar2")
+var _ os.Error = MyError("bal") // don't rewrite this one
+
+var (
+ _ = os.ErrorString("foo")
+ _ os.Error = os.ErrorString("bar1")
+ _ os.Error = os.NewError("bar2")
+ _ os.Error = MyError("bal") // don't rewrite this one
+)
+
+func _() (err os.Error) {
+ err = os.ErrorString("foo")
+ return os.ErrorString("foo")
+}
+`,
+ Out: `package main
+
+import "os"
+
+var _ = os.NewError("foo")
+var _ = os.NewError("bar1")
+var _ = os.NewError("bar2")
+var _ os.Error = MyError("bal") // don't rewrite this one
+
+var (
+ _ = os.NewError("foo")
+ _ = os.NewError("bar1")
+ _ = os.NewError("bar2")
+ _ os.Error = MyError("bal") // don't rewrite this one
+)
+
+func _() (err os.Error) {
+ err = os.NewError("foo")
+ return os.NewError("foo")
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/osopen.go b/src/cmd/fix/osopen.go
new file mode 100644
index 000000000..af2796ac2
--- /dev/null
+++ b/src/cmd/fix/osopen.go
@@ -0,0 +1,124 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(osopenFix)
+}
+
+var osopenFix = fix{
+ "osopen",
+ "2011-04-04",
+ osopen,
+ `Adapt os.Open calls to new, easier API and rename O_CREAT O_CREATE.
+
+http://codereview.appspot.com/4357052
+`,
+}
+
+func osopen(f *ast.File) bool {
+ if !imports(f, "os") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ // Rename O_CREAT to O_CREATE.
+ if expr, ok := n.(ast.Expr); ok && isPkgDot(expr, "os", "O_CREAT") {
+ expr.(*ast.SelectorExpr).Sel.Name = "O_CREATE"
+ fixed = true
+ return
+ }
+
+ // Fix up calls to Open.
+ call, ok := n.(*ast.CallExpr)
+ if !ok || len(call.Args) != 3 {
+ return
+ }
+ if !isPkgDot(call.Fun, "os", "Open") {
+ return
+ }
+ sel := call.Fun.(*ast.SelectorExpr)
+ args := call.Args
+ // os.Open(a, os.O_RDONLY, c) -> os.Open(a)
+ if isPkgDot(args[1], "os", "O_RDONLY") || isPkgDot(args[1], "syscall", "O_RDONLY") {
+ call.Args = call.Args[0:1]
+ fixed = true
+ return
+ }
+ // os.Open(a, createlike_flags, c) -> os.Create(a, c)
+ if isCreateFlag(args[1]) {
+ sel.Sel.Name = "Create"
+ if !isSimplePerm(args[2]) {
+ warn(sel.Pos(), "rewrote os.Open to os.Create with permission not 0666")
+ }
+ call.Args = args[0:1]
+ fixed = true
+ return
+ }
+ // Fallback: os.Open(a, b, c) -> os.OpenFile(a, b, c)
+ sel.Sel.Name = "OpenFile"
+ fixed = true
+ })
+ return fixed
+}
+
+func isCreateFlag(flag ast.Expr) bool {
+ foundCreate := false
+ foundTrunc := false
+ // OR'ing of flags: is O_CREATE on? + or | would be fine; we just look for os.O_CREATE
+ // and don't worry about the actual operator.
+ p := flag.Pos()
+ for {
+ lhs := flag
+ expr, isBinary := flag.(*ast.BinaryExpr)
+ if isBinary {
+ lhs = expr.Y
+ }
+ sel, ok := lhs.(*ast.SelectorExpr)
+ if !ok || !isTopName(sel.X, "os") {
+ return false
+ }
+ switch sel.Sel.Name {
+ case "O_CREATE":
+ foundCreate = true
+ case "O_TRUNC":
+ foundTrunc = true
+ case "O_RDONLY", "O_WRONLY", "O_RDWR":
+ // okay
+ default:
+ // Unexpected flag, like O_APPEND or O_EXCL.
+ // Be conservative and do not rewrite.
+ return false
+ }
+ if !isBinary {
+ break
+ }
+ flag = expr.X
+ }
+ if !foundCreate {
+ return false
+ }
+ if !foundTrunc {
+ warn(p, "rewrote os.Open with O_CREATE but not O_TRUNC to os.Create")
+ }
+ return foundCreate
+}
+
+func isSimplePerm(perm ast.Expr) bool {
+ basicLit, ok := perm.(*ast.BasicLit)
+ if !ok {
+ return false
+ }
+ switch basicLit.Value {
+ case "0666":
+ return true
+ }
+ return false
+}
diff --git a/src/cmd/fix/osopen_test.go b/src/cmd/fix/osopen_test.go
new file mode 100644
index 000000000..5797adb7b
--- /dev/null
+++ b/src/cmd/fix/osopen_test.go
@@ -0,0 +1,82 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(osopenTests, osopen)
+}
+
+var osopenTests = []testCase{
+ {
+ Name: "osopen.0",
+ In: `package main
+
+import (
+ "os"
+)
+
+func f() {
+ os.OpenFile(a, b, c)
+ os.Open(a, os.O_RDONLY, 0)
+ os.Open(a, os.O_RDONLY, 0666)
+ os.Open(a, os.O_RDWR, 0)
+ os.Open(a, os.O_CREAT, 0666)
+ os.Open(a, os.O_CREAT|os.O_TRUNC, 0664)
+ os.Open(a, os.O_CREATE, 0666)
+ os.Open(a, os.O_CREATE|os.O_TRUNC, 0664)
+ os.Open(a, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
+ os.Open(a, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
+ os.Open(a, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666)
+ os.Open(a, os.O_SURPRISE|os.O_CREATE, 0666)
+ _ = os.O_CREAT
+}
+`,
+ Out: `package main
+
+import (
+ "os"
+)
+
+func f() {
+ os.OpenFile(a, b, c)
+ os.Open(a)
+ os.Open(a)
+ os.OpenFile(a, os.O_RDWR, 0)
+ os.Create(a)
+ os.Create(a)
+ os.Create(a)
+ os.Create(a)
+ os.Create(a)
+ os.OpenFile(a, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
+ os.OpenFile(a, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666)
+ os.OpenFile(a, os.O_SURPRISE|os.O_CREATE, 0666)
+ _ = os.O_CREATE
+}
+`,
+ },
+ {
+ Name: "osopen.1",
+ In: `package main
+
+import (
+ "os"
+)
+
+func f() {
+ _ = os.O_CREAT
+}
+`,
+ Out: `package main
+
+import (
+ "os"
+)
+
+func f() {
+ _ = os.O_CREATE
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/procattr.go b/src/cmd/fix/procattr.go
new file mode 100644
index 000000000..ea375ec9d
--- /dev/null
+++ b/src/cmd/fix/procattr.go
@@ -0,0 +1,62 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "go/token"
+)
+
+func init() {
+ register(procattrFix)
+}
+
+var procattrFix = fix{
+ "procattr",
+ "2011-03-15",
+ procattr,
+ `Adapt calls to os.StartProcess to use new ProcAttr type.
+
+http://codereview.appspot.com/4253052
+`,
+}
+
+func procattr(f *ast.File) bool {
+ if !imports(f, "os") && !imports(f, "syscall") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok || len(call.Args) != 5 {
+ return
+ }
+ var pkg string
+ if isPkgDot(call.Fun, "os", "StartProcess") {
+ pkg = "os"
+ } else if isPkgDot(call.Fun, "syscall", "StartProcess") {
+ pkg = "syscall"
+ } else {
+ return
+ }
+ // os.StartProcess(a, b, c, d, e) -> os.StartProcess(a, b, &os.ProcAttr{Env: c, Dir: d, Files: e})
+ lit := &ast.CompositeLit{Type: ast.NewIdent(pkg + ".ProcAttr")}
+ env, dir, files := call.Args[2], call.Args[3], call.Args[4]
+ if !isName(env, "nil") && !isCall(env, "os", "Environ") {
+ lit.Elts = append(lit.Elts, &ast.KeyValueExpr{Key: ast.NewIdent("Env"), Value: env})
+ }
+ if !isEmptyString(dir) {
+ lit.Elts = append(lit.Elts, &ast.KeyValueExpr{Key: ast.NewIdent("Dir"), Value: dir})
+ }
+ if !isName(files, "nil") {
+ lit.Elts = append(lit.Elts, &ast.KeyValueExpr{Key: ast.NewIdent("Files"), Value: files})
+ }
+ call.Args[2] = &ast.UnaryExpr{Op: token.AND, X: lit}
+ call.Args = call.Args[:3]
+ fixed = true
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/procattr_test.go b/src/cmd/fix/procattr_test.go
new file mode 100644
index 000000000..9e2b86e74
--- /dev/null
+++ b/src/cmd/fix/procattr_test.go
@@ -0,0 +1,74 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(procattrTests, procattr)
+}
+
+var procattrTests = []testCase{
+ {
+ Name: "procattr.0",
+ In: `package main
+
+import (
+ "os"
+ "syscall"
+)
+
+func f() {
+ os.StartProcess(a, b, c, d, e)
+ os.StartProcess(a, b, os.Environ(), d, e)
+ os.StartProcess(a, b, nil, d, e)
+ os.StartProcess(a, b, c, "", e)
+ os.StartProcess(a, b, c, d, nil)
+ os.StartProcess(a, b, nil, "", nil)
+
+ os.StartProcess(
+ a,
+ b,
+ c,
+ d,
+ e,
+ )
+
+ syscall.StartProcess(a, b, c, d, e)
+ syscall.StartProcess(a, b, os.Environ(), d, e)
+ syscall.StartProcess(a, b, nil, d, e)
+ syscall.StartProcess(a, b, c, "", e)
+ syscall.StartProcess(a, b, c, d, nil)
+ syscall.StartProcess(a, b, nil, "", nil)
+}
+`,
+ Out: `package main
+
+import (
+ "os"
+ "syscall"
+)
+
+func f() {
+ os.StartProcess(a, b, &os.ProcAttr{Env: c, Dir: d, Files: e})
+ os.StartProcess(a, b, &os.ProcAttr{Dir: d, Files: e})
+ os.StartProcess(a, b, &os.ProcAttr{Dir: d, Files: e})
+ os.StartProcess(a, b, &os.ProcAttr{Env: c, Files: e})
+ os.StartProcess(a, b, &os.ProcAttr{Env: c, Dir: d})
+ os.StartProcess(a, b, &os.ProcAttr{})
+
+ os.StartProcess(
+ a,
+ b, &os.ProcAttr{Env: c, Dir: d, Files: e},
+ )
+
+ syscall.StartProcess(a, b, &syscall.ProcAttr{Env: c, Dir: d, Files: e})
+ syscall.StartProcess(a, b, &syscall.ProcAttr{Dir: d, Files: e})
+ syscall.StartProcess(a, b, &syscall.ProcAttr{Dir: d, Files: e})
+ syscall.StartProcess(a, b, &syscall.ProcAttr{Env: c, Files: e})
+ syscall.StartProcess(a, b, &syscall.ProcAttr{Env: c, Dir: d})
+ syscall.StartProcess(a, b, &syscall.ProcAttr{})
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/reflect.go b/src/cmd/fix/reflect.go
new file mode 100644
index 000000000..151da569d
--- /dev/null
+++ b/src/cmd/fix/reflect.go
@@ -0,0 +1,862 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// TODO(rsc): Once there is better support for writing
+// multi-package commands, this should really be in
+// its own package, and then we can drop all the "reflect"
+// prefixes on the global variables and functions.
+
+package main
+
+import (
+ "go/ast"
+ "go/token"
+ "strings"
+)
+
+func init() {
+ register(reflectFix)
+}
+
+var reflectFix = fix{
+ "reflect",
+ "2011-04-08",
+ reflectFn,
+ `Adapt code to new reflect API.
+
+http://codereview.appspot.com/4281055
+http://codereview.appspot.com/4433066
+`,
+}
+
+// The reflect API change dropped the concrete types *reflect.ArrayType etc.
+// Any type assertions prior to method calls can be deleted:
+// x.(*reflect.ArrayType).Len() -> x.Len()
+//
+// Any type checks can be replaced by assignment and check of Kind:
+// x, y := z.(*reflect.ArrayType)
+// ->
+// x := z
+// y := x.Kind() == reflect.Array
+//
+// If z is an ordinary variable name and x is not subsequently assigned to,
+// references to x can be replaced by z and the assignment deleted.
+// We only bother if x and z are the same name.
+// If y is not subsequently assigned to and neither is x, references to
+// y can be replaced by its expression. We only bother when there is
+// just one use or when the use appears in an if clause.
+//
+// Not all type checks result in a single Kind check. The rewrite of the type check for
+// reflect.ArrayOrSliceType checks x.Kind() against reflect.Array and reflect.Slice.
+// The rewrite for *reflect.IntType checks against Int, Int8, Int16, Int32, Int64.
+// The rewrite for *reflect.UintType adds Uintptr.
+//
+// A type switch turns into an assignment and a switch on Kind:
+// switch x := y.(type) {
+// case reflect.ArrayOrSliceType:
+// ...
+// case *reflect.ChanType:
+// ...
+// case *reflect.IntType:
+// ...
+// }
+// ->
+// switch x := y; x.Kind() {
+// case reflect.Array, reflect.Slice:
+// ...
+// case reflect.Chan:
+// ...
+// case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+// ...
+// }
+//
+// The same simplification applies: we drop x := x if x is not assigned
+// to in the switch cases.
+//
+// Because the type check assignment includes a type assertion in its
+// syntax and the rewrite traversal is bottom up, we must do a pass to
+// rewrite the type check assignments and then a separate pass to
+// rewrite the type assertions.
+//
+// The same process applies to the API changes for reflect.Value.
+//
+// For both cases, but especially Value, the code needs to be aware
+// of the type of a receiver when rewriting a method call. For example,
+// x.(*reflect.ArrayValue).Elem(i) becomes x.Index(i) while
+// x.(*reflect.MapValue).Elem(v) becomes x.MapIndex(v).
+// In general, reflectFn needs to know the type of the receiver expression.
+// In most cases (and in all the cases in the Go source tree), the toy
+// type checker in typecheck.go provides enough information for fix
+// to make the rewrite. If fix misses a rewrite, the code that is left over
+// will not compile, so it will be noticed immediately.
+
+func reflectFn(f *ast.File) bool {
+ if !imports(f, "reflect") {
+ return false
+ }
+
+ fixed := false
+
+ // Rewrite names in method calls.
+ // Needs basic type information (see above).
+ typeof, _ := typecheck(reflectTypeConfig, f)
+ walk(f, func(n interface{}) {
+ switch n := n.(type) {
+ case *ast.SelectorExpr:
+ typ := typeof[n.X]
+ if m := reflectRewriteMethod[typ]; m != nil {
+ if replace := m[n.Sel.Name]; replace != "" {
+ n.Sel.Name = replace
+ fixed = true
+ return
+ }
+ }
+
+ // For all reflect Values, replace SetValue with Set.
+ if isReflectValue[typ] && n.Sel.Name == "SetValue" {
+ n.Sel.Name = "Set"
+ fixed = true
+ return
+ }
+
+ // Replace reflect.MakeZero with reflect.Zero.
+ if isPkgDot(n, "reflect", "MakeZero") {
+ n.Sel.Name = "Zero"
+ fixed = true
+ return
+ }
+ }
+ })
+
+ // Replace PtrValue's PointTo(x) with Set(x.Addr()).
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok || len(call.Args) != 1 {
+ return
+ }
+ sel, ok := call.Fun.(*ast.SelectorExpr)
+ if !ok || sel.Sel.Name != "PointTo" {
+ return
+ }
+ typ := typeof[sel.X]
+ if typ != "*reflect.PtrValue" {
+ return
+ }
+ sel.Sel.Name = "Set"
+ if !isTopName(call.Args[0], "nil") {
+ call.Args[0] = &ast.SelectorExpr{
+ X: call.Args[0],
+ Sel: ast.NewIdent("Addr()"),
+ }
+ }
+ fixed = true
+ })
+
+ // Fix type switches.
+ walk(f, func(n interface{}) {
+ if reflectFixSwitch(n) {
+ fixed = true
+ }
+ })
+
+ // Fix type assertion checks (multiple assignment statements).
+ // Have to work on the statement context (statement list or if statement)
+ // so that we can insert an extra statement occasionally.
+ // Ignoring for and switch because they don't come up in
+ // typical code.
+ walk(f, func(n interface{}) {
+ switch n := n.(type) {
+ case *[]ast.Stmt:
+ // v is the replacement statement list.
+ var v []ast.Stmt
+ insert := func(x ast.Stmt) {
+ v = append(v, x)
+ }
+ for i, x := range *n {
+ // Tentatively append to v; if we rewrite x
+ // we'll have to update the entry, so remember
+ // the index.
+ j := len(v)
+ v = append(v, x)
+ if reflectFixTypecheck(&x, insert, (*n)[i+1:]) {
+ // reflectFixTypecheck may have overwritten x.
+ // Update the entry we appended just before the call.
+ v[j] = x
+ fixed = true
+ }
+ }
+ *n = v
+ case *ast.IfStmt:
+ x := &ast.ExprStmt{X: n.Cond}
+ if reflectFixTypecheck(&n.Init, nil, []ast.Stmt{x, n.Body, n.Else}) {
+ n.Cond = x.X
+ fixed = true
+ }
+ }
+ })
+
+ // Warn about any typecheck statements that we missed.
+ walk(f, reflectWarnTypecheckStmt)
+
+ // Now that those are gone, fix remaining type assertions.
+ // Delayed because the type checks have
+ // type assertions as part of their syntax.
+ walk(f, func(n interface{}) {
+ if reflectFixAssert(n) {
+ fixed = true
+ }
+ })
+
+ // Now that the type assertions are gone, rewrite remaining
+ // references to specific reflect types to use the general ones.
+ walk(f, func(n interface{}) {
+ ptr, ok := n.(*ast.Expr)
+ if !ok {
+ return
+ }
+ nn := *ptr
+ typ := reflectType(nn)
+ if typ == "" {
+ return
+ }
+ if strings.HasSuffix(typ, "Type") {
+ *ptr = newPkgDot(nn.Pos(), "reflect", "Type")
+ } else {
+ *ptr = newPkgDot(nn.Pos(), "reflect", "Value")
+ }
+ fixed = true
+ })
+
+ // Rewrite v.Set(nil) to v.Set(reflect.MakeZero(v.Type())).
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok || len(call.Args) != 1 || !isTopName(call.Args[0], "nil") {
+ return
+ }
+ sel, ok := call.Fun.(*ast.SelectorExpr)
+ if !ok || !isReflectValue[typeof[sel.X]] || sel.Sel.Name != "Set" {
+ return
+ }
+ call.Args[0] = &ast.CallExpr{
+ Fun: newPkgDot(call.Args[0].Pos(), "reflect", "Zero"),
+ Args: []ast.Expr{
+ &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: sel.X,
+ Sel: &ast.Ident{Name: "Type"},
+ },
+ },
+ },
+ }
+ fixed = true
+ })
+
+ // Rewrite v != nil to v.IsValid().
+ // Rewrite nil used as reflect.Value (in function argument or return) to reflect.Value{}.
+ walk(f, func(n interface{}) {
+ ptr, ok := n.(*ast.Expr)
+ if !ok {
+ return
+ }
+ if isTopName(*ptr, "nil") && isReflectValue[typeof[*ptr]] {
+ *ptr = ast.NewIdent("reflect.Value{}")
+ fixed = true
+ return
+ }
+ nn, ok := (*ptr).(*ast.BinaryExpr)
+ if !ok || (nn.Op != token.EQL && nn.Op != token.NEQ) || !isTopName(nn.Y, "nil") || !isReflectValue[typeof[nn.X]] {
+ return
+ }
+ var call ast.Expr = &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: nn.X,
+ Sel: &ast.Ident{Name: "IsValid"},
+ },
+ }
+ if nn.Op == token.EQL {
+ call = &ast.UnaryExpr{Op: token.NOT, X: call}
+ }
+ *ptr = call
+ fixed = true
+ })
+
+ // Rewrite
+ // reflect.Typeof -> reflect.TypeOf,
+ walk(f, func(n interface{}) {
+ sel, ok := n.(*ast.SelectorExpr)
+ if !ok {
+ return
+ }
+ if isTopName(sel.X, "reflect") && sel.Sel.Name == "Typeof" {
+ sel.Sel.Name = "TypeOf"
+ fixed = true
+ }
+ if isTopName(sel.X, "reflect") && sel.Sel.Name == "NewValue" {
+ sel.Sel.Name = "ValueOf"
+ fixed = true
+ }
+ })
+
+ return fixed
+}
+
+// reflectFixSwitch rewrites *n (if n is an *ast.Stmt) corresponding
+// to a type switch.
+func reflectFixSwitch(n interface{}) bool {
+ ptr, ok := n.(*ast.Stmt)
+ if !ok {
+ return false
+ }
+ n = *ptr
+
+ ts, ok := n.(*ast.TypeSwitchStmt)
+ if !ok {
+ return false
+ }
+
+ // Are any switch cases referring to reflect types?
+ // (That is, is this an old reflect type switch?)
+ for _, cas := range ts.Body.List {
+ for _, typ := range cas.(*ast.CaseClause).List {
+ if reflectType(typ) != "" {
+ goto haveReflect
+ }
+ }
+ }
+ return false
+
+haveReflect:
+ // Now we know it's an old reflect type switch. Prepare the new version,
+ // but don't replace or edit the original until we're sure of success.
+
+ // Figure out the initializer statement, if any, and the receiver for the Kind call.
+ var init ast.Stmt
+ var rcvr ast.Expr
+
+ init = ts.Init
+ switch n := ts.Assign.(type) {
+ default:
+ warn(ts.Pos(), "unexpected form in type switch")
+ return false
+
+ case *ast.AssignStmt:
+ as := n
+ ta := as.Rhs[0].(*ast.TypeAssertExpr)
+ x := isIdent(as.Lhs[0])
+ z := isIdent(ta.X)
+
+ if isBlank(x) || x != nil && z != nil && x.Name == z.Name && !assignsTo(x, ts.Body.List) {
+ // Can drop the variable creation.
+ rcvr = ta.X
+ } else {
+ // Need to use initialization statement.
+ if init != nil {
+ warn(ts.Pos(), "cannot rewrite reflect type switch with initializing statement")
+ return false
+ }
+ init = &ast.AssignStmt{
+ Lhs: []ast.Expr{as.Lhs[0]},
+ TokPos: as.TokPos,
+ Tok: token.DEFINE,
+ Rhs: []ast.Expr{ta.X},
+ }
+ rcvr = as.Lhs[0]
+ }
+
+ case *ast.ExprStmt:
+ rcvr = n.X.(*ast.TypeAssertExpr).X
+ }
+
+ // Prepare rewritten type switch (see large comment above for form).
+ sw := &ast.SwitchStmt{
+ Switch: ts.Switch,
+ Init: init,
+ Tag: &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: rcvr,
+ Sel: &ast.Ident{
+ NamePos: rcvr.End(),
+ Name: "Kind",
+ Obj: nil,
+ },
+ },
+ Lparen: rcvr.End(),
+ Rparen: rcvr.End(),
+ },
+ Body: &ast.BlockStmt{
+ Lbrace: ts.Body.Lbrace,
+ List: nil, // to be filled in
+ Rbrace: ts.Body.Rbrace,
+ },
+ }
+
+ // Translate cases.
+ for _, tcas := range ts.Body.List {
+ tcas := tcas.(*ast.CaseClause)
+ cas := &ast.CaseClause{
+ Case: tcas.Case,
+ Colon: tcas.Colon,
+ Body: tcas.Body,
+ }
+ for _, t := range tcas.List {
+ if isTopName(t, "nil") {
+ cas.List = append(cas.List, newPkgDot(t.Pos(), "reflect", "Invalid"))
+ continue
+ }
+
+ typ := reflectType(t)
+ if typ == "" {
+ warn(t.Pos(), "cannot rewrite reflect type switch case with non-reflect type %s", gofmt(t))
+ cas.List = append(cas.List, t)
+ continue
+ }
+
+ for _, k := range reflectKind[typ] {
+ cas.List = append(cas.List, newPkgDot(t.Pos(), "reflect", k))
+ }
+ }
+ sw.Body.List = append(sw.Body.List, cas)
+ }
+
+ // Everything worked. Rewrite AST.
+ *ptr = sw
+ return true
+}
+
+// Rewrite x, y = z.(T) into
+// x = z
+// y = x.Kind() == K
+// as described in the long comment above.
+//
+// If insert != nil, it can be called to insert a statement after *ptr in its block.
+// If insert == nil, insertion is not possible.
+// At most one call to insert is allowed.
+//
+// Scope gives the statements for which a declaration
+// in *ptr would be in scope.
+//
+// The result is true of the statement was rewritten.
+//
+func reflectFixTypecheck(ptr *ast.Stmt, insert func(ast.Stmt), scope []ast.Stmt) bool {
+ st := *ptr
+ as, ok := st.(*ast.AssignStmt)
+ if !ok || len(as.Lhs) != 2 || len(as.Rhs) != 1 {
+ return false
+ }
+
+ ta, ok := as.Rhs[0].(*ast.TypeAssertExpr)
+ if !ok {
+ return false
+ }
+ typ := reflectType(ta.Type)
+ if typ == "" {
+ return false
+ }
+
+ // Have x, y := z.(t).
+ x := isIdent(as.Lhs[0])
+ y := isIdent(as.Lhs[1])
+ z := isIdent(ta.X)
+
+ // First step is x := z, unless it's x := x and the resulting x is never reassigned.
+ // rcvr is the x in x.Kind().
+ var rcvr ast.Expr
+ if isBlank(x) ||
+ as.Tok == token.DEFINE && x != nil && z != nil && x.Name == z.Name && !assignsTo(x, scope) {
+ // Can drop the statement.
+ // If we need to insert a statement later, now we have a slot.
+ *ptr = &ast.EmptyStmt{}
+ insert = func(x ast.Stmt) { *ptr = x }
+ rcvr = ta.X
+ } else {
+ *ptr = &ast.AssignStmt{
+ Lhs: []ast.Expr{as.Lhs[0]},
+ TokPos: as.TokPos,
+ Tok: as.Tok,
+ Rhs: []ast.Expr{ta.X},
+ }
+ rcvr = as.Lhs[0]
+ }
+
+ // Prepare x.Kind() == T expression appropriate to t.
+ // If x is not a simple identifier, warn that we might be
+ // reevaluating x.
+ if x == nil {
+ warn(as.Pos(), "rewrite reevaluates expr with possible side effects: %s", gofmt(as.Lhs[0]))
+ }
+ yExpr, yNotExpr := reflectKindEq(rcvr, reflectKind[typ])
+
+ // Second step is y := x.Kind() == T, unless it's only used once
+ // or we have no way to insert that statement.
+ var yStmt *ast.AssignStmt
+ if as.Tok == token.DEFINE && countUses(y, scope) <= 1 || insert == nil {
+ // Can drop the statement and use the expression directly.
+ rewriteUses(y,
+ func(token.Pos) ast.Expr { return yExpr },
+ func(token.Pos) ast.Expr { return yNotExpr },
+ scope)
+ } else {
+ yStmt = &ast.AssignStmt{
+ Lhs: []ast.Expr{as.Lhs[1]},
+ TokPos: as.End(),
+ Tok: as.Tok,
+ Rhs: []ast.Expr{yExpr},
+ }
+ insert(yStmt)
+ }
+ return true
+}
+
+// reflectKindEq returns the expression z.Kind() == kinds[0] || z.Kind() == kinds[1] || ...
+// and its negation.
+// The qualifier "reflect." is inserted before each kinds[i] expression.
+func reflectKindEq(z ast.Expr, kinds []string) (ast.Expr, ast.Expr) {
+ n := len(kinds)
+ if n == 1 {
+ y := &ast.BinaryExpr{
+ X: &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: z,
+ Sel: ast.NewIdent("Kind"),
+ },
+ },
+ Op: token.EQL,
+ Y: newPkgDot(token.NoPos, "reflect", kinds[0]),
+ }
+ ynot := &ast.BinaryExpr{
+ X: &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: z,
+ Sel: ast.NewIdent("Kind"),
+ },
+ },
+ Op: token.NEQ,
+ Y: newPkgDot(token.NoPos, "reflect", kinds[0]),
+ }
+ return y, ynot
+ }
+
+ x, xnot := reflectKindEq(z, kinds[0:n-1])
+ y, ynot := reflectKindEq(z, kinds[n-1:])
+
+ or := &ast.BinaryExpr{
+ X: x,
+ Op: token.LOR,
+ Y: y,
+ }
+ andnot := &ast.BinaryExpr{
+ X: xnot,
+ Op: token.LAND,
+ Y: ynot,
+ }
+ return or, andnot
+}
+
+// if x represents a known old reflect type/value like *reflect.PtrType or reflect.ArrayOrSliceValue,
+// reflectType returns the string form of that type.
+func reflectType(x ast.Expr) string {
+ ptr, ok := x.(*ast.StarExpr)
+ if ok {
+ x = ptr.X
+ }
+
+ sel, ok := x.(*ast.SelectorExpr)
+ if !ok || !isName(sel.X, "reflect") {
+ return ""
+ }
+
+ var s = "reflect."
+ if ptr != nil {
+ s = "*reflect."
+ }
+ s += sel.Sel.Name
+
+ if reflectKind[s] != nil {
+ return s
+ }
+ return ""
+}
+
+// reflectWarnTypecheckStmt warns about statements
+// of the form x, y = z.(T) for any old reflect type T.
+// The last pass should have gotten them all, and if it didn't,
+// the next pass is going to turn them into x, y = z.
+func reflectWarnTypecheckStmt(n interface{}) {
+ as, ok := n.(*ast.AssignStmt)
+ if !ok || len(as.Lhs) != 2 || len(as.Rhs) != 1 {
+ return
+ }
+ ta, ok := as.Rhs[0].(*ast.TypeAssertExpr)
+ if !ok || reflectType(ta.Type) == "" {
+ return
+ }
+ warn(n.(ast.Node).Pos(), "unfixed reflect type check")
+}
+
+// reflectFixAssert rewrites x.(T) to x for any old reflect type T.
+func reflectFixAssert(n interface{}) bool {
+ ptr, ok := n.(*ast.Expr)
+ if ok {
+ ta, ok := (*ptr).(*ast.TypeAssertExpr)
+ if ok && reflectType(ta.Type) != "" {
+ *ptr = ta.X
+ return true
+ }
+ }
+ return false
+}
+
+// Tables describing the transformations.
+
+// Description of old reflect API for partial type checking.
+// We pretend the Elem method is on Type and Value instead
+// of enumerating all the types it is actually on.
+// Also, we pretend that ArrayType etc embeds Type for the
+// purposes of describing the API. (In fact they embed commonType,
+// which implements Type.)
+var reflectTypeConfig = &TypeConfig{
+ Type: map[string]*Type{
+ "reflect.ArrayOrSliceType": {Embed: []string{"reflect.Type"}},
+ "reflect.ArrayOrSliceValue": {Embed: []string{"reflect.Value"}},
+ "reflect.ArrayType": {Embed: []string{"reflect.Type"}},
+ "reflect.ArrayValue": {Embed: []string{"reflect.Value"}},
+ "reflect.BoolType": {Embed: []string{"reflect.Type"}},
+ "reflect.BoolValue": {Embed: []string{"reflect.Value"}},
+ "reflect.ChanType": {Embed: []string{"reflect.Type"}},
+ "reflect.ChanValue": {
+ Method: map[string]string{
+ "Recv": "func() (reflect.Value, bool)",
+ "TryRecv": "func() (reflect.Value, bool)",
+ },
+ Embed: []string{"reflect.Value"},
+ },
+ "reflect.ComplexType": {Embed: []string{"reflect.Type"}},
+ "reflect.ComplexValue": {Embed: []string{"reflect.Value"}},
+ "reflect.FloatType": {Embed: []string{"reflect.Type"}},
+ "reflect.FloatValue": {Embed: []string{"reflect.Value"}},
+ "reflect.FuncType": {
+ Method: map[string]string{
+ "In": "func(int) reflect.Type",
+ "Out": "func(int) reflect.Type",
+ },
+ Embed: []string{"reflect.Type"},
+ },
+ "reflect.FuncValue": {
+ Method: map[string]string{
+ "Call": "func([]reflect.Value) []reflect.Value",
+ },
+ },
+ "reflect.IntType": {Embed: []string{"reflect.Type"}},
+ "reflect.IntValue": {Embed: []string{"reflect.Value"}},
+ "reflect.InterfaceType": {Embed: []string{"reflect.Type"}},
+ "reflect.InterfaceValue": {Embed: []string{"reflect.Value"}},
+ "reflect.MapType": {
+ Method: map[string]string{
+ "Key": "func() reflect.Type",
+ },
+ Embed: []string{"reflect.Type"},
+ },
+ "reflect.MapValue": {
+ Method: map[string]string{
+ "Keys": "func() []reflect.Value",
+ },
+ Embed: []string{"reflect.Value"},
+ },
+ "reflect.Method": {
+ Field: map[string]string{
+ "Type": "*reflect.FuncType",
+ "Func": "*reflect.FuncValue",
+ },
+ },
+ "reflect.PtrType": {Embed: []string{"reflect.Type"}},
+ "reflect.PtrValue": {Embed: []string{"reflect.Value"}},
+ "reflect.SliceType": {Embed: []string{"reflect.Type"}},
+ "reflect.SliceValue": {
+ Method: map[string]string{
+ "Slice": "func(int, int) *reflect.SliceValue",
+ },
+ Embed: []string{"reflect.Value"},
+ },
+ "reflect.StringType": {Embed: []string{"reflect.Type"}},
+ "reflect.StringValue": {Embed: []string{"reflect.Value"}},
+ "reflect.StructField": {
+ Field: map[string]string{
+ "Type": "reflect.Type",
+ },
+ },
+ "reflect.StructType": {
+ Method: map[string]string{
+ "Field": "func() reflect.StructField",
+ "FieldByIndex": "func() reflect.StructField",
+ "FieldByName": "func() reflect.StructField,bool",
+ "FieldByNameFunc": "func() reflect.StructField,bool",
+ },
+ Embed: []string{"reflect.Type"},
+ },
+ "reflect.StructValue": {
+ Method: map[string]string{
+ "Field": "func() reflect.Value",
+ "FieldByIndex": "func() reflect.Value",
+ "FieldByName": "func() reflect.Value",
+ "FieldByNameFunc": "func() reflect.Value",
+ },
+ Embed: []string{"reflect.Value"},
+ },
+ "reflect.Type": {
+ Method: map[string]string{
+ "Elem": "func() reflect.Type",
+ "Method": "func() reflect.Method",
+ },
+ },
+ "reflect.UintType": {Embed: []string{"reflect.Type"}},
+ "reflect.UintValue": {Embed: []string{"reflect.Value"}},
+ "reflect.UnsafePointerType": {Embed: []string{"reflect.Type"}},
+ "reflect.UnsafePointerValue": {Embed: []string{"reflect.Value"}},
+ "reflect.Value": {
+ Method: map[string]string{
+ "Addr": "func() *reflect.PtrValue",
+ "Elem": "func() reflect.Value",
+ "Method": "func() *reflect.FuncValue",
+ "SetValue": "func(reflect.Value)",
+ },
+ },
+ },
+ Func: map[string]string{
+ "reflect.Append": "*reflect.SliceValue",
+ "reflect.AppendSlice": "*reflect.SliceValue",
+ "reflect.Indirect": "reflect.Value",
+ "reflect.MakeSlice": "*reflect.SliceValue",
+ "reflect.MakeChan": "*reflect.ChanValue",
+ "reflect.MakeMap": "*reflect.MapValue",
+ "reflect.MakeZero": "reflect.Value",
+ "reflect.NewValue": "reflect.Value",
+ "reflect.PtrTo": "*reflect.PtrType",
+ "reflect.Typeof": "reflect.Type",
+ },
+}
+
+var reflectRewriteMethod = map[string]map[string]string{
+ // The type API didn't change much.
+ "*reflect.ChanType": {"Dir": "ChanDir"},
+ "*reflect.FuncType": {"DotDotDot": "IsVariadic"},
+
+ // The value API has longer names to disambiguate
+ // methods with different signatures.
+ "reflect.ArrayOrSliceValue": { // interface, not pointer
+ "Elem": "Index",
+ },
+ "*reflect.ArrayValue": {
+ "Elem": "Index",
+ },
+ "*reflect.BoolValue": {
+ "Get": "Bool",
+ "Set": "SetBool",
+ },
+ "*reflect.ChanValue": {
+ "Get": "Pointer",
+ },
+ "*reflect.ComplexValue": {
+ "Get": "Complex",
+ "Set": "SetComplex",
+ "Overflow": "OverflowComplex",
+ },
+ "*reflect.FloatValue": {
+ "Get": "Float",
+ "Set": "SetFloat",
+ "Overflow": "OverflowFloat",
+ },
+ "*reflect.FuncValue": {
+ "Get": "Pointer",
+ },
+ "*reflect.IntValue": {
+ "Get": "Int",
+ "Set": "SetInt",
+ "Overflow": "OverflowInt",
+ },
+ "*reflect.InterfaceValue": {
+ "Get": "InterfaceData",
+ },
+ "*reflect.MapValue": {
+ "Elem": "MapIndex",
+ "Get": "Pointer",
+ "Keys": "MapKeys",
+ "SetElem": "SetMapIndex",
+ },
+ "*reflect.PtrValue": {
+ "Get": "Pointer",
+ },
+ "*reflect.SliceValue": {
+ "Elem": "Index",
+ "Get": "Pointer",
+ },
+ "*reflect.StringValue": {
+ "Get": "String",
+ "Set": "SetString",
+ },
+ "*reflect.UintValue": {
+ "Get": "Uint",
+ "Set": "SetUint",
+ "Overflow": "OverflowUint",
+ },
+ "*reflect.UnsafePointerValue": {
+ "Get": "Pointer",
+ "Set": "SetPointer",
+ },
+}
+
+var reflectKind = map[string][]string{
+ "reflect.ArrayOrSliceType": {"Array", "Slice"}, // interface, not pointer
+ "*reflect.ArrayType": {"Array"},
+ "*reflect.BoolType": {"Bool"},
+ "*reflect.ChanType": {"Chan"},
+ "*reflect.ComplexType": {"Complex64", "Complex128"},
+ "*reflect.FloatType": {"Float32", "Float64"},
+ "*reflect.FuncType": {"Func"},
+ "*reflect.IntType": {"Int", "Int8", "Int16", "Int32", "Int64"},
+ "*reflect.InterfaceType": {"Interface"},
+ "*reflect.MapType": {"Map"},
+ "*reflect.PtrType": {"Ptr"},
+ "*reflect.SliceType": {"Slice"},
+ "*reflect.StringType": {"String"},
+ "*reflect.StructType": {"Struct"},
+ "*reflect.UintType": {"Uint", "Uint8", "Uint16", "Uint32", "Uint64", "Uintptr"},
+ "*reflect.UnsafePointerType": {"UnsafePointer"},
+
+ "reflect.ArrayOrSliceValue": {"Array", "Slice"}, // interface, not pointer
+ "*reflect.ArrayValue": {"Array"},
+ "*reflect.BoolValue": {"Bool"},
+ "*reflect.ChanValue": {"Chan"},
+ "*reflect.ComplexValue": {"Complex64", "Complex128"},
+ "*reflect.FloatValue": {"Float32", "Float64"},
+ "*reflect.FuncValue": {"Func"},
+ "*reflect.IntValue": {"Int", "Int8", "Int16", "Int32", "Int64"},
+ "*reflect.InterfaceValue": {"Interface"},
+ "*reflect.MapValue": {"Map"},
+ "*reflect.PtrValue": {"Ptr"},
+ "*reflect.SliceValue": {"Slice"},
+ "*reflect.StringValue": {"String"},
+ "*reflect.StructValue": {"Struct"},
+ "*reflect.UintValue": {"Uint", "Uint8", "Uint16", "Uint32", "Uint64", "Uintptr"},
+ "*reflect.UnsafePointerValue": {"UnsafePointer"},
+}
+
+var isReflectValue = map[string]bool{
+ "reflect.ArrayOrSliceValue": true, // interface, not pointer
+ "*reflect.ArrayValue": true,
+ "*reflect.BoolValue": true,
+ "*reflect.ChanValue": true,
+ "*reflect.ComplexValue": true,
+ "*reflect.FloatValue": true,
+ "*reflect.FuncValue": true,
+ "*reflect.IntValue": true,
+ "*reflect.InterfaceValue": true,
+ "*reflect.MapValue": true,
+ "*reflect.PtrValue": true,
+ "*reflect.SliceValue": true,
+ "*reflect.StringValue": true,
+ "*reflect.StructValue": true,
+ "*reflect.UintValue": true,
+ "*reflect.UnsafePointerValue": true,
+ "reflect.Value": true, // interface, not pointer
+}
diff --git a/src/cmd/fix/reflect_test.go b/src/cmd/fix/reflect_test.go
new file mode 100644
index 000000000..032cbc745
--- /dev/null
+++ b/src/cmd/fix/reflect_test.go
@@ -0,0 +1,35 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file
+
+package main
+
+import (
+ "io/ioutil"
+ "log"
+ "path/filepath"
+)
+
+func init() {
+ addTestCases(reflectTests(), reflectFn)
+}
+
+func reflectTests() []testCase {
+ var tests []testCase
+
+ names, _ := filepath.Glob("testdata/reflect.*.in")
+ for _, in := range names {
+ out := in[:len(in)-len(".in")] + ".out"
+ inb, err := ioutil.ReadFile(in)
+ if err != nil {
+ log.Fatal(err)
+ }
+ outb, err := ioutil.ReadFile(out)
+ if err != nil {
+ log.Fatal(err)
+ }
+ tests = append(tests, testCase{Name: in, In: string(inb), Out: string(outb)})
+ }
+
+ return tests
+}
diff --git a/src/cmd/fix/signal.go b/src/cmd/fix/signal.go
new file mode 100644
index 000000000..5a583d41e
--- /dev/null
+++ b/src/cmd/fix/signal.go
@@ -0,0 +1,50 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "strings"
+)
+
+func init() {
+ register(signalFix)
+}
+
+var signalFix = fix{
+ "signal",
+ "2011-06-29",
+ signal,
+ `Adapt code to types moved from os/signal to signal.
+
+http://codereview.appspot.com/4437091
+`,
+}
+
+func signal(f *ast.File) (fixed bool) {
+ if !imports(f, "os/signal") {
+ return
+ }
+
+ walk(f, func(n interface{}) {
+ s, ok := n.(*ast.SelectorExpr)
+
+ if !ok || !isTopName(s.X, "signal") {
+ return
+ }
+
+ sel := s.Sel.String()
+ if sel == "Signal" || sel == "UnixSignal" || strings.HasPrefix(sel, "SIG") {
+ addImport(f, "os")
+ s.X = &ast.Ident{Name: "os"}
+ fixed = true
+ }
+ })
+
+ if fixed && !usesImport(f, "os/signal") {
+ deleteImport(f, "os/signal")
+ }
+ return
+}
diff --git a/src/cmd/fix/signal_test.go b/src/cmd/fix/signal_test.go
new file mode 100644
index 000000000..7bca7d5c4
--- /dev/null
+++ b/src/cmd/fix/signal_test.go
@@ -0,0 +1,94 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(signalTests, signal)
+}
+
+var signalTests = []testCase{
+ {
+ Name: "signal.0",
+ In: `package main
+
+import (
+ _ "a"
+ "os/signal"
+ _ "z"
+)
+
+type T1 signal.UnixSignal
+type T2 signal.Signal
+
+func f() {
+ _ = signal.SIGHUP
+ _ = signal.Incoming
+}
+`,
+ Out: `package main
+
+import (
+ _ "a"
+ "os"
+ "os/signal"
+ _ "z"
+)
+
+type T1 os.UnixSignal
+type T2 os.Signal
+
+func f() {
+ _ = os.SIGHUP
+ _ = signal.Incoming
+}
+`,
+ },
+ {
+ Name: "signal.1",
+ In: `package main
+
+import (
+ "os"
+ "os/signal"
+)
+
+func f() {
+ var _ os.Error
+ _ = signal.SIGHUP
+}
+`,
+ Out: `package main
+
+import "os"
+
+func f() {
+ var _ os.Error
+ _ = os.SIGHUP
+}
+`,
+ },
+ {
+ Name: "signal.2",
+ In: `package main
+
+import "os"
+import "os/signal"
+
+func f() {
+ var _ os.Error
+ _ = signal.SIGHUP
+}
+`,
+ Out: `package main
+
+import "os"
+
+func f() {
+ var _ os.Error
+ _ = os.SIGHUP
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/sorthelpers.go b/src/cmd/fix/sorthelpers.go
new file mode 100644
index 000000000..fa549313e
--- /dev/null
+++ b/src/cmd/fix/sorthelpers.go
@@ -0,0 +1,49 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(sorthelpersFix)
+}
+
+var sorthelpersFix = fix{
+ "sorthelpers",
+ "2011-07-08",
+ sorthelpers,
+ `Adapt code from sort.Sort[Ints|Float64s|Strings] to sort.[Ints|Float64s|Strings].
+`,
+}
+
+func sorthelpers(f *ast.File) (fixed bool) {
+ if !imports(f, "sort") {
+ return
+ }
+
+ walk(f, func(n interface{}) {
+ s, ok := n.(*ast.SelectorExpr)
+ if !ok || !isTopName(s.X, "sort") {
+ return
+ }
+
+ switch s.Sel.String() {
+ case "SortFloat64s":
+ s.Sel.Name = "Float64s"
+ case "SortInts":
+ s.Sel.Name = "Ints"
+ case "SortStrings":
+ s.Sel.Name = "Strings"
+ default:
+ return
+ }
+
+ fixed = true
+ })
+
+ return
+}
diff --git a/src/cmd/fix/sorthelpers_test.go b/src/cmd/fix/sorthelpers_test.go
new file mode 100644
index 000000000..dd6b58e03
--- /dev/null
+++ b/src/cmd/fix/sorthelpers_test.go
@@ -0,0 +1,45 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(sorthelpersTests, sorthelpers)
+}
+
+var sorthelpersTests = []testCase{
+ {
+ Name: "sortslice.0",
+ In: `package main
+
+import (
+ "sort"
+)
+
+func main() {
+ var s []string
+ sort.SortStrings(s)
+ var i []ints
+ sort.SortInts(i)
+ var f []float64
+ sort.SortFloat64s(f)
+}
+`,
+ Out: `package main
+
+import (
+ "sort"
+)
+
+func main() {
+ var s []string
+ sort.Strings(s)
+ var i []ints
+ sort.Ints(i)
+ var f []float64
+ sort.Float64s(f)
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/sortslice.go b/src/cmd/fix/sortslice.go
new file mode 100644
index 000000000..89267b847
--- /dev/null
+++ b/src/cmd/fix/sortslice.go
@@ -0,0 +1,52 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(sortsliceFix)
+}
+
+var sortsliceFix = fix{
+ "sortslice",
+ "2011-06-26",
+ sortslice,
+ `Adapt code from sort.[Float64|Int|String]Array to sort.[Float64|Int|String]Slice.
+
+http://codereview.appspot.com/4602054
+http://codereview.appspot.com/4639041
+`,
+}
+
+func sortslice(f *ast.File) (fixed bool) {
+ if !imports(f, "sort") {
+ return
+ }
+
+ walk(f, func(n interface{}) {
+ s, ok := n.(*ast.SelectorExpr)
+ if !ok || !isTopName(s.X, "sort") {
+ return
+ }
+
+ switch s.Sel.String() {
+ case "Float64Array":
+ s.Sel.Name = "Float64Slice"
+ case "IntArray":
+ s.Sel.Name = "IntSlice"
+ case "StringArray":
+ s.Sel.Name = "StringSlice"
+ default:
+ return
+ }
+
+ fixed = true
+ })
+
+ return
+}
diff --git a/src/cmd/fix/sortslice_test.go b/src/cmd/fix/sortslice_test.go
new file mode 100644
index 000000000..7b745a232
--- /dev/null
+++ b/src/cmd/fix/sortslice_test.go
@@ -0,0 +1,35 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(sortsliceTests, sortslice)
+}
+
+var sortsliceTests = []testCase{
+ {
+ Name: "sortslice.0",
+ In: `package main
+
+import (
+ "sort"
+)
+
+var _ = sort.Float64Array
+var _ = sort.IntArray
+var _ = sort.StringArray
+`,
+ Out: `package main
+
+import (
+ "sort"
+)
+
+var _ = sort.Float64Slice
+var _ = sort.IntSlice
+var _ = sort.StringSlice
+`,
+ },
+}
diff --git a/src/cmd/fix/strconv.go b/src/cmd/fix/strconv.go
new file mode 100644
index 000000000..6cd69020b
--- /dev/null
+++ b/src/cmd/fix/strconv.go
@@ -0,0 +1,127 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "go/ast"
+
+func init() {
+ register(strconvFix)
+}
+
+var strconvFix = fix{
+ "strconv",
+ "2011-12-01",
+ strconvFn,
+ `Convert to new strconv API.
+
+http://codereview.appspot.com/5434095
+http://codereview.appspot.com/5434069
+`,
+}
+
+func strconvFn(f *ast.File) bool {
+ if !imports(f, "strconv") {
+ return false
+ }
+
+ fixed := false
+
+ walk(f, func(n interface{}) {
+ // Rename functions.
+ call, ok := n.(*ast.CallExpr)
+ if !ok || len(call.Args) < 1 {
+ return
+ }
+ sel, ok := call.Fun.(*ast.SelectorExpr)
+ if !ok || !isTopName(sel.X, "strconv") {
+ return
+ }
+ change := func(name string) {
+ fixed = true
+ sel.Sel.Name = name
+ }
+ add := func(s string) {
+ call.Args = append(call.Args, expr(s))
+ }
+ switch sel.Sel.Name {
+ case "Atob":
+ change("ParseBool")
+ case "Atof32":
+ change("ParseFloat")
+ add("32") // bitSize
+ warn(call.Pos(), "rewrote strconv.Atof32(_) to strconv.ParseFloat(_, 32) but return value must be converted to float32")
+ case "Atof64":
+ change("ParseFloat")
+ add("64") // bitSize
+ case "AtofN":
+ change("ParseFloat")
+ case "Atoi":
+ // Atoi stayed as a convenience wrapper.
+ case "Atoi64":
+ change("ParseInt")
+ add("10") // base
+ add("64") // bitSize
+ case "Atoui":
+ change("ParseUint")
+ add("10") // base
+ add("0") // bitSize
+ warn(call.Pos(), "rewrote strconv.Atoui(_) to strconv.ParseUint(_, 10, 0) but return value must be converted to uint")
+ case "Atoui64":
+ change("ParseUint")
+ add("10") // base
+ add("64") // bitSize
+ case "Btoa":
+ change("FormatBool")
+ case "Btoi64":
+ change("ParseInt")
+ add("64") // bitSize
+ case "Btoui64":
+ change("ParseUint")
+ add("64") // bitSize
+ case "Ftoa32":
+ change("FormatFloat")
+ call.Args[0] = strconvRewrite("float32", "float64", call.Args[0])
+ add("32") // bitSize
+ case "Ftoa64":
+ change("FormatFloat")
+ add("64") // bitSize
+ case "FtoaN":
+ change("FormatFloat")
+ case "Itoa":
+ // Itoa stayed as a convenience wrapper.
+ case "Itoa64":
+ change("FormatInt")
+ add("10") // base
+ case "Itob":
+ change("FormatInt")
+ call.Args[0] = strconvRewrite("int", "int64", call.Args[0])
+ case "Itob64":
+ change("FormatInt")
+ case "Uitoa":
+ change("FormatUint")
+ call.Args[0] = strconvRewrite("uint", "uint64", call.Args[0])
+ add("10") // base
+ case "Uitoa64":
+ change("FormatUint")
+ add("10") // base
+ case "Uitob":
+ change("FormatUint")
+ call.Args[0] = strconvRewrite("uint", "uint64", call.Args[0])
+ case "Uitob64":
+ change("FormatUint")
+ }
+ })
+ return fixed
+}
+
+// rewrite from type t1 to type t2
+// If the expression x is of the form t1(_), use t2(_). Otherwise use t2(x).
+func strconvRewrite(t1, t2 string, x ast.Expr) ast.Expr {
+ if call, ok := x.(*ast.CallExpr); ok && isTopName(call.Fun, t1) {
+ call.Fun.(*ast.Ident).Name = t2
+ return x
+ }
+ return &ast.CallExpr{Fun: ast.NewIdent(t2), Args: []ast.Expr{x}}
+}
diff --git a/src/cmd/fix/strconv_test.go b/src/cmd/fix/strconv_test.go
new file mode 100644
index 000000000..7fbd4e42e
--- /dev/null
+++ b/src/cmd/fix/strconv_test.go
@@ -0,0 +1,93 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(strconvTests, strconvFn)
+}
+
+var strconvTests = []testCase{
+ {
+ Name: "strconv.0",
+ In: `package main
+
+import "strconv"
+
+func f() {
+ foo.Atob("abc")
+
+ strconv.Atob("true")
+ strconv.Btoa(false)
+
+ strconv.Atof32("1.2")
+ strconv.Atof64("1.2")
+ strconv.AtofN("1.2", 64)
+ strconv.Ftoa32(1.2, 'g', 17)
+ strconv.Ftoa64(1.2, 'g', 17)
+ strconv.FtoaN(1.2, 'g', 17, 64)
+
+ strconv.Atoi("3")
+ strconv.Atoi64("3")
+ strconv.Btoi64("1234", 5)
+
+ strconv.Atoui("3")
+ strconv.Atoui64("3")
+ strconv.Btoui64("1234", 5)
+
+ strconv.Itoa(123)
+ strconv.Itoa64(1234)
+ strconv.Itob(123, 5)
+ strconv.Itob64(1234, 5)
+
+ strconv.Uitoa(123)
+ strconv.Uitoa64(1234)
+ strconv.Uitob(123, 5)
+ strconv.Uitob64(1234, 5)
+
+ strconv.Uitoa(uint(x))
+ strconv.Uitoa(f(x))
+}
+`,
+ Out: `package main
+
+import "strconv"
+
+func f() {
+ foo.Atob("abc")
+
+ strconv.ParseBool("true")
+ strconv.FormatBool(false)
+
+ strconv.ParseFloat("1.2", 32)
+ strconv.ParseFloat("1.2", 64)
+ strconv.ParseFloat("1.2", 64)
+ strconv.FormatFloat(float64(1.2), 'g', 17, 32)
+ strconv.FormatFloat(1.2, 'g', 17, 64)
+ strconv.FormatFloat(1.2, 'g', 17, 64)
+
+ strconv.Atoi("3")
+ strconv.ParseInt("3", 10, 64)
+ strconv.ParseInt("1234", 5, 64)
+
+ strconv.ParseUint("3", 10, 0)
+ strconv.ParseUint("3", 10, 64)
+ strconv.ParseUint("1234", 5, 64)
+
+ strconv.Itoa(123)
+ strconv.FormatInt(1234, 10)
+ strconv.FormatInt(int64(123), 5)
+ strconv.FormatInt(1234, 5)
+
+ strconv.FormatUint(uint64(123), 10)
+ strconv.FormatUint(1234, 10)
+ strconv.FormatUint(uint64(123), 5)
+ strconv.FormatUint(1234, 5)
+
+ strconv.FormatUint(uint64(x), 10)
+ strconv.FormatUint(uint64(f(x)), 10)
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/stringssplit.go b/src/cmd/fix/stringssplit.go
new file mode 100644
index 000000000..d89ecf039
--- /dev/null
+++ b/src/cmd/fix/stringssplit.go
@@ -0,0 +1,72 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "go/token"
+)
+
+func init() {
+ register(stringssplitFix)
+}
+
+var stringssplitFix = fix{
+ "stringssplit",
+ "2011-06-28",
+ stringssplit,
+ `Restore strings.Split to its original meaning and add strings.SplitN. Bytes too.
+
+http://codereview.appspot.com/4661051
+`,
+}
+
+func stringssplit(f *ast.File) bool {
+ if !imports(f, "bytes") && !imports(f, "strings") {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ // func Split(s, sep string, n int) []string
+ // func SplitAfter(s, sep string, n int) []string
+ if !ok || len(call.Args) != 3 {
+ return
+ }
+ // Is this our function?
+ switch {
+ case isPkgDot(call.Fun, "bytes", "Split"):
+ case isPkgDot(call.Fun, "bytes", "SplitAfter"):
+ case isPkgDot(call.Fun, "strings", "Split"):
+ case isPkgDot(call.Fun, "strings", "SplitAfter"):
+ default:
+ return
+ }
+
+ sel := call.Fun.(*ast.SelectorExpr)
+ args := call.Args
+ fixed = true // We're committed.
+
+ // Is the last argument -1? If so, drop the arg.
+ // (Actually we just look for a negative integer literal.)
+ // Otherwise, Split->SplitN and keep the arg.
+ final := args[2]
+ if unary, ok := final.(*ast.UnaryExpr); ok && unary.Op == token.SUB {
+ if lit, ok := unary.X.(*ast.BasicLit); ok {
+ // Is it an integer? If so, it's a negative integer and that's what we're after.
+ if lit.Kind == token.INT {
+ // drop the last arg.
+ call.Args = args[0:2]
+ return
+ }
+ }
+ }
+
+ // If not, rename and keep the argument list.
+ sel.Sel.Name += "N"
+ })
+ return fixed
+}
diff --git a/src/cmd/fix/stringssplit_test.go b/src/cmd/fix/stringssplit_test.go
new file mode 100644
index 000000000..fa42b1bea
--- /dev/null
+++ b/src/cmd/fix/stringssplit_test.go
@@ -0,0 +1,51 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(stringssplitTests, stringssplit)
+}
+
+var stringssplitTests = []testCase{
+ {
+ Name: "stringssplit.0",
+ In: `package main
+
+import (
+ "bytes"
+ "strings"
+)
+
+func f() {
+ bytes.Split(a, b, c)
+ bytes.Split(a, b, -1)
+ bytes.SplitAfter(a, b, c)
+ bytes.SplitAfter(a, b, -1)
+ strings.Split(a, b, c)
+ strings.Split(a, b, -1)
+ strings.SplitAfter(a, b, c)
+ strings.SplitAfter(a, b, -1)
+}
+`,
+ Out: `package main
+
+import (
+ "bytes"
+ "strings"
+)
+
+func f() {
+ bytes.SplitN(a, b, c)
+ bytes.Split(a, b)
+ bytes.SplitAfterN(a, b, c)
+ bytes.SplitAfter(a, b)
+ strings.SplitN(a, b, c)
+ strings.Split(a, b)
+ strings.SplitAfterN(a, b, c)
+ strings.SplitAfter(a, b)
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/template.go b/src/cmd/fix/template.go
new file mode 100644
index 000000000..a3dd1440b
--- /dev/null
+++ b/src/cmd/fix/template.go
@@ -0,0 +1,111 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(templateFix)
+}
+
+var templateFix = fix{
+ "template",
+ "2011-11-22",
+ template,
+ `Rewrite calls to template.ParseFile to template.ParseFiles
+
+http://codereview.appspot.com/5433048
+`,
+}
+
+var templateSetGlobals = []string{
+ "ParseSetFiles",
+ "ParseSetGlob",
+ "ParseTemplateFiles",
+ "ParseTemplateGlob",
+ "Set",
+ "SetMust",
+}
+
+var templateSetMethods = []string{
+ "ParseSetFiles",
+ "ParseSetGlob",
+ "ParseTemplateFiles",
+ "ParseTemplateGlob",
+}
+
+var templateTypeConfig = &TypeConfig{
+ Type: map[string]*Type{
+ "template.Template": {
+ Method: map[string]string{
+ "Funcs": "func() *template.Template",
+ "Delims": "func() *template.Template",
+ "Parse": "func() (*template.Template, error)",
+ "ParseFile": "func() (*template.Template, error)",
+ "ParseInSet": "func() (*template.Template, error)",
+ },
+ },
+ "template.Set": {
+ Method: map[string]string{
+ "ParseSetFiles": "func() (*template.Set, error)",
+ "ParseSetGlob": "func() (*template.Set, error)",
+ "ParseTemplateFiles": "func() (*template.Set, error)",
+ "ParseTemplateGlob": "func() (*template.Set, error)",
+ },
+ },
+ },
+
+ Func: map[string]string{
+ "template.New": "*template.Template",
+ "template.Must": "(*template.Template, error)",
+ "template.SetMust": "(*template.Set, error)",
+ },
+}
+
+func template(f *ast.File) bool {
+ if !imports(f, "text/template") && !imports(f, "html/template") {
+ return false
+ }
+
+ fixed := false
+
+ typeof, _ := typecheck(templateTypeConfig, f)
+
+ // Now update the names used by importers.
+ walk(f, func(n interface{}) {
+ if sel, ok := n.(*ast.SelectorExpr); ok {
+ // Reference to top-level function ParseFile.
+ if isPkgDot(sel, "template", "ParseFile") {
+ sel.Sel.Name = "ParseFiles"
+ fixed = true
+ return
+ }
+ // Reference to ParseFiles method.
+ if typeof[sel.X] == "*template.Template" && sel.Sel.Name == "ParseFile" {
+ sel.Sel.Name = "ParseFiles"
+ fixed = true
+ return
+ }
+ // The Set type and its functions are now gone.
+ for _, name := range templateSetGlobals {
+ if isPkgDot(sel, "template", name) {
+ warn(sel.Pos(), "reference to template.%s must be fixed manually", name)
+ return
+ }
+ }
+ // The methods of Set are now gone.
+ for _, name := range templateSetMethods {
+ if typeof[sel.X] == "*template.Set" && sel.Sel.Name == name {
+ warn(sel.Pos(), "reference to template.*Set.%s must be fixed manually", name)
+ return
+ }
+ }
+ }
+ })
+
+ return fixed
+}
diff --git a/src/cmd/fix/template_test.go b/src/cmd/fix/template_test.go
new file mode 100644
index 000000000..f713a2901
--- /dev/null
+++ b/src/cmd/fix/template_test.go
@@ -0,0 +1,55 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(templateTests, template)
+}
+
+var templateTests = []testCase{
+ {
+ Name: "template.0",
+ In: `package main
+
+import (
+ "text/template"
+)
+
+func f() {
+ template.ParseFile(a)
+ var t template.Template
+ x, y := template.ParseFile()
+ template.New("x").Funcs(m).ParseFile(a) // chained method
+ // Output should complain about these as functions or methods.
+ var s *template.Set
+ s.ParseSetFiles(a)
+ template.ParseSetGlob(a)
+ s.ParseTemplateFiles(a)
+ template.ParseTemplateGlob(a)
+ x := template.SetMust(a())
+}
+`,
+ Out: `package main
+
+import (
+ "text/template"
+)
+
+func f() {
+ template.ParseFiles(a)
+ var t template.Template
+ x, y := template.ParseFiles()
+ template.New("x").Funcs(m).ParseFiles(a) // chained method
+ // Output should complain about these as functions or methods.
+ var s *template.Set
+ s.ParseSetFiles(a)
+ template.ParseSetGlob(a)
+ s.ParseTemplateFiles(a)
+ template.ParseTemplateGlob(a)
+ x := template.SetMust(a())
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/testdata/reflect.asn1.go.in b/src/cmd/fix/testdata/reflect.asn1.go.in
new file mode 100644
index 000000000..43128f6b2
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.asn1.go.in
@@ -0,0 +1,814 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The asn1 package implements parsing of DER-encoded ASN.1 data structures,
+// as defined in ITU-T Rec X.690.
+//
+// See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,''
+// http://luca.ntop.org/Teaching/Appunti/asn1.html.
+package asn1
+
+// ASN.1 is a syntax for specifying abstract objects and BER, DER, PER, XER etc
+// are different encoding formats for those objects. Here, we'll be dealing
+// with DER, the Distinguished Encoding Rules. DER is used in X.509 because
+// it's fast to parse and, unlike BER, has a unique encoding for every object.
+// When calculating hashes over objects, it's important that the resulting
+// bytes be the same at both ends and DER removes this margin of error.
+//
+// ASN.1 is very complex and this package doesn't attempt to implement
+// everything by any means.
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+ "time"
+)
+
+// A StructuralError suggests that the ASN.1 data is valid, but the Go type
+// which is receiving it doesn't match.
+type StructuralError struct {
+ Msg string
+}
+
+func (e StructuralError) String() string { return "ASN.1 structure error: " + e.Msg }
+
+// A SyntaxError suggests that the ASN.1 data is invalid.
+type SyntaxError struct {
+ Msg string
+}
+
+func (e SyntaxError) String() string { return "ASN.1 syntax error: " + e.Msg }
+
+// We start by dealing with each of the primitive types in turn.
+
+// BOOLEAN
+
+func parseBool(bytes []byte) (ret bool, err os.Error) {
+ if len(bytes) != 1 {
+ err = SyntaxError{"invalid boolean"}
+ return
+ }
+
+ return bytes[0] != 0, nil
+}
+
+// INTEGER
+
+// parseInt64 treats the given bytes as a big-endian, signed integer and
+// returns the result.
+func parseInt64(bytes []byte) (ret int64, err os.Error) {
+ if len(bytes) > 8 {
+ // We'll overflow an int64 in this case.
+ err = StructuralError{"integer too large"}
+ return
+ }
+ for bytesRead := 0; bytesRead < len(bytes); bytesRead++ {
+ ret <<= 8
+ ret |= int64(bytes[bytesRead])
+ }
+
+ // Shift up and down in order to sign extend the result.
+ ret <<= 64 - uint8(len(bytes))*8
+ ret >>= 64 - uint8(len(bytes))*8
+ return
+}
+
+// parseInt treats the given bytes as a big-endian, signed integer and returns
+// the result.
+func parseInt(bytes []byte) (int, os.Error) {
+ ret64, err := parseInt64(bytes)
+ if err != nil {
+ return 0, err
+ }
+ if ret64 != int64(int(ret64)) {
+ return 0, StructuralError{"integer too large"}
+ }
+ return int(ret64), nil
+}
+
+// BIT STRING
+
+// BitString is the structure to use when you want an ASN.1 BIT STRING type. A
+// bit string is padded up to the nearest byte in memory and the number of
+// valid bits is recorded. Padding bits will be zero.
+type BitString struct {
+ Bytes []byte // bits packed into bytes.
+ BitLength int // length in bits.
+}
+
+// At returns the bit at the given index. If the index is out of range it
+// returns false.
+func (b BitString) At(i int) int {
+ if i < 0 || i >= b.BitLength {
+ return 0
+ }
+ x := i / 8
+ y := 7 - uint(i%8)
+ return int(b.Bytes[x]>>y) & 1
+}
+
+// RightAlign returns a slice where the padding bits are at the beginning. The
+// slice may share memory with the BitString.
+func (b BitString) RightAlign() []byte {
+ shift := uint(8 - (b.BitLength % 8))
+ if shift == 8 || len(b.Bytes) == 0 {
+ return b.Bytes
+ }
+
+ a := make([]byte, len(b.Bytes))
+ a[0] = b.Bytes[0] >> shift
+ for i := 1; i < len(b.Bytes); i++ {
+ a[i] = b.Bytes[i-1] << (8 - shift)
+ a[i] |= b.Bytes[i] >> shift
+ }
+
+ return a
+}
+
+// parseBitString parses an ASN.1 bit string from the given byte array and returns it.
+func parseBitString(bytes []byte) (ret BitString, err os.Error) {
+ if len(bytes) == 0 {
+ err = SyntaxError{"zero length BIT STRING"}
+ return
+ }
+ paddingBits := int(bytes[0])
+ if paddingBits > 7 ||
+ len(bytes) == 1 && paddingBits > 0 ||
+ bytes[len(bytes)-1]&((1<<bytes[0])-1) != 0 {
+ err = SyntaxError{"invalid padding bits in BIT STRING"}
+ return
+ }
+ ret.BitLength = (len(bytes)-1)*8 - paddingBits
+ ret.Bytes = bytes[1:]
+ return
+}
+
+// OBJECT IDENTIFIER
+
+// An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER.
+type ObjectIdentifier []int
+
+// Equal returns true iff oi and other represent the same identifier.
+func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
+ if len(oi) != len(other) {
+ return false
+ }
+ for i := 0; i < len(oi); i++ {
+ if oi[i] != other[i] {
+ return false
+ }
+ }
+
+ return true
+}
+
+// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and
+// returns it. An object identifer is a sequence of variable length integers
+// that are assigned in a hierarachy.
+func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
+ if len(bytes) == 0 {
+ err = SyntaxError{"zero length OBJECT IDENTIFIER"}
+ return
+ }
+
+ // In the worst case, we get two elements from the first byte (which is
+ // encoded differently) and then every varint is a single byte long.
+ s = make([]int, len(bytes)+1)
+
+ // The first byte is 40*value1 + value2:
+ s[0] = int(bytes[0]) / 40
+ s[1] = int(bytes[0]) % 40
+ i := 2
+ for offset := 1; offset < len(bytes); i++ {
+ var v int
+ v, offset, err = parseBase128Int(bytes, offset)
+ if err != nil {
+ return
+ }
+ s[i] = v
+ }
+ s = s[0:i]
+ return
+}
+
+// ENUMERATED
+
+// An Enumerated is represented as a plain int.
+type Enumerated int
+
+// FLAG
+
+// A Flag accepts any data and is set to true if present.
+type Flag bool
+
+// parseBase128Int parses a base-128 encoded int from the given offset in the
+// given byte array. It returns the value and the new offset.
+func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) {
+ offset = initOffset
+ for shifted := 0; offset < len(bytes); shifted++ {
+ if shifted > 4 {
+ err = StructuralError{"base 128 integer too large"}
+ return
+ }
+ ret <<= 7
+ b := bytes[offset]
+ ret |= int(b & 0x7f)
+ offset++
+ if b&0x80 == 0 {
+ return
+ }
+ }
+ err = SyntaxError{"truncated base 128 integer"}
+ return
+}
+
+// UTCTime
+
+func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) {
+ s := string(bytes)
+ ret, err = time.Parse("0601021504Z0700", s)
+ if err == nil {
+ return
+ }
+ ret, err = time.Parse("060102150405Z0700", s)
+ return
+}
+
+// parseGeneralizedTime parses the GeneralizedTime from the given byte array
+// and returns the resulting time.
+func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) {
+ return time.Parse("20060102150405Z0700", string(bytes))
+}
+
+// PrintableString
+
+// parsePrintableString parses a ASN.1 PrintableString from the given byte
+// array and returns it.
+func parsePrintableString(bytes []byte) (ret string, err os.Error) {
+ for _, b := range bytes {
+ if !isPrintable(b) {
+ err = SyntaxError{"PrintableString contains invalid character"}
+ return
+ }
+ }
+ ret = string(bytes)
+ return
+}
+
+// isPrintable returns true iff the given b is in the ASN.1 PrintableString set.
+func isPrintable(b byte) bool {
+ return 'a' <= b && b <= 'z' ||
+ 'A' <= b && b <= 'Z' ||
+ '0' <= b && b <= '9' ||
+ '\'' <= b && b <= ')' ||
+ '+' <= b && b <= '/' ||
+ b == ' ' ||
+ b == ':' ||
+ b == '=' ||
+ b == '?' ||
+ // This is techincally not allowed in a PrintableString.
+ // However, x509 certificates with wildcard strings don't
+ // always use the correct string type so we permit it.
+ b == '*'
+}
+
+// IA5String
+
+// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given
+// byte array and returns it.
+func parseIA5String(bytes []byte) (ret string, err os.Error) {
+ for _, b := range bytes {
+ if b >= 0x80 {
+ err = SyntaxError{"IA5String contains invalid character"}
+ return
+ }
+ }
+ ret = string(bytes)
+ return
+}
+
+// T61String
+
+// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given
+// byte array and returns it.
+func parseT61String(bytes []byte) (ret string, err os.Error) {
+ return string(bytes), nil
+}
+
+// A RawValue represents an undecoded ASN.1 object.
+type RawValue struct {
+ Class, Tag int
+ IsCompound bool
+ Bytes []byte
+ FullBytes []byte // includes the tag and length
+}
+
+// RawContent is used to signal that the undecoded, DER data needs to be
+// preserved for a struct. To use it, the first field of the struct must have
+// this type. It's an error for any of the other fields to have this type.
+type RawContent []byte
+
+// Tagging
+
+// parseTagAndLength parses an ASN.1 tag and length pair from the given offset
+// into a byte array. It returns the parsed data and the new offset. SET and
+// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
+// don't distinguish between ordered and unordered objects in this code.
+func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) {
+ offset = initOffset
+ b := bytes[offset]
+ offset++
+ ret.class = int(b >> 6)
+ ret.isCompound = b&0x20 == 0x20
+ ret.tag = int(b & 0x1f)
+
+ // If the bottom five bits are set, then the tag number is actually base 128
+ // encoded afterwards
+ if ret.tag == 0x1f {
+ ret.tag, offset, err = parseBase128Int(bytes, offset)
+ if err != nil {
+ return
+ }
+ }
+ if offset >= len(bytes) {
+ err = SyntaxError{"truncated tag or length"}
+ return
+ }
+ b = bytes[offset]
+ offset++
+ if b&0x80 == 0 {
+ // The length is encoded in the bottom 7 bits.
+ ret.length = int(b & 0x7f)
+ } else {
+ // Bottom 7 bits give the number of length bytes to follow.
+ numBytes := int(b & 0x7f)
+ // We risk overflowing a signed 32-bit number if we accept more than 3 bytes.
+ if numBytes > 3 {
+ err = StructuralError{"length too large"}
+ return
+ }
+ if numBytes == 0 {
+ err = SyntaxError{"indefinite length found (not DER)"}
+ return
+ }
+ ret.length = 0
+ for i := 0; i < numBytes; i++ {
+ if offset >= len(bytes) {
+ err = SyntaxError{"truncated tag or length"}
+ return
+ }
+ b = bytes[offset]
+ offset++
+ ret.length <<= 8
+ ret.length |= int(b)
+ }
+ }
+
+ return
+}
+
+// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
+// a number of ASN.1 values from the given byte array and returns them as a
+// slice of Go values of the given type.
+func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflect.Type) (ret *reflect.SliceValue, err os.Error) {
+ expectedTag, compoundType, ok := getUniversalType(elemType)
+ if !ok {
+ err = StructuralError{"unknown Go type for slice"}
+ return
+ }
+
+ // First we iterate over the input and count the number of elements,
+ // checking that the types are correct in each case.
+ numElements := 0
+ for offset := 0; offset < len(bytes); {
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ // We pretend that GENERAL STRINGs are PRINTABLE STRINGs so
+ // that a sequence of them can be parsed into a []string.
+ if t.tag == tagGeneralString {
+ t.tag = tagPrintableString
+ }
+ if t.class != classUniversal || t.isCompound != compoundType || t.tag != expectedTag {
+ err = StructuralError{"sequence tag mismatch"}
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"truncated sequence"}
+ return
+ }
+ offset += t.length
+ numElements++
+ }
+ ret = reflect.MakeSlice(sliceType, numElements, numElements)
+ params := fieldParameters{}
+ offset := 0
+ for i := 0; i < numElements; i++ {
+ offset, err = parseField(ret.Elem(i), bytes, offset, params)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+
+var (
+ bitStringType = reflect.Typeof(BitString{})
+ objectIdentifierType = reflect.Typeof(ObjectIdentifier{})
+ enumeratedType = reflect.Typeof(Enumerated(0))
+ flagType = reflect.Typeof(Flag(false))
+ timeType = reflect.Typeof(&time.Time{})
+ rawValueType = reflect.Typeof(RawValue{})
+ rawContentsType = reflect.Typeof(RawContent(nil))
+)
+
+// invalidLength returns true iff offset + length > sliceLength, or if the
+// addition would overflow.
+func invalidLength(offset, length, sliceLength int) bool {
+ return offset+length < offset || offset+length > sliceLength
+}
+
+// parseField is the main parsing function. Given a byte array and an offset
+// into the array, it will try to parse a suitable ASN.1 value out and store it
+// in the given Value.
+func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) {
+ offset = initOffset
+ fieldType := v.Type()
+
+ // If we have run out of data, it may be that there are optional elements at the end.
+ if offset == len(bytes) {
+ if !setDefaultValue(v, params) {
+ err = SyntaxError{"sequence truncated"}
+ }
+ return
+ }
+
+ // Deal with raw values.
+ if fieldType == rawValueType {
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]}
+ offset += t.length
+ v.(*reflect.StructValue).Set(reflect.NewValue(result).(*reflect.StructValue))
+ return
+ }
+
+ // Deal with the ANY type.
+ if ifaceType, ok := fieldType.(*reflect.InterfaceType); ok && ifaceType.NumMethod() == 0 {
+ ifaceValue := v.(*reflect.InterfaceValue)
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ var result interface{}
+ if !t.isCompound && t.class == classUniversal {
+ innerBytes := bytes[offset : offset+t.length]
+ switch t.tag {
+ case tagPrintableString:
+ result, err = parsePrintableString(innerBytes)
+ case tagIA5String:
+ result, err = parseIA5String(innerBytes)
+ case tagT61String:
+ result, err = parseT61String(innerBytes)
+ case tagInteger:
+ result, err = parseInt64(innerBytes)
+ case tagBitString:
+ result, err = parseBitString(innerBytes)
+ case tagOID:
+ result, err = parseObjectIdentifier(innerBytes)
+ case tagUTCTime:
+ result, err = parseUTCTime(innerBytes)
+ case tagOctetString:
+ result = innerBytes
+ default:
+ // If we don't know how to handle the type, we just leave Value as nil.
+ }
+ }
+ offset += t.length
+ if err != nil {
+ return
+ }
+ if result != nil {
+ ifaceValue.Set(reflect.NewValue(result))
+ }
+ return
+ }
+ universalTag, compoundType, ok1 := getUniversalType(fieldType)
+ if !ok1 {
+ err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)}
+ return
+ }
+
+ t, offset, err := parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if params.explicit {
+ expectedClass := classContextSpecific
+ if params.application {
+ expectedClass = classApplication
+ }
+ if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) {
+ if t.length > 0 {
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ } else {
+ if fieldType != flagType {
+ err = StructuralError{"Zero length explicit tag was not an asn1.Flag"}
+ return
+ }
+
+ flagValue := v.(*reflect.BoolValue)
+ flagValue.Set(true)
+ return
+ }
+ } else {
+ // The tags didn't match, it might be an optional element.
+ ok := setDefaultValue(v, params)
+ if ok {
+ offset = initOffset
+ } else {
+ err = StructuralError{"explicitly tagged member didn't match"}
+ }
+ return
+ }
+ }
+
+ // Special case for strings: PrintableString and IA5String both map to
+ // the Go type string. getUniversalType returns the tag for
+ // PrintableString when it sees a string so, if we see an IA5String on
+ // the wire, we change the universal type to match.
+ if universalTag == tagPrintableString && t.tag == tagIA5String {
+ universalTag = tagIA5String
+ }
+ // Likewise for GeneralString
+ if universalTag == tagPrintableString && t.tag == tagGeneralString {
+ universalTag = tagGeneralString
+ }
+
+ // Special case for time: UTCTime and GeneralizedTime both map to the
+ // Go type time.Time.
+ if universalTag == tagUTCTime && t.tag == tagGeneralizedTime {
+ universalTag = tagGeneralizedTime
+ }
+
+ expectedClass := classUniversal
+ expectedTag := universalTag
+
+ if !params.explicit && params.tag != nil {
+ expectedClass = classContextSpecific
+ expectedTag = *params.tag
+ }
+
+ if !params.explicit && params.application && params.tag != nil {
+ expectedClass = classApplication
+ expectedTag = *params.tag
+ }
+
+ // We have unwrapped any explicit tagging at this point.
+ if t.class != expectedClass || t.tag != expectedTag || t.isCompound != compoundType {
+ // Tags don't match. Again, it could be an optional element.
+ ok := setDefaultValue(v, params)
+ if ok {
+ offset = initOffset
+ } else {
+ err = StructuralError{fmt.Sprintf("tags don't match (%d vs %+v) %+v %s @%d", expectedTag, t, params, fieldType.Name(), offset)}
+ }
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ innerBytes := bytes[offset : offset+t.length]
+ offset += t.length
+
+ // We deal with the structures defined in this package first.
+ switch fieldType {
+ case objectIdentifierType:
+ newSlice, err1 := parseObjectIdentifier(innerBytes)
+ sliceValue := v.(*reflect.SliceValue)
+ sliceValue.Set(reflect.MakeSlice(sliceValue.Type().(*reflect.SliceType), len(newSlice), len(newSlice)))
+ if err1 == nil {
+ reflect.Copy(sliceValue, reflect.NewValue(newSlice).(reflect.ArrayOrSliceValue))
+ }
+ err = err1
+ return
+ case bitStringType:
+ structValue := v.(*reflect.StructValue)
+ bs, err1 := parseBitString(innerBytes)
+ if err1 == nil {
+ structValue.Set(reflect.NewValue(bs).(*reflect.StructValue))
+ }
+ err = err1
+ return
+ case timeType:
+ ptrValue := v.(*reflect.PtrValue)
+ var time *time.Time
+ var err1 os.Error
+ if universalTag == tagUTCTime {
+ time, err1 = parseUTCTime(innerBytes)
+ } else {
+ time, err1 = parseGeneralizedTime(innerBytes)
+ }
+ if err1 == nil {
+ ptrValue.Set(reflect.NewValue(time).(*reflect.PtrValue))
+ }
+ err = err1
+ return
+ case enumeratedType:
+ parsedInt, err1 := parseInt(innerBytes)
+ enumValue := v.(*reflect.IntValue)
+ if err1 == nil {
+ enumValue.Set(int64(parsedInt))
+ }
+ err = err1
+ return
+ case flagType:
+ flagValue := v.(*reflect.BoolValue)
+ flagValue.Set(true)
+ return
+ }
+ switch val := v.(type) {
+ case *reflect.BoolValue:
+ parsedBool, err1 := parseBool(innerBytes)
+ if err1 == nil {
+ val.Set(parsedBool)
+ }
+ err = err1
+ return
+ case *reflect.IntValue:
+ switch val.Type().Kind() {
+ case reflect.Int:
+ parsedInt, err1 := parseInt(innerBytes)
+ if err1 == nil {
+ val.Set(int64(parsedInt))
+ }
+ err = err1
+ return
+ case reflect.Int64:
+ parsedInt, err1 := parseInt64(innerBytes)
+ if err1 == nil {
+ val.Set(parsedInt)
+ }
+ err = err1
+ return
+ }
+ case *reflect.StructValue:
+ structType := fieldType.(*reflect.StructType)
+
+ if structType.NumField() > 0 &&
+ structType.Field(0).Type == rawContentsType {
+ bytes := bytes[initOffset:offset]
+ val.Field(0).SetValue(reflect.NewValue(RawContent(bytes)))
+ }
+
+ innerOffset := 0
+ for i := 0; i < structType.NumField(); i++ {
+ field := structType.Field(i)
+ if i == 0 && field.Type == rawContentsType {
+ continue
+ }
+ innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag))
+ if err != nil {
+ return
+ }
+ }
+ // We allow extra bytes at the end of the SEQUENCE because
+ // adding elements to the end has been used in X.509 as the
+ // version numbers have increased.
+ return
+ case *reflect.SliceValue:
+ sliceType := fieldType.(*reflect.SliceType)
+ if sliceType.Elem().Kind() == reflect.Uint8 {
+ val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes)))
+ reflect.Copy(val, reflect.NewValue(innerBytes).(reflect.ArrayOrSliceValue))
+ return
+ }
+ newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem())
+ if err1 == nil {
+ val.Set(newSlice)
+ }
+ err = err1
+ return
+ case *reflect.StringValue:
+ var v string
+ switch universalTag {
+ case tagPrintableString:
+ v, err = parsePrintableString(innerBytes)
+ case tagIA5String:
+ v, err = parseIA5String(innerBytes)
+ case tagT61String:
+ v, err = parseT61String(innerBytes)
+ case tagGeneralString:
+ // GeneralString is specified in ISO-2022/ECMA-35,
+ // A brief review suggests that it includes structures
+ // that allow the encoding to change midstring and
+ // such. We give up and pass it as an 8-bit string.
+ v, err = parseT61String(innerBytes)
+ default:
+ err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)}
+ }
+ if err == nil {
+ val.Set(v)
+ }
+ return
+ }
+ err = StructuralError{"unknown Go type"}
+ return
+}
+
+// setDefaultValue is used to install a default value, from a tag string, into
+// a Value. It is successful is the field was optional, even if a default value
+// wasn't provided or it failed to install it into the Value.
+func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
+ if !params.optional {
+ return
+ }
+ ok = true
+ if params.defaultValue == nil {
+ return
+ }
+ switch val := v.(type) {
+ case *reflect.IntValue:
+ val.Set(*params.defaultValue)
+ }
+ return
+}
+
+// Unmarshal parses the DER-encoded ASN.1 data structure b
+// and uses the reflect package to fill in an arbitrary value pointed at by val.
+// Because Unmarshal uses the reflect package, the structs
+// being written to must use upper case field names.
+//
+// An ASN.1 INTEGER can be written to an int or int64.
+// If the encoded value does not fit in the Go type,
+// Unmarshal returns a parse error.
+//
+// An ASN.1 BIT STRING can be written to a BitString.
+//
+// An ASN.1 OCTET STRING can be written to a []byte.
+//
+// An ASN.1 OBJECT IDENTIFIER can be written to an
+// ObjectIdentifier.
+//
+// An ASN.1 ENUMERATED can be written to an Enumerated.
+//
+// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a *time.Time.
+//
+// An ASN.1 PrintableString or IA5String can be written to a string.
+//
+// Any of the above ASN.1 values can be written to an interface{}.
+// The value stored in the interface has the corresponding Go type.
+// For integers, that type is int64.
+//
+// An ASN.1 SEQUENCE OF x or SET OF x can be written
+// to a slice if an x can be written to the slice's element type.
+//
+// An ASN.1 SEQUENCE or SET can be written to a struct
+// if each of the elements in the sequence can be
+// written to the corresponding element in the struct.
+//
+// The following tags on struct fields have special meaning to Unmarshal:
+//
+// optional marks the field as ASN.1 OPTIONAL
+// [explicit] tag:x specifies the ASN.1 tag number; implies ASN.1 CONTEXT SPECIFIC
+// default:x sets the default value for optional integer fields
+//
+// If the type of the first field of a structure is RawContent then the raw
+// ASN1 contents of the struct will be stored in it.
+//
+// Other ASN.1 types are not supported; if it encounters them,
+// Unmarshal returns a parse error.
+func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) {
+ return UnmarshalWithParams(b, val, "")
+}
+
+// UnmarshalWithParams allows field parameters to be specified for the
+// top-level element. The form of the params is the same as the field tags.
+func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) {
+ v := reflect.NewValue(val).(*reflect.PtrValue).Elem()
+ offset, err := parseField(v, b, 0, parseFieldParameters(params))
+ if err != nil {
+ return nil, err
+ }
+ return b[offset:], nil
+}
diff --git a/src/cmd/fix/testdata/reflect.asn1.go.out b/src/cmd/fix/testdata/reflect.asn1.go.out
new file mode 100644
index 000000000..ba6224e6d
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.asn1.go.out
@@ -0,0 +1,814 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The asn1 package implements parsing of DER-encoded ASN.1 data structures,
+// as defined in ITU-T Rec X.690.
+//
+// See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,''
+// http://luca.ntop.org/Teaching/Appunti/asn1.html.
+package asn1
+
+// ASN.1 is a syntax for specifying abstract objects and BER, DER, PER, XER etc
+// are different encoding formats for those objects. Here, we'll be dealing
+// with DER, the Distinguished Encoding Rules. DER is used in X.509 because
+// it's fast to parse and, unlike BER, has a unique encoding for every object.
+// When calculating hashes over objects, it's important that the resulting
+// bytes be the same at both ends and DER removes this margin of error.
+//
+// ASN.1 is very complex and this package doesn't attempt to implement
+// everything by any means.
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+ "time"
+)
+
+// A StructuralError suggests that the ASN.1 data is valid, but the Go type
+// which is receiving it doesn't match.
+type StructuralError struct {
+ Msg string
+}
+
+func (e StructuralError) String() string { return "ASN.1 structure error: " + e.Msg }
+
+// A SyntaxError suggests that the ASN.1 data is invalid.
+type SyntaxError struct {
+ Msg string
+}
+
+func (e SyntaxError) String() string { return "ASN.1 syntax error: " + e.Msg }
+
+// We start by dealing with each of the primitive types in turn.
+
+// BOOLEAN
+
+func parseBool(bytes []byte) (ret bool, err os.Error) {
+ if len(bytes) != 1 {
+ err = SyntaxError{"invalid boolean"}
+ return
+ }
+
+ return bytes[0] != 0, nil
+}
+
+// INTEGER
+
+// parseInt64 treats the given bytes as a big-endian, signed integer and
+// returns the result.
+func parseInt64(bytes []byte) (ret int64, err os.Error) {
+ if len(bytes) > 8 {
+ // We'll overflow an int64 in this case.
+ err = StructuralError{"integer too large"}
+ return
+ }
+ for bytesRead := 0; bytesRead < len(bytes); bytesRead++ {
+ ret <<= 8
+ ret |= int64(bytes[bytesRead])
+ }
+
+ // Shift up and down in order to sign extend the result.
+ ret <<= 64 - uint8(len(bytes))*8
+ ret >>= 64 - uint8(len(bytes))*8
+ return
+}
+
+// parseInt treats the given bytes as a big-endian, signed integer and returns
+// the result.
+func parseInt(bytes []byte) (int, os.Error) {
+ ret64, err := parseInt64(bytes)
+ if err != nil {
+ return 0, err
+ }
+ if ret64 != int64(int(ret64)) {
+ return 0, StructuralError{"integer too large"}
+ }
+ return int(ret64), nil
+}
+
+// BIT STRING
+
+// BitString is the structure to use when you want an ASN.1 BIT STRING type. A
+// bit string is padded up to the nearest byte in memory and the number of
+// valid bits is recorded. Padding bits will be zero.
+type BitString struct {
+ Bytes []byte // bits packed into bytes.
+ BitLength int // length in bits.
+}
+
+// At returns the bit at the given index. If the index is out of range it
+// returns false.
+func (b BitString) At(i int) int {
+ if i < 0 || i >= b.BitLength {
+ return 0
+ }
+ x := i / 8
+ y := 7 - uint(i%8)
+ return int(b.Bytes[x]>>y) & 1
+}
+
+// RightAlign returns a slice where the padding bits are at the beginning. The
+// slice may share memory with the BitString.
+func (b BitString) RightAlign() []byte {
+ shift := uint(8 - (b.BitLength % 8))
+ if shift == 8 || len(b.Bytes) == 0 {
+ return b.Bytes
+ }
+
+ a := make([]byte, len(b.Bytes))
+ a[0] = b.Bytes[0] >> shift
+ for i := 1; i < len(b.Bytes); i++ {
+ a[i] = b.Bytes[i-1] << (8 - shift)
+ a[i] |= b.Bytes[i] >> shift
+ }
+
+ return a
+}
+
+// parseBitString parses an ASN.1 bit string from the given byte array and returns it.
+func parseBitString(bytes []byte) (ret BitString, err os.Error) {
+ if len(bytes) == 0 {
+ err = SyntaxError{"zero length BIT STRING"}
+ return
+ }
+ paddingBits := int(bytes[0])
+ if paddingBits > 7 ||
+ len(bytes) == 1 && paddingBits > 0 ||
+ bytes[len(bytes)-1]&((1<<bytes[0])-1) != 0 {
+ err = SyntaxError{"invalid padding bits in BIT STRING"}
+ return
+ }
+ ret.BitLength = (len(bytes)-1)*8 - paddingBits
+ ret.Bytes = bytes[1:]
+ return
+}
+
+// OBJECT IDENTIFIER
+
+// An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER.
+type ObjectIdentifier []int
+
+// Equal returns true iff oi and other represent the same identifier.
+func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
+ if len(oi) != len(other) {
+ return false
+ }
+ for i := 0; i < len(oi); i++ {
+ if oi[i] != other[i] {
+ return false
+ }
+ }
+
+ return true
+}
+
+// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and
+// returns it. An object identifer is a sequence of variable length integers
+// that are assigned in a hierarachy.
+func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
+ if len(bytes) == 0 {
+ err = SyntaxError{"zero length OBJECT IDENTIFIER"}
+ return
+ }
+
+ // In the worst case, we get two elements from the first byte (which is
+ // encoded differently) and then every varint is a single byte long.
+ s = make([]int, len(bytes)+1)
+
+ // The first byte is 40*value1 + value2:
+ s[0] = int(bytes[0]) / 40
+ s[1] = int(bytes[0]) % 40
+ i := 2
+ for offset := 1; offset < len(bytes); i++ {
+ var v int
+ v, offset, err = parseBase128Int(bytes, offset)
+ if err != nil {
+ return
+ }
+ s[i] = v
+ }
+ s = s[0:i]
+ return
+}
+
+// ENUMERATED
+
+// An Enumerated is represented as a plain int.
+type Enumerated int
+
+// FLAG
+
+// A Flag accepts any data and is set to true if present.
+type Flag bool
+
+// parseBase128Int parses a base-128 encoded int from the given offset in the
+// given byte array. It returns the value and the new offset.
+func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) {
+ offset = initOffset
+ for shifted := 0; offset < len(bytes); shifted++ {
+ if shifted > 4 {
+ err = StructuralError{"base 128 integer too large"}
+ return
+ }
+ ret <<= 7
+ b := bytes[offset]
+ ret |= int(b & 0x7f)
+ offset++
+ if b&0x80 == 0 {
+ return
+ }
+ }
+ err = SyntaxError{"truncated base 128 integer"}
+ return
+}
+
+// UTCTime
+
+func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) {
+ s := string(bytes)
+ ret, err = time.Parse("0601021504Z0700", s)
+ if err == nil {
+ return
+ }
+ ret, err = time.Parse("060102150405Z0700", s)
+ return
+}
+
+// parseGeneralizedTime parses the GeneralizedTime from the given byte array
+// and returns the resulting time.
+func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) {
+ return time.Parse("20060102150405Z0700", string(bytes))
+}
+
+// PrintableString
+
+// parsePrintableString parses a ASN.1 PrintableString from the given byte
+// array and returns it.
+func parsePrintableString(bytes []byte) (ret string, err os.Error) {
+ for _, b := range bytes {
+ if !isPrintable(b) {
+ err = SyntaxError{"PrintableString contains invalid character"}
+ return
+ }
+ }
+ ret = string(bytes)
+ return
+}
+
+// isPrintable returns true iff the given b is in the ASN.1 PrintableString set.
+func isPrintable(b byte) bool {
+ return 'a' <= b && b <= 'z' ||
+ 'A' <= b && b <= 'Z' ||
+ '0' <= b && b <= '9' ||
+ '\'' <= b && b <= ')' ||
+ '+' <= b && b <= '/' ||
+ b == ' ' ||
+ b == ':' ||
+ b == '=' ||
+ b == '?' ||
+ // This is techincally not allowed in a PrintableString.
+ // However, x509 certificates with wildcard strings don't
+ // always use the correct string type so we permit it.
+ b == '*'
+}
+
+// IA5String
+
+// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given
+// byte array and returns it.
+func parseIA5String(bytes []byte) (ret string, err os.Error) {
+ for _, b := range bytes {
+ if b >= 0x80 {
+ err = SyntaxError{"IA5String contains invalid character"}
+ return
+ }
+ }
+ ret = string(bytes)
+ return
+}
+
+// T61String
+
+// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given
+// byte array and returns it.
+func parseT61String(bytes []byte) (ret string, err os.Error) {
+ return string(bytes), nil
+}
+
+// A RawValue represents an undecoded ASN.1 object.
+type RawValue struct {
+ Class, Tag int
+ IsCompound bool
+ Bytes []byte
+ FullBytes []byte // includes the tag and length
+}
+
+// RawContent is used to signal that the undecoded, DER data needs to be
+// preserved for a struct. To use it, the first field of the struct must have
+// this type. It's an error for any of the other fields to have this type.
+type RawContent []byte
+
+// Tagging
+
+// parseTagAndLength parses an ASN.1 tag and length pair from the given offset
+// into a byte array. It returns the parsed data and the new offset. SET and
+// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
+// don't distinguish between ordered and unordered objects in this code.
+func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) {
+ offset = initOffset
+ b := bytes[offset]
+ offset++
+ ret.class = int(b >> 6)
+ ret.isCompound = b&0x20 == 0x20
+ ret.tag = int(b & 0x1f)
+
+ // If the bottom five bits are set, then the tag number is actually base 128
+ // encoded afterwards
+ if ret.tag == 0x1f {
+ ret.tag, offset, err = parseBase128Int(bytes, offset)
+ if err != nil {
+ return
+ }
+ }
+ if offset >= len(bytes) {
+ err = SyntaxError{"truncated tag or length"}
+ return
+ }
+ b = bytes[offset]
+ offset++
+ if b&0x80 == 0 {
+ // The length is encoded in the bottom 7 bits.
+ ret.length = int(b & 0x7f)
+ } else {
+ // Bottom 7 bits give the number of length bytes to follow.
+ numBytes := int(b & 0x7f)
+ // We risk overflowing a signed 32-bit number if we accept more than 3 bytes.
+ if numBytes > 3 {
+ err = StructuralError{"length too large"}
+ return
+ }
+ if numBytes == 0 {
+ err = SyntaxError{"indefinite length found (not DER)"}
+ return
+ }
+ ret.length = 0
+ for i := 0; i < numBytes; i++ {
+ if offset >= len(bytes) {
+ err = SyntaxError{"truncated tag or length"}
+ return
+ }
+ b = bytes[offset]
+ offset++
+ ret.length <<= 8
+ ret.length |= int(b)
+ }
+ }
+
+ return
+}
+
+// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
+// a number of ASN.1 values from the given byte array and returns them as a
+// slice of Go values of the given type.
+func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) {
+ expectedTag, compoundType, ok := getUniversalType(elemType)
+ if !ok {
+ err = StructuralError{"unknown Go type for slice"}
+ return
+ }
+
+ // First we iterate over the input and count the number of elements,
+ // checking that the types are correct in each case.
+ numElements := 0
+ for offset := 0; offset < len(bytes); {
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ // We pretend that GENERAL STRINGs are PRINTABLE STRINGs so
+ // that a sequence of them can be parsed into a []string.
+ if t.tag == tagGeneralString {
+ t.tag = tagPrintableString
+ }
+ if t.class != classUniversal || t.isCompound != compoundType || t.tag != expectedTag {
+ err = StructuralError{"sequence tag mismatch"}
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"truncated sequence"}
+ return
+ }
+ offset += t.length
+ numElements++
+ }
+ ret = reflect.MakeSlice(sliceType, numElements, numElements)
+ params := fieldParameters{}
+ offset := 0
+ for i := 0; i < numElements; i++ {
+ offset, err = parseField(ret.Index(i), bytes, offset, params)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+
+var (
+ bitStringType = reflect.TypeOf(BitString{})
+ objectIdentifierType = reflect.TypeOf(ObjectIdentifier{})
+ enumeratedType = reflect.TypeOf(Enumerated(0))
+ flagType = reflect.TypeOf(Flag(false))
+ timeType = reflect.TypeOf(&time.Time{})
+ rawValueType = reflect.TypeOf(RawValue{})
+ rawContentsType = reflect.TypeOf(RawContent(nil))
+)
+
+// invalidLength returns true iff offset + length > sliceLength, or if the
+// addition would overflow.
+func invalidLength(offset, length, sliceLength int) bool {
+ return offset+length < offset || offset+length > sliceLength
+}
+
+// parseField is the main parsing function. Given a byte array and an offset
+// into the array, it will try to parse a suitable ASN.1 value out and store it
+// in the given Value.
+func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) {
+ offset = initOffset
+ fieldType := v.Type()
+
+ // If we have run out of data, it may be that there are optional elements at the end.
+ if offset == len(bytes) {
+ if !setDefaultValue(v, params) {
+ err = SyntaxError{"sequence truncated"}
+ }
+ return
+ }
+
+ // Deal with raw values.
+ if fieldType == rawValueType {
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]}
+ offset += t.length
+ v.Set(reflect.ValueOf(result))
+ return
+ }
+
+ // Deal with the ANY type.
+ if ifaceType := fieldType; ifaceType.Kind() == reflect.Interface && ifaceType.NumMethod() == 0 {
+ ifaceValue := v
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ var result interface{}
+ if !t.isCompound && t.class == classUniversal {
+ innerBytes := bytes[offset : offset+t.length]
+ switch t.tag {
+ case tagPrintableString:
+ result, err = parsePrintableString(innerBytes)
+ case tagIA5String:
+ result, err = parseIA5String(innerBytes)
+ case tagT61String:
+ result, err = parseT61String(innerBytes)
+ case tagInteger:
+ result, err = parseInt64(innerBytes)
+ case tagBitString:
+ result, err = parseBitString(innerBytes)
+ case tagOID:
+ result, err = parseObjectIdentifier(innerBytes)
+ case tagUTCTime:
+ result, err = parseUTCTime(innerBytes)
+ case tagOctetString:
+ result = innerBytes
+ default:
+ // If we don't know how to handle the type, we just leave Value as nil.
+ }
+ }
+ offset += t.length
+ if err != nil {
+ return
+ }
+ if result != nil {
+ ifaceValue.Set(reflect.ValueOf(result))
+ }
+ return
+ }
+ universalTag, compoundType, ok1 := getUniversalType(fieldType)
+ if !ok1 {
+ err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)}
+ return
+ }
+
+ t, offset, err := parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if params.explicit {
+ expectedClass := classContextSpecific
+ if params.application {
+ expectedClass = classApplication
+ }
+ if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) {
+ if t.length > 0 {
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ } else {
+ if fieldType != flagType {
+ err = StructuralError{"Zero length explicit tag was not an asn1.Flag"}
+ return
+ }
+
+ flagValue := v
+ flagValue.SetBool(true)
+ return
+ }
+ } else {
+ // The tags didn't match, it might be an optional element.
+ ok := setDefaultValue(v, params)
+ if ok {
+ offset = initOffset
+ } else {
+ err = StructuralError{"explicitly tagged member didn't match"}
+ }
+ return
+ }
+ }
+
+ // Special case for strings: PrintableString and IA5String both map to
+ // the Go type string. getUniversalType returns the tag for
+ // PrintableString when it sees a string so, if we see an IA5String on
+ // the wire, we change the universal type to match.
+ if universalTag == tagPrintableString && t.tag == tagIA5String {
+ universalTag = tagIA5String
+ }
+ // Likewise for GeneralString
+ if universalTag == tagPrintableString && t.tag == tagGeneralString {
+ universalTag = tagGeneralString
+ }
+
+ // Special case for time: UTCTime and GeneralizedTime both map to the
+ // Go type time.Time.
+ if universalTag == tagUTCTime && t.tag == tagGeneralizedTime {
+ universalTag = tagGeneralizedTime
+ }
+
+ expectedClass := classUniversal
+ expectedTag := universalTag
+
+ if !params.explicit && params.tag != nil {
+ expectedClass = classContextSpecific
+ expectedTag = *params.tag
+ }
+
+ if !params.explicit && params.application && params.tag != nil {
+ expectedClass = classApplication
+ expectedTag = *params.tag
+ }
+
+ // We have unwrapped any explicit tagging at this point.
+ if t.class != expectedClass || t.tag != expectedTag || t.isCompound != compoundType {
+ // Tags don't match. Again, it could be an optional element.
+ ok := setDefaultValue(v, params)
+ if ok {
+ offset = initOffset
+ } else {
+ err = StructuralError{fmt.Sprintf("tags don't match (%d vs %+v) %+v %s @%d", expectedTag, t, params, fieldType.Name(), offset)}
+ }
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ innerBytes := bytes[offset : offset+t.length]
+ offset += t.length
+
+ // We deal with the structures defined in this package first.
+ switch fieldType {
+ case objectIdentifierType:
+ newSlice, err1 := parseObjectIdentifier(innerBytes)
+ sliceValue := v
+ sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), len(newSlice), len(newSlice)))
+ if err1 == nil {
+ reflect.Copy(sliceValue, reflect.ValueOf(newSlice))
+ }
+ err = err1
+ return
+ case bitStringType:
+ structValue := v
+ bs, err1 := parseBitString(innerBytes)
+ if err1 == nil {
+ structValue.Set(reflect.ValueOf(bs))
+ }
+ err = err1
+ return
+ case timeType:
+ ptrValue := v
+ var time *time.Time
+ var err1 os.Error
+ if universalTag == tagUTCTime {
+ time, err1 = parseUTCTime(innerBytes)
+ } else {
+ time, err1 = parseGeneralizedTime(innerBytes)
+ }
+ if err1 == nil {
+ ptrValue.Set(reflect.ValueOf(time))
+ }
+ err = err1
+ return
+ case enumeratedType:
+ parsedInt, err1 := parseInt(innerBytes)
+ enumValue := v
+ if err1 == nil {
+ enumValue.SetInt(int64(parsedInt))
+ }
+ err = err1
+ return
+ case flagType:
+ flagValue := v
+ flagValue.SetBool(true)
+ return
+ }
+ switch val := v; val.Kind() {
+ case reflect.Bool:
+ parsedBool, err1 := parseBool(innerBytes)
+ if err1 == nil {
+ val.SetBool(parsedBool)
+ }
+ err = err1
+ return
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ switch val.Type().Kind() {
+ case reflect.Int:
+ parsedInt, err1 := parseInt(innerBytes)
+ if err1 == nil {
+ val.SetInt(int64(parsedInt))
+ }
+ err = err1
+ return
+ case reflect.Int64:
+ parsedInt, err1 := parseInt64(innerBytes)
+ if err1 == nil {
+ val.SetInt(parsedInt)
+ }
+ err = err1
+ return
+ }
+ case reflect.Struct:
+ structType := fieldType
+
+ if structType.NumField() > 0 &&
+ structType.Field(0).Type == rawContentsType {
+ bytes := bytes[initOffset:offset]
+ val.Field(0).Set(reflect.ValueOf(RawContent(bytes)))
+ }
+
+ innerOffset := 0
+ for i := 0; i < structType.NumField(); i++ {
+ field := structType.Field(i)
+ if i == 0 && field.Type == rawContentsType {
+ continue
+ }
+ innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag))
+ if err != nil {
+ return
+ }
+ }
+ // We allow extra bytes at the end of the SEQUENCE because
+ // adding elements to the end has been used in X.509 as the
+ // version numbers have increased.
+ return
+ case reflect.Slice:
+ sliceType := fieldType
+ if sliceType.Elem().Kind() == reflect.Uint8 {
+ val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes)))
+ reflect.Copy(val, reflect.ValueOf(innerBytes))
+ return
+ }
+ newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem())
+ if err1 == nil {
+ val.Set(newSlice)
+ }
+ err = err1
+ return
+ case reflect.String:
+ var v string
+ switch universalTag {
+ case tagPrintableString:
+ v, err = parsePrintableString(innerBytes)
+ case tagIA5String:
+ v, err = parseIA5String(innerBytes)
+ case tagT61String:
+ v, err = parseT61String(innerBytes)
+ case tagGeneralString:
+ // GeneralString is specified in ISO-2022/ECMA-35,
+ // A brief review suggests that it includes structures
+ // that allow the encoding to change midstring and
+ // such. We give up and pass it as an 8-bit string.
+ v, err = parseT61String(innerBytes)
+ default:
+ err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)}
+ }
+ if err == nil {
+ val.SetString(v)
+ }
+ return
+ }
+ err = StructuralError{"unknown Go type"}
+ return
+}
+
+// setDefaultValue is used to install a default value, from a tag string, into
+// a Value. It is successful is the field was optional, even if a default value
+// wasn't provided or it failed to install it into the Value.
+func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
+ if !params.optional {
+ return
+ }
+ ok = true
+ if params.defaultValue == nil {
+ return
+ }
+ switch val := v; val.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ val.SetInt(*params.defaultValue)
+ }
+ return
+}
+
+// Unmarshal parses the DER-encoded ASN.1 data structure b
+// and uses the reflect package to fill in an arbitrary value pointed at by val.
+// Because Unmarshal uses the reflect package, the structs
+// being written to must use upper case field names.
+//
+// An ASN.1 INTEGER can be written to an int or int64.
+// If the encoded value does not fit in the Go type,
+// Unmarshal returns a parse error.
+//
+// An ASN.1 BIT STRING can be written to a BitString.
+//
+// An ASN.1 OCTET STRING can be written to a []byte.
+//
+// An ASN.1 OBJECT IDENTIFIER can be written to an
+// ObjectIdentifier.
+//
+// An ASN.1 ENUMERATED can be written to an Enumerated.
+//
+// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a *time.Time.
+//
+// An ASN.1 PrintableString or IA5String can be written to a string.
+//
+// Any of the above ASN.1 values can be written to an interface{}.
+// The value stored in the interface has the corresponding Go type.
+// For integers, that type is int64.
+//
+// An ASN.1 SEQUENCE OF x or SET OF x can be written
+// to a slice if an x can be written to the slice's element type.
+//
+// An ASN.1 SEQUENCE or SET can be written to a struct
+// if each of the elements in the sequence can be
+// written to the corresponding element in the struct.
+//
+// The following tags on struct fields have special meaning to Unmarshal:
+//
+// optional marks the field as ASN.1 OPTIONAL
+// [explicit] tag:x specifies the ASN.1 tag number; implies ASN.1 CONTEXT SPECIFIC
+// default:x sets the default value for optional integer fields
+//
+// If the type of the first field of a structure is RawContent then the raw
+// ASN1 contents of the struct will be stored in it.
+//
+// Other ASN.1 types are not supported; if it encounters them,
+// Unmarshal returns a parse error.
+func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) {
+ return UnmarshalWithParams(b, val, "")
+}
+
+// UnmarshalWithParams allows field parameters to be specified for the
+// top-level element. The form of the params is the same as the field tags.
+func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) {
+ v := reflect.ValueOf(val).Elem()
+ offset, err := parseField(v, b, 0, parseFieldParameters(params))
+ if err != nil {
+ return nil, err
+ }
+ return b[offset:], nil
+}
diff --git a/src/cmd/fix/testdata/reflect.datafmt.go.in b/src/cmd/fix/testdata/reflect.datafmt.go.in
new file mode 100644
index 000000000..91f885f9a
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.datafmt.go.in
@@ -0,0 +1,710 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/* The datafmt package implements syntax-directed, type-driven formatting
+ of arbitrary data structures. Formatting a data structure consists of
+ two phases: first, a parser reads a format specification and builds a
+ "compiled" format. Then, the format can be applied repeatedly to
+ arbitrary values. Applying a format to a value evaluates to a []byte
+ containing the formatted value bytes, or nil.
+
+ A format specification is a set of package declarations and format rules:
+
+ Format = [ Entry { ";" Entry } [ ";" ] ] .
+ Entry = PackageDecl | FormatRule .
+
+ (The syntax of a format specification is presented in the same EBNF
+ notation as used in the Go language specification. The syntax of white
+ space, comments, identifiers, and string literals is the same as in Go.)
+
+ A package declaration binds a package name (such as 'ast') to a
+ package import path (such as '"go/ast"'). Each package used (in
+ a type name, see below) must be declared once before use.
+
+ PackageDecl = PackageName ImportPath .
+ PackageName = identifier .
+ ImportPath = string .
+
+ A format rule binds a rule name to a format expression. A rule name
+ may be a type name or one of the special names 'default' or '/'.
+ A type name may be the name of a predeclared type (for example, 'int',
+ 'float32', etc.), the package-qualified name of a user-defined type
+ (for example, 'ast.MapType'), or an identifier indicating the structure
+ of unnamed composite types ('array', 'chan', 'func', 'interface', 'map',
+ or 'ptr'). Each rule must have a unique name; rules can be declared in
+ any order.
+
+ FormatRule = RuleName "=" Expression .
+ RuleName = TypeName | "default" | "/" .
+ TypeName = [ PackageName "." ] identifier .
+
+ To format a value, the value's type name is used to select the format rule
+ (there is an override mechanism, see below). The format expression of the
+ selected rule specifies how the value is formatted. Each format expression,
+ when applied to a value, evaluates to a byte sequence or nil.
+
+ In its most general form, a format expression is a list of alternatives,
+ each of which is a sequence of operands:
+
+ Expression = [ Sequence ] { "|" [ Sequence ] } .
+ Sequence = Operand { Operand } .
+
+ The formatted result produced by an expression is the result of the first
+ alternative sequence that evaluates to a non-nil result; if there is no
+ such alternative, the expression evaluates to nil. The result produced by
+ an operand sequence is the concatenation of the results of its operands.
+ If any operand in the sequence evaluates to nil, the entire sequence
+ evaluates to nil.
+
+ There are five kinds of operands:
+
+ Operand = Literal | Field | Group | Option | Repetition .
+
+ Literals evaluate to themselves, with two substitutions. First,
+ %-formats expand in the manner of fmt.Printf, with the current value
+ passed as the parameter. Second, the current indentation (see below)
+ is inserted after every newline or form feed character.
+
+ Literal = string .
+
+ This table shows string literals applied to the value 42 and the
+ corresponding formatted result:
+
+ "foo" foo
+ "%x" 2a
+ "x = %d" x = 42
+ "%#x = %d" 0x2a = 42
+
+ A field operand is a field name optionally followed by an alternate
+ rule name. The field name may be an identifier or one of the special
+ names @ or *.
+
+ Field = FieldName [ ":" RuleName ] .
+ FieldName = identifier | "@" | "*" .
+
+ If the field name is an identifier, the current value must be a struct,
+ and there must be a field with that name in the struct. The same lookup
+ rules apply as in the Go language (for instance, the name of an anonymous
+ field is the unqualified type name). The field name denotes the field
+ value in the struct. If the field is not found, formatting is aborted
+ and an error message is returned. (TODO consider changing the semantics
+ such that if a field is not found, it evaluates to nil).
+
+ The special name '@' denotes the current value.
+
+ The meaning of the special name '*' depends on the type of the current
+ value:
+
+ array, slice types array, slice element (inside {} only, see below)
+ interfaces value stored in interface
+ pointers value pointed to by pointer
+
+ (Implementation restriction: channel, function and map types are not
+ supported due to missing reflection support).
+
+ Fields are evaluated as follows: If the field value is nil, or an array
+ or slice element does not exist, the result is nil (see below for details
+ on array/slice elements). If the value is not nil the field value is
+ formatted (recursively) using the rule corresponding to its type name,
+ or the alternate rule name, if given.
+
+ The following example shows a complete format specification for a
+ struct 'myPackage.Point'. Assume the package
+
+ package myPackage // in directory myDir/myPackage
+ type Point struct {
+ name string;
+ x, y int;
+ }
+
+ Applying the format specification
+
+ myPackage "myDir/myPackage";
+ int = "%d";
+ hexInt = "0x%x";
+ string = "---%s---";
+ myPackage.Point = name "{" x ", " y:hexInt "}";
+
+ to the value myPackage.Point{"foo", 3, 15} results in
+
+ ---foo---{3, 0xf}
+
+ Finally, an operand may be a grouped, optional, or repeated expression.
+ A grouped expression ("group") groups a more complex expression (body)
+ so that it can be used in place of a single operand:
+
+ Group = "(" [ Indentation ">>" ] Body ")" .
+ Indentation = Expression .
+ Body = Expression .
+
+ A group body may be prefixed by an indentation expression followed by '>>'.
+ The indentation expression is applied to the current value like any other
+ expression and the result, if not nil, is appended to the current indentation
+ during the evaluation of the body (see also formatting state, below).
+
+ An optional expression ("option") is enclosed in '[]' brackets.
+
+ Option = "[" Body "]" .
+
+ An option evaluates to its body, except that if the body evaluates to nil,
+ the option expression evaluates to an empty []byte. Thus an option's purpose
+ is to protect the expression containing the option from a nil operand.
+
+ A repeated expression ("repetition") is enclosed in '{}' braces.
+
+ Repetition = "{" Body [ "/" Separator ] "}" .
+ Separator = Expression .
+
+ A repeated expression is evaluated as follows: The body is evaluated
+ repeatedly and its results are concatenated until the body evaluates
+ to nil. The result of the repetition is the (possibly empty) concatenation,
+ but it is never nil. An implicit index is supplied for the evaluation of
+ the body: that index is used to address elements of arrays or slices. If
+ the corresponding elements do not exist, the field denoting the element
+ evaluates to nil (which in turn may terminate the repetition).
+
+ The body of a repetition may be followed by a '/' and a "separator"
+ expression. If the separator is present, it is invoked between repetitions
+ of the body.
+
+ The following example shows a complete format specification for formatting
+ a slice of unnamed type. Applying the specification
+
+ int = "%b";
+ array = { * / ", " }; // array is the type name for an unnamed slice
+
+ to the value '[]int{2, 3, 5, 7}' results in
+
+ 10, 11, 101, 111
+
+ Default rule: If a format rule named 'default' is present, it is used for
+ formatting a value if no other rule was found. A common default rule is
+
+ default = "%v"
+
+ to provide default formatting for basic types without having to specify
+ a specific rule for each basic type.
+
+ Global separator rule: If a format rule named '/' is present, it is
+ invoked with the current value between literals. If the separator
+ expression evaluates to nil, it is ignored.
+
+ For instance, a global separator rule may be used to punctuate a sequence
+ of values with commas. The rules:
+
+ default = "%v";
+ / = ", ";
+
+ will format an argument list by printing each one in its default format,
+ separated by a comma and a space.
+*/
+package datafmt
+
+import (
+ "bytes"
+ "fmt"
+ "go/token"
+ "io"
+ "os"
+ "reflect"
+ "runtime"
+)
+
+// ----------------------------------------------------------------------------
+// Format representation
+
+// Custom formatters implement the Formatter function type.
+// A formatter is invoked with the current formatting state, the
+// value to format, and the rule name under which the formatter
+// was installed (the same formatter function may be installed
+// under different names). The formatter may access the current state
+// to guide formatting and use State.Write to append to the state's
+// output.
+//
+// A formatter must return a boolean value indicating if it evaluated
+// to a non-nil value (true), or a nil value (false).
+//
+type Formatter func(state *State, value interface{}, ruleName string) bool
+
+// A FormatterMap is a set of custom formatters.
+// It maps a rule name to a formatter function.
+//
+type FormatterMap map[string]Formatter
+
+// A parsed format expression is built from the following nodes.
+//
+type (
+ expr interface{}
+
+ alternatives []expr // x | y | z
+
+ sequence []expr // x y z
+
+ literal [][]byte // a list of string segments, possibly starting with '%'
+
+ field struct {
+ fieldName string // including "@", "*"
+ ruleName string // "" if no rule name specified
+ }
+
+ group struct {
+ indent, body expr // (indent >> body)
+ }
+
+ option struct {
+ body expr // [body]
+ }
+
+ repetition struct {
+ body, separator expr // {body / separator}
+ }
+
+ custom struct {
+ ruleName string
+ fun Formatter
+ }
+)
+
+// A Format is the result of parsing a format specification.
+// The format may be applied repeatedly to format values.
+//
+type Format map[string]expr
+
+// ----------------------------------------------------------------------------
+// Formatting
+
+// An application-specific environment may be provided to Format.Apply;
+// the environment is available inside custom formatters via State.Env().
+// Environments must implement copying; the Copy method must return an
+// complete copy of the receiver. This is necessary so that the formatter
+// can save and restore an environment (in case of an absent expression).
+//
+// If the Environment doesn't change during formatting (this is under
+// control of the custom formatters), the Copy function can simply return
+// the receiver, and thus can be very light-weight.
+//
+type Environment interface {
+ Copy() Environment
+}
+
+// State represents the current formatting state.
+// It is provided as argument to custom formatters.
+//
+type State struct {
+ fmt Format // format in use
+ env Environment // user-supplied environment
+ errors chan os.Error // not chan *Error (errors <- nil would be wrong!)
+ hasOutput bool // true after the first literal has been written
+ indent bytes.Buffer // current indentation
+ output bytes.Buffer // format output
+ linePos token.Position // position of line beginning (Column == 0)
+ default_ expr // possibly nil
+ separator expr // possibly nil
+}
+
+func newState(fmt Format, env Environment, errors chan os.Error) *State {
+ s := new(State)
+ s.fmt = fmt
+ s.env = env
+ s.errors = errors
+ s.linePos = token.Position{Line: 1}
+
+ // if we have a default rule, cache it's expression for fast access
+ if x, found := fmt["default"]; found {
+ s.default_ = x
+ }
+
+ // if we have a global separator rule, cache it's expression for fast access
+ if x, found := fmt["/"]; found {
+ s.separator = x
+ }
+
+ return s
+}
+
+// Env returns the environment passed to Format.Apply.
+func (s *State) Env() interface{} { return s.env }
+
+// LinePos returns the position of the current line beginning
+// in the state's output buffer. Line numbers start at 1.
+//
+func (s *State) LinePos() token.Position { return s.linePos }
+
+// Pos returns the position of the next byte to be written to the
+// output buffer. Line numbers start at 1.
+//
+func (s *State) Pos() token.Position {
+ offs := s.output.Len()
+ return token.Position{Line: s.linePos.Line, Column: offs - s.linePos.Offset, Offset: offs}
+}
+
+// Write writes data to the output buffer, inserting the indentation
+// string after each newline or form feed character. It cannot return an error.
+//
+func (s *State) Write(data []byte) (int, os.Error) {
+ n := 0
+ i0 := 0
+ for i, ch := range data {
+ if ch == '\n' || ch == '\f' {
+ // write text segment and indentation
+ n1, _ := s.output.Write(data[i0 : i+1])
+ n2, _ := s.output.Write(s.indent.Bytes())
+ n += n1 + n2
+ i0 = i + 1
+ s.linePos.Offset = s.output.Len()
+ s.linePos.Line++
+ }
+ }
+ n3, _ := s.output.Write(data[i0:])
+ return n + n3, nil
+}
+
+type checkpoint struct {
+ env Environment
+ hasOutput bool
+ outputLen int
+ linePos token.Position
+}
+
+func (s *State) save() checkpoint {
+ saved := checkpoint{nil, s.hasOutput, s.output.Len(), s.linePos}
+ if s.env != nil {
+ saved.env = s.env.Copy()
+ }
+ return saved
+}
+
+func (s *State) restore(m checkpoint) {
+ s.env = m.env
+ s.output.Truncate(m.outputLen)
+}
+
+func (s *State) error(msg string) {
+ s.errors <- os.NewError(msg)
+ runtime.Goexit()
+}
+
+// TODO At the moment, unnamed types are simply mapped to the default
+// names below. For instance, all unnamed arrays are mapped to
+// 'array' which is not really sufficient. Eventually one may want
+// to be able to specify rules for say an unnamed slice of T.
+//
+
+func typename(typ reflect.Type) string {
+ switch typ.(type) {
+ case *reflect.ArrayType:
+ return "array"
+ case *reflect.SliceType:
+ return "array"
+ case *reflect.ChanType:
+ return "chan"
+ case *reflect.FuncType:
+ return "func"
+ case *reflect.InterfaceType:
+ return "interface"
+ case *reflect.MapType:
+ return "map"
+ case *reflect.PtrType:
+ return "ptr"
+ }
+ return typ.String()
+}
+
+func (s *State) getFormat(name string) expr {
+ if fexpr, found := s.fmt[name]; found {
+ return fexpr
+ }
+
+ if s.default_ != nil {
+ return s.default_
+ }
+
+ s.error(fmt.Sprintf("no format rule for type: '%s'", name))
+ return nil
+}
+
+// eval applies a format expression fexpr to a value. If the expression
+// evaluates internally to a non-nil []byte, that slice is appended to
+// the state's output buffer and eval returns true. Otherwise, eval
+// returns false and the state remains unchanged.
+//
+func (s *State) eval(fexpr expr, value reflect.Value, index int) bool {
+ // an empty format expression always evaluates
+ // to a non-nil (but empty) []byte
+ if fexpr == nil {
+ return true
+ }
+
+ switch t := fexpr.(type) {
+ case alternatives:
+ // append the result of the first alternative that evaluates to
+ // a non-nil []byte to the state's output
+ mark := s.save()
+ for _, x := range t {
+ if s.eval(x, value, index) {
+ return true
+ }
+ s.restore(mark)
+ }
+ return false
+
+ case sequence:
+ // append the result of all operands to the state's output
+ // unless a nil result is encountered
+ mark := s.save()
+ for _, x := range t {
+ if !s.eval(x, value, index) {
+ s.restore(mark)
+ return false
+ }
+ }
+ return true
+
+ case literal:
+ // write separator, if any
+ if s.hasOutput {
+ // not the first literal
+ if s.separator != nil {
+ sep := s.separator // save current separator
+ s.separator = nil // and disable it (avoid recursion)
+ mark := s.save()
+ if !s.eval(sep, value, index) {
+ s.restore(mark)
+ }
+ s.separator = sep // enable it again
+ }
+ }
+ s.hasOutput = true
+ // write literal segments
+ for _, lit := range t {
+ if len(lit) > 1 && lit[0] == '%' {
+ // segment contains a %-format at the beginning
+ if lit[1] == '%' {
+ // "%%" is printed as a single "%"
+ s.Write(lit[1:])
+ } else {
+ // use s instead of s.output to get indentation right
+ fmt.Fprintf(s, string(lit), value.Interface())
+ }
+ } else {
+ // segment contains no %-formats
+ s.Write(lit)
+ }
+ }
+ return true // a literal never evaluates to nil
+
+ case *field:
+ // determine field value
+ switch t.fieldName {
+ case "@":
+ // field value is current value
+
+ case "*":
+ // indirection: operation is type-specific
+ switch v := value.(type) {
+ case *reflect.ArrayValue:
+ if v.Len() <= index {
+ return false
+ }
+ value = v.Elem(index)
+
+ case *reflect.SliceValue:
+ if v.IsNil() || v.Len() <= index {
+ return false
+ }
+ value = v.Elem(index)
+
+ case *reflect.MapValue:
+ s.error("reflection support for maps incomplete")
+
+ case *reflect.PtrValue:
+ if v.IsNil() {
+ return false
+ }
+ value = v.Elem()
+
+ case *reflect.InterfaceValue:
+ if v.IsNil() {
+ return false
+ }
+ value = v.Elem()
+
+ case *reflect.ChanValue:
+ s.error("reflection support for chans incomplete")
+
+ case *reflect.FuncValue:
+ s.error("reflection support for funcs incomplete")
+
+ default:
+ s.error(fmt.Sprintf("error: * does not apply to `%s`", value.Type()))
+ }
+
+ default:
+ // value is value of named field
+ var field reflect.Value
+ if sval, ok := value.(*reflect.StructValue); ok {
+ field = sval.FieldByName(t.fieldName)
+ if field == nil {
+ // TODO consider just returning false in this case
+ s.error(fmt.Sprintf("error: no field `%s` in `%s`", t.fieldName, value.Type()))
+ }
+ }
+ value = field
+ }
+
+ // determine rule
+ ruleName := t.ruleName
+ if ruleName == "" {
+ // no alternate rule name, value type determines rule
+ ruleName = typename(value.Type())
+ }
+ fexpr = s.getFormat(ruleName)
+
+ mark := s.save()
+ if !s.eval(fexpr, value, index) {
+ s.restore(mark)
+ return false
+ }
+ return true
+
+ case *group:
+ // remember current indentation
+ indentLen := s.indent.Len()
+
+ // update current indentation
+ mark := s.save()
+ s.eval(t.indent, value, index)
+ // if the indentation evaluates to nil, the state's output buffer
+ // didn't change - either way it's ok to append the difference to
+ // the current identation
+ s.indent.Write(s.output.Bytes()[mark.outputLen:s.output.Len()])
+ s.restore(mark)
+
+ // format group body
+ mark = s.save()
+ b := true
+ if !s.eval(t.body, value, index) {
+ s.restore(mark)
+ b = false
+ }
+
+ // reset indentation
+ s.indent.Truncate(indentLen)
+ return b
+
+ case *option:
+ // evaluate the body and append the result to the state's output
+ // buffer unless the result is nil
+ mark := s.save()
+ if !s.eval(t.body, value, 0) { // TODO is 0 index correct?
+ s.restore(mark)
+ }
+ return true // an option never evaluates to nil
+
+ case *repetition:
+ // evaluate the body and append the result to the state's output
+ // buffer until a result is nil
+ for i := 0; ; i++ {
+ mark := s.save()
+ // write separator, if any
+ if i > 0 && t.separator != nil {
+ // nil result from separator is ignored
+ mark := s.save()
+ if !s.eval(t.separator, value, i) {
+ s.restore(mark)
+ }
+ }
+ if !s.eval(t.body, value, i) {
+ s.restore(mark)
+ break
+ }
+ }
+ return true // a repetition never evaluates to nil
+
+ case *custom:
+ // invoke the custom formatter to obtain the result
+ mark := s.save()
+ if !t.fun(s, value.Interface(), t.ruleName) {
+ s.restore(mark)
+ return false
+ }
+ return true
+ }
+
+ panic("unreachable")
+ return false
+}
+
+// Eval formats each argument according to the format
+// f and returns the resulting []byte and os.Error. If
+// an error occurred, the []byte contains the partially
+// formatted result. An environment env may be passed
+// in which is available in custom formatters through
+// the state parameter.
+//
+func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) {
+ if f == nil {
+ return nil, os.NewError("format is nil")
+ }
+
+ errors := make(chan os.Error)
+ s := newState(f, env, errors)
+
+ go func() {
+ for _, v := range args {
+ fld := reflect.NewValue(v)
+ if fld == nil {
+ errors <- os.NewError("nil argument")
+ return
+ }
+ mark := s.save()
+ if !s.eval(s.getFormat(typename(fld.Type())), fld, 0) { // TODO is 0 index correct?
+ s.restore(mark)
+ }
+ }
+ errors <- nil // no errors
+ }()
+
+ err := <-errors
+ return s.output.Bytes(), err
+}
+
+// ----------------------------------------------------------------------------
+// Convenience functions
+
+// Fprint formats each argument according to the format f
+// and writes to w. The result is the total number of bytes
+// written and an os.Error, if any.
+//
+func (f Format) Fprint(w io.Writer, env Environment, args ...interface{}) (int, os.Error) {
+ data, err := f.Eval(env, args...)
+ if err != nil {
+ // TODO should we print partial result in case of error?
+ return 0, err
+ }
+ return w.Write(data)
+}
+
+// Print formats each argument according to the format f
+// and writes to standard output. The result is the total
+// number of bytes written and an os.Error, if any.
+//
+func (f Format) Print(args ...interface{}) (int, os.Error) {
+ return f.Fprint(os.Stdout, nil, args...)
+}
+
+// Sprint formats each argument according to the format f
+// and returns the resulting string. If an error occurs
+// during formatting, the result string contains the
+// partially formatted result followed by an error message.
+//
+func (f Format) Sprint(args ...interface{}) string {
+ var buf bytes.Buffer
+ _, err := f.Fprint(&buf, nil, args...)
+ if err != nil {
+ var i interface{} = args
+ fmt.Fprintf(&buf, "--- Sprint(%s) failed: %v", fmt.Sprint(i), err)
+ }
+ return buf.String()
+}
diff --git a/src/cmd/fix/testdata/reflect.datafmt.go.out b/src/cmd/fix/testdata/reflect.datafmt.go.out
new file mode 100644
index 000000000..fd447588b
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.datafmt.go.out
@@ -0,0 +1,710 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/* The datafmt package implements syntax-directed, type-driven formatting
+ of arbitrary data structures. Formatting a data structure consists of
+ two phases: first, a parser reads a format specification and builds a
+ "compiled" format. Then, the format can be applied repeatedly to
+ arbitrary values. Applying a format to a value evaluates to a []byte
+ containing the formatted value bytes, or nil.
+
+ A format specification is a set of package declarations and format rules:
+
+ Format = [ Entry { ";" Entry } [ ";" ] ] .
+ Entry = PackageDecl | FormatRule .
+
+ (The syntax of a format specification is presented in the same EBNF
+ notation as used in the Go language specification. The syntax of white
+ space, comments, identifiers, and string literals is the same as in Go.)
+
+ A package declaration binds a package name (such as 'ast') to a
+ package import path (such as '"go/ast"'). Each package used (in
+ a type name, see below) must be declared once before use.
+
+ PackageDecl = PackageName ImportPath .
+ PackageName = identifier .
+ ImportPath = string .
+
+ A format rule binds a rule name to a format expression. A rule name
+ may be a type name or one of the special names 'default' or '/'.
+ A type name may be the name of a predeclared type (for example, 'int',
+ 'float32', etc.), the package-qualified name of a user-defined type
+ (for example, 'ast.MapType'), or an identifier indicating the structure
+ of unnamed composite types ('array', 'chan', 'func', 'interface', 'map',
+ or 'ptr'). Each rule must have a unique name; rules can be declared in
+ any order.
+
+ FormatRule = RuleName "=" Expression .
+ RuleName = TypeName | "default" | "/" .
+ TypeName = [ PackageName "." ] identifier .
+
+ To format a value, the value's type name is used to select the format rule
+ (there is an override mechanism, see below). The format expression of the
+ selected rule specifies how the value is formatted. Each format expression,
+ when applied to a value, evaluates to a byte sequence or nil.
+
+ In its most general form, a format expression is a list of alternatives,
+ each of which is a sequence of operands:
+
+ Expression = [ Sequence ] { "|" [ Sequence ] } .
+ Sequence = Operand { Operand } .
+
+ The formatted result produced by an expression is the result of the first
+ alternative sequence that evaluates to a non-nil result; if there is no
+ such alternative, the expression evaluates to nil. The result produced by
+ an operand sequence is the concatenation of the results of its operands.
+ If any operand in the sequence evaluates to nil, the entire sequence
+ evaluates to nil.
+
+ There are five kinds of operands:
+
+ Operand = Literal | Field | Group | Option | Repetition .
+
+ Literals evaluate to themselves, with two substitutions. First,
+ %-formats expand in the manner of fmt.Printf, with the current value
+ passed as the parameter. Second, the current indentation (see below)
+ is inserted after every newline or form feed character.
+
+ Literal = string .
+
+ This table shows string literals applied to the value 42 and the
+ corresponding formatted result:
+
+ "foo" foo
+ "%x" 2a
+ "x = %d" x = 42
+ "%#x = %d" 0x2a = 42
+
+ A field operand is a field name optionally followed by an alternate
+ rule name. The field name may be an identifier or one of the special
+ names @ or *.
+
+ Field = FieldName [ ":" RuleName ] .
+ FieldName = identifier | "@" | "*" .
+
+ If the field name is an identifier, the current value must be a struct,
+ and there must be a field with that name in the struct. The same lookup
+ rules apply as in the Go language (for instance, the name of an anonymous
+ field is the unqualified type name). The field name denotes the field
+ value in the struct. If the field is not found, formatting is aborted
+ and an error message is returned. (TODO consider changing the semantics
+ such that if a field is not found, it evaluates to nil).
+
+ The special name '@' denotes the current value.
+
+ The meaning of the special name '*' depends on the type of the current
+ value:
+
+ array, slice types array, slice element (inside {} only, see below)
+ interfaces value stored in interface
+ pointers value pointed to by pointer
+
+ (Implementation restriction: channel, function and map types are not
+ supported due to missing reflection support).
+
+ Fields are evaluated as follows: If the field value is nil, or an array
+ or slice element does not exist, the result is nil (see below for details
+ on array/slice elements). If the value is not nil the field value is
+ formatted (recursively) using the rule corresponding to its type name,
+ or the alternate rule name, if given.
+
+ The following example shows a complete format specification for a
+ struct 'myPackage.Point'. Assume the package
+
+ package myPackage // in directory myDir/myPackage
+ type Point struct {
+ name string;
+ x, y int;
+ }
+
+ Applying the format specification
+
+ myPackage "myDir/myPackage";
+ int = "%d";
+ hexInt = "0x%x";
+ string = "---%s---";
+ myPackage.Point = name "{" x ", " y:hexInt "}";
+
+ to the value myPackage.Point{"foo", 3, 15} results in
+
+ ---foo---{3, 0xf}
+
+ Finally, an operand may be a grouped, optional, or repeated expression.
+ A grouped expression ("group") groups a more complex expression (body)
+ so that it can be used in place of a single operand:
+
+ Group = "(" [ Indentation ">>" ] Body ")" .
+ Indentation = Expression .
+ Body = Expression .
+
+ A group body may be prefixed by an indentation expression followed by '>>'.
+ The indentation expression is applied to the current value like any other
+ expression and the result, if not nil, is appended to the current indentation
+ during the evaluation of the body (see also formatting state, below).
+
+ An optional expression ("option") is enclosed in '[]' brackets.
+
+ Option = "[" Body "]" .
+
+ An option evaluates to its body, except that if the body evaluates to nil,
+ the option expression evaluates to an empty []byte. Thus an option's purpose
+ is to protect the expression containing the option from a nil operand.
+
+ A repeated expression ("repetition") is enclosed in '{}' braces.
+
+ Repetition = "{" Body [ "/" Separator ] "}" .
+ Separator = Expression .
+
+ A repeated expression is evaluated as follows: The body is evaluated
+ repeatedly and its results are concatenated until the body evaluates
+ to nil. The result of the repetition is the (possibly empty) concatenation,
+ but it is never nil. An implicit index is supplied for the evaluation of
+ the body: that index is used to address elements of arrays or slices. If
+ the corresponding elements do not exist, the field denoting the element
+ evaluates to nil (which in turn may terminate the repetition).
+
+ The body of a repetition may be followed by a '/' and a "separator"
+ expression. If the separator is present, it is invoked between repetitions
+ of the body.
+
+ The following example shows a complete format specification for formatting
+ a slice of unnamed type. Applying the specification
+
+ int = "%b";
+ array = { * / ", " }; // array is the type name for an unnamed slice
+
+ to the value '[]int{2, 3, 5, 7}' results in
+
+ 10, 11, 101, 111
+
+ Default rule: If a format rule named 'default' is present, it is used for
+ formatting a value if no other rule was found. A common default rule is
+
+ default = "%v"
+
+ to provide default formatting for basic types without having to specify
+ a specific rule for each basic type.
+
+ Global separator rule: If a format rule named '/' is present, it is
+ invoked with the current value between literals. If the separator
+ expression evaluates to nil, it is ignored.
+
+ For instance, a global separator rule may be used to punctuate a sequence
+ of values with commas. The rules:
+
+ default = "%v";
+ / = ", ";
+
+ will format an argument list by printing each one in its default format,
+ separated by a comma and a space.
+*/
+package datafmt
+
+import (
+ "bytes"
+ "fmt"
+ "go/token"
+ "io"
+ "os"
+ "reflect"
+ "runtime"
+)
+
+// ----------------------------------------------------------------------------
+// Format representation
+
+// Custom formatters implement the Formatter function type.
+// A formatter is invoked with the current formatting state, the
+// value to format, and the rule name under which the formatter
+// was installed (the same formatter function may be installed
+// under different names). The formatter may access the current state
+// to guide formatting and use State.Write to append to the state's
+// output.
+//
+// A formatter must return a boolean value indicating if it evaluated
+// to a non-nil value (true), or a nil value (false).
+//
+type Formatter func(state *State, value interface{}, ruleName string) bool
+
+// A FormatterMap is a set of custom formatters.
+// It maps a rule name to a formatter function.
+//
+type FormatterMap map[string]Formatter
+
+// A parsed format expression is built from the following nodes.
+//
+type (
+ expr interface{}
+
+ alternatives []expr // x | y | z
+
+ sequence []expr // x y z
+
+ literal [][]byte // a list of string segments, possibly starting with '%'
+
+ field struct {
+ fieldName string // including "@", "*"
+ ruleName string // "" if no rule name specified
+ }
+
+ group struct {
+ indent, body expr // (indent >> body)
+ }
+
+ option struct {
+ body expr // [body]
+ }
+
+ repetition struct {
+ body, separator expr // {body / separator}
+ }
+
+ custom struct {
+ ruleName string
+ fun Formatter
+ }
+)
+
+// A Format is the result of parsing a format specification.
+// The format may be applied repeatedly to format values.
+//
+type Format map[string]expr
+
+// ----------------------------------------------------------------------------
+// Formatting
+
+// An application-specific environment may be provided to Format.Apply;
+// the environment is available inside custom formatters via State.Env().
+// Environments must implement copying; the Copy method must return an
+// complete copy of the receiver. This is necessary so that the formatter
+// can save and restore an environment (in case of an absent expression).
+//
+// If the Environment doesn't change during formatting (this is under
+// control of the custom formatters), the Copy function can simply return
+// the receiver, and thus can be very light-weight.
+//
+type Environment interface {
+ Copy() Environment
+}
+
+// State represents the current formatting state.
+// It is provided as argument to custom formatters.
+//
+type State struct {
+ fmt Format // format in use
+ env Environment // user-supplied environment
+ errors chan os.Error // not chan *Error (errors <- nil would be wrong!)
+ hasOutput bool // true after the first literal has been written
+ indent bytes.Buffer // current indentation
+ output bytes.Buffer // format output
+ linePos token.Position // position of line beginning (Column == 0)
+ default_ expr // possibly nil
+ separator expr // possibly nil
+}
+
+func newState(fmt Format, env Environment, errors chan os.Error) *State {
+ s := new(State)
+ s.fmt = fmt
+ s.env = env
+ s.errors = errors
+ s.linePos = token.Position{Line: 1}
+
+ // if we have a default rule, cache it's expression for fast access
+ if x, found := fmt["default"]; found {
+ s.default_ = x
+ }
+
+ // if we have a global separator rule, cache it's expression for fast access
+ if x, found := fmt["/"]; found {
+ s.separator = x
+ }
+
+ return s
+}
+
+// Env returns the environment passed to Format.Apply.
+func (s *State) Env() interface{} { return s.env }
+
+// LinePos returns the position of the current line beginning
+// in the state's output buffer. Line numbers start at 1.
+//
+func (s *State) LinePos() token.Position { return s.linePos }
+
+// Pos returns the position of the next byte to be written to the
+// output buffer. Line numbers start at 1.
+//
+func (s *State) Pos() token.Position {
+ offs := s.output.Len()
+ return token.Position{Line: s.linePos.Line, Column: offs - s.linePos.Offset, Offset: offs}
+}
+
+// Write writes data to the output buffer, inserting the indentation
+// string after each newline or form feed character. It cannot return an error.
+//
+func (s *State) Write(data []byte) (int, os.Error) {
+ n := 0
+ i0 := 0
+ for i, ch := range data {
+ if ch == '\n' || ch == '\f' {
+ // write text segment and indentation
+ n1, _ := s.output.Write(data[i0 : i+1])
+ n2, _ := s.output.Write(s.indent.Bytes())
+ n += n1 + n2
+ i0 = i + 1
+ s.linePos.Offset = s.output.Len()
+ s.linePos.Line++
+ }
+ }
+ n3, _ := s.output.Write(data[i0:])
+ return n + n3, nil
+}
+
+type checkpoint struct {
+ env Environment
+ hasOutput bool
+ outputLen int
+ linePos token.Position
+}
+
+func (s *State) save() checkpoint {
+ saved := checkpoint{nil, s.hasOutput, s.output.Len(), s.linePos}
+ if s.env != nil {
+ saved.env = s.env.Copy()
+ }
+ return saved
+}
+
+func (s *State) restore(m checkpoint) {
+ s.env = m.env
+ s.output.Truncate(m.outputLen)
+}
+
+func (s *State) error(msg string) {
+ s.errors <- os.NewError(msg)
+ runtime.Goexit()
+}
+
+// TODO At the moment, unnamed types are simply mapped to the default
+// names below. For instance, all unnamed arrays are mapped to
+// 'array' which is not really sufficient. Eventually one may want
+// to be able to specify rules for say an unnamed slice of T.
+//
+
+func typename(typ reflect.Type) string {
+ switch typ.Kind() {
+ case reflect.Array:
+ return "array"
+ case reflect.Slice:
+ return "array"
+ case reflect.Chan:
+ return "chan"
+ case reflect.Func:
+ return "func"
+ case reflect.Interface:
+ return "interface"
+ case reflect.Map:
+ return "map"
+ case reflect.Ptr:
+ return "ptr"
+ }
+ return typ.String()
+}
+
+func (s *State) getFormat(name string) expr {
+ if fexpr, found := s.fmt[name]; found {
+ return fexpr
+ }
+
+ if s.default_ != nil {
+ return s.default_
+ }
+
+ s.error(fmt.Sprintf("no format rule for type: '%s'", name))
+ return nil
+}
+
+// eval applies a format expression fexpr to a value. If the expression
+// evaluates internally to a non-nil []byte, that slice is appended to
+// the state's output buffer and eval returns true. Otherwise, eval
+// returns false and the state remains unchanged.
+//
+func (s *State) eval(fexpr expr, value reflect.Value, index int) bool {
+ // an empty format expression always evaluates
+ // to a non-nil (but empty) []byte
+ if fexpr == nil {
+ return true
+ }
+
+ switch t := fexpr.(type) {
+ case alternatives:
+ // append the result of the first alternative that evaluates to
+ // a non-nil []byte to the state's output
+ mark := s.save()
+ for _, x := range t {
+ if s.eval(x, value, index) {
+ return true
+ }
+ s.restore(mark)
+ }
+ return false
+
+ case sequence:
+ // append the result of all operands to the state's output
+ // unless a nil result is encountered
+ mark := s.save()
+ for _, x := range t {
+ if !s.eval(x, value, index) {
+ s.restore(mark)
+ return false
+ }
+ }
+ return true
+
+ case literal:
+ // write separator, if any
+ if s.hasOutput {
+ // not the first literal
+ if s.separator != nil {
+ sep := s.separator // save current separator
+ s.separator = nil // and disable it (avoid recursion)
+ mark := s.save()
+ if !s.eval(sep, value, index) {
+ s.restore(mark)
+ }
+ s.separator = sep // enable it again
+ }
+ }
+ s.hasOutput = true
+ // write literal segments
+ for _, lit := range t {
+ if len(lit) > 1 && lit[0] == '%' {
+ // segment contains a %-format at the beginning
+ if lit[1] == '%' {
+ // "%%" is printed as a single "%"
+ s.Write(lit[1:])
+ } else {
+ // use s instead of s.output to get indentation right
+ fmt.Fprintf(s, string(lit), value.Interface())
+ }
+ } else {
+ // segment contains no %-formats
+ s.Write(lit)
+ }
+ }
+ return true // a literal never evaluates to nil
+
+ case *field:
+ // determine field value
+ switch t.fieldName {
+ case "@":
+ // field value is current value
+
+ case "*":
+ // indirection: operation is type-specific
+ switch v := value; v.Kind() {
+ case reflect.Array:
+ if v.Len() <= index {
+ return false
+ }
+ value = v.Index(index)
+
+ case reflect.Slice:
+ if v.IsNil() || v.Len() <= index {
+ return false
+ }
+ value = v.Index(index)
+
+ case reflect.Map:
+ s.error("reflection support for maps incomplete")
+
+ case reflect.Ptr:
+ if v.IsNil() {
+ return false
+ }
+ value = v.Elem()
+
+ case reflect.Interface:
+ if v.IsNil() {
+ return false
+ }
+ value = v.Elem()
+
+ case reflect.Chan:
+ s.error("reflection support for chans incomplete")
+
+ case reflect.Func:
+ s.error("reflection support for funcs incomplete")
+
+ default:
+ s.error(fmt.Sprintf("error: * does not apply to `%s`", value.Type()))
+ }
+
+ default:
+ // value is value of named field
+ var field reflect.Value
+ if sval := value; sval.Kind() == reflect.Struct {
+ field = sval.FieldByName(t.fieldName)
+ if !field.IsValid() {
+ // TODO consider just returning false in this case
+ s.error(fmt.Sprintf("error: no field `%s` in `%s`", t.fieldName, value.Type()))
+ }
+ }
+ value = field
+ }
+
+ // determine rule
+ ruleName := t.ruleName
+ if ruleName == "" {
+ // no alternate rule name, value type determines rule
+ ruleName = typename(value.Type())
+ }
+ fexpr = s.getFormat(ruleName)
+
+ mark := s.save()
+ if !s.eval(fexpr, value, index) {
+ s.restore(mark)
+ return false
+ }
+ return true
+
+ case *group:
+ // remember current indentation
+ indentLen := s.indent.Len()
+
+ // update current indentation
+ mark := s.save()
+ s.eval(t.indent, value, index)
+ // if the indentation evaluates to nil, the state's output buffer
+ // didn't change - either way it's ok to append the difference to
+ // the current identation
+ s.indent.Write(s.output.Bytes()[mark.outputLen:s.output.Len()])
+ s.restore(mark)
+
+ // format group body
+ mark = s.save()
+ b := true
+ if !s.eval(t.body, value, index) {
+ s.restore(mark)
+ b = false
+ }
+
+ // reset indentation
+ s.indent.Truncate(indentLen)
+ return b
+
+ case *option:
+ // evaluate the body and append the result to the state's output
+ // buffer unless the result is nil
+ mark := s.save()
+ if !s.eval(t.body, value, 0) { // TODO is 0 index correct?
+ s.restore(mark)
+ }
+ return true // an option never evaluates to nil
+
+ case *repetition:
+ // evaluate the body and append the result to the state's output
+ // buffer until a result is nil
+ for i := 0; ; i++ {
+ mark := s.save()
+ // write separator, if any
+ if i > 0 && t.separator != nil {
+ // nil result from separator is ignored
+ mark := s.save()
+ if !s.eval(t.separator, value, i) {
+ s.restore(mark)
+ }
+ }
+ if !s.eval(t.body, value, i) {
+ s.restore(mark)
+ break
+ }
+ }
+ return true // a repetition never evaluates to nil
+
+ case *custom:
+ // invoke the custom formatter to obtain the result
+ mark := s.save()
+ if !t.fun(s, value.Interface(), t.ruleName) {
+ s.restore(mark)
+ return false
+ }
+ return true
+ }
+
+ panic("unreachable")
+ return false
+}
+
+// Eval formats each argument according to the format
+// f and returns the resulting []byte and os.Error. If
+// an error occurred, the []byte contains the partially
+// formatted result. An environment env may be passed
+// in which is available in custom formatters through
+// the state parameter.
+//
+func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) {
+ if f == nil {
+ return nil, os.NewError("format is nil")
+ }
+
+ errors := make(chan os.Error)
+ s := newState(f, env, errors)
+
+ go func() {
+ for _, v := range args {
+ fld := reflect.ValueOf(v)
+ if !fld.IsValid() {
+ errors <- os.NewError("nil argument")
+ return
+ }
+ mark := s.save()
+ if !s.eval(s.getFormat(typename(fld.Type())), fld, 0) { // TODO is 0 index correct?
+ s.restore(mark)
+ }
+ }
+ errors <- nil // no errors
+ }()
+
+ err := <-errors
+ return s.output.Bytes(), err
+}
+
+// ----------------------------------------------------------------------------
+// Convenience functions
+
+// Fprint formats each argument according to the format f
+// and writes to w. The result is the total number of bytes
+// written and an os.Error, if any.
+//
+func (f Format) Fprint(w io.Writer, env Environment, args ...interface{}) (int, os.Error) {
+ data, err := f.Eval(env, args...)
+ if err != nil {
+ // TODO should we print partial result in case of error?
+ return 0, err
+ }
+ return w.Write(data)
+}
+
+// Print formats each argument according to the format f
+// and writes to standard output. The result is the total
+// number of bytes written and an os.Error, if any.
+//
+func (f Format) Print(args ...interface{}) (int, os.Error) {
+ return f.Fprint(os.Stdout, nil, args...)
+}
+
+// Sprint formats each argument according to the format f
+// and returns the resulting string. If an error occurs
+// during formatting, the result string contains the
+// partially formatted result followed by an error message.
+//
+func (f Format) Sprint(args ...interface{}) string {
+ var buf bytes.Buffer
+ _, err := f.Fprint(&buf, nil, args...)
+ if err != nil {
+ var i interface{} = args
+ fmt.Fprintf(&buf, "--- Sprint(%s) failed: %v", fmt.Sprint(i), err)
+ }
+ return buf.String()
+}
diff --git a/src/cmd/fix/testdata/reflect.decode.go.in b/src/cmd/fix/testdata/reflect.decode.go.in
new file mode 100644
index 000000000..f831abee3
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.decode.go.in
@@ -0,0 +1,905 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Represents JSON data structure using native Go types: booleans, floats,
+// strings, arrays, and maps.
+
+package json
+
+import (
+ "container/vector"
+ "encoding/base64"
+ "os"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf16"
+ "utf8"
+)
+
+// Unmarshal parses the JSON-encoded data and stores the result
+// in the value pointed to by v.
+//
+// Unmarshal traverses the value v recursively.
+// If an encountered value implements the Unmarshaler interface,
+// Unmarshal calls its UnmarshalJSON method with a well-formed
+// JSON encoding.
+//
+// Otherwise, Unmarshal uses the inverse of the encodings that
+// Marshal uses, allocating maps, slices, and pointers as necessary,
+// with the following additional rules:
+//
+// To unmarshal a JSON value into a nil interface value, the
+// type stored in the interface value is one of:
+//
+// bool, for JSON booleans
+// float64, for JSON numbers
+// string, for JSON strings
+// []interface{}, for JSON arrays
+// map[string]interface{}, for JSON objects
+// nil for JSON null
+//
+// If a JSON value is not appropriate for a given target type,
+// or if a JSON number overflows the target type, Unmarshal
+// skips that field and completes the unmarshalling as best it can.
+// If no more serious errors are encountered, Unmarshal returns
+// an UnmarshalTypeError describing the earliest such error.
+//
+func Unmarshal(data []byte, v interface{}) os.Error {
+ d := new(decodeState).init(data)
+
+ // Quick check for well-formedness.
+ // Avoids filling out half a data structure
+ // before discovering a JSON syntax error.
+ err := checkValid(data, &d.scan)
+ if err != nil {
+ return err
+ }
+
+ return d.unmarshal(v)
+}
+
+// Unmarshaler is the interface implemented by objects
+// that can unmarshal a JSON description of themselves.
+// The input can be assumed to be a valid JSON object
+// encoding. UnmarshalJSON must copy the JSON data
+// if it wishes to retain the data after returning.
+type Unmarshaler interface {
+ UnmarshalJSON([]byte) os.Error
+}
+
+// An UnmarshalTypeError describes a JSON value that was
+// not appropriate for a value of a specific Go type.
+type UnmarshalTypeError struct {
+ Value string // description of JSON value - "bool", "array", "number -5"
+ Type reflect.Type // type of Go value it could not be assigned to
+}
+
+func (e *UnmarshalTypeError) String() string {
+ return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String()
+}
+
+// An UnmarshalFieldError describes a JSON object key that
+// led to an unexported (and therefore unwritable) struct field.
+type UnmarshalFieldError struct {
+ Key string
+ Type *reflect.StructType
+ Field reflect.StructField
+}
+
+func (e *UnmarshalFieldError) String() string {
+ return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String()
+}
+
+// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal.
+// (The argument to Unmarshal must be a non-nil pointer.)
+type InvalidUnmarshalError struct {
+ Type reflect.Type
+}
+
+func (e *InvalidUnmarshalError) String() string {
+ if e.Type == nil {
+ return "json: Unmarshal(nil)"
+ }
+
+ if _, ok := e.Type.(*reflect.PtrType); !ok {
+ return "json: Unmarshal(non-pointer " + e.Type.String() + ")"
+ }
+ return "json: Unmarshal(nil " + e.Type.String() + ")"
+}
+
+func (d *decodeState) unmarshal(v interface{}) (err os.Error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(runtime.Error); ok {
+ panic(r)
+ }
+ err = r.(os.Error)
+ }
+ }()
+
+ rv := reflect.NewValue(v)
+ pv, ok := rv.(*reflect.PtrValue)
+ if !ok || pv.IsNil() {
+ return &InvalidUnmarshalError{reflect.Typeof(v)}
+ }
+
+ d.scan.reset()
+ // We decode rv not pv.Elem because the Unmarshaler interface
+ // test must be applied at the top level of the value.
+ d.value(rv)
+ return d.savedError
+}
+
+// decodeState represents the state while decoding a JSON value.
+type decodeState struct {
+ data []byte
+ off int // read offset in data
+ scan scanner
+ nextscan scanner // for calls to nextValue
+ savedError os.Error
+}
+
+// errPhase is used for errors that should not happen unless
+// there is a bug in the JSON decoder or something is editing
+// the data slice while the decoder executes.
+var errPhase = os.NewError("JSON decoder out of sync - data changing underfoot?")
+
+func (d *decodeState) init(data []byte) *decodeState {
+ d.data = data
+ d.off = 0
+ d.savedError = nil
+ return d
+}
+
+// error aborts the decoding by panicking with err.
+func (d *decodeState) error(err os.Error) {
+ panic(err)
+}
+
+// saveError saves the first err it is called with,
+// for reporting at the end of the unmarshal.
+func (d *decodeState) saveError(err os.Error) {
+ if d.savedError == nil {
+ d.savedError = err
+ }
+}
+
+// next cuts off and returns the next full JSON value in d.data[d.off:].
+// The next value is known to be an object or array, not a literal.
+func (d *decodeState) next() []byte {
+ c := d.data[d.off]
+ item, rest, err := nextValue(d.data[d.off:], &d.nextscan)
+ if err != nil {
+ d.error(err)
+ }
+ d.off = len(d.data) - len(rest)
+
+ // Our scanner has seen the opening brace/bracket
+ // and thinks we're still in the middle of the object.
+ // invent a closing brace/bracket to get it out.
+ if c == '{' {
+ d.scan.step(&d.scan, '}')
+ } else {
+ d.scan.step(&d.scan, ']')
+ }
+
+ return item
+}
+
+// scanWhile processes bytes in d.data[d.off:] until it
+// receives a scan code not equal to op.
+// It updates d.off and returns the new scan code.
+func (d *decodeState) scanWhile(op int) int {
+ var newOp int
+ for {
+ if d.off >= len(d.data) {
+ newOp = d.scan.eof()
+ d.off = len(d.data) + 1 // mark processed EOF with len+1
+ } else {
+ c := int(d.data[d.off])
+ d.off++
+ newOp = d.scan.step(&d.scan, c)
+ }
+ if newOp != op {
+ break
+ }
+ }
+ return newOp
+}
+
+// value decodes a JSON value from d.data[d.off:] into the value.
+// it updates d.off to point past the decoded value.
+func (d *decodeState) value(v reflect.Value) {
+ if v == nil {
+ _, rest, err := nextValue(d.data[d.off:], &d.nextscan)
+ if err != nil {
+ d.error(err)
+ }
+ d.off = len(d.data) - len(rest)
+
+ // d.scan thinks we're still at the beginning of the item.
+ // Feed in an empty string - the shortest, simplest value -
+ // so that it knows we got to the end of the value.
+ if d.scan.step == stateRedo {
+ panic("redo")
+ }
+ d.scan.step(&d.scan, '"')
+ d.scan.step(&d.scan, '"')
+ return
+ }
+
+ switch op := d.scanWhile(scanSkipSpace); op {
+ default:
+ d.error(errPhase)
+
+ case scanBeginArray:
+ d.array(v)
+
+ case scanBeginObject:
+ d.object(v)
+
+ case scanBeginLiteral:
+ d.literal(v)
+ }
+}
+
+// indirect walks down v allocating pointers as needed,
+// until it gets to a non-pointer.
+// if it encounters an Unmarshaler, indirect stops and returns that.
+// if wantptr is true, indirect stops at the last pointer.
+func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, reflect.Value) {
+ for {
+ var isUnmarshaler bool
+ if v.Type().NumMethod() > 0 {
+ // Remember that this is an unmarshaler,
+ // but wait to return it until after allocating
+ // the pointer (if necessary).
+ _, isUnmarshaler = v.Interface().(Unmarshaler)
+ }
+
+ if iv, ok := v.(*reflect.InterfaceValue); ok && !iv.IsNil() {
+ v = iv.Elem()
+ continue
+ }
+ pv, ok := v.(*reflect.PtrValue)
+ if !ok {
+ break
+ }
+ _, isptrptr := pv.Elem().(*reflect.PtrValue)
+ if !isptrptr && wantptr && !isUnmarshaler {
+ return nil, pv
+ }
+ if pv.IsNil() {
+ pv.PointTo(reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem()))
+ }
+ if isUnmarshaler {
+ // Using v.Interface().(Unmarshaler)
+ // here means that we have to use a pointer
+ // as the struct field. We cannot use a value inside
+ // a pointer to a struct, because in that case
+ // v.Interface() is the value (x.f) not the pointer (&x.f).
+ // This is an unfortunate consequence of reflect.
+ // An alternative would be to look up the
+ // UnmarshalJSON method and return a FuncValue.
+ return v.Interface().(Unmarshaler), nil
+ }
+ v = pv.Elem()
+ }
+ return nil, v
+}
+
+// array consumes an array from d.data[d.off-1:], decoding into the value v.
+// the first byte of the array ('[') has been read already.
+func (d *decodeState) array(v reflect.Value) {
+ // Check for unmarshaler.
+ unmarshaler, pv := d.indirect(v, false)
+ if unmarshaler != nil {
+ d.off--
+ err := unmarshaler.UnmarshalJSON(d.next())
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ // Decoding into nil interface? Switch to non-reflect code.
+ iv, ok := v.(*reflect.InterfaceValue)
+ if ok {
+ iv.Set(reflect.NewValue(d.arrayInterface()))
+ return
+ }
+
+ // Check type of target.
+ av, ok := v.(reflect.ArrayOrSliceValue)
+ if !ok {
+ d.saveError(&UnmarshalTypeError{"array", v.Type()})
+ d.off--
+ d.next()
+ return
+ }
+
+ sv, _ := v.(*reflect.SliceValue)
+
+ i := 0
+ for {
+ // Look ahead for ] - can only happen on first iteration.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+
+ // Back up so d.value can have the byte we just read.
+ d.off--
+ d.scan.undo(op)
+
+ // Get element of array, growing if necessary.
+ if i >= av.Cap() && sv != nil {
+ newcap := sv.Cap() + sv.Cap()/2
+ if newcap < 4 {
+ newcap = 4
+ }
+ newv := reflect.MakeSlice(sv.Type().(*reflect.SliceType), sv.Len(), newcap)
+ reflect.Copy(newv, sv)
+ sv.Set(newv)
+ }
+ if i >= av.Len() && sv != nil {
+ // Must be slice; gave up on array during i >= av.Cap().
+ sv.SetLen(i + 1)
+ }
+
+ // Decode into element.
+ if i < av.Len() {
+ d.value(av.Elem(i))
+ } else {
+ // Ran out of fixed array: skip.
+ d.value(nil)
+ }
+ i++
+
+ // Next token must be , or ].
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+ if op != scanArrayValue {
+ d.error(errPhase)
+ }
+ }
+ if i < av.Len() {
+ if sv == nil {
+ // Array. Zero the rest.
+ z := reflect.MakeZero(av.Type().(*reflect.ArrayType).Elem())
+ for ; i < av.Len(); i++ {
+ av.Elem(i).SetValue(z)
+ }
+ } else {
+ sv.SetLen(i)
+ }
+ }
+}
+
+// matchName returns true if key should be written to a field named name.
+func matchName(key, name string) bool {
+ return strings.ToLower(key) == strings.ToLower(name)
+}
+
+// object consumes an object from d.data[d.off-1:], decoding into the value v.
+// the first byte of the object ('{') has been read already.
+func (d *decodeState) object(v reflect.Value) {
+ // Check for unmarshaler.
+ unmarshaler, pv := d.indirect(v, false)
+ if unmarshaler != nil {
+ d.off--
+ err := unmarshaler.UnmarshalJSON(d.next())
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ // Decoding into nil interface? Switch to non-reflect code.
+ iv, ok := v.(*reflect.InterfaceValue)
+ if ok {
+ iv.Set(reflect.NewValue(d.objectInterface()))
+ return
+ }
+
+ // Check type of target: struct or map[string]T
+ var (
+ mv *reflect.MapValue
+ sv *reflect.StructValue
+ )
+ switch v := v.(type) {
+ case *reflect.MapValue:
+ // map must have string type
+ t := v.Type().(*reflect.MapType)
+ if t.Key() != reflect.Typeof("") {
+ d.saveError(&UnmarshalTypeError{"object", v.Type()})
+ break
+ }
+ mv = v
+ if mv.IsNil() {
+ mv.SetValue(reflect.MakeMap(t))
+ }
+ case *reflect.StructValue:
+ sv = v
+ default:
+ d.saveError(&UnmarshalTypeError{"object", v.Type()})
+ }
+
+ if mv == nil && sv == nil {
+ d.off--
+ d.next() // skip over { } in input
+ return
+ }
+
+ for {
+ // Read opening " of string key or closing }.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ // closing } - can only happen on first iteration.
+ break
+ }
+ if op != scanBeginLiteral {
+ d.error(errPhase)
+ }
+
+ // Read string key.
+ start := d.off - 1
+ op = d.scanWhile(scanContinue)
+ item := d.data[start : d.off-1]
+ key, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+
+ // Figure out field corresponding to key.
+ var subv reflect.Value
+ if mv != nil {
+ subv = reflect.MakeZero(mv.Type().(*reflect.MapType).Elem())
+ } else {
+ var f reflect.StructField
+ var ok bool
+ st := sv.Type().(*reflect.StructType)
+ // First try for field with that tag.
+ if isValidTag(key) {
+ for i := 0; i < sv.NumField(); i++ {
+ f = st.Field(i)
+ if f.Tag == key {
+ ok = true
+ break
+ }
+ }
+ }
+ if !ok {
+ // Second, exact match.
+ f, ok = st.FieldByName(key)
+ }
+ if !ok {
+ // Third, case-insensitive match.
+ f, ok = st.FieldByNameFunc(func(s string) bool { return matchName(key, s) })
+ }
+
+ // Extract value; name must be exported.
+ if ok {
+ if f.PkgPath != "" {
+ d.saveError(&UnmarshalFieldError{key, st, f})
+ } else {
+ subv = sv.FieldByIndex(f.Index)
+ }
+ }
+ }
+
+ // Read : before value.
+ if op == scanSkipSpace {
+ op = d.scanWhile(scanSkipSpace)
+ }
+ if op != scanObjectKey {
+ d.error(errPhase)
+ }
+
+ // Read value.
+ d.value(subv)
+
+ // Write value back to map;
+ // if using struct, subv points into struct already.
+ if mv != nil {
+ mv.SetElem(reflect.NewValue(key), subv)
+ }
+
+ // Next token must be , or }.
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ break
+ }
+ if op != scanObjectValue {
+ d.error(errPhase)
+ }
+ }
+}
+
+// literal consumes a literal from d.data[d.off-1:], decoding into the value v.
+// The first byte of the literal has been read already
+// (that's how the caller knows it's a literal).
+func (d *decodeState) literal(v reflect.Value) {
+ // All bytes inside literal return scanContinue op code.
+ start := d.off - 1
+ op := d.scanWhile(scanContinue)
+
+ // Scan read one byte too far; back up.
+ d.off--
+ d.scan.undo(op)
+ item := d.data[start:d.off]
+
+ // Check for unmarshaler.
+ wantptr := item[0] == 'n' // null
+ unmarshaler, pv := d.indirect(v, wantptr)
+ if unmarshaler != nil {
+ err := unmarshaler.UnmarshalJSON(item)
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ switch c := item[0]; c {
+ case 'n': // null
+ switch v.(type) {
+ default:
+ d.saveError(&UnmarshalTypeError{"null", v.Type()})
+ case *reflect.InterfaceValue, *reflect.PtrValue, *reflect.MapValue:
+ v.SetValue(nil)
+ }
+
+ case 't', 'f': // true, false
+ value := c == 't'
+ switch v := v.(type) {
+ default:
+ d.saveError(&UnmarshalTypeError{"bool", v.Type()})
+ case *reflect.BoolValue:
+ v.Set(value)
+ case *reflect.InterfaceValue:
+ v.Set(reflect.NewValue(value))
+ }
+
+ case '"': // string
+ s, ok := unquoteBytes(item)
+ if !ok {
+ d.error(errPhase)
+ }
+ switch v := v.(type) {
+ default:
+ d.saveError(&UnmarshalTypeError{"string", v.Type()})
+ case *reflect.SliceValue:
+ if v.Type() != byteSliceType {
+ d.saveError(&UnmarshalTypeError{"string", v.Type()})
+ break
+ }
+ b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
+ n, err := base64.StdEncoding.Decode(b, s)
+ if err != nil {
+ d.saveError(err)
+ break
+ }
+ v.Set(reflect.NewValue(b[0:n]).(*reflect.SliceValue))
+ case *reflect.StringValue:
+ v.Set(string(s))
+ case *reflect.InterfaceValue:
+ v.Set(reflect.NewValue(string(s)))
+ }
+
+ default: // number
+ if c != '-' && (c < '0' || c > '9') {
+ d.error(errPhase)
+ }
+ s := string(item)
+ switch v := v.(type) {
+ default:
+ d.error(&UnmarshalTypeError{"number", v.Type()})
+ case *reflect.InterfaceValue:
+ n, err := strconv.Atof64(s)
+ if err != nil {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(reflect.NewValue(n))
+
+ case *reflect.IntValue:
+ n, err := strconv.Atoi64(s)
+ if err != nil || v.Overflow(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(n)
+
+ case *reflect.UintValue:
+ n, err := strconv.Atoui64(s)
+ if err != nil || v.Overflow(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(n)
+
+ case *reflect.FloatValue:
+ n, err := strconv.AtofN(s, v.Type().Bits())
+ if err != nil || v.Overflow(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(n)
+ }
+ }
+}
+
+// The xxxInterface routines build up a value to be stored
+// in an empty interface. They are not strictly necessary,
+// but they avoid the weight of reflection in this common case.
+
+// valueInterface is like value but returns interface{}
+func (d *decodeState) valueInterface() interface{} {
+ switch d.scanWhile(scanSkipSpace) {
+ default:
+ d.error(errPhase)
+ case scanBeginArray:
+ return d.arrayInterface()
+ case scanBeginObject:
+ return d.objectInterface()
+ case scanBeginLiteral:
+ return d.literalInterface()
+ }
+ panic("unreachable")
+}
+
+// arrayInterface is like array but returns []interface{}.
+func (d *decodeState) arrayInterface() []interface{} {
+ var v vector.Vector
+ for {
+ // Look ahead for ] - can only happen on first iteration.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+
+ // Back up so d.value can have the byte we just read.
+ d.off--
+ d.scan.undo(op)
+
+ v.Push(d.valueInterface())
+
+ // Next token must be , or ].
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+ if op != scanArrayValue {
+ d.error(errPhase)
+ }
+ }
+ return v
+}
+
+// objectInterface is like object but returns map[string]interface{}.
+func (d *decodeState) objectInterface() map[string]interface{} {
+ m := make(map[string]interface{})
+ for {
+ // Read opening " of string key or closing }.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ // closing } - can only happen on first iteration.
+ break
+ }
+ if op != scanBeginLiteral {
+ d.error(errPhase)
+ }
+
+ // Read string key.
+ start := d.off - 1
+ op = d.scanWhile(scanContinue)
+ item := d.data[start : d.off-1]
+ key, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+
+ // Read : before value.
+ if op == scanSkipSpace {
+ op = d.scanWhile(scanSkipSpace)
+ }
+ if op != scanObjectKey {
+ d.error(errPhase)
+ }
+
+ // Read value.
+ m[key] = d.valueInterface()
+
+ // Next token must be , or }.
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ break
+ }
+ if op != scanObjectValue {
+ d.error(errPhase)
+ }
+ }
+ return m
+}
+
+// literalInterface is like literal but returns an interface value.
+func (d *decodeState) literalInterface() interface{} {
+ // All bytes inside literal return scanContinue op code.
+ start := d.off - 1
+ op := d.scanWhile(scanContinue)
+
+ // Scan read one byte too far; back up.
+ d.off--
+ d.scan.undo(op)
+ item := d.data[start:d.off]
+
+ switch c := item[0]; c {
+ case 'n': // null
+ return nil
+
+ case 't', 'f': // true, false
+ return c == 't'
+
+ case '"': // string
+ s, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+ return s
+
+ default: // number
+ if c != '-' && (c < '0' || c > '9') {
+ d.error(errPhase)
+ }
+ n, err := strconv.Atof64(string(item))
+ if err != nil {
+ d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.Typeof(0.0)})
+ }
+ return n
+ }
+ panic("unreachable")
+}
+
+// getu4 decodes \uXXXX from the beginning of s, returning the hex value,
+// or it returns -1.
+func getu4(s []byte) int {
+ if len(s) < 6 || s[0] != '\\' || s[1] != 'u' {
+ return -1
+ }
+ rune, err := strconv.Btoui64(string(s[2:6]), 16)
+ if err != nil {
+ return -1
+ }
+ return int(rune)
+}
+
+// unquote converts a quoted JSON string literal s into an actual string t.
+// The rules are different than for Go, so cannot use strconv.Unquote.
+func unquote(s []byte) (t string, ok bool) {
+ s, ok = unquoteBytes(s)
+ t = string(s)
+ return
+}
+
+func unquoteBytes(s []byte) (t []byte, ok bool) {
+ if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' {
+ return
+ }
+ s = s[1 : len(s)-1]
+
+ // Check for unusual characters. If there are none,
+ // then no unquoting is needed, so return a slice of the
+ // original bytes.
+ r := 0
+ for r < len(s) {
+ c := s[r]
+ if c == '\\' || c == '"' || c < ' ' {
+ break
+ }
+ if c < utf8.RuneSelf {
+ r++
+ continue
+ }
+ rune, size := utf8.DecodeRune(s[r:])
+ if rune == utf8.RuneError && size == 1 {
+ break
+ }
+ r += size
+ }
+ if r == len(s) {
+ return s, true
+ }
+
+ b := make([]byte, len(s)+2*utf8.UTFMax)
+ w := copy(b, s[0:r])
+ for r < len(s) {
+ // Out of room? Can only happen if s is full of
+ // malformed UTF-8 and we're replacing each
+ // byte with RuneError.
+ if w >= len(b)-2*utf8.UTFMax {
+ nb := make([]byte, (len(b)+utf8.UTFMax)*2)
+ copy(nb, b[0:w])
+ b = nb
+ }
+ switch c := s[r]; {
+ case c == '\\':
+ r++
+ if r >= len(s) {
+ return
+ }
+ switch s[r] {
+ default:
+ return
+ case '"', '\\', '/', '\'':
+ b[w] = s[r]
+ r++
+ w++
+ case 'b':
+ b[w] = '\b'
+ r++
+ w++
+ case 'f':
+ b[w] = '\f'
+ r++
+ w++
+ case 'n':
+ b[w] = '\n'
+ r++
+ w++
+ case 'r':
+ b[w] = '\r'
+ r++
+ w++
+ case 't':
+ b[w] = '\t'
+ r++
+ w++
+ case 'u':
+ r--
+ rune := getu4(s[r:])
+ if rune < 0 {
+ return
+ }
+ r += 6
+ if utf16.IsSurrogate(rune) {
+ rune1 := getu4(s[r:])
+ if dec := utf16.DecodeRune(rune, rune1); dec != unicode.ReplacementChar {
+ // A valid pair; consume.
+ r += 6
+ w += utf8.EncodeRune(b[w:], dec)
+ break
+ }
+ // Invalid surrogate; fall back to replacement rune.
+ rune = unicode.ReplacementChar
+ }
+ w += utf8.EncodeRune(b[w:], rune)
+ }
+
+ // Quote, control characters are invalid.
+ case c == '"', c < ' ':
+ return
+
+ // ASCII
+ case c < utf8.RuneSelf:
+ b[w] = c
+ r++
+ w++
+
+ // Coerce to well-formed UTF-8.
+ default:
+ rune, size := utf8.DecodeRune(s[r:])
+ r += size
+ w += utf8.EncodeRune(b[w:], rune)
+ }
+ }
+ return b[0:w], true
+}
diff --git a/src/cmd/fix/testdata/reflect.decode.go.out b/src/cmd/fix/testdata/reflect.decode.go.out
new file mode 100644
index 000000000..fb7910ee3
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.decode.go.out
@@ -0,0 +1,908 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Represents JSON data structure using native Go types: booleans, floats,
+// strings, arrays, and maps.
+
+package json
+
+import (
+ "container/vector"
+ "encoding/base64"
+ "os"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf16"
+ "utf8"
+)
+
+// Unmarshal parses the JSON-encoded data and stores the result
+// in the value pointed to by v.
+//
+// Unmarshal traverses the value v recursively.
+// If an encountered value implements the Unmarshaler interface,
+// Unmarshal calls its UnmarshalJSON method with a well-formed
+// JSON encoding.
+//
+// Otherwise, Unmarshal uses the inverse of the encodings that
+// Marshal uses, allocating maps, slices, and pointers as necessary,
+// with the following additional rules:
+//
+// To unmarshal a JSON value into a nil interface value, the
+// type stored in the interface value is one of:
+//
+// bool, for JSON booleans
+// float64, for JSON numbers
+// string, for JSON strings
+// []interface{}, for JSON arrays
+// map[string]interface{}, for JSON objects
+// nil for JSON null
+//
+// If a JSON value is not appropriate for a given target type,
+// or if a JSON number overflows the target type, Unmarshal
+// skips that field and completes the unmarshalling as best it can.
+// If no more serious errors are encountered, Unmarshal returns
+// an UnmarshalTypeError describing the earliest such error.
+//
+func Unmarshal(data []byte, v interface{}) os.Error {
+ d := new(decodeState).init(data)
+
+ // Quick check for well-formedness.
+ // Avoids filling out half a data structure
+ // before discovering a JSON syntax error.
+ err := checkValid(data, &d.scan)
+ if err != nil {
+ return err
+ }
+
+ return d.unmarshal(v)
+}
+
+// Unmarshaler is the interface implemented by objects
+// that can unmarshal a JSON description of themselves.
+// The input can be assumed to be a valid JSON object
+// encoding. UnmarshalJSON must copy the JSON data
+// if it wishes to retain the data after returning.
+type Unmarshaler interface {
+ UnmarshalJSON([]byte) os.Error
+}
+
+// An UnmarshalTypeError describes a JSON value that was
+// not appropriate for a value of a specific Go type.
+type UnmarshalTypeError struct {
+ Value string // description of JSON value - "bool", "array", "number -5"
+ Type reflect.Type // type of Go value it could not be assigned to
+}
+
+func (e *UnmarshalTypeError) String() string {
+ return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String()
+}
+
+// An UnmarshalFieldError describes a JSON object key that
+// led to an unexported (and therefore unwritable) struct field.
+type UnmarshalFieldError struct {
+ Key string
+ Type reflect.Type
+ Field reflect.StructField
+}
+
+func (e *UnmarshalFieldError) String() string {
+ return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String()
+}
+
+// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal.
+// (The argument to Unmarshal must be a non-nil pointer.)
+type InvalidUnmarshalError struct {
+ Type reflect.Type
+}
+
+func (e *InvalidUnmarshalError) String() string {
+ if e.Type == nil {
+ return "json: Unmarshal(nil)"
+ }
+
+ if e.Type.Kind() != reflect.Ptr {
+ return "json: Unmarshal(non-pointer " + e.Type.String() + ")"
+ }
+ return "json: Unmarshal(nil " + e.Type.String() + ")"
+}
+
+func (d *decodeState) unmarshal(v interface{}) (err os.Error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(runtime.Error); ok {
+ panic(r)
+ }
+ err = r.(os.Error)
+ }
+ }()
+
+ rv := reflect.ValueOf(v)
+ pv := rv
+ if pv.Kind() != reflect.Ptr ||
+ pv.IsNil() {
+ return &InvalidUnmarshalError{reflect.TypeOf(v)}
+ }
+
+ d.scan.reset()
+ // We decode rv not pv.Elem because the Unmarshaler interface
+ // test must be applied at the top level of the value.
+ d.value(rv)
+ return d.savedError
+}
+
+// decodeState represents the state while decoding a JSON value.
+type decodeState struct {
+ data []byte
+ off int // read offset in data
+ scan scanner
+ nextscan scanner // for calls to nextValue
+ savedError os.Error
+}
+
+// errPhase is used for errors that should not happen unless
+// there is a bug in the JSON decoder or something is editing
+// the data slice while the decoder executes.
+var errPhase = os.NewError("JSON decoder out of sync - data changing underfoot?")
+
+func (d *decodeState) init(data []byte) *decodeState {
+ d.data = data
+ d.off = 0
+ d.savedError = nil
+ return d
+}
+
+// error aborts the decoding by panicking with err.
+func (d *decodeState) error(err os.Error) {
+ panic(err)
+}
+
+// saveError saves the first err it is called with,
+// for reporting at the end of the unmarshal.
+func (d *decodeState) saveError(err os.Error) {
+ if d.savedError == nil {
+ d.savedError = err
+ }
+}
+
+// next cuts off and returns the next full JSON value in d.data[d.off:].
+// The next value is known to be an object or array, not a literal.
+func (d *decodeState) next() []byte {
+ c := d.data[d.off]
+ item, rest, err := nextValue(d.data[d.off:], &d.nextscan)
+ if err != nil {
+ d.error(err)
+ }
+ d.off = len(d.data) - len(rest)
+
+ // Our scanner has seen the opening brace/bracket
+ // and thinks we're still in the middle of the object.
+ // invent a closing brace/bracket to get it out.
+ if c == '{' {
+ d.scan.step(&d.scan, '}')
+ } else {
+ d.scan.step(&d.scan, ']')
+ }
+
+ return item
+}
+
+// scanWhile processes bytes in d.data[d.off:] until it
+// receives a scan code not equal to op.
+// It updates d.off and returns the new scan code.
+func (d *decodeState) scanWhile(op int) int {
+ var newOp int
+ for {
+ if d.off >= len(d.data) {
+ newOp = d.scan.eof()
+ d.off = len(d.data) + 1 // mark processed EOF with len+1
+ } else {
+ c := int(d.data[d.off])
+ d.off++
+ newOp = d.scan.step(&d.scan, c)
+ }
+ if newOp != op {
+ break
+ }
+ }
+ return newOp
+}
+
+// value decodes a JSON value from d.data[d.off:] into the value.
+// it updates d.off to point past the decoded value.
+func (d *decodeState) value(v reflect.Value) {
+ if !v.IsValid() {
+ _, rest, err := nextValue(d.data[d.off:], &d.nextscan)
+ if err != nil {
+ d.error(err)
+ }
+ d.off = len(d.data) - len(rest)
+
+ // d.scan thinks we're still at the beginning of the item.
+ // Feed in an empty string - the shortest, simplest value -
+ // so that it knows we got to the end of the value.
+ if d.scan.step == stateRedo {
+ panic("redo")
+ }
+ d.scan.step(&d.scan, '"')
+ d.scan.step(&d.scan, '"')
+ return
+ }
+
+ switch op := d.scanWhile(scanSkipSpace); op {
+ default:
+ d.error(errPhase)
+
+ case scanBeginArray:
+ d.array(v)
+
+ case scanBeginObject:
+ d.object(v)
+
+ case scanBeginLiteral:
+ d.literal(v)
+ }
+}
+
+// indirect walks down v allocating pointers as needed,
+// until it gets to a non-pointer.
+// if it encounters an Unmarshaler, indirect stops and returns that.
+// if wantptr is true, indirect stops at the last pointer.
+func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, reflect.Value) {
+ for {
+ var isUnmarshaler bool
+ if v.Type().NumMethod() > 0 {
+ // Remember that this is an unmarshaler,
+ // but wait to return it until after allocating
+ // the pointer (if necessary).
+ _, isUnmarshaler = v.Interface().(Unmarshaler)
+ }
+
+ if iv := v; iv.Kind() == reflect.Interface && !iv.IsNil() {
+ v = iv.Elem()
+ continue
+ }
+ pv := v
+ if pv.Kind() != reflect.Ptr {
+ break
+ }
+
+ if pv.Elem().Kind() != reflect.Ptr &&
+ wantptr && !isUnmarshaler {
+ return nil, pv
+ }
+ if pv.IsNil() {
+ pv.Set(reflect.Zero(pv.Type().Elem()).Addr())
+ }
+ if isUnmarshaler {
+ // Using v.Interface().(Unmarshaler)
+ // here means that we have to use a pointer
+ // as the struct field. We cannot use a value inside
+ // a pointer to a struct, because in that case
+ // v.Interface() is the value (x.f) not the pointer (&x.f).
+ // This is an unfortunate consequence of reflect.
+ // An alternative would be to look up the
+ // UnmarshalJSON method and return a FuncValue.
+ return v.Interface().(Unmarshaler), reflect.Value{}
+ }
+ v = pv.Elem()
+ }
+ return nil, v
+}
+
+// array consumes an array from d.data[d.off-1:], decoding into the value v.
+// the first byte of the array ('[') has been read already.
+func (d *decodeState) array(v reflect.Value) {
+ // Check for unmarshaler.
+ unmarshaler, pv := d.indirect(v, false)
+ if unmarshaler != nil {
+ d.off--
+ err := unmarshaler.UnmarshalJSON(d.next())
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ // Decoding into nil interface? Switch to non-reflect code.
+ iv := v
+ ok := iv.Kind() == reflect.Interface
+ if ok {
+ iv.Set(reflect.ValueOf(d.arrayInterface()))
+ return
+ }
+
+ // Check type of target.
+ av := v
+ if av.Kind() != reflect.Array && av.Kind() != reflect.Slice {
+ d.saveError(&UnmarshalTypeError{"array", v.Type()})
+ d.off--
+ d.next()
+ return
+ }
+
+ sv := v
+
+ i := 0
+ for {
+ // Look ahead for ] - can only happen on first iteration.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+
+ // Back up so d.value can have the byte we just read.
+ d.off--
+ d.scan.undo(op)
+
+ // Get element of array, growing if necessary.
+ if i >= av.Cap() && sv.IsValid() {
+ newcap := sv.Cap() + sv.Cap()/2
+ if newcap < 4 {
+ newcap = 4
+ }
+ newv := reflect.MakeSlice(sv.Type(), sv.Len(), newcap)
+ reflect.Copy(newv, sv)
+ sv.Set(newv)
+ }
+ if i >= av.Len() && sv.IsValid() {
+ // Must be slice; gave up on array during i >= av.Cap().
+ sv.SetLen(i + 1)
+ }
+
+ // Decode into element.
+ if i < av.Len() {
+ d.value(av.Index(i))
+ } else {
+ // Ran out of fixed array: skip.
+ d.value(reflect.Value{})
+ }
+ i++
+
+ // Next token must be , or ].
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+ if op != scanArrayValue {
+ d.error(errPhase)
+ }
+ }
+ if i < av.Len() {
+ if !sv.IsValid() {
+ // Array. Zero the rest.
+ z := reflect.Zero(av.Type().Elem())
+ for ; i < av.Len(); i++ {
+ av.Index(i).Set(z)
+ }
+ } else {
+ sv.SetLen(i)
+ }
+ }
+}
+
+// matchName returns true if key should be written to a field named name.
+func matchName(key, name string) bool {
+ return strings.ToLower(key) == strings.ToLower(name)
+}
+
+// object consumes an object from d.data[d.off-1:], decoding into the value v.
+// the first byte of the object ('{') has been read already.
+func (d *decodeState) object(v reflect.Value) {
+ // Check for unmarshaler.
+ unmarshaler, pv := d.indirect(v, false)
+ if unmarshaler != nil {
+ d.off--
+ err := unmarshaler.UnmarshalJSON(d.next())
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ // Decoding into nil interface? Switch to non-reflect code.
+ iv := v
+ if iv.Kind() == reflect.Interface {
+ iv.Set(reflect.ValueOf(d.objectInterface()))
+ return
+ }
+
+ // Check type of target: struct or map[string]T
+ var (
+ mv reflect.Value
+ sv reflect.Value
+ )
+ switch v.Kind() {
+ case reflect.Map:
+ // map must have string type
+ t := v.Type()
+ if t.Key() != reflect.TypeOf("") {
+ d.saveError(&UnmarshalTypeError{"object", v.Type()})
+ break
+ }
+ mv = v
+ if mv.IsNil() {
+ mv.Set(reflect.MakeMap(t))
+ }
+ case reflect.Struct:
+ sv = v
+ default:
+ d.saveError(&UnmarshalTypeError{"object", v.Type()})
+ }
+
+ if !mv.IsValid() && !sv.IsValid() {
+ d.off--
+ d.next() // skip over { } in input
+ return
+ }
+
+ for {
+ // Read opening " of string key or closing }.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ // closing } - can only happen on first iteration.
+ break
+ }
+ if op != scanBeginLiteral {
+ d.error(errPhase)
+ }
+
+ // Read string key.
+ start := d.off - 1
+ op = d.scanWhile(scanContinue)
+ item := d.data[start : d.off-1]
+ key, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+
+ // Figure out field corresponding to key.
+ var subv reflect.Value
+ if mv.IsValid() {
+ subv = reflect.Zero(mv.Type().Elem())
+ } else {
+ var f reflect.StructField
+ var ok bool
+ st := sv.Type()
+ // First try for field with that tag.
+ if isValidTag(key) {
+ for i := 0; i < sv.NumField(); i++ {
+ f = st.Field(i)
+ if f.Tag == key {
+ ok = true
+ break
+ }
+ }
+ }
+ if !ok {
+ // Second, exact match.
+ f, ok = st.FieldByName(key)
+ }
+ if !ok {
+ // Third, case-insensitive match.
+ f, ok = st.FieldByNameFunc(func(s string) bool { return matchName(key, s) })
+ }
+
+ // Extract value; name must be exported.
+ if ok {
+ if f.PkgPath != "" {
+ d.saveError(&UnmarshalFieldError{key, st, f})
+ } else {
+ subv = sv.FieldByIndex(f.Index)
+ }
+ }
+ }
+
+ // Read : before value.
+ if op == scanSkipSpace {
+ op = d.scanWhile(scanSkipSpace)
+ }
+ if op != scanObjectKey {
+ d.error(errPhase)
+ }
+
+ // Read value.
+ d.value(subv)
+
+ // Write value back to map;
+ // if using struct, subv points into struct already.
+ if mv.IsValid() {
+ mv.SetMapIndex(reflect.ValueOf(key), subv)
+ }
+
+ // Next token must be , or }.
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ break
+ }
+ if op != scanObjectValue {
+ d.error(errPhase)
+ }
+ }
+}
+
+// literal consumes a literal from d.data[d.off-1:], decoding into the value v.
+// The first byte of the literal has been read already
+// (that's how the caller knows it's a literal).
+func (d *decodeState) literal(v reflect.Value) {
+ // All bytes inside literal return scanContinue op code.
+ start := d.off - 1
+ op := d.scanWhile(scanContinue)
+
+ // Scan read one byte too far; back up.
+ d.off--
+ d.scan.undo(op)
+ item := d.data[start:d.off]
+
+ // Check for unmarshaler.
+ wantptr := item[0] == 'n' // null
+ unmarshaler, pv := d.indirect(v, wantptr)
+ if unmarshaler != nil {
+ err := unmarshaler.UnmarshalJSON(item)
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ switch c := item[0]; c {
+ case 'n': // null
+ switch v.Kind() {
+ default:
+ d.saveError(&UnmarshalTypeError{"null", v.Type()})
+ case reflect.Interface, reflect.Ptr, reflect.Map:
+ v.Set(reflect.Zero(v.Type()))
+ }
+
+ case 't', 'f': // true, false
+ value := c == 't'
+ switch v.Kind() {
+ default:
+ d.saveError(&UnmarshalTypeError{"bool", v.Type()})
+ case reflect.Bool:
+ v.SetBool(value)
+ case reflect.Interface:
+ v.Set(reflect.ValueOf(value))
+ }
+
+ case '"': // string
+ s, ok := unquoteBytes(item)
+ if !ok {
+ d.error(errPhase)
+ }
+ switch v.Kind() {
+ default:
+ d.saveError(&UnmarshalTypeError{"string", v.Type()})
+ case reflect.Slice:
+ if v.Type() != byteSliceType {
+ d.saveError(&UnmarshalTypeError{"string", v.Type()})
+ break
+ }
+ b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
+ n, err := base64.StdEncoding.Decode(b, s)
+ if err != nil {
+ d.saveError(err)
+ break
+ }
+ v.Set(reflect.ValueOf(b[0:n]))
+ case reflect.String:
+ v.SetString(string(s))
+ case reflect.Interface:
+ v.Set(reflect.ValueOf(string(s)))
+ }
+
+ default: // number
+ if c != '-' && (c < '0' || c > '9') {
+ d.error(errPhase)
+ }
+ s := string(item)
+ switch v.Kind() {
+ default:
+ d.error(&UnmarshalTypeError{"number", v.Type()})
+ case reflect.Interface:
+ n, err := strconv.Atof64(s)
+ if err != nil {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(reflect.ValueOf(n))
+
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ n, err := strconv.Atoi64(s)
+ if err != nil || v.OverflowInt(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.SetInt(n)
+
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ n, err := strconv.Atoui64(s)
+ if err != nil || v.OverflowUint(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.SetUint(n)
+
+ case reflect.Float32, reflect.Float64:
+ n, err := strconv.AtofN(s, v.Type().Bits())
+ if err != nil || v.OverflowFloat(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.SetFloat(n)
+ }
+ }
+}
+
+// The xxxInterface routines build up a value to be stored
+// in an empty interface. They are not strictly necessary,
+// but they avoid the weight of reflection in this common case.
+
+// valueInterface is like value but returns interface{}
+func (d *decodeState) valueInterface() interface{} {
+ switch d.scanWhile(scanSkipSpace) {
+ default:
+ d.error(errPhase)
+ case scanBeginArray:
+ return d.arrayInterface()
+ case scanBeginObject:
+ return d.objectInterface()
+ case scanBeginLiteral:
+ return d.literalInterface()
+ }
+ panic("unreachable")
+}
+
+// arrayInterface is like array but returns []interface{}.
+func (d *decodeState) arrayInterface() []interface{} {
+ var v vector.Vector
+ for {
+ // Look ahead for ] - can only happen on first iteration.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+
+ // Back up so d.value can have the byte we just read.
+ d.off--
+ d.scan.undo(op)
+
+ v.Push(d.valueInterface())
+
+ // Next token must be , or ].
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+ if op != scanArrayValue {
+ d.error(errPhase)
+ }
+ }
+ return v
+}
+
+// objectInterface is like object but returns map[string]interface{}.
+func (d *decodeState) objectInterface() map[string]interface{} {
+ m := make(map[string]interface{})
+ for {
+ // Read opening " of string key or closing }.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ // closing } - can only happen on first iteration.
+ break
+ }
+ if op != scanBeginLiteral {
+ d.error(errPhase)
+ }
+
+ // Read string key.
+ start := d.off - 1
+ op = d.scanWhile(scanContinue)
+ item := d.data[start : d.off-1]
+ key, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+
+ // Read : before value.
+ if op == scanSkipSpace {
+ op = d.scanWhile(scanSkipSpace)
+ }
+ if op != scanObjectKey {
+ d.error(errPhase)
+ }
+
+ // Read value.
+ m[key] = d.valueInterface()
+
+ // Next token must be , or }.
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ break
+ }
+ if op != scanObjectValue {
+ d.error(errPhase)
+ }
+ }
+ return m
+}
+
+// literalInterface is like literal but returns an interface value.
+func (d *decodeState) literalInterface() interface{} {
+ // All bytes inside literal return scanContinue op code.
+ start := d.off - 1
+ op := d.scanWhile(scanContinue)
+
+ // Scan read one byte too far; back up.
+ d.off--
+ d.scan.undo(op)
+ item := d.data[start:d.off]
+
+ switch c := item[0]; c {
+ case 'n': // null
+ return nil
+
+ case 't', 'f': // true, false
+ return c == 't'
+
+ case '"': // string
+ s, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+ return s
+
+ default: // number
+ if c != '-' && (c < '0' || c > '9') {
+ d.error(errPhase)
+ }
+ n, err := strconv.Atof64(string(item))
+ if err != nil {
+ d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.TypeOf(0.0)})
+ }
+ return n
+ }
+ panic("unreachable")
+}
+
+// getu4 decodes \uXXXX from the beginning of s, returning the hex value,
+// or it returns -1.
+func getu4(s []byte) int {
+ if len(s) < 6 || s[0] != '\\' || s[1] != 'u' {
+ return -1
+ }
+ rune, err := strconv.Btoui64(string(s[2:6]), 16)
+ if err != nil {
+ return -1
+ }
+ return int(rune)
+}
+
+// unquote converts a quoted JSON string literal s into an actual string t.
+// The rules are different than for Go, so cannot use strconv.Unquote.
+func unquote(s []byte) (t string, ok bool) {
+ s, ok = unquoteBytes(s)
+ t = string(s)
+ return
+}
+
+func unquoteBytes(s []byte) (t []byte, ok bool) {
+ if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' {
+ return
+ }
+ s = s[1 : len(s)-1]
+
+ // Check for unusual characters. If there are none,
+ // then no unquoting is needed, so return a slice of the
+ // original bytes.
+ r := 0
+ for r < len(s) {
+ c := s[r]
+ if c == '\\' || c == '"' || c < ' ' {
+ break
+ }
+ if c < utf8.RuneSelf {
+ r++
+ continue
+ }
+ rune, size := utf8.DecodeRune(s[r:])
+ if rune == utf8.RuneError && size == 1 {
+ break
+ }
+ r += size
+ }
+ if r == len(s) {
+ return s, true
+ }
+
+ b := make([]byte, len(s)+2*utf8.UTFMax)
+ w := copy(b, s[0:r])
+ for r < len(s) {
+ // Out of room? Can only happen if s is full of
+ // malformed UTF-8 and we're replacing each
+ // byte with RuneError.
+ if w >= len(b)-2*utf8.UTFMax {
+ nb := make([]byte, (len(b)+utf8.UTFMax)*2)
+ copy(nb, b[0:w])
+ b = nb
+ }
+ switch c := s[r]; {
+ case c == '\\':
+ r++
+ if r >= len(s) {
+ return
+ }
+ switch s[r] {
+ default:
+ return
+ case '"', '\\', '/', '\'':
+ b[w] = s[r]
+ r++
+ w++
+ case 'b':
+ b[w] = '\b'
+ r++
+ w++
+ case 'f':
+ b[w] = '\f'
+ r++
+ w++
+ case 'n':
+ b[w] = '\n'
+ r++
+ w++
+ case 'r':
+ b[w] = '\r'
+ r++
+ w++
+ case 't':
+ b[w] = '\t'
+ r++
+ w++
+ case 'u':
+ r--
+ rune := getu4(s[r:])
+ if rune < 0 {
+ return
+ }
+ r += 6
+ if utf16.IsSurrogate(rune) {
+ rune1 := getu4(s[r:])
+ if dec := utf16.DecodeRune(rune, rune1); dec != unicode.ReplacementChar {
+ // A valid pair; consume.
+ r += 6
+ w += utf8.EncodeRune(b[w:], dec)
+ break
+ }
+ // Invalid surrogate; fall back to replacement rune.
+ rune = unicode.ReplacementChar
+ }
+ w += utf8.EncodeRune(b[w:], rune)
+ }
+
+ // Quote, control characters are invalid.
+ case c == '"', c < ' ':
+ return
+
+ // ASCII
+ case c < utf8.RuneSelf:
+ b[w] = c
+ r++
+ w++
+
+ // Coerce to well-formed UTF-8.
+ default:
+ rune, size := utf8.DecodeRune(s[r:])
+ r += size
+ w += utf8.EncodeRune(b[w:], rune)
+ }
+ }
+ return b[0:w], true
+}
diff --git a/src/cmd/fix/testdata/reflect.decoder.go.in b/src/cmd/fix/testdata/reflect.decoder.go.in
new file mode 100644
index 000000000..0ce9b06fd
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.decoder.go.in
@@ -0,0 +1,196 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "bufio"
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "sync"
+)
+
+// A Decoder manages the receipt of type and data information read from the
+// remote side of a connection.
+type Decoder struct {
+ mutex sync.Mutex // each item must be received atomically
+ r io.Reader // source of the data
+ buf bytes.Buffer // buffer for more efficient i/o from r
+ wireType map[typeId]*wireType // map from remote ID to local description
+ decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
+ ignorerCache map[typeId]**decEngine // ditto for ignored objects
+ freeList *decoderState // list of free decoderStates; avoids reallocation
+ countBuf []byte // used for decoding integers while parsing messages
+ tmp []byte // temporary storage for i/o; saves reallocating
+ err os.Error
+}
+
+// NewDecoder returns a new decoder that reads from the io.Reader.
+func NewDecoder(r io.Reader) *Decoder {
+ dec := new(Decoder)
+ dec.r = bufio.NewReader(r)
+ dec.wireType = make(map[typeId]*wireType)
+ dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
+ dec.ignorerCache = make(map[typeId]**decEngine)
+ dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes
+
+ return dec
+}
+
+// recvType loads the definition of a type.
+func (dec *Decoder) recvType(id typeId) {
+ // Have we already seen this type? That's an error
+ if id < firstUserId || dec.wireType[id] != nil {
+ dec.err = os.NewError("gob: duplicate type received")
+ return
+ }
+
+ // Type:
+ wire := new(wireType)
+ dec.decodeValue(tWireType, reflect.NewValue(wire))
+ if dec.err != nil {
+ return
+ }
+ // Remember we've seen this type.
+ dec.wireType[id] = wire
+}
+
+// recvMessage reads the next count-delimited item from the input. It is the converse
+// of Encoder.writeMessage. It returns false on EOF or other error reading the message.
+func (dec *Decoder) recvMessage() bool {
+ // Read a count.
+ nbytes, _, err := decodeUintReader(dec.r, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ return false
+ }
+ dec.readMessage(int(nbytes))
+ return dec.err == nil
+}
+
+// readMessage reads the next nbytes bytes from the input.
+func (dec *Decoder) readMessage(nbytes int) {
+ // Allocate the buffer.
+ if cap(dec.tmp) < nbytes {
+ dec.tmp = make([]byte, nbytes+100) // room to grow
+ }
+ dec.tmp = dec.tmp[:nbytes]
+
+ // Read the data
+ _, dec.err = io.ReadFull(dec.r, dec.tmp)
+ if dec.err != nil {
+ if dec.err == os.EOF {
+ dec.err = io.ErrUnexpectedEOF
+ }
+ return
+ }
+ dec.buf.Write(dec.tmp)
+}
+
+// toInt turns an encoded uint64 into an int, according to the marshaling rules.
+func toInt(x uint64) int64 {
+ i := int64(x >> 1)
+ if x&1 != 0 {
+ i = ^i
+ }
+ return i
+}
+
+func (dec *Decoder) nextInt() int64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return toInt(n)
+}
+
+func (dec *Decoder) nextUint() uint64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return n
+}
+
+// decodeTypeSequence parses:
+// TypeSequence
+// (TypeDefinition DelimitedTypeDefinition*)?
+// and returns the type id of the next value. It returns -1 at
+// EOF. Upon return, the remainder of dec.buf is the value to be
+// decoded. If this is an interface value, it can be ignored by
+// simply resetting that buffer.
+func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
+ for dec.err == nil {
+ if dec.buf.Len() == 0 {
+ if !dec.recvMessage() {
+ break
+ }
+ }
+ // Receive a type id.
+ id := typeId(dec.nextInt())
+ if id >= 0 {
+ // Value follows.
+ return id
+ }
+ // Type definition for (-id) follows.
+ dec.recvType(-id)
+ // When decoding an interface, after a type there may be a
+ // DelimitedValue still in the buffer. Skip its count.
+ // (Alternatively, the buffer is empty and the byte count
+ // will be absorbed by recvMessage.)
+ if dec.buf.Len() > 0 {
+ if !isInterface {
+ dec.err = os.NewError("extra data in buffer")
+ break
+ }
+ dec.nextUint()
+ }
+ }
+ return -1
+}
+
+// Decode reads the next value from the connection and stores
+// it in the data represented by the empty interface value.
+// If e is nil, the value will be discarded. Otherwise,
+// the value underlying e must either be the correct type for the next
+// data item received, and must be a pointer.
+func (dec *Decoder) Decode(e interface{}) os.Error {
+ if e == nil {
+ return dec.DecodeValue(nil)
+ }
+ value := reflect.NewValue(e)
+ // If e represents a value as opposed to a pointer, the answer won't
+ // get back to the caller. Make sure it's a pointer.
+ if value.Type().Kind() != reflect.Ptr {
+ dec.err = os.NewError("gob: attempt to decode into a non-pointer")
+ return dec.err
+ }
+ return dec.DecodeValue(value)
+}
+
+// DecodeValue reads the next value from the connection and stores
+// it in the data represented by the reflection value.
+// The value must be the correct type for the next
+// data item received, or it may be nil, which means the
+// value will be discarded.
+func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
+ // Make sure we're single-threaded through here.
+ dec.mutex.Lock()
+ defer dec.mutex.Unlock()
+
+ dec.buf.Reset() // In case data lingers from previous invocation.
+ dec.err = nil
+ id := dec.decodeTypeSequence(false)
+ if dec.err == nil {
+ dec.decodeValue(id, value)
+ }
+ return dec.err
+}
+
+// If debug.go is compiled into the program , debugFunc prints a human-readable
+// representation of the gob data read from r by calling that file's Debug function.
+// Otherwise it is nil.
+var debugFunc func(io.Reader)
diff --git a/src/cmd/fix/testdata/reflect.decoder.go.out b/src/cmd/fix/testdata/reflect.decoder.go.out
new file mode 100644
index 000000000..ece88ecbe
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.decoder.go.out
@@ -0,0 +1,196 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "bufio"
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "sync"
+)
+
+// A Decoder manages the receipt of type and data information read from the
+// remote side of a connection.
+type Decoder struct {
+ mutex sync.Mutex // each item must be received atomically
+ r io.Reader // source of the data
+ buf bytes.Buffer // buffer for more efficient i/o from r
+ wireType map[typeId]*wireType // map from remote ID to local description
+ decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
+ ignorerCache map[typeId]**decEngine // ditto for ignored objects
+ freeList *decoderState // list of free decoderStates; avoids reallocation
+ countBuf []byte // used for decoding integers while parsing messages
+ tmp []byte // temporary storage for i/o; saves reallocating
+ err os.Error
+}
+
+// NewDecoder returns a new decoder that reads from the io.Reader.
+func NewDecoder(r io.Reader) *Decoder {
+ dec := new(Decoder)
+ dec.r = bufio.NewReader(r)
+ dec.wireType = make(map[typeId]*wireType)
+ dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
+ dec.ignorerCache = make(map[typeId]**decEngine)
+ dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes
+
+ return dec
+}
+
+// recvType loads the definition of a type.
+func (dec *Decoder) recvType(id typeId) {
+ // Have we already seen this type? That's an error
+ if id < firstUserId || dec.wireType[id] != nil {
+ dec.err = os.NewError("gob: duplicate type received")
+ return
+ }
+
+ // Type:
+ wire := new(wireType)
+ dec.decodeValue(tWireType, reflect.ValueOf(wire))
+ if dec.err != nil {
+ return
+ }
+ // Remember we've seen this type.
+ dec.wireType[id] = wire
+}
+
+// recvMessage reads the next count-delimited item from the input. It is the converse
+// of Encoder.writeMessage. It returns false on EOF or other error reading the message.
+func (dec *Decoder) recvMessage() bool {
+ // Read a count.
+ nbytes, _, err := decodeUintReader(dec.r, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ return false
+ }
+ dec.readMessage(int(nbytes))
+ return dec.err == nil
+}
+
+// readMessage reads the next nbytes bytes from the input.
+func (dec *Decoder) readMessage(nbytes int) {
+ // Allocate the buffer.
+ if cap(dec.tmp) < nbytes {
+ dec.tmp = make([]byte, nbytes+100) // room to grow
+ }
+ dec.tmp = dec.tmp[:nbytes]
+
+ // Read the data
+ _, dec.err = io.ReadFull(dec.r, dec.tmp)
+ if dec.err != nil {
+ if dec.err == os.EOF {
+ dec.err = io.ErrUnexpectedEOF
+ }
+ return
+ }
+ dec.buf.Write(dec.tmp)
+}
+
+// toInt turns an encoded uint64 into an int, according to the marshaling rules.
+func toInt(x uint64) int64 {
+ i := int64(x >> 1)
+ if x&1 != 0 {
+ i = ^i
+ }
+ return i
+}
+
+func (dec *Decoder) nextInt() int64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return toInt(n)
+}
+
+func (dec *Decoder) nextUint() uint64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return n
+}
+
+// decodeTypeSequence parses:
+// TypeSequence
+// (TypeDefinition DelimitedTypeDefinition*)?
+// and returns the type id of the next value. It returns -1 at
+// EOF. Upon return, the remainder of dec.buf is the value to be
+// decoded. If this is an interface value, it can be ignored by
+// simply resetting that buffer.
+func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
+ for dec.err == nil {
+ if dec.buf.Len() == 0 {
+ if !dec.recvMessage() {
+ break
+ }
+ }
+ // Receive a type id.
+ id := typeId(dec.nextInt())
+ if id >= 0 {
+ // Value follows.
+ return id
+ }
+ // Type definition for (-id) follows.
+ dec.recvType(-id)
+ // When decoding an interface, after a type there may be a
+ // DelimitedValue still in the buffer. Skip its count.
+ // (Alternatively, the buffer is empty and the byte count
+ // will be absorbed by recvMessage.)
+ if dec.buf.Len() > 0 {
+ if !isInterface {
+ dec.err = os.NewError("extra data in buffer")
+ break
+ }
+ dec.nextUint()
+ }
+ }
+ return -1
+}
+
+// Decode reads the next value from the connection and stores
+// it in the data represented by the empty interface value.
+// If e is nil, the value will be discarded. Otherwise,
+// the value underlying e must either be the correct type for the next
+// data item received, and must be a pointer.
+func (dec *Decoder) Decode(e interface{}) os.Error {
+ if e == nil {
+ return dec.DecodeValue(reflect.Value{})
+ }
+ value := reflect.ValueOf(e)
+ // If e represents a value as opposed to a pointer, the answer won't
+ // get back to the caller. Make sure it's a pointer.
+ if value.Type().Kind() != reflect.Ptr {
+ dec.err = os.NewError("gob: attempt to decode into a non-pointer")
+ return dec.err
+ }
+ return dec.DecodeValue(value)
+}
+
+// DecodeValue reads the next value from the connection and stores
+// it in the data represented by the reflection value.
+// The value must be the correct type for the next
+// data item received, or it may be nil, which means the
+// value will be discarded.
+func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
+ // Make sure we're single-threaded through here.
+ dec.mutex.Lock()
+ defer dec.mutex.Unlock()
+
+ dec.buf.Reset() // In case data lingers from previous invocation.
+ dec.err = nil
+ id := dec.decodeTypeSequence(false)
+ if dec.err == nil {
+ dec.decodeValue(id, value)
+ }
+ return dec.err
+}
+
+// If debug.go is compiled into the program , debugFunc prints a human-readable
+// representation of the gob data read from r by calling that file's Debug function.
+// Otherwise it is nil.
+var debugFunc func(io.Reader)
diff --git a/src/cmd/fix/testdata/reflect.dnsmsg.go.in b/src/cmd/fix/testdata/reflect.dnsmsg.go.in
new file mode 100644
index 000000000..3d9c312f2
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.dnsmsg.go.in
@@ -0,0 +1,777 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// DNS packet assembly. See RFC 1035.
+//
+// This is intended to support name resolution during net.Dial.
+// It doesn't have to be blazing fast.
+//
+// Rather than write the usual handful of routines to pack and
+// unpack every message that can appear on the wire, we use
+// reflection to write a generic pack/unpack for structs and then
+// use it. Thus, if in the future we need to define new message
+// structs, no new pack/unpack/printing code needs to be written.
+//
+// The first half of this file defines the DNS message formats.
+// The second half implements the conversion to and from wire format.
+// A few of the structure elements have string tags to aid the
+// generic pack/unpack routines.
+//
+// TODO(rsc): There are enough names defined in this file that they're all
+// prefixed with dns. Perhaps put this in its own package later.
+
+package net
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+)
+
+// Packet formats
+
+// Wire constants.
+const (
+ // valid dnsRR_Header.Rrtype and dnsQuestion.qtype
+ dnsTypeA = 1
+ dnsTypeNS = 2
+ dnsTypeMD = 3
+ dnsTypeMF = 4
+ dnsTypeCNAME = 5
+ dnsTypeSOA = 6
+ dnsTypeMB = 7
+ dnsTypeMG = 8
+ dnsTypeMR = 9
+ dnsTypeNULL = 10
+ dnsTypeWKS = 11
+ dnsTypePTR = 12
+ dnsTypeHINFO = 13
+ dnsTypeMINFO = 14
+ dnsTypeMX = 15
+ dnsTypeTXT = 16
+ dnsTypeAAAA = 28
+ dnsTypeSRV = 33
+
+ // valid dnsQuestion.qtype only
+ dnsTypeAXFR = 252
+ dnsTypeMAILB = 253
+ dnsTypeMAILA = 254
+ dnsTypeALL = 255
+
+ // valid dnsQuestion.qclass
+ dnsClassINET = 1
+ dnsClassCSNET = 2
+ dnsClassCHAOS = 3
+ dnsClassHESIOD = 4
+ dnsClassANY = 255
+
+ // dnsMsg.rcode
+ dnsRcodeSuccess = 0
+ dnsRcodeFormatError = 1
+ dnsRcodeServerFailure = 2
+ dnsRcodeNameError = 3
+ dnsRcodeNotImplemented = 4
+ dnsRcodeRefused = 5
+)
+
+// The wire format for the DNS packet header.
+type dnsHeader struct {
+ Id uint16
+ Bits uint16
+ Qdcount, Ancount, Nscount, Arcount uint16
+}
+
+const (
+ // dnsHeader.Bits
+ _QR = 1 << 15 // query/response (response=1)
+ _AA = 1 << 10 // authoritative
+ _TC = 1 << 9 // truncated
+ _RD = 1 << 8 // recursion desired
+ _RA = 1 << 7 // recursion available
+)
+
+// DNS queries.
+type dnsQuestion struct {
+ Name string "domain-name" // "domain-name" specifies encoding; see packers below
+ Qtype uint16
+ Qclass uint16
+}
+
+// DNS responses (resource records).
+// There are many types of messages,
+// but they all share the same header.
+type dnsRR_Header struct {
+ Name string "domain-name"
+ Rrtype uint16
+ Class uint16
+ Ttl uint32
+ Rdlength uint16 // length of data after header
+}
+
+func (h *dnsRR_Header) Header() *dnsRR_Header {
+ return h
+}
+
+type dnsRR interface {
+ Header() *dnsRR_Header
+}
+
+// Specific DNS RR formats for each query type.
+
+type dnsRR_CNAME struct {
+ Hdr dnsRR_Header
+ Cname string "domain-name"
+}
+
+func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_HINFO struct {
+ Hdr dnsRR_Header
+ Cpu string
+ Os string
+}
+
+func (rr *dnsRR_HINFO) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MB struct {
+ Hdr dnsRR_Header
+ Mb string "domain-name"
+}
+
+func (rr *dnsRR_MB) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MG struct {
+ Hdr dnsRR_Header
+ Mg string "domain-name"
+}
+
+func (rr *dnsRR_MG) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MINFO struct {
+ Hdr dnsRR_Header
+ Rmail string "domain-name"
+ Email string "domain-name"
+}
+
+func (rr *dnsRR_MINFO) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MR struct {
+ Hdr dnsRR_Header
+ Mr string "domain-name"
+}
+
+func (rr *dnsRR_MR) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MX struct {
+ Hdr dnsRR_Header
+ Pref uint16
+ Mx string "domain-name"
+}
+
+func (rr *dnsRR_MX) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_NS struct {
+ Hdr dnsRR_Header
+ Ns string "domain-name"
+}
+
+func (rr *dnsRR_NS) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_PTR struct {
+ Hdr dnsRR_Header
+ Ptr string "domain-name"
+}
+
+func (rr *dnsRR_PTR) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_SOA struct {
+ Hdr dnsRR_Header
+ Ns string "domain-name"
+ Mbox string "domain-name"
+ Serial uint32
+ Refresh uint32
+ Retry uint32
+ Expire uint32
+ Minttl uint32
+}
+
+func (rr *dnsRR_SOA) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_TXT struct {
+ Hdr dnsRR_Header
+ Txt string // not domain name
+}
+
+func (rr *dnsRR_TXT) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_SRV struct {
+ Hdr dnsRR_Header
+ Priority uint16
+ Weight uint16
+ Port uint16
+ Target string "domain-name"
+}
+
+func (rr *dnsRR_SRV) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_A struct {
+ Hdr dnsRR_Header
+ A uint32 "ipv4"
+}
+
+func (rr *dnsRR_A) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_AAAA struct {
+ Hdr dnsRR_Header
+ AAAA [16]byte "ipv6"
+}
+
+func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+// Packing and unpacking.
+//
+// All the packers and unpackers take a (msg []byte, off int)
+// and return (off1 int, ok bool). If they return ok==false, they
+// also return off1==len(msg), so that the next unpacker will
+// also fail. This lets us avoid checks of ok until the end of a
+// packing sequence.
+
+// Map of constructors for each RR wire type.
+var rr_mk = map[int]func() dnsRR{
+ dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) },
+ dnsTypeHINFO: func() dnsRR { return new(dnsRR_HINFO) },
+ dnsTypeMB: func() dnsRR { return new(dnsRR_MB) },
+ dnsTypeMG: func() dnsRR { return new(dnsRR_MG) },
+ dnsTypeMINFO: func() dnsRR { return new(dnsRR_MINFO) },
+ dnsTypeMR: func() dnsRR { return new(dnsRR_MR) },
+ dnsTypeMX: func() dnsRR { return new(dnsRR_MX) },
+ dnsTypeNS: func() dnsRR { return new(dnsRR_NS) },
+ dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) },
+ dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) },
+ dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) },
+ dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) },
+ dnsTypeA: func() dnsRR { return new(dnsRR_A) },
+ dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) },
+}
+
+// Pack a domain name s into msg[off:].
+// Domain names are a sequence of counted strings
+// split at the dots. They end with a zero-length string.
+func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
+ // Add trailing dot to canonicalize name.
+ if n := len(s); n == 0 || s[n-1] != '.' {
+ s += "."
+ }
+
+ // Each dot ends a segment of the name.
+ // We trade each dot byte for a length byte.
+ // There is also a trailing zero.
+ // Check that we have all the space we need.
+ tot := len(s) + 1
+ if off+tot > len(msg) {
+ return len(msg), false
+ }
+
+ // Emit sequence of counted strings, chopping at dots.
+ begin := 0
+ for i := 0; i < len(s); i++ {
+ if s[i] == '.' {
+ if i-begin >= 1<<6 { // top two bits of length must be clear
+ return len(msg), false
+ }
+ msg[off] = byte(i - begin)
+ off++
+ for j := begin; j < i; j++ {
+ msg[off] = s[j]
+ off++
+ }
+ begin = i + 1
+ }
+ }
+ msg[off] = 0
+ off++
+ return off, true
+}
+
+// Unpack a domain name.
+// In addition to the simple sequences of counted strings above,
+// domain names are allowed to refer to strings elsewhere in the
+// packet, to avoid repeating common suffixes when returning
+// many entries in a single domain. The pointers are marked
+// by a length byte with the top two bits set. Ignoring those
+// two bits, that byte and the next give a 14 bit offset from msg[0]
+// where we should pick up the trail.
+// Note that if we jump elsewhere in the packet,
+// we return off1 == the offset after the first pointer we found,
+// which is where the next record will start.
+// In theory, the pointers are only allowed to jump backward.
+// We let them jump anywhere and stop jumping after a while.
+func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
+ s = ""
+ ptr := 0 // number of pointers followed
+Loop:
+ for {
+ if off >= len(msg) {
+ return "", len(msg), false
+ }
+ c := int(msg[off])
+ off++
+ switch c & 0xC0 {
+ case 0x00:
+ if c == 0x00 {
+ // end of name
+ break Loop
+ }
+ // literal string
+ if off+c > len(msg) {
+ return "", len(msg), false
+ }
+ s += string(msg[off:off+c]) + "."
+ off += c
+ case 0xC0:
+ // pointer to somewhere else in msg.
+ // remember location after first ptr,
+ // since that's how many bytes we consumed.
+ // also, don't follow too many pointers --
+ // maybe there's a loop.
+ if off >= len(msg) {
+ return "", len(msg), false
+ }
+ c1 := msg[off]
+ off++
+ if ptr == 0 {
+ off1 = off
+ }
+ if ptr++; ptr > 10 {
+ return "", len(msg), false
+ }
+ off = (c^0xC0)<<8 | int(c1)
+ default:
+ // 0x80 and 0x40 are reserved
+ return "", len(msg), false
+ }
+ }
+ if ptr == 0 {
+ off1 = off
+ }
+ return s, off1, true
+}
+
+// TODO(rsc): Move into generic library?
+// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string,
+// [n]byte, and other (often anonymous) structs.
+func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) {
+ for i := 0; i < val.NumField(); i++ {
+ f := val.Type().(*reflect.StructType).Field(i)
+ switch fv := val.Field(i).(type) {
+ default:
+ BadType:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
+ return len(msg), false
+ case *reflect.StructValue:
+ off, ok = packStructValue(fv, msg, off)
+ case *reflect.UintValue:
+ i := fv.Get()
+ switch fv.Type().Kind() {
+ default:
+ goto BadType
+ case reflect.Uint16:
+ if off+2 > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(i >> 8)
+ msg[off+1] = byte(i)
+ off += 2
+ case reflect.Uint32:
+ if off+4 > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(i >> 24)
+ msg[off+1] = byte(i >> 16)
+ msg[off+2] = byte(i >> 8)
+ msg[off+3] = byte(i)
+ off += 4
+ }
+ case *reflect.ArrayValue:
+ if fv.Type().(*reflect.ArrayType).Elem().Kind() != reflect.Uint8 {
+ goto BadType
+ }
+ n := fv.Len()
+ if off+n > len(msg) {
+ return len(msg), false
+ }
+ reflect.Copy(reflect.NewValue(msg[off:off+n]).(*reflect.SliceValue), fv)
+ off += n
+ case *reflect.StringValue:
+ // There are multiple string encodings.
+ // The tag distinguishes ordinary strings from domain names.
+ s := fv.Get()
+ switch f.Tag {
+ default:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
+ return len(msg), false
+ case "domain-name":
+ off, ok = packDomainName(s, msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ case "":
+ // Counted string: 1 byte length.
+ if len(s) > 255 || off+1+len(s) > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(len(s))
+ off++
+ off += copy(msg[off:], s)
+ }
+ }
+ }
+ return off, true
+}
+
+func structValue(any interface{}) *reflect.StructValue {
+ return reflect.NewValue(any).(*reflect.PtrValue).Elem().(*reflect.StructValue)
+}
+
+func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
+ off, ok = packStructValue(structValue(any), msg, off)
+ return off, ok
+}
+
+// TODO(rsc): Move into generic library?
+// Unpack a reflect.StructValue from msg.
+// Same restrictions as packStructValue.
+func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) {
+ for i := 0; i < val.NumField(); i++ {
+ f := val.Type().(*reflect.StructType).Field(i)
+ switch fv := val.Field(i).(type) {
+ default:
+ BadType:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
+ return len(msg), false
+ case *reflect.StructValue:
+ off, ok = unpackStructValue(fv, msg, off)
+ case *reflect.UintValue:
+ switch fv.Type().Kind() {
+ default:
+ goto BadType
+ case reflect.Uint16:
+ if off+2 > len(msg) {
+ return len(msg), false
+ }
+ i := uint16(msg[off])<<8 | uint16(msg[off+1])
+ fv.Set(uint64(i))
+ off += 2
+ case reflect.Uint32:
+ if off+4 > len(msg) {
+ return len(msg), false
+ }
+ i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
+ fv.Set(uint64(i))
+ off += 4
+ }
+ case *reflect.ArrayValue:
+ if fv.Type().(*reflect.ArrayType).Elem().Kind() != reflect.Uint8 {
+ goto BadType
+ }
+ n := fv.Len()
+ if off+n > len(msg) {
+ return len(msg), false
+ }
+ reflect.Copy(fv, reflect.NewValue(msg[off:off+n]).(*reflect.SliceValue))
+ off += n
+ case *reflect.StringValue:
+ var s string
+ switch f.Tag {
+ default:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
+ return len(msg), false
+ case "domain-name":
+ s, off, ok = unpackDomainName(msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ case "":
+ if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
+ return len(msg), false
+ }
+ n := int(msg[off])
+ off++
+ b := make([]byte, n)
+ for i := 0; i < n; i++ {
+ b[i] = msg[off+i]
+ }
+ off += n
+ s = string(b)
+ }
+ fv.Set(s)
+ }
+ }
+ return off, true
+}
+
+func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
+ off, ok = unpackStructValue(structValue(any), msg, off)
+ return off, ok
+}
+
+// Generic struct printer.
+// Doesn't care about the string tag "domain-name",
+// but does look for an "ipv4" tag on uint32 variables
+// and the "ipv6" tag on array variables,
+// printing them as IP addresses.
+func printStructValue(val *reflect.StructValue) string {
+ s := "{"
+ for i := 0; i < val.NumField(); i++ {
+ if i > 0 {
+ s += ", "
+ }
+ f := val.Type().(*reflect.StructType).Field(i)
+ if !f.Anonymous {
+ s += f.Name + "="
+ }
+ fval := val.Field(i)
+ if fv, ok := fval.(*reflect.StructValue); ok {
+ s += printStructValue(fv)
+ } else if fv, ok := fval.(*reflect.UintValue); ok && f.Tag == "ipv4" {
+ i := fv.Get()
+ s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
+ } else if fv, ok := fval.(*reflect.ArrayValue); ok && f.Tag == "ipv6" {
+ i := fv.Interface().([]byte)
+ s += IP(i).String()
+ } else {
+ s += fmt.Sprint(fval.Interface())
+ }
+ }
+ s += "}"
+ return s
+}
+
+func printStruct(any interface{}) string { return printStructValue(structValue(any)) }
+
+// Resource record packer.
+func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
+ var off1 int
+ // pack twice, once to find end of header
+ // and again to find end of packet.
+ // a bit inefficient but this doesn't need to be fast.
+ // off1 is end of header
+ // off2 is end of rr
+ off1, ok = packStruct(rr.Header(), msg, off)
+ off2, ok = packStruct(rr, msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ // pack a third time; redo header with correct data length
+ rr.Header().Rdlength = uint16(off2 - off1)
+ packStruct(rr.Header(), msg, off)
+ return off2, true
+}
+
+// Resource record unpacker.
+func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) {
+ // unpack just the header, to find the rr type and length
+ var h dnsRR_Header
+ off0 := off
+ if off, ok = unpackStruct(&h, msg, off); !ok {
+ return nil, len(msg), false
+ }
+ end := off + int(h.Rdlength)
+
+ // make an rr of that type and re-unpack.
+ // again inefficient but doesn't need to be fast.
+ mk, known := rr_mk[int(h.Rrtype)]
+ if !known {
+ return &h, end, true
+ }
+ rr = mk()
+ off, ok = unpackStruct(rr, msg, off0)
+ if off != end {
+ return &h, end, true
+ }
+ return rr, off, ok
+}
+
+// Usable representation of a DNS packet.
+
+// A manually-unpacked version of (id, bits).
+// This is in its own struct for easy printing.
+type dnsMsgHdr struct {
+ id uint16
+ response bool
+ opcode int
+ authoritative bool
+ truncated bool
+ recursion_desired bool
+ recursion_available bool
+ rcode int
+}
+
+type dnsMsg struct {
+ dnsMsgHdr
+ question []dnsQuestion
+ answer []dnsRR
+ ns []dnsRR
+ extra []dnsRR
+}
+
+func (dns *dnsMsg) Pack() (msg []byte, ok bool) {
+ var dh dnsHeader
+
+ // Convert convenient dnsMsg into wire-like dnsHeader.
+ dh.Id = dns.id
+ dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode)
+ if dns.recursion_available {
+ dh.Bits |= _RA
+ }
+ if dns.recursion_desired {
+ dh.Bits |= _RD
+ }
+ if dns.truncated {
+ dh.Bits |= _TC
+ }
+ if dns.authoritative {
+ dh.Bits |= _AA
+ }
+ if dns.response {
+ dh.Bits |= _QR
+ }
+
+ // Prepare variable sized arrays.
+ question := dns.question
+ answer := dns.answer
+ ns := dns.ns
+ extra := dns.extra
+
+ dh.Qdcount = uint16(len(question))
+ dh.Ancount = uint16(len(answer))
+ dh.Nscount = uint16(len(ns))
+ dh.Arcount = uint16(len(extra))
+
+ // Could work harder to calculate message size,
+ // but this is far more than we need and not
+ // big enough to hurt the allocator.
+ msg = make([]byte, 2000)
+
+ // Pack it in: header and then the pieces.
+ off := 0
+ off, ok = packStruct(&dh, msg, off)
+ for i := 0; i < len(question); i++ {
+ off, ok = packStruct(&question[i], msg, off)
+ }
+ for i := 0; i < len(answer); i++ {
+ off, ok = packRR(answer[i], msg, off)
+ }
+ for i := 0; i < len(ns); i++ {
+ off, ok = packRR(ns[i], msg, off)
+ }
+ for i := 0; i < len(extra); i++ {
+ off, ok = packRR(extra[i], msg, off)
+ }
+ if !ok {
+ return nil, false
+ }
+ return msg[0:off], true
+}
+
+func (dns *dnsMsg) Unpack(msg []byte) bool {
+ // Header.
+ var dh dnsHeader
+ off := 0
+ var ok bool
+ if off, ok = unpackStruct(&dh, msg, off); !ok {
+ return false
+ }
+ dns.id = dh.Id
+ dns.response = (dh.Bits & _QR) != 0
+ dns.opcode = int(dh.Bits>>11) & 0xF
+ dns.authoritative = (dh.Bits & _AA) != 0
+ dns.truncated = (dh.Bits & _TC) != 0
+ dns.recursion_desired = (dh.Bits & _RD) != 0
+ dns.recursion_available = (dh.Bits & _RA) != 0
+ dns.rcode = int(dh.Bits & 0xF)
+
+ // Arrays.
+ dns.question = make([]dnsQuestion, dh.Qdcount)
+ dns.answer = make([]dnsRR, dh.Ancount)
+ dns.ns = make([]dnsRR, dh.Nscount)
+ dns.extra = make([]dnsRR, dh.Arcount)
+
+ for i := 0; i < len(dns.question); i++ {
+ off, ok = unpackStruct(&dns.question[i], msg, off)
+ }
+ for i := 0; i < len(dns.answer); i++ {
+ dns.answer[i], off, ok = unpackRR(msg, off)
+ }
+ for i := 0; i < len(dns.ns); i++ {
+ dns.ns[i], off, ok = unpackRR(msg, off)
+ }
+ for i := 0; i < len(dns.extra); i++ {
+ dns.extra[i], off, ok = unpackRR(msg, off)
+ }
+ if !ok {
+ return false
+ }
+ // if off != len(msg) {
+ // println("extra bytes in dns packet", off, "<", len(msg));
+ // }
+ return true
+}
+
+func (dns *dnsMsg) String() string {
+ s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n"
+ if len(dns.question) > 0 {
+ s += "-- Questions\n"
+ for i := 0; i < len(dns.question); i++ {
+ s += printStruct(&dns.question[i]) + "\n"
+ }
+ }
+ if len(dns.answer) > 0 {
+ s += "-- Answers\n"
+ for i := 0; i < len(dns.answer); i++ {
+ s += printStruct(dns.answer[i]) + "\n"
+ }
+ }
+ if len(dns.ns) > 0 {
+ s += "-- Name servers\n"
+ for i := 0; i < len(dns.ns); i++ {
+ s += printStruct(dns.ns[i]) + "\n"
+ }
+ }
+ if len(dns.extra) > 0 {
+ s += "-- Extra\n"
+ for i := 0; i < len(dns.extra); i++ {
+ s += printStruct(dns.extra[i]) + "\n"
+ }
+ }
+ return s
+}
diff --git a/src/cmd/fix/testdata/reflect.dnsmsg.go.out b/src/cmd/fix/testdata/reflect.dnsmsg.go.out
new file mode 100644
index 000000000..c777fe27c
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.dnsmsg.go.out
@@ -0,0 +1,777 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// DNS packet assembly. See RFC 1035.
+//
+// This is intended to support name resolution during net.Dial.
+// It doesn't have to be blazing fast.
+//
+// Rather than write the usual handful of routines to pack and
+// unpack every message that can appear on the wire, we use
+// reflection to write a generic pack/unpack for structs and then
+// use it. Thus, if in the future we need to define new message
+// structs, no new pack/unpack/printing code needs to be written.
+//
+// The first half of this file defines the DNS message formats.
+// The second half implements the conversion to and from wire format.
+// A few of the structure elements have string tags to aid the
+// generic pack/unpack routines.
+//
+// TODO(rsc): There are enough names defined in this file that they're all
+// prefixed with dns. Perhaps put this in its own package later.
+
+package net
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+)
+
+// Packet formats
+
+// Wire constants.
+const (
+ // valid dnsRR_Header.Rrtype and dnsQuestion.qtype
+ dnsTypeA = 1
+ dnsTypeNS = 2
+ dnsTypeMD = 3
+ dnsTypeMF = 4
+ dnsTypeCNAME = 5
+ dnsTypeSOA = 6
+ dnsTypeMB = 7
+ dnsTypeMG = 8
+ dnsTypeMR = 9
+ dnsTypeNULL = 10
+ dnsTypeWKS = 11
+ dnsTypePTR = 12
+ dnsTypeHINFO = 13
+ dnsTypeMINFO = 14
+ dnsTypeMX = 15
+ dnsTypeTXT = 16
+ dnsTypeAAAA = 28
+ dnsTypeSRV = 33
+
+ // valid dnsQuestion.qtype only
+ dnsTypeAXFR = 252
+ dnsTypeMAILB = 253
+ dnsTypeMAILA = 254
+ dnsTypeALL = 255
+
+ // valid dnsQuestion.qclass
+ dnsClassINET = 1
+ dnsClassCSNET = 2
+ dnsClassCHAOS = 3
+ dnsClassHESIOD = 4
+ dnsClassANY = 255
+
+ // dnsMsg.rcode
+ dnsRcodeSuccess = 0
+ dnsRcodeFormatError = 1
+ dnsRcodeServerFailure = 2
+ dnsRcodeNameError = 3
+ dnsRcodeNotImplemented = 4
+ dnsRcodeRefused = 5
+)
+
+// The wire format for the DNS packet header.
+type dnsHeader struct {
+ Id uint16
+ Bits uint16
+ Qdcount, Ancount, Nscount, Arcount uint16
+}
+
+const (
+ // dnsHeader.Bits
+ _QR = 1 << 15 // query/response (response=1)
+ _AA = 1 << 10 // authoritative
+ _TC = 1 << 9 // truncated
+ _RD = 1 << 8 // recursion desired
+ _RA = 1 << 7 // recursion available
+)
+
+// DNS queries.
+type dnsQuestion struct {
+ Name string "domain-name" // "domain-name" specifies encoding; see packers below
+ Qtype uint16
+ Qclass uint16
+}
+
+// DNS responses (resource records).
+// There are many types of messages,
+// but they all share the same header.
+type dnsRR_Header struct {
+ Name string "domain-name"
+ Rrtype uint16
+ Class uint16
+ Ttl uint32
+ Rdlength uint16 // length of data after header
+}
+
+func (h *dnsRR_Header) Header() *dnsRR_Header {
+ return h
+}
+
+type dnsRR interface {
+ Header() *dnsRR_Header
+}
+
+// Specific DNS RR formats for each query type.
+
+type dnsRR_CNAME struct {
+ Hdr dnsRR_Header
+ Cname string "domain-name"
+}
+
+func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_HINFO struct {
+ Hdr dnsRR_Header
+ Cpu string
+ Os string
+}
+
+func (rr *dnsRR_HINFO) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MB struct {
+ Hdr dnsRR_Header
+ Mb string "domain-name"
+}
+
+func (rr *dnsRR_MB) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MG struct {
+ Hdr dnsRR_Header
+ Mg string "domain-name"
+}
+
+func (rr *dnsRR_MG) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MINFO struct {
+ Hdr dnsRR_Header
+ Rmail string "domain-name"
+ Email string "domain-name"
+}
+
+func (rr *dnsRR_MINFO) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MR struct {
+ Hdr dnsRR_Header
+ Mr string "domain-name"
+}
+
+func (rr *dnsRR_MR) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MX struct {
+ Hdr dnsRR_Header
+ Pref uint16
+ Mx string "domain-name"
+}
+
+func (rr *dnsRR_MX) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_NS struct {
+ Hdr dnsRR_Header
+ Ns string "domain-name"
+}
+
+func (rr *dnsRR_NS) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_PTR struct {
+ Hdr dnsRR_Header
+ Ptr string "domain-name"
+}
+
+func (rr *dnsRR_PTR) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_SOA struct {
+ Hdr dnsRR_Header
+ Ns string "domain-name"
+ Mbox string "domain-name"
+ Serial uint32
+ Refresh uint32
+ Retry uint32
+ Expire uint32
+ Minttl uint32
+}
+
+func (rr *dnsRR_SOA) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_TXT struct {
+ Hdr dnsRR_Header
+ Txt string // not domain name
+}
+
+func (rr *dnsRR_TXT) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_SRV struct {
+ Hdr dnsRR_Header
+ Priority uint16
+ Weight uint16
+ Port uint16
+ Target string "domain-name"
+}
+
+func (rr *dnsRR_SRV) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_A struct {
+ Hdr dnsRR_Header
+ A uint32 "ipv4"
+}
+
+func (rr *dnsRR_A) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_AAAA struct {
+ Hdr dnsRR_Header
+ AAAA [16]byte "ipv6"
+}
+
+func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+// Packing and unpacking.
+//
+// All the packers and unpackers take a (msg []byte, off int)
+// and return (off1 int, ok bool). If they return ok==false, they
+// also return off1==len(msg), so that the next unpacker will
+// also fail. This lets us avoid checks of ok until the end of a
+// packing sequence.
+
+// Map of constructors for each RR wire type.
+var rr_mk = map[int]func() dnsRR{
+ dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) },
+ dnsTypeHINFO: func() dnsRR { return new(dnsRR_HINFO) },
+ dnsTypeMB: func() dnsRR { return new(dnsRR_MB) },
+ dnsTypeMG: func() dnsRR { return new(dnsRR_MG) },
+ dnsTypeMINFO: func() dnsRR { return new(dnsRR_MINFO) },
+ dnsTypeMR: func() dnsRR { return new(dnsRR_MR) },
+ dnsTypeMX: func() dnsRR { return new(dnsRR_MX) },
+ dnsTypeNS: func() dnsRR { return new(dnsRR_NS) },
+ dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) },
+ dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) },
+ dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) },
+ dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) },
+ dnsTypeA: func() dnsRR { return new(dnsRR_A) },
+ dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) },
+}
+
+// Pack a domain name s into msg[off:].
+// Domain names are a sequence of counted strings
+// split at the dots. They end with a zero-length string.
+func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
+ // Add trailing dot to canonicalize name.
+ if n := len(s); n == 0 || s[n-1] != '.' {
+ s += "."
+ }
+
+ // Each dot ends a segment of the name.
+ // We trade each dot byte for a length byte.
+ // There is also a trailing zero.
+ // Check that we have all the space we need.
+ tot := len(s) + 1
+ if off+tot > len(msg) {
+ return len(msg), false
+ }
+
+ // Emit sequence of counted strings, chopping at dots.
+ begin := 0
+ for i := 0; i < len(s); i++ {
+ if s[i] == '.' {
+ if i-begin >= 1<<6 { // top two bits of length must be clear
+ return len(msg), false
+ }
+ msg[off] = byte(i - begin)
+ off++
+ for j := begin; j < i; j++ {
+ msg[off] = s[j]
+ off++
+ }
+ begin = i + 1
+ }
+ }
+ msg[off] = 0
+ off++
+ return off, true
+}
+
+// Unpack a domain name.
+// In addition to the simple sequences of counted strings above,
+// domain names are allowed to refer to strings elsewhere in the
+// packet, to avoid repeating common suffixes when returning
+// many entries in a single domain. The pointers are marked
+// by a length byte with the top two bits set. Ignoring those
+// two bits, that byte and the next give a 14 bit offset from msg[0]
+// where we should pick up the trail.
+// Note that if we jump elsewhere in the packet,
+// we return off1 == the offset after the first pointer we found,
+// which is where the next record will start.
+// In theory, the pointers are only allowed to jump backward.
+// We let them jump anywhere and stop jumping after a while.
+func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
+ s = ""
+ ptr := 0 // number of pointers followed
+Loop:
+ for {
+ if off >= len(msg) {
+ return "", len(msg), false
+ }
+ c := int(msg[off])
+ off++
+ switch c & 0xC0 {
+ case 0x00:
+ if c == 0x00 {
+ // end of name
+ break Loop
+ }
+ // literal string
+ if off+c > len(msg) {
+ return "", len(msg), false
+ }
+ s += string(msg[off:off+c]) + "."
+ off += c
+ case 0xC0:
+ // pointer to somewhere else in msg.
+ // remember location after first ptr,
+ // since that's how many bytes we consumed.
+ // also, don't follow too many pointers --
+ // maybe there's a loop.
+ if off >= len(msg) {
+ return "", len(msg), false
+ }
+ c1 := msg[off]
+ off++
+ if ptr == 0 {
+ off1 = off
+ }
+ if ptr++; ptr > 10 {
+ return "", len(msg), false
+ }
+ off = (c^0xC0)<<8 | int(c1)
+ default:
+ // 0x80 and 0x40 are reserved
+ return "", len(msg), false
+ }
+ }
+ if ptr == 0 {
+ off1 = off
+ }
+ return s, off1, true
+}
+
+// TODO(rsc): Move into generic library?
+// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string,
+// [n]byte, and other (often anonymous) structs.
+func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
+ for i := 0; i < val.NumField(); i++ {
+ f := val.Type().Field(i)
+ switch fv := val.Field(i); fv.Kind() {
+ default:
+ BadType:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
+ return len(msg), false
+ case reflect.Struct:
+ off, ok = packStructValue(fv, msg, off)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ i := fv.Uint()
+ switch fv.Type().Kind() {
+ default:
+ goto BadType
+ case reflect.Uint16:
+ if off+2 > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(i >> 8)
+ msg[off+1] = byte(i)
+ off += 2
+ case reflect.Uint32:
+ if off+4 > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(i >> 24)
+ msg[off+1] = byte(i >> 16)
+ msg[off+2] = byte(i >> 8)
+ msg[off+3] = byte(i)
+ off += 4
+ }
+ case reflect.Array:
+ if fv.Type().Elem().Kind() != reflect.Uint8 {
+ goto BadType
+ }
+ n := fv.Len()
+ if off+n > len(msg) {
+ return len(msg), false
+ }
+ reflect.Copy(reflect.ValueOf(msg[off:off+n]), fv)
+ off += n
+ case reflect.String:
+ // There are multiple string encodings.
+ // The tag distinguishes ordinary strings from domain names.
+ s := fv.String()
+ switch f.Tag {
+ default:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
+ return len(msg), false
+ case "domain-name":
+ off, ok = packDomainName(s, msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ case "":
+ // Counted string: 1 byte length.
+ if len(s) > 255 || off+1+len(s) > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(len(s))
+ off++
+ off += copy(msg[off:], s)
+ }
+ }
+ }
+ return off, true
+}
+
+func structValue(any interface{}) reflect.Value {
+ return reflect.ValueOf(any).Elem()
+}
+
+func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
+ off, ok = packStructValue(structValue(any), msg, off)
+ return off, ok
+}
+
+// TODO(rsc): Move into generic library?
+// Unpack a reflect.StructValue from msg.
+// Same restrictions as packStructValue.
+func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
+ for i := 0; i < val.NumField(); i++ {
+ f := val.Type().Field(i)
+ switch fv := val.Field(i); fv.Kind() {
+ default:
+ BadType:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
+ return len(msg), false
+ case reflect.Struct:
+ off, ok = unpackStructValue(fv, msg, off)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ switch fv.Type().Kind() {
+ default:
+ goto BadType
+ case reflect.Uint16:
+ if off+2 > len(msg) {
+ return len(msg), false
+ }
+ i := uint16(msg[off])<<8 | uint16(msg[off+1])
+ fv.SetUint(uint64(i))
+ off += 2
+ case reflect.Uint32:
+ if off+4 > len(msg) {
+ return len(msg), false
+ }
+ i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
+ fv.SetUint(uint64(i))
+ off += 4
+ }
+ case reflect.Array:
+ if fv.Type().Elem().Kind() != reflect.Uint8 {
+ goto BadType
+ }
+ n := fv.Len()
+ if off+n > len(msg) {
+ return len(msg), false
+ }
+ reflect.Copy(fv, reflect.ValueOf(msg[off:off+n]))
+ off += n
+ case reflect.String:
+ var s string
+ switch f.Tag {
+ default:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
+ return len(msg), false
+ case "domain-name":
+ s, off, ok = unpackDomainName(msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ case "":
+ if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
+ return len(msg), false
+ }
+ n := int(msg[off])
+ off++
+ b := make([]byte, n)
+ for i := 0; i < n; i++ {
+ b[i] = msg[off+i]
+ }
+ off += n
+ s = string(b)
+ }
+ fv.SetString(s)
+ }
+ }
+ return off, true
+}
+
+func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
+ off, ok = unpackStructValue(structValue(any), msg, off)
+ return off, ok
+}
+
+// Generic struct printer.
+// Doesn't care about the string tag "domain-name",
+// but does look for an "ipv4" tag on uint32 variables
+// and the "ipv6" tag on array variables,
+// printing them as IP addresses.
+func printStructValue(val reflect.Value) string {
+ s := "{"
+ for i := 0; i < val.NumField(); i++ {
+ if i > 0 {
+ s += ", "
+ }
+ f := val.Type().Field(i)
+ if !f.Anonymous {
+ s += f.Name + "="
+ }
+ fval := val.Field(i)
+ if fv := fval; fv.Kind() == reflect.Struct {
+ s += printStructValue(fv)
+ } else if fv := fval; (fv.Kind() == reflect.Uint || fv.Kind() == reflect.Uint8 || fv.Kind() == reflect.Uint16 || fv.Kind() == reflect.Uint32 || fv.Kind() == reflect.Uint64 || fv.Kind() == reflect.Uintptr) && f.Tag == "ipv4" {
+ i := fv.Uint()
+ s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
+ } else if fv := fval; fv.Kind() == reflect.Array && f.Tag == "ipv6" {
+ i := fv.Interface().([]byte)
+ s += IP(i).String()
+ } else {
+ s += fmt.Sprint(fval.Interface())
+ }
+ }
+ s += "}"
+ return s
+}
+
+func printStruct(any interface{}) string { return printStructValue(structValue(any)) }
+
+// Resource record packer.
+func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
+ var off1 int
+ // pack twice, once to find end of header
+ // and again to find end of packet.
+ // a bit inefficient but this doesn't need to be fast.
+ // off1 is end of header
+ // off2 is end of rr
+ off1, ok = packStruct(rr.Header(), msg, off)
+ off2, ok = packStruct(rr, msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ // pack a third time; redo header with correct data length
+ rr.Header().Rdlength = uint16(off2 - off1)
+ packStruct(rr.Header(), msg, off)
+ return off2, true
+}
+
+// Resource record unpacker.
+func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) {
+ // unpack just the header, to find the rr type and length
+ var h dnsRR_Header
+ off0 := off
+ if off, ok = unpackStruct(&h, msg, off); !ok {
+ return nil, len(msg), false
+ }
+ end := off + int(h.Rdlength)
+
+ // make an rr of that type and re-unpack.
+ // again inefficient but doesn't need to be fast.
+ mk, known := rr_mk[int(h.Rrtype)]
+ if !known {
+ return &h, end, true
+ }
+ rr = mk()
+ off, ok = unpackStruct(rr, msg, off0)
+ if off != end {
+ return &h, end, true
+ }
+ return rr, off, ok
+}
+
+// Usable representation of a DNS packet.
+
+// A manually-unpacked version of (id, bits).
+// This is in its own struct for easy printing.
+type dnsMsgHdr struct {
+ id uint16
+ response bool
+ opcode int
+ authoritative bool
+ truncated bool
+ recursion_desired bool
+ recursion_available bool
+ rcode int
+}
+
+type dnsMsg struct {
+ dnsMsgHdr
+ question []dnsQuestion
+ answer []dnsRR
+ ns []dnsRR
+ extra []dnsRR
+}
+
+func (dns *dnsMsg) Pack() (msg []byte, ok bool) {
+ var dh dnsHeader
+
+ // Convert convenient dnsMsg into wire-like dnsHeader.
+ dh.Id = dns.id
+ dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode)
+ if dns.recursion_available {
+ dh.Bits |= _RA
+ }
+ if dns.recursion_desired {
+ dh.Bits |= _RD
+ }
+ if dns.truncated {
+ dh.Bits |= _TC
+ }
+ if dns.authoritative {
+ dh.Bits |= _AA
+ }
+ if dns.response {
+ dh.Bits |= _QR
+ }
+
+ // Prepare variable sized arrays.
+ question := dns.question
+ answer := dns.answer
+ ns := dns.ns
+ extra := dns.extra
+
+ dh.Qdcount = uint16(len(question))
+ dh.Ancount = uint16(len(answer))
+ dh.Nscount = uint16(len(ns))
+ dh.Arcount = uint16(len(extra))
+
+ // Could work harder to calculate message size,
+ // but this is far more than we need and not
+ // big enough to hurt the allocator.
+ msg = make([]byte, 2000)
+
+ // Pack it in: header and then the pieces.
+ off := 0
+ off, ok = packStruct(&dh, msg, off)
+ for i := 0; i < len(question); i++ {
+ off, ok = packStruct(&question[i], msg, off)
+ }
+ for i := 0; i < len(answer); i++ {
+ off, ok = packRR(answer[i], msg, off)
+ }
+ for i := 0; i < len(ns); i++ {
+ off, ok = packRR(ns[i], msg, off)
+ }
+ for i := 0; i < len(extra); i++ {
+ off, ok = packRR(extra[i], msg, off)
+ }
+ if !ok {
+ return nil, false
+ }
+ return msg[0:off], true
+}
+
+func (dns *dnsMsg) Unpack(msg []byte) bool {
+ // Header.
+ var dh dnsHeader
+ off := 0
+ var ok bool
+ if off, ok = unpackStruct(&dh, msg, off); !ok {
+ return false
+ }
+ dns.id = dh.Id
+ dns.response = (dh.Bits & _QR) != 0
+ dns.opcode = int(dh.Bits>>11) & 0xF
+ dns.authoritative = (dh.Bits & _AA) != 0
+ dns.truncated = (dh.Bits & _TC) != 0
+ dns.recursion_desired = (dh.Bits & _RD) != 0
+ dns.recursion_available = (dh.Bits & _RA) != 0
+ dns.rcode = int(dh.Bits & 0xF)
+
+ // Arrays.
+ dns.question = make([]dnsQuestion, dh.Qdcount)
+ dns.answer = make([]dnsRR, dh.Ancount)
+ dns.ns = make([]dnsRR, dh.Nscount)
+ dns.extra = make([]dnsRR, dh.Arcount)
+
+ for i := 0; i < len(dns.question); i++ {
+ off, ok = unpackStruct(&dns.question[i], msg, off)
+ }
+ for i := 0; i < len(dns.answer); i++ {
+ dns.answer[i], off, ok = unpackRR(msg, off)
+ }
+ for i := 0; i < len(dns.ns); i++ {
+ dns.ns[i], off, ok = unpackRR(msg, off)
+ }
+ for i := 0; i < len(dns.extra); i++ {
+ dns.extra[i], off, ok = unpackRR(msg, off)
+ }
+ if !ok {
+ return false
+ }
+ // if off != len(msg) {
+ // println("extra bytes in dns packet", off, "<", len(msg));
+ // }
+ return true
+}
+
+func (dns *dnsMsg) String() string {
+ s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n"
+ if len(dns.question) > 0 {
+ s += "-- Questions\n"
+ for i := 0; i < len(dns.question); i++ {
+ s += printStruct(&dns.question[i]) + "\n"
+ }
+ }
+ if len(dns.answer) > 0 {
+ s += "-- Answers\n"
+ for i := 0; i < len(dns.answer); i++ {
+ s += printStruct(dns.answer[i]) + "\n"
+ }
+ }
+ if len(dns.ns) > 0 {
+ s += "-- Name servers\n"
+ for i := 0; i < len(dns.ns); i++ {
+ s += printStruct(dns.ns[i]) + "\n"
+ }
+ }
+ if len(dns.extra) > 0 {
+ s += "-- Extra\n"
+ for i := 0; i < len(dns.extra); i++ {
+ s += printStruct(dns.extra[i]) + "\n"
+ }
+ }
+ return s
+}
diff --git a/src/cmd/fix/testdata/reflect.encode.go.in b/src/cmd/fix/testdata/reflect.encode.go.in
new file mode 100644
index 000000000..26ce47039
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.encode.go.in
@@ -0,0 +1,367 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The json package implements encoding and decoding of JSON objects as
+// defined in RFC 4627.
+package json
+
+import (
+ "bytes"
+ "encoding/base64"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strconv"
+ "unicode"
+ "utf8"
+)
+
+// Marshal returns the JSON encoding of v.
+//
+// Marshal traverses the value v recursively.
+// If an encountered value implements the Marshaler interface,
+// Marshal calls its MarshalJSON method to produce JSON.
+//
+// Otherwise, Marshal uses the following type-dependent default encodings:
+//
+// Boolean values encode as JSON booleans.
+//
+// Floating point and integer values encode as JSON numbers.
+//
+// String values encode as JSON strings, with each invalid UTF-8 sequence
+// replaced by the encoding of the Unicode replacement character U+FFFD.
+//
+// Array and slice values encode as JSON arrays, except that
+// []byte encodes as a base64-encoded string.
+//
+// Struct values encode as JSON objects. Each struct field becomes
+// a member of the object. By default the object's key name is the
+// struct field name. If the struct field has a non-empty tag consisting
+// of only Unicode letters, digits, and underscores, that tag will be used
+// as the name instead. Only exported fields will be encoded.
+//
+// Map values encode as JSON objects.
+// The map's key type must be string; the object keys are used directly
+// as map keys.
+//
+// Pointer values encode as the value pointed to.
+// A nil pointer encodes as the null JSON object.
+//
+// Interface values encode as the value contained in the interface.
+// A nil interface value encodes as the null JSON object.
+//
+// Channel, complex, and function values cannot be encoded in JSON.
+// Attempting to encode such a value causes Marshal to return
+// an InvalidTypeError.
+//
+// JSON cannot represent cyclic data structures and Marshal does not
+// handle them. Passing cyclic structures to Marshal will result in
+// an infinite recursion.
+//
+func Marshal(v interface{}) ([]byte, os.Error) {
+ e := &encodeState{}
+ err := e.marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ return e.Bytes(), nil
+}
+
+// MarshalIndent is like Marshal but applies Indent to format the output.
+func MarshalIndent(v interface{}, prefix, indent string) ([]byte, os.Error) {
+ b, err := Marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ var buf bytes.Buffer
+ err = Indent(&buf, b, prefix, indent)
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+// MarshalForHTML is like Marshal but applies HTMLEscape to the output.
+func MarshalForHTML(v interface{}) ([]byte, os.Error) {
+ b, err := Marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ var buf bytes.Buffer
+ HTMLEscape(&buf, b)
+ return buf.Bytes(), nil
+}
+
+// HTMLEscape appends to dst the JSON-encoded src with <, >, and &
+// characters inside string literals changed to \u003c, \u003e, \u0026
+// so that the JSON will be safe to embed inside HTML <script> tags.
+// For historical reasons, web browsers don't honor standard HTML
+// escaping within <script> tags, so an alternative JSON encoding must
+// be used.
+func HTMLEscape(dst *bytes.Buffer, src []byte) {
+ // < > & can only appear in string literals,
+ // so just scan the string one byte at a time.
+ start := 0
+ for i, c := range src {
+ if c == '<' || c == '>' || c == '&' {
+ if start < i {
+ dst.Write(src[start:i])
+ }
+ dst.WriteString(`\u00`)
+ dst.WriteByte(hex[c>>4])
+ dst.WriteByte(hex[c&0xF])
+ start = i + 1
+ }
+ }
+ if start < len(src) {
+ dst.Write(src[start:])
+ }
+}
+
+// Marshaler is the interface implemented by objects that
+// can marshal themselves into valid JSON.
+type Marshaler interface {
+ MarshalJSON() ([]byte, os.Error)
+}
+
+type UnsupportedTypeError struct {
+ Type reflect.Type
+}
+
+func (e *UnsupportedTypeError) String() string {
+ return "json: unsupported type: " + e.Type.String()
+}
+
+type InvalidUTF8Error struct {
+ S string
+}
+
+func (e *InvalidUTF8Error) String() string {
+ return "json: invalid UTF-8 in string: " + strconv.Quote(e.S)
+}
+
+type MarshalerError struct {
+ Type reflect.Type
+ Error os.Error
+}
+
+func (e *MarshalerError) String() string {
+ return "json: error calling MarshalJSON for type " + e.Type.String() + ": " + e.Error.String()
+}
+
+type interfaceOrPtrValue interface {
+ IsNil() bool
+ Elem() reflect.Value
+}
+
+var hex = "0123456789abcdef"
+
+// An encodeState encodes JSON into a bytes.Buffer.
+type encodeState struct {
+ bytes.Buffer // accumulated output
+}
+
+func (e *encodeState) marshal(v interface{}) (err os.Error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(runtime.Error); ok {
+ panic(r)
+ }
+ err = r.(os.Error)
+ }
+ }()
+ e.reflectValue(reflect.NewValue(v))
+ return nil
+}
+
+func (e *encodeState) error(err os.Error) {
+ panic(err)
+}
+
+var byteSliceType = reflect.Typeof([]byte(nil))
+
+func (e *encodeState) reflectValue(v reflect.Value) {
+ if v == nil {
+ e.WriteString("null")
+ return
+ }
+
+ if j, ok := v.Interface().(Marshaler); ok {
+ b, err := j.MarshalJSON()
+ if err == nil {
+ // copy JSON into buffer, checking validity.
+ err = Compact(&e.Buffer, b)
+ }
+ if err != nil {
+ e.error(&MarshalerError{v.Type(), err})
+ }
+ return
+ }
+
+ switch v := v.(type) {
+ case *reflect.BoolValue:
+ x := v.Get()
+ if x {
+ e.WriteString("true")
+ } else {
+ e.WriteString("false")
+ }
+
+ case *reflect.IntValue:
+ e.WriteString(strconv.Itoa64(v.Get()))
+
+ case *reflect.UintValue:
+ e.WriteString(strconv.Uitoa64(v.Get()))
+
+ case *reflect.FloatValue:
+ e.WriteString(strconv.FtoaN(v.Get(), 'g', -1, v.Type().Bits()))
+
+ case *reflect.StringValue:
+ e.string(v.Get())
+
+ case *reflect.StructValue:
+ e.WriteByte('{')
+ t := v.Type().(*reflect.StructType)
+ n := v.NumField()
+ first := true
+ for i := 0; i < n; i++ {
+ f := t.Field(i)
+ if f.PkgPath != "" {
+ continue
+ }
+ if first {
+ first = false
+ } else {
+ e.WriteByte(',')
+ }
+ if isValidTag(f.Tag) {
+ e.string(f.Tag)
+ } else {
+ e.string(f.Name)
+ }
+ e.WriteByte(':')
+ e.reflectValue(v.Field(i))
+ }
+ e.WriteByte('}')
+
+ case *reflect.MapValue:
+ if _, ok := v.Type().(*reflect.MapType).Key().(*reflect.StringType); !ok {
+ e.error(&UnsupportedTypeError{v.Type()})
+ }
+ if v.IsNil() {
+ e.WriteString("null")
+ break
+ }
+ e.WriteByte('{')
+ var sv stringValues = v.Keys()
+ sort.Sort(sv)
+ for i, k := range sv {
+ if i > 0 {
+ e.WriteByte(',')
+ }
+ e.string(k.(*reflect.StringValue).Get())
+ e.WriteByte(':')
+ e.reflectValue(v.Elem(k))
+ }
+ e.WriteByte('}')
+
+ case reflect.ArrayOrSliceValue:
+ if v.Type() == byteSliceType {
+ e.WriteByte('"')
+ s := v.Interface().([]byte)
+ if len(s) < 1024 {
+ // for small buffers, using Encode directly is much faster.
+ dst := make([]byte, base64.StdEncoding.EncodedLen(len(s)))
+ base64.StdEncoding.Encode(dst, s)
+ e.Write(dst)
+ } else {
+ // for large buffers, avoid unnecessary extra temporary
+ // buffer space.
+ enc := base64.NewEncoder(base64.StdEncoding, e)
+ enc.Write(s)
+ enc.Close()
+ }
+ e.WriteByte('"')
+ break
+ }
+ e.WriteByte('[')
+ n := v.Len()
+ for i := 0; i < n; i++ {
+ if i > 0 {
+ e.WriteByte(',')
+ }
+ e.reflectValue(v.Elem(i))
+ }
+ e.WriteByte(']')
+
+ case interfaceOrPtrValue:
+ if v.IsNil() {
+ e.WriteString("null")
+ return
+ }
+ e.reflectValue(v.Elem())
+
+ default:
+ e.error(&UnsupportedTypeError{v.Type()})
+ }
+ return
+}
+
+func isValidTag(s string) bool {
+ if s == "" {
+ return false
+ }
+ for _, c := range s {
+ if c != '_' && !unicode.IsLetter(c) && !unicode.IsDigit(c) {
+ return false
+ }
+ }
+ return true
+}
+
+// stringValues is a slice of reflect.Value holding *reflect.StringValue.
+// It implements the methods to sort by string.
+type stringValues []reflect.Value
+
+func (sv stringValues) Len() int { return len(sv) }
+func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
+func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
+func (sv stringValues) get(i int) string { return sv[i].(*reflect.StringValue).Get() }
+
+func (e *encodeState) string(s string) {
+ e.WriteByte('"')
+ start := 0
+ for i := 0; i < len(s); {
+ if b := s[i]; b < utf8.RuneSelf {
+ if 0x20 <= b && b != '\\' && b != '"' {
+ i++
+ continue
+ }
+ if start < i {
+ e.WriteString(s[start:i])
+ }
+ if b == '\\' || b == '"' {
+ e.WriteByte('\\')
+ e.WriteByte(b)
+ } else {
+ e.WriteString(`\u00`)
+ e.WriteByte(hex[b>>4])
+ e.WriteByte(hex[b&0xF])
+ }
+ i++
+ start = i
+ continue
+ }
+ c, size := utf8.DecodeRuneInString(s[i:])
+ if c == utf8.RuneError && size == 1 {
+ e.error(&InvalidUTF8Error{s})
+ }
+ i += size
+ }
+ if start < len(s) {
+ e.WriteString(s[start:])
+ }
+ e.WriteByte('"')
+}
diff --git a/src/cmd/fix/testdata/reflect.encode.go.out b/src/cmd/fix/testdata/reflect.encode.go.out
new file mode 100644
index 000000000..9a13a75ab
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.encode.go.out
@@ -0,0 +1,367 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The json package implements encoding and decoding of JSON objects as
+// defined in RFC 4627.
+package json
+
+import (
+ "bytes"
+ "encoding/base64"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strconv"
+ "unicode"
+ "utf8"
+)
+
+// Marshal returns the JSON encoding of v.
+//
+// Marshal traverses the value v recursively.
+// If an encountered value implements the Marshaler interface,
+// Marshal calls its MarshalJSON method to produce JSON.
+//
+// Otherwise, Marshal uses the following type-dependent default encodings:
+//
+// Boolean values encode as JSON booleans.
+//
+// Floating point and integer values encode as JSON numbers.
+//
+// String values encode as JSON strings, with each invalid UTF-8 sequence
+// replaced by the encoding of the Unicode replacement character U+FFFD.
+//
+// Array and slice values encode as JSON arrays, except that
+// []byte encodes as a base64-encoded string.
+//
+// Struct values encode as JSON objects. Each struct field becomes
+// a member of the object. By default the object's key name is the
+// struct field name. If the struct field has a non-empty tag consisting
+// of only Unicode letters, digits, and underscores, that tag will be used
+// as the name instead. Only exported fields will be encoded.
+//
+// Map values encode as JSON objects.
+// The map's key type must be string; the object keys are used directly
+// as map keys.
+//
+// Pointer values encode as the value pointed to.
+// A nil pointer encodes as the null JSON object.
+//
+// Interface values encode as the value contained in the interface.
+// A nil interface value encodes as the null JSON object.
+//
+// Channel, complex, and function values cannot be encoded in JSON.
+// Attempting to encode such a value causes Marshal to return
+// an InvalidTypeError.
+//
+// JSON cannot represent cyclic data structures and Marshal does not
+// handle them. Passing cyclic structures to Marshal will result in
+// an infinite recursion.
+//
+func Marshal(v interface{}) ([]byte, os.Error) {
+ e := &encodeState{}
+ err := e.marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ return e.Bytes(), nil
+}
+
+// MarshalIndent is like Marshal but applies Indent to format the output.
+func MarshalIndent(v interface{}, prefix, indent string) ([]byte, os.Error) {
+ b, err := Marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ var buf bytes.Buffer
+ err = Indent(&buf, b, prefix, indent)
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+// MarshalForHTML is like Marshal but applies HTMLEscape to the output.
+func MarshalForHTML(v interface{}) ([]byte, os.Error) {
+ b, err := Marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ var buf bytes.Buffer
+ HTMLEscape(&buf, b)
+ return buf.Bytes(), nil
+}
+
+// HTMLEscape appends to dst the JSON-encoded src with <, >, and &
+// characters inside string literals changed to \u003c, \u003e, \u0026
+// so that the JSON will be safe to embed inside HTML <script> tags.
+// For historical reasons, web browsers don't honor standard HTML
+// escaping within <script> tags, so an alternative JSON encoding must
+// be used.
+func HTMLEscape(dst *bytes.Buffer, src []byte) {
+ // < > & can only appear in string literals,
+ // so just scan the string one byte at a time.
+ start := 0
+ for i, c := range src {
+ if c == '<' || c == '>' || c == '&' {
+ if start < i {
+ dst.Write(src[start:i])
+ }
+ dst.WriteString(`\u00`)
+ dst.WriteByte(hex[c>>4])
+ dst.WriteByte(hex[c&0xF])
+ start = i + 1
+ }
+ }
+ if start < len(src) {
+ dst.Write(src[start:])
+ }
+}
+
+// Marshaler is the interface implemented by objects that
+// can marshal themselves into valid JSON.
+type Marshaler interface {
+ MarshalJSON() ([]byte, os.Error)
+}
+
+type UnsupportedTypeError struct {
+ Type reflect.Type
+}
+
+func (e *UnsupportedTypeError) String() string {
+ return "json: unsupported type: " + e.Type.String()
+}
+
+type InvalidUTF8Error struct {
+ S string
+}
+
+func (e *InvalidUTF8Error) String() string {
+ return "json: invalid UTF-8 in string: " + strconv.Quote(e.S)
+}
+
+type MarshalerError struct {
+ Type reflect.Type
+ Error os.Error
+}
+
+func (e *MarshalerError) String() string {
+ return "json: error calling MarshalJSON for type " + e.Type.String() + ": " + e.Error.String()
+}
+
+type interfaceOrPtrValue interface {
+ IsNil() bool
+ Elem() reflect.Value
+}
+
+var hex = "0123456789abcdef"
+
+// An encodeState encodes JSON into a bytes.Buffer.
+type encodeState struct {
+ bytes.Buffer // accumulated output
+}
+
+func (e *encodeState) marshal(v interface{}) (err os.Error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(runtime.Error); ok {
+ panic(r)
+ }
+ err = r.(os.Error)
+ }
+ }()
+ e.reflectValue(reflect.ValueOf(v))
+ return nil
+}
+
+func (e *encodeState) error(err os.Error) {
+ panic(err)
+}
+
+var byteSliceType = reflect.TypeOf([]byte(nil))
+
+func (e *encodeState) reflectValue(v reflect.Value) {
+ if !v.IsValid() {
+ e.WriteString("null")
+ return
+ }
+
+ if j, ok := v.Interface().(Marshaler); ok {
+ b, err := j.MarshalJSON()
+ if err == nil {
+ // copy JSON into buffer, checking validity.
+ err = Compact(&e.Buffer, b)
+ }
+ if err != nil {
+ e.error(&MarshalerError{v.Type(), err})
+ }
+ return
+ }
+
+ switch v.Kind() {
+ case reflect.Bool:
+ x := v.Bool()
+ if x {
+ e.WriteString("true")
+ } else {
+ e.WriteString("false")
+ }
+
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ e.WriteString(strconv.Itoa64(v.Int()))
+
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ e.WriteString(strconv.Uitoa64(v.Uint()))
+
+ case reflect.Float32, reflect.Float64:
+ e.WriteString(strconv.FtoaN(v.Float(), 'g', -1, v.Type().Bits()))
+
+ case reflect.String:
+ e.string(v.String())
+
+ case reflect.Struct:
+ e.WriteByte('{')
+ t := v.Type()
+ n := v.NumField()
+ first := true
+ for i := 0; i < n; i++ {
+ f := t.Field(i)
+ if f.PkgPath != "" {
+ continue
+ }
+ if first {
+ first = false
+ } else {
+ e.WriteByte(',')
+ }
+ if isValidTag(f.Tag) {
+ e.string(f.Tag)
+ } else {
+ e.string(f.Name)
+ }
+ e.WriteByte(':')
+ e.reflectValue(v.Field(i))
+ }
+ e.WriteByte('}')
+
+ case reflect.Map:
+ if v.Type().Key().Kind() != reflect.String {
+ e.error(&UnsupportedTypeError{v.Type()})
+ }
+ if v.IsNil() {
+ e.WriteString("null")
+ break
+ }
+ e.WriteByte('{')
+ var sv stringValues = v.MapKeys()
+ sort.Sort(sv)
+ for i, k := range sv {
+ if i > 0 {
+ e.WriteByte(',')
+ }
+ e.string(k.String())
+ e.WriteByte(':')
+ e.reflectValue(v.MapIndex(k))
+ }
+ e.WriteByte('}')
+
+ case reflect.Array, reflect.Slice:
+ if v.Type() == byteSliceType {
+ e.WriteByte('"')
+ s := v.Interface().([]byte)
+ if len(s) < 1024 {
+ // for small buffers, using Encode directly is much faster.
+ dst := make([]byte, base64.StdEncoding.EncodedLen(len(s)))
+ base64.StdEncoding.Encode(dst, s)
+ e.Write(dst)
+ } else {
+ // for large buffers, avoid unnecessary extra temporary
+ // buffer space.
+ enc := base64.NewEncoder(base64.StdEncoding, e)
+ enc.Write(s)
+ enc.Close()
+ }
+ e.WriteByte('"')
+ break
+ }
+ e.WriteByte('[')
+ n := v.Len()
+ for i := 0; i < n; i++ {
+ if i > 0 {
+ e.WriteByte(',')
+ }
+ e.reflectValue(v.Index(i))
+ }
+ e.WriteByte(']')
+
+ case interfaceOrPtrValue:
+ if v.IsNil() {
+ e.WriteString("null")
+ return
+ }
+ e.reflectValue(v.Elem())
+
+ default:
+ e.error(&UnsupportedTypeError{v.Type()})
+ }
+ return
+}
+
+func isValidTag(s string) bool {
+ if s == "" {
+ return false
+ }
+ for _, c := range s {
+ if c != '_' && !unicode.IsLetter(c) && !unicode.IsDigit(c) {
+ return false
+ }
+ }
+ return true
+}
+
+// stringValues is a slice of reflect.Value holding *reflect.StringValue.
+// It implements the methods to sort by string.
+type stringValues []reflect.Value
+
+func (sv stringValues) Len() int { return len(sv) }
+func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
+func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
+func (sv stringValues) get(i int) string { return sv[i].String() }
+
+func (e *encodeState) string(s string) {
+ e.WriteByte('"')
+ start := 0
+ for i := 0; i < len(s); {
+ if b := s[i]; b < utf8.RuneSelf {
+ if 0x20 <= b && b != '\\' && b != '"' {
+ i++
+ continue
+ }
+ if start < i {
+ e.WriteString(s[start:i])
+ }
+ if b == '\\' || b == '"' {
+ e.WriteByte('\\')
+ e.WriteByte(b)
+ } else {
+ e.WriteString(`\u00`)
+ e.WriteByte(hex[b>>4])
+ e.WriteByte(hex[b&0xF])
+ }
+ i++
+ start = i
+ continue
+ }
+ c, size := utf8.DecodeRuneInString(s[i:])
+ if c == utf8.RuneError && size == 1 {
+ e.error(&InvalidUTF8Error{s})
+ }
+ i += size
+ }
+ if start < len(s) {
+ e.WriteString(s[start:])
+ }
+ e.WriteByte('"')
+}
diff --git a/src/cmd/fix/testdata/reflect.encoder.go.in b/src/cmd/fix/testdata/reflect.encoder.go.in
new file mode 100644
index 000000000..0202d79ac
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.encoder.go.in
@@ -0,0 +1,240 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "sync"
+)
+
+// An Encoder manages the transmission of type and data information to the
+// other side of a connection.
+type Encoder struct {
+ mutex sync.Mutex // each item must be sent atomically
+ w []io.Writer // where to send the data
+ sent map[reflect.Type]typeId // which types we've already sent
+ countState *encoderState // stage for writing counts
+ freeList *encoderState // list of free encoderStates; avoids reallocation
+ buf []byte // for collecting the output.
+ byteBuf bytes.Buffer // buffer for top-level encoderState
+ err os.Error
+}
+
+// NewEncoder returns a new encoder that will transmit on the io.Writer.
+func NewEncoder(w io.Writer) *Encoder {
+ enc := new(Encoder)
+ enc.w = []io.Writer{w}
+ enc.sent = make(map[reflect.Type]typeId)
+ enc.countState = enc.newEncoderState(new(bytes.Buffer))
+ return enc
+}
+
+// writer() returns the innermost writer the encoder is using
+func (enc *Encoder) writer() io.Writer {
+ return enc.w[len(enc.w)-1]
+}
+
+// pushWriter adds a writer to the encoder.
+func (enc *Encoder) pushWriter(w io.Writer) {
+ enc.w = append(enc.w, w)
+}
+
+// popWriter pops the innermost writer.
+func (enc *Encoder) popWriter() {
+ enc.w = enc.w[0 : len(enc.w)-1]
+}
+
+func (enc *Encoder) badType(rt reflect.Type) {
+ enc.setError(os.NewError("gob: can't encode type " + rt.String()))
+}
+
+func (enc *Encoder) setError(err os.Error) {
+ if enc.err == nil { // remember the first.
+ enc.err = err
+ }
+}
+
+// writeMessage sends the data item preceded by a unsigned count of its length.
+func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
+ enc.countState.encodeUint(uint64(b.Len()))
+ // Build the buffer.
+ countLen := enc.countState.b.Len()
+ total := countLen + b.Len()
+ if total > len(enc.buf) {
+ enc.buf = make([]byte, total+1000) // extra for growth
+ }
+ // Place the length before the data.
+ // TODO(r): avoid the extra copy here.
+ enc.countState.b.Read(enc.buf[0:countLen])
+ // Now the data.
+ b.Read(enc.buf[countLen:total])
+ // Write the data.
+ _, err := w.Write(enc.buf[0:total])
+ if err != nil {
+ enc.setError(err)
+ }
+}
+
+// sendActualType sends the requested type, without further investigation, unless
+// it's been sent before.
+func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
+ if _, alreadySent := enc.sent[actual]; alreadySent {
+ return false
+ }
+ typeLock.Lock()
+ info, err := getTypeInfo(ut)
+ typeLock.Unlock()
+ if err != nil {
+ enc.setError(err)
+ return
+ }
+ // Send the pair (-id, type)
+ // Id:
+ state.encodeInt(-int64(info.id))
+ // Type:
+ enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
+ enc.writeMessage(w, state.b)
+ if enc.err != nil {
+ return
+ }
+
+ // Remember we've sent this type, both what the user gave us and the base type.
+ enc.sent[ut.base] = info.id
+ if ut.user != ut.base {
+ enc.sent[ut.user] = info.id
+ }
+ // Now send the inner types
+ switch st := actual.(type) {
+ case *reflect.StructType:
+ for i := 0; i < st.NumField(); i++ {
+ enc.sendType(w, state, st.Field(i).Type)
+ }
+ case reflect.ArrayOrSliceType:
+ enc.sendType(w, state, st.Elem())
+ }
+ return true
+}
+
+// sendType sends the type info to the other side, if necessary.
+func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
+ ut := userType(origt)
+ if ut.isGobEncoder {
+ // The rules are different: regardless of the underlying type's representation,
+ // we need to tell the other side that this exact type is a GobEncoder.
+ return enc.sendActualType(w, state, ut, ut.user)
+ }
+
+ // It's a concrete value, so drill down to the base type.
+ switch rt := ut.base.(type) {
+ default:
+ // Basic types and interfaces do not need to be described.
+ return
+ case *reflect.SliceType:
+ // If it's []uint8, don't send; it's considered basic.
+ if rt.Elem().Kind() == reflect.Uint8 {
+ return
+ }
+ // Otherwise we do send.
+ break
+ case *reflect.ArrayType:
+ // arrays must be sent so we know their lengths and element types.
+ break
+ case *reflect.MapType:
+ // maps must be sent so we know their lengths and key/value types.
+ break
+ case *reflect.StructType:
+ // structs must be sent so we know their fields.
+ break
+ case *reflect.ChanType, *reflect.FuncType:
+ // Probably a bad field in a struct.
+ enc.badType(rt)
+ return
+ }
+
+ return enc.sendActualType(w, state, ut, ut.base)
+}
+
+// Encode transmits the data item represented by the empty interface value,
+// guaranteeing that all necessary type information has been transmitted first.
+func (enc *Encoder) Encode(e interface{}) os.Error {
+ return enc.EncodeValue(reflect.NewValue(e))
+}
+
+// sendTypeDescriptor makes sure the remote side knows about this type.
+// It will send a descriptor if this is the first time the type has been
+// sent.
+func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
+ // Make sure the type is known to the other side.
+ // First, have we already sent this type?
+ rt := ut.base
+ if ut.isGobEncoder {
+ rt = ut.user
+ }
+ if _, alreadySent := enc.sent[rt]; !alreadySent {
+ // No, so send it.
+ sent := enc.sendType(w, state, rt)
+ if enc.err != nil {
+ return
+ }
+ // If the type info has still not been transmitted, it means we have
+ // a singleton basic type (int, []byte etc.) at top level. We don't
+ // need to send the type info but we do need to update enc.sent.
+ if !sent {
+ typeLock.Lock()
+ info, err := getTypeInfo(ut)
+ typeLock.Unlock()
+ if err != nil {
+ enc.setError(err)
+ return
+ }
+ enc.sent[rt] = info.id
+ }
+ }
+}
+
+// sendTypeId sends the id, which must have already been defined.
+func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) {
+ // Identify the type of this top-level value.
+ state.encodeInt(int64(enc.sent[ut.base]))
+}
+
+// EncodeValue transmits the data item represented by the reflection value,
+// guaranteeing that all necessary type information has been transmitted first.
+func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
+ // Make sure we're single-threaded through here, so multiple
+ // goroutines can share an encoder.
+ enc.mutex.Lock()
+ defer enc.mutex.Unlock()
+
+ // Remove any nested writers remaining due to previous errors.
+ enc.w = enc.w[0:1]
+
+ ut, err := validUserType(value.Type())
+ if err != nil {
+ return err
+ }
+
+ enc.err = nil
+ enc.byteBuf.Reset()
+ state := enc.newEncoderState(&enc.byteBuf)
+
+ enc.sendTypeDescriptor(enc.writer(), state, ut)
+ enc.sendTypeId(state, ut)
+ if enc.err != nil {
+ return enc.err
+ }
+
+ // Encode the object.
+ enc.encode(state.b, value, ut)
+ if enc.err == nil {
+ enc.writeMessage(enc.writer(), state.b)
+ }
+
+ enc.freeEncoderState(state)
+ return enc.err
+}
diff --git a/src/cmd/fix/testdata/reflect.encoder.go.out b/src/cmd/fix/testdata/reflect.encoder.go.out
new file mode 100644
index 000000000..925d39301
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.encoder.go.out
@@ -0,0 +1,240 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "sync"
+)
+
+// An Encoder manages the transmission of type and data information to the
+// other side of a connection.
+type Encoder struct {
+ mutex sync.Mutex // each item must be sent atomically
+ w []io.Writer // where to send the data
+ sent map[reflect.Type]typeId // which types we've already sent
+ countState *encoderState // stage for writing counts
+ freeList *encoderState // list of free encoderStates; avoids reallocation
+ buf []byte // for collecting the output.
+ byteBuf bytes.Buffer // buffer for top-level encoderState
+ err os.Error
+}
+
+// NewEncoder returns a new encoder that will transmit on the io.Writer.
+func NewEncoder(w io.Writer) *Encoder {
+ enc := new(Encoder)
+ enc.w = []io.Writer{w}
+ enc.sent = make(map[reflect.Type]typeId)
+ enc.countState = enc.newEncoderState(new(bytes.Buffer))
+ return enc
+}
+
+// writer() returns the innermost writer the encoder is using
+func (enc *Encoder) writer() io.Writer {
+ return enc.w[len(enc.w)-1]
+}
+
+// pushWriter adds a writer to the encoder.
+func (enc *Encoder) pushWriter(w io.Writer) {
+ enc.w = append(enc.w, w)
+}
+
+// popWriter pops the innermost writer.
+func (enc *Encoder) popWriter() {
+ enc.w = enc.w[0 : len(enc.w)-1]
+}
+
+func (enc *Encoder) badType(rt reflect.Type) {
+ enc.setError(os.NewError("gob: can't encode type " + rt.String()))
+}
+
+func (enc *Encoder) setError(err os.Error) {
+ if enc.err == nil { // remember the first.
+ enc.err = err
+ }
+}
+
+// writeMessage sends the data item preceded by a unsigned count of its length.
+func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
+ enc.countState.encodeUint(uint64(b.Len()))
+ // Build the buffer.
+ countLen := enc.countState.b.Len()
+ total := countLen + b.Len()
+ if total > len(enc.buf) {
+ enc.buf = make([]byte, total+1000) // extra for growth
+ }
+ // Place the length before the data.
+ // TODO(r): avoid the extra copy here.
+ enc.countState.b.Read(enc.buf[0:countLen])
+ // Now the data.
+ b.Read(enc.buf[countLen:total])
+ // Write the data.
+ _, err := w.Write(enc.buf[0:total])
+ if err != nil {
+ enc.setError(err)
+ }
+}
+
+// sendActualType sends the requested type, without further investigation, unless
+// it's been sent before.
+func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
+ if _, alreadySent := enc.sent[actual]; alreadySent {
+ return false
+ }
+ typeLock.Lock()
+ info, err := getTypeInfo(ut)
+ typeLock.Unlock()
+ if err != nil {
+ enc.setError(err)
+ return
+ }
+ // Send the pair (-id, type)
+ // Id:
+ state.encodeInt(-int64(info.id))
+ // Type:
+ enc.encode(state.b, reflect.ValueOf(info.wire), wireTypeUserInfo)
+ enc.writeMessage(w, state.b)
+ if enc.err != nil {
+ return
+ }
+
+ // Remember we've sent this type, both what the user gave us and the base type.
+ enc.sent[ut.base] = info.id
+ if ut.user != ut.base {
+ enc.sent[ut.user] = info.id
+ }
+ // Now send the inner types
+ switch st := actual; st.Kind() {
+ case reflect.Struct:
+ for i := 0; i < st.NumField(); i++ {
+ enc.sendType(w, state, st.Field(i).Type)
+ }
+ case reflect.Array, reflect.Slice:
+ enc.sendType(w, state, st.Elem())
+ }
+ return true
+}
+
+// sendType sends the type info to the other side, if necessary.
+func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
+ ut := userType(origt)
+ if ut.isGobEncoder {
+ // The rules are different: regardless of the underlying type's representation,
+ // we need to tell the other side that this exact type is a GobEncoder.
+ return enc.sendActualType(w, state, ut, ut.user)
+ }
+
+ // It's a concrete value, so drill down to the base type.
+ switch rt := ut.base; rt.Kind() {
+ default:
+ // Basic types and interfaces do not need to be described.
+ return
+ case reflect.Slice:
+ // If it's []uint8, don't send; it's considered basic.
+ if rt.Elem().Kind() == reflect.Uint8 {
+ return
+ }
+ // Otherwise we do send.
+ break
+ case reflect.Array:
+ // arrays must be sent so we know their lengths and element types.
+ break
+ case reflect.Map:
+ // maps must be sent so we know their lengths and key/value types.
+ break
+ case reflect.Struct:
+ // structs must be sent so we know their fields.
+ break
+ case reflect.Chan, reflect.Func:
+ // Probably a bad field in a struct.
+ enc.badType(rt)
+ return
+ }
+
+ return enc.sendActualType(w, state, ut, ut.base)
+}
+
+// Encode transmits the data item represented by the empty interface value,
+// guaranteeing that all necessary type information has been transmitted first.
+func (enc *Encoder) Encode(e interface{}) os.Error {
+ return enc.EncodeValue(reflect.ValueOf(e))
+}
+
+// sendTypeDescriptor makes sure the remote side knows about this type.
+// It will send a descriptor if this is the first time the type has been
+// sent.
+func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
+ // Make sure the type is known to the other side.
+ // First, have we already sent this type?
+ rt := ut.base
+ if ut.isGobEncoder {
+ rt = ut.user
+ }
+ if _, alreadySent := enc.sent[rt]; !alreadySent {
+ // No, so send it.
+ sent := enc.sendType(w, state, rt)
+ if enc.err != nil {
+ return
+ }
+ // If the type info has still not been transmitted, it means we have
+ // a singleton basic type (int, []byte etc.) at top level. We don't
+ // need to send the type info but we do need to update enc.sent.
+ if !sent {
+ typeLock.Lock()
+ info, err := getTypeInfo(ut)
+ typeLock.Unlock()
+ if err != nil {
+ enc.setError(err)
+ return
+ }
+ enc.sent[rt] = info.id
+ }
+ }
+}
+
+// sendTypeId sends the id, which must have already been defined.
+func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) {
+ // Identify the type of this top-level value.
+ state.encodeInt(int64(enc.sent[ut.base]))
+}
+
+// EncodeValue transmits the data item represented by the reflection value,
+// guaranteeing that all necessary type information has been transmitted first.
+func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
+ // Make sure we're single-threaded through here, so multiple
+ // goroutines can share an encoder.
+ enc.mutex.Lock()
+ defer enc.mutex.Unlock()
+
+ // Remove any nested writers remaining due to previous errors.
+ enc.w = enc.w[0:1]
+
+ ut, err := validUserType(value.Type())
+ if err != nil {
+ return err
+ }
+
+ enc.err = nil
+ enc.byteBuf.Reset()
+ state := enc.newEncoderState(&enc.byteBuf)
+
+ enc.sendTypeDescriptor(enc.writer(), state, ut)
+ enc.sendTypeId(state, ut)
+ if enc.err != nil {
+ return enc.err
+ }
+
+ // Encode the object.
+ enc.encode(state.b, value, ut)
+ if enc.err == nil {
+ enc.writeMessage(enc.writer(), state.b)
+ }
+
+ enc.freeEncoderState(state)
+ return enc.err
+}
diff --git a/src/cmd/fix/testdata/reflect.export.go.in b/src/cmd/fix/testdata/reflect.export.go.in
new file mode 100644
index 000000000..ce7940b29
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.export.go.in
@@ -0,0 +1,400 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ The netchan package implements type-safe networked channels:
+ it allows the two ends of a channel to appear on different
+ computers connected by a network. It does this by transporting
+ data sent to a channel on one machine so it can be recovered
+ by a receive of a channel of the same type on the other.
+
+ An exporter publishes a set of channels by name. An importer
+ connects to the exporting machine and imports the channels
+ by name. After importing the channels, the two machines can
+ use the channels in the usual way.
+
+ Networked channels are not synchronized; they always behave
+ as if they are buffered channels of at least one element.
+*/
+package netchan
+
+// BUG: can't use range clause to receive when using ImportNValues to limit the count.
+
+import (
+ "io"
+ "log"
+ "net"
+ "os"
+ "reflect"
+ "strconv"
+ "sync"
+)
+
+// Export
+
+// expLog is a logging convenience function. The first argument must be a string.
+func expLog(args ...interface{}) {
+ args[0] = "netchan export: " + args[0].(string)
+ log.Print(args...)
+}
+
+// An Exporter allows a set of channels to be published on a single
+// network port. A single machine may have multiple Exporters
+// but they must use different ports.
+type Exporter struct {
+ *clientSet
+}
+
+type expClient struct {
+ *encDec
+ exp *Exporter
+ chans map[int]*netChan // channels in use by client
+ mu sync.Mutex // protects remaining fields
+ errored bool // client has been sent an error
+ seqNum int64 // sequences messages sent to client; has value of highest sent
+ ackNum int64 // highest sequence number acknowledged
+ seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
+}
+
+func newClient(exp *Exporter, conn io.ReadWriter) *expClient {
+ client := new(expClient)
+ client.exp = exp
+ client.encDec = newEncDec(conn)
+ client.seqNum = 0
+ client.ackNum = 0
+ client.chans = make(map[int]*netChan)
+ return client
+}
+
+func (client *expClient) sendError(hdr *header, err string) {
+ error := &error{err}
+ expLog("sending error to client:", error.Error)
+ client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
+ client.mu.Lock()
+ client.errored = true
+ client.mu.Unlock()
+}
+
+func (client *expClient) newChan(hdr *header, dir Dir, name string, size int, count int64) *netChan {
+ exp := client.exp
+ exp.mu.Lock()
+ ech, ok := exp.names[name]
+ exp.mu.Unlock()
+ if !ok {
+ client.sendError(hdr, "no such channel: "+name)
+ return nil
+ }
+ if ech.dir != dir {
+ client.sendError(hdr, "wrong direction for channel: "+name)
+ return nil
+ }
+ nch := newNetChan(name, hdr.Id, ech, client.encDec, size, count)
+ client.chans[hdr.Id] = nch
+ return nch
+}
+
+func (client *expClient) getChan(hdr *header, dir Dir) *netChan {
+ nch := client.chans[hdr.Id]
+ if nch == nil {
+ return nil
+ }
+ if nch.dir != dir {
+ client.sendError(hdr, "wrong direction for channel: "+nch.name)
+ }
+ return nch
+}
+
+// The function run manages sends and receives for a single client. For each
+// (client Recv) request, this will launch a serveRecv goroutine to deliver
+// the data for that channel, while (client Send) requests are handled as
+// data arrives from the client.
+func (client *expClient) run() {
+ hdr := new(header)
+ hdrValue := reflect.NewValue(hdr)
+ req := new(request)
+ reqValue := reflect.NewValue(req)
+ error := new(error)
+ for {
+ *hdr = header{}
+ if err := client.decode(hdrValue); err != nil {
+ if err != os.EOF {
+ expLog("error decoding client header:", err)
+ }
+ break
+ }
+ switch hdr.PayloadType {
+ case payRequest:
+ *req = request{}
+ if err := client.decode(reqValue); err != nil {
+ expLog("error decoding client request:", err)
+ break
+ }
+ if req.Size < 1 {
+ panic("netchan: remote requested " + strconv.Itoa(req.Size) + " values")
+ }
+ switch req.Dir {
+ case Recv:
+ // look up channel before calling serveRecv to
+ // avoid a lock around client.chans.
+ if nch := client.newChan(hdr, Send, req.Name, req.Size, req.Count); nch != nil {
+ go client.serveRecv(nch, *hdr, req.Count)
+ }
+ case Send:
+ client.newChan(hdr, Recv, req.Name, req.Size, req.Count)
+ // The actual sends will have payload type payData.
+ // TODO: manage the count?
+ default:
+ error.Error = "request: can't handle channel direction"
+ expLog(error.Error, req.Dir)
+ client.encode(hdr, payError, error)
+ }
+ case payData:
+ client.serveSend(*hdr)
+ case payClosed:
+ client.serveClosed(*hdr)
+ case payAck:
+ client.mu.Lock()
+ if client.ackNum != hdr.SeqNum-1 {
+ // Since the sequence number is incremented and the message is sent
+ // in a single instance of locking client.mu, the messages are guaranteed
+ // to be sent in order. Therefore receipt of acknowledgement N means
+ // all messages <=N have been seen by the recipient. We check anyway.
+ expLog("sequence out of order:", client.ackNum, hdr.SeqNum)
+ }
+ if client.ackNum < hdr.SeqNum { // If there has been an error, don't back up the count.
+ client.ackNum = hdr.SeqNum
+ }
+ client.mu.Unlock()
+ case payAckSend:
+ if nch := client.getChan(hdr, Send); nch != nil {
+ nch.acked()
+ }
+ default:
+ log.Fatal("netchan export: unknown payload type", hdr.PayloadType)
+ }
+ }
+ client.exp.delClient(client)
+}
+
+// Send all the data on a single channel to a client asking for a Recv.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) {
+ for {
+ val, ok := nch.recv()
+ if !ok {
+ if err := client.encode(&hdr, payClosed, nil); err != nil {
+ expLog("error encoding server closed message:", err)
+ }
+ break
+ }
+ // We hold the lock during transmission to guarantee messages are
+ // sent in sequence number order. Also, we increment first so the
+ // value of client.SeqNum is the value of the highest used sequence
+ // number, not one beyond.
+ client.mu.Lock()
+ client.seqNum++
+ hdr.SeqNum = client.seqNum
+ client.seqLock.Lock() // guarantee ordering of messages
+ client.mu.Unlock()
+ err := client.encode(&hdr, payData, val.Interface())
+ client.seqLock.Unlock()
+ if err != nil {
+ expLog("error encoding client response:", err)
+ client.sendError(&hdr, err.String())
+ break
+ }
+ // Negative count means run forever.
+ if count >= 0 {
+ if count--; count <= 0 {
+ break
+ }
+ }
+ }
+}
+
+// Receive and deliver locally one item from a client asking for a Send
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveSend(hdr header) {
+ nch := client.getChan(&hdr, Recv)
+ if nch == nil {
+ return
+ }
+ // Create a new value for each received item.
+ val := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem())
+ if err := client.decode(val); err != nil {
+ expLog("value decode:", err, "; type ", nch.ch.Type())
+ return
+ }
+ nch.send(val)
+}
+
+// Report that client has closed the channel that is sending to us.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveClosed(hdr header) {
+ nch := client.getChan(&hdr, Recv)
+ if nch == nil {
+ return
+ }
+ nch.close()
+}
+
+func (client *expClient) unackedCount() int64 {
+ client.mu.Lock()
+ n := client.seqNum - client.ackNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) seq() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) ack() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+// Serve waits for incoming connections on the listener
+// and serves the Exporter's channels on each.
+// It blocks until the listener is closed.
+func (exp *Exporter) Serve(listener net.Listener) {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ expLog("listen:", err)
+ break
+ }
+ go exp.ServeConn(conn)
+ }
+}
+
+// ServeConn exports the Exporter's channels on conn.
+// It blocks until the connection is terminated.
+func (exp *Exporter) ServeConn(conn io.ReadWriter) {
+ exp.addClient(conn).run()
+}
+
+// NewExporter creates a new Exporter that exports a set of channels.
+func NewExporter() *Exporter {
+ e := &Exporter{
+ clientSet: &clientSet{
+ names: make(map[string]*chanDir),
+ clients: make(map[unackedCounter]bool),
+ },
+ }
+ return e
+}
+
+// ListenAndServe exports the exporter's channels through the
+// given network and local address defined as in net.Listen.
+func (exp *Exporter) ListenAndServe(network, localaddr string) os.Error {
+ listener, err := net.Listen(network, localaddr)
+ if err != nil {
+ return err
+ }
+ go exp.Serve(listener)
+ return nil
+}
+
+// addClient creates a new expClient and records its existence
+func (exp *Exporter) addClient(conn io.ReadWriter) *expClient {
+ client := newClient(exp, conn)
+ exp.mu.Lock()
+ exp.clients[client] = true
+ exp.mu.Unlock()
+ return client
+}
+
+// delClient forgets the client existed
+func (exp *Exporter) delClient(client *expClient) {
+ exp.mu.Lock()
+ exp.clients[client] = false, false
+ exp.mu.Unlock()
+}
+
+// Drain waits until all messages sent from this exporter/importer, including
+// those not yet sent to any client and possibly including those sent while
+// Drain was executing, have been received by the importer. In short, it
+// waits until all the exporter's messages have been received by a client.
+// If the timeout (measured in nanoseconds) is positive and Drain takes
+// longer than that to complete, an error is returned.
+func (exp *Exporter) Drain(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.drain(timeout)
+}
+
+// Sync waits until all clients of the exporter have received the messages
+// that were sent at the time Sync was invoked. Unlike Drain, it does not
+// wait for messages sent while it is running or messages that have not been
+// dispatched to any client. If the timeout (measured in nanoseconds) is
+// positive and Sync takes longer than that to complete, an error is
+// returned.
+func (exp *Exporter) Sync(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.sync(timeout)
+}
+
+func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) {
+ chanType, ok := reflect.Typeof(chT).(*reflect.ChanType)
+ if !ok {
+ return nil, os.NewError("not a channel")
+ }
+ if dir != Send && dir != Recv {
+ return nil, os.NewError("unknown channel direction")
+ }
+ switch chanType.Dir() {
+ case reflect.BothDir:
+ case reflect.SendDir:
+ if dir != Recv {
+ return nil, os.NewError("to import/export with Send, must provide <-chan")
+ }
+ case reflect.RecvDir:
+ if dir != Send {
+ return nil, os.NewError("to import/export with Recv, must provide chan<-")
+ }
+ }
+ return reflect.NewValue(chT).(*reflect.ChanValue), nil
+}
+
+// Export exports a channel of a given type and specified direction. The
+// channel to be exported is provided in the call and may be of arbitrary
+// channel type.
+// Despite the literal signature, the effective signature is
+// Export(name string, chT chan T, dir Dir)
+func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
+ ch, err := checkChan(chT, dir)
+ if err != nil {
+ return err
+ }
+ exp.mu.Lock()
+ defer exp.mu.Unlock()
+ _, present := exp.names[name]
+ if present {
+ return os.NewError("channel name already being exported:" + name)
+ }
+ exp.names[name] = &chanDir{ch, dir}
+ return nil
+}
+
+// Hangup disassociates the named channel from the Exporter and closes
+// the channel. Messages in flight for the channel may be dropped.
+func (exp *Exporter) Hangup(name string) os.Error {
+ exp.mu.Lock()
+ chDir, ok := exp.names[name]
+ if ok {
+ exp.names[name] = nil, false
+ }
+ // TODO drop all instances of channel from client sets
+ exp.mu.Unlock()
+ if !ok {
+ return os.NewError("netchan export: hangup: no such channel: " + name)
+ }
+ chDir.ch.Close()
+ return nil
+}
diff --git a/src/cmd/fix/testdata/reflect.export.go.out b/src/cmd/fix/testdata/reflect.export.go.out
new file mode 100644
index 000000000..7bd73c5e7
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.export.go.out
@@ -0,0 +1,400 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ The netchan package implements type-safe networked channels:
+ it allows the two ends of a channel to appear on different
+ computers connected by a network. It does this by transporting
+ data sent to a channel on one machine so it can be recovered
+ by a receive of a channel of the same type on the other.
+
+ An exporter publishes a set of channels by name. An importer
+ connects to the exporting machine and imports the channels
+ by name. After importing the channels, the two machines can
+ use the channels in the usual way.
+
+ Networked channels are not synchronized; they always behave
+ as if they are buffered channels of at least one element.
+*/
+package netchan
+
+// BUG: can't use range clause to receive when using ImportNValues to limit the count.
+
+import (
+ "io"
+ "log"
+ "net"
+ "os"
+ "reflect"
+ "strconv"
+ "sync"
+)
+
+// Export
+
+// expLog is a logging convenience function. The first argument must be a string.
+func expLog(args ...interface{}) {
+ args[0] = "netchan export: " + args[0].(string)
+ log.Print(args...)
+}
+
+// An Exporter allows a set of channels to be published on a single
+// network port. A single machine may have multiple Exporters
+// but they must use different ports.
+type Exporter struct {
+ *clientSet
+}
+
+type expClient struct {
+ *encDec
+ exp *Exporter
+ chans map[int]*netChan // channels in use by client
+ mu sync.Mutex // protects remaining fields
+ errored bool // client has been sent an error
+ seqNum int64 // sequences messages sent to client; has value of highest sent
+ ackNum int64 // highest sequence number acknowledged
+ seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
+}
+
+func newClient(exp *Exporter, conn io.ReadWriter) *expClient {
+ client := new(expClient)
+ client.exp = exp
+ client.encDec = newEncDec(conn)
+ client.seqNum = 0
+ client.ackNum = 0
+ client.chans = make(map[int]*netChan)
+ return client
+}
+
+func (client *expClient) sendError(hdr *header, err string) {
+ error := &error{err}
+ expLog("sending error to client:", error.Error)
+ client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
+ client.mu.Lock()
+ client.errored = true
+ client.mu.Unlock()
+}
+
+func (client *expClient) newChan(hdr *header, dir Dir, name string, size int, count int64) *netChan {
+ exp := client.exp
+ exp.mu.Lock()
+ ech, ok := exp.names[name]
+ exp.mu.Unlock()
+ if !ok {
+ client.sendError(hdr, "no such channel: "+name)
+ return nil
+ }
+ if ech.dir != dir {
+ client.sendError(hdr, "wrong direction for channel: "+name)
+ return nil
+ }
+ nch := newNetChan(name, hdr.Id, ech, client.encDec, size, count)
+ client.chans[hdr.Id] = nch
+ return nch
+}
+
+func (client *expClient) getChan(hdr *header, dir Dir) *netChan {
+ nch := client.chans[hdr.Id]
+ if nch == nil {
+ return nil
+ }
+ if nch.dir != dir {
+ client.sendError(hdr, "wrong direction for channel: "+nch.name)
+ }
+ return nch
+}
+
+// The function run manages sends and receives for a single client. For each
+// (client Recv) request, this will launch a serveRecv goroutine to deliver
+// the data for that channel, while (client Send) requests are handled as
+// data arrives from the client.
+func (client *expClient) run() {
+ hdr := new(header)
+ hdrValue := reflect.ValueOf(hdr)
+ req := new(request)
+ reqValue := reflect.ValueOf(req)
+ error := new(error)
+ for {
+ *hdr = header{}
+ if err := client.decode(hdrValue); err != nil {
+ if err != os.EOF {
+ expLog("error decoding client header:", err)
+ }
+ break
+ }
+ switch hdr.PayloadType {
+ case payRequest:
+ *req = request{}
+ if err := client.decode(reqValue); err != nil {
+ expLog("error decoding client request:", err)
+ break
+ }
+ if req.Size < 1 {
+ panic("netchan: remote requested " + strconv.Itoa(req.Size) + " values")
+ }
+ switch req.Dir {
+ case Recv:
+ // look up channel before calling serveRecv to
+ // avoid a lock around client.chans.
+ if nch := client.newChan(hdr, Send, req.Name, req.Size, req.Count); nch != nil {
+ go client.serveRecv(nch, *hdr, req.Count)
+ }
+ case Send:
+ client.newChan(hdr, Recv, req.Name, req.Size, req.Count)
+ // The actual sends will have payload type payData.
+ // TODO: manage the count?
+ default:
+ error.Error = "request: can't handle channel direction"
+ expLog(error.Error, req.Dir)
+ client.encode(hdr, payError, error)
+ }
+ case payData:
+ client.serveSend(*hdr)
+ case payClosed:
+ client.serveClosed(*hdr)
+ case payAck:
+ client.mu.Lock()
+ if client.ackNum != hdr.SeqNum-1 {
+ // Since the sequence number is incremented and the message is sent
+ // in a single instance of locking client.mu, the messages are guaranteed
+ // to be sent in order. Therefore receipt of acknowledgement N means
+ // all messages <=N have been seen by the recipient. We check anyway.
+ expLog("sequence out of order:", client.ackNum, hdr.SeqNum)
+ }
+ if client.ackNum < hdr.SeqNum { // If there has been an error, don't back up the count.
+ client.ackNum = hdr.SeqNum
+ }
+ client.mu.Unlock()
+ case payAckSend:
+ if nch := client.getChan(hdr, Send); nch != nil {
+ nch.acked()
+ }
+ default:
+ log.Fatal("netchan export: unknown payload type", hdr.PayloadType)
+ }
+ }
+ client.exp.delClient(client)
+}
+
+// Send all the data on a single channel to a client asking for a Recv.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) {
+ for {
+ val, ok := nch.recv()
+ if !ok {
+ if err := client.encode(&hdr, payClosed, nil); err != nil {
+ expLog("error encoding server closed message:", err)
+ }
+ break
+ }
+ // We hold the lock during transmission to guarantee messages are
+ // sent in sequence number order. Also, we increment first so the
+ // value of client.SeqNum is the value of the highest used sequence
+ // number, not one beyond.
+ client.mu.Lock()
+ client.seqNum++
+ hdr.SeqNum = client.seqNum
+ client.seqLock.Lock() // guarantee ordering of messages
+ client.mu.Unlock()
+ err := client.encode(&hdr, payData, val.Interface())
+ client.seqLock.Unlock()
+ if err != nil {
+ expLog("error encoding client response:", err)
+ client.sendError(&hdr, err.String())
+ break
+ }
+ // Negative count means run forever.
+ if count >= 0 {
+ if count--; count <= 0 {
+ break
+ }
+ }
+ }
+}
+
+// Receive and deliver locally one item from a client asking for a Send
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveSend(hdr header) {
+ nch := client.getChan(&hdr, Recv)
+ if nch == nil {
+ return
+ }
+ // Create a new value for each received item.
+ val := reflect.Zero(nch.ch.Type().Elem())
+ if err := client.decode(val); err != nil {
+ expLog("value decode:", err, "; type ", nch.ch.Type())
+ return
+ }
+ nch.send(val)
+}
+
+// Report that client has closed the channel that is sending to us.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveClosed(hdr header) {
+ nch := client.getChan(&hdr, Recv)
+ if nch == nil {
+ return
+ }
+ nch.close()
+}
+
+func (client *expClient) unackedCount() int64 {
+ client.mu.Lock()
+ n := client.seqNum - client.ackNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) seq() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) ack() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+// Serve waits for incoming connections on the listener
+// and serves the Exporter's channels on each.
+// It blocks until the listener is closed.
+func (exp *Exporter) Serve(listener net.Listener) {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ expLog("listen:", err)
+ break
+ }
+ go exp.ServeConn(conn)
+ }
+}
+
+// ServeConn exports the Exporter's channels on conn.
+// It blocks until the connection is terminated.
+func (exp *Exporter) ServeConn(conn io.ReadWriter) {
+ exp.addClient(conn).run()
+}
+
+// NewExporter creates a new Exporter that exports a set of channels.
+func NewExporter() *Exporter {
+ e := &Exporter{
+ clientSet: &clientSet{
+ names: make(map[string]*chanDir),
+ clients: make(map[unackedCounter]bool),
+ },
+ }
+ return e
+}
+
+// ListenAndServe exports the exporter's channels through the
+// given network and local address defined as in net.Listen.
+func (exp *Exporter) ListenAndServe(network, localaddr string) os.Error {
+ listener, err := net.Listen(network, localaddr)
+ if err != nil {
+ return err
+ }
+ go exp.Serve(listener)
+ return nil
+}
+
+// addClient creates a new expClient and records its existence
+func (exp *Exporter) addClient(conn io.ReadWriter) *expClient {
+ client := newClient(exp, conn)
+ exp.mu.Lock()
+ exp.clients[client] = true
+ exp.mu.Unlock()
+ return client
+}
+
+// delClient forgets the client existed
+func (exp *Exporter) delClient(client *expClient) {
+ exp.mu.Lock()
+ exp.clients[client] = false, false
+ exp.mu.Unlock()
+}
+
+// Drain waits until all messages sent from this exporter/importer, including
+// those not yet sent to any client and possibly including those sent while
+// Drain was executing, have been received by the importer. In short, it
+// waits until all the exporter's messages have been received by a client.
+// If the timeout (measured in nanoseconds) is positive and Drain takes
+// longer than that to complete, an error is returned.
+func (exp *Exporter) Drain(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.drain(timeout)
+}
+
+// Sync waits until all clients of the exporter have received the messages
+// that were sent at the time Sync was invoked. Unlike Drain, it does not
+// wait for messages sent while it is running or messages that have not been
+// dispatched to any client. If the timeout (measured in nanoseconds) is
+// positive and Sync takes longer than that to complete, an error is
+// returned.
+func (exp *Exporter) Sync(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.sync(timeout)
+}
+
+func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) {
+ chanType := reflect.TypeOf(chT)
+ if chanType.Kind() != reflect.Chan {
+ return reflect.Value{}, os.NewError("not a channel")
+ }
+ if dir != Send && dir != Recv {
+ return reflect.Value{}, os.NewError("unknown channel direction")
+ }
+ switch chanType.ChanDir() {
+ case reflect.BothDir:
+ case reflect.SendDir:
+ if dir != Recv {
+ return reflect.Value{}, os.NewError("to import/export with Send, must provide <-chan")
+ }
+ case reflect.RecvDir:
+ if dir != Send {
+ return reflect.Value{}, os.NewError("to import/export with Recv, must provide chan<-")
+ }
+ }
+ return reflect.ValueOf(chT), nil
+}
+
+// Export exports a channel of a given type and specified direction. The
+// channel to be exported is provided in the call and may be of arbitrary
+// channel type.
+// Despite the literal signature, the effective signature is
+// Export(name string, chT chan T, dir Dir)
+func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
+ ch, err := checkChan(chT, dir)
+ if err != nil {
+ return err
+ }
+ exp.mu.Lock()
+ defer exp.mu.Unlock()
+ _, present := exp.names[name]
+ if present {
+ return os.NewError("channel name already being exported:" + name)
+ }
+ exp.names[name] = &chanDir{ch, dir}
+ return nil
+}
+
+// Hangup disassociates the named channel from the Exporter and closes
+// the channel. Messages in flight for the channel may be dropped.
+func (exp *Exporter) Hangup(name string) os.Error {
+ exp.mu.Lock()
+ chDir, ok := exp.names[name]
+ if ok {
+ exp.names[name] = nil, false
+ }
+ // TODO drop all instances of channel from client sets
+ exp.mu.Unlock()
+ if !ok {
+ return os.NewError("netchan export: hangup: no such channel: " + name)
+ }
+ chDir.ch.Close()
+ return nil
+}
diff --git a/src/cmd/fix/testdata/reflect.print.go.in b/src/cmd/fix/testdata/reflect.print.go.in
new file mode 100644
index 000000000..6c9b8e4f9
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.print.go.in
@@ -0,0 +1,944 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "utf8"
+)
+
+// Some constants in the form of bytes, to avoid string overhead.
+// Needlessly fastidious, I suppose.
+var (
+ commaSpaceBytes = []byte(", ")
+ nilAngleBytes = []byte("<nil>")
+ nilParenBytes = []byte("(nil)")
+ nilBytes = []byte("nil")
+ mapBytes = []byte("map[")
+ missingBytes = []byte("(MISSING)")
+ extraBytes = []byte("%!(EXTRA ")
+ irparenBytes = []byte("i)")
+ bytesBytes = []byte("[]byte{")
+ widthBytes = []byte("%!(BADWIDTH)")
+ precBytes = []byte("%!(BADPREC)")
+ noVerbBytes = []byte("%!(NOVERB)")
+)
+
+// State represents the printer state passed to custom formatters.
+// It provides access to the io.Writer interface plus information about
+// the flags and options for the operand's format specifier.
+type State interface {
+ // Write is the function to call to emit formatted output to be printed.
+ Write(b []byte) (ret int, err os.Error)
+ // Width returns the value of the width option and whether it has been set.
+ Width() (wid int, ok bool)
+ // Precision returns the value of the precision option and whether it has been set.
+ Precision() (prec int, ok bool)
+
+ // Flag returns whether the flag c, a character, has been set.
+ Flag(int) bool
+}
+
+// Formatter is the interface implemented by values with a custom formatter.
+// The implementation of Format may call Sprintf or Fprintf(f) etc.
+// to generate its output.
+type Formatter interface {
+ Format(f State, c int)
+}
+
+// Stringer is implemented by any value that has a String method(),
+// which defines the ``native'' format for that value.
+// The String method is used to print values passed as an operand
+// to a %s or %v format or to an unformatted printer such as Print.
+type Stringer interface {
+ String() string
+}
+
+// GoStringer is implemented by any value that has a GoString() method,
+// which defines the Go syntax for that value.
+// The GoString method is used to print values passed as an operand
+// to a %#v format.
+type GoStringer interface {
+ GoString() string
+}
+
+type pp struct {
+ n int
+ buf bytes.Buffer
+ runeBuf [utf8.UTFMax]byte
+ fmt fmt
+}
+
+// A cache holds a set of reusable objects.
+// The buffered channel holds the currently available objects.
+// If more are needed, the cache creates them by calling new.
+type cache struct {
+ saved chan interface{}
+ new func() interface{}
+}
+
+func (c *cache) put(x interface{}) {
+ select {
+ case c.saved <- x:
+ // saved in cache
+ default:
+ // discard
+ }
+}
+
+func (c *cache) get() interface{} {
+ select {
+ case x := <-c.saved:
+ return x // reused from cache
+ default:
+ return c.new()
+ }
+ panic("not reached")
+}
+
+func newCache(f func() interface{}) *cache {
+ return &cache{make(chan interface{}, 100), f}
+}
+
+var ppFree = newCache(func() interface{} { return new(pp) })
+
+// Allocate a new pp struct or grab a cached one.
+func newPrinter() *pp {
+ p := ppFree.get().(*pp)
+ p.fmt.init(&p.buf)
+ return p
+}
+
+// Save used pp structs in ppFree; avoids an allocation per invocation.
+func (p *pp) free() {
+ // Don't hold on to pp structs with large buffers.
+ if cap(p.buf.Bytes()) > 1024 {
+ return
+ }
+ p.buf.Reset()
+ ppFree.put(p)
+}
+
+func (p *pp) Width() (wid int, ok bool) { return p.fmt.wid, p.fmt.widPresent }
+
+func (p *pp) Precision() (prec int, ok bool) { return p.fmt.prec, p.fmt.precPresent }
+
+func (p *pp) Flag(b int) bool {
+ switch b {
+ case '-':
+ return p.fmt.minus
+ case '+':
+ return p.fmt.plus
+ case '#':
+ return p.fmt.sharp
+ case ' ':
+ return p.fmt.space
+ case '0':
+ return p.fmt.zero
+ }
+ return false
+}
+
+func (p *pp) add(c int) {
+ p.buf.WriteRune(c)
+}
+
+// Implement Write so we can call Fprintf on a pp (through State), for
+// recursive use in custom verbs.
+func (p *pp) Write(b []byte) (ret int, err os.Error) {
+ return p.buf.Write(b)
+}
+
+// These routines end in 'f' and take a format string.
+
+// Fprintf formats according to a format specifier and writes to w.
+// It returns the number of bytes written and any write error encountered.
+func Fprintf(w io.Writer, format string, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrintf(format, a)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Printf formats according to a format specifier and writes to standard output.
+// It returns the number of bytes written and any write error encountered.
+func Printf(format string, a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprintf(os.Stdout, format, a...)
+ return n, errno
+}
+
+// Sprintf formats according to a format specifier and returns the resulting string.
+func Sprintf(format string, a ...interface{}) string {
+ p := newPrinter()
+ p.doPrintf(format, a)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// Errorf formats according to a format specifier and returns the string
+// converted to an os.ErrorString, which satisfies the os.Error interface.
+func Errorf(format string, a ...interface{}) os.Error {
+ return os.NewError(Sprintf(format, a...))
+}
+
+// These routines do not take a format string
+
+// Fprint formats using the default formats for its operands and writes to w.
+// Spaces are added between operands when neither is a string.
+// It returns the number of bytes written and any write error encountered.
+func Fprint(w io.Writer, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrint(a, false, false)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Print formats using the default formats for its operands and writes to standard output.
+// Spaces are added between operands when neither is a string.
+// It returns the number of bytes written and any write error encountered.
+func Print(a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprint(os.Stdout, a...)
+ return n, errno
+}
+
+// Sprint formats using the default formats for its operands and returns the resulting string.
+// Spaces are added between operands when neither is a string.
+func Sprint(a ...interface{}) string {
+ p := newPrinter()
+ p.doPrint(a, false, false)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// These routines end in 'ln', do not take a format string,
+// always add spaces between operands, and add a newline
+// after the last operand.
+
+// Fprintln formats using the default formats for its operands and writes to w.
+// Spaces are always added between operands and a newline is appended.
+// It returns the number of bytes written and any write error encountered.
+func Fprintln(w io.Writer, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrint(a, true, true)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Println formats using the default formats for its operands and writes to standard output.
+// Spaces are always added between operands and a newline is appended.
+// It returns the number of bytes written and any write error encountered.
+func Println(a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprintln(os.Stdout, a...)
+ return n, errno
+}
+
+// Sprintln formats using the default formats for its operands and returns the resulting string.
+// Spaces are always added between operands and a newline is appended.
+func Sprintln(a ...interface{}) string {
+ p := newPrinter()
+ p.doPrint(a, true, true)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// Get the i'th arg of the struct value.
+// If the arg itself is an interface, return a value for
+// the thing inside the interface, not the interface itself.
+func getField(v *reflect.StructValue, i int) reflect.Value {
+ val := v.Field(i)
+ if i, ok := val.(*reflect.InterfaceValue); ok {
+ if inter := i.Interface(); inter != nil {
+ return reflect.NewValue(inter)
+ }
+ }
+ return val
+}
+
+// Convert ASCII to integer. n is 0 (and got is false) if no number present.
+func parsenum(s string, start, end int) (num int, isnum bool, newi int) {
+ if start >= end {
+ return 0, false, end
+ }
+ for newi = start; newi < end && '0' <= s[newi] && s[newi] <= '9'; newi++ {
+ num = num*10 + int(s[newi]-'0')
+ isnum = true
+ }
+ return
+}
+
+func (p *pp) unknownType(v interface{}) {
+ if v == nil {
+ p.buf.Write(nilAngleBytes)
+ return
+ }
+ p.buf.WriteByte('?')
+ p.buf.WriteString(reflect.Typeof(v).String())
+ p.buf.WriteByte('?')
+}
+
+func (p *pp) badVerb(verb int, val interface{}) {
+ p.add('%')
+ p.add('!')
+ p.add(verb)
+ p.add('(')
+ if val == nil {
+ p.buf.Write(nilAngleBytes)
+ } else {
+ p.buf.WriteString(reflect.Typeof(val).String())
+ p.add('=')
+ p.printField(val, 'v', false, false, 0)
+ }
+ p.add(')')
+}
+
+func (p *pp) fmtBool(v bool, verb int, value interface{}) {
+ switch verb {
+ case 't', 'v':
+ p.fmt.fmt_boolean(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+// fmtC formats a rune for the 'c' format.
+func (p *pp) fmtC(c int64) {
+ rune := int(c) // Check for overflow.
+ if int64(rune) != c {
+ rune = utf8.RuneError
+ }
+ w := utf8.EncodeRune(p.runeBuf[0:utf8.UTFMax], rune)
+ p.fmt.pad(p.runeBuf[0:w])
+}
+
+func (p *pp) fmtInt64(v int64, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.integer(v, 2, signed, ldigits)
+ case 'c':
+ p.fmtC(v)
+ case 'd', 'v':
+ p.fmt.integer(v, 10, signed, ldigits)
+ case 'o':
+ p.fmt.integer(v, 8, signed, ldigits)
+ case 'x':
+ p.fmt.integer(v, 16, signed, ldigits)
+ case 'U':
+ p.fmtUnicode(v)
+ case 'X':
+ p.fmt.integer(v, 16, signed, udigits)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+// fmt0x64 formats a uint64 in hexadecimal and prefixes it with 0x or
+// not, as requested, by temporarily setting the sharp flag.
+func (p *pp) fmt0x64(v uint64, leading0x bool) {
+ sharp := p.fmt.sharp
+ p.fmt.sharp = leading0x
+ p.fmt.integer(int64(v), 16, unsigned, ldigits)
+ p.fmt.sharp = sharp
+}
+
+// fmtUnicode formats a uint64 in U+1234 form by
+// temporarily turning on the unicode flag and tweaking the precision.
+func (p *pp) fmtUnicode(v int64) {
+ precPresent := p.fmt.precPresent
+ prec := p.fmt.prec
+ if !precPresent {
+ // If prec is already set, leave it alone; otherwise 4 is minimum.
+ p.fmt.prec = 4
+ p.fmt.precPresent = true
+ }
+ p.fmt.unicode = true // turn on U+
+ p.fmt.integer(int64(v), 16, unsigned, udigits)
+ p.fmt.unicode = false
+ p.fmt.prec = prec
+ p.fmt.precPresent = precPresent
+}
+
+func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.integer(int64(v), 2, unsigned, ldigits)
+ case 'c':
+ p.fmtC(int64(v))
+ case 'd':
+ p.fmt.integer(int64(v), 10, unsigned, ldigits)
+ case 'v':
+ if goSyntax {
+ p.fmt0x64(v, true)
+ } else {
+ p.fmt.integer(int64(v), 10, unsigned, ldigits)
+ }
+ case 'o':
+ p.fmt.integer(int64(v), 8, unsigned, ldigits)
+ case 'x':
+ p.fmt.integer(int64(v), 16, unsigned, ldigits)
+ case 'X':
+ p.fmt.integer(int64(v), 16, unsigned, udigits)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtFloat32(v float32, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.fmt_fb32(v)
+ case 'e':
+ p.fmt.fmt_e32(v)
+ case 'E':
+ p.fmt.fmt_E32(v)
+ case 'f':
+ p.fmt.fmt_f32(v)
+ case 'g', 'v':
+ p.fmt.fmt_g32(v)
+ case 'G':
+ p.fmt.fmt_G32(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtFloat64(v float64, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.fmt_fb64(v)
+ case 'e':
+ p.fmt.fmt_e64(v)
+ case 'E':
+ p.fmt.fmt_E64(v)
+ case 'f':
+ p.fmt.fmt_f64(v)
+ case 'g', 'v':
+ p.fmt.fmt_g64(v)
+ case 'G':
+ p.fmt.fmt_G64(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtComplex64(v complex64, verb int, value interface{}) {
+ switch verb {
+ case 'e', 'E', 'f', 'F', 'g', 'G':
+ p.fmt.fmt_c64(v, verb)
+ case 'v':
+ p.fmt.fmt_c64(v, 'g')
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtComplex128(v complex128, verb int, value interface{}) {
+ switch verb {
+ case 'e', 'E', 'f', 'F', 'g', 'G':
+ p.fmt.fmt_c128(v, verb)
+ case 'v':
+ p.fmt.fmt_c128(v, 'g')
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) {
+ switch verb {
+ case 'v':
+ if goSyntax {
+ p.fmt.fmt_q(v)
+ } else {
+ p.fmt.fmt_s(v)
+ }
+ case 's':
+ p.fmt.fmt_s(v)
+ case 'x':
+ p.fmt.fmt_sx(v)
+ case 'X':
+ p.fmt.fmt_sX(v)
+ case 'q':
+ p.fmt.fmt_q(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interface{}) {
+ if verb == 'v' || verb == 'd' {
+ if goSyntax {
+ p.buf.Write(bytesBytes)
+ } else {
+ p.buf.WriteByte('[')
+ }
+ for i, c := range v {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(c, 'v', p.fmt.plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ return
+ }
+ s := string(v)
+ switch verb {
+ case 's':
+ p.fmt.fmt_s(s)
+ case 'x':
+ p.fmt.fmt_sx(s)
+ case 'X':
+ p.fmt.fmt_sX(s)
+ case 'q':
+ p.fmt.fmt_q(s)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSyntax bool) {
+ var u uintptr
+ switch value.(type) {
+ case *reflect.ChanValue, *reflect.FuncValue, *reflect.MapValue, *reflect.PtrValue, *reflect.SliceValue, *reflect.UnsafePointerValue:
+ u = value.Pointer()
+ default:
+ p.badVerb(verb, field)
+ return
+ }
+ if goSyntax {
+ p.add('(')
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.add(')')
+ p.add('(')
+ if u == 0 {
+ p.buf.Write(nilBytes)
+ } else {
+ p.fmt0x64(uint64(u), true)
+ }
+ p.add(')')
+ } else {
+ p.fmt0x64(uint64(u), !p.fmt.sharp)
+ }
+}
+
+var (
+ intBits = reflect.Typeof(0).Bits()
+ floatBits = reflect.Typeof(0.0).Bits()
+ complexBits = reflect.Typeof(1i).Bits()
+ uintptrBits = reflect.Typeof(uintptr(0)).Bits()
+)
+
+func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) {
+ if field == nil {
+ if verb == 'T' || verb == 'v' {
+ p.buf.Write(nilAngleBytes)
+ } else {
+ p.badVerb(verb, field)
+ }
+ return false
+ }
+
+ // Special processing considerations.
+ // %T (the value's type) and %p (its address) are special; we always do them first.
+ switch verb {
+ case 'T':
+ p.printField(reflect.Typeof(field).String(), 's', false, false, 0)
+ return false
+ case 'p':
+ p.fmtPointer(field, reflect.NewValue(field), verb, goSyntax)
+ return false
+ }
+ // Is it a Formatter?
+ if formatter, ok := field.(Formatter); ok {
+ formatter.Format(p, verb)
+ return false // this value is not a string
+
+ }
+ // Must not touch flags before Formatter looks at them.
+ if plus {
+ p.fmt.plus = false
+ }
+ // If we're doing Go syntax and the field knows how to supply it, take care of it now.
+ if goSyntax {
+ p.fmt.sharp = false
+ if stringer, ok := field.(GoStringer); ok {
+ // Print the result of GoString unadorned.
+ p.fmtString(stringer.GoString(), 's', false, field)
+ return false // this value is not a string
+ }
+ } else {
+ // Is it a Stringer?
+ if stringer, ok := field.(Stringer); ok {
+ p.printField(stringer.String(), verb, plus, false, depth)
+ return false // this value is not a string
+ }
+ }
+
+ // Some types can be done without reflection.
+ switch f := field.(type) {
+ case bool:
+ p.fmtBool(f, verb, field)
+ return false
+ case float32:
+ p.fmtFloat32(f, verb, field)
+ return false
+ case float64:
+ p.fmtFloat64(f, verb, field)
+ return false
+ case complex64:
+ p.fmtComplex64(complex64(f), verb, field)
+ return false
+ case complex128:
+ p.fmtComplex128(f, verb, field)
+ return false
+ case int:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int8:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int16:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int32:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int64:
+ p.fmtInt64(f, verb, field)
+ return false
+ case uint:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint8:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint16:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint32:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint64:
+ p.fmtUint64(f, verb, goSyntax, field)
+ return false
+ case uintptr:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case string:
+ p.fmtString(f, verb, goSyntax, field)
+ return verb == 's' || verb == 'v'
+ case []byte:
+ p.fmtBytes(f, verb, goSyntax, depth, field)
+ return verb == 's'
+ }
+
+ // Need to use reflection
+ value := reflect.NewValue(field)
+
+BigSwitch:
+ switch f := value.(type) {
+ case *reflect.BoolValue:
+ p.fmtBool(f.Get(), verb, field)
+ case *reflect.IntValue:
+ p.fmtInt64(f.Get(), verb, field)
+ case *reflect.UintValue:
+ p.fmtUint64(uint64(f.Get()), verb, goSyntax, field)
+ case *reflect.FloatValue:
+ if f.Type().Size() == 4 {
+ p.fmtFloat32(float32(f.Get()), verb, field)
+ } else {
+ p.fmtFloat64(float64(f.Get()), verb, field)
+ }
+ case *reflect.ComplexValue:
+ if f.Type().Size() == 8 {
+ p.fmtComplex64(complex64(f.Get()), verb, field)
+ } else {
+ p.fmtComplex128(complex128(f.Get()), verb, field)
+ }
+ case *reflect.StringValue:
+ p.fmtString(f.Get(), verb, goSyntax, field)
+ case *reflect.MapValue:
+ if goSyntax {
+ p.buf.WriteString(f.Type().String())
+ p.buf.WriteByte('{')
+ } else {
+ p.buf.Write(mapBytes)
+ }
+ keys := f.Keys()
+ for i, key := range keys {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(key.Interface(), verb, plus, goSyntax, depth+1)
+ p.buf.WriteByte(':')
+ p.printField(f.Elem(key).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ case *reflect.StructValue:
+ if goSyntax {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ }
+ p.add('{')
+ v := f
+ t := v.Type().(*reflect.StructType)
+ for i := 0; i < v.NumField(); i++ {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ if plus || goSyntax {
+ if f := t.Field(i); f.Name != "" {
+ p.buf.WriteString(f.Name)
+ p.buf.WriteByte(':')
+ }
+ }
+ p.printField(getField(v, i).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ p.buf.WriteByte('}')
+ case *reflect.InterfaceValue:
+ value := f.Elem()
+ if value == nil {
+ if goSyntax {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.Write(nilParenBytes)
+ } else {
+ p.buf.Write(nilAngleBytes)
+ }
+ } else {
+ return p.printField(value.Interface(), verb, plus, goSyntax, depth+1)
+ }
+ case reflect.ArrayOrSliceValue:
+ // Byte slices are special.
+ if f.Type().(reflect.ArrayOrSliceType).Elem().Kind() == reflect.Uint8 {
+ // We know it's a slice of bytes, but we also know it does not have static type
+ // []byte, or it would have been caught above. Therefore we cannot convert
+ // it directly in the (slightly) obvious way: f.Interface().([]byte); it doesn't have
+ // that type, and we can't write an expression of the right type and do a
+ // conversion because we don't have a static way to write the right type.
+ // So we build a slice by hand. This is a rare case but it would be nice
+ // if reflection could help a little more.
+ bytes := make([]byte, f.Len())
+ for i := range bytes {
+ bytes[i] = byte(f.Elem(i).(*reflect.UintValue).Get())
+ }
+ p.fmtBytes(bytes, verb, goSyntax, depth, field)
+ return verb == 's'
+ }
+ if goSyntax {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.WriteByte('{')
+ } else {
+ p.buf.WriteByte('[')
+ }
+ for i := 0; i < f.Len(); i++ {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(f.Elem(i).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ case *reflect.PtrValue:
+ v := f.Get()
+ // pointer to array or slice or struct? ok at top level
+ // but not embedded (avoid loops)
+ if v != 0 && depth == 0 {
+ switch a := f.Elem().(type) {
+ case reflect.ArrayOrSliceValue:
+ p.buf.WriteByte('&')
+ p.printField(a.Interface(), verb, plus, goSyntax, depth+1)
+ break BigSwitch
+ case *reflect.StructValue:
+ p.buf.WriteByte('&')
+ p.printField(a.Interface(), verb, plus, goSyntax, depth+1)
+ break BigSwitch
+ }
+ }
+ if goSyntax {
+ p.buf.WriteByte('(')
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.WriteByte(')')
+ p.buf.WriteByte('(')
+ if v == 0 {
+ p.buf.Write(nilBytes)
+ } else {
+ p.fmt0x64(uint64(v), true)
+ }
+ p.buf.WriteByte(')')
+ break
+ }
+ if v == 0 {
+ p.buf.Write(nilAngleBytes)
+ break
+ }
+ p.fmt0x64(uint64(v), true)
+ case *reflect.ChanValue, *reflect.FuncValue, *reflect.UnsafePointerValue:
+ p.fmtPointer(field, value, verb, goSyntax)
+ default:
+ p.unknownType(f)
+ }
+ return false
+}
+
+// intFromArg gets the fieldnumth element of a. On return, isInt reports whether the argument has type int.
+func intFromArg(a []interface{}, end, i, fieldnum int) (num int, isInt bool, newi, newfieldnum int) {
+ newi, newfieldnum = end, fieldnum
+ if i < end && fieldnum < len(a) {
+ num, isInt = a[fieldnum].(int)
+ newi, newfieldnum = i+1, fieldnum+1
+ }
+ return
+}
+
+func (p *pp) doPrintf(format string, a []interface{}) {
+ end := len(format)
+ fieldnum := 0 // we process one field per non-trivial format
+ for i := 0; i < end; {
+ lasti := i
+ for i < end && format[i] != '%' {
+ i++
+ }
+ if i > lasti {
+ p.buf.WriteString(format[lasti:i])
+ }
+ if i >= end {
+ // done processing format string
+ break
+ }
+
+ // Process one verb
+ i++
+ // flags and widths
+ p.fmt.clearflags()
+ F:
+ for ; i < end; i++ {
+ switch format[i] {
+ case '#':
+ p.fmt.sharp = true
+ case '0':
+ p.fmt.zero = true
+ case '+':
+ p.fmt.plus = true
+ case '-':
+ p.fmt.minus = true
+ case ' ':
+ p.fmt.space = true
+ default:
+ break F
+ }
+ }
+ // do we have width?
+ if i < end && format[i] == '*' {
+ p.fmt.wid, p.fmt.widPresent, i, fieldnum = intFromArg(a, end, i, fieldnum)
+ if !p.fmt.widPresent {
+ p.buf.Write(widthBytes)
+ }
+ } else {
+ p.fmt.wid, p.fmt.widPresent, i = parsenum(format, i, end)
+ }
+ // do we have precision?
+ if i < end && format[i] == '.' {
+ if format[i+1] == '*' {
+ p.fmt.prec, p.fmt.precPresent, i, fieldnum = intFromArg(a, end, i+1, fieldnum)
+ if !p.fmt.precPresent {
+ p.buf.Write(precBytes)
+ }
+ } else {
+ p.fmt.prec, p.fmt.precPresent, i = parsenum(format, i+1, end)
+ }
+ }
+ if i >= end {
+ p.buf.Write(noVerbBytes)
+ continue
+ }
+ c, w := utf8.DecodeRuneInString(format[i:])
+ i += w
+ // percent is special - absorbs no operand
+ if c == '%' {
+ p.buf.WriteByte('%') // We ignore width and prec.
+ continue
+ }
+ if fieldnum >= len(a) { // out of operands
+ p.buf.WriteByte('%')
+ p.add(c)
+ p.buf.Write(missingBytes)
+ continue
+ }
+ field := a[fieldnum]
+ fieldnum++
+
+ goSyntax := c == 'v' && p.fmt.sharp
+ plus := c == 'v' && p.fmt.plus
+ p.printField(field, c, plus, goSyntax, 0)
+ }
+
+ if fieldnum < len(a) {
+ p.buf.Write(extraBytes)
+ for ; fieldnum < len(a); fieldnum++ {
+ field := a[fieldnum]
+ if field != nil {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.WriteByte('=')
+ }
+ p.printField(field, 'v', false, false, 0)
+ if fieldnum+1 < len(a) {
+ p.buf.Write(commaSpaceBytes)
+ }
+ }
+ p.buf.WriteByte(')')
+ }
+}
+
+func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) {
+ prevString := false
+ for fieldnum := 0; fieldnum < len(a); fieldnum++ {
+ p.fmt.clearflags()
+ // always add spaces if we're doing println
+ field := a[fieldnum]
+ if fieldnum > 0 {
+ isString := field != nil && reflect.Typeof(field).Kind() == reflect.String
+ if addspace || !isString && !prevString {
+ p.buf.WriteByte(' ')
+ }
+ }
+ prevString = p.printField(field, 'v', false, false, 0)
+ }
+ if addnewline {
+ p.buf.WriteByte('\n')
+ }
+}
diff --git a/src/cmd/fix/testdata/reflect.print.go.out b/src/cmd/fix/testdata/reflect.print.go.out
new file mode 100644
index 000000000..b475a2ae1
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.print.go.out
@@ -0,0 +1,944 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "utf8"
+)
+
+// Some constants in the form of bytes, to avoid string overhead.
+// Needlessly fastidious, I suppose.
+var (
+ commaSpaceBytes = []byte(", ")
+ nilAngleBytes = []byte("<nil>")
+ nilParenBytes = []byte("(nil)")
+ nilBytes = []byte("nil")
+ mapBytes = []byte("map[")
+ missingBytes = []byte("(MISSING)")
+ extraBytes = []byte("%!(EXTRA ")
+ irparenBytes = []byte("i)")
+ bytesBytes = []byte("[]byte{")
+ widthBytes = []byte("%!(BADWIDTH)")
+ precBytes = []byte("%!(BADPREC)")
+ noVerbBytes = []byte("%!(NOVERB)")
+)
+
+// State represents the printer state passed to custom formatters.
+// It provides access to the io.Writer interface plus information about
+// the flags and options for the operand's format specifier.
+type State interface {
+ // Write is the function to call to emit formatted output to be printed.
+ Write(b []byte) (ret int, err os.Error)
+ // Width returns the value of the width option and whether it has been set.
+ Width() (wid int, ok bool)
+ // Precision returns the value of the precision option and whether it has been set.
+ Precision() (prec int, ok bool)
+
+ // Flag returns whether the flag c, a character, has been set.
+ Flag(int) bool
+}
+
+// Formatter is the interface implemented by values with a custom formatter.
+// The implementation of Format may call Sprintf or Fprintf(f) etc.
+// to generate its output.
+type Formatter interface {
+ Format(f State, c int)
+}
+
+// Stringer is implemented by any value that has a String method(),
+// which defines the ``native'' format for that value.
+// The String method is used to print values passed as an operand
+// to a %s or %v format or to an unformatted printer such as Print.
+type Stringer interface {
+ String() string
+}
+
+// GoStringer is implemented by any value that has a GoString() method,
+// which defines the Go syntax for that value.
+// The GoString method is used to print values passed as an operand
+// to a %#v format.
+type GoStringer interface {
+ GoString() string
+}
+
+type pp struct {
+ n int
+ buf bytes.Buffer
+ runeBuf [utf8.UTFMax]byte
+ fmt fmt
+}
+
+// A cache holds a set of reusable objects.
+// The buffered channel holds the currently available objects.
+// If more are needed, the cache creates them by calling new.
+type cache struct {
+ saved chan interface{}
+ new func() interface{}
+}
+
+func (c *cache) put(x interface{}) {
+ select {
+ case c.saved <- x:
+ // saved in cache
+ default:
+ // discard
+ }
+}
+
+func (c *cache) get() interface{} {
+ select {
+ case x := <-c.saved:
+ return x // reused from cache
+ default:
+ return c.new()
+ }
+ panic("not reached")
+}
+
+func newCache(f func() interface{}) *cache {
+ return &cache{make(chan interface{}, 100), f}
+}
+
+var ppFree = newCache(func() interface{} { return new(pp) })
+
+// Allocate a new pp struct or grab a cached one.
+func newPrinter() *pp {
+ p := ppFree.get().(*pp)
+ p.fmt.init(&p.buf)
+ return p
+}
+
+// Save used pp structs in ppFree; avoids an allocation per invocation.
+func (p *pp) free() {
+ // Don't hold on to pp structs with large buffers.
+ if cap(p.buf.Bytes()) > 1024 {
+ return
+ }
+ p.buf.Reset()
+ ppFree.put(p)
+}
+
+func (p *pp) Width() (wid int, ok bool) { return p.fmt.wid, p.fmt.widPresent }
+
+func (p *pp) Precision() (prec int, ok bool) { return p.fmt.prec, p.fmt.precPresent }
+
+func (p *pp) Flag(b int) bool {
+ switch b {
+ case '-':
+ return p.fmt.minus
+ case '+':
+ return p.fmt.plus
+ case '#':
+ return p.fmt.sharp
+ case ' ':
+ return p.fmt.space
+ case '0':
+ return p.fmt.zero
+ }
+ return false
+}
+
+func (p *pp) add(c int) {
+ p.buf.WriteRune(c)
+}
+
+// Implement Write so we can call Fprintf on a pp (through State), for
+// recursive use in custom verbs.
+func (p *pp) Write(b []byte) (ret int, err os.Error) {
+ return p.buf.Write(b)
+}
+
+// These routines end in 'f' and take a format string.
+
+// Fprintf formats according to a format specifier and writes to w.
+// It returns the number of bytes written and any write error encountered.
+func Fprintf(w io.Writer, format string, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrintf(format, a)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Printf formats according to a format specifier and writes to standard output.
+// It returns the number of bytes written and any write error encountered.
+func Printf(format string, a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprintf(os.Stdout, format, a...)
+ return n, errno
+}
+
+// Sprintf formats according to a format specifier and returns the resulting string.
+func Sprintf(format string, a ...interface{}) string {
+ p := newPrinter()
+ p.doPrintf(format, a)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// Errorf formats according to a format specifier and returns the string
+// converted to an os.ErrorString, which satisfies the os.Error interface.
+func Errorf(format string, a ...interface{}) os.Error {
+ return os.NewError(Sprintf(format, a...))
+}
+
+// These routines do not take a format string
+
+// Fprint formats using the default formats for its operands and writes to w.
+// Spaces are added between operands when neither is a string.
+// It returns the number of bytes written and any write error encountered.
+func Fprint(w io.Writer, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrint(a, false, false)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Print formats using the default formats for its operands and writes to standard output.
+// Spaces are added between operands when neither is a string.
+// It returns the number of bytes written and any write error encountered.
+func Print(a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprint(os.Stdout, a...)
+ return n, errno
+}
+
+// Sprint formats using the default formats for its operands and returns the resulting string.
+// Spaces are added between operands when neither is a string.
+func Sprint(a ...interface{}) string {
+ p := newPrinter()
+ p.doPrint(a, false, false)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// These routines end in 'ln', do not take a format string,
+// always add spaces between operands, and add a newline
+// after the last operand.
+
+// Fprintln formats using the default formats for its operands and writes to w.
+// Spaces are always added between operands and a newline is appended.
+// It returns the number of bytes written and any write error encountered.
+func Fprintln(w io.Writer, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrint(a, true, true)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Println formats using the default formats for its operands and writes to standard output.
+// Spaces are always added between operands and a newline is appended.
+// It returns the number of bytes written and any write error encountered.
+func Println(a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprintln(os.Stdout, a...)
+ return n, errno
+}
+
+// Sprintln formats using the default formats for its operands and returns the resulting string.
+// Spaces are always added between operands and a newline is appended.
+func Sprintln(a ...interface{}) string {
+ p := newPrinter()
+ p.doPrint(a, true, true)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// Get the i'th arg of the struct value.
+// If the arg itself is an interface, return a value for
+// the thing inside the interface, not the interface itself.
+func getField(v reflect.Value, i int) reflect.Value {
+ val := v.Field(i)
+ if i := val; i.Kind() == reflect.Interface {
+ if inter := i.Interface(); inter != nil {
+ return reflect.ValueOf(inter)
+ }
+ }
+ return val
+}
+
+// Convert ASCII to integer. n is 0 (and got is false) if no number present.
+func parsenum(s string, start, end int) (num int, isnum bool, newi int) {
+ if start >= end {
+ return 0, false, end
+ }
+ for newi = start; newi < end && '0' <= s[newi] && s[newi] <= '9'; newi++ {
+ num = num*10 + int(s[newi]-'0')
+ isnum = true
+ }
+ return
+}
+
+func (p *pp) unknownType(v interface{}) {
+ if v == nil {
+ p.buf.Write(nilAngleBytes)
+ return
+ }
+ p.buf.WriteByte('?')
+ p.buf.WriteString(reflect.TypeOf(v).String())
+ p.buf.WriteByte('?')
+}
+
+func (p *pp) badVerb(verb int, val interface{}) {
+ p.add('%')
+ p.add('!')
+ p.add(verb)
+ p.add('(')
+ if val == nil {
+ p.buf.Write(nilAngleBytes)
+ } else {
+ p.buf.WriteString(reflect.TypeOf(val).String())
+ p.add('=')
+ p.printField(val, 'v', false, false, 0)
+ }
+ p.add(')')
+}
+
+func (p *pp) fmtBool(v bool, verb int, value interface{}) {
+ switch verb {
+ case 't', 'v':
+ p.fmt.fmt_boolean(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+// fmtC formats a rune for the 'c' format.
+func (p *pp) fmtC(c int64) {
+ rune := int(c) // Check for overflow.
+ if int64(rune) != c {
+ rune = utf8.RuneError
+ }
+ w := utf8.EncodeRune(p.runeBuf[0:utf8.UTFMax], rune)
+ p.fmt.pad(p.runeBuf[0:w])
+}
+
+func (p *pp) fmtInt64(v int64, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.integer(v, 2, signed, ldigits)
+ case 'c':
+ p.fmtC(v)
+ case 'd', 'v':
+ p.fmt.integer(v, 10, signed, ldigits)
+ case 'o':
+ p.fmt.integer(v, 8, signed, ldigits)
+ case 'x':
+ p.fmt.integer(v, 16, signed, ldigits)
+ case 'U':
+ p.fmtUnicode(v)
+ case 'X':
+ p.fmt.integer(v, 16, signed, udigits)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+// fmt0x64 formats a uint64 in hexadecimal and prefixes it with 0x or
+// not, as requested, by temporarily setting the sharp flag.
+func (p *pp) fmt0x64(v uint64, leading0x bool) {
+ sharp := p.fmt.sharp
+ p.fmt.sharp = leading0x
+ p.fmt.integer(int64(v), 16, unsigned, ldigits)
+ p.fmt.sharp = sharp
+}
+
+// fmtUnicode formats a uint64 in U+1234 form by
+// temporarily turning on the unicode flag and tweaking the precision.
+func (p *pp) fmtUnicode(v int64) {
+ precPresent := p.fmt.precPresent
+ prec := p.fmt.prec
+ if !precPresent {
+ // If prec is already set, leave it alone; otherwise 4 is minimum.
+ p.fmt.prec = 4
+ p.fmt.precPresent = true
+ }
+ p.fmt.unicode = true // turn on U+
+ p.fmt.integer(int64(v), 16, unsigned, udigits)
+ p.fmt.unicode = false
+ p.fmt.prec = prec
+ p.fmt.precPresent = precPresent
+}
+
+func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.integer(int64(v), 2, unsigned, ldigits)
+ case 'c':
+ p.fmtC(int64(v))
+ case 'd':
+ p.fmt.integer(int64(v), 10, unsigned, ldigits)
+ case 'v':
+ if goSyntax {
+ p.fmt0x64(v, true)
+ } else {
+ p.fmt.integer(int64(v), 10, unsigned, ldigits)
+ }
+ case 'o':
+ p.fmt.integer(int64(v), 8, unsigned, ldigits)
+ case 'x':
+ p.fmt.integer(int64(v), 16, unsigned, ldigits)
+ case 'X':
+ p.fmt.integer(int64(v), 16, unsigned, udigits)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtFloat32(v float32, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.fmt_fb32(v)
+ case 'e':
+ p.fmt.fmt_e32(v)
+ case 'E':
+ p.fmt.fmt_E32(v)
+ case 'f':
+ p.fmt.fmt_f32(v)
+ case 'g', 'v':
+ p.fmt.fmt_g32(v)
+ case 'G':
+ p.fmt.fmt_G32(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtFloat64(v float64, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.fmt_fb64(v)
+ case 'e':
+ p.fmt.fmt_e64(v)
+ case 'E':
+ p.fmt.fmt_E64(v)
+ case 'f':
+ p.fmt.fmt_f64(v)
+ case 'g', 'v':
+ p.fmt.fmt_g64(v)
+ case 'G':
+ p.fmt.fmt_G64(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtComplex64(v complex64, verb int, value interface{}) {
+ switch verb {
+ case 'e', 'E', 'f', 'F', 'g', 'G':
+ p.fmt.fmt_c64(v, verb)
+ case 'v':
+ p.fmt.fmt_c64(v, 'g')
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtComplex128(v complex128, verb int, value interface{}) {
+ switch verb {
+ case 'e', 'E', 'f', 'F', 'g', 'G':
+ p.fmt.fmt_c128(v, verb)
+ case 'v':
+ p.fmt.fmt_c128(v, 'g')
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) {
+ switch verb {
+ case 'v':
+ if goSyntax {
+ p.fmt.fmt_q(v)
+ } else {
+ p.fmt.fmt_s(v)
+ }
+ case 's':
+ p.fmt.fmt_s(v)
+ case 'x':
+ p.fmt.fmt_sx(v)
+ case 'X':
+ p.fmt.fmt_sX(v)
+ case 'q':
+ p.fmt.fmt_q(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interface{}) {
+ if verb == 'v' || verb == 'd' {
+ if goSyntax {
+ p.buf.Write(bytesBytes)
+ } else {
+ p.buf.WriteByte('[')
+ }
+ for i, c := range v {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(c, 'v', p.fmt.plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ return
+ }
+ s := string(v)
+ switch verb {
+ case 's':
+ p.fmt.fmt_s(s)
+ case 'x':
+ p.fmt.fmt_sx(s)
+ case 'X':
+ p.fmt.fmt_sX(s)
+ case 'q':
+ p.fmt.fmt_q(s)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSyntax bool) {
+ var u uintptr
+ switch value.Kind() {
+ case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
+ u = value.Pointer()
+ default:
+ p.badVerb(verb, field)
+ return
+ }
+ if goSyntax {
+ p.add('(')
+ p.buf.WriteString(reflect.TypeOf(field).String())
+ p.add(')')
+ p.add('(')
+ if u == 0 {
+ p.buf.Write(nilBytes)
+ } else {
+ p.fmt0x64(uint64(u), true)
+ }
+ p.add(')')
+ } else {
+ p.fmt0x64(uint64(u), !p.fmt.sharp)
+ }
+}
+
+var (
+ intBits = reflect.TypeOf(0).Bits()
+ floatBits = reflect.TypeOf(0.0).Bits()
+ complexBits = reflect.TypeOf(1i).Bits()
+ uintptrBits = reflect.TypeOf(uintptr(0)).Bits()
+)
+
+func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) {
+ if field == nil {
+ if verb == 'T' || verb == 'v' {
+ p.buf.Write(nilAngleBytes)
+ } else {
+ p.badVerb(verb, field)
+ }
+ return false
+ }
+
+ // Special processing considerations.
+ // %T (the value's type) and %p (its address) are special; we always do them first.
+ switch verb {
+ case 'T':
+ p.printField(reflect.TypeOf(field).String(), 's', false, false, 0)
+ return false
+ case 'p':
+ p.fmtPointer(field, reflect.ValueOf(field), verb, goSyntax)
+ return false
+ }
+ // Is it a Formatter?
+ if formatter, ok := field.(Formatter); ok {
+ formatter.Format(p, verb)
+ return false // this value is not a string
+
+ }
+ // Must not touch flags before Formatter looks at them.
+ if plus {
+ p.fmt.plus = false
+ }
+ // If we're doing Go syntax and the field knows how to supply it, take care of it now.
+ if goSyntax {
+ p.fmt.sharp = false
+ if stringer, ok := field.(GoStringer); ok {
+ // Print the result of GoString unadorned.
+ p.fmtString(stringer.GoString(), 's', false, field)
+ return false // this value is not a string
+ }
+ } else {
+ // Is it a Stringer?
+ if stringer, ok := field.(Stringer); ok {
+ p.printField(stringer.String(), verb, plus, false, depth)
+ return false // this value is not a string
+ }
+ }
+
+ // Some types can be done without reflection.
+ switch f := field.(type) {
+ case bool:
+ p.fmtBool(f, verb, field)
+ return false
+ case float32:
+ p.fmtFloat32(f, verb, field)
+ return false
+ case float64:
+ p.fmtFloat64(f, verb, field)
+ return false
+ case complex64:
+ p.fmtComplex64(complex64(f), verb, field)
+ return false
+ case complex128:
+ p.fmtComplex128(f, verb, field)
+ return false
+ case int:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int8:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int16:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int32:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int64:
+ p.fmtInt64(f, verb, field)
+ return false
+ case uint:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint8:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint16:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint32:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint64:
+ p.fmtUint64(f, verb, goSyntax, field)
+ return false
+ case uintptr:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case string:
+ p.fmtString(f, verb, goSyntax, field)
+ return verb == 's' || verb == 'v'
+ case []byte:
+ p.fmtBytes(f, verb, goSyntax, depth, field)
+ return verb == 's'
+ }
+
+ // Need to use reflection
+ value := reflect.ValueOf(field)
+
+BigSwitch:
+ switch f := value; f.Kind() {
+ case reflect.Bool:
+ p.fmtBool(f.Bool(), verb, field)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ p.fmtInt64(f.Int(), verb, field)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ p.fmtUint64(uint64(f.Uint()), verb, goSyntax, field)
+ case reflect.Float32, reflect.Float64:
+ if f.Type().Size() == 4 {
+ p.fmtFloat32(float32(f.Float()), verb, field)
+ } else {
+ p.fmtFloat64(float64(f.Float()), verb, field)
+ }
+ case reflect.Complex64, reflect.Complex128:
+ if f.Type().Size() == 8 {
+ p.fmtComplex64(complex64(f.Complex()), verb, field)
+ } else {
+ p.fmtComplex128(complex128(f.Complex()), verb, field)
+ }
+ case reflect.String:
+ p.fmtString(f.String(), verb, goSyntax, field)
+ case reflect.Map:
+ if goSyntax {
+ p.buf.WriteString(f.Type().String())
+ p.buf.WriteByte('{')
+ } else {
+ p.buf.Write(mapBytes)
+ }
+ keys := f.MapKeys()
+ for i, key := range keys {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(key.Interface(), verb, plus, goSyntax, depth+1)
+ p.buf.WriteByte(':')
+ p.printField(f.MapIndex(key).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ case reflect.Struct:
+ if goSyntax {
+ p.buf.WriteString(reflect.TypeOf(field).String())
+ }
+ p.add('{')
+ v := f
+ t := v.Type()
+ for i := 0; i < v.NumField(); i++ {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ if plus || goSyntax {
+ if f := t.Field(i); f.Name != "" {
+ p.buf.WriteString(f.Name)
+ p.buf.WriteByte(':')
+ }
+ }
+ p.printField(getField(v, i).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ p.buf.WriteByte('}')
+ case reflect.Interface:
+ value := f.Elem()
+ if !value.IsValid() {
+ if goSyntax {
+ p.buf.WriteString(reflect.TypeOf(field).String())
+ p.buf.Write(nilParenBytes)
+ } else {
+ p.buf.Write(nilAngleBytes)
+ }
+ } else {
+ return p.printField(value.Interface(), verb, plus, goSyntax, depth+1)
+ }
+ case reflect.Array, reflect.Slice:
+ // Byte slices are special.
+ if f.Type().Elem().Kind() == reflect.Uint8 {
+ // We know it's a slice of bytes, but we also know it does not have static type
+ // []byte, or it would have been caught above. Therefore we cannot convert
+ // it directly in the (slightly) obvious way: f.Interface().([]byte); it doesn't have
+ // that type, and we can't write an expression of the right type and do a
+ // conversion because we don't have a static way to write the right type.
+ // So we build a slice by hand. This is a rare case but it would be nice
+ // if reflection could help a little more.
+ bytes := make([]byte, f.Len())
+ for i := range bytes {
+ bytes[i] = byte(f.Index(i).Uint())
+ }
+ p.fmtBytes(bytes, verb, goSyntax, depth, field)
+ return verb == 's'
+ }
+ if goSyntax {
+ p.buf.WriteString(reflect.TypeOf(field).String())
+ p.buf.WriteByte('{')
+ } else {
+ p.buf.WriteByte('[')
+ }
+ for i := 0; i < f.Len(); i++ {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(f.Index(i).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ case reflect.Ptr:
+ v := f.Pointer()
+ // pointer to array or slice or struct? ok at top level
+ // but not embedded (avoid loops)
+ if v != 0 && depth == 0 {
+ switch a := f.Elem(); a.Kind() {
+ case reflect.Array, reflect.Slice:
+ p.buf.WriteByte('&')
+ p.printField(a.Interface(), verb, plus, goSyntax, depth+1)
+ break BigSwitch
+ case reflect.Struct:
+ p.buf.WriteByte('&')
+ p.printField(a.Interface(), verb, plus, goSyntax, depth+1)
+ break BigSwitch
+ }
+ }
+ if goSyntax {
+ p.buf.WriteByte('(')
+ p.buf.WriteString(reflect.TypeOf(field).String())
+ p.buf.WriteByte(')')
+ p.buf.WriteByte('(')
+ if v == 0 {
+ p.buf.Write(nilBytes)
+ } else {
+ p.fmt0x64(uint64(v), true)
+ }
+ p.buf.WriteByte(')')
+ break
+ }
+ if v == 0 {
+ p.buf.Write(nilAngleBytes)
+ break
+ }
+ p.fmt0x64(uint64(v), true)
+ case reflect.Chan, reflect.Func, reflect.UnsafePointer:
+ p.fmtPointer(field, value, verb, goSyntax)
+ default:
+ p.unknownType(f)
+ }
+ return false
+}
+
+// intFromArg gets the fieldnumth element of a. On return, isInt reports whether the argument has type int.
+func intFromArg(a []interface{}, end, i, fieldnum int) (num int, isInt bool, newi, newfieldnum int) {
+ newi, newfieldnum = end, fieldnum
+ if i < end && fieldnum < len(a) {
+ num, isInt = a[fieldnum].(int)
+ newi, newfieldnum = i+1, fieldnum+1
+ }
+ return
+}
+
+func (p *pp) doPrintf(format string, a []interface{}) {
+ end := len(format)
+ fieldnum := 0 // we process one field per non-trivial format
+ for i := 0; i < end; {
+ lasti := i
+ for i < end && format[i] != '%' {
+ i++
+ }
+ if i > lasti {
+ p.buf.WriteString(format[lasti:i])
+ }
+ if i >= end {
+ // done processing format string
+ break
+ }
+
+ // Process one verb
+ i++
+ // flags and widths
+ p.fmt.clearflags()
+ F:
+ for ; i < end; i++ {
+ switch format[i] {
+ case '#':
+ p.fmt.sharp = true
+ case '0':
+ p.fmt.zero = true
+ case '+':
+ p.fmt.plus = true
+ case '-':
+ p.fmt.minus = true
+ case ' ':
+ p.fmt.space = true
+ default:
+ break F
+ }
+ }
+ // do we have width?
+ if i < end && format[i] == '*' {
+ p.fmt.wid, p.fmt.widPresent, i, fieldnum = intFromArg(a, end, i, fieldnum)
+ if !p.fmt.widPresent {
+ p.buf.Write(widthBytes)
+ }
+ } else {
+ p.fmt.wid, p.fmt.widPresent, i = parsenum(format, i, end)
+ }
+ // do we have precision?
+ if i < end && format[i] == '.' {
+ if format[i+1] == '*' {
+ p.fmt.prec, p.fmt.precPresent, i, fieldnum = intFromArg(a, end, i+1, fieldnum)
+ if !p.fmt.precPresent {
+ p.buf.Write(precBytes)
+ }
+ } else {
+ p.fmt.prec, p.fmt.precPresent, i = parsenum(format, i+1, end)
+ }
+ }
+ if i >= end {
+ p.buf.Write(noVerbBytes)
+ continue
+ }
+ c, w := utf8.DecodeRuneInString(format[i:])
+ i += w
+ // percent is special - absorbs no operand
+ if c == '%' {
+ p.buf.WriteByte('%') // We ignore width and prec.
+ continue
+ }
+ if fieldnum >= len(a) { // out of operands
+ p.buf.WriteByte('%')
+ p.add(c)
+ p.buf.Write(missingBytes)
+ continue
+ }
+ field := a[fieldnum]
+ fieldnum++
+
+ goSyntax := c == 'v' && p.fmt.sharp
+ plus := c == 'v' && p.fmt.plus
+ p.printField(field, c, plus, goSyntax, 0)
+ }
+
+ if fieldnum < len(a) {
+ p.buf.Write(extraBytes)
+ for ; fieldnum < len(a); fieldnum++ {
+ field := a[fieldnum]
+ if field != nil {
+ p.buf.WriteString(reflect.TypeOf(field).String())
+ p.buf.WriteByte('=')
+ }
+ p.printField(field, 'v', false, false, 0)
+ if fieldnum+1 < len(a) {
+ p.buf.Write(commaSpaceBytes)
+ }
+ }
+ p.buf.WriteByte(')')
+ }
+}
+
+func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) {
+ prevString := false
+ for fieldnum := 0; fieldnum < len(a); fieldnum++ {
+ p.fmt.clearflags()
+ // always add spaces if we're doing println
+ field := a[fieldnum]
+ if fieldnum > 0 {
+ isString := field != nil && reflect.TypeOf(field).Kind() == reflect.String
+ if addspace || !isString && !prevString {
+ p.buf.WriteByte(' ')
+ }
+ }
+ prevString = p.printField(field, 'v', false, false, 0)
+ }
+ if addnewline {
+ p.buf.WriteByte('\n')
+ }
+}
diff --git a/src/cmd/fix/testdata/reflect.quick.go.in b/src/cmd/fix/testdata/reflect.quick.go.in
new file mode 100644
index 000000000..a5568b048
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.quick.go.in
@@ -0,0 +1,364 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package implements utility functions to help with black box testing.
+package quick
+
+import (
+ "flag"
+ "fmt"
+ "math"
+ "os"
+ "rand"
+ "reflect"
+ "strings"
+)
+
+var defaultMaxCount *int = flag.Int("quickchecks", 100, "The default number of iterations for each check")
+
+// A Generator can generate random values of its own type.
+type Generator interface {
+ // Generate returns a random instance of the type on which it is a
+ // method using the size as a size hint.
+ Generate(rand *rand.Rand, size int) reflect.Value
+}
+
+// randFloat32 generates a random float taking the full range of a float32.
+func randFloat32(rand *rand.Rand) float32 {
+ f := rand.Float64() * math.MaxFloat32
+ if rand.Int()&1 == 1 {
+ f = -f
+ }
+ return float32(f)
+}
+
+// randFloat64 generates a random float taking the full range of a float64.
+func randFloat64(rand *rand.Rand) float64 {
+ f := rand.Float64()
+ if rand.Int()&1 == 1 {
+ f = -f
+ }
+ return f
+}
+
+// randInt64 returns a random integer taking half the range of an int64.
+func randInt64(rand *rand.Rand) int64 { return rand.Int63() - 1<<62 }
+
+// complexSize is the maximum length of arbitrary values that contain other
+// values.
+const complexSize = 50
+
+// Value returns an arbitrary value of the given type.
+// If the type implements the Generator interface, that will be used.
+// Note: in order to create arbitrary values for structs, all the members must be public.
+func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) {
+ if m, ok := reflect.MakeZero(t).Interface().(Generator); ok {
+ return m.Generate(rand, complexSize), true
+ }
+
+ switch concrete := t.(type) {
+ case *reflect.BoolType:
+ return reflect.NewValue(rand.Int()&1 == 0), true
+ case *reflect.FloatType, *reflect.IntType, *reflect.UintType, *reflect.ComplexType:
+ switch t.Kind() {
+ case reflect.Float32:
+ return reflect.NewValue(randFloat32(rand)), true
+ case reflect.Float64:
+ return reflect.NewValue(randFloat64(rand)), true
+ case reflect.Complex64:
+ return reflect.NewValue(complex(randFloat32(rand), randFloat32(rand))), true
+ case reflect.Complex128:
+ return reflect.NewValue(complex(randFloat64(rand), randFloat64(rand))), true
+ case reflect.Int16:
+ return reflect.NewValue(int16(randInt64(rand))), true
+ case reflect.Int32:
+ return reflect.NewValue(int32(randInt64(rand))), true
+ case reflect.Int64:
+ return reflect.NewValue(randInt64(rand)), true
+ case reflect.Int8:
+ return reflect.NewValue(int8(randInt64(rand))), true
+ case reflect.Int:
+ return reflect.NewValue(int(randInt64(rand))), true
+ case reflect.Uint16:
+ return reflect.NewValue(uint16(randInt64(rand))), true
+ case reflect.Uint32:
+ return reflect.NewValue(uint32(randInt64(rand))), true
+ case reflect.Uint64:
+ return reflect.NewValue(uint64(randInt64(rand))), true
+ case reflect.Uint8:
+ return reflect.NewValue(uint8(randInt64(rand))), true
+ case reflect.Uint:
+ return reflect.NewValue(uint(randInt64(rand))), true
+ case reflect.Uintptr:
+ return reflect.NewValue(uintptr(randInt64(rand))), true
+ }
+ case *reflect.MapType:
+ numElems := rand.Intn(complexSize)
+ m := reflect.MakeMap(concrete)
+ for i := 0; i < numElems; i++ {
+ key, ok1 := Value(concrete.Key(), rand)
+ value, ok2 := Value(concrete.Elem(), rand)
+ if !ok1 || !ok2 {
+ return nil, false
+ }
+ m.SetElem(key, value)
+ }
+ return m, true
+ case *reflect.PtrType:
+ v, ok := Value(concrete.Elem(), rand)
+ if !ok {
+ return nil, false
+ }
+ p := reflect.MakeZero(concrete)
+ p.(*reflect.PtrValue).PointTo(v)
+ return p, true
+ case *reflect.SliceType:
+ numElems := rand.Intn(complexSize)
+ s := reflect.MakeSlice(concrete, numElems, numElems)
+ for i := 0; i < numElems; i++ {
+ v, ok := Value(concrete.Elem(), rand)
+ if !ok {
+ return nil, false
+ }
+ s.Elem(i).SetValue(v)
+ }
+ return s, true
+ case *reflect.StringType:
+ numChars := rand.Intn(complexSize)
+ codePoints := make([]int, numChars)
+ for i := 0; i < numChars; i++ {
+ codePoints[i] = rand.Intn(0x10ffff)
+ }
+ return reflect.NewValue(string(codePoints)), true
+ case *reflect.StructType:
+ s := reflect.MakeZero(t).(*reflect.StructValue)
+ for i := 0; i < s.NumField(); i++ {
+ v, ok := Value(concrete.Field(i).Type, rand)
+ if !ok {
+ return nil, false
+ }
+ s.Field(i).SetValue(v)
+ }
+ return s, true
+ default:
+ return nil, false
+ }
+
+ return
+}
+
+// A Config structure contains options for running a test.
+type Config struct {
+ // MaxCount sets the maximum number of iterations. If zero,
+ // MaxCountScale is used.
+ MaxCount int
+ // MaxCountScale is a non-negative scale factor applied to the default
+ // maximum. If zero, the default is unchanged.
+ MaxCountScale float64
+ // If non-nil, rand is a source of random numbers. Otherwise a default
+ // pseudo-random source will be used.
+ Rand *rand.Rand
+ // If non-nil, Values is a function which generates a slice of arbitrary
+ // Values that are congruent with the arguments to the function being
+ // tested. Otherwise, Values is used to generate the values.
+ Values func([]reflect.Value, *rand.Rand)
+}
+
+var defaultConfig Config
+
+// getRand returns the *rand.Rand to use for a given Config.
+func (c *Config) getRand() *rand.Rand {
+ if c.Rand == nil {
+ return rand.New(rand.NewSource(0))
+ }
+ return c.Rand
+}
+
+// getMaxCount returns the maximum number of iterations to run for a given
+// Config.
+func (c *Config) getMaxCount() (maxCount int) {
+ maxCount = c.MaxCount
+ if maxCount == 0 {
+ if c.MaxCountScale != 0 {
+ maxCount = int(c.MaxCountScale * float64(*defaultMaxCount))
+ } else {
+ maxCount = *defaultMaxCount
+ }
+ }
+
+ return
+}
+
+// A SetupError is the result of an error in the way that check is being
+// used, independent of the functions being tested.
+type SetupError string
+
+func (s SetupError) String() string { return string(s) }
+
+// A CheckError is the result of Check finding an error.
+type CheckError struct {
+ Count int
+ In []interface{}
+}
+
+func (s *CheckError) String() string {
+ return fmt.Sprintf("#%d: failed on input %s", s.Count, toString(s.In))
+}
+
+// A CheckEqualError is the result CheckEqual finding an error.
+type CheckEqualError struct {
+ CheckError
+ Out1 []interface{}
+ Out2 []interface{}
+}
+
+func (s *CheckEqualError) String() string {
+ return fmt.Sprintf("#%d: failed on input %s. Output 1: %s. Output 2: %s", s.Count, toString(s.In), toString(s.Out1), toString(s.Out2))
+}
+
+// Check looks for an input to f, any function that returns bool,
+// such that f returns false. It calls f repeatedly, with arbitrary
+// values for each argument. If f returns false on a given input,
+// Check returns that input as a *CheckError.
+// For example:
+//
+// func TestOddMultipleOfThree(t *testing.T) {
+// f := func(x int) bool {
+// y := OddMultipleOfThree(x)
+// return y%2 == 1 && y%3 == 0
+// }
+// if err := quick.Check(f, nil); err != nil {
+// t.Error(err)
+// }
+// }
+func Check(function interface{}, config *Config) (err os.Error) {
+ if config == nil {
+ config = &defaultConfig
+ }
+
+ f, fType, ok := functionAndType(function)
+ if !ok {
+ err = SetupError("argument is not a function")
+ return
+ }
+
+ if fType.NumOut() != 1 {
+ err = SetupError("function returns more than one value.")
+ return
+ }
+ if _, ok := fType.Out(0).(*reflect.BoolType); !ok {
+ err = SetupError("function does not return a bool")
+ return
+ }
+
+ arguments := make([]reflect.Value, fType.NumIn())
+ rand := config.getRand()
+ maxCount := config.getMaxCount()
+
+ for i := 0; i < maxCount; i++ {
+ err = arbitraryValues(arguments, fType, config, rand)
+ if err != nil {
+ return
+ }
+
+ if !f.Call(arguments)[0].(*reflect.BoolValue).Get() {
+ err = &CheckError{i + 1, toInterfaces(arguments)}
+ return
+ }
+ }
+
+ return
+}
+
+// CheckEqual looks for an input on which f and g return different results.
+// It calls f and g repeatedly with arbitrary values for each argument.
+// If f and g return different answers, CheckEqual returns a *CheckEqualError
+// describing the input and the outputs.
+func CheckEqual(f, g interface{}, config *Config) (err os.Error) {
+ if config == nil {
+ config = &defaultConfig
+ }
+
+ x, xType, ok := functionAndType(f)
+ if !ok {
+ err = SetupError("f is not a function")
+ return
+ }
+ y, yType, ok := functionAndType(g)
+ if !ok {
+ err = SetupError("g is not a function")
+ return
+ }
+
+ if xType != yType {
+ err = SetupError("functions have different types")
+ return
+ }
+
+ arguments := make([]reflect.Value, xType.NumIn())
+ rand := config.getRand()
+ maxCount := config.getMaxCount()
+
+ for i := 0; i < maxCount; i++ {
+ err = arbitraryValues(arguments, xType, config, rand)
+ if err != nil {
+ return
+ }
+
+ xOut := toInterfaces(x.Call(arguments))
+ yOut := toInterfaces(y.Call(arguments))
+
+ if !reflect.DeepEqual(xOut, yOut) {
+ err = &CheckEqualError{CheckError{i + 1, toInterfaces(arguments)}, xOut, yOut}
+ return
+ }
+ }
+
+ return
+}
+
+// arbitraryValues writes Values to args such that args contains Values
+// suitable for calling f.
+func arbitraryValues(args []reflect.Value, f *reflect.FuncType, config *Config, rand *rand.Rand) (err os.Error) {
+ if config.Values != nil {
+ config.Values(args, rand)
+ return
+ }
+
+ for j := 0; j < len(args); j++ {
+ var ok bool
+ args[j], ok = Value(f.In(j), rand)
+ if !ok {
+ err = SetupError(fmt.Sprintf("cannot create arbitrary value of type %s for argument %d", f.In(j), j))
+ return
+ }
+ }
+
+ return
+}
+
+func functionAndType(f interface{}) (v *reflect.FuncValue, t *reflect.FuncType, ok bool) {
+ v, ok = reflect.NewValue(f).(*reflect.FuncValue)
+ if !ok {
+ return
+ }
+ t = v.Type().(*reflect.FuncType)
+ return
+}
+
+func toInterfaces(values []reflect.Value) []interface{} {
+ ret := make([]interface{}, len(values))
+ for i, v := range values {
+ ret[i] = v.Interface()
+ }
+ return ret
+}
+
+func toString(interfaces []interface{}) string {
+ s := make([]string, len(interfaces))
+ for i, v := range interfaces {
+ s[i] = fmt.Sprintf("%#v", v)
+ }
+ return strings.Join(s, ", ")
+}
diff --git a/src/cmd/fix/testdata/reflect.quick.go.out b/src/cmd/fix/testdata/reflect.quick.go.out
new file mode 100644
index 000000000..c62305b83
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.quick.go.out
@@ -0,0 +1,365 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package implements utility functions to help with black box testing.
+package quick
+
+import (
+ "flag"
+ "fmt"
+ "math"
+ "os"
+ "rand"
+ "reflect"
+ "strings"
+)
+
+var defaultMaxCount *int = flag.Int("quickchecks", 100, "The default number of iterations for each check")
+
+// A Generator can generate random values of its own type.
+type Generator interface {
+ // Generate returns a random instance of the type on which it is a
+ // method using the size as a size hint.
+ Generate(rand *rand.Rand, size int) reflect.Value
+}
+
+// randFloat32 generates a random float taking the full range of a float32.
+func randFloat32(rand *rand.Rand) float32 {
+ f := rand.Float64() * math.MaxFloat32
+ if rand.Int()&1 == 1 {
+ f = -f
+ }
+ return float32(f)
+}
+
+// randFloat64 generates a random float taking the full range of a float64.
+func randFloat64(rand *rand.Rand) float64 {
+ f := rand.Float64()
+ if rand.Int()&1 == 1 {
+ f = -f
+ }
+ return f
+}
+
+// randInt64 returns a random integer taking half the range of an int64.
+func randInt64(rand *rand.Rand) int64 { return rand.Int63() - 1<<62 }
+
+// complexSize is the maximum length of arbitrary values that contain other
+// values.
+const complexSize = 50
+
+// Value returns an arbitrary value of the given type.
+// If the type implements the Generator interface, that will be used.
+// Note: in order to create arbitrary values for structs, all the members must be public.
+func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) {
+ if m, ok := reflect.Zero(t).Interface().(Generator); ok {
+ return m.Generate(rand, complexSize), true
+ }
+
+ switch concrete := t; concrete.Kind() {
+ case reflect.Bool:
+ return reflect.ValueOf(rand.Int()&1 == 0), true
+ case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Complex64, reflect.Complex128:
+ switch t.Kind() {
+ case reflect.Float32:
+ return reflect.ValueOf(randFloat32(rand)), true
+ case reflect.Float64:
+ return reflect.ValueOf(randFloat64(rand)), true
+ case reflect.Complex64:
+ return reflect.ValueOf(complex(randFloat32(rand), randFloat32(rand))), true
+ case reflect.Complex128:
+ return reflect.ValueOf(complex(randFloat64(rand), randFloat64(rand))), true
+ case reflect.Int16:
+ return reflect.ValueOf(int16(randInt64(rand))), true
+ case reflect.Int32:
+ return reflect.ValueOf(int32(randInt64(rand))), true
+ case reflect.Int64:
+ return reflect.ValueOf(randInt64(rand)), true
+ case reflect.Int8:
+ return reflect.ValueOf(int8(randInt64(rand))), true
+ case reflect.Int:
+ return reflect.ValueOf(int(randInt64(rand))), true
+ case reflect.Uint16:
+ return reflect.ValueOf(uint16(randInt64(rand))), true
+ case reflect.Uint32:
+ return reflect.ValueOf(uint32(randInt64(rand))), true
+ case reflect.Uint64:
+ return reflect.ValueOf(uint64(randInt64(rand))), true
+ case reflect.Uint8:
+ return reflect.ValueOf(uint8(randInt64(rand))), true
+ case reflect.Uint:
+ return reflect.ValueOf(uint(randInt64(rand))), true
+ case reflect.Uintptr:
+ return reflect.ValueOf(uintptr(randInt64(rand))), true
+ }
+ case reflect.Map:
+ numElems := rand.Intn(complexSize)
+ m := reflect.MakeMap(concrete)
+ for i := 0; i < numElems; i++ {
+ key, ok1 := Value(concrete.Key(), rand)
+ value, ok2 := Value(concrete.Elem(), rand)
+ if !ok1 || !ok2 {
+ return reflect.Value{}, false
+ }
+ m.SetMapIndex(key, value)
+ }
+ return m, true
+ case reflect.Ptr:
+ v, ok := Value(concrete.Elem(), rand)
+ if !ok {
+ return reflect.Value{}, false
+ }
+ p := reflect.Zero(concrete)
+ p.Set(v.Addr())
+ return p, true
+ case reflect.Slice:
+ numElems := rand.Intn(complexSize)
+ s := reflect.MakeSlice(concrete, numElems, numElems)
+ for i := 0; i < numElems; i++ {
+ v, ok := Value(concrete.Elem(), rand)
+ if !ok {
+ return reflect.Value{}, false
+ }
+ s.Index(i).Set(v)
+ }
+ return s, true
+ case reflect.String:
+ numChars := rand.Intn(complexSize)
+ codePoints := make([]int, numChars)
+ for i := 0; i < numChars; i++ {
+ codePoints[i] = rand.Intn(0x10ffff)
+ }
+ return reflect.ValueOf(string(codePoints)), true
+ case reflect.Struct:
+ s := reflect.Zero(t)
+ for i := 0; i < s.NumField(); i++ {
+ v, ok := Value(concrete.Field(i).Type, rand)
+ if !ok {
+ return reflect.Value{}, false
+ }
+ s.Field(i).Set(v)
+ }
+ return s, true
+ default:
+ return reflect.Value{}, false
+ }
+
+ return
+}
+
+// A Config structure contains options for running a test.
+type Config struct {
+ // MaxCount sets the maximum number of iterations. If zero,
+ // MaxCountScale is used.
+ MaxCount int
+ // MaxCountScale is a non-negative scale factor applied to the default
+ // maximum. If zero, the default is unchanged.
+ MaxCountScale float64
+ // If non-nil, rand is a source of random numbers. Otherwise a default
+ // pseudo-random source will be used.
+ Rand *rand.Rand
+ // If non-nil, Values is a function which generates a slice of arbitrary
+ // Values that are congruent with the arguments to the function being
+ // tested. Otherwise, Values is used to generate the values.
+ Values func([]reflect.Value, *rand.Rand)
+}
+
+var defaultConfig Config
+
+// getRand returns the *rand.Rand to use for a given Config.
+func (c *Config) getRand() *rand.Rand {
+ if c.Rand == nil {
+ return rand.New(rand.NewSource(0))
+ }
+ return c.Rand
+}
+
+// getMaxCount returns the maximum number of iterations to run for a given
+// Config.
+func (c *Config) getMaxCount() (maxCount int) {
+ maxCount = c.MaxCount
+ if maxCount == 0 {
+ if c.MaxCountScale != 0 {
+ maxCount = int(c.MaxCountScale * float64(*defaultMaxCount))
+ } else {
+ maxCount = *defaultMaxCount
+ }
+ }
+
+ return
+}
+
+// A SetupError is the result of an error in the way that check is being
+// used, independent of the functions being tested.
+type SetupError string
+
+func (s SetupError) String() string { return string(s) }
+
+// A CheckError is the result of Check finding an error.
+type CheckError struct {
+ Count int
+ In []interface{}
+}
+
+func (s *CheckError) String() string {
+ return fmt.Sprintf("#%d: failed on input %s", s.Count, toString(s.In))
+}
+
+// A CheckEqualError is the result CheckEqual finding an error.
+type CheckEqualError struct {
+ CheckError
+ Out1 []interface{}
+ Out2 []interface{}
+}
+
+func (s *CheckEqualError) String() string {
+ return fmt.Sprintf("#%d: failed on input %s. Output 1: %s. Output 2: %s", s.Count, toString(s.In), toString(s.Out1), toString(s.Out2))
+}
+
+// Check looks for an input to f, any function that returns bool,
+// such that f returns false. It calls f repeatedly, with arbitrary
+// values for each argument. If f returns false on a given input,
+// Check returns that input as a *CheckError.
+// For example:
+//
+// func TestOddMultipleOfThree(t *testing.T) {
+// f := func(x int) bool {
+// y := OddMultipleOfThree(x)
+// return y%2 == 1 && y%3 == 0
+// }
+// if err := quick.Check(f, nil); err != nil {
+// t.Error(err)
+// }
+// }
+func Check(function interface{}, config *Config) (err os.Error) {
+ if config == nil {
+ config = &defaultConfig
+ }
+
+ f, fType, ok := functionAndType(function)
+ if !ok {
+ err = SetupError("argument is not a function")
+ return
+ }
+
+ if fType.NumOut() != 1 {
+ err = SetupError("function returns more than one value.")
+ return
+ }
+ if fType.Out(0).Kind() != reflect.Bool {
+ err = SetupError("function does not return a bool")
+ return
+ }
+
+ arguments := make([]reflect.Value, fType.NumIn())
+ rand := config.getRand()
+ maxCount := config.getMaxCount()
+
+ for i := 0; i < maxCount; i++ {
+ err = arbitraryValues(arguments, fType, config, rand)
+ if err != nil {
+ return
+ }
+
+ if !f.Call(arguments)[0].Bool() {
+ err = &CheckError{i + 1, toInterfaces(arguments)}
+ return
+ }
+ }
+
+ return
+}
+
+// CheckEqual looks for an input on which f and g return different results.
+// It calls f and g repeatedly with arbitrary values for each argument.
+// If f and g return different answers, CheckEqual returns a *CheckEqualError
+// describing the input and the outputs.
+func CheckEqual(f, g interface{}, config *Config) (err os.Error) {
+ if config == nil {
+ config = &defaultConfig
+ }
+
+ x, xType, ok := functionAndType(f)
+ if !ok {
+ err = SetupError("f is not a function")
+ return
+ }
+ y, yType, ok := functionAndType(g)
+ if !ok {
+ err = SetupError("g is not a function")
+ return
+ }
+
+ if xType != yType {
+ err = SetupError("functions have different types")
+ return
+ }
+
+ arguments := make([]reflect.Value, xType.NumIn())
+ rand := config.getRand()
+ maxCount := config.getMaxCount()
+
+ for i := 0; i < maxCount; i++ {
+ err = arbitraryValues(arguments, xType, config, rand)
+ if err != nil {
+ return
+ }
+
+ xOut := toInterfaces(x.Call(arguments))
+ yOut := toInterfaces(y.Call(arguments))
+
+ if !reflect.DeepEqual(xOut, yOut) {
+ err = &CheckEqualError{CheckError{i + 1, toInterfaces(arguments)}, xOut, yOut}
+ return
+ }
+ }
+
+ return
+}
+
+// arbitraryValues writes Values to args such that args contains Values
+// suitable for calling f.
+func arbitraryValues(args []reflect.Value, f reflect.Type, config *Config, rand *rand.Rand) (err os.Error) {
+ if config.Values != nil {
+ config.Values(args, rand)
+ return
+ }
+
+ for j := 0; j < len(args); j++ {
+ var ok bool
+ args[j], ok = Value(f.In(j), rand)
+ if !ok {
+ err = SetupError(fmt.Sprintf("cannot create arbitrary value of type %s for argument %d", f.In(j), j))
+ return
+ }
+ }
+
+ return
+}
+
+func functionAndType(f interface{}) (v reflect.Value, t reflect.Type, ok bool) {
+ v = reflect.ValueOf(f)
+ ok = v.Kind() == reflect.Func
+ if !ok {
+ return
+ }
+ t = v.Type()
+ return
+}
+
+func toInterfaces(values []reflect.Value) []interface{} {
+ ret := make([]interface{}, len(values))
+ for i, v := range values {
+ ret[i] = v.Interface()
+ }
+ return ret
+}
+
+func toString(interfaces []interface{}) string {
+ s := make([]string, len(interfaces))
+ for i, v := range interfaces {
+ s[i] = fmt.Sprintf("%#v", v)
+ }
+ return strings.Join(s, ", ")
+}
diff --git a/src/cmd/fix/testdata/reflect.read.go.in b/src/cmd/fix/testdata/reflect.read.go.in
new file mode 100644
index 000000000..487994ac6
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.read.go.in
@@ -0,0 +1,620 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xml
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
+// an XML element is an order-dependent collection of anonymous
+// values, while a data structure is an order-independent collection
+// of named values.
+// See package json for a textual representation more suitable
+// to data structures.
+
+// Unmarshal parses an XML element from r and uses the
+// reflect library to fill in an arbitrary struct, slice, or string
+// pointed at by val. Well-formed data that does not fit
+// into val is discarded.
+//
+// For example, given these definitions:
+//
+// type Email struct {
+// Where string "attr"
+// Addr string
+// }
+//
+// type Result struct {
+// XMLName xml.Name "result"
+// Name string
+// Phone string
+// Email []Email
+// Groups []string "group>value"
+// }
+//
+// result := Result{Name: "name", Phone: "phone", Email: nil}
+//
+// unmarshalling the XML input
+//
+// <result>
+// <email where="home">
+// <addr>gre@example.com</addr>
+// </email>
+// <email where='work'>
+// <addr>gre@work.com</addr>
+// </email>
+// <name>Grace R. Emlin</name>
+// <group>
+// <value>Friends</value>
+// <value>Squash</value>
+// </group>
+// <address>123 Main Street</address>
+// </result>
+//
+// via Unmarshal(r, &result) is equivalent to assigning
+//
+// r = Result{xml.Name{"", "result"},
+// "Grace R. Emlin", // name
+// "phone", // no phone given
+// []Email{
+// Email{"home", "gre@example.com"},
+// Email{"work", "gre@work.com"},
+// },
+// []string{"Friends", "Squash"},
+// }
+//
+// Note that the field r.Phone has not been modified and
+// that the XML <address> element was discarded. Also, the field
+// Groups was assigned considering the element path provided in the
+// field tag.
+//
+// Because Unmarshal uses the reflect package, it can only
+// assign to upper case fields. Unmarshal uses a case-insensitive
+// comparison to match XML element names to struct field names.
+//
+// Unmarshal maps an XML element to a struct using the following rules:
+//
+// * If the struct has a field of type []byte or string with tag "innerxml",
+// Unmarshal accumulates the raw XML nested inside the element
+// in that field. The rest of the rules still apply.
+//
+// * If the struct has a field named XMLName of type xml.Name,
+// Unmarshal records the element name in that field.
+//
+// * If the XMLName field has an associated tag string of the form
+// "tag" or "namespace-URL tag", the XML element must have
+// the given tag (and, optionally, name space) or else Unmarshal
+// returns an error.
+//
+// * If the XML element has an attribute whose name matches a
+// struct field of type string with tag "attr", Unmarshal records
+// the attribute value in that field.
+//
+// * If the XML element contains character data, that data is
+// accumulated in the first struct field that has tag "chardata".
+// The struct field may have type []byte or string.
+// If there is no such field, the character data is discarded.
+//
+// * If the XML element contains a sub-element whose name matches
+// the prefix of a struct field tag formatted as "a>b>c", unmarshal
+// will descend into the XML structure looking for elements with the
+// given names, and will map the innermost elements to that struct field.
+// A struct field tag starting with ">" is equivalent to one starting
+// with the field name followed by ">".
+//
+// * If the XML element contains a sub-element whose name
+// matches a struct field whose tag is neither "attr" nor "chardata",
+// Unmarshal maps the sub-element to that struct field.
+// Otherwise, if the struct has a field named Any, unmarshal
+// maps the sub-element to that struct field.
+//
+// Unmarshal maps an XML element to a string or []byte by saving the
+// concatenation of that element's character data in the string or []byte.
+//
+// Unmarshal maps an XML element to a slice by extending the length
+// of the slice and mapping the element to the newly created value.
+//
+// Unmarshal maps an XML element to a bool by setting it to the boolean
+// value represented by the string.
+//
+// Unmarshal maps an XML element to an integer or floating-point
+// field by setting the field to the result of interpreting the string
+// value in decimal. There is no check for overflow.
+//
+// Unmarshal maps an XML element to an xml.Name by recording the
+// element name.
+//
+// Unmarshal maps an XML element to a pointer by setting the pointer
+// to a freshly allocated value and then mapping the element to that value.
+//
+func Unmarshal(r io.Reader, val interface{}) os.Error {
+ v, ok := reflect.NewValue(val).(*reflect.PtrValue)
+ if !ok {
+ return os.NewError("non-pointer passed to Unmarshal")
+ }
+ p := NewParser(r)
+ elem := v.Elem()
+ err := p.unmarshal(elem, nil)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// An UnmarshalError represents an error in the unmarshalling process.
+type UnmarshalError string
+
+func (e UnmarshalError) String() string { return string(e) }
+
+// A TagPathError represents an error in the unmarshalling process
+// caused by the use of field tags with conflicting paths.
+type TagPathError struct {
+ Struct reflect.Type
+ Field1, Tag1 string
+ Field2, Tag2 string
+}
+
+func (e *TagPathError) String() string {
+ return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
+}
+
+// The Parser's Unmarshal method is like xml.Unmarshal
+// except that it can be passed a pointer to the initial start element,
+// useful when a client reads some raw XML tokens itself
+// but also defers to Unmarshal for some elements.
+// Passing a nil start element indicates that Unmarshal should
+// read the token stream to find the start element.
+func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error {
+ v, ok := reflect.NewValue(val).(*reflect.PtrValue)
+ if !ok {
+ return os.NewError("non-pointer passed to Unmarshal")
+ }
+ return p.unmarshal(v.Elem(), start)
+}
+
+// fieldName strips invalid characters from an XML name
+// to create a valid Go struct name. It also converts the
+// name to lower case letters.
+func fieldName(original string) string {
+
+ var i int
+ //remove leading underscores
+ for i = 0; i < len(original) && original[i] == '_'; i++ {
+ }
+
+ return strings.Map(
+ func(x int) int {
+ if x == '_' || unicode.IsDigit(x) || unicode.IsLetter(x) {
+ return unicode.ToLower(x)
+ }
+ return -1
+ },
+ original[i:])
+}
+
+// Unmarshal a single XML element into val.
+func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error {
+ // Find start element if we need it.
+ if start == nil {
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ if t, ok := tok.(StartElement); ok {
+ start = &t
+ break
+ }
+ }
+ }
+
+ if pv, ok := val.(*reflect.PtrValue); ok {
+ if pv.Get() == 0 {
+ zv := reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem())
+ pv.PointTo(zv)
+ val = zv
+ } else {
+ val = pv.Elem()
+ }
+ }
+
+ var (
+ data []byte
+ saveData reflect.Value
+ comment []byte
+ saveComment reflect.Value
+ saveXML reflect.Value
+ saveXMLIndex int
+ saveXMLData []byte
+ sv *reflect.StructValue
+ styp *reflect.StructType
+ fieldPaths map[string]pathInfo
+ )
+
+ switch v := val.(type) {
+ default:
+ return os.NewError("unknown type " + v.Type().String())
+
+ case *reflect.SliceValue:
+ typ := v.Type().(*reflect.SliceType)
+ if typ.Elem().Kind() == reflect.Uint8 {
+ // []byte
+ saveData = v
+ break
+ }
+
+ // Slice of element values.
+ // Grow slice.
+ n := v.Len()
+ if n >= v.Cap() {
+ ncap := 2 * n
+ if ncap < 4 {
+ ncap = 4
+ }
+ new := reflect.MakeSlice(typ, n, ncap)
+ reflect.Copy(new, v)
+ v.Set(new)
+ }
+ v.SetLen(n + 1)
+
+ // Recur to read element into slice.
+ if err := p.unmarshal(v.Elem(n), start); err != nil {
+ v.SetLen(n)
+ return err
+ }
+ return nil
+
+ case *reflect.BoolValue, *reflect.FloatValue, *reflect.IntValue, *reflect.UintValue, *reflect.StringValue:
+ saveData = v
+
+ case *reflect.StructValue:
+ if _, ok := v.Interface().(Name); ok {
+ v.Set(reflect.NewValue(start.Name).(*reflect.StructValue))
+ break
+ }
+
+ sv = v
+ typ := sv.Type().(*reflect.StructType)
+ styp = typ
+ // Assign name.
+ if f, ok := typ.FieldByName("XMLName"); ok {
+ // Validate element name.
+ if f.Tag != "" {
+ tag := f.Tag
+ ns := ""
+ i := strings.LastIndex(tag, " ")
+ if i >= 0 {
+ ns, tag = tag[0:i], tag[i+1:]
+ }
+ if tag != start.Name.Local {
+ return UnmarshalError("expected element type <" + tag + "> but have <" + start.Name.Local + ">")
+ }
+ if ns != "" && ns != start.Name.Space {
+ e := "expected element <" + tag + "> in name space " + ns + " but have "
+ if start.Name.Space == "" {
+ e += "no name space"
+ } else {
+ e += start.Name.Space
+ }
+ return UnmarshalError(e)
+ }
+ }
+
+ // Save
+ v := sv.FieldByIndex(f.Index)
+ if _, ok := v.Interface().(Name); !ok {
+ return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name")
+ }
+ v.(*reflect.StructValue).Set(reflect.NewValue(start.Name).(*reflect.StructValue))
+ }
+
+ // Assign attributes.
+ // Also, determine whether we need to save character data or comments.
+ for i, n := 0, typ.NumField(); i < n; i++ {
+ f := typ.Field(i)
+ switch f.Tag {
+ case "attr":
+ strv, ok := sv.FieldByIndex(f.Index).(*reflect.StringValue)
+ if !ok {
+ return UnmarshalError(sv.Type().String() + " field " + f.Name + " has attr tag but is not type string")
+ }
+ // Look for attribute.
+ val := ""
+ k := strings.ToLower(f.Name)
+ for _, a := range start.Attr {
+ if fieldName(a.Name.Local) == k {
+ val = a.Value
+ break
+ }
+ }
+ strv.Set(val)
+
+ case "comment":
+ if saveComment == nil {
+ saveComment = sv.FieldByIndex(f.Index)
+ }
+
+ case "chardata":
+ if saveData == nil {
+ saveData = sv.FieldByIndex(f.Index)
+ }
+
+ case "innerxml":
+ if saveXML == nil {
+ saveXML = sv.FieldByIndex(f.Index)
+ if p.saved == nil {
+ saveXMLIndex = 0
+ p.saved = new(bytes.Buffer)
+ } else {
+ saveXMLIndex = p.savedOffset()
+ }
+ }
+
+ default:
+ if strings.Contains(f.Tag, ">") {
+ if fieldPaths == nil {
+ fieldPaths = make(map[string]pathInfo)
+ }
+ path := strings.ToLower(f.Tag)
+ if strings.HasPrefix(f.Tag, ">") {
+ path = strings.ToLower(f.Name) + path
+ }
+ if strings.HasSuffix(f.Tag, ">") {
+ path = path[:len(path)-1]
+ }
+ err := addFieldPath(sv, fieldPaths, path, f.Index)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ }
+
+ // Find end element.
+ // Process sub-elements along the way.
+Loop:
+ for {
+ var savedOffset int
+ if saveXML != nil {
+ savedOffset = p.savedOffset()
+ }
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ // Sub-element.
+ // Look up by tag name.
+ if sv != nil {
+ k := fieldName(t.Name.Local)
+
+ if fieldPaths != nil {
+ if _, found := fieldPaths[k]; found {
+ if err := p.unmarshalPaths(sv, fieldPaths, k, &t); err != nil {
+ return err
+ }
+ continue Loop
+ }
+ }
+
+ match := func(s string) bool {
+ // check if the name matches ignoring case
+ if strings.ToLower(s) != k {
+ return false
+ }
+ // now check that it's public
+ c, _ := utf8.DecodeRuneInString(s)
+ return unicode.IsUpper(c)
+ }
+
+ f, found := styp.FieldByNameFunc(match)
+ if !found { // fall back to mop-up field named "Any"
+ f, found = styp.FieldByName("Any")
+ }
+ if found {
+ if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil {
+ return err
+ }
+ continue Loop
+ }
+ }
+ // Not saving sub-element but still have to skip over it.
+ if err := p.Skip(); err != nil {
+ return err
+ }
+
+ case EndElement:
+ if saveXML != nil {
+ saveXMLData = p.saved.Bytes()[saveXMLIndex:savedOffset]
+ if saveXMLIndex == 0 {
+ p.saved = nil
+ }
+ }
+ break Loop
+
+ case CharData:
+ if saveData != nil {
+ data = append(data, t...)
+ }
+
+ case Comment:
+ if saveComment != nil {
+ comment = append(comment, t...)
+ }
+ }
+ }
+
+ var err os.Error
+ // Helper functions for integer and unsigned integer conversions
+ var itmp int64
+ getInt64 := func() bool {
+ itmp, err = strconv.Atoi64(string(data))
+ // TODO: should check sizes
+ return err == nil
+ }
+ var utmp uint64
+ getUint64 := func() bool {
+ utmp, err = strconv.Atoui64(string(data))
+ // TODO: check for overflow?
+ return err == nil
+ }
+ var ftmp float64
+ getFloat64 := func() bool {
+ ftmp, err = strconv.Atof64(string(data))
+ // TODO: check for overflow?
+ return err == nil
+ }
+
+ // Save accumulated data and comments
+ switch t := saveData.(type) {
+ case nil:
+ // Probably a comment, handled below
+ default:
+ return os.NewError("cannot happen: unknown type " + t.Type().String())
+ case *reflect.IntValue:
+ if !getInt64() {
+ return err
+ }
+ t.Set(itmp)
+ case *reflect.UintValue:
+ if !getUint64() {
+ return err
+ }
+ t.Set(utmp)
+ case *reflect.FloatValue:
+ if !getFloat64() {
+ return err
+ }
+ t.Set(ftmp)
+ case *reflect.BoolValue:
+ value, err := strconv.Atob(strings.TrimSpace(string(data)))
+ if err != nil {
+ return err
+ }
+ t.Set(value)
+ case *reflect.StringValue:
+ t.Set(string(data))
+ case *reflect.SliceValue:
+ t.Set(reflect.NewValue(data).(*reflect.SliceValue))
+ }
+
+ switch t := saveComment.(type) {
+ case *reflect.StringValue:
+ t.Set(string(comment))
+ case *reflect.SliceValue:
+ t.Set(reflect.NewValue(comment).(*reflect.SliceValue))
+ }
+
+ switch t := saveXML.(type) {
+ case *reflect.StringValue:
+ t.Set(string(saveXMLData))
+ case *reflect.SliceValue:
+ t.Set(reflect.NewValue(saveXMLData).(*reflect.SliceValue))
+ }
+
+ return nil
+}
+
+type pathInfo struct {
+ fieldIdx []int
+ complete bool
+}
+
+// addFieldPath takes an element path such as "a>b>c" and fills the
+// paths map with all paths leading to it ("a", "a>b", and "a>b>c").
+// It is okay for paths to share a common, shorter prefix but not ok
+// for one path to itself be a prefix of another.
+func addFieldPath(sv *reflect.StructValue, paths map[string]pathInfo, path string, fieldIdx []int) os.Error {
+ if info, found := paths[path]; found {
+ return tagError(sv, info.fieldIdx, fieldIdx)
+ }
+ paths[path] = pathInfo{fieldIdx, true}
+ for {
+ i := strings.LastIndex(path, ">")
+ if i < 0 {
+ break
+ }
+ path = path[:i]
+ if info, found := paths[path]; found {
+ if info.complete {
+ return tagError(sv, info.fieldIdx, fieldIdx)
+ }
+ } else {
+ paths[path] = pathInfo{fieldIdx, false}
+ }
+ }
+ return nil
+
+}
+
+func tagError(sv *reflect.StructValue, idx1 []int, idx2 []int) os.Error {
+ t := sv.Type().(*reflect.StructType)
+ f1 := t.FieldByIndex(idx1)
+ f2 := t.FieldByIndex(idx2)
+ return &TagPathError{t, f1.Name, f1.Tag, f2.Name, f2.Tag}
+}
+
+// unmarshalPaths walks down an XML structure looking for
+// wanted paths, and calls unmarshal on them.
+func (p *Parser) unmarshalPaths(sv *reflect.StructValue, paths map[string]pathInfo, path string, start *StartElement) os.Error {
+ if info, _ := paths[path]; info.complete {
+ return p.unmarshal(sv.FieldByIndex(info.fieldIdx), start)
+ }
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ k := path + ">" + fieldName(t.Name.Local)
+ if _, found := paths[k]; found {
+ if err := p.unmarshalPaths(sv, paths, k, &t); err != nil {
+ return err
+ }
+ continue
+ }
+ if err := p.Skip(); err != nil {
+ return err
+ }
+ case EndElement:
+ return nil
+ }
+ }
+ panic("unreachable")
+}
+
+// Have already read a start element.
+// Read tokens until we find the end element.
+// Token is taking care of making sure the
+// end element matches the start element we saw.
+func (p *Parser) Skip() os.Error {
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ if err := p.Skip(); err != nil {
+ return err
+ }
+ case EndElement:
+ return nil
+ }
+ }
+ panic("unreachable")
+}
diff --git a/src/cmd/fix/testdata/reflect.read.go.out b/src/cmd/fix/testdata/reflect.read.go.out
new file mode 100644
index 000000000..a6b126744
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.read.go.out
@@ -0,0 +1,620 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xml
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
+// an XML element is an order-dependent collection of anonymous
+// values, while a data structure is an order-independent collection
+// of named values.
+// See package json for a textual representation more suitable
+// to data structures.
+
+// Unmarshal parses an XML element from r and uses the
+// reflect library to fill in an arbitrary struct, slice, or string
+// pointed at by val. Well-formed data that does not fit
+// into val is discarded.
+//
+// For example, given these definitions:
+//
+// type Email struct {
+// Where string "attr"
+// Addr string
+// }
+//
+// type Result struct {
+// XMLName xml.Name "result"
+// Name string
+// Phone string
+// Email []Email
+// Groups []string "group>value"
+// }
+//
+// result := Result{Name: "name", Phone: "phone", Email: nil}
+//
+// unmarshalling the XML input
+//
+// <result>
+// <email where="home">
+// <addr>gre@example.com</addr>
+// </email>
+// <email where='work'>
+// <addr>gre@work.com</addr>
+// </email>
+// <name>Grace R. Emlin</name>
+// <group>
+// <value>Friends</value>
+// <value>Squash</value>
+// </group>
+// <address>123 Main Street</address>
+// </result>
+//
+// via Unmarshal(r, &result) is equivalent to assigning
+//
+// r = Result{xml.Name{"", "result"},
+// "Grace R. Emlin", // name
+// "phone", // no phone given
+// []Email{
+// Email{"home", "gre@example.com"},
+// Email{"work", "gre@work.com"},
+// },
+// []string{"Friends", "Squash"},
+// }
+//
+// Note that the field r.Phone has not been modified and
+// that the XML <address> element was discarded. Also, the field
+// Groups was assigned considering the element path provided in the
+// field tag.
+//
+// Because Unmarshal uses the reflect package, it can only
+// assign to upper case fields. Unmarshal uses a case-insensitive
+// comparison to match XML element names to struct field names.
+//
+// Unmarshal maps an XML element to a struct using the following rules:
+//
+// * If the struct has a field of type []byte or string with tag "innerxml",
+// Unmarshal accumulates the raw XML nested inside the element
+// in that field. The rest of the rules still apply.
+//
+// * If the struct has a field named XMLName of type xml.Name,
+// Unmarshal records the element name in that field.
+//
+// * If the XMLName field has an associated tag string of the form
+// "tag" or "namespace-URL tag", the XML element must have
+// the given tag (and, optionally, name space) or else Unmarshal
+// returns an error.
+//
+// * If the XML element has an attribute whose name matches a
+// struct field of type string with tag "attr", Unmarshal records
+// the attribute value in that field.
+//
+// * If the XML element contains character data, that data is
+// accumulated in the first struct field that has tag "chardata".
+// The struct field may have type []byte or string.
+// If there is no such field, the character data is discarded.
+//
+// * If the XML element contains a sub-element whose name matches
+// the prefix of a struct field tag formatted as "a>b>c", unmarshal
+// will descend into the XML structure looking for elements with the
+// given names, and will map the innermost elements to that struct field.
+// A struct field tag starting with ">" is equivalent to one starting
+// with the field name followed by ">".
+//
+// * If the XML element contains a sub-element whose name
+// matches a struct field whose tag is neither "attr" nor "chardata",
+// Unmarshal maps the sub-element to that struct field.
+// Otherwise, if the struct has a field named Any, unmarshal
+// maps the sub-element to that struct field.
+//
+// Unmarshal maps an XML element to a string or []byte by saving the
+// concatenation of that element's character data in the string or []byte.
+//
+// Unmarshal maps an XML element to a slice by extending the length
+// of the slice and mapping the element to the newly created value.
+//
+// Unmarshal maps an XML element to a bool by setting it to the boolean
+// value represented by the string.
+//
+// Unmarshal maps an XML element to an integer or floating-point
+// field by setting the field to the result of interpreting the string
+// value in decimal. There is no check for overflow.
+//
+// Unmarshal maps an XML element to an xml.Name by recording the
+// element name.
+//
+// Unmarshal maps an XML element to a pointer by setting the pointer
+// to a freshly allocated value and then mapping the element to that value.
+//
+func Unmarshal(r io.Reader, val interface{}) os.Error {
+ v := reflect.ValueOf(val)
+ if v.Kind() != reflect.Ptr {
+ return os.NewError("non-pointer passed to Unmarshal")
+ }
+ p := NewParser(r)
+ elem := v.Elem()
+ err := p.unmarshal(elem, nil)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// An UnmarshalError represents an error in the unmarshalling process.
+type UnmarshalError string
+
+func (e UnmarshalError) String() string { return string(e) }
+
+// A TagPathError represents an error in the unmarshalling process
+// caused by the use of field tags with conflicting paths.
+type TagPathError struct {
+ Struct reflect.Type
+ Field1, Tag1 string
+ Field2, Tag2 string
+}
+
+func (e *TagPathError) String() string {
+ return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
+}
+
+// The Parser's Unmarshal method is like xml.Unmarshal
+// except that it can be passed a pointer to the initial start element,
+// useful when a client reads some raw XML tokens itself
+// but also defers to Unmarshal for some elements.
+// Passing a nil start element indicates that Unmarshal should
+// read the token stream to find the start element.
+func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error {
+ v := reflect.ValueOf(val)
+ if v.Kind() != reflect.Ptr {
+ return os.NewError("non-pointer passed to Unmarshal")
+ }
+ return p.unmarshal(v.Elem(), start)
+}
+
+// fieldName strips invalid characters from an XML name
+// to create a valid Go struct name. It also converts the
+// name to lower case letters.
+func fieldName(original string) string {
+
+ var i int
+ //remove leading underscores
+ for i = 0; i < len(original) && original[i] == '_'; i++ {
+ }
+
+ return strings.Map(
+ func(x int) int {
+ if x == '_' || unicode.IsDigit(x) || unicode.IsLetter(x) {
+ return unicode.ToLower(x)
+ }
+ return -1
+ },
+ original[i:])
+}
+
+// Unmarshal a single XML element into val.
+func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error {
+ // Find start element if we need it.
+ if start == nil {
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ if t, ok := tok.(StartElement); ok {
+ start = &t
+ break
+ }
+ }
+ }
+
+ if pv := val; pv.Kind() == reflect.Ptr {
+ if pv.Pointer() == 0 {
+ zv := reflect.Zero(pv.Type().Elem())
+ pv.Set(zv.Addr())
+ val = zv
+ } else {
+ val = pv.Elem()
+ }
+ }
+
+ var (
+ data []byte
+ saveData reflect.Value
+ comment []byte
+ saveComment reflect.Value
+ saveXML reflect.Value
+ saveXMLIndex int
+ saveXMLData []byte
+ sv reflect.Value
+ styp reflect.Type
+ fieldPaths map[string]pathInfo
+ )
+
+ switch v := val; v.Kind() {
+ default:
+ return os.NewError("unknown type " + v.Type().String())
+
+ case reflect.Slice:
+ typ := v.Type()
+ if typ.Elem().Kind() == reflect.Uint8 {
+ // []byte
+ saveData = v
+ break
+ }
+
+ // Slice of element values.
+ // Grow slice.
+ n := v.Len()
+ if n >= v.Cap() {
+ ncap := 2 * n
+ if ncap < 4 {
+ ncap = 4
+ }
+ new := reflect.MakeSlice(typ, n, ncap)
+ reflect.Copy(new, v)
+ v.Set(new)
+ }
+ v.SetLen(n + 1)
+
+ // Recur to read element into slice.
+ if err := p.unmarshal(v.Index(n), start); err != nil {
+ v.SetLen(n)
+ return err
+ }
+ return nil
+
+ case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String:
+ saveData = v
+
+ case reflect.Struct:
+ if _, ok := v.Interface().(Name); ok {
+ v.Set(reflect.ValueOf(start.Name))
+ break
+ }
+
+ sv = v
+ typ := sv.Type()
+ styp = typ
+ // Assign name.
+ if f, ok := typ.FieldByName("XMLName"); ok {
+ // Validate element name.
+ if f.Tag != "" {
+ tag := f.Tag
+ ns := ""
+ i := strings.LastIndex(tag, " ")
+ if i >= 0 {
+ ns, tag = tag[0:i], tag[i+1:]
+ }
+ if tag != start.Name.Local {
+ return UnmarshalError("expected element type <" + tag + "> but have <" + start.Name.Local + ">")
+ }
+ if ns != "" && ns != start.Name.Space {
+ e := "expected element <" + tag + "> in name space " + ns + " but have "
+ if start.Name.Space == "" {
+ e += "no name space"
+ } else {
+ e += start.Name.Space
+ }
+ return UnmarshalError(e)
+ }
+ }
+
+ // Save
+ v := sv.FieldByIndex(f.Index)
+ if _, ok := v.Interface().(Name); !ok {
+ return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name")
+ }
+ v.Set(reflect.ValueOf(start.Name))
+ }
+
+ // Assign attributes.
+ // Also, determine whether we need to save character data or comments.
+ for i, n := 0, typ.NumField(); i < n; i++ {
+ f := typ.Field(i)
+ switch f.Tag {
+ case "attr":
+ strv := sv.FieldByIndex(f.Index)
+ if strv.Kind() != reflect.String {
+ return UnmarshalError(sv.Type().String() + " field " + f.Name + " has attr tag but is not type string")
+ }
+ // Look for attribute.
+ val := ""
+ k := strings.ToLower(f.Name)
+ for _, a := range start.Attr {
+ if fieldName(a.Name.Local) == k {
+ val = a.Value
+ break
+ }
+ }
+ strv.SetString(val)
+
+ case "comment":
+ if !saveComment.IsValid() {
+ saveComment = sv.FieldByIndex(f.Index)
+ }
+
+ case "chardata":
+ if !saveData.IsValid() {
+ saveData = sv.FieldByIndex(f.Index)
+ }
+
+ case "innerxml":
+ if !saveXML.IsValid() {
+ saveXML = sv.FieldByIndex(f.Index)
+ if p.saved == nil {
+ saveXMLIndex = 0
+ p.saved = new(bytes.Buffer)
+ } else {
+ saveXMLIndex = p.savedOffset()
+ }
+ }
+
+ default:
+ if strings.Contains(f.Tag, ">") {
+ if fieldPaths == nil {
+ fieldPaths = make(map[string]pathInfo)
+ }
+ path := strings.ToLower(f.Tag)
+ if strings.HasPrefix(f.Tag, ">") {
+ path = strings.ToLower(f.Name) + path
+ }
+ if strings.HasSuffix(f.Tag, ">") {
+ path = path[:len(path)-1]
+ }
+ err := addFieldPath(sv, fieldPaths, path, f.Index)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ }
+
+ // Find end element.
+ // Process sub-elements along the way.
+Loop:
+ for {
+ var savedOffset int
+ if saveXML.IsValid() {
+ savedOffset = p.savedOffset()
+ }
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ // Sub-element.
+ // Look up by tag name.
+ if sv.IsValid() {
+ k := fieldName(t.Name.Local)
+
+ if fieldPaths != nil {
+ if _, found := fieldPaths[k]; found {
+ if err := p.unmarshalPaths(sv, fieldPaths, k, &t); err != nil {
+ return err
+ }
+ continue Loop
+ }
+ }
+
+ match := func(s string) bool {
+ // check if the name matches ignoring case
+ if strings.ToLower(s) != k {
+ return false
+ }
+ // now check that it's public
+ c, _ := utf8.DecodeRuneInString(s)
+ return unicode.IsUpper(c)
+ }
+
+ f, found := styp.FieldByNameFunc(match)
+ if !found { // fall back to mop-up field named "Any"
+ f, found = styp.FieldByName("Any")
+ }
+ if found {
+ if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil {
+ return err
+ }
+ continue Loop
+ }
+ }
+ // Not saving sub-element but still have to skip over it.
+ if err := p.Skip(); err != nil {
+ return err
+ }
+
+ case EndElement:
+ if saveXML.IsValid() {
+ saveXMLData = p.saved.Bytes()[saveXMLIndex:savedOffset]
+ if saveXMLIndex == 0 {
+ p.saved = nil
+ }
+ }
+ break Loop
+
+ case CharData:
+ if saveData.IsValid() {
+ data = append(data, t...)
+ }
+
+ case Comment:
+ if saveComment.IsValid() {
+ comment = append(comment, t...)
+ }
+ }
+ }
+
+ var err os.Error
+ // Helper functions for integer and unsigned integer conversions
+ var itmp int64
+ getInt64 := func() bool {
+ itmp, err = strconv.Atoi64(string(data))
+ // TODO: should check sizes
+ return err == nil
+ }
+ var utmp uint64
+ getUint64 := func() bool {
+ utmp, err = strconv.Atoui64(string(data))
+ // TODO: check for overflow?
+ return err == nil
+ }
+ var ftmp float64
+ getFloat64 := func() bool {
+ ftmp, err = strconv.Atof64(string(data))
+ // TODO: check for overflow?
+ return err == nil
+ }
+
+ // Save accumulated data and comments
+ switch t := saveData; t.Kind() {
+ case reflect.Invalid:
+ // Probably a comment, handled below
+ default:
+ return os.NewError("cannot happen: unknown type " + t.Type().String())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if !getInt64() {
+ return err
+ }
+ t.SetInt(itmp)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ if !getUint64() {
+ return err
+ }
+ t.SetUint(utmp)
+ case reflect.Float32, reflect.Float64:
+ if !getFloat64() {
+ return err
+ }
+ t.SetFloat(ftmp)
+ case reflect.Bool:
+ value, err := strconv.Atob(strings.TrimSpace(string(data)))
+ if err != nil {
+ return err
+ }
+ t.SetBool(value)
+ case reflect.String:
+ t.SetString(string(data))
+ case reflect.Slice:
+ t.Set(reflect.ValueOf(data))
+ }
+
+ switch t := saveComment; t.Kind() {
+ case reflect.String:
+ t.SetString(string(comment))
+ case reflect.Slice:
+ t.Set(reflect.ValueOf(comment))
+ }
+
+ switch t := saveXML; t.Kind() {
+ case reflect.String:
+ t.SetString(string(saveXMLData))
+ case reflect.Slice:
+ t.Set(reflect.ValueOf(saveXMLData))
+ }
+
+ return nil
+}
+
+type pathInfo struct {
+ fieldIdx []int
+ complete bool
+}
+
+// addFieldPath takes an element path such as "a>b>c" and fills the
+// paths map with all paths leading to it ("a", "a>b", and "a>b>c").
+// It is okay for paths to share a common, shorter prefix but not ok
+// for one path to itself be a prefix of another.
+func addFieldPath(sv reflect.Value, paths map[string]pathInfo, path string, fieldIdx []int) os.Error {
+ if info, found := paths[path]; found {
+ return tagError(sv, info.fieldIdx, fieldIdx)
+ }
+ paths[path] = pathInfo{fieldIdx, true}
+ for {
+ i := strings.LastIndex(path, ">")
+ if i < 0 {
+ break
+ }
+ path = path[:i]
+ if info, found := paths[path]; found {
+ if info.complete {
+ return tagError(sv, info.fieldIdx, fieldIdx)
+ }
+ } else {
+ paths[path] = pathInfo{fieldIdx, false}
+ }
+ }
+ return nil
+
+}
+
+func tagError(sv reflect.Value, idx1 []int, idx2 []int) os.Error {
+ t := sv.Type()
+ f1 := t.FieldByIndex(idx1)
+ f2 := t.FieldByIndex(idx2)
+ return &TagPathError{t, f1.Name, f1.Tag, f2.Name, f2.Tag}
+}
+
+// unmarshalPaths walks down an XML structure looking for
+// wanted paths, and calls unmarshal on them.
+func (p *Parser) unmarshalPaths(sv reflect.Value, paths map[string]pathInfo, path string, start *StartElement) os.Error {
+ if info, _ := paths[path]; info.complete {
+ return p.unmarshal(sv.FieldByIndex(info.fieldIdx), start)
+ }
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ k := path + ">" + fieldName(t.Name.Local)
+ if _, found := paths[k]; found {
+ if err := p.unmarshalPaths(sv, paths, k, &t); err != nil {
+ return err
+ }
+ continue
+ }
+ if err := p.Skip(); err != nil {
+ return err
+ }
+ case EndElement:
+ return nil
+ }
+ }
+ panic("unreachable")
+}
+
+// Have already read a start element.
+// Read tokens until we find the end element.
+// Token is taking care of making sure the
+// end element matches the start element we saw.
+func (p *Parser) Skip() os.Error {
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ if err := p.Skip(); err != nil {
+ return err
+ }
+ case EndElement:
+ return nil
+ }
+ }
+ panic("unreachable")
+}
diff --git a/src/cmd/fix/testdata/reflect.scan.go.in b/src/cmd/fix/testdata/reflect.scan.go.in
new file mode 100644
index 000000000..51898181f
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.scan.go.in
@@ -0,0 +1,1082 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+import (
+ "bytes"
+ "io"
+ "math"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// runeUnreader is the interface to something that can unread runes.
+// If the object provided to Scan does not satisfy this interface,
+// a local buffer will be used to back up the input, but its contents
+// will be lost when Scan returns.
+type runeUnreader interface {
+ UnreadRune() os.Error
+}
+
+// ScanState represents the scanner state passed to custom scanners.
+// Scanners may do rune-at-a-time scanning or ask the ScanState
+// to discover the next space-delimited token.
+type ScanState interface {
+ // ReadRune reads the next rune (Unicode code point) from the input.
+ // If invoked during Scanln, Fscanln, or Sscanln, ReadRune() will
+ // return EOF after returning the first '\n' or when reading beyond
+ // the specified width.
+ ReadRune() (rune int, size int, err os.Error)
+ // UnreadRune causes the next call to ReadRune to return the same rune.
+ UnreadRune() os.Error
+ // Token skips space in the input if skipSpace is true, then returns the
+ // run of Unicode code points c satisfying f(c). If f is nil,
+ // !unicode.IsSpace(c) is used; that is, the token will hold non-space
+ // characters. Newlines are treated as space unless the scan operation
+ // is Scanln, Fscanln or Sscanln, in which case a newline is treated as
+ // EOF. The returned slice points to shared data that may be overwritten
+ // by the next call to Token, a call to a Scan function using the ScanState
+ // as input, or when the calling Scan method returns.
+ Token(skipSpace bool, f func(int) bool) (token []byte, err os.Error)
+ // Width returns the value of the width option and whether it has been set.
+ // The unit is Unicode code points.
+ Width() (wid int, ok bool)
+ // Because ReadRune is implemented by the interface, Read should never be
+ // called by the scanning routines and a valid implementation of
+ // ScanState may choose always to return an error from Read.
+ Read(buf []byte) (n int, err os.Error)
+}
+
+// Scanner is implemented by any value that has a Scan method, which scans
+// the input for the representation of a value and stores the result in the
+// receiver, which must be a pointer to be useful. The Scan method is called
+// for any argument to Scan, Scanf, or Scanln that implements it.
+type Scanner interface {
+ Scan(state ScanState, verb int) os.Error
+}
+
+// Scan scans text read from standard input, storing successive
+// space-separated values into successive arguments. Newlines count
+// as space. It returns the number of items successfully scanned.
+// If that is less than the number of arguments, err will report why.
+func Scan(a ...interface{}) (n int, err os.Error) {
+ return Fscan(os.Stdin, a...)
+}
+
+// Scanln is similar to Scan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Scanln(a ...interface{}) (n int, err os.Error) {
+ return Fscanln(os.Stdin, a...)
+}
+
+// Scanf scans text read from standard input, storing successive
+// space-separated values into successive arguments as determined by
+// the format. It returns the number of items successfully scanned.
+func Scanf(format string, a ...interface{}) (n int, err os.Error) {
+ return Fscanf(os.Stdin, format, a...)
+}
+
+// Sscan scans the argument string, storing successive space-separated
+// values into successive arguments. Newlines count as space. It
+// returns the number of items successfully scanned. If that is less
+// than the number of arguments, err will report why.
+func Sscan(str string, a ...interface{}) (n int, err os.Error) {
+ return Fscan(strings.NewReader(str), a...)
+}
+
+// Sscanln is similar to Sscan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Sscanln(str string, a ...interface{}) (n int, err os.Error) {
+ return Fscanln(strings.NewReader(str), a...)
+}
+
+// Sscanf scans the argument string, storing successive space-separated
+// values into successive arguments as determined by the format. It
+// returns the number of items successfully parsed.
+func Sscanf(str string, format string, a ...interface{}) (n int, err os.Error) {
+ return Fscanf(strings.NewReader(str), format, a...)
+}
+
+// Fscan scans text read from r, storing successive space-separated
+// values into successive arguments. Newlines count as space. It
+// returns the number of items successfully scanned. If that is less
+// than the number of arguments, err will report why.
+func Fscan(r io.Reader, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, true, false)
+ n, err = s.doScan(a)
+ s.free(old)
+ return
+}
+
+// Fscanln is similar to Fscan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Fscanln(r io.Reader, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, false, true)
+ n, err = s.doScan(a)
+ s.free(old)
+ return
+}
+
+// Fscanf scans text read from r, storing successive space-separated
+// values into successive arguments as determined by the format. It
+// returns the number of items successfully parsed.
+func Fscanf(r io.Reader, format string, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, false, false)
+ n, err = s.doScanf(format, a)
+ s.free(old)
+ return
+}
+
+// scanError represents an error generated by the scanning software.
+// It's used as a unique signature to identify such errors when recovering.
+type scanError struct {
+ err os.Error
+}
+
+const eof = -1
+
+// ss is the internal implementation of ScanState.
+type ss struct {
+ rr io.RuneReader // where to read input
+ buf bytes.Buffer // token accumulator
+ peekRune int // one-rune lookahead
+ prevRune int // last rune returned by ReadRune
+ count int // runes consumed so far.
+ atEOF bool // already read EOF
+ ssave
+}
+
+// ssave holds the parts of ss that need to be
+// saved and restored on recursive scans.
+type ssave struct {
+ validSave bool // is or was a part of an actual ss.
+ nlIsEnd bool // whether newline terminates scan
+ nlIsSpace bool // whether newline counts as white space
+ fieldLimit int // max value of ss.count for this field; fieldLimit <= limit
+ limit int // max value of ss.count.
+ maxWid int // width of this field.
+}
+
+// The Read method is only in ScanState so that ScanState
+// satisfies io.Reader. It will never be called when used as
+// intended, so there is no need to make it actually work.
+func (s *ss) Read(buf []byte) (n int, err os.Error) {
+ return 0, os.NewError("ScanState's Read should not be called. Use ReadRune")
+}
+
+func (s *ss) ReadRune() (rune int, size int, err os.Error) {
+ if s.peekRune >= 0 {
+ s.count++
+ rune = s.peekRune
+ size = utf8.RuneLen(rune)
+ s.prevRune = rune
+ s.peekRune = -1
+ return
+ }
+ if s.atEOF || s.nlIsEnd && s.prevRune == '\n' || s.count >= s.fieldLimit {
+ err = os.EOF
+ return
+ }
+
+ rune, size, err = s.rr.ReadRune()
+ if err == nil {
+ s.count++
+ s.prevRune = rune
+ } else if err == os.EOF {
+ s.atEOF = true
+ }
+ return
+}
+
+func (s *ss) Width() (wid int, ok bool) {
+ if s.maxWid == hugeWid {
+ return 0, false
+ }
+ return s.maxWid, true
+}
+
+// The public method returns an error; this private one panics.
+// If getRune reaches EOF, the return value is EOF (-1).
+func (s *ss) getRune() (rune int) {
+ rune, _, err := s.ReadRune()
+ if err != nil {
+ if err == os.EOF {
+ return eof
+ }
+ s.error(err)
+ }
+ return
+}
+
+// mustReadRune turns os.EOF into a panic(io.ErrUnexpectedEOF).
+// It is called in cases such as string scanning where an EOF is a
+// syntax error.
+func (s *ss) mustReadRune() (rune int) {
+ rune = s.getRune()
+ if rune == eof {
+ s.error(io.ErrUnexpectedEOF)
+ }
+ return
+}
+
+func (s *ss) UnreadRune() os.Error {
+ if u, ok := s.rr.(runeUnreader); ok {
+ u.UnreadRune()
+ } else {
+ s.peekRune = s.prevRune
+ }
+ s.count--
+ return nil
+}
+
+func (s *ss) error(err os.Error) {
+ panic(scanError{err})
+}
+
+func (s *ss) errorString(err string) {
+ panic(scanError{os.NewError(err)})
+}
+
+func (s *ss) Token(skipSpace bool, f func(int) bool) (tok []byte, err os.Error) {
+ defer func() {
+ if e := recover(); e != nil {
+ if se, ok := e.(scanError); ok {
+ err = se.err
+ } else {
+ panic(e)
+ }
+ }
+ }()
+ if f == nil {
+ f = notSpace
+ }
+ s.buf.Reset()
+ tok = s.token(skipSpace, f)
+ return
+}
+
+// notSpace is the default scanning function used in Token.
+func notSpace(r int) bool {
+ return !unicode.IsSpace(r)
+}
+
+// readRune is a structure to enable reading UTF-8 encoded code points
+// from an io.Reader. It is used if the Reader given to the scanner does
+// not already implement io.RuneReader.
+type readRune struct {
+ reader io.Reader
+ buf [utf8.UTFMax]byte // used only inside ReadRune
+ pending int // number of bytes in pendBuf; only >0 for bad UTF-8
+ pendBuf [utf8.UTFMax]byte // bytes left over
+}
+
+// readByte returns the next byte from the input, which may be
+// left over from a previous read if the UTF-8 was ill-formed.
+func (r *readRune) readByte() (b byte, err os.Error) {
+ if r.pending > 0 {
+ b = r.pendBuf[0]
+ copy(r.pendBuf[0:], r.pendBuf[1:])
+ r.pending--
+ return
+ }
+ _, err = r.reader.Read(r.pendBuf[0:1])
+ return r.pendBuf[0], err
+}
+
+// unread saves the bytes for the next read.
+func (r *readRune) unread(buf []byte) {
+ copy(r.pendBuf[r.pending:], buf)
+ r.pending += len(buf)
+}
+
+// ReadRune returns the next UTF-8 encoded code point from the
+// io.Reader inside r.
+func (r *readRune) ReadRune() (rune int, size int, err os.Error) {
+ r.buf[0], err = r.readByte()
+ if err != nil {
+ return 0, 0, err
+ }
+ if r.buf[0] < utf8.RuneSelf { // fast check for common ASCII case
+ rune = int(r.buf[0])
+ return
+ }
+ var n int
+ for n = 1; !utf8.FullRune(r.buf[0:n]); n++ {
+ r.buf[n], err = r.readByte()
+ if err != nil {
+ if err == os.EOF {
+ err = nil
+ break
+ }
+ return
+ }
+ }
+ rune, size = utf8.DecodeRune(r.buf[0:n])
+ if size < n { // an error
+ r.unread(r.buf[size:n])
+ }
+ return
+}
+
+var ssFree = newCache(func() interface{} { return new(ss) })
+
+// Allocate a new ss struct or grab a cached one.
+func newScanState(r io.Reader, nlIsSpace, nlIsEnd bool) (s *ss, old ssave) {
+ // If the reader is a *ss, then we've got a recursive
+ // call to Scan, so re-use the scan state.
+ s, ok := r.(*ss)
+ if ok {
+ old = s.ssave
+ s.limit = s.fieldLimit
+ s.nlIsEnd = nlIsEnd || s.nlIsEnd
+ s.nlIsSpace = nlIsSpace
+ return
+ }
+
+ s = ssFree.get().(*ss)
+ if rr, ok := r.(io.RuneReader); ok {
+ s.rr = rr
+ } else {
+ s.rr = &readRune{reader: r}
+ }
+ s.nlIsSpace = nlIsSpace
+ s.nlIsEnd = nlIsEnd
+ s.prevRune = -1
+ s.peekRune = -1
+ s.atEOF = false
+ s.limit = hugeWid
+ s.fieldLimit = hugeWid
+ s.maxWid = hugeWid
+ s.validSave = true
+ return
+}
+
+// Save used ss structs in ssFree; avoid an allocation per invocation.
+func (s *ss) free(old ssave) {
+ // If it was used recursively, just restore the old state.
+ if old.validSave {
+ s.ssave = old
+ return
+ }
+ // Don't hold on to ss structs with large buffers.
+ if cap(s.buf.Bytes()) > 1024 {
+ return
+ }
+ s.buf.Reset()
+ s.rr = nil
+ ssFree.put(s)
+}
+
+// skipSpace skips spaces and maybe newlines.
+func (s *ss) skipSpace(stopAtNewline bool) {
+ for {
+ rune := s.getRune()
+ if rune == eof {
+ return
+ }
+ if rune == '\n' {
+ if stopAtNewline {
+ break
+ }
+ if s.nlIsSpace {
+ continue
+ }
+ s.errorString("unexpected newline")
+ return
+ }
+ if !unicode.IsSpace(rune) {
+ s.UnreadRune()
+ break
+ }
+ }
+}
+
+// token returns the next space-delimited string from the input. It
+// skips white space. For Scanln, it stops at newlines. For Scan,
+// newlines are treated as spaces.
+func (s *ss) token(skipSpace bool, f func(int) bool) []byte {
+ if skipSpace {
+ s.skipSpace(false)
+ }
+ // read until white space or newline
+ for {
+ rune := s.getRune()
+ if rune == eof {
+ break
+ }
+ if !f(rune) {
+ s.UnreadRune()
+ break
+ }
+ s.buf.WriteRune(rune)
+ }
+ return s.buf.Bytes()
+}
+
+// typeError indicates that the type of the operand did not match the format
+func (s *ss) typeError(field interface{}, expected string) {
+ s.errorString("expected field of type pointer to " + expected + "; found " + reflect.Typeof(field).String())
+}
+
+var complexError = os.NewError("syntax error scanning complex number")
+var boolError = os.NewError("syntax error scanning boolean")
+
+// consume reads the next rune in the input and reports whether it is in the ok string.
+// If accept is true, it puts the character into the input token.
+func (s *ss) consume(ok string, accept bool) bool {
+ rune := s.getRune()
+ if rune == eof {
+ return false
+ }
+ if strings.IndexRune(ok, rune) >= 0 {
+ if accept {
+ s.buf.WriteRune(rune)
+ }
+ return true
+ }
+ if rune != eof && accept {
+ s.UnreadRune()
+ }
+ return false
+}
+
+// peek reports whether the next character is in the ok string, without consuming it.
+func (s *ss) peek(ok string) bool {
+ rune := s.getRune()
+ if rune != eof {
+ s.UnreadRune()
+ }
+ return strings.IndexRune(ok, rune) >= 0
+}
+
+// accept checks the next rune in the input. If it's a byte (sic) in the string, it puts it in the
+// buffer and returns true. Otherwise it return false.
+func (s *ss) accept(ok string) bool {
+ return s.consume(ok, true)
+}
+
+// okVerb verifies that the verb is present in the list, setting s.err appropriately if not.
+func (s *ss) okVerb(verb int, okVerbs, typ string) bool {
+ for _, v := range okVerbs {
+ if v == verb {
+ return true
+ }
+ }
+ s.errorString("bad verb %" + string(verb) + " for " + typ)
+ return false
+}
+
+// scanBool returns the value of the boolean represented by the next token.
+func (s *ss) scanBool(verb int) bool {
+ if !s.okVerb(verb, "tv", "boolean") {
+ return false
+ }
+ // Syntax-checking a boolean is annoying. We're not fastidious about case.
+ switch s.mustReadRune() {
+ case '0':
+ return false
+ case '1':
+ return true
+ case 't', 'T':
+ if s.accept("rR") && (!s.accept("uU") || !s.accept("eE")) {
+ s.error(boolError)
+ }
+ return true
+ case 'f', 'F':
+ if s.accept("aL") && (!s.accept("lL") || !s.accept("sS") || !s.accept("eE")) {
+ s.error(boolError)
+ }
+ return false
+ }
+ return false
+}
+
+// Numerical elements
+const (
+ binaryDigits = "01"
+ octalDigits = "01234567"
+ decimalDigits = "0123456789"
+ hexadecimalDigits = "0123456789aAbBcCdDeEfF"
+ sign = "+-"
+ period = "."
+ exponent = "eEp"
+)
+
+// getBase returns the numeric base represented by the verb and its digit string.
+func (s *ss) getBase(verb int) (base int, digits string) {
+ s.okVerb(verb, "bdoUxXv", "integer") // sets s.err
+ base = 10
+ digits = decimalDigits
+ switch verb {
+ case 'b':
+ base = 2
+ digits = binaryDigits
+ case 'o':
+ base = 8
+ digits = octalDigits
+ case 'x', 'X', 'U':
+ base = 16
+ digits = hexadecimalDigits
+ }
+ return
+}
+
+// scanNumber returns the numerical string with specified digits starting here.
+func (s *ss) scanNumber(digits string, haveDigits bool) string {
+ if !haveDigits && !s.accept(digits) {
+ s.errorString("expected integer")
+ }
+ for s.accept(digits) {
+ }
+ return s.buf.String()
+}
+
+// scanRune returns the next rune value in the input.
+func (s *ss) scanRune(bitSize int) int64 {
+ rune := int64(s.mustReadRune())
+ n := uint(bitSize)
+ x := (rune << (64 - n)) >> (64 - n)
+ if x != rune {
+ s.errorString("overflow on character value " + string(rune))
+ }
+ return rune
+}
+
+// scanBasePrefix reports whether the integer begins with a 0 or 0x,
+// and returns the base, digit string, and whether a zero was found.
+// It is called only if the verb is %v.
+func (s *ss) scanBasePrefix() (base int, digits string, found bool) {
+ if !s.peek("0") {
+ return 10, decimalDigits, false
+ }
+ s.accept("0")
+ found = true // We've put a digit into the token buffer.
+ // Special cases for '0' && '0x'
+ base, digits = 8, octalDigits
+ if s.peek("xX") {
+ s.consume("xX", false)
+ base, digits = 16, hexadecimalDigits
+ }
+ return
+}
+
+// scanInt returns the value of the integer represented by the next
+// token, checking for overflow. Any error is stored in s.err.
+func (s *ss) scanInt(verb int, bitSize int) int64 {
+ if verb == 'c' {
+ return s.scanRune(bitSize)
+ }
+ s.skipSpace(false)
+ base, digits := s.getBase(verb)
+ haveDigits := false
+ if verb == 'U' {
+ if !s.consume("U", false) || !s.consume("+", false) {
+ s.errorString("bad unicode format ")
+ }
+ } else {
+ s.accept(sign) // If there's a sign, it will be left in the token buffer.
+ if verb == 'v' {
+ base, digits, haveDigits = s.scanBasePrefix()
+ }
+ }
+ tok := s.scanNumber(digits, haveDigits)
+ i, err := strconv.Btoi64(tok, base)
+ if err != nil {
+ s.error(err)
+ }
+ n := uint(bitSize)
+ x := (i << (64 - n)) >> (64 - n)
+ if x != i {
+ s.errorString("integer overflow on token " + tok)
+ }
+ return i
+}
+
+// scanUint returns the value of the unsigned integer represented
+// by the next token, checking for overflow. Any error is stored in s.err.
+func (s *ss) scanUint(verb int, bitSize int) uint64 {
+ if verb == 'c' {
+ return uint64(s.scanRune(bitSize))
+ }
+ s.skipSpace(false)
+ base, digits := s.getBase(verb)
+ haveDigits := false
+ if verb == 'U' {
+ if !s.consume("U", false) || !s.consume("+", false) {
+ s.errorString("bad unicode format ")
+ }
+ } else if verb == 'v' {
+ base, digits, haveDigits = s.scanBasePrefix()
+ }
+ tok := s.scanNumber(digits, haveDigits)
+ i, err := strconv.Btoui64(tok, base)
+ if err != nil {
+ s.error(err)
+ }
+ n := uint(bitSize)
+ x := (i << (64 - n)) >> (64 - n)
+ if x != i {
+ s.errorString("unsigned integer overflow on token " + tok)
+ }
+ return i
+}
+
+// floatToken returns the floating-point number starting here, no longer than swid
+// if the width is specified. It's not rigorous about syntax because it doesn't check that
+// we have at least some digits, but Atof will do that.
+func (s *ss) floatToken() string {
+ s.buf.Reset()
+ // NaN?
+ if s.accept("nN") && s.accept("aA") && s.accept("nN") {
+ return s.buf.String()
+ }
+ // leading sign?
+ s.accept(sign)
+ // Inf?
+ if s.accept("iI") && s.accept("nN") && s.accept("fF") {
+ return s.buf.String()
+ }
+ // digits?
+ for s.accept(decimalDigits) {
+ }
+ // decimal point?
+ if s.accept(period) {
+ // fraction?
+ for s.accept(decimalDigits) {
+ }
+ }
+ // exponent?
+ if s.accept(exponent) {
+ // leading sign?
+ s.accept(sign)
+ // digits?
+ for s.accept(decimalDigits) {
+ }
+ }
+ return s.buf.String()
+}
+
+// complexTokens returns the real and imaginary parts of the complex number starting here.
+// The number might be parenthesized and has the format (N+Ni) where N is a floating-point
+// number and there are no spaces within.
+func (s *ss) complexTokens() (real, imag string) {
+ // TODO: accept N and Ni independently?
+ parens := s.accept("(")
+ real = s.floatToken()
+ s.buf.Reset()
+ // Must now have a sign.
+ if !s.accept("+-") {
+ s.error(complexError)
+ }
+ // Sign is now in buffer
+ imagSign := s.buf.String()
+ imag = s.floatToken()
+ if !s.accept("i") {
+ s.error(complexError)
+ }
+ if parens && !s.accept(")") {
+ s.error(complexError)
+ }
+ return real, imagSign + imag
+}
+
+// convertFloat converts the string to a float64value.
+func (s *ss) convertFloat(str string, n int) float64 {
+ if p := strings.Index(str, "p"); p >= 0 {
+ // Atof doesn't handle power-of-2 exponents,
+ // but they're easy to evaluate.
+ f, err := strconv.AtofN(str[:p], n)
+ if err != nil {
+ // Put full string into error.
+ if e, ok := err.(*strconv.NumError); ok {
+ e.Num = str
+ }
+ s.error(err)
+ }
+ n, err := strconv.Atoi(str[p+1:])
+ if err != nil {
+ // Put full string into error.
+ if e, ok := err.(*strconv.NumError); ok {
+ e.Num = str
+ }
+ s.error(err)
+ }
+ return math.Ldexp(f, n)
+ }
+ f, err := strconv.AtofN(str, n)
+ if err != nil {
+ s.error(err)
+ }
+ return f
+}
+
+// convertComplex converts the next token to a complex128 value.
+// The atof argument is a type-specific reader for the underlying type.
+// If we're reading complex64, atof will parse float32s and convert them
+// to float64's to avoid reproducing this code for each complex type.
+func (s *ss) scanComplex(verb int, n int) complex128 {
+ if !s.okVerb(verb, floatVerbs, "complex") {
+ return 0
+ }
+ s.skipSpace(false)
+ sreal, simag := s.complexTokens()
+ real := s.convertFloat(sreal, n/2)
+ imag := s.convertFloat(simag, n/2)
+ return complex(real, imag)
+}
+
+// convertString returns the string represented by the next input characters.
+// The format of the input is determined by the verb.
+func (s *ss) convertString(verb int) (str string) {
+ if !s.okVerb(verb, "svqx", "string") {
+ return ""
+ }
+ s.skipSpace(false)
+ switch verb {
+ case 'q':
+ str = s.quotedString()
+ case 'x':
+ str = s.hexString()
+ default:
+ str = string(s.token(true, notSpace)) // %s and %v just return the next word
+ }
+ // Empty strings other than with %q are not OK.
+ if len(str) == 0 && verb != 'q' && s.maxWid > 0 {
+ s.errorString("Scan: no data for string")
+ }
+ return
+}
+
+// quotedString returns the double- or back-quoted string represented by the next input characters.
+func (s *ss) quotedString() string {
+ quote := s.mustReadRune()
+ switch quote {
+ case '`':
+ // Back-quoted: Anything goes until EOF or back quote.
+ for {
+ rune := s.mustReadRune()
+ if rune == quote {
+ break
+ }
+ s.buf.WriteRune(rune)
+ }
+ return s.buf.String()
+ case '"':
+ // Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes.
+ s.buf.WriteRune(quote)
+ for {
+ rune := s.mustReadRune()
+ s.buf.WriteRune(rune)
+ if rune == '\\' {
+ // In a legal backslash escape, no matter how long, only the character
+ // immediately after the escape can itself be a backslash or quote.
+ // Thus we only need to protect the first character after the backslash.
+ rune := s.mustReadRune()
+ s.buf.WriteRune(rune)
+ } else if rune == '"' {
+ break
+ }
+ }
+ result, err := strconv.Unquote(s.buf.String())
+ if err != nil {
+ s.error(err)
+ }
+ return result
+ default:
+ s.errorString("expected quoted string")
+ }
+ return ""
+}
+
+// hexDigit returns the value of the hexadecimal digit
+func (s *ss) hexDigit(digit int) int {
+ switch digit {
+ case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
+ return digit - '0'
+ case 'a', 'b', 'c', 'd', 'e', 'f':
+ return 10 + digit - 'a'
+ case 'A', 'B', 'C', 'D', 'E', 'F':
+ return 10 + digit - 'A'
+ }
+ s.errorString("Scan: illegal hex digit")
+ return 0
+}
+
+// hexByte returns the next hex-encoded (two-character) byte from the input.
+// There must be either two hexadecimal digits or a space character in the input.
+func (s *ss) hexByte() (b byte, ok bool) {
+ rune1 := s.getRune()
+ if rune1 == eof {
+ return
+ }
+ if unicode.IsSpace(rune1) {
+ s.UnreadRune()
+ return
+ }
+ rune2 := s.mustReadRune()
+ return byte(s.hexDigit(rune1)<<4 | s.hexDigit(rune2)), true
+}
+
+// hexString returns the space-delimited hexpair-encoded string.
+func (s *ss) hexString() string {
+ for {
+ b, ok := s.hexByte()
+ if !ok {
+ break
+ }
+ s.buf.WriteByte(b)
+ }
+ if s.buf.Len() == 0 {
+ s.errorString("Scan: no hex data for %x string")
+ return ""
+ }
+ return s.buf.String()
+}
+
+const floatVerbs = "beEfFgGv"
+
+const hugeWid = 1 << 30
+
+// scanOne scans a single value, deriving the scanner from the type of the argument.
+func (s *ss) scanOne(verb int, field interface{}) {
+ s.buf.Reset()
+ var err os.Error
+ // If the parameter has its own Scan method, use that.
+ if v, ok := field.(Scanner); ok {
+ err = v.Scan(s, verb)
+ if err != nil {
+ if err == os.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ s.error(err)
+ }
+ return
+ }
+ switch v := field.(type) {
+ case *bool:
+ *v = s.scanBool(verb)
+ case *complex64:
+ *v = complex64(s.scanComplex(verb, 64))
+ case *complex128:
+ *v = s.scanComplex(verb, 128)
+ case *int:
+ *v = int(s.scanInt(verb, intBits))
+ case *int8:
+ *v = int8(s.scanInt(verb, 8))
+ case *int16:
+ *v = int16(s.scanInt(verb, 16))
+ case *int32:
+ *v = int32(s.scanInt(verb, 32))
+ case *int64:
+ *v = s.scanInt(verb, 64)
+ case *uint:
+ *v = uint(s.scanUint(verb, intBits))
+ case *uint8:
+ *v = uint8(s.scanUint(verb, 8))
+ case *uint16:
+ *v = uint16(s.scanUint(verb, 16))
+ case *uint32:
+ *v = uint32(s.scanUint(verb, 32))
+ case *uint64:
+ *v = s.scanUint(verb, 64)
+ case *uintptr:
+ *v = uintptr(s.scanUint(verb, uintptrBits))
+ // Floats are tricky because you want to scan in the precision of the result, not
+ // scan in high precision and convert, in order to preserve the correct error condition.
+ case *float32:
+ if s.okVerb(verb, floatVerbs, "float32") {
+ s.skipSpace(false)
+ *v = float32(s.convertFloat(s.floatToken(), 32))
+ }
+ case *float64:
+ if s.okVerb(verb, floatVerbs, "float64") {
+ s.skipSpace(false)
+ *v = s.convertFloat(s.floatToken(), 64)
+ }
+ case *string:
+ *v = s.convertString(verb)
+ case *[]byte:
+ // We scan to string and convert so we get a copy of the data.
+ // If we scanned to bytes, the slice would point at the buffer.
+ *v = []byte(s.convertString(verb))
+ default:
+ val := reflect.NewValue(v)
+ ptr, ok := val.(*reflect.PtrValue)
+ if !ok {
+ s.errorString("Scan: type not a pointer: " + val.Type().String())
+ return
+ }
+ switch v := ptr.Elem().(type) {
+ case *reflect.BoolValue:
+ v.Set(s.scanBool(verb))
+ case *reflect.IntValue:
+ v.Set(s.scanInt(verb, v.Type().Bits()))
+ case *reflect.UintValue:
+ v.Set(s.scanUint(verb, v.Type().Bits()))
+ case *reflect.StringValue:
+ v.Set(s.convertString(verb))
+ case *reflect.SliceValue:
+ // For now, can only handle (renamed) []byte.
+ typ := v.Type().(*reflect.SliceType)
+ if typ.Elem().Kind() != reflect.Uint8 {
+ goto CantHandle
+ }
+ str := s.convertString(verb)
+ v.Set(reflect.MakeSlice(typ, len(str), len(str)))
+ for i := 0; i < len(str); i++ {
+ v.Elem(i).(*reflect.UintValue).Set(uint64(str[i]))
+ }
+ case *reflect.FloatValue:
+ s.skipSpace(false)
+ v.Set(s.convertFloat(s.floatToken(), v.Type().Bits()))
+ case *reflect.ComplexValue:
+ v.Set(s.scanComplex(verb, v.Type().Bits()))
+ default:
+ CantHandle:
+ s.errorString("Scan: can't handle type: " + val.Type().String())
+ }
+ }
+}
+
+// errorHandler turns local panics into error returns. EOFs are benign.
+func errorHandler(errp *os.Error) {
+ if e := recover(); e != nil {
+ if se, ok := e.(scanError); ok { // catch local error
+ if se.err != os.EOF {
+ *errp = se.err
+ }
+ } else {
+ panic(e)
+ }
+ }
+}
+
+// doScan does the real work for scanning without a format string.
+func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) {
+ defer errorHandler(&err)
+ for _, field := range a {
+ s.scanOne('v', field)
+ numProcessed++
+ }
+ // Check for newline if required.
+ if !s.nlIsSpace {
+ for {
+ rune := s.getRune()
+ if rune == '\n' || rune == eof {
+ break
+ }
+ if !unicode.IsSpace(rune) {
+ s.errorString("Scan: expected newline")
+ break
+ }
+ }
+ }
+ return
+}
+
+// advance determines whether the next characters in the input match
+// those of the format. It returns the number of bytes (sic) consumed
+// in the format. Newlines included, all runs of space characters in
+// either input or format behave as a single space. This routine also
+// handles the %% case. If the return value is zero, either format
+// starts with a % (with no following %) or the input is empty.
+// If it is negative, the input did not match the string.
+func (s *ss) advance(format string) (i int) {
+ for i < len(format) {
+ fmtc, w := utf8.DecodeRuneInString(format[i:])
+ if fmtc == '%' {
+ // %% acts like a real percent
+ nextc, _ := utf8.DecodeRuneInString(format[i+w:]) // will not match % if string is empty
+ if nextc != '%' {
+ return
+ }
+ i += w // skip the first %
+ }
+ sawSpace := false
+ for unicode.IsSpace(fmtc) && i < len(format) {
+ sawSpace = true
+ i += w
+ fmtc, w = utf8.DecodeRuneInString(format[i:])
+ }
+ if sawSpace {
+ // There was space in the format, so there should be space (EOF)
+ // in the input.
+ inputc := s.getRune()
+ if inputc == eof {
+ return
+ }
+ if !unicode.IsSpace(inputc) {
+ // Space in format but not in input: error
+ s.errorString("expected space in input to match format")
+ }
+ s.skipSpace(true)
+ continue
+ }
+ inputc := s.mustReadRune()
+ if fmtc != inputc {
+ s.UnreadRune()
+ return -1
+ }
+ i += w
+ }
+ return
+}
+
+// doScanf does the real work when scanning with a format string.
+// At the moment, it handles only pointers to basic types.
+func (s *ss) doScanf(format string, a []interface{}) (numProcessed int, err os.Error) {
+ defer errorHandler(&err)
+ end := len(format) - 1
+ // We process one item per non-trivial format
+ for i := 0; i <= end; {
+ w := s.advance(format[i:])
+ if w > 0 {
+ i += w
+ continue
+ }
+ // Either we failed to advance, we have a percent character, or we ran out of input.
+ if format[i] != '%' {
+ // Can't advance format. Why not?
+ if w < 0 {
+ s.errorString("input does not match format")
+ }
+ // Otherwise at EOF; "too many operands" error handled below
+ break
+ }
+ i++ // % is one byte
+
+ // do we have 20 (width)?
+ var widPresent bool
+ s.maxWid, widPresent, i = parsenum(format, i, end)
+ if !widPresent {
+ s.maxWid = hugeWid
+ }
+ s.fieldLimit = s.limit
+ if f := s.count + s.maxWid; f < s.fieldLimit {
+ s.fieldLimit = f
+ }
+
+ c, w := utf8.DecodeRuneInString(format[i:])
+ i += w
+
+ if numProcessed >= len(a) { // out of operands
+ s.errorString("too few operands for format %" + format[i-w:])
+ break
+ }
+ field := a[numProcessed]
+
+ s.scanOne(c, field)
+ numProcessed++
+ s.fieldLimit = s.limit
+ }
+ if numProcessed < len(a) {
+ s.errorString("too many operands")
+ }
+ return
+}
diff --git a/src/cmd/fix/testdata/reflect.scan.go.out b/src/cmd/fix/testdata/reflect.scan.go.out
new file mode 100644
index 000000000..956c13bde
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.scan.go.out
@@ -0,0 +1,1082 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+import (
+ "bytes"
+ "io"
+ "math"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// runeUnreader is the interface to something that can unread runes.
+// If the object provided to Scan does not satisfy this interface,
+// a local buffer will be used to back up the input, but its contents
+// will be lost when Scan returns.
+type runeUnreader interface {
+ UnreadRune() os.Error
+}
+
+// ScanState represents the scanner state passed to custom scanners.
+// Scanners may do rune-at-a-time scanning or ask the ScanState
+// to discover the next space-delimited token.
+type ScanState interface {
+ // ReadRune reads the next rune (Unicode code point) from the input.
+ // If invoked during Scanln, Fscanln, or Sscanln, ReadRune() will
+ // return EOF after returning the first '\n' or when reading beyond
+ // the specified width.
+ ReadRune() (rune int, size int, err os.Error)
+ // UnreadRune causes the next call to ReadRune to return the same rune.
+ UnreadRune() os.Error
+ // Token skips space in the input if skipSpace is true, then returns the
+ // run of Unicode code points c satisfying f(c). If f is nil,
+ // !unicode.IsSpace(c) is used; that is, the token will hold non-space
+ // characters. Newlines are treated as space unless the scan operation
+ // is Scanln, Fscanln or Sscanln, in which case a newline is treated as
+ // EOF. The returned slice points to shared data that may be overwritten
+ // by the next call to Token, a call to a Scan function using the ScanState
+ // as input, or when the calling Scan method returns.
+ Token(skipSpace bool, f func(int) bool) (token []byte, err os.Error)
+ // Width returns the value of the width option and whether it has been set.
+ // The unit is Unicode code points.
+ Width() (wid int, ok bool)
+ // Because ReadRune is implemented by the interface, Read should never be
+ // called by the scanning routines and a valid implementation of
+ // ScanState may choose always to return an error from Read.
+ Read(buf []byte) (n int, err os.Error)
+}
+
+// Scanner is implemented by any value that has a Scan method, which scans
+// the input for the representation of a value and stores the result in the
+// receiver, which must be a pointer to be useful. The Scan method is called
+// for any argument to Scan, Scanf, or Scanln that implements it.
+type Scanner interface {
+ Scan(state ScanState, verb int) os.Error
+}
+
+// Scan scans text read from standard input, storing successive
+// space-separated values into successive arguments. Newlines count
+// as space. It returns the number of items successfully scanned.
+// If that is less than the number of arguments, err will report why.
+func Scan(a ...interface{}) (n int, err os.Error) {
+ return Fscan(os.Stdin, a...)
+}
+
+// Scanln is similar to Scan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Scanln(a ...interface{}) (n int, err os.Error) {
+ return Fscanln(os.Stdin, a...)
+}
+
+// Scanf scans text read from standard input, storing successive
+// space-separated values into successive arguments as determined by
+// the format. It returns the number of items successfully scanned.
+func Scanf(format string, a ...interface{}) (n int, err os.Error) {
+ return Fscanf(os.Stdin, format, a...)
+}
+
+// Sscan scans the argument string, storing successive space-separated
+// values into successive arguments. Newlines count as space. It
+// returns the number of items successfully scanned. If that is less
+// than the number of arguments, err will report why.
+func Sscan(str string, a ...interface{}) (n int, err os.Error) {
+ return Fscan(strings.NewReader(str), a...)
+}
+
+// Sscanln is similar to Sscan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Sscanln(str string, a ...interface{}) (n int, err os.Error) {
+ return Fscanln(strings.NewReader(str), a...)
+}
+
+// Sscanf scans the argument string, storing successive space-separated
+// values into successive arguments as determined by the format. It
+// returns the number of items successfully parsed.
+func Sscanf(str string, format string, a ...interface{}) (n int, err os.Error) {
+ return Fscanf(strings.NewReader(str), format, a...)
+}
+
+// Fscan scans text read from r, storing successive space-separated
+// values into successive arguments. Newlines count as space. It
+// returns the number of items successfully scanned. If that is less
+// than the number of arguments, err will report why.
+func Fscan(r io.Reader, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, true, false)
+ n, err = s.doScan(a)
+ s.free(old)
+ return
+}
+
+// Fscanln is similar to Fscan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Fscanln(r io.Reader, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, false, true)
+ n, err = s.doScan(a)
+ s.free(old)
+ return
+}
+
+// Fscanf scans text read from r, storing successive space-separated
+// values into successive arguments as determined by the format. It
+// returns the number of items successfully parsed.
+func Fscanf(r io.Reader, format string, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, false, false)
+ n, err = s.doScanf(format, a)
+ s.free(old)
+ return
+}
+
+// scanError represents an error generated by the scanning software.
+// It's used as a unique signature to identify such errors when recovering.
+type scanError struct {
+ err os.Error
+}
+
+const eof = -1
+
+// ss is the internal implementation of ScanState.
+type ss struct {
+ rr io.RuneReader // where to read input
+ buf bytes.Buffer // token accumulator
+ peekRune int // one-rune lookahead
+ prevRune int // last rune returned by ReadRune
+ count int // runes consumed so far.
+ atEOF bool // already read EOF
+ ssave
+}
+
+// ssave holds the parts of ss that need to be
+// saved and restored on recursive scans.
+type ssave struct {
+ validSave bool // is or was a part of an actual ss.
+ nlIsEnd bool // whether newline terminates scan
+ nlIsSpace bool // whether newline counts as white space
+ fieldLimit int // max value of ss.count for this field; fieldLimit <= limit
+ limit int // max value of ss.count.
+ maxWid int // width of this field.
+}
+
+// The Read method is only in ScanState so that ScanState
+// satisfies io.Reader. It will never be called when used as
+// intended, so there is no need to make it actually work.
+func (s *ss) Read(buf []byte) (n int, err os.Error) {
+ return 0, os.NewError("ScanState's Read should not be called. Use ReadRune")
+}
+
+func (s *ss) ReadRune() (rune int, size int, err os.Error) {
+ if s.peekRune >= 0 {
+ s.count++
+ rune = s.peekRune
+ size = utf8.RuneLen(rune)
+ s.prevRune = rune
+ s.peekRune = -1
+ return
+ }
+ if s.atEOF || s.nlIsEnd && s.prevRune == '\n' || s.count >= s.fieldLimit {
+ err = os.EOF
+ return
+ }
+
+ rune, size, err = s.rr.ReadRune()
+ if err == nil {
+ s.count++
+ s.prevRune = rune
+ } else if err == os.EOF {
+ s.atEOF = true
+ }
+ return
+}
+
+func (s *ss) Width() (wid int, ok bool) {
+ if s.maxWid == hugeWid {
+ return 0, false
+ }
+ return s.maxWid, true
+}
+
+// The public method returns an error; this private one panics.
+// If getRune reaches EOF, the return value is EOF (-1).
+func (s *ss) getRune() (rune int) {
+ rune, _, err := s.ReadRune()
+ if err != nil {
+ if err == os.EOF {
+ return eof
+ }
+ s.error(err)
+ }
+ return
+}
+
+// mustReadRune turns os.EOF into a panic(io.ErrUnexpectedEOF).
+// It is called in cases such as string scanning where an EOF is a
+// syntax error.
+func (s *ss) mustReadRune() (rune int) {
+ rune = s.getRune()
+ if rune == eof {
+ s.error(io.ErrUnexpectedEOF)
+ }
+ return
+}
+
+func (s *ss) UnreadRune() os.Error {
+ if u, ok := s.rr.(runeUnreader); ok {
+ u.UnreadRune()
+ } else {
+ s.peekRune = s.prevRune
+ }
+ s.count--
+ return nil
+}
+
+func (s *ss) error(err os.Error) {
+ panic(scanError{err})
+}
+
+func (s *ss) errorString(err string) {
+ panic(scanError{os.NewError(err)})
+}
+
+func (s *ss) Token(skipSpace bool, f func(int) bool) (tok []byte, err os.Error) {
+ defer func() {
+ if e := recover(); e != nil {
+ if se, ok := e.(scanError); ok {
+ err = se.err
+ } else {
+ panic(e)
+ }
+ }
+ }()
+ if f == nil {
+ f = notSpace
+ }
+ s.buf.Reset()
+ tok = s.token(skipSpace, f)
+ return
+}
+
+// notSpace is the default scanning function used in Token.
+func notSpace(r int) bool {
+ return !unicode.IsSpace(r)
+}
+
+// readRune is a structure to enable reading UTF-8 encoded code points
+// from an io.Reader. It is used if the Reader given to the scanner does
+// not already implement io.RuneReader.
+type readRune struct {
+ reader io.Reader
+ buf [utf8.UTFMax]byte // used only inside ReadRune
+ pending int // number of bytes in pendBuf; only >0 for bad UTF-8
+ pendBuf [utf8.UTFMax]byte // bytes left over
+}
+
+// readByte returns the next byte from the input, which may be
+// left over from a previous read if the UTF-8 was ill-formed.
+func (r *readRune) readByte() (b byte, err os.Error) {
+ if r.pending > 0 {
+ b = r.pendBuf[0]
+ copy(r.pendBuf[0:], r.pendBuf[1:])
+ r.pending--
+ return
+ }
+ _, err = r.reader.Read(r.pendBuf[0:1])
+ return r.pendBuf[0], err
+}
+
+// unread saves the bytes for the next read.
+func (r *readRune) unread(buf []byte) {
+ copy(r.pendBuf[r.pending:], buf)
+ r.pending += len(buf)
+}
+
+// ReadRune returns the next UTF-8 encoded code point from the
+// io.Reader inside r.
+func (r *readRune) ReadRune() (rune int, size int, err os.Error) {
+ r.buf[0], err = r.readByte()
+ if err != nil {
+ return 0, 0, err
+ }
+ if r.buf[0] < utf8.RuneSelf { // fast check for common ASCII case
+ rune = int(r.buf[0])
+ return
+ }
+ var n int
+ for n = 1; !utf8.FullRune(r.buf[0:n]); n++ {
+ r.buf[n], err = r.readByte()
+ if err != nil {
+ if err == os.EOF {
+ err = nil
+ break
+ }
+ return
+ }
+ }
+ rune, size = utf8.DecodeRune(r.buf[0:n])
+ if size < n { // an error
+ r.unread(r.buf[size:n])
+ }
+ return
+}
+
+var ssFree = newCache(func() interface{} { return new(ss) })
+
+// Allocate a new ss struct or grab a cached one.
+func newScanState(r io.Reader, nlIsSpace, nlIsEnd bool) (s *ss, old ssave) {
+ // If the reader is a *ss, then we've got a recursive
+ // call to Scan, so re-use the scan state.
+ s, ok := r.(*ss)
+ if ok {
+ old = s.ssave
+ s.limit = s.fieldLimit
+ s.nlIsEnd = nlIsEnd || s.nlIsEnd
+ s.nlIsSpace = nlIsSpace
+ return
+ }
+
+ s = ssFree.get().(*ss)
+ if rr, ok := r.(io.RuneReader); ok {
+ s.rr = rr
+ } else {
+ s.rr = &readRune{reader: r}
+ }
+ s.nlIsSpace = nlIsSpace
+ s.nlIsEnd = nlIsEnd
+ s.prevRune = -1
+ s.peekRune = -1
+ s.atEOF = false
+ s.limit = hugeWid
+ s.fieldLimit = hugeWid
+ s.maxWid = hugeWid
+ s.validSave = true
+ return
+}
+
+// Save used ss structs in ssFree; avoid an allocation per invocation.
+func (s *ss) free(old ssave) {
+ // If it was used recursively, just restore the old state.
+ if old.validSave {
+ s.ssave = old
+ return
+ }
+ // Don't hold on to ss structs with large buffers.
+ if cap(s.buf.Bytes()) > 1024 {
+ return
+ }
+ s.buf.Reset()
+ s.rr = nil
+ ssFree.put(s)
+}
+
+// skipSpace skips spaces and maybe newlines.
+func (s *ss) skipSpace(stopAtNewline bool) {
+ for {
+ rune := s.getRune()
+ if rune == eof {
+ return
+ }
+ if rune == '\n' {
+ if stopAtNewline {
+ break
+ }
+ if s.nlIsSpace {
+ continue
+ }
+ s.errorString("unexpected newline")
+ return
+ }
+ if !unicode.IsSpace(rune) {
+ s.UnreadRune()
+ break
+ }
+ }
+}
+
+// token returns the next space-delimited string from the input. It
+// skips white space. For Scanln, it stops at newlines. For Scan,
+// newlines are treated as spaces.
+func (s *ss) token(skipSpace bool, f func(int) bool) []byte {
+ if skipSpace {
+ s.skipSpace(false)
+ }
+ // read until white space or newline
+ for {
+ rune := s.getRune()
+ if rune == eof {
+ break
+ }
+ if !f(rune) {
+ s.UnreadRune()
+ break
+ }
+ s.buf.WriteRune(rune)
+ }
+ return s.buf.Bytes()
+}
+
+// typeError indicates that the type of the operand did not match the format
+func (s *ss) typeError(field interface{}, expected string) {
+ s.errorString("expected field of type pointer to " + expected + "; found " + reflect.TypeOf(field).String())
+}
+
+var complexError = os.NewError("syntax error scanning complex number")
+var boolError = os.NewError("syntax error scanning boolean")
+
+// consume reads the next rune in the input and reports whether it is in the ok string.
+// If accept is true, it puts the character into the input token.
+func (s *ss) consume(ok string, accept bool) bool {
+ rune := s.getRune()
+ if rune == eof {
+ return false
+ }
+ if strings.IndexRune(ok, rune) >= 0 {
+ if accept {
+ s.buf.WriteRune(rune)
+ }
+ return true
+ }
+ if rune != eof && accept {
+ s.UnreadRune()
+ }
+ return false
+}
+
+// peek reports whether the next character is in the ok string, without consuming it.
+func (s *ss) peek(ok string) bool {
+ rune := s.getRune()
+ if rune != eof {
+ s.UnreadRune()
+ }
+ return strings.IndexRune(ok, rune) >= 0
+}
+
+// accept checks the next rune in the input. If it's a byte (sic) in the string, it puts it in the
+// buffer and returns true. Otherwise it return false.
+func (s *ss) accept(ok string) bool {
+ return s.consume(ok, true)
+}
+
+// okVerb verifies that the verb is present in the list, setting s.err appropriately if not.
+func (s *ss) okVerb(verb int, okVerbs, typ string) bool {
+ for _, v := range okVerbs {
+ if v == verb {
+ return true
+ }
+ }
+ s.errorString("bad verb %" + string(verb) + " for " + typ)
+ return false
+}
+
+// scanBool returns the value of the boolean represented by the next token.
+func (s *ss) scanBool(verb int) bool {
+ if !s.okVerb(verb, "tv", "boolean") {
+ return false
+ }
+ // Syntax-checking a boolean is annoying. We're not fastidious about case.
+ switch s.mustReadRune() {
+ case '0':
+ return false
+ case '1':
+ return true
+ case 't', 'T':
+ if s.accept("rR") && (!s.accept("uU") || !s.accept("eE")) {
+ s.error(boolError)
+ }
+ return true
+ case 'f', 'F':
+ if s.accept("aL") && (!s.accept("lL") || !s.accept("sS") || !s.accept("eE")) {
+ s.error(boolError)
+ }
+ return false
+ }
+ return false
+}
+
+// Numerical elements
+const (
+ binaryDigits = "01"
+ octalDigits = "01234567"
+ decimalDigits = "0123456789"
+ hexadecimalDigits = "0123456789aAbBcCdDeEfF"
+ sign = "+-"
+ period = "."
+ exponent = "eEp"
+)
+
+// getBase returns the numeric base represented by the verb and its digit string.
+func (s *ss) getBase(verb int) (base int, digits string) {
+ s.okVerb(verb, "bdoUxXv", "integer") // sets s.err
+ base = 10
+ digits = decimalDigits
+ switch verb {
+ case 'b':
+ base = 2
+ digits = binaryDigits
+ case 'o':
+ base = 8
+ digits = octalDigits
+ case 'x', 'X', 'U':
+ base = 16
+ digits = hexadecimalDigits
+ }
+ return
+}
+
+// scanNumber returns the numerical string with specified digits starting here.
+func (s *ss) scanNumber(digits string, haveDigits bool) string {
+ if !haveDigits && !s.accept(digits) {
+ s.errorString("expected integer")
+ }
+ for s.accept(digits) {
+ }
+ return s.buf.String()
+}
+
+// scanRune returns the next rune value in the input.
+func (s *ss) scanRune(bitSize int) int64 {
+ rune := int64(s.mustReadRune())
+ n := uint(bitSize)
+ x := (rune << (64 - n)) >> (64 - n)
+ if x != rune {
+ s.errorString("overflow on character value " + string(rune))
+ }
+ return rune
+}
+
+// scanBasePrefix reports whether the integer begins with a 0 or 0x,
+// and returns the base, digit string, and whether a zero was found.
+// It is called only if the verb is %v.
+func (s *ss) scanBasePrefix() (base int, digits string, found bool) {
+ if !s.peek("0") {
+ return 10, decimalDigits, false
+ }
+ s.accept("0")
+ found = true // We've put a digit into the token buffer.
+ // Special cases for '0' && '0x'
+ base, digits = 8, octalDigits
+ if s.peek("xX") {
+ s.consume("xX", false)
+ base, digits = 16, hexadecimalDigits
+ }
+ return
+}
+
+// scanInt returns the value of the integer represented by the next
+// token, checking for overflow. Any error is stored in s.err.
+func (s *ss) scanInt(verb int, bitSize int) int64 {
+ if verb == 'c' {
+ return s.scanRune(bitSize)
+ }
+ s.skipSpace(false)
+ base, digits := s.getBase(verb)
+ haveDigits := false
+ if verb == 'U' {
+ if !s.consume("U", false) || !s.consume("+", false) {
+ s.errorString("bad unicode format ")
+ }
+ } else {
+ s.accept(sign) // If there's a sign, it will be left in the token buffer.
+ if verb == 'v' {
+ base, digits, haveDigits = s.scanBasePrefix()
+ }
+ }
+ tok := s.scanNumber(digits, haveDigits)
+ i, err := strconv.Btoi64(tok, base)
+ if err != nil {
+ s.error(err)
+ }
+ n := uint(bitSize)
+ x := (i << (64 - n)) >> (64 - n)
+ if x != i {
+ s.errorString("integer overflow on token " + tok)
+ }
+ return i
+}
+
+// scanUint returns the value of the unsigned integer represented
+// by the next token, checking for overflow. Any error is stored in s.err.
+func (s *ss) scanUint(verb int, bitSize int) uint64 {
+ if verb == 'c' {
+ return uint64(s.scanRune(bitSize))
+ }
+ s.skipSpace(false)
+ base, digits := s.getBase(verb)
+ haveDigits := false
+ if verb == 'U' {
+ if !s.consume("U", false) || !s.consume("+", false) {
+ s.errorString("bad unicode format ")
+ }
+ } else if verb == 'v' {
+ base, digits, haveDigits = s.scanBasePrefix()
+ }
+ tok := s.scanNumber(digits, haveDigits)
+ i, err := strconv.Btoui64(tok, base)
+ if err != nil {
+ s.error(err)
+ }
+ n := uint(bitSize)
+ x := (i << (64 - n)) >> (64 - n)
+ if x != i {
+ s.errorString("unsigned integer overflow on token " + tok)
+ }
+ return i
+}
+
+// floatToken returns the floating-point number starting here, no longer than swid
+// if the width is specified. It's not rigorous about syntax because it doesn't check that
+// we have at least some digits, but Atof will do that.
+func (s *ss) floatToken() string {
+ s.buf.Reset()
+ // NaN?
+ if s.accept("nN") && s.accept("aA") && s.accept("nN") {
+ return s.buf.String()
+ }
+ // leading sign?
+ s.accept(sign)
+ // Inf?
+ if s.accept("iI") && s.accept("nN") && s.accept("fF") {
+ return s.buf.String()
+ }
+ // digits?
+ for s.accept(decimalDigits) {
+ }
+ // decimal point?
+ if s.accept(period) {
+ // fraction?
+ for s.accept(decimalDigits) {
+ }
+ }
+ // exponent?
+ if s.accept(exponent) {
+ // leading sign?
+ s.accept(sign)
+ // digits?
+ for s.accept(decimalDigits) {
+ }
+ }
+ return s.buf.String()
+}
+
+// complexTokens returns the real and imaginary parts of the complex number starting here.
+// The number might be parenthesized and has the format (N+Ni) where N is a floating-point
+// number and there are no spaces within.
+func (s *ss) complexTokens() (real, imag string) {
+ // TODO: accept N and Ni independently?
+ parens := s.accept("(")
+ real = s.floatToken()
+ s.buf.Reset()
+ // Must now have a sign.
+ if !s.accept("+-") {
+ s.error(complexError)
+ }
+ // Sign is now in buffer
+ imagSign := s.buf.String()
+ imag = s.floatToken()
+ if !s.accept("i") {
+ s.error(complexError)
+ }
+ if parens && !s.accept(")") {
+ s.error(complexError)
+ }
+ return real, imagSign + imag
+}
+
+// convertFloat converts the string to a float64value.
+func (s *ss) convertFloat(str string, n int) float64 {
+ if p := strings.Index(str, "p"); p >= 0 {
+ // Atof doesn't handle power-of-2 exponents,
+ // but they're easy to evaluate.
+ f, err := strconv.AtofN(str[:p], n)
+ if err != nil {
+ // Put full string into error.
+ if e, ok := err.(*strconv.NumError); ok {
+ e.Num = str
+ }
+ s.error(err)
+ }
+ n, err := strconv.Atoi(str[p+1:])
+ if err != nil {
+ // Put full string into error.
+ if e, ok := err.(*strconv.NumError); ok {
+ e.Num = str
+ }
+ s.error(err)
+ }
+ return math.Ldexp(f, n)
+ }
+ f, err := strconv.AtofN(str, n)
+ if err != nil {
+ s.error(err)
+ }
+ return f
+}
+
+// convertComplex converts the next token to a complex128 value.
+// The atof argument is a type-specific reader for the underlying type.
+// If we're reading complex64, atof will parse float32s and convert them
+// to float64's to avoid reproducing this code for each complex type.
+func (s *ss) scanComplex(verb int, n int) complex128 {
+ if !s.okVerb(verb, floatVerbs, "complex") {
+ return 0
+ }
+ s.skipSpace(false)
+ sreal, simag := s.complexTokens()
+ real := s.convertFloat(sreal, n/2)
+ imag := s.convertFloat(simag, n/2)
+ return complex(real, imag)
+}
+
+// convertString returns the string represented by the next input characters.
+// The format of the input is determined by the verb.
+func (s *ss) convertString(verb int) (str string) {
+ if !s.okVerb(verb, "svqx", "string") {
+ return ""
+ }
+ s.skipSpace(false)
+ switch verb {
+ case 'q':
+ str = s.quotedString()
+ case 'x':
+ str = s.hexString()
+ default:
+ str = string(s.token(true, notSpace)) // %s and %v just return the next word
+ }
+ // Empty strings other than with %q are not OK.
+ if len(str) == 0 && verb != 'q' && s.maxWid > 0 {
+ s.errorString("Scan: no data for string")
+ }
+ return
+}
+
+// quotedString returns the double- or back-quoted string represented by the next input characters.
+func (s *ss) quotedString() string {
+ quote := s.mustReadRune()
+ switch quote {
+ case '`':
+ // Back-quoted: Anything goes until EOF or back quote.
+ for {
+ rune := s.mustReadRune()
+ if rune == quote {
+ break
+ }
+ s.buf.WriteRune(rune)
+ }
+ return s.buf.String()
+ case '"':
+ // Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes.
+ s.buf.WriteRune(quote)
+ for {
+ rune := s.mustReadRune()
+ s.buf.WriteRune(rune)
+ if rune == '\\' {
+ // In a legal backslash escape, no matter how long, only the character
+ // immediately after the escape can itself be a backslash or quote.
+ // Thus we only need to protect the first character after the backslash.
+ rune := s.mustReadRune()
+ s.buf.WriteRune(rune)
+ } else if rune == '"' {
+ break
+ }
+ }
+ result, err := strconv.Unquote(s.buf.String())
+ if err != nil {
+ s.error(err)
+ }
+ return result
+ default:
+ s.errorString("expected quoted string")
+ }
+ return ""
+}
+
+// hexDigit returns the value of the hexadecimal digit
+func (s *ss) hexDigit(digit int) int {
+ switch digit {
+ case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
+ return digit - '0'
+ case 'a', 'b', 'c', 'd', 'e', 'f':
+ return 10 + digit - 'a'
+ case 'A', 'B', 'C', 'D', 'E', 'F':
+ return 10 + digit - 'A'
+ }
+ s.errorString("Scan: illegal hex digit")
+ return 0
+}
+
+// hexByte returns the next hex-encoded (two-character) byte from the input.
+// There must be either two hexadecimal digits or a space character in the input.
+func (s *ss) hexByte() (b byte, ok bool) {
+ rune1 := s.getRune()
+ if rune1 == eof {
+ return
+ }
+ if unicode.IsSpace(rune1) {
+ s.UnreadRune()
+ return
+ }
+ rune2 := s.mustReadRune()
+ return byte(s.hexDigit(rune1)<<4 | s.hexDigit(rune2)), true
+}
+
+// hexString returns the space-delimited hexpair-encoded string.
+func (s *ss) hexString() string {
+ for {
+ b, ok := s.hexByte()
+ if !ok {
+ break
+ }
+ s.buf.WriteByte(b)
+ }
+ if s.buf.Len() == 0 {
+ s.errorString("Scan: no hex data for %x string")
+ return ""
+ }
+ return s.buf.String()
+}
+
+const floatVerbs = "beEfFgGv"
+
+const hugeWid = 1 << 30
+
+// scanOne scans a single value, deriving the scanner from the type of the argument.
+func (s *ss) scanOne(verb int, field interface{}) {
+ s.buf.Reset()
+ var err os.Error
+ // If the parameter has its own Scan method, use that.
+ if v, ok := field.(Scanner); ok {
+ err = v.Scan(s, verb)
+ if err != nil {
+ if err == os.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ s.error(err)
+ }
+ return
+ }
+ switch v := field.(type) {
+ case *bool:
+ *v = s.scanBool(verb)
+ case *complex64:
+ *v = complex64(s.scanComplex(verb, 64))
+ case *complex128:
+ *v = s.scanComplex(verb, 128)
+ case *int:
+ *v = int(s.scanInt(verb, intBits))
+ case *int8:
+ *v = int8(s.scanInt(verb, 8))
+ case *int16:
+ *v = int16(s.scanInt(verb, 16))
+ case *int32:
+ *v = int32(s.scanInt(verb, 32))
+ case *int64:
+ *v = s.scanInt(verb, 64)
+ case *uint:
+ *v = uint(s.scanUint(verb, intBits))
+ case *uint8:
+ *v = uint8(s.scanUint(verb, 8))
+ case *uint16:
+ *v = uint16(s.scanUint(verb, 16))
+ case *uint32:
+ *v = uint32(s.scanUint(verb, 32))
+ case *uint64:
+ *v = s.scanUint(verb, 64)
+ case *uintptr:
+ *v = uintptr(s.scanUint(verb, uintptrBits))
+ // Floats are tricky because you want to scan in the precision of the result, not
+ // scan in high precision and convert, in order to preserve the correct error condition.
+ case *float32:
+ if s.okVerb(verb, floatVerbs, "float32") {
+ s.skipSpace(false)
+ *v = float32(s.convertFloat(s.floatToken(), 32))
+ }
+ case *float64:
+ if s.okVerb(verb, floatVerbs, "float64") {
+ s.skipSpace(false)
+ *v = s.convertFloat(s.floatToken(), 64)
+ }
+ case *string:
+ *v = s.convertString(verb)
+ case *[]byte:
+ // We scan to string and convert so we get a copy of the data.
+ // If we scanned to bytes, the slice would point at the buffer.
+ *v = []byte(s.convertString(verb))
+ default:
+ val := reflect.ValueOf(v)
+ ptr := val
+ if ptr.Kind() != reflect.Ptr {
+ s.errorString("Scan: type not a pointer: " + val.Type().String())
+ return
+ }
+ switch v := ptr.Elem(); v.Kind() {
+ case reflect.Bool:
+ v.SetBool(s.scanBool(verb))
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ v.SetInt(s.scanInt(verb, v.Type().Bits()))
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ v.SetUint(s.scanUint(verb, v.Type().Bits()))
+ case reflect.String:
+ v.SetString(s.convertString(verb))
+ case reflect.Slice:
+ // For now, can only handle (renamed) []byte.
+ typ := v.Type()
+ if typ.Elem().Kind() != reflect.Uint8 {
+ goto CantHandle
+ }
+ str := s.convertString(verb)
+ v.Set(reflect.MakeSlice(typ, len(str), len(str)))
+ for i := 0; i < len(str); i++ {
+ v.Index(i).SetUint(uint64(str[i]))
+ }
+ case reflect.Float32, reflect.Float64:
+ s.skipSpace(false)
+ v.SetFloat(s.convertFloat(s.floatToken(), v.Type().Bits()))
+ case reflect.Complex64, reflect.Complex128:
+ v.SetComplex(s.scanComplex(verb, v.Type().Bits()))
+ default:
+ CantHandle:
+ s.errorString("Scan: can't handle type: " + val.Type().String())
+ }
+ }
+}
+
+// errorHandler turns local panics into error returns. EOFs are benign.
+func errorHandler(errp *os.Error) {
+ if e := recover(); e != nil {
+ if se, ok := e.(scanError); ok { // catch local error
+ if se.err != os.EOF {
+ *errp = se.err
+ }
+ } else {
+ panic(e)
+ }
+ }
+}
+
+// doScan does the real work for scanning without a format string.
+func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) {
+ defer errorHandler(&err)
+ for _, field := range a {
+ s.scanOne('v', field)
+ numProcessed++
+ }
+ // Check for newline if required.
+ if !s.nlIsSpace {
+ for {
+ rune := s.getRune()
+ if rune == '\n' || rune == eof {
+ break
+ }
+ if !unicode.IsSpace(rune) {
+ s.errorString("Scan: expected newline")
+ break
+ }
+ }
+ }
+ return
+}
+
+// advance determines whether the next characters in the input match
+// those of the format. It returns the number of bytes (sic) consumed
+// in the format. Newlines included, all runs of space characters in
+// either input or format behave as a single space. This routine also
+// handles the %% case. If the return value is zero, either format
+// starts with a % (with no following %) or the input is empty.
+// If it is negative, the input did not match the string.
+func (s *ss) advance(format string) (i int) {
+ for i < len(format) {
+ fmtc, w := utf8.DecodeRuneInString(format[i:])
+ if fmtc == '%' {
+ // %% acts like a real percent
+ nextc, _ := utf8.DecodeRuneInString(format[i+w:]) // will not match % if string is empty
+ if nextc != '%' {
+ return
+ }
+ i += w // skip the first %
+ }
+ sawSpace := false
+ for unicode.IsSpace(fmtc) && i < len(format) {
+ sawSpace = true
+ i += w
+ fmtc, w = utf8.DecodeRuneInString(format[i:])
+ }
+ if sawSpace {
+ // There was space in the format, so there should be space (EOF)
+ // in the input.
+ inputc := s.getRune()
+ if inputc == eof {
+ return
+ }
+ if !unicode.IsSpace(inputc) {
+ // Space in format but not in input: error
+ s.errorString("expected space in input to match format")
+ }
+ s.skipSpace(true)
+ continue
+ }
+ inputc := s.mustReadRune()
+ if fmtc != inputc {
+ s.UnreadRune()
+ return -1
+ }
+ i += w
+ }
+ return
+}
+
+// doScanf does the real work when scanning with a format string.
+// At the moment, it handles only pointers to basic types.
+func (s *ss) doScanf(format string, a []interface{}) (numProcessed int, err os.Error) {
+ defer errorHandler(&err)
+ end := len(format) - 1
+ // We process one item per non-trivial format
+ for i := 0; i <= end; {
+ w := s.advance(format[i:])
+ if w > 0 {
+ i += w
+ continue
+ }
+ // Either we failed to advance, we have a percent character, or we ran out of input.
+ if format[i] != '%' {
+ // Can't advance format. Why not?
+ if w < 0 {
+ s.errorString("input does not match format")
+ }
+ // Otherwise at EOF; "too many operands" error handled below
+ break
+ }
+ i++ // % is one byte
+
+ // do we have 20 (width)?
+ var widPresent bool
+ s.maxWid, widPresent, i = parsenum(format, i, end)
+ if !widPresent {
+ s.maxWid = hugeWid
+ }
+ s.fieldLimit = s.limit
+ if f := s.count + s.maxWid; f < s.fieldLimit {
+ s.fieldLimit = f
+ }
+
+ c, w := utf8.DecodeRuneInString(format[i:])
+ i += w
+
+ if numProcessed >= len(a) { // out of operands
+ s.errorString("too few operands for format %" + format[i-w:])
+ break
+ }
+ field := a[numProcessed]
+
+ s.scanOne(c, field)
+ numProcessed++
+ s.fieldLimit = s.limit
+ }
+ if numProcessed < len(a) {
+ s.errorString("too many operands")
+ }
+ return
+}
diff --git a/src/cmd/fix/testdata/reflect.script.go.in b/src/cmd/fix/testdata/reflect.script.go.in
new file mode 100644
index 000000000..b341b1f89
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.script.go.in
@@ -0,0 +1,359 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package aids in the testing of code that uses channels.
+package script
+
+import (
+ "fmt"
+ "os"
+ "rand"
+ "reflect"
+ "strings"
+)
+
+// An Event is an element in a partially ordered set that either sends a value
+// to a channel or expects a value from a channel.
+type Event struct {
+ name string
+ occurred bool
+ predecessors []*Event
+ action action
+}
+
+type action interface {
+ // getSend returns nil if the action is not a send action.
+ getSend() sendAction
+ // getRecv returns nil if the action is not a receive action.
+ getRecv() recvAction
+ // getChannel returns the channel that the action operates on.
+ getChannel() interface{}
+}
+
+type recvAction interface {
+ recvMatch(interface{}) bool
+}
+
+type sendAction interface {
+ send()
+}
+
+// isReady returns true if all the predecessors of an Event have occurred.
+func (e Event) isReady() bool {
+ for _, predecessor := range e.predecessors {
+ if !predecessor.occurred {
+ return false
+ }
+ }
+
+ return true
+}
+
+// A Recv action reads a value from a channel and uses reflect.DeepMatch to
+// compare it with an expected value.
+type Recv struct {
+ Channel interface{}
+ Expected interface{}
+}
+
+func (r Recv) getRecv() recvAction { return r }
+
+func (Recv) getSend() sendAction { return nil }
+
+func (r Recv) getChannel() interface{} { return r.Channel }
+
+func (r Recv) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelRecv)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return reflect.DeepEqual(c.value, r.Expected)
+}
+
+// A RecvMatch action reads a value from a channel and calls a function to
+// determine if the value matches.
+type RecvMatch struct {
+ Channel interface{}
+ Match func(interface{}) bool
+}
+
+func (r RecvMatch) getRecv() recvAction { return r }
+
+func (RecvMatch) getSend() sendAction { return nil }
+
+func (r RecvMatch) getChannel() interface{} { return r.Channel }
+
+func (r RecvMatch) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelRecv)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return r.Match(c.value)
+}
+
+// A Closed action matches if the given channel is closed. The closing is
+// treated as an event, not a state, thus Closed will only match once for a
+// given channel.
+type Closed struct {
+ Channel interface{}
+}
+
+func (r Closed) getRecv() recvAction { return r }
+
+func (Closed) getSend() sendAction { return nil }
+
+func (r Closed) getChannel() interface{} { return r.Channel }
+
+func (r Closed) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelClosed)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return true
+}
+
+// A Send action sends a value to a channel. The value must match the
+// type of the channel exactly unless the channel if of type chan interface{}.
+type Send struct {
+ Channel interface{}
+ Value interface{}
+}
+
+func (Send) getRecv() recvAction { return nil }
+
+func (s Send) getSend() sendAction { return s }
+
+func (s Send) getChannel() interface{} { return s.Channel }
+
+type empty struct {
+ x interface{}
+}
+
+func newEmptyInterface(e empty) reflect.Value {
+ return reflect.NewValue(e).(*reflect.StructValue).Field(0)
+}
+
+func (s Send) send() {
+ // With reflect.ChanValue.Send, we must match the types exactly. So, if
+ // s.Channel is a chan interface{} we convert s.Value to an interface{}
+ // first.
+ c := reflect.NewValue(s.Channel).(*reflect.ChanValue)
+ var v reflect.Value
+ if iface, ok := c.Type().(*reflect.ChanType).Elem().(*reflect.InterfaceType); ok && iface.NumMethod() == 0 {
+ v = newEmptyInterface(empty{s.Value})
+ } else {
+ v = reflect.NewValue(s.Value)
+ }
+ c.Send(v)
+}
+
+// A Close action closes the given channel.
+type Close struct {
+ Channel interface{}
+}
+
+func (Close) getRecv() recvAction { return nil }
+
+func (s Close) getSend() sendAction { return s }
+
+func (s Close) getChannel() interface{} { return s.Channel }
+
+func (s Close) send() { reflect.NewValue(s.Channel).(*reflect.ChanValue).Close() }
+
+// A ReceivedUnexpected error results if no active Events match a value
+// received from a channel.
+type ReceivedUnexpected struct {
+ Value interface{}
+ ready []*Event
+}
+
+func (r ReceivedUnexpected) String() string {
+ names := make([]string, len(r.ready))
+ for i, v := range r.ready {
+ names[i] = v.name
+ }
+ return fmt.Sprintf("received unexpected value on one of the channels: %#v. Runnable events: %s", r.Value, strings.Join(names, ", "))
+}
+
+// A SetupError results if there is a error with the configuration of a set of
+// Events.
+type SetupError string
+
+func (s SetupError) String() string { return string(s) }
+
+func NewEvent(name string, predecessors []*Event, action action) *Event {
+ e := &Event{name, false, predecessors, action}
+ return e
+}
+
+// Given a set of Events, Perform repeatedly iterates over the set and finds the
+// subset of ready Events (that is, all of their predecessors have
+// occurred). From that subset, it pseudo-randomly selects an Event to perform.
+// If the Event is a send event, the send occurs and Perform recalculates the ready
+// set. If the event is a receive event, Perform waits for a value from any of the
+// channels that are contained in any of the events. That value is then matched
+// against the ready events. The first event that matches is considered to
+// have occurred and Perform recalculates the ready set.
+//
+// Perform continues this until all Events have occurred.
+//
+// Note that uncollected goroutines may still be reading from any of the
+// channels read from after Perform returns.
+//
+// For example, consider the problem of testing a function that reads values on
+// one channel and echos them to two output channels. To test this we would
+// create three events: a send event and two receive events. Each of the
+// receive events must list the send event as a predecessor but there is no
+// ordering between the receive events.
+//
+// send := NewEvent("send", nil, Send{c, 1})
+// recv1 := NewEvent("recv 1", []*Event{send}, Recv{c, 1})
+// recv2 := NewEvent("recv 2", []*Event{send}, Recv{c, 1})
+// Perform(0, []*Event{send, recv1, recv2})
+//
+// At first, only the send event would be in the ready set and thus Perform will
+// send a value to the input channel. Now the two receive events are ready and
+// Perform will match each of them against the values read from the output channels.
+//
+// It would be invalid to list one of the receive events as a predecessor of
+// the other. At each receive step, all the receive channels are considered,
+// thus Perform may see a value from a channel that is not in the current ready
+// set and fail.
+func Perform(seed int64, events []*Event) (err os.Error) {
+ r := rand.New(rand.NewSource(seed))
+
+ channels, err := getChannels(events)
+ if err != nil {
+ return
+ }
+ multiplex := make(chan interface{})
+ for _, channel := range channels {
+ go recvValues(multiplex, channel)
+ }
+
+Outer:
+ for {
+ ready, err := readyEvents(events)
+ if err != nil {
+ return err
+ }
+
+ if len(ready) == 0 {
+ // All events occurred.
+ break
+ }
+
+ event := ready[r.Intn(len(ready))]
+ if send := event.action.getSend(); send != nil {
+ send.send()
+ event.occurred = true
+ continue
+ }
+
+ v := <-multiplex
+ for _, event := range ready {
+ if recv := event.action.getRecv(); recv != nil && recv.recvMatch(v) {
+ event.occurred = true
+ continue Outer
+ }
+ }
+
+ return ReceivedUnexpected{v, ready}
+ }
+
+ return nil
+}
+
+// getChannels returns all the channels listed in any receive events.
+func getChannels(events []*Event) ([]interface{}, os.Error) {
+ channels := make([]interface{}, len(events))
+
+ j := 0
+ for _, event := range events {
+ if recv := event.action.getRecv(); recv == nil {
+ continue
+ }
+ c := event.action.getChannel()
+ if _, ok := reflect.NewValue(c).(*reflect.ChanValue); !ok {
+ return nil, SetupError("one of the channel values is not a channel")
+ }
+
+ duplicate := false
+ for _, other := range channels[0:j] {
+ if c == other {
+ duplicate = true
+ break
+ }
+ }
+
+ if !duplicate {
+ channels[j] = c
+ j++
+ }
+ }
+
+ return channels[0:j], nil
+}
+
+// recvValues is a multiplexing helper function. It reads values from the given
+// channel repeatedly, wrapping them up as either a channelRecv or
+// channelClosed structure, and forwards them to the multiplex channel.
+func recvValues(multiplex chan<- interface{}, channel interface{}) {
+ c := reflect.NewValue(channel).(*reflect.ChanValue)
+
+ for {
+ v, ok := c.Recv()
+ if !ok {
+ multiplex <- channelClosed{channel}
+ return
+ }
+
+ multiplex <- channelRecv{channel, v.Interface()}
+ }
+}
+
+type channelClosed struct {
+ channel interface{}
+}
+
+type channelRecv struct {
+ channel interface{}
+ value interface{}
+}
+
+// readyEvents returns the subset of events that are ready.
+func readyEvents(events []*Event) ([]*Event, os.Error) {
+ ready := make([]*Event, len(events))
+
+ j := 0
+ eventsWaiting := false
+ for _, event := range events {
+ if event.occurred {
+ continue
+ }
+
+ eventsWaiting = true
+ if event.isReady() {
+ ready[j] = event
+ j++
+ }
+ }
+
+ if j == 0 && eventsWaiting {
+ names := make([]string, len(events))
+ for _, event := range events {
+ if event.occurred {
+ continue
+ }
+ names[j] = event.name
+ }
+
+ return nil, SetupError("dependency cycle in events. These events are waiting to run but cannot: " + strings.Join(names, ", "))
+ }
+
+ return ready[0:j], nil
+}
diff --git a/src/cmd/fix/testdata/reflect.script.go.out b/src/cmd/fix/testdata/reflect.script.go.out
new file mode 100644
index 000000000..bc5a6a41d
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.script.go.out
@@ -0,0 +1,359 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package aids in the testing of code that uses channels.
+package script
+
+import (
+ "fmt"
+ "os"
+ "rand"
+ "reflect"
+ "strings"
+)
+
+// An Event is an element in a partially ordered set that either sends a value
+// to a channel or expects a value from a channel.
+type Event struct {
+ name string
+ occurred bool
+ predecessors []*Event
+ action action
+}
+
+type action interface {
+ // getSend returns nil if the action is not a send action.
+ getSend() sendAction
+ // getRecv returns nil if the action is not a receive action.
+ getRecv() recvAction
+ // getChannel returns the channel that the action operates on.
+ getChannel() interface{}
+}
+
+type recvAction interface {
+ recvMatch(interface{}) bool
+}
+
+type sendAction interface {
+ send()
+}
+
+// isReady returns true if all the predecessors of an Event have occurred.
+func (e Event) isReady() bool {
+ for _, predecessor := range e.predecessors {
+ if !predecessor.occurred {
+ return false
+ }
+ }
+
+ return true
+}
+
+// A Recv action reads a value from a channel and uses reflect.DeepMatch to
+// compare it with an expected value.
+type Recv struct {
+ Channel interface{}
+ Expected interface{}
+}
+
+func (r Recv) getRecv() recvAction { return r }
+
+func (Recv) getSend() sendAction { return nil }
+
+func (r Recv) getChannel() interface{} { return r.Channel }
+
+func (r Recv) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelRecv)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return reflect.DeepEqual(c.value, r.Expected)
+}
+
+// A RecvMatch action reads a value from a channel and calls a function to
+// determine if the value matches.
+type RecvMatch struct {
+ Channel interface{}
+ Match func(interface{}) bool
+}
+
+func (r RecvMatch) getRecv() recvAction { return r }
+
+func (RecvMatch) getSend() sendAction { return nil }
+
+func (r RecvMatch) getChannel() interface{} { return r.Channel }
+
+func (r RecvMatch) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelRecv)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return r.Match(c.value)
+}
+
+// A Closed action matches if the given channel is closed. The closing is
+// treated as an event, not a state, thus Closed will only match once for a
+// given channel.
+type Closed struct {
+ Channel interface{}
+}
+
+func (r Closed) getRecv() recvAction { return r }
+
+func (Closed) getSend() sendAction { return nil }
+
+func (r Closed) getChannel() interface{} { return r.Channel }
+
+func (r Closed) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelClosed)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return true
+}
+
+// A Send action sends a value to a channel. The value must match the
+// type of the channel exactly unless the channel if of type chan interface{}.
+type Send struct {
+ Channel interface{}
+ Value interface{}
+}
+
+func (Send) getRecv() recvAction { return nil }
+
+func (s Send) getSend() sendAction { return s }
+
+func (s Send) getChannel() interface{} { return s.Channel }
+
+type empty struct {
+ x interface{}
+}
+
+func newEmptyInterface(e empty) reflect.Value {
+ return reflect.ValueOf(e).Field(0)
+}
+
+func (s Send) send() {
+ // With reflect.ChanValue.Send, we must match the types exactly. So, if
+ // s.Channel is a chan interface{} we convert s.Value to an interface{}
+ // first.
+ c := reflect.ValueOf(s.Channel)
+ var v reflect.Value
+ if iface := c.Type().Elem(); iface.Kind() == reflect.Interface && iface.NumMethod() == 0 {
+ v = newEmptyInterface(empty{s.Value})
+ } else {
+ v = reflect.ValueOf(s.Value)
+ }
+ c.Send(v)
+}
+
+// A Close action closes the given channel.
+type Close struct {
+ Channel interface{}
+}
+
+func (Close) getRecv() recvAction { return nil }
+
+func (s Close) getSend() sendAction { return s }
+
+func (s Close) getChannel() interface{} { return s.Channel }
+
+func (s Close) send() { reflect.ValueOf(s.Channel).Close() }
+
+// A ReceivedUnexpected error results if no active Events match a value
+// received from a channel.
+type ReceivedUnexpected struct {
+ Value interface{}
+ ready []*Event
+}
+
+func (r ReceivedUnexpected) String() string {
+ names := make([]string, len(r.ready))
+ for i, v := range r.ready {
+ names[i] = v.name
+ }
+ return fmt.Sprintf("received unexpected value on one of the channels: %#v. Runnable events: %s", r.Value, strings.Join(names, ", "))
+}
+
+// A SetupError results if there is a error with the configuration of a set of
+// Events.
+type SetupError string
+
+func (s SetupError) String() string { return string(s) }
+
+func NewEvent(name string, predecessors []*Event, action action) *Event {
+ e := &Event{name, false, predecessors, action}
+ return e
+}
+
+// Given a set of Events, Perform repeatedly iterates over the set and finds the
+// subset of ready Events (that is, all of their predecessors have
+// occurred). From that subset, it pseudo-randomly selects an Event to perform.
+// If the Event is a send event, the send occurs and Perform recalculates the ready
+// set. If the event is a receive event, Perform waits for a value from any of the
+// channels that are contained in any of the events. That value is then matched
+// against the ready events. The first event that matches is considered to
+// have occurred and Perform recalculates the ready set.
+//
+// Perform continues this until all Events have occurred.
+//
+// Note that uncollected goroutines may still be reading from any of the
+// channels read from after Perform returns.
+//
+// For example, consider the problem of testing a function that reads values on
+// one channel and echos them to two output channels. To test this we would
+// create three events: a send event and two receive events. Each of the
+// receive events must list the send event as a predecessor but there is no
+// ordering between the receive events.
+//
+// send := NewEvent("send", nil, Send{c, 1})
+// recv1 := NewEvent("recv 1", []*Event{send}, Recv{c, 1})
+// recv2 := NewEvent("recv 2", []*Event{send}, Recv{c, 1})
+// Perform(0, []*Event{send, recv1, recv2})
+//
+// At first, only the send event would be in the ready set and thus Perform will
+// send a value to the input channel. Now the two receive events are ready and
+// Perform will match each of them against the values read from the output channels.
+//
+// It would be invalid to list one of the receive events as a predecessor of
+// the other. At each receive step, all the receive channels are considered,
+// thus Perform may see a value from a channel that is not in the current ready
+// set and fail.
+func Perform(seed int64, events []*Event) (err os.Error) {
+ r := rand.New(rand.NewSource(seed))
+
+ channels, err := getChannels(events)
+ if err != nil {
+ return
+ }
+ multiplex := make(chan interface{})
+ for _, channel := range channels {
+ go recvValues(multiplex, channel)
+ }
+
+Outer:
+ for {
+ ready, err := readyEvents(events)
+ if err != nil {
+ return err
+ }
+
+ if len(ready) == 0 {
+ // All events occurred.
+ break
+ }
+
+ event := ready[r.Intn(len(ready))]
+ if send := event.action.getSend(); send != nil {
+ send.send()
+ event.occurred = true
+ continue
+ }
+
+ v := <-multiplex
+ for _, event := range ready {
+ if recv := event.action.getRecv(); recv != nil && recv.recvMatch(v) {
+ event.occurred = true
+ continue Outer
+ }
+ }
+
+ return ReceivedUnexpected{v, ready}
+ }
+
+ return nil
+}
+
+// getChannels returns all the channels listed in any receive events.
+func getChannels(events []*Event) ([]interface{}, os.Error) {
+ channels := make([]interface{}, len(events))
+
+ j := 0
+ for _, event := range events {
+ if recv := event.action.getRecv(); recv == nil {
+ continue
+ }
+ c := event.action.getChannel()
+ if reflect.ValueOf(c).Kind() != reflect.Chan {
+ return nil, SetupError("one of the channel values is not a channel")
+ }
+
+ duplicate := false
+ for _, other := range channels[0:j] {
+ if c == other {
+ duplicate = true
+ break
+ }
+ }
+
+ if !duplicate {
+ channels[j] = c
+ j++
+ }
+ }
+
+ return channels[0:j], nil
+}
+
+// recvValues is a multiplexing helper function. It reads values from the given
+// channel repeatedly, wrapping them up as either a channelRecv or
+// channelClosed structure, and forwards them to the multiplex channel.
+func recvValues(multiplex chan<- interface{}, channel interface{}) {
+ c := reflect.ValueOf(channel)
+
+ for {
+ v, ok := c.Recv()
+ if !ok {
+ multiplex <- channelClosed{channel}
+ return
+ }
+
+ multiplex <- channelRecv{channel, v.Interface()}
+ }
+}
+
+type channelClosed struct {
+ channel interface{}
+}
+
+type channelRecv struct {
+ channel interface{}
+ value interface{}
+}
+
+// readyEvents returns the subset of events that are ready.
+func readyEvents(events []*Event) ([]*Event, os.Error) {
+ ready := make([]*Event, len(events))
+
+ j := 0
+ eventsWaiting := false
+ for _, event := range events {
+ if event.occurred {
+ continue
+ }
+
+ eventsWaiting = true
+ if event.isReady() {
+ ready[j] = event
+ j++
+ }
+ }
+
+ if j == 0 && eventsWaiting {
+ names := make([]string, len(events))
+ for _, event := range events {
+ if event.occurred {
+ continue
+ }
+ names[j] = event.name
+ }
+
+ return nil, SetupError("dependency cycle in events. These events are waiting to run but cannot: " + strings.Join(names, ", "))
+ }
+
+ return ready[0:j], nil
+}
diff --git a/src/cmd/fix/testdata/reflect.template.go.in b/src/cmd/fix/testdata/reflect.template.go.in
new file mode 100644
index 000000000..1f5a8128f
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.template.go.in
@@ -0,0 +1,1043 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ Data-driven templates for generating textual output such as
+ HTML.
+
+ Templates are executed by applying them to a data structure.
+ Annotations in the template refer to elements of the data
+ structure (typically a field of a struct or a key in a map)
+ to control execution and derive values to be displayed.
+ The template walks the structure as it executes and the
+ "cursor" @ represents the value at the current location
+ in the structure.
+
+ Data items may be values or pointers; the interface hides the
+ indirection.
+
+ In the following, 'field' is one of several things, according to the data.
+
+ - The name of a field of a struct (result = data.field),
+ - The value stored in a map under that key (result = data[field]), or
+ - The result of invoking a niladic single-valued method with that name
+ (result = data.field())
+
+ Major constructs ({} are the default delimiters for template actions;
+ [] are the notation in this comment for optional elements):
+
+ {# comment }
+
+ A one-line comment.
+
+ {.section field} XXX [ {.or} YYY ] {.end}
+
+ Set @ to the value of the field. It may be an explicit @
+ to stay at the same point in the data. If the field is nil
+ or empty, execute YYY; otherwise execute XXX.
+
+ {.repeated section field} XXX [ {.alternates with} ZZZ ] [ {.or} YYY ] {.end}
+
+ Like .section, but field must be an array or slice. XXX
+ is executed for each element. If the array is nil or empty,
+ YYY is executed instead. If the {.alternates with} marker
+ is present, ZZZ is executed between iterations of XXX.
+
+ {field}
+ {field1 field2 ...}
+ {field|formatter}
+ {field1 field2...|formatter}
+ {field|formatter1|formatter2}
+
+ Insert the value of the fields into the output. Each field is
+ first looked for in the cursor, as in .section and .repeated.
+ If it is not found, the search continues in outer sections
+ until the top level is reached.
+
+ If the field value is a pointer, leading asterisks indicate
+ that the value to be inserted should be evaluated through the
+ pointer. For example, if x.p is of type *int, {x.p} will
+ insert the value of the pointer but {*x.p} will insert the
+ value of the underlying integer. If the value is nil or not a
+ pointer, asterisks have no effect.
+
+ If a formatter is specified, it must be named in the formatter
+ map passed to the template set up routines or in the default
+ set ("html","str","") and is used to process the data for
+ output. The formatter function has signature
+ func(wr io.Writer, formatter string, data ...interface{})
+ where wr is the destination for output, data holds the field
+ values at the instantiation, and formatter is its name at
+ the invocation site. The default formatter just concatenates
+ the string representations of the fields.
+
+ Multiple formatters separated by the pipeline character | are
+ executed sequentially, with each formatter receiving the bytes
+ emitted by the one to its left.
+
+ The delimiter strings get their default value, "{" and "}", from
+ JSON-template. They may be set to any non-empty, space-free
+ string using the SetDelims method. Their value can be printed
+ in the output using {.meta-left} and {.meta-right}.
+*/
+package template
+
+import (
+ "bytes"
+ "container/vector"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "reflect"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// Errors returned during parsing and execution. Users may extract the information and reformat
+// if they desire.
+type Error struct {
+ Line int
+ Msg string
+}
+
+func (e *Error) String() string { return fmt.Sprintf("line %d: %s", e.Line, e.Msg) }
+
+// Most of the literals are aces.
+var lbrace = []byte{'{'}
+var rbrace = []byte{'}'}
+var space = []byte{' '}
+var tab = []byte{'\t'}
+
+// The various types of "tokens", which are plain text or (usually) brace-delimited descriptors
+const (
+ tokAlternates = iota
+ tokComment
+ tokEnd
+ tokLiteral
+ tokOr
+ tokRepeated
+ tokSection
+ tokText
+ tokVariable
+)
+
+// FormatterMap is the type describing the mapping from formatter
+// names to the functions that implement them.
+type FormatterMap map[string]func(io.Writer, string, ...interface{})
+
+// Built-in formatters.
+var builtins = FormatterMap{
+ "html": HTMLFormatter,
+ "str": StringFormatter,
+ "": StringFormatter,
+}
+
+// The parsed state of a template is a vector of xxxElement structs.
+// Sections have line numbers so errors can be reported better during execution.
+
+// Plain text.
+type textElement struct {
+ text []byte
+}
+
+// A literal such as .meta-left or .meta-right
+type literalElement struct {
+ text []byte
+}
+
+// A variable invocation to be evaluated
+type variableElement struct {
+ linenum int
+ word []string // The fields in the invocation.
+ fmts []string // Names of formatters to apply. len(fmts) > 0
+}
+
+// A .section block, possibly with a .or
+type sectionElement struct {
+ linenum int // of .section itself
+ field string // cursor field for this block
+ start int // first element
+ or int // first element of .or block
+ end int // one beyond last element
+}
+
+// A .repeated block, possibly with a .or and a .alternates
+type repeatedElement struct {
+ sectionElement // It has the same structure...
+ altstart int // ... except for alternates
+ altend int
+}
+
+// Template is the type that represents a template definition.
+// It is unchanged after parsing.
+type Template struct {
+ fmap FormatterMap // formatters for variables
+ // Used during parsing:
+ ldelim, rdelim []byte // delimiters; default {}
+ buf []byte // input text to process
+ p int // position in buf
+ linenum int // position in input
+ // Parsed results:
+ elems *vector.Vector
+}
+
+// Internal state for executing a Template. As we evaluate the struct,
+// the data item descends into the fields associated with sections, etc.
+// Parent is used to walk upwards to find variables higher in the tree.
+type state struct {
+ parent *state // parent in hierarchy
+ data reflect.Value // the driver data for this section etc.
+ wr io.Writer // where to send output
+ buf [2]bytes.Buffer // alternating buffers used when chaining formatters
+}
+
+func (parent *state) clone(data reflect.Value) *state {
+ return &state{parent: parent, data: data, wr: parent.wr}
+}
+
+// New creates a new template with the specified formatter map (which
+// may be nil) to define auxiliary functions for formatting variables.
+func New(fmap FormatterMap) *Template {
+ t := new(Template)
+ t.fmap = fmap
+ t.ldelim = lbrace
+ t.rdelim = rbrace
+ t.elems = new(vector.Vector)
+ return t
+}
+
+// Report error and stop executing. The line number must be provided explicitly.
+func (t *Template) execError(st *state, line int, err string, args ...interface{}) {
+ panic(&Error{line, fmt.Sprintf(err, args...)})
+}
+
+// Report error, panic to terminate parsing.
+// The line number comes from the template state.
+func (t *Template) parseError(err string, args ...interface{}) {
+ panic(&Error{t.linenum, fmt.Sprintf(err, args...)})
+}
+
+// Is this an exported - upper case - name?
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// -- Lexical analysis
+
+// Is c a white space character?
+func white(c uint8) bool { return c == ' ' || c == '\t' || c == '\r' || c == '\n' }
+
+// Safely, does s[n:n+len(t)] == t?
+func equal(s []byte, n int, t []byte) bool {
+ b := s[n:]
+ if len(t) > len(b) { // not enough space left for a match.
+ return false
+ }
+ for i, c := range t {
+ if c != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// nextItem returns the next item from the input buffer. If the returned
+// item is empty, we are at EOF. The item will be either a
+// delimited string or a non-empty string between delimited
+// strings. Tokens stop at (but include, if plain text) a newline.
+// Action tokens on a line by themselves drop any space on
+// either side, up to and including the newline.
+func (t *Template) nextItem() []byte {
+ startOfLine := t.p == 0 || t.buf[t.p-1] == '\n'
+ start := t.p
+ var i int
+ newline := func() {
+ t.linenum++
+ i++
+ }
+ // Leading white space up to but not including newline
+ for i = start; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' || !white(t.buf[i]) {
+ break
+ }
+ }
+ leadingSpace := i > start
+ // What's left is nothing, newline, delimited string, or plain text
+ switch {
+ case i == len(t.buf):
+ // EOF; nothing to do
+ case t.buf[i] == '\n':
+ newline()
+ case equal(t.buf, i, t.ldelim):
+ left := i // Start of left delimiter.
+ right := -1 // Will be (immediately after) right delimiter.
+ haveText := false // Delimiters contain text.
+ i += len(t.ldelim)
+ // Find the end of the action.
+ for ; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' {
+ break
+ }
+ if equal(t.buf, i, t.rdelim) {
+ i += len(t.rdelim)
+ right = i
+ break
+ }
+ haveText = true
+ }
+ if right < 0 {
+ t.parseError("unmatched opening delimiter")
+ return nil
+ }
+ // Is this a special action (starts with '.' or '#') and the only thing on the line?
+ if startOfLine && haveText {
+ firstChar := t.buf[left+len(t.ldelim)]
+ if firstChar == '.' || firstChar == '#' {
+ // It's special and the first thing on the line. Is it the last?
+ for j := right; j < len(t.buf) && white(t.buf[j]); j++ {
+ if t.buf[j] == '\n' {
+ // Yes it is. Drop the surrounding space and return the {.foo}
+ t.linenum++
+ t.p = j + 1
+ return t.buf[left:right]
+ }
+ }
+ }
+ }
+ // No it's not. If there's leading space, return that.
+ if leadingSpace {
+ // not trimming space: return leading white space if there is some.
+ t.p = left
+ return t.buf[start:left]
+ }
+ // Return the word, leave the trailing space.
+ start = left
+ break
+ default:
+ for ; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' {
+ newline()
+ break
+ }
+ if equal(t.buf, i, t.ldelim) {
+ break
+ }
+ }
+ }
+ item := t.buf[start:i]
+ t.p = i
+ return item
+}
+
+// Turn a byte array into a white-space-split array of strings.
+func words(buf []byte) []string {
+ s := make([]string, 0, 5)
+ p := 0 // position in buf
+ // one word per loop
+ for i := 0; ; i++ {
+ // skip white space
+ for ; p < len(buf) && white(buf[p]); p++ {
+ }
+ // grab word
+ start := p
+ for ; p < len(buf) && !white(buf[p]); p++ {
+ }
+ if start == p { // no text left
+ break
+ }
+ s = append(s, string(buf[start:p]))
+ }
+ return s
+}
+
+// Analyze an item and return its token type and, if it's an action item, an array of
+// its constituent words.
+func (t *Template) analyze(item []byte) (tok int, w []string) {
+ // item is known to be non-empty
+ if !equal(item, 0, t.ldelim) { // doesn't start with left delimiter
+ tok = tokText
+ return
+ }
+ if !equal(item, len(item)-len(t.rdelim), t.rdelim) { // doesn't end with right delimiter
+ t.parseError("internal error: unmatched opening delimiter") // lexing should prevent this
+ return
+ }
+ if len(item) <= len(t.ldelim)+len(t.rdelim) { // no contents
+ t.parseError("empty directive")
+ return
+ }
+ // Comment
+ if item[len(t.ldelim)] == '#' {
+ tok = tokComment
+ return
+ }
+ // Split into words
+ w = words(item[len(t.ldelim) : len(item)-len(t.rdelim)]) // drop final delimiter
+ if len(w) == 0 {
+ t.parseError("empty directive")
+ return
+ }
+ if len(w) > 0 && w[0][0] != '.' {
+ tok = tokVariable
+ return
+ }
+ switch w[0] {
+ case ".meta-left", ".meta-right", ".space", ".tab":
+ tok = tokLiteral
+ return
+ case ".or":
+ tok = tokOr
+ return
+ case ".end":
+ tok = tokEnd
+ return
+ case ".section":
+ if len(w) != 2 {
+ t.parseError("incorrect fields for .section: %s", item)
+ return
+ }
+ tok = tokSection
+ return
+ case ".repeated":
+ if len(w) != 3 || w[1] != "section" {
+ t.parseError("incorrect fields for .repeated: %s", item)
+ return
+ }
+ tok = tokRepeated
+ return
+ case ".alternates":
+ if len(w) != 2 || w[1] != "with" {
+ t.parseError("incorrect fields for .alternates: %s", item)
+ return
+ }
+ tok = tokAlternates
+ return
+ }
+ t.parseError("bad directive: %s", item)
+ return
+}
+
+// formatter returns the Formatter with the given name in the Template, or nil if none exists.
+func (t *Template) formatter(name string) func(io.Writer, string, ...interface{}) {
+ if t.fmap != nil {
+ if fn := t.fmap[name]; fn != nil {
+ return fn
+ }
+ }
+ return builtins[name]
+}
+
+// -- Parsing
+
+// Allocate a new variable-evaluation element.
+func (t *Template) newVariable(words []string) *variableElement {
+ // After the final space-separated argument, formatters may be specified separated
+ // by pipe symbols, for example: {a b c|d|e}
+
+ // Until we learn otherwise, formatters contains a single name: "", the default formatter.
+ formatters := []string{""}
+ lastWord := words[len(words)-1]
+ bar := strings.IndexRune(lastWord, '|')
+ if bar >= 0 {
+ words[len(words)-1] = lastWord[0:bar]
+ formatters = strings.Split(lastWord[bar+1:], "|")
+ }
+
+ // We could remember the function address here and avoid the lookup later,
+ // but it's more dynamic to let the user change the map contents underfoot.
+ // We do require the name to be present, though.
+
+ // Is it in user-supplied map?
+ for _, f := range formatters {
+ if t.formatter(f) == nil {
+ t.parseError("unknown formatter: %q", f)
+ }
+ }
+ return &variableElement{t.linenum, words, formatters}
+}
+
+// Grab the next item. If it's simple, just append it to the template.
+// Otherwise return its details.
+func (t *Template) parseSimple(item []byte) (done bool, tok int, w []string) {
+ tok, w = t.analyze(item)
+ done = true // assume for simplicity
+ switch tok {
+ case tokComment:
+ return
+ case tokText:
+ t.elems.Push(&textElement{item})
+ return
+ case tokLiteral:
+ switch w[0] {
+ case ".meta-left":
+ t.elems.Push(&literalElement{t.ldelim})
+ case ".meta-right":
+ t.elems.Push(&literalElement{t.rdelim})
+ case ".space":
+ t.elems.Push(&literalElement{space})
+ case ".tab":
+ t.elems.Push(&literalElement{tab})
+ default:
+ t.parseError("internal error: unknown literal: %s", w[0])
+ }
+ return
+ case tokVariable:
+ t.elems.Push(t.newVariable(w))
+ return
+ }
+ return false, tok, w
+}
+
+// parseRepeated and parseSection are mutually recursive
+
+func (t *Template) parseRepeated(words []string) *repeatedElement {
+ r := new(repeatedElement)
+ t.elems.Push(r)
+ r.linenum = t.linenum
+ r.field = words[2]
+ // Scan section, collecting true and false (.or) blocks.
+ r.start = t.elems.Len()
+ r.or = -1
+ r.altstart = -1
+ r.altend = -1
+Loop:
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ t.parseError("missing .end for .repeated section")
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokEnd:
+ break Loop
+ case tokOr:
+ if r.or >= 0 {
+ t.parseError("extra .or in .repeated section")
+ break Loop
+ }
+ r.altend = t.elems.Len()
+ r.or = t.elems.Len()
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ case tokAlternates:
+ if r.altstart >= 0 {
+ t.parseError("extra .alternates in .repeated section")
+ break Loop
+ }
+ if r.or >= 0 {
+ t.parseError(".alternates inside .or block in .repeated section")
+ break Loop
+ }
+ r.altstart = t.elems.Len()
+ default:
+ t.parseError("internal error: unknown repeated section item: %s", item)
+ break Loop
+ }
+ }
+ if r.altend < 0 {
+ r.altend = t.elems.Len()
+ }
+ r.end = t.elems.Len()
+ return r
+}
+
+func (t *Template) parseSection(words []string) *sectionElement {
+ s := new(sectionElement)
+ t.elems.Push(s)
+ s.linenum = t.linenum
+ s.field = words[1]
+ // Scan section, collecting true and false (.or) blocks.
+ s.start = t.elems.Len()
+ s.or = -1
+Loop:
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ t.parseError("missing .end for .section")
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokEnd:
+ break Loop
+ case tokOr:
+ if s.or >= 0 {
+ t.parseError("extra .or in .section")
+ break Loop
+ }
+ s.or = t.elems.Len()
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ case tokAlternates:
+ t.parseError(".alternates not in .repeated")
+ default:
+ t.parseError("internal error: unknown section item: %s", item)
+ }
+ }
+ s.end = t.elems.Len()
+ return s
+}
+
+func (t *Template) parse() {
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokOr, tokEnd, tokAlternates:
+ t.parseError("unexpected %s", w[0])
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ default:
+ t.parseError("internal error: bad directive in parse: %s", item)
+ }
+ }
+}
+
+// -- Execution
+
+// Evaluate interfaces and pointers looking for a value that can look up the name, via a
+// struct field, method, or map key, and return the result of the lookup.
+func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value {
+ for v != nil {
+ typ := v.Type()
+ if n := v.Type().NumMethod(); n > 0 {
+ for i := 0; i < n; i++ {
+ m := typ.Method(i)
+ mtyp := m.Type
+ if m.Name == name && mtyp.NumIn() == 1 && mtyp.NumOut() == 1 {
+ if !isExported(name) {
+ t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type())
+ }
+ return v.Method(i).Call(nil)[0]
+ }
+ }
+ }
+ switch av := v.(type) {
+ case *reflect.PtrValue:
+ v = av.Elem()
+ case *reflect.InterfaceValue:
+ v = av.Elem()
+ case *reflect.StructValue:
+ if !isExported(name) {
+ t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type())
+ }
+ return av.FieldByName(name)
+ case *reflect.MapValue:
+ if v := av.Elem(reflect.NewValue(name)); v != nil {
+ return v
+ }
+ return reflect.MakeZero(typ.(*reflect.MapType).Elem())
+ default:
+ return nil
+ }
+ }
+ return v
+}
+
+// indirectPtr returns the item numLevels levels of indirection below the value.
+// It is forgiving: if the value is not a pointer, it returns it rather than giving
+// an error. If the pointer is nil, it is returned as is.
+func indirectPtr(v reflect.Value, numLevels int) reflect.Value {
+ for i := numLevels; v != nil && i > 0; i++ {
+ if p, ok := v.(*reflect.PtrValue); ok {
+ if p.IsNil() {
+ return v
+ }
+ v = p.Elem()
+ } else {
+ break
+ }
+ }
+ return v
+}
+
+// Walk v through pointers and interfaces, extracting the elements within.
+func indirect(v reflect.Value) reflect.Value {
+loop:
+ for v != nil {
+ switch av := v.(type) {
+ case *reflect.PtrValue:
+ v = av.Elem()
+ case *reflect.InterfaceValue:
+ v = av.Elem()
+ default:
+ break loop
+ }
+ }
+ return v
+}
+
+// If the data for this template is a struct, find the named variable.
+// Names of the form a.b.c are walked down the data tree.
+// The special name "@" (the "cursor") denotes the current data.
+// The value coming in (st.data) might need indirecting to reach
+// a struct while the return value is not indirected - that is,
+// it represents the actual named field. Leading stars indicate
+// levels of indirection to be applied to the value.
+func (t *Template) findVar(st *state, s string) reflect.Value {
+ data := st.data
+ flattenedName := strings.TrimLeft(s, "*")
+ numStars := len(s) - len(flattenedName)
+ s = flattenedName
+ if s == "@" {
+ return indirectPtr(data, numStars)
+ }
+ for _, elem := range strings.Split(s, ".") {
+ // Look up field; data must be a struct or map.
+ data = t.lookup(st, data, elem)
+ if data == nil {
+ return nil
+ }
+ }
+ return indirectPtr(data, numStars)
+}
+
+// Is there no data to look at?
+func empty(v reflect.Value) bool {
+ v = indirect(v)
+ if v == nil {
+ return true
+ }
+ switch v := v.(type) {
+ case *reflect.BoolValue:
+ return v.Get() == false
+ case *reflect.StringValue:
+ return v.Get() == ""
+ case *reflect.StructValue:
+ return false
+ case *reflect.MapValue:
+ return false
+ case *reflect.ArrayValue:
+ return v.Len() == 0
+ case *reflect.SliceValue:
+ return v.Len() == 0
+ }
+ return false
+}
+
+// Look up a variable or method, up through the parent if necessary.
+func (t *Template) varValue(name string, st *state) reflect.Value {
+ field := t.findVar(st, name)
+ if field == nil {
+ if st.parent == nil {
+ t.execError(st, t.linenum, "name not found: %s in type %s", name, st.data.Type())
+ }
+ return t.varValue(name, st.parent)
+ }
+ return field
+}
+
+func (t *Template) format(wr io.Writer, fmt string, val []interface{}, v *variableElement, st *state) {
+ fn := t.formatter(fmt)
+ if fn == nil {
+ t.execError(st, v.linenum, "missing formatter %s for variable %s", fmt, v.word[0])
+ }
+ fn(wr, fmt, val...)
+}
+
+// Evaluate a variable, looking up through the parent if necessary.
+// If it has a formatter attached ({var|formatter}) run that too.
+func (t *Template) writeVariable(v *variableElement, st *state) {
+ // Turn the words of the invocation into values.
+ val := make([]interface{}, len(v.word))
+ for i, word := range v.word {
+ val[i] = t.varValue(word, st).Interface()
+ }
+
+ for i, fmt := range v.fmts[:len(v.fmts)-1] {
+ b := &st.buf[i&1]
+ b.Reset()
+ t.format(b, fmt, val, v, st)
+ val = val[0:1]
+ val[0] = b.Bytes()
+ }
+ t.format(st.wr, v.fmts[len(v.fmts)-1], val, v, st)
+}
+
+// Execute element i. Return next index to execute.
+func (t *Template) executeElement(i int, st *state) int {
+ switch elem := t.elems.At(i).(type) {
+ case *textElement:
+ st.wr.Write(elem.text)
+ return i + 1
+ case *literalElement:
+ st.wr.Write(elem.text)
+ return i + 1
+ case *variableElement:
+ t.writeVariable(elem, st)
+ return i + 1
+ case *sectionElement:
+ t.executeSection(elem, st)
+ return elem.end
+ case *repeatedElement:
+ t.executeRepeated(elem, st)
+ return elem.end
+ }
+ e := t.elems.At(i)
+ t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.NewValue(e).Interface(), e)
+ return 0
+}
+
+// Execute the template.
+func (t *Template) execute(start, end int, st *state) {
+ for i := start; i < end; {
+ i = t.executeElement(i, st)
+ }
+}
+
+// Execute a .section
+func (t *Template) executeSection(s *sectionElement, st *state) {
+ // Find driver data for this section. It must be in the current struct.
+ field := t.varValue(s.field, st)
+ if field == nil {
+ t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, st.data.Type())
+ }
+ st = st.clone(field)
+ start, end := s.start, s.or
+ if !empty(field) {
+ // Execute the normal block.
+ if end < 0 {
+ end = s.end
+ }
+ } else {
+ // Execute the .or block. If it's missing, do nothing.
+ start, end = s.or, s.end
+ if start < 0 {
+ return
+ }
+ }
+ for i := start; i < end; {
+ i = t.executeElement(i, st)
+ }
+}
+
+// Return the result of calling the Iter method on v, or nil.
+func iter(v reflect.Value) *reflect.ChanValue {
+ for j := 0; j < v.Type().NumMethod(); j++ {
+ mth := v.Type().Method(j)
+ fv := v.Method(j)
+ ft := fv.Type().(*reflect.FuncType)
+ // TODO(rsc): NumIn() should return 0 here, because ft is from a curried FuncValue.
+ if mth.Name != "Iter" || ft.NumIn() != 1 || ft.NumOut() != 1 {
+ continue
+ }
+ ct, ok := ft.Out(0).(*reflect.ChanType)
+ if !ok || ct.Dir()&reflect.RecvDir == 0 {
+ continue
+ }
+ return fv.Call(nil)[0].(*reflect.ChanValue)
+ }
+ return nil
+}
+
+// Execute a .repeated section
+func (t *Template) executeRepeated(r *repeatedElement, st *state) {
+ // Find driver data for this section. It must be in the current struct.
+ field := t.varValue(r.field, st)
+ if field == nil {
+ t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, st.data.Type())
+ }
+ field = indirect(field)
+
+ start, end := r.start, r.or
+ if end < 0 {
+ end = r.end
+ }
+ if r.altstart >= 0 {
+ end = r.altstart
+ }
+ first := true
+
+ // Code common to all the loops.
+ loopBody := func(newst *state) {
+ // .alternates between elements
+ if !first && r.altstart >= 0 {
+ for i := r.altstart; i < r.altend; {
+ i = t.executeElement(i, newst)
+ }
+ }
+ first = false
+ for i := start; i < end; {
+ i = t.executeElement(i, newst)
+ }
+ }
+
+ if array, ok := field.(reflect.ArrayOrSliceValue); ok {
+ for j := 0; j < array.Len(); j++ {
+ loopBody(st.clone(array.Elem(j)))
+ }
+ } else if m, ok := field.(*reflect.MapValue); ok {
+ for _, key := range m.Keys() {
+ loopBody(st.clone(m.Elem(key)))
+ }
+ } else if ch := iter(field); ch != nil {
+ for {
+ e, ok := ch.Recv()
+ if !ok {
+ break
+ }
+ loopBody(st.clone(e))
+ }
+ } else {
+ t.execError(st, r.linenum, ".repeated: cannot repeat %s (type %s)",
+ r.field, field.Type())
+ }
+
+ if first {
+ // Empty. Execute the .or block, once. If it's missing, do nothing.
+ start, end := r.or, r.end
+ if start >= 0 {
+ newst := st.clone(field)
+ for i := start; i < end; {
+ i = t.executeElement(i, newst)
+ }
+ }
+ return
+ }
+}
+
+// A valid delimiter must contain no white space and be non-empty.
+func validDelim(d []byte) bool {
+ if len(d) == 0 {
+ return false
+ }
+ for _, c := range d {
+ if white(c) {
+ return false
+ }
+ }
+ return true
+}
+
+// checkError is a deferred function to turn a panic with type *Error into a plain error return.
+// Other panics are unexpected and so are re-enabled.
+func checkError(error *os.Error) {
+ if v := recover(); v != nil {
+ if e, ok := v.(*Error); ok {
+ *error = e
+ } else {
+ // runtime errors should crash
+ panic(v)
+ }
+ }
+}
+
+// -- Public interface
+
+// Parse initializes a Template by parsing its definition. The string
+// s contains the template text. If any errors occur, Parse returns
+// the error.
+func (t *Template) Parse(s string) (err os.Error) {
+ if t.elems == nil {
+ return &Error{1, "template not allocated with New"}
+ }
+ if !validDelim(t.ldelim) || !validDelim(t.rdelim) {
+ return &Error{1, fmt.Sprintf("bad delimiter strings %q %q", t.ldelim, t.rdelim)}
+ }
+ defer checkError(&err)
+ t.buf = []byte(s)
+ t.p = 0
+ t.linenum = 1
+ t.parse()
+ return nil
+}
+
+// ParseFile is like Parse but reads the template definition from the
+// named file.
+func (t *Template) ParseFile(filename string) (err os.Error) {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return err
+ }
+ return t.Parse(string(b))
+}
+
+// Execute applies a parsed template to the specified data object,
+// generating output to wr.
+func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) {
+ // Extract the driver data.
+ val := reflect.NewValue(data)
+ defer checkError(&err)
+ t.p = 0
+ t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr})
+ return nil
+}
+
+// SetDelims sets the left and right delimiters for operations in the
+// template. They are validated during parsing. They could be
+// validated here but it's better to keep the routine simple. The
+// delimiters are very rarely invalid and Parse has the necessary
+// error-handling interface already.
+func (t *Template) SetDelims(left, right string) {
+ t.ldelim = []byte(left)
+ t.rdelim = []byte(right)
+}
+
+// Parse creates a Template with default parameters (such as {} for
+// metacharacters). The string s contains the template text while
+// the formatter map fmap, which may be nil, defines auxiliary functions
+// for formatting variables. The template is returned. If any errors
+// occur, err will be non-nil.
+func Parse(s string, fmap FormatterMap) (t *Template, err os.Error) {
+ t = New(fmap)
+ err = t.Parse(s)
+ if err != nil {
+ t = nil
+ }
+ return
+}
+
+// ParseFile is a wrapper function that creates a Template with default
+// parameters (such as {} for metacharacters). The filename identifies
+// a file containing the template text, while the formatter map fmap, which
+// may be nil, defines auxiliary functions for formatting variables.
+// The template is returned. If any errors occur, err will be non-nil.
+func ParseFile(filename string, fmap FormatterMap) (t *Template, err os.Error) {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ return Parse(string(b), fmap)
+}
+
+// MustParse is like Parse but panics if the template cannot be parsed.
+func MustParse(s string, fmap FormatterMap) *Template {
+ t, err := Parse(s, fmap)
+ if err != nil {
+ panic("template.MustParse error: " + err.String())
+ }
+ return t
+}
+
+// MustParseFile is like ParseFile but panics if the file cannot be read
+// or the template cannot be parsed.
+func MustParseFile(filename string, fmap FormatterMap) *Template {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ panic("template.MustParseFile error: " + err.String())
+ }
+ return MustParse(string(b), fmap)
+}
diff --git a/src/cmd/fix/testdata/reflect.template.go.out b/src/cmd/fix/testdata/reflect.template.go.out
new file mode 100644
index 000000000..f2f56ef3c
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.template.go.out
@@ -0,0 +1,1044 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ Data-driven templates for generating textual output such as
+ HTML.
+
+ Templates are executed by applying them to a data structure.
+ Annotations in the template refer to elements of the data
+ structure (typically a field of a struct or a key in a map)
+ to control execution and derive values to be displayed.
+ The template walks the structure as it executes and the
+ "cursor" @ represents the value at the current location
+ in the structure.
+
+ Data items may be values or pointers; the interface hides the
+ indirection.
+
+ In the following, 'field' is one of several things, according to the data.
+
+ - The name of a field of a struct (result = data.field),
+ - The value stored in a map under that key (result = data[field]), or
+ - The result of invoking a niladic single-valued method with that name
+ (result = data.field())
+
+ Major constructs ({} are the default delimiters for template actions;
+ [] are the notation in this comment for optional elements):
+
+ {# comment }
+
+ A one-line comment.
+
+ {.section field} XXX [ {.or} YYY ] {.end}
+
+ Set @ to the value of the field. It may be an explicit @
+ to stay at the same point in the data. If the field is nil
+ or empty, execute YYY; otherwise execute XXX.
+
+ {.repeated section field} XXX [ {.alternates with} ZZZ ] [ {.or} YYY ] {.end}
+
+ Like .section, but field must be an array or slice. XXX
+ is executed for each element. If the array is nil or empty,
+ YYY is executed instead. If the {.alternates with} marker
+ is present, ZZZ is executed between iterations of XXX.
+
+ {field}
+ {field1 field2 ...}
+ {field|formatter}
+ {field1 field2...|formatter}
+ {field|formatter1|formatter2}
+
+ Insert the value of the fields into the output. Each field is
+ first looked for in the cursor, as in .section and .repeated.
+ If it is not found, the search continues in outer sections
+ until the top level is reached.
+
+ If the field value is a pointer, leading asterisks indicate
+ that the value to be inserted should be evaluated through the
+ pointer. For example, if x.p is of type *int, {x.p} will
+ insert the value of the pointer but {*x.p} will insert the
+ value of the underlying integer. If the value is nil or not a
+ pointer, asterisks have no effect.
+
+ If a formatter is specified, it must be named in the formatter
+ map passed to the template set up routines or in the default
+ set ("html","str","") and is used to process the data for
+ output. The formatter function has signature
+ func(wr io.Writer, formatter string, data ...interface{})
+ where wr is the destination for output, data holds the field
+ values at the instantiation, and formatter is its name at
+ the invocation site. The default formatter just concatenates
+ the string representations of the fields.
+
+ Multiple formatters separated by the pipeline character | are
+ executed sequentially, with each formatter receiving the bytes
+ emitted by the one to its left.
+
+ The delimiter strings get their default value, "{" and "}", from
+ JSON-template. They may be set to any non-empty, space-free
+ string using the SetDelims method. Their value can be printed
+ in the output using {.meta-left} and {.meta-right}.
+*/
+package template
+
+import (
+ "bytes"
+ "container/vector"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "reflect"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// Errors returned during parsing and execution. Users may extract the information and reformat
+// if they desire.
+type Error struct {
+ Line int
+ Msg string
+}
+
+func (e *Error) String() string { return fmt.Sprintf("line %d: %s", e.Line, e.Msg) }
+
+// Most of the literals are aces.
+var lbrace = []byte{'{'}
+var rbrace = []byte{'}'}
+var space = []byte{' '}
+var tab = []byte{'\t'}
+
+// The various types of "tokens", which are plain text or (usually) brace-delimited descriptors
+const (
+ tokAlternates = iota
+ tokComment
+ tokEnd
+ tokLiteral
+ tokOr
+ tokRepeated
+ tokSection
+ tokText
+ tokVariable
+)
+
+// FormatterMap is the type describing the mapping from formatter
+// names to the functions that implement them.
+type FormatterMap map[string]func(io.Writer, string, ...interface{})
+
+// Built-in formatters.
+var builtins = FormatterMap{
+ "html": HTMLFormatter,
+ "str": StringFormatter,
+ "": StringFormatter,
+}
+
+// The parsed state of a template is a vector of xxxElement structs.
+// Sections have line numbers so errors can be reported better during execution.
+
+// Plain text.
+type textElement struct {
+ text []byte
+}
+
+// A literal such as .meta-left or .meta-right
+type literalElement struct {
+ text []byte
+}
+
+// A variable invocation to be evaluated
+type variableElement struct {
+ linenum int
+ word []string // The fields in the invocation.
+ fmts []string // Names of formatters to apply. len(fmts) > 0
+}
+
+// A .section block, possibly with a .or
+type sectionElement struct {
+ linenum int // of .section itself
+ field string // cursor field for this block
+ start int // first element
+ or int // first element of .or block
+ end int // one beyond last element
+}
+
+// A .repeated block, possibly with a .or and a .alternates
+type repeatedElement struct {
+ sectionElement // It has the same structure...
+ altstart int // ... except for alternates
+ altend int
+}
+
+// Template is the type that represents a template definition.
+// It is unchanged after parsing.
+type Template struct {
+ fmap FormatterMap // formatters for variables
+ // Used during parsing:
+ ldelim, rdelim []byte // delimiters; default {}
+ buf []byte // input text to process
+ p int // position in buf
+ linenum int // position in input
+ // Parsed results:
+ elems *vector.Vector
+}
+
+// Internal state for executing a Template. As we evaluate the struct,
+// the data item descends into the fields associated with sections, etc.
+// Parent is used to walk upwards to find variables higher in the tree.
+type state struct {
+ parent *state // parent in hierarchy
+ data reflect.Value // the driver data for this section etc.
+ wr io.Writer // where to send output
+ buf [2]bytes.Buffer // alternating buffers used when chaining formatters
+}
+
+func (parent *state) clone(data reflect.Value) *state {
+ return &state{parent: parent, data: data, wr: parent.wr}
+}
+
+// New creates a new template with the specified formatter map (which
+// may be nil) to define auxiliary functions for formatting variables.
+func New(fmap FormatterMap) *Template {
+ t := new(Template)
+ t.fmap = fmap
+ t.ldelim = lbrace
+ t.rdelim = rbrace
+ t.elems = new(vector.Vector)
+ return t
+}
+
+// Report error and stop executing. The line number must be provided explicitly.
+func (t *Template) execError(st *state, line int, err string, args ...interface{}) {
+ panic(&Error{line, fmt.Sprintf(err, args...)})
+}
+
+// Report error, panic to terminate parsing.
+// The line number comes from the template state.
+func (t *Template) parseError(err string, args ...interface{}) {
+ panic(&Error{t.linenum, fmt.Sprintf(err, args...)})
+}
+
+// Is this an exported - upper case - name?
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// -- Lexical analysis
+
+// Is c a white space character?
+func white(c uint8) bool { return c == ' ' || c == '\t' || c == '\r' || c == '\n' }
+
+// Safely, does s[n:n+len(t)] == t?
+func equal(s []byte, n int, t []byte) bool {
+ b := s[n:]
+ if len(t) > len(b) { // not enough space left for a match.
+ return false
+ }
+ for i, c := range t {
+ if c != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// nextItem returns the next item from the input buffer. If the returned
+// item is empty, we are at EOF. The item will be either a
+// delimited string or a non-empty string between delimited
+// strings. Tokens stop at (but include, if plain text) a newline.
+// Action tokens on a line by themselves drop any space on
+// either side, up to and including the newline.
+func (t *Template) nextItem() []byte {
+ startOfLine := t.p == 0 || t.buf[t.p-1] == '\n'
+ start := t.p
+ var i int
+ newline := func() {
+ t.linenum++
+ i++
+ }
+ // Leading white space up to but not including newline
+ for i = start; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' || !white(t.buf[i]) {
+ break
+ }
+ }
+ leadingSpace := i > start
+ // What's left is nothing, newline, delimited string, or plain text
+ switch {
+ case i == len(t.buf):
+ // EOF; nothing to do
+ case t.buf[i] == '\n':
+ newline()
+ case equal(t.buf, i, t.ldelim):
+ left := i // Start of left delimiter.
+ right := -1 // Will be (immediately after) right delimiter.
+ haveText := false // Delimiters contain text.
+ i += len(t.ldelim)
+ // Find the end of the action.
+ for ; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' {
+ break
+ }
+ if equal(t.buf, i, t.rdelim) {
+ i += len(t.rdelim)
+ right = i
+ break
+ }
+ haveText = true
+ }
+ if right < 0 {
+ t.parseError("unmatched opening delimiter")
+ return nil
+ }
+ // Is this a special action (starts with '.' or '#') and the only thing on the line?
+ if startOfLine && haveText {
+ firstChar := t.buf[left+len(t.ldelim)]
+ if firstChar == '.' || firstChar == '#' {
+ // It's special and the first thing on the line. Is it the last?
+ for j := right; j < len(t.buf) && white(t.buf[j]); j++ {
+ if t.buf[j] == '\n' {
+ // Yes it is. Drop the surrounding space and return the {.foo}
+ t.linenum++
+ t.p = j + 1
+ return t.buf[left:right]
+ }
+ }
+ }
+ }
+ // No it's not. If there's leading space, return that.
+ if leadingSpace {
+ // not trimming space: return leading white space if there is some.
+ t.p = left
+ return t.buf[start:left]
+ }
+ // Return the word, leave the trailing space.
+ start = left
+ break
+ default:
+ for ; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' {
+ newline()
+ break
+ }
+ if equal(t.buf, i, t.ldelim) {
+ break
+ }
+ }
+ }
+ item := t.buf[start:i]
+ t.p = i
+ return item
+}
+
+// Turn a byte array into a white-space-split array of strings.
+func words(buf []byte) []string {
+ s := make([]string, 0, 5)
+ p := 0 // position in buf
+ // one word per loop
+ for i := 0; ; i++ {
+ // skip white space
+ for ; p < len(buf) && white(buf[p]); p++ {
+ }
+ // grab word
+ start := p
+ for ; p < len(buf) && !white(buf[p]); p++ {
+ }
+ if start == p { // no text left
+ break
+ }
+ s = append(s, string(buf[start:p]))
+ }
+ return s
+}
+
+// Analyze an item and return its token type and, if it's an action item, an array of
+// its constituent words.
+func (t *Template) analyze(item []byte) (tok int, w []string) {
+ // item is known to be non-empty
+ if !equal(item, 0, t.ldelim) { // doesn't start with left delimiter
+ tok = tokText
+ return
+ }
+ if !equal(item, len(item)-len(t.rdelim), t.rdelim) { // doesn't end with right delimiter
+ t.parseError("internal error: unmatched opening delimiter") // lexing should prevent this
+ return
+ }
+ if len(item) <= len(t.ldelim)+len(t.rdelim) { // no contents
+ t.parseError("empty directive")
+ return
+ }
+ // Comment
+ if item[len(t.ldelim)] == '#' {
+ tok = tokComment
+ return
+ }
+ // Split into words
+ w = words(item[len(t.ldelim) : len(item)-len(t.rdelim)]) // drop final delimiter
+ if len(w) == 0 {
+ t.parseError("empty directive")
+ return
+ }
+ if len(w) > 0 && w[0][0] != '.' {
+ tok = tokVariable
+ return
+ }
+ switch w[0] {
+ case ".meta-left", ".meta-right", ".space", ".tab":
+ tok = tokLiteral
+ return
+ case ".or":
+ tok = tokOr
+ return
+ case ".end":
+ tok = tokEnd
+ return
+ case ".section":
+ if len(w) != 2 {
+ t.parseError("incorrect fields for .section: %s", item)
+ return
+ }
+ tok = tokSection
+ return
+ case ".repeated":
+ if len(w) != 3 || w[1] != "section" {
+ t.parseError("incorrect fields for .repeated: %s", item)
+ return
+ }
+ tok = tokRepeated
+ return
+ case ".alternates":
+ if len(w) != 2 || w[1] != "with" {
+ t.parseError("incorrect fields for .alternates: %s", item)
+ return
+ }
+ tok = tokAlternates
+ return
+ }
+ t.parseError("bad directive: %s", item)
+ return
+}
+
+// formatter returns the Formatter with the given name in the Template, or nil if none exists.
+func (t *Template) formatter(name string) func(io.Writer, string, ...interface{}) {
+ if t.fmap != nil {
+ if fn := t.fmap[name]; fn != nil {
+ return fn
+ }
+ }
+ return builtins[name]
+}
+
+// -- Parsing
+
+// Allocate a new variable-evaluation element.
+func (t *Template) newVariable(words []string) *variableElement {
+ // After the final space-separated argument, formatters may be specified separated
+ // by pipe symbols, for example: {a b c|d|e}
+
+ // Until we learn otherwise, formatters contains a single name: "", the default formatter.
+ formatters := []string{""}
+ lastWord := words[len(words)-1]
+ bar := strings.IndexRune(lastWord, '|')
+ if bar >= 0 {
+ words[len(words)-1] = lastWord[0:bar]
+ formatters = strings.Split(lastWord[bar+1:], "|")
+ }
+
+ // We could remember the function address here and avoid the lookup later,
+ // but it's more dynamic to let the user change the map contents underfoot.
+ // We do require the name to be present, though.
+
+ // Is it in user-supplied map?
+ for _, f := range formatters {
+ if t.formatter(f) == nil {
+ t.parseError("unknown formatter: %q", f)
+ }
+ }
+ return &variableElement{t.linenum, words, formatters}
+}
+
+// Grab the next item. If it's simple, just append it to the template.
+// Otherwise return its details.
+func (t *Template) parseSimple(item []byte) (done bool, tok int, w []string) {
+ tok, w = t.analyze(item)
+ done = true // assume for simplicity
+ switch tok {
+ case tokComment:
+ return
+ case tokText:
+ t.elems.Push(&textElement{item})
+ return
+ case tokLiteral:
+ switch w[0] {
+ case ".meta-left":
+ t.elems.Push(&literalElement{t.ldelim})
+ case ".meta-right":
+ t.elems.Push(&literalElement{t.rdelim})
+ case ".space":
+ t.elems.Push(&literalElement{space})
+ case ".tab":
+ t.elems.Push(&literalElement{tab})
+ default:
+ t.parseError("internal error: unknown literal: %s", w[0])
+ }
+ return
+ case tokVariable:
+ t.elems.Push(t.newVariable(w))
+ return
+ }
+ return false, tok, w
+}
+
+// parseRepeated and parseSection are mutually recursive
+
+func (t *Template) parseRepeated(words []string) *repeatedElement {
+ r := new(repeatedElement)
+ t.elems.Push(r)
+ r.linenum = t.linenum
+ r.field = words[2]
+ // Scan section, collecting true and false (.or) blocks.
+ r.start = t.elems.Len()
+ r.or = -1
+ r.altstart = -1
+ r.altend = -1
+Loop:
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ t.parseError("missing .end for .repeated section")
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokEnd:
+ break Loop
+ case tokOr:
+ if r.or >= 0 {
+ t.parseError("extra .or in .repeated section")
+ break Loop
+ }
+ r.altend = t.elems.Len()
+ r.or = t.elems.Len()
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ case tokAlternates:
+ if r.altstart >= 0 {
+ t.parseError("extra .alternates in .repeated section")
+ break Loop
+ }
+ if r.or >= 0 {
+ t.parseError(".alternates inside .or block in .repeated section")
+ break Loop
+ }
+ r.altstart = t.elems.Len()
+ default:
+ t.parseError("internal error: unknown repeated section item: %s", item)
+ break Loop
+ }
+ }
+ if r.altend < 0 {
+ r.altend = t.elems.Len()
+ }
+ r.end = t.elems.Len()
+ return r
+}
+
+func (t *Template) parseSection(words []string) *sectionElement {
+ s := new(sectionElement)
+ t.elems.Push(s)
+ s.linenum = t.linenum
+ s.field = words[1]
+ // Scan section, collecting true and false (.or) blocks.
+ s.start = t.elems.Len()
+ s.or = -1
+Loop:
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ t.parseError("missing .end for .section")
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokEnd:
+ break Loop
+ case tokOr:
+ if s.or >= 0 {
+ t.parseError("extra .or in .section")
+ break Loop
+ }
+ s.or = t.elems.Len()
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ case tokAlternates:
+ t.parseError(".alternates not in .repeated")
+ default:
+ t.parseError("internal error: unknown section item: %s", item)
+ }
+ }
+ s.end = t.elems.Len()
+ return s
+}
+
+func (t *Template) parse() {
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokOr, tokEnd, tokAlternates:
+ t.parseError("unexpected %s", w[0])
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ default:
+ t.parseError("internal error: bad directive in parse: %s", item)
+ }
+ }
+}
+
+// -- Execution
+
+// Evaluate interfaces and pointers looking for a value that can look up the name, via a
+// struct field, method, or map key, and return the result of the lookup.
+func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value {
+ for v.IsValid() {
+ typ := v.Type()
+ if n := v.Type().NumMethod(); n > 0 {
+ for i := 0; i < n; i++ {
+ m := typ.Method(i)
+ mtyp := m.Type
+ if m.Name == name && mtyp.NumIn() == 1 && mtyp.NumOut() == 1 {
+ if !isExported(name) {
+ t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type())
+ }
+ return v.Method(i).Call(nil)[0]
+ }
+ }
+ }
+ switch av := v; av.Kind() {
+ case reflect.Ptr:
+ v = av.Elem()
+ case reflect.Interface:
+ v = av.Elem()
+ case reflect.Struct:
+ if !isExported(name) {
+ t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type())
+ }
+ return av.FieldByName(name)
+ case reflect.Map:
+ if v := av.MapIndex(reflect.ValueOf(name)); v.IsValid() {
+ return v
+ }
+ return reflect.Zero(typ.Elem())
+ default:
+ return reflect.Value{}
+ }
+ }
+ return v
+}
+
+// indirectPtr returns the item numLevels levels of indirection below the value.
+// It is forgiving: if the value is not a pointer, it returns it rather than giving
+// an error. If the pointer is nil, it is returned as is.
+func indirectPtr(v reflect.Value, numLevels int) reflect.Value {
+ for i := numLevels; v.IsValid() && i > 0; i++ {
+ if p := v; p.Kind() == reflect.Ptr {
+ if p.IsNil() {
+ return v
+ }
+ v = p.Elem()
+ } else {
+ break
+ }
+ }
+ return v
+}
+
+// Walk v through pointers and interfaces, extracting the elements within.
+func indirect(v reflect.Value) reflect.Value {
+loop:
+ for v.IsValid() {
+ switch av := v; av.Kind() {
+ case reflect.Ptr:
+ v = av.Elem()
+ case reflect.Interface:
+ v = av.Elem()
+ default:
+ break loop
+ }
+ }
+ return v
+}
+
+// If the data for this template is a struct, find the named variable.
+// Names of the form a.b.c are walked down the data tree.
+// The special name "@" (the "cursor") denotes the current data.
+// The value coming in (st.data) might need indirecting to reach
+// a struct while the return value is not indirected - that is,
+// it represents the actual named field. Leading stars indicate
+// levels of indirection to be applied to the value.
+func (t *Template) findVar(st *state, s string) reflect.Value {
+ data := st.data
+ flattenedName := strings.TrimLeft(s, "*")
+ numStars := len(s) - len(flattenedName)
+ s = flattenedName
+ if s == "@" {
+ return indirectPtr(data, numStars)
+ }
+ for _, elem := range strings.Split(s, ".") {
+ // Look up field; data must be a struct or map.
+ data = t.lookup(st, data, elem)
+ if !data.IsValid() {
+ return reflect.Value{}
+ }
+ }
+ return indirectPtr(data, numStars)
+}
+
+// Is there no data to look at?
+func empty(v reflect.Value) bool {
+ v = indirect(v)
+ if !v.IsValid() {
+ return true
+ }
+ switch v.Kind() {
+ case reflect.Bool:
+ return v.Bool() == false
+ case reflect.String:
+ return v.String() == ""
+ case reflect.Struct:
+ return false
+ case reflect.Map:
+ return false
+ case reflect.Array:
+ return v.Len() == 0
+ case reflect.Slice:
+ return v.Len() == 0
+ }
+ return false
+}
+
+// Look up a variable or method, up through the parent if necessary.
+func (t *Template) varValue(name string, st *state) reflect.Value {
+ field := t.findVar(st, name)
+ if !field.IsValid() {
+ if st.parent == nil {
+ t.execError(st, t.linenum, "name not found: %s in type %s", name, st.data.Type())
+ }
+ return t.varValue(name, st.parent)
+ }
+ return field
+}
+
+func (t *Template) format(wr io.Writer, fmt string, val []interface{}, v *variableElement, st *state) {
+ fn := t.formatter(fmt)
+ if fn == nil {
+ t.execError(st, v.linenum, "missing formatter %s for variable %s", fmt, v.word[0])
+ }
+ fn(wr, fmt, val...)
+}
+
+// Evaluate a variable, looking up through the parent if necessary.
+// If it has a formatter attached ({var|formatter}) run that too.
+func (t *Template) writeVariable(v *variableElement, st *state) {
+ // Turn the words of the invocation into values.
+ val := make([]interface{}, len(v.word))
+ for i, word := range v.word {
+ val[i] = t.varValue(word, st).Interface()
+ }
+
+ for i, fmt := range v.fmts[:len(v.fmts)-1] {
+ b := &st.buf[i&1]
+ b.Reset()
+ t.format(b, fmt, val, v, st)
+ val = val[0:1]
+ val[0] = b.Bytes()
+ }
+ t.format(st.wr, v.fmts[len(v.fmts)-1], val, v, st)
+}
+
+// Execute element i. Return next index to execute.
+func (t *Template) executeElement(i int, st *state) int {
+ switch elem := t.elems.At(i).(type) {
+ case *textElement:
+ st.wr.Write(elem.text)
+ return i + 1
+ case *literalElement:
+ st.wr.Write(elem.text)
+ return i + 1
+ case *variableElement:
+ t.writeVariable(elem, st)
+ return i + 1
+ case *sectionElement:
+ t.executeSection(elem, st)
+ return elem.end
+ case *repeatedElement:
+ t.executeRepeated(elem, st)
+ return elem.end
+ }
+ e := t.elems.At(i)
+ t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.ValueOf(e).Interface(), e)
+ return 0
+}
+
+// Execute the template.
+func (t *Template) execute(start, end int, st *state) {
+ for i := start; i < end; {
+ i = t.executeElement(i, st)
+ }
+}
+
+// Execute a .section
+func (t *Template) executeSection(s *sectionElement, st *state) {
+ // Find driver data for this section. It must be in the current struct.
+ field := t.varValue(s.field, st)
+ if !field.IsValid() {
+ t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, st.data.Type())
+ }
+ st = st.clone(field)
+ start, end := s.start, s.or
+ if !empty(field) {
+ // Execute the normal block.
+ if end < 0 {
+ end = s.end
+ }
+ } else {
+ // Execute the .or block. If it's missing, do nothing.
+ start, end = s.or, s.end
+ if start < 0 {
+ return
+ }
+ }
+ for i := start; i < end; {
+ i = t.executeElement(i, st)
+ }
+}
+
+// Return the result of calling the Iter method on v, or nil.
+func iter(v reflect.Value) reflect.Value {
+ for j := 0; j < v.Type().NumMethod(); j++ {
+ mth := v.Type().Method(j)
+ fv := v.Method(j)
+ ft := fv.Type()
+ // TODO(rsc): NumIn() should return 0 here, because ft is from a curried FuncValue.
+ if mth.Name != "Iter" || ft.NumIn() != 1 || ft.NumOut() != 1 {
+ continue
+ }
+ ct := ft.Out(0)
+ if ct.Kind() != reflect.Chan ||
+ ct.ChanDir()&reflect.RecvDir == 0 {
+ continue
+ }
+ return fv.Call(nil)[0]
+ }
+ return reflect.Value{}
+}
+
+// Execute a .repeated section
+func (t *Template) executeRepeated(r *repeatedElement, st *state) {
+ // Find driver data for this section. It must be in the current struct.
+ field := t.varValue(r.field, st)
+ if !field.IsValid() {
+ t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, st.data.Type())
+ }
+ field = indirect(field)
+
+ start, end := r.start, r.or
+ if end < 0 {
+ end = r.end
+ }
+ if r.altstart >= 0 {
+ end = r.altstart
+ }
+ first := true
+
+ // Code common to all the loops.
+ loopBody := func(newst *state) {
+ // .alternates between elements
+ if !first && r.altstart >= 0 {
+ for i := r.altstart; i < r.altend; {
+ i = t.executeElement(i, newst)
+ }
+ }
+ first = false
+ for i := start; i < end; {
+ i = t.executeElement(i, newst)
+ }
+ }
+
+ if array := field; array.Kind() == reflect.Array || array.Kind() == reflect.Slice {
+ for j := 0; j < array.Len(); j++ {
+ loopBody(st.clone(array.Index(j)))
+ }
+ } else if m := field; m.Kind() == reflect.Map {
+ for _, key := range m.MapKeys() {
+ loopBody(st.clone(m.MapIndex(key)))
+ }
+ } else if ch := iter(field); ch.IsValid() {
+ for {
+ e, ok := ch.Recv()
+ if !ok {
+ break
+ }
+ loopBody(st.clone(e))
+ }
+ } else {
+ t.execError(st, r.linenum, ".repeated: cannot repeat %s (type %s)",
+ r.field, field.Type())
+ }
+
+ if first {
+ // Empty. Execute the .or block, once. If it's missing, do nothing.
+ start, end := r.or, r.end
+ if start >= 0 {
+ newst := st.clone(field)
+ for i := start; i < end; {
+ i = t.executeElement(i, newst)
+ }
+ }
+ return
+ }
+}
+
+// A valid delimiter must contain no white space and be non-empty.
+func validDelim(d []byte) bool {
+ if len(d) == 0 {
+ return false
+ }
+ for _, c := range d {
+ if white(c) {
+ return false
+ }
+ }
+ return true
+}
+
+// checkError is a deferred function to turn a panic with type *Error into a plain error return.
+// Other panics are unexpected and so are re-enabled.
+func checkError(error *os.Error) {
+ if v := recover(); v != nil {
+ if e, ok := v.(*Error); ok {
+ *error = e
+ } else {
+ // runtime errors should crash
+ panic(v)
+ }
+ }
+}
+
+// -- Public interface
+
+// Parse initializes a Template by parsing its definition. The string
+// s contains the template text. If any errors occur, Parse returns
+// the error.
+func (t *Template) Parse(s string) (err os.Error) {
+ if t.elems == nil {
+ return &Error{1, "template not allocated with New"}
+ }
+ if !validDelim(t.ldelim) || !validDelim(t.rdelim) {
+ return &Error{1, fmt.Sprintf("bad delimiter strings %q %q", t.ldelim, t.rdelim)}
+ }
+ defer checkError(&err)
+ t.buf = []byte(s)
+ t.p = 0
+ t.linenum = 1
+ t.parse()
+ return nil
+}
+
+// ParseFile is like Parse but reads the template definition from the
+// named file.
+func (t *Template) ParseFile(filename string) (err os.Error) {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return err
+ }
+ return t.Parse(string(b))
+}
+
+// Execute applies a parsed template to the specified data object,
+// generating output to wr.
+func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) {
+ // Extract the driver data.
+ val := reflect.ValueOf(data)
+ defer checkError(&err)
+ t.p = 0
+ t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr})
+ return nil
+}
+
+// SetDelims sets the left and right delimiters for operations in the
+// template. They are validated during parsing. They could be
+// validated here but it's better to keep the routine simple. The
+// delimiters are very rarely invalid and Parse has the necessary
+// error-handling interface already.
+func (t *Template) SetDelims(left, right string) {
+ t.ldelim = []byte(left)
+ t.rdelim = []byte(right)
+}
+
+// Parse creates a Template with default parameters (such as {} for
+// metacharacters). The string s contains the template text while
+// the formatter map fmap, which may be nil, defines auxiliary functions
+// for formatting variables. The template is returned. If any errors
+// occur, err will be non-nil.
+func Parse(s string, fmap FormatterMap) (t *Template, err os.Error) {
+ t = New(fmap)
+ err = t.Parse(s)
+ if err != nil {
+ t = nil
+ }
+ return
+}
+
+// ParseFile is a wrapper function that creates a Template with default
+// parameters (such as {} for metacharacters). The filename identifies
+// a file containing the template text, while the formatter map fmap, which
+// may be nil, defines auxiliary functions for formatting variables.
+// The template is returned. If any errors occur, err will be non-nil.
+func ParseFile(filename string, fmap FormatterMap) (t *Template, err os.Error) {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ return Parse(string(b), fmap)
+}
+
+// MustParse is like Parse but panics if the template cannot be parsed.
+func MustParse(s string, fmap FormatterMap) *Template {
+ t, err := Parse(s, fmap)
+ if err != nil {
+ panic("template.MustParse error: " + err.String())
+ }
+ return t
+}
+
+// MustParseFile is like ParseFile but panics if the file cannot be read
+// or the template cannot be parsed.
+func MustParseFile(filename string, fmap FormatterMap) *Template {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ panic("template.MustParseFile error: " + err.String())
+ }
+ return MustParse(string(b), fmap)
+}
diff --git a/src/cmd/fix/testdata/reflect.type.go.in b/src/cmd/fix/testdata/reflect.type.go.in
new file mode 100644
index 000000000..34963bef9
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.type.go.in
@@ -0,0 +1,790 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+ "sync"
+ "unicode"
+ "utf8"
+)
+
+// userTypeInfo stores the information associated with a type the user has handed
+// to the package. It's computed once and stored in a map keyed by reflection
+// type.
+type userTypeInfo struct {
+ user reflect.Type // the type the user handed us
+ base reflect.Type // the base type after all indirections
+ indir int // number of indirections to reach the base type
+ isGobEncoder bool // does the type implement GobEncoder?
+ isGobDecoder bool // does the type implement GobDecoder?
+ encIndir int8 // number of indirections to reach the receiver type; may be negative
+ decIndir int8 // number of indirections to reach the receiver type; may be negative
+}
+
+var (
+ // Protected by an RWMutex because we read it a lot and write
+ // it only when we see a new type, typically when compiling.
+ userTypeLock sync.RWMutex
+ userTypeCache = make(map[reflect.Type]*userTypeInfo)
+)
+
+// validType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, err will be non-nil. To be used when the error handler
+// is not set up.
+func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
+ userTypeLock.RLock()
+ ut = userTypeCache[rt]
+ userTypeLock.RUnlock()
+ if ut != nil {
+ return
+ }
+ // Now set the value under the write lock.
+ userTypeLock.Lock()
+ defer userTypeLock.Unlock()
+ if ut = userTypeCache[rt]; ut != nil {
+ // Lost the race; not a problem.
+ return
+ }
+ ut = new(userTypeInfo)
+ ut.base = rt
+ ut.user = rt
+ // A type that is just a cycle of pointers (such as type T *T) cannot
+ // be represented in gobs, which need some concrete data. We use a
+ // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
+ // pp 539-540. As we step through indirections, run another type at
+ // half speed. If they meet up, there's a cycle.
+ slowpoke := ut.base // walks half as fast as ut.base
+ for {
+ pt, ok := ut.base.(*reflect.PtrType)
+ if !ok {
+ break
+ }
+ ut.base = pt.Elem()
+ if ut.base == slowpoke { // ut.base lapped slowpoke
+ // recursive pointer type.
+ return nil, os.NewError("can't represent recursive pointer type " + ut.base.String())
+ }
+ if ut.indir%2 == 0 {
+ slowpoke = slowpoke.(*reflect.PtrType).Elem()
+ }
+ ut.indir++
+ }
+ ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck)
+ ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck)
+ userTypeCache[rt] = ut
+ return
+}
+
+const (
+ gobEncodeMethodName = "GobEncode"
+ gobDecodeMethodName = "GobDecode"
+)
+
+// implements returns whether the type implements the interface, as encoded
+// in the check function.
+func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool {
+ if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance
+ return false
+ }
+ return check(typ)
+}
+
+// gobEncoderCheck makes the type assertion a boolean function.
+func gobEncoderCheck(typ reflect.Type) bool {
+ _, ok := reflect.MakeZero(typ).Interface().(GobEncoder)
+ return ok
+}
+
+// gobDecoderCheck makes the type assertion a boolean function.
+func gobDecoderCheck(typ reflect.Type) bool {
+ _, ok := reflect.MakeZero(typ).Interface().(GobDecoder)
+ return ok
+}
+
+// implementsInterface reports whether the type implements the
+// interface. (The actual check is done through the provided function.)
+// It also returns the number of indirections required to get to the
+// implementation.
+func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) {
+ if typ == nil {
+ return
+ }
+ rt := typ
+ // The type might be a pointer and we need to keep
+ // dereferencing to the base type until we find an implementation.
+ for {
+ if implements(rt, check) {
+ return true, indir
+ }
+ if p, ok := rt.(*reflect.PtrType); ok {
+ indir++
+ if indir > 100 { // insane number of indirections
+ return false, 0
+ }
+ rt = p.Elem()
+ continue
+ }
+ break
+ }
+ // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy.
+ if _, ok := typ.(*reflect.PtrType); !ok {
+ // Not a pointer, but does the pointer work?
+ if implements(reflect.PtrTo(typ), check) {
+ return true, -1
+ }
+ }
+ return false, 0
+}
+
+// userType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, it calls error.
+func userType(rt reflect.Type) *userTypeInfo {
+ ut, err := validUserType(rt)
+ if err != nil {
+ error(err)
+ }
+ return ut
+}
+
+// A typeId represents a gob Type as an integer that can be passed on the wire.
+// Internally, typeIds are used as keys to a map to recover the underlying type info.
+type typeId int32
+
+var nextId typeId // incremented for each new type we build
+var typeLock sync.Mutex // set while building a type
+const firstUserId = 64 // lowest id number granted to user
+
+type gobType interface {
+ id() typeId
+ setId(id typeId)
+ name() string
+ string() string // not public; only for debugging
+ safeString(seen map[typeId]bool) string
+}
+
+var types = make(map[reflect.Type]gobType)
+var idToType = make(map[typeId]gobType)
+var builtinIdToType map[typeId]gobType // set in init() after builtins are established
+
+func setTypeId(typ gobType) {
+ nextId++
+ typ.setId(nextId)
+ idToType[nextId] = typ
+}
+
+func (t typeId) gobType() gobType {
+ if t == 0 {
+ return nil
+ }
+ return idToType[t]
+}
+
+// string returns the string representation of the type associated with the typeId.
+func (t typeId) string() string {
+ if t.gobType() == nil {
+ return "<nil>"
+ }
+ return t.gobType().string()
+}
+
+// Name returns the name of the type associated with the typeId.
+func (t typeId) name() string {
+ if t.gobType() == nil {
+ return "<nil>"
+ }
+ return t.gobType().name()
+}
+
+// Common elements of all types.
+type CommonType struct {
+ Name string
+ Id typeId
+}
+
+func (t *CommonType) id() typeId { return t.Id }
+
+func (t *CommonType) setId(id typeId) { t.Id = id }
+
+func (t *CommonType) string() string { return t.Name }
+
+func (t *CommonType) safeString(seen map[typeId]bool) string {
+ return t.Name
+}
+
+func (t *CommonType) name() string { return t.Name }
+
+// Create and check predefined types
+// The string for tBytes is "bytes" not "[]byte" to signify its specialness.
+
+var (
+ // Primordial types, needed during initialization.
+ // Always passed as pointers so the interface{} type
+ // goes through without losing its interfaceness.
+ tBool = bootstrapType("bool", (*bool)(nil), 1)
+ tInt = bootstrapType("int", (*int)(nil), 2)
+ tUint = bootstrapType("uint", (*uint)(nil), 3)
+ tFloat = bootstrapType("float", (*float64)(nil), 4)
+ tBytes = bootstrapType("bytes", (*[]byte)(nil), 5)
+ tString = bootstrapType("string", (*string)(nil), 6)
+ tComplex = bootstrapType("complex", (*complex128)(nil), 7)
+ tInterface = bootstrapType("interface", (*interface{})(nil), 8)
+ // Reserve some Ids for compatible expansion
+ tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil), 9)
+ tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil), 10)
+ tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil), 11)
+ tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil), 12)
+ tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil), 13)
+ tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil), 14)
+ tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil), 15)
+)
+
+// Predefined because it's needed by the Decoder
+var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id
+var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType)
+
+func init() {
+ // Some magic numbers to make sure there are no surprises.
+ checkId(16, tWireType)
+ checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id)
+ checkId(18, mustGetTypeInfo(reflect.Typeof(CommonType{})).id)
+ checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id)
+ checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id)
+ checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id)
+ checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id)
+
+ builtinIdToType = make(map[typeId]gobType)
+ for k, v := range idToType {
+ builtinIdToType[k] = v
+ }
+
+ // Move the id space upwards to allow for growth in the predefined world
+ // without breaking existing files.
+ if nextId > firstUserId {
+ panic(fmt.Sprintln("nextId too large:", nextId))
+ }
+ nextId = firstUserId
+ registerBasics()
+ wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil)))
+}
+
+// Array type
+type arrayType struct {
+ CommonType
+ Elem typeId
+ Len int
+}
+
+func newArrayType(name string) *arrayType {
+ a := &arrayType{CommonType{Name: name}, 0, 0}
+ return a
+}
+
+func (a *arrayType) init(elem gobType, len int) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(a)
+ a.Elem = elem.id()
+ a.Len = len
+}
+
+func (a *arrayType) safeString(seen map[typeId]bool) string {
+ if seen[a.Id] {
+ return a.Name
+ }
+ seen[a.Id] = true
+ return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen))
+}
+
+func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
+
+// GobEncoder type (something that implements the GobEncoder interface)
+type gobEncoderType struct {
+ CommonType
+}
+
+func newGobEncoderType(name string) *gobEncoderType {
+ g := &gobEncoderType{CommonType{Name: name}}
+ setTypeId(g)
+ return g
+}
+
+func (g *gobEncoderType) safeString(seen map[typeId]bool) string {
+ return g.Name
+}
+
+func (g *gobEncoderType) string() string { return g.Name }
+
+// Map type
+type mapType struct {
+ CommonType
+ Key typeId
+ Elem typeId
+}
+
+func newMapType(name string) *mapType {
+ m := &mapType{CommonType{Name: name}, 0, 0}
+ return m
+}
+
+func (m *mapType) init(key, elem gobType) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(m)
+ m.Key = key.id()
+ m.Elem = elem.id()
+}
+
+func (m *mapType) safeString(seen map[typeId]bool) string {
+ if seen[m.Id] {
+ return m.Name
+ }
+ seen[m.Id] = true
+ key := m.Key.gobType().safeString(seen)
+ elem := m.Elem.gobType().safeString(seen)
+ return fmt.Sprintf("map[%s]%s", key, elem)
+}
+
+func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) }
+
+// Slice type
+type sliceType struct {
+ CommonType
+ Elem typeId
+}
+
+func newSliceType(name string) *sliceType {
+ s := &sliceType{CommonType{Name: name}, 0}
+ return s
+}
+
+func (s *sliceType) init(elem gobType) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(s)
+ s.Elem = elem.id()
+}
+
+func (s *sliceType) safeString(seen map[typeId]bool) string {
+ if seen[s.Id] {
+ return s.Name
+ }
+ seen[s.Id] = true
+ return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen))
+}
+
+func (s *sliceType) string() string { return s.safeString(make(map[typeId]bool)) }
+
+// Struct type
+type fieldType struct {
+ Name string
+ Id typeId
+}
+
+type structType struct {
+ CommonType
+ Field []*fieldType
+}
+
+func (s *structType) safeString(seen map[typeId]bool) string {
+ if s == nil {
+ return "<nil>"
+ }
+ if _, ok := seen[s.Id]; ok {
+ return s.Name
+ }
+ seen[s.Id] = true
+ str := s.Name + " = struct { "
+ for _, f := range s.Field {
+ str += fmt.Sprintf("%s %s; ", f.Name, f.Id.gobType().safeString(seen))
+ }
+ str += "}"
+ return str
+}
+
+func (s *structType) string() string { return s.safeString(make(map[typeId]bool)) }
+
+func newStructType(name string) *structType {
+ s := &structType{CommonType{Name: name}, nil}
+ // For historical reasons we set the id here rather than init.
+ // See the comment in newTypeObject for details.
+ setTypeId(s)
+ return s
+}
+
+// newTypeObject allocates a gobType for the reflection type rt.
+// Unless ut represents a GobEncoder, rt should be the base type
+// of ut.
+// This is only called from the encoding side. The decoding side
+// works through typeIds and userTypeInfos alone.
+func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+ // Does this type implement GobEncoder?
+ if ut.isGobEncoder {
+ return newGobEncoderType(name), nil
+ }
+ var err os.Error
+ var type0, type1 gobType
+ defer func() {
+ if err != nil {
+ types[rt] = nil, false
+ }
+ }()
+ // Install the top-level type before the subtypes (e.g. struct before
+ // fields) so recursive types can be constructed safely.
+ switch t := rt.(type) {
+ // All basic types are easy: they are predefined.
+ case *reflect.BoolType:
+ return tBool.gobType(), nil
+
+ case *reflect.IntType:
+ return tInt.gobType(), nil
+
+ case *reflect.UintType:
+ return tUint.gobType(), nil
+
+ case *reflect.FloatType:
+ return tFloat.gobType(), nil
+
+ case *reflect.ComplexType:
+ return tComplex.gobType(), nil
+
+ case *reflect.StringType:
+ return tString.gobType(), nil
+
+ case *reflect.InterfaceType:
+ return tInterface.gobType(), nil
+
+ case *reflect.ArrayType:
+ at := newArrayType(name)
+ types[rt] = at
+ type0, err = getBaseType("", t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ // Historical aside:
+ // For arrays, maps, and slices, we set the type id after the elements
+ // are constructed. This is to retain the order of type id allocation after
+ // a fix made to handle recursive types, which changed the order in
+ // which types are built. Delaying the setting in this way preserves
+ // type ids while allowing recursive types to be described. Structs,
+ // done below, were already handling recursion correctly so they
+ // assign the top-level id before those of the field.
+ at.init(type0, t.Len())
+ return at, nil
+
+ case *reflect.MapType:
+ mt := newMapType(name)
+ types[rt] = mt
+ type0, err = getBaseType("", t.Key())
+ if err != nil {
+ return nil, err
+ }
+ type1, err = getBaseType("", t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ mt.init(type0, type1)
+ return mt, nil
+
+ case *reflect.SliceType:
+ // []byte == []uint8 is a special case
+ if t.Elem().Kind() == reflect.Uint8 {
+ return tBytes.gobType(), nil
+ }
+ st := newSliceType(name)
+ types[rt] = st
+ type0, err = getBaseType(t.Elem().Name(), t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ st.init(type0)
+ return st, nil
+
+ case *reflect.StructType:
+ st := newStructType(name)
+ types[rt] = st
+ idToType[st.id()] = st
+ for i := 0; i < t.NumField(); i++ {
+ f := t.Field(i)
+ if !isExported(f.Name) {
+ continue
+ }
+ typ := userType(f.Type).base
+ tname := typ.Name()
+ if tname == "" {
+ t := userType(f.Type).base
+ tname = t.String()
+ }
+ gt, err := getBaseType(tname, f.Type)
+ if err != nil {
+ return nil, err
+ }
+ st.Field = append(st.Field, &fieldType{f.Name, gt.id()})
+ }
+ return st, nil
+
+ default:
+ return nil, os.NewError("gob NewTypeObject can't handle type: " + rt.String())
+ }
+ return nil, nil
+}
+
+// isExported reports whether this is an exported - upper case - name.
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// getBaseType returns the Gob type describing the given reflect.Type's base type.
+// typeLock must be held.
+func getBaseType(name string, rt reflect.Type) (gobType, os.Error) {
+ ut := userType(rt)
+ return getType(name, ut, ut.base)
+}
+
+// getType returns the Gob type describing the given reflect.Type.
+// Should be called only when handling GobEncoders/Decoders,
+// which may be pointers. All other types are handled through the
+// base type, never a pointer.
+// typeLock must be held.
+func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+ typ, present := types[rt]
+ if present {
+ return typ, nil
+ }
+ typ, err := newTypeObject(name, ut, rt)
+ if err == nil {
+ types[rt] = typ
+ }
+ return typ, err
+}
+
+func checkId(want, got typeId) {
+ if want != got {
+ fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(got), int(want))
+ panic("bootstrap type wrong id: " + got.name() + " " + got.string() + " not " + want.string())
+ }
+}
+
+// used for building the basic types; called only from init(). the incoming
+// interface always refers to a pointer.
+func bootstrapType(name string, e interface{}, expect typeId) typeId {
+ rt := reflect.Typeof(e).(*reflect.PtrType).Elem()
+ _, present := types[rt]
+ if present {
+ panic("bootstrap type already present: " + name + ", " + rt.String())
+ }
+ typ := &CommonType{Name: name}
+ types[rt] = typ
+ setTypeId(typ)
+ checkId(expect, nextId)
+ userType(rt) // might as well cache it now
+ return nextId
+}
+
+// Representation of the information we send and receive about this type.
+// Each value we send is preceded by its type definition: an encoded int.
+// However, the very first time we send the value, we first send the pair
+// (-id, wireType).
+// For bootstrapping purposes, we assume that the recipient knows how
+// to decode a wireType; it is exactly the wireType struct here, interpreted
+// using the gob rules for sending a structure, except that we assume the
+// ids for wireType and structType etc. are known. The relevant pieces
+// are built in encode.go's init() function.
+// To maintain binary compatibility, if you extend this type, always put
+// the new fields last.
+type wireType struct {
+ ArrayT *arrayType
+ SliceT *sliceType
+ StructT *structType
+ MapT *mapType
+ GobEncoderT *gobEncoderType
+}
+
+func (w *wireType) string() string {
+ const unknown = "unknown type"
+ if w == nil {
+ return unknown
+ }
+ switch {
+ case w.ArrayT != nil:
+ return w.ArrayT.Name
+ case w.SliceT != nil:
+ return w.SliceT.Name
+ case w.StructT != nil:
+ return w.StructT.Name
+ case w.MapT != nil:
+ return w.MapT.Name
+ case w.GobEncoderT != nil:
+ return w.GobEncoderT.Name
+ }
+ return unknown
+}
+
+type typeInfo struct {
+ id typeId
+ encoder *encEngine
+ wire *wireType
+}
+
+var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock
+
+// typeLock must be held.
+func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) {
+ rt := ut.base
+ if ut.isGobEncoder {
+ // We want the user type, not the base type.
+ rt = ut.user
+ }
+ info, ok := typeInfoMap[rt]
+ if ok {
+ return info, nil
+ }
+ info = new(typeInfo)
+ gt, err := getBaseType(rt.Name(), rt)
+ if err != nil {
+ return nil, err
+ }
+ info.id = gt.id()
+
+ if ut.isGobEncoder {
+ userType, err := getType(rt.Name(), ut, rt)
+ if err != nil {
+ return nil, err
+ }
+ info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)}
+ typeInfoMap[ut.user] = info
+ return info, nil
+ }
+
+ t := info.id.gobType()
+ switch typ := rt.(type) {
+ case *reflect.ArrayType:
+ info.wire = &wireType{ArrayT: t.(*arrayType)}
+ case *reflect.MapType:
+ info.wire = &wireType{MapT: t.(*mapType)}
+ case *reflect.SliceType:
+ // []byte == []uint8 is a special case handled separately
+ if typ.Elem().Kind() != reflect.Uint8 {
+ info.wire = &wireType{SliceT: t.(*sliceType)}
+ }
+ case *reflect.StructType:
+ info.wire = &wireType{StructT: t.(*structType)}
+ }
+ typeInfoMap[rt] = info
+ return info, nil
+}
+
+// Called only when a panic is acceptable and unexpected.
+func mustGetTypeInfo(rt reflect.Type) *typeInfo {
+ t, err := getTypeInfo(userType(rt))
+ if err != nil {
+ panic("getTypeInfo: " + err.String())
+ }
+ return t
+}
+
+// GobEncoder is the interface describing data that provides its own
+// representation for encoding values for transmission to a GobDecoder.
+// A type that implements GobEncoder and GobDecoder has complete
+// control over the representation of its data and may therefore
+// contain things such as private fields, channels, and functions,
+// which are not usually transmissable in gob streams.
+//
+// Note: Since gobs can be stored permanently, It is good design
+// to guarantee the encoding used by a GobEncoder is stable as the
+// software evolves. For instance, it might make sense for GobEncode
+// to include a version number in the encoding.
+type GobEncoder interface {
+ // GobEncode returns a byte slice representing the encoding of the
+ // receiver for transmission to a GobDecoder, usually of the same
+ // concrete type.
+ GobEncode() ([]byte, os.Error)
+}
+
+// GobDecoder is the interface describing data that provides its own
+// routine for decoding transmitted values sent by a GobEncoder.
+type GobDecoder interface {
+ // GobDecode overwrites the receiver, which must be a pointer,
+ // with the value represented by the byte slice, which was written
+ // by GobEncode, usually for the same concrete type.
+ GobDecode([]byte) os.Error
+}
+
+var (
+ nameToConcreteType = make(map[string]reflect.Type)
+ concreteTypeToName = make(map[reflect.Type]string)
+)
+
+// RegisterName is like Register but uses the provided name rather than the
+// type's default.
+func RegisterName(name string, value interface{}) {
+ if name == "" {
+ // reserved for nil
+ panic("attempt to register empty name")
+ }
+ base := userType(reflect.Typeof(value)).base
+ // Check for incompatible duplicates.
+ if t, ok := nameToConcreteType[name]; ok && t != base {
+ panic("gob: registering duplicate types for " + name)
+ }
+ if n, ok := concreteTypeToName[base]; ok && n != name {
+ panic("gob: registering duplicate names for " + base.String())
+ }
+ // Store the name and type provided by the user....
+ nameToConcreteType[name] = reflect.Typeof(value)
+ // but the flattened type in the type table, since that's what decode needs.
+ concreteTypeToName[base] = name
+}
+
+// Register records a type, identified by a value for that type, under its
+// internal type name. That name will identify the concrete type of a value
+// sent or received as an interface variable. Only types that will be
+// transferred as implementations of interface values need to be registered.
+// Expecting to be used only during initialization, it panics if the mapping
+// between types and names is not a bijection.
+func Register(value interface{}) {
+ // Default to printed representation for unnamed types
+ rt := reflect.Typeof(value)
+ name := rt.String()
+
+ // But for named types (or pointers to them), qualify with import path.
+ // Dereference one pointer looking for a named type.
+ star := ""
+ if rt.Name() == "" {
+ if pt, ok := rt.(*reflect.PtrType); ok {
+ star = "*"
+ rt = pt
+ }
+ }
+ if rt.Name() != "" {
+ if rt.PkgPath() == "" {
+ name = star + rt.Name()
+ } else {
+ name = star + rt.PkgPath() + "." + rt.Name()
+ }
+ }
+
+ RegisterName(name, value)
+}
+
+func registerBasics() {
+ Register(int(0))
+ Register(int8(0))
+ Register(int16(0))
+ Register(int32(0))
+ Register(int64(0))
+ Register(uint(0))
+ Register(uint8(0))
+ Register(uint16(0))
+ Register(uint32(0))
+ Register(uint64(0))
+ Register(float32(0))
+ Register(float64(0))
+ Register(complex64(0i))
+ Register(complex128(0i))
+ Register(false)
+ Register("")
+ Register([]byte(nil))
+}
diff --git a/src/cmd/fix/testdata/reflect.type.go.out b/src/cmd/fix/testdata/reflect.type.go.out
new file mode 100644
index 000000000..d729ea471
--- /dev/null
+++ b/src/cmd/fix/testdata/reflect.type.go.out
@@ -0,0 +1,790 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+ "sync"
+ "unicode"
+ "utf8"
+)
+
+// userTypeInfo stores the information associated with a type the user has handed
+// to the package. It's computed once and stored in a map keyed by reflection
+// type.
+type userTypeInfo struct {
+ user reflect.Type // the type the user handed us
+ base reflect.Type // the base type after all indirections
+ indir int // number of indirections to reach the base type
+ isGobEncoder bool // does the type implement GobEncoder?
+ isGobDecoder bool // does the type implement GobDecoder?
+ encIndir int8 // number of indirections to reach the receiver type; may be negative
+ decIndir int8 // number of indirections to reach the receiver type; may be negative
+}
+
+var (
+ // Protected by an RWMutex because we read it a lot and write
+ // it only when we see a new type, typically when compiling.
+ userTypeLock sync.RWMutex
+ userTypeCache = make(map[reflect.Type]*userTypeInfo)
+)
+
+// validType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, err will be non-nil. To be used when the error handler
+// is not set up.
+func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
+ userTypeLock.RLock()
+ ut = userTypeCache[rt]
+ userTypeLock.RUnlock()
+ if ut != nil {
+ return
+ }
+ // Now set the value under the write lock.
+ userTypeLock.Lock()
+ defer userTypeLock.Unlock()
+ if ut = userTypeCache[rt]; ut != nil {
+ // Lost the race; not a problem.
+ return
+ }
+ ut = new(userTypeInfo)
+ ut.base = rt
+ ut.user = rt
+ // A type that is just a cycle of pointers (such as type T *T) cannot
+ // be represented in gobs, which need some concrete data. We use a
+ // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
+ // pp 539-540. As we step through indirections, run another type at
+ // half speed. If they meet up, there's a cycle.
+ slowpoke := ut.base // walks half as fast as ut.base
+ for {
+ pt := ut.base
+ if pt.Kind() != reflect.Ptr {
+ break
+ }
+ ut.base = pt.Elem()
+ if ut.base == slowpoke { // ut.base lapped slowpoke
+ // recursive pointer type.
+ return nil, os.NewError("can't represent recursive pointer type " + ut.base.String())
+ }
+ if ut.indir%2 == 0 {
+ slowpoke = slowpoke.Elem()
+ }
+ ut.indir++
+ }
+ ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck)
+ ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck)
+ userTypeCache[rt] = ut
+ return
+}
+
+const (
+ gobEncodeMethodName = "GobEncode"
+ gobDecodeMethodName = "GobDecode"
+)
+
+// implements returns whether the type implements the interface, as encoded
+// in the check function.
+func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool {
+ if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance
+ return false
+ }
+ return check(typ)
+}
+
+// gobEncoderCheck makes the type assertion a boolean function.
+func gobEncoderCheck(typ reflect.Type) bool {
+ _, ok := reflect.Zero(typ).Interface().(GobEncoder)
+ return ok
+}
+
+// gobDecoderCheck makes the type assertion a boolean function.
+func gobDecoderCheck(typ reflect.Type) bool {
+ _, ok := reflect.Zero(typ).Interface().(GobDecoder)
+ return ok
+}
+
+// implementsInterface reports whether the type implements the
+// interface. (The actual check is done through the provided function.)
+// It also returns the number of indirections required to get to the
+// implementation.
+func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) {
+ if typ == nil {
+ return
+ }
+ rt := typ
+ // The type might be a pointer and we need to keep
+ // dereferencing to the base type until we find an implementation.
+ for {
+ if implements(rt, check) {
+ return true, indir
+ }
+ if p := rt; p.Kind() == reflect.Ptr {
+ indir++
+ if indir > 100 { // insane number of indirections
+ return false, 0
+ }
+ rt = p.Elem()
+ continue
+ }
+ break
+ }
+ // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy.
+ if typ.Kind() != reflect.Ptr {
+ // Not a pointer, but does the pointer work?
+ if implements(reflect.PtrTo(typ), check) {
+ return true, -1
+ }
+ }
+ return false, 0
+}
+
+// userType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, it calls error.
+func userType(rt reflect.Type) *userTypeInfo {
+ ut, err := validUserType(rt)
+ if err != nil {
+ error(err)
+ }
+ return ut
+}
+
+// A typeId represents a gob Type as an integer that can be passed on the wire.
+// Internally, typeIds are used as keys to a map to recover the underlying type info.
+type typeId int32
+
+var nextId typeId // incremented for each new type we build
+var typeLock sync.Mutex // set while building a type
+const firstUserId = 64 // lowest id number granted to user
+
+type gobType interface {
+ id() typeId
+ setId(id typeId)
+ name() string
+ string() string // not public; only for debugging
+ safeString(seen map[typeId]bool) string
+}
+
+var types = make(map[reflect.Type]gobType)
+var idToType = make(map[typeId]gobType)
+var builtinIdToType map[typeId]gobType // set in init() after builtins are established
+
+func setTypeId(typ gobType) {
+ nextId++
+ typ.setId(nextId)
+ idToType[nextId] = typ
+}
+
+func (t typeId) gobType() gobType {
+ if t == 0 {
+ return nil
+ }
+ return idToType[t]
+}
+
+// string returns the string representation of the type associated with the typeId.
+func (t typeId) string() string {
+ if t.gobType() == nil {
+ return "<nil>"
+ }
+ return t.gobType().string()
+}
+
+// Name returns the name of the type associated with the typeId.
+func (t typeId) name() string {
+ if t.gobType() == nil {
+ return "<nil>"
+ }
+ return t.gobType().name()
+}
+
+// Common elements of all types.
+type CommonType struct {
+ Name string
+ Id typeId
+}
+
+func (t *CommonType) id() typeId { return t.Id }
+
+func (t *CommonType) setId(id typeId) { t.Id = id }
+
+func (t *CommonType) string() string { return t.Name }
+
+func (t *CommonType) safeString(seen map[typeId]bool) string {
+ return t.Name
+}
+
+func (t *CommonType) name() string { return t.Name }
+
+// Create and check predefined types
+// The string for tBytes is "bytes" not "[]byte" to signify its specialness.
+
+var (
+ // Primordial types, needed during initialization.
+ // Always passed as pointers so the interface{} type
+ // goes through without losing its interfaceness.
+ tBool = bootstrapType("bool", (*bool)(nil), 1)
+ tInt = bootstrapType("int", (*int)(nil), 2)
+ tUint = bootstrapType("uint", (*uint)(nil), 3)
+ tFloat = bootstrapType("float", (*float64)(nil), 4)
+ tBytes = bootstrapType("bytes", (*[]byte)(nil), 5)
+ tString = bootstrapType("string", (*string)(nil), 6)
+ tComplex = bootstrapType("complex", (*complex128)(nil), 7)
+ tInterface = bootstrapType("interface", (*interface{})(nil), 8)
+ // Reserve some Ids for compatible expansion
+ tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil), 9)
+ tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil), 10)
+ tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil), 11)
+ tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil), 12)
+ tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil), 13)
+ tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil), 14)
+ tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil), 15)
+)
+
+// Predefined because it's needed by the Decoder
+var tWireType = mustGetTypeInfo(reflect.TypeOf(wireType{})).id
+var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType)
+
+func init() {
+ // Some magic numbers to make sure there are no surprises.
+ checkId(16, tWireType)
+ checkId(17, mustGetTypeInfo(reflect.TypeOf(arrayType{})).id)
+ checkId(18, mustGetTypeInfo(reflect.TypeOf(CommonType{})).id)
+ checkId(19, mustGetTypeInfo(reflect.TypeOf(sliceType{})).id)
+ checkId(20, mustGetTypeInfo(reflect.TypeOf(structType{})).id)
+ checkId(21, mustGetTypeInfo(reflect.TypeOf(fieldType{})).id)
+ checkId(23, mustGetTypeInfo(reflect.TypeOf(mapType{})).id)
+
+ builtinIdToType = make(map[typeId]gobType)
+ for k, v := range idToType {
+ builtinIdToType[k] = v
+ }
+
+ // Move the id space upwards to allow for growth in the predefined world
+ // without breaking existing files.
+ if nextId > firstUserId {
+ panic(fmt.Sprintln("nextId too large:", nextId))
+ }
+ nextId = firstUserId
+ registerBasics()
+ wireTypeUserInfo = userType(reflect.TypeOf((*wireType)(nil)))
+}
+
+// Array type
+type arrayType struct {
+ CommonType
+ Elem typeId
+ Len int
+}
+
+func newArrayType(name string) *arrayType {
+ a := &arrayType{CommonType{Name: name}, 0, 0}
+ return a
+}
+
+func (a *arrayType) init(elem gobType, len int) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(a)
+ a.Elem = elem.id()
+ a.Len = len
+}
+
+func (a *arrayType) safeString(seen map[typeId]bool) string {
+ if seen[a.Id] {
+ return a.Name
+ }
+ seen[a.Id] = true
+ return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen))
+}
+
+func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
+
+// GobEncoder type (something that implements the GobEncoder interface)
+type gobEncoderType struct {
+ CommonType
+}
+
+func newGobEncoderType(name string) *gobEncoderType {
+ g := &gobEncoderType{CommonType{Name: name}}
+ setTypeId(g)
+ return g
+}
+
+func (g *gobEncoderType) safeString(seen map[typeId]bool) string {
+ return g.Name
+}
+
+func (g *gobEncoderType) string() string { return g.Name }
+
+// Map type
+type mapType struct {
+ CommonType
+ Key typeId
+ Elem typeId
+}
+
+func newMapType(name string) *mapType {
+ m := &mapType{CommonType{Name: name}, 0, 0}
+ return m
+}
+
+func (m *mapType) init(key, elem gobType) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(m)
+ m.Key = key.id()
+ m.Elem = elem.id()
+}
+
+func (m *mapType) safeString(seen map[typeId]bool) string {
+ if seen[m.Id] {
+ return m.Name
+ }
+ seen[m.Id] = true
+ key := m.Key.gobType().safeString(seen)
+ elem := m.Elem.gobType().safeString(seen)
+ return fmt.Sprintf("map[%s]%s", key, elem)
+}
+
+func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) }
+
+// Slice type
+type sliceType struct {
+ CommonType
+ Elem typeId
+}
+
+func newSliceType(name string) *sliceType {
+ s := &sliceType{CommonType{Name: name}, 0}
+ return s
+}
+
+func (s *sliceType) init(elem gobType) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(s)
+ s.Elem = elem.id()
+}
+
+func (s *sliceType) safeString(seen map[typeId]bool) string {
+ if seen[s.Id] {
+ return s.Name
+ }
+ seen[s.Id] = true
+ return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen))
+}
+
+func (s *sliceType) string() string { return s.safeString(make(map[typeId]bool)) }
+
+// Struct type
+type fieldType struct {
+ Name string
+ Id typeId
+}
+
+type structType struct {
+ CommonType
+ Field []*fieldType
+}
+
+func (s *structType) safeString(seen map[typeId]bool) string {
+ if s == nil {
+ return "<nil>"
+ }
+ if _, ok := seen[s.Id]; ok {
+ return s.Name
+ }
+ seen[s.Id] = true
+ str := s.Name + " = struct { "
+ for _, f := range s.Field {
+ str += fmt.Sprintf("%s %s; ", f.Name, f.Id.gobType().safeString(seen))
+ }
+ str += "}"
+ return str
+}
+
+func (s *structType) string() string { return s.safeString(make(map[typeId]bool)) }
+
+func newStructType(name string) *structType {
+ s := &structType{CommonType{Name: name}, nil}
+ // For historical reasons we set the id here rather than init.
+ // See the comment in newTypeObject for details.
+ setTypeId(s)
+ return s
+}
+
+// newTypeObject allocates a gobType for the reflection type rt.
+// Unless ut represents a GobEncoder, rt should be the base type
+// of ut.
+// This is only called from the encoding side. The decoding side
+// works through typeIds and userTypeInfos alone.
+func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+ // Does this type implement GobEncoder?
+ if ut.isGobEncoder {
+ return newGobEncoderType(name), nil
+ }
+ var err os.Error
+ var type0, type1 gobType
+ defer func() {
+ if err != nil {
+ types[rt] = nil, false
+ }
+ }()
+ // Install the top-level type before the subtypes (e.g. struct before
+ // fields) so recursive types can be constructed safely.
+ switch t := rt; t.Kind() {
+ // All basic types are easy: they are predefined.
+ case reflect.Bool:
+ return tBool.gobType(), nil
+
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return tInt.gobType(), nil
+
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return tUint.gobType(), nil
+
+ case reflect.Float32, reflect.Float64:
+ return tFloat.gobType(), nil
+
+ case reflect.Complex64, reflect.Complex128:
+ return tComplex.gobType(), nil
+
+ case reflect.String:
+ return tString.gobType(), nil
+
+ case reflect.Interface:
+ return tInterface.gobType(), nil
+
+ case reflect.Array:
+ at := newArrayType(name)
+ types[rt] = at
+ type0, err = getBaseType("", t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ // Historical aside:
+ // For arrays, maps, and slices, we set the type id after the elements
+ // are constructed. This is to retain the order of type id allocation after
+ // a fix made to handle recursive types, which changed the order in
+ // which types are built. Delaying the setting in this way preserves
+ // type ids while allowing recursive types to be described. Structs,
+ // done below, were already handling recursion correctly so they
+ // assign the top-level id before those of the field.
+ at.init(type0, t.Len())
+ return at, nil
+
+ case reflect.Map:
+ mt := newMapType(name)
+ types[rt] = mt
+ type0, err = getBaseType("", t.Key())
+ if err != nil {
+ return nil, err
+ }
+ type1, err = getBaseType("", t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ mt.init(type0, type1)
+ return mt, nil
+
+ case reflect.Slice:
+ // []byte == []uint8 is a special case
+ if t.Elem().Kind() == reflect.Uint8 {
+ return tBytes.gobType(), nil
+ }
+ st := newSliceType(name)
+ types[rt] = st
+ type0, err = getBaseType(t.Elem().Name(), t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ st.init(type0)
+ return st, nil
+
+ case reflect.Struct:
+ st := newStructType(name)
+ types[rt] = st
+ idToType[st.id()] = st
+ for i := 0; i < t.NumField(); i++ {
+ f := t.Field(i)
+ if !isExported(f.Name) {
+ continue
+ }
+ typ := userType(f.Type).base
+ tname := typ.Name()
+ if tname == "" {
+ t := userType(f.Type).base
+ tname = t.String()
+ }
+ gt, err := getBaseType(tname, f.Type)
+ if err != nil {
+ return nil, err
+ }
+ st.Field = append(st.Field, &fieldType{f.Name, gt.id()})
+ }
+ return st, nil
+
+ default:
+ return nil, os.NewError("gob NewTypeObject can't handle type: " + rt.String())
+ }
+ return nil, nil
+}
+
+// isExported reports whether this is an exported - upper case - name.
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// getBaseType returns the Gob type describing the given reflect.Type's base type.
+// typeLock must be held.
+func getBaseType(name string, rt reflect.Type) (gobType, os.Error) {
+ ut := userType(rt)
+ return getType(name, ut, ut.base)
+}
+
+// getType returns the Gob type describing the given reflect.Type.
+// Should be called only when handling GobEncoders/Decoders,
+// which may be pointers. All other types are handled through the
+// base type, never a pointer.
+// typeLock must be held.
+func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+ typ, present := types[rt]
+ if present {
+ return typ, nil
+ }
+ typ, err := newTypeObject(name, ut, rt)
+ if err == nil {
+ types[rt] = typ
+ }
+ return typ, err
+}
+
+func checkId(want, got typeId) {
+ if want != got {
+ fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(got), int(want))
+ panic("bootstrap type wrong id: " + got.name() + " " + got.string() + " not " + want.string())
+ }
+}
+
+// used for building the basic types; called only from init(). the incoming
+// interface always refers to a pointer.
+func bootstrapType(name string, e interface{}, expect typeId) typeId {
+ rt := reflect.TypeOf(e).Elem()
+ _, present := types[rt]
+ if present {
+ panic("bootstrap type already present: " + name + ", " + rt.String())
+ }
+ typ := &CommonType{Name: name}
+ types[rt] = typ
+ setTypeId(typ)
+ checkId(expect, nextId)
+ userType(rt) // might as well cache it now
+ return nextId
+}
+
+// Representation of the information we send and receive about this type.
+// Each value we send is preceded by its type definition: an encoded int.
+// However, the very first time we send the value, we first send the pair
+// (-id, wireType).
+// For bootstrapping purposes, we assume that the recipient knows how
+// to decode a wireType; it is exactly the wireType struct here, interpreted
+// using the gob rules for sending a structure, except that we assume the
+// ids for wireType and structType etc. are known. The relevant pieces
+// are built in encode.go's init() function.
+// To maintain binary compatibility, if you extend this type, always put
+// the new fields last.
+type wireType struct {
+ ArrayT *arrayType
+ SliceT *sliceType
+ StructT *structType
+ MapT *mapType
+ GobEncoderT *gobEncoderType
+}
+
+func (w *wireType) string() string {
+ const unknown = "unknown type"
+ if w == nil {
+ return unknown
+ }
+ switch {
+ case w.ArrayT != nil:
+ return w.ArrayT.Name
+ case w.SliceT != nil:
+ return w.SliceT.Name
+ case w.StructT != nil:
+ return w.StructT.Name
+ case w.MapT != nil:
+ return w.MapT.Name
+ case w.GobEncoderT != nil:
+ return w.GobEncoderT.Name
+ }
+ return unknown
+}
+
+type typeInfo struct {
+ id typeId
+ encoder *encEngine
+ wire *wireType
+}
+
+var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock
+
+// typeLock must be held.
+func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) {
+ rt := ut.base
+ if ut.isGobEncoder {
+ // We want the user type, not the base type.
+ rt = ut.user
+ }
+ info, ok := typeInfoMap[rt]
+ if ok {
+ return info, nil
+ }
+ info = new(typeInfo)
+ gt, err := getBaseType(rt.Name(), rt)
+ if err != nil {
+ return nil, err
+ }
+ info.id = gt.id()
+
+ if ut.isGobEncoder {
+ userType, err := getType(rt.Name(), ut, rt)
+ if err != nil {
+ return nil, err
+ }
+ info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)}
+ typeInfoMap[ut.user] = info
+ return info, nil
+ }
+
+ t := info.id.gobType()
+ switch typ := rt; typ.Kind() {
+ case reflect.Array:
+ info.wire = &wireType{ArrayT: t.(*arrayType)}
+ case reflect.Map:
+ info.wire = &wireType{MapT: t.(*mapType)}
+ case reflect.Slice:
+ // []byte == []uint8 is a special case handled separately
+ if typ.Elem().Kind() != reflect.Uint8 {
+ info.wire = &wireType{SliceT: t.(*sliceType)}
+ }
+ case reflect.Struct:
+ info.wire = &wireType{StructT: t.(*structType)}
+ }
+ typeInfoMap[rt] = info
+ return info, nil
+}
+
+// Called only when a panic is acceptable and unexpected.
+func mustGetTypeInfo(rt reflect.Type) *typeInfo {
+ t, err := getTypeInfo(userType(rt))
+ if err != nil {
+ panic("getTypeInfo: " + err.String())
+ }
+ return t
+}
+
+// GobEncoder is the interface describing data that provides its own
+// representation for encoding values for transmission to a GobDecoder.
+// A type that implements GobEncoder and GobDecoder has complete
+// control over the representation of its data and may therefore
+// contain things such as private fields, channels, and functions,
+// which are not usually transmissable in gob streams.
+//
+// Note: Since gobs can be stored permanently, It is good design
+// to guarantee the encoding used by a GobEncoder is stable as the
+// software evolves. For instance, it might make sense for GobEncode
+// to include a version number in the encoding.
+type GobEncoder interface {
+ // GobEncode returns a byte slice representing the encoding of the
+ // receiver for transmission to a GobDecoder, usually of the same
+ // concrete type.
+ GobEncode() ([]byte, os.Error)
+}
+
+// GobDecoder is the interface describing data that provides its own
+// routine for decoding transmitted values sent by a GobEncoder.
+type GobDecoder interface {
+ // GobDecode overwrites the receiver, which must be a pointer,
+ // with the value represented by the byte slice, which was written
+ // by GobEncode, usually for the same concrete type.
+ GobDecode([]byte) os.Error
+}
+
+var (
+ nameToConcreteType = make(map[string]reflect.Type)
+ concreteTypeToName = make(map[reflect.Type]string)
+)
+
+// RegisterName is like Register but uses the provided name rather than the
+// type's default.
+func RegisterName(name string, value interface{}) {
+ if name == "" {
+ // reserved for nil
+ panic("attempt to register empty name")
+ }
+ base := userType(reflect.TypeOf(value)).base
+ // Check for incompatible duplicates.
+ if t, ok := nameToConcreteType[name]; ok && t != base {
+ panic("gob: registering duplicate types for " + name)
+ }
+ if n, ok := concreteTypeToName[base]; ok && n != name {
+ panic("gob: registering duplicate names for " + base.String())
+ }
+ // Store the name and type provided by the user....
+ nameToConcreteType[name] = reflect.TypeOf(value)
+ // but the flattened type in the type table, since that's what decode needs.
+ concreteTypeToName[base] = name
+}
+
+// Register records a type, identified by a value for that type, under its
+// internal type name. That name will identify the concrete type of a value
+// sent or received as an interface variable. Only types that will be
+// transferred as implementations of interface values need to be registered.
+// Expecting to be used only during initialization, it panics if the mapping
+// between types and names is not a bijection.
+func Register(value interface{}) {
+ // Default to printed representation for unnamed types
+ rt := reflect.TypeOf(value)
+ name := rt.String()
+
+ // But for named types (or pointers to them), qualify with import path.
+ // Dereference one pointer looking for a named type.
+ star := ""
+ if rt.Name() == "" {
+ if pt := rt; pt.Kind() == reflect.Ptr {
+ star = "*"
+ rt = pt
+ }
+ }
+ if rt.Name() != "" {
+ if rt.PkgPath() == "" {
+ name = star + rt.Name()
+ } else {
+ name = star + rt.PkgPath() + "." + rt.Name()
+ }
+ }
+
+ RegisterName(name, value)
+}
+
+func registerBasics() {
+ Register(int(0))
+ Register(int8(0))
+ Register(int16(0))
+ Register(int32(0))
+ Register(int64(0))
+ Register(uint(0))
+ Register(uint8(0))
+ Register(uint16(0))
+ Register(uint32(0))
+ Register(uint64(0))
+ Register(float32(0))
+ Register(float64(0))
+ Register(complex64(0i))
+ Register(complex128(0i))
+ Register(false)
+ Register("")
+ Register([]byte(nil))
+}
diff --git a/src/cmd/fix/timefileinfo.go b/src/cmd/fix/timefileinfo.go
new file mode 100644
index 000000000..b2ea23d8f
--- /dev/null
+++ b/src/cmd/fix/timefileinfo.go
@@ -0,0 +1,298 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+ "go/token"
+ "strings"
+)
+
+func init() {
+ register(timefileinfoFix)
+}
+
+var timefileinfoFix = fix{
+ "time+fileinfo",
+ "2011-11-29",
+ timefileinfo,
+ `Rewrite for new time and os.FileInfo APIs.
+
+This fix applies some of the more mechanical changes,
+but most code will still need manual cleanup.
+
+http://codereview.appspot.com/5392041
+http://codereview.appspot.com/5416060
+`,
+}
+
+var timefileinfoTypeConfig = &TypeConfig{
+ Type: map[string]*Type{
+ "os.File": {
+ Method: map[string]string{
+ "Readdir": "func() []*os.FileInfo",
+ "Stat": "func() (*os.FileInfo, error)",
+ },
+ },
+ "time.Time": {
+ Method: map[string]string{
+ "Seconds": "time.raw",
+ "Nanoseconds": "time.raw",
+ },
+ },
+ },
+ Func: map[string]string{
+ "ioutil.ReadDir": "([]*os.FileInfo, error)",
+ "os.Stat": "(*os.FileInfo, error)",
+ "os.Lstat": "(*os.FileInfo, error)",
+ "time.LocalTime": "*time.Time",
+ "time.UTC": "*time.Time",
+ "time.SecondsToLocalTime": "*time.Time",
+ "time.SecondsToUTC": "*time.Time",
+ "time.NanosecondsToLocalTime": "*time.Time",
+ "time.NanosecondsToUTC": "*time.Time",
+ "time.Parse": "(*time.Time, error)",
+ "time.Nanoseconds": "time.raw",
+ "time.Seconds": "time.raw",
+ },
+}
+
+// timefileinfoIsOld reports whether f has evidence of being
+// "old code", from before the API changes. Evidence means:
+//
+// a mention of *os.FileInfo (the pointer)
+// a mention of *time.Time (the pointer)
+// a mention of old functions from package time
+// an attempt to call time.UTC
+//
+func timefileinfoIsOld(f *ast.File, typeof map[interface{}]string) bool {
+ old := false
+
+ // called records the expressions that appear as
+ // the function part of a function call, so that
+ // we can distinguish a ref to the possibly new time.UTC
+ // from the definitely old time.UTC() function call.
+ called := make(map[interface{}]bool)
+
+ before := func(n interface{}) {
+ if old {
+ return
+ }
+ if star, ok := n.(*ast.StarExpr); ok {
+ if isPkgDot(star.X, "os", "FileInfo") || isPkgDot(star.X, "time", "Time") {
+ old = true
+ return
+ }
+ }
+ if sel, ok := n.(*ast.SelectorExpr); ok {
+ if isTopName(sel.X, "time") {
+ if timefileinfoOldTimeFunc[sel.Sel.Name] {
+ old = true
+ return
+ }
+ }
+ if typeof[sel.X] == "os.FileInfo" || typeof[sel.X] == "*os.FileInfo" {
+ switch sel.Sel.Name {
+ case "Mtime_ns", "IsDirectory", "IsRegular":
+ old = true
+ return
+ case "Name", "Mode", "Size":
+ if !called[sel] {
+ old = true
+ return
+ }
+ }
+ }
+ }
+ call, ok := n.(*ast.CallExpr)
+ if ok && isPkgDot(call.Fun, "time", "UTC") {
+ old = true
+ return
+ }
+ if ok {
+ called[call.Fun] = true
+ }
+ }
+ walkBeforeAfter(f, before, nop)
+ return old
+}
+
+var timefileinfoOldTimeFunc = map[string]bool{
+ "LocalTime": true,
+ "SecondsToLocalTime": true,
+ "SecondsToUTC": true,
+ "NanosecondsToLocalTime": true,
+ "NanosecondsToUTC": true,
+ "Seconds": true,
+ "Nanoseconds": true,
+}
+
+var isTimeNow = map[string]bool{
+ "LocalTime": true,
+ "UTC": true,
+ "Seconds": true,
+ "Nanoseconds": true,
+}
+
+func timefileinfo(f *ast.File) bool {
+ if !imports(f, "os") && !imports(f, "time") && !imports(f, "io/ioutil") {
+ return false
+ }
+
+ typeof, _ := typecheck(timefileinfoTypeConfig, f)
+
+ if !timefileinfoIsOld(f, typeof) {
+ return false
+ }
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ p, ok := n.(*ast.Expr)
+ if !ok {
+ return
+ }
+ nn := *p
+
+ // Rewrite *os.FileInfo and *time.Time to drop the pointer.
+ if star, ok := nn.(*ast.StarExpr); ok {
+ if isPkgDot(star.X, "os", "FileInfo") || isPkgDot(star.X, "time", "Time") {
+ fixed = true
+ *p = star.X
+ return
+ }
+ }
+
+ // Rewrite old time API calls to new calls.
+ // The code will still not compile after this edit,
+ // but the compiler will catch that, and the replacement
+ // code will be the correct functions to use in the new API.
+ if sel, ok := nn.(*ast.SelectorExpr); ok && isTopName(sel.X, "time") {
+ fn := sel.Sel.Name
+ if fn == "LocalTime" || fn == "Seconds" || fn == "Nanoseconds" {
+ fixed = true
+ sel.Sel.Name = "Now"
+ return
+ }
+ }
+
+ if call, ok := nn.(*ast.CallExpr); ok {
+ if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
+ // Rewrite time.UTC but only when called (there's a new time.UTC var now).
+ if isPkgDot(sel, "time", "UTC") {
+ fixed = true
+ sel.Sel.Name = "Now"
+ // rewrite time.Now() into time.Now().UTC()
+ *p = &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: call,
+ Sel: ast.NewIdent("UTC"),
+ },
+ }
+ return
+ }
+
+ // Rewrite conversions.
+ if ok && isTopName(sel.X, "time") && len(call.Args) == 1 {
+ fn := sel.Sel.Name
+ switch fn {
+ case "SecondsToLocalTime", "SecondsToUTC",
+ "NanosecondsToLocalTime", "NanosecondsToUTC":
+ fixed = true
+ sel.Sel.Name = "Unix"
+ call.Args = append(call.Args, nil)
+ if strings.HasPrefix(fn, "Seconds") {
+ // Unix(sec, 0)
+ call.Args[1] = ast.NewIdent("0")
+ } else {
+ // Unix(0, nsec)
+ call.Args[1] = call.Args[0]
+ call.Args[0] = ast.NewIdent("0")
+ }
+ if strings.HasSuffix(fn, "ToUTC") {
+ // rewrite call into call.UTC()
+ *p = &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: call,
+ Sel: ast.NewIdent("UTC"),
+ },
+ }
+ }
+ return
+ }
+ }
+
+ // Rewrite method calls.
+ switch typeof[sel.X] {
+ case "*time.Time", "time.Time":
+ switch sel.Sel.Name {
+ case "Seconds":
+ fixed = true
+ sel.Sel.Name = "Unix"
+ return
+ case "Nanoseconds":
+ fixed = true
+ sel.Sel.Name = "UnixNano"
+ return
+ }
+
+ case "*os.FileInfo", "os.FileInfo":
+ switch sel.Sel.Name {
+ case "IsDirectory":
+ fixed = true
+ sel.Sel.Name = "IsDir"
+ return
+ case "IsRegular":
+ fixed = true
+ sel.Sel.Name = "IsDir"
+ *p = &ast.UnaryExpr{
+ Op: token.NOT,
+ X: call,
+ }
+ return
+ }
+ }
+ }
+ }
+
+ // Rewrite subtraction of two times.
+ // Cannot handle +=/-=.
+ if bin, ok := nn.(*ast.BinaryExpr); ok &&
+ bin.Op == token.SUB &&
+ (typeof[bin.X] == "time.raw" || typeof[bin.Y] == "time.raw") {
+ fixed = true
+ *p = &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: bin.X,
+ Sel: ast.NewIdent("Sub"),
+ },
+ Args: []ast.Expr{bin.Y},
+ }
+ }
+
+ // Rewrite field references for os.FileInfo.
+ if sel, ok := nn.(*ast.SelectorExpr); ok {
+ if typ := typeof[sel.X]; typ == "*os.FileInfo" || typ == "os.FileInfo" {
+ addCall := false
+ switch sel.Sel.Name {
+ case "Name", "Size", "Mode":
+ fixed = true
+ addCall = true
+ case "Mtime_ns":
+ fixed = true
+ sel.Sel.Name = "ModTime"
+ addCall = true
+ }
+ if addCall {
+ *p = &ast.CallExpr{
+ Fun: sel,
+ }
+ return
+ }
+ }
+ }
+ })
+
+ return true
+}
diff --git a/src/cmd/fix/timefileinfo_test.go b/src/cmd/fix/timefileinfo_test.go
new file mode 100644
index 000000000..6573b8545
--- /dev/null
+++ b/src/cmd/fix/timefileinfo_test.go
@@ -0,0 +1,187 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(timefileinfoTests, timefileinfo)
+}
+
+var timefileinfoTests = []testCase{
+ {
+ Name: "timefileinfo.0",
+ In: `package main
+
+import "os"
+
+func main() {
+ st, _ := os.Stat("/etc/passwd")
+ _ = st.Name
+}
+`,
+ Out: `package main
+
+import "os"
+
+func main() {
+ st, _ := os.Stat("/etc/passwd")
+ _ = st.Name()
+}
+`,
+ },
+ {
+ Name: "timefileinfo.1",
+ In: `package main
+
+import "os"
+
+func main() {
+ st, _ := os.Stat("/etc/passwd")
+ _ = st.Size
+ _ = st.Mode
+ _ = st.Mtime_ns
+ _ = st.IsDirectory()
+ _ = st.IsRegular()
+}
+`,
+ Out: `package main
+
+import "os"
+
+func main() {
+ st, _ := os.Stat("/etc/passwd")
+ _ = st.Size()
+ _ = st.Mode()
+ _ = st.ModTime()
+ _ = st.IsDir()
+ _ = !st.IsDir()
+}
+`,
+ },
+ {
+ Name: "timefileinfo.2",
+ In: `package main
+
+import "os"
+
+func f(st *os.FileInfo) {
+ _ = st.Name
+ _ = st.Size
+ _ = st.Mode
+ _ = st.Mtime_ns
+ _ = st.IsDirectory()
+ _ = st.IsRegular()
+}
+`,
+ Out: `package main
+
+import "os"
+
+func f(st os.FileInfo) {
+ _ = st.Name()
+ _ = st.Size()
+ _ = st.Mode()
+ _ = st.ModTime()
+ _ = st.IsDir()
+ _ = !st.IsDir()
+}
+`,
+ },
+ {
+ Name: "timefileinfo.3",
+ In: `package main
+
+import "time"
+
+func main() {
+ _ = time.Seconds()
+ _ = time.Nanoseconds()
+ _ = time.LocalTime()
+ _ = time.UTC()
+ _ = time.SecondsToLocalTime(sec)
+ _ = time.SecondsToUTC(sec)
+ _ = time.NanosecondsToLocalTime(nsec)
+ _ = time.NanosecondsToUTC(nsec)
+}
+`,
+ Out: `package main
+
+import "time"
+
+func main() {
+ _ = time.Now()
+ _ = time.Now()
+ _ = time.Now()
+ _ = time.Now().UTC()
+ _ = time.Unix(sec, 0)
+ _ = time.Unix(sec, 0).UTC()
+ _ = time.Unix(0, nsec)
+ _ = time.Unix(0, nsec).UTC()
+}
+`,
+ },
+ {
+ Name: "timefileinfo.4",
+ In: `package main
+
+import "time"
+
+func f(*time.Time)
+
+func main() {
+ t := time.LocalTime()
+ _ = t.Seconds()
+ _ = t.Nanoseconds()
+
+ t1 := time.Nanoseconds()
+ f(nil)
+ t2 := time.Nanoseconds()
+ dt := t2 - t1
+}
+`,
+ Out: `package main
+
+import "time"
+
+func f(time.Time)
+
+func main() {
+ t := time.Now()
+ _ = t.Unix()
+ _ = t.UnixNano()
+
+ t1 := time.Now()
+ f(nil)
+ t2 := time.Now()
+ dt := t2.Sub(t1)
+}
+`,
+ },
+ {
+ Name: "timefileinfo.5", // test for issues 1505, 2636
+ In: `package main
+
+import (
+ "fmt"
+ "time"
+)
+
+func main() {
+ fmt.Println(time.SecondsToUTC(now)) // this comment must not introduce an illegal linebreak
+}
+`,
+ Out: `package main
+
+import (
+ "fmt"
+ "time"
+)
+
+func main() {
+ fmt.Println(time.Unix(now, 0).UTC( // this comment must not introduce an illegal linebreak
+ ))
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/typecheck.go b/src/cmd/fix/typecheck.go
new file mode 100644
index 000000000..8e54314d1
--- /dev/null
+++ b/src/cmd/fix/typecheck.go
@@ -0,0 +1,673 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "os"
+ "reflect"
+ "strings"
+)
+
+// Partial type checker.
+//
+// The fact that it is partial is very important: the input is
+// an AST and a description of some type information to
+// assume about one or more packages, but not all the
+// packages that the program imports. The checker is
+// expected to do as much as it can with what it has been
+// given. There is not enough information supplied to do
+// a full type check, but the type checker is expected to
+// apply information that can be derived from variable
+// declarations, function and method returns, and type switches
+// as far as it can, so that the caller can still tell the types
+// of expression relevant to a particular fix.
+//
+// TODO(rsc,gri): Replace with go/typechecker.
+// Doing that could be an interesting test case for go/typechecker:
+// the constraints about working with partial information will
+// likely exercise it in interesting ways. The ideal interface would
+// be to pass typecheck a map from importpath to package API text
+// (Go source code), but for now we use data structures (TypeConfig, Type).
+//
+// The strings mostly use gofmt form.
+//
+// A Field or FieldList has as its type a comma-separated list
+// of the types of the fields. For example, the field list
+// x, y, z int
+// has type "int, int, int".
+
+// The prefix "type " is the type of a type.
+// For example, given
+// var x int
+// type T int
+// x's type is "int" but T's type is "type int".
+// mkType inserts the "type " prefix.
+// getType removes it.
+// isType tests for it.
+
+func mkType(t string) string {
+ return "type " + t
+}
+
+func getType(t string) string {
+ if !isType(t) {
+ return ""
+ }
+ return t[len("type "):]
+}
+
+func isType(t string) bool {
+ return strings.HasPrefix(t, "type ")
+}
+
+// TypeConfig describes the universe of relevant types.
+// For ease of creation, the types are all referred to by string
+// name (e.g., "reflect.Value"). TypeByName is the only place
+// where the strings are resolved.
+
+type TypeConfig struct {
+ Type map[string]*Type
+ Var map[string]string
+ Func map[string]string
+}
+
+// typeof returns the type of the given name, which may be of
+// the form "x" or "p.X".
+func (cfg *TypeConfig) typeof(name string) string {
+ if cfg.Var != nil {
+ if t := cfg.Var[name]; t != "" {
+ return t
+ }
+ }
+ if cfg.Func != nil {
+ if t := cfg.Func[name]; t != "" {
+ return "func()" + t
+ }
+ }
+ return ""
+}
+
+// Type describes the Fields and Methods of a type.
+// If the field or method cannot be found there, it is next
+// looked for in the Embed list.
+type Type struct {
+ Field map[string]string // map field name to type
+ Method map[string]string // map method name to comma-separated return types (should start with "func ")
+ Embed []string // list of types this type embeds (for extra methods)
+ Def string // definition of named type
+}
+
+// dot returns the type of "typ.name", making its decision
+// using the type information in cfg.
+func (typ *Type) dot(cfg *TypeConfig, name string) string {
+ if typ.Field != nil {
+ if t := typ.Field[name]; t != "" {
+ return t
+ }
+ }
+ if typ.Method != nil {
+ if t := typ.Method[name]; t != "" {
+ return t
+ }
+ }
+
+ for _, e := range typ.Embed {
+ etyp := cfg.Type[e]
+ if etyp != nil {
+ if t := etyp.dot(cfg, name); t != "" {
+ return t
+ }
+ }
+ }
+
+ return ""
+}
+
+// typecheck type checks the AST f assuming the information in cfg.
+// It returns two maps with type information:
+// typeof maps AST nodes to type information in gofmt string form.
+// assign maps type strings to lists of expressions that were assigned
+// to values of another type that were assigned to that type.
+func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) {
+ typeof = make(map[interface{}]string)
+ assign = make(map[string][]interface{})
+ cfg1 := &TypeConfig{}
+ *cfg1 = *cfg // make copy so we can add locally
+ copied := false
+
+ // gather function declarations
+ for _, decl := range f.Decls {
+ fn, ok := decl.(*ast.FuncDecl)
+ if !ok {
+ continue
+ }
+ typecheck1(cfg, fn.Type, typeof, assign)
+ t := typeof[fn.Type]
+ if fn.Recv != nil {
+ // The receiver must be a type.
+ rcvr := typeof[fn.Recv]
+ if !isType(rcvr) {
+ if len(fn.Recv.List) != 1 {
+ continue
+ }
+ rcvr = mkType(gofmt(fn.Recv.List[0].Type))
+ typeof[fn.Recv.List[0].Type] = rcvr
+ }
+ rcvr = getType(rcvr)
+ if rcvr != "" && rcvr[0] == '*' {
+ rcvr = rcvr[1:]
+ }
+ typeof[rcvr+"."+fn.Name.Name] = t
+ } else {
+ if isType(t) {
+ t = getType(t)
+ } else {
+ t = gofmt(fn.Type)
+ }
+ typeof[fn.Name] = t
+
+ // Record typeof[fn.Name.Obj] for future references to fn.Name.
+ typeof[fn.Name.Obj] = t
+ }
+ }
+
+ // gather struct declarations
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.GenDecl)
+ if ok {
+ for _, s := range d.Specs {
+ switch s := s.(type) {
+ case *ast.TypeSpec:
+ if cfg1.Type[s.Name.Name] != nil {
+ break
+ }
+ if !copied {
+ copied = true
+ // Copy map lazily: it's time.
+ cfg1.Type = make(map[string]*Type)
+ for k, v := range cfg.Type {
+ cfg1.Type[k] = v
+ }
+ }
+ t := &Type{Field: map[string]string{}}
+ cfg1.Type[s.Name.Name] = t
+ switch st := s.Type.(type) {
+ case *ast.StructType:
+ for _, f := range st.Fields.List {
+ for _, n := range f.Names {
+ t.Field[n.Name] = gofmt(f.Type)
+ }
+ }
+ case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
+ t.Def = gofmt(st)
+ }
+ }
+ }
+ }
+ }
+
+ typecheck1(cfg1, f, typeof, assign)
+ return typeof, assign
+}
+
+func makeExprList(a []*ast.Ident) []ast.Expr {
+ var b []ast.Expr
+ for _, x := range a {
+ b = append(b, x)
+ }
+ return b
+}
+
+// Typecheck1 is the recursive form of typecheck.
+// It is like typecheck but adds to the information in typeof
+// instead of allocating a new map.
+func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) {
+ // set sets the type of n to typ.
+ // If isDecl is true, n is being declared.
+ set := func(n ast.Expr, typ string, isDecl bool) {
+ if typeof[n] != "" || typ == "" {
+ if typeof[n] != typ {
+ assign[typ] = append(assign[typ], n)
+ }
+ return
+ }
+ typeof[n] = typ
+
+ // If we obtained typ from the declaration of x
+ // propagate the type to all the uses.
+ // The !isDecl case is a cheat here, but it makes
+ // up in some cases for not paying attention to
+ // struct fields. The real type checker will be
+ // more accurate so we won't need the cheat.
+ if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
+ typeof[id.Obj] = typ
+ }
+ }
+
+ // Type-check an assignment lhs = rhs.
+ // If isDecl is true, this is := so we can update
+ // the types of the objects that lhs refers to.
+ typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
+ if len(lhs) > 1 && len(rhs) == 1 {
+ if _, ok := rhs[0].(*ast.CallExpr); ok {
+ t := split(typeof[rhs[0]])
+ // Lists should have same length but may not; pair what can be paired.
+ for i := 0; i < len(lhs) && i < len(t); i++ {
+ set(lhs[i], t[i], isDecl)
+ }
+ return
+ }
+ }
+ if len(lhs) == 1 && len(rhs) == 2 {
+ // x = y, ok
+ rhs = rhs[:1]
+ } else if len(lhs) == 2 && len(rhs) == 1 {
+ // x, ok = y
+ lhs = lhs[:1]
+ }
+
+ // Match as much as we can.
+ for i := 0; i < len(lhs) && i < len(rhs); i++ {
+ x, y := lhs[i], rhs[i]
+ if typeof[y] != "" {
+ set(x, typeof[y], isDecl)
+ } else {
+ set(y, typeof[x], false)
+ }
+ }
+ }
+
+ expand := func(s string) string {
+ typ := cfg.Type[s]
+ if typ != nil && typ.Def != "" {
+ return typ.Def
+ }
+ return s
+ }
+
+ // The main type check is a recursive algorithm implemented
+ // by walkBeforeAfter(n, before, after).
+ // Most of it is bottom-up, but in a few places we need
+ // to know the type of the function we are checking.
+ // The before function records that information on
+ // the curfn stack.
+ var curfn []*ast.FuncType
+
+ before := func(n interface{}) {
+ // push function type on stack
+ switch n := n.(type) {
+ case *ast.FuncDecl:
+ curfn = append(curfn, n.Type)
+ case *ast.FuncLit:
+ curfn = append(curfn, n.Type)
+ }
+ }
+
+ // After is the real type checker.
+ after := func(n interface{}) {
+ if n == nil {
+ return
+ }
+ if false && reflect.TypeOf(n).Kind() == reflect.Ptr { // debugging trace
+ defer func() {
+ if t := typeof[n]; t != "" {
+ pos := fset.Position(n.(ast.Node).Pos())
+ fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
+ }
+ }()
+ }
+
+ switch n := n.(type) {
+ case *ast.FuncDecl, *ast.FuncLit:
+ // pop function type off stack
+ curfn = curfn[:len(curfn)-1]
+
+ case *ast.FuncType:
+ typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
+
+ case *ast.FieldList:
+ // Field list is concatenation of sub-lists.
+ t := ""
+ for _, field := range n.List {
+ if t != "" {
+ t += ", "
+ }
+ t += typeof[field]
+ }
+ typeof[n] = t
+
+ case *ast.Field:
+ // Field is one instance of the type per name.
+ all := ""
+ t := typeof[n.Type]
+ if !isType(t) {
+ // Create a type, because it is typically *T or *p.T
+ // and we might care about that type.
+ t = mkType(gofmt(n.Type))
+ typeof[n.Type] = t
+ }
+ t = getType(t)
+ if len(n.Names) == 0 {
+ all = t
+ } else {
+ for _, id := range n.Names {
+ if all != "" {
+ all += ", "
+ }
+ all += t
+ typeof[id.Obj] = t
+ typeof[id] = t
+ }
+ }
+ typeof[n] = all
+
+ case *ast.ValueSpec:
+ // var declaration. Use type if present.
+ if n.Type != nil {
+ t := typeof[n.Type]
+ if !isType(t) {
+ t = mkType(gofmt(n.Type))
+ typeof[n.Type] = t
+ }
+ t = getType(t)
+ for _, id := range n.Names {
+ set(id, t, true)
+ }
+ }
+ // Now treat same as assignment.
+ typecheckAssign(makeExprList(n.Names), n.Values, true)
+
+ case *ast.AssignStmt:
+ typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
+
+ case *ast.Ident:
+ // Identifier can take its type from underlying object.
+ if t := typeof[n.Obj]; t != "" {
+ typeof[n] = t
+ }
+
+ case *ast.SelectorExpr:
+ // Field or method.
+ name := n.Sel.Name
+ if t := typeof[n.X]; t != "" {
+ if strings.HasPrefix(t, "*") {
+ t = t[1:] // implicit *
+ }
+ if typ := cfg.Type[t]; typ != nil {
+ if t := typ.dot(cfg, name); t != "" {
+ typeof[n] = t
+ return
+ }
+ }
+ tt := typeof[t+"."+name]
+ if isType(tt) {
+ typeof[n] = getType(tt)
+ return
+ }
+ }
+ // Package selector.
+ if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
+ str := x.Name + "." + name
+ if cfg.Type[str] != nil {
+ typeof[n] = mkType(str)
+ return
+ }
+ if t := cfg.typeof(x.Name + "." + name); t != "" {
+ typeof[n] = t
+ return
+ }
+ }
+
+ case *ast.CallExpr:
+ // make(T) has type T.
+ if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
+ typeof[n] = gofmt(n.Args[0])
+ return
+ }
+ // new(T) has type *T
+ if isTopName(n.Fun, "new") && len(n.Args) == 1 {
+ typeof[n] = "*" + gofmt(n.Args[0])
+ return
+ }
+ // Otherwise, use type of function to determine arguments.
+ t := typeof[n.Fun]
+ in, out := splitFunc(t)
+ if in == nil && out == nil {
+ return
+ }
+ typeof[n] = join(out)
+ for i, arg := range n.Args {
+ if i >= len(in) {
+ break
+ }
+ if typeof[arg] == "" {
+ typeof[arg] = in[i]
+ }
+ }
+
+ case *ast.TypeAssertExpr:
+ // x.(type) has type of x.
+ if n.Type == nil {
+ typeof[n] = typeof[n.X]
+ return
+ }
+ // x.(T) has type T.
+ if t := typeof[n.Type]; isType(t) {
+ typeof[n] = getType(t)
+ } else {
+ typeof[n] = gofmt(n.Type)
+ }
+
+ case *ast.SliceExpr:
+ // x[i:j] has type of x.
+ typeof[n] = typeof[n.X]
+
+ case *ast.IndexExpr:
+ // x[i] has key type of x's type.
+ t := expand(typeof[n.X])
+ if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
+ // Lazy: assume there are no nested [] in the array
+ // length or map key type.
+ if i := strings.Index(t, "]"); i >= 0 {
+ typeof[n] = t[i+1:]
+ }
+ }
+
+ case *ast.StarExpr:
+ // *x for x of type *T has type T when x is an expr.
+ // We don't use the result when *x is a type, but
+ // compute it anyway.
+ t := expand(typeof[n.X])
+ if isType(t) {
+ typeof[n] = "type *" + getType(t)
+ } else if strings.HasPrefix(t, "*") {
+ typeof[n] = t[len("*"):]
+ }
+
+ case *ast.UnaryExpr:
+ // &x for x of type T has type *T.
+ t := typeof[n.X]
+ if t != "" && n.Op == token.AND {
+ typeof[n] = "*" + t
+ }
+
+ case *ast.CompositeLit:
+ // T{...} has type T.
+ typeof[n] = gofmt(n.Type)
+
+ case *ast.ParenExpr:
+ // (x) has type of x.
+ typeof[n] = typeof[n.X]
+
+ case *ast.RangeStmt:
+ t := expand(typeof[n.X])
+ if t == "" {
+ return
+ }
+ var key, value string
+ if t == "string" {
+ key, value = "int", "rune"
+ } else if strings.HasPrefix(t, "[") {
+ key = "int"
+ if i := strings.Index(t, "]"); i >= 0 {
+ value = t[i+1:]
+ }
+ } else if strings.HasPrefix(t, "map[") {
+ if i := strings.Index(t, "]"); i >= 0 {
+ key, value = t[4:i], t[i+1:]
+ }
+ }
+ changed := false
+ if n.Key != nil && key != "" {
+ changed = true
+ set(n.Key, key, n.Tok == token.DEFINE)
+ }
+ if n.Value != nil && value != "" {
+ changed = true
+ set(n.Value, value, n.Tok == token.DEFINE)
+ }
+ // Ugly failure of vision: already type-checked body.
+ // Do it again now that we have that type info.
+ if changed {
+ typecheck1(cfg, n.Body, typeof, assign)
+ }
+
+ case *ast.TypeSwitchStmt:
+ // Type of variable changes for each case in type switch,
+ // but go/parser generates just one variable.
+ // Repeat type check for each case with more precise
+ // type information.
+ as, ok := n.Assign.(*ast.AssignStmt)
+ if !ok {
+ return
+ }
+ varx, ok := as.Lhs[0].(*ast.Ident)
+ if !ok {
+ return
+ }
+ t := typeof[varx]
+ for _, cas := range n.Body.List {
+ cas := cas.(*ast.CaseClause)
+ if len(cas.List) == 1 {
+ // Variable has specific type only when there is
+ // exactly one type in the case list.
+ if tt := typeof[cas.List[0]]; isType(tt) {
+ tt = getType(tt)
+ typeof[varx] = tt
+ typeof[varx.Obj] = tt
+ typecheck1(cfg, cas.Body, typeof, assign)
+ }
+ }
+ }
+ // Restore t.
+ typeof[varx] = t
+ typeof[varx.Obj] = t
+
+ case *ast.ReturnStmt:
+ if len(curfn) == 0 {
+ // Probably can't happen.
+ return
+ }
+ f := curfn[len(curfn)-1]
+ res := n.Results
+ if f.Results != nil {
+ t := split(typeof[f.Results])
+ for i := 0; i < len(res) && i < len(t); i++ {
+ set(res[i], t[i], false)
+ }
+ }
+ }
+ }
+ walkBeforeAfter(f, before, after)
+}
+
+// Convert between function type strings and lists of types.
+// Using strings makes this a little harder, but it makes
+// a lot of the rest of the code easier. This will all go away
+// when we can use go/typechecker directly.
+
+// splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"].
+func splitFunc(s string) (in, out []string) {
+ if !strings.HasPrefix(s, "func(") {
+ return nil, nil
+ }
+
+ i := len("func(") // index of beginning of 'in' arguments
+ nparen := 0
+ for j := i; j < len(s); j++ {
+ switch s[j] {
+ case '(':
+ nparen++
+ case ')':
+ nparen--
+ if nparen < 0 {
+ // found end of parameter list
+ out := strings.TrimSpace(s[j+1:])
+ if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
+ out = out[1 : len(out)-1]
+ }
+ return split(s[i:j]), split(out)
+ }
+ }
+ }
+ return nil, nil
+}
+
+// joinFunc is the inverse of splitFunc.
+func joinFunc(in, out []string) string {
+ outs := ""
+ if len(out) == 1 {
+ outs = " " + out[0]
+ } else if len(out) > 1 {
+ outs = " (" + join(out) + ")"
+ }
+ return "func(" + join(in) + ")" + outs
+}
+
+// split splits "int, float" into ["int", "float"] and splits "" into [].
+func split(s string) []string {
+ out := []string{}
+ i := 0 // current type being scanned is s[i:j].
+ nparen := 0
+ for j := 0; j < len(s); j++ {
+ switch s[j] {
+ case ' ':
+ if i == j {
+ i++
+ }
+ case '(':
+ nparen++
+ case ')':
+ nparen--
+ if nparen < 0 {
+ // probably can't happen
+ return nil
+ }
+ case ',':
+ if nparen == 0 {
+ if i < j {
+ out = append(out, s[i:j])
+ }
+ i = j + 1
+ }
+ }
+ }
+ if nparen != 0 {
+ // probably can't happen
+ return nil
+ }
+ if i < len(s) {
+ out = append(out, s[i:])
+ }
+ return out
+}
+
+// join is the inverse of split.
+func join(x []string) string {
+ return strings.Join(x, ", ")
+}
diff --git a/src/cmd/fix/url.go b/src/cmd/fix/url.go
new file mode 100644
index 000000000..49aac739b
--- /dev/null
+++ b/src/cmd/fix/url.go
@@ -0,0 +1,101 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "go/ast"
+
+func init() {
+ register(urlFix)
+}
+
+var urlFix = fix{
+ "url",
+ "2011-08-17",
+ url,
+ `Move the URL pieces of package http into a new package, url.
+
+http://codereview.appspot.com/4893043
+`,
+}
+
+var urlRenames = []struct{ in, out string }{
+ {"URL", "URL"},
+ {"ParseURL", "Parse"},
+ {"ParseURLReference", "ParseWithReference"},
+ {"ParseQuery", "ParseQuery"},
+ {"Values", "Values"},
+ {"URLEscape", "QueryEscape"},
+ {"URLUnescape", "QueryUnescape"},
+ {"URLError", "Error"},
+ {"URLEscapeError", "EscapeError"},
+}
+
+func url(f *ast.File) bool {
+ if imports(f, "url") || !imports(f, "http") {
+ return false
+ }
+
+ fixed := false
+
+ // Update URL code.
+ urlWalk := func(n interface{}) {
+ // Is it an identifier?
+ if ident, ok := n.(*ast.Ident); ok && ident.Name == "url" {
+ ident.Name = "url_"
+ return
+ }
+ // Parameter and result names.
+ if fn, ok := n.(*ast.FuncType); ok {
+ fixed = urlDoFields(fn.Params) || fixed
+ fixed = urlDoFields(fn.Results) || fixed
+ }
+ }
+
+ // Fix up URL code and add import, at most once.
+ fix := func() {
+ if fixed {
+ return
+ }
+ addImport(f, "url")
+ walkBeforeAfter(f, urlWalk, nop)
+ fixed = true
+ }
+
+ walk(f, func(n interface{}) {
+ // Rename functions and methods.
+ if expr, ok := n.(ast.Expr); ok {
+ for _, s := range urlRenames {
+ if isPkgDot(expr, "http", s.in) {
+ fix()
+ expr.(*ast.SelectorExpr).X.(*ast.Ident).Name = "url"
+ expr.(*ast.SelectorExpr).Sel.Name = s.out
+ return
+ }
+ }
+ }
+ })
+
+ // Remove the http import if no longer needed.
+ if fixed && !usesImport(f, "http") {
+ deleteImport(f, "http")
+ }
+
+ return fixed
+}
+
+func urlDoFields(list *ast.FieldList) (fixed bool) {
+ if list == nil {
+ return
+ }
+ for _, field := range list.List {
+ for _, ident := range field.Names {
+ if ident.Name == "url" {
+ fixed = true
+ ident.Name = "url_"
+ }
+ }
+ }
+ return
+}
diff --git a/src/cmd/fix/url2.go b/src/cmd/fix/url2.go
new file mode 100644
index 000000000..5fd05ad2a
--- /dev/null
+++ b/src/cmd/fix/url2.go
@@ -0,0 +1,46 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "go/ast"
+
+func init() {
+ register(url2Fix)
+}
+
+var url2Fix = fix{
+ "url2",
+ "2012-02-16",
+ url2,
+ `Rename some functions in net/url.
+
+http://codereview.appspot.com/5671061
+`,
+}
+
+func url2(f *ast.File) bool {
+ if !imports(f, "net/url") {
+ return false
+ }
+
+ fixed := false
+
+ walk(f, func(n interface{}) {
+ // Rename functions and methods.
+ sel, ok := n.(*ast.SelectorExpr)
+ if !ok {
+ return
+ }
+ if !isTopName(sel.X, "url") {
+ return
+ }
+ if sel.Sel.Name == "ParseWithReference" {
+ sel.Sel.Name = "ParseWithFragment"
+ fixed = true
+ }
+ })
+
+ return fixed
+}
diff --git a/src/cmd/fix/url2_test.go b/src/cmd/fix/url2_test.go
new file mode 100644
index 000000000..c68dd88f1
--- /dev/null
+++ b/src/cmd/fix/url2_test.go
@@ -0,0 +1,31 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(url2Tests, url2)
+}
+
+var url2Tests = []testCase{
+ {
+ Name: "url2.0",
+ In: `package main
+
+import "net/url"
+
+func f() {
+ url.ParseWithReference("foo")
+}
+`,
+ Out: `package main
+
+import "net/url"
+
+func f() {
+ url.ParseWithFragment("foo")
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/url_test.go b/src/cmd/fix/url_test.go
new file mode 100644
index 000000000..39827f780
--- /dev/null
+++ b/src/cmd/fix/url_test.go
@@ -0,0 +1,159 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(urlTests, url)
+}
+
+var urlTests = []testCase{
+ {
+ Name: "url.0",
+ In: `package main
+
+import (
+ "http"
+)
+
+func f() {
+ var _ http.URL
+ http.ParseURL(a)
+ http.ParseURLReference(a)
+ http.ParseQuery(a)
+ m := http.Values{a: b}
+ http.URLEscape(a)
+ http.URLUnescape(a)
+ var x http.URLError
+ var y http.URLEscapeError
+}
+`,
+ Out: `package main
+
+import "url"
+
+func f() {
+ var _ url.URL
+ url.Parse(a)
+ url.ParseWithReference(a)
+ url.ParseQuery(a)
+ m := url.Values{a: b}
+ url.QueryEscape(a)
+ url.QueryUnescape(a)
+ var x url.Error
+ var y url.EscapeError
+}
+`,
+ },
+ {
+ Name: "url.1",
+ In: `package main
+
+import (
+ "http"
+)
+
+func f() {
+ http.ParseURL(a)
+ var x http.Request
+}
+`,
+ Out: `package main
+
+import (
+ "http"
+ "url"
+)
+
+func f() {
+ url.Parse(a)
+ var x http.Request
+}
+`,
+ },
+ {
+ Name: "url.2",
+ In: `package main
+
+import (
+ "http"
+)
+
+type U struct{ url int }
+type M map[int]int
+
+func f() {
+ http.ParseURL(a)
+ var url = 23
+ url, x := 45, y
+ _ = U{url: url}
+ _ = M{url + 1: url}
+}
+
+func g(url string) string {
+ return url
+}
+
+func h() (url string) {
+ return url
+}
+`,
+ Out: `package main
+
+import "url"
+
+type U struct{ url_ int }
+type M map[int]int
+
+func f() {
+ url.Parse(a)
+ var url_ = 23
+ url_, x := 45, y
+ _ = U{url_: url_}
+ _ = M{url_ + 1: url_}
+}
+
+func g(url_ string) string {
+ return url_
+}
+
+func h() (url_ string) {
+ return url_
+}
+`,
+ },
+ {
+ Name: "url.3",
+ In: `package main
+
+import "http"
+
+type U struct{ url string }
+
+func f() {
+ var u U
+ u.url = "x"
+}
+
+func (url *T) m() string {
+ return url
+}
+`,
+ Out: `package main
+
+import "http"
+
+type U struct{ url string }
+
+func f() {
+ var u U
+ u.url = "x"
+}
+
+func (url *T) m() string {
+ return url
+}
+`,
+ },
+}
diff --git a/src/cmd/fix/xmlapi.go b/src/cmd/fix/xmlapi.go
new file mode 100644
index 000000000..e74425914
--- /dev/null
+++ b/src/cmd/fix/xmlapi.go
@@ -0,0 +1,111 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "go/ast"
+)
+
+func init() {
+ register(xmlapiFix)
+}
+
+var xmlapiFix = fix{
+ "xmlapi",
+ "2012-01-23",
+ xmlapi,
+ `
+ Make encoding/xml's API look more like the rest of the encoding packages.
+
+http://codereview.appspot.com/5574053
+`,
+}
+
+var xmlapiTypeConfig = &TypeConfig{
+ Func: map[string]string{
+ "xml.NewParser": "*xml.Parser",
+ "os.Open": "*os.File",
+ "os.OpenFile": "*os.File",
+ "bytes.NewBuffer": "*bytes.Buffer",
+ "bytes.NewBufferString": "*bytes.Buffer",
+ "bufio.NewReader": "*bufio.Reader",
+ "bufio.NewReadWriter": "*bufio.ReadWriter",
+ },
+}
+
+var isReader = map[string]bool{
+ "*os.File": true,
+ "*bytes.Buffer": true,
+ "*bufio.Reader": true,
+ "*bufio.ReadWriter": true,
+ "io.Reader": true,
+}
+
+func xmlapi(f *ast.File) bool {
+ if !imports(f, "encoding/xml") {
+ return false
+ }
+
+ typeof, _ := typecheck(xmlapiTypeConfig, f)
+
+ fixed := false
+ walk(f, func(n interface{}) {
+ s, ok := n.(*ast.SelectorExpr)
+ if ok && typeof[s.X] == "*xml.Parser" && s.Sel.Name == "Unmarshal" {
+ s.Sel.Name = "DecodeElement"
+ fixed = true
+ return
+ }
+ if ok && isPkgDot(s, "xml", "Parser") {
+ s.Sel.Name = "Decoder"
+ fixed = true
+ return
+ }
+
+ call, ok := n.(*ast.CallExpr)
+ if !ok {
+ return
+ }
+ switch {
+ case len(call.Args) == 2 && isPkgDot(call.Fun, "xml", "Marshal"):
+ *call = xmlMarshal(call.Args)
+ fixed = true
+ case len(call.Args) == 2 && isPkgDot(call.Fun, "xml", "Unmarshal"):
+ if isReader[typeof[call.Args[0]]] {
+ *call = xmlUnmarshal(call.Args)
+ fixed = true
+ }
+ case len(call.Args) == 1 && isPkgDot(call.Fun, "xml", "NewParser"):
+ sel := call.Fun.(*ast.SelectorExpr).Sel
+ sel.Name = "NewDecoder"
+ fixed = true
+ }
+ })
+ return fixed
+}
+
+func xmlMarshal(args []ast.Expr) ast.CallExpr {
+ return xmlCallChain("NewEncoder", "Encode", args)
+}
+
+func xmlUnmarshal(args []ast.Expr) ast.CallExpr {
+ return xmlCallChain("NewDecoder", "Decode", args)
+}
+
+func xmlCallChain(first, second string, args []ast.Expr) ast.CallExpr {
+ return ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: ast.NewIdent("xml"),
+ Sel: ast.NewIdent(first),
+ },
+ Args: args[:1],
+ },
+ Sel: ast.NewIdent(second),
+ },
+ Args: args[1:2],
+ }
+}
diff --git a/src/cmd/fix/xmlapi_test.go b/src/cmd/fix/xmlapi_test.go
new file mode 100644
index 000000000..6486c8124
--- /dev/null
+++ b/src/cmd/fix/xmlapi_test.go
@@ -0,0 +1,85 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+ addTestCases(xmlapiTests, xmlapi)
+}
+
+var xmlapiTests = []testCase{
+ {
+ Name: "xmlapi.0",
+ In: `package main
+
+import "encoding/xml"
+
+func f() {
+ xml.Marshal(a, b)
+ xml.Unmarshal(a, b)
+
+ var buf1 bytes.Buffer
+ buf2 := &bytes.Buffer{}
+ buf3 := bytes.NewBuffer(data)
+ buf4 := bytes.NewBufferString(data)
+ buf5 := bufio.NewReader(r)
+ xml.Unmarshal(&buf1, v)
+ xml.Unmarshal(buf2, v)
+ xml.Unmarshal(buf3, v)
+ xml.Unmarshal(buf4, v)
+ xml.Unmarshal(buf5, v)
+
+ f := os.Open("foo.xml")
+ xml.Unmarshal(f, v)
+
+ p1 := xml.NewParser(stream)
+ p1.Unmarshal(v, start)
+
+ var p2 *xml.Parser
+ p2.Unmarshal(v, start)
+}
+
+func g(r io.Reader, f *os.File, b []byte) {
+ xml.Unmarshal(r, v)
+ xml.Unmarshal(f, v)
+ xml.Unmarshal(b, v)
+}
+`,
+ Out: `package main
+
+import "encoding/xml"
+
+func f() {
+ xml.NewEncoder(a).Encode(b)
+ xml.Unmarshal(a, b)
+
+ var buf1 bytes.Buffer
+ buf2 := &bytes.Buffer{}
+ buf3 := bytes.NewBuffer(data)
+ buf4 := bytes.NewBufferString(data)
+ buf5 := bufio.NewReader(r)
+ xml.NewDecoder(&buf1).Decode(v)
+ xml.NewDecoder(buf2).Decode(v)
+ xml.NewDecoder(buf3).Decode(v)
+ xml.NewDecoder(buf4).Decode(v)
+ xml.NewDecoder(buf5).Decode(v)
+
+ f := os.Open("foo.xml")
+ xml.NewDecoder(f).Decode(v)
+
+ p1 := xml.NewDecoder(stream)
+ p1.DecodeElement(v, start)
+
+ var p2 *xml.Decoder
+ p2.DecodeElement(v, start)
+}
+
+func g(r io.Reader, f *os.File, b []byte) {
+ xml.NewDecoder(r).Decode(v)
+ xml.NewDecoder(f).Decode(v)
+ xml.Unmarshal(b, v)
+}
+`,
+ },
+}