summaryrefslogtreecommitdiff
path: root/src/cmd/gofmt/rewrite.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd/gofmt/rewrite.go')
-rw-r--r--src/cmd/gofmt/rewrite.go118
1 files changed, 71 insertions, 47 deletions
diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go
index fbcd46aa2..93643dced 100644
--- a/src/cmd/gofmt/rewrite.go
+++ b/src/cmd/gofmt/rewrite.go
@@ -46,6 +46,16 @@ func parseExpr(s string, what string) ast.Expr {
}
+// Keep this function for debugging.
+/*
+func dump(msg string, val reflect.Value) {
+ fmt.Printf("%s:\n", msg)
+ ast.Print(fset, val.Interface())
+ fmt.Println()
+}
+*/
+
+
// rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
m := make(map[string]reflect.Value)
@@ -54,7 +64,7 @@ func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
var f func(val reflect.Value) reflect.Value // f is recursive
f = func(val reflect.Value) reflect.Value {
for k := range m {
- m[k] = nil, false
+ m[k] = reflect.Value{}, false
}
val = apply(f, val)
if match(m, pat, val) {
@@ -78,28 +88,45 @@ func setValue(x, y reflect.Value) {
panic(x)
}
}()
- x.SetValue(y)
+ x.Set(y)
}
+// Values/types for special cases.
+var (
+ objectPtrNil = reflect.NewValue((*ast.Object)(nil))
+
+ identType = reflect.Typeof((*ast.Ident)(nil))
+ objectPtrType = reflect.Typeof((*ast.Object)(nil))
+ positionType = reflect.Typeof(token.NoPos)
+)
+
+
// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
- if val == nil {
- return nil
+ if !val.IsValid() {
+ return reflect.Value{}
+ }
+
+ // *ast.Objects introduce cycles and are likely incorrect after
+ // rewrite; don't follow them but replace with nil instead
+ if val.Type() == objectPtrType {
+ return objectPtrNil
}
- switch v := reflect.Indirect(val).(type) {
- case *reflect.SliceValue:
+
+ switch v := reflect.Indirect(val); v.Kind() {
+ case reflect.Slice:
for i := 0; i < v.Len(); i++ {
- e := v.Elem(i)
+ e := v.Index(i)
setValue(e, f(e))
}
- case *reflect.StructValue:
+ case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
e := v.Field(i)
setValue(e, f(e))
}
- case *reflect.InterfaceValue:
+ case reflect.Interface:
e := v.Elem()
setValue(v, f(e))
}
@@ -107,10 +134,6 @@ func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value
}
-var positionType = reflect.Typeof(token.NoPos)
-var identType = reflect.Typeof((*ast.Ident)(nil))
-
-
func isWildcard(s string) bool {
rune, size := utf8.DecodeRuneInString(s)
return size == len(s) && unicode.IsLower(rune)
@@ -124,9 +147,9 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
// Wildcard matches any expression. If it appears multiple
// times in the pattern, it must match the same expression
// each time.
- if m != nil && pattern != nil && pattern.Type() == identType {
+ if m != nil && pattern.IsValid() && pattern.Type() == identType {
name := pattern.Interface().(*ast.Ident).Name
- if isWildcard(name) && val != nil {
+ if isWildcard(name) && val.IsValid() {
// wildcards only match expressions
if _, ok := val.Interface().(ast.Expr); ok {
if old, ok := m[name]; ok {
@@ -139,8 +162,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
}
// Otherwise, pattern and val must match recursively.
- if pattern == nil || val == nil {
- return pattern == nil && val == nil
+ if !pattern.IsValid() || !val.IsValid() {
+ return !pattern.IsValid() && !val.IsValid()
}
if pattern.Type() != val.Type() {
return false
@@ -148,9 +171,6 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
// Special cases.
switch pattern.Type() {
- case positionType:
- // token positions don't need to match
- return true
case identType:
// For identifiers, only the names need to match
// (and none of the other *ast.Object information).
@@ -159,29 +179,30 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
p := pattern.Interface().(*ast.Ident)
v := val.Interface().(*ast.Ident)
return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
+ case objectPtrType, positionType:
+ // object pointers and token positions don't need to match
+ return true
}
p := reflect.Indirect(pattern)
v := reflect.Indirect(val)
- if p == nil || v == nil {
- return p == nil && v == nil
+ if !p.IsValid() || !v.IsValid() {
+ return !p.IsValid() && !v.IsValid()
}
- switch p := p.(type) {
- case *reflect.SliceValue:
- v := v.(*reflect.SliceValue)
+ switch p.Kind() {
+ case reflect.Slice:
if p.Len() != v.Len() {
return false
}
for i := 0; i < p.Len(); i++ {
- if !match(m, p.Elem(i), v.Elem(i)) {
+ if !match(m, p.Index(i), v.Index(i)) {
return false
}
}
return true
- case *reflect.StructValue:
- v := v.(*reflect.StructValue)
+ case reflect.Struct:
if p.NumField() != v.NumField() {
return false
}
@@ -192,8 +213,7 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
}
return true
- case *reflect.InterfaceValue:
- v := v.(*reflect.InterfaceValue)
+ case reflect.Interface:
return match(m, p.Elem(), v.Elem())
}
@@ -207,8 +227,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
// if m == nil, subst returns a copy of pattern and doesn't change the line
// number information.
func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
- if pattern == nil {
- return nil
+ if !pattern.IsValid() {
+ return reflect.Value{}
}
// Wildcard gets replaced with map value.
@@ -216,12 +236,12 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value)
name := pattern.Interface().(*ast.Ident).Name
if isWildcard(name) {
if old, ok := m[name]; ok {
- return subst(nil, old, nil)
+ return subst(nil, old, reflect.Value{})
}
}
}
- if pos != nil && pattern.Type() == positionType {
+ if pos.IsValid() && pattern.Type() == positionType {
// use new position only if old position was valid in the first place
if old := pattern.Interface().(token.Pos); !old.IsValid() {
return pattern
@@ -230,29 +250,33 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value)
}
// Otherwise copy.
- switch p := pattern.(type) {
- case *reflect.SliceValue:
- v := reflect.MakeSlice(p.Type().(*reflect.SliceType), p.Len(), p.Len())
+ switch p := pattern; p.Kind() {
+ case reflect.Slice:
+ v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
for i := 0; i < p.Len(); i++ {
- v.Elem(i).SetValue(subst(m, p.Elem(i), pos))
+ v.Index(i).Set(subst(m, p.Index(i), pos))
}
return v
- case *reflect.StructValue:
- v := reflect.MakeZero(p.Type()).(*reflect.StructValue)
+ case reflect.Struct:
+ v := reflect.Zero(p.Type())
for i := 0; i < p.NumField(); i++ {
- v.Field(i).SetValue(subst(m, p.Field(i), pos))
+ v.Field(i).Set(subst(m, p.Field(i), pos))
}
return v
- case *reflect.PtrValue:
- v := reflect.MakeZero(p.Type()).(*reflect.PtrValue)
- v.PointTo(subst(m, p.Elem(), pos))
+ case reflect.Ptr:
+ v := reflect.Zero(p.Type())
+ if elem := p.Elem(); elem.IsValid() {
+ v.Set(subst(m, elem, pos).Addr())
+ }
return v
- case *reflect.InterfaceValue:
- v := reflect.MakeZero(p.Type()).(*reflect.InterfaceValue)
- v.SetValue(subst(m, p.Elem(), pos))
+ case reflect.Interface:
+ v := reflect.Zero(p.Type())
+ if elem := p.Elem(); elem.IsValid() {
+ v.Set(subst(m, elem, pos))
+ }
return v
}