diff options
Diffstat (limited to 'src/cmd/gofmt/rewrite.go')
-rw-r--r-- | src/cmd/gofmt/rewrite.go | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go index 93643dced..4c24282f3 100644 --- a/src/cmd/gofmt/rewrite.go +++ b/src/cmd/gofmt/rewrite.go @@ -19,6 +19,7 @@ import ( func initRewrite() { if *rewriteRule == "" { + rewrite = nil // disable any previous rewrite return } f := strings.Split(*rewriteRule, "->", -1) @@ -59,26 +60,34 @@ func dump(msg string, val reflect.Value) { // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file. func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { m := make(map[string]reflect.Value) - pat := reflect.NewValue(pattern) - repl := reflect.NewValue(replace) + pat := reflect.ValueOf(pattern) + repl := reflect.ValueOf(replace) var f func(val reflect.Value) reflect.Value // f is recursive f = func(val reflect.Value) reflect.Value { + // don't bother if val is invalid to start with + if !val.IsValid() { + return reflect.Value{} + } for k := range m { m[k] = reflect.Value{}, false } val = apply(f, val) if match(m, pat, val) { - val = subst(m, repl, reflect.NewValue(val.Interface().(ast.Node).Pos())) + val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos())) } return val } - return apply(f, reflect.NewValue(p)).Interface().(*ast.File) + return apply(f, reflect.ValueOf(p)).Interface().(*ast.File) } // setValue is a wrapper for x.SetValue(y); it protects // the caller from panics if x cannot be changed to y. func setValue(x, y reflect.Value) { + // don't bother if y is invalid to start with + if !y.IsValid() { + return + } defer func() { if x := recover(); x != nil { if s, ok := x.(string); ok && strings.HasPrefix(s, "type mismatch") { @@ -94,11 +103,13 @@ func setValue(x, y reflect.Value) { // Values/types for special cases. var ( - objectPtrNil = reflect.NewValue((*ast.Object)(nil)) + objectPtrNil = reflect.ValueOf((*ast.Object)(nil)) + scopePtrNil = reflect.ValueOf((*ast.Scope)(nil)) - identType = reflect.Typeof((*ast.Ident)(nil)) - objectPtrType = reflect.Typeof((*ast.Object)(nil)) - positionType = reflect.Typeof(token.NoPos) + identType = reflect.TypeOf((*ast.Ident)(nil)) + objectPtrType = reflect.TypeOf((*ast.Object)(nil)) + positionType = reflect.TypeOf(token.NoPos) + scopePtrType = reflect.TypeOf((*ast.Scope)(nil)) ) @@ -115,6 +126,12 @@ func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value return objectPtrNil } + // similarly for scopes: they are likely incorrect after a rewrite; + // replace them with nil + if val.Type() == scopePtrType { + return scopePtrNil + } + switch v := reflect.Indirect(val); v.Kind() { case reflect.Slice: for i := 0; i < v.Len(); i++ { @@ -259,21 +276,21 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) return v case reflect.Struct: - v := reflect.Zero(p.Type()) + v := reflect.New(p.Type()).Elem() for i := 0; i < p.NumField(); i++ { v.Field(i).Set(subst(m, p.Field(i), pos)) } return v case reflect.Ptr: - v := reflect.Zero(p.Type()) + v := reflect.New(p.Type()).Elem() if elem := p.Elem(); elem.IsValid() { v.Set(subst(m, elem, pos).Addr()) } return v case reflect.Interface: - v := reflect.Zero(p.Type()) + v := reflect.New(p.Type()).Elem() if elem := p.Elem(); elem.IsValid() { v.Set(subst(m, elem, pos)) } |