diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/pkg/reflect/type.go | 16 | ||||
-rw-r--r-- | src/pkg/reflect/value.go | 10 | ||||
-rw-r--r-- | src/pkg/xml/embed_test.go | 124 | ||||
-rw-r--r-- | src/pkg/xml/read.go | 29 |
4 files changed, 160 insertions, 19 deletions
diff --git a/src/pkg/reflect/type.go b/src/pkg/reflect/type.go index a8df033af..eb1ba52a9 100644 --- a/src/pkg/reflect/type.go +++ b/src/pkg/reflect/type.go @@ -507,7 +507,7 @@ func (t *StructType) FieldByIndex(index []int) (f StructField) { const inf = 1 << 30 // infinity - no struct has that many nesting levels -func (t *StructType) fieldByName(name string, mark map[*StructType]bool, depth int) (ff StructField, fd int) { +func (t *StructType) fieldByNameFunc(match func(string) bool, mark map[*StructType]bool, depth int) (ff StructField, fd int) { fd = inf // field depth if mark[t] { @@ -522,7 +522,7 @@ L: for i, _ := range t.fields { f := t.Field(i) d := inf switch { - case f.Name == name: + case match(f.Name): // Matching top-level field. d = depth case f.Anonymous: @@ -531,13 +531,13 @@ L: for i, _ := range t.fields { ft = pt.Elem() } switch { - case ft.Name() == name: + case match(ft.Name()): // Matching anonymous top-level field. d = depth case fd > depth: // No top-level field yet; look inside nested structs. if st, ok := ft.(*StructType); ok { - f, d = st.fieldByName(name, mark, depth+1) + f, d = st.fieldByNameFunc(match, mark, depth+1) } } } @@ -576,7 +576,13 @@ L: for i, _ := range t.fields { // FieldByName returns the struct field with the given name // and a boolean to indicate if the field was found. func (t *StructType) FieldByName(name string) (f StructField, present bool) { - if ff, fd := t.fieldByName(name, make(map[*StructType]bool), 0); fd < inf { + return t.FieldByNameFunc(func(s string) bool { return s == name }) +} + +// FieldByNameFunc returns the struct field with a name that satisfies the +// match function and a boolean to indicate if the field was found. +func (t *StructType) FieldByNameFunc(match func(string) bool) (f StructField, present bool) { + if ff, fd := t.fieldByNameFunc(match, make(map[*StructType]bool), 0); fd < inf { ff.Index = ff.Index[0 : fd+1] f, present = ff, true } diff --git a/src/pkg/reflect/value.go b/src/pkg/reflect/value.go index f21c564d5..d8ddb289a 100644 --- a/src/pkg/reflect/value.go +++ b/src/pkg/reflect/value.go @@ -1251,6 +1251,16 @@ func (t *StructValue) FieldByName(name string) Value { return nil } +// FieldByNameFunc returns the struct field with a name that satisfies the +// match function. +// The result is nil if no field was found. +func (t *StructValue) FieldByNameFunc(match func(string) bool) Value { + if f, ok := t.Type().(*StructType).FieldByNameFunc(match); ok { + return t.FieldByIndex(f.Index) + } + return nil +} + // NumField returns the number of fields in the struct. func (v *StructValue) NumField() int { return v.typ.(*StructType).NumField() } diff --git a/src/pkg/xml/embed_test.go b/src/pkg/xml/embed_test.go new file mode 100644 index 000000000..abfe781ac --- /dev/null +++ b/src/pkg/xml/embed_test.go @@ -0,0 +1,124 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import "testing" + +type C struct { + Name string + Open bool +} + +type A struct { + XMLName Name "http://domain a" + C + B B + FieldA string +} + +type B struct { + XMLName Name "b" + C + FieldB string +} + +const _1a = ` +<?xml version="1.0" encoding="UTF-8"?> +<a xmlns="http://domain"> + <name>KmlFile</name> + <open>1</open> + <b> + <name>Absolute</name> + <open>0</open> + <fieldb>bar</fieldb> + </b> + <fielda>foo</fielda> +</a> +` + +// Tests that embedded structs are marshalled. +func TestEmbedded1(t *testing.T) { + var a A + if e := Unmarshal(StringReader(_1a), &a); e != nil { + t.Fatalf("Unmarshal: %s", e) + } + if a.FieldA != "foo" { + t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.FieldA) + } + if a.Name != "KmlFile" { + t.Fatalf("Unmarshal: expected 'KmlFile' but found '%s'", a.Name) + } + if !a.Open { + t.Fatal("Unmarshal: expected 'true' but found otherwise") + } + if a.B.FieldB != "bar" { + t.Fatalf("Unmarshal: expected 'bar' but found '%s'", a.B.FieldB) + } + if a.B.Name != "Absolute" { + t.Fatalf("Unmarshal: expected 'Absolute' but found '%s'", a.B.Name) + } + if a.B.Open { + t.Fatal("Unmarshal: expected 'false' but found otherwise") + } +} + +type A2 struct { + XMLName Name "http://domain a" + XY string + Xy string +} + +const _2a = ` +<?xml version="1.0" encoding="UTF-8"?> +<a xmlns="http://domain"> + <xy>foo</xy> +</a> +` + +// Tests that conflicting field names get excluded. +func TestEmbedded2(t *testing.T) { + var a A2 + if e := Unmarshal(StringReader(_2a), &a); e != nil { + t.Fatalf("Unmarshal: %s", e) + } + if a.XY != "" { + t.Fatalf("Unmarshal: expected empty string but found '%s'", a.XY) + } + if a.Xy != "" { + t.Fatalf("Unmarshal: expected empty string but found '%s'", a.Xy) + } +} + +type A3 struct { + XMLName Name "http://domain a" + xy string +} + +// Tests that private fields are not set. +func TestEmbedded3(t *testing.T) { + var a A3 + if e := Unmarshal(StringReader(_2a), &a); e != nil { + t.Fatalf("Unmarshal: %s", e) + } + if a.xy != "" { + t.Fatalf("Unmarshal: expected empty string but found '%s'", a.xy) + } +} + +type A4 struct { + XMLName Name "http://domain a" + Any string +} + +// Tests that private fields are not set. +func TestEmbedded4(t *testing.T) { + var a A4 + if e := Unmarshal(StringReader(_2a), &a); e != nil { + t.Fatalf("Unmarshal: %s", e) + } + if a.Any != "foo" { + t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.Any) + } +} diff --git a/src/pkg/xml/read.go b/src/pkg/xml/read.go index 9eb0be253..45db7daa3 100644 --- a/src/pkg/xml/read.go +++ b/src/pkg/xml/read.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" "unicode" + "utf8" ) // BUG(rsc): Mapping between XML elements and data structures is inherently flawed: @@ -331,24 +332,24 @@ Loop: case StartElement: // Sub-element. // Look up by tag name. - // If that fails, fall back to mop-up field named "Any". if sv != nil { k := fieldName(t.Name.Local) - any := -1 - for i, n := 0, styp.NumField(); i < n; i++ { - f := styp.Field(i) - if strings.ToLower(f.Name) == k { - if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil { - return err - } - continue Loop - } - if any < 0 && f.Name == "Any" { - any = i + match := func(s string) bool { + // check if the name matches ignoring case + if strings.ToLower(s) != strings.ToLower(k) { + return false } + // now check that it's public + c, _ := utf8.DecodeRuneInString(s) + return unicode.IsUpper(c) + } + + f, found := styp.FieldByNameFunc(match) + if !found { // fall back to mop-up field named "Any" + f, found = styp.FieldByName("Any") } - if any >= 0 { - if err := p.unmarshal(sv.FieldByIndex(styp.Field(any).Index), &t); err != nil { + if found { + if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil { return err } continue Loop |