diff options
Diffstat (limited to 'src/pkg/net/lookup_windows.go')
-rw-r--r-- | src/pkg/net/lookup_windows.go | 155 |
1 files changed, 140 insertions, 15 deletions
diff --git a/src/pkg/net/lookup_windows.go b/src/pkg/net/lookup_windows.go index 99783e975..3b29724f2 100644 --- a/src/pkg/net/lookup_windows.go +++ b/src/pkg/net/lookup_windows.go @@ -6,21 +6,17 @@ package net import ( "os" - "sync" + "runtime" "syscall" "unsafe" ) var ( - protoentLock sync.Mutex - hostentLock sync.Mutex - serventLock sync.Mutex + lookupPort = oldLookupPort + lookupIP = oldLookupIP ) -// lookupProtocol looks up IP protocol name and returns correspondent protocol number. -func lookupProtocol(name string) (proto int, err error) { - protoentLock.Lock() - defer protoentLock.Unlock() +func getprotobyname(name string) (proto int, err error) { p, err := syscall.GetProtoByName(name) if err != nil { return 0, os.NewSyscallError("GetProtoByName", err) @@ -28,6 +24,25 @@ func lookupProtocol(name string) (proto int, err error) { return int(p.Proto), nil } +// lookupProtocol looks up IP protocol name and returns correspondent protocol number. +func lookupProtocol(name string) (proto int, err error) { + // GetProtoByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + proto int + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + proto, err := getprotobyname(name) + ch <- result{proto: proto, err: err} + }() + r := <-ch + return r.proto, r.err +} + func lookupHost(name string) (addrs []string, err error) { ips, err := LookupIP(name) if err != nil { @@ -40,9 +55,7 @@ func lookupHost(name string) (addrs []string, err error) { return } -func lookupIP(name string) (addrs []IP, err error) { - hostentLock.Lock() - defer hostentLock.Unlock() +func gethostbyname(name string) (addrs []IP, err error) { h, err := syscall.GetHostByName(name) if err != nil { return nil, os.NewSyscallError("GetHostByName", err) @@ -56,20 +69,65 @@ func lookupIP(name string) (addrs []IP, err error) { } addrs = addrs[0:i] default: // TODO(vcc): Implement non IPv4 address lookups. - return nil, os.NewSyscallError("LookupHost", syscall.EWINDOWS) + return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS) } return addrs, nil } -func lookupPort(network, service string) (port int, err error) { +func oldLookupIP(name string) (addrs []IP, err error) { + // GetHostByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + addrs []IP + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + addrs, err := gethostbyname(name) + ch <- result{addrs: addrs, err: err} + }() + r := <-ch + return r.addrs, r.err +} + +func newLookupIP(name string) (addrs []IP, err error) { + hints := syscall.AddrinfoW{ + Family: syscall.AF_UNSPEC, + Socktype: syscall.SOCK_STREAM, + Protocol: syscall.IPPROTO_IP, + } + var result *syscall.AddrinfoW + e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result) + if e != nil { + return nil, os.NewSyscallError("GetAddrInfoW", e) + } + defer syscall.FreeAddrInfoW(result) + addrs = make([]IP, 0, 5) + for ; result != nil; result = result.Next { + addr := unsafe.Pointer(result.Addr) + switch result.Family { + case syscall.AF_INET: + a := (*syscall.RawSockaddrInet4)(addr).Addr + addrs = append(addrs, IPv4(a[0], a[1], a[2], a[3])) + case syscall.AF_INET6: + a := (*syscall.RawSockaddrInet6)(addr).Addr + addrs = append(addrs, IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}) + default: + return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS) + } + } + return addrs, nil +} + +func getservbyname(network, service string) (port int, err error) { switch network { case "tcp4", "tcp6": network = "tcp" case "udp4", "udp6": network = "udp" } - serventLock.Lock() - defer serventLock.Unlock() s, err := syscall.GetServByName(service, network) if err != nil { return 0, os.NewSyscallError("GetServByName", err) @@ -77,6 +135,58 @@ func lookupPort(network, service string) (port int, err error) { return int(syscall.Ntohs(s.Port)), nil } +func oldLookupPort(network, service string) (port int, err error) { + // GetServByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + port int + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + port, err := getservbyname(network, service) + ch <- result{port: port, err: err} + }() + r := <-ch + return r.port, r.err +} + +func newLookupPort(network, service string) (port int, err error) { + var stype int32 + switch network { + case "tcp4", "tcp6": + stype = syscall.SOCK_STREAM + case "udp4", "udp6": + stype = syscall.SOCK_DGRAM + } + hints := syscall.AddrinfoW{ + Family: syscall.AF_UNSPEC, + Socktype: stype, + Protocol: syscall.IPPROTO_IP, + } + var result *syscall.AddrinfoW + e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result) + if e != nil { + return 0, os.NewSyscallError("GetAddrInfoW", e) + } + defer syscall.FreeAddrInfoW(result) + if result == nil { + return 0, os.NewSyscallError("LookupPort", syscall.EINVAL) + } + addr := unsafe.Pointer(result.Addr) + switch result.Family { + case syscall.AF_INET: + a := (*syscall.RawSockaddrInet4)(addr) + return int(syscall.Ntohs(a.Port)), nil + case syscall.AF_INET6: + a := (*syscall.RawSockaddrInet6)(addr) + return int(syscall.Ntohs(a.Port)), nil + } + return 0, os.NewSyscallError("LookupPort", syscall.EINVAL) +} + func lookupCNAME(name string) (cname string, err error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) @@ -129,6 +239,21 @@ func lookupMX(name string) (mx []*MX, err error) { return mx, nil } +func lookupNS(name string) (ns []*NS, err error) { + var r *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil) + if e != nil { + return nil, os.NewSyscallError("LookupNS", e) + } + defer syscall.DnsRecordListFree(r, 1) + ns = make([]*NS, 0, 10) + for p := r; p != nil && p.Type == syscall.DNS_TYPE_NS; p = p.Next { + v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) + ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."}) + } + return ns, nil +} + func lookupTXT(name string) (txt []string, err error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil) |