diff options
author | Rob Pike <r@golang.org> | 2009-11-16 23:32:30 -0800 |
---|---|---|
committer | Rob Pike <r@golang.org> | 2009-11-16 23:32:30 -0800 |
commit | 92ba369b0adaa73e09b9b2764f821d234696df6b (patch) | |
tree | 725ed68c8e8549c4c6e26f6ef1e76aa2a5381879 | |
parent | 7e6b32703d34fa3913817fde6f0a935021fb1ab4 (diff) | |
download | golang-92ba369b0adaa73e09b9b2764f821d234696df6b.tar.gz |
Rework gobs to fix bad bug related to sharing of id's between encoder and decoder side.
Fix is to move all decoder state into the decoder object.
Fixes issue 215.
R=rsc
CC=golang-dev
http://codereview.appspot.com/155077
-rw-r--r-- | src/pkg/gob/codec_test.go | 102 | ||||
-rw-r--r-- | src/pkg/gob/decode.go | 80 | ||||
-rw-r--r-- | src/pkg/gob/decoder.go | 31 | ||||
-rw-r--r-- | src/pkg/gob/encoder.go | 29 | ||||
-rw-r--r-- | src/pkg/gob/encoder_test.go | 116 | ||||
-rw-r--r-- | src/pkg/gob/type.go | 5 |
6 files changed, 138 insertions, 225 deletions
diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index a1491a4a1..c5d070155 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -37,7 +37,6 @@ var encodeT = []EncodeT{ EncodeT{1 << 63, []byte{0xF8, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, } - // Test basic encode/decode routines for unsigned integers func TestUintCodec(t *testing.T) { b := new(bytes.Buffer); @@ -592,9 +591,15 @@ func TestEndToEnd(t *testing.T) { t: &T2{"this is T2"}, }; b := new(bytes.Buffer); - encode(b, t1); + err := NewEncoder(b).Encode(t1); + if err != nil { + t.Error("encode:", err) + } var _t1 T1; - decode(b, getTypeInfoNoError(reflect.Typeof(_t1)).id, &_t1); + err = NewDecoder(b).Decode(&_t1); + if err != nil { + t.Fatal("decode:", err) + } if !reflect.DeepEqual(t1, &_t1) { t.Errorf("encode expected %v got %v", *t1, _t1) } @@ -610,8 +615,9 @@ func TestOverflow(t *testing.T) { } var it inputT; var err os.Error; - id := getTypeInfoNoError(reflect.Typeof(it)).id; b := new(bytes.Buffer); + enc := NewEncoder(b); + dec := NewDecoder(b); // int8 b.Reset(); @@ -623,8 +629,8 @@ func TestOverflow(t *testing.T) { mini int8; } var o1 outi8; - encode(b, it); - err = decode(b, id, &o1); + enc.Encode(it); + err = dec.Decode(&o1); if err == nil || err.String() != `value for "maxi" out of range` { t.Error("wrong overflow error for int8:", err) } @@ -632,8 +638,8 @@ func TestOverflow(t *testing.T) { mini: math.MinInt8 - 1, }; b.Reset(); - encode(b, it); - err = decode(b, id, &o1); + enc.Encode(it); + err = dec.Decode(&o1); if err == nil || err.String() != `value for "mini" out of range` { t.Error("wrong underflow error for int8:", err) } @@ -648,8 +654,8 @@ func TestOverflow(t *testing.T) { mini int16; } var o2 outi16; - encode(b, it); - err = decode(b, id, &o2); + enc.Encode(it); + err = dec.Decode(&o2); if err == nil || err.String() != `value for "maxi" out of range` { t.Error("wrong overflow error for int16:", err) } @@ -657,8 +663,8 @@ func TestOverflow(t *testing.T) { mini: math.MinInt16 - 1, }; b.Reset(); - encode(b, it); - err = decode(b, id, &o2); + enc.Encode(it); + err = dec.Decode(&o2); if err == nil || err.String() != `value for "mini" out of range` { t.Error("wrong underflow error for int16:", err) } @@ -673,8 +679,8 @@ func TestOverflow(t *testing.T) { mini int32; } var o3 outi32; - encode(b, it); - err = decode(b, id, &o3); + enc.Encode(it); + err = dec.Decode(&o3); if err == nil || err.String() != `value for "maxi" out of range` { t.Error("wrong overflow error for int32:", err) } @@ -682,8 +688,8 @@ func TestOverflow(t *testing.T) { mini: math.MinInt32 - 1, }; b.Reset(); - encode(b, it); - err = decode(b, id, &o3); + enc.Encode(it); + err = dec.Decode(&o3); if err == nil || err.String() != `value for "mini" out of range` { t.Error("wrong underflow error for int32:", err) } @@ -697,8 +703,8 @@ func TestOverflow(t *testing.T) { maxu uint8; } var o4 outu8; - encode(b, it); - err = decode(b, id, &o4); + enc.Encode(it); + err = dec.Decode(&o4); if err == nil || err.String() != `value for "maxu" out of range` { t.Error("wrong overflow error for uint8:", err) } @@ -712,8 +718,8 @@ func TestOverflow(t *testing.T) { maxu uint16; } var o5 outu16; - encode(b, it); - err = decode(b, id, &o5); + enc.Encode(it); + err = dec.Decode(&o5); if err == nil || err.String() != `value for "maxu" out of range` { t.Error("wrong overflow error for uint16:", err) } @@ -727,8 +733,8 @@ func TestOverflow(t *testing.T) { maxu uint32; } var o6 outu32; - encode(b, it); - err = decode(b, id, &o6); + enc.Encode(it); + err = dec.Decode(&o6); if err == nil || err.String() != `value for "maxu" out of range` { t.Error("wrong overflow error for uint32:", err) } @@ -743,8 +749,8 @@ func TestOverflow(t *testing.T) { minf float32; } var o7 outf32; - encode(b, it); - err = decode(b, id, &o7); + enc.Encode(it); + err = dec.Decode(&o7); if err == nil || err.String() != `value for "maxf" out of range` { t.Error("wrong overflow error for float32:", err) } @@ -761,9 +767,13 @@ func TestNesting(t *testing.T) { rt.next = new(RT); rt.next.a = "level2"; b := new(bytes.Buffer); - encode(b, rt); + NewEncoder(b).Encode(rt); var drt RT; - decode(b, getTypeInfoNoError(reflect.Typeof(drt)).id, &drt); + dec := NewDecoder(b); + err := dec.Decode(&drt); + if err != nil { + t.Errorf("decoder error:", err) + } if drt.a != rt.a { t.Errorf("nesting: encode expected %v got %v", *rt, drt) } @@ -809,10 +819,11 @@ func TestAutoIndirection(t *testing.T) { **t1.d = new(int); ***t1.d = 17777; b := new(bytes.Buffer); - encode(b, t1); + enc := NewEncoder(b); + enc.Encode(t1); + dec := NewDecoder(b); var t0 T0; - t0Id := getTypeInfoNoError(reflect.Typeof(t0)).id; - decode(b, t0Id, &t0); + dec.Decode(&t0); if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 { t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0) } @@ -830,9 +841,9 @@ func TestAutoIndirection(t *testing.T) { **t2.a = new(int); ***t2.a = 17; b.Reset(); - encode(b, t2); + enc.Encode(t2); t0 = T0{}; - decode(b, t0Id, &t0); + dec.Decode(&t0); if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 { t.Errorf("t2->t0 expected {17 177 1777 17777}; got %v", t0) } @@ -840,32 +851,30 @@ func TestAutoIndirection(t *testing.T) { // Now transfer t0 into t1 t0 = T0{17, 177, 1777, 17777}; b.Reset(); - encode(b, t0); + enc.Encode(t0); t1 = T1{}; - t1Id := getTypeInfoNoError(reflect.Typeof(t1)).id; - decode(b, t1Id, &t1); + dec.Decode(&t1); if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 { t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d) } // Now transfer t0 into t2 b.Reset(); - encode(b, t0); + enc.Encode(t0); t2 = T2{}; - t2Id := getTypeInfoNoError(reflect.Typeof(t2)).id; - decode(b, t2Id, &t2); + dec.Decode(&t2); if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 { t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d) } // Now do t2 again but without pre-allocated pointers. b.Reset(); - encode(b, t0); + enc.Encode(t0); ***t2.a = 0; **t2.b = 0; *t2.c = 0; t2.d = 0; - decode(b, t2Id, &t2); + dec.Decode(&t2); if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 { t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d) } @@ -889,11 +898,14 @@ func TestReorderedFields(t *testing.T) { rt0.b = "hello"; rt0.c = 3.14159; b := new(bytes.Buffer); - encode(b, rt0); - rt0Id := getTypeInfoNoError(reflect.Typeof(rt0)).id; + NewEncoder(b).Encode(rt0); + dec := NewDecoder(b); var rt1 RT1; // Wire type is RT0, local type is RT1. - decode(b, rt0Id, &rt1); + err := dec.Decode(&rt1); + if err != nil { + t.Error("decode error:", err) + } if rt0.a != rt1.a || rt0.b != rt1.b || rt0.c != rt1.c { t.Errorf("rt1->rt0: expected %v; got %v", rt0, rt1) } @@ -927,11 +939,11 @@ func TestIgnoredFields(t *testing.T) { it0.ignore_i = &RT1{3.1, "hi", 7, "hello"}; b := new(bytes.Buffer); - encode(b, it0); - rt0Id := getTypeInfoNoError(reflect.Typeof(it0)).id; + NewEncoder(b).Encode(it0); + dec := NewDecoder(b); var rt1 RT1; // Wire type is IT0, local type is RT1. - err := decode(b, rt0Id, &rt1); + err := dec.Decode(&rt1); if err != nil { t.Error("error: ", err) } diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index a9cdbe684..b8400480c 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -348,7 +348,7 @@ func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { // Execution engine // The encoder engine is an array of instructions indexed by field number of the incoming -// data. It is executed with random access according to field number. +// decoder. It is executed with random access according to field number. type decEngine struct { instr []decInstr; numInstr int; // the number of active instructions @@ -515,7 +515,7 @@ var decIgnoreOpMap = map[typeId]decOp{ // Return the decoding op for the base type under rt and // the indirection count to reach it. -func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) { +func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) { typ, indir := indirect(rt); op, ok := decOpMap[reflect.Typeof(typ)]; if !ok { @@ -528,7 +528,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error break; } elemId := wireId.gobType().(*sliceType).Elem; - elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name); + elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name); if err != nil { return nil, 0, err } @@ -540,7 +540,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error case *reflect.ArrayType: name = "element of " + name; elemId := wireId.gobType().(*arrayType).Elem; - elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name); + elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name); if err != nil { return nil, 0, err } @@ -551,7 +551,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. - enginePtr, err := getDecEnginePtr(wireId, typ); + enginePtr, err := dec.getDecEnginePtr(wireId, typ); if err != nil { return nil, 0, err } @@ -568,14 +568,14 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error } // Return the decoding op for a field that has no destination. -func decIgnoreOpFor(wireId typeId) (decOp, os.Error) { +func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { op, ok := decIgnoreOpMap[wireId]; if !ok { // Special cases switch t := wireId.gobType().(type) { case *sliceType: elemId := wireId.gobType().(*sliceType).Elem; - elemOp, err := decIgnoreOpFor(elemId); + elemOp, err := dec.decIgnoreOpFor(elemId); if err != nil { return nil, err } @@ -585,7 +585,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) { case *arrayType: elemId := wireId.gobType().(*arrayType).Elem; - elemOp, err := decIgnoreOpFor(elemId); + elemOp, err := dec.decIgnoreOpFor(elemId); if err != nil { return nil, err } @@ -595,7 +595,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) { case *structType: // Generate a closure that calls out to the engine for the nested type. - enginePtr, err := getIgnoreEnginePtr(wireId); + enginePtr, err := dec.getIgnoreEnginePtr(wireId); if err != nil { return nil, err } @@ -676,11 +676,18 @@ func compatibleType(fr reflect.Type, fw typeId) bool { return true; } -func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { +func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { srt, ok1 := rt.(*reflect.StructType); - wireStruct, ok2 := wireId.gobType().(*structType); - if !ok1 || !ok2 { - return nil, errNotStruct + var wireStruct *structType; + // Builtin types can come from global pool; the rest must be defined by the decoder + if t, ok := builtinIdToType[remoteId]; ok { + wireStruct = t.(*structType) + } else { + w, ok2 := dec.wireType[remoteId]; + if !ok1 || !ok2 { + return nil, errNotStruct + } + wireStruct = w.s; } engine = new(decEngine); engine.instr = make([]decInstr, len(wireStruct.field)); @@ -692,7 +699,7 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error ovfl := overflow(wireField.name); // TODO(r): anonymous names if !present { - op, err := decIgnoreOpFor(wireField.id); + op, err := dec.decIgnoreOpFor(wireField.id); if err != nil { return nil, err } @@ -700,10 +707,10 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error continue; } if !compatibleType(localField.Type, wireField.id) { - details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + wireId.Name(); + details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + remoteId.Name(); return nil, os.ErrorString("gob: wrong type for field " + wireField.name + details); } - op, indir, err := decOpFor(wireField.id, localField.Type, localField.Name); + op, indir, err := dec.decOpFor(wireField.id, localField.Type, localField.Name); if err != nil { return nil, err } @@ -713,23 +720,19 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error return; } -var decoderCache = make(map[reflect.Type]map[typeId]**decEngine) -var ignorerCache = make(map[typeId]**decEngine) - -// typeLock must be held. -func getDecEnginePtr(wireId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) { - decoderMap, ok := decoderCache[rt]; +func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) { + decoderMap, ok := dec.decoderCache[rt]; if !ok { decoderMap = make(map[typeId]**decEngine); - decoderCache[rt] = decoderMap; + dec.decoderCache[rt] = decoderMap; } - if enginePtr, ok = decoderMap[wireId]; !ok { + if enginePtr, ok = decoderMap[remoteId]; !ok { // To handle recursive types, mark this engine as underway before compiling. enginePtr = new(*decEngine); - decoderMap[wireId] = enginePtr; - *enginePtr, err = compileDec(wireId, rt); + decoderMap[remoteId] = enginePtr; + *enginePtr, err = dec.compileDec(remoteId, rt); if err != nil { - decoderMap[wireId] = nil, false + decoderMap[remoteId] = nil, false } } return; @@ -740,35 +743,28 @@ type emptyStruct struct{} var emptyStructType = reflect.Typeof(emptyStruct{}) -// typeLock must be held. -func getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) { +func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) { var ok bool; - if enginePtr, ok = ignorerCache[wireId]; !ok { + if enginePtr, ok = dec.ignorerCache[wireId]; !ok { // To handle recursive types, mark this engine as underway before compiling. enginePtr = new(*decEngine); - ignorerCache[wireId] = enginePtr; - *enginePtr, err = compileDec(wireId, emptyStructType); + dec.ignorerCache[wireId] = enginePtr; + *enginePtr, err = dec.compileDec(wireId, emptyStructType); if err != nil { - ignorerCache[wireId] = nil, false + dec.ignorerCache[wireId] = nil, false } } return; } -func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error { +func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error { // Dereference down to the underlying struct type. rt, indir := indirect(reflect.Typeof(e)); st, ok := rt.(*reflect.StructType); if !ok { return os.ErrorString("gob: decode can't handle " + rt.String()) } - typeLock.Lock(); - if _, ok := idToType[wireId]; !ok { - typeLock.Unlock(); - return errBadType; - } - enginePtr, err := getDecEnginePtr(wireId, rt); - typeLock.Unlock(); + enginePtr, err := dec.getDecEnginePtr(wireId, rt); if err != nil { return err } @@ -777,7 +773,7 @@ func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error { name := rt.Name(); return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name); } - return decodeStruct(engine, st, b, uintptr(reflect.NewValue(e).Addr()), indir); + return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir); } func init() { diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go index dde5d823f..1713a3e59 100644 --- a/src/pkg/gob/decoder.go +++ b/src/pkg/gob/decoder.go @@ -8,17 +8,20 @@ import ( "bytes"; "io"; "os"; + "reflect"; "sync"; ) // A Decoder manages the receipt of type and data information read from the // remote side of a connection. type Decoder struct { - mutex sync.Mutex; // each item must be received atomically - r io.Reader; // source of the data - seen map[typeId]*wireType; // which types we've already seen described - state *decodeState; // reads data from in-memory buffer - countState *decodeState; // reads counts from wire + mutex sync.Mutex; // each item must be received atomically + r io.Reader; // source of the data + wireType map[typeId]*wireType; // map from remote ID to local description + decoderCache map[reflect.Type]map[typeId]**decEngine; // cache of compiled engines + ignorerCache map[typeId]**decEngine; // ditto for ignored objects + state *decodeState; // reads data from in-memory buffer + countState *decodeState; // reads counts from wire buf []byte; oneByte []byte; } @@ -27,8 +30,10 @@ type Decoder struct { func NewDecoder(r io.Reader) *Decoder { dec := new(Decoder); dec.r = r; - dec.seen = make(map[typeId]*wireType); + dec.wireType = make(map[typeId]*wireType); dec.state = newDecodeState(nil); // buffer set in Decode(); rest is unimportant + dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine); + dec.ignorerCache = make(map[typeId]**decEngine); dec.oneByte = make([]byte, 1); return dec; @@ -36,16 +41,16 @@ func NewDecoder(r io.Reader) *Decoder { func (dec *Decoder) recvType(id typeId) { // Have we already seen this type? That's an error - if _, alreadySeen := dec.seen[id]; alreadySeen { + if _, alreadySeen := dec.wireType[id]; alreadySeen { dec.state.err = os.ErrorString("gob: duplicate type received"); return; } // Type: wire := new(wireType); - decode(dec.state.b, tWireType, wire); + dec.state.err = dec.decode(tWireType, wire); // Remember we've seen this type. - dec.seen[id] = wire; + dec.wireType[id] = wire; } // Decode reads the next value from the connection and stores @@ -97,7 +102,13 @@ func (dec *Decoder) Decode(e interface{}) os.Error { } // No, it's a value. - dec.state.err = decode(dec.state.b, id, e); + // Make sure the type has been defined already. + _, ok := dec.wireType[id]; + if !ok { + dec.state.err = errBadType; + break; + } + dec.state.err = dec.decode(id, e); break; } return dec.state.err; diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index a59309e33..28ecaec93 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -235,19 +235,15 @@ func (enc *Encoder) send() { enc.w.Write(enc.buf[0:total]); } -func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) { +func (enc *Encoder) sendType(origt reflect.Type) { // Drill down to the base type. rt, _ := indirect(origt); // We only send structs - everything else is basic or an error - switch rt.(type) { + switch rt := rt.(type) { default: - // Basic types do not need to be described, but if this is a top-level - // type, it's a user error, at least for now. - if topLevel { - enc.badType(rt) - } - return; + // Basic types do not need to be described. + return case *reflect.StructType: // Structs do need to be described. break @@ -255,10 +251,9 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) { // Probably a bad field in a struct. enc.badType(rt); return; - case *reflect.ArrayType, *reflect.SliceType: - // Array and slice types are not sent, only their element types. - // If we see one here it's user error; probably a bad top-level value. - enc.badType(rt); + // Array and slice types are not sent, only their element types. + case reflect.ArrayOrSliceType: + enc.sendType(rt.Elem()); return; } @@ -289,7 +284,7 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) { // Now send the inner types st := rt.(*reflect.StructType); for i := 0; i < st.NumField(); i++ { - enc.sendType(st.Field(i).Type, false) + enc.sendType(st.Field(i).Type) } return; } @@ -301,6 +296,12 @@ func (enc *Encoder) Encode(e interface{}) os.Error { panicln("Encoder: buffer not empty") } rt, _ := indirect(reflect.Typeof(e)); + // Must be a struct + if _, ok := rt.(*reflect.StructType); !ok { + enc.badType(rt); + return enc.state.err; + } + // Make sure we're single-threaded through here. enc.mutex.Lock(); @@ -310,7 +311,7 @@ func (enc *Encoder) Encode(e interface{}) os.Error { // First, have we already sent this type? if _, alreadySent := enc.sent[rt]; !alreadySent { // No, so send it. - enc.sendType(rt, true); + enc.sendType(rt); if enc.state.err != nil { enc.state.b.Reset(); enc.countState.b.Reset(); diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go index e850bceae..43d3e72ed 100644 --- a/src/pkg/gob/encoder_test.go +++ b/src/pkg/gob/encoder_test.go @@ -32,122 +32,10 @@ type ET3 struct { // Like ET1 but with a different type for a field type ET4 struct { a int; - et2 *ET1; + et2 float; next int; } -func TestBasicEncoder(t *testing.T) { - b := new(bytes.Buffer); - enc := NewEncoder(b); - et1 := new(ET1); - et1.a = 7; - et1.et2 = new(ET2); - enc.Encode(et1); - if enc.state.err != nil { - t.Error("encoder fail:", enc.state.err) - } - - // Decode the result by hand to verify; - state := newDecodeState(b); - // The output should be: - // 0) The length, 38. - length := decodeUint(state); - if length != 38 { - t.Fatal("0. expected length 38; got", length) - } - // 1) -7: the type id of ET1 - id1 := decodeInt(state); - if id1 >= 0 { - t.Fatal("expected ET1 negative id; got", id1) - } - // 2) The wireType for ET1 - wire1 := new(wireType); - err := decode(b, tWireType, wire1); - if err != nil { - t.Fatal("error decoding ET1 type:", err) - } - info := getTypeInfoNoError(reflect.Typeof(ET1{})); - trueWire1 := &wireType{s: info.id.gobType().(*structType)}; - if !reflect.DeepEqual(wire1, trueWire1) { - t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1) - } - // 3) The length, 21. - length = decodeUint(state); - if length != 21 { - t.Fatal("3. expected length 21; got", length) - } - // 4) -8: the type id of ET2 - id2 := decodeInt(state); - if id2 >= 0 { - t.Fatal("expected ET2 negative id; got", id2) - } - // 5) The wireType for ET2 - wire2 := new(wireType); - err = decode(b, tWireType, wire2); - if err != nil { - t.Fatal("error decoding ET2 type:", err) - } - info = getTypeInfoNoError(reflect.Typeof(ET2{})); - trueWire2 := &wireType{s: info.id.gobType().(*structType)}; - if !reflect.DeepEqual(wire2, trueWire2) { - t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2) - } - // 6) The length, 6. - length = decodeUint(state); - if length != 6 { - t.Fatal("6. expected length 6; got", length) - } - // 7) The type id for the et1 value - newId1 := decodeInt(state); - if newId1 != -id1 { - t.Fatal("expected Et1 id", -id1, "got", newId1) - } - // 8) The value of et1 - newEt1 := new(ET1); - et1Id := getTypeInfoNoError(reflect.Typeof(*newEt1)).id; - err = decode(b, et1Id, newEt1); - if err != nil { - t.Fatal("error decoding ET1 value:", err) - } - if !reflect.DeepEqual(et1, newEt1) { - t.Fatalf("invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1) - } - // 9) EOF - if b.Len() != 0 { - t.Error("not at eof;", b.Len(), "bytes left") - } - - // Now do it again. This time we should see only the type id and value. - b.Reset(); - enc.Encode(et1); - if enc.state.err != nil { - t.Error("2nd round: encoder fail:", enc.state.err) - } - // The length. - length = decodeUint(state); - if length != 6 { - t.Fatal("6. expected length 6; got", length) - } - // 5a) The type id for the et1 value - newId1 = decodeInt(state); - if newId1 != -id1 { - t.Fatal("2nd round: expected Et1 id", -id1, "got", newId1) - } - // 6a) The value of et1 - newEt1 = new(ET1); - err = decode(b, et1Id, newEt1); - if err != nil { - t.Fatal("2nd round: error decoding ET1 value:", err) - } - if !reflect.DeepEqual(et1, newEt1) { - t.Fatalf("2nd round: invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1) - } - // 7a) EOF - if b.Len() != 0 { - t.Error("2nd round: not at eof;", b.Len(), "bytes left") - } -} - func TestEncoderDecoder(t *testing.T) { b := new(bytes.Buffer); enc := NewEncoder(b); @@ -215,7 +103,7 @@ func badTypeCheck(e interface{}, shouldFail bool, msg string, t *testing.T) { t.Error("expected error for", msg) } if !shouldFail && (dec.state.err != nil) { - t.Error("unexpected error for", msg) + t.Error("unexpected error for", msg, dec.state.err) } } diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index 92db7cef3..ffff0541d 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -48,6 +48,7 @@ type gobType interface { var types = make(map[reflect.Type]gobType) var idToType = make(map[typeId]gobType) +var builtinIdToType map[typeId]gobType // set in init() after builtins are established func setTypeId(typ gobType) { nextId++; @@ -104,6 +105,10 @@ func init() { checkId(8, getTypeInfoNoError(reflect.Typeof(structType{})).id); checkId(9, getTypeInfoNoError(reflect.Typeof(commonType{})).id); checkId(10, getTypeInfoNoError(reflect.Typeof(fieldType{})).id); + builtinIdToType = make(map[typeId]gobType); + for k, v := range idToType { + builtinIdToType[k] = v + } } // Array type |