diff options
Diffstat (limited to 'src/pkg/encoding')
46 files changed, 2540 insertions, 643 deletions
diff --git a/src/pkg/encoding/asn1/asn1.go b/src/pkg/encoding/asn1/asn1.go index ac2b5f8da..cac9d64b5 100644 --- a/src/pkg/encoding/asn1/asn1.go +++ b/src/pkg/encoding/asn1/asn1.go @@ -77,15 +77,15 @@ func parseInt64(bytes []byte) (ret int64, err error) { // parseInt treats the given bytes as a big-endian, signed integer and returns // the result. -func parseInt(bytes []byte) (int, error) { +func parseInt32(bytes []byte) (int32, error) { ret64, err := parseInt64(bytes) if err != nil { return 0, err } - if ret64 != int64(int(ret64)) { + if ret64 != int64(int32(ret64)) { return 0, StructuralError{"integer too large"} } - return int(ret64), nil + return int32(ret64), nil } var bigOne = big.NewInt(1) @@ -670,7 +670,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam err = err1 return case enumeratedType: - parsedInt, err1 := parseInt(innerBytes) + parsedInt, err1 := parseInt32(innerBytes) if err1 == nil { v.SetInt(int64(parsedInt)) } @@ -692,19 +692,20 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam } err = err1 return - case reflect.Int, reflect.Int32: - parsedInt, err1 := parseInt(innerBytes) - if err1 == nil { - val.SetInt(int64(parsedInt)) - } - err = err1 - return - case reflect.Int64: - parsedInt, err1 := parseInt64(innerBytes) - if err1 == nil { - val.SetInt(parsedInt) + case reflect.Int, reflect.Int32, reflect.Int64: + if val.Type().Size() == 4 { + parsedInt, err1 := parseInt32(innerBytes) + if err1 == nil { + val.SetInt(int64(parsedInt)) + } + err = err1 + } else { + parsedInt, err1 := parseInt64(innerBytes) + if err1 == nil { + val.SetInt(parsedInt) + } + err = err1 } - err = err1 return // TODO(dfc) Add support for the remaining integer types case reflect.Struct: diff --git a/src/pkg/encoding/asn1/asn1_test.go b/src/pkg/encoding/asn1/asn1_test.go index eb848bdb4..6e98dcf0b 100644 --- a/src/pkg/encoding/asn1/asn1_test.go +++ b/src/pkg/encoding/asn1/asn1_test.go @@ -64,7 +64,7 @@ var int32TestData = []int32Test{ func TestParseInt32(t *testing.T) { for i, test := range int32TestData { - ret, err := parseInt(test.in) + ret, err := parseInt32(test.in) if (err == nil) != test.ok { t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok) } @@ -124,7 +124,7 @@ func TestBitString(t *testing.T) { t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok) } if err == nil { - if test.bitLength != ret.BitLength || bytes.Compare(ret.Bytes, test.out) != 0 { + if test.bitLength != ret.BitLength || !bytes.Equal(ret.Bytes, test.out) { t.Errorf("#%d: Bad result: %v (expected %v %v)", i, ret, test.out, test.bitLength) } } @@ -166,7 +166,7 @@ func TestBitStringRightAlign(t *testing.T) { for i, test := range bitStringRightAlignTests { bs := BitString{test.in, test.inlen} out := bs.RightAlign() - if bytes.Compare(out, test.out) != 0 { + if !bytes.Equal(out, test.out) { t.Errorf("#%d got: %x want: %x", i, out, test.out) } } @@ -477,7 +477,7 @@ func TestRawStructs(t *testing.T) { if s.A != 0x50 { t.Errorf("bad value for A: got %d want %d", s.A, 0x50) } - if bytes.Compare([]byte(s.Raw), input) != 0 { + if !bytes.Equal([]byte(s.Raw), input) { t.Errorf("bad value for Raw: got %x want %x", s.Raw, input) } } diff --git a/src/pkg/encoding/asn1/common.go b/src/pkg/encoding/asn1/common.go index 03856bc55..33a117ece 100644 --- a/src/pkg/encoding/asn1/common.go +++ b/src/pkg/encoding/asn1/common.go @@ -98,6 +98,8 @@ func parseFieldParameters(str string) (ret fieldParameters) { ret.stringType = tagIA5String case part == "printable": ret.stringType = tagPrintableString + case part == "utf8": + ret.stringType = tagUTF8String case strings.HasPrefix(part, "default:"): i, err := strconv.ParseInt(part[8:], 10, 64) if err == nil { diff --git a/src/pkg/encoding/asn1/marshal.go b/src/pkg/encoding/asn1/marshal.go index 163bca575..0c216fdb3 100644 --- a/src/pkg/encoding/asn1/marshal.go +++ b/src/pkg/encoding/asn1/marshal.go @@ -6,11 +6,13 @@ package asn1 import ( "bytes" + "errors" "fmt" "io" "math/big" "reflect" "time" + "unicode/utf8" ) // A forkableWriter is an in-memory buffer that can be @@ -280,6 +282,11 @@ func marshalIA5String(out *forkableWriter, s string) (err error) { return } +func marshalUTF8String(out *forkableWriter, s string) (err error) { + _, err = out.Write([]byte(s)) + return +} + func marshalTwoDigits(out *forkableWriter, v int) (err error) { err = out.WriteByte(byte('0' + (v/10)%10)) if err != nil { @@ -289,8 +296,7 @@ func marshalTwoDigits(out *forkableWriter, v int) (err error) { } func marshalUTCTime(out *forkableWriter, t time.Time) (err error) { - utc := t.UTC() - year, month, day := utc.Date() + year, month, day := t.Date() switch { case 1950 <= year && year < 2000: @@ -314,7 +320,7 @@ func marshalUTCTime(out *forkableWriter, t time.Time) (err error) { return } - hour, min, sec := utc.Clock() + hour, min, sec := t.Clock() err = marshalTwoDigits(out, hour) if err != nil { @@ -446,10 +452,13 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter } return case reflect.String: - if params.stringType == tagIA5String { + switch params.stringType { + case tagIA5String: return marshalIA5String(out, v.String()) - } else { + case tagPrintableString: return marshalPrintableString(out, v.String()) + default: + return marshalUTF8String(out, v.String()) } return } @@ -492,11 +501,27 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) } class := classUniversal - if params.stringType != 0 { - if tag != tagPrintableString { - return StructuralError{"Explicit string type given to non-string member"} + if params.stringType != 0 && tag != tagPrintableString { + return StructuralError{"Explicit string type given to non-string member"} + } + + if tag == tagPrintableString { + if params.stringType == 0 { + // This is a string without an explicit string type. We'll use + // a PrintableString if the character set in the string is + // sufficiently limited, otherwise we'll use a UTF8String. + for _, r := range v.String() { + if r >= utf8.RuneSelf || !isPrintable(byte(r)) { + if !utf8.ValidString(v.String()) { + return errors.New("asn1: string not valid UTF-8") + } + tag = tagUTF8String + break + } + } + } else { + tag = params.stringType } - tag = params.stringType } if params.set { diff --git a/src/pkg/encoding/asn1/marshal_test.go b/src/pkg/encoding/asn1/marshal_test.go index f43bcae68..b4dbe71ef 100644 --- a/src/pkg/encoding/asn1/marshal_test.go +++ b/src/pkg/encoding/asn1/marshal_test.go @@ -82,7 +82,7 @@ var marshalTests = []marshalTest{ {explicitTagTest{64}, "3005a503020140"}, {time.Unix(0, 0).UTC(), "170d3730303130313030303030305a"}, {time.Unix(1258325776, 0).UTC(), "170d3039313131353232353631365a"}, - {time.Unix(1258325776, 0).In(PST), "17113039313131353232353631362d30383030"}, + {time.Unix(1258325776, 0).In(PST), "17113039313131353134353631362d30383030"}, {BitString{[]byte{0x80}, 1}, "03020780"}, {BitString{[]byte{0x81, 0xf0}, 12}, "03030481f0"}, {ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"}, @@ -122,6 +122,7 @@ var marshalTests = []marshalTest{ {testSET([]int{10}), "310302010a"}, {omitEmptyTest{[]string{}}, "3000"}, {omitEmptyTest{[]string{"1"}}, "30053003130131"}, + {"Σ", "0c02cea3"}, } func TestMarshal(t *testing.T) { @@ -131,9 +132,16 @@ func TestMarshal(t *testing.T) { t.Errorf("#%d failed: %s", i, err) } out, _ := hex.DecodeString(test.out) - if bytes.Compare(out, data) != 0 { + if !bytes.Equal(out, data) { t.Errorf("#%d got: %x want %x\n\t%q\n\t%q", i, data, out, data, out) } } } + +func TestInvalidUTF8(t *testing.T) { + _, err := Marshal(string([]byte{0xff, 0xff})) + if err == nil { + t.Errorf("invalid UTF8 string was accepted") + } +} diff --git a/src/pkg/encoding/base32/base32.go b/src/pkg/encoding/base32/base32.go index 71da6e22b..dbefc48fa 100644 --- a/src/pkg/encoding/base32/base32.go +++ b/src/pkg/encoding/base32/base32.go @@ -237,7 +237,6 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { dlen := 8 // do the top bytes contain any data? - dbufloop: for j := 0; j < 8; { if len(src) == 0 { return n, false, CorruptInputError(len(osrc) - len(src) - j) @@ -258,7 +257,7 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { } dlen = j end = true - break dbufloop + break } dbuf[j] = enc.decodeMap[in] if dbuf[j] == 0xFF { diff --git a/src/pkg/encoding/base32/example_test.go b/src/pkg/encoding/base32/example_test.go new file mode 100644 index 000000000..f6128d900 --- /dev/null +++ b/src/pkg/encoding/base32/example_test.go @@ -0,0 +1,45 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Keep in sync with ../base64/example_test.go. + +package base32_test + +import ( + "encoding/base32" + "fmt" + "os" +) + +func ExampleEncoding_EncodeToString() { + data := []byte("any + old & data") + str := base32.StdEncoding.EncodeToString(data) + fmt.Println(str) + // Output: + // MFXHSIBLEBXWYZBAEYQGIYLUME====== +} + +func ExampleEncoding_DecodeString() { + str := "ONXW2ZJAMRQXIYJAO5UXI2BAAAQGC3TEEDX3XPY=" + data, err := base32.StdEncoding.DecodeString(str) + if err != nil { + fmt.Println("error:", err) + return + } + fmt.Printf("%q\n", data) + // Output: + // "some data with \x00 and \ufeff" +} + +func ExampleNewEncoder() { + input := []byte("foo\x00bar") + encoder := base32.NewEncoder(base32.StdEncoding, os.Stdout) + encoder.Write(input) + // Must close the encoder when finished to flush any partial blocks. + // If you comment out the following line, the last partial block "r" + // won't be encoded. + encoder.Close() + // Output: + // MZXW6ADCMFZA==== +} diff --git a/src/pkg/encoding/base64/base64.go b/src/pkg/encoding/base64/base64.go index 0b842f066..e66672a1c 100644 --- a/src/pkg/encoding/base64/base64.go +++ b/src/pkg/encoding/base64/base64.go @@ -216,7 +216,6 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { var dbuf [4]byte dlen := 4 - dbufloop: for j := 0; j < 4; { if len(src) == 0 { return n, false, CorruptInputError(len(osrc) - len(src) - j) @@ -240,7 +239,7 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { } dlen = j end = true - break dbufloop + break } dbuf[j] = enc.decodeMap[in] if dbuf[j] == 0xFF { diff --git a/src/pkg/encoding/base64/base64_test.go b/src/pkg/encoding/base64/base64_test.go index f9b863c36..2166abd7a 100644 --- a/src/pkg/encoding/base64/base64_test.go +++ b/src/pkg/encoding/base64/base64_test.go @@ -257,6 +257,7 @@ func TestDecoderIssue3577(t *testing.T) { wantErr := errors.New("my error") next <- nextRead{5, nil} next <- nextRead{10, wantErr} + next <- nextRead{0, wantErr} d := NewDecoder(StdEncoding, &faultInjectReader{ source: "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==", // twas brillig... nextc: next, diff --git a/src/pkg/encoding/base64/example_test.go b/src/pkg/encoding/base64/example_test.go new file mode 100644 index 000000000..d18b856a0 --- /dev/null +++ b/src/pkg/encoding/base64/example_test.go @@ -0,0 +1,45 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Keep in sync with ../base32/example_test.go. + +package base64_test + +import ( + "encoding/base64" + "fmt" + "os" +) + +func ExampleEncoding_EncodeToString() { + data := []byte("any + old & data") + str := base64.StdEncoding.EncodeToString(data) + fmt.Println(str) + // Output: + // YW55ICsgb2xkICYgZGF0YQ== +} + +func ExampleEncoding_DecodeString() { + str := "c29tZSBkYXRhIHdpdGggACBhbmQg77u/" + data, err := base64.StdEncoding.DecodeString(str) + if err != nil { + fmt.Println("error:", err) + return + } + fmt.Printf("%q\n", data) + // Output: + // "some data with \x00 and \ufeff" +} + +func ExampleNewEncoder() { + input := []byte("foo\x00bar") + encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout) + encoder.Write(input) + // Must close the encoder when finished to flush any partial blocks. + // If you comment out the following line, the last partial block "r" + // won't be encoded. + encoder.Close() + // Output: + // Zm9vAGJhcg== +} diff --git a/src/pkg/encoding/binary/binary.go b/src/pkg/encoding/binary/binary.go index 712e490e6..edbac197d 100644 --- a/src/pkg/encoding/binary/binary.go +++ b/src/pkg/encoding/binary/binary.go @@ -125,6 +125,9 @@ func (bigEndian) GoString() string { return "binary.BigEndian" } // of fixed-size values. // Bytes read from r are decoded using the specified byte order // and written to successive fields of the data. +// When reading into structs, the field data for fields with +// blank (_) field names is skipped; i.e., blank field names +// may be used for padding. func Read(r io.Reader, order ByteOrder, data interface{}) error { // Fast path for basic types. if n := intDestSize(data); n != 0 { @@ -154,7 +157,7 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error { return nil } - // Fallback to reflect-based. + // Fallback to reflect-based decoding. var v reflect.Value switch d := reflect.ValueOf(data); d.Kind() { case reflect.Ptr: @@ -164,9 +167,9 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error { default: return errors.New("binary.Read: invalid type " + d.Type().String()) } - size := dataSize(v) - if size < 0 { - return errors.New("binary.Read: invalid type " + v.Type().String()) + size, err := dataSize(v) + if err != nil { + return errors.New("binary.Read: " + err.Error()) } d := &decoder{order: order, buf: make([]byte, size)} if _, err := io.ReadFull(r, d.buf); err != nil { @@ -181,6 +184,8 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error { // values, or a pointer to such data. // Bytes written to w are encoded using the specified byte order // and read from successive fields of the data. +// When writing structs, zero values are written for fields +// with blank (_) field names. func Write(w io.Writer, order ByteOrder, data interface{}) error { // Fast path for basic types. var b [8]byte @@ -239,76 +244,80 @@ func Write(w io.Writer, order ByteOrder, data interface{}) error { _, err := w.Write(bs) return err } + + // Fallback to reflect-based encoding. v := reflect.Indirect(reflect.ValueOf(data)) - size := dataSize(v) - if size < 0 { - return errors.New("binary.Write: invalid type " + v.Type().String()) + size, err := dataSize(v) + if err != nil { + return errors.New("binary.Write: " + err.Error()) } buf := make([]byte, size) e := &encoder{order: order, buf: buf} e.value(v) - _, err := w.Write(buf) + _, err = w.Write(buf) return err } // Size returns how many bytes Write would generate to encode the value v, which // must be a fixed-size value or a slice of fixed-size values, or a pointer to such data. func Size(v interface{}) int { - return dataSize(reflect.Indirect(reflect.ValueOf(v))) + n, err := dataSize(reflect.Indirect(reflect.ValueOf(v))) + if err != nil { + return -1 + } + return n } // dataSize returns the number of bytes the actual data represented by v occupies in memory. // For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice // it returns the length of the slice times the element size and does not count the memory // occupied by the header. -func dataSize(v reflect.Value) int { +func dataSize(v reflect.Value) (int, error) { if v.Kind() == reflect.Slice { - elem := sizeof(v.Type().Elem()) - if elem < 0 { - return -1 + elem, err := sizeof(v.Type().Elem()) + if err != nil { + return 0, err } - return v.Len() * elem + return v.Len() * elem, nil } return sizeof(v.Type()) } -func sizeof(t reflect.Type) int { +func sizeof(t reflect.Type) (int, error) { switch t.Kind() { case reflect.Array: - n := sizeof(t.Elem()) - if n < 0 { - return -1 + n, err := sizeof(t.Elem()) + if err != nil { + return 0, err } - return t.Len() * n + return t.Len() * n, nil case reflect.Struct: sum := 0 for i, n := 0, t.NumField(); i < n; i++ { - s := sizeof(t.Field(i).Type) - if s < 0 { - return -1 + s, err := sizeof(t.Field(i).Type) + if err != nil { + return 0, err } sum += s } - return sum + return sum, nil case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: - return int(t.Size()) + return int(t.Size()), nil } - return -1 + return 0, errors.New("invalid type " + t.String()) } -type decoder struct { +type coder struct { order ByteOrder buf []byte } -type encoder struct { - order ByteOrder - buf []byte -} +type decoder coder +type encoder coder func (d *decoder) uint8() uint8 { x := d.buf[0] @@ -379,9 +388,19 @@ func (d *decoder) value(v reflect.Value) { } case reflect.Struct: + t := v.Type() l := v.NumField() for i := 0; i < l; i++ { - d.value(v.Field(i)) + // Note: Calling v.CanSet() below is an optimization. + // It would be sufficient to check the field name, + // but creating the StructField info for each field is + // costly (run "go test -bench=ReadStruct" and compare + // results when making changes to this code). + if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" { + d.value(v) + } else { + d.skip(v) + } } case reflect.Slice: @@ -435,9 +454,15 @@ func (e *encoder) value(v reflect.Value) { } case reflect.Struct: + t := v.Type() l := v.NumField() for i := 0; i < l; i++ { - e.value(v.Field(i)) + // see comment for corresponding code in decoder.value() + if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" { + e.value(v) + } else { + e.skip(v) + } } case reflect.Slice: @@ -492,6 +517,19 @@ func (e *encoder) value(v reflect.Value) { } } +func (d *decoder) skip(v reflect.Value) { + n, _ := dataSize(v) + d.buf = d.buf[n:] +} + +func (e *encoder) skip(v reflect.Value) { + n, _ := dataSize(v) + for i := range e.buf[0:n] { + e.buf[i] = 0 + } + e.buf = e.buf[n:] +} + // intDestSize returns the size of the integer that ptrType points to, // or 0 if the type is not supported. func intDestSize(ptrType interface{}) int { diff --git a/src/pkg/encoding/binary/binary_test.go b/src/pkg/encoding/binary/binary_test.go index ff361b7e3..056f0998f 100644 --- a/src/pkg/encoding/binary/binary_test.go +++ b/src/pkg/encoding/binary/binary_test.go @@ -9,6 +9,7 @@ import ( "io" "math" "reflect" + "strings" "testing" ) @@ -120,18 +121,14 @@ func testWrite(t *testing.T, order ByteOrder, b []byte, s1 interface{}) { checkResult(t, "Write", order, err, buf.Bytes(), b) } -func TestBigEndianRead(t *testing.T) { testRead(t, BigEndian, big, s) } - -func TestLittleEndianRead(t *testing.T) { testRead(t, LittleEndian, little, s) } - -func TestBigEndianWrite(t *testing.T) { testWrite(t, BigEndian, big, s) } - -func TestLittleEndianWrite(t *testing.T) { testWrite(t, LittleEndian, little, s) } +func TestLittleEndianRead(t *testing.T) { testRead(t, LittleEndian, little, s) } +func TestLittleEndianWrite(t *testing.T) { testWrite(t, LittleEndian, little, s) } +func TestLittleEndianPtrWrite(t *testing.T) { testWrite(t, LittleEndian, little, &s) } +func TestBigEndianRead(t *testing.T) { testRead(t, BigEndian, big, s) } +func TestBigEndianWrite(t *testing.T) { testWrite(t, BigEndian, big, s) } func TestBigEndianPtrWrite(t *testing.T) { testWrite(t, BigEndian, big, &s) } -func TestLittleEndianPtrWrite(t *testing.T) { testWrite(t, LittleEndian, little, &s) } - func TestReadSlice(t *testing.T) { slice := make([]int32, 2) err := Read(bytes.NewBuffer(src), BigEndian, slice) @@ -147,20 +144,81 @@ func TestWriteSlice(t *testing.T) { func TestWriteT(t *testing.T) { buf := new(bytes.Buffer) ts := T{} - err := Write(buf, BigEndian, ts) - if err == nil { - t.Errorf("WriteT: have nil, want non-nil") + if err := Write(buf, BigEndian, ts); err == nil { + t.Errorf("WriteT: have err == nil, want non-nil") } tv := reflect.Indirect(reflect.ValueOf(ts)) for i, n := 0, tv.NumField(); i < n; i++ { - err = Write(buf, BigEndian, tv.Field(i).Interface()) - if err == nil { - t.Errorf("WriteT.%v: have nil, want non-nil", tv.Field(i).Type()) + typ := tv.Field(i).Type().String() + if typ == "[4]int" { + typ = "int" // the problem is int, not the [4] + } + if err := Write(buf, BigEndian, tv.Field(i).Interface()); err == nil { + t.Errorf("WriteT.%v: have err == nil, want non-nil", tv.Field(i).Type()) + } else if !strings.Contains(err.Error(), typ) { + t.Errorf("WriteT: have err == %q, want it to mention %s", err, typ) } } } +type BlankFields struct { + A uint32 + _ int32 + B float64 + _ [4]int16 + C byte + _ [7]byte + _ struct { + f [8]float32 + } +} + +type BlankFieldsProbe struct { + A uint32 + P0 int32 + B float64 + P1 [4]int16 + C byte + P2 [7]byte + P3 struct { + F [8]float32 + } +} + +func TestBlankFields(t *testing.T) { + buf := new(bytes.Buffer) + b1 := BlankFields{A: 1234567890, B: 2.718281828, C: 42} + if err := Write(buf, LittleEndian, &b1); err != nil { + t.Error(err) + } + + // zero values must have been written for blank fields + var p BlankFieldsProbe + if err := Read(buf, LittleEndian, &p); err != nil { + t.Error(err) + } + + // quick test: only check first value of slices + if p.P0 != 0 || p.P1[0] != 0 || p.P2[0] != 0 || p.P3.F[0] != 0 { + t.Errorf("non-zero values for originally blank fields: %#v", p) + } + + // write p and see if we can probe only some fields + if err := Write(buf, LittleEndian, &p); err != nil { + t.Error(err) + } + + // read should ignore blank fields in b2 + var b2 BlankFields + if err := Read(buf, LittleEndian, &b2); err != nil { + t.Error(err) + } + if b1.A != b2.A || b1.B != b2.B || b1.C != b2.C { + t.Errorf("%#v != %#v", b1, b2) + } +} + type byteSliceReader struct { remain []byte } @@ -187,7 +245,7 @@ func BenchmarkReadStruct(b *testing.B) { bsr := &byteSliceReader{} var buf bytes.Buffer Write(&buf, BigEndian, &s) - n := dataSize(reflect.ValueOf(s)) + n, _ := dataSize(reflect.ValueOf(s)) b.SetBytes(int64(n)) t := s b.ResetTimer() diff --git a/src/pkg/encoding/binary/varint.go b/src/pkg/encoding/binary/varint.go index b756afdd0..7035529f2 100644 --- a/src/pkg/encoding/binary/varint.go +++ b/src/pkg/encoding/binary/varint.go @@ -123,7 +123,7 @@ func ReadUvarint(r io.ByteReader) (uint64, error) { panic("unreachable") } -// ReadVarint reads an encoded unsigned integer from r and returns it as a uint64. +// ReadVarint reads an encoded signed integer from r and returns it as an int64. func ReadVarint(r io.ByteReader) (int64, error) { ux, err := ReadUvarint(r) // ok to continue in presence of error x := int64(ux >> 1) diff --git a/src/pkg/encoding/csv/writer.go b/src/pkg/encoding/csv/writer.go index c4dcba566..1faecb664 100644 --- a/src/pkg/encoding/csv/writer.go +++ b/src/pkg/encoding/csv/writer.go @@ -22,7 +22,7 @@ import ( // // If UseCRLF is true, the Writer ends each record with \r\n instead of \n. type Writer struct { - Comma rune // Field delimiter (set to to ',' by NewWriter) + Comma rune // Field delimiter (set to ',' by NewWriter) UseCRLF bool // True to use \r\n as the line terminator w *bufio.Writer } @@ -92,20 +92,26 @@ func (w *Writer) Write(record []string) (err error) { } // Flush writes any buffered data to the underlying io.Writer. +// To check if an error occurred during the Flush, call Error. func (w *Writer) Flush() { w.w.Flush() } +// Error reports any error that has occurred during a previous Write or Flush. +func (w *Writer) Error() error { + _, err := w.w.Write(nil) + return err +} + // WriteAll writes multiple CSV records to w using Write and then calls Flush. func (w *Writer) WriteAll(records [][]string) (err error) { for _, record := range records { err = w.Write(record) if err != nil { - break + return err } } - w.Flush() - return nil + return w.w.Flush() } // fieldNeedsQuotes returns true if our field must be enclosed in quotes. diff --git a/src/pkg/encoding/csv/writer_test.go b/src/pkg/encoding/csv/writer_test.go index 578959007..03ca6b093 100644 --- a/src/pkg/encoding/csv/writer_test.go +++ b/src/pkg/encoding/csv/writer_test.go @@ -6,6 +6,7 @@ package csv import ( "bytes" + "errors" "testing" ) @@ -42,3 +43,30 @@ func TestWrite(t *testing.T) { } } } + +type errorWriter struct{} + +func (e errorWriter) Write(b []byte) (int, error) { + return 0, errors.New("Test") +} + +func TestError(t *testing.T) { + b := &bytes.Buffer{} + f := NewWriter(b) + f.Write([]string{"abc"}) + f.Flush() + err := f.Error() + + if err != nil { + t.Errorf("Unexpected error: %s\n", err) + } + + f = NewWriter(errorWriter{}) + f.Write([]string{"abc"}) + f.Flush() + err = f.Error() + + if err == nil { + t.Error("Error should not be nil") + } +} diff --git a/src/pkg/encoding/gob/codec_test.go b/src/pkg/encoding/gob/codec_test.go index ebcbb78eb..482212b74 100644 --- a/src/pkg/encoding/gob/codec_test.go +++ b/src/pkg/encoding/gob/codec_test.go @@ -7,6 +7,7 @@ package gob import ( "bytes" "errors" + "flag" "math" "math/rand" "reflect" @@ -16,6 +17,8 @@ import ( "unsafe" ) +var doFuzzTests = flag.Bool("gob.fuzz", false, "run the fuzz tests, which are large and very slow") + // Guarantee encoding format by comparing some encodings to hand-written values type EncodeT struct { x uint64 @@ -1434,7 +1437,8 @@ func encFuzzDec(rng *rand.Rand, in interface{}) error { // This does some "fuzz testing" by attempting to decode a sequence of random bytes. func TestFuzz(t *testing.T) { - if testing.Short() { + if !*doFuzzTests { + t.Logf("disabled; run with -gob.fuzz to enable") return } @@ -1453,11 +1457,16 @@ func TestFuzz(t *testing.T) { } func TestFuzzRegressions(t *testing.T) { + if !*doFuzzTests { + t.Logf("disabled; run with -gob.fuzz to enable") + return + } + // An instance triggering a type name of length ~102 GB. testFuzz(t, 1328492090837718000, 100, new(float32)) // An instance triggering a type name of 1.6 GB. - // Commented out because it takes 5m to run. - //testFuzz(t, 1330522872628565000, 100, new(int)) + // Note: can take several minutes to run. + testFuzz(t, 1330522872628565000, 100, new(int)) } func testFuzz(t *testing.T, seed int64, n int, input ...interface{}) { diff --git a/src/pkg/encoding/gob/decode.go b/src/pkg/encoding/gob/decode.go index e32a178ab..a80d9f919 100644 --- a/src/pkg/encoding/gob/decode.go +++ b/src/pkg/encoding/gob/decode.go @@ -62,15 +62,15 @@ func overflow(name string) error { // Used only by the Decoder to read the message length. func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err error) { width = 1 - _, err = r.Read(buf[0:width]) - if err != nil { + n, err := io.ReadFull(r, buf[0:width]) + if n == 0 { return } b := buf[0] if b <= 0x7f { return uint64(b), width, nil } - n := -int(int8(b)) + n = -int(int8(b)) if n > uint64Size { err = errBadUint return @@ -562,6 +562,9 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) { func (dec *Decoder) decodeArrayHelper(state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl error) { instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl} for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding array or slice: length exceeds input size (%d elements)", length) + } up := unsafe.Pointer(p) if elemIndir > 1 { up = decIndirect(up, elemIndir) @@ -652,9 +655,6 @@ func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) { // Slices are encoded as an unsigned length followed by the elements. func (dec *Decoder) decodeSlice(atyp reflect.Type, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl error) { nr := state.decodeUint() - if nr > uint64(state.b.Len()) { - errorf("length of slice exceeds input size (%d elements)", nr) - } n := int(nr) if indir > 0 { up := unsafe.Pointer(p) @@ -717,7 +717,9 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui errorf("name too long (%d bytes): %.20q...", len(name), name) } // The concrete type must be registered. + registerLock.RLock() typ, ok := nameToConcreteType[name] + registerLock.RUnlock() if !ok { errorf("name not registered for interface: %q", name) } diff --git a/src/pkg/encoding/gob/decoder.go b/src/pkg/encoding/gob/decoder.go index c5c7d3fdb..04f706ca5 100644 --- a/src/pkg/encoding/gob/decoder.go +++ b/src/pkg/encoding/gob/decoder.go @@ -87,21 +87,38 @@ func (dec *Decoder) recvMessage() bool { // readMessage reads the next nbytes bytes from the input. func (dec *Decoder) readMessage(nbytes int) { - // Allocate the buffer. - if cap(dec.tmp) < nbytes { - dec.tmp = make([]byte, nbytes+100) // room to grow + // Allocate the dec.tmp buffer, up to 10KB. + const maxBuf = 10 * 1024 + nTmp := nbytes + if nTmp > maxBuf { + nTmp = maxBuf } - dec.tmp = dec.tmp[:nbytes] + if cap(dec.tmp) < nTmp { + nAlloc := nTmp + 100 // A little extra for growth. + if nAlloc > maxBuf { + nAlloc = maxBuf + } + dec.tmp = make([]byte, nAlloc) + } + dec.tmp = dec.tmp[:nTmp] // Read the data - _, dec.err = io.ReadFull(dec.r, dec.tmp) - if dec.err != nil { - if dec.err == io.EOF { - dec.err = io.ErrUnexpectedEOF + dec.buf.Grow(nbytes) + for nbytes > 0 { + if nbytes < nTmp { + dec.tmp = dec.tmp[:nbytes] } - return + var nRead int + nRead, dec.err = io.ReadFull(dec.r, dec.tmp) + if dec.err != nil { + if dec.err == io.EOF { + dec.err = io.ErrUnexpectedEOF + } + return + } + dec.buf.Write(dec.tmp) + nbytes -= nRead } - dec.buf.Write(dec.tmp) } // toInt turns an encoded uint64 into an int, according to the marshaling rules. diff --git a/src/pkg/encoding/gob/doc.go b/src/pkg/encoding/gob/doc.go index 821d9a3fe..5bd61b12e 100644 --- a/src/pkg/encoding/gob/doc.go +++ b/src/pkg/encoding/gob/doc.go @@ -67,11 +67,13 @@ point values may be received into any floating point variable. However, the destination variable must be able to represent the value or the decode operation will fail. -Structs, arrays and slices are also supported. Strings and arrays of bytes are -supported with a special, efficient representation (see below). When a slice is -decoded, if the existing slice has capacity the slice will be extended in place; -if not, a new array is allocated. Regardless, the length of the resulting slice -reports the number of elements decoded. +Structs, arrays and slices are also supported. Structs encode and +decode only exported fields. Strings and arrays of bytes are supported +with a special, efficient representation (see below). When a slice +is decoded, if the existing slice has capacity the slice will be +extended in place; if not, a new array is allocated. Regardless, +the length of the resulting slice reports the number of elements +decoded. Functions and channels cannot be sent in a gob. Attempting to encode a value that contains one will fail. @@ -118,7 +120,7 @@ elements using the standard gob encoding for their type, recursively. Maps are sent as an unsigned count followed by that many key, element pairs. Empty but non-nil maps are sent, so if the sender has allocated -a map, the receiver will allocate a map even no elements are +a map, the receiver will allocate a map even if no elements are transmitted. Structs are sent as a sequence of (field number, field value) pairs. The field @@ -328,7 +330,7 @@ reserved). 01 // Add 1 to get field number 0: field[1].name 01 // 1 byte 59 // structType.field[1].name = "Y" - 01 // Add 1 to get field number 1: field[0].id + 01 // Add 1 to get field number 1: field[1].id 04 // struct.Type.field[1].typeId is 2 (signed int). 00 // End of structType.field[1]; end of structType.field. 00 // end of wireType.structType structure diff --git a/src/pkg/encoding/gob/encode.go b/src/pkg/encoding/gob/encode.go index 168e08b13..ea37a6cbd 100644 --- a/src/pkg/encoding/gob/encode.go +++ b/src/pkg/encoding/gob/encode.go @@ -426,6 +426,12 @@ func (enc *Encoder) encodeMap(b *bytes.Buffer, mv reflect.Value, keyOp, elemOp e // by the concrete value. A nil value gets sent as the empty string for the name, // followed by no value. func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv reflect.Value) { + // Gobs can encode nil interface values but not typed interface + // values holding nil pointers, since nil pointers point to no value. + elem := iv.Elem() + if elem.Kind() == reflect.Ptr && elem.IsNil() { + errorf("gob: cannot encode nil pointer of type %s inside interface", iv.Elem().Type()) + } state := enc.newEncoderState(b) state.fieldnum = -1 state.sendZero = true @@ -435,7 +441,9 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv reflect.Value) { } ut := userType(iv.Elem().Type()) + registerLock.RLock() name, ok := concreteTypeToName[ut.base] + registerLock.RUnlock() if !ok { errorf("type not registered for interface: %s", ut.base) } @@ -454,7 +462,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv reflect.Value) { enc.pushWriter(b) data := new(bytes.Buffer) data.Write(spaceForLength) - enc.encode(data, iv.Elem(), ut) + enc.encode(data, elem, ut) if enc.err != nil { error_(enc.err) } @@ -698,9 +706,20 @@ func (enc *Encoder) getEncEngine(ut *userTypeInfo) *encEngine { error_(err1) } if info.encoder == nil { - // mark this engine as underway before compiling to handle recursive types. + // Assign the encEngine now, so recursive types work correctly. But... info.encoder = new(encEngine) + // ... if we fail to complete building the engine, don't cache the half-built machine. + // Doing this here means we won't cache a type that is itself OK but + // that contains a nested type that won't compile. The result is consistent + // error behavior when Encode is called multiple times on the top-level type. + ok := false + defer func() { + if !ok { + info.encoder = nil + } + }() info.encoder = enc.compileEnc(ut) + ok = true } return info.encoder } diff --git a/src/pkg/encoding/gob/encoder.go b/src/pkg/encoding/gob/encoder.go index a15b5a1f9..f669c3d5b 100644 --- a/src/pkg/encoding/gob/encoder.go +++ b/src/pkg/encoding/gob/encoder.go @@ -132,13 +132,13 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp return true } -// sendType sends the type info to the other side, if necessary. +// sendType sends the type info to the other side, if necessary. func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { ut := userType(origt) if ut.isGobEncoder { // The rules are different: regardless of the underlying type's representation, - // we need to tell the other side that this exact type is a GobEncoder. - return enc.sendActualType(w, state, ut, ut.user) + // we need to tell the other side that the base type is a GobEncoder. + return enc.sendActualType(w, state, ut, ut.base) } // It's a concrete value, so drill down to the base type. @@ -218,6 +218,12 @@ func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) { // EncodeValue transmits the data item represented by the reflection value, // guaranteeing that all necessary type information has been transmitted first. func (enc *Encoder) EncodeValue(value reflect.Value) error { + // Gobs contain values. They cannot represent nil pointers, which + // have no value to encode. + if value.Kind() == reflect.Ptr && value.IsNil() { + panic("gob: cannot encode nil pointer of type " + value.Type().String()) + } + // Make sure we're single-threaded through here, so multiple // goroutines can share an encoder. enc.mutex.Lock() diff --git a/src/pkg/encoding/gob/encoder_test.go b/src/pkg/encoding/gob/encoder_test.go index c4947cbb8..b684772c6 100644 --- a/src/pkg/encoding/gob/encoder_test.go +++ b/src/pkg/encoding/gob/encoder_test.go @@ -736,3 +736,109 @@ func TestPtrToMapOfMap(t *testing.T) { t.Fatalf("expected %v got %v", data, newData) } } + +// A top-level nil pointer generates a panic with a helpful string-valued message. +func TestTopLevelNilPointer(t *testing.T) { + errMsg := topLevelNilPanic(t) + if errMsg == "" { + t.Fatal("top-level nil pointer did not panic") + } + if !strings.Contains(errMsg, "nil pointer") { + t.Fatal("expected nil pointer error, got:", errMsg) + } +} + +func topLevelNilPanic(t *testing.T) (panicErr string) { + defer func() { + e := recover() + if err, ok := e.(string); ok { + panicErr = err + } + }() + var ip *int + buf := new(bytes.Buffer) + if err := NewEncoder(buf).Encode(ip); err != nil { + t.Fatal("error in encode:", err) + } + return +} + +func TestNilPointerInsideInterface(t *testing.T) { + var ip *int + si := struct { + I interface{} + }{ + I: ip, + } + buf := new(bytes.Buffer) + err := NewEncoder(buf).Encode(si) + if err == nil { + t.Fatal("expected error, got none") + } + errMsg := err.Error() + if !strings.Contains(errMsg, "nil pointer") || !strings.Contains(errMsg, "interface") { + t.Fatal("expected error about nil pointer and interface, got:", errMsg) + } +} + +type Bug4Public struct { + Name string + Secret Bug4Secret +} + +type Bug4Secret struct { + a int // error: no exported fields. +} + +// Test that a failed compilation doesn't leave around an executable encoder. +// Issue 3273. +func TestMutipleEncodingsOfBadType(t *testing.T) { + x := Bug4Public{ + Name: "name", + Secret: Bug4Secret{1}, + } + buf := new(bytes.Buffer) + enc := NewEncoder(buf) + err := enc.Encode(x) + if err == nil { + t.Fatal("first encoding: expected error") + } + buf.Reset() + enc = NewEncoder(buf) + err = enc.Encode(x) + if err == nil { + t.Fatal("second encoding: expected error") + } + if !strings.Contains(err.Error(), "no exported fields") { + t.Errorf("expected error about no exported fields; got %v", err) + } +} + +// There was an error check comparing the length of the input with the +// length of the slice being decoded. It was wrong because the next +// thing in the input might be a type definition, which would lead to +// an incorrect length check. This test reproduces the corner case. + +type Z struct { +} + +func Test29ElementSlice(t *testing.T) { + Register(Z{}) + src := make([]interface{}, 100) // Size needs to be bigger than size of type definition. + for i := range src { + src[i] = Z{} + } + buf := new(bytes.Buffer) + err := NewEncoder(buf).Encode(src) + if err != nil { + t.Fatalf("encode: %v", err) + return + } + + var dst []interface{} + err = NewDecoder(buf).Decode(&dst) + if err != nil { + t.Errorf("decode: %v", err) + return + } +} diff --git a/src/pkg/encoding/gob/gobencdec_test.go b/src/pkg/encoding/gob/gobencdec_test.go index 45240d764..ddcd80b1a 100644 --- a/src/pkg/encoding/gob/gobencdec_test.go +++ b/src/pkg/encoding/gob/gobencdec_test.go @@ -1,4 +1,4 @@ -// Copyright 20011 The Go Authors. All rights reserved. +// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -142,6 +142,18 @@ type GobTest5 struct { V *ValueGobber } +type GobTest6 struct { + X int // guarantee we have something in common with GobTest* + V ValueGobber + W *ValueGobber +} + +type GobTest7 struct { + X int // guarantee we have something in common with GobTest* + V *ValueGobber + W ValueGobber +} + type GobTestIgnoreEncoder struct { X int // guarantee we have something in common with GobTest* } @@ -336,7 +348,7 @@ func TestGobEncoderFieldsOfDifferentType(t *testing.T) { t.Fatal("decode error:", err) } if y.G.s != "XYZ" { - t.Fatalf("expected `XYZ` got %c", y.G.s) + t.Fatalf("expected `XYZ` got %q", y.G.s) } } @@ -360,6 +372,61 @@ func TestGobEncoderValueEncoder(t *testing.T) { } } +// Test that we can use a value then a pointer type of a GobEncoder +// in the same encoded value. Bug 4647. +func TestGobEncoderValueThenPointer(t *testing.T) { + v := ValueGobber("forty-two") + w := ValueGobber("six-by-nine") + + // this was a bug: encoding a GobEncoder by value before a GobEncoder + // pointer would cause duplicate type definitions to be sent. + + b := new(bytes.Buffer) + enc := NewEncoder(b) + if err := enc.Encode(GobTest6{42, v, &w}); err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest6) + if err := dec.Decode(x); err != nil { + t.Fatal("decode error:", err) + } + if got, want := x.V, v; got != want { + t.Errorf("v = %q, want %q", got, want) + } + if got, want := x.W, w; got == nil { + t.Errorf("w = nil, want %q", want) + } else if *got != want { + t.Errorf("w = %q, want %q", *got, want) + } +} + +// Test that we can use a pointer then a value type of a GobEncoder +// in the same encoded value. +func TestGobEncoderPointerThenValue(t *testing.T) { + v := ValueGobber("forty-two") + w := ValueGobber("six-by-nine") + + b := new(bytes.Buffer) + enc := NewEncoder(b) + if err := enc.Encode(GobTest7{42, &v, w}); err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest7) + if err := dec.Decode(x); err != nil { + t.Fatal("decode error:", err) + } + if got, want := x.V, v; got == nil { + t.Errorf("v = nil, want %q", want) + } else if *got != want { + t.Errorf("v = %q, want %q", got, want) + } + if got, want := x.W, w; got != want { + t.Errorf("w = %q, want %q", got, want) + } +} + func TestGobEncoderFieldTypeError(t *testing.T) { // GobEncoder to non-decoder: error b := new(bytes.Buffer) diff --git a/src/pkg/encoding/gob/timing_test.go b/src/pkg/encoding/gob/timing_test.go index b9371c423..13eb11925 100644 --- a/src/pkg/encoding/gob/timing_test.go +++ b/src/pkg/encoding/gob/timing_test.go @@ -9,7 +9,6 @@ import ( "fmt" "io" "os" - "runtime" "testing" ) @@ -50,47 +49,43 @@ func BenchmarkEndToEndByteBuffer(b *testing.B) { } func TestCountEncodeMallocs(t *testing.T) { + const N = 1000 + var buf bytes.Buffer enc := NewEncoder(&buf) bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")} - memstats := new(runtime.MemStats) - runtime.ReadMemStats(memstats) - mallocs := 0 - memstats.Mallocs - const count = 1000 - for i := 0; i < count; i++ { + + allocs := testing.AllocsPerRun(N, func() { err := enc.Encode(bench) if err != nil { t.Fatal("encode:", err) } - } - runtime.ReadMemStats(memstats) - mallocs += memstats.Mallocs - fmt.Printf("mallocs per encode of type Bench: %d\n", mallocs/count) + }) + fmt.Printf("mallocs per encode of type Bench: %v\n", allocs) } func TestCountDecodeMallocs(t *testing.T) { + const N = 1000 + var buf bytes.Buffer enc := NewEncoder(&buf) bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")} - const count = 1000 - for i := 0; i < count; i++ { + + // Fill the buffer with enough to decode + testing.AllocsPerRun(N, func() { err := enc.Encode(bench) if err != nil { t.Fatal("encode:", err) } - } + }) + dec := NewDecoder(&buf) - memstats := new(runtime.MemStats) - runtime.ReadMemStats(memstats) - mallocs := 0 - memstats.Mallocs - for i := 0; i < count; i++ { + allocs := testing.AllocsPerRun(N, func() { *bench = Bench{} err := dec.Decode(&bench) if err != nil { t.Fatal("decode:", err) } - } - runtime.ReadMemStats(memstats) - mallocs += memstats.Mallocs - fmt.Printf("mallocs per decode of type Bench: %d\n", mallocs/count) + }) + fmt.Printf("mallocs per decode of type Bench: %v\n", allocs) } diff --git a/src/pkg/encoding/gob/type.go b/src/pkg/encoding/gob/type.go index 0dd7a0a77..ea0db4eac 100644 --- a/src/pkg/encoding/gob/type.go +++ b/src/pkg/encoding/gob/type.go @@ -712,6 +712,7 @@ type GobDecoder interface { } var ( + registerLock sync.RWMutex nameToConcreteType = make(map[string]reflect.Type) concreteTypeToName = make(map[reflect.Type]string) ) @@ -723,6 +724,8 @@ func RegisterName(name string, value interface{}) { // reserved for nil panic("attempt to register empty name") } + registerLock.Lock() + defer registerLock.Unlock() ut := userType(reflect.TypeOf(value)) // Check for incompatible duplicates. The name must refer to the // same user type, and vice versa. @@ -749,12 +752,28 @@ func Register(value interface{}) { rt := reflect.TypeOf(value) name := rt.String() - // But for named types (or pointers to them), qualify with import path. + // But for named types (or pointers to them), qualify with import path (but see inner comment). // Dereference one pointer looking for a named type. star := "" if rt.Name() == "" { if pt := rt; pt.Kind() == reflect.Ptr { star = "*" + // NOTE: The following line should be rt = pt.Elem() to implement + // what the comment above claims, but fixing it would break compatibility + // with existing gobs. + // + // Given package p imported as "full/p" with these definitions: + // package p + // type T1 struct { ... } + // this table shows the intended and actual strings used by gob to + // name the types: + // + // Type Correct string Actual string + // + // T1 full/p.T1 full/p.T1 + // *T1 *full/p.T1 *p.T1 + // + // The missing full path cannot be fixed without breaking existing gob decoders. rt = pt } } diff --git a/src/pkg/encoding/gob/type_test.go b/src/pkg/encoding/gob/type_test.go index 42bdb4cf7..e230d22d4 100644 --- a/src/pkg/encoding/gob/type_test.go +++ b/src/pkg/encoding/gob/type_test.go @@ -5,6 +5,7 @@ package gob import ( + "bytes" "reflect" "testing" ) @@ -159,3 +160,63 @@ func TestRegistration(t *testing.T) { Register(new(T)) Register(new(T)) } + +type N1 struct{} +type N2 struct{} + +// See comment in type.go/Register. +func TestRegistrationNaming(t *testing.T) { + testCases := []struct { + t interface{} + name string + }{ + {&N1{}, "*gob.N1"}, + {N2{}, "encoding/gob.N2"}, + } + + for _, tc := range testCases { + Register(tc.t) + + tct := reflect.TypeOf(tc.t) + registerLock.RLock() + ct := nameToConcreteType[tc.name] + registerLock.RUnlock() + if ct != tct { + t.Errorf("nameToConcreteType[%q] = %v, want %v", tc.name, ct, tct) + } + // concreteTypeToName is keyed off the base type. + if tct.Kind() == reflect.Ptr { + tct = tct.Elem() + } + if n := concreteTypeToName[tct]; n != tc.name { + t.Errorf("concreteTypeToName[%v] got %v, want %v", tct, n, tc.name) + } + } +} + +func TestStressParallel(t *testing.T) { + type T2 struct{ A int } + c := make(chan bool) + const N = 10 + for i := 0; i < N; i++ { + go func() { + p := new(T2) + Register(p) + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(p) + if err != nil { + t.Error("encoder fail:", err) + } + dec := NewDecoder(b) + err = dec.Decode(p) + if err != nil { + t.Error("decoder fail:", err) + } + c <- true + }() + } + for i := 0; i < N; i++ { + <-c + } +} diff --git a/src/pkg/encoding/hex/hex_test.go b/src/pkg/encoding/hex/hex_test.go index 456f9eac7..356f590f0 100644 --- a/src/pkg/encoding/hex/hex_test.go +++ b/src/pkg/encoding/hex/hex_test.go @@ -65,7 +65,7 @@ func TestDecodeString(t *testing.T) { t.Errorf("#%d: unexpected err value: %s", i, err) continue } - if bytes.Compare(dst, test.dec) != 0 { + if !bytes.Equal(dst, test.dec) { t.Errorf("#%d: got: %#v want: #%v", i, dst, test.dec) } } diff --git a/src/pkg/encoding/json/bench_test.go b/src/pkg/encoding/json/bench_test.go index 333c1c0ce..29dbc26d4 100644 --- a/src/pkg/encoding/json/bench_test.go +++ b/src/pkg/encoding/json/bench_test.go @@ -153,5 +153,37 @@ func BenchmarkCodeUnmarshalReuse(b *testing.B) { b.Fatal("Unmmarshal:", err) } } - b.SetBytes(int64(len(codeJSON))) +} + +func BenchmarkUnmarshalString(b *testing.B) { + data := []byte(`"hello, world"`) + var s string + + for i := 0; i < b.N; i++ { + if err := Unmarshal(data, &s); err != nil { + b.Fatal("Unmarshal:", err) + } + } +} + +func BenchmarkUnmarshalFloat64(b *testing.B) { + var f float64 + data := []byte(`3.14`) + + for i := 0; i < b.N; i++ { + if err := Unmarshal(data, &f); err != nil { + b.Fatal("Unmarshal:", err) + } + } +} + +func BenchmarkUnmarshalInt64(b *testing.B) { + var x int64 + data := []byte(`3`) + + for i := 0; i < b.N; i++ { + if err := Unmarshal(data, &x); err != nil { + b.Fatal("Unmarshal:", err) + } + } } diff --git a/src/pkg/encoding/json/decode.go b/src/pkg/encoding/json/decode.go index d61f88706..f2ec9cb67 100644 --- a/src/pkg/encoding/json/decode.go +++ b/src/pkg/encoding/json/decode.go @@ -33,6 +33,10 @@ import ( // the value pointed at by the pointer. If the pointer is nil, Unmarshal // allocates a new value for it to point to. // +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by Marshal (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. +// // To unmarshal JSON into an interface value, Unmarshal unmarshals // the JSON into the concrete value contained in the interface value. // If the interface value is nil, that is, has no concrete value stored in it, @@ -51,24 +55,29 @@ import ( // If no more serious errors are encountered, Unmarshal returns // an UnmarshalTypeError describing the earliest such error. // +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +// func Unmarshal(data []byte, v interface{}) error { - d := new(decodeState).init(data) - - // Quick check for well-formedness. + // Check for well-formedness. // Avoids filling out half a data structure // before discovering a JSON syntax error. + var d decodeState err := checkValid(data, &d.scan) if err != nil { return err } + d.init(data) return d.unmarshal(v) } // Unmarshaler is the interface implemented by objects // that can unmarshal a JSON description of themselves. -// The input can be assumed to be a valid JSON object -// encoding. UnmarshalJSON must copy the JSON data +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data // if it wishes to retain the data after returning. type Unmarshaler interface { UnmarshalJSON([]byte) error @@ -87,6 +96,7 @@ func (e *UnmarshalTypeError) Error() string { // An UnmarshalFieldError describes a JSON object key that // led to an unexported (and therefore unwritable) struct field. +// (No longer used; kept for compatibility.) type UnmarshalFieldError struct { Key string Type reflect.Type @@ -125,18 +135,33 @@ func (d *decodeState) unmarshal(v interface{}) (err error) { }() rv := reflect.ValueOf(v) - pv := rv - if pv.Kind() != reflect.Ptr || pv.IsNil() { + if rv.Kind() != reflect.Ptr || rv.IsNil() { return &InvalidUnmarshalError{reflect.TypeOf(v)} } d.scan.reset() - // We decode rv not pv.Elem because the Unmarshaler interface + // We decode rv not rv.Elem because the Unmarshaler interface // test must be applied at the top level of the value. d.value(rv) return d.savedError } +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + // decodeState represents the state while decoding a JSON value. type decodeState struct { data []byte @@ -145,6 +170,7 @@ type decodeState struct { nextscan scanner // for calls to nextValue savedError error tempstr string // scratch space to avoid some allocations + useNumber bool } // errPhase is used for errors that should not happen unless @@ -265,47 +291,32 @@ func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, v = v.Addr() } for { - var isUnmarshaler bool - if v.Type().NumMethod() > 0 { - // Remember that this is an unmarshaler, - // but wait to return it until after allocating - // the pointer (if necessary). - _, isUnmarshaler = v.Interface().(Unmarshaler) - } - // Load value from interface, but only if the result will be // usefully addressable. - if iv := v; iv.Kind() == reflect.Interface && !iv.IsNil() { - e := iv.Elem() + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { v = e continue } } - pv := v - if pv.Kind() != reflect.Ptr { + if v.Kind() != reflect.Ptr { break } - if pv.Elem().Kind() != reflect.Ptr && decodingNull && pv.CanSet() { - return nil, pv + if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + break } - if pv.IsNil() { - pv.Set(reflect.New(pv.Type().Elem())) + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) } - if isUnmarshaler { - // Using v.Interface().(Unmarshaler) - // here means that we have to use a pointer - // as the struct field. We cannot use a value inside - // a pointer to a struct, because in that case - // v.Interface() is the value (x.f) not the pointer (&x.f). - // This is an unfortunate consequence of reflect. - // An alternative would be to look up the - // UnmarshalJSON method and return a FuncValue. - return v.Interface().(Unmarshaler), reflect.Value{} + if v.Type().NumMethod() > 0 { + if unmarshaler, ok := v.Interface().(Unmarshaler); ok { + return unmarshaler, reflect.Value{} + } } - v = pv.Elem() + v = v.Elem() } return nil, v } @@ -327,15 +338,19 @@ func (d *decodeState) array(v reflect.Value) { // Check type of target. switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + v.Set(reflect.ValueOf(d.arrayInterface())) + return + } + // Otherwise it's invalid. + fallthrough default: d.saveError(&UnmarshalTypeError{"array", v.Type()}) d.off-- d.next() return - case reflect.Interface: - // Decoding into nil interface? Switch to non-reflect code. - v.Set(reflect.ValueOf(d.arrayInterface())) - return case reflect.Array: case reflect.Slice: break @@ -421,36 +436,27 @@ func (d *decodeState) object(v reflect.Value) { v = pv // Decoding into nil interface? Switch to non-reflect code. - iv := v - if iv.Kind() == reflect.Interface { - iv.Set(reflect.ValueOf(d.objectInterface())) + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + v.Set(reflect.ValueOf(d.objectInterface())) return } // Check type of target: struct or map[string]T - var ( - mv reflect.Value - sv reflect.Value - ) switch v.Kind() { case reflect.Map: - // map must have string type + // map must have string kind t := v.Type() - if t.Key() != reflect.TypeOf("") { + if t.Key().Kind() != reflect.String { d.saveError(&UnmarshalTypeError{"object", v.Type()}) break } - mv = v - if mv.IsNil() { - mv.Set(reflect.MakeMap(t)) + if v.IsNil() { + v.Set(reflect.MakeMap(t)) } case reflect.Struct: - sv = v + default: d.saveError(&UnmarshalTypeError{"object", v.Type()}) - } - - if !mv.IsValid() && !sv.IsValid() { d.off-- d.next() // skip over { } in input return @@ -482,8 +488,8 @@ func (d *decodeState) object(v reflect.Value) { var subv reflect.Value destring := false // whether the value is wrapped in a string to be decoded first - if mv.IsValid() { - elemType := mv.Type().Elem() + if v.Kind() == reflect.Map { + elemType := v.Type().Elem() if !mapElem.IsValid() { mapElem = reflect.New(elemType).Elem() } else { @@ -491,51 +497,30 @@ func (d *decodeState) object(v reflect.Value) { } subv = mapElem } else { - var f reflect.StructField - var ok bool - st := sv.Type() - for i := 0; i < sv.NumField(); i++ { - sf := st.Field(i) - tag := sf.Tag.Get("json") - if tag == "-" { - // Pretend this field doesn't exist. - continue - } - if sf.Anonymous { - // Pretend this field doesn't exist, - // so that we can do a good job with - // these in a later version. - continue - } - // First, tag match - tagName, _ := parseTag(tag) - if tagName == key { - f = sf - ok = true - break // no better match possible - } - // Second, exact field name match - if sf.Name == key { - f = sf - ok = true + var f *field + fields := cachedTypeFields(v.Type()) + for i := range fields { + ff := &fields[i] + if ff.name == key { + f = ff + break } - // Third, case-insensitive field name match, - // but only if a better match hasn't already been seen - if !ok && strings.EqualFold(sf.Name, key) { - f = sf - ok = true + if f == nil && strings.EqualFold(ff.name, key) { + f = ff } } - - // Extract value; name must be exported. - if ok { - if f.PkgPath != "" { - d.saveError(&UnmarshalFieldError{key, st, f}) - } else { - subv = sv.FieldByIndex(f.Index) + if f != nil { + subv = v + destring = f.quoted + for _, i := range f.index { + if subv.Kind() == reflect.Ptr { + if subv.IsNil() { + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + subv = subv.Field(i) } - _, opts := parseTag(f.Tag.Get("json")) - destring = opts.Contains("string") } } @@ -554,10 +539,12 @@ func (d *decodeState) object(v reflect.Value) { } else { d.value(subv) } + // Write value back to map; // if using struct, subv points into struct already. - if mv.IsValid() { - mv.SetMapIndex(reflect.ValueOf(key), subv) + if v.Kind() == reflect.Map { + kv := reflect.ValueOf(key).Convert(v.Type().Key()) + v.SetMapIndex(kv, subv) } // Next token must be , or }. @@ -586,6 +573,21 @@ func (d *decodeState) literal(v reflect.Value) { d.literalStore(d.data[start:d.off], v, false) } +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (interface{}, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, &UnmarshalTypeError{"number " + s, reflect.TypeOf(0.0)} + } + return f, nil +} + +var numberType = reflect.TypeOf(Number("")) + // literalStore decodes a literal stored in item into v. // // fromQuoted indicates whether this literal came from unwrapping a @@ -612,12 +614,10 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool switch c := item[0]; c { case 'n': // null switch v.Kind() { - default: - d.saveError(&UnmarshalTypeError{"null", v.Type()}) case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: v.Set(reflect.Zero(v.Type())) + // otherwise, ignore null for primitives/string } - case 't', 'f': // true, false value := c == 't' switch v.Kind() { @@ -630,7 +630,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool case reflect.Bool: v.SetBool(value) case reflect.Interface: - v.Set(reflect.ValueOf(value)) + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{"bool", v.Type()}) + } } case '"': // string @@ -660,7 +664,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool case reflect.String: v.SetString(string(s)) case reflect.Interface: - v.Set(reflect.ValueOf(string(s))) + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + } } default: // number @@ -674,15 +682,23 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool s := string(item) switch v.Kind() { default: + if v.Kind() == reflect.String && v.Type() == numberType { + v.SetString(s) + break + } if fromQuoted { d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) } else { d.error(&UnmarshalTypeError{"number", v.Type()}) } case reflect.Interface: - n, err := strconv.ParseFloat(s, 64) + n, err := d.convertNumber(s) if err != nil { - d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{"number", v.Type()}) break } v.Set(reflect.ValueOf(n)) @@ -735,7 +751,7 @@ func (d *decodeState) valueInterface() interface{} { // arrayInterface is like array but returns []interface{}. func (d *decodeState) arrayInterface() []interface{} { - var v []interface{} + var v = make([]interface{}, 0) for { // Look ahead for ] - can only happen on first iteration. op := d.scanWhile(scanSkipSpace) @@ -836,9 +852,9 @@ func (d *decodeState) literalInterface() interface{} { if c != '-' && (c < '0' || c > '9') { d.error(errPhase) } - n, err := strconv.ParseFloat(string(item), 64) + n, err := d.convertNumber(string(item)) if err != nil { - d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.TypeOf(0.0)}) + d.saveError(err) } return n } @@ -979,11 +995,3 @@ func unquoteBytes(s []byte) (t []byte, ok bool) { } return b[0:w], true } - -// The following is issue 3069. - -// BUG(rsc): This package ignores anonymous (embedded) struct fields -// during encoding and decoding. A future version may assign meaning -// to them. To force an anonymous field to be ignored in all future -// versions of this package, use an explicit `json:"-"` tag in the struct -// definition. diff --git a/src/pkg/encoding/json/decode_test.go b/src/pkg/encoding/json/decode_test.go index 6fac22c4a..e1bd918dd 100644 --- a/src/pkg/encoding/json/decode_test.go +++ b/src/pkg/encoding/json/decode_test.go @@ -7,9 +7,11 @@ package json import ( "bytes" "fmt" + "image" "reflect" "strings" "testing" + "time" ) type T struct { @@ -18,6 +20,32 @@ type T struct { Z int `json:"-"` } +type U struct { + Alphabet string `json:"alpha"` +} + +type V struct { + F1 interface{} + F2 int32 + F3 Number +} + +// ifaceNumAsFloat64/ifaceNumAsNumber are used to test unmarshalling with and +// without UseNumber +var ifaceNumAsFloat64 = map[string]interface{}{ + "k1": float64(1), + "k2": "s", + "k3": []interface{}{float64(1), float64(2.0), float64(3e-3)}, + "k4": map[string]interface{}{"kk1": "s", "kk2": float64(2)}, +} + +var ifaceNumAsNumber = map[string]interface{}{ + "k1": Number("1"), + "k2": "s", + "k3": []interface{}{Number("1"), Number("2.0"), Number("3e-3")}, + "k4": map[string]interface{}{"kk1": "s", "kk2": Number("2")}, +} + type tx struct { x int } @@ -48,55 +76,297 @@ var ( umstruct = ustruct{unmarshaler{true}} ) +// Test data structures for anonymous fields. + +type Point struct { + Z int +} + +type Top struct { + Level0 int + Embed0 + *Embed0a + *Embed0b `json:"e,omitempty"` // treated as named + Embed0c `json:"-"` // ignored + Loop + Embed0p // has Point with X, Y, used + Embed0q // has Point with Z, used +} + +type Embed0 struct { + Level1a int // overridden by Embed0a's Level1a with json tag + Level1b int // used because Embed0a's Level1b is renamed + Level1c int // used because Embed0a's Level1c is ignored + Level1d int // annihilated by Embed0a's Level1d + Level1e int `json:"x"` // annihilated by Embed0a.Level1e +} + +type Embed0a struct { + Level1a int `json:"Level1a,omitempty"` + Level1b int `json:"LEVEL1B,omitempty"` + Level1c int `json:"-"` + Level1d int // annihilated by Embed0's Level1d + Level1f int `json:"x"` // annihilated by Embed0's Level1e +} + +type Embed0b Embed0 + +type Embed0c Embed0 + +type Embed0p struct { + image.Point +} + +type Embed0q struct { + Point +} + +type Loop struct { + Loop1 int `json:",omitempty"` + Loop2 int `json:",omitempty"` + *Loop +} + +// From reflect test: +// The X in S6 and S7 annihilate, but they also block the X in S8.S9. +type S5 struct { + S6 + S7 + S8 +} + +type S6 struct { + X int +} + +type S7 S6 + +type S8 struct { + S9 +} + +type S9 struct { + X int + Y int +} + +// From reflect test: +// The X in S11.S6 and S12.S6 annihilate, but they also block the X in S13.S8.S9. +type S10 struct { + S11 + S12 + S13 +} + +type S11 struct { + S6 +} + +type S12 struct { + S6 +} + +type S13 struct { + S8 +} + type unmarshalTest struct { - in string - ptr interface{} - out interface{} - err error + in string + ptr interface{} + out interface{} + err error + useNumber bool +} + +type Ambig struct { + // Given "hello", the first match should win. + First int `json:"HELLO"` + Second int `json:"Hello"` } var unmarshalTests = []unmarshalTest{ // basic types - {`true`, new(bool), true, nil}, - {`1`, new(int), 1, nil}, - {`1.2`, new(float64), 1.2, nil}, - {`-5`, new(int16), int16(-5), nil}, - {`"a\u1234"`, new(string), "a\u1234", nil}, - {`"http:\/\/"`, new(string), "http://", nil}, - {`"g-clef: \uD834\uDD1E"`, new(string), "g-clef: \U0001D11E", nil}, - {`"invalid: \uD834x\uDD1E"`, new(string), "invalid: \uFFFDx\uFFFD", nil}, - {"null", new(interface{}), nil, nil}, - {`{"X": [1,2,3], "Y": 4}`, new(T), T{Y: 4}, &UnmarshalTypeError{"array", reflect.TypeOf("")}}, - {`{"x": 1}`, new(tx), tx{}, &UnmarshalFieldError{"x", txType, txType.Field(0)}}, + {in: `true`, ptr: new(bool), out: true}, + {in: `1`, ptr: new(int), out: 1}, + {in: `1.2`, ptr: new(float64), out: 1.2}, + {in: `-5`, ptr: new(int16), out: int16(-5)}, + {in: `2`, ptr: new(Number), out: Number("2"), useNumber: true}, + {in: `2`, ptr: new(Number), out: Number("2")}, + {in: `2`, ptr: new(interface{}), out: float64(2.0)}, + {in: `2`, ptr: new(interface{}), out: Number("2"), useNumber: true}, + {in: `"a\u1234"`, ptr: new(string), out: "a\u1234"}, + {in: `"http:\/\/"`, ptr: new(string), out: "http://"}, + {in: `"g-clef: \uD834\uDD1E"`, ptr: new(string), out: "g-clef: \U0001D11E"}, + {in: `"invalid: \uD834x\uDD1E"`, ptr: new(string), out: "invalid: \uFFFDx\uFFFD"}, + {in: "null", ptr: new(interface{}), out: nil}, + {in: `{"X": [1,2,3], "Y": 4}`, ptr: new(T), out: T{Y: 4}, err: &UnmarshalTypeError{"array", reflect.TypeOf("")}}, + {in: `{"x": 1}`, ptr: new(tx), out: tx{}}, + {in: `{"F1":1,"F2":2,"F3":3}`, ptr: new(V), out: V{F1: float64(1), F2: int32(2), F3: Number("3")}}, + {in: `{"F1":1,"F2":2,"F3":3}`, ptr: new(V), out: V{F1: Number("1"), F2: int32(2), F3: Number("3")}, useNumber: true}, + {in: `{"k1":1,"k2":"s","k3":[1,2.0,3e-3],"k4":{"kk1":"s","kk2":2}}`, ptr: new(interface{}), out: ifaceNumAsFloat64}, + {in: `{"k1":1,"k2":"s","k3":[1,2.0,3e-3],"k4":{"kk1":"s","kk2":2}}`, ptr: new(interface{}), out: ifaceNumAsNumber, useNumber: true}, + + // raw values with whitespace + {in: "\n true ", ptr: new(bool), out: true}, + {in: "\t 1 ", ptr: new(int), out: 1}, + {in: "\r 1.2 ", ptr: new(float64), out: 1.2}, + {in: "\t -5 \n", ptr: new(int16), out: int16(-5)}, + {in: "\t \"a\\u1234\" \n", ptr: new(string), out: "a\u1234"}, // Z has a "-" tag. - {`{"Y": 1, "Z": 2}`, new(T), T{Y: 1}, nil}, + {in: `{"Y": 1, "Z": 2}`, ptr: new(T), out: T{Y: 1}}, + + {in: `{"alpha": "abc", "alphabet": "xyz"}`, ptr: new(U), out: U{Alphabet: "abc"}}, + {in: `{"alpha": "abc"}`, ptr: new(U), out: U{Alphabet: "abc"}}, + {in: `{"alphabet": "xyz"}`, ptr: new(U), out: U{}}, // syntax errors - {`{"X": "foo", "Y"}`, nil, nil, &SyntaxError{"invalid character '}' after object key", 17}}, - {`[1, 2, 3+]`, nil, nil, &SyntaxError{"invalid character '+' after array element", 9}}, + {in: `{"X": "foo", "Y"}`, err: &SyntaxError{"invalid character '}' after object key", 17}}, + {in: `[1, 2, 3+]`, err: &SyntaxError{"invalid character '+' after array element", 9}}, + {in: `{"X":12x}`, err: &SyntaxError{"invalid character 'x' after object key:value pair", 8}, useNumber: true}, + + // raw value errors + {in: "\x01 42", err: &SyntaxError{"invalid character '\\x01' looking for beginning of value", 1}}, + {in: " 42 \x01", err: &SyntaxError{"invalid character '\\x01' after top-level value", 5}}, + {in: "\x01 true", err: &SyntaxError{"invalid character '\\x01' looking for beginning of value", 1}}, + {in: " false \x01", err: &SyntaxError{"invalid character '\\x01' after top-level value", 8}}, + {in: "\x01 1.2", err: &SyntaxError{"invalid character '\\x01' looking for beginning of value", 1}}, + {in: " 3.4 \x01", err: &SyntaxError{"invalid character '\\x01' after top-level value", 6}}, + {in: "\x01 \"string\"", err: &SyntaxError{"invalid character '\\x01' looking for beginning of value", 1}}, + {in: " \"string\" \x01", err: &SyntaxError{"invalid character '\\x01' after top-level value", 11}}, // array tests - {`[1, 2, 3]`, new([3]int), [3]int{1, 2, 3}, nil}, - {`[1, 2, 3]`, new([1]int), [1]int{1}, nil}, - {`[1, 2, 3]`, new([5]int), [5]int{1, 2, 3, 0, 0}, nil}, + {in: `[1, 2, 3]`, ptr: new([3]int), out: [3]int{1, 2, 3}}, + {in: `[1, 2, 3]`, ptr: new([1]int), out: [1]int{1}}, + {in: `[1, 2, 3]`, ptr: new([5]int), out: [5]int{1, 2, 3, 0, 0}}, + + // empty array to interface test + {in: `[]`, ptr: new([]interface{}), out: []interface{}{}}, + {in: `null`, ptr: new([]interface{}), out: []interface{}(nil)}, + {in: `{"T":[]}`, ptr: new(map[string]interface{}), out: map[string]interface{}{"T": []interface{}{}}}, + {in: `{"T":null}`, ptr: new(map[string]interface{}), out: map[string]interface{}{"T": interface{}(nil)}}, // composite tests - {allValueIndent, new(All), allValue, nil}, - {allValueCompact, new(All), allValue, nil}, - {allValueIndent, new(*All), &allValue, nil}, - {allValueCompact, new(*All), &allValue, nil}, - {pallValueIndent, new(All), pallValue, nil}, - {pallValueCompact, new(All), pallValue, nil}, - {pallValueIndent, new(*All), &pallValue, nil}, - {pallValueCompact, new(*All), &pallValue, nil}, + {in: allValueIndent, ptr: new(All), out: allValue}, + {in: allValueCompact, ptr: new(All), out: allValue}, + {in: allValueIndent, ptr: new(*All), out: &allValue}, + {in: allValueCompact, ptr: new(*All), out: &allValue}, + {in: pallValueIndent, ptr: new(All), out: pallValue}, + {in: pallValueCompact, ptr: new(All), out: pallValue}, + {in: pallValueIndent, ptr: new(*All), out: &pallValue}, + {in: pallValueCompact, ptr: new(*All), out: &pallValue}, // unmarshal interface test - {`{"T":false}`, &um0, umtrue, nil}, // use "false" so test will fail if custom unmarshaler is not called - {`{"T":false}`, &ump, &umtrue, nil}, - {`[{"T":false}]`, &umslice, umslice, nil}, - {`[{"T":false}]`, &umslicep, &umslice, nil}, - {`{"M":{"T":false}}`, &umstruct, umstruct, nil}, + {in: `{"T":false}`, ptr: &um0, out: umtrue}, // use "false" so test will fail if custom unmarshaler is not called + {in: `{"T":false}`, ptr: &ump, out: &umtrue}, + {in: `[{"T":false}]`, ptr: &umslice, out: umslice}, + {in: `[{"T":false}]`, ptr: &umslicep, out: &umslice}, + {in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct}, + + { + in: `{ + "Level0": 1, + "Level1b": 2, + "Level1c": 3, + "x": 4, + "Level1a": 5, + "LEVEL1B": 6, + "e": { + "Level1a": 8, + "Level1b": 9, + "Level1c": 10, + "Level1d": 11, + "x": 12 + }, + "Loop1": 13, + "Loop2": 14, + "X": 15, + "Y": 16, + "Z": 17 + }`, + ptr: new(Top), + out: Top{ + Level0: 1, + Embed0: Embed0{ + Level1b: 2, + Level1c: 3, + }, + Embed0a: &Embed0a{ + Level1a: 5, + Level1b: 6, + }, + Embed0b: &Embed0b{ + Level1a: 8, + Level1b: 9, + Level1c: 10, + Level1d: 11, + Level1e: 12, + }, + Loop: Loop{ + Loop1: 13, + Loop2: 14, + }, + Embed0p: Embed0p{ + Point: image.Point{X: 15, Y: 16}, + }, + Embed0q: Embed0q{ + Point: Point{Z: 17}, + }, + }, + }, + { + in: `{"hello": 1}`, + ptr: new(Ambig), + out: Ambig{First: 1}, + }, + + { + in: `{"X": 1,"Y":2}`, + ptr: new(S5), + out: S5{S8: S8{S9: S9{Y: 2}}}, + }, + { + in: `{"X": 1,"Y":2}`, + ptr: new(S10), + out: S10{S13: S13{S8: S8{S9: S9{Y: 2}}}}, + }, + + // invalid UTF-8 is coerced to valid UTF-8. + { + in: "\"hello\xffworld\"", + ptr: new(string), + out: "hello\ufffdworld", + }, + { + in: "\"hello\xc2\xc2world\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\xc2\xffworld\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\\ud800world\"", + ptr: new(string), + out: "hello\ufffdworld", + }, + { + in: "\"hello\\ud800\\ud800world\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\\ud800\\ud800world\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\xed\xa0\x80\xed\xb0\x80world\"", + ptr: new(string), + out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld", + }, } func TestMarshal(t *testing.T) { @@ -135,6 +405,18 @@ func TestMarshalBadUTF8(t *testing.T) { } } +func TestMarshalNumberZeroVal(t *testing.T) { + var n Number + out, err := Marshal(n) + if err != nil { + t.Fatal(err) + } + outStr := string(out) + if outStr != "0" { + t.Fatalf("Invalid zero val for Number: %q", outStr) + } +} + func TestUnmarshal(t *testing.T) { for i, tt := range unmarshalTests { var scan scanner @@ -150,7 +432,11 @@ func TestUnmarshal(t *testing.T) { } // v = new(right-type) v := reflect.New(reflect.TypeOf(tt.ptr).Elem()) - if err := Unmarshal([]byte(in), v.Interface()); !reflect.DeepEqual(err, tt.err) { + dec := NewDecoder(bytes.NewBuffer(in)) + if tt.useNumber { + dec.UseNumber() + } + if err := dec.Decode(v.Interface()); !reflect.DeepEqual(err, tt.err) { t.Errorf("#%d: %v want %v", i, err, tt.err) continue } @@ -162,6 +448,28 @@ func TestUnmarshal(t *testing.T) { println(string(data)) continue } + + // Check round trip. + if tt.err == nil { + enc, err := Marshal(v.Interface()) + if err != nil { + t.Errorf("#%d: error re-marshaling: %v", i, err) + continue + } + vv := reflect.New(reflect.TypeOf(tt.ptr).Elem()) + dec = NewDecoder(bytes.NewBuffer(enc)) + if tt.useNumber { + dec.UseNumber() + } + if err := dec.Decode(vv.Interface()); err != nil { + t.Errorf("#%d: error re-unmarshaling: %v", i, err) + continue + } + if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) { + t.Errorf("#%d: mismatch\nhave: %#+v\nwant: %#+v", i, v.Elem().Interface(), vv.Elem().Interface()) + continue + } + } } } @@ -175,13 +483,45 @@ func TestUnmarshalMarshal(t *testing.T) { if err != nil { t.Fatalf("Marshal: %v", err) } - if bytes.Compare(jsonBig, b) != 0 { + if !bytes.Equal(jsonBig, b) { t.Errorf("Marshal jsonBig") diff(t, b, jsonBig) return } } +var numberTests = []struct { + in string + i int64 + intErr string + f float64 + floatErr string +}{ + {in: "-1.23e1", intErr: "strconv.ParseInt: parsing \"-1.23e1\": invalid syntax", f: -1.23e1}, + {in: "-12", i: -12, f: -12.0}, + {in: "1e1000", intErr: "strconv.ParseInt: parsing \"1e1000\": invalid syntax", floatErr: "strconv.ParseFloat: parsing \"1e1000\": value out of range"}, +} + +// Independent of Decode, basic coverage of the accessors in Number +func TestNumberAccessors(t *testing.T) { + for _, tt := range numberTests { + n := Number(tt.in) + if s := n.String(); s != tt.in { + t.Errorf("Number(%q).String() is %q", tt.in, s) + } + if i, err := n.Int64(); err == nil && tt.intErr == "" && i != tt.i { + t.Errorf("Number(%q).Int64() is %d", tt.in, i) + } else if (err == nil && tt.intErr != "") || (err != nil && err.Error() != tt.intErr) { + t.Errorf("Number(%q).Int64() wanted error %q but got: %v", tt.in, tt.intErr, err) + } + if f, err := n.Float64(); err == nil && tt.floatErr == "" && f != tt.f { + t.Errorf("Number(%q).Float64() is %g", tt.in, f) + } else if (err == nil && tt.floatErr != "") || (err != nil && err.Error() != tt.floatErr) { + t.Errorf("Number(%q).Float64() wanted error %q but got: %v", tt.in, tt.floatErr, err) + } + } +} + func TestLargeByteSlice(t *testing.T) { s0 := make([]byte, 2000) for i := range s0 { @@ -195,7 +535,7 @@ func TestLargeByteSlice(t *testing.T) { if err := Unmarshal(b, &s1); err != nil { t.Fatalf("Unmarshal: %v", err) } - if bytes.Compare(s0, s1) != 0 { + if !bytes.Equal(s0, s1) { t.Errorf("Marshal large byte slice") diff(t, s0, s1) } @@ -610,35 +950,6 @@ func TestRefUnmarshal(t *testing.T) { } } -// Test that anonymous fields are ignored. -// We may assign meaning to them later. -func TestAnonymous(t *testing.T) { - type S struct { - T - N int - } - - data, err := Marshal(new(S)) - if err != nil { - t.Fatalf("Marshal: %v", err) - } - want := `{"N":0}` - if string(data) != want { - t.Fatalf("Marshal = %#q, want %#q", string(data), want) - } - - var s S - if err := Unmarshal([]byte(`{"T": 1, "T": {"Y": 1}, "N": 2}`), &s); err != nil { - t.Fatalf("Unmarshal: %v", err) - } - if s.N != 2 { - t.Fatal("Unmarshal: did not set N") - } - if s.T.Y != 0 { - t.Fatal("Unmarshal: did set T.Y") - } -} - // Test that the empty string doesn't panic decoding when ,string is specified // Issue 3450 func TestEmptyString(t *testing.T) { @@ -703,3 +1014,167 @@ func TestInterfaceSet(t *testing.T) { } } } + +// JSON null values should be ignored for primitives and string values instead of resulting in an error. +// Issue 2540 +func TestUnmarshalNulls(t *testing.T) { + jsonData := []byte(`{ + "Bool" : null, + "Int" : null, + "Int8" : null, + "Int16" : null, + "Int32" : null, + "Int64" : null, + "Uint" : null, + "Uint8" : null, + "Uint16" : null, + "Uint32" : null, + "Uint64" : null, + "Float32" : null, + "Float64" : null, + "String" : null}`) + + nulls := All{ + Bool: true, + Int: 2, + Int8: 3, + Int16: 4, + Int32: 5, + Int64: 6, + Uint: 7, + Uint8: 8, + Uint16: 9, + Uint32: 10, + Uint64: 11, + Float32: 12.1, + Float64: 13.1, + String: "14"} + + err := Unmarshal(jsonData, &nulls) + if err != nil { + t.Errorf("Unmarshal of null values failed: %v", err) + } + if !nulls.Bool || nulls.Int != 2 || nulls.Int8 != 3 || nulls.Int16 != 4 || nulls.Int32 != 5 || nulls.Int64 != 6 || + nulls.Uint != 7 || nulls.Uint8 != 8 || nulls.Uint16 != 9 || nulls.Uint32 != 10 || nulls.Uint64 != 11 || + nulls.Float32 != 12.1 || nulls.Float64 != 13.1 || nulls.String != "14" { + + t.Errorf("Unmarshal of null values affected primitives") + } +} + +func TestStringKind(t *testing.T) { + type stringKind string + type aMap map[stringKind]int + + var m1, m2 map[stringKind]int + m1 = map[stringKind]int{ + "foo": 42, + } + + data, err := Marshal(m1) + if err != nil { + t.Errorf("Unexpected error marshalling: %v", err) + } + + err = Unmarshal(data, &m2) + if err != nil { + t.Errorf("Unexpected error unmarshalling: %v", err) + } + + if !reflect.DeepEqual(m1, m2) { + t.Error("Items should be equal after encoding and then decoding") + } + +} + +var decodeTypeErrorTests = []struct { + dest interface{} + src string +}{ + {new(string), `{"user": "name"}`}, // issue 4628. + {new(error), `{}`}, // issue 4222 + {new(error), `[]`}, + {new(error), `""`}, + {new(error), `123`}, + {new(error), `true`}, +} + +func TestUnmarshalTypeError(t *testing.T) { + for _, item := range decodeTypeErrorTests { + err := Unmarshal([]byte(item.src), item.dest) + if _, ok := err.(*UnmarshalTypeError); !ok { + t.Errorf("expected type error for Unmarshal(%q, type %T): got %T", + item.src, item.dest, err) + } + } +} + +var unmarshalSyntaxTests = []string{ + "tru", + "fals", + "nul", + "123e", + `"hello`, + `[1,2,3`, + `{"key":1`, + `{"key":1,`, +} + +func TestUnmarshalSyntax(t *testing.T) { + var x interface{} + for _, src := range unmarshalSyntaxTests { + err := Unmarshal([]byte(src), &x) + if _, ok := err.(*SyntaxError); !ok { + t.Errorf("expected syntax error for Unmarshal(%q): got %T", src, err) + } + } +} + +// Test handling of unexported fields that should be ignored. +// Issue 4660 +type unexportedFields struct { + Name string + m map[string]interface{} `json:"-"` + m2 map[string]interface{} `json:"abcd"` +} + +func TestUnmarshalUnexported(t *testing.T) { + input := `{"Name": "Bob", "m": {"x": 123}, "m2": {"y": 456}, "abcd": {"z": 789}}` + want := &unexportedFields{Name: "Bob"} + + out := &unexportedFields{} + err := Unmarshal([]byte(input), out) + if err != nil { + t.Errorf("got error %v, expected nil", err) + } + if !reflect.DeepEqual(out, want) { + t.Errorf("got %q, want %q", out, want) + } +} + +// Time3339 is a time.Time which encodes to and from JSON +// as an RFC 3339 time in UTC. +type Time3339 time.Time + +func (t *Time3339) UnmarshalJSON(b []byte) error { + if len(b) < 2 || b[0] != '"' || b[len(b)-1] != '"' { + return fmt.Errorf("types: failed to unmarshal non-string value %q as an RFC 3339 time", b) + } + tm, err := time.Parse(time.RFC3339, string(b[1:len(b)-1])) + if err != nil { + return err + } + *t = Time3339(tm) + return nil +} + +func TestUnmarshalJSONLiteralError(t *testing.T) { + var t3 Time3339 + err := Unmarshal([]byte(`"0000-00-00T00:00:00Z"`), &t3) + if err == nil { + t.Fatalf("expected error; got time %v", time.Time(t3)) + } + if !strings.Contains(err.Error(), "range") { + t.Errorf("got err = %v; want out of range error", err) + } +} diff --git a/src/pkg/encoding/json/encode.go b/src/pkg/encoding/json/encode.go index b6e1cb16e..fb57f1d51 100644 --- a/src/pkg/encoding/json/encode.go +++ b/src/pkg/encoding/json/encode.go @@ -36,7 +36,7 @@ import ( // // Boolean values encode as JSON booleans. // -// Floating point and integer values encode as JSON numbers. +// Floating point, integer, and Number values encode as JSON numbers. // // String values encode as JSON strings, with each invalid UTF-8 sequence // replaced by the encoding of the Unicode replacement character U+FFFD. @@ -55,7 +55,7 @@ import ( // nil pointer or interface value, and any array, slice, map, or string of // length zero. The object's default key string is the struct field name // but can be specified in the struct field's tag value. The "json" key in -// struct field's tag value is the key name, followed by an optional comma +// the struct field's tag value is the key name, followed by an optional comma // and options. Examples: // // // Field is ignored by this package. @@ -75,8 +75,9 @@ import ( // Field int `json:",omitempty"` // // The "string" option signals that a field is stored as JSON inside a -// JSON-encoded string. This extra level of encoding is sometimes -// used when communicating with JavaScript programs: +// JSON-encoded string. It applies only to fields of string, floating point, +// or integer types. This extra level of encoding is sometimes used when +// communicating with JavaScript programs: // // Int64String int64 `json:",string"` // @@ -84,6 +85,16 @@ import ( // only Unicode letters, digits, dollar signs, percent signs, hyphens, // underscores and slashes. // +// Anonymous struct fields are usually marshaled as if their inner exported fields +// were fields in the outer struct, subject to the usual Go visibility rules. +// An anonymous struct field with a name given in its JSON tag is treated as +// having that name instead of as anonymous. +// +// Handling of anonymous struct fields is new in Go 1.1. +// Prior to Go 1.1, anonymous struct fields were ignored. To force ignoring of +// an anonymous struct field in both current and earlier versions, give the field +// a JSON tag of "-". +// // Map values encode as JSON objects. // The map's key type must be string; the object keys are used directly // as map keys. @@ -312,6 +323,14 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) { e.Write(b) } case reflect.String: + if v.Type() == numberType { + numStr := v.String() + if numStr == "" { + numStr = "0" // Number's zero-val + } + e.WriteString(numStr) + break + } if quoted { sb, err := Marshal(v.String()) if err != nil { @@ -325,9 +344,9 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) { case reflect.Struct: e.WriteByte('{') first := true - for _, ef := range encodeFields(v.Type()) { - fieldValue := v.Field(ef.i) - if ef.omitEmpty && isEmptyValue(fieldValue) { + for _, f := range cachedTypeFields(v.Type()) { + fv := fieldByIndex(v, f.index) + if !fv.IsValid() || f.omitEmpty && isEmptyValue(fv) { continue } if first { @@ -335,9 +354,9 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) { } else { e.WriteByte(',') } - e.string(ef.tag) + e.string(f.name) e.WriteByte(':') - e.reflectValueQuoted(fieldValue, ef.quoted) + e.reflectValueQuoted(fv, f.quoted) } e.WriteByte('}') @@ -419,7 +438,7 @@ func isValidTag(s string) bool { } for _, c := range s { switch { - case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~", c): + case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c): // Backslash and quote chars are reserved, but // otherwise any punctuation chars are allowed // in a tag name. @@ -432,6 +451,19 @@ func isValidTag(s string) bool { return true } +func fieldByIndex(v reflect.Value, index []int) reflect.Value { + for _, i := range index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + } + v = v.Field(i) + } + return v +} + // stringValues is a slice of reflect.Value holding *reflect.StringValue. // It implements the methods to sort by string. type stringValues []reflect.Value @@ -490,67 +522,185 @@ func (e *encodeState) string(s string) (int, error) { return e.Len() - len0, nil } -// encodeField contains information about how to encode a field of a -// struct. -type encodeField struct { - i int // field index in struct - tag string - quoted bool +// A field represents a single field found in a struct. +type field struct { + name string + tag bool + index []int + typ reflect.Type omitEmpty bool + quoted bool } -var ( - typeCacheLock sync.RWMutex - encodeFieldsCache = make(map[reflect.Type][]encodeField) -) +// byName sorts field by name, breaking ties with depth, +// then breaking ties with "name came from json tag", then +// breaking ties with index sequence. +type byName []field -// encodeFields returns a slice of encodeField for a given -// struct type. -func encodeFields(t reflect.Type) []encodeField { - typeCacheLock.RLock() - fs, ok := encodeFieldsCache[t] - typeCacheLock.RUnlock() - if ok { - return fs - } +func (x byName) Len() int { return len(x) } + +func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] } - typeCacheLock.Lock() - defer typeCacheLock.Unlock() - fs, ok = encodeFieldsCache[t] - if ok { - return fs +func (x byName) Less(i, j int) bool { + if x[i].name != x[j].name { + return x[i].name < x[j].name } + if len(x[i].index) != len(x[j].index) { + return len(x[i].index) < len(x[j].index) + } + if x[i].tag != x[j].tag { + return x[i].tag + } + return byIndex(x).Less(i, j) +} - v := reflect.Zero(t) - n := v.NumField() - for i := 0; i < n; i++ { - f := t.Field(i) - if f.PkgPath != "" { - continue +// byIndex sorts field by index sequence. +type byIndex []field + +func (x byIndex) Len() int { return len(x) } + +func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x byIndex) Less(i, j int) bool { + for k, xik := range x[i].index { + if k >= len(x[j].index) { + return false } - if f.Anonymous { - // We want to do a better job with these later, - // so for now pretend they don't exist. - continue + if xik != x[j].index[k] { + return xik < x[j].index[k] } - var ef encodeField - ef.i = i - ef.tag = f.Name + } + return len(x[i].index) < len(x[j].index) +} + +// typeFields returns a list of fields that JSON should recognize for the given type. +// The algorithm is breadth-first search over the set of structs to include - the top struct +// and then any reachable anonymous structs. +func typeFields(t reflect.Type) []field { + // Anonymous fields to explore at the current level and the next. + current := []field{} + next := []field{{typ: t}} + + // Count of queued names for current level and the next. + count := map[reflect.Type]int{} + nextCount := map[reflect.Type]int{} + + // Types already visited at an earlier level. + visited := map[reflect.Type]bool{} + + // Fields found. + var fields []field + + for len(next) > 0 { + current, next = next, current[:0] + count, nextCount = nextCount, map[reflect.Type]int{} - tv := f.Tag.Get("json") - if tv != "" { - if tv == "-" { + for _, f := range current { + if visited[f.typ] { continue } - name, opts := parseTag(tv) - if isValidTag(name) { - ef.tag = name + visited[f.typ] = true + + // Scan f.typ for fields to include. + for i := 0; i < f.typ.NumField(); i++ { + sf := f.typ.Field(i) + if sf.PkgPath != "" { // unexported + continue + } + tag := sf.Tag.Get("json") + if tag == "-" { + continue + } + name, opts := parseTag(tag) + if !isValidTag(name) { + name = "" + } + index := make([]int, len(f.index)+1) + copy(index, f.index) + index[len(f.index)] = i + + ft := sf.Type + if ft.Name() == "" && ft.Kind() == reflect.Ptr { + // Follow pointer. + ft = ft.Elem() + } + + // Record found field and index sequence. + if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { + tagged := name != "" + if name == "" { + name = sf.Name + } + fields = append(fields, field{name, tagged, index, ft, + opts.Contains("omitempty"), opts.Contains("string")}) + if count[f.typ] > 1 { + // If there were multiple instances, add a second, + // so that the annihilation code will see a duplicate. + // It only cares about the distinction between 1 or 2, + // so don't bother generating any more copies. + fields = append(fields, fields[len(fields)-1]) + } + continue + } + + // Record new anonymous struct to explore in next round. + nextCount[ft]++ + if nextCount[ft] == 1 { + next = append(next, field{name: ft.Name(), index: index, typ: ft}) + } } - ef.omitEmpty = opts.Contains("omitempty") - ef.quoted = opts.Contains("string") } - fs = append(fs, ef) } - encodeFieldsCache[t] = fs - return fs + + sort.Sort(byName(fields)) + + // Remove fields with annihilating name collisions + // and also fields shadowed by fields with explicit JSON tags. + name := "" + out := fields[:0] + for _, f := range fields { + if f.name != name { + name = f.name + out = append(out, f) + continue + } + if n := len(out); n > 0 && out[n-1].name == name && (!out[n-1].tag || f.tag) { + out = out[:n-1] + } + } + fields = out + + sort.Sort(byIndex(fields)) + + return fields +} + +var fieldCache struct { + sync.RWMutex + m map[reflect.Type][]field +} + +// cachedTypeFields is like typeFields but uses a cache to avoid repeated work. +func cachedTypeFields(t reflect.Type) []field { + fieldCache.RLock() + f := fieldCache.m[t] + fieldCache.RUnlock() + if f != nil { + return f + } + + // Compute fields without lock. + // Might duplicate effort but won't hold other computations back. + f = typeFields(t) + if f == nil { + f = []field{} + } + + fieldCache.Lock() + if fieldCache.m == nil { + fieldCache.m = map[reflect.Type][]field{} + } + fieldCache.m[t] = f + fieldCache.Unlock() + return f } diff --git a/src/pkg/encoding/json/encode_test.go b/src/pkg/encoding/json/encode_test.go index cb1c77eb5..be74c997c 100644 --- a/src/pkg/encoding/json/encode_test.go +++ b/src/pkg/encoding/json/encode_test.go @@ -186,3 +186,23 @@ func TestMarshalerEscaping(t *testing.T) { t.Errorf("got %q, want %q", got, want) } } + +type IntType int + +type MyStruct struct { + IntType +} + +func TestAnonymousNonstruct(t *testing.T) { + var i IntType = 11 + a := MyStruct{i} + const want = `{"IntType":11}` + + b, err := Marshal(a) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if got := string(b); got != want { + t.Errorf("got %q, want %q", got, want) + } +} diff --git a/src/pkg/encoding/json/scanner_test.go b/src/pkg/encoding/json/scanner_test.go index 14d850865..77d3455d3 100644 --- a/src/pkg/encoding/json/scanner_test.go +++ b/src/pkg/encoding/json/scanner_test.go @@ -92,7 +92,7 @@ func TestCompactBig(t *testing.T) { t.Fatalf("Compact: %v", err) } b := buf.Bytes() - if bytes.Compare(b, jsonBig) != 0 { + if !bytes.Equal(b, jsonBig) { t.Error("Compact(jsonBig) != jsonBig") diff(t, b, jsonBig) return @@ -118,7 +118,7 @@ func TestIndentBig(t *testing.T) { t.Fatalf("Indent2: %v", err) } b1 := buf1.Bytes() - if bytes.Compare(b1, b) != 0 { + if !bytes.Equal(b1, b) { t.Error("Indent(Indent(jsonBig)) != Indent(jsonBig)") diff(t, b1, b) return @@ -130,7 +130,7 @@ func TestIndentBig(t *testing.T) { t.Fatalf("Compact: %v", err) } b1 = buf1.Bytes() - if bytes.Compare(b1, jsonBig) != 0 { + if !bytes.Equal(b1, jsonBig) { t.Error("Compact(Indent(jsonBig)) != jsonBig") diff(t, b1, jsonBig) return @@ -277,9 +277,6 @@ func genArray(n int) []interface{} { if f > n { f = n } - if n > 0 && f == 0 { - f = 1 - } x := make([]interface{}, f) for i := range x { x[i] = genValue(((i+1)*n)/f - (i*n)/f) diff --git a/src/pkg/encoding/json/stream.go b/src/pkg/encoding/json/stream.go index 7d1cc5f11..00f4726cf 100644 --- a/src/pkg/encoding/json/stream.go +++ b/src/pkg/encoding/json/stream.go @@ -5,6 +5,7 @@ package json import ( + "bytes" "errors" "io" ) @@ -26,6 +27,10 @@ func NewDecoder(r io.Reader) *Decoder { return &Decoder{r: r} } +// UseNumber causes the Decoder to unmarshal a number into an interface{} as a +// Number instead of as a float64. +func (dec *Decoder) UseNumber() { dec.d.useNumber = true } + // Decode reads the next JSON-encoded value from its // input and stores it in the value pointed to by v. // @@ -54,6 +59,12 @@ func (dec *Decoder) Decode(v interface{}) error { return err } +// Buffered returns a reader of the data remaining in the Decoder's +// buffer. The reader is valid until the next call to Decode. +func (dec *Decoder) Buffered() io.Reader { + return bytes.NewReader(dec.buf) +} + // readValue reads a JSON value into dec.buf. // It returns the length of the encoding. func (dec *Decoder) readValue() (int, error) { @@ -74,7 +85,7 @@ Input: // scanEnd is delayed one byte. // We might block trying to get that byte from src, // so instead invent a space byte. - if v == scanEndObject && dec.scan.step(&dec.scan, ' ') == scanEnd { + if (v == scanEndObject || v == scanEndArray) && dec.scan.step(&dec.scan, ' ') == scanEnd { scanp += i + 1 break Input } diff --git a/src/pkg/encoding/json/stream_test.go b/src/pkg/encoding/json/stream_test.go index ce5a7e6d6..07c9e1d39 100644 --- a/src/pkg/encoding/json/stream_test.go +++ b/src/pkg/encoding/json/stream_test.go @@ -6,7 +6,10 @@ package json import ( "bytes" + "io/ioutil" + "net" "reflect" + "strings" "testing" ) @@ -82,6 +85,28 @@ func TestDecoder(t *testing.T) { } } +func TestDecoderBuffered(t *testing.T) { + r := strings.NewReader(`{"Name": "Gopher"} extra `) + var m struct { + Name string + } + d := NewDecoder(r) + err := d.Decode(&m) + if err != nil { + t.Fatal(err) + } + if m.Name != "Gopher" { + t.Errorf("Name = %q; want Gopher", m.Name) + } + rest, err := ioutil.ReadAll(d.Buffered()) + if err != nil { + t.Fatal(err) + } + if g, w := string(rest), " extra "; g != w { + t.Errorf("Remaining = %q; want %q", g, w) + } +} + func nlines(s string, n int) string { if n <= 0 { return "" @@ -145,3 +170,24 @@ func TestNullRawMessage(t *testing.T) { t.Fatalf("Marshal: have %#q want %#q", b, msg) } } + +var blockingTests = []string{ + `{"x": 1}`, + `[1, 2, 3]`, +} + +func TestBlocking(t *testing.T) { + for _, enc := range blockingTests { + r, w := net.Pipe() + go w.Write([]byte(enc)) + var val interface{} + + // If Decode reads beyond what w.Write writes above, + // it will block, and the test will deadlock. + if err := NewDecoder(r).Decode(&val); err != nil { + t.Errorf("decoding %s: %v", enc, err) + } + r.Close() + w.Close() + } +} diff --git a/src/pkg/encoding/json/tagkey_test.go b/src/pkg/encoding/json/tagkey_test.go index da8b12bd8..23e71c752 100644 --- a/src/pkg/encoding/json/tagkey_test.go +++ b/src/pkg/encoding/json/tagkey_test.go @@ -60,6 +60,14 @@ type badCodeTag struct { Z string `json:" !\"#&'()*+,."` } +type spaceTag struct { + Q string `json:"With space"` +} + +type unicodeTag struct { + W string `json:"Ελλάδα"` +} + var structTagObjectKeyTests = []struct { raw interface{} value string @@ -78,6 +86,8 @@ var structTagObjectKeyTests = []struct { {badCodeTag{"Reliable Man"}, "Reliable Man", "Z"}, {percentSlashTag{"brut"}, "brut", "text/html%"}, {punctuationTag{"Union Rags"}, "Union Rags", "!#$%&()*+-./:<=>?@[]^_{|}~"}, + {spaceTag{"Perreddu"}, "Perreddu", "With space"}, + {unicodeTag{"Loukanikos"}, "Loukanikos", "Ελλάδα"}, } func TestStructTagObjectKey(t *testing.T) { diff --git a/src/pkg/encoding/pem/pem.go b/src/pkg/encoding/pem/pem.go index 3c1f5ab70..8ff7ee8c3 100644 --- a/src/pkg/encoding/pem/pem.go +++ b/src/pkg/encoding/pem/pem.go @@ -11,6 +11,7 @@ import ( "bytes" "encoding/base64" "io" + "sort" ) // A Block represents a PEM encoded structure. @@ -209,26 +210,46 @@ func (l *lineBreaker) Close() (err error) { return } -func Encode(out io.Writer, b *Block) (err error) { - _, err = out.Write(pemStart[1:]) - if err != nil { - return +func writeHeader(out io.Writer, k, v string) error { + _, err := out.Write([]byte(k + ": " + v + "\n")) + return err +} + +func Encode(out io.Writer, b *Block) error { + if _, err := out.Write(pemStart[1:]); err != nil { + return err } - _, err = out.Write([]byte(b.Type + "-----\n")) - if err != nil { - return + if _, err := out.Write([]byte(b.Type + "-----\n")); err != nil { + return err } if len(b.Headers) > 0 { - for k, v := range b.Headers { - _, err = out.Write([]byte(k + ": " + v + "\n")) - if err != nil { - return + const procType = "Proc-Type" + h := make([]string, 0, len(b.Headers)) + hasProcType := false + for k := range b.Headers { + if k == procType { + hasProcType = true + continue } + h = append(h, k) } - _, err = out.Write([]byte{'\n'}) - if err != nil { - return + // The Proc-Type header must be written first. + // See RFC 1421, section 4.6.1.1 + if hasProcType { + if err := writeHeader(out, procType, b.Headers[procType]); err != nil { + return err + } + } + // For consistency of output, write other headers sorted by key. + sort.Strings(h) + for _, k := range h { + if err := writeHeader(out, k, b.Headers[k]); err != nil { + return err + } + } + if _, err := out.Write([]byte{'\n'}); err != nil { + return err } } @@ -236,19 +257,17 @@ func Encode(out io.Writer, b *Block) (err error) { breaker.out = out b64 := base64.NewEncoder(base64.StdEncoding, &breaker) - _, err = b64.Write(b.Bytes) - if err != nil { - return + if _, err := b64.Write(b.Bytes); err != nil { + return err } b64.Close() breaker.Close() - _, err = out.Write(pemEnd[1:]) - if err != nil { - return + if _, err := out.Write(pemEnd[1:]); err != nil { + return err } - _, err = out.Write([]byte(b.Type + "-----\n")) - return + _, err := out.Write([]byte(b.Type + "-----\n")) + return err } func EncodeToMemory(b *Block) []byte { diff --git a/src/pkg/encoding/pem/pem_test.go b/src/pkg/encoding/pem/pem_test.go index 613353483..ccce42cf1 100644 --- a/src/pkg/encoding/pem/pem_test.go +++ b/src/pkg/encoding/pem/pem_test.go @@ -43,7 +43,7 @@ func TestDecode(t *testing.T) { if !reflect.DeepEqual(result, privateKey) { t.Errorf("#1 got:%#v want:%#v", result, privateKey) } - result, _ = Decode([]byte(pemPrivateKey)) + result, _ = Decode([]byte(pemPrivateKey2)) if !reflect.DeepEqual(result, privateKey2) { t.Errorf("#2 got:%#v want:%#v", result, privateKey2) } @@ -51,8 +51,8 @@ func TestDecode(t *testing.T) { func TestEncode(t *testing.T) { r := EncodeToMemory(privateKey2) - if string(r) != pemPrivateKey { - t.Errorf("got:%s want:%s", r, pemPrivateKey) + if string(r) != pemPrivateKey2 { + t.Errorf("got:%s want:%s", r, pemPrivateKey2) } } @@ -341,50 +341,64 @@ var privateKey = &Block{Type: "RSA PRIVATE KEY", }, } -var privateKey2 = &Block{Type: "RSA PRIVATE KEY", - Headers: map[string]string{}, - Bytes: []uint8{0x30, 0x82, 0x1, 0x3a, 0x2, 0x1, 0x0, 0x2, - 0x41, 0x0, 0xb2, 0x99, 0xf, 0x49, 0xc4, 0x7d, 0xfa, 0x8c, - 0xd4, 0x0, 0xae, 0x6a, 0x4d, 0x1b, 0x8a, 0x3b, 0x6a, 0x13, - 0x64, 0x2b, 0x23, 0xf2, 0x8b, 0x0, 0x3b, 0xfb, 0x97, 0x79, - 0xa, 0xde, 0x9a, 0x4c, 0xc8, 0x2b, 0x8b, 0x2a, 0x81, 0x74, - 0x7d, 0xde, 0xc0, 0x8b, 0x62, 0x96, 0xe5, 0x3a, 0x8, 0xc3, - 0x31, 0x68, 0x7e, 0xf2, 0x5c, 0x4b, 0xf4, 0x93, 0x6b, 0xa1, - 0xc0, 0xe6, 0x4, 0x1e, 0x9d, 0x15, 0x2, 0x3, 0x1, 0x0, 0x1, - 0x2, 0x41, 0x0, 0x8a, 0xbd, 0x6a, 0x69, 0xf4, 0xd1, 0xa4, - 0xb4, 0x87, 0xf0, 0xab, 0x8d, 0x7a, 0xae, 0xfd, 0x38, 0x60, - 0x94, 0x5, 0xc9, 0x99, 0x98, 0x4e, 0x30, 0xf5, 0x67, 0xe1, - 0xe8, 0xae, 0xef, 0xf4, 0x4e, 0x8b, 0x18, 0xbd, 0xb1, 0xec, - 0x78, 0xdf, 0xa3, 0x1a, 0x55, 0xe3, 0x2a, 0x48, 0xd7, 0xfb, - 0x13, 0x1f, 0x5a, 0xf1, 0xf4, 0x4d, 0x7d, 0x6b, 0x2c, 0xed, - 0x2a, 0x9d, 0xf5, 0xe5, 0xae, 0x45, 0x35, 0x2, 0x21, 0x0, - 0xda, 0xb2, 0xf1, 0x80, 0x48, 0xba, 0xa6, 0x8d, 0xe7, 0xdf, - 0x4, 0xd2, 0xd3, 0x5d, 0x5d, 0x80, 0xe6, 0xe, 0x2d, 0xfa, - 0x42, 0xd5, 0xa, 0x9b, 0x4, 0x21, 0x90, 0x32, 0x71, 0x5e, - 0x46, 0xb3, 0x2, 0x21, 0x0, 0xd1, 0xf, 0x2e, 0x66, 0xb1, - 0xd0, 0xc1, 0x3f, 0x10, 0xef, 0x99, 0x27, 0xbf, 0x53, 0x24, - 0xa3, 0x79, 0xca, 0x21, 0x81, 0x46, 0xcb, 0xf9, 0xca, 0xfc, - 0x79, 0x52, 0x21, 0xf1, 0x6a, 0x31, 0x17, 0x2, 0x20, 0x21, - 0x2, 0x89, 0x79, 0x37, 0x81, 0x14, 0xca, 0xae, 0x88, 0xf7, - 0xd, 0x6b, 0x61, 0xd8, 0x4f, 0x30, 0x6a, 0x4b, 0x7e, 0x4e, - 0xc0, 0x21, 0x4d, 0xac, 0x9d, 0xf4, 0x49, 0xe8, 0xda, 0xb6, - 0x9, 0x2, 0x20, 0x16, 0xb3, 0xec, 0x59, 0x10, 0xa4, 0x57, - 0xe8, 0xe, 0x61, 0xc6, 0xa3, 0xf, 0x5e, 0xeb, 0x12, 0xa9, - 0xae, 0x2e, 0xb7, 0x48, 0x45, 0xec, 0x69, 0x83, 0xc3, 0x75, - 0xc, 0xe4, 0x97, 0xa0, 0x9f, 0x2, 0x20, 0x69, 0x52, 0xb4, - 0x6, 0xe8, 0x50, 0x60, 0x71, 0x4c, 0x3a, 0xb7, 0x66, 0xba, - 0xd, 0x8a, 0xc9, 0xb7, 0xd, 0xa3, 0x8, 0x6c, 0xa3, 0xf2, - 0x62, 0xb0, 0x2a, 0x84, 0xaa, 0x2f, 0xd6, 0x1e, 0x55, +var privateKey2 = &Block{ + Type: "RSA PRIVATE KEY", + Headers: map[string]string{ + "Proc-Type": "4,ENCRYPTED", + "DEK-Info": "AES-128-CBC,BFCD243FEDBB40A4AA6DDAA1335473A4", + "Content-Domain": "RFC822", + }, + Bytes: []uint8{ + 0xa8, 0x35, 0xcc, 0x2b, 0xb9, 0xcb, 0x21, 0xab, 0xc0, + 0x9d, 0x76, 0x61, 0x0, 0xf4, 0x81, 0xad, 0x69, 0xd2, + 0xc0, 0x42, 0x41, 0x3b, 0xe4, 0x3c, 0xaf, 0x59, 0x5e, + 0x6d, 0x2a, 0x3c, 0x9c, 0xa1, 0xa4, 0x5e, 0x68, 0x37, + 0xc4, 0x8c, 0x70, 0x1c, 0xa9, 0x18, 0xe6, 0xc2, 0x2b, + 0x8a, 0x91, 0xdc, 0x2d, 0x1f, 0x8, 0x23, 0x39, 0xf1, + 0x4b, 0x8b, 0x1b, 0x2f, 0x46, 0xb, 0xb2, 0x26, 0xba, + 0x4f, 0x40, 0x80, 0x39, 0xc4, 0xb1, 0xcb, 0x3b, 0xb4, + 0x65, 0x3f, 0x1b, 0xb2, 0xf7, 0x8, 0xd2, 0xc6, 0xd5, + 0xa8, 0x9f, 0x23, 0x69, 0xb6, 0x3d, 0xf9, 0xac, 0x1c, + 0xb3, 0x13, 0x87, 0x64, 0x4, 0x37, 0xdb, 0x40, 0xc8, + 0x82, 0xc, 0xd0, 0xf8, 0x21, 0x7c, 0xdc, 0xbd, 0x9, 0x4, + 0x20, 0x16, 0xb0, 0x97, 0xe2, 0x6d, 0x56, 0x1d, 0xe3, + 0xec, 0xf0, 0xfc, 0xe2, 0x56, 0xad, 0xa4, 0x3, 0x70, + 0x6d, 0x63, 0x3c, 0x1, 0xbe, 0x3e, 0x28, 0x38, 0x6f, + 0xc0, 0xe6, 0xfd, 0x85, 0xd1, 0x53, 0xa8, 0x9b, 0xcb, + 0xd4, 0x4, 0xb1, 0x73, 0xb9, 0x73, 0x32, 0xd6, 0x7a, + 0xc6, 0x29, 0x25, 0xa5, 0xda, 0x17, 0x93, 0x7a, 0x10, + 0xe8, 0x41, 0xfb, 0xa5, 0x17, 0x20, 0xf8, 0x4e, 0xe9, + 0xe3, 0x8f, 0x51, 0x20, 0x13, 0xbb, 0xde, 0xb7, 0x93, + 0xae, 0x13, 0x8a, 0xf6, 0x9, 0xf4, 0xa6, 0x41, 0xe0, + 0x2b, 0x51, 0x1a, 0x30, 0x38, 0xd, 0xb1, 0x3b, 0x67, + 0x87, 0x64, 0xf5, 0xca, 0x32, 0x67, 0xd1, 0xc8, 0xa5, + 0x3d, 0x23, 0x72, 0xc4, 0x6, 0xaf, 0x8f, 0x7b, 0x26, + 0xac, 0x3c, 0x75, 0x91, 0xa1, 0x0, 0x13, 0xc6, 0x5c, + 0x49, 0xd5, 0x3c, 0xe7, 0xb2, 0xb2, 0x99, 0xe0, 0xd5, + 0x25, 0xfa, 0xe2, 0x12, 0x80, 0x37, 0x85, 0xcf, 0x92, + 0xca, 0x1b, 0x9f, 0xf3, 0x4e, 0xd8, 0x80, 0xef, 0x3c, + 0xce, 0xcd, 0xf5, 0x90, 0x9e, 0xf9, 0xa7, 0xb2, 0xc, + 0x49, 0x4, 0xf1, 0x9, 0x8f, 0xea, 0x63, 0xd2, 0x70, + 0xbb, 0x86, 0xbf, 0x34, 0xab, 0xb2, 0x3, 0xb1, 0x59, + 0x33, 0x16, 0x17, 0xb0, 0xdb, 0x77, 0x38, 0xf4, 0xb4, + 0x94, 0xb, 0x25, 0x16, 0x7e, 0x22, 0xd4, 0xf9, 0x22, + 0xb9, 0x78, 0xa3, 0x4, 0x84, 0x4, 0xd2, 0xda, 0x84, + 0x2d, 0x63, 0xdd, 0xf8, 0x50, 0x6a, 0xf6, 0xe3, 0xf5, + 0x65, 0x40, 0x7c, 0xa9, }, } -var pemPrivateKey = `-----BEGIN RSA PRIVATE KEY----- -MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0 -fd7Ai2KW5ToIwzFofvJcS/STa6HA5gQenRUCAwEAAQJBAIq9amn00aS0h/CrjXqu -/ThglAXJmZhOMPVn4eiu7/ROixi9sex436MaVeMqSNf7Ex9a8fRNfWss7Sqd9eWu -RTUCIQDasvGASLqmjeffBNLTXV2A5g4t+kLVCpsEIZAycV5GswIhANEPLmax0ME/ -EO+ZJ79TJKN5yiGBRsv5yvx5UiHxajEXAiAhAol5N4EUyq6I9w1rYdhPMGpLfk7A -IU2snfRJ6Nq2CQIgFrPsWRCkV+gOYcajD17rEqmuLrdIRexpg8N1DOSXoJ8CIGlS -tAboUGBxTDq3ZroNism3DaMIbKPyYrAqhKov1h5V +var pemPrivateKey2 = `-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +Content-Domain: RFC822 +DEK-Info: AES-128-CBC,BFCD243FEDBB40A4AA6DDAA1335473A4 + +qDXMK7nLIavAnXZhAPSBrWnSwEJBO+Q8r1lebSo8nKGkXmg3xIxwHKkY5sIripHc +LR8IIznxS4sbL0YLsia6T0CAOcSxyzu0ZT8bsvcI0sbVqJ8jabY9+awcsxOHZAQ3 +20DIggzQ+CF83L0JBCAWsJfibVYd4+zw/OJWraQDcG1jPAG+Pig4b8Dm/YXRU6ib +y9QEsXO5czLWesYpJaXaF5N6EOhB+6UXIPhO6eOPUSATu963k64TivYJ9KZB4CtR +GjA4DbE7Z4dk9coyZ9HIpT0jcsQGr497Jqw8dZGhABPGXEnVPOeyspng1SX64hKA +N4XPksobn/NO2IDvPM7N9ZCe+aeyDEkE8QmP6mPScLuGvzSrsgOxWTMWF7Dbdzj0 +tJQLJRZ+ItT5Irl4owSEBNLahC1j3fhQavbj9WVAfKk= -----END RSA PRIVATE KEY----- ` diff --git a/src/pkg/encoding/xml/example_test.go b/src/pkg/encoding/xml/example_test.go index 97c8c0b0d..becedd583 100644 --- a/src/pkg/encoding/xml/example_test.go +++ b/src/pkg/encoding/xml/example_test.go @@ -50,6 +50,46 @@ func ExampleMarshalIndent() { // </person> } +func ExampleEncoder() { + type Address struct { + City, State string + } + type Person struct { + XMLName xml.Name `xml:"person"` + Id int `xml:"id,attr"` + FirstName string `xml:"name>first"` + LastName string `xml:"name>last"` + Age int `xml:"age"` + Height float32 `xml:"height,omitempty"` + Married bool + Address + Comment string `xml:",comment"` + } + + v := &Person{Id: 13, FirstName: "John", LastName: "Doe", Age: 42} + v.Comment = " Need more details. " + v.Address = Address{"Hanga Roa", "Easter Island"} + + enc := xml.NewEncoder(os.Stdout) + enc.Indent(" ", " ") + if err := enc.Encode(v); err != nil { + fmt.Printf("error: %v\n", err) + } + + // Output: + // <person id="13"> + // <name> + // <first>John</first> + // <last>Doe</last> + // </name> + // <age>42</age> + // <Married>false</Married> + // <City>Hanga Roa</City> + // <State>Easter Island</State> + // <!-- Need more details. --> + // </person> +} + // This example demonstrates unmarshaling an XML excerpt into a value with // some preset fields. Note that the Phone field isn't modified and that // the XML <Company> element is ignored. Also, the Groups field is assigned diff --git a/src/pkg/encoding/xml/marshal.go b/src/pkg/encoding/xml/marshal.go index 6c3170bdd..ea58ce254 100644 --- a/src/pkg/encoding/xml/marshal.go +++ b/src/pkg/encoding/xml/marshal.go @@ -45,7 +45,7 @@ const ( // - a field with tag "name,attr" becomes an attribute with // the given name in the XML element. // - a field with tag ",attr" becomes an attribute with the -// field name in the in the XML element. +// field name in the XML element. // - a field with tag ",chardata" is written as character data, // not as an XML element. // - a field with tag ",innerxml" is written verbatim, not subject @@ -57,8 +57,8 @@ const ( // if the field value is empty. The empty values are false, 0, any // nil pointer or interface value, and any array, slice, map, or // string of length zero. -// - a non-pointer anonymous struct field is handled as if the -// fields of its value were part of the outer struct. +// - an anonymous struct field is handled as if the fields of its +// value were part of the outer struct. // // If a field uses a tag "a>b>c", then the element c will be nested inside // parent elements a and b. Fields that appear next to each other that name @@ -81,11 +81,8 @@ func Marshal(v interface{}) ([]byte, error) { func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { var b bytes.Buffer enc := NewEncoder(&b) - enc.prefix = prefix - enc.indent = indent - err := enc.marshalValue(reflect.ValueOf(v), nil) - enc.Flush() - if err != nil { + enc.Indent(prefix, indent) + if err := enc.Encode(v); err != nil { return nil, err } return b.Bytes(), nil @@ -101,14 +98,24 @@ func NewEncoder(w io.Writer) *Encoder { return &Encoder{printer{Writer: bufio.NewWriter(w)}} } +// Indent sets the encoder to generate XML in which each element +// begins on a new indented line that starts with prefix and is followed by +// one or more copies of indent according to the nesting depth. +func (enc *Encoder) Indent(prefix, indent string) { + enc.prefix = prefix + enc.indent = indent +} + // Encode writes the XML encoding of v to the stream. // // See the documentation for Marshal for details about the conversion // of Go values to XML. func (enc *Encoder) Encode(v interface{}) error { err := enc.marshalValue(reflect.ValueOf(v), nil) - enc.Flush() - return err + if err != nil { + return err + } + return enc.Flush() } type printer struct { @@ -117,6 +124,7 @@ type printer struct { prefix string depth int indentedIn bool + putNewline bool } // marshalValue writes one or more XML elements representing val. @@ -164,7 +172,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { xmlname := tinfo.xmlname if xmlname.name != "" { xmlns, name = xmlname.xmlns, xmlname.name - } else if v, ok := val.FieldByIndex(xmlname.idx).Interface().(Name); ok && v.Local != "" { + } else if v, ok := xmlname.value(val).Interface().(Name); ok && v.Local != "" { xmlns, name = v.Space, v.Local } } @@ -185,7 +193,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { if xmlns != "" { p.WriteString(` xmlns="`) // TODO: EscapeString, to avoid the allocation. - Escape(p, []byte(xmlns)) + if err := EscapeText(p, []byte(xmlns)); err != nil { + return err + } p.WriteByte('"') } @@ -195,7 +205,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { if finfo.flags&fAttr == 0 { continue } - fv := val.FieldByIndex(finfo.idx) + fv := finfo.value(val) if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) { continue } @@ -224,7 +234,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { p.WriteString(name) p.WriteByte('>') - return nil + return p.cachedWriteError() } var timeType = reflect.TypeOf(time.Time{}) @@ -241,50 +251,70 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error { case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: p.WriteString(strconv.FormatUint(val.Uint(), 10)) case reflect.Float32, reflect.Float64: - p.WriteString(strconv.FormatFloat(val.Float(), 'g', -1, 64)) + p.WriteString(strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits())) case reflect.String: // TODO: Add EscapeString. - Escape(p, []byte(val.String())) + EscapeText(p, []byte(val.String())) case reflect.Bool: p.WriteString(strconv.FormatBool(val.Bool())) case reflect.Array: // will be [...]byte - bytes := make([]byte, val.Len()) - for i := range bytes { - bytes[i] = val.Index(i).Interface().(byte) + var bytes []byte + if val.CanAddr() { + bytes = val.Slice(0, val.Len()).Bytes() + } else { + bytes = make([]byte, val.Len()) + reflect.Copy(reflect.ValueOf(bytes), val) } - Escape(p, bytes) + EscapeText(p, bytes) case reflect.Slice: // will be []byte - Escape(p, val.Bytes()) + EscapeText(p, val.Bytes()) default: return &UnsupportedTypeError{typ} } - return nil + return p.cachedWriteError() } var ddBytes = []byte("--") func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { if val.Type() == timeType { - p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano)) - return nil + _, err := p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano)) + return err } s := parentStack{printer: p} for i := range tinfo.fields { finfo := &tinfo.fields[i] - if finfo.flags&(fAttr|fAny) != 0 { + if finfo.flags&(fAttr) != 0 { continue } - vf := val.FieldByIndex(finfo.idx) + vf := finfo.value(val) switch finfo.flags & fMode { case fCharData: + var scratch [64]byte switch vf.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + Escape(p, strconv.AppendInt(scratch[:0], vf.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + Escape(p, strconv.AppendUint(scratch[:0], vf.Uint(), 10)) + case reflect.Float32, reflect.Float64: + Escape(p, strconv.AppendFloat(scratch[:0], vf.Float(), 'g', -1, vf.Type().Bits())) + case reflect.Bool: + Escape(p, strconv.AppendBool(scratch[:0], vf.Bool())) case reflect.String: - Escape(p, []byte(vf.String())) + if err := EscapeText(p, []byte(vf.String())); err != nil { + return err + } case reflect.Slice: if elem, ok := vf.Interface().([]byte); ok { - Escape(p, elem) + if err := EscapeText(p, elem); err != nil { + return err + } + } + case reflect.Struct: + if vf.Type() == timeType { + Escape(p, []byte(vf.Interface().(time.Time).Format(time.RFC3339Nano))) } } continue @@ -340,7 +370,7 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { continue } - case fElement: + case fElement, fElement | fAny: s.trim(finfo.parents) if len(finfo.parents) > len(s.stack) { if vf.Kind() != reflect.Ptr && vf.Kind() != reflect.Interface || !vf.IsNil() { @@ -353,7 +383,13 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { } } s.trim(nil) - return nil + return p.cachedWriteError() +} + +// return the bufio Writer's cached write error +func (p *printer) cachedWriteError() error { + _, err := p.Write(nil) + return err } func (p *printer) writeIndent(depthDelta int) { @@ -368,7 +404,11 @@ func (p *printer) writeIndent(depthDelta int) { } p.indentedIn = false } - p.WriteByte('\n') + if p.putNewline { + p.WriteByte('\n') + } else { + p.putNewline = true + } if len(p.prefix) > 0 { p.WriteString(p.prefix) } diff --git a/src/pkg/encoding/xml/marshal_test.go b/src/pkg/encoding/xml/marshal_test.go index b6978a1e6..3a190def6 100644 --- a/src/pkg/encoding/xml/marshal_test.go +++ b/src/pkg/encoding/xml/marshal_test.go @@ -5,6 +5,10 @@ package xml import ( + "bytes" + "errors" + "fmt" + "io" "reflect" "strconv" "strings" @@ -56,6 +60,36 @@ type Book struct { Title string `xml:",chardata"` } +type Event struct { + XMLName struct{} `xml:"event"` + Year int `xml:",chardata"` +} + +type Movie struct { + XMLName struct{} `xml:"movie"` + Length uint `xml:",chardata"` +} + +type Pi struct { + XMLName struct{} `xml:"pi"` + Approximation float32 `xml:",chardata"` +} + +type Universe struct { + XMLName struct{} `xml:"universe"` + Visible float64 `xml:",chardata"` +} + +type Particle struct { + XMLName struct{} `xml:"particle"` + HasMass bool `xml:",chardata"` +} + +type Departure struct { + XMLName struct{} `xml:"departure"` + When time.Time `xml:",chardata"` +} + type SecretAgent struct { XMLName struct{} `xml:"agent"` Handle string `xml:"handle,attr"` @@ -108,7 +142,7 @@ type EmbedA struct { type EmbedB struct { FieldB string - EmbedC + *EmbedC } type EmbedC struct { @@ -185,6 +219,18 @@ type AnyTest struct { AnyField AnyHolder `xml:",any"` } +type AnyOmitTest struct { + XMLName struct{} `xml:"a"` + Nested string `xml:"nested>value"` + AnyField *AnyHolder `xml:",any,omitempty"` +} + +type AnySliceTest struct { + XMLName struct{} `xml:"a"` + Nested string `xml:"nested>value"` + AnyField []AnyHolder `xml:",any"` +} + type AnyHolder struct { XMLName Name XML string `xml:",innerxml"` @@ -330,6 +376,12 @@ var marshalTests = []struct { {Value: &Domain{Name: []byte("google.com&friends")}, ExpectXML: `<domain>google.com&friends</domain>`}, {Value: &Domain{Name: []byte("google.com"), Comment: []byte(" &friends ")}, ExpectXML: `<domain>google.com<!-- &friends --></domain>`}, {Value: &Book{Title: "Pride & Prejudice"}, ExpectXML: `<book>Pride & Prejudice</book>`}, + {Value: &Event{Year: -3114}, ExpectXML: `<event>-3114</event>`}, + {Value: &Movie{Length: 13440}, ExpectXML: `<movie>13440</movie>`}, + {Value: &Pi{Approximation: 3.14159265}, ExpectXML: `<pi>3.1415927</pi>`}, + {Value: &Universe{Visible: 9.3e13}, ExpectXML: `<universe>9.3e+13</universe>`}, + {Value: &Particle{HasMass: true}, ExpectXML: `<particle>true</particle>`}, + {Value: &Departure{When: ParseTime("2013-01-09T00:15:00-09:00")}, ExpectXML: `<departure>2013-01-09T00:15:00-09:00</departure>`}, {Value: atomValue, ExpectXML: atomXml}, { Value: &Ship{ @@ -493,7 +545,7 @@ var marshalTests = []struct { }, EmbedB: EmbedB{ FieldB: "A.B.B", - EmbedC: EmbedC{ + EmbedC: &EmbedC{ FieldA1: "A.B.C.A1", FieldA2: "A.B.C.A2", FieldB: "", // Shadowed by A.B.B @@ -649,12 +701,43 @@ var marshalTests = []struct { XML: "<sub>unknown</sub>", }, }, - UnmarshalOnly: true, }, { - Value: &AnyTest{Nested: "known", AnyField: AnyHolder{XML: "<unknown/>"}}, - ExpectXML: `<a><nested><value>known</value></nested></a>`, - MarshalOnly: true, + Value: &AnyTest{Nested: "known", + AnyField: AnyHolder{ + XML: "<unknown/>", + XMLName: Name{Local: "AnyField"}, + }, + }, + ExpectXML: `<a><nested><value>known</value></nested><AnyField><unknown/></AnyField></a>`, + }, + { + ExpectXML: `<a><nested><value>b</value></nested></a>`, + Value: &AnyOmitTest{ + Nested: "b", + }, + }, + { + ExpectXML: `<a><nested><value>b</value></nested><c><d>e</d></c><g xmlns="f"><h>i</h></g></a>`, + Value: &AnySliceTest{ + Nested: "b", + AnyField: []AnyHolder{ + { + XMLName: Name{Local: "c"}, + XML: "<d>e</d>", + }, + { + XMLName: Name{Space: "f", Local: "g"}, + XML: "<h>i</h>", + }, + }, + }, + }, + { + ExpectXML: `<a><nested><value>b</value></nested></a>`, + Value: &AnySliceTest{ + Nested: "b", + }, }, // Test recursive types. @@ -684,6 +767,29 @@ var marshalTests = []struct { Value: &IgnoreTest{}, UnmarshalOnly: true, }, + + // Test escaping. + { + ExpectXML: `<a><nested><value>dquote: "; squote: '; ampersand: &; less: <; greater: >;</value></nested><empty></empty></a>`, + Value: &AnyTest{ + Nested: `dquote: "; squote: '; ampersand: &; less: <; greater: >;`, + AnyField: AnyHolder{XMLName: Name{Local: "empty"}}, + }, + }, + { + ExpectXML: `<a><nested><value>newline: 
; cr: 
; tab: 	;</value></nested><AnyField></AnyField></a>`, + Value: &AnyTest{ + Nested: "newline: \n; cr: \r; tab: \t;", + AnyField: AnyHolder{XMLName: Name{Local: "AnyField"}}, + }, + }, + { + ExpectXML: "<a><nested><value>1\r2\r\n3\n\r4\n5</value></nested></a>", + Value: &AnyTest{ + Nested: "1\n2\n3\n\n4\n5", + }, + UnmarshalOnly: true, + }, } func TestMarshal(t *testing.T) { @@ -735,6 +841,24 @@ var marshalErrorTests = []struct { }, } +var marshalIndentTests = []struct { + Value interface{} + Prefix string + Indent string + ExpectXML string +}{ + { + Value: &SecretAgent{ + Handle: "007", + Identity: "James Bond", + Obfuscate: "<redacted/>", + }, + Prefix: "", + Indent: "\t", + ExpectXML: fmt.Sprintf("<agent handle=\"007\">\n\t<Identity>James Bond</Identity><redacted/>\n</agent>"), + }, +} + func TestMarshalErrors(t *testing.T) { for idx, test := range marshalErrorTests { _, err := Marshal(test.Value) @@ -779,6 +903,78 @@ func TestUnmarshal(t *testing.T) { } } +func TestMarshalIndent(t *testing.T) { + for i, test := range marshalIndentTests { + data, err := MarshalIndent(test.Value, test.Prefix, test.Indent) + if err != nil { + t.Errorf("#%d: Error: %s", i, err) + continue + } + if got, want := string(data), test.ExpectXML; got != want { + t.Errorf("#%d: MarshalIndent:\nGot:%s\nWant:\n%s", i, got, want) + } + } +} + +type limitedBytesWriter struct { + w io.Writer + remain int // until writes fail +} + +func (lw *limitedBytesWriter) Write(p []byte) (n int, err error) { + if lw.remain <= 0 { + println("error") + return 0, errors.New("write limit hit") + } + if len(p) > lw.remain { + p = p[:lw.remain] + n, _ = lw.w.Write(p) + lw.remain = 0 + return n, errors.New("write limit hit") + } + n, err = lw.w.Write(p) + lw.remain -= n + return n, err +} + +func TestMarshalWriteErrors(t *testing.T) { + var buf bytes.Buffer + const writeCap = 1024 + w := &limitedBytesWriter{&buf, writeCap} + enc := NewEncoder(w) + var err error + var i int + const n = 4000 + for i = 1; i <= n; i++ { + err = enc.Encode(&Passenger{ + Name: []string{"Alice", "Bob"}, + Weight: 5, + }) + if err != nil { + break + } + } + if err == nil { + t.Error("expected an error") + } + if i == n { + t.Errorf("expected to fail before the end") + } + if buf.Len() != writeCap { + t.Errorf("buf.Len() = %d; want %d", buf.Len(), writeCap) + } +} + +func TestMarshalWriteIOErrors(t *testing.T) { + enc := NewEncoder(errWriter{}) + + expectErr := "unwritable" + err := enc.Encode(&Passenger{}) + if err == nil || err.Error() != expectErr { + t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr) + } +} + func BenchmarkMarshal(b *testing.B) { for i := 0; i < b.N; i++ { Marshal(atomValue) diff --git a/src/pkg/encoding/xml/read.go b/src/pkg/encoding/xml/read.go index c21682420..344ab514e 100644 --- a/src/pkg/encoding/xml/read.go +++ b/src/pkg/encoding/xml/read.go @@ -81,8 +81,8 @@ import ( // of the above rules and the struct has a field with tag ",any", // unmarshal maps the sub-element to that struct field. // -// * A non-pointer anonymous struct field is handled as if the -// fields of its value were part of the outer struct. +// * An anonymous struct field is handled as if the fields of its +// value were part of the outer struct. // // * A struct field with tag "-" is never unmarshalled into. // @@ -248,7 +248,7 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { } return UnmarshalError(e) } - fv := sv.FieldByIndex(finfo.idx) + fv := finfo.value(sv) if _, ok := fv.Interface().(Name); ok { fv.Set(reflect.ValueOf(start.Name)) } @@ -260,7 +260,7 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { finfo := &tinfo.fields[i] switch finfo.flags & fMode { case fAttr: - strv := sv.FieldByIndex(finfo.idx) + strv := finfo.value(sv) // Look for attribute. for _, a := range start.Attr { if a.Name.Local == finfo.name { @@ -271,22 +271,22 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { case fCharData: if !saveData.IsValid() { - saveData = sv.FieldByIndex(finfo.idx) + saveData = finfo.value(sv) } case fComment: if !saveComment.IsValid() { - saveComment = sv.FieldByIndex(finfo.idx) + saveComment = finfo.value(sv) } - case fAny: + case fAny, fAny | fElement: if !saveAny.IsValid() { - saveAny = sv.FieldByIndex(finfo.idx) + saveAny = finfo.value(sv) } case fInnerXml: if !saveXML.IsValid() { - saveXML = sv.FieldByIndex(finfo.idx) + saveXML = finfo.value(sv) if p.saved == nil { saveXMLIndex = 0 p.saved = new(bytes.Buffer) @@ -374,68 +374,58 @@ Loop: } func copyValue(dst reflect.Value, src []byte) (err error) { - // Helper functions for integer and unsigned integer conversions - var itmp int64 - getInt64 := func() bool { - itmp, err = strconv.ParseInt(string(src), 10, 64) - // TODO: should check sizes - return err == nil - } - var utmp uint64 - getUint64 := func() bool { - utmp, err = strconv.ParseUint(string(src), 10, 64) - // TODO: check for overflow? - return err == nil - } - var ftmp float64 - getFloat64 := func() bool { - ftmp, err = strconv.ParseFloat(string(src), 64) - // TODO: check for overflow? - return err == nil + if dst.Kind() == reflect.Ptr { + if dst.IsNil() { + dst.Set(reflect.New(dst.Type().Elem())) + } + dst = dst.Elem() } // Save accumulated data. - switch t := dst; t.Kind() { + switch dst.Kind() { case reflect.Invalid: - // Probably a comment. + // Probably a commendst. default: - return errors.New("cannot happen: unknown type " + t.Type().String()) + return errors.New("cannot happen: unknown type " + dst.Type().String()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if !getInt64() { + itmp, err := strconv.ParseInt(string(src), 10, dst.Type().Bits()) + if err != nil { return err } - t.SetInt(itmp) + dst.SetInt(itmp) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - if !getUint64() { + utmp, err := strconv.ParseUint(string(src), 10, dst.Type().Bits()) + if err != nil { return err } - t.SetUint(utmp) + dst.SetUint(utmp) case reflect.Float32, reflect.Float64: - if !getFloat64() { + ftmp, err := strconv.ParseFloat(string(src), dst.Type().Bits()) + if err != nil { return err } - t.SetFloat(ftmp) + dst.SetFloat(ftmp) case reflect.Bool: value, err := strconv.ParseBool(strings.TrimSpace(string(src))) if err != nil { return err } - t.SetBool(value) + dst.SetBool(value) case reflect.String: - t.SetString(string(src)) + dst.SetString(string(src)) case reflect.Slice: if len(src) == 0 { // non-nil to flag presence src = []byte{} } - t.SetBytes(src) + dst.SetBytes(src) case reflect.Struct: - if t.Type() == timeType { + if dst.Type() == timeType { tv, err := time.Parse(time.RFC3339, string(src)) if err != nil { return err } - t.Set(reflect.ValueOf(tv)) + dst.Set(reflect.ValueOf(tv)) } } return nil @@ -461,7 +451,7 @@ Loop: } if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local { // It's a perfect match, unmarshal the field. - return true, p.unmarshal(sv.FieldByIndex(finfo.idx), start) + return true, p.unmarshal(finfo.value(sv), start) } if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local { // It's a prefix for the field. Break and recurse diff --git a/src/pkg/encoding/xml/read_test.go b/src/pkg/encoding/xml/read_test.go index 8df09b3cc..b45e2f0e6 100644 --- a/src/pkg/encoding/xml/read_test.go +++ b/src/pkg/encoding/xml/read_test.go @@ -355,3 +355,47 @@ func TestUnmarshalWithoutNameType(t *testing.T) { t.Fatalf("have %v\nwant %v", x.Attr, OK) } } + +func TestUnmarshalAttr(t *testing.T) { + type ParamVal struct { + Int int `xml:"int,attr"` + } + + type ParamPtr struct { + Int *int `xml:"int,attr"` + } + + type ParamStringPtr struct { + Int *string `xml:"int,attr"` + } + + x := []byte(`<Param int="1" />`) + + p1 := &ParamPtr{} + if err := Unmarshal(x, p1); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if p1.Int == nil { + t.Fatalf("Unmarshal failed in to *int field") + } else if *p1.Int != 1 { + t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p1.Int, 1) + } + + p2 := &ParamVal{} + if err := Unmarshal(x, p2); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if p2.Int != 1 { + t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p2.Int, 1) + } + + p3 := &ParamStringPtr{} + if err := Unmarshal(x, p3); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if p3.Int == nil { + t.Fatalf("Unmarshal failed in to *string field") + } else if *p3.Int != "1" { + t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p3.Int, 1) + } +} diff --git a/src/pkg/encoding/xml/typeinfo.go b/src/pkg/encoding/xml/typeinfo.go index 8e2e4508b..bbeb28d87 100644 --- a/src/pkg/encoding/xml/typeinfo.go +++ b/src/pkg/encoding/xml/typeinfo.go @@ -66,10 +66,14 @@ func getTypeInfo(typ reflect.Type) (*typeInfo, error) { // For embedded structs, embed its fields. if f.Anonymous { - if f.Type.Kind() != reflect.Struct { + t := f.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { continue } - inner, err := getTypeInfo(f.Type) + inner, err := getTypeInfo(t) if err != nil { return nil, err } @@ -150,6 +154,9 @@ func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, erro // This will also catch multiple modes in a single field. valid = false } + if finfo.flags&fMode == fAny { + finfo.flags |= fElement + } if finfo.flags&fOmitEmpty != 0 && finfo.flags&(fElement|fAttr) == 0 { valid = false } @@ -327,3 +334,22 @@ type TagPathError struct { func (e *TagPathError) Error() string { return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2) } + +// value returns v's field value corresponding to finfo. +// It's equivalent to v.FieldByIndex(finfo.idx), but initializes +// and dereferences pointers as necessary. +func (finfo *fieldInfo) value(v reflect.Value) reflect.Value { + for i, x := range finfo.idx { + if i > 0 { + t := v.Type() + if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + } + v = v.Field(x) + } + return v +} diff --git a/src/pkg/encoding/xml/xml.go b/src/pkg/encoding/xml/xml.go index 5066f5c01..143fec554 100644 --- a/src/pkg/encoding/xml/xml.go +++ b/src/pkg/encoding/xml/xml.go @@ -181,7 +181,6 @@ type Decoder struct { ns map[string]string err error line int - tmp [32]byte } // NewDecoder creates a new XML parser reading from r. @@ -584,6 +583,7 @@ func (d *Decoder) RawToken() (Token, error) { if inquote == 0 && b == '>' && depth == 0 { break } + HandleB: d.buf.WriteByte(b) switch { case b == inquote: @@ -599,7 +599,35 @@ func (d *Decoder) RawToken() (Token, error) { depth-- case b == '<' && inquote == 0: - depth++ + // Look for <!-- to begin comment. + s := "!--" + for i := 0; i < len(s); i++ { + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b != s[i] { + for j := 0; j < i; j++ { + d.buf.WriteByte(s[j]) + } + depth++ + goto HandleB + } + } + + // Remove < that was written above. + d.buf.Truncate(d.buf.Len() - 1) + + // Look for terminator. + var b0, b1 byte + for { + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b0 == '-' && b1 == '-' && b == '>' { + break + } + b0, b1 = b1, b + } } } return Directive(d.buf.Bytes()), nil @@ -848,78 +876,103 @@ Input: // XML in all its glory allows a document to define and use // its own character names with <!ENTITY ...> directives. // Parsers are required to recognize lt, gt, amp, apos, and quot - // even if they have not been declared. That's all we allow. - var i int - for i = 0; i < len(d.tmp); i++ { - var ok bool - d.tmp[i], ok = d.getc() - if !ok { - if d.err == io.EOF { - d.err = d.syntaxError("unexpected EOF") - } + // even if they have not been declared. + before := d.buf.Len() + d.buf.WriteByte('&') + var ok bool + var text string + var haveText bool + if b, ok = d.mustgetc(); !ok { + return nil + } + if b == '#' { + d.buf.WriteByte(b) + if b, ok = d.mustgetc(); !ok { return nil } - c := d.tmp[i] - if c == ';' { - break - } - if 'a' <= c && c <= 'z' || - 'A' <= c && c <= 'Z' || - '0' <= c && c <= '9' || - c == '_' || c == '#' { - continue + base := 10 + if b == 'x' { + base = 16 + d.buf.WriteByte(b) + if b, ok = d.mustgetc(); !ok { + return nil + } } - d.ungetc(c) - break - } - s := string(d.tmp[0:i]) - if i >= len(d.tmp) { - if !d.Strict { - b0, b1 = 0, 0 - d.buf.WriteByte('&') - d.buf.Write(d.tmp[0:i]) - continue Input + start := d.buf.Len() + for '0' <= b && b <= '9' || + base == 16 && 'a' <= b && b <= 'f' || + base == 16 && 'A' <= b && b <= 'F' { + d.buf.WriteByte(b) + if b, ok = d.mustgetc(); !ok { + return nil + } } - d.err = d.syntaxError("character entity expression &" + s + "... too long") - return nil - } - var haveText bool - var text string - if i >= 2 && s[0] == '#' { - var n uint64 - var err error - if i >= 3 && s[1] == 'x' { - n, err = strconv.ParseUint(s[2:], 16, 64) + if b != ';' { + d.ungetc(b) } else { - n, err = strconv.ParseUint(s[1:], 10, 64) - } - if err == nil && n <= unicode.MaxRune { - text = string(n) - haveText = true + s := string(d.buf.Bytes()[start:]) + d.buf.WriteByte(';') + n, err := strconv.ParseUint(s, base, 64) + if err == nil && n <= unicode.MaxRune { + text = string(n) + haveText = true + } } } else { - if r, ok := entity[s]; ok { - text = string(r) - haveText = true - } else if d.Entity != nil { - text, haveText = d.Entity[s] + d.ungetc(b) + if !d.readName() { + if d.err != nil { + return nil + } + ok = false } - } - if !haveText { - if !d.Strict { - b0, b1 = 0, 0 - d.buf.WriteByte('&') - d.buf.Write(d.tmp[0:i]) - continue Input + if b, ok = d.mustgetc(); !ok { + return nil } - d.err = d.syntaxError("invalid character entity &" + s + ";") - return nil + if b != ';' { + d.ungetc(b) + } else { + name := d.buf.Bytes()[before+1:] + d.buf.WriteByte(';') + if isName(name) { + s := string(name) + if r, ok := entity[s]; ok { + text = string(r) + haveText = true + } else if d.Entity != nil { + text, haveText = d.Entity[s] + } + } + } + } + + if haveText { + d.buf.Truncate(before) + d.buf.Write([]byte(text)) + b0, b1 = 0, 0 + continue Input } - d.buf.Write([]byte(text)) - b0, b1 = 0, 0 - continue Input + if !d.Strict { + b0, b1 = 0, 0 + continue Input + } + ent := string(d.buf.Bytes()[before]) + if ent[len(ent)-1] != ';' { + ent += " (no semicolon)" + } + d.err = d.syntaxError("invalid character entity " + ent) + return nil } - d.buf.WriteByte(b) + + // We must rewrite unescaped \r and \r\n into \n. + if b == '\r' { + d.buf.WriteByte('\n') + } else if b1 == '\r' && b == '\n' { + // Skip \r\n--we already wrote \n. + } else { + d.buf.WriteByte(b) + } + b0, b1 = b1, b } data := d.buf.Bytes() @@ -940,20 +993,7 @@ Input: } } - // Must rewrite \r and \r\n into \n. - w := 0 - for r := 0; r < len(data); r++ { - b := data[r] - if b == '\r' { - if r+1 < len(data) && data[r+1] == '\n' { - continue - } - b = '\n' - } - data[w] = b - w++ - } - return data[0:w] + return data } // Decide whether the given rune is in the XML Character Range, per @@ -989,18 +1029,34 @@ func (d *Decoder) nsname() (name Name, ok bool) { // Do not set d.err if the name is missing (unless unexpected EOF is received): // let the caller provide better context. func (d *Decoder) name() (s string, ok bool) { + d.buf.Reset() + if !d.readName() { + return "", false + } + + // Now we check the characters. + s = d.buf.String() + if !isName([]byte(s)) { + d.err = d.syntaxError("invalid XML name: " + s) + return "", false + } + return s, true +} + +// Read a name and append its bytes to d.buf. +// The name is delimited by any single-byte character not valid in names. +// All multi-byte characters are accepted; the caller must check their validity. +func (d *Decoder) readName() (ok bool) { var b byte if b, ok = d.mustgetc(); !ok { return } - - // As a first approximation, we gather the bytes [A-Za-z_:.-\x80-\xFF]* if b < utf8.RuneSelf && !isNameByte(b) { d.ungetc(b) - return "", false + return false } - d.buf.Reset() d.buf.WriteByte(b) + for { if b, ok = d.mustgetc(); !ok { return @@ -1011,16 +1067,7 @@ func (d *Decoder) name() (s string, ok bool) { } d.buf.WriteByte(b) } - - // Then we check the characters. - s = d.buf.String() - for i, c := range s { - if !unicode.Is(first, c) && (i == 0 || !unicode.Is(second, c)) { - d.err = d.syntaxError("invalid XML name: " + s) - return "", false - } - } - return s, true + return true } func isNameByte(c byte) bool { @@ -1030,6 +1077,30 @@ func isNameByte(c byte) bool { c == '_' || c == ':' || c == '.' || c == '-' } +func isName(s []byte) bool { + if len(s) == 0 { + return false + } + c, n := utf8.DecodeRune(s) + if c == utf8.RuneError && n == 1 { + return false + } + if !unicode.Is(first, c) { + return false + } + for n < len(s) { + s = s[n:] + c, n = utf8.DecodeRune(s) + if c == utf8.RuneError && n == 1 { + return false + } + if !unicode.Is(first, c) && !unicode.Is(second, c) { + return false + } + } + return true +} + // These tables were generated by cut and paste from Appendix B of // the XML spec at http://www.xml.com/axml/testaxml.htm // and then reformatting. First corresponds to (Letter | '_' | ':') @@ -1621,7 +1692,7 @@ var HTMLAutoClose = htmlAutoClose var htmlAutoClose = []string{ /* hget http://www.w3.org/TR/html4/loose.dtd | - 9 sed -n 's/<!ELEMENT (.*) - O EMPTY.+/ "\1",/p' | tr A-Z a-z + 9 sed -n 's/<!ELEMENT ([^ ]*) +- O EMPTY.+/ "\1",/p' | tr A-Z a-z */ "basefont", "br", @@ -1631,7 +1702,7 @@ var htmlAutoClose = []string{ "param", "hr", "input", - "col ", + "col", "frame", "isindex", "base", @@ -1644,11 +1715,14 @@ var ( esc_amp = []byte("&") esc_lt = []byte("<") esc_gt = []byte(">") + esc_tab = []byte("	") + esc_nl = []byte("
") + esc_cr = []byte("
") ) -// Escape writes to w the properly escaped XML equivalent +// EscapeText writes to w the properly escaped XML equivalent // of the plain text data s. -func Escape(w io.Writer, s []byte) { +func EscapeText(w io.Writer, s []byte) error { var esc []byte last := 0 for i, c := range s { @@ -1663,14 +1737,34 @@ func Escape(w io.Writer, s []byte) { esc = esc_lt case '>': esc = esc_gt + case '\t': + esc = esc_tab + case '\n': + esc = esc_nl + case '\r': + esc = esc_cr default: continue } - w.Write(s[last:i]) - w.Write(esc) + if _, err := w.Write(s[last:i]); err != nil { + return err + } + if _, err := w.Write(esc); err != nil { + return err + } last = i + 1 } - w.Write(s[last:]) + if _, err := w.Write(s[last:]); err != nil { + return err + } + return nil +} + +// Escape is like EscapeText but omits the error return value. +// It is provided for backwards compatibility with Go 1.0. +// Code targeting Go 1.1 or later should use EscapeText. +func Escape(w io.Writer, s []byte) { + EscapeText(w, s) } // procInstEncoding parses the `encoding="..."` or `encoding='...'` diff --git a/src/pkg/encoding/xml/xml_test.go b/src/pkg/encoding/xml/xml_test.go index 1d0696ce0..54dab5484 100644 --- a/src/pkg/encoding/xml/xml_test.go +++ b/src/pkg/encoding/xml/xml_test.go @@ -5,6 +5,7 @@ package xml import ( + "fmt" "io" "reflect" "strings" @@ -18,6 +19,7 @@ const testInput = ` <body xmlns:foo="ns1" xmlns="ns2" xmlns:tag="ns3" ` + "\r\n\t" + ` > <hello lang="en">World <>'" 白鵬翔</hello> + <query>&何; &is-it;</query> <goodbye /> <outer foo:attr="value" xmlns:tag="ns4"> <inner/> @@ -27,6 +29,8 @@ const testInput = ` </tag:name> </body><!-- missing final newline -->` +var testEntity = map[string]string{"何": "What", "is-it": "is it?"} + var rawTokens = []Token{ CharData("\n"), ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)}, @@ -40,6 +44,10 @@ var rawTokens = []Token{ CharData("World <>'\" 白鵬翔"), EndElement{Name{"", "hello"}}, CharData("\n "), + StartElement{Name{"", "query"}, []Attr{}}, + CharData("What is it?"), + EndElement{Name{"", "query"}}, + CharData("\n "), StartElement{Name{"", "goodbye"}, []Attr{}}, EndElement{Name{"", "goodbye"}}, CharData("\n "), @@ -73,6 +81,10 @@ var cookedTokens = []Token{ CharData("World <>'\" 白鵬翔"), EndElement{Name{"ns2", "hello"}}, CharData("\n "), + StartElement{Name{"ns2", "query"}, []Attr{}}, + CharData("What is it?"), + EndElement{Name{"ns2", "query"}}, + CharData("\n "), StartElement{Name{"ns2", "goodbye"}, []Attr{}}, EndElement{Name{"ns2", "goodbye"}}, CharData("\n "), @@ -155,9 +167,65 @@ var xmlInput = []string{ func TestRawToken(t *testing.T) { d := NewDecoder(strings.NewReader(testInput)) + d.Entity = testEntity testRawToken(t, d, rawTokens) } +const nonStrictInput = ` +<tag>non&entity</tag> +<tag>&unknown;entity</tag> +<tag>{</tag> +<tag>&#zzz;</tag> +<tag>&なまえ3;</tag> +<tag><-gt;</tag> +<tag>&;</tag> +<tag>&0a;</tag> +` + +var nonStringEntity = map[string]string{"": "oops!", "0a": "oops!"} + +var nonStrictTokens = []Token{ + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("non&entity"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&unknown;entity"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("{"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&#zzz;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&なまえ3;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("<-gt;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&0a;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), +} + +func TestNonStrictRawToken(t *testing.T) { + d := NewDecoder(strings.NewReader(nonStrictInput)) + d.Strict = false + testRawToken(t, d, nonStrictTokens) +} + type downCaser struct { t *testing.T r io.ByteReader @@ -219,7 +287,18 @@ func testRawToken(t *testing.T, d *Decoder, rawTokens []Token) { t.Fatalf("token %d: unexpected error: %s", i, err) } if !reflect.DeepEqual(have, want) { - t.Errorf("token %d = %#v want %#v", i, have, want) + var shave, swant string + if _, ok := have.(CharData); ok { + shave = fmt.Sprintf("CharData(%q)", have) + } else { + shave = fmt.Sprintf("%#v", have) + } + if _, ok := want.(CharData); ok { + swant = fmt.Sprintf("CharData(%q)", want) + } else { + swant = fmt.Sprintf("%#v", want) + } + t.Errorf("token %d = %s, want %s", i, shave, swant) } } } @@ -272,6 +351,7 @@ func TestNestedDirectives(t *testing.T) { func TestToken(t *testing.T) { d := NewDecoder(strings.NewReader(testInput)) + d.Entity = testEntity for i, want := range cookedTokens { have, err := d.Token() @@ -531,8 +611,8 @@ var characterTests = []struct { {"\xef\xbf\xbe<doc/>", "illegal character code U+FFFE"}, {"<?xml version=\"1.0\"?><doc>\r\n<hiya/>\x07<toots/></doc>", "illegal character code U+0007"}, {"<?xml version=\"1.0\"?><doc \x12='value'>what's up</doc>", "expected attribute name in element"}, - {"<doc>&\x01;</doc>", "invalid character entity &;"}, - {"<doc>&\xef\xbf\xbe;</doc>", "invalid character entity &;"}, + {"<doc>&\x01;</doc>", "invalid character entity & (no semicolon)"}, + {"<doc>&\xef\xbf\xbe;</doc>", "invalid character entity & (no semicolon)"}, } func TestDisallowedCharacters(t *testing.T) { @@ -576,3 +656,50 @@ func TestProcInstEncoding(t *testing.T) { } } } + +// Ensure that directives with comments include the complete +// text of any nested directives. + +var directivesWithCommentsInput = ` +<!DOCTYPE [<!-- a comment --><!ENTITY rdf "http://www.w3.org/1999/02/22-rdf-syntax-ns#">]> +<!DOCTYPE [<!ENTITY go "Golang"><!-- a comment-->]> +<!DOCTYPE <!-> <!> <!----> <!-->--> <!--->--> [<!ENTITY go "Golang"><!-- a comment-->]> +` + +var directivesWithCommentsTokens = []Token{ + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY rdf "http://www.w3.org/1999/02/22-rdf-syntax-ns#">]`), + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY go "Golang">]`), + CharData("\n"), + Directive(`DOCTYPE <!-> <!> [<!ENTITY go "Golang">]`), + CharData("\n"), +} + +func TestDirectivesWithComments(t *testing.T) { + d := NewDecoder(strings.NewReader(directivesWithCommentsInput)) + + for i, want := range directivesWithCommentsTokens { + have, err := d.Token() + if err != nil { + t.Fatalf("token %d: unexpected error: %s", i, err) + } + if !reflect.DeepEqual(have, want) { + t.Errorf("token %d = %#v want %#v", i, have, want) + } + } +} + +// Writer whose Write method always returns an error. +type errWriter struct{} + +func (errWriter) Write(p []byte) (n int, err error) { return 0, fmt.Errorf("unwritable") } + +func TestEscapeTextIOErrors(t *testing.T) { + expectErr := "unwritable" + err := EscapeText(errWriter{}, []byte{'A'}) + + if err == nil || err.Error() != expectErr { + t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr) + } +} |