diff options
author | Ondřej Surý <ondrej@sury.org> | 2011-09-13 13:13:40 +0200 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2011-09-13 13:13:40 +0200 |
commit | 5ff4c17907d5b19510a62e08fd8d3b11e62b431d (patch) | |
tree | c0650497e988f47be9c6f2324fa692a52dea82e1 /src/pkg/net | |
parent | 80f18fc933cf3f3e829c5455a1023d69f7b86e52 (diff) | |
download | golang-upstream/60.tar.gz |
Imported Upstream version 60upstream/60
Diffstat (limited to 'src/pkg/net')
86 files changed, 12459 insertions, 0 deletions
diff --git a/src/pkg/net/Makefile b/src/pkg/net/Makefile new file mode 100644 index 000000000..eba9e26d9 --- /dev/null +++ b/src/pkg/net/Makefile @@ -0,0 +1,157 @@ +# Copyright 2009 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../Make.inc + +TARG=net +GOFILES=\ + dial.go\ + dnsclient.go\ + dnsmsg.go\ + hosts.go\ + interface.go\ + ip.go\ + iprawsock.go\ + ipsock.go\ + net.go\ + parse.go\ + pipe.go\ + tcpsock.go\ + udpsock.go\ + unixsock.go\ + +GOFILES_freebsd=\ + dnsclient_unix.go\ + dnsconfig.go\ + fd.go\ + fd_$(GOOS).go\ + file.go\ + interface_bsd.go\ + interface_freebsd.go\ + iprawsock_posix.go\ + ipsock_posix.go\ + lookup_unix.go\ + newpollserver.go\ + port.go\ + sendfile_stub.go\ + sock.go\ + sock_bsd.go\ + tcpsock_posix.go\ + udpsock_posix.go\ + unixsock_posix.go\ + +ifeq ($(CGO_ENABLED),1) +CGOFILES_freebsd=\ + cgo_bsd.go\ + cgo_unix.go +else +GOFILES_freebsd+=cgo_stub.go +endif + +GOFILES_darwin=\ + dnsclient_unix.go\ + dnsconfig.go\ + fd.go\ + fd_$(GOOS).go\ + file.go\ + interface_bsd.go\ + interface_darwin.go\ + iprawsock_posix.go\ + ipsock_posix.go\ + lookup_unix.go\ + newpollserver.go\ + port.go\ + sendfile_stub.go\ + sock.go\ + sock_bsd.go\ + tcpsock_posix.go\ + udpsock_posix.go\ + unixsock_posix.go\ + +ifeq ($(CGO_ENABLED),1) +CGOFILES_darwin=\ + cgo_bsd.go\ + cgo_unix.go +else +GOFILES_darwin+=cgo_stub.go +endif + +GOFILES_linux=\ + dnsclient_unix.go\ + dnsconfig.go\ + fd.go\ + fd_$(GOOS).go\ + file.go\ + interface_linux.go\ + iprawsock_posix.go\ + ipsock_posix.go\ + lookup_unix.go\ + newpollserver.go\ + port.go\ + sendfile_linux.go\ + sock.go\ + sock_linux.go\ + tcpsock_posix.go\ + udpsock_posix.go\ + unixsock_posix.go\ + +ifeq ($(CGO_ENABLED),1) +CGOFILES_linux=\ + cgo_linux.go\ + cgo_unix.go +else +GOFILES_linux+=cgo_stub.go +endif + +GOFILES_openbsd=\ + dnsclient_unix.go\ + dnsconfig.go\ + fd.go\ + fd_$(GOOS).go\ + file.go\ + interface_bsd.go\ + interface_openbsd.go\ + iprawsock_posix.go\ + ipsock_posix.go\ + lookup_unix.go\ + newpollserver.go\ + port.go\ + sendfile_stub.go\ + sock.go\ + sock_bsd.go\ + tcpsock_posix.go\ + udpsock_posix.go\ + unixsock_posix.go\ + cgo_stub.go\ + +GOFILES_plan9=\ + file_plan9.go\ + interface_stub.go\ + iprawsock_plan9.go\ + ipsock_plan9.go\ + lookup_plan9.go\ + tcpsock_plan9.go\ + udpsock_plan9.go\ + unixsock_plan9.go\ + +GOFILES_windows=\ + fd_$(GOOS).go\ + file_windows.go\ + interface_windows.go\ + iprawsock_posix.go\ + ipsock_posix.go\ + lookup_windows.go\ + sendfile_windows.go\ + sock.go\ + sock_windows.go\ + tcpsock_posix.go\ + udpsock_posix.go\ + unixsock_posix.go\ + +GOFILES+=$(GOFILES_$(GOOS)) +ifneq ($(CGOFILES_$(GOOS)),) +CGOFILES+=$(CGOFILES_$(GOOS)) +endif + +include ../../Make.pkg diff --git a/src/pkg/net/cgo_bsd.go b/src/pkg/net/cgo_bsd.go new file mode 100644 index 000000000..4984df4a2 --- /dev/null +++ b/src/pkg/net/cgo_bsd.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <netdb.h> +*/ +import "C" + +func cgoAddrInfoMask() C.int { + return C.AI_MASK +} diff --git a/src/pkg/net/cgo_linux.go b/src/pkg/net/cgo_linux.go new file mode 100644 index 000000000..8d4413d2d --- /dev/null +++ b/src/pkg/net/cgo_linux.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <netdb.h> +*/ +import "C" + +func cgoAddrInfoMask() C.int { + return C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL +} diff --git a/src/pkg/net/cgo_stub.go b/src/pkg/net/cgo_stub.go new file mode 100644 index 000000000..c6277cb65 --- /dev/null +++ b/src/pkg/net/cgo_stub.go @@ -0,0 +1,25 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Stub cgo routines for systems that do not use cgo to do network lookups. + +package net + +import "os" + +func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) { + return nil, nil, false +} + +func cgoLookupPort(network, service string) (port int, err os.Error, completed bool) { + return 0, nil, false +} + +func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { + return nil, nil, false +} + +func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { + return "", nil, false +} diff --git a/src/pkg/net/cgo_unix.go b/src/pkg/net/cgo_unix.go new file mode 100644 index 000000000..a3711d601 --- /dev/null +++ b/src/pkg/net/cgo_unix.go @@ -0,0 +1,148 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netdb.h> +#include <stdlib.h> +#include <unistd.h> +#include <string.h> +*/ +import "C" + +import ( + "os" + "syscall" + "unsafe" +) + +func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) { + ip, err, completed := cgoLookupIP(name) + for _, p := range ip { + addrs = append(addrs, p.String()) + } + return +} + +func cgoLookupPort(net, service string) (port int, err os.Error, completed bool) { + var res *C.struct_addrinfo + var hints C.struct_addrinfo + + switch net { + case "": + // no hints + case "tcp", "tcp4", "tcp6": + hints.ai_socktype = C.SOCK_STREAM + hints.ai_protocol = C.IPPROTO_TCP + case "udp", "udp4", "udp6": + hints.ai_socktype = C.SOCK_DGRAM + hints.ai_protocol = C.IPPROTO_UDP + default: + return 0, UnknownNetworkError(net), true + } + if len(net) >= 4 { + switch net[3] { + case '4': + hints.ai_family = C.AF_INET + case '6': + hints.ai_family = C.AF_INET6 + } + } + + s := C.CString(service) + defer C.free(unsafe.Pointer(s)) + if C.getaddrinfo(nil, s, &hints, &res) == 0 { + defer C.freeaddrinfo(res) + for r := res; r != nil; r = r.ai_next { + switch r.ai_family { + default: + continue + case C.AF_INET: + sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr)) + p := (*[2]byte)(unsafe.Pointer(&sa.Port)) + return int(p[0])<<8 | int(p[1]), nil, true + case C.AF_INET6: + sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr)) + p := (*[2]byte)(unsafe.Pointer(&sa.Port)) + return int(p[0])<<8 | int(p[1]), nil, true + } + } + } + return 0, &AddrError{"unknown port", net + "/" + service}, true +} + +func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, completed bool) { + var res *C.struct_addrinfo + var hints C.struct_addrinfo + + // NOTE(rsc): In theory there are approximately balanced + // arguments for and against including AI_ADDRCONFIG + // in the flags (it includes IPv4 results only on IPv4 systems, + // and similarly for IPv6), but in practice setting it causes + // getaddrinfo to return the wrong canonical name on Linux. + // So definitely leave it out. + hints.ai_flags = (C.AI_ALL | C.AI_V4MAPPED | C.AI_CANONNAME) & cgoAddrInfoMask() + + h := C.CString(name) + defer C.free(unsafe.Pointer(h)) + gerrno, err := C.getaddrinfo(h, nil, &hints, &res) + if gerrno != 0 { + var str string + if gerrno == C.EAI_NONAME { + str = noSuchHost + } else if gerrno == C.EAI_SYSTEM { + str = err.String() + } else { + str = C.GoString(C.gai_strerror(gerrno)) + } + return nil, "", &DNSError{Error: str, Name: name}, true + } + defer C.freeaddrinfo(res) + if res != nil { + cname = C.GoString(res.ai_canonname) + if cname == "" { + cname = name + } + if len(cname) > 0 && cname[len(cname)-1] != '.' { + cname += "." + } + } + for r := res; r != nil; r = r.ai_next { + // Everything comes back twice, once for UDP and once for TCP. + if r.ai_socktype != C.SOCK_STREAM { + continue + } + switch r.ai_family { + default: + continue + case C.AF_INET: + sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr)) + addrs = append(addrs, copyIP(sa.Addr[:])) + case C.AF_INET6: + sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr)) + addrs = append(addrs, copyIP(sa.Addr[:])) + } + } + return addrs, cname, nil, true +} + +func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { + addrs, _, err, completed = cgoLookupIPCNAME(name) + return +} + +func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { + _, cname, err, completed = cgoLookupIPCNAME(name) + return +} + +func copyIP(x IP) IP { + y := make(IP, len(x)) + copy(y, x) + return y +} diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go new file mode 100644 index 000000000..10c67dcc4 --- /dev/null +++ b/src/pkg/net/dial.go @@ -0,0 +1,153 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import "os" + +func resolveNetAddr(op, net, addr string) (a Addr, err os.Error) { + if addr == "" { + return nil, &OpError{op, net, nil, errMissingAddress} + } + switch net { + case "tcp", "tcp4", "tcp6": + a, err = ResolveTCPAddr(net, addr) + case "udp", "udp4", "udp6": + a, err = ResolveUDPAddr(net, addr) + case "unix", "unixgram", "unixpacket": + a, err = ResolveUnixAddr(net, addr) + case "ip", "ip4", "ip6": + a, err = ResolveIPAddr(net, addr) + default: + err = UnknownNetworkError(net) + } + if err != nil { + return nil, &OpError{op, net + " " + addr, nil, err} + } + return +} + +// Dial connects to the address addr on the network net. +// +// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), +// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" +// (IPv4-only), "ip6" (IPv6-only), "unix" and "unixgram". +// +// For IP networks, addresses have the form host:port. If host is +// a literal IPv6 address, it must be enclosed in square brackets. +// The functions JoinHostPort and SplitHostPort manipulate +// addresses in this form. +// +// Examples: +// Dial("tcp", "12.34.56.78:80") +// Dial("tcp", "google.com:80") +// Dial("tcp", "[de:ad:be:ef::ca:fe]:80") +// +func Dial(net, addr string) (c Conn, err os.Error) { + addri, err := resolveNetAddr("dial", net, addr) + if err != nil { + return nil, err + } + switch ra := addri.(type) { + case *TCPAddr: + c, err = DialTCP(net, nil, ra) + case *UDPAddr: + c, err = DialUDP(net, nil, ra) + case *UnixAddr: + c, err = DialUnix(net, nil, ra) + case *IPAddr: + c, err = DialIP(net, nil, ra) + default: + err = UnknownNetworkError(net) + } + if err != nil { + return nil, &OpError{"dial", net + " " + addr, nil, err} + } + return +} + +// Listen announces on the local network address laddr. +// The network string net must be a stream-oriented +// network: "tcp", "tcp4", "tcp6", or "unix", or "unixpacket". +func Listen(net, laddr string) (l Listener, err os.Error) { + switch net { + case "tcp", "tcp4", "tcp6": + var la *TCPAddr + if laddr != "" { + if la, err = ResolveTCPAddr(net, laddr); err != nil { + return nil, err + } + } + l, err := ListenTCP(net, la) + if err != nil { + return nil, err + } + return l, nil + case "unix", "unixpacket": + var la *UnixAddr + if laddr != "" { + if la, err = ResolveUnixAddr(net, laddr); err != nil { + return nil, err + } + } + l, err := ListenUnix(net, la) + if err != nil { + return nil, err + } + return l, nil + } + return nil, UnknownNetworkError(net) +} + +// ListenPacket announces on the local network address laddr. +// The network string net must be a packet-oriented network: +// "udp", "udp4", "udp6", or "unixgram". +func ListenPacket(net, laddr string) (c PacketConn, err os.Error) { + switch net { + case "udp", "udp4", "udp6": + var la *UDPAddr + if laddr != "" { + if la, err = ResolveUDPAddr(net, laddr); err != nil { + return nil, err + } + } + c, err := ListenUDP(net, la) + if err != nil { + return nil, err + } + return c, nil + case "unixgram": + var la *UnixAddr + if laddr != "" { + if la, err = ResolveUnixAddr(net, laddr); err != nil { + return nil, err + } + } + c, err := DialUnix(net, la, nil) + if err != nil { + return nil, err + } + return c, nil + } + + var rawnet string + if rawnet, _, err = splitNetProto(net); err != nil { + switch rawnet { + case "ip", "ip4", "ip6": + var la *IPAddr + if laddr != "" { + if la, err = ResolveIPAddr(rawnet, laddr); err != nil { + return nil, err + } + } + c, err := ListenIP(net, la) + if err != nil { + return nil, err + } + return c, nil + } + } + + return nil, UnknownNetworkError(net) +} diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go new file mode 100644 index 000000000..9ad1770da --- /dev/null +++ b/src/pkg/net/dialgoogle_test.go @@ -0,0 +1,169 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "flag" + "fmt" + "io" + "strings" + "syscall" + "testing" +) + +// If an IPv6 tunnel is running, we can try dialing a real IPv6 address. +var ipv6 = flag.Bool("ipv6", false, "assume ipv6 tunnel is present") + +// fd is already connected to the destination, port 80. +// Run an HTTP request to fetch the appropriate page. +func fetchGoogle(t *testing.T, fd Conn, network, addr string) { + req := []byte("GET /intl/en/privacy/ HTTP/1.0\r\nHost: www.google.com\r\n\r\n") + n, err := fd.Write(req) + + buf := make([]byte, 1000) + n, err = io.ReadFull(fd, buf) + + if n < 1000 { + t.Errorf("fetchGoogle: short HTTP read from %s %s - %v", network, addr, err) + return + } +} + +func doDial(t *testing.T, network, addr string) { + fd, err := Dial(network, addr) + if err != nil { + t.Errorf("Dial(%q, %q, %q) = _, %v", network, "", addr, err) + return + } + fetchGoogle(t, fd, network, addr) + fd.Close() +} + +func TestLookupCNAME(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return + } + cname, err := LookupCNAME("www.google.com") + if !strings.HasSuffix(cname, ".l.google.com.") || err != nil { + t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "*.l.google.com.", nil`, cname, err) + } +} + +var googleaddrsipv4 = []string{ + "%d.%d.%d.%d:80", + "www.google.com:80", + "%d.%d.%d.%d:http", + "www.google.com:http", + "%03d.%03d.%03d.%03d:0080", + "[::ffff:%d.%d.%d.%d]:80", + "[::ffff:%02x%02x:%02x%02x]:80", + "[0:0:0:0:0000:ffff:%d.%d.%d.%d]:80", + "[0:0:0:0:000000:ffff:%d.%d.%d.%d]:80", + "[0:0:0:0:0:ffff::%d.%d.%d.%d]:80", +} + +func TestDialGoogleIPv4(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return + } + + // Insert an actual IPv4 address for google.com + // into the table. + addrs, err := LookupIP("www.google.com") + if err != nil { + t.Fatalf("lookup www.google.com: %v", err) + } + var ip IP + for _, addr := range addrs { + if x := addr.To4(); x != nil { + ip = x + break + } + } + if ip == nil { + t.Fatalf("no IPv4 addresses for www.google.com") + } + + for i, s := range googleaddrsipv4 { + if strings.Contains(s, "%") { + googleaddrsipv4[i] = fmt.Sprintf(s, ip[0], ip[1], ip[2], ip[3]) + } + } + + for i := 0; i < len(googleaddrsipv4); i++ { + addr := googleaddrsipv4[i] + if addr == "" { + continue + } + t.Logf("-- %s --", addr) + doDial(t, "tcp", addr) + if addr[0] != '[' { + doDial(t, "tcp4", addr) + if supportsIPv6 { + // make sure syscall.SocketDisableIPv6 flag works. + syscall.SocketDisableIPv6 = true + doDial(t, "tcp", addr) + doDial(t, "tcp4", addr) + syscall.SocketDisableIPv6 = false + } + } + } +} + +var googleaddrsipv6 = []string{ + "[%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x]:80", + "ipv6.google.com:80", + "[%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x]:http", + "ipv6.google.com:http", +} + +func TestDialGoogleIPv6(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return + } + // Only run tcp6 if the kernel will take it. + if !*ipv6 || !supportsIPv6 { + return + } + + // Insert an actual IPv6 address for ipv6.google.com + // into the table. + addrs, err := LookupIP("ipv6.google.com") + if err != nil { + t.Fatalf("lookup ipv6.google.com: %v", err) + } + var ip IP + for _, addr := range addrs { + if x := addr.To16(); x != nil { + ip = x + break + } + } + if ip == nil { + t.Fatalf("no IPv6 addresses for ipv6.google.com") + } + + for i, s := range googleaddrsipv6 { + if strings.Contains(s, "%") { + googleaddrsipv6[i] = fmt.Sprintf(s, ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], ip[8], ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15]) + } + } + + for i := 0; i < len(googleaddrsipv6); i++ { + addr := googleaddrsipv6[i] + if addr == "" { + continue + } + t.Logf("-- %s --", addr) + doDial(t, "tcp", addr) + doDial(t, "tcp6", addr) + } +} diff --git a/src/pkg/net/dict/Makefile b/src/pkg/net/dict/Makefile new file mode 100644 index 000000000..eaa9e6531 --- /dev/null +++ b/src/pkg/net/dict/Makefile @@ -0,0 +1,7 @@ +include ../../../Make.inc + +TARG=net/dict +GOFILES=\ + dict.go\ + +include ../../../Make.pkg diff --git a/src/pkg/net/dict/dict.go b/src/pkg/net/dict/dict.go new file mode 100644 index 000000000..b146ea212 --- /dev/null +++ b/src/pkg/net/dict/dict.go @@ -0,0 +1,211 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package dict implements the Dictionary Server Protocol +// as defined in RFC 2229. +package dict + +import ( + "net/textproto" + "os" + "strconv" + "strings" +) + +// A Client represents a client connection to a dictionary server. +type Client struct { + text *textproto.Conn +} + +// Dial returns a new client connected to a dictionary server at +// addr on the given network. +func Dial(network, addr string) (*Client, os.Error) { + text, err := textproto.Dial(network, addr) + if err != nil { + return nil, err + } + _, _, err = text.ReadCodeLine(220) + if err != nil { + text.Close() + return nil, err + } + return &Client{text: text}, nil +} + +// Close closes the connection to the dictionary server. +func (c *Client) Close() os.Error { + return c.text.Close() +} + +// A Dict represents a dictionary available on the server. +type Dict struct { + Name string // short name of dictionary + Desc string // long description +} + +// Dicts returns a list of the dictionaries available on the server. +func (c *Client) Dicts() ([]Dict, os.Error) { + id, err := c.text.Cmd("SHOW DB") + if err != nil { + return nil, err + } + + c.text.StartResponse(id) + defer c.text.EndResponse(id) + + _, _, err = c.text.ReadCodeLine(110) + if err != nil { + return nil, err + } + lines, err := c.text.ReadDotLines() + if err != nil { + return nil, err + } + _, _, err = c.text.ReadCodeLine(250) + + dicts := make([]Dict, len(lines)) + for i := range dicts { + d := &dicts[i] + a, _ := fields(lines[i]) + if len(a) < 2 { + return nil, textproto.ProtocolError("invalid dictionary: " + lines[i]) + } + d.Name = a[0] + d.Desc = a[1] + } + return dicts, err +} + +// A Defn represents a definition. +type Defn struct { + Dict Dict // Dict where definition was found + Word string // Word being defined + Text []byte // Definition text, typically multiple lines +} + +// Define requests the definition of the given word. +// The argument dict names the dictionary to use, +// the Name field of a Dict returned by Dicts. +// +// The special dictionary name "*" means to look in all the +// server's dictionaries. +// The special dictionary name "!" means to look in all the +// server's dictionaries in turn, stopping after finding the word +// in one of them. +func (c *Client) Define(dict, word string) ([]*Defn, os.Error) { + id, err := c.text.Cmd("DEFINE %s %q", dict, word) + if err != nil { + return nil, err + } + + c.text.StartResponse(id) + defer c.text.EndResponse(id) + + _, line, err := c.text.ReadCodeLine(150) + if err != nil { + return nil, err + } + a, _ := fields(line) + if len(a) < 1 { + return nil, textproto.ProtocolError("malformed response: " + line) + } + n, err := strconv.Atoi(a[0]) + if err != nil { + return nil, textproto.ProtocolError("invalid definition count: " + a[0]) + } + def := make([]*Defn, n) + for i := 0; i < n; i++ { + _, line, err = c.text.ReadCodeLine(151) + if err != nil { + return nil, err + } + a, _ := fields(line) + if len(a) < 3 { + // skip it, to keep protocol in sync + i-- + n-- + def = def[0:n] + continue + } + d := &Defn{Word: a[0], Dict: Dict{a[1], a[2]}} + d.Text, err = c.text.ReadDotBytes() + if err != nil { + return nil, err + } + def[i] = d + } + _, _, err = c.text.ReadCodeLine(250) + return def, err +} + +// Fields returns the fields in s. +// Fields are space separated unquoted words +// or quoted with single or double quote. +func fields(s string) ([]string, os.Error) { + var v []string + i := 0 + for { + for i < len(s) && (s[i] == ' ' || s[i] == '\t') { + i++ + } + if i >= len(s) { + break + } + if s[i] == '"' || s[i] == '\'' { + q := s[i] + // quoted string + var j int + for j = i + 1; ; j++ { + if j >= len(s) { + return nil, textproto.ProtocolError("malformed quoted string") + } + if s[j] == '\\' { + j++ + continue + } + if s[j] == q { + j++ + break + } + } + v = append(v, unquote(s[i+1:j-1])) + i = j + } else { + // atom + var j int + for j = i; j < len(s); j++ { + if s[j] == ' ' || s[j] == '\t' || s[j] == '\\' || s[j] == '"' || s[j] == '\'' { + break + } + } + v = append(v, s[i:j]) + i = j + } + if i < len(s) { + c := s[i] + if c != ' ' && c != '\t' { + return nil, textproto.ProtocolError("quotes not on word boundaries") + } + } + } + return v, nil +} + +func unquote(s string) string { + if strings.Index(s, "\\") < 0 { + return s + } + b := []byte(s) + w := 0 + for r := 0; r < len(b); r++ { + c := b[r] + if c == '\\' { + r++ + c = b[r] + } + b[w] = c + w++ + } + return string(b[0:w]) +} diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go new file mode 100644 index 000000000..93c04f6b5 --- /dev/null +++ b/src/pkg/net/dnsclient.go @@ -0,0 +1,247 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "fmt" + "os" + "rand" + "sort" +) + +// DNSError represents a DNS lookup error. +type DNSError struct { + Error string // description of the error + Name string // name looked for + Server string // server used + IsTimeout bool +} + +func (e *DNSError) String() string { + if e == nil { + return "<nil>" + } + s := "lookup " + e.Name + if e.Server != "" { + s += " on " + e.Server + } + s += ": " + e.Error + return s +} + +func (e *DNSError) Timeout() bool { return e.IsTimeout } +func (e *DNSError) Temporary() bool { return e.IsTimeout } + +const noSuchHost = "no such host" + +// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP +// address addr suitable for rDNS (PTR) record lookup or an error if it fails +// to parse the IP address. +func reverseaddr(addr string) (arpa string, err os.Error) { + ip := ParseIP(addr) + if ip == nil { + return "", &DNSError{Error: "unrecognized address", Name: addr} + } + if ip.To4() != nil { + return fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa.", ip[15], ip[14], ip[13], ip[12]), nil + } + // Must be IPv6 + var buf bytes.Buffer + // Add it, in reverse, to the buffer + for i := len(ip) - 1; i >= 0; i-- { + s := fmt.Sprintf("%02x", ip[i]) + buf.WriteByte(s[1]) + buf.WriteByte('.') + buf.WriteByte(s[0]) + buf.WriteByte('.') + } + // Append "ip6.arpa." and return (buf already has the final .) + return buf.String() + "ip6.arpa.", nil +} + +// Find answer for name in dns message. +// On return, if err == nil, addrs != nil. +func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err os.Error) { + addrs = make([]dnsRR, 0, len(dns.answer)) + + if dns.rcode == dnsRcodeNameError && dns.recursion_available { + return "", nil, &DNSError{Error: noSuchHost, Name: name} + } + if dns.rcode != dnsRcodeSuccess { + // None of the error codes make sense + // for the query we sent. If we didn't get + // a name error and we didn't get success, + // the server is behaving incorrectly. + return "", nil, &DNSError{Error: "server misbehaving", Name: name, Server: server} + } + + // Look for the name. + // Presotto says it's okay to assume that servers listed in + // /etc/resolv.conf are recursive resolvers. + // We asked for recursion, so it should have included + // all the answers we need in this one packet. +Cname: + for cnameloop := 0; cnameloop < 10; cnameloop++ { + addrs = addrs[0:0] + for _, rr := range dns.answer { + if _, justHeader := rr.(*dnsRR_Header); justHeader { + // Corrupt record: we only have a + // header. That header might say it's + // of type qtype, but we don't + // actually have it. Skip. + continue + } + h := rr.Header() + if h.Class == dnsClassINET && h.Name == name { + switch h.Rrtype { + case qtype: + addrs = append(addrs, rr) + case dnsTypeCNAME: + // redirect to cname + name = rr.(*dnsRR_CNAME).Cname + continue Cname + } + } + } + if len(addrs) == 0 { + return "", nil, &DNSError{Error: noSuchHost, Name: name, Server: server} + } + return name, addrs, nil + } + + return "", nil, &DNSError{Error: "too many redirects", Name: name, Server: server} +} + +func isDomainName(s string) bool { + // See RFC 1035, RFC 3696. + if len(s) == 0 { + return false + } + if len(s) > 255 { + return false + } + if s[len(s)-1] != '.' { // simplify checking loop: make name end in dot + s += "." + } + + last := byte('.') + ok := false // ok once we've seen a letter + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + ok = true + partlen++ + case '0' <= c && c <= '9': + // fine + partlen++ + case c == '-': + // byte before dash cannot be dot + if last == '.' { + return false + } + partlen++ + case c == '.': + // byte before dot cannot be dot, dash + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + + return ok +} + +// An SRV represents a single DNS SRV record. +type SRV struct { + Target string + Port uint16 + Priority uint16 + Weight uint16 +} + +// byPriorityWeight sorts SRV records by ascending priority and weight. +type byPriorityWeight []*SRV + +func (s byPriorityWeight) Len() int { return len(s) } + +func (s byPriorityWeight) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func (s byPriorityWeight) Less(i, j int) bool { + return s[i].Priority < s[j].Priority || + (s[i].Priority == s[j].Priority && s[i].Weight < s[j].Weight) +} + +// shuffleByWeight shuffles SRV records by weight using the algorithm +// described in RFC 2782. +func (addrs byPriorityWeight) shuffleByWeight() { + sum := 0 + for _, addr := range addrs { + sum += int(addr.Weight) + } + for sum > 0 && len(addrs) > 1 { + s := 0 + n := rand.Intn(sum + 1) + for i := range addrs { + s += int(addrs[i].Weight) + if s >= n { + if i > 0 { + t := addrs[i] + copy(addrs[1:i+1], addrs[0:i]) + addrs[0] = t + } + break + } + } + sum -= int(addrs[0].Weight) + addrs = addrs[1:] + } +} + +// sort reorders SRV records as specified in RFC 2782. +func (addrs byPriorityWeight) sort() { + sort.Sort(addrs) + i := 0 + for j := 1; j < len(addrs); j++ { + if addrs[i].Priority != addrs[j].Priority { + addrs[i:j].shuffleByWeight() + i = j + } + } + addrs[i:].shuffleByWeight() +} + +// An MX represents a single DNS MX record. +type MX struct { + Host string + Pref uint16 +} + +// byPref implements sort.Interface to sort MX records by preference +type byPref []*MX + +func (s byPref) Len() int { return len(s) } + +func (s byPref) Less(i, j int) bool { return s[i].Pref < s[j].Pref } + +func (s byPref) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// sort reorders MX records as specified in RFC 5321. +func (s byPref) sort() { + for i := range s { + j := rand.Intn(i + 1) + s[i], s[j] = s[j], s[i] + } + sort.Sort(s) +} diff --git a/src/pkg/net/dnsclient_unix.go b/src/pkg/net/dnsclient_unix.go new file mode 100644 index 000000000..f407b1778 --- /dev/null +++ b/src/pkg/net/dnsclient_unix.go @@ -0,0 +1,261 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS client: see RFC 1035. +// Has to be linked into package net for Dial. + +// TODO(rsc): +// Check periodically whether /etc/resolv.conf has changed. +// Could potentially handle many outstanding lookups faster. +// Could have a small cache. +// Random UDP source port (net.Dial should do that for us). +// Random request IDs. + +package net + +import ( + "os" + "rand" + "sync" + "time" +) + +// Send a request on the connection and hope for a reply. +// Up to cfg.attempts attempts. +func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, os.Error) { + if len(name) >= 256 { + return nil, &DNSError{Error: "name too long", Name: name} + } + out := new(dnsMsg) + out.id = uint16(rand.Int()) ^ uint16(time.Nanoseconds()) + out.question = []dnsQuestion{ + {name, qtype, dnsClassINET}, + } + out.recursion_desired = true + msg, ok := out.Pack() + if !ok { + return nil, &DNSError{Error: "internal error - cannot pack message", Name: name} + } + + for attempt := 0; attempt < cfg.attempts; attempt++ { + n, err := c.Write(msg) + if err != nil { + return nil, err + } + + c.SetReadTimeout(int64(cfg.timeout) * 1e9) // nanoseconds + + buf := make([]byte, 2000) // More than enough. + n, err = c.Read(buf) + if err != nil { + if e, ok := err.(Error); ok && e.Timeout() { + continue + } + return nil, err + } + buf = buf[0:n] + in := new(dnsMsg) + if !in.Unpack(buf) || in.id != out.id { + continue + } + return in, nil + } + var server string + if a := c.RemoteAddr(); a != nil { + server = a.String() + } + return nil, &DNSError{Error: "no answer from server", Name: name, Server: server, IsTimeout: true} +} + +// Do a lookup for a single name, which must be rooted +// (otherwise answer will not find the answers). +func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err os.Error) { + if len(cfg.servers) == 0 { + return "", nil, &DNSError{Error: "no DNS servers", Name: name} + } + for i := 0; i < len(cfg.servers); i++ { + // Calling Dial here is scary -- we have to be sure + // not to dial a name that will require a DNS lookup, + // or Dial will call back here to translate it. + // The DNS config parser has already checked that + // all the cfg.servers[i] are IP addresses, which + // Dial will use without a DNS lookup. + server := cfg.servers[i] + ":53" + c, cerr := Dial("udp", server) + if cerr != nil { + err = cerr + continue + } + msg, merr := exchange(cfg, c, name, qtype) + c.Close() + if merr != nil { + err = merr + continue + } + cname, addrs, err = answer(name, server, msg, qtype) + if err == nil || err.(*DNSError).Error == noSuchHost { + break + } + } + return +} + +func convertRR_A(records []dnsRR) []IP { + addrs := make([]IP, len(records)) + for i, rr := range records { + a := rr.(*dnsRR_A).A + addrs[i] = IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a)) + } + return addrs +} + +func convertRR_AAAA(records []dnsRR) []IP { + addrs := make([]IP, len(records)) + for i, rr := range records { + a := make(IP, 16) + copy(a, rr.(*dnsRR_AAAA).AAAA[:]) + addrs[i] = a + } + return addrs +} + +var cfg *dnsConfig +var dnserr os.Error + +func loadConfig() { cfg, dnserr = dnsReadConfig() } + +var onceLoadConfig sync.Once + +func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Error) { + if !isDomainName(name) { + return name, nil, &DNSError{Error: "invalid domain name", Name: name} + } + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } + // If name is rooted (trailing dot) or has enough dots, + // try it by itself first. + rooted := len(name) > 0 && name[len(name)-1] == '.' + if rooted || count(name, '.') >= cfg.ndots { + rname := name + if !rooted { + rname += "." + } + // Can try as ordinary name. + cname, addrs, err = tryOneName(cfg, rname, qtype) + if err == nil { + return + } + } + if rooted { + return + } + + // Otherwise, try suffixes. + for i := 0; i < len(cfg.search); i++ { + rname := name + "." + cfg.search[i] + if rname[len(rname)-1] != '.' { + rname += "." + } + cname, addrs, err = tryOneName(cfg, rname, qtype) + if err == nil { + return + } + } + + // Last ditch effort: try unsuffixed. + rname := name + if !rooted { + rname += "." + } + cname, addrs, err = tryOneName(cfg, rname, qtype) + if err == nil { + return + } + return +} + +// goLookupHost is the native Go implementation of LookupHost. +// Used only if cgoLookupHost refuses to handle the request +// (that is, only if cgoLookupHost is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. +func goLookupHost(name string) (addrs []string, err os.Error) { + // Use entries from /etc/hosts if they match. + addrs = lookupStaticHost(name) + if len(addrs) > 0 { + return + } + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } + ips, err := goLookupIP(name) + if err != nil { + return + } + addrs = make([]string, 0, len(ips)) + for _, ip := range ips { + addrs = append(addrs, ip.String()) + } + return +} + +// goLookupIP is the native Go implementation of LookupIP. +// Used only if cgoLookupIP refuses to handle the request +// (that is, only if cgoLookupIP is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. +func goLookupIP(name string) (addrs []IP, err os.Error) { + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } + var records []dnsRR + var cname string + cname, records, err = lookup(name, dnsTypeA) + if err != nil { + return + } + addrs = convertRR_A(records) + if cname != "" { + name = cname + } + _, records, err = lookup(name, dnsTypeAAAA) + if err != nil && len(addrs) > 0 { + // Ignore error because A lookup succeeded. + err = nil + } + if err != nil { + return + } + addrs = append(addrs, convertRR_AAAA(records)...) + return +} + +// goLookupCNAME is the native Go implementation of LookupCNAME. +// Used only if cgoLookupCNAME refuses to handle the request +// (that is, only if cgoLookupCNAME is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. +func goLookupCNAME(name string) (cname string, err os.Error) { + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } + _, rr, err := lookup(name, dnsTypeCNAME) + if err != nil { + return + } + cname = rr[0].(*dnsRR_CNAME).Cname + return +} diff --git a/src/pkg/net/dnsconfig.go b/src/pkg/net/dnsconfig.go new file mode 100644 index 000000000..54e334342 --- /dev/null +++ b/src/pkg/net/dnsconfig.go @@ -0,0 +1,119 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Read system DNS config from /etc/resolv.conf + +package net + +import "os" + +type dnsConfig struct { + servers []string // servers to use + search []string // suffixes to append to local name + ndots int // number of dots in name to trigger absolute lookup + timeout int // seconds before giving up on packet + attempts int // lost packets before giving up on server + rotate bool // round robin among servers +} + +var dnsconfigError os.Error + +type DNSConfigError struct { + Error os.Error +} + +func (e *DNSConfigError) String() string { + return "error reading DNS config: " + e.Error.String() +} + +func (e *DNSConfigError) Timeout() bool { return false } +func (e *DNSConfigError) Temporary() bool { return false } + +// See resolv.conf(5) on a Linux machine. +// TODO(rsc): Supposed to call uname() and chop the beginning +// of the host name to get the default search domain. +// We assume it's in resolv.conf anyway. +func dnsReadConfig() (*dnsConfig, os.Error) { + file, err := open("/etc/resolv.conf") + if err != nil { + return nil, &DNSConfigError{err} + } + conf := new(dnsConfig) + conf.servers = make([]string, 3)[0:0] // small, but the standard limit + conf.search = make([]string, 0) + conf.ndots = 1 + conf.timeout = 5 + conf.attempts = 2 + conf.rotate = false + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + f := getFields(line) + if len(f) < 1 { + continue + } + switch f[0] { + case "nameserver": // add one name server + a := conf.servers + n := len(a) + if len(f) > 1 && n < cap(a) { + // One more check: make sure server name is + // just an IP address. Otherwise we need DNS + // to look it up. + name := f[1] + switch len(ParseIP(name)) { + case 16: + name = "[" + name + "]" + fallthrough + case 4: + a = a[0 : n+1] + a[n] = name + conf.servers = a + } + } + + case "domain": // set search path to just this domain + if len(f) > 1 { + conf.search = make([]string, 1) + conf.search[0] = f[1] + } else { + conf.search = make([]string, 0) + } + + case "search": // set search path to given servers + conf.search = make([]string, len(f)-1) + for i := 0; i < len(conf.search); i++ { + conf.search[i] = f[i+1] + } + + case "options": // magic options + for i := 1; i < len(f); i++ { + s := f[i] + switch { + case len(s) >= 6 && s[0:6] == "ndots:": + n, _, _ := dtoi(s, 6) + if n < 1 { + n = 1 + } + conf.ndots = n + case len(s) >= 8 && s[0:8] == "timeout:": + n, _, _ := dtoi(s, 8) + if n < 1 { + n = 1 + } + conf.timeout = n + case len(s) >= 8 && s[0:9] == "attempts:": + n, _, _ := dtoi(s, 9) + if n < 1 { + n = 1 + } + conf.attempts = n + case s == "rotate": + conf.rotate = true + } + } + } + } + file.close() + + return conf, nil +} diff --git a/src/pkg/net/dnsmsg.go b/src/pkg/net/dnsmsg.go new file mode 100644 index 000000000..7595aa2ab --- /dev/null +++ b/src/pkg/net/dnsmsg.go @@ -0,0 +1,779 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS packet assembly. See RFC 1035. +// +// This is intended to support name resolution during net.Dial. +// It doesn't have to be blazing fast. +// +// Rather than write the usual handful of routines to pack and +// unpack every message that can appear on the wire, we use +// reflection to write a generic pack/unpack for structs and then +// use it. Thus, if in the future we need to define new message +// structs, no new pack/unpack/printing code needs to be written. +// +// The first half of this file defines the DNS message formats. +// The second half implements the conversion to and from wire format. +// A few of the structure elements have string tags to aid the +// generic pack/unpack routines. +// +// TODO(rsc): There are enough names defined in this file that they're all +// prefixed with dns. Perhaps put this in its own package later. + +package net + +import ( + "fmt" + "os" + "reflect" +) + +// Packet formats + +// Wire constants. +const ( + // valid dnsRR_Header.Rrtype and dnsQuestion.qtype + dnsTypeA = 1 + dnsTypeNS = 2 + dnsTypeMD = 3 + dnsTypeMF = 4 + dnsTypeCNAME = 5 + dnsTypeSOA = 6 + dnsTypeMB = 7 + dnsTypeMG = 8 + dnsTypeMR = 9 + dnsTypeNULL = 10 + dnsTypeWKS = 11 + dnsTypePTR = 12 + dnsTypeHINFO = 13 + dnsTypeMINFO = 14 + dnsTypeMX = 15 + dnsTypeTXT = 16 + dnsTypeAAAA = 28 + dnsTypeSRV = 33 + + // valid dnsQuestion.qtype only + dnsTypeAXFR = 252 + dnsTypeMAILB = 253 + dnsTypeMAILA = 254 + dnsTypeALL = 255 + + // valid dnsQuestion.qclass + dnsClassINET = 1 + dnsClassCSNET = 2 + dnsClassCHAOS = 3 + dnsClassHESIOD = 4 + dnsClassANY = 255 + + // dnsMsg.rcode + dnsRcodeSuccess = 0 + dnsRcodeFormatError = 1 + dnsRcodeServerFailure = 2 + dnsRcodeNameError = 3 + dnsRcodeNotImplemented = 4 + dnsRcodeRefused = 5 +) + +// The wire format for the DNS packet header. +type dnsHeader struct { + Id uint16 + Bits uint16 + Qdcount, Ancount, Nscount, Arcount uint16 +} + +const ( + // dnsHeader.Bits + _QR = 1 << 15 // query/response (response=1) + _AA = 1 << 10 // authoritative + _TC = 1 << 9 // truncated + _RD = 1 << 8 // recursion desired + _RA = 1 << 7 // recursion available +) + +// DNS queries. +type dnsQuestion struct { + Name string `net:"domain-name"` // `net:"domain-name"` specifies encoding; see packers below + Qtype uint16 + Qclass uint16 +} + +// DNS responses (resource records). +// There are many types of messages, +// but they all share the same header. +type dnsRR_Header struct { + Name string `net:"domain-name"` + Rrtype uint16 + Class uint16 + Ttl uint32 + Rdlength uint16 // length of data after header +} + +func (h *dnsRR_Header) Header() *dnsRR_Header { + return h +} + +type dnsRR interface { + Header() *dnsRR_Header +} + +// Specific DNS RR formats for each query type. + +type dnsRR_CNAME struct { + Hdr dnsRR_Header + Cname string `net:"domain-name"` +} + +func (rr *dnsRR_CNAME) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_HINFO struct { + Hdr dnsRR_Header + Cpu string + Os string +} + +func (rr *dnsRR_HINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MB struct { + Hdr dnsRR_Header + Mb string `net:"domain-name"` +} + +func (rr *dnsRR_MB) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MG struct { + Hdr dnsRR_Header + Mg string `net:"domain-name"` +} + +func (rr *dnsRR_MG) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MINFO struct { + Hdr dnsRR_Header + Rmail string `net:"domain-name"` + Email string `net:"domain-name"` +} + +func (rr *dnsRR_MINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MR struct { + Hdr dnsRR_Header + Mr string `net:"domain-name"` +} + +func (rr *dnsRR_MR) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MX struct { + Hdr dnsRR_Header + Pref uint16 + Mx string `net:"domain-name"` +} + +func (rr *dnsRR_MX) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_NS struct { + Hdr dnsRR_Header + Ns string `net:"domain-name"` +} + +func (rr *dnsRR_NS) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_PTR struct { + Hdr dnsRR_Header + Ptr string `net:"domain-name"` +} + +func (rr *dnsRR_PTR) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_SOA struct { + Hdr dnsRR_Header + Ns string `net:"domain-name"` + Mbox string `net:"domain-name"` + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + Minttl uint32 +} + +func (rr *dnsRR_SOA) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_TXT struct { + Hdr dnsRR_Header + Txt string // not domain name +} + +func (rr *dnsRR_TXT) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_SRV struct { + Hdr dnsRR_Header + Priority uint16 + Weight uint16 + Port uint16 + Target string `net:"domain-name"` +} + +func (rr *dnsRR_SRV) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_A struct { + Hdr dnsRR_Header + A uint32 `net:"ipv4"` +} + +func (rr *dnsRR_A) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_AAAA struct { + Hdr dnsRR_Header + AAAA [16]byte `net:"ipv6"` +} + +func (rr *dnsRR_AAAA) Header() *dnsRR_Header { + return &rr.Hdr +} + +// Packing and unpacking. +// +// All the packers and unpackers take a (msg []byte, off int) +// and return (off1 int, ok bool). If they return ok==false, they +// also return off1==len(msg), so that the next unpacker will +// also fail. This lets us avoid checks of ok until the end of a +// packing sequence. + +// Map of constructors for each RR wire type. +var rr_mk = map[int]func() dnsRR{ + dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) }, + dnsTypeHINFO: func() dnsRR { return new(dnsRR_HINFO) }, + dnsTypeMB: func() dnsRR { return new(dnsRR_MB) }, + dnsTypeMG: func() dnsRR { return new(dnsRR_MG) }, + dnsTypeMINFO: func() dnsRR { return new(dnsRR_MINFO) }, + dnsTypeMR: func() dnsRR { return new(dnsRR_MR) }, + dnsTypeMX: func() dnsRR { return new(dnsRR_MX) }, + dnsTypeNS: func() dnsRR { return new(dnsRR_NS) }, + dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) }, + dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) }, + dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) }, + dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) }, + dnsTypeA: func() dnsRR { return new(dnsRR_A) }, + dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) }, +} + +// Pack a domain name s into msg[off:]. +// Domain names are a sequence of counted strings +// split at the dots. They end with a zero-length string. +func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) { + // Add trailing dot to canonicalize name. + if n := len(s); n == 0 || s[n-1] != '.' { + s += "." + } + + // Each dot ends a segment of the name. + // We trade each dot byte for a length byte. + // There is also a trailing zero. + // Check that we have all the space we need. + tot := len(s) + 1 + if off+tot > len(msg) { + return len(msg), false + } + + // Emit sequence of counted strings, chopping at dots. + begin := 0 + for i := 0; i < len(s); i++ { + if s[i] == '.' { + if i-begin >= 1<<6 { // top two bits of length must be clear + return len(msg), false + } + msg[off] = byte(i - begin) + off++ + for j := begin; j < i; j++ { + msg[off] = s[j] + off++ + } + begin = i + 1 + } + } + msg[off] = 0 + off++ + return off, true +} + +// Unpack a domain name. +// In addition to the simple sequences of counted strings above, +// domain names are allowed to refer to strings elsewhere in the +// packet, to avoid repeating common suffixes when returning +// many entries in a single domain. The pointers are marked +// by a length byte with the top two bits set. Ignoring those +// two bits, that byte and the next give a 14 bit offset from msg[0] +// where we should pick up the trail. +// Note that if we jump elsewhere in the packet, +// we return off1 == the offset after the first pointer we found, +// which is where the next record will start. +// In theory, the pointers are only allowed to jump backward. +// We let them jump anywhere and stop jumping after a while. +func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) { + s = "" + ptr := 0 // number of pointers followed +Loop: + for { + if off >= len(msg) { + return "", len(msg), false + } + c := int(msg[off]) + off++ + switch c & 0xC0 { + case 0x00: + if c == 0x00 { + // end of name + break Loop + } + // literal string + if off+c > len(msg) { + return "", len(msg), false + } + s += string(msg[off:off+c]) + "." + off += c + case 0xC0: + // pointer to somewhere else in msg. + // remember location after first ptr, + // since that's how many bytes we consumed. + // also, don't follow too many pointers -- + // maybe there's a loop. + if off >= len(msg) { + return "", len(msg), false + } + c1 := msg[off] + off++ + if ptr == 0 { + off1 = off + } + if ptr++; ptr > 10 { + return "", len(msg), false + } + off = (c^0xC0)<<8 | int(c1) + default: + // 0x80 and 0x40 are reserved + return "", len(msg), false + } + } + if ptr == 0 { + off1 = off + } + return s, off1, true +} + +// TODO(rsc): Move into generic library? +// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string, +// [n]byte, and other (often anonymous) structs. +func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().Field(i) + switch fv := val.Field(i); fv.Kind() { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case reflect.Struct: + off, ok = packStructValue(fv, msg, off) + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + i := fv.Uint() + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := fv.Uint() + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + case reflect.Array: + if fv.Type().Elem().Kind() != reflect.Uint8 { + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + } + n := fv.Len() + if off+n > len(msg) { + return len(msg), false + } + reflect.Copy(reflect.ValueOf(msg[off:off+n]), fv) + off += n + case reflect.String: + // There are multiple string encodings. + // The tag distinguishes ordinary strings from domain names. + s := fv.String() + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) + return len(msg), false + case `net:"domain-name"`: + off, ok = packDomainName(s, msg, off) + if !ok { + return len(msg), false + } + case "": + // Counted string: 1 byte length. + if len(s) > 255 || off+1+len(s) > len(msg) { + return len(msg), false + } + msg[off] = byte(len(s)) + off++ + off += copy(msg[off:], s) + } + } + } + return off, true +} + +func structValue(any interface{}) reflect.Value { + return reflect.ValueOf(any).Elem() +} + +func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = packStructValue(structValue(any), msg, off) + return off, ok +} + +// TODO(rsc): Move into generic library? +// Unpack a reflect.StructValue from msg. +// Same restrictions as packStructValue. +func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().Field(i) + switch fv := val.Field(i); fv.Kind() { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case reflect.Struct: + off, ok = unpackStructValue(fv, msg, off) + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + i := uint16(msg[off])<<8 | uint16(msg[off+1]) + fv.SetUint(uint64(i)) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + fv.SetUint(uint64(i)) + off += 4 + case reflect.Array: + if fv.Type().Elem().Kind() != reflect.Uint8 { + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + } + n := fv.Len() + if off+n > len(msg) { + return len(msg), false + } + reflect.Copy(fv, reflect.ValueOf(msg[off:off+n])) + off += n + case reflect.String: + var s string + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) + return len(msg), false + case `net:"domain-name"`: + s, off, ok = unpackDomainName(msg, off) + if !ok { + return len(msg), false + } + case "": + if off >= len(msg) || off+1+int(msg[off]) > len(msg) { + return len(msg), false + } + n := int(msg[off]) + off++ + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = msg[off+i] + } + off += n + s = string(b) + } + fv.SetString(s) + } + } + return off, true +} + +func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = unpackStructValue(structValue(any), msg, off) + return off, ok +} + +// Generic struct printer. +// Doesn't care about the string tag `net:"domain-name"`, +// but does look for an `net:"ipv4"` tag on uint32 variables +// and the `net:"ipv6"` tag on array variables, +// printing them as IP addresses. +func printStructValue(val reflect.Value) string { + s := "{" + for i := 0; i < val.NumField(); i++ { + if i > 0 { + s += ", " + } + f := val.Type().Field(i) + if !f.Anonymous { + s += f.Name + "=" + } + fval := val.Field(i) + if fv := fval; fv.Kind() == reflect.Struct { + s += printStructValue(fv) + } else if fv := fval; (fv.Kind() == reflect.Uint || fv.Kind() == reflect.Uint8 || fv.Kind() == reflect.Uint16 || fv.Kind() == reflect.Uint32 || fv.Kind() == reflect.Uint64 || fv.Kind() == reflect.Uintptr) && f.Tag == `net:"ipv4"` { + i := fv.Uint() + s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() + } else if fv := fval; fv.Kind() == reflect.Array && f.Tag == `net:"ipv6"` { + i := fv.Interface().([]byte) + s += IP(i).String() + } else { + s += fmt.Sprint(fval.Interface()) + } + } + s += "}" + return s +} + +func printStruct(any interface{}) string { return printStructValue(structValue(any)) } + +// Resource record packer. +func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) { + var off1 int + // pack twice, once to find end of header + // and again to find end of packet. + // a bit inefficient but this doesn't need to be fast. + // off1 is end of header + // off2 is end of rr + off1, ok = packStruct(rr.Header(), msg, off) + off2, ok = packStruct(rr, msg, off) + if !ok { + return len(msg), false + } + // pack a third time; redo header with correct data length + rr.Header().Rdlength = uint16(off2 - off1) + packStruct(rr.Header(), msg, off) + return off2, true +} + +// Resource record unpacker. +func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) { + // unpack just the header, to find the rr type and length + var h dnsRR_Header + off0 := off + if off, ok = unpackStruct(&h, msg, off); !ok { + return nil, len(msg), false + } + end := off + int(h.Rdlength) + + // make an rr of that type and re-unpack. + // again inefficient but doesn't need to be fast. + mk, known := rr_mk[int(h.Rrtype)] + if !known { + return &h, end, true + } + rr = mk() + off, ok = unpackStruct(rr, msg, off0) + if off != end { + return &h, end, true + } + return rr, off, ok +} + +// Usable representation of a DNS packet. + +// A manually-unpacked version of (id, bits). +// This is in its own struct for easy printing. +type dnsMsgHdr struct { + id uint16 + response bool + opcode int + authoritative bool + truncated bool + recursion_desired bool + recursion_available bool + rcode int +} + +type dnsMsg struct { + dnsMsgHdr + question []dnsQuestion + answer []dnsRR + ns []dnsRR + extra []dnsRR +} + +func (dns *dnsMsg) Pack() (msg []byte, ok bool) { + var dh dnsHeader + + // Convert convenient dnsMsg into wire-like dnsHeader. + dh.Id = dns.id + dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode) + if dns.recursion_available { + dh.Bits |= _RA + } + if dns.recursion_desired { + dh.Bits |= _RD + } + if dns.truncated { + dh.Bits |= _TC + } + if dns.authoritative { + dh.Bits |= _AA + } + if dns.response { + dh.Bits |= _QR + } + + // Prepare variable sized arrays. + question := dns.question + answer := dns.answer + ns := dns.ns + extra := dns.extra + + dh.Qdcount = uint16(len(question)) + dh.Ancount = uint16(len(answer)) + dh.Nscount = uint16(len(ns)) + dh.Arcount = uint16(len(extra)) + + // Could work harder to calculate message size, + // but this is far more than we need and not + // big enough to hurt the allocator. + msg = make([]byte, 2000) + + // Pack it in: header and then the pieces. + off := 0 + off, ok = packStruct(&dh, msg, off) + for i := 0; i < len(question); i++ { + off, ok = packStruct(&question[i], msg, off) + } + for i := 0; i < len(answer); i++ { + off, ok = packRR(answer[i], msg, off) + } + for i := 0; i < len(ns); i++ { + off, ok = packRR(ns[i], msg, off) + } + for i := 0; i < len(extra); i++ { + off, ok = packRR(extra[i], msg, off) + } + if !ok { + return nil, false + } + return msg[0:off], true +} + +func (dns *dnsMsg) Unpack(msg []byte) bool { + // Header. + var dh dnsHeader + off := 0 + var ok bool + if off, ok = unpackStruct(&dh, msg, off); !ok { + return false + } + dns.id = dh.Id + dns.response = (dh.Bits & _QR) != 0 + dns.opcode = int(dh.Bits>>11) & 0xF + dns.authoritative = (dh.Bits & _AA) != 0 + dns.truncated = (dh.Bits & _TC) != 0 + dns.recursion_desired = (dh.Bits & _RD) != 0 + dns.recursion_available = (dh.Bits & _RA) != 0 + dns.rcode = int(dh.Bits & 0xF) + + // Arrays. + dns.question = make([]dnsQuestion, dh.Qdcount) + dns.answer = make([]dnsRR, 0, dh.Ancount) + dns.ns = make([]dnsRR, 0, dh.Nscount) + dns.extra = make([]dnsRR, 0, dh.Arcount) + + var rec dnsRR + + for i := 0; i < len(dns.question); i++ { + off, ok = unpackStruct(&dns.question[i], msg, off) + } + for i := 0; i < int(dh.Ancount); i++ { + rec, off, ok = unpackRR(msg, off) + if !ok { + return false + } + dns.answer = append(dns.answer, rec) + } + for i := 0; i < int(dh.Nscount); i++ { + rec, off, ok = unpackRR(msg, off) + if !ok { + return false + } + dns.ns = append(dns.ns, rec) + } + for i := 0; i < int(dh.Arcount); i++ { + rec, off, ok = unpackRR(msg, off) + if !ok { + return false + } + dns.extra = append(dns.extra, rec) + } + // if off != len(msg) { + // println("extra bytes in dns packet", off, "<", len(msg)); + // } + return true +} + +func (dns *dnsMsg) String() string { + s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n" + if len(dns.question) > 0 { + s += "-- Questions\n" + for i := 0; i < len(dns.question); i++ { + s += printStruct(&dns.question[i]) + "\n" + } + } + if len(dns.answer) > 0 { + s += "-- Answers\n" + for i := 0; i < len(dns.answer); i++ { + s += printStruct(dns.answer[i]) + "\n" + } + } + if len(dns.ns) > 0 { + s += "-- Name servers\n" + for i := 0; i < len(dns.ns); i++ { + s += printStruct(dns.ns[i]) + "\n" + } + } + if len(dns.extra) > 0 { + s += "-- Extra\n" + for i := 0; i < len(dns.extra); i++ { + s += printStruct(dns.extra[i]) + "\n" + } + } + return s +} diff --git a/src/pkg/net/dnsmsg_test.go b/src/pkg/net/dnsmsg_test.go new file mode 100644 index 000000000..06152a01a --- /dev/null +++ b/src/pkg/net/dnsmsg_test.go @@ -0,0 +1,100 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "encoding/hex" + "testing" +) + +func TestDNSParseSRVReply(t *testing.T) { + data, err := hex.DecodeString(dnsSRVReply) + if err != nil { + t.Fatal(err) + } + msg := new(dnsMsg) + ok := msg.Unpack(data) + if !ok { + t.Fatalf("unpacking packet failed") + } + if g, e := len(msg.answer), 5; g != e { + t.Errorf("len(msg.answer) = %d; want %d", g, e) + } + for idx, rr := range msg.answer { + if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e { + t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e) + } + if _, ok := rr.(*dnsRR_SRV); !ok { + t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr) + } + } + _, addrs, err := answer("_xmpp-server._tcp.google.com.", "foo:53", msg, uint16(dnsTypeSRV)) + if err != nil { + t.Fatalf("answer: %v", err) + } + if g, e := len(addrs), 5; g != e { + t.Errorf("len(addrs) = %d; want %d", g, e) + t.Logf("addrs = %#v", addrs) + } +} + +func TestDNSParseCorruptSRVReply(t *testing.T) { + data, err := hex.DecodeString(dnsSRVCorruptReply) + if err != nil { + t.Fatal(err) + } + msg := new(dnsMsg) + ok := msg.Unpack(data) + if !ok { + t.Fatalf("unpacking packet failed") + } + if g, e := len(msg.answer), 5; g != e { + t.Errorf("len(msg.answer) = %d; want %d", g, e) + } + for idx, rr := range msg.answer { + if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e { + t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e) + } + if idx == 4 { + if _, ok := rr.(*dnsRR_Header); !ok { + t.Errorf("answer[%d] = %T; want *dnsRR_Header", idx, rr) + } + } else { + if _, ok := rr.(*dnsRR_SRV); !ok { + t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr) + } + } + } + _, addrs, err := answer("_xmpp-server._tcp.google.com.", "foo:53", msg, uint16(dnsTypeSRV)) + if err != nil { + t.Fatalf("answer: %v", err) + } + if g, e := len(addrs), 4; g != e { + t.Errorf("len(addrs) = %d; want %d", g, e) + t.Logf("addrs = %#v", addrs) + } +} + +// Valid DNS SRV reply +const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" + + "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" + + "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" + + "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" + + "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" + + "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" + + "72016c06676f6f676c6503636f6d00c00c002100010000012c00210014000014950c78" + + "6d70702d73657276657231016c06676f6f676c6503636f6d00" + +// Corrupt DNS SRV reply, with its final RR having a bogus length +// (perhaps it was truncated, or it's malicious) The mutation is the +// capital "FF" below, instead of the proper "21". +const dnsSRVCorruptReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" + + "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" + + "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" + + "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" + + "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" + + "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" + + "72016c06676f6f676c6503636f6d00c00c002100010000012c00FF0014000014950c78" + + "6d70702d73657276657231016c06676f6f676c6503636f6d00" diff --git a/src/pkg/net/dnsname_test.go b/src/pkg/net/dnsname_test.go new file mode 100644 index 000000000..70df693f7 --- /dev/null +++ b/src/pkg/net/dnsname_test.go @@ -0,0 +1,65 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "testing" +) + +type testCase struct { + name string + result bool +} + +var tests = []testCase{ + // RFC2181, section 11. + {"_xmpp-server._tcp.google.com", true}, + {"_xmpp-server._tcp.google.com", true}, + {"foo.com", true}, + {"1foo.com", true}, + {"26.0.0.73.com", true}, + {"fo-o.com", true}, + {"fo1o.com", true}, + {"foo1.com", true}, + {"a.b..com", false}, +} + +func getTestCases(ch chan<- testCase) { + defer close(ch) + var char59 = "" + var char63 = "" + var char64 = "" + for i := 0; i < 59; i++ { + char59 += "a" + } + char63 = char59 + "aaaa" + char64 = char63 + "a" + + for _, tc := range tests { + ch <- tc + } + + ch <- testCase{char63 + ".com", true} + ch <- testCase{char64 + ".com", false} + // 255 char name is fine: + ch <- testCase{char59 + "." + char63 + "." + char63 + "." + + char63 + ".com", + true} + // 256 char name is bad: + ch <- testCase{char59 + "a." + char63 + "." + char63 + "." + + char63 + ".com", + false} +} + +func TestDNSNames(t *testing.T) { + ch := make(chan testCase) + go getTestCases(ch) + for tc := range ch { + if isDomainName(tc.name) != tc.result { + t.Errorf("isDomainName(%v) failed: Should be %v", + tc.name, tc.result) + } + } +} diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go new file mode 100644 index 000000000..707dccaa4 --- /dev/null +++ b/src/pkg/net/fd.go @@ -0,0 +1,645 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "io" + "os" + "sync" + "syscall" + "time" +) + +// Network file descriptor. +type netFD struct { + // locking/lifetime of sysfd + sysmu sync.Mutex + sysref int + closing bool + + // immutable until Close + sysfd int + family int + proto int + sysfile *os.File + cr chan bool + cw chan bool + net string + laddr Addr + raddr Addr + + // owned by client + rdeadline_delta int64 + rdeadline int64 + rio sync.Mutex + wdeadline_delta int64 + wdeadline int64 + wio sync.Mutex + + // owned by fd wait server + ncr, ncw int +} + +type InvalidConnError struct{} + +func (e *InvalidConnError) String() string { return "invalid net.Conn" } +func (e *InvalidConnError) Temporary() bool { return false } +func (e *InvalidConnError) Timeout() bool { return false } + +// A pollServer helps FDs determine when to retry a non-blocking +// read or write after they get EAGAIN. When an FD needs to wait, +// send the fd on s.cr (for a read) or s.cw (for a write) to pass the +// request to the poll server. Then receive on fd.cr/fd.cw. +// When the pollServer finds that i/o on FD should be possible +// again, it will send fd on fd.cr/fd.cw to wake any waiting processes. +// This protocol is implemented as s.WaitRead() and s.WaitWrite(). +// +// There is one subtlety: when sending on s.cr/s.cw, the +// poll server is probably in a system call, waiting for an fd +// to become ready. It's not looking at the request channels. +// To resolve this, the poll server waits not just on the FDs it has +// been given but also its own pipe. After sending on the +// buffered channel s.cr/s.cw, WaitRead/WaitWrite writes a +// byte to the pipe, causing the pollServer's poll system call to +// return. In response to the pipe being readable, the pollServer +// re-polls its request channels. +// +// Note that the ordering is "send request" and then "wake up server". +// If the operations were reversed, there would be a race: the poll +// server might wake up and look at the request channel, see that it +// was empty, and go back to sleep, all before the requester managed +// to send the request. Because the send must complete before the wakeup, +// the request channel must be buffered. A buffer of size 1 is sufficient +// for any request load. If many processes are trying to submit requests, +// one will succeed, the pollServer will read the request, and then the +// channel will be empty for the next process's request. A larger buffer +// might help batch requests. +// +// To avoid races in closing, all fd operations are locked and +// refcounted. when netFD.Close() is called, it calls syscall.Shutdown +// and sets a closing flag. Only when the last reference is removed +// will the fd be closed. + +type pollServer struct { + cr, cw chan *netFD // buffered >= 1 + pr, pw *os.File + poll *pollster // low-level OS hooks + sync.Mutex // controls pending and deadline + pending map[int]*netFD + deadline int64 // next deadline (nsec since 1970) +} + +func (s *pollServer) AddFD(fd *netFD, mode int) { + intfd := fd.sysfd + if intfd < 0 { + // fd closed underfoot + if mode == 'r' { + fd.cr <- true + } else { + fd.cw <- true + } + return + } + + s.Lock() + + var t int64 + key := intfd << 1 + if mode == 'r' { + fd.ncr++ + t = fd.rdeadline + } else { + fd.ncw++ + key++ + t = fd.wdeadline + } + s.pending[key] = fd + doWakeup := false + if t > 0 && (s.deadline == 0 || t < s.deadline) { + s.deadline = t + doWakeup = true + } + + wake, err := s.poll.AddFD(intfd, mode, false) + if err != nil { + panic("pollServer AddFD " + err.String()) + } + if wake { + doWakeup = true + } + + s.Unlock() + + if doWakeup { + s.Wakeup() + } +} + +var wakeupbuf [1]byte + +func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) } + +func (s *pollServer) LookupFD(fd int, mode int) *netFD { + key := fd << 1 + if mode == 'w' { + key++ + } + netfd, ok := s.pending[key] + if !ok { + return nil + } + s.pending[key] = nil, false + return netfd +} + +func (s *pollServer) WakeFD(fd *netFD, mode int) { + if mode == 'r' { + for fd.ncr > 0 { + fd.ncr-- + fd.cr <- true + } + } else { + for fd.ncw > 0 { + fd.ncw-- + fd.cw <- true + } + } +} + +func (s *pollServer) Now() int64 { + return time.Nanoseconds() +} + +func (s *pollServer) CheckDeadlines() { + now := s.Now() + // TODO(rsc): This will need to be handled more efficiently, + // probably with a heap indexed by wakeup time. + + var next_deadline int64 + for key, fd := range s.pending { + var t int64 + var mode int + if key&1 == 0 { + mode = 'r' + } else { + mode = 'w' + } + if mode == 'r' { + t = fd.rdeadline + } else { + t = fd.wdeadline + } + if t > 0 { + if t <= now { + s.pending[key] = nil, false + if mode == 'r' { + s.poll.DelFD(fd.sysfd, mode) + fd.rdeadline = -1 + } else { + s.poll.DelFD(fd.sysfd, mode) + fd.wdeadline = -1 + } + s.WakeFD(fd, mode) + } else if next_deadline == 0 || t < next_deadline { + next_deadline = t + } + } + } + s.deadline = next_deadline +} + +func (s *pollServer) Run() { + var scratch [100]byte + s.Lock() + defer s.Unlock() + for { + var t = s.deadline + if t > 0 { + t = t - s.Now() + if t <= 0 { + s.CheckDeadlines() + continue + } + } + fd, mode, err := s.poll.WaitFD(s, t) + if err != nil { + print("pollServer WaitFD: ", err.String(), "\n") + return + } + if fd < 0 { + // Timeout happened. + s.CheckDeadlines() + continue + } + if fd == s.pr.Fd() { + // Drain our wakeup pipe (we could loop here, + // but it's unlikely that there are more than + // len(scratch) wakeup calls). + s.pr.Read(scratch[0:]) + s.CheckDeadlines() + } else { + netfd := s.LookupFD(fd, mode) + if netfd == nil { + print("pollServer: unexpected wakeup for fd=", fd, " mode=", string(mode), "\n") + continue + } + s.WakeFD(netfd, mode) + } + } +} + +func (s *pollServer) WaitRead(fd *netFD) { + s.AddFD(fd, 'r') + <-fd.cr +} + +func (s *pollServer) WaitWrite(fd *netFD) { + s.AddFD(fd, 'w') + <-fd.cw +} + +// Network FD methods. +// All the network FDs use a single pollServer. + +var pollserver *pollServer +var onceStartServer sync.Once + +func startServer() { + p, err := newPollServer() + if err != nil { + print("Start pollServer: ", err.String(), "\n") + } + pollserver = p +} + +func newFD(fd, family, proto int, net string) (f *netFD, err os.Error) { + onceStartServer.Do(startServer) + if e := syscall.SetNonblock(fd, true); e != 0 { + return nil, os.Errno(e) + } + f = &netFD{ + sysfd: fd, + family: family, + proto: proto, + net: net, + } + f.cr = make(chan bool, 1) + f.cw = make(chan bool, 1) + return f, nil +} + +func (fd *netFD) setAddr(laddr, raddr Addr) { + fd.laddr = laddr + fd.raddr = raddr + var ls, rs string + if laddr != nil { + ls = laddr.String() + } + if raddr != nil { + rs = raddr.String() + } + fd.sysfile = os.NewFile(fd.sysfd, fd.net+":"+ls+"->"+rs) +} + +func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) { + e := syscall.Connect(fd.sysfd, ra) + if e == syscall.EINPROGRESS { + var errno int + pollserver.WaitWrite(fd) + e, errno = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) + if errno != 0 { + return os.NewSyscallError("getsockopt", errno) + } + } + if e != 0 { + return os.Errno(e) + } + return nil +} + +// Add a reference to this fd. +func (fd *netFD) incref() { + fd.sysmu.Lock() + fd.sysref++ + fd.sysmu.Unlock() +} + +// Remove a reference to this FD and close if we've been asked to do so (and +// there are no references left. +func (fd *netFD) decref() { + fd.sysmu.Lock() + fd.sysref-- + if fd.closing && fd.sysref == 0 && fd.sysfd >= 0 { + // In case the user has set linger, switch to blocking mode so + // the close blocks. As long as this doesn't happen often, we + // can handle the extra OS processes. Otherwise we'll need to + // use the pollserver for Close too. Sigh. + syscall.SetNonblock(fd.sysfd, false) + fd.sysfile.Close() + fd.sysfile = nil + fd.sysfd = -1 + } + fd.sysmu.Unlock() +} + +func (fd *netFD) Close() os.Error { + if fd == nil || fd.sysfile == nil { + return os.EINVAL + } + + fd.incref() + syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR) + fd.closing = true + fd.decref() + return nil +} + +func (fd *netFD) Read(p []byte) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfile == nil { + return 0, os.EINVAL + } + if fd.rdeadline_delta > 0 { + fd.rdeadline = pollserver.Now() + fd.rdeadline_delta + } else { + fd.rdeadline = 0 + } + var oserr os.Error + for { + var errno int + n, errno = syscall.Read(fd.sysfile.Fd(), p) + if errno == syscall.EAGAIN && fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + if errno != 0 { + n = 0 + oserr = os.Errno(errno) + } else if n == 0 && errno == 0 && fd.proto != syscall.SOCK_DGRAM { + err = os.EOF + } + break + } + if oserr != nil { + err = &OpError{"read", fd.net, fd.raddr, oserr} + } + return +} + +func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, nil, os.EINVAL + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.rdeadline_delta > 0 { + fd.rdeadline = pollserver.Now() + fd.rdeadline_delta + } else { + fd.rdeadline = 0 + } + var oserr os.Error + for { + var errno int + n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0) + if errno == syscall.EAGAIN && fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + if errno != 0 { + n = 0 + oserr = os.Errno(errno) + } + break + } + if oserr != nil { + err = &OpError{"read", fd.net, fd.laddr, oserr} + } + return +} + +func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, 0, 0, nil, os.EINVAL + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.rdeadline_delta > 0 { + fd.rdeadline = pollserver.Now() + fd.rdeadline_delta + } else { + fd.rdeadline = 0 + } + var oserr os.Error + for { + var errno int + n, oobn, flags, sa, errno = syscall.Recvmsg(fd.sysfd, p, oob, 0) + if errno == syscall.EAGAIN && fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + if errno != 0 { + oserr = os.Errno(errno) + } + if n == 0 { + oserr = os.EOF + } + break + } + if oserr != nil { + err = &OpError{"read", fd.net, fd.laddr, oserr} + return + } + return +} + +func (fd *netFD) Write(p []byte) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfile == nil { + return 0, os.EINVAL + } + if fd.wdeadline_delta > 0 { + fd.wdeadline = pollserver.Now() + fd.wdeadline_delta + } else { + fd.wdeadline = 0 + } + nn := 0 + var oserr os.Error + + for { + n, errno := syscall.Write(fd.sysfile.Fd(), p[nn:]) + if n > 0 { + nn += n + } + if nn == len(p) { + break + } + if errno == syscall.EAGAIN && fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + if errno != 0 { + n = 0 + oserr = os.Errno(errno) + break + } + if n == 0 { + oserr = io.ErrUnexpectedEOF + break + } + } + if oserr != nil { + err = &OpError{"write", fd.net, fd.raddr, oserr} + } + return nn, err +} + +func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, os.EINVAL + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.wdeadline_delta > 0 { + fd.wdeadline = pollserver.Now() + fd.wdeadline_delta + } else { + fd.wdeadline = 0 + } + var oserr os.Error + for { + errno := syscall.Sendto(fd.sysfd, p, 0, sa) + if errno == syscall.EAGAIN && fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + if errno != 0 { + oserr = os.Errno(errno) + } + break + } + if oserr == nil { + n = len(p) + } else { + err = &OpError{"write", fd.net, fd.raddr, oserr} + } + return +} + +func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, 0, os.EINVAL + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.wdeadline_delta > 0 { + fd.wdeadline = pollserver.Now() + fd.wdeadline_delta + } else { + fd.wdeadline = 0 + } + var oserr os.Error + for { + var errno int + errno = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) + if errno == syscall.EAGAIN && fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + if errno != 0 { + oserr = os.Errno(errno) + } + break + } + if oserr == nil { + n = len(p) + oobn = len(oob) + } else { + err = &OpError{"write", fd.net, fd.raddr, oserr} + } + return +} + +func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) { + if fd == nil || fd.sysfile == nil { + return nil, os.EINVAL + } + + fd.incref() + defer fd.decref() + if fd.rdeadline_delta > 0 { + fd.rdeadline = pollserver.Now() + fd.rdeadline_delta + } else { + fd.rdeadline = 0 + } + + // See ../syscall/exec.go for description of ForkLock. + // It is okay to hold the lock across syscall.Accept + // because we have put fd.sysfd into non-blocking mode. + syscall.ForkLock.RLock() + var s, e int + var rsa syscall.Sockaddr + for { + if fd.closing { + syscall.ForkLock.RUnlock() + return nil, os.EINVAL + } + s, rsa, e = syscall.Accept(fd.sysfd) + if e != syscall.EAGAIN || fd.rdeadline < 0 { + break + } + syscall.ForkLock.RUnlock() + pollserver.WaitRead(fd) + syscall.ForkLock.RLock() + } + if e != 0 { + syscall.ForkLock.RUnlock() + return nil, &OpError{"accept", fd.net, fd.laddr, os.Errno(e)} + } + syscall.CloseOnExec(s) + syscall.ForkLock.RUnlock() + + if nfd, err = newFD(s, fd.family, fd.proto, fd.net); err != nil { + syscall.Close(s) + return nil, err + } + lsa, _ := syscall.Getsockname(nfd.sysfd) + nfd.setAddr(toAddr(lsa), toAddr(rsa)) + return nfd, nil +} + +func (fd *netFD) dup() (f *os.File, err os.Error) { + ns, e := syscall.Dup(fd.sysfd) + if e != 0 { + return nil, &OpError{"dup", fd.net, fd.laddr, os.Errno(e)} + } + + // We want blocking mode for the new fd, hence the double negative. + if e = syscall.SetNonblock(ns, false); e != 0 { + return nil, &OpError{"setnonblock", fd.net, fd.laddr, os.Errno(e)} + } + + return os.NewFile(ns, fd.sysfile.Name()), nil +} + +func closesocket(s int) (errno int) { + return syscall.Close(s) +} diff --git a/src/pkg/net/fd_darwin.go b/src/pkg/net/fd_darwin.go new file mode 100644 index 000000000..7e3d549eb --- /dev/null +++ b/src/pkg/net/fd_darwin.go @@ -0,0 +1,120 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Waiting for FDs via kqueue/kevent. + +package net + +import ( + "os" + "syscall" +) + +type pollster struct { + kq int + eventbuf [10]syscall.Kevent_t + events []syscall.Kevent_t + + // An event buffer for AddFD/DelFD. + // Must hold pollServer lock. + kbuf [1]syscall.Kevent_t +} + +func newpollster() (p *pollster, err os.Error) { + p = new(pollster) + var e int + if p.kq, e = syscall.Kqueue(); e != 0 { + return nil, os.NewSyscallError("kqueue", e) + } + p.events = p.eventbuf[0:0] + return p, nil +} + +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { + // pollServer is locked. + + var kmode int + if mode == 'r' { + kmode = syscall.EVFILT_READ + } else { + kmode = syscall.EVFILT_WRITE + } + ev := &p.kbuf[0] + // EV_ADD - add event to kqueue list + // EV_RECEIPT - generate fake EV_ERROR as result of add, + // rather than waiting for real event + // EV_ONESHOT - delete the event the first time it triggers + flags := syscall.EV_ADD | syscall.EV_RECEIPT + if !repeat { + flags |= syscall.EV_ONESHOT + } + syscall.SetKevent(ev, fd, kmode, flags) + + n, e := syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil) + if e != 0 { + return false, os.NewSyscallError("kevent", e) + } + if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { + return false, os.NewError("kqueue phase error") + } + if ev.Data != 0 { + return false, os.Errno(int(ev.Data)) + } + return false, nil +} + +func (p *pollster) DelFD(fd int, mode int) { + // pollServer is locked. + + var kmode int + if mode == 'r' { + kmode = syscall.EVFILT_READ + } else { + kmode = syscall.EVFILT_WRITE + } + ev := &p.kbuf[0] + // EV_DELETE - delete event from kqueue list + // EV_RECEIPT - generate fake EV_ERROR as result of add, + // rather than waiting for real event + syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE|syscall.EV_RECEIPT) + syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil) +} + +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { + var t *syscall.Timespec + for len(p.events) == 0 { + if nsec > 0 { + if t == nil { + t = new(syscall.Timespec) + } + *t = syscall.NsecToTimespec(nsec) + } + + s.Unlock() + nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[0:], t) + s.Lock() + + if e != 0 { + if e == syscall.EINTR { + continue + } + return -1, 0, os.NewSyscallError("kevent", e) + } + if nn == 0 { + return -1, 0, nil + } + p.events = p.eventbuf[0:nn] + } + ev := &p.events[0] + p.events = p.events[1:] + fd = int(ev.Ident) + if ev.Filter == syscall.EVFILT_READ { + mode = 'r' + } else { + mode = 'w' + } + return fd, mode, nil +} + +func (p *pollster) Close() os.Error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_freebsd.go b/src/pkg/net/fd_freebsd.go new file mode 100644 index 000000000..e50883e94 --- /dev/null +++ b/src/pkg/net/fd_freebsd.go @@ -0,0 +1,116 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Waiting for FDs via kqueue/kevent. + +package net + +import ( + "os" + "syscall" +) + +type pollster struct { + kq int + eventbuf [10]syscall.Kevent_t + events []syscall.Kevent_t + + // An event buffer for AddFD/DelFD. + // Must hold pollServer lock. + kbuf [1]syscall.Kevent_t +} + +func newpollster() (p *pollster, err os.Error) { + p = new(pollster) + var e int + if p.kq, e = syscall.Kqueue(); e != 0 { + return nil, os.NewSyscallError("kqueue", e) + } + p.events = p.eventbuf[0:0] + return p, nil +} + +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { + // pollServer is locked. + + var kmode int + if mode == 'r' { + kmode = syscall.EVFILT_READ + } else { + kmode = syscall.EVFILT_WRITE + } + ev := &p.kbuf[0] + // EV_ADD - add event to kqueue list + // EV_ONESHOT - delete the event the first time it triggers + flags := syscall.EV_ADD + if !repeat { + flags |= syscall.EV_ONESHOT + } + syscall.SetKevent(ev, fd, kmode, flags) + + n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) + if e != 0 { + return false, os.NewSyscallError("kevent", e) + } + if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { + return false, os.NewSyscallError("kqueue phase error", e) + } + if ev.Data != 0 { + return false, os.Errno(int(ev.Data)) + } + return false, nil +} + +func (p *pollster) DelFD(fd int, mode int) { + // pollServer is locked. + + var kmode int + if mode == 'r' { + kmode = syscall.EVFILT_READ + } else { + kmode = syscall.EVFILT_WRITE + } + ev := &p.kbuf[0] + // EV_DELETE - delete event from kqueue list + syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE) + syscall.Kevent(p.kq, p.kbuf[:], nil, nil) +} + +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { + var t *syscall.Timespec + for len(p.events) == 0 { + if nsec > 0 { + if t == nil { + t = new(syscall.Timespec) + } + *t = syscall.NsecToTimespec(nsec) + } + + s.Unlock() + nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) + s.Lock() + + if e != 0 { + if e == syscall.EINTR { + continue + } + return -1, 0, os.NewSyscallError("kevent", e) + } + if nn == 0 { + return -1, 0, nil + } + p.events = p.eventbuf[0:nn] + } + ev := &p.events[0] + p.events = p.events[1:] + fd = int(ev.Ident) + if ev.Filter == syscall.EVFILT_READ { + mode = 'r' + } else { + mode = 'w' + } + return fd, mode, nil +} + +func (p *pollster) Close() os.Error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_linux.go b/src/pkg/net/fd_linux.go new file mode 100644 index 000000000..70fc344b2 --- /dev/null +++ b/src/pkg/net/fd_linux.go @@ -0,0 +1,182 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Waiting for FDs via epoll(7). + +package net + +import ( + "os" + "syscall" +) + +const ( + readFlags = syscall.EPOLLIN | syscall.EPOLLRDHUP + writeFlags = syscall.EPOLLOUT +) + +type pollster struct { + epfd int + + // Events we're already waiting for + // Must hold pollServer lock + events map[int]uint32 + + // An event buffer for EpollWait. + // Used without a lock, may only be used by WaitFD. + waitEventBuf [10]syscall.EpollEvent + waitEvents []syscall.EpollEvent + + // An event buffer for EpollCtl, to avoid a malloc. + // Must hold pollServer lock. + ctlEvent syscall.EpollEvent +} + +func newpollster() (p *pollster, err os.Error) { + p = new(pollster) + var e int + + // The arg to epoll_create is a hint to the kernel + // about the number of FDs we will care about. + // We don't know, and since 2.6.8 the kernel ignores it anyhow. + if p.epfd, e = syscall.EpollCreate(16); e != 0 { + return nil, os.NewSyscallError("epoll_create", e) + } + p.events = make(map[int]uint32) + return p, nil +} + +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { + // pollServer is locked. + + var already bool + p.ctlEvent.Fd = int32(fd) + p.ctlEvent.Events, already = p.events[fd] + if !repeat { + p.ctlEvent.Events |= syscall.EPOLLONESHOT + } + if mode == 'r' { + p.ctlEvent.Events |= readFlags + } else { + p.ctlEvent.Events |= writeFlags + } + + var op int + if already { + op = syscall.EPOLL_CTL_MOD + } else { + op = syscall.EPOLL_CTL_ADD + } + if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != 0 { + return false, os.NewSyscallError("epoll_ctl", e) + } + p.events[fd] = p.ctlEvent.Events + return false, nil +} + +func (p *pollster) StopWaiting(fd int, bits uint) { + // pollServer is locked. + + events, already := p.events[fd] + if !already { + print("Epoll unexpected fd=", fd, "\n") + return + } + + // If syscall.EPOLLONESHOT is not set, the wait + // is a repeating wait, so don't change it. + if events&syscall.EPOLLONESHOT == 0 { + return + } + + // Disable the given bits. + // If we're still waiting for other events, modify the fd + // event in the kernel. Otherwise, delete it. + events &= ^uint32(bits) + if int32(events)&^syscall.EPOLLONESHOT != 0 { + p.ctlEvent.Fd = int32(fd) + p.ctlEvent.Events = events + if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != 0 { + print("Epoll modify fd=", fd, ": ", os.Errno(e).String(), "\n") + } + p.events[fd] = events + } else { + if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != 0 { + print("Epoll delete fd=", fd, ": ", os.Errno(e).String(), "\n") + } + p.events[fd] = 0, false + } +} + +func (p *pollster) DelFD(fd int, mode int) { + // pollServer is locked. + + if mode == 'r' { + p.StopWaiting(fd, readFlags) + } else { + p.StopWaiting(fd, writeFlags) + } + + // Discard any queued up events. + i := 0 + for i < len(p.waitEvents) { + if fd == int(p.waitEvents[i].Fd) { + copy(p.waitEvents[i:], p.waitEvents[i+1:]) + p.waitEvents = p.waitEvents[:len(p.waitEvents)-1] + } else { + i++ + } + } +} + +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { + for len(p.waitEvents) == 0 { + var msec int = -1 + if nsec > 0 { + msec = int((nsec + 1e6 - 1) / 1e6) + } + + s.Unlock() + n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec) + s.Lock() + + if e != 0 { + if e == syscall.EAGAIN || e == syscall.EINTR { + continue + } + return -1, 0, os.NewSyscallError("epoll_wait", e) + } + if n == 0 { + return -1, 0, nil + } + p.waitEvents = p.waitEventBuf[0:n] + } + + ev := &p.waitEvents[0] + p.waitEvents = p.waitEvents[1:] + + fd = int(ev.Fd) + + if ev.Events&writeFlags != 0 { + p.StopWaiting(fd, writeFlags) + return fd, 'w', nil + } + if ev.Events&readFlags != 0 { + p.StopWaiting(fd, readFlags) + return fd, 'r', nil + } + + // Other events are error conditions - wake whoever is waiting. + events, _ := p.events[fd] + if events&writeFlags != 0 { + p.StopWaiting(fd, writeFlags) + return fd, 'w', nil + } + p.StopWaiting(fd, readFlags) + return fd, 'r', nil +} + +func (p *pollster) Close() os.Error { + return os.NewSyscallError("close", syscall.Close(p.epfd)) +} diff --git a/src/pkg/net/fd_openbsd.go b/src/pkg/net/fd_openbsd.go new file mode 100644 index 000000000..e50883e94 --- /dev/null +++ b/src/pkg/net/fd_openbsd.go @@ -0,0 +1,116 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Waiting for FDs via kqueue/kevent. + +package net + +import ( + "os" + "syscall" +) + +type pollster struct { + kq int + eventbuf [10]syscall.Kevent_t + events []syscall.Kevent_t + + // An event buffer for AddFD/DelFD. + // Must hold pollServer lock. + kbuf [1]syscall.Kevent_t +} + +func newpollster() (p *pollster, err os.Error) { + p = new(pollster) + var e int + if p.kq, e = syscall.Kqueue(); e != 0 { + return nil, os.NewSyscallError("kqueue", e) + } + p.events = p.eventbuf[0:0] + return p, nil +} + +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { + // pollServer is locked. + + var kmode int + if mode == 'r' { + kmode = syscall.EVFILT_READ + } else { + kmode = syscall.EVFILT_WRITE + } + ev := &p.kbuf[0] + // EV_ADD - add event to kqueue list + // EV_ONESHOT - delete the event the first time it triggers + flags := syscall.EV_ADD + if !repeat { + flags |= syscall.EV_ONESHOT + } + syscall.SetKevent(ev, fd, kmode, flags) + + n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) + if e != 0 { + return false, os.NewSyscallError("kevent", e) + } + if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { + return false, os.NewSyscallError("kqueue phase error", e) + } + if ev.Data != 0 { + return false, os.Errno(int(ev.Data)) + } + return false, nil +} + +func (p *pollster) DelFD(fd int, mode int) { + // pollServer is locked. + + var kmode int + if mode == 'r' { + kmode = syscall.EVFILT_READ + } else { + kmode = syscall.EVFILT_WRITE + } + ev := &p.kbuf[0] + // EV_DELETE - delete event from kqueue list + syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE) + syscall.Kevent(p.kq, p.kbuf[:], nil, nil) +} + +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { + var t *syscall.Timespec + for len(p.events) == 0 { + if nsec > 0 { + if t == nil { + t = new(syscall.Timespec) + } + *t = syscall.NsecToTimespec(nsec) + } + + s.Unlock() + nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) + s.Lock() + + if e != 0 { + if e == syscall.EINTR { + continue + } + return -1, 0, os.NewSyscallError("kevent", e) + } + if nn == 0 { + return -1, 0, nil + } + p.events = p.eventbuf[0:nn] + } + ev := &p.events[0] + p.events = p.events[1:] + fd = int(ev.Ident) + if ev.Filter == syscall.EVFILT_READ { + mode = 'r' + } else { + mode = 'w' + } + return fd, mode, nil +} + +func (p *pollster) Close() os.Error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go new file mode 100644 index 000000000..3757e143d --- /dev/null +++ b/src/pkg/net/fd_windows.go @@ -0,0 +1,532 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "runtime" + "sync" + "syscall" + "time" + "unsafe" +) + +type InvalidConnError struct{} + +func (e *InvalidConnError) String() string { return "invalid net.Conn" } +func (e *InvalidConnError) Temporary() bool { return false } +func (e *InvalidConnError) Timeout() bool { return false } + +var initErr os.Error + +func init() { + var d syscall.WSAData + e := syscall.WSAStartup(uint32(0x101), &d) + if e != 0 { + initErr = os.NewSyscallError("WSAStartup", e) + } +} + +func closesocket(s syscall.Handle) (errno int) { + return syscall.Closesocket(s) +} + +// Interface for all io operations. +type anOpIface interface { + Op() *anOp + Name() string + Submit() (errno int) +} + +// IO completion result parameters. +type ioResult struct { + qty uint32 + err int +} + +// anOp implements functionality common to all io operations. +type anOp struct { + // Used by IOCP interface, it must be first field + // of the struct, as our code rely on it. + o syscall.Overlapped + + resultc chan ioResult // io completion results + errnoc chan int // io submit / cancel operation errors + fd *netFD +} + +func (o *anOp) Init(fd *netFD) { + o.fd = fd + o.resultc = make(chan ioResult, 1) + o.errnoc = make(chan int) +} + +func (o *anOp) Op() *anOp { + return o +} + +// bufOp is used by io operations that read / write +// data from / to client buffer. +type bufOp struct { + anOp + buf syscall.WSABuf +} + +func (o *bufOp) Init(fd *netFD, buf []byte) { + o.anOp.Init(fd) + o.buf.Len = uint32(len(buf)) + if len(buf) == 0 { + o.buf.Buf = nil + } else { + o.buf.Buf = (*byte)(unsafe.Pointer(&buf[0])) + } +} + +// resultSrv will retrieve all io completion results from +// iocp and send them to the correspondent waiting client +// goroutine via channel supplied in the request. +type resultSrv struct { + iocp syscall.Handle +} + +func (s *resultSrv) Run() { + var o *syscall.Overlapped + var key uint32 + var r ioResult + for { + r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) + switch { + case r.err == 0: + // Dequeued successfully completed io packet. + case r.err == syscall.WAIT_TIMEOUT && o == nil: + // Wait has timed out (should not happen now, but might be used in the future). + panic("GetQueuedCompletionStatus timed out") + case o == nil: + // Failed to dequeue anything -> report the error. + panic("GetQueuedCompletionStatus failed " + syscall.Errstr(r.err)) + default: + // Dequeued failed io packet. + } + (*anOp)(unsafe.Pointer(o)).resultc <- r + } +} + +// ioSrv executes net io requests. +type ioSrv struct { + submchan chan anOpIface // submit io requests + canchan chan anOpIface // cancel io requests +} + +// ProcessRemoteIO will execute submit io requests on behalf +// of other goroutines, all on a single os thread, so it can +// cancel them later. Results of all operations will be sent +// back to their requesters via channel supplied in request. +func (s *ioSrv) ProcessRemoteIO() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + for { + select { + case o := <-s.submchan: + o.Op().errnoc <- o.Submit() + case o := <-s.canchan: + o.Op().errnoc <- syscall.CancelIo(syscall.Handle(o.Op().fd.sysfd)) + } + } +} + +// ExecIO executes a single io operation. It either executes it +// inline, or, if timeouts are employed, passes the request onto +// a special goroutine and waits for completion or cancels request. +func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err os.Error) { + var e int + o := oi.Op() + if deadline_delta > 0 { + // Send request to a special dedicated thread, + // so it can stop the io with CancelIO later. + s.submchan <- oi + e = <-o.errnoc + } else { + e = oi.Submit() + } + switch e { + case 0: + // IO completed immediately, but we need to get our completion message anyway. + case syscall.ERROR_IO_PENDING: + // IO started, and we have to wait for its completion. + default: + return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(e)} + } + // Wait for our request to complete. + var r ioResult + if deadline_delta > 0 { + select { + case r = <-o.resultc: + case <-time.After(deadline_delta): + s.canchan <- oi + <-o.errnoc + r = <-o.resultc + if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled + r.err = syscall.EWOULDBLOCK + } + } + } else { + r = <-o.resultc + } + if r.err != 0 { + err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(r.err)} + } + return int(r.qty), err +} + +// Start helper goroutines. +var resultsrv *resultSrv +var iosrv *ioSrv +var onceStartServer sync.Once + +func startServer() { + resultsrv = new(resultSrv) + var errno int + resultsrv.iocp, errno = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1) + if errno != 0 { + panic("CreateIoCompletionPort failed " + syscall.Errstr(errno)) + } + go resultsrv.Run() + + iosrv = new(ioSrv) + iosrv.submchan = make(chan anOpIface) + iosrv.canchan = make(chan anOpIface) + go iosrv.ProcessRemoteIO() +} + +// Network file descriptor. +type netFD struct { + // locking/lifetime of sysfd + sysmu sync.Mutex + sysref int + closing bool + + // immutable until Close + sysfd syscall.Handle + family int + proto int + net string + laddr Addr + raddr Addr + + // owned by client + rdeadline_delta int64 + rdeadline int64 + rio sync.Mutex + wdeadline_delta int64 + wdeadline int64 + wio sync.Mutex +} + +func allocFD(fd syscall.Handle, family, proto int, net string) (f *netFD) { + f = &netFD{ + sysfd: fd, + family: family, + proto: proto, + net: net, + } + runtime.SetFinalizer(f, (*netFD).Close) + return f +} + +func newFD(fd syscall.Handle, family, proto int, net string) (f *netFD, err os.Error) { + if initErr != nil { + return nil, initErr + } + onceStartServer.Do(startServer) + // Associate our socket with resultsrv.iocp. + if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != 0 { + return nil, os.Errno(e) + } + return allocFD(fd, family, proto, net), nil +} + +func (fd *netFD) setAddr(laddr, raddr Addr) { + fd.laddr = laddr + fd.raddr = raddr +} + +func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) { + e := syscall.Connect(fd.sysfd, ra) + if e != 0 { + return os.Errno(e) + } + return nil +} + +// Add a reference to this fd. +func (fd *netFD) incref() { + fd.sysmu.Lock() + fd.sysref++ + fd.sysmu.Unlock() +} + +// Remove a reference to this FD and close if we've been asked to do so (and +// there are no references left. +func (fd *netFD) decref() { + fd.sysmu.Lock() + fd.sysref-- + if fd.closing && fd.sysref == 0 && fd.sysfd != syscall.InvalidHandle { + // In case the user has set linger, switch to blocking mode so + // the close blocks. As long as this doesn't happen often, we + // can handle the extra OS processes. Otherwise we'll need to + // use the resultsrv for Close too. Sigh. + syscall.SetNonblock(fd.sysfd, false) + closesocket(fd.sysfd) + fd.sysfd = syscall.InvalidHandle + // no need for a finalizer anymore + runtime.SetFinalizer(fd, nil) + } + fd.sysmu.Unlock() +} + +func (fd *netFD) Close() os.Error { + if fd == nil || fd.sysfd == syscall.InvalidHandle { + return os.EINVAL + } + + fd.incref() + syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR) + fd.closing = true + fd.decref() + return nil +} + +// Read from network. + +type readOp struct { + bufOp +} + +func (o *readOp) Submit() (errno int) { + var d, f uint32 + return syscall.WSARecv(syscall.Handle(o.fd.sysfd), &o.buf, 1, &d, &f, &o.o, nil) +} + +func (o *readOp) Name() string { + return "WSARecv" +} + +func (fd *netFD) Read(buf []byte) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == syscall.InvalidHandle { + return 0, os.EINVAL + } + var o readOp + o.Init(fd, buf) + n, err = iosrv.ExecIO(&o, fd.rdeadline_delta) + if err == nil && n == 0 { + err = os.EOF + } + return +} + +// ReadFrom from network. + +type readFromOp struct { + bufOp + rsa syscall.RawSockaddrAny + rsan int32 +} + +func (o *readFromOp) Submit() (errno int) { + var d, f uint32 + return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &d, &f, &o.rsa, &o.rsan, &o.o, nil) +} + +func (o *readFromOp) Name() string { + return "WSARecvFrom" +} + +func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err os.Error) { + if fd == nil { + return 0, nil, os.EINVAL + } + if len(buf) == 0 { + return 0, nil, nil + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == syscall.InvalidHandle { + return 0, nil, os.EINVAL + } + var o readFromOp + o.Init(fd, buf) + o.rsan = int32(unsafe.Sizeof(o.rsa)) + n, err = iosrv.ExecIO(&o, fd.rdeadline_delta) + if err != nil { + return 0, nil, err + } + sa, _ = o.rsa.Sockaddr() + return +} + +// Write to network. + +type writeOp struct { + bufOp +} + +func (o *writeOp) Submit() (errno int) { + var d uint32 + return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &d, 0, &o.o, nil) +} + +func (o *writeOp) Name() string { + return "WSASend" +} + +func (fd *netFD) Write(buf []byte) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == syscall.InvalidHandle { + return 0, os.EINVAL + } + var o writeOp + o.Init(fd, buf) + return iosrv.ExecIO(&o, fd.wdeadline_delta) +} + +// WriteTo to network. + +type writeToOp struct { + bufOp + sa syscall.Sockaddr +} + +func (o *writeToOp) Submit() (errno int) { + var d uint32 + return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &d, 0, o.sa, &o.o, nil) +} + +func (o *writeToOp) Name() string { + return "WSASendto" +} + +func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + if len(buf) == 0 { + return 0, nil + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == syscall.InvalidHandle { + return 0, os.EINVAL + } + var o writeToOp + o.Init(fd, buf) + o.sa = sa + return iosrv.ExecIO(&o, fd.wdeadline_delta) +} + +// Accept new network connections. + +type acceptOp struct { + anOp + newsock syscall.Handle + attrs [2]syscall.RawSockaddrAny // space for local and remote address only +} + +func (o *acceptOp) Submit() (errno int) { + var d uint32 + l := uint32(unsafe.Sizeof(o.attrs[0])) + return syscall.AcceptEx(o.fd.sysfd, o.newsock, + (*byte)(unsafe.Pointer(&o.attrs[0])), 0, l, l, &d, &o.o) +} + +func (o *acceptOp) Name() string { + return "AcceptEx" +} + +func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) { + if fd == nil || fd.sysfd == syscall.InvalidHandle { + return nil, os.EINVAL + } + fd.incref() + defer fd.decref() + + // Get new socket. + // See ../syscall/exec.go for description of ForkLock. + syscall.ForkLock.RLock() + s, e := syscall.Socket(fd.family, fd.proto, 0) + if e != 0 { + syscall.ForkLock.RUnlock() + return nil, os.Errno(e) + } + syscall.CloseOnExec(s) + syscall.ForkLock.RUnlock() + + // Associate our new socket with IOCP. + onceStartServer.Do(startServer) + if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != 0 { + return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, os.Errno(e)} + } + + // Submit accept request. + var o acceptOp + o.Init(fd) + o.newsock = s + _, err = iosrv.ExecIO(&o, 0) + if err != nil { + closesocket(s) + return nil, err + } + + // Inherit properties of the listening socket. + e = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))) + if e != 0 { + closesocket(s) + return nil, err + } + + // Get local and peer addr out of AcceptEx buffer. + var lrsa, rrsa *syscall.RawSockaddrAny + var llen, rlen int32 + l := uint32(unsafe.Sizeof(*lrsa)) + syscall.GetAcceptExSockaddrs((*byte)(unsafe.Pointer(&o.attrs[0])), + 0, l, l, &lrsa, &llen, &rrsa, &rlen) + lsa, _ := lrsa.Sockaddr() + rsa, _ := rrsa.Sockaddr() + + nfd = allocFD(s, fd.family, fd.proto, fd.net) + nfd.setAddr(toAddr(lsa), toAddr(rsa)) + return nfd, nil +} + +// Unimplemented functions. + +func (fd *netFD) dup() (f *os.File, err os.Error) { + // TODO: Implement this + return nil, os.NewSyscallError("dup", syscall.EWINDOWS) +} + +func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) { + return 0, 0, 0, nil, os.EAFNOSUPPORT +} + +func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) { + return 0, 0, os.EAFNOSUPPORT +} diff --git a/src/pkg/net/file.go b/src/pkg/net/file.go new file mode 100644 index 000000000..0e411a192 --- /dev/null +++ b/src/pkg/net/file.go @@ -0,0 +1,119 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" +) + +func newFileFD(f *os.File) (nfd *netFD, err os.Error) { + fd, errno := syscall.Dup(f.Fd()) + if errno != 0 { + return nil, os.NewSyscallError("dup", errno) + } + + proto, errno := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE) + if errno != 0 { + return nil, os.NewSyscallError("getsockopt", errno) + } + + toAddr := sockaddrToTCP + sa, _ := syscall.Getsockname(fd) + switch sa.(type) { + default: + closesocket(fd) + return nil, os.EINVAL + case *syscall.SockaddrInet4: + if proto == syscall.SOCK_DGRAM { + toAddr = sockaddrToUDP + } else if proto == syscall.SOCK_RAW { + toAddr = sockaddrToIP + } + case *syscall.SockaddrInet6: + if proto == syscall.SOCK_DGRAM { + toAddr = sockaddrToUDP + } else if proto == syscall.SOCK_RAW { + toAddr = sockaddrToIP + } + case *syscall.SockaddrUnix: + toAddr = sockaddrToUnix + if proto == syscall.SOCK_DGRAM { + toAddr = sockaddrToUnixgram + } else if proto == syscall.SOCK_SEQPACKET { + toAddr = sockaddrToUnixpacket + } + } + laddr := toAddr(sa) + sa, _ = syscall.Getpeername(fd) + raddr := toAddr(sa) + + if nfd, err = newFD(fd, 0, proto, laddr.Network()); err != nil { + return nil, err + } + nfd.setAddr(laddr, raddr) + return nfd, nil +} + +// FileConn returns a copy of the network connection corresponding to +// the open file f. It is the caller's responsibility to close f when +// finished. Closing c does not affect f, and closing f does not +// affect c. +func FileConn(f *os.File) (c Conn, err os.Error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + switch fd.laddr.(type) { + case *TCPAddr: + return newTCPConn(fd), nil + case *UDPAddr: + return newUDPConn(fd), nil + case *UnixAddr: + return newUnixConn(fd), nil + case *IPAddr: + return newIPConn(fd), nil + } + fd.Close() + return nil, os.EINVAL +} + +// FileListener returns a copy of the network listener corresponding +// to the open file f. It is the caller's responsibility to close l +// when finished. Closing c does not affect l, and closing l does not +// affect c. +func FileListener(f *os.File) (l Listener, err os.Error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + switch laddr := fd.laddr.(type) { + case *TCPAddr: + return &TCPListener{fd}, nil + case *UnixAddr: + return &UnixListener{fd, laddr.Name}, nil + } + fd.Close() + return nil, os.EINVAL +} + +// FilePacketConn returns a copy of the packet network connection +// corresponding to the open file f. It is the caller's +// responsibility to close f when finished. Closing c does not affect +// f, and closing f does not affect c. +func FilePacketConn(f *os.File) (c PacketConn, err os.Error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + switch fd.laddr.(type) { + case *UDPAddr: + return newUDPConn(fd), nil + case *UnixAddr: + return newUnixConn(fd), nil + } + fd.Close() + return nil, os.EINVAL +} diff --git a/src/pkg/net/file_plan9.go b/src/pkg/net/file_plan9.go new file mode 100644 index 000000000..a07e74331 --- /dev/null +++ b/src/pkg/net/file_plan9.go @@ -0,0 +1,33 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" +) + +// FileConn returns a copy of the network connection corresponding to +// the open file f. It is the caller's responsibility to close f when +// finished. Closing c does not affect f, and closing f does not +// affect c. +func FileConn(f *os.File) (c Conn, err os.Error) { + return nil, os.EPLAN9 +} + +// FileListener returns a copy of the network listener corresponding +// to the open file f. It is the caller's responsibility to close l +// when finished. Closing c does not affect l, and closing l does not +// affect c. +func FileListener(f *os.File) (l Listener, err os.Error) { + return nil, os.EPLAN9 +} + +// FilePacketConn returns a copy of the packet network connection +// corresponding to the open file f. It is the caller's +// responsibility to close f when finished. Closing c does not affect +// f, and closing f does not affect c. +func FilePacketConn(f *os.File) (c PacketConn, err os.Error) { + return nil, os.EPLAN9 +} diff --git a/src/pkg/net/file_test.go b/src/pkg/net/file_test.go new file mode 100644 index 000000000..9a8c2dcbc --- /dev/null +++ b/src/pkg/net/file_test.go @@ -0,0 +1,133 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "reflect" + "runtime" + "syscall" + "testing" +) + +type listenerFile interface { + Listener + File() (f *os.File, err os.Error) +} + +type packetConnFile interface { + PacketConn + File() (f *os.File, err os.Error) +} + +type connFile interface { + Conn + File() (f *os.File, err os.Error) +} + +func testFileListener(t *testing.T, net, laddr string) { + if net == "tcp" { + laddr += ":0" // any available port + } + l, err := Listen(net, laddr) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer l.Close() + lf := l.(listenerFile) + f, err := lf.File() + if err != nil { + t.Fatalf("File failed: %v", err) + } + c, err := FileListener(f) + if err != nil { + t.Fatalf("FileListener failed: %v", err) + } + if !reflect.DeepEqual(l.Addr(), c.Addr()) { + t.Fatalf("Addrs not equal: %#v != %#v", l.Addr(), c.Addr()) + } + if err := c.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } +} + +func TestFileListener(t *testing.T) { + if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { + return + } + testFileListener(t, "tcp", "127.0.0.1") + testFileListener(t, "tcp", "127.0.0.1") + if supportsIPv6 && supportsIPv4map { + testFileListener(t, "tcp", "[::ffff:127.0.0.1]") + testFileListener(t, "tcp", "127.0.0.1") + testFileListener(t, "tcp", "[::ffff:127.0.0.1]") + } + if syscall.OS == "linux" { + testFileListener(t, "unix", "@gotest/net") + testFileListener(t, "unixpacket", "@gotest/net") + } +} + +func testFilePacketConn(t *testing.T, pcf packetConnFile) { + f, err := pcf.File() + if err != nil { + t.Fatalf("File failed: %v", err) + } + c, err := FilePacketConn(f) + if err != nil { + t.Fatalf("FilePacketConn failed: %v", err) + } + if !reflect.DeepEqual(pcf.LocalAddr(), c.LocalAddr()) { + t.Fatalf("LocalAddrs not equal: %#v != %#v", pcf.LocalAddr(), c.LocalAddr()) + } + if err := c.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } +} + +func testFilePacketConnListen(t *testing.T, net, laddr string) { + l, err := ListenPacket(net, laddr) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + testFilePacketConn(t, l.(packetConnFile)) + if err := l.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } +} + +func testFilePacketConnDial(t *testing.T, net, raddr string) { + c, err := Dial(net, raddr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + testFilePacketConn(t, c.(packetConnFile)) + if err := c.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } +} + +func TestFilePacketConn(t *testing.T) { + if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { + return + } + testFilePacketConnListen(t, "udp", "127.0.0.1:0") + testFilePacketConnDial(t, "udp", "127.0.0.1:12345") + if supportsIPv6 { + testFilePacketConnListen(t, "udp", "[::1]:0") + } + if supportsIPv6 && supportsIPv4map { + testFilePacketConnDial(t, "udp", "[::ffff:127.0.0.1]:12345") + } + if syscall.OS == "linux" { + testFilePacketConnListen(t, "unixgram", "@gotest1/net") + } +} diff --git a/src/pkg/net/file_windows.go b/src/pkg/net/file_windows.go new file mode 100644 index 000000000..94aa58375 --- /dev/null +++ b/src/pkg/net/file_windows.go @@ -0,0 +1,25 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" +) + +func FileConn(f *os.File) (c Conn, err os.Error) { + // TODO: Implement this + return nil, os.NewSyscallError("FileConn", syscall.EWINDOWS) +} + +func FileListener(f *os.File) (l Listener, err os.Error) { + // TODO: Implement this + return nil, os.NewSyscallError("FileListener", syscall.EWINDOWS) +} + +func FilePacketConn(f *os.File) (c PacketConn, err os.Error) { + // TODO: Implement this + return nil, os.NewSyscallError("FilePacketConn", syscall.EWINDOWS) +} diff --git a/src/pkg/net/hosts.go b/src/pkg/net/hosts.go new file mode 100644 index 000000000..d75e9e038 --- /dev/null +++ b/src/pkg/net/hosts.go @@ -0,0 +1,86 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Read static host/IP entries from /etc/hosts. + +package net + +import ( + "os" + "sync" +) + +const cacheMaxAge = int64(300) // 5 minutes. + +// hostsPath points to the file with static IP/address entries. +var hostsPath = "/etc/hosts" + +// Simple cache. +var hosts struct { + sync.Mutex + byName map[string][]string + byAddr map[string][]string + time int64 + path string +} + +func readHosts() { + now, _, _ := os.Time() + hp := hostsPath + if len(hosts.byName) == 0 || hosts.time+cacheMaxAge <= now || hosts.path != hp { + hs := make(map[string][]string) + is := make(map[string][]string) + var file *file + if file, _ = open(hp); file == nil { + return + } + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + if i := byteIndex(line, '#'); i >= 0 { + // Discard comments. + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 || ParseIP(f[0]) == nil { + continue + } + for i := 1; i < len(f); i++ { + h := f[i] + hs[h] = append(hs[h], f[0]) + is[f[0]] = append(is[f[0]], h) + } + } + // Update the data cache. + hosts.time, _, _ = os.Time() + hosts.path = hp + hosts.byName = hs + hosts.byAddr = is + file.close() + } +} + +// lookupStaticHost looks up the addresses for the given host from /etc/hosts. +func lookupStaticHost(host string) []string { + hosts.Lock() + defer hosts.Unlock() + readHosts() + if len(hosts.byName) != 0 { + if ips, ok := hosts.byName[host]; ok { + return ips + } + } + return nil +} + +// lookupStaticAddr looks up the hosts for the given address from /etc/hosts. +func lookupStaticAddr(addr string) []string { + hosts.Lock() + defer hosts.Unlock() + readHosts() + if len(hosts.byAddr) != 0 { + if hosts, ok := hosts.byAddr[addr]; ok { + return hosts + } + } + return nil +} diff --git a/src/pkg/net/hosts_test.go b/src/pkg/net/hosts_test.go new file mode 100644 index 000000000..1bd00541c --- /dev/null +++ b/src/pkg/net/hosts_test.go @@ -0,0 +1,68 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "sort" + "testing" +) + +type hostTest struct { + host string + ips []IP +} + +var hosttests = []hostTest{ + {"odin", []IP{ + IPv4(127, 0, 0, 2), + IPv4(127, 0, 0, 3), + ParseIP("::2"), + }}, + {"thor", []IP{ + IPv4(127, 1, 1, 1), + }}, + {"loki", []IP{}}, + {"ullr", []IP{ + IPv4(127, 1, 1, 2), + }}, + {"ullrhost", []IP{ + IPv4(127, 1, 1, 2), + }}, +} + +func TestLookupStaticHost(t *testing.T) { + p := hostsPath + hostsPath = "hosts_testdata" + for i := 0; i < len(hosttests); i++ { + tt := hosttests[i] + ips := lookupStaticHost(tt.host) + if len(ips) != len(tt.ips) { + t.Errorf("# of hosts = %v; want %v", + len(ips), len(tt.ips)) + return + } + for k, v := range ips { + if tt.ips[k].String() != v { + t.Errorf("lookupStaticHost(%q) = %v; want %v", + tt.host, v, tt.ips[k]) + } + } + } + hostsPath = p +} + +func TestLookupHost(t *testing.T) { + // Can't depend on this to return anything in particular, + // but if it does return something, make sure it doesn't + // duplicate addresses (a common bug due to the way + // getaddrinfo works). + addrs, _ := LookupHost("localhost") + sort.Strings(addrs) + for i := 0; i+1 < len(addrs); i++ { + if addrs[i] == addrs[i+1] { + t.Fatalf("LookupHost(\"localhost\") = %v, has duplicate addresses", addrs) + } + } +} diff --git a/src/pkg/net/hosts_testdata b/src/pkg/net/hosts_testdata new file mode 100644 index 000000000..b60176389 --- /dev/null +++ b/src/pkg/net/hosts_testdata @@ -0,0 +1,12 @@ +255.255.255.255 broadcasthost +127.0.0.2 odin +127.0.0.3 odin # inline comment +::2 odin +127.1.1.1 thor +# aliases +127.1.1.2 ullr ullrhost +# Bogus entries that must be ignored. +123.123.123 loki +321.321.321.321 +# TODO(yvesj): Should we be able to parse this? From a Darwin system. +fe80::1%lo0 localhost diff --git a/src/pkg/net/interface.go b/src/pkg/net/interface.go new file mode 100644 index 000000000..8a14cb232 --- /dev/null +++ b/src/pkg/net/interface.go @@ -0,0 +1,132 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Network interface identification + +package net + +import ( + "bytes" + "fmt" + "os" +) + +// A HardwareAddr represents a physical hardware address. +type HardwareAddr []byte + +func (a HardwareAddr) String() string { + var buf bytes.Buffer + for i, b := range a { + if i > 0 { + buf.WriteByte(':') + } + fmt.Fprintf(&buf, "%02x", b) + } + return buf.String() +} + +// Interface represents a mapping between network interface name +// and index. It also represents network interface facility +// information. +type Interface struct { + Index int // positive integer that starts at one, zero is never used + MTU int // maximum transmission unit + Name string // e.g., "en0", "lo0", "eth0.100" + HardwareAddr HardwareAddr // IEEE MAC-48, EUI-48 and EUI-64 form + Flags Flags // e.g., FlagUp, FlagLoopback, FlagMulticast +} + +type Flags uint + +const ( + FlagUp Flags = 1 << iota // interface is up + FlagBroadcast // interface supports broadcast access capability + FlagLoopback // interface is a loopback interface + FlagPointToPoint // interface belongs to a point-to-point link + FlagMulticast // interface supports multicast access capability +) + +var flagNames = []string{ + "up", + "broadcast", + "loopback", + "pointtopoint", + "multicast", +} + +func (f Flags) String() string { + s := "" + for i, name := range flagNames { + if f&(1<<uint(i)) != 0 { + if s != "" { + s += "|" + } + s += name + } + } + if s == "" { + s = "0" + } + return s +} + +// Addrs returns interface addresses for a specific interface. +func (ifi *Interface) Addrs() ([]Addr, os.Error) { + if ifi == nil { + return nil, os.NewError("net: invalid interface") + } + return interfaceAddrTable(ifi.Index) +} + +// MulticastAddrs returns multicast, joined group addresses for +// a specific interface. +func (ifi *Interface) MulticastAddrs() ([]Addr, os.Error) { + if ifi == nil { + return nil, os.NewError("net: invalid interface") + } + return interfaceMulticastAddrTable(ifi.Index) +} + +// Interfaces returns a list of the systems's network interfaces. +func Interfaces() ([]Interface, os.Error) { + return interfaceTable(0) +} + +// InterfaceAddrs returns a list of the system's network interface +// addresses. +func InterfaceAddrs() ([]Addr, os.Error) { + return interfaceAddrTable(0) +} + +// InterfaceByIndex returns the interface specified by index. +func InterfaceByIndex(index int) (*Interface, os.Error) { + if index <= 0 { + return nil, os.NewError("net: invalid interface index") + } + ift, err := interfaceTable(index) + if err != nil { + return nil, err + } + for _, ifi := range ift { + return &ifi, nil + } + return nil, os.NewError("net: no such interface") +} + +// InterfaceByName returns the interface specified by name. +func InterfaceByName(name string) (*Interface, os.Error) { + if name == "" { + return nil, os.NewError("net: invalid interface name") + } + ift, err := interfaceTable(0) + if err != nil { + return nil, err + } + for _, ifi := range ift { + if name == ifi.Name { + return &ifi, nil + } + } + return nil, os.NewError("net: no such interface") +} diff --git a/src/pkg/net/interface_bsd.go b/src/pkg/net/interface_bsd.go new file mode 100644 index 000000000..2675f94b9 --- /dev/null +++ b/src/pkg/net/interface_bsd.go @@ -0,0 +1,171 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Network interface identification for BSD variants + +package net + +import ( + "os" + "syscall" + "unsafe" +) + +// If the ifindex is zero, interfaceTable returns mappings of all +// network interfaces. Otheriwse it returns a mapping of a specific +// interface. +func interfaceTable(ifindex int) ([]Interface, os.Error) { + var ( + tab []byte + e int + msgs []syscall.RoutingMessage + ift []Interface + ) + + tab, e = syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex) + if e != 0 { + return nil, os.NewSyscallError("route rib", e) + } + + msgs, e = syscall.ParseRoutingMessage(tab) + if e != 0 { + return nil, os.NewSyscallError("route message", e) + } + + for _, m := range msgs { + switch v := m.(type) { + case *syscall.InterfaceMessage: + if ifindex == 0 || ifindex == int(v.Header.Index) { + ifi, err := newLink(v) + if err != nil { + return nil, err + } + ift = append(ift, ifi...) + } + } + } + + return ift, nil +} + +func newLink(m *syscall.InterfaceMessage) ([]Interface, os.Error) { + var ift []Interface + + sas, e := syscall.ParseRoutingSockaddr(m) + if e != 0 { + return nil, os.NewSyscallError("route sockaddr", e) + } + + for _, s := range sas { + switch v := s.(type) { + case *syscall.SockaddrDatalink: + // NOTE: SockaddrDatalink.Data is minimum work area, + // can be larger. + m.Data = m.Data[unsafe.Offsetof(v.Data):] + ifi := Interface{Index: int(m.Header.Index), Flags: linkFlags(m.Header.Flags)} + var name [syscall.IFNAMSIZ]byte + for i := 0; i < int(v.Nlen); i++ { + name[i] = byte(m.Data[i]) + } + ifi.Name = string(name[:v.Nlen]) + ifi.MTU = int(m.Header.Data.Mtu) + addr := make([]byte, v.Alen) + for i := 0; i < int(v.Alen); i++ { + addr[i] = byte(m.Data[int(v.Nlen)+i]) + } + ifi.HardwareAddr = addr[:v.Alen] + ift = append(ift, ifi) + } + } + + return ift, nil +} + +func linkFlags(rawFlags int32) Flags { + var f Flags + if rawFlags&syscall.IFF_UP != 0 { + f |= FlagUp + } + if rawFlags&syscall.IFF_BROADCAST != 0 { + f |= FlagBroadcast + } + if rawFlags&syscall.IFF_LOOPBACK != 0 { + f |= FlagLoopback + } + if rawFlags&syscall.IFF_POINTOPOINT != 0 { + f |= FlagPointToPoint + } + if rawFlags&syscall.IFF_MULTICAST != 0 { + f |= FlagMulticast + } + return f +} + +// If the ifindex is zero, interfaceAddrTable returns addresses +// for all network interfaces. Otherwise it returns addresses +// for a specific interface. +func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { + var ( + tab []byte + e int + msgs []syscall.RoutingMessage + ifat []Addr + ) + + tab, e = syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex) + if e != 0 { + return nil, os.NewSyscallError("route rib", e) + } + + msgs, e = syscall.ParseRoutingMessage(tab) + if e != 0 { + return nil, os.NewSyscallError("route message", e) + } + + for _, m := range msgs { + switch v := m.(type) { + case *syscall.InterfaceAddrMessage: + if ifindex == 0 || ifindex == int(v.Header.Index) { + ifa, err := newAddr(v) + if err != nil { + return nil, err + } + ifat = append(ifat, ifa...) + } + } + } + + return ifat, nil +} + +func newAddr(m *syscall.InterfaceAddrMessage) ([]Addr, os.Error) { + var ifat []Addr + + sas, e := syscall.ParseRoutingSockaddr(m) + if e != 0 { + return nil, os.NewSyscallError("route sockaddr", e) + } + + for _, s := range sas { + + switch v := s.(type) { + case *syscall.SockaddrInet4: + ifa := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])} + ifat = append(ifat, ifa.toAddr()) + case *syscall.SockaddrInet6: + ifa := &IPAddr{IP: make(IP, IPv6len)} + copy(ifa.IP, v.Addr[:]) + // NOTE: KAME based IPv6 protcol stack usually embeds + // the interface index in the interface-local or link- + // local address as the kernel-internal form. + if ifa.IP.IsLinkLocalUnicast() { + // remove embedded scope zone ID + ifa.IP[2], ifa.IP[3] = 0, 0 + } + ifat = append(ifat, ifa.toAddr()) + } + } + + return ifat, nil +} diff --git a/src/pkg/net/interface_darwin.go b/src/pkg/net/interface_darwin.go new file mode 100644 index 000000000..a7b68ad7f --- /dev/null +++ b/src/pkg/net/interface_darwin.go @@ -0,0 +1,80 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Network interface identification for Darwin + +package net + +import ( + "os" + "syscall" +) + +// If the ifindex is zero, interfaceMulticastAddrTable returns +// addresses for all network interfaces. Otherwise it returns +// addresses for a specific interface. +func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { + var ( + tab []byte + e int + msgs []syscall.RoutingMessage + ifmat []Addr + ) + + tab, e = syscall.RouteRIB(syscall.NET_RT_IFLIST2, ifindex) + if e != 0 { + return nil, os.NewSyscallError("route rib", e) + } + + msgs, e = syscall.ParseRoutingMessage(tab) + if e != 0 { + return nil, os.NewSyscallError("route message", e) + } + + for _, m := range msgs { + switch v := m.(type) { + case *syscall.InterfaceMulticastAddrMessage: + if ifindex == 0 || ifindex == int(v.Header.Index) { + ifma, err := newMulticastAddr(v) + if err != nil { + return nil, err + } + ifmat = append(ifmat, ifma...) + } + } + } + + return ifmat, nil +} + +func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, os.Error) { + var ifmat []Addr + + sas, e := syscall.ParseRoutingSockaddr(m) + if e != 0 { + return nil, os.NewSyscallError("route sockaddr", e) + } + + for _, s := range sas { + switch v := s.(type) { + case *syscall.SockaddrInet4: + ifma := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])} + ifmat = append(ifmat, ifma.toAddr()) + case *syscall.SockaddrInet6: + ifma := &IPAddr{IP: make(IP, IPv6len)} + copy(ifma.IP, v.Addr[:]) + // NOTE: KAME based IPv6 protcol stack usually embeds + // the interface index in the interface-local or link- + // local address as the kernel-internal form. + if ifma.IP.IsInterfaceLocalMulticast() || + ifma.IP.IsLinkLocalMulticast() { + // remove embedded scope zone ID + ifma.IP[2], ifma.IP[3] = 0, 0 + } + ifmat = append(ifmat, ifma.toAddr()) + } + } + + return ifmat, nil +} diff --git a/src/pkg/net/interface_freebsd.go b/src/pkg/net/interface_freebsd.go new file mode 100644 index 000000000..20f506b08 --- /dev/null +++ b/src/pkg/net/interface_freebsd.go @@ -0,0 +1,80 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Network interface identification for FreeBSD + +package net + +import ( + "os" + "syscall" +) + +// If the ifindex is zero, interfaceMulticastAddrTable returns +// addresses for all network interfaces. Otherwise it returns +// addresses for a specific interface. +func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { + var ( + tab []byte + e int + msgs []syscall.RoutingMessage + ifmat []Addr + ) + + tab, e = syscall.RouteRIB(syscall.NET_RT_IFMALIST, ifindex) + if e != 0 { + return nil, os.NewSyscallError("route rib", e) + } + + msgs, e = syscall.ParseRoutingMessage(tab) + if e != 0 { + return nil, os.NewSyscallError("route message", e) + } + + for _, m := range msgs { + switch v := m.(type) { + case *syscall.InterfaceMulticastAddrMessage: + if ifindex == 0 || ifindex == int(v.Header.Index) { + ifma, err := newMulticastAddr(v) + if err != nil { + return nil, err + } + ifmat = append(ifmat, ifma...) + } + } + } + + return ifmat, nil +} + +func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, os.Error) { + var ifmat []Addr + + sas, e := syscall.ParseRoutingSockaddr(m) + if e != 0 { + return nil, os.NewSyscallError("route sockaddr", e) + } + + for _, s := range sas { + switch v := s.(type) { + case *syscall.SockaddrInet4: + ifma := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])} + ifmat = append(ifmat, ifma.toAddr()) + case *syscall.SockaddrInet6: + ifma := &IPAddr{IP: make(IP, IPv6len)} + copy(ifma.IP, v.Addr[:]) + // NOTE: KAME based IPv6 protcol stack usually embeds + // the interface index in the interface-local or link- + // local address as the kernel-internal form. + if ifma.IP.IsInterfaceLocalMulticast() || + ifma.IP.IsLinkLocalMulticast() { + // remove embedded scope zone ID + ifma.IP[2], ifma.IP[3] = 0, 0 + } + ifmat = append(ifmat, ifma.toAddr()) + } + } + + return ifmat, nil +} diff --git a/src/pkg/net/interface_linux.go b/src/pkg/net/interface_linux.go new file mode 100644 index 000000000..3d2a0bb9f --- /dev/null +++ b/src/pkg/net/interface_linux.go @@ -0,0 +1,262 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Network interface identification for Linux + +package net + +import ( + "fmt" + "os" + "syscall" + "unsafe" +) + +// If the ifindex is zero, interfaceTable returns mappings of all +// network interfaces. Otheriwse it returns a mapping of a specific +// interface. +func interfaceTable(ifindex int) ([]Interface, os.Error) { + var ( + ift []Interface + tab []byte + msgs []syscall.NetlinkMessage + e int + ) + + tab, e = syscall.NetlinkRIB(syscall.RTM_GETLINK, syscall.AF_UNSPEC) + if e != 0 { + return nil, os.NewSyscallError("netlink rib", e) + } + + msgs, e = syscall.ParseNetlinkMessage(tab) + if e != 0 { + return nil, os.NewSyscallError("netlink message", e) + } + + for _, m := range msgs { + switch m.Header.Type { + case syscall.NLMSG_DONE: + goto done + case syscall.RTM_NEWLINK: + ifim := (*syscall.IfInfomsg)(unsafe.Pointer(&m.Data[0])) + if ifindex == 0 || ifindex == int(ifim.Index) { + attrs, e := syscall.ParseNetlinkRouteAttr(&m) + if e != 0 { + return nil, os.NewSyscallError("netlink routeattr", e) + } + ifi := newLink(attrs, ifim) + ift = append(ift, ifi) + } + } + } + +done: + return ift, nil +} + +func newLink(attrs []syscall.NetlinkRouteAttr, ifim *syscall.IfInfomsg) Interface { + ifi := Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)} + for _, a := range attrs { + switch a.Attr.Type { + case syscall.IFLA_ADDRESS: + var nonzero bool + for _, b := range a.Value { + if b != 0 { + nonzero = true + } + } + if nonzero { + ifi.HardwareAddr = a.Value[:] + } + case syscall.IFLA_IFNAME: + ifi.Name = string(a.Value[:len(a.Value)-1]) + case syscall.IFLA_MTU: + ifi.MTU = int(uint32(a.Value[3])<<24 | uint32(a.Value[2])<<16 | uint32(a.Value[1])<<8 | uint32(a.Value[0])) + } + } + return ifi +} + +func linkFlags(rawFlags uint32) Flags { + var f Flags + if rawFlags&syscall.IFF_UP != 0 { + f |= FlagUp + } + if rawFlags&syscall.IFF_BROADCAST != 0 { + f |= FlagBroadcast + } + if rawFlags&syscall.IFF_LOOPBACK != 0 { + f |= FlagLoopback + } + if rawFlags&syscall.IFF_POINTOPOINT != 0 { + f |= FlagPointToPoint + } + if rawFlags&syscall.IFF_MULTICAST != 0 { + f |= FlagMulticast + } + return f +} + +// If the ifindex is zero, interfaceAddrTable returns addresses +// for all network interfaces. Otherwise it returns addresses +// for a specific interface. +func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { + var ( + tab []byte + e int + err os.Error + ifat4 []Addr + ifat6 []Addr + msgs4 []syscall.NetlinkMessage + msgs6 []syscall.NetlinkMessage + ) + + tab, e = syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_INET) + if e != 0 { + return nil, os.NewSyscallError("netlink rib", e) + } + msgs4, e = syscall.ParseNetlinkMessage(tab) + if e != 0 { + return nil, os.NewSyscallError("netlink message", e) + } + ifat4, err = addrTable(msgs4, ifindex) + if err != nil { + return nil, err + } + + tab, e = syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_INET6) + if e != 0 { + return nil, os.NewSyscallError("netlink rib", e) + } + msgs6, e = syscall.ParseNetlinkMessage(tab) + if e != 0 { + return nil, os.NewSyscallError("netlink message", e) + } + ifat6, err = addrTable(msgs6, ifindex) + if err != nil { + return nil, err + } + + return append(ifat4, ifat6...), nil +} + +func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, os.Error) { + var ifat []Addr + + for _, m := range msgs { + switch m.Header.Type { + case syscall.NLMSG_DONE: + goto done + case syscall.RTM_NEWADDR: + ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0])) + if ifindex == 0 || ifindex == int(ifam.Index) { + attrs, e := syscall.ParseNetlinkRouteAttr(&m) + if e != 0 { + return nil, os.NewSyscallError("netlink routeattr", e) + } + ifat = append(ifat, newAddr(attrs, int(ifam.Family))...) + } + } + } + +done: + return ifat, nil +} + +func newAddr(attrs []syscall.NetlinkRouteAttr, family int) []Addr { + var ifat []Addr + + for _, a := range attrs { + switch a.Attr.Type { + case syscall.IFA_ADDRESS: + switch family { + case syscall.AF_INET: + ifa := &IPAddr{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3])} + ifat = append(ifat, ifa.toAddr()) + case syscall.AF_INET6: + ifa := &IPAddr{IP: make(IP, IPv6len)} + copy(ifa.IP, a.Value[:]) + ifat = append(ifat, ifa.toAddr()) + } + } + } + + return ifat +} + +// If the ifindex is zero, interfaceMulticastAddrTable returns +// addresses for all network interfaces. Otherwise it returns +// addresses for a specific interface. +func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { + var ( + ifi *Interface + err os.Error + ) + + if ifindex > 0 { + ifi, err = InterfaceByIndex(ifindex) + if err != nil { + return nil, err + } + } + + ifmat4 := parseProcNetIGMP(ifi) + ifmat6 := parseProcNetIGMP6(ifi) + + return append(ifmat4, ifmat6...), nil +} + +func parseProcNetIGMP(ifi *Interface) []Addr { + var ( + ifmat []Addr + name string + ) + + fd, err := open("/proc/net/igmp") + if err != nil { + return nil + } + defer fd.close() + + fd.readLine() // skip first line + b := make([]byte, IPv4len) + for l, ok := fd.readLine(); ok; l, ok = fd.readLine() { + f := getFields(l) + switch len(f) { + case 4: + if ifi == nil || name == ifi.Name { + fmt.Sscanf(f[0], "%08x", &b) + ifma := IPAddr{IP: IPv4(b[3], b[2], b[1], b[0])} + ifmat = append(ifmat, ifma.toAddr()) + } + case 5: + name = f[1] + } + } + + return ifmat +} + +func parseProcNetIGMP6(ifi *Interface) []Addr { + var ifmat []Addr + + fd, err := open("/proc/net/igmp6") + if err != nil { + return nil + } + defer fd.close() + + b := make([]byte, IPv6len) + for l, ok := fd.readLine(); ok; l, ok = fd.readLine() { + f := getFields(l) + if ifi == nil || f[1] == ifi.Name { + fmt.Sscanf(f[2], "%32x", &b) + ifma := IPAddr{IP: IP{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]}} + ifmat = append(ifmat, ifma.toAddr()) + + } + } + + return ifmat +} diff --git a/src/pkg/net/interface_openbsd.go b/src/pkg/net/interface_openbsd.go new file mode 100644 index 000000000..f18149393 --- /dev/null +++ b/src/pkg/net/interface_openbsd.go @@ -0,0 +1,16 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Network interface identification for OpenBSD + +package net + +import "os" + +// If the ifindex is zero, interfaceMulticastAddrTable returns +// addresses for all network interfaces. Otherwise it returns +// addresses for a specific interface. +func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { + return nil, nil +} diff --git a/src/pkg/net/interface_stub.go b/src/pkg/net/interface_stub.go new file mode 100644 index 000000000..950de6c59 --- /dev/null +++ b/src/pkg/net/interface_stub.go @@ -0,0 +1,30 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Network interface identification + +package net + +import "os" + +// If the ifindex is zero, interfaceTable returns mappings of all +// network interfaces. Otheriwse it returns a mapping of a specific +// interface. +func interfaceTable(ifindex int) ([]Interface, os.Error) { + return nil, nil +} + +// If the ifindex is zero, interfaceAddrTable returns addresses +// for all network interfaces. Otherwise it returns addresses +// for a specific interface. +func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { + return nil, nil +} + +// If the ifindex is zero, interfaceMulticastAddrTable returns +// addresses for all network interfaces. Otherwise it returns +// addresses for a specific interface. +func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { + return nil, nil +} diff --git a/src/pkg/net/interface_test.go b/src/pkg/net/interface_test.go new file mode 100644 index 000000000..0e4089abf --- /dev/null +++ b/src/pkg/net/interface_test.go @@ -0,0 +1,73 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "testing" +) + +func sameInterface(i, j *Interface) bool { + if i == nil || j == nil { + return false + } + if i.Index == j.Index && i.Name == j.Name && bytes.Equal(i.HardwareAddr, j.HardwareAddr) { + return true + } + return false +} + +func TestInterfaces(t *testing.T) { + ift, err := Interfaces() + if err != nil { + t.Fatalf("Interfaces() failed: %v", err) + } + t.Logf("table: len/cap = %v/%v\n", len(ift), cap(ift)) + + for _, ifi := range ift { + ifxi, err := InterfaceByIndex(ifi.Index) + if err != nil { + t.Fatalf("InterfaceByIndex(%#q) failed: %v", ifi.Index, err) + } + if !sameInterface(ifxi, &ifi) { + t.Fatalf("InterfaceByIndex(%#q) = %v, want %v", ifi.Index, *ifxi, ifi) + } + ifxn, err := InterfaceByName(ifi.Name) + if err != nil { + t.Fatalf("InterfaceByName(%#q) failed: %v", ifi.Name, err) + } + if !sameInterface(ifxn, &ifi) { + t.Fatalf("InterfaceByName(%#q) = %v, want %v", ifi.Name, *ifxn, ifi) + } + ifat, err := ifi.Addrs() + if err != nil { + t.Fatalf("Interface.Addrs() failed: %v", err) + } + ifmat, err := ifi.MulticastAddrs() + if err != nil { + t.Fatalf("Interface.MulticastAddrs() failed: %v", err) + } + t.Logf("%q: flags %q, ifindex %v, mtu %v\n", ifi.Name, ifi.Flags.String(), ifi.Index, ifi.MTU) + for _, ifa := range ifat { + t.Logf("\tinterface address %q\n", ifa.String()) + } + for _, ifma := range ifmat { + t.Logf("\tjoined group address %q\n", ifma.String()) + } + t.Logf("\thardware address %q", ifi.HardwareAddr.String()) + } +} + +func TestInterfaceAddrs(t *testing.T) { + ifat, err := InterfaceAddrs() + if err != nil { + t.Fatalf("InterfaceAddrs() failed: %v", err) + } + t.Logf("table: len/cap = %v/%v\n", len(ifat), cap(ifat)) + + for _, ifa := range ifat { + t.Logf("interface address %q\n", ifa.String()) + } +} diff --git a/src/pkg/net/interface_windows.go b/src/pkg/net/interface_windows.go new file mode 100644 index 000000000..7f5169c87 --- /dev/null +++ b/src/pkg/net/interface_windows.go @@ -0,0 +1,158 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Network interface identification for Windows + +package net + +import ( + "os" + "syscall" + "unsafe" +) + +func bytePtrToString(p *uint8) string { + a := (*[10000]uint8)(unsafe.Pointer(p)) + i := 0 + for a[i] != 0 { + i++ + } + return string(a[:i]) +} + +func getAdapterList() (*syscall.IpAdapterInfo, os.Error) { + b := make([]byte, 1000) + l := uint32(len(b)) + a := (*syscall.IpAdapterInfo)(unsafe.Pointer(&b[0])) + e := syscall.GetAdaptersInfo(a, &l) + if e == syscall.ERROR_BUFFER_OVERFLOW { + b = make([]byte, l) + a = (*syscall.IpAdapterInfo)(unsafe.Pointer(&b[0])) + e = syscall.GetAdaptersInfo(a, &l) + } + if e != 0 { + return nil, os.NewSyscallError("GetAdaptersInfo", e) + } + return a, nil +} + +func getInterfaceList() ([]syscall.InterfaceInfo, os.Error) { + s, e := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + if e != 0 { + return nil, os.NewSyscallError("Socket", e) + } + defer syscall.Closesocket(s) + + ii := [20]syscall.InterfaceInfo{} + ret := uint32(0) + size := uint32(unsafe.Sizeof(ii)) + e = syscall.WSAIoctl(s, syscall.SIO_GET_INTERFACE_LIST, nil, 0, (*byte)(unsafe.Pointer(&ii[0])), size, &ret, nil, 0) + if e != 0 { + return nil, os.NewSyscallError("WSAIoctl", e) + } + c := ret / uint32(unsafe.Sizeof(ii[0])) + return ii[:c-1], nil +} + +// If the ifindex is zero, interfaceTable returns mappings of all +// network interfaces. Otheriwse it returns a mapping of a specific +// interface. +func interfaceTable(ifindex int) ([]Interface, os.Error) { + ai, e := getAdapterList() + if e != nil { + return nil, e + } + + ii, e := getInterfaceList() + if e != nil { + return nil, e + } + + var ift []Interface + for ; ai != nil; ai = ai.Next { + index := ai.Index + if ifindex == 0 || ifindex == int(index) { + var flags Flags + + row := syscall.MibIfRow{Index: index} + e := syscall.GetIfEntry(&row) + if e != 0 { + return nil, os.NewSyscallError("GetIfEntry", e) + } + + for _, ii := range ii { + ip := (*syscall.RawSockaddrInet4)(unsafe.Pointer(&ii.Address)).Addr + ipv4 := IPv4(ip[0], ip[1], ip[2], ip[3]) + ipl := &ai.IpAddressList + for ipl != nil { + ips := bytePtrToString(&ipl.IpAddress.String[0]) + if ipv4.Equal(parseIPv4(ips)) { + break + } + ipl = ipl.Next + } + if ipl == nil { + continue + } + if ii.Flags&syscall.IFF_UP != 0 { + flags |= FlagUp + } + if ii.Flags&syscall.IFF_LOOPBACK != 0 { + flags |= FlagLoopback + } + if ii.Flags&syscall.IFF_BROADCAST != 0 { + flags |= FlagBroadcast + } + if ii.Flags&syscall.IFF_POINTTOPOINT != 0 { + flags |= FlagPointToPoint + } + if ii.Flags&syscall.IFF_MULTICAST != 0 { + flags |= FlagMulticast + } + } + + name := bytePtrToString(&ai.AdapterName[0]) + + ifi := Interface{ + Index: int(index), + MTU: int(row.Mtu), + Name: name, + HardwareAddr: HardwareAddr(row.PhysAddr[:row.PhysAddrLen]), + Flags: flags} + ift = append(ift, ifi) + } + } + return ift, nil +} + +// If the ifindex is zero, interfaceAddrTable returns addresses +// for all network interfaces. Otherwise it returns addresses +// for a specific interface. +func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { + ai, e := getAdapterList() + if e != nil { + return nil, e + } + + var ifat []Addr + for ; ai != nil; ai = ai.Next { + index := ai.Index + if ifindex == 0 || ifindex == int(index) { + ipl := &ai.IpAddressList + for ; ipl != nil; ipl = ipl.Next { + ifa := IPAddr{} + ifa.IP = parseIPv4(bytePtrToString(&ipl.IpAddress.String[0])) + ifat = append(ifat, ifa.toAddr()) + } + } + } + return ifat, nil +} + +// If the ifindex is zero, interfaceMulticastAddrTable returns +// addresses for all network interfaces. Otherwise it returns +// addresses for a specific interface. +func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { + return nil, nil +} diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go new file mode 100644 index 000000000..b0e2c4205 --- /dev/null +++ b/src/pkg/net/ip.go @@ -0,0 +1,614 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP address manipulations +// +// IPv4 addresses are 4 bytes; IPv6 addresses are 16 bytes. +// An IPv4 address can be converted to an IPv6 address by +// adding a canonical prefix (10 zeros, 2 0xFFs). +// This library accepts either size of byte array but always +// returns 16-byte addresses. + +package net + +import "os" + +// IP address lengths (bytes). +const ( + IPv4len = 4 + IPv6len = 16 +) + +// An IP is a single IP address, an array of bytes. +// Functions in this package accept either 4-byte (IP v4) +// or 16-byte (IP v6) arrays as input. Unless otherwise +// specified, functions in this package always return +// IP addresses in 16-byte form using the canonical +// embedding. +// +// Note that in this documentation, referring to an +// IP address as an IPv4 address or an IPv6 address +// is a semantic property of the address, not just the +// length of the byte array: a 16-byte array can still +// be an IPv4 address. +type IP []byte + +// An IP mask is an IP address. +type IPMask []byte + +// IPv4 returns the IP address (in 16-byte form) of the +// IPv4 address a.b.c.d. +func IPv4(a, b, c, d byte) IP { + p := make(IP, IPv6len) + copy(p, v4InV6Prefix) + p[12] = a + p[13] = b + p[14] = c + p[15] = d + return p +} + +var v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff} + +// IPv4Mask returns the IP mask (in 16-byte form) of the +// IPv4 mask a.b.c.d. +func IPv4Mask(a, b, c, d byte) IPMask { + p := make(IPMask, IPv6len) + for i := 0; i < 12; i++ { + p[i] = 0xff + } + p[12] = a + p[13] = b + p[14] = c + p[15] = d + return p +} + +// Well-known IPv4 addresses +var ( + IPv4bcast = IPv4(255, 255, 255, 255) // broadcast + IPv4allsys = IPv4(224, 0, 0, 1) // all systems + IPv4allrouter = IPv4(224, 0, 0, 2) // all routers + IPv4zero = IPv4(0, 0, 0, 0) // all zeros +) + +// Well-known IPv6 addresses +var ( + IPv6zero = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + IPv6unspecified = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + IPv6loopback = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + IPv6interfacelocalallnodes = IP{0xff, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} + IPv6linklocalallnodes = IP{0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} + IPv6linklocalallrouters = IP{0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x02} +) + +// IsUnspecified returns true if ip is an unspecified address. +func (ip IP) IsUnspecified() bool { + if ip.Equal(IPv4zero) || ip.Equal(IPv6unspecified) { + return true + } + return false +} + +// IsLoopback returns true if ip is a loopback address. +func (ip IP) IsLoopback() bool { + if ip4 := ip.To4(); ip4 != nil && ip4[0] == 127 { + return true + } + return ip.Equal(IPv6loopback) +} + +// IsMulticast returns true if ip is a multicast address. +func (ip IP) IsMulticast() bool { + if ip4 := ip.To4(); ip4 != nil && ip4[0]&0xf0 == 0xe0 { + return true + } + return ip[0] == 0xff +} + +// IsInterfaceLinkLocalMulticast returns true if ip is +// an interface-local multicast address. +func (ip IP) IsInterfaceLocalMulticast() bool { + return len(ip) == IPv6len && ip[0] == 0xff && ip[1]&0x0f == 0x01 +} + +// IsLinkLocalMulticast returns true if ip is a link-local +// multicast address. +func (ip IP) IsLinkLocalMulticast() bool { + if ip4 := ip.To4(); ip4 != nil && ip4[0] == 224 && ip4[1] == 0 && ip4[2] == 0 { + return true + } + return ip[0] == 0xff && ip[1]&0x0f == 0x02 +} + +// IsLinkLocalUnicast returns true if ip is a link-local +// unicast address. +func (ip IP) IsLinkLocalUnicast() bool { + if ip4 := ip.To4(); ip4 != nil && ip4[0] == 169 && ip4[1] == 254 { + return true + } + return ip[0] == 0xfe && ip[1]&0xc0 == 0x80 +} + +// IsGlobalUnicast returns true if ip is a global unicast +// address. +func (ip IP) IsGlobalUnicast() bool { + return !ip.IsUnspecified() && + !ip.IsLoopback() && + !ip.IsMulticast() && + !ip.IsLinkLocalUnicast() +} + +// Is p all zeros? +func isZeros(p IP) bool { + for i := 0; i < len(p); i++ { + if p[i] != 0 { + return false + } + } + return true +} + +// To4 converts the IPv4 address ip to a 4-byte representation. +// If ip is not an IPv4 address, To4 returns nil. +func (ip IP) To4() IP { + if len(ip) == IPv4len { + return ip + } + if len(ip) == IPv6len && + isZeros(ip[0:10]) && + ip[10] == 0xff && + ip[11] == 0xff { + return ip[12:16] + } + return nil +} + +// To16 converts the IP address ip to a 16-byte representation. +// If ip is not an IP address (it is the wrong length), To16 returns nil. +func (ip IP) To16() IP { + if len(ip) == IPv4len { + return IPv4(ip[0], ip[1], ip[2], ip[3]) + } + if len(ip) == IPv6len { + return ip + } + return nil +} + +// Default route masks for IPv4. +var ( + classAMask = IPv4Mask(0xff, 0, 0, 0) + classBMask = IPv4Mask(0xff, 0xff, 0, 0) + classCMask = IPv4Mask(0xff, 0xff, 0xff, 0) +) + +// DefaultMask returns the default IP mask for the IP address ip. +// Only IPv4 addresses have default masks; DefaultMask returns +// nil if ip is not a valid IPv4 address. +func (ip IP) DefaultMask() IPMask { + if ip = ip.To4(); ip == nil { + return nil + } + switch true { + case ip[0] < 0x80: + return classAMask + case ip[0] < 0xC0: + return classBMask + default: + return classCMask + } + return nil // not reached +} + +func allFF(b []byte) bool { + for _, c := range b { + if c != 0xff { + return false + } + } + return true +} + +// Mask returns the result of masking the IP address ip with mask. +func (ip IP) Mask(mask IPMask) IP { + n := len(ip) + if len(mask) == 16 && len(ip) == 4 && allFF(mask[:12]) { + mask = mask[12:] + } + if len(mask) == 4 && len(ip) == 16 && bytesEqual(ip[:12], v4InV6Prefix) { + ip = ip[12:] + } + if n != len(mask) { + return nil + } + out := make(IP, n) + for i := 0; i < n; i++ { + out[i] = ip[i] & mask[i] + } + return out +} + +// Convert i to decimal string. +func itod(i uint) string { + if i == 0 { + return "0" + } + + // Assemble decimal in reverse order. + var b [32]byte + bp := len(b) + for ; i > 0; i /= 10 { + bp-- + b[bp] = byte(i%10) + '0' + } + + return string(b[bp:]) +} + +// Convert i to hexadecimal string. +func itox(i uint) string { + if i == 0 { + return "0" + } + + // Assemble hexadecimal in reverse order. + var b [32]byte + bp := len(b) + for ; i > 0; i /= 16 { + bp-- + b[bp] = "0123456789abcdef"[byte(i%16)] + } + + return string(b[bp:]) +} + +// String returns the string form of the IP address ip. +// If the address is an IPv4 address, the string representation +// is dotted decimal ("74.125.19.99"). Otherwise the representation +// is IPv6 ("2001:4860:0:2001::68"). +func (ip IP) String() string { + p := ip + + if len(ip) == 0 { + return "" + } + + // If IPv4, use dotted notation. + if p4 := p.To4(); len(p4) == 4 { + return itod(uint(p4[0])) + "." + + itod(uint(p4[1])) + "." + + itod(uint(p4[2])) + "." + + itod(uint(p4[3])) + } + if len(p) != IPv6len { + return "?" + } + + // Find longest run of zeros. + e0 := -1 + e1 := -1 + for i := 0; i < 16; i += 2 { + j := i + for j < 16 && p[j] == 0 && p[j+1] == 0 { + j += 2 + } + if j > i && j-i > e1-e0 { + e0 = i + e1 = j + } + } + // The symbol "::" MUST NOT be used to shorten just one 16 bit 0 field. + if e1-e0 <= 2 { + e0 = -1 + e1 = -1 + } + + // Print with possible :: in place of run of zeros + var s string + for i := 0; i < 16; i += 2 { + if i == e0 { + s += "::" + i = e1 + if i >= 16 { + break + } + } else if i > 0 { + s += ":" + } + s += itox((uint(p[i]) << 8) | uint(p[i+1])) + } + return s +} + +// Equal returns true if ip and x are the same IP address. +// An IPv4 address and that same address in IPv6 form are +// considered to be equal. +func (ip IP) Equal(x IP) bool { + if len(ip) == len(x) { + return bytesEqual(ip, x) + } + if len(ip) == 4 && len(x) == 16 { + return bytesEqual(x[0:12], v4InV6Prefix) && bytesEqual(ip, x[12:]) + } + if len(ip) == 16 && len(x) == 4 { + return bytesEqual(ip[0:12], v4InV6Prefix) && bytesEqual(ip[12:], x) + } + return false +} + +func bytesEqual(x, y []byte) bool { + if len(x) != len(y) { + return false + } + for i, b := range x { + if y[i] != b { + return false + } + } + return true +} + +// If mask is a sequence of 1 bits followed by 0 bits, +// return the number of 1 bits. +func simpleMaskLength(mask IPMask) int { + var n int + for i, v := range mask { + if v == 0xff { + n += 8 + continue + } + // found non-ff byte + // count 1 bits + for v&0x80 != 0 { + n++ + v <<= 1 + } + // rest must be 0 bits + if v != 0 { + return -1 + } + for i++; i < len(mask); i++ { + if mask[i] != 0 { + return -1 + } + } + break + } + return n +} + +// String returns the string representation of mask. +// If the mask is in the canonical form--ones followed by zeros--the +// string representation is just the decimal number of ones. +// If the mask is in a non-canonical form, it is formatted +// as an IP address. +func (mask IPMask) String() string { + switch len(mask) { + case 4: + n := simpleMaskLength(mask) + if n >= 0 { + return itod(uint(n + (IPv6len-IPv4len)*8)) + } + case 16: + n := simpleMaskLength(mask) + if n >= 12*8 { + return itod(uint(n - 12*8)) + } + } + return IP(mask).String() +} + +// Parse IPv4 address (d.d.d.d). +func parseIPv4(s string) IP { + var p [IPv4len]byte + i := 0 + for j := 0; j < IPv4len; j++ { + if i >= len(s) { + // Missing octets. + return nil + } + if j > 0 { + if s[i] != '.' { + return nil + } + i++ + } + var ( + n int + ok bool + ) + n, i, ok = dtoi(s, i) + if !ok || n > 0xFF { + return nil + } + p[j] = byte(n) + } + if i != len(s) { + return nil + } + return IPv4(p[0], p[1], p[2], p[3]) +} + +// Parse IPv6 address. Many forms. +// The basic form is a sequence of eight colon-separated +// 16-bit hex numbers separated by colons, +// as in 0123:4567:89ab:cdef:0123:4567:89ab:cdef. +// Two exceptions: +// * A run of zeros can be replaced with "::". +// * The last 32 bits can be in IPv4 form. +// Thus, ::ffff:1.2.3.4 is the IPv4 address 1.2.3.4. +func parseIPv6(s string) IP { + p := make(IP, 16) + ellipsis := -1 // position of ellipsis in p + i := 0 // index in string s + + // Might have leading ellipsis + if len(s) >= 2 && s[0] == ':' && s[1] == ':' { + ellipsis = 0 + i = 2 + // Might be only ellipsis + if i == len(s) { + return p + } + } + + // Loop, parsing hex numbers followed by colon. + j := 0 + for j < IPv6len { + // Hex number. + n, i1, ok := xtoi(s, i) + if !ok || n > 0xFFFF { + return nil + } + + // If followed by dot, might be in trailing IPv4. + if i1 < len(s) && s[i1] == '.' { + if ellipsis < 0 && j != IPv6len-IPv4len { + // Not the right place. + return nil + } + if j+IPv4len > IPv6len { + // Not enough room. + return nil + } + p4 := parseIPv4(s[i:]) + if p4 == nil { + return nil + } + p[j] = p4[12] + p[j+1] = p4[13] + p[j+2] = p4[14] + p[j+3] = p4[15] + i = len(s) + j += 4 + break + } + + // Save this 16-bit chunk. + p[j] = byte(n >> 8) + p[j+1] = byte(n) + j += 2 + + // Stop at end of string. + i = i1 + if i == len(s) { + break + } + + // Otherwise must be followed by colon and more. + if s[i] != ':' || i+1 == len(s) { + return nil + } + i++ + + // Look for ellipsis. + if s[i] == ':' { + if ellipsis >= 0 { // already have one + return nil + } + ellipsis = j + if i++; i == len(s) { // can be at end + break + } + } + } + + // Must have used entire string. + if i != len(s) { + return nil + } + + // If didn't parse enough, expand ellipsis. + if j < IPv6len { + if ellipsis < 0 { + return nil + } + n := IPv6len - j + for k := j - 1; k >= ellipsis; k-- { + p[k+n] = p[k] + } + for k := ellipsis + n - 1; k >= ellipsis; k-- { + p[k] = 0 + } + } + return p +} + +// A ParseError represents a malformed text string and the type of string that was expected. +type ParseError struct { + Type string + Text string +} + +func (e *ParseError) String() string { + return "invalid " + e.Type + ": " + e.Text +} + +func parseIP(s string) IP { + if p := parseIPv4(s); p != nil { + return p + } + if p := parseIPv6(s); p != nil { + return p + } + return nil +} + +// ParseIP parses s as an IP address, returning the result. +// The string s can be in dotted decimal ("74.125.19.99") +// or IPv6 ("2001:4860:0:2001::68") form. +// If s is not a valid textual representation of an IP address, +// ParseIP returns nil. +func ParseIP(s string) IP { + if p := parseIPv4(s); p != nil { + return p + } + return parseIPv6(s) +} + +// ParseCIDR parses s as a CIDR notation IP address and mask, +// like "192.168.100.1/24", "2001:DB8::/48", as defined in +// RFC 4632 and RFC 4291. +func ParseCIDR(s string) (ip IP, mask IPMask, err os.Error) { + i := byteIndex(s, '/') + if i < 0 { + return nil, nil, &ParseError{"CIDR address", s} + } + ipstr, maskstr := s[:i], s[i+1:] + iplen := 4 + ip = parseIPv4(ipstr) + if ip == nil { + iplen = 16 + ip = parseIPv6(ipstr) + } + nn, i, ok := dtoi(maskstr, 0) + if ip == nil || !ok || i != len(maskstr) || nn < 0 || nn > 8*iplen { + return nil, nil, &ParseError{"CIDR address", s} + } + n := uint(nn) + if iplen == 4 { + v4mask := ^uint32(0xffffffff >> n) + mask = IPv4Mask(byte(v4mask>>24), byte(v4mask>>16), byte(v4mask>>8), byte(v4mask)) + } else { + mask = make(IPMask, 16) + for i := 0; i < 16; i++ { + if n >= 8 { + mask[i] = 0xff + n -= 8 + continue + } + mask[i] = ^byte(0xff >> n) + n = 0 + + } + } + // address must not have any bits not in mask + for i := range ip { + if ip[i]&^mask[i] != 0 { + return nil, nil, &ParseError{"CIDR address", s} + } + } + return ip, mask, nil +} diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go new file mode 100644 index 000000000..b189b10c4 --- /dev/null +++ b/src/pkg/net/ip_test.go @@ -0,0 +1,215 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "reflect" + "testing" + "os" + "runtime" +) + +func isEqual(a, b []byte) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return bytes.Equal(a, b) +} + +var parseiptests = []struct { + in string + out IP +}{ + {"127.0.1.2", IPv4(127, 0, 1, 2)}, + {"127.0.0.1", IPv4(127, 0, 0, 1)}, + {"127.0.0.256", nil}, + {"abc", nil}, + {"123:", nil}, + {"::ffff:127.0.0.1", IPv4(127, 0, 0, 1)}, + {"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}}, + {"::ffff:4a7d:1363", IPv4(74, 125, 19, 99)}, +} + +func TestParseIP(t *testing.T) { + for _, tt := range parseiptests { + if out := ParseIP(tt.in); !isEqual(out, tt.out) { + t.Errorf("ParseIP(%#q) = %v, want %v", tt.in, out, tt.out) + } + } +} + +var ipstringtests = []struct { + in IP + out string +}{ + // cf. RFC 5952 (A Recommendation for IPv6 Address Text Representation) + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1}, + "2001:db8::123:12:1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1}, + "2001:db8::1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0x1, 0, 0, 0, 0x1, 0, 0, 0, 0x1}, + "2001:db8:0:1:0:1:0:1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0x1, 0, 0, 0, 0x1, 0, 0, 0, 0x1, 0, 0}, + "2001:db8:1:0:1:0:1:0"}, + {IP{0x20, 0x1, 0, 0, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0x1}, + "2001::1:0:0:1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0}, + "2001:db8:0:0:1::"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0x1}, + "2001:db8::1:0:0:1"}, + {IP{0x20, 0x1, 0xD, 0xB8, 0, 0, 0, 0, 0, 0xA, 0, 0xB, 0, 0xC, 0, 0xD}, + "2001:db8::a:b:c:d"}, +} + +func TestIPString(t *testing.T) { + for _, tt := range ipstringtests { + if out := tt.in.String(); out != tt.out { + t.Errorf("IP.String(%v) = %#q, want %#q", tt.in, out, tt.out) + } + } +} + +var parsecidrtests = []struct { + in string + ip IP + mask IPMask + err os.Error +}{ + {"135.104.0.0/32", IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 255), nil}, + {"0.0.0.0/24", IPv4(0, 0, 0, 0), IPv4Mask(255, 255, 255, 0), nil}, + {"135.104.0.0/24", IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 0), nil}, + {"135.104.0.1/32", IPv4(135, 104, 0, 1), IPv4Mask(255, 255, 255, 255), nil}, + {"135.104.0.1/24", nil, nil, &ParseError{"CIDR address", "135.104.0.1/24"}}, + {"::1/128", ParseIP("::1"), IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")), nil}, + {"abcd:2345::/127", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe")), nil}, + {"abcd:2345::/65", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff:8000::")), nil}, + {"abcd:2345::/64", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff::")), nil}, + {"abcd:2345::/63", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:fffe::")), nil}, + {"abcd:2345::/33", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:8000::")), nil}, + {"abcd:2345::/32", ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff::")), nil}, + {"abcd:2344::/31", ParseIP("abcd:2344::"), IPMask(ParseIP("ffff:fffe::")), nil}, + {"abcd:2300::/24", ParseIP("abcd:2300::"), IPMask(ParseIP("ffff:ff00::")), nil}, + {"abcd:2345::/24", nil, nil, &ParseError{"CIDR address", "abcd:2345::/24"}}, + {"2001:DB8::/48", ParseIP("2001:DB8::"), IPMask(ParseIP("ffff:ffff:ffff::")), nil}, +} + +func TestParseCIDR(t *testing.T) { + for _, tt := range parsecidrtests { + if ip, mask, err := ParseCIDR(tt.in); !isEqual(ip, tt.ip) || !isEqual(mask, tt.mask) || !reflect.DeepEqual(err, tt.err) { + t.Errorf("ParseCIDR(%q) = %v, %v, %v; want %v, %v, %v", tt.in, ip, mask, err, tt.ip, tt.mask, tt.err) + } + } +} + +var splitjointests = []struct { + Host string + Port string + Join string +}{ + {"www.google.com", "80", "www.google.com:80"}, + {"127.0.0.1", "1234", "127.0.0.1:1234"}, + {"::1", "80", "[::1]:80"}, +} + +func TestSplitHostPort(t *testing.T) { + for _, tt := range splitjointests { + if host, port, err := SplitHostPort(tt.Join); host != tt.Host || port != tt.Port || err != nil { + t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.Join, host, port, err, tt.Host, tt.Port) + } + } +} + +func TestJoinHostPort(t *testing.T) { + for _, tt := range splitjointests { + if join := JoinHostPort(tt.Host, tt.Port); join != tt.Join { + t.Errorf("JoinHostPort(%q, %q) = %q; want %q", tt.Host, tt.Port, join, tt.Join) + } + } +} + +var ipaftests = []struct { + in IP + af4 bool + af6 bool +}{ + {IPv4bcast, true, false}, + {IPv4allsys, true, false}, + {IPv4allrouter, true, false}, + {IPv4zero, true, false}, + {IPv4(224, 0, 0, 1), true, false}, + {IPv4(127, 0, 0, 1), true, false}, + {IPv4(240, 0, 0, 1), true, false}, + {IPv6unspecified, false, true}, + {IPv6loopback, false, true}, + {IPv6interfacelocalallnodes, false, true}, + {IPv6linklocalallnodes, false, true}, + {IPv6linklocalallrouters, false, true}, + {ParseIP("ff05::a:b:c:d"), false, true}, + {ParseIP("fe80::1:2:3:4"), false, true}, + {ParseIP("2001:db8::123:12:1"), false, true}, +} + +func TestIPAddrFamily(t *testing.T) { + for _, tt := range ipaftests { + if af := tt.in.To4() != nil; af != tt.af4 { + t.Errorf("verifying IPv4 address family for %#q = %v, want %v", tt.in, af, tt.af4) + } + if af := len(tt.in) == IPv6len && tt.in.To4() == nil; af != tt.af6 { + t.Errorf("verifying IPv6 address family for %#q = %v, want %v", tt.in, af, tt.af6) + } + } +} + +var ipscopetests = []struct { + scope func(IP) bool + in IP + ok bool +}{ + {IP.IsUnspecified, IPv4zero, true}, + {IP.IsUnspecified, IPv4(127, 0, 0, 1), false}, + {IP.IsUnspecified, IPv6unspecified, true}, + {IP.IsUnspecified, IPv6interfacelocalallnodes, false}, + {IP.IsLoopback, IPv4(127, 0, 0, 1), true}, + {IP.IsLoopback, IPv4(127, 255, 255, 254), true}, + {IP.IsLoopback, IPv4(128, 1, 2, 3), false}, + {IP.IsLoopback, IPv6loopback, true}, + {IP.IsLoopback, IPv6linklocalallrouters, false}, + {IP.IsMulticast, IPv4(224, 0, 0, 0), true}, + {IP.IsMulticast, IPv4(239, 0, 0, 0), true}, + {IP.IsMulticast, IPv4(240, 0, 0, 0), false}, + {IP.IsMulticast, IPv6linklocalallnodes, true}, + {IP.IsMulticast, IP{0xff, 0x05, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true}, + {IP.IsMulticast, IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false}, + {IP.IsLinkLocalMulticast, IPv4(224, 0, 0, 0), true}, + {IP.IsLinkLocalMulticast, IPv4(239, 0, 0, 0), false}, + {IP.IsLinkLocalMulticast, IPv6linklocalallrouters, true}, + {IP.IsLinkLocalMulticast, IP{0xff, 0x05, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false}, + {IP.IsLinkLocalUnicast, IPv4(169, 254, 0, 0), true}, + {IP.IsLinkLocalUnicast, IPv4(169, 255, 0, 0), false}, + {IP.IsLinkLocalUnicast, IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true}, + {IP.IsLinkLocalUnicast, IP{0xfe, 0xc0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false}, + {IP.IsGlobalUnicast, IPv4(240, 0, 0, 0), true}, + {IP.IsGlobalUnicast, IPv4(232, 0, 0, 0), false}, + {IP.IsGlobalUnicast, IPv4(169, 254, 0, 0), false}, + {IP.IsGlobalUnicast, IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1}, true}, + {IP.IsGlobalUnicast, IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false}, + {IP.IsGlobalUnicast, IP{0xff, 0x05, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false}, +} + +func name(f interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() +} + +func TestIPAddrScope(t *testing.T) { + for _, tt := range ipscopetests { + if ok := tt.scope(tt.in); ok != tt.ok { + t.Errorf("%s(%#q) = %v, want %v", name(tt.scope), tt.in, ok, tt.ok) + } + } +} diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go new file mode 100644 index 000000000..6894ce656 --- /dev/null +++ b/src/pkg/net/ipraw_test.go @@ -0,0 +1,120 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO(cw): ListenPacket test, Read() test, ipv6 test & +// Dial()/Listen() level tests + +package net + +import ( + "bytes" + "flag" + "os" + "testing" +) + +const ICMP_ECHO_REQUEST = 8 +const ICMP_ECHO_REPLY = 0 + +// returns a suitable 'ping request' packet, with id & seq and a +// payload length of pktlen +func makePingRequest(id, seq, pktlen int, filler []byte) []byte { + p := make([]byte, pktlen) + copy(p[8:], bytes.Repeat(filler, (pktlen-8)/len(filler)+1)) + + p[0] = ICMP_ECHO_REQUEST // type + p[1] = 0 // code + p[2] = 0 // cksum + p[3] = 0 // cksum + p[4] = uint8(id >> 8) // id + p[5] = uint8(id & 0xff) // id + p[6] = uint8(seq >> 8) // sequence + p[7] = uint8(seq & 0xff) // sequence + + // calculate icmp checksum + cklen := len(p) + s := uint32(0) + for i := 0; i < (cklen - 1); i += 2 { + s += uint32(p[i+1])<<8 | uint32(p[i]) + } + if cklen&1 == 1 { + s += uint32(p[cklen-1]) + } + s = (s >> 16) + (s & 0xffff) + s = s + (s >> 16) + + // place checksum back in header; using ^= avoids the + // assumption the checksum bytes are zero + p[2] ^= uint8(^s & 0xff) + p[3] ^= uint8(^s >> 8) + + return p +} + +func parsePingReply(p []byte) (id, seq int) { + id = int(p[4])<<8 | int(p[5]) + seq = int(p[6])<<8 | int(p[7]) + return +} + +var srchost = flag.String("srchost", "", "Source of the ICMP ECHO request") +// 127.0.0.1 because this is an IPv4-specific test. +var dsthost = flag.String("dsthost", "127.0.0.1", "Destination for the ICMP ECHO request") + +// test (raw) IP socket using ICMP +func TestICMP(t *testing.T) { + if os.Getuid() != 0 { + t.Logf("test disabled; must be root") + return + } + + var ( + laddr *IPAddr + err os.Error + ) + if *srchost != "" { + laddr, err = ResolveIPAddr("ip4", *srchost) + if err != nil { + t.Fatalf(`net.ResolveIPAddr("ip4", %v") = %v, %v`, *srchost, laddr, err) + } + } + + raddr, err := ResolveIPAddr("ip4", *dsthost) + if err != nil { + t.Fatalf(`net.ResolveIPAddr("ip4", %v") = %v, %v`, *dsthost, raddr, err) + } + + c, err := ListenIP("ip4:icmp", laddr) + if err != nil { + t.Fatalf(`net.ListenIP("ip4:icmp", %v) = %v, %v`, *srchost, c, err) + } + + sendid := os.Getpid() & 0xffff + const sendseq = 61455 + const pingpktlen = 128 + sendpkt := makePingRequest(sendid, sendseq, pingpktlen, []byte("Go Go Gadget Ping!!!")) + + n, err := c.WriteToIP(sendpkt, raddr) + if err != nil || n != pingpktlen { + t.Fatalf(`net.WriteToIP(..., %v) = %v, %v`, raddr, n, err) + } + + c.SetTimeout(100e6) + resp := make([]byte, 1024) + for { + n, from, err := c.ReadFrom(resp) + if err != nil { + t.Fatalf(`ReadFrom(...) = %v, %v, %v`, n, from, err) + } + if resp[0] != ICMP_ECHO_REPLY { + continue + } + rcvid, rcvseq := parsePingReply(resp) + if rcvid != sendid || rcvseq != sendseq { + t.Fatalf(`Ping reply saw id,seq=0x%x,0x%x (expected 0x%x, 0x%x)`, rcvid, rcvseq, sendid, sendseq) + } + return + } + t.Fatalf("saw no ping return") +} diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go new file mode 100644 index 000000000..662b9f57b --- /dev/null +++ b/src/pkg/net/iprawsock.go @@ -0,0 +1,69 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// (Raw) IP sockets + +package net + +import ( + "os" +) + +// IPAddr represents the address of a IP end point. +type IPAddr struct { + IP IP +} + +// Network returns the address's network name, "ip". +func (a *IPAddr) Network() string { return "ip" } + +func (a *IPAddr) String() string { + if a == nil { + return "<nil>" + } + return a.IP.String() +} + +// ResolveIPAddr parses addr as a IP address and resolves domain +// names to numeric addresses on the network net, which must be +// "ip", "ip4" or "ip6". A literal IPv6 host address must be +// enclosed in square brackets, as in "[::]". +func ResolveIPAddr(net, addr string) (*IPAddr, os.Error) { + ip, err := hostToIP(net, addr) + if err != nil { + return nil, err + } + return &IPAddr{ip}, nil +} + +// Convert "host" into IP address. +func hostToIP(net, host string) (ip IP, err os.Error) { + var addr IP + // Try as an IP address. + addr = ParseIP(host) + if addr == nil { + filter := anyaddr + if net != "" && net[len(net)-1] == '4' { + filter = ipv4only + } + if net != "" && net[len(net)-1] == '6' { + filter = ipv6only + } + // Not an IP address. Try as a DNS name. + addrs, err1 := LookupHost(host) + if err1 != nil { + err = err1 + goto Error + } + addr = firstFavoriteAddr(filter, addrs) + if addr == nil { + // should not happen + err = &AddrError{"LookupHost returned no suitable address", addrs[0]} + goto Error + } + } + return addr, nil +Error: + return nil, err +} diff --git a/src/pkg/net/iprawsock_plan9.go b/src/pkg/net/iprawsock_plan9.go new file mode 100644 index 000000000..808e17974 --- /dev/null +++ b/src/pkg/net/iprawsock_plan9.go @@ -0,0 +1,99 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// (Raw) IP sockets stubs for Plan 9 + +package net + +import ( + "os" +) + +// IPConn is the implementation of the Conn and PacketConn +// interfaces for IP network connections. +type IPConn bool + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *IPConn) Read(b []byte) (n int, err os.Error) { + return 0, os.EPLAN9 +} + +// Write implements the net.Conn Write method. +func (c *IPConn) Write(b []byte) (n int, err os.Error) { + return 0, os.EPLAN9 +} + +// Close closes the IP connection. +func (c *IPConn) Close() os.Error { + return os.EPLAN9 +} + +// LocalAddr returns the local network address. +func (c *IPConn) LocalAddr() Addr { + return nil +} + +// RemoteAddr returns the remote network address, a *IPAddr. +func (c *IPConn) RemoteAddr() Addr { + return nil +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *IPConn) SetTimeout(nsec int64) os.Error { + return os.EPLAN9 +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *IPConn) SetReadTimeout(nsec int64) os.Error { + return os.EPLAN9 +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *IPConn) SetWriteTimeout(nsec int64) os.Error { + return os.EPLAN9 +} + +// IP-specific methods. + +// ReadFrom implements the net.PacketConn ReadFrom method. +func (c *IPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { + err = os.EPLAN9 + return +} + +// WriteToIP writes a IP packet to addr via c, copying the payload from b. +// +// WriteToIP can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetWriteTimeout. +// On packet-oriented connections, write timeouts are rare. +func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (n int, err os.Error) { + return 0, os.EPLAN9 +} + +// WriteTo implements the net.PacketConn WriteTo method. +func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { + return 0, os.EPLAN9 +} + +func splitNetProto(netProto string) (net string, proto int, err os.Error) { + err = os.EPLAN9 + return +} + +// DialIP connects to the remote address raddr on the network net, +// which must be "ip", "ip4", or "ip6". +func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) { + return nil, os.EPLAN9 +} + +// ListenIP listens for incoming IP packets addressed to the +// local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send IP +// packets with per-packet addressing. +func ListenIP(netProto string, laddr *IPAddr) (c *IPConn, err os.Error) { + return nil, os.EPLAN9 +} diff --git a/src/pkg/net/iprawsock_posix.go b/src/pkg/net/iprawsock_posix.go new file mode 100644 index 000000000..4e1151800 --- /dev/null +++ b/src/pkg/net/iprawsock_posix.go @@ -0,0 +1,305 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// (Raw) IP sockets + +package net + +import ( + "os" + "sync" + "syscall" +) + +var onceReadProtocols sync.Once + +func sockaddrToIP(sa syscall.Sockaddr) Addr { + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + return &IPAddr{sa.Addr[0:]} + case *syscall.SockaddrInet6: + return &IPAddr{sa.Addr[0:]} + } + return nil +} + +func (a *IPAddr) family() int { + if a == nil || len(a.IP) <= 4 { + return syscall.AF_INET + } + if a.IP.To4() != nil { + return syscall.AF_INET + } + return syscall.AF_INET6 +} + +func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) { + return ipToSockaddr(family, a.IP, 0) +} + +func (a *IPAddr) toAddr() sockaddr { + if a == nil { // nil *IPAddr + return nil // nil interface + } + return a +} + +// IPConn is the implementation of the Conn and PacketConn +// interfaces for IP network connections. +type IPConn struct { + fd *netFD +} + +func newIPConn(fd *netFD) *IPConn { return &IPConn{fd} } + +func (c *IPConn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *IPConn) Read(b []byte) (n int, err os.Error) { + n, _, err = c.ReadFrom(b) + return +} + +// Write implements the net.Conn Write method. +func (c *IPConn) Write(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Write(b) +} + +// Close closes the IP connection. +func (c *IPConn) Close() os.Error { + if !c.ok() { + return os.EINVAL + } + err := c.fd.Close() + c.fd = nil + return err +} + +// LocalAddr returns the local network address. +func (c *IPConn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address, a *IPAddr. +func (c *IPConn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *IPConn) SetTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setTimeout(c.fd, nsec) +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *IPConn) SetReadTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadTimeout(c.fd, nsec) +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *IPConn) SetWriteTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteTimeout(c.fd, nsec) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *IPConn) SetReadBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *IPConn) SetWriteBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// IP-specific methods. + +// ReadFromIP reads a IP packet from c, copying the payload into b. +// It returns the number of bytes copied into b and the return address +// that was on the packet. +// +// ReadFromIP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetTimeout and +// SetReadTimeout. +func (c *IPConn) ReadFromIP(b []byte) (n int, addr *IPAddr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + // TODO(cw,rsc): consider using readv if we know the family + // type to avoid the header trim/copy + n, sa, err := c.fd.ReadFrom(b) + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + addr = &IPAddr{sa.Addr[0:]} + if len(b) >= 4 { // discard ipv4 header + hsize := (int(b[0]) & 0xf) * 4 + copy(b, b[hsize:]) + n -= hsize + } + case *syscall.SockaddrInet6: + addr = &IPAddr{sa.Addr[0:]} + } + return +} + +// ReadFrom implements the net.PacketConn ReadFrom method. +func (c *IPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, uaddr, err := c.ReadFromIP(b) + return n, uaddr.toAddr(), err +} + +// WriteToIP writes a IP packet to addr via c, copying the payload from b. +// +// WriteToIP can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetWriteTimeout. +// On packet-oriented connections, write timeouts are rare. +func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + sa, err1 := addr.sockaddr(c.fd.family) + if err1 != nil { + return 0, &OpError{Op: "write", Net: "ip", Addr: addr, Error: err1} + } + return c.fd.WriteTo(b, sa) +} + +// WriteTo implements the net.PacketConn WriteTo method. +func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + a, ok := addr.(*IPAddr) + if !ok { + return 0, &OpError{"writeto", "ip", addr, os.EINVAL} + } + return c.WriteToIP(b, a) +} + +var protocols map[string]int + +func readProtocols() { + protocols = make(map[string]int) + if file, err := open("/etc/protocols"); err == nil { + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + // tcp 6 TCP # transmission control protocol + if i := byteIndex(line, '#'); i >= 0 { + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 { + continue + } + if proto, _, ok := dtoi(f[1], 0); ok { + protocols[f[0]] = proto + for _, alias := range f[2:] { + protocols[alias] = proto + } + } + } + file.close() + } +} + +func splitNetProto(netProto string) (net string, proto int, err os.Error) { + onceReadProtocols.Do(readProtocols) + i := last(netProto, ':') + if i < 0 { // no colon + return "", 0, os.NewError("no IP protocol specified") + } + net = netProto[0:i] + protostr := netProto[i+1:] + proto, i, ok := dtoi(protostr, 0) + if !ok || i != len(protostr) { + // lookup by name + proto, ok = protocols[protostr] + if ok { + return + } + } + return +} + +// DialIP connects to the remote address raddr on the network net, +// which must be "ip", "ip4", or "ip6". +func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) { + net, proto, err := splitNetProto(netProto) + if err != nil { + return + } + switch net { + case "ip", "ip4", "ip6": + default: + return nil, UnknownNetworkError(net) + } + if raddr == nil { + return nil, &OpError{"dial", "ip", nil, errMissingAddress} + } + fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_RAW, proto, "dial", sockaddrToIP) + if e != nil { + return nil, e + } + return newIPConn(fd), nil +} + +// ListenIP listens for incoming IP packets addressed to the +// local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send IP +// packets with per-packet addressing. +func ListenIP(netProto string, laddr *IPAddr) (c *IPConn, err os.Error) { + net, proto, err := splitNetProto(netProto) + if err != nil { + return + } + switch net { + case "ip", "ip4", "ip6": + default: + return nil, UnknownNetworkError(net) + } + fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "dial", sockaddrToIP) + if e != nil { + return nil, e + } + return newIPConn(fd), nil +} + +// BindToDevice binds an IPConn to a network interface. +func (c *IPConn) BindToDevice(device string) os.Error { + if !c.ok() { + return os.EINVAL + } + c.fd.incref() + defer c.fd.decref() + return os.NewSyscallError("setsockopt", syscall.BindToDevice(c.fd.sysfd, device)) +} diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go new file mode 100644 index 000000000..4e2a5622b --- /dev/null +++ b/src/pkg/net/ipsock.go @@ -0,0 +1,158 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP sockets + +package net + +import ( + "os" +) + +var supportsIPv6, supportsIPv4map = probeIPv6Stack() + +func firstFavoriteAddr(filter func(IP) IP, addrs []string) (addr IP) { + if filter == anyaddr { + // We'll take any IP address, but since the dialing code + // does not yet try multiple addresses, prefer to use + // an IPv4 address if possible. This is especially relevant + // if localhost resolves to [ipv6-localhost, ipv4-localhost]. + // Too much code assumes localhost == ipv4-localhost. + addr = firstSupportedAddr(ipv4only, addrs) + if addr == nil { + addr = firstSupportedAddr(anyaddr, addrs) + } + } else { + addr = firstSupportedAddr(filter, addrs) + } + return +} + +func firstSupportedAddr(filter func(IP) IP, addrs []string) IP { + for _, s := range addrs { + if addr := filter(ParseIP(s)); addr != nil { + return addr + } + } + return nil +} + +func anyaddr(x IP) IP { + if x4 := x.To4(); x4 != nil { + return x4 + } + if supportsIPv6 { + return x + } + return nil +} + +func ipv4only(x IP) IP { return x.To4() } + +func ipv6only(x IP) IP { + // Only return addresses that we can use + // with the kernel's IPv6 addressing modes. + if len(x) == IPv6len && x.To4() == nil && supportsIPv6 { + return x + } + return nil +} + +type InvalidAddrError string + +func (e InvalidAddrError) String() string { return string(e) } +func (e InvalidAddrError) Timeout() bool { return false } +func (e InvalidAddrError) Temporary() bool { return false } + +// SplitHostPort splits a network address of the form +// "host:port" or "[host]:port" into host and port. +// The latter form must be used when host contains a colon. +func SplitHostPort(hostport string) (host, port string, err os.Error) { + // The port starts after the last colon. + i := last(hostport, ':') + if i < 0 { + err = &AddrError{"missing port in address", hostport} + return + } + + host, port = hostport[0:i], hostport[i+1:] + + // Can put brackets around host ... + if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } else { + // ... but if there are no brackets, no colons. + if byteIndex(host, ':') >= 0 { + err = &AddrError{"too many colons in address", hostport} + return + } + } + return +} + +// JoinHostPort combines host and port into a network address +// of the form "host:port" or, if host contains a colon, "[host]:port". +func JoinHostPort(host, port string) string { + // If host has colons, have to bracket it. + if byteIndex(host, ':') >= 0 { + return "[" + host + "]:" + port + } + return host + ":" + port +} + +// Convert "host:port" into IP address and port. +func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) { + var ( + addr IP + p, i int + ok bool + ) + host, port, err := SplitHostPort(hostport) + if err != nil { + goto Error + } + + if host != "" { + // Try as an IP address. + addr = ParseIP(host) + if addr == nil { + filter := anyaddr + if net != "" && net[len(net)-1] == '4' { + filter = ipv4only + } + if net != "" && net[len(net)-1] == '6' { + filter = ipv6only + } + // Not an IP address. Try as a DNS name. + addrs, err1 := LookupHost(host) + if err1 != nil { + err = err1 + goto Error + } + addr = firstFavoriteAddr(filter, addrs) + if addr == nil { + // should not happen + err = &AddrError{"LookupHost returned no suitable address", addrs[0]} + goto Error + } + } + } + + p, i, ok = dtoi(port, 0) + if !ok || i != len(port) { + p, err = LookupPort(net, port) + if err != nil { + goto Error + } + } + if p < 0 || p > 0xFFFF { + err = &AddrError{"invalid port", port} + goto Error + } + + return addr, p, nil + +Error: + return nil, 0, err +} diff --git a/src/pkg/net/ipsock_plan9.go b/src/pkg/net/ipsock_plan9.go new file mode 100644 index 000000000..9e5da6d38 --- /dev/null +++ b/src/pkg/net/ipsock_plan9.go @@ -0,0 +1,305 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP sockets stubs for Plan 9 + +package net + +import ( + "os" +) + +// probeIPv6Stack returns two boolean values. If the first boolean value is +// true, kernel supports basic IPv6 functionality. If the second +// boolean value is true, kernel supports IPv6 IPv4-mapping. +func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { + return false, false +} + +// parsePlan9Addr parses address of the form [ip!]port (e.g. 127.0.0.1!80). +func parsePlan9Addr(s string) (ip IP, iport int, err os.Error) { + var ( + addr IP + p, i int + ok bool + ) + addr = IPv4zero // address contains port only + i = byteIndex(s, '!') + if i >= 0 { + addr = ParseIP(s[:i]) + if addr == nil { + err = os.NewError("net: parsing IP failed") + goto Error + } + } + p, _, ok = dtoi(s[i+1:], 0) + if !ok { + err = os.NewError("net: parsing port failed") + goto Error + } + if p < 0 || p > 0xFFFF { + err = &AddrError{"invalid port", string(p)} + goto Error + } + return addr, p, nil + +Error: + return nil, 0, err +} + +func readPlan9Addr(proto, filename string) (addr Addr, err os.Error) { + var buf [128]byte + + f, err := os.Open(filename) + if err != nil { + return + } + n, err := f.Read(buf[:]) + if err != nil { + return + } + ip, port, err := parsePlan9Addr(string(buf[:n])) + if err != nil { + return + } + switch proto { + case "tcp": + addr = &TCPAddr{ip, port} + case "udp": + addr = &UDPAddr{ip, port} + default: + return nil, os.NewError("unknown protocol " + proto) + } + return addr, nil +} + +type plan9Conn struct { + proto, name, dir string + ctl, data *os.File + laddr, raddr Addr +} + +func newPlan9Conn(proto, name string, ctl *os.File, laddr, raddr Addr) *plan9Conn { + return &plan9Conn{proto, name, "/net/" + proto + "/" + name, ctl, nil, laddr, raddr} +} + +func (c *plan9Conn) ok() bool { return c != nil && c.ctl != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *plan9Conn) Read(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + if c.data == nil { + c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0) + if err != nil { + return 0, err + } + } + n, err = c.data.Read(b) + if c.proto == "udp" && err == os.EOF { + n = 0 + err = nil + } + return +} + +// Write implements the net.Conn Write method. +func (c *plan9Conn) Write(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + if c.data == nil { + c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0) + if err != nil { + return 0, err + } + } + return c.data.Write(b) +} + +// Close closes the connection. +func (c *plan9Conn) Close() os.Error { + if !c.ok() { + return os.EINVAL + } + err := c.ctl.Close() + if err != nil { + return err + } + if c.data != nil { + err = c.data.Close() + } + c.ctl = nil + c.data = nil + return err +} + +// LocalAddr returns the local network address. +func (c *plan9Conn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.laddr +} + +// RemoteAddr returns the remote network address. +func (c *plan9Conn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.raddr +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *plan9Conn) SetTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return os.EPLAN9 +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *plan9Conn) SetReadTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return os.EPLAN9 +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *plan9Conn) SetWriteTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return os.EPLAN9 +} + +func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err os.Error) { + var ( + ip IP + port int + ) + switch a := addr.(type) { + case *TCPAddr: + proto = "tcp" + ip = a.IP + port = a.Port + case *UDPAddr: + proto = "udp" + ip = a.IP + port = a.Port + default: + err = UnknownNetworkError(net) + return + } + + clone, dest, err := queryCS1(proto, ip, port) + if err != nil { + return + } + f, err := os.OpenFile(clone, os.O_RDWR, 0) + if err != nil { + return + } + var buf [16]byte + n, err := f.Read(buf[:]) + if err != nil { + return + } + return f, dest, proto, string(buf[:n]), nil +} + +func dialPlan9(net string, laddr, raddr Addr) (c *plan9Conn, err os.Error) { + f, dest, proto, name, err := startPlan9(net, raddr) + if err != nil { + return + } + _, err = f.WriteString("connect " + dest) + if err != nil { + return + } + laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") + if err != nil { + return + } + raddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/remote") + if err != nil { + return + } + return newPlan9Conn(proto, name, f, laddr, raddr), nil +} + +type plan9Listener struct { + proto, name, dir string + ctl *os.File + laddr Addr +} + +func listenPlan9(net string, laddr Addr) (l *plan9Listener, err os.Error) { + f, dest, proto, name, err := startPlan9(net, laddr) + if err != nil { + return + } + _, err = f.WriteString("announce " + dest) + if err != nil { + return + } + laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") + if err != nil { + return + } + l = new(plan9Listener) + l.proto = proto + l.name = name + l.dir = "/net/" + proto + "/" + name + l.ctl = f + l.laddr = laddr + return l, nil +} + +func (l *plan9Listener) plan9Conn() *plan9Conn { + return newPlan9Conn(l.proto, l.name, l.ctl, l.laddr, nil) +} + +func (l *plan9Listener) acceptPlan9() (c *plan9Conn, err os.Error) { + f, err := os.Open(l.dir + "/listen") + if err != nil { + return + } + var buf [16]byte + n, err := f.Read(buf[:]) + if err != nil { + return + } + name := string(buf[:n]) + laddr, err := readPlan9Addr(l.proto, l.dir+"/local") + if err != nil { + return + } + raddr, err := readPlan9Addr(l.proto, l.dir+"/remote") + if err != nil { + return + } + return newPlan9Conn(l.proto, name, f, laddr, raddr), nil +} + +func (l *plan9Listener) Accept() (c Conn, err os.Error) { + c1, err := l.acceptPlan9() + if err != nil { + return + } + return c1, nil +} + +func (l *plan9Listener) Close() os.Error { + if l == nil || l.ctl == nil { + return os.EINVAL + } + return l.ctl.Close() +} + +func (l *plan9Listener) Addr() Addr { return l.laddr } diff --git a/src/pkg/net/ipsock_posix.go b/src/pkg/net/ipsock_posix.go new file mode 100644 index 000000000..0c522fb7f --- /dev/null +++ b/src/pkg/net/ipsock_posix.go @@ -0,0 +1,174 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" +) + +// Should we try to use the IPv4 socket interface if we're +// only dealing with IPv4 sockets? As long as the host system +// understands IPv6, it's okay to pass IPv4 addresses to the IPv6 +// interface. That simplifies our code and is most general. +// Unfortunately, we need to run on kernels built without IPv6 +// support too. So probe the kernel to figure it out. +// +// probeIPv6Stack probes both basic IPv6 capability and IPv6 IPv4- +// mapping capability which is controlled by IPV6_V6ONLY socket +// option and/or kernel state "net.inet6.ip6.v6only". +// It returns two boolean values. If the first boolean value is +// true, kernel supports basic IPv6 functionality. If the second +// boolean value is true, kernel supports IPv6 IPv4-mapping. +func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { + var probes = []struct { + la TCPAddr + ok bool + }{ + // IPv6 communication capability + {TCPAddr{IP: ParseIP("::1")}, false}, + // IPv6 IPv4-mapped address communication capability + {TCPAddr{IP: IPv4(127, 0, 0, 1)}, false}, + } + + for i := range probes { + s, errno := syscall.Socket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + if errno != 0 { + continue + } + defer closesocket(s) + sa, err := probes[i].la.toAddr().sockaddr(syscall.AF_INET6) + if err != nil { + continue + } + errno = syscall.Bind(s, sa) + if errno != 0 { + continue + } + probes[i].ok = true + } + + return probes[0].ok, probes[1].ok +} + +// favoriteAddrFamily returns the appropriate address family to +// the given net, raddr, laddr and mode. At first it figures +// address family out from the net. If mode indicates "listen" +// and laddr.(type).IP is nil, it assumes that the user wants to +// make a passive connection with wildcard address family, both +// INET and INET6, and wildcard address. Otherwise guess: if the +// addresses are IPv4 then returns INET, or else returns INET6. +func favoriteAddrFamily(net string, raddr, laddr sockaddr, mode string) int { + switch net[len(net)-1] { + case '4': + return syscall.AF_INET + case '6': + return syscall.AF_INET6 + } + + if mode == "listen" { + switch a := laddr.(type) { + case *TCPAddr: + if a.IP == nil && supportsIPv6 { + return syscall.AF_INET6 + } + case *UDPAddr: + if a.IP == nil && supportsIPv6 { + return syscall.AF_INET6 + } + case *IPAddr: + if a.IP == nil && supportsIPv6 { + return syscall.AF_INET6 + } + } + } + + if (laddr == nil || laddr.family() == syscall.AF_INET) && + (raddr == nil || raddr.family() == syscall.AF_INET) { + return syscall.AF_INET + } + return syscall.AF_INET6 +} + +// TODO(rsc): if syscall.OS == "linux", we're supposed to read +// /proc/sys/net/core/somaxconn, +// to take advantage of kernels that have raised the limit. +func listenBacklog() int { return syscall.SOMAXCONN } + +// Internet sockets (TCP, UDP) + +// A sockaddr represents a TCP or UDP network address that can +// be converted into a syscall.Sockaddr. +type sockaddr interface { + Addr + sockaddr(family int) (syscall.Sockaddr, os.Error) + family() int +} + +func internetSocket(net string, laddr, raddr sockaddr, socktype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err os.Error) { + var oserr os.Error + var la, ra syscall.Sockaddr + family := favoriteAddrFamily(net, raddr, laddr, mode) + if laddr != nil { + if la, oserr = laddr.sockaddr(family); oserr != nil { + goto Error + } + } + if raddr != nil { + if ra, oserr = raddr.sockaddr(family); oserr != nil { + goto Error + } + } + fd, oserr = socket(net, family, socktype, proto, la, ra, toAddr) + if oserr != nil { + goto Error + } + return fd, nil + +Error: + addr := raddr + if mode == "listen" { + addr = laddr + } + return nil, &OpError{mode, net, addr, oserr} +} + +func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, os.Error) { + switch family { + case syscall.AF_INET: + if len(ip) == 0 { + ip = IPv4zero + } + if ip = ip.To4(); ip == nil { + return nil, InvalidAddrError("non-IPv4 address") + } + s := new(syscall.SockaddrInet4) + for i := 0; i < IPv4len; i++ { + s.Addr[i] = ip[i] + } + s.Port = port + return s, nil + case syscall.AF_INET6: + if len(ip) == 0 { + ip = IPv6zero + } + // IPv4 callers use 0.0.0.0 to mean "announce on any available address". + // In IPv6 mode, Linux treats that as meaning "announce on 0.0.0.0", + // which it refuses to do. Rewrite to the IPv6 all zeros. + if ip.Equal(IPv4zero) { + ip = IPv6zero + } + if ip = ip.To16(); ip == nil { + return nil, InvalidAddrError("non-IPv6 address") + } + s := new(syscall.SockaddrInet6) + for i := 0; i < IPv6len; i++ { + s.Addr[i] = ip[i] + } + s.Port = port + return s, nil + } + return nil, InvalidAddrError("unexpected socket family") +} diff --git a/src/pkg/net/lookup_plan9.go b/src/pkg/net/lookup_plan9.go new file mode 100644 index 000000000..37d6b8e31 --- /dev/null +++ b/src/pkg/net/lookup_plan9.go @@ -0,0 +1,226 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" +) + +func query(filename, query string, bufSize int) (res []string, err os.Error) { + file, err := os.OpenFile(filename, os.O_RDWR, 0) + if err != nil { + return + } + defer file.Close() + + _, err = file.WriteString(query) + if err != nil { + return + } + _, err = file.Seek(0, 0) + if err != nil { + return + } + buf := make([]byte, bufSize) + for { + n, _ := file.Read(buf) + if n <= 0 { + break + } + res = append(res, string(buf[:n])) + } + return +} + +func queryCS(net, host, service string) (res []string, err os.Error) { + switch net { + case "tcp4", "tcp6": + net = "tcp" + case "udp4", "udp6": + net = "udp" + } + if host == "" { + host = "*" + } + return query("/net/cs", net+"!"+host+"!"+service, 128) +} + +func queryCS1(net string, ip IP, port int) (clone, dest string, err os.Error) { + ips := "*" + if !ip.IsUnspecified() { + ips = ip.String() + } + lines, err := queryCS(net, ips, itoa(port)) + if err != nil { + return + } + f := getFields(lines[0]) + if len(f) < 2 { + return "", "", os.NewError("net: bad response from ndb/cs") + } + clone, dest = f[0], f[1] + return +} + +func queryDNS(addr string, typ string) (res []string, err os.Error) { + return query("/net/dns", addr+" "+typ, 1024) +} + +// LookupHost looks up the given host using the local resolver. +// It returns an array of that host's addresses. +func LookupHost(host string) (addrs []string, err os.Error) { + // Use /net/cs insead of /net/dns because cs knows about + // host names in local network (e.g. from /lib/ndb/local) + lines, err := queryCS("tcp", host, "1") + if err != nil { + return + } + for _, line := range lines { + f := getFields(line) + if len(f) < 2 { + continue + } + addr := f[1] + if i := byteIndex(addr, '!'); i >= 0 { + addr = addr[:i] // remove port + } + if ParseIP(addr) == nil { + continue + } + addrs = append(addrs, addr) + } + return +} + +// LookupIP looks up host using the local resolver. +// It returns an array of that host's IPv4 and IPv6 addresses. +func LookupIP(host string) (ips []IP, err os.Error) { + addrs, err := LookupHost(host) + if err != nil { + return + } + for _, addr := range addrs { + if ip := ParseIP(addr); ip != nil { + ips = append(ips, ip) + } + } + return +} + +// LookupPort looks up the port for the given network and service. +func LookupPort(network, service string) (port int, err os.Error) { + switch network { + case "tcp4", "tcp6": + network = "tcp" + case "udp4", "udp6": + network = "udp" + } + lines, err := queryCS(network, "127.0.0.1", service) + if err != nil { + return + } + unknownPortError := &AddrError{"unknown port", network + "/" + service} + if len(lines) == 0 { + return 0, unknownPortError + } + f := getFields(lines[0]) + if len(f) < 2 { + return 0, unknownPortError + } + s := f[1] + if i := byteIndex(s, '!'); i >= 0 { + s = s[i+1:] // remove address + } + if n, _, ok := dtoi(s, 0); ok { + return n, nil + } + return 0, unknownPortError +} + +// LookupCNAME returns the canonical DNS host for the given name. +// Callers that do not care about the canonical name can call +// LookupHost or LookupIP directly; both take care of resolving +// the canonical name as part of the lookup. +func LookupCNAME(name string) (cname string, err os.Error) { + lines, err := queryDNS(name, "cname") + if err != nil { + return + } + if len(lines) > 0 { + if f := getFields(lines[0]); len(f) >= 3 { + return f[2] + ".", nil + } + } + return "", os.NewError("net: bad response from ndb/dns") +} + +// LookupSRV tries to resolve an SRV query of the given service, +// protocol, and domain name, as specified in RFC 2782. In most cases +// the proto argument can be the same as the corresponding +// Addr.Network(). The returned records are sorted by priority +// and randomized by weight within a priority. +func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { + target := "_" + service + "._" + proto + "." + name + lines, err := queryDNS(target, "srv") + if err != nil { + return + } + for _, line := range lines { + f := getFields(line) + if len(f) < 6 { + continue + } + port, _, portOk := dtoi(f[2], 0) + priority, _, priorityOk := dtoi(f[3], 0) + weight, _, weightOk := dtoi(f[4], 0) + if !(portOk && priorityOk && weightOk) { + continue + } + addrs = append(addrs, &SRV{f[5], uint16(port), uint16(priority), uint16(weight)}) + cname = f[0] + } + byPriorityWeight(addrs).sort() + return +} + +// LookupMX returns the DNS MX records for the given domain name sorted by preference. +func LookupMX(name string) (mx []*MX, err os.Error) { + lines, err := queryDNS(name, "mx") + if err != nil { + return + } + for _, line := range lines { + f := getFields(line) + if len(f) < 4 { + continue + } + if pref, _, ok := dtoi(f[2], 0); ok { + mx = append(mx, &MX{f[3], uint16(pref)}) + } + } + byPref(mx).sort() + return +} + +// LookupAddr performs a reverse lookup for the given address, returning a list +// of names mapping to that address. +func LookupAddr(addr string) (name []string, err os.Error) { + arpa, err := reverseaddr(addr) + if err != nil { + return + } + lines, err := queryDNS(arpa, "ptr") + if err != nil { + return + } + for _, line := range lines { + f := getFields(line) + if len(f) < 3 { + continue + } + name = append(name, f[2]) + } + return +} diff --git a/src/pkg/net/lookup_test.go b/src/pkg/net/lookup_test.go new file mode 100644 index 000000000..995ab03d0 --- /dev/null +++ b/src/pkg/net/lookup_test.go @@ -0,0 +1,57 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO It would be nice to use a mock DNS server, to eliminate +// external dependencies. + +package net + +import ( + "runtime" + "testing" +) + +var avoidMacFirewall = runtime.GOOS == "darwin" + +func TestGoogleSRV(t *testing.T) { + if testing.Short() || avoidMacFirewall { + t.Logf("skipping test to avoid external network") + return + } + _, addrs, err := LookupSRV("xmpp-server", "tcp", "google.com") + if err != nil { + t.Errorf("failed: %s", err) + } + if len(addrs) == 0 { + t.Errorf("no results") + } +} + +func TestGmailMX(t *testing.T) { + if testing.Short() || avoidMacFirewall { + t.Logf("skipping test to avoid external network") + return + } + mx, err := LookupMX("gmail.com") + if err != nil { + t.Errorf("failed: %s", err) + } + if len(mx) == 0 { + t.Errorf("no results") + } +} + +func TestGoogleDNSAddr(t *testing.T) { + if testing.Short() || avoidMacFirewall { + t.Logf("skipping test to avoid external network") + return + } + names, err := LookupAddr("8.8.8.8") + if err != nil { + t.Errorf("failed: %s", err) + } + if len(names) == 0 { + t.Errorf("no results") + } +} diff --git a/src/pkg/net/lookup_unix.go b/src/pkg/net/lookup_unix.go new file mode 100644 index 000000000..8f5e66212 --- /dev/null +++ b/src/pkg/net/lookup_unix.go @@ -0,0 +1,111 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" +) + +// LookupHost looks up the given host using the local resolver. +// It returns an array of that host's addresses. +func LookupHost(host string) (addrs []string, err os.Error) { + addrs, err, ok := cgoLookupHost(host) + if !ok { + addrs, err = goLookupHost(host) + } + return +} + +// LookupIP looks up host using the local resolver. +// It returns an array of that host's IPv4 and IPv6 addresses. +func LookupIP(host string) (addrs []IP, err os.Error) { + addrs, err, ok := cgoLookupIP(host) + if !ok { + addrs, err = goLookupIP(host) + } + return +} + +// LookupPort looks up the port for the given network and service. +func LookupPort(network, service string) (port int, err os.Error) { + port, err, ok := cgoLookupPort(network, service) + if !ok { + port, err = goLookupPort(network, service) + } + return +} + +// LookupCNAME returns the canonical DNS host for the given name. +// Callers that do not care about the canonical name can call +// LookupHost or LookupIP directly; both take care of resolving +// the canonical name as part of the lookup. +func LookupCNAME(name string) (cname string, err os.Error) { + cname, err, ok := cgoLookupCNAME(name) + if !ok { + cname, err = goLookupCNAME(name) + } + return +} + +// LookupSRV tries to resolve an SRV query of the given service, +// protocol, and domain name, as specified in RFC 2782. In most cases +// the proto argument can be the same as the corresponding +// Addr.Network(). The returned records are sorted by priority +// and randomized by weight within a priority. +func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { + target := "_" + service + "._" + proto + "." + name + var records []dnsRR + cname, records, err = lookup(target, dnsTypeSRV) + if err != nil { + return + } + addrs = make([]*SRV, len(records)) + for i, rr := range records { + r := rr.(*dnsRR_SRV) + addrs[i] = &SRV{r.Target, r.Port, r.Priority, r.Weight} + } + byPriorityWeight(addrs).sort() + return +} + +// LookupMX returns the DNS MX records for the given domain name sorted by preference. +func LookupMX(name string) (mx []*MX, err os.Error) { + _, rr, err := lookup(name, dnsTypeMX) + if err != nil { + return + } + mx = make([]*MX, len(rr)) + for i := range rr { + r := rr[i].(*dnsRR_MX) + mx[i] = &MX{r.Mx, r.Pref} + } + byPref(mx).sort() + return +} + +// LookupAddr performs a reverse lookup for the given address, returning a list +// of names mapping to that address. +func LookupAddr(addr string) (name []string, err os.Error) { + name = lookupStaticAddr(addr) + if len(name) > 0 { + return + } + var arpa string + arpa, err = reverseaddr(addr) + if err != nil { + return + } + var records []dnsRR + _, records, err = lookup(arpa, dnsTypePTR) + if err != nil { + return + } + name = make([]string, len(records)) + for i := range records { + r := records[i].(*dnsRR_PTR) + name[i] = r.Ptr + } + return +} diff --git a/src/pkg/net/lookup_windows.go b/src/pkg/net/lookup_windows.go new file mode 100644 index 000000000..fa3ad7c7f --- /dev/null +++ b/src/pkg/net/lookup_windows.go @@ -0,0 +1,130 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "syscall" + "unsafe" + "os" + "sync" +) + +var hostentLock sync.Mutex +var serventLock sync.Mutex + +func LookupHost(name string) (addrs []string, err os.Error) { + ips, err := LookupIP(name) + if err != nil { + return + } + addrs = make([]string, 0, len(ips)) + for _, ip := range ips { + addrs = append(addrs, ip.String()) + } + return +} + +func LookupIP(name string) (addrs []IP, err os.Error) { + hostentLock.Lock() + defer hostentLock.Unlock() + h, e := syscall.GetHostByName(name) + if e != 0 { + return nil, os.NewSyscallError("GetHostByName", e) + } + switch h.AddrType { + case syscall.AF_INET: + i := 0 + addrs = make([]IP, 100) // plenty of room to grow + for p := (*[100](*[4]byte))(unsafe.Pointer(h.AddrList)); i < cap(addrs) && p[i] != nil; i++ { + addrs[i] = IPv4(p[i][0], p[i][1], p[i][2], p[i][3]) + } + addrs = addrs[0:i] + default: // TODO(vcc): Implement non IPv4 address lookups. + return nil, os.NewSyscallError("LookupHost", syscall.EWINDOWS) + } + return addrs, nil +} + +func LookupPort(network, service string) (port int, err os.Error) { + switch network { + case "tcp4", "tcp6": + network = "tcp" + case "udp4", "udp6": + network = "udp" + } + serventLock.Lock() + defer serventLock.Unlock() + s, e := syscall.GetServByName(service, network) + if e != 0 { + return 0, os.NewSyscallError("GetServByName", e) + } + return int(syscall.Ntohs(s.Port)), nil +} + +func LookupCNAME(name string) (cname string, err os.Error) { + var r *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) + if int(e) != 0 { + return "", os.NewSyscallError("LookupCNAME", int(e)) + } + defer syscall.DnsRecordListFree(r, 1) + if r != nil && r.Type == syscall.DNS_TYPE_CNAME { + v := (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])) + cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "." + } + return +} + +func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { + var r *syscall.DNSRecord + target := "_" + service + "._" + proto + "." + name + e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil) + if int(e) != 0 { + return "", nil, os.NewSyscallError("LookupSRV", int(e)) + } + defer syscall.DnsRecordListFree(r, 1) + addrs = make([]*SRV, 0, 10) + for p := r; p != nil && p.Type == syscall.DNS_TYPE_SRV; p = p.Next { + v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0])) + addrs = append(addrs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight}) + } + byPriorityWeight(addrs).sort() + return name, addrs, nil +} + +func LookupMX(name string) (mx []*MX, err os.Error) { + var r *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil) + if int(e) != 0 { + return nil, os.NewSyscallError("LookupMX", int(e)) + } + defer syscall.DnsRecordListFree(r, 1) + mx = make([]*MX, 0, 10) + for p := r; p != nil && p.Type == syscall.DNS_TYPE_MX; p = p.Next { + v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0])) + mx = append(mx, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference}) + } + byPref(mx).sort() + return mx, nil +} + +func LookupAddr(addr string) (name []string, err os.Error) { + arpa, err := reverseaddr(addr) + if err != nil { + return nil, err + } + var r *syscall.DNSRecord + e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil) + if int(e) != 0 { + return nil, os.NewSyscallError("LookupAddr", int(e)) + } + defer syscall.DnsRecordListFree(r, 1) + name = make([]string, 0, 10) + for p := r; p != nil && p.Type == syscall.DNS_TYPE_PTR; p = p.Next { + v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) + name = append(name, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])) + } + return name, nil +} diff --git a/src/pkg/net/multicast_test.go b/src/pkg/net/multicast_test.go new file mode 100644 index 000000000..be6dbf2dc --- /dev/null +++ b/src/pkg/net/multicast_test.go @@ -0,0 +1,73 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "flag" + "runtime" + "testing" +) + +var multicast = flag.Bool("multicast", false, "enable multicast tests") + +func TestMulticastJoinAndLeave(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + if !*multicast { + t.Logf("test disabled; use --multicast to enable") + return + } + + addr := &UDPAddr{ + IP: IPv4zero, + Port: 0, + } + // open a UDPConn + conn, err := ListenUDP("udp4", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // try to join group + mcast := IPv4(224, 0, 0, 254) + err = conn.JoinGroup(mcast) + if err != nil { + t.Fatal(err) + } + + // try to leave group + err = conn.LeaveGroup(mcast) + if err != nil { + t.Fatal(err) + } +} + +func TestJoinFailureWithIPv6Address(t *testing.T) { + if !*multicast { + t.Logf("test disabled; use --multicast to enable") + return + } + addr := &UDPAddr{ + IP: IPv4zero, + Port: 0, + } + + // open a UDPConn + conn, err := ListenUDP("udp4", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // try to join group + mcast := ParseIP("ff02::1") + err = conn.JoinGroup(mcast) + if err == nil { + t.Fatal("JoinGroup succeeded, should fail") + } + t.Logf("%s", err) +} diff --git a/src/pkg/net/net.go b/src/pkg/net/net.go new file mode 100644 index 000000000..5c84d3434 --- /dev/null +++ b/src/pkg/net/net.go @@ -0,0 +1,188 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package net provides a portable interface to Unix networks sockets, +// including TCP/IP, UDP, domain name resolution, and Unix domain sockets. +package net + +// TODO(rsc): +// support for raw ethernet sockets + +import "os" + +// Addr represents a network end point address. +type Addr interface { + Network() string // name of the network + String() string // string form of address +} + +// Conn is a generic stream-oriented network connection. +type Conn interface { + // Read reads data from the connection. + // Read can be made to time out and return a net.Error with Timeout() == true + // after a fixed time limit; see SetTimeout and SetReadTimeout. + Read(b []byte) (n int, err os.Error) + + // Write writes data to the connection. + // Write can be made to time out and return a net.Error with Timeout() == true + // after a fixed time limit; see SetTimeout and SetWriteTimeout. + Write(b []byte) (n int, err os.Error) + + // Close closes the connection. + Close() os.Error + + // LocalAddr returns the local network address. + LocalAddr() Addr + + // RemoteAddr returns the remote network address. + RemoteAddr() Addr + + // SetTimeout sets the read and write deadlines associated + // with the connection. + SetTimeout(nsec int64) os.Error + + // SetReadTimeout sets the time (in nanoseconds) that + // Read will wait for data before returning an error with Timeout() == true. + // Setting nsec == 0 (the default) disables the deadline. + SetReadTimeout(nsec int64) os.Error + + // SetWriteTimeout sets the time (in nanoseconds) that + // Write will wait to send its data before returning an error with Timeout() == true. + // Setting nsec == 0 (the default) disables the deadline. + // Even if write times out, it may return n > 0, indicating that + // some of the data was successfully written. + SetWriteTimeout(nsec int64) os.Error +} + +// An Error represents a network error. +type Error interface { + os.Error + Timeout() bool // Is the error a timeout? + Temporary() bool // Is the error temporary? +} + +// PacketConn is a generic packet-oriented network connection. +type PacketConn interface { + // ReadFrom reads a packet from the connection, + // copying the payload into b. It returns the number of + // bytes copied into b and the return address that + // was on the packet. + // ReadFrom can be made to time out and return + // an error with Timeout() == true after a fixed time limit; + // see SetTimeout and SetReadTimeout. + ReadFrom(b []byte) (n int, addr Addr, err os.Error) + + // WriteTo writes a packet with payload b to addr. + // WriteTo can be made to time out and return + // an error with Timeout() == true after a fixed time limit; + // see SetTimeout and SetWriteTimeout. + // On packet-oriented connections, write timeouts are rare. + WriteTo(b []byte, addr Addr) (n int, err os.Error) + + // Close closes the connection. + Close() os.Error + + // LocalAddr returns the local network address. + LocalAddr() Addr + + // SetTimeout sets the read and write deadlines associated + // with the connection. + SetTimeout(nsec int64) os.Error + + // SetReadTimeout sets the time (in nanoseconds) that + // Read will wait for data before returning an error with Timeout() == true. + // Setting nsec == 0 (the default) disables the deadline. + SetReadTimeout(nsec int64) os.Error + + // SetWriteTimeout sets the time (in nanoseconds) that + // Write will wait to send its data before returning an error with Timeout() == true. + // Setting nsec == 0 (the default) disables the deadline. + // Even if write times out, it may return n > 0, indicating that + // some of the data was successfully written. + SetWriteTimeout(nsec int64) os.Error +} + +// A Listener is a generic network listener for stream-oriented protocols. +type Listener interface { + // Accept waits for and returns the next connection to the listener. + Accept() (c Conn, err os.Error) + + // Close closes the listener. + Close() os.Error + + // Addr returns the listener's network address. + Addr() Addr +} + +var errMissingAddress = os.NewError("missing address") + +type OpError struct { + Op string + Net string + Addr Addr + Error os.Error +} + +func (e *OpError) String() string { + if e == nil { + return "<nil>" + } + s := e.Op + if e.Net != "" { + s += " " + e.Net + } + if e.Addr != nil { + s += " " + e.Addr.String() + } + s += ": " + e.Error.String() + return s +} + +type temporary interface { + Temporary() bool +} + +func (e *OpError) Temporary() bool { + t, ok := e.Error.(temporary) + return ok && t.Temporary() +} + +type timeout interface { + Timeout() bool +} + +func (e *OpError) Timeout() bool { + t, ok := e.Error.(timeout) + return ok && t.Timeout() +} + +type AddrError struct { + Error string + Addr string +} + +func (e *AddrError) String() string { + if e == nil { + return "<nil>" + } + s := e.Error + if e.Addr != "" { + s += " " + e.Addr + } + return s +} + +func (e *AddrError) Temporary() bool { + return false +} + +func (e *AddrError) Timeout() bool { + return false +} + +type UnknownNetworkError string + +func (e UnknownNetworkError) String() string { return "unknown network " + string(e) } +func (e UnknownNetworkError) Temporary() bool { return false } +func (e UnknownNetworkError) Timeout() bool { return false } diff --git a/src/pkg/net/net_test.go b/src/pkg/net/net_test.go new file mode 100644 index 000000000..698a84527 --- /dev/null +++ b/src/pkg/net/net_test.go @@ -0,0 +1,121 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "flag" + "regexp" + "testing" +) + +var runErrorTest = flag.Bool("run_error_test", false, "let TestDialError check for dns errors") + +type DialErrorTest struct { + Net string + Raddr string + Pattern string +} + +var dialErrorTests = []DialErrorTest{ + { + "datakit", "mh/astro/r70", + "dial datakit mh/astro/r70: unknown network datakit", + }, + { + "tcp", "127.0.0.1:☺", + "dial tcp 127.0.0.1:☺: unknown port tcp/☺", + }, + { + "tcp", "no-such-name.google.com.:80", + "dial tcp no-such-name.google.com.:80: lookup no-such-name.google.com.( on .*)?: no (.*)", + }, + { + "tcp", "no-such-name.no-such-top-level-domain.:80", + "dial tcp no-such-name.no-such-top-level-domain.:80: lookup no-such-name.no-such-top-level-domain.( on .*)?: no (.*)", + }, + { + "tcp", "no-such-name:80", + `dial tcp no-such-name:80: lookup no-such-name\.(.*\.)?( on .*)?: no (.*)`, + }, + { + "tcp", "mh/astro/r70:http", + "dial tcp mh/astro/r70:http: lookup mh/astro/r70: invalid domain name", + }, + { + "unix", "/etc/file-not-found", + "dial unix /etc/file-not-found: no such file or directory", + }, + { + "unix", "/etc/", + "dial unix /etc/: (permission denied|socket operation on non-socket|connection refused)", + }, + { + "unixpacket", "/etc/file-not-found", + "dial unixpacket /etc/file-not-found: no such file or directory", + }, + { + "unixpacket", "/etc/", + "dial unixpacket /etc/: (permission denied|socket operation on non-socket|connection refused)", + }, +} + +func TestDialError(t *testing.T) { + if !*runErrorTest { + t.Logf("test disabled; use --run_error_test to enable") + return + } + for i, tt := range dialErrorTests { + c, e := Dial(tt.Net, tt.Raddr) + if c != nil { + c.Close() + } + if e == nil { + t.Errorf("#%d: nil error, want match for %#q", i, tt.Pattern) + continue + } + s := e.String() + match, _ := regexp.MatchString(tt.Pattern, s) + if !match { + t.Errorf("#%d: %q, want match for %#q", i, s, tt.Pattern) + } + } +} + +var revAddrTests = []struct { + Addr string + Reverse string + ErrPrefix string +}{ + {"1.2.3.4", "4.3.2.1.in-addr.arpa.", ""}, + {"245.110.36.114", "114.36.110.245.in-addr.arpa.", ""}, + {"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa.", ""}, + {"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", ""}, + {"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa.", ""}, + {"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa.", ""}, + {"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa.", ""}, + {"1.2.3", "", "unrecognized address"}, + {"1.2.3.4.5", "", "unrecognized address"}, + {"1234:567:bcbca::89a:bcde", "", "unrecognized address"}, + {"1234:567::bcbc:adad::89a:bcde", "", "unrecognized address"}, +} + +func TestReverseAddress(t *testing.T) { + for i, tt := range revAddrTests { + a, e := reverseaddr(tt.Addr) + if len(tt.ErrPrefix) > 0 && e == nil { + t.Errorf("#%d: expected %q, got <nil> (error)", i, tt.ErrPrefix) + continue + } + if len(tt.ErrPrefix) == 0 && e != nil { + t.Errorf("#%d: expected <nil>, got %q (error)", i, e) + } + if e != nil && e.(*DNSError).Error != tt.ErrPrefix { + t.Errorf("#%d: expected %q, got %q (mismatched error)", i, tt.ErrPrefix, e.(*DNSError).Error) + } + if a != tt.Reverse { + t.Errorf("#%d: expected %q, got %q (reverse address)", i, tt.Reverse, a) + } + } +} diff --git a/src/pkg/net/newpollserver.go b/src/pkg/net/newpollserver.go new file mode 100644 index 000000000..427208701 --- /dev/null +++ b/src/pkg/net/newpollserver.go @@ -0,0 +1,43 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" +) + +func newPollServer() (s *pollServer, err os.Error) { + s = new(pollServer) + s.cr = make(chan *netFD, 1) + s.cw = make(chan *netFD, 1) + if s.pr, s.pw, err = os.Pipe(); err != nil { + return nil, err + } + var e int + if e = syscall.SetNonblock(s.pr.Fd(), true); e != 0 { + goto Errno + } + if e = syscall.SetNonblock(s.pw.Fd(), true); e != 0 { + goto Errno + } + if s.poll, err = newpollster(); err != nil { + goto Error + } + if _, err = s.poll.AddFD(s.pr.Fd(), 'r', true); err != nil { + s.poll.Close() + goto Error + } + s.pending = make(map[int]*netFD) + go s.Run() + return s, nil + +Errno: + err = &os.PathError{"setnonblock", s.pr.Name(), os.Errno(e)} +Error: + s.pr.Close() + s.pw.Close() + return nil, err +} diff --git a/src/pkg/net/parse.go b/src/pkg/net/parse.go new file mode 100644 index 000000000..de46830d2 --- /dev/null +++ b/src/pkg/net/parse.go @@ -0,0 +1,204 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Simple file i/o and string manipulation, to avoid +// depending on strconv and bufio and strings. + +package net + +import ( + "io" + "os" +) + +type file struct { + file *os.File + data []byte + atEOF bool +} + +func (f *file) close() { f.file.Close() } + +func (f *file) getLineFromData() (s string, ok bool) { + data := f.data + i := 0 + for i = 0; i < len(data); i++ { + if data[i] == '\n' { + s = string(data[0:i]) + ok = true + // move data + i++ + n := len(data) - i + copy(data[0:], data[i:]) + f.data = data[0:n] + return + } + } + if f.atEOF && len(f.data) > 0 { + // EOF, return all we have + s = string(data) + f.data = f.data[0:0] + ok = true + } + return +} + +func (f *file) readLine() (s string, ok bool) { + if s, ok = f.getLineFromData(); ok { + return + } + if len(f.data) < cap(f.data) { + ln := len(f.data) + n, err := io.ReadFull(f.file, f.data[ln:cap(f.data)]) + if n >= 0 { + f.data = f.data[0 : ln+n] + } + if err == os.EOF { + f.atEOF = true + } + } + s, ok = f.getLineFromData() + return +} + +func open(name string) (*file, os.Error) { + fd, err := os.Open(name) + if err != nil { + return nil, err + } + return &file{fd, make([]byte, 1024)[0:0], false}, nil +} + +func byteIndex(s string, c byte) int { + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} + +// Count occurrences in s of any bytes in t. +func countAnyByte(s string, t string) int { + n := 0 + for i := 0; i < len(s); i++ { + if byteIndex(t, s[i]) >= 0 { + n++ + } + } + return n +} + +// Split s at any bytes in t. +func splitAtBytes(s string, t string) []string { + a := make([]string, 1+countAnyByte(s, t)) + n := 0 + last := 0 + for i := 0; i < len(s); i++ { + if byteIndex(t, s[i]) >= 0 { + if last < i { + a[n] = string(s[last:i]) + n++ + } + last = i + 1 + } + } + if last < len(s) { + a[n] = string(s[last:]) + n++ + } + return a[0:n] +} + +func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") } + +// Bigger than we need, not too big to worry about overflow +const big = 0xFFFFFF + +// Decimal to integer starting at &s[i0]. +// Returns number, new offset, success. +func dtoi(s string, i0 int) (n int, i int, ok bool) { + n = 0 + for i = i0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { + n = n*10 + int(s[i]-'0') + if n >= big { + return 0, i, false + } + } + if i == i0 { + return 0, i, false + } + return n, i, true +} + +// Hexadecimal to integer starting at &s[i0]. +// Returns number, new offset, success. +func xtoi(s string, i0 int) (n int, i int, ok bool) { + n = 0 + for i = i0; i < len(s); i++ { + if '0' <= s[i] && s[i] <= '9' { + n *= 16 + n += int(s[i] - '0') + } else if 'a' <= s[i] && s[i] <= 'f' { + n *= 16 + n += int(s[i]-'a') + 10 + } else if 'A' <= s[i] && s[i] <= 'F' { + n *= 16 + n += int(s[i]-'A') + 10 + } else { + break + } + if n >= big { + return 0, i, false + } + } + if i == i0 { + return 0, i, false + } + return n, i, true +} + +// Integer to decimal. +func itoa(i int) string { + var buf [30]byte + n := len(buf) + neg := false + if i < 0 { + i = -i + neg = true + } + ui := uint(i) + for ui > 0 || n == len(buf) { + n-- + buf[n] = byte('0' + ui%10) + ui /= 10 + } + if neg { + n-- + buf[n] = '-' + } + return string(buf[n:]) +} + +// Number of occurrences of b in s. +func count(s string, b byte) int { + n := 0 + for i := 0; i < len(s); i++ { + if s[i] == b { + n++ + } + } + return n +} + +// Index of rightmost occurrence of b in s. +func last(s string, b byte) int { + i := len(s) + for i--; i >= 0; i-- { + if s[i] == b { + break + } + } + return i +} diff --git a/src/pkg/net/parse_test.go b/src/pkg/net/parse_test.go new file mode 100644 index 000000000..8d51eba18 --- /dev/null +++ b/src/pkg/net/parse_test.go @@ -0,0 +1,50 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bufio" + "os" + "testing" + "runtime" +) + +func TestReadLine(t *testing.T) { + // /etc/services file does not exist on windows and Plan 9. + if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { + return + } + filename := "/etc/services" // a nice big file + + fd, err := os.Open(filename) + if err != nil { + t.Fatalf("open %s: %v", filename, err) + } + br := bufio.NewReader(fd) + + file, err := open(filename) + if file == nil { + t.Fatalf("net.open(%s) = nil", filename) + } + + lineno := 1 + byteno := 0 + for { + bline, berr := br.ReadString('\n') + if n := len(bline); n > 0 { + bline = bline[0 : n-1] + } + line, ok := file.readLine() + if (berr != nil) != !ok || bline != line { + t.Fatalf("%s:%d (#%d)\nbufio => %q, %v\nnet => %q, %v", + filename, lineno, byteno, bline, berr, line, ok) + } + if !ok { + break + } + lineno++ + byteno += len(line) + 1 + } +} diff --git a/src/pkg/net/pipe.go b/src/pkg/net/pipe.go new file mode 100644 index 000000000..c0bbd356b --- /dev/null +++ b/src/pkg/net/pipe.go @@ -0,0 +1,62 @@ +package net + +import ( + "io" + "os" +) + +// Pipe creates a synchronous, in-memory, full duplex +// network connection; both ends implement the Conn interface. +// Reads on one end are matched with writes on the other, +// copying data directly between the two; there is no internal +// buffering. +func Pipe() (Conn, Conn) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + return &pipe{r1, w2}, &pipe{r2, w1} +} + +type pipe struct { + *io.PipeReader + *io.PipeWriter +} + +type pipeAddr int + +func (pipeAddr) Network() string { + return "pipe" +} + +func (pipeAddr) String() string { + return "pipe" +} + +func (p *pipe) Close() os.Error { + err := p.PipeReader.Close() + err1 := p.PipeWriter.Close() + if err == nil { + err = err1 + } + return err +} + +func (p *pipe) LocalAddr() Addr { + return pipeAddr(0) +} + +func (p *pipe) RemoteAddr() Addr { + return pipeAddr(0) +} + +func (p *pipe) SetTimeout(nsec int64) os.Error { + return os.NewError("net.Pipe does not support timeouts") +} + +func (p *pipe) SetReadTimeout(nsec int64) os.Error { + return os.NewError("net.Pipe does not support timeouts") +} + +func (p *pipe) SetWriteTimeout(nsec int64) os.Error { + return os.NewError("net.Pipe does not support timeouts") +} diff --git a/src/pkg/net/pipe_test.go b/src/pkg/net/pipe_test.go new file mode 100644 index 000000000..7e4c6db44 --- /dev/null +++ b/src/pkg/net/pipe_test.go @@ -0,0 +1,57 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "io" + "os" + "testing" +) + +func checkWrite(t *testing.T, w io.Writer, data []byte, c chan int) { + n, err := w.Write(data) + if err != nil { + t.Errorf("write: %v", err) + } + if n != len(data) { + t.Errorf("short write: %d != %d", n, len(data)) + } + c <- 0 +} + +func checkRead(t *testing.T, r io.Reader, data []byte, wantErr os.Error) { + buf := make([]byte, len(data)+10) + n, err := r.Read(buf) + if err != wantErr { + t.Errorf("read: %v", err) + return + } + if n != len(data) || !bytes.Equal(buf[0:n], data) { + t.Errorf("bad read: got %q", buf[0:n]) + return + } +} + +// Test a simple read/write/close sequence. +// Assumes that the underlying io.Pipe implementation +// is solid and we're just testing the net wrapping. + +func TestPipe(t *testing.T) { + c := make(chan int) + cli, srv := Pipe() + go checkWrite(t, cli, []byte("hello, world"), c) + checkRead(t, srv, []byte("hello, world"), nil) + <-c + go checkWrite(t, srv, []byte("line 2"), c) + checkRead(t, cli, []byte("line 2"), nil) + <-c + go checkWrite(t, cli, []byte("a third line"), c) + checkRead(t, srv, []byte("a third line"), nil) + <-c + go srv.Close() + checkRead(t, cli, nil, os.EOF) + cli.Close() +} diff --git a/src/pkg/net/port.go b/src/pkg/net/port.go new file mode 100644 index 000000000..8f8327a37 --- /dev/null +++ b/src/pkg/net/port.go @@ -0,0 +1,70 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Read system port mappings from /etc/services + +package net + +import ( + "os" + "sync" +) + +var services map[string]map[string]int +var servicesError os.Error +var onceReadServices sync.Once + +func readServices() { + services = make(map[string]map[string]int) + var file *file + if file, servicesError = open("/etc/services"); servicesError != nil { + return + } + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + // "http 80/tcp www www-http # World Wide Web HTTP" + if i := byteIndex(line, '#'); i >= 0 { + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 { + continue + } + portnet := f[1] // "tcp/80" + port, j, ok := dtoi(portnet, 0) + if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' { + continue + } + netw := portnet[j+1:] // "tcp" + m, ok1 := services[netw] + if !ok1 { + m = make(map[string]int) + services[netw] = m + } + for i := 0; i < len(f); i++ { + if i != 1 { // f[1] was port/net + m[f[i]] = port + } + } + } + file.close() +} + +// goLookupPort is the native Go implementation of LookupPort. +func goLookupPort(network, service string) (port int, err os.Error) { + onceReadServices.Do(readServices) + + switch network { + case "tcp4", "tcp6": + network = "tcp" + case "udp4", "udp6": + network = "udp" + } + + if m, ok := services[network]; ok { + if port, ok = m[service]; ok { + return + } + } + return 0, &AddrError{"unknown port", network + "/" + service} +} diff --git a/src/pkg/net/port_test.go b/src/pkg/net/port_test.go new file mode 100644 index 000000000..329b169f3 --- /dev/null +++ b/src/pkg/net/port_test.go @@ -0,0 +1,53 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "testing" +) + +type portTest struct { + netw string + name string + port int + ok bool +} + +var porttests = []portTest{ + {"tcp", "echo", 7, true}, + {"tcp", "discard", 9, true}, + {"tcp", "systat", 11, true}, + {"tcp", "daytime", 13, true}, + {"tcp", "chargen", 19, true}, + {"tcp", "ftp-data", 20, true}, + {"tcp", "ftp", 21, true}, + {"tcp", "telnet", 23, true}, + {"tcp", "smtp", 25, true}, + {"tcp", "time", 37, true}, + {"tcp", "domain", 53, true}, + {"tcp", "finger", 79, true}, + + {"udp", "echo", 7, true}, + {"udp", "tftp", 69, true}, + {"udp", "bootpc", 68, true}, + {"udp", "bootps", 67, true}, + {"udp", "domain", 53, true}, + {"udp", "ntp", 123, true}, + {"udp", "snmp", 161, true}, + {"udp", "syslog", 514, true}, + + {"--badnet--", "zzz", 0, false}, + {"tcp", "--badport--", 0, false}, +} + +func TestLookupPort(t *testing.T) { + for i := 0; i < len(porttests); i++ { + tt := porttests[i] + if port, err := LookupPort(tt.netw, tt.name); port != tt.port || (err == nil) != tt.ok { + t.Errorf("LookupPort(%q, %q) = %v, %s; want %v", + tt.netw, tt.name, port, err, tt.port) + } + } +} diff --git a/src/pkg/net/sendfile_linux.go b/src/pkg/net/sendfile_linux.go new file mode 100644 index 000000000..6a5a06c8c --- /dev/null +++ b/src/pkg/net/sendfile_linux.go @@ -0,0 +1,84 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "io" + "os" + "syscall" +) + +// maxSendfileSize is the largest chunk size we ask the kernel to copy +// at a time. +const maxSendfileSize int = 4 << 20 + +// sendFile copies the contents of r to c using the sendfile +// system call to minimize copies. +// +// if handled == true, sendFile returns the number of bytes copied and any +// non-EOF error. +// +// if handled == false, sendFile performed no work. +func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool) { + var remain int64 = 1 << 62 // by default, copy until EOF + + lr, ok := r.(*io.LimitedReader) + if ok { + remain, r = lr.N, lr.R + if remain <= 0 { + return 0, nil, true + } + } + f, ok := r.(*os.File) + if !ok { + return 0, nil, false + } + + c.wio.Lock() + defer c.wio.Unlock() + c.incref() + defer c.decref() + if c.wdeadline_delta > 0 { + // This is a little odd that we're setting the timeout + // for the entire file but Write has the same issue + // (if one slurps the whole file into memory and + // do one large Write). At least they're consistent. + c.wdeadline = pollserver.Now() + c.wdeadline_delta + } else { + c.wdeadline = 0 + } + + dst := c.sysfd + src := f.Fd() + for remain > 0 { + n := maxSendfileSize + if int64(n) > remain { + n = int(remain) + } + n, errno := syscall.Sendfile(dst, src, nil, n) + if n > 0 { + written += int64(n) + remain -= int64(n) + } + if n == 0 && errno == 0 { + break + } + if errno == syscall.EAGAIN && c.wdeadline >= 0 { + pollserver.WaitWrite(c) + continue + } + if errno != 0 { + // This includes syscall.ENOSYS (no kernel + // support) and syscall.EINVAL (fd types which + // don't implement sendfile together) + err = &OpError{"sendfile", c.net, c.raddr, os.Errno(errno)} + break + } + } + if lr != nil { + lr.N = remain + } + return written, err, written > 0 +} diff --git a/src/pkg/net/sendfile_stub.go b/src/pkg/net/sendfile_stub.go new file mode 100644 index 000000000..43e8104e9 --- /dev/null +++ b/src/pkg/net/sendfile_stub.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "io" + "os" +) + +func sendFile(c *netFD, r io.Reader) (n int64, err os.Error, handled bool) { + return 0, nil, false +} diff --git a/src/pkg/net/sendfile_windows.go b/src/pkg/net/sendfile_windows.go new file mode 100644 index 000000000..3772eee24 --- /dev/null +++ b/src/pkg/net/sendfile_windows.go @@ -0,0 +1,68 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "io" + "os" + "syscall" +) + +type sendfileOp struct { + anOp + src syscall.Handle // source + n uint32 +} + +func (o *sendfileOp) Submit() (errno int) { + return syscall.TransmitFile(o.fd.sysfd, o.src, o.n, 0, &o.o, nil, syscall.TF_WRITE_BEHIND) +} + +func (o *sendfileOp) Name() string { + return "TransmitFile" +} + +// sendFile copies the contents of r to c using the TransmitFile +// system call to minimize copies. +// +// if handled == true, sendFile returns the number of bytes copied and any +// non-EOF error. +// +// if handled == false, sendFile performed no work. +// +// Note that sendfile for windows does not suppport >2GB file. +func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool) { + var n int64 = 0 // by default, copy until EOF + + lr, ok := r.(*io.LimitedReader) + if ok { + n, r = lr.N, lr.R + if n <= 0 { + return 0, nil, true + } + } + f, ok := r.(*os.File) + if !ok { + return 0, nil, false + } + + c.wio.Lock() + defer c.wio.Unlock() + c.incref() + defer c.decref() + + var o sendfileOp + o.Init(c) + o.n = uint32(n) + o.src = f.Fd() + done, err := iosrv.ExecIO(&o, 0) + if err != nil { + return 0, err, false + } + if lr != nil { + lr.N -= int64(done) + } + return int64(done), nil, true +} diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go new file mode 100644 index 000000000..7d7f7fc01 --- /dev/null +++ b/src/pkg/net/server_test.go @@ -0,0 +1,243 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "flag" + "io" + "os" + "strings" + "syscall" + "testing" + "runtime" +) + +// Do not test empty datagrams by default. +// It causes unexplained timeouts on some systems, +// including Snow Leopard. I think that the kernel +// doesn't quite expect them. +var testUDP = flag.Bool("udp", false, "whether to test UDP datagrams") + +func runEcho(fd io.ReadWriter, done chan<- int) { + var buf [1024]byte + + for { + n, err := fd.Read(buf[0:]) + if err != nil || n == 0 || string(buf[:n]) == "END" { + break + } + fd.Write(buf[0:n]) + } + done <- 1 +} + +func runServe(t *testing.T, network, addr string, listening chan<- string, done chan<- int) { + l, err := Listen(network, addr) + if err != nil { + t.Fatalf("net.Listen(%q, %q) = _, %v", network, addr, err) + } + listening <- l.Addr().String() + + for { + fd, err := l.Accept() + if err != nil { + break + } + echodone := make(chan int) + go runEcho(fd, echodone) + <-echodone // make sure Echo stops + l.Close() + } + done <- 1 +} + +func connect(t *testing.T, network, addr string, isEmpty bool) { + var fd Conn + var err os.Error + if network == "unixgram" { + fd, err = DialUnix(network, &UnixAddr{addr + ".local", network}, &UnixAddr{addr, network}) + } else { + fd, err = Dial(network, addr) + } + if err != nil { + t.Fatalf("net.Dial(%q, %q) = _, %v", network, addr, err) + } + fd.SetReadTimeout(1e9) // 1s + + var b []byte + if !isEmpty { + b = []byte("hello, world\n") + } + var b1 [100]byte + + n, err1 := fd.Write(b) + if n != len(b) { + t.Fatalf("fd.Write(%q) = %d, %v", b, n, err1) + } + + n, err1 = fd.Read(b1[0:]) + if n != len(b) || err1 != nil { + t.Fatalf("fd.Read() = %d, %v (want %d, nil)", n, err1, len(b)) + } + + // Send explicit ending for unixpacket. + // Older Linux kernels do stop reads on close. + if network == "unixpacket" { + fd.Write([]byte("END")) + } + + fd.Close() +} + +func doTest(t *testing.T, network, listenaddr, dialaddr string) { + t.Logf("Test %q %q %q\n", network, listenaddr, dialaddr) + switch listenaddr { + case "", "0.0.0.0", "[::]", "[::ffff:0.0.0.0]": + if testing.Short() || avoidMacFirewall { + t.Logf("skip wildcard listen during short test") + return + } + } + listening := make(chan string) + done := make(chan int) + if network == "tcp" || network == "tcp4" || network == "tcp6" { + listenaddr += ":0" // any available port + } + go runServe(t, network, listenaddr, listening, done) + addr := <-listening // wait for server to start + if network == "tcp" || network == "tcp4" || network == "tcp6" { + dialaddr += addr[strings.LastIndex(addr, ":"):] + } + connect(t, network, dialaddr, false) + <-done // make sure server stopped +} + +func TestTCPServer(t *testing.T) { + doTest(t, "tcp", "", "127.0.0.1") + doTest(t, "tcp", "0.0.0.0", "127.0.0.1") + doTest(t, "tcp", "127.0.0.1", "127.0.0.1") + doTest(t, "tcp4", "", "127.0.0.1") + doTest(t, "tcp4", "0.0.0.0", "127.0.0.1") + doTest(t, "tcp4", "127.0.0.1", "127.0.0.1") + if supportsIPv6 { + doTest(t, "tcp", "", "[::1]") + doTest(t, "tcp", "[::]", "[::1]") + doTest(t, "tcp", "[::1]", "[::1]") + doTest(t, "tcp6", "", "[::1]") + doTest(t, "tcp6", "[::]", "[::1]") + doTest(t, "tcp6", "[::1]", "[::1]") + } + if supportsIPv6 && supportsIPv4map { + doTest(t, "tcp", "[::ffff:0.0.0.0]", "127.0.0.1") + doTest(t, "tcp", "[::]", "127.0.0.1") + doTest(t, "tcp4", "[::ffff:0.0.0.0]", "127.0.0.1") + doTest(t, "tcp6", "", "127.0.0.1") + doTest(t, "tcp6", "[::ffff:0.0.0.0]", "127.0.0.1") + doTest(t, "tcp6", "[::]", "127.0.0.1") + doTest(t, "tcp", "127.0.0.1", "[::ffff:127.0.0.1]") + doTest(t, "tcp", "[::ffff:127.0.0.1]", "127.0.0.1") + doTest(t, "tcp4", "127.0.0.1", "[::ffff:127.0.0.1]") + doTest(t, "tcp4", "[::ffff:127.0.0.1]", "127.0.0.1") + doTest(t, "tcp6", "127.0.0.1", "[::ffff:127.0.0.1]") + doTest(t, "tcp6", "[::ffff:127.0.0.1]", "127.0.0.1") + } +} + +func TestUnixServer(t *testing.T) { + // "unix" sockets are not supported on windows and Plan 9. + if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { + return + } + os.Remove("/tmp/gotest.net") + doTest(t, "unix", "/tmp/gotest.net", "/tmp/gotest.net") + os.Remove("/tmp/gotest.net") + if syscall.OS == "linux" { + doTest(t, "unixpacket", "/tmp/gotest.net", "/tmp/gotest.net") + os.Remove("/tmp/gotest.net") + // Test abstract unix domain socket, a Linux-ism + doTest(t, "unix", "@gotest/net", "@gotest/net") + doTest(t, "unixpacket", "@gotest/net", "@gotest/net") + } +} + +func runPacket(t *testing.T, network, addr string, listening chan<- string, done chan<- int) { + c, err := ListenPacket(network, addr) + if err != nil { + t.Fatalf("net.ListenPacket(%q, %q) = _, %v", network, addr, err) + } + listening <- c.LocalAddr().String() + c.SetReadTimeout(10e6) // 10ms + var buf [1000]byte +Run: + for { + n, addr, err := c.ReadFrom(buf[0:]) + if e, ok := err.(Error); ok && e.Timeout() { + select { + case done <- 1: + break Run + default: + continue Run + } + } + if err != nil { + break + } + if _, err = c.WriteTo(buf[0:n], addr); err != nil { + t.Fatalf("WriteTo %v: %v", addr, err) + } + } + c.Close() + done <- 1 +} + +func doTestPacket(t *testing.T, network, listenaddr, dialaddr string, isEmpty bool) { + t.Logf("TestPacket %s %s %s\n", network, listenaddr, dialaddr) + listening := make(chan string) + done := make(chan int) + if network == "udp" { + listenaddr += ":0" // any available port + } + go runPacket(t, network, listenaddr, listening, done) + addr := <-listening // wait for server to start + if network == "udp" { + dialaddr += addr[strings.LastIndex(addr, ":"):] + } + connect(t, network, dialaddr, isEmpty) + <-done // tell server to stop + <-done // wait for stop +} + +func TestUDPServer(t *testing.T) { + if !*testUDP { + return + } + for _, isEmpty := range []bool{false, true} { + doTestPacket(t, "udp", "0.0.0.0", "127.0.0.1", isEmpty) + doTestPacket(t, "udp", "", "127.0.0.1", isEmpty) + if supportsIPv6 && supportsIPv4map { + doTestPacket(t, "udp", "[::]", "[::ffff:127.0.0.1]", isEmpty) + doTestPacket(t, "udp", "[::]", "127.0.0.1", isEmpty) + doTestPacket(t, "udp", "0.0.0.0", "[::ffff:127.0.0.1]", isEmpty) + } + } +} + +func TestUnixDatagramServer(t *testing.T) { + // "unix" sockets are not supported on windows and Plan 9. + if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { + return + } + for _, isEmpty := range []bool{false} { + os.Remove("/tmp/gotest1.net") + os.Remove("/tmp/gotest1.net.local") + doTestPacket(t, "unixgram", "/tmp/gotest1.net", "/tmp/gotest1.net", isEmpty) + os.Remove("/tmp/gotest1.net") + os.Remove("/tmp/gotest1.net.local") + if syscall.OS == "linux" { + // Test abstract unix domain socket, a Linux-ism + doTestPacket(t, "unixgram", "@gotest1/net", "@gotest1/net", isEmpty) + } + } +} diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go new file mode 100644 index 000000000..821716e43 --- /dev/null +++ b/src/pkg/net/sock.go @@ -0,0 +1,166 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Sockets + +package net + +import ( + "io" + "os" + "reflect" + "syscall" +) + +// Boolean to int. +func boolint(b bool) int { + if b { + return 1 + } + return 0 +} + +// Generic socket creation. +func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err os.Error) { + // See ../syscall/exec.go for description of ForkLock. + syscall.ForkLock.RLock() + s, e := syscall.Socket(f, p, t) + if e != 0 { + syscall.ForkLock.RUnlock() + return nil, os.Errno(e) + } + syscall.CloseOnExec(s) + syscall.ForkLock.RUnlock() + + setKernelSpecificSockopt(s, f) + + if la != nil { + e = syscall.Bind(s, la) + if e != 0 { + closesocket(s) + return nil, os.Errno(e) + } + } + + if fd, err = newFD(s, f, p, net); err != nil { + closesocket(s) + return nil, err + } + + if ra != nil { + if err = fd.connect(ra); err != nil { + fd.Close() + return nil, err + } + } + + sa, _ := syscall.Getsockname(s) + laddr := toAddr(sa) + sa, _ = syscall.Getpeername(s) + raddr := toAddr(sa) + + fd.setAddr(laddr, raddr) + return fd, nil +} + +func setsockoptInt(fd *netFD, level, opt int, value int) os.Error { + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, level, opt, value)) +} + +func setsockoptNsec(fd *netFD, level, opt int, nsec int64) os.Error { + var tv = syscall.NsecToTimeval(nsec) + return os.NewSyscallError("setsockopt", syscall.SetsockoptTimeval(fd.sysfd, level, opt, &tv)) +} + +func setReadBuffer(fd *netFD, bytes int) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes) +} + +func setWriteBuffer(fd *netFD, bytes int) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes) +} + +func setReadTimeout(fd *netFD, nsec int64) os.Error { + fd.rdeadline_delta = nsec + return nil +} + +func setWriteTimeout(fd *netFD, nsec int64) os.Error { + fd.wdeadline_delta = nsec + return nil +} + +func setTimeout(fd *netFD, nsec int64) os.Error { + if e := setReadTimeout(fd, nsec); e != nil { + return e + } + return setWriteTimeout(fd, nsec) +} + +func setReuseAddr(fd *netFD, reuse bool) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse)) +} + +func bindToDevice(fd *netFD, dev string) os.Error { + // TODO(rsc): call setsockopt with null-terminated string pointer + return os.EINVAL +} + +func setDontRoute(fd *netFD, dontroute bool) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute)) +} + +func setKeepAlive(fd *netFD, keepalive bool) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive)) +} + +func setNoDelay(fd *netFD, noDelay bool) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay)) +} + +func setLinger(fd *netFD, sec int) os.Error { + var l syscall.Linger + if sec >= 0 { + l.Onoff = 1 + l.Linger = int32(sec) + } else { + l.Onoff = 0 + l.Linger = 0 + } + fd.incref() + defer fd.decref() + e := syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l) + return os.NewSyscallError("setsockopt", e) +} + +type UnknownSocketError struct { + sa syscall.Sockaddr +} + +func (e *UnknownSocketError) String() string { + return "unknown socket address type " + reflect.TypeOf(e.sa).String() +} + +type writerOnly struct { + io.Writer +} + +// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't +// applicable. +func genericReadFrom(w io.Writer, r io.Reader) (n int64, err os.Error) { + // Use wrapper to hide existing r.ReadFrom from io.Copy. + return io.Copy(writerOnly{w}, r) +} diff --git a/src/pkg/net/sock_bsd.go b/src/pkg/net/sock_bsd.go new file mode 100644 index 000000000..5fd52074a --- /dev/null +++ b/src/pkg/net/sock_bsd.go @@ -0,0 +1,31 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Sockets for BSD variants + +package net + +import ( + "syscall" +) + +func setKernelSpecificSockopt(s, f int) { + // Allow reuse of recently-used addresses. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + + // Allow reuse of recently-used ports. + // This option is supported only in descendants of 4.4BSD, + // to make an effective multicast application and an application + // that requires quick draw possible. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1) + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + + if f == syscall.AF_INET6 { + // using ip, tcp, udp, etc. + // allow both protocols even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } +} diff --git a/src/pkg/net/sock_linux.go b/src/pkg/net/sock_linux.go new file mode 100644 index 000000000..ec31e803b --- /dev/null +++ b/src/pkg/net/sock_linux.go @@ -0,0 +1,25 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Sockets for Linux + +package net + +import ( + "syscall" +) + +func setKernelSpecificSockopt(s, f int) { + // Allow reuse of recently-used addresses. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + + if f == syscall.AF_INET6 { + // using ip, tcp, udp, etc. + // allow both protocols even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } +} diff --git a/src/pkg/net/sock_windows.go b/src/pkg/net/sock_windows.go new file mode 100644 index 000000000..c6dbd0465 --- /dev/null +++ b/src/pkg/net/sock_windows.go @@ -0,0 +1,25 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Sockets for Windows + +package net + +import ( + "syscall" +) + +func setKernelSpecificSockopt(s syscall.Handle, f int) { + // Allow reuse of recently-used addresses and ports. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + + if f == syscall.AF_INET6 { + // using ip, tcp, udp, etc. + // allow both protocols even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } +} diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go new file mode 100644 index 000000000..f5c0a2781 --- /dev/null +++ b/src/pkg/net/tcpsock.go @@ -0,0 +1,40 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TCP sockets + +package net + +import ( + "os" +) + +// TCPAddr represents the address of a TCP end point. +type TCPAddr struct { + IP IP + Port int +} + +// Network returns the address's network name, "tcp". +func (a *TCPAddr) Network() string { return "tcp" } + +func (a *TCPAddr) String() string { + if a == nil { + return "<nil>" + } + return JoinHostPort(a.IP.String(), itoa(a.Port)) +} + +// ResolveTCPAddr parses addr as a TCP address of the form +// host:port and resolves domain names or port names to +// numeric addresses on the network net, which must be "tcp", +// "tcp4" or "tcp6". A literal IPv6 host address must be +// enclosed in square brackets, as in "[::]:80". +func ResolveTCPAddr(net, addr string) (*TCPAddr, os.Error) { + ip, port, err := hostPortToIP(net, addr) + if err != nil { + return nil, err + } + return &TCPAddr{ip, port}, nil +} diff --git a/src/pkg/net/tcpsock_plan9.go b/src/pkg/net/tcpsock_plan9.go new file mode 100644 index 000000000..f4f6e9fee --- /dev/null +++ b/src/pkg/net/tcpsock_plan9.go @@ -0,0 +1,63 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TCP for Plan 9 + +package net + +import ( + "os" +) + +// TCPConn is an implementation of the Conn interface +// for TCP network connections. +type TCPConn struct { + plan9Conn +} + +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. +func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err os.Error) { + switch net { + case "tcp", "tcp4", "tcp6": + default: + return nil, UnknownNetworkError(net) + } + if raddr == nil { + return nil, &OpError{"dial", "tcp", nil, errMissingAddress} + } + c1, err := dialPlan9(net, laddr, raddr) + if err != nil { + return + } + return &TCPConn{*c1}, nil +} + +// TCPListener is a TCP network listener. +// Clients should typically use variables of type Listener +// instead of assuming TCP. +type TCPListener struct { + plan9Listener +} + +// ListenTCP announces on the TCP address laddr and returns a TCP listener. +// Net must be "tcp", "tcp4", or "tcp6". +// If laddr has a port of 0, it means to listen on some available port. +// The caller can use l.Addr() to retrieve the chosen address. +func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) { + switch net { + case "tcp", "tcp4", "tcp6": + default: + return nil, UnknownNetworkError(net) + } + if laddr == nil { + return nil, &OpError{"listen", "tcp", nil, errMissingAddress} + } + l1, err := listenPlan9(net, laddr) + if err != nil { + return + } + return &TCPListener{*l1}, nil +} diff --git a/src/pkg/net/tcpsock_posix.go b/src/pkg/net/tcpsock_posix.go new file mode 100644 index 000000000..5560301b4 --- /dev/null +++ b/src/pkg/net/tcpsock_posix.go @@ -0,0 +1,283 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TCP sockets + +package net + +import ( + "io" + "os" + "syscall" +) + +func sockaddrToTCP(sa syscall.Sockaddr) Addr { + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + return &TCPAddr{sa.Addr[0:], sa.Port} + case *syscall.SockaddrInet6: + return &TCPAddr{sa.Addr[0:], sa.Port} + } + return nil +} + +func (a *TCPAddr) family() int { + if a == nil || len(a.IP) <= 4 { + return syscall.AF_INET + } + if a.IP.To4() != nil { + return syscall.AF_INET + } + return syscall.AF_INET6 +} + +func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) { + return ipToSockaddr(family, a.IP, a.Port) +} + +func (a *TCPAddr) toAddr() sockaddr { + if a == nil { // nil *TCPAddr + return nil // nil interface + } + return a +} + +// TCPConn is an implementation of the Conn interface +// for TCP network connections. +type TCPConn struct { + fd *netFD +} + +func newTCPConn(fd *netFD) *TCPConn { + c := &TCPConn{fd} + c.SetNoDelay(true) + return c +} + +func (c *TCPConn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *TCPConn) Read(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Read(b) +} + +// ReadFrom implements the io.ReaderFrom ReadFrom method. +func (c *TCPConn) ReadFrom(r io.Reader) (int64, os.Error) { + if n, err, handled := sendFile(c.fd, r); handled { + return n, err + } + return genericReadFrom(c, r) +} + +// Write implements the net.Conn Write method. +func (c *TCPConn) Write(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Write(b) +} + +// Close closes the TCP connection. +func (c *TCPConn) Close() os.Error { + if !c.ok() { + return os.EINVAL + } + err := c.fd.Close() + c.fd = nil + return err +} + +// LocalAddr returns the local network address, a *TCPAddr. +func (c *TCPConn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address, a *TCPAddr. +func (c *TCPConn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *TCPConn) SetTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setTimeout(c.fd, nsec) +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *TCPConn) SetReadTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadTimeout(c.fd, nsec) +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *TCPConn) SetWriteTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteTimeout(c.fd, nsec) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *TCPConn) SetReadBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *TCPConn) SetWriteBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// SetLinger sets the behavior of Close() on a connection +// which still has data waiting to be sent or to be acknowledged. +// +// If sec < 0 (the default), Close returns immediately and +// the operating system finishes sending the data in the background. +// +// If sec == 0, Close returns immediately and the operating system +// discards any unsent or unacknowledged data. +// +// If sec > 0, Close blocks for at most sec seconds waiting for +// data to be sent and acknowledged. +func (c *TCPConn) SetLinger(sec int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setLinger(c.fd, sec) +} + +// SetKeepAlive sets whether the operating system should send +// keepalive messages on the connection. +func (c *TCPConn) SetKeepAlive(keepalive bool) os.Error { + if !c.ok() { + return os.EINVAL + } + return setKeepAlive(c.fd, keepalive) +} + +// SetNoDelay controls whether the operating system should delay +// packet transmission in hopes of sending fewer packets +// (Nagle's algorithm). The default is true (no delay), meaning +// that data is sent as soon as possible after a Write. +func (c *TCPConn) SetNoDelay(noDelay bool) os.Error { + if !c.ok() { + return os.EINVAL + } + return setNoDelay(c.fd, noDelay) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *TCPConn) File() (f *os.File, err os.Error) { return c.fd.dup() } + +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. +func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err os.Error) { + if raddr == nil { + return nil, &OpError{"dial", "tcp", nil, errMissingAddress} + } + fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) + if e != nil { + return nil, e + } + return newTCPConn(fd), nil +} + +// TCPListener is a TCP network listener. +// Clients should typically use variables of type Listener +// instead of assuming TCP. +type TCPListener struct { + fd *netFD +} + +// ListenTCP announces on the TCP address laddr and returns a TCP listener. +// Net must be "tcp", "tcp4", or "tcp6". +// If laddr has a port of 0, it means to listen on some available port. +// The caller can use l.Addr() to retrieve the chosen address. +func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) { + fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP) + if err != nil { + return nil, err + } + errno := syscall.Listen(fd.sysfd, listenBacklog()) + if errno != 0 { + closesocket(fd.sysfd) + return nil, &OpError{"listen", "tcp", laddr, os.Errno(errno)} + } + l = new(TCPListener) + l.fd = fd + return l, nil +} + +// AcceptTCP accepts the next incoming call and returns the new connection +// and the remote address. +func (l *TCPListener) AcceptTCP() (c *TCPConn, err os.Error) { + if l == nil || l.fd == nil || l.fd.sysfd < 0 { + return nil, os.EINVAL + } + fd, err := l.fd.accept(sockaddrToTCP) + if err != nil { + return nil, err + } + return newTCPConn(fd), nil +} + +// Accept implements the Accept method in the Listener interface; +// it waits for the next call and returns a generic Conn. +func (l *TCPListener) Accept() (c Conn, err os.Error) { + c1, err := l.AcceptTCP() + if err != nil { + return nil, err + } + return c1, nil +} + +// Close stops listening on the TCP address. +// Already Accepted connections are not closed. +func (l *TCPListener) Close() os.Error { + if l == nil || l.fd == nil { + return os.EINVAL + } + return l.fd.Close() +} + +// Addr returns the listener's network address, a *TCPAddr. +func (l *TCPListener) Addr() Addr { return l.fd.laddr } + +// SetTimeout sets the deadline associated with the listener +func (l *TCPListener) SetTimeout(nsec int64) os.Error { + if l == nil || l.fd == nil { + return os.EINVAL + } + return setTimeout(l.fd, nsec) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (l *TCPListener) File() (f *os.File, err os.Error) { return l.fd.dup() } diff --git a/src/pkg/net/textproto/Makefile b/src/pkg/net/textproto/Makefile new file mode 100644 index 000000000..cadf3ab69 --- /dev/null +++ b/src/pkg/net/textproto/Makefile @@ -0,0 +1,15 @@ +# Copyright 2010 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=net/textproto +GOFILES=\ + header.go\ + pipeline.go\ + reader.go\ + textproto.go\ + writer.go\ + +include ../../../Make.pkg diff --git a/src/pkg/net/textproto/header.go b/src/pkg/net/textproto/header.go new file mode 100644 index 000000000..288deb2ce --- /dev/null +++ b/src/pkg/net/textproto/header.go @@ -0,0 +1,43 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +// A MIMEHeader represents a MIME-style header mapping +// keys to sets of values. +type MIMEHeader map[string][]string + +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +func (h MIMEHeader) Add(key, value string) { + key = CanonicalMIMEHeaderKey(key) + h[key] = append(h[key], value) +} + +// Set sets the header entries associated with key to +// the single element value. It replaces any existing +// values associated with key. +func (h MIMEHeader) Set(key, value string) { + h[CanonicalMIMEHeaderKey(key)] = []string{value} +} + +// Get gets the first value associated with the given key. +// If there are no values associated with the key, Get returns "". +// Get is a convenience method. For more complex queries, +// access the map directly. +func (h MIMEHeader) Get(key string) string { + if h == nil { + return "" + } + v := h[CanonicalMIMEHeaderKey(key)] + if len(v) == 0 { + return "" + } + return v[0] +} + +// Del deletes the values associated with key. +func (h MIMEHeader) Del(key string) { + h[CanonicalMIMEHeaderKey(key)] = nil, false +} diff --git a/src/pkg/net/textproto/pipeline.go b/src/pkg/net/textproto/pipeline.go new file mode 100644 index 000000000..8c25884b3 --- /dev/null +++ b/src/pkg/net/textproto/pipeline.go @@ -0,0 +1,117 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "sync" +) + +// A Pipeline manages a pipelined in-order request/response sequence. +// +// To use a Pipeline p to manage multiple clients on a connection, +// each client should run: +// +// id := p.Next() // take a number +// +// p.StartRequest(id) // wait for turn to send request +// «send request» +// p.EndRequest(id) // notify Pipeline that request is sent +// +// p.StartResponse(id) // wait for turn to read response +// «read response» +// p.EndResponse(id) // notify Pipeline that response is read +// +// A pipelined server can use the same calls to ensure that +// responses computed in parallel are written in the correct order. +type Pipeline struct { + mu sync.Mutex + id uint + request sequencer + response sequencer +} + +// Next returns the next id for a request/response pair. +func (p *Pipeline) Next() uint { + p.mu.Lock() + id := p.id + p.id++ + p.mu.Unlock() + return id +} + +// StartRequest blocks until it is time to send (or, if this is a server, receive) +// the request with the given id. +func (p *Pipeline) StartRequest(id uint) { + p.request.Start(id) +} + +// EndRequest notifies p that the request with the given id has been sent +// (or, if this is a server, received). +func (p *Pipeline) EndRequest(id uint) { + p.request.End(id) +} + +// StartResponse blocks until it is time to receive (or, if this is a server, send) +// the request with the given id. +func (p *Pipeline) StartResponse(id uint) { + p.response.Start(id) +} + +// EndResponse notifies p that the response with the given id has been received +// (or, if this is a server, sent). +func (p *Pipeline) EndResponse(id uint) { + p.response.End(id) +} + +// A sequencer schedules a sequence of numbered events that must +// happen in order, one after the other. The event numbering must start +// at 0 and increment without skipping. The event number wraps around +// safely as long as there are not 2^32 simultaneous events pending. +type sequencer struct { + mu sync.Mutex + id uint + wait map[uint]chan uint +} + +// Start waits until it is time for the event numbered id to begin. +// That is, except for the first event, it waits until End(id-1) has +// been called. +func (s *sequencer) Start(id uint) { + s.mu.Lock() + if s.id == id { + s.mu.Unlock() + return + } + c := make(chan uint) + if s.wait == nil { + s.wait = make(map[uint]chan uint) + } + s.wait[id] = c + s.mu.Unlock() + <-c +} + +// End notifies the sequencer that the event numbered id has completed, +// allowing it to schedule the event numbered id+1. It is a run-time error +// to call End with an id that is not the number of the active event. +func (s *sequencer) End(id uint) { + s.mu.Lock() + if s.id != id { + panic("out of sync") + } + id++ + s.id = id + if s.wait == nil { + s.wait = make(map[uint]chan uint) + } + c, ok := s.wait[id] + if ok { + s.wait[id] = nil, false + } + s.mu.Unlock() + if ok { + c <- 1 + } +} diff --git a/src/pkg/net/textproto/reader.go b/src/pkg/net/textproto/reader.go new file mode 100644 index 000000000..ce0ddc73f --- /dev/null +++ b/src/pkg/net/textproto/reader.go @@ -0,0 +1,514 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "io" + "io/ioutil" + "os" + "strconv" +) + +// BUG(rsc): To let callers manage exposure to denial of service +// attacks, Reader should allow them to set and reset a limit on +// the number of bytes read from the connection. + +// A Reader implements convenience methods for reading requests +// or responses from a text protocol network connection. +type Reader struct { + R *bufio.Reader + dot *dotReader +} + +// NewReader returns a new Reader reading from r. +func NewReader(r *bufio.Reader) *Reader { + return &Reader{R: r} +} + +// ReadLine reads a single line from r, +// eliding the final \n or \r\n from the returned string. +func (r *Reader) ReadLine() (string, os.Error) { + line, err := r.readLineSlice() + return string(line), err +} + +// ReadLineBytes is like ReadLine but returns a []byte instead of a string. +func (r *Reader) ReadLineBytes() ([]byte, os.Error) { + line, err := r.readLineSlice() + if line != nil { + buf := make([]byte, len(line)) + copy(buf, line) + line = buf + } + return line, err +} + +func (r *Reader) readLineSlice() ([]byte, os.Error) { + r.closeDot() + line, _, err := r.R.ReadLine() + return line, err +} + +// ReadContinuedLine reads a possibly continued line from r, +// eliding the final trailing ASCII white space. +// Lines after the first are considered continuations if they +// begin with a space or tab character. In the returned data, +// continuation lines are separated from the previous line +// only by a single space: the newline and leading white space +// are removed. +// +// For example, consider this input: +// +// Line 1 +// continued... +// Line 2 +// +// The first call to ReadContinuedLine will return "Line 1 continued..." +// and the second will return "Line 2". +// +// A line consisting of only white space is never continued. +// +func (r *Reader) ReadContinuedLine() (string, os.Error) { + line, err := r.readContinuedLineSlice() + return string(line), err +} + +// trim returns s with leading and trailing spaces and tabs removed. +// It does not assume Unicode or UTF-8. +func trim(s []byte) []byte { + i := 0 + for i < len(s) && (s[i] == ' ' || s[i] == '\t') { + i++ + } + n := len(s) + for n > i && (s[n-1] == ' ' || s[n-1] == '\t') { + n-- + } + return s[i:n] +} + +// ReadContinuedLineBytes is like ReadContinuedLine but +// returns a []byte instead of a string. +func (r *Reader) ReadContinuedLineBytes() ([]byte, os.Error) { + line, err := r.readContinuedLineSlice() + if line != nil { + buf := make([]byte, len(line)) + copy(buf, line) + line = buf + } + return line, err +} + +func (r *Reader) readContinuedLineSlice() ([]byte, os.Error) { + // Read the first line. + line, err := r.readLineSlice() + if err != nil { + return line, err + } + if len(line) == 0 { // blank line - no continuation + return line, nil + } + line = trim(line) + + copied := false + if r.R.Buffered() < 1 { + // ReadByte will flush the buffer; make a copy of the slice. + copied = true + line = append([]byte(nil), line...) + } + + // Look for a continuation line. + c, err := r.R.ReadByte() + if err != nil { + // Delay err until we read the byte next time. + return line, nil + } + if c != ' ' && c != '\t' { + // Not a continuation. + r.R.UnreadByte() + return line, nil + } + + if !copied { + // The next readLineSlice will invalidate the previous one. + line = append(make([]byte, 0, len(line)*2), line...) + } + + // Read continuation lines. + for { + // Consume leading spaces; one already gone. + for { + c, err = r.R.ReadByte() + if err != nil { + break + } + if c != ' ' && c != '\t' { + r.R.UnreadByte() + break + } + } + var cont []byte + cont, err = r.readLineSlice() + cont = trim(cont) + line = append(line, ' ') + line = append(line, cont...) + if err != nil { + break + } + + // Check for leading space on next line. + if c, err = r.R.ReadByte(); err != nil { + break + } + if c != ' ' && c != '\t' { + r.R.UnreadByte() + break + } + } + + // Delay error until next call. + if len(line) > 0 { + err = nil + } + return line, err +} + +func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err os.Error) { + line, err := r.ReadLine() + if err != nil { + return + } + if len(line) < 4 || line[3] != ' ' && line[3] != '-' { + err = ProtocolError("short response: " + line) + return + } + continued = line[3] == '-' + code, err = strconv.Atoi(line[0:3]) + if err != nil || code < 100 { + err = ProtocolError("invalid response code: " + line) + return + } + message = line[4:] + if 1 <= expectCode && expectCode < 10 && code/100 != expectCode || + 10 <= expectCode && expectCode < 100 && code/10 != expectCode || + 100 <= expectCode && expectCode < 1000 && code != expectCode { + err = &Error{code, message} + } + return +} + +// ReadCodeLine reads a response code line of the form +// code message +// where code is a 3-digit status code and the message +// extends to the rest of the line. An example of such a line is: +// 220 plan9.bell-labs.com ESMTP +// +// If the prefix of the status does not match the digits in expectCode, +// ReadCodeLine returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// If the response is multi-line, ReadCodeLine returns an error. +// +// An expectCode <= 0 disables the check of the status code. +// +func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err os.Error) { + code, continued, message, err := r.readCodeLine(expectCode) + if err == nil && continued { + err = ProtocolError("unexpected multi-line response: " + message) + } + return +} + +// ReadResponse reads a multi-line response of the form +// code-message line 1 +// code-message line 2 +// ... +// code message line n +// where code is a 3-digit status code. Each line should have the same code. +// The response is terminated by a line that uses a space between the code and +// the message line rather than a dash. Each line in message is separated by +// a newline (\n). +// +// If the prefix of the status does not match the digits in expectCode, +// ReadResponse returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// An expectCode <= 0 disables the check of the status code. +// +func (r *Reader) ReadResponse(expectCode int) (code int, message string, err os.Error) { + code, continued, message, err := r.readCodeLine(expectCode) + for err == nil && continued { + var code2 int + var moreMessage string + code2, continued, moreMessage, err = r.readCodeLine(expectCode) + if code != code2 { + err = ProtocolError("status code mismatch: " + strconv.Itoa(code) + ", " + strconv.Itoa(code2)) + } + message += "\n" + moreMessage + } + return +} + +// DotReader returns a new Reader that satisfies Reads using the +// decoded text of a dot-encoded block read from r. +// The returned Reader is only valid until the next call +// to a method on r. +// +// Dot encoding is a common framing used for data blocks +// in text protocols such as SMTP. The data consists of a sequence +// of lines, each of which ends in "\r\n". The sequence itself +// ends at a line containing just a dot: ".\r\n". Lines beginning +// with a dot are escaped with an additional dot to avoid +// looking like the end of the sequence. +// +// The decoded form returned by the Reader's Read method +// rewrites the "\r\n" line endings into the simpler "\n", +// removes leading dot escapes if present, and stops with error os.EOF +// after consuming (and discarding) the end-of-sequence line. +func (r *Reader) DotReader() io.Reader { + r.closeDot() + r.dot = &dotReader{r: r} + return r.dot +} + +type dotReader struct { + r *Reader + state int +} + +// Read satisfies reads by decoding dot-encoded data read from d.r. +func (d *dotReader) Read(b []byte) (n int, err os.Error) { + // Run data through a simple state machine to + // elide leading dots, rewrite trailing \r\n into \n, + // and detect ending .\r\n line. + const ( + stateBeginLine = iota // beginning of line; initial state; must be zero + stateDot // read . at beginning of line + stateDotCR // read .\r at beginning of line + stateCR // read \r (possibly at end of line) + stateData // reading data in middle of line + stateEOF // reached .\r\n end marker line + ) + br := d.r.R + for n < len(b) && d.state != stateEOF { + var c byte + c, err = br.ReadByte() + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + break + } + switch d.state { + case stateBeginLine: + if c == '.' { + d.state = stateDot + continue + } + if c == '\r' { + d.state = stateCR + continue + } + d.state = stateData + + case stateDot: + if c == '\r' { + d.state = stateDotCR + continue + } + if c == '\n' { + d.state = stateEOF + continue + } + d.state = stateData + + case stateDotCR: + if c == '\n' { + d.state = stateEOF + continue + } + // Not part of .\r\n. + // Consume leading dot and emit saved \r. + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateCR: + if c == '\n' { + d.state = stateBeginLine + break + } + // Not part of \r\n. Emit saved \r + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateData: + if c == '\r' { + d.state = stateCR + continue + } + if c == '\n' { + d.state = stateBeginLine + } + } + b[n] = c + n++ + } + if err == nil && d.state == stateEOF { + err = os.EOF + } + if err != nil && d.r.dot == d { + d.r.dot = nil + } + return +} + +// closeDot drains the current DotReader if any, +// making sure that it reads until the ending dot line. +func (r *Reader) closeDot() { + if r.dot == nil { + return + } + buf := make([]byte, 128) + for r.dot != nil { + // When Read reaches EOF or an error, + // it will set r.dot == nil. + r.dot.Read(buf) + } +} + +// ReadDotBytes reads a dot-encoding and returns the decoded data. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotBytes() ([]byte, os.Error) { + return ioutil.ReadAll(r.DotReader()) +} + +// ReadDotLines reads a dot-encoding and returns a slice +// containing the decoded lines, with the final \r\n or \n elided from each. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotLines() ([]string, os.Error) { + // We could use ReadDotBytes and then Split it, + // but reading a line at a time avoids needing a + // large contiguous block of memory and is simpler. + var v []string + var err os.Error + for { + var line string + line, err = r.ReadLine() + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + break + } + + // Dot by itself marks end; otherwise cut one dot. + if len(line) > 0 && line[0] == '.' { + if len(line) == 1 { + break + } + line = line[1:] + } + v = append(v, line) + } + return v, err +} + +// ReadMIMEHeader reads a MIME-style header from r. +// The header is a sequence of possibly continued Key: Value lines +// ending in a blank line. +// The returned map m maps CanonicalMIMEHeaderKey(key) to a +// sequence of values in the same order encountered in the input. +// +// For example, consider this input: +// +// My-Key: Value 1 +// Long-Key: Even +// Longer Value +// My-Key: Value 2 +// +// Given that input, ReadMIMEHeader returns the map: +// +// map[string][]string{ +// "My-Key": {"Value 1", "Value 2"}, +// "Long-Key": {"Even Longer Value"}, +// } +// +func (r *Reader) ReadMIMEHeader() (MIMEHeader, os.Error) { + m := make(MIMEHeader) + for { + kv, err := r.readContinuedLineSlice() + if len(kv) == 0 { + return m, err + } + + // Key ends at first colon; must not have spaces. + i := bytes.IndexByte(kv, ':') + if i < 0 || bytes.IndexByte(kv[0:i], ' ') >= 0 { + return m, ProtocolError("malformed MIME header line: " + string(kv)) + } + key := CanonicalMIMEHeaderKey(string(kv[0:i])) + + // Skip initial spaces in value. + i++ // skip colon + for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') { + i++ + } + value := string(kv[i:]) + + m[key] = append(m[key], value) + + if err != nil { + return m, err + } + } + panic("unreachable") +} + +// CanonicalMIMEHeaderKey returns the canonical format of the +// MIME header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +func CanonicalMIMEHeaderKey(s string) string { + // Quick check for canonical encoding. + needUpper := true + for i := 0; i < len(s); i++ { + c := s[i] + if needUpper && 'a' <= c && c <= 'z' { + goto MustRewrite + } + if !needUpper && 'A' <= c && c <= 'Z' { + goto MustRewrite + } + needUpper = c == '-' + } + return s + +MustRewrite: + // Canonicalize: first letter upper case + // and upper case after each dash. + // (Host, User-Agent, If-Modified-Since). + // MIME headers are ASCII only, so no Unicode issues. + a := []byte(s) + upper := true + for i, v := range a { + if upper && 'a' <= v && v <= 'z' { + a[i] = v + 'A' - 'a' + } + if !upper && 'A' <= v && v <= 'Z' { + a[i] = v + 'a' - 'A' + } + upper = v == '-' + } + return string(a) +} diff --git a/src/pkg/net/textproto/reader_test.go b/src/pkg/net/textproto/reader_test.go new file mode 100644 index 000000000..0658e58b8 --- /dev/null +++ b/src/pkg/net/textproto/reader_test.go @@ -0,0 +1,140 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "io" + "os" + "reflect" + "strings" + "testing" +) + +type canonicalHeaderKeyTest struct { + in, out string +} + +var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{ + {"a-b-c", "A-B-C"}, + {"a-1-c", "A-1-C"}, + {"User-Agent", "User-Agent"}, + {"uSER-aGENT", "User-Agent"}, + {"user-agent", "User-Agent"}, + {"USER-AGENT", "User-Agent"}, +} + +func TestCanonicalMIMEHeaderKey(t *testing.T) { + for _, tt := range canonicalHeaderKeyTests { + if s := CanonicalMIMEHeaderKey(tt.in); s != tt.out { + t.Errorf("CanonicalMIMEHeaderKey(%q) = %q, want %q", tt.in, s, tt.out) + } + } +} + +func reader(s string) *Reader { + return NewReader(bufio.NewReader(strings.NewReader(s))) +} + +func TestReadLine(t *testing.T) { + r := reader("line1\nline2\n") + s, err := r.ReadLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "line2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "" || err != os.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadContinuedLine(t *testing.T) { + r := reader("line1\nline\n 2\nline3\n") + s, err := r.ReadContinuedLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line 2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line3" || err != nil { + t.Fatalf("Line 3: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "" || err != os.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadCodeLine(t *testing.T) { + r := reader("123 hi\n234 bye\n345 no way\n") + code, msg, err := r.ReadCodeLine(0) + if code != 123 || msg != "hi" || err != nil { + t.Fatalf("Line 1: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(23) + if code != 234 || msg != "bye" || err != nil { + t.Fatalf("Line 2: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(346) + if code != 345 || msg != "no way" || err == nil { + t.Fatalf("Line 3: %d, %s, %v", code, msg, err) + } + if e, ok := err.(*Error); !ok || e.Code != code || e.Msg != msg { + t.Fatalf("Line 3: wrong error %v\n", err) + } + code, msg, err = r.ReadCodeLine(1) + if code != 0 || msg != "" || err != os.EOF { + t.Fatalf("EOF: %d, %s, %v", code, msg, err) + } +} + +func TestReadDotLines(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanother\n") + s, err := r.ReadDotLines() + want := []string{"dotlines", "foo", ".bar", "..baz", "quux", ""} + if !reflect.DeepEqual(s, want) || err != nil { + t.Fatalf("ReadDotLines: %v, %v", s, err) + } + + s, err = r.ReadDotLines() + want = []string{"another"} + if !reflect.DeepEqual(s, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotLines2: %v, %v", s, err) + } +} + +func TestReadDotBytes(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanot.her\r\n") + b, err := r.ReadDotBytes() + want := []byte("dotlines\nfoo\n.bar\n..baz\nquux\n\n") + if !reflect.DeepEqual(b, want) || err != nil { + t.Fatalf("ReadDotBytes: %q, %v", b, err) + } + + b, err = r.ReadDotBytes() + want = []byte("anot.her\n") + if !reflect.DeepEqual(b, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotBytes2: %q, %v", b, err) + } +} + +func TestReadMIMEHeader(t *testing.T) { + r := reader("my-key: Value 1 \r\nLong-key: Even \n Longer Value\r\nmy-Key: Value 2\r\n\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{ + "My-Key": {"Value 1", "Value 2"}, + "Long-Key": {"Even Longer Value"}, + } + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} diff --git a/src/pkg/net/textproto/textproto.go b/src/pkg/net/textproto/textproto.go new file mode 100644 index 000000000..9f19b5495 --- /dev/null +++ b/src/pkg/net/textproto/textproto.go @@ -0,0 +1,121 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package textproto implements generic support for text-based request/response +// protocols in the style of HTTP, NNTP, and SMTP. +// +// The package provides: +// +// Error, which represents a numeric error response from +// a server. +// +// Pipeline, to manage pipelined requests and responses +// in a client. +// +// Reader, to read numeric response code lines, +// key: value headers, lines wrapped with leading spaces +// on continuation lines, and whole text blocks ending +// with a dot on a line by itself. +// +// Writer, to write dot-encoded text blocks. +// +package textproto + +import ( + "bufio" + "fmt" + "io" + "net" + "os" +) + +// An Error represents a numeric error response from a server. +type Error struct { + Code int + Msg string +} + +func (e *Error) String() string { + return fmt.Sprintf("%03d %s", e.Code, e.Msg) +} + +// A ProtocolError describes a protocol violation such +// as an invalid response or a hung-up connection. +type ProtocolError string + +func (p ProtocolError) String() string { + return string(p) +} + +// A Conn represents a textual network protocol connection. +// It consists of a Reader and Writer to manage I/O +// and a Pipeline to sequence concurrent requests on the connection. +// These embedded types carry methods with them; +// see the documentation of those types for details. +type Conn struct { + Reader + Writer + Pipeline + conn io.ReadWriteCloser +} + +// NewConn returns a new Conn using conn for I/O. +func NewConn(conn io.ReadWriteCloser) *Conn { + return &Conn{ + Reader: Reader{R: bufio.NewReader(conn)}, + Writer: Writer{W: bufio.NewWriter(conn)}, + conn: conn, + } +} + +// Close closes the connection. +func (c *Conn) Close() os.Error { + return c.conn.Close() +} + +// Dial connects to the given address on the given network using net.Dial +// and then returns a new Conn for the connection. +func Dial(network, addr string) (*Conn, os.Error) { + c, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + return NewConn(c), nil +} + +// Cmd is a convenience method that sends a command after +// waiting its turn in the pipeline. The command text is the +// result of formatting format with args and appending \r\n. +// Cmd returns the id of the command, for use with StartResponse and EndResponse. +// +// For example, a client might run a HELP command that returns a dot-body +// by using: +// +// id, err := c.Cmd("HELP") +// if err != nil { +// return nil, err +// } +// +// c.StartResponse(id) +// defer c.EndResponse(id) +// +// if _, _, err = c.ReadCodeLine(110); err != nil { +// return nil, err +// } +// text, err := c.ReadDotAll() +// if err != nil { +// return nil, err +// } +// return c.ReadCodeLine(250) +// +func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err os.Error) { + id = c.Next() + c.StartRequest(id) + err = c.PrintfLine(format, args...) + c.EndRequest(id) + if err != nil { + return 0, err + } + return id, nil +} diff --git a/src/pkg/net/textproto/writer.go b/src/pkg/net/textproto/writer.go new file mode 100644 index 000000000..4e705f6c3 --- /dev/null +++ b/src/pkg/net/textproto/writer.go @@ -0,0 +1,119 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "fmt" + "io" + "os" +) + +// A Writer implements convenience methods for writing +// requests or responses to a text protocol network connection. +type Writer struct { + W *bufio.Writer + dot *dotWriter +} + +// NewWriter returns a new Writer writing to w. +func NewWriter(w *bufio.Writer) *Writer { + return &Writer{W: w} +} + +var crnl = []byte{'\r', '\n'} +var dotcrnl = []byte{'.', '\r', '\n'} + +// PrintfLine writes the formatted output followed by \r\n. +func (w *Writer) PrintfLine(format string, args ...interface{}) os.Error { + w.closeDot() + fmt.Fprintf(w.W, format, args...) + w.W.Write(crnl) + return w.W.Flush() +} + +// DotWriter returns a writer that can be used to write a dot-encoding to w. +// It takes care of inserting leading dots when necessary, +// translating line-ending \n into \r\n, and adding the final .\r\n line +// when the DotWriter is closed. The caller should close the +// DotWriter before the next call to a method on w. +// +// See the documentation for Reader's DotReader method for details about dot-encoding. +func (w *Writer) DotWriter() io.WriteCloser { + w.closeDot() + w.dot = &dotWriter{w: w} + return w.dot +} + +func (w *Writer) closeDot() { + if w.dot != nil { + w.dot.Close() // sets w.dot = nil + } +} + +type dotWriter struct { + w *Writer + state int +} + +const ( + wstateBeginLine = iota // beginning of line; initial state; must be zero + wstateCR // wrote \r (possibly at end of line) + wstateData // writing data in middle of line +) + +func (d *dotWriter) Write(b []byte) (n int, err os.Error) { + bw := d.w.W + for n < len(b) { + c := b[n] + switch d.state { + case wstateBeginLine: + d.state = wstateData + if c == '.' { + // escape leading dot + bw.WriteByte('.') + } + fallthrough + + case wstateData: + if c == '\r' { + d.state = wstateCR + } + if c == '\n' { + bw.WriteByte('\r') + d.state = wstateBeginLine + } + + case wstateCR: + d.state = wstateData + if c == '\n' { + d.state = wstateBeginLine + } + } + if err = bw.WriteByte(c); err != nil { + break + } + n++ + } + return +} + +func (d *dotWriter) Close() os.Error { + if d.w.dot == d { + d.w.dot = nil + } + bw := d.w.W + switch d.state { + default: + bw.WriteByte('\r') + fallthrough + case wstateCR: + bw.WriteByte('\n') + fallthrough + case wstateBeginLine: + bw.Write(dotcrnl) + } + return bw.Flush() +} diff --git a/src/pkg/net/textproto/writer_test.go b/src/pkg/net/textproto/writer_test.go new file mode 100644 index 000000000..e03ab5e15 --- /dev/null +++ b/src/pkg/net/textproto/writer_test.go @@ -0,0 +1,35 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "testing" +) + +func TestPrintfLine(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(bufio.NewWriter(&buf)) + err := w.PrintfLine("foo %d", 123) + if s := buf.String(); s != "foo 123\r\n" || err != nil { + t.Fatalf("s=%q; err=%s", s, err) + } +} + +func TestDotWriter(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(bufio.NewWriter(&buf)) + d := w.DotWriter() + n, err := d.Write([]byte("abc\n.def\n..ghi\n.jkl\n.")) + if n != 21 || err != nil { + t.Fatalf("Write: %d, %s", n, err) + } + d.Close() + want := "abc\r\n..def\r\n...ghi\r\n..jkl\r\n..\r\n.\r\n" + if s := buf.String(); s != want { + t.Fatalf("wrote %q", s) + } +} diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go new file mode 100644 index 000000000..0dbab5846 --- /dev/null +++ b/src/pkg/net/timeout_test.go @@ -0,0 +1,57 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "testing" + "time" +) + +func testTimeout(t *testing.T, network, addr string, readFrom bool) { + fd, err := Dial(network, addr) + if err != nil { + t.Errorf("dial %s %s failed: %v", network, addr, err) + return + } + defer fd.Close() + t0 := time.Nanoseconds() + fd.SetReadTimeout(1e8) // 100ms + var b [100]byte + var n int + var err1 os.Error + if readFrom { + n, _, err1 = fd.(PacketConn).ReadFrom(b[0:]) + } else { + n, err1 = fd.Read(b[0:]) + } + t1 := time.Nanoseconds() + what := "Read" + if readFrom { + what = "ReadFrom" + } + if n != 0 || err1 == nil || !err1.(Error).Timeout() { + t.Errorf("fd.%s on %s %s did not return 0, timeout: %v, %v", what, network, addr, n, err1) + } + if t1-t0 < 0.5e8 || t1-t0 > 1.5e8 { + t.Errorf("fd.%s on %s %s took %f seconds, expected 0.1", what, network, addr, float64(t1-t0)/1e9) + } +} + +func TestTimeoutUDP(t *testing.T) { + testTimeout(t, "udp", "127.0.0.1:53", false) + testTimeout(t, "udp", "127.0.0.1:53", true) +} + +func TestTimeoutTCP(t *testing.T) { + // set up a listener that won't talk back + listening := make(chan string) + done := make(chan int) + go runServe(t, "tcp", "127.0.0.1:0", listening, done) + addr := <-listening + + testTimeout(t, "tcp", addr, false) + <-done +} diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go new file mode 100644 index 000000000..3dfa71675 --- /dev/null +++ b/src/pkg/net/udpsock.go @@ -0,0 +1,40 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// UDP sockets + +package net + +import ( + "os" +) + +// UDPAddr represents the address of a UDP end point. +type UDPAddr struct { + IP IP + Port int +} + +// Network returns the address's network name, "udp". +func (a *UDPAddr) Network() string { return "udp" } + +func (a *UDPAddr) String() string { + if a == nil { + return "<nil>" + } + return JoinHostPort(a.IP.String(), itoa(a.Port)) +} + +// ResolveUDPAddr parses addr as a UDP address of the form +// host:port and resolves domain names or port names to +// numeric addresses on the network net, which must be "udp", +// "udp4" or "udp6". A literal IPv6 host address must be +// enclosed in square brackets, as in "[::]:80". +func ResolveUDPAddr(net, addr string) (*UDPAddr, os.Error) { + ip, port, err := hostPortToIP(net, addr) + if err != nil { + return nil, err + } + return &UDPAddr{ip, port}, nil +} diff --git a/src/pkg/net/udpsock_plan9.go b/src/pkg/net/udpsock_plan9.go new file mode 100644 index 000000000..bb7196041 --- /dev/null +++ b/src/pkg/net/udpsock_plan9.go @@ -0,0 +1,187 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// UDP for Plan 9 + +package net + +import ( + "os" +) + +// UDPConn is the implementation of the Conn and PacketConn +// interfaces for UDP network connections. +type UDPConn struct { + plan9Conn +} + +// UDP-specific methods. + +// ReadFromUDP reads a UDP packet from c, copying the payload into b. +// It returns the number of bytes copied into b and the return address +// that was on the packet. +// +// ReadFromUDP can be made to time out and return an error with Timeout() == true +// after a fixed time limit; see SetTimeout and SetReadTimeout. +func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + if c.data == nil { + c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0) + if err != nil { + return 0, nil, err + } + } + buf := make([]byte, udpHeaderSize+len(b)) + m, err := c.data.Read(buf) + if err != nil { + return + } + if m < udpHeaderSize { + return 0, nil, os.NewError("short read reading UDP header") + } + buf = buf[:m] + + h, buf := unmarshalUDPHeader(buf) + n = copy(b, buf) + return n, &UDPAddr{h.raddr, int(h.rport)}, nil +} + +// ReadFrom implements the net.PacketConn ReadFrom method. +func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + return c.ReadFromUDP(b) +} + +// WriteToUDP writes a UDP packet to addr via c, copying the payload from b. +// +// WriteToUDP can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetWriteTimeout. +// On packet-oriented connections, write timeouts are rare. +func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + if c.data == nil { + c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0) + if err != nil { + return 0, err + } + } + h := new(udpHeader) + h.raddr = addr.IP.To16() + h.laddr = c.laddr.(*UDPAddr).IP.To16() + h.ifcaddr = IPv6zero // ignored (receive only) + h.rport = uint16(addr.Port) + h.lport = uint16(c.laddr.(*UDPAddr).Port) + + buf := make([]byte, udpHeaderSize+len(b)) + i := copy(buf, h.Bytes()) + copy(buf[i:], b) + return c.data.Write(buf) +} + +// WriteTo implements the net.PacketConn WriteTo method. +func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + a, ok := addr.(*UDPAddr) + if !ok { + return 0, &OpError{"writeto", "udp", addr, os.EINVAL} + } + return c.WriteToUDP(b, a) +} + +// DialUDP connects to the remote address raddr on the network net, +// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used +// as the local address for the connection. +func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err os.Error) { + switch net { + case "udp", "udp4", "udp6": + default: + return nil, UnknownNetworkError(net) + } + if raddr == nil { + return nil, &OpError{"dial", "udp", nil, errMissingAddress} + } + c1, err := dialPlan9(net, laddr, raddr) + if err != nil { + return + } + return &UDPConn{*c1}, nil +} + +const udpHeaderSize = 16*3 + 2*2 + +type udpHeader struct { + raddr, laddr, ifcaddr IP + rport, lport uint16 +} + +func (h *udpHeader) Bytes() []byte { + b := make([]byte, udpHeaderSize) + i := 0 + i += copy(b[i:i+16], h.raddr) + i += copy(b[i:i+16], h.laddr) + i += copy(b[i:i+16], h.ifcaddr) + b[i], b[i+1], i = byte(h.rport>>8), byte(h.rport), i+2 + b[i], b[i+1], i = byte(h.lport>>8), byte(h.lport), i+2 + return b +} + +func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) { + h := new(udpHeader) + h.raddr, b = IP(b[:16]), b[16:] + h.laddr, b = IP(b[:16]), b[16:] + h.ifcaddr, b = IP(b[:16]), b[16:] + h.rport, b = uint16(b[0])<<8|uint16(b[1]), b[2:] + h.lport, b = uint16(b[0])<<8|uint16(b[1]), b[2:] + return h, b +} + +// ListenUDP listens for incoming UDP packets addressed to the +// local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send UDP +// packets with per-packet addressing. +func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err os.Error) { + switch net { + case "udp", "udp4", "udp6": + default: + return nil, UnknownNetworkError(net) + } + if laddr == nil { + return nil, &OpError{"listen", "udp", nil, errMissingAddress} + } + l, err := listenPlan9(net, laddr) + if err != nil { + return + } + _, err = l.ctl.WriteString("headers") + if err != nil { + return + } + return &UDPConn{*l.plan9Conn()}, nil +} + +// JoinGroup joins the IPv4 multicast group named by addr. +// The UDPConn must use the "udp4" network. +func (c *UDPConn) JoinGroup(addr IP) os.Error { + if !c.ok() { + return os.EINVAL + } + return os.EPLAN9 +} + +// LeaveGroup exits the IPv4 multicast group named by addr. +func (c *UDPConn) LeaveGroup(addr IP) os.Error { + if !c.ok() { + return os.EINVAL + } + return os.EPLAN9 +} diff --git a/src/pkg/net/udpsock_posix.go b/src/pkg/net/udpsock_posix.go new file mode 100644 index 000000000..d4ea056f3 --- /dev/null +++ b/src/pkg/net/udpsock_posix.go @@ -0,0 +1,294 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// UDP sockets + +package net + +import ( + "os" + "syscall" +) + +func sockaddrToUDP(sa syscall.Sockaddr) Addr { + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + return &UDPAddr{sa.Addr[0:], sa.Port} + case *syscall.SockaddrInet6: + return &UDPAddr{sa.Addr[0:], sa.Port} + } + return nil +} + +func (a *UDPAddr) family() int { + if a == nil || len(a.IP) <= 4 { + return syscall.AF_INET + } + if a.IP.To4() != nil { + return syscall.AF_INET + } + return syscall.AF_INET6 +} + +func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) { + return ipToSockaddr(family, a.IP, a.Port) +} + +func (a *UDPAddr) toAddr() sockaddr { + if a == nil { // nil *UDPAddr + return nil // nil interface + } + return a +} + +// UDPConn is the implementation of the Conn and PacketConn +// interfaces for UDP network connections. +type UDPConn struct { + fd *netFD +} + +func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{fd} } + +func (c *UDPConn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *UDPConn) Read(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Read(b) +} + +// Write implements the net.Conn Write method. +func (c *UDPConn) Write(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Write(b) +} + +// Close closes the UDP connection. +func (c *UDPConn) Close() os.Error { + if !c.ok() { + return os.EINVAL + } + err := c.fd.Close() + c.fd = nil + return err +} + +// LocalAddr returns the local network address. +func (c *UDPConn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address, a *UDPAddr. +func (c *UDPConn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *UDPConn) SetTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setTimeout(c.fd, nsec) +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *UDPConn) SetReadTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadTimeout(c.fd, nsec) +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *UDPConn) SetWriteTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteTimeout(c.fd, nsec) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *UDPConn) SetReadBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *UDPConn) SetWriteBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// UDP-specific methods. + +// ReadFromUDP reads a UDP packet from c, copying the payload into b. +// It returns the number of bytes copied into b and the return address +// that was on the packet. +// +// ReadFromUDP can be made to time out and return an error with Timeout() == true +// after a fixed time limit; see SetTimeout and SetReadTimeout. +func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, sa, err := c.fd.ReadFrom(b) + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + addr = &UDPAddr{sa.Addr[0:], sa.Port} + case *syscall.SockaddrInet6: + addr = &UDPAddr{sa.Addr[0:], sa.Port} + } + return +} + +// ReadFrom implements the net.PacketConn ReadFrom method. +func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, uaddr, err := c.ReadFromUDP(b) + return n, uaddr.toAddr(), err +} + +// WriteToUDP writes a UDP packet to addr via c, copying the payload from b. +// +// WriteToUDP can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetWriteTimeout. +// On packet-oriented connections, write timeouts are rare. +func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + sa, err1 := addr.sockaddr(c.fd.family) + if err1 != nil { + return 0, &OpError{Op: "write", Net: "udp", Addr: addr, Error: err1} + } + return c.fd.WriteTo(b, sa) +} + +// WriteTo implements the net.PacketConn WriteTo method. +func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + a, ok := addr.(*UDPAddr) + if !ok { + return 0, &OpError{"writeto", "udp", addr, os.EINVAL} + } + return c.WriteToUDP(b, a) +} + +// DialUDP connects to the remote address raddr on the network net, +// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used +// as the local address for the connection. +func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err os.Error) { + switch net { + case "udp", "udp4", "udp6": + default: + return nil, UnknownNetworkError(net) + } + if raddr == nil { + return nil, &OpError{"dial", "udp", nil, errMissingAddress} + } + fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) + if e != nil { + return nil, e + } + return newUDPConn(fd), nil +} + +// ListenUDP listens for incoming UDP packets addressed to the +// local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send UDP +// packets with per-packet addressing. +func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err os.Error) { + switch net { + case "udp", "udp4", "udp6": + default: + return nil, UnknownNetworkError(net) + } + if laddr == nil { + return nil, &OpError{"listen", "udp", nil, errMissingAddress} + } + fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) + if e != nil { + return nil, e + } + return newUDPConn(fd), nil +} + +// BindToDevice binds a UDPConn to a network interface. +func (c *UDPConn) BindToDevice(device string) os.Error { + if !c.ok() { + return os.EINVAL + } + c.fd.incref() + defer c.fd.decref() + return os.NewSyscallError("setsockopt", syscall.BindToDevice(c.fd.sysfd, device)) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *UDPConn) File() (f *os.File, err os.Error) { return c.fd.dup() } + +var errInvalidMulticast = os.NewError("invalid IPv4 multicast address") + +// JoinGroup joins the IPv4 multicast group named by addr. +// The UDPConn must use the "udp4" network. +func (c *UDPConn) JoinGroup(addr IP) os.Error { + if !c.ok() { + return os.EINVAL + } + ip := addr.To4() + if ip == nil { + return &OpError{"joingroup", "udp", &IPAddr{ip}, errInvalidMulticast} + } + mreq := &syscall.IPMreq{ + Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}, + } + err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(c.fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)) + if err != nil { + return &OpError{"joingroup", "udp", &IPAddr{ip}, err} + } + return nil +} + +// LeaveGroup exits the IPv4 multicast group named by addr. +func (c *UDPConn) LeaveGroup(addr IP) os.Error { + if !c.ok() { + return os.EINVAL + } + ip := addr.To4() + if ip == nil { + return &OpError{"leavegroup", "udp", &IPAddr{ip}, errInvalidMulticast} + } + mreq := &syscall.IPMreq{ + Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}, + } + err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(c.fd.sysfd, syscall.IPPROTO_IP, syscall.IP_DROP_MEMBERSHIP, mreq)) + if err != nil { + return &OpError{"leavegroup", "udp", &IPAddr{ip}, err} + } + return nil +} diff --git a/src/pkg/net/unixsock.go b/src/pkg/net/unixsock.go new file mode 100644 index 000000000..d5040f9a2 --- /dev/null +++ b/src/pkg/net/unixsock.go @@ -0,0 +1,50 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Unix domain sockets + +package net + +import ( + "os" +) + +// UnixAddr represents the address of a Unix domain socket end point. +type UnixAddr struct { + Name string + Net string +} + +// Network returns the address's network name, "unix" or "unixgram". +func (a *UnixAddr) Network() string { + return a.Net +} + +func (a *UnixAddr) String() string { + if a == nil { + return "<nil>" + } + return a.Name +} + +func (a *UnixAddr) toAddr() Addr { + if a == nil { // nil *UnixAddr + return nil // nil interface + } + return a +} + +// ResolveUnixAddr parses addr as a Unix domain socket address. +// The string net gives the network name, "unix", "unixgram" or +// "unixpacket". +func ResolveUnixAddr(net, addr string) (*UnixAddr, os.Error) { + switch net { + case "unix": + case "unixpacket": + case "unixgram": + default: + return nil, UnknownNetworkError(net) + } + return &UnixAddr{addr, net}, nil +} diff --git a/src/pkg/net/unixsock_plan9.go b/src/pkg/net/unixsock_plan9.go new file mode 100644 index 000000000..7e212df8a --- /dev/null +++ b/src/pkg/net/unixsock_plan9.go @@ -0,0 +1,105 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Unix domain sockets stubs for Plan 9 + +package net + +import ( + "os" +) + +// UnixConn is an implementation of the Conn interface +// for connections to Unix domain sockets. +type UnixConn bool + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *UnixConn) Read(b []byte) (n int, err os.Error) { + return 0, os.EPLAN9 +} + +// Write implements the net.Conn Write method. +func (c *UnixConn) Write(b []byte) (n int, err os.Error) { + return 0, os.EPLAN9 +} + +// Close closes the Unix domain connection. +func (c *UnixConn) Close() os.Error { + return os.EPLAN9 +} + +// LocalAddr returns the local network address, a *UnixAddr. +// Unlike in other protocols, LocalAddr is usually nil for dialed connections. +func (c *UnixConn) LocalAddr() Addr { + return nil +} + +// RemoteAddr returns the remote network address, a *UnixAddr. +// Unlike in other protocols, RemoteAddr is usually nil for connections +// accepted by a listener. +func (c *UnixConn) RemoteAddr() Addr { + return nil +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *UnixConn) SetTimeout(nsec int64) os.Error { + return os.EPLAN9 +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *UnixConn) SetReadTimeout(nsec int64) os.Error { + return os.EPLAN9 +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *UnixConn) SetWriteTimeout(nsec int64) os.Error { + return os.EPLAN9 +} + +// ReadFrom implements the net.PacketConn ReadFrom method. +func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { + err = os.EPLAN9 + return +} + +// WriteTo implements the net.PacketConn WriteTo method. +func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { + err = os.EPLAN9 + return +} + +// DialUnix connects to the remote address raddr on the network net, +// which must be "unix" or "unixgram". If laddr is not nil, it is used +// as the local address for the connection. +func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err os.Error) { + return nil, os.EPLAN9 +} + +// UnixListener is a Unix domain socket listener. +// Clients should typically use variables of type Listener +// instead of assuming Unix domain sockets. +type UnixListener bool + +// ListenUnix announces on the Unix domain socket laddr and returns a Unix listener. +// Net must be "unix" (stream sockets). +func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) { + return nil, os.EPLAN9 +} + +// Accept implements the Accept method in the Listener interface; +// it waits for the next call and returns a generic Conn. +func (l *UnixListener) Accept() (c Conn, err os.Error) { + return nil, os.EPLAN9 +} + +// Close stops listening on the Unix address. +// Already accepted connections are not closed. +func (l *UnixListener) Close() os.Error { + return os.EPLAN9 +} + +// Addr returns the listener's network address. +func (l *UnixListener) Addr() Addr { return nil } diff --git a/src/pkg/net/unixsock_posix.go b/src/pkg/net/unixsock_posix.go new file mode 100644 index 000000000..38c6fe9eb --- /dev/null +++ b/src/pkg/net/unixsock_posix.go @@ -0,0 +1,418 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Unix domain sockets + +package net + +import ( + "os" + "syscall" +) + +func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err os.Error) { + var proto int + switch net { + default: + return nil, UnknownNetworkError(net) + case "unix": + proto = syscall.SOCK_STREAM + case "unixgram": + proto = syscall.SOCK_DGRAM + case "unixpacket": + proto = syscall.SOCK_SEQPACKET + } + + var la, ra syscall.Sockaddr + switch mode { + default: + panic("unixSocket mode " + mode) + + case "dial": + if laddr != nil { + la = &syscall.SockaddrUnix{Name: laddr.Name} + } + if raddr != nil { + ra = &syscall.SockaddrUnix{Name: raddr.Name} + } else if proto != syscall.SOCK_DGRAM || laddr == nil { + return nil, &OpError{Op: mode, Net: net, Error: errMissingAddress} + } + + case "listen": + if laddr == nil { + return nil, &OpError{mode, net, nil, errMissingAddress} + } + la = &syscall.SockaddrUnix{Name: laddr.Name} + if raddr != nil { + return nil, &OpError{Op: mode, Net: net, Addr: raddr, Error: &AddrError{Error: "unexpected remote address", Addr: raddr.String()}} + } + } + + f := sockaddrToUnix + if proto == syscall.SOCK_DGRAM { + f = sockaddrToUnixgram + } else if proto == syscall.SOCK_SEQPACKET { + f = sockaddrToUnixpacket + } + + fd, oserr := socket(net, syscall.AF_UNIX, proto, 0, la, ra, f) + if oserr != nil { + goto Error + } + return fd, nil + +Error: + addr := raddr + if mode == "listen" { + addr = laddr + } + return nil, &OpError{Op: mode, Net: net, Addr: addr, Error: oserr} +} + +func sockaddrToUnix(sa syscall.Sockaddr) Addr { + if s, ok := sa.(*syscall.SockaddrUnix); ok { + return &UnixAddr{s.Name, "unix"} + } + return nil +} + +func sockaddrToUnixgram(sa syscall.Sockaddr) Addr { + if s, ok := sa.(*syscall.SockaddrUnix); ok { + return &UnixAddr{s.Name, "unixgram"} + } + return nil +} + +func sockaddrToUnixpacket(sa syscall.Sockaddr) Addr { + if s, ok := sa.(*syscall.SockaddrUnix); ok { + return &UnixAddr{s.Name, "unixpacket"} + } + return nil +} + +func protoToNet(proto int) string { + switch proto { + case syscall.SOCK_STREAM: + return "unix" + case syscall.SOCK_SEQPACKET: + return "unixpacket" + case syscall.SOCK_DGRAM: + return "unixgram" + default: + panic("protoToNet unknown protocol") + } + return "" +} + +// UnixConn is an implementation of the Conn interface +// for connections to Unix domain sockets. +type UnixConn struct { + fd *netFD +} + +func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{fd} } + +func (c *UnixConn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *UnixConn) Read(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Read(b) +} + +// Write implements the net.Conn Write method. +func (c *UnixConn) Write(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Write(b) +} + +// Close closes the Unix domain connection. +func (c *UnixConn) Close() os.Error { + if !c.ok() { + return os.EINVAL + } + err := c.fd.Close() + c.fd = nil + return err +} + +// LocalAddr returns the local network address, a *UnixAddr. +// Unlike in other protocols, LocalAddr is usually nil for dialed connections. +func (c *UnixConn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address, a *UnixAddr. +// Unlike in other protocols, RemoteAddr is usually nil for connections +// accepted by a listener. +func (c *UnixConn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *UnixConn) SetTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setTimeout(c.fd, nsec) +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *UnixConn) SetReadTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadTimeout(c.fd, nsec) +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *UnixConn) SetWriteTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteTimeout(c.fd, nsec) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *UnixConn) SetReadBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *UnixConn) SetWriteBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// ReadFromUnix reads a packet from c, copying the payload into b. +// It returns the number of bytes copied into b and the return address +// that was on the packet. +// +// ReadFromUnix can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetReadTimeout. +func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, sa, err := c.fd.ReadFrom(b) + switch sa := sa.(type) { + case *syscall.SockaddrUnix: + addr = &UnixAddr{sa.Name, protoToNet(c.fd.proto)} + } + return +} + +// ReadFrom implements the net.PacketConn ReadFrom method. +func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, uaddr, err := c.ReadFromUnix(b) + return n, uaddr.toAddr(), err +} + +// WriteToUnix writes a packet to addr via c, copying the payload from b. +// +// WriteToUnix can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetWriteTimeout. +// On packet-oriented connections, write timeouts are rare. +func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + if addr.Net != protoToNet(c.fd.proto) { + return 0, os.EAFNOSUPPORT + } + sa := &syscall.SockaddrUnix{Name: addr.Name} + return c.fd.WriteTo(b, sa) +} + +// WriteTo implements the net.PacketConn WriteTo method. +func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + a, ok := addr.(*UnixAddr) + if !ok { + return 0, &OpError{"writeto", "unix", addr, os.EINVAL} + } + return c.WriteToUnix(b, a) +} + +func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err os.Error) { + if !c.ok() { + return 0, 0, 0, nil, os.EINVAL + } + n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob) + switch sa := sa.(type) { + case *syscall.SockaddrUnix: + addr = &UnixAddr{sa.Name, protoToNet(c.fd.proto)} + } + return +} + +func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err os.Error) { + if !c.ok() { + return 0, 0, os.EINVAL + } + if addr != nil { + if addr.Net != protoToNet(c.fd.proto) { + return 0, 0, os.EAFNOSUPPORT + } + sa := &syscall.SockaddrUnix{Name: addr.Name} + return c.fd.WriteMsg(b, oob, sa) + } + return c.fd.WriteMsg(b, oob, nil) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *UnixConn) File() (f *os.File, err os.Error) { return c.fd.dup() } + +// DialUnix connects to the remote address raddr on the network net, +// which must be "unix" or "unixgram". If laddr is not nil, it is used +// as the local address for the connection. +func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err os.Error) { + fd, e := unixSocket(net, laddr, raddr, "dial") + if e != nil { + return nil, e + } + return newUnixConn(fd), nil +} + +// UnixListener is a Unix domain socket listener. +// Clients should typically use variables of type Listener +// instead of assuming Unix domain sockets. +type UnixListener struct { + fd *netFD + path string +} + +// ListenUnix announces on the Unix domain socket laddr and returns a Unix listener. +// Net must be "unix" (stream sockets). +func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) { + if net != "unix" && net != "unixgram" && net != "unixpacket" { + return nil, UnknownNetworkError(net) + } + if laddr != nil { + laddr = &UnixAddr{laddr.Name, net} // make our own copy + } + fd, err := unixSocket(net, laddr, nil, "listen") + if err != nil { + return nil, err + } + e1 := syscall.Listen(fd.sysfd, 8) // listenBacklog()); + if e1 != 0 { + closesocket(fd.sysfd) + return nil, &OpError{Op: "listen", Net: "unix", Addr: laddr, Error: os.Errno(e1)} + } + return &UnixListener{fd, laddr.Name}, nil +} + +// AcceptUnix accepts the next incoming call and returns the new connection +// and the remote address. +func (l *UnixListener) AcceptUnix() (c *UnixConn, err os.Error) { + if l == nil || l.fd == nil { + return nil, os.EINVAL + } + fd, e := l.fd.accept(sockaddrToUnix) + if e != nil { + return nil, e + } + c = newUnixConn(fd) + return c, nil +} + +// Accept implements the Accept method in the Listener interface; +// it waits for the next call and returns a generic Conn. +func (l *UnixListener) Accept() (c Conn, err os.Error) { + c1, err := l.AcceptUnix() + if err != nil { + return nil, err + } + return c1, nil +} + +// Close stops listening on the Unix address. +// Already accepted connections are not closed. +func (l *UnixListener) Close() os.Error { + if l == nil || l.fd == nil { + return os.EINVAL + } + + // The operating system doesn't clean up + // the file that announcing created, so + // we have to clean it up ourselves. + // There's a race here--we can't know for + // sure whether someone else has come along + // and replaced our socket name already-- + // but this sequence (remove then close) + // is at least compatible with the auto-remove + // sequence in ListenUnix. It's only non-Go + // programs that can mess us up. + if l.path[0] != '@' { + syscall.Unlink(l.path) + } + err := l.fd.Close() + l.fd = nil + return err +} + +// Addr returns the listener's network address. +func (l *UnixListener) Addr() Addr { return l.fd.laddr } + +// SetTimeout sets the deadline associated wuth the listener +func (l *UnixListener) SetTimeout(nsec int64) (err os.Error) { + if l == nil || l.fd == nil { + return os.EINVAL + } + return setTimeout(l.fd, nsec) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (l *UnixListener) File() (f *os.File, err os.Error) { return l.fd.dup() } + +// ListenUnixgram listens for incoming Unix datagram packets addressed to the +// local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send UDP +// packets with per-packet addressing. The network net must be "unixgram". +func ListenUnixgram(net string, laddr *UnixAddr) (c *UDPConn, err os.Error) { + switch net { + case "unixgram": + default: + return nil, UnknownNetworkError(net) + } + if laddr == nil { + return nil, &OpError{"listen", "unixgram", nil, errMissingAddress} + } + fd, e := unixSocket(net, laddr, nil, "listen") + if e != nil { + return nil, e + } + return newUDPConn(fd), nil +} |