diff options
author | Rob Pike <r@golang.org> | 2010-06-28 14:09:47 -0700 |
---|---|---|
committer | Rob Pike <r@golang.org> | 2010-06-28 14:09:47 -0700 |
commit | 38ce83e85d620f191bd6f5f7de1fe0e22ed538cf (patch) | |
tree | b0ff633bd6547899a6df0747d8beacdf735e43b6 /src/pkg | |
parent | 0f603c351f4cfffac3fb0b637be5664ab130083a (diff) | |
download | golang-38ce83e85d620f191bd6f5f7de1fe0e22ed538cf.tar.gz |
gob: allow transmission of things other than structs at the top level.
also fix a bug handling nil maps: before, would needlessly send empty map
R=rsc
CC=golang-dev
http://codereview.appspot.com/1739043
Diffstat (limited to 'src/pkg')
-rw-r--r-- | src/pkg/gob/codec_test.go | 19 | ||||
-rw-r--r-- | src/pkg/gob/decode.go | 113 | ||||
-rw-r--r-- | src/pkg/gob/decoder.go | 5 | ||||
-rw-r--r-- | src/pkg/gob/encode.go | 123 | ||||
-rw-r--r-- | src/pkg/gob/encoder.go | 22 | ||||
-rw-r--r-- | src/pkg/gob/encoder_test.go | 65 |
6 files changed, 234 insertions, 113 deletions
diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index d8bdf2d2f..49a13e84d 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -1039,7 +1039,7 @@ func TestInvalidField(t *testing.T) { type Indirect struct { a ***[3]int s ***[]int - m ***map[string]int + m ****map[string]int } type Direct struct { @@ -1059,10 +1059,11 @@ func TestIndirectSliceMapArray(t *testing.T) { *i.s = new(*[]int) **i.s = new([]int) ***i.s = []int{4, 5, 6} - i.m = new(**map[string]int) - *i.m = new(*map[string]int) - **i.m = new(map[string]int) - ***i.m = map[string]int{"one": 1, "two": 2, "three": 3} + i.m = new(***map[string]int) + *i.m = new(**map[string]int) + **i.m = new(*map[string]int) + ***i.m = new(map[string]int) + ****i.m = map[string]int{"one": 1, "two": 2, "three": 3} b := new(bytes.Buffer) NewEncoder(b).Encode(i) dec := NewDecoder(b) @@ -1093,12 +1094,12 @@ func TestIndirectSliceMapArray(t *testing.T) { t.Error("error: ", err) } if len(***i.a) != 3 || (***i.a)[0] != 11 || (***i.a)[1] != 22 || (***i.a)[2] != 33 { - t.Errorf("indirect to direct: ***i.a is %v not %v", ***i.a, d.a) + t.Errorf("direct to indirect: ***i.a is %v not %v", ***i.a, d.a) } if len(***i.s) != 3 || (***i.s)[0] != 44 || (***i.s)[1] != 55 || (***i.s)[2] != 66 { - t.Errorf("indirect to direct: ***i.s is %v not %v", ***i.s, ***i.s) + t.Errorf("direct to indirect: ***i.s is %v not %v", ***i.s, ***i.s) } - if len(***i.m) != 3 || (***i.m)["four"] != 4 || (***i.m)["five"] != 5 || (***i.m)["six"] != 6 { - t.Errorf("indirect to direct: ***i.m is %v not %v", ***i.m, d.m) + if len(****i.m) != 3 || (****i.m)["four"] != 4 || (****i.m)["five"] != 5 || (****i.m)["six"] != 6 { + t.Errorf("direct to indirect: ****i.m is %v not %v", ****i.m, d.m) } } diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 0dbf81488..51e439900 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -13,15 +13,13 @@ import ( "math" "os" "reflect" - "runtime" "unsafe" ) var ( - errBadUint = os.ErrorString("gob: encoded unsigned integer out of range") - errBadType = os.ErrorString("gob: unknown type id or corrupted data") - errRange = os.ErrorString("gob: internal error: field numbers out of bounds") - errNotStruct = os.ErrorString("gob: TODO: can only handle structs") + errBadUint = os.ErrorString("gob: encoded unsigned integer out of range") + errBadType = os.ErrorString("gob: unknown type id or corrupted data") + errRange = os.ErrorString("gob: internal error: field numbers out of bounds") ) // The global execution state of an instance of the decoder. @@ -389,18 +387,44 @@ type decEngine struct { numInstr int // the number of active instructions } -func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error { - if indir > 0 { - up := unsafe.Pointer(p) - if indir > 1 { - up = decIndirect(up, indir) - } - if *(*unsafe.Pointer)(up) == nil { - // Allocate object. - *(*unsafe.Pointer)(up) = unsafe.New((*runtime.StructType)(unsafe.Pointer(rtyp))) - } - p = *(*uintptr)(up) +// allocate makes sure storage is available for an object of underlying type rtyp +// that is indir levels of indirection through p. +func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { + if indir == 0 { + return p + } + up := unsafe.Pointer(p) + if indir > 1 { + up = decIndirect(up, indir) } + if *(*unsafe.Pointer)(up) == nil { + // Allocate object. + *(*unsafe.Pointer)(up) = unsafe.New(rtyp) + } + return *(*uintptr)(up) +} + +func decodeSingle(engine *decEngine, rtyp reflect.Type, b *bytes.Buffer, p uintptr, indir int) os.Error { + p = allocate(rtyp, p, indir) + state := newDecodeState(b) + state.fieldnum = singletonField + basep := p + delta := int(decodeUint(state)) + if delta != 0 { + state.err = os.ErrorString("gob decode: corrupted data: non-zero delta for singleton") + return state.err + } + instr := &engine.instr[singletonField] + ptr := unsafe.Pointer(basep) // offset will be zero + if instr.indir > 1 { + ptr = decIndirect(ptr, instr.indir) + } + instr.op(instr, state, ptr) + return state.err +} + +func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error { + p = allocate(rtyp, p, indir) state := newDecodeState(b) state.fieldnum = -1 basep := p @@ -468,12 +492,7 @@ func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uint func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) os.Error { if indir > 0 { - up := unsafe.Pointer(p) - if *(*unsafe.Pointer)(up) == nil { - // Allocate object. - *(*unsafe.Pointer)(up) = unsafe.New(atyp) - } - p = *(*uintptr)(up) + p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect } if n := decodeUint(state); n != uint64(length) { return os.ErrorString("gob: length mismatch in decodeArray") @@ -493,12 +512,7 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error { if indir > 0 { - up := unsafe.Pointer(p) - if *(*unsafe.Pointer)(up) == nil { - // Allocate object. - *(*unsafe.Pointer)(up) = unsafe.New(mtyp) - } - p = *(*uintptr)(up) + p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect } up := unsafe.Pointer(p) if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime @@ -806,18 +820,34 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { return true } +func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { + engine = new(decEngine) + engine.instr = make([]decInstr, 1) // one item + name := rt.String() // best we can do + if !dec.compatibleType(rt, remoteId) { + return nil, os.ErrorString("gob: wrong type received for local value " + name) + } + op, indir, err := dec.decOpFor(remoteId, rt, name) + if err != nil { + return nil, err + } + ovfl := os.ErrorString(`value for "` + name + `" out of range`) + engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl} + engine.numInstr = 1 + return +} + func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { - srt, ok1 := rt.(*reflect.StructType) + srt, ok := rt.(*reflect.StructType) + if !ok { + return dec.compileSingle(remoteId, rt) + } var wireStruct *structType // Builtin types can come from global pool; the rest must be defined by the decoder if t, ok := builtinIdToType[remoteId]; ok { wireStruct = t.(*structType) } else { - w, ok2 := dec.wireType[remoteId] - if !ok1 || !ok2 { - return nil, errNotStruct - } - wireStruct = w.structT + wireStruct = dec.wireType[remoteId].structT } engine = new(decEngine) engine.instr = make([]decInstr, len(wireStruct.field)) @@ -891,20 +921,19 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error { // Dereference down to the underlying struct type. rt, indir := indirect(reflect.Typeof(e)) - st, ok := rt.(*reflect.StructType) - if !ok { - return os.ErrorString("gob: decode can't handle " + rt.String()) - } enginePtr, err := dec.getDecEnginePtr(wireId, rt) if err != nil { return err } engine := *enginePtr - if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 { - name := rt.Name() - return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) + if st, ok := rt.(*reflect.StructType); ok { + if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 { + name := rt.Name() + return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) + } + return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir) } - return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir) + return decodeSingle(engine, rt, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir) } func init() { diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go index 90dc2e34c..caec51712 100644 --- a/src/pkg/gob/decoder.go +++ b/src/pkg/gob/decoder.go @@ -108,8 +108,9 @@ func (dec *Decoder) Decode(e interface{}) os.Error { } // No, it's a value. - // Make sure the type has been defined already. - if dec.wireType[id] == nil { + // Make sure the type has been defined already or is a builtin type (for + // top-level singleton values). + if dec.wireType[id] == nil && builtinIdToType[id] == nil { dec.state.err = errBadType break } diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index b48c1f698..a7d44ecc2 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -271,7 +271,7 @@ const uint64Size = unsafe.Sizeof(uint64(0)) type encoderState struct { b *bytes.Buffer err os.Error // error encountered during encoding. - inArray bool // encoding an array element or map key/value pair + sendZero bool // encoding an array element or map key/value pair; send zero values fieldnum int // the last field number written. buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation. } @@ -352,7 +352,7 @@ func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*bool)(p) - if b || state.inArray { + if b || state.sendZero { state.update(i) if b { encodeUint(state, 1) @@ -364,7 +364,7 @@ func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int)(p)) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeInt(state, v) } @@ -372,7 +372,7 @@ func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint)(p)) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeUint(state, v) } @@ -380,7 +380,7 @@ func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int8)(p)) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeInt(state, v) } @@ -388,7 +388,7 @@ func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint8)(p)) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeUint(state, v) } @@ -396,7 +396,7 @@ func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int16)(p)) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeInt(state, v) } @@ -404,7 +404,7 @@ func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint16)(p)) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeUint(state, v) } @@ -412,7 +412,7 @@ func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int32)(p)) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeInt(state, v) } @@ -420,7 +420,7 @@ func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint32)(p)) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeUint(state, v) } @@ -428,7 +428,7 @@ func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) { func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) { v := *(*int64)(p) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeInt(state, v) } @@ -436,7 +436,7 @@ func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) { v := *(*uint64)(p) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeUint(state, v) } @@ -444,7 +444,7 @@ func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) { func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uintptr)(p)) - if v != 0 || state.inArray { + if v != 0 || state.sendZero { state.update(i) encodeUint(state, v) } @@ -468,7 +468,7 @@ func floatBits(f float64) uint64 { func encFloat(i *encInstr, state *encoderState, p unsafe.Pointer) { f := *(*float)(p) - if f != 0 || state.inArray { + if f != 0 || state.sendZero { v := floatBits(float64(f)) state.update(i) encodeUint(state, v) @@ -477,7 +477,7 @@ func encFloat(i *encInstr, state *encoderState, p unsafe.Pointer) { func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { f := *(*float32)(p) - if f != 0 || state.inArray { + if f != 0 || state.sendZero { v := floatBits(float64(f)) state.update(i) encodeUint(state, v) @@ -486,7 +486,7 @@ func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { f := *(*float64)(p) - if f != 0 || state.inArray { + if f != 0 || state.sendZero { state.update(i) v := floatBits(f) encodeUint(state, v) @@ -496,7 +496,7 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { // Complex numbers are just a pair of floating-point numbers, real part first. func encComplex(i *encInstr, state *encoderState, p unsafe.Pointer) { c := *(*complex)(p) - if c != 0+0i || state.inArray { + if c != 0+0i || state.sendZero { rpart := floatBits(float64(real(c))) ipart := floatBits(float64(imag(c))) state.update(i) @@ -507,7 +507,7 @@ func encComplex(i *encInstr, state *encoderState, p unsafe.Pointer) { func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) { c := *(*complex64)(p) - if c != 0+0i || state.inArray { + if c != 0+0i || state.sendZero { rpart := floatBits(float64(real(c))) ipart := floatBits(float64(imag(c))) state.update(i) @@ -518,7 +518,7 @@ func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) { func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { c := *(*complex128)(p) - if c != 0+0i || state.inArray { + if c != 0+0i || state.sendZero { rpart := floatBits(real(c)) ipart := floatBits(imag(c)) state.update(i) @@ -530,7 +530,7 @@ func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { // Byte arrays are encoded as an unsigned count followed by the raw bytes. func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*[]byte)(p) - if len(b) > 0 || state.inArray { + if len(b) > 0 || state.sendZero { state.update(i) encodeUint(state, uint64(len(b))) state.b.Write(b) @@ -540,7 +540,7 @@ func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { // Strings are encoded as an unsigned count followed by the raw bytes. func encString(i *encInstr, state *encoderState, p unsafe.Pointer) { s := *(*string)(p) - if len(s) > 0 || state.inArray { + if len(s) > 0 || state.sendZero { state.update(i) encodeUint(state, uint64(len(s))) io.WriteString(state.b, s) @@ -560,6 +560,26 @@ type encEngine struct { instr []encInstr } +const singletonField = 0 + +func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error { + state := new(encoderState) + state.b = b + state.fieldnum = singletonField + // There is no surrounding struct to frame the transmission, so we must + // generate data even if the item is zero. To do this, set sendZero. + state.sendZero = true + instr := &engine.instr[singletonField] + p := unsafe.Pointer(basep) // offset will be zero + if instr.indir > 0 { + if p = encIndirect(p, instr.indir); p == nil { + return nil + } + } + instr.op(instr, state, p) + return state.err +} + func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error { state := new(encoderState) state.b = b @@ -584,7 +604,7 @@ func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndi state := new(encoderState) state.b = b state.fieldnum = -1 - state.inArray = true + state.sendZero = true encodeUint(state, uint64(length)) for i := 0; i < length && state.err == nil; i++ { elemp := p @@ -607,22 +627,17 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in v = reflect.Indirect(v) } if v == nil { - state.err = os.ErrorString("gob: encodeMap: nil element") + state.err = os.ErrorString("gob: encodeReflectValue: nil element") return } op(nil, state, unsafe.Pointer(v.Addr())) } -func encodeMap(b *bytes.Buffer, rt reflect.Type, p uintptr, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error { +func encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error { state := new(encoderState) state.b = b state.fieldnum = -1 - state.inArray = true - // Maps cannot be accessed by moving addresses around the way - // that slices etc. can. We must recover a full reflection value for - // the iteration. - v := reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer((p)))) - mv := reflect.Indirect(v).(*reflect.MapValue) + state.sendZero = true keys := mv.Keys() encodeUint(state, uint64(len(keys))) for _, key := range keys { @@ -694,6 +709,10 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { return nil, 0, err } op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { + slice := (*reflect.SliceHeader)(p) + if slice.Len == 0 { + return + } state.update(i) state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len()) } @@ -707,8 +726,16 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { return nil, 0, err } op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { + // Maps cannot be accessed by moving addresses around the way + // that slices etc. can. We must recover a full reflection value for + // the iteration. + v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p)))) + mv := reflect.Indirect(v).(*reflect.MapValue) + if mv.Len() == 0 { + return + } state.update(i) - state.err = encodeMap(state.b, typ, uintptr(p), keyOp, elemOp, keyIndir, elemIndir) + state.err = encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir) } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. @@ -732,21 +759,27 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { // The local Type was compiled from the actual value, so we know it's compatible. func compileEnc(rt reflect.Type) (*encEngine, os.Error) { - srt, ok := rt.(*reflect.StructType) - if !ok { - panic("can't happen: non-struct") - } + srt, isStruct := rt.(*reflect.StructType) engine := new(encEngine) - engine.instr = make([]encInstr, srt.NumField()+1) // +1 for terminator - for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ { - f := srt.Field(fieldnum) - op, indir, err := encOpFor(f.Type) + if isStruct { + engine.instr = make([]encInstr, srt.NumField()+1) // +1 for terminator + for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ { + f := srt.Field(fieldnum) + op, indir, err := encOpFor(f.Type) + if err != nil { + return nil, err + } + engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)} + } + engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0} + } else { + engine.instr = make([]encInstr, 1) + op, indir, err := encOpFor(rt) if err != nil { return nil, err } - engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)} + engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero } - engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0} return engine, nil } @@ -772,14 +805,14 @@ func encode(b *bytes.Buffer, e interface{}) os.Error { for i := 0; i < indir; i++ { v = reflect.Indirect(v) } - if _, ok := v.(*reflect.StructValue); !ok { - return os.ErrorString("gob: encode can't handle " + v.Type().String()) - } typeLock.Lock() engine, err := getEncEngine(rt) typeLock.Unlock() if err != nil { return err } - return encodeStruct(engine, b, v.Addr()) + if _, ok := v.(*reflect.StructValue); ok { + return encodeStruct(engine, b, v.Addr()) + } + return encodeSingle(engine, b, v.Addr()) } diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index e24c18d20..28cf6f6e0 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -68,7 +68,7 @@ func (enc *Encoder) send() { } } -func (enc *Encoder) sendType(origt reflect.Type) { +func (enc *Encoder) sendType(origt reflect.Type) (sent bool) { // Drill down to the base type. rt, _ := indirect(origt) @@ -147,11 +147,6 @@ func (enc *Encoder) Encode(e interface{}) os.Error { enc.state.err = nil rt, _ := indirect(reflect.Typeof(e)) - // Must be a struct - if _, ok := rt.(*reflect.StructType); !ok { - enc.badType(rt) - return enc.state.err - } // Sanity check only: encoder should never come in with data present. if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 { @@ -163,10 +158,23 @@ func (enc *Encoder) Encode(e interface{}) os.Error { // First, have we already sent this type? if _, alreadySent := enc.sent[rt]; !alreadySent { // No, so send it. - enc.sendType(rt) + sent := enc.sendType(rt) if enc.state.err != nil { return enc.state.err } + // If the type info has still not been transmitted, it means we have + // a singleton basic type (int, []byte etc.) at top level. We don't + // need to send the type info but we do need to update enc.sent. + if !sent { + typeLock.Lock() + info, err := getTypeInfo(rt) + typeLock.Unlock() + if err != nil { + enc.setError(err) + return err + } + enc.sent[rt] = info.id + } } // Identify the type of this top-level value. diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go index 4250b8a9d..b578cd0f8 100644 --- a/src/pkg/gob/encoder_test.go +++ b/src/pkg/gob/encoder_test.go @@ -131,17 +131,10 @@ func TestBadData(t *testing.T) { corruptDataCheck("\x03now is the time for all good men", errBadType, t) } -// Types not supported by the Encoder (only structs work at the top level). -// Basic types work implicitly. +// Types not supported by the Encoder. var unsupportedValues = []interface{}{ - 3, - "hi", - 7.2, - []int{1, 2, 3}, - [3]int{1, 2, 3}, make(chan int), func(a int) bool { return true }, - make(map[string]int), new(interface{}), } @@ -275,3 +268,59 @@ func TestDefaultsInArray(t *testing.T) { t.Error(err) } } + +var testInt int +var testFloat32 float32 +var testString string +var testSlice []string +var testMap map[string]int + +type SingleTest struct { + in interface{} + out interface{} + err string +} + +var singleTests = []SingleTest{ + SingleTest{17, &testInt, ""}, + SingleTest{float32(17.5), &testFloat32, ""}, + SingleTest{"bike shed", &testString, ""}, + SingleTest{[]string{"bike", "shed", "paint", "color"}, &testSlice, ""}, + SingleTest{map[string]int{"seven": 7, "twelve": 12}, &testMap, ""}, + + // Decode errors + SingleTest{172, &testFloat32, "wrong type"}, +} + +func TestSingletons(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + dec := NewDecoder(b) + for _, test := range singleTests { + b.Reset() + err := enc.Encode(test.in) + if err != nil { + t.Errorf("error encoding %v: %s", test.in, err) + continue + } + err = dec.Decode(test.out) + switch { + case err != nil && test.err == "": + t.Errorf("error decoding %v: %s", test.in, err) + continue + case err == nil && test.err != "": + t.Errorf("expected error decoding %v: %s", test.in, test.err) + continue + case err != nil && test.err != "": + if strings.Index(err.String(), test.err) < 0 { + t.Errorf("wrong error decoding %v: wanted %s, got %v", test.in, test.err, err) + } + continue + } + // Get rid of the pointer in the rhs + val := reflect.NewValue(test.out).(*reflect.PtrValue).Elem().Interface() + if !reflect.DeepEqual(test.in, val) { + t.Errorf("decoding int: expected %v got %v", test.in, val) + } + } +} |