diff options
author | Ondřej Surý <ondrej@sury.org> | 2011-04-26 09:55:32 +0200 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2011-04-26 09:55:32 +0200 |
commit | 7b15ed9ef455b6b66c6b376898a88aef5d6a9970 (patch) | |
tree | 3ef530baa80cdf29436ba981f5783be6b4d2202b /src/pkg/net | |
parent | 50104cc32a498f7517a51c8dc93106c51c7a54b4 (diff) | |
download | golang-7b15ed9ef455b6b66c6b376898a88aef5d6a9970.tar.gz |
Imported Upstream version 2011.04.13upstream/2011.04.13
Diffstat (limited to 'src/pkg/net')
36 files changed, 1027 insertions, 329 deletions
diff --git a/src/pkg/net/Makefile b/src/pkg/net/Makefile index 6b6d7c0e3..7ce650279 100644 --- a/src/pkg/net/Makefile +++ b/src/pkg/net/Makefile @@ -6,6 +6,7 @@ include ../../Make.inc TARG=net GOFILES=\ + cgo_stub.go\ dial.go\ dnsmsg.go\ fd_$(GOOS).go\ @@ -13,6 +14,7 @@ GOFILES=\ ip.go\ ipsock.go\ iprawsock.go\ + lookup.go\ net.go\ parse.go\ pipe.go\ @@ -24,6 +26,7 @@ GOFILES=\ GOFILES_freebsd=\ newpollserver.go\ fd.go\ + file.go\ dnsconfig.go\ dnsclient.go\ port.go\ @@ -31,6 +34,7 @@ GOFILES_freebsd=\ GOFILES_darwin=\ newpollserver.go\ fd.go\ + file.go\ dnsconfig.go\ dnsclient.go\ port.go\ @@ -38,12 +42,14 @@ GOFILES_darwin=\ GOFILES_linux=\ newpollserver.go\ fd.go\ + file.go\ dnsconfig.go\ dnsclient.go\ port.go\ GOFILES_windows=\ resolv_windows.go\ + file_windows.go\ GOFILES+=$(GOFILES_$(GOOS)) diff --git a/src/pkg/net/cgo_stub.go b/src/pkg/net/cgo_stub.go new file mode 100644 index 000000000..e28f6622e --- /dev/null +++ b/src/pkg/net/cgo_stub.go @@ -0,0 +1,21 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Stub cgo routines for systems that do not use cgo to do network lookups. + +package net + +import "os" + +func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) { + return nil, nil, false +} + +func cgoLookupPort(network, service string) (port int, err os.Error, completed bool) { + return 0, nil, false +} + +func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { + return nil, nil, false +} diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index 1cf8e7915..66cb09b19 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -6,9 +6,7 @@ package net import "os" -// Dial connects to the remote address raddr on the network net. -// If the string laddr is not empty, it is used as the local address -// for the connection. +// Dial connects to the address addr on the network net. // // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), // "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" @@ -16,79 +14,56 @@ import "os" // // For IP networks, addresses have the form host:port. If host is // a literal IPv6 address, it must be enclosed in square brackets. +// The functions JoinHostPort and SplitHostPort manipulate +// addresses in this form. // // Examples: -// Dial("tcp", "", "12.34.56.78:80") -// Dial("tcp", "", "google.com:80") -// Dial("tcp", "", "[de:ad:be:ef::ca:fe]:80") -// Dial("tcp", "127.0.0.1:123", "127.0.0.1:88") +// Dial("tcp", "12.34.56.78:80") +// Dial("tcp", "google.com:80") +// Dial("tcp", "[de:ad:be:ef::ca:fe]:80") // -func Dial(net, laddr, raddr string) (c Conn, err os.Error) { +func Dial(net, addr string) (c Conn, err os.Error) { + raddr := addr + if raddr == "" { + return nil, &OpError{"dial", net, nil, errMissingAddress} + } switch net { case "tcp", "tcp4", "tcp6": - var la, ra *TCPAddr - if laddr != "" { - if la, err = ResolveTCPAddr(laddr); err != nil { - goto Error - } - } - if raddr != "" { - if ra, err = ResolveTCPAddr(raddr); err != nil { - goto Error - } + var ra *TCPAddr + if ra, err = ResolveTCPAddr(raddr); err != nil { + goto Error } - c, err := DialTCP(net, la, ra) + c, err := DialTCP(net, nil, ra) if err != nil { return nil, err } return c, nil case "udp", "udp4", "udp6": - var la, ra *UDPAddr - if laddr != "" { - if la, err = ResolveUDPAddr(laddr); err != nil { - goto Error - } - } - if raddr != "" { - if ra, err = ResolveUDPAddr(raddr); err != nil { - goto Error - } + var ra *UDPAddr + if ra, err = ResolveUDPAddr(raddr); err != nil { + goto Error } - c, err := DialUDP(net, la, ra) + c, err := DialUDP(net, nil, ra) if err != nil { return nil, err } return c, nil case "unix", "unixgram", "unixpacket": - var la, ra *UnixAddr - if raddr != "" { - if ra, err = ResolveUnixAddr(net, raddr); err != nil { - goto Error - } - } - if laddr != "" { - if la, err = ResolveUnixAddr(net, laddr); err != nil { - goto Error - } + var ra *UnixAddr + if ra, err = ResolveUnixAddr(net, raddr); err != nil { + goto Error } - c, err = DialUnix(net, la, ra) + c, err = DialUnix(net, nil, ra) if err != nil { return nil, err } return c, nil case "ip", "ip4", "ip6": - var la, ra *IPAddr - if laddr != "" { - if la, err = ResolveIPAddr(laddr); err != nil { - goto Error - } - } - if raddr != "" { - if ra, err = ResolveIPAddr(raddr); err != nil { - goto Error - } + var ra *IPAddr + if ra, err = ResolveIPAddr(raddr); err != nil { + goto Error } - c, err := DialIP(net, la, ra) + c, err := DialIP(net, nil, ra) if err != nil { return nil, err } diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go index a432800cf..9a9c02ebd 100644 --- a/src/pkg/net/dialgoogle_test.go +++ b/src/pkg/net/dialgoogle_test.go @@ -32,7 +32,7 @@ func fetchGoogle(t *testing.T, fd Conn, network, addr string) { } func doDial(t *testing.T, network, addr string) { - fd, err := Dial(network, "", addr) + fd, err := Dial(network, addr) if err != nil { t.Errorf("Dial(%q, %q, %q) = _, %v", network, "", addr, err) return @@ -55,6 +55,13 @@ var googleaddrs = []string{ "[2001:4860:0:2001::68]:80", // ipv6.google.com; removed if ipv6 flag not set } +func TestLookupCNAME(t *testing.T) { + cname, err := LookupCNAME("www.google.com") + if cname != "www.l.google.com." || err != nil { + t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "www.l.google.com.", nil`, cname, err) + } +} + func TestDialGoogle(t *testing.T) { // If no ipv6 tunnel, don't try the last address. if !*ipv6 { @@ -64,14 +71,14 @@ func TestDialGoogle(t *testing.T) { // Insert an actual IP address for google.com // into the table. - _, addrs, err := LookupHost("www.google.com") + addrs, err := LookupIP("www.google.com") if err != nil { t.Fatalf("lookup www.google.com: %v", err) } if len(addrs) == 0 { t.Fatalf("no addresses for www.google.com") } - ip := ParseIP(addrs[0]).To4() + ip := addrs[0].To4() for i, s := range googleaddrs { if strings.Contains(s, "%") { diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go index 3252dd454..c3e727bce 100644 --- a/src/pkg/net/dnsclient.go +++ b/src/pkg/net/dnsclient.go @@ -21,6 +21,7 @@ import ( "rand" "sync" "time" + "sort" ) // DNSError represents a DNS lookup error. @@ -159,7 +160,7 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs // all the cfg.servers[i] are IP addresses, which // Dial will use without a DNS lookup. server := cfg.servers[i] + ":53" - c, cerr := Dial("udp", "", server) + c, cerr := Dial("udp", server) if cerr != nil { err = cerr continue @@ -178,12 +179,23 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs return } -func convertRR_A(records []dnsRR) []string { - addrs := make([]string, len(records)) +func convertRR_A(records []dnsRR) []IP { + addrs := make([]IP, len(records)) for i := 0; i < len(records); i++ { rr := records[i] a := rr.(*dnsRR_A).A - addrs[i] = IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a)).String() + addrs[i] = IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a)) + } + return addrs +} + +func convertRR_AAAA(records []dnsRR) []IP { + addrs := make([]IP, len(records)) + for i := 0; i < len(records); i++ { + rr := records[i] + a := make(IP, 16) + copy(a, rr.(*dnsRR_AAAA).AAAA[:]) + addrs[i] = a } return addrs } @@ -294,10 +306,8 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Erro return } -// LookupHost looks for name using the local hosts file and DNS resolver. -// It returns the canonical name for the host and an array of that -// host's addresses. -func LookupHost(name string) (cname string, addrs []string, err os.Error) { +// goLookupHost is the native Go implementation of LookupHost. +func goLookupHost(name string) (addrs []string, err os.Error) { onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { err = dnserr @@ -306,18 +316,69 @@ func LookupHost(name string) (cname string, addrs []string, err os.Error) { // Use entries from /etc/hosts if they match. addrs = lookupStaticHost(name) if len(addrs) > 0 { - cname = name + return + } + ips, err := goLookupIP(name) + if err != nil { + return + } + addrs = make([]string, 0, len(ips)) + for _, ip := range ips { + addrs = append(addrs, ip.String()) + } + return +} + +// goLookupIP is the native Go implementation of LookupIP. +func goLookupIP(name string) (addrs []IP, err os.Error) { + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr return } var records []dnsRR + var cname string cname, records, err = lookup(name, dnsTypeA) if err != nil { return } addrs = convertRR_A(records) + if cname != "" { + name = cname + } + _, records, err = lookup(name, dnsTypeAAAA) + if err != nil && len(addrs) > 0 { + // Ignore error because A lookup succeeded. + err = nil + } + if err != nil { + return + } + addrs = append(addrs, convertRR_AAAA(records)...) return } +// LookupCNAME returns the canonical DNS host for the given name. +// Callers that do not care about the canonical name can call +// LookupHost or LookupIP directly; both take care of resolving +// the canonical name as part of the lookup. +func LookupCNAME(name string) (cname string, err os.Error) { + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } + _, rr, err := lookup(name, dnsTypeCNAME) + if err != nil { + return + } + if len(rr) >= 0 { + cname = rr[0].(*dnsRR_CNAME).Cname + } + return +} + +// An SRV represents a single DNS SRV record. type SRV struct { Target string Port uint16 @@ -344,22 +405,38 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os. return } +// An MX represents a single DNS MX record. type MX struct { Host string Pref uint16 } -func LookupMX(name string) (entries []*MX, err os.Error) { - var records []dnsRR - _, records, err = lookup(name, dnsTypeMX) +// byPref implements sort.Interface to sort MX records by preference +type byPref []*MX + +func (s byPref) Len() int { return len(s) } + +func (s byPref) Less(i, j int) bool { return s[i].Pref < s[j].Pref } + +func (s byPref) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// LookupMX returns the DNS MX records for the given domain name sorted by preference. +func LookupMX(name string) (mx []*MX, err os.Error) { + _, rr, err := lookup(name, dnsTypeMX) if err != nil { return } - entries = make([]*MX, len(records)) - for i := range records { - r := records[i].(*dnsRR_MX) - entries[i] = &MX{r.Mx, r.Pref} + mx = make([]*MX, len(rr)) + for i := range rr { + r := rr[i].(*dnsRR_MX) + mx[i] = &MX{r.Mx, r.Pref} + } + // Shuffle the records to match RFC 5321 when sorted + for i := range mx { + j := rand.Intn(i + 1) + mx[i], mx[j] = mx[j], mx[i] } + sort.Sort(byPref(mx)) return } diff --git a/src/pkg/net/dnsmsg.go b/src/pkg/net/dnsmsg.go index dc195caf8..e8eb8d958 100644 --- a/src/pkg/net/dnsmsg.go +++ b/src/pkg/net/dnsmsg.go @@ -50,6 +50,7 @@ const ( dnsTypeMINFO = 14 dnsTypeMX = 15 dnsTypeTXT = 16 + dnsTypeAAAA = 28 dnsTypeSRV = 33 // valid dnsQuestion.qtype only @@ -244,8 +245,18 @@ type dnsRR_A struct { A uint32 "ipv4" } -func (rr *dnsRR_A) Header() *dnsRR_Header { return &rr.Hdr } +func (rr *dnsRR_A) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_AAAA struct { + Hdr dnsRR_Header + AAAA [16]byte "ipv6" +} +func (rr *dnsRR_AAAA) Header() *dnsRR_Header { + return &rr.Hdr +} // Packing and unpacking. // @@ -270,6 +281,7 @@ var rr_mk = map[int]func() dnsRR{ dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) }, dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) }, dnsTypeA: func() dnsRR { return new(dnsRR_A) }, + dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) }, } // Pack a domain name s into msg[off:]. @@ -377,43 +389,49 @@ Loop: // TODO(rsc): Move into generic library? // Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string, -// and other (often anonymous) structs. -func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { +// [n]byte, and other (often anonymous) structs. +func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { for i := 0; i < val.NumField(); i++ { - f := val.Type().(*reflect.StructType).Field(i) - switch fv := val.Field(i).(type) { + f := val.Type().Field(i) + switch fv := val.Field(i); fv.Kind() { default: BadType: fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) return len(msg), false - case *reflect.StructValue: + case reflect.Struct: off, ok = packStructValue(fv, msg, off) - case *reflect.UintValue: - i := fv.Get() - switch fv.Type().Kind() { - default: + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + i := fv.Uint() + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := fv.Uint() + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + case reflect.Array: + if fv.Type().Elem().Kind() != reflect.Uint8 { goto BadType - case reflect.Uint16: - if off+2 > len(msg) { - return len(msg), false - } - msg[off] = byte(i >> 8) - msg[off+1] = byte(i) - off += 2 - case reflect.Uint32: - if off+4 > len(msg) { - return len(msg), false - } - msg[off] = byte(i >> 24) - msg[off+1] = byte(i >> 16) - msg[off+2] = byte(i >> 8) - msg[off+3] = byte(i) - off += 4 } - case *reflect.StringValue: + n := fv.Len() + if off+n > len(msg) { + return len(msg), false + } + reflect.Copy(reflect.NewValue(msg[off:off+n]), fv) + off += n + case reflect.String: // There are multiple string encodings. // The tag distinguishes ordinary strings from domain names. - s := fv.Get() + s := fv.String() switch f.Tag { default: fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) @@ -437,8 +455,8 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o return off, true } -func structValue(any interface{}) *reflect.StructValue { - return reflect.NewValue(any).(*reflect.PtrValue).Elem().(*reflect.StructValue) +func structValue(any interface{}) reflect.Value { + return reflect.NewValue(any).Elem() } func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { @@ -449,36 +467,41 @@ func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { // TODO(rsc): Move into generic library? // Unpack a reflect.StructValue from msg. // Same restrictions as packStructValue. -func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { +func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { for i := 0; i < val.NumField(); i++ { - f := val.Type().(*reflect.StructType).Field(i) - switch fv := val.Field(i).(type) { + f := val.Type().Field(i) + switch fv := val.Field(i); fv.Kind() { default: BadType: fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) return len(msg), false - case *reflect.StructValue: + case reflect.Struct: off, ok = unpackStructValue(fv, msg, off) - case *reflect.UintValue: - switch fv.Type().Kind() { - default: + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + i := uint16(msg[off])<<8 | uint16(msg[off+1]) + fv.SetUint(uint64(i)) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + fv.SetUint(uint64(i)) + off += 4 + case reflect.Array: + if fv.Type().Elem().Kind() != reflect.Uint8 { goto BadType - case reflect.Uint16: - if off+2 > len(msg) { - return len(msg), false - } - i := uint16(msg[off])<<8 | uint16(msg[off+1]) - fv.Set(uint64(i)) - off += 2 - case reflect.Uint32: - if off+4 > len(msg) { - return len(msg), false - } - i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) - fv.Set(uint64(i)) - off += 4 } - case *reflect.StringValue: + n := fv.Len() + if off+n > len(msg) { + return len(msg), false + } + reflect.Copy(fv, reflect.NewValue(msg[off:off+n])) + off += n + case reflect.String: var s string switch f.Tag { default: @@ -502,7 +525,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, off += n s = string(b) } - fv.Set(s) + fv.SetString(s) } } return off, true @@ -515,24 +538,28 @@ func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { // Generic struct printer. // Doesn't care about the string tag "domain-name", -// but does look for an "ipv4" tag on uint32 variables, +// but does look for an "ipv4" tag on uint32 variables +// and the "ipv6" tag on array variables, // printing them as IP addresses. -func printStructValue(val *reflect.StructValue) string { +func printStructValue(val reflect.Value) string { s := "{" for i := 0; i < val.NumField(); i++ { if i > 0 { s += ", " } - f := val.Type().(*reflect.StructType).Field(i) + f := val.Type().Field(i) if !f.Anonymous { s += f.Name + "=" } fval := val.Field(i) - if fv, ok := fval.(*reflect.StructValue); ok { + if fv := fval; fv.Kind() == reflect.Struct { s += printStructValue(fv) - } else if fv, ok := fval.(*reflect.UintValue); ok && f.Tag == "ipv4" { - i := fv.Get() + } else if fv := fval; (fv.Kind() == reflect.Uint || fv.Kind() == reflect.Uint8 || fv.Kind() == reflect.Uint16 || fv.Kind() == reflect.Uint32 || fv.Kind() == reflect.Uint64 || fv.Kind() == reflect.Uintptr) && f.Tag == "ipv4" { + i := fv.Uint() s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() + } else if fv := fval; fv.Kind() == reflect.Array && f.Tag == "ipv6" { + i := fv.Interface().([]byte) + s += IP(i).String() } else { s += fmt.Sprint(fval.Interface()) } diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go index 2ba9296f3..cd1a21dc3 100644 --- a/src/pkg/net/fd.go +++ b/src/pkg/net/fd.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TODO(rsc): All the prints in this file should go to standard error. - package net import ( @@ -85,11 +83,12 @@ func (e *InvalidConnError) Timeout() bool { return false } // will the fd be closed. type pollServer struct { - cr, cw chan *netFD // buffered >= 1 - pr, pw *os.File - pending map[int]*netFD - poll *pollster // low-level OS hooks - deadline int64 // next deadline (nsec since 1970) + cr, cw chan *netFD // buffered >= 1 + pr, pw *os.File + poll *pollster // low-level OS hooks + sync.Mutex // controls pending and deadline + pending map[int]*netFD + deadline int64 // next deadline (nsec since 1970) } func (s *pollServer) AddFD(fd *netFD, mode int) { @@ -103,10 +102,8 @@ func (s *pollServer) AddFD(fd *netFD, mode int) { } return } - if err := s.poll.AddFD(intfd, mode, false); err != nil { - panic("pollServer AddFD " + err.String()) - return - } + + s.Lock() var t int64 key := intfd << 1 @@ -119,11 +116,31 @@ func (s *pollServer) AddFD(fd *netFD, mode int) { t = fd.wdeadline } s.pending[key] = fd + doWakeup := false if t > 0 && (s.deadline == 0 || t < s.deadline) { s.deadline = t + doWakeup = true + } + + wake, err := s.poll.AddFD(intfd, mode, false) + if err != nil { + panic("pollServer AddFD " + err.String()) + } + if wake { + doWakeup = true + } + + s.Unlock() + + if doWakeup { + s.Wakeup() } } +var wakeupbuf [1]byte + +func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) } + func (s *pollServer) LookupFD(fd int, mode int) *netFD { key := fd << 1 if mode == 'w' { @@ -195,6 +212,8 @@ func (s *pollServer) CheckDeadlines() { func (s *pollServer) Run() { var scratch [100]byte + s.Lock() + defer s.Unlock() for { var t = s.deadline if t > 0 { @@ -204,7 +223,7 @@ func (s *pollServer) Run() { continue } } - fd, mode, err := s.poll.WaitFD(t) + fd, mode, err := s.poll.WaitFD(s, t) if err != nil { print("pollServer WaitFD: ", err.String(), "\n") return @@ -215,22 +234,11 @@ func (s *pollServer) Run() { continue } if fd == s.pr.Fd() { - // Drain our wakeup pipe. - for nn, _ := s.pr.Read(scratch[0:]); nn > 0; { - nn, _ = s.pr.Read(scratch[0:]) - } - // Read from channels - Update: - for { - select { - case fd := <-s.cr: - s.AddFD(fd, 'r') - case fd := <-s.cw: - s.AddFD(fd, 'w') - default: - break Update - } - } + // Drain our wakeup pipe (we could loop here, + // but it's unlikely that there are more than + // len(scratch) wakeup calls). + s.pr.Read(scratch[0:]) + s.CheckDeadlines() } else { netfd := s.LookupFD(fd, mode) if netfd == nil { @@ -242,19 +250,13 @@ func (s *pollServer) Run() { } } -var wakeupbuf [1]byte - -func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) } - func (s *pollServer) WaitRead(fd *netFD) { - s.cr <- fd - s.Wakeup() + s.AddFD(fd, 'r') <-fd.cr } func (s *pollServer) WaitWrite(fd *netFD) { - s.cw <- fd - s.Wakeup() + s.AddFD(fd, 'w') <-fd.cw } @@ -272,19 +274,25 @@ func startServer() { pollserver = p } -func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err os.Error) { +func newFD(fd, family, proto int, net string) (f *netFD, err os.Error) { onceStartServer.Do(startServer) if e := syscall.SetNonblock(fd, true); e != 0 { - return nil, &OpError{"setnonblock", net, laddr, os.Errno(e)} + return nil, os.Errno(e) } f = &netFD{ sysfd: fd, family: family, proto: proto, net: net, - laddr: laddr, - raddr: raddr, } + f.cr = make(chan bool, 1) + f.cw = make(chan bool, 1) + return f, nil +} + +func (fd *netFD) setAddr(laddr, raddr Addr) { + fd.laddr = laddr + fd.raddr = raddr var ls, rs string if laddr != nil { ls = laddr.String() @@ -292,10 +300,23 @@ func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err if raddr != nil { rs = raddr.String() } - f.sysfile = os.NewFile(fd, net+":"+ls+"->"+rs) - f.cr = make(chan bool, 1) - f.cw = make(chan bool, 1) - return f, nil + fd.sysfile = os.NewFile(fd.sysfd, fd.net+":"+ls+"->"+rs) +} + +func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) { + e := syscall.Connect(fd.sysfd, ra) + if e == syscall.EINPROGRESS { + var errno int + pollserver.WaitWrite(fd) + e, errno = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) + if errno != 0 { + return os.NewSyscallError("getsockopt", errno) + } + } + if e != 0 { + return os.Errno(e) + } + return nil } // Add a reference to this fd. @@ -591,10 +612,11 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. syscall.CloseOnExec(s) syscall.ForkLock.RUnlock() - if nfd, err = newFD(s, fd.family, fd.proto, fd.net, fd.laddr, toAddr(sa)); err != nil { + if nfd, err = newFD(s, fd.family, fd.proto, fd.net); err != nil { syscall.Close(s) return nil, err } + nfd.setAddr(fd.laddr, toAddr(sa)) return nfd, nil } diff --git a/src/pkg/net/fd_darwin.go b/src/pkg/net/fd_darwin.go index cd0738753..00a049bfd 100644 --- a/src/pkg/net/fd_darwin.go +++ b/src/pkg/net/fd_darwin.go @@ -15,6 +15,10 @@ type pollster struct { kq int eventbuf [10]syscall.Kevent_t events []syscall.Kevent_t + + // An event buffer for AddFD/DelFD. + // Must hold pollServer lock. + kbuf [1]syscall.Kevent_t } func newpollster() (p *pollster, err os.Error) { @@ -27,15 +31,16 @@ func newpollster() (p *pollster, err os.Error) { return p, nil } -func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { + // pollServer is locked. + var kmode int if mode == 'r' { kmode = syscall.EVFILT_READ } else { kmode = syscall.EVFILT_WRITE } - var events [1]syscall.Kevent_t - ev := &events[0] + ev := &p.kbuf[0] // EV_ADD - add event to kqueue list // EV_RECEIPT - generate fake EV_ERROR as result of add, // rather than waiting for real event @@ -46,36 +51,37 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { } syscall.SetKevent(ev, fd, kmode, flags) - n, e := syscall.Kevent(p.kq, events[0:], events[0:], nil) + n, e := syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil) if e != 0 { - return os.NewSyscallError("kevent", e) + return false, os.NewSyscallError("kevent", e) } if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { - return os.ErrorString("kqueue phase error") + return false, os.ErrorString("kqueue phase error") } if ev.Data != 0 { - return os.Errno(int(ev.Data)) + return false, os.Errno(int(ev.Data)) } - return nil + return false, nil } func (p *pollster) DelFD(fd int, mode int) { + // pollServer is locked. + var kmode int if mode == 'r' { kmode = syscall.EVFILT_READ } else { kmode = syscall.EVFILT_WRITE } - var events [1]syscall.Kevent_t - ev := &events[0] + ev := &p.kbuf[0] // EV_DELETE - delete event from kqueue list // EV_RECEIPT - generate fake EV_ERROR as result of add, // rather than waiting for real event syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE|syscall.EV_RECEIPT) - syscall.Kevent(p.kq, events[0:], events[0:], nil) + syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil) } -func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { var t *syscall.Timespec for len(p.events) == 0 { if nsec > 0 { @@ -84,7 +90,11 @@ func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { } *t = syscall.NsecToTimespec(nsec) } + + s.Unlock() nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[0:], t) + s.Lock() + if e != 0 { if e == syscall.EINTR { continue diff --git a/src/pkg/net/fd_freebsd.go b/src/pkg/net/fd_freebsd.go index 4c5e93424..e50883e94 100644 --- a/src/pkg/net/fd_freebsd.go +++ b/src/pkg/net/fd_freebsd.go @@ -15,6 +15,10 @@ type pollster struct { kq int eventbuf [10]syscall.Kevent_t events []syscall.Kevent_t + + // An event buffer for AddFD/DelFD. + // Must hold pollServer lock. + kbuf [1]syscall.Kevent_t } func newpollster() (p *pollster, err os.Error) { @@ -27,15 +31,16 @@ func newpollster() (p *pollster, err os.Error) { return p, nil } -func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { + // pollServer is locked. + var kmode int if mode == 'r' { kmode = syscall.EVFILT_READ } else { kmode = syscall.EVFILT_WRITE } - var events [1]syscall.Kevent_t - ev := &events[0] + ev := &p.kbuf[0] // EV_ADD - add event to kqueue list // EV_ONESHOT - delete the event the first time it triggers flags := syscall.EV_ADD @@ -44,34 +49,35 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { } syscall.SetKevent(ev, fd, kmode, flags) - n, e := syscall.Kevent(p.kq, events[:], nil, nil) + n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) if e != 0 { - return os.NewSyscallError("kevent", e) + return false, os.NewSyscallError("kevent", e) } if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { - return os.NewSyscallError("kqueue phase error", e) + return false, os.NewSyscallError("kqueue phase error", e) } if ev.Data != 0 { - return os.Errno(int(ev.Data)) + return false, os.Errno(int(ev.Data)) } - return nil + return false, nil } func (p *pollster) DelFD(fd int, mode int) { + // pollServer is locked. + var kmode int if mode == 'r' { kmode = syscall.EVFILT_READ } else { kmode = syscall.EVFILT_WRITE } - var events [1]syscall.Kevent_t - ev := &events[0] + ev := &p.kbuf[0] // EV_DELETE - delete event from kqueue list syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE) - syscall.Kevent(p.kq, events[:], nil, nil) + syscall.Kevent(p.kq, p.kbuf[:], nil, nil) } -func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { var t *syscall.Timespec for len(p.events) == 0 { if nsec > 0 { @@ -80,7 +86,11 @@ func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { } *t = syscall.NsecToTimespec(nsec) } + + s.Unlock() nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) + s.Lock() + if e != 0 { if e == syscall.EINTR { continue diff --git a/src/pkg/net/fd_linux.go b/src/pkg/net/fd_linux.go index ef86cb17f..dcf65c014 100644 --- a/src/pkg/net/fd_linux.go +++ b/src/pkg/net/fd_linux.go @@ -20,7 +20,17 @@ type pollster struct { epfd int // Events we're already waiting for + // Must hold pollServer lock events map[int]uint32 + + // An event buffer for EpollWait. + // Used without a lock, may only be used by WaitFD. + waitEventBuf [10]syscall.EpollEvent + waitEvents []syscall.EpollEvent + + // An event buffer for EpollCtl, to avoid a malloc. + // Must hold pollServer lock. + ctlEvent syscall.EpollEvent } func newpollster() (p *pollster, err os.Error) { @@ -29,7 +39,7 @@ func newpollster() (p *pollster, err os.Error) { // The arg to epoll_create is a hint to the kernel // about the number of FDs we will care about. - // We don't know. + // We don't know, and since 2.6.8 the kernel ignores it anyhow. if p.epfd, e = syscall.EpollCreate(16); e != 0 { return nil, os.NewSyscallError("epoll_create", e) } @@ -37,18 +47,19 @@ func newpollster() (p *pollster, err os.Error) { return p, nil } -func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { - var ev syscall.EpollEvent +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { + // pollServer is locked. + var already bool - ev.Fd = int32(fd) - ev.Events, already = p.events[fd] + p.ctlEvent.Fd = int32(fd) + p.ctlEvent.Events, already = p.events[fd] if !repeat { - ev.Events |= syscall.EPOLLONESHOT + p.ctlEvent.Events |= syscall.EPOLLONESHOT } if mode == 'r' { - ev.Events |= readFlags + p.ctlEvent.Events |= readFlags } else { - ev.Events |= writeFlags + p.ctlEvent.Events |= writeFlags } var op int @@ -57,14 +68,16 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { } else { op = syscall.EPOLL_CTL_ADD } - if e := syscall.EpollCtl(p.epfd, op, fd, &ev); e != 0 { - return os.NewSyscallError("epoll_ctl", e) + if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != 0 { + return false, os.NewSyscallError("epoll_ctl", e) } - p.events[fd] = ev.Events - return nil + p.events[fd] = p.ctlEvent.Events + return false, nil } func (p *pollster) StopWaiting(fd int, bits uint) { + // pollServer is locked. + events, already := p.events[fd] if !already { print("Epoll unexpected fd=", fd, "\n") @@ -82,10 +95,9 @@ func (p *pollster) StopWaiting(fd int, bits uint) { // event in the kernel. Otherwise, delete it. events &= ^uint32(bits) if int32(events)&^syscall.EPOLLONESHOT != 0 { - var ev syscall.EpollEvent - ev.Fd = int32(fd) - ev.Events = events - if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &ev); e != 0 { + p.ctlEvent.Fd = int32(fd) + p.ctlEvent.Events = events + if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != 0 { print("Epoll modify fd=", fd, ": ", os.Errno(e).String(), "\n") } p.events[fd] = events @@ -98,6 +110,8 @@ func (p *pollster) StopWaiting(fd int, bits uint) { } func (p *pollster) DelFD(fd int, mode int) { + // pollServer is locked. + if mode == 'r' { p.StopWaiting(fd, readFlags) } else { @@ -105,24 +119,32 @@ func (p *pollster) DelFD(fd int, mode int) { } } -func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { - // Get an event. - var evarray [1]syscall.EpollEvent - ev := &evarray[0] - var msec int = -1 - if nsec > 0 { - msec = int((nsec + 1e6 - 1) / 1e6) - } - n, e := syscall.EpollWait(p.epfd, evarray[0:], msec) - for e == syscall.EAGAIN || e == syscall.EINTR { - n, e = syscall.EpollWait(p.epfd, evarray[0:], msec) - } - if e != 0 { - return -1, 0, os.NewSyscallError("epoll_wait", e) - } - if n == 0 { - return -1, 0, nil +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { + for len(p.waitEvents) == 0 { + var msec int = -1 + if nsec > 0 { + msec = int((nsec + 1e6 - 1) / 1e6) + } + + s.Unlock() + n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec) + s.Lock() + + if e != 0 { + if e == syscall.EAGAIN || e == syscall.EINTR { + continue + } + return -1, 0, os.NewSyscallError("epoll_wait", e) + } + if n == 0 { + return -1, 0, nil + } + p.waitEvents = p.waitEventBuf[0:n] } + + ev := &p.waitEvents[0] + p.waitEvents = p.waitEvents[1:] + fd = int(ev.Fd) if ev.Events&writeFlags != 0 { diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index 63a8fbc44..c2f736cc1 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -225,29 +225,40 @@ type netFD struct { wio sync.Mutex } -func allocFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD) { +func allocFD(fd, family, proto int, net string) (f *netFD) { f = &netFD{ sysfd: fd, family: family, proto: proto, net: net, - laddr: laddr, - raddr: raddr, } runtime.SetFinalizer(f, (*netFD).Close) return f } -func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err os.Error) { +func newFD(fd, family, proto int, net string) (f *netFD, err os.Error) { if initErr != nil { return nil, initErr } onceStartServer.Do(startServer) // Associate our socket with resultsrv.iocp. if _, e := syscall.CreateIoCompletionPort(int32(fd), resultsrv.iocp, 0, 0); e != 0 { - return nil, &OpError{"CreateIoCompletionPort", net, laddr, os.Errno(e)} + return nil, os.Errno(e) + } + return allocFD(fd, family, proto, net), nil +} + +func (fd *netFD) setAddr(laddr, raddr Addr) { + fd.laddr = laddr + fd.raddr = raddr +} + +func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) { + e := syscall.Connect(fd.sysfd, ra) + if e != 0 { + return os.Errno(e) } - return allocFD(fd, family, proto, net, laddr, raddr), nil + return nil } // Add a reference to this fd. @@ -497,7 +508,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. lsa, _ := lrsa.Sockaddr() rsa, _ := rrsa.Sockaddr() - return allocFD(s, fd.family, fd.proto, fd.net, toAddr(lsa), toAddr(rsa)), nil + nfd = allocFD(s, fd.family, fd.proto, fd.net) + nfd.setAddr(toAddr(lsa), toAddr(rsa)) + return nfd, nil } // Not implemeted functions. diff --git a/src/pkg/net/file.go b/src/pkg/net/file.go new file mode 100644 index 000000000..0e411a192 --- /dev/null +++ b/src/pkg/net/file.go @@ -0,0 +1,119 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" +) + +func newFileFD(f *os.File) (nfd *netFD, err os.Error) { + fd, errno := syscall.Dup(f.Fd()) + if errno != 0 { + return nil, os.NewSyscallError("dup", errno) + } + + proto, errno := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE) + if errno != 0 { + return nil, os.NewSyscallError("getsockopt", errno) + } + + toAddr := sockaddrToTCP + sa, _ := syscall.Getsockname(fd) + switch sa.(type) { + default: + closesocket(fd) + return nil, os.EINVAL + case *syscall.SockaddrInet4: + if proto == syscall.SOCK_DGRAM { + toAddr = sockaddrToUDP + } else if proto == syscall.SOCK_RAW { + toAddr = sockaddrToIP + } + case *syscall.SockaddrInet6: + if proto == syscall.SOCK_DGRAM { + toAddr = sockaddrToUDP + } else if proto == syscall.SOCK_RAW { + toAddr = sockaddrToIP + } + case *syscall.SockaddrUnix: + toAddr = sockaddrToUnix + if proto == syscall.SOCK_DGRAM { + toAddr = sockaddrToUnixgram + } else if proto == syscall.SOCK_SEQPACKET { + toAddr = sockaddrToUnixpacket + } + } + laddr := toAddr(sa) + sa, _ = syscall.Getpeername(fd) + raddr := toAddr(sa) + + if nfd, err = newFD(fd, 0, proto, laddr.Network()); err != nil { + return nil, err + } + nfd.setAddr(laddr, raddr) + return nfd, nil +} + +// FileConn returns a copy of the network connection corresponding to +// the open file f. It is the caller's responsibility to close f when +// finished. Closing c does not affect f, and closing f does not +// affect c. +func FileConn(f *os.File) (c Conn, err os.Error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + switch fd.laddr.(type) { + case *TCPAddr: + return newTCPConn(fd), nil + case *UDPAddr: + return newUDPConn(fd), nil + case *UnixAddr: + return newUnixConn(fd), nil + case *IPAddr: + return newIPConn(fd), nil + } + fd.Close() + return nil, os.EINVAL +} + +// FileListener returns a copy of the network listener corresponding +// to the open file f. It is the caller's responsibility to close l +// when finished. Closing c does not affect l, and closing l does not +// affect c. +func FileListener(f *os.File) (l Listener, err os.Error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + switch laddr := fd.laddr.(type) { + case *TCPAddr: + return &TCPListener{fd}, nil + case *UnixAddr: + return &UnixListener{fd, laddr.Name}, nil + } + fd.Close() + return nil, os.EINVAL +} + +// FilePacketConn returns a copy of the packet network connection +// corresponding to the open file f. It is the caller's +// responsibility to close f when finished. Closing c does not affect +// f, and closing f does not affect c. +func FilePacketConn(f *os.File) (c PacketConn, err os.Error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + switch fd.laddr.(type) { + case *UDPAddr: + return newUDPConn(fd), nil + case *UnixAddr: + return newUnixConn(fd), nil + } + fd.Close() + return nil, os.EINVAL +} diff --git a/src/pkg/net/file_test.go b/src/pkg/net/file_test.go new file mode 100644 index 000000000..1ec05fdee --- /dev/null +++ b/src/pkg/net/file_test.go @@ -0,0 +1,131 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "reflect" + "runtime" + "syscall" + "testing" +) + +type listenerFile interface { + Listener + File() (f *os.File, err os.Error) +} + +type packetConnFile interface { + PacketConn + File() (f *os.File, err os.Error) +} + +type connFile interface { + Conn + File() (f *os.File, err os.Error) +} + +func testFileListener(t *testing.T, net, laddr string) { + if net == "tcp" { + laddr += ":0" // any available port + } + l, err := Listen(net, laddr) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer l.Close() + lf := l.(listenerFile) + f, err := lf.File() + if err != nil { + t.Fatalf("File failed: %v", err) + } + c, err := FileListener(f) + if err != nil { + t.Fatalf("FileListener failed: %v", err) + } + if !reflect.DeepEqual(l.Addr(), c.Addr()) { + t.Fatalf("Addrs not equal: %#v != %#v", l.Addr(), c.Addr()) + } + if err := c.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } +} + +func TestFileListener(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + testFileListener(t, "tcp", "127.0.0.1") + testFileListener(t, "tcp", "127.0.0.1") + if kernelSupportsIPv6() { + testFileListener(t, "tcp", "[::ffff:127.0.0.1]") + testFileListener(t, "tcp", "127.0.0.1") + testFileListener(t, "tcp", "[::ffff:127.0.0.1]") + } + if syscall.OS == "linux" { + testFileListener(t, "unix", "@gotest/net") + testFileListener(t, "unixpacket", "@gotest/net") + } +} + +func testFilePacketConn(t *testing.T, pcf packetConnFile) { + f, err := pcf.File() + if err != nil { + t.Fatalf("File failed: %v", err) + } + c, err := FilePacketConn(f) + if err != nil { + t.Fatalf("FilePacketConn failed: %v", err) + } + if !reflect.DeepEqual(pcf.LocalAddr(), c.LocalAddr()) { + t.Fatalf("LocalAddrs not equal: %#v != %#v", pcf.LocalAddr(), c.LocalAddr()) + } + if err := c.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } +} + +func testFilePacketConnListen(t *testing.T, net, laddr string) { + l, err := ListenPacket(net, laddr) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + testFilePacketConn(t, l.(packetConnFile)) + if err := l.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } +} + +func testFilePacketConnDial(t *testing.T, net, raddr string) { + c, err := Dial(net, raddr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + testFilePacketConn(t, c.(packetConnFile)) + if err := c.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } +} + +func TestFilePacketConn(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + testFilePacketConnListen(t, "udp", "127.0.0.1:0") + testFilePacketConnDial(t, "udp", "127.0.0.1:12345") + if kernelSupportsIPv6() { + testFilePacketConnListen(t, "udp", "[::1]:0") + testFilePacketConnDial(t, "udp", "[::ffff:127.0.0.1]:12345") + } + if syscall.OS == "linux" { + testFilePacketConnListen(t, "unixgram", "@gotest1/net") + } +} diff --git a/src/pkg/net/file_windows.go b/src/pkg/net/file_windows.go new file mode 100644 index 000000000..94aa58375 --- /dev/null +++ b/src/pkg/net/file_windows.go @@ -0,0 +1,25 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" +) + +func FileConn(f *os.File) (c Conn, err os.Error) { + // TODO: Implement this + return nil, os.NewSyscallError("FileConn", syscall.EWINDOWS) +} + +func FileListener(f *os.File) (l Listener, err os.Error) { + // TODO: Implement this + return nil, os.NewSyscallError("FileListener", syscall.EWINDOWS) +} + +func FilePacketConn(f *os.File) (c PacketConn, err os.Error) { + // TODO: Implement this + return nil, os.NewSyscallError("FilePacketConn", syscall.EWINDOWS) +} diff --git a/src/pkg/net/hosts.go b/src/pkg/net/hosts.go index 8525f578d..d75e9e038 100644 --- a/src/pkg/net/hosts.go +++ b/src/pkg/net/hosts.go @@ -59,7 +59,7 @@ func readHosts() { } } -// lookupStaticHosts looks up the addresses for the given host from /etc/hosts. +// lookupStaticHost looks up the addresses for the given host from /etc/hosts. func lookupStaticHost(host string) []string { hosts.Lock() defer hosts.Unlock() @@ -72,7 +72,7 @@ func lookupStaticHost(host string) []string { return nil } -// rlookupStaticHosts looks up the hosts for the given address from /etc/hosts. +// lookupStaticAddr looks up the hosts for the given address from /etc/hosts. func lookupStaticAddr(addr string) []string { hosts.Lock() defer hosts.Unlock() diff --git a/src/pkg/net/hosts_test.go b/src/pkg/net/hosts_test.go index 84cd92e37..470e35f78 100644 --- a/src/pkg/net/hosts_test.go +++ b/src/pkg/net/hosts_test.go @@ -13,7 +13,6 @@ type hostTest struct { ips []IP } - var hosttests = []hostTest{ {"odin", []IP{ IPv4(127, 0, 0, 2), diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go index e82224a28..12bb6f351 100644 --- a/src/pkg/net/ip.go +++ b/src/pkg/net/ip.go @@ -12,6 +12,8 @@ package net +import "os" + // IP address lengths (bytes). const ( IPv4len = 4 @@ -39,11 +41,7 @@ type IPMask []byte // IPv4 address a.b.c.d. func IPv4(a, b, c, d byte) IP { p := make(IP, IPv6len) - for i := 0; i < 10; i++ { - p[i] = 0 - } - p[10] = 0xff - p[11] = 0xff + copy(p, v4InV6Prefix) p[12] = a p[13] = b p[14] = c @@ -51,6 +49,8 @@ func IPv4(a, b, c, d byte) IP { return p } +var v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff} + // IPv4Mask returns the IP mask (in 16-byte form) of the // IPv4 mask a.b.c.d. func IPv4Mask(a, b, c, d byte) IPMask { @@ -140,9 +140,24 @@ func (ip IP) DefaultMask() IPMask { return nil // not reached } +func allFF(b []byte) bool { + for _, c := range b { + if c != 0xff { + return false + } + } + return true +} + // Mask returns the result of masking the IP address ip with mask. func (ip IP) Mask(mask IPMask) IP { n := len(ip) + if len(mask) == 16 && len(ip) == 4 && allFF(mask[:12]) { + mask = mask[12:] + } + if len(mask) == 4 && len(ip) == 16 && bytesEqual(ip[:12], v4InV6Prefix) { + ip = ip[12:] + } if n != len(mask) { return nil } @@ -245,6 +260,34 @@ func (ip IP) String() string { return s } +// Equal returns true if ip and x are the same IP address. +// An IPv4 address and that same address in IPv6 form are +// considered to be equal. +func (ip IP) Equal(x IP) bool { + if len(ip) == len(x) { + return bytesEqual(ip, x) + } + if len(ip) == 4 && len(x) == 16 { + return bytesEqual(x[0:12], v4InV6Prefix) && bytesEqual(ip, x[12:]) + } + if len(ip) == 16 && len(x) == 4 { + return bytesEqual(ip[0:12], v4InV6Prefix) && bytesEqual(ip[12:], x) + } + return false +} + +func bytesEqual(x, y []byte) bool { + if len(x) != len(y) { + return false + } + for i, b := range x { + if y[i] != b { + return false + } + } + return true +} + // If mask is a sequence of 1 bits followed by 0 bits, // return the number of 1 bits. func simpleMaskLength(mask IPMask) int { @@ -351,7 +394,6 @@ func parseIPv6(s string) IP { // Loop, parsing hex numbers followed by colon. j := 0 -L: for j < IPv6len { // Hex number. n, i1, ok := xtoi(s, i) @@ -432,15 +474,79 @@ L: return p } +// A ParseError represents a malformed text string and the type of string that was expected. +type ParseError struct { + Type string + Text string +} + +func (e *ParseError) String() string { + return "invalid " + e.Type + ": " + e.Text +} + +func parseIP(s string) IP { + if p := parseIPv4(s); p != nil { + return p + } + if p := parseIPv6(s); p != nil { + return p + } + return nil +} + // ParseIP parses s as an IP address, returning the result. // The string s can be in dotted decimal ("74.125.19.99") // or IPv6 ("2001:4860:0:2001::68") form. // If s is not a valid textual representation of an IP address, // ParseIP returns nil. func ParseIP(s string) IP { - p := parseIPv4(s) - if p != nil { + if p := parseIPv4(s); p != nil { return p } return parseIPv6(s) } + +// ParseCIDR parses s as a CIDR notation IP address and mask, +// like "192.168.100.1/24", "2001:DB8::/48", as defined in +// RFC 4632 and RFC 4291. +func ParseCIDR(s string) (ip IP, mask IPMask, err os.Error) { + i := byteIndex(s, '/') + if i < 0 { + return nil, nil, &ParseError{"CIDR address", s} + } + ipstr, maskstr := s[:i], s[i+1:] + iplen := 4 + ip = parseIPv4(ipstr) + if ip == nil { + iplen = 16 + ip = parseIPv6(ipstr) + } + nn, i, ok := dtoi(maskstr, 0) + if ip == nil || !ok || i != len(maskstr) || nn < 0 || nn > 8*iplen { + return nil, nil, &ParseError{"CIDR address", s} + } + n := uint(nn) + if iplen == 4 { + v4mask := ^uint32(0xffffffff >> n) + mask = IPv4Mask(byte(v4mask>>24), byte(v4mask>>16), byte(v4mask>>8), byte(v4mask)) + } else { + mask = make(IPMask, 16) + for i := 0; i < 16; i++ { + if n >= 8 { + mask[i] = 0xff + n -= 8 + continue + } + mask[i] = ^byte(0xff >> n) + n = 0 + + } + } + // address must not have any bits not in mask + for i := range ip { + if ip[i]&^mask[i] != 0 { + return nil, nil, &ParseError{"CIDR address", s} + } + } + return ip, mask, nil +} diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go index e29c3021d..f1a4716d2 100644 --- a/src/pkg/net/ip_test.go +++ b/src/pkg/net/ip_test.go @@ -5,30 +5,26 @@ package net import ( + "bytes" + "reflect" "testing" + "os" ) -func isEqual(a, b IP) bool { +func isEqual(a, b []byte) bool { if a == nil && b == nil { return true } - if a == nil || b == nil || len(a) != len(b) { + if a == nil || b == nil { return false } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false - } - } - return true + return bytes.Equal(a, b) } -type parseIPTest struct { +var parseiptests = []struct { in string out IP -} - -var parseiptests = []parseIPTest{ +}{ {"127.0.1.2", IPv4(127, 0, 1, 2)}, {"127.0.0.1", IPv4(127, 0, 0, 1)}, {"127.0.0.256", nil}, @@ -43,20 +39,17 @@ var parseiptests = []parseIPTest{ } func TestParseIP(t *testing.T) { - for i := 0; i < len(parseiptests); i++ { - tt := parseiptests[i] + for _, tt := range parseiptests { if out := ParseIP(tt.in); !isEqual(out, tt.out) { t.Errorf("ParseIP(%#q) = %v, want %v", tt.in, out, tt.out) } } } -type ipStringTest struct { +var ipstringtests = []struct { in IP out string -} - -var ipstringtests = []ipStringTest{ +}{ // cf. RFC 5952 (A Recommendation for IPv6 Address Text Representation) {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1}, @@ -85,10 +78,67 @@ var ipstringtests = []ipStringTest{ } func TestIPString(t *testing.T) { - for i := 0; i < len(ipstringtests); i++ { - tt := ipstringtests[i] + for _, tt := range ipstringtests { if out := tt.in.String(); out != tt.out { t.Errorf("IP.String(%v) = %#q, want %#q", tt.in, out, tt.out) } } } + +var parsecidrtests = []struct { + in string + ip IP + mask IPMask + err os.Error +}{ + {"135.104.0.0/32", IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 255), nil}, + {"0.0.0.0/24", IPv4(0, 0, 0, 0), IPv4Mask(255, 255, 255, 0), nil}, + {"135.104.0.0/24", IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 0), nil}, + {"135.104.0.1/32", IPv4(135, 104, 0, 1), IPv4Mask(255, 255, 255, 255), nil}, + {"135.104.0.1/24", nil, nil, &ParseError{"CIDR address", "135.104.0.1/24"}}, + {"::1/128", ParseIP("::1"), IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")), nil}, + {"abcd:2345::/127", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe")), nil}, + {"abcd:2345::/65", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff:8000::")), nil}, + {"abcd:2345::/64", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff::")), nil}, + {"abcd:2345::/63", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:fffe::")), nil}, + {"abcd:2345::/33", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:8000::")), nil}, + {"abcd:2345::/32", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff::")), nil}, + {"abcd:2344::/31", ParseIP("abcd:2344::"), IPMask(ParseIP("ffff:fffe::")), nil}, + {"abcd:2300::/24", ParseIP("abcd:2300::"), IPMask(ParseIP("ffff:ff00::")), nil}, + {"abcd:2345::/24", nil, nil, &ParseError{"CIDR address", "abcd:2345::/24"}}, + {"2001:DB8::/48", ParseIP("2001:DB8::"), IPMask(ParseIP("ffff:ffff:ffff::")), nil}, +} + +func TestParseCIDR(t *testing.T) { + for _, tt := range parsecidrtests { + if ip, mask, err := ParseCIDR(tt.in); !isEqual(ip, tt.ip) || !isEqual(mask, tt.mask) || !reflect.DeepEqual(err, tt.err) { + t.Errorf("ParseCIDR(%q) = %v, %v, %v; want %v, %v, %v", tt.in, ip, mask, err, tt.ip, tt.mask, tt.err) + } + } +} + +var splitjointests = []struct { + Host string + Port string + Join string +}{ + {"www.google.com", "80", "www.google.com:80"}, + {"127.0.0.1", "1234", "127.0.0.1:1234"}, + {"::1", "80", "[::1]:80"}, +} + +func TestSplitHostPort(t *testing.T) { + for _, tt := range splitjointests { + if host, port, err := SplitHostPort(tt.Join); host != tt.Host || port != tt.Port || err != nil { + t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.Join, host, port, err, tt.Host, tt.Port) + } + } +} + +func TestJoinHostPort(t *testing.T) { + for _, tt := range splitjointests { + if join := JoinHostPort(tt.Host, tt.Port); join != tt.Join { + t.Errorf("JoinHostPort(%q, %q) = %q; want %q", tt.Host, tt.Port, join, tt.Join) + } + } +} diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go index 562298bdf..ee8c71fc1 100644 --- a/src/pkg/net/ipraw_test.go +++ b/src/pkg/net/ipraw_test.go @@ -69,9 +69,12 @@ func TestICMP(t *testing.T) { return } - var laddr *IPAddr + var ( + laddr *IPAddr + err os.Error + ) if *srchost != "" { - laddr, err := ResolveIPAddr(*srchost) + laddr, err = ResolveIPAddr(*srchost) if err != nil { t.Fatalf(`net.ResolveIPAddr("%v") = %v, %v`, *srchost, laddr, err) } diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go index 81a918ce5..60433303a 100644 --- a/src/pkg/net/iprawsock.go +++ b/src/pkg/net/iprawsock.go @@ -240,7 +240,7 @@ func hostToIP(host string) (ip IP, err os.Error) { addr = ParseIP(host) if addr == nil { // Not an IP address. Try as a DNS name. - _, addrs, err1 := LookupHost(host) + addrs, err1 := LookupHost(host) if err1 != nil { err = err1 goto Error diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go index ae4204b48..80bc3eea5 100644 --- a/src/pkg/net/ipsock.go +++ b/src/pkg/net/ipsock.go @@ -170,9 +170,10 @@ func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, os.Error) { return nil, InvalidAddrError("unexpected socket family") } -// Split "host:port" into "host" and "port". -// Host cannot contain colons unless it is bracketed. -func splitHostPort(hostport string) (host, port string, err os.Error) { +// SplitHostPort splits a network address of the form +// "host:port" or "[host]:port" into host and port. +// The latter form must be used when host contains a colon. +func SplitHostPort(hostport string) (host, port string, err os.Error) { // The port starts after the last colon. i := last(hostport, ':') if i < 0 { @@ -195,9 +196,9 @@ func splitHostPort(hostport string) (host, port string, err os.Error) { return } -// Join "host" and "port" into "host:port". -// If host contains colons, will join into "[host]:port". -func joinHostPort(host, port string) string { +// JoinHostPort combines host and port into a network address +// of the form "host:port" or, if host contains a colon, "[host]:port". +func JoinHostPort(host, port string) string { // If host has colons, have to bracket it. if byteIndex(host, ':') >= 0 { return "[" + host + "]:" + port @@ -207,7 +208,7 @@ func joinHostPort(host, port string) string { // Convert "host:port" into IP address and port. func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) { - host, port, err := splitHostPort(hostport) + host, port, err := SplitHostPort(hostport) if err != nil { goto Error } @@ -218,7 +219,7 @@ func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) { addr = ParseIP(host) if addr == nil { // Not an IP address. Try as a DNS name. - _, addrs, err1 := LookupHost(host) + addrs, err1 := LookupHost(host) if err1 != nil { err = err1 goto Error diff --git a/src/pkg/net/lookup.go b/src/pkg/net/lookup.go new file mode 100644 index 000000000..7b2185ed4 --- /dev/null +++ b/src/pkg/net/lookup.go @@ -0,0 +1,38 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" +) + +// LookupHost looks up the given host using the local resolver. +// It returns an array of that host's addresses. +func LookupHost(host string) (addrs []string, err os.Error) { + addrs, err, ok := cgoLookupHost(host) + if !ok { + addrs, err = goLookupHost(host) + } + return +} + +// LookupIP looks up host using the local resolver. +// It returns an array of that host's IPv4 and IPv6 addresses. +func LookupIP(host string) (addrs []IP, err os.Error) { + addrs, err, ok := cgoLookupIP(host) + if !ok { + addrs, err = goLookupIP(host) + } + return +} + +// LookupPort looks up the port for the given network and service. +func LookupPort(network, service string) (port int, err os.Error) { + port, err, ok := cgoLookupPort(network, service) + if !ok { + port, err = goLookupPort(network, service) + } + return +} diff --git a/src/pkg/net/multicast_test.go b/src/pkg/net/multicast_test.go index 32fdec85b..be6dbf2dc 100644 --- a/src/pkg/net/multicast_test.go +++ b/src/pkg/net/multicast_test.go @@ -5,14 +5,21 @@ package net import ( + "flag" "runtime" "testing" ) +var multicast = flag.Bool("multicast", false, "enable multicast tests") + func TestMulticastJoinAndLeave(t *testing.T) { if runtime.GOOS == "windows" { return } + if !*multicast { + t.Logf("test disabled; use --multicast to enable") + return + } addr := &UDPAddr{ IP: IPv4zero, @@ -40,6 +47,10 @@ func TestMulticastJoinAndLeave(t *testing.T) { } func TestJoinFailureWithIPv6Address(t *testing.T) { + if !*multicast { + t.Logf("test disabled; use --multicast to enable") + return + } addr := &UDPAddr{ IP: IPv4zero, Port: 0, diff --git a/src/pkg/net/net_test.go b/src/pkg/net/net_test.go index 1e6e99eec..f7eae56fe 100644 --- a/src/pkg/net/net_test.go +++ b/src/pkg/net/net_test.go @@ -15,50 +15,49 @@ var runErrorTest = flag.Bool("run_error_test", false, "let TestDialError check f type DialErrorTest struct { Net string - Laddr string Raddr string Pattern string } var dialErrorTests = []DialErrorTest{ { - "datakit", "", "mh/astro/r70", + "datakit", "mh/astro/r70", "dial datakit mh/astro/r70: unknown network datakit", }, { - "tcp", "", "127.0.0.1:☺", + "tcp", "127.0.0.1:☺", "dial tcp 127.0.0.1:☺: unknown port tcp/☺", }, { - "tcp", "", "no-such-name.google.com.:80", + "tcp", "no-such-name.google.com.:80", "dial tcp no-such-name.google.com.:80: lookup no-such-name.google.com.( on .*)?: no (.*)", }, { - "tcp", "", "no-such-name.no-such-top-level-domain.:80", + "tcp", "no-such-name.no-such-top-level-domain.:80", "dial tcp no-such-name.no-such-top-level-domain.:80: lookup no-such-name.no-such-top-level-domain.( on .*)?: no (.*)", }, { - "tcp", "", "no-such-name:80", + "tcp", "no-such-name:80", `dial tcp no-such-name:80: lookup no-such-name\.(.*\.)?( on .*)?: no (.*)`, }, { - "tcp", "", "mh/astro/r70:http", + "tcp", "mh/astro/r70:http", "dial tcp mh/astro/r70:http: lookup mh/astro/r70: invalid domain name", }, { - "unix", "", "/etc/file-not-found", + "unix", "/etc/file-not-found", "dial unix /etc/file-not-found: no such file or directory", }, { - "unix", "", "/etc/", + "unix", "/etc/", "dial unix /etc/: (permission denied|socket operation on non-socket|connection refused)", }, { - "unixpacket", "", "/etc/file-not-found", + "unixpacket", "/etc/file-not-found", "dial unixpacket /etc/file-not-found: no such file or directory", }, { - "unixpacket", "", "/etc/", + "unixpacket", "/etc/", "dial unixpacket /etc/: (permission denied|socket operation on non-socket|connection refused)", }, } @@ -69,7 +68,7 @@ func TestDialError(t *testing.T) { return } for i, tt := range dialErrorTests { - c, e := Dial(tt.Net, tt.Laddr, tt.Raddr) + c, e := Dial(tt.Net, tt.Raddr) if c != nil { c.Close() } diff --git a/src/pkg/net/newpollserver.go b/src/pkg/net/newpollserver.go index 820e70b46..fff54dba7 100644 --- a/src/pkg/net/newpollserver.go +++ b/src/pkg/net/newpollserver.go @@ -31,7 +31,7 @@ func newPollServer() (s *pollServer, err os.Error) { if s.poll, err = newpollster(); err != nil { goto Error } - if err = s.poll.AddFD(s.pr.Fd(), 'r', true); err != nil { + if _, err = s.poll.AddFD(s.pr.Fd(), 'r', true); err != nil { s.poll.Close() goto Error } diff --git a/src/pkg/net/parse.go b/src/pkg/net/parse.go index 2bc0db465..de46830d2 100644 --- a/src/pkg/net/parse.go +++ b/src/pkg/net/parse.go @@ -63,7 +63,7 @@ func (f *file) readLine() (s string, ok bool) { } func open(name string) (*file, os.Error) { - fd, err := os.Open(name, os.O_RDONLY, 0) + fd, err := os.Open(name) if err != nil { return nil, err } diff --git a/src/pkg/net/parse_test.go b/src/pkg/net/parse_test.go index 2b7784eee..226f354d3 100644 --- a/src/pkg/net/parse_test.go +++ b/src/pkg/net/parse_test.go @@ -18,7 +18,7 @@ func TestReadLine(t *testing.T) { } filename := "/etc/services" // a nice big file - fd, err := os.Open(filename, os.O_RDONLY, 0) + fd, err := os.Open(filename) if err != nil { t.Fatalf("open %s: %v", filename, err) } diff --git a/src/pkg/net/port.go b/src/pkg/net/port.go index 7d25058b2..8f8327a37 100644 --- a/src/pkg/net/port.go +++ b/src/pkg/net/port.go @@ -50,8 +50,8 @@ func readServices() { file.close() } -// LookupPort looks up the port for the given network and service. -func LookupPort(network, service string) (port int, err os.Error) { +// goLookupPort is the native Go implementation of LookupPort. +func goLookupPort(network, service string) (port int, err os.Error) { onceReadServices.Do(readServices) switch network { diff --git a/src/pkg/net/port_test.go b/src/pkg/net/port_test.go index 1b7eaf231..329b169f3 100644 --- a/src/pkg/net/port_test.go +++ b/src/pkg/net/port_test.go @@ -27,9 +27,7 @@ var porttests = []portTest{ {"tcp", "smtp", 25, true}, {"tcp", "time", 37, true}, {"tcp", "domain", 53, true}, - {"tcp", "gopher", 70, true}, {"tcp", "finger", 79, true}, - {"tcp", "http", 80, true}, {"udp", "echo", 7, true}, {"udp", "tftp", 69, true}, diff --git a/src/pkg/net/resolv_windows.go b/src/pkg/net/resolv_windows.go index f3d854ff2..000c30659 100644 --- a/src/pkg/net/resolv_windows.go +++ b/src/pkg/net/resolv_windows.go @@ -14,26 +14,51 @@ import ( var hostentLock sync.Mutex var serventLock sync.Mutex -func LookupHost(name string) (cname string, addrs []string, err os.Error) { +func goLookupHost(name string) (addrs []string, err os.Error) { + ips, err := goLookupIP(name) + if err != nil { + return + } + addrs = make([]string, 0, len(ips)) + for _, ip := range ips { + addrs = append(addrs, ip.String()) + } + return +} + +func goLookupIP(name string) (addrs []IP, err os.Error) { hostentLock.Lock() defer hostentLock.Unlock() h, e := syscall.GetHostByName(name) if e != 0 { - return "", nil, os.NewSyscallError("GetHostByName", e) + return nil, os.NewSyscallError("GetHostByName", e) } - cname = name switch h.AddrType { case syscall.AF_INET: i := 0 - addrs = make([]string, 100) // plenty of room to grow + addrs = make([]IP, 100) // plenty of room to grow for p := (*[100](*[4]byte))(unsafe.Pointer(h.AddrList)); i < cap(addrs) && p[i] != nil; i++ { - addrs[i] = IPv4(p[i][0], p[i][1], p[i][2], p[i][3]).String() + addrs[i] = IPv4(p[i][0], p[i][1], p[i][2], p[i][3]) } addrs = addrs[0:i] default: // TODO(vcc): Implement non IPv4 address lookups. - return "", nil, os.NewSyscallError("LookupHost", syscall.EWINDOWS) + return nil, os.NewSyscallError("LookupHost", syscall.EWINDOWS) + } + return addrs, nil +} + +func LookupCNAME(name string) (cname string, err os.Error) { + var r *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) + if int(e) != 0 { + return "", os.NewSyscallError("LookupCNAME", int(e)) + } + defer syscall.DnsRecordListFree(r, 1) + if r != nil && r.Type == syscall.DNS_TYPE_CNAME { + v := (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])) + cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "." } - return cname, addrs, nil + return } type SRV struct { @@ -62,7 +87,7 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os. return name, addrs, nil } -func LookupPort(network, service string) (port int, err os.Error) { +func goLookupPort(network, service string) (port int, err os.Error) { switch network { case "tcp4", "tcp6": network = "tcp" diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go index 3dda500e5..37695a068 100644 --- a/src/pkg/net/server_test.go +++ b/src/pkg/net/server_test.go @@ -54,13 +54,15 @@ func runServe(t *testing.T, network, addr string, listening chan<- string, done } func connect(t *testing.T, network, addr string, isEmpty bool) { - var laddr string + var fd Conn + var err os.Error if network == "unixgram" { - laddr = addr + ".local" + fd, err = DialUnix(network, &UnixAddr{addr + ".local", network}, &UnixAddr{addr, network}) + } else { + fd, err = Dial(network, addr) } - fd, err := Dial(network, laddr, addr) if err != nil { - t.Fatalf("net.Dial(%q, %q, %q) = _, %v", network, laddr, addr, err) + t.Fatalf("net.Dial(%q, %q) = _, %v", network, addr, err) } fd.SetReadTimeout(1e9) // 1s diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go index 8ad3548ad..933700af1 100644 --- a/src/pkg/net/sock.go +++ b/src/pkg/net/sock.go @@ -52,11 +52,16 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal } } + if fd, err = newFD(s, f, p, net); err != nil { + closesocket(s) + return nil, err + } + if ra != nil { - e = syscall.Connect(s, ra) - if e != 0 { + if err = fd.connect(ra); err != nil { + fd.sysfd = -1 closesocket(s) - return nil, os.Errno(e) + return nil, err } } @@ -65,12 +70,7 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal sa, _ = syscall.Getpeername(s) raddr := toAddr(sa) - fd, err = newFD(s, f, p, net, laddr, raddr) - if err != nil { - closesocket(s) - return nil, err - } - + fd.setAddr(laddr, raddr) return fd, nil } @@ -167,9 +167,9 @@ func (e *UnknownSocketError) String() string { func sockaddrToString(sa syscall.Sockaddr) (name string, err os.Error) { switch a := sa.(type) { case *syscall.SockaddrInet4: - return joinHostPort(IP(a.Addr[0:]).String(), itoa(a.Port)), nil + return JoinHostPort(IP(a.Addr[0:]).String(), itoa(a.Port)), nil case *syscall.SockaddrInet6: - return joinHostPort(IP(a.Addr[0:]).String(), itoa(a.Port)), nil + return JoinHostPort(IP(a.Addr[0:]).String(), itoa(a.Port)), nil case *syscall.SockaddrUnix: return a.Name, nil } diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go index a4bca11bb..b484be20b 100644 --- a/src/pkg/net/tcpsock.go +++ b/src/pkg/net/tcpsock.go @@ -34,7 +34,7 @@ func (a *TCPAddr) String() string { if a == nil { return "<nil>" } - return joinHostPort(a.IP.String(), itoa(a.Port)) + return JoinHostPort(a.IP.String(), itoa(a.Port)) } func (a *TCPAddr) family() int { @@ -213,8 +213,9 @@ func (c *TCPConn) SetNoDelay(noDelay bool) os.Error { // Closing c does not affect f, and closing f does not affect c. func (c *TCPConn) File() (f *os.File, err os.Error) { return c.fd.dup() } -// DialTCP is like Dial but can only connect to TCP networks -// and returns a TCPConn structure. +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err os.Error) { if raddr == nil { return nil, &OpError{"dial", "tcp", nil, errMissingAddress} diff --git a/src/pkg/net/textproto/textproto.go b/src/pkg/net/textproto/textproto.go index f62009c52..fbfad9d61 100644 --- a/src/pkg/net/textproto/textproto.go +++ b/src/pkg/net/textproto/textproto.go @@ -78,7 +78,7 @@ func (c *Conn) Close() os.Error { // Dial connects to the given address on the given network using net.Dial // and then returns a new Conn for the connection. func Dial(network, addr string) (*Conn, os.Error) { - c, err := net.Dial(network, "", addr) + c, err := net.Dial(network, addr) if err != nil { return nil, err } diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index 09a257dc8..0dbab5846 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -11,7 +11,7 @@ import ( ) func testTimeout(t *testing.T, network, addr string, readFrom bool) { - fd, err := Dial(network, "", addr) + fd, err := Dial(network, addr) if err != nil { t.Errorf("dial %s %s failed: %v", network, addr, err) return diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go index f9274493e..44d618dab 100644 --- a/src/pkg/net/udpsock.go +++ b/src/pkg/net/udpsock.go @@ -34,7 +34,7 @@ func (a *UDPAddr) String() string { if a == nil { return "<nil>" } - return joinHostPort(a.IP.String(), itoa(a.Port)) + return JoinHostPort(a.IP.String(), itoa(a.Port)) } func (a *UDPAddr) family() int { |