From c3a14037297827d36b77a49cf0cbd0849295e79a Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Thu, 24 Jun 2010 15:07:28 -0700 Subject: gob: add support for complex numbers R=rsc CC=golang-dev http://codereview.appspot.com/1708048 --- src/pkg/gob/codec_test.go | 63 ++++++++++++++++++++++- src/pkg/gob/decode.go | 126 ++++++++++++++++++++++++++++++---------------- src/pkg/gob/encode.go | 75 ++++++++++++++++++++------- src/pkg/gob/type.go | 29 ++++++++--- 4 files changed, 224 insertions(+), 69 deletions(-) diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index dad0ac48c..d8bdf2d2f 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -112,7 +112,9 @@ var boolResult = []byte{0x07, 0x01} var signedResult = []byte{0x07, 2 * 17} var unsignedResult = []byte{0x07, 17} var floatResult = []byte{0x07, 0xFE, 0x31, 0x40} -// The result of encoding "hello" with field number 6 +// The result of encoding a number 17+19i with field number 7 +var complexResult = []byte{0x07, 0xFE, 0x31, 0x40, 0xFE, 0x33, 0x40} +// The result of encoding "hello" with field number 7 var bytesResult = []byte{0x07, 0x05, 'h', 'e', 'l', 'l', 'o'} func newencoderState(b *bytes.Buffer) *encoderState { @@ -537,6 +539,45 @@ func TestScalarDecInstructions(t *testing.T) { } } + // complex + { + var data struct { + a complex + } + instr := &decInstr{decOpMap[reflect.Complex], 6, 0, 0, ovfl} + state := newDecodeStateFromData(complexResult) + execDec("complex", instr, state, t, unsafe.Pointer(&data)) + if data.a != 17+19i { + t.Errorf("complex a = %v not 17+19i", data.a) + } + } + + // complex64 + { + var data struct { + a complex64 + } + instr := &decInstr{decOpMap[reflect.Complex64], 6, 0, 0, ovfl} + state := newDecodeStateFromData(complexResult) + execDec("complex", instr, state, t, unsafe.Pointer(&data)) + if data.a != 17+19i { + t.Errorf("complex a = %v not 17+19i", data.a) + } + } + + // complex128 + { + var data struct { + a complex128 + } + instr := &decInstr{decOpMap[reflect.Complex128], 6, 0, 0, ovfl} + state := newDecodeStateFromData(complexResult) + execDec("complex", instr, state, t, unsafe.Pointer(&data)) + if data.a != 17+19i { + t.Errorf("complex a = %v not 17+19i", data.a) + } + } + // bytes == []uint8 { var data struct { @@ -576,6 +617,7 @@ func TestEndToEnd(t *testing.T) { n *[3]float strs *[2]string int64s *[]int64 + ri complex64 s string y []byte t *T2 @@ -590,6 +632,7 @@ func TestEndToEnd(t *testing.T) { n: &[3]float{1.5, 2.5, 3.5}, strs: &[2]string{s1, s2}, int64s: &[]int64{77, 89, 123412342134}, + ri: 17 - 23i, s: "Now is the time", y: []byte("hello, sailor"), t: &T2{"this is T2"}, @@ -616,6 +659,8 @@ func TestOverflow(t *testing.T) { maxu uint64 maxf float64 minf float64 + maxc complex128 + minc complex128 } var it inputT var err os.Error @@ -758,6 +803,22 @@ func TestOverflow(t *testing.T) { if err == nil || err.String() != `value for "maxf" out of range` { t.Error("wrong overflow error for float32:", err) } + + // complex64 + b.Reset() + it = inputT{ + maxc: cmplx(math.MaxFloat32*2, math.MaxFloat32*2), + } + type outc64 struct { + maxc complex64 + minc complex64 + } + var o8 outc64 + enc.Encode(it) + err = dec.Decode(&o8) + if err == nil || err.String() != `value for "maxc" out of range` { + t.Error("wrong overflow error for complex64:", err) + } } diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 8f5c383ea..0dbf81488 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -151,6 +151,11 @@ func ignoreUint(i *decInstr, state *decodeState, p unsafe.Pointer) { decodeUint(state) } +func ignoreTwoUints(i *decInstr, state *decodeState, p unsafe.Pointer) { + decodeUint(state) + decodeUint(state) +} + func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { @@ -286,25 +291,30 @@ func floatFromBits(u uint64) float64 { return math.Float64frombits(v) } -func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { - if i.indir > 0 { - if *(*unsafe.Pointer)(p) == nil { - *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32)) - } - p = *(*unsafe.Pointer)(p) - } +func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { v := floatFromBits(decodeUint(state)) av := v if av < 0 { av = -av } - if math.MaxFloat32 < av { // underflow is OK + // +Inf is OK in both 32- and 64-bit floats. Underflow is always OK. + if math.MaxFloat32 < av && av <= math.MaxFloat64 { state.err = i.ovfl } else { *(*float32)(p) = float32(v) } } +func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { + if i.indir > 0 { + if *(*unsafe.Pointer)(p) == nil { + *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32)) + } + p = *(*unsafe.Pointer)(p) + } + storeFloat32(i, state, p) +} + func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { @@ -315,6 +325,30 @@ func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*float64)(p) = floatFromBits(uint64(decodeUint(state))) } +// Complex numbers are just a pair of floating-point numbers, real part first. +func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) { + if i.indir > 0 { + if *(*unsafe.Pointer)(p) == nil { + *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex64)) + } + p = *(*unsafe.Pointer)(p) + } + storeFloat32(i, state, p) + storeFloat32(i, state, unsafe.Pointer(uintptr(p)+uintptr(unsafe.Sizeof(float(0))))) +} + +func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) { + if i.indir > 0 { + if *(*unsafe.Pointer)(p) == nil { + *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex128)) + } + p = *(*unsafe.Pointer)(p) + } + real := floatFromBits(uint64(decodeUint(state))) + imag := floatFromBits(uint64(decodeUint(state))) + *(*complex128)(p) = cmplx(real, imag) +} + // uint8 arrays are encoded as an unsigned count followed by the raw bytes. func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { @@ -540,21 +574,25 @@ func ignoreSlice(state *decodeState, elemOp decOp) os.Error { return ignoreArrayHelper(state, elemOp, int(decodeUint(state))) } +// Index by Go types. var decOpMap = []decOp{ - reflect.Bool: decBool, - reflect.Int8: decInt8, - reflect.Int16: decInt16, - reflect.Int32: decInt32, - reflect.Int64: decInt64, - reflect.Uint8: decUint8, - reflect.Uint16: decUint16, - reflect.Uint32: decUint32, - reflect.Uint64: decUint64, - reflect.Float32: decFloat32, - reflect.Float64: decFloat64, - reflect.String: decString, -} - + reflect.Bool: decBool, + reflect.Int8: decInt8, + reflect.Int16: decInt16, + reflect.Int32: decInt32, + reflect.Int64: decInt64, + reflect.Uint8: decUint8, + reflect.Uint16: decUint16, + reflect.Uint32: decUint32, + reflect.Uint64: decUint64, + reflect.Float32: decFloat32, + reflect.Float64: decFloat64, + reflect.Complex64: decComplex64, + reflect.Complex128: decComplex128, + reflect.String: decString, +} + +// Indexed by gob types. tComplex will be added during type.init(). var decIgnoreOpMap = map[typeId]decOp{ tBool: ignoreUint, tInt: ignoreUint, @@ -652,6 +690,8 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { // Special cases wire := dec.wireType[wireId] switch { + case wire == nil: + panic("internal error: can't find ignore op for type " + wireId.string()) case wire.arrayT != nil: elemId := wire.arrayT.Elem elemOp, err := dec.decIgnoreOpFor(elemId) @@ -728,6 +768,8 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { return fw == tUint case *reflect.FloatType: return fw == tFloat + case *reflect.ComplexType: + return fw == tComplex case *reflect.StringType: return fw == tString case *reflect.ArrayType: @@ -866,39 +908,39 @@ func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error { } func init() { - // We assume that the size of float is sufficient to tell us whether it is - // equivalent to float32 or to float64. This is very unlikely to be wrong. - var op decOp - switch unsafe.Sizeof(float(0)) { - case unsafe.Sizeof(float32(0)): - op = decFloat32 - case unsafe.Sizeof(float64(0)): - op = decFloat64 + var fop, cop decOp + switch reflect.Typeof(float(0)).Bits() { + case 32: + fop = decFloat32 + cop = decComplex64 + case 64: + fop = decFloat64 + cop = decComplex128 default: panic("gob: unknown size of float") } - decOpMap[reflect.Float] = op + decOpMap[reflect.Float] = fop + decOpMap[reflect.Complex] = cop - // A similar assumption about int and uint. Also assume int and uint have the same size. - var uop decOp - switch unsafe.Sizeof(int(0)) { - case unsafe.Sizeof(int32(0)): - op = decInt32 + var iop, uop decOp + switch reflect.Typeof(int(0)).Bits() { + case 32: + iop = decInt32 uop = decUint32 - case unsafe.Sizeof(int64(0)): - op = decInt64 + case 64: + iop = decInt64 uop = decUint64 default: panic("gob: unknown size of int/uint") } - decOpMap[reflect.Int] = op + decOpMap[reflect.Int] = iop decOpMap[reflect.Uint] = uop // Finally uintptr - switch unsafe.Sizeof(uintptr(0)) { - case unsafe.Sizeof(uint32(0)): + switch reflect.Typeof(uintptr(0)).Bits() { + case 32: uop = decUint32 - case unsafe.Sizeof(uint64(0)): + case 64: uop = decUint64 default: panic("gob: unknown size of uintptr") diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index 76032389e..b48c1f698 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -467,7 +467,7 @@ func floatBits(f float64) uint64 { } func encFloat(i *encInstr, state *encoderState, p unsafe.Pointer) { - f := float(*(*float)(p)) + f := *(*float)(p) if f != 0 || state.inArray { v := floatBits(float64(f)) state.update(i) @@ -476,7 +476,7 @@ func encFloat(i *encInstr, state *encoderState, p unsafe.Pointer) { } func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { - f := float32(*(*float32)(p)) + f := *(*float32)(p) if f != 0 || state.inArray { v := floatBits(float64(f)) state.update(i) @@ -493,6 +493,40 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// Complex numbers are just a pair of floating-point numbers, real part first. +func encComplex(i *encInstr, state *encoderState, p unsafe.Pointer) { + c := *(*complex)(p) + if c != 0+0i || state.inArray { + rpart := floatBits(float64(real(c))) + ipart := floatBits(float64(imag(c))) + state.update(i) + encodeUint(state, rpart) + encodeUint(state, ipart) + } +} + +func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) { + c := *(*complex64)(p) + if c != 0+0i || state.inArray { + rpart := floatBits(float64(real(c))) + ipart := floatBits(float64(imag(c))) + state.update(i) + encodeUint(state, rpart) + encodeUint(state, ipart) + } +} + +func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { + c := *(*complex128)(p) + if c != 0+0i || state.inArray { + rpart := floatBits(real(c)) + ipart := floatBits(imag(c)) + state.update(i) + encodeUint(state, rpart) + encodeUint(state, ipart) + } +} + // Byte arrays are encoded as an unsigned count followed by the raw bytes. func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*[]byte)(p) @@ -602,22 +636,25 @@ func encodeMap(b *bytes.Buffer, rt reflect.Type, p uintptr, keyOp, elemOp encOp, } var encOpMap = []encOp{ - reflect.Bool: encBool, - reflect.Int: encInt, - reflect.Int8: encInt8, - reflect.Int16: encInt16, - reflect.Int32: encInt32, - reflect.Int64: encInt64, - reflect.Uint: encUint, - reflect.Uint8: encUint8, - reflect.Uint16: encUint16, - reflect.Uint32: encUint32, - reflect.Uint64: encUint64, - reflect.Uintptr: encUintptr, - reflect.Float: encFloat, - reflect.Float32: encFloat32, - reflect.Float64: encFloat64, - reflect.String: encString, + reflect.Bool: encBool, + reflect.Int: encInt, + reflect.Int8: encInt8, + reflect.Int16: encInt16, + reflect.Int32: encInt32, + reflect.Int64: encInt64, + reflect.Uint: encUint, + reflect.Uint8: encUint8, + reflect.Uint16: encUint16, + reflect.Uint32: encUint32, + reflect.Uint64: encUint64, + reflect.Uintptr: encUintptr, + reflect.Float: encFloat, + reflect.Float32: encFloat32, + reflect.Float64: encFloat64, + reflect.Complex: encComplex, + reflect.Complex64: encComplex64, + reflect.Complex128: encComplex128, + reflect.String: encString, } // Return the encoding op for the base type under rt and @@ -688,7 +725,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) { } } if op == nil { - return op, indir, os.ErrorString("gob enc: can't happen: encode type" + rt.String()) + return op, indir, os.ErrorString("gob enc: can't happen: encode type " + rt.String()) } return op, indir, nil } diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index 6a3e6ba65..2ad36ae65 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -78,12 +78,17 @@ func (t *commonType) Name() string { return t.name } // Create and check predefined types // The string for tBytes is "bytes" not "[]byte" to signify its specialness. -var tBool = bootstrapType("bool", false, 1) -var tInt = bootstrapType("int", int(0), 2) -var tUint = bootstrapType("uint", uint(0), 3) -var tFloat = bootstrapType("float", float64(0), 4) -var tBytes = bootstrapType("bytes", make([]byte, 0), 5) -var tString = bootstrapType("string", "", 6) +var ( + // Primordial types, needed during initialization. + tBool = bootstrapType("bool", false, 1) + tInt = bootstrapType("int", int(0), 2) + tUint = bootstrapType("uint", uint(0), 3) + tFloat = bootstrapType("float", float64(0), 4) + tBytes = bootstrapType("bytes", make([]byte, 0), 5) + tString = bootstrapType("string", "", 6) + // Types added to the language later, not needed during initialization. + tComplex typeId +) // Predefined because it's needed by the Decoder var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id @@ -94,10 +99,17 @@ func init() { checkId(9, mustGetTypeInfo(reflect.Typeof(commonType{})).id) checkId(11, mustGetTypeInfo(reflect.Typeof(structType{})).id) checkId(12, mustGetTypeInfo(reflect.Typeof(fieldType{})).id) + + // Complex was added after gob was written, so appears after the + // fundamental types are built. + tComplex = bootstrapType("complex", 0+0i, 15) + decIgnoreOpMap[tComplex] = ignoreTwoUints + builtinIdToType = make(map[typeId]gobType) for k, v := range idToType { builtinIdToType[k] = v } + // Move the id space upwards to allow for growth in the predefined world // without breaking existing files. if nextId > firstUserId { @@ -241,6 +253,9 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { case *reflect.FloatType: return tFloat.gobType(), nil + case *reflect.ComplexType: + return tComplex.gobType(), nil + case *reflect.StringType: return tString.gobType(), nil @@ -335,7 +350,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { rt := reflect.Typeof(e) _, present := types[rt] if present { - panic("bootstrap type already present: " + name) + panic("bootstrap type already present: " + name + ", " + rt.String()) } typ := &commonType{name: name} types[rt] = typ -- cgit v1.2.3