summaryrefslogtreecommitdiff
path: root/src/cmd/gofmt/rewrite.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd/gofmt/rewrite.go')
-rw-r--r--src/cmd/gofmt/rewrite.go39
1 files changed, 28 insertions, 11 deletions
diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go
index 93643dced..4c24282f3 100644
--- a/src/cmd/gofmt/rewrite.go
+++ b/src/cmd/gofmt/rewrite.go
@@ -19,6 +19,7 @@ import (
func initRewrite() {
if *rewriteRule == "" {
+ rewrite = nil // disable any previous rewrite
return
}
f := strings.Split(*rewriteRule, "->", -1)
@@ -59,26 +60,34 @@ func dump(msg string, val reflect.Value) {
// rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
m := make(map[string]reflect.Value)
- pat := reflect.NewValue(pattern)
- repl := reflect.NewValue(replace)
+ pat := reflect.ValueOf(pattern)
+ repl := reflect.ValueOf(replace)
var f func(val reflect.Value) reflect.Value // f is recursive
f = func(val reflect.Value) reflect.Value {
+ // don't bother if val is invalid to start with
+ if !val.IsValid() {
+ return reflect.Value{}
+ }
for k := range m {
m[k] = reflect.Value{}, false
}
val = apply(f, val)
if match(m, pat, val) {
- val = subst(m, repl, reflect.NewValue(val.Interface().(ast.Node).Pos()))
+ val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
}
return val
}
- return apply(f, reflect.NewValue(p)).Interface().(*ast.File)
+ return apply(f, reflect.ValueOf(p)).Interface().(*ast.File)
}
// setValue is a wrapper for x.SetValue(y); it protects
// the caller from panics if x cannot be changed to y.
func setValue(x, y reflect.Value) {
+ // don't bother if y is invalid to start with
+ if !y.IsValid() {
+ return
+ }
defer func() {
if x := recover(); x != nil {
if s, ok := x.(string); ok && strings.HasPrefix(s, "type mismatch") {
@@ -94,11 +103,13 @@ func setValue(x, y reflect.Value) {
// Values/types for special cases.
var (
- objectPtrNil = reflect.NewValue((*ast.Object)(nil))
+ objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
+ scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
- identType = reflect.Typeof((*ast.Ident)(nil))
- objectPtrType = reflect.Typeof((*ast.Object)(nil))
- positionType = reflect.Typeof(token.NoPos)
+ identType = reflect.TypeOf((*ast.Ident)(nil))
+ objectPtrType = reflect.TypeOf((*ast.Object)(nil))
+ positionType = reflect.TypeOf(token.NoPos)
+ scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
)
@@ -115,6 +126,12 @@ func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value
return objectPtrNil
}
+ // similarly for scopes: they are likely incorrect after a rewrite;
+ // replace them with nil
+ if val.Type() == scopePtrType {
+ return scopePtrNil
+ }
+
switch v := reflect.Indirect(val); v.Kind() {
case reflect.Slice:
for i := 0; i < v.Len(); i++ {
@@ -259,21 +276,21 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value)
return v
case reflect.Struct:
- v := reflect.Zero(p.Type())
+ v := reflect.New(p.Type()).Elem()
for i := 0; i < p.NumField(); i++ {
v.Field(i).Set(subst(m, p.Field(i), pos))
}
return v
case reflect.Ptr:
- v := reflect.Zero(p.Type())
+ v := reflect.New(p.Type()).Elem()
if elem := p.Elem(); elem.IsValid() {
v.Set(subst(m, elem, pos).Addr())
}
return v
case reflect.Interface:
- v := reflect.Zero(p.Type())
+ v := reflect.New(p.Type()).Elem()
if elem := p.Elem(); elem.IsValid() {
v.Set(subst(m, elem, pos))
}