diff options
Diffstat (limited to 'src/pkg/gob/decode.go')
-rw-r--r-- | src/pkg/gob/decode.go | 488 |
1 files changed, 272 insertions, 216 deletions
diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index a70799e9a..f88ca72da 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -13,7 +13,9 @@ import ( "math" "os" "reflect" + "unicode" "unsafe" + "utf8" ) var ( @@ -22,16 +24,20 @@ var ( errRange = os.ErrorString("gob: internal error: field numbers out of bounds") ) -// The global execution state of an instance of the decoder. +// The execution state of an instance of the decoder. A new state +// is created for nested objects. type decodeState struct { - b *bytes.Buffer - err os.Error + dec *Decoder + // The buffer is stored with an extra indirection because it may be replaced + // if we load a type during decode (when reading an interface value). + b **bytes.Buffer fieldnum int // the last field number read. buf []byte } -func newDecodeState(b *bytes.Buffer) *decodeState { +func newDecodeState(dec *Decoder, b **bytes.Buffer) *decodeState { d := new(decodeState) + d.dec = dec d.b = b d.buf = make([]byte, uint64Size) return d @@ -74,24 +80,23 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, err os.Error) { } // decodeUint reads an encoded unsigned integer from state.r. -// Sets state.err. If state.err is already non-nil, it does nothing. // Does not check for overflow. -func decodeUint(state *decodeState) (x uint64) { - if state.err != nil { - return +func (state *decodeState) decodeUint() (x uint64) { + b, err := state.b.ReadByte() + if err != nil { + error(err) } - var b uint8 - b, state.err = state.b.ReadByte() - if b <= 0x7f { // includes state.err != nil + if b <= 0x7f { return uint64(b) } nb := -int(int8(b)) if nb > uint64Size { - state.err = errBadUint - return + error(errBadUint) + } + n, err := state.b.Read(state.buf[0:nb]) + if err != nil { + error(err) } - var n int - n, state.err = state.b.Read(state.buf[0:nb]) // Don't need to check error; it's safe to loop regardless. // Could check that the high byte is zero but it's not worth it. for i := 0; i < n; i++ { @@ -102,13 +107,9 @@ func decodeUint(state *decodeState) (x uint64) { } // decodeInt reads an encoded signed integer from state.r. -// Sets state.err. If state.err is already non-nil, it does nothing. // Does not check for overflow. -func decodeInt(state *decodeState) int64 { - x := decodeUint(state) - if state.err != nil { - return 0 - } +func (state *decodeState) decodeInt() int64 { + x := state.decodeUint() if x&1 != 0 { return ^int64(x >> 1) } @@ -146,12 +147,12 @@ func decIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { } func ignoreUint(i *decInstr, state *decodeState, p unsafe.Pointer) { - decodeUint(state) + state.decodeUint() } func ignoreTwoUints(i *decInstr, state *decodeState, p unsafe.Pointer) { - decodeUint(state) - decodeUint(state) + state.decodeUint() + state.decodeUint() } func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) { @@ -161,7 +162,7 @@ func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - *(*bool)(p) = decodeInt(state) != 0 + *(*bool)(p) = state.decodeInt() != 0 } func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { @@ -171,9 +172,9 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - v := decodeInt(state) + v := state.decodeInt() if v < math.MinInt8 || math.MaxInt8 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*int8)(p) = int8(v) } @@ -186,9 +187,9 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - v := decodeUint(state) + v := state.decodeUint() if math.MaxUint8 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*uint8)(p) = uint8(v) } @@ -201,9 +202,9 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - v := decodeInt(state) + v := state.decodeInt() if v < math.MinInt16 || math.MaxInt16 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*int16)(p) = int16(v) } @@ -216,9 +217,9 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - v := decodeUint(state) + v := state.decodeUint() if math.MaxUint16 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*uint16)(p) = uint16(v) } @@ -231,9 +232,9 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - v := decodeInt(state) + v := state.decodeInt() if v < math.MinInt32 || math.MaxInt32 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*int32)(p) = int32(v) } @@ -246,9 +247,9 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - v := decodeUint(state) + v := state.decodeUint() if math.MaxUint32 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*uint32)(p) = uint32(v) } @@ -261,7 +262,7 @@ func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - *(*int64)(p) = int64(decodeInt(state)) + *(*int64)(p) = int64(state.decodeInt()) } func decUint64(i *decInstr, state *decodeState, p unsafe.Pointer) { @@ -271,7 +272,7 @@ func decUint64(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - *(*uint64)(p) = uint64(decodeUint(state)) + *(*uint64)(p) = uint64(state.decodeUint()) } // Floating-point numbers are transmitted as uint64s holding the bits @@ -290,14 +291,14 @@ func floatFromBits(u uint64) float64 { } func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { - v := floatFromBits(decodeUint(state)) + v := floatFromBits(state.decodeUint()) av := v if av < 0 { av = -av } // +Inf is OK in both 32- and 64-bit floats. Underflow is always OK. if math.MaxFloat32 < av && av <= math.MaxFloat64 { - state.err = i.ovfl + error(i.ovfl) } else { *(*float32)(p) = float32(v) } @@ -320,7 +321,7 @@ func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - *(*float64)(p) = floatFromBits(uint64(decodeUint(state))) + *(*float64)(p) = floatFromBits(uint64(state.decodeUint())) } // Complex numbers are just a pair of floating-point numbers, real part first. @@ -342,8 +343,8 @@ func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - real := floatFromBits(uint64(decodeUint(state))) - imag := floatFromBits(uint64(decodeUint(state))) + real := floatFromBits(uint64(state.decodeUint())) + imag := floatFromBits(uint64(state.decodeUint())) *(*complex128)(p) = cmplx(real, imag) } @@ -355,7 +356,7 @@ func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - b := make([]uint8, decodeUint(state)) + b := make([]uint8, state.decodeUint()) state.b.Read(b) *(*[]uint8)(p) = b } @@ -368,13 +369,13 @@ func decString(i *decInstr, state *decodeState, p unsafe.Pointer) { } p = *(*unsafe.Pointer)(p) } - b := make([]byte, decodeUint(state)) + b := make([]byte, state.decodeUint()) state.b.Read(b) *(*string)(p) = string(b) } func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { - b := make([]byte, decodeUint(state)) + b := make([]byte, state.decodeUint()) state.b.Read(b) } @@ -404,15 +405,15 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { return *(*uintptr)(up) } -func decodeSingle(engine *decEngine, rtyp reflect.Type, b *bytes.Buffer, p uintptr, indir int) os.Error { +func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) (err os.Error) { + defer catchError(&err) p = allocate(rtyp, p, indir) - state := newDecodeState(b) + state := newDecodeState(dec, b) state.fieldnum = singletonField basep := p - delta := int(decodeUint(state)) + delta := int(state.decodeUint()) if delta != 0 { - state.err = os.ErrorString("gob decode: corrupted data: non-zero delta for singleton") - return state.err + errorf("gob decode: corrupted data: non-zero delta for singleton") } instr := &engine.instr[singletonField] ptr := unsafe.Pointer(basep) // offset will be zero @@ -420,26 +421,26 @@ func decodeSingle(engine *decEngine, rtyp reflect.Type, b *bytes.Buffer, p uintp ptr = decIndirect(ptr, instr.indir) } instr.op(instr, state, ptr) - return state.err + return nil } -func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error { +func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) (err os.Error) { + defer catchError(&err) p = allocate(rtyp, p, indir) - state := newDecodeState(b) + state := newDecodeState(dec, b) state.fieldnum = -1 basep := p - for state.err == nil { - delta := int(decodeUint(state)) + for state.b.Len() > 0 { + delta := int(state.decodeUint()) if delta < 0 { - state.err = os.ErrorString("gob decode: corrupted data: negative delta") - break + errorf("gob decode: corrupted data: negative delta") } - if state.err != nil || delta == 0 { // struct terminator is zero delta fieldnum + if delta == 0 { // struct terminator is zero delta fieldnum break } fieldnum := state.fieldnum + delta if fieldnum >= len(engine.instr) { - state.err = errRange + error(errRange) break } instr := &engine.instr[fieldnum] @@ -450,36 +451,35 @@ func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, instr.op(instr, state, p) state.fieldnum = fieldnum } - return state.err + return nil } -func ignoreStruct(engine *decEngine, b *bytes.Buffer) os.Error { - state := newDecodeState(b) +func (dec *Decoder) ignoreStruct(engine *decEngine, b **bytes.Buffer) (err os.Error) { + defer catchError(&err) + state := newDecodeState(dec, b) state.fieldnum = -1 - for state.err == nil { - delta := int(decodeUint(state)) + for state.b.Len() > 0 { + delta := int(state.decodeUint()) if delta < 0 { - state.err = os.ErrorString("gob ignore decode: corrupted data: negative delta") - break + errorf("gob ignore decode: corrupted data: negative delta") } - if state.err != nil || delta == 0 { // struct terminator is zero delta fieldnum + if delta == 0 { // struct terminator is zero delta fieldnum break } fieldnum := state.fieldnum + delta if fieldnum >= len(engine.instr) { - state.err = errRange - break + error(errRange) } instr := &engine.instr[fieldnum] instr.op(instr, state, unsafe.Pointer(nil)) state.fieldnum = fieldnum } - return state.err + return nil } -func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) os.Error { +func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) { instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl} - for i := 0; i < length && state.err == nil; i++ { + for i := 0; i < length; i++ { up := unsafe.Pointer(p) if elemIndir > 1 { up = decIndirect(up, elemIndir) @@ -487,17 +487,16 @@ func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uint elemOp(instr, state, up) p += uintptr(elemWid) } - return state.err } -func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) os.Error { +func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect } - if n := decodeUint(state); n != uint64(length) { - return os.ErrorString("gob: length mismatch in decodeArray") + if n := state.decodeUint(); n != uint64(length) { + errorf("gob: length mismatch in decodeArray") } - return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) + dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) } func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { @@ -510,7 +509,7 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o return v } -func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error { +func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect } @@ -523,50 +522,40 @@ func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elem // that slices etc. can. We must recover a full reflection value for // the iteration. v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue) - n := int(decodeUint(state)) - for i := 0; i < n && state.err == nil; i++ { + n := int(state.decodeUint()) + for i := 0; i < n; i++ { key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl) - if state.err != nil { - break - } elem := decodeIntoValue(state, elemOp, elemIndir, reflect.MakeZero(mtyp.Elem()), ovfl) - if state.err != nil { - break - } v.SetElem(key, elem) } - return state.err } -func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error { +func (dec *Decoder) ignoreArrayHelper(state *decodeState, elemOp decOp, length int) { instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} - for i := 0; i < length && state.err == nil; i++ { + for i := 0; i < length; i++ { elemOp(instr, state, nil) } - return state.err } -func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error { - if n := decodeUint(state); n != uint64(length) { - return os.ErrorString("gob: length mismatch in ignoreArray") +func (dec *Decoder) ignoreArray(state *decodeState, elemOp decOp, length int) { + if n := state.decodeUint(); n != uint64(length) { + errorf("gob: length mismatch in ignoreArray") } - return ignoreArrayHelper(state, elemOp, length) + dec.ignoreArrayHelper(state, elemOp, length) } -func ignoreMap(state *decodeState, keyOp, elemOp decOp) os.Error { - n := int(decodeUint(state)) +func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) { + n := int(state.decodeUint()) keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")} elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} - for i := 0; i < n && state.err == nil; i++ { + for i := 0; i < n; i++ { keyOp(keyInstr, state, nil) elemOp(elemInstr, state, nil) } - return state.err } - -func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error { - n := int(uintptr(decodeUint(state))) +func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) { + n := int(uintptr(state.decodeUint())) if indir > 0 { up := unsafe.Pointer(p) if *(*unsafe.Pointer)(up) == nil { @@ -581,11 +570,77 @@ func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp hdrp.Data = uintptr(unsafe.NewArray(atyp.Elem(), n)) hdrp.Len = n hdrp.Cap = n - return decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl) + dec.decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl) } -func ignoreSlice(state *decodeState, elemOp decOp) os.Error { - return ignoreArrayHelper(state, elemOp, int(decodeUint(state))) +func (dec *Decoder) ignoreSlice(state *decodeState, elemOp decOp) { + dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint())) +} + +// setInterfaceValue sets an interface value to a concrete value through +// reflection. If the concrete value does not implement the interface, the +// setting will panic. This routine turns the panic into an error return. +// This dance avoids manually checking that the value satisfies the +// interface. +// TODO(rsc): avoid panic+recover after fixing issue 327. +func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) { + defer func() { + if e := recover(); e != nil { + error(e.(os.Error)) + } + }() + ivalue.Set(value) +} + +// decodeInterface receives the name of a concrete type followed by its value. +// If the name is empty, the value is nil and no value is sent. +func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) { + // Create an interface reflect.Value. We need one even for the nil case. + ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue) + // Read the name of the concrete type. + b := make([]byte, state.decodeUint()) + state.b.Read(b) + name := string(b) + if name == "" { + // Copy the representation of the nil interface value to the target. + // This is horribly unsafe and special. + *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get() + return + } + // The concrete type must be registered. + typ, ok := nameToConcreteType[name] + if !ok { + errorf("gob: name not registered for interface: %q", name) + } + // Read the concrete value. + value := reflect.MakeZero(typ) + dec.decodeValueFromBuffer(value, false, true) + if dec.err != nil { + error(dec.err) + } + // Allocate the destination interface value. + if indir > 0 { + p = allocate(ityp, p, 1) // All but the last level has been allocated by dec.Indirect + } + // Assign the concrete value to the interface. + // Tread carefully; it might not satisfy the interface. + setInterfaceValue(ivalue, value) + // Copy the representation of the interface value to the target. + // This is horribly unsafe and special. + *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get() +} + +func (dec *Decoder) ignoreInterface(state *decodeState) { + // Read the name of the concrete type. + b := make([]byte, state.decodeUint()) + _, err := state.b.Read(b) + if err != nil { + error(err) + } + dec.decodeValueFromBuffer(nil, true, true) + if dec.err != nil { + error(err) + } } // Index by Go types. @@ -608,17 +663,18 @@ var decOpMap = []decOp{ // Indexed by gob types. tComplex will be added during type.init(). var decIgnoreOpMap = map[typeId]decOp{ - tBool: ignoreUint, - tInt: ignoreUint, - tUint: ignoreUint, - tFloat: ignoreUint, - tBytes: ignoreUint8Array, - tString: ignoreUint8Array, + tBool: ignoreUint, + tInt: ignoreUint, + tUint: ignoreUint, + tFloat: ignoreUint, + tBytes: ignoreUint8Array, + tString: ignoreUint8Array, + tComplex: ignoreTwoUints, } // Return the decoding op for the base type under rt and // the indirection count to reach it. -func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) { +func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) { typ, indir := indirect(rt) var op decOp k := typ.Kind() @@ -630,32 +686,23 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp switch t := typ.(type) { case *reflect.ArrayType: name = "element of " + name - elemId := dec.wireType[wireId].arrayT.Elem - elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) - if err != nil { - return nil, 0, err - } + elemId := dec.wireType[wireId].ArrayT.Elem + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) + state.dec.decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) } case *reflect.MapType: name = "element of " + name - keyId := dec.wireType[wireId].mapT.Key - elemId := dec.wireType[wireId].mapT.Elem - keyOp, keyIndir, err := dec.decOpFor(keyId, t.Key(), name) - if err != nil { - return nil, 0, err - } - elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) - if err != nil { - return nil, 0, err - } + keyId := dec.wireType[wireId].MapT.Key + elemId := dec.wireType[wireId].MapT.Elem + keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { up := unsafe.Pointer(p) - state.err = decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) + state.dec.decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) } case *reflect.SliceType: @@ -668,111 +715,105 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp if tt, ok := builtinIdToType[wireId]; ok { elemId = tt.(*sliceType).Elem } else { - elemId = dec.wireType[wireId].sliceT.Elem - } - elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) - if err != nil { - return nil, 0, err + elemId = dec.wireType[wireId].SliceT.Elem } + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) + state.dec.decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. enginePtr, err := dec.getDecEnginePtr(wireId, typ) if err != nil { - return nil, 0, err + error(err) } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { // indirect through enginePtr to delay evaluation for recursive structs - state.err = decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir) + err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir) + if err != nil { + error(err) + } + } + case *reflect.InterfaceType: + op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + dec.decodeInterface(t, state, uintptr(p), i.indir) } } } if op == nil { - return nil, 0, os.ErrorString("gob: decode can't handle type " + rt.String()) + errorf("gob: decode can't handle type %s", rt.String()) } - return op, indir, nil + return op, indir } // Return the decoding op for a field that has no destination. -func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { +func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { op, ok := decIgnoreOpMap[wireId] if !ok { + if wireId == tInterface { + // Special case because it's a method: the ignored item might + // define types and we need to record their state in the decoder. + op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + dec.ignoreInterface(state) + } + return op + } // Special cases wire := dec.wireType[wireId] switch { case wire == nil: panic("internal error: can't find ignore op for type " + wireId.string()) - case wire.arrayT != nil: - elemId := wire.arrayT.Elem - elemOp, err := dec.decIgnoreOpFor(elemId) - if err != nil { - return nil, err - } + case wire.ArrayT != nil: + elemId := wire.ArrayT.Elem + elemOp := dec.decIgnoreOpFor(elemId) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = ignoreArray(state, elemOp, wire.arrayT.Len) + state.dec.ignoreArray(state, elemOp, wire.ArrayT.Len) } - case wire.mapT != nil: - keyId := dec.wireType[wireId].mapT.Key - elemId := dec.wireType[wireId].mapT.Elem - keyOp, err := dec.decIgnoreOpFor(keyId) - if err != nil { - return nil, err - } - elemOp, err := dec.decIgnoreOpFor(elemId) - if err != nil { - return nil, err - } + case wire.MapT != nil: + keyId := dec.wireType[wireId].MapT.Key + elemId := dec.wireType[wireId].MapT.Elem + keyOp := dec.decIgnoreOpFor(keyId) + elemOp := dec.decIgnoreOpFor(elemId) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = ignoreMap(state, keyOp, elemOp) + state.dec.ignoreMap(state, keyOp, elemOp) } - case wire.sliceT != nil: - elemId := wire.sliceT.Elem - elemOp, err := dec.decIgnoreOpFor(elemId) - if err != nil { - return nil, err - } + case wire.SliceT != nil: + elemId := wire.SliceT.Elem + elemOp := dec.decIgnoreOpFor(elemId) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = ignoreSlice(state, elemOp) + state.dec.ignoreSlice(state, elemOp) } - case wire.structT != nil: + case wire.StructT != nil: // Generate a closure that calls out to the engine for the nested type. enginePtr, err := dec.getIgnoreEnginePtr(wireId) if err != nil { - return nil, err + error(err) } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { // indirect through enginePtr to delay evaluation for recursive structs - state.err = ignoreStruct(*enginePtr, state.b) + state.dec.ignoreStruct(*enginePtr, state.b) } } } if op == nil { - return nil, os.ErrorString("ignore can't handle type " + wireId.string()) + errorf("ignore can't handle type %s", wireId.string()) } - return op, nil + return op } // Are these two gob Types compatible? // Answers the question for basic types, arrays, and slices. // Structs are considered ok; fields will be checked later. func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { - for { - if pt, ok := fr.(*reflect.PtrType); ok { - fr = pt.Elem() - continue - } - break - } + fr, _ = indirect(fr) switch t := fr.(type) { default: - // interface, map, chan, etc: cannot handle. + // map, chan, etc: cannot handle. return false case *reflect.BoolType: return fw == tBool @@ -786,20 +827,22 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { return fw == tComplex case *reflect.StringType: return fw == tString + case *reflect.InterfaceType: + return fw == tInterface case *reflect.ArrayType: wire, ok := dec.wireType[fw] - if !ok || wire.arrayT == nil { + if !ok || wire.ArrayT == nil { return false } - array := wire.arrayT + array := wire.ArrayT return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) case *reflect.MapType: wire, ok := dec.wireType[fw] - if !ok || wire.mapT == nil { + if !ok || wire.MapT == nil { return false } - mapType := wire.mapT - return dec.compatibleType(t.Key(), mapType.Key) && dec.compatibleType(t.Elem(), mapType.Elem) + MapType := wire.MapT + return dec.compatibleType(t.Key(), MapType.Key) && dec.compatibleType(t.Elem(), MapType.Elem) case *reflect.SliceType: // Is it an array of bytes? if t.Elem().Kind() == reflect.Uint8 { @@ -810,7 +853,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { if tt, ok := builtinIdToType[fw]; ok { sw = tt.(*sliceType) } else { - sw = dec.wireType[fw].sliceT + sw = dec.wireType[fw].SliceT } elem, _ := indirect(t.Elem()) return sw != nil && dec.compatibleType(elem, sw.Elem) @@ -820,61 +863,74 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { return true } +// typeString returns a human-readable description of the type identified by remoteId. +func (dec *Decoder) typeString(remoteId typeId) string { + if t := idToType[remoteId]; t != nil { + // globally known type. + return t.string() + } + return dec.wireType[remoteId].string() +} + + func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { engine = new(decEngine) engine.instr = make([]decInstr, 1) // one item name := rt.String() // best we can do if !dec.compatibleType(rt, remoteId) { - return nil, os.ErrorString("gob: wrong type received for local value " + name) - } - op, indir, err := dec.decOpFor(remoteId, rt, name) - if err != nil { - return nil, err + return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId)) } + op, indir := dec.decOpFor(remoteId, rt, name) ovfl := os.ErrorString(`value for "` + name + `" out of range`) engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl} engine.numInstr = 1 return } +// Is this an exported - upper case - name? +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { + defer catchError(&err) srt, ok := rt.(*reflect.StructType) if !ok { return dec.compileSingle(remoteId, rt) } var wireStruct *structType - // Builtin types can come from global pool; the rest must be defined by the decoder + // Builtin types can come from global pool; the rest must be defined by the decoder. + // Also we know we're decoding a struct now, so the client must have sent one. if t, ok := builtinIdToType[remoteId]; ok { - wireStruct = t.(*structType) + wireStruct, _ = t.(*structType) } else { - wireStruct = dec.wireType[remoteId].structT + wireStruct = dec.wireType[remoteId].StructT + } + if wireStruct == nil { + errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String()) } engine = new(decEngine) - engine.instr = make([]decInstr, len(wireStruct.field)) + engine.instr = make([]decInstr, len(wireStruct.Field)) // Loop over the fields of the wire type. - for fieldnum := 0; fieldnum < len(wireStruct.field); fieldnum++ { - wireField := wireStruct.field[fieldnum] + for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ { + wireField := wireStruct.Field[fieldnum] + if wireField.Name == "" { + errorf("gob: empty name for remote field of type %s", wireStruct.Name) + } + ovfl := overflow(wireField.Name) // Find the field of the local type with the same name. - localField, present := srt.FieldByName(wireField.name) - ovfl := overflow(wireField.name) + localField, present := srt.FieldByName(wireField.Name) // TODO(r): anonymous names - if !present { - op, err := dec.decIgnoreOpFor(wireField.id) - if err != nil { - return nil, err - } + if !present || !isExported(wireField.Name) { + op := dec.decIgnoreOpFor(wireField.Id) engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl} continue } - if !dec.compatibleType(localField.Type, wireField.id) { - return nil, os.ErrorString("gob: wrong type (" + - localField.Type.String() + ") for received field " + - wireStruct.name + "." + wireField.name) - } - op, indir, err := dec.decOpFor(wireField.id, localField.Type, localField.Name) - if err != nil { - return nil, err + if !dec.compatibleType(localField.Type, wireField.Id) { + errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) } + op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name) engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl} engine.numInstr++ } @@ -899,7 +955,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr return } -// When ignoring data, in effect we compile it into this type +// When ignoring struct data, in effect we compile it into this type type emptyStruct struct{} var emptyStructType = reflect.Typeof(emptyStruct{}) @@ -927,13 +983,13 @@ func (dec *Decoder) decode(wireId typeId, val reflect.Value) os.Error { } engine := *enginePtr if st, ok := rt.(*reflect.StructType); ok { - if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 { + if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { name := rt.Name() return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) } - return decodeStruct(engine, st, dec.state.b, uintptr(val.Addr()), indir) + return dec.decodeStruct(engine, st, dec.state.b, uintptr(val.Addr()), indir) } - return decodeSingle(engine, rt, dec.state.b, uintptr(val.Addr()), indir) + return dec.decodeSingle(engine, rt, dec.state.b, uintptr(val.Addr()), indir) } func init() { |