diff options
author | Rob Pike <r@golang.org> | 2010-05-05 16:46:39 -0700 |
---|---|---|
committer | Rob Pike <r@golang.org> | 2010-05-05 16:46:39 -0700 |
commit | 69dd710bc51d7feded67d58132cbc3a7055f16a3 (patch) | |
tree | 9fb90ef889824db02c2794c93dbed35ea8b969b9 | |
parent | 1dd79d8f4bf45a7ebdb7b006bb02188b5561ce16 (diff) | |
download | golang-69dd710bc51d7feded67d58132cbc3a7055f16a3.tar.gz |
gob: add support for maps.
Because maps are mostly a hidden type, they must be
implemented using reflection values and will not be as
efficient as arrays and slices.
R=rsc
CC=golang-dev
http://codereview.appspot.com/1127041
-rw-r--r-- | src/pkg/gob/codec_test.go | 6 | ||||
-rw-r--r-- | src/pkg/gob/decode.go | 148 | ||||
-rw-r--r-- | src/pkg/gob/encode.go | 55 | ||||
-rw-r--r-- | src/pkg/gob/encoder.go | 10 | ||||
-rw-r--r-- | src/pkg/gob/type.go | 58 | ||||
-rw-r--r-- | src/pkg/gob/type_test.go | 20 |
6 files changed, 254 insertions, 43 deletions
diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index df82a5b6b..447b199cb 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -572,6 +572,7 @@ func TestEndToEnd(t *testing.T) { s2 := "string2" type T1 struct { a, b, c int + m map[string]*float n *[3]float strs *[2]string int64s *[]int64 @@ -579,10 +580,13 @@ func TestEndToEnd(t *testing.T) { y []byte t *T2 } + pi := 3.14159 + e := 2.71828 t1 := &T1{ a: 17, b: 18, c: -5, + m: map[string]*float{"pi": &pi, "e": &e}, n: &[3]float{1.5, 2.5, 3.5}, strs: &[2]string{s1, s2}, int64s: &[]int64{77, 89, 123412342134}, @@ -921,6 +925,7 @@ type IT0 struct { ignore_g string ignore_h []byte ignore_i *RT1 + ignore_m map[string]int c float } @@ -937,6 +942,7 @@ func TestIgnoredFields(t *testing.T) { it0.ignore_g = "pay no attention" it0.ignore_h = []byte("to the curtain") it0.ignore_i = &RT1{3.1, "hi", 7, "hello"} + it0.ignore_m = map[string]int{"one": 1, "two": 2} b := new(bytes.Buffer) NewEncoder(b).Encode(it0) diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 3b14841af..fb1e99367 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -447,6 +447,49 @@ func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) } +func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { + instr := &decInstr{op, 0, indir, 0, ovfl} + up := unsafe.Pointer(v.Addr()) + if indir > 1 { + up = decIndirect(up, indir) + } + op(instr, state, up) + return v +} + +func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error { + if indir > 0 { + up := unsafe.Pointer(p) + if *(*unsafe.Pointer)(up) == nil { + // Allocate object. + *(*unsafe.Pointer)(up) = unsafe.New(mtyp) + } + p = *(*uintptr)(up) + } + up := unsafe.Pointer(p) + if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime + // Allocate map. + *(*unsafe.Pointer)(up) = unsafe.Pointer(reflect.MakeMap(mtyp).Get()) + } + // Maps cannot be accessed by moving addresses around the way + // that slices etc. can. We must recover a full reflection value for + // the iteration. + v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue) + n := int(decodeUint(state)) + for i := 0; i < n && state.err == nil; i++ { + key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl) + if state.err != nil { + break + } + elem := decodeIntoValue(state, elemOp, elemIndir, reflect.MakeZero(mtyp.Elem()), ovfl) + if state.err != nil { + break + } + v.SetElem(key, elem) + } + return state.err +} + func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error { instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} for i := 0; i < length && state.err == nil; i++ { @@ -462,6 +505,18 @@ func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error { return ignoreArrayHelper(state, elemOp, length) } +func ignoreMap(state *decodeState, keyOp, elemOp decOp) os.Error { + n := int(decodeUint(state)) + keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")} + elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} + for i := 0; i < n && state.err == nil; i++ { + keyOp(keyInstr, state, nil) + elemOp(elemInstr, state, nil) + } + return state.err +} + + func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error { n := int(uintptr(decodeUint(state))) if indir > 0 { @@ -517,17 +572,25 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp if !ok { // Special cases switch t := typ.(type) { - case *reflect.SliceType: + case *reflect.ArrayType: name = "element of " + name - if _, ok := t.Elem().(*reflect.Uint8Type); ok { - op = decUint8Array - break + elemId := dec.wireType[wireId].arrayT.Elem + elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) + if err != nil { + return nil, 0, err } - var elemId typeId - if tt, ok := builtinIdToType[wireId]; ok { - elemId = tt.(*sliceType).Elem - } else { - elemId = dec.wireType[wireId].slice.Elem + ovfl := overflow(name) + op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) + } + + case *reflect.MapType: + name = "element of " + name + keyId := dec.wireType[wireId].mapT.Key + elemId := dec.wireType[wireId].mapT.Elem + keyOp, keyIndir, err := dec.decOpFor(keyId, t.Key(), name) + if err != nil { + return nil, 0, err } elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) if err != nil { @@ -535,19 +598,32 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp } ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) + up := unsafe.Pointer(p) + if indir > 1 { + up = decIndirect(up, indir) + } + state.err = decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) } - case *reflect.ArrayType: + case *reflect.SliceType: name = "element of " + name - elemId := dec.wireType[wireId].array.Elem + if _, ok := t.Elem().(*reflect.Uint8Type); ok { + op = decUint8Array + break + } + var elemId typeId + if tt, ok := builtinIdToType[wireId]; ok { + elemId = tt.(*sliceType).Elem + } else { + elemId = dec.wireType[wireId].sliceT.Elem + } elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name) if err != nil { return nil, 0, err } ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) + state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) } case *reflect.StructType: @@ -575,18 +651,33 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { // Special cases wire := dec.wireType[wireId] switch { - case wire.array != nil: - elemId := wire.array.Elem + case wire.arrayT != nil: + elemId := wire.arrayT.Elem elemOp, err := dec.decIgnoreOpFor(elemId) if err != nil { return nil, err } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = ignoreArray(state, elemOp, wire.array.Len) + state.err = ignoreArray(state, elemOp, wire.arrayT.Len) } - case wire.slice != nil: - elemId := wire.slice.Elem + case wire.mapT != nil: + keyId := dec.wireType[wireId].mapT.Key + elemId := dec.wireType[wireId].mapT.Elem + keyOp, err := dec.decIgnoreOpFor(keyId) + if err != nil { + return nil, err + } + elemOp, err := dec.decIgnoreOpFor(elemId) + if err != nil { + return nil, err + } + op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + state.err = ignoreMap(state, keyOp, elemOp) + } + + case wire.sliceT != nil: + elemId := wire.sliceT.Elem elemOp, err := dec.decIgnoreOpFor(elemId) if err != nil { return nil, err @@ -595,7 +686,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { state.err = ignoreSlice(state, elemOp) } - case wire.strct != nil: + case wire.structT != nil: // Generate a closure that calls out to the engine for the nested type. enginePtr, err := dec.getIgnoreEnginePtr(wireId) if err != nil { @@ -640,11 +731,18 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { return fw == tString case *reflect.ArrayType: wire, ok := dec.wireType[fw] - if !ok || wire.array == nil { + if !ok || wire.arrayT == nil { + return false + } + array := wire.arrayT + return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) + case *reflect.MapType: + wire, ok := dec.wireType[fw] + if !ok || wire.mapT == nil { return false } - array := wire.array - return ok && t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) + mapType := wire.mapT + return dec.compatibleType(t.Key(), mapType.Key) && dec.compatibleType(t.Elem(), mapType.Elem) case *reflect.SliceType: // Is it an array of bytes? et := t.Elem() @@ -656,7 +754,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { if tt, ok := builtinIdToType[fw]; ok { sw = tt.(*sliceType) } else { - sw = dec.wireType[fw].slice + sw = dec.wireType[fw].sliceT } elem, _ := indirect(t.Elem()) return sw != nil && dec.compatibleType(elem, sw.Elem) @@ -677,7 +775,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng if !ok1 || !ok2 { return nil, errNotStruct } - wireStruct = w.strct + wireStruct = w.structT } engine = new(decEngine) engine.instr = make([]decInstr, len(wireStruct.field)) @@ -760,7 +858,7 @@ func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error { return err } engine := *enginePtr - if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].strct.field) > 0 { + if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 { name := rt.Name() return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) } diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index 195d6c647..fbea891b9 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -22,7 +22,7 @@ const uint64Size = unsafe.Sizeof(uint64(0)) type encoderState struct { b *bytes.Buffer err os.Error // error encountered during encoding. - inArray bool // encoding an array element + inArray bool // encoding an array element or map key/value pair fieldnum int // the last field number written. buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation. } @@ -297,7 +297,7 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error { return state.err } -func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length int, elemIndir int) os.Error { +func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) os.Error { state := new(encoderState) state.b = b state.fieldnum = -1 @@ -319,6 +319,39 @@ func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length i return state.err } +func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) { + for i := 0; i < indir && v != nil; i++ { + v = reflect.Indirect(v) + } + if v == nil { + state.err = os.ErrorString("gob: encodeMap: nil element") + return + } + op(nil, state, unsafe.Pointer(v.Addr())) +} + +func encodeMap(b *bytes.Buffer, rt reflect.Type, p uintptr, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error { + state := new(encoderState) + state.b = b + state.fieldnum = -1 + state.inArray = true + // Maps cannot be accessed by moving addresses around the way + // that slices etc. can. We must recover a full reflection value for + // the iteration. + v := reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer((p)))) + mv := reflect.Indirect(v).(*reflect.MapValue) + keys := mv.Keys() + encodeUint(state, uint64(len(keys))) + for _, key := range keys { + if state.err != nil { + break + } + encodeReflectValue(state, key, keyOp, keyIndir) + encodeReflectValue(state, mv.Elem(key), elemOp, elemIndir) + } + return state.err +} + var encOpMap = map[reflect.Type]encOp{ valueKind(false): encBool, valueKind(int(0)): encInt, @@ -344,7 +377,6 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { typ, indir := indirect(rt) op, ok := encOpMap[reflect.Typeof(typ)] if !ok { - typ, _ := indirect(rt) // Special cases switch t := typ.(type) { case *reflect.SliceType: @@ -363,7 +395,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { return } state.update(i) - state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), int(slice.Len), indir) + state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len)) } case *reflect.ArrayType: // True arrays have size in the type. @@ -373,7 +405,20 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { } op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i) - state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), t.Len(), indir) + state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len()) + } + case *reflect.MapType: + keyOp, keyIndir, err := encOpFor(t.Key()) + if err != nil { + return nil, 0, err + } + elemOp, elemIndir, err := encOpFor(t.Elem()) + if err != nil { + return nil, 0, err + } + op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { + state.update(i) + state.err = encodeMap(state.b, typ, uintptr(p), keyOp, elemOp, keyIndir, elemIndir) } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index 8ba503138..d65a71080 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -71,9 +71,8 @@ Structs, arrays and slices are also supported. Strings and arrays of bytes are supported with a special, efficient representation (see below). - Maps are not supported yet, but they will be. Interfaces, functions, and channels - cannot be sent in a gob. Attempting to encode a value that contains one will - fail. + Interfaces, functions, and channels cannot be sent in a gob. Attempting + to encode a value that contains one will fail. The rest of this comment documents the encoding, details that are not important for most users. Details are presented bottom-up. @@ -263,10 +262,13 @@ func (enc *Encoder) sendType(origt reflect.Type) { case *reflect.ArrayType: // arrays must be sent so we know their lengths and element types. break + case *reflect.MapType: + // maps must be sent so we know their lengths and key/value types. + break case *reflect.StructType: // structs must be sent so we know their fields. break - case *reflect.ChanType, *reflect.FuncType, *reflect.MapType, *reflect.InterfaceType: + case *reflect.ChanType, *reflect.FuncType, *reflect.InterfaceType: // Probably a bad field in a struct. enc.badType(rt) return diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index 2a178af04..78793ba44 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -142,6 +142,31 @@ func (a *arrayType) safeString(seen map[typeId]bool) string { func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) } +// Map type +type mapType struct { + commonType + Key typeId + Elem typeId +} + +func newMapType(name string, key, elem gobType) *mapType { + m := &mapType{commonType{name: name}, key.id(), elem.id()} + setTypeId(m) + return m +} + +func (m *mapType) safeString(seen map[typeId]bool) string { + if seen[m._id] { + return m.name + } + seen[m._id] = true + key := m.Key.gobType().safeString(seen) + elem := m.Elem.gobType().safeString(seen) + return fmt.Sprintf("map[%s]%s", key, elem) +} + +func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) } + // Slice type type sliceType struct { commonType @@ -239,6 +264,17 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { } return newArrayType(name, gt, t.Len()), nil + case *reflect.MapType: + kt, err := getType("", t.Key()) + if err != nil { + return nil, err + } + vt, err := getType("", t.Elem()) + if err != nil { + return nil, err + } + return newMapType(name, kt, vt), nil + case *reflect.SliceType: // []byte == []uint8 is a special case if _, ok := t.Elem().(*reflect.Uint8Type); ok { @@ -330,16 +366,18 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { // using the gob rules for sending a structure, except that we assume the // ids for wireType and structType are known. The relevant pieces // are built in encode.go's init() function. - +// To maintain binary compatibility, if you extend this type, always put +// the new fields last. type wireType struct { - array *arrayType - slice *sliceType - strct *structType + arrayT *arrayType + sliceT *sliceType + structT *structType + mapT *mapType } func (w *wireType) name() string { - if w.strct != nil { - return w.strct.name + if w.structT != nil { + return w.structT.name } return "unknown" } @@ -370,14 +408,16 @@ func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) { t := info.id.gobType() switch typ := rt.(type) { case *reflect.ArrayType: - info.wire = &wireType{array: t.(*arrayType)} + info.wire = &wireType{arrayT: t.(*arrayType)} + case *reflect.MapType: + info.wire = &wireType{mapT: t.(*mapType)} case *reflect.SliceType: // []byte == []uint8 is a special case handled separately if _, ok := typ.Elem().(*reflect.Uint8Type); !ok { - info.wire = &wireType{slice: t.(*sliceType)} + info.wire = &wireType{sliceT: t.(*sliceType)} } case *reflect.StructType: - info.wire = &wireType{strct: t.(*structType)} + info.wire = &wireType{structT: t.(*structType)} } typeInfoMap[rt] = info } diff --git a/src/pkg/gob/type_test.go b/src/pkg/gob/type_test.go index 3d4871f1d..6acfa7135 100644 --- a/src/pkg/gob/type_test.go +++ b/src/pkg/gob/type_test.go @@ -105,6 +105,26 @@ func TestSliceType(t *testing.T) { } } +func TestMapType(t *testing.T) { + var m map[string]int + mapStringInt := getTypeUnlocked("map", reflect.Typeof(m)) + var newm map[string]int + newMapStringInt := getTypeUnlocked("map1", reflect.Typeof(newm)) + if mapStringInt != newMapStringInt { + t.Errorf("second registration of map[string]int creates new type") + } + var b map[string]bool + mapStringBool := getTypeUnlocked("", reflect.Typeof(b)) + if mapStringBool == mapStringInt { + t.Errorf("registration of map[string]bool creates same type as map[string]int") + } + str := mapStringBool.string() + expected := "map[string]bool" + if str != expected { + t.Errorf("map printed as %q; expected %q", str, expected) + } +} + type Bar struct { x string } |