summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRob Pike <r@golang.org>2010-05-05 16:46:39 -0700
committerRob Pike <r@golang.org>2010-05-05 16:46:39 -0700
commit69dd710bc51d7feded67d58132cbc3a7055f16a3 (patch)
tree9fb90ef889824db02c2794c93dbed35ea8b969b9
parent1dd79d8f4bf45a7ebdb7b006bb02188b5561ce16 (diff)
downloadgolang-69dd710bc51d7feded67d58132cbc3a7055f16a3.tar.gz
gob: add support for maps.
Because maps are mostly a hidden type, they must be implemented using reflection values and will not be as efficient as arrays and slices. R=rsc CC=golang-dev http://codereview.appspot.com/1127041
-rw-r--r--src/pkg/gob/codec_test.go6
-rw-r--r--src/pkg/gob/decode.go148
-rw-r--r--src/pkg/gob/encode.go55
-rw-r--r--src/pkg/gob/encoder.go10
-rw-r--r--src/pkg/gob/type.go58
-rw-r--r--src/pkg/gob/type_test.go20
6 files changed, 254 insertions, 43 deletions
diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go
index df82a5b6b..447b199cb 100644
--- a/src/pkg/gob/codec_test.go
+++ b/src/pkg/gob/codec_test.go
@@ -572,6 +572,7 @@ func TestEndToEnd(t *testing.T) {
s2 := "string2"
type T1 struct {
a, b, c int
+ m map[string]*float
n *[3]float
strs *[2]string
int64s *[]int64
@@ -579,10 +580,13 @@ func TestEndToEnd(t *testing.T) {
y []byte
t *T2
}
+ pi := 3.14159
+ e := 2.71828
t1 := &T1{
a: 17,
b: 18,
c: -5,
+ m: map[string]*float{"pi": &pi, "e": &e},
n: &[3]float{1.5, 2.5, 3.5},
strs: &[2]string{s1, s2},
int64s: &[]int64{77, 89, 123412342134},
@@ -921,6 +925,7 @@ type IT0 struct {
ignore_g string
ignore_h []byte
ignore_i *RT1
+ ignore_m map[string]int
c float
}
@@ -937,6 +942,7 @@ func TestIgnoredFields(t *testing.T) {
it0.ignore_g = "pay no attention"
it0.ignore_h = []byte("to the curtain")
it0.ignore_i = &RT1{3.1, "hi", 7, "hello"}
+ it0.ignore_m = map[string]int{"one": 1, "two": 2}
b := new(bytes.Buffer)
NewEncoder(b).Encode(it0)
diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go
index 3b14841af..fb1e99367 100644
--- a/src/pkg/gob/decode.go
+++ b/src/pkg/gob/decode.go
@@ -447,6 +447,49 @@ func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp
return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl)
}
+func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
+ instr := &decInstr{op, 0, indir, 0, ovfl}
+ up := unsafe.Pointer(v.Addr())
+ if indir > 1 {
+ up = decIndirect(up, indir)
+ }
+ op(instr, state, up)
+ return v
+}
+
+func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error {
+ if indir > 0 {
+ up := unsafe.Pointer(p)
+ if *(*unsafe.Pointer)(up) == nil {
+ // Allocate object.
+ *(*unsafe.Pointer)(up) = unsafe.New(mtyp)
+ }
+ p = *(*uintptr)(up)
+ }
+ up := unsafe.Pointer(p)
+ if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime
+ // Allocate map.
+ *(*unsafe.Pointer)(up) = unsafe.Pointer(reflect.MakeMap(mtyp).Get())
+ }
+ // Maps cannot be accessed by moving addresses around the way
+ // that slices etc. can. We must recover a full reflection value for
+ // the iteration.
+ v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue)
+ n := int(decodeUint(state))
+ for i := 0; i < n && state.err == nil; i++ {
+ key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl)
+ if state.err != nil {
+ break
+ }
+ elem := decodeIntoValue(state, elemOp, elemIndir, reflect.MakeZero(mtyp.Elem()), ovfl)
+ if state.err != nil {
+ break
+ }
+ v.SetElem(key, elem)
+ }
+ return state.err
+}
+
func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error {
instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
for i := 0; i < length && state.err == nil; i++ {
@@ -462,6 +505,18 @@ func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error {
return ignoreArrayHelper(state, elemOp, length)
}
+func ignoreMap(state *decodeState, keyOp, elemOp decOp) os.Error {
+ n := int(decodeUint(state))
+ keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")}
+ elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
+ for i := 0; i < n && state.err == nil; i++ {
+ keyOp(keyInstr, state, nil)
+ elemOp(elemInstr, state, nil)
+ }
+ return state.err
+}
+
+
func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error {
n := int(uintptr(decodeUint(state)))
if indir > 0 {
@@ -517,17 +572,25 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
if !ok {
// Special cases
switch t := typ.(type) {
- case *reflect.SliceType:
+ case *reflect.ArrayType:
name = "element of " + name
- if _, ok := t.Elem().(*reflect.Uint8Type); ok {
- op = decUint8Array
- break
+ elemId := dec.wireType[wireId].arrayT.Elem
+ elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
+ if err != nil {
+ return nil, 0, err
}
- var elemId typeId
- if tt, ok := builtinIdToType[wireId]; ok {
- elemId = tt.(*sliceType).Elem
- } else {
- elemId = dec.wireType[wireId].slice.Elem
+ ovfl := overflow(name)
+ op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
+ }
+
+ case *reflect.MapType:
+ name = "element of " + name
+ keyId := dec.wireType[wireId].mapT.Key
+ elemId := dec.wireType[wireId].mapT.Elem
+ keyOp, keyIndir, err := dec.decOpFor(keyId, t.Key(), name)
+ if err != nil {
+ return nil, 0, err
}
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
if err != nil {
@@ -535,19 +598,32 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
}
ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
- state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
+ up := unsafe.Pointer(p)
+ if indir > 1 {
+ up = decIndirect(up, indir)
+ }
+ state.err = decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl)
}
- case *reflect.ArrayType:
+ case *reflect.SliceType:
name = "element of " + name
- elemId := dec.wireType[wireId].array.Elem
+ if _, ok := t.Elem().(*reflect.Uint8Type); ok {
+ op = decUint8Array
+ break
+ }
+ var elemId typeId
+ if tt, ok := builtinIdToType[wireId]; ok {
+ elemId = tt.(*sliceType).Elem
+ } else {
+ elemId = dec.wireType[wireId].sliceT.Elem
+ }
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
if err != nil {
return nil, 0, err
}
ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
- state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
+ state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
}
case *reflect.StructType:
@@ -575,18 +651,33 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
// Special cases
wire := dec.wireType[wireId]
switch {
- case wire.array != nil:
- elemId := wire.array.Elem
+ case wire.arrayT != nil:
+ elemId := wire.arrayT.Elem
elemOp, err := dec.decIgnoreOpFor(elemId)
if err != nil {
return nil, err
}
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
- state.err = ignoreArray(state, elemOp, wire.array.Len)
+ state.err = ignoreArray(state, elemOp, wire.arrayT.Len)
}
- case wire.slice != nil:
- elemId := wire.slice.Elem
+ case wire.mapT != nil:
+ keyId := dec.wireType[wireId].mapT.Key
+ elemId := dec.wireType[wireId].mapT.Elem
+ keyOp, err := dec.decIgnoreOpFor(keyId)
+ if err != nil {
+ return nil, err
+ }
+ elemOp, err := dec.decIgnoreOpFor(elemId)
+ if err != nil {
+ return nil, err
+ }
+ op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ state.err = ignoreMap(state, keyOp, elemOp)
+ }
+
+ case wire.sliceT != nil:
+ elemId := wire.sliceT.Elem
elemOp, err := dec.decIgnoreOpFor(elemId)
if err != nil {
return nil, err
@@ -595,7 +686,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
state.err = ignoreSlice(state, elemOp)
}
- case wire.strct != nil:
+ case wire.structT != nil:
// Generate a closure that calls out to the engine for the nested type.
enginePtr, err := dec.getIgnoreEnginePtr(wireId)
if err != nil {
@@ -640,11 +731,18 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
return fw == tString
case *reflect.ArrayType:
wire, ok := dec.wireType[fw]
- if !ok || wire.array == nil {
+ if !ok || wire.arrayT == nil {
+ return false
+ }
+ array := wire.arrayT
+ return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem)
+ case *reflect.MapType:
+ wire, ok := dec.wireType[fw]
+ if !ok || wire.mapT == nil {
return false
}
- array := wire.array
- return ok && t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem)
+ mapType := wire.mapT
+ return dec.compatibleType(t.Key(), mapType.Key) && dec.compatibleType(t.Elem(), mapType.Elem)
case *reflect.SliceType:
// Is it an array of bytes?
et := t.Elem()
@@ -656,7 +754,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
if tt, ok := builtinIdToType[fw]; ok {
sw = tt.(*sliceType)
} else {
- sw = dec.wireType[fw].slice
+ sw = dec.wireType[fw].sliceT
}
elem, _ := indirect(t.Elem())
return sw != nil && dec.compatibleType(elem, sw.Elem)
@@ -677,7 +775,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
if !ok1 || !ok2 {
return nil, errNotStruct
}
- wireStruct = w.strct
+ wireStruct = w.structT
}
engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.field))
@@ -760,7 +858,7 @@ func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
return err
}
engine := *enginePtr
- if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].strct.field) > 0 {
+ if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 {
name := rt.Name()
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
}
diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go
index 195d6c647..fbea891b9 100644
--- a/src/pkg/gob/encode.go
+++ b/src/pkg/gob/encode.go
@@ -22,7 +22,7 @@ const uint64Size = unsafe.Sizeof(uint64(0))
type encoderState struct {
b *bytes.Buffer
err os.Error // error encountered during encoding.
- inArray bool // encoding an array element
+ inArray bool // encoding an array element or map key/value pair
fieldnum int // the last field number written.
buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation.
}
@@ -297,7 +297,7 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
return state.err
}
-func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length int, elemIndir int) os.Error {
+func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) os.Error {
state := new(encoderState)
state.b = b
state.fieldnum = -1
@@ -319,6 +319,39 @@ func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length i
return state.err
}
+func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) {
+ for i := 0; i < indir && v != nil; i++ {
+ v = reflect.Indirect(v)
+ }
+ if v == nil {
+ state.err = os.ErrorString("gob: encodeMap: nil element")
+ return
+ }
+ op(nil, state, unsafe.Pointer(v.Addr()))
+}
+
+func encodeMap(b *bytes.Buffer, rt reflect.Type, p uintptr, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error {
+ state := new(encoderState)
+ state.b = b
+ state.fieldnum = -1
+ state.inArray = true
+ // Maps cannot be accessed by moving addresses around the way
+ // that slices etc. can. We must recover a full reflection value for
+ // the iteration.
+ v := reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer((p))))
+ mv := reflect.Indirect(v).(*reflect.MapValue)
+ keys := mv.Keys()
+ encodeUint(state, uint64(len(keys)))
+ for _, key := range keys {
+ if state.err != nil {
+ break
+ }
+ encodeReflectValue(state, key, keyOp, keyIndir)
+ encodeReflectValue(state, mv.Elem(key), elemOp, elemIndir)
+ }
+ return state.err
+}
+
var encOpMap = map[reflect.Type]encOp{
valueKind(false): encBool,
valueKind(int(0)): encInt,
@@ -344,7 +377,6 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
typ, indir := indirect(rt)
op, ok := encOpMap[reflect.Typeof(typ)]
if !ok {
- typ, _ := indirect(rt)
// Special cases
switch t := typ.(type) {
case *reflect.SliceType:
@@ -363,7 +395,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
return
}
state.update(i)
- state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), int(slice.Len), indir)
+ state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len))
}
case *reflect.ArrayType:
// True arrays have size in the type.
@@ -373,7 +405,20 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
}
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i)
- state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), t.Len(), indir)
+ state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
+ }
+ case *reflect.MapType:
+ keyOp, keyIndir, err := encOpFor(t.Key())
+ if err != nil {
+ return nil, 0, err
+ }
+ elemOp, elemIndir, err := encOpFor(t.Elem())
+ if err != nil {
+ return nil, 0, err
+ }
+ op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
+ state.update(i)
+ state.err = encodeMap(state.b, typ, uintptr(p), keyOp, elemOp, keyIndir, elemIndir)
}
case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type.
diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go
index 8ba503138..d65a71080 100644
--- a/src/pkg/gob/encoder.go
+++ b/src/pkg/gob/encoder.go
@@ -71,9 +71,8 @@
Structs, arrays and slices are also supported. Strings and arrays of bytes are
supported with a special, efficient representation (see below).
- Maps are not supported yet, but they will be. Interfaces, functions, and channels
- cannot be sent in a gob. Attempting to encode a value that contains one will
- fail.
+ Interfaces, functions, and channels cannot be sent in a gob. Attempting
+ to encode a value that contains one will fail.
The rest of this comment documents the encoding, details that are not important
for most users. Details are presented bottom-up.
@@ -263,10 +262,13 @@ func (enc *Encoder) sendType(origt reflect.Type) {
case *reflect.ArrayType:
// arrays must be sent so we know their lengths and element types.
break
+ case *reflect.MapType:
+ // maps must be sent so we know their lengths and key/value types.
+ break
case *reflect.StructType:
// structs must be sent so we know their fields.
break
- case *reflect.ChanType, *reflect.FuncType, *reflect.MapType, *reflect.InterfaceType:
+ case *reflect.ChanType, *reflect.FuncType, *reflect.InterfaceType:
// Probably a bad field in a struct.
enc.badType(rt)
return
diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go
index 2a178af04..78793ba44 100644
--- a/src/pkg/gob/type.go
+++ b/src/pkg/gob/type.go
@@ -142,6 +142,31 @@ func (a *arrayType) safeString(seen map[typeId]bool) string {
func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
+// Map type
+type mapType struct {
+ commonType
+ Key typeId
+ Elem typeId
+}
+
+func newMapType(name string, key, elem gobType) *mapType {
+ m := &mapType{commonType{name: name}, key.id(), elem.id()}
+ setTypeId(m)
+ return m
+}
+
+func (m *mapType) safeString(seen map[typeId]bool) string {
+ if seen[m._id] {
+ return m.name
+ }
+ seen[m._id] = true
+ key := m.Key.gobType().safeString(seen)
+ elem := m.Elem.gobType().safeString(seen)
+ return fmt.Sprintf("map[%s]%s", key, elem)
+}
+
+func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) }
+
// Slice type
type sliceType struct {
commonType
@@ -239,6 +264,17 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
}
return newArrayType(name, gt, t.Len()), nil
+ case *reflect.MapType:
+ kt, err := getType("", t.Key())
+ if err != nil {
+ return nil, err
+ }
+ vt, err := getType("", t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return newMapType(name, kt, vt), nil
+
case *reflect.SliceType:
// []byte == []uint8 is a special case
if _, ok := t.Elem().(*reflect.Uint8Type); ok {
@@ -330,16 +366,18 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
// using the gob rules for sending a structure, except that we assume the
// ids for wireType and structType are known. The relevant pieces
// are built in encode.go's init() function.
-
+// To maintain binary compatibility, if you extend this type, always put
+// the new fields last.
type wireType struct {
- array *arrayType
- slice *sliceType
- strct *structType
+ arrayT *arrayType
+ sliceT *sliceType
+ structT *structType
+ mapT *mapType
}
func (w *wireType) name() string {
- if w.strct != nil {
- return w.strct.name
+ if w.structT != nil {
+ return w.structT.name
}
return "unknown"
}
@@ -370,14 +408,16 @@ func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) {
t := info.id.gobType()
switch typ := rt.(type) {
case *reflect.ArrayType:
- info.wire = &wireType{array: t.(*arrayType)}
+ info.wire = &wireType{arrayT: t.(*arrayType)}
+ case *reflect.MapType:
+ info.wire = &wireType{mapT: t.(*mapType)}
case *reflect.SliceType:
// []byte == []uint8 is a special case handled separately
if _, ok := typ.Elem().(*reflect.Uint8Type); !ok {
- info.wire = &wireType{slice: t.(*sliceType)}
+ info.wire = &wireType{sliceT: t.(*sliceType)}
}
case *reflect.StructType:
- info.wire = &wireType{strct: t.(*structType)}
+ info.wire = &wireType{structT: t.(*structType)}
}
typeInfoMap[rt] = info
}
diff --git a/src/pkg/gob/type_test.go b/src/pkg/gob/type_test.go
index 3d4871f1d..6acfa7135 100644
--- a/src/pkg/gob/type_test.go
+++ b/src/pkg/gob/type_test.go
@@ -105,6 +105,26 @@ func TestSliceType(t *testing.T) {
}
}
+func TestMapType(t *testing.T) {
+ var m map[string]int
+ mapStringInt := getTypeUnlocked("map", reflect.Typeof(m))
+ var newm map[string]int
+ newMapStringInt := getTypeUnlocked("map1", reflect.Typeof(newm))
+ if mapStringInt != newMapStringInt {
+ t.Errorf("second registration of map[string]int creates new type")
+ }
+ var b map[string]bool
+ mapStringBool := getTypeUnlocked("", reflect.Typeof(b))
+ if mapStringBool == mapStringInt {
+ t.Errorf("registration of map[string]bool creates same type as map[string]int")
+ }
+ str := mapStringBool.string()
+ expected := "map[string]bool"
+ if str != expected {
+ t.Errorf("map printed as %q; expected %q", str, expected)
+ }
+}
+
type Bar struct {
x string
}