diff options
Diffstat (limited to 'src/pkg/net/dnsmsg.go')
| -rw-r--r-- | src/pkg/net/dnsmsg.go | 71 |
1 files changed, 41 insertions, 30 deletions
diff --git a/src/pkg/net/dnsmsg.go b/src/pkg/net/dnsmsg.go index 630dbd1e9..f136b8c08 100644 --- a/src/pkg/net/dnsmsg.go +++ b/src/pkg/net/dnsmsg.go @@ -369,28 +369,33 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o f := val.Type().(*reflect.StructType).Field(i) switch fv := val.Field(i).(type) { default: + BadType: fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) return len(msg), false case *reflect.StructValue: off, ok = packStructValue(fv, msg, off) - case *reflect.Uint16Value: + case *reflect.UintValue: i := fv.Get() - if off+2 > len(msg) { - return len(msg), false - } - msg[off] = byte(i >> 8) - msg[off+1] = byte(i) - off += 2 - case *reflect.Uint32Value: - i := fv.Get() - if off+4 > len(msg) { - return len(msg), false + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 } - msg[off] = byte(i >> 24) - msg[off+1] = byte(i >> 16) - msg[off+2] = byte(i >> 8) - msg[off+3] = byte(i) - off += 4 case *reflect.StringValue: // There are multiple string encodings. // The tag distinguishes ordinary strings from domain names. @@ -438,24 +443,30 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, f := val.Type().(*reflect.StructType).Field(i) switch fv := val.Field(i).(type) { default: + BadType: fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) return len(msg), false case *reflect.StructValue: off, ok = unpackStructValue(fv, msg, off) - case *reflect.Uint16Value: - if off+2 > len(msg) { - return len(msg), false - } - i := uint16(msg[off])<<8 | uint16(msg[off+1]) - fv.Set(i) - off += 2 - case *reflect.Uint32Value: - if off+4 > len(msg) { - return len(msg), false + case *reflect.UintValue: + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + i := uint16(msg[off])<<8 | uint16(msg[off+1]) + fv.Set(uint64(i)) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + fv.Set(uint64(i)) + off += 4 } - i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) - fv.Set(i) - off += 4 case *reflect.StringValue: var s string switch f.Tag { @@ -508,7 +519,7 @@ func printStructValue(val *reflect.StructValue) string { fval := val.Field(i) if fv, ok := fval.(*reflect.StructValue); ok { s += printStructValue(fv) - } else if fv, ok := fval.(*reflect.Uint32Value); ok && f.Tag == "ipv4" { + } else if fv, ok := fval.(*reflect.UintValue); ok && f.Tag == "ipv4" { i := fv.Get() s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() } else { |
