diff options
Diffstat (limited to 'src/pkg/gob/decode.go')
-rw-r--r-- | src/pkg/gob/decode.go | 148 |
1 files changed, 123 insertions, 25 deletions
diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 3b14841af..fb1e99367 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -447,6 +447,49 @@ func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) } +func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { + instr := &decInstr{op, 0, indir, 0, ovfl} + up := unsafe.Pointer(v.Addr()) + if indir > 1 { + up = decIndirect(up, indir) + } + op(instr, state, up) + return v +} + +func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error { + if indir > 0 { + up := unsafe.Pointer(p) + if *(*unsafe.Pointer)(up) == nil { + // Allocate object. + *(*unsafe.Pointer)(up) = unsafe.New(mtyp) + } + p = *(*uintptr)(up) + } + up := unsafe.Pointer(p) + if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime + // Allocate map. + *(*unsafe.Pointer)(up) = unsafe.Pointer(reflect.MakeMap(mtyp).Get()) + } + // Maps cannot be accessed by moving addresses around the way + // that slices etc. can. We must recover a full reflection value for + // the iteration. + v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue) + n := int(decodeUint(state)) + for i := 0; i < n && state.err == nil; i++ { + key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl) + if state.err != nil { + break + } + elem := decodeIntoValue(state, elemOp, elemIndir, reflect.MakeZero(mtyp.Elem()), ovfl) + if state.err != nil { + break + } + v.SetElem(key, elem) + } + return state.err +} + func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error { instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} for i := 0; i < length && state.err == nil; i++ { @@ -462,6 +505,18 @@ func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error { return ignoreArrayHelper(state, elemOp, length) } +func ignoreMap(state *decodeState, keyOp, elemOp decOp) os.Error { + n := int(decodeUint(state)) + keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")} + elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} + for i := 0; i < n && state.err == nil; i++ { + keyOp(keyInstr, state, nil) + elemOp(elemInstr, state, nil) + } + return state.err +} + + func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error { n := int(uintptr(decodeUint(state))) if indir > 0 { @@ -517,17 +572,25 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp if !ok { // Special cases switch t := typ.(type) { - case *reflect.SliceType: + case *reflect.ArrayType: name = "element of " + name - if _, ok := t.Elem().(*reflect.Uint8Type); ok { - op = decUint8Array - break + elemId := dec.wireType[wireId].arrayT.Elem + elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) + if err != nil { + return nil, 0, err } - var elemId typeId - if tt, ok := builtinIdToType[wireId]; ok { - elemId = tt.(*sliceType).Elem - } else { - elemId = dec.wireType[wireId].slice.Elem + ovfl := overflow(name) + op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) + } + + case *reflect.MapType: + name = "element of " + name + keyId := dec.wireType[wireId].mapT.Key + elemId := dec.wireType[wireId].mapT.Elem + keyOp, keyIndir, err := dec.decOpFor(keyId, t.Key(), name) + if err != nil { + return nil, 0, err } elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) if err != nil { @@ -535,19 +598,32 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp } ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) + up := unsafe.Pointer(p) + if indir > 1 { + up = decIndirect(up, indir) + } + state.err = decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) } - case *reflect.ArrayType: + case *reflect.SliceType: name = "element of " + name - elemId := dec.wireType[wireId].array.Elem + if _, ok := t.Elem().(*reflect.Uint8Type); ok { + op = decUint8Array + break + } + var elemId typeId + if tt, ok := builtinIdToType[wireId]; ok { + elemId = tt.(*sliceType).Elem + } else { + elemId = dec.wireType[wireId].sliceT.Elem + } elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) if err != nil { return nil, 0, err } ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) + state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) } case *reflect.StructType: @@ -575,18 +651,33 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { // Special cases wire := dec.wireType[wireId] switch { - case wire.array != nil: - elemId := wire.array.Elem + case wire.arrayT != nil: + elemId := wire.arrayT.Elem elemOp, err := dec.decIgnoreOpFor(elemId) if err != nil { return nil, err } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = ignoreArray(state, elemOp, wire.array.Len) + state.err = ignoreArray(state, elemOp, wire.arrayT.Len) } - case wire.slice != nil: - elemId := wire.slice.Elem + case wire.mapT != nil: + keyId := dec.wireType[wireId].mapT.Key + elemId := dec.wireType[wireId].mapT.Elem + keyOp, err := dec.decIgnoreOpFor(keyId) + if err != nil { + return nil, err + } + elemOp, err := dec.decIgnoreOpFor(elemId) + if err != nil { + return nil, err + } + op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + state.err = ignoreMap(state, keyOp, elemOp) + } + + case wire.sliceT != nil: + elemId := wire.sliceT.Elem elemOp, err := dec.decIgnoreOpFor(elemId) if err != nil { return nil, err @@ -595,7 +686,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { state.err = ignoreSlice(state, elemOp) } - case wire.strct != nil: + case wire.structT != nil: // Generate a closure that calls out to the engine for the nested type. enginePtr, err := dec.getIgnoreEnginePtr(wireId) if err != nil { @@ -640,11 +731,18 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { return fw == tString case *reflect.ArrayType: wire, ok := dec.wireType[fw] - if !ok || wire.array == nil { + if !ok || wire.arrayT == nil { + return false + } + array := wire.arrayT + return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) + case *reflect.MapType: + wire, ok := dec.wireType[fw] + if !ok || wire.mapT == nil { return false } - array := wire.array - return ok && t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) + mapType := wire.mapT + return dec.compatibleType(t.Key(), mapType.Key) && dec.compatibleType(t.Elem(), mapType.Elem) case *reflect.SliceType: // Is it an array of bytes? et := t.Elem() @@ -656,7 +754,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { if tt, ok := builtinIdToType[fw]; ok { sw = tt.(*sliceType) } else { - sw = dec.wireType[fw].slice + sw = dec.wireType[fw].sliceT } elem, _ := indirect(t.Elem()) return sw != nil && dec.compatibleType(elem, sw.Elem) @@ -677,7 +775,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng if !ok1 || !ok2 { return nil, errNotStruct } - wireStruct = w.strct + wireStruct = w.structT } engine = new(decEngine) engine.instr = make([]decInstr, len(wireStruct.field)) @@ -760,7 +858,7 @@ func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error { return err } engine := *enginePtr - if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].strct.field) > 0 { + if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 { name := rt.Name() return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) } |