diff options
author | Tianon Gravi <admwiggin@gmail.com> | 2015-01-15 11:54:00 -0700 |
---|---|---|
committer | Tianon Gravi <admwiggin@gmail.com> | 2015-01-15 11:54:00 -0700 |
commit | f154da9e12608589e8d5f0508f908a0c3e88a1bb (patch) | |
tree | f8255d51e10c6f1e0ed69702200b966c9556a431 /src/encoding | |
parent | 8d8329ed5dfb9622c82a9fbec6fd99a580f9c9f6 (diff) | |
download | golang-upstream/1.4.tar.gz |
Imported Upstream version 1.4upstream/1.4
Diffstat (limited to 'src/encoding')
74 files changed, 31303 insertions, 0 deletions
diff --git a/src/encoding/ascii85/ascii85.go b/src/encoding/ascii85/ascii85.go new file mode 100644 index 000000000..4d7193873 --- /dev/null +++ b/src/encoding/ascii85/ascii85.go @@ -0,0 +1,310 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ascii85 implements the ascii85 data encoding +// as used in the btoa tool and Adobe's PostScript and PDF document formats. +package ascii85 + +import ( + "io" + "strconv" +) + +/* + * Encoder + */ + +// Encode encodes src into at most MaxEncodedLen(len(src)) +// bytes of dst, returning the actual number of bytes written. +// +// The encoding handles 4-byte chunks, using a special encoding +// for the last fragment, so Encode is not appropriate for use on +// individual blocks of a large data stream. Use NewEncoder() instead. +// +// Often, ascii85-encoded data is wrapped in <~ and ~> symbols. +// Encode does not add these. +func Encode(dst, src []byte) int { + if len(src) == 0 { + return 0 + } + + n := 0 + for len(src) > 0 { + dst[0] = 0 + dst[1] = 0 + dst[2] = 0 + dst[3] = 0 + dst[4] = 0 + + // Unpack 4 bytes into uint32 to repack into base 85 5-byte. + var v uint32 + switch len(src) { + default: + v |= uint32(src[3]) + fallthrough + case 3: + v |= uint32(src[2]) << 8 + fallthrough + case 2: + v |= uint32(src[1]) << 16 + fallthrough + case 1: + v |= uint32(src[0]) << 24 + } + + // Special case: zero (!!!!!) shortens to z. + if v == 0 && len(src) >= 4 { + dst[0] = 'z' + dst = dst[1:] + src = src[4:] + n++ + continue + } + + // Otherwise, 5 base 85 digits starting at !. + for i := 4; i >= 0; i-- { + dst[i] = '!' + byte(v%85) + v /= 85 + } + + // If src was short, discard the low destination bytes. + m := 5 + if len(src) < 4 { + m -= 4 - len(src) + src = nil + } else { + src = src[4:] + } + dst = dst[m:] + n += m + } + return n +} + +// MaxEncodedLen returns the maximum length of an encoding of n source bytes. +func MaxEncodedLen(n int) int { return (n + 3) / 4 * 5 } + +// NewEncoder returns a new ascii85 stream encoder. Data written to +// the returned writer will be encoded and then written to w. +// Ascii85 encodings operate in 32-bit blocks; when finished +// writing, the caller must Close the returned encoder to flush any +// trailing partial block. +func NewEncoder(w io.Writer) io.WriteCloser { return &encoder{w: w} } + +type encoder struct { + err error + w io.Writer + buf [4]byte // buffered data waiting to be encoded + nbuf int // number of bytes in buf + out [1024]byte // output buffer +} + +func (e *encoder) Write(p []byte) (n int, err error) { + if e.err != nil { + return 0, e.err + } + + // Leading fringe. + if e.nbuf > 0 { + var i int + for i = 0; i < len(p) && e.nbuf < 4; i++ { + e.buf[e.nbuf] = p[i] + e.nbuf++ + } + n += i + p = p[i:] + if e.nbuf < 4 { + return + } + nout := Encode(e.out[0:], e.buf[0:]) + if _, e.err = e.w.Write(e.out[0:nout]); e.err != nil { + return n, e.err + } + e.nbuf = 0 + } + + // Large interior chunks. + for len(p) >= 4 { + nn := len(e.out) / 5 * 4 + if nn > len(p) { + nn = len(p) + } + nn -= nn % 4 + if nn > 0 { + nout := Encode(e.out[0:], p[0:nn]) + if _, e.err = e.w.Write(e.out[0:nout]); e.err != nil { + return n, e.err + } + } + n += nn + p = p[nn:] + } + + // Trailing fringe. + for i := 0; i < len(p); i++ { + e.buf[i] = p[i] + } + e.nbuf = len(p) + n += len(p) + return +} + +// Close flushes any pending output from the encoder. +// It is an error to call Write after calling Close. +func (e *encoder) Close() error { + // If there's anything left in the buffer, flush it out + if e.err == nil && e.nbuf > 0 { + nout := Encode(e.out[0:], e.buf[0:e.nbuf]) + e.nbuf = 0 + _, e.err = e.w.Write(e.out[0:nout]) + } + return e.err +} + +/* + * Decoder + */ + +type CorruptInputError int64 + +func (e CorruptInputError) Error() string { + return "illegal ascii85 data at input byte " + strconv.FormatInt(int64(e), 10) +} + +// Decode decodes src into dst, returning both the number +// of bytes written to dst and the number consumed from src. +// If src contains invalid ascii85 data, Decode will return the +// number of bytes successfully written and a CorruptInputError. +// Decode ignores space and control characters in src. +// Often, ascii85-encoded data is wrapped in <~ and ~> symbols. +// Decode expects these to have been stripped by the caller. +// +// If flush is true, Decode assumes that src represents the +// end of the input stream and processes it completely rather +// than wait for the completion of another 32-bit block. +// +// NewDecoder wraps an io.Reader interface around Decode. +// +func Decode(dst, src []byte, flush bool) (ndst, nsrc int, err error) { + var v uint32 + var nb int + for i, b := range src { + if len(dst)-ndst < 4 { + return + } + switch { + case b <= ' ': + continue + case b == 'z' && nb == 0: + nb = 5 + v = 0 + case '!' <= b && b <= 'u': + v = v*85 + uint32(b-'!') + nb++ + default: + return 0, 0, CorruptInputError(i) + } + if nb == 5 { + nsrc = i + 1 + dst[ndst] = byte(v >> 24) + dst[ndst+1] = byte(v >> 16) + dst[ndst+2] = byte(v >> 8) + dst[ndst+3] = byte(v) + ndst += 4 + nb = 0 + v = 0 + } + } + if flush { + nsrc = len(src) + if nb > 0 { + // The number of output bytes in the last fragment + // is the number of leftover input bytes - 1: + // the extra byte provides enough bits to cover + // the inefficiency of the encoding for the block. + if nb == 1 { + return 0, 0, CorruptInputError(len(src)) + } + for i := nb; i < 5; i++ { + // The short encoding truncated the output value. + // We have to assume the worst case values (digit 84) + // in order to ensure that the top bits are correct. + v = v*85 + 84 + } + for i := 0; i < nb-1; i++ { + dst[ndst] = byte(v >> 24) + v <<= 8 + ndst++ + } + } + } + return +} + +// NewDecoder constructs a new ascii85 stream decoder. +func NewDecoder(r io.Reader) io.Reader { return &decoder{r: r} } + +type decoder struct { + err error + readErr error + r io.Reader + buf [1024]byte // leftover input + nbuf int + out []byte // leftover decoded output + outbuf [1024]byte +} + +func (d *decoder) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if d.err != nil { + return 0, d.err + } + + for { + // Copy leftover output from last decode. + if len(d.out) > 0 { + n = copy(p, d.out) + d.out = d.out[n:] + return + } + + // Decode leftover input from last read. + var nn, nsrc, ndst int + if d.nbuf > 0 { + ndst, nsrc, d.err = Decode(d.outbuf[0:], d.buf[0:d.nbuf], d.readErr != nil) + if ndst > 0 { + d.out = d.outbuf[0:ndst] + d.nbuf = copy(d.buf[0:], d.buf[nsrc:d.nbuf]) + continue // copy out and return + } + if ndst == 0 && d.err == nil { + // Special case: input buffer is mostly filled with non-data bytes. + // Filter out such bytes to make room for more input. + off := 0 + for i := 0; i < d.nbuf; i++ { + if d.buf[i] > ' ' { + d.buf[off] = d.buf[i] + off++ + } + } + d.nbuf = off + } + } + + // Out of input, out of decoded output. Check errors. + if d.err != nil { + return 0, d.err + } + if d.readErr != nil { + d.err = d.readErr + return 0, d.err + } + + // Read more data. + nn, d.readErr = d.r.Read(d.buf[d.nbuf:]) + d.nbuf += nn + } +} diff --git a/src/encoding/ascii85/ascii85_test.go b/src/encoding/ascii85/ascii85_test.go new file mode 100644 index 000000000..aad199b4f --- /dev/null +++ b/src/encoding/ascii85/ascii85_test.go @@ -0,0 +1,210 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ascii85 + +import ( + "bytes" + "io" + "io/ioutil" + "strings" + "testing" +) + +type testpair struct { + decoded, encoded string +} + +var pairs = []testpair{ + // Encode returns 0 when len(src) is 0 + { + "", + "", + }, + // Wikipedia example + { + "Man is distinguished, not only by his reason, but by this singular passion from " + + "other animals, which is a lust of the mind, that by a perseverance of delight in " + + "the continued and indefatigable generation of knowledge, exceeds the short " + + "vehemence of any carnal pleasure.", + "9jqo^BlbD-BleB1DJ+*+F(f,q/0JhKF<GL>Cj@.4Gp$d7F!,L7@<6@)/0JDEF<G%<+EV:2F!,\n" + + "O<DJ+*.@<*K0@<6L(Df-\\0Ec5e;DffZ(EZee.Bl.9pF\"AGXBPCsi+DGm>@3BB/F*&OCAfu2/AKY\n" + + "i(DIb:@FD,*)+C]U=@3BN#EcYf8ATD3s@q?d$AftVqCh[NqF<G:8+EV:.+Cf>-FD5W8ARlolDIa\n" + + "l(DId<j@<?3r@:F%a+D58'ATD4$Bl@l3De:,-DJs`8ARoFb/0JMK@qB4^F!,R<AKZ&-DfTqBG%G\n" + + ">uD.RTpAKYo'+CT/5+Cei#DII?(E,9)oF*2M7/c\n", + }, + // Special case when shortening !!!!! to z. + { + "\000\000\000\000", + "z", + }, +} + +var bigtest = pairs[len(pairs)-1] + +func testEqual(t *testing.T, msg string, args ...interface{}) bool { + if args[len(args)-2] != args[len(args)-1] { + t.Errorf(msg, args...) + return false + } + return true +} + +func strip85(s string) string { + t := make([]byte, len(s)) + w := 0 + for r := 0; r < len(s); r++ { + c := s[r] + if c > ' ' { + t[w] = c + w++ + } + } + return string(t[0:w]) +} + +func TestEncode(t *testing.T) { + for _, p := range pairs { + buf := make([]byte, MaxEncodedLen(len(p.decoded))) + n := Encode(buf, []byte(p.decoded)) + buf = buf[0:n] + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, strip85(string(buf)), strip85(p.encoded)) + } +} + +func TestEncoder(t *testing.T) { + for _, p := range pairs { + bb := &bytes.Buffer{} + encoder := NewEncoder(bb) + encoder.Write([]byte(p.decoded)) + encoder.Close() + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, strip85(bb.String()), strip85(p.encoded)) + } +} + +func TestEncoderBuffering(t *testing.T) { + input := []byte(bigtest.decoded) + for bs := 1; bs <= 12; bs++ { + bb := &bytes.Buffer{} + encoder := NewEncoder(bb) + for pos := 0; pos < len(input); pos += bs { + end := pos + bs + if end > len(input) { + end = len(input) + } + n, err := encoder.Write(input[pos:end]) + testEqual(t, "Write(%q) gave error %v, want %v", input[pos:end], err, error(nil)) + testEqual(t, "Write(%q) gave length %v, want %v", input[pos:end], n, end-pos) + } + err := encoder.Close() + testEqual(t, "Close gave error %v, want %v", err, error(nil)) + testEqual(t, "Encoding/%d of %q = %q, want %q", bs, bigtest.decoded, strip85(bb.String()), strip85(bigtest.encoded)) + } +} + +func TestDecode(t *testing.T) { + for _, p := range pairs { + dbuf := make([]byte, 4*len(p.encoded)) + ndst, nsrc, err := Decode(dbuf, []byte(p.encoded), true) + testEqual(t, "Decode(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, "Decode(%q) = nsrc %v, want %v", p.encoded, nsrc, len(p.encoded)) + testEqual(t, "Decode(%q) = ndst %v, want %v", p.encoded, ndst, len(p.decoded)) + testEqual(t, "Decode(%q) = %q, want %q", p.encoded, string(dbuf[0:ndst]), p.decoded) + } +} + +func TestDecoder(t *testing.T) { + for _, p := range pairs { + decoder := NewDecoder(strings.NewReader(p.encoded)) + dbuf, err := ioutil.ReadAll(decoder) + if err != nil { + t.Fatal("Read failed", err) + } + testEqual(t, "Read from %q = length %v, want %v", p.encoded, len(dbuf), len(p.decoded)) + testEqual(t, "Decoding of %q = %q, want %q", p.encoded, string(dbuf), p.decoded) + if err != nil { + testEqual(t, "Read from %q = %v, want %v", p.encoded, err, io.EOF) + } + } +} + +func TestDecoderBuffering(t *testing.T) { + for bs := 1; bs <= 12; bs++ { + decoder := NewDecoder(strings.NewReader(bigtest.encoded)) + buf := make([]byte, len(bigtest.decoded)+12) + var total int + for total = 0; total < len(bigtest.decoded); { + n, err := decoder.Read(buf[total : total+bs]) + testEqual(t, "Read from %q at pos %d = %d, %v, want _, %v", bigtest.encoded, total, n, err, error(nil)) + total += n + } + testEqual(t, "Decoding/%d of %q = %q, want %q", bs, bigtest.encoded, string(buf[0:total]), bigtest.decoded) + } +} + +func TestDecodeCorrupt(t *testing.T) { + type corrupt struct { + e string + p int + } + examples := []corrupt{ + {"v", 0}, + {"!z!!!!!!!!!", 1}, + } + + for _, e := range examples { + dbuf := make([]byte, 4*len(e.e)) + _, _, err := Decode(dbuf, []byte(e.e), true) + switch err := err.(type) { + case CorruptInputError: + testEqual(t, "Corruption in %q at offset %v, want %v", e.e, int(err), e.p) + default: + t.Error("Decoder failed to detect corruption in", e) + } + } +} + +func TestBig(t *testing.T) { + n := 3*1000 + 1 + raw := make([]byte, n) + const alpha = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + for i := 0; i < n; i++ { + raw[i] = alpha[i%len(alpha)] + } + encoded := new(bytes.Buffer) + w := NewEncoder(encoded) + nn, err := w.Write(raw) + if nn != n || err != nil { + t.Fatalf("Encoder.Write(raw) = %d, %v want %d, nil", nn, err, n) + } + err = w.Close() + if err != nil { + t.Fatalf("Encoder.Close() = %v want nil", err) + } + decoded, err := ioutil.ReadAll(NewDecoder(encoded)) + if err != nil { + t.Fatalf("io.ReadAll(NewDecoder(...)): %v", err) + } + + if !bytes.Equal(raw, decoded) { + var i int + for i = 0; i < len(decoded) && i < len(raw); i++ { + if decoded[i] != raw[i] { + break + } + } + t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i) + } +} + +func TestDecoderInternalWhitespace(t *testing.T) { + s := strings.Repeat(" ", 2048) + "z" + decoded, err := ioutil.ReadAll(NewDecoder(strings.NewReader(s))) + if err != nil { + t.Errorf("Decode gave error %v", err) + } + if want := []byte("\000\000\000\000"); !bytes.Equal(want, decoded) { + t.Errorf("Decode failed: got %v, want %v", decoded, want) + } +} diff --git a/src/encoding/asn1/asn1.go b/src/encoding/asn1/asn1.go new file mode 100644 index 000000000..8b3d1b341 --- /dev/null +++ b/src/encoding/asn1/asn1.go @@ -0,0 +1,922 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package asn1 implements parsing of DER-encoded ASN.1 data structures, +// as defined in ITU-T Rec X.690. +// +// See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,'' +// http://luca.ntop.org/Teaching/Appunti/asn1.html. +package asn1 + +// ASN.1 is a syntax for specifying abstract objects and BER, DER, PER, XER etc +// are different encoding formats for those objects. Here, we'll be dealing +// with DER, the Distinguished Encoding Rules. DER is used in X.509 because +// it's fast to parse and, unlike BER, has a unique encoding for every object. +// When calculating hashes over objects, it's important that the resulting +// bytes be the same at both ends and DER removes this margin of error. +// +// ASN.1 is very complex and this package doesn't attempt to implement +// everything by any means. + +import ( + "fmt" + "math/big" + "reflect" + "strconv" + "time" +) + +// A StructuralError suggests that the ASN.1 data is valid, but the Go type +// which is receiving it doesn't match. +type StructuralError struct { + Msg string +} + +func (e StructuralError) Error() string { return "asn1: structure error: " + e.Msg } + +// A SyntaxError suggests that the ASN.1 data is invalid. +type SyntaxError struct { + Msg string +} + +func (e SyntaxError) Error() string { return "asn1: syntax error: " + e.Msg } + +// We start by dealing with each of the primitive types in turn. + +// BOOLEAN + +func parseBool(bytes []byte) (ret bool, err error) { + if len(bytes) != 1 { + err = SyntaxError{"invalid boolean"} + return + } + + // DER demands that "If the encoding represents the boolean value TRUE, + // its single contents octet shall have all eight bits set to one." + // Thus only 0 and 255 are valid encoded values. + switch bytes[0] { + case 0: + ret = false + case 0xff: + ret = true + default: + err = SyntaxError{"invalid boolean"} + } + + return +} + +// INTEGER + +// parseInt64 treats the given bytes as a big-endian, signed integer and +// returns the result. +func parseInt64(bytes []byte) (ret int64, err error) { + if len(bytes) > 8 { + // We'll overflow an int64 in this case. + err = StructuralError{"integer too large"} + return + } + for bytesRead := 0; bytesRead < len(bytes); bytesRead++ { + ret <<= 8 + ret |= int64(bytes[bytesRead]) + } + + // Shift up and down in order to sign extend the result. + ret <<= 64 - uint8(len(bytes))*8 + ret >>= 64 - uint8(len(bytes))*8 + return +} + +// parseInt treats the given bytes as a big-endian, signed integer and returns +// the result. +func parseInt32(bytes []byte) (int32, error) { + ret64, err := parseInt64(bytes) + if err != nil { + return 0, err + } + if ret64 != int64(int32(ret64)) { + return 0, StructuralError{"integer too large"} + } + return int32(ret64), nil +} + +var bigOne = big.NewInt(1) + +// parseBigInt treats the given bytes as a big-endian, signed integer and returns +// the result. +func parseBigInt(bytes []byte) *big.Int { + ret := new(big.Int) + if len(bytes) > 0 && bytes[0]&0x80 == 0x80 { + // This is a negative number. + notBytes := make([]byte, len(bytes)) + for i := range notBytes { + notBytes[i] = ^bytes[i] + } + ret.SetBytes(notBytes) + ret.Add(ret, bigOne) + ret.Neg(ret) + return ret + } + ret.SetBytes(bytes) + return ret +} + +// BIT STRING + +// BitString is the structure to use when you want an ASN.1 BIT STRING type. A +// bit string is padded up to the nearest byte in memory and the number of +// valid bits is recorded. Padding bits will be zero. +type BitString struct { + Bytes []byte // bits packed into bytes. + BitLength int // length in bits. +} + +// At returns the bit at the given index. If the index is out of range it +// returns false. +func (b BitString) At(i int) int { + if i < 0 || i >= b.BitLength { + return 0 + } + x := i / 8 + y := 7 - uint(i%8) + return int(b.Bytes[x]>>y) & 1 +} + +// RightAlign returns a slice where the padding bits are at the beginning. The +// slice may share memory with the BitString. +func (b BitString) RightAlign() []byte { + shift := uint(8 - (b.BitLength % 8)) + if shift == 8 || len(b.Bytes) == 0 { + return b.Bytes + } + + a := make([]byte, len(b.Bytes)) + a[0] = b.Bytes[0] >> shift + for i := 1; i < len(b.Bytes); i++ { + a[i] = b.Bytes[i-1] << (8 - shift) + a[i] |= b.Bytes[i] >> shift + } + + return a +} + +// parseBitString parses an ASN.1 bit string from the given byte slice and returns it. +func parseBitString(bytes []byte) (ret BitString, err error) { + if len(bytes) == 0 { + err = SyntaxError{"zero length BIT STRING"} + return + } + paddingBits := int(bytes[0]) + if paddingBits > 7 || + len(bytes) == 1 && paddingBits > 0 || + bytes[len(bytes)-1]&((1<<bytes[0])-1) != 0 { + err = SyntaxError{"invalid padding bits in BIT STRING"} + return + } + ret.BitLength = (len(bytes)-1)*8 - paddingBits + ret.Bytes = bytes[1:] + return +} + +// OBJECT IDENTIFIER + +// An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER. +type ObjectIdentifier []int + +// Equal reports whether oi and other represent the same identifier. +func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool { + if len(oi) != len(other) { + return false + } + for i := 0; i < len(oi); i++ { + if oi[i] != other[i] { + return false + } + } + + return true +} + +func (oi ObjectIdentifier) String() string { + var s string + + for i, v := range oi { + if i > 0 { + s += "." + } + s += strconv.Itoa(v) + } + + return s +} + +// parseObjectIdentifier parses an OBJECT IDENTIFIER from the given bytes and +// returns it. An object identifier is a sequence of variable length integers +// that are assigned in a hierarchy. +func parseObjectIdentifier(bytes []byte) (s []int, err error) { + if len(bytes) == 0 { + err = SyntaxError{"zero length OBJECT IDENTIFIER"} + return + } + + // In the worst case, we get two elements from the first byte (which is + // encoded differently) and then every varint is a single byte long. + s = make([]int, len(bytes)+1) + + // The first varint is 40*value1 + value2: + // According to this packing, value1 can take the values 0, 1 and 2 only. + // When value1 = 0 or value1 = 1, then value2 is <= 39. When value1 = 2, + // then there are no restrictions on value2. + v, offset, err := parseBase128Int(bytes, 0) + if err != nil { + return + } + if v < 80 { + s[0] = v / 40 + s[1] = v % 40 + } else { + s[0] = 2 + s[1] = v - 80 + } + + i := 2 + for ; offset < len(bytes); i++ { + v, offset, err = parseBase128Int(bytes, offset) + if err != nil { + return + } + s[i] = v + } + s = s[0:i] + return +} + +// ENUMERATED + +// An Enumerated is represented as a plain int. +type Enumerated int + +// FLAG + +// A Flag accepts any data and is set to true if present. +type Flag bool + +// parseBase128Int parses a base-128 encoded int from the given offset in the +// given byte slice. It returns the value and the new offset. +func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error) { + offset = initOffset + for shifted := 0; offset < len(bytes); shifted++ { + if shifted > 4 { + err = StructuralError{"base 128 integer too large"} + return + } + ret <<= 7 + b := bytes[offset] + ret |= int(b & 0x7f) + offset++ + if b&0x80 == 0 { + return + } + } + err = SyntaxError{"truncated base 128 integer"} + return +} + +// UTCTime + +func parseUTCTime(bytes []byte) (ret time.Time, err error) { + s := string(bytes) + ret, err = time.Parse("0601021504Z0700", s) + if err != nil { + ret, err = time.Parse("060102150405Z0700", s) + } + if err == nil && ret.Year() >= 2050 { + // UTCTime only encodes times prior to 2050. See https://tools.ietf.org/html/rfc5280#section-4.1.2.5.1 + ret = ret.AddDate(-100, 0, 0) + } + + return +} + +// parseGeneralizedTime parses the GeneralizedTime from the given byte slice +// and returns the resulting time. +func parseGeneralizedTime(bytes []byte) (ret time.Time, err error) { + return time.Parse("20060102150405Z0700", string(bytes)) +} + +// PrintableString + +// parsePrintableString parses a ASN.1 PrintableString from the given byte +// array and returns it. +func parsePrintableString(bytes []byte) (ret string, err error) { + for _, b := range bytes { + if !isPrintable(b) { + err = SyntaxError{"PrintableString contains invalid character"} + return + } + } + ret = string(bytes) + return +} + +// isPrintable returns true iff the given b is in the ASN.1 PrintableString set. +func isPrintable(b byte) bool { + return 'a' <= b && b <= 'z' || + 'A' <= b && b <= 'Z' || + '0' <= b && b <= '9' || + '\'' <= b && b <= ')' || + '+' <= b && b <= '/' || + b == ' ' || + b == ':' || + b == '=' || + b == '?' || + // This is technically not allowed in a PrintableString. + // However, x509 certificates with wildcard strings don't + // always use the correct string type so we permit it. + b == '*' +} + +// IA5String + +// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given +// byte slice and returns it. +func parseIA5String(bytes []byte) (ret string, err error) { + for _, b := range bytes { + if b >= 0x80 { + err = SyntaxError{"IA5String contains invalid character"} + return + } + } + ret = string(bytes) + return +} + +// T61String + +// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given +// byte slice and returns it. +func parseT61String(bytes []byte) (ret string, err error) { + return string(bytes), nil +} + +// UTF8String + +// parseUTF8String parses a ASN.1 UTF8String (raw UTF-8) from the given byte +// array and returns it. +func parseUTF8String(bytes []byte) (ret string, err error) { + return string(bytes), nil +} + +// A RawValue represents an undecoded ASN.1 object. +type RawValue struct { + Class, Tag int + IsCompound bool + Bytes []byte + FullBytes []byte // includes the tag and length +} + +// RawContent is used to signal that the undecoded, DER data needs to be +// preserved for a struct. To use it, the first field of the struct must have +// this type. It's an error for any of the other fields to have this type. +type RawContent []byte + +// Tagging + +// parseTagAndLength parses an ASN.1 tag and length pair from the given offset +// into a byte slice. It returns the parsed data and the new offset. SET and +// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we +// don't distinguish between ordered and unordered objects in this code. +func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err error) { + offset = initOffset + b := bytes[offset] + offset++ + ret.class = int(b >> 6) + ret.isCompound = b&0x20 == 0x20 + ret.tag = int(b & 0x1f) + + // If the bottom five bits are set, then the tag number is actually base 128 + // encoded afterwards + if ret.tag == 0x1f { + ret.tag, offset, err = parseBase128Int(bytes, offset) + if err != nil { + return + } + } + if offset >= len(bytes) { + err = SyntaxError{"truncated tag or length"} + return + } + b = bytes[offset] + offset++ + if b&0x80 == 0 { + // The length is encoded in the bottom 7 bits. + ret.length = int(b & 0x7f) + } else { + // Bottom 7 bits give the number of length bytes to follow. + numBytes := int(b & 0x7f) + if numBytes == 0 { + err = SyntaxError{"indefinite length found (not DER)"} + return + } + ret.length = 0 + for i := 0; i < numBytes; i++ { + if offset >= len(bytes) { + err = SyntaxError{"truncated tag or length"} + return + } + b = bytes[offset] + offset++ + if ret.length >= 1<<23 { + // We can't shift ret.length up without + // overflowing. + err = StructuralError{"length too large"} + return + } + ret.length <<= 8 + ret.length |= int(b) + if ret.length == 0 { + // DER requires that lengths be minimal. + err = StructuralError{"superfluous leading zeros in length"} + return + } + } + } + + return +} + +// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse +// a number of ASN.1 values from the given byte slice and returns them as a +// slice of Go values of the given type. +func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err error) { + expectedTag, compoundType, ok := getUniversalType(elemType) + if !ok { + err = StructuralError{"unknown Go type for slice"} + return + } + + // First we iterate over the input and count the number of elements, + // checking that the types are correct in each case. + numElements := 0 + for offset := 0; offset < len(bytes); { + var t tagAndLength + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + switch t.tag { + case tagIA5String, tagGeneralString, tagT61String, tagUTF8String: + // We pretend that various other string types are + // PRINTABLE STRINGs so that a sequence of them can be + // parsed into a []string. + t.tag = tagPrintableString + case tagGeneralizedTime, tagUTCTime: + // Likewise, both time types are treated the same. + t.tag = tagUTCTime + } + + if t.class != classUniversal || t.isCompound != compoundType || t.tag != expectedTag { + err = StructuralError{"sequence tag mismatch"} + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"truncated sequence"} + return + } + offset += t.length + numElements++ + } + ret = reflect.MakeSlice(sliceType, numElements, numElements) + params := fieldParameters{} + offset := 0 + for i := 0; i < numElements; i++ { + offset, err = parseField(ret.Index(i), bytes, offset, params) + if err != nil { + return + } + } + return +} + +var ( + bitStringType = reflect.TypeOf(BitString{}) + objectIdentifierType = reflect.TypeOf(ObjectIdentifier{}) + enumeratedType = reflect.TypeOf(Enumerated(0)) + flagType = reflect.TypeOf(Flag(false)) + timeType = reflect.TypeOf(time.Time{}) + rawValueType = reflect.TypeOf(RawValue{}) + rawContentsType = reflect.TypeOf(RawContent(nil)) + bigIntType = reflect.TypeOf(new(big.Int)) +) + +// invalidLength returns true iff offset + length > sliceLength, or if the +// addition would overflow. +func invalidLength(offset, length, sliceLength int) bool { + return offset+length < offset || offset+length > sliceLength +} + +// parseField is the main parsing function. Given a byte slice and an offset +// into the array, it will try to parse a suitable ASN.1 value out and store it +// in the given Value. +func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err error) { + offset = initOffset + fieldType := v.Type() + + // If we have run out of data, it may be that there are optional elements at the end. + if offset == len(bytes) { + if !setDefaultValue(v, params) { + err = SyntaxError{"sequence truncated"} + } + return + } + + // Deal with raw values. + if fieldType == rawValueType { + var t tagAndLength + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"data truncated"} + return + } + result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]} + offset += t.length + v.Set(reflect.ValueOf(result)) + return + } + + // Deal with the ANY type. + if ifaceType := fieldType; ifaceType.Kind() == reflect.Interface && ifaceType.NumMethod() == 0 { + var t tagAndLength + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"data truncated"} + return + } + var result interface{} + if !t.isCompound && t.class == classUniversal { + innerBytes := bytes[offset : offset+t.length] + switch t.tag { + case tagPrintableString: + result, err = parsePrintableString(innerBytes) + case tagIA5String: + result, err = parseIA5String(innerBytes) + case tagT61String: + result, err = parseT61String(innerBytes) + case tagUTF8String: + result, err = parseUTF8String(innerBytes) + case tagInteger: + result, err = parseInt64(innerBytes) + case tagBitString: + result, err = parseBitString(innerBytes) + case tagOID: + result, err = parseObjectIdentifier(innerBytes) + case tagUTCTime: + result, err = parseUTCTime(innerBytes) + case tagOctetString: + result = innerBytes + default: + // If we don't know how to handle the type, we just leave Value as nil. + } + } + offset += t.length + if err != nil { + return + } + if result != nil { + v.Set(reflect.ValueOf(result)) + } + return + } + universalTag, compoundType, ok1 := getUniversalType(fieldType) + if !ok1 { + err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)} + return + } + + t, offset, err := parseTagAndLength(bytes, offset) + if err != nil { + return + } + if params.explicit { + expectedClass := classContextSpecific + if params.application { + expectedClass = classApplication + } + if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) { + if t.length > 0 { + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + } else { + if fieldType != flagType { + err = StructuralError{"zero length explicit tag was not an asn1.Flag"} + return + } + v.SetBool(true) + return + } + } else { + // The tags didn't match, it might be an optional element. + ok := setDefaultValue(v, params) + if ok { + offset = initOffset + } else { + err = StructuralError{"explicitly tagged member didn't match"} + } + return + } + } + + // Special case for strings: all the ASN.1 string types map to the Go + // type string. getUniversalType returns the tag for PrintableString + // when it sees a string, so if we see a different string type on the + // wire, we change the universal type to match. + if universalTag == tagPrintableString { + if t.class == classUniversal { + switch t.tag { + case tagIA5String, tagGeneralString, tagT61String, tagUTF8String: + universalTag = t.tag + } + } else if params.stringType != 0 { + universalTag = params.stringType + } + } + + // Special case for time: UTCTime and GeneralizedTime both map to the + // Go type time.Time. + if universalTag == tagUTCTime && t.tag == tagGeneralizedTime && t.class == classUniversal { + universalTag = tagGeneralizedTime + } + + if params.set { + universalTag = tagSet + } + + expectedClass := classUniversal + expectedTag := universalTag + + if !params.explicit && params.tag != nil { + expectedClass = classContextSpecific + expectedTag = *params.tag + } + + if !params.explicit && params.application && params.tag != nil { + expectedClass = classApplication + expectedTag = *params.tag + } + + // We have unwrapped any explicit tagging at this point. + if t.class != expectedClass || t.tag != expectedTag || t.isCompound != compoundType { + // Tags don't match. Again, it could be an optional element. + ok := setDefaultValue(v, params) + if ok { + offset = initOffset + } else { + err = StructuralError{fmt.Sprintf("tags don't match (%d vs %+v) %+v %s @%d", expectedTag, t, params, fieldType.Name(), offset)} + } + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"data truncated"} + return + } + innerBytes := bytes[offset : offset+t.length] + offset += t.length + + // We deal with the structures defined in this package first. + switch fieldType { + case objectIdentifierType: + newSlice, err1 := parseObjectIdentifier(innerBytes) + v.Set(reflect.MakeSlice(v.Type(), len(newSlice), len(newSlice))) + if err1 == nil { + reflect.Copy(v, reflect.ValueOf(newSlice)) + } + err = err1 + return + case bitStringType: + bs, err1 := parseBitString(innerBytes) + if err1 == nil { + v.Set(reflect.ValueOf(bs)) + } + err = err1 + return + case timeType: + var time time.Time + var err1 error + if universalTag == tagUTCTime { + time, err1 = parseUTCTime(innerBytes) + } else { + time, err1 = parseGeneralizedTime(innerBytes) + } + if err1 == nil { + v.Set(reflect.ValueOf(time)) + } + err = err1 + return + case enumeratedType: + parsedInt, err1 := parseInt32(innerBytes) + if err1 == nil { + v.SetInt(int64(parsedInt)) + } + err = err1 + return + case flagType: + v.SetBool(true) + return + case bigIntType: + parsedInt := parseBigInt(innerBytes) + v.Set(reflect.ValueOf(parsedInt)) + return + } + switch val := v; val.Kind() { + case reflect.Bool: + parsedBool, err1 := parseBool(innerBytes) + if err1 == nil { + val.SetBool(parsedBool) + } + err = err1 + return + case reflect.Int, reflect.Int32, reflect.Int64: + if val.Type().Size() == 4 { + parsedInt, err1 := parseInt32(innerBytes) + if err1 == nil { + val.SetInt(int64(parsedInt)) + } + err = err1 + } else { + parsedInt, err1 := parseInt64(innerBytes) + if err1 == nil { + val.SetInt(parsedInt) + } + err = err1 + } + return + // TODO(dfc) Add support for the remaining integer types + case reflect.Struct: + structType := fieldType + + if structType.NumField() > 0 && + structType.Field(0).Type == rawContentsType { + bytes := bytes[initOffset:offset] + val.Field(0).Set(reflect.ValueOf(RawContent(bytes))) + } + + innerOffset := 0 + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + if i == 0 && field.Type == rawContentsType { + continue + } + innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag.Get("asn1"))) + if err != nil { + return + } + } + // We allow extra bytes at the end of the SEQUENCE because + // adding elements to the end has been used in X.509 as the + // version numbers have increased. + return + case reflect.Slice: + sliceType := fieldType + if sliceType.Elem().Kind() == reflect.Uint8 { + val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes))) + reflect.Copy(val, reflect.ValueOf(innerBytes)) + return + } + newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem()) + if err1 == nil { + val.Set(newSlice) + } + err = err1 + return + case reflect.String: + var v string + switch universalTag { + case tagPrintableString: + v, err = parsePrintableString(innerBytes) + case tagIA5String: + v, err = parseIA5String(innerBytes) + case tagT61String: + v, err = parseT61String(innerBytes) + case tagUTF8String: + v, err = parseUTF8String(innerBytes) + case tagGeneralString: + // GeneralString is specified in ISO-2022/ECMA-35, + // A brief review suggests that it includes structures + // that allow the encoding to change midstring and + // such. We give up and pass it as an 8-bit string. + v, err = parseT61String(innerBytes) + default: + err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)} + } + if err == nil { + val.SetString(v) + } + return + } + err = StructuralError{"unsupported: " + v.Type().String()} + return +} + +// canHaveDefaultValue reports whether k is a Kind that we will set a default +// value for. (A signed integer, essentially.) +func canHaveDefaultValue(k reflect.Kind) bool { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + } + + return false +} + +// setDefaultValue is used to install a default value, from a tag string, into +// a Value. It is successful if the field was optional, even if a default value +// wasn't provided or it failed to install it into the Value. +func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) { + if !params.optional { + return + } + ok = true + if params.defaultValue == nil { + return + } + if canHaveDefaultValue(v.Kind()) { + v.SetInt(*params.defaultValue) + } + return +} + +// Unmarshal parses the DER-encoded ASN.1 data structure b +// and uses the reflect package to fill in an arbitrary value pointed at by val. +// Because Unmarshal uses the reflect package, the structs +// being written to must use upper case field names. +// +// An ASN.1 INTEGER can be written to an int, int32, int64, +// or *big.Int (from the math/big package). +// If the encoded value does not fit in the Go type, +// Unmarshal returns a parse error. +// +// An ASN.1 BIT STRING can be written to a BitString. +// +// An ASN.1 OCTET STRING can be written to a []byte. +// +// An ASN.1 OBJECT IDENTIFIER can be written to an +// ObjectIdentifier. +// +// An ASN.1 ENUMERATED can be written to an Enumerated. +// +// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a time.Time. +// +// An ASN.1 PrintableString or IA5String can be written to a string. +// +// Any of the above ASN.1 values can be written to an interface{}. +// The value stored in the interface has the corresponding Go type. +// For integers, that type is int64. +// +// An ASN.1 SEQUENCE OF x or SET OF x can be written +// to a slice if an x can be written to the slice's element type. +// +// An ASN.1 SEQUENCE or SET can be written to a struct +// if each of the elements in the sequence can be +// written to the corresponding element in the struct. +// +// The following tags on struct fields have special meaning to Unmarshal: +// +// application specifies that a APPLICATION tag is used +// default:x sets the default value for optional integer fields +// explicit specifies that an additional, explicit tag wraps the implicit one +// optional marks the field as ASN.1 OPTIONAL +// set causes a SET, rather than a SEQUENCE type to be expected +// tag:x specifies the ASN.1 tag number; implies ASN.1 CONTEXT SPECIFIC +// +// If the type of the first field of a structure is RawContent then the raw +// ASN1 contents of the struct will be stored in it. +// +// If the type name of a slice element ends with "SET" then it's treated as if +// the "set" tag was set on it. This can be used with nested slices where a +// struct tag cannot be given. +// +// Other ASN.1 types are not supported; if it encounters them, +// Unmarshal returns a parse error. +func Unmarshal(b []byte, val interface{}) (rest []byte, err error) { + return UnmarshalWithParams(b, val, "") +} + +// UnmarshalWithParams allows field parameters to be specified for the +// top-level element. The form of the params is the same as the field tags. +func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err error) { + v := reflect.ValueOf(val).Elem() + offset, err := parseField(v, b, 0, parseFieldParameters(params)) + if err != nil { + return nil, err + } + return b[offset:], nil +} diff --git a/src/encoding/asn1/asn1_test.go b/src/encoding/asn1/asn1_test.go new file mode 100644 index 000000000..4e864d08a --- /dev/null +++ b/src/encoding/asn1/asn1_test.go @@ -0,0 +1,867 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package asn1 + +import ( + "bytes" + "fmt" + "math/big" + "reflect" + "testing" + "time" +) + +type boolTest struct { + in []byte + ok bool + out bool +} + +var boolTestData = []boolTest{ + {[]byte{0x00}, true, false}, + {[]byte{0xff}, true, true}, + {[]byte{0x00, 0x00}, false, false}, + {[]byte{0xff, 0xff}, false, false}, + {[]byte{0x01}, false, false}, +} + +func TestParseBool(t *testing.T) { + for i, test := range boolTestData { + ret, err := parseBool(test.in) + if (err == nil) != test.ok { + t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok) + } + if test.ok && ret != test.out { + t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out) + } + } +} + +type int64Test struct { + in []byte + ok bool + out int64 +} + +var int64TestData = []int64Test{ + {[]byte{0x00}, true, 0}, + {[]byte{0x7f}, true, 127}, + {[]byte{0x00, 0x80}, true, 128}, + {[]byte{0x01, 0x00}, true, 256}, + {[]byte{0x80}, true, -128}, + {[]byte{0xff, 0x7f}, true, -129}, + {[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, true, -1}, + {[]byte{0xff}, true, -1}, + {[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, true, -9223372036854775808}, + {[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, false, 0}, +} + +func TestParseInt64(t *testing.T) { + for i, test := range int64TestData { + ret, err := parseInt64(test.in) + if (err == nil) != test.ok { + t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok) + } + if test.ok && ret != test.out { + t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out) + } + } +} + +type int32Test struct { + in []byte + ok bool + out int32 +} + +var int32TestData = []int32Test{ + {[]byte{0x00}, true, 0}, + {[]byte{0x7f}, true, 127}, + {[]byte{0x00, 0x80}, true, 128}, + {[]byte{0x01, 0x00}, true, 256}, + {[]byte{0x80}, true, -128}, + {[]byte{0xff, 0x7f}, true, -129}, + {[]byte{0xff, 0xff, 0xff, 0xff}, true, -1}, + {[]byte{0xff}, true, -1}, + {[]byte{0x80, 0x00, 0x00, 0x00}, true, -2147483648}, + {[]byte{0x80, 0x00, 0x00, 0x00, 0x00}, false, 0}, +} + +func TestParseInt32(t *testing.T) { + for i, test := range int32TestData { + ret, err := parseInt32(test.in) + if (err == nil) != test.ok { + t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok) + } + if test.ok && int32(ret) != test.out { + t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out) + } + } +} + +var bigIntTests = []struct { + in []byte + base10 string +}{ + {[]byte{0xff}, "-1"}, + {[]byte{0x00}, "0"}, + {[]byte{0x01}, "1"}, + {[]byte{0x00, 0xff}, "255"}, + {[]byte{0xff, 0x00}, "-256"}, + {[]byte{0x01, 0x00}, "256"}, +} + +func TestParseBigInt(t *testing.T) { + for i, test := range bigIntTests { + ret := parseBigInt(test.in) + if ret.String() != test.base10 { + t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10) + } + fw := newForkableWriter() + marshalBigInt(fw, ret) + result := fw.Bytes() + if !bytes.Equal(result, test.in) { + t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in) + } + } +} + +type bitStringTest struct { + in []byte + ok bool + out []byte + bitLength int +} + +var bitStringTestData = []bitStringTest{ + {[]byte{}, false, []byte{}, 0}, + {[]byte{0x00}, true, []byte{}, 0}, + {[]byte{0x07, 0x00}, true, []byte{0x00}, 1}, + {[]byte{0x07, 0x01}, false, []byte{}, 0}, + {[]byte{0x07, 0x40}, false, []byte{}, 0}, + {[]byte{0x08, 0x00}, false, []byte{}, 0}, +} + +func TestBitString(t *testing.T) { + for i, test := range bitStringTestData { + ret, err := parseBitString(test.in) + if (err == nil) != test.ok { + t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok) + } + if err == nil { + if test.bitLength != ret.BitLength || !bytes.Equal(ret.Bytes, test.out) { + t.Errorf("#%d: Bad result: %v (expected %v %v)", i, ret, test.out, test.bitLength) + } + } + } +} + +func TestBitStringAt(t *testing.T) { + bs := BitString{[]byte{0x82, 0x40}, 16} + if bs.At(0) != 1 { + t.Error("#1: Failed") + } + if bs.At(1) != 0 { + t.Error("#2: Failed") + } + if bs.At(6) != 1 { + t.Error("#3: Failed") + } + if bs.At(9) != 1 { + t.Error("#4: Failed") + } + if bs.At(-1) != 0 { + t.Error("#5: Failed") + } + if bs.At(17) != 0 { + t.Error("#6: Failed") + } +} + +type bitStringRightAlignTest struct { + in []byte + inlen int + out []byte +} + +var bitStringRightAlignTests = []bitStringRightAlignTest{ + {[]byte{0x80}, 1, []byte{0x01}}, + {[]byte{0x80, 0x80}, 9, []byte{0x01, 0x01}}, + {[]byte{}, 0, []byte{}}, + {[]byte{0xce}, 8, []byte{0xce}}, + {[]byte{0xce, 0x47}, 16, []byte{0xce, 0x47}}, + {[]byte{0x34, 0x50}, 12, []byte{0x03, 0x45}}, +} + +func TestBitStringRightAlign(t *testing.T) { + for i, test := range bitStringRightAlignTests { + bs := BitString{test.in, test.inlen} + out := bs.RightAlign() + if !bytes.Equal(out, test.out) { + t.Errorf("#%d got: %x want: %x", i, out, test.out) + } + } +} + +type objectIdentifierTest struct { + in []byte + ok bool + out []int +} + +var objectIdentifierTestData = []objectIdentifierTest{ + {[]byte{}, false, []int{}}, + {[]byte{85}, true, []int{2, 5}}, + {[]byte{85, 0x02}, true, []int{2, 5, 2}}, + {[]byte{85, 0x02, 0xc0, 0x00}, true, []int{2, 5, 2, 0x2000}}, + {[]byte{0x81, 0x34, 0x03}, true, []int{2, 100, 3}}, + {[]byte{85, 0x02, 0xc0, 0x80, 0x80, 0x80, 0x80}, false, []int{}}, +} + +func TestObjectIdentifier(t *testing.T) { + for i, test := range objectIdentifierTestData { + ret, err := parseObjectIdentifier(test.in) + if (err == nil) != test.ok { + t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok) + } + if err == nil { + if !reflect.DeepEqual(test.out, ret) { + t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out) + } + } + } + + if s := ObjectIdentifier([]int{1, 2, 3, 4}).String(); s != "1.2.3.4" { + t.Errorf("bad ObjectIdentifier.String(). Got %s, want 1.2.3.4", s) + } +} + +type timeTest struct { + in string + ok bool + out time.Time +} + +var utcTestData = []timeTest{ + {"910506164540-0700", true, time.Date(1991, 05, 06, 16, 45, 40, 0, time.FixedZone("", -7*60*60))}, + {"910506164540+0730", true, time.Date(1991, 05, 06, 16, 45, 40, 0, time.FixedZone("", 7*60*60+30*60))}, + {"910506234540Z", true, time.Date(1991, 05, 06, 23, 45, 40, 0, time.UTC)}, + {"9105062345Z", true, time.Date(1991, 05, 06, 23, 45, 0, 0, time.UTC)}, + {"5105062345Z", true, time.Date(1951, 05, 06, 23, 45, 0, 0, time.UTC)}, + {"a10506234540Z", false, time.Time{}}, + {"91a506234540Z", false, time.Time{}}, + {"9105a6234540Z", false, time.Time{}}, + {"910506a34540Z", false, time.Time{}}, + {"910506334a40Z", false, time.Time{}}, + {"91050633444aZ", false, time.Time{}}, + {"910506334461Z", false, time.Time{}}, + {"910506334400Za", false, time.Time{}}, +} + +func TestUTCTime(t *testing.T) { + for i, test := range utcTestData { + ret, err := parseUTCTime([]byte(test.in)) + if err != nil { + if test.ok { + t.Errorf("#%d: parseUTCTime(%q) = error %v", i, test.in, err) + } + continue + } + if !test.ok { + t.Errorf("#%d: parseUTCTime(%q) succeeded, should have failed", i, test.in) + continue + } + const format = "Jan _2 15:04:05 -0700 2006" // ignore zone name, just offset + have := ret.Format(format) + want := test.out.Format(format) + if have != want { + t.Errorf("#%d: parseUTCTime(%q) = %s, want %s", i, test.in, have, want) + } + } +} + +var generalizedTimeTestData = []timeTest{ + {"20100102030405Z", true, time.Date(2010, 01, 02, 03, 04, 05, 0, time.UTC)}, + {"20100102030405", false, time.Time{}}, + {"20100102030405+0607", true, time.Date(2010, 01, 02, 03, 04, 05, 0, time.FixedZone("", 6*60*60+7*60))}, + {"20100102030405-0607", true, time.Date(2010, 01, 02, 03, 04, 05, 0, time.FixedZone("", -6*60*60-7*60))}, +} + +func TestGeneralizedTime(t *testing.T) { + for i, test := range generalizedTimeTestData { + ret, err := parseGeneralizedTime([]byte(test.in)) + if (err == nil) != test.ok { + t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok) + } + if err == nil { + if !reflect.DeepEqual(test.out, ret) { + t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out) + } + } + } +} + +type tagAndLengthTest struct { + in []byte + ok bool + out tagAndLength +} + +var tagAndLengthData = []tagAndLengthTest{ + {[]byte{0x80, 0x01}, true, tagAndLength{2, 0, 1, false}}, + {[]byte{0xa0, 0x01}, true, tagAndLength{2, 0, 1, true}}, + {[]byte{0x02, 0x00}, true, tagAndLength{0, 2, 0, false}}, + {[]byte{0xfe, 0x00}, true, tagAndLength{3, 30, 0, true}}, + {[]byte{0x1f, 0x01, 0x00}, true, tagAndLength{0, 1, 0, false}}, + {[]byte{0x1f, 0x81, 0x00, 0x00}, true, tagAndLength{0, 128, 0, false}}, + {[]byte{0x1f, 0x81, 0x80, 0x01, 0x00}, true, tagAndLength{0, 0x4001, 0, false}}, + {[]byte{0x00, 0x81, 0x01}, true, tagAndLength{0, 0, 1, false}}, + {[]byte{0x00, 0x82, 0x01, 0x00}, true, tagAndLength{0, 0, 256, false}}, + {[]byte{0x00, 0x83, 0x01, 0x00}, false, tagAndLength{}}, + {[]byte{0x1f, 0x85}, false, tagAndLength{}}, + {[]byte{0x30, 0x80}, false, tagAndLength{}}, + // Superfluous zeros in the length should be an error. + {[]byte{0xa0, 0x82, 0x00, 0x01}, false, tagAndLength{}}, + // Lengths up to the maximum size of an int should work. + {[]byte{0xa0, 0x84, 0x7f, 0xff, 0xff, 0xff}, true, tagAndLength{2, 0, 0x7fffffff, true}}, + // Lengths that would overflow an int should be rejected. + {[]byte{0xa0, 0x84, 0x80, 0x00, 0x00, 0x00}, false, tagAndLength{}}, +} + +func TestParseTagAndLength(t *testing.T) { + for i, test := range tagAndLengthData { + tagAndLength, _, err := parseTagAndLength(test.in, 0) + if (err == nil) != test.ok { + t.Errorf("#%d: Incorrect error result (did pass? %v, expected: %v)", i, err == nil, test.ok) + } + if err == nil && !reflect.DeepEqual(test.out, tagAndLength) { + t.Errorf("#%d: Bad result: %v (expected %v)", i, tagAndLength, test.out) + } + } +} + +type parseFieldParametersTest struct { + in string + out fieldParameters +} + +func newInt(n int) *int { return &n } + +func newInt64(n int64) *int64 { return &n } + +func newString(s string) *string { return &s } + +func newBool(b bool) *bool { return &b } + +var parseFieldParametersTestData []parseFieldParametersTest = []parseFieldParametersTest{ + {"", fieldParameters{}}, + {"ia5", fieldParameters{stringType: tagIA5String}}, + {"printable", fieldParameters{stringType: tagPrintableString}}, + {"optional", fieldParameters{optional: true}}, + {"explicit", fieldParameters{explicit: true, tag: new(int)}}, + {"application", fieldParameters{application: true, tag: new(int)}}, + {"optional,explicit", fieldParameters{optional: true, explicit: true, tag: new(int)}}, + {"default:42", fieldParameters{defaultValue: newInt64(42)}}, + {"tag:17", fieldParameters{tag: newInt(17)}}, + {"optional,explicit,default:42,tag:17", fieldParameters{optional: true, explicit: true, defaultValue: newInt64(42), tag: newInt(17)}}, + {"optional,explicit,default:42,tag:17,rubbish1", fieldParameters{true, true, false, newInt64(42), newInt(17), 0, false, false}}, + {"set", fieldParameters{set: true}}, +} + +func TestParseFieldParameters(t *testing.T) { + for i, test := range parseFieldParametersTestData { + f := parseFieldParameters(test.in) + if !reflect.DeepEqual(f, test.out) { + t.Errorf("#%d: Bad result: %v (expected %v)", i, f, test.out) + } + } +} + +type TestObjectIdentifierStruct struct { + OID ObjectIdentifier +} + +type TestContextSpecificTags struct { + A int `asn1:"tag:1"` +} + +type TestContextSpecificTags2 struct { + A int `asn1:"explicit,tag:1"` + B int +} + +type TestContextSpecificTags3 struct { + S string `asn1:"tag:1,utf8"` +} + +type TestElementsAfterString struct { + S string + A, B int +} + +type TestBigInt struct { + X *big.Int +} + +type TestSet struct { + Ints []int `asn1:"set"` +} + +var unmarshalTestData = []struct { + in []byte + out interface{} +}{ + {[]byte{0x02, 0x01, 0x42}, newInt(0x42)}, + {[]byte{0x30, 0x08, 0x06, 0x06, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d}, &TestObjectIdentifierStruct{[]int{1, 2, 840, 113549}}}, + {[]byte{0x03, 0x04, 0x06, 0x6e, 0x5d, 0xc0}, &BitString{[]byte{110, 93, 192}, 18}}, + {[]byte{0x30, 0x09, 0x02, 0x01, 0x01, 0x02, 0x01, 0x02, 0x02, 0x01, 0x03}, &[]int{1, 2, 3}}, + {[]byte{0x02, 0x01, 0x10}, newInt(16)}, + {[]byte{0x13, 0x04, 't', 'e', 's', 't'}, newString("test")}, + {[]byte{0x16, 0x04, 't', 'e', 's', 't'}, newString("test")}, + {[]byte{0x16, 0x04, 't', 'e', 's', 't'}, &RawValue{0, 22, false, []byte("test"), []byte("\x16\x04test")}}, + {[]byte{0x04, 0x04, 1, 2, 3, 4}, &RawValue{0, 4, false, []byte{1, 2, 3, 4}, []byte{4, 4, 1, 2, 3, 4}}}, + {[]byte{0x30, 0x03, 0x81, 0x01, 0x01}, &TestContextSpecificTags{1}}, + {[]byte{0x30, 0x08, 0xa1, 0x03, 0x02, 0x01, 0x01, 0x02, 0x01, 0x02}, &TestContextSpecificTags2{1, 2}}, + {[]byte{0x30, 0x03, 0x81, 0x01, '@'}, &TestContextSpecificTags3{"@"}}, + {[]byte{0x01, 0x01, 0x00}, newBool(false)}, + {[]byte{0x01, 0x01, 0xff}, newBool(true)}, + {[]byte{0x30, 0x0b, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x01, 0x22, 0x02, 0x01, 0x33}, &TestElementsAfterString{"foo", 0x22, 0x33}}, + {[]byte{0x30, 0x05, 0x02, 0x03, 0x12, 0x34, 0x56}, &TestBigInt{big.NewInt(0x123456)}}, + {[]byte{0x30, 0x0b, 0x31, 0x09, 0x02, 0x01, 0x01, 0x02, 0x01, 0x02, 0x02, 0x01, 0x03}, &TestSet{Ints: []int{1, 2, 3}}}, +} + +func TestUnmarshal(t *testing.T) { + for i, test := range unmarshalTestData { + pv := reflect.New(reflect.TypeOf(test.out).Elem()) + val := pv.Interface() + _, err := Unmarshal(test.in, val) + if err != nil { + t.Errorf("Unmarshal failed at index %d %v", i, err) + } + if !reflect.DeepEqual(val, test.out) { + t.Errorf("#%d:\nhave %#v\nwant %#v", i, val, test.out) + } + } +} + +type Certificate struct { + TBSCertificate TBSCertificate + SignatureAlgorithm AlgorithmIdentifier + SignatureValue BitString +} + +type TBSCertificate struct { + Version int `asn1:"optional,explicit,default:0,tag:0"` + SerialNumber RawValue + SignatureAlgorithm AlgorithmIdentifier + Issuer RDNSequence + Validity Validity + Subject RDNSequence + PublicKey PublicKeyInfo +} + +type AlgorithmIdentifier struct { + Algorithm ObjectIdentifier +} + +type RDNSequence []RelativeDistinguishedNameSET + +type RelativeDistinguishedNameSET []AttributeTypeAndValue + +type AttributeTypeAndValue struct { + Type ObjectIdentifier + Value interface{} +} + +type Validity struct { + NotBefore, NotAfter time.Time +} + +type PublicKeyInfo struct { + Algorithm AlgorithmIdentifier + PublicKey BitString +} + +func TestCertificate(t *testing.T) { + // This is a minimal, self-signed certificate that should parse correctly. + var cert Certificate + if _, err := Unmarshal(derEncodedSelfSignedCertBytes, &cert); err != nil { + t.Errorf("Unmarshal failed: %v", err) + } + if !reflect.DeepEqual(cert, derEncodedSelfSignedCert) { + t.Errorf("Bad result:\ngot: %+v\nwant: %+v", cert, derEncodedSelfSignedCert) + } +} + +func TestCertificateWithNUL(t *testing.T) { + // This is the paypal NUL-hack certificate. It should fail to parse because + // NUL isn't a permitted character in a PrintableString. + + var cert Certificate + if _, err := Unmarshal(derEncodedPaypalNULCertBytes, &cert); err == nil { + t.Error("Unmarshal succeeded, should not have") + } +} + +type rawStructTest struct { + Raw RawContent + A int +} + +func TestRawStructs(t *testing.T) { + var s rawStructTest + input := []byte{0x30, 0x03, 0x02, 0x01, 0x50} + + rest, err := Unmarshal(input, &s) + if len(rest) != 0 { + t.Errorf("incomplete parse: %x", rest) + return + } + if err != nil { + t.Error(err) + return + } + if s.A != 0x50 { + t.Errorf("bad value for A: got %d want %d", s.A, 0x50) + } + if !bytes.Equal([]byte(s.Raw), input) { + t.Errorf("bad value for Raw: got %x want %x", s.Raw, input) + } +} + +type oiEqualTest struct { + first ObjectIdentifier + second ObjectIdentifier + same bool +} + +var oiEqualTests = []oiEqualTest{ + { + ObjectIdentifier{1, 2, 3}, + ObjectIdentifier{1, 2, 3}, + true, + }, + { + ObjectIdentifier{1}, + ObjectIdentifier{1, 2, 3}, + false, + }, + { + ObjectIdentifier{1, 2, 3}, + ObjectIdentifier{10, 11, 12}, + false, + }, +} + +func TestObjectIdentifierEqual(t *testing.T) { + for _, o := range oiEqualTests { + if s := o.first.Equal(o.second); s != o.same { + t.Errorf("ObjectIdentifier.Equal: got: %t want: %t", s, o.same) + } + } +} + +var derEncodedSelfSignedCert = Certificate{ + TBSCertificate: TBSCertificate{ + Version: 0, + SerialNumber: RawValue{Class: 0, Tag: 2, IsCompound: false, Bytes: []uint8{0x0, 0x8c, 0xc3, 0x37, 0x92, 0x10, 0xec, 0x2c, 0x98}, FullBytes: []byte{2, 9, 0x0, 0x8c, 0xc3, 0x37, 0x92, 0x10, 0xec, 0x2c, 0x98}}, + SignatureAlgorithm: AlgorithmIdentifier{Algorithm: ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5}}, + Issuer: RDNSequence{ + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 6}, Value: "XX"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 8}, Value: "Some-State"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 7}, Value: "City"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 10}, Value: "Internet Widgits Pty Ltd"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 3}, Value: "false.example.com"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{1, 2, 840, 113549, 1, 9, 1}, Value: "false@example.com"}}, + }, + Validity: Validity{ + NotBefore: time.Date(2009, 10, 8, 00, 25, 53, 0, time.UTC), + NotAfter: time.Date(2010, 10, 8, 00, 25, 53, 0, time.UTC), + }, + Subject: RDNSequence{ + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 6}, Value: "XX"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 8}, Value: "Some-State"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 7}, Value: "City"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 10}, Value: "Internet Widgits Pty Ltd"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 3}, Value: "false.example.com"}}, + RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{1, 2, 840, 113549, 1, 9, 1}, Value: "false@example.com"}}, + }, + PublicKey: PublicKeyInfo{ + Algorithm: AlgorithmIdentifier{Algorithm: ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1}}, + PublicKey: BitString{ + Bytes: []uint8{ + 0x30, 0x48, 0x2, 0x41, 0x0, 0xcd, 0xb7, + 0x63, 0x9c, 0x32, 0x78, 0xf0, 0x6, 0xaa, 0x27, 0x7f, 0x6e, 0xaf, 0x42, + 0x90, 0x2b, 0x59, 0x2d, 0x8c, 0xbc, 0xbe, 0x38, 0xa1, 0xc9, 0x2b, 0xa4, + 0x69, 0x5a, 0x33, 0x1b, 0x1d, 0xea, 0xde, 0xad, 0xd8, 0xe9, 0xa5, 0xc2, + 0x7e, 0x8c, 0x4c, 0x2f, 0xd0, 0xa8, 0x88, 0x96, 0x57, 0x72, 0x2a, 0x4f, + 0x2a, 0xf7, 0x58, 0x9c, 0xf2, 0xc7, 0x70, 0x45, 0xdc, 0x8f, 0xde, 0xec, + 0x35, 0x7d, 0x2, 0x3, 0x1, 0x0, 0x1, + }, + BitLength: 592, + }, + }, + }, + SignatureAlgorithm: AlgorithmIdentifier{Algorithm: ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5}}, + SignatureValue: BitString{ + Bytes: []uint8{ + 0xa6, 0x7b, 0x6, 0xec, 0x5e, 0xce, + 0x92, 0x77, 0x2c, 0xa4, 0x13, 0xcb, 0xa3, 0xca, 0x12, 0x56, 0x8f, 0xdc, 0x6c, + 0x7b, 0x45, 0x11, 0xcd, 0x40, 0xa7, 0xf6, 0x59, 0x98, 0x4, 0x2, 0xdf, 0x2b, + 0x99, 0x8b, 0xb9, 0xa4, 0xa8, 0xcb, 0xeb, 0x34, 0xc0, 0xf0, 0xa7, 0x8c, 0xf8, + 0xd9, 0x1e, 0xde, 0x14, 0xa5, 0xed, 0x76, 0xbf, 0x11, 0x6f, 0xe3, 0x60, 0xaa, + 0xfa, 0x88, 0x21, 0x49, 0x4, 0x35, + }, + BitLength: 512, + }, +} + +var derEncodedSelfSignedCertBytes = []byte{ + 0x30, 0x82, 0x02, 0x18, 0x30, + 0x82, 0x01, 0xc2, 0x02, 0x09, 0x00, 0x8c, 0xc3, 0x37, 0x92, 0x10, 0xec, 0x2c, + 0x98, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, + 0x05, 0x05, 0x00, 0x30, 0x81, 0x92, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, + 0x04, 0x06, 0x13, 0x02, 0x58, 0x58, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, + 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x04, 0x43, + 0x69, 0x74, 0x79, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, + 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, + 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x31, + 0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x66, 0x61, 0x6c, + 0x73, 0x65, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, + 0x6d, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, + 0x01, 0x09, 0x01, 0x16, 0x11, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x40, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x30, 0x1e, 0x17, 0x0d, + 0x30, 0x39, 0x31, 0x30, 0x30, 0x38, 0x30, 0x30, 0x32, 0x35, 0x35, 0x33, 0x5a, + 0x17, 0x0d, 0x31, 0x30, 0x31, 0x30, 0x30, 0x38, 0x30, 0x30, 0x32, 0x35, 0x35, + 0x33, 0x5a, 0x30, 0x81, 0x92, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, + 0x06, 0x13, 0x02, 0x58, 0x58, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, + 0x08, 0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x04, 0x43, 0x69, + 0x74, 0x79, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, + 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, + 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x31, 0x1a, + 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x66, 0x61, 0x6c, 0x73, + 0x65, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, + 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, + 0x09, 0x01, 0x16, 0x11, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x40, 0x65, 0x78, 0x61, + 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x30, 0x5c, 0x30, 0x0d, 0x06, + 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, + 0x4b, 0x00, 0x30, 0x48, 0x02, 0x41, 0x00, 0xcd, 0xb7, 0x63, 0x9c, 0x32, 0x78, + 0xf0, 0x06, 0xaa, 0x27, 0x7f, 0x6e, 0xaf, 0x42, 0x90, 0x2b, 0x59, 0x2d, 0x8c, + 0xbc, 0xbe, 0x38, 0xa1, 0xc9, 0x2b, 0xa4, 0x69, 0x5a, 0x33, 0x1b, 0x1d, 0xea, + 0xde, 0xad, 0xd8, 0xe9, 0xa5, 0xc2, 0x7e, 0x8c, 0x4c, 0x2f, 0xd0, 0xa8, 0x88, + 0x96, 0x57, 0x72, 0x2a, 0x4f, 0x2a, 0xf7, 0x58, 0x9c, 0xf2, 0xc7, 0x70, 0x45, + 0xdc, 0x8f, 0xde, 0xec, 0x35, 0x7d, 0x02, 0x03, 0x01, 0x00, 0x01, 0x30, 0x0d, + 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, + 0x03, 0x41, 0x00, 0xa6, 0x7b, 0x06, 0xec, 0x5e, 0xce, 0x92, 0x77, 0x2c, 0xa4, + 0x13, 0xcb, 0xa3, 0xca, 0x12, 0x56, 0x8f, 0xdc, 0x6c, 0x7b, 0x45, 0x11, 0xcd, + 0x40, 0xa7, 0xf6, 0x59, 0x98, 0x04, 0x02, 0xdf, 0x2b, 0x99, 0x8b, 0xb9, 0xa4, + 0xa8, 0xcb, 0xeb, 0x34, 0xc0, 0xf0, 0xa7, 0x8c, 0xf8, 0xd9, 0x1e, 0xde, 0x14, + 0xa5, 0xed, 0x76, 0xbf, 0x11, 0x6f, 0xe3, 0x60, 0xaa, 0xfa, 0x88, 0x21, 0x49, + 0x04, 0x35, +} + +var derEncodedPaypalNULCertBytes = []byte{ + 0x30, 0x82, 0x06, 0x44, 0x30, + 0x82, 0x05, 0xad, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x03, 0x00, 0xf0, 0x9b, + 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, + 0x05, 0x00, 0x30, 0x82, 0x01, 0x12, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, + 0x04, 0x06, 0x13, 0x02, 0x45, 0x53, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, + 0x04, 0x08, 0x13, 0x09, 0x42, 0x61, 0x72, 0x63, 0x65, 0x6c, 0x6f, 0x6e, 0x61, + 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x09, 0x42, 0x61, + 0x72, 0x63, 0x65, 0x6c, 0x6f, 0x6e, 0x61, 0x31, 0x29, 0x30, 0x27, 0x06, 0x03, + 0x55, 0x04, 0x0a, 0x13, 0x20, 0x49, 0x50, 0x53, 0x20, 0x43, 0x65, 0x72, 0x74, + 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x74, 0x79, 0x20, 0x73, 0x2e, 0x6c, 0x2e, 0x31, 0x2e, + 0x30, 0x2c, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x14, 0x25, 0x67, 0x65, 0x6e, 0x65, + 0x72, 0x61, 0x6c, 0x40, 0x69, 0x70, 0x73, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, + 0x20, 0x43, 0x2e, 0x49, 0x2e, 0x46, 0x2e, 0x20, 0x20, 0x42, 0x2d, 0x42, 0x36, + 0x32, 0x32, 0x31, 0x30, 0x36, 0x39, 0x35, 0x31, 0x2e, 0x30, 0x2c, 0x06, 0x03, + 0x55, 0x04, 0x0b, 0x13, 0x25, 0x69, 0x70, 0x73, 0x43, 0x41, 0x20, 0x43, 0x4c, + 0x41, 0x53, 0x45, 0x41, 0x31, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x74, 0x79, 0x31, 0x2e, 0x30, 0x2c, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, + 0x25, 0x69, 0x70, 0x73, 0x43, 0x41, 0x20, 0x43, 0x4c, 0x41, 0x53, 0x45, 0x41, + 0x31, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x79, 0x31, + 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x09, + 0x01, 0x16, 0x11, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x40, 0x69, 0x70, + 0x73, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x30, 0x1e, 0x17, 0x0d, 0x30, 0x39, + 0x30, 0x32, 0x32, 0x34, 0x32, 0x33, 0x30, 0x34, 0x31, 0x37, 0x5a, 0x17, 0x0d, + 0x31, 0x31, 0x30, 0x32, 0x32, 0x34, 0x32, 0x33, 0x30, 0x34, 0x31, 0x37, 0x5a, + 0x30, 0x81, 0x94, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x55, 0x53, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13, + 0x0a, 0x43, 0x61, 0x6c, 0x69, 0x66, 0x6f, 0x72, 0x6e, 0x69, 0x61, 0x31, 0x16, + 0x30, 0x14, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x0d, 0x53, 0x61, 0x6e, 0x20, + 0x46, 0x72, 0x61, 0x6e, 0x63, 0x69, 0x73, 0x63, 0x6f, 0x31, 0x11, 0x30, 0x0f, + 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x08, 0x53, 0x65, 0x63, 0x75, 0x72, 0x69, + 0x74, 0x79, 0x31, 0x14, 0x30, 0x12, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x0b, + 0x53, 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x55, 0x6e, 0x69, 0x74, 0x31, 0x2f, + 0x30, 0x2d, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x26, 0x77, 0x77, 0x77, 0x2e, + 0x70, 0x61, 0x79, 0x70, 0x61, 0x6c, 0x2e, 0x63, 0x6f, 0x6d, 0x00, 0x73, 0x73, + 0x6c, 0x2e, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x63, 0x63, 0x30, 0x81, 0x9f, 0x30, 0x0d, + 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, + 0x03, 0x81, 0x8d, 0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00, 0xd2, 0x69, + 0xfa, 0x6f, 0x3a, 0x00, 0xb4, 0x21, 0x1b, 0xc8, 0xb1, 0x02, 0xd7, 0x3f, 0x19, + 0xb2, 0xc4, 0x6d, 0xb4, 0x54, 0xf8, 0x8b, 0x8a, 0xcc, 0xdb, 0x72, 0xc2, 0x9e, + 0x3c, 0x60, 0xb9, 0xc6, 0x91, 0x3d, 0x82, 0xb7, 0x7d, 0x99, 0xff, 0xd1, 0x29, + 0x84, 0xc1, 0x73, 0x53, 0x9c, 0x82, 0xdd, 0xfc, 0x24, 0x8c, 0x77, 0xd5, 0x41, + 0xf3, 0xe8, 0x1e, 0x42, 0xa1, 0xad, 0x2d, 0x9e, 0xff, 0x5b, 0x10, 0x26, 0xce, + 0x9d, 0x57, 0x17, 0x73, 0x16, 0x23, 0x38, 0xc8, 0xd6, 0xf1, 0xba, 0xa3, 0x96, + 0x5b, 0x16, 0x67, 0x4a, 0x4f, 0x73, 0x97, 0x3a, 0x4d, 0x14, 0xa4, 0xf4, 0xe2, + 0x3f, 0x8b, 0x05, 0x83, 0x42, 0xd1, 0xd0, 0xdc, 0x2f, 0x7a, 0xe5, 0xb6, 0x10, + 0xb2, 0x11, 0xc0, 0xdc, 0x21, 0x2a, 0x90, 0xff, 0xae, 0x97, 0x71, 0x5a, 0x49, + 0x81, 0xac, 0x40, 0xf3, 0x3b, 0xb8, 0x59, 0xb2, 0x4f, 0x02, 0x03, 0x01, 0x00, + 0x01, 0xa3, 0x82, 0x03, 0x21, 0x30, 0x82, 0x03, 0x1d, 0x30, 0x09, 0x06, 0x03, + 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x11, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x01, 0x04, 0x04, 0x03, 0x02, 0x06, 0x40, + 0x30, 0x0b, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x03, 0xf8, + 0x30, 0x13, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x0c, 0x30, 0x0a, 0x06, 0x08, + 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x1d, 0x06, 0x03, 0x55, + 0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0x61, 0x8f, 0x61, 0x34, 0x43, 0x55, 0x14, + 0x7f, 0x27, 0x09, 0xce, 0x4c, 0x8b, 0xea, 0x9b, 0x7b, 0x19, 0x25, 0xbc, 0x6e, + 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, + 0x0e, 0x07, 0x60, 0xd4, 0x39, 0xc9, 0x1b, 0x5b, 0x5d, 0x90, 0x7b, 0x23, 0xc8, + 0xd2, 0x34, 0x9d, 0x4a, 0x9a, 0x46, 0x39, 0x30, 0x09, 0x06, 0x03, 0x55, 0x1d, + 0x11, 0x04, 0x02, 0x30, 0x00, 0x30, 0x1c, 0x06, 0x03, 0x55, 0x1d, 0x12, 0x04, + 0x15, 0x30, 0x13, 0x81, 0x11, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x40, + 0x69, 0x70, 0x73, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x30, 0x72, 0x06, 0x09, + 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x0d, 0x04, 0x65, 0x16, 0x63, + 0x4f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, + 0x49, 0x6e, 0x66, 0x6f, 0x72, 0x6d, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x4e, + 0x4f, 0x54, 0x20, 0x56, 0x41, 0x4c, 0x49, 0x44, 0x41, 0x54, 0x45, 0x44, 0x2e, + 0x20, 0x43, 0x4c, 0x41, 0x53, 0x45, 0x41, 0x31, 0x20, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, + 0x65, 0x20, 0x69, 0x73, 0x73, 0x75, 0x65, 0x64, 0x20, 0x62, 0x79, 0x20, 0x68, + 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x69, 0x70, + 0x73, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x30, 0x2f, 0x06, 0x09, 0x60, + 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x02, 0x04, 0x22, 0x16, 0x20, 0x68, + 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x69, 0x70, + 0x73, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x69, 0x70, 0x73, 0x63, 0x61, + 0x32, 0x30, 0x30, 0x32, 0x2f, 0x30, 0x43, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, + 0x86, 0xf8, 0x42, 0x01, 0x04, 0x04, 0x36, 0x16, 0x34, 0x68, 0x74, 0x74, 0x70, + 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x69, 0x70, 0x73, 0x63, 0x61, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x69, 0x70, 0x73, 0x63, 0x61, 0x32, 0x30, 0x30, + 0x32, 0x2f, 0x69, 0x70, 0x73, 0x63, 0x61, 0x32, 0x30, 0x30, 0x32, 0x43, 0x4c, + 0x41, 0x53, 0x45, 0x41, 0x31, 0x2e, 0x63, 0x72, 0x6c, 0x30, 0x46, 0x06, 0x09, + 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x03, 0x04, 0x39, 0x16, 0x37, + 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x69, + 0x70, 0x73, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x69, 0x70, 0x73, 0x63, + 0x61, 0x32, 0x30, 0x30, 0x32, 0x2f, 0x72, 0x65, 0x76, 0x6f, 0x63, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x43, 0x4c, 0x41, 0x53, 0x45, 0x41, 0x31, 0x2e, 0x68, 0x74, + 0x6d, 0x6c, 0x3f, 0x30, 0x43, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, + 0x42, 0x01, 0x07, 0x04, 0x36, 0x16, 0x34, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, + 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x69, 0x70, 0x73, 0x63, 0x61, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x69, 0x70, 0x73, 0x63, 0x61, 0x32, 0x30, 0x30, 0x32, 0x2f, + 0x72, 0x65, 0x6e, 0x65, 0x77, 0x61, 0x6c, 0x43, 0x4c, 0x41, 0x53, 0x45, 0x41, + 0x31, 0x2e, 0x68, 0x74, 0x6d, 0x6c, 0x3f, 0x30, 0x41, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x08, 0x04, 0x34, 0x16, 0x32, 0x68, 0x74, + 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x69, 0x70, 0x73, + 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x69, 0x70, 0x73, 0x63, 0x61, 0x32, + 0x30, 0x30, 0x32, 0x2f, 0x70, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x43, 0x4c, 0x41, + 0x53, 0x45, 0x41, 0x31, 0x2e, 0x68, 0x74, 0x6d, 0x6c, 0x30, 0x81, 0x83, 0x06, + 0x03, 0x55, 0x1d, 0x1f, 0x04, 0x7c, 0x30, 0x7a, 0x30, 0x39, 0xa0, 0x37, 0xa0, + 0x35, 0x86, 0x33, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, + 0x2e, 0x69, 0x70, 0x73, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x69, 0x70, + 0x73, 0x63, 0x61, 0x32, 0x30, 0x30, 0x32, 0x2f, 0x69, 0x70, 0x73, 0x63, 0x61, + 0x32, 0x30, 0x30, 0x32, 0x43, 0x4c, 0x41, 0x53, 0x45, 0x41, 0x31, 0x2e, 0x63, + 0x72, 0x6c, 0x30, 0x3d, 0xa0, 0x3b, 0xa0, 0x39, 0x86, 0x37, 0x68, 0x74, 0x74, + 0x70, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x62, 0x61, 0x63, 0x6b, 0x2e, 0x69, + 0x70, 0x73, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x69, 0x70, 0x73, 0x63, + 0x61, 0x32, 0x30, 0x30, 0x32, 0x2f, 0x69, 0x70, 0x73, 0x63, 0x61, 0x32, 0x30, + 0x30, 0x32, 0x43, 0x4c, 0x41, 0x53, 0x45, 0x41, 0x31, 0x2e, 0x63, 0x72, 0x6c, + 0x30, 0x32, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x01, 0x04, + 0x26, 0x30, 0x24, 0x30, 0x22, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, + 0x30, 0x01, 0x86, 0x16, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, + 0x73, 0x70, 0x2e, 0x69, 0x70, 0x73, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, + 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, + 0x05, 0x00, 0x03, 0x81, 0x81, 0x00, 0x68, 0xee, 0x79, 0x97, 0x97, 0xdd, 0x3b, + 0xef, 0x16, 0x6a, 0x06, 0xf2, 0x14, 0x9a, 0x6e, 0xcd, 0x9e, 0x12, 0xf7, 0xaa, + 0x83, 0x10, 0xbd, 0xd1, 0x7c, 0x98, 0xfa, 0xc7, 0xae, 0xd4, 0x0e, 0x2c, 0x9e, + 0x38, 0x05, 0x9d, 0x52, 0x60, 0xa9, 0x99, 0x0a, 0x81, 0xb4, 0x98, 0x90, 0x1d, + 0xae, 0xbb, 0x4a, 0xd7, 0xb9, 0xdc, 0x88, 0x9e, 0x37, 0x78, 0x41, 0x5b, 0xf7, + 0x82, 0xa5, 0xf2, 0xba, 0x41, 0x25, 0x5a, 0x90, 0x1a, 0x1e, 0x45, 0x38, 0xa1, + 0x52, 0x58, 0x75, 0x94, 0x26, 0x44, 0xfb, 0x20, 0x07, 0xba, 0x44, 0xcc, 0xe5, + 0x4a, 0x2d, 0x72, 0x3f, 0x98, 0x47, 0xf6, 0x26, 0xdc, 0x05, 0x46, 0x05, 0x07, + 0x63, 0x21, 0xab, 0x46, 0x9b, 0x9c, 0x78, 0xd5, 0x54, 0x5b, 0x3d, 0x0c, 0x1e, + 0xc8, 0x64, 0x8c, 0xb5, 0x50, 0x23, 0x82, 0x6f, 0xdb, 0xb8, 0x22, 0x1c, 0x43, + 0x96, 0x07, 0xa8, 0xbb, +} + +var stringSliceTestData = [][]string{ + {"foo", "bar"}, + {"foo", "\\bar"}, + {"foo", "\"bar\""}, + {"foo", "åäö"}, +} + +func TestStringSlice(t *testing.T) { + for _, test := range stringSliceTestData { + bs, err := Marshal(test) + if err != nil { + t.Error(err) + } + + var res []string + _, err = Unmarshal(bs, &res) + if err != nil { + t.Error(err) + } + + if fmt.Sprintf("%v", res) != fmt.Sprintf("%v", test) { + t.Errorf("incorrect marshal/unmarshal; %v != %v", res, test) + } + } +} + +type explicitTaggedTimeTest struct { + Time time.Time `asn1:"explicit,tag:0"` +} + +var explicitTaggedTimeTestData = []struct { + in []byte + out explicitTaggedTimeTest +}{ + {[]byte{0x30, 0x11, 0xa0, 0xf, 0x17, 0xd, '9', '1', '0', '5', '0', '6', '1', '6', '4', '5', '4', '0', 'Z'}, + explicitTaggedTimeTest{time.Date(1991, 05, 06, 16, 45, 40, 0, time.UTC)}}, + {[]byte{0x30, 0x17, 0xa0, 0xf, 0x18, 0x13, '2', '0', '1', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '+', '0', '6', '0', '7'}, + explicitTaggedTimeTest{time.Date(2010, 01, 02, 03, 04, 05, 0, time.FixedZone("", 6*60*60+7*60))}}, +} + +func TestExplicitTaggedTime(t *testing.T) { + // Test that a time.Time will match either tagUTCTime or + // tagGeneralizedTime. + for i, test := range explicitTaggedTimeTestData { + var got explicitTaggedTimeTest + _, err := Unmarshal(test.in, &got) + if err != nil { + t.Errorf("Unmarshal failed at index %d %v", i, err) + } + if !got.Time.Equal(test.out.Time) { + t.Errorf("#%d: got %v, want %v", i, got.Time, test.out.Time) + } + } +} + +type implicitTaggedTimeTest struct { + Time time.Time `asn1:"tag:24"` +} + +func TestImplicitTaggedTime(t *testing.T) { + // An implicitly tagged time value, that happens to have an implicit + // tag equal to a GENERALIZEDTIME, should still be parsed as a UTCTime. + // (There's no "timeType" in fieldParameters to determine what type of + // time should be expected when implicitly tagged.) + der := []byte{0x30, 0x0f, 0x80 | 24, 0xd, '9', '1', '0', '5', '0', '6', '1', '6', '4', '5', '4', '0', 'Z'} + var result implicitTaggedTimeTest + if _, err := Unmarshal(der, &result); err != nil { + t.Fatalf("Error while parsing: %s", err) + } + if expected := time.Date(1991, 05, 06, 16, 45, 40, 0, time.UTC); !result.Time.Equal(expected) { + t.Errorf("Wrong result. Got %v, want %v", result.Time, expected) + } +} diff --git a/src/encoding/asn1/common.go b/src/encoding/asn1/common.go new file mode 100644 index 000000000..33a117ece --- /dev/null +++ b/src/encoding/asn1/common.go @@ -0,0 +1,163 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package asn1 + +import ( + "reflect" + "strconv" + "strings" +) + +// ASN.1 objects have metadata preceding them: +// the tag: the type of the object +// a flag denoting if this object is compound or not +// the class type: the namespace of the tag +// the length of the object, in bytes + +// Here are some standard tags and classes + +const ( + tagBoolean = 1 + tagInteger = 2 + tagBitString = 3 + tagOctetString = 4 + tagOID = 6 + tagEnum = 10 + tagUTF8String = 12 + tagSequence = 16 + tagSet = 17 + tagPrintableString = 19 + tagT61String = 20 + tagIA5String = 22 + tagUTCTime = 23 + tagGeneralizedTime = 24 + tagGeneralString = 27 +) + +const ( + classUniversal = 0 + classApplication = 1 + classContextSpecific = 2 + classPrivate = 3 +) + +type tagAndLength struct { + class, tag, length int + isCompound bool +} + +// ASN.1 has IMPLICIT and EXPLICIT tags, which can be translated as "instead +// of" and "in addition to". When not specified, every primitive type has a +// default tag in the UNIVERSAL class. +// +// For example: a BIT STRING is tagged [UNIVERSAL 3] by default (although ASN.1 +// doesn't actually have a UNIVERSAL keyword). However, by saying [IMPLICIT +// CONTEXT-SPECIFIC 42], that means that the tag is replaced by another. +// +// On the other hand, if it said [EXPLICIT CONTEXT-SPECIFIC 10], then an +// /additional/ tag would wrap the default tag. This explicit tag will have the +// compound flag set. +// +// (This is used in order to remove ambiguity with optional elements.) +// +// You can layer EXPLICIT and IMPLICIT tags to an arbitrary depth, however we +// don't support that here. We support a single layer of EXPLICIT or IMPLICIT +// tagging with tag strings on the fields of a structure. + +// fieldParameters is the parsed representation of tag string from a structure field. +type fieldParameters struct { + optional bool // true iff the field is OPTIONAL + explicit bool // true iff an EXPLICIT tag is in use. + application bool // true iff an APPLICATION tag is in use. + defaultValue *int64 // a default value for INTEGER typed fields (maybe nil). + tag *int // the EXPLICIT or IMPLICIT tag (maybe nil). + stringType int // the string tag to use when marshaling. + set bool // true iff this should be encoded as a SET + omitEmpty bool // true iff this should be omitted if empty when marshaling. + + // Invariants: + // if explicit is set, tag is non-nil. +} + +// Given a tag string with the format specified in the package comment, +// parseFieldParameters will parse it into a fieldParameters structure, +// ignoring unknown parts of the string. +func parseFieldParameters(str string) (ret fieldParameters) { + for _, part := range strings.Split(str, ",") { + switch { + case part == "optional": + ret.optional = true + case part == "explicit": + ret.explicit = true + if ret.tag == nil { + ret.tag = new(int) + } + case part == "ia5": + ret.stringType = tagIA5String + case part == "printable": + ret.stringType = tagPrintableString + case part == "utf8": + ret.stringType = tagUTF8String + case strings.HasPrefix(part, "default:"): + i, err := strconv.ParseInt(part[8:], 10, 64) + if err == nil { + ret.defaultValue = new(int64) + *ret.defaultValue = i + } + case strings.HasPrefix(part, "tag:"): + i, err := strconv.Atoi(part[4:]) + if err == nil { + ret.tag = new(int) + *ret.tag = i + } + case part == "set": + ret.set = true + case part == "application": + ret.application = true + if ret.tag == nil { + ret.tag = new(int) + } + case part == "omitempty": + ret.omitEmpty = true + } + } + return +} + +// Given a reflected Go type, getUniversalType returns the default tag number +// and expected compound flag. +func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) { + switch t { + case objectIdentifierType: + return tagOID, false, true + case bitStringType: + return tagBitString, false, true + case timeType: + return tagUTCTime, false, true + case enumeratedType: + return tagEnum, false, true + case bigIntType: + return tagInteger, false, true + } + switch t.Kind() { + case reflect.Bool: + return tagBoolean, false, true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return tagInteger, false, true + case reflect.Struct: + return tagSequence, true, true + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { + return tagOctetString, false, true + } + if strings.HasSuffix(t.Name(), "SET") { + return tagSet, true, true + } + return tagSequence, true, true + case reflect.String: + return tagPrintableString, false, true + } + return 0, false, false +} diff --git a/src/encoding/asn1/marshal.go b/src/encoding/asn1/marshal.go new file mode 100644 index 000000000..b2f104b4c --- /dev/null +++ b/src/encoding/asn1/marshal.go @@ -0,0 +1,646 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package asn1 + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "time" + "unicode/utf8" +) + +// A forkableWriter is an in-memory buffer that can be +// 'forked' to create new forkableWriters that bracket the +// original. After +// pre, post := w.fork(); +// the overall sequence of bytes represented is logically w+pre+post. +type forkableWriter struct { + *bytes.Buffer + pre, post *forkableWriter +} + +func newForkableWriter() *forkableWriter { + return &forkableWriter{new(bytes.Buffer), nil, nil} +} + +func (f *forkableWriter) fork() (pre, post *forkableWriter) { + if f.pre != nil || f.post != nil { + panic("have already forked") + } + f.pre = newForkableWriter() + f.post = newForkableWriter() + return f.pre, f.post +} + +func (f *forkableWriter) Len() (l int) { + l += f.Buffer.Len() + if f.pre != nil { + l += f.pre.Len() + } + if f.post != nil { + l += f.post.Len() + } + return +} + +func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) { + n, err = out.Write(f.Bytes()) + if err != nil { + return + } + + var nn int + + if f.pre != nil { + nn, err = f.pre.writeTo(out) + n += nn + if err != nil { + return + } + } + + if f.post != nil { + nn, err = f.post.writeTo(out) + n += nn + } + return +} + +func marshalBase128Int(out *forkableWriter, n int64) (err error) { + if n == 0 { + err = out.WriteByte(0) + return + } + + l := 0 + for i := n; i > 0; i >>= 7 { + l++ + } + + for i := l - 1; i >= 0; i-- { + o := byte(n >> uint(i*7)) + o &= 0x7f + if i != 0 { + o |= 0x80 + } + err = out.WriteByte(o) + if err != nil { + return + } + } + + return nil +} + +func marshalInt64(out *forkableWriter, i int64) (err error) { + n := int64Length(i) + + for ; n > 0; n-- { + err = out.WriteByte(byte(i >> uint((n-1)*8))) + if err != nil { + return + } + } + + return nil +} + +func int64Length(i int64) (numBytes int) { + numBytes = 1 + + for i > 127 { + numBytes++ + i >>= 8 + } + + for i < -128 { + numBytes++ + i >>= 8 + } + + return +} + +func marshalBigInt(out *forkableWriter, n *big.Int) (err error) { + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement + // form. So we'll subtract 1 and invert. If the + // most-significant-bit isn't set then we'll need to pad the + // beginning with 0xff in order to keep the number negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + err = out.WriteByte(0xff) + if err != nil { + return + } + } + _, err = out.Write(bytes) + } else if n.Sign() == 0 { + // Zero is written as a single 0 zero rather than no bytes. + err = out.WriteByte(0x00) + } else { + bytes := n.Bytes() + if len(bytes) > 0 && bytes[0]&0x80 != 0 { + // We'll have to pad this with 0x00 in order to stop it + // looking like a negative number. + err = out.WriteByte(0) + if err != nil { + return + } + } + _, err = out.Write(bytes) + } + return +} + +func marshalLength(out *forkableWriter, i int) (err error) { + n := lengthLength(i) + + for ; n > 0; n-- { + err = out.WriteByte(byte(i >> uint((n-1)*8))) + if err != nil { + return + } + } + + return nil +} + +func lengthLength(i int) (numBytes int) { + numBytes = 1 + for i > 255 { + numBytes++ + i >>= 8 + } + return +} + +func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) { + b := uint8(t.class) << 6 + if t.isCompound { + b |= 0x20 + } + if t.tag >= 31 { + b |= 0x1f + err = out.WriteByte(b) + if err != nil { + return + } + err = marshalBase128Int(out, int64(t.tag)) + if err != nil { + return + } + } else { + b |= uint8(t.tag) + err = out.WriteByte(b) + if err != nil { + return + } + } + + if t.length >= 128 { + l := lengthLength(t.length) + err = out.WriteByte(0x80 | byte(l)) + if err != nil { + return + } + err = marshalLength(out, t.length) + if err != nil { + return + } + } else { + err = out.WriteByte(byte(t.length)) + if err != nil { + return + } + } + + return nil +} + +func marshalBitString(out *forkableWriter, b BitString) (err error) { + paddingBits := byte((8 - b.BitLength%8) % 8) + err = out.WriteByte(paddingBits) + if err != nil { + return + } + _, err = out.Write(b.Bytes) + return +} + +func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) { + if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { + return StructuralError{"invalid object identifier"} + } + + err = marshalBase128Int(out, int64(oid[0]*40+oid[1])) + if err != nil { + return + } + for i := 2; i < len(oid); i++ { + err = marshalBase128Int(out, int64(oid[i])) + if err != nil { + return + } + } + + return +} + +func marshalPrintableString(out *forkableWriter, s string) (err error) { + b := []byte(s) + for _, c := range b { + if !isPrintable(c) { + return StructuralError{"PrintableString contains invalid character"} + } + } + + _, err = out.Write(b) + return +} + +func marshalIA5String(out *forkableWriter, s string) (err error) { + b := []byte(s) + for _, c := range b { + if c > 127 { + return StructuralError{"IA5String contains invalid character"} + } + } + + _, err = out.Write(b) + return +} + +func marshalUTF8String(out *forkableWriter, s string) (err error) { + _, err = out.Write([]byte(s)) + return +} + +func marshalTwoDigits(out *forkableWriter, v int) (err error) { + err = out.WriteByte(byte('0' + (v/10)%10)) + if err != nil { + return + } + return out.WriteByte(byte('0' + v%10)) +} + +func marshalFourDigits(out *forkableWriter, v int) (err error) { + var bytes [4]byte + for i := range bytes { + bytes[3-i] = '0' + byte(v%10) + v /= 10 + } + _, err = out.Write(bytes[:]) + return +} + +func outsideUTCRange(t time.Time) bool { + year := t.Year() + return year < 1950 || year >= 2050 +} + +func marshalUTCTime(out *forkableWriter, t time.Time) (err error) { + year := t.Year() + + switch { + case 1950 <= year && year < 2000: + err = marshalTwoDigits(out, int(year-1900)) + case 2000 <= year && year < 2050: + err = marshalTwoDigits(out, int(year-2000)) + default: + return StructuralError{"cannot represent time as UTCTime"} + } + if err != nil { + return + } + + return marshalTimeCommon(out, t) +} + +func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) { + year := t.Year() + if year < 0 || year > 9999 { + return StructuralError{"cannot represent time as GeneralizedTime"} + } + if err = marshalFourDigits(out, year); err != nil { + return + } + + return marshalTimeCommon(out, t) +} + +func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) { + _, month, day := t.Date() + + err = marshalTwoDigits(out, int(month)) + if err != nil { + return + } + + err = marshalTwoDigits(out, day) + if err != nil { + return + } + + hour, min, sec := t.Clock() + + err = marshalTwoDigits(out, hour) + if err != nil { + return + } + + err = marshalTwoDigits(out, min) + if err != nil { + return + } + + err = marshalTwoDigits(out, sec) + if err != nil { + return + } + + _, offset := t.Zone() + + switch { + case offset/60 == 0: + err = out.WriteByte('Z') + return + case offset > 0: + err = out.WriteByte('+') + case offset < 0: + err = out.WriteByte('-') + } + + if err != nil { + return + } + + offsetMinutes := offset / 60 + if offsetMinutes < 0 { + offsetMinutes = -offsetMinutes + } + + err = marshalTwoDigits(out, offsetMinutes/60) + if err != nil { + return + } + + err = marshalTwoDigits(out, offsetMinutes%60) + return +} + +func stripTagAndLength(in []byte) []byte { + _, offset, err := parseTagAndLength(in, 0) + if err != nil { + return in + } + return in[offset:] +} + +func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) { + switch value.Type() { + case timeType: + t := value.Interface().(time.Time) + if outsideUTCRange(t) { + return marshalGeneralizedTime(out, t) + } else { + return marshalUTCTime(out, t) + } + case bitStringType: + return marshalBitString(out, value.Interface().(BitString)) + case objectIdentifierType: + return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier)) + case bigIntType: + return marshalBigInt(out, value.Interface().(*big.Int)) + } + + switch v := value; v.Kind() { + case reflect.Bool: + if v.Bool() { + return out.WriteByte(255) + } else { + return out.WriteByte(0) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return marshalInt64(out, int64(v.Int())) + case reflect.Struct: + t := v.Type() + + startingField := 0 + + // If the first element of the structure is a non-empty + // RawContents, then we don't bother serializing the rest. + if t.NumField() > 0 && t.Field(0).Type == rawContentsType { + s := v.Field(0) + if s.Len() > 0 { + bytes := make([]byte, s.Len()) + for i := 0; i < s.Len(); i++ { + bytes[i] = uint8(s.Index(i).Uint()) + } + /* The RawContents will contain the tag and + * length fields but we'll also be writing + * those ourselves, so we strip them out of + * bytes */ + _, err = out.Write(stripTagAndLength(bytes)) + return + } else { + startingField = 1 + } + } + + for i := startingField; i < t.NumField(); i++ { + var pre *forkableWriter + pre, out = out.fork() + err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1"))) + if err != nil { + return + } + } + return + case reflect.Slice: + sliceType := v.Type() + if sliceType.Elem().Kind() == reflect.Uint8 { + bytes := make([]byte, v.Len()) + for i := 0; i < v.Len(); i++ { + bytes[i] = uint8(v.Index(i).Uint()) + } + _, err = out.Write(bytes) + return + } + + var fp fieldParameters + for i := 0; i < v.Len(); i++ { + var pre *forkableWriter + pre, out = out.fork() + err = marshalField(pre, v.Index(i), fp) + if err != nil { + return + } + } + return + case reflect.String: + switch params.stringType { + case tagIA5String: + return marshalIA5String(out, v.String()) + case tagPrintableString: + return marshalPrintableString(out, v.String()) + default: + return marshalUTF8String(out, v.String()) + } + } + + return StructuralError{"unknown Go type"} +} + +func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) { + // If the field is an interface{} then recurse into it. + if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 { + return marshalField(out, v.Elem(), params) + } + + if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty { + return + } + + if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) { + defaultValue := reflect.New(v.Type()).Elem() + defaultValue.SetInt(*params.defaultValue) + + if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) { + return + } + } + + // If no default value is given then the zero value for the type is + // assumed to be the default value. This isn't obviously the correct + // behaviour, but it's what Go has traditionally done. + if params.optional && params.defaultValue == nil { + if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) { + return + } + } + + if v.Type() == rawValueType { + rv := v.Interface().(RawValue) + if len(rv.FullBytes) != 0 { + _, err = out.Write(rv.FullBytes) + } else { + err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}) + if err != nil { + return + } + _, err = out.Write(rv.Bytes) + } + return + } + + tag, isCompound, ok := getUniversalType(v.Type()) + if !ok { + err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())} + return + } + class := classUniversal + + if params.stringType != 0 && tag != tagPrintableString { + return StructuralError{"explicit string type given to non-string member"} + } + + switch tag { + case tagPrintableString: + if params.stringType == 0 { + // This is a string without an explicit string type. We'll use + // a PrintableString if the character set in the string is + // sufficiently limited, otherwise we'll use a UTF8String. + for _, r := range v.String() { + if r >= utf8.RuneSelf || !isPrintable(byte(r)) { + if !utf8.ValidString(v.String()) { + return errors.New("asn1: string not valid UTF-8") + } + tag = tagUTF8String + break + } + } + } else { + tag = params.stringType + } + case tagUTCTime: + if outsideUTCRange(v.Interface().(time.Time)) { + tag = tagGeneralizedTime + } + } + + if params.set { + if tag != tagSequence { + return StructuralError{"non sequence tagged as set"} + } + tag = tagSet + } + + tags, body := out.fork() + + err = marshalBody(body, v, params) + if err != nil { + return + } + + bodyLen := body.Len() + + var explicitTag *forkableWriter + if params.explicit { + explicitTag, tags = tags.fork() + } + + if !params.explicit && params.tag != nil { + // implicit tag. + tag = *params.tag + class = classContextSpecific + } + + err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound}) + if err != nil { + return + } + + if params.explicit { + err = marshalTagAndLength(explicitTag, tagAndLength{ + class: classContextSpecific, + tag: *params.tag, + length: bodyLen + tags.Len(), + isCompound: true, + }) + } + + return nil +} + +// Marshal returns the ASN.1 encoding of val. +// +// In addition to the struct tags recognised by Unmarshal, the following can be +// used: +// +// ia5: causes strings to be marshaled as ASN.1, IA5 strings +// omitempty: causes empty slices to be skipped +// printable: causes strings to be marshaled as ASN.1, PrintableString strings. +// utf8: causes strings to be marshaled as ASN.1, UTF8 strings +func Marshal(val interface{}) ([]byte, error) { + var out bytes.Buffer + v := reflect.ValueOf(val) + f := newForkableWriter() + err := marshalField(f, v, fieldParameters{}) + if err != nil { + return nil, err + } + _, err = f.writeTo(&out) + return out.Bytes(), nil +} diff --git a/src/encoding/asn1/marshal_test.go b/src/encoding/asn1/marshal_test.go new file mode 100644 index 000000000..5b0115f28 --- /dev/null +++ b/src/encoding/asn1/marshal_test.go @@ -0,0 +1,164 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package asn1 + +import ( + "bytes" + "encoding/hex" + "math/big" + "testing" + "time" +) + +type intStruct struct { + A int +} + +type twoIntStruct struct { + A int + B int +} + +type bigIntStruct struct { + A *big.Int +} + +type nestedStruct struct { + A intStruct +} + +type rawContentsStruct struct { + Raw RawContent + A int +} + +type implicitTagTest struct { + A int `asn1:"implicit,tag:5"` +} + +type explicitTagTest struct { + A int `asn1:"explicit,tag:5"` +} + +type ia5StringTest struct { + A string `asn1:"ia5"` +} + +type printableStringTest struct { + A string `asn1:"printable"` +} + +type optionalRawValueTest struct { + A RawValue `asn1:"optional"` +} + +type omitEmptyTest struct { + A []string `asn1:"omitempty"` +} + +type defaultTest struct { + A int `asn1:"optional,default:1"` +} + +type testSET []int + +var PST = time.FixedZone("PST", -8*60*60) + +type marshalTest struct { + in interface{} + out string // hex encoded +} + +func farFuture() time.Time { + t, err := time.Parse(time.RFC3339, "2100-04-05T12:01:01Z") + if err != nil { + panic(err) + } + return t +} + +var marshalTests = []marshalTest{ + {10, "02010a"}, + {127, "02017f"}, + {128, "02020080"}, + {-128, "020180"}, + {-129, "0202ff7f"}, + {intStruct{64}, "3003020140"}, + {bigIntStruct{big.NewInt(0x123456)}, "30050203123456"}, + {twoIntStruct{64, 65}, "3006020140020141"}, + {nestedStruct{intStruct{127}}, "3005300302017f"}, + {[]byte{1, 2, 3}, "0403010203"}, + {implicitTagTest{64}, "3003850140"}, + {explicitTagTest{64}, "3005a503020140"}, + {time.Unix(0, 0).UTC(), "170d3730303130313030303030305a"}, + {time.Unix(1258325776, 0).UTC(), "170d3039313131353232353631365a"}, + {time.Unix(1258325776, 0).In(PST), "17113039313131353134353631362d30383030"}, + {farFuture(), "180f32313030303430353132303130315a"}, + {BitString{[]byte{0x80}, 1}, "03020780"}, + {BitString{[]byte{0x81, 0xf0}, 12}, "03030481f0"}, + {ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"}, + {ObjectIdentifier([]int{1, 2, 840, 133549, 1, 1, 5}), "06092a864888932d010105"}, + {ObjectIdentifier([]int{2, 100, 3}), "0603813403"}, + {"test", "130474657374"}, + { + "" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", // This is 127 times 'x' + "137f" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "78787878787878787878787878787878787878787878787878787878787878", + }, + { + "" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", // This is 128 times 'x' + "138180" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878", + }, + {ia5StringTest{"test"}, "3006160474657374"}, + {optionalRawValueTest{}, "3000"}, + {printableStringTest{"test"}, "3006130474657374"}, + {printableStringTest{"test*"}, "30071305746573742a"}, + {rawContentsStruct{nil, 64}, "3003020140"}, + {rawContentsStruct{[]byte{0x30, 3, 1, 2, 3}, 64}, "3003010203"}, + {RawValue{Tag: 1, Class: 2, IsCompound: false, Bytes: []byte{1, 2, 3}}, "8103010203"}, + {testSET([]int{10}), "310302010a"}, + {omitEmptyTest{[]string{}}, "3000"}, + {omitEmptyTest{[]string{"1"}}, "30053003130131"}, + {"Σ", "0c02cea3"}, + {defaultTest{0}, "3003020100"}, + {defaultTest{1}, "3000"}, + {defaultTest{2}, "3003020102"}, +} + +func TestMarshal(t *testing.T) { + for i, test := range marshalTests { + data, err := Marshal(test.in) + if err != nil { + t.Errorf("#%d failed: %s", i, err) + } + out, _ := hex.DecodeString(test.out) + if !bytes.Equal(out, data) { + t.Errorf("#%d got: %x want %x\n\t%q\n\t%q", i, data, out, data, out) + + } + } +} + +func TestInvalidUTF8(t *testing.T) { + _, err := Marshal(string([]byte{0xff, 0xff})) + if err == nil { + t.Errorf("invalid UTF8 string was accepted") + } +} diff --git a/src/encoding/base32/base32.go b/src/encoding/base32/base32.go new file mode 100644 index 000000000..5a9e86919 --- /dev/null +++ b/src/encoding/base32/base32.go @@ -0,0 +1,426 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package base32 implements base32 encoding as specified by RFC 4648. +package base32 + +import ( + "bytes" + "io" + "strconv" + "strings" +) + +/* + * Encodings + */ + +// An Encoding is a radix 32 encoding/decoding scheme, defined by a +// 32-character alphabet. The most common is the "base32" encoding +// introduced for SASL GSSAPI and standardized in RFC 4648. +// The alternate "base32hex" encoding is used in DNSSEC. +type Encoding struct { + encode string + decodeMap [256]byte +} + +const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" +const encodeHex = "0123456789ABCDEFGHIJKLMNOPQRSTUV" + +// NewEncoding returns a new Encoding defined by the given alphabet, +// which must be a 32-byte string. +func NewEncoding(encoder string) *Encoding { + e := new(Encoding) + e.encode = encoder + for i := 0; i < len(e.decodeMap); i++ { + e.decodeMap[i] = 0xFF + } + for i := 0; i < len(encoder); i++ { + e.decodeMap[encoder[i]] = byte(i) + } + return e +} + +// StdEncoding is the standard base32 encoding, as defined in +// RFC 4648. +var StdEncoding = NewEncoding(encodeStd) + +// HexEncoding is the ``Extended Hex Alphabet'' defined in RFC 4648. +// It is typically used in DNS. +var HexEncoding = NewEncoding(encodeHex) + +var removeNewlinesMapper = func(r rune) rune { + if r == '\r' || r == '\n' { + return -1 + } + return r +} + +/* + * Encoder + */ + +// Encode encodes src using the encoding enc, writing +// EncodedLen(len(src)) bytes to dst. +// +// The encoding pads the output to a multiple of 8 bytes, +// so Encode is not appropriate for use on individual blocks +// of a large data stream. Use NewEncoder() instead. +func (enc *Encoding) Encode(dst, src []byte) { + if len(src) == 0 { + return + } + + for len(src) > 0 { + var b0, b1, b2, b3, b4, b5, b6, b7 byte + + // Unpack 8x 5-bit source blocks into a 5 byte + // destination quantum + switch len(src) { + default: + b7 = src[4] & 0x1F + b6 = src[4] >> 5 + fallthrough + case 4: + b6 |= (src[3] << 3) & 0x1F + b5 = (src[3] >> 2) & 0x1F + b4 = src[3] >> 7 + fallthrough + case 3: + b4 |= (src[2] << 1) & 0x1F + b3 = (src[2] >> 4) & 0x1F + fallthrough + case 2: + b3 |= (src[1] << 4) & 0x1F + b2 = (src[1] >> 1) & 0x1F + b1 = (src[1] >> 6) & 0x1F + fallthrough + case 1: + b1 |= (src[0] << 2) & 0x1F + b0 = src[0] >> 3 + } + + // Encode 5-bit blocks using the base32 alphabet + dst[0] = enc.encode[b0] + dst[1] = enc.encode[b1] + dst[2] = enc.encode[b2] + dst[3] = enc.encode[b3] + dst[4] = enc.encode[b4] + dst[5] = enc.encode[b5] + dst[6] = enc.encode[b6] + dst[7] = enc.encode[b7] + + // Pad the final quantum + if len(src) < 5 { + dst[7] = '=' + if len(src) < 4 { + dst[6] = '=' + dst[5] = '=' + if len(src) < 3 { + dst[4] = '=' + if len(src) < 2 { + dst[3] = '=' + dst[2] = '=' + } + } + } + break + } + src = src[5:] + dst = dst[8:] + } +} + +// EncodeToString returns the base32 encoding of src. +func (enc *Encoding) EncodeToString(src []byte) string { + buf := make([]byte, enc.EncodedLen(len(src))) + enc.Encode(buf, src) + return string(buf) +} + +type encoder struct { + err error + enc *Encoding + w io.Writer + buf [5]byte // buffered data waiting to be encoded + nbuf int // number of bytes in buf + out [1024]byte // output buffer +} + +func (e *encoder) Write(p []byte) (n int, err error) { + if e.err != nil { + return 0, e.err + } + + // Leading fringe. + if e.nbuf > 0 { + var i int + for i = 0; i < len(p) && e.nbuf < 5; i++ { + e.buf[e.nbuf] = p[i] + e.nbuf++ + } + n += i + p = p[i:] + if e.nbuf < 5 { + return + } + e.enc.Encode(e.out[0:], e.buf[0:]) + if _, e.err = e.w.Write(e.out[0:8]); e.err != nil { + return n, e.err + } + e.nbuf = 0 + } + + // Large interior chunks. + for len(p) >= 5 { + nn := len(e.out) / 8 * 5 + if nn > len(p) { + nn = len(p) + nn -= nn % 5 + } + e.enc.Encode(e.out[0:], p[0:nn]) + if _, e.err = e.w.Write(e.out[0 : nn/5*8]); e.err != nil { + return n, e.err + } + n += nn + p = p[nn:] + } + + // Trailing fringe. + for i := 0; i < len(p); i++ { + e.buf[i] = p[i] + } + e.nbuf = len(p) + n += len(p) + return +} + +// Close flushes any pending output from the encoder. +// It is an error to call Write after calling Close. +func (e *encoder) Close() error { + // If there's anything left in the buffer, flush it out + if e.err == nil && e.nbuf > 0 { + e.enc.Encode(e.out[0:], e.buf[0:e.nbuf]) + e.nbuf = 0 + _, e.err = e.w.Write(e.out[0:8]) + } + return e.err +} + +// NewEncoder returns a new base32 stream encoder. Data written to +// the returned writer will be encoded using enc and then written to w. +// Base32 encodings operate in 5-byte blocks; when finished +// writing, the caller must Close the returned encoder to flush any +// partially written blocks. +func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser { + return &encoder{enc: enc, w: w} +} + +// EncodedLen returns the length in bytes of the base32 encoding +// of an input buffer of length n. +func (enc *Encoding) EncodedLen(n int) int { return (n + 4) / 5 * 8 } + +/* + * Decoder + */ + +type CorruptInputError int64 + +func (e CorruptInputError) Error() string { + return "illegal base32 data at input byte " + strconv.FormatInt(int64(e), 10) +} + +// decode is like Decode but returns an additional 'end' value, which +// indicates if end-of-message padding was encountered and thus any +// additional data is an error. This method assumes that src has been +// stripped of all supported whitespace ('\r' and '\n'). +func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { + olen := len(src) + for len(src) > 0 && !end { + // Decode quantum using the base32 alphabet + var dbuf [8]byte + dlen := 8 + + for j := 0; j < 8; { + if len(src) == 0 { + return n, false, CorruptInputError(olen - len(src) - j) + } + in := src[0] + src = src[1:] + if in == '=' && j >= 2 && len(src) < 8 { + // We've reached the end and there's padding + if len(src)+j < 8-1 { + // not enough padding + return n, false, CorruptInputError(olen) + } + for k := 0; k < 8-1-j; k++ { + if len(src) > k && src[k] != '=' { + // incorrect padding + return n, false, CorruptInputError(olen - len(src) + k - 1) + } + } + dlen, end = j, true + // 7, 5 and 2 are not valid padding lengths, and so 1, 3 and 6 are not + // valid dlen values. See RFC 4648 Section 6 "Base 32 Encoding" listing + // the five valid padding lengths, and Section 9 "Illustrations and + // Examples" for an illustration for how the 1st, 3rd and 6th base32 + // src bytes do not yield enough information to decode a dst byte. + if dlen == 1 || dlen == 3 || dlen == 6 { + return n, false, CorruptInputError(olen - len(src) - 1) + } + break + } + dbuf[j] = enc.decodeMap[in] + if dbuf[j] == 0xFF { + return n, false, CorruptInputError(olen - len(src) - 1) + } + j++ + } + + // Pack 8x 5-bit source blocks into 5 byte destination + // quantum + switch dlen { + case 8: + dst[4] = dbuf[6]<<5 | dbuf[7] + fallthrough + case 7: + dst[3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3 + fallthrough + case 5: + dst[2] = dbuf[3]<<4 | dbuf[4]>>1 + fallthrough + case 4: + dst[1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4 + fallthrough + case 2: + dst[0] = dbuf[0]<<3 | dbuf[1]>>2 + } + dst = dst[5:] + switch dlen { + case 2: + n += 1 + case 4: + n += 2 + case 5: + n += 3 + case 7: + n += 4 + case 8: + n += 5 + } + } + return n, end, nil +} + +// Decode decodes src using the encoding enc. It writes at most +// DecodedLen(len(src)) bytes to dst and returns the number of bytes +// written. If src contains invalid base32 data, it will return the +// number of bytes successfully written and CorruptInputError. +// New line characters (\r and \n) are ignored. +func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { + src = bytes.Map(removeNewlinesMapper, src) + n, _, err = enc.decode(dst, src) + return +} + +// DecodeString returns the bytes represented by the base32 string s. +func (enc *Encoding) DecodeString(s string) ([]byte, error) { + s = strings.Map(removeNewlinesMapper, s) + dbuf := make([]byte, enc.DecodedLen(len(s))) + n, _, err := enc.decode(dbuf, []byte(s)) + return dbuf[:n], err +} + +type decoder struct { + err error + enc *Encoding + r io.Reader + end bool // saw end of message + buf [1024]byte // leftover input + nbuf int + out []byte // leftover decoded output + outbuf [1024 / 8 * 5]byte +} + +func (d *decoder) Read(p []byte) (n int, err error) { + if d.err != nil { + return 0, d.err + } + + // Use leftover decoded output from last read. + if len(d.out) > 0 { + n = copy(p, d.out) + d.out = d.out[n:] + return n, nil + } + + // Read a chunk. + nn := len(p) / 5 * 8 + if nn < 8 { + nn = 8 + } + if nn > len(d.buf) { + nn = len(d.buf) + } + nn, d.err = io.ReadAtLeast(d.r, d.buf[d.nbuf:nn], 8-d.nbuf) + d.nbuf += nn + if d.nbuf < 8 { + return 0, d.err + } + + // Decode chunk into p, or d.out and then p if p is too small. + nr := d.nbuf / 8 * 8 + nw := d.nbuf / 8 * 5 + if nw > len(p) { + nw, d.end, d.err = d.enc.decode(d.outbuf[0:], d.buf[0:nr]) + d.out = d.outbuf[0:nw] + n = copy(p, d.out) + d.out = d.out[n:] + } else { + n, d.end, d.err = d.enc.decode(p, d.buf[0:nr]) + } + d.nbuf -= nr + for i := 0; i < d.nbuf; i++ { + d.buf[i] = d.buf[i+nr] + } + + if d.err == nil { + d.err = err + } + return n, d.err +} + +type newlineFilteringReader struct { + wrapped io.Reader +} + +func (r *newlineFilteringReader) Read(p []byte) (int, error) { + n, err := r.wrapped.Read(p) + for n > 0 { + offset := 0 + for i, b := range p[0:n] { + if b != '\r' && b != '\n' { + if i != offset { + p[offset] = b + } + offset++ + } + } + if offset > 0 { + return offset, err + } + // Previous buffer entirely whitespace, read again + n, err = r.wrapped.Read(p) + } + return n, err +} + +// NewDecoder constructs a new base32 stream decoder. +func NewDecoder(enc *Encoding, r io.Reader) io.Reader { + return &decoder{enc: enc, r: &newlineFilteringReader{r}} +} + +// DecodedLen returns the maximum length in bytes of the decoded data +// corresponding to n bytes of base32-encoded data. +func (enc *Encoding) DecodedLen(n int) int { return n / 8 * 5 } diff --git a/src/encoding/base32/base32_test.go b/src/encoding/base32/base32_test.go new file mode 100644 index 000000000..5a68f06e1 --- /dev/null +++ b/src/encoding/base32/base32_test.go @@ -0,0 +1,302 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package base32 + +import ( + "bytes" + "io" + "io/ioutil" + "strings" + "testing" +) + +type testpair struct { + decoded, encoded string +} + +var pairs = []testpair{ + // RFC 4648 examples + {"", ""}, + {"f", "MY======"}, + {"fo", "MZXQ===="}, + {"foo", "MZXW6==="}, + {"foob", "MZXW6YQ="}, + {"fooba", "MZXW6YTB"}, + {"foobar", "MZXW6YTBOI======"}, + + // Wikipedia examples, converted to base32 + {"sure.", "ON2XEZJO"}, + {"sure", "ON2XEZI="}, + {"sur", "ON2XE==="}, + {"su", "ON2Q===="}, + {"leasure.", "NRSWC43VOJSS4==="}, + {"easure.", "MVQXG5LSMUXA===="}, + {"asure.", "MFZXK4TFFY======"}, + {"sure.", "ON2XEZJO"}, +} + +var bigtest = testpair{ + "Twas brillig, and the slithy toves", + "KR3WC4ZAMJZGS3DMNFTSYIDBNZSCA5DIMUQHG3DJORUHSIDUN53GK4Y=", +} + +func testEqual(t *testing.T, msg string, args ...interface{}) bool { + if args[len(args)-2] != args[len(args)-1] { + t.Errorf(msg, args...) + return false + } + return true +} + +func TestEncode(t *testing.T) { + for _, p := range pairs { + got := StdEncoding.EncodeToString([]byte(p.decoded)) + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, got, p.encoded) + } +} + +func TestEncoder(t *testing.T) { + for _, p := range pairs { + bb := &bytes.Buffer{} + encoder := NewEncoder(StdEncoding, bb) + encoder.Write([]byte(p.decoded)) + encoder.Close() + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, bb.String(), p.encoded) + } +} + +func TestEncoderBuffering(t *testing.T) { + input := []byte(bigtest.decoded) + for bs := 1; bs <= 12; bs++ { + bb := &bytes.Buffer{} + encoder := NewEncoder(StdEncoding, bb) + for pos := 0; pos < len(input); pos += bs { + end := pos + bs + if end > len(input) { + end = len(input) + } + n, err := encoder.Write(input[pos:end]) + testEqual(t, "Write(%q) gave error %v, want %v", input[pos:end], err, error(nil)) + testEqual(t, "Write(%q) gave length %v, want %v", input[pos:end], n, end-pos) + } + err := encoder.Close() + testEqual(t, "Close gave error %v, want %v", err, error(nil)) + testEqual(t, "Encoding/%d of %q = %q, want %q", bs, bigtest.decoded, bb.String(), bigtest.encoded) + } +} + +func TestDecode(t *testing.T) { + for _, p := range pairs { + dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded))) + count, end, err := StdEncoding.decode(dbuf, []byte(p.encoded)) + testEqual(t, "Decode(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, "Decode(%q) = length %v, want %v", p.encoded, count, len(p.decoded)) + if len(p.encoded) > 0 { + testEqual(t, "Decode(%q) = end %v, want %v", p.encoded, end, (p.encoded[len(p.encoded)-1] == '=')) + } + testEqual(t, "Decode(%q) = %q, want %q", p.encoded, + string(dbuf[0:count]), + p.decoded) + + dbuf, err = StdEncoding.DecodeString(p.encoded) + testEqual(t, "DecodeString(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, "DecodeString(%q) = %q, want %q", p.encoded, string(dbuf), p.decoded) + } +} + +func TestDecoder(t *testing.T) { + for _, p := range pairs { + decoder := NewDecoder(StdEncoding, strings.NewReader(p.encoded)) + dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded))) + count, err := decoder.Read(dbuf) + if err != nil && err != io.EOF { + t.Fatal("Read failed", err) + } + testEqual(t, "Read from %q = length %v, want %v", p.encoded, count, len(p.decoded)) + testEqual(t, "Decoding of %q = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded) + if err != io.EOF { + count, err = decoder.Read(dbuf) + } + testEqual(t, "Read from %q = %v, want %v", p.encoded, err, io.EOF) + } +} + +func TestDecoderBuffering(t *testing.T) { + for bs := 1; bs <= 12; bs++ { + decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded)) + buf := make([]byte, len(bigtest.decoded)+12) + var total int + for total = 0; total < len(bigtest.decoded); { + n, err := decoder.Read(buf[total : total+bs]) + testEqual(t, "Read from %q at pos %d = %d, %v, want _, %v", bigtest.encoded, total, n, err, error(nil)) + total += n + } + testEqual(t, "Decoding/%d of %q = %q, want %q", bs, bigtest.encoded, string(buf[0:total]), bigtest.decoded) + } +} + +func TestDecodeCorrupt(t *testing.T) { + testCases := []struct { + input string + offset int // -1 means no corruption. + }{ + {"", -1}, + {"!!!!", 0}, + {"x===", 0}, + {"AA=A====", 2}, + {"AAA=AAAA", 3}, + {"MMMMMMMMM", 8}, + {"MMMMMM", 0}, + {"A=", 1}, + {"AA=", 3}, + {"AA==", 4}, + {"AA===", 5}, + {"AAAA=", 5}, + {"AAAA==", 6}, + {"AAAAA=", 6}, + {"AAAAA==", 7}, + {"A=======", 1}, + {"AA======", -1}, + {"AAA=====", 3}, + {"AAAA====", -1}, + {"AAAAA===", -1}, + {"AAAAAA==", 6}, + {"AAAAAAA=", -1}, + {"AAAAAAAA", -1}, + } + for _, tc := range testCases { + dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input))) + _, err := StdEncoding.Decode(dbuf, []byte(tc.input)) + if tc.offset == -1 { + if err != nil { + t.Error("Decoder wrongly detected coruption in", tc.input) + } + continue + } + switch err := err.(type) { + case CorruptInputError: + testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset) + default: + t.Error("Decoder failed to detect corruption in", tc) + } + } +} + +func TestBig(t *testing.T) { + n := 3*1000 + 1 + raw := make([]byte, n) + const alpha = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + for i := 0; i < n; i++ { + raw[i] = alpha[i%len(alpha)] + } + encoded := new(bytes.Buffer) + w := NewEncoder(StdEncoding, encoded) + nn, err := w.Write(raw) + if nn != n || err != nil { + t.Fatalf("Encoder.Write(raw) = %d, %v want %d, nil", nn, err, n) + } + err = w.Close() + if err != nil { + t.Fatalf("Encoder.Close() = %v want nil", err) + } + decoded, err := ioutil.ReadAll(NewDecoder(StdEncoding, encoded)) + if err != nil { + t.Fatalf("ioutil.ReadAll(NewDecoder(...)): %v", err) + } + + if !bytes.Equal(raw, decoded) { + var i int + for i = 0; i < len(decoded) && i < len(raw); i++ { + if decoded[i] != raw[i] { + break + } + } + t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i) + } +} + +func testStringEncoding(t *testing.T, expected string, examples []string) { + for _, e := range examples { + buf, err := StdEncoding.DecodeString(e) + if err != nil { + t.Errorf("Decode(%q) failed: %v", e, err) + continue + } + if s := string(buf); s != expected { + t.Errorf("Decode(%q) = %q, want %q", e, s, expected) + } + } +} + +func TestNewLineCharacters(t *testing.T) { + // Each of these should decode to the string "sure", without errors. + examples := []string{ + "ON2XEZI=", + "ON2XEZI=\r", + "ON2XEZI=\n", + "ON2XEZI=\r\n", + "ON2XEZ\r\nI=", + "ON2X\rEZ\nI=", + "ON2X\nEZ\rI=", + "ON2XEZ\nI=", + "ON2XEZI\n=", + } + testStringEncoding(t, "sure", examples) + + // Each of these should decode to the string "foobar", without errors. + examples = []string{ + "MZXW6YTBOI======", + "MZXW6YTBOI=\r\n=====", + } + testStringEncoding(t, "foobar", examples) +} + +func TestDecoderIssue4779(t *testing.T) { + encoded := `JRXXEZLNEBUXA43VNUQGI33MN5ZCA43JOQQGC3LFOQWCAY3PNZZWKY3UMV2HK4 +RAMFSGS4DJONUWG2LOM4QGK3DJOQWCA43FMQQGI3YKMVUXK43NN5SCA5DFNVYG64RANFXGG2LENFSH +K3TUEB2XIIDMMFRG64TFEBSXIIDEN5WG64TFEBWWCZ3OMEQGC3DJOF2WCLRAKV2CAZLONFWQUYLEEB +WWS3TJNUQHMZLONFQW2LBAOF2WS4ZANZXXG5DSOVSCAZLYMVZGG2LUMF2GS33OEB2WY3DBNVRW6IDM +MFRG64TJOMQG42LTNEQHK5AKMFWGS4LVNFYCAZLYEBSWCIDDN5WW233EN4QGG33OONSXC5LBOQXCAR +DVNFZSAYLVORSSA2LSOVZGKIDEN5WG64RANFXAU4TFOBZGK2DFNZSGK4TJOQQGS3RAOZXWY5LQORQX +IZJAOZSWY2LUEBSXG43FEBRWS3DMOVWSAZDPNRXXEZJAMV2SAZTVM5UWC5BANZ2WY3DBBJYGC4TJMF +2HK4ROEBCXQY3FOB2GK5LSEBZWS3TUEBXWGY3BMVRWC5BAMN2XA2LEMF2GC5BANZXW4IDQOJXWSZDF +NZ2CYIDTOVXHIIDJNYFGG5LMOBQSA4LVNEQG6ZTGNFRWSYJAMRSXGZLSOVXHIIDNN5WGY2LUEBQW42 +LNEBUWIIDFON2CA3DBMJXXE5LNFY== +====` + encodedShort := strings.Replace(encoded, "\n", "", -1) + + dec := NewDecoder(StdEncoding, strings.NewReader(encoded)) + res1, err := ioutil.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + dec = NewDecoder(StdEncoding, strings.NewReader(encodedShort)) + var res2 []byte + res2, err = ioutil.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + if !bytes.Equal(res1, res2) { + t.Error("Decoded results not equal") + } +} + +func BenchmarkEncodeToString(b *testing.B) { + data := make([]byte, 8192) + b.SetBytes(int64(len(data))) + for i := 0; i < b.N; i++ { + StdEncoding.EncodeToString(data) + } +} + +func BenchmarkDecodeString(b *testing.B) { + data := StdEncoding.EncodeToString(make([]byte, 8192)) + b.SetBytes(int64(len(data))) + for i := 0; i < b.N; i++ { + StdEncoding.DecodeString(data) + } +} diff --git a/src/encoding/base32/example_test.go b/src/encoding/base32/example_test.go new file mode 100644 index 000000000..f6128d900 --- /dev/null +++ b/src/encoding/base32/example_test.go @@ -0,0 +1,45 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Keep in sync with ../base64/example_test.go. + +package base32_test + +import ( + "encoding/base32" + "fmt" + "os" +) + +func ExampleEncoding_EncodeToString() { + data := []byte("any + old & data") + str := base32.StdEncoding.EncodeToString(data) + fmt.Println(str) + // Output: + // MFXHSIBLEBXWYZBAEYQGIYLUME====== +} + +func ExampleEncoding_DecodeString() { + str := "ONXW2ZJAMRQXIYJAO5UXI2BAAAQGC3TEEDX3XPY=" + data, err := base32.StdEncoding.DecodeString(str) + if err != nil { + fmt.Println("error:", err) + return + } + fmt.Printf("%q\n", data) + // Output: + // "some data with \x00 and \ufeff" +} + +func ExampleNewEncoder() { + input := []byte("foo\x00bar") + encoder := base32.NewEncoder(base32.StdEncoding, os.Stdout) + encoder.Write(input) + // Must close the encoder when finished to flush any partial blocks. + // If you comment out the following line, the last partial block "r" + // won't be encoded. + encoder.Close() + // Output: + // MZXW6ADCMFZA==== +} diff --git a/src/encoding/base64/base64.go b/src/encoding/base64/base64.go new file mode 100644 index 000000000..ad3abe662 --- /dev/null +++ b/src/encoding/base64/base64.go @@ -0,0 +1,391 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package base64 implements base64 encoding as specified by RFC 4648. +package base64 + +import ( + "bytes" + "io" + "strconv" + "strings" +) + +/* + * Encodings + */ + +// An Encoding is a radix 64 encoding/decoding scheme, defined by a +// 64-character alphabet. The most common encoding is the "base64" +// encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM +// (RFC 1421). RFC 4648 also defines an alternate encoding, which is +// the standard encoding with - and _ substituted for + and /. +type Encoding struct { + encode string + decodeMap [256]byte +} + +const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" +const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + +// NewEncoding returns a new Encoding defined by the given alphabet, +// which must be a 64-byte string. +func NewEncoding(encoder string) *Encoding { + e := new(Encoding) + e.encode = encoder + for i := 0; i < len(e.decodeMap); i++ { + e.decodeMap[i] = 0xFF + } + for i := 0; i < len(encoder); i++ { + e.decodeMap[encoder[i]] = byte(i) + } + return e +} + +// StdEncoding is the standard base64 encoding, as defined in +// RFC 4648. +var StdEncoding = NewEncoding(encodeStd) + +// URLEncoding is the alternate base64 encoding defined in RFC 4648. +// It is typically used in URLs and file names. +var URLEncoding = NewEncoding(encodeURL) + +var removeNewlinesMapper = func(r rune) rune { + if r == '\r' || r == '\n' { + return -1 + } + return r +} + +/* + * Encoder + */ + +// Encode encodes src using the encoding enc, writing +// EncodedLen(len(src)) bytes to dst. +// +// The encoding pads the output to a multiple of 4 bytes, +// so Encode is not appropriate for use on individual blocks +// of a large data stream. Use NewEncoder() instead. +func (enc *Encoding) Encode(dst, src []byte) { + if len(src) == 0 { + return + } + + for len(src) > 0 { + var b0, b1, b2, b3 byte + + // Unpack 4x 6-bit source blocks into a 4 byte + // destination quantum + switch len(src) { + default: + b3 = src[2] & 0x3F + b2 = src[2] >> 6 + fallthrough + case 2: + b2 |= (src[1] << 2) & 0x3F + b1 = src[1] >> 4 + fallthrough + case 1: + b1 |= (src[0] << 4) & 0x3F + b0 = src[0] >> 2 + } + + // Encode 6-bit blocks using the base64 alphabet + dst[0] = enc.encode[b0] + dst[1] = enc.encode[b1] + dst[2] = enc.encode[b2] + dst[3] = enc.encode[b3] + + // Pad the final quantum + if len(src) < 3 { + dst[3] = '=' + if len(src) < 2 { + dst[2] = '=' + } + break + } + + src = src[3:] + dst = dst[4:] + } +} + +// EncodeToString returns the base64 encoding of src. +func (enc *Encoding) EncodeToString(src []byte) string { + buf := make([]byte, enc.EncodedLen(len(src))) + enc.Encode(buf, src) + return string(buf) +} + +type encoder struct { + err error + enc *Encoding + w io.Writer + buf [3]byte // buffered data waiting to be encoded + nbuf int // number of bytes in buf + out [1024]byte // output buffer +} + +func (e *encoder) Write(p []byte) (n int, err error) { + if e.err != nil { + return 0, e.err + } + + // Leading fringe. + if e.nbuf > 0 { + var i int + for i = 0; i < len(p) && e.nbuf < 3; i++ { + e.buf[e.nbuf] = p[i] + e.nbuf++ + } + n += i + p = p[i:] + if e.nbuf < 3 { + return + } + e.enc.Encode(e.out[0:], e.buf[0:]) + if _, e.err = e.w.Write(e.out[0:4]); e.err != nil { + return n, e.err + } + e.nbuf = 0 + } + + // Large interior chunks. + for len(p) >= 3 { + nn := len(e.out) / 4 * 3 + if nn > len(p) { + nn = len(p) + nn -= nn % 3 + } + e.enc.Encode(e.out[0:], p[0:nn]) + if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil { + return n, e.err + } + n += nn + p = p[nn:] + } + + // Trailing fringe. + for i := 0; i < len(p); i++ { + e.buf[i] = p[i] + } + e.nbuf = len(p) + n += len(p) + return +} + +// Close flushes any pending output from the encoder. +// It is an error to call Write after calling Close. +func (e *encoder) Close() error { + // If there's anything left in the buffer, flush it out + if e.err == nil && e.nbuf > 0 { + e.enc.Encode(e.out[0:], e.buf[0:e.nbuf]) + e.nbuf = 0 + _, e.err = e.w.Write(e.out[0:4]) + } + return e.err +} + +// NewEncoder returns a new base64 stream encoder. Data written to +// the returned writer will be encoded using enc and then written to w. +// Base64 encodings operate in 4-byte blocks; when finished +// writing, the caller must Close the returned encoder to flush any +// partially written blocks. +func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser { + return &encoder{enc: enc, w: w} +} + +// EncodedLen returns the length in bytes of the base64 encoding +// of an input buffer of length n. +func (enc *Encoding) EncodedLen(n int) int { return (n + 2) / 3 * 4 } + +/* + * Decoder + */ + +type CorruptInputError int64 + +func (e CorruptInputError) Error() string { + return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10) +} + +// decode is like Decode but returns an additional 'end' value, which +// indicates if end-of-message padding was encountered and thus any +// additional data is an error. This method assumes that src has been +// stripped of all supported whitespace ('\r' and '\n'). +func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { + olen := len(src) + for len(src) > 0 && !end { + // Decode quantum using the base64 alphabet + var dbuf [4]byte + dlen := 4 + + for j := range dbuf { + if len(src) == 0 { + return n, false, CorruptInputError(olen - len(src) - j) + } + in := src[0] + src = src[1:] + if in == '=' { + // We've reached the end and there's padding + switch j { + case 0, 1: + // incorrect padding + return n, false, CorruptInputError(olen - len(src) - 1) + case 2: + // "==" is expected, the first "=" is already consumed. + if len(src) == 0 { + // not enough padding + return n, false, CorruptInputError(olen) + } + if src[0] != '=' { + // incorrect padding + return n, false, CorruptInputError(olen - len(src) - 1) + } + src = src[1:] + } + if len(src) > 0 { + // trailing garbage + err = CorruptInputError(olen - len(src)) + } + dlen, end = j, true + break + } + dbuf[j] = enc.decodeMap[in] + if dbuf[j] == 0xFF { + return n, false, CorruptInputError(olen - len(src) - 1) + } + } + + // Pack 4x 6-bit source blocks into 3 byte destination + // quantum + switch dlen { + case 4: + dst[2] = dbuf[2]<<6 | dbuf[3] + fallthrough + case 3: + dst[1] = dbuf[1]<<4 | dbuf[2]>>2 + fallthrough + case 2: + dst[0] = dbuf[0]<<2 | dbuf[1]>>4 + } + dst = dst[3:] + n += dlen - 1 + } + + return n, end, err +} + +// Decode decodes src using the encoding enc. It writes at most +// DecodedLen(len(src)) bytes to dst and returns the number of bytes +// written. If src contains invalid base64 data, it will return the +// number of bytes successfully written and CorruptInputError. +// New line characters (\r and \n) are ignored. +func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { + src = bytes.Map(removeNewlinesMapper, src) + n, _, err = enc.decode(dst, src) + return +} + +// DecodeString returns the bytes represented by the base64 string s. +func (enc *Encoding) DecodeString(s string) ([]byte, error) { + s = strings.Map(removeNewlinesMapper, s) + dbuf := make([]byte, enc.DecodedLen(len(s))) + n, _, err := enc.decode(dbuf, []byte(s)) + return dbuf[:n], err +} + +type decoder struct { + err error + enc *Encoding + r io.Reader + end bool // saw end of message + buf [1024]byte // leftover input + nbuf int + out []byte // leftover decoded output + outbuf [1024 / 4 * 3]byte +} + +func (d *decoder) Read(p []byte) (n int, err error) { + if d.err != nil { + return 0, d.err + } + + // Use leftover decoded output from last read. + if len(d.out) > 0 { + n = copy(p, d.out) + d.out = d.out[n:] + return n, nil + } + + // Read a chunk. + nn := len(p) / 3 * 4 + if nn < 4 { + nn = 4 + } + if nn > len(d.buf) { + nn = len(d.buf) + } + nn, d.err = io.ReadAtLeast(d.r, d.buf[d.nbuf:nn], 4-d.nbuf) + d.nbuf += nn + if d.err != nil || d.nbuf < 4 { + return 0, d.err + } + + // Decode chunk into p, or d.out and then p if p is too small. + nr := d.nbuf / 4 * 4 + nw := d.nbuf / 4 * 3 + if nw > len(p) { + nw, d.end, d.err = d.enc.decode(d.outbuf[0:], d.buf[0:nr]) + d.out = d.outbuf[0:nw] + n = copy(p, d.out) + d.out = d.out[n:] + } else { + n, d.end, d.err = d.enc.decode(p, d.buf[0:nr]) + } + d.nbuf -= nr + for i := 0; i < d.nbuf; i++ { + d.buf[i] = d.buf[i+nr] + } + + if d.err == nil { + d.err = err + } + return n, d.err +} + +type newlineFilteringReader struct { + wrapped io.Reader +} + +func (r *newlineFilteringReader) Read(p []byte) (int, error) { + n, err := r.wrapped.Read(p) + for n > 0 { + offset := 0 + for i, b := range p[0:n] { + if b != '\r' && b != '\n' { + if i != offset { + p[offset] = b + } + offset++ + } + } + if offset > 0 { + return offset, err + } + // Previous buffer entirely whitespace, read again + n, err = r.wrapped.Read(p) + } + return n, err +} + +// NewDecoder constructs a new base64 stream decoder. +func NewDecoder(enc *Encoding, r io.Reader) io.Reader { + return &decoder{enc: enc, r: &newlineFilteringReader{r}} +} + +// DecodedLen returns the maximum length in bytes of the decoded data +// corresponding to n bytes of base64-encoded data. +func (enc *Encoding) DecodedLen(n int) int { return n / 4 * 3 } diff --git a/src/encoding/base64/base64_test.go b/src/encoding/base64/base64_test.go new file mode 100644 index 000000000..7d199bfa0 --- /dev/null +++ b/src/encoding/base64/base64_test.go @@ -0,0 +1,360 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package base64 + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + "reflect" + "strings" + "testing" + "time" +) + +type testpair struct { + decoded, encoded string +} + +var pairs = []testpair{ + // RFC 3548 examples + {"\x14\xfb\x9c\x03\xd9\x7e", "FPucA9l+"}, + {"\x14\xfb\x9c\x03\xd9", "FPucA9k="}, + {"\x14\xfb\x9c\x03", "FPucAw=="}, + + // RFC 4648 examples + {"", ""}, + {"f", "Zg=="}, + {"fo", "Zm8="}, + {"foo", "Zm9v"}, + {"foob", "Zm9vYg=="}, + {"fooba", "Zm9vYmE="}, + {"foobar", "Zm9vYmFy"}, + + // Wikipedia examples + {"sure.", "c3VyZS4="}, + {"sure", "c3VyZQ=="}, + {"sur", "c3Vy"}, + {"su", "c3U="}, + {"leasure.", "bGVhc3VyZS4="}, + {"easure.", "ZWFzdXJlLg=="}, + {"asure.", "YXN1cmUu"}, + {"sure.", "c3VyZS4="}, +} + +var bigtest = testpair{ + "Twas brillig, and the slithy toves", + "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==", +} + +func testEqual(t *testing.T, msg string, args ...interface{}) bool { + if args[len(args)-2] != args[len(args)-1] { + t.Errorf(msg, args...) + return false + } + return true +} + +func TestEncode(t *testing.T) { + for _, p := range pairs { + got := StdEncoding.EncodeToString([]byte(p.decoded)) + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, got, p.encoded) + } +} + +func TestEncoder(t *testing.T) { + for _, p := range pairs { + bb := &bytes.Buffer{} + encoder := NewEncoder(StdEncoding, bb) + encoder.Write([]byte(p.decoded)) + encoder.Close() + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, bb.String(), p.encoded) + } +} + +func TestEncoderBuffering(t *testing.T) { + input := []byte(bigtest.decoded) + for bs := 1; bs <= 12; bs++ { + bb := &bytes.Buffer{} + encoder := NewEncoder(StdEncoding, bb) + for pos := 0; pos < len(input); pos += bs { + end := pos + bs + if end > len(input) { + end = len(input) + } + n, err := encoder.Write(input[pos:end]) + testEqual(t, "Write(%q) gave error %v, want %v", input[pos:end], err, error(nil)) + testEqual(t, "Write(%q) gave length %v, want %v", input[pos:end], n, end-pos) + } + err := encoder.Close() + testEqual(t, "Close gave error %v, want %v", err, error(nil)) + testEqual(t, "Encoding/%d of %q = %q, want %q", bs, bigtest.decoded, bb.String(), bigtest.encoded) + } +} + +func TestDecode(t *testing.T) { + for _, p := range pairs { + dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded))) + count, end, err := StdEncoding.decode(dbuf, []byte(p.encoded)) + testEqual(t, "Decode(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, "Decode(%q) = length %v, want %v", p.encoded, count, len(p.decoded)) + if len(p.encoded) > 0 { + testEqual(t, "Decode(%q) = end %v, want %v", p.encoded, end, (p.encoded[len(p.encoded)-1] == '=')) + } + testEqual(t, "Decode(%q) = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded) + + dbuf, err = StdEncoding.DecodeString(p.encoded) + testEqual(t, "DecodeString(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, "DecodeString(%q) = %q, want %q", string(dbuf), p.decoded) + } +} + +func TestDecoder(t *testing.T) { + for _, p := range pairs { + decoder := NewDecoder(StdEncoding, strings.NewReader(p.encoded)) + dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded))) + count, err := decoder.Read(dbuf) + if err != nil && err != io.EOF { + t.Fatal("Read failed", err) + } + testEqual(t, "Read from %q = length %v, want %v", p.encoded, count, len(p.decoded)) + testEqual(t, "Decoding of %q = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded) + if err != io.EOF { + count, err = decoder.Read(dbuf) + } + testEqual(t, "Read from %q = %v, want %v", p.encoded, err, io.EOF) + } +} + +func TestDecoderBuffering(t *testing.T) { + for bs := 1; bs <= 12; bs++ { + decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded)) + buf := make([]byte, len(bigtest.decoded)+12) + var total int + for total = 0; total < len(bigtest.decoded); { + n, err := decoder.Read(buf[total : total+bs]) + testEqual(t, "Read from %q at pos %d = %d, %v, want _, %v", bigtest.encoded, total, n, err, error(nil)) + total += n + } + testEqual(t, "Decoding/%d of %q = %q, want %q", bs, bigtest.encoded, string(buf[0:total]), bigtest.decoded) + } +} + +func TestDecodeCorrupt(t *testing.T) { + testCases := []struct { + input string + offset int // -1 means no corruption. + }{ + {"", -1}, + {"!!!!", 0}, + {"====", 0}, + {"x===", 1}, + {"=AAA", 0}, + {"A=AA", 1}, + {"AA=A", 2}, + {"AA==A", 4}, + {"AAA=AAAA", 4}, + {"AAAAA", 4}, + {"AAAAAA", 4}, + {"A=", 1}, + {"A==", 1}, + {"AA=", 3}, + {"AA==", -1}, + {"AAA=", -1}, + {"AAAA", -1}, + {"AAAAAA=", 7}, + {"YWJjZA=====", 8}, + } + for _, tc := range testCases { + dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input))) + _, err := StdEncoding.Decode(dbuf, []byte(tc.input)) + if tc.offset == -1 { + if err != nil { + t.Error("Decoder wrongly detected coruption in", tc.input) + } + continue + } + switch err := err.(type) { + case CorruptInputError: + testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset) + default: + t.Error("Decoder failed to detect corruption in", tc) + } + } +} + +func TestBig(t *testing.T) { + n := 3*1000 + 1 + raw := make([]byte, n) + const alpha = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + for i := 0; i < n; i++ { + raw[i] = alpha[i%len(alpha)] + } + encoded := new(bytes.Buffer) + w := NewEncoder(StdEncoding, encoded) + nn, err := w.Write(raw) + if nn != n || err != nil { + t.Fatalf("Encoder.Write(raw) = %d, %v want %d, nil", nn, err, n) + } + err = w.Close() + if err != nil { + t.Fatalf("Encoder.Close() = %v want nil", err) + } + decoded, err := ioutil.ReadAll(NewDecoder(StdEncoding, encoded)) + if err != nil { + t.Fatalf("ioutil.ReadAll(NewDecoder(...)): %v", err) + } + + if !bytes.Equal(raw, decoded) { + var i int + for i = 0; i < len(decoded) && i < len(raw); i++ { + if decoded[i] != raw[i] { + break + } + } + t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i) + } +} + +func TestNewLineCharacters(t *testing.T) { + // Each of these should decode to the string "sure", without errors. + const expected = "sure" + examples := []string{ + "c3VyZQ==", + "c3VyZQ==\r", + "c3VyZQ==\n", + "c3VyZQ==\r\n", + "c3VyZ\r\nQ==", + "c3V\ryZ\nQ==", + "c3V\nyZ\rQ==", + "c3VyZ\nQ==", + "c3VyZQ\n==", + "c3VyZQ=\n=", + "c3VyZQ=\r\n\r\n=", + } + for _, e := range examples { + buf, err := StdEncoding.DecodeString(e) + if err != nil { + t.Errorf("Decode(%q) failed: %v", e, err) + continue + } + if s := string(buf); s != expected { + t.Errorf("Decode(%q) = %q, want %q", e, s, expected) + } + } +} + +type nextRead struct { + n int // bytes to return + err error // error to return +} + +// faultInjectReader returns data from source, rate-limited +// and with the errors as written to nextc. +type faultInjectReader struct { + source string + nextc <-chan nextRead +} + +func (r *faultInjectReader) Read(p []byte) (int, error) { + nr := <-r.nextc + if len(p) > nr.n { + p = p[:nr.n] + } + n := copy(p, r.source) + r.source = r.source[n:] + return n, nr.err +} + +// tests that we don't ignore errors from our underlying reader +func TestDecoderIssue3577(t *testing.T) { + next := make(chan nextRead, 10) + wantErr := errors.New("my error") + next <- nextRead{5, nil} + next <- nextRead{10, wantErr} + next <- nextRead{0, wantErr} + d := NewDecoder(StdEncoding, &faultInjectReader{ + source: "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==", // twas brillig... + nextc: next, + }) + errc := make(chan error) + go func() { + _, err := ioutil.ReadAll(d) + errc <- err + }() + select { + case err := <-errc: + if err != wantErr { + t.Errorf("got error %v; want %v", err, wantErr) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout; Decoder blocked without returning an error") + } +} + +func TestDecoderIssue4779(t *testing.T) { + encoded := `CP/EAT8AAAEF +AQEBAQEBAAAAAAAAAAMAAQIEBQYHCAkKCwEAAQUBAQEBAQEAAAAAAAAAAQACAwQFBgcICQoLEAAB +BAEDAgQCBQcGCAUDDDMBAAIRAwQhEjEFQVFhEyJxgTIGFJGhsUIjJBVSwWIzNHKC0UMHJZJT8OHx +Y3M1FqKygyZEk1RkRcKjdDYX0lXiZfKzhMPTdePzRieUpIW0lcTU5PSltcXV5fVWZnaGlqa2xtbm +9jdHV2d3h5ent8fX5/cRAAICAQIEBAMEBQYHBwYFNQEAAhEDITESBEFRYXEiEwUygZEUobFCI8FS +0fAzJGLhcoKSQ1MVY3M08SUGFqKygwcmNcLSRJNUoxdkRVU2dGXi8rOEw9N14/NGlKSFtJXE1OT0 +pbXF1eX1VmZ2hpamtsbW5vYnN0dXZ3eHl6e3x//aAAwDAQACEQMRAD8A9VSSSSUpJJJJSkkkJ+Tj +1kiy1jCJJDnAcCTykpKkuQ6p/jN6FgmxlNduXawwAzaGH+V6jn/R/wCt71zdn+N/qL3kVYFNYB4N +ji6PDVjWpKp9TSXnvTf8bFNjg3qOEa2n6VlLpj/rT/pf567DpX1i6L1hs9Py67X8mqdtg/rUWbbf ++gkp0kkkklKSSSSUpJJJJT//0PVUkkklKVLq3WMDpGI7KzrNjADtYNXvI/Mqr/Pd/q9W3vaxjnvM +NaCXE9gNSvGPrf8AWS3qmba5jjsJhoB0DAf0NDf6sevf+/lf8Hj0JJATfWT6/dV6oXU1uOLQeKKn +EQP+Hubtfe/+R7Mf/g7f5xcocp++Z11JMCJPgFBxOg7/AOuqDx8I/ikpkXkmSdU8mJIJA/O8EMAy +j+mSARB/17pKVXYWHXjsj7yIex0PadzXMO1zT5KHoNA3HT8ietoGhgjsfA+CSnvvqh/jJtqsrwOv +2b6NGNzXfTYexzJ+nU7/ALkf4P8Awv6P9KvTQQ4AgyDqCF85Pho3CTB7eHwXoH+LT65uZbX9X+o2 +bqbPb06551Y4 +` + encodedShort := strings.Replace(encoded, "\n", "", -1) + + dec := NewDecoder(StdEncoding, strings.NewReader(encoded)) + res1, err := ioutil.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + dec = NewDecoder(StdEncoding, strings.NewReader(encodedShort)) + var res2 []byte + res2, err = ioutil.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + if !bytes.Equal(res1, res2) { + t.Error("Decoded results not equal") + } +} + +func TestDecoderIssue7733(t *testing.T) { + s, err := StdEncoding.DecodeString("YWJjZA=====") + want := CorruptInputError(8) + if !reflect.DeepEqual(want, err) { + t.Errorf("Error = %v; want CorruptInputError(8)", err) + } + if string(s) != "abcd" { + t.Errorf("DecodeString = %q; want abcd", s) + } +} + +func BenchmarkEncodeToString(b *testing.B) { + data := make([]byte, 8192) + b.SetBytes(int64(len(data))) + for i := 0; i < b.N; i++ { + StdEncoding.EncodeToString(data) + } +} + +func BenchmarkDecodeString(b *testing.B) { + data := StdEncoding.EncodeToString(make([]byte, 8192)) + b.SetBytes(int64(len(data))) + for i := 0; i < b.N; i++ { + StdEncoding.DecodeString(data) + } +} diff --git a/src/encoding/base64/example_test.go b/src/encoding/base64/example_test.go new file mode 100644 index 000000000..d18b856a0 --- /dev/null +++ b/src/encoding/base64/example_test.go @@ -0,0 +1,45 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Keep in sync with ../base32/example_test.go. + +package base64_test + +import ( + "encoding/base64" + "fmt" + "os" +) + +func ExampleEncoding_EncodeToString() { + data := []byte("any + old & data") + str := base64.StdEncoding.EncodeToString(data) + fmt.Println(str) + // Output: + // YW55ICsgb2xkICYgZGF0YQ== +} + +func ExampleEncoding_DecodeString() { + str := "c29tZSBkYXRhIHdpdGggACBhbmQg77u/" + data, err := base64.StdEncoding.DecodeString(str) + if err != nil { + fmt.Println("error:", err) + return + } + fmt.Printf("%q\n", data) + // Output: + // "some data with \x00 and \ufeff" +} + +func ExampleNewEncoder() { + input := []byte("foo\x00bar") + encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout) + encoder.Write(input) + // Must close the encoder when finished to flush any partial blocks. + // If you comment out the following line, the last partial block "r" + // won't be encoded. + encoder.Close() + // Output: + // Zm9vAGJhcg== +} diff --git a/src/encoding/binary/binary.go b/src/encoding/binary/binary.go new file mode 100644 index 000000000..466bf97c9 --- /dev/null +++ b/src/encoding/binary/binary.go @@ -0,0 +1,634 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package binary implements simple translation between numbers and byte +// sequences and encoding and decoding of varints. +// +// Numbers are translated by reading and writing fixed-size values. +// A fixed-size value is either a fixed-size arithmetic +// type (int8, uint8, int16, float32, complex64, ...) +// or an array or struct containing only fixed-size values. +// +// The varint functions encode and decode single integer values using +// a variable-length encoding; smaller values require fewer bytes. +// For a specification, see +// http://code.google.com/apis/protocolbuffers/docs/encoding.html. +// +// This package favors simplicity over efficiency. Clients that require +// high-performance serialization, especially for large data structures, +// should look at more advanced solutions such as the encoding/gob +// package or protocol buffers. +package binary + +import ( + "errors" + "io" + "math" + "reflect" +) + +// A ByteOrder specifies how to convert byte sequences into +// 16-, 32-, or 64-bit unsigned integers. +type ByteOrder interface { + Uint16([]byte) uint16 + Uint32([]byte) uint32 + Uint64([]byte) uint64 + PutUint16([]byte, uint16) + PutUint32([]byte, uint32) + PutUint64([]byte, uint64) + String() string +} + +// LittleEndian is the little-endian implementation of ByteOrder. +var LittleEndian littleEndian + +// BigEndian is the big-endian implementation of ByteOrder. +var BigEndian bigEndian + +type littleEndian struct{} + +func (littleEndian) Uint16(b []byte) uint16 { return uint16(b[0]) | uint16(b[1])<<8 } + +func (littleEndian) PutUint16(b []byte, v uint16) { + b[0] = byte(v) + b[1] = byte(v >> 8) +} + +func (littleEndian) Uint32(b []byte) uint32 { + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func (littleEndian) PutUint32(b []byte, v uint32) { + b[0] = byte(v) + b[1] = byte(v >> 8) + b[2] = byte(v >> 16) + b[3] = byte(v >> 24) +} + +func (littleEndian) Uint64(b []byte) uint64 { + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | + uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 +} + +func (littleEndian) PutUint64(b []byte, v uint64) { + b[0] = byte(v) + b[1] = byte(v >> 8) + b[2] = byte(v >> 16) + b[3] = byte(v >> 24) + b[4] = byte(v >> 32) + b[5] = byte(v >> 40) + b[6] = byte(v >> 48) + b[7] = byte(v >> 56) +} + +func (littleEndian) String() string { return "LittleEndian" } + +func (littleEndian) GoString() string { return "binary.LittleEndian" } + +type bigEndian struct{} + +func (bigEndian) Uint16(b []byte) uint16 { return uint16(b[1]) | uint16(b[0])<<8 } + +func (bigEndian) PutUint16(b []byte, v uint16) { + b[0] = byte(v >> 8) + b[1] = byte(v) +} + +func (bigEndian) Uint32(b []byte) uint32 { + return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 +} + +func (bigEndian) PutUint32(b []byte, v uint32) { + b[0] = byte(v >> 24) + b[1] = byte(v >> 16) + b[2] = byte(v >> 8) + b[3] = byte(v) +} + +func (bigEndian) Uint64(b []byte) uint64 { + return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | + uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56 +} + +func (bigEndian) PutUint64(b []byte, v uint64) { + b[0] = byte(v >> 56) + b[1] = byte(v >> 48) + b[2] = byte(v >> 40) + b[3] = byte(v >> 32) + b[4] = byte(v >> 24) + b[5] = byte(v >> 16) + b[6] = byte(v >> 8) + b[7] = byte(v) +} + +func (bigEndian) String() string { return "BigEndian" } + +func (bigEndian) GoString() string { return "binary.BigEndian" } + +// Read reads structured binary data from r into data. +// Data must be a pointer to a fixed-size value or a slice +// of fixed-size values. +// Bytes read from r are decoded using the specified byte order +// and written to successive fields of the data. +// When reading into structs, the field data for fields with +// blank (_) field names is skipped; i.e., blank field names +// may be used for padding. +// When reading into a struct, all non-blank fields must be exported. +func Read(r io.Reader, order ByteOrder, data interface{}) error { + // Fast path for basic types and slices. + if n := intDataSize(data); n != 0 { + var b [8]byte + var bs []byte + if n > len(b) { + bs = make([]byte, n) + } else { + bs = b[:n] + } + if _, err := io.ReadFull(r, bs); err != nil { + return err + } + switch data := data.(type) { + case *int8: + *data = int8(b[0]) + case *uint8: + *data = b[0] + case *int16: + *data = int16(order.Uint16(bs)) + case *uint16: + *data = order.Uint16(bs) + case *int32: + *data = int32(order.Uint32(bs)) + case *uint32: + *data = order.Uint32(bs) + case *int64: + *data = int64(order.Uint64(bs)) + case *uint64: + *data = order.Uint64(bs) + case []int8: + for i, x := range bs { // Easier to loop over the input for 8-bit values. + data[i] = int8(x) + } + case []uint8: + copy(data, bs) + case []int16: + for i := range data { + data[i] = int16(order.Uint16(bs[2*i:])) + } + case []uint16: + for i := range data { + data[i] = order.Uint16(bs[2*i:]) + } + case []int32: + for i := range data { + data[i] = int32(order.Uint32(bs[4*i:])) + } + case []uint32: + for i := range data { + data[i] = order.Uint32(bs[4*i:]) + } + case []int64: + for i := range data { + data[i] = int64(order.Uint64(bs[8*i:])) + } + case []uint64: + for i := range data { + data[i] = order.Uint64(bs[8*i:]) + } + } + return nil + } + + // Fallback to reflect-based decoding. + v := reflect.ValueOf(data) + size := -1 + switch v.Kind() { + case reflect.Ptr: + v = v.Elem() + size = dataSize(v) + case reflect.Slice: + size = dataSize(v) + } + if size < 0 { + return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String()) + } + d := &decoder{order: order, buf: make([]byte, size)} + if _, err := io.ReadFull(r, d.buf); err != nil { + return err + } + d.value(v) + return nil +} + +// Write writes the binary representation of data into w. +// Data must be a fixed-size value or a slice of fixed-size +// values, or a pointer to such data. +// Bytes written to w are encoded using the specified byte order +// and read from successive fields of the data. +// When writing structs, zero values are written for fields +// with blank (_) field names. +func Write(w io.Writer, order ByteOrder, data interface{}) error { + // Fast path for basic types and slices. + if n := intDataSize(data); n != 0 { + var b [8]byte + var bs []byte + if n > len(b) { + bs = make([]byte, n) + } else { + bs = b[:n] + } + switch v := data.(type) { + case *int8: + bs = b[:1] + b[0] = byte(*v) + case int8: + bs = b[:1] + b[0] = byte(v) + case []int8: + for i, x := range v { + bs[i] = byte(x) + } + case *uint8: + bs = b[:1] + b[0] = *v + case uint8: + bs = b[:1] + b[0] = byte(v) + case []uint8: + bs = v + case *int16: + bs = b[:2] + order.PutUint16(bs, uint16(*v)) + case int16: + bs = b[:2] + order.PutUint16(bs, uint16(v)) + case []int16: + for i, x := range v { + order.PutUint16(bs[2*i:], uint16(x)) + } + case *uint16: + bs = b[:2] + order.PutUint16(bs, *v) + case uint16: + bs = b[:2] + order.PutUint16(bs, v) + case []uint16: + for i, x := range v { + order.PutUint16(bs[2*i:], x) + } + case *int32: + bs = b[:4] + order.PutUint32(bs, uint32(*v)) + case int32: + bs = b[:4] + order.PutUint32(bs, uint32(v)) + case []int32: + for i, x := range v { + order.PutUint32(bs[4*i:], uint32(x)) + } + case *uint32: + bs = b[:4] + order.PutUint32(bs, *v) + case uint32: + bs = b[:4] + order.PutUint32(bs, v) + case []uint32: + for i, x := range v { + order.PutUint32(bs[4*i:], x) + } + case *int64: + bs = b[:8] + order.PutUint64(bs, uint64(*v)) + case int64: + bs = b[:8] + order.PutUint64(bs, uint64(v)) + case []int64: + for i, x := range v { + order.PutUint64(bs[8*i:], uint64(x)) + } + case *uint64: + bs = b[:8] + order.PutUint64(bs, *v) + case uint64: + bs = b[:8] + order.PutUint64(bs, v) + case []uint64: + for i, x := range v { + order.PutUint64(bs[8*i:], x) + } + } + _, err := w.Write(bs) + return err + } + + // Fallback to reflect-based encoding. + v := reflect.Indirect(reflect.ValueOf(data)) + size := dataSize(v) + if size < 0 { + return errors.New("binary.Write: invalid type " + reflect.TypeOf(data).String()) + } + buf := make([]byte, size) + e := &encoder{order: order, buf: buf} + e.value(v) + _, err := w.Write(buf) + return err +} + +// Size returns how many bytes Write would generate to encode the value v, which +// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data. +// If v is neither of these, Size returns -1. +func Size(v interface{}) int { + return dataSize(reflect.Indirect(reflect.ValueOf(v))) +} + +// dataSize returns the number of bytes the actual data represented by v occupies in memory. +// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice +// it returns the length of the slice times the element size and does not count the memory +// occupied by the header. If the type of v is not acceptable, dataSize returns -1. +func dataSize(v reflect.Value) int { + if v.Kind() == reflect.Slice { + if s := sizeof(v.Type().Elem()); s >= 0 { + return s * v.Len() + } + return -1 + } + return sizeof(v.Type()) +} + +// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable. +func sizeof(t reflect.Type) int { + switch t.Kind() { + case reflect.Array: + if s := sizeof(t.Elem()); s >= 0 { + return s * t.Len() + } + + case reflect.Struct: + sum := 0 + for i, n := 0, t.NumField(); i < n; i++ { + s := sizeof(t.Field(i).Type) + if s < 0 { + return -1 + } + sum += s + } + return sum + + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return int(t.Size()) + } + + return -1 +} + +type coder struct { + order ByteOrder + buf []byte +} + +type decoder coder +type encoder coder + +func (d *decoder) uint8() uint8 { + x := d.buf[0] + d.buf = d.buf[1:] + return x +} + +func (e *encoder) uint8(x uint8) { + e.buf[0] = x + e.buf = e.buf[1:] +} + +func (d *decoder) uint16() uint16 { + x := d.order.Uint16(d.buf[0:2]) + d.buf = d.buf[2:] + return x +} + +func (e *encoder) uint16(x uint16) { + e.order.PutUint16(e.buf[0:2], x) + e.buf = e.buf[2:] +} + +func (d *decoder) uint32() uint32 { + x := d.order.Uint32(d.buf[0:4]) + d.buf = d.buf[4:] + return x +} + +func (e *encoder) uint32(x uint32) { + e.order.PutUint32(e.buf[0:4], x) + e.buf = e.buf[4:] +} + +func (d *decoder) uint64() uint64 { + x := d.order.Uint64(d.buf[0:8]) + d.buf = d.buf[8:] + return x +} + +func (e *encoder) uint64(x uint64) { + e.order.PutUint64(e.buf[0:8], x) + e.buf = e.buf[8:] +} + +func (d *decoder) int8() int8 { return int8(d.uint8()) } + +func (e *encoder) int8(x int8) { e.uint8(uint8(x)) } + +func (d *decoder) int16() int16 { return int16(d.uint16()) } + +func (e *encoder) int16(x int16) { e.uint16(uint16(x)) } + +func (d *decoder) int32() int32 { return int32(d.uint32()) } + +func (e *encoder) int32(x int32) { e.uint32(uint32(x)) } + +func (d *decoder) int64() int64 { return int64(d.uint64()) } + +func (e *encoder) int64(x int64) { e.uint64(uint64(x)) } + +func (d *decoder) value(v reflect.Value) { + switch v.Kind() { + case reflect.Array: + l := v.Len() + for i := 0; i < l; i++ { + d.value(v.Index(i)) + } + + case reflect.Struct: + t := v.Type() + l := v.NumField() + for i := 0; i < l; i++ { + // Note: Calling v.CanSet() below is an optimization. + // It would be sufficient to check the field name, + // but creating the StructField info for each field is + // costly (run "go test -bench=ReadStruct" and compare + // results when making changes to this code). + if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" { + d.value(v) + } else { + d.skip(v) + } + } + + case reflect.Slice: + l := v.Len() + for i := 0; i < l; i++ { + d.value(v.Index(i)) + } + + case reflect.Int8: + v.SetInt(int64(d.int8())) + case reflect.Int16: + v.SetInt(int64(d.int16())) + case reflect.Int32: + v.SetInt(int64(d.int32())) + case reflect.Int64: + v.SetInt(d.int64()) + + case reflect.Uint8: + v.SetUint(uint64(d.uint8())) + case reflect.Uint16: + v.SetUint(uint64(d.uint16())) + case reflect.Uint32: + v.SetUint(uint64(d.uint32())) + case reflect.Uint64: + v.SetUint(d.uint64()) + + case reflect.Float32: + v.SetFloat(float64(math.Float32frombits(d.uint32()))) + case reflect.Float64: + v.SetFloat(math.Float64frombits(d.uint64())) + + case reflect.Complex64: + v.SetComplex(complex( + float64(math.Float32frombits(d.uint32())), + float64(math.Float32frombits(d.uint32())), + )) + case reflect.Complex128: + v.SetComplex(complex( + math.Float64frombits(d.uint64()), + math.Float64frombits(d.uint64()), + )) + } +} + +func (e *encoder) value(v reflect.Value) { + switch v.Kind() { + case reflect.Array: + l := v.Len() + for i := 0; i < l; i++ { + e.value(v.Index(i)) + } + + case reflect.Struct: + t := v.Type() + l := v.NumField() + for i := 0; i < l; i++ { + // see comment for corresponding code in decoder.value() + if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" { + e.value(v) + } else { + e.skip(v) + } + } + + case reflect.Slice: + l := v.Len() + for i := 0; i < l; i++ { + e.value(v.Index(i)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch v.Type().Kind() { + case reflect.Int8: + e.int8(int8(v.Int())) + case reflect.Int16: + e.int16(int16(v.Int())) + case reflect.Int32: + e.int32(int32(v.Int())) + case reflect.Int64: + e.int64(v.Int()) + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch v.Type().Kind() { + case reflect.Uint8: + e.uint8(uint8(v.Uint())) + case reflect.Uint16: + e.uint16(uint16(v.Uint())) + case reflect.Uint32: + e.uint32(uint32(v.Uint())) + case reflect.Uint64: + e.uint64(v.Uint()) + } + + case reflect.Float32, reflect.Float64: + switch v.Type().Kind() { + case reflect.Float32: + e.uint32(math.Float32bits(float32(v.Float()))) + case reflect.Float64: + e.uint64(math.Float64bits(v.Float())) + } + + case reflect.Complex64, reflect.Complex128: + switch v.Type().Kind() { + case reflect.Complex64: + x := v.Complex() + e.uint32(math.Float32bits(float32(real(x)))) + e.uint32(math.Float32bits(float32(imag(x)))) + case reflect.Complex128: + x := v.Complex() + e.uint64(math.Float64bits(real(x))) + e.uint64(math.Float64bits(imag(x))) + } + } +} + +func (d *decoder) skip(v reflect.Value) { + d.buf = d.buf[dataSize(v):] +} + +func (e *encoder) skip(v reflect.Value) { + n := dataSize(v) + for i := range e.buf[0:n] { + e.buf[i] = 0 + } + e.buf = e.buf[n:] +} + +// intDataSize returns the size of the data required to represent the data when encoded. +// It returns zero if the type cannot be implemented by the fast path in Read or Write. +func intDataSize(data interface{}) int { + switch data := data.(type) { + case int8, *int8, *uint8: + return 1 + case []int8: + return len(data) + case []uint8: + return len(data) + case int16, *int16, *uint16: + return 2 + case []int16: + return 2 * len(data) + case []uint16: + return 2 * len(data) + case int32, *int32, *uint32: + return 4 + case []int32: + return 4 * len(data) + case []uint32: + return 4 * len(data) + case int64, *int64, *uint64: + return 8 + case []int64: + return 8 * len(data) + case []uint64: + return 8 * len(data) + } + return 0 +} diff --git a/src/encoding/binary/binary_test.go b/src/encoding/binary/binary_test.go new file mode 100644 index 000000000..8ee595fa4 --- /dev/null +++ b/src/encoding/binary/binary_test.go @@ -0,0 +1,416 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package binary + +import ( + "bytes" + "io" + "math" + "reflect" + "strings" + "testing" +) + +type Struct struct { + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Complex64 complex64 + Complex128 complex128 + Array [4]uint8 +} + +type T struct { + Int int + Uint uint + Uintptr uintptr + Array [4]int +} + +var s = Struct{ + 0x01, + 0x0203, + 0x04050607, + 0x08090a0b0c0d0e0f, + 0x10, + 0x1112, + 0x13141516, + 0x1718191a1b1c1d1e, + + math.Float32frombits(0x1f202122), + math.Float64frombits(0x232425262728292a), + complex( + math.Float32frombits(0x2b2c2d2e), + math.Float32frombits(0x2f303132), + ), + complex( + math.Float64frombits(0x333435363738393a), + math.Float64frombits(0x3b3c3d3e3f404142), + ), + + [4]uint8{0x43, 0x44, 0x45, 0x46}, +} + +var big = []byte{ + 1, + 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, + 17, 18, + 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, + + 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 47, 48, 49, 50, + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, + + 67, 68, 69, 70, +} + +var little = []byte{ + 1, + 3, 2, + 7, 6, 5, 4, + 15, 14, 13, 12, 11, 10, 9, 8, + 16, + 18, 17, + 22, 21, 20, 19, + 30, 29, 28, 27, 26, 25, 24, 23, + + 34, 33, 32, 31, + 42, 41, 40, 39, 38, 37, 36, 35, + 46, 45, 44, 43, 50, 49, 48, 47, + 58, 57, 56, 55, 54, 53, 52, 51, 66, 65, 64, 63, 62, 61, 60, 59, + + 67, 68, 69, 70, +} + +var src = []byte{1, 2, 3, 4, 5, 6, 7, 8} +var res = []int32{0x01020304, 0x05060708} + +func checkResult(t *testing.T, dir string, order ByteOrder, err error, have, want interface{}) { + if err != nil { + t.Errorf("%v %v: %v", dir, order, err) + return + } + if !reflect.DeepEqual(have, want) { + t.Errorf("%v %v:\n\thave %+v\n\twant %+v", dir, order, have, want) + } +} + +func testRead(t *testing.T, order ByteOrder, b []byte, s1 interface{}) { + var s2 Struct + err := Read(bytes.NewReader(b), order, &s2) + checkResult(t, "Read", order, err, s2, s1) +} + +func testWrite(t *testing.T, order ByteOrder, b []byte, s1 interface{}) { + buf := new(bytes.Buffer) + err := Write(buf, order, s1) + checkResult(t, "Write", order, err, buf.Bytes(), b) +} + +func TestLittleEndianRead(t *testing.T) { testRead(t, LittleEndian, little, s) } +func TestLittleEndianWrite(t *testing.T) { testWrite(t, LittleEndian, little, s) } +func TestLittleEndianPtrWrite(t *testing.T) { testWrite(t, LittleEndian, little, &s) } + +func TestBigEndianRead(t *testing.T) { testRead(t, BigEndian, big, s) } +func TestBigEndianWrite(t *testing.T) { testWrite(t, BigEndian, big, s) } +func TestBigEndianPtrWrite(t *testing.T) { testWrite(t, BigEndian, big, &s) } + +func TestReadSlice(t *testing.T) { + slice := make([]int32, 2) + err := Read(bytes.NewReader(src), BigEndian, slice) + checkResult(t, "ReadSlice", BigEndian, err, slice, res) +} + +func TestWriteSlice(t *testing.T) { + buf := new(bytes.Buffer) + err := Write(buf, BigEndian, res) + checkResult(t, "WriteSlice", BigEndian, err, buf.Bytes(), src) +} + +// Addresses of arrays are easier to manipulate with reflection than are slices. +var intArrays = []interface{}{ + &[100]int8{}, + &[100]int16{}, + &[100]int32{}, + &[100]int64{}, + &[100]uint8{}, + &[100]uint16{}, + &[100]uint32{}, + &[100]uint64{}, +} + +func TestSliceRoundTrip(t *testing.T) { + buf := new(bytes.Buffer) + for _, array := range intArrays { + src := reflect.ValueOf(array).Elem() + unsigned := false + switch src.Index(0).Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + unsigned = true + } + for i := 0; i < src.Len(); i++ { + if unsigned { + src.Index(i).SetUint(uint64(i * 0x07654321)) + } else { + src.Index(i).SetInt(int64(i * 0x07654321)) + } + } + buf.Reset() + srcSlice := src.Slice(0, src.Len()) + err := Write(buf, BigEndian, srcSlice.Interface()) + if err != nil { + t.Fatal(err) + } + dst := reflect.New(src.Type()).Elem() + dstSlice := dst.Slice(0, dst.Len()) + err = Read(buf, BigEndian, dstSlice.Interface()) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(src.Interface(), dst.Interface()) { + t.Fatal(src) + } + } +} + +func TestWriteT(t *testing.T) { + buf := new(bytes.Buffer) + ts := T{} + if err := Write(buf, BigEndian, ts); err == nil { + t.Errorf("WriteT: have err == nil, want non-nil") + } + + tv := reflect.Indirect(reflect.ValueOf(ts)) + for i, n := 0, tv.NumField(); i < n; i++ { + typ := tv.Field(i).Type().String() + if typ == "[4]int" { + typ = "int" // the problem is int, not the [4] + } + if err := Write(buf, BigEndian, tv.Field(i).Interface()); err == nil { + t.Errorf("WriteT.%v: have err == nil, want non-nil", tv.Field(i).Type()) + } else if !strings.Contains(err.Error(), typ) { + t.Errorf("WriteT: have err == %q, want it to mention %s", err, typ) + } + } +} + +type BlankFields struct { + A uint32 + _ int32 + B float64 + _ [4]int16 + C byte + _ [7]byte + _ struct { + f [8]float32 + } +} + +type BlankFieldsProbe struct { + A uint32 + P0 int32 + B float64 + P1 [4]int16 + C byte + P2 [7]byte + P3 struct { + F [8]float32 + } +} + +func TestBlankFields(t *testing.T) { + buf := new(bytes.Buffer) + b1 := BlankFields{A: 1234567890, B: 2.718281828, C: 42} + if err := Write(buf, LittleEndian, &b1); err != nil { + t.Error(err) + } + + // zero values must have been written for blank fields + var p BlankFieldsProbe + if err := Read(buf, LittleEndian, &p); err != nil { + t.Error(err) + } + + // quick test: only check first value of slices + if p.P0 != 0 || p.P1[0] != 0 || p.P2[0] != 0 || p.P3.F[0] != 0 { + t.Errorf("non-zero values for originally blank fields: %#v", p) + } + + // write p and see if we can probe only some fields + if err := Write(buf, LittleEndian, &p); err != nil { + t.Error(err) + } + + // read should ignore blank fields in b2 + var b2 BlankFields + if err := Read(buf, LittleEndian, &b2); err != nil { + t.Error(err) + } + if b1.A != b2.A || b1.B != b2.B || b1.C != b2.C { + t.Errorf("%#v != %#v", b1, b2) + } +} + +// An attempt to read into a struct with an unexported field will +// panic. This is probably not the best choice, but at this point +// anything else would be an API change. + +type Unexported struct { + a int32 +} + +func TestUnexportedRead(t *testing.T) { + var buf bytes.Buffer + u1 := Unexported{a: 1} + if err := Write(&buf, LittleEndian, &u1); err != nil { + t.Fatal(err) + } + + defer func() { + if recover() == nil { + t.Fatal("did not panic") + } + }() + var u2 Unexported + Read(&buf, LittleEndian, &u2) +} + +func TestReadErrorMsg(t *testing.T) { + var buf bytes.Buffer + read := func(data interface{}) { + err := Read(&buf, LittleEndian, data) + want := "binary.Read: invalid type " + reflect.TypeOf(data).String() + if err == nil { + t.Errorf("%T: got no error; want %q", data, want) + return + } + if got := err.Error(); got != want { + t.Errorf("%T: got %q; want %q", data, got, want) + } + } + read(0) + s := new(struct{}) + read(&s) + p := &s + read(&p) +} + +type byteSliceReader struct { + remain []byte +} + +func (br *byteSliceReader) Read(p []byte) (int, error) { + n := copy(p, br.remain) + br.remain = br.remain[n:] + return n, nil +} + +func BenchmarkReadSlice1000Int32s(b *testing.B) { + bsr := &byteSliceReader{} + slice := make([]int32, 1000) + buf := make([]byte, len(slice)*4) + b.SetBytes(int64(len(buf))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + bsr.remain = buf + Read(bsr, BigEndian, slice) + } +} + +func BenchmarkReadStruct(b *testing.B) { + bsr := &byteSliceReader{} + var buf bytes.Buffer + Write(&buf, BigEndian, &s) + b.SetBytes(int64(dataSize(reflect.ValueOf(s)))) + t := s + b.ResetTimer() + for i := 0; i < b.N; i++ { + bsr.remain = buf.Bytes() + Read(bsr, BigEndian, &t) + } + b.StopTimer() + if !reflect.DeepEqual(s, t) { + b.Fatal("no match") + } +} + +func BenchmarkReadInts(b *testing.B) { + var ls Struct + bsr := &byteSliceReader{} + var r io.Reader = bsr + b.SetBytes(2 * (1 + 2 + 4 + 8)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + bsr.remain = big + Read(r, BigEndian, &ls.Int8) + Read(r, BigEndian, &ls.Int16) + Read(r, BigEndian, &ls.Int32) + Read(r, BigEndian, &ls.Int64) + Read(r, BigEndian, &ls.Uint8) + Read(r, BigEndian, &ls.Uint16) + Read(r, BigEndian, &ls.Uint32) + Read(r, BigEndian, &ls.Uint64) + } + + want := s + want.Float32 = 0 + want.Float64 = 0 + want.Complex64 = 0 + want.Complex128 = 0 + for i := range want.Array { + want.Array[i] = 0 + } + b.StopTimer() + if !reflect.DeepEqual(ls, want) { + panic("no match") + } +} + +func BenchmarkWriteInts(b *testing.B) { + buf := new(bytes.Buffer) + var w io.Writer = buf + b.SetBytes(2 * (1 + 2 + 4 + 8)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + Write(w, BigEndian, s.Int8) + Write(w, BigEndian, s.Int16) + Write(w, BigEndian, s.Int32) + Write(w, BigEndian, s.Int64) + Write(w, BigEndian, s.Uint8) + Write(w, BigEndian, s.Uint16) + Write(w, BigEndian, s.Uint32) + Write(w, BigEndian, s.Uint64) + } + b.StopTimer() + if !bytes.Equal(buf.Bytes(), big[:30]) { + b.Fatalf("first half doesn't match: %x %x", buf.Bytes(), big[:30]) + } +} + +func BenchmarkWriteSlice1000Int32s(b *testing.B) { + slice := make([]int32, 1000) + buf := new(bytes.Buffer) + var w io.Writer = buf + b.SetBytes(4 * 1000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + Write(w, BigEndian, slice) + } + b.StopTimer() +} diff --git a/src/encoding/binary/example_test.go b/src/encoding/binary/example_test.go new file mode 100644 index 000000000..067cf553b --- /dev/null +++ b/src/encoding/binary/example_test.go @@ -0,0 +1,52 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package binary_test + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" +) + +func ExampleWrite() { + buf := new(bytes.Buffer) + var pi float64 = math.Pi + err := binary.Write(buf, binary.LittleEndian, pi) + if err != nil { + fmt.Println("binary.Write failed:", err) + } + fmt.Printf("% x", buf.Bytes()) + // Output: 18 2d 44 54 fb 21 09 40 +} + +func ExampleWrite_multi() { + buf := new(bytes.Buffer) + var data = []interface{}{ + uint16(61374), + int8(-54), + uint8(254), + } + for _, v := range data { + err := binary.Write(buf, binary.LittleEndian, v) + if err != nil { + fmt.Println("binary.Write failed:", err) + } + } + fmt.Printf("%x", buf.Bytes()) + // Output: beefcafe +} + +func ExampleRead() { + var pi float64 + b := []byte{0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40} + buf := bytes.NewReader(b) + err := binary.Read(buf, binary.LittleEndian, &pi) + if err != nil { + fmt.Println("binary.Read failed:", err) + } + fmt.Print(pi) + // Output: 3.141592653589793 +} diff --git a/src/encoding/binary/varint.go b/src/encoding/binary/varint.go new file mode 100644 index 000000000..3a2dfa3c7 --- /dev/null +++ b/src/encoding/binary/varint.go @@ -0,0 +1,133 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package binary + +// This file implements "varint" encoding of 64-bit integers. +// The encoding is: +// - unsigned integers are serialized 7 bits at a time, starting with the +// least significant bits +// - the most significant bit (msb) in each output byte indicates if there +// is a continuation byte (msb = 1) +// - signed integers are mapped to unsigned integers using "zig-zag" +// encoding: Positive values x are written as 2*x + 0, negative values +// are written as 2*(^x) + 1; that is, negative numbers are complemented +// and whether to complement is encoded in bit 0. +// +// Design note: +// At most 10 bytes are needed for 64-bit values. The encoding could +// be more dense: a full 64-bit value needs an extra byte just to hold bit 63. +// Instead, the msb of the previous byte could be used to hold bit 63 since we +// know there can't be more than 64 bits. This is a trivial improvement and +// would reduce the maximum encoding length to 9 bytes. However, it breaks the +// invariant that the msb is always the "continuation bit" and thus makes the +// format incompatible with a varint encoding for larger numbers (say 128-bit). + +import ( + "errors" + "io" +) + +// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer. +const ( + MaxVarintLen16 = 3 + MaxVarintLen32 = 5 + MaxVarintLen64 = 10 +) + +// PutUvarint encodes a uint64 into buf and returns the number of bytes written. +// If the buffer is too small, PutUvarint will panic. +func PutUvarint(buf []byte, x uint64) int { + i := 0 + for x >= 0x80 { + buf[i] = byte(x) | 0x80 + x >>= 7 + i++ + } + buf[i] = byte(x) + return i + 1 +} + +// Uvarint decodes a uint64 from buf and returns that value and the +// number of bytes read (> 0). If an error occurred, the value is 0 +// and the number of bytes n is <= 0 meaning: +// +// n == 0: buf too small +// n < 0: value larger than 64 bits (overflow) +// and -n is the number of bytes read +// +func Uvarint(buf []byte) (uint64, int) { + var x uint64 + var s uint + for i, b := range buf { + if b < 0x80 { + if i > 9 || i == 9 && b > 1 { + return 0, -(i + 1) // overflow + } + return x | uint64(b)<<s, i + 1 + } + x |= uint64(b&0x7f) << s + s += 7 + } + return 0, 0 +} + +// PutVarint encodes an int64 into buf and returns the number of bytes written. +// If the buffer is too small, PutVarint will panic. +func PutVarint(buf []byte, x int64) int { + ux := uint64(x) << 1 + if x < 0 { + ux = ^ux + } + return PutUvarint(buf, ux) +} + +// Varint decodes an int64 from buf and returns that value and the +// number of bytes read (> 0). If an error occurred, the value is 0 +// and the number of bytes n is <= 0 with the following meaning: +// +// n == 0: buf too small +// n < 0: value larger than 64 bits (overflow) +// and -n is the number of bytes read +// +func Varint(buf []byte) (int64, int) { + ux, n := Uvarint(buf) // ok to continue in presence of error + x := int64(ux >> 1) + if ux&1 != 0 { + x = ^x + } + return x, n +} + +var overflow = errors.New("binary: varint overflows a 64-bit integer") + +// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64. +func ReadUvarint(r io.ByteReader) (uint64, error) { + var x uint64 + var s uint + for i := 0; ; i++ { + b, err := r.ReadByte() + if err != nil { + return x, err + } + if b < 0x80 { + if i > 9 || i == 9 && b > 1 { + return x, overflow + } + return x | uint64(b)<<s, nil + } + x |= uint64(b&0x7f) << s + s += 7 + } +} + +// ReadVarint reads an encoded signed integer from r and returns it as an int64. +func ReadVarint(r io.ByteReader) (int64, error) { + ux, err := ReadUvarint(r) // ok to continue in presence of error + x := int64(ux >> 1) + if ux&1 != 0 { + x = ^x + } + return x, err +} diff --git a/src/encoding/binary/varint_test.go b/src/encoding/binary/varint_test.go new file mode 100644 index 000000000..ca411ecbd --- /dev/null +++ b/src/encoding/binary/varint_test.go @@ -0,0 +1,168 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package binary + +import ( + "bytes" + "io" + "testing" +) + +func testConstant(t *testing.T, w uint, max int) { + buf := make([]byte, MaxVarintLen64) + n := PutUvarint(buf, 1<<w-1) + if n != max { + t.Errorf("MaxVarintLen%d = %d; want %d", w, max, n) + } +} + +func TestConstants(t *testing.T) { + testConstant(t, 16, MaxVarintLen16) + testConstant(t, 32, MaxVarintLen32) + testConstant(t, 64, MaxVarintLen64) +} + +func testVarint(t *testing.T, x int64) { + buf := make([]byte, MaxVarintLen64) + n := PutVarint(buf, x) + y, m := Varint(buf[0:n]) + if x != y { + t.Errorf("Varint(%d): got %d", x, y) + } + if n != m { + t.Errorf("Varint(%d): got n = %d; want %d", x, m, n) + } + + y, err := ReadVarint(bytes.NewReader(buf)) + if err != nil { + t.Errorf("ReadVarint(%d): %s", x, err) + } + if x != y { + t.Errorf("ReadVarint(%d): got %d", x, y) + } +} + +func testUvarint(t *testing.T, x uint64) { + buf := make([]byte, MaxVarintLen64) + n := PutUvarint(buf, x) + y, m := Uvarint(buf[0:n]) + if x != y { + t.Errorf("Uvarint(%d): got %d", x, y) + } + if n != m { + t.Errorf("Uvarint(%d): got n = %d; want %d", x, m, n) + } + + y, err := ReadUvarint(bytes.NewReader(buf)) + if err != nil { + t.Errorf("ReadUvarint(%d): %s", x, err) + } + if x != y { + t.Errorf("ReadUvarint(%d): got %d", x, y) + } +} + +var tests = []int64{ + -1 << 63, + -1<<63 + 1, + -1, + 0, + 1, + 2, + 10, + 20, + 63, + 64, + 65, + 127, + 128, + 129, + 255, + 256, + 257, + 1<<63 - 1, +} + +func TestVarint(t *testing.T) { + for _, x := range tests { + testVarint(t, x) + testVarint(t, -x) + } + for x := int64(0x7); x != 0; x <<= 1 { + testVarint(t, x) + testVarint(t, -x) + } +} + +func TestUvarint(t *testing.T) { + for _, x := range tests { + testUvarint(t, uint64(x)) + } + for x := uint64(0x7); x != 0; x <<= 1 { + testUvarint(t, x) + } +} + +func TestBufferTooSmall(t *testing.T) { + buf := []byte{0x80, 0x80, 0x80, 0x80} + for i := 0; i <= len(buf); i++ { + buf := buf[0:i] + x, n := Uvarint(buf) + if x != 0 || n != 0 { + t.Errorf("Uvarint(%v): got x = %d, n = %d", buf, x, n) + } + + x, err := ReadUvarint(bytes.NewReader(buf)) + if x != 0 || err != io.EOF { + t.Errorf("ReadUvarint(%v): got x = %d, err = %s", buf, x, err) + } + } +} + +func testOverflow(t *testing.T, buf []byte, n0 int, err0 error) { + x, n := Uvarint(buf) + if x != 0 || n != n0 { + t.Errorf("Uvarint(%v): got x = %d, n = %d; want 0, %d", buf, x, n, n0) + } + + x, err := ReadUvarint(bytes.NewReader(buf)) + if x != 0 || err != err0 { + t.Errorf("ReadUvarint(%v): got x = %d, err = %s; want 0, %s", buf, x, err, err0) + } +} + +func TestOverflow(t *testing.T) { + testOverflow(t, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x2}, -10, overflow) + testOverflow(t, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x1, 0, 0}, -13, overflow) +} + +func TestNonCanonicalZero(t *testing.T) { + buf := []byte{0x80, 0x80, 0x80, 0} + x, n := Uvarint(buf) + if x != 0 || n != 4 { + t.Errorf("Uvarint(%v): got x = %d, n = %d; want 0, 4", buf, x, n) + + } +} + +func BenchmarkPutUvarint32(b *testing.B) { + buf := make([]byte, MaxVarintLen32) + b.SetBytes(4) + for i := 0; i < b.N; i++ { + for j := uint(0); j < MaxVarintLen32; j++ { + PutUvarint(buf, 1<<(j*7)) + } + } +} + +func BenchmarkPutUvarint64(b *testing.B) { + buf := make([]byte, MaxVarintLen64) + b.SetBytes(8) + for i := 0; i < b.N; i++ { + for j := uint(0); j < MaxVarintLen64; j++ { + PutUvarint(buf, 1<<(j*7)) + } + } +} diff --git a/src/encoding/csv/reader.go b/src/encoding/csv/reader.go new file mode 100644 index 000000000..d9432954a --- /dev/null +++ b/src/encoding/csv/reader.go @@ -0,0 +1,337 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package csv reads and writes comma-separated values (CSV) files. +// +// A csv file contains zero or more records of one or more fields per record. +// Each record is separated by the newline character. The final record may +// optionally be followed by a newline character. +// +// field1,field2,field3 +// +// White space is considered part of a field. +// +// Carriage returns before newline characters are silently removed. +// +// Blank lines are ignored. A line with only whitespace characters (excluding +// the ending newline character) is not considered a blank line. +// +// Fields which start and stop with the quote character " are called +// quoted-fields. The beginning and ending quote are not part of the +// field. +// +// The source: +// +// normal string,"quoted-field" +// +// results in the fields +// +// {`normal string`, `quoted-field`} +// +// Within a quoted-field a quote character followed by a second quote +// character is considered a single quote. +// +// "the ""word"" is true","a ""quoted-field""" +// +// results in +// +// {`the "word" is true`, `a "quoted-field"`} +// +// Newlines and commas may be included in a quoted-field +// +// "Multi-line +// field","comma is ," +// +// results in +// +// {`Multi-line +// field`, `comma is ,`} +package csv + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "unicode" +) + +// A ParseError is returned for parsing errors. +// The first line is 1. The first column is 0. +type ParseError struct { + Line int // Line where the error occurred + Column int // Column (rune index) where the error occurred + Err error // The actual error +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("line %d, column %d: %s", e.Line, e.Column, e.Err) +} + +// These are the errors that can be returned in ParseError.Error +var ( + ErrTrailingComma = errors.New("extra delimiter at end of line") // no longer used + ErrBareQuote = errors.New("bare \" in non-quoted-field") + ErrQuote = errors.New("extraneous \" in field") + ErrFieldCount = errors.New("wrong number of fields in line") +) + +// A Reader reads records from a CSV-encoded file. +// +// As returned by NewReader, a Reader expects input conforming to RFC 4180. +// The exported fields can be changed to customize the details before the +// first call to Read or ReadAll. +// +// Comma is the field delimiter. It defaults to ','. +// +// Comment, if not 0, is the comment character. Lines beginning with the +// Comment character are ignored. +// +// If FieldsPerRecord is positive, Read requires each record to +// have the given number of fields. If FieldsPerRecord is 0, Read sets it to +// the number of fields in the first record, so that future records must +// have the same field count. If FieldsPerRecord is negative, no check is +// made and records may have a variable number of fields. +// +// If LazyQuotes is true, a quote may appear in an unquoted field and a +// non-doubled quote may appear in a quoted field. +// +// If TrimLeadingSpace is true, leading white space in a field is ignored. +type Reader struct { + Comma rune // field delimiter (set to ',' by NewReader) + Comment rune // comment character for start of line + FieldsPerRecord int // number of expected fields per record + LazyQuotes bool // allow lazy quotes + TrailingComma bool // ignored; here for backwards compatibility + TrimLeadingSpace bool // trim leading space + line int + column int + r *bufio.Reader + field bytes.Buffer +} + +// NewReader returns a new Reader that reads from r. +func NewReader(r io.Reader) *Reader { + return &Reader{ + Comma: ',', + r: bufio.NewReader(r), + } +} + +// error creates a new ParseError based on err. +func (r *Reader) error(err error) error { + return &ParseError{ + Line: r.line, + Column: r.column, + Err: err, + } +} + +// Read reads one record from r. The record is a slice of strings with each +// string representing one field. +func (r *Reader) Read() (record []string, err error) { + for { + record, err = r.parseRecord() + if record != nil { + break + } + if err != nil { + return nil, err + } + } + + if r.FieldsPerRecord > 0 { + if len(record) != r.FieldsPerRecord { + r.column = 0 // report at start of record + return record, r.error(ErrFieldCount) + } + } else if r.FieldsPerRecord == 0 { + r.FieldsPerRecord = len(record) + } + return record, nil +} + +// ReadAll reads all the remaining records from r. +// Each record is a slice of fields. +// A successful call returns err == nil, not err == EOF. Because ReadAll is +// defined to read until EOF, it does not treat end of file as an error to be +// reported. +func (r *Reader) ReadAll() (records [][]string, err error) { + for { + record, err := r.Read() + if err == io.EOF { + return records, nil + } + if err != nil { + return nil, err + } + records = append(records, record) + } +} + +// readRune reads one rune from r, folding \r\n to \n and keeping track +// of how far into the line we have read. r.column will point to the start +// of this rune, not the end of this rune. +func (r *Reader) readRune() (rune, error) { + r1, _, err := r.r.ReadRune() + + // Handle \r\n here. We make the simplifying assumption that + // anytime \r is followed by \n that it can be folded to \n. + // We will not detect files which contain both \r\n and bare \n. + if r1 == '\r' { + r1, _, err = r.r.ReadRune() + if err == nil { + if r1 != '\n' { + r.r.UnreadRune() + r1 = '\r' + } + } + } + r.column++ + return r1, err +} + +// skip reads runes up to and including the rune delim or until error. +func (r *Reader) skip(delim rune) error { + for { + r1, err := r.readRune() + if err != nil { + return err + } + if r1 == delim { + return nil + } + } +} + +// parseRecord reads and parses a single csv record from r. +func (r *Reader) parseRecord() (fields []string, err error) { + // Each record starts on a new line. We increment our line + // number (lines start at 1, not 0) and set column to -1 + // so as we increment in readRune it points to the character we read. + r.line++ + r.column = -1 + + // Peek at the first rune. If it is an error we are done. + // If we are support comments and it is the comment character + // then skip to the end of line. + + r1, _, err := r.r.ReadRune() + if err != nil { + return nil, err + } + + if r.Comment != 0 && r1 == r.Comment { + return nil, r.skip('\n') + } + r.r.UnreadRune() + + // At this point we have at least one field. + for { + haveField, delim, err := r.parseField() + if haveField { + fields = append(fields, r.field.String()) + } + if delim == '\n' || err == io.EOF { + return fields, err + } else if err != nil { + return nil, err + } + } +} + +// parseField parses the next field in the record. The read field is +// located in r.field. Delim is the first character not part of the field +// (r.Comma or '\n'). +func (r *Reader) parseField() (haveField bool, delim rune, err error) { + r.field.Reset() + + r1, err := r.readRune() + for err == nil && r.TrimLeadingSpace && r1 != '\n' && unicode.IsSpace(r1) { + r1, err = r.readRune() + } + + if err == io.EOF && r.column != 0 { + return true, 0, err + } + if err != nil { + return false, 0, err + } + + switch r1 { + case r.Comma: + // will check below + + case '\n': + // We are a trailing empty field or a blank line + if r.column == 0 { + return false, r1, nil + } + return true, r1, nil + + case '"': + // quoted field + Quoted: + for { + r1, err = r.readRune() + if err != nil { + if err == io.EOF { + if r.LazyQuotes { + return true, 0, err + } + return false, 0, r.error(ErrQuote) + } + return false, 0, err + } + switch r1 { + case '"': + r1, err = r.readRune() + if err != nil || r1 == r.Comma { + break Quoted + } + if r1 == '\n' { + return true, r1, nil + } + if r1 != '"' { + if !r.LazyQuotes { + r.column-- + return false, 0, r.error(ErrQuote) + } + // accept the bare quote + r.field.WriteRune('"') + } + case '\n': + r.line++ + r.column = -1 + } + r.field.WriteRune(r1) + } + + default: + // unquoted field + for { + r.field.WriteRune(r1) + r1, err = r.readRune() + if err != nil || r1 == r.Comma { + break + } + if r1 == '\n' { + return true, r1, nil + } + if !r.LazyQuotes && r1 == '"' { + return false, 0, r.error(ErrBareQuote) + } + } + } + + if err != nil { + if err == io.EOF { + return true, 0, err + } + return false, 0, err + } + + return true, r1, nil +} diff --git a/src/encoding/csv/reader_test.go b/src/encoding/csv/reader_test.go new file mode 100644 index 000000000..123df06bc --- /dev/null +++ b/src/encoding/csv/reader_test.go @@ -0,0 +1,284 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package csv + +import ( + "reflect" + "strings" + "testing" +) + +var readTests = []struct { + Name string + Input string + Output [][]string + UseFieldsPerRecord bool // false (default) means FieldsPerRecord is -1 + + // These fields are copied into the Reader + Comma rune + Comment rune + FieldsPerRecord int + LazyQuotes bool + TrailingComma bool + TrimLeadingSpace bool + + Error string + Line int // Expected error line if != 0 + Column int // Expected error column if line != 0 +}{ + { + Name: "Simple", + Input: "a,b,c\n", + Output: [][]string{{"a", "b", "c"}}, + }, + { + Name: "CRLF", + Input: "a,b\r\nc,d\r\n", + Output: [][]string{{"a", "b"}, {"c", "d"}}, + }, + { + Name: "BareCR", + Input: "a,b\rc,d\r\n", + Output: [][]string{{"a", "b\rc", "d"}}, + }, + { + Name: "RFC4180test", + UseFieldsPerRecord: true, + Input: `#field1,field2,field3 +"aaa","bb +b","ccc" +"a,a","b""bb","ccc" +zzz,yyy,xxx +`, + Output: [][]string{ + {"#field1", "field2", "field3"}, + {"aaa", "bb\nb", "ccc"}, + {"a,a", `b"bb`, "ccc"}, + {"zzz", "yyy", "xxx"}, + }, + }, + { + Name: "NoEOLTest", + Input: "a,b,c", + Output: [][]string{{"a", "b", "c"}}, + }, + { + Name: "Semicolon", + Comma: ';', + Input: "a;b;c\n", + Output: [][]string{{"a", "b", "c"}}, + }, + { + Name: "MultiLine", + Input: `"two +line","one line","three +line +field"`, + Output: [][]string{{"two\nline", "one line", "three\nline\nfield"}}, + }, + { + Name: "BlankLine", + Input: "a,b,c\n\nd,e,f\n\n", + Output: [][]string{ + {"a", "b", "c"}, + {"d", "e", "f"}, + }, + }, + { + Name: "TrimSpace", + Input: " a, b, c\n", + TrimLeadingSpace: true, + Output: [][]string{{"a", "b", "c"}}, + }, + { + Name: "LeadingSpace", + Input: " a, b, c\n", + Output: [][]string{{" a", " b", " c"}}, + }, + { + Name: "Comment", + Comment: '#', + Input: "#1,2,3\na,b,c\n#comment", + Output: [][]string{{"a", "b", "c"}}, + }, + { + Name: "NoComment", + Input: "#1,2,3\na,b,c", + Output: [][]string{{"#1", "2", "3"}, {"a", "b", "c"}}, + }, + { + Name: "LazyQuotes", + LazyQuotes: true, + Input: `a "word","1"2",a","b`, + Output: [][]string{{`a "word"`, `1"2`, `a"`, `b`}}, + }, + { + Name: "BareQuotes", + LazyQuotes: true, + Input: `a "word","1"2",a"`, + Output: [][]string{{`a "word"`, `1"2`, `a"`}}, + }, + { + Name: "BareDoubleQuotes", + LazyQuotes: true, + Input: `a""b,c`, + Output: [][]string{{`a""b`, `c`}}, + }, + { + Name: "BadDoubleQuotes", + Input: `a""b,c`, + Error: `bare " in non-quoted-field`, Line: 1, Column: 1, + }, + { + Name: "TrimQuote", + Input: ` "a"," b",c`, + TrimLeadingSpace: true, + Output: [][]string{{"a", " b", "c"}}, + }, + { + Name: "BadBareQuote", + Input: `a "word","b"`, + Error: `bare " in non-quoted-field`, Line: 1, Column: 2, + }, + { + Name: "BadTrailingQuote", + Input: `"a word",b"`, + Error: `bare " in non-quoted-field`, Line: 1, Column: 10, + }, + { + Name: "ExtraneousQuote", + Input: `"a "word","b"`, + Error: `extraneous " in field`, Line: 1, Column: 3, + }, + { + Name: "BadFieldCount", + UseFieldsPerRecord: true, + Input: "a,b,c\nd,e", + Error: "wrong number of fields", Line: 2, + }, + { + Name: "BadFieldCount1", + UseFieldsPerRecord: true, + FieldsPerRecord: 2, + Input: `a,b,c`, + Error: "wrong number of fields", Line: 1, + }, + { + Name: "FieldCount", + Input: "a,b,c\nd,e", + Output: [][]string{{"a", "b", "c"}, {"d", "e"}}, + }, + { + Name: "TrailingCommaEOF", + Input: "a,b,c,", + Output: [][]string{{"a", "b", "c", ""}}, + }, + { + Name: "TrailingCommaEOL", + Input: "a,b,c,\n", + Output: [][]string{{"a", "b", "c", ""}}, + }, + { + Name: "TrailingCommaSpaceEOF", + TrimLeadingSpace: true, + Input: "a,b,c, ", + Output: [][]string{{"a", "b", "c", ""}}, + }, + { + Name: "TrailingCommaSpaceEOL", + TrimLeadingSpace: true, + Input: "a,b,c, \n", + Output: [][]string{{"a", "b", "c", ""}}, + }, + { + Name: "TrailingCommaLine3", + TrimLeadingSpace: true, + Input: "a,b,c\nd,e,f\ng,hi,", + Output: [][]string{{"a", "b", "c"}, {"d", "e", "f"}, {"g", "hi", ""}}, + }, + { + Name: "NotTrailingComma3", + Input: "a,b,c, \n", + Output: [][]string{{"a", "b", "c", " "}}, + }, + { + Name: "CommaFieldTest", + TrailingComma: true, + Input: `x,y,z,w +x,y,z, +x,y,, +x,,, +,,, +"x","y","z","w" +"x","y","z","" +"x","y","","" +"x","","","" +"","","","" +`, + Output: [][]string{ + {"x", "y", "z", "w"}, + {"x", "y", "z", ""}, + {"x", "y", "", ""}, + {"x", "", "", ""}, + {"", "", "", ""}, + {"x", "y", "z", "w"}, + {"x", "y", "z", ""}, + {"x", "y", "", ""}, + {"x", "", "", ""}, + {"", "", "", ""}, + }, + }, + { + Name: "TrailingCommaIneffective1", + TrailingComma: true, + TrimLeadingSpace: true, + Input: "a,b,\nc,d,e", + Output: [][]string{ + {"a", "b", ""}, + {"c", "d", "e"}, + }, + }, + { + Name: "TrailingCommaIneffective2", + TrailingComma: false, + TrimLeadingSpace: true, + Input: "a,b,\nc,d,e", + Output: [][]string{ + {"a", "b", ""}, + {"c", "d", "e"}, + }, + }, +} + +func TestRead(t *testing.T) { + for _, tt := range readTests { + r := NewReader(strings.NewReader(tt.Input)) + r.Comment = tt.Comment + if tt.UseFieldsPerRecord { + r.FieldsPerRecord = tt.FieldsPerRecord + } else { + r.FieldsPerRecord = -1 + } + r.LazyQuotes = tt.LazyQuotes + r.TrailingComma = tt.TrailingComma + r.TrimLeadingSpace = tt.TrimLeadingSpace + if tt.Comma != 0 { + r.Comma = tt.Comma + } + out, err := r.ReadAll() + perr, _ := err.(*ParseError) + if tt.Error != "" { + if err == nil || !strings.Contains(err.Error(), tt.Error) { + t.Errorf("%s: error %v, want error %q", tt.Name, err, tt.Error) + } else if tt.Line != 0 && (tt.Line != perr.Line || tt.Column != perr.Column) { + t.Errorf("%s: error at %d:%d expected %d:%d", tt.Name, perr.Line, perr.Column, tt.Line, tt.Column) + } + } else if err != nil { + t.Errorf("%s: unexpected error %v", tt.Name, err) + } else if !reflect.DeepEqual(out, tt.Output) { + t.Errorf("%s: out=%q want %q", tt.Name, out, tt.Output) + } + } +} diff --git a/src/encoding/csv/writer.go b/src/encoding/csv/writer.go new file mode 100644 index 000000000..17e7bb7f5 --- /dev/null +++ b/src/encoding/csv/writer.go @@ -0,0 +1,139 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package csv + +import ( + "bufio" + "io" + "strings" + "unicode" + "unicode/utf8" +) + +// A Writer writes records to a CSV encoded file. +// +// As returned by NewWriter, a Writer writes records terminated by a +// newline and uses ',' as the field delimiter. The exported fields can be +// changed to customize the details before the first call to Write or WriteAll. +// +// Comma is the field delimiter. +// +// If UseCRLF is true, the Writer ends each record with \r\n instead of \n. +type Writer struct { + Comma rune // Field delimiter (set to ',' by NewWriter) + UseCRLF bool // True to use \r\n as the line terminator + w *bufio.Writer +} + +// NewWriter returns a new Writer that writes to w. +func NewWriter(w io.Writer) *Writer { + return &Writer{ + Comma: ',', + w: bufio.NewWriter(w), + } +} + +// Writer writes a single CSV record to w along with any necessary quoting. +// A record is a slice of strings with each string being one field. +func (w *Writer) Write(record []string) (err error) { + for n, field := range record { + if n > 0 { + if _, err = w.w.WriteRune(w.Comma); err != nil { + return + } + } + + // If we don't have to have a quoted field then just + // write out the field and continue to the next field. + if !w.fieldNeedsQuotes(field) { + if _, err = w.w.WriteString(field); err != nil { + return + } + continue + } + if err = w.w.WriteByte('"'); err != nil { + return + } + + for _, r1 := range field { + switch r1 { + case '"': + _, err = w.w.WriteString(`""`) + case '\r': + if !w.UseCRLF { + err = w.w.WriteByte('\r') + } + case '\n': + if w.UseCRLF { + _, err = w.w.WriteString("\r\n") + } else { + err = w.w.WriteByte('\n') + } + default: + _, err = w.w.WriteRune(r1) + } + if err != nil { + return + } + } + + if err = w.w.WriteByte('"'); err != nil { + return + } + } + if w.UseCRLF { + _, err = w.w.WriteString("\r\n") + } else { + err = w.w.WriteByte('\n') + } + return +} + +// Flush writes any buffered data to the underlying io.Writer. +// To check if an error occurred during the Flush, call Error. +func (w *Writer) Flush() { + w.w.Flush() +} + +// Error reports any error that has occurred during a previous Write or Flush. +func (w *Writer) Error() error { + _, err := w.w.Write(nil) + return err +} + +// WriteAll writes multiple CSV records to w using Write and then calls Flush. +func (w *Writer) WriteAll(records [][]string) (err error) { + for _, record := range records { + err = w.Write(record) + if err != nil { + return err + } + } + return w.w.Flush() +} + +// fieldNeedsQuotes returns true if our field must be enclosed in quotes. +// Fields with a Comma, fields with a quote or newline, and +// fields which start with a space must be enclosed in quotes. +// We used to quote empty strings, but we do not anymore (as of Go 1.4). +// The two representations should be equivalent, but Postgres distinguishes +// quoted vs non-quoted empty string during database imports, and it has +// an option to force the quoted behavior for non-quoted CSV but it has +// no option to force the non-quoted behavior for quoted CSV, making +// CSV with quoted empty strings strictly less useful. +// Not quoting the empty string also makes this package match the behavior +// of Microsoft Excel and Google Drive. +// For Postgres, quote the data termating string `\.`. +func (w *Writer) fieldNeedsQuotes(field string) bool { + if field == "" { + return false + } + if field == `\.` || strings.IndexRune(field, w.Comma) >= 0 || strings.IndexAny(field, "\"\r\n") >= 0 { + return true + } + + r1, _ := utf8.DecodeRuneInString(field) + return unicode.IsSpace(r1) +} diff --git a/src/encoding/csv/writer_test.go b/src/encoding/csv/writer_test.go new file mode 100644 index 000000000..8ddca0abe --- /dev/null +++ b/src/encoding/csv/writer_test.go @@ -0,0 +1,85 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package csv + +import ( + "bytes" + "errors" + "testing" +) + +var writeTests = []struct { + Input [][]string + Output string + UseCRLF bool +}{ + {Input: [][]string{{"abc"}}, Output: "abc\n"}, + {Input: [][]string{{"abc"}}, Output: "abc\r\n", UseCRLF: true}, + {Input: [][]string{{`"abc"`}}, Output: `"""abc"""` + "\n"}, + {Input: [][]string{{`a"b`}}, Output: `"a""b"` + "\n"}, + {Input: [][]string{{`"a"b"`}}, Output: `"""a""b"""` + "\n"}, + {Input: [][]string{{" abc"}}, Output: `" abc"` + "\n"}, + {Input: [][]string{{"abc,def"}}, Output: `"abc,def"` + "\n"}, + {Input: [][]string{{"abc", "def"}}, Output: "abc,def\n"}, + {Input: [][]string{{"abc"}, {"def"}}, Output: "abc\ndef\n"}, + {Input: [][]string{{"abc\ndef"}}, Output: "\"abc\ndef\"\n"}, + {Input: [][]string{{"abc\ndef"}}, Output: "\"abc\r\ndef\"\r\n", UseCRLF: true}, + {Input: [][]string{{"abc\rdef"}}, Output: "\"abcdef\"\r\n", UseCRLF: true}, + {Input: [][]string{{"abc\rdef"}}, Output: "\"abc\rdef\"\n", UseCRLF: false}, + {Input: [][]string{{""}}, Output: "\n"}, + {Input: [][]string{{"", ""}}, Output: ",\n"}, + {Input: [][]string{{"", "", ""}}, Output: ",,\n"}, + {Input: [][]string{{"", "", "a"}}, Output: ",,a\n"}, + {Input: [][]string{{"", "a", ""}}, Output: ",a,\n"}, + {Input: [][]string{{"", "a", "a"}}, Output: ",a,a\n"}, + {Input: [][]string{{"a", "", ""}}, Output: "a,,\n"}, + {Input: [][]string{{"a", "", "a"}}, Output: "a,,a\n"}, + {Input: [][]string{{"a", "a", ""}}, Output: "a,a,\n"}, + {Input: [][]string{{"a", "a", "a"}}, Output: "a,a,a\n"}, + {Input: [][]string{{`\.`}}, Output: "\"\\.\"\n"}, +} + +func TestWrite(t *testing.T) { + for n, tt := range writeTests { + b := &bytes.Buffer{} + f := NewWriter(b) + f.UseCRLF = tt.UseCRLF + err := f.WriteAll(tt.Input) + if err != nil { + t.Errorf("Unexpected error: %s\n", err) + } + out := b.String() + if out != tt.Output { + t.Errorf("#%d: out=%q want %q", n, out, tt.Output) + } + } +} + +type errorWriter struct{} + +func (e errorWriter) Write(b []byte) (int, error) { + return 0, errors.New("Test") +} + +func TestError(t *testing.T) { + b := &bytes.Buffer{} + f := NewWriter(b) + f.Write([]string{"abc"}) + f.Flush() + err := f.Error() + + if err != nil { + t.Errorf("Unexpected error: %s\n", err) + } + + f = NewWriter(errorWriter{}) + f.Write([]string{"abc"}) + f.Flush() + err = f.Error() + + if err == nil { + t.Error("Error should not be nil") + } +} diff --git a/src/encoding/encoding.go b/src/encoding/encoding.go new file mode 100644 index 000000000..6d218071b --- /dev/null +++ b/src/encoding/encoding.go @@ -0,0 +1,48 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package encoding defines interfaces shared by other packages that +// convert data to and from byte-level and textual representations. +// Packages that check for these interfaces include encoding/gob, +// encoding/json, and encoding/xml. As a result, implementing an +// interface once can make a type useful in multiple encodings. +// Standard types that implement these interfaces include time.Time and net.IP. +// The interfaces come in pairs that produce and consume encoded data. +package encoding + +// BinaryMarshaler is the interface implemented by an object that can +// marshal itself into a binary form. +// +// MarshalBinary encodes the receiver into a binary form and returns the result. +type BinaryMarshaler interface { + MarshalBinary() (data []byte, err error) +} + +// BinaryUnmarshaler is the interface implemented by an object that can +// unmarshal a binary representation of itself. +// +// UnmarshalBinary must be able to decode the form generated by MarshalBinary. +// UnmarshalBinary must copy the data if it wishes to retain the data +// after returning. +type BinaryUnmarshaler interface { + UnmarshalBinary(data []byte) error +} + +// TextMarshaler is the interface implemented by an object that can +// marshal itself into a textual form. +// +// MarshalText encodes the receiver into UTF-8-encoded text and returns the result. +type TextMarshaler interface { + MarshalText() (text []byte, err error) +} + +// TextUnmarshaler is the interface implemented by an object that can +// unmarshal a textual representation of itself. +// +// UnmarshalText must be able to decode the form generated by MarshalText. +// UnmarshalText must copy the text if it wishes to retain the text +// after returning. +type TextUnmarshaler interface { + UnmarshalText(text []byte) error +} diff --git a/src/encoding/gob/codec_test.go b/src/encoding/gob/codec_test.go new file mode 100644 index 000000000..56a7298fa --- /dev/null +++ b/src/encoding/gob/codec_test.go @@ -0,0 +1,1475 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "bytes" + "errors" + "flag" + "math" + "math/rand" + "reflect" + "strings" + "testing" + "time" +) + +var doFuzzTests = flag.Bool("gob.fuzz", false, "run the fuzz tests, which are large and very slow") + +// Guarantee encoding format by comparing some encodings to hand-written values +type EncodeT struct { + x uint64 + b []byte +} + +var encodeT = []EncodeT{ + {0x00, []byte{0x00}}, + {0x0F, []byte{0x0F}}, + {0xFF, []byte{0xFF, 0xFF}}, + {0xFFFF, []byte{0xFE, 0xFF, 0xFF}}, + {0xFFFFFF, []byte{0xFD, 0xFF, 0xFF, 0xFF}}, + {0xFFFFFFFF, []byte{0xFC, 0xFF, 0xFF, 0xFF, 0xFF}}, + {0xFFFFFFFFFF, []byte{0xFB, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}}, + {0xFFFFFFFFFFFF, []byte{0xFA, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}}, + {0xFFFFFFFFFFFFFF, []byte{0xF9, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}}, + {0xFFFFFFFFFFFFFFFF, []byte{0xF8, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}}, + {0x1111, []byte{0xFE, 0x11, 0x11}}, + {0x1111111111111111, []byte{0xF8, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}}, + {0x8888888888888888, []byte{0xF8, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88}}, + {1 << 63, []byte{0xF8, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, +} + +// testError is meant to be used as a deferred function to turn a panic(gobError) into a +// plain test.Error call. +func testError(t *testing.T) { + if e := recover(); e != nil { + t.Error(e.(gobError).err) // Will re-panic if not one of our errors, such as a runtime error. + } + return +} + +func newDecBuffer(data []byte) *decBuffer { + return &decBuffer{ + data: data, + } +} + +// Test basic encode/decode routines for unsigned integers +func TestUintCodec(t *testing.T) { + defer testError(t) + b := new(encBuffer) + encState := newEncoderState(b) + for _, tt := range encodeT { + b.Reset() + encState.encodeUint(tt.x) + if !bytes.Equal(tt.b, b.Bytes()) { + t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes()) + } + } + for u := uint64(0); ; u = (u + 1) * 7 { + b.Reset() + encState.encodeUint(u) + decState := newDecodeState(newDecBuffer(b.Bytes())) + v := decState.decodeUint() + if u != v { + t.Errorf("Encode/Decode: sent %#x received %#x", u, v) + } + if u&(1<<63) != 0 { + break + } + } +} + +func verifyInt(i int64, t *testing.T) { + defer testError(t) + var b = new(encBuffer) + encState := newEncoderState(b) + encState.encodeInt(i) + decState := newDecodeState(newDecBuffer(b.Bytes())) + decState.buf = make([]byte, 8) + j := decState.decodeInt() + if i != j { + t.Errorf("Encode/Decode: sent %#x received %#x", uint64(i), uint64(j)) + } +} + +// Test basic encode/decode routines for signed integers +func TestIntCodec(t *testing.T) { + for u := uint64(0); ; u = (u + 1) * 7 { + // Do positive and negative values + i := int64(u) + verifyInt(i, t) + verifyInt(-i, t) + verifyInt(^i, t) + if u&(1<<63) != 0 { + break + } + } + verifyInt(-1<<63, t) // a tricky case +} + +// The result of encoding a true boolean with field number 7 +var boolResult = []byte{0x07, 0x01} + +// The result of encoding a number 17 with field number 7 +var signedResult = []byte{0x07, 2 * 17} +var unsignedResult = []byte{0x07, 17} +var floatResult = []byte{0x07, 0xFE, 0x31, 0x40} + +// The result of encoding a number 17+19i with field number 7 +var complexResult = []byte{0x07, 0xFE, 0x31, 0x40, 0xFE, 0x33, 0x40} + +// The result of encoding "hello" with field number 7 +var bytesResult = []byte{0x07, 0x05, 'h', 'e', 'l', 'l', 'o'} + +func newDecodeState(buf *decBuffer) *decoderState { + d := new(decoderState) + d.b = buf + d.buf = make([]byte, uint64Size) + return d +} + +func newEncoderState(b *encBuffer) *encoderState { + b.Reset() + state := &encoderState{enc: nil, b: b} + state.fieldnum = -1 + return state +} + +// Test instruction execution for encoding. +// Do not run the machine yet; instead do individual instructions crafted by hand. +func TestScalarEncInstructions(t *testing.T) { + var b = new(encBuffer) + + // bool + { + var data bool = true + instr := &encInstr{encBool, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(boolResult, b.Bytes()) { + t.Errorf("bool enc instructions: expected % x got % x", boolResult, b.Bytes()) + } + } + + // int + { + b.Reset() + var data int = 17 + instr := &encInstr{encInt, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(signedResult, b.Bytes()) { + t.Errorf("int enc instructions: expected % x got % x", signedResult, b.Bytes()) + } + } + + // uint + { + b.Reset() + var data uint = 17 + instr := &encInstr{encUint, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(unsignedResult, b.Bytes()) { + t.Errorf("uint enc instructions: expected % x got % x", unsignedResult, b.Bytes()) + } + } + + // int8 + { + b.Reset() + var data int8 = 17 + instr := &encInstr{encInt, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(signedResult, b.Bytes()) { + t.Errorf("int8 enc instructions: expected % x got % x", signedResult, b.Bytes()) + } + } + + // uint8 + { + b.Reset() + var data uint8 = 17 + instr := &encInstr{encUint, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(unsignedResult, b.Bytes()) { + t.Errorf("uint8 enc instructions: expected % x got % x", unsignedResult, b.Bytes()) + } + } + + // int16 + { + b.Reset() + var data int16 = 17 + instr := &encInstr{encInt, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(signedResult, b.Bytes()) { + t.Errorf("int16 enc instructions: expected % x got % x", signedResult, b.Bytes()) + } + } + + // uint16 + { + b.Reset() + var data uint16 = 17 + instr := &encInstr{encUint, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(unsignedResult, b.Bytes()) { + t.Errorf("uint16 enc instructions: expected % x got % x", unsignedResult, b.Bytes()) + } + } + + // int32 + { + b.Reset() + var data int32 = 17 + instr := &encInstr{encInt, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(signedResult, b.Bytes()) { + t.Errorf("int32 enc instructions: expected % x got % x", signedResult, b.Bytes()) + } + } + + // uint32 + { + b.Reset() + var data uint32 = 17 + instr := &encInstr{encUint, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(unsignedResult, b.Bytes()) { + t.Errorf("uint32 enc instructions: expected % x got % x", unsignedResult, b.Bytes()) + } + } + + // int64 + { + b.Reset() + var data int64 = 17 + instr := &encInstr{encInt, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(signedResult, b.Bytes()) { + t.Errorf("int64 enc instructions: expected % x got % x", signedResult, b.Bytes()) + } + } + + // uint64 + { + b.Reset() + var data uint64 = 17 + instr := &encInstr{encUint, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(unsignedResult, b.Bytes()) { + t.Errorf("uint64 enc instructions: expected % x got % x", unsignedResult, b.Bytes()) + } + } + + // float32 + { + b.Reset() + var data float32 = 17 + instr := &encInstr{encFloat, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(floatResult, b.Bytes()) { + t.Errorf("float32 enc instructions: expected % x got % x", floatResult, b.Bytes()) + } + } + + // float64 + { + b.Reset() + var data float64 = 17 + instr := &encInstr{encFloat, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(floatResult, b.Bytes()) { + t.Errorf("float64 enc instructions: expected % x got % x", floatResult, b.Bytes()) + } + } + + // bytes == []uint8 + { + b.Reset() + data := []byte("hello") + instr := &encInstr{encUint8Array, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(bytesResult, b.Bytes()) { + t.Errorf("bytes enc instructions: expected % x got % x", bytesResult, b.Bytes()) + } + } + + // string + { + b.Reset() + var data string = "hello" + instr := &encInstr{encString, 6, nil, 0} + state := newEncoderState(b) + instr.op(instr, state, reflect.ValueOf(data)) + if !bytes.Equal(bytesResult, b.Bytes()) { + t.Errorf("string enc instructions: expected % x got % x", bytesResult, b.Bytes()) + } + } +} + +func execDec(typ string, instr *decInstr, state *decoderState, t *testing.T, value reflect.Value) { + defer testError(t) + v := int(state.decodeUint()) + if v+state.fieldnum != 6 { + t.Fatalf("decoding field number %d, got %d", 6, v+state.fieldnum) + } + instr.op(instr, state, value.Elem()) + state.fieldnum = 6 +} + +func newDecodeStateFromData(data []byte) *decoderState { + b := newDecBuffer(data) + state := newDecodeState(b) + state.fieldnum = -1 + return state +} + +// Test instruction execution for decoding. +// Do not run the machine yet; instead do individual instructions crafted by hand. +func TestScalarDecInstructions(t *testing.T) { + ovfl := errors.New("overflow") + + // bool + { + var data bool + instr := &decInstr{decBool, 6, nil, ovfl} + state := newDecodeStateFromData(boolResult) + execDec("bool", instr, state, t, reflect.ValueOf(&data)) + if data != true { + t.Errorf("bool a = %v not true", data) + } + } + // int + { + var data int + instr := &decInstr{decOpTable[reflect.Int], 6, nil, ovfl} + state := newDecodeStateFromData(signedResult) + execDec("int", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("int a = %v not 17", data) + } + } + + // uint + { + var data uint + instr := &decInstr{decOpTable[reflect.Uint], 6, nil, ovfl} + state := newDecodeStateFromData(unsignedResult) + execDec("uint", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("uint a = %v not 17", data) + } + } + + // int8 + { + var data int8 + instr := &decInstr{decInt8, 6, nil, ovfl} + state := newDecodeStateFromData(signedResult) + execDec("int8", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("int8 a = %v not 17", data) + } + } + + // uint8 + { + var data uint8 + instr := &decInstr{decUint8, 6, nil, ovfl} + state := newDecodeStateFromData(unsignedResult) + execDec("uint8", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("uint8 a = %v not 17", data) + } + } + + // int16 + { + var data int16 + instr := &decInstr{decInt16, 6, nil, ovfl} + state := newDecodeStateFromData(signedResult) + execDec("int16", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("int16 a = %v not 17", data) + } + } + + // uint16 + { + var data uint16 + instr := &decInstr{decUint16, 6, nil, ovfl} + state := newDecodeStateFromData(unsignedResult) + execDec("uint16", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("uint16 a = %v not 17", data) + } + } + + // int32 + { + var data int32 + instr := &decInstr{decInt32, 6, nil, ovfl} + state := newDecodeStateFromData(signedResult) + execDec("int32", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("int32 a = %v not 17", data) + } + } + + // uint32 + { + var data uint32 + instr := &decInstr{decUint32, 6, nil, ovfl} + state := newDecodeStateFromData(unsignedResult) + execDec("uint32", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("uint32 a = %v not 17", data) + } + } + + // uintptr + { + var data uintptr + instr := &decInstr{decOpTable[reflect.Uintptr], 6, nil, ovfl} + state := newDecodeStateFromData(unsignedResult) + execDec("uintptr", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("uintptr a = %v not 17", data) + } + } + + // int64 + { + var data int64 + instr := &decInstr{decInt64, 6, nil, ovfl} + state := newDecodeStateFromData(signedResult) + execDec("int64", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("int64 a = %v not 17", data) + } + } + + // uint64 + { + var data uint64 + instr := &decInstr{decUint64, 6, nil, ovfl} + state := newDecodeStateFromData(unsignedResult) + execDec("uint64", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("uint64 a = %v not 17", data) + } + } + + // float32 + { + var data float32 + instr := &decInstr{decFloat32, 6, nil, ovfl} + state := newDecodeStateFromData(floatResult) + execDec("float32", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("float32 a = %v not 17", data) + } + } + + // float64 + { + var data float64 + instr := &decInstr{decFloat64, 6, nil, ovfl} + state := newDecodeStateFromData(floatResult) + execDec("float64", instr, state, t, reflect.ValueOf(&data)) + if data != 17 { + t.Errorf("float64 a = %v not 17", data) + } + } + + // complex64 + { + var data complex64 + instr := &decInstr{decOpTable[reflect.Complex64], 6, nil, ovfl} + state := newDecodeStateFromData(complexResult) + execDec("complex", instr, state, t, reflect.ValueOf(&data)) + if data != 17+19i { + t.Errorf("complex a = %v not 17+19i", data) + } + } + + // complex128 + { + var data complex128 + instr := &decInstr{decOpTable[reflect.Complex128], 6, nil, ovfl} + state := newDecodeStateFromData(complexResult) + execDec("complex", instr, state, t, reflect.ValueOf(&data)) + if data != 17+19i { + t.Errorf("complex a = %v not 17+19i", data) + } + } + + // bytes == []uint8 + { + var data []byte + instr := &decInstr{decUint8Slice, 6, nil, ovfl} + state := newDecodeStateFromData(bytesResult) + execDec("bytes", instr, state, t, reflect.ValueOf(&data)) + if string(data) != "hello" { + t.Errorf(`bytes a = %q not "hello"`, string(data)) + } + } + + // string + { + var data string + instr := &decInstr{decString, 6, nil, ovfl} + state := newDecodeStateFromData(bytesResult) + execDec("bytes", instr, state, t, reflect.ValueOf(&data)) + if data != "hello" { + t.Errorf(`bytes a = %q not "hello"`, data) + } + } +} + +func TestEndToEnd(t *testing.T) { + type T2 struct { + T string + } + s1 := "string1" + s2 := "string2" + type T1 struct { + A, B, C int + M map[string]*float64 + EmptyMap map[string]int // to check that we receive a non-nil map. + N *[3]float64 + Strs *[2]string + Int64s *[]int64 + RI complex64 + S string + Y []byte + T *T2 + } + pi := 3.14159 + e := 2.71828 + t1 := &T1{ + A: 17, + B: 18, + C: -5, + M: map[string]*float64{"pi": &pi, "e": &e}, + EmptyMap: make(map[string]int), + N: &[3]float64{1.5, 2.5, 3.5}, + Strs: &[2]string{s1, s2}, + Int64s: &[]int64{77, 89, 123412342134}, + RI: 17 - 23i, + S: "Now is the time", + Y: []byte("hello, sailor"), + T: &T2{"this is T2"}, + } + b := new(bytes.Buffer) + err := NewEncoder(b).Encode(t1) + if err != nil { + t.Error("encode:", err) + } + var _t1 T1 + err = NewDecoder(b).Decode(&_t1) + if err != nil { + t.Fatal("decode:", err) + } + if !reflect.DeepEqual(t1, &_t1) { + t.Errorf("encode expected %v got %v", *t1, _t1) + } + // Be absolutely sure the received map is non-nil. + if t1.EmptyMap == nil { + t.Errorf("nil map sent") + } + if _t1.EmptyMap == nil { + t.Errorf("nil map received") + } +} + +func TestOverflow(t *testing.T) { + type inputT struct { + Maxi int64 + Mini int64 + Maxu uint64 + Maxf float64 + Minf float64 + Maxc complex128 + Minc complex128 + } + var it inputT + var err error + b := new(bytes.Buffer) + enc := NewEncoder(b) + dec := NewDecoder(b) + + // int8 + b.Reset() + it = inputT{ + Maxi: math.MaxInt8 + 1, + } + type outi8 struct { + Maxi int8 + Mini int8 + } + var o1 outi8 + enc.Encode(it) + err = dec.Decode(&o1) + if err == nil || err.Error() != `value for "Maxi" out of range` { + t.Error("wrong overflow error for int8:", err) + } + it = inputT{ + Mini: math.MinInt8 - 1, + } + b.Reset() + enc.Encode(it) + err = dec.Decode(&o1) + if err == nil || err.Error() != `value for "Mini" out of range` { + t.Error("wrong underflow error for int8:", err) + } + + // int16 + b.Reset() + it = inputT{ + Maxi: math.MaxInt16 + 1, + } + type outi16 struct { + Maxi int16 + Mini int16 + } + var o2 outi16 + enc.Encode(it) + err = dec.Decode(&o2) + if err == nil || err.Error() != `value for "Maxi" out of range` { + t.Error("wrong overflow error for int16:", err) + } + it = inputT{ + Mini: math.MinInt16 - 1, + } + b.Reset() + enc.Encode(it) + err = dec.Decode(&o2) + if err == nil || err.Error() != `value for "Mini" out of range` { + t.Error("wrong underflow error for int16:", err) + } + + // int32 + b.Reset() + it = inputT{ + Maxi: math.MaxInt32 + 1, + } + type outi32 struct { + Maxi int32 + Mini int32 + } + var o3 outi32 + enc.Encode(it) + err = dec.Decode(&o3) + if err == nil || err.Error() != `value for "Maxi" out of range` { + t.Error("wrong overflow error for int32:", err) + } + it = inputT{ + Mini: math.MinInt32 - 1, + } + b.Reset() + enc.Encode(it) + err = dec.Decode(&o3) + if err == nil || err.Error() != `value for "Mini" out of range` { + t.Error("wrong underflow error for int32:", err) + } + + // uint8 + b.Reset() + it = inputT{ + Maxu: math.MaxUint8 + 1, + } + type outu8 struct { + Maxu uint8 + } + var o4 outu8 + enc.Encode(it) + err = dec.Decode(&o4) + if err == nil || err.Error() != `value for "Maxu" out of range` { + t.Error("wrong overflow error for uint8:", err) + } + + // uint16 + b.Reset() + it = inputT{ + Maxu: math.MaxUint16 + 1, + } + type outu16 struct { + Maxu uint16 + } + var o5 outu16 + enc.Encode(it) + err = dec.Decode(&o5) + if err == nil || err.Error() != `value for "Maxu" out of range` { + t.Error("wrong overflow error for uint16:", err) + } + + // uint32 + b.Reset() + it = inputT{ + Maxu: math.MaxUint32 + 1, + } + type outu32 struct { + Maxu uint32 + } + var o6 outu32 + enc.Encode(it) + err = dec.Decode(&o6) + if err == nil || err.Error() != `value for "Maxu" out of range` { + t.Error("wrong overflow error for uint32:", err) + } + + // float32 + b.Reset() + it = inputT{ + Maxf: math.MaxFloat32 * 2, + } + type outf32 struct { + Maxf float32 + Minf float32 + } + var o7 outf32 + enc.Encode(it) + err = dec.Decode(&o7) + if err == nil || err.Error() != `value for "Maxf" out of range` { + t.Error("wrong overflow error for float32:", err) + } + + // complex64 + b.Reset() + it = inputT{ + Maxc: complex(math.MaxFloat32*2, math.MaxFloat32*2), + } + type outc64 struct { + Maxc complex64 + Minc complex64 + } + var o8 outc64 + enc.Encode(it) + err = dec.Decode(&o8) + if err == nil || err.Error() != `value for "Maxc" out of range` { + t.Error("wrong overflow error for complex64:", err) + } +} + +func TestNesting(t *testing.T) { + type RT struct { + A string + Next *RT + } + rt := new(RT) + rt.A = "level1" + rt.Next = new(RT) + rt.Next.A = "level2" + b := new(bytes.Buffer) + NewEncoder(b).Encode(rt) + var drt RT + dec := NewDecoder(b) + err := dec.Decode(&drt) + if err != nil { + t.Fatal("decoder error:", err) + } + if drt.A != rt.A { + t.Errorf("nesting: encode expected %v got %v", *rt, drt) + } + if drt.Next == nil { + t.Errorf("nesting: recursion failed") + } + if drt.Next.A != rt.Next.A { + t.Errorf("nesting: encode expected %v got %v", *rt.Next, *drt.Next) + } +} + +// These three structures have the same data with different indirections +type T0 struct { + A int + B int + C int + D int +} +type T1 struct { + A int + B *int + C **int + D ***int +} +type T2 struct { + A ***int + B **int + C *int + D int +} + +func TestAutoIndirection(t *testing.T) { + // First transfer t1 into t0 + var t1 T1 + t1.A = 17 + t1.B = new(int) + *t1.B = 177 + t1.C = new(*int) + *t1.C = new(int) + **t1.C = 1777 + t1.D = new(**int) + *t1.D = new(*int) + **t1.D = new(int) + ***t1.D = 17777 + b := new(bytes.Buffer) + enc := NewEncoder(b) + enc.Encode(t1) + dec := NewDecoder(b) + var t0 T0 + dec.Decode(&t0) + if t0.A != 17 || t0.B != 177 || t0.C != 1777 || t0.D != 17777 { + t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0) + } + + // Now transfer t2 into t0 + var t2 T2 + t2.D = 17777 + t2.C = new(int) + *t2.C = 1777 + t2.B = new(*int) + *t2.B = new(int) + **t2.B = 177 + t2.A = new(**int) + *t2.A = new(*int) + **t2.A = new(int) + ***t2.A = 17 + b.Reset() + enc.Encode(t2) + t0 = T0{} + dec.Decode(&t0) + if t0.A != 17 || t0.B != 177 || t0.C != 1777 || t0.D != 17777 { + t.Errorf("t2->t0 expected {17 177 1777 17777}; got %v", t0) + } + + // Now transfer t0 into t1 + t0 = T0{17, 177, 1777, 17777} + b.Reset() + enc.Encode(t0) + t1 = T1{} + dec.Decode(&t1) + if t1.A != 17 || *t1.B != 177 || **t1.C != 1777 || ***t1.D != 17777 { + t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.A, *t1.B, **t1.C, ***t1.D) + } + + // Now transfer t0 into t2 + b.Reset() + enc.Encode(t0) + t2 = T2{} + dec.Decode(&t2) + if ***t2.A != 17 || **t2.B != 177 || *t2.C != 1777 || t2.D != 17777 { + t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.A, **t2.B, *t2.C, t2.D) + } + + // Now do t2 again but without pre-allocated pointers. + b.Reset() + enc.Encode(t0) + ***t2.A = 0 + **t2.B = 0 + *t2.C = 0 + t2.D = 0 + dec.Decode(&t2) + if ***t2.A != 17 || **t2.B != 177 || *t2.C != 1777 || t2.D != 17777 { + t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.A, **t2.B, *t2.C, t2.D) + } +} + +type RT0 struct { + A int + B string + C float64 +} +type RT1 struct { + C float64 + B string + A int + NotSet string +} + +func TestReorderedFields(t *testing.T) { + var rt0 RT0 + rt0.A = 17 + rt0.B = "hello" + rt0.C = 3.14159 + b := new(bytes.Buffer) + NewEncoder(b).Encode(rt0) + dec := NewDecoder(b) + var rt1 RT1 + // Wire type is RT0, local type is RT1. + err := dec.Decode(&rt1) + if err != nil { + t.Fatal("decode error:", err) + } + if rt0.A != rt1.A || rt0.B != rt1.B || rt0.C != rt1.C { + t.Errorf("rt1->rt0: expected %v; got %v", rt0, rt1) + } +} + +// Like an RT0 but with fields we'll ignore on the decode side. +type IT0 struct { + A int64 + B string + Ignore_d []int + Ignore_e [3]float64 + Ignore_f bool + Ignore_g string + Ignore_h []byte + Ignore_i *RT1 + Ignore_m map[string]int + C float64 +} + +func TestIgnoredFields(t *testing.T) { + var it0 IT0 + it0.A = 17 + it0.B = "hello" + it0.C = 3.14159 + it0.Ignore_d = []int{1, 2, 3} + it0.Ignore_e[0] = 1.0 + it0.Ignore_e[1] = 2.0 + it0.Ignore_e[2] = 3.0 + it0.Ignore_f = true + it0.Ignore_g = "pay no attention" + it0.Ignore_h = []byte("to the curtain") + it0.Ignore_i = &RT1{3.1, "hi", 7, "hello"} + it0.Ignore_m = map[string]int{"one": 1, "two": 2} + + b := new(bytes.Buffer) + NewEncoder(b).Encode(it0) + dec := NewDecoder(b) + var rt1 RT1 + // Wire type is IT0, local type is RT1. + err := dec.Decode(&rt1) + if err != nil { + t.Error("error: ", err) + } + if int(it0.A) != rt1.A || it0.B != rt1.B || it0.C != rt1.C { + t.Errorf("rt0->rt1: expected %v; got %v", it0, rt1) + } +} + +func TestBadRecursiveType(t *testing.T) { + type Rec ***Rec + var rec Rec + b := new(bytes.Buffer) + err := NewEncoder(b).Encode(&rec) + if err == nil { + t.Error("expected error; got none") + } else if strings.Index(err.Error(), "recursive") < 0 { + t.Error("expected recursive type error; got", err) + } + // Can't test decode easily because we can't encode one, so we can't pass one to a Decoder. +} + +type Indirect struct { + A ***[3]int + S ***[]int + M ****map[string]int +} + +type Direct struct { + A [3]int + S []int + M map[string]int +} + +func TestIndirectSliceMapArray(t *testing.T) { + // Marshal indirect, unmarshal to direct. + i := new(Indirect) + i.A = new(**[3]int) + *i.A = new(*[3]int) + **i.A = new([3]int) + ***i.A = [3]int{1, 2, 3} + i.S = new(**[]int) + *i.S = new(*[]int) + **i.S = new([]int) + ***i.S = []int{4, 5, 6} + i.M = new(***map[string]int) + *i.M = new(**map[string]int) + **i.M = new(*map[string]int) + ***i.M = new(map[string]int) + ****i.M = map[string]int{"one": 1, "two": 2, "three": 3} + b := new(bytes.Buffer) + NewEncoder(b).Encode(i) + dec := NewDecoder(b) + var d Direct + err := dec.Decode(&d) + if err != nil { + t.Error("error: ", err) + } + if len(d.A) != 3 || d.A[0] != 1 || d.A[1] != 2 || d.A[2] != 3 { + t.Errorf("indirect to direct: d.A is %v not %v", d.A, ***i.A) + } + if len(d.S) != 3 || d.S[0] != 4 || d.S[1] != 5 || d.S[2] != 6 { + t.Errorf("indirect to direct: d.S is %v not %v", d.S, ***i.S) + } + if len(d.M) != 3 || d.M["one"] != 1 || d.M["two"] != 2 || d.M["three"] != 3 { + t.Errorf("indirect to direct: d.M is %v not %v", d.M, ***i.M) + } + // Marshal direct, unmarshal to indirect. + d.A = [3]int{11, 22, 33} + d.S = []int{44, 55, 66} + d.M = map[string]int{"four": 4, "five": 5, "six": 6} + i = new(Indirect) + b.Reset() + NewEncoder(b).Encode(d) + dec = NewDecoder(b) + err = dec.Decode(&i) + if err != nil { + t.Fatal("error: ", err) + } + if len(***i.A) != 3 || (***i.A)[0] != 11 || (***i.A)[1] != 22 || (***i.A)[2] != 33 { + t.Errorf("direct to indirect: ***i.A is %v not %v", ***i.A, d.A) + } + if len(***i.S) != 3 || (***i.S)[0] != 44 || (***i.S)[1] != 55 || (***i.S)[2] != 66 { + t.Errorf("direct to indirect: ***i.S is %v not %v", ***i.S, ***i.S) + } + if len(****i.M) != 3 || (****i.M)["four"] != 4 || (****i.M)["five"] != 5 || (****i.M)["six"] != 6 { + t.Errorf("direct to indirect: ****i.M is %v not %v", ****i.M, d.M) + } +} + +// An interface with several implementations +type Squarer interface { + Square() int +} + +type Int int + +func (i Int) Square() int { + return int(i * i) +} + +type Float float64 + +func (f Float) Square() int { + return int(f * f) +} + +type Vector []int + +func (v Vector) Square() int { + sum := 0 + for _, x := range v { + sum += x * x + } + return sum +} + +type Point struct { + X, Y int +} + +func (p Point) Square() int { + return p.X*p.X + p.Y*p.Y +} + +// A struct with interfaces in it. +type InterfaceItem struct { + I int + Sq1, Sq2, Sq3 Squarer + F float64 + Sq []Squarer +} + +// The same struct without interfaces +type NoInterfaceItem struct { + I int + F float64 +} + +func TestInterface(t *testing.T) { + iVal := Int(3) + fVal := Float(5) + // Sending a Vector will require that the receiver define a type in the middle of + // receiving the value for item2. + vVal := Vector{1, 2, 3} + b := new(bytes.Buffer) + item1 := &InterfaceItem{1, iVal, fVal, vVal, 11.5, []Squarer{iVal, fVal, nil, vVal}} + // Register the types. + Register(Int(0)) + Register(Float(0)) + Register(Vector{}) + err := NewEncoder(b).Encode(item1) + if err != nil { + t.Error("expected no encode error; got", err) + } + + item2 := InterfaceItem{} + err = NewDecoder(b).Decode(&item2) + if err != nil { + t.Fatal("decode:", err) + } + if item2.I != item1.I { + t.Error("normal int did not decode correctly") + } + if item2.Sq1 == nil || item2.Sq1.Square() != iVal.Square() { + t.Error("Int did not decode correctly") + } + if item2.Sq2 == nil || item2.Sq2.Square() != fVal.Square() { + t.Error("Float did not decode correctly") + } + if item2.Sq3 == nil || item2.Sq3.Square() != vVal.Square() { + t.Error("Vector did not decode correctly") + } + if item2.F != item1.F { + t.Error("normal float did not decode correctly") + } + // Now check that we received a slice of Squarers correctly, including a nil element + if len(item1.Sq) != len(item2.Sq) { + t.Fatalf("[]Squarer length wrong: got %d; expected %d", len(item2.Sq), len(item1.Sq)) + } + for i, v1 := range item1.Sq { + v2 := item2.Sq[i] + if v1 == nil || v2 == nil { + if v1 != nil || v2 != nil { + t.Errorf("item %d inconsistent nils", i) + } + } else if v1.Square() != v2.Square() { + t.Errorf("item %d inconsistent values: %v %v", i, v1, v2) + } + } +} + +// A struct with all basic types, stored in interfaces. +type BasicInterfaceItem struct { + Int, Int8, Int16, Int32, Int64 interface{} + Uint, Uint8, Uint16, Uint32, Uint64 interface{} + Float32, Float64 interface{} + Complex64, Complex128 interface{} + Bool interface{} + String interface{} + Bytes interface{} +} + +func TestInterfaceBasic(t *testing.T) { + b := new(bytes.Buffer) + item1 := &BasicInterfaceItem{ + int(1), int8(1), int16(1), int32(1), int64(1), + uint(1), uint8(1), uint16(1), uint32(1), uint64(1), + float32(1), 1.0, + complex64(1i), complex128(1i), + true, + "hello", + []byte("sailor"), + } + err := NewEncoder(b).Encode(item1) + if err != nil { + t.Error("expected no encode error; got", err) + } + + item2 := &BasicInterfaceItem{} + err = NewDecoder(b).Decode(&item2) + if err != nil { + t.Fatal("decode:", err) + } + if !reflect.DeepEqual(item1, item2) { + t.Errorf("encode expected %v got %v", item1, item2) + } + // Hand check a couple for correct types. + if v, ok := item2.Bool.(bool); !ok || !v { + t.Error("boolean should be true") + } + if v, ok := item2.String.(string); !ok || v != item1.String.(string) { + t.Errorf("string should be %v is %v", item1.String, v) + } +} + +type String string + +type PtrInterfaceItem struct { + Str1 interface{} // basic + Str2 interface{} // derived +} + +// We'll send pointers; should receive values. +// Also check that we can register T but send *T. +func TestInterfacePointer(t *testing.T) { + b := new(bytes.Buffer) + str1 := "howdy" + str2 := String("kiddo") + item1 := &PtrInterfaceItem{ + &str1, + &str2, + } + // Register the type. + Register(str2) + err := NewEncoder(b).Encode(item1) + if err != nil { + t.Error("expected no encode error; got", err) + } + + item2 := &PtrInterfaceItem{} + err = NewDecoder(b).Decode(&item2) + if err != nil { + t.Fatal("decode:", err) + } + // Hand test for correct types and values. + if v, ok := item2.Str1.(string); !ok || v != str1 { + t.Errorf("basic string failed: %q should be %q", v, str1) + } + if v, ok := item2.Str2.(String); !ok || v != str2 { + t.Errorf("derived type String failed: %q should be %q", v, str2) + } +} + +func TestIgnoreInterface(t *testing.T) { + iVal := Int(3) + fVal := Float(5) + // Sending a Point will require that the receiver define a type in the middle of + // receiving the value for item2. + pVal := Point{2, 3} + b := new(bytes.Buffer) + item1 := &InterfaceItem{1, iVal, fVal, pVal, 11.5, nil} + // Register the types. + Register(Int(0)) + Register(Float(0)) + Register(Point{}) + err := NewEncoder(b).Encode(item1) + if err != nil { + t.Error("expected no encode error; got", err) + } + + item2 := NoInterfaceItem{} + err = NewDecoder(b).Decode(&item2) + if err != nil { + t.Fatal("decode:", err) + } + if item2.I != item1.I { + t.Error("normal int did not decode correctly") + } + if item2.F != item2.F { + t.Error("normal float did not decode correctly") + } +} + +type U struct { + A int + B string + c float64 + D uint +} + +func TestUnexportedFields(t *testing.T) { + var u0 U + u0.A = 17 + u0.B = "hello" + u0.c = 3.14159 + u0.D = 23 + b := new(bytes.Buffer) + NewEncoder(b).Encode(u0) + dec := NewDecoder(b) + var u1 U + u1.c = 1234. + err := dec.Decode(&u1) + if err != nil { + t.Fatal("decode error:", err) + } + if u0.A != u0.A || u0.B != u1.B || u0.D != u1.D { + t.Errorf("u1->u0: expected %v; got %v", u0, u1) + } + if u1.c != 1234. { + t.Error("u1.c modified") + } +} + +var singletons = []interface{}{ + true, + 7, + 3.2, + "hello", + [3]int{11, 22, 33}, + []float32{0.5, 0.25, 0.125}, + map[string]int{"one": 1, "two": 2}, +} + +func TestDebugSingleton(t *testing.T) { + if debugFunc == nil { + return + } + b := new(bytes.Buffer) + // Accumulate a number of values and print them out all at once. + for _, x := range singletons { + err := NewEncoder(b).Encode(x) + if err != nil { + t.Fatal("encode:", err) + } + } + debugFunc(b) +} + +// A type that won't be defined in the gob until we send it in an interface value. +type OnTheFly struct { + A int +} + +type DT struct { + // X OnTheFly + A int + B string + C float64 + I interface{} + J interface{} + I_nil interface{} + M map[string]int + T [3]int + S []string +} + +func newDT() DT { + var dt DT + dt.A = 17 + dt.B = "hello" + dt.C = 3.14159 + dt.I = 271828 + dt.J = OnTheFly{3} + dt.I_nil = nil + dt.M = map[string]int{"one": 1, "two": 2} + dt.T = [3]int{11, 22, 33} + dt.S = []string{"hi", "joe"} + return dt +} + +func TestDebugStruct(t *testing.T) { + if debugFunc == nil { + return + } + Register(OnTheFly{}) + dt := newDT() + b := new(bytes.Buffer) + err := NewEncoder(b).Encode(dt) + if err != nil { + t.Fatal("encode:", err) + } + debugBuffer := bytes.NewBuffer(b.Bytes()) + dt2 := &DT{} + err = NewDecoder(b).Decode(&dt2) + if err != nil { + t.Error("decode:", err) + } + debugFunc(debugBuffer) +} + +func encFuzzDec(rng *rand.Rand, in interface{}) error { + buf := new(bytes.Buffer) + enc := NewEncoder(buf) + if err := enc.Encode(&in); err != nil { + return err + } + + b := buf.Bytes() + for i, bi := range b { + if rng.Intn(10) < 3 { + b[i] = bi + uint8(rng.Intn(256)) + } + } + + dec := NewDecoder(buf) + var e interface{} + if err := dec.Decode(&e); err != nil { + return err + } + return nil +} + +// This does some "fuzz testing" by attempting to decode a sequence of random bytes. +func TestFuzz(t *testing.T) { + if !*doFuzzTests { + t.Logf("disabled; run with -gob.fuzz to enable") + return + } + + // all possible inputs + input := []interface{}{ + new(int), + new(float32), + new(float64), + new(complex128), + &ByteStruct{255}, + &ArrayStruct{}, + &StringStruct{"hello"}, + &GobTest1{0, &StringStruct{"hello"}}, + } + testFuzz(t, time.Now().UnixNano(), 100, input...) +} + +func TestFuzzRegressions(t *testing.T) { + if !*doFuzzTests { + t.Logf("disabled; run with -gob.fuzz to enable") + return + } + + // An instance triggering a type name of length ~102 GB. + testFuzz(t, 1328492090837718000, 100, new(float32)) + // An instance triggering a type name of 1.6 GB. + // Note: can take several minutes to run. + testFuzz(t, 1330522872628565000, 100, new(int)) +} + +func testFuzz(t *testing.T, seed int64, n int, input ...interface{}) { + for _, e := range input { + t.Logf("seed=%d n=%d e=%T", seed, n, e) + rng := rand.New(rand.NewSource(seed)) + for i := 0; i < n; i++ { + encFuzzDec(rng, e) + } + } +} + +// TestFuzzOneByte tries to decode corrupted input sequences +// and checks that no panic occurs. +func TestFuzzOneByte(t *testing.T) { + buf := new(bytes.Buffer) + Register(OnTheFly{}) + dt := newDT() + if err := NewEncoder(buf).Encode(dt); err != nil { + t.Fatal(err) + } + s := buf.String() + + indices := make([]int, 0, len(s)) + for i := 0; i < len(s); i++ { + switch i { + case 14, 167, 231, 265: // a slice length, corruptions are not handled yet. + continue + } + indices = append(indices, i) + } + if testing.Short() { + indices = []int{1, 111, 178} // known fixed panics + } + for _, i := range indices { + for j := 0; j < 256; j += 3 { + b := []byte(s) + b[i] ^= byte(j) + var e DT + func() { + defer func() { + if p := recover(); p != nil { + t.Errorf("crash for b[%d] ^= 0x%x", i, j) + panic(p) + } + }() + err := NewDecoder(bytes.NewReader(b)).Decode(&e) + _ = err + }() + } + } +} diff --git a/src/encoding/gob/debug.go b/src/encoding/gob/debug.go new file mode 100644 index 000000000..536bbdb5a --- /dev/null +++ b/src/encoding/gob/debug.go @@ -0,0 +1,705 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Delete the next line to include in the gob package. +// +build ignore + +package gob + +// This file is not normally included in the gob package. Used only for debugging the package itself. +// Except for reading uints, it is an implementation of a reader that is independent of +// the one implemented by Decoder. +// To enable the Debug function, delete the +build ignore line above and do +// go install + +import ( + "bytes" + "fmt" + "io" + "os" + "strings" + "sync" +) + +var dumpBytes = false // If true, print the remaining bytes in the input buffer at each item. + +// Init installs the debugging facility. If this file is not compiled in the +// package, the tests in codec_test.go are no-ops. +func init() { + debugFunc = Debug +} + +var ( + blanks = bytes.Repeat([]byte{' '}, 3*10) + empty = []byte(": <empty>\n") + tabs = strings.Repeat("\t", 100) +) + +// tab indents itself when printed. +type tab int + +func (t tab) String() string { + n := int(t) + if n > len(tabs) { + n = len(tabs) + } + return tabs[0:n] +} + +func (t tab) print() { + fmt.Fprint(os.Stderr, t) +} + +// A peekReader wraps an io.Reader, allowing one to peek ahead to see +// what's coming without stealing the data from the client of the Reader. +type peekReader struct { + r io.Reader + data []byte // read-ahead data +} + +// newPeekReader returns a peekReader that wraps r. +func newPeekReader(r io.Reader) *peekReader { + return &peekReader{r: r} +} + +// Read is the usual method. It will first take data that has been read ahead. +func (p *peekReader) Read(b []byte) (n int, err error) { + if len(p.data) == 0 { + return p.r.Read(b) + } + // Satisfy what's possible from the read-ahead data. + n = copy(b, p.data) + // Move data down to beginning of slice, to avoid endless growth + copy(p.data, p.data[n:]) + p.data = p.data[:len(p.data)-n] + return +} + +// peek returns as many bytes as possible from the unread +// portion of the stream, up to the length of b. +func (p *peekReader) peek(b []byte) (n int, err error) { + if len(p.data) > 0 { + n = copy(b, p.data) + if n == len(b) { + return + } + b = b[n:] + } + if len(b) == 0 { + return + } + m, e := io.ReadFull(p.r, b) + if m > 0 { + p.data = append(p.data, b[:m]...) + } + n += m + if e == io.ErrUnexpectedEOF { + // That means m > 0 but we reached EOF. If we got data + // we won't complain about not being able to peek enough. + if n > 0 { + e = nil + } else { + e = io.EOF + } + } + return n, e +} + +type debugger struct { + mutex sync.Mutex + remain int // the number of bytes known to remain in the input + remainingKnown bool // the value of 'remain' is valid + r *peekReader + wireType map[typeId]*wireType + tmp []byte // scratch space for decoding uints. +} + +// dump prints the next nBytes of the input. +// It arranges to print the output aligned from call to +// call, to make it easy to see what has been consumed. +func (deb *debugger) dump(format string, args ...interface{}) { + if !dumpBytes { + return + } + fmt.Fprintf(os.Stderr, format+" ", args...) + if !deb.remainingKnown { + return + } + if deb.remain < 0 { + fmt.Fprintf(os.Stderr, "remaining byte count is negative! %d\n", deb.remain) + return + } + data := make([]byte, deb.remain) + n, _ := deb.r.peek(data) + if n == 0 { + os.Stderr.Write(empty) + return + } + b := new(bytes.Buffer) + fmt.Fprintf(b, "[%d]{\n", deb.remain) + // Blanks until first byte + lineLength := 0 + if n := len(data); n%10 != 0 { + lineLength = 10 - n%10 + fmt.Fprintf(b, "\t%s", blanks[:lineLength*3]) + } + // 10 bytes per line + for len(data) > 0 { + if lineLength == 0 { + fmt.Fprint(b, "\t") + } + m := 10 - lineLength + lineLength = 0 + if m > len(data) { + m = len(data) + } + fmt.Fprintf(b, "% x\n", data[:m]) + data = data[m:] + } + fmt.Fprint(b, "}\n") + os.Stderr.Write(b.Bytes()) +} + +// Debug prints a human-readable representation of the gob data read from r. +// It is a no-op unless debugging was enabled when the package was built. +func Debug(r io.Reader) { + err := debug(r) + if err != nil { + fmt.Fprintf(os.Stderr, "gob debug: %s\n", err) + } +} + +// debug implements Debug, but catches panics and returns +// them as errors to be printed by Debug. +func debug(r io.Reader) (err error) { + defer catchError(&err) + fmt.Fprintln(os.Stderr, "Start of debugging") + deb := &debugger{ + r: newPeekReader(r), + wireType: make(map[typeId]*wireType), + tmp: make([]byte, 16), + } + if b, ok := r.(*bytes.Buffer); ok { + deb.remain = b.Len() + deb.remainingKnown = true + } + deb.gobStream() + return +} + +// note that we've consumed some bytes +func (deb *debugger) consumed(n int) { + if deb.remainingKnown { + deb.remain -= n + } +} + +// int64 decodes and returns the next integer, which must be present. +// Don't call this if you could be at EOF. +func (deb *debugger) int64() int64 { + return toInt(deb.uint64()) +} + +// uint64 returns and decodes the next unsigned integer, which must be present. +// Don't call this if you could be at EOF. +// TODO: handle errors better. +func (deb *debugger) uint64() uint64 { + n, w, err := decodeUintReader(deb.r, deb.tmp) + if err != nil { + errorf("debug: read error: %s", err) + } + deb.consumed(w) + return n +} + +// GobStream: +// DelimitedMessage* (until EOF) +func (deb *debugger) gobStream() { + // Make sure we're single-threaded through here. + deb.mutex.Lock() + defer deb.mutex.Unlock() + + for deb.delimitedMessage(0) { + } +} + +// DelimitedMessage: +// uint(lengthOfMessage) Message +func (deb *debugger) delimitedMessage(indent tab) bool { + for { + n := deb.loadBlock(true) + if n < 0 { + return false + } + deb.dump("Delimited message of length %d", n) + deb.message(indent) + } + return true +} + +// loadBlock preps us to read a message +// of the length specified next in the input. It returns +// the length of the block. The argument tells whether +// an EOF is acceptable now. If it is and one is found, +// the return value is negative. +func (deb *debugger) loadBlock(eofOK bool) int { + n64, w, err := decodeUintReader(deb.r, deb.tmp) // deb.uint64 will error at EOF + if err != nil { + if eofOK && err == io.EOF { + return -1 + } + errorf("debug: unexpected error: %s", err) + } + deb.consumed(w) + n := int(n64) + if n < 0 { + errorf("huge value for message length: %d", n64) + } + return int(n) +} + +// Message: +// TypeSequence TypedValue +// TypeSequence +// (TypeDefinition DelimitedTypeDefinition*)? +// DelimitedTypeDefinition: +// uint(lengthOfTypeDefinition) TypeDefinition +// TypedValue: +// int(typeId) Value +func (deb *debugger) message(indent tab) bool { + for { + // Convert the uint64 to a signed integer typeId + uid := deb.int64() + id := typeId(uid) + deb.dump("type id=%d", id) + if id < 0 { + deb.typeDefinition(indent, -id) + n := deb.loadBlock(false) + deb.dump("Message of length %d", n) + continue + } else { + deb.value(indent, id) + break + } + } + return true +} + +// Helper methods to make it easy to scan a type descriptor. + +// common returns the CommonType at the input point. +func (deb *debugger) common() CommonType { + fieldNum := -1 + name := "" + id := typeId(0) + for { + delta := deb.delta(-1) + if delta == 0 { + break + } + fieldNum += delta + switch fieldNum { + case 0: + name = deb.string() + case 1: + // Id typeId + id = deb.typeId() + default: + errorf("corrupted CommonType, delta is %d fieldNum is %d", delta, fieldNum) + } + } + return CommonType{name, id} +} + +// uint returns the unsigned int at the input point, as a uint (not uint64). +func (deb *debugger) uint() uint { + return uint(deb.uint64()) +} + +// int returns the signed int at the input point, as an int (not int64). +func (deb *debugger) int() int { + return int(deb.int64()) +} + +// typeId returns the type id at the input point. +func (deb *debugger) typeId() typeId { + return typeId(deb.int64()) +} + +// string returns the string at the input point. +func (deb *debugger) string() string { + x := int(deb.uint64()) + b := make([]byte, x) + nb, _ := deb.r.Read(b) + if nb != x { + errorf("corrupted type") + } + deb.consumed(nb) + return string(b) +} + +// delta returns the field delta at the input point. The expect argument, +// if non-negative, identifies what the value should be. +func (deb *debugger) delta(expect int) int { + delta := int(deb.uint64()) + if delta < 0 || (expect >= 0 && delta != expect) { + errorf("decode: corrupted type: delta %d expected %d", delta, expect) + } + return delta +} + +// TypeDefinition: +// [int(-typeId) (already read)] encodingOfWireType +func (deb *debugger) typeDefinition(indent tab, id typeId) { + deb.dump("type definition for id %d", id) + // Encoding is of a wireType. Decode the structure as usual + fieldNum := -1 + wire := new(wireType) + // A wireType defines a single field. + delta := deb.delta(-1) + fieldNum += delta + switch fieldNum { + case 0: // array type, one field of {{Common}, elem, length} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + // Field number 1 is type Id of elem + deb.delta(1) + id := deb.typeId() + // Field number 3 is length + deb.delta(1) + length := deb.int() + wire.ArrayT = &arrayType{com, id, length} + + case 1: // slice type, one field of {{Common}, elem} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + // Field number 1 is type Id of elem + deb.delta(1) + id := deb.typeId() + wire.SliceT = &sliceType{com, id} + + case 2: // struct type, one field of {{Common}, []fieldType} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + // Field number 1 is slice of FieldType + deb.delta(1) + numField := int(deb.uint()) + field := make([]*fieldType, numField) + for i := 0; i < numField; i++ { + field[i] = new(fieldType) + deb.delta(1) // field 0 of fieldType: name + field[i].Name = deb.string() + deb.delta(1) // field 1 of fieldType: id + field[i].Id = deb.typeId() + deb.delta(0) // end of fieldType + } + wire.StructT = &structType{com, field} + + case 3: // map type, one field of {{Common}, key, elem} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + // Field number 1 is type Id of key + deb.delta(1) + keyId := deb.typeId() + // Field number 2 is type Id of elem + deb.delta(1) + elemId := deb.typeId() + wire.MapT = &mapType{com, keyId, elemId} + case 4: // GobEncoder type, one field of {{Common}} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + wire.GobEncoderT = &gobEncoderType{com} + case 5: // BinaryMarshaler type, one field of {{Common}} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + wire.BinaryMarshalerT = &gobEncoderType{com} + case 6: // TextMarshaler type, one field of {{Common}} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + wire.TextMarshalerT = &gobEncoderType{com} + default: + errorf("bad field in type %d", fieldNum) + } + deb.printWireType(indent, wire) + deb.delta(0) // end inner type (arrayType, etc.) + deb.delta(0) // end wireType + // Remember we've seen this type. + deb.wireType[id] = wire +} + +// Value: +// SingletonValue | StructValue +func (deb *debugger) value(indent tab, id typeId) { + wire, ok := deb.wireType[id] + if ok && wire.StructT != nil { + deb.structValue(indent, id) + } else { + deb.singletonValue(indent, id) + } +} + +// SingletonValue: +// uint(0) FieldValue +func (deb *debugger) singletonValue(indent tab, id typeId) { + deb.dump("Singleton value") + // is it a builtin type? + wire := deb.wireType[id] + _, ok := builtinIdToType[id] + if !ok && wire == nil { + errorf("type id %d not defined", id) + } + m := deb.uint64() + if m != 0 { + errorf("expected zero; got %d", m) + } + deb.fieldValue(indent, id) +} + +// InterfaceValue: +// NilInterfaceValue | NonNilInterfaceValue +func (deb *debugger) interfaceValue(indent tab) { + deb.dump("Start of interface value") + if nameLen := deb.uint64(); nameLen == 0 { + deb.nilInterfaceValue(indent) + } else { + deb.nonNilInterfaceValue(indent, int(nameLen)) + } +} + +// NilInterfaceValue: +// uint(0) [already read] +func (deb *debugger) nilInterfaceValue(indent tab) int { + fmt.Fprintf(os.Stderr, "%snil interface\n", indent) + return 0 +} + +// NonNilInterfaceValue: +// ConcreteTypeName TypeSequence InterfaceContents +// ConcreteTypeName: +// uint(lengthOfName) [already read=n] name +// InterfaceContents: +// int(concreteTypeId) DelimitedValue +// DelimitedValue: +// uint(length) Value +func (deb *debugger) nonNilInterfaceValue(indent tab, nameLen int) { + // ConcreteTypeName + b := make([]byte, nameLen) + deb.r.Read(b) // TODO: CHECK THESE READS!! + deb.consumed(nameLen) + name := string(b) + + for { + id := deb.typeId() + if id < 0 { + deb.typeDefinition(indent, -id) + n := deb.loadBlock(false) + deb.dump("Nested message of length %d", n) + } else { + // DelimitedValue + x := deb.uint64() // in case we want to ignore the value; we don't. + fmt.Fprintf(os.Stderr, "%sinterface value, type %q id=%d; valueLength %d\n", indent, name, id, x) + deb.value(indent, id) + break + } + } +} + +// printCommonType prints a common type; used by printWireType. +func (deb *debugger) printCommonType(indent tab, kind string, common *CommonType) { + indent.print() + fmt.Fprintf(os.Stderr, "%s %q id=%d\n", kind, common.Name, common.Id) +} + +// printWireType prints the contents of a wireType. +func (deb *debugger) printWireType(indent tab, wire *wireType) { + fmt.Fprintf(os.Stderr, "%stype definition {\n", indent) + indent++ + switch { + case wire.ArrayT != nil: + deb.printCommonType(indent, "array", &wire.ArrayT.CommonType) + fmt.Fprintf(os.Stderr, "%slen %d\n", indent+1, wire.ArrayT.Len) + fmt.Fprintf(os.Stderr, "%selemid %d\n", indent+1, wire.ArrayT.Elem) + case wire.MapT != nil: + deb.printCommonType(indent, "map", &wire.MapT.CommonType) + fmt.Fprintf(os.Stderr, "%skey id=%d\n", indent+1, wire.MapT.Key) + fmt.Fprintf(os.Stderr, "%selem id=%d\n", indent+1, wire.MapT.Elem) + case wire.SliceT != nil: + deb.printCommonType(indent, "slice", &wire.SliceT.CommonType) + fmt.Fprintf(os.Stderr, "%selem id=%d\n", indent+1, wire.SliceT.Elem) + case wire.StructT != nil: + deb.printCommonType(indent, "struct", &wire.StructT.CommonType) + for i, field := range wire.StructT.Field { + fmt.Fprintf(os.Stderr, "%sfield %d:\t%s\tid=%d\n", indent+1, i, field.Name, field.Id) + } + case wire.GobEncoderT != nil: + deb.printCommonType(indent, "GobEncoder", &wire.GobEncoderT.CommonType) + } + indent-- + fmt.Fprintf(os.Stderr, "%s}\n", indent) +} + +// fieldValue prints a value of any type, such as a struct field. +// FieldValue: +// builtinValue | ArrayValue | MapValue | SliceValue | StructValue | InterfaceValue +func (deb *debugger) fieldValue(indent tab, id typeId) { + _, ok := builtinIdToType[id] + if ok { + if id == tInterface { + deb.interfaceValue(indent) + } else { + deb.printBuiltin(indent, id) + } + return + } + wire, ok := deb.wireType[id] + if !ok { + errorf("type id %d not defined", id) + } + switch { + case wire.ArrayT != nil: + deb.arrayValue(indent, wire) + case wire.MapT != nil: + deb.mapValue(indent, wire) + case wire.SliceT != nil: + deb.sliceValue(indent, wire) + case wire.StructT != nil: + deb.structValue(indent, id) + case wire.GobEncoderT != nil: + deb.gobEncoderValue(indent, id) + default: + panic("bad wire type for field") + } +} + +// printBuiltin prints a value not of a fundamental type, that is, +// one whose type is known to gobs at bootstrap time. +func (deb *debugger) printBuiltin(indent tab, id typeId) { + switch id { + case tBool: + x := deb.int64() + if x == 0 { + fmt.Fprintf(os.Stderr, "%sfalse\n", indent) + } else { + fmt.Fprintf(os.Stderr, "%strue\n", indent) + } + case tInt: + x := deb.int64() + fmt.Fprintf(os.Stderr, "%s%d\n", indent, x) + case tUint: + x := deb.int64() + fmt.Fprintf(os.Stderr, "%s%d\n", indent, x) + case tFloat: + x := deb.uint64() + fmt.Fprintf(os.Stderr, "%s%g\n", indent, float64FromBits(x)) + case tComplex: + r := deb.uint64() + i := deb.uint64() + fmt.Fprintf(os.Stderr, "%s%g+%gi\n", indent, float64FromBits(r), float64FromBits(i)) + case tBytes: + x := int(deb.uint64()) + b := make([]byte, x) + deb.r.Read(b) + deb.consumed(x) + fmt.Fprintf(os.Stderr, "%s{% x}=%q\n", indent, b, b) + case tString: + x := int(deb.uint64()) + b := make([]byte, x) + deb.r.Read(b) + deb.consumed(x) + fmt.Fprintf(os.Stderr, "%s%q\n", indent, b) + default: + panic("unknown builtin") + } +} + +// ArrayValue: +// uint(n) FieldValue*n +func (deb *debugger) arrayValue(indent tab, wire *wireType) { + elemId := wire.ArrayT.Elem + u := deb.uint64() + length := int(u) + for i := 0; i < length; i++ { + deb.fieldValue(indent, elemId) + } + if length != wire.ArrayT.Len { + fmt.Fprintf(os.Stderr, "%s(wrong length for array: %d should be %d)\n", indent, length, wire.ArrayT.Len) + } +} + +// MapValue: +// uint(n) (FieldValue FieldValue)*n [n (key, value) pairs] +func (deb *debugger) mapValue(indent tab, wire *wireType) { + keyId := wire.MapT.Key + elemId := wire.MapT.Elem + u := deb.uint64() + length := int(u) + for i := 0; i < length; i++ { + deb.fieldValue(indent+1, keyId) + deb.fieldValue(indent+1, elemId) + } +} + +// SliceValue: +// uint(n) (n FieldValue) +func (deb *debugger) sliceValue(indent tab, wire *wireType) { + elemId := wire.SliceT.Elem + u := deb.uint64() + length := int(u) + deb.dump("Start of slice of length %d", length) + + for i := 0; i < length; i++ { + deb.fieldValue(indent, elemId) + } +} + +// StructValue: +// (uint(fieldDelta) FieldValue)* +func (deb *debugger) structValue(indent tab, id typeId) { + deb.dump("Start of struct value of %q id=%d\n<<\n", id.name(), id) + fmt.Fprintf(os.Stderr, "%s%s struct {\n", indent, id.name()) + wire, ok := deb.wireType[id] + if !ok { + errorf("type id %d not defined", id) + } + strct := wire.StructT + fieldNum := -1 + indent++ + for { + delta := deb.uint64() + if delta == 0 { // struct terminator is zero delta fieldnum + break + } + fieldNum += int(delta) + if fieldNum < 0 || fieldNum >= len(strct.Field) { + deb.dump("field number out of range: prevField=%d delta=%d", fieldNum-int(delta), delta) + break + } + fmt.Fprintf(os.Stderr, "%sfield %d:\t%s\n", indent, fieldNum, wire.StructT.Field[fieldNum].Name) + deb.fieldValue(indent+1, strct.Field[fieldNum].Id) + } + indent-- + fmt.Fprintf(os.Stderr, "%s} // end %s struct\n", indent, id.name()) + deb.dump(">> End of struct value of type %d %q", id, id.name()) +} + +// GobEncoderValue: +// uint(n) byte*n +func (deb *debugger) gobEncoderValue(indent tab, id typeId) { + len := deb.uint64() + deb.dump("GobEncoder value of %q id=%d, length %d\n", id.name(), id, len) + fmt.Fprintf(os.Stderr, "%s%s (implements GobEncoder)\n", indent, id.name()) + data := make([]byte, len) + _, err := deb.r.Read(data) + if err != nil { + errorf("gobEncoder data read: %s", err) + } + fmt.Fprintf(os.Stderr, "%s[% .2x]\n", indent+1, data) +} diff --git a/src/encoding/gob/dec_helpers.go b/src/encoding/gob/dec_helpers.go new file mode 100644 index 000000000..a1b67661d --- /dev/null +++ b/src/encoding/gob/dec_helpers.go @@ -0,0 +1,468 @@ +// Created by decgen --output dec_helpers.go; DO NOT EDIT + +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "math" + "reflect" +) + +var decArrayHelper = map[reflect.Kind]decHelper{ + reflect.Bool: decBoolArray, + reflect.Complex64: decComplex64Array, + reflect.Complex128: decComplex128Array, + reflect.Float32: decFloat32Array, + reflect.Float64: decFloat64Array, + reflect.Int: decIntArray, + reflect.Int16: decInt16Array, + reflect.Int32: decInt32Array, + reflect.Int64: decInt64Array, + reflect.Int8: decInt8Array, + reflect.String: decStringArray, + reflect.Uint: decUintArray, + reflect.Uint16: decUint16Array, + reflect.Uint32: decUint32Array, + reflect.Uint64: decUint64Array, + reflect.Uintptr: decUintptrArray, +} + +var decSliceHelper = map[reflect.Kind]decHelper{ + reflect.Bool: decBoolSlice, + reflect.Complex64: decComplex64Slice, + reflect.Complex128: decComplex128Slice, + reflect.Float32: decFloat32Slice, + reflect.Float64: decFloat64Slice, + reflect.Int: decIntSlice, + reflect.Int16: decInt16Slice, + reflect.Int32: decInt32Slice, + reflect.Int64: decInt64Slice, + reflect.Int8: decInt8Slice, + reflect.String: decStringSlice, + reflect.Uint: decUintSlice, + reflect.Uint16: decUint16Slice, + reflect.Uint32: decUint32Slice, + reflect.Uint64: decUint64Slice, + reflect.Uintptr: decUintptrSlice, +} + +func decBoolArray(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decBoolSlice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decBoolSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]bool) + if !ok { + // It is kind bool but not type bool. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding bool array or slice: length exceeds input size (%d elements)", length) + } + slice[i] = state.decodeUint() != 0 + } + return true +} + +func decComplex64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decComplex64Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decComplex64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]complex64) + if !ok { + // It is kind complex64 but not type complex64. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding complex64 array or slice: length exceeds input size (%d elements)", length) + } + real := float32FromBits(state.decodeUint(), ovfl) + imag := float32FromBits(state.decodeUint(), ovfl) + slice[i] = complex(float32(real), float32(imag)) + } + return true +} + +func decComplex128Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decComplex128Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decComplex128Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]complex128) + if !ok { + // It is kind complex128 but not type complex128. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding complex128 array or slice: length exceeds input size (%d elements)", length) + } + real := float64FromBits(state.decodeUint()) + imag := float64FromBits(state.decodeUint()) + slice[i] = complex(real, imag) + } + return true +} + +func decFloat32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decFloat32Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decFloat32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]float32) + if !ok { + // It is kind float32 but not type float32. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding float32 array or slice: length exceeds input size (%d elements)", length) + } + slice[i] = float32(float32FromBits(state.decodeUint(), ovfl)) + } + return true +} + +func decFloat64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decFloat64Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decFloat64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]float64) + if !ok { + // It is kind float64 but not type float64. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding float64 array or slice: length exceeds input size (%d elements)", length) + } + slice[i] = float64FromBits(state.decodeUint()) + } + return true +} + +func decIntArray(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decIntSlice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decIntSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]int) + if !ok { + // It is kind int but not type int. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding int array or slice: length exceeds input size (%d elements)", length) + } + x := state.decodeInt() + // MinInt and MaxInt + if x < ^int64(^uint(0)>>1) || int64(^uint(0)>>1) < x { + error_(ovfl) + } + slice[i] = int(x) + } + return true +} + +func decInt16Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decInt16Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decInt16Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]int16) + if !ok { + // It is kind int16 but not type int16. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding int16 array or slice: length exceeds input size (%d elements)", length) + } + x := state.decodeInt() + if x < math.MinInt16 || math.MaxInt16 < x { + error_(ovfl) + } + slice[i] = int16(x) + } + return true +} + +func decInt32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decInt32Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decInt32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]int32) + if !ok { + // It is kind int32 but not type int32. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding int32 array or slice: length exceeds input size (%d elements)", length) + } + x := state.decodeInt() + if x < math.MinInt32 || math.MaxInt32 < x { + error_(ovfl) + } + slice[i] = int32(x) + } + return true +} + +func decInt64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decInt64Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decInt64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]int64) + if !ok { + // It is kind int64 but not type int64. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding int64 array or slice: length exceeds input size (%d elements)", length) + } + slice[i] = state.decodeInt() + } + return true +} + +func decInt8Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decInt8Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decInt8Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]int8) + if !ok { + // It is kind int8 but not type int8. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding int8 array or slice: length exceeds input size (%d elements)", length) + } + x := state.decodeInt() + if x < math.MinInt8 || math.MaxInt8 < x { + error_(ovfl) + } + slice[i] = int8(x) + } + return true +} + +func decStringArray(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decStringSlice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decStringSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]string) + if !ok { + // It is kind string but not type string. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding string array or slice: length exceeds input size (%d elements)", length) + } + u := state.decodeUint() + n := int(u) + if n < 0 || uint64(n) != u || n > state.b.Len() { + errorf("length of string exceeds input size (%d bytes)", u) + } + if n > state.b.Len() { + errorf("string data too long for buffer: %d", n) + } + // Read the data. + data := make([]byte, n) + if _, err := state.b.Read(data); err != nil { + errorf("error decoding string: %s", err) + } + slice[i] = string(data) + } + return true +} + +func decUintArray(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decUintSlice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decUintSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]uint) + if !ok { + // It is kind uint but not type uint. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding uint array or slice: length exceeds input size (%d elements)", length) + } + x := state.decodeUint() + /*TODO if math.MaxUint32 < x { + error_(ovfl) + }*/ + slice[i] = uint(x) + } + return true +} + +func decUint16Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decUint16Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decUint16Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]uint16) + if !ok { + // It is kind uint16 but not type uint16. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding uint16 array or slice: length exceeds input size (%d elements)", length) + } + x := state.decodeUint() + if math.MaxUint16 < x { + error_(ovfl) + } + slice[i] = uint16(x) + } + return true +} + +func decUint32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decUint32Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decUint32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]uint32) + if !ok { + // It is kind uint32 but not type uint32. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding uint32 array or slice: length exceeds input size (%d elements)", length) + } + x := state.decodeUint() + if math.MaxUint32 < x { + error_(ovfl) + } + slice[i] = uint32(x) + } + return true +} + +func decUint64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decUint64Slice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decUint64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]uint64) + if !ok { + // It is kind uint64 but not type uint64. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding uint64 array or slice: length exceeds input size (%d elements)", length) + } + slice[i] = state.decodeUint() + } + return true +} + +func decUintptrArray(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return decUintptrSlice(state, v.Slice(0, v.Len()), length, ovfl) +} + +func decUintptrSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]uintptr) + if !ok { + // It is kind uintptr but not type uintptr. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding uintptr array or slice: length exceeds input size (%d elements)", length) + } + x := state.decodeUint() + if uint64(^uintptr(0)) < x { + error_(ovfl) + } + slice[i] = uintptr(x) + } + return true +} diff --git a/src/encoding/gob/decgen.go b/src/encoding/gob/decgen.go new file mode 100644 index 000000000..da41a899e --- /dev/null +++ b/src/encoding/gob/decgen.go @@ -0,0 +1,240 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +// encgen writes the helper functions for encoding. Intended to be +// used with go generate; see the invocation in encode.go. + +// TODO: We could do more by being unsafe. Add a -unsafe flag? + +package main + +import ( + "bytes" + "flag" + "fmt" + "go/format" + "log" + "os" +) + +var output = flag.String("output", "dec_helpers.go", "file name to write") + +type Type struct { + lower string + upper string + decoder string +} + +var types = []Type{ + { + "bool", + "Bool", + `slice[i] = state.decodeUint() != 0`, + }, + { + "complex64", + "Complex64", + `real := float32FromBits(state.decodeUint(), ovfl) + imag := float32FromBits(state.decodeUint(), ovfl) + slice[i] = complex(float32(real), float32(imag))`, + }, + { + "complex128", + "Complex128", + `real := float64FromBits(state.decodeUint()) + imag := float64FromBits(state.decodeUint()) + slice[i] = complex(real, imag)`, + }, + { + "float32", + "Float32", + `slice[i] = float32(float32FromBits(state.decodeUint(), ovfl))`, + }, + { + "float64", + "Float64", + `slice[i] = float64FromBits(state.decodeUint())`, + }, + { + "int", + "Int", + `x := state.decodeInt() + // MinInt and MaxInt + if x < ^int64(^uint(0)>>1) || int64(^uint(0)>>1) < x { + error_(ovfl) + } + slice[i] = int(x)`, + }, + { + "int16", + "Int16", + `x := state.decodeInt() + if x < math.MinInt16 || math.MaxInt16 < x { + error_(ovfl) + } + slice[i] = int16(x)`, + }, + { + "int32", + "Int32", + `x := state.decodeInt() + if x < math.MinInt32 || math.MaxInt32 < x { + error_(ovfl) + } + slice[i] = int32(x)`, + }, + { + "int64", + "Int64", + `slice[i] = state.decodeInt()`, + }, + { + "int8", + "Int8", + `x := state.decodeInt() + if x < math.MinInt8 || math.MaxInt8 < x { + error_(ovfl) + } + slice[i] = int8(x)`, + }, + { + "string", + "String", + `u := state.decodeUint() + n := int(u) + if n < 0 || uint64(n) != u || n > state.b.Len() { + errorf("length of string exceeds input size (%d bytes)", u) + } + if n > state.b.Len() { + errorf("string data too long for buffer: %d", n) + } + // Read the data. + data := make([]byte, n) + if _, err := state.b.Read(data); err != nil { + errorf("error decoding string: %s", err) + } + slice[i] = string(data)`, + }, + { + "uint", + "Uint", + `x := state.decodeUint() + /*TODO if math.MaxUint32 < x { + error_(ovfl) + }*/ + slice[i] = uint(x)`, + }, + { + "uint16", + "Uint16", + `x := state.decodeUint() + if math.MaxUint16 < x { + error_(ovfl) + } + slice[i] = uint16(x)`, + }, + { + "uint32", + "Uint32", + `x := state.decodeUint() + if math.MaxUint32 < x { + error_(ovfl) + } + slice[i] = uint32(x)`, + }, + { + "uint64", + "Uint64", + `slice[i] = state.decodeUint()`, + }, + { + "uintptr", + "Uintptr", + `x := state.decodeUint() + if uint64(^uintptr(0)) < x { + error_(ovfl) + } + slice[i] = uintptr(x)`, + }, + // uint8 Handled separately. +} + +func main() { + log.SetFlags(0) + log.SetPrefix("decgen: ") + flag.Parse() + if flag.NArg() != 0 { + log.Fatal("usage: decgen [--output filename]") + } + var b bytes.Buffer + fmt.Fprintf(&b, "// Created by decgen --output %s; DO NOT EDIT\n", *output) + fmt.Fprint(&b, header) + printMaps(&b, "Array") + fmt.Fprint(&b, "\n") + printMaps(&b, "Slice") + for _, t := range types { + fmt.Fprintf(&b, arrayHelper, t.lower, t.upper) + fmt.Fprintf(&b, sliceHelper, t.lower, t.upper, t.decoder) + } + source, err := format.Source(b.Bytes()) + if err != nil { + log.Fatal("source format error:", err) + } + fd, err := os.Create(*output) + _, err = fd.Write(source) + if err != nil { + log.Fatal(err) + } +} + +func printMaps(b *bytes.Buffer, upperClass string) { + fmt.Fprintf(b, "var dec%sHelper = map[reflect.Kind]decHelper{\n", upperClass) + for _, t := range types { + fmt.Fprintf(b, "reflect.%s: dec%s%s,\n", t.upper, t.upper, upperClass) + } + fmt.Fprintf(b, "}\n") +} + +const header = ` +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "math" + "reflect" +) + +` + +const arrayHelper = ` +func dec%[2]sArray(state *decoderState, v reflect.Value, length int, ovfl error) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return dec%[2]sSlice(state, v.Slice(0, v.Len()), length, ovfl) +} +` + +const sliceHelper = ` +func dec%[2]sSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool { + slice, ok := v.Interface().([]%[1]s) + if !ok { + // It is kind %[1]s but not type %[1]s. TODO: We can handle this unsafely. + return false + } + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding %[1]s array or slice: length exceeds input size (%%d elements)", length) + } + %[3]s + } + return true +} +` diff --git a/src/encoding/gob/decode.go b/src/encoding/gob/decode.go new file mode 100644 index 000000000..a5bef9314 --- /dev/null +++ b/src/encoding/gob/decode.go @@ -0,0 +1,1217 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:generate go run decgen.go -output dec_helpers.go + +package gob + +import ( + "encoding" + "errors" + "io" + "math" + "reflect" +) + +var ( + errBadUint = errors.New("gob: encoded unsigned integer out of range") + errBadType = errors.New("gob: unknown type id or corrupted data") + errRange = errors.New("gob: bad data: field numbers out of bounds") +) + +type decHelper func(state *decoderState, v reflect.Value, length int, ovfl error) bool + +// decoderState is the execution state of an instance of the decoder. A new state +// is created for nested objects. +type decoderState struct { + dec *Decoder + // The buffer is stored with an extra indirection because it may be replaced + // if we load a type during decode (when reading an interface value). + b *decBuffer + fieldnum int // the last field number read. + buf []byte + next *decoderState // for free list +} + +// decBuffer is an extremely simple, fast implementation of a read-only byte buffer. +// It is initialized by calling Size and then copying the data into the slice returned by Bytes(). +type decBuffer struct { + data []byte + offset int // Read offset. +} + +func (d *decBuffer) Read(p []byte) (int, error) { + n := copy(p, d.data[d.offset:]) + if n == 0 && len(p) != 0 { + return 0, io.EOF + } + d.offset += n + return n, nil +} + +func (d *decBuffer) Drop(n int) { + if n > d.Len() { + panic("drop") + } + d.offset += n +} + +// Size grows the buffer to exactly n bytes, so d.Bytes() will +// return a slice of length n. Existing data is first discarded. +func (d *decBuffer) Size(n int) { + d.Reset() + if cap(d.data) < n { + d.data = make([]byte, n) + } else { + d.data = d.data[0:n] + } +} + +func (d *decBuffer) ReadByte() (byte, error) { + if d.offset >= len(d.data) { + return 0, io.EOF + } + c := d.data[d.offset] + d.offset++ + return c, nil +} + +func (d *decBuffer) Len() int { + return len(d.data) - d.offset +} + +func (d *decBuffer) Bytes() []byte { + return d.data[d.offset:] +} + +func (d *decBuffer) Reset() { + d.data = d.data[0:0] + d.offset = 0 +} + +// We pass the bytes.Buffer separately for easier testing of the infrastructure +// without requiring a full Decoder. +func (dec *Decoder) newDecoderState(buf *decBuffer) *decoderState { + d := dec.freeList + if d == nil { + d = new(decoderState) + d.dec = dec + d.buf = make([]byte, uint64Size) + } else { + dec.freeList = d.next + } + d.b = buf + return d +} + +func (dec *Decoder) freeDecoderState(d *decoderState) { + d.next = dec.freeList + dec.freeList = d +} + +func overflow(name string) error { + return errors.New(`value for "` + name + `" out of range`) +} + +// decodeUintReader reads an encoded unsigned integer from an io.Reader. +// Used only by the Decoder to read the message length. +func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err error) { + width = 1 + n, err := io.ReadFull(r, buf[0:width]) + if n == 0 { + return + } + b := buf[0] + if b <= 0x7f { + return uint64(b), width, nil + } + n = -int(int8(b)) + if n > uint64Size { + err = errBadUint + return + } + width, err = io.ReadFull(r, buf[0:n]) + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return + } + // Could check that the high byte is zero but it's not worth it. + for _, b := range buf[0:width] { + x = x<<8 | uint64(b) + } + width++ // +1 for length byte + return +} + +// decodeUint reads an encoded unsigned integer from state.r. +// Does not check for overflow. +func (state *decoderState) decodeUint() (x uint64) { + b, err := state.b.ReadByte() + if err != nil { + error_(err) + } + if b <= 0x7f { + return uint64(b) + } + n := -int(int8(b)) + if n > uint64Size { + error_(errBadUint) + } + width, err := state.b.Read(state.buf[0:n]) + if err != nil { + error_(err) + } + // Don't need to check error; it's safe to loop regardless. + // Could check that the high byte is zero but it's not worth it. + for _, b := range state.buf[0:width] { + x = x<<8 | uint64(b) + } + return x +} + +// decodeInt reads an encoded signed integer from state.r. +// Does not check for overflow. +func (state *decoderState) decodeInt() int64 { + x := state.decodeUint() + if x&1 != 0 { + return ^int64(x >> 1) + } + return int64(x >> 1) +} + +// decOp is the signature of a decoding operator for a given type. +type decOp func(i *decInstr, state *decoderState, v reflect.Value) + +// The 'instructions' of the decoding machine +type decInstr struct { + op decOp + field int // field number of the wire type + index []int // field access indices for destination type + ovfl error // error message for overflow/underflow (for arrays, of the elements) +} + +// ignoreUint discards a uint value with no destination. +func ignoreUint(i *decInstr, state *decoderState, v reflect.Value) { + state.decodeUint() +} + +// ignoreTwoUints discards a uint value with no destination. It's used to skip +// complex values. +func ignoreTwoUints(i *decInstr, state *decoderState, v reflect.Value) { + state.decodeUint() + state.decodeUint() +} + +// Since the encoder writes no zeros, if we arrive at a decoder we have +// a value to extract and store. The field number has already been read +// (it's how we knew to call this decoder). +// Each decoder is responsible for handling any indirections associated +// with the data structure. If any pointer so reached is nil, allocation must +// be done. + +// decAlloc takes a value and returns a settable value that can +// be assigned to. If the value is a pointer, decAlloc guarantees it points to storage. +// The callers to the individual decoders are expected to have used decAlloc. +// The individual decoders don't need to it. +func decAlloc(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} + +// decBool decodes a uint and stores it as a boolean in value. +func decBool(i *decInstr, state *decoderState, value reflect.Value) { + value.SetBool(state.decodeUint() != 0) +} + +// decInt8 decodes an integer and stores it as an int8 in value. +func decInt8(i *decInstr, state *decoderState, value reflect.Value) { + v := state.decodeInt() + if v < math.MinInt8 || math.MaxInt8 < v { + error_(i.ovfl) + } + value.SetInt(v) +} + +// decUint8 decodes an unsigned integer and stores it as a uint8 in value. +func decUint8(i *decInstr, state *decoderState, value reflect.Value) { + v := state.decodeUint() + if math.MaxUint8 < v { + error_(i.ovfl) + } + value.SetUint(v) +} + +// decInt16 decodes an integer and stores it as an int16 in value. +func decInt16(i *decInstr, state *decoderState, value reflect.Value) { + v := state.decodeInt() + if v < math.MinInt16 || math.MaxInt16 < v { + error_(i.ovfl) + } + value.SetInt(v) +} + +// decUint16 decodes an unsigned integer and stores it as a uint16 in value. +func decUint16(i *decInstr, state *decoderState, value reflect.Value) { + v := state.decodeUint() + if math.MaxUint16 < v { + error_(i.ovfl) + } + value.SetUint(v) +} + +// decInt32 decodes an integer and stores it as an int32 in value. +func decInt32(i *decInstr, state *decoderState, value reflect.Value) { + v := state.decodeInt() + if v < math.MinInt32 || math.MaxInt32 < v { + error_(i.ovfl) + } + value.SetInt(v) +} + +// decUint32 decodes an unsigned integer and stores it as a uint32 in value. +func decUint32(i *decInstr, state *decoderState, value reflect.Value) { + v := state.decodeUint() + if math.MaxUint32 < v { + error_(i.ovfl) + } + value.SetUint(v) +} + +// decInt64 decodes an integer and stores it as an int64 in value. +func decInt64(i *decInstr, state *decoderState, value reflect.Value) { + v := state.decodeInt() + value.SetInt(v) +} + +// decUint64 decodes an unsigned integer and stores it as a uint64 in value. +func decUint64(i *decInstr, state *decoderState, value reflect.Value) { + v := state.decodeUint() + value.SetUint(v) +} + +// Floating-point numbers are transmitted as uint64s holding the bits +// of the underlying representation. They are sent byte-reversed, with +// the exponent end coming out first, so integer floating point numbers +// (for example) transmit more compactly. This routine does the +// unswizzling. +func float64FromBits(u uint64) float64 { + var v uint64 + for i := 0; i < 8; i++ { + v <<= 8 + v |= u & 0xFF + u >>= 8 + } + return math.Float64frombits(v) +} + +// float32FromBits decodes an unsigned integer, treats it as a 32-bit floating-point +// number, and returns it. It's a helper function for float32 and complex64. +// It returns a float64 because that's what reflection needs, but its return +// value is known to be accurately representable in a float32. +func float32FromBits(u uint64, ovfl error) float64 { + v := float64FromBits(u) + av := v + if av < 0 { + av = -av + } + // +Inf is OK in both 32- and 64-bit floats. Underflow is always OK. + if math.MaxFloat32 < av && av <= math.MaxFloat64 { + error_(ovfl) + } + return v +} + +// decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point +// number, and stores it in value. +func decFloat32(i *decInstr, state *decoderState, value reflect.Value) { + value.SetFloat(float32FromBits(state.decodeUint(), i.ovfl)) +} + +// decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point +// number, and stores it in value. +func decFloat64(i *decInstr, state *decoderState, value reflect.Value) { + value.SetFloat(float64FromBits(state.decodeUint())) +} + +// decComplex64 decodes a pair of unsigned integers, treats them as a +// pair of floating point numbers, and stores them as a complex64 in value. +// The real part comes first. +func decComplex64(i *decInstr, state *decoderState, value reflect.Value) { + real := float32FromBits(state.decodeUint(), i.ovfl) + imag := float32FromBits(state.decodeUint(), i.ovfl) + value.SetComplex(complex(real, imag)) +} + +// decComplex128 decodes a pair of unsigned integers, treats them as a +// pair of floating point numbers, and stores them as a complex128 in value. +// The real part comes first. +func decComplex128(i *decInstr, state *decoderState, value reflect.Value) { + real := float64FromBits(state.decodeUint()) + imag := float64FromBits(state.decodeUint()) + value.SetComplex(complex(real, imag)) +} + +// decUint8Slice decodes a byte slice and stores in value a slice header +// describing the data. +// uint8 slices are encoded as an unsigned count followed by the raw bytes. +func decUint8Slice(i *decInstr, state *decoderState, value reflect.Value) { + u := state.decodeUint() + n := int(u) + if n < 0 || uint64(n) != u { + errorf("length of %s exceeds input size (%d bytes)", value.Type(), u) + } + if n > state.b.Len() { + errorf("%s data too long for buffer: %d", value.Type(), n) + } + if n > tooBig { + errorf("byte slice too big: %d", n) + } + if value.Cap() < n { + value.Set(reflect.MakeSlice(value.Type(), n, n)) + } else { + value.Set(value.Slice(0, n)) + } + if _, err := state.b.Read(value.Bytes()); err != nil { + errorf("error decoding []byte: %s", err) + } +} + +// decString decodes byte array and stores in value a string header +// describing the data. +// Strings are encoded as an unsigned count followed by the raw bytes. +func decString(i *decInstr, state *decoderState, value reflect.Value) { + u := state.decodeUint() + n := int(u) + if n < 0 || uint64(n) != u || n > state.b.Len() { + errorf("length of %s exceeds input size (%d bytes)", value.Type(), u) + } + if n > state.b.Len() { + errorf("%s data too long for buffer: %d", value.Type(), n) + } + // Read the data. + data := make([]byte, n) + if _, err := state.b.Read(data); err != nil { + errorf("error decoding string: %s", err) + } + value.SetString(string(data)) +} + +// ignoreUint8Array skips over the data for a byte slice value with no destination. +func ignoreUint8Array(i *decInstr, state *decoderState, value reflect.Value) { + b := make([]byte, state.decodeUint()) + state.b.Read(b) +} + +// Execution engine + +// The encoder engine is an array of instructions indexed by field number of the incoming +// decoder. It is executed with random access according to field number. +type decEngine struct { + instr []decInstr + numInstr int // the number of active instructions +} + +// decodeSingle decodes a top-level value that is not a struct and stores it in value. +// Such values are preceded by a zero, making them have the memory layout of a +// struct field (although with an illegal field number). +func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, value reflect.Value) { + state := dec.newDecoderState(&dec.buf) + defer dec.freeDecoderState(state) + state.fieldnum = singletonField + if state.decodeUint() != 0 { + errorf("decode: corrupted data: non-zero delta for singleton") + } + instr := &engine.instr[singletonField] + instr.op(instr, state, value) +} + +// decodeStruct decodes a top-level struct and stores it in value. +// Indir is for the value, not the type. At the time of the call it may +// differ from ut.indir, which was computed when the engine was built. +// This state cannot arise for decodeSingle, which is called directly +// from the user's value, not from the innards of an engine. +func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, value reflect.Value) { + state := dec.newDecoderState(&dec.buf) + defer dec.freeDecoderState(state) + state.fieldnum = -1 + for state.b.Len() > 0 { + delta := int(state.decodeUint()) + if delta < 0 { + errorf("decode: corrupted data: negative delta") + } + if delta == 0 { // struct terminator is zero delta fieldnum + break + } + fieldnum := state.fieldnum + delta + if fieldnum >= len(engine.instr) { + error_(errRange) + break + } + instr := &engine.instr[fieldnum] + var field reflect.Value + if instr.index != nil { + // Otherwise the field is unknown to us and instr.op is an ignore op. + field = value.FieldByIndex(instr.index) + if field.Kind() == reflect.Ptr { + field = decAlloc(field) + } + } + instr.op(instr, state, field) + state.fieldnum = fieldnum + } +} + +var noValue reflect.Value + +// ignoreStruct discards the data for a struct with no destination. +func (dec *Decoder) ignoreStruct(engine *decEngine) { + state := dec.newDecoderState(&dec.buf) + defer dec.freeDecoderState(state) + state.fieldnum = -1 + for state.b.Len() > 0 { + delta := int(state.decodeUint()) + if delta < 0 { + errorf("ignore decode: corrupted data: negative delta") + } + if delta == 0 { // struct terminator is zero delta fieldnum + break + } + fieldnum := state.fieldnum + delta + if fieldnum >= len(engine.instr) { + error_(errRange) + } + instr := &engine.instr[fieldnum] + instr.op(instr, state, noValue) + state.fieldnum = fieldnum + } +} + +// ignoreSingle discards the data for a top-level non-struct value with no +// destination. It's used when calling Decode with a nil value. +func (dec *Decoder) ignoreSingle(engine *decEngine) { + state := dec.newDecoderState(&dec.buf) + defer dec.freeDecoderState(state) + state.fieldnum = singletonField + delta := int(state.decodeUint()) + if delta != 0 { + errorf("decode: corrupted data: non-zero delta for singleton") + } + instr := &engine.instr[singletonField] + instr.op(instr, state, noValue) +} + +// decodeArrayHelper does the work for decoding arrays and slices. +func (dec *Decoder) decodeArrayHelper(state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error, helper decHelper) { + if helper != nil && helper(state, value, length, ovfl) { + return + } + instr := &decInstr{elemOp, 0, nil, ovfl} + isPtr := value.Type().Elem().Kind() == reflect.Ptr + for i := 0; i < length; i++ { + if state.b.Len() == 0 { + errorf("decoding array or slice: length exceeds input size (%d elements)", length) + } + v := value.Index(i) + if isPtr { + v = decAlloc(v) + } + elemOp(instr, state, v) + } +} + +// decodeArray decodes an array and stores it in value. +// The length is an unsigned integer preceding the elements. Even though the length is redundant +// (it's part of the type), it's a useful check and is included in the encoding. +func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error, helper decHelper) { + if n := state.decodeUint(); n != uint64(length) { + errorf("length mismatch in decodeArray") + } + dec.decodeArrayHelper(state, value, elemOp, length, ovfl, helper) +} + +// decodeIntoValue is a helper for map decoding. +func decodeIntoValue(state *decoderState, op decOp, isPtr bool, value reflect.Value, ovfl error) reflect.Value { + instr := &decInstr{op, 0, nil, ovfl} + v := value + if isPtr { + v = decAlloc(value) + } + op(instr, state, v) + return value +} + +// decodeMap decodes a map and stores it in value. +// Maps are encoded as a length followed by key:value pairs. +// Because the internals of maps are not visible to us, we must +// use reflection rather than pointer magic. +func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, value reflect.Value, keyOp, elemOp decOp, ovfl error) { + if value.IsNil() { + // Allocate map. + value.Set(reflect.MakeMap(mtyp)) + } + n := int(state.decodeUint()) + keyIsPtr := mtyp.Key().Kind() == reflect.Ptr + elemIsPtr := mtyp.Elem().Kind() == reflect.Ptr + for i := 0; i < n; i++ { + key := decodeIntoValue(state, keyOp, keyIsPtr, allocValue(mtyp.Key()), ovfl) + elem := decodeIntoValue(state, elemOp, elemIsPtr, allocValue(mtyp.Elem()), ovfl) + value.SetMapIndex(key, elem) + } +} + +// ignoreArrayHelper does the work for discarding arrays and slices. +func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length int) { + instr := &decInstr{elemOp, 0, nil, errors.New("no error")} + for i := 0; i < length; i++ { + elemOp(instr, state, noValue) + } +} + +// ignoreArray discards the data for an array value with no destination. +func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) { + if n := state.decodeUint(); n != uint64(length) { + errorf("length mismatch in ignoreArray") + } + dec.ignoreArrayHelper(state, elemOp, length) +} + +// ignoreMap discards the data for a map value with no destination. +func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) { + n := int(state.decodeUint()) + keyInstr := &decInstr{keyOp, 0, nil, errors.New("no error")} + elemInstr := &decInstr{elemOp, 0, nil, errors.New("no error")} + for i := 0; i < n; i++ { + keyOp(keyInstr, state, noValue) + elemOp(elemInstr, state, noValue) + } +} + +// decodeSlice decodes a slice and stores it in value. +// Slices are encoded as an unsigned length followed by the elements. +func (dec *Decoder) decodeSlice(state *decoderState, value reflect.Value, elemOp decOp, ovfl error, helper decHelper) { + u := state.decodeUint() + typ := value.Type() + size := uint64(typ.Elem().Size()) + nBytes := u * size + n := int(u) + // Take care with overflow in this calculation. + if n < 0 || uint64(n) != u || nBytes > tooBig || (size > 0 && nBytes/size != u) { + // We don't check n against buffer length here because if it's a slice + // of interfaces, there will be buffer reloads. + errorf("%s slice too big: %d elements of %d bytes", typ.Elem(), u, size) + } + if value.Cap() < n { + value.Set(reflect.MakeSlice(typ, n, n)) + } else { + value.Set(value.Slice(0, n)) + } + dec.decodeArrayHelper(state, value, elemOp, n, ovfl, helper) +} + +// ignoreSlice skips over the data for a slice value with no destination. +func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) { + dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint())) +} + +// decodeInterface decodes an interface value and stores it in value. +// Interfaces are encoded as the name of a concrete type followed by a value. +// If the name is empty, the value is nil and no value is sent. +func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, value reflect.Value) { + // Read the name of the concrete type. + nr := state.decodeUint() + if nr < 0 || nr > 1<<31 { // zero is permissible for anonymous types + errorf("invalid type name length %d", nr) + } + if nr > uint64(state.b.Len()) { + errorf("invalid type name length %d: exceeds input size", nr) + } + b := make([]byte, nr) + state.b.Read(b) + name := string(b) + // Allocate the destination interface value. + if name == "" { + // Copy the nil interface value to the target. + value.Set(reflect.Zero(value.Type())) + return + } + if len(name) > 1024 { + errorf("name too long (%d bytes): %.20q...", len(name), name) + } + // The concrete type must be registered. + registerLock.RLock() + typ, ok := nameToConcreteType[name] + registerLock.RUnlock() + if !ok { + errorf("name not registered for interface: %q", name) + } + // Read the type id of the concrete value. + concreteId := dec.decodeTypeSequence(true) + if concreteId < 0 { + error_(dec.err) + } + // Byte count of value is next; we don't care what it is (it's there + // in case we want to ignore the value by skipping it completely). + state.decodeUint() + // Read the concrete value. + v := allocValue(typ) + dec.decodeValue(concreteId, v) + if dec.err != nil { + error_(dec.err) + } + // Assign the concrete value to the interface. + // Tread carefully; it might not satisfy the interface. + if !typ.AssignableTo(ityp) { + errorf("%s is not assignable to type %s", typ, ityp) + } + // Copy the interface value to the target. + value.Set(v) +} + +// ignoreInterface discards the data for an interface value with no destination. +func (dec *Decoder) ignoreInterface(state *decoderState) { + // Read the name of the concrete type. + b := make([]byte, state.decodeUint()) + _, err := state.b.Read(b) + if err != nil { + error_(err) + } + id := dec.decodeTypeSequence(true) + if id < 0 { + error_(dec.err) + } + // At this point, the decoder buffer contains a delimited value. Just toss it. + state.b.Drop(int(state.decodeUint())) +} + +// decodeGobDecoder decodes something implementing the GobDecoder interface. +// The data is encoded as a byte slice. +func (dec *Decoder) decodeGobDecoder(ut *userTypeInfo, state *decoderState, value reflect.Value) { + // Read the bytes for the value. + b := make([]byte, state.decodeUint()) + _, err := state.b.Read(b) + if err != nil { + error_(err) + } + // We know it's one of these. + switch ut.externalDec { + case xGob: + err = value.Interface().(GobDecoder).GobDecode(b) + case xBinary: + err = value.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(b) + case xText: + err = value.Interface().(encoding.TextUnmarshaler).UnmarshalText(b) + } + if err != nil { + error_(err) + } +} + +// ignoreGobDecoder discards the data for a GobDecoder value with no destination. +func (dec *Decoder) ignoreGobDecoder(state *decoderState) { + // Read the bytes for the value. + b := make([]byte, state.decodeUint()) + _, err := state.b.Read(b) + if err != nil { + error_(err) + } +} + +// Index by Go types. +var decOpTable = [...]decOp{ + reflect.Bool: decBool, + reflect.Int8: decInt8, + reflect.Int16: decInt16, + reflect.Int32: decInt32, + reflect.Int64: decInt64, + reflect.Uint8: decUint8, + reflect.Uint16: decUint16, + reflect.Uint32: decUint32, + reflect.Uint64: decUint64, + reflect.Float32: decFloat32, + reflect.Float64: decFloat64, + reflect.Complex64: decComplex64, + reflect.Complex128: decComplex128, + reflect.String: decString, +} + +// Indexed by gob types. tComplex will be added during type.init(). +var decIgnoreOpMap = map[typeId]decOp{ + tBool: ignoreUint, + tInt: ignoreUint, + tUint: ignoreUint, + tFloat: ignoreUint, + tBytes: ignoreUint8Array, + tString: ignoreUint8Array, + tComplex: ignoreTwoUints, +} + +// decOpFor returns the decoding op for the base type under rt and +// the indirection count to reach it. +func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) *decOp { + ut := userType(rt) + // If the type implements GobEncoder, we handle it without further processing. + if ut.externalDec != 0 { + return dec.gobDecodeOpFor(ut) + } + + // If this type is already in progress, it's a recursive type (e.g. map[string]*T). + // Return the pointer to the op we're already building. + if opPtr := inProgress[rt]; opPtr != nil { + return opPtr + } + typ := ut.base + var op decOp + k := typ.Kind() + if int(k) < len(decOpTable) { + op = decOpTable[k] + } + if op == nil { + inProgress[rt] = &op + // Special cases + switch t := typ; t.Kind() { + case reflect.Array: + name = "element of " + name + elemId := dec.wireType[wireId].ArrayT.Elem + elemOp := dec.decOpFor(elemId, t.Elem(), name, inProgress) + ovfl := overflow(name) + helper := decArrayHelper[t.Elem().Kind()] + op = func(i *decInstr, state *decoderState, value reflect.Value) { + state.dec.decodeArray(t, state, value, *elemOp, t.Len(), ovfl, helper) + } + + case reflect.Map: + keyId := dec.wireType[wireId].MapT.Key + elemId := dec.wireType[wireId].MapT.Elem + keyOp := dec.decOpFor(keyId, t.Key(), "key of "+name, inProgress) + elemOp := dec.decOpFor(elemId, t.Elem(), "element of "+name, inProgress) + ovfl := overflow(name) + op = func(i *decInstr, state *decoderState, value reflect.Value) { + state.dec.decodeMap(t, state, value, *keyOp, *elemOp, ovfl) + } + + case reflect.Slice: + name = "element of " + name + if t.Elem().Kind() == reflect.Uint8 { + op = decUint8Slice + break + } + var elemId typeId + if tt, ok := builtinIdToType[wireId]; ok { + elemId = tt.(*sliceType).Elem + } else { + elemId = dec.wireType[wireId].SliceT.Elem + } + elemOp := dec.decOpFor(elemId, t.Elem(), name, inProgress) + ovfl := overflow(name) + helper := decSliceHelper[t.Elem().Kind()] + op = func(i *decInstr, state *decoderState, value reflect.Value) { + state.dec.decodeSlice(state, value, *elemOp, ovfl, helper) + } + + case reflect.Struct: + // Generate a closure that calls out to the engine for the nested type. + ut := userType(typ) + enginePtr, err := dec.getDecEnginePtr(wireId, ut) + if err != nil { + error_(err) + } + op = func(i *decInstr, state *decoderState, value reflect.Value) { + // indirect through enginePtr to delay evaluation for recursive structs. + dec.decodeStruct(*enginePtr, ut, value) + } + case reflect.Interface: + op = func(i *decInstr, state *decoderState, value reflect.Value) { + state.dec.decodeInterface(t, state, value) + } + } + } + if op == nil { + errorf("decode can't handle type %s", rt) + } + return &op +} + +// decIgnoreOpFor returns the decoding op for a field that has no destination. +func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { + op, ok := decIgnoreOpMap[wireId] + if !ok { + if wireId == tInterface { + // Special case because it's a method: the ignored item might + // define types and we need to record their state in the decoder. + op = func(i *decInstr, state *decoderState, value reflect.Value) { + state.dec.ignoreInterface(state) + } + return op + } + // Special cases + wire := dec.wireType[wireId] + switch { + case wire == nil: + errorf("bad data: undefined type %s", wireId.string()) + case wire.ArrayT != nil: + elemId := wire.ArrayT.Elem + elemOp := dec.decIgnoreOpFor(elemId) + op = func(i *decInstr, state *decoderState, value reflect.Value) { + state.dec.ignoreArray(state, elemOp, wire.ArrayT.Len) + } + + case wire.MapT != nil: + keyId := dec.wireType[wireId].MapT.Key + elemId := dec.wireType[wireId].MapT.Elem + keyOp := dec.decIgnoreOpFor(keyId) + elemOp := dec.decIgnoreOpFor(elemId) + op = func(i *decInstr, state *decoderState, value reflect.Value) { + state.dec.ignoreMap(state, keyOp, elemOp) + } + + case wire.SliceT != nil: + elemId := wire.SliceT.Elem + elemOp := dec.decIgnoreOpFor(elemId) + op = func(i *decInstr, state *decoderState, value reflect.Value) { + state.dec.ignoreSlice(state, elemOp) + } + + case wire.StructT != nil: + // Generate a closure that calls out to the engine for the nested type. + enginePtr, err := dec.getIgnoreEnginePtr(wireId) + if err != nil { + error_(err) + } + op = func(i *decInstr, state *decoderState, value reflect.Value) { + // indirect through enginePtr to delay evaluation for recursive structs + state.dec.ignoreStruct(*enginePtr) + } + + case wire.GobEncoderT != nil, wire.BinaryMarshalerT != nil, wire.TextMarshalerT != nil: + op = func(i *decInstr, state *decoderState, value reflect.Value) { + state.dec.ignoreGobDecoder(state) + } + } + } + if op == nil { + errorf("bad data: ignore can't handle type %s", wireId.string()) + } + return op +} + +// gobDecodeOpFor returns the op for a type that is known to implement +// GobDecoder. +func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) *decOp { + rcvrType := ut.user + if ut.decIndir == -1 { + rcvrType = reflect.PtrTo(rcvrType) + } else if ut.decIndir > 0 { + for i := int8(0); i < ut.decIndir; i++ { + rcvrType = rcvrType.Elem() + } + } + var op decOp + op = func(i *decInstr, state *decoderState, value reflect.Value) { + // We now have the base type. We need its address if the receiver is a pointer. + if value.Kind() != reflect.Ptr && rcvrType.Kind() == reflect.Ptr { + value = value.Addr() + } + state.dec.decodeGobDecoder(ut, state, value) + } + return &op +} + +// compatibleType asks: Are these two gob Types compatible? +// Answers the question for basic types, arrays, maps and slices, plus +// GobEncoder/Decoder pairs. +// Structs are considered ok; fields will be checked later. +func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool { + if rhs, ok := inProgress[fr]; ok { + return rhs == fw + } + inProgress[fr] = fw + ut := userType(fr) + wire, ok := dec.wireType[fw] + // If wire was encoded with an encoding method, fr must have that method. + // And if not, it must not. + // At most one of the booleans in ut is set. + // We could possibly relax this constraint in the future in order to + // choose the decoding method using the data in the wireType. + // The parentheses look odd but are correct. + if (ut.externalDec == xGob) != (ok && wire.GobEncoderT != nil) || + (ut.externalDec == xBinary) != (ok && wire.BinaryMarshalerT != nil) || + (ut.externalDec == xText) != (ok && wire.TextMarshalerT != nil) { + return false + } + if ut.externalDec != 0 { // This test trumps all others. + return true + } + switch t := ut.base; t.Kind() { + default: + // chan, etc: cannot handle. + return false + case reflect.Bool: + return fw == tBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return fw == tInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return fw == tUint + case reflect.Float32, reflect.Float64: + return fw == tFloat + case reflect.Complex64, reflect.Complex128: + return fw == tComplex + case reflect.String: + return fw == tString + case reflect.Interface: + return fw == tInterface + case reflect.Array: + if !ok || wire.ArrayT == nil { + return false + } + array := wire.ArrayT + return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress) + case reflect.Map: + if !ok || wire.MapT == nil { + return false + } + MapType := wire.MapT + return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress) + case reflect.Slice: + // Is it an array of bytes? + if t.Elem().Kind() == reflect.Uint8 { + return fw == tBytes + } + // Extract and compare element types. + var sw *sliceType + if tt, ok := builtinIdToType[fw]; ok { + sw, _ = tt.(*sliceType) + } else if wire != nil { + sw = wire.SliceT + } + elem := userType(t.Elem()).base + return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress) + case reflect.Struct: + return true + } +} + +// typeString returns a human-readable description of the type identified by remoteId. +func (dec *Decoder) typeString(remoteId typeId) string { + if t := idToType[remoteId]; t != nil { + // globally known type. + return t.string() + } + return dec.wireType[remoteId].string() +} + +// compileSingle compiles the decoder engine for a non-struct top-level value, including +// GobDecoders. +func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err error) { + rt := ut.user + engine = new(decEngine) + engine.instr = make([]decInstr, 1) // one item + name := rt.String() // best we can do + if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) { + remoteType := dec.typeString(remoteId) + // Common confusing case: local interface type, remote concrete type. + if ut.base.Kind() == reflect.Interface && remoteId != tInterface { + return nil, errors.New("gob: local interface type " + name + " can only be decoded from remote interface type; received concrete type " + remoteType) + } + return nil, errors.New("gob: decoding into local type " + name + ", received remote type " + remoteType) + } + op := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp)) + ovfl := errors.New(`value for "` + name + `" out of range`) + engine.instr[singletonField] = decInstr{*op, singletonField, nil, ovfl} + engine.numInstr = 1 + return +} + +// compileIgnoreSingle compiles the decoder engine for a non-struct top-level value that will be discarded. +func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err error) { + engine = new(decEngine) + engine.instr = make([]decInstr, 1) // one item + op := dec.decIgnoreOpFor(remoteId) + ovfl := overflow(dec.typeString(remoteId)) + engine.instr[0] = decInstr{op, 0, nil, ovfl} + engine.numInstr = 1 + return +} + +// compileDec compiles the decoder engine for a value. If the value is not a struct, +// it calls out to compileSingle. +func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err error) { + rt := ut.base + srt := rt + if srt.Kind() != reflect.Struct || ut.externalDec != 0 { + return dec.compileSingle(remoteId, ut) + } + var wireStruct *structType + // Builtin types can come from global pool; the rest must be defined by the decoder. + // Also we know we're decoding a struct now, so the client must have sent one. + if t, ok := builtinIdToType[remoteId]; ok { + wireStruct, _ = t.(*structType) + } else { + wire := dec.wireType[remoteId] + if wire == nil { + error_(errBadType) + } + wireStruct = wire.StructT + } + if wireStruct == nil { + errorf("type mismatch in decoder: want struct type %s; got non-struct", rt) + } + engine = new(decEngine) + engine.instr = make([]decInstr, len(wireStruct.Field)) + seen := make(map[reflect.Type]*decOp) + // Loop over the fields of the wire type. + for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ { + wireField := wireStruct.Field[fieldnum] + if wireField.Name == "" { + errorf("empty name for remote field of type %s", wireStruct.Name) + } + ovfl := overflow(wireField.Name) + // Find the field of the local type with the same name. + localField, present := srt.FieldByName(wireField.Name) + // TODO(r): anonymous names + if !present || !isExported(wireField.Name) { + op := dec.decIgnoreOpFor(wireField.Id) + engine.instr[fieldnum] = decInstr{op, fieldnum, nil, ovfl} + continue + } + if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) { + errorf("wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) + } + op := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen) + engine.instr[fieldnum] = decInstr{*op, fieldnum, localField.Index, ovfl} + engine.numInstr++ + } + return +} + +// getDecEnginePtr returns the engine for the specified type. +func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePtr **decEngine, err error) { + rt := ut.user + decoderMap, ok := dec.decoderCache[rt] + if !ok { + decoderMap = make(map[typeId]**decEngine) + dec.decoderCache[rt] = decoderMap + } + if enginePtr, ok = decoderMap[remoteId]; !ok { + // To handle recursive types, mark this engine as underway before compiling. + enginePtr = new(*decEngine) + decoderMap[remoteId] = enginePtr + *enginePtr, err = dec.compileDec(remoteId, ut) + if err != nil { + delete(decoderMap, remoteId) + } + } + return +} + +// emptyStruct is the type we compile into when ignoring a struct value. +type emptyStruct struct{} + +var emptyStructType = reflect.TypeOf(emptyStruct{}) + +// getDecEnginePtr returns the engine for the specified type when the value is to be discarded. +func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err error) { + var ok bool + if enginePtr, ok = dec.ignorerCache[wireId]; !ok { + // To handle recursive types, mark this engine as underway before compiling. + enginePtr = new(*decEngine) + dec.ignorerCache[wireId] = enginePtr + wire := dec.wireType[wireId] + if wire != nil && wire.StructT != nil { + *enginePtr, err = dec.compileDec(wireId, userType(emptyStructType)) + } else { + *enginePtr, err = dec.compileIgnoreSingle(wireId) + } + if err != nil { + delete(dec.ignorerCache, wireId) + } + } + return +} + +// decodeValue decodes the data stream representing a value and stores it in value. +func (dec *Decoder) decodeValue(wireId typeId, value reflect.Value) { + defer catchError(&dec.err) + // If the value is nil, it means we should just ignore this item. + if !value.IsValid() { + dec.decodeIgnoredValue(wireId) + return + } + // Dereference down to the underlying type. + ut := userType(value.Type()) + base := ut.base + var enginePtr **decEngine + enginePtr, dec.err = dec.getDecEnginePtr(wireId, ut) + if dec.err != nil { + return + } + value = decAlloc(value) + engine := *enginePtr + if st := base; st.Kind() == reflect.Struct && ut.externalDec == 0 { + if engine.numInstr == 0 && st.NumField() > 0 && + dec.wireType[wireId] != nil && len(dec.wireType[wireId].StructT.Field) > 0 { + name := base.Name() + errorf("type mismatch: no fields matched compiling decoder for %s", name) + } + dec.decodeStruct(engine, ut, value) + } else { + dec.decodeSingle(engine, ut, value) + } +} + +// decodeIgnoredValue decodes the data stream representing a value of the specified type and discards it. +func (dec *Decoder) decodeIgnoredValue(wireId typeId) { + var enginePtr **decEngine + enginePtr, dec.err = dec.getIgnoreEnginePtr(wireId) + if dec.err != nil { + return + } + wire := dec.wireType[wireId] + if wire != nil && wire.StructT != nil { + dec.ignoreStruct(*enginePtr) + } else { + dec.ignoreSingle(*enginePtr) + } +} + +func init() { + var iop, uop decOp + switch reflect.TypeOf(int(0)).Bits() { + case 32: + iop = decInt32 + uop = decUint32 + case 64: + iop = decInt64 + uop = decUint64 + default: + panic("gob: unknown size of int/uint") + } + decOpTable[reflect.Int] = iop + decOpTable[reflect.Uint] = uop + + // Finally uintptr + switch reflect.TypeOf(uintptr(0)).Bits() { + case 32: + uop = decUint32 + case 64: + uop = decUint64 + default: + panic("gob: unknown size of uintptr") + } + decOpTable[reflect.Uintptr] = uop +} + +// Gob depends on being able to take the address +// of zeroed Values it creates, so use this wrapper instead +// of the standard reflect.Zero. +// Each call allocates once. +func allocValue(t reflect.Type) reflect.Value { + return reflect.New(t).Elem() +} diff --git a/src/encoding/gob/decoder.go b/src/encoding/gob/decoder.go new file mode 100644 index 000000000..c453e9ba3 --- /dev/null +++ b/src/encoding/gob/decoder.go @@ -0,0 +1,218 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "bufio" + "errors" + "io" + "reflect" + "sync" +) + +// tooBig provides a sanity check for sizes; used in several places. +// Upper limit of 1GB, allowing room to grow a little without overflow. +// TODO: make this adjustable? +const tooBig = 1 << 30 + +// A Decoder manages the receipt of type and data information read from the +// remote side of a connection. +type Decoder struct { + mutex sync.Mutex // each item must be received atomically + r io.Reader // source of the data + buf decBuffer // buffer for more efficient i/o from r + wireType map[typeId]*wireType // map from remote ID to local description + decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines + ignorerCache map[typeId]**decEngine // ditto for ignored objects + freeList *decoderState // list of free decoderStates; avoids reallocation + countBuf []byte // used for decoding integers while parsing messages + err error +} + +// NewDecoder returns a new decoder that reads from the io.Reader. +// If r does not also implement io.ByteReader, it will be wrapped in a +// bufio.Reader. +func NewDecoder(r io.Reader) *Decoder { + dec := new(Decoder) + // We use the ability to read bytes as a plausible surrogate for buffering. + if _, ok := r.(io.ByteReader); !ok { + r = bufio.NewReader(r) + } + dec.r = r + dec.wireType = make(map[typeId]*wireType) + dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine) + dec.ignorerCache = make(map[typeId]**decEngine) + dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes + + return dec +} + +// recvType loads the definition of a type. +func (dec *Decoder) recvType(id typeId) { + // Have we already seen this type? That's an error + if id < firstUserId || dec.wireType[id] != nil { + dec.err = errors.New("gob: duplicate type received") + return + } + + // Type: + wire := new(wireType) + dec.decodeValue(tWireType, reflect.ValueOf(wire)) + if dec.err != nil { + return + } + // Remember we've seen this type. + dec.wireType[id] = wire +} + +var errBadCount = errors.New("invalid message length") + +// recvMessage reads the next count-delimited item from the input. It is the converse +// of Encoder.writeMessage. It returns false on EOF or other error reading the message. +func (dec *Decoder) recvMessage() bool { + // Read a count. + nbytes, _, err := decodeUintReader(dec.r, dec.countBuf) + if err != nil { + dec.err = err + return false + } + if nbytes >= tooBig { + dec.err = errBadCount + return false + } + dec.readMessage(int(nbytes)) + return dec.err == nil +} + +// readMessage reads the next nbytes bytes from the input. +func (dec *Decoder) readMessage(nbytes int) { + if dec.buf.Len() != 0 { + // The buffer should always be empty now. + panic("non-empty decoder buffer") + } + // Read the data + dec.buf.Size(nbytes) + _, dec.err = io.ReadFull(dec.r, dec.buf.Bytes()) + if dec.err != nil { + if dec.err == io.EOF { + dec.err = io.ErrUnexpectedEOF + } + } +} + +// toInt turns an encoded uint64 into an int, according to the marshaling rules. +func toInt(x uint64) int64 { + i := int64(x >> 1) + if x&1 != 0 { + i = ^i + } + return i +} + +func (dec *Decoder) nextInt() int64 { + n, _, err := decodeUintReader(&dec.buf, dec.countBuf) + if err != nil { + dec.err = err + } + return toInt(n) +} + +func (dec *Decoder) nextUint() uint64 { + n, _, err := decodeUintReader(&dec.buf, dec.countBuf) + if err != nil { + dec.err = err + } + return n +} + +// decodeTypeSequence parses: +// TypeSequence +// (TypeDefinition DelimitedTypeDefinition*)? +// and returns the type id of the next value. It returns -1 at +// EOF. Upon return, the remainder of dec.buf is the value to be +// decoded. If this is an interface value, it can be ignored by +// resetting that buffer. +func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId { + for dec.err == nil { + if dec.buf.Len() == 0 { + if !dec.recvMessage() { + break + } + } + // Receive a type id. + id := typeId(dec.nextInt()) + if id >= 0 { + // Value follows. + return id + } + // Type definition for (-id) follows. + dec.recvType(-id) + // When decoding an interface, after a type there may be a + // DelimitedValue still in the buffer. Skip its count. + // (Alternatively, the buffer is empty and the byte count + // will be absorbed by recvMessage.) + if dec.buf.Len() > 0 { + if !isInterface { + dec.err = errors.New("extra data in buffer") + break + } + dec.nextUint() + } + } + return -1 +} + +// Decode reads the next value from the input stream and stores +// it in the data represented by the empty interface value. +// If e is nil, the value will be discarded. Otherwise, +// the value underlying e must be a pointer to the +// correct type for the next data item received. +// If the input is at EOF, Decode returns io.EOF and +// does not modify e. +func (dec *Decoder) Decode(e interface{}) error { + if e == nil { + return dec.DecodeValue(reflect.Value{}) + } + value := reflect.ValueOf(e) + // If e represents a value as opposed to a pointer, the answer won't + // get back to the caller. Make sure it's a pointer. + if value.Type().Kind() != reflect.Ptr { + dec.err = errors.New("gob: attempt to decode into a non-pointer") + return dec.err + } + return dec.DecodeValue(value) +} + +// DecodeValue reads the next value from the input stream. +// If v is the zero reflect.Value (v.Kind() == Invalid), DecodeValue discards the value. +// Otherwise, it stores the value into v. In that case, v must represent +// a non-nil pointer to data or be an assignable reflect.Value (v.CanSet()) +// If the input is at EOF, DecodeValue returns io.EOF and +// does not modify v. +func (dec *Decoder) DecodeValue(v reflect.Value) error { + if v.IsValid() { + if v.Kind() == reflect.Ptr && !v.IsNil() { + // That's okay, we'll store through the pointer. + } else if !v.CanSet() { + return errors.New("gob: DecodeValue of unassignable value") + } + } + // Make sure we're single-threaded through here. + dec.mutex.Lock() + defer dec.mutex.Unlock() + + dec.buf.Reset() // In case data lingers from previous invocation. + dec.err = nil + id := dec.decodeTypeSequence(false) + if dec.err == nil { + dec.decodeValue(id, v) + } + return dec.err +} + +// If debug.go is compiled into the program , debugFunc prints a human-readable +// representation of the gob data read from r by calling that file's Debug function. +// Otherwise it is nil. +var debugFunc func(io.Reader) diff --git a/src/encoding/gob/doc.go b/src/encoding/gob/doc.go new file mode 100644 index 000000000..d0acaba1a --- /dev/null +++ b/src/encoding/gob/doc.go @@ -0,0 +1,386 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package gob manages streams of gobs - binary values exchanged between an +Encoder (transmitter) and a Decoder (receiver). A typical use is transporting +arguments and results of remote procedure calls (RPCs) such as those provided by +package "rpc". + +The implementation compiles a custom codec for each data type in the stream and +is most efficient when a single Encoder is used to transmit a stream of values, +amortizing the cost of compilation. + +Basics + +A stream of gobs is self-describing. Each data item in the stream is preceded by +a specification of its type, expressed in terms of a small set of predefined +types. Pointers are not transmitted, but the things they point to are +transmitted; that is, the values are flattened. Recursive types work fine, but +recursive values (data with cycles) are problematic. This may change. + +To use gobs, create an Encoder and present it with a series of data items as +values or addresses that can be dereferenced to values. The Encoder makes sure +all type information is sent before it is needed. At the receive side, a +Decoder retrieves values from the encoded stream and unpacks them into local +variables. + +Types and Values + +The source and destination values/types need not correspond exactly. For structs, +fields (identified by name) that are in the source but absent from the receiving +variable will be ignored. Fields that are in the receiving variable but missing +from the transmitted type or value will be ignored in the destination. If a field +with the same name is present in both, their types must be compatible. Both the +receiver and transmitter will do all necessary indirection and dereferencing to +convert between gobs and actual Go values. For instance, a gob type that is +schematically, + + struct { A, B int } + +can be sent from or received into any of these Go types: + + struct { A, B int } // the same + *struct { A, B int } // extra indirection of the struct + struct { *A, **B int } // extra indirection of the fields + struct { A, B int64 } // different concrete value type; see below + +It may also be received into any of these: + + struct { A, B int } // the same + struct { B, A int } // ordering doesn't matter; matching is by name + struct { A, B, C int } // extra field (C) ignored + struct { B int } // missing field (A) ignored; data will be dropped + struct { B, C int } // missing field (A) ignored; extra field (C) ignored. + +Attempting to receive into these types will draw a decode error: + + struct { A int; B uint } // change of signedness for B + struct { A int; B float } // change of type for B + struct { } // no field names in common + struct { C, D int } // no field names in common + +Integers are transmitted two ways: arbitrary precision signed integers or +arbitrary precision unsigned integers. There is no int8, int16 etc. +discrimination in the gob format; there are only signed and unsigned integers. As +described below, the transmitter sends the value in a variable-length encoding; +the receiver accepts the value and stores it in the destination variable. +Floating-point numbers are always sent using IEEE-754 64-bit precision (see +below). + +Signed integers may be received into any signed integer variable: int, int16, etc.; +unsigned integers may be received into any unsigned integer variable; and floating +point values may be received into any floating point variable. However, +the destination variable must be able to represent the value or the decode +operation will fail. + +Structs, arrays and slices are also supported. Structs encode and decode only +exported fields. Strings and arrays of bytes are supported with a special, +efficient representation (see below). When a slice is decoded, if the existing +slice has capacity the slice will be extended in place; if not, a new array is +allocated. Regardless, the length of the resulting slice reports the number of +elements decoded. + +Functions and channels will not be sent in a gob. Attempting to encode such a value +at top the level will fail. A struct field of chan or func type is treated exactly +like an unexported field and is ignored. + +Gob can encode a value of any type implementing the GobEncoder or +encoding.BinaryMarshaler interfaces by calling the corresponding method, +in that order of preference. + +Gob can decode a value of any type implementing the GobDecoder or +encoding.BinaryUnmarshaler interfaces by calling the corresponding method, +again in that order of preference. + +Encoding Details + +This section documents the encoding, details that are not important for most +users. Details are presented bottom-up. + +An unsigned integer is sent one of two ways. If it is less than 128, it is sent +as a byte with that value. Otherwise it is sent as a minimal-length big-endian +(high byte first) byte stream holding the value, preceded by one byte holding the +byte count, negated. Thus 0 is transmitted as (00), 7 is transmitted as (07) and +256 is transmitted as (FE 01 00). + +A boolean is encoded within an unsigned integer: 0 for false, 1 for true. + +A signed integer, i, is encoded within an unsigned integer, u. Within u, bits 1 +upward contain the value; bit 0 says whether they should be complemented upon +receipt. The encode algorithm looks like this: + + uint u; + if i < 0 { + u = (^i << 1) | 1 // complement i, bit 0 is 1 + } else { + u = (i << 1) // do not complement i, bit 0 is 0 + } + encodeUnsigned(u) + +The low bit is therefore analogous to a sign bit, but making it the complement bit +instead guarantees that the largest negative integer is not a special case. For +example, -129=^128=(^256>>1) encodes as (FE 01 01). + +Floating-point numbers are always sent as a representation of a float64 value. +That value is converted to a uint64 using math.Float64bits. The uint64 is then +byte-reversed and sent as a regular unsigned integer. The byte-reversal means the +exponent and high-precision part of the mantissa go first. Since the low bits are +often zero, this can save encoding bytes. For instance, 17.0 is encoded in only +three bytes (FE 31 40). + +Strings and slices of bytes are sent as an unsigned count followed by that many +uninterpreted bytes of the value. + +All other slices and arrays are sent as an unsigned count followed by that many +elements using the standard gob encoding for their type, recursively. + +Maps are sent as an unsigned count followed by that many key, element +pairs. Empty but non-nil maps are sent, so if the sender has allocated +a map, the receiver will allocate a map even if no elements are +transmitted. + +Structs are sent as a sequence of (field number, field value) pairs. The field +value is sent using the standard gob encoding for its type, recursively. If a +field has the zero value for its type, it is omitted from the transmission. The +field number is defined by the type of the encoded struct: the first field of the +encoded type is field 0, the second is field 1, etc. When encoding a value, the +field numbers are delta encoded for efficiency and the fields are always sent in +order of increasing field number; the deltas are therefore unsigned. The +initialization for the delta encoding sets the field number to -1, so an unsigned +integer field 0 with value 7 is transmitted as unsigned delta = 1, unsigned value += 7 or (01 07). Finally, after all the fields have been sent a terminating mark +denotes the end of the struct. That mark is a delta=0 value, which has +representation (00). + +Interface types are not checked for compatibility; all interface types are +treated, for transmission, as members of a single "interface" type, analogous to +int or []byte - in effect they're all treated as interface{}. Interface values +are transmitted as a string identifying the concrete type being sent (a name +that must be pre-defined by calling Register), followed by a byte count of the +length of the following data (so the value can be skipped if it cannot be +stored), followed by the usual encoding of concrete (dynamic) value stored in +the interface value. (A nil interface value is identified by the empty string +and transmits no value.) Upon receipt, the decoder verifies that the unpacked +concrete item satisfies the interface of the receiving variable. + +The representation of types is described below. When a type is defined on a given +connection between an Encoder and Decoder, it is assigned a signed integer type +id. When Encoder.Encode(v) is called, it makes sure there is an id assigned for +the type of v and all its elements and then it sends the pair (typeid, encoded-v) +where typeid is the type id of the encoded type of v and encoded-v is the gob +encoding of the value v. + +To define a type, the encoder chooses an unused, positive type id and sends the +pair (-type id, encoded-type) where encoded-type is the gob encoding of a wireType +description, constructed from these types: + + type wireType struct { + ArrayT *ArrayType + SliceT *SliceType + StructT *StructType + MapT *MapType + } + type arrayType struct { + CommonType + Elem typeId + Len int + } + type CommonType struct { + Name string // the name of the struct type + Id int // the id of the type, repeated so it's inside the type + } + type sliceType struct { + CommonType + Elem typeId + } + type structType struct { + CommonType + Field []*fieldType // the fields of the struct. + } + type fieldType struct { + Name string // the name of the field. + Id int // the type id of the field, which must be already defined + } + type mapType struct { + CommonType + Key typeId + Elem typeId + } + +If there are nested type ids, the types for all inner type ids must be defined +before the top-level type id is used to describe an encoded-v. + +For simplicity in setup, the connection is defined to understand these types a +priori, as well as the basic gob types int, uint, etc. Their ids are: + + bool 1 + int 2 + uint 3 + float 4 + []byte 5 + string 6 + complex 7 + interface 8 + // gap for reserved ids. + WireType 16 + ArrayType 17 + CommonType 18 + SliceType 19 + StructType 20 + FieldType 21 + // 22 is slice of fieldType. + MapType 23 + +Finally, each message created by a call to Encode is preceded by an encoded +unsigned integer count of the number of bytes remaining in the message. After +the initial type name, interface values are wrapped the same way; in effect, the +interface value acts like a recursive invocation of Encode. + +In summary, a gob stream looks like + + (byteCount (-type id, encoding of a wireType)* (type id, encoding of a value))* + +where * signifies zero or more repetitions and the type id of a value must +be predefined or be defined before the value in the stream. + +See "Gobs of data" for a design discussion of the gob wire format: +http://golang.org/doc/articles/gobs_of_data.html +*/ +package gob + +/* +Grammar: + +Tokens starting with a lower case letter are terminals; int(n) +and uint(n) represent the signed/unsigned encodings of the value n. + +GobStream: + DelimitedMessage* +DelimitedMessage: + uint(lengthOfMessage) Message +Message: + TypeSequence TypedValue +TypeSequence + (TypeDefinition DelimitedTypeDefinition*)? +DelimitedTypeDefinition: + uint(lengthOfTypeDefinition) TypeDefinition +TypedValue: + int(typeId) Value +TypeDefinition: + int(-typeId) encodingOfWireType +Value: + SingletonValue | StructValue +SingletonValue: + uint(0) FieldValue +FieldValue: + builtinValue | ArrayValue | MapValue | SliceValue | StructValue | InterfaceValue +InterfaceValue: + NilInterfaceValue | NonNilInterfaceValue +NilInterfaceValue: + uint(0) +NonNilInterfaceValue: + ConcreteTypeName TypeSequence InterfaceContents +ConcreteTypeName: + uint(lengthOfName) [already read=n] name +InterfaceContents: + int(concreteTypeId) DelimitedValue +DelimitedValue: + uint(length) Value +ArrayValue: + uint(n) FieldValue*n [n elements] +MapValue: + uint(n) (FieldValue FieldValue)*n [n (key, value) pairs] +SliceValue: + uint(n) FieldValue*n [n elements] +StructValue: + (uint(fieldDelta) FieldValue)* +*/ + +/* +For implementers and the curious, here is an encoded example. Given + type Point struct {X, Y int} +and the value + p := Point{22, 33} +the bytes transmitted that encode p will be: + 1f ff 81 03 01 01 05 50 6f 69 6e 74 01 ff 82 00 + 01 02 01 01 58 01 04 00 01 01 59 01 04 00 00 00 + 07 ff 82 01 2c 01 42 00 +They are determined as follows. + +Since this is the first transmission of type Point, the type descriptor +for Point itself must be sent before the value. This is the first type +we've sent on this Encoder, so it has type id 65 (0 through 64 are +reserved). + + 1f // This item (a type descriptor) is 31 bytes long. + ff 81 // The negative of the id for the type we're defining, -65. + // This is one byte (indicated by FF = -1) followed by + // ^-65<<1 | 1. The low 1 bit signals to complement the + // rest upon receipt. + + // Now we send a type descriptor, which is itself a struct (wireType). + // The type of wireType itself is known (it's built in, as is the type of + // all its components), so we just need to send a *value* of type wireType + // that represents type "Point". + // Here starts the encoding of that value. + // Set the field number implicitly to -1; this is done at the beginning + // of every struct, including nested structs. + 03 // Add 3 to field number; now 2 (wireType.structType; this is a struct). + // structType starts with an embedded CommonType, which appears + // as a regular structure here too. + 01 // add 1 to field number (now 0); start of embedded CommonType. + 01 // add 1 to field number (now 0, the name of the type) + 05 // string is (unsigned) 5 bytes long + 50 6f 69 6e 74 // wireType.structType.CommonType.name = "Point" + 01 // add 1 to field number (now 1, the id of the type) + ff 82 // wireType.structType.CommonType._id = 65 + 00 // end of embedded wiretype.structType.CommonType struct + 01 // add 1 to field number (now 1, the field array in wireType.structType) + 02 // There are two fields in the type (len(structType.field)) + 01 // Start of first field structure; add 1 to get field number 0: field[0].name + 01 // 1 byte + 58 // structType.field[0].name = "X" + 01 // Add 1 to get field number 1: field[0].id + 04 // structType.field[0].typeId is 2 (signed int). + 00 // End of structType.field[0]; start structType.field[1]; set field number to -1. + 01 // Add 1 to get field number 0: field[1].name + 01 // 1 byte + 59 // structType.field[1].name = "Y" + 01 // Add 1 to get field number 1: field[1].id + 04 // struct.Type.field[1].typeId is 2 (signed int). + 00 // End of structType.field[1]; end of structType.field. + 00 // end of wireType.structType structure + 00 // end of wireType structure + +Now we can send the Point value. Again the field number resets to -1: + + 07 // this value is 7 bytes long + ff 82 // the type number, 65 (1 byte (-FF) followed by 65<<1) + 01 // add one to field number, yielding field 0 + 2c // encoding of signed "22" (0x22 = 44 = 22<<1); Point.x = 22 + 01 // add one to field number, yielding field 1 + 42 // encoding of signed "33" (0x42 = 66 = 33<<1); Point.y = 33 + 00 // end of structure + +The type encoding is long and fairly intricate but we send it only once. +If p is transmitted a second time, the type is already known so the +output will be just: + + 07 ff 82 01 2c 01 42 00 + +A single non-struct value at top level is transmitted like a field with +delta tag 0. For instance, a signed integer with value 3 presented as +the argument to Encode will emit: + + 03 04 00 06 + +Which represents: + + 03 // this value is 3 bytes long + 04 // the type number, 2, represents an integer + 00 // tag delta 0 + 06 // value 3 + +*/ diff --git a/src/encoding/gob/dump.go b/src/encoding/gob/dump.go new file mode 100644 index 000000000..17238c98d --- /dev/null +++ b/src/encoding/gob/dump.go @@ -0,0 +1,29 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package main + +// Need to compile package gob with debug.go to build this program. +// See comments in debug.go for how to do this. + +import ( + "encoding/gob" + "fmt" + "os" +) + +func main() { + var err error + file := os.Stdin + if len(os.Args) > 1 { + file, err = os.Open(os.Args[1]) + if err != nil { + fmt.Fprintf(os.Stderr, "dump: %s\n", err) + os.Exit(1) + } + } + gob.Debug(file) +} diff --git a/src/encoding/gob/enc_helpers.go b/src/encoding/gob/enc_helpers.go new file mode 100644 index 000000000..804e539d8 --- /dev/null +++ b/src/encoding/gob/enc_helpers.go @@ -0,0 +1,414 @@ +// Created by encgen --output enc_helpers.go; DO NOT EDIT + +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "reflect" +) + +var encArrayHelper = map[reflect.Kind]encHelper{ + reflect.Bool: encBoolArray, + reflect.Complex64: encComplex64Array, + reflect.Complex128: encComplex128Array, + reflect.Float32: encFloat32Array, + reflect.Float64: encFloat64Array, + reflect.Int: encIntArray, + reflect.Int16: encInt16Array, + reflect.Int32: encInt32Array, + reflect.Int64: encInt64Array, + reflect.Int8: encInt8Array, + reflect.String: encStringArray, + reflect.Uint: encUintArray, + reflect.Uint16: encUint16Array, + reflect.Uint32: encUint32Array, + reflect.Uint64: encUint64Array, + reflect.Uintptr: encUintptrArray, +} + +var encSliceHelper = map[reflect.Kind]encHelper{ + reflect.Bool: encBoolSlice, + reflect.Complex64: encComplex64Slice, + reflect.Complex128: encComplex128Slice, + reflect.Float32: encFloat32Slice, + reflect.Float64: encFloat64Slice, + reflect.Int: encIntSlice, + reflect.Int16: encInt16Slice, + reflect.Int32: encInt32Slice, + reflect.Int64: encInt64Slice, + reflect.Int8: encInt8Slice, + reflect.String: encStringSlice, + reflect.Uint: encUintSlice, + reflect.Uint16: encUint16Slice, + reflect.Uint32: encUint32Slice, + reflect.Uint64: encUint64Slice, + reflect.Uintptr: encUintptrSlice, +} + +func encBoolArray(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encBoolSlice(state, v.Slice(0, v.Len())) +} + +func encBoolSlice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]bool) + if !ok { + // It is kind bool but not type bool. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != false || state.sendZero { + if x { + state.encodeUint(1) + } else { + state.encodeUint(0) + } + } + } + return true +} + +func encComplex64Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encComplex64Slice(state, v.Slice(0, v.Len())) +} + +func encComplex64Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]complex64) + if !ok { + // It is kind complex64 but not type complex64. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0+0i || state.sendZero { + rpart := floatBits(float64(real(x))) + ipart := floatBits(float64(imag(x))) + state.encodeUint(rpart) + state.encodeUint(ipart) + } + } + return true +} + +func encComplex128Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encComplex128Slice(state, v.Slice(0, v.Len())) +} + +func encComplex128Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]complex128) + if !ok { + // It is kind complex128 but not type complex128. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0+0i || state.sendZero { + rpart := floatBits(real(x)) + ipart := floatBits(imag(x)) + state.encodeUint(rpart) + state.encodeUint(ipart) + } + } + return true +} + +func encFloat32Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encFloat32Slice(state, v.Slice(0, v.Len())) +} + +func encFloat32Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]float32) + if !ok { + // It is kind float32 but not type float32. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + bits := floatBits(float64(x)) + state.encodeUint(bits) + } + } + return true +} + +func encFloat64Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encFloat64Slice(state, v.Slice(0, v.Len())) +} + +func encFloat64Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]float64) + if !ok { + // It is kind float64 but not type float64. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + bits := floatBits(x) + state.encodeUint(bits) + } + } + return true +} + +func encIntArray(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encIntSlice(state, v.Slice(0, v.Len())) +} + +func encIntSlice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]int) + if !ok { + // It is kind int but not type int. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeInt(int64(x)) + } + } + return true +} + +func encInt16Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encInt16Slice(state, v.Slice(0, v.Len())) +} + +func encInt16Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]int16) + if !ok { + // It is kind int16 but not type int16. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeInt(int64(x)) + } + } + return true +} + +func encInt32Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encInt32Slice(state, v.Slice(0, v.Len())) +} + +func encInt32Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]int32) + if !ok { + // It is kind int32 but not type int32. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeInt(int64(x)) + } + } + return true +} + +func encInt64Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encInt64Slice(state, v.Slice(0, v.Len())) +} + +func encInt64Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]int64) + if !ok { + // It is kind int64 but not type int64. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeInt(x) + } + } + return true +} + +func encInt8Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encInt8Slice(state, v.Slice(0, v.Len())) +} + +func encInt8Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]int8) + if !ok { + // It is kind int8 but not type int8. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeInt(int64(x)) + } + } + return true +} + +func encStringArray(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encStringSlice(state, v.Slice(0, v.Len())) +} + +func encStringSlice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]string) + if !ok { + // It is kind string but not type string. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != "" || state.sendZero { + state.encodeUint(uint64(len(x))) + state.b.WriteString(x) + } + } + return true +} + +func encUintArray(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encUintSlice(state, v.Slice(0, v.Len())) +} + +func encUintSlice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]uint) + if !ok { + // It is kind uint but not type uint. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeUint(uint64(x)) + } + } + return true +} + +func encUint16Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encUint16Slice(state, v.Slice(0, v.Len())) +} + +func encUint16Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]uint16) + if !ok { + // It is kind uint16 but not type uint16. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeUint(uint64(x)) + } + } + return true +} + +func encUint32Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encUint32Slice(state, v.Slice(0, v.Len())) +} + +func encUint32Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]uint32) + if !ok { + // It is kind uint32 but not type uint32. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeUint(uint64(x)) + } + } + return true +} + +func encUint64Array(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encUint64Slice(state, v.Slice(0, v.Len())) +} + +func encUint64Slice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]uint64) + if !ok { + // It is kind uint64 but not type uint64. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeUint(x) + } + } + return true +} + +func encUintptrArray(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return encUintptrSlice(state, v.Slice(0, v.Len())) +} + +func encUintptrSlice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]uintptr) + if !ok { + // It is kind uintptr but not type uintptr. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != 0 || state.sendZero { + state.encodeUint(uint64(x)) + } + } + return true +} diff --git a/src/encoding/gob/encgen.go b/src/encoding/gob/encgen.go new file mode 100644 index 000000000..efdd92829 --- /dev/null +++ b/src/encoding/gob/encgen.go @@ -0,0 +1,218 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +// encgen writes the helper functions for encoding. Intended to be +// used with go generate; see the invocation in encode.go. + +// TODO: We could do more by being unsafe. Add a -unsafe flag? + +package main + +import ( + "bytes" + "flag" + "fmt" + "go/format" + "log" + "os" +) + +var output = flag.String("output", "enc_helpers.go", "file name to write") + +type Type struct { + lower string + upper string + zero string + encoder string +} + +var types = []Type{ + { + "bool", + "Bool", + "false", + `if x { + state.encodeUint(1) + } else { + state.encodeUint(0) + }`, + }, + { + "complex64", + "Complex64", + "0+0i", + `rpart := floatBits(float64(real(x))) + ipart := floatBits(float64(imag(x))) + state.encodeUint(rpart) + state.encodeUint(ipart)`, + }, + { + "complex128", + "Complex128", + "0+0i", + `rpart := floatBits(real(x)) + ipart := floatBits(imag(x)) + state.encodeUint(rpart) + state.encodeUint(ipart)`, + }, + { + "float32", + "Float32", + "0", + `bits := floatBits(float64(x)) + state.encodeUint(bits)`, + }, + { + "float64", + "Float64", + "0", + `bits := floatBits(x) + state.encodeUint(bits)`, + }, + { + "int", + "Int", + "0", + `state.encodeInt(int64(x))`, + }, + { + "int16", + "Int16", + "0", + `state.encodeInt(int64(x))`, + }, + { + "int32", + "Int32", + "0", + `state.encodeInt(int64(x))`, + }, + { + "int64", + "Int64", + "0", + `state.encodeInt(x)`, + }, + { + "int8", + "Int8", + "0", + `state.encodeInt(int64(x))`, + }, + { + "string", + "String", + `""`, + `state.encodeUint(uint64(len(x))) + state.b.WriteString(x)`, + }, + { + "uint", + "Uint", + "0", + `state.encodeUint(uint64(x))`, + }, + { + "uint16", + "Uint16", + "0", + `state.encodeUint(uint64(x))`, + }, + { + "uint32", + "Uint32", + "0", + `state.encodeUint(uint64(x))`, + }, + { + "uint64", + "Uint64", + "0", + `state.encodeUint(x)`, + }, + { + "uintptr", + "Uintptr", + "0", + `state.encodeUint(uint64(x))`, + }, + // uint8 Handled separately. +} + +func main() { + log.SetFlags(0) + log.SetPrefix("encgen: ") + flag.Parse() + if flag.NArg() != 0 { + log.Fatal("usage: encgen [--output filename]") + } + var b bytes.Buffer + fmt.Fprintf(&b, "// Created by encgen --output %s; DO NOT EDIT\n", *output) + fmt.Fprint(&b, header) + printMaps(&b, "Array") + fmt.Fprint(&b, "\n") + printMaps(&b, "Slice") + for _, t := range types { + fmt.Fprintf(&b, arrayHelper, t.lower, t.upper) + fmt.Fprintf(&b, sliceHelper, t.lower, t.upper, t.zero, t.encoder) + } + source, err := format.Source(b.Bytes()) + if err != nil { + log.Fatal("source format error:", err) + } + fd, err := os.Create(*output) + _, err = fd.Write(source) + if err != nil { + log.Fatal(err) + } +} + +func printMaps(b *bytes.Buffer, upperClass string) { + fmt.Fprintf(b, "var enc%sHelper = map[reflect.Kind]encHelper{\n", upperClass) + for _, t := range types { + fmt.Fprintf(b, "reflect.%s: enc%s%s,\n", t.upper, t.upper, upperClass) + } + fmt.Fprintf(b, "}\n") +} + +const header = ` +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "reflect" +) + +` + +const arrayHelper = ` +func enc%[2]sArray(state *encoderState, v reflect.Value) bool { + // Can only slice if it is addressable. + if !v.CanAddr() { + return false + } + return enc%[2]sSlice(state, v.Slice(0, v.Len())) +} +` + +const sliceHelper = ` +func enc%[2]sSlice(state *encoderState, v reflect.Value) bool { + slice, ok := v.Interface().([]%[1]s) + if !ok { + // It is kind %[1]s but not type %[1]s. TODO: We can handle this unsafely. + return false + } + for _, x := range slice { + if x != %[3]s || state.sendZero { + %[4]s + } + } + return true +} +` diff --git a/src/encoding/gob/encode.go b/src/encoding/gob/encode.go new file mode 100644 index 000000000..f66279f14 --- /dev/null +++ b/src/encoding/gob/encode.go @@ -0,0 +1,696 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:generate go run encgen.go -output enc_helpers.go + +package gob + +import ( + "encoding" + "math" + "reflect" +) + +const uint64Size = 8 + +type encHelper func(state *encoderState, v reflect.Value) bool + +// encoderState is the global execution state of an instance of the encoder. +// Field numbers are delta encoded and always increase. The field +// number is initialized to -1 so 0 comes out as delta(1). A delta of +// 0 terminates the structure. +type encoderState struct { + enc *Encoder + b *encBuffer + sendZero bool // encoding an array element or map key/value pair; send zero values + fieldnum int // the last field number written. + buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation. + next *encoderState // for free list +} + +// encBuffer is an extremely simple, fast implementation of a write-only byte buffer. +// It never returns a non-nil error, but Write returns an error value so it matches io.Writer. +type encBuffer struct { + data []byte + scratch [64]byte +} + +func (e *encBuffer) WriteByte(c byte) { + e.data = append(e.data, c) +} + +func (e *encBuffer) Write(p []byte) (int, error) { + e.data = append(e.data, p...) + return len(p), nil +} + +func (e *encBuffer) WriteString(s string) { + e.data = append(e.data, s...) +} + +func (e *encBuffer) Len() int { + return len(e.data) +} + +func (e *encBuffer) Bytes() []byte { + return e.data +} + +func (e *encBuffer) Reset() { + e.data = e.data[0:0] +} + +func (enc *Encoder) newEncoderState(b *encBuffer) *encoderState { + e := enc.freeList + if e == nil { + e = new(encoderState) + e.enc = enc + } else { + enc.freeList = e.next + } + e.sendZero = false + e.fieldnum = 0 + e.b = b + if len(b.data) == 0 { + b.data = b.scratch[0:0] + } + return e +} + +func (enc *Encoder) freeEncoderState(e *encoderState) { + e.next = enc.freeList + enc.freeList = e +} + +// Unsigned integers have a two-state encoding. If the number is less +// than 128 (0 through 0x7F), its value is written directly. +// Otherwise the value is written in big-endian byte order preceded +// by the byte length, negated. + +// encodeUint writes an encoded unsigned integer to state.b. +func (state *encoderState) encodeUint(x uint64) { + if x <= 0x7F { + state.b.WriteByte(uint8(x)) + return + } + i := uint64Size + for x > 0 { + state.buf[i] = uint8(x) + x >>= 8 + i-- + } + state.buf[i] = uint8(i - uint64Size) // = loop count, negated + state.b.Write(state.buf[i : uint64Size+1]) +} + +// encodeInt writes an encoded signed integer to state.w. +// The low bit of the encoding says whether to bit complement the (other bits of the) +// uint to recover the int. +func (state *encoderState) encodeInt(i int64) { + var x uint64 + if i < 0 { + x = uint64(^i<<1) | 1 + } else { + x = uint64(i << 1) + } + state.encodeUint(uint64(x)) +} + +// encOp is the signature of an encoding operator for a given type. +type encOp func(i *encInstr, state *encoderState, v reflect.Value) + +// The 'instructions' of the encoding machine +type encInstr struct { + op encOp + field int // field number in input + index []int // struct index + indir int // how many pointer indirections to reach the value in the struct +} + +// update emits a field number and updates the state to record its value for delta encoding. +// If the instruction pointer is nil, it does nothing +func (state *encoderState) update(instr *encInstr) { + if instr != nil { + state.encodeUint(uint64(instr.field - state.fieldnum)) + state.fieldnum = instr.field + } +} + +// Each encoder for a composite is responsible for handling any +// indirections associated with the elements of the data structure. +// If any pointer so reached is nil, no bytes are written. If the +// data item is zero, no bytes are written. Single values - ints, +// strings etc. - are indirected before calling their encoders. +// Otherwise, the output (for a scalar) is the field number, as an +// encoded integer, followed by the field data in its appropriate +// format. + +// encIndirect dereferences pv indir times and returns the result. +func encIndirect(pv reflect.Value, indir int) reflect.Value { + for ; indir > 0; indir-- { + if pv.IsNil() { + break + } + pv = pv.Elem() + } + return pv +} + +// encBool encodes the bool referenced by v as an unsigned 0 or 1. +func encBool(i *encInstr, state *encoderState, v reflect.Value) { + b := v.Bool() + if b || state.sendZero { + state.update(i) + if b { + state.encodeUint(1) + } else { + state.encodeUint(0) + } + } +} + +// encInt encodes the signed integer (int int8 int16 int32 int64) referenced by v. +func encInt(i *encInstr, state *encoderState, v reflect.Value) { + value := v.Int() + if value != 0 || state.sendZero { + state.update(i) + state.encodeInt(value) + } +} + +// encUint encodes the unsigned integer (uint uint8 uint16 uint32 uint64 uintptr) referenced by v. +func encUint(i *encInstr, state *encoderState, v reflect.Value) { + value := v.Uint() + if value != 0 || state.sendZero { + state.update(i) + state.encodeUint(value) + } +} + +// floatBits returns a uint64 holding the bits of a floating-point number. +// Floating-point numbers are transmitted as uint64s holding the bits +// of the underlying representation. They are sent byte-reversed, with +// the exponent end coming out first, so integer floating point numbers +// (for example) transmit more compactly. This routine does the +// swizzling. +func floatBits(f float64) uint64 { + u := math.Float64bits(f) + var v uint64 + for i := 0; i < 8; i++ { + v <<= 8 + v |= u & 0xFF + u >>= 8 + } + return v +} + +// encFloat encodes the floating point value (float32 float64) referenced by v. +func encFloat(i *encInstr, state *encoderState, v reflect.Value) { + f := v.Float() + if f != 0 || state.sendZero { + bits := floatBits(f) + state.update(i) + state.encodeUint(bits) + } +} + +// encComplex encodes the complex value (complex64 complex128) referenced by v. +// Complex numbers are just a pair of floating-point numbers, real part first. +func encComplex(i *encInstr, state *encoderState, v reflect.Value) { + c := v.Complex() + if c != 0+0i || state.sendZero { + rpart := floatBits(real(c)) + ipart := floatBits(imag(c)) + state.update(i) + state.encodeUint(rpart) + state.encodeUint(ipart) + } +} + +// encUint8Array encodes the byte array referenced by v. +// Byte arrays are encoded as an unsigned count followed by the raw bytes. +func encUint8Array(i *encInstr, state *encoderState, v reflect.Value) { + b := v.Bytes() + if len(b) > 0 || state.sendZero { + state.update(i) + state.encodeUint(uint64(len(b))) + state.b.Write(b) + } +} + +// encString encodes the string referenced by v. +// Strings are encoded as an unsigned count followed by the raw bytes. +func encString(i *encInstr, state *encoderState, v reflect.Value) { + s := v.String() + if len(s) > 0 || state.sendZero { + state.update(i) + state.encodeUint(uint64(len(s))) + state.b.WriteString(s) + } +} + +// encStructTerminator encodes the end of an encoded struct +// as delta field number of 0. +func encStructTerminator(i *encInstr, state *encoderState, v reflect.Value) { + state.encodeUint(0) +} + +// Execution engine + +// encEngine an array of instructions indexed by field number of the encoding +// data, typically a struct. It is executed top to bottom, walking the struct. +type encEngine struct { + instr []encInstr +} + +const singletonField = 0 + +// valid reports whether the value is valid and a non-nil pointer. +// (Slices, maps, and chans take care of themselves.) +func valid(v reflect.Value) bool { + switch v.Kind() { + case reflect.Invalid: + return false + case reflect.Ptr: + return !v.IsNil() + } + return true +} + +// encodeSingle encodes a single top-level non-struct value. +func (enc *Encoder) encodeSingle(b *encBuffer, engine *encEngine, value reflect.Value) { + state := enc.newEncoderState(b) + defer enc.freeEncoderState(state) + state.fieldnum = singletonField + // There is no surrounding struct to frame the transmission, so we must + // generate data even if the item is zero. To do this, set sendZero. + state.sendZero = true + instr := &engine.instr[singletonField] + if instr.indir > 0 { + value = encIndirect(value, instr.indir) + } + if valid(value) { + instr.op(instr, state, value) + } +} + +// encodeStruct encodes a single struct value. +func (enc *Encoder) encodeStruct(b *encBuffer, engine *encEngine, value reflect.Value) { + if !valid(value) { + return + } + state := enc.newEncoderState(b) + defer enc.freeEncoderState(state) + state.fieldnum = -1 + for i := 0; i < len(engine.instr); i++ { + instr := &engine.instr[i] + if i >= value.NumField() { + // encStructTerminator + instr.op(instr, state, reflect.Value{}) + break + } + field := value.FieldByIndex(instr.index) + if instr.indir > 0 { + field = encIndirect(field, instr.indir) + // TODO: Is field guaranteed valid? If so we could avoid this check. + if !valid(field) { + continue + } + } + instr.op(instr, state, field) + } +} + +// encodeArray encodes an array. +func (enc *Encoder) encodeArray(b *encBuffer, value reflect.Value, op encOp, elemIndir int, length int, helper encHelper) { + state := enc.newEncoderState(b) + defer enc.freeEncoderState(state) + state.fieldnum = -1 + state.sendZero = true + state.encodeUint(uint64(length)) + if helper != nil && helper(state, value) { + return + } + for i := 0; i < length; i++ { + elem := value.Index(i) + if elemIndir > 0 { + elem = encIndirect(elem, elemIndir) + // TODO: Is elem guaranteed valid? If so we could avoid this check. + if !valid(elem) { + errorf("encodeArray: nil element") + } + } + op(nil, state, elem) + } +} + +// encodeReflectValue is a helper for maps. It encodes the value v. +func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) { + for i := 0; i < indir && v.IsValid(); i++ { + v = reflect.Indirect(v) + } + if !v.IsValid() { + errorf("encodeReflectValue: nil element") + } + op(nil, state, v) +} + +// encodeMap encodes a map as unsigned count followed by key:value pairs. +func (enc *Encoder) encodeMap(b *encBuffer, mv reflect.Value, keyOp, elemOp encOp, keyIndir, elemIndir int) { + state := enc.newEncoderState(b) + state.fieldnum = -1 + state.sendZero = true + keys := mv.MapKeys() + state.encodeUint(uint64(len(keys))) + for _, key := range keys { + encodeReflectValue(state, key, keyOp, keyIndir) + encodeReflectValue(state, mv.MapIndex(key), elemOp, elemIndir) + } + enc.freeEncoderState(state) +} + +// encodeInterface encodes the interface value iv. +// To send an interface, we send a string identifying the concrete type, followed +// by the type identifier (which might require defining that type right now), followed +// by the concrete value. A nil value gets sent as the empty string for the name, +// followed by no value. +func (enc *Encoder) encodeInterface(b *encBuffer, iv reflect.Value) { + // Gobs can encode nil interface values but not typed interface + // values holding nil pointers, since nil pointers point to no value. + elem := iv.Elem() + if elem.Kind() == reflect.Ptr && elem.IsNil() { + errorf("gob: cannot encode nil pointer of type %s inside interface", iv.Elem().Type()) + } + state := enc.newEncoderState(b) + state.fieldnum = -1 + state.sendZero = true + if iv.IsNil() { + state.encodeUint(0) + return + } + + ut := userType(iv.Elem().Type()) + registerLock.RLock() + name, ok := concreteTypeToName[ut.base] + registerLock.RUnlock() + if !ok { + errorf("type not registered for interface: %s", ut.base) + } + // Send the name. + state.encodeUint(uint64(len(name))) + state.b.WriteString(name) + // Define the type id if necessary. + enc.sendTypeDescriptor(enc.writer(), state, ut) + // Send the type id. + enc.sendTypeId(state, ut) + // Encode the value into a new buffer. Any nested type definitions + // should be written to b, before the encoded value. + enc.pushWriter(b) + data := new(encBuffer) + data.Write(spaceForLength) + enc.encode(data, elem, ut) + if enc.err != nil { + error_(enc.err) + } + enc.popWriter() + enc.writeMessage(b, data) + if enc.err != nil { + error_(enc.err) + } + enc.freeEncoderState(state) +} + +// isZero reports whether the value is the zero of its type. +func isZero(val reflect.Value) bool { + switch val.Kind() { + case reflect.Array: + for i := 0; i < val.Len(); i++ { + if !isZero(val.Index(i)) { + return false + } + } + return true + case reflect.Map, reflect.Slice, reflect.String: + return val.Len() == 0 + case reflect.Bool: + return !val.Bool() + case reflect.Complex64, reflect.Complex128: + return val.Complex() == 0 + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Ptr: + return val.IsNil() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return val.Int() == 0 + case reflect.Float32, reflect.Float64: + return val.Float() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return val.Uint() == 0 + case reflect.Struct: + for i := 0; i < val.NumField(); i++ { + if !isZero(val.Field(i)) { + return false + } + } + return true + } + panic("unknown type in isZero " + val.Type().String()) +} + +// encGobEncoder encodes a value that implements the GobEncoder interface. +// The data is sent as a byte array. +func (enc *Encoder) encodeGobEncoder(b *encBuffer, ut *userTypeInfo, v reflect.Value) { + // TODO: should we catch panics from the called method? + + var data []byte + var err error + // We know it's one of these. + switch ut.externalEnc { + case xGob: + data, err = v.Interface().(GobEncoder).GobEncode() + case xBinary: + data, err = v.Interface().(encoding.BinaryMarshaler).MarshalBinary() + case xText: + data, err = v.Interface().(encoding.TextMarshaler).MarshalText() + } + if err != nil { + error_(err) + } + state := enc.newEncoderState(b) + state.fieldnum = -1 + state.encodeUint(uint64(len(data))) + state.b.Write(data) + enc.freeEncoderState(state) +} + +var encOpTable = [...]encOp{ + reflect.Bool: encBool, + reflect.Int: encInt, + reflect.Int8: encInt, + reflect.Int16: encInt, + reflect.Int32: encInt, + reflect.Int64: encInt, + reflect.Uint: encUint, + reflect.Uint8: encUint, + reflect.Uint16: encUint, + reflect.Uint32: encUint, + reflect.Uint64: encUint, + reflect.Uintptr: encUint, + reflect.Float32: encFloat, + reflect.Float64: encFloat, + reflect.Complex64: encComplex, + reflect.Complex128: encComplex, + reflect.String: encString, +} + +// encOpFor returns (a pointer to) the encoding op for the base type under rt and +// the indirection count to reach it. +func encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp, building map[*typeInfo]bool) (*encOp, int) { + ut := userType(rt) + // If the type implements GobEncoder, we handle it without further processing. + if ut.externalEnc != 0 { + return gobEncodeOpFor(ut) + } + // If this type is already in progress, it's a recursive type (e.g. map[string]*T). + // Return the pointer to the op we're already building. + if opPtr := inProgress[rt]; opPtr != nil { + return opPtr, ut.indir + } + typ := ut.base + indir := ut.indir + k := typ.Kind() + var op encOp + if int(k) < len(encOpTable) { + op = encOpTable[k] + } + if op == nil { + inProgress[rt] = &op + // Special cases + switch t := typ; t.Kind() { + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { + op = encUint8Array + break + } + // Slices have a header; we decode it to find the underlying array. + elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building) + helper := encSliceHelper[t.Elem().Kind()] + op = func(i *encInstr, state *encoderState, slice reflect.Value) { + if !state.sendZero && slice.Len() == 0 { + return + } + state.update(i) + state.enc.encodeArray(state.b, slice, *elemOp, elemIndir, slice.Len(), helper) + } + case reflect.Array: + // True arrays have size in the type. + elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building) + helper := encArrayHelper[t.Elem().Kind()] + op = func(i *encInstr, state *encoderState, array reflect.Value) { + state.update(i) + state.enc.encodeArray(state.b, array, *elemOp, elemIndir, array.Len(), helper) + } + case reflect.Map: + keyOp, keyIndir := encOpFor(t.Key(), inProgress, building) + elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building) + op = func(i *encInstr, state *encoderState, mv reflect.Value) { + // We send zero-length (but non-nil) maps because the + // receiver might want to use the map. (Maps don't use append.) + if !state.sendZero && mv.IsNil() { + return + } + state.update(i) + state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir) + } + case reflect.Struct: + // Generate a closure that calls out to the engine for the nested type. + getEncEngine(userType(typ), building) + info := mustGetTypeInfo(typ) + op = func(i *encInstr, state *encoderState, sv reflect.Value) { + state.update(i) + // indirect through info to delay evaluation for recursive structs + enc := info.encoder.Load().(*encEngine) + state.enc.encodeStruct(state.b, enc, sv) + } + case reflect.Interface: + op = func(i *encInstr, state *encoderState, iv reflect.Value) { + if !state.sendZero && (!iv.IsValid() || iv.IsNil()) { + return + } + state.update(i) + state.enc.encodeInterface(state.b, iv) + } + } + } + if op == nil { + errorf("can't happen: encode type %s", rt) + } + return &op, indir +} + +// gobEncodeOpFor returns the op for a type that is known to implement GobEncoder. +func gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { + rt := ut.user + if ut.encIndir == -1 { + rt = reflect.PtrTo(rt) + } else if ut.encIndir > 0 { + for i := int8(0); i < ut.encIndir; i++ { + rt = rt.Elem() + } + } + var op encOp + op = func(i *encInstr, state *encoderState, v reflect.Value) { + if ut.encIndir == -1 { + // Need to climb up one level to turn value into pointer. + if !v.CanAddr() { + errorf("unaddressable value of type %s", rt) + } + v = v.Addr() + } + if !state.sendZero && isZero(v) { + return + } + state.update(i) + state.enc.encodeGobEncoder(state.b, ut, v) + } + return &op, int(ut.encIndir) // encIndir: op will get called with p == address of receiver. +} + +// compileEnc returns the engine to compile the type. +func compileEnc(ut *userTypeInfo, building map[*typeInfo]bool) *encEngine { + srt := ut.base + engine := new(encEngine) + seen := make(map[reflect.Type]*encOp) + rt := ut.base + if ut.externalEnc != 0 { + rt = ut.user + } + if ut.externalEnc == 0 && srt.Kind() == reflect.Struct { + for fieldNum, wireFieldNum := 0, 0; fieldNum < srt.NumField(); fieldNum++ { + f := srt.Field(fieldNum) + if !isSent(&f) { + continue + } + op, indir := encOpFor(f.Type, seen, building) + engine.instr = append(engine.instr, encInstr{*op, wireFieldNum, f.Index, indir}) + wireFieldNum++ + } + if srt.NumField() > 0 && len(engine.instr) == 0 { + errorf("type %s has no exported fields", rt) + } + engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, nil, 0}) + } else { + engine.instr = make([]encInstr, 1) + op, indir := encOpFor(rt, seen, building) + engine.instr[0] = encInstr{*op, singletonField, nil, indir} + } + return engine +} + +// getEncEngine returns the engine to compile the type. +func getEncEngine(ut *userTypeInfo, building map[*typeInfo]bool) *encEngine { + info, err := getTypeInfo(ut) + if err != nil { + error_(err) + } + enc, ok := info.encoder.Load().(*encEngine) + if !ok { + enc = buildEncEngine(info, ut, building) + } + return enc +} + +func buildEncEngine(info *typeInfo, ut *userTypeInfo, building map[*typeInfo]bool) *encEngine { + // Check for recursive types. + if building != nil && building[info] { + return nil + } + info.encInit.Lock() + defer info.encInit.Unlock() + enc, ok := info.encoder.Load().(*encEngine) + if !ok { + if building == nil { + building = make(map[*typeInfo]bool) + } + building[info] = true + enc = compileEnc(ut, building) + info.encoder.Store(enc) + } + return enc +} + +func (enc *Encoder) encode(b *encBuffer, value reflect.Value, ut *userTypeInfo) { + defer catchError(&enc.err) + engine := getEncEngine(ut, nil) + indir := ut.indir + if ut.externalEnc != 0 { + indir = int(ut.encIndir) + } + for i := 0; i < indir; i++ { + value = reflect.Indirect(value) + } + if ut.externalEnc == 0 && value.Type().Kind() == reflect.Struct { + enc.encodeStruct(b, engine, value) + } else { + enc.encodeSingle(b, engine, value) + } +} diff --git a/src/encoding/gob/encoder.go b/src/encoding/gob/encoder.go new file mode 100644 index 000000000..a340e47b5 --- /dev/null +++ b/src/encoding/gob/encoder.go @@ -0,0 +1,248 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "io" + "reflect" + "sync" +) + +// An Encoder manages the transmission of type and data information to the +// other side of a connection. +type Encoder struct { + mutex sync.Mutex // each item must be sent atomically + w []io.Writer // where to send the data + sent map[reflect.Type]typeId // which types we've already sent + countState *encoderState // stage for writing counts + freeList *encoderState // list of free encoderStates; avoids reallocation + byteBuf encBuffer // buffer for top-level encoderState + err error +} + +// Before we encode a message, we reserve space at the head of the +// buffer in which to encode its length. This means we can use the +// buffer to assemble the message without another allocation. +const maxLength = 9 // Maximum size of an encoded length. +var spaceForLength = make([]byte, maxLength) + +// NewEncoder returns a new encoder that will transmit on the io.Writer. +func NewEncoder(w io.Writer) *Encoder { + enc := new(Encoder) + enc.w = []io.Writer{w} + enc.sent = make(map[reflect.Type]typeId) + enc.countState = enc.newEncoderState(new(encBuffer)) + return enc +} + +// writer() returns the innermost writer the encoder is using +func (enc *Encoder) writer() io.Writer { + return enc.w[len(enc.w)-1] +} + +// pushWriter adds a writer to the encoder. +func (enc *Encoder) pushWriter(w io.Writer) { + enc.w = append(enc.w, w) +} + +// popWriter pops the innermost writer. +func (enc *Encoder) popWriter() { + enc.w = enc.w[0 : len(enc.w)-1] +} + +func (enc *Encoder) setError(err error) { + if enc.err == nil { // remember the first. + enc.err = err + } +} + +// writeMessage sends the data item preceded by a unsigned count of its length. +func (enc *Encoder) writeMessage(w io.Writer, b *encBuffer) { + // Space has been reserved for the length at the head of the message. + // This is a little dirty: we grab the slice from the bytes.Buffer and massage + // it by hand. + message := b.Bytes() + messageLen := len(message) - maxLength + // Encode the length. + enc.countState.b.Reset() + enc.countState.encodeUint(uint64(messageLen)) + // Copy the length to be a prefix of the message. + offset := maxLength - enc.countState.b.Len() + copy(message[offset:], enc.countState.b.Bytes()) + // Write the data. + _, err := w.Write(message[offset:]) + // Drain the buffer and restore the space at the front for the count of the next message. + b.Reset() + b.Write(spaceForLength) + if err != nil { + enc.setError(err) + } +} + +// sendActualType sends the requested type, without further investigation, unless +// it's been sent before. +func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) { + if _, alreadySent := enc.sent[actual]; alreadySent { + return false + } + info, err := getTypeInfo(ut) + if err != nil { + enc.setError(err) + return + } + // Send the pair (-id, type) + // Id: + state.encodeInt(-int64(info.id)) + // Type: + enc.encode(state.b, reflect.ValueOf(info.wire), wireTypeUserInfo) + enc.writeMessage(w, state.b) + if enc.err != nil { + return + } + + // Remember we've sent this type, both what the user gave us and the base type. + enc.sent[ut.base] = info.id + if ut.user != ut.base { + enc.sent[ut.user] = info.id + } + // Now send the inner types + switch st := actual; st.Kind() { + case reflect.Struct: + for i := 0; i < st.NumField(); i++ { + if isExported(st.Field(i).Name) { + enc.sendType(w, state, st.Field(i).Type) + } + } + case reflect.Array, reflect.Slice: + enc.sendType(w, state, st.Elem()) + case reflect.Map: + enc.sendType(w, state, st.Key()) + enc.sendType(w, state, st.Elem()) + } + return true +} + +// sendType sends the type info to the other side, if necessary. +func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { + ut := userType(origt) + if ut.externalEnc != 0 { + // The rules are different: regardless of the underlying type's representation, + // we need to tell the other side that the base type is a GobEncoder. + return enc.sendActualType(w, state, ut, ut.base) + } + + // It's a concrete value, so drill down to the base type. + switch rt := ut.base; rt.Kind() { + default: + // Basic types and interfaces do not need to be described. + return + case reflect.Slice: + // If it's []uint8, don't send; it's considered basic. + if rt.Elem().Kind() == reflect.Uint8 { + return + } + // Otherwise we do send. + break + case reflect.Array: + // arrays must be sent so we know their lengths and element types. + break + case reflect.Map: + // maps must be sent so we know their lengths and key/value types. + break + case reflect.Struct: + // structs must be sent so we know their fields. + break + case reflect.Chan, reflect.Func: + // If we get here, it's a field of a struct; ignore it. + return + } + + return enc.sendActualType(w, state, ut, ut.base) +} + +// Encode transmits the data item represented by the empty interface value, +// guaranteeing that all necessary type information has been transmitted first. +func (enc *Encoder) Encode(e interface{}) error { + return enc.EncodeValue(reflect.ValueOf(e)) +} + +// sendTypeDescriptor makes sure the remote side knows about this type. +// It will send a descriptor if this is the first time the type has been +// sent. +func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) { + // Make sure the type is known to the other side. + // First, have we already sent this type? + rt := ut.base + if ut.externalEnc != 0 { + rt = ut.user + } + if _, alreadySent := enc.sent[rt]; !alreadySent { + // No, so send it. + sent := enc.sendType(w, state, rt) + if enc.err != nil { + return + } + // If the type info has still not been transmitted, it means we have + // a singleton basic type (int, []byte etc.) at top level. We don't + // need to send the type info but we do need to update enc.sent. + if !sent { + info, err := getTypeInfo(ut) + if err != nil { + enc.setError(err) + return + } + enc.sent[rt] = info.id + } + } +} + +// sendTypeId sends the id, which must have already been defined. +func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) { + // Identify the type of this top-level value. + state.encodeInt(int64(enc.sent[ut.base])) +} + +// EncodeValue transmits the data item represented by the reflection value, +// guaranteeing that all necessary type information has been transmitted first. +func (enc *Encoder) EncodeValue(value reflect.Value) error { + // Gobs contain values. They cannot represent nil pointers, which + // have no value to encode. + if value.Kind() == reflect.Ptr && value.IsNil() { + panic("gob: cannot encode nil pointer of type " + value.Type().String()) + } + + // Make sure we're single-threaded through here, so multiple + // goroutines can share an encoder. + enc.mutex.Lock() + defer enc.mutex.Unlock() + + // Remove any nested writers remaining due to previous errors. + enc.w = enc.w[0:1] + + ut, err := validUserType(value.Type()) + if err != nil { + return err + } + + enc.err = nil + enc.byteBuf.Reset() + enc.byteBuf.Write(spaceForLength) + state := enc.newEncoderState(&enc.byteBuf) + + enc.sendTypeDescriptor(enc.writer(), state, ut) + enc.sendTypeId(state, ut) + if enc.err != nil { + return enc.err + } + + // Encode the object. + enc.encode(state.b, value, ut) + if enc.err == nil { + enc.writeMessage(enc.writer(), state.b) + } + + enc.freeEncoderState(state) + return enc.err +} diff --git a/src/encoding/gob/encoder_test.go b/src/encoding/gob/encoder_test.go new file mode 100644 index 000000000..0ea4c0ec8 --- /dev/null +++ b/src/encoding/gob/encoder_test.go @@ -0,0 +1,956 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "bytes" + "fmt" + "io" + "reflect" + "strings" + "testing" +) + +// Test basic operations in a safe manner. +func TestBasicEncoderDecoder(t *testing.T) { + var values = []interface{}{ + true, + int(123), + int8(123), + int16(-12345), + int32(123456), + int64(-1234567), + uint(123), + uint8(123), + uint16(12345), + uint32(123456), + uint64(1234567), + uintptr(12345678), + float32(1.2345), + float64(1.2345678), + complex64(1.2345 + 2.3456i), + complex128(1.2345678 + 2.3456789i), + []byte("hello"), + string("hello"), + } + for _, value := range values { + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(value) + if err != nil { + t.Error("encoder fail:", err) + } + dec := NewDecoder(b) + result := reflect.New(reflect.TypeOf(value)) + err = dec.Decode(result.Interface()) + if err != nil { + t.Fatalf("error decoding %T: %v:", reflect.TypeOf(value), err) + } + if !reflect.DeepEqual(value, result.Elem().Interface()) { + t.Fatalf("%T: expected %v got %v", value, value, result.Elem().Interface()) + } + } +} + +type ET0 struct { + A int + B string +} + +type ET2 struct { + X string +} + +type ET1 struct { + A int + Et2 *ET2 + Next *ET1 +} + +// Like ET1 but with a different name for a field +type ET3 struct { + A int + Et2 *ET2 + DifferentNext *ET1 +} + +// Like ET1 but with a different type for a field +type ET4 struct { + A int + Et2 float64 + Next int +} + +func TestEncoderDecoder(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + et0 := new(ET0) + et0.A = 7 + et0.B = "gobs of fun" + err := enc.Encode(et0) + if err != nil { + t.Error("encoder fail:", err) + } + //fmt.Printf("% x %q\n", b, b) + //Debug(b) + dec := NewDecoder(b) + newEt0 := new(ET0) + err = dec.Decode(newEt0) + if err != nil { + t.Fatal("error decoding ET0:", err) + } + + if !reflect.DeepEqual(et0, newEt0) { + t.Fatalf("invalid data for et0: expected %+v; got %+v", *et0, *newEt0) + } + if b.Len() != 0 { + t.Error("not at eof;", b.Len(), "bytes left") + } + // t.FailNow() + + b = new(bytes.Buffer) + enc = NewEncoder(b) + et1 := new(ET1) + et1.A = 7 + et1.Et2 = new(ET2) + err = enc.Encode(et1) + if err != nil { + t.Error("encoder fail:", err) + } + dec = NewDecoder(b) + newEt1 := new(ET1) + err = dec.Decode(newEt1) + if err != nil { + t.Fatal("error decoding ET1:", err) + } + + if !reflect.DeepEqual(et1, newEt1) { + t.Fatalf("invalid data for et1: expected %+v; got %+v", *et1, *newEt1) + } + if b.Len() != 0 { + t.Error("not at eof;", b.Len(), "bytes left") + } + + enc.Encode(et1) + newEt1 = new(ET1) + err = dec.Decode(newEt1) + if err != nil { + t.Fatal("round 2: error decoding ET1:", err) + } + if !reflect.DeepEqual(et1, newEt1) { + t.Fatalf("round 2: invalid data for et1: expected %+v; got %+v", *et1, *newEt1) + } + if b.Len() != 0 { + t.Error("round 2: not at eof;", b.Len(), "bytes left") + } + + // Now test with a running encoder/decoder pair that we recognize a type mismatch. + err = enc.Encode(et1) + if err != nil { + t.Error("round 3: encoder fail:", err) + } + newEt2 := new(ET2) + err = dec.Decode(newEt2) + if err == nil { + t.Fatal("round 3: expected `bad type' error decoding ET2") + } +} + +// Run one value through the encoder/decoder, but use the wrong type. +// Input is always an ET1; we compare it to whatever is under 'e'. +func badTypeCheck(e interface{}, shouldFail bool, msg string, t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + et1 := new(ET1) + et1.A = 7 + et1.Et2 = new(ET2) + err := enc.Encode(et1) + if err != nil { + t.Error("encoder fail:", err) + } + dec := NewDecoder(b) + err = dec.Decode(e) + if shouldFail && err == nil { + t.Error("expected error for", msg) + } + if !shouldFail && err != nil { + t.Error("unexpected error for", msg, err) + } +} + +// Test that we recognize a bad type the first time. +func TestWrongTypeDecoder(t *testing.T) { + badTypeCheck(new(ET2), true, "no fields in common", t) + badTypeCheck(new(ET3), false, "different name of field", t) + badTypeCheck(new(ET4), true, "different type of field", t) +} + +func corruptDataCheck(s string, err error, t *testing.T) { + b := bytes.NewBufferString(s) + dec := NewDecoder(b) + err1 := dec.Decode(new(ET2)) + if err1 != err { + t.Errorf("from %q expected error %s; got %s", s, err, err1) + } +} + +// Check that we survive bad data. +func TestBadData(t *testing.T) { + corruptDataCheck("", io.EOF, t) + corruptDataCheck("\x7Fhi", io.ErrUnexpectedEOF, t) + corruptDataCheck("\x03now is the time for all good men", errBadType, t) + // issue 6323. + corruptDataCheck("\x04\x24foo", errRange, t) +} + +// Types not supported at top level by the Encoder. +var unsupportedValues = []interface{}{ + make(chan int), + func(a int) bool { return true }, +} + +func TestUnsupported(t *testing.T) { + var b bytes.Buffer + enc := NewEncoder(&b) + for _, v := range unsupportedValues { + err := enc.Encode(v) + if err == nil { + t.Errorf("expected error for %T; got none", v) + } + } +} + +func encAndDec(in, out interface{}) error { + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(in) + if err != nil { + return err + } + dec := NewDecoder(b) + err = dec.Decode(out) + if err != nil { + return err + } + return nil +} + +func TestTypeToPtrType(t *testing.T) { + // Encode a T, decode a *T + type Type0 struct { + A int + } + t0 := Type0{7} + t0p := new(Type0) + if err := encAndDec(t0, t0p); err != nil { + t.Error(err) + } +} + +func TestPtrTypeToType(t *testing.T) { + // Encode a *T, decode a T + type Type1 struct { + A uint + } + t1p := &Type1{17} + var t1 Type1 + if err := encAndDec(t1, t1p); err != nil { + t.Error(err) + } +} + +func TestTypeToPtrPtrPtrPtrType(t *testing.T) { + type Type2 struct { + A ****float64 + } + t2 := Type2{} + t2.A = new(***float64) + *t2.A = new(**float64) + **t2.A = new(*float64) + ***t2.A = new(float64) + ****t2.A = 27.4 + t2pppp := new(***Type2) + if err := encAndDec(t2, t2pppp); err != nil { + t.Fatal(err) + } + if ****(****t2pppp).A != ****t2.A { + t.Errorf("wrong value after decode: %g not %g", ****(****t2pppp).A, ****t2.A) + } +} + +func TestSlice(t *testing.T) { + type Type3 struct { + A []string + } + t3p := &Type3{[]string{"hello", "world"}} + var t3 Type3 + if err := encAndDec(t3, t3p); err != nil { + t.Error(err) + } +} + +func TestValueError(t *testing.T) { + // Encode a *T, decode a T + type Type4 struct { + A int + } + t4p := &Type4{3} + var t4 Type4 // note: not a pointer. + if err := encAndDec(t4p, t4); err == nil || strings.Index(err.Error(), "pointer") < 0 { + t.Error("expected error about pointer; got", err) + } +} + +func TestArray(t *testing.T) { + type Type5 struct { + A [3]string + B [3]byte + } + type Type6 struct { + A [2]string // can't hold t5.a + } + t5 := Type5{[3]string{"hello", ",", "world"}, [3]byte{1, 2, 3}} + var t5p Type5 + if err := encAndDec(t5, &t5p); err != nil { + t.Error(err) + } + var t6 Type6 + if err := encAndDec(t5, &t6); err == nil { + t.Error("should fail with mismatched array sizes") + } +} + +func TestRecursiveMapType(t *testing.T) { + type recursiveMap map[string]recursiveMap + r1 := recursiveMap{"A": recursiveMap{"B": nil, "C": nil}, "D": nil} + r2 := make(recursiveMap) + if err := encAndDec(r1, &r2); err != nil { + t.Error(err) + } +} + +func TestRecursiveSliceType(t *testing.T) { + type recursiveSlice []recursiveSlice + r1 := recursiveSlice{0: recursiveSlice{0: nil}, 1: nil} + r2 := make(recursiveSlice, 0) + if err := encAndDec(r1, &r2); err != nil { + t.Error(err) + } +} + +// Regression test for bug: must send zero values inside arrays +func TestDefaultsInArray(t *testing.T) { + type Type7 struct { + B []bool + I []int + S []string + F []float64 + } + t7 := Type7{ + []bool{false, false, true}, + []int{0, 0, 1}, + []string{"hi", "", "there"}, + []float64{0, 0, 1}, + } + var t7p Type7 + if err := encAndDec(t7, &t7p); err != nil { + t.Error(err) + } +} + +var testInt int +var testFloat32 float32 +var testString string +var testSlice []string +var testMap map[string]int +var testArray [7]int + +type SingleTest struct { + in interface{} + out interface{} + err string +} + +var singleTests = []SingleTest{ + {17, &testInt, ""}, + {float32(17.5), &testFloat32, ""}, + {"bike shed", &testString, ""}, + {[]string{"bike", "shed", "paint", "color"}, &testSlice, ""}, + {map[string]int{"seven": 7, "twelve": 12}, &testMap, ""}, + {[7]int{4, 55, 0, 0, 0, 0, 0}, &testArray, ""}, // case that once triggered a bug + {[7]int{4, 55, 1, 44, 22, 66, 1234}, &testArray, ""}, + + // Decode errors + {172, &testFloat32, "type"}, +} + +func TestSingletons(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + dec := NewDecoder(b) + for _, test := range singleTests { + b.Reset() + err := enc.Encode(test.in) + if err != nil { + t.Errorf("error encoding %v: %s", test.in, err) + continue + } + err = dec.Decode(test.out) + switch { + case err != nil && test.err == "": + t.Errorf("error decoding %v: %s", test.in, err) + continue + case err == nil && test.err != "": + t.Errorf("expected error decoding %v: %s", test.in, test.err) + continue + case err != nil && test.err != "": + if strings.Index(err.Error(), test.err) < 0 { + t.Errorf("wrong error decoding %v: wanted %s, got %v", test.in, test.err, err) + } + continue + } + // Get rid of the pointer in the rhs + val := reflect.ValueOf(test.out).Elem().Interface() + if !reflect.DeepEqual(test.in, val) { + t.Errorf("decoding singleton: expected %v got %v", test.in, val) + } + } +} + +func TestStructNonStruct(t *testing.T) { + type Struct struct { + A string + } + type NonStruct string + s := Struct{"hello"} + var sp Struct + if err := encAndDec(s, &sp); err != nil { + t.Error(err) + } + var ns NonStruct + if err := encAndDec(s, &ns); err == nil { + t.Error("should get error for struct/non-struct") + } else if strings.Index(err.Error(), "type") < 0 { + t.Error("for struct/non-struct expected type error; got", err) + } + // Now try the other way + var nsp NonStruct + if err := encAndDec(ns, &nsp); err != nil { + t.Error(err) + } + if err := encAndDec(ns, &s); err == nil { + t.Error("should get error for non-struct/struct") + } else if strings.Index(err.Error(), "type") < 0 { + t.Error("for non-struct/struct expected type error; got", err) + } +} + +type interfaceIndirectTestI interface { + F() bool +} + +type interfaceIndirectTestT struct{} + +func (this *interfaceIndirectTestT) F() bool { + return true +} + +// A version of a bug reported on golang-nuts. Also tests top-level +// slice of interfaces. The issue was registering *T caused T to be +// stored as the concrete type. +func TestInterfaceIndirect(t *testing.T) { + Register(&interfaceIndirectTestT{}) + b := new(bytes.Buffer) + w := []interfaceIndirectTestI{&interfaceIndirectTestT{}} + err := NewEncoder(b).Encode(w) + if err != nil { + t.Fatal("encode error:", err) + } + + var r []interfaceIndirectTestI + err = NewDecoder(b).Decode(&r) + if err != nil { + t.Fatal("decode error:", err) + } +} + +// Now follow various tests that decode into things that can't represent the +// encoded value, all of which should be legal. + +// Also, when the ignored object contains an interface value, it may define +// types. Make sure that skipping the value still defines the types by using +// the encoder/decoder pair to send a value afterwards. If an interface +// is sent, its type in the test is always NewType0, so this checks that the +// encoder and decoder don't skew with respect to type definitions. + +type Struct0 struct { + I interface{} +} + +type NewType0 struct { + S string +} + +type ignoreTest struct { + in, out interface{} +} + +var ignoreTests = []ignoreTest{ + // Decode normal struct into an empty struct + {&struct{ A int }{23}, &struct{}{}}, + // Decode normal struct into a nil. + {&struct{ A int }{23}, nil}, + // Decode singleton string into a nil. + {"hello, world", nil}, + // Decode singleton slice into a nil. + {[]int{1, 2, 3, 4}, nil}, + // Decode struct containing an interface into a nil. + {&Struct0{&NewType0{"value0"}}, nil}, + // Decode singleton slice of interfaces into a nil. + {[]interface{}{"hi", &NewType0{"value1"}, 23}, nil}, +} + +func TestDecodeIntoNothing(t *testing.T) { + Register(new(NewType0)) + for i, test := range ignoreTests { + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(test.in) + if err != nil { + t.Errorf("%d: encode error %s:", i, err) + continue + } + dec := NewDecoder(b) + err = dec.Decode(test.out) + if err != nil { + t.Errorf("%d: decode error: %s", i, err) + continue + } + // Now see if the encoder and decoder are in a consistent state. + str := fmt.Sprintf("Value %d", i) + err = enc.Encode(&NewType0{str}) + if err != nil { + t.Fatalf("%d: NewType0 encode error: %s", i, err) + } + ns := new(NewType0) + err = dec.Decode(ns) + if err != nil { + t.Fatalf("%d: NewType0 decode error: %s", i, err) + } + if ns.S != str { + t.Fatalf("%d: expected %q got %q", i, str, ns.S) + } + } +} + +// Another bug from golang-nuts, involving nested interfaces. +type Bug0Outer struct { + Bug0Field interface{} +} + +type Bug0Inner struct { + A int +} + +func TestNestedInterfaces(t *testing.T) { + var buf bytes.Buffer + e := NewEncoder(&buf) + d := NewDecoder(&buf) + Register(new(Bug0Outer)) + Register(new(Bug0Inner)) + f := &Bug0Outer{&Bug0Outer{&Bug0Inner{7}}} + var v interface{} = f + err := e.Encode(&v) + if err != nil { + t.Fatal("Encode:", err) + } + err = d.Decode(&v) + if err != nil { + t.Fatal("Decode:", err) + } + // Make sure it decoded correctly. + outer1, ok := v.(*Bug0Outer) + if !ok { + t.Fatalf("v not Bug0Outer: %T", v) + } + outer2, ok := outer1.Bug0Field.(*Bug0Outer) + if !ok { + t.Fatalf("v.Bug0Field not Bug0Outer: %T", outer1.Bug0Field) + } + inner, ok := outer2.Bug0Field.(*Bug0Inner) + if !ok { + t.Fatalf("v.Bug0Field.Bug0Field not Bug0Inner: %T", outer2.Bug0Field) + } + if inner.A != 7 { + t.Fatalf("final value %d; expected %d", inner.A, 7) + } +} + +// The bugs keep coming. We forgot to send map subtypes before the map. + +type Bug1Elem struct { + Name string + Id int +} + +type Bug1StructMap map[string]Bug1Elem + +func bug1EncDec(in Bug1StructMap, out *Bug1StructMap) error { + return nil +} + +func TestMapBug1(t *testing.T) { + in := make(Bug1StructMap) + in["val1"] = Bug1Elem{"elem1", 1} + in["val2"] = Bug1Elem{"elem2", 2} + + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(in) + if err != nil { + t.Fatal("encode:", err) + } + dec := NewDecoder(b) + out := make(Bug1StructMap) + err = dec.Decode(&out) + if err != nil { + t.Fatal("decode:", err) + } + if !reflect.DeepEqual(in, out) { + t.Errorf("mismatch: %v %v", in, out) + } +} + +func TestGobMapInterfaceEncode(t *testing.T) { + m := map[string]interface{}{ + "up": uintptr(0), + "i0": []int{-1}, + "i1": []int8{-1}, + "i2": []int16{-1}, + "i3": []int32{-1}, + "i4": []int64{-1}, + "u0": []uint{1}, + "u1": []uint8{1}, + "u2": []uint16{1}, + "u3": []uint32{1}, + "u4": []uint64{1}, + "f0": []float32{1}, + "f1": []float64{1}, + "c0": []complex64{complex(2, -2)}, + "c1": []complex128{complex(2, float64(-2))}, + "us": []uintptr{0}, + "bo": []bool{false}, + "st": []string{"s"}, + } + enc := NewEncoder(new(bytes.Buffer)) + err := enc.Encode(m) + if err != nil { + t.Errorf("encode map: %s", err) + } +} + +func TestSliceReusesMemory(t *testing.T) { + buf := new(bytes.Buffer) + // Bytes + { + x := []byte("abcd") + enc := NewEncoder(buf) + err := enc.Encode(x) + if err != nil { + t.Errorf("bytes: encode: %s", err) + } + // Decode into y, which is big enough. + y := []byte("ABCDE") + addr := &y[0] + dec := NewDecoder(buf) + err = dec.Decode(&y) + if err != nil { + t.Fatal("bytes: decode:", err) + } + if !bytes.Equal(x, y) { + t.Errorf("bytes: expected %q got %q\n", x, y) + } + if addr != &y[0] { + t.Errorf("bytes: unnecessary reallocation") + } + } + // general slice + { + x := []rune("abcd") + enc := NewEncoder(buf) + err := enc.Encode(x) + if err != nil { + t.Errorf("ints: encode: %s", err) + } + // Decode into y, which is big enough. + y := []rune("ABCDE") + addr := &y[0] + dec := NewDecoder(buf) + err = dec.Decode(&y) + if err != nil { + t.Fatal("ints: decode:", err) + } + if !reflect.DeepEqual(x, y) { + t.Errorf("ints: expected %q got %q\n", x, y) + } + if addr != &y[0] { + t.Errorf("ints: unnecessary reallocation") + } + } +} + +// Used to crash: negative count in recvMessage. +func TestBadCount(t *testing.T) { + b := []byte{0xfb, 0xa5, 0x82, 0x2f, 0xca, 0x1} + if err := NewDecoder(bytes.NewReader(b)).Decode(nil); err == nil { + t.Error("expected error from bad count") + } else if err.Error() != errBadCount.Error() { + t.Error("expected bad count error; got", err) + } +} + +// Verify that sequential Decoders built on a single input will +// succeed if the input implements ReadByte and there is no +// type information in the stream. +func TestSequentialDecoder(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + const count = 10 + for i := 0; i < count; i++ { + s := fmt.Sprintf("%d", i) + if err := enc.Encode(s); err != nil { + t.Error("encoder fail:", err) + } + } + for i := 0; i < count; i++ { + dec := NewDecoder(b) + var s string + if err := dec.Decode(&s); err != nil { + t.Fatal("decoder fail:", err) + } + if s != fmt.Sprintf("%d", i) { + t.Fatalf("decode expected %d got %s", i, s) + } + } +} + +// Should be able to have unrepresentable fields (chan, func, *chan etc.); we just ignore them. +type Bug2 struct { + A int + C chan int + CP *chan int + F func() + FPP **func() +} + +func TestChanFuncIgnored(t *testing.T) { + c := make(chan int) + f := func() {} + fp := &f + b0 := Bug2{23, c, &c, f, &fp} + var buf bytes.Buffer + enc := NewEncoder(&buf) + if err := enc.Encode(b0); err != nil { + t.Fatal("error encoding:", err) + } + var b1 Bug2 + err := NewDecoder(&buf).Decode(&b1) + if err != nil { + t.Fatal("decode:", err) + } + if b1.A != b0.A { + t.Fatalf("got %d want %d", b1.A, b0.A) + } + if b1.C != nil || b1.CP != nil || b1.F != nil || b1.FPP != nil { + t.Fatal("unexpected value for chan or func") + } +} + +func TestSliceIncompatibility(t *testing.T) { + var in = []byte{1, 2, 3} + var out []int + if err := encAndDec(in, &out); err == nil { + t.Error("expected compatibility error") + } +} + +// Mutually recursive slices of structs caused problems. +type Bug3 struct { + Num int + Children []*Bug3 +} + +func TestGobPtrSlices(t *testing.T) { + in := []*Bug3{ + {1, nil}, + {2, nil}, + } + b := new(bytes.Buffer) + err := NewEncoder(b).Encode(&in) + if err != nil { + t.Fatal("encode:", err) + } + + var out []*Bug3 + err = NewDecoder(b).Decode(&out) + if err != nil { + t.Fatal("decode:", err) + } + if !reflect.DeepEqual(in, out) { + t.Fatalf("got %v; wanted %v", out, in) + } +} + +// getDecEnginePtr cached engine for ut.base instead of ut.user so we passed +// a *map and then tried to reuse its engine to decode the inner map. +func TestPtrToMapOfMap(t *testing.T) { + Register(make(map[string]interface{})) + subdata := make(map[string]interface{}) + subdata["bar"] = "baz" + data := make(map[string]interface{}) + data["foo"] = subdata + + b := new(bytes.Buffer) + err := NewEncoder(b).Encode(data) + if err != nil { + t.Fatal("encode:", err) + } + var newData map[string]interface{} + err = NewDecoder(b).Decode(&newData) + if err != nil { + t.Fatal("decode:", err) + } + if !reflect.DeepEqual(data, newData) { + t.Fatalf("expected %v got %v", data, newData) + } +} + +// A top-level nil pointer generates a panic with a helpful string-valued message. +func TestTopLevelNilPointer(t *testing.T) { + errMsg := topLevelNilPanic(t) + if errMsg == "" { + t.Fatal("top-level nil pointer did not panic") + } + if !strings.Contains(errMsg, "nil pointer") { + t.Fatal("expected nil pointer error, got:", errMsg) + } +} + +func topLevelNilPanic(t *testing.T) (panicErr string) { + defer func() { + e := recover() + if err, ok := e.(string); ok { + panicErr = err + } + }() + var ip *int + buf := new(bytes.Buffer) + if err := NewEncoder(buf).Encode(ip); err != nil { + t.Fatal("error in encode:", err) + } + return +} + +func TestNilPointerInsideInterface(t *testing.T) { + var ip *int + si := struct { + I interface{} + }{ + I: ip, + } + buf := new(bytes.Buffer) + err := NewEncoder(buf).Encode(si) + if err == nil { + t.Fatal("expected error, got none") + } + errMsg := err.Error() + if !strings.Contains(errMsg, "nil pointer") || !strings.Contains(errMsg, "interface") { + t.Fatal("expected error about nil pointer and interface, got:", errMsg) + } +} + +type Bug4Public struct { + Name string + Secret Bug4Secret +} + +type Bug4Secret struct { + a int // error: no exported fields. +} + +// Test that a failed compilation doesn't leave around an executable encoder. +// Issue 3273. +func TestMutipleEncodingsOfBadType(t *testing.T) { + x := Bug4Public{ + Name: "name", + Secret: Bug4Secret{1}, + } + buf := new(bytes.Buffer) + enc := NewEncoder(buf) + err := enc.Encode(x) + if err == nil { + t.Fatal("first encoding: expected error") + } + buf.Reset() + enc = NewEncoder(buf) + err = enc.Encode(x) + if err == nil { + t.Fatal("second encoding: expected error") + } + if !strings.Contains(err.Error(), "no exported fields") { + t.Errorf("expected error about no exported fields; got %v", err) + } +} + +// There was an error check comparing the length of the input with the +// length of the slice being decoded. It was wrong because the next +// thing in the input might be a type definition, which would lead to +// an incorrect length check. This test reproduces the corner case. + +type Z struct { +} + +func Test29ElementSlice(t *testing.T) { + Register(Z{}) + src := make([]interface{}, 100) // Size needs to be bigger than size of type definition. + for i := range src { + src[i] = Z{} + } + buf := new(bytes.Buffer) + err := NewEncoder(buf).Encode(src) + if err != nil { + t.Fatalf("encode: %v", err) + return + } + + var dst []interface{} + err = NewDecoder(buf).Decode(&dst) + if err != nil { + t.Errorf("decode: %v", err) + return + } +} + +// Don't crash, just give error when allocating a huge slice. +// Issue 8084. +func TestErrorForHugeSlice(t *testing.T) { + // Encode an int slice. + buf := new(bytes.Buffer) + slice := []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + err := NewEncoder(buf).Encode(slice) + if err != nil { + t.Fatal("encode:", err) + } + // Reach into the buffer and smash the count to make the encoded slice very long. + buf.Bytes()[buf.Len()-len(slice)-1] = 0xfa + // Decode and see error. + err = NewDecoder(buf).Decode(&slice) + if err == nil { + t.Fatal("decode: no error") + } + if !strings.Contains(err.Error(), "slice too big") { + t.Fatal("decode: expected slice too big error, got %s", err.Error()) + } +} diff --git a/src/encoding/gob/error.go b/src/encoding/gob/error.go new file mode 100644 index 000000000..92cc0c615 --- /dev/null +++ b/src/encoding/gob/error.go @@ -0,0 +1,43 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import "fmt" + +// Errors in decoding and encoding are handled using panic and recover. +// Panics caused by user error (that is, everything except run-time panics +// such as "index out of bounds" errors) do not leave the file that caused +// them, but are instead turned into plain error returns. Encoding and +// decoding functions and methods that do not return an error either use +// panic to report an error or are guaranteed error-free. + +// A gobError is used to distinguish errors (panics) generated in this package. +type gobError struct { + err error +} + +// errorf is like error_ but takes Printf-style arguments to construct an error. +// It always prefixes the message with "gob: ". +func errorf(format string, args ...interface{}) { + error_(fmt.Errorf("gob: "+format, args...)) +} + +// error wraps the argument error and uses it as the argument to panic. +func error_(err error) { + panic(gobError{err}) +} + +// catchError is meant to be used as a deferred function to turn a panic(gobError) into a +// plain error. It overwrites the error return of the function that deferred its call. +func catchError(err *error) { + if e := recover(); e != nil { + ge, ok := e.(gobError) + if !ok { + panic(e) + } + *err = ge.err + } + return +} diff --git a/src/encoding/gob/example_encdec_test.go b/src/encoding/gob/example_encdec_test.go new file mode 100644 index 000000000..e45ad4ccf --- /dev/null +++ b/src/encoding/gob/example_encdec_test.go @@ -0,0 +1,61 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob_test + +import ( + "bytes" + "encoding/gob" + "fmt" + "log" +) + +// The Vector type has unexported fields, which the package cannot access. +// We therefore write a BinaryMarshal/BinaryUnmarshal method pair to allow us +// to send and receive the type with the gob package. These interfaces are +// defined in the "encoding" package. +// We could equivalently use the locally defined GobEncode/GobDecoder +// interfaces. +type Vector struct { + x, y, z int +} + +func (v Vector) MarshalBinary() ([]byte, error) { + // A simple encoding: plain text. + var b bytes.Buffer + fmt.Fprintln(&b, v.x, v.y, v.z) + return b.Bytes(), nil +} + +// UnmarshalBinary modifies the receiver so it must take a pointer receiver. +func (v *Vector) UnmarshalBinary(data []byte) error { + // A simple encoding: plain text. + b := bytes.NewBuffer(data) + _, err := fmt.Fscanln(b, &v.x, &v.y, &v.z) + return err +} + +// This example transmits a value that implements the custom encoding and decoding methods. +func Example_encodeDecode() { + var network bytes.Buffer // Stand-in for the network. + + // Create an encoder and send a value. + enc := gob.NewEncoder(&network) + err := enc.Encode(Vector{3, 4, 5}) + if err != nil { + log.Fatal("encode:", err) + } + + // Create a decoder and receive a value. + dec := gob.NewDecoder(&network) + var v Vector + err = dec.Decode(&v) + if err != nil { + log.Fatal("decode:", err) + } + fmt.Println(v) + + // Output: + // {3 4 5} +} diff --git a/src/encoding/gob/example_interface_test.go b/src/encoding/gob/example_interface_test.go new file mode 100644 index 000000000..4681e6307 --- /dev/null +++ b/src/encoding/gob/example_interface_test.go @@ -0,0 +1,81 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob_test + +import ( + "bytes" + "encoding/gob" + "fmt" + "log" + "math" +) + +type Point struct { + X, Y int +} + +func (p Point) Hypotenuse() float64 { + return math.Hypot(float64(p.X), float64(p.Y)) +} + +type Pythagoras interface { + Hypotenuse() float64 +} + +// This example shows how to encode an interface value. The key +// distinction from regular types is to register the concrete type that +// implements the interface. +func Example_interface() { + var network bytes.Buffer // Stand-in for the network. + + // We must register the concrete type for the encoder and decoder (which would + // normally be on a separate machine from the encoder). On each end, this tells the + // engine which concrete type is being sent that implements the interface. + gob.Register(Point{}) + + // Create an encoder and send some values. + enc := gob.NewEncoder(&network) + for i := 1; i <= 3; i++ { + interfaceEncode(enc, Point{3 * i, 4 * i}) + } + + // Create a decoder and receive some values. + dec := gob.NewDecoder(&network) + for i := 1; i <= 3; i++ { + result := interfaceDecode(dec) + fmt.Println(result.Hypotenuse()) + } + + // Output: + // 5 + // 10 + // 15 +} + +// interfaceEncode encodes the interface value into the encoder. +func interfaceEncode(enc *gob.Encoder, p Pythagoras) { + // The encode will fail unless the concrete type has been + // registered. We registered it in the calling function. + + // Pass pointer to interface so Encode sees (and hence sends) a value of + // interface type. If we passed p directly it would see the concrete type instead. + // See the blog post, "The Laws of Reflection" for background. + err := enc.Encode(&p) + if err != nil { + log.Fatal("encode:", err) + } +} + +// interfaceDecode decodes the next interface value from the stream and returns it. +func interfaceDecode(dec *gob.Decoder) Pythagoras { + // The decode will fail unless the concrete type on the wire has been + // registered. We registered it in the calling function. + var p Pythagoras + err := dec.Decode(&p) + if err != nil { + log.Fatal("decode:", err) + } + return p +} diff --git a/src/encoding/gob/example_test.go b/src/encoding/gob/example_test.go new file mode 100644 index 000000000..020352cee --- /dev/null +++ b/src/encoding/gob/example_test.go @@ -0,0 +1,60 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob_test + +import ( + "bytes" + "encoding/gob" + "fmt" + "log" +) + +type P struct { + X, Y, Z int + Name string +} + +type Q struct { + X, Y *int32 + Name string +} + +// This example shows the basic usage of the package: Create an encoder, +// transmit some values, receive them with a decoder. +func Example_basic() { + // Initialize the encoder and decoder. Normally enc and dec would be + // bound to network connections and the encoder and decoder would + // run in different processes. + var network bytes.Buffer // Stand-in for a network connection + enc := gob.NewEncoder(&network) // Will write to network. + dec := gob.NewDecoder(&network) // Will read from network. + + // Encode (send) some values. + err := enc.Encode(P{3, 4, 5, "Pythagoras"}) + if err != nil { + log.Fatal("encode error:", err) + } + err = enc.Encode(P{1782, 1841, 1922, "Treehouse"}) + if err != nil { + log.Fatal("encode error:", err) + } + + // Decode (receive) and print the values. + var q Q + err = dec.Decode(&q) + if err != nil { + log.Fatal("decode error 1:", err) + } + fmt.Printf("%q: {%d, %d}\n", q.Name, *q.X, *q.Y) + err = dec.Decode(&q) + if err != nil { + log.Fatal("decode error 2:", err) + } + fmt.Printf("%q: {%d, %d}\n", q.Name, *q.X, *q.Y) + + // Output: + // "Pythagoras": {3, 4} + // "Treehouse": {1782, 1841} +} diff --git a/src/encoding/gob/gobencdec_test.go b/src/encoding/gob/gobencdec_test.go new file mode 100644 index 000000000..eb76b481d --- /dev/null +++ b/src/encoding/gob/gobencdec_test.go @@ -0,0 +1,798 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file contains tests of the GobEncoder/GobDecoder support. + +package gob + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "strings" + "testing" + "time" +) + +// Types that implement the GobEncoder/Decoder interfaces. + +type ByteStruct struct { + a byte // not an exported field +} + +type StringStruct struct { + s string // not an exported field +} + +type ArrayStruct struct { + a [8192]byte // not an exported field +} + +type Gobber int + +type ValueGobber string // encodes with a value, decodes with a pointer. + +type BinaryGobber int + +type BinaryValueGobber string + +type TextGobber int + +type TextValueGobber string + +// The relevant methods + +func (g *ByteStruct) GobEncode() ([]byte, error) { + b := make([]byte, 3) + b[0] = g.a + b[1] = g.a + 1 + b[2] = g.a + 2 + return b, nil +} + +func (g *ByteStruct) GobDecode(data []byte) error { + if g == nil { + return errors.New("NIL RECEIVER") + } + // Expect N sequential-valued bytes. + if len(data) == 0 { + return io.EOF + } + g.a = data[0] + for i, c := range data { + if c != g.a+byte(i) { + return errors.New("invalid data sequence") + } + } + return nil +} + +func (g *StringStruct) GobEncode() ([]byte, error) { + return []byte(g.s), nil +} + +func (g *StringStruct) GobDecode(data []byte) error { + // Expect N sequential-valued bytes. + if len(data) == 0 { + return io.EOF + } + a := data[0] + for i, c := range data { + if c != a+byte(i) { + return errors.New("invalid data sequence") + } + } + g.s = string(data) + return nil +} + +func (a *ArrayStruct) GobEncode() ([]byte, error) { + return a.a[:], nil +} + +func (a *ArrayStruct) GobDecode(data []byte) error { + if len(data) != len(a.a) { + return errors.New("wrong length in array decode") + } + copy(a.a[:], data) + return nil +} + +func (g *Gobber) GobEncode() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%d", *g)), nil +} + +func (g *Gobber) GobDecode(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g)) + return err +} + +func (g *BinaryGobber) MarshalBinary() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%d", *g)), nil +} + +func (g *BinaryGobber) UnmarshalBinary(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g)) + return err +} + +func (g *TextGobber) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%d", *g)), nil +} + +func (g *TextGobber) UnmarshalText(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g)) + return err +} + +func (v ValueGobber) GobEncode() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%s", v)), nil +} + +func (v *ValueGobber) GobDecode(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v)) + return err +} + +func (v BinaryValueGobber) MarshalBinary() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%s", v)), nil +} + +func (v *BinaryValueGobber) UnmarshalBinary(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v)) + return err +} + +func (v TextValueGobber) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%s", v)), nil +} + +func (v *TextValueGobber) UnmarshalText(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v)) + return err +} + +// Structs that include GobEncodable fields. + +type GobTest0 struct { + X int // guarantee we have something in common with GobTest* + G *ByteStruct +} + +type GobTest1 struct { + X int // guarantee we have something in common with GobTest* + G *StringStruct +} + +type GobTest2 struct { + X int // guarantee we have something in common with GobTest* + G string // not a GobEncoder - should give us errors +} + +type GobTest3 struct { + X int // guarantee we have something in common with GobTest* + G *Gobber + B *BinaryGobber + T *TextGobber +} + +type GobTest4 struct { + X int // guarantee we have something in common with GobTest* + V ValueGobber + BV BinaryValueGobber + TV TextValueGobber +} + +type GobTest5 struct { + X int // guarantee we have something in common with GobTest* + V *ValueGobber + BV *BinaryValueGobber + TV *TextValueGobber +} + +type GobTest6 struct { + X int // guarantee we have something in common with GobTest* + V ValueGobber + W *ValueGobber + BV BinaryValueGobber + BW *BinaryValueGobber + TV TextValueGobber + TW *TextValueGobber +} + +type GobTest7 struct { + X int // guarantee we have something in common with GobTest* + V *ValueGobber + W ValueGobber + BV *BinaryValueGobber + BW BinaryValueGobber + TV *TextValueGobber + TW TextValueGobber +} + +type GobTestIgnoreEncoder struct { + X int // guarantee we have something in common with GobTest* +} + +type GobTestValueEncDec struct { + X int // guarantee we have something in common with GobTest* + G StringStruct // not a pointer. +} + +type GobTestIndirectEncDec struct { + X int // guarantee we have something in common with GobTest* + G ***StringStruct // indirections to the receiver. +} + +type GobTestArrayEncDec struct { + X int // guarantee we have something in common with GobTest* + A ArrayStruct // not a pointer. +} + +type GobTestIndirectArrayEncDec struct { + X int // guarantee we have something in common with GobTest* + A ***ArrayStruct // indirections to a large receiver. +} + +func TestGobEncoderField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + err := enc.Encode(GobTest0{17, &ByteStruct{'A'}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest0) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.G.a != 'A' { + t.Errorf("expected 'A' got %c", x.G.a) + } + // Now a field that's not a structure. + b.Reset() + gobber := Gobber(23) + bgobber := BinaryGobber(24) + tgobber := TextGobber(25) + err = enc.Encode(GobTest3{17, &gobber, &bgobber, &tgobber}) + if err != nil { + t.Fatal("encode error:", err) + } + y := new(GobTest3) + err = dec.Decode(y) + if err != nil { + t.Fatal("decode error:", err) + } + if *y.G != 23 || *y.B != 24 || *y.T != 25 { + t.Errorf("expected '23 got %d", *y.G) + } +} + +// Even though the field is a value, we can still take its address +// and should be able to call the methods. +func TestGobEncoderValueField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + err := enc.Encode(&GobTestValueEncDec{17, StringStruct{"HIJKL"}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestValueEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.G.s != "HIJKL" { + t.Errorf("expected `HIJKL` got %s", x.G.s) + } +} + +// GobEncode/Decode should work even if the value is +// more indirect than the receiver. +func TestGobEncoderIndirectField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + s := &StringStruct{"HIJKL"} + sp := &s + err := enc.Encode(GobTestIndirectEncDec{17, &sp}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIndirectEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if (***x.G).s != "HIJKL" { + t.Errorf("expected `HIJKL` got %s", (***x.G).s) + } +} + +// Test with a large field with methods. +func TestGobEncoderArrayField(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + var a GobTestArrayEncDec + a.X = 17 + for i := range a.A.a { + a.A.a[i] = byte(i) + } + err := enc.Encode(&a) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestArrayEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + for i, v := range x.A.a { + if v != byte(i) { + t.Errorf("expected %x got %x", byte(i), v) + break + } + } +} + +// Test an indirection to a large field with methods. +func TestGobEncoderIndirectArrayField(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + var a GobTestIndirectArrayEncDec + a.X = 17 + var array ArrayStruct + ap := &array + app := &ap + a.A = &app + for i := range array.a { + array.a[i] = byte(i) + } + err := enc.Encode(a) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIndirectArrayEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + for i, v := range (***x.A).a { + if v != byte(i) { + t.Errorf("expected %x got %x", byte(i), v) + break + } + } +} + +// As long as the fields have the same name and implement the +// interface, we can cross-connect them. Not sure it's useful +// and may even be bad but it works and it's hard to prevent +// without exposing the contents of the object, which would +// defeat the purpose. +func TestGobEncoderFieldsOfDifferentType(t *testing.T) { + // first, string in field to byte in field + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest0) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.G.a != 'A' { + t.Errorf("expected 'A' got %c", x.G.a) + } + // now the other direction, byte in field to string in field + b.Reset() + err = enc.Encode(GobTest0{17, &ByteStruct{'X'}}) + if err != nil { + t.Fatal("encode error:", err) + } + y := new(GobTest1) + err = dec.Decode(y) + if err != nil { + t.Fatal("decode error:", err) + } + if y.G.s != "XYZ" { + t.Fatalf("expected `XYZ` got %q", y.G.s) + } +} + +// Test that we can encode a value and decode into a pointer. +func TestGobEncoderValueEncoder(t *testing.T) { + // first, string in field to byte in field + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(GobTest4{17, ValueGobber("hello"), BinaryValueGobber("Καλημέρα"), TextValueGobber("こんにちは")}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest5) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if *x.V != "hello" || *x.BV != "Καλημέρα" || *x.TV != "こんにちは" { + t.Errorf("expected `hello` got %s", *x.V) + } +} + +// Test that we can use a value then a pointer type of a GobEncoder +// in the same encoded value. Bug 4647. +func TestGobEncoderValueThenPointer(t *testing.T) { + v := ValueGobber("forty-two") + w := ValueGobber("six-by-nine") + bv := BinaryValueGobber("1nanocentury") + bw := BinaryValueGobber("πseconds") + tv := TextValueGobber("gravitationalacceleration") + tw := TextValueGobber("π²ft/s²") + + // this was a bug: encoding a GobEncoder by value before a GobEncoder + // pointer would cause duplicate type definitions to be sent. + + b := new(bytes.Buffer) + enc := NewEncoder(b) + if err := enc.Encode(GobTest6{42, v, &w, bv, &bw, tv, &tw}); err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest6) + if err := dec.Decode(x); err != nil { + t.Fatal("decode error:", err) + } + + if got, want := x.V, v; got != want { + t.Errorf("v = %q, want %q", got, want) + } + if got, want := x.W, w; got == nil { + t.Errorf("w = nil, want %q", want) + } else if *got != want { + t.Errorf("w = %q, want %q", *got, want) + } + + if got, want := x.BV, bv; got != want { + t.Errorf("bv = %q, want %q", got, want) + } + if got, want := x.BW, bw; got == nil { + t.Errorf("bw = nil, want %q", want) + } else if *got != want { + t.Errorf("bw = %q, want %q", *got, want) + } + + if got, want := x.TV, tv; got != want { + t.Errorf("tv = %q, want %q", got, want) + } + if got, want := x.TW, tw; got == nil { + t.Errorf("tw = nil, want %q", want) + } else if *got != want { + t.Errorf("tw = %q, want %q", *got, want) + } +} + +// Test that we can use a pointer then a value type of a GobEncoder +// in the same encoded value. +func TestGobEncoderPointerThenValue(t *testing.T) { + v := ValueGobber("forty-two") + w := ValueGobber("six-by-nine") + bv := BinaryValueGobber("1nanocentury") + bw := BinaryValueGobber("πseconds") + tv := TextValueGobber("gravitationalacceleration") + tw := TextValueGobber("π²ft/s²") + + b := new(bytes.Buffer) + enc := NewEncoder(b) + if err := enc.Encode(GobTest7{42, &v, w, &bv, bw, &tv, tw}); err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest7) + if err := dec.Decode(x); err != nil { + t.Fatal("decode error:", err) + } + + if got, want := x.V, v; got == nil { + t.Errorf("v = nil, want %q", want) + } else if *got != want { + t.Errorf("v = %q, want %q", *got, want) + } + if got, want := x.W, w; got != want { + t.Errorf("w = %q, want %q", got, want) + } + + if got, want := x.BV, bv; got == nil { + t.Errorf("bv = nil, want %q", want) + } else if *got != want { + t.Errorf("bv = %q, want %q", *got, want) + } + if got, want := x.BW, bw; got != want { + t.Errorf("bw = %q, want %q", got, want) + } + + if got, want := x.TV, tv; got == nil { + t.Errorf("tv = nil, want %q", want) + } else if *got != want { + t.Errorf("tv = %q, want %q", *got, want) + } + if got, want := x.TW, tw; got != want { + t.Errorf("tw = %q, want %q", got, want) + } +} + +func TestGobEncoderFieldTypeError(t *testing.T) { + // GobEncoder to non-decoder: error + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := &GobTest2{} + err = dec.Decode(x) + if err == nil { + t.Fatal("expected decode error for mismatched fields (encoder to non-decoder)") + } + if strings.Index(err.Error(), "type") < 0 { + t.Fatal("expected type error; got", err) + } + // Non-encoder to GobDecoder: error + b.Reset() + err = enc.Encode(GobTest2{17, "ABC"}) + if err != nil { + t.Fatal("encode error:", err) + } + y := &GobTest1{} + err = dec.Decode(y) + if err == nil { + t.Fatal("expected decode error for mismatched fields (non-encoder to decoder)") + } + if strings.Index(err.Error(), "type") < 0 { + t.Fatal("expected type error; got", err) + } +} + +// Even though ByteStruct is a struct, it's treated as a singleton at the top level. +func TestGobEncoderStructSingleton(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(&ByteStruct{'A'}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(ByteStruct) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.a != 'A' { + t.Errorf("expected 'A' got %c", x.a) + } +} + +func TestGobEncoderNonStructSingleton(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + var g Gobber = 1234 + err := enc.Encode(&g) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + var x Gobber + err = dec.Decode(&x) + if err != nil { + t.Fatal("decode error:", err) + } + if x != 1234 { + t.Errorf("expected 1234 got %d", x) + } +} + +func TestGobEncoderIgnoreStructField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + err := enc.Encode(GobTest0{17, &ByteStruct{'A'}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIgnoreEncoder) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.X != 17 { + t.Errorf("expected 17 got %c", x.X) + } +} + +func TestGobEncoderIgnoreNonStructField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + gobber := Gobber(23) + bgobber := BinaryGobber(24) + tgobber := TextGobber(25) + err := enc.Encode(GobTest3{17, &gobber, &bgobber, &tgobber}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIgnoreEncoder) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.X != 17 { + t.Errorf("expected 17 got %c", x.X) + } +} + +func TestGobEncoderIgnoreNilEncoder(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + err := enc.Encode(GobTest0{X: 18}) // G is nil + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest0) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.X != 18 { + t.Errorf("expected x.X = 18, got %v", x.X) + } + if x.G != nil { + t.Errorf("expected x.G = nil, got %v", x.G) + } +} + +type gobDecoderBug0 struct { + foo, bar string +} + +func (br *gobDecoderBug0) String() string { + return br.foo + "-" + br.bar +} + +func (br *gobDecoderBug0) GobEncode() ([]byte, error) { + return []byte(br.String()), nil +} + +func (br *gobDecoderBug0) GobDecode(b []byte) error { + br.foo = "foo" + br.bar = "bar" + return nil +} + +// This was a bug: the receiver has a different indirection level +// than the variable. +func TestGobEncoderExtraIndirect(t *testing.T) { + gdb := &gobDecoderBug0{"foo", "bar"} + buf := new(bytes.Buffer) + e := NewEncoder(buf) + if err := e.Encode(gdb); err != nil { + t.Fatalf("encode: %v", err) + } + d := NewDecoder(buf) + var got *gobDecoderBug0 + if err := d.Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.foo != gdb.foo || got.bar != gdb.bar { + t.Errorf("got = %q, want %q", got, gdb) + } +} + +// Another bug: this caused a crash with the new Go1 Time type. +// We throw in a gob-encoding array, to test another case of isZero, +// and a struct containing an nil interface, to test a third. +type isZeroBug struct { + T time.Time + S string + I int + A isZeroBugArray + F isZeroBugInterface +} + +type isZeroBugArray [2]uint8 + +// Receiver is value, not pointer, to test isZero of array. +func (a isZeroBugArray) GobEncode() (b []byte, e error) { + b = append(b, a[:]...) + return b, nil +} + +func (a *isZeroBugArray) GobDecode(data []byte) error { + if len(data) != len(a) { + return io.EOF + } + a[0] = data[0] + a[1] = data[1] + return nil +} + +type isZeroBugInterface struct { + I interface{} +} + +func (i isZeroBugInterface) GobEncode() (b []byte, e error) { + return []byte{}, nil +} + +func (i *isZeroBugInterface) GobDecode(data []byte) error { + return nil +} + +func TestGobEncodeIsZero(t *testing.T) { + x := isZeroBug{time.Now(), "hello", -55, isZeroBugArray{1, 2}, isZeroBugInterface{}} + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(x) + if err != nil { + t.Fatal("encode:", err) + } + var y isZeroBug + dec := NewDecoder(b) + err = dec.Decode(&y) + if err != nil { + t.Fatal("decode:", err) + } + if x != y { + t.Fatalf("%v != %v", x, y) + } +} + +func TestGobEncodePtrError(t *testing.T) { + var err error + b := new(bytes.Buffer) + enc := NewEncoder(b) + err = enc.Encode(&err) + if err != nil { + t.Fatal("encode:", err) + } + dec := NewDecoder(b) + err2 := fmt.Errorf("foo") + err = dec.Decode(&err2) + if err != nil { + t.Fatal("decode:", err) + } + if err2 != nil { + t.Fatalf("expected nil, got %v", err2) + } +} + +func TestNetIP(t *testing.T) { + // Encoding of net.IP{1,2,3,4} in Go 1.1. + enc := []byte{0x07, 0x0a, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04} + + var ip net.IP + err := NewDecoder(bytes.NewReader(enc)).Decode(&ip) + if err != nil { + t.Fatalf("decode: %v", err) + } + if ip.String() != "1.2.3.4" { + t.Errorf("decoded to %v, want 1.2.3.4", ip.String()) + } +} diff --git a/src/encoding/gob/timing_test.go b/src/encoding/gob/timing_test.go new file mode 100644 index 000000000..940e5ad41 --- /dev/null +++ b/src/encoding/gob/timing_test.go @@ -0,0 +1,325 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "bytes" + "io" + "os" + "runtime" + "testing" +) + +type Bench struct { + A int + B float64 + C string + D []byte +} + +func benchmarkEndToEnd(b *testing.B, ctor func() interface{}, pipe func() (r io.Reader, w io.Writer, err error)) { + b.RunParallel(func(pb *testing.PB) { + r, w, err := pipe() + if err != nil { + b.Fatal("can't get pipe:", err) + } + v := ctor() + enc := NewEncoder(w) + dec := NewDecoder(r) + for pb.Next() { + if err := enc.Encode(v); err != nil { + b.Fatal("encode error:", err) + } + if err := dec.Decode(v); err != nil { + b.Fatal("decode error:", err) + } + } + }) +} + +func BenchmarkEndToEndPipe(b *testing.B) { + benchmarkEndToEnd(b, func() interface{} { + return &Bench{7, 3.2, "now is the time", bytes.Repeat([]byte("for all good men"), 100)} + }, func() (r io.Reader, w io.Writer, err error) { + r, w, err = os.Pipe() + return + }) +} + +func BenchmarkEndToEndByteBuffer(b *testing.B) { + benchmarkEndToEnd(b, func() interface{} { + return &Bench{7, 3.2, "now is the time", bytes.Repeat([]byte("for all good men"), 100)} + }, func() (r io.Reader, w io.Writer, err error) { + var buf bytes.Buffer + return &buf, &buf, nil + }) +} + +func BenchmarkEndToEndSliceByteBuffer(b *testing.B) { + benchmarkEndToEnd(b, func() interface{} { + v := &Bench{7, 3.2, "now is the time", nil} + Register(v) + arr := make([]interface{}, 100) + for i := range arr { + arr[i] = v + } + return &arr + }, func() (r io.Reader, w io.Writer, err error) { + var buf bytes.Buffer + return &buf, &buf, nil + }) +} + +func TestCountEncodeMallocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + + const N = 1000 + + var buf bytes.Buffer + enc := NewEncoder(&buf) + bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")} + + allocs := testing.AllocsPerRun(N, func() { + err := enc.Encode(bench) + if err != nil { + t.Fatal("encode:", err) + } + }) + if allocs != 0 { + t.Fatalf("mallocs per encode of type Bench: %v; wanted 0\n", allocs) + } +} + +func TestCountDecodeMallocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + + const N = 1000 + + var buf bytes.Buffer + enc := NewEncoder(&buf) + bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")} + + // Fill the buffer with enough to decode + testing.AllocsPerRun(N, func() { + err := enc.Encode(bench) + if err != nil { + t.Fatal("encode:", err) + } + }) + + dec := NewDecoder(&buf) + allocs := testing.AllocsPerRun(N, func() { + *bench = Bench{} + err := dec.Decode(&bench) + if err != nil { + t.Fatal("decode:", err) + } + }) + if allocs != 4 { + t.Fatalf("mallocs per decode of type Bench: %v; wanted 4\n", allocs) + } +} + +func BenchmarkEncodeComplex128Slice(b *testing.B) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + a := make([]complex128, 1000) + for i := range a { + a[i] = 1.2 + 3.4i + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + err := enc.Encode(a) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeFloat64Slice(b *testing.B) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + a := make([]float64, 1000) + for i := range a { + a[i] = 1.23e4 + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + err := enc.Encode(a) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeInt32Slice(b *testing.B) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + a := make([]int32, 1000) + for i := range a { + a[i] = 1234 + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + err := enc.Encode(a) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeStringSlice(b *testing.B) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + a := make([]string, 1000) + for i := range a { + a[i] = "now is the time" + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + err := enc.Encode(a) + if err != nil { + b.Fatal(err) + } + } +} + +// benchmarkBuf is a read buffer we can reset +type benchmarkBuf struct { + offset int + data []byte +} + +func (b *benchmarkBuf) Read(p []byte) (n int, err error) { + n = copy(p, b.data[b.offset:]) + if n == 0 { + return 0, io.EOF + } + b.offset += n + return +} + +func (b *benchmarkBuf) ReadByte() (c byte, err error) { + if b.offset >= len(b.data) { + return 0, io.EOF + } + c = b.data[b.offset] + b.offset++ + return +} + +func (b *benchmarkBuf) reset() { + b.offset = 0 +} + +func BenchmarkDecodeComplex128Slice(b *testing.B) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + a := make([]complex128, 1000) + for i := range a { + a[i] = 1.2 + 3.4i + } + err := enc.Encode(a) + if err != nil { + b.Fatal(err) + } + x := make([]complex128, 1000) + bbuf := benchmarkBuf{data: buf.Bytes()} + b.ResetTimer() + for i := 0; i < b.N; i++ { + bbuf.reset() + dec := NewDecoder(&bbuf) + err := dec.Decode(&x) + if err != nil { + b.Fatal(i, err) + } + } +} + +func BenchmarkDecodeFloat64Slice(b *testing.B) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + a := make([]float64, 1000) + for i := range a { + a[i] = 1.23e4 + } + err := enc.Encode(a) + if err != nil { + b.Fatal(err) + } + x := make([]float64, 1000) + bbuf := benchmarkBuf{data: buf.Bytes()} + b.ResetTimer() + for i := 0; i < b.N; i++ { + bbuf.reset() + dec := NewDecoder(&bbuf) + err := dec.Decode(&x) + if err != nil { + b.Fatal(i, err) + } + } +} + +func BenchmarkDecodeInt32Slice(b *testing.B) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + a := make([]int32, 1000) + for i := range a { + a[i] = 1234 + } + err := enc.Encode(a) + if err != nil { + b.Fatal(err) + } + x := make([]int32, 1000) + bbuf := benchmarkBuf{data: buf.Bytes()} + b.ResetTimer() + for i := 0; i < b.N; i++ { + bbuf.reset() + dec := NewDecoder(&bbuf) + err := dec.Decode(&x) + if err != nil { + b.Fatal(i, err) + } + } +} + +func BenchmarkDecodeStringSlice(b *testing.B) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + a := make([]string, 1000) + for i := range a { + a[i] = "now is the time" + } + err := enc.Encode(a) + if err != nil { + b.Fatal(err) + } + x := make([]string, 1000) + bbuf := benchmarkBuf{data: buf.Bytes()} + b.ResetTimer() + for i := 0; i < b.N; i++ { + bbuf.reset() + dec := NewDecoder(&bbuf) + err := dec.Decode(&x) + if err != nil { + b.Fatal(i, err) + } + } +} diff --git a/src/encoding/gob/type.go b/src/encoding/gob/type.go new file mode 100644 index 000000000..a49b71a86 --- /dev/null +++ b/src/encoding/gob/type.go @@ -0,0 +1,923 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "encoding" + "errors" + "fmt" + "os" + "reflect" + "sync" + "sync/atomic" + "unicode" + "unicode/utf8" +) + +// userTypeInfo stores the information associated with a type the user has handed +// to the package. It's computed once and stored in a map keyed by reflection +// type. +type userTypeInfo struct { + user reflect.Type // the type the user handed us + base reflect.Type // the base type after all indirections + indir int // number of indirections to reach the base type + externalEnc int // xGob, xBinary, or xText + externalDec int // xGob, xBinary or xText + encIndir int8 // number of indirections to reach the receiver type; may be negative + decIndir int8 // number of indirections to reach the receiver type; may be negative +} + +// externalEncoding bits +const ( + xGob = 1 + iota // GobEncoder or GobDecoder + xBinary // encoding.BinaryMarshaler or encoding.BinaryUnmarshaler + xText // encoding.TextMarshaler or encoding.TextUnmarshaler +) + +var ( + // Protected by an RWMutex because we read it a lot and write + // it only when we see a new type, typically when compiling. + userTypeLock sync.RWMutex + userTypeCache = make(map[reflect.Type]*userTypeInfo) +) + +// validType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, err will be non-nil. To be used when the error handler +// is not set up. +func validUserType(rt reflect.Type) (ut *userTypeInfo, err error) { + userTypeLock.RLock() + ut = userTypeCache[rt] + userTypeLock.RUnlock() + if ut != nil { + return + } + // Now set the value under the write lock. + userTypeLock.Lock() + defer userTypeLock.Unlock() + if ut = userTypeCache[rt]; ut != nil { + // Lost the race; not a problem. + return + } + ut = new(userTypeInfo) + ut.base = rt + ut.user = rt + // A type that is just a cycle of pointers (such as type T *T) cannot + // be represented in gobs, which need some concrete data. We use a + // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6, + // pp 539-540. As we step through indirections, run another type at + // half speed. If they meet up, there's a cycle. + slowpoke := ut.base // walks half as fast as ut.base + for { + pt := ut.base + if pt.Kind() != reflect.Ptr { + break + } + ut.base = pt.Elem() + if ut.base == slowpoke { // ut.base lapped slowpoke + // recursive pointer type. + return nil, errors.New("can't represent recursive pointer type " + ut.base.String()) + } + if ut.indir%2 == 0 { + slowpoke = slowpoke.Elem() + } + ut.indir++ + } + + if ok, indir := implementsInterface(ut.user, gobEncoderInterfaceType); ok { + ut.externalEnc, ut.encIndir = xGob, indir + } else if ok, indir := implementsInterface(ut.user, binaryMarshalerInterfaceType); ok { + ut.externalEnc, ut.encIndir = xBinary, indir + } + + // NOTE(rsc): Would like to allow MarshalText here, but results in incompatibility + // with older encodings for net.IP. See golang.org/issue/6760. + // } else if ok, indir := implementsInterface(ut.user, textMarshalerInterfaceType); ok { + // ut.externalEnc, ut.encIndir = xText, indir + // } + + if ok, indir := implementsInterface(ut.user, gobDecoderInterfaceType); ok { + ut.externalDec, ut.decIndir = xGob, indir + } else if ok, indir := implementsInterface(ut.user, binaryUnmarshalerInterfaceType); ok { + ut.externalDec, ut.decIndir = xBinary, indir + } + + // See note above. + // } else if ok, indir := implementsInterface(ut.user, textUnmarshalerInterfaceType); ok { + // ut.externalDec, ut.decIndir = xText, indir + // } + + userTypeCache[rt] = ut + return +} + +var ( + gobEncoderInterfaceType = reflect.TypeOf((*GobEncoder)(nil)).Elem() + gobDecoderInterfaceType = reflect.TypeOf((*GobDecoder)(nil)).Elem() + binaryMarshalerInterfaceType = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem() + binaryUnmarshalerInterfaceType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem() + textMarshalerInterfaceType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + textUnmarshalerInterfaceType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +// implementsInterface reports whether the type implements the +// gobEncoder/gobDecoder interface. +// It also returns the number of indirections required to get to the +// implementation. +func implementsInterface(typ, gobEncDecType reflect.Type) (success bool, indir int8) { + if typ == nil { + return + } + rt := typ + // The type might be a pointer and we need to keep + // dereferencing to the base type until we find an implementation. + for { + if rt.Implements(gobEncDecType) { + return true, indir + } + if p := rt; p.Kind() == reflect.Ptr { + indir++ + if indir > 100 { // insane number of indirections + return false, 0 + } + rt = p.Elem() + continue + } + break + } + // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy. + if typ.Kind() != reflect.Ptr { + // Not a pointer, but does the pointer work? + if reflect.PtrTo(typ).Implements(gobEncDecType) { + return true, -1 + } + } + return false, 0 +} + +// userType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, it calls error. +func userType(rt reflect.Type) *userTypeInfo { + ut, err := validUserType(rt) + if err != nil { + error_(err) + } + return ut +} + +// A typeId represents a gob Type as an integer that can be passed on the wire. +// Internally, typeIds are used as keys to a map to recover the underlying type info. +type typeId int32 + +var nextId typeId // incremented for each new type we build +var typeLock sync.Mutex // set while building a type +const firstUserId = 64 // lowest id number granted to user + +type gobType interface { + id() typeId + setId(id typeId) + name() string + string() string // not public; only for debugging + safeString(seen map[typeId]bool) string +} + +var types = make(map[reflect.Type]gobType) +var idToType = make(map[typeId]gobType) +var builtinIdToType map[typeId]gobType // set in init() after builtins are established + +func setTypeId(typ gobType) { + // When building recursive types, someone may get there before us. + if typ.id() != 0 { + return + } + nextId++ + typ.setId(nextId) + idToType[nextId] = typ +} + +func (t typeId) gobType() gobType { + if t == 0 { + return nil + } + return idToType[t] +} + +// string returns the string representation of the type associated with the typeId. +func (t typeId) string() string { + if t.gobType() == nil { + return "<nil>" + } + return t.gobType().string() +} + +// Name returns the name of the type associated with the typeId. +func (t typeId) name() string { + if t.gobType() == nil { + return "<nil>" + } + return t.gobType().name() +} + +// CommonType holds elements of all types. +// It is a historical artifact, kept for binary compatibility and exported +// only for the benefit of the package's encoding of type descriptors. It is +// not intended for direct use by clients. +type CommonType struct { + Name string + Id typeId +} + +func (t *CommonType) id() typeId { return t.Id } + +func (t *CommonType) setId(id typeId) { t.Id = id } + +func (t *CommonType) string() string { return t.Name } + +func (t *CommonType) safeString(seen map[typeId]bool) string { + return t.Name +} + +func (t *CommonType) name() string { return t.Name } + +// Create and check predefined types +// The string for tBytes is "bytes" not "[]byte" to signify its specialness. + +var ( + // Primordial types, needed during initialization. + // Always passed as pointers so the interface{} type + // goes through without losing its interfaceness. + tBool = bootstrapType("bool", (*bool)(nil), 1) + tInt = bootstrapType("int", (*int)(nil), 2) + tUint = bootstrapType("uint", (*uint)(nil), 3) + tFloat = bootstrapType("float", (*float64)(nil), 4) + tBytes = bootstrapType("bytes", (*[]byte)(nil), 5) + tString = bootstrapType("string", (*string)(nil), 6) + tComplex = bootstrapType("complex", (*complex128)(nil), 7) + tInterface = bootstrapType("interface", (*interface{})(nil), 8) + // Reserve some Ids for compatible expansion + tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil), 9) + tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil), 10) + tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil), 11) + tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil), 12) + tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil), 13) + tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil), 14) + tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil), 15) +) + +// Predefined because it's needed by the Decoder +var tWireType = mustGetTypeInfo(reflect.TypeOf(wireType{})).id +var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType) + +func init() { + // Some magic numbers to make sure there are no surprises. + checkId(16, tWireType) + checkId(17, mustGetTypeInfo(reflect.TypeOf(arrayType{})).id) + checkId(18, mustGetTypeInfo(reflect.TypeOf(CommonType{})).id) + checkId(19, mustGetTypeInfo(reflect.TypeOf(sliceType{})).id) + checkId(20, mustGetTypeInfo(reflect.TypeOf(structType{})).id) + checkId(21, mustGetTypeInfo(reflect.TypeOf(fieldType{})).id) + checkId(23, mustGetTypeInfo(reflect.TypeOf(mapType{})).id) + + builtinIdToType = make(map[typeId]gobType) + for k, v := range idToType { + builtinIdToType[k] = v + } + + // Move the id space upwards to allow for growth in the predefined world + // without breaking existing files. + if nextId > firstUserId { + panic(fmt.Sprintln("nextId too large:", nextId)) + } + nextId = firstUserId + registerBasics() + wireTypeUserInfo = userType(reflect.TypeOf((*wireType)(nil))) +} + +// Array type +type arrayType struct { + CommonType + Elem typeId + Len int +} + +func newArrayType(name string) *arrayType { + a := &arrayType{CommonType{Name: name}, 0, 0} + return a +} + +func (a *arrayType) init(elem gobType, len int) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(a) + a.Elem = elem.id() + a.Len = len +} + +func (a *arrayType) safeString(seen map[typeId]bool) string { + if seen[a.Id] { + return a.Name + } + seen[a.Id] = true + return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen)) +} + +func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) } + +// GobEncoder type (something that implements the GobEncoder interface) +type gobEncoderType struct { + CommonType +} + +func newGobEncoderType(name string) *gobEncoderType { + g := &gobEncoderType{CommonType{Name: name}} + setTypeId(g) + return g +} + +func (g *gobEncoderType) safeString(seen map[typeId]bool) string { + return g.Name +} + +func (g *gobEncoderType) string() string { return g.Name } + +// Map type +type mapType struct { + CommonType + Key typeId + Elem typeId +} + +func newMapType(name string) *mapType { + m := &mapType{CommonType{Name: name}, 0, 0} + return m +} + +func (m *mapType) init(key, elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(m) + m.Key = key.id() + m.Elem = elem.id() +} + +func (m *mapType) safeString(seen map[typeId]bool) string { + if seen[m.Id] { + return m.Name + } + seen[m.Id] = true + key := m.Key.gobType().safeString(seen) + elem := m.Elem.gobType().safeString(seen) + return fmt.Sprintf("map[%s]%s", key, elem) +} + +func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) } + +// Slice type +type sliceType struct { + CommonType + Elem typeId +} + +func newSliceType(name string) *sliceType { + s := &sliceType{CommonType{Name: name}, 0} + return s +} + +func (s *sliceType) init(elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(s) + // See the comments about ids in newTypeObject. Only slices and + // structs have mutual recursion. + if elem.id() == 0 { + setTypeId(elem) + } + s.Elem = elem.id() +} + +func (s *sliceType) safeString(seen map[typeId]bool) string { + if seen[s.Id] { + return s.Name + } + seen[s.Id] = true + return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen)) +} + +func (s *sliceType) string() string { return s.safeString(make(map[typeId]bool)) } + +// Struct type +type fieldType struct { + Name string + Id typeId +} + +type structType struct { + CommonType + Field []*fieldType +} + +func (s *structType) safeString(seen map[typeId]bool) string { + if s == nil { + return "<nil>" + } + if _, ok := seen[s.Id]; ok { + return s.Name + } + seen[s.Id] = true + str := s.Name + " = struct { " + for _, f := range s.Field { + str += fmt.Sprintf("%s %s; ", f.Name, f.Id.gobType().safeString(seen)) + } + str += "}" + return str +} + +func (s *structType) string() string { return s.safeString(make(map[typeId]bool)) } + +func newStructType(name string) *structType { + s := &structType{CommonType{Name: name}, nil} + // For historical reasons we set the id here rather than init. + // See the comment in newTypeObject for details. + setTypeId(s) + return s +} + +// newTypeObject allocates a gobType for the reflection type rt. +// Unless ut represents a GobEncoder, rt should be the base type +// of ut. +// This is only called from the encoding side. The decoding side +// works through typeIds and userTypeInfos alone. +func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, error) { + // Does this type implement GobEncoder? + if ut.externalEnc != 0 { + return newGobEncoderType(name), nil + } + var err error + var type0, type1 gobType + defer func() { + if err != nil { + delete(types, rt) + } + }() + // Install the top-level type before the subtypes (e.g. struct before + // fields) so recursive types can be constructed safely. + switch t := rt; t.Kind() { + // All basic types are easy: they are predefined. + case reflect.Bool: + return tBool.gobType(), nil + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return tInt.gobType(), nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return tUint.gobType(), nil + + case reflect.Float32, reflect.Float64: + return tFloat.gobType(), nil + + case reflect.Complex64, reflect.Complex128: + return tComplex.gobType(), nil + + case reflect.String: + return tString.gobType(), nil + + case reflect.Interface: + return tInterface.gobType(), nil + + case reflect.Array: + at := newArrayType(name) + types[rt] = at + type0, err = getBaseType("", t.Elem()) + if err != nil { + return nil, err + } + // Historical aside: + // For arrays, maps, and slices, we set the type id after the elements + // are constructed. This is to retain the order of type id allocation after + // a fix made to handle recursive types, which changed the order in + // which types are built. Delaying the setting in this way preserves + // type ids while allowing recursive types to be described. Structs, + // done below, were already handling recursion correctly so they + // assign the top-level id before those of the field. + at.init(type0, t.Len()) + return at, nil + + case reflect.Map: + mt := newMapType(name) + types[rt] = mt + type0, err = getBaseType("", t.Key()) + if err != nil { + return nil, err + } + type1, err = getBaseType("", t.Elem()) + if err != nil { + return nil, err + } + mt.init(type0, type1) + return mt, nil + + case reflect.Slice: + // []byte == []uint8 is a special case + if t.Elem().Kind() == reflect.Uint8 { + return tBytes.gobType(), nil + } + st := newSliceType(name) + types[rt] = st + type0, err = getBaseType(t.Elem().Name(), t.Elem()) + if err != nil { + return nil, err + } + st.init(type0) + return st, nil + + case reflect.Struct: + st := newStructType(name) + types[rt] = st + idToType[st.id()] = st + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !isSent(&f) { + continue + } + typ := userType(f.Type).base + tname := typ.Name() + if tname == "" { + t := userType(f.Type).base + tname = t.String() + } + gt, err := getBaseType(tname, f.Type) + if err != nil { + return nil, err + } + // Some mutually recursive types can cause us to be here while + // still defining the element. Fix the element type id here. + // We could do this more neatly by setting the id at the start of + // building every type, but that would break binary compatibility. + if gt.id() == 0 { + setTypeId(gt) + } + st.Field = append(st.Field, &fieldType{f.Name, gt.id()}) + } + return st, nil + + default: + return nil, errors.New("gob NewTypeObject can't handle type: " + rt.String()) + } +} + +// isExported reports whether this is an exported - upper case - name. +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// isSent reports whether this struct field is to be transmitted. +// It will be transmitted only if it is exported and not a chan or func field +// or pointer to chan or func. +func isSent(field *reflect.StructField) bool { + if !isExported(field.Name) { + return false + } + // If the field is a chan or func or pointer thereto, don't send it. + // That is, treat it like an unexported field. + typ := field.Type + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + if typ.Kind() == reflect.Chan || typ.Kind() == reflect.Func { + return false + } + return true +} + +// getBaseType returns the Gob type describing the given reflect.Type's base type. +// typeLock must be held. +func getBaseType(name string, rt reflect.Type) (gobType, error) { + ut := userType(rt) + return getType(name, ut, ut.base) +} + +// getType returns the Gob type describing the given reflect.Type. +// Should be called only when handling GobEncoders/Decoders, +// which may be pointers. All other types are handled through the +// base type, never a pointer. +// typeLock must be held. +func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, error) { + typ, present := types[rt] + if present { + return typ, nil + } + typ, err := newTypeObject(name, ut, rt) + if err == nil { + types[rt] = typ + } + return typ, err +} + +func checkId(want, got typeId) { + if want != got { + fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(got), int(want)) + panic("bootstrap type wrong id: " + got.name() + " " + got.string() + " not " + want.string()) + } +} + +// used for building the basic types; called only from init(). the incoming +// interface always refers to a pointer. +func bootstrapType(name string, e interface{}, expect typeId) typeId { + rt := reflect.TypeOf(e).Elem() + _, present := types[rt] + if present { + panic("bootstrap type already present: " + name + ", " + rt.String()) + } + typ := &CommonType{Name: name} + types[rt] = typ + setTypeId(typ) + checkId(expect, nextId) + userType(rt) // might as well cache it now + return nextId +} + +// Representation of the information we send and receive about this type. +// Each value we send is preceded by its type definition: an encoded int. +// However, the very first time we send the value, we first send the pair +// (-id, wireType). +// For bootstrapping purposes, we assume that the recipient knows how +// to decode a wireType; it is exactly the wireType struct here, interpreted +// using the gob rules for sending a structure, except that we assume the +// ids for wireType and structType etc. are known. The relevant pieces +// are built in encode.go's init() function. +// To maintain binary compatibility, if you extend this type, always put +// the new fields last. +type wireType struct { + ArrayT *arrayType + SliceT *sliceType + StructT *structType + MapT *mapType + GobEncoderT *gobEncoderType + BinaryMarshalerT *gobEncoderType + TextMarshalerT *gobEncoderType +} + +func (w *wireType) string() string { + const unknown = "unknown type" + if w == nil { + return unknown + } + switch { + case w.ArrayT != nil: + return w.ArrayT.Name + case w.SliceT != nil: + return w.SliceT.Name + case w.StructT != nil: + return w.StructT.Name + case w.MapT != nil: + return w.MapT.Name + case w.GobEncoderT != nil: + return w.GobEncoderT.Name + case w.BinaryMarshalerT != nil: + return w.BinaryMarshalerT.Name + case w.TextMarshalerT != nil: + return w.TextMarshalerT.Name + } + return unknown +} + +type typeInfo struct { + id typeId + encInit sync.Mutex // protects creation of encoder + encoder atomic.Value // *encEngine + wire *wireType +} + +// typeInfoMap is an atomic pointer to map[reflect.Type]*typeInfo. +// It's updated copy-on-write. Readers just do an atomic load +// to get the current version of the map. Writers make a full copy of +// the map and atomically update the pointer to point to the new map. +// Under heavy read contention, this is significantly faster than a map +// protected by a mutex. +var typeInfoMap atomic.Value + +func lookupTypeInfo(rt reflect.Type) *typeInfo { + m, _ := typeInfoMap.Load().(map[reflect.Type]*typeInfo) + return m[rt] +} + +func getTypeInfo(ut *userTypeInfo) (*typeInfo, error) { + rt := ut.base + if ut.externalEnc != 0 { + // We want the user type, not the base type. + rt = ut.user + } + if info := lookupTypeInfo(rt); info != nil { + return info, nil + } + return buildTypeInfo(ut, rt) +} + +// buildTypeInfo constructs the type information for the type +// and stores it in the type info map. +func buildTypeInfo(ut *userTypeInfo, rt reflect.Type) (*typeInfo, error) { + typeLock.Lock() + defer typeLock.Unlock() + + if info := lookupTypeInfo(rt); info != nil { + return info, nil + } + + gt, err := getBaseType(rt.Name(), rt) + if err != nil { + return nil, err + } + info := &typeInfo{id: gt.id()} + + if ut.externalEnc != 0 { + userType, err := getType(rt.Name(), ut, rt) + if err != nil { + return nil, err + } + gt := userType.id().gobType().(*gobEncoderType) + switch ut.externalEnc { + case xGob: + info.wire = &wireType{GobEncoderT: gt} + case xBinary: + info.wire = &wireType{BinaryMarshalerT: gt} + case xText: + info.wire = &wireType{TextMarshalerT: gt} + } + rt = ut.user + } else { + t := info.id.gobType() + switch typ := rt; typ.Kind() { + case reflect.Array: + info.wire = &wireType{ArrayT: t.(*arrayType)} + case reflect.Map: + info.wire = &wireType{MapT: t.(*mapType)} + case reflect.Slice: + // []byte == []uint8 is a special case handled separately + if typ.Elem().Kind() != reflect.Uint8 { + info.wire = &wireType{SliceT: t.(*sliceType)} + } + case reflect.Struct: + info.wire = &wireType{StructT: t.(*structType)} + } + } + + // Create new map with old contents plus new entry. + newm := make(map[reflect.Type]*typeInfo) + m, _ := typeInfoMap.Load().(map[reflect.Type]*typeInfo) + for k, v := range m { + newm[k] = v + } + newm[rt] = info + typeInfoMap.Store(newm) + return info, nil +} + +// Called only when a panic is acceptable and unexpected. +func mustGetTypeInfo(rt reflect.Type) *typeInfo { + t, err := getTypeInfo(userType(rt)) + if err != nil { + panic("getTypeInfo: " + err.Error()) + } + return t +} + +// GobEncoder is the interface describing data that provides its own +// representation for encoding values for transmission to a GobDecoder. +// A type that implements GobEncoder and GobDecoder has complete +// control over the representation of its data and may therefore +// contain things such as private fields, channels, and functions, +// which are not usually transmissible in gob streams. +// +// Note: Since gobs can be stored permanently, It is good design +// to guarantee the encoding used by a GobEncoder is stable as the +// software evolves. For instance, it might make sense for GobEncode +// to include a version number in the encoding. +type GobEncoder interface { + // GobEncode returns a byte slice representing the encoding of the + // receiver for transmission to a GobDecoder, usually of the same + // concrete type. + GobEncode() ([]byte, error) +} + +// GobDecoder is the interface describing data that provides its own +// routine for decoding transmitted values sent by a GobEncoder. +type GobDecoder interface { + // GobDecode overwrites the receiver, which must be a pointer, + // with the value represented by the byte slice, which was written + // by GobEncode, usually for the same concrete type. + GobDecode([]byte) error +} + +var ( + registerLock sync.RWMutex + nameToConcreteType = make(map[string]reflect.Type) + concreteTypeToName = make(map[reflect.Type]string) +) + +// RegisterName is like Register but uses the provided name rather than the +// type's default. +func RegisterName(name string, value interface{}) { + if name == "" { + // reserved for nil + panic("attempt to register empty name") + } + registerLock.Lock() + defer registerLock.Unlock() + ut := userType(reflect.TypeOf(value)) + // Check for incompatible duplicates. The name must refer to the + // same user type, and vice versa. + if t, ok := nameToConcreteType[name]; ok && t != ut.user { + panic(fmt.Sprintf("gob: registering duplicate types for %q: %s != %s", name, t, ut.user)) + } + if n, ok := concreteTypeToName[ut.base]; ok && n != name { + panic(fmt.Sprintf("gob: registering duplicate names for %s: %q != %q", ut.user, n, name)) + } + // Store the name and type provided by the user.... + nameToConcreteType[name] = reflect.TypeOf(value) + // but the flattened type in the type table, since that's what decode needs. + concreteTypeToName[ut.base] = name +} + +// Register records a type, identified by a value for that type, under its +// internal type name. That name will identify the concrete type of a value +// sent or received as an interface variable. Only types that will be +// transferred as implementations of interface values need to be registered. +// Expecting to be used only during initialization, it panics if the mapping +// between types and names is not a bijection. +func Register(value interface{}) { + // Default to printed representation for unnamed types + rt := reflect.TypeOf(value) + name := rt.String() + + // But for named types (or pointers to them), qualify with import path (but see inner comment). + // Dereference one pointer looking for a named type. + star := "" + if rt.Name() == "" { + if pt := rt; pt.Kind() == reflect.Ptr { + star = "*" + // NOTE: The following line should be rt = pt.Elem() to implement + // what the comment above claims, but fixing it would break compatibility + // with existing gobs. + // + // Given package p imported as "full/p" with these definitions: + // package p + // type T1 struct { ... } + // this table shows the intended and actual strings used by gob to + // name the types: + // + // Type Correct string Actual string + // + // T1 full/p.T1 full/p.T1 + // *T1 *full/p.T1 *p.T1 + // + // The missing full path cannot be fixed without breaking existing gob decoders. + rt = pt + } + } + if rt.Name() != "" { + if rt.PkgPath() == "" { + name = star + rt.Name() + } else { + name = star + rt.PkgPath() + "." + rt.Name() + } + } + + RegisterName(name, value) +} + +func registerBasics() { + Register(int(0)) + Register(int8(0)) + Register(int16(0)) + Register(int32(0)) + Register(int64(0)) + Register(uint(0)) + Register(uint8(0)) + Register(uint16(0)) + Register(uint32(0)) + Register(uint64(0)) + Register(float32(0)) + Register(float64(0)) + Register(complex64(0i)) + Register(complex128(0i)) + Register(uintptr(0)) + Register(false) + Register("") + Register([]byte(nil)) + Register([]int(nil)) + Register([]int8(nil)) + Register([]int16(nil)) + Register([]int32(nil)) + Register([]int64(nil)) + Register([]uint(nil)) + Register([]uint8(nil)) + Register([]uint16(nil)) + Register([]uint32(nil)) + Register([]uint64(nil)) + Register([]float32(nil)) + Register([]float64(nil)) + Register([]complex64(nil)) + Register([]complex128(nil)) + Register([]uintptr(nil)) + Register([]bool(nil)) + Register([]string(nil)) +} diff --git a/src/encoding/gob/type_test.go b/src/encoding/gob/type_test.go new file mode 100644 index 000000000..e230d22d4 --- /dev/null +++ b/src/encoding/gob/type_test.go @@ -0,0 +1,222 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "bytes" + "reflect" + "testing" +) + +type typeT struct { + id typeId + str string +} + +var basicTypes = []typeT{ + {tBool, "bool"}, + {tInt, "int"}, + {tUint, "uint"}, + {tFloat, "float"}, + {tBytes, "bytes"}, + {tString, "string"}, +} + +func getTypeUnlocked(name string, rt reflect.Type) gobType { + typeLock.Lock() + defer typeLock.Unlock() + t, err := getBaseType(name, rt) + if err != nil { + panic("getTypeUnlocked: " + err.Error()) + } + return t +} + +// Sanity checks +func TestBasic(t *testing.T) { + for _, tt := range basicTypes { + if tt.id.string() != tt.str { + t.Errorf("checkType: expected %q got %s", tt.str, tt.id.string()) + } + if tt.id == 0 { + t.Errorf("id for %q is zero", tt.str) + } + } +} + +// Reregister some basic types to check registration is idempotent. +func TestReregistration(t *testing.T) { + newtyp := getTypeUnlocked("int", reflect.TypeOf(int(0))) + if newtyp != tInt.gobType() { + t.Errorf("reregistration of %s got new type", newtyp.string()) + } + newtyp = getTypeUnlocked("uint", reflect.TypeOf(uint(0))) + if newtyp != tUint.gobType() { + t.Errorf("reregistration of %s got new type", newtyp.string()) + } + newtyp = getTypeUnlocked("string", reflect.TypeOf("hello")) + if newtyp != tString.gobType() { + t.Errorf("reregistration of %s got new type", newtyp.string()) + } +} + +func TestArrayType(t *testing.T) { + var a3 [3]int + a3int := getTypeUnlocked("foo", reflect.TypeOf(a3)) + newa3int := getTypeUnlocked("bar", reflect.TypeOf(a3)) + if a3int != newa3int { + t.Errorf("second registration of [3]int creates new type") + } + var a4 [4]int + a4int := getTypeUnlocked("goo", reflect.TypeOf(a4)) + if a3int == a4int { + t.Errorf("registration of [3]int creates same type as [4]int") + } + var b3 [3]bool + a3bool := getTypeUnlocked("", reflect.TypeOf(b3)) + if a3int == a3bool { + t.Errorf("registration of [3]bool creates same type as [3]int") + } + str := a3bool.string() + expected := "[3]bool" + if str != expected { + t.Errorf("array printed as %q; expected %q", str, expected) + } +} + +func TestSliceType(t *testing.T) { + var s []int + sint := getTypeUnlocked("slice", reflect.TypeOf(s)) + var news []int + newsint := getTypeUnlocked("slice1", reflect.TypeOf(news)) + if sint != newsint { + t.Errorf("second registration of []int creates new type") + } + var b []bool + sbool := getTypeUnlocked("", reflect.TypeOf(b)) + if sbool == sint { + t.Errorf("registration of []bool creates same type as []int") + } + str := sbool.string() + expected := "[]bool" + if str != expected { + t.Errorf("slice printed as %q; expected %q", str, expected) + } +} + +func TestMapType(t *testing.T) { + var m map[string]int + mapStringInt := getTypeUnlocked("map", reflect.TypeOf(m)) + var newm map[string]int + newMapStringInt := getTypeUnlocked("map1", reflect.TypeOf(newm)) + if mapStringInt != newMapStringInt { + t.Errorf("second registration of map[string]int creates new type") + } + var b map[string]bool + mapStringBool := getTypeUnlocked("", reflect.TypeOf(b)) + if mapStringBool == mapStringInt { + t.Errorf("registration of map[string]bool creates same type as map[string]int") + } + str := mapStringBool.string() + expected := "map[string]bool" + if str != expected { + t.Errorf("map printed as %q; expected %q", str, expected) + } +} + +type Bar struct { + X string +} + +// This structure has pointers and refers to itself, making it a good test case. +type Foo struct { + A int + B int32 // will become int + C string + D []byte + E *float64 // will become float64 + F ****float64 // will become float64 + G *Bar + H *Bar // should not interpolate the definition of Bar again + I *Foo // will not explode +} + +func TestStructType(t *testing.T) { + sstruct := getTypeUnlocked("Foo", reflect.TypeOf(Foo{})) + str := sstruct.string() + // If we can print it correctly, we built it correctly. + expected := "Foo = struct { A int; B int; C string; D bytes; E float; F float; G Bar = struct { X string; }; H Bar; I Foo; }" + if str != expected { + t.Errorf("struct printed as %q; expected %q", str, expected) + } +} + +// Should be OK to register the same type multiple times, as long as they're +// at the same level of indirection. +func TestRegistration(t *testing.T) { + type T struct{ a int } + Register(new(T)) + Register(new(T)) +} + +type N1 struct{} +type N2 struct{} + +// See comment in type.go/Register. +func TestRegistrationNaming(t *testing.T) { + testCases := []struct { + t interface{} + name string + }{ + {&N1{}, "*gob.N1"}, + {N2{}, "encoding/gob.N2"}, + } + + for _, tc := range testCases { + Register(tc.t) + + tct := reflect.TypeOf(tc.t) + registerLock.RLock() + ct := nameToConcreteType[tc.name] + registerLock.RUnlock() + if ct != tct { + t.Errorf("nameToConcreteType[%q] = %v, want %v", tc.name, ct, tct) + } + // concreteTypeToName is keyed off the base type. + if tct.Kind() == reflect.Ptr { + tct = tct.Elem() + } + if n := concreteTypeToName[tct]; n != tc.name { + t.Errorf("concreteTypeToName[%v] got %v, want %v", tct, n, tc.name) + } + } +} + +func TestStressParallel(t *testing.T) { + type T2 struct{ A int } + c := make(chan bool) + const N = 10 + for i := 0; i < N; i++ { + go func() { + p := new(T2) + Register(p) + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(p) + if err != nil { + t.Error("encoder fail:", err) + } + dec := NewDecoder(b) + err = dec.Decode(p) + if err != nil { + t.Error("decoder fail:", err) + } + c <- true + }() + } + for i := 0; i < N; i++ { + <-c + } +} diff --git a/src/encoding/hex/hex.go b/src/encoding/hex/hex.go new file mode 100644 index 000000000..d1fc7024a --- /dev/null +++ b/src/encoding/hex/hex.go @@ -0,0 +1,216 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package hex implements hexadecimal encoding and decoding. +package hex + +import ( + "bytes" + "errors" + "fmt" + "io" +) + +const hextable = "0123456789abcdef" + +// EncodedLen returns the length of an encoding of n source bytes. +func EncodedLen(n int) int { return n * 2 } + +// Encode encodes src into EncodedLen(len(src)) +// bytes of dst. As a convenience, it returns the number +// of bytes written to dst, but this value is always EncodedLen(len(src)). +// Encode implements hexadecimal encoding. +func Encode(dst, src []byte) int { + for i, v := range src { + dst[i*2] = hextable[v>>4] + dst[i*2+1] = hextable[v&0x0f] + } + + return len(src) * 2 +} + +// ErrLength results from decoding an odd length slice. +var ErrLength = errors.New("encoding/hex: odd length hex string") + +// InvalidByteError values describe errors resulting from an invalid byte in a hex string. +type InvalidByteError byte + +func (e InvalidByteError) Error() string { + return fmt.Sprintf("encoding/hex: invalid byte: %#U", rune(e)) +} + +func DecodedLen(x int) int { return x / 2 } + +// Decode decodes src into DecodedLen(len(src)) bytes, returning the actual +// number of bytes written to dst. +// +// If Decode encounters invalid input, it returns an error describing the failure. +func Decode(dst, src []byte) (int, error) { + if len(src)%2 == 1 { + return 0, ErrLength + } + + for i := 0; i < len(src)/2; i++ { + a, ok := fromHexChar(src[i*2]) + if !ok { + return 0, InvalidByteError(src[i*2]) + } + b, ok := fromHexChar(src[i*2+1]) + if !ok { + return 0, InvalidByteError(src[i*2+1]) + } + dst[i] = (a << 4) | b + } + + return len(src) / 2, nil +} + +// fromHexChar converts a hex character into its value and a success flag. +func fromHexChar(c byte) (byte, bool) { + switch { + case '0' <= c && c <= '9': + return c - '0', true + case 'a' <= c && c <= 'f': + return c - 'a' + 10, true + case 'A' <= c && c <= 'F': + return c - 'A' + 10, true + } + + return 0, false +} + +// EncodeToString returns the hexadecimal encoding of src. +func EncodeToString(src []byte) string { + dst := make([]byte, EncodedLen(len(src))) + Encode(dst, src) + return string(dst) +} + +// DecodeString returns the bytes represented by the hexadecimal string s. +func DecodeString(s string) ([]byte, error) { + src := []byte(s) + dst := make([]byte, DecodedLen(len(src))) + _, err := Decode(dst, src) + if err != nil { + return nil, err + } + return dst, nil +} + +// Dump returns a string that contains a hex dump of the given data. The format +// of the hex dump matches the output of `hexdump -C` on the command line. +func Dump(data []byte) string { + var buf bytes.Buffer + dumper := Dumper(&buf) + dumper.Write(data) + dumper.Close() + return string(buf.Bytes()) +} + +// Dumper returns a WriteCloser that writes a hex dump of all written data to +// w. The format of the dump matches the output of `hexdump -C` on the command +// line. +func Dumper(w io.Writer) io.WriteCloser { + return &dumper{w: w} +} + +type dumper struct { + w io.Writer + rightChars [18]byte + buf [14]byte + used int // number of bytes in the current line + n uint // number of bytes, total +} + +func toChar(b byte) byte { + if b < 32 || b > 126 { + return '.' + } + return b +} + +func (h *dumper) Write(data []byte) (n int, err error) { + // Output lines look like: + // 00000010 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d |./0123456789:;<=| + // ^ offset ^ extra space ^ ASCII of line. + for i := range data { + if h.used == 0 { + // At the beginning of a line we print the current + // offset in hex. + h.buf[0] = byte(h.n >> 24) + h.buf[1] = byte(h.n >> 16) + h.buf[2] = byte(h.n >> 8) + h.buf[3] = byte(h.n) + Encode(h.buf[4:], h.buf[:4]) + h.buf[12] = ' ' + h.buf[13] = ' ' + _, err = h.w.Write(h.buf[4:]) + if err != nil { + return + } + } + Encode(h.buf[:], data[i:i+1]) + h.buf[2] = ' ' + l := 3 + if h.used == 7 { + // There's an additional space after the 8th byte. + h.buf[3] = ' ' + l = 4 + } else if h.used == 15 { + // At the end of the line there's an extra space and + // the bar for the right column. + h.buf[3] = ' ' + h.buf[4] = '|' + l = 5 + } + _, err = h.w.Write(h.buf[:l]) + if err != nil { + return + } + n++ + h.rightChars[h.used] = toChar(data[i]) + h.used++ + h.n++ + if h.used == 16 { + h.rightChars[16] = '|' + h.rightChars[17] = '\n' + _, err = h.w.Write(h.rightChars[:]) + if err != nil { + return + } + h.used = 0 + } + } + return +} + +func (h *dumper) Close() (err error) { + // See the comments in Write() for the details of this format. + if h.used == 0 { + return + } + h.buf[0] = ' ' + h.buf[1] = ' ' + h.buf[2] = ' ' + h.buf[3] = ' ' + h.buf[4] = '|' + nBytes := h.used + for h.used < 16 { + l := 3 + if h.used == 7 { + l = 4 + } else if h.used == 15 { + l = 5 + } + _, err = h.w.Write(h.buf[:l]) + if err != nil { + return + } + h.used++ + } + h.rightChars[nBytes] = '|' + h.rightChars[nBytes+1] = '\n' + _, err = h.w.Write(h.rightChars[:nBytes+2]) + return +} diff --git a/src/encoding/hex/hex_test.go b/src/encoding/hex/hex_test.go new file mode 100644 index 000000000..b969636cd --- /dev/null +++ b/src/encoding/hex/hex_test.go @@ -0,0 +1,153 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hex + +import ( + "bytes" + "testing" +) + +type encDecTest struct { + enc string + dec []byte +} + +var encDecTests = []encDecTest{ + {"", []byte{}}, + {"0001020304050607", []byte{0, 1, 2, 3, 4, 5, 6, 7}}, + {"08090a0b0c0d0e0f", []byte{8, 9, 10, 11, 12, 13, 14, 15}}, + {"f0f1f2f3f4f5f6f7", []byte{0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7}}, + {"f8f9fafbfcfdfeff", []byte{0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff}}, + {"67", []byte{'g'}}, + {"e3a1", []byte{0xe3, 0xa1}}, +} + +func TestEncode(t *testing.T) { + for i, test := range encDecTests { + dst := make([]byte, EncodedLen(len(test.dec))) + n := Encode(dst, test.dec) + if n != len(dst) { + t.Errorf("#%d: bad return value: got: %d want: %d", i, n, len(dst)) + } + if string(dst) != test.enc { + t.Errorf("#%d: got: %#v want: %#v", i, dst, test.enc) + } + } +} + +func TestDecode(t *testing.T) { + // Case for decoding uppercase hex characters, since + // Encode always uses lowercase. + decTests := append(encDecTests, encDecTest{"F8F9FAFBFCFDFEFF", []byte{0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff}}) + for i, test := range decTests { + dst := make([]byte, DecodedLen(len(test.enc))) + n, err := Decode(dst, []byte(test.enc)) + if err != nil { + t.Errorf("#%d: bad return value: got:%d want:%d", i, n, len(dst)) + } else if !bytes.Equal(dst, test.dec) { + t.Errorf("#%d: got: %#v want: %#v", i, dst, test.dec) + } + } +} + +func TestEncodeToString(t *testing.T) { + for i, test := range encDecTests { + s := EncodeToString(test.dec) + if s != test.enc { + t.Errorf("#%d got:%s want:%s", i, s, test.enc) + } + } +} + +func TestDecodeString(t *testing.T) { + for i, test := range encDecTests { + dst, err := DecodeString(test.enc) + if err != nil { + t.Errorf("#%d: unexpected err value: %s", i, err) + continue + } + if !bytes.Equal(dst, test.dec) { + t.Errorf("#%d: got: %#v want: #%v", i, dst, test.dec) + } + } +} + +type errTest struct { + in string + err string +} + +var errTests = []errTest{ + {"0", "encoding/hex: odd length hex string"}, + {"0g", "encoding/hex: invalid byte: U+0067 'g'"}, + {"00gg", "encoding/hex: invalid byte: U+0067 'g'"}, + {"0\x01", "encoding/hex: invalid byte: U+0001"}, +} + +func TestInvalidErr(t *testing.T) { + for i, test := range errTests { + dst := make([]byte, DecodedLen(len(test.in))) + _, err := Decode(dst, []byte(test.in)) + if err == nil { + t.Errorf("#%d: expected error; got none", i) + } else if err.Error() != test.err { + t.Errorf("#%d: got: %v want: %v", i, err, test.err) + } + } +} + +func TestInvalidStringErr(t *testing.T) { + for i, test := range errTests { + _, err := DecodeString(test.in) + if err == nil { + t.Errorf("#%d: expected error; got none", i) + } else if err.Error() != test.err { + t.Errorf("#%d: got: %v want: %v", i, err, test.err) + } + } +} + +func TestDumper(t *testing.T) { + var in [40]byte + for i := range in { + in[i] = byte(i + 30) + } + + for stride := 1; stride < len(in); stride++ { + var out bytes.Buffer + dumper := Dumper(&out) + done := 0 + for done < len(in) { + todo := done + stride + if todo > len(in) { + todo = len(in) + } + dumper.Write(in[done:todo]) + done = todo + } + + dumper.Close() + if !bytes.Equal(out.Bytes(), expectedHexDump) { + t.Errorf("stride: %d failed. got:\n%s\nwant:\n%s", stride, out.Bytes(), expectedHexDump) + } + } +} + +func TestDump(t *testing.T) { + var in [40]byte + for i := range in { + in[i] = byte(i + 30) + } + + out := []byte(Dump(in[:])) + if !bytes.Equal(out, expectedHexDump) { + t.Errorf("got:\n%s\nwant:\n%s", out, expectedHexDump) + } +} + +var expectedHexDump = []byte(`00000000 1e 1f 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d |.. !"#$%&'()*+,-| +00000010 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d |./0123456789:;<=| +00000020 3e 3f 40 41 42 43 44 45 |>?@ABCDE| +`) diff --git a/src/encoding/json/bench_test.go b/src/encoding/json/bench_test.go new file mode 100644 index 000000000..29dbc26d4 --- /dev/null +++ b/src/encoding/json/bench_test.go @@ -0,0 +1,189 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Large data benchmark. +// The JSON data is a summary of agl's changes in the +// go, webkit, and chromium open source projects. +// We benchmark converting between the JSON form +// and in-memory data structures. + +package json + +import ( + "bytes" + "compress/gzip" + "io/ioutil" + "os" + "testing" +) + +type codeResponse struct { + Tree *codeNode `json:"tree"` + Username string `json:"username"` +} + +type codeNode struct { + Name string `json:"name"` + Kids []*codeNode `json:"kids"` + CLWeight float64 `json:"cl_weight"` + Touches int `json:"touches"` + MinT int64 `json:"min_t"` + MaxT int64 `json:"max_t"` + MeanT int64 `json:"mean_t"` +} + +var codeJSON []byte +var codeStruct codeResponse + +func codeInit() { + f, err := os.Open("testdata/code.json.gz") + if err != nil { + panic(err) + } + defer f.Close() + gz, err := gzip.NewReader(f) + if err != nil { + panic(err) + } + data, err := ioutil.ReadAll(gz) + if err != nil { + panic(err) + } + + codeJSON = data + + if err := Unmarshal(codeJSON, &codeStruct); err != nil { + panic("unmarshal code.json: " + err.Error()) + } + + if data, err = Marshal(&codeStruct); err != nil { + panic("marshal code.json: " + err.Error()) + } + + if !bytes.Equal(data, codeJSON) { + println("different lengths", len(data), len(codeJSON)) + for i := 0; i < len(data) && i < len(codeJSON); i++ { + if data[i] != codeJSON[i] { + println("re-marshal: changed at byte", i) + println("orig: ", string(codeJSON[i-10:i+10])) + println("new: ", string(data[i-10:i+10])) + break + } + } + panic("re-marshal code.json: different result") + } +} + +func BenchmarkCodeEncoder(b *testing.B) { + if codeJSON == nil { + b.StopTimer() + codeInit() + b.StartTimer() + } + enc := NewEncoder(ioutil.Discard) + for i := 0; i < b.N; i++ { + if err := enc.Encode(&codeStruct); err != nil { + b.Fatal("Encode:", err) + } + } + b.SetBytes(int64(len(codeJSON))) +} + +func BenchmarkCodeMarshal(b *testing.B) { + if codeJSON == nil { + b.StopTimer() + codeInit() + b.StartTimer() + } + for i := 0; i < b.N; i++ { + if _, err := Marshal(&codeStruct); err != nil { + b.Fatal("Marshal:", err) + } + } + b.SetBytes(int64(len(codeJSON))) +} + +func BenchmarkCodeDecoder(b *testing.B) { + if codeJSON == nil { + b.StopTimer() + codeInit() + b.StartTimer() + } + var buf bytes.Buffer + dec := NewDecoder(&buf) + var r codeResponse + for i := 0; i < b.N; i++ { + buf.Write(codeJSON) + // hide EOF + buf.WriteByte('\n') + buf.WriteByte('\n') + buf.WriteByte('\n') + if err := dec.Decode(&r); err != nil { + b.Fatal("Decode:", err) + } + } + b.SetBytes(int64(len(codeJSON))) +} + +func BenchmarkCodeUnmarshal(b *testing.B) { + if codeJSON == nil { + b.StopTimer() + codeInit() + b.StartTimer() + } + for i := 0; i < b.N; i++ { + var r codeResponse + if err := Unmarshal(codeJSON, &r); err != nil { + b.Fatal("Unmmarshal:", err) + } + } + b.SetBytes(int64(len(codeJSON))) +} + +func BenchmarkCodeUnmarshalReuse(b *testing.B) { + if codeJSON == nil { + b.StopTimer() + codeInit() + b.StartTimer() + } + var r codeResponse + for i := 0; i < b.N; i++ { + if err := Unmarshal(codeJSON, &r); err != nil { + b.Fatal("Unmmarshal:", err) + } + } +} + +func BenchmarkUnmarshalString(b *testing.B) { + data := []byte(`"hello, world"`) + var s string + + for i := 0; i < b.N; i++ { + if err := Unmarshal(data, &s); err != nil { + b.Fatal("Unmarshal:", err) + } + } +} + +func BenchmarkUnmarshalFloat64(b *testing.B) { + var f float64 + data := []byte(`3.14`) + + for i := 0; i < b.N; i++ { + if err := Unmarshal(data, &f); err != nil { + b.Fatal("Unmarshal:", err) + } + } +} + +func BenchmarkUnmarshalInt64(b *testing.B) { + var x int64 + data := []byte(`3`) + + for i := 0; i < b.N; i++ { + if err := Unmarshal(data, &x); err != nil { + b.Fatal("Unmarshal:", err) + } + } +} diff --git a/src/encoding/json/decode.go b/src/encoding/json/decode.go new file mode 100644 index 000000000..705bc2e17 --- /dev/null +++ b/src/encoding/json/decode.go @@ -0,0 +1,1084 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "errors" + "fmt" + "reflect" + "runtime" + "strconv" + "unicode" + "unicode/utf16" + "unicode/utf8" +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. +// +// Unmarshal uses the inverse of the encodings that +// Marshal uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal JSON into a pointer, Unmarshal first handles the case of +// the JSON being the JSON literal null. In that case, Unmarshal sets +// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into +// the value pointed at by the pointer. If the pointer is nil, Unmarshal +// allocates a new value for it to point to. +// +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by Marshal (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. +// +// To unmarshal JSON into an interface value, +// Unmarshal stores one of these in the interface value: +// +// bool, for JSON booleans +// float64, for JSON numbers +// string, for JSON strings +// []interface{}, for JSON arrays +// map[string]interface{}, for JSON objects +// nil for JSON null +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshalling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an UnmarshalTypeError describing the earliest such error. +// +// The JSON null value unmarshals into an interface, map, pointer, or slice +// by setting that Go value to nil. Because null is often used in JSON to mean +// ``not present,'' unmarshaling a JSON null into any other Go type has no effect +// on the value and produces no error. +// +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +// +func Unmarshal(data []byte, v interface{}) error { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + var d decodeState + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + d.init(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by objects +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to +} + +func (e *UnmarshalTypeError) Error() string { + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +// (No longer used; kept for compatibility.) +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) Error() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. +// (The argument to Unmarshal must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Ptr { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + d.scan.reset() + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + d.value(rv) + return d.savedError +} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // read offset in data + scan scanner + nextscan scanner // for calls to nextValue + savedError error + useNumber bool +} + +// errPhase is used for errors that should not happen unless +// there is a bug in the JSON decoder or something is editing +// the data slice while the decoder executes. +var errPhase = errors.New("JSON decoder out of sync - data changing underfoot?") + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + return d +} + +// error aborts the decoding by panicking with err. +func (d *decodeState) error(err error) { + panic(err) +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err error) { + if d.savedError == nil { + d.savedError = err + } +} + +// next cuts off and returns the next full JSON value in d.data[d.off:]. +// The next value is known to be an object or array, not a literal. +func (d *decodeState) next() []byte { + c := d.data[d.off] + item, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // Our scanner has seen the opening brace/bracket + // and thinks we're still in the middle of the object. + // invent a closing brace/bracket to get it out. + if c == '{' { + d.scan.step(&d.scan, '}') + } else { + d.scan.step(&d.scan, ']') + } + + return item +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +// It updates d.off and returns the new scan code. +func (d *decodeState) scanWhile(op int) int { + var newOp int + for { + if d.off >= len(d.data) { + newOp = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } else { + c := int(d.data[d.off]) + d.off++ + newOp = d.scan.step(&d.scan, c) + } + if newOp != op { + break + } + } + return newOp +} + +// value decodes a JSON value from d.data[d.off:] into the value. +// it updates d.off to point past the decoded value. +func (d *decodeState) value(v reflect.Value) { + if !v.IsValid() { + _, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // d.scan thinks we're still at the beginning of the item. + // Feed in an empty string - the shortest, simplest value - + // so that it knows we got to the end of the value. + if d.scan.redo { + // rewind. + d.scan.redo = false + d.scan.step = stateBeginValue + } + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + + n := len(d.scan.parseState) + if n > 0 && d.scan.parseState[n-1] == parseObjectKey { + // d.scan thinks we just read an object key; finish the object + d.scan.step(&d.scan, ':') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '}') + } + + return + } + + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(v) + + case scanBeginObject: + d.object(v) + + case scanBeginLiteral: + d.literal(v) + } +} + +type unquotedValue struct{} + +// valueQuoted is like value but decodes a +// quoted string literal or literal null into an interface value. +// If it finds anything other than a quoted string literal or null, +// valueQuoted returns unquotedValue{}. +func (d *decodeState) valueQuoted() interface{} { + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(reflect.Value{}) + + case scanBeginObject: + d.object(reflect.Value{}) + + case scanBeginLiteral: + switch v := d.literalInterface().(type) { + case nil, string: + return v + } + } + return unquotedValue{} +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// if it encounters an Unmarshaler, indirect stops and returns that. +// if decodingNull is true, indirect stops at the last pointer so it can be set to nil. +func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + v = e + continue + } + } + + if v.Kind() != reflect.Ptr { + break + } + + if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + } + v = v.Elem() + } + return nil, nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into the value v. +// the first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) { + // Check for unmarshaler. + u, ut, pv := d.indirect(v, false) + if u != nil { + d.off-- + err := u.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + d.saveError(&UnmarshalTypeError{"array", v.Type()}) + d.off-- + d.next() + return + } + + v = pv + + // Check type of target. + switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + v.Set(reflect.ValueOf(d.arrayInterface())) + return + } + // Otherwise it's invalid. + fallthrough + default: + d.saveError(&UnmarshalTypeError{"array", v.Type()}) + d.off-- + d.next() + return + case reflect.Array: + case reflect.Slice: + break + } + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + // Get element of array, growing if necessary. + if v.Kind() == reflect.Slice { + // Grow slice if necessary + if i >= v.Cap() { + newcap := v.Cap() + v.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) + reflect.Copy(newv, v) + v.Set(newv) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + // Decode into element. + d.value(v.Index(i)) + } else { + // Ran out of fixed array: skip. + d.value(reflect.Value{}) + } + i++ + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + // Array. Zero the rest. + z := reflect.Zero(v.Type().Elem()) + for ; i < v.Len(); i++ { + v.Index(i).Set(z) + } + } else { + v.SetLen(i) + } + } + if i == 0 && v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } +} + +var nullLiteral = []byte("null") + +// object consumes an object from d.data[d.off-1:], decoding into the value v. +// the first byte ('{') of the object has been read already. +func (d *decodeState) object(v reflect.Value) { + // Check for unmarshaler. + u, ut, pv := d.indirect(v, false) + if u != nil { + d.off-- + err := u.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + d.saveError(&UnmarshalTypeError{"object", v.Type()}) + d.off-- + d.next() // skip over { } in input + return + } + v = pv + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + v.Set(reflect.ValueOf(d.objectInterface())) + return + } + + // Check type of target: struct or map[string]T + switch v.Kind() { + case reflect.Map: + // map must have string kind + t := v.Type() + if t.Key().Kind() != reflect.String { + d.saveError(&UnmarshalTypeError{"object", v.Type()}) + d.off-- + d.next() // skip over { } in input + return + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + + default: + d.saveError(&UnmarshalTypeError{"object", v.Type()}) + d.off-- + d.next() // skip over { } in input + return + } + + var mapElem reflect.Value + + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquoteBytes(item) + if !ok { + d.error(errPhase) + } + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := v.Type().Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.Set(reflect.Zero(elemType)) + } + subv = mapElem + } else { + var f *field + fields := cachedTypeFields(v.Type()) + for i := range fields { + ff := &fields[i] + if bytes.Equal(ff.nameBytes, key) { + f = ff + break + } + if f == nil && ff.equalFold(ff.nameBytes, key) { + f = ff + } + } + if f != nil { + subv = v + destring = f.quoted + for _, i := range f.index { + if subv.Kind() == reflect.Ptr { + if subv.IsNil() { + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + subv = subv.Field(i) + } + } + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + d.literalStore(nullLiteral, subv, false) + case string: + d.literalStore([]byte(qv), subv, true) + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", item, v.Type())) + } + } else { + d.value(subv) + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kv := reflect.ValueOf(key).Convert(v.Type().Key()) + v.SetMapIndex(kv, subv) + } + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } +} + +// literal consumes a literal from d.data[d.off-1:], decoding into the value v. +// The first byte of the literal has been read already +// (that's how the caller knows it's a literal). +func (d *decodeState) literal(v reflect.Value) { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + + d.literalStore(d.data[start:d.off], v, false) +} + +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (interface{}, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, &UnmarshalTypeError{"number " + s, reflect.TypeOf(0.0)} + } + return f, nil +} + +var numberType = reflect.TypeOf(Number("")) + +// literalStore decodes a literal stored in item into v. +// +// fromQuoted indicates whether this literal came from unwrapping a +// string from the ",string" struct tag option. this is used only to +// produce more helpful error messages. +func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) { + // Check for unmarshaler. + if len(item) == 0 { + //Empty string given + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return + } + wantptr := item[0] == 'n' // null + u, ut, pv := d.indirect(v, wantptr) + if u != nil { + err := u.UnmarshalJSON(item) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + } + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + err := ut.UnmarshalText(s) + if err != nil { + d.error(err) + } + return + } + + v = pv + + switch c := item[0]; c { + case 'n': // null + switch v.Kind() { + case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: + v.Set(reflect.Zero(v.Type())) + // otherwise, ignore null for primitives/string + } + case 't', 'f': // true, false + value := c == 't' + switch v.Kind() { + default: + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{"bool", v.Type()}) + } + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{"bool", v.Type()}) + } + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + case reflect.Slice: + if v.Type() != byteSliceType { + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.Set(reflect.ValueOf(b[0:n])) + case reflect.String: + v.SetString(string(s)) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + } + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + s := string(item) + switch v.Kind() { + default: + if v.Kind() == reflect.String && v.Type() == numberType { + v.SetString(s) + break + } + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(&UnmarshalTypeError{"number", v.Type()}) + } + case reflect.Interface: + n, err := d.convertNumber(s) + if err != nil { + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{"number", v.Type()}) + break + } + v.Set(reflect.ValueOf(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(s, v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.SetFloat(n) + } + } +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns interface{} +func (d *decodeState) valueInterface() interface{} { + switch d.scanWhile(scanSkipSpace) { + default: + d.error(errPhase) + panic("unreachable") + case scanBeginArray: + return d.arrayInterface() + case scanBeginObject: + return d.objectInterface() + case scanBeginLiteral: + return d.literalInterface() + } +} + +// arrayInterface is like array but returns []interface{}. +func (d *decodeState) arrayInterface() []interface{} { + var v = make([]interface{}, 0) + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + v = append(v, d.valueInterface()) + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + return v +} + +// objectInterface is like object but returns map[string]interface{}. +func (d *decodeState) objectInterface() map[string]interface{} { + m := make(map[string]interface{}) + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read string key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } + return m +} + +// literalInterface is like literal but returns an interface value. +func (d *decodeState) literalInterface() interface{} { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + item := d.data[start:d.off] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + d.error(errPhase) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + d.error(errPhase) + } + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + } + return n + } +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + r, err := strconv.ParseUint(string(s[2:6]), 16, 64) + if err != nil { + return -1 + } + return rune(r) +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rr, size := utf8.DecodeRune(s[r:]) + if rr == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rr := getu4(s[r:]) + if rr < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rr) { + rr1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rr = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rr) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rr, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rr) + } + } + return b[0:w], true +} diff --git a/src/encoding/json/decode_test.go b/src/encoding/json/decode_test.go new file mode 100644 index 000000000..7235969b9 --- /dev/null +++ b/src/encoding/json/decode_test.go @@ -0,0 +1,1373 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "bytes" + "encoding" + "fmt" + "image" + "reflect" + "strings" + "testing" + "time" +) + +type T struct { + X string + Y int + Z int `json:"-"` +} + +type U struct { + Alphabet string `json:"alpha"` +} + +type V struct { + F1 interface{} + F2 int32 + F3 Number +} + +// ifaceNumAsFloat64/ifaceNumAsNumber are used to test unmarshaling with and +// without UseNumber +var ifaceNumAsFloat64 = map[string]interface{}{ + "k1": float64(1), + "k2": "s", + "k3": []interface{}{float64(1), float64(2.0), float64(3e-3)}, + "k4": map[string]interface{}{"kk1": "s", "kk2": float64(2)}, +} + +var ifaceNumAsNumber = map[string]interface{}{ + "k1": Number("1"), + "k2": "s", + "k3": []interface{}{Number("1"), Number("2.0"), Number("3e-3")}, + "k4": map[string]interface{}{"kk1": "s", "kk2": Number("2")}, +} + +type tx struct { + x int +} + +// A type that can unmarshal itself. + +type unmarshaler struct { + T bool +} + +func (u *unmarshaler) UnmarshalJSON(b []byte) error { + *u = unmarshaler{true} // All we need to see that UnmarshalJSON is called. + return nil +} + +type ustruct struct { + M unmarshaler +} + +type unmarshalerText struct { + T bool +} + +// needed for re-marshaling tests +func (u *unmarshalerText) MarshalText() ([]byte, error) { + return []byte(""), nil +} + +func (u *unmarshalerText) UnmarshalText(b []byte) error { + *u = unmarshalerText{true} // All we need to see that UnmarshalText is called. + return nil +} + +var _ encoding.TextUnmarshaler = (*unmarshalerText)(nil) + +type ustructText struct { + M unmarshalerText +} + +var ( + um0, um1 unmarshaler // target2 of unmarshaling + ump = &um1 + umtrue = unmarshaler{true} + umslice = []unmarshaler{{true}} + umslicep = new([]unmarshaler) + umstruct = ustruct{unmarshaler{true}} + + um0T, um1T unmarshalerText // target2 of unmarshaling + umpT = &um1T + umtrueT = unmarshalerText{true} + umsliceT = []unmarshalerText{{true}} + umslicepT = new([]unmarshalerText) + umstructT = ustructText{unmarshalerText{true}} +) + +// Test data structures for anonymous fields. + +type Point struct { + Z int +} + +type Top struct { + Level0 int + Embed0 + *Embed0a + *Embed0b `json:"e,omitempty"` // treated as named + Embed0c `json:"-"` // ignored + Loop + Embed0p // has Point with X, Y, used + Embed0q // has Point with Z, used +} + +type Embed0 struct { + Level1a int // overridden by Embed0a's Level1a with json tag + Level1b int // used because Embed0a's Level1b is renamed + Level1c int // used because Embed0a's Level1c is ignored + Level1d int // annihilated by Embed0a's Level1d + Level1e int `json:"x"` // annihilated by Embed0a.Level1e +} + +type Embed0a struct { + Level1a int `json:"Level1a,omitempty"` + Level1b int `json:"LEVEL1B,omitempty"` + Level1c int `json:"-"` + Level1d int // annihilated by Embed0's Level1d + Level1f int `json:"x"` // annihilated by Embed0's Level1e +} + +type Embed0b Embed0 + +type Embed0c Embed0 + +type Embed0p struct { + image.Point +} + +type Embed0q struct { + Point +} + +type Loop struct { + Loop1 int `json:",omitempty"` + Loop2 int `json:",omitempty"` + *Loop +} + +// From reflect test: +// The X in S6 and S7 annihilate, but they also block the X in S8.S9. +type S5 struct { + S6 + S7 + S8 +} + +type S6 struct { + X int +} + +type S7 S6 + +type S8 struct { + S9 +} + +type S9 struct { + X int + Y int +} + +// From reflect test: +// The X in S11.S6 and S12.S6 annihilate, but they also block the X in S13.S8.S9. +type S10 struct { + S11 + S12 + S13 +} + +type S11 struct { + S6 +} + +type S12 struct { + S6 +} + +type S13 struct { + S8 +} + +type unmarshalTest struct { + in string + ptr interface{} + out interface{} + err error + useNumber bool +} + +type Ambig struct { + // Given "hello", the first match should win. + First int `json:"HELLO"` + Second int `json:"Hello"` +} + +type XYZ struct { + X interface{} + Y interface{} + Z interface{} +} + +var unmarshalTests = []unmarshalTest{ + // basic types + {in: `true`, ptr: new(bool), out: true}, + {in: `1`, ptr: new(int), out: 1}, + {in: `1.2`, ptr: new(float64), out: 1.2}, + {in: `-5`, ptr: new(int16), out: int16(-5)}, + {in: `2`, ptr: new(Number), out: Number("2"), useNumber: true}, + {in: `2`, ptr: new(Number), out: Number("2")}, + {in: `2`, ptr: new(interface{}), out: float64(2.0)}, + {in: `2`, ptr: new(interface{}), out: Number("2"), useNumber: true}, + {in: `"a\u1234"`, ptr: new(string), out: "a\u1234"}, + {in: `"http:\/\/"`, ptr: new(string), out: "http://"}, + {in: `"g-clef: \uD834\uDD1E"`, ptr: new(string), out: "g-clef: \U0001D11E"}, + {in: `"invalid: \uD834x\uDD1E"`, ptr: new(string), out: "invalid: \uFFFDx\uFFFD"}, + {in: "null", ptr: new(interface{}), out: nil}, + {in: `{"X": [1,2,3], "Y": 4}`, ptr: new(T), out: T{Y: 4}, err: &UnmarshalTypeError{"array", reflect.TypeOf("")}}, + {in: `{"x": 1}`, ptr: new(tx), out: tx{}}, + {in: `{"F1":1,"F2":2,"F3":3}`, ptr: new(V), out: V{F1: float64(1), F2: int32(2), F3: Number("3")}}, + {in: `{"F1":1,"F2":2,"F3":3}`, ptr: new(V), out: V{F1: Number("1"), F2: int32(2), F3: Number("3")}, useNumber: true}, + {in: `{"k1":1,"k2":"s","k3":[1,2.0,3e-3],"k4":{"kk1":"s","kk2":2}}`, ptr: new(interface{}), out: ifaceNumAsFloat64}, + {in: `{"k1":1,"k2":"s","k3":[1,2.0,3e-3],"k4":{"kk1":"s","kk2":2}}`, ptr: new(interface{}), out: ifaceNumAsNumber, useNumber: true}, + + // raw values with whitespace + {in: "\n true ", ptr: new(bool), out: true}, + {in: "\t 1 ", ptr: new(int), out: 1}, + {in: "\r 1.2 ", ptr: new(float64), out: 1.2}, + {in: "\t -5 \n", ptr: new(int16), out: int16(-5)}, + {in: "\t \"a\\u1234\" \n", ptr: new(string), out: "a\u1234"}, + + // Z has a "-" tag. + {in: `{"Y": 1, "Z": 2}`, ptr: new(T), out: T{Y: 1}}, + + {in: `{"alpha": "abc", "alphabet": "xyz"}`, ptr: new(U), out: U{Alphabet: "abc"}}, + {in: `{"alpha": "abc"}`, ptr: new(U), out: U{Alphabet: "abc"}}, + {in: `{"alphabet": "xyz"}`, ptr: new(U), out: U{}}, + + // syntax errors + {in: `{"X": "foo", "Y"}`, err: &SyntaxError{"invalid character '}' after object key", 17}}, + {in: `[1, 2, 3+]`, err: &SyntaxError{"invalid character '+' after array element", 9}}, + {in: `{"X":12x}`, err: &SyntaxError{"invalid character 'x' after object key:value pair", 8}, useNumber: true}, + + // raw value errors + {in: "\x01 42", err: &SyntaxError{"invalid character '\\x01' looking for beginning of value", 1}}, + {in: " 42 \x01", err: &SyntaxError{"invalid character '\\x01' after top-level value", 5}}, + {in: "\x01 true", err: &SyntaxError{"invalid character '\\x01' looking for beginning of value", 1}}, + {in: " false \x01", err: &SyntaxError{"invalid character '\\x01' after top-level value", 8}}, + {in: "\x01 1.2", err: &SyntaxError{"invalid character '\\x01' looking for beginning of value", 1}}, + {in: " 3.4 \x01", err: &SyntaxError{"invalid character '\\x01' after top-level value", 6}}, + {in: "\x01 \"string\"", err: &SyntaxError{"invalid character '\\x01' looking for beginning of value", 1}}, + {in: " \"string\" \x01", err: &SyntaxError{"invalid character '\\x01' after top-level value", 11}}, + + // array tests + {in: `[1, 2, 3]`, ptr: new([3]int), out: [3]int{1, 2, 3}}, + {in: `[1, 2, 3]`, ptr: new([1]int), out: [1]int{1}}, + {in: `[1, 2, 3]`, ptr: new([5]int), out: [5]int{1, 2, 3, 0, 0}}, + + // empty array to interface test + {in: `[]`, ptr: new([]interface{}), out: []interface{}{}}, + {in: `null`, ptr: new([]interface{}), out: []interface{}(nil)}, + {in: `{"T":[]}`, ptr: new(map[string]interface{}), out: map[string]interface{}{"T": []interface{}{}}}, + {in: `{"T":null}`, ptr: new(map[string]interface{}), out: map[string]interface{}{"T": interface{}(nil)}}, + + // composite tests + {in: allValueIndent, ptr: new(All), out: allValue}, + {in: allValueCompact, ptr: new(All), out: allValue}, + {in: allValueIndent, ptr: new(*All), out: &allValue}, + {in: allValueCompact, ptr: new(*All), out: &allValue}, + {in: pallValueIndent, ptr: new(All), out: pallValue}, + {in: pallValueCompact, ptr: new(All), out: pallValue}, + {in: pallValueIndent, ptr: new(*All), out: &pallValue}, + {in: pallValueCompact, ptr: new(*All), out: &pallValue}, + + // unmarshal interface test + {in: `{"T":false}`, ptr: &um0, out: umtrue}, // use "false" so test will fail if custom unmarshaler is not called + {in: `{"T":false}`, ptr: &ump, out: &umtrue}, + {in: `[{"T":false}]`, ptr: &umslice, out: umslice}, + {in: `[{"T":false}]`, ptr: &umslicep, out: &umslice}, + {in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct}, + + // UnmarshalText interface test + {in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called + {in: `"X"`, ptr: &umpT, out: &umtrueT}, + {in: `["X"]`, ptr: &umsliceT, out: umsliceT}, + {in: `["X"]`, ptr: &umslicepT, out: &umsliceT}, + {in: `{"M":"X"}`, ptr: &umstructT, out: umstructT}, + + { + in: `{ + "Level0": 1, + "Level1b": 2, + "Level1c": 3, + "x": 4, + "Level1a": 5, + "LEVEL1B": 6, + "e": { + "Level1a": 8, + "Level1b": 9, + "Level1c": 10, + "Level1d": 11, + "x": 12 + }, + "Loop1": 13, + "Loop2": 14, + "X": 15, + "Y": 16, + "Z": 17 + }`, + ptr: new(Top), + out: Top{ + Level0: 1, + Embed0: Embed0{ + Level1b: 2, + Level1c: 3, + }, + Embed0a: &Embed0a{ + Level1a: 5, + Level1b: 6, + }, + Embed0b: &Embed0b{ + Level1a: 8, + Level1b: 9, + Level1c: 10, + Level1d: 11, + Level1e: 12, + }, + Loop: Loop{ + Loop1: 13, + Loop2: 14, + }, + Embed0p: Embed0p{ + Point: image.Point{X: 15, Y: 16}, + }, + Embed0q: Embed0q{ + Point: Point{Z: 17}, + }, + }, + }, + { + in: `{"hello": 1}`, + ptr: new(Ambig), + out: Ambig{First: 1}, + }, + + { + in: `{"X": 1,"Y":2}`, + ptr: new(S5), + out: S5{S8: S8{S9: S9{Y: 2}}}, + }, + { + in: `{"X": 1,"Y":2}`, + ptr: new(S10), + out: S10{S13: S13{S8: S8{S9: S9{Y: 2}}}}, + }, + + // invalid UTF-8 is coerced to valid UTF-8. + { + in: "\"hello\xffworld\"", + ptr: new(string), + out: "hello\ufffdworld", + }, + { + in: "\"hello\xc2\xc2world\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\xc2\xffworld\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\\ud800world\"", + ptr: new(string), + out: "hello\ufffdworld", + }, + { + in: "\"hello\\ud800\\ud800world\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\\ud800\\ud800world\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\xed\xa0\x80\xed\xb0\x80world\"", + ptr: new(string), + out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld", + }, + + // issue 8305 + { + in: `{"2009-11-10T23:00:00Z": "hello world"}`, + ptr: &map[time.Time]string{}, + err: &UnmarshalTypeError{"object", reflect.TypeOf(map[time.Time]string{})}, + }, +} + +func TestMarshal(t *testing.T) { + b, err := Marshal(allValue) + if err != nil { + t.Fatalf("Marshal allValue: %v", err) + } + if string(b) != allValueCompact { + t.Errorf("Marshal allValueCompact") + diff(t, b, []byte(allValueCompact)) + return + } + + b, err = Marshal(pallValue) + if err != nil { + t.Fatalf("Marshal pallValue: %v", err) + } + if string(b) != pallValueCompact { + t.Errorf("Marshal pallValueCompact") + diff(t, b, []byte(pallValueCompact)) + return + } +} + +var badUTF8 = []struct { + in, out string +}{ + {"hello\xffworld", `"hello\ufffdworld"`}, + {"", `""`}, + {"\xff", `"\ufffd"`}, + {"\xff\xff", `"\ufffd\ufffd"`}, + {"a\xffb", `"a\ufffdb"`}, + {"\xe6\x97\xa5\xe6\x9c\xac\xff\xaa\x9e", `"日本\ufffd\ufffd\ufffd"`}, +} + +func TestMarshalBadUTF8(t *testing.T) { + for _, tt := range badUTF8 { + b, err := Marshal(tt.in) + if string(b) != tt.out || err != nil { + t.Errorf("Marshal(%q) = %#q, %v, want %#q, nil", tt.in, b, err, tt.out) + } + } +} + +func TestMarshalNumberZeroVal(t *testing.T) { + var n Number + out, err := Marshal(n) + if err != nil { + t.Fatal(err) + } + outStr := string(out) + if outStr != "0" { + t.Fatalf("Invalid zero val for Number: %q", outStr) + } +} + +func TestMarshalEmbeds(t *testing.T) { + top := &Top{ + Level0: 1, + Embed0: Embed0{ + Level1b: 2, + Level1c: 3, + }, + Embed0a: &Embed0a{ + Level1a: 5, + Level1b: 6, + }, + Embed0b: &Embed0b{ + Level1a: 8, + Level1b: 9, + Level1c: 10, + Level1d: 11, + Level1e: 12, + }, + Loop: Loop{ + Loop1: 13, + Loop2: 14, + }, + Embed0p: Embed0p{ + Point: image.Point{X: 15, Y: 16}, + }, + Embed0q: Embed0q{ + Point: Point{Z: 17}, + }, + } + b, err := Marshal(top) + if err != nil { + t.Fatal(err) + } + want := "{\"Level0\":1,\"Level1b\":2,\"Level1c\":3,\"Level1a\":5,\"LEVEL1B\":6,\"e\":{\"Level1a\":8,\"Level1b\":9,\"Level1c\":10,\"Level1d\":11,\"x\":12},\"Loop1\":13,\"Loop2\":14,\"X\":15,\"Y\":16,\"Z\":17}" + if string(b) != want { + t.Errorf("Wrong marshal result.\n got: %q\nwant: %q", b, want) + } +} + +func TestUnmarshal(t *testing.T) { + for i, tt := range unmarshalTests { + var scan scanner + in := []byte(tt.in) + if err := checkValid(in, &scan); err != nil { + if !reflect.DeepEqual(err, tt.err) { + t.Errorf("#%d: checkValid: %#v", i, err) + continue + } + } + if tt.ptr == nil { + continue + } + + // v = new(right-type) + v := reflect.New(reflect.TypeOf(tt.ptr).Elem()) + dec := NewDecoder(bytes.NewReader(in)) + if tt.useNumber { + dec.UseNumber() + } + if err := dec.Decode(v.Interface()); !reflect.DeepEqual(err, tt.err) { + t.Errorf("#%d: %v, want %v", i, err, tt.err) + continue + } else if err != nil { + continue + } + if !reflect.DeepEqual(v.Elem().Interface(), tt.out) { + t.Errorf("#%d: mismatch\nhave: %#+v\nwant: %#+v", i, v.Elem().Interface(), tt.out) + data, _ := Marshal(v.Elem().Interface()) + println(string(data)) + data, _ = Marshal(tt.out) + println(string(data)) + continue + } + + // Check round trip. + if tt.err == nil { + enc, err := Marshal(v.Interface()) + if err != nil { + t.Errorf("#%d: error re-marshaling: %v", i, err) + continue + } + vv := reflect.New(reflect.TypeOf(tt.ptr).Elem()) + dec = NewDecoder(bytes.NewReader(enc)) + if tt.useNumber { + dec.UseNumber() + } + if err := dec.Decode(vv.Interface()); err != nil { + t.Errorf("#%d: error re-unmarshaling %#q: %v", i, enc, err) + continue + } + if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) { + t.Errorf("#%d: mismatch\nhave: %#+v\nwant: %#+v", i, v.Elem().Interface(), vv.Elem().Interface()) + t.Errorf(" In: %q", strings.Map(noSpace, string(in))) + t.Errorf("Marshal: %q", strings.Map(noSpace, string(enc))) + continue + } + } + } +} + +func TestUnmarshalMarshal(t *testing.T) { + initBig() + var v interface{} + if err := Unmarshal(jsonBig, &v); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + b, err := Marshal(v) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if !bytes.Equal(jsonBig, b) { + t.Errorf("Marshal jsonBig") + diff(t, b, jsonBig) + return + } +} + +var numberTests = []struct { + in string + i int64 + intErr string + f float64 + floatErr string +}{ + {in: "-1.23e1", intErr: "strconv.ParseInt: parsing \"-1.23e1\": invalid syntax", f: -1.23e1}, + {in: "-12", i: -12, f: -12.0}, + {in: "1e1000", intErr: "strconv.ParseInt: parsing \"1e1000\": invalid syntax", floatErr: "strconv.ParseFloat: parsing \"1e1000\": value out of range"}, +} + +// Independent of Decode, basic coverage of the accessors in Number +func TestNumberAccessors(t *testing.T) { + for _, tt := range numberTests { + n := Number(tt.in) + if s := n.String(); s != tt.in { + t.Errorf("Number(%q).String() is %q", tt.in, s) + } + if i, err := n.Int64(); err == nil && tt.intErr == "" && i != tt.i { + t.Errorf("Number(%q).Int64() is %d", tt.in, i) + } else if (err == nil && tt.intErr != "") || (err != nil && err.Error() != tt.intErr) { + t.Errorf("Number(%q).Int64() wanted error %q but got: %v", tt.in, tt.intErr, err) + } + if f, err := n.Float64(); err == nil && tt.floatErr == "" && f != tt.f { + t.Errorf("Number(%q).Float64() is %g", tt.in, f) + } else if (err == nil && tt.floatErr != "") || (err != nil && err.Error() != tt.floatErr) { + t.Errorf("Number(%q).Float64() wanted error %q but got: %v", tt.in, tt.floatErr, err) + } + } +} + +func TestLargeByteSlice(t *testing.T) { + s0 := make([]byte, 2000) + for i := range s0 { + s0[i] = byte(i) + } + b, err := Marshal(s0) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var s1 []byte + if err := Unmarshal(b, &s1); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !bytes.Equal(s0, s1) { + t.Errorf("Marshal large byte slice") + diff(t, s0, s1) + } +} + +type Xint struct { + X int +} + +func TestUnmarshalInterface(t *testing.T) { + var xint Xint + var i interface{} = &xint + if err := Unmarshal([]byte(`{"X":1}`), &i); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if xint.X != 1 { + t.Fatalf("Did not write to xint") + } +} + +func TestUnmarshalPtrPtr(t *testing.T) { + var xint Xint + pxint := &xint + if err := Unmarshal([]byte(`{"X":1}`), &pxint); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if xint.X != 1 { + t.Fatalf("Did not write to xint") + } +} + +func TestEscape(t *testing.T) { + const input = `"foobar"<html>` + " [\u2028 \u2029]" + const expected = `"\"foobar\"\u003chtml\u003e [\u2028 \u2029]"` + b, err := Marshal(input) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + if s := string(b); s != expected { + t.Errorf("Encoding of [%s]:\n got [%s]\nwant [%s]", input, s, expected) + } +} + +// WrongString is a struct that's misusing the ,string modifier. +type WrongString struct { + Message string `json:"result,string"` +} + +type wrongStringTest struct { + in, err string +} + +var wrongStringTests = []wrongStringTest{ + {`{"result":"x"}`, `json: invalid use of ,string struct tag, trying to unmarshal "x" into string`}, + {`{"result":"foo"}`, `json: invalid use of ,string struct tag, trying to unmarshal "foo" into string`}, + {`{"result":"123"}`, `json: invalid use of ,string struct tag, trying to unmarshal "123" into string`}, +} + +// If people misuse the ,string modifier, the error message should be +// helpful, telling the user that they're doing it wrong. +func TestErrorMessageFromMisusedString(t *testing.T) { + for n, tt := range wrongStringTests { + r := strings.NewReader(tt.in) + var s WrongString + err := NewDecoder(r).Decode(&s) + got := fmt.Sprintf("%v", err) + if got != tt.err { + t.Errorf("%d. got err = %q, want %q", n, got, tt.err) + } + } +} + +func noSpace(c rune) rune { + if isSpace(c) { + return -1 + } + return c +} + +type All struct { + Bool bool + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Uintptr uintptr + Float32 float32 + Float64 float64 + + Foo string `json:"bar"` + Foo2 string `json:"bar2,dummyopt"` + + IntStr int64 `json:",string"` + + PBool *bool + PInt *int + PInt8 *int8 + PInt16 *int16 + PInt32 *int32 + PInt64 *int64 + PUint *uint + PUint8 *uint8 + PUint16 *uint16 + PUint32 *uint32 + PUint64 *uint64 + PUintptr *uintptr + PFloat32 *float32 + PFloat64 *float64 + + String string + PString *string + + Map map[string]Small + MapP map[string]*Small + PMap *map[string]Small + PMapP *map[string]*Small + + EmptyMap map[string]Small + NilMap map[string]Small + + Slice []Small + SliceP []*Small + PSlice *[]Small + PSliceP *[]*Small + + EmptySlice []Small + NilSlice []Small + + StringSlice []string + ByteSlice []byte + + Small Small + PSmall *Small + PPSmall **Small + + Interface interface{} + PInterface *interface{} + + unexported int +} + +type Small struct { + Tag string +} + +var allValue = All{ + Bool: true, + Int: 2, + Int8: 3, + Int16: 4, + Int32: 5, + Int64: 6, + Uint: 7, + Uint8: 8, + Uint16: 9, + Uint32: 10, + Uint64: 11, + Uintptr: 12, + Float32: 14.1, + Float64: 15.1, + Foo: "foo", + Foo2: "foo2", + IntStr: 42, + String: "16", + Map: map[string]Small{ + "17": {Tag: "tag17"}, + "18": {Tag: "tag18"}, + }, + MapP: map[string]*Small{ + "19": {Tag: "tag19"}, + "20": nil, + }, + EmptyMap: map[string]Small{}, + Slice: []Small{{Tag: "tag20"}, {Tag: "tag21"}}, + SliceP: []*Small{{Tag: "tag22"}, nil, {Tag: "tag23"}}, + EmptySlice: []Small{}, + StringSlice: []string{"str24", "str25", "str26"}, + ByteSlice: []byte{27, 28, 29}, + Small: Small{Tag: "tag30"}, + PSmall: &Small{Tag: "tag31"}, + Interface: 5.2, +} + +var pallValue = All{ + PBool: &allValue.Bool, + PInt: &allValue.Int, + PInt8: &allValue.Int8, + PInt16: &allValue.Int16, + PInt32: &allValue.Int32, + PInt64: &allValue.Int64, + PUint: &allValue.Uint, + PUint8: &allValue.Uint8, + PUint16: &allValue.Uint16, + PUint32: &allValue.Uint32, + PUint64: &allValue.Uint64, + PUintptr: &allValue.Uintptr, + PFloat32: &allValue.Float32, + PFloat64: &allValue.Float64, + PString: &allValue.String, + PMap: &allValue.Map, + PMapP: &allValue.MapP, + PSlice: &allValue.Slice, + PSliceP: &allValue.SliceP, + PPSmall: &allValue.PSmall, + PInterface: &allValue.Interface, +} + +var allValueIndent = `{ + "Bool": true, + "Int": 2, + "Int8": 3, + "Int16": 4, + "Int32": 5, + "Int64": 6, + "Uint": 7, + "Uint8": 8, + "Uint16": 9, + "Uint32": 10, + "Uint64": 11, + "Uintptr": 12, + "Float32": 14.1, + "Float64": 15.1, + "bar": "foo", + "bar2": "foo2", + "IntStr": "42", + "PBool": null, + "PInt": null, + "PInt8": null, + "PInt16": null, + "PInt32": null, + "PInt64": null, + "PUint": null, + "PUint8": null, + "PUint16": null, + "PUint32": null, + "PUint64": null, + "PUintptr": null, + "PFloat32": null, + "PFloat64": null, + "String": "16", + "PString": null, + "Map": { + "17": { + "Tag": "tag17" + }, + "18": { + "Tag": "tag18" + } + }, + "MapP": { + "19": { + "Tag": "tag19" + }, + "20": null + }, + "PMap": null, + "PMapP": null, + "EmptyMap": {}, + "NilMap": null, + "Slice": [ + { + "Tag": "tag20" + }, + { + "Tag": "tag21" + } + ], + "SliceP": [ + { + "Tag": "tag22" + }, + null, + { + "Tag": "tag23" + } + ], + "PSlice": null, + "PSliceP": null, + "EmptySlice": [], + "NilSlice": null, + "StringSlice": [ + "str24", + "str25", + "str26" + ], + "ByteSlice": "Gxwd", + "Small": { + "Tag": "tag30" + }, + "PSmall": { + "Tag": "tag31" + }, + "PPSmall": null, + "Interface": 5.2, + "PInterface": null +}` + +var allValueCompact = strings.Map(noSpace, allValueIndent) + +var pallValueIndent = `{ + "Bool": false, + "Int": 0, + "Int8": 0, + "Int16": 0, + "Int32": 0, + "Int64": 0, + "Uint": 0, + "Uint8": 0, + "Uint16": 0, + "Uint32": 0, + "Uint64": 0, + "Uintptr": 0, + "Float32": 0, + "Float64": 0, + "bar": "", + "bar2": "", + "IntStr": "0", + "PBool": true, + "PInt": 2, + "PInt8": 3, + "PInt16": 4, + "PInt32": 5, + "PInt64": 6, + "PUint": 7, + "PUint8": 8, + "PUint16": 9, + "PUint32": 10, + "PUint64": 11, + "PUintptr": 12, + "PFloat32": 14.1, + "PFloat64": 15.1, + "String": "", + "PString": "16", + "Map": null, + "MapP": null, + "PMap": { + "17": { + "Tag": "tag17" + }, + "18": { + "Tag": "tag18" + } + }, + "PMapP": { + "19": { + "Tag": "tag19" + }, + "20": null + }, + "EmptyMap": null, + "NilMap": null, + "Slice": null, + "SliceP": null, + "PSlice": [ + { + "Tag": "tag20" + }, + { + "Tag": "tag21" + } + ], + "PSliceP": [ + { + "Tag": "tag22" + }, + null, + { + "Tag": "tag23" + } + ], + "EmptySlice": null, + "NilSlice": null, + "StringSlice": null, + "ByteSlice": null, + "Small": { + "Tag": "" + }, + "PSmall": null, + "PPSmall": { + "Tag": "tag31" + }, + "Interface": null, + "PInterface": 5.2 +}` + +var pallValueCompact = strings.Map(noSpace, pallValueIndent) + +func TestRefUnmarshal(t *testing.T) { + type S struct { + // Ref is defined in encode_test.go. + R0 Ref + R1 *Ref + R2 RefText + R3 *RefText + } + want := S{ + R0: 12, + R1: new(Ref), + R2: 13, + R3: new(RefText), + } + *want.R1 = 12 + *want.R3 = 13 + + var got S + if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref","R2":"ref","R3":"ref"}`), &got); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} + +// Test that the empty string doesn't panic decoding when ,string is specified +// Issue 3450 +func TestEmptyString(t *testing.T) { + type T2 struct { + Number1 int `json:",string"` + Number2 int `json:",string"` + } + data := `{"Number1":"1", "Number2":""}` + dec := NewDecoder(strings.NewReader(data)) + var t2 T2 + err := dec.Decode(&t2) + if err == nil { + t.Fatal("Decode: did not return error") + } + if t2.Number1 != 1 { + t.Fatal("Decode: did not set Number1") + } +} + +// Test that a null for ,string is not replaced with the previous quoted string (issue 7046). +// It should also not be an error (issue 2540, issue 8587). +func TestNullString(t *testing.T) { + type T struct { + A int `json:",string"` + B int `json:",string"` + C *int `json:",string"` + } + data := []byte(`{"A": "1", "B": null, "C": null}`) + var s T + s.B = 1 + s.C = new(int) + *s.C = 2 + err := Unmarshal(data, &s) + if err != nil { + t.Fatalf("Unmarshal: %v") + } + if s.B != 1 || s.C != nil { + t.Fatalf("after Unmarshal, s.B=%d, s.C=%p, want 1, nil", s.B, s.C) + } +} + +func intp(x int) *int { + p := new(int) + *p = x + return p +} + +func intpp(x *int) **int { + pp := new(*int) + *pp = x + return pp +} + +var interfaceSetTests = []struct { + pre interface{} + json string + post interface{} +}{ + {"foo", `"bar"`, "bar"}, + {"foo", `2`, 2.0}, + {"foo", `true`, true}, + {"foo", `null`, nil}, + + {nil, `null`, nil}, + {new(int), `null`, nil}, + {(*int)(nil), `null`, nil}, + {new(*int), `null`, new(*int)}, + {(**int)(nil), `null`, nil}, + {intp(1), `null`, nil}, + {intpp(nil), `null`, intpp(nil)}, + {intpp(intp(1)), `null`, intpp(nil)}, +} + +func TestInterfaceSet(t *testing.T) { + for _, tt := range interfaceSetTests { + b := struct{ X interface{} }{tt.pre} + blob := `{"X":` + tt.json + `}` + if err := Unmarshal([]byte(blob), &b); err != nil { + t.Errorf("Unmarshal %#q: %v", blob, err) + continue + } + if !reflect.DeepEqual(b.X, tt.post) { + t.Errorf("Unmarshal %#q into %#v: X=%#v, want %#v", blob, tt.pre, b.X, tt.post) + } + } +} + +// JSON null values should be ignored for primitives and string values instead of resulting in an error. +// Issue 2540 +func TestUnmarshalNulls(t *testing.T) { + jsonData := []byte(`{ + "Bool" : null, + "Int" : null, + "Int8" : null, + "Int16" : null, + "Int32" : null, + "Int64" : null, + "Uint" : null, + "Uint8" : null, + "Uint16" : null, + "Uint32" : null, + "Uint64" : null, + "Float32" : null, + "Float64" : null, + "String" : null}`) + + nulls := All{ + Bool: true, + Int: 2, + Int8: 3, + Int16: 4, + Int32: 5, + Int64: 6, + Uint: 7, + Uint8: 8, + Uint16: 9, + Uint32: 10, + Uint64: 11, + Float32: 12.1, + Float64: 13.1, + String: "14"} + + err := Unmarshal(jsonData, &nulls) + if err != nil { + t.Errorf("Unmarshal of null values failed: %v", err) + } + if !nulls.Bool || nulls.Int != 2 || nulls.Int8 != 3 || nulls.Int16 != 4 || nulls.Int32 != 5 || nulls.Int64 != 6 || + nulls.Uint != 7 || nulls.Uint8 != 8 || nulls.Uint16 != 9 || nulls.Uint32 != 10 || nulls.Uint64 != 11 || + nulls.Float32 != 12.1 || nulls.Float64 != 13.1 || nulls.String != "14" { + + t.Errorf("Unmarshal of null values affected primitives") + } +} + +func TestStringKind(t *testing.T) { + type stringKind string + + var m1, m2 map[stringKind]int + m1 = map[stringKind]int{ + "foo": 42, + } + + data, err := Marshal(m1) + if err != nil { + t.Errorf("Unexpected error marshalling: %v", err) + } + + err = Unmarshal(data, &m2) + if err != nil { + t.Errorf("Unexpected error unmarshalling: %v", err) + } + + if !reflect.DeepEqual(m1, m2) { + t.Error("Items should be equal after encoding and then decoding") + } + +} + +var decodeTypeErrorTests = []struct { + dest interface{} + src string +}{ + {new(string), `{"user": "name"}`}, // issue 4628. + {new(error), `{}`}, // issue 4222 + {new(error), `[]`}, + {new(error), `""`}, + {new(error), `123`}, + {new(error), `true`}, +} + +func TestUnmarshalTypeError(t *testing.T) { + for _, item := range decodeTypeErrorTests { + err := Unmarshal([]byte(item.src), item.dest) + if _, ok := err.(*UnmarshalTypeError); !ok { + t.Errorf("expected type error for Unmarshal(%q, type %T): got %T", + item.src, item.dest, err) + } + } +} + +var unmarshalSyntaxTests = []string{ + "tru", + "fals", + "nul", + "123e", + `"hello`, + `[1,2,3`, + `{"key":1`, + `{"key":1,`, +} + +func TestUnmarshalSyntax(t *testing.T) { + var x interface{} + for _, src := range unmarshalSyntaxTests { + err := Unmarshal([]byte(src), &x) + if _, ok := err.(*SyntaxError); !ok { + t.Errorf("expected syntax error for Unmarshal(%q): got %T", src, err) + } + } +} + +// Test handling of unexported fields that should be ignored. +// Issue 4660 +type unexportedFields struct { + Name string + m map[string]interface{} `json:"-"` + m2 map[string]interface{} `json:"abcd"` +} + +func TestUnmarshalUnexported(t *testing.T) { + input := `{"Name": "Bob", "m": {"x": 123}, "m2": {"y": 456}, "abcd": {"z": 789}}` + want := &unexportedFields{Name: "Bob"} + + out := &unexportedFields{} + err := Unmarshal([]byte(input), out) + if err != nil { + t.Errorf("got error %v, expected nil", err) + } + if !reflect.DeepEqual(out, want) { + t.Errorf("got %q, want %q", out, want) + } +} + +// Time3339 is a time.Time which encodes to and from JSON +// as an RFC 3339 time in UTC. +type Time3339 time.Time + +func (t *Time3339) UnmarshalJSON(b []byte) error { + if len(b) < 2 || b[0] != '"' || b[len(b)-1] != '"' { + return fmt.Errorf("types: failed to unmarshal non-string value %q as an RFC 3339 time", b) + } + tm, err := time.Parse(time.RFC3339, string(b[1:len(b)-1])) + if err != nil { + return err + } + *t = Time3339(tm) + return nil +} + +func TestUnmarshalJSONLiteralError(t *testing.T) { + var t3 Time3339 + err := Unmarshal([]byte(`"0000-00-00T00:00:00Z"`), &t3) + if err == nil { + t.Fatalf("expected error; got time %v", time.Time(t3)) + } + if !strings.Contains(err.Error(), "range") { + t.Errorf("got err = %v; want out of range error", err) + } +} + +// Test that extra object elements in an array do not result in a +// "data changing underfoot" error. +// Issue 3717 +func TestSkipArrayObjects(t *testing.T) { + json := `[{}]` + var dest [0]interface{} + + err := Unmarshal([]byte(json), &dest) + if err != nil { + t.Errorf("got error %q, want nil", err) + } +} + +// Test semantics of pre-filled struct fields and pre-filled map fields. +// Issue 4900. +func TestPrefilled(t *testing.T) { + ptrToMap := func(m map[string]interface{}) *map[string]interface{} { return &m } + + // Values here change, cannot reuse table across runs. + var prefillTests = []struct { + in string + ptr interface{} + out interface{} + }{ + { + in: `{"X": 1, "Y": 2}`, + ptr: &XYZ{X: float32(3), Y: int16(4), Z: 1.5}, + out: &XYZ{X: float64(1), Y: float64(2), Z: 1.5}, + }, + { + in: `{"X": 1, "Y": 2}`, + ptr: ptrToMap(map[string]interface{}{"X": float32(3), "Y": int16(4), "Z": 1.5}), + out: ptrToMap(map[string]interface{}{"X": float64(1), "Y": float64(2), "Z": 1.5}), + }, + } + + for _, tt := range prefillTests { + ptrstr := fmt.Sprintf("%v", tt.ptr) + err := Unmarshal([]byte(tt.in), tt.ptr) // tt.ptr edited here + if err != nil { + t.Errorf("Unmarshal: %v", err) + } + if !reflect.DeepEqual(tt.ptr, tt.out) { + t.Errorf("Unmarshal(%#q, %s): have %v, want %v", tt.in, ptrstr, tt.ptr, tt.out) + } + } +} + +var invalidUnmarshalTests = []struct { + v interface{} + want string +}{ + {nil, "json: Unmarshal(nil)"}, + {struct{}{}, "json: Unmarshal(non-pointer struct {})"}, + {(*int)(nil), "json: Unmarshal(nil *int)"}, +} + +func TestInvalidUnmarshal(t *testing.T) { + buf := []byte(`{"a":"1"}`) + for _, tt := range invalidUnmarshalTests { + err := Unmarshal(buf, tt.v) + if err == nil { + t.Errorf("Unmarshal expecting error, got nil") + continue + } + if got := err.Error(); got != tt.want { + t.Errorf("Unmarshal = %q; want %q", got, tt.want) + } + } +} diff --git a/src/encoding/json/encode.go b/src/encoding/json/encode.go new file mode 100644 index 000000000..fca2a0980 --- /dev/null +++ b/src/encoding/json/encode.go @@ -0,0 +1,1183 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package json implements encoding and decoding of JSON objects as defined in +// RFC 4627. The mapping between JSON objects and Go values is described +// in the documentation for the Marshal and Unmarshal functions. +// +// See "JSON and Go" for an introduction to this package: +// http://golang.org/doc/articles/json_and_go.html +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "math" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements the Marshaler interface +// and is not a nil pointer, Marshal calls its MarshalJSON method +// to produce JSON. The nil pointer exception is not strictly necessary +// but mimics a similar, necessary exception in the behavior of +// UnmarshalJSON. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point, integer, and Number values encode as JSON numbers. +// +// String values encode as JSON strings coerced to valid UTF-8, +// replacing invalid bytes with the Unicode replacement rune. +// The angle brackets "<" and ">" are escaped to "\u003c" and "\u003e" +// to keep some browsers from misinterpreting JSON output as HTML. +// Ampersand "&" is also escaped to "\u0026" for the same reason. +// +// Array and slice values encode as JSON arrays, except that +// []byte encodes as a base64-encoded string, and a nil slice +// encodes as the null JSON object. +// +// Struct values encode as JSON objects. Each exported struct field +// becomes a member of the object unless +// - the field's tag is "-", or +// - the field is empty and its tag specifies the "omitempty" option. +// The empty values are false, 0, any +// nil pointer or interface value, and any array, slice, map, or string of +// length zero. The object's default key string is the struct field name +// but can be specified in the struct field's tag value. The "json" key in +// the struct field's tag value is the key name, followed by an optional comma +// and options. Examples: +// +// // Field is ignored by this package. +// Field int `json:"-"` +// +// // Field appears in JSON as key "myName". +// Field int `json:"myName"` +// +// // Field appears in JSON as key "myName" and +// // the field is omitted from the object if its value is empty, +// // as defined above. +// Field int `json:"myName,omitempty"` +// +// // Field appears in JSON as key "Field" (the default), but +// // the field is skipped if empty. +// // Note the leading comma. +// Field int `json:",omitempty"` +// +// The "string" option signals that a field is stored as JSON inside a +// JSON-encoded string. It applies only to fields of string, floating point, +// or integer types. This extra level of encoding is sometimes used when +// communicating with JavaScript programs: +// +// Int64String int64 `json:",string"` +// +// The key name will be used if it's a non-empty string consisting of +// only Unicode letters, digits, dollar signs, percent signs, hyphens, +// underscores and slashes. +// +// Anonymous struct fields are usually marshaled as if their inner exported fields +// were fields in the outer struct, subject to the usual Go visibility rules amended +// as described in the next paragraph. +// An anonymous struct field with a name given in its JSON tag is treated as +// having that name, rather than being anonymous. +// An anonymous struct field of interface type is treated the same as having +// that type as its name, rather than being anonymous. +// +// The Go visibility rules for struct fields are amended for JSON when +// deciding which field to marshal or unmarshal. If there are +// multiple fields at the same level, and that level is the least +// nested (and would therefore be the nesting level selected by the +// usual Go rules), the following extra rules apply: +// +// 1) Of those fields, if any are JSON-tagged, only tagged fields are considered, +// even if there are multiple untagged fields that would otherwise conflict. +// 2) If there is exactly one field (tagged or not according to the first rule), that is selected. +// 3) Otherwise there are multiple fields, and all are ignored; no error occurs. +// +// Handling of anonymous struct fields is new in Go 1.1. +// Prior to Go 1.1, anonymous struct fields were ignored. To force ignoring of +// an anonymous struct field in both current and earlier versions, give the field +// a JSON tag of "-". +// +// Map values encode as JSON objects. +// The map's key type must be string; the object keys are used directly +// as map keys. +// +// Pointer values encode as the value pointed to. +// A nil pointer encodes as the null JSON object. +// +// Interface values encode as the value contained in the interface. +// A nil interface value encodes as the null JSON object. +// +// Channel, complex, and function values cannot be encoded in JSON. +// Attempting to encode such a value causes Marshal to return +// an UnsupportedTypeError. +// +// JSON cannot represent cyclic data structures and Marshal does not +// handle them. Passing cyclic structures to Marshal will result in +// an infinite recursion. +// +func Marshal(v interface{}) ([]byte, error) { + e := &encodeState{} + err := e.marshal(v) + if err != nil { + return nil, err + } + return e.Bytes(), nil +} + +// MarshalIndent is like Marshal but applies Indent to format the output. +func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { + b, err := Marshal(v) + if err != nil { + return nil, err + } + var buf bytes.Buffer + err = Indent(&buf, b, prefix, indent) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029 +// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029 +// so that the JSON will be safe to embed inside HTML <script> tags. +// For historical reasons, web browsers don't honor standard HTML +// escaping within <script> tags, so an alternative JSON encoding must +// be used. +func HTMLEscape(dst *bytes.Buffer, src []byte) { + // The characters can only appear in string literals, + // so just scan the string one byte at a time. + start := 0 + for i, c := range src { + if c == '<' || c == '>' || c == '&' { + if start < i { + dst.Write(src[start:i]) + } + dst.WriteString(`\u00`) + dst.WriteByte(hex[c>>4]) + dst.WriteByte(hex[c&0xF]) + start = i + 1 + } + // Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9). + if c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 { + if start < i { + dst.Write(src[start:i]) + } + dst.WriteString(`\u202`) + dst.WriteByte(hex[src[i+2]&0xF]) + start = i + 3 + } + } + if start < len(src) { + dst.Write(src[start:]) + } +} + +// Marshaler is the interface implemented by objects that +// can marshal themselves into valid JSON. +type Marshaler interface { + MarshalJSON() ([]byte, error) +} + +// An UnsupportedTypeError is returned by Marshal when attempting +// to encode an unsupported value type. +type UnsupportedTypeError struct { + Type reflect.Type +} + +func (e *UnsupportedTypeError) Error() string { + return "json: unsupported type: " + e.Type.String() +} + +type UnsupportedValueError struct { + Value reflect.Value + Str string +} + +func (e *UnsupportedValueError) Error() string { + return "json: unsupported value: " + e.Str +} + +// Before Go 1.2, an InvalidUTF8Error was returned by Marshal when +// attempting to encode a string value with invalid UTF-8 sequences. +// As of Go 1.2, Marshal instead coerces the string to valid UTF-8 by +// replacing invalid bytes with the Unicode replacement rune U+FFFD. +// This error is no longer generated but is kept for backwards compatibility +// with programs that might mention it. +type InvalidUTF8Error struct { + S string // the whole string value that caused the error +} + +func (e *InvalidUTF8Error) Error() string { + return "json: invalid UTF-8 in string: " + strconv.Quote(e.S) +} + +type MarshalerError struct { + Type reflect.Type + Err error +} + +func (e *MarshalerError) Error() string { + return "json: error calling MarshalJSON for type " + e.Type.String() + ": " + e.Err.Error() +} + +var hex = "0123456789abcdef" + +// An encodeState encodes JSON into a bytes.Buffer. +type encodeState struct { + bytes.Buffer // accumulated output + scratch [64]byte +} + +var encodeStatePool sync.Pool + +func newEncodeState() *encodeState { + if v := encodeStatePool.Get(); v != nil { + e := v.(*encodeState) + e.Reset() + return e + } + return new(encodeState) +} + +func (e *encodeState) marshal(v interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + if s, ok := r.(string); ok { + panic(s) + } + err = r.(error) + } + }() + e.reflectValue(reflect.ValueOf(v)) + return nil +} + +func (e *encodeState) error(err error) { + panic(err) +} + +var byteSliceType = reflect.TypeOf([]byte(nil)) + +func isEmptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false +} + +func (e *encodeState) reflectValue(v reflect.Value) { + valueEncoder(v)(e, v, false) +} + +type encoderFunc func(e *encodeState, v reflect.Value, quoted bool) + +var encoderCache struct { + sync.RWMutex + m map[reflect.Type]encoderFunc +} + +func valueEncoder(v reflect.Value) encoderFunc { + if !v.IsValid() { + return invalidValueEncoder + } + return typeEncoder(v.Type()) +} + +func typeEncoder(t reflect.Type) encoderFunc { + encoderCache.RLock() + f := encoderCache.m[t] + encoderCache.RUnlock() + if f != nil { + return f + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + encoderCache.Lock() + if encoderCache.m == nil { + encoderCache.m = make(map[reflect.Type]encoderFunc) + } + var wg sync.WaitGroup + wg.Add(1) + encoderCache.m[t] = func(e *encodeState, v reflect.Value, quoted bool) { + wg.Wait() + f(e, v, quoted) + } + encoderCache.Unlock() + + // Compute fields without lock. + // Might duplicate effort but won't hold other computations back. + f = newTypeEncoder(t, true) + wg.Done() + encoderCache.Lock() + encoderCache.m[t] = f + encoderCache.Unlock() + return f +} + +var ( + marshalerType = reflect.TypeOf(new(Marshaler)).Elem() + textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() +) + +// newTypeEncoder constructs an encoderFunc for a type. +// The returned encoder only checks CanAddr when allowAddr is true. +func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { + if t.Implements(marshalerType) { + return marshalerEncoder + } + if t.Kind() != reflect.Ptr && allowAddr { + if reflect.PtrTo(t).Implements(marshalerType) { + return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) + } + } + + if t.Implements(textMarshalerType) { + return textMarshalerEncoder + } + if t.Kind() != reflect.Ptr && allowAddr { + if reflect.PtrTo(t).Implements(textMarshalerType) { + return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false)) + } + } + + switch t.Kind() { + case reflect.Bool: + return boolEncoder + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intEncoder + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintEncoder + case reflect.Float32: + return float32Encoder + case reflect.Float64: + return float64Encoder + case reflect.String: + return stringEncoder + case reflect.Interface: + return interfaceEncoder + case reflect.Struct: + return newStructEncoder(t) + case reflect.Map: + return newMapEncoder(t) + case reflect.Slice: + return newSliceEncoder(t) + case reflect.Array: + return newArrayEncoder(t) + case reflect.Ptr: + return newPtrEncoder(t) + default: + return unsupportedTypeEncoder + } +} + +func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) { + e.WriteString("null") +} + +func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { + if v.Kind() == reflect.Ptr && v.IsNil() { + e.WriteString("null") + return + } + m := v.Interface().(Marshaler) + b, err := m.MarshalJSON() + if err == nil { + // copy JSON into buffer, checking validity. + err = compact(&e.Buffer, b, true) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err}) + } +} + +func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { + va := v.Addr() + if va.IsNil() { + e.WriteString("null") + return + } + m := va.Interface().(Marshaler) + b, err := m.MarshalJSON() + if err == nil { + // copy JSON into buffer, checking validity. + err = compact(&e.Buffer, b, true) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err}) + } +} + +func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { + if v.Kind() == reflect.Ptr && v.IsNil() { + e.WriteString("null") + return + } + m := v.Interface().(encoding.TextMarshaler) + b, err := m.MarshalText() + if err == nil { + _, err = e.stringBytes(b) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err}) + } +} + +func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { + va := v.Addr() + if va.IsNil() { + e.WriteString("null") + return + } + m := va.Interface().(encoding.TextMarshaler) + b, err := m.MarshalText() + if err == nil { + _, err = e.stringBytes(b) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err}) + } +} + +func boolEncoder(e *encodeState, v reflect.Value, quoted bool) { + if quoted { + e.WriteByte('"') + } + if v.Bool() { + e.WriteString("true") + } else { + e.WriteString("false") + } + if quoted { + e.WriteByte('"') + } +} + +func intEncoder(e *encodeState, v reflect.Value, quoted bool) { + b := strconv.AppendInt(e.scratch[:0], v.Int(), 10) + if quoted { + e.WriteByte('"') + } + e.Write(b) + if quoted { + e.WriteByte('"') + } +} + +func uintEncoder(e *encodeState, v reflect.Value, quoted bool) { + b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10) + if quoted { + e.WriteByte('"') + } + e.Write(b) + if quoted { + e.WriteByte('"') + } +} + +type floatEncoder int // number of bits + +func (bits floatEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { + f := v.Float() + if math.IsInf(f, 0) || math.IsNaN(f) { + e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, int(bits))}) + } + b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, int(bits)) + if quoted { + e.WriteByte('"') + } + e.Write(b) + if quoted { + e.WriteByte('"') + } +} + +var ( + float32Encoder = (floatEncoder(32)).encode + float64Encoder = (floatEncoder(64)).encode +) + +func stringEncoder(e *encodeState, v reflect.Value, quoted bool) { + if v.Type() == numberType { + numStr := v.String() + if numStr == "" { + numStr = "0" // Number's zero-val + } + e.WriteString(numStr) + return + } + if quoted { + sb, err := Marshal(v.String()) + if err != nil { + e.error(err) + } + e.string(string(sb)) + } else { + e.string(v.String()) + } +} + +func interfaceEncoder(e *encodeState, v reflect.Value, quoted bool) { + if v.IsNil() { + e.WriteString("null") + return + } + e.reflectValue(v.Elem()) +} + +func unsupportedTypeEncoder(e *encodeState, v reflect.Value, quoted bool) { + e.error(&UnsupportedTypeError{v.Type()}) +} + +type structEncoder struct { + fields []field + fieldEncs []encoderFunc +} + +func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { + e.WriteByte('{') + first := true + for i, f := range se.fields { + fv := fieldByIndex(v, f.index) + if !fv.IsValid() || f.omitEmpty && isEmptyValue(fv) { + continue + } + if first { + first = false + } else { + e.WriteByte(',') + } + e.string(f.name) + e.WriteByte(':') + se.fieldEncs[i](e, fv, f.quoted) + } + e.WriteByte('}') +} + +func newStructEncoder(t reflect.Type) encoderFunc { + fields := cachedTypeFields(t) + se := &structEncoder{ + fields: fields, + fieldEncs: make([]encoderFunc, len(fields)), + } + for i, f := range fields { + se.fieldEncs[i] = typeEncoder(typeByIndex(t, f.index)) + } + return se.encode +} + +type mapEncoder struct { + elemEnc encoderFunc +} + +func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) { + if v.IsNil() { + e.WriteString("null") + return + } + e.WriteByte('{') + var sv stringValues = v.MapKeys() + sort.Sort(sv) + for i, k := range sv { + if i > 0 { + e.WriteByte(',') + } + e.string(k.String()) + e.WriteByte(':') + me.elemEnc(e, v.MapIndex(k), false) + } + e.WriteByte('}') +} + +func newMapEncoder(t reflect.Type) encoderFunc { + if t.Key().Kind() != reflect.String { + return unsupportedTypeEncoder + } + me := &mapEncoder{typeEncoder(t.Elem())} + return me.encode +} + +func encodeByteSlice(e *encodeState, v reflect.Value, _ bool) { + if v.IsNil() { + e.WriteString("null") + return + } + s := v.Bytes() + e.WriteByte('"') + if len(s) < 1024 { + // for small buffers, using Encode directly is much faster. + dst := make([]byte, base64.StdEncoding.EncodedLen(len(s))) + base64.StdEncoding.Encode(dst, s) + e.Write(dst) + } else { + // for large buffers, avoid unnecessary extra temporary + // buffer space. + enc := base64.NewEncoder(base64.StdEncoding, e) + enc.Write(s) + enc.Close() + } + e.WriteByte('"') +} + +// sliceEncoder just wraps an arrayEncoder, checking to make sure the value isn't nil. +type sliceEncoder struct { + arrayEnc encoderFunc +} + +func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) { + if v.IsNil() { + e.WriteString("null") + return + } + se.arrayEnc(e, v, false) +} + +func newSliceEncoder(t reflect.Type) encoderFunc { + // Byte slices get special treatment; arrays don't. + if t.Elem().Kind() == reflect.Uint8 { + return encodeByteSlice + } + enc := &sliceEncoder{newArrayEncoder(t)} + return enc.encode +} + +type arrayEncoder struct { + elemEnc encoderFunc +} + +func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) { + e.WriteByte('[') + n := v.Len() + for i := 0; i < n; i++ { + if i > 0 { + e.WriteByte(',') + } + ae.elemEnc(e, v.Index(i), false) + } + e.WriteByte(']') +} + +func newArrayEncoder(t reflect.Type) encoderFunc { + enc := &arrayEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +type ptrEncoder struct { + elemEnc encoderFunc +} + +func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { + if v.IsNil() { + e.WriteString("null") + return + } + pe.elemEnc(e, v.Elem(), quoted) +} + +func newPtrEncoder(t reflect.Type) encoderFunc { + enc := &ptrEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +type condAddrEncoder struct { + canAddrEnc, elseEnc encoderFunc +} + +func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { + if v.CanAddr() { + ce.canAddrEnc(e, v, quoted) + } else { + ce.elseEnc(e, v, quoted) + } +} + +// newCondAddrEncoder returns an encoder that checks whether its value +// CanAddr and delegates to canAddrEnc if so, else to elseEnc. +func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc { + enc := &condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc} + return enc.encode +} + +func isValidTag(s string) bool { + if s == "" { + return false + } + for _, c := range s { + switch { + case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c): + // Backslash and quote chars are reserved, but + // otherwise any punctuation chars are allowed + // in a tag name. + default: + if !unicode.IsLetter(c) && !unicode.IsDigit(c) { + return false + } + } + } + return true +} + +func fieldByIndex(v reflect.Value, index []int) reflect.Value { + for _, i := range index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + } + v = v.Field(i) + } + return v +} + +func typeByIndex(t reflect.Type, index []int) reflect.Type { + for _, i := range index { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + t = t.Field(i).Type + } + return t +} + +// stringValues is a slice of reflect.Value holding *reflect.StringValue. +// It implements the methods to sort by string. +type stringValues []reflect.Value + +func (sv stringValues) Len() int { return len(sv) } +func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } +func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) } +func (sv stringValues) get(i int) string { return sv[i].String() } + +// NOTE: keep in sync with stringBytes below. +func (e *encodeState) string(s string) (int, error) { + len0 := e.Len() + e.WriteByte('"') + start := 0 + for i := 0; i < len(s); { + if b := s[i]; b < utf8.RuneSelf { + if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' { + i++ + continue + } + if start < i { + e.WriteString(s[start:i]) + } + switch b { + case '\\', '"': + e.WriteByte('\\') + e.WriteByte(b) + case '\n': + e.WriteByte('\\') + e.WriteByte('n') + case '\r': + e.WriteByte('\\') + e.WriteByte('r') + case '\t': + e.WriteByte('\\') + e.WriteByte('t') + default: + // This encodes bytes < 0x20 except for \n and \r, + // as well as <, > and &. The latter are escaped because they + // can lead to security holes when user-controlled strings + // are rendered into JSON and served to some browsers. + e.WriteString(`\u00`) + e.WriteByte(hex[b>>4]) + e.WriteByte(hex[b&0xF]) + } + i++ + start = i + continue + } + c, size := utf8.DecodeRuneInString(s[i:]) + if c == utf8.RuneError && size == 1 { + if start < i { + e.WriteString(s[start:i]) + } + e.WriteString(`\ufffd`) + i += size + start = i + continue + } + // U+2028 is LINE SEPARATOR. + // U+2029 is PARAGRAPH SEPARATOR. + // They are both technically valid characters in JSON strings, + // but don't work in JSONP, which has to be evaluated as JavaScript, + // and can lead to security holes there. It is valid JSON to + // escape them, so we do so unconditionally. + // See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion. + if c == '\u2028' || c == '\u2029' { + if start < i { + e.WriteString(s[start:i]) + } + e.WriteString(`\u202`) + e.WriteByte(hex[c&0xF]) + i += size + start = i + continue + } + i += size + } + if start < len(s) { + e.WriteString(s[start:]) + } + e.WriteByte('"') + return e.Len() - len0, nil +} + +// NOTE: keep in sync with string above. +func (e *encodeState) stringBytes(s []byte) (int, error) { + len0 := e.Len() + e.WriteByte('"') + start := 0 + for i := 0; i < len(s); { + if b := s[i]; b < utf8.RuneSelf { + if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' { + i++ + continue + } + if start < i { + e.Write(s[start:i]) + } + switch b { + case '\\', '"': + e.WriteByte('\\') + e.WriteByte(b) + case '\n': + e.WriteByte('\\') + e.WriteByte('n') + case '\r': + e.WriteByte('\\') + e.WriteByte('r') + case '\t': + e.WriteByte('\\') + e.WriteByte('t') + default: + // This encodes bytes < 0x20 except for \n and \r, + // as well as <, >, and &. The latter are escaped because they + // can lead to security holes when user-controlled strings + // are rendered into JSON and served to some browsers. + e.WriteString(`\u00`) + e.WriteByte(hex[b>>4]) + e.WriteByte(hex[b&0xF]) + } + i++ + start = i + continue + } + c, size := utf8.DecodeRune(s[i:]) + if c == utf8.RuneError && size == 1 { + if start < i { + e.Write(s[start:i]) + } + e.WriteString(`\ufffd`) + i += size + start = i + continue + } + // U+2028 is LINE SEPARATOR. + // U+2029 is PARAGRAPH SEPARATOR. + // They are both technically valid characters in JSON strings, + // but don't work in JSONP, which has to be evaluated as JavaScript, + // and can lead to security holes there. It is valid JSON to + // escape them, so we do so unconditionally. + // See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion. + if c == '\u2028' || c == '\u2029' { + if start < i { + e.Write(s[start:i]) + } + e.WriteString(`\u202`) + e.WriteByte(hex[c&0xF]) + i += size + start = i + continue + } + i += size + } + if start < len(s) { + e.Write(s[start:]) + } + e.WriteByte('"') + return e.Len() - len0, nil +} + +// A field represents a single field found in a struct. +type field struct { + name string + nameBytes []byte // []byte(name) + equalFold func(s, t []byte) bool // bytes.EqualFold or equivalent + + tag bool + index []int + typ reflect.Type + omitEmpty bool + quoted bool +} + +func fillField(f field) field { + f.nameBytes = []byte(f.name) + f.equalFold = foldFunc(f.nameBytes) + return f +} + +// byName sorts field by name, breaking ties with depth, +// then breaking ties with "name came from json tag", then +// breaking ties with index sequence. +type byName []field + +func (x byName) Len() int { return len(x) } + +func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x byName) Less(i, j int) bool { + if x[i].name != x[j].name { + return x[i].name < x[j].name + } + if len(x[i].index) != len(x[j].index) { + return len(x[i].index) < len(x[j].index) + } + if x[i].tag != x[j].tag { + return x[i].tag + } + return byIndex(x).Less(i, j) +} + +// byIndex sorts field by index sequence. +type byIndex []field + +func (x byIndex) Len() int { return len(x) } + +func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x byIndex) Less(i, j int) bool { + for k, xik := range x[i].index { + if k >= len(x[j].index) { + return false + } + if xik != x[j].index[k] { + return xik < x[j].index[k] + } + } + return len(x[i].index) < len(x[j].index) +} + +// typeFields returns a list of fields that JSON should recognize for the given type. +// The algorithm is breadth-first search over the set of structs to include - the top struct +// and then any reachable anonymous structs. +func typeFields(t reflect.Type) []field { + // Anonymous fields to explore at the current level and the next. + current := []field{} + next := []field{{typ: t}} + + // Count of queued names for current level and the next. + count := map[reflect.Type]int{} + nextCount := map[reflect.Type]int{} + + // Types already visited at an earlier level. + visited := map[reflect.Type]bool{} + + // Fields found. + var fields []field + + for len(next) > 0 { + current, next = next, current[:0] + count, nextCount = nextCount, map[reflect.Type]int{} + + for _, f := range current { + if visited[f.typ] { + continue + } + visited[f.typ] = true + + // Scan f.typ for fields to include. + for i := 0; i < f.typ.NumField(); i++ { + sf := f.typ.Field(i) + if sf.PkgPath != "" { // unexported + continue + } + tag := sf.Tag.Get("json") + if tag == "-" { + continue + } + name, opts := parseTag(tag) + if !isValidTag(name) { + name = "" + } + index := make([]int, len(f.index)+1) + copy(index, f.index) + index[len(f.index)] = i + + ft := sf.Type + if ft.Name() == "" && ft.Kind() == reflect.Ptr { + // Follow pointer. + ft = ft.Elem() + } + + // Record found field and index sequence. + if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { + tagged := name != "" + if name == "" { + name = sf.Name + } + fields = append(fields, fillField(field{ + name: name, + tag: tagged, + index: index, + typ: ft, + omitEmpty: opts.Contains("omitempty"), + quoted: opts.Contains("string"), + })) + if count[f.typ] > 1 { + // If there were multiple instances, add a second, + // so that the annihilation code will see a duplicate. + // It only cares about the distinction between 1 or 2, + // so don't bother generating any more copies. + fields = append(fields, fields[len(fields)-1]) + } + continue + } + + // Record new anonymous struct to explore in next round. + nextCount[ft]++ + if nextCount[ft] == 1 { + next = append(next, fillField(field{name: ft.Name(), index: index, typ: ft})) + } + } + } + } + + sort.Sort(byName(fields)) + + // Delete all fields that are hidden by the Go rules for embedded fields, + // except that fields with JSON tags are promoted. + + // The fields are sorted in primary order of name, secondary order + // of field index length. Loop over names; for each name, delete + // hidden fields by choosing the one dominant field that survives. + out := fields[:0] + for advance, i := 0, 0; i < len(fields); i += advance { + // One iteration per name. + // Find the sequence of fields with the name of this first field. + fi := fields[i] + name := fi.name + for advance = 1; i+advance < len(fields); advance++ { + fj := fields[i+advance] + if fj.name != name { + break + } + } + if advance == 1 { // Only one field with this name + out = append(out, fi) + continue + } + dominant, ok := dominantField(fields[i : i+advance]) + if ok { + out = append(out, dominant) + } + } + + fields = out + sort.Sort(byIndex(fields)) + + return fields +} + +// dominantField looks through the fields, all of which are known to +// have the same name, to find the single field that dominates the +// others using Go's embedding rules, modified by the presence of +// JSON tags. If there are multiple top-level fields, the boolean +// will be false: This condition is an error in Go and we skip all +// the fields. +func dominantField(fields []field) (field, bool) { + // The fields are sorted in increasing index-length order. The winner + // must therefore be one with the shortest index length. Drop all + // longer entries, which is easy: just truncate the slice. + length := len(fields[0].index) + tagged := -1 // Index of first tagged field. + for i, f := range fields { + if len(f.index) > length { + fields = fields[:i] + break + } + if f.tag { + if tagged >= 0 { + // Multiple tagged fields at the same level: conflict. + // Return no field. + return field{}, false + } + tagged = i + } + } + if tagged >= 0 { + return fields[tagged], true + } + // All remaining fields have the same length. If there's more than one, + // we have a conflict (two fields named "X" at the same level) and we + // return no field. + if len(fields) > 1 { + return field{}, false + } + return fields[0], true +} + +var fieldCache struct { + sync.RWMutex + m map[reflect.Type][]field +} + +// cachedTypeFields is like typeFields but uses a cache to avoid repeated work. +func cachedTypeFields(t reflect.Type) []field { + fieldCache.RLock() + f := fieldCache.m[t] + fieldCache.RUnlock() + if f != nil { + return f + } + + // Compute fields without lock. + // Might duplicate effort but won't hold other computations back. + f = typeFields(t) + if f == nil { + f = []field{} + } + + fieldCache.Lock() + if fieldCache.m == nil { + fieldCache.m = map[reflect.Type][]field{} + } + fieldCache.m[t] = f + fieldCache.Unlock() + return f +} diff --git a/src/encoding/json/encode_test.go b/src/encoding/json/encode_test.go new file mode 100644 index 000000000..7abfa85db --- /dev/null +++ b/src/encoding/json/encode_test.go @@ -0,0 +1,532 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "bytes" + "math" + "reflect" + "testing" + "unicode" +) + +type Optionals struct { + Sr string `json:"sr"` + So string `json:"so,omitempty"` + Sw string `json:"-"` + + Ir int `json:"omitempty"` // actually named omitempty, not an option + Io int `json:"io,omitempty"` + + Slr []string `json:"slr,random"` + Slo []string `json:"slo,omitempty"` + + Mr map[string]interface{} `json:"mr"` + Mo map[string]interface{} `json:",omitempty"` + + Fr float64 `json:"fr"` + Fo float64 `json:"fo,omitempty"` + + Br bool `json:"br"` + Bo bool `json:"bo,omitempty"` + + Ur uint `json:"ur"` + Uo uint `json:"uo,omitempty"` + + Str struct{} `json:"str"` + Sto struct{} `json:"sto,omitempty"` +} + +var optionalsExpected = `{ + "sr": "", + "omitempty": 0, + "slr": null, + "mr": {}, + "fr": 0, + "br": false, + "ur": 0, + "str": {}, + "sto": {} +}` + +func TestOmitEmpty(t *testing.T) { + var o Optionals + o.Sw = "something" + o.Mr = map[string]interface{}{} + o.Mo = map[string]interface{}{} + + got, err := MarshalIndent(&o, "", " ") + if err != nil { + t.Fatal(err) + } + if got := string(got); got != optionalsExpected { + t.Errorf(" got: %s\nwant: %s\n", got, optionalsExpected) + } +} + +type StringTag struct { + BoolStr bool `json:",string"` + IntStr int64 `json:",string"` + StrStr string `json:",string"` +} + +var stringTagExpected = `{ + "BoolStr": "true", + "IntStr": "42", + "StrStr": "\"xzbit\"" +}` + +func TestStringTag(t *testing.T) { + var s StringTag + s.BoolStr = true + s.IntStr = 42 + s.StrStr = "xzbit" + got, err := MarshalIndent(&s, "", " ") + if err != nil { + t.Fatal(err) + } + if got := string(got); got != stringTagExpected { + t.Fatalf(" got: %s\nwant: %s\n", got, stringTagExpected) + } + + // Verify that it round-trips. + var s2 StringTag + err = NewDecoder(bytes.NewReader(got)).Decode(&s2) + if err != nil { + t.Fatalf("Decode: %v", err) + } + if !reflect.DeepEqual(s, s2) { + t.Fatalf("decode didn't match.\nsource: %#v\nEncoded as:\n%s\ndecode: %#v", s, string(got), s2) + } +} + +// byte slices are special even if they're renamed types. +type renamedByte byte +type renamedByteSlice []byte +type renamedRenamedByteSlice []renamedByte + +func TestEncodeRenamedByteSlice(t *testing.T) { + s := renamedByteSlice("abc") + result, err := Marshal(s) + if err != nil { + t.Fatal(err) + } + expect := `"YWJj"` + if string(result) != expect { + t.Errorf(" got %s want %s", result, expect) + } + r := renamedRenamedByteSlice("abc") + result, err = Marshal(r) + if err != nil { + t.Fatal(err) + } + if string(result) != expect { + t.Errorf(" got %s want %s", result, expect) + } +} + +var unsupportedValues = []interface{}{ + math.NaN(), + math.Inf(-1), + math.Inf(1), +} + +func TestUnsupportedValues(t *testing.T) { + for _, v := range unsupportedValues { + if _, err := Marshal(v); err != nil { + if _, ok := err.(*UnsupportedValueError); !ok { + t.Errorf("for %v, got %T want UnsupportedValueError", v, err) + } + } else { + t.Errorf("for %v, expected error", v) + } + } +} + +// Ref has Marshaler and Unmarshaler methods with pointer receiver. +type Ref int + +func (*Ref) MarshalJSON() ([]byte, error) { + return []byte(`"ref"`), nil +} + +func (r *Ref) UnmarshalJSON([]byte) error { + *r = 12 + return nil +} + +// Val has Marshaler methods with value receiver. +type Val int + +func (Val) MarshalJSON() ([]byte, error) { + return []byte(`"val"`), nil +} + +// RefText has Marshaler and Unmarshaler methods with pointer receiver. +type RefText int + +func (*RefText) MarshalText() ([]byte, error) { + return []byte(`"ref"`), nil +} + +func (r *RefText) UnmarshalText([]byte) error { + *r = 13 + return nil +} + +// ValText has Marshaler methods with value receiver. +type ValText int + +func (ValText) MarshalText() ([]byte, error) { + return []byte(`"val"`), nil +} + +func TestRefValMarshal(t *testing.T) { + var s = struct { + R0 Ref + R1 *Ref + R2 RefText + R3 *RefText + V0 Val + V1 *Val + V2 ValText + V3 *ValText + }{ + R0: 12, + R1: new(Ref), + R2: 14, + R3: new(RefText), + V0: 13, + V1: new(Val), + V2: 15, + V3: new(ValText), + } + const want = `{"R0":"ref","R1":"ref","R2":"\"ref\"","R3":"\"ref\"","V0":"val","V1":"val","V2":"\"val\"","V3":"\"val\""}` + b, err := Marshal(&s) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if got := string(b); got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +// C implements Marshaler and returns unescaped JSON. +type C int + +func (C) MarshalJSON() ([]byte, error) { + return []byte(`"<&>"`), nil +} + +// CText implements Marshaler and returns unescaped text. +type CText int + +func (CText) MarshalText() ([]byte, error) { + return []byte(`"<&>"`), nil +} + +func TestMarshalerEscaping(t *testing.T) { + var c C + want := `"\u003c\u0026\u003e"` + b, err := Marshal(c) + if err != nil { + t.Fatalf("Marshal(c): %v", err) + } + if got := string(b); got != want { + t.Errorf("Marshal(c) = %#q, want %#q", got, want) + } + + var ct CText + want = `"\"\u003c\u0026\u003e\""` + b, err = Marshal(ct) + if err != nil { + t.Fatalf("Marshal(ct): %v", err) + } + if got := string(b); got != want { + t.Errorf("Marshal(ct) = %#q, want %#q", got, want) + } +} + +type IntType int + +type MyStruct struct { + IntType +} + +func TestAnonymousNonstruct(t *testing.T) { + var i IntType = 11 + a := MyStruct{i} + const want = `{"IntType":11}` + + b, err := Marshal(a) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if got := string(b); got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +type BugA struct { + S string +} + +type BugB struct { + BugA + S string +} + +type BugC struct { + S string +} + +// Legal Go: We never use the repeated embedded field (S). +type BugX struct { + A int + BugA + BugB +} + +// Issue 5245. +func TestEmbeddedBug(t *testing.T) { + v := BugB{ + BugA{"A"}, + "B", + } + b, err := Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{"S":"B"}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } + // Now check that the duplicate field, S, does not appear. + x := BugX{ + A: 23, + } + b, err = Marshal(x) + if err != nil { + t.Fatal("Marshal:", err) + } + want = `{"A":23}` + got = string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} + +type BugD struct { // Same as BugA after tagging. + XXX string `json:"S"` +} + +// BugD's tagged S field should dominate BugA's. +type BugY struct { + BugA + BugD +} + +// Test that a field with a tag dominates untagged fields. +func TestTaggedFieldDominates(t *testing.T) { + v := BugY{ + BugA{"BugA"}, + BugD{"BugD"}, + } + b, err := Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{"S":"BugD"}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} + +// There are no tags here, so S should not appear. +type BugZ struct { + BugA + BugC + BugY // Contains a tagged S field through BugD; should not dominate. +} + +func TestDuplicatedFieldDisappears(t *testing.T) { + v := BugZ{ + BugA{"BugA"}, + BugC{"BugC"}, + BugY{ + BugA{"nested BugA"}, + BugD{"nested BugD"}, + }, + } + b, err := Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} + +func TestStringBytes(t *testing.T) { + // Test that encodeState.stringBytes and encodeState.string use the same encoding. + es := &encodeState{} + var r []rune + for i := '\u0000'; i <= unicode.MaxRune; i++ { + r = append(r, i) + } + s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too + _, err := es.string(s) + if err != nil { + t.Fatal(err) + } + + esBytes := &encodeState{} + _, err = esBytes.stringBytes([]byte(s)) + if err != nil { + t.Fatal(err) + } + + enc := es.Buffer.String() + encBytes := esBytes.Buffer.String() + if enc != encBytes { + i := 0 + for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] { + i++ + } + enc = enc[i:] + encBytes = encBytes[i:] + i = 0 + for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] { + i++ + } + enc = enc[:len(enc)-i] + encBytes = encBytes[:len(encBytes)-i] + + if len(enc) > 20 { + enc = enc[:20] + "..." + } + if len(encBytes) > 20 { + encBytes = encBytes[:20] + "..." + } + + t.Errorf("encodings differ at %#q vs %#q", enc, encBytes) + } +} + +func TestIssue6458(t *testing.T) { + type Foo struct { + M RawMessage + } + x := Foo{RawMessage(`"foo"`)} + + b, err := Marshal(&x) + if err != nil { + t.Fatal(err) + } + if want := `{"M":"foo"}`; string(b) != want { + t.Errorf("Marshal(&x) = %#q; want %#q", b, want) + } + + b, err = Marshal(x) + if err != nil { + t.Fatal(err) + } + + if want := `{"M":"ImZvbyI="}`; string(b) != want { + t.Errorf("Marshal(x) = %#q; want %#q", b, want) + } +} + +func TestHTMLEscape(t *testing.T) { + var b, want bytes.Buffer + m := `{"M":"<html>foo &` + "\xe2\x80\xa8 \xe2\x80\xa9" + `</html>"}` + want.Write([]byte(`{"M":"\u003chtml\u003efoo \u0026\u2028 \u2029\u003c/html\u003e"}`)) + HTMLEscape(&b, []byte(m)) + if !bytes.Equal(b.Bytes(), want.Bytes()) { + t.Errorf("HTMLEscape(&b, []byte(m)) = %s; want %s", b.Bytes(), want.Bytes()) + } +} + +// golang.org/issue/8582 +func TestEncodePointerString(t *testing.T) { + type stringPointer struct { + N *int64 `json:"n,string"` + } + var n int64 = 42 + b, err := Marshal(stringPointer{N: &n}) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if got, want := string(b), `{"n":"42"}`; got != want { + t.Errorf("Marshal = %s, want %s", got, want) + } + var back stringPointer + err = Unmarshal(b, &back) + if err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if back.N == nil { + t.Fatalf("Unmarshalled nil N field") + } + if *back.N != 42 { + t.Fatalf("*N = %d; want 42", *back.N) + } +} + +var encodeStringTests = []struct { + in string + out string +}{ + {"\x00", `"\u0000"`}, + {"\x01", `"\u0001"`}, + {"\x02", `"\u0002"`}, + {"\x03", `"\u0003"`}, + {"\x04", `"\u0004"`}, + {"\x05", `"\u0005"`}, + {"\x06", `"\u0006"`}, + {"\x07", `"\u0007"`}, + {"\x08", `"\u0008"`}, + {"\x09", `"\t"`}, + {"\x0a", `"\n"`}, + {"\x0b", `"\u000b"`}, + {"\x0c", `"\u000c"`}, + {"\x0d", `"\r"`}, + {"\x0e", `"\u000e"`}, + {"\x0f", `"\u000f"`}, + {"\x10", `"\u0010"`}, + {"\x11", `"\u0011"`}, + {"\x12", `"\u0012"`}, + {"\x13", `"\u0013"`}, + {"\x14", `"\u0014"`}, + {"\x15", `"\u0015"`}, + {"\x16", `"\u0016"`}, + {"\x17", `"\u0017"`}, + {"\x18", `"\u0018"`}, + {"\x19", `"\u0019"`}, + {"\x1a", `"\u001a"`}, + {"\x1b", `"\u001b"`}, + {"\x1c", `"\u001c"`}, + {"\x1d", `"\u001d"`}, + {"\x1e", `"\u001e"`}, + {"\x1f", `"\u001f"`}, +} + +func TestEncodeString(t *testing.T) { + for _, tt := range encodeStringTests { + b, err := Marshal(tt.in) + if err != nil { + t.Errorf("Marshal(%q): %v", tt.in, err) + continue + } + out := string(b) + if out != tt.out { + t.Errorf("Marshal(%q) = %#q, want %#q", tt.in, out, tt.out) + } + } +} diff --git a/src/encoding/json/example_test.go b/src/encoding/json/example_test.go new file mode 100644 index 000000000..ca4e5ae68 --- /dev/null +++ b/src/encoding/json/example_test.go @@ -0,0 +1,161 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json_test + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "os" + "strings" +) + +func ExampleMarshal() { + type ColorGroup struct { + ID int + Name string + Colors []string + } + group := ColorGroup{ + ID: 1, + Name: "Reds", + Colors: []string{"Crimson", "Red", "Ruby", "Maroon"}, + } + b, err := json.Marshal(group) + if err != nil { + fmt.Println("error:", err) + } + os.Stdout.Write(b) + // Output: + // {"ID":1,"Name":"Reds","Colors":["Crimson","Red","Ruby","Maroon"]} +} + +func ExampleUnmarshal() { + var jsonBlob = []byte(`[ + {"Name": "Platypus", "Order": "Monotremata"}, + {"Name": "Quoll", "Order": "Dasyuromorphia"} + ]`) + type Animal struct { + Name string + Order string + } + var animals []Animal + err := json.Unmarshal(jsonBlob, &animals) + if err != nil { + fmt.Println("error:", err) + } + fmt.Printf("%+v", animals) + // Output: + // [{Name:Platypus Order:Monotremata} {Name:Quoll Order:Dasyuromorphia}] +} + +// This example uses a Decoder to decode a stream of distinct JSON values. +func ExampleDecoder() { + const jsonStream = ` + {"Name": "Ed", "Text": "Knock knock."} + {"Name": "Sam", "Text": "Who's there?"} + {"Name": "Ed", "Text": "Go fmt."} + {"Name": "Sam", "Text": "Go fmt who?"} + {"Name": "Ed", "Text": "Go fmt yourself!"} + ` + type Message struct { + Name, Text string + } + dec := json.NewDecoder(strings.NewReader(jsonStream)) + for { + var m Message + if err := dec.Decode(&m); err == io.EOF { + break + } else if err != nil { + log.Fatal(err) + } + fmt.Printf("%s: %s\n", m.Name, m.Text) + } + // Output: + // Ed: Knock knock. + // Sam: Who's there? + // Ed: Go fmt. + // Sam: Go fmt who? + // Ed: Go fmt yourself! +} + +// This example uses RawMessage to delay parsing part of a JSON message. +func ExampleRawMessage() { + type Color struct { + Space string + Point json.RawMessage // delay parsing until we know the color space + } + type RGB struct { + R uint8 + G uint8 + B uint8 + } + type YCbCr struct { + Y uint8 + Cb int8 + Cr int8 + } + + var j = []byte(`[ + {"Space": "YCbCr", "Point": {"Y": 255, "Cb": 0, "Cr": -10}}, + {"Space": "RGB", "Point": {"R": 98, "G": 218, "B": 255}} + ]`) + var colors []Color + err := json.Unmarshal(j, &colors) + if err != nil { + log.Fatalln("error:", err) + } + + for _, c := range colors { + var dst interface{} + switch c.Space { + case "RGB": + dst = new(RGB) + case "YCbCr": + dst = new(YCbCr) + } + err := json.Unmarshal(c.Point, dst) + if err != nil { + log.Fatalln("error:", err) + } + fmt.Println(c.Space, dst) + } + // Output: + // YCbCr &{255 0 -10} + // RGB &{98 218 255} +} + +func ExampleIndent() { + type Road struct { + Name string + Number int + } + roads := []Road{ + {"Diamond Fork", 29}, + {"Sheep Creek", 51}, + } + + b, err := json.Marshal(roads) + if err != nil { + log.Fatal(err) + } + + var out bytes.Buffer + json.Indent(&out, b, "=", "\t") + out.WriteTo(os.Stdout) + // Output: + // [ + // = { + // = "Name": "Diamond Fork", + // = "Number": 29 + // = }, + // = { + // = "Name": "Sheep Creek", + // = "Number": 51 + // = } + // =] +} diff --git a/src/encoding/json/fold.go b/src/encoding/json/fold.go new file mode 100644 index 000000000..d6f77c93e --- /dev/null +++ b/src/encoding/json/fold.go @@ -0,0 +1,143 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "bytes" + "unicode/utf8" +) + +const ( + caseMask = ^byte(0x20) // Mask to ignore case in ASCII. + kelvin = '\u212a' + smallLongEss = '\u017f' +) + +// foldFunc returns one of four different case folding equivalence +// functions, from most general (and slow) to fastest: +// +// 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8 +// 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S') +// 3) asciiEqualFold, no special, but includes non-letters (including _) +// 4) simpleLetterEqualFold, no specials, no non-letters. +// +// The letters S and K are special because they map to 3 runes, not just 2: +// * S maps to s and to U+017F 'ſ' Latin small letter long s +// * k maps to K and to U+212A 'K' Kelvin sign +// See http://play.golang.org/p/tTxjOc0OGo +// +// The returned function is specialized for matching against s and +// should only be given s. It's not curried for performance reasons. +func foldFunc(s []byte) func(s, t []byte) bool { + nonLetter := false + special := false // special letter + for _, b := range s { + if b >= utf8.RuneSelf { + return bytes.EqualFold + } + upper := b & caseMask + if upper < 'A' || upper > 'Z' { + nonLetter = true + } else if upper == 'K' || upper == 'S' { + // See above for why these letters are special. + special = true + } + } + if special { + return equalFoldRight + } + if nonLetter { + return asciiEqualFold + } + return simpleLetterEqualFold +} + +// equalFoldRight is a specialization of bytes.EqualFold when s is +// known to be all ASCII (including punctuation), but contains an 's', +// 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t. +// See comments on foldFunc. +func equalFoldRight(s, t []byte) bool { + for _, sb := range s { + if len(t) == 0 { + return false + } + tb := t[0] + if tb < utf8.RuneSelf { + if sb != tb { + sbUpper := sb & caseMask + if 'A' <= sbUpper && sbUpper <= 'Z' { + if sbUpper != tb&caseMask { + return false + } + } else { + return false + } + } + t = t[1:] + continue + } + // sb is ASCII and t is not. t must be either kelvin + // sign or long s; sb must be s, S, k, or K. + tr, size := utf8.DecodeRune(t) + switch sb { + case 's', 'S': + if tr != smallLongEss { + return false + } + case 'k', 'K': + if tr != kelvin { + return false + } + default: + return false + } + t = t[size:] + + } + if len(t) > 0 { + return false + } + return true +} + +// asciiEqualFold is a specialization of bytes.EqualFold for use when +// s is all ASCII (but may contain non-letters) and contains no +// special-folding letters. +// See comments on foldFunc. +func asciiEqualFold(s, t []byte) bool { + if len(s) != len(t) { + return false + } + for i, sb := range s { + tb := t[i] + if sb == tb { + continue + } + if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') { + if sb&caseMask != tb&caseMask { + return false + } + } else { + return false + } + } + return true +} + +// simpleLetterEqualFold is a specialization of bytes.EqualFold for +// use when s is all ASCII letters (no underscores, etc) and also +// doesn't contain 'k', 'K', 's', or 'S'. +// See comments on foldFunc. +func simpleLetterEqualFold(s, t []byte) bool { + if len(s) != len(t) { + return false + } + for i, b := range s { + if b&caseMask != t[i]&caseMask { + return false + } + } + return true +} diff --git a/src/encoding/json/fold_test.go b/src/encoding/json/fold_test.go new file mode 100644 index 000000000..9fb94646a --- /dev/null +++ b/src/encoding/json/fold_test.go @@ -0,0 +1,116 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "bytes" + "strings" + "testing" + "unicode/utf8" +) + +var foldTests = []struct { + fn func(s, t []byte) bool + s, t string + want bool +}{ + {equalFoldRight, "", "", true}, + {equalFoldRight, "a", "a", true}, + {equalFoldRight, "", "a", false}, + {equalFoldRight, "a", "", false}, + {equalFoldRight, "a", "A", true}, + {equalFoldRight, "AB", "ab", true}, + {equalFoldRight, "AB", "ac", false}, + {equalFoldRight, "sbkKc", "ſbKKc", true}, + {equalFoldRight, "SbKkc", "ſbKKc", true}, + {equalFoldRight, "SbKkc", "ſbKK", false}, + {equalFoldRight, "e", "é", false}, + {equalFoldRight, "s", "S", true}, + + {simpleLetterEqualFold, "", "", true}, + {simpleLetterEqualFold, "abc", "abc", true}, + {simpleLetterEqualFold, "abc", "ABC", true}, + {simpleLetterEqualFold, "abc", "ABCD", false}, + {simpleLetterEqualFold, "abc", "xxx", false}, + + {asciiEqualFold, "a_B", "A_b", true}, + {asciiEqualFold, "aa@", "aa`", false}, // verify 0x40 and 0x60 aren't case-equivalent +} + +func TestFold(t *testing.T) { + for i, tt := range foldTests { + if got := tt.fn([]byte(tt.s), []byte(tt.t)); got != tt.want { + t.Errorf("%d. %q, %q = %v; want %v", i, tt.s, tt.t, got, tt.want) + } + truth := strings.EqualFold(tt.s, tt.t) + if truth != tt.want { + t.Errorf("strings.EqualFold doesn't agree with case %d", i) + } + } +} + +func TestFoldAgainstUnicode(t *testing.T) { + const bufSize = 5 + buf1 := make([]byte, 0, bufSize) + buf2 := make([]byte, 0, bufSize) + var runes []rune + for i := 0x20; i <= 0x7f; i++ { + runes = append(runes, rune(i)) + } + runes = append(runes, kelvin, smallLongEss) + + funcs := []struct { + name string + fold func(s, t []byte) bool + letter bool // must be ASCII letter + simple bool // must be simple ASCII letter (not 'S' or 'K') + }{ + { + name: "equalFoldRight", + fold: equalFoldRight, + }, + { + name: "asciiEqualFold", + fold: asciiEqualFold, + simple: true, + }, + { + name: "simpleLetterEqualFold", + fold: simpleLetterEqualFold, + simple: true, + letter: true, + }, + } + + for _, ff := range funcs { + for _, r := range runes { + if r >= utf8.RuneSelf { + continue + } + if ff.letter && !isASCIILetter(byte(r)) { + continue + } + if ff.simple && (r == 's' || r == 'S' || r == 'k' || r == 'K') { + continue + } + for _, r2 := range runes { + buf1 := append(buf1[:0], 'x') + buf2 := append(buf2[:0], 'x') + buf1 = buf1[:1+utf8.EncodeRune(buf1[1:bufSize], r)] + buf2 = buf2[:1+utf8.EncodeRune(buf2[1:bufSize], r2)] + buf1 = append(buf1, 'x') + buf2 = append(buf2, 'x') + want := bytes.EqualFold(buf1, buf2) + if got := ff.fold(buf1, buf2); got != want { + t.Errorf("%s(%q, %q) = %v; want %v", ff.name, buf1, buf2, got, want) + } + } + } + } +} + +func isASCIILetter(b byte) bool { + return ('A' <= b && b <= 'Z') || ('a' <= b && b <= 'z') +} diff --git a/src/encoding/json/indent.go b/src/encoding/json/indent.go new file mode 100644 index 000000000..e1bacafd6 --- /dev/null +++ b/src/encoding/json/indent.go @@ -0,0 +1,137 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import "bytes" + +// Compact appends to dst the JSON-encoded src with +// insignificant space characters elided. +func Compact(dst *bytes.Buffer, src []byte) error { + return compact(dst, src, false) +} + +func compact(dst *bytes.Buffer, src []byte, escape bool) error { + origLen := dst.Len() + var scan scanner + scan.reset() + start := 0 + for i, c := range src { + if escape && (c == '<' || c == '>' || c == '&') { + if start < i { + dst.Write(src[start:i]) + } + dst.WriteString(`\u00`) + dst.WriteByte(hex[c>>4]) + dst.WriteByte(hex[c&0xF]) + start = i + 1 + } + // Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9). + if c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 { + if start < i { + dst.Write(src[start:i]) + } + dst.WriteString(`\u202`) + dst.WriteByte(hex[src[i+2]&0xF]) + start = i + 3 + } + v := scan.step(&scan, int(c)) + if v >= scanSkipSpace { + if v == scanError { + break + } + if start < i { + dst.Write(src[start:i]) + } + start = i + 1 + } + } + if scan.eof() == scanError { + dst.Truncate(origLen) + return scan.err + } + if start < len(src) { + dst.Write(src[start:]) + } + return nil +} + +func newline(dst *bytes.Buffer, prefix, indent string, depth int) { + dst.WriteByte('\n') + dst.WriteString(prefix) + for i := 0; i < depth; i++ { + dst.WriteString(indent) + } +} + +// Indent appends to dst an indented form of the JSON-encoded src. +// Each element in a JSON object or array begins on a new, +// indented line beginning with prefix followed by one or more +// copies of indent according to the indentation nesting. +// The data appended to dst does not begin with the prefix nor +// any indentation, and has no trailing newline, to make it +// easier to embed inside other formatted JSON data. +func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error { + origLen := dst.Len() + var scan scanner + scan.reset() + needIndent := false + depth := 0 + for _, c := range src { + scan.bytes++ + v := scan.step(&scan, int(c)) + if v == scanSkipSpace { + continue + } + if v == scanError { + break + } + if needIndent && v != scanEndObject && v != scanEndArray { + needIndent = false + depth++ + newline(dst, prefix, indent, depth) + } + + // Emit semantically uninteresting bytes + // (in particular, punctuation in strings) unmodified. + if v == scanContinue { + dst.WriteByte(c) + continue + } + + // Add spacing around real punctuation. + switch c { + case '{', '[': + // delay indent so that empty object and array are formatted as {} and []. + needIndent = true + dst.WriteByte(c) + + case ',': + dst.WriteByte(c) + newline(dst, prefix, indent, depth) + + case ':': + dst.WriteByte(c) + dst.WriteByte(' ') + + case '}', ']': + if needIndent { + // suppress indent in empty object/array + needIndent = false + } else { + depth-- + newline(dst, prefix, indent, depth) + } + dst.WriteByte(c) + + default: + dst.WriteByte(c) + } + } + if scan.eof() == scanError { + dst.Truncate(origLen) + return scan.err + } + return nil +} diff --git a/src/encoding/json/scanner.go b/src/encoding/json/scanner.go new file mode 100644 index 000000000..a4609c895 --- /dev/null +++ b/src/encoding/json/scanner.go @@ -0,0 +1,623 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +// JSON value parser state machine. +// Just about at the limit of what is reasonable to write by hand. +// Some parts are a bit tedious, but overall it nicely factors out the +// otherwise common code from the multiple scanning functions +// in this package (Compact, Indent, checkValid, nextValue, etc). +// +// This file starts with two simple examples using the scanner +// before diving into the scanner itself. + +import "strconv" + +// checkValid verifies that data is valid JSON-encoded data. +// scan is passed in for use by checkValid to avoid an allocation. +func checkValid(data []byte, scan *scanner) error { + scan.reset() + for _, c := range data { + scan.bytes++ + if scan.step(scan, int(c)) == scanError { + return scan.err + } + } + if scan.eof() == scanError { + return scan.err + } + return nil +} + +// nextValue splits data after the next whole JSON value, +// returning that value and the bytes that follow it as separate slices. +// scan is passed in for use by nextValue to avoid an allocation. +func nextValue(data []byte, scan *scanner) (value, rest []byte, err error) { + scan.reset() + for i, c := range data { + v := scan.step(scan, int(c)) + if v >= scanEnd { + switch v { + case scanError: + return nil, nil, scan.err + case scanEnd: + return data[0:i], data[i:], nil + } + } + } + if scan.eof() == scanError { + return nil, nil, scan.err + } + return data, nil, nil +} + +// A SyntaxError is a description of a JSON syntax error. +type SyntaxError struct { + msg string // description of error + Offset int64 // error occurred after reading Offset bytes +} + +func (e *SyntaxError) Error() string { return e.msg } + +// A scanner is a JSON scanning state machine. +// Callers call scan.reset() and then pass bytes in one at a time +// by calling scan.step(&scan, c) for each byte. +// The return value, referred to as an opcode, tells the +// caller about significant parsing events like beginning +// and ending literals, objects, and arrays, so that the +// caller can follow along if it wishes. +// The return value scanEnd indicates that a single top-level +// JSON value has been completed, *before* the byte that +// just got passed in. (The indication must be delayed in order +// to recognize the end of numbers: is 123 a whole value or +// the beginning of 12345e+6?). +type scanner struct { + // The step is a func to be called to execute the next transition. + // Also tried using an integer constant and a single func + // with a switch, but using the func directly was 10% faster + // on a 64-bit Mac Mini, and it's nicer to read. + step func(*scanner, int) int + + // Reached end of top-level value. + endTop bool + + // Stack of what we're in the middle of - array values, object keys, object values. + parseState []int + + // Error that happened, if any. + err error + + // 1-byte redo (see undo method) + redo bool + redoCode int + redoState func(*scanner, int) int + + // total bytes consumed, updated by decoder.Decode + bytes int64 +} + +// These values are returned by the state transition functions +// assigned to scanner.state and the method scanner.eof. +// They give details about the current state of the scan that +// callers might be interested to know about. +// It is okay to ignore the return value of any particular +// call to scanner.state: if one call returns scanError, +// every subsequent call will return scanError too. +const ( + // Continue. + scanContinue = iota // uninteresting byte + scanBeginLiteral // end implied by next result != scanContinue + scanBeginObject // begin object + scanObjectKey // just finished object key (string) + scanObjectValue // just finished non-last object value + scanEndObject // end object (implies scanObjectValue if possible) + scanBeginArray // begin array + scanArrayValue // just finished array value + scanEndArray // end array (implies scanArrayValue if possible) + scanSkipSpace // space byte; can skip; known to be last "continue" result + + // Stop. + scanEnd // top-level value ended *before* this byte; known to be first "stop" result + scanError // hit an error, scanner.err. +) + +// These values are stored in the parseState stack. +// They give the current state of a composite value +// being scanned. If the parser is inside a nested value +// the parseState describes the nested state, outermost at entry 0. +const ( + parseObjectKey = iota // parsing object key (before colon) + parseObjectValue // parsing object value (after colon) + parseArrayValue // parsing array value +) + +// reset prepares the scanner for use. +// It must be called before calling s.step. +func (s *scanner) reset() { + s.step = stateBeginValue + s.parseState = s.parseState[0:0] + s.err = nil + s.redo = false + s.endTop = false +} + +// eof tells the scanner that the end of input has been reached. +// It returns a scan status just as s.step does. +func (s *scanner) eof() int { + if s.err != nil { + return scanError + } + if s.endTop { + return scanEnd + } + s.step(s, ' ') + if s.endTop { + return scanEnd + } + if s.err == nil { + s.err = &SyntaxError{"unexpected end of JSON input", s.bytes} + } + return scanError +} + +// pushParseState pushes a new parse state p onto the parse stack. +func (s *scanner) pushParseState(p int) { + s.parseState = append(s.parseState, p) +} + +// popParseState pops a parse state (already obtained) off the stack +// and updates s.step accordingly. +func (s *scanner) popParseState() { + n := len(s.parseState) - 1 + s.parseState = s.parseState[0:n] + s.redo = false + if n == 0 { + s.step = stateEndTop + s.endTop = true + } else { + s.step = stateEndValue + } +} + +func isSpace(c rune) bool { + return c == ' ' || c == '\t' || c == '\r' || c == '\n' +} + +// stateBeginValueOrEmpty is the state after reading `[`. +func stateBeginValueOrEmpty(s *scanner, c int) int { + if c <= ' ' && isSpace(rune(c)) { + return scanSkipSpace + } + if c == ']' { + return stateEndValue(s, c) + } + return stateBeginValue(s, c) +} + +// stateBeginValue is the state at the beginning of the input. +func stateBeginValue(s *scanner, c int) int { + if c <= ' ' && isSpace(rune(c)) { + return scanSkipSpace + } + switch c { + case '{': + s.step = stateBeginStringOrEmpty + s.pushParseState(parseObjectKey) + return scanBeginObject + case '[': + s.step = stateBeginValueOrEmpty + s.pushParseState(parseArrayValue) + return scanBeginArray + case '"': + s.step = stateInString + return scanBeginLiteral + case '-': + s.step = stateNeg + return scanBeginLiteral + case '0': // beginning of 0.123 + s.step = state0 + return scanBeginLiteral + case 't': // beginning of true + s.step = stateT + return scanBeginLiteral + case 'f': // beginning of false + s.step = stateF + return scanBeginLiteral + case 'n': // beginning of null + s.step = stateN + return scanBeginLiteral + } + if '1' <= c && c <= '9' { // beginning of 1234.5 + s.step = state1 + return scanBeginLiteral + } + return s.error(c, "looking for beginning of value") +} + +// stateBeginStringOrEmpty is the state after reading `{`. +func stateBeginStringOrEmpty(s *scanner, c int) int { + if c <= ' ' && isSpace(rune(c)) { + return scanSkipSpace + } + if c == '}' { + n := len(s.parseState) + s.parseState[n-1] = parseObjectValue + return stateEndValue(s, c) + } + return stateBeginString(s, c) +} + +// stateBeginString is the state after reading `{"key": value,`. +func stateBeginString(s *scanner, c int) int { + if c <= ' ' && isSpace(rune(c)) { + return scanSkipSpace + } + if c == '"' { + s.step = stateInString + return scanBeginLiteral + } + return s.error(c, "looking for beginning of object key string") +} + +// stateEndValue is the state after completing a value, +// such as after reading `{}` or `true` or `["x"`. +func stateEndValue(s *scanner, c int) int { + n := len(s.parseState) + if n == 0 { + // Completed top-level before the current byte. + s.step = stateEndTop + s.endTop = true + return stateEndTop(s, c) + } + if c <= ' ' && isSpace(rune(c)) { + s.step = stateEndValue + return scanSkipSpace + } + ps := s.parseState[n-1] + switch ps { + case parseObjectKey: + if c == ':' { + s.parseState[n-1] = parseObjectValue + s.step = stateBeginValue + return scanObjectKey + } + return s.error(c, "after object key") + case parseObjectValue: + if c == ',' { + s.parseState[n-1] = parseObjectKey + s.step = stateBeginString + return scanObjectValue + } + if c == '}' { + s.popParseState() + return scanEndObject + } + return s.error(c, "after object key:value pair") + case parseArrayValue: + if c == ',' { + s.step = stateBeginValue + return scanArrayValue + } + if c == ']' { + s.popParseState() + return scanEndArray + } + return s.error(c, "after array element") + } + return s.error(c, "") +} + +// stateEndTop is the state after finishing the top-level value, +// such as after reading `{}` or `[1,2,3]`. +// Only space characters should be seen now. +func stateEndTop(s *scanner, c int) int { + if c != ' ' && c != '\t' && c != '\r' && c != '\n' { + // Complain about non-space byte on next call. + s.error(c, "after top-level value") + } + return scanEnd +} + +// stateInString is the state after reading `"`. +func stateInString(s *scanner, c int) int { + if c == '"' { + s.step = stateEndValue + return scanContinue + } + if c == '\\' { + s.step = stateInStringEsc + return scanContinue + } + if c < 0x20 { + return s.error(c, "in string literal") + } + return scanContinue +} + +// stateInStringEsc is the state after reading `"\` during a quoted string. +func stateInStringEsc(s *scanner, c int) int { + switch c { + case 'b', 'f', 'n', 'r', 't', '\\', '/', '"': + s.step = stateInString + return scanContinue + } + if c == 'u' { + s.step = stateInStringEscU + return scanContinue + } + return s.error(c, "in string escape code") +} + +// stateInStringEscU is the state after reading `"\u` during a quoted string. +func stateInStringEscU(s *scanner, c int) int { + if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' { + s.step = stateInStringEscU1 + return scanContinue + } + // numbers + return s.error(c, "in \\u hexadecimal character escape") +} + +// stateInStringEscU1 is the state after reading `"\u1` during a quoted string. +func stateInStringEscU1(s *scanner, c int) int { + if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' { + s.step = stateInStringEscU12 + return scanContinue + } + // numbers + return s.error(c, "in \\u hexadecimal character escape") +} + +// stateInStringEscU12 is the state after reading `"\u12` during a quoted string. +func stateInStringEscU12(s *scanner, c int) int { + if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' { + s.step = stateInStringEscU123 + return scanContinue + } + // numbers + return s.error(c, "in \\u hexadecimal character escape") +} + +// stateInStringEscU123 is the state after reading `"\u123` during a quoted string. +func stateInStringEscU123(s *scanner, c int) int { + if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' { + s.step = stateInString + return scanContinue + } + // numbers + return s.error(c, "in \\u hexadecimal character escape") +} + +// stateNeg is the state after reading `-` during a number. +func stateNeg(s *scanner, c int) int { + if c == '0' { + s.step = state0 + return scanContinue + } + if '1' <= c && c <= '9' { + s.step = state1 + return scanContinue + } + return s.error(c, "in numeric literal") +} + +// state1 is the state after reading a non-zero integer during a number, +// such as after reading `1` or `100` but not `0`. +func state1(s *scanner, c int) int { + if '0' <= c && c <= '9' { + s.step = state1 + return scanContinue + } + return state0(s, c) +} + +// state0 is the state after reading `0` during a number. +func state0(s *scanner, c int) int { + if c == '.' { + s.step = stateDot + return scanContinue + } + if c == 'e' || c == 'E' { + s.step = stateE + return scanContinue + } + return stateEndValue(s, c) +} + +// stateDot is the state after reading the integer and decimal point in a number, +// such as after reading `1.`. +func stateDot(s *scanner, c int) int { + if '0' <= c && c <= '9' { + s.step = stateDot0 + return scanContinue + } + return s.error(c, "after decimal point in numeric literal") +} + +// stateDot0 is the state after reading the integer, decimal point, and subsequent +// digits of a number, such as after reading `3.14`. +func stateDot0(s *scanner, c int) int { + if '0' <= c && c <= '9' { + s.step = stateDot0 + return scanContinue + } + if c == 'e' || c == 'E' { + s.step = stateE + return scanContinue + } + return stateEndValue(s, c) +} + +// stateE is the state after reading the mantissa and e in a number, +// such as after reading `314e` or `0.314e`. +func stateE(s *scanner, c int) int { + if c == '+' { + s.step = stateESign + return scanContinue + } + if c == '-' { + s.step = stateESign + return scanContinue + } + return stateESign(s, c) +} + +// stateESign is the state after reading the mantissa, e, and sign in a number, +// such as after reading `314e-` or `0.314e+`. +func stateESign(s *scanner, c int) int { + if '0' <= c && c <= '9' { + s.step = stateE0 + return scanContinue + } + return s.error(c, "in exponent of numeric literal") +} + +// stateE0 is the state after reading the mantissa, e, optional sign, +// and at least one digit of the exponent in a number, +// such as after reading `314e-2` or `0.314e+1` or `3.14e0`. +func stateE0(s *scanner, c int) int { + if '0' <= c && c <= '9' { + s.step = stateE0 + return scanContinue + } + return stateEndValue(s, c) +} + +// stateT is the state after reading `t`. +func stateT(s *scanner, c int) int { + if c == 'r' { + s.step = stateTr + return scanContinue + } + return s.error(c, "in literal true (expecting 'r')") +} + +// stateTr is the state after reading `tr`. +func stateTr(s *scanner, c int) int { + if c == 'u' { + s.step = stateTru + return scanContinue + } + return s.error(c, "in literal true (expecting 'u')") +} + +// stateTru is the state after reading `tru`. +func stateTru(s *scanner, c int) int { + if c == 'e' { + s.step = stateEndValue + return scanContinue + } + return s.error(c, "in literal true (expecting 'e')") +} + +// stateF is the state after reading `f`. +func stateF(s *scanner, c int) int { + if c == 'a' { + s.step = stateFa + return scanContinue + } + return s.error(c, "in literal false (expecting 'a')") +} + +// stateFa is the state after reading `fa`. +func stateFa(s *scanner, c int) int { + if c == 'l' { + s.step = stateFal + return scanContinue + } + return s.error(c, "in literal false (expecting 'l')") +} + +// stateFal is the state after reading `fal`. +func stateFal(s *scanner, c int) int { + if c == 's' { + s.step = stateFals + return scanContinue + } + return s.error(c, "in literal false (expecting 's')") +} + +// stateFals is the state after reading `fals`. +func stateFals(s *scanner, c int) int { + if c == 'e' { + s.step = stateEndValue + return scanContinue + } + return s.error(c, "in literal false (expecting 'e')") +} + +// stateN is the state after reading `n`. +func stateN(s *scanner, c int) int { + if c == 'u' { + s.step = stateNu + return scanContinue + } + return s.error(c, "in literal null (expecting 'u')") +} + +// stateNu is the state after reading `nu`. +func stateNu(s *scanner, c int) int { + if c == 'l' { + s.step = stateNul + return scanContinue + } + return s.error(c, "in literal null (expecting 'l')") +} + +// stateNul is the state after reading `nul`. +func stateNul(s *scanner, c int) int { + if c == 'l' { + s.step = stateEndValue + return scanContinue + } + return s.error(c, "in literal null (expecting 'l')") +} + +// stateError is the state after reaching a syntax error, +// such as after reading `[1}` or `5.1.2`. +func stateError(s *scanner, c int) int { + return scanError +} + +// error records an error and switches to the error state. +func (s *scanner) error(c int, context string) int { + s.step = stateError + s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes} + return scanError +} + +// quoteChar formats c as a quoted character literal +func quoteChar(c int) string { + // special cases - different from quoted strings + if c == '\'' { + return `'\''` + } + if c == '"' { + return `'"'` + } + + // use quoted string with different quotation marks + s := strconv.Quote(string(c)) + return "'" + s[1:len(s)-1] + "'" +} + +// undo causes the scanner to return scanCode from the next state transition. +// This gives callers a simple 1-byte undo mechanism. +func (s *scanner) undo(scanCode int) { + if s.redo { + panic("json: invalid use of scanner") + } + s.redoCode = scanCode + s.redoState = s.step + s.step = stateRedo + s.redo = true +} + +// stateRedo helps implement the scanner's 1-byte undo. +func stateRedo(s *scanner, c int) int { + s.redo = false + s.step = s.redoState + return s.redoCode +} diff --git a/src/encoding/json/scanner_test.go b/src/encoding/json/scanner_test.go new file mode 100644 index 000000000..788034290 --- /dev/null +++ b/src/encoding/json/scanner_test.go @@ -0,0 +1,315 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "bytes" + "math" + "math/rand" + "reflect" + "testing" +) + +// Tests of simple examples. + +type example struct { + compact string + indent string +} + +var examples = []example{ + {`1`, `1`}, + {`{}`, `{}`}, + {`[]`, `[]`}, + {`{"":2}`, "{\n\t\"\": 2\n}"}, + {`[3]`, "[\n\t3\n]"}, + {`[1,2,3]`, "[\n\t1,\n\t2,\n\t3\n]"}, + {`{"x":1}`, "{\n\t\"x\": 1\n}"}, + {ex1, ex1i}, +} + +var ex1 = `[true,false,null,"x",1,1.5,0,-5e+2]` + +var ex1i = `[ + true, + false, + null, + "x", + 1, + 1.5, + 0, + -5e+2 +]` + +func TestCompact(t *testing.T) { + var buf bytes.Buffer + for _, tt := range examples { + buf.Reset() + if err := Compact(&buf, []byte(tt.compact)); err != nil { + t.Errorf("Compact(%#q): %v", tt.compact, err) + } else if s := buf.String(); s != tt.compact { + t.Errorf("Compact(%#q) = %#q, want original", tt.compact, s) + } + + buf.Reset() + if err := Compact(&buf, []byte(tt.indent)); err != nil { + t.Errorf("Compact(%#q): %v", tt.indent, err) + continue + } else if s := buf.String(); s != tt.compact { + t.Errorf("Compact(%#q) = %#q, want %#q", tt.indent, s, tt.compact) + } + } +} + +func TestCompactSeparators(t *testing.T) { + // U+2028 and U+2029 should be escaped inside strings. + // They should not appear outside strings. + tests := []struct { + in, compact string + }{ + {"{\"\u2028\": 1}", `{"\u2028":1}`}, + {"{\"\u2029\" :2}", `{"\u2029":2}`}, + } + for _, tt := range tests { + var buf bytes.Buffer + if err := Compact(&buf, []byte(tt.in)); err != nil { + t.Errorf("Compact(%q): %v", tt.in, err) + } else if s := buf.String(); s != tt.compact { + t.Errorf("Compact(%q) = %q, want %q", tt.in, s, tt.compact) + } + } +} + +func TestIndent(t *testing.T) { + var buf bytes.Buffer + for _, tt := range examples { + buf.Reset() + if err := Indent(&buf, []byte(tt.indent), "", "\t"); err != nil { + t.Errorf("Indent(%#q): %v", tt.indent, err) + } else if s := buf.String(); s != tt.indent { + t.Errorf("Indent(%#q) = %#q, want original", tt.indent, s) + } + + buf.Reset() + if err := Indent(&buf, []byte(tt.compact), "", "\t"); err != nil { + t.Errorf("Indent(%#q): %v", tt.compact, err) + continue + } else if s := buf.String(); s != tt.indent { + t.Errorf("Indent(%#q) = %#q, want %#q", tt.compact, s, tt.indent) + } + } +} + +// Tests of a large random structure. + +func TestCompactBig(t *testing.T) { + initBig() + var buf bytes.Buffer + if err := Compact(&buf, jsonBig); err != nil { + t.Fatalf("Compact: %v", err) + } + b := buf.Bytes() + if !bytes.Equal(b, jsonBig) { + t.Error("Compact(jsonBig) != jsonBig") + diff(t, b, jsonBig) + return + } +} + +func TestIndentBig(t *testing.T) { + initBig() + var buf bytes.Buffer + if err := Indent(&buf, jsonBig, "", "\t"); err != nil { + t.Fatalf("Indent1: %v", err) + } + b := buf.Bytes() + if len(b) == len(jsonBig) { + // jsonBig is compact (no unnecessary spaces); + // indenting should make it bigger + t.Fatalf("Indent(jsonBig) did not get bigger") + } + + // should be idempotent + var buf1 bytes.Buffer + if err := Indent(&buf1, b, "", "\t"); err != nil { + t.Fatalf("Indent2: %v", err) + } + b1 := buf1.Bytes() + if !bytes.Equal(b1, b) { + t.Error("Indent(Indent(jsonBig)) != Indent(jsonBig)") + diff(t, b1, b) + return + } + + // should get back to original + buf1.Reset() + if err := Compact(&buf1, b); err != nil { + t.Fatalf("Compact: %v", err) + } + b1 = buf1.Bytes() + if !bytes.Equal(b1, jsonBig) { + t.Error("Compact(Indent(jsonBig)) != jsonBig") + diff(t, b1, jsonBig) + return + } +} + +type indentErrorTest struct { + in string + err error +} + +var indentErrorTests = []indentErrorTest{ + {`{"X": "foo", "Y"}`, &SyntaxError{"invalid character '}' after object key", 17}}, + {`{"X": "foo" "Y": "bar"}`, &SyntaxError{"invalid character '\"' after object key:value pair", 13}}, +} + +func TestIndentErrors(t *testing.T) { + for i, tt := range indentErrorTests { + slice := make([]uint8, 0) + buf := bytes.NewBuffer(slice) + if err := Indent(buf, []uint8(tt.in), "", ""); err != nil { + if !reflect.DeepEqual(err, tt.err) { + t.Errorf("#%d: Indent: %#v", i, err) + continue + } + } + } +} + +func TestNextValueBig(t *testing.T) { + initBig() + var scan scanner + item, rest, err := nextValue(jsonBig, &scan) + if err != nil { + t.Fatalf("nextValue: %s", err) + } + if len(item) != len(jsonBig) || &item[0] != &jsonBig[0] { + t.Errorf("invalid item: %d %d", len(item), len(jsonBig)) + } + if len(rest) != 0 { + t.Errorf("invalid rest: %d", len(rest)) + } + + item, rest, err = nextValue(append(jsonBig, "HELLO WORLD"...), &scan) + if err != nil { + t.Fatalf("nextValue extra: %s", err) + } + if len(item) != len(jsonBig) { + t.Errorf("invalid item: %d %d", len(item), len(jsonBig)) + } + if string(rest) != "HELLO WORLD" { + t.Errorf("invalid rest: %d", len(rest)) + } +} + +var benchScan scanner + +func BenchmarkSkipValue(b *testing.B) { + initBig() + for i := 0; i < b.N; i++ { + nextValue(jsonBig, &benchScan) + } + b.SetBytes(int64(len(jsonBig))) +} + +func diff(t *testing.T, a, b []byte) { + for i := 0; ; i++ { + if i >= len(a) || i >= len(b) || a[i] != b[i] { + j := i - 10 + if j < 0 { + j = 0 + } + t.Errorf("diverge at %d: «%s» vs «%s»", i, trim(a[j:]), trim(b[j:])) + return + } + } +} + +func trim(b []byte) []byte { + if len(b) > 20 { + return b[0:20] + } + return b +} + +// Generate a random JSON object. + +var jsonBig []byte + +func initBig() { + n := 10000 + if testing.Short() { + n = 100 + } + b, err := Marshal(genValue(n)) + if err != nil { + panic(err) + } + jsonBig = b +} + +func genValue(n int) interface{} { + if n > 1 { + switch rand.Intn(2) { + case 0: + return genArray(n) + case 1: + return genMap(n) + } + } + switch rand.Intn(3) { + case 0: + return rand.Intn(2) == 0 + case 1: + return rand.NormFloat64() + case 2: + return genString(30) + } + panic("unreachable") +} + +func genString(stddev float64) string { + n := int(math.Abs(rand.NormFloat64()*stddev + stddev/2)) + c := make([]rune, n) + for i := range c { + f := math.Abs(rand.NormFloat64()*64 + 32) + if f > 0x10ffff { + f = 0x10ffff + } + c[i] = rune(f) + } + return string(c) +} + +func genArray(n int) []interface{} { + f := int(math.Abs(rand.NormFloat64()) * math.Min(10, float64(n/2))) + if f > n { + f = n + } + if f < 1 { + f = 1 + } + x := make([]interface{}, f) + for i := range x { + x[i] = genValue(((i+1)*n)/f - (i*n)/f) + } + return x +} + +func genMap(n int) map[string]interface{} { + f := int(math.Abs(rand.NormFloat64()) * math.Min(10, float64(n/2))) + if f > n { + f = n + } + if n > 0 && f == 0 { + f = 1 + } + x := make(map[string]interface{}) + for i := 0; i < f; i++ { + x[genString(10)] = genValue(((i+1)*n)/f - (i*n)/f) + } + return x +} diff --git a/src/encoding/json/stream.go b/src/encoding/json/stream.go new file mode 100644 index 000000000..9566ecadc --- /dev/null +++ b/src/encoding/json/stream.go @@ -0,0 +1,200 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "bytes" + "errors" + "io" +) + +// A Decoder reads and decodes JSON objects from an input stream. +type Decoder struct { + r io.Reader + buf []byte + d decodeState + scan scanner + err error +} + +// NewDecoder returns a new decoder that reads from r. +// +// The decoder introduces its own buffering and may +// read data from r beyond the JSON values requested. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r: r} +} + +// UseNumber causes the Decoder to unmarshal a number into an interface{} as a +// Number instead of as a float64. +func (dec *Decoder) UseNumber() { dec.d.useNumber = true } + +// Decode reads the next JSON-encoded value from its +// input and stores it in the value pointed to by v. +// +// See the documentation for Unmarshal for details about +// the conversion of JSON into a Go value. +func (dec *Decoder) Decode(v interface{}) error { + if dec.err != nil { + return dec.err + } + + n, err := dec.readValue() + if err != nil { + return err + } + + // Don't save err from unmarshal into dec.err: + // the connection is still usable since we read a complete JSON + // object from it before the error happened. + dec.d.init(dec.buf[0:n]) + err = dec.d.unmarshal(v) + + // Slide rest of data down. + rest := copy(dec.buf, dec.buf[n:]) + dec.buf = dec.buf[0:rest] + + return err +} + +// Buffered returns a reader of the data remaining in the Decoder's +// buffer. The reader is valid until the next call to Decode. +func (dec *Decoder) Buffered() io.Reader { + return bytes.NewReader(dec.buf) +} + +// readValue reads a JSON value into dec.buf. +// It returns the length of the encoding. +func (dec *Decoder) readValue() (int, error) { + dec.scan.reset() + + scanp := 0 + var err error +Input: + for { + // Look in the buffer for a new value. + for i, c := range dec.buf[scanp:] { + dec.scan.bytes++ + v := dec.scan.step(&dec.scan, int(c)) + if v == scanEnd { + scanp += i + break Input + } + // scanEnd is delayed one byte. + // We might block trying to get that byte from src, + // so instead invent a space byte. + if (v == scanEndObject || v == scanEndArray) && dec.scan.step(&dec.scan, ' ') == scanEnd { + scanp += i + 1 + break Input + } + if v == scanError { + dec.err = dec.scan.err + return 0, dec.scan.err + } + } + scanp = len(dec.buf) + + // Did the last read have an error? + // Delayed until now to allow buffer scan. + if err != nil { + if err == io.EOF { + if dec.scan.step(&dec.scan, ' ') == scanEnd { + break Input + } + if nonSpace(dec.buf) { + err = io.ErrUnexpectedEOF + } + } + dec.err = err + return 0, err + } + + // Make room to read more into the buffer. + const minRead = 512 + if cap(dec.buf)-len(dec.buf) < minRead { + newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead) + copy(newBuf, dec.buf) + dec.buf = newBuf + } + + // Read. Delay error for next iteration (after scan). + var n int + n, err = dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)]) + dec.buf = dec.buf[0 : len(dec.buf)+n] + } + return scanp, nil +} + +func nonSpace(b []byte) bool { + for _, c := range b { + if !isSpace(rune(c)) { + return true + } + } + return false +} + +// An Encoder writes JSON objects to an output stream. +type Encoder struct { + w io.Writer + err error +} + +// NewEncoder returns a new encoder that writes to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{w: w} +} + +// Encode writes the JSON encoding of v to the stream, +// followed by a newline character. +// +// See the documentation for Marshal for details about the +// conversion of Go values to JSON. +func (enc *Encoder) Encode(v interface{}) error { + if enc.err != nil { + return enc.err + } + e := newEncodeState() + err := e.marshal(v) + if err != nil { + return err + } + + // Terminate each value with a newline. + // This makes the output look a little nicer + // when debugging, and some kind of space + // is required if the encoded value was a number, + // so that the reader knows there aren't more + // digits coming. + e.WriteByte('\n') + + if _, err = enc.w.Write(e.Bytes()); err != nil { + enc.err = err + } + encodeStatePool.Put(e) + return err +} + +// RawMessage is a raw encoded JSON object. +// It implements Marshaler and Unmarshaler and can +// be used to delay JSON decoding or precompute a JSON encoding. +type RawMessage []byte + +// MarshalJSON returns *m as the JSON encoding of m. +func (m *RawMessage) MarshalJSON() ([]byte, error) { + return *m, nil +} + +// UnmarshalJSON sets *m to a copy of data. +func (m *RawMessage) UnmarshalJSON(data []byte) error { + if m == nil { + return errors.New("json.RawMessage: UnmarshalJSON on nil pointer") + } + *m = append((*m)[0:0], data...) + return nil +} + +var _ Marshaler = (*RawMessage)(nil) +var _ Unmarshaler = (*RawMessage)(nil) diff --git a/src/encoding/json/stream_test.go b/src/encoding/json/stream_test.go new file mode 100644 index 000000000..b562e8769 --- /dev/null +++ b/src/encoding/json/stream_test.go @@ -0,0 +1,206 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "bytes" + "io/ioutil" + "net" + "reflect" + "strings" + "testing" +) + +// Test values for the stream test. +// One of each JSON kind. +var streamTest = []interface{}{ + 0.1, + "hello", + nil, + true, + false, + []interface{}{"a", "b", "c"}, + map[string]interface{}{"K": "Kelvin", "ß": "long s"}, + 3.14, // another value to make sure something can follow map +} + +var streamEncoded = `0.1 +"hello" +null +true +false +["a","b","c"] +{"ß":"long s","K":"Kelvin"} +3.14 +` + +func TestEncoder(t *testing.T) { + for i := 0; i <= len(streamTest); i++ { + var buf bytes.Buffer + enc := NewEncoder(&buf) + for j, v := range streamTest[0:i] { + if err := enc.Encode(v); err != nil { + t.Fatalf("encode #%d: %v", j, err) + } + } + if have, want := buf.String(), nlines(streamEncoded, i); have != want { + t.Errorf("encoding %d items: mismatch", i) + diff(t, []byte(have), []byte(want)) + break + } + } +} + +func TestDecoder(t *testing.T) { + for i := 0; i <= len(streamTest); i++ { + // Use stream without newlines as input, + // just to stress the decoder even more. + // Our test input does not include back-to-back numbers. + // Otherwise stripping the newlines would + // merge two adjacent JSON values. + var buf bytes.Buffer + for _, c := range nlines(streamEncoded, i) { + if c != '\n' { + buf.WriteRune(c) + } + } + out := make([]interface{}, i) + dec := NewDecoder(&buf) + for j := range out { + if err := dec.Decode(&out[j]); err != nil { + t.Fatalf("decode #%d/%d: %v", j, i, err) + } + } + if !reflect.DeepEqual(out, streamTest[0:i]) { + t.Errorf("decoding %d items: mismatch", i) + for j := range out { + if !reflect.DeepEqual(out[j], streamTest[j]) { + t.Errorf("#%d: have %v want %v", j, out[j], streamTest[j]) + } + } + break + } + } +} + +func TestDecoderBuffered(t *testing.T) { + r := strings.NewReader(`{"Name": "Gopher"} extra `) + var m struct { + Name string + } + d := NewDecoder(r) + err := d.Decode(&m) + if err != nil { + t.Fatal(err) + } + if m.Name != "Gopher" { + t.Errorf("Name = %q; want Gopher", m.Name) + } + rest, err := ioutil.ReadAll(d.Buffered()) + if err != nil { + t.Fatal(err) + } + if g, w := string(rest), " extra "; g != w { + t.Errorf("Remaining = %q; want %q", g, w) + } +} + +func nlines(s string, n int) string { + if n <= 0 { + return "" + } + for i, c := range s { + if c == '\n' { + if n--; n == 0 { + return s[0 : i+1] + } + } + } + return s +} + +func TestRawMessage(t *testing.T) { + // TODO(rsc): Should not need the * in *RawMessage + var data struct { + X float64 + Id *RawMessage + Y float32 + } + const raw = `["\u0056",null]` + const msg = `{"X":0.1,"Id":["\u0056",null],"Y":0.2}` + err := Unmarshal([]byte(msg), &data) + if err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if string([]byte(*data.Id)) != raw { + t.Fatalf("Raw mismatch: have %#q want %#q", []byte(*data.Id), raw) + } + b, err := Marshal(&data) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if string(b) != msg { + t.Fatalf("Marshal: have %#q want %#q", b, msg) + } +} + +func TestNullRawMessage(t *testing.T) { + // TODO(rsc): Should not need the * in *RawMessage + var data struct { + X float64 + Id *RawMessage + Y float32 + } + data.Id = new(RawMessage) + const msg = `{"X":0.1,"Id":null,"Y":0.2}` + err := Unmarshal([]byte(msg), &data) + if err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if data.Id != nil { + t.Fatalf("Raw mismatch: have non-nil, want nil") + } + b, err := Marshal(&data) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if string(b) != msg { + t.Fatalf("Marshal: have %#q want %#q", b, msg) + } +} + +var blockingTests = []string{ + `{"x": 1}`, + `[1, 2, 3]`, +} + +func TestBlocking(t *testing.T) { + for _, enc := range blockingTests { + r, w := net.Pipe() + go w.Write([]byte(enc)) + var val interface{} + + // If Decode reads beyond what w.Write writes above, + // it will block, and the test will deadlock. + if err := NewDecoder(r).Decode(&val); err != nil { + t.Errorf("decoding %s: %v", enc, err) + } + r.Close() + w.Close() + } +} + +func BenchmarkEncoderEncode(b *testing.B) { + b.ReportAllocs() + type T struct { + X, Y string + } + v := &T{"foo", "bar"} + for i := 0; i < b.N; i++ { + if err := NewEncoder(ioutil.Discard).Encode(v); err != nil { + b.Fatal(err) + } + } +} diff --git a/src/encoding/json/tagkey_test.go b/src/encoding/json/tagkey_test.go new file mode 100644 index 000000000..23e71c752 --- /dev/null +++ b/src/encoding/json/tagkey_test.go @@ -0,0 +1,115 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "testing" +) + +type basicLatin2xTag struct { + V string `json:"$%-/"` +} + +type basicLatin3xTag struct { + V string `json:"0123456789"` +} + +type basicLatin4xTag struct { + V string `json:"ABCDEFGHIJKLMO"` +} + +type basicLatin5xTag struct { + V string `json:"PQRSTUVWXYZ_"` +} + +type basicLatin6xTag struct { + V string `json:"abcdefghijklmno"` +} + +type basicLatin7xTag struct { + V string `json:"pqrstuvwxyz"` +} + +type miscPlaneTag struct { + V string `json:"色は匂へど"` +} + +type percentSlashTag struct { + V string `json:"text/html%"` // http://golang.org/issue/2718 +} + +type punctuationTag struct { + V string `json:"!#$%&()*+-./:<=>?@[]^_{|}~"` // http://golang.org/issue/3546 +} + +type emptyTag struct { + W string +} + +type misnamedTag struct { + X string `jsom:"Misnamed"` +} + +type badFormatTag struct { + Y string `:"BadFormat"` +} + +type badCodeTag struct { + Z string `json:" !\"#&'()*+,."` +} + +type spaceTag struct { + Q string `json:"With space"` +} + +type unicodeTag struct { + W string `json:"Ελλάδα"` +} + +var structTagObjectKeyTests = []struct { + raw interface{} + value string + key string +}{ + {basicLatin2xTag{"2x"}, "2x", "$%-/"}, + {basicLatin3xTag{"3x"}, "3x", "0123456789"}, + {basicLatin4xTag{"4x"}, "4x", "ABCDEFGHIJKLMO"}, + {basicLatin5xTag{"5x"}, "5x", "PQRSTUVWXYZ_"}, + {basicLatin6xTag{"6x"}, "6x", "abcdefghijklmno"}, + {basicLatin7xTag{"7x"}, "7x", "pqrstuvwxyz"}, + {miscPlaneTag{"いろはにほへと"}, "いろはにほへと", "色は匂へど"}, + {emptyTag{"Pour Moi"}, "Pour Moi", "W"}, + {misnamedTag{"Animal Kingdom"}, "Animal Kingdom", "X"}, + {badFormatTag{"Orfevre"}, "Orfevre", "Y"}, + {badCodeTag{"Reliable Man"}, "Reliable Man", "Z"}, + {percentSlashTag{"brut"}, "brut", "text/html%"}, + {punctuationTag{"Union Rags"}, "Union Rags", "!#$%&()*+-./:<=>?@[]^_{|}~"}, + {spaceTag{"Perreddu"}, "Perreddu", "With space"}, + {unicodeTag{"Loukanikos"}, "Loukanikos", "Ελλάδα"}, +} + +func TestStructTagObjectKey(t *testing.T) { + for _, tt := range structTagObjectKeyTests { + b, err := Marshal(tt.raw) + if err != nil { + t.Fatalf("Marshal(%#q) failed: %v", tt.raw, err) + } + var f interface{} + err = Unmarshal(b, &f) + if err != nil { + t.Fatalf("Unmarshal(%#q) failed: %v", b, err) + } + for i, v := range f.(map[string]interface{}) { + switch i { + case tt.key: + if s, ok := v.(string); !ok || s != tt.value { + t.Fatalf("Unexpected value: %#q, want %v", s, tt.value) + } + default: + t.Fatalf("Unexpected key: %#q, from %#q", i, b) + } + } + } +} diff --git a/src/encoding/json/tags.go b/src/encoding/json/tags.go new file mode 100644 index 000000000..c38fd5102 --- /dev/null +++ b/src/encoding/json/tags.go @@ -0,0 +1,44 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "strings" +) + +// tagOptions is the string following a comma in a struct field's "json" +// tag, or the empty string. It does not include the leading comma. +type tagOptions string + +// parseTag splits a struct field's json tag into its name and +// comma-separated options. +func parseTag(tag string) (string, tagOptions) { + if idx := strings.Index(tag, ","); idx != -1 { + return tag[:idx], tagOptions(tag[idx+1:]) + } + return tag, tagOptions("") +} + +// Contains reports whether a comma-separated list of options +// contains a particular substr flag. substr must be surrounded by a +// string boundary or commas. +func (o tagOptions) Contains(optionName string) bool { + if len(o) == 0 { + return false + } + s := string(o) + for s != "" { + var next string + i := strings.Index(s, ",") + if i >= 0 { + s, next = s[:i], s[i+1:] + } + if s == optionName { + return true + } + s = next + } + return false +} diff --git a/src/encoding/json/tags_test.go b/src/encoding/json/tags_test.go new file mode 100644 index 000000000..91fb18831 --- /dev/null +++ b/src/encoding/json/tags_test.go @@ -0,0 +1,28 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import ( + "testing" +) + +func TestTagParsing(t *testing.T) { + name, opts := parseTag("field,foobar,foo") + if name != "field" { + t.Fatalf("name = %q, want field", name) + } + for _, tt := range []struct { + opt string + want bool + }{ + {"foobar", true}, + {"foo", true}, + {"bar", false}, + } { + if opts.Contains(tt.opt) != tt.want { + t.Errorf("Contains(%q) = %v", tt.opt, !tt.want) + } + } +} diff --git a/src/encoding/json/testdata/code.json.gz b/src/encoding/json/testdata/code.json.gz Binary files differnew file mode 100644 index 000000000..0e2895b53 --- /dev/null +++ b/src/encoding/json/testdata/code.json.gz diff --git a/src/encoding/pem/pem.go b/src/encoding/pem/pem.go new file mode 100644 index 000000000..8ff7ee8c3 --- /dev/null +++ b/src/encoding/pem/pem.go @@ -0,0 +1,277 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pem implements the PEM data encoding, which originated in Privacy +// Enhanced Mail. The most common use of PEM encoding today is in TLS keys and +// certificates. See RFC 1421. +package pem + +import ( + "bytes" + "encoding/base64" + "io" + "sort" +) + +// A Block represents a PEM encoded structure. +// +// The encoded form is: +// -----BEGIN Type----- +// Headers +// base64-encoded Bytes +// -----END Type----- +// where Headers is a possibly empty sequence of Key: Value lines. +type Block struct { + Type string // The type, taken from the preamble (i.e. "RSA PRIVATE KEY"). + Headers map[string]string // Optional headers. + Bytes []byte // The decoded bytes of the contents. Typically a DER encoded ASN.1 structure. +} + +// getLine results the first \r\n or \n delineated line from the given byte +// array. The line does not include trailing whitespace or the trailing new +// line bytes. The remainder of the byte array (also not including the new line +// bytes) is also returned and this will always be smaller than the original +// argument. +func getLine(data []byte) (line, rest []byte) { + i := bytes.Index(data, []byte{'\n'}) + var j int + if i < 0 { + i = len(data) + j = i + } else { + j = i + 1 + if i > 0 && data[i-1] == '\r' { + i-- + } + } + return bytes.TrimRight(data[0:i], " \t"), data[j:] +} + +// removeWhitespace returns a copy of its input with all spaces, tab and +// newline characters removed. +func removeWhitespace(data []byte) []byte { + result := make([]byte, len(data)) + n := 0 + + for _, b := range data { + if b == ' ' || b == '\t' || b == '\r' || b == '\n' { + continue + } + result[n] = b + n++ + } + + return result[0:n] +} + +var pemStart = []byte("\n-----BEGIN ") +var pemEnd = []byte("\n-----END ") +var pemEndOfLine = []byte("-----") + +// Decode will find the next PEM formatted block (certificate, private key +// etc) in the input. It returns that block and the remainder of the input. If +// no PEM data is found, p is nil and the whole of the input is returned in +// rest. +func Decode(data []byte) (p *Block, rest []byte) { + // pemStart begins with a newline. However, at the very beginning of + // the byte array, we'll accept the start string without it. + rest = data + if bytes.HasPrefix(data, pemStart[1:]) { + rest = rest[len(pemStart)-1 : len(data)] + } else if i := bytes.Index(data, pemStart); i >= 0 { + rest = rest[i+len(pemStart) : len(data)] + } else { + return nil, data + } + + typeLine, rest := getLine(rest) + if !bytes.HasSuffix(typeLine, pemEndOfLine) { + return decodeError(data, rest) + } + typeLine = typeLine[0 : len(typeLine)-len(pemEndOfLine)] + + p = &Block{ + Headers: make(map[string]string), + Type: string(typeLine), + } + + for { + // This loop terminates because getLine's second result is + // always smaller than its argument. + if len(rest) == 0 { + return nil, data + } + line, next := getLine(rest) + + i := bytes.Index(line, []byte{':'}) + if i == -1 { + break + } + + // TODO(agl): need to cope with values that spread across lines. + key, val := line[0:i], line[i+1:] + key = bytes.TrimSpace(key) + val = bytes.TrimSpace(val) + p.Headers[string(key)] = string(val) + rest = next + } + + i := bytes.Index(rest, pemEnd) + if i < 0 { + return decodeError(data, rest) + } + base64Data := removeWhitespace(rest[0:i]) + + p.Bytes = make([]byte, base64.StdEncoding.DecodedLen(len(base64Data))) + n, err := base64.StdEncoding.Decode(p.Bytes, base64Data) + if err != nil { + return decodeError(data, rest) + } + p.Bytes = p.Bytes[0:n] + + _, rest = getLine(rest[i+len(pemEnd):]) + + return +} + +func decodeError(data, rest []byte) (*Block, []byte) { + // If we get here then we have rejected a likely looking, but + // ultimately invalid PEM block. We need to start over from a new + // position. We have consumed the preamble line and will have consumed + // any lines which could be header lines. However, a valid preamble + // line is not a valid header line, therefore we cannot have consumed + // the preamble line for the any subsequent block. Thus, we will always + // find any valid block, no matter what bytes precede it. + // + // For example, if the input is + // + // -----BEGIN MALFORMED BLOCK----- + // junk that may look like header lines + // or data lines, but no END line + // + // -----BEGIN ACTUAL BLOCK----- + // realdata + // -----END ACTUAL BLOCK----- + // + // we've failed to parse using the first BEGIN line + // and now will try again, using the second BEGIN line. + p, rest := Decode(rest) + if p == nil { + rest = data + } + return p, rest +} + +const pemLineLength = 64 + +type lineBreaker struct { + line [pemLineLength]byte + used int + out io.Writer +} + +func (l *lineBreaker) Write(b []byte) (n int, err error) { + if l.used+len(b) < pemLineLength { + copy(l.line[l.used:], b) + l.used += len(b) + return len(b), nil + } + + n, err = l.out.Write(l.line[0:l.used]) + if err != nil { + return + } + excess := pemLineLength - l.used + l.used = 0 + + n, err = l.out.Write(b[0:excess]) + if err != nil { + return + } + + n, err = l.out.Write([]byte{'\n'}) + if err != nil { + return + } + + return l.Write(b[excess:]) +} + +func (l *lineBreaker) Close() (err error) { + if l.used > 0 { + _, err = l.out.Write(l.line[0:l.used]) + if err != nil { + return + } + _, err = l.out.Write([]byte{'\n'}) + } + + return +} + +func writeHeader(out io.Writer, k, v string) error { + _, err := out.Write([]byte(k + ": " + v + "\n")) + return err +} + +func Encode(out io.Writer, b *Block) error { + if _, err := out.Write(pemStart[1:]); err != nil { + return err + } + if _, err := out.Write([]byte(b.Type + "-----\n")); err != nil { + return err + } + + if len(b.Headers) > 0 { + const procType = "Proc-Type" + h := make([]string, 0, len(b.Headers)) + hasProcType := false + for k := range b.Headers { + if k == procType { + hasProcType = true + continue + } + h = append(h, k) + } + // The Proc-Type header must be written first. + // See RFC 1421, section 4.6.1.1 + if hasProcType { + if err := writeHeader(out, procType, b.Headers[procType]); err != nil { + return err + } + } + // For consistency of output, write other headers sorted by key. + sort.Strings(h) + for _, k := range h { + if err := writeHeader(out, k, b.Headers[k]); err != nil { + return err + } + } + if _, err := out.Write([]byte{'\n'}); err != nil { + return err + } + } + + var breaker lineBreaker + breaker.out = out + + b64 := base64.NewEncoder(base64.StdEncoding, &breaker) + if _, err := b64.Write(b.Bytes); err != nil { + return err + } + b64.Close() + breaker.Close() + + if _, err := out.Write(pemEnd[1:]); err != nil { + return err + } + _, err := out.Write([]byte(b.Type + "-----\n")) + return err +} + +func EncodeToMemory(b *Block) []byte { + var buf bytes.Buffer + Encode(&buf, b) + return buf.Bytes() +} diff --git a/src/encoding/pem/pem_test.go b/src/encoding/pem/pem_test.go new file mode 100644 index 000000000..ccce42cf1 --- /dev/null +++ b/src/encoding/pem/pem_test.go @@ -0,0 +1,404 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pem + +import ( + "bytes" + "reflect" + "testing" +) + +type GetLineTest struct { + in, out1, out2 string +} + +var getLineTests = []GetLineTest{ + {"abc", "abc", ""}, + {"abc\r", "abc\r", ""}, + {"abc\n", "abc", ""}, + {"abc\r\n", "abc", ""}, + {"abc\nd", "abc", "d"}, + {"abc\r\nd", "abc", "d"}, + {"\nabc", "", "abc"}, + {"\r\nabc", "", "abc"}, +} + +func TestGetLine(t *testing.T) { + for i, test := range getLineTests { + x, y := getLine([]byte(test.in)) + if string(x) != test.out1 || string(y) != test.out2 { + t.Errorf("#%d got:%+v,%+v want:%s,%s", i, x, y, test.out1, test.out2) + } + } +} + +func TestDecode(t *testing.T) { + result, remainder := Decode([]byte(pemData)) + if !reflect.DeepEqual(result, certificate) { + t.Errorf("#0 got:%#v want:%#v", result, certificate) + } + result, remainder = Decode(remainder) + if !reflect.DeepEqual(result, privateKey) { + t.Errorf("#1 got:%#v want:%#v", result, privateKey) + } + result, _ = Decode([]byte(pemPrivateKey2)) + if !reflect.DeepEqual(result, privateKey2) { + t.Errorf("#2 got:%#v want:%#v", result, privateKey2) + } +} + +func TestEncode(t *testing.T) { + r := EncodeToMemory(privateKey2) + if string(r) != pemPrivateKey2 { + t.Errorf("got:%s want:%s", r, pemPrivateKey2) + } +} + +type lineBreakerTest struct { + in, out string +} + +const sixtyFourCharString = "0123456789012345678901234567890123456789012345678901234567890123" + +var lineBreakerTests = []lineBreakerTest{ + {"", ""}, + {"a", "a\n"}, + {"ab", "ab\n"}, + {sixtyFourCharString, sixtyFourCharString + "\n"}, + {sixtyFourCharString + "X", sixtyFourCharString + "\nX\n"}, + {sixtyFourCharString + sixtyFourCharString, sixtyFourCharString + "\n" + sixtyFourCharString + "\n"}, +} + +func TestLineBreaker(t *testing.T) { + for i, test := range lineBreakerTests { + buf := new(bytes.Buffer) + var breaker lineBreaker + breaker.out = buf + _, err := breaker.Write([]byte(test.in)) + if err != nil { + t.Errorf("#%d: error from Write: %s", i, err) + continue + } + err = breaker.Close() + if err != nil { + t.Errorf("#%d: error from Close: %s", i, err) + continue + } + + if string(buf.Bytes()) != test.out { + t.Errorf("#%d: got:%s want:%s", i, string(buf.Bytes()), test.out) + } + } + + for i, test := range lineBreakerTests { + buf := new(bytes.Buffer) + var breaker lineBreaker + breaker.out = buf + + for i := 0; i < len(test.in); i++ { + _, err := breaker.Write([]byte(test.in[i : i+1])) + if err != nil { + t.Errorf("#%d: error from Write (byte by byte): %s", i, err) + continue + } + } + err := breaker.Close() + if err != nil { + t.Errorf("#%d: error from Close (byte by byte): %s", i, err) + continue + } + + if string(buf.Bytes()) != test.out { + t.Errorf("#%d: (byte by byte) got:%s want:%s", i, string(buf.Bytes()), test.out) + } + } +} + +var pemData = `verify return:0 +-----BEGIN CERTIFICATE----- +sdlfkjskldfj + -----BEGIN CERTIFICATE----- +--- +Certificate chain + 0 s:/C=AU/ST=Somewhere/L=Someplace/O=Foo Bar/CN=foo.example.com + i:/C=ZA/O=CA Inc./CN=CA Inc +-----BEGIN CERTIFICATE----- +testing +-----BEGIN CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIID6TCCA1ICAQEwDQYJKoZIhvcNAQEFBQAwgYsxCzAJBgNVBAYTAlVTMRMwEQYD +VQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRQwEgYDVQQK +EwtHb29nbGUgSW5jLjEMMAoGA1UECxMDRW5nMQwwCgYDVQQDEwNhZ2wxHTAbBgkq +hkiG9w0BCQEWDmFnbEBnb29nbGUuY29tMB4XDTA5MDkwOTIyMDU0M1oXDTEwMDkw +OTIyMDU0M1owajELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAf +BgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEjMCEGA1UEAxMaZXVyb3Bh +LnNmby5jb3JwLmdvb2dsZS5jb20wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIK +AoICAQC6pgYt7/EibBDumASF+S0qvqdL/f+nouJw2T1Qc8GmXF/iiUcrsgzh/Fd8 +pDhz/T96Qg9IyR4ztuc2MXrmPra+zAuSf5bevFReSqvpIt8Duv0HbDbcqs/XKPfB +uMDe+of7a9GCywvAZ4ZUJcp0thqD9fKTTjUWOBzHY1uNE4RitrhmJCrbBGXbJ249 +bvgmb7jgdInH2PU7PT55hujvOoIsQW2osXBFRur4pF1wmVh4W4lTLD6pjfIMUcML +ICHEXEN73PDic8KS3EtNYCwoIld+tpIBjE1QOb1KOyuJBNW6Esw9ALZn7stWdYcE +qAwvv20egN2tEXqj7Q4/1ccyPZc3PQgC3FJ8Be2mtllM+80qf4dAaQ/fWvCtOrQ5 +pnfe9juQvCo8Y0VGlFcrSys/MzSg9LJ/24jZVgzQved/Qupsp89wVidwIzjt+WdS +fyWfH0/v1aQLvu5cMYuW//C0W2nlYziL5blETntM8My2ybNARy3ICHxCBv2RNtPI +WQVm+E9/W5rwh2IJR4DHn2LHwUVmT/hHNTdBLl5Uhwr4Wc7JhE7AVqb14pVNz1lr +5jxsp//ncIwftb7mZQ3DF03Yna+jJhpzx8CQoeLT6aQCHyzmH68MrHHT4MALPyUs +Pomjn71GNTtDeWAXibjCgdL6iHACCF6Htbl0zGlG0OAK+bdn0QIDAQABMA0GCSqG +SIb3DQEBBQUAA4GBAOKnQDtqBV24vVqvesL5dnmyFpFPXBn3WdFfwD6DzEb21UVG +5krmJiu+ViipORJPGMkgoL6BjU21XI95VQbun5P8vvg8Z+FnFsvRFY3e1CCzAVQY +ZsUkLw2I7zI/dNlWdB8Xp7v+3w9sX5N3J/WuJ1KOO5m26kRlHQo7EzT3974g +-----END CERTIFICATE----- + 1 s:/C=ZA/O=Ca Inc./CN=CA Inc + +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: DES-EDE3-CBC,80C7C7A09690757A + +eQp5ZkH6CyHBz7BZfUPxyLCCmftsBJ7HlqGb8Ld21cSwnzWZ4/SIlhyrUtsfw7VR +2TTwA+odo9ex7GdxOTaH8oZFumIRoiEjHsk8U7Bhntp+ekkPP79xunnN7hb7hkhr +yGDQZgA7s2cQHQ71v3gwT2BACAft26jCjbM1wgNzBnJ8M0Rzn68YWqaPtdBu8qb/ +zVR5JB1mnqvTSbFsfF5yMc6o2WQ9jJCl6KypnMl+BpL+dlvdjYVK4l9lYsB1Hs3d ++zDBbWxos818zzhS8/y6eIfiSG27cqrbhURbmgiSfDXjncK4m/pLcQ7mmBL6mFOr +3Pj4jepzgOiFRL6MKE//h62fZvI1ErYr8VunHEykgKNhChDvb1RO6LEfqKBu+Ivw +TB6fBhW3TCLMnVPYVoYwA+fHNTmZZm8BEonlIMfI+KktjWUg4Oia+NI6vKcPpFox +hSnlGgCtvfEaq5/H4kHJp95eOpnFsLviw2seHNkz/LxJMRP1X428+DpYW/QD/0JU +tJSuC/q9FUHL6RI3u/Asrv8pCb4+D7i1jW/AMIdJTtycOGsbPxQA7yHMWujHmeb1 +BTiHcL3s3KrJu1vDVrshvxfnz71KTeNnZH8UbOqT5i7fPGyXtY1XJddcbI/Q6tXf +wHFsZc20TzSdsVLBtwksUacpbDogcEVMctnNrB8FIrB3vZEv9Q0Z1VeY7nmTpF+6 +a+z2P7acL7j6A6Pr3+q8P9CPiPC7zFonVzuVPyB8GchGR2hytyiOVpuD9+k8hcuw +ZWAaUoVtWIQ52aKS0p19G99hhb+IVANC4akkdHV4SP8i7MVNZhfUmg== +-----END RSA PRIVATE KEY-----` + +var certificate = &Block{Type: "CERTIFICATE", + Headers: map[string]string{}, + Bytes: []uint8{0x30, 0x82, 0x3, 0xe9, 0x30, 0x82, 0x3, 0x52, 0x2, 0x1, + 0x1, 0x30, 0xd, 0x6, 0x9, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0xd, + 0x1, 0x1, 0x5, 0x5, 0x0, 0x30, 0x81, 0x8b, 0x31, 0xb, 0x30, + 0x9, 0x6, 0x3, 0x55, 0x4, 0x6, 0x13, 0x2, 0x55, 0x53, 0x31, + 0x13, 0x30, 0x11, 0x6, 0x3, 0x55, 0x4, 0x8, 0x13, 0xa, 0x43, + 0x61, 0x6c, 0x69, 0x66, 0x6f, 0x72, 0x6e, 0x69, 0x61, 0x31, + 0x16, 0x30, 0x14, 0x6, 0x3, 0x55, 0x4, 0x7, 0x13, 0xd, 0x53, + 0x61, 0x6e, 0x20, 0x46, 0x72, 0x61, 0x6e, 0x63, 0x69, 0x73, + 0x63, 0x6f, 0x31, 0x14, 0x30, 0x12, 0x6, 0x3, 0x55, 0x4, 0xa, + 0x13, 0xb, 0x47, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x20, 0x49, + 0x6e, 0x63, 0x2e, 0x31, 0xc, 0x30, 0xa, 0x6, 0x3, 0x55, 0x4, + 0xb, 0x13, 0x3, 0x45, 0x6e, 0x67, 0x31, 0xc, 0x30, 0xa, 0x6, + 0x3, 0x55, 0x4, 0x3, 0x13, 0x3, 0x61, 0x67, 0x6c, 0x31, 0x1d, + 0x30, 0x1b, 0x6, 0x9, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0xd, 0x1, + 0x9, 0x1, 0x16, 0xe, 0x61, 0x67, 0x6c, 0x40, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x30, 0x1e, 0x17, + 0xd, 0x30, 0x39, 0x30, 0x39, 0x30, 0x39, 0x32, 0x32, 0x30, + 0x35, 0x34, 0x33, 0x5a, 0x17, 0xd, 0x31, 0x30, 0x30, 0x39, + 0x30, 0x39, 0x32, 0x32, 0x30, 0x35, 0x34, 0x33, 0x5a, 0x30, + 0x6a, 0x31, 0xb, 0x30, 0x9, 0x6, 0x3, 0x55, 0x4, 0x6, 0x13, + 0x2, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x6, 0x3, 0x55, 0x4, + 0x8, 0x13, 0xa, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x6, 0x3, 0x55, 0x4, 0xa, + 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, + 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, + 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x31, 0x23, 0x30, 0x21, + 0x6, 0x3, 0x55, 0x4, 0x3, 0x13, 0x1a, 0x65, 0x75, 0x72, 0x6f, + 0x70, 0x61, 0x2e, 0x73, 0x66, 0x6f, 0x2e, 0x63, 0x6f, 0x72, + 0x70, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x63, + 0x6f, 0x6d, 0x30, 0x82, 0x2, 0x22, 0x30, 0xd, 0x6, 0x9, 0x2a, + 0x86, 0x48, 0x86, 0xf7, 0xd, 0x1, 0x1, 0x1, 0x5, 0x0, 0x3, + 0x82, 0x2, 0xf, 0x0, 0x30, 0x82, 0x2, 0xa, 0x2, 0x82, 0x2, 0x1, + 0x0, 0xba, 0xa6, 0x6, 0x2d, 0xef, 0xf1, 0x22, 0x6c, 0x10, 0xee, + 0x98, 0x4, 0x85, 0xf9, 0x2d, 0x2a, 0xbe, 0xa7, 0x4b, 0xfd, + 0xff, 0xa7, 0xa2, 0xe2, 0x70, 0xd9, 0x3d, 0x50, 0x73, 0xc1, + 0xa6, 0x5c, 0x5f, 0xe2, 0x89, 0x47, 0x2b, 0xb2, 0xc, 0xe1, + 0xfc, 0x57, 0x7c, 0xa4, 0x38, 0x73, 0xfd, 0x3f, 0x7a, 0x42, + 0xf, 0x48, 0xc9, 0x1e, 0x33, 0xb6, 0xe7, 0x36, 0x31, 0x7a, + 0xe6, 0x3e, 0xb6, 0xbe, 0xcc, 0xb, 0x92, 0x7f, 0x96, 0xde, + 0xbc, 0x54, 0x5e, 0x4a, 0xab, 0xe9, 0x22, 0xdf, 0x3, 0xba, + 0xfd, 0x7, 0x6c, 0x36, 0xdc, 0xaa, 0xcf, 0xd7, 0x28, 0xf7, + 0xc1, 0xb8, 0xc0, 0xde, 0xfa, 0x87, 0xfb, 0x6b, 0xd1, 0x82, + 0xcb, 0xb, 0xc0, 0x67, 0x86, 0x54, 0x25, 0xca, 0x74, 0xb6, + 0x1a, 0x83, 0xf5, 0xf2, 0x93, 0x4e, 0x35, 0x16, 0x38, 0x1c, + 0xc7, 0x63, 0x5b, 0x8d, 0x13, 0x84, 0x62, 0xb6, 0xb8, 0x66, + 0x24, 0x2a, 0xdb, 0x4, 0x65, 0xdb, 0x27, 0x6e, 0x3d, 0x6e, + 0xf8, 0x26, 0x6f, 0xb8, 0xe0, 0x74, 0x89, 0xc7, 0xd8, 0xf5, + 0x3b, 0x3d, 0x3e, 0x79, 0x86, 0xe8, 0xef, 0x3a, 0x82, 0x2c, + 0x41, 0x6d, 0xa8, 0xb1, 0x70, 0x45, 0x46, 0xea, 0xf8, 0xa4, + 0x5d, 0x70, 0x99, 0x58, 0x78, 0x5b, 0x89, 0x53, 0x2c, 0x3e, + 0xa9, 0x8d, 0xf2, 0xc, 0x51, 0xc3, 0xb, 0x20, 0x21, 0xc4, 0x5c, + 0x43, 0x7b, 0xdc, 0xf0, 0xe2, 0x73, 0xc2, 0x92, 0xdc, 0x4b, + 0x4d, 0x60, 0x2c, 0x28, 0x22, 0x57, 0x7e, 0xb6, 0x92, 0x1, + 0x8c, 0x4d, 0x50, 0x39, 0xbd, 0x4a, 0x3b, 0x2b, 0x89, 0x4, + 0xd5, 0xba, 0x12, 0xcc, 0x3d, 0x0, 0xb6, 0x67, 0xee, 0xcb, + 0x56, 0x75, 0x87, 0x4, 0xa8, 0xc, 0x2f, 0xbf, 0x6d, 0x1e, 0x80, + 0xdd, 0xad, 0x11, 0x7a, 0xa3, 0xed, 0xe, 0x3f, 0xd5, 0xc7, + 0x32, 0x3d, 0x97, 0x37, 0x3d, 0x8, 0x2, 0xdc, 0x52, 0x7c, 0x5, + 0xed, 0xa6, 0xb6, 0x59, 0x4c, 0xfb, 0xcd, 0x2a, 0x7f, 0x87, + 0x40, 0x69, 0xf, 0xdf, 0x5a, 0xf0, 0xad, 0x3a, 0xb4, 0x39, + 0xa6, 0x77, 0xde, 0xf6, 0x3b, 0x90, 0xbc, 0x2a, 0x3c, 0x63, + 0x45, 0x46, 0x94, 0x57, 0x2b, 0x4b, 0x2b, 0x3f, 0x33, 0x34, + 0xa0, 0xf4, 0xb2, 0x7f, 0xdb, 0x88, 0xd9, 0x56, 0xc, 0xd0, + 0xbd, 0xe7, 0x7f, 0x42, 0xea, 0x6c, 0xa7, 0xcf, 0x70, 0x56, + 0x27, 0x70, 0x23, 0x38, 0xed, 0xf9, 0x67, 0x52, 0x7f, 0x25, + 0x9f, 0x1f, 0x4f, 0xef, 0xd5, 0xa4, 0xb, 0xbe, 0xee, 0x5c, + 0x31, 0x8b, 0x96, 0xff, 0xf0, 0xb4, 0x5b, 0x69, 0xe5, 0x63, + 0x38, 0x8b, 0xe5, 0xb9, 0x44, 0x4e, 0x7b, 0x4c, 0xf0, 0xcc, + 0xb6, 0xc9, 0xb3, 0x40, 0x47, 0x2d, 0xc8, 0x8, 0x7c, 0x42, 0x6, + 0xfd, 0x91, 0x36, 0xd3, 0xc8, 0x59, 0x5, 0x66, 0xf8, 0x4f, + 0x7f, 0x5b, 0x9a, 0xf0, 0x87, 0x62, 0x9, 0x47, 0x80, 0xc7, + 0x9f, 0x62, 0xc7, 0xc1, 0x45, 0x66, 0x4f, 0xf8, 0x47, 0x35, + 0x37, 0x41, 0x2e, 0x5e, 0x54, 0x87, 0xa, 0xf8, 0x59, 0xce, + 0xc9, 0x84, 0x4e, 0xc0, 0x56, 0xa6, 0xf5, 0xe2, 0x95, 0x4d, + 0xcf, 0x59, 0x6b, 0xe6, 0x3c, 0x6c, 0xa7, 0xff, 0xe7, 0x70, + 0x8c, 0x1f, 0xb5, 0xbe, 0xe6, 0x65, 0xd, 0xc3, 0x17, 0x4d, + 0xd8, 0x9d, 0xaf, 0xa3, 0x26, 0x1a, 0x73, 0xc7, 0xc0, 0x90, + 0xa1, 0xe2, 0xd3, 0xe9, 0xa4, 0x2, 0x1f, 0x2c, 0xe6, 0x1f, + 0xaf, 0xc, 0xac, 0x71, 0xd3, 0xe0, 0xc0, 0xb, 0x3f, 0x25, 0x2c, + 0x3e, 0x89, 0xa3, 0x9f, 0xbd, 0x46, 0x35, 0x3b, 0x43, 0x79, + 0x60, 0x17, 0x89, 0xb8, 0xc2, 0x81, 0xd2, 0xfa, 0x88, 0x70, + 0x2, 0x8, 0x5e, 0x87, 0xb5, 0xb9, 0x74, 0xcc, 0x69, 0x46, 0xd0, + 0xe0, 0xa, 0xf9, 0xb7, 0x67, 0xd1, 0x2, 0x3, 0x1, 0x0, 0x1, + 0x30, 0xd, 0x6, 0x9, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0xd, 0x1, + 0x1, 0x5, 0x5, 0x0, 0x3, 0x81, 0x81, 0x0, 0xe2, 0xa7, 0x40, + 0x3b, 0x6a, 0x5, 0x5d, 0xb8, 0xbd, 0x5a, 0xaf, 0x7a, 0xc2, + 0xf9, 0x76, 0x79, 0xb2, 0x16, 0x91, 0x4f, 0x5c, 0x19, 0xf7, + 0x59, 0xd1, 0x5f, 0xc0, 0x3e, 0x83, 0xcc, 0x46, 0xf6, 0xd5, + 0x45, 0x46, 0xe6, 0x4a, 0xe6, 0x26, 0x2b, 0xbe, 0x56, 0x28, + 0xa9, 0x39, 0x12, 0x4f, 0x18, 0xc9, 0x20, 0xa0, 0xbe, 0x81, + 0x8d, 0x4d, 0xb5, 0x5c, 0x8f, 0x79, 0x55, 0x6, 0xee, 0x9f, + 0x93, 0xfc, 0xbe, 0xf8, 0x3c, 0x67, 0xe1, 0x67, 0x16, 0xcb, + 0xd1, 0x15, 0x8d, 0xde, 0xd4, 0x20, 0xb3, 0x1, 0x54, 0x18, + 0x66, 0xc5, 0x24, 0x2f, 0xd, 0x88, 0xef, 0x32, 0x3f, 0x74, + 0xd9, 0x56, 0x74, 0x1f, 0x17, 0xa7, 0xbb, 0xfe, 0xdf, 0xf, + 0x6c, 0x5f, 0x93, 0x77, 0x27, 0xf5, 0xae, 0x27, 0x52, 0x8e, + 0x3b, 0x99, 0xb6, 0xea, 0x44, 0x65, 0x1d, 0xa, 0x3b, 0x13, + 0x34, 0xf7, 0xf7, 0xbe, 0x20, + }, +} + +var privateKey = &Block{Type: "RSA PRIVATE KEY", + Headers: map[string]string{"DEK-Info": "DES-EDE3-CBC,80C7C7A09690757A", "Proc-Type": "4,ENCRYPTED"}, + Bytes: []uint8{0x79, 0xa, 0x79, 0x66, 0x41, 0xfa, 0xb, + 0x21, 0xc1, 0xcf, 0xb0, 0x59, 0x7d, 0x43, 0xf1, 0xc8, 0xb0, + 0x82, 0x99, 0xfb, 0x6c, 0x4, 0x9e, 0xc7, 0x96, 0xa1, 0x9b, + 0xf0, 0xb7, 0x76, 0xd5, 0xc4, 0xb0, 0x9f, 0x35, 0x99, 0xe3, + 0xf4, 0x88, 0x96, 0x1c, 0xab, 0x52, 0xdb, 0x1f, 0xc3, 0xb5, + 0x51, 0xd9, 0x34, 0xf0, 0x3, 0xea, 0x1d, 0xa3, 0xd7, 0xb1, + 0xec, 0x67, 0x71, 0x39, 0x36, 0x87, 0xf2, 0x86, 0x45, 0xba, + 0x62, 0x11, 0xa2, 0x21, 0x23, 0x1e, 0xc9, 0x3c, 0x53, 0xb0, + 0x61, 0x9e, 0xda, 0x7e, 0x7a, 0x49, 0xf, 0x3f, 0xbf, 0x71, + 0xba, 0x79, 0xcd, 0xee, 0x16, 0xfb, 0x86, 0x48, 0x6b, 0xc8, + 0x60, 0xd0, 0x66, 0x0, 0x3b, 0xb3, 0x67, 0x10, 0x1d, 0xe, + 0xf5, 0xbf, 0x78, 0x30, 0x4f, 0x60, 0x40, 0x8, 0x7, 0xed, + 0xdb, 0xa8, 0xc2, 0x8d, 0xb3, 0x35, 0xc2, 0x3, 0x73, 0x6, + 0x72, 0x7c, 0x33, 0x44, 0x73, 0x9f, 0xaf, 0x18, 0x5a, 0xa6, + 0x8f, 0xb5, 0xd0, 0x6e, 0xf2, 0xa6, 0xff, 0xcd, 0x54, 0x79, + 0x24, 0x1d, 0x66, 0x9e, 0xab, 0xd3, 0x49, 0xb1, 0x6c, 0x7c, + 0x5e, 0x72, 0x31, 0xce, 0xa8, 0xd9, 0x64, 0x3d, 0x8c, 0x90, + 0xa5, 0xe8, 0xac, 0xa9, 0x9c, 0xc9, 0x7e, 0x6, 0x92, 0xfe, + 0x76, 0x5b, 0xdd, 0x8d, 0x85, 0x4a, 0xe2, 0x5f, 0x65, 0x62, + 0xc0, 0x75, 0x1e, 0xcd, 0xdd, 0xfb, 0x30, 0xc1, 0x6d, 0x6c, + 0x68, 0xb3, 0xcd, 0x7c, 0xcf, 0x38, 0x52, 0xf3, 0xfc, 0xba, + 0x78, 0x87, 0xe2, 0x48, 0x6d, 0xbb, 0x72, 0xaa, 0xdb, 0x85, + 0x44, 0x5b, 0x9a, 0x8, 0x92, 0x7c, 0x35, 0xe3, 0x9d, 0xc2, + 0xb8, 0x9b, 0xfa, 0x4b, 0x71, 0xe, 0xe6, 0x98, 0x12, 0xfa, + 0x98, 0x53, 0xab, 0xdc, 0xf8, 0xf8, 0x8d, 0xea, 0x73, 0x80, + 0xe8, 0x85, 0x44, 0xbe, 0x8c, 0x28, 0x4f, 0xff, 0x87, 0xad, + 0x9f, 0x66, 0xf2, 0x35, 0x12, 0xb6, 0x2b, 0xf1, 0x5b, 0xa7, + 0x1c, 0x4c, 0xa4, 0x80, 0xa3, 0x61, 0xa, 0x10, 0xef, 0x6f, + 0x54, 0x4e, 0xe8, 0xb1, 0x1f, 0xa8, 0xa0, 0x6e, 0xf8, 0x8b, + 0xf0, 0x4c, 0x1e, 0x9f, 0x6, 0x15, 0xb7, 0x4c, 0x22, 0xcc, + 0x9d, 0x53, 0xd8, 0x56, 0x86, 0x30, 0x3, 0xe7, 0xc7, 0x35, + 0x39, 0x99, 0x66, 0x6f, 0x1, 0x12, 0x89, 0xe5, 0x20, 0xc7, + 0xc8, 0xf8, 0xa9, 0x2d, 0x8d, 0x65, 0x20, 0xe0, 0xe8, 0x9a, + 0xf8, 0xd2, 0x3a, 0xbc, 0xa7, 0xf, 0xa4, 0x5a, 0x31, 0x85, + 0x29, 0xe5, 0x1a, 0x0, 0xad, 0xbd, 0xf1, 0x1a, 0xab, 0x9f, + 0xc7, 0xe2, 0x41, 0xc9, 0xa7, 0xde, 0x5e, 0x3a, 0x99, 0xc5, + 0xb0, 0xbb, 0xe2, 0xc3, 0x6b, 0x1e, 0x1c, 0xd9, 0x33, 0xfc, + 0xbc, 0x49, 0x31, 0x13, 0xf5, 0x5f, 0x8d, 0xbc, 0xf8, 0x3a, + 0x58, 0x5b, 0xf4, 0x3, 0xff, 0x42, 0x54, 0xb4, 0x94, 0xae, + 0xb, 0xfa, 0xbd, 0x15, 0x41, 0xcb, 0xe9, 0x12, 0x37, 0xbb, + 0xf0, 0x2c, 0xae, 0xff, 0x29, 0x9, 0xbe, 0x3e, 0xf, 0xb8, + 0xb5, 0x8d, 0x6f, 0xc0, 0x30, 0x87, 0x49, 0x4e, 0xdc, 0x9c, + 0x38, 0x6b, 0x1b, 0x3f, 0x14, 0x0, 0xef, 0x21, 0xcc, 0x5a, + 0xe8, 0xc7, 0x99, 0xe6, 0xf5, 0x5, 0x38, 0x87, 0x70, 0xbd, + 0xec, 0xdc, 0xaa, 0xc9, 0xbb, 0x5b, 0xc3, 0x56, 0xbb, 0x21, + 0xbf, 0x17, 0xe7, 0xcf, 0xbd, 0x4a, 0x4d, 0xe3, 0x67, 0x64, + 0x7f, 0x14, 0x6c, 0xea, 0x93, 0xe6, 0x2e, 0xdf, 0x3c, 0x6c, + 0x97, 0xb5, 0x8d, 0x57, 0x25, 0xd7, 0x5c, 0x6c, 0x8f, 0xd0, + 0xea, 0xd5, 0xdf, 0xc0, 0x71, 0x6c, 0x65, 0xcd, 0xb4, 0x4f, + 0x34, 0x9d, 0xb1, 0x52, 0xc1, 0xb7, 0x9, 0x2c, 0x51, 0xa7, + 0x29, 0x6c, 0x3a, 0x20, 0x70, 0x45, 0x4c, 0x72, 0xd9, 0xcd, + 0xac, 0x1f, 0x5, 0x22, 0xb0, 0x77, 0xbd, 0x91, 0x2f, 0xf5, + 0xd, 0x19, 0xd5, 0x57, 0x98, 0xee, 0x79, 0x93, 0xa4, 0x5f, + 0xba, 0x6b, 0xec, 0xf6, 0x3f, 0xb6, 0x9c, 0x2f, 0xb8, 0xfa, + 0x3, 0xa3, 0xeb, 0xdf, 0xea, 0xbc, 0x3f, 0xd0, 0x8f, 0x88, + 0xf0, 0xbb, 0xcc, 0x5a, 0x27, 0x57, 0x3b, 0x95, 0x3f, 0x20, + 0x7c, 0x19, 0xc8, 0x46, 0x47, 0x68, 0x72, 0xb7, 0x28, 0x8e, + 0x56, 0x9b, 0x83, 0xf7, 0xe9, 0x3c, 0x85, 0xcb, 0xb0, 0x65, + 0x60, 0x1a, 0x52, 0x85, 0x6d, 0x58, 0x84, 0x39, 0xd9, 0xa2, + 0x92, 0xd2, 0x9d, 0x7d, 0x1b, 0xdf, 0x61, 0x85, 0xbf, 0x88, + 0x54, 0x3, 0x42, 0xe1, 0xa9, 0x24, 0x74, 0x75, 0x78, 0x48, + 0xff, 0x22, 0xec, 0xc5, 0x4d, 0x66, 0x17, 0xd4, 0x9a, + }, +} + +var privateKey2 = &Block{ + Type: "RSA PRIVATE KEY", + Headers: map[string]string{ + "Proc-Type": "4,ENCRYPTED", + "DEK-Info": "AES-128-CBC,BFCD243FEDBB40A4AA6DDAA1335473A4", + "Content-Domain": "RFC822", + }, + Bytes: []uint8{ + 0xa8, 0x35, 0xcc, 0x2b, 0xb9, 0xcb, 0x21, 0xab, 0xc0, + 0x9d, 0x76, 0x61, 0x0, 0xf4, 0x81, 0xad, 0x69, 0xd2, + 0xc0, 0x42, 0x41, 0x3b, 0xe4, 0x3c, 0xaf, 0x59, 0x5e, + 0x6d, 0x2a, 0x3c, 0x9c, 0xa1, 0xa4, 0x5e, 0x68, 0x37, + 0xc4, 0x8c, 0x70, 0x1c, 0xa9, 0x18, 0xe6, 0xc2, 0x2b, + 0x8a, 0x91, 0xdc, 0x2d, 0x1f, 0x8, 0x23, 0x39, 0xf1, + 0x4b, 0x8b, 0x1b, 0x2f, 0x46, 0xb, 0xb2, 0x26, 0xba, + 0x4f, 0x40, 0x80, 0x39, 0xc4, 0xb1, 0xcb, 0x3b, 0xb4, + 0x65, 0x3f, 0x1b, 0xb2, 0xf7, 0x8, 0xd2, 0xc6, 0xd5, + 0xa8, 0x9f, 0x23, 0x69, 0xb6, 0x3d, 0xf9, 0xac, 0x1c, + 0xb3, 0x13, 0x87, 0x64, 0x4, 0x37, 0xdb, 0x40, 0xc8, + 0x82, 0xc, 0xd0, 0xf8, 0x21, 0x7c, 0xdc, 0xbd, 0x9, 0x4, + 0x20, 0x16, 0xb0, 0x97, 0xe2, 0x6d, 0x56, 0x1d, 0xe3, + 0xec, 0xf0, 0xfc, 0xe2, 0x56, 0xad, 0xa4, 0x3, 0x70, + 0x6d, 0x63, 0x3c, 0x1, 0xbe, 0x3e, 0x28, 0x38, 0x6f, + 0xc0, 0xe6, 0xfd, 0x85, 0xd1, 0x53, 0xa8, 0x9b, 0xcb, + 0xd4, 0x4, 0xb1, 0x73, 0xb9, 0x73, 0x32, 0xd6, 0x7a, + 0xc6, 0x29, 0x25, 0xa5, 0xda, 0x17, 0x93, 0x7a, 0x10, + 0xe8, 0x41, 0xfb, 0xa5, 0x17, 0x20, 0xf8, 0x4e, 0xe9, + 0xe3, 0x8f, 0x51, 0x20, 0x13, 0xbb, 0xde, 0xb7, 0x93, + 0xae, 0x13, 0x8a, 0xf6, 0x9, 0xf4, 0xa6, 0x41, 0xe0, + 0x2b, 0x51, 0x1a, 0x30, 0x38, 0xd, 0xb1, 0x3b, 0x67, + 0x87, 0x64, 0xf5, 0xca, 0x32, 0x67, 0xd1, 0xc8, 0xa5, + 0x3d, 0x23, 0x72, 0xc4, 0x6, 0xaf, 0x8f, 0x7b, 0x26, + 0xac, 0x3c, 0x75, 0x91, 0xa1, 0x0, 0x13, 0xc6, 0x5c, + 0x49, 0xd5, 0x3c, 0xe7, 0xb2, 0xb2, 0x99, 0xe0, 0xd5, + 0x25, 0xfa, 0xe2, 0x12, 0x80, 0x37, 0x85, 0xcf, 0x92, + 0xca, 0x1b, 0x9f, 0xf3, 0x4e, 0xd8, 0x80, 0xef, 0x3c, + 0xce, 0xcd, 0xf5, 0x90, 0x9e, 0xf9, 0xa7, 0xb2, 0xc, + 0x49, 0x4, 0xf1, 0x9, 0x8f, 0xea, 0x63, 0xd2, 0x70, + 0xbb, 0x86, 0xbf, 0x34, 0xab, 0xb2, 0x3, 0xb1, 0x59, + 0x33, 0x16, 0x17, 0xb0, 0xdb, 0x77, 0x38, 0xf4, 0xb4, + 0x94, 0xb, 0x25, 0x16, 0x7e, 0x22, 0xd4, 0xf9, 0x22, + 0xb9, 0x78, 0xa3, 0x4, 0x84, 0x4, 0xd2, 0xda, 0x84, + 0x2d, 0x63, 0xdd, 0xf8, 0x50, 0x6a, 0xf6, 0xe3, 0xf5, + 0x65, 0x40, 0x7c, 0xa9, + }, +} + +var pemPrivateKey2 = `-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +Content-Domain: RFC822 +DEK-Info: AES-128-CBC,BFCD243FEDBB40A4AA6DDAA1335473A4 + +qDXMK7nLIavAnXZhAPSBrWnSwEJBO+Q8r1lebSo8nKGkXmg3xIxwHKkY5sIripHc +LR8IIznxS4sbL0YLsia6T0CAOcSxyzu0ZT8bsvcI0sbVqJ8jabY9+awcsxOHZAQ3 +20DIggzQ+CF83L0JBCAWsJfibVYd4+zw/OJWraQDcG1jPAG+Pig4b8Dm/YXRU6ib +y9QEsXO5czLWesYpJaXaF5N6EOhB+6UXIPhO6eOPUSATu963k64TivYJ9KZB4CtR +GjA4DbE7Z4dk9coyZ9HIpT0jcsQGr497Jqw8dZGhABPGXEnVPOeyspng1SX64hKA +N4XPksobn/NO2IDvPM7N9ZCe+aeyDEkE8QmP6mPScLuGvzSrsgOxWTMWF7Dbdzj0 +tJQLJRZ+ItT5Irl4owSEBNLahC1j3fhQavbj9WVAfKk= +-----END RSA PRIVATE KEY----- +` diff --git a/src/encoding/xml/atom_test.go b/src/encoding/xml/atom_test.go new file mode 100644 index 000000000..a71284312 --- /dev/null +++ b/src/encoding/xml/atom_test.go @@ -0,0 +1,56 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import "time" + +var atomValue = &Feed{ + XMLName: Name{"http://www.w3.org/2005/Atom", "feed"}, + Title: "Example Feed", + Link: []Link{{Href: "http://example.org/"}}, + Updated: ParseTime("2003-12-13T18:30:02Z"), + Author: Person{Name: "John Doe"}, + Id: "urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6", + + Entry: []Entry{ + { + Title: "Atom-Powered Robots Run Amok", + Link: []Link{{Href: "http://example.org/2003/12/13/atom03"}}, + Id: "urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a", + Updated: ParseTime("2003-12-13T18:30:02Z"), + Summary: NewText("Some text."), + }, + }, +} + +var atomXml = `` + + `<feed xmlns="http://www.w3.org/2005/Atom" updated="2003-12-13T18:30:02Z">` + + `<title>Example Feed</title>` + + `<id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</id>` + + `<link href="http://example.org/"></link>` + + `<author><name>John Doe</name><uri></uri><email></email></author>` + + `<entry>` + + `<title>Atom-Powered Robots Run Amok</title>` + + `<id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</id>` + + `<link href="http://example.org/2003/12/13/atom03"></link>` + + `<updated>2003-12-13T18:30:02Z</updated>` + + `<author><name></name><uri></uri><email></email></author>` + + `<summary>Some text.</summary>` + + `</entry>` + + `</feed>` + +func ParseTime(str string) time.Time { + t, err := time.Parse(time.RFC3339, str) + if err != nil { + panic(err) + } + return t +} + +func NewText(text string) Text { + return Text{ + Body: text, + } +} diff --git a/src/encoding/xml/example_test.go b/src/encoding/xml/example_test.go new file mode 100644 index 000000000..becedd583 --- /dev/null +++ b/src/encoding/xml/example_test.go @@ -0,0 +1,151 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml_test + +import ( + "encoding/xml" + "fmt" + "os" +) + +func ExampleMarshalIndent() { + type Address struct { + City, State string + } + type Person struct { + XMLName xml.Name `xml:"person"` + Id int `xml:"id,attr"` + FirstName string `xml:"name>first"` + LastName string `xml:"name>last"` + Age int `xml:"age"` + Height float32 `xml:"height,omitempty"` + Married bool + Address + Comment string `xml:",comment"` + } + + v := &Person{Id: 13, FirstName: "John", LastName: "Doe", Age: 42} + v.Comment = " Need more details. " + v.Address = Address{"Hanga Roa", "Easter Island"} + + output, err := xml.MarshalIndent(v, " ", " ") + if err != nil { + fmt.Printf("error: %v\n", err) + } + + os.Stdout.Write(output) + // Output: + // <person id="13"> + // <name> + // <first>John</first> + // <last>Doe</last> + // </name> + // <age>42</age> + // <Married>false</Married> + // <City>Hanga Roa</City> + // <State>Easter Island</State> + // <!-- Need more details. --> + // </person> +} + +func ExampleEncoder() { + type Address struct { + City, State string + } + type Person struct { + XMLName xml.Name `xml:"person"` + Id int `xml:"id,attr"` + FirstName string `xml:"name>first"` + LastName string `xml:"name>last"` + Age int `xml:"age"` + Height float32 `xml:"height,omitempty"` + Married bool + Address + Comment string `xml:",comment"` + } + + v := &Person{Id: 13, FirstName: "John", LastName: "Doe", Age: 42} + v.Comment = " Need more details. " + v.Address = Address{"Hanga Roa", "Easter Island"} + + enc := xml.NewEncoder(os.Stdout) + enc.Indent(" ", " ") + if err := enc.Encode(v); err != nil { + fmt.Printf("error: %v\n", err) + } + + // Output: + // <person id="13"> + // <name> + // <first>John</first> + // <last>Doe</last> + // </name> + // <age>42</age> + // <Married>false</Married> + // <City>Hanga Roa</City> + // <State>Easter Island</State> + // <!-- Need more details. --> + // </person> +} + +// This example demonstrates unmarshaling an XML excerpt into a value with +// some preset fields. Note that the Phone field isn't modified and that +// the XML <Company> element is ignored. Also, the Groups field is assigned +// considering the element path provided in its tag. +func ExampleUnmarshal() { + type Email struct { + Where string `xml:"where,attr"` + Addr string + } + type Address struct { + City, State string + } + type Result struct { + XMLName xml.Name `xml:"Person"` + Name string `xml:"FullName"` + Phone string + Email []Email + Groups []string `xml:"Group>Value"` + Address + } + v := Result{Name: "none", Phone: "none"} + + data := ` + <Person> + <FullName>Grace R. Emlin</FullName> + <Company>Example Inc.</Company> + <Email where="home"> + <Addr>gre@example.com</Addr> + </Email> + <Email where='work'> + <Addr>gre@work.com</Addr> + </Email> + <Group> + <Value>Friends</Value> + <Value>Squash</Value> + </Group> + <City>Hanga Roa</City> + <State>Easter Island</State> + </Person> + ` + err := xml.Unmarshal([]byte(data), &v) + if err != nil { + fmt.Printf("error: %v", err) + return + } + fmt.Printf("XMLName: %#v\n", v.XMLName) + fmt.Printf("Name: %q\n", v.Name) + fmt.Printf("Phone: %q\n", v.Phone) + fmt.Printf("Email: %v\n", v.Email) + fmt.Printf("Groups: %v\n", v.Groups) + fmt.Printf("Address: %v\n", v.Address) + // Output: + // XMLName: xml.Name{Space:"", Local:"Person"} + // Name: "Grace R. Emlin" + // Phone: "none" + // Email: [{home gre@example.com} {work gre@work.com}] + // Groups: [Friends Squash] + // Address: {Hanga Roa Easter Island} +} diff --git a/src/encoding/xml/marshal.go b/src/encoding/xml/marshal.go new file mode 100644 index 000000000..8c6342013 --- /dev/null +++ b/src/encoding/xml/marshal.go @@ -0,0 +1,938 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bufio" + "bytes" + "encoding" + "fmt" + "io" + "reflect" + "strconv" + "strings" +) + +const ( + // A generic XML header suitable for use with the output of Marshal. + // This is not automatically added to any output of this package, + // it is provided as a convenience. + Header = `<?xml version="1.0" encoding="UTF-8"?>` + "\n" +) + +// Marshal returns the XML encoding of v. +// +// Marshal handles an array or slice by marshalling each of the elements. +// Marshal handles a pointer by marshalling the value it points at or, if the +// pointer is nil, by writing nothing. Marshal handles an interface value by +// marshalling the value it contains or, if the interface value is nil, by +// writing nothing. Marshal handles all other data by writing one or more XML +// elements containing the data. +// +// The name for the XML elements is taken from, in order of preference: +// - the tag on the XMLName field, if the data is a struct +// - the value of the XMLName field of type xml.Name +// - the tag of the struct field used to obtain the data +// - the name of the struct field used to obtain the data +// - the name of the marshalled type +// +// The XML element for a struct contains marshalled elements for each of the +// exported fields of the struct, with these exceptions: +// - the XMLName field, described above, is omitted. +// - a field with tag "-" is omitted. +// - a field with tag "name,attr" becomes an attribute with +// the given name in the XML element. +// - a field with tag ",attr" becomes an attribute with the +// field name in the XML element. +// - a field with tag ",chardata" is written as character data, +// not as an XML element. +// - a field with tag ",innerxml" is written verbatim, not subject +// to the usual marshalling procedure. +// - a field with tag ",comment" is written as an XML comment, not +// subject to the usual marshalling procedure. It must not contain +// the "--" string within it. +// - a field with a tag including the "omitempty" option is omitted +// if the field value is empty. The empty values are false, 0, any +// nil pointer or interface value, and any array, slice, map, or +// string of length zero. +// - an anonymous struct field is handled as if the fields of its +// value were part of the outer struct. +// +// If a field uses a tag "a>b>c", then the element c will be nested inside +// parent elements a and b. Fields that appear next to each other that name +// the same parent will be enclosed in one XML element. +// +// See MarshalIndent for an example. +// +// Marshal will return an error if asked to marshal a channel, function, or map. +func Marshal(v interface{}) ([]byte, error) { + var b bytes.Buffer + if err := NewEncoder(&b).Encode(v); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// Marshaler is the interface implemented by objects that can marshal +// themselves into valid XML elements. +// +// MarshalXML encodes the receiver as zero or more XML elements. +// By convention, arrays or slices are typically encoded as a sequence +// of elements, one per entry. +// Using start as the element tag is not required, but doing so +// will enable Unmarshal to match the XML elements to the correct +// struct field. +// One common implementation strategy is to construct a separate +// value with a layout corresponding to the desired XML and then +// to encode it using e.EncodeElement. +// Another common strategy is to use repeated calls to e.EncodeToken +// to generate the XML output one token at a time. +// The sequence of encoded tokens must make up zero or more valid +// XML elements. +type Marshaler interface { + MarshalXML(e *Encoder, start StartElement) error +} + +// MarshalerAttr is the interface implemented by objects that can marshal +// themselves into valid XML attributes. +// +// MarshalXMLAttr returns an XML attribute with the encoded value of the receiver. +// Using name as the attribute name is not required, but doing so +// will enable Unmarshal to match the attribute to the correct +// struct field. +// If MarshalXMLAttr returns the zero attribute Attr{}, no attribute +// will be generated in the output. +// MarshalXMLAttr is used only for struct fields with the +// "attr" option in the field tag. +type MarshalerAttr interface { + MarshalXMLAttr(name Name) (Attr, error) +} + +// MarshalIndent works like Marshal, but each XML element begins on a new +// indented line that starts with prefix and is followed by one or more +// copies of indent according to the nesting depth. +func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { + var b bytes.Buffer + enc := NewEncoder(&b) + enc.Indent(prefix, indent) + if err := enc.Encode(v); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// An Encoder writes XML data to an output stream. +type Encoder struct { + p printer +} + +// NewEncoder returns a new encoder that writes to w. +func NewEncoder(w io.Writer) *Encoder { + e := &Encoder{printer{Writer: bufio.NewWriter(w)}} + e.p.encoder = e + return e +} + +// Indent sets the encoder to generate XML in which each element +// begins on a new indented line that starts with prefix and is followed by +// one or more copies of indent according to the nesting depth. +func (enc *Encoder) Indent(prefix, indent string) { + enc.p.prefix = prefix + enc.p.indent = indent +} + +// Encode writes the XML encoding of v to the stream. +// +// See the documentation for Marshal for details about the conversion +// of Go values to XML. +// +// Encode calls Flush before returning. +func (enc *Encoder) Encode(v interface{}) error { + err := enc.p.marshalValue(reflect.ValueOf(v), nil, nil) + if err != nil { + return err + } + return enc.p.Flush() +} + +// EncodeElement writes the XML encoding of v to the stream, +// using start as the outermost tag in the encoding. +// +// See the documentation for Marshal for details about the conversion +// of Go values to XML. +// +// EncodeElement calls Flush before returning. +func (enc *Encoder) EncodeElement(v interface{}, start StartElement) error { + err := enc.p.marshalValue(reflect.ValueOf(v), nil, &start) + if err != nil { + return err + } + return enc.p.Flush() +} + +var ( + endComment = []byte("-->") + endProcInst = []byte("?>") + endDirective = []byte(">") +) + +// EncodeToken writes the given XML token to the stream. +// It returns an error if StartElement and EndElement tokens are not properly matched. +// +// EncodeToken does not call Flush, because usually it is part of a larger operation +// such as Encode or EncodeElement (or a custom Marshaler's MarshalXML invoked +// during those), and those will call Flush when finished. +// Callers that create an Encoder and then invoke EncodeToken directly, without +// using Encode or EncodeElement, need to call Flush when finished to ensure +// that the XML is written to the underlying writer. +// +// EncodeToken allows writing a ProcInst with Target set to "xml" only as the first token +// in the stream. +func (enc *Encoder) EncodeToken(t Token) error { + p := &enc.p + switch t := t.(type) { + case StartElement: + if err := p.writeStart(&t); err != nil { + return err + } + case EndElement: + if err := p.writeEnd(t.Name); err != nil { + return err + } + case CharData: + EscapeText(p, t) + case Comment: + if bytes.Contains(t, endComment) { + return fmt.Errorf("xml: EncodeToken of Comment containing --> marker") + } + p.WriteString("<!--") + p.Write(t) + p.WriteString("-->") + return p.cachedWriteError() + case ProcInst: + // First token to be encoded which is also a ProcInst with target of xml + // is the xml declaration. The only ProcInst where target of xml is allowed. + if t.Target == "xml" && p.Buffered() != 0 { + return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded") + } + if !isNameString(t.Target) { + return fmt.Errorf("xml: EncodeToken of ProcInst with invalid Target") + } + if bytes.Contains(t.Inst, endProcInst) { + return fmt.Errorf("xml: EncodeToken of ProcInst containing ?> marker") + } + p.WriteString("<?") + p.WriteString(t.Target) + if len(t.Inst) > 0 { + p.WriteByte(' ') + p.Write(t.Inst) + } + p.WriteString("?>") + case Directive: + if bytes.Contains(t, endDirective) { + return fmt.Errorf("xml: EncodeToken of Directive containing > marker") + } + p.WriteString("<!") + p.Write(t) + p.WriteString(">") + } + return p.cachedWriteError() +} + +// Flush flushes any buffered XML to the underlying writer. +// See the EncodeToken documentation for details about when it is necessary. +func (enc *Encoder) Flush() error { + return enc.p.Flush() +} + +type printer struct { + *bufio.Writer + encoder *Encoder + seq int + indent string + prefix string + depth int + indentedIn bool + putNewline bool + attrNS map[string]string // map prefix -> name space + attrPrefix map[string]string // map name space -> prefix + prefixes []string + tags []Name +} + +// createAttrPrefix finds the name space prefix attribute to use for the given name space, +// defining a new prefix if necessary. It returns the prefix. +func (p *printer) createAttrPrefix(url string) string { + if prefix := p.attrPrefix[url]; prefix != "" { + return prefix + } + + // The "http://www.w3.org/XML/1998/namespace" name space is predefined as "xml" + // and must be referred to that way. + // (The "http://www.w3.org/2000/xmlns/" name space is also predefined as "xmlns", + // but users should not be trying to use that one directly - that's our job.) + if url == xmlURL { + return "xml" + } + + // Need to define a new name space. + if p.attrPrefix == nil { + p.attrPrefix = make(map[string]string) + p.attrNS = make(map[string]string) + } + + // Pick a name. We try to use the final element of the path + // but fall back to _. + prefix := strings.TrimRight(url, "/") + if i := strings.LastIndex(prefix, "/"); i >= 0 { + prefix = prefix[i+1:] + } + if prefix == "" || !isName([]byte(prefix)) || strings.Contains(prefix, ":") { + prefix = "_" + } + if strings.HasPrefix(prefix, "xml") { + // xmlanything is reserved. + prefix = "_" + prefix + } + if p.attrNS[prefix] != "" { + // Name is taken. Find a better one. + for p.seq++; ; p.seq++ { + if id := prefix + "_" + strconv.Itoa(p.seq); p.attrNS[id] == "" { + prefix = id + break + } + } + } + + p.attrPrefix[url] = prefix + p.attrNS[prefix] = url + + p.WriteString(`xmlns:`) + p.WriteString(prefix) + p.WriteString(`="`) + EscapeText(p, []byte(url)) + p.WriteString(`" `) + + p.prefixes = append(p.prefixes, prefix) + + return prefix +} + +// deleteAttrPrefix removes an attribute name space prefix. +func (p *printer) deleteAttrPrefix(prefix string) { + delete(p.attrPrefix, p.attrNS[prefix]) + delete(p.attrNS, prefix) +} + +func (p *printer) markPrefix() { + p.prefixes = append(p.prefixes, "") +} + +func (p *printer) popPrefix() { + for len(p.prefixes) > 0 { + prefix := p.prefixes[len(p.prefixes)-1] + p.prefixes = p.prefixes[:len(p.prefixes)-1] + if prefix == "" { + break + } + p.deleteAttrPrefix(prefix) + } +} + +var ( + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + marshalerAttrType = reflect.TypeOf((*MarshalerAttr)(nil)).Elem() + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() +) + +// marshalValue writes one or more XML elements representing val. +// If val was obtained from a struct field, finfo must have its details. +func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplate *StartElement) error { + if startTemplate != nil && startTemplate.Name.Local == "" { + return fmt.Errorf("xml: EncodeElement of StartElement with missing name") + } + + if !val.IsValid() { + return nil + } + if finfo != nil && finfo.flags&fOmitEmpty != 0 && isEmptyValue(val) { + return nil + } + + // Drill into interfaces and pointers. + // This can turn into an infinite loop given a cyclic chain, + // but it matches the Go 1 behavior. + for val.Kind() == reflect.Interface || val.Kind() == reflect.Ptr { + if val.IsNil() { + return nil + } + val = val.Elem() + } + + kind := val.Kind() + typ := val.Type() + + // Check for marshaler. + if val.CanInterface() && typ.Implements(marshalerType) { + return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate)) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(marshalerType) { + return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate)) + } + } + + // Check for text marshaler. + if val.CanInterface() && typ.Implements(textMarshalerType) { + return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate)) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate)) + } + } + + // Slices and arrays iterate over the elements. They do not have an enclosing tag. + if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 { + for i, n := 0, val.Len(); i < n; i++ { + if err := p.marshalValue(val.Index(i), finfo, startTemplate); err != nil { + return err + } + } + return nil + } + + tinfo, err := getTypeInfo(typ) + if err != nil { + return err + } + + // Create start element. + // Precedence for the XML element name is: + // 0. startTemplate + // 1. XMLName field in underlying struct; + // 2. field name/tag in the struct field; and + // 3. type name + var start StartElement + + if startTemplate != nil { + start.Name = startTemplate.Name + start.Attr = append(start.Attr, startTemplate.Attr...) + } else if tinfo.xmlname != nil { + xmlname := tinfo.xmlname + if xmlname.name != "" { + start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name + } else if v, ok := xmlname.value(val).Interface().(Name); ok && v.Local != "" { + start.Name = v + } + } + if start.Name.Local == "" && finfo != nil { + start.Name.Space, start.Name.Local = finfo.xmlns, finfo.name + } + if start.Name.Local == "" { + name := typ.Name() + if name == "" { + return &UnsupportedTypeError{typ} + } + start.Name.Local = name + } + + // Attributes + for i := range tinfo.fields { + finfo := &tinfo.fields[i] + if finfo.flags&fAttr == 0 { + continue + } + fv := finfo.value(val) + name := Name{Space: finfo.xmlns, Local: finfo.name} + + if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) { + continue + } + + if fv.Kind() == reflect.Interface && fv.IsNil() { + continue + } + + if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) { + attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + if err != nil { + return err + } + if attr.Name.Local != "" { + start.Attr = append(start.Attr, attr) + } + continue + } + + if fv.CanAddr() { + pv := fv.Addr() + if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { + attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + if err != nil { + return err + } + if attr.Name.Local != "" { + start.Attr = append(start.Attr, attr) + } + continue + } + } + + if fv.CanInterface() && fv.Type().Implements(textMarshalerType) { + text, err := fv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err + } + start.Attr = append(start.Attr, Attr{name, string(text)}) + continue + } + + if fv.CanAddr() { + pv := fv.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + text, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err + } + start.Attr = append(start.Attr, Attr{name, string(text)}) + continue + } + } + + // Dereference or skip nil pointer, interface values. + switch fv.Kind() { + case reflect.Ptr, reflect.Interface: + if fv.IsNil() { + continue + } + fv = fv.Elem() + } + + s, b, err := p.marshalSimple(fv.Type(), fv) + if err != nil { + return err + } + if b != nil { + s = string(b) + } + start.Attr = append(start.Attr, Attr{name, s}) + } + + if err := p.writeStart(&start); err != nil { + return err + } + + if val.Kind() == reflect.Struct { + err = p.marshalStruct(tinfo, val) + } else { + s, b, err1 := p.marshalSimple(typ, val) + if err1 != nil { + err = err1 + } else if b != nil { + EscapeText(p, b) + } else { + p.EscapeString(s) + } + } + if err != nil { + return err + } + + if err := p.writeEnd(start.Name); err != nil { + return err + } + + return p.cachedWriteError() +} + +// defaultStart returns the default start element to use, +// given the reflect type, field info, and start template. +func defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement { + var start StartElement + // Precedence for the XML element name is as above, + // except that we do not look inside structs for the first field. + if startTemplate != nil { + start.Name = startTemplate.Name + start.Attr = append(start.Attr, startTemplate.Attr...) + } else if finfo != nil && finfo.name != "" { + start.Name.Local = finfo.name + start.Name.Space = finfo.xmlns + } else if typ.Name() != "" { + start.Name.Local = typ.Name() + } else { + // Must be a pointer to a named type, + // since it has the Marshaler methods. + start.Name.Local = typ.Elem().Name() + } + return start +} + +// marshalInterface marshals a Marshaler interface value. +func (p *printer) marshalInterface(val Marshaler, start StartElement) error { + // Push a marker onto the tag stack so that MarshalXML + // cannot close the XML tags that it did not open. + p.tags = append(p.tags, Name{}) + n := len(p.tags) + + err := val.MarshalXML(p.encoder, start) + if err != nil { + return err + } + + // Make sure MarshalXML closed all its tags. p.tags[n-1] is the mark. + if len(p.tags) > n { + return fmt.Errorf("xml: %s.MarshalXML wrote invalid XML: <%s> not closed", receiverType(val), p.tags[len(p.tags)-1].Local) + } + p.tags = p.tags[:n-1] + return nil +} + +// marshalTextInterface marshals a TextMarshaler interface value. +func (p *printer) marshalTextInterface(val encoding.TextMarshaler, start StartElement) error { + if err := p.writeStart(&start); err != nil { + return err + } + text, err := val.MarshalText() + if err != nil { + return err + } + EscapeText(p, text) + return p.writeEnd(start.Name) +} + +// writeStart writes the given start element. +func (p *printer) writeStart(start *StartElement) error { + if start.Name.Local == "" { + return fmt.Errorf("xml: start tag with no name") + } + + p.tags = append(p.tags, start.Name) + p.markPrefix() + + p.writeIndent(1) + p.WriteByte('<') + p.WriteString(start.Name.Local) + + if start.Name.Space != "" { + p.WriteString(` xmlns="`) + p.EscapeString(start.Name.Space) + p.WriteByte('"') + } + + // Attributes + for _, attr := range start.Attr { + name := attr.Name + if name.Local == "" { + continue + } + p.WriteByte(' ') + if name.Space != "" { + p.WriteString(p.createAttrPrefix(name.Space)) + p.WriteByte(':') + } + p.WriteString(name.Local) + p.WriteString(`="`) + p.EscapeString(attr.Value) + p.WriteByte('"') + } + p.WriteByte('>') + return nil +} + +func (p *printer) writeEnd(name Name) error { + if name.Local == "" { + return fmt.Errorf("xml: end tag with no name") + } + if len(p.tags) == 0 || p.tags[len(p.tags)-1].Local == "" { + return fmt.Errorf("xml: end tag </%s> without start tag", name.Local) + } + if top := p.tags[len(p.tags)-1]; top != name { + if top.Local != name.Local { + return fmt.Errorf("xml: end tag </%s> does not match start tag <%s>", name.Local, top.Local) + } + return fmt.Errorf("xml: end tag </%s> in namespace %s does not match start tag <%s> in namespace %s", name.Local, name.Space, top.Local, top.Space) + } + p.tags = p.tags[:len(p.tags)-1] + + p.writeIndent(-1) + p.WriteByte('<') + p.WriteByte('/') + p.WriteString(name.Local) + p.WriteByte('>') + p.popPrefix() + return nil +} + +func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []byte, error) { + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(val.Int(), 10), nil, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return strconv.FormatUint(val.Uint(), 10), nil, nil + case reflect.Float32, reflect.Float64: + return strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits()), nil, nil + case reflect.String: + return val.String(), nil, nil + case reflect.Bool: + return strconv.FormatBool(val.Bool()), nil, nil + case reflect.Array: + if typ.Elem().Kind() != reflect.Uint8 { + break + } + // [...]byte + var bytes []byte + if val.CanAddr() { + bytes = val.Slice(0, val.Len()).Bytes() + } else { + bytes = make([]byte, val.Len()) + reflect.Copy(reflect.ValueOf(bytes), val) + } + return "", bytes, nil + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + break + } + // []byte + return "", val.Bytes(), nil + } + return "", nil, &UnsupportedTypeError{typ} +} + +var ddBytes = []byte("--") + +func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { + s := parentStack{p: p} + for i := range tinfo.fields { + finfo := &tinfo.fields[i] + if finfo.flags&fAttr != 0 { + continue + } + vf := finfo.value(val) + + // Dereference or skip nil pointer, interface values. + switch vf.Kind() { + case reflect.Ptr, reflect.Interface: + if !vf.IsNil() { + vf = vf.Elem() + } + } + + switch finfo.flags & fMode { + case fCharData: + if vf.CanInterface() && vf.Type().Implements(textMarshalerType) { + data, err := vf.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err + } + Escape(p, data) + continue + } + if vf.CanAddr() { + pv := vf.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + data, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err + } + Escape(p, data) + continue + } + } + var scratch [64]byte + switch vf.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + Escape(p, strconv.AppendInt(scratch[:0], vf.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + Escape(p, strconv.AppendUint(scratch[:0], vf.Uint(), 10)) + case reflect.Float32, reflect.Float64: + Escape(p, strconv.AppendFloat(scratch[:0], vf.Float(), 'g', -1, vf.Type().Bits())) + case reflect.Bool: + Escape(p, strconv.AppendBool(scratch[:0], vf.Bool())) + case reflect.String: + if err := EscapeText(p, []byte(vf.String())); err != nil { + return err + } + case reflect.Slice: + if elem, ok := vf.Interface().([]byte); ok { + if err := EscapeText(p, elem); err != nil { + return err + } + } + } + continue + + case fComment: + k := vf.Kind() + if !(k == reflect.String || k == reflect.Slice && vf.Type().Elem().Kind() == reflect.Uint8) { + return fmt.Errorf("xml: bad type for comment field of %s", val.Type()) + } + if vf.Len() == 0 { + continue + } + p.writeIndent(0) + p.WriteString("<!--") + dashDash := false + dashLast := false + switch k { + case reflect.String: + s := vf.String() + dashDash = strings.Index(s, "--") >= 0 + dashLast = s[len(s)-1] == '-' + if !dashDash { + p.WriteString(s) + } + case reflect.Slice: + b := vf.Bytes() + dashDash = bytes.Index(b, ddBytes) >= 0 + dashLast = b[len(b)-1] == '-' + if !dashDash { + p.Write(b) + } + default: + panic("can't happen") + } + if dashDash { + return fmt.Errorf(`xml: comments must not contain "--"`) + } + if dashLast { + // "--->" is invalid grammar. Make it "- -->" + p.WriteByte(' ') + } + p.WriteString("-->") + continue + + case fInnerXml: + iface := vf.Interface() + switch raw := iface.(type) { + case []byte: + p.Write(raw) + continue + case string: + p.WriteString(raw) + continue + } + + case fElement, fElement | fAny: + if err := s.trim(finfo.parents); err != nil { + return err + } + if len(finfo.parents) > len(s.stack) { + if vf.Kind() != reflect.Ptr && vf.Kind() != reflect.Interface || !vf.IsNil() { + if err := s.push(finfo.parents[len(s.stack):]); err != nil { + return err + } + } + } + } + if err := p.marshalValue(vf, finfo, nil); err != nil { + return err + } + } + s.trim(nil) + return p.cachedWriteError() +} + +// return the bufio Writer's cached write error +func (p *printer) cachedWriteError() error { + _, err := p.Write(nil) + return err +} + +func (p *printer) writeIndent(depthDelta int) { + if len(p.prefix) == 0 && len(p.indent) == 0 { + return + } + if depthDelta < 0 { + p.depth-- + if p.indentedIn { + p.indentedIn = false + return + } + p.indentedIn = false + } + if p.putNewline { + p.WriteByte('\n') + } else { + p.putNewline = true + } + if len(p.prefix) > 0 { + p.WriteString(p.prefix) + } + if len(p.indent) > 0 { + for i := 0; i < p.depth; i++ { + p.WriteString(p.indent) + } + } + if depthDelta > 0 { + p.depth++ + p.indentedIn = true + } +} + +type parentStack struct { + p *printer + stack []string +} + +// trim updates the XML context to match the longest common prefix of the stack +// and the given parents. A closing tag will be written for every parent +// popped. Passing a zero slice or nil will close all the elements. +func (s *parentStack) trim(parents []string) error { + split := 0 + for ; split < len(parents) && split < len(s.stack); split++ { + if parents[split] != s.stack[split] { + break + } + } + for i := len(s.stack) - 1; i >= split; i-- { + if err := s.p.writeEnd(Name{Local: s.stack[i]}); err != nil { + return err + } + } + s.stack = parents[:split] + return nil +} + +// push adds parent elements to the stack and writes open tags. +func (s *parentStack) push(parents []string) error { + for i := 0; i < len(parents); i++ { + if err := s.p.writeStart(&StartElement{Name: Name{Local: parents[i]}}); err != nil { + return err + } + } + s.stack = append(s.stack, parents...) + return nil +} + +// A MarshalXMLError is returned when Marshal encounters a type +// that cannot be converted into XML. +type UnsupportedTypeError struct { + Type reflect.Type +} + +func (e *UnsupportedTypeError) Error() string { + return "xml: unsupported type: " + e.Type.String() +} + +func isEmptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false +} diff --git a/src/encoding/xml/marshal_test.go b/src/encoding/xml/marshal_test.go new file mode 100644 index 000000000..14f73a75d --- /dev/null +++ b/src/encoding/xml/marshal_test.go @@ -0,0 +1,1266 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "testing" + "time" +) + +type DriveType int + +const ( + HyperDrive DriveType = iota + ImprobabilityDrive +) + +type Passenger struct { + Name []string `xml:"name"` + Weight float32 `xml:"weight"` +} + +type Ship struct { + XMLName struct{} `xml:"spaceship"` + + Name string `xml:"name,attr"` + Pilot string `xml:"pilot,attr"` + Drive DriveType `xml:"drive"` + Age uint `xml:"age"` + Passenger []*Passenger `xml:"passenger"` + secret string +} + +type NamedType string + +type Port struct { + XMLName struct{} `xml:"port"` + Type string `xml:"type,attr,omitempty"` + Comment string `xml:",comment"` + Number string `xml:",chardata"` +} + +type Domain struct { + XMLName struct{} `xml:"domain"` + Country string `xml:",attr,omitempty"` + Name []byte `xml:",chardata"` + Comment []byte `xml:",comment"` +} + +type Book struct { + XMLName struct{} `xml:"book"` + Title string `xml:",chardata"` +} + +type Event struct { + XMLName struct{} `xml:"event"` + Year int `xml:",chardata"` +} + +type Movie struct { + XMLName struct{} `xml:"movie"` + Length uint `xml:",chardata"` +} + +type Pi struct { + XMLName struct{} `xml:"pi"` + Approximation float32 `xml:",chardata"` +} + +type Universe struct { + XMLName struct{} `xml:"universe"` + Visible float64 `xml:",chardata"` +} + +type Particle struct { + XMLName struct{} `xml:"particle"` + HasMass bool `xml:",chardata"` +} + +type Departure struct { + XMLName struct{} `xml:"departure"` + When time.Time `xml:",chardata"` +} + +type SecretAgent struct { + XMLName struct{} `xml:"agent"` + Handle string `xml:"handle,attr"` + Identity string + Obfuscate string `xml:",innerxml"` +} + +type NestedItems struct { + XMLName struct{} `xml:"result"` + Items []string `xml:">item"` + Item1 []string `xml:"Items>item1"` +} + +type NestedOrder struct { + XMLName struct{} `xml:"result"` + Field1 string `xml:"parent>c"` + Field2 string `xml:"parent>b"` + Field3 string `xml:"parent>a"` +} + +type MixedNested struct { + XMLName struct{} `xml:"result"` + A string `xml:"parent1>a"` + B string `xml:"b"` + C string `xml:"parent1>parent2>c"` + D string `xml:"parent1>d"` +} + +type NilTest struct { + A interface{} `xml:"parent1>parent2>a"` + B interface{} `xml:"parent1>b"` + C interface{} `xml:"parent1>parent2>c"` +} + +type Service struct { + XMLName struct{} `xml:"service"` + Domain *Domain `xml:"host>domain"` + Port *Port `xml:"host>port"` + Extra1 interface{} + Extra2 interface{} `xml:"host>extra2"` +} + +var nilStruct *Ship + +type EmbedA struct { + EmbedC + EmbedB EmbedB + FieldA string +} + +type EmbedB struct { + FieldB string + *EmbedC +} + +type EmbedC struct { + FieldA1 string `xml:"FieldA>A1"` + FieldA2 string `xml:"FieldA>A2"` + FieldB string + FieldC string +} + +type NameCasing struct { + XMLName struct{} `xml:"casing"` + Xy string + XY string + XyA string `xml:"Xy,attr"` + XYA string `xml:"XY,attr"` +} + +type NamePrecedence struct { + XMLName Name `xml:"Parent"` + FromTag XMLNameWithoutTag `xml:"InTag"` + FromNameVal XMLNameWithoutTag + FromNameTag XMLNameWithTag + InFieldName string +} + +type XMLNameWithTag struct { + XMLName Name `xml:"InXMLNameTag"` + Value string `xml:",chardata"` +} + +type XMLNameWithoutTag struct { + XMLName Name + Value string `xml:",chardata"` +} + +type NameInField struct { + Foo Name `xml:"ns foo"` +} + +type AttrTest struct { + Int int `xml:",attr"` + Named int `xml:"int,attr"` + Float float64 `xml:",attr"` + Uint8 uint8 `xml:",attr"` + Bool bool `xml:",attr"` + Str string `xml:",attr"` + Bytes []byte `xml:",attr"` +} + +type OmitAttrTest struct { + Int int `xml:",attr,omitempty"` + Named int `xml:"int,attr,omitempty"` + Float float64 `xml:",attr,omitempty"` + Uint8 uint8 `xml:",attr,omitempty"` + Bool bool `xml:",attr,omitempty"` + Str string `xml:",attr,omitempty"` + Bytes []byte `xml:",attr,omitempty"` +} + +type OmitFieldTest struct { + Int int `xml:",omitempty"` + Named int `xml:"int,omitempty"` + Float float64 `xml:",omitempty"` + Uint8 uint8 `xml:",omitempty"` + Bool bool `xml:",omitempty"` + Str string `xml:",omitempty"` + Bytes []byte `xml:",omitempty"` + Ptr *PresenceTest `xml:",omitempty"` +} + +type AnyTest struct { + XMLName struct{} `xml:"a"` + Nested string `xml:"nested>value"` + AnyField AnyHolder `xml:",any"` +} + +type AnyOmitTest struct { + XMLName struct{} `xml:"a"` + Nested string `xml:"nested>value"` + AnyField *AnyHolder `xml:",any,omitempty"` +} + +type AnySliceTest struct { + XMLName struct{} `xml:"a"` + Nested string `xml:"nested>value"` + AnyField []AnyHolder `xml:",any"` +} + +type AnyHolder struct { + XMLName Name + XML string `xml:",innerxml"` +} + +type RecurseA struct { + A string + B *RecurseB +} + +type RecurseB struct { + A *RecurseA + B string +} + +type PresenceTest struct { + Exists *struct{} +} + +type IgnoreTest struct { + PublicSecret string `xml:"-"` +} + +type MyBytes []byte + +type Data struct { + Bytes []byte + Attr []byte `xml:",attr"` + Custom MyBytes +} + +type Plain struct { + V interface{} +} + +type MyInt int + +type EmbedInt struct { + MyInt +} + +type Strings struct { + X []string `xml:"A>B,omitempty"` +} + +type PointerFieldsTest struct { + XMLName Name `xml:"dummy"` + Name *string `xml:"name,attr"` + Age *uint `xml:"age,attr"` + Empty *string `xml:"empty,attr"` + Contents *string `xml:",chardata"` +} + +type ChardataEmptyTest struct { + XMLName Name `xml:"test"` + Contents *string `xml:",chardata"` +} + +type MyMarshalerTest struct { +} + +var _ Marshaler = (*MyMarshalerTest)(nil) + +func (m *MyMarshalerTest) MarshalXML(e *Encoder, start StartElement) error { + e.EncodeToken(start) + e.EncodeToken(CharData([]byte("hello world"))) + e.EncodeToken(EndElement{start.Name}) + return nil +} + +type MyMarshalerAttrTest struct { +} + +var _ MarshalerAttr = (*MyMarshalerAttrTest)(nil) + +func (m *MyMarshalerAttrTest) MarshalXMLAttr(name Name) (Attr, error) { + return Attr{name, "hello world"}, nil +} + +type MarshalerStruct struct { + Foo MyMarshalerAttrTest `xml:",attr"` +} + +type InnerStruct struct { + XMLName Name `xml:"testns outer"` +} + +type OuterStruct struct { + InnerStruct + IntAttr int `xml:"int,attr"` +} + +type OuterNamedStruct struct { + InnerStruct + XMLName Name `xml:"outerns test"` + IntAttr int `xml:"int,attr"` +} + +type OuterNamedOrderedStruct struct { + XMLName Name `xml:"outerns test"` + InnerStruct + IntAttr int `xml:"int,attr"` +} + +type OuterOuterStruct struct { + OuterStruct +} + +func ifaceptr(x interface{}) interface{} { + return &x +} + +var ( + nameAttr = "Sarah" + ageAttr = uint(12) + contentsAttr = "lorem ipsum" +) + +// Unless explicitly stated as such (or *Plain), all of the +// tests below are two-way tests. When introducing new tests, +// please try to make them two-way as well to ensure that +// marshalling and unmarshalling are as symmetrical as feasible. +var marshalTests = []struct { + Value interface{} + ExpectXML string + MarshalOnly bool + UnmarshalOnly bool +}{ + // Test nil marshals to nothing + {Value: nil, ExpectXML: ``, MarshalOnly: true}, + {Value: nilStruct, ExpectXML: ``, MarshalOnly: true}, + + // Test value types + {Value: &Plain{true}, ExpectXML: `<Plain><V>true</V></Plain>`}, + {Value: &Plain{false}, ExpectXML: `<Plain><V>false</V></Plain>`}, + {Value: &Plain{int(42)}, ExpectXML: `<Plain><V>42</V></Plain>`}, + {Value: &Plain{int8(42)}, ExpectXML: `<Plain><V>42</V></Plain>`}, + {Value: &Plain{int16(42)}, ExpectXML: `<Plain><V>42</V></Plain>`}, + {Value: &Plain{int32(42)}, ExpectXML: `<Plain><V>42</V></Plain>`}, + {Value: &Plain{uint(42)}, ExpectXML: `<Plain><V>42</V></Plain>`}, + {Value: &Plain{uint8(42)}, ExpectXML: `<Plain><V>42</V></Plain>`}, + {Value: &Plain{uint16(42)}, ExpectXML: `<Plain><V>42</V></Plain>`}, + {Value: &Plain{uint32(42)}, ExpectXML: `<Plain><V>42</V></Plain>`}, + {Value: &Plain{float32(1.25)}, ExpectXML: `<Plain><V>1.25</V></Plain>`}, + {Value: &Plain{float64(1.25)}, ExpectXML: `<Plain><V>1.25</V></Plain>`}, + {Value: &Plain{uintptr(0xFFDD)}, ExpectXML: `<Plain><V>65501</V></Plain>`}, + {Value: &Plain{"gopher"}, ExpectXML: `<Plain><V>gopher</V></Plain>`}, + {Value: &Plain{[]byte("gopher")}, ExpectXML: `<Plain><V>gopher</V></Plain>`}, + {Value: &Plain{"</>"}, ExpectXML: `<Plain><V></></V></Plain>`}, + {Value: &Plain{[]byte("</>")}, ExpectXML: `<Plain><V></></V></Plain>`}, + {Value: &Plain{[3]byte{'<', '/', '>'}}, ExpectXML: `<Plain><V></></V></Plain>`}, + {Value: &Plain{NamedType("potato")}, ExpectXML: `<Plain><V>potato</V></Plain>`}, + {Value: &Plain{[]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`}, + {Value: &Plain{[3]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`}, + {Value: ifaceptr(true), MarshalOnly: true, ExpectXML: `<bool>true</bool>`}, + + // Test time. + { + Value: &Plain{time.Unix(1e9, 123456789).UTC()}, + ExpectXML: `<Plain><V>2001-09-09T01:46:40.123456789Z</V></Plain>`, + }, + + // A pointer to struct{} may be used to test for an element's presence. + { + Value: &PresenceTest{new(struct{})}, + ExpectXML: `<PresenceTest><Exists></Exists></PresenceTest>`, + }, + { + Value: &PresenceTest{}, + ExpectXML: `<PresenceTest></PresenceTest>`, + }, + + // A pointer to struct{} may be used to test for an element's presence. + { + Value: &PresenceTest{new(struct{})}, + ExpectXML: `<PresenceTest><Exists></Exists></PresenceTest>`, + }, + { + Value: &PresenceTest{}, + ExpectXML: `<PresenceTest></PresenceTest>`, + }, + + // A []byte field is only nil if the element was not found. + { + Value: &Data{}, + ExpectXML: `<Data></Data>`, + UnmarshalOnly: true, + }, + { + Value: &Data{Bytes: []byte{}, Custom: MyBytes{}, Attr: []byte{}}, + ExpectXML: `<Data Attr=""><Bytes></Bytes><Custom></Custom></Data>`, + UnmarshalOnly: true, + }, + + // Check that []byte works, including named []byte types. + { + Value: &Data{Bytes: []byte("ab"), Custom: MyBytes("cd"), Attr: []byte{'v'}}, + ExpectXML: `<Data Attr="v"><Bytes>ab</Bytes><Custom>cd</Custom></Data>`, + }, + + // Test innerxml + { + Value: &SecretAgent{ + Handle: "007", + Identity: "James Bond", + Obfuscate: "<redacted/>", + }, + ExpectXML: `<agent handle="007"><Identity>James Bond</Identity><redacted/></agent>`, + MarshalOnly: true, + }, + { + Value: &SecretAgent{ + Handle: "007", + Identity: "James Bond", + Obfuscate: "<Identity>James Bond</Identity><redacted/>", + }, + ExpectXML: `<agent handle="007"><Identity>James Bond</Identity><redacted/></agent>`, + UnmarshalOnly: true, + }, + + // Test structs + {Value: &Port{Type: "ssl", Number: "443"}, ExpectXML: `<port type="ssl">443</port>`}, + {Value: &Port{Number: "443"}, ExpectXML: `<port>443</port>`}, + {Value: &Port{Type: "<unix>"}, ExpectXML: `<port type="<unix>"></port>`}, + {Value: &Port{Number: "443", Comment: "https"}, ExpectXML: `<port><!--https-->443</port>`}, + {Value: &Port{Number: "443", Comment: "add space-"}, ExpectXML: `<port><!--add space- -->443</port>`, MarshalOnly: true}, + {Value: &Domain{Name: []byte("google.com&friends")}, ExpectXML: `<domain>google.com&friends</domain>`}, + {Value: &Domain{Name: []byte("google.com"), Comment: []byte(" &friends ")}, ExpectXML: `<domain>google.com<!-- &friends --></domain>`}, + {Value: &Book{Title: "Pride & Prejudice"}, ExpectXML: `<book>Pride & Prejudice</book>`}, + {Value: &Event{Year: -3114}, ExpectXML: `<event>-3114</event>`}, + {Value: &Movie{Length: 13440}, ExpectXML: `<movie>13440</movie>`}, + {Value: &Pi{Approximation: 3.14159265}, ExpectXML: `<pi>3.1415927</pi>`}, + {Value: &Universe{Visible: 9.3e13}, ExpectXML: `<universe>9.3e+13</universe>`}, + {Value: &Particle{HasMass: true}, ExpectXML: `<particle>true</particle>`}, + {Value: &Departure{When: ParseTime("2013-01-09T00:15:00-09:00")}, ExpectXML: `<departure>2013-01-09T00:15:00-09:00</departure>`}, + {Value: atomValue, ExpectXML: atomXml}, + { + Value: &Ship{ + Name: "Heart of Gold", + Pilot: "Computer", + Age: 1, + Drive: ImprobabilityDrive, + Passenger: []*Passenger{ + { + Name: []string{"Zaphod", "Beeblebrox"}, + Weight: 7.25, + }, + { + Name: []string{"Trisha", "McMillen"}, + Weight: 5.5, + }, + { + Name: []string{"Ford", "Prefect"}, + Weight: 7, + }, + { + Name: []string{"Arthur", "Dent"}, + Weight: 6.75, + }, + }, + }, + ExpectXML: `<spaceship name="Heart of Gold" pilot="Computer">` + + `<drive>` + strconv.Itoa(int(ImprobabilityDrive)) + `</drive>` + + `<age>1</age>` + + `<passenger>` + + `<name>Zaphod</name>` + + `<name>Beeblebrox</name>` + + `<weight>7.25</weight>` + + `</passenger>` + + `<passenger>` + + `<name>Trisha</name>` + + `<name>McMillen</name>` + + `<weight>5.5</weight>` + + `</passenger>` + + `<passenger>` + + `<name>Ford</name>` + + `<name>Prefect</name>` + + `<weight>7</weight>` + + `</passenger>` + + `<passenger>` + + `<name>Arthur</name>` + + `<name>Dent</name>` + + `<weight>6.75</weight>` + + `</passenger>` + + `</spaceship>`, + }, + + // Test a>b + { + Value: &NestedItems{Items: nil, Item1: nil}, + ExpectXML: `<result>` + + `<Items>` + + `</Items>` + + `</result>`, + }, + { + Value: &NestedItems{Items: []string{}, Item1: []string{}}, + ExpectXML: `<result>` + + `<Items>` + + `</Items>` + + `</result>`, + MarshalOnly: true, + }, + { + Value: &NestedItems{Items: nil, Item1: []string{"A"}}, + ExpectXML: `<result>` + + `<Items>` + + `<item1>A</item1>` + + `</Items>` + + `</result>`, + }, + { + Value: &NestedItems{Items: []string{"A", "B"}, Item1: nil}, + ExpectXML: `<result>` + + `<Items>` + + `<item>A</item>` + + `<item>B</item>` + + `</Items>` + + `</result>`, + }, + { + Value: &NestedItems{Items: []string{"A", "B"}, Item1: []string{"C"}}, + ExpectXML: `<result>` + + `<Items>` + + `<item>A</item>` + + `<item>B</item>` + + `<item1>C</item1>` + + `</Items>` + + `</result>`, + }, + { + Value: &NestedOrder{Field1: "C", Field2: "B", Field3: "A"}, + ExpectXML: `<result>` + + `<parent>` + + `<c>C</c>` + + `<b>B</b>` + + `<a>A</a>` + + `</parent>` + + `</result>`, + }, + { + Value: &NilTest{A: "A", B: nil, C: "C"}, + ExpectXML: `<NilTest>` + + `<parent1>` + + `<parent2><a>A</a></parent2>` + + `<parent2><c>C</c></parent2>` + + `</parent1>` + + `</NilTest>`, + MarshalOnly: true, // Uses interface{} + }, + { + Value: &MixedNested{A: "A", B: "B", C: "C", D: "D"}, + ExpectXML: `<result>` + + `<parent1><a>A</a></parent1>` + + `<b>B</b>` + + `<parent1>` + + `<parent2><c>C</c></parent2>` + + `<d>D</d>` + + `</parent1>` + + `</result>`, + }, + { + Value: &Service{Port: &Port{Number: "80"}}, + ExpectXML: `<service><host><port>80</port></host></service>`, + }, + { + Value: &Service{}, + ExpectXML: `<service></service>`, + }, + { + Value: &Service{Port: &Port{Number: "80"}, Extra1: "A", Extra2: "B"}, + ExpectXML: `<service>` + + `<host><port>80</port></host>` + + `<Extra1>A</Extra1>` + + `<host><extra2>B</extra2></host>` + + `</service>`, + MarshalOnly: true, + }, + { + Value: &Service{Port: &Port{Number: "80"}, Extra2: "example"}, + ExpectXML: `<service>` + + `<host><port>80</port></host>` + + `<host><extra2>example</extra2></host>` + + `</service>`, + MarshalOnly: true, + }, + + // Test struct embedding + { + Value: &EmbedA{ + EmbedC: EmbedC{ + FieldA1: "", // Shadowed by A.A + FieldA2: "", // Shadowed by A.A + FieldB: "A.C.B", + FieldC: "A.C.C", + }, + EmbedB: EmbedB{ + FieldB: "A.B.B", + EmbedC: &EmbedC{ + FieldA1: "A.B.C.A1", + FieldA2: "A.B.C.A2", + FieldB: "", // Shadowed by A.B.B + FieldC: "A.B.C.C", + }, + }, + FieldA: "A.A", + }, + ExpectXML: `<EmbedA>` + + `<FieldB>A.C.B</FieldB>` + + `<FieldC>A.C.C</FieldC>` + + `<EmbedB>` + + `<FieldB>A.B.B</FieldB>` + + `<FieldA>` + + `<A1>A.B.C.A1</A1>` + + `<A2>A.B.C.A2</A2>` + + `</FieldA>` + + `<FieldC>A.B.C.C</FieldC>` + + `</EmbedB>` + + `<FieldA>A.A</FieldA>` + + `</EmbedA>`, + }, + + // Test that name casing matters + { + Value: &NameCasing{Xy: "mixed", XY: "upper", XyA: "mixedA", XYA: "upperA"}, + ExpectXML: `<casing Xy="mixedA" XY="upperA"><Xy>mixed</Xy><XY>upper</XY></casing>`, + }, + + // Test the order in which the XML element name is chosen + { + Value: &NamePrecedence{ + FromTag: XMLNameWithoutTag{Value: "A"}, + FromNameVal: XMLNameWithoutTag{XMLName: Name{Local: "InXMLName"}, Value: "B"}, + FromNameTag: XMLNameWithTag{Value: "C"}, + InFieldName: "D", + }, + ExpectXML: `<Parent>` + + `<InTag>A</InTag>` + + `<InXMLName>B</InXMLName>` + + `<InXMLNameTag>C</InXMLNameTag>` + + `<InFieldName>D</InFieldName>` + + `</Parent>`, + MarshalOnly: true, + }, + { + Value: &NamePrecedence{ + XMLName: Name{Local: "Parent"}, + FromTag: XMLNameWithoutTag{XMLName: Name{Local: "InTag"}, Value: "A"}, + FromNameVal: XMLNameWithoutTag{XMLName: Name{Local: "FromNameVal"}, Value: "B"}, + FromNameTag: XMLNameWithTag{XMLName: Name{Local: "InXMLNameTag"}, Value: "C"}, + InFieldName: "D", + }, + ExpectXML: `<Parent>` + + `<InTag>A</InTag>` + + `<FromNameVal>B</FromNameVal>` + + `<InXMLNameTag>C</InXMLNameTag>` + + `<InFieldName>D</InFieldName>` + + `</Parent>`, + UnmarshalOnly: true, + }, + + // xml.Name works in a plain field as well. + { + Value: &NameInField{Name{Space: "ns", Local: "foo"}}, + ExpectXML: `<NameInField><foo xmlns="ns"></foo></NameInField>`, + }, + { + Value: &NameInField{Name{Space: "ns", Local: "foo"}}, + ExpectXML: `<NameInField><foo xmlns="ns"><ignore></ignore></foo></NameInField>`, + UnmarshalOnly: true, + }, + + // Marshaling zero xml.Name uses the tag or field name. + { + Value: &NameInField{}, + ExpectXML: `<NameInField><foo xmlns="ns"></foo></NameInField>`, + MarshalOnly: true, + }, + + // Test attributes + { + Value: &AttrTest{ + Int: 8, + Named: 9, + Float: 23.5, + Uint8: 255, + Bool: true, + Str: "str", + Bytes: []byte("byt"), + }, + ExpectXML: `<AttrTest Int="8" int="9" Float="23.5" Uint8="255"` + + ` Bool="true" Str="str" Bytes="byt"></AttrTest>`, + }, + { + Value: &AttrTest{Bytes: []byte{}}, + ExpectXML: `<AttrTest Int="0" int="0" Float="0" Uint8="0"` + + ` Bool="false" Str="" Bytes=""></AttrTest>`, + }, + { + Value: &OmitAttrTest{ + Int: 8, + Named: 9, + Float: 23.5, + Uint8: 255, + Bool: true, + Str: "str", + Bytes: []byte("byt"), + }, + ExpectXML: `<OmitAttrTest Int="8" int="9" Float="23.5" Uint8="255"` + + ` Bool="true" Str="str" Bytes="byt"></OmitAttrTest>`, + }, + { + Value: &OmitAttrTest{}, + ExpectXML: `<OmitAttrTest></OmitAttrTest>`, + }, + + // pointer fields + { + Value: &PointerFieldsTest{Name: &nameAttr, Age: &ageAttr, Contents: &contentsAttr}, + ExpectXML: `<dummy name="Sarah" age="12">lorem ipsum</dummy>`, + MarshalOnly: true, + }, + + // empty chardata pointer field + { + Value: &ChardataEmptyTest{}, + ExpectXML: `<test></test>`, + MarshalOnly: true, + }, + + // omitempty on fields + { + Value: &OmitFieldTest{ + Int: 8, + Named: 9, + Float: 23.5, + Uint8: 255, + Bool: true, + Str: "str", + Bytes: []byte("byt"), + Ptr: &PresenceTest{}, + }, + ExpectXML: `<OmitFieldTest>` + + `<Int>8</Int>` + + `<int>9</int>` + + `<Float>23.5</Float>` + + `<Uint8>255</Uint8>` + + `<Bool>true</Bool>` + + `<Str>str</Str>` + + `<Bytes>byt</Bytes>` + + `<Ptr></Ptr>` + + `</OmitFieldTest>`, + }, + { + Value: &OmitFieldTest{}, + ExpectXML: `<OmitFieldTest></OmitFieldTest>`, + }, + + // Test ",any" + { + ExpectXML: `<a><nested><value>known</value></nested><other><sub>unknown</sub></other></a>`, + Value: &AnyTest{ + Nested: "known", + AnyField: AnyHolder{ + XMLName: Name{Local: "other"}, + XML: "<sub>unknown</sub>", + }, + }, + }, + { + Value: &AnyTest{Nested: "known", + AnyField: AnyHolder{ + XML: "<unknown/>", + XMLName: Name{Local: "AnyField"}, + }, + }, + ExpectXML: `<a><nested><value>known</value></nested><AnyField><unknown/></AnyField></a>`, + }, + { + ExpectXML: `<a><nested><value>b</value></nested></a>`, + Value: &AnyOmitTest{ + Nested: "b", + }, + }, + { + ExpectXML: `<a><nested><value>b</value></nested><c><d>e</d></c><g xmlns="f"><h>i</h></g></a>`, + Value: &AnySliceTest{ + Nested: "b", + AnyField: []AnyHolder{ + { + XMLName: Name{Local: "c"}, + XML: "<d>e</d>", + }, + { + XMLName: Name{Space: "f", Local: "g"}, + XML: "<h>i</h>", + }, + }, + }, + }, + { + ExpectXML: `<a><nested><value>b</value></nested></a>`, + Value: &AnySliceTest{ + Nested: "b", + }, + }, + + // Test recursive types. + { + Value: &RecurseA{ + A: "a1", + B: &RecurseB{ + A: &RecurseA{"a2", nil}, + B: "b1", + }, + }, + ExpectXML: `<RecurseA><A>a1</A><B><A><A>a2</A></A><B>b1</B></B></RecurseA>`, + }, + + // Test ignoring fields via "-" tag + { + ExpectXML: `<IgnoreTest></IgnoreTest>`, + Value: &IgnoreTest{}, + }, + { + ExpectXML: `<IgnoreTest></IgnoreTest>`, + Value: &IgnoreTest{PublicSecret: "can't tell"}, + MarshalOnly: true, + }, + { + ExpectXML: `<IgnoreTest><PublicSecret>ignore me</PublicSecret></IgnoreTest>`, + Value: &IgnoreTest{}, + UnmarshalOnly: true, + }, + + // Test escaping. + { + ExpectXML: `<a><nested><value>dquote: "; squote: '; ampersand: &; less: <; greater: >;</value></nested><empty></empty></a>`, + Value: &AnyTest{ + Nested: `dquote: "; squote: '; ampersand: &; less: <; greater: >;`, + AnyField: AnyHolder{XMLName: Name{Local: "empty"}}, + }, + }, + { + ExpectXML: `<a><nested><value>newline: 
; cr: 
; tab: 	;</value></nested><AnyField></AnyField></a>`, + Value: &AnyTest{ + Nested: "newline: \n; cr: \r; tab: \t;", + AnyField: AnyHolder{XMLName: Name{Local: "AnyField"}}, + }, + }, + { + ExpectXML: "<a><nested><value>1\r2\r\n3\n\r4\n5</value></nested></a>", + Value: &AnyTest{ + Nested: "1\n2\n3\n\n4\n5", + }, + UnmarshalOnly: true, + }, + { + ExpectXML: `<EmbedInt><MyInt>42</MyInt></EmbedInt>`, + Value: &EmbedInt{ + MyInt: 42, + }, + }, + // Test omitempty with parent chain; see golang.org/issue/4168. + { + ExpectXML: `<Strings><A></A></Strings>`, + Value: &Strings{}, + }, + // Custom marshalers. + { + ExpectXML: `<MyMarshalerTest>hello world</MyMarshalerTest>`, + Value: &MyMarshalerTest{}, + }, + { + ExpectXML: `<MarshalerStruct Foo="hello world"></MarshalerStruct>`, + Value: &MarshalerStruct{}, + }, + { + ExpectXML: `<outer xmlns="testns" int="10"></outer>`, + Value: &OuterStruct{IntAttr: 10}, + }, + { + ExpectXML: `<test xmlns="outerns" int="10"></test>`, + Value: &OuterNamedStruct{XMLName: Name{Space: "outerns", Local: "test"}, IntAttr: 10}, + }, + { + ExpectXML: `<test xmlns="outerns" int="10"></test>`, + Value: &OuterNamedOrderedStruct{XMLName: Name{Space: "outerns", Local: "test"}, IntAttr: 10}, + }, + { + ExpectXML: `<outer xmlns="testns" int="10"></outer>`, + Value: &OuterOuterStruct{OuterStruct{IntAttr: 10}}, + }, +} + +func TestMarshal(t *testing.T) { + for idx, test := range marshalTests { + if test.UnmarshalOnly { + continue + } + data, err := Marshal(test.Value) + if err != nil { + t.Errorf("#%d: Error: %s", idx, err) + continue + } + if got, want := string(data), test.ExpectXML; got != want { + if strings.Contains(want, "\n") { + t.Errorf("#%d: marshal(%#v):\nHAVE:\n%s\nWANT:\n%s", idx, test.Value, got, want) + } else { + t.Errorf("#%d: marshal(%#v):\nhave %#q\nwant %#q", idx, test.Value, got, want) + } + } + } +} + +type AttrParent struct { + X string `xml:"X>Y,attr"` +} + +type BadAttr struct { + Name []string `xml:"name,attr"` +} + +var marshalErrorTests = []struct { + Value interface{} + Err string + Kind reflect.Kind +}{ + { + Value: make(chan bool), + Err: "xml: unsupported type: chan bool", + Kind: reflect.Chan, + }, + { + Value: map[string]string{ + "question": "What do you get when you multiply six by nine?", + "answer": "42", + }, + Err: "xml: unsupported type: map[string]string", + Kind: reflect.Map, + }, + { + Value: map[*Ship]bool{nil: false}, + Err: "xml: unsupported type: map[*xml.Ship]bool", + Kind: reflect.Map, + }, + { + Value: &Domain{Comment: []byte("f--bar")}, + Err: `xml: comments must not contain "--"`, + }, + // Reject parent chain with attr, never worked; see golang.org/issue/5033. + { + Value: &AttrParent{}, + Err: `xml: X>Y chain not valid with attr flag`, + }, + { + Value: BadAttr{[]string{"X", "Y"}}, + Err: `xml: unsupported type: []string`, + }, +} + +var marshalIndentTests = []struct { + Value interface{} + Prefix string + Indent string + ExpectXML string +}{ + { + Value: &SecretAgent{ + Handle: "007", + Identity: "James Bond", + Obfuscate: "<redacted/>", + }, + Prefix: "", + Indent: "\t", + ExpectXML: fmt.Sprintf("<agent handle=\"007\">\n\t<Identity>James Bond</Identity><redacted/>\n</agent>"), + }, +} + +func TestMarshalErrors(t *testing.T) { + for idx, test := range marshalErrorTests { + data, err := Marshal(test.Value) + if err == nil { + t.Errorf("#%d: marshal(%#v) = [success] %q, want error %v", idx, test.Value, data, test.Err) + continue + } + if err.Error() != test.Err { + t.Errorf("#%d: marshal(%#v) = [error] %v, want %v", idx, test.Value, err, test.Err) + } + if test.Kind != reflect.Invalid { + if kind := err.(*UnsupportedTypeError).Type.Kind(); kind != test.Kind { + t.Errorf("#%d: marshal(%#v) = [error kind] %s, want %s", idx, test.Value, kind, test.Kind) + } + } + } +} + +// Do invertibility testing on the various structures that we test +func TestUnmarshal(t *testing.T) { + for i, test := range marshalTests { + if test.MarshalOnly { + continue + } + if _, ok := test.Value.(*Plain); ok { + continue + } + + vt := reflect.TypeOf(test.Value) + dest := reflect.New(vt.Elem()).Interface() + err := Unmarshal([]byte(test.ExpectXML), dest) + + switch fix := dest.(type) { + case *Feed: + fix.Author.InnerXML = "" + for i := range fix.Entry { + fix.Entry[i].Author.InnerXML = "" + } + } + + if err != nil { + t.Errorf("#%d: unexpected error: %#v", i, err) + } else if got, want := dest, test.Value; !reflect.DeepEqual(got, want) { + t.Errorf("#%d: unmarshal(%q):\nhave %#v\nwant %#v", i, test.ExpectXML, got, want) + } + } +} + +func TestMarshalIndent(t *testing.T) { + for i, test := range marshalIndentTests { + data, err := MarshalIndent(test.Value, test.Prefix, test.Indent) + if err != nil { + t.Errorf("#%d: Error: %s", i, err) + continue + } + if got, want := string(data), test.ExpectXML; got != want { + t.Errorf("#%d: MarshalIndent:\nGot:%s\nWant:\n%s", i, got, want) + } + } +} + +type limitedBytesWriter struct { + w io.Writer + remain int // until writes fail +} + +func (lw *limitedBytesWriter) Write(p []byte) (n int, err error) { + if lw.remain <= 0 { + println("error") + return 0, errors.New("write limit hit") + } + if len(p) > lw.remain { + p = p[:lw.remain] + n, _ = lw.w.Write(p) + lw.remain = 0 + return n, errors.New("write limit hit") + } + n, err = lw.w.Write(p) + lw.remain -= n + return n, err +} + +func TestMarshalWriteErrors(t *testing.T) { + var buf bytes.Buffer + const writeCap = 1024 + w := &limitedBytesWriter{&buf, writeCap} + enc := NewEncoder(w) + var err error + var i int + const n = 4000 + for i = 1; i <= n; i++ { + err = enc.Encode(&Passenger{ + Name: []string{"Alice", "Bob"}, + Weight: 5, + }) + if err != nil { + break + } + } + if err == nil { + t.Error("expected an error") + } + if i == n { + t.Errorf("expected to fail before the end") + } + if buf.Len() != writeCap { + t.Errorf("buf.Len() = %d; want %d", buf.Len(), writeCap) + } +} + +func TestMarshalWriteIOErrors(t *testing.T) { + enc := NewEncoder(errWriter{}) + + expectErr := "unwritable" + err := enc.Encode(&Passenger{}) + if err == nil || err.Error() != expectErr { + t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr) + } +} + +func TestMarshalFlush(t *testing.T) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + if err := enc.EncodeToken(CharData("hello world")); err != nil { + t.Fatalf("enc.EncodeToken: %v", err) + } + if buf.Len() > 0 { + t.Fatalf("enc.EncodeToken caused actual write: %q", buf.Bytes()) + } + if err := enc.Flush(); err != nil { + t.Fatalf("enc.Flush: %v", err) + } + if buf.String() != "hello world" { + t.Fatalf("after enc.Flush, buf.String() = %q, want %q", buf.String(), "hello world") + } +} + +func BenchmarkMarshal(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(atomValue) + } +} + +func BenchmarkUnmarshal(b *testing.B) { + xml := []byte(atomXml) + for i := 0; i < b.N; i++ { + Unmarshal(xml, &Feed{}) + } +} + +// golang.org/issue/6556 +func TestStructPointerMarshal(t *testing.T) { + type A struct { + XMLName string `xml:"a"` + B []interface{} + } + type C struct { + XMLName Name + Value string `xml:"value"` + } + + a := new(A) + a.B = append(a.B, &C{ + XMLName: Name{Local: "c"}, + Value: "x", + }) + + b, err := Marshal(a) + if err != nil { + t.Fatal(err) + } + if x := string(b); x != "<a><c><value>x</value></c></a>" { + t.Fatal(x) + } + var v A + err = Unmarshal(b, &v) + if err != nil { + t.Fatal(err) + } +} + +var encodeTokenTests = []struct { + tok Token + want string + ok bool +}{ + {StartElement{Name{"space", "local"}, nil}, "<local xmlns=\"space\">", true}, + {StartElement{Name{"space", ""}, nil}, "", false}, + {EndElement{Name{"space", ""}}, "", false}, + {CharData("foo"), "foo", true}, + {Comment("foo"), "<!--foo-->", true}, + {Comment("foo-->"), "", false}, + {ProcInst{"Target", []byte("Instruction")}, "<?Target Instruction?>", true}, + {ProcInst{"", []byte("Instruction")}, "", false}, + {ProcInst{"Target", []byte("Instruction?>")}, "", false}, + {Directive("foo"), "<!foo>", true}, + {Directive("foo>"), "", false}, +} + +func TestEncodeToken(t *testing.T) { + for _, tt := range encodeTokenTests { + var buf bytes.Buffer + enc := NewEncoder(&buf) + err := enc.EncodeToken(tt.tok) + switch { + case !tt.ok && err == nil: + t.Errorf("enc.EncodeToken(%#v): expected error; got none", tt.tok) + case tt.ok && err != nil: + t.Fatalf("enc.EncodeToken: %v", err) + case !tt.ok && err != nil: + // expected error, got one + } + if err := enc.Flush(); err != nil { + t.Fatalf("enc.EncodeToken: %v", err) + } + if got := buf.String(); got != tt.want { + t.Errorf("enc.EncodeToken = %s; want: %s", got, tt.want) + } + } +} + +func TestProcInstEncodeToken(t *testing.T) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + + if err := enc.EncodeToken(ProcInst{"xml", []byte("Instruction")}); err != nil { + t.Fatalf("enc.EncodeToken: expected to be able to encode xml target ProcInst as first token, %s", err) + } + + if err := enc.EncodeToken(ProcInst{"Target", []byte("Instruction")}); err != nil { + t.Fatalf("enc.EncodeToken: expected to be able to add non-xml target ProcInst") + } + + if err := enc.EncodeToken(ProcInst{"xml", []byte("Instruction")}); err == nil { + t.Fatalf("enc.EncodeToken: expected to not be allowed to encode xml target ProcInst when not first token") + } +} + +func TestDecodeEncode(t *testing.T) { + var in, out bytes.Buffer + in.WriteString(`<?xml version="1.0" encoding="UTF-8"?> +<?Target Instruction?> +<root> +</root> +`) + dec := NewDecoder(&in) + enc := NewEncoder(&out) + for tok, err := dec.Token(); err == nil; tok, err = dec.Token() { + err = enc.EncodeToken(tok) + if err != nil { + t.Fatalf("enc.EncodeToken: Unable to encode token (%#v), %v", tok, err) + } + } +} diff --git a/src/encoding/xml/read.go b/src/encoding/xml/read.go new file mode 100644 index 000000000..75b9f2ba1 --- /dev/null +++ b/src/encoding/xml/read.go @@ -0,0 +1,692 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bytes" + "encoding" + "errors" + "fmt" + "reflect" + "strconv" + "strings" +) + +// BUG(rsc): Mapping between XML elements and data structures is inherently flawed: +// an XML element is an order-dependent collection of anonymous +// values, while a data structure is an order-independent collection +// of named values. +// See package json for a textual representation more suitable +// to data structures. + +// Unmarshal parses the XML-encoded data and stores the result in +// the value pointed to by v, which must be an arbitrary struct, +// slice, or string. Well-formed data that does not fit into v is +// discarded. +// +// Because Unmarshal uses the reflect package, it can only assign +// to exported (upper case) fields. Unmarshal uses a case-sensitive +// comparison to match XML element names to tag values and struct +// field names. +// +// Unmarshal maps an XML element to a struct using the following rules. +// In the rules, the tag of a field refers to the value associated with the +// key 'xml' in the struct field's tag (see the example above). +// +// * If the struct has a field of type []byte or string with tag +// ",innerxml", Unmarshal accumulates the raw XML nested inside the +// element in that field. The rest of the rules still apply. +// +// * If the struct has a field named XMLName of type xml.Name, +// Unmarshal records the element name in that field. +// +// * If the XMLName field has an associated tag of the form +// "name" or "namespace-URL name", the XML element must have +// the given name (and, optionally, name space) or else Unmarshal +// returns an error. +// +// * If the XML element has an attribute whose name matches a +// struct field name with an associated tag containing ",attr" or +// the explicit name in a struct field tag of the form "name,attr", +// Unmarshal records the attribute value in that field. +// +// * If the XML element contains character data, that data is +// accumulated in the first struct field that has tag ",chardata". +// The struct field may have type []byte or string. +// If there is no such field, the character data is discarded. +// +// * If the XML element contains comments, they are accumulated in +// the first struct field that has tag ",comment". The struct +// field may have type []byte or string. If there is no such +// field, the comments are discarded. +// +// * If the XML element contains a sub-element whose name matches +// the prefix of a tag formatted as "a" or "a>b>c", unmarshal +// will descend into the XML structure looking for elements with the +// given names, and will map the innermost elements to that struct +// field. A tag starting with ">" is equivalent to one starting +// with the field name followed by ">". +// +// * If the XML element contains a sub-element whose name matches +// a struct field's XMLName tag and the struct field has no +// explicit name tag as per the previous rule, unmarshal maps +// the sub-element to that struct field. +// +// * If the XML element contains a sub-element whose name matches a +// field without any mode flags (",attr", ",chardata", etc), Unmarshal +// maps the sub-element to that struct field. +// +// * If the XML element contains a sub-element that hasn't matched any +// of the above rules and the struct has a field with tag ",any", +// unmarshal maps the sub-element to that struct field. +// +// * An anonymous struct field is handled as if the fields of its +// value were part of the outer struct. +// +// * A struct field with tag "-" is never unmarshalled into. +// +// Unmarshal maps an XML element to a string or []byte by saving the +// concatenation of that element's character data in the string or +// []byte. The saved []byte is never nil. +// +// Unmarshal maps an attribute value to a string or []byte by saving +// the value in the string or slice. +// +// Unmarshal maps an XML element to a slice by extending the length of +// the slice and mapping the element to the newly created value. +// +// Unmarshal maps an XML element or attribute value to a bool by +// setting it to the boolean value represented by the string. +// +// Unmarshal maps an XML element or attribute value to an integer or +// floating-point field by setting the field to the result of +// interpreting the string value in decimal. There is no check for +// overflow. +// +// Unmarshal maps an XML element to an xml.Name by recording the +// element name. +// +// Unmarshal maps an XML element to a pointer by setting the pointer +// to a freshly allocated value and then mapping the element to that value. +// +func Unmarshal(data []byte, v interface{}) error { + return NewDecoder(bytes.NewReader(data)).Decode(v) +} + +// Decode works like xml.Unmarshal, except it reads the decoder +// stream to find the start element. +func (d *Decoder) Decode(v interface{}) error { + return d.DecodeElement(v, nil) +} + +// DecodeElement works like xml.Unmarshal except that it takes +// a pointer to the start XML element to decode into v. +// It is useful when a client reads some raw XML tokens itself +// but also wants to defer to Unmarshal for some elements. +func (d *Decoder) DecodeElement(v interface{}, start *StartElement) error { + val := reflect.ValueOf(v) + if val.Kind() != reflect.Ptr { + return errors.New("non-pointer passed to Unmarshal") + } + return d.unmarshal(val.Elem(), start) +} + +// An UnmarshalError represents an error in the unmarshalling process. +type UnmarshalError string + +func (e UnmarshalError) Error() string { return string(e) } + +// Unmarshaler is the interface implemented by objects that can unmarshal +// an XML element description of themselves. +// +// UnmarshalXML decodes a single XML element +// beginning with the given start element. +// If it returns an error, the outer call to Unmarshal stops and +// returns that error. +// UnmarshalXML must consume exactly one XML element. +// One common implementation strategy is to unmarshal into +// a separate value with a layout matching the expected XML +// using d.DecodeElement, and then to copy the data from +// that value into the receiver. +// Another common strategy is to use d.Token to process the +// XML object one token at a time. +// UnmarshalXML may not use d.RawToken. +type Unmarshaler interface { + UnmarshalXML(d *Decoder, start StartElement) error +} + +// UnmarshalerAttr is the interface implemented by objects that can unmarshal +// an XML attribute description of themselves. +// +// UnmarshalXMLAttr decodes a single XML attribute. +// If it returns an error, the outer call to Unmarshal stops and +// returns that error. +// UnmarshalXMLAttr is used only for struct fields with the +// "attr" option in the field tag. +type UnmarshalerAttr interface { + UnmarshalXMLAttr(attr Attr) error +} + +// receiverType returns the receiver type to use in an expression like "%s.MethodName". +func receiverType(val interface{}) string { + t := reflect.TypeOf(val) + if t.Name() != "" { + return t.String() + } + return "(" + t.String() + ")" +} + +// unmarshalInterface unmarshals a single XML element into val. +// start is the opening tag of the element. +func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error { + // Record that decoder must stop at end tag corresponding to start. + p.pushEOF() + + p.unmarshalDepth++ + err := val.UnmarshalXML(p, *start) + p.unmarshalDepth-- + if err != nil { + p.popEOF() + return err + } + + if !p.popEOF() { + return fmt.Errorf("xml: %s.UnmarshalXML did not consume entire <%s> element", receiverType(val), start.Name.Local) + } + + return nil +} + +// unmarshalTextInterface unmarshals a single XML element into val. +// The chardata contained in the element (but not its children) +// is passed to the text unmarshaler. +func (p *Decoder) unmarshalTextInterface(val encoding.TextUnmarshaler, start *StartElement) error { + var buf []byte + depth := 1 + for depth > 0 { + t, err := p.Token() + if err != nil { + return err + } + switch t := t.(type) { + case CharData: + if depth == 1 { + buf = append(buf, t...) + } + case StartElement: + depth++ + case EndElement: + depth-- + } + } + return val.UnmarshalText(buf) +} + +// unmarshalAttr unmarshals a single XML attribute into val. +func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { + if val.Kind() == reflect.Ptr { + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + val = val.Elem() + } + + if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) { + // This is an unmarshaler with a non-pointer receiver, + // so it's likely to be incorrect, but we do what we're told. + return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) { + return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr) + } + } + + // Not an UnmarshalerAttr; try encoding.TextUnmarshaler. + if val.CanInterface() && val.Type().Implements(textUnmarshalerType) { + // This is an unmarshaler with a non-pointer receiver, + // so it's likely to be incorrect, but we do what we're told. + return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value)) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { + return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value)) + } + } + + copyValue(val, []byte(attr.Value)) + return nil +} + +var ( + unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() + unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +// Unmarshal a single XML element into val. +func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { + // Find start element if we need it. + if start == nil { + for { + tok, err := p.Token() + if err != nil { + return err + } + if t, ok := tok.(StartElement); ok { + start = &t + break + } + } + } + + // Load value from interface, but only if the result will be + // usefully addressable. + if val.Kind() == reflect.Interface && !val.IsNil() { + e := val.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() { + val = e + } + } + + if val.Kind() == reflect.Ptr { + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + val = val.Elem() + } + + if val.CanInterface() && val.Type().Implements(unmarshalerType) { + // This is an unmarshaler with a non-pointer receiver, + // so it's likely to be incorrect, but we do what we're told. + return p.unmarshalInterface(val.Interface().(Unmarshaler), start) + } + + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(unmarshalerType) { + return p.unmarshalInterface(pv.Interface().(Unmarshaler), start) + } + } + + if val.CanInterface() && val.Type().Implements(textUnmarshalerType) { + return p.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler), start) + } + + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { + return p.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler), start) + } + } + + var ( + data []byte + saveData reflect.Value + comment []byte + saveComment reflect.Value + saveXML reflect.Value + saveXMLIndex int + saveXMLData []byte + saveAny reflect.Value + sv reflect.Value + tinfo *typeInfo + err error + ) + + switch v := val; v.Kind() { + default: + return errors.New("unknown type " + v.Type().String()) + + case reflect.Interface: + // TODO: For now, simply ignore the field. In the near + // future we may choose to unmarshal the start + // element on it, if not nil. + return p.Skip() + + case reflect.Slice: + typ := v.Type() + if typ.Elem().Kind() == reflect.Uint8 { + // []byte + saveData = v + break + } + + // Slice of element values. + // Grow slice. + n := v.Len() + if n >= v.Cap() { + ncap := 2 * n + if ncap < 4 { + ncap = 4 + } + new := reflect.MakeSlice(typ, n, ncap) + reflect.Copy(new, v) + v.Set(new) + } + v.SetLen(n + 1) + + // Recur to read element into slice. + if err := p.unmarshal(v.Index(n), start); err != nil { + v.SetLen(n) + return err + } + return nil + + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: + saveData = v + + case reflect.Struct: + typ := v.Type() + if typ == nameType { + v.Set(reflect.ValueOf(start.Name)) + break + } + + sv = v + tinfo, err = getTypeInfo(typ) + if err != nil { + return err + } + + // Validate and assign element name. + if tinfo.xmlname != nil { + finfo := tinfo.xmlname + if finfo.name != "" && finfo.name != start.Name.Local { + return UnmarshalError("expected element type <" + finfo.name + "> but have <" + start.Name.Local + ">") + } + if finfo.xmlns != "" && finfo.xmlns != start.Name.Space { + e := "expected element <" + finfo.name + "> in name space " + finfo.xmlns + " but have " + if start.Name.Space == "" { + e += "no name space" + } else { + e += start.Name.Space + } + return UnmarshalError(e) + } + fv := finfo.value(sv) + if _, ok := fv.Interface().(Name); ok { + fv.Set(reflect.ValueOf(start.Name)) + } + } + + // Assign attributes. + // Also, determine whether we need to save character data or comments. + for i := range tinfo.fields { + finfo := &tinfo.fields[i] + switch finfo.flags & fMode { + case fAttr: + strv := finfo.value(sv) + // Look for attribute. + for _, a := range start.Attr { + if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) { + if err := p.unmarshalAttr(strv, a); err != nil { + return err + } + break + } + } + + case fCharData: + if !saveData.IsValid() { + saveData = finfo.value(sv) + } + + case fComment: + if !saveComment.IsValid() { + saveComment = finfo.value(sv) + } + + case fAny, fAny | fElement: + if !saveAny.IsValid() { + saveAny = finfo.value(sv) + } + + case fInnerXml: + if !saveXML.IsValid() { + saveXML = finfo.value(sv) + if p.saved == nil { + saveXMLIndex = 0 + p.saved = new(bytes.Buffer) + } else { + saveXMLIndex = p.savedOffset() + } + } + } + } + } + + // Find end element. + // Process sub-elements along the way. +Loop: + for { + var savedOffset int + if saveXML.IsValid() { + savedOffset = p.savedOffset() + } + tok, err := p.Token() + if err != nil { + return err + } + switch t := tok.(type) { + case StartElement: + consumed := false + if sv.IsValid() { + consumed, err = p.unmarshalPath(tinfo, sv, nil, &t) + if err != nil { + return err + } + if !consumed && saveAny.IsValid() { + consumed = true + if err := p.unmarshal(saveAny, &t); err != nil { + return err + } + } + } + if !consumed { + if err := p.Skip(); err != nil { + return err + } + } + + case EndElement: + if saveXML.IsValid() { + saveXMLData = p.saved.Bytes()[saveXMLIndex:savedOffset] + if saveXMLIndex == 0 { + p.saved = nil + } + } + break Loop + + case CharData: + if saveData.IsValid() { + data = append(data, t...) + } + + case Comment: + if saveComment.IsValid() { + comment = append(comment, t...) + } + } + } + + if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) { + if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { + return err + } + saveData = reflect.Value{} + } + + if saveData.IsValid() && saveData.CanAddr() { + pv := saveData.Addr() + if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { + if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { + return err + } + saveData = reflect.Value{} + } + } + + if err := copyValue(saveData, data); err != nil { + return err + } + + switch t := saveComment; t.Kind() { + case reflect.String: + t.SetString(string(comment)) + case reflect.Slice: + t.Set(reflect.ValueOf(comment)) + } + + switch t := saveXML; t.Kind() { + case reflect.String: + t.SetString(string(saveXMLData)) + case reflect.Slice: + t.Set(reflect.ValueOf(saveXMLData)) + } + + return nil +} + +func copyValue(dst reflect.Value, src []byte) (err error) { + dst0 := dst + + if dst.Kind() == reflect.Ptr { + if dst.IsNil() { + dst.Set(reflect.New(dst.Type().Elem())) + } + dst = dst.Elem() + } + + // Save accumulated data. + switch dst.Kind() { + case reflect.Invalid: + // Probably a comment. + default: + return errors.New("cannot unmarshal into " + dst0.Type().String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + itmp, err := strconv.ParseInt(string(src), 10, dst.Type().Bits()) + if err != nil { + return err + } + dst.SetInt(itmp) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + utmp, err := strconv.ParseUint(string(src), 10, dst.Type().Bits()) + if err != nil { + return err + } + dst.SetUint(utmp) + case reflect.Float32, reflect.Float64: + ftmp, err := strconv.ParseFloat(string(src), dst.Type().Bits()) + if err != nil { + return err + } + dst.SetFloat(ftmp) + case reflect.Bool: + value, err := strconv.ParseBool(strings.TrimSpace(string(src))) + if err != nil { + return err + } + dst.SetBool(value) + case reflect.String: + dst.SetString(string(src)) + case reflect.Slice: + if len(src) == 0 { + // non-nil to flag presence + src = []byte{} + } + dst.SetBytes(src) + } + return nil +} + +// unmarshalPath walks down an XML structure looking for wanted +// paths, and calls unmarshal on them. +// The consumed result tells whether XML elements have been consumed +// from the Decoder until start's matching end element, or if it's +// still untouched because start is uninteresting for sv's fields. +func (p *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement) (consumed bool, err error) { + recurse := false +Loop: + for i := range tinfo.fields { + finfo := &tinfo.fields[i] + if finfo.flags&fElement == 0 || len(finfo.parents) < len(parents) || finfo.xmlns != "" && finfo.xmlns != start.Name.Space { + continue + } + for j := range parents { + if parents[j] != finfo.parents[j] { + continue Loop + } + } + if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local { + // It's a perfect match, unmarshal the field. + return true, p.unmarshal(finfo.value(sv), start) + } + if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local { + // It's a prefix for the field. Break and recurse + // since it's not ok for one field path to be itself + // the prefix for another field path. + recurse = true + + // We can reuse the same slice as long as we + // don't try to append to it. + parents = finfo.parents[:len(parents)+1] + break + } + } + if !recurse { + // We have no business with this element. + return false, nil + } + // The element is not a perfect match for any field, but one + // or more fields have the path to this element as a parent + // prefix. Recurse and attempt to match these. + for { + var tok Token + tok, err = p.Token() + if err != nil { + return true, err + } + switch t := tok.(type) { + case StartElement: + consumed2, err := p.unmarshalPath(tinfo, sv, parents, &t) + if err != nil { + return true, err + } + if !consumed2 { + if err := p.Skip(); err != nil { + return true, err + } + } + case EndElement: + return true, nil + } + } +} + +// Skip reads tokens until it has consumed the end element +// matching the most recent start element already consumed. +// It recurs if it encounters a start element, so it can be used to +// skip nested structures. +// It returns nil if it finds an end element matching the start +// element; otherwise it returns an error describing the problem. +func (d *Decoder) Skip() error { + for { + tok, err := d.Token() + if err != nil { + return err + } + switch tok.(type) { + case StartElement: + if err := d.Skip(); err != nil { + return err + } + case EndElement: + return nil + } + } +} diff --git a/src/encoding/xml/read_test.go b/src/encoding/xml/read_test.go new file mode 100644 index 000000000..01f55d0dd --- /dev/null +++ b/src/encoding/xml/read_test.go @@ -0,0 +1,714 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "io" + "reflect" + "strings" + "testing" + "time" +) + +// Stripped down Atom feed data structures. + +func TestUnmarshalFeed(t *testing.T) { + var f Feed + if err := Unmarshal([]byte(atomFeedString), &f); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if !reflect.DeepEqual(f, atomFeed) { + t.Fatalf("have %#v\nwant %#v", f, atomFeed) + } +} + +// hget http://codereview.appspot.com/rss/mine/rsc +const atomFeedString = ` +<?xml version="1.0" encoding="utf-8"?> +<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us" updated="2009-10-04T01:35:58+00:00"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><link href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></link><id>http://codereview.appspot.com/</id><author><name>rietveld<></name></author><entry><title>rietveld: an attempt at pubsubhubbub +</title><link href="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html"> + An attempt at adding pubsubhubbub support to Rietveld. +http://code.google.com/p/pubsubhubbub +http://code.google.com/p/rietveld/issues/detail?id=155 + +The server side of the protocol is trivial: + 1. add a &lt;link rel=&quot;hub&quot; href=&quot;hub-server&quot;&gt; tag to all + feeds that will be pubsubhubbubbed. + 2. every time one of those feeds changes, tell the hub + with a simple POST request. + +I have tested this by adding debug prints to a local hub +server and checking that the server got the right publish +requests. + +I can&#39;t quite get the server to work, but I think the bug +is not in my code. I think that the server expects to be +able to grab the feed and see the feed&#39;s actual URL in +the link rel=&quot;self&quot;, but the default value for that drops +the :port from the URL, and I cannot for the life of me +figure out how to get the Atom generator deep inside +django not to do that, or even where it is doing that, +or even what code is running to generate the Atom feed. +(I thought I knew but I added some assert False statements +and it kept running!) + +Ignoring that particular problem, I would appreciate +feedback on the right way to get the two values at +the top of feeds.py marked NOTE(rsc). + + +</summary></entry><entry><title>rietveld: correct tab handling +</title><link href="http://codereview.appspot.com/124106" rel="alternate"></link><updated>2009-10-03T23:02:17+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:0a2a4f19bb815101f0ba2904aed7c35a</id><summary type="html"> + This fixes the buggy tab rendering that can be seen at +http://codereview.appspot.com/116075/diff/1/2 + +The fundamental problem was that the tab code was +not being told what column the text began in, so it +didn&#39;t know where to put the tab stops. Another problem +was that some of the code assumed that string byte +offsets were the same as column offsets, which is only +true if there are no tabs. + +In the process of fixing this, I cleaned up the arguments +to Fold and ExpandTabs and renamed them Break and +_ExpandTabs so that I could be sure that I found all the +call sites. I also wanted to verify that ExpandTabs was +not being used from outside intra_region_diff.py. + + +</summary></entry></feed> ` + +type Feed struct { + XMLName Name `xml:"http://www.w3.org/2005/Atom feed"` + Title string `xml:"title"` + Id string `xml:"id"` + Link []Link `xml:"link"` + Updated time.Time `xml:"updated,attr"` + Author Person `xml:"author"` + Entry []Entry `xml:"entry"` +} + +type Entry struct { + Title string `xml:"title"` + Id string `xml:"id"` + Link []Link `xml:"link"` + Updated time.Time `xml:"updated"` + Author Person `xml:"author"` + Summary Text `xml:"summary"` +} + +type Link struct { + Rel string `xml:"rel,attr,omitempty"` + Href string `xml:"href,attr"` +} + +type Person struct { + Name string `xml:"name"` + URI string `xml:"uri"` + Email string `xml:"email"` + InnerXML string `xml:",innerxml"` +} + +type Text struct { + Type string `xml:"type,attr,omitempty"` + Body string `xml:",chardata"` +} + +var atomFeed = Feed{ + XMLName: Name{"http://www.w3.org/2005/Atom", "feed"}, + Title: "Code Review - My issues", + Link: []Link{ + {Rel: "alternate", Href: "http://codereview.appspot.com/"}, + {Rel: "self", Href: "http://codereview.appspot.com/rss/mine/rsc"}, + }, + Id: "http://codereview.appspot.com/", + Updated: ParseTime("2009-10-04T01:35:58+00:00"), + Author: Person{ + Name: "rietveld<>", + InnerXML: "<name>rietveld<></name>", + }, + Entry: []Entry{ + { + Title: "rietveld: an attempt at pubsubhubbub\n", + Link: []Link{ + {Rel: "alternate", Href: "http://codereview.appspot.com/126085"}, + }, + Updated: ParseTime("2009-10-04T01:35:58+00:00"), + Author: Person{ + Name: "email-address-removed", + InnerXML: "<name>email-address-removed</name>", + }, + Id: "urn:md5:134d9179c41f806be79b3a5f7877d19a", + Summary: Text{ + Type: "html", + Body: ` + An attempt at adding pubsubhubbub support to Rietveld. +http://code.google.com/p/pubsubhubbub +http://code.google.com/p/rietveld/issues/detail?id=155 + +The server side of the protocol is trivial: + 1. add a <link rel="hub" href="hub-server"> tag to all + feeds that will be pubsubhubbubbed. + 2. every time one of those feeds changes, tell the hub + with a simple POST request. + +I have tested this by adding debug prints to a local hub +server and checking that the server got the right publish +requests. + +I can't quite get the server to work, but I think the bug +is not in my code. I think that the server expects to be +able to grab the feed and see the feed's actual URL in +the link rel="self", but the default value for that drops +the :port from the URL, and I cannot for the life of me +figure out how to get the Atom generator deep inside +django not to do that, or even where it is doing that, +or even what code is running to generate the Atom feed. +(I thought I knew but I added some assert False statements +and it kept running!) + +Ignoring that particular problem, I would appreciate +feedback on the right way to get the two values at +the top of feeds.py marked NOTE(rsc). + + +`, + }, + }, + { + Title: "rietveld: correct tab handling\n", + Link: []Link{ + {Rel: "alternate", Href: "http://codereview.appspot.com/124106"}, + }, + Updated: ParseTime("2009-10-03T23:02:17+00:00"), + Author: Person{ + Name: "email-address-removed", + InnerXML: "<name>email-address-removed</name>", + }, + Id: "urn:md5:0a2a4f19bb815101f0ba2904aed7c35a", + Summary: Text{ + Type: "html", + Body: ` + This fixes the buggy tab rendering that can be seen at +http://codereview.appspot.com/116075/diff/1/2 + +The fundamental problem was that the tab code was +not being told what column the text began in, so it +didn't know where to put the tab stops. Another problem +was that some of the code assumed that string byte +offsets were the same as column offsets, which is only +true if there are no tabs. + +In the process of fixing this, I cleaned up the arguments +to Fold and ExpandTabs and renamed them Break and +_ExpandTabs so that I could be sure that I found all the +call sites. I also wanted to verify that ExpandTabs was +not being used from outside intra_region_diff.py. + + +`, + }, + }, + }, +} + +const pathTestString = ` +<Result> + <Before>1</Before> + <Items> + <Item1> + <Value>A</Value> + </Item1> + <Item2> + <Value>B</Value> + </Item2> + <Item1> + <Value>C</Value> + <Value>D</Value> + </Item1> + <_> + <Value>E</Value> + </_> + </Items> + <After>2</After> +</Result> +` + +type PathTestItem struct { + Value string +} + +type PathTestA struct { + Items []PathTestItem `xml:">Item1"` + Before, After string +} + +type PathTestB struct { + Other []PathTestItem `xml:"Items>Item1"` + Before, After string +} + +type PathTestC struct { + Values1 []string `xml:"Items>Item1>Value"` + Values2 []string `xml:"Items>Item2>Value"` + Before, After string +} + +type PathTestSet struct { + Item1 []PathTestItem +} + +type PathTestD struct { + Other PathTestSet `xml:"Items"` + Before, After string +} + +type PathTestE struct { + Underline string `xml:"Items>_>Value"` + Before, After string +} + +var pathTests = []interface{}{ + &PathTestA{Items: []PathTestItem{{"A"}, {"D"}}, Before: "1", After: "2"}, + &PathTestB{Other: []PathTestItem{{"A"}, {"D"}}, Before: "1", After: "2"}, + &PathTestC{Values1: []string{"A", "C", "D"}, Values2: []string{"B"}, Before: "1", After: "2"}, + &PathTestD{Other: PathTestSet{Item1: []PathTestItem{{"A"}, {"D"}}}, Before: "1", After: "2"}, + &PathTestE{Underline: "E", Before: "1", After: "2"}, +} + +func TestUnmarshalPaths(t *testing.T) { + for _, pt := range pathTests { + v := reflect.New(reflect.TypeOf(pt).Elem()).Interface() + if err := Unmarshal([]byte(pathTestString), v); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if !reflect.DeepEqual(v, pt) { + t.Fatalf("have %#v\nwant %#v", v, pt) + } + } +} + +type BadPathTestA struct { + First string `xml:"items>item1"` + Other string `xml:"items>item2"` + Second string `xml:"items"` +} + +type BadPathTestB struct { + Other string `xml:"items>item2>value"` + First string `xml:"items>item1"` + Second string `xml:"items>item1>value"` +} + +type BadPathTestC struct { + First string + Second string `xml:"First"` +} + +type BadPathTestD struct { + BadPathEmbeddedA + BadPathEmbeddedB +} + +type BadPathEmbeddedA struct { + First string +} + +type BadPathEmbeddedB struct { + Second string `xml:"First"` +} + +var badPathTests = []struct { + v, e interface{} +}{ + {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items"}}, + {&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, + {&BadPathTestC{}, &TagPathError{reflect.TypeOf(BadPathTestC{}), "First", "", "Second", "First"}}, + {&BadPathTestD{}, &TagPathError{reflect.TypeOf(BadPathTestD{}), "First", "", "Second", "First"}}, +} + +func TestUnmarshalBadPaths(t *testing.T) { + for _, tt := range badPathTests { + err := Unmarshal([]byte(pathTestString), tt.v) + if !reflect.DeepEqual(err, tt.e) { + t.Fatalf("Unmarshal with %#v didn't fail properly:\nhave %#v,\nwant %#v", tt.v, err, tt.e) + } + } +} + +const OK = "OK" +const withoutNameTypeData = ` +<?xml version="1.0" charset="utf-8"?> +<Test3 Attr="OK" />` + +type TestThree struct { + XMLName Name `xml:"Test3"` + Attr string `xml:",attr"` +} + +func TestUnmarshalWithoutNameType(t *testing.T) { + var x TestThree + if err := Unmarshal([]byte(withoutNameTypeData), &x); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if x.Attr != OK { + t.Fatalf("have %v\nwant %v", x.Attr, OK) + } +} + +func TestUnmarshalAttr(t *testing.T) { + type ParamVal struct { + Int int `xml:"int,attr"` + } + + type ParamPtr struct { + Int *int `xml:"int,attr"` + } + + type ParamStringPtr struct { + Int *string `xml:"int,attr"` + } + + x := []byte(`<Param int="1" />`) + + p1 := &ParamPtr{} + if err := Unmarshal(x, p1); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if p1.Int == nil { + t.Fatalf("Unmarshal failed in to *int field") + } else if *p1.Int != 1 { + t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p1.Int, 1) + } + + p2 := &ParamVal{} + if err := Unmarshal(x, p2); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if p2.Int != 1 { + t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p2.Int, 1) + } + + p3 := &ParamStringPtr{} + if err := Unmarshal(x, p3); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if p3.Int == nil { + t.Fatalf("Unmarshal failed in to *string field") + } else if *p3.Int != "1" { + t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p3.Int, 1) + } +} + +type Tables struct { + HTable string `xml:"http://www.w3.org/TR/html4/ table"` + FTable string `xml:"http://www.w3schools.com/furniture table"` +} + +var tables = []struct { + xml string + tab Tables + ns string +}{ + { + xml: `<Tables>` + + `<table xmlns="http://www.w3.org/TR/html4/">hello</table>` + + `<table xmlns="http://www.w3schools.com/furniture">world</table>` + + `</Tables>`, + tab: Tables{"hello", "world"}, + }, + { + xml: `<Tables>` + + `<table xmlns="http://www.w3schools.com/furniture">world</table>` + + `<table xmlns="http://www.w3.org/TR/html4/">hello</table>` + + `</Tables>`, + tab: Tables{"hello", "world"}, + }, + { + xml: `<Tables xmlns:f="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/">` + + `<f:table>world</f:table>` + + `<h:table>hello</h:table>` + + `</Tables>`, + tab: Tables{"hello", "world"}, + }, + { + xml: `<Tables>` + + `<table>bogus</table>` + + `</Tables>`, + tab: Tables{}, + }, + { + xml: `<Tables>` + + `<table>only</table>` + + `</Tables>`, + tab: Tables{HTable: "only"}, + ns: "http://www.w3.org/TR/html4/", + }, + { + xml: `<Tables>` + + `<table>only</table>` + + `</Tables>`, + tab: Tables{FTable: "only"}, + ns: "http://www.w3schools.com/furniture", + }, + { + xml: `<Tables>` + + `<table>only</table>` + + `</Tables>`, + tab: Tables{}, + ns: "something else entirely", + }, +} + +func TestUnmarshalNS(t *testing.T) { + for i, tt := range tables { + var dst Tables + var err error + if tt.ns != "" { + d := NewDecoder(strings.NewReader(tt.xml)) + d.DefaultSpace = tt.ns + err = d.Decode(&dst) + } else { + err = Unmarshal([]byte(tt.xml), &dst) + } + if err != nil { + t.Errorf("#%d: Unmarshal: %v", i, err) + continue + } + want := tt.tab + if dst != want { + t.Errorf("#%d: dst=%+v, want %+v", i, dst, want) + } + } +} + +func TestMarshalNS(t *testing.T) { + dst := Tables{"hello", "world"} + data, err := Marshal(&dst) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + want := `<Tables><table xmlns="http://www.w3.org/TR/html4/">hello</table><table xmlns="http://www.w3schools.com/furniture">world</table></Tables>` + str := string(data) + if str != want { + t.Errorf("have: %q\nwant: %q\n", str, want) + } +} + +type TableAttrs struct { + TAttr TAttr +} + +type TAttr struct { + HTable string `xml:"http://www.w3.org/TR/html4/ table,attr"` + FTable string `xml:"http://www.w3schools.com/furniture table,attr"` + Lang string `xml:"http://www.w3.org/XML/1998/namespace lang,attr,omitempty"` + Other1 string `xml:"http://golang.org/xml/ other,attr,omitempty"` + Other2 string `xml:"http://golang.org/xmlfoo/ other,attr,omitempty"` + Other3 string `xml:"http://golang.org/json/ other,attr,omitempty"` + Other4 string `xml:"http://golang.org/2/json/ other,attr,omitempty"` +} + +var tableAttrs = []struct { + xml string + tab TableAttrs + ns string +}{ + { + xml: `<TableAttrs xmlns:f="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/"><TAttr ` + + `h:table="hello" f:table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: "world"}}, + }, + { + xml: `<TableAttrs><TAttr xmlns:f="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/" ` + + `h:table="hello" f:table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: "world"}}, + }, + { + xml: `<TableAttrs><TAttr ` + + `h:table="hello" f:table="world" xmlns:f="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: "world"}}, + }, + { + // Default space does not apply to attribute names. + xml: `<TableAttrs xmlns="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/"><TAttr ` + + `h:table="hello" table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: ""}}, + }, + { + // Default space does not apply to attribute names. + xml: `<TableAttrs xmlns:f="http://www.w3schools.com/furniture"><TAttr xmlns="http://www.w3.org/TR/html4/" ` + + `table="hello" f:table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "", FTable: "world"}}, + }, + { + xml: `<TableAttrs><TAttr ` + + `table="bogus" ` + + `/></TableAttrs>`, + tab: TableAttrs{}, + }, + { + // Default space does not apply to attribute names. + xml: `<TableAttrs xmlns:h="http://www.w3.org/TR/html4/"><TAttr ` + + `h:table="hello" table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: ""}}, + ns: "http://www.w3schools.com/furniture", + }, + { + // Default space does not apply to attribute names. + xml: `<TableAttrs xmlns:f="http://www.w3schools.com/furniture"><TAttr ` + + `table="hello" f:table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "", FTable: "world"}}, + ns: "http://www.w3.org/TR/html4/", + }, + { + xml: `<TableAttrs><TAttr ` + + `table="bogus" ` + + `/></TableAttrs>`, + tab: TableAttrs{}, + ns: "something else entirely", + }, +} + +func TestUnmarshalNSAttr(t *testing.T) { + for i, tt := range tableAttrs { + var dst TableAttrs + var err error + if tt.ns != "" { + d := NewDecoder(strings.NewReader(tt.xml)) + d.DefaultSpace = tt.ns + err = d.Decode(&dst) + } else { + err = Unmarshal([]byte(tt.xml), &dst) + } + if err != nil { + t.Errorf("#%d: Unmarshal: %v", i, err) + continue + } + want := tt.tab + if dst != want { + t.Errorf("#%d: dst=%+v, want %+v", i, dst, want) + } + } +} + +func TestMarshalNSAttr(t *testing.T) { + src := TableAttrs{TAttr{"hello", "world", "en_US", "other1", "other2", "other3", "other4"}} + data, err := Marshal(&src) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + want := `<TableAttrs><TAttr xmlns:html4="http://www.w3.org/TR/html4/" html4:table="hello" xmlns:furniture="http://www.w3schools.com/furniture" furniture:table="world" xml:lang="en_US" xmlns:_xml="http://golang.org/xml/" _xml:other="other1" xmlns:_xmlfoo="http://golang.org/xmlfoo/" _xmlfoo:other="other2" xmlns:json="http://golang.org/json/" json:other="other3" xmlns:json_1="http://golang.org/2/json/" json_1:other="other4"></TAttr></TableAttrs>` + str := string(data) + if str != want { + t.Errorf("Marshal:\nhave: %#q\nwant: %#q\n", str, want) + } + + var dst TableAttrs + if err := Unmarshal(data, &dst); err != nil { + t.Errorf("Unmarshal: %v", err) + } + + if dst != src { + t.Errorf("Unmarshal = %q, want %q", dst, src) + } +} + +type MyCharData struct { + body string +} + +func (m *MyCharData) UnmarshalXML(d *Decoder, start StartElement) error { + for { + t, err := d.Token() + if err == io.EOF { // found end of element + break + } + if err != nil { + return err + } + if char, ok := t.(CharData); ok { + m.body += string(char) + } + } + return nil +} + +var _ Unmarshaler = (*MyCharData)(nil) + +func (m *MyCharData) UnmarshalXMLAttr(attr Attr) error { + panic("must not call") +} + +type MyAttr struct { + attr string +} + +func (m *MyAttr) UnmarshalXMLAttr(attr Attr) error { + m.attr = attr.Value + return nil +} + +var _ UnmarshalerAttr = (*MyAttr)(nil) + +type MyStruct struct { + Data *MyCharData + Attr *MyAttr `xml:",attr"` + + Data2 MyCharData + Attr2 MyAttr `xml:",attr"` +} + +func TestUnmarshaler(t *testing.T) { + xml := `<?xml version="1.0" encoding="utf-8"?> + <MyStruct Attr="attr1" Attr2="attr2"> + <Data>hello <!-- comment -->world</Data> + <Data2>howdy <!-- comment -->world</Data2> + </MyStruct> + ` + + var m MyStruct + if err := Unmarshal([]byte(xml), &m); err != nil { + t.Fatal(err) + } + + if m.Data == nil || m.Attr == nil || m.Data.body != "hello world" || m.Attr.attr != "attr1" || m.Data2.body != "howdy world" || m.Attr2.attr != "attr2" { + t.Errorf("m=%#+v\n", m) + } +} + +type Pea struct { + Cotelydon string +} + +type Pod struct { + Pea interface{} `xml:"Pea"` +} + +// https://code.google.com/p/go/issues/detail?id=6836 +func TestUnmarshalIntoInterface(t *testing.T) { + pod := new(Pod) + pod.Pea = new(Pea) + xml := `<Pod><Pea><Cotelydon>Green stuff</Cotelydon></Pea></Pod>` + err := Unmarshal([]byte(xml), pod) + if err != nil { + t.Fatalf("failed to unmarshal %q: %v", xml, err) + } + pea, ok := pod.Pea.(*Pea) + if !ok { + t.Fatalf("unmarshalled into wrong type: have %T want *Pea", pod.Pea) + } + have, want := pea.Cotelydon, "Green stuff" + if have != want { + t.Errorf("failed to unmarshal into interface, have %q want %q", have, want) + } +} diff --git a/src/encoding/xml/typeinfo.go b/src/encoding/xml/typeinfo.go new file mode 100644 index 000000000..22248d20a --- /dev/null +++ b/src/encoding/xml/typeinfo.go @@ -0,0 +1,363 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "fmt" + "reflect" + "strings" + "sync" +) + +// typeInfo holds details for the xml representation of a type. +type typeInfo struct { + xmlname *fieldInfo + fields []fieldInfo +} + +// fieldInfo holds details for the xml representation of a single field. +type fieldInfo struct { + idx []int + name string + xmlns string + flags fieldFlags + parents []string +} + +type fieldFlags int + +const ( + fElement fieldFlags = 1 << iota + fAttr + fCharData + fInnerXml + fComment + fAny + + fOmitEmpty + + fMode = fElement | fAttr | fCharData | fInnerXml | fComment | fAny +) + +var tinfoMap = make(map[reflect.Type]*typeInfo) +var tinfoLock sync.RWMutex + +var nameType = reflect.TypeOf(Name{}) + +// getTypeInfo returns the typeInfo structure with details necessary +// for marshalling and unmarshalling typ. +func getTypeInfo(typ reflect.Type) (*typeInfo, error) { + tinfoLock.RLock() + tinfo, ok := tinfoMap[typ] + tinfoLock.RUnlock() + if ok { + return tinfo, nil + } + tinfo = &typeInfo{} + if typ.Kind() == reflect.Struct && typ != nameType { + n := typ.NumField() + for i := 0; i < n; i++ { + f := typ.Field(i) + if f.PkgPath != "" || f.Tag.Get("xml") == "-" { + continue // Private field + } + + // For embedded structs, embed its fields. + if f.Anonymous { + t := f.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + inner, err := getTypeInfo(t) + if err != nil { + return nil, err + } + if tinfo.xmlname == nil { + tinfo.xmlname = inner.xmlname + } + for _, finfo := range inner.fields { + finfo.idx = append([]int{i}, finfo.idx...) + if err := addFieldInfo(typ, tinfo, &finfo); err != nil { + return nil, err + } + } + continue + } + } + + finfo, err := structFieldInfo(typ, &f) + if err != nil { + return nil, err + } + + if f.Name == "XMLName" { + tinfo.xmlname = finfo + continue + } + + // Add the field if it doesn't conflict with other fields. + if err := addFieldInfo(typ, tinfo, finfo); err != nil { + return nil, err + } + } + } + tinfoLock.Lock() + tinfoMap[typ] = tinfo + tinfoLock.Unlock() + return tinfo, nil +} + +// structFieldInfo builds and returns a fieldInfo for f. +func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, error) { + finfo := &fieldInfo{idx: f.Index} + + // Split the tag from the xml namespace if necessary. + tag := f.Tag.Get("xml") + if i := strings.Index(tag, " "); i >= 0 { + finfo.xmlns, tag = tag[:i], tag[i+1:] + } + + // Parse flags. + tokens := strings.Split(tag, ",") + if len(tokens) == 1 { + finfo.flags = fElement + } else { + tag = tokens[0] + for _, flag := range tokens[1:] { + switch flag { + case "attr": + finfo.flags |= fAttr + case "chardata": + finfo.flags |= fCharData + case "innerxml": + finfo.flags |= fInnerXml + case "comment": + finfo.flags |= fComment + case "any": + finfo.flags |= fAny + case "omitempty": + finfo.flags |= fOmitEmpty + } + } + + // Validate the flags used. + valid := true + switch mode := finfo.flags & fMode; mode { + case 0: + finfo.flags |= fElement + case fAttr, fCharData, fInnerXml, fComment, fAny: + if f.Name == "XMLName" || tag != "" && mode != fAttr { + valid = false + } + default: + // This will also catch multiple modes in a single field. + valid = false + } + if finfo.flags&fMode == fAny { + finfo.flags |= fElement + } + if finfo.flags&fOmitEmpty != 0 && finfo.flags&(fElement|fAttr) == 0 { + valid = false + } + if !valid { + return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q", + f.Name, typ, f.Tag.Get("xml")) + } + } + + // Use of xmlns without a name is not allowed. + if finfo.xmlns != "" && tag == "" { + return nil, fmt.Errorf("xml: namespace without name in field %s of type %s: %q", + f.Name, typ, f.Tag.Get("xml")) + } + + if f.Name == "XMLName" { + // The XMLName field records the XML element name. Don't + // process it as usual because its name should default to + // empty rather than to the field name. + finfo.name = tag + return finfo, nil + } + + if tag == "" { + // If the name part of the tag is completely empty, get + // default from XMLName of underlying struct if feasible, + // or field name otherwise. + if xmlname := lookupXMLName(f.Type); xmlname != nil { + finfo.xmlns, finfo.name = xmlname.xmlns, xmlname.name + } else { + finfo.name = f.Name + } + return finfo, nil + } + + // Prepare field name and parents. + parents := strings.Split(tag, ">") + if parents[0] == "" { + parents[0] = f.Name + } + if parents[len(parents)-1] == "" { + return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ) + } + finfo.name = parents[len(parents)-1] + if len(parents) > 1 { + if (finfo.flags & fElement) == 0 { + return nil, fmt.Errorf("xml: %s chain not valid with %s flag", tag, strings.Join(tokens[1:], ",")) + } + finfo.parents = parents[:len(parents)-1] + } + + // If the field type has an XMLName field, the names must match + // so that the behavior of both marshalling and unmarshalling + // is straightforward and unambiguous. + if finfo.flags&fElement != 0 { + ftyp := f.Type + xmlname := lookupXMLName(ftyp) + if xmlname != nil && xmlname.name != finfo.name { + return nil, fmt.Errorf("xml: name %q in tag of %s.%s conflicts with name %q in %s.XMLName", + finfo.name, typ, f.Name, xmlname.name, ftyp) + } + } + return finfo, nil +} + +// lookupXMLName returns the fieldInfo for typ's XMLName field +// in case it exists and has a valid xml field tag, otherwise +// it returns nil. +func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) { + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + if typ.Kind() != reflect.Struct { + return nil + } + for i, n := 0, typ.NumField(); i < n; i++ { + f := typ.Field(i) + if f.Name != "XMLName" { + continue + } + finfo, err := structFieldInfo(typ, &f) + if finfo.name != "" && err == nil { + return finfo + } + // Also consider errors as a non-existent field tag + // and let getTypeInfo itself report the error. + break + } + return nil +} + +func min(a, b int) int { + if a <= b { + return a + } + return b +} + +// addFieldInfo adds finfo to tinfo.fields if there are no +// conflicts, or if conflicts arise from previous fields that were +// obtained from deeper embedded structures than finfo. In the latter +// case, the conflicting entries are dropped. +// A conflict occurs when the path (parent + name) to a field is +// itself a prefix of another path, or when two paths match exactly. +// It is okay for field paths to share a common, shorter prefix. +func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo) error { + var conflicts []int +Loop: + // First, figure all conflicts. Most working code will have none. + for i := range tinfo.fields { + oldf := &tinfo.fields[i] + if oldf.flags&fMode != newf.flags&fMode { + continue + } + if oldf.xmlns != "" && newf.xmlns != "" && oldf.xmlns != newf.xmlns { + continue + } + minl := min(len(newf.parents), len(oldf.parents)) + for p := 0; p < minl; p++ { + if oldf.parents[p] != newf.parents[p] { + continue Loop + } + } + if len(oldf.parents) > len(newf.parents) { + if oldf.parents[len(newf.parents)] == newf.name { + conflicts = append(conflicts, i) + } + } else if len(oldf.parents) < len(newf.parents) { + if newf.parents[len(oldf.parents)] == oldf.name { + conflicts = append(conflicts, i) + } + } else { + if newf.name == oldf.name { + conflicts = append(conflicts, i) + } + } + } + // Without conflicts, add the new field and return. + if conflicts == nil { + tinfo.fields = append(tinfo.fields, *newf) + return nil + } + + // If any conflict is shallower, ignore the new field. + // This matches the Go field resolution on embedding. + for _, i := range conflicts { + if len(tinfo.fields[i].idx) < len(newf.idx) { + return nil + } + } + + // Otherwise, if any of them is at the same depth level, it's an error. + for _, i := range conflicts { + oldf := &tinfo.fields[i] + if len(oldf.idx) == len(newf.idx) { + f1 := typ.FieldByIndex(oldf.idx) + f2 := typ.FieldByIndex(newf.idx) + return &TagPathError{typ, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")} + } + } + + // Otherwise, the new field is shallower, and thus takes precedence, + // so drop the conflicting fields from tinfo and append the new one. + for c := len(conflicts) - 1; c >= 0; c-- { + i := conflicts[c] + copy(tinfo.fields[i:], tinfo.fields[i+1:]) + tinfo.fields = tinfo.fields[:len(tinfo.fields)-1] + } + tinfo.fields = append(tinfo.fields, *newf) + return nil +} + +// A TagPathError represents an error in the unmarshalling process +// caused by the use of field tags with conflicting paths. +type TagPathError struct { + Struct reflect.Type + Field1, Tag1 string + Field2, Tag2 string +} + +func (e *TagPathError) Error() string { + return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2) +} + +// value returns v's field value corresponding to finfo. +// It's equivalent to v.FieldByIndex(finfo.idx), but initializes +// and dereferences pointers as necessary. +func (finfo *fieldInfo) value(v reflect.Value) reflect.Value { + for i, x := range finfo.idx { + if i > 0 { + t := v.Type() + if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + } + v = v.Field(x) + } + return v +} diff --git a/src/encoding/xml/xml.go b/src/encoding/xml/xml.go new file mode 100644 index 000000000..8c15b98c3 --- /dev/null +++ b/src/encoding/xml/xml.go @@ -0,0 +1,1945 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package xml implements a simple XML 1.0 parser that +// understands XML name spaces. +package xml + +// References: +// Annotated XML spec: http://www.xml.com/axml/testaxml.htm +// XML name spaces: http://www.w3.org/TR/REC-xml-names/ + +// TODO(rsc): +// Test error handling. + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "strconv" + "strings" + "unicode" + "unicode/utf8" +) + +// A SyntaxError represents a syntax error in the XML input stream. +type SyntaxError struct { + Msg string + Line int +} + +func (e *SyntaxError) Error() string { + return "XML syntax error on line " + strconv.Itoa(e.Line) + ": " + e.Msg +} + +// A Name represents an XML name (Local) annotated +// with a name space identifier (Space). +// In tokens returned by Decoder.Token, the Space identifier +// is given as a canonical URL, not the short prefix used +// in the document being parsed. +type Name struct { + Space, Local string +} + +// An Attr represents an attribute in an XML element (Name=Value). +type Attr struct { + Name Name + Value string +} + +// A Token is an interface holding one of the token types: +// StartElement, EndElement, CharData, Comment, ProcInst, or Directive. +type Token interface{} + +// A StartElement represents an XML start element. +type StartElement struct { + Name Name + Attr []Attr +} + +func (e StartElement) Copy() StartElement { + attrs := make([]Attr, len(e.Attr)) + copy(attrs, e.Attr) + e.Attr = attrs + return e +} + +// End returns the corresponding XML end element. +func (e StartElement) End() EndElement { + return EndElement{e.Name} +} + +// An EndElement represents an XML end element. +type EndElement struct { + Name Name +} + +// A CharData represents XML character data (raw text), +// in which XML escape sequences have been replaced by +// the characters they represent. +type CharData []byte + +func makeCopy(b []byte) []byte { + b1 := make([]byte, len(b)) + copy(b1, b) + return b1 +} + +func (c CharData) Copy() CharData { return CharData(makeCopy(c)) } + +// A Comment represents an XML comment of the form <!--comment-->. +// The bytes do not include the <!-- and --> comment markers. +type Comment []byte + +func (c Comment) Copy() Comment { return Comment(makeCopy(c)) } + +// A ProcInst represents an XML processing instruction of the form <?target inst?> +type ProcInst struct { + Target string + Inst []byte +} + +func (p ProcInst) Copy() ProcInst { + p.Inst = makeCopy(p.Inst) + return p +} + +// A Directive represents an XML directive of the form <!text>. +// The bytes do not include the <! and > markers. +type Directive []byte + +func (d Directive) Copy() Directive { return Directive(makeCopy(d)) } + +// CopyToken returns a copy of a Token. +func CopyToken(t Token) Token { + switch v := t.(type) { + case CharData: + return v.Copy() + case Comment: + return v.Copy() + case Directive: + return v.Copy() + case ProcInst: + return v.Copy() + case StartElement: + return v.Copy() + } + return t +} + +// A Decoder represents an XML parser reading a particular input stream. +// The parser assumes that its input is encoded in UTF-8. +type Decoder struct { + // Strict defaults to true, enforcing the requirements + // of the XML specification. + // If set to false, the parser allows input containing common + // mistakes: + // * If an element is missing an end tag, the parser invents + // end tags as necessary to keep the return values from Token + // properly balanced. + // * In attribute values and character data, unknown or malformed + // character entities (sequences beginning with &) are left alone. + // + // Setting: + // + // d.Strict = false; + // d.AutoClose = HTMLAutoClose; + // d.Entity = HTMLEntity + // + // creates a parser that can handle typical HTML. + // + // Strict mode does not enforce the requirements of the XML name spaces TR. + // In particular it does not reject name space tags using undefined prefixes. + // Such tags are recorded with the unknown prefix as the name space URL. + Strict bool + + // When Strict == false, AutoClose indicates a set of elements to + // consider closed immediately after they are opened, regardless + // of whether an end element is present. + AutoClose []string + + // Entity can be used to map non-standard entity names to string replacements. + // The parser behaves as if these standard mappings are present in the map, + // regardless of the actual map content: + // + // "lt": "<", + // "gt": ">", + // "amp": "&", + // "apos": "'", + // "quot": `"`, + Entity map[string]string + + // CharsetReader, if non-nil, defines a function to generate + // charset-conversion readers, converting from the provided + // non-UTF-8 charset into UTF-8. If CharsetReader is nil or + // returns an error, parsing stops with an error. One of the + // the CharsetReader's result values must be non-nil. + CharsetReader func(charset string, input io.Reader) (io.Reader, error) + + // DefaultSpace sets the default name space used for unadorned tags, + // as if the entire XML stream were wrapped in an element containing + // the attribute xmlns="DefaultSpace". + DefaultSpace string + + r io.ByteReader + buf bytes.Buffer + saved *bytes.Buffer + stk *stack + free *stack + needClose bool + toClose Name + nextToken Token + nextByte int + ns map[string]string + err error + line int + offset int64 + unmarshalDepth int +} + +// NewDecoder creates a new XML parser reading from r. +// If r does not implement io.ByteReader, NewDecoder will +// do its own buffering. +func NewDecoder(r io.Reader) *Decoder { + d := &Decoder{ + ns: make(map[string]string), + nextByte: -1, + line: 1, + Strict: true, + } + d.switchToReader(r) + return d +} + +// Token returns the next XML token in the input stream. +// At the end of the input stream, Token returns nil, io.EOF. +// +// Slices of bytes in the returned token data refer to the +// parser's internal buffer and remain valid only until the next +// call to Token. To acquire a copy of the bytes, call CopyToken +// or the token's Copy method. +// +// Token expands self-closing elements such as <br/> +// into separate start and end elements returned by successive calls. +// +// Token guarantees that the StartElement and EndElement +// tokens it returns are properly nested and matched: +// if Token encounters an unexpected end element, +// it will return an error. +// +// Token implements XML name spaces as described by +// http://www.w3.org/TR/REC-xml-names/. Each of the +// Name structures contained in the Token has the Space +// set to the URL identifying its name space when known. +// If Token encounters an unrecognized name space prefix, +// it uses the prefix as the Space rather than report an error. +func (d *Decoder) Token() (t Token, err error) { + if d.stk != nil && d.stk.kind == stkEOF { + err = io.EOF + return + } + if d.nextToken != nil { + t = d.nextToken + d.nextToken = nil + } else if t, err = d.rawToken(); err != nil { + return + } + + if !d.Strict { + if t1, ok := d.autoClose(t); ok { + d.nextToken = t + t = t1 + } + } + switch t1 := t.(type) { + case StartElement: + // In XML name spaces, the translations listed in the + // attributes apply to the element name and + // to the other attribute names, so process + // the translations first. + for _, a := range t1.Attr { + if a.Name.Space == "xmlns" { + v, ok := d.ns[a.Name.Local] + d.pushNs(a.Name.Local, v, ok) + d.ns[a.Name.Local] = a.Value + } + if a.Name.Space == "" && a.Name.Local == "xmlns" { + // Default space for untagged names + v, ok := d.ns[""] + d.pushNs("", v, ok) + d.ns[""] = a.Value + } + } + + d.translate(&t1.Name, true) + for i := range t1.Attr { + d.translate(&t1.Attr[i].Name, false) + } + d.pushElement(t1.Name) + t = t1 + + case EndElement: + d.translate(&t1.Name, true) + if !d.popElement(&t1) { + return nil, d.err + } + t = t1 + } + return +} + +const xmlURL = "http://www.w3.org/XML/1998/namespace" + +// Apply name space translation to name n. +// The default name space (for Space=="") +// applies only to element names, not to attribute names. +func (d *Decoder) translate(n *Name, isElementName bool) { + switch { + case n.Space == "xmlns": + return + case n.Space == "" && !isElementName: + return + case n.Space == "xml": + n.Space = xmlURL + case n.Space == "" && n.Local == "xmlns": + return + } + if v, ok := d.ns[n.Space]; ok { + n.Space = v + } else if n.Space == "" { + n.Space = d.DefaultSpace + } +} + +func (d *Decoder) switchToReader(r io.Reader) { + // Get efficient byte at a time reader. + // Assume that if reader has its own + // ReadByte, it's efficient enough. + // Otherwise, use bufio. + if rb, ok := r.(io.ByteReader); ok { + d.r = rb + } else { + d.r = bufio.NewReader(r) + } +} + +// Parsing state - stack holds old name space translations +// and the current set of open elements. The translations to pop when +// ending a given tag are *below* it on the stack, which is +// more work but forced on us by XML. +type stack struct { + next *stack + kind int + name Name + ok bool +} + +const ( + stkStart = iota + stkNs + stkEOF +) + +func (d *Decoder) push(kind int) *stack { + s := d.free + if s != nil { + d.free = s.next + } else { + s = new(stack) + } + s.next = d.stk + s.kind = kind + d.stk = s + return s +} + +func (d *Decoder) pop() *stack { + s := d.stk + if s != nil { + d.stk = s.next + s.next = d.free + d.free = s + } + return s +} + +// Record that after the current element is finished +// (that element is already pushed on the stack) +// Token should return EOF until popEOF is called. +func (d *Decoder) pushEOF() { + // Walk down stack to find Start. + // It might not be the top, because there might be stkNs + // entries above it. + start := d.stk + for start.kind != stkStart { + start = start.next + } + // The stkNs entries below a start are associated with that + // element too; skip over them. + for start.next != nil && start.next.kind == stkNs { + start = start.next + } + s := d.free + if s != nil { + d.free = s.next + } else { + s = new(stack) + } + s.kind = stkEOF + s.next = start.next + start.next = s +} + +// Undo a pushEOF. +// The element must have been finished, so the EOF should be at the top of the stack. +func (d *Decoder) popEOF() bool { + if d.stk == nil || d.stk.kind != stkEOF { + return false + } + d.pop() + return true +} + +// Record that we are starting an element with the given name. +func (d *Decoder) pushElement(name Name) { + s := d.push(stkStart) + s.name = name +} + +// Record that we are changing the value of ns[local]. +// The old value is url, ok. +func (d *Decoder) pushNs(local string, url string, ok bool) { + s := d.push(stkNs) + s.name.Local = local + s.name.Space = url + s.ok = ok +} + +// Creates a SyntaxError with the current line number. +func (d *Decoder) syntaxError(msg string) error { + return &SyntaxError{Msg: msg, Line: d.line} +} + +// Record that we are ending an element with the given name. +// The name must match the record at the top of the stack, +// which must be a pushElement record. +// After popping the element, apply any undo records from +// the stack to restore the name translations that existed +// before we saw this element. +func (d *Decoder) popElement(t *EndElement) bool { + s := d.pop() + name := t.Name + switch { + case s == nil || s.kind != stkStart: + d.err = d.syntaxError("unexpected end element </" + name.Local + ">") + return false + case s.name.Local != name.Local: + if !d.Strict { + d.needClose = true + d.toClose = t.Name + t.Name = s.name + return true + } + d.err = d.syntaxError("element <" + s.name.Local + "> closed by </" + name.Local + ">") + return false + case s.name.Space != name.Space: + d.err = d.syntaxError("element <" + s.name.Local + "> in space " + s.name.Space + + "closed by </" + name.Local + "> in space " + name.Space) + return false + } + + // Pop stack until a Start or EOF is on the top, undoing the + // translations that were associated with the element we just closed. + for d.stk != nil && d.stk.kind != stkStart && d.stk.kind != stkEOF { + s := d.pop() + if s.ok { + d.ns[s.name.Local] = s.name.Space + } else { + delete(d.ns, s.name.Local) + } + } + + return true +} + +// If the top element on the stack is autoclosing and +// t is not the end tag, invent the end tag. +func (d *Decoder) autoClose(t Token) (Token, bool) { + if d.stk == nil || d.stk.kind != stkStart { + return nil, false + } + name := strings.ToLower(d.stk.name.Local) + for _, s := range d.AutoClose { + if strings.ToLower(s) == name { + // This one should be auto closed if t doesn't close it. + et, ok := t.(EndElement) + if !ok || et.Name.Local != name { + return EndElement{d.stk.name}, true + } + break + } + } + return nil, false +} + +var errRawToken = errors.New("xml: cannot use RawToken from UnmarshalXML method") + +// RawToken is like Token but does not verify that +// start and end elements match and does not translate +// name space prefixes to their corresponding URLs. +func (d *Decoder) RawToken() (Token, error) { + if d.unmarshalDepth > 0 { + return nil, errRawToken + } + return d.rawToken() +} + +func (d *Decoder) rawToken() (Token, error) { + if d.err != nil { + return nil, d.err + } + if d.needClose { + // The last element we read was self-closing and + // we returned just the StartElement half. + // Return the EndElement half now. + d.needClose = false + return EndElement{d.toClose}, nil + } + + b, ok := d.getc() + if !ok { + return nil, d.err + } + + if b != '<' { + // Text section. + d.ungetc(b) + data := d.text(-1, false) + if data == nil { + return nil, d.err + } + return CharData(data), nil + } + + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + switch b { + case '/': + // </: End element + var name Name + if name, ok = d.nsname(); !ok { + if d.err == nil { + d.err = d.syntaxError("expected element name after </") + } + return nil, d.err + } + d.space() + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b != '>' { + d.err = d.syntaxError("invalid characters between </" + name.Local + " and >") + return nil, d.err + } + return EndElement{name}, nil + + case '?': + // <?: Processing instruction. + // TODO(rsc): Should parse the <?xml declaration to make sure the version is 1.0. + var target string + if target, ok = d.name(); !ok { + if d.err == nil { + d.err = d.syntaxError("expected target name after <?") + } + return nil, d.err + } + d.space() + d.buf.Reset() + var b0 byte + for { + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + d.buf.WriteByte(b) + if b0 == '?' && b == '>' { + break + } + b0 = b + } + data := d.buf.Bytes() + data = data[0 : len(data)-2] // chop ?> + + if target == "xml" { + enc := procInstEncoding(string(data)) + if enc != "" && enc != "utf-8" && enc != "UTF-8" { + if d.CharsetReader == nil { + d.err = fmt.Errorf("xml: encoding %q declared but Decoder.CharsetReader is nil", enc) + return nil, d.err + } + newr, err := d.CharsetReader(enc, d.r.(io.Reader)) + if err != nil { + d.err = fmt.Errorf("xml: opening charset %q: %v", enc, err) + return nil, d.err + } + if newr == nil { + panic("CharsetReader returned a nil Reader for charset " + enc) + } + d.switchToReader(newr) + } + } + return ProcInst{target, data}, nil + + case '!': + // <!: Maybe comment, maybe CDATA. + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + switch b { + case '-': // <!- + // Probably <!-- for a comment. + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b != '-' { + d.err = d.syntaxError("invalid sequence <!- not part of <!--") + return nil, d.err + } + // Look for terminator. + d.buf.Reset() + var b0, b1 byte + for { + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + d.buf.WriteByte(b) + if b0 == '-' && b1 == '-' && b == '>' { + break + } + b0, b1 = b1, b + } + data := d.buf.Bytes() + data = data[0 : len(data)-3] // chop --> + return Comment(data), nil + + case '[': // <![ + // Probably <![CDATA[. + for i := 0; i < 6; i++ { + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b != "CDATA["[i] { + d.err = d.syntaxError("invalid <![ sequence") + return nil, d.err + } + } + // Have <![CDATA[. Read text until ]]>. + data := d.text(-1, true) + if data == nil { + return nil, d.err + } + return CharData(data), nil + } + + // Probably a directive: <!DOCTYPE ...>, <!ENTITY ...>, etc. + // We don't care, but accumulate for caller. Quoted angle + // brackets do not count for nesting. + d.buf.Reset() + d.buf.WriteByte(b) + inquote := uint8(0) + depth := 0 + for { + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if inquote == 0 && b == '>' && depth == 0 { + break + } + HandleB: + d.buf.WriteByte(b) + switch { + case b == inquote: + inquote = 0 + + case inquote != 0: + // in quotes, no special action + + case b == '\'' || b == '"': + inquote = b + + case b == '>' && inquote == 0: + depth-- + + case b == '<' && inquote == 0: + // Look for <!-- to begin comment. + s := "!--" + for i := 0; i < len(s); i++ { + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b != s[i] { + for j := 0; j < i; j++ { + d.buf.WriteByte(s[j]) + } + depth++ + goto HandleB + } + } + + // Remove < that was written above. + d.buf.Truncate(d.buf.Len() - 1) + + // Look for terminator. + var b0, b1 byte + for { + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b0 == '-' && b1 == '-' && b == '>' { + break + } + b0, b1 = b1, b + } + } + } + return Directive(d.buf.Bytes()), nil + } + + // Must be an open element like <a href="foo"> + d.ungetc(b) + + var ( + name Name + empty bool + attr []Attr + ) + if name, ok = d.nsname(); !ok { + if d.err == nil { + d.err = d.syntaxError("expected element name after <") + } + return nil, d.err + } + + attr = make([]Attr, 0, 4) + for { + d.space() + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b == '/' { + empty = true + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b != '>' { + d.err = d.syntaxError("expected /> in element") + return nil, d.err + } + break + } + if b == '>' { + break + } + d.ungetc(b) + + n := len(attr) + if n >= cap(attr) { + nattr := make([]Attr, n, 2*cap(attr)) + copy(nattr, attr) + attr = nattr + } + attr = attr[0 : n+1] + a := &attr[n] + if a.Name, ok = d.nsname(); !ok { + if d.err == nil { + d.err = d.syntaxError("expected attribute name in element") + } + return nil, d.err + } + d.space() + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if b != '=' { + if d.Strict { + d.err = d.syntaxError("attribute name without = in element") + return nil, d.err + } else { + d.ungetc(b) + a.Value = a.Name.Local + } + } else { + d.space() + data := d.attrval() + if data == nil { + return nil, d.err + } + a.Value = string(data) + } + } + if empty { + d.needClose = true + d.toClose = name + } + return StartElement{name, attr}, nil +} + +func (d *Decoder) attrval() []byte { + b, ok := d.mustgetc() + if !ok { + return nil + } + // Handle quoted attribute values + if b == '"' || b == '\'' { + return d.text(int(b), false) + } + // Handle unquoted attribute values for strict parsers + if d.Strict { + d.err = d.syntaxError("unquoted or missing attribute value in element") + return nil + } + // Handle unquoted attribute values for unstrict parsers + d.ungetc(b) + d.buf.Reset() + for { + b, ok = d.mustgetc() + if !ok { + return nil + } + // http://www.w3.org/TR/REC-html40/intro/sgmltut.html#h-3.2.2 + if 'a' <= b && b <= 'z' || 'A' <= b && b <= 'Z' || + '0' <= b && b <= '9' || b == '_' || b == ':' || b == '-' { + d.buf.WriteByte(b) + } else { + d.ungetc(b) + break + } + } + return d.buf.Bytes() +} + +// Skip spaces if any +func (d *Decoder) space() { + for { + b, ok := d.getc() + if !ok { + return + } + switch b { + case ' ', '\r', '\n', '\t': + default: + d.ungetc(b) + return + } + } +} + +// Read a single byte. +// If there is no byte to read, return ok==false +// and leave the error in d.err. +// Maintain line number. +func (d *Decoder) getc() (b byte, ok bool) { + if d.err != nil { + return 0, false + } + if d.nextByte >= 0 { + b = byte(d.nextByte) + d.nextByte = -1 + } else { + b, d.err = d.r.ReadByte() + if d.err != nil { + return 0, false + } + if d.saved != nil { + d.saved.WriteByte(b) + } + } + if b == '\n' { + d.line++ + } + d.offset++ + return b, true +} + +// InputOffset returns the input stream byte offset of the current decoder position. +// The offset gives the location of the end of the most recently returned token +// and the beginning of the next token. +func (d *Decoder) InputOffset() int64 { + return d.offset +} + +// Return saved offset. +// If we did ungetc (nextByte >= 0), have to back up one. +func (d *Decoder) savedOffset() int { + n := d.saved.Len() + if d.nextByte >= 0 { + n-- + } + return n +} + +// Must read a single byte. +// If there is no byte to read, +// set d.err to SyntaxError("unexpected EOF") +// and return ok==false +func (d *Decoder) mustgetc() (b byte, ok bool) { + if b, ok = d.getc(); !ok { + if d.err == io.EOF { + d.err = d.syntaxError("unexpected EOF") + } + } + return +} + +// Unread a single byte. +func (d *Decoder) ungetc(b byte) { + if b == '\n' { + d.line-- + } + d.nextByte = int(b) + d.offset-- +} + +var entity = map[string]int{ + "lt": '<', + "gt": '>', + "amp": '&', + "apos": '\'', + "quot": '"', +} + +// Read plain text section (XML calls it character data). +// If quote >= 0, we are in a quoted string and need to find the matching quote. +// If cdata == true, we are in a <![CDATA[ section and need to find ]]>. +// On failure return nil and leave the error in d.err. +func (d *Decoder) text(quote int, cdata bool) []byte { + var b0, b1 byte + var trunc int + d.buf.Reset() +Input: + for { + b, ok := d.getc() + if !ok { + if cdata { + if d.err == io.EOF { + d.err = d.syntaxError("unexpected EOF in CDATA section") + } + return nil + } + break Input + } + + // <![CDATA[ section ends with ]]>. + // It is an error for ]]> to appear in ordinary text. + if b0 == ']' && b1 == ']' && b == '>' { + if cdata { + trunc = 2 + break Input + } + d.err = d.syntaxError("unescaped ]]> not in CDATA section") + return nil + } + + // Stop reading text if we see a <. + if b == '<' && !cdata { + if quote >= 0 { + d.err = d.syntaxError("unescaped < inside quoted string") + return nil + } + d.ungetc('<') + break Input + } + if quote >= 0 && b == byte(quote) { + break Input + } + if b == '&' && !cdata { + // Read escaped character expression up to semicolon. + // XML in all its glory allows a document to define and use + // its own character names with <!ENTITY ...> directives. + // Parsers are required to recognize lt, gt, amp, apos, and quot + // even if they have not been declared. + before := d.buf.Len() + d.buf.WriteByte('&') + var ok bool + var text string + var haveText bool + if b, ok = d.mustgetc(); !ok { + return nil + } + if b == '#' { + d.buf.WriteByte(b) + if b, ok = d.mustgetc(); !ok { + return nil + } + base := 10 + if b == 'x' { + base = 16 + d.buf.WriteByte(b) + if b, ok = d.mustgetc(); !ok { + return nil + } + } + start := d.buf.Len() + for '0' <= b && b <= '9' || + base == 16 && 'a' <= b && b <= 'f' || + base == 16 && 'A' <= b && b <= 'F' { + d.buf.WriteByte(b) + if b, ok = d.mustgetc(); !ok { + return nil + } + } + if b != ';' { + d.ungetc(b) + } else { + s := string(d.buf.Bytes()[start:]) + d.buf.WriteByte(';') + n, err := strconv.ParseUint(s, base, 64) + if err == nil && n <= unicode.MaxRune { + text = string(n) + haveText = true + } + } + } else { + d.ungetc(b) + if !d.readName() { + if d.err != nil { + return nil + } + ok = false + } + if b, ok = d.mustgetc(); !ok { + return nil + } + if b != ';' { + d.ungetc(b) + } else { + name := d.buf.Bytes()[before+1:] + d.buf.WriteByte(';') + if isName(name) { + s := string(name) + if r, ok := entity[s]; ok { + text = string(r) + haveText = true + } else if d.Entity != nil { + text, haveText = d.Entity[s] + } + } + } + } + + if haveText { + d.buf.Truncate(before) + d.buf.Write([]byte(text)) + b0, b1 = 0, 0 + continue Input + } + if !d.Strict { + b0, b1 = 0, 0 + continue Input + } + ent := string(d.buf.Bytes()[before:]) + if ent[len(ent)-1] != ';' { + ent += " (no semicolon)" + } + d.err = d.syntaxError("invalid character entity " + ent) + return nil + } + + // We must rewrite unescaped \r and \r\n into \n. + if b == '\r' { + d.buf.WriteByte('\n') + } else if b1 == '\r' && b == '\n' { + // Skip \r\n--we already wrote \n. + } else { + d.buf.WriteByte(b) + } + + b0, b1 = b1, b + } + data := d.buf.Bytes() + data = data[0 : len(data)-trunc] + + // Inspect each rune for being a disallowed character. + buf := data + for len(buf) > 0 { + r, size := utf8.DecodeRune(buf) + if r == utf8.RuneError && size == 1 { + d.err = d.syntaxError("invalid UTF-8") + return nil + } + buf = buf[size:] + if !isInCharacterRange(r) { + d.err = d.syntaxError(fmt.Sprintf("illegal character code %U", r)) + return nil + } + } + + return data +} + +// Decide whether the given rune is in the XML Character Range, per +// the Char production of http://www.xml.com/axml/testaxml.htm, +// Section 2.2 Characters. +func isInCharacterRange(r rune) (inrange bool) { + return r == 0x09 || + r == 0x0A || + r == 0x0D || + r >= 0x20 && r <= 0xDF77 || + r >= 0xE000 && r <= 0xFFFD || + r >= 0x10000 && r <= 0x10FFFF +} + +// Get name space name: name with a : stuck in the middle. +// The part before the : is the name space identifier. +func (d *Decoder) nsname() (name Name, ok bool) { + s, ok := d.name() + if !ok { + return + } + i := strings.Index(s, ":") + if i < 0 { + name.Local = s + } else { + name.Space = s[0:i] + name.Local = s[i+1:] + } + return name, true +} + +// Get name: /first(first|second)*/ +// Do not set d.err if the name is missing (unless unexpected EOF is received): +// let the caller provide better context. +func (d *Decoder) name() (s string, ok bool) { + d.buf.Reset() + if !d.readName() { + return "", false + } + + // Now we check the characters. + s = d.buf.String() + if !isName([]byte(s)) { + d.err = d.syntaxError("invalid XML name: " + s) + return "", false + } + return s, true +} + +// Read a name and append its bytes to d.buf. +// The name is delimited by any single-byte character not valid in names. +// All multi-byte characters are accepted; the caller must check their validity. +func (d *Decoder) readName() (ok bool) { + var b byte + if b, ok = d.mustgetc(); !ok { + return + } + if b < utf8.RuneSelf && !isNameByte(b) { + d.ungetc(b) + return false + } + d.buf.WriteByte(b) + + for { + if b, ok = d.mustgetc(); !ok { + return + } + if b < utf8.RuneSelf && !isNameByte(b) { + d.ungetc(b) + break + } + d.buf.WriteByte(b) + } + return true +} + +func isNameByte(c byte) bool { + return 'A' <= c && c <= 'Z' || + 'a' <= c && c <= 'z' || + '0' <= c && c <= '9' || + c == '_' || c == ':' || c == '.' || c == '-' +} + +func isName(s []byte) bool { + if len(s) == 0 { + return false + } + c, n := utf8.DecodeRune(s) + if c == utf8.RuneError && n == 1 { + return false + } + if !unicode.Is(first, c) { + return false + } + for n < len(s) { + s = s[n:] + c, n = utf8.DecodeRune(s) + if c == utf8.RuneError && n == 1 { + return false + } + if !unicode.Is(first, c) && !unicode.Is(second, c) { + return false + } + } + return true +} + +func isNameString(s string) bool { + if len(s) == 0 { + return false + } + c, n := utf8.DecodeRuneInString(s) + if c == utf8.RuneError && n == 1 { + return false + } + if !unicode.Is(first, c) { + return false + } + for n < len(s) { + s = s[n:] + c, n = utf8.DecodeRuneInString(s) + if c == utf8.RuneError && n == 1 { + return false + } + if !unicode.Is(first, c) && !unicode.Is(second, c) { + return false + } + } + return true +} + +// These tables were generated by cut and paste from Appendix B of +// the XML spec at http://www.xml.com/axml/testaxml.htm +// and then reformatting. First corresponds to (Letter | '_' | ':') +// and second corresponds to NameChar. + +var first = &unicode.RangeTable{ + R16: []unicode.Range16{ + {0x003A, 0x003A, 1}, + {0x0041, 0x005A, 1}, + {0x005F, 0x005F, 1}, + {0x0061, 0x007A, 1}, + {0x00C0, 0x00D6, 1}, + {0x00D8, 0x00F6, 1}, + {0x00F8, 0x00FF, 1}, + {0x0100, 0x0131, 1}, + {0x0134, 0x013E, 1}, + {0x0141, 0x0148, 1}, + {0x014A, 0x017E, 1}, + {0x0180, 0x01C3, 1}, + {0x01CD, 0x01F0, 1}, + {0x01F4, 0x01F5, 1}, + {0x01FA, 0x0217, 1}, + {0x0250, 0x02A8, 1}, + {0x02BB, 0x02C1, 1}, + {0x0386, 0x0386, 1}, + {0x0388, 0x038A, 1}, + {0x038C, 0x038C, 1}, + {0x038E, 0x03A1, 1}, + {0x03A3, 0x03CE, 1}, + {0x03D0, 0x03D6, 1}, + {0x03DA, 0x03E0, 2}, + {0x03E2, 0x03F3, 1}, + {0x0401, 0x040C, 1}, + {0x040E, 0x044F, 1}, + {0x0451, 0x045C, 1}, + {0x045E, 0x0481, 1}, + {0x0490, 0x04C4, 1}, + {0x04C7, 0x04C8, 1}, + {0x04CB, 0x04CC, 1}, + {0x04D0, 0x04EB, 1}, + {0x04EE, 0x04F5, 1}, + {0x04F8, 0x04F9, 1}, + {0x0531, 0x0556, 1}, + {0x0559, 0x0559, 1}, + {0x0561, 0x0586, 1}, + {0x05D0, 0x05EA, 1}, + {0x05F0, 0x05F2, 1}, + {0x0621, 0x063A, 1}, + {0x0641, 0x064A, 1}, + {0x0671, 0x06B7, 1}, + {0x06BA, 0x06BE, 1}, + {0x06C0, 0x06CE, 1}, + {0x06D0, 0x06D3, 1}, + {0x06D5, 0x06D5, 1}, + {0x06E5, 0x06E6, 1}, + {0x0905, 0x0939, 1}, + {0x093D, 0x093D, 1}, + {0x0958, 0x0961, 1}, + {0x0985, 0x098C, 1}, + {0x098F, 0x0990, 1}, + {0x0993, 0x09A8, 1}, + {0x09AA, 0x09B0, 1}, + {0x09B2, 0x09B2, 1}, + {0x09B6, 0x09B9, 1}, + {0x09DC, 0x09DD, 1}, + {0x09DF, 0x09E1, 1}, + {0x09F0, 0x09F1, 1}, + {0x0A05, 0x0A0A, 1}, + {0x0A0F, 0x0A10, 1}, + {0x0A13, 0x0A28, 1}, + {0x0A2A, 0x0A30, 1}, + {0x0A32, 0x0A33, 1}, + {0x0A35, 0x0A36, 1}, + {0x0A38, 0x0A39, 1}, + {0x0A59, 0x0A5C, 1}, + {0x0A5E, 0x0A5E, 1}, + {0x0A72, 0x0A74, 1}, + {0x0A85, 0x0A8B, 1}, + {0x0A8D, 0x0A8D, 1}, + {0x0A8F, 0x0A91, 1}, + {0x0A93, 0x0AA8, 1}, + {0x0AAA, 0x0AB0, 1}, + {0x0AB2, 0x0AB3, 1}, + {0x0AB5, 0x0AB9, 1}, + {0x0ABD, 0x0AE0, 0x23}, + {0x0B05, 0x0B0C, 1}, + {0x0B0F, 0x0B10, 1}, + {0x0B13, 0x0B28, 1}, + {0x0B2A, 0x0B30, 1}, + {0x0B32, 0x0B33, 1}, + {0x0B36, 0x0B39, 1}, + {0x0B3D, 0x0B3D, 1}, + {0x0B5C, 0x0B5D, 1}, + {0x0B5F, 0x0B61, 1}, + {0x0B85, 0x0B8A, 1}, + {0x0B8E, 0x0B90, 1}, + {0x0B92, 0x0B95, 1}, + {0x0B99, 0x0B9A, 1}, + {0x0B9C, 0x0B9C, 1}, + {0x0B9E, 0x0B9F, 1}, + {0x0BA3, 0x0BA4, 1}, + {0x0BA8, 0x0BAA, 1}, + {0x0BAE, 0x0BB5, 1}, + {0x0BB7, 0x0BB9, 1}, + {0x0C05, 0x0C0C, 1}, + {0x0C0E, 0x0C10, 1}, + {0x0C12, 0x0C28, 1}, + {0x0C2A, 0x0C33, 1}, + {0x0C35, 0x0C39, 1}, + {0x0C60, 0x0C61, 1}, + {0x0C85, 0x0C8C, 1}, + {0x0C8E, 0x0C90, 1}, + {0x0C92, 0x0CA8, 1}, + {0x0CAA, 0x0CB3, 1}, + {0x0CB5, 0x0CB9, 1}, + {0x0CDE, 0x0CDE, 1}, + {0x0CE0, 0x0CE1, 1}, + {0x0D05, 0x0D0C, 1}, + {0x0D0E, 0x0D10, 1}, + {0x0D12, 0x0D28, 1}, + {0x0D2A, 0x0D39, 1}, + {0x0D60, 0x0D61, 1}, + {0x0E01, 0x0E2E, 1}, + {0x0E30, 0x0E30, 1}, + {0x0E32, 0x0E33, 1}, + {0x0E40, 0x0E45, 1}, + {0x0E81, 0x0E82, 1}, + {0x0E84, 0x0E84, 1}, + {0x0E87, 0x0E88, 1}, + {0x0E8A, 0x0E8D, 3}, + {0x0E94, 0x0E97, 1}, + {0x0E99, 0x0E9F, 1}, + {0x0EA1, 0x0EA3, 1}, + {0x0EA5, 0x0EA7, 2}, + {0x0EAA, 0x0EAB, 1}, + {0x0EAD, 0x0EAE, 1}, + {0x0EB0, 0x0EB0, 1}, + {0x0EB2, 0x0EB3, 1}, + {0x0EBD, 0x0EBD, 1}, + {0x0EC0, 0x0EC4, 1}, + {0x0F40, 0x0F47, 1}, + {0x0F49, 0x0F69, 1}, + {0x10A0, 0x10C5, 1}, + {0x10D0, 0x10F6, 1}, + {0x1100, 0x1100, 1}, + {0x1102, 0x1103, 1}, + {0x1105, 0x1107, 1}, + {0x1109, 0x1109, 1}, + {0x110B, 0x110C, 1}, + {0x110E, 0x1112, 1}, + {0x113C, 0x1140, 2}, + {0x114C, 0x1150, 2}, + {0x1154, 0x1155, 1}, + {0x1159, 0x1159, 1}, + {0x115F, 0x1161, 1}, + {0x1163, 0x1169, 2}, + {0x116D, 0x116E, 1}, + {0x1172, 0x1173, 1}, + {0x1175, 0x119E, 0x119E - 0x1175}, + {0x11A8, 0x11AB, 0x11AB - 0x11A8}, + {0x11AE, 0x11AF, 1}, + {0x11B7, 0x11B8, 1}, + {0x11BA, 0x11BA, 1}, + {0x11BC, 0x11C2, 1}, + {0x11EB, 0x11F0, 0x11F0 - 0x11EB}, + {0x11F9, 0x11F9, 1}, + {0x1E00, 0x1E9B, 1}, + {0x1EA0, 0x1EF9, 1}, + {0x1F00, 0x1F15, 1}, + {0x1F18, 0x1F1D, 1}, + {0x1F20, 0x1F45, 1}, + {0x1F48, 0x1F4D, 1}, + {0x1F50, 0x1F57, 1}, + {0x1F59, 0x1F5B, 0x1F5B - 0x1F59}, + {0x1F5D, 0x1F5D, 1}, + {0x1F5F, 0x1F7D, 1}, + {0x1F80, 0x1FB4, 1}, + {0x1FB6, 0x1FBC, 1}, + {0x1FBE, 0x1FBE, 1}, + {0x1FC2, 0x1FC4, 1}, + {0x1FC6, 0x1FCC, 1}, + {0x1FD0, 0x1FD3, 1}, + {0x1FD6, 0x1FDB, 1}, + {0x1FE0, 0x1FEC, 1}, + {0x1FF2, 0x1FF4, 1}, + {0x1FF6, 0x1FFC, 1}, + {0x2126, 0x2126, 1}, + {0x212A, 0x212B, 1}, + {0x212E, 0x212E, 1}, + {0x2180, 0x2182, 1}, + {0x3007, 0x3007, 1}, + {0x3021, 0x3029, 1}, + {0x3041, 0x3094, 1}, + {0x30A1, 0x30FA, 1}, + {0x3105, 0x312C, 1}, + {0x4E00, 0x9FA5, 1}, + {0xAC00, 0xD7A3, 1}, + }, +} + +var second = &unicode.RangeTable{ + R16: []unicode.Range16{ + {0x002D, 0x002E, 1}, + {0x0030, 0x0039, 1}, + {0x00B7, 0x00B7, 1}, + {0x02D0, 0x02D1, 1}, + {0x0300, 0x0345, 1}, + {0x0360, 0x0361, 1}, + {0x0387, 0x0387, 1}, + {0x0483, 0x0486, 1}, + {0x0591, 0x05A1, 1}, + {0x05A3, 0x05B9, 1}, + {0x05BB, 0x05BD, 1}, + {0x05BF, 0x05BF, 1}, + {0x05C1, 0x05C2, 1}, + {0x05C4, 0x0640, 0x0640 - 0x05C4}, + {0x064B, 0x0652, 1}, + {0x0660, 0x0669, 1}, + {0x0670, 0x0670, 1}, + {0x06D6, 0x06DC, 1}, + {0x06DD, 0x06DF, 1}, + {0x06E0, 0x06E4, 1}, + {0x06E7, 0x06E8, 1}, + {0x06EA, 0x06ED, 1}, + {0x06F0, 0x06F9, 1}, + {0x0901, 0x0903, 1}, + {0x093C, 0x093C, 1}, + {0x093E, 0x094C, 1}, + {0x094D, 0x094D, 1}, + {0x0951, 0x0954, 1}, + {0x0962, 0x0963, 1}, + {0x0966, 0x096F, 1}, + {0x0981, 0x0983, 1}, + {0x09BC, 0x09BC, 1}, + {0x09BE, 0x09BF, 1}, + {0x09C0, 0x09C4, 1}, + {0x09C7, 0x09C8, 1}, + {0x09CB, 0x09CD, 1}, + {0x09D7, 0x09D7, 1}, + {0x09E2, 0x09E3, 1}, + {0x09E6, 0x09EF, 1}, + {0x0A02, 0x0A3C, 0x3A}, + {0x0A3E, 0x0A3F, 1}, + {0x0A40, 0x0A42, 1}, + {0x0A47, 0x0A48, 1}, + {0x0A4B, 0x0A4D, 1}, + {0x0A66, 0x0A6F, 1}, + {0x0A70, 0x0A71, 1}, + {0x0A81, 0x0A83, 1}, + {0x0ABC, 0x0ABC, 1}, + {0x0ABE, 0x0AC5, 1}, + {0x0AC7, 0x0AC9, 1}, + {0x0ACB, 0x0ACD, 1}, + {0x0AE6, 0x0AEF, 1}, + {0x0B01, 0x0B03, 1}, + {0x0B3C, 0x0B3C, 1}, + {0x0B3E, 0x0B43, 1}, + {0x0B47, 0x0B48, 1}, + {0x0B4B, 0x0B4D, 1}, + {0x0B56, 0x0B57, 1}, + {0x0B66, 0x0B6F, 1}, + {0x0B82, 0x0B83, 1}, + {0x0BBE, 0x0BC2, 1}, + {0x0BC6, 0x0BC8, 1}, + {0x0BCA, 0x0BCD, 1}, + {0x0BD7, 0x0BD7, 1}, + {0x0BE7, 0x0BEF, 1}, + {0x0C01, 0x0C03, 1}, + {0x0C3E, 0x0C44, 1}, + {0x0C46, 0x0C48, 1}, + {0x0C4A, 0x0C4D, 1}, + {0x0C55, 0x0C56, 1}, + {0x0C66, 0x0C6F, 1}, + {0x0C82, 0x0C83, 1}, + {0x0CBE, 0x0CC4, 1}, + {0x0CC6, 0x0CC8, 1}, + {0x0CCA, 0x0CCD, 1}, + {0x0CD5, 0x0CD6, 1}, + {0x0CE6, 0x0CEF, 1}, + {0x0D02, 0x0D03, 1}, + {0x0D3E, 0x0D43, 1}, + {0x0D46, 0x0D48, 1}, + {0x0D4A, 0x0D4D, 1}, + {0x0D57, 0x0D57, 1}, + {0x0D66, 0x0D6F, 1}, + {0x0E31, 0x0E31, 1}, + {0x0E34, 0x0E3A, 1}, + {0x0E46, 0x0E46, 1}, + {0x0E47, 0x0E4E, 1}, + {0x0E50, 0x0E59, 1}, + {0x0EB1, 0x0EB1, 1}, + {0x0EB4, 0x0EB9, 1}, + {0x0EBB, 0x0EBC, 1}, + {0x0EC6, 0x0EC6, 1}, + {0x0EC8, 0x0ECD, 1}, + {0x0ED0, 0x0ED9, 1}, + {0x0F18, 0x0F19, 1}, + {0x0F20, 0x0F29, 1}, + {0x0F35, 0x0F39, 2}, + {0x0F3E, 0x0F3F, 1}, + {0x0F71, 0x0F84, 1}, + {0x0F86, 0x0F8B, 1}, + {0x0F90, 0x0F95, 1}, + {0x0F97, 0x0F97, 1}, + {0x0F99, 0x0FAD, 1}, + {0x0FB1, 0x0FB7, 1}, + {0x0FB9, 0x0FB9, 1}, + {0x20D0, 0x20DC, 1}, + {0x20E1, 0x3005, 0x3005 - 0x20E1}, + {0x302A, 0x302F, 1}, + {0x3031, 0x3035, 1}, + {0x3099, 0x309A, 1}, + {0x309D, 0x309E, 1}, + {0x30FC, 0x30FE, 1}, + }, +} + +// HTMLEntity is an entity map containing translations for the +// standard HTML entity characters. +var HTMLEntity = htmlEntity + +var htmlEntity = map[string]string{ + /* + hget http://www.w3.org/TR/html4/sgml/entities.html | + ssam ' + ,y /\>/ x/\<(.|\n)+/ s/\n/ /g + ,x v/^\<!ENTITY/d + ,s/\<!ENTITY ([^ ]+) .*U\+([0-9A-F][0-9A-F][0-9A-F][0-9A-F]) .+/ "\1": "\\u\2",/g + ' + */ + "nbsp": "\u00A0", + "iexcl": "\u00A1", + "cent": "\u00A2", + "pound": "\u00A3", + "curren": "\u00A4", + "yen": "\u00A5", + "brvbar": "\u00A6", + "sect": "\u00A7", + "uml": "\u00A8", + "copy": "\u00A9", + "ordf": "\u00AA", + "laquo": "\u00AB", + "not": "\u00AC", + "shy": "\u00AD", + "reg": "\u00AE", + "macr": "\u00AF", + "deg": "\u00B0", + "plusmn": "\u00B1", + "sup2": "\u00B2", + "sup3": "\u00B3", + "acute": "\u00B4", + "micro": "\u00B5", + "para": "\u00B6", + "middot": "\u00B7", + "cedil": "\u00B8", + "sup1": "\u00B9", + "ordm": "\u00BA", + "raquo": "\u00BB", + "frac14": "\u00BC", + "frac12": "\u00BD", + "frac34": "\u00BE", + "iquest": "\u00BF", + "Agrave": "\u00C0", + "Aacute": "\u00C1", + "Acirc": "\u00C2", + "Atilde": "\u00C3", + "Auml": "\u00C4", + "Aring": "\u00C5", + "AElig": "\u00C6", + "Ccedil": "\u00C7", + "Egrave": "\u00C8", + "Eacute": "\u00C9", + "Ecirc": "\u00CA", + "Euml": "\u00CB", + "Igrave": "\u00CC", + "Iacute": "\u00CD", + "Icirc": "\u00CE", + "Iuml": "\u00CF", + "ETH": "\u00D0", + "Ntilde": "\u00D1", + "Ograve": "\u00D2", + "Oacute": "\u00D3", + "Ocirc": "\u00D4", + "Otilde": "\u00D5", + "Ouml": "\u00D6", + "times": "\u00D7", + "Oslash": "\u00D8", + "Ugrave": "\u00D9", + "Uacute": "\u00DA", + "Ucirc": "\u00DB", + "Uuml": "\u00DC", + "Yacute": "\u00DD", + "THORN": "\u00DE", + "szlig": "\u00DF", + "agrave": "\u00E0", + "aacute": "\u00E1", + "acirc": "\u00E2", + "atilde": "\u00E3", + "auml": "\u00E4", + "aring": "\u00E5", + "aelig": "\u00E6", + "ccedil": "\u00E7", + "egrave": "\u00E8", + "eacute": "\u00E9", + "ecirc": "\u00EA", + "euml": "\u00EB", + "igrave": "\u00EC", + "iacute": "\u00ED", + "icirc": "\u00EE", + "iuml": "\u00EF", + "eth": "\u00F0", + "ntilde": "\u00F1", + "ograve": "\u00F2", + "oacute": "\u00F3", + "ocirc": "\u00F4", + "otilde": "\u00F5", + "ouml": "\u00F6", + "divide": "\u00F7", + "oslash": "\u00F8", + "ugrave": "\u00F9", + "uacute": "\u00FA", + "ucirc": "\u00FB", + "uuml": "\u00FC", + "yacute": "\u00FD", + "thorn": "\u00FE", + "yuml": "\u00FF", + "fnof": "\u0192", + "Alpha": "\u0391", + "Beta": "\u0392", + "Gamma": "\u0393", + "Delta": "\u0394", + "Epsilon": "\u0395", + "Zeta": "\u0396", + "Eta": "\u0397", + "Theta": "\u0398", + "Iota": "\u0399", + "Kappa": "\u039A", + "Lambda": "\u039B", + "Mu": "\u039C", + "Nu": "\u039D", + "Xi": "\u039E", + "Omicron": "\u039F", + "Pi": "\u03A0", + "Rho": "\u03A1", + "Sigma": "\u03A3", + "Tau": "\u03A4", + "Upsilon": "\u03A5", + "Phi": "\u03A6", + "Chi": "\u03A7", + "Psi": "\u03A8", + "Omega": "\u03A9", + "alpha": "\u03B1", + "beta": "\u03B2", + "gamma": "\u03B3", + "delta": "\u03B4", + "epsilon": "\u03B5", + "zeta": "\u03B6", + "eta": "\u03B7", + "theta": "\u03B8", + "iota": "\u03B9", + "kappa": "\u03BA", + "lambda": "\u03BB", + "mu": "\u03BC", + "nu": "\u03BD", + "xi": "\u03BE", + "omicron": "\u03BF", + "pi": "\u03C0", + "rho": "\u03C1", + "sigmaf": "\u03C2", + "sigma": "\u03C3", + "tau": "\u03C4", + "upsilon": "\u03C5", + "phi": "\u03C6", + "chi": "\u03C7", + "psi": "\u03C8", + "omega": "\u03C9", + "thetasym": "\u03D1", + "upsih": "\u03D2", + "piv": "\u03D6", + "bull": "\u2022", + "hellip": "\u2026", + "prime": "\u2032", + "Prime": "\u2033", + "oline": "\u203E", + "frasl": "\u2044", + "weierp": "\u2118", + "image": "\u2111", + "real": "\u211C", + "trade": "\u2122", + "alefsym": "\u2135", + "larr": "\u2190", + "uarr": "\u2191", + "rarr": "\u2192", + "darr": "\u2193", + "harr": "\u2194", + "crarr": "\u21B5", + "lArr": "\u21D0", + "uArr": "\u21D1", + "rArr": "\u21D2", + "dArr": "\u21D3", + "hArr": "\u21D4", + "forall": "\u2200", + "part": "\u2202", + "exist": "\u2203", + "empty": "\u2205", + "nabla": "\u2207", + "isin": "\u2208", + "notin": "\u2209", + "ni": "\u220B", + "prod": "\u220F", + "sum": "\u2211", + "minus": "\u2212", + "lowast": "\u2217", + "radic": "\u221A", + "prop": "\u221D", + "infin": "\u221E", + "ang": "\u2220", + "and": "\u2227", + "or": "\u2228", + "cap": "\u2229", + "cup": "\u222A", + "int": "\u222B", + "there4": "\u2234", + "sim": "\u223C", + "cong": "\u2245", + "asymp": "\u2248", + "ne": "\u2260", + "equiv": "\u2261", + "le": "\u2264", + "ge": "\u2265", + "sub": "\u2282", + "sup": "\u2283", + "nsub": "\u2284", + "sube": "\u2286", + "supe": "\u2287", + "oplus": "\u2295", + "otimes": "\u2297", + "perp": "\u22A5", + "sdot": "\u22C5", + "lceil": "\u2308", + "rceil": "\u2309", + "lfloor": "\u230A", + "rfloor": "\u230B", + "lang": "\u2329", + "rang": "\u232A", + "loz": "\u25CA", + "spades": "\u2660", + "clubs": "\u2663", + "hearts": "\u2665", + "diams": "\u2666", + "quot": "\u0022", + "amp": "\u0026", + "lt": "\u003C", + "gt": "\u003E", + "OElig": "\u0152", + "oelig": "\u0153", + "Scaron": "\u0160", + "scaron": "\u0161", + "Yuml": "\u0178", + "circ": "\u02C6", + "tilde": "\u02DC", + "ensp": "\u2002", + "emsp": "\u2003", + "thinsp": "\u2009", + "zwnj": "\u200C", + "zwj": "\u200D", + "lrm": "\u200E", + "rlm": "\u200F", + "ndash": "\u2013", + "mdash": "\u2014", + "lsquo": "\u2018", + "rsquo": "\u2019", + "sbquo": "\u201A", + "ldquo": "\u201C", + "rdquo": "\u201D", + "bdquo": "\u201E", + "dagger": "\u2020", + "Dagger": "\u2021", + "permil": "\u2030", + "lsaquo": "\u2039", + "rsaquo": "\u203A", + "euro": "\u20AC", +} + +// HTMLAutoClose is the set of HTML elements that +// should be considered to close automatically. +var HTMLAutoClose = htmlAutoClose + +var htmlAutoClose = []string{ + /* + hget http://www.w3.org/TR/html4/loose.dtd | + 9 sed -n 's/<!ELEMENT ([^ ]*) +- O EMPTY.+/ "\1",/p' | tr A-Z a-z + */ + "basefont", + "br", + "area", + "link", + "img", + "param", + "hr", + "input", + "col", + "frame", + "isindex", + "base", + "meta", +} + +var ( + esc_quot = []byte(""") // shorter than """ + esc_apos = []byte("'") // shorter than "'" + esc_amp = []byte("&") + esc_lt = []byte("<") + esc_gt = []byte(">") + esc_tab = []byte("	") + esc_nl = []byte("
") + esc_cr = []byte("
") + esc_fffd = []byte("\uFFFD") // Unicode replacement character +) + +// EscapeText writes to w the properly escaped XML equivalent +// of the plain text data s. +func EscapeText(w io.Writer, s []byte) error { + var esc []byte + last := 0 + for i := 0; i < len(s); { + r, width := utf8.DecodeRune(s[i:]) + i += width + switch r { + case '"': + esc = esc_quot + case '\'': + esc = esc_apos + case '&': + esc = esc_amp + case '<': + esc = esc_lt + case '>': + esc = esc_gt + case '\t': + esc = esc_tab + case '\n': + esc = esc_nl + case '\r': + esc = esc_cr + default: + if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) { + esc = esc_fffd + break + } + continue + } + if _, err := w.Write(s[last : i-width]); err != nil { + return err + } + if _, err := w.Write(esc); err != nil { + return err + } + last = i + } + if _, err := w.Write(s[last:]); err != nil { + return err + } + return nil +} + +// EscapeString writes to p the properly escaped XML equivalent +// of the plain text data s. +func (p *printer) EscapeString(s string) { + var esc []byte + last := 0 + for i := 0; i < len(s); { + r, width := utf8.DecodeRuneInString(s[i:]) + i += width + switch r { + case '"': + esc = esc_quot + case '\'': + esc = esc_apos + case '&': + esc = esc_amp + case '<': + esc = esc_lt + case '>': + esc = esc_gt + case '\t': + esc = esc_tab + case '\n': + esc = esc_nl + case '\r': + esc = esc_cr + default: + if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) { + esc = esc_fffd + break + } + continue + } + p.WriteString(s[last : i-width]) + p.Write(esc) + last = i + } + p.WriteString(s[last:]) +} + +// Escape is like EscapeText but omits the error return value. +// It is provided for backwards compatibility with Go 1.0. +// Code targeting Go 1.1 or later should use EscapeText. +func Escape(w io.Writer, s []byte) { + EscapeText(w, s) +} + +// procInstEncoding parses the `encoding="..."` or `encoding='...'` +// value out of the provided string, returning "" if not found. +func procInstEncoding(s string) string { + // TODO: this parsing is somewhat lame and not exact. + // It works for all actual cases, though. + idx := strings.Index(s, "encoding=") + if idx == -1 { + return "" + } + v := s[idx+len("encoding="):] + if v == "" { + return "" + } + if v[0] != '\'' && v[0] != '"' { + return "" + } + idx = strings.IndexRune(v[1:], rune(v[0])) + if idx == -1 { + return "" + } + return v[1 : idx+1] +} diff --git a/src/encoding/xml/xml_test.go b/src/encoding/xml/xml_test.go new file mode 100644 index 000000000..be995c0d5 --- /dev/null +++ b/src/encoding/xml/xml_test.go @@ -0,0 +1,749 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bytes" + "fmt" + "io" + "reflect" + "strings" + "testing" + "unicode/utf8" +) + +const testInput = ` +<?xml version="1.0" encoding="UTF-8"?> +<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> +<body xmlns:foo="ns1" xmlns="ns2" xmlns:tag="ns3" ` + + "\r\n\t" + ` > + <hello lang="en">World <>'" 白鵬翔</hello> + <query>&何; &is-it;</query> + <goodbye /> + <outer foo:attr="value" xmlns:tag="ns4"> + <inner/> + </outer> + <tag:name> + <![CDATA[Some text here.]]> + </tag:name> +</body><!-- missing final newline -->` + +var testEntity = map[string]string{"何": "What", "is-it": "is it?"} + +var rawTokens = []Token{ + CharData("\n"), + ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)}, + CharData("\n"), + Directive(`DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"`), + CharData("\n"), + StartElement{Name{"", "body"}, []Attr{{Name{"xmlns", "foo"}, "ns1"}, {Name{"", "xmlns"}, "ns2"}, {Name{"xmlns", "tag"}, "ns3"}}}, + CharData("\n "), + StartElement{Name{"", "hello"}, []Attr{{Name{"", "lang"}, "en"}}}, + CharData("World <>'\" 白鵬翔"), + EndElement{Name{"", "hello"}}, + CharData("\n "), + StartElement{Name{"", "query"}, []Attr{}}, + CharData("What is it?"), + EndElement{Name{"", "query"}}, + CharData("\n "), + StartElement{Name{"", "goodbye"}, []Attr{}}, + EndElement{Name{"", "goodbye"}}, + CharData("\n "), + StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}}, + CharData("\n "), + StartElement{Name{"", "inner"}, []Attr{}}, + EndElement{Name{"", "inner"}}, + CharData("\n "), + EndElement{Name{"", "outer"}}, + CharData("\n "), + StartElement{Name{"tag", "name"}, []Attr{}}, + CharData("\n "), + CharData("Some text here."), + CharData("\n "), + EndElement{Name{"tag", "name"}}, + CharData("\n"), + EndElement{Name{"", "body"}}, + Comment(" missing final newline "), +} + +var cookedTokens = []Token{ + CharData("\n"), + ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)}, + CharData("\n"), + Directive(`DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"`), + CharData("\n"), + StartElement{Name{"ns2", "body"}, []Attr{{Name{"xmlns", "foo"}, "ns1"}, {Name{"", "xmlns"}, "ns2"}, {Name{"xmlns", "tag"}, "ns3"}}}, + CharData("\n "), + StartElement{Name{"ns2", "hello"}, []Attr{{Name{"", "lang"}, "en"}}}, + CharData("World <>'\" 白鵬翔"), + EndElement{Name{"ns2", "hello"}}, + CharData("\n "), + StartElement{Name{"ns2", "query"}, []Attr{}}, + CharData("What is it?"), + EndElement{Name{"ns2", "query"}}, + CharData("\n "), + StartElement{Name{"ns2", "goodbye"}, []Attr{}}, + EndElement{Name{"ns2", "goodbye"}}, + CharData("\n "), + StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}}, + CharData("\n "), + StartElement{Name{"ns2", "inner"}, []Attr{}}, + EndElement{Name{"ns2", "inner"}}, + CharData("\n "), + EndElement{Name{"ns2", "outer"}}, + CharData("\n "), + StartElement{Name{"ns3", "name"}, []Attr{}}, + CharData("\n "), + CharData("Some text here."), + CharData("\n "), + EndElement{Name{"ns3", "name"}}, + CharData("\n"), + EndElement{Name{"ns2", "body"}}, + Comment(" missing final newline "), +} + +const testInputAltEncoding = ` +<?xml version="1.0" encoding="x-testing-uppercase"?> +<TAG>VALUE</TAG>` + +var rawTokensAltEncoding = []Token{ + CharData("\n"), + ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("value"), + EndElement{Name{"", "tag"}}, +} + +var xmlInput = []string{ + // unexpected EOF cases + "<", + "<t", + "<t ", + "<t/", + "<!", + "<!-", + "<!--", + "<!--c-", + "<!--c--", + "<!d", + "<t></", + "<t></t", + "<?", + "<?p", + "<t a", + "<t a=", + "<t a='", + "<t a=''", + "<t/><![", + "<t/><![C", + "<t/><![CDATA[d", + "<t/><![CDATA[d]", + "<t/><![CDATA[d]]", + + // other Syntax errors + "<>", + "<t/a", + "<0 />", + "<?0 >", + // "<!0 >", // let the Token() caller handle + "</0>", + "<t 0=''>", + "<t a='&'>", + "<t a='<'>", + "<t> c;</t>", + "<t a>", + "<t a=>", + "<t a=v>", + // "<![CDATA[d]]>", // let the Token() caller handle + "<t></e>", + "<t></>", + "<t></t!", + "<t>cdata]]></t>", +} + +func TestRawToken(t *testing.T) { + d := NewDecoder(strings.NewReader(testInput)) + d.Entity = testEntity + testRawToken(t, d, testInput, rawTokens) +} + +const nonStrictInput = ` +<tag>non&entity</tag> +<tag>&unknown;entity</tag> +<tag>{</tag> +<tag>&#zzz;</tag> +<tag>&なまえ3;</tag> +<tag><-gt;</tag> +<tag>&;</tag> +<tag>&0a;</tag> +` + +var nonStringEntity = map[string]string{"": "oops!", "0a": "oops!"} + +var nonStrictTokens = []Token{ + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("non&entity"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&unknown;entity"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("{"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&#zzz;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&なまえ3;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("<-gt;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&0a;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), +} + +func TestNonStrictRawToken(t *testing.T) { + d := NewDecoder(strings.NewReader(nonStrictInput)) + d.Strict = false + testRawToken(t, d, nonStrictInput, nonStrictTokens) +} + +type downCaser struct { + t *testing.T + r io.ByteReader +} + +func (d *downCaser) ReadByte() (c byte, err error) { + c, err = d.r.ReadByte() + if c >= 'A' && c <= 'Z' { + c += 'a' - 'A' + } + return +} + +func (d *downCaser) Read(p []byte) (int, error) { + d.t.Fatalf("unexpected Read call on downCaser reader") + panic("unreachable") +} + +func TestRawTokenAltEncoding(t *testing.T) { + d := NewDecoder(strings.NewReader(testInputAltEncoding)) + d.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) { + if charset != "x-testing-uppercase" { + t.Fatalf("unexpected charset %q", charset) + } + return &downCaser{t, input.(io.ByteReader)}, nil + } + testRawToken(t, d, testInputAltEncoding, rawTokensAltEncoding) +} + +func TestRawTokenAltEncodingNoConverter(t *testing.T) { + d := NewDecoder(strings.NewReader(testInputAltEncoding)) + token, err := d.RawToken() + if token == nil { + t.Fatalf("expected a token on first RawToken call") + } + if err != nil { + t.Fatal(err) + } + token, err = d.RawToken() + if token != nil { + t.Errorf("expected a nil token; got %#v", token) + } + if err == nil { + t.Fatalf("expected an error on second RawToken call") + } + const encoding = "x-testing-uppercase" + if !strings.Contains(err.Error(), encoding) { + t.Errorf("expected error to contain %q; got error: %v", + encoding, err) + } +} + +func testRawToken(t *testing.T, d *Decoder, raw string, rawTokens []Token) { + lastEnd := int64(0) + for i, want := range rawTokens { + start := d.InputOffset() + have, err := d.RawToken() + end := d.InputOffset() + if err != nil { + t.Fatalf("token %d: unexpected error: %s", i, err) + } + if !reflect.DeepEqual(have, want) { + var shave, swant string + if _, ok := have.(CharData); ok { + shave = fmt.Sprintf("CharData(%q)", have) + } else { + shave = fmt.Sprintf("%#v", have) + } + if _, ok := want.(CharData); ok { + swant = fmt.Sprintf("CharData(%q)", want) + } else { + swant = fmt.Sprintf("%#v", want) + } + t.Errorf("token %d = %s, want %s", i, shave, swant) + } + + // Check that InputOffset returned actual token. + switch { + case start < lastEnd: + t.Errorf("token %d: position [%d,%d) for %T is before previous token", i, start, end, have) + case start >= end: + // Special case: EndElement can be synthesized. + if start == end && end == lastEnd { + break + } + t.Errorf("token %d: position [%d,%d) for %T is empty", i, start, end, have) + case end > int64(len(raw)): + t.Errorf("token %d: position [%d,%d) for %T extends beyond input", i, start, end, have) + default: + text := raw[start:end] + if strings.ContainsAny(text, "<>") && (!strings.HasPrefix(text, "<") || !strings.HasSuffix(text, ">")) { + t.Errorf("token %d: misaligned raw token %#q for %T", i, text, have) + } + } + lastEnd = end + } +} + +// Ensure that directives (specifically !DOCTYPE) include the complete +// text of any nested directives, noting that < and > do not change +// nesting depth if they are in single or double quotes. + +var nestedDirectivesInput = ` +<!DOCTYPE [<!ENTITY rdf "http://www.w3.org/1999/02/22-rdf-syntax-ns#">]> +<!DOCTYPE [<!ENTITY xlt ">">]> +<!DOCTYPE [<!ENTITY xlt "<">]> +<!DOCTYPE [<!ENTITY xlt '>'>]> +<!DOCTYPE [<!ENTITY xlt '<'>]> +<!DOCTYPE [<!ENTITY xlt '">'>]> +<!DOCTYPE [<!ENTITY xlt "'<">]> +` + +var nestedDirectivesTokens = []Token{ + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY rdf "http://www.w3.org/1999/02/22-rdf-syntax-ns#">]`), + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY xlt ">">]`), + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY xlt "<">]`), + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY xlt '>'>]`), + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY xlt '<'>]`), + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY xlt '">'>]`), + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY xlt "'<">]`), + CharData("\n"), +} + +func TestNestedDirectives(t *testing.T) { + d := NewDecoder(strings.NewReader(nestedDirectivesInput)) + + for i, want := range nestedDirectivesTokens { + have, err := d.Token() + if err != nil { + t.Fatalf("token %d: unexpected error: %s", i, err) + } + if !reflect.DeepEqual(have, want) { + t.Errorf("token %d = %#v want %#v", i, have, want) + } + } +} + +func TestToken(t *testing.T) { + d := NewDecoder(strings.NewReader(testInput)) + d.Entity = testEntity + + for i, want := range cookedTokens { + have, err := d.Token() + if err != nil { + t.Fatalf("token %d: unexpected error: %s", i, err) + } + if !reflect.DeepEqual(have, want) { + t.Errorf("token %d = %#v want %#v", i, have, want) + } + } +} + +func TestSyntax(t *testing.T) { + for i := range xmlInput { + d := NewDecoder(strings.NewReader(xmlInput[i])) + var err error + for _, err = d.Token(); err == nil; _, err = d.Token() { + } + if _, ok := err.(*SyntaxError); !ok { + t.Fatalf(`xmlInput "%s": expected SyntaxError not received`, xmlInput[i]) + } + } +} + +type allScalars struct { + True1 bool + True2 bool + False1 bool + False2 bool + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint int + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Uintptr uintptr + Float32 float32 + Float64 float64 + String string + PtrString *string +} + +var all = allScalars{ + True1: true, + True2: true, + False1: false, + False2: false, + Int: 1, + Int8: -2, + Int16: 3, + Int32: -4, + Int64: 5, + Uint: 6, + Uint8: 7, + Uint16: 8, + Uint32: 9, + Uint64: 10, + Uintptr: 11, + Float32: 13.0, + Float64: 14.0, + String: "15", + PtrString: &sixteen, +} + +var sixteen = "16" + +const testScalarsInput = `<allscalars> + <True1>true</True1> + <True2>1</True2> + <False1>false</False1> + <False2>0</False2> + <Int>1</Int> + <Int8>-2</Int8> + <Int16>3</Int16> + <Int32>-4</Int32> + <Int64>5</Int64> + <Uint>6</Uint> + <Uint8>7</Uint8> + <Uint16>8</Uint16> + <Uint32>9</Uint32> + <Uint64>10</Uint64> + <Uintptr>11</Uintptr> + <Float>12.0</Float> + <Float32>13.0</Float32> + <Float64>14.0</Float64> + <String>15</String> + <PtrString>16</PtrString> +</allscalars>` + +func TestAllScalars(t *testing.T) { + var a allScalars + err := Unmarshal([]byte(testScalarsInput), &a) + + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(a, all) { + t.Errorf("have %+v want %+v", a, all) + } +} + +type item struct { + Field_a string +} + +func TestIssue569(t *testing.T) { + data := `<item><Field_a>abcd</Field_a></item>` + var i item + err := Unmarshal([]byte(data), &i) + + if err != nil || i.Field_a != "abcd" { + t.Fatal("Expecting abcd") + } +} + +func TestUnquotedAttrs(t *testing.T) { + data := "<tag attr=azAZ09:-_\t>" + d := NewDecoder(strings.NewReader(data)) + d.Strict = false + token, err := d.Token() + if _, ok := err.(*SyntaxError); ok { + t.Errorf("Unexpected error: %v", err) + } + if token.(StartElement).Name.Local != "tag" { + t.Errorf("Unexpected tag name: %v", token.(StartElement).Name.Local) + } + attr := token.(StartElement).Attr[0] + if attr.Value != "azAZ09:-_" { + t.Errorf("Unexpected attribute value: %v", attr.Value) + } + if attr.Name.Local != "attr" { + t.Errorf("Unexpected attribute name: %v", attr.Name.Local) + } +} + +func TestValuelessAttrs(t *testing.T) { + tests := [][3]string{ + {"<p nowrap>", "p", "nowrap"}, + {"<p nowrap >", "p", "nowrap"}, + {"<input checked/>", "input", "checked"}, + {"<input checked />", "input", "checked"}, + } + for _, test := range tests { + d := NewDecoder(strings.NewReader(test[0])) + d.Strict = false + token, err := d.Token() + if _, ok := err.(*SyntaxError); ok { + t.Errorf("Unexpected error: %v", err) + } + if token.(StartElement).Name.Local != test[1] { + t.Errorf("Unexpected tag name: %v", token.(StartElement).Name.Local) + } + attr := token.(StartElement).Attr[0] + if attr.Value != test[2] { + t.Errorf("Unexpected attribute value: %v", attr.Value) + } + if attr.Name.Local != test[2] { + t.Errorf("Unexpected attribute name: %v", attr.Name.Local) + } + } +} + +func TestCopyTokenCharData(t *testing.T) { + data := []byte("same data") + var tok1 Token = CharData(data) + tok2 := CopyToken(tok1) + if !reflect.DeepEqual(tok1, tok2) { + t.Error("CopyToken(CharData) != CharData") + } + data[1] = 'o' + if reflect.DeepEqual(tok1, tok2) { + t.Error("CopyToken(CharData) uses same buffer.") + } +} + +func TestCopyTokenStartElement(t *testing.T) { + elt := StartElement{Name{"", "hello"}, []Attr{{Name{"", "lang"}, "en"}}} + var tok1 Token = elt + tok2 := CopyToken(tok1) + if tok1.(StartElement).Attr[0].Value != "en" { + t.Error("CopyToken overwrote Attr[0]") + } + if !reflect.DeepEqual(tok1, tok2) { + t.Error("CopyToken(StartElement) != StartElement") + } + tok1.(StartElement).Attr[0] = Attr{Name{"", "lang"}, "de"} + if reflect.DeepEqual(tok1, tok2) { + t.Error("CopyToken(CharData) uses same buffer.") + } +} + +func TestSyntaxErrorLineNum(t *testing.T) { + testInput := "<P>Foo<P>\n\n<P>Bar</>\n" + d := NewDecoder(strings.NewReader(testInput)) + var err error + for _, err = d.Token(); err == nil; _, err = d.Token() { + } + synerr, ok := err.(*SyntaxError) + if !ok { + t.Error("Expected SyntaxError.") + } + if synerr.Line != 3 { + t.Error("SyntaxError didn't have correct line number.") + } +} + +func TestTrailingRawToken(t *testing.T) { + input := `<FOO></FOO> ` + d := NewDecoder(strings.NewReader(input)) + var err error + for _, err = d.RawToken(); err == nil; _, err = d.RawToken() { + } + if err != io.EOF { + t.Fatalf("d.RawToken() = _, %v, want _, io.EOF", err) + } +} + +func TestTrailingToken(t *testing.T) { + input := `<FOO></FOO> ` + d := NewDecoder(strings.NewReader(input)) + var err error + for _, err = d.Token(); err == nil; _, err = d.Token() { + } + if err != io.EOF { + t.Fatalf("d.Token() = _, %v, want _, io.EOF", err) + } +} + +func TestEntityInsideCDATA(t *testing.T) { + input := `<test><![CDATA[ &val=foo ]]></test>` + d := NewDecoder(strings.NewReader(input)) + var err error + for _, err = d.Token(); err == nil; _, err = d.Token() { + } + if err != io.EOF { + t.Fatalf("d.Token() = _, %v, want _, io.EOF", err) + } +} + +var characterTests = []struct { + in string + err string +}{ + {"\x12<doc/>", "illegal character code U+0012"}, + {"<?xml version=\"1.0\"?>\x0b<doc/>", "illegal character code U+000B"}, + {"\xef\xbf\xbe<doc/>", "illegal character code U+FFFE"}, + {"<?xml version=\"1.0\"?><doc>\r\n<hiya/>\x07<toots/></doc>", "illegal character code U+0007"}, + {"<?xml version=\"1.0\"?><doc \x12='value'>what's up</doc>", "expected attribute name in element"}, + {"<doc>&abc\x01;</doc>", "invalid character entity &abc (no semicolon)"}, + {"<doc>&\x01;</doc>", "invalid character entity & (no semicolon)"}, + {"<doc>&\xef\xbf\xbe;</doc>", "invalid character entity &\uFFFE;"}, + {"<doc>&hello;</doc>", "invalid character entity &hello;"}, +} + +func TestDisallowedCharacters(t *testing.T) { + + for i, tt := range characterTests { + d := NewDecoder(strings.NewReader(tt.in)) + var err error + + for err == nil { + _, err = d.Token() + } + synerr, ok := err.(*SyntaxError) + if !ok { + t.Fatalf("input %d d.Token() = _, %v, want _, *SyntaxError", i, err) + } + if synerr.Msg != tt.err { + t.Fatalf("input %d synerr.Msg wrong: want %q, got %q", i, tt.err, synerr.Msg) + } + } +} + +type procInstEncodingTest struct { + expect, got string +} + +var procInstTests = []struct { + input, expect string +}{ + {`version="1.0" encoding="utf-8"`, "utf-8"}, + {`version="1.0" encoding='utf-8'`, "utf-8"}, + {`version="1.0" encoding='utf-8' `, "utf-8"}, + {`version="1.0" encoding=utf-8`, ""}, + {`encoding="FOO" `, "FOO"}, +} + +func TestProcInstEncoding(t *testing.T) { + for _, test := range procInstTests { + got := procInstEncoding(test.input) + if got != test.expect { + t.Errorf("procInstEncoding(%q) = %q; want %q", test.input, got, test.expect) + } + } +} + +// Ensure that directives with comments include the complete +// text of any nested directives. + +var directivesWithCommentsInput = ` +<!DOCTYPE [<!-- a comment --><!ENTITY rdf "http://www.w3.org/1999/02/22-rdf-syntax-ns#">]> +<!DOCTYPE [<!ENTITY go "Golang"><!-- a comment-->]> +<!DOCTYPE <!-> <!> <!----> <!-->--> <!--->--> [<!ENTITY go "Golang"><!-- a comment-->]> +` + +var directivesWithCommentsTokens = []Token{ + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY rdf "http://www.w3.org/1999/02/22-rdf-syntax-ns#">]`), + CharData("\n"), + Directive(`DOCTYPE [<!ENTITY go "Golang">]`), + CharData("\n"), + Directive(`DOCTYPE <!-> <!> [<!ENTITY go "Golang">]`), + CharData("\n"), +} + +func TestDirectivesWithComments(t *testing.T) { + d := NewDecoder(strings.NewReader(directivesWithCommentsInput)) + + for i, want := range directivesWithCommentsTokens { + have, err := d.Token() + if err != nil { + t.Fatalf("token %d: unexpected error: %s", i, err) + } + if !reflect.DeepEqual(have, want) { + t.Errorf("token %d = %#v want %#v", i, have, want) + } + } +} + +// Writer whose Write method always returns an error. +type errWriter struct{} + +func (errWriter) Write(p []byte) (n int, err error) { return 0, fmt.Errorf("unwritable") } + +func TestEscapeTextIOErrors(t *testing.T) { + expectErr := "unwritable" + err := EscapeText(errWriter{}, []byte{'A'}) + + if err == nil || err.Error() != expectErr { + t.Errorf("have %v, want %v", err, expectErr) + } +} + +func TestEscapeTextInvalidChar(t *testing.T) { + input := []byte("A \x00 terminated string.") + expected := "A \uFFFD terminated string." + + buff := new(bytes.Buffer) + if err := EscapeText(buff, input); err != nil { + t.Fatalf("have %v, want nil", err) + } + text := buff.String() + + if text != expected { + t.Errorf("have %v, want %v", text, expected) + } +} + +func TestIssue5880(t *testing.T) { + type T []byte + data, err := Marshal(T{192, 168, 0, 1}) + if err != nil { + t.Errorf("Marshal error: %v", err) + } + if !utf8.Valid(data) { + t.Errorf("Marshal generated invalid UTF-8: %x", data) + } +} |