diff options
Diffstat (limited to 'src/pkg/gob')
-rw-r--r-- | src/pkg/gob/codec_test.go | 38 | ||||
-rw-r--r-- | src/pkg/gob/debug.go | 35 | ||||
-rw-r--r-- | src/pkg/gob/decode.go | 410 | ||||
-rw-r--r-- | src/pkg/gob/decoder.go | 2 | ||||
-rw-r--r-- | src/pkg/gob/encode.go | 202 | ||||
-rw-r--r-- | src/pkg/gob/encoder.go | 120 | ||||
-rw-r--r-- | src/pkg/gob/encoder_test.go | 18 | ||||
-rw-r--r-- | src/pkg/gob/gobencdec_test.go | 331 | ||||
-rw-r--r-- | src/pkg/gob/type.go | 437 | ||||
-rw-r--r-- | src/pkg/gob/type_test.go | 24 |
10 files changed, 1276 insertions, 341 deletions
diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index fe1f60ba7..4562e1930 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -303,7 +303,7 @@ func TestScalarEncInstructions(t *testing.T) { } } -func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) { +func execDec(typ string, instr *decInstr, state *decoderState, t *testing.T, p unsafe.Pointer) { defer testError(t) v := int(state.decodeUint()) if v+state.fieldnum != 6 { @@ -313,7 +313,7 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un state.fieldnum = 6 } -func newDecodeStateFromData(data []byte) *decodeState { +func newDecodeStateFromData(data []byte) *decoderState { b := bytes.NewBuffer(data) state := newDecodeState(nil, b) state.fieldnum = -1 @@ -342,7 +342,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a int } - instr := &decInstr{decOpMap[reflect.Int], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Int], 6, 0, 0, ovfl} state := newDecodeStateFromData(signedResult) execDec("int", instr, state, t, unsafe.Pointer(&data)) if data.a != 17 { @@ -355,7 +355,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a uint } - instr := &decInstr{decOpMap[reflect.Uint], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Uint], 6, 0, 0, ovfl} state := newDecodeStateFromData(unsignedResult) execDec("uint", instr, state, t, unsafe.Pointer(&data)) if data.a != 17 { @@ -446,7 +446,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a uintptr } - instr := &decInstr{decOpMap[reflect.Uintptr], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Uintptr], 6, 0, 0, ovfl} state := newDecodeStateFromData(unsignedResult) execDec("uintptr", instr, state, t, unsafe.Pointer(&data)) if data.a != 17 { @@ -511,7 +511,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a complex64 } - instr := &decInstr{decOpMap[reflect.Complex64], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Complex64], 6, 0, 0, ovfl} state := newDecodeStateFromData(complexResult) execDec("complex", instr, state, t, unsafe.Pointer(&data)) if data.a != 17+19i { @@ -524,7 +524,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a complex128 } - instr := &decInstr{decOpMap[reflect.Complex128], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Complex128], 6, 0, 0, ovfl} state := newDecodeStateFromData(complexResult) execDec("complex", instr, state, t, unsafe.Pointer(&data)) if data.a != 17+19i { @@ -973,18 +973,32 @@ func TestIgnoredFields(t *testing.T) { } } + +func TestBadRecursiveType(t *testing.T) { + type Rec ***Rec + var rec Rec + b := new(bytes.Buffer) + err := NewEncoder(b).Encode(&rec) + if err == nil { + t.Error("expected error; got none") + } else if strings.Index(err.String(), "recursive") < 0 { + t.Error("expected recursive type error; got", err) + } + // Can't test decode easily because we can't encode one, so we can't pass one to a Decoder. +} + type Bad0 struct { - ch chan int - c float64 + CH chan int + C float64 } -var nilEncoder *Encoder func TestInvalidField(t *testing.T) { var bad0 Bad0 - bad0.ch = make(chan int) + bad0.CH = make(chan int) b := new(bytes.Buffer) - err := nilEncoder.encode(b, reflect.NewValue(&bad0)) + var nilEncoder *Encoder + err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0))) if err == nil { t.Error("expected error; got none") } else if strings.Index(err.String(), "type") < 0 { diff --git a/src/pkg/gob/debug.go b/src/pkg/gob/debug.go index e4583901e..69c83bda7 100644 --- a/src/pkg/gob/debug.go +++ b/src/pkg/gob/debug.go @@ -155,6 +155,16 @@ func (deb *debugger) dump(format string, args ...interface{}) { // Debug prints a human-readable representation of the gob data read from r. func Debug(r io.Reader) { + err := debug(r) + if err != nil { + fmt.Fprintf(os.Stderr, "gob debug: %s\n", err) + } +} + +// debug implements Debug, but catches panics and returns +// them as errors to be printed by Debug. +func debug(r io.Reader) (err os.Error) { + defer catchError(&err) fmt.Fprintln(os.Stderr, "Start of debugging") deb := &debugger{ r: newPeekReader(r), @@ -166,6 +176,7 @@ func Debug(r io.Reader) { deb.remainingKnown = true } deb.gobStream() + return } // note that we've consumed some bytes @@ -386,11 +397,15 @@ func (deb *debugger) typeDefinition(indent tab, id typeId) { // Field number 1 is type Id of key deb.delta(1) keyId := deb.typeId() - wire.SliceT = &sliceType{com, id} // Field number 2 is type Id of elem deb.delta(1) elemId := deb.typeId() wire.MapT = &mapType{com, keyId, elemId} + case 4: // GobEncoder type, one field of {{Common}} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + wire.GobEncoderT = &gobEncoderType{com} default: errorf("bad field in type %d", fieldNum) } @@ -507,6 +522,8 @@ func (deb *debugger) printWireType(indent tab, wire *wireType) { for i, field := range wire.StructT.Field { fmt.Fprintf(os.Stderr, "%sfield %d:\t%s\tid=%d\n", indent+1, i, field.Name, field.Id) } + case wire.GobEncoderT != nil: + deb.printCommonType(indent, "GobEncoder", &wire.GobEncoderT.CommonType) } indent-- fmt.Fprintf(os.Stderr, "%s}\n", indent) @@ -538,6 +555,8 @@ func (deb *debugger) fieldValue(indent tab, id typeId) { deb.sliceValue(indent, wire) case wire.StructT != nil: deb.structValue(indent, id) + case wire.GobEncoderT != nil: + deb.gobEncoderValue(indent, id) default: panic("bad wire type for field") } @@ -654,3 +673,17 @@ func (deb *debugger) structValue(indent tab, id typeId) { fmt.Fprintf(os.Stderr, "%s} // end %s struct\n", indent, id.name()) deb.dump(">> End of struct value of type %d %q", id, id.name()) } + +// GobEncoderValue: +// uint(n) byte*n +func (deb *debugger) gobEncoderValue(indent tab, id typeId) { + len := deb.uint64() + deb.dump("GobEncoder value of %q id=%d, length %d\n", id.name(), id, len) + fmt.Fprintf(os.Stderr, "%s%s (implements GobEncoder)\n", indent, id.name()) + data := make([]byte, len) + _, err := deb.r.Read(data) + if err != nil { + errorf("gobEncoder data read: %s", err) + } + fmt.Fprintf(os.Stderr, "%s[% .2x]\n", indent+1, data) +} diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 9667f6157..b7ae78200 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -13,9 +13,7 @@ import ( "math" "os" "reflect" - "unicode" "unsafe" - "utf8" ) var ( @@ -24,9 +22,9 @@ var ( errRange = os.ErrorString("gob: internal error: field numbers out of bounds") ) -// The execution state of an instance of the decoder. A new state +// decoderState is the execution state of an instance of the decoder. A new state // is created for nested objects. -type decodeState struct { +type decoderState struct { dec *Decoder // The buffer is stored with an extra indirection because it may be replaced // if we load a type during decode (when reading an interface value). @@ -37,8 +35,8 @@ type decodeState struct { // We pass the bytes.Buffer separately for easier testing of the infrastructure // without requiring a full Decoder. -func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decodeState { - d := new(decodeState) +func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decoderState { + d := new(decoderState) d.dec = dec d.b = buf d.buf = make([]byte, uint64Size) @@ -85,7 +83,7 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err os.Erro // decodeUint reads an encoded unsigned integer from state.r. // Does not check for overflow. -func (state *decodeState) decodeUint() (x uint64) { +func (state *decoderState) decodeUint() (x uint64) { b, err := state.b.ReadByte() if err != nil { error(err) @@ -112,7 +110,7 @@ func (state *decodeState) decodeUint() (x uint64) { // decodeInt reads an encoded signed integer from state.r. // Does not check for overflow. -func (state *decodeState) decodeInt() int64 { +func (state *decoderState) decodeInt() int64 { x := state.decodeUint() if x&1 != 0 { return ^int64(x >> 1) @@ -120,7 +118,8 @@ func (state *decodeState) decodeInt() int64 { return int64(x >> 1) } -type decOp func(i *decInstr, state *decodeState, p unsafe.Pointer) +// decOp is the signature of a decoding operator for a given type. +type decOp func(i *decInstr, state *decoderState, p unsafe.Pointer) // The 'instructions' of the decoding machine type decInstr struct { @@ -150,26 +149,31 @@ func decIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { return p } -func ignoreUint(i *decInstr, state *decodeState, p unsafe.Pointer) { +// ignoreUint discards a uint value with no destination. +func ignoreUint(i *decInstr, state *decoderState, p unsafe.Pointer) { state.decodeUint() } -func ignoreTwoUints(i *decInstr, state *decodeState, p unsafe.Pointer) { +// ignoreTwoUints discards a uint value with no destination. It's used to skip +// complex values. +func ignoreTwoUints(i *decInstr, state *decoderState, p unsafe.Pointer) { state.decodeUint() state.decodeUint() } -func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decBool decodes a uiint and stores it as a boolean through p. +func decBool(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(bool)) } p = *(*unsafe.Pointer)(p) } - *(*bool)(p) = state.decodeInt() != 0 + *(*bool)(p) = state.decodeUint() != 0 } -func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decInt8 decodes an integer and stores it as an int8 through p. +func decInt8(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int8)) @@ -184,7 +188,8 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decUint8 decodes an unsigned integer and stores it as a uint8 through p. +func decUint8(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint8)) @@ -199,7 +204,8 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decInt16 decodes an integer and stores it as an int16 through p. +func decInt16(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int16)) @@ -214,7 +220,8 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decUint16 decodes an unsigned integer and stores it as a uint16 through p. +func decUint16(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint16)) @@ -229,7 +236,8 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decInt32 decodes an integer and stores it as an int32 through p. +func decInt32(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int32)) @@ -244,7 +252,8 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decUint32 decodes an unsigned integer and stores it as a uint32 through p. +func decUint32(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint32)) @@ -259,7 +268,8 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decInt64 decodes an integer and stores it as an int64 through p. +func decInt64(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int64)) @@ -269,7 +279,8 @@ func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*int64)(p) = int64(state.decodeInt()) } -func decUint64(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decUint64 decodes an unsigned integer and stores it as a uint64 through p. +func decUint64(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint64)) @@ -294,7 +305,9 @@ func floatFromBits(u uint64) float64 { return math.Float64frombits(v) } -func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { +// storeFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point +// number, and stores it through p. It's a helper function for float32 and complex64. +func storeFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) { v := floatFromBits(state.decodeUint()) av := v if av < 0 { @@ -308,7 +321,9 @@ func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point +// number, and stores it through p. +func decFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32)) @@ -318,7 +333,9 @@ func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { storeFloat32(i, state, p) } -func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point +// number, and stores it through p. +func decFloat64(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float64)) @@ -328,8 +345,10 @@ func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*float64)(p) = floatFromBits(uint64(state.decodeUint())) } -// Complex numbers are just a pair of floating-point numbers, real part first. -func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decComplex64 decodes a pair of unsigned integers, treats them as a +// pair of floating point numbers, and stores them as a complex64 through p. +// The real part comes first. +func decComplex64(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex64)) @@ -340,7 +359,10 @@ func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) { storeFloat32(i, state, unsafe.Pointer(uintptr(p)+uintptr(unsafe.Sizeof(float32(0))))) } -func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decComplex128 decodes a pair of unsigned integers, treats them as a +// pair of floating point numbers, and stores them as a complex128 through p. +// The real part comes first. +func decComplex128(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex128)) @@ -352,8 +374,10 @@ func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*complex128)(p) = complex(real, imag) } +// decUint8Array decodes byte array and stores through p a slice header +// describing the data. // uint8 arrays are encoded as an unsigned count followed by the raw bytes. -func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { +func decUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]uint8)) @@ -365,8 +389,10 @@ func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*[]uint8)(p) = b } +// decString decodes byte array and stores through p a string header +// describing the data. // Strings are encoded as an unsigned count followed by the raw bytes. -func decString(i *decInstr, state *decodeState, p unsafe.Pointer) { +func decString(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte)) @@ -378,7 +404,8 @@ func decString(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*string)(p) = string(b) } -func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { +// ignoreUint8Array skips over the data for a byte slice value with no destination. +func ignoreUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) { b := make([]byte, state.decodeUint()) state.b.Read(b) } @@ -409,9 +436,15 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { return *(*uintptr)(up) } -func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr, indir int) (err os.Error) { - defer catchError(&err) - p = allocate(rtyp, p, indir) +// decodeSingle decodes a top-level value that is not a struct and stores it through p. +// Such values are preceded by a zero, making them have the memory layout of a +// struct field (although with an illegal field number). +func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) { + indir := ut.indir + if ut.isGobDecoder { + indir = int(ut.decIndir) + } + p = allocate(ut.base, p, indir) state := newDecodeState(dec, &dec.buf) state.fieldnum = singletonField basep := p @@ -428,9 +461,13 @@ func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr return nil } -func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, p uintptr, indir int) (err os.Error) { - defer catchError(&err) - p = allocate(rtyp, p, indir) +// decodeSingle decodes a top-level struct and stores it through p. +// Indir is for the value, not the type. At the time of the call it may +// differ from ut.indir, which was computed when the engine was built. +// This state cannot arise for decodeSingle, which is called directly +// from the user's value, not from the innards of an engine. +func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) { + p = allocate(ut.base.(*reflect.StructType), p, indir) state := newDecodeState(dec, &dec.buf) state.fieldnum = -1 basep := p @@ -458,8 +495,8 @@ func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, p return nil } +// ignoreStruct discards the data for a struct with no destination. func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { - defer catchError(&err) state := newDecodeState(dec, &dec.buf) state.fieldnum = -1 for state.b.Len() > 0 { @@ -481,8 +518,9 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { return nil } +// ignoreSingle discards the data for a top-level non-struct value with no +// destination. It's used when calling Decode with a nil value. func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { - defer catchError(&err) state := newDecodeState(dec, &dec.buf) state.fieldnum = singletonField delta := int(state.decodeUint()) @@ -494,7 +532,8 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { return nil } -func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) { +// decodeArrayHelper does the work for decoding arrays and slices. +func (dec *Decoder) decodeArrayHelper(state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) { instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl} for i := 0; i < length; i++ { up := unsafe.Pointer(p) @@ -506,7 +545,10 @@ func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decO } } -func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) { +// decodeArray decodes an array and stores it through p, that is, p points to the zeroth element. +// The length is an unsigned integer preceding the elements. Even though the length is redundant +// (it's part of the type), it's a useful check and is included in the encoding. +func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect } @@ -516,9 +558,11 @@ func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p u dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) } -func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { +// decodeIntoValue is a helper for map decoding. Since maps are decoded using reflection, +// unlike the other items we can't use a pointer directly. +func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { instr := &decInstr{op, 0, indir, 0, ovfl} - up := unsafe.Pointer(v.Addr()) + up := unsafe.Pointer(v.UnsafeAddr()) if indir > 1 { up = decIndirect(up, indir) } @@ -526,7 +570,11 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o return v } -func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) { +// decodeMap decodes a map and stores its header through p. +// Maps are encoded as a length followed by key:value pairs. +// Because the internals of maps are not visible to us, we must +// use reflection rather than pointer magic. +func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decoderState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect } @@ -538,7 +586,7 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue) + v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p))).(*reflect.MapValue) n := int(state.decodeUint()) for i := 0; i < n; i++ { key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl) @@ -547,21 +595,24 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp } } -func (dec *Decoder) ignoreArrayHelper(state *decodeState, elemOp decOp, length int) { +// ignoreArrayHelper does the work for discarding arrays and slices. +func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length int) { instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} for i := 0; i < length; i++ { elemOp(instr, state, nil) } } -func (dec *Decoder) ignoreArray(state *decodeState, elemOp decOp, length int) { +// ignoreArray discards the data for an array value with no destination. +func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) { if n := state.decodeUint(); n != uint64(length) { errorf("gob: length mismatch in ignoreArray") } dec.ignoreArrayHelper(state, elemOp, length) } -func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) { +// ignoreMap discards the data for a map value with no destination. +func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) { n := int(state.decodeUint()) keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")} elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} @@ -571,7 +622,9 @@ func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) { } } -func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) { +// decodeSlice decodes a slice and stores the slice header through p. +// Slices are encoded as an unsigned length followed by the elements. +func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) { n := int(uintptr(state.decodeUint())) if indir > 0 { up := unsafe.Pointer(p) @@ -590,7 +643,8 @@ func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p u dec.decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl) } -func (dec *Decoder) ignoreSlice(state *decodeState, elemOp decOp) { +// ignoreSlice skips over the data for a slice value with no destination. +func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) { dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint())) } @@ -609,9 +663,10 @@ func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) { ivalue.Set(value) } -// decodeInterface receives the name of a concrete type followed by its value. +// decodeInterface decodes an interface value and stores it through p. +// Interfaces are encoded as the name of a concrete type followed by a value. // If the name is empty, the value is nil and no value is sent. -func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) { +func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decoderState, p uintptr, indir int) { // Create an interface reflect.Value. We need one even for the nil case. ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue) // Read the name of the concrete type. @@ -655,7 +710,8 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get() } -func (dec *Decoder) ignoreInterface(state *decodeState) { +// ignoreInterface discards the data for an interface value with no destination. +func (dec *Decoder) ignoreInterface(state *decoderState) { // Read the name of the concrete type. b := make([]byte, state.decodeUint()) _, err := state.b.Read(b) @@ -670,8 +726,34 @@ func (dec *Decoder) ignoreInterface(state *decodeState) { state.b.Next(int(state.decodeUint())) } +// decodeGobDecoder decodes something implementing the GobDecoder interface. +// The data is encoded as a byte slice. +func (dec *Decoder) decodeGobDecoder(state *decoderState, v reflect.Value, index int) { + // Read the bytes for the value. + b := make([]byte, state.decodeUint()) + _, err := state.b.Read(b) + if err != nil { + error(err) + } + // We know it's a GobDecoder, so just call the method directly. + err = v.Interface().(GobDecoder).GobDecode(b) + if err != nil { + error(err) + } +} + +// ignoreGobDecoder discards the data for a GobDecoder value with no destination. +func (dec *Decoder) ignoreGobDecoder(state *decoderState) { + // Read the bytes for the value. + b := make([]byte, state.decodeUint()) + _, err := state.b.Read(b) + if err != nil { + error(err) + } +} + // Index by Go types. -var decOpMap = []decOp{ +var decOpTable = [...]decOp{ reflect.Bool: decBool, reflect.Int8: decInt8, reflect.Int16: decInt16, @@ -699,37 +781,49 @@ var decIgnoreOpMap = map[typeId]decOp{ tComplex: ignoreTwoUints, } -// Return the decoding op for the base type under rt and +// decOpFor returns the decoding op for the base type under rt and // the indirection count to reach it. -func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) { - typ, indir := indirect(rt) +func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) { + ut := userType(rt) + // If the type implements GobEncoder, we handle it without further processing. + if ut.isGobDecoder { + return dec.gobDecodeOpFor(ut) + } + // If this type is already in progress, it's a recursive type (e.g. map[string]*T). + // Return the pointer to the op we're already building. + if opPtr := inProgress[rt]; opPtr != nil { + return opPtr, ut.indir + } + typ := ut.base + indir := ut.indir var op decOp k := typ.Kind() - if int(k) < len(decOpMap) { - op = decOpMap[k] + if int(k) < len(decOpTable) { + op = decOpTable[k] } if op == nil { + inProgress[rt] = &op // Special cases switch t := typ.(type) { case *reflect.ArrayType: name = "element of " + name elemId := dec.wireType[wireId].ArrayT.Elem - elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.dec.decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) } case *reflect.MapType: name = "element of " + name keyId := dec.wireType[wireId].MapT.Key elemId := dec.wireType[wireId].MapT.Elem - keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name) - elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) + keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { up := unsafe.Pointer(p) - state.dec.decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) + state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl) } case *reflect.SliceType: @@ -744,46 +838,46 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp } else { elemId = dec.wireType[wireId].SliceT.Elem } - elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.dec.decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. - enginePtr, err := dec.getDecEnginePtr(wireId, typ) + enginePtr, err := dec.getDecEnginePtr(wireId, userType(typ)) if err != nil { error(err) } - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - // indirect through enginePtr to delay evaluation for recursive structs - err = dec.decodeStruct(*enginePtr, t, uintptr(p), i.indir) + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + // indirect through enginePtr to delay evaluation for recursive structs. + err = dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir) if err != nil { error(err) } } case *reflect.InterfaceType: - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - dec.decodeInterface(t, state, uintptr(p), i.indir) + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + state.dec.decodeInterface(t, state, uintptr(p), i.indir) } } } if op == nil { errorf("gob: decode can't handle type %s", rt.String()) } - return op, indir + return &op, indir } -// Return the decoding op for a field that has no destination. +// decIgnoreOpFor returns the decoding op for a field that has no destination. func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { op, ok := decIgnoreOpMap[wireId] if !ok { if wireId == tInterface { // Special case because it's a method: the ignored item might // define types and we need to record their state in the decoder. - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - dec.ignoreInterface(state) + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + state.dec.ignoreInterface(state) } return op } @@ -795,7 +889,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { case wire.ArrayT != nil: elemId := wire.ArrayT.Elem elemOp := dec.decIgnoreOpFor(elemId) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.ignoreArray(state, elemOp, wire.ArrayT.Len) } @@ -804,14 +898,14 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { elemId := dec.wireType[wireId].MapT.Elem keyOp := dec.decIgnoreOpFor(keyId) elemOp := dec.decIgnoreOpFor(elemId) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.ignoreMap(state, keyOp, elemOp) } case wire.SliceT != nil: elemId := wire.SliceT.Elem elemOp := dec.decIgnoreOpFor(elemId) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.ignoreSlice(state, elemOp) } @@ -821,10 +915,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { if err != nil { error(err) } - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { // indirect through enginePtr to delay evaluation for recursive structs state.dec.ignoreStruct(*enginePtr) } + + case wire.GobEncoderT != nil: + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + state.dec.ignoreGobDecoder(state) + } } } if op == nil { @@ -833,14 +932,58 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { return op } -// Are these two gob Types compatible? -// Answers the question for basic types, arrays, and slices. +// gobDecodeOpFor returns the op for a type that is known to implement +// GobDecoder. +func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { + rt := ut.user + if ut.decIndir != 0 { + errorf("gob: TODO: can't handle indirection to reach GobDecoder") + } + index := -1 + for i := 0; i < rt.NumMethod(); i++ { + if rt.Method(i).Name == gobDecodeMethodName { + index = i + break + } + } + if index < 0 { + panic("can't find GobDecode method") + } + var op decOp + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + // Allocate the underlying data, but hold on to the address we have, + // since it's known to be the receiver's address. + // TODO: fix this up when decIndir can be non-zero. + allocate(ut.base, uintptr(p), ut.indir) + v := reflect.NewValue(unsafe.Unreflect(rt, p)) + state.dec.decodeGobDecoder(state, v, index) + } + return &op, int(ut.decIndir) + +} + +// compatibleType asks: Are these two gob Types compatible? +// Answers the question for basic types, arrays, maps and slices, plus +// GobEncoder/Decoder pairs. // Structs are considered ok; fields will be checked later. -func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { - fr, _ = indirect(fr) - switch t := fr.(type) { +func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool { + if rhs, ok := inProgress[fr]; ok { + return rhs == fw + } + inProgress[fr] = fw + ut := userType(fr) + wire, ok := dec.wireType[fw] + // If fr is a GobDecoder, the wire type must be GobEncoder. + // And if fr is not a GobDecoder, the wire type must not be either. + if ut.isGobDecoder != (ok && wire.GobEncoderT != nil) { // the parentheses look odd but are correct. + return false + } + if ut.isGobDecoder { // This test trumps all others. + return true + } + switch t := ut.base.(type) { default: - // map, chan, etc: cannot handle. + // chan, etc: cannot handle. return false case *reflect.BoolType: return fw == tBool @@ -857,19 +1000,17 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { case *reflect.InterfaceType: return fw == tInterface case *reflect.ArrayType: - wire, ok := dec.wireType[fw] if !ok || wire.ArrayT == nil { return false } array := wire.ArrayT - return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) + return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress) case *reflect.MapType: - wire, ok := dec.wireType[fw] if !ok || wire.MapT == nil { return false } MapType := wire.MapT - return dec.compatibleType(t.Key(), MapType.Key) && dec.compatibleType(t.Elem(), MapType.Elem) + return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress) case *reflect.SliceType: // Is it an array of bytes? if t.Elem().Kind() == reflect.Uint8 { @@ -882,8 +1023,8 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { } else { sw = dec.wireType[fw].SliceT } - elem, _ := indirect(t.Elem()) - return sw != nil && dec.compatibleType(elem, sw.Elem) + elem := userType(t.Elem()).base + return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress) case *reflect.StructType: return true } @@ -899,21 +1040,27 @@ func (dec *Decoder) typeString(remoteId typeId) string { return dec.wireType[remoteId].string() } - -func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { +// compileSingle compiles the decoder engine for a non-struct top-level value, including +// GobDecoders. +func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) { + rt := ut.base + if ut.isGobDecoder { + rt = ut.user + } engine = new(decEngine) engine.instr = make([]decInstr, 1) // one item name := rt.String() // best we can do - if !dec.compatibleType(rt, remoteId) { + if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) { return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId)) } - op, indir := dec.decOpFor(remoteId, rt, name) + op, indir := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp)) ovfl := os.ErrorString(`value for "` + name + `" out of range`) - engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl} + engine.instr[singletonField] = decInstr{*op, singletonField, indir, 0, ovfl} engine.numInstr = 1 return } +// compileIgnoreSingle compiles the decoder engine for a non-struct top-level value that will be discarded. func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err os.Error) { engine = new(decEngine) engine.instr = make([]decInstr, 1) // one item @@ -924,17 +1071,13 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err return } -// Is this an exported - upper case - name? -func isExported(name string) bool { - rune, _ := utf8.DecodeRuneInString(name) - return unicode.IsUpper(rune) -} - -func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { - defer catchError(&err) +// compileDec compiles the decoder engine for a value. If the value is not a struct, +// it calls out to compileSingle. +func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) { + rt := ut.base srt, ok := rt.(*reflect.StructType) - if !ok { - return dec.compileSingle(remoteId, rt) + if !ok || ut.isGobDecoder { + return dec.compileSingle(remoteId, ut) } var wireStruct *structType // Builtin types can come from global pool; the rest must be defined by the decoder. @@ -953,6 +1096,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng } engine = new(decEngine) engine.instr = make([]decInstr, len(wireStruct.Field)) + seen := make(map[reflect.Type]*decOp) // Loop over the fields of the wire type. for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ { wireField := wireStruct.Field[fieldnum] @@ -968,17 +1112,19 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl} continue } - if !dec.compatibleType(localField.Type, wireField.Id) { + if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) { errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) } - op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name) - engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl} + op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen) + engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl} engine.numInstr++ } return } -func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) { +// getDecEnginePtr returns the engine for the specified type. +func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePtr **decEngine, err os.Error) { + rt := ut.base decoderMap, ok := dec.decoderCache[rt] if !ok { decoderMap = make(map[typeId]**decEngine) @@ -988,7 +1134,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr // To handle recursive types, mark this engine as underway before compiling. enginePtr = new(*decEngine) decoderMap[remoteId] = enginePtr - *enginePtr, err = dec.compileDec(remoteId, rt) + *enginePtr, err = dec.compileDec(remoteId, ut) if err != nil { decoderMap[remoteId] = nil, false } @@ -996,11 +1142,12 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr return } -// When ignoring struct data, in effect we compile it into this type +// emptyStruct is the type we compile into when ignoring a struct value. type emptyStruct struct{} var emptyStructType = reflect.Typeof(emptyStruct{}) +// getDecEnginePtr returns the engine for the specified type when the value is to be discarded. func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) { var ok bool if enginePtr, ok = dec.ignorerCache[wireId]; !ok { @@ -1009,7 +1156,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er dec.ignorerCache[wireId] = enginePtr wire := dec.wireType[wireId] if wire != nil && wire.StructT != nil { - *enginePtr, err = dec.compileDec(wireId, emptyStructType) + *enginePtr, err = dec.compileDec(wireId, userType(emptyStructType)) } else { *enginePtr, err = dec.compileIgnoreSingle(wireId) } @@ -1020,28 +1167,39 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er return } -func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error { +// decodeValue decodes the data stream representing a value and stores it in val. +func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) { + defer catchError(&err) // If the value is nil, it means we should just ignore this item. if val == nil { return dec.decodeIgnoredValue(wireId) } // Dereference down to the underlying struct type. - rt, indir := indirect(val.Type()) - enginePtr, err := dec.getDecEnginePtr(wireId, rt) + ut := userType(val.Type()) + base := ut.base + indir := ut.indir + if ut.isGobDecoder { + indir = int(ut.decIndir) + if indir != 0 { + errorf("TODO: can't handle indirection in GobDecoder value") + } + } + enginePtr, err := dec.getDecEnginePtr(wireId, ut) if err != nil { return err } engine := *enginePtr - if st, ok := rt.(*reflect.StructType); ok { + if st, ok := base.(*reflect.StructType); ok && !ut.isGobDecoder { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { - name := rt.Name() + name := base.Name() return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) } - return dec.decodeStruct(engine, st, uintptr(val.Addr()), indir) + return dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), indir) } - return dec.decodeSingle(engine, rt, uintptr(val.Addr()), indir) + return dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr())) } +// decodeIgnoredValue decodes the data stream representing a value of the specified type and discards it. func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error { enginePtr, err := dec.getIgnoreEnginePtr(wireId) if err != nil { @@ -1066,8 +1224,8 @@ func init() { default: panic("gob: unknown size of int/uint") } - decOpMap[reflect.Int] = iop - decOpMap[reflect.Uint] = uop + decOpTable[reflect.Int] = iop + decOpTable[reflect.Uint] = uop // Finally uintptr switch reflect.Typeof(uintptr(0)).Bits() { @@ -1078,5 +1236,5 @@ func init() { default: panic("gob: unknown size of uintptr") } - decOpMap[reflect.Uintptr] = uop + decOpTable[reflect.Uintptr] = uop } diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go index f7c994ffa..719274583 100644 --- a/src/pkg/gob/decoder.go +++ b/src/pkg/gob/decoder.go @@ -21,7 +21,7 @@ type Decoder struct { wireType map[typeId]*wireType // map from remote ID to local description decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines ignorerCache map[typeId]**decEngine // ditto for ignored objects - countState *decodeState // reads counts from wire + countState *decoderState // reads counts from wire countBuf []byte // used for decoding integers while parsing messages tmp []byte // temporary storage for i/o; saves reallocating err os.Error diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index 2e5ba2487..9190d9203 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -15,7 +15,7 @@ import ( const uint64Size = unsafe.Sizeof(uint64(0)) -// The global execution state of an instance of the encoder. +// encoderState is the global execution state of an instance of the encoder. // Field numbers are delta encoded and always increase. The field // number is initialized to -1 so 0 comes out as delta(1). A delta of // 0 terminates the structure. @@ -72,6 +72,7 @@ func (state *encoderState) encodeInt(i int64) { state.encodeUint(uint64(x)) } +// encOp is the signature of an encoding operator for a given type. type encOp func(i *encInstr, state *encoderState, p unsafe.Pointer) // The 'instructions' of the encoding machine @@ -82,8 +83,8 @@ type encInstr struct { offset uintptr // offset in the structure of the field to encode } -// Emit a field number and update the state to record its value for delta encoding. -// If the instruction pointer is nil, do nothing +// update emits a field number and updates the state to record its value for delta encoding. +// If the instruction pointer is nil, it does nothing func (state *encoderState) update(instr *encInstr) { if instr != nil { state.encodeUint(uint64(instr.field - state.fieldnum)) @@ -97,6 +98,7 @@ func (state *encoderState) update(instr *encInstr) { // Otherwise, the output (for a scalar) is the field number, as an encoded integer, // followed by the field data in its appropriate format. +// encIndirect dereferences p indir times and returns the result. func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { for ; indir > 0; indir-- { p = *(*unsafe.Pointer)(p) @@ -107,6 +109,7 @@ func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { return p } +// encBool encodes the bool with address p as an unsigned 0 or 1. func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*bool)(p) if b || state.sendZero { @@ -119,6 +122,7 @@ func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt encodes the int with address p. func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int)(p)) if v != 0 || state.sendZero { @@ -127,6 +131,7 @@ func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint encodes the uint with address p. func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint)(p)) if v != 0 || state.sendZero { @@ -135,6 +140,7 @@ func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt8 encodes the int8 with address p. func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int8)(p)) if v != 0 || state.sendZero { @@ -143,6 +149,7 @@ func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint8 encodes the uint8 with address p. func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint8)(p)) if v != 0 || state.sendZero { @@ -151,6 +158,7 @@ func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt16 encodes the int16 with address p. func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int16)(p)) if v != 0 || state.sendZero { @@ -159,6 +167,7 @@ func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint16 encodes the uint16 with address p. func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint16)(p)) if v != 0 || state.sendZero { @@ -167,6 +176,7 @@ func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt32 encodes the int32 with address p. func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int32)(p)) if v != 0 || state.sendZero { @@ -175,6 +185,7 @@ func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint encodes the uint32 with address p. func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint32)(p)) if v != 0 || state.sendZero { @@ -183,6 +194,7 @@ func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt64 encodes the int64 with address p. func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) { v := *(*int64)(p) if v != 0 || state.sendZero { @@ -191,6 +203,7 @@ func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt64 encodes the uint64 with address p. func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) { v := *(*uint64)(p) if v != 0 || state.sendZero { @@ -199,6 +212,7 @@ func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUintptr encodes the uintptr with address p. func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uintptr)(p)) if v != 0 || state.sendZero { @@ -207,6 +221,7 @@ func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// floatBits returns a uint64 holding the bits of a floating-point number. // Floating-point numbers are transmitted as uint64s holding the bits // of the underlying representation. They are sent byte-reversed, with // the exponent end coming out first, so integer floating point numbers @@ -223,6 +238,7 @@ func floatBits(f float64) uint64 { return v } +// encFloat32 encodes the float32 with address p. func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { f := *(*float32)(p) if f != 0 || state.sendZero { @@ -232,6 +248,7 @@ func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encFloat64 encodes the float64 with address p. func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { f := *(*float64)(p) if f != 0 || state.sendZero { @@ -241,6 +258,7 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encComplex64 encodes the complex64 with address p. // Complex numbers are just a pair of floating-point numbers, real part first. func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) { c := *(*complex64)(p) @@ -253,6 +271,7 @@ func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encComplex128 encodes the complex128 with address p. func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { c := *(*complex128)(p) if c != 0+0i || state.sendZero { @@ -264,6 +283,7 @@ func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint8Array encodes the byte slice whose header has address p. // Byte arrays are encoded as an unsigned count followed by the raw bytes. func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*[]byte)(p) @@ -274,6 +294,7 @@ func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encString encodes the string whose header has address p. // Strings are encoded as an unsigned count followed by the raw bytes. func encString(i *encInstr, state *encoderState, p unsafe.Pointer) { s := *(*string)(p) @@ -284,14 +305,15 @@ func encString(i *encInstr, state *encoderState, p unsafe.Pointer) { } } -// The end of a struct is marked by a delta field number of 0. +// encStructTerminator encodes the end of an encoded struct +// as delta field number of 0. func encStructTerminator(i *encInstr, state *encoderState, p unsafe.Pointer) { state.encodeUint(0) } // Execution engine -// The encoder engine is an array of instructions indexed by field number of the encoding +// encEngine an array of instructions indexed by field number of the encoding // data, typically a struct. It is executed top to bottom, walking the struct. type encEngine struct { instr []encInstr @@ -299,6 +321,7 @@ type encEngine struct { const singletonField = 0 +// encodeSingle encodes a single top-level non-struct value. func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintptr) { state := newEncoderState(enc, b) state.fieldnum = singletonField @@ -315,6 +338,7 @@ func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintp instr.op(instr, state, p) } +// encodeStruct encodes a single struct value. func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintptr) { state := newEncoderState(enc, b) state.fieldnum = -1 @@ -330,6 +354,7 @@ func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintp } } +// encodeArray encodes the array whose 0th element is at p. func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) { state := newEncoderState(enc, b) state.fieldnum = -1 @@ -349,6 +374,7 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui } } +// encodeReflectValue is a helper for maps. It encodes the value v. func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) { for i := 0; i < indir && v != nil; i++ { v = reflect.Indirect(v) @@ -356,9 +382,12 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in if v == nil { errorf("gob: encodeReflectValue: nil element") } - op(nil, state, unsafe.Pointer(v.Addr())) + op(nil, state, unsafe.Pointer(v.UnsafeAddr())) } +// encodeMap encodes a map as unsigned count followed by key:value pairs. +// Because map internals are not exposed, we must use reflection rather than +// addresses. func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) { state := newEncoderState(enc, b) state.fieldnum = -1 @@ -371,6 +400,7 @@ func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elem } } +// encodeInterface encodes the interface value iv. // To send an interface, we send a string identifying the concrete type, followed // by the type identifier (which might require defining that type right now), followed // by the concrete value. A nil value gets sent as the empty string for the name, @@ -384,10 +414,10 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) return } - typ, _ := indirect(iv.Elem().Type()) - name, ok := concreteTypeToName[typ] + ut := userType(iv.Elem().Type()) + name, ok := concreteTypeToName[ut.base] if !ok { - errorf("gob: type not registered for interface: %s", typ) + errorf("gob: type not registered for interface: %s", ut.base) } // Send the name. state.encodeUint(uint64(len(name))) @@ -396,14 +426,14 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) error(err) } // Define the type id if necessary. - enc.sendTypeDescriptor(enc.writer(), state, typ) + enc.sendTypeDescriptor(enc.writer(), state, ut) // Send the type id. - enc.sendTypeId(state, typ) + enc.sendTypeId(state, ut) // Encode the value into a new buffer. Any nested type definitions // should be written to b, before the encoded value. enc.pushWriter(b) data := new(bytes.Buffer) - err = enc.encode(data, iv.Elem()) + err = enc.encode(data, iv.Elem(), ut) if err != nil { error(err) } @@ -414,7 +444,22 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) } } -var encOpMap = []encOp{ +// encGobEncoder encodes a value that implements the GobEncoder interface. +// The data is sent as a byte array. +func (enc *Encoder) encodeGobEncoder(b *bytes.Buffer, v reflect.Value, index int) { + // TODO: should we catch panics from the called method? + // We know it's a GobEncoder, so just call the method directly. + data, err := v.Interface().(GobEncoder).GobEncode() + if err != nil { + error(err) + } + state := newEncoderState(enc, b) + state.fieldnum = -1 + state.encodeUint(uint64(len(data))) + state.b.Write(data) +} + +var encOpTable = [...]encOp{ reflect.Bool: encBool, reflect.Int: encInt, reflect.Int8: encInt8, @@ -434,16 +479,28 @@ var encOpMap = []encOp{ reflect.String: encString, } -// Return the encoding op for the base type under rt and +// encOpFor returns (a pointer to) the encoding op for the base type under rt and // the indirection count to reach it. -func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { - typ, indir := indirect(rt) - var op encOp +func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) { + ut := userType(rt) + // If the type implements GobEncoder, we handle it without further processing. + if ut.isGobEncoder { + return enc.gobEncodeOpFor(ut) + } + // If this type is already in progress, it's a recursive type (e.g. map[string]*T). + // Return the pointer to the op we're already building. + if opPtr := inProgress[rt]; opPtr != nil { + return opPtr, ut.indir + } + typ := ut.base + indir := ut.indir k := typ.Kind() - if int(k) < len(encOpMap) { - op = encOpMap[k] + var op encOp + if int(k) < len(encOpTable) { + op = encOpTable[k] } if op == nil { + inProgress[rt] = &op // Special cases switch t := typ.(type) { case *reflect.SliceType: @@ -452,40 +509,40 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { break } // Slices have a header; we decode it to find the underlying array. - elemOp, indir := enc.encOpFor(t.Elem()) + elemOp, indir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { slice := (*reflect.SliceHeader)(p) if !state.sendZero && slice.Len == 0 { return } state.update(i) - state.enc.encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len)) + state.enc.encodeArray(state.b, slice.Data, *elemOp, t.Elem().Size(), indir, int(slice.Len)) } case *reflect.ArrayType: // True arrays have size in the type. - elemOp, indir := enc.encOpFor(t.Elem()) + elemOp, indir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i) - state.enc.encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len()) + state.enc.encodeArray(state.b, uintptr(p), *elemOp, t.Elem().Size(), indir, t.Len()) } case *reflect.MapType: - keyOp, keyIndir := enc.encOpFor(t.Key()) - elemOp, elemIndir := enc.encOpFor(t.Elem()) + keyOp, keyIndir := enc.encOpFor(t.Key(), inProgress) + elemOp, elemIndir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p)))) + v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) mv := reflect.Indirect(v).(*reflect.MapValue) if !state.sendZero && mv.Len() == 0 { return } state.update(i) - state.enc.encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir) + state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir) } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. - enc.getEncEngine(typ) + enc.getEncEngine(userType(typ)) info := mustGetTypeInfo(typ) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i) @@ -496,7 +553,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { // Interfaces transmit the name and contents of the concrete // value they contain. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p)))) + v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) iv := reflect.Indirect(v).(*reflect.InterfaceValue) if !state.sendZero && (iv == nil || iv.IsNil()) { return @@ -509,21 +566,54 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { if op == nil { errorf("gob enc: can't happen: encode type %s", rt.String()) } - return op, indir + return &op, indir } -// The local Type was compiled from the actual value, so we know it's compatible. -func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { - srt, isStruct := rt.(*reflect.StructType) +// gobEncodeOpFor returns the op for a type that is known to implement +// GobEncoder. +func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { + rt := ut.user + if ut.encIndir != 0 { + errorf("gob: TODO: can't handle indirection to reach GobEncoder") + } + index := -1 + for i := 0; i < rt.NumMethod(); i++ { + if rt.Method(i).Name == gobEncodeMethodName { + index = i + break + } + } + if index < 0 { + panic("can't find GobEncode method") + } + var op encOp + op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { + // TODO: this will need fixing when ut.encIndr != 0. + v := reflect.NewValue(unsafe.Unreflect(rt, p)) + state.update(i) + state.enc.encodeGobEncoder(state.b, v, index) + } + return &op, int(ut.encIndir) +} + +// compileEnc returns the engine to compile the type. +func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine { + srt, isStruct := ut.base.(*reflect.StructType) engine := new(encEngine) - if isStruct { - for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ { + seen := make(map[reflect.Type]*encOp) + rt := ut.base + if ut.isGobEncoder { + rt = ut.user + } + if !ut.isGobEncoder && isStruct { + for fieldNum, wireFieldNum := 0, 0; fieldNum < srt.NumField(); fieldNum++ { f := srt.Field(fieldNum) if !isExported(f.Name) { continue } - op, indir := enc.encOpFor(f.Type) - engine.instr = append(engine.instr, encInstr{op, fieldNum, indir, uintptr(f.Offset)}) + op, indir := enc.encOpFor(f.Type, seen) + engine.instr = append(engine.instr, encInstr{*op, wireFieldNum, indir, uintptr(f.Offset)}) + wireFieldNum++ } if srt.NumField() > 0 && len(engine.instr) == 0 { errorf("type %s has no exported fields", rt) @@ -531,46 +621,52 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0}) } else { engine.instr = make([]encInstr, 1) - op, indir := enc.encOpFor(rt) - engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero + op, indir := enc.encOpFor(rt, seen) + engine.instr[0] = encInstr{*op, singletonField, indir, 0} // offset is zero } return engine } +// getEncEngine returns the engine to compile the type. // typeLock must be held (or we're in initialization and guaranteed single-threaded). -// The reflection type must have all its indirections processed out. -func (enc *Encoder) getEncEngine(rt reflect.Type) *encEngine { - info, err1 := getTypeInfo(rt) +func (enc *Encoder) getEncEngine(ut *userTypeInfo) *encEngine { + info, err1 := getTypeInfo(ut) if err1 != nil { error(err1) } if info.encoder == nil { // mark this engine as underway before compiling to handle recursive types. info.encoder = new(encEngine) - info.encoder = enc.compileEnc(rt) + info.encoder = enc.compileEnc(ut) } return info.encoder } -// Put this in a function so we can hold the lock only while compiling, not when encoding. -func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine { +// lockAndGetEncEngine is a function that locks and compiles. +// This lets us hold the lock only while compiling, not when encoding. +func (enc *Encoder) lockAndGetEncEngine(ut *userTypeInfo) *encEngine { typeLock.Lock() defer typeLock.Unlock() - return enc.getEncEngine(rt) + return enc.getEncEngine(ut) } -func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value) (err os.Error) { +func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) (err os.Error) { defer catchError(&err) - // Dereference down to the underlying object. - rt, indir := indirect(value.Type()) + engine := enc.lockAndGetEncEngine(ut) + indir := ut.indir + if ut.isGobEncoder { + indir = int(ut.encIndir) + if indir != 0 { + errorf("TODO: can't handle indirection in GobEncoder value") + } + } for i := 0; i < indir; i++ { value = reflect.Indirect(value) } - engine := enc.lockAndGetEncEngine(rt) - if value.Type().Kind() == reflect.Struct { - enc.encodeStruct(b, engine, value.Addr()) + if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct { + enc.encodeStruct(b, engine, value.UnsafeAddr()) } else { - enc.encodeSingle(b, engine, value.Addr()) + enc.encodeSingle(b, engine, value.UnsafeAddr()) } return nil } diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index 29ba44057..4bfcf15c7 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -78,11 +78,57 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) { } } +// sendActualType sends the requested type, without further investigation, unless +// it's been sent before. +func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) { + if _, alreadySent := enc.sent[actual]; alreadySent { + return false + } + typeLock.Lock() + info, err := getTypeInfo(ut) + typeLock.Unlock() + if err != nil { + enc.setError(err) + return + } + // Send the pair (-id, type) + // Id: + state.encodeInt(-int64(info.id)) + // Type: + enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) + enc.writeMessage(w, state.b) + if enc.err != nil { + return + } + + // Remember we've sent this type, both what the user gave us and the base type. + enc.sent[ut.base] = info.id + if ut.user != ut.base { + enc.sent[ut.user] = info.id + } + // Now send the inner types + switch st := actual.(type) { + case *reflect.StructType: + for i := 0; i < st.NumField(); i++ { + enc.sendType(w, state, st.Field(i).Type) + } + case reflect.ArrayOrSliceType: + enc.sendType(w, state, st.Elem()) + } + return true +} + +// sendType sends the type info to the other side, if necessary. func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { - // Drill down to the base type. - rt, _ := indirect(origt) + ut := userType(origt) + if ut.isGobEncoder { + // The rules are different: regardless of the underlying type's representation, + // we need to tell the other side that this exact type is a GobEncoder. + return enc.sendActualType(w, state, ut, ut.user) + } - switch rt := rt.(type) { + // It's a concrete value, so drill down to the base type. + switch rt := ut.base.(type) { default: // Basic types and interfaces do not need to be described. return @@ -108,43 +154,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ return } - // Have we already sent this type? This time we ask about the base type. - if _, alreadySent := enc.sent[rt]; alreadySent { - return - } - - // Need to send it. - typeLock.Lock() - info, err := getTypeInfo(rt) - typeLock.Unlock() - if err != nil { - enc.setError(err) - return - } - // Send the pair (-id, type) - // Id: - state.encodeInt(-int64(info.id)) - // Type: - enc.encode(state.b, reflect.NewValue(info.wire)) - enc.writeMessage(w, state.b) - if enc.err != nil { - return - } - - // Remember we've sent this type. - enc.sent[rt] = info.id - // Remember we've sent the top-level, possibly indirect type too. - enc.sent[origt] = info.id - // Now send the inner types - switch st := rt.(type) { - case *reflect.StructType: - for i := 0; i < st.NumField(); i++ { - enc.sendType(w, state, st.Field(i).Type) - } - case reflect.ArrayOrSliceType: - enc.sendType(w, state, st.Elem()) - } - return true + return enc.sendActualType(w, state, ut, ut.base) } // Encode transmits the data item represented by the empty interface value, @@ -153,12 +163,19 @@ func (enc *Encoder) Encode(e interface{}) os.Error { return enc.EncodeValue(reflect.NewValue(e)) } -// sendTypeId makes sure the remote side knows about this type. +// sendTypeDescriptor makes sure the remote side knows about this type. // It will send a descriptor if this is the first time the type has been // sent. -func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt reflect.Type) { +func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) { // Make sure the type is known to the other side. // First, have we already sent this type? + rt := ut.base + if ut.isGobEncoder { + rt = ut.user + if ut.encIndir != 0 { + panic("TODO: can't handle non-zero encIndir") + } + } if _, alreadySent := enc.sent[rt]; !alreadySent { // No, so send it. sent := enc.sendType(w, state, rt) @@ -170,7 +187,7 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt refl // need to send the type info but we do need to update enc.sent. if !sent { typeLock.Lock() - info, err := getTypeInfo(rt) + info, err := getTypeInfo(ut) typeLock.Unlock() if err != nil { enc.setError(err) @@ -182,9 +199,9 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt refl } // sendTypeId sends the id, which must have already been defined. -func (enc *Encoder) sendTypeId(state *encoderState, rt reflect.Type) { +func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) { // Identify the type of this top-level value. - state.encodeInt(int64(enc.sent[rt])) + state.encodeInt(int64(enc.sent[ut.base])) } // EncodeValue transmits the data item represented by the reflection value, @@ -198,19 +215,22 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { // Remove any nested writers remaining due to previous errors. enc.w = enc.w[0:1] - enc.err = nil - rt, _ := indirect(value.Type()) + ut, err := validUserType(value.Type()) + if err != nil { + return err + } + enc.err = nil state := newEncoderState(enc, new(bytes.Buffer)) - enc.sendTypeDescriptor(enc.writer(), state, rt) - enc.sendTypeId(state, rt) + enc.sendTypeDescriptor(enc.writer(), state, ut) + enc.sendTypeId(state, ut) if enc.err != nil { return enc.err } // Encode the object. - err := enc.encode(state.b, value) + err = enc.encode(state.b, value, ut) if err != nil { enc.setError(err) } else { diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go index 3e06db727..a0c713b81 100644 --- a/src/pkg/gob/encoder_test.go +++ b/src/pkg/gob/encoder_test.go @@ -249,6 +249,24 @@ func TestArray(t *testing.T) { } } +func TestRecursiveMapType(t *testing.T) { + type recursiveMap map[string]recursiveMap + r1 := recursiveMap{"A": recursiveMap{"B": nil, "C": nil}, "D": nil} + r2 := make(recursiveMap) + if err := encAndDec(r1, &r2); err != nil { + t.Error(err) + } +} + +func TestRecursiveSliceType(t *testing.T) { + type recursiveSlice []recursiveSlice + r1 := recursiveSlice{0: recursiveSlice{0: nil}, 1: nil} + r2 := make(recursiveSlice, 0) + if err := encAndDec(r1, &r2); err != nil { + t.Error(err) + } +} + // Regression test for bug: must send zero values inside arrays func TestDefaultsInArray(t *testing.T) { type Type7 struct { diff --git a/src/pkg/gob/gobencdec_test.go b/src/pkg/gob/gobencdec_test.go new file mode 100644 index 000000000..82ca68170 --- /dev/null +++ b/src/pkg/gob/gobencdec_test.go @@ -0,0 +1,331 @@ +// Copyright 20011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file contains tests of the GobEncoder/GobDecoder support. + +package gob + +import ( + "bytes" + "fmt" + "os" + "strings" + "testing" +) + +// Types that implement the GobEncoder/Decoder interfaces. + +type ByteStruct struct { + a byte // not an exported field +} + +type StringStruct struct { + s string // not an exported field +} + +type Gobber int + +type ValueGobber string // encodes with a value, decodes with a pointer. + +// The relevant methods + +func (g *ByteStruct) GobEncode() ([]byte, os.Error) { + b := make([]byte, 3) + b[0] = g.a + b[1] = g.a + 1 + b[2] = g.a + 2 + return b, nil +} + +func (g *ByteStruct) GobDecode(data []byte) os.Error { + if g == nil { + return os.ErrorString("NIL RECEIVER") + } + // Expect N sequential-valued bytes. + if len(data) == 0 { + return os.EOF + } + g.a = data[0] + for i, c := range data { + if c != g.a+byte(i) { + return os.ErrorString("invalid data sequence") + } + } + return nil +} + +func (g *StringStruct) GobEncode() ([]byte, os.Error) { + return []byte(g.s), nil +} + +func (g *StringStruct) GobDecode(data []byte) os.Error { + // Expect N sequential-valued bytes. + if len(data) == 0 { + return os.EOF + } + a := data[0] + for i, c := range data { + if c != a+byte(i) { + return os.ErrorString("invalid data sequence") + } + } + g.s = string(data) + return nil +} + +func (g *Gobber) GobEncode() ([]byte, os.Error) { + return []byte(fmt.Sprintf("VALUE=%d", *g)), nil +} + +func (g *Gobber) GobDecode(data []byte) os.Error { + _, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g)) + return err +} + +func (v ValueGobber) GobEncode() ([]byte, os.Error) { + return []byte(fmt.Sprintf("VALUE=%s", v)), nil +} + +func (v *ValueGobber) GobDecode(data []byte) os.Error { + _, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v)) + return err +} + +// Structs that include GobEncodable fields. + +type GobTest0 struct { + X int // guarantee we have something in common with GobTest* + G *ByteStruct +} + +type GobTest1 struct { + X int // guarantee we have something in common with GobTest* + G *StringStruct +} + +type GobTest2 struct { + X int // guarantee we have something in common with GobTest* + G string // not a GobEncoder - should give us errors +} + +type GobTest3 struct { + X int // guarantee we have something in common with GobTest* + G *Gobber // TODO: should be able to satisfy interface without a pointer +} + +type GobTest4 struct { + X int // guarantee we have something in common with GobTest* + V ValueGobber +} + +type GobTest5 struct { + X int // guarantee we have something in common with GobTest* + V *ValueGobber +} + +type GobTestIgnoreEncoder struct { + X int // guarantee we have something in common with GobTest* +} + +func TestGobEncoderField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + err := enc.Encode(GobTest0{17, &ByteStruct{'A'}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest0) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.G.a != 'A' { + t.Errorf("expected 'A' got %c", x.G.a) + } + // Now a field that's not a structure. + b.Reset() + gobber := Gobber(23) + err = enc.Encode(GobTest3{17, &gobber}) + if err != nil { + t.Fatal("encode error:", err) + } + y := new(GobTest3) + err = dec.Decode(y) + if err != nil { + t.Fatal("decode error:", err) + } + if *y.G != 23 { + t.Errorf("expected '23 got %d", *y.G) + } +} + +// As long as the fields have the same name and implement the +// interface, we can cross-connect them. Not sure it's useful +// and may even be bad but it works and it's hard to prevent +// without exposing the contents of the object, which would +// defeat the purpose. +func TestGobEncoderFieldsOfDifferentType(t *testing.T) { + // first, string in field to byte in field + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest0) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.G.a != 'A' { + t.Errorf("expected 'A' got %c", x.G.a) + } + // now the other direction, byte in field to string in field + b.Reset() + err = enc.Encode(GobTest0{17, &ByteStruct{'X'}}) + if err != nil { + t.Fatal("encode error:", err) + } + y := new(GobTest1) + err = dec.Decode(y) + if err != nil { + t.Fatal("decode error:", err) + } + if y.G.s != "XYZ" { + t.Fatalf("expected `XYZ` got %c", y.G.s) + } +} + +// Test that we can encode a value and decode into a pointer. +func TestGobEncoderValueEncoder(t *testing.T) { + // first, string in field to byte in field + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(GobTest4{17, ValueGobber("hello")}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest5) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if *x.V != "hello" { + t.Errorf("expected `hello` got %s", x.V) + } +} + +func TestGobEncoderFieldTypeError(t *testing.T) { + // GobEncoder to non-decoder: error + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := &GobTest2{} + err = dec.Decode(x) + if err == nil { + t.Fatal("expected decode error for mismatched fields (encoder to non-decoder)") + } + if strings.Index(err.String(), "type") < 0 { + t.Fatal("expected type error; got", err) + } + // Non-encoder to GobDecoder: error + b.Reset() + err = enc.Encode(GobTest2{17, "ABC"}) + if err != nil { + t.Fatal("encode error:", err) + } + y := &GobTest1{} + err = dec.Decode(y) + if err == nil { + t.Fatal("expected decode error for mistmatched fields (non-encoder to decoder)") + } + if strings.Index(err.String(), "type") < 0 { + t.Fatal("expected type error; got", err) + } +} + +// Even though ByteStruct is a struct, it's treated as a singleton at the top level. +func TestGobEncoderStructSingleton(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(&ByteStruct{'A'}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(ByteStruct) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.a != 'A' { + t.Errorf("expected 'A' got %c", x.a) + } +} + +func TestGobEncoderNonStructSingleton(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + g := Gobber(1234) // TODO: shouldn't need to take the address here. + err := enc.Encode(&g) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + var x Gobber + err = dec.Decode(&x) + if err != nil { + t.Fatal("decode error:", err) + } + if x != 1234 { + t.Errorf("expected 1234 got %c", x) + } +} + +func TestGobEncoderIgnoreStructField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + err := enc.Encode(GobTest0{17, &ByteStruct{'A'}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIgnoreEncoder) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.X != 17 { + t.Errorf("expected 17 got %c", x.X) + } +} + +func TestGobEncoderIgnoreNonStructField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + gobber := Gobber(23) + err := enc.Encode(GobTest3{17, &gobber}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIgnoreEncoder) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.X != 17 { + t.Errorf("expected 17 got %c", x.X) + } +} diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index f613f6e8a..a43813941 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -9,15 +9,157 @@ import ( "os" "reflect" "sync" + "unicode" + "utf8" ) -// Reflection types are themselves interface values holding structs -// describing the type. Each type has a different struct so that struct can -// be the kind. For example, if typ is the reflect type for an int8, typ is -// a pointer to a reflect.Int8Type struct; if typ is the reflect type for a -// function, typ is a pointer to a reflect.FuncType struct; we use the type -// of that pointer as the kind. +// userTypeInfo stores the information associated with a type the user has handed +// to the package. It's computed once and stored in a map keyed by reflection +// type. +type userTypeInfo struct { + user reflect.Type // the type the user handed us + base reflect.Type // the base type after all indirections + indir int // number of indirections to reach the base type + isGobEncoder bool // does the type implement GobEncoder? + isGobDecoder bool // does the type implement GobDecoder? + encIndir int8 // number of indirections to reach the receiver type; may be negative + decIndir int8 // number of indirections to reach the receiver type; may be negative +} + +var ( + // Protected by an RWMutex because we read it a lot and write + // it only when we see a new type, typically when compiling. + userTypeLock sync.RWMutex + userTypeCache = make(map[reflect.Type]*userTypeInfo) +) + +// validType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, err will be non-nil. To be used when the error handler +// is not set up. +func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { + userTypeLock.RLock() + ut = userTypeCache[rt] + userTypeLock.RUnlock() + if ut != nil { + return + } + // Now set the value under the write lock. + userTypeLock.Lock() + defer userTypeLock.Unlock() + if ut = userTypeCache[rt]; ut != nil { + // Lost the race; not a problem. + return + } + ut = new(userTypeInfo) + ut.base = rt + ut.user = rt + // A type that is just a cycle of pointers (such as type T *T) cannot + // be represented in gobs, which need some concrete data. We use a + // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6, + // pp 539-540. As we step through indirections, run another type at + // half speed. If they meet up, there's a cycle. + slowpoke := ut.base // walks half as fast as ut.base + for { + pt, ok := ut.base.(*reflect.PtrType) + if !ok { + break + } + ut.base = pt.Elem() + if ut.base == slowpoke { // ut.base lapped slowpoke + // recursive pointer type. + return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String()) + } + if ut.indir%2 == 0 { + slowpoke = slowpoke.(*reflect.PtrType).Elem() + } + ut.indir++ + } + ut.isGobEncoder, ut.encIndir = implementsGobEncoder(ut.user) + ut.isGobDecoder, ut.decIndir = implementsGobDecoder(ut.user) + userTypeCache[rt] = ut + if ut.encIndir != 0 || ut.decIndir != 0 { + // There are checks in lots of other places, but putting this here means we won't even + // attempt to encode/decode this type. + // TODO: make it possible to handle types that are indirect to the implementation, + // such as a structure field of type T when *T implements GobDecoder. + return nil, os.ErrorString("TODO: gob can't handle indirections to GobEncoder/Decoder") + } + return +} + +const ( + gobEncodeMethodName = "GobEncode" + gobDecodeMethodName = "GobDecode" +) + +// implementsGobEncoder reports whether the type implements the interface. It also +// returns the number of indirections required to get to the implementation. +// TODO: when reflection makes it possible, should also be prepared to climb up +// one level if we're not on a pointer (implementation could be on *T for our T). +// That will mean that indir could be < 0, which is sure to cause problems, but +// we ignore them now as indir is always >= 0 now. +func implementsGobEncoder(rt reflect.Type) (implements bool, indir int8) { + if rt == nil { + return + } + // The type might be a pointer, or it might not, and we need to keep + // dereferencing to the base type until we find an implementation. + for { + if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance + if _, ok := reflect.MakeZero(rt).Interface().(GobEncoder); ok { + return true, indir + } + } + if p, ok := rt.(*reflect.PtrType); ok { + indir++ + if indir > 100 { // insane number of indirections + return false, 0 + } + rt = p.Elem() + continue + } + break + } + return false, 0 +} + +// implementsGobDecoder reports whether the type implements the interface. It also +// returns the number of indirections required to get to the implementation. +// TODO: see comment on implementsGobEncoder. +func implementsGobDecoder(rt reflect.Type) (implements bool, indir int8) { + if rt == nil { + return + } + // The type might be a pointer, or it might not, and we need to keep + // dereferencing to the base type until we find an implementation. + for { + if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance + if _, ok := reflect.MakeZero(rt).Interface().(GobDecoder); ok { + return true, indir + } + } + if p, ok := rt.(*reflect.PtrType); ok { + indir++ + if indir > 100 { // insane number of indirections + return false, 0 + } + rt = p.Elem() + continue + } + break + } + return false, 0 +} +// userType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, it calls error. +func userType(rt reflect.Type) *userTypeInfo { + ut, err := validUserType(rt) + if err != nil { + error(err) + } + return ut +} // A typeId represents a gob Type as an integer that can be passed on the wire. // Internally, typeIds are used as keys to a map to recover the underlying type info. type typeId int32 @@ -110,6 +252,7 @@ var ( // Predefined because it's needed by the Decoder var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id +var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType) func init() { // Some magic numbers to make sure there are no surprises. @@ -133,6 +276,7 @@ func init() { } nextId = firstUserId registerBasics() + wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil))) } // Array type @@ -142,12 +286,18 @@ type arrayType struct { Len int } -func newArrayType(name string, elem gobType, length int) *arrayType { - a := &arrayType{CommonType{Name: name}, elem.id(), length} - setTypeId(a) +func newArrayType(name string) *arrayType { + a := &arrayType{CommonType{Name: name}, 0, 0} return a } +func (a *arrayType) init(elem gobType, len int) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(a) + a.Elem = elem.id() + a.Len = len +} + func (a *arrayType) safeString(seen map[typeId]bool) string { if seen[a.Id] { return a.Name @@ -158,6 +308,23 @@ func (a *arrayType) safeString(seen map[typeId]bool) string { func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) } +// GobEncoder type (something that implements the GobEncoder interface) +type gobEncoderType struct { + CommonType +} + +func newGobEncoderType(name string) *gobEncoderType { + g := &gobEncoderType{CommonType{Name: name}} + setTypeId(g) + return g +} + +func (g *gobEncoderType) safeString(seen map[typeId]bool) string { + return g.Name +} + +func (g *gobEncoderType) string() string { return g.Name } + // Map type type mapType struct { CommonType @@ -165,12 +332,18 @@ type mapType struct { Elem typeId } -func newMapType(name string, key, elem gobType) *mapType { - m := &mapType{CommonType{Name: name}, key.id(), elem.id()} - setTypeId(m) +func newMapType(name string) *mapType { + m := &mapType{CommonType{Name: name}, 0, 0} return m } +func (m *mapType) init(key, elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(m) + m.Key = key.id() + m.Elem = elem.id() +} + func (m *mapType) safeString(seen map[typeId]bool) string { if seen[m.Id] { return m.Name @@ -189,12 +362,17 @@ type sliceType struct { Elem typeId } -func newSliceType(name string, elem gobType) *sliceType { - s := &sliceType{CommonType{Name: name}, elem.id()} - setTypeId(s) +func newSliceType(name string) *sliceType { + s := &sliceType{CommonType{Name: name}, 0} return s } +func (s *sliceType) init(elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(s) + s.Elem = elem.id() +} + func (s *sliceType) safeString(seen map[typeId]bool) string { if seen[s.Id] { return s.Name @@ -236,26 +414,31 @@ func (s *structType) string() string { return s.safeString(make(map[typeId]bool) func newStructType(name string) *structType { s := &structType{CommonType{Name: name}, nil} + // For historical reasons we set the id here rather than init. + // Se the comment in newTypeObject for details. setTypeId(s) return s } -// Step through the indirections on a type to discover the base type. -// Return the base type and the number of indirections. -func indirect(t reflect.Type) (rt reflect.Type, count int) { - rt = t - for { - pt, ok := rt.(*reflect.PtrType) - if !ok { - break - } - rt = pt.Elem() - count++ +// newTypeObject allocates a gobType for the reflection type rt. +// Unless ut represents a GobEncoder, rt should be the base type +// of ut. +// This is only called from the encoding side. The decoding side +// works through typeIds and userTypeInfos alone. +func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) { + // Does this type implement GobEncoder? + if ut.isGobEncoder { + return newGobEncoderType(name), nil } - return -} - -func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { + var err os.Error + var type0, type1 gobType + defer func() { + if err != nil { + types[rt] = nil, false + } + }() + // Install the top-level type before the subtypes (e.g. struct before + // fields) so recursive types can be constructed safely. switch t := rt.(type) { // All basic types are easy: they are predefined. case *reflect.BoolType: @@ -280,57 +463,73 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { return tInterface.gobType(), nil case *reflect.ArrayType: - gt, err := getType("", t.Elem()) + at := newArrayType(name) + types[rt] = at + type0, err = getBaseType("", t.Elem()) if err != nil { return nil, err } - return newArrayType(name, gt, t.Len()), nil + // Historical aside: + // For arrays, maps, and slices, we set the type id after the elements + // are constructed. This is to retain the order of type id allocation after + // a fix made to handle recursive types, which changed the order in + // which types are built. Delaying the setting in this way preserves + // type ids while allowing recursive types to be described. Structs, + // done below, were already handling recursion correctly so they + // assign the top-level id before those of the field. + at.init(type0, t.Len()) + return at, nil case *reflect.MapType: - kt, err := getType("", t.Key()) + mt := newMapType(name) + types[rt] = mt + type0, err = getBaseType("", t.Key()) if err != nil { return nil, err } - vt, err := getType("", t.Elem()) + type1, err = getBaseType("", t.Elem()) if err != nil { return nil, err } - return newMapType(name, kt, vt), nil + mt.init(type0, type1) + return mt, nil case *reflect.SliceType: // []byte == []uint8 is a special case if t.Elem().Kind() == reflect.Uint8 { return tBytes.gobType(), nil } - gt, err := getType(t.Elem().Name(), t.Elem()) + st := newSliceType(name) + types[rt] = st + type0, err = getBaseType(t.Elem().Name(), t.Elem()) if err != nil { return nil, err } - return newSliceType(name, gt), nil + st.init(type0) + return st, nil case *reflect.StructType: - // Install the struct type itself before the fields so recursive - // structures can be constructed safely. - strType := newStructType(name) - types[rt] = strType - idToType[strType.id()] = strType - field := make([]*fieldType, t.NumField()) + st := newStructType(name) + types[rt] = st + idToType[st.id()] = st for i := 0; i < t.NumField(); i++ { f := t.Field(i) - typ, _ := indirect(f.Type) + if !isExported(f.Name) { + continue + } + typ := userType(f.Type).base tname := typ.Name() if tname == "" { - t, _ := indirect(f.Type) + t := userType(f.Type).base tname = t.String() } - gt, err := getType(tname, f.Type) + gt, err := getBaseType(tname, f.Type) if err != nil { return nil, err } - field[i] = &fieldType{f.Name, gt.id()} + st.Field = append(st.Field, &fieldType{f.Name, gt.id()}) } - strType.Field = field - return strType, nil + return st, nil default: return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String()) @@ -338,15 +537,30 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { return nil, nil } +// isExported reports whether this is an exported - upper case - name. +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// getBaseType returns the Gob type describing the given reflect.Type's base type. +// typeLock must be held. +func getBaseType(name string, rt reflect.Type) (gobType, os.Error) { + ut := userType(rt) + return getType(name, ut, ut.base) +} + // getType returns the Gob type describing the given reflect.Type. +// Should be called only when handling GobEncoders/Decoders, +// which may be pointers. All other types are handled through the +// base type, never a pointer. // typeLock must be held. -func getType(name string, rt reflect.Type) (gobType, os.Error) { - rt, _ = indirect(rt) +func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) { typ, present := types[rt] if present { return typ, nil } - typ, err := newTypeObject(name, rt) + typ, err := newTypeObject(name, ut, rt) if err == nil { types[rt] = typ } @@ -371,6 +585,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { types[rt] = typ setTypeId(typ) checkId(expect, nextId) + userType(rt) // might as well cache it now return nextId } @@ -381,15 +596,16 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { // For bootstrapping purposes, we assume that the recipient knows how // to decode a wireType; it is exactly the wireType struct here, interpreted // using the gob rules for sending a structure, except that we assume the -// ids for wireType and structType are known. The relevant pieces +// ids for wireType and structType etc. are known. The relevant pieces // are built in encode.go's init() function. // To maintain binary compatibility, if you extend this type, always put // the new fields last. type wireType struct { - ArrayT *arrayType - SliceT *sliceType - StructT *structType - MapT *mapType + ArrayT *arrayType + SliceT *sliceType + StructT *structType + MapT *mapType + GobEncoderT *gobEncoderType } func (w *wireType) string() string { @@ -406,6 +622,8 @@ func (w *wireType) string() string { return w.StructT.Name case w.MapT != nil: return w.MapT.Name + case w.GobEncoderT != nil: + return w.GobEncoderT.Name } return unknown } @@ -418,49 +636,96 @@ type typeInfo struct { var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock -// The reflection type must have all its indirections processed out. // typeLock must be held. -func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) { - if rt.Kind() == reflect.Ptr { - panic("pointer type in getTypeInfo: " + rt.String()) +func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) { + rt := ut.base + if ut.isGobEncoder { + // We want the user type, not the base type. + rt = ut.user } info, ok := typeInfoMap[rt] - if !ok { - info = new(typeInfo) - name := rt.Name() - gt, err := getType(name, rt) + if ok { + return info, nil + } + info = new(typeInfo) + gt, err := getBaseType(rt.Name(), rt) + if err != nil { + return nil, err + } + info.id = gt.id() + + if ut.isGobEncoder { + userType, err := getType(rt.Name(), ut, rt) if err != nil { return nil, err } - info.id = gt.id() - t := info.id.gobType() - switch typ := rt.(type) { - case *reflect.ArrayType: - info.wire = &wireType{ArrayT: t.(*arrayType)} - case *reflect.MapType: - info.wire = &wireType{MapT: t.(*mapType)} - case *reflect.SliceType: - // []byte == []uint8 is a special case handled separately - if typ.Elem().Kind() != reflect.Uint8 { - info.wire = &wireType{SliceT: t.(*sliceType)} - } - case *reflect.StructType: - info.wire = &wireType{StructT: t.(*structType)} + info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)} + typeInfoMap[ut.user] = info + return info, nil + } + + t := info.id.gobType() + switch typ := rt.(type) { + case *reflect.ArrayType: + info.wire = &wireType{ArrayT: t.(*arrayType)} + case *reflect.MapType: + info.wire = &wireType{MapT: t.(*mapType)} + case *reflect.SliceType: + // []byte == []uint8 is a special case handled separately + if typ.Elem().Kind() != reflect.Uint8 { + info.wire = &wireType{SliceT: t.(*sliceType)} } - typeInfoMap[rt] = info + case *reflect.StructType: + info.wire = &wireType{StructT: t.(*structType)} } + typeInfoMap[rt] = info return info, nil } // Called only when a panic is acceptable and unexpected. func mustGetTypeInfo(rt reflect.Type) *typeInfo { - t, err := getTypeInfo(rt) + t, err := getTypeInfo(userType(rt)) if err != nil { panic("getTypeInfo: " + err.String()) } return t } +// GobEncoder is the interface describing data that provides its own +// representation for encoding values for transmission to a GobDecoder. +// A type that implements GobEncoder and GobDecoder has complete +// control over the representation of its data and may therefore +// contain things such as private fields, channels, and functions, +// which are not usually transmissable in gob streams. +// +// Note: Since gobs can be stored permanently, It is good design +// to guarantee the encoding used by a GobEncoder is stable as the +// software evolves. For instance, it might make sense for GobEncode +// to include a version number in the encoding. +// +// Note: At the moment, the type implementing GobEncoder must +// be exactly the type passed to Encode. For example, if *T implements +// GobEncoder, the data item must be of type *T, not T or **T. +type GobEncoder interface { + // GobEncode returns a byte slice representing the encoding of the + // receiver for transmission to a GobDecoder, usually of the same + // concrete type. + GobEncode() ([]byte, os.Error) +} + +// GobDecoder is the interface describing data that provides its own +// routine for decoding transmitted values sent by a GobEncoder. +// +// Note: At the moment, the type implementing GobDecoder must +// be exactly the type passed to Decode. For example, if *T implements +// GobDecoder, the data item must be of type *T, not T or **T. +type GobDecoder interface { + // GobDecode overwrites the receiver, which must be a pointer, + // with the value represented by the byte slice, which was written + // by GobEncode, usually for the same concrete type. + GobDecode([]byte) os.Error +} + var ( nameToConcreteType = make(map[string]reflect.Type) concreteTypeToName = make(map[reflect.Type]string) @@ -473,18 +738,18 @@ func RegisterName(name string, value interface{}) { // reserved for nil panic("attempt to register empty name") } - rt, _ := indirect(reflect.Typeof(value)) + base := userType(reflect.Typeof(value)).base // Check for incompatible duplicates. - if t, ok := nameToConcreteType[name]; ok && t != rt { + if t, ok := nameToConcreteType[name]; ok && t != base { panic("gob: registering duplicate types for " + name) } - if n, ok := concreteTypeToName[rt]; ok && n != name { - panic("gob: registering duplicate names for " + rt.String()) + if n, ok := concreteTypeToName[base]; ok && n != name { + panic("gob: registering duplicate names for " + base.String()) } // Store the name and type provided by the user.... nameToConcreteType[name] = reflect.Typeof(value) // but the flattened type in the type table, since that's what decode needs. - concreteTypeToName[rt] = name + concreteTypeToName[base] = name } // Register records a type, identified by a value for that type, under its diff --git a/src/pkg/gob/type_test.go b/src/pkg/gob/type_test.go index 5aecde103..ffd1345e5 100644 --- a/src/pkg/gob/type_test.go +++ b/src/pkg/gob/type_test.go @@ -26,7 +26,7 @@ var basicTypes = []typeT{ func getTypeUnlocked(name string, rt reflect.Type) gobType { typeLock.Lock() defer typeLock.Unlock() - t, err := getType(name, rt) + t, err := getBaseType(name, rt) if err != nil { panic("getTypeUnlocked: " + err.String()) } @@ -126,27 +126,27 @@ func TestMapType(t *testing.T) { } type Bar struct { - x string + X string } // This structure has pointers and refers to itself, making it a good test case. type Foo struct { - a int - b int32 // will become int - c string - d []byte - e *float64 // will become float64 - f ****float64 // will become float64 - g *Bar - h *Bar // should not interpolate the definition of Bar again - i *Foo // will not explode + A int + B int32 // will become int + C string + D []byte + E *float64 // will become float64 + F ****float64 // will become float64 + G *Bar + H *Bar // should not interpolate the definition of Bar again + I *Foo // will not explode } func TestStructType(t *testing.T) { sstruct := getTypeUnlocked("Foo", reflect.Typeof(Foo{})) str := sstruct.string() // If we can print it correctly, we built it correctly. - expected := "Foo = struct { a int; b int; c string; d bytes; e float; f float; g Bar = struct { x string; }; h Bar; i Foo; }" + expected := "Foo = struct { A int; B int; C string; D bytes; E float; F float; G Bar = struct { X string; }; H Bar; I Foo; }" if str != expected { t.Errorf("struct printed as %q; expected %q", str, expected) } |