diff options
Diffstat (limited to 'src/cmd/gofmt/rewrite.go')
-rw-r--r-- | src/cmd/gofmt/rewrite.go | 118 |
1 files changed, 71 insertions, 47 deletions
diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go index fbcd46aa2..93643dced 100644 --- a/src/cmd/gofmt/rewrite.go +++ b/src/cmd/gofmt/rewrite.go @@ -46,6 +46,16 @@ func parseExpr(s string, what string) ast.Expr { } +// Keep this function for debugging. +/* +func dump(msg string, val reflect.Value) { + fmt.Printf("%s:\n", msg) + ast.Print(fset, val.Interface()) + fmt.Println() +} +*/ + + // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file. func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { m := make(map[string]reflect.Value) @@ -54,7 +64,7 @@ func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { var f func(val reflect.Value) reflect.Value // f is recursive f = func(val reflect.Value) reflect.Value { for k := range m { - m[k] = nil, false + m[k] = reflect.Value{}, false } val = apply(f, val) if match(m, pat, val) { @@ -78,28 +88,45 @@ func setValue(x, y reflect.Value) { panic(x) } }() - x.SetValue(y) + x.Set(y) } +// Values/types for special cases. +var ( + objectPtrNil = reflect.NewValue((*ast.Object)(nil)) + + identType = reflect.Typeof((*ast.Ident)(nil)) + objectPtrType = reflect.Typeof((*ast.Object)(nil)) + positionType = reflect.Typeof(token.NoPos) +) + + // apply replaces each AST field x in val with f(x), returning val. // To avoid extra conversions, f operates on the reflect.Value form. func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value { - if val == nil { - return nil + if !val.IsValid() { + return reflect.Value{} + } + + // *ast.Objects introduce cycles and are likely incorrect after + // rewrite; don't follow them but replace with nil instead + if val.Type() == objectPtrType { + return objectPtrNil } - switch v := reflect.Indirect(val).(type) { - case *reflect.SliceValue: + + switch v := reflect.Indirect(val); v.Kind() { + case reflect.Slice: for i := 0; i < v.Len(); i++ { - e := v.Elem(i) + e := v.Index(i) setValue(e, f(e)) } - case *reflect.StructValue: + case reflect.Struct: for i := 0; i < v.NumField(); i++ { e := v.Field(i) setValue(e, f(e)) } - case *reflect.InterfaceValue: + case reflect.Interface: e := v.Elem() setValue(v, f(e)) } @@ -107,10 +134,6 @@ func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value } -var positionType = reflect.Typeof(token.NoPos) -var identType = reflect.Typeof((*ast.Ident)(nil)) - - func isWildcard(s string) bool { rune, size := utf8.DecodeRuneInString(s) return size == len(s) && unicode.IsLower(rune) @@ -124,9 +147,9 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // Wildcard matches any expression. If it appears multiple // times in the pattern, it must match the same expression // each time. - if m != nil && pattern != nil && pattern.Type() == identType { + if m != nil && pattern.IsValid() && pattern.Type() == identType { name := pattern.Interface().(*ast.Ident).Name - if isWildcard(name) && val != nil { + if isWildcard(name) && val.IsValid() { // wildcards only match expressions if _, ok := val.Interface().(ast.Expr); ok { if old, ok := m[name]; ok { @@ -139,8 +162,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { } // Otherwise, pattern and val must match recursively. - if pattern == nil || val == nil { - return pattern == nil && val == nil + if !pattern.IsValid() || !val.IsValid() { + return !pattern.IsValid() && !val.IsValid() } if pattern.Type() != val.Type() { return false @@ -148,9 +171,6 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // Special cases. switch pattern.Type() { - case positionType: - // token positions don't need to match - return true case identType: // For identifiers, only the names need to match // (and none of the other *ast.Object information). @@ -159,29 +179,30 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { p := pattern.Interface().(*ast.Ident) v := val.Interface().(*ast.Ident) return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name + case objectPtrType, positionType: + // object pointers and token positions don't need to match + return true } p := reflect.Indirect(pattern) v := reflect.Indirect(val) - if p == nil || v == nil { - return p == nil && v == nil + if !p.IsValid() || !v.IsValid() { + return !p.IsValid() && !v.IsValid() } - switch p := p.(type) { - case *reflect.SliceValue: - v := v.(*reflect.SliceValue) + switch p.Kind() { + case reflect.Slice: if p.Len() != v.Len() { return false } for i := 0; i < p.Len(); i++ { - if !match(m, p.Elem(i), v.Elem(i)) { + if !match(m, p.Index(i), v.Index(i)) { return false } } return true - case *reflect.StructValue: - v := v.(*reflect.StructValue) + case reflect.Struct: if p.NumField() != v.NumField() { return false } @@ -192,8 +213,7 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { } return true - case *reflect.InterfaceValue: - v := v.(*reflect.InterfaceValue) + case reflect.Interface: return match(m, p.Elem(), v.Elem()) } @@ -207,8 +227,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // if m == nil, subst returns a copy of pattern and doesn't change the line // number information. func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value { - if pattern == nil { - return nil + if !pattern.IsValid() { + return reflect.Value{} } // Wildcard gets replaced with map value. @@ -216,12 +236,12 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) name := pattern.Interface().(*ast.Ident).Name if isWildcard(name) { if old, ok := m[name]; ok { - return subst(nil, old, nil) + return subst(nil, old, reflect.Value{}) } } } - if pos != nil && pattern.Type() == positionType { + if pos.IsValid() && pattern.Type() == positionType { // use new position only if old position was valid in the first place if old := pattern.Interface().(token.Pos); !old.IsValid() { return pattern @@ -230,29 +250,33 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) } // Otherwise copy. - switch p := pattern.(type) { - case *reflect.SliceValue: - v := reflect.MakeSlice(p.Type().(*reflect.SliceType), p.Len(), p.Len()) + switch p := pattern; p.Kind() { + case reflect.Slice: + v := reflect.MakeSlice(p.Type(), p.Len(), p.Len()) for i := 0; i < p.Len(); i++ { - v.Elem(i).SetValue(subst(m, p.Elem(i), pos)) + v.Index(i).Set(subst(m, p.Index(i), pos)) } return v - case *reflect.StructValue: - v := reflect.MakeZero(p.Type()).(*reflect.StructValue) + case reflect.Struct: + v := reflect.Zero(p.Type()) for i := 0; i < p.NumField(); i++ { - v.Field(i).SetValue(subst(m, p.Field(i), pos)) + v.Field(i).Set(subst(m, p.Field(i), pos)) } return v - case *reflect.PtrValue: - v := reflect.MakeZero(p.Type()).(*reflect.PtrValue) - v.PointTo(subst(m, p.Elem(), pos)) + case reflect.Ptr: + v := reflect.Zero(p.Type()) + if elem := p.Elem(); elem.IsValid() { + v.Set(subst(m, elem, pos).Addr()) + } return v - case *reflect.InterfaceValue: - v := reflect.MakeZero(p.Type()).(*reflect.InterfaceValue) - v.SetValue(subst(m, p.Elem(), pos)) + case reflect.Interface: + v := reflect.Zero(p.Type()) + if elem := p.Elem(); elem.IsValid() { + v.Set(subst(m, elem, pos)) + } return v } |