diff options
Diffstat (limited to 'src/pkg/net')
175 files changed, 22537 insertions, 1873 deletions
diff --git a/src/pkg/net/Makefile b/src/pkg/net/Makefile index eba9e26d9..a02798c73 100644 --- a/src/pkg/net/Makefile +++ b/src/pkg/net/Makefile @@ -9,6 +9,7 @@ GOFILES=\ dial.go\ dnsclient.go\ dnsmsg.go\ + doc.go\ hosts.go\ interface.go\ ip.go\ @@ -21,14 +22,14 @@ GOFILES=\ udpsock.go\ unixsock.go\ -GOFILES_freebsd=\ +GOFILES_darwin=\ dnsclient_unix.go\ dnsconfig.go\ fd.go\ fd_$(GOOS).go\ file.go\ interface_bsd.go\ - interface_freebsd.go\ + interface_darwin.go\ iprawsock_posix.go\ ipsock_posix.go\ lookup_unix.go\ @@ -37,26 +38,31 @@ GOFILES_freebsd=\ sendfile_stub.go\ sock.go\ sock_bsd.go\ + sockopt.go\ + sockopt_bsd.go\ + sockoptip.go\ + sockoptip_bsd.go\ + sockoptip_darwin.go\ tcpsock_posix.go\ udpsock_posix.go\ unixsock_posix.go\ ifeq ($(CGO_ENABLED),1) -CGOFILES_freebsd=\ +CGOFILES_darwin=\ cgo_bsd.go\ cgo_unix.go else -GOFILES_freebsd+=cgo_stub.go +GOFILES_darwin+=cgo_stub.go endif -GOFILES_darwin=\ +GOFILES_freebsd=\ dnsclient_unix.go\ dnsconfig.go\ fd.go\ fd_$(GOOS).go\ file.go\ interface_bsd.go\ - interface_darwin.go\ + interface_freebsd.go\ iprawsock_posix.go\ ipsock_posix.go\ lookup_unix.go\ @@ -65,16 +71,21 @@ GOFILES_darwin=\ sendfile_stub.go\ sock.go\ sock_bsd.go\ + sockopt.go\ + sockopt_bsd.go\ + sockoptip.go\ + sockoptip_bsd.go\ + sockoptip_freebsd.go\ tcpsock_posix.go\ udpsock_posix.go\ unixsock_posix.go\ ifeq ($(CGO_ENABLED),1) -CGOFILES_darwin=\ +CGOFILES_freebsd=\ cgo_bsd.go\ cgo_unix.go else -GOFILES_darwin+=cgo_stub.go +GOFILES_freebsd+=cgo_stub.go endif GOFILES_linux=\ @@ -92,6 +103,10 @@ GOFILES_linux=\ sendfile_linux.go\ sock.go\ sock_linux.go\ + sockopt.go\ + sockopt_linux.go\ + sockoptip.go\ + sockoptip_linux.go\ tcpsock_posix.go\ udpsock_posix.go\ unixsock_posix.go\ @@ -104,6 +119,32 @@ else GOFILES_linux+=cgo_stub.go endif +GOFILES_netbsd=\ + dnsclient_unix.go\ + dnsconfig.go\ + fd.go\ + fd_$(GOOS).go\ + file.go\ + interface_bsd.go\ + interface_netbsd.go\ + iprawsock_posix.go\ + ipsock_posix.go\ + lookup_unix.go\ + newpollserver.go\ + port.go\ + sendfile_stub.go\ + sock.go\ + sock_bsd.go\ + sockopt.go\ + sockopt_bsd.go\ + sockoptip.go\ + sockoptip_bsd.go\ + sockoptip_netbsd.go\ + tcpsock_posix.go\ + udpsock_posix.go\ + unixsock_posix.go\ + cgo_stub.go\ + GOFILES_openbsd=\ dnsclient_unix.go\ dnsconfig.go\ @@ -120,6 +161,11 @@ GOFILES_openbsd=\ sendfile_stub.go\ sock.go\ sock_bsd.go\ + sockopt.go\ + sockopt_bsd.go\ + sockoptip.go\ + sockoptip_bsd.go\ + sockoptip_openbsd.go\ tcpsock_posix.go\ udpsock_posix.go\ unixsock_posix.go\ @@ -145,6 +191,10 @@ GOFILES_windows=\ sendfile_windows.go\ sock.go\ sock_windows.go\ + sockopt.go\ + sockopt_windows.go\ + sockoptip.go\ + sockoptip_windows.go\ tcpsock_posix.go\ udpsock_posix.go\ unixsock_posix.go\ diff --git a/src/pkg/net/cgo_stub.go b/src/pkg/net/cgo_stub.go index 565cbe7fe..52e57d740 100644 --- a/src/pkg/net/cgo_stub.go +++ b/src/pkg/net/cgo_stub.go @@ -2,26 +2,24 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build openbsd +// +build !cgo // Stub cgo routines for systems that do not use cgo to do network lookups. package net -import "os" - -func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) { +func cgoLookupHost(name string) (addrs []string, err error, completed bool) { return nil, nil, false } -func cgoLookupPort(network, service string) (port int, err os.Error, completed bool) { +func cgoLookupPort(network, service string) (port int, err error, completed bool) { return 0, nil, false } -func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { +func cgoLookupIP(name string) (addrs []IP, err error, completed bool) { return nil, nil, false } -func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { +func cgoLookupCNAME(name string) (cname string, err error, completed bool) { return "", nil, false } diff --git a/src/pkg/net/cgo_unix.go b/src/pkg/net/cgo_unix.go index ec2a393e8..36a3f3d34 100644 --- a/src/pkg/net/cgo_unix.go +++ b/src/pkg/net/cgo_unix.go @@ -18,12 +18,11 @@ package net import "C" import ( - "os" "syscall" "unsafe" ) -func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) { +func cgoLookupHost(name string) (addrs []string, err error, completed bool) { ip, err, completed := cgoLookupIP(name) for _, p := range ip { addrs = append(addrs, p.String()) @@ -31,7 +30,7 @@ func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) { return } -func cgoLookupPort(net, service string) (port int, err os.Error, completed bool) { +func cgoLookupPort(net, service string) (port int, err error, completed bool) { var res *C.struct_addrinfo var hints C.struct_addrinfo @@ -78,7 +77,7 @@ func cgoLookupPort(net, service string) (port int, err os.Error, completed bool) return 0, &AddrError{"unknown port", net + "/" + service}, true } -func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, completed bool) { +func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err error, completed bool) { var res *C.struct_addrinfo var hints C.struct_addrinfo @@ -98,11 +97,11 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, comp if gerrno == C.EAI_NONAME { str = noSuchHost } else if gerrno == C.EAI_SYSTEM { - str = err.String() + str = err.Error() } else { str = C.GoString(C.gai_strerror(gerrno)) } - return nil, "", &DNSError{Error: str, Name: name}, true + return nil, "", &DNSError{Err: str, Name: name}, true } defer C.freeaddrinfo(res) if res != nil { @@ -133,12 +132,12 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, comp return addrs, cname, nil, true } -func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { +func cgoLookupIP(name string) (addrs []IP, err error, completed bool) { addrs, _, err, completed = cgoLookupIPCNAME(name) return } -func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { +func cgoLookupCNAME(name string) (cname string, err error, completed bool) { _, cname, err, completed = cgoLookupIPCNAME(name) return } diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index 10c67dcc4..5d596bcb6 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -4,26 +4,63 @@ package net -import "os" +import ( + "time" +) -func resolveNetAddr(op, net, addr string) (a Addr, err os.Error) { - if addr == "" { - return nil, &OpError{op, net, nil, errMissingAddress} +func parseDialNetwork(net string) (afnet string, proto int, err error) { + i := last(net, ':') + if i < 0 { // no colon + switch net { + case "tcp", "tcp4", "tcp6": + case "udp", "udp4", "udp6": + case "unix", "unixgram", "unixpacket": + default: + return "", 0, UnknownNetworkError(net) + } + return net, 0, nil } - switch net { - case "tcp", "tcp4", "tcp6": - a, err = ResolveTCPAddr(net, addr) - case "udp", "udp4", "udp6": - a, err = ResolveUDPAddr(net, addr) - case "unix", "unixgram", "unixpacket": - a, err = ResolveUnixAddr(net, addr) + afnet = net[:i] + switch afnet { case "ip", "ip4", "ip6": - a, err = ResolveIPAddr(net, addr) - default: - err = UnknownNetworkError(net) + protostr := net[i+1:] + proto, i, ok := dtoi(protostr, 0) + if !ok || i != len(protostr) { + proto, err = lookupProtocol(protostr) + if err != nil { + return "", 0, err + } + } + return afnet, proto, nil } + return "", 0, UnknownNetworkError(net) +} + +func resolveNetAddr(op, net, addr string) (afnet string, a Addr, err error) { + afnet, _, err = parseDialNetwork(net) if err != nil { - return nil, &OpError{op, net + " " + addr, nil, err} + return "", nil, &OpError{op, net, nil, err} + } + if op == "dial" && addr == "" { + return "", nil, &OpError{op, net, nil, errMissingAddress} + } + switch afnet { + case "tcp", "tcp4", "tcp6": + if addr != "" { + a, err = ResolveTCPAddr(afnet, addr) + } + case "udp", "udp4", "udp6": + if addr != "" { + a, err = ResolveUDPAddr(afnet, addr) + } + case "ip", "ip4", "ip6": + if addr != "" { + a, err = ResolveIPAddr(afnet, addr) + } + case "unix", "unixgram", "unixpacket": + if addr != "" { + a, err = ResolveUnixAddr(afnet, addr) + } } return } @@ -32,122 +69,160 @@ func resolveNetAddr(op, net, addr string) (a Addr, err os.Error) { // // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), // "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" -// (IPv4-only), "ip6" (IPv6-only), "unix" and "unixgram". +// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and "unixpacket". // -// For IP networks, addresses have the form host:port. If host is -// a literal IPv6 address, it must be enclosed in square brackets. -// The functions JoinHostPort and SplitHostPort manipulate -// addresses in this form. +// For TCP and UDP networks, addresses have the form host:port. +// If host is a literal IPv6 address, it must be enclosed +// in square brackets. The functions JoinHostPort and SplitHostPort +// manipulate addresses in this form. // // Examples: // Dial("tcp", "12.34.56.78:80") // Dial("tcp", "google.com:80") // Dial("tcp", "[de:ad:be:ef::ca:fe]:80") // -func Dial(net, addr string) (c Conn, err os.Error) { - addri, err := resolveNetAddr("dial", net, addr) +// For IP networks, addr must be "ip", "ip4" or "ip6" followed +// by a colon and a protocol number or name. +// +// Examples: +// Dial("ip4:1", "127.0.0.1") +// Dial("ip6:ospf", "::1") +// +func Dial(net, addr string) (Conn, error) { + _, addri, err := resolveNetAddr("dial", net, addr) if err != nil { return nil, err } + return dialAddr(net, addr, addri) +} + +func dialAddr(net, addr string, addri Addr) (c Conn, err error) { switch ra := addri.(type) { case *TCPAddr: c, err = DialTCP(net, nil, ra) case *UDPAddr: c, err = DialUDP(net, nil, ra) - case *UnixAddr: - c, err = DialUnix(net, nil, ra) case *IPAddr: c, err = DialIP(net, nil, ra) + case *UnixAddr: + c, err = DialUnix(net, nil, ra) default: - err = UnknownNetworkError(net) + err = &OpError{"dial", net + " " + addr, nil, UnknownNetworkError(net)} } if err != nil { - return nil, &OpError{"dial", net + " " + addr, nil, err} + return nil, err } return } +// DialTimeout acts like Dial but takes a timeout. +// The timeout includes name resolution, if required. +func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) { + // TODO(bradfitz): the timeout should be pushed down into the + // net package's event loop, so on timeout to dead hosts we + // don't have a goroutine sticking around for the default of + // ~3 minutes. + t := time.NewTimer(timeout) + defer t.Stop() + type pair struct { + Conn + error + } + ch := make(chan pair, 1) + resolvedAddr := make(chan Addr, 1) + go func() { + _, addri, err := resolveNetAddr("dial", net, addr) + if err != nil { + ch <- pair{nil, err} + return + } + resolvedAddr <- addri // in case we need it for OpError + c, err := dialAddr(net, addr, addri) + ch <- pair{c, err} + }() + select { + case <-t.C: + // Try to use the real Addr in our OpError, if we resolved it + // before the timeout. Otherwise we just use stringAddr. + var addri Addr + select { + case a := <-resolvedAddr: + addri = a + default: + addri = &stringAddr{net, addr} + } + err := &OpError{ + Op: "dial", + Net: net, + Addr: addri, + Err: &timeoutError{}, + } + return nil, err + case p := <-ch: + return p.Conn, p.error + } + panic("unreachable") +} + +type stringAddr struct { + net, addr string +} + +func (a stringAddr) Network() string { return a.net } +func (a stringAddr) String() string { return a.addr } + // Listen announces on the local network address laddr. -// The network string net must be a stream-oriented -// network: "tcp", "tcp4", "tcp6", or "unix", or "unixpacket". -func Listen(net, laddr string) (l Listener, err os.Error) { - switch net { +// The network string net must be a stream-oriented network: +// "tcp", "tcp4", "tcp6", or "unix", or "unixpacket". +func Listen(net, laddr string) (Listener, error) { + afnet, a, err := resolveNetAddr("listen", net, laddr) + if err != nil { + return nil, err + } + switch afnet { case "tcp", "tcp4", "tcp6": var la *TCPAddr - if laddr != "" { - if la, err = ResolveTCPAddr(net, laddr); err != nil { - return nil, err - } - } - l, err := ListenTCP(net, la) - if err != nil { - return nil, err + if a != nil { + la = a.(*TCPAddr) } - return l, nil + return ListenTCP(afnet, la) case "unix", "unixpacket": var la *UnixAddr - if laddr != "" { - if la, err = ResolveUnixAddr(net, laddr); err != nil { - return nil, err - } - } - l, err := ListenUnix(net, la) - if err != nil { - return nil, err + if a != nil { + la = a.(*UnixAddr) } - return l, nil + return ListenUnix(net, la) } return nil, UnknownNetworkError(net) } // ListenPacket announces on the local network address laddr. // The network string net must be a packet-oriented network: -// "udp", "udp4", "udp6", or "unixgram". -func ListenPacket(net, laddr string) (c PacketConn, err os.Error) { - switch net { +// "udp", "udp4", "udp6", "ip", "ip4", "ip6" or "unixgram". +func ListenPacket(net, addr string) (PacketConn, error) { + afnet, a, err := resolveNetAddr("listen", net, addr) + if err != nil { + return nil, err + } + switch afnet { case "udp", "udp4", "udp6": var la *UDPAddr - if laddr != "" { - if la, err = ResolveUDPAddr(net, laddr); err != nil { - return nil, err - } + if a != nil { + la = a.(*UDPAddr) } - c, err := ListenUDP(net, la) - if err != nil { - return nil, err + return ListenUDP(net, la) + case "ip", "ip4", "ip6": + var la *IPAddr + if a != nil { + la = a.(*IPAddr) } - return c, nil + return ListenIP(net, la) case "unixgram": var la *UnixAddr - if laddr != "" { - if la, err = ResolveUnixAddr(net, laddr); err != nil { - return nil, err - } - } - c, err := DialUnix(net, la, nil) - if err != nil { - return nil, err - } - return c, nil - } - - var rawnet string - if rawnet, _, err = splitNetProto(net); err != nil { - switch rawnet { - case "ip", "ip4", "ip6": - var la *IPAddr - if laddr != "" { - if la, err = ResolveIPAddr(rawnet, laddr); err != nil { - return nil, err - } - } - c, err := ListenIP(net, la) - if err != nil { - return nil, err - } - return c, nil + if a != nil { + la = a.(*UnixAddr) } + return DialUnix(net, la, nil) } - return nil, UnknownNetworkError(net) } diff --git a/src/pkg/net/dial_test.go b/src/pkg/net/dial_test.go new file mode 100644 index 000000000..16b726311 --- /dev/null +++ b/src/pkg/net/dial_test.go @@ -0,0 +1,88 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "runtime" + "testing" + "time" +) + +func newLocalListener(t *testing.T) Listener { + ln, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + +func TestDialTimeout(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + errc := make(chan error) + + const SOMAXCONN = 0x80 // copied from syscall, but not always available + const numConns = SOMAXCONN + 10 + + // TODO(bradfitz): It's hard to test this in a portable + // way. This is unforunate, but works for now. + switch runtime.GOOS { + case "linux": + // The kernel will start accepting TCP connections before userspace + // gets a chance to not accept them, so fire off a bunch to fill up + // the kernel's backlog. Then we test we get a failure after that. + for i := 0; i < numConns; i++ { + go func() { + _, err := DialTimeout("tcp", ln.Addr().String(), 200*time.Millisecond) + errc <- err + }() + } + case "darwin": + // At least OS X 10.7 seems to accept any number of + // connections, ignoring listen's backlog, so resort + // to connecting to a hopefully-dead 127/8 address. + go func() { + _, err := DialTimeout("tcp", "127.0.71.111:80", 200*time.Millisecond) + errc <- err + }() + default: + // TODO(bradfitz): this probably doesn't work on + // Windows? SOMAXCONN is huge there. I'm not sure how + // listen works there. + // OpenBSD may have a reject route to 10/8. + // FreeBSD likely works, but is untested. + t.Logf("skipping test on %q; untested.", runtime.GOOS) + return + } + + connected := 0 + for { + select { + case <-time.After(15 * time.Second): + t.Fatal("too slow") + case err := <-errc: + if err == nil { + connected++ + if connected == numConns { + t.Fatal("all connections connected; expected some to time out") + } + } else { + terr, ok := err.(timeout) + if !ok { + t.Fatalf("got error %q; want error with timeout interface", err) + } + if !terr.Timeout() { + t.Fatalf("got error %q; not a timeout", err) + } + // Pass. We saw a timeout error. + return + } + } + } +} diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go index 9ad1770da..81750a3d7 100644 --- a/src/pkg/net/dialgoogle_test.go +++ b/src/pkg/net/dialgoogle_test.go @@ -19,7 +19,7 @@ var ipv6 = flag.Bool("ipv6", false, "assume ipv6 tunnel is present") // fd is already connected to the destination, port 80. // Run an HTTP request to fetch the appropriate page. func fetchGoogle(t *testing.T, fd Conn, network, addr string) { - req := []byte("GET /intl/en/privacy/ HTTP/1.0\r\nHost: www.google.com\r\n\r\n") + req := []byte("GET /robots.txt HTTP/1.0\r\nHost: www.google.com\r\n\r\n") n, err := fd.Write(req) buf := make([]byte, 1000) diff --git a/src/pkg/net/dict/Makefile b/src/pkg/net/dict/Makefile deleted file mode 100644 index eaa9e6531..000000000 --- a/src/pkg/net/dict/Makefile +++ /dev/null @@ -1,7 +0,0 @@ -include ../../../Make.inc - -TARG=net/dict -GOFILES=\ - dict.go\ - -include ../../../Make.pkg diff --git a/src/pkg/net/dict/dict.go b/src/pkg/net/dict/dict.go deleted file mode 100644 index b146ea212..000000000 --- a/src/pkg/net/dict/dict.go +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package dict implements the Dictionary Server Protocol -// as defined in RFC 2229. -package dict - -import ( - "net/textproto" - "os" - "strconv" - "strings" -) - -// A Client represents a client connection to a dictionary server. -type Client struct { - text *textproto.Conn -} - -// Dial returns a new client connected to a dictionary server at -// addr on the given network. -func Dial(network, addr string) (*Client, os.Error) { - text, err := textproto.Dial(network, addr) - if err != nil { - return nil, err - } - _, _, err = text.ReadCodeLine(220) - if err != nil { - text.Close() - return nil, err - } - return &Client{text: text}, nil -} - -// Close closes the connection to the dictionary server. -func (c *Client) Close() os.Error { - return c.text.Close() -} - -// A Dict represents a dictionary available on the server. -type Dict struct { - Name string // short name of dictionary - Desc string // long description -} - -// Dicts returns a list of the dictionaries available on the server. -func (c *Client) Dicts() ([]Dict, os.Error) { - id, err := c.text.Cmd("SHOW DB") - if err != nil { - return nil, err - } - - c.text.StartResponse(id) - defer c.text.EndResponse(id) - - _, _, err = c.text.ReadCodeLine(110) - if err != nil { - return nil, err - } - lines, err := c.text.ReadDotLines() - if err != nil { - return nil, err - } - _, _, err = c.text.ReadCodeLine(250) - - dicts := make([]Dict, len(lines)) - for i := range dicts { - d := &dicts[i] - a, _ := fields(lines[i]) - if len(a) < 2 { - return nil, textproto.ProtocolError("invalid dictionary: " + lines[i]) - } - d.Name = a[0] - d.Desc = a[1] - } - return dicts, err -} - -// A Defn represents a definition. -type Defn struct { - Dict Dict // Dict where definition was found - Word string // Word being defined - Text []byte // Definition text, typically multiple lines -} - -// Define requests the definition of the given word. -// The argument dict names the dictionary to use, -// the Name field of a Dict returned by Dicts. -// -// The special dictionary name "*" means to look in all the -// server's dictionaries. -// The special dictionary name "!" means to look in all the -// server's dictionaries in turn, stopping after finding the word -// in one of them. -func (c *Client) Define(dict, word string) ([]*Defn, os.Error) { - id, err := c.text.Cmd("DEFINE %s %q", dict, word) - if err != nil { - return nil, err - } - - c.text.StartResponse(id) - defer c.text.EndResponse(id) - - _, line, err := c.text.ReadCodeLine(150) - if err != nil { - return nil, err - } - a, _ := fields(line) - if len(a) < 1 { - return nil, textproto.ProtocolError("malformed response: " + line) - } - n, err := strconv.Atoi(a[0]) - if err != nil { - return nil, textproto.ProtocolError("invalid definition count: " + a[0]) - } - def := make([]*Defn, n) - for i := 0; i < n; i++ { - _, line, err = c.text.ReadCodeLine(151) - if err != nil { - return nil, err - } - a, _ := fields(line) - if len(a) < 3 { - // skip it, to keep protocol in sync - i-- - n-- - def = def[0:n] - continue - } - d := &Defn{Word: a[0], Dict: Dict{a[1], a[2]}} - d.Text, err = c.text.ReadDotBytes() - if err != nil { - return nil, err - } - def[i] = d - } - _, _, err = c.text.ReadCodeLine(250) - return def, err -} - -// Fields returns the fields in s. -// Fields are space separated unquoted words -// or quoted with single or double quote. -func fields(s string) ([]string, os.Error) { - var v []string - i := 0 - for { - for i < len(s) && (s[i] == ' ' || s[i] == '\t') { - i++ - } - if i >= len(s) { - break - } - if s[i] == '"' || s[i] == '\'' { - q := s[i] - // quoted string - var j int - for j = i + 1; ; j++ { - if j >= len(s) { - return nil, textproto.ProtocolError("malformed quoted string") - } - if s[j] == '\\' { - j++ - continue - } - if s[j] == q { - j++ - break - } - } - v = append(v, unquote(s[i+1:j-1])) - i = j - } else { - // atom - var j int - for j = i; j < len(s); j++ { - if s[j] == ' ' || s[j] == '\t' || s[j] == '\\' || s[j] == '"' || s[j] == '\'' { - break - } - } - v = append(v, s[i:j]) - i = j - } - if i < len(s) { - c := s[i] - if c != ' ' && c != '\t' { - return nil, textproto.ProtocolError("quotes not on word boundaries") - } - } - } - return v, nil -} - -func unquote(s string) string { - if strings.Index(s, "\\") < 0 { - return s - } - b := []byte(s) - w := 0 - for r := 0; r < len(b); r++ { - c := b[r] - if c == '\\' { - r++ - c = b[r] - } - b[w] = c - w++ - } - return string(b[0:w]) -} diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go index 93c04f6b5..f4ed8b87c 100644 --- a/src/pkg/net/dnsclient.go +++ b/src/pkg/net/dnsclient.go @@ -7,20 +7,19 @@ package net import ( "bytes" "fmt" - "os" - "rand" + "math/rand" "sort" ) // DNSError represents a DNS lookup error. type DNSError struct { - Error string // description of the error + Err string // description of the error Name string // name looked for Server string // server used IsTimeout bool } -func (e *DNSError) String() string { +func (e *DNSError) Error() string { if e == nil { return "<nil>" } @@ -28,7 +27,7 @@ func (e *DNSError) String() string { if e.Server != "" { s += " on " + e.Server } - s += ": " + e.Error + s += ": " + e.Err return s } @@ -40,10 +39,10 @@ const noSuchHost = "no such host" // reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP // address addr suitable for rDNS (PTR) record lookup or an error if it fails // to parse the IP address. -func reverseaddr(addr string) (arpa string, err os.Error) { +func reverseaddr(addr string) (arpa string, err error) { ip := ParseIP(addr) if ip == nil { - return "", &DNSError{Error: "unrecognized address", Name: addr} + return "", &DNSError{Err: "unrecognized address", Name: addr} } if ip.To4() != nil { return fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa.", ip[15], ip[14], ip[13], ip[12]), nil @@ -64,18 +63,18 @@ func reverseaddr(addr string) (arpa string, err os.Error) { // Find answer for name in dns message. // On return, if err == nil, addrs != nil. -func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err os.Error) { +func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err error) { addrs = make([]dnsRR, 0, len(dns.answer)) if dns.rcode == dnsRcodeNameError && dns.recursion_available { - return "", nil, &DNSError{Error: noSuchHost, Name: name} + return "", nil, &DNSError{Err: noSuchHost, Name: name} } if dns.rcode != dnsRcodeSuccess { // None of the error codes make sense // for the query we sent. If we didn't get // a name error and we didn't get success, // the server is behaving incorrectly. - return "", nil, &DNSError{Error: "server misbehaving", Name: name, Server: server} + return "", nil, &DNSError{Err: "server misbehaving", Name: name, Server: server} } // Look for the name. @@ -107,12 +106,12 @@ Cname: } } if len(addrs) == 0 { - return "", nil, &DNSError{Error: noSuchHost, Name: name, Server: server} + return "", nil, &DNSError{Err: noSuchHost, Name: name, Server: server} } return name, addrs, nil } - return "", nil, &DNSError{Error: "too many redirects", Name: name, Server: server} + return "", nil, &DNSError{Err: "too many redirects", Name: name, Server: server} } func isDomainName(s string) bool { diff --git a/src/pkg/net/dnsclient_unix.go b/src/pkg/net/dnsclient_unix.go index eb7db5e27..18c39360e 100644 --- a/src/pkg/net/dnsclient_unix.go +++ b/src/pkg/net/dnsclient_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd +// +build darwin freebsd linux netbsd openbsd // DNS client: see RFC 1035. // Has to be linked into package net for Dial. @@ -17,27 +17,26 @@ package net import ( - "os" - "rand" + "math/rand" "sync" "time" ) // Send a request on the connection and hope for a reply. // Up to cfg.attempts attempts. -func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, os.Error) { +func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error) { if len(name) >= 256 { - return nil, &DNSError{Error: "name too long", Name: name} + return nil, &DNSError{Err: "name too long", Name: name} } out := new(dnsMsg) - out.id = uint16(rand.Int()) ^ uint16(time.Nanoseconds()) + out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) out.question = []dnsQuestion{ {name, qtype, dnsClassINET}, } out.recursion_desired = true msg, ok := out.Pack() if !ok { - return nil, &DNSError{Error: "internal error - cannot pack message", Name: name} + return nil, &DNSError{Err: "internal error - cannot pack message", Name: name} } for attempt := 0; attempt < cfg.attempts; attempt++ { @@ -46,7 +45,11 @@ func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, os.Er return nil, err } - c.SetReadTimeout(int64(cfg.timeout) * 1e9) // nanoseconds + if cfg.timeout == 0 { + c.SetReadDeadline(time.Time{}) + } else { + c.SetReadDeadline(time.Now().Add(time.Duration(cfg.timeout) * time.Second)) + } buf := make([]byte, 2000) // More than enough. n, err = c.Read(buf) @@ -67,14 +70,14 @@ func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, os.Er if a := c.RemoteAddr(); a != nil { server = a.String() } - return nil, &DNSError{Error: "no answer from server", Name: name, Server: server, IsTimeout: true} + return nil, &DNSError{Err: "no answer from server", Name: name, Server: server, IsTimeout: true} } // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). -func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err os.Error) { +func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err error) { if len(cfg.servers) == 0 { - return "", nil, &DNSError{Error: "no DNS servers", Name: name} + return "", nil, &DNSError{Err: "no DNS servers", Name: name} } for i := 0; i < len(cfg.servers); i++ { // Calling Dial here is scary -- we have to be sure @@ -96,7 +99,7 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs continue } cname, addrs, err = answer(name, server, msg, qtype) - if err == nil || err.(*DNSError).Error == noSuchHost { + if err == nil || err.(*DNSError).Err == noSuchHost { break } } @@ -123,15 +126,15 @@ func convertRR_AAAA(records []dnsRR) []IP { } var cfg *dnsConfig -var dnserr os.Error +var dnserr error func loadConfig() { cfg, dnserr = dnsReadConfig() } var onceLoadConfig sync.Once -func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Error) { +func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err error) { if !isDomainName(name) { - return name, nil, &DNSError{Error: "invalid domain name", Name: name} + return name, nil, &DNSError{Err: "invalid domain name", Name: name} } onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { @@ -186,7 +189,7 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Erro // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupHost(name string) (addrs []string, err os.Error) { +func goLookupHost(name string) (addrs []string, err error) { // Use entries from /etc/hosts if they match. addrs = lookupStaticHost(name) if len(addrs) > 0 { @@ -214,7 +217,7 @@ func goLookupHost(name string) (addrs []string, err os.Error) { // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupIP(name string) (addrs []IP, err os.Error) { +func goLookupIP(name string) (addrs []IP, err error) { // Use entries from /etc/hosts if possible. haddrs := lookupStaticHost(name) if len(haddrs) > 0 { @@ -260,7 +263,7 @@ func goLookupIP(name string) (addrs []IP, err os.Error) { // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupCNAME(name string) (cname string, err os.Error) { +func goLookupCNAME(name string) (cname string, err error) { onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { err = dnserr diff --git a/src/pkg/net/dnsconfig.go b/src/pkg/net/dnsconfig.go index afc059917..c0ab80288 100644 --- a/src/pkg/net/dnsconfig.go +++ b/src/pkg/net/dnsconfig.go @@ -2,14 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd +// +build darwin freebsd linux netbsd openbsd // Read system DNS config from /etc/resolv.conf package net -import "os" - type dnsConfig struct { servers []string // servers to use search []string // suffixes to append to local name @@ -19,14 +17,14 @@ type dnsConfig struct { rotate bool // round robin among servers } -var dnsconfigError os.Error +var dnsconfigError error type DNSConfigError struct { - Error os.Error + Err error } -func (e *DNSConfigError) String() string { - return "error reading DNS config: " + e.Error.String() +func (e *DNSConfigError) Error() string { + return "error reading DNS config: " + e.Err.Error() } func (e *DNSConfigError) Timeout() bool { return false } @@ -36,7 +34,7 @@ func (e *DNSConfigError) Temporary() bool { return false } // TODO(rsc): Supposed to call uname() and chop the beginning // of the host name to get the default search domain. // We assume it's in resolv.conf anyway. -func dnsReadConfig() (*dnsConfig, os.Error) { +func dnsReadConfig() (*dnsConfig, error) { file, err := open("/etc/resolv.conf") if err != nil { return nil, &DNSConfigError{err} diff --git a/src/pkg/net/doc.go b/src/pkg/net/doc.go new file mode 100644 index 000000000..3a44e528e --- /dev/null +++ b/src/pkg/net/doc.go @@ -0,0 +1,59 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +// LookupHost looks up the given host using the local resolver. +// It returns an array of that host's addresses. +func LookupHost(host string) (addrs []string, err error) { + return lookupHost(host) +} + +// LookupIP looks up host using the local resolver. +// It returns an array of that host's IPv4 and IPv6 addresses. +func LookupIP(host string) (addrs []IP, err error) { + return lookupIP(host) +} + +// LookupPort looks up the port for the given network and service. +func LookupPort(network, service string) (port int, err error) { + return lookupPort(network, service) +} + +// LookupCNAME returns the canonical DNS host for the given name. +// Callers that do not care about the canonical name can call +// LookupHost or LookupIP directly; both take care of resolving +// the canonical name as part of the lookup. +func LookupCNAME(name string) (cname string, err error) { + return lookupCNAME(name) +} + +// LookupSRV tries to resolve an SRV query of the given service, +// protocol, and domain name. The proto is "tcp" or "udp". +// The returned records are sorted by priority and randomized +// by weight within a priority. +// +// LookupSRV constructs the DNS name to look up following RFC 2782. +// That is, it looks up _service._proto.name. To accommodate services +// publishing SRV records under non-standard names, if both service +// and proto are empty strings, LookupSRV looks up name directly. +func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { + return lookupSRV(service, proto, name) +} + +// LookupMX returns the DNS MX records for the given domain name sorted by preference. +func LookupMX(name string) (mx []*MX, err error) { + return lookupMX(name) +} + +// LookupTXT returns the DNS TXT records for the given domain name. +func LookupTXT(name string) (txt []string, err error) { + return lookupTXT(name) +} + +// LookupAddr performs a reverse lookup for the given address, returning a list +// of names mapping to that address. +func LookupAddr(addr string) (name []string, err error) { + return lookupAddr(addr) +} diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go index 9084e8875..495ef007f 100644 --- a/src/pkg/net/fd.go +++ b/src/pkg/net/fd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd +// +build darwin freebsd linux netbsd openbsd package net @@ -22,23 +22,22 @@ type netFD struct { closing bool // immutable until Close - sysfd int - family int - proto int - sysfile *os.File - cr chan bool - cw chan bool - net string - laddr Addr - raddr Addr + sysfd int + family int + sotype int + isConnected bool + sysfile *os.File + cr chan bool + cw chan bool + net string + laddr Addr + raddr Addr // owned by client - rdeadline_delta int64 - rdeadline int64 - rio sync.Mutex - wdeadline_delta int64 - wdeadline int64 - wio sync.Mutex + rdeadline int64 + rio sync.Mutex + wdeadline int64 + wio sync.Mutex // owned by fd wait server ncr, ncw int @@ -46,7 +45,7 @@ type netFD struct { type InvalidConnError struct{} -func (e *InvalidConnError) String() string { return "invalid net.Conn" } +func (e *InvalidConnError) Error() string { return "invalid net.Conn" } func (e *InvalidConnError) Temporary() bool { return false } func (e *InvalidConnError) Timeout() bool { return false } @@ -126,7 +125,7 @@ func (s *pollServer) AddFD(fd *netFD, mode int) { wake, err := s.poll.AddFD(intfd, mode, false) if err != nil { - panic("pollServer AddFD " + err.String()) + panic("pollServer AddFD " + err.Error()) } if wake { doWakeup = true @@ -152,7 +151,7 @@ func (s *pollServer) LookupFD(fd int, mode int) *netFD { if !ok { return nil } - s.pending[key] = nil, false + delete(s.pending, key) return netfd } @@ -171,7 +170,7 @@ func (s *pollServer) WakeFD(fd *netFD, mode int) { } func (s *pollServer) Now() int64 { - return time.Nanoseconds() + return time.Now().UnixNano() } func (s *pollServer) CheckDeadlines() { @@ -195,7 +194,7 @@ func (s *pollServer) CheckDeadlines() { } if t > 0 { if t <= now { - s.pending[key] = nil, false + delete(s.pending, key) if mode == 'r' { s.poll.DelFD(fd.sysfd, mode) fd.rdeadline = -1 @@ -227,7 +226,7 @@ func (s *pollServer) Run() { } fd, mode, err := s.poll.WaitFD(s, t) if err != nil { - print("pollServer WaitFD: ", err.String(), "\n") + print("pollServer WaitFD: ", err.Error(), "\n") return } if fd < 0 { @@ -271,20 +270,20 @@ var onceStartServer sync.Once func startServer() { p, err := newPollServer() if err != nil { - print("Start pollServer: ", err.String(), "\n") + print("Start pollServer: ", err.Error(), "\n") } pollserver = p } -func newFD(fd, family, proto int, net string) (f *netFD, err os.Error) { +func newFD(fd, family, sotype int, net string) (f *netFD, err error) { onceStartServer.Do(startServer) - if e := syscall.SetNonblock(fd, true); e != 0 { - return nil, os.Errno(e) + if e := syscall.SetNonblock(fd, true); e != nil { + return nil, e } f = &netFD{ sysfd: fd, family: family, - proto: proto, + sotype: sotype, net: net, } f.cr = make(chan bool, 1) @@ -305,20 +304,20 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { fd.sysfile = os.NewFile(fd.sysfd, fd.net+":"+ls+"->"+rs) } -func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) { - e := syscall.Connect(fd.sysfd, ra) - if e == syscall.EINPROGRESS { - var errno int +func (fd *netFD) connect(ra syscall.Sockaddr) (err error) { + err = syscall.Connect(fd.sysfd, ra) + if err == syscall.EINPROGRESS { pollserver.WaitWrite(fd) - e, errno = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) - if errno != 0 { - return os.NewSyscallError("getsockopt", errno) + var e int + e, err = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) + if err != nil { + return os.NewSyscallError("getsockopt", err) + } + if e != 0 { + err = syscall.Errno(e) } } - if e != 0 { - return os.Errno(e) - } - return nil + return err } // Add a reference to this fd. @@ -346,7 +345,7 @@ func (fd *netFD) decref() { fd.sysmu.Unlock() } -func (fd *netFD) Close() os.Error { +func (fd *netFD) Close() error { if fd == nil || fd.sysfile == nil { return os.EINVAL } @@ -358,7 +357,26 @@ func (fd *netFD) Close() os.Error { return nil } -func (fd *netFD) Read(p []byte) (n int, err os.Error) { +func (fd *netFD) shutdown(how int) error { + if fd == nil || fd.sysfile == nil { + return os.EINVAL + } + err := syscall.Shutdown(fd.sysfd, how) + if err != nil { + return &OpError{"shutdown", fd.net, fd.laddr, err} + } + return nil +} + +func (fd *netFD) CloseRead() error { + return fd.shutdown(syscall.SHUT_RD) +} + +func (fd *netFD) CloseWrite() error { + return fd.shutdown(syscall.SHUT_WR) +} + +func (fd *netFD) Read(p []byte) (n int, err error) { if fd == nil { return 0, os.EINVAL } @@ -369,34 +387,29 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) { if fd.sysfile == nil { return 0, os.EINVAL } - if fd.rdeadline_delta > 0 { - fd.rdeadline = pollserver.Now() + fd.rdeadline_delta - } else { - fd.rdeadline = 0 - } - var oserr os.Error for { - var errno int - n, errno = syscall.Read(fd.sysfile.Fd(), p) - if errno == syscall.EAGAIN && fd.rdeadline >= 0 { - pollserver.WaitRead(fd) - continue + n, err = syscall.Read(fd.sysfile.Fd(), p) + if err == syscall.EAGAIN { + if fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + err = errTimeout } - if errno != 0 { + if err != nil { n = 0 - oserr = os.Errno(errno) - } else if n == 0 && errno == 0 && fd.proto != syscall.SOCK_DGRAM { - err = os.EOF + } else if n == 0 && err == nil && fd.sotype != syscall.SOCK_DGRAM { + err = io.EOF } break } - if oserr != nil { - err = &OpError{"read", fd.net, fd.raddr, oserr} + if err != nil && err != io.EOF { + err = &OpError{"read", fd.net, fd.raddr, err} } return } -func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) { +func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { if fd == nil || fd.sysfile == nil { return 0, nil, os.EINVAL } @@ -404,32 +417,27 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) { defer fd.rio.Unlock() fd.incref() defer fd.decref() - if fd.rdeadline_delta > 0 { - fd.rdeadline = pollserver.Now() + fd.rdeadline_delta - } else { - fd.rdeadline = 0 - } - var oserr os.Error for { - var errno int - n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0) - if errno == syscall.EAGAIN && fd.rdeadline >= 0 { - pollserver.WaitRead(fd) - continue + n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0) + if err == syscall.EAGAIN { + if fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + err = errTimeout } - if errno != 0 { + if err != nil { n = 0 - oserr = os.Errno(errno) } break } - if oserr != nil { - err = &OpError{"read", fd.net, fd.laddr, oserr} + if err != nil { + err = &OpError{"read", fd.net, fd.laddr, err} } return } -func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) { +func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { if fd == nil || fd.sysfile == nil { return 0, 0, 0, nil, os.EINVAL } @@ -437,35 +445,28 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S defer fd.rio.Unlock() fd.incref() defer fd.decref() - if fd.rdeadline_delta > 0 { - fd.rdeadline = pollserver.Now() + fd.rdeadline_delta - } else { - fd.rdeadline = 0 - } - var oserr os.Error for { - var errno int - n, oobn, flags, sa, errno = syscall.Recvmsg(fd.sysfd, p, oob, 0) - if errno == syscall.EAGAIN && fd.rdeadline >= 0 { - pollserver.WaitRead(fd) - continue - } - if errno != 0 { - oserr = os.Errno(errno) + n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0) + if err == syscall.EAGAIN { + if fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + err = errTimeout } - if n == 0 { - oserr = os.EOF + if err == nil && n == 0 { + err = io.EOF } break } - if oserr != nil { - err = &OpError{"read", fd.net, fd.laddr, oserr} + if err != nil && err != io.EOF { + err = &OpError{"read", fd.net, fd.laddr, err} return } return } -func (fd *netFD) Write(p []byte) (n int, err os.Error) { +func (fd *netFD) Write(p []byte) (n int, err error) { if fd == nil { return 0, os.EINVAL } @@ -476,43 +477,40 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) { if fd.sysfile == nil { return 0, os.EINVAL } - if fd.wdeadline_delta > 0 { - fd.wdeadline = pollserver.Now() + fd.wdeadline_delta - } else { - fd.wdeadline = 0 - } nn := 0 - var oserr os.Error for { - n, errno := syscall.Write(fd.sysfile.Fd(), p[nn:]) + var n int + n, err = syscall.Write(fd.sysfile.Fd(), p[nn:]) if n > 0 { nn += n } if nn == len(p) { break } - if errno == syscall.EAGAIN && fd.wdeadline >= 0 { - pollserver.WaitWrite(fd) - continue + if err == syscall.EAGAIN { + if fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + err = errTimeout } - if errno != 0 { + if err != nil { n = 0 - oserr = os.Errno(errno) break } if n == 0 { - oserr = io.ErrUnexpectedEOF + err = io.ErrUnexpectedEOF break } } - if oserr != nil { - err = &OpError{"write", fd.net, fd.raddr, oserr} + if err != nil { + err = &OpError{"write", fd.net, fd.raddr, err} } return nn, err } -func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) { +func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { if fd == nil || fd.sysfile == nil { return 0, os.EINVAL } @@ -520,32 +518,26 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) { defer fd.wio.Unlock() fd.incref() defer fd.decref() - if fd.wdeadline_delta > 0 { - fd.wdeadline = pollserver.Now() + fd.wdeadline_delta - } else { - fd.wdeadline = 0 - } - var oserr os.Error for { - errno := syscall.Sendto(fd.sysfd, p, 0, sa) - if errno == syscall.EAGAIN && fd.wdeadline >= 0 { - pollserver.WaitWrite(fd) - continue - } - if errno != 0 { - oserr = os.Errno(errno) + err = syscall.Sendto(fd.sysfd, p, 0, sa) + if err == syscall.EAGAIN { + if fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + err = errTimeout } break } - if oserr == nil { + if err == nil { n = len(p) } else { - err = &OpError{"write", fd.net, fd.raddr, oserr} + err = &OpError{"write", fd.net, fd.raddr, err} } return } -func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) { +func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { if fd == nil || fd.sysfile == nil { return 0, 0, os.EINVAL } @@ -553,73 +545,62 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob defer fd.wio.Unlock() fd.incref() defer fd.decref() - if fd.wdeadline_delta > 0 { - fd.wdeadline = pollserver.Now() + fd.wdeadline_delta - } else { - fd.wdeadline = 0 - } - var oserr os.Error for { - var errno int - errno = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) - if errno == syscall.EAGAIN && fd.wdeadline >= 0 { - pollserver.WaitWrite(fd) - continue - } - if errno != 0 { - oserr = os.Errno(errno) + err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) + if err == syscall.EAGAIN { + if fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + err = errTimeout } break } - if oserr == nil { + if err == nil { n = len(p) oobn = len(oob) } else { - err = &OpError{"write", fd.net, fd.raddr, oserr} + err = &OpError{"write", fd.net, fd.raddr, err} } return } -func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) { +func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err error) { if fd == nil || fd.sysfile == nil { return nil, os.EINVAL } fd.incref() defer fd.decref() - if fd.rdeadline_delta > 0 { - fd.rdeadline = pollserver.Now() + fd.rdeadline_delta - } else { - fd.rdeadline = 0 - } // See ../syscall/exec.go for description of ForkLock. // It is okay to hold the lock across syscall.Accept // because we have put fd.sysfd into non-blocking mode. - syscall.ForkLock.RLock() - var s, e int + var s int var rsa syscall.Sockaddr for { if fd.closing { - syscall.ForkLock.RUnlock() return nil, os.EINVAL } - s, rsa, e = syscall.Accept(fd.sysfd) - if e != syscall.EAGAIN || fd.rdeadline < 0 { - break - } - syscall.ForkLock.RUnlock() - pollserver.WaitRead(fd) syscall.ForkLock.RLock() - } - if e != 0 { - syscall.ForkLock.RUnlock() - return nil, &OpError{"accept", fd.net, fd.laddr, os.Errno(e)} + s, rsa, err = syscall.Accept(fd.sysfd) + if err != nil { + syscall.ForkLock.RUnlock() + if err == syscall.EAGAIN { + if fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + err = errTimeout + } + return nil, &OpError{"accept", fd.net, fd.laddr, err} + } + break } syscall.CloseOnExec(s) syscall.ForkLock.RUnlock() - if nfd, err = newFD(s, fd.family, fd.proto, fd.net); err != nil { + if nfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil { syscall.Close(s) return nil, err } @@ -628,20 +609,20 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. return nfd, nil } -func (fd *netFD) dup() (f *os.File, err os.Error) { - ns, e := syscall.Dup(fd.sysfd) - if e != 0 { - return nil, &OpError{"dup", fd.net, fd.laddr, os.Errno(e)} +func (fd *netFD) dup() (f *os.File, err error) { + ns, err := syscall.Dup(fd.sysfd) + if err != nil { + return nil, &OpError{"dup", fd.net, fd.laddr, err} } // We want blocking mode for the new fd, hence the double negative. - if e = syscall.SetNonblock(ns, false); e != 0 { - return nil, &OpError{"setnonblock", fd.net, fd.laddr, os.Errno(e)} + if err = syscall.SetNonblock(ns, false); err != nil { + return nil, &OpError{"setnonblock", fd.net, fd.laddr, err} } return os.NewFile(ns, fd.sysfile.Name()), nil } -func closesocket(s int) (errno int) { +func closesocket(s int) error { return syscall.Close(s) } diff --git a/src/pkg/net/fd_darwin.go b/src/pkg/net/fd_darwin.go index 7e3d549eb..c6db083c4 100644 --- a/src/pkg/net/fd_darwin.go +++ b/src/pkg/net/fd_darwin.go @@ -7,6 +7,7 @@ package net import ( + "errors" "os" "syscall" ) @@ -21,17 +22,17 @@ type pollster struct { kbuf [1]syscall.Kevent_t } -func newpollster() (p *pollster, err os.Error) { +func newpollster() (p *pollster, err error) { p = new(pollster) - var e int - if p.kq, e = syscall.Kqueue(); e != 0 { - return nil, os.NewSyscallError("kqueue", e) + if p.kq, err = syscall.Kqueue(); err != nil { + return nil, os.NewSyscallError("kqueue", err) } + syscall.CloseOnExec(p.kq) p.events = p.eventbuf[0:0] return p, nil } -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { // pollServer is locked. var kmode int @@ -51,15 +52,15 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { } syscall.SetKevent(ev, fd, kmode, flags) - n, e := syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil) - if e != 0 { - return false, os.NewSyscallError("kevent", e) + n, err := syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil) + if err != nil { + return false, os.NewSyscallError("kevent", err) } if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { - return false, os.NewError("kqueue phase error") + return false, errors.New("kqueue phase error") } if ev.Data != 0 { - return false, os.Errno(int(ev.Data)) + return false, syscall.Errno(ev.Data) } return false, nil } @@ -81,7 +82,7 @@ func (p *pollster) DelFD(fd int, mode int) { syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil) } -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { var t *syscall.Timespec for len(p.events) == 0 { if nsec > 0 { @@ -95,11 +96,11 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[0:], t) s.Lock() - if e != 0 { + if e != nil { if e == syscall.EINTR { continue } - return -1, 0, os.NewSyscallError("kevent", e) + return -1, 0, os.NewSyscallError("kevent", nil) } if nn == 0 { return -1, 0, nil @@ -117,4 +118,4 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E return fd, mode, nil } -func (p *pollster) Close() os.Error { return os.NewSyscallError("close", syscall.Close(p.kq)) } +func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_freebsd.go b/src/pkg/net/fd_freebsd.go index e50883e94..31d0744e2 100644 --- a/src/pkg/net/fd_freebsd.go +++ b/src/pkg/net/fd_freebsd.go @@ -21,17 +21,17 @@ type pollster struct { kbuf [1]syscall.Kevent_t } -func newpollster() (p *pollster, err os.Error) { +func newpollster() (p *pollster, err error) { p = new(pollster) - var e int - if p.kq, e = syscall.Kqueue(); e != 0 { - return nil, os.NewSyscallError("kqueue", e) + if p.kq, err = syscall.Kqueue(); err != nil { + return nil, os.NewSyscallError("kqueue", err) } + syscall.CloseOnExec(p.kq) p.events = p.eventbuf[0:0] return p, nil } -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { // pollServer is locked. var kmode int @@ -50,14 +50,14 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { syscall.SetKevent(ev, fd, kmode, flags) n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) - if e != 0 { + if e != nil { return false, os.NewSyscallError("kevent", e) } if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { return false, os.NewSyscallError("kqueue phase error", e) } if ev.Data != 0 { - return false, os.Errno(int(ev.Data)) + return false, syscall.Errno(int(ev.Data)) } return false, nil } @@ -77,7 +77,7 @@ func (p *pollster) DelFD(fd int, mode int) { syscall.Kevent(p.kq, p.kbuf[:], nil, nil) } -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { var t *syscall.Timespec for len(p.events) == 0 { if nsec > 0 { @@ -91,7 +91,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) s.Lock() - if e != 0 { + if e != nil { if e == syscall.EINTR { continue } @@ -113,4 +113,4 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E return fd, mode, nil } -func (p *pollster) Close() os.Error { return os.NewSyscallError("close", syscall.Close(p.kq)) } +func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_linux.go b/src/pkg/net/fd_linux.go index 70fc344b2..c8df9c932 100644 --- a/src/pkg/net/fd_linux.go +++ b/src/pkg/net/fd_linux.go @@ -33,21 +33,27 @@ type pollster struct { ctlEvent syscall.EpollEvent } -func newpollster() (p *pollster, err os.Error) { +func newpollster() (p *pollster, err error) { p = new(pollster) - var e int + var e error - // The arg to epoll_create is a hint to the kernel - // about the number of FDs we will care about. - // We don't know, and since 2.6.8 the kernel ignores it anyhow. - if p.epfd, e = syscall.EpollCreate(16); e != 0 { - return nil, os.NewSyscallError("epoll_create", e) + if p.epfd, e = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC); e != nil { + if e != syscall.ENOSYS { + return nil, os.NewSyscallError("epoll_create1", e) + } + // The arg to epoll_create is a hint to the kernel + // about the number of FDs we will care about. + // We don't know, and since 2.6.8 the kernel ignores it anyhow. + if p.epfd, e = syscall.EpollCreate(16); e != nil { + return nil, os.NewSyscallError("epoll_create", e) + } + syscall.CloseOnExec(p.epfd) } p.events = make(map[int]uint32) return p, nil } -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { // pollServer is locked. var already bool @@ -68,7 +74,7 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { } else { op = syscall.EPOLL_CTL_ADD } - if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != 0 { + if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != nil { return false, os.NewSyscallError("epoll_ctl", e) } p.events[fd] = p.ctlEvent.Events @@ -97,15 +103,15 @@ func (p *pollster) StopWaiting(fd int, bits uint) { if int32(events)&^syscall.EPOLLONESHOT != 0 { p.ctlEvent.Fd = int32(fd) p.ctlEvent.Events = events - if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != 0 { - print("Epoll modify fd=", fd, ": ", os.Errno(e).String(), "\n") + if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != nil { + print("Epoll modify fd=", fd, ": ", e.Error(), "\n") } p.events[fd] = events } else { - if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != 0 { - print("Epoll delete fd=", fd, ": ", os.Errno(e).String(), "\n") + if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != nil { + print("Epoll delete fd=", fd, ": ", e.Error(), "\n") } - p.events[fd] = 0, false + delete(p.events, fd) } } @@ -130,7 +136,7 @@ func (p *pollster) DelFD(fd int, mode int) { } } -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { for len(p.waitEvents) == 0 { var msec int = -1 if nsec > 0 { @@ -141,7 +147,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec) s.Lock() - if e != 0 { + if e != nil { if e == syscall.EAGAIN || e == syscall.EINTR { continue } @@ -177,6 +183,6 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E return fd, 'r', nil } -func (p *pollster) Close() os.Error { +func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.epfd)) } diff --git a/src/pkg/net/fd_netbsd.go b/src/pkg/net/fd_netbsd.go new file mode 100644 index 000000000..31d0744e2 --- /dev/null +++ b/src/pkg/net/fd_netbsd.go @@ -0,0 +1,116 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Waiting for FDs via kqueue/kevent. + +package net + +import ( + "os" + "syscall" +) + +type pollster struct { + kq int + eventbuf [10]syscall.Kevent_t + events []syscall.Kevent_t + + // An event buffer for AddFD/DelFD. + // Must hold pollServer lock. + kbuf [1]syscall.Kevent_t +} + +func newpollster() (p *pollster, err error) { + p = new(pollster) + if p.kq, err = syscall.Kqueue(); err != nil { + return nil, os.NewSyscallError("kqueue", err) + } + syscall.CloseOnExec(p.kq) + p.events = p.eventbuf[0:0] + return p, nil +} + +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { + // pollServer is locked. + + var kmode int + if mode == 'r' { + kmode = syscall.EVFILT_READ + } else { + kmode = syscall.EVFILT_WRITE + } + ev := &p.kbuf[0] + // EV_ADD - add event to kqueue list + // EV_ONESHOT - delete the event the first time it triggers + flags := syscall.EV_ADD + if !repeat { + flags |= syscall.EV_ONESHOT + } + syscall.SetKevent(ev, fd, kmode, flags) + + n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) + if e != nil { + return false, os.NewSyscallError("kevent", e) + } + if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { + return false, os.NewSyscallError("kqueue phase error", e) + } + if ev.Data != 0 { + return false, syscall.Errno(int(ev.Data)) + } + return false, nil +} + +func (p *pollster) DelFD(fd int, mode int) { + // pollServer is locked. + + var kmode int + if mode == 'r' { + kmode = syscall.EVFILT_READ + } else { + kmode = syscall.EVFILT_WRITE + } + ev := &p.kbuf[0] + // EV_DELETE - delete event from kqueue list + syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE) + syscall.Kevent(p.kq, p.kbuf[:], nil, nil) +} + +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { + var t *syscall.Timespec + for len(p.events) == 0 { + if nsec > 0 { + if t == nil { + t = new(syscall.Timespec) + } + *t = syscall.NsecToTimespec(nsec) + } + + s.Unlock() + nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) + s.Lock() + + if e != nil { + if e == syscall.EINTR { + continue + } + return -1, 0, os.NewSyscallError("kevent", e) + } + if nn == 0 { + return -1, 0, nil + } + p.events = p.eventbuf[0:nn] + } + ev := &p.events[0] + p.events = p.events[1:] + fd = int(ev.Ident) + if ev.Filter == syscall.EVFILT_READ { + mode = 'r' + } else { + mode = 'w' + } + return fd, mode, nil +} + +func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_openbsd.go b/src/pkg/net/fd_openbsd.go index e50883e94..31d0744e2 100644 --- a/src/pkg/net/fd_openbsd.go +++ b/src/pkg/net/fd_openbsd.go @@ -21,17 +21,17 @@ type pollster struct { kbuf [1]syscall.Kevent_t } -func newpollster() (p *pollster, err os.Error) { +func newpollster() (p *pollster, err error) { p = new(pollster) - var e int - if p.kq, e = syscall.Kqueue(); e != 0 { - return nil, os.NewSyscallError("kqueue", e) + if p.kq, err = syscall.Kqueue(); err != nil { + return nil, os.NewSyscallError("kqueue", err) } + syscall.CloseOnExec(p.kq) p.events = p.eventbuf[0:0] return p, nil } -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { +func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { // pollServer is locked. var kmode int @@ -50,14 +50,14 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, os.Error) { syscall.SetKevent(ev, fd, kmode, flags) n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) - if e != 0 { + if e != nil { return false, os.NewSyscallError("kevent", e) } if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { return false, os.NewSyscallError("kqueue phase error", e) } if ev.Data != 0 { - return false, os.Errno(int(ev.Data)) + return false, syscall.Errno(int(ev.Data)) } return false, nil } @@ -77,7 +77,7 @@ func (p *pollster) DelFD(fd int, mode int) { syscall.Kevent(p.kq, p.kbuf[:], nil, nil) } -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) { +func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { var t *syscall.Timespec for len(p.events) == 0 { if nsec > 0 { @@ -91,7 +91,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) s.Lock() - if e != 0 { + if e != nil { if e == syscall.EINTR { continue } @@ -113,4 +113,4 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.E return fd, mode, nil } -func (p *pollster) Close() os.Error { return os.NewSyscallError("close", syscall.Close(p.kq)) } +func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index b025bddea..f00459f0b 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -5,6 +5,7 @@ package net import ( + "io" "os" "runtime" "sync" @@ -15,21 +16,21 @@ import ( type InvalidConnError struct{} -func (e *InvalidConnError) String() string { return "invalid net.Conn" } +func (e *InvalidConnError) Error() string { return "invalid net.Conn" } func (e *InvalidConnError) Temporary() bool { return false } func (e *InvalidConnError) Timeout() bool { return false } -var initErr os.Error +var initErr error func init() { var d syscall.WSAData e := syscall.WSAStartup(uint32(0x202), &d) - if e != 0 { + if e != nil { initErr = os.NewSyscallError("WSAStartup", e) } } -func closesocket(s syscall.Handle) (errno int) { +func closesocket(s syscall.Handle) (err error) { return syscall.Closesocket(s) } @@ -37,13 +38,13 @@ func closesocket(s syscall.Handle) (errno int) { type anOpIface interface { Op() *anOp Name() string - Submit() (errno int) + Submit() (err error) } // IO completion result parameters. type ioResult struct { qty uint32 - err int + err error } // anOp implements functionality common to all io operations. @@ -53,7 +54,7 @@ type anOp struct { o syscall.Overlapped resultc chan ioResult - errnoc chan int + errnoc chan error fd *netFD } @@ -70,7 +71,7 @@ func (o *anOp) Init(fd *netFD, mode int) { } o.resultc = fd.resultc[i] if fd.errnoc[i] == nil { - fd.errnoc[i] = make(chan int) + fd.errnoc[i] = make(chan error) } o.errnoc = fd.errnoc[i] } @@ -110,14 +111,14 @@ func (s *resultSrv) Run() { for { r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) switch { - case r.err == 0: + case r.err == nil: // Dequeued successfully completed io packet. - case r.err == syscall.WAIT_TIMEOUT && o == nil: + case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil: // Wait has timed out (should not happen now, but might be used in the future). panic("GetQueuedCompletionStatus timed out") case o == nil: // Failed to dequeue anything -> report the error. - panic("GetQueuedCompletionStatus failed " + syscall.Errstr(r.err)) + panic("GetQueuedCompletionStatus failed " + r.err.Error()) default: // Dequeued failed io packet. } @@ -149,12 +150,13 @@ func (s *ioSrv) ProcessRemoteIO() { } // ExecIO executes a single io operation. It either executes it -// inline, or, if timeouts are employed, passes the request onto +// inline, or, if a deadline is employed, passes the request onto // a special goroutine and waits for completion or cancels request. -func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err os.Error) { - var e int +// deadline is unix nanos. +func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (n int, err error) { + var e error o := oi.Op() - if deadline_delta > 0 { + if deadline != 0 { // Send request to a special dedicated thread, // so it can stop the io with CancelIO later. s.submchan <- oi @@ -163,19 +165,25 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err os.Error) e = oi.Submit() } switch e { - case 0: + case nil: // IO completed immediately, but we need to get our completion message anyway. case syscall.ERROR_IO_PENDING: // IO started, and we have to wait for its completion. default: - return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(e)} + return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, e} } // Wait for our request to complete. var r ioResult - if deadline_delta > 0 { + if deadline != 0 { + dt := deadline - time.Now().UnixNano() + if dt < 1 { + dt = 1 + } + timer := time.NewTimer(time.Duration(dt) * time.Nanosecond) + defer timer.Stop() select { case r = <-o.resultc: - case <-time.After(deadline_delta): + case <-timer.C: s.canchan <- oi <-o.errnoc r = <-o.resultc @@ -186,8 +194,8 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err os.Error) } else { r = <-o.resultc } - if r.err != 0 { - err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(r.err)} + if r.err != nil { + err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err} } return int(r.qty), err } @@ -199,10 +207,10 @@ var onceStartServer sync.Once func startServer() { resultsrv = new(resultSrv) - var errno int - resultsrv.iocp, errno = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1) - if errno != 0 { - panic("CreateIoCompletionPort failed " + syscall.Errstr(errno)) + var err error + resultsrv.iocp, err = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1) + if err != nil { + panic("CreateIoCompletionPort: " + err.Error()) } go resultsrv.Run() @@ -220,43 +228,42 @@ type netFD struct { closing bool // immutable until Close - sysfd syscall.Handle - family int - proto int - net string - laddr Addr - raddr Addr - resultc [2]chan ioResult // read/write completion results - errnoc [2]chan int // read/write submit or cancel operation errors + sysfd syscall.Handle + family int + sotype int + isConnected bool + net string + laddr Addr + raddr Addr + resultc [2]chan ioResult // read/write completion results + errnoc [2]chan error // read/write submit or cancel operation errors // owned by client - rdeadline_delta int64 - rdeadline int64 - rio sync.Mutex - wdeadline_delta int64 - wdeadline int64 - wio sync.Mutex + rdeadline int64 + rio sync.Mutex + wdeadline int64 + wio sync.Mutex } -func allocFD(fd syscall.Handle, family, proto int, net string) (f *netFD) { +func allocFD(fd syscall.Handle, family, sotype int, net string) (f *netFD) { f = &netFD{ sysfd: fd, family: family, - proto: proto, + sotype: sotype, net: net, } runtime.SetFinalizer(f, (*netFD).Close) return f } -func newFD(fd syscall.Handle, family, proto int, net string) (f *netFD, err os.Error) { +func newFD(fd syscall.Handle, family, proto int, net string) (f *netFD, err error) { if initErr != nil { return nil, initErr } onceStartServer.Do(startServer) // Associate our socket with resultsrv.iocp. - if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != 0 { - return nil, os.Errno(e) + if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != nil { + return nil, e } return allocFD(fd, family, proto, net), nil } @@ -266,12 +273,8 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { fd.raddr = raddr } -func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) { - e := syscall.Connect(fd.sysfd, ra) - if e != 0 { - return os.Errno(e) - } - return nil +func (fd *netFD) connect(ra syscall.Sockaddr) (err error) { + return syscall.Connect(fd.sysfd, ra) } // Add a reference to this fd. @@ -300,7 +303,7 @@ func (fd *netFD) decref() { fd.sysmu.Unlock() } -func (fd *netFD) Close() os.Error { +func (fd *netFD) Close() error { if fd == nil || fd.sysfd == syscall.InvalidHandle { return os.EINVAL } @@ -312,13 +315,32 @@ func (fd *netFD) Close() os.Error { return nil } +func (fd *netFD) shutdown(how int) error { + if fd == nil || fd.sysfd == syscall.InvalidHandle { + return os.EINVAL + } + err := syscall.Shutdown(fd.sysfd, how) + if err != nil { + return &OpError{"shutdown", fd.net, fd.laddr, err} + } + return nil +} + +func (fd *netFD) CloseRead() error { + return fd.shutdown(syscall.SHUT_RD) +} + +func (fd *netFD) CloseWrite() error { + return fd.shutdown(syscall.SHUT_WR) +} + // Read from network. type readOp struct { bufOp } -func (o *readOp) Submit() (errno int) { +func (o *readOp) Submit() (err error) { var d, f uint32 return syscall.WSARecv(syscall.Handle(o.fd.sysfd), &o.buf, 1, &d, &f, &o.o, nil) } @@ -327,7 +349,7 @@ func (o *readOp) Name() string { return "WSARecv" } -func (fd *netFD) Read(buf []byte) (n int, err os.Error) { +func (fd *netFD) Read(buf []byte) (n int, err error) { if fd == nil { return 0, os.EINVAL } @@ -340,9 +362,9 @@ func (fd *netFD) Read(buf []byte) (n int, err os.Error) { } var o readOp o.Init(fd, buf, 'r') - n, err = iosrv.ExecIO(&o, fd.rdeadline_delta) + n, err = iosrv.ExecIO(&o, fd.rdeadline) if err == nil && n == 0 { - err = os.EOF + err = io.EOF } return } @@ -355,7 +377,7 @@ type readFromOp struct { rsan int32 } -func (o *readFromOp) Submit() (errno int) { +func (o *readFromOp) Submit() (err error) { var d, f uint32 return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &d, &f, &o.rsa, &o.rsan, &o.o, nil) } @@ -364,7 +386,7 @@ func (o *readFromOp) Name() string { return "WSARecvFrom" } -func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err os.Error) { +func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) { if fd == nil { return 0, nil, os.EINVAL } @@ -381,7 +403,7 @@ func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err os.Error) var o readFromOp o.Init(fd, buf, 'r') o.rsan = int32(unsafe.Sizeof(o.rsa)) - n, err = iosrv.ExecIO(&o, fd.rdeadline_delta) + n, err = iosrv.ExecIO(&o, fd.rdeadline) if err != nil { return 0, nil, err } @@ -395,7 +417,7 @@ type writeOp struct { bufOp } -func (o *writeOp) Submit() (errno int) { +func (o *writeOp) Submit() (err error) { var d uint32 return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &d, 0, &o.o, nil) } @@ -404,7 +426,7 @@ func (o *writeOp) Name() string { return "WSASend" } -func (fd *netFD) Write(buf []byte) (n int, err os.Error) { +func (fd *netFD) Write(buf []byte) (n int, err error) { if fd == nil { return 0, os.EINVAL } @@ -417,7 +439,7 @@ func (fd *netFD) Write(buf []byte) (n int, err os.Error) { } var o writeOp o.Init(fd, buf, 'w') - return iosrv.ExecIO(&o, fd.wdeadline_delta) + return iosrv.ExecIO(&o, fd.wdeadline) } // WriteTo to network. @@ -427,7 +449,7 @@ type writeToOp struct { sa syscall.Sockaddr } -func (o *writeToOp) Submit() (errno int) { +func (o *writeToOp) Submit() (err error) { var d uint32 return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &d, 0, o.sa, &o.o, nil) } @@ -436,7 +458,7 @@ func (o *writeToOp) Name() string { return "WSASendto" } -func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (n int, err os.Error) { +func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (n int, err error) { if fd == nil { return 0, os.EINVAL } @@ -453,7 +475,7 @@ func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (n int, err os.Error) var o writeToOp o.Init(fd, buf, 'w') o.sa = sa - return iosrv.ExecIO(&o, fd.wdeadline_delta) + return iosrv.ExecIO(&o, fd.wdeadline) } // Accept new network connections. @@ -464,7 +486,7 @@ type acceptOp struct { attrs [2]syscall.RawSockaddrAny // space for local and remote address only } -func (o *acceptOp) Submit() (errno int) { +func (o *acceptOp) Submit() (err error) { var d uint32 l := uint32(unsafe.Sizeof(o.attrs[0])) return syscall.AcceptEx(o.fd.sysfd, o.newsock, @@ -475,7 +497,7 @@ func (o *acceptOp) Name() string { return "AcceptEx" } -func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) { +func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err error) { if fd == nil || fd.sysfd == syscall.InvalidHandle { return nil, os.EINVAL } @@ -485,18 +507,18 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. // Get new socket. // See ../syscall/exec.go for description of ForkLock. syscall.ForkLock.RLock() - s, e := syscall.Socket(fd.family, fd.proto, 0) - if e != 0 { + s, e := syscall.Socket(fd.family, fd.sotype, 0) + if e != nil { syscall.ForkLock.RUnlock() - return nil, os.Errno(e) + return nil, e } syscall.CloseOnExec(s) syscall.ForkLock.RUnlock() // Associate our new socket with IOCP. onceStartServer.Do(startServer) - if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != 0 { - return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, os.Errno(e)} + if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != nil { + return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, e} } // Submit accept request. @@ -511,9 +533,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. // Inherit properties of the listening socket. e = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))) - if e != 0 { + if e != nil { closesocket(s) - return nil, err + return nil, e } // Get local and peer addr out of AcceptEx buffer. @@ -525,22 +547,22 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. lsa, _ := lrsa.Sockaddr() rsa, _ := rrsa.Sockaddr() - nfd = allocFD(s, fd.family, fd.proto, fd.net) + nfd = allocFD(s, fd.family, fd.sotype, fd.net) nfd.setAddr(toAddr(lsa), toAddr(rsa)) return nfd, nil } // Unimplemented functions. -func (fd *netFD) dup() (f *os.File, err os.Error) { +func (fd *netFD) dup() (f *os.File, err error) { // TODO: Implement this return nil, os.NewSyscallError("dup", syscall.EWINDOWS) } -func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) { +func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { return 0, 0, 0, nil, os.EAFNOSUPPORT } -func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) { +func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { return 0, 0, os.EAFNOSUPPORT } diff --git a/src/pkg/net/file.go b/src/pkg/net/file.go index d8528e41b..4ac280bd1 100644 --- a/src/pkg/net/file.go +++ b/src/pkg/net/file.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd +// +build darwin freebsd linux netbsd openbsd package net @@ -11,17 +11,18 @@ import ( "syscall" ) -func newFileFD(f *os.File) (nfd *netFD, err os.Error) { +func newFileFD(f *os.File) (nfd *netFD, err error) { fd, errno := syscall.Dup(f.Fd()) - if errno != 0 { + if errno != nil { return nil, os.NewSyscallError("dup", errno) } proto, errno := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE) - if errno != 0 { + if errno != nil { return nil, os.NewSyscallError("getsockopt", errno) } + family := syscall.AF_UNSPEC toAddr := sockaddrToTCP sa, _ := syscall.Getsockname(fd) switch sa.(type) { @@ -29,18 +30,21 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) { closesocket(fd) return nil, os.EINVAL case *syscall.SockaddrInet4: + family = syscall.AF_INET if proto == syscall.SOCK_DGRAM { toAddr = sockaddrToUDP } else if proto == syscall.SOCK_RAW { toAddr = sockaddrToIP } case *syscall.SockaddrInet6: + family = syscall.AF_INET6 if proto == syscall.SOCK_DGRAM { toAddr = sockaddrToUDP } else if proto == syscall.SOCK_RAW { toAddr = sockaddrToIP } case *syscall.SockaddrUnix: + family = syscall.AF_UNIX toAddr = sockaddrToUnix if proto == syscall.SOCK_DGRAM { toAddr = sockaddrToUnixgram @@ -52,7 +56,7 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) { sa, _ = syscall.Getpeername(fd) raddr := toAddr(sa) - if nfd, err = newFD(fd, 0, proto, laddr.Network()); err != nil { + if nfd, err = newFD(fd, family, proto, laddr.Network()); err != nil { return nil, err } nfd.setAddr(laddr, raddr) @@ -63,7 +67,7 @@ func newFileFD(f *os.File) (nfd *netFD, err os.Error) { // the open file f. It is the caller's responsibility to close f when // finished. Closing c does not affect f, and closing f does not // affect c. -func FileConn(f *os.File) (c Conn, err os.Error) { +func FileConn(f *os.File) (c Conn, err error) { fd, err := newFileFD(f) if err != nil { return nil, err @@ -86,7 +90,7 @@ func FileConn(f *os.File) (c Conn, err os.Error) { // to the open file f. It is the caller's responsibility to close l // when finished. Closing c does not affect l, and closing l does not // affect c. -func FileListener(f *os.File) (l Listener, err os.Error) { +func FileListener(f *os.File) (l Listener, err error) { fd, err := newFileFD(f) if err != nil { return nil, err @@ -105,7 +109,7 @@ func FileListener(f *os.File) (l Listener, err os.Error) { // corresponding to the open file f. It is the caller's // responsibility to close f when finished. Closing c does not affect // f, and closing f does not affect c. -func FilePacketConn(f *os.File) (c PacketConn, err os.Error) { +func FilePacketConn(f *os.File) (c PacketConn, err error) { fd, err := newFileFD(f) if err != nil { return nil, err diff --git a/src/pkg/net/file_plan9.go b/src/pkg/net/file_plan9.go index a07e74331..06d7cc898 100644 --- a/src/pkg/net/file_plan9.go +++ b/src/pkg/net/file_plan9.go @@ -12,7 +12,7 @@ import ( // the open file f. It is the caller's responsibility to close f when // finished. Closing c does not affect f, and closing f does not // affect c. -func FileConn(f *os.File) (c Conn, err os.Error) { +func FileConn(f *os.File) (c Conn, err error) { return nil, os.EPLAN9 } @@ -20,7 +20,7 @@ func FileConn(f *os.File) (c Conn, err os.Error) { // to the open file f. It is the caller's responsibility to close l // when finished. Closing c does not affect l, and closing l does not // affect c. -func FileListener(f *os.File) (l Listener, err os.Error) { +func FileListener(f *os.File) (l Listener, err error) { return nil, os.EPLAN9 } @@ -28,6 +28,6 @@ func FileListener(f *os.File) (l Listener, err os.Error) { // corresponding to the open file f. It is the caller's // responsibility to close f when finished. Closing c does not affect // f, and closing f does not affect c. -func FilePacketConn(f *os.File) (c PacketConn, err os.Error) { +func FilePacketConn(f *os.File) (c PacketConn, err error) { return nil, os.EPLAN9 } diff --git a/src/pkg/net/file_test.go b/src/pkg/net/file_test.go index 9a8c2dcbc..868388efa 100644 --- a/src/pkg/net/file_test.go +++ b/src/pkg/net/file_test.go @@ -8,23 +8,22 @@ import ( "os" "reflect" "runtime" - "syscall" "testing" ) type listenerFile interface { Listener - File() (f *os.File, err os.Error) + File() (f *os.File, err error) } type packetConnFile interface { PacketConn - File() (f *os.File, err os.Error) + File() (f *os.File, err error) } type connFile interface { Conn - File() (f *os.File, err os.Error) + File() (f *os.File, err error) } func testFileListener(t *testing.T, net, laddr string) { @@ -67,13 +66,13 @@ func TestFileListener(t *testing.T) { testFileListener(t, "tcp", "127.0.0.1") testFileListener(t, "tcp", "[::ffff:127.0.0.1]") } - if syscall.OS == "linux" { + if runtime.GOOS == "linux" { testFileListener(t, "unix", "@gotest/net") testFileListener(t, "unixpacket", "@gotest/net") } } -func testFilePacketConn(t *testing.T, pcf packetConnFile) { +func testFilePacketConn(t *testing.T, pcf packetConnFile, listen bool) { f, err := pcf.File() if err != nil { t.Fatalf("File failed: %v", err) @@ -85,6 +84,11 @@ func testFilePacketConn(t *testing.T, pcf packetConnFile) { if !reflect.DeepEqual(pcf.LocalAddr(), c.LocalAddr()) { t.Fatalf("LocalAddrs not equal: %#v != %#v", pcf.LocalAddr(), c.LocalAddr()) } + if listen { + if _, err := c.WriteTo([]byte{}, c.LocalAddr()); err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + } if err := c.Close(); err != nil { t.Fatalf("Close failed: %v", err) } @@ -98,7 +102,7 @@ func testFilePacketConnListen(t *testing.T, net, laddr string) { if err != nil { t.Fatalf("Listen failed: %v", err) } - testFilePacketConn(t, l.(packetConnFile)) + testFilePacketConn(t, l.(packetConnFile), true) if err := l.Close(); err != nil { t.Fatalf("Close failed: %v", err) } @@ -109,7 +113,7 @@ func testFilePacketConnDial(t *testing.T, net, raddr string) { if err != nil { t.Fatalf("Dial failed: %v", err) } - testFilePacketConn(t, c.(packetConnFile)) + testFilePacketConn(t, c.(packetConnFile), false) if err := c.Close(); err != nil { t.Fatalf("Close failed: %v", err) } @@ -127,7 +131,7 @@ func TestFilePacketConn(t *testing.T) { if supportsIPv6 && supportsIPv4map { testFilePacketConnDial(t, "udp", "[::ffff:127.0.0.1]:12345") } - if syscall.OS == "linux" { + if runtime.GOOS == "linux" { testFilePacketConnListen(t, "unixgram", "@gotest1/net") } } diff --git a/src/pkg/net/file_windows.go b/src/pkg/net/file_windows.go index 94aa58375..c50c32e21 100644 --- a/src/pkg/net/file_windows.go +++ b/src/pkg/net/file_windows.go @@ -9,17 +9,17 @@ import ( "syscall" ) -func FileConn(f *os.File) (c Conn, err os.Error) { +func FileConn(f *os.File) (c Conn, err error) { // TODO: Implement this return nil, os.NewSyscallError("FileConn", syscall.EWINDOWS) } -func FileListener(f *os.File) (l Listener, err os.Error) { +func FileListener(f *os.File) (l Listener, err error) { // TODO: Implement this return nil, os.NewSyscallError("FileListener", syscall.EWINDOWS) } -func FilePacketConn(f *os.File) (c PacketConn, err os.Error) { +func FilePacketConn(f *os.File) (c PacketConn, err error) { // TODO: Implement this return nil, os.NewSyscallError("FilePacketConn", syscall.EWINDOWS) } diff --git a/src/pkg/net/hosts.go b/src/pkg/net/hosts.go index d75e9e038..e6674ba34 100644 --- a/src/pkg/net/hosts.go +++ b/src/pkg/net/hosts.go @@ -7,11 +7,11 @@ package net import ( - "os" "sync" + "time" ) -const cacheMaxAge = int64(300) // 5 minutes. +const cacheMaxAge = 5 * time.Minute // hostsPath points to the file with static IP/address entries. var hostsPath = "/etc/hosts" @@ -21,14 +21,14 @@ var hosts struct { sync.Mutex byName map[string][]string byAddr map[string][]string - time int64 + expire time.Time path string } func readHosts() { - now, _, _ := os.Time() + now := time.Now() hp := hostsPath - if len(hosts.byName) == 0 || hosts.time+cacheMaxAge <= now || hosts.path != hp { + if len(hosts.byName) == 0 || now.After(hosts.expire) || hosts.path != hp { hs := make(map[string][]string) is := make(map[string][]string) var file *file @@ -51,7 +51,7 @@ func readHosts() { } } // Update the data cache. - hosts.time, _, _ = os.Time() + hosts.expire = time.Now().Add(cacheMaxAge) hosts.path = hp hosts.byName = hs hosts.byAddr = is diff --git a/src/pkg/net/http/Makefile b/src/pkg/net/http/Makefile new file mode 100644 index 000000000..5c351b0c4 --- /dev/null +++ b/src/pkg/net/http/Makefile @@ -0,0 +1,25 @@ +# Copyright 2009 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=net/http +GOFILES=\ + chunked.go\ + client.go\ + cookie.go\ + filetransport.go\ + fs.go\ + header.go\ + jar.go\ + lex.go\ + request.go\ + response.go\ + server.go\ + sniff.go\ + status.go\ + transfer.go\ + transport.go\ + +include ../../../Make.pkg diff --git a/src/pkg/net/http/cgi/Makefile b/src/pkg/net/http/cgi/Makefile new file mode 100644 index 000000000..0d6be0180 --- /dev/null +++ b/src/pkg/net/http/cgi/Makefile @@ -0,0 +1,12 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../../Make.inc + +TARG=net/http/cgi +GOFILES=\ + child.go\ + host.go\ + +include ../../../../Make.pkg diff --git a/src/pkg/net/http/cgi/child.go b/src/pkg/net/http/cgi/child.go new file mode 100644 index 000000000..e6c3ef911 --- /dev/null +++ b/src/pkg/net/http/cgi/child.go @@ -0,0 +1,192 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements CGI from the perspective of a child +// process. + +package cgi + +import ( + "bufio" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "os" + "strconv" + "strings" +) + +// Request returns the HTTP request as represented in the current +// environment. This assumes the current program is being run +// by a web server in a CGI environment. +// The returned Request's Body is populated, if applicable. +func Request() (*http.Request, error) { + r, err := RequestFromMap(envMap(os.Environ())) + if err != nil { + return nil, err + } + if r.ContentLength > 0 { + r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) + } + return r, nil +} + +func envMap(env []string) map[string]string { + m := make(map[string]string) + for _, kv := range env { + if idx := strings.Index(kv, "="); idx != -1 { + m[kv[:idx]] = kv[idx+1:] + } + } + return m +} + +// RequestFromMap creates an http.Request from CGI variables. +// The returned Request's Body field is not populated. +func RequestFromMap(params map[string]string) (*http.Request, error) { + r := new(http.Request) + r.Method = params["REQUEST_METHOD"] + if r.Method == "" { + return nil, errors.New("cgi: no REQUEST_METHOD in environment") + } + + r.Proto = params["SERVER_PROTOCOL"] + var ok bool + r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto) + if !ok { + return nil, errors.New("cgi: invalid SERVER_PROTOCOL version") + } + + r.Close = true + r.Trailer = http.Header{} + r.Header = http.Header{} + + r.Host = params["HTTP_HOST"] + + if lenstr := params["CONTENT_LENGTH"]; lenstr != "" { + clen, err := strconv.ParseInt(lenstr, 10, 64) + if err != nil { + return nil, errors.New("cgi: bad CONTENT_LENGTH in environment: " + lenstr) + } + r.ContentLength = clen + } + + if ct := params["CONTENT_TYPE"]; ct != "" { + r.Header.Set("Content-Type", ct) + } + + // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers + for k, v := range params { + if !strings.HasPrefix(k, "HTTP_") || k == "HTTP_HOST" { + continue + } + r.Header.Add(strings.Replace(k[5:], "_", "-", -1), v) + } + + // TODO: cookies. parsing them isn't exported, though. + + if r.Host != "" { + // Hostname is provided, so we can reasonably construct a URL, + // even if we have to assume 'http' for the scheme. + rawurl := "http://" + r.Host + params["REQUEST_URI"] + url, err := url.Parse(rawurl) + if err != nil { + return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl) + } + r.URL = url + } + // Fallback logic if we don't have a Host header or the URL + // failed to parse + if r.URL == nil { + uriStr := params["REQUEST_URI"] + url, err := url.Parse(uriStr) + if err != nil { + return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr) + } + r.URL = url + } + + // There's apparently a de-facto standard for this. + // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 + if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" { + r.TLS = &tls.ConnectionState{HandshakeComplete: true} + } + + // Request.RemoteAddr has its port set by Go's standard http + // server, so we do here too. We don't have one, though, so we + // use a dummy one. + r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], "0") + + return r, nil +} + +// Serve executes the provided Handler on the currently active CGI +// request, if any. If there's no current CGI environment +// an error is returned. The provided handler may be nil to use +// http.DefaultServeMux. +func Serve(handler http.Handler) error { + req, err := Request() + if err != nil { + return err + } + if handler == nil { + handler = http.DefaultServeMux + } + rw := &response{ + req: req, + header: make(http.Header), + bufw: bufio.NewWriter(os.Stdout), + } + handler.ServeHTTP(rw, req) + if err = rw.bufw.Flush(); err != nil { + return err + } + return nil +} + +type response struct { + req *http.Request + header http.Header + bufw *bufio.Writer + headerSent bool +} + +func (r *response) Flush() { + r.bufw.Flush() +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(p []byte) (n int, err error) { + if !r.headerSent { + r.WriteHeader(http.StatusOK) + } + return r.bufw.Write(p) +} + +func (r *response) WriteHeader(code int) { + if r.headerSent { + // Note: explicitly using Stderr, as Stdout is our HTTP output. + fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL) + return + } + r.headerSent = true + fmt.Fprintf(r.bufw, "Status: %d %s\r\n", code, http.StatusText(code)) + + // Set a default Content-Type + if _, hasType := r.header["Content-Type"]; !hasType { + r.header.Add("Content-Type", "text/html; charset=utf-8") + } + + r.header.Write(r.bufw) + r.bufw.WriteString("\r\n") + r.bufw.Flush() +} diff --git a/src/pkg/net/http/cgi/child_test.go b/src/pkg/net/http/cgi/child_test.go new file mode 100644 index 000000000..ec53ab851 --- /dev/null +++ b/src/pkg/net/http/cgi/child_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for CGI (the child process perspective) + +package cgi + +import ( + "testing" +) + +func TestRequest(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "REQUEST_METHOD": "GET", + "HTTP_HOST": "example.com", + "HTTP_REFERER": "elsewhere", + "HTTP_USER_AGENT": "goclient", + "HTTP_FOO_BAR": "baz", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", + "CONTENT_TYPE": "text/xml", + "HTTPS": "1", + "REMOTE_ADDR": "5.6.7.8", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if g, e := req.UserAgent(), "goclient"; e != g { + t.Errorf("expected UserAgent %q; got %q", e, g) + } + if g, e := req.Method, "GET"; e != g { + t.Errorf("expected Method %q; got %q", e, g) + } + if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g { + t.Errorf("expected Content-Type %q; got %q", e, g) + } + if g, e := req.ContentLength, int64(123); e != g { + t.Errorf("expected ContentLength %d; got %d", e, g) + } + if g, e := req.Referer(), "elsewhere"; e != g { + t.Errorf("expected Referer %q; got %q", e, g) + } + if req.Header == nil { + t.Fatalf("unexpected nil Header") + } + if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g { + t.Errorf("expected Foo-Bar %q; got %q", e, g) + } + if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g { + t.Errorf("expected URL %q; got %q", e, g) + } + if g, e := req.FormValue("a"), "b"; e != g { + t.Errorf("expected FormValue(a) %q; got %q", e, g) + } + if req.Trailer == nil { + t.Errorf("unexpected nil Trailer") + } + if req.TLS == nil { + t.Errorf("expected non-nil TLS") + } + if e, g := "5.6.7.8:0", req.RemoteAddr; e != g { + t.Errorf("RemoteAddr: got %q; want %q", g, e) + } +} + +func TestRequestWithoutHost(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "", + "REQUEST_METHOD": "GET", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if req.URL == nil { + t.Fatalf("unexpected nil URL") + } + if g, e := req.URL.String(), "/path?a=b"; e != g { + t.Errorf("expected URL %q; got %q", e, g) + } +} diff --git a/src/pkg/net/http/cgi/host.go b/src/pkg/net/http/cgi/host.go new file mode 100644 index 000000000..73a9b6ea6 --- /dev/null +++ b/src/pkg/net/http/cgi/host.go @@ -0,0 +1,350 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements the host side of CGI (being the webserver +// parent process). + +// Package cgi implements CGI (Common Gateway Interface) as specified +// in RFC 3875. +// +// Note that using CGI means starting a new process to handle each +// request, which is typically less efficient than using a +// long-running server. This package is intended primarily for +// compatibility with existing systems. +package cgi + +import ( + "bufio" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" +) + +var trailingPort = regexp.MustCompile(`:([0-9]+)$`) + +var osDefaultInheritEnv = map[string][]string{ + "darwin": {"DYLD_LIBRARY_PATH"}, + "freebsd": {"LD_LIBRARY_PATH"}, + "hpux": {"LD_LIBRARY_PATH", "SHLIB_PATH"}, + "irix": {"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"}, + "linux": {"LD_LIBRARY_PATH"}, + "openbsd": {"LD_LIBRARY_PATH"}, + "solaris": {"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"}, + "windows": {"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}, +} + +// Handler runs an executable in a subprocess with a CGI environment. +type Handler struct { + Path string // path to the CGI executable + Root string // root URI prefix of handler or empty for "/" + + // Dir specifies the CGI executable's working directory. + // If Dir is empty, the base directory of Path is used. + // If Path has no base directory, the current working + // directory is used. + Dir string + + Env []string // extra environment variables to set, if any, as "key=value" + InheritEnv []string // environment variables to inherit from host, as "key" + Logger *log.Logger // optional log for errors or nil to use log.Print + Args []string // optional arguments to pass to child process + + // PathLocationHandler specifies the root http Handler that + // should handle internal redirects when the CGI process + // returns a Location header value starting with a "/", as + // specified in RFC 3875 § 6.3.2. This will likely be + // http.DefaultServeMux. + // + // If nil, a CGI response with a local URI path is instead sent + // back to the client and not redirected internally. + PathLocationHandler http.Handler +} + +// removeLeadingDuplicates remove leading duplicate in environments. +// It's possible to override environment like following. +// cgi.Handler{ +// ... +// Env: []string{"SCRIPT_FILENAME=foo.php"}, +// } +func removeLeadingDuplicates(env []string) (ret []string) { + n := len(env) + for i := 0; i < n; i++ { + e := env[i] + s := strings.SplitN(e, "=", 2)[0] + found := false + for j := i + 1; j < n; j++ { + if s == strings.SplitN(env[j], "=", 2)[0] { + found = true + break + } + } + if !found { + ret = append(ret, e) + } + } + return +} + +func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + root := h.Root + if root == "" { + root = "/" + } + + if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" { + rw.WriteHeader(http.StatusBadRequest) + rw.Write([]byte("Chunked request bodies are not supported by CGI.")) + return + } + + pathInfo := req.URL.Path + if root != "/" && strings.HasPrefix(pathInfo, root) { + pathInfo = pathInfo[len(root):] + } + + port := "80" + if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 { + port = matches[1] + } + + env := []string{ + "SERVER_SOFTWARE=go", + "SERVER_NAME=" + req.Host, + "SERVER_PROTOCOL=HTTP/1.1", + "HTTP_HOST=" + req.Host, + "GATEWAY_INTERFACE=CGI/1.1", + "REQUEST_METHOD=" + req.Method, + "QUERY_STRING=" + req.URL.RawQuery, + "REQUEST_URI=" + req.URL.RequestURI(), + "PATH_INFO=" + pathInfo, + "SCRIPT_NAME=" + root, + "SCRIPT_FILENAME=" + h.Path, + "REMOTE_ADDR=" + req.RemoteAddr, + "REMOTE_HOST=" + req.RemoteAddr, + "SERVER_PORT=" + port, + } + + if req.TLS != nil { + env = append(env, "HTTPS=on") + } + + for k, v := range req.Header { + k = strings.Map(upperCaseAndUnderscore, k) + joinStr := ", " + if k == "COOKIE" { + joinStr = "; " + } + env = append(env, "HTTP_"+k+"="+strings.Join(v, joinStr)) + } + + if req.ContentLength > 0 { + env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength)) + } + if ctype := req.Header.Get("Content-Type"); ctype != "" { + env = append(env, "CONTENT_TYPE="+ctype) + } + + if h.Env != nil { + env = append(env, h.Env...) + } + + envPath := os.Getenv("PATH") + if envPath == "" { + envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin" + } + env = append(env, "PATH="+envPath) + + for _, e := range h.InheritEnv { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + + for _, e := range osDefaultInheritEnv[runtime.GOOS] { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + + env = removeLeadingDuplicates(env) + + var cwd, path string + if h.Dir != "" { + path = h.Path + cwd = h.Dir + } else { + cwd, path = filepath.Split(h.Path) + } + if cwd == "" { + cwd = "." + } + + internalError := func(err error) { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("CGI error: %v", err) + } + + cmd := &exec.Cmd{ + Path: path, + Args: append([]string{h.Path}, h.Args...), + Dir: cwd, + Env: env, + Stderr: os.Stderr, // for now + } + if req.ContentLength != 0 { + cmd.Stdin = req.Body + } + stdoutRead, err := cmd.StdoutPipe() + if err != nil { + internalError(err) + return + } + + err = cmd.Start() + if err != nil { + internalError(err) + return + } + defer cmd.Wait() + defer stdoutRead.Close() + + linebody, _ := bufio.NewReaderSize(stdoutRead, 1024) + headers := make(http.Header) + statusCode := 0 + for { + line, isPrefix, err := linebody.ReadLine() + if isPrefix { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: long header line from subprocess.") + return + } + if err == io.EOF { + break + } + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error reading headers: %v", err) + return + } + if len(line) == 0 { + break + } + parts := strings.SplitN(string(line), ":", 2) + if len(parts) < 2 { + h.printf("cgi: bogus header line: %s", string(line)) + continue + } + header, val := parts[0], parts[1] + header = strings.TrimSpace(header) + val = strings.TrimSpace(val) + switch { + case header == "Status": + if len(val) < 3 { + h.printf("cgi: bogus status (short): %q", val) + return + } + code, err := strconv.Atoi(val[0:3]) + if err != nil { + h.printf("cgi: bogus status: %q", val) + h.printf("cgi: line was %q", line) + return + } + statusCode = code + default: + headers.Add(header, val) + } + } + + if loc := headers.Get("Location"); loc != "" { + if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil { + h.handleInternalRedirect(rw, req, loc) + return + } + if statusCode == 0 { + statusCode = http.StatusFound + } + } + + if statusCode == 0 { + statusCode = http.StatusOK + } + + // Copy headers to rw's headers, after we've decided not to + // go into handleInternalRedirect, which won't want its rw + // headers to have been touched. + for k, vv := range headers { + for _, v := range vv { + rw.Header().Add(k, v) + } + } + + rw.WriteHeader(statusCode) + + _, err = io.Copy(rw, linebody) + if err != nil { + h.printf("cgi: copy error: %v", err) + } +} + +func (h *Handler) printf(format string, v ...interface{}) { + if h.Logger != nil { + h.Logger.Printf(format, v...) + } else { + log.Printf(format, v...) + } +} + +func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) { + url, err := req.URL.Parse(path) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error resolving local URI path %q: %v", path, err) + return + } + // TODO: RFC 3875 isn't clear if only GET is supported, but it + // suggests so: "Note that any message-body attached to the + // request (such as for a POST request) may not be available + // to the resource that is the target of the redirect." We + // should do some tests against Apache to see how it handles + // POST, HEAD, etc. Does the internal redirect get the same + // method or just GET? What about incoming headers? + // (e.g. Cookies) Which headers, if any, are copied into the + // second request? + newReq := &http.Request{ + Method: "GET", + URL: url, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: url.Host, + RemoteAddr: req.RemoteAddr, + TLS: req.TLS, + } + h.PathLocationHandler.ServeHTTP(rw, newReq) +} + +func upperCaseAndUnderscore(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r - ('a' - 'A') + case r == '-': + return '_' + case r == '=': + // Maybe not part of the CGI 'spec' but would mess up + // the environment in any case, as Go represents the + // environment as a slice of "key=value" strings. + return '_' + } + // TODO: other transformations in spec or practice? + return r +} diff --git a/src/pkg/net/http/cgi/host_test.go b/src/pkg/net/http/cgi/host_test.go new file mode 100644 index 000000000..9ef80ea5e --- /dev/null +++ b/src/pkg/net/http/cgi/host_test.go @@ -0,0 +1,477 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for package cgi + +package cgi + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "testing" + "time" +) + +func newRequest(httpreq string) *http.Request { + buf := bufio.NewReader(strings.NewReader(httpreq)) + req, err := http.ReadRequest(buf) + if err != nil { + panic("cgi: bogus http request in test: " + httpreq) + } + req.RemoteAddr = "1.2.3.4" + return req +} + +func runCgiTest(t *testing.T, h *Handler, httpreq string, expectedMap map[string]string) *httptest.ResponseRecorder { + rw := httptest.NewRecorder() + req := newRequest(httpreq) + h.ServeHTTP(rw, req) + + // Make a map to hold the test map that the CGI returns. + m := make(map[string]string) + linesRead := 0 +readlines: + for { + line, err := rw.Body.ReadString('\n') + switch { + case err == io.EOF: + break readlines + case err != nil: + t.Fatalf("unexpected error reading from CGI: %v", err) + } + linesRead++ + trimmedLine := strings.TrimRight(line, "\r\n") + split := strings.SplitN(trimmedLine, "=", 2) + if len(split) != 2 { + t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v", + len(split), linesRead, line, m) + } + m[split[0]] = split[1] + } + + for key, expected := range expectedMap { + if got := m[key]; got != expected { + t.Errorf("for key %q got %q; expected %q", key, got, expected) + } + } + return rw +} + +var cgiTested = false +var cgiWorks bool + +func skipTest(t *testing.T) bool { + if !cgiTested { + cgiTested = true + cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil + } + if !cgiWorks { + // No Perl on Windows, needed by test.cgi + // TODO: make the child process be Go, not Perl. + t.Logf("Skipping test: test.cgi failed.") + return true + } + return false +} + +func TestCGIBasicGet(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "test": "Hello CGI", + "param-a": "b", + "param-foo": "bar", + "env-GATEWAY_INTERFACE": "CGI/1.1", + "env-HTTP_HOST": "example.com", + "env-PATH_INFO": "", + "env-QUERY_STRING": "foo=bar&a=b", + "env-REMOTE_ADDR": "1.2.3.4", + "env-REMOTE_HOST": "1.2.3.4", + "env-REQUEST_METHOD": "GET", + "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + "env-SERVER_NAME": "example.com", + "env-SERVER_PORT": "80", + "env-SERVER_SOFTWARE": "go", + } + replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) + + if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected { + t.Errorf("got a Content-Type of %q; expected %q", got, expected) + } + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +func TestCGIBasicGetAbsPath(t *testing.T) { + if skipTest(t) { + return + } + pwd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd error: %v", err) + } + h := &Handler{ + Path: pwd + "/testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/test.cgi?foo=bar&a=b", + "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + } + runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestPathInfo(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "param-a": "b", + "env-PATH_INFO": "/extrapath", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/test.cgi/extrapath?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/test.cgi", + } + runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestPathInfoDirRoot(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/myscript/", + } + expectedMap := map[string]string{ + "env-PATH_INFO": "bar", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/myscript/", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestDupHeaders(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-HTTP_COOKIE": "nom=NOM; yum=YUM", + "env-HTTP_X_FOO": "val1, val2", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+ + "Cookie: nom=NOM\n"+ + "Cookie: yum=YUM\n"+ + "X-Foo: val1\n"+ + "X-Foo: val2\n"+ + "Host: example.com\n\n", + expectedMap) +} + +func TestPathInfoNoRoot(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "", + } + expectedMap := map[string]string{ + "env-PATH_INFO": "/bar", + "env-QUERY_STRING": "a=b", + "env-REQUEST_URI": "/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-SCRIPT_NAME": "/", + } + runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestCGIBasicPost(t *testing.T) { + if skipTest(t) { + return + } + postReq := `POST /test.cgi?a=b HTTP/1.0 +Host: example.com +Content-Type: application/x-www-form-urlencoded +Content-Length: 15 + +postfoo=postbar` + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{ + "test": "Hello CGI", + "param-postfoo": "postbar", + "env-REQUEST_METHOD": "POST", + "env-CONTENT_LENGTH": "15", + "env-REQUEST_URI": "/test.cgi?a=b", + } + runCgiTest(t, h, postReq, expectedMap) +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +// The CGI spec doesn't allow chunked requests. +func TestCGIPostChunked(t *testing.T) { + if skipTest(t) { + return + } + postReq := `POST /test.cgi?a=b HTTP/1.1 +Host: example.com +Content-Type: application/x-www-form-urlencoded +Transfer-Encoding: chunked + +` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("") + + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap := map[string]string{} + resp := runCgiTest(t, h, postReq, expectedMap) + if got, expected := resp.Code, http.StatusBadRequest; got != expected { + t.Fatalf("Expected %v response code from chunked request body; got %d", + expected, got) + } +} + +func TestRedirect(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil) + if e, g := 302, rec.Code; e != g { + t.Errorf("expected status code %d; got %d", e, g) + } + if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g { + t.Errorf("expected Location header of %q; got %q", e, g) + } +} + +func TestInternalRedirect(t *testing.T) { + if skipTest(t) { + return + } + baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) + fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) + }) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + PathLocationHandler: baseHandler, + } + expectedMap := map[string]string{ + "basepath": "/foo", + "remoteaddr": "1.2.3.4", + } + runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +// TestCopyError tests that we kill the process if there's an error copying +// its output. (for example, from the client having gone away) +func TestCopyError(t *testing.T) { + if skipTest(t) || runtime.GOOS == "windows" { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + ts := httptest.NewServer(h) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + req, _ := http.NewRequest("GET", "http://example.com/test.cgi?bigresponse=1", nil) + err = req.Write(conn) + if err != nil { + t.Fatalf("Write: %v", err) + } + + res, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + t.Fatalf("ReadResponse: %v", err) + } + + pidstr := res.Header.Get("X-CGI-Pid") + if pidstr == "" { + t.Fatalf("expected an X-CGI-Pid header in response") + } + pid, err := strconv.Atoi(pidstr) + if err != nil { + t.Fatalf("invalid X-CGI-Pid value") + } + + var buf [5000]byte + n, err := io.ReadFull(res.Body, buf[:]) + if err != nil { + t.Fatalf("ReadFull: %d bytes, %v", n, err) + } + + childRunning := func() bool { + p, err := os.FindProcess(pid) + if err != nil { + return false + } + return p.Signal(os.UnixSignal(0)) == nil + } + + if !childRunning() { + t.Fatalf("pre-conn.Close, expected child to be running") + } + conn.Close() + + tries := 0 + for tries < 25 && childRunning() { + time.Sleep(50 * time.Millisecond * time.Duration(tries)) + tries++ + } + if childRunning() { + t.Fatalf("post-conn.Close, expected child to be gone") + } +} + +func TestDirUnix(t *testing.T) { + if skipTest(t) || runtime.GOOS == "windows" { + return + } + + cwd, _ := os.Getwd() + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + Dir: cwd, + } + expectedMap := map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) + + cwd, _ = os.Getwd() + cwd = filepath.Join(cwd, "testdata") + h = &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + expectedMap = map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestDirWindows(t *testing.T) { + if skipTest(t) || runtime.GOOS != "windows" { + return + } + + cgifile, _ := filepath.Abs("testdata/test.cgi") + + var perl string + var err error + perl, err = exec.LookPath("perl") + if err != nil { + return + } + perl, _ = filepath.Abs(perl) + + cwd, _ := os.Getwd() + h := &Handler{ + Path: perl, + Root: "/test.cgi", + Dir: cwd, + Args: []string{cgifile}, + Env: []string{"SCRIPT_FILENAME=" + cgifile}, + } + expectedMap := map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) + + // If not specify Dir on windows, working directory should be + // base directory of perl. + cwd, _ = filepath.Split(perl) + if cwd != "" && cwd[len(cwd)-1] == filepath.Separator { + cwd = cwd[:len(cwd)-1] + } + h = &Handler{ + Path: perl, + Root: "/test.cgi", + Args: []string{cgifile}, + Env: []string{"SCRIPT_FILENAME=" + cgifile}, + } + expectedMap = map[string]string{ + "cwd": cwd, + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} + +func TestEnvOverride(t *testing.T) { + cgifile, _ := filepath.Abs("testdata/test.cgi") + + var perl string + var err error + perl, err = exec.LookPath("perl") + if err != nil { + return + } + perl, _ = filepath.Abs(perl) + + cwd, _ := os.Getwd() + h := &Handler{ + Path: perl, + Root: "/test.cgi", + Dir: cwd, + Args: []string{cgifile}, + Env: []string{ + "SCRIPT_FILENAME=" + cgifile, + "REQUEST_URI=/foo/bar"}, + } + expectedMap := map[string]string{ + "cwd": cwd, + "env-SCRIPT_FILENAME": cgifile, + "env-REQUEST_URI": "/foo/bar", + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} diff --git a/src/pkg/net/http/cgi/matryoshka_test.go b/src/pkg/net/http/cgi/matryoshka_test.go new file mode 100644 index 000000000..1a44df204 --- /dev/null +++ b/src/pkg/net/http/cgi/matryoshka_test.go @@ -0,0 +1,74 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests a Go CGI program running under a Go CGI host process. +// Further, the two programs are the same binary, just checking +// their environment to figure out what mode to run in. + +package cgi + +import ( + "fmt" + "net/http" + "os" + "testing" +) + +// This test is a CGI host (testing host.go) that runs its own binary +// as a child process testing the other half of CGI (child.go). +func TestHostingOurselves(t *testing.T) { + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "test": "Hello CGI-in-CGI", + "param-a": "b", + "param-foo": "bar", + "env-GATEWAY_INTERFACE": "CGI/1.1", + "env-HTTP_HOST": "example.com", + "env-PATH_INFO": "", + "env-QUERY_STRING": "foo=bar&a=b", + "env-REMOTE_ADDR": "1.2.3.4", + "env-REMOTE_HOST": "1.2.3.4", + "env-REQUEST_METHOD": "GET", + "env-REQUEST_URI": "/test.go?foo=bar&a=b", + "env-SCRIPT_FILENAME": os.Args[0], + "env-SCRIPT_NAME": "/test.go", + "env-SERVER_NAME": "example.com", + "env-SERVER_PORT": "80", + "env-SERVER_SOFTWARE": "go", + } + replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) + + if expected, got := "text/html; charset=utf-8", replay.Header().Get("Content-Type"); got != expected { + t.Errorf("got a Content-Type of %q; expected %q", got, expected) + } + if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { + t.Errorf("got a X-Test-Header of %q; expected %q", got, expected) + } +} + +// Note: not actually a test. +func TestBeChildCGIProcess(t *testing.T) { + if os.Getenv("REQUEST_METHOD") == "" { + // Not in a CGI environment; skipping test. + return + } + Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Test-Header", "X-Test-Value") + fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n") + req.ParseForm() + for k, vv := range req.Form { + for _, v := range vv { + fmt.Fprintf(rw, "param-%s=%s\n", k, v) + } + } + for _, kv := range os.Environ() { + fmt.Fprintf(rw, "env-%s\n", kv) + } + })) + os.Exit(0) +} diff --git a/src/pkg/net/http/cgi/testdata/test.cgi b/src/pkg/net/http/cgi/testdata/test.cgi new file mode 100755 index 000000000..b46b1330f --- /dev/null +++ b/src/pkg/net/http/cgi/testdata/test.cgi @@ -0,0 +1,96 @@ +#!/usr/bin/perl +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. +# +# Test script run as a child process under cgi_test.go + +use strict; +use Cwd; + +my $q = MiniCGI->new; +my $params = $q->Vars; + +if ($params->{"loc"}) { + print "Location: $params->{loc}\r\n\r\n"; + exit(0); +} + +my $NL = "\r\n"; +$NL = "\n" if $params->{mode} eq "NL"; + +my $p = sub { + print "$_[0]$NL"; +}; + +# With carriage returns +$p->("Content-Type: text/html"); +$p->("X-CGI-Pid: $$"); +$p->("X-Test-Header: X-Test-Value"); +$p->(""); + +if ($params->{"bigresponse"}) { + for (1..1024) { + print "A" x 1024, "\n"; + } + exit 0; +} + +print "test=Hello CGI\n"; + +foreach my $k (sort keys %$params) { + print "param-$k=$params->{$k}\n"; +} + +foreach my $k (sort keys %ENV) { + my $clean_env = $ENV{$k}; + $clean_env =~ s/[\n\r]//g; + print "env-$k=$clean_env\n"; +} + +# NOTE: don't call getcwd() for windows. +# msys return /c/go/src/... not C:\go\... +my $dir; +if ($^O eq 'MSWin32' || $^O eq 'msys') { + my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe'; + $cmd =~ s!\\!/!g; + $dir = `$cmd /c cd`; + chomp $dir; +} else { + $dir = getcwd(); +} +print "cwd=$dir\n"; + + +# A minimal version of CGI.pm, for people without the perl-modules +# package installed. (CGI.pm used to be part of the Perl core, but +# some distros now bundle perl-base and perl-modules separately...) +package MiniCGI; + +sub new { + my $class = shift; + return bless {}, $class; +} + +sub Vars { + my $self = shift; + my $pairs; + if ($ENV{CONTENT_LENGTH}) { + $pairs = do { local $/; <STDIN> }; + } else { + $pairs = $ENV{QUERY_STRING}; + } + my $vars = {}; + foreach my $kv (split(/&/, $pairs)) { + my ($k, $v) = split(/=/, $kv, 2); + $vars->{_urldecode($k)} = _urldecode($v); + } + return $vars; +} + +sub _urldecode { + my $v = shift; + $v =~ tr/+/ /; + $v =~ s/%([a-fA-F0-9][a-fA-F0-9])/pack("C", hex($1))/eg; + return $v; +} diff --git a/src/pkg/net/http/chunked.go b/src/pkg/net/http/chunked.go new file mode 100644 index 000000000..60a478fd8 --- /dev/null +++ b/src/pkg/net/http/chunked.go @@ -0,0 +1,170 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The wire protocol for HTTP's "chunked" Transfer-Encoding. + +// This code is duplicated in httputil/chunked.go. +// Please make any changes in both files. + +package http + +import ( + "bufio" + "bytes" + "errors" + "io" + "strconv" +) + +const maxLineLength = 4096 // assumed <= bufio.defaultBufSize + +var ErrLineTooLong = errors.New("header line too long") + +// newChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// newChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func newChunkedReader(r io.Reader) io.Reader { + br, ok := r.(*bufio.Reader) + if !ok { + br = bufio.NewReader(r) + } + return &chunkedReader{r: br} +} + +type chunkedReader struct { + r *bufio.Reader + n uint64 // unread bytes in chunk + err error +} + +func (cr *chunkedReader) beginChunk() { + // chunk-size CRLF + var line string + line, cr.err = readLine(cr.r) + if cr.err != nil { + return + } + cr.n, cr.err = strconv.ParseUint(line, 16, 64) + if cr.err != nil { + return + } + if cr.n == 0 { + cr.err = io.EOF + } +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + if cr.err != nil { + return 0, cr.err + } + if cr.n == 0 { + cr.beginChunk() + if cr.err != nil { + return 0, cr.err + } + } + if uint64(len(b)) > cr.n { + b = b[0:cr.n] + } + n, cr.err = cr.r.Read(b) + cr.n -= uint64(n) + if cr.n == 0 && cr.err == nil { + // end of chunk (CRLF) + b := make([]byte, 2) + if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil { + if b[0] != '\r' || b[1] != '\n' { + cr.err = errors.New("malformed chunked encoding") + } + } + } + return n, cr.err +} + +// Read a line of bytes (up to \n) from b. +// Give up if the line exceeds maxLineLength. +// The returned bytes are a pointer into storage in +// the bufio, so they are only valid until the next bufio read. +func readLineBytes(b *bufio.Reader) (p []byte, err error) { + if p, err = b.ReadSlice('\n'); err != nil { + // We always know when EOF is coming. + // If the caller asked for a line, there should be a line. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } else if err == bufio.ErrBufferFull { + err = ErrLineTooLong + } + return nil, err + } + if len(p) >= maxLineLength { + return nil, ErrLineTooLong + } + + // Chop off trailing white space. + p = bytes.TrimRight(p, " \r\t\n") + + return p, nil +} + +// readLineBytes, but convert the bytes into a string. +func readLine(b *bufio.Reader) (s string, err error) { + p, e := readLineBytes(b) + if e != nil { + return "", e + } + return string(p), nil +} + +// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream. +// +// newChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using newChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func newChunkedWriter(w io.Writer) io.WriteCloser { + return &chunkedWriter{w} +} + +// Writing to chunkedWriter translates to writing in HTTP chunked Transfer +// Encoding wire format to the underlying Wire chunkedWriter. +type chunkedWriter struct { + Wire io.Writer +} + +// Write the contents of data as one chunk to Wire. +// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has +// a bug since it does not check for success of io.WriteString +func (cw *chunkedWriter) Write(data []byte) (n int, err error) { + + // Don't send 0-length data. It looks like EOF for chunked encoding. + if len(data) == 0 { + return 0, nil + } + + head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" + + if _, err = io.WriteString(cw.Wire, head); err != nil { + return 0, err + } + if n, err = cw.Wire.Write(data); err != nil { + return + } + if n != len(data) { + err = io.ErrShortWrite + return + } + _, err = io.WriteString(cw.Wire, "\r\n") + + return +} + +func (cw *chunkedWriter) Close() error { + _, err := io.WriteString(cw.Wire, "0\r\n") + return err +} diff --git a/src/pkg/net/http/chunked_test.go b/src/pkg/net/http/chunked_test.go new file mode 100644 index 000000000..b77ee2ff2 --- /dev/null +++ b/src/pkg/net/http/chunked_test.go @@ -0,0 +1,39 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code is duplicated in httputil/chunked_test.go. +// Please make any changes in both files. + +package http + +import ( + "bytes" + "io/ioutil" + "testing" +) + +func TestChunk(t *testing.T) { + var b bytes.Buffer + + w := newChunkedWriter(&b) + const chunk1 = "hello, " + const chunk2 = "world! 0123456789abcdef" + w.Write([]byte(chunk1)) + w.Write([]byte(chunk2)) + w.Close() + + if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e { + t.Fatalf("chunk writer wrote %q; want %q", g, e) + } + + r := newChunkedReader(&b) + data, err := ioutil.ReadAll(r) + if err != nil { + t.Logf(`data: "%s"`, data) + t.Fatalf("ReadAll from reader: %v", err) + } + if g, e := string(data), chunk1+chunk2; g != e { + t.Errorf("chunk reader read %q; want %q", g, e) + } +} diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go new file mode 100644 index 000000000..c9f024017 --- /dev/null +++ b/src/pkg/net/http/client.go @@ -0,0 +1,326 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP client. See RFC 2616. +// +// This is the high-level Client interface. +// The low-level implementation is in transport.go. + +package http + +import ( + "encoding/base64" + "errors" + "fmt" + "io" + "net/url" + "strings" +) + +// A Client is an HTTP client. Its zero value (DefaultClient) is a usable client +// that uses DefaultTransport. +// +// The Client's Transport typically has internal state (cached +// TCP connections), so Clients should be reused instead of created as +// needed. Clients are safe for concurrent use by multiple goroutines. +type Client struct { + // Transport specifies the mechanism by which individual + // HTTP requests are made. + // If nil, DefaultTransport is used. + Transport RoundTripper + + // CheckRedirect specifies the policy for handling redirects. + // If CheckRedirect is not nil, the client calls it before + // following an HTTP redirect. The arguments req and via + // are the upcoming request and the requests made already, + // oldest first. If CheckRedirect returns an error, the client + // returns that error instead of issue the Request req. + // + // If CheckRedirect is nil, the Client uses its default policy, + // which is to stop after 10 consecutive requests. + CheckRedirect func(req *Request, via []*Request) error + + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + Jar CookieJar +} + +// DefaultClient is the default Client and is used by Get, Head, and Post. +var DefaultClient = &Client{} + +// RoundTripper is an interface representing the ability to execute a +// single HTTP transaction, obtaining the Response for a given Request. +// +// A RoundTripper must be safe for concurrent use by multiple +// goroutines. +type RoundTripper interface { + // RoundTrip executes a single HTTP transaction, returning + // the Response for the request req. RoundTrip should not + // attempt to interpret the response. In particular, + // RoundTrip must return err == nil if it obtained a response, + // regardless of the response's HTTP status code. A non-nil + // err should be reserved for failure to obtain a response. + // Similarly, RoundTrip should not attempt to handle + // higher-level protocol details such as redirects, + // authentication, or cookies. + // + // RoundTrip should not modify the request, except for + // consuming the Body. The request's URL and Header fields + // are guaranteed to be initialized. + RoundTrip(*Request) (*Response, error) +} + +// Given a string of the form "host", "host:port", or "[ipv6::address]:port", +// return true if the string includes a port. +func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } + +// Used in Send to implement io.ReadCloser by bundling together the +// bufio.Reader through which we read the response, and the underlying +// network connection. +type readClose struct { + io.Reader + io.Closer +} + +// Do sends an HTTP request and returns an HTTP response, following +// policy (e.g. redirects, cookies, auth) as configured on the client. +// +// A non-nil response always contains a non-nil resp.Body. +// +// Callers should close resp.Body when done reading from it. If +// resp.Body is not closed, the Client's underlying RoundTripper +// (typically Transport) may not be able to re-use a persistent TCP +// connection to the server for a subsequent "keep-alive" request. +// +// Generally Get, Post, or PostForm will be used instead of Do. +func (c *Client) Do(req *Request) (resp *Response, err error) { + if req.Method == "GET" || req.Method == "HEAD" { + return c.doFollowingRedirects(req) + } + return send(req, c.Transport) +} + +// send issues an HTTP request. Caller should close resp.Body when done reading from it. +func send(req *Request, t RoundTripper) (resp *Response, err error) { + if t == nil { + t = DefaultTransport + if t == nil { + err = errors.New("http: no Client.Transport or DefaultTransport") + return + } + } + + if req.URL == nil { + return nil, errors.New("http: nil Request.URL") + } + + if req.RequestURI != "" { + return nil, errors.New("http: Request.RequestURI can't be set in client requests.") + } + + // Most the callers of send (Get, Post, et al) don't need + // Headers, leaving it uninitialized. We guarantee to the + // Transport that this has been initialized, though. + if req.Header == nil { + req.Header = make(Header) + } + + if u := req.URL.User; u != nil { + req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(u.String()))) + } + return t.RoundTrip(req) +} + +// True if the specified HTTP status code is one for which the Get utility should +// automatically redirect. +func shouldRedirect(statusCode int) bool { + switch statusCode { + case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect: + return true + } + return false +} + +// Get issues a GET to the specified URL. If the response is one of the following +// redirect codes, Get follows the redirect, up to a maximum of 10 redirects: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// +// Caller should close r.Body when done reading from it. +// +// Get is a wrapper around DefaultClient.Get. +func Get(url string) (r *Response, err error) { + return DefaultClient.Get(url) +} + +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// Client's CheckRedirect function. +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// +// Caller should close r.Body when done reading from it. +func (c *Client) Get(url string) (r *Response, err error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.doFollowingRedirects(req) +} + +func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { + // TODO: if/when we add cookie support, the redirected request shouldn't + // necessarily supply the same cookies as the original. + var base *url.URL + redirectChecker := c.CheckRedirect + if redirectChecker == nil { + redirectChecker = defaultCheckRedirect + } + var via []*Request + + if ireq.URL == nil { + return nil, errors.New("http: nil Request.URL") + } + + jar := c.Jar + if jar == nil { + jar = blackHoleJar{} + } + + req := ireq + urlStr := "" // next relative or absolute URL to fetch (after first request) + for redirect := 0; ; redirect++ { + if redirect != 0 { + req = new(Request) + req.Method = ireq.Method + req.Header = make(Header) + req.URL, err = base.Parse(urlStr) + if err != nil { + break + } + if len(via) > 0 { + // Add the Referer header. + lastReq := via[len(via)-1] + if lastReq.URL.Scheme != "https" { + req.Header.Set("Referer", lastReq.URL.String()) + } + + err = redirectChecker(req, via) + if err != nil { + break + } + } + } + + for _, cookie := range jar.Cookies(req.URL) { + req.AddCookie(cookie) + } + urlStr = req.URL.String() + if r, err = send(req, c.Transport); err != nil { + break + } + if c := r.Cookies(); len(c) > 0 { + jar.SetCookies(req.URL, c) + } + + if shouldRedirect(r.StatusCode) { + r.Body.Close() + if urlStr = r.Header.Get("Location"); urlStr == "" { + err = errors.New(fmt.Sprintf("%d response missing Location header", r.StatusCode)) + break + } + base = req.URL + via = append(via, req) + continue + } + return + } + + method := ireq.Method + err = &url.Error{method[0:1] + strings.ToLower(method[1:]), urlStr, err} + return +} + +func defaultCheckRedirect(req *Request, via []*Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil +} + +// Post issues a POST to the specified URL. +// +// Caller should close r.Body when done reading from it. +// +// Post is a wrapper around DefaultClient.Post +func Post(url string, bodyType string, body io.Reader) (r *Response, err error) { + return DefaultClient.Post(url, bodyType, body) +} + +// Post issues a POST to the specified URL. +// +// Caller should close r.Body when done reading from it. +func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err error) { + req, err := NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + return send(req, c.Transport) +} + +// PostForm issues a POST to the specified URL, +// with data's keys and values urlencoded as the request body. +// +// Caller should close r.Body when done reading from it. +// +// PostForm is a wrapper around DefaultClient.PostForm +func PostForm(url string, data url.Values) (r *Response, err error) { + return DefaultClient.PostForm(url, data) +} + +// PostForm issues a POST to the specified URL, +// with data's keys and values urlencoded as the request body. +// +// Caller should close r.Body when done reading from it. +func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) { + return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} + +// Head issues a HEAD to the specified URL. If the response is one of the +// following redirect codes, Head follows the redirect after calling the +// Client's CheckRedirect function. +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// +// Head is a wrapper around DefaultClient.Head +func Head(url string) (r *Response, err error) { + return DefaultClient.Head(url) +} + +// Head issues a HEAD to the specified URL. If the response is one of the +// following redirect codes, Head follows the redirect after calling the +// Client's CheckRedirect function. +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +func (c *Client) Head(url string) (r *Response, err error) { + req, err := NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return c.doFollowingRedirects(req) +} diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go new file mode 100644 index 000000000..aa0bf4be6 --- /dev/null +++ b/src/pkg/net/http/client_test.go @@ -0,0 +1,442 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for client.go + +package http_test + +import ( + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + . "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "sync" + "testing" +) + +var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Last-Modified", "sometime") + fmt.Fprintf(w, "User-agent: go\nDisallow: /something/") +}) + +// pedanticReadAll works like ioutil.ReadAll but additionally +// verifies that r obeys the documented io.Reader contract. +func pedanticReadAll(r io.Reader) (b []byte, err error) { + var bufa [64]byte + buf := bufa[:] + for { + n, err := r.Read(buf) + if n == 0 && err == nil { + return nil, fmt.Errorf("Read: n=0 with err=nil") + } + b = append(b, buf[:n]...) + if err == io.EOF { + n, err := r.Read(buf) + if n != 0 || err != io.EOF { + return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err) + } + return b, nil + } + if err != nil { + return b, err + } + } + panic("unreachable") +} + +func TestClient(t *testing.T) { + ts := httptest.NewServer(robotsTxtHandler) + defer ts.Close() + + r, err := Get(ts.URL) + var b []byte + if err == nil { + b, err = pedanticReadAll(r.Body) + r.Body.Close() + } + if err != nil { + t.Error(err) + } else if s := string(b); !strings.HasPrefix(s, "User-agent:") { + t.Errorf("Incorrect page body (did not begin with User-agent): %q", s) + } +} + +func TestClientHead(t *testing.T) { + ts := httptest.NewServer(robotsTxtHandler) + defer ts.Close() + + r, err := Head(ts.URL) + if err != nil { + t.Fatal(err) + } + if _, ok := r.Header["Last-Modified"]; !ok { + t.Error("Last-Modified header not found.") + } +} + +type recordingTransport struct { + req *Request +} + +func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) { + t.req = req + return nil, errors.New("dummy impl") +} + +func TestGetRequestFormat(t *testing.T) { + tr := &recordingTransport{} + client := &Client{Transport: tr} + url := "http://dummy.faketld/" + client.Get(url) // Note: doesn't hit network + if tr.req.Method != "GET" { + t.Errorf("expected method %q; got %q", "GET", tr.req.Method) + } + if tr.req.URL.String() != url { + t.Errorf("expected URL %q; got %q", url, tr.req.URL.String()) + } + if tr.req.Header == nil { + t.Errorf("expected non-nil request Header") + } +} + +func TestPostRequestFormat(t *testing.T) { + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://dummy.faketld/" + json := `{"key":"value"}` + b := strings.NewReader(json) + client.Post(url, "application/json", b) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if tr.req.Close { + t.Error("got Close true, want false") + } + if g, e := tr.req.ContentLength, int64(len(json)); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } +} + +func TestPostFormRequestFormat(t *testing.T) { + tr := &recordingTransport{} + client := &Client{Transport: tr} + + urlStr := "http://dummy.faketld/" + form := make(url.Values) + form.Set("foo", "bar") + form.Add("foo", "bar2") + form.Set("bar", "baz") + client.PostForm(urlStr, form) // Note: doesn't hit network + + if tr.req.Method != "POST" { + t.Errorf("got method %q, want %q", tr.req.Method, "POST") + } + if tr.req.URL.String() != urlStr { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e { + t.Errorf("got Content-Type %q, want %q", g, e) + } + if tr.req.Close { + t.Error("got Close true, want false") + } + // Depending on map iteration, body can be either of these. + expectedBody := "foo=bar&foo=bar2&bar=baz" + expectedBody1 := "bar=baz&foo=bar&foo=bar2" + if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e { + t.Errorf("got ContentLength %d, want %d", g, e) + } + bodyb, err := ioutil.ReadAll(tr.req.Body) + if err != nil { + t.Fatalf("ReadAll on req.Body: %v", err) + } + if g := string(bodyb); g != expectedBody && g != expectedBody1 { + t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1) + } +} + +func TestRedirects(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + n, _ := strconv.Atoi(r.FormValue("n")) + // Test Referer header. (7 is arbitrary position to test at) + if n == 7 { + if g, e := r.Referer(), ts.URL+"/?n=6"; e != g { + t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g) + } + } + if n < 15 { + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound) + return + } + fmt.Fprintf(w, "n=%d", n) + })) + defer ts.Close() + + c := &Client{} + _, err := c.Get(ts.URL) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Get, expected error %q, got %q", e, g) + } + + // HEAD request should also have the ability to follow redirects. + _, err = c.Head(ts.URL) + if e, g := "Head /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Head, expected error %q, got %q", e, g) + } + + // Do should also follow redirects. + greq, _ := NewRequest("GET", ts.URL, nil) + _, err = c.Do(greq) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Do, expected error %q, got %q", e, g) + } + + var checkErr error + var lastVia []*Request + c = &Client{CheckRedirect: func(_ *Request, via []*Request) error { + lastVia = via + return checkErr + }} + res, err := c.Get(ts.URL) + finalUrl := res.Request.URL.String() + if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { + t.Errorf("with custom client, expected error %q, got %q", e, g) + } + if !strings.HasSuffix(finalUrl, "/?n=15") { + t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl) + } + if e, g := 15, len(lastVia); e != g { + t.Errorf("expected lastVia to have contained %d elements; got %d", e, g) + } + + checkErr = errors.New("no redirects allowed") + res, err = c.Get(ts.URL) + finalUrl = res.Request.URL.String() + if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { + t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + } +} + +var expectedCookies = []*Cookie{ + &Cookie{Name: "ChocolateChip", Value: "tasty"}, + &Cookie{Name: "First", Value: "Hit"}, + &Cookie{Name: "Second", Value: "Hit"}, +} + +var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + for _, cookie := range r.Cookies() { + SetCookie(w, cookie) + } + if r.URL.Path == "/" { + SetCookie(w, expectedCookies[1]) + Redirect(w, r, "/second", StatusMovedPermanently) + } else { + SetCookie(w, expectedCookies[2]) + w.Write([]byte("hello")) + } +}) + +// Just enough correctness for our redirect tests. Uses the URL.Host as the +// scope of all cookies. +type TestJar struct { + m sync.Mutex + perURL map[string][]*Cookie +} + +func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) { + j.m.Lock() + defer j.m.Unlock() + j.perURL[u.Host] = cookies +} + +func (j *TestJar) Cookies(u *url.URL) []*Cookie { + j.m.Lock() + defer j.m.Unlock() + return j.perURL[u.Host] +} + +func TestRedirectCookiesOnRequest(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(echoCookiesRedirectHandler) + defer ts.Close() + c := &Client{} + req, _ := NewRequest("GET", ts.URL, nil) + req.AddCookie(expectedCookies[0]) + // TODO: Uncomment when an implementation of a RFC6265 cookie jar lands. + _ = c + // resp, _ := c.Do(req) + // matchReturnedCookies(t, expectedCookies, resp.Cookies()) + + req, _ = NewRequest("GET", ts.URL, nil) + // resp, _ = c.Do(req) + // matchReturnedCookies(t, expectedCookies[1:], resp.Cookies()) +} + +func TestRedirectCookiesJar(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(echoCookiesRedirectHandler) + defer ts.Close() + c := &Client{} + c.Jar = &TestJar{perURL: make(map[string][]*Cookie)} + u, _ := url.Parse(ts.URL) + c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) + resp, _ := c.Get(ts.URL) + matchReturnedCookies(t, expectedCookies, resp.Cookies()) +} + +func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { + t.Logf("Received cookies: %v", given) + if len(given) != len(expected) { + t.Errorf("Expected %d cookies, got %d", len(expected), len(given)) + } + for _, ec := range expected { + foundC := false + for _, c := range given { + if ec.Name == c.Name && ec.Value == c.Value { + foundC = true + break + } + } + if !foundC { + t.Errorf("Missing cookie %v", ec) + } + } +} + +func TestStreamingGet(t *testing.T) { + say := make(chan string) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() + for str := range say { + w.Write([]byte(str)) + w.(Flusher).Flush() + } + })) + defer ts.Close() + + c := &Client{} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + var buf [10]byte + for _, str := range []string{"i", "am", "also", "known", "as", "comet"} { + say <- str + n, err := io.ReadFull(res.Body, buf[0:len(str)]) + if err != nil { + t.Fatalf("ReadFull on %q: %v", str, err) + } + if n != len(str) { + t.Fatalf("Receiving %q, only read %d bytes", str, n) + } + got := string(buf[0:n]) + if got != str { + t.Fatalf("Expected %q, got %q", str, got) + } + } + close(say) + _, err = io.ReadFull(res.Body, buf[0:1]) + if err != io.EOF { + t.Fatalf("at end expected EOF, got %v", err) + } +} + +type writeCountingConn struct { + net.Conn + count *int +} + +func (c *writeCountingConn) Write(p []byte) (int, error) { + *c.count++ + return c.Conn.Write(p) +} + +// TestClientWrites verifies that client requests are buffered and we +// don't send a TCP packet per line of the http request + body. +func TestClientWrites(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })) + defer ts.Close() + + writes := 0 + dialer := func(netz string, addr string) (net.Conn, error) { + c, err := net.Dial(netz, addr) + if err == nil { + c = &writeCountingConn{c, &writes} + } + return c, err + } + c := &Client{Transport: &Transport{Dial: dialer}} + + _, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if writes != 1 { + t.Errorf("Get request did %d Write calls, want 1", writes) + } + + writes = 0 + _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}}) + if err != nil { + t.Fatal(err) + } + if writes != 1 { + t.Errorf("Post request did %d Write calls, want 1", writes) + } +} + +func TestClientInsecureTransport(t *testing.T) { + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + defer ts.Close() + + // TODO(bradfitz): add tests for skipping hostname checks too? + // would require a new cert for testing, and probably + // redundant with these tests. + for _, insecure := range []bool{true, false} { + tr := &Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: insecure, + }, + } + c := &Client{Transport: tr} + _, err := c.Get(ts.URL) + if (err == nil) != insecure { + t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) + } + } +} + +func TestClientErrorWithRequestURI(t *testing.T) { + req, _ := NewRequest("GET", "http://localhost:1234/", nil) + req.RequestURI = "/this/field/is/illegal/and/should/error/" + _, err := DefaultClient.Do(req) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "RequestURI") { + t.Errorf("wanted error mentioning RequestURI; got error: %v", err) + } +} diff --git a/src/pkg/net/http/cookie.go b/src/pkg/net/http/cookie.go new file mode 100644 index 000000000..2e30bbff1 --- /dev/null +++ b/src/pkg/net/http/cookie.go @@ -0,0 +1,267 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "fmt" + "strconv" + "strings" + "time" +) + +// This implementation is done according to RFC 6265: +// +// http://tools.ietf.org/html/rfc6265 + +// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an +// HTTP response or the Cookie header of an HTTP request. +type Cookie struct { + Name string + Value string + Path string + Domain string + Expires time.Time + RawExpires string + + // MaxAge=0 means no 'Max-Age' attribute specified. + // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' + // MaxAge>0 means Max-Age attribute present and given in seconds + MaxAge int + Secure bool + HttpOnly bool + Raw string + Unparsed []string // Raw text of unparsed attribute-value pairs +} + +// readSetCookies parses all "Set-Cookie" values from +// the header h and returns the successfully parsed Cookies. +func readSetCookies(h Header) []*Cookie { + cookies := []*Cookie{} + for _, line := range h["Set-Cookie"] { + parts := strings.Split(strings.TrimSpace(line), ";") + if len(parts) == 1 && parts[0] == "" { + continue + } + parts[0] = strings.TrimSpace(parts[0]) + j := strings.Index(parts[0], "=") + if j < 0 { + continue + } + name, value := parts[0][:j], parts[0][j+1:] + if !isCookieNameValid(name) { + continue + } + value, success := parseCookieValue(value) + if !success { + continue + } + c := &Cookie{ + Name: name, + Value: value, + Raw: line, + } + for i := 1; i < len(parts); i++ { + parts[i] = strings.TrimSpace(parts[i]) + if len(parts[i]) == 0 { + continue + } + + attr, val := parts[i], "" + if j := strings.Index(attr, "="); j >= 0 { + attr, val = attr[:j], attr[j+1:] + } + lowerAttr := strings.ToLower(attr) + parseCookieValueFn := parseCookieValue + if lowerAttr == "expires" { + parseCookieValueFn = parseCookieExpiresValue + } + val, success = parseCookieValueFn(val) + if !success { + c.Unparsed = append(c.Unparsed, parts[i]) + continue + } + switch lowerAttr { + case "secure": + c.Secure = true + continue + case "httponly": + c.HttpOnly = true + continue + case "domain": + c.Domain = val + // TODO: Add domain parsing + continue + case "max-age": + secs, err := strconv.Atoi(val) + if err != nil || secs != 0 && val[0] == '0' { + break + } + if secs <= 0 { + c.MaxAge = -1 + } else { + c.MaxAge = secs + } + continue + case "expires": + c.RawExpires = val + exptime, err := time.Parse(time.RFC1123, val) + if err != nil { + exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val) + if err != nil { + c.Expires = time.Time{} + break + } + } + c.Expires = exptime.UTC() + continue + case "path": + c.Path = val + // TODO: Add path parsing + continue + } + c.Unparsed = append(c.Unparsed, parts[i]) + } + cookies = append(cookies, c) + } + return cookies +} + +// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers. +func SetCookie(w ResponseWriter, cookie *Cookie) { + w.Header().Add("Set-Cookie", cookie.String()) +} + +// String returns the serialization of the cookie for use in a Cookie +// header (if only Name and Value are set) or a Set-Cookie response +// header (if other fields are set). +func (c *Cookie) String() string { + var b bytes.Buffer + fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) + if len(c.Path) > 0 { + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path)) + } + if len(c.Domain) > 0 { + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain)) + } + if c.Expires.Unix() > 0 { + fmt.Fprintf(&b, "; Expires=%s", c.Expires.UTC().Format(time.RFC1123)) + } + if c.MaxAge > 0 { + fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge) + } else if c.MaxAge < 0 { + fmt.Fprintf(&b, "; Max-Age=0") + } + if c.HttpOnly { + fmt.Fprintf(&b, "; HttpOnly") + } + if c.Secure { + fmt.Fprintf(&b, "; Secure") + } + return b.String() +} + +// readCookies parses all "Cookie" values from the header h and +// returns the successfully parsed Cookies. +// +// if filter isn't empty, only cookies of that name are returned +func readCookies(h Header, filter string) []*Cookie { + cookies := []*Cookie{} + lines, ok := h["Cookie"] + if !ok { + return cookies + } + + for _, line := range lines { + parts := strings.Split(strings.TrimSpace(line), ";") + if len(parts) == 1 && parts[0] == "" { + continue + } + // Per-line attributes + parsedPairs := 0 + for i := 0; i < len(parts); i++ { + parts[i] = strings.TrimSpace(parts[i]) + if len(parts[i]) == 0 { + continue + } + name, val := parts[i], "" + if j := strings.Index(name, "="); j >= 0 { + name, val = name[:j], name[j+1:] + } + if !isCookieNameValid(name) { + continue + } + if filter != "" && filter != name { + continue + } + val, success := parseCookieValue(val) + if !success { + continue + } + cookies = append(cookies, &Cookie{Name: name, Value: val}) + parsedPairs++ + } + } + return cookies +} + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ") + +func sanitizeValue(v string) string { + return cookieValueSanitizer.Replace(v) +} + +func unquoteCookieValue(v string) string { + if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' { + return v[1 : len(v)-1] + } + return v +} + +func isCookieByte(c byte) bool { + switch { + case c == 0x21, 0x23 <= c && c <= 0x2b, 0x2d <= c && c <= 0x3a, + 0x3c <= c && c <= 0x5b, 0x5d <= c && c <= 0x7e: + return true + } + return false +} + +func isCookieExpiresByte(c byte) (ok bool) { + return isCookieByte(c) || c == ',' || c == ' ' +} + +func parseCookieValue(raw string) (string, bool) { + return parseCookieValueUsing(raw, isCookieByte) +} + +func parseCookieExpiresValue(raw string) (string, bool) { + return parseCookieValueUsing(raw, isCookieExpiresByte) +} + +func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool) { + raw = unquoteCookieValue(raw) + for i := 0; i < len(raw); i++ { + if !validByte(raw[i]) { + return "", false + } + } + return raw, true +} + +func isCookieNameValid(raw string) bool { + for _, c := range raw { + if !isToken(byte(c)) { + return false + } + } + return true +} diff --git a/src/pkg/net/http/cookie_test.go b/src/pkg/net/http/cookie_test.go new file mode 100644 index 000000000..712350dfc --- /dev/null +++ b/src/pkg/net/http/cookie_test.go @@ -0,0 +1,200 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + "time" +) + +var writeSetCookiesTests = []struct { + Cookie *Cookie + Raw string +}{ + { + &Cookie{Name: "cookie-1", Value: "v$1"}, + "cookie-1=v$1", + }, + { + &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, + "cookie-2=two; Max-Age=3600", + }, + { + &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, + "cookie-3=three; Domain=.example.com", + }, + { + &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, + "cookie-4=four; Path=/restricted/", + }, +} + +func TestWriteSetCookies(t *testing.T) { + for i, tt := range writeSetCookiesTests { + if g, e := tt.Cookie.String(), tt.Raw; g != e { + t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g) + continue + } + } +} + +type headerOnlyResponseWriter Header + +func (ho headerOnlyResponseWriter) Header() Header { + return Header(ho) +} + +func (ho headerOnlyResponseWriter) Write([]byte) (int, error) { + panic("NOIMPL") +} + +func (ho headerOnlyResponseWriter) WriteHeader(int) { + panic("NOIMPL") +} + +func TestSetCookie(t *testing.T) { + m := make(Header) + SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-1", Value: "one", Path: "/restricted/"}) + SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}) + if l := len(m["Set-Cookie"]); l != 2 { + t.Fatalf("expected %d cookies, got %d", 2, l) + } + if g, e := m["Set-Cookie"][0], "cookie-1=one; Path=/restricted/"; g != e { + t.Errorf("cookie #1: want %q, got %q", e, g) + } + if g, e := m["Set-Cookie"][1], "cookie-2=two; Max-Age=3600"; g != e { + t.Errorf("cookie #2: want %q, got %q", e, g) + } +} + +var addCookieTests = []struct { + Cookies []*Cookie + Raw string +}{ + { + []*Cookie{}, + "", + }, + { + []*Cookie{{Name: "cookie-1", Value: "v$1"}}, + "cookie-1=v$1", + }, + { + []*Cookie{ + {Name: "cookie-1", Value: "v$1"}, + {Name: "cookie-2", Value: "v$2"}, + {Name: "cookie-3", Value: "v$3"}, + }, + "cookie-1=v$1; cookie-2=v$2; cookie-3=v$3", + }, +} + +func TestAddCookie(t *testing.T) { + for i, tt := range addCookieTests { + req, _ := NewRequest("GET", "http://example.com/", nil) + for _, c := range tt.Cookies { + req.AddCookie(c) + } + if g := req.Header.Get("Cookie"); g != tt.Raw { + t.Errorf("Test %d:\nwant: %s\n got: %s\n", i, tt.Raw, g) + continue + } + } +} + +var readSetCookiesTests = []struct { + Header Header + Cookies []*Cookie +}{ + { + Header{"Set-Cookie": {"Cookie-1=v$1"}}, + []*Cookie{{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}}, + }, + { + Header{"Set-Cookie": {"NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly"}}, + []*Cookie{{ + Name: "NID", + Value: "99=YsDT5i3E-CXax-", + Path: "/", + Domain: ".google.ch", + HttpOnly: true, + Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC), + RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", + Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + }}, + }, +} + +func toJSON(v interface{}) string { + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%#v", v) + } + return string(b) +} + +func TestReadSetCookies(t *testing.T) { + for i, tt := range readSetCookiesTests { + for n := 0; n < 2; n++ { // to verify readSetCookies doesn't mutate its input + c := readSetCookies(tt.Header) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies)) + continue + } + } + } +} + +var readCookiesTests = []struct { + Header Header + Filter string + Cookies []*Cookie +}{ + { + Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}}, + "c2", + []*Cookie{ + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1; c2=v2"}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {"Cookie-1=v$1; c2=v2"}}, + "c2", + []*Cookie{ + {Name: "c2", Value: "v2"}, + }, + }, +} + +func TestReadCookies(t *testing.T) { + for i, tt := range readCookiesTests { + for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input + c := readCookies(tt.Header, tt.Filter) + if !reflect.DeepEqual(c, tt.Cookies) { + t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies)) + continue + } + } + } +} diff --git a/src/pkg/net/http/doc.go b/src/pkg/net/http/doc.go new file mode 100644 index 000000000..8962ed31e --- /dev/null +++ b/src/pkg/net/http/doc.go @@ -0,0 +1,80 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package http provides HTTP client and server implementations. + +Get, Head, Post, and PostForm make HTTP requests: + + resp, err := http.Get("http://example.com/") + ... + resp, err := http.Post("http://example.com/upload", "image/jpeg", &buf) + ... + resp, err := http.PostForm("http://example.com/form", + url.Values{"key": {"Value"}, "id": {"123"}}) + +The client must close the response body when finished with it: + + resp, err := http.Get("http://example.com/") + if err != nil { + // handle error + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + // ... + +For control over HTTP client headers, redirect policy, and other +settings, create a Client: + + client := &http.Client{ + CheckRedirect: redirectPolicyFunc, + } + + resp, err := client.Get("http://example.com") + // ... + + req, err := http.NewRequest("GET", "http://example.com", nil) + // ... + req.Header.Add("If-None-Match", `W/"wyzzy"`) + resp, err := client.Do(req) + // ... + +For control over proxies, TLS configuration, keep-alives, +compression, and other settings, create a Transport: + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: pool}, + DisableCompression: true, + } + client := &http.Client{Transport: tr} + resp, err := client.Get("https://example.com") + +Clients and Transports are safe for concurrent use by multiple +goroutines and for efficiency should only be created once and re-used. + +ListenAndServe starts an HTTP server with a given address and handler. +The handler is usually nil, which means to use DefaultServeMux. +Handle and HandleFunc add handlers to DefaultServeMux: + + http.Handle("/foo", fooHandler) + + http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.RawPath)) + }) + + log.Fatal(http.ListenAndServe(":8080", nil)) + +More control over the server's behavior is available by creating a +custom Server: + + s := &http.Server{ + Addr: ":8080", + Handler: myHandler, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, + } + log.Fatal(s.ListenAndServe()) +*/ +package http diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go new file mode 100644 index 000000000..13640ca85 --- /dev/null +++ b/src/pkg/net/http/export_test.go @@ -0,0 +1,43 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Bridge package to expose http internals to tests in the http_test +// package. + +package http + +import "time" + +func (t *Transport) IdleConnKeysForTesting() (keys []string) { + keys = make([]string, 0) + t.lk.Lock() + defer t.lk.Unlock() + if t.idleConn == nil { + return + } + for key := range t.idleConn { + keys = append(keys, key) + } + return +} + +func (t *Transport) IdleConnCountForTesting(cacheKey string) int { + t.lk.Lock() + defer t.lk.Unlock() + if t.idleConn == nil { + return 0 + } + conns, ok := t.idleConn[cacheKey] + if !ok { + return 0 + } + return len(conns) +} + +func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { + f := func() <-chan time.Time { + return ch + } + return &timeoutHandler{handler, f, ""} +} diff --git a/src/pkg/net/http/fcgi/Makefile b/src/pkg/net/http/fcgi/Makefile new file mode 100644 index 000000000..9a75f1a80 --- /dev/null +++ b/src/pkg/net/http/fcgi/Makefile @@ -0,0 +1,12 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../../Make.inc + +TARG=net/http/fcgi +GOFILES=\ + child.go\ + fcgi.go\ + +include ../../../../Make.pkg diff --git a/src/pkg/net/http/fcgi/child.go b/src/pkg/net/http/fcgi/child.go new file mode 100644 index 000000000..c94b9a7b2 --- /dev/null +++ b/src/pkg/net/http/fcgi/child.go @@ -0,0 +1,271 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +// This file implements FastCGI from the perspective of a child process. + +import ( + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/cgi" + "os" + "time" +) + +// request holds the state for an in-progress request. As soon as it's complete, +// it's converted to an http.Request. +type request struct { + pw *io.PipeWriter + reqId uint16 + params map[string]string + buf [1024]byte + rawParams []byte + keepConn bool +} + +func newRequest(reqId uint16, flags uint8) *request { + r := &request{ + reqId: reqId, + params: map[string]string{}, + keepConn: flags&flagKeepConn != 0, + } + r.rawParams = r.buf[:0] + return r +} + +// parseParams reads an encoded []byte into Params. +func (r *request) parseParams() { + text := r.rawParams + r.rawParams = nil + for len(text) > 0 { + keyLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + valLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + key := readString(text, keyLen) + text = text[keyLen:] + val := readString(text, valLen) + text = text[valLen:] + r.params[key] = val + } +} + +// response implements http.ResponseWriter. +type response struct { + req *request + header http.Header + w *bufWriter + wroteHeader bool +} + +func newResponse(c *child, req *request) *response { + return &response{ + req: req, + header: http.Header{}, + w: newWriter(c.conn, typeStdout, req.reqId), + } +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(data []byte) (int, error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + return r.w.Write(data) +} + +func (r *response) WriteHeader(code int) { + if r.wroteHeader { + return + } + r.wroteHeader = true + if code == http.StatusNotModified { + // Must not have body. + r.header.Del("Content-Type") + r.header.Del("Content-Length") + r.header.Del("Transfer-Encoding") + } else if r.header.Get("Content-Type") == "" { + r.header.Set("Content-Type", "text/html; charset=utf-8") + } + + if r.header.Get("Date") == "" { + r.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) + } + + fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code)) + r.header.Write(r.w) + r.w.WriteString("\r\n") +} + +func (r *response) Flush() { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + r.w.Flush() +} + +func (r *response) Close() error { + r.Flush() + return r.w.Close() +} + +type child struct { + conn *conn + handler http.Handler + requests map[uint16]*request // keyed by request ID +} + +func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child { + return &child{ + conn: newConn(rwc), + handler: handler, + requests: make(map[uint16]*request), + } +} + +func (c *child) serve() { + defer c.conn.Close() + var rec record + for { + if err := rec.read(c.conn.rwc); err != nil { + return + } + if err := c.handleRecord(&rec); err != nil { + return + } + } +} + +var errCloseConn = errors.New("fcgi: connection should be closed") + +func (c *child) handleRecord(rec *record) error { + req, ok := c.requests[rec.h.Id] + if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { + // The spec says to ignore unknown request IDs. + return nil + } + if ok && rec.h.Type == typeBeginRequest { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return errors.New("fcgi: received ID that is already in-flight") + } + + switch rec.h.Type { + case typeBeginRequest: + var br beginRequest + if err := br.read(rec.content()); err != nil { + return err + } + if br.role != roleResponder { + c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) + return nil + } + c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + case typeParams: + // NOTE(eds): Technically a key-value pair can straddle the boundary + // between two packets. We buffer until we've received all parameters. + if len(rec.content()) > 0 { + req.rawParams = append(req.rawParams, rec.content()...) + return nil + } + req.parseParams() + case typeStdin: + content := rec.content() + if req.pw == nil { + var body io.ReadCloser + if len(content) > 0 { + // body could be an io.LimitReader, but it shouldn't matter + // as long as both sides are behaving. + body, req.pw = io.Pipe() + } + go c.serveRequest(req, body) + } + if len(content) > 0 { + // TODO(eds): This blocks until the handler reads from the pipe. + // If the handler takes a long time, it might be a problem. + req.pw.Write(content) + } else if req.pw != nil { + req.pw.Close() + } + case typeGetValues: + values := map[string]string{"FCGI_MPXS_CONNS": "1"} + c.conn.writePairs(typeGetValuesResult, 0, values) + case typeData: + // If the filter role is implemented, read the data stream here. + case typeAbortRequest: + delete(c.requests, rec.h.Id) + c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) + if !req.keepConn { + // connection will close upon return + return errCloseConn + } + default: + b := make([]byte, 8) + b[0] = byte(rec.h.Type) + c.conn.writeRecord(typeUnknownType, 0, b) + } + return nil +} + +func (c *child) serveRequest(req *request, body io.ReadCloser) { + r := newResponse(c, req) + httpReq, err := cgi.RequestFromMap(req.params) + if err != nil { + // there was an error reading the request + r.WriteHeader(http.StatusInternalServerError) + c.conn.writeRecord(typeStderr, req.reqId, []byte(err.Error())) + } else { + httpReq.Body = body + c.handler.ServeHTTP(r, httpReq) + } + if body != nil { + body.Close() + } + r.Close() + c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + if !req.keepConn { + c.conn.Close() + } +} + +// Serve accepts incoming FastCGI connections on the listener l, creating a new +// service thread for each. The service threads read requests and then call handler +// to reply to them. +// If l is nil, Serve accepts connections on stdin. +// If handler is nil, http.DefaultServeMux is used. +func Serve(l net.Listener, handler http.Handler) error { + if l == nil { + var err error + l, err = net.FileListener(os.Stdin) + if err != nil { + return err + } + defer l.Close() + } + if handler == nil { + handler = http.DefaultServeMux + } + for { + rw, err := l.Accept() + if err != nil { + return err + } + c := newChild(rw, handler) + go c.serve() + } + panic("unreachable") +} diff --git a/src/pkg/net/http/fcgi/fcgi.go b/src/pkg/net/http/fcgi/fcgi.go new file mode 100644 index 000000000..d35aa84d2 --- /dev/null +++ b/src/pkg/net/http/fcgi/fcgi.go @@ -0,0 +1,274 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fcgi implements the FastCGI protocol. +// Currently only the responder role is supported. +// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22 +package fcgi + +// This file defines the raw protocol and some utilities used by the child and +// the host. + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "io" + "sync" +) + +// recType is a record type, as defined by +// http://www.fastcgi.com/devkit/doc/fcgi-spec.html#S8 +type recType uint8 + +const ( + typeBeginRequest recType = 1 + typeAbortRequest recType = 2 + typeEndRequest recType = 3 + typeParams recType = 4 + typeStdin recType = 5 + typeStdout recType = 6 + typeStderr recType = 7 + typeData recType = 8 + typeGetValues recType = 9 + typeGetValuesResult recType = 10 + typeUnknownType recType = 11 +) + +// keep the connection between web-server and responder open after request +const flagKeepConn = 1 + +const ( + maxWrite = 65535 // maximum record body + maxPad = 255 +) + +const ( + roleResponder = iota + 1 // only Responders are implemented. + roleAuthorizer + roleFilter +) + +const ( + statusRequestComplete = iota + statusCantMultiplex + statusOverloaded + statusUnknownRole +) + +const headerLen = 8 + +type header struct { + Version uint8 + Type recType + Id uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +type beginRequest struct { + role uint16 + flags uint8 + reserved [5]uint8 +} + +func (br *beginRequest) read(content []byte) error { + if len(content) != 8 { + return errors.New("fcgi: invalid begin request record") + } + br.role = binary.BigEndian.Uint16(content) + br.flags = content[2] + return nil +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType recType, reqId uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.Id = reqId + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +// conn sends records over rwc +type conn struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + + // to avoid allocations + buf bytes.Buffer + h header +} + +func newConn(rwc io.ReadWriteCloser) *conn { + return &conn{rwc: rwc} +} + +func (c *conn) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.rwc.Close() +} + +type record struct { + h header + buf [maxWrite + maxPad]byte +} + +func (rec *record) read(r io.Reader) (err error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return err + } + if rec.h.Version != 1 { + return errors.New("fcgi: invalid header version") + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if _, err = io.ReadFull(r, rec.buf[:n]); err != nil { + return err + } + return nil +} + +func (r *record) content() []byte { + return r.buf[:r.h.ContentLength] +} + +// writeRecord writes and sends a single record. +func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, reqId, len(b)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(b); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err := c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) error { + b := [8]byte{byte(role >> 8), byte(role), flags} + return c.writeRecord(typeBeginRequest, reqId, b[:]) +} + +func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(typeEndRequest, reqId, b) +} + +func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error { + w := newWriter(c, recType, reqId) + b := make([]byte, 8) + for k, v := range pairs { + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(v))) + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func readSize(s []byte) (uint32, int) { + if len(s) == 0 { + return 0, 0 + } + size, n := uint32(s[0]), 1 + if size&(1<<7) != 0 { + if len(s) < 4 { + return 0, 0 + } + n = 4 + size = binary.BigEndian.Uint32(s) + size &^= 1 << 31 + } + return size, n +} + +func readString(s []byte, size uint32) string { + if size > uint32(len(s)) { + return "" + } + return string(s[:size]) +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *conn, recType recType, reqId uint16) *bufWriter { + s := &streamWriter{c: c, recType: recType, reqId: reqId} + w, _ := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *conn + recType recType + reqId uint16 +} + +func (w *streamWriter) Write(p []byte) (int, error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, w.reqId, nil) +} diff --git a/src/pkg/net/http/fcgi/fcgi_test.go b/src/pkg/net/http/fcgi/fcgi_test.go new file mode 100644 index 000000000..6c7e1a9ce --- /dev/null +++ b/src/pkg/net/http/fcgi/fcgi_test.go @@ -0,0 +1,150 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +import ( + "bytes" + "errors" + "io" + "testing" +) + +var sizeTests = []struct { + size uint32 + bytes []byte +}{ + {0, []byte{0x00}}, + {127, []byte{0x7F}}, + {128, []byte{0x80, 0x00, 0x00, 0x80}}, + {1000, []byte{0x80, 0x00, 0x03, 0xE8}}, + {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}}, +} + +func TestSize(t *testing.T) { + b := make([]byte, 4) + for i, test := range sizeTests { + n := encodeSize(b, test.size) + if !bytes.Equal(b[:n], test.bytes) { + t.Errorf("%d expected %x, encoded %x", i, test.bytes, b) + } + size, n := readSize(test.bytes) + if size != test.size { + t.Errorf("%d expected %d, read %d", i, test.size, size) + } + if len(test.bytes) != n { + t.Errorf("%d did not consume all the bytes", i) + } + } +} + +var streamTests = []struct { + desc string + recType recType + reqId uint16 + content []byte + raw []byte +}{ + {"single record", typeStdout, 1, nil, + []byte{1, byte(typeStdout), 0, 1, 0, 0, 0, 0}, + }, + // this data will have to be split into two records + {"two records", typeStdin, 300, make([]byte, 66000), + bytes.Join([][]byte{ + // header for the first record + {1, byte(typeStdin), 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, + make([]byte, 65536), + // header for the second + {1, byte(typeStdin), 0x01, 0x2C, 0x01, 0xD1, 7, 0}, + make([]byte, 472), + // header for the empty record + {1, byte(typeStdin), 0x01, 0x2C, 0, 0, 0, 0}, + }, + nil), + }, +} + +type nilCloser struct { + io.ReadWriter +} + +func (c *nilCloser) Close() error { return nil } + +func TestStreams(t *testing.T) { + var rec record +outer: + for _, test := range streamTests { + buf := bytes.NewBuffer(test.raw) + var content []byte + for buf.Len() > 0 { + if err := rec.read(buf); err != nil { + t.Errorf("%s: error reading record: %v", test.desc, err) + continue outer + } + content = append(content, rec.content()...) + } + if rec.h.Type != test.recType { + t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType) + continue + } + if rec.h.Id != test.reqId { + t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId) + continue + } + if !bytes.Equal(content, test.content) { + t.Errorf("%s: read wrong content", test.desc) + continue + } + buf.Reset() + c := newConn(&nilCloser{buf}) + w := newWriter(c, test.recType, test.reqId) + if _, err := w.Write(test.content); err != nil { + t.Errorf("%s: error writing record: %v", test.desc, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: error closing stream: %v", test.desc, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.raw) { + t.Errorf("%s: wrote wrong content", test.desc) + } + } +} + +type writeOnlyConn struct { + buf []byte +} + +func (c *writeOnlyConn) Write(p []byte) (int, error) { + c.buf = append(c.buf, p...) + return len(p), nil +} + +func (c *writeOnlyConn) Read(p []byte) (int, error) { + return 0, errors.New("conn is write-only") +} + +func (c *writeOnlyConn) Close() error { + return nil +} + +func TestGetValues(t *testing.T) { + var rec record + rec.h.Type = typeGetValues + + wc := new(writeOnlyConn) + c := newChild(wc, nil) + err := c.handleRecord(&rec) + if err != nil { + t.Fatalf("handleRecord: %v", err) + } + + const want = "\x01\n\x00\x00\x00\x12\x06\x00" + + "\x0f\x01FCGI_MPXS_CONNS1" + + "\x00\x00\x00\x00\x00\x00\x01\n\x00\x00\x00\x00\x00\x00" + if got := string(wc.buf); got != want { + t.Errorf(" got: %q\nwant: %q\n", got, want) + } +} diff --git a/src/pkg/net/http/filetransport.go b/src/pkg/net/http/filetransport.go new file mode 100644 index 000000000..821787e0c --- /dev/null +++ b/src/pkg/net/http/filetransport.go @@ -0,0 +1,123 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "fmt" + "io" +) + +// fileTransport implements RoundTripper for the 'file' protocol. +type fileTransport struct { + fh fileHandler +} + +// NewFileTransport returns a new RoundTripper, serving the provided +// FileSystem. The returned RoundTripper ignores the URL host in its +// incoming requests, as well as most other properties of the +// request. +// +// The typical use case for NewFileTransport is to register the "file" +// protocol with a Transport, as in: +// +// t := &http.Transport{} +// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/"))) +// c := &http.Client{Transport: t} +// res, err := c.Get("file:///etc/passwd") +// ... +func NewFileTransport(fs FileSystem) RoundTripper { + return fileTransport{fileHandler{fs}} +} + +func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) { + // We start ServeHTTP in a goroutine, which may take a long + // time if the file is large. The newPopulateResponseWriter + // call returns a channel which either ServeHTTP or finish() + // sends our *Response on, once the *Response itself has been + // populated (even if the body itself is still being + // written to the res.Body, a pipe) + rw, resc := newPopulateResponseWriter() + go func() { + t.fh.ServeHTTP(rw, req) + rw.finish() + }() + return <-resc, nil +} + +func newPopulateResponseWriter() (*populateResponse, <-chan *Response) { + pr, pw := io.Pipe() + rw := &populateResponse{ + ch: make(chan *Response), + pw: pw, + res: &Response{ + Proto: "HTTP/1.0", + ProtoMajor: 1, + Header: make(Header), + Close: true, + Body: pr, + }, + } + return rw, rw.ch +} + +// populateResponse is a ResponseWriter that populates the *Response +// in res, and writes its body to a pipe connected to the response +// body. Once writes begin or finish() is called, the response is sent +// on ch. +type populateResponse struct { + res *Response + ch chan *Response + wroteHeader bool + hasContent bool + sentResponse bool + pw *io.PipeWriter +} + +func (pr *populateResponse) finish() { + if !pr.wroteHeader { + pr.WriteHeader(500) + } + if !pr.sentResponse { + pr.sendResponse() + } + pr.pw.Close() +} + +func (pr *populateResponse) sendResponse() { + if pr.sentResponse { + return + } + pr.sentResponse = true + + if pr.hasContent { + pr.res.ContentLength = -1 + } + pr.ch <- pr.res +} + +func (pr *populateResponse) Header() Header { + return pr.res.Header +} + +func (pr *populateResponse) WriteHeader(code int) { + if pr.wroteHeader { + return + } + pr.wroteHeader = true + + pr.res.StatusCode = code + pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code)) +} + +func (pr *populateResponse) Write(p []byte) (n int, err error) { + if !pr.wroteHeader { + pr.WriteHeader(StatusOK) + } + pr.hasContent = true + if !pr.sentResponse { + pr.sendResponse() + } + return pr.pw.Write(p) +} diff --git a/src/pkg/net/http/filetransport_test.go b/src/pkg/net/http/filetransport_test.go new file mode 100644 index 000000000..039926b53 --- /dev/null +++ b/src/pkg/net/http/filetransport_test.go @@ -0,0 +1,65 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "io/ioutil" + "net/http" + "os" + "path/filepath" + "testing" +) + +func checker(t *testing.T) func(string, error) { + return func(call string, err error) { + if err == nil { + return + } + t.Fatalf("%s: %v", call, err) + } +} + +func TestFileTransport(t *testing.T) { + check := checker(t) + + dname, err := ioutil.TempDir("", "") + check("TempDir", err) + fname := filepath.Join(dname, "foo.txt") + err = ioutil.WriteFile(fname, []byte("Bar"), 0644) + check("WriteFile", err) + defer os.Remove(dname) + defer os.Remove(fname) + + tr := &http.Transport{} + tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname))) + c := &http.Client{Transport: tr} + + fooURLs := []string{"file:///foo.txt", "file://../foo.txt"} + for _, urlstr := range fooURLs { + res, err := c.Get(urlstr) + check("Get "+urlstr, err) + if res.StatusCode != 200 { + t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode) + } + if res.ContentLength != -1 { + t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength) + } + if res.Body == nil { + t.Fatalf("for %s, nil Body", urlstr) + } + slurp, err := ioutil.ReadAll(res.Body) + check("ReadAll "+urlstr, err) + if string(slurp) != "Bar" { + t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar") + } + } + + const badURL = "file://../no-exist.txt" + res, err := c.Get(badURL) + check("Get "+badURL, err) + if res.StatusCode != 404 { + t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode) + } +} diff --git a/src/pkg/net/http/fs.go b/src/pkg/net/http/fs.go new file mode 100644 index 000000000..1392ca68a --- /dev/null +++ b/src/pkg/net/http/fs.go @@ -0,0 +1,330 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP file system request handler + +package http + +import ( + "errors" + "fmt" + "io" + "mime" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +// A Dir implements http.FileSystem using the native file +// system restricted to a specific directory tree. +// +// An empty Dir is treated as ".". +type Dir string + +func (d Dir) Open(name string) (File, error) { + if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 { + return nil, errors.New("http: invalid character in file path") + } + dir := string(d) + if dir == "" { + dir = "." + } + f, err := os.Open(filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name)))) + if err != nil { + return nil, err + } + return f, nil +} + +// A FileSystem implements access to a collection of named files. +// The elements in a file path are separated by slash ('/', U+002F) +// characters, regardless of host operating system convention. +type FileSystem interface { + Open(name string) (File, error) +} + +// A File is returned by a FileSystem's Open method and can be +// served by the FileServer implementation. +type File interface { + Close() error + Stat() (os.FileInfo, error) + Readdir(count int) ([]os.FileInfo, error) + Read([]byte) (int, error) + Seek(offset int64, whence int) (int64, error) +} + +// Heuristic: b is text if it is valid UTF-8 and doesn't +// contain any unprintable ASCII or Unicode characters. +func isText(b []byte) bool { + for len(b) > 0 && utf8.FullRune(b) { + rune, size := utf8.DecodeRune(b) + if size == 1 && rune == utf8.RuneError { + // decoding error + return false + } + if 0x7F <= rune && rune <= 0x9F { + return false + } + if rune < ' ' { + switch rune { + case '\n', '\r', '\t': + // okay + default: + // binary garbage + return false + } + } + b = b[size:] + } + return true +} + +func dirList(w ResponseWriter, f File) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprintf(w, "<pre>\n") + for { + dirs, err := f.Readdir(100) + if err != nil || len(dirs) == 0 { + break + } + for _, d := range dirs { + name := d.Name() + if d.IsDir() { + name += "/" + } + // TODO htmlescape + fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", name, name) + } + } + fmt.Fprintf(w, "</pre>\n") +} + +// name is '/'-separated, not filepath.Separator. +func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) { + const indexPage = "/index.html" + + // redirect .../index.html to .../ + // can't use Redirect() because that would make the path absolute, + // which would be a problem running under StripPrefix + if strings.HasSuffix(r.URL.Path, indexPage) { + localRedirect(w, r, "./") + return + } + + f, err := fs.Open(name) + if err != nil { + // TODO expose actual error? + NotFound(w, r) + return + } + defer f.Close() + + d, err1 := f.Stat() + if err1 != nil { + // TODO expose actual error? + NotFound(w, r) + return + } + + if redirect { + // redirect to canonical path: / at end of directory url + // r.URL.Path always begins with / + url := r.URL.Path + if d.IsDir() { + if url[len(url)-1] != '/' { + localRedirect(w, r, path.Base(url)+"/") + return + } + } else { + if url[len(url)-1] == '/' { + localRedirect(w, r, "../"+path.Base(url)) + return + } + } + } + + if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && !d.ModTime().After(t) { + w.WriteHeader(StatusNotModified) + return + } + w.Header().Set("Last-Modified", d.ModTime().UTC().Format(TimeFormat)) + + // use contents of index.html for directory, if present + if d.IsDir() { + index := name + indexPage + ff, err := fs.Open(index) + if err == nil { + defer ff.Close() + dd, err := ff.Stat() + if err == nil { + name = index + d = dd + f = ff + } + } + } + + if d.IsDir() { + dirList(w, f) + return + } + + // serve file + size := d.Size() + code := StatusOK + + // If Content-Type isn't set, use the file's extension to find it. + if w.Header().Get("Content-Type") == "" { + ctype := mime.TypeByExtension(filepath.Ext(name)) + if ctype == "" { + // read a chunk to decide between utf-8 text and binary + var buf [1024]byte + n, _ := io.ReadFull(f, buf[:]) + b := buf[:n] + if isText(b) { + ctype = "text/plain; charset=utf-8" + } else { + // generic binary + ctype = "application/octet-stream" + } + f.Seek(0, os.SEEK_SET) // rewind to output whole file + } + w.Header().Set("Content-Type", ctype) + } + + // handle Content-Range header. + // TODO(adg): handle multiple ranges + ranges, err := parseRange(r.Header.Get("Range"), size) + if err == nil && len(ranges) > 1 { + err = errors.New("multiple ranges not supported") + } + if err != nil { + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + if len(ranges) == 1 { + ra := ranges[0] + if _, err := f.Seek(ra.start, os.SEEK_SET); err != nil { + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + size = ra.length + code = StatusPartialContent + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, d.Size())) + } + + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + } + + w.WriteHeader(code) + + if r.Method != "HEAD" { + io.CopyN(w, f, size) + } +} + +// localRedirect gives a Moved Permanently response. +// It does not convert relative paths to absolute paths like Redirect does. +func localRedirect(w ResponseWriter, r *Request, newPath string) { + if q := r.URL.RawQuery; q != "" { + newPath += "?" + q + } + w.Header().Set("Location", newPath) + w.WriteHeader(StatusMovedPermanently) +} + +// ServeFile replies to the request with the contents of the named file or directory. +func ServeFile(w ResponseWriter, r *Request, name string) { + dir, file := filepath.Split(name) + serveFile(w, r, Dir(dir), file, false) +} + +type fileHandler struct { + root FileSystem +} + +// FileServer returns a handler that serves HTTP requests +// with the contents of the file system rooted at root. +// +// To use the operating system's file system implementation, +// use http.Dir: +// +// http.Handle("/", http.FileServer(http.Dir("/tmp"))) +func FileServer(root FileSystem) Handler { + return &fileHandler{root} +} + +func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) { + upath := r.URL.Path + if !strings.HasPrefix(upath, "/") { + upath = "/" + upath + r.URL.Path = upath + } + serveFile(w, r, f.root, path.Clean(upath), true) +} + +// httpRange specifies the byte range to be sent to the client. +type httpRange struct { + start, length int64 +} + +// parseRange parses a Range header string as per RFC 2616. +func parseRange(s string, size int64) ([]httpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []httpRange + for _, ra := range strings.Split(s[len(b):], ",") { + i := strings.Index(ra, "-") + if i < 0 { + return nil, errors.New("invalid range") + } + start, end := ra[:i], ra[i+1:] + var r httpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file. + i, err := strconv.ParseInt(end, 10, 64) + if err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i > size || i < 0 { + return nil, errors.New("invalid range") + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + return ranges, nil +} diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go new file mode 100644 index 000000000..85cad3ec7 --- /dev/null +++ b/src/pkg/net/http/fs_test.go @@ -0,0 +1,334 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "fmt" + "io/ioutil" + . "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "strings" + "testing" +) + +const ( + testFile = "testdata/file" + testFileLength = 11 +) + +var ServeFileRangeTests = []struct { + start, end int + r string + code int +}{ + {0, testFileLength, "", StatusOK}, + {0, 5, "0-4", StatusPartialContent}, + {2, testFileLength, "2-", StatusPartialContent}, + {testFileLength - 5, testFileLength, "-5", StatusPartialContent}, + {3, 8, "3-7", StatusPartialContent}, + {0, 0, "20-", StatusRequestedRangeNotSatisfiable}, +} + +func TestServeFile(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + + var err error + + file, err := ioutil.ReadFile(testFile) + if err != nil { + t.Fatal("reading file:", err) + } + + // set up the Request (re-used for all tests) + var req Request + req.Header = make(Header) + if req.URL, err = url.Parse(ts.URL); err != nil { + t.Fatal("ParseURL:", err) + } + req.Method = "GET" + + // straight GET + _, body := getBody(t, req) + if !equal(body, file) { + t.Fatalf("body mismatch: got %q, want %q", body, file) + } + + // Range tests + for _, rt := range ServeFileRangeTests { + req.Header.Set("Range", "bytes="+rt.r) + if rt.r == "" { + req.Header["Range"] = nil + } + r, body := getBody(t, req) + if r.StatusCode != rt.code { + t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, r.StatusCode, rt.code) + } + if rt.code == StatusRequestedRangeNotSatisfiable { + continue + } + h := fmt.Sprintf("bytes %d-%d/%d", rt.start, rt.end-1, testFileLength) + if rt.r == "" { + h = "" + } + cr := r.Header.Get("Content-Range") + if cr != h { + t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h) + } + if !equal(body, file[rt.start:rt.end]) { + t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end]) + } + } +} + +var fsRedirectTestData = []struct { + original, redirect string +}{ + {"/test/index.html", "/test/"}, + {"/test/testdata", "/test/testdata/"}, + {"/test/testdata/file/", "/test/testdata/file"}, +} + +func TestFSRedirect(t *testing.T) { + ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) + defer ts.Close() + + for _, data := range fsRedirectTestData { + res, err := Get(ts.URL + data.original) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if g, e := res.Request.URL.Path, data.redirect; g != e { + t.Errorf("redirect from %s: got %s, want %s", data.original, g, e) + } + } +} + +type testFileSystem struct { + open func(name string) (File, error) +} + +func (fs *testFileSystem) Open(name string) (File, error) { + return fs.open(name) +} + +func TestFileServerCleans(t *testing.T) { + ch := make(chan string, 1) + fs := FileServer(&testFileSystem{func(name string) (File, error) { + ch <- name + return nil, os.ENOENT + }}) + tests := []struct { + reqPath, openArg string + }{ + {"/foo.txt", "/foo.txt"}, + {"//foo.txt", "/foo.txt"}, + {"/../foo.txt", "/foo.txt"}, + } + req, _ := NewRequest("GET", "http://example.com", nil) + for n, test := range tests { + rec := httptest.NewRecorder() + req.URL.Path = test.reqPath + fs.ServeHTTP(rec, req) + if got := <-ch; got != test.openArg { + t.Errorf("test %d: got %q, want %q", n, got, test.openArg) + } + } +} + +func TestFileServerImplicitLeadingSlash(t *testing.T) { + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("TempDir: %v", err) + } + defer os.RemoveAll(tempDir) + if err := ioutil.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir)))) + defer ts.Close() + get := func(suffix string) string { + res, err := Get(ts.URL + suffix) + if err != nil { + t.Fatalf("Get %s: %v", suffix, err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("ReadAll %s: %v", suffix, err) + } + return string(b) + } + if s := get("/bar/"); !strings.Contains(s, ">foo.txt<") { + t.Logf("expected a directory listing with foo.txt, got %q", s) + } + if s := get("/bar/foo.txt"); s != "Hello world" { + t.Logf("expected %q, got %q", "Hello world", s) + } +} + +func TestDirJoin(t *testing.T) { + wfi, err := os.Stat("/etc/hosts") + if err != nil { + t.Logf("skipping test; no /etc/hosts file") + return + } + test := func(d Dir, name string) { + f, err := d.Open(name) + if err != nil { + t.Fatalf("open of %s: %v", name, err) + } + defer f.Close() + gfi, err := f.Stat() + if err != nil { + t.Fatalf("stat of %s: %v", name, err) + } + if !gfi.(*os.FileStat).SameFile(wfi.(*os.FileStat)) { + t.Errorf("%s got different file", name) + } + } + test(Dir("/etc/"), "/hosts") + test(Dir("/etc/"), "hosts") + test(Dir("/etc/"), "../../../../hosts") + test(Dir("/etc"), "/hosts") + test(Dir("/etc"), "hosts") + test(Dir("/etc"), "../../../../hosts") + + // Not really directories, but since we use this trick in + // ServeFile, test it: + test(Dir("/etc/hosts"), "") + test(Dir("/etc/hosts"), "/") + test(Dir("/etc/hosts"), "../") +} + +func TestEmptyDirOpenCWD(t *testing.T) { + test := func(d Dir) { + name := "fs_test.go" + f, err := d.Open(name) + if err != nil { + t.Fatalf("open of %s: %v", name, err) + } + defer f.Close() + } + test(Dir("")) + test(Dir(".")) + test(Dir("./")) +} + +func TestServeFileContentType(t *testing.T) { + const ctype = "icecream/chocolate" + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.FormValue("override") == "1" { + w.Header().Set("Content-Type", ctype) + } + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + get := func(override, want string) { + resp, err := Get(ts.URL + "?override=" + override) + if err != nil { + t.Fatal(err) + } + if h := resp.Header.Get("Content-Type"); h != want { + t.Errorf("Content-Type mismatch: got %q, want %q", h, want) + } + } + get("0", "text/plain; charset=utf-8") + get("1", ctype) +} + +func TestServeFileMimeType(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/style.css") + })) + defer ts.Close() + resp, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + want := "text/css; charset=utf-8" + if h := resp.Header.Get("Content-Type"); h != want { + t.Errorf("Content-Type mismatch: got %q, want %q", h, want) + } +} + +func TestServeFileFromCWD(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "fs_test.go") + })) + defer ts.Close() + r, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if r.StatusCode != 200 { + t.Fatalf("expected 200 OK, got %s", r.Status) + } +} + +func TestServeFileWithContentEncoding(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "foo") + ServeFile(w, r, "testdata/file") + })) + defer ts.Close() + resp, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if g, e := resp.ContentLength, int64(-1); g != e { + t.Errorf("Content-Length mismatch: got %d, want %d", g, e) + } +} + +func TestServeIndexHtml(t *testing.T) { + const want = "index.html says hello\n" + ts := httptest.NewServer(FileServer(Dir("."))) + defer ts.Close() + + for _, path := range []string{"/testdata/", "/testdata/index.html"} { + res, err := Get(ts.URL + path) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if s := string(b); s != want { + t.Errorf("for path %q got %q, want %q", path, s, want) + } + } +} + +func getBody(t *testing.T, req Request) (*Response, []byte) { + r, err := DefaultClient.Do(&req) + if err != nil { + t.Fatal(req.URL.String(), "send:", err) + } + b, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + return r, b +} + +func equal(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go new file mode 100644 index 000000000..b107c312d --- /dev/null +++ b/src/pkg/net/http/header.go @@ -0,0 +1,78 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "fmt" + "io" + "net/textproto" + "sort" + "strings" +) + +// A Header represents the key-value pairs in an HTTP header. +type Header map[string][]string + +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +func (h Header) Add(key, value string) { + textproto.MIMEHeader(h).Add(key, value) +} + +// Set sets the header entries associated with key to +// the single element value. It replaces any existing +// values associated with key. +func (h Header) Set(key, value string) { + textproto.MIMEHeader(h).Set(key, value) +} + +// Get gets the first value associated with the given key. +// If there are no values associated with the key, Get returns "". +// To access multiple values of a key, access the map directly +// with CanonicalHeaderKey. +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// Del deletes the values associated with key. +func (h Header) Del(key string) { + textproto.MIMEHeader(h).Del(key) +} + +// Write writes a header in wire format. +func (h Header) Write(w io.Writer) error { + return h.WriteSubset(w, nil) +} + +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { + keys := make([]string, 0, len(h)) + for k := range h { + if exclude == nil || !exclude[k] { + keys = append(keys, k) + } + } + sort.Strings(keys) + for _, k := range keys { + for _, v := range h[k] { + v = headerNewlineToSpace.Replace(v) + v = strings.TrimSpace(v) + if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { + return err + } + } + } + return nil +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go new file mode 100644 index 000000000..ccdee8a97 --- /dev/null +++ b/src/pkg/net/http/header_test.go @@ -0,0 +1,81 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "testing" +) + +var headerWriteTests = []struct { + h Header + exclude map[string]bool + expected string +}{ + {Header{}, nil, ""}, + { + Header{ + "Content-Type": {"text/html; charset=UTF-8"}, + "Content-Length": {"0"}, + }, + nil, + "Content-Length: 0\r\nContent-Type: text/html; charset=UTF-8\r\n", + }, + { + Header{ + "Content-Length": {"0", "1", "2"}, + }, + nil, + "Content-Length: 0\r\nContent-Length: 1\r\nContent-Length: 2\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0", "1", "2"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true}, + "", + }, + { + Header{ + "Nil": nil, + "Empty": {}, + "Blank": {""}, + "Double-Blank": {"", ""}, + }, + nil, + "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n", + }, +} + +func TestHeaderWrite(t *testing.T) { + var buf bytes.Buffer + for i, test := range headerWriteTests { + test.h.WriteSubset(&buf, test.exclude) + if buf.String() != test.expected { + t.Errorf("#%d:\n got: %q\nwant: %q", i, buf.String(), test.expected) + } + buf.Reset() + } +} diff --git a/src/pkg/net/http/httptest/Makefile b/src/pkg/net/http/httptest/Makefile new file mode 100644 index 000000000..3bb445419 --- /dev/null +++ b/src/pkg/net/http/httptest/Makefile @@ -0,0 +1,12 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../../Make.inc + +TARG=net/http/httptest +GOFILES=\ + recorder.go\ + server.go\ + +include ../../../../Make.pkg diff --git a/src/pkg/net/http/httptest/recorder.go b/src/pkg/net/http/httptest/recorder.go new file mode 100644 index 000000000..9aa0d510b --- /dev/null +++ b/src/pkg/net/http/httptest/recorder.go @@ -0,0 +1,58 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httptest provides utilities for HTTP testing. +package httptest + +import ( + "bytes" + "net/http" +) + +// ResponseRecorder is an implementation of http.ResponseWriter that +// records its mutations for later inspection in tests. +type ResponseRecorder struct { + Code int // the HTTP response code from WriteHeader + HeaderMap http.Header // the HTTP response headers + Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to + Flushed bool +} + +// NewRecorder returns an initialized ResponseRecorder. +func NewRecorder() *ResponseRecorder { + return &ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + } +} + +// DefaultRemoteAddr is the default remote address to return in RemoteAddr if +// an explicit DefaultRemoteAddr isn't set on ResponseRecorder. +const DefaultRemoteAddr = "1.2.3.4" + +// Header returns the response headers. +func (rw *ResponseRecorder) Header() http.Header { + return rw.HeaderMap +} + +// Write always succeeds and writes to rw.Body, if not nil. +func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + if rw.Body != nil { + rw.Body.Write(buf) + } + if rw.Code == 0 { + rw.Code = http.StatusOK + } + return len(buf), nil +} + +// WriteHeader sets rw.Code. +func (rw *ResponseRecorder) WriteHeader(code int) { + rw.Code = code +} + +// Flush sets rw.Flushed to true. +func (rw *ResponseRecorder) Flush() { + rw.Flushed = true +} diff --git a/src/pkg/net/http/httptest/server.go b/src/pkg/net/http/httptest/server.go new file mode 100644 index 000000000..5b02e143d --- /dev/null +++ b/src/pkg/net/http/httptest/server.go @@ -0,0 +1,173 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Implementation of Server + +package httptest + +import ( + "crypto/tls" + "flag" + "fmt" + "net" + "net/http" + "os" +) + +// A Server is an HTTP server listening on a system-chosen port on the +// local loopback interface, for use in end-to-end HTTP tests. +type Server struct { + URL string // base URL of form http://ipaddr:port with no trailing slash + Listener net.Listener + TLS *tls.Config // nil if not using using TLS + + // Config may be changed after calling NewUnstartedServer and + // before Start or StartTLS. + Config *http.Server +} + +// historyListener keeps track of all connections that it's ever +// accepted. +type historyListener struct { + net.Listener + history []net.Conn +} + +func (hs *historyListener) Accept() (c net.Conn, err error) { + c, err = hs.Listener.Accept() + if err == nil { + hs.history = append(hs.history, c) + } + return +} + +func newLocalListener() net.Listener { + if *serve != "" { + l, err := net.Listen("tcp", *serve) + if err != nil { + panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err)) + } + return l + } + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + } + } + return l +} + +// When debugging a particular http server-based test, +// this flag lets you run +// gotest -run=BrokenTest -httptest.serve=127.0.0.1:8000 +// to start the broken server so you can interact with it manually. +var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks") + +// NewServer starts and returns a new Server. +// The caller should call Close when finished, to shut it down. +func NewServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.Start() + return ts +} + +// NewUnstartedServer returns a new Server but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedServer(handler http.Handler) *Server { + return &Server{ + Listener: newLocalListener(), + Config: &http.Server{Handler: handler}, + } +} + +// Start starts a server from NewUnstartedServer. +func (s *Server) Start() { + if s.URL != "" { + panic("Server already started") + } + s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)} + s.URL = "http://" + s.Listener.Addr().String() + go s.Config.Serve(s.Listener) + if *serve != "" { + fmt.Println(os.Stderr, "httptest: serving on", s.URL) + select {} + } +} + +// StartTLS starts TLS on a server from NewUnstartedServer. +func (s *Server) StartTLS() { + if s.URL != "" { + panic("Server already started") + } + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + + s.TLS = &tls.Config{ + NextProtos: []string{"http/1.1"}, + Certificates: []tls.Certificate{cert}, + } + tlsListener := tls.NewListener(s.Listener, s.TLS) + + s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)} + s.URL = "https://" + s.Listener.Addr().String() + go s.Config.Serve(s.Listener) +} + +// NewTLSServer starts and returns a new Server using TLS. +// The caller should call Close when finished, to shut it down. +func NewTLSServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.StartTLS() + return ts +} + +// Close shuts down the server. +func (s *Server) Close() { + s.Listener.Close() +} + +// CloseClientConnections closes any currently open HTTP connections +// to the test Server. +func (s *Server) CloseClientConnections() { + hl, ok := s.Listener.(*historyListener) + if !ok { + return + } + for _, conn := range hl.history { + conn.Close() + } +} + +// localhostCert is a PEM-encoded TLS cert with SAN DNS names +// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end +// of ASN.1 time). +var localhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX +DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7 +qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL +8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw +DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG +CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7 ++l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM= +-----END CERTIFICATE----- +`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v +tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi +SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0 +3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ +/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN +poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh +AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93 +-----END RSA PRIVATE KEY----- +`) diff --git a/src/pkg/net/http/httputil/Makefile b/src/pkg/net/http/httputil/Makefile new file mode 100644 index 000000000..8bfc7a022 --- /dev/null +++ b/src/pkg/net/http/httputil/Makefile @@ -0,0 +1,14 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../../Make.inc + +TARG=net/http/httputil +GOFILES=\ + chunked.go\ + dump.go\ + persist.go\ + reverseproxy.go\ + +include ../../../../Make.pkg diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go new file mode 100644 index 000000000..29eaf3475 --- /dev/null +++ b/src/pkg/net/http/httputil/chunked.go @@ -0,0 +1,172 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The wire protocol for HTTP's "chunked" Transfer-Encoding. + +// This code is a duplicate of ../chunked.go with these edits: +// s/newChunked/NewChunked/g +// s/package http/package httputil/ +// Please make any changes in both files. + +package httputil + +import ( + "bufio" + "bytes" + "errors" + "io" + "strconv" +) + +const maxLineLength = 4096 // assumed <= bufio.defaultBufSize + +var ErrLineTooLong = errors.New("header line too long") + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + br, ok := r.(*bufio.Reader) + if !ok { + br = bufio.NewReader(r) + } + return &chunkedReader{r: br} +} + +type chunkedReader struct { + r *bufio.Reader + n uint64 // unread bytes in chunk + err error +} + +func (cr *chunkedReader) beginChunk() { + // chunk-size CRLF + var line string + line, cr.err = readLine(cr.r) + if cr.err != nil { + return + } + cr.n, cr.err = strconv.ParseUint(line, 16, 64) + if cr.err != nil { + return + } + if cr.n == 0 { + cr.err = io.EOF + } +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + if cr.err != nil { + return 0, cr.err + } + if cr.n == 0 { + cr.beginChunk() + if cr.err != nil { + return 0, cr.err + } + } + if uint64(len(b)) > cr.n { + b = b[0:cr.n] + } + n, cr.err = cr.r.Read(b) + cr.n -= uint64(n) + if cr.n == 0 && cr.err == nil { + // end of chunk (CRLF) + b := make([]byte, 2) + if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil { + if b[0] != '\r' || b[1] != '\n' { + cr.err = errors.New("malformed chunked encoding") + } + } + } + return n, cr.err +} + +// Read a line of bytes (up to \n) from b. +// Give up if the line exceeds maxLineLength. +// The returned bytes are a pointer into storage in +// the bufio, so they are only valid until the next bufio read. +func readLineBytes(b *bufio.Reader) (p []byte, err error) { + if p, err = b.ReadSlice('\n'); err != nil { + // We always know when EOF is coming. + // If the caller asked for a line, there should be a line. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } else if err == bufio.ErrBufferFull { + err = ErrLineTooLong + } + return nil, err + } + if len(p) >= maxLineLength { + return nil, ErrLineTooLong + } + + // Chop off trailing white space. + p = bytes.TrimRight(p, " \r\t\n") + + return p, nil +} + +// readLineBytes, but convert the bytes into a string. +func readLine(b *bufio.Reader) (s string, err error) { + p, e := readLineBytes(b) + if e != nil { + return "", e + } + return string(p), nil +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using NewChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return &chunkedWriter{w} +} + +// Writing to chunkedWriter translates to writing in HTTP chunked Transfer +// Encoding wire format to the underlying Wire chunkedWriter. +type chunkedWriter struct { + Wire io.Writer +} + +// Write the contents of data as one chunk to Wire. +// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has +// a bug since it does not check for success of io.WriteString +func (cw *chunkedWriter) Write(data []byte) (n int, err error) { + + // Don't send 0-length data. It looks like EOF for chunked encoding. + if len(data) == 0 { + return 0, nil + } + + head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" + + if _, err = io.WriteString(cw.Wire, head); err != nil { + return 0, err + } + if n, err = cw.Wire.Write(data); err != nil { + return + } + if n != len(data) { + err = io.ErrShortWrite + return + } + _, err = io.WriteString(cw.Wire, "\r\n") + + return +} + +func (cw *chunkedWriter) Close() error { + _, err := io.WriteString(cw.Wire, "0\r\n") + return err +} diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go new file mode 100644 index 000000000..155a32bdf --- /dev/null +++ b/src/pkg/net/http/httputil/chunked_test.go @@ -0,0 +1,41 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code is a duplicate of ../chunked_test.go with these edits: +// s/newChunked/NewChunked/g +// s/package http/package httputil/ +// Please make any changes in both files. + +package httputil + +import ( + "bytes" + "io/ioutil" + "testing" +) + +func TestChunk(t *testing.T) { + var b bytes.Buffer + + w := NewChunkedWriter(&b) + const chunk1 = "hello, " + const chunk2 = "world! 0123456789abcdef" + w.Write([]byte(chunk1)) + w.Write([]byte(chunk2)) + w.Close() + + if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e { + t.Fatalf("chunk writer wrote %q; want %q", g, e) + } + + r := NewChunkedReader(&b) + data, err := ioutil.ReadAll(r) + if err != nil { + t.Logf(`data: "%s"`, data) + t.Fatalf("ReadAll from reader: %v", err) + } + if g, e := string(data), chunk1+chunk2; g != e { + t.Errorf("chunk reader read %q; want %q", g, e) + } +} diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go new file mode 100644 index 000000000..b8a98ee42 --- /dev/null +++ b/src/pkg/net/http/httputil/dump.go @@ -0,0 +1,196 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "strings" + "time" +) + +// One of the copies, say from b to r2, could be avoided by using a more +// elaborate trick where the other copy is made during Request/Response.Write. +// This would complicate things too much, given that these functions are for +// debugging only. +func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { + var buf bytes.Buffer + if _, err = buf.ReadFrom(b); err != nil { + return nil, nil, err + } + if err = b.Close(); err != nil { + return nil, nil, err + } + return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewBuffer(buf.Bytes())), nil +} + +// dumpConn is a net.Conn which writes to Writer and reads from Reader +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } + +// DumpRequestOut is like DumpRequest but includes +// headers that the standard http.Transport adds, +// such as User-Agent. +func DumpRequestOut(req *http.Request, body bool) (dump []byte, err error) { + save := req.Body + if !body || req.Body == nil { + req.Body = nil + } else { + save, req.Body, err = drainBody(req.Body) + if err != nil { + return + } + } + + var b bytes.Buffer + dialed := false + t := &http.Transport{ + Dial: func(net, addr string) (c net.Conn, err error) { + if dialed { + return nil, errors.New("unexpected second dial") + } + c = &dumpConn{ + Writer: &b, + Reader: strings.NewReader("HTTP/1.1 500 Fake Error\r\n\r\n"), + } + return + }, + } + + _, err = t.RoundTrip(req) + + req.Body = save + if err != nil { + return + } + dump = b.Bytes() + return +} + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} + +var reqWriteExcludeHeaderDump = map[string]bool{ + "Host": true, // not in Header map anyway + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// dumpAsReceived writes req to w in the form as it was received, or +// at least as accurately as possible from the information retained in +// the request. +func dumpAsReceived(req *http.Request, w io.Writer) error { + return nil +} + +// DumpRequest returns the as-received wire representation of req, +// optionally including the request body, for debugging. +// DumpRequest is semantically a no-op, but in order to +// dump the body, it reads the body data into memory and +// changes req.Body to refer to the in-memory copy. +// The documentation for http.Request.Write details which fields +// of req are used. +func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { + save := req.Body + if !body || req.Body == nil { + req.Body = nil + } else { + save, req.Body, err = drainBody(req.Body) + if err != nil { + return + } + } + + var b bytes.Buffer + + fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"), + req.URL.RequestURI(), req.ProtoMajor, req.ProtoMinor) + + host := req.Host + if host == "" && req.URL != nil { + host = req.URL.Host + } + if host != "" { + fmt.Fprintf(&b, "Host: %s\r\n", host) + } + + chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" + if len(req.TransferEncoding) > 0 { + fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ",")) + } + if req.Close { + fmt.Fprintf(&b, "Connection: close\r\n") + } + + err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump) + if err != nil { + return + } + + io.WriteString(&b, "\r\n") + + if req.Body != nil { + var dest io.Writer = &b + if chunked { + dest = NewChunkedWriter(dest) + } + _, err = io.Copy(dest, req.Body) + if chunked { + dest.(io.Closer).Close() + io.WriteString(&b, "\r\n") + } + } + + req.Body = save + if err != nil { + return + } + dump = b.Bytes() + return +} + +// DumpResponse is like DumpRequest but dumps a response. +func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) { + var b bytes.Buffer + save := resp.Body + savecl := resp.ContentLength + if !body || resp.Body == nil { + resp.Body = nil + resp.ContentLength = 0 + } else { + save, resp.Body, err = drainBody(resp.Body) + if err != nil { + return + } + } + err = resp.Write(&b) + resp.Body = save + resp.ContentLength = savecl + if err != nil { + return + } + dump = b.Bytes() + return +} diff --git a/src/pkg/net/http/httputil/dump_test.go b/src/pkg/net/http/httputil/dump_test.go new file mode 100644 index 000000000..819efb584 --- /dev/null +++ b/src/pkg/net/http/httputil/dump_test.go @@ -0,0 +1,140 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "testing" +) + +type dumpTest struct { + Req http.Request + Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + + WantDump string + WantDumpOut string +} + +var dumpTests = []dumpTest{ + + // HTTP/1.1 => chunked coding; body; empty trailer + { + Req: http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + TransferEncoding: []string{"chunked"}, + }, + + Body: []byte("abcdef"), + + WantDump: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + }, + + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, + // and doesn't add a User-Agent. + { + Req: http.Request{ + Method: "GET", + URL: mustParseURL("/foo"), + ProtoMajor: 1, + ProtoMinor: 0, + Header: http.Header{ + "X-Foo": []string{"X-Bar"}, + }, + }, + + WantDump: "GET /foo HTTP/1.0\r\n" + + "X-Foo: X-Bar\r\n\r\n", + }, + + { + Req: *mustNewRequest("GET", "http://example.com/foo", nil), + + WantDumpOut: "GET /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, +} + +func TestDumpRequest(t *testing.T) { + for i, tt := range dumpTests { + setBody := func() { + if tt.Body == nil { + return + } + switch b := tt.Body.(type) { + case []byte: + tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b)) + case func() io.ReadCloser: + tt.Req.Body = b() + } + } + setBody() + if tt.Req.Header == nil { + tt.Req.Header = make(http.Header) + } + + if tt.WantDump != "" { + setBody() + dump, err := DumpRequest(&tt.Req, true) + if err != nil { + t.Errorf("DumpRequest #%d: %s", i, err) + continue + } + if string(dump) != tt.WantDump { + t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump)) + continue + } + } + + if tt.WantDumpOut != "" { + setBody() + dump, err := DumpRequestOut(&tt.Req, true) + if err != nil { + t.Errorf("DumpRequestOut #%d: %s", i, err) + continue + } + if string(dump) != tt.WantDumpOut { + t.Errorf("DumpRequestOut %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDumpOut, string(dump)) + continue + } + } + } +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Sprintf("Error parsing URL %q: %v", s, err)) + } + return u +} + +func mustNewRequest(method, url string, body io.Reader) *http.Request { + req, err := http.NewRequest(method, url, body) + if err != nil { + panic(fmt.Sprintf("NewRequest(%q, %q, %p) err = %v", method, url, body, err)) + } + return req +} diff --git a/src/pkg/net/http/httputil/persist.go b/src/pkg/net/http/httputil/persist.go new file mode 100644 index 000000000..1266bd3ad --- /dev/null +++ b/src/pkg/net/http/httputil/persist.go @@ -0,0 +1,422 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httputil provides HTTP utility functions, complementing the +// more common ones in the net/http package. +package httputil + +import ( + "bufio" + "errors" + "io" + "net" + "net/http" + "net/textproto" + "os" + "sync" +) + +var ( + ErrPersistEOF = &http.ProtocolError{"persistent connection closed"} + ErrPipeline = &http.ProtocolError{"pipeline error"} +) + +// This is an API usage error - the local side is closed. +// ErrPersistEOF (above) reports that the remote side is closed. +var errClosed = errors.New("i/o operation on closed connection") + +// A ServerConn reads requests and sends responses over an underlying +// connection, until the HTTP keepalive logic commands an end. ServerConn +// also allows hijacking the underlying connection by calling Hijack +// to regain control over the connection. ServerConn supports pipe-lining, +// i.e. requests can be read out of sync (but in the same order) while the +// respective responses are sent. +// +// ServerConn is low-level and should not be needed by most applications. +// See Server. +type ServerConn struct { + lk sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline +} + +// NewServerConn returns a new ServerConn reading and writing c. If r is not +// nil, it is the buffer to use when reading c. +func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)} +} + +// Hijack detaches the ServerConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before Read has signaled the end of the keep-alive logic. The user +// should not call Hijack while Read or Write is in progress. +func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) { + sc.lk.Lock() + defer sc.lk.Unlock() + c = sc.c + r = sc.r + sc.c = nil + sc.r = nil + return +} + +// Close calls Hijack and then also closes the underlying connection +func (sc *ServerConn) Close() error { + c, _ := sc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Read returns the next request on the wire. An ErrPersistEOF is returned if +// it is gracefully determined that there are no more requests (e.g. after the +// first request on an HTTP/1.0 connection, or after a Connection:close on a +// HTTP/1.1 connection). +func (sc *ServerConn) Read() (req *http.Request, err error) { + + // Ensure ordered execution of Reads and Writes + id := sc.pipe.Next() + sc.pipe.StartRequest(id) + defer func() { + sc.pipe.EndRequest(id) + if req == nil { + sc.pipe.StartResponse(id) + sc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + sc.lk.Lock() + sc.pipereq[req] = id + sc.lk.Unlock() + } + }() + + sc.lk.Lock() + if sc.we != nil { // no point receiving if write-side broken or closed + defer sc.lk.Unlock() + return nil, sc.we + } + if sc.re != nil { + defer sc.lk.Unlock() + return nil, sc.re + } + if sc.r == nil { // connection closed by user in the meantime + defer sc.lk.Unlock() + return nil, errClosed + } + r := sc.r + lastbody := sc.lastbody + sc.lastbody = nil + sc.lk.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invocation + // returned. + err = lastbody.Close() + if err != nil { + sc.lk.Lock() + defer sc.lk.Unlock() + sc.re = err + return nil, err + } + } + + req, err = http.ReadRequest(r) + sc.lk.Lock() + defer sc.lk.Unlock() + if err != nil { + if err == io.ErrUnexpectedEOF { + // A close from the opposing client is treated as a + // graceful close, even if there was some unparse-able + // data before the close. + sc.re = ErrPersistEOF + return nil, sc.re + } else { + sc.re = err + return req, err + } + } + sc.lastbody = req.Body + sc.nread++ + if req.Close { + sc.re = ErrPersistEOF + return req, sc.re + } + return req, err +} + +// Pending returns the number of unanswered requests +// that have been received on the connection. +func (sc *ServerConn) Pending() int { + sc.lk.Lock() + defer sc.lk.Unlock() + return sc.nread - sc.nwritten +} + +// Write writes resp in response to req. To close the connection gracefully, set the +// Response.Close field to true. Write should be considered operational until +// it returns an error, regardless of any errors returned on the Read side. +func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { + + // Retrieve the pipeline ID of this request/response pair + sc.lk.Lock() + id, ok := sc.pipereq[req] + delete(sc.pipereq, req) + if !ok { + sc.lk.Unlock() + return ErrPipeline + } + sc.lk.Unlock() + + // Ensure pipeline order + sc.pipe.StartResponse(id) + defer sc.pipe.EndResponse(id) + + sc.lk.Lock() + if sc.we != nil { + defer sc.lk.Unlock() + return sc.we + } + if sc.c == nil { // connection closed by user in the meantime + defer sc.lk.Unlock() + return os.EBADF + } + c := sc.c + if sc.nread <= sc.nwritten { + defer sc.lk.Unlock() + return errors.New("persist server pipe count") + } + if resp.Close { + // After signaling a keep-alive close, any pipelined unread + // requests will be lost. It is up to the user to drain them + // before signaling. + sc.re = ErrPersistEOF + } + sc.lk.Unlock() + + err := resp.Write(c) + sc.lk.Lock() + defer sc.lk.Unlock() + if err != nil { + sc.we = err + return err + } + sc.nwritten++ + + return nil +} + +// A ClientConn sends request and receives headers over an underlying +// connection, while respecting the HTTP keepalive logic. ClientConn +// supports hijacking the connection calling Hijack to +// regain control of the underlying net.Conn and deal with it as desired. +// +// ClientConn is low-level and should not be needed by most applications. +// See Client. +type ClientConn struct { + lk sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline + writeReq func(*http.Request, io.Writer) error +} + +// NewClientConn returns a new ClientConn reading and writing c. If r is not +// nil, it is the buffer to use when reading c. +func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ClientConn{ + c: c, + r: r, + pipereq: make(map[*http.Request]uint), + writeReq: (*http.Request).Write, + } +} + +// NewProxyClientConn works like NewClientConn but writes Requests +// using Request's WriteProxy method. +func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + cc := NewClientConn(c, r) + cc.writeReq = (*http.Request).WriteProxy + return cc +} + +// Hijack detaches the ClientConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before the user or Read have signaled the end of the keep-alive +// logic. The user should not call Hijack while Read or Write is in progress. +func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { + cc.lk.Lock() + defer cc.lk.Unlock() + c = cc.c + r = cc.r + cc.c = nil + cc.r = nil + return +} + +// Close calls Hijack and then also closes the underlying connection +func (cc *ClientConn) Close() error { + c, _ := cc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Write writes a request. An ErrPersistEOF error is returned if the connection +// has been closed in an HTTP keepalive sense. If req.Close equals true, the +// keepalive connection is logically closed after this request and the opposing +// server is informed. An ErrUnexpectedEOF indicates the remote closed the +// underlying TCP connection, which is usually considered as graceful close. +func (cc *ClientConn) Write(req *http.Request) (err error) { + + // Ensure ordered execution of Writes + id := cc.pipe.Next() + cc.pipe.StartRequest(id) + defer func() { + cc.pipe.EndRequest(id) + if err != nil { + cc.pipe.StartResponse(id) + cc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + cc.lk.Lock() + cc.pipereq[req] = id + cc.lk.Unlock() + } + }() + + cc.lk.Lock() + if cc.re != nil { // no point sending if read-side closed or broken + defer cc.lk.Unlock() + return cc.re + } + if cc.we != nil { + defer cc.lk.Unlock() + return cc.we + } + if cc.c == nil { // connection closed by user in the meantime + defer cc.lk.Unlock() + return errClosed + } + c := cc.c + if req.Close { + // We write the EOF to the write-side error, because there + // still might be some pipelined reads + cc.we = ErrPersistEOF + } + cc.lk.Unlock() + + err = cc.writeReq(req, c) + cc.lk.Lock() + defer cc.lk.Unlock() + if err != nil { + cc.we = err + return err + } + cc.nwritten++ + + return nil +} + +// Pending returns the number of unanswered requests +// that have been sent on the connection. +func (cc *ClientConn) Pending() int { + cc.lk.Lock() + defer cc.lk.Unlock() + return cc.nwritten - cc.nread +} + +// Read reads the next response from the wire. A valid response might be +// returned together with an ErrPersistEOF, which means that the remote +// requested that this be the last request serviced. Read can be called +// concurrently with Write, but not with another Read. +func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) { + // Retrieve the pipeline ID of this request/response pair + cc.lk.Lock() + id, ok := cc.pipereq[req] + delete(cc.pipereq, req) + if !ok { + cc.lk.Unlock() + return nil, ErrPipeline + } + cc.lk.Unlock() + + // Ensure pipeline order + cc.pipe.StartResponse(id) + defer cc.pipe.EndResponse(id) + + cc.lk.Lock() + if cc.re != nil { + defer cc.lk.Unlock() + return nil, cc.re + } + if cc.r == nil { // connection closed by user in the meantime + defer cc.lk.Unlock() + return nil, errClosed + } + r := cc.r + lastbody := cc.lastbody + cc.lastbody = nil + cc.lk.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invokation + // returned. + err = lastbody.Close() + if err != nil { + cc.lk.Lock() + defer cc.lk.Unlock() + cc.re = err + return nil, err + } + } + + resp, err = http.ReadResponse(r, req) + cc.lk.Lock() + defer cc.lk.Unlock() + if err != nil { + cc.re = err + return resp, err + } + cc.lastbody = resp.Body + + cc.nread++ + + if resp.Close { + cc.re = ErrPersistEOF // don't send any more requests + return resp, cc.re + } + return resp, err +} + +// Do is convenience method that writes a request and reads a response. +func (cc *ClientConn) Do(req *http.Request) (resp *http.Response, err error) { + err = cc.Write(req) + if err != nil { + return + } + return cc.Read(req) +} diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go new file mode 100644 index 000000000..1072e2e34 --- /dev/null +++ b/src/pkg/net/http/httputil/reverseproxy.go @@ -0,0 +1,167 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package httputil + +import ( + "io" + "log" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + FlushInterval time.Duration +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + director := func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + req.URL.RawQuery = target.RawQuery + } + return &ReverseProxy{Director: director} +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + outreq := new(http.Request) + *outreq = *req // includes shallow copies of maps, but okay + + p.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + // Remove the connection header to the backend. We want a + // persistent connection, regardless of what the client sent + // to us. This is modifying the same underlying map from req + // (shallow copied above) so we only copy it if necessary. + if outreq.Header.Get("Connection") != "" { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + outreq.Header.Del("Connection") + } + + if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + outreq.Header.Set("X-Forwarded-For", clientIp) + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + log.Printf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusInternalServerError) + return + } + + copyHeader(rw.Header(), res.Header) + + rw.WriteHeader(res.StatusCode) + + if res.Body != nil { + var dst io.Writer = rw + if p.FlushInterval != 0 { + if wf, ok := rw.(writeFlusher); ok { + dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval} + } + } + io.Copy(dst, res.Body) + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration + + lk sync.Mutex // protects init of done, as well Write + Flush + done chan bool +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.lk.Lock() + defer m.lk.Unlock() + if m.done == nil { + m.done = make(chan bool) + go m.flushLoop() + } + n, err = m.dst.Write(p) + if err != nil { + m.done <- true + } + return +} + +func (m *maxLatencyWriter) flushLoop() { + t := time.NewTicker(m.latency) + defer t.Stop() + for { + select { + case <-t.C: + m.lk.Lock() + m.dst.Flush() + m.lk.Unlock() + case <-m.done: + return + } + } + panic("unreached") +} diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go new file mode 100644 index 000000000..655784b30 --- /dev/null +++ b/src/pkg/net/http/httputil/reverseproxy_test.go @@ -0,0 +1,71 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package httputil + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(r.TransferEncoding) > 0 { + t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) + } + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got Connection header value %q", c) + } + if g, e := r.Host, "some-name"; g != e { + t.Errorf("backend got Host header %q, want %q", g, e) + } + w.Header().Set("X-Foo", "bar") + http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + if g, e := len(res.Header["Set-Cookie"]), 1; g != e { + t.Fatalf("got %d SetCookies, want %d", g, e) + } + if cookie := res.Cookies()[0]; cookie.Name != "flavor" { + t.Errorf("unexpected cookie %q", cookie.Name) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} diff --git a/src/pkg/net/http/jar.go b/src/pkg/net/http/jar.go new file mode 100644 index 000000000..2c2caa251 --- /dev/null +++ b/src/pkg/net/http/jar.go @@ -0,0 +1,30 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" +) + +// A CookieJar manages storage and use of cookies in HTTP requests. +// +// Implementations of CookieJar must be safe for concurrent use by multiple +// goroutines. +type CookieJar interface { + // SetCookies handles the receipt of the cookies in a reply for the + // given URL. It may or may not choose to save the cookies, depending + // on the jar's policy and implementation. + SetCookies(u *url.URL, cookies []*Cookie) + + // Cookies returns the cookies to send in a request for the given URL. + // It is up to the implementation to honor the standard cookie use + // restrictions such as in RFC 6265. + Cookies(u *url.URL) []*Cookie +} + +type blackHoleJar struct{} + +func (blackHoleJar) SetCookies(u *url.URL, cookies []*Cookie) {} +func (blackHoleJar) Cookies(u *url.URL) []*Cookie { return nil } diff --git a/src/pkg/net/http/lex.go b/src/pkg/net/http/lex.go new file mode 100644 index 000000000..93b67e701 --- /dev/null +++ b/src/pkg/net/http/lex.go @@ -0,0 +1,144 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// This file deals with lexical matters of HTTP + +func isSeparator(c byte) bool { + switch c { + case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t': + return true + } + return false +} + +func isSpace(c byte) bool { + switch c { + case ' ', '\t', '\r', '\n': + return true + } + return false +} + +func isCtl(c byte) bool { return (0 <= c && c <= 31) || c == 127 } + +func isChar(c byte) bool { return 0 <= c && c <= 127 } + +func isAnyText(c byte) bool { return !isCtl(c) } + +func isQdText(c byte) bool { return isAnyText(c) && c != '"' } + +func isToken(c byte) bool { return isChar(c) && !isCtl(c) && !isSeparator(c) } + +// Valid escaped sequences are not specified in RFC 2616, so for now, we assume +// that they coincide with the common sense ones used by GO. Malformed +// characters should probably not be treated as errors by a robust (forgiving) +// parser, so we replace them with the '?' character. +func httpUnquotePair(b byte) byte { + // skip the first byte, which should always be '\' + switch b { + case 'a': + return '\a' + case 'b': + return '\b' + case 'f': + return '\f' + case 'n': + return '\n' + case 'r': + return '\r' + case 't': + return '\t' + case 'v': + return '\v' + case '\\': + return '\\' + case '\'': + return '\'' + case '"': + return '"' + } + return '?' +} + +// raw must begin with a valid quoted string. Only the first quoted string is +// parsed and is unquoted in result. eaten is the number of bytes parsed, or -1 +// upon failure. +func httpUnquote(raw []byte) (eaten int, result string) { + buf := make([]byte, len(raw)) + if raw[0] != '"' { + return -1, "" + } + eaten = 1 + j := 0 // # of bytes written in buf + for i := 1; i < len(raw); i++ { + switch b := raw[i]; b { + case '"': + eaten++ + buf = buf[0:j] + return i + 1, string(buf) + case '\\': + if len(raw) < i+2 { + return -1, "" + } + buf[j] = httpUnquotePair(raw[i+1]) + eaten += 2 + j++ + i++ + default: + if isQdText(b) { + buf[j] = b + } else { + buf[j] = '?' + } + eaten++ + j++ + } + } + return -1, "" +} + +// This is a best effort parse, so errors are not returned, instead not all of +// the input string might be parsed. result is always non-nil. +func httpSplitFieldValue(fv string) (eaten int, result []string) { + result = make([]string, 0, len(fv)) + raw := []byte(fv) + i := 0 + chunk := "" + for i < len(raw) { + b := raw[i] + switch { + case b == '"': + eaten, unq := httpUnquote(raw[i:len(raw)]) + if eaten < 0 { + return i, result + } else { + i += eaten + chunk += unq + } + case isSeparator(b): + if chunk != "" { + result = result[0 : len(result)+1] + result[len(result)-1] = chunk + chunk = "" + } + i++ + case isToken(b): + chunk += string(b) + i++ + case b == '\n' || b == '\r': + i++ + default: + chunk += "?" + i++ + } + } + if chunk != "" { + result = result[0 : len(result)+1] + result[len(result)-1] = chunk + chunk = "" + } + return i, result +} diff --git a/src/pkg/net/http/lex_test.go b/src/pkg/net/http/lex_test.go new file mode 100644 index 000000000..5386f7534 --- /dev/null +++ b/src/pkg/net/http/lex_test.go @@ -0,0 +1,70 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "testing" +) + +type lexTest struct { + Raw string + Parsed int // # of parsed characters + Result []string +} + +var lexTests = []lexTest{ + { + Raw: `"abc"def,:ghi`, + Parsed: 13, + Result: []string{"abcdef", "ghi"}, + }, + // My understanding of the RFC is that escape sequences outside of + // quotes are not interpreted? + { + Raw: `"\t"\t"\t"`, + Parsed: 10, + Result: []string{"\t", "t\t"}, + }, + { + Raw: `"\yab"\r\n`, + Parsed: 10, + Result: []string{"?ab", "r", "n"}, + }, + { + Raw: "ab\f", + Parsed: 3, + Result: []string{"ab?"}, + }, + { + Raw: "\"ab \" c,de f, gh, ij\n\t\r", + Parsed: 23, + Result: []string{"ab ", "c", "de", "f", "gh", "ij"}, + }, +} + +func min(x, y int) int { + if x <= y { + return x + } + return y +} + +func TestSplitFieldValue(t *testing.T) { + for k, l := range lexTests { + parsed, result := httpSplitFieldValue(l.Raw) + if parsed != l.Parsed { + t.Errorf("#%d: Parsed %d, expected %d", k, parsed, l.Parsed) + } + if len(result) != len(l.Result) { + t.Errorf("#%d: Result len %d, expected %d", k, len(result), len(l.Result)) + } + for i := 0; i < min(len(result), len(l.Result)); i++ { + if result[i] != l.Result[i] { + t.Errorf("#%d: %d-th entry mismatch. Have {%s}, expect {%s}", + k, i, result[i], l.Result[i]) + } + } + } +} diff --git a/src/pkg/net/http/pprof/Makefile b/src/pkg/net/http/pprof/Makefile new file mode 100644 index 000000000..b78fce8e4 --- /dev/null +++ b/src/pkg/net/http/pprof/Makefile @@ -0,0 +1,11 @@ +# Copyright 2010 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../../Make.inc + +TARG=net/http/pprof +GOFILES=\ + pprof.go\ + +include ../../../../Make.pkg diff --git a/src/pkg/net/http/pprof/pprof.go b/src/pkg/net/http/pprof/pprof.go new file mode 100644 index 000000000..21eac4743 --- /dev/null +++ b/src/pkg/net/http/pprof/pprof.go @@ -0,0 +1,133 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pprof serves via its HTTP server runtime profiling data +// in the format expected by the pprof visualization tool. +// For more information about pprof, see +// http://code.google.com/p/google-perftools/. +// +// The package is typically only imported for the side effect of +// registering its HTTP handlers. +// The handled paths all begin with /debug/pprof/. +// +// To use pprof, link this package into your program: +// import _ "http/pprof" +// +// Then use the pprof tool to look at the heap profile: +// +// pprof http://localhost:6060/debug/pprof/heap +// +// Or to look at a 30-second CPU profile: +// +// pprof http://localhost:6060/debug/pprof/profile +// +package pprof + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/http" + "os" + "runtime" + "runtime/pprof" + "strconv" + "strings" + "time" +) + +func init() { + http.Handle("/debug/pprof/cmdline", http.HandlerFunc(Cmdline)) + http.Handle("/debug/pprof/profile", http.HandlerFunc(Profile)) + http.Handle("/debug/pprof/heap", http.HandlerFunc(Heap)) + http.Handle("/debug/pprof/symbol", http.HandlerFunc(Symbol)) +} + +// Cmdline responds with the running program's +// command line, with arguments separated by NUL bytes. +// The package initialization registers it as /debug/pprof/cmdline. +func Cmdline(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprintf(w, strings.Join(os.Args, "\x00")) +} + +// Heap responds with the pprof-formatted heap profile. +// The package initialization registers it as /debug/pprof/heap. +func Heap(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + pprof.WriteHeapProfile(w) +} + +// Profile responds with the pprof-formatted cpu profile. +// The package initialization registers it as /debug/pprof/profile. +func Profile(w http.ResponseWriter, r *http.Request) { + sec, _ := strconv.ParseInt(r.FormValue("seconds"), 10, 64) + if sec == 0 { + sec = 30 + } + + // Set Content Type assuming StartCPUProfile will work, + // because if it does it starts writing. + w.Header().Set("Content-Type", "application/octet-stream") + if err := pprof.StartCPUProfile(w); err != nil { + // StartCPUProfile failed, so no writes yet. + // Can change header back to text content + // and send error code. + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) + return + } + time.Sleep(time.Duration(sec) * time.Second) + pprof.StopCPUProfile() +} + +// Symbol looks up the program counters listed in the request, +// responding with a table mapping program counters to function names. +// The package initialization registers it as /debug/pprof/symbol. +func Symbol(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + + // We have to read the whole POST body before + // writing any output. Buffer the output here. + var buf bytes.Buffer + + // We don't know how many symbols we have, but we + // do have symbol information. Pprof only cares whether + // this number is 0 (no symbols available) or > 0. + fmt.Fprintf(&buf, "num_symbols: 1\n") + + var b *bufio.Reader + if r.Method == "POST" { + b = bufio.NewReader(r.Body) + } else { + b = bufio.NewReader(strings.NewReader(r.URL.RawQuery)) + } + + for { + word, err := b.ReadSlice('+') + if err == nil { + word = word[0 : len(word)-1] // trim + + } + pc, _ := strconv.ParseUint(string(word), 0, 64) + if pc != 0 { + f := runtime.FuncForPC(uintptr(pc)) + if f != nil { + fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name()) + } + } + + // Wait until here to check for err; the last + // symbol will have an err because it doesn't end in +. + if err != nil { + if err != io.EOF { + fmt.Fprintf(&buf, "reading request: %v\n", err) + } + break + } + } + + w.Write(buf.Bytes()) +} diff --git a/src/pkg/net/http/proxy_test.go b/src/pkg/net/http/proxy_test.go new file mode 100644 index 000000000..9b320b3aa --- /dev/null +++ b/src/pkg/net/http/proxy_test.go @@ -0,0 +1,48 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "os" + "testing" +) + +// TODO(mattn): +// test ProxyAuth + +var UseProxyTests = []struct { + host string + match bool +}{ + // Never proxy localhost: + {"localhost:80", false}, + {"127.0.0.1", false}, + {"127.0.0.2", false}, + {"[::1]", false}, + {"[::2]", true}, // not a loopback address + + {"barbaz.net", false}, // match as .barbaz.net + {"foobar.com", false}, // have a port but match + {"foofoobar.com", true}, // not match as a part of foobar.com + {"baz.com", true}, // not match as a part of barbaz.com + {"localhost.net", true}, // not match as suffix of address + {"local.localhost", true}, // not match as prefix as address + {"barbarbaz.net", true}, // not match because NO_PROXY have a '.' + {"www.foobar.com", true}, // not match because NO_PROXY is not .foobar.com +} + +func TestUseProxy(t *testing.T) { + oldenv := os.Getenv("NO_PROXY") + defer os.Setenv("NO_PROXY", oldenv) + + no_proxy := "foobar.com, .barbaz.net" + os.Setenv("NO_PROXY", no_proxy) + + for _, test := range UseProxyTests { + if useProxy(test.host+":80") != test.match { + t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) + } + } +} diff --git a/src/pkg/net/http/range_test.go b/src/pkg/net/http/range_test.go new file mode 100644 index 000000000..5274a81fa --- /dev/null +++ b/src/pkg/net/http/range_test.go @@ -0,0 +1,57 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "testing" +) + +var ParseRangeTests = []struct { + s string + length int64 + r []httpRange +}{ + {"", 0, nil}, + {"foo", 0, nil}, + {"bytes=", 0, nil}, + {"bytes=5-4", 10, nil}, + {"bytes=0-2,5-4", 10, nil}, + {"bytes=0-9", 10, []httpRange{{0, 10}}}, + {"bytes=0-", 10, []httpRange{{0, 10}}}, + {"bytes=5-", 10, []httpRange{{5, 5}}}, + {"bytes=0-20", 10, []httpRange{{0, 10}}}, + {"bytes=15-,0-5", 10, nil}, + {"bytes=-5", 10, []httpRange{{5, 5}}}, + {"bytes=-15", 10, []httpRange{{0, 10}}}, + {"bytes=0-499", 10000, []httpRange{{0, 500}}}, + {"bytes=500-999", 10000, []httpRange{{500, 500}}}, + {"bytes=-500", 10000, []httpRange{{9500, 500}}}, + {"bytes=9500-", 10000, []httpRange{{9500, 500}}}, + {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}}, + {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}}, + {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}}, +} + +func TestParseRange(t *testing.T) { + for _, test := range ParseRangeTests { + r := test.r + ranges, err := parseRange(test.s, test.length) + if err != nil && r != nil { + t.Errorf("parseRange(%q) returned error %q", test.s, err) + } + if len(ranges) != len(r) { + t.Errorf("len(parseRange(%q)) = %d, want %d", test.s, len(ranges), len(r)) + continue + } + for i := range r { + if ranges[i].start != r[i].start { + t.Errorf("parseRange(%q)[%d].start = %d, want %d", test.s, i, ranges[i].start, r[i].start) + } + if ranges[i].length != r[i].length { + t.Errorf("parseRange(%q)[%d].length = %d, want %d", test.s, i, ranges[i].length, r[i].length) + } + } + } +} diff --git a/src/pkg/net/http/readrequest_test.go b/src/pkg/net/http/readrequest_test.go new file mode 100644 index 000000000..2e03c658a --- /dev/null +++ b/src/pkg/net/http/readrequest_test.go @@ -0,0 +1,283 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/url" + "reflect" + "testing" +) + +type reqTest struct { + Raw string + Req *Request + Body string + Trailer Header + Error string +} + +var noError = "" +var noBody = "" +var noTrailer Header = nil + +var reqTests = []reqTest{ + // Baseline test; All Request fields included for template use + { + "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + "Host: www.techcrunch.com\r\n" + + "User-Agent: Fake\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + + "Accept-Language: en-us,en;q=0.5\r\n" + + "Accept-Encoding: gzip,deflate\r\n" + + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + + "Keep-Alive: 300\r\n" + + "Content-Length: 7\r\n" + + "Proxy-Connection: keep-alive\r\n\r\n" + + "abcdef\n???", + + &Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.techcrunch.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, + "Accept-Language": {"en-us,en;q=0.5"}, + "Accept-Encoding": {"gzip,deflate"}, + "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"}, + "Keep-Alive": {"300"}, + "Proxy-Connection": {"keep-alive"}, + "Content-Length": {"7"}, + "User-Agent": {"Fake"}, + }, + Close: false, + ContentLength: 7, + Host: "www.techcrunch.com", + RequestURI: "http://www.techcrunch.com/", + }, + + "abcdef\n", + + noTrailer, + noError, + }, + + // GET request with no body (the normal case) + { + "GET / HTTP/1.1\r\n" + + "Host: foo.com\r\n\r\n", + + &Request{ + Method: "GET", + URL: &url.URL{ + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "foo.com", + RequestURI: "/", + }, + + noBody, + noTrailer, + noError, + }, + + // Tests that we don't parse a path that looks like a + // scheme-relative URI as a scheme-relative URI. + { + "GET //user@host/is/actually/a/path/ HTTP/1.1\r\n" + + "Host: test\r\n\r\n", + + &Request{ + Method: "GET", + URL: &url.URL{ + Path: "//user@host/is/actually/a/path/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "test", + RequestURI: "//user@host/is/actually/a/path/", + }, + + noBody, + noTrailer, + noError, + }, + + // Tests a bogus abs_path on the Request-Line (RFC 2616 section 5.1.2) + { + "GET ../../../../etc/passwd HTTP/1.1\r\n" + + "Host: test\r\n\r\n", + nil, + noBody, + noTrailer, + "parse ../../../../etc/passwd: invalid URI for request", + }, + + // Tests missing URL: + { + "GET HTTP/1.1\r\n" + + "Host: test\r\n\r\n", + nil, + noBody, + noTrailer, + "parse : empty url", + }, + + // Tests chunked body with trailer: + { + "POST / HTTP/1.1\r\n" + + "Host: foo.com\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "3\r\nfoo\r\n" + + "3\r\nbar\r\n" + + "0\r\n" + + "Trailer-Key: Trailer-Value\r\n" + + "\r\n", + &Request{ + Method: "POST", + URL: &url.URL{ + Path: "/", + }, + TransferEncoding: []string{"chunked"}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + ContentLength: -1, + Host: "foo.com", + RequestURI: "/", + }, + + "foobar", + Header{ + "Trailer-Key": {"Trailer-Value"}, + }, + noError, + }, + + // CONNECT request with domain name: + { + "CONNECT www.google.com:443 HTTP/1.1\r\n\r\n", + + &Request{ + Method: "CONNECT", + URL: &url.URL{ + Host: "www.google.com:443", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "www.google.com:443", + RequestURI: "www.google.com:443", + }, + + noBody, + noTrailer, + noError, + }, + + // CONNECT request with IP address: + { + "CONNECT 127.0.0.1:6060 HTTP/1.1\r\n\r\n", + + &Request{ + Method: "CONNECT", + URL: &url.URL{ + Host: "127.0.0.1:6060", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "127.0.0.1:6060", + RequestURI: "127.0.0.1:6060", + }, + + noBody, + noTrailer, + noError, + }, + + // CONNECT request for RPC: + { + "CONNECT /_goRPC_ HTTP/1.1\r\n\r\n", + + &Request{ + Method: "CONNECT", + URL: &url.URL{ + Path: "/_goRPC_", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: false, + ContentLength: 0, + Host: "", + RequestURI: "/_goRPC_", + }, + + noBody, + noTrailer, + noError, + }, +} + +func TestReadRequest(t *testing.T) { + for i := range reqTests { + tt := &reqTests[i] + var braw bytes.Buffer + braw.WriteString(tt.Raw) + req, err := ReadRequest(bufio.NewReader(&braw)) + if err != nil { + if err.Error() != tt.Error { + t.Errorf("#%d: error %q, want error %q", i, err.Error(), tt.Error) + } + continue + } + rbody := req.Body + req.Body = nil + diff(t, fmt.Sprintf("#%d Request", i), req, tt.Req) + var bout bytes.Buffer + if rbody != nil { + _, err := io.Copy(&bout, rbody) + if err != nil { + t.Fatalf("#%d. copying body: %v", i, err) + } + rbody.Close() + } + body := bout.String() + if body != tt.Body { + t.Errorf("#%d: Body = %q want %q", i, body, tt.Body) + } + if !reflect.DeepEqual(tt.Trailer, req.Trailer) { + t.Errorf("#%d. Trailers differ.\n got: %v\nwant: %v", i, req.Trailer, tt.Trailer) + } + } +} diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go new file mode 100644 index 000000000..5f8c00086 --- /dev/null +++ b/src/pkg/net/http/request.go @@ -0,0 +1,741 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP Request reading and parsing. + +package http + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "mime" + "mime/multipart" + "net/textproto" + "net/url" + "strings" +) + +const ( + maxValueLength = 4096 + maxHeaderLines = 1024 + chunkSize = 4 << 10 // 4 KB chunks + defaultMaxMemory = 32 << 20 // 32 MB +) + +// ErrMissingFile is returned by FormFile when the provided file field name +// is either not present in the request or not a file field. +var ErrMissingFile = errors.New("http: no such file") + +// HTTP request parsing errors. +type ProtocolError struct { + ErrorString string +} + +func (err *ProtocolError) Error() string { return err.ErrorString } + +var ( + ErrHeaderTooLong = &ProtocolError{"header too long"} + ErrShortBody = &ProtocolError{"entity body too short"} + ErrNotSupported = &ProtocolError{"feature not supported"} + ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} + ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"} + ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} + ErrMissingBoundary = &ProtocolError{"no multipart boundary param Content-Type"} +) + +type badStringError struct { + what string + str string +} + +func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } + +// Headers that Request.Write handles itself and should be skipped. +var reqWriteExcludeHeader = map[string]bool{ + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// A Request represents an HTTP request received by a server +// or to be sent by a client. +type Request struct { + Method string // GET, POST, PUT, etc. + URL *url.URL + + // The protocol version for incoming requests. + // Outgoing requests always use HTTP/1.1. + Proto string // "HTTP/1.0" + ProtoMajor int // 1 + ProtoMinor int // 0 + + // A header maps request lines to their values. + // If the header says + // + // accept-encoding: gzip, deflate + // Accept-Language: en-us + // Connection: keep-alive + // + // then + // + // Header = map[string][]string{ + // "Accept-Encoding": {"gzip, deflate"}, + // "Accept-Language": {"en-us"}, + // "Connection": {"keep-alive"}, + // } + // + // HTTP defines that header names are case-insensitive. + // The request parser implements this by canonicalizing the + // name, making the first character and any characters + // following a hyphen uppercase and the rest lowercase. + Header Header + + // The message body. + Body io.ReadCloser + + // ContentLength records the length of the associated content. + // The value -1 indicates that the length is unknown. + // Values >= 0 indicate that the given number of bytes may + // be read from Body. + // For outgoing requests, a value of 0 means unknown if Body is not nil. + ContentLength int64 + + // TransferEncoding lists the transfer encodings from outermost to + // innermost. An empty list denotes the "identity" encoding. + // TransferEncoding can usually be ignored; chunked encoding is + // automatically added and removed as necessary when sending and + // receiving requests. + TransferEncoding []string + + // Close indicates whether to close the connection after + // replying to this request. + Close bool + + // The host on which the URL is sought. + // Per RFC 2616, this is either the value of the Host: header + // or the host name given in the URL itself. + Host string + + // Form contains the parsed form data, including both the URL + // field's query parameters and the POST or PUT form data. + // This field is only available after ParseForm is called. + // The HTTP client ignores Form and uses Body instead. + Form url.Values + + // MultipartForm is the parsed multipart form, including file uploads. + // This field is only available after ParseMultipartForm is called. + // The HTTP client ignores MultipartForm and uses Body instead. + MultipartForm *multipart.Form + + // Trailer maps trailer keys to values. Like for Header, if the + // response has multiple trailer lines with the same key, they will be + // concatenated, delimited by commas. + // For server requests, Trailer is only populated after Body has been + // closed or fully consumed. + // Trailer support is only partially complete. + Trailer Header + + // RemoteAddr allows HTTP servers and other software to record + // the network address that sent the request, usually for + // logging. This field is not filled in by ReadRequest and + // has no defined format. The HTTP server in this package + // sets RemoteAddr to an "IP:port" address before invoking a + // handler. + // This field is ignored by the HTTP client. + RemoteAddr string + + // RequestURI is the unmodified Request-URI of the + // Request-Line (RFC 2616, Section 5.1) as sent by the client + // to a server. Usually the URL field should be used instead. + // It is an error to set this field in an HTTP client request. + RequestURI string + + // TLS allows HTTP servers and other software to record + // information about the TLS connection on which the request + // was received. This field is not filled in by ReadRequest. + // The HTTP server in this package sets the field for + // TLS-enabled connections before invoking a handler; + // otherwise it leaves the field nil. + // This field is ignored by the HTTP client. + TLS *tls.ConnectionState +} + +// ProtoAtLeast returns whether the HTTP protocol used +// in the request is at least major.minor. +func (r *Request) ProtoAtLeast(major, minor int) bool { + return r.ProtoMajor > major || + r.ProtoMajor == major && r.ProtoMinor >= minor +} + +// UserAgent returns the client's User-Agent, if sent in the request. +func (r *Request) UserAgent() string { + return r.Header.Get("User-Agent") +} + +// Cookies parses and returns the HTTP cookies sent with the request. +func (r *Request) Cookies() []*Cookie { + return readCookies(r.Header, "") +} + +var ErrNoCookie = errors.New("http: named cookied not present") + +// Cookie returns the named cookie provided in the request or +// ErrNoCookie if not found. +func (r *Request) Cookie(name string) (*Cookie, error) { + for _, c := range readCookies(r.Header, name) { + return c, nil + } + return nil, ErrNoCookie +} + +// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, +// AddCookie does not attach more than one Cookie header field. That +// means all cookies, if any, are written into the same line, +// separated by semicolon. +func (r *Request) AddCookie(c *Cookie) { + s := fmt.Sprintf("%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) + if c := r.Header.Get("Cookie"); c != "" { + r.Header.Set("Cookie", c+"; "+s) + } else { + r.Header.Set("Cookie", s) + } +} + +// Referer returns the referring URL, if sent in the request. +// +// Referer is misspelled as in the request itself, a mistake from the +// earliest days of HTTP. This value can also be fetched from the +// Header map as Header["Referer"]; the benefit of making it available +// as a method is that the compiler can diagnose programs that use the +// alternate (correct English) spelling req.Referrer() but cannot +// diagnose programs that use Header["Referrer"]. +func (r *Request) Referer() string { + return r.Header.Get("Referer") +} + +// multipartByReader is a sentinel value. +// Its presence in Request.MultipartForm indicates that parsing of the request +// body has been handed off to a MultipartReader instead of ParseMultipartFrom. +var multipartByReader = &multipart.Form{ + Value: make(map[string][]string), + File: make(map[string][]*multipart.FileHeader), +} + +// MultipartReader returns a MIME multipart reader if this is a +// multipart/form-data POST request, else returns nil and an error. +// Use this function instead of ParseMultipartForm to +// process the request body as a stream. +func (r *Request) MultipartReader() (*multipart.Reader, error) { + if r.MultipartForm == multipartByReader { + return nil, errors.New("http: MultipartReader called twice") + } + if r.MultipartForm != nil { + return nil, errors.New("http: multipart handled by ParseMultipartForm") + } + r.MultipartForm = multipartByReader + return r.multipartReader() +} + +func (r *Request) multipartReader() (*multipart.Reader, error) { + v := r.Header.Get("Content-Type") + if v == "" { + return nil, ErrNotMultipart + } + d, params, err := mime.ParseMediaType(v) + if err != nil || d != "multipart/form-data" { + return nil, ErrNotMultipart + } + boundary, ok := params["boundary"] + if !ok { + return nil, ErrMissingBoundary + } + return multipart.NewReader(r.Body, boundary), nil +} + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} + +const defaultUserAgent = "Go http package" + +// Write writes an HTTP/1.1 request -- header and body -- in wire format. +// This method consults the following fields of req: +// Host +// URL +// Method (defaults to "GET") +// Header +// ContentLength +// TransferEncoding +// Body +// +// If Body is present, Content-Length is <= 0 and TransferEncoding +// hasn't been set to "identity", Write adds "Transfer-Encoding: +// chunked" to the header. Body is closed after it is sent. +func (req *Request) Write(w io.Writer) error { + return req.write(w, false, nil) +} + +// WriteProxy is like Write but writes the request in the form +// expected by an HTTP proxy. In particular, WriteProxy writes the +// initial Request-URI line of the request with an absolute URI, per +// section 5.1.2 of RFC 2616, including the scheme and host. In +// either case, WriteProxy also writes a Host header, using either +// req.Host or req.URL.Host. +func (req *Request) WriteProxy(w io.Writer) error { + return req.write(w, true, nil) +} + +// extraHeaders may be nil +func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) error { + host := req.Host + if host == "" { + if req.URL == nil { + return errors.New("http: Request.Write on Request with no Host or URL set") + } + host = req.URL.Host + } + + ruri := req.URL.RequestURI() + if usingProxy && req.URL.Scheme != "" && req.URL.Opaque == "" { + ruri = req.URL.Scheme + "://" + host + ruri + } else if req.Method == "CONNECT" && req.URL.Path == "" { + // CONNECT requests normally give just the host and port, not a full URL. + ruri = host + } + // TODO(bradfitz): escape at least newlines in ruri? + + bw := bufio.NewWriter(w) + fmt.Fprintf(bw, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) + + // Header lines + fmt.Fprintf(bw, "Host: %s\r\n", host) + + // Use the defaultUserAgent unless the Header contains one, which + // may be blank to not send the header. + userAgent := defaultUserAgent + if req.Header != nil { + if ua := req.Header["User-Agent"]; len(ua) > 0 { + userAgent = ua[0] + } + } + if userAgent != "" { + fmt.Fprintf(bw, "User-Agent: %s\r\n", userAgent) + } + + // Process Body,ContentLength,Close,Trailer + tw, err := newTransferWriter(req) + if err != nil { + return err + } + err = tw.WriteHeader(bw) + if err != nil { + return err + } + + // TODO: split long values? (If so, should share code with Conn.Write) + err = req.Header.WriteSubset(bw, reqWriteExcludeHeader) + if err != nil { + return err + } + + if extraHeaders != nil { + err = extraHeaders.Write(bw) + if err != nil { + return err + } + } + + io.WriteString(bw, "\r\n") + + // Write body and trailer + err = tw.WriteBody(bw) + if err != nil { + return err + } + + return bw.Flush() +} + +// Convert decimal at s[i:len(s)] to integer, +// returning value, string position where the digits stopped, +// and whether there was a valid number (digits, not too big). +func atoi(s string, i int) (n, i1 int, ok bool) { + const Big = 1000000 + if i >= len(s) || s[i] < '0' || s[i] > '9' { + return 0, 0, false + } + n = 0 + for ; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { + n = n*10 + int(s[i]-'0') + if n > Big { + return 0, 0, false + } + } + return n, i, true +} + +// ParseHTTPVersion parses a HTTP version string. +// "HTTP/1.0" returns (1, 0, true). +func ParseHTTPVersion(vers string) (major, minor int, ok bool) { + if len(vers) < 5 || vers[0:5] != "HTTP/" { + return 0, 0, false + } + major, i, ok := atoi(vers, 5) + if !ok || i >= len(vers) || vers[i] != '.' { + return 0, 0, false + } + minor, i, ok = atoi(vers, i+1) + if !ok || i != len(vers) { + return 0, 0, false + } + return major, minor, true +} + +// NewRequest returns a new Request given a method, URL, and optional body. +func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { + u, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + rc, ok := body.(io.ReadCloser) + if !ok && body != nil { + rc = ioutil.NopCloser(body) + } + req := &Request{ + Method: method, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + Body: rc, + Host: u.Host, + } + if body != nil { + switch v := body.(type) { + case *strings.Reader: + req.ContentLength = int64(v.Len()) + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + } + } + + return req, nil +} + +// SetBasicAuth sets the request's Authorization header to use HTTP +// Basic Authentication with the provided username and password. +// +// With HTTP Basic Authentication the provided username and password +// are not encrypted. +func (r *Request) SetBasicAuth(username, password string) { + s := username + ":" + password + r.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) +} + +// ReadRequest reads and parses a request from b. +func ReadRequest(b *bufio.Reader) (req *Request, err error) { + + tp := textproto.NewReader(b) + req = new(Request) + + // First line: GET /index.html HTTP/1.0 + var s string + if s, err = tp.ReadLine(); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + + var f []string + if f = strings.SplitN(s, " ", 3); len(f) < 3 { + return nil, &badStringError{"malformed HTTP request", s} + } + req.Method, req.RequestURI, req.Proto = f[0], f[1], f[2] + rawurl := req.RequestURI + var ok bool + if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { + return nil, &badStringError{"malformed HTTP version", req.Proto} + } + + // CONNECT requests are used two different ways, and neither uses a full URL: + // The standard use is to tunnel HTTPS through an HTTP proxy. + // It looks like "CONNECT www.google.com:443 HTTP/1.1", and the parameter is + // just the authority section of a URL. This information should go in req.URL.Host. + // + // The net/rpc package also uses CONNECT, but there the parameter is a path + // that starts with a slash. It can be parsed with the regular URL parser, + // and the path will end up in req.URL.Path, where it needs to be in order for + // RPC to work. + justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/") + if justAuthority { + rawurl = "http://" + rawurl + } + + if req.URL, err = url.ParseRequest(rawurl); err != nil { + return nil, err + } + + if justAuthority { + // Strip the bogus "http://" back off. + req.URL.Scheme = "" + } + + // Subsequent lines: Key: value. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + return nil, err + } + req.Header = Header(mimeHeader) + + // RFC2616: Must treat + // GET /index.html HTTP/1.1 + // Host: www.google.com + // and + // GET http://www.google.com/index.html HTTP/1.1 + // Host: doesntmatter + // the same. In the second case, any Host line is ignored. + req.Host = req.URL.Host + if req.Host == "" { + req.Host = req.Header.Get("Host") + } + req.Header.Del("Host") + + fixPragmaCacheControl(req.Header) + + // TODO: Parse specific header values: + // Accept + // Accept-Encoding + // Accept-Language + // Authorization + // Cache-Control + // Connection + // Date + // Expect + // From + // If-Match + // If-Modified-Since + // If-None-Match + // If-Range + // If-Unmodified-Since + // Max-Forwards + // Proxy-Authorization + // Referer [sic] + // TE (transfer-codings) + // Trailer + // Transfer-Encoding + // Upgrade + // User-Agent + // Via + // Warning + + err = readTransfer(req, b) + if err != nil { + return nil, err + } + + return req, nil +} + +// MaxBytesReader is similar to io.LimitReader but is intended for +// limiting the size of incoming request bodies. In contrast to +// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a +// non-EOF error for a Read beyond the limit, and Closes the +// underlying reader when its Close method is called. +// +// MaxBytesReader prevents clients from accidentally or maliciously +// sending a large request and wasting server resources. +func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { + return &maxBytesReader{w: w, r: r, n: n} +} + +type maxBytesReader struct { + w ResponseWriter + r io.ReadCloser // underlying reader + n int64 // max bytes remaining + stopped bool +} + +func (l *maxBytesReader) Read(p []byte) (n int, err error) { + if l.n <= 0 { + if !l.stopped { + l.stopped = true + if res, ok := l.w.(*response); ok { + res.requestTooLarge() + } + } + return 0, errors.New("http: request body too large") + } + if int64(len(p)) > l.n { + p = p[:l.n] + } + n, err = l.r.Read(p) + l.n -= int64(n) + return +} + +func (l *maxBytesReader) Close() error { + return l.r.Close() +} + +// ParseForm parses the raw query from the URL. +// +// For POST or PUT requests, it also parses the request body as a form. +// If the request Body's size has not already been limited by MaxBytesReader, +// the size is capped at 10MB. +// +// ParseMultipartForm calls ParseForm automatically. +// It is idempotent. +func (r *Request) ParseForm() (err error) { + if r.Form != nil { + return + } + if r.URL != nil { + r.Form, err = url.ParseQuery(r.URL.RawQuery) + } + if r.Method == "POST" || r.Method == "PUT" { + if r.Body == nil { + return errors.New("missing form body") + } + ct := r.Header.Get("Content-Type") + ct, _, err = mime.ParseMediaType(ct) + switch { + case ct == "application/x-www-form-urlencoded": + var reader io.Reader = r.Body + maxFormSize := int64(1<<63 - 1) + if _, ok := r.Body.(*maxBytesReader); !ok { + maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + reader = io.LimitReader(r.Body, maxFormSize+1) + } + b, e := ioutil.ReadAll(reader) + if e != nil { + if err == nil { + err = e + } + break + } + if int64(len(b)) > maxFormSize { + return errors.New("http: POST too large") + } + var newValues url.Values + newValues, e = url.ParseQuery(string(b)) + if err == nil { + err = e + } + if r.Form == nil { + r.Form = make(url.Values) + } + // Copy values into r.Form. TODO: make this smoother. + for k, vs := range newValues { + for _, value := range vs { + r.Form.Add(k, value) + } + } + case ct == "multipart/form-data": + // handled by ParseMultipartForm (which is calling us, or should be) + // TODO(bradfitz): there are too many possible + // orders to call too many functions here. + // Clean this up and write more tests. + // request_test.go contains the start of this, + // in TestRequestMultipartCallOrder. + } + } + return err +} + +// ParseMultipartForm parses a request body as multipart/form-data. +// The whole request body is parsed and up to a total of maxMemory bytes of +// its file parts are stored in memory, with the remainder stored on +// disk in temporary files. +// ParseMultipartForm calls ParseForm if necessary. +// After one call to ParseMultipartForm, subsequent calls have no effect. +func (r *Request) ParseMultipartForm(maxMemory int64) error { + if r.MultipartForm == multipartByReader { + return errors.New("http: multipart handled by MultipartReader") + } + if r.Form == nil { + err := r.ParseForm() + if err != nil { + return err + } + } + if r.MultipartForm != nil { + return nil + } + + mr, err := r.multipartReader() + if err == ErrNotMultipart { + return nil + } else if err != nil { + return err + } + + f, err := mr.ReadForm(maxMemory) + if err != nil { + return err + } + for k, v := range f.Value { + r.Form[k] = append(r.Form[k], v...) + } + r.MultipartForm = f + + return nil +} + +// FormValue returns the first value for the named component of the query. +// FormValue calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) FormValue(key string) string { + if r.Form == nil { + r.ParseMultipartForm(defaultMaxMemory) + } + if vs := r.Form[key]; len(vs) > 0 { + return vs[0] + } + return "" +} + +// FormFile returns the first file for the provided form key. +// FormFile calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) { + if r.MultipartForm == multipartByReader { + return nil, nil, errors.New("http: multipart handled by MultipartReader") + } + if r.MultipartForm == nil { + err := r.ParseMultipartForm(defaultMaxMemory) + if err != nil { + return nil, nil, err + } + } + if r.MultipartForm != nil && r.MultipartForm.File != nil { + if fhs := r.MultipartForm.File[key]; len(fhs) > 0 { + f, err := fhs[0].Open() + return f, fhs[0], err + } + } + return nil, nil, ErrMissingFile +} + +func (r *Request) expectsContinue() bool { + return strings.ToLower(r.Header.Get("Expect")) == "100-continue" +} + +func (r *Request) wantsHttp10KeepAlive() bool { + if r.ProtoMajor != 1 || r.ProtoMinor != 0 { + return false + } + return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "keep-alive") +} diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go new file mode 100644 index 000000000..7a3556d03 --- /dev/null +++ b/src/pkg/net/http/request_test.go @@ -0,0 +1,283 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + . "net/http" + "net/http/httptest" + "net/url" + "os" + "reflect" + "regexp" + "strings" + "testing" +) + +func TestQuery(t *testing.T) { + req := &Request{Method: "GET"} + req.URL, _ = url.Parse("http://www.google.com/search?q=foo&q=bar") + if q := req.FormValue("q"); q != "foo" { + t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) + } +} + +func TestPostQuery(t *testing.T) { + req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x", + strings.NewReader("z=post&both=y")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + if q := req.FormValue("q"); q != "foo" { + t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) + } + if z := req.FormValue("z"); z != "post" { + t.Errorf(`req.FormValue("z") = %q, want "post"`, z) + } + if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"x", "y"}) { + t.Errorf(`req.FormValue("both") = %q, want ["x", "y"]`, both) + } +} + +type stringMap map[string][]string +type parseContentTypeTest struct { + shouldError bool + contentType stringMap +} + +var parseContentTypeTests = []parseContentTypeTest{ + {false, stringMap{"Content-Type": {"text/plain"}}}, + // Non-existent keys are not placed. The value nil is illegal. + {true, stringMap{}}, + {true, stringMap{"Content-Type": {"text/plain; boundary="}}}, + {false, stringMap{"Content-Type": {"application/unknown"}}}, +} + +func TestParseFormUnknownContentType(t *testing.T) { + for i, test := range parseContentTypeTests { + req := &Request{ + Method: "POST", + Header: Header(test.contentType), + Body: ioutil.NopCloser(bytes.NewBufferString("body")), + } + err := req.ParseForm() + switch { + case err == nil && test.shouldError: + t.Errorf("test %d should have returned error", i) + case err != nil && !test.shouldError: + t.Errorf("test %d should not have returned error, got %v", i, err) + } + } +} + +func TestMultipartReader(t *testing.T) { + req := &Request{ + Method: "POST", + Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, + Body: ioutil.NopCloser(new(bytes.Buffer)), + } + multipart, err := req.MultipartReader() + if multipart == nil { + t.Errorf("expected multipart; error: %v", err) + } + + req.Header = Header{"Content-Type": {"text/plain"}} + multipart, err = req.MultipartReader() + if multipart != nil { + t.Errorf("unexpected multipart for text/plain") + } +} + +func TestRedirect(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + switch r.URL.Path { + case "/": + w.Header().Set("Location", "/foo/") + w.WriteHeader(StatusSeeOther) + case "/foo/": + fmt.Fprintf(w, "foo") + default: + w.WriteHeader(StatusBadRequest) + } + })) + defer ts.Close() + + var end = regexp.MustCompile("/foo/$") + r, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + r.Body.Close() + url := r.Request.URL.String() + if r.StatusCode != 200 || !end.MatchString(url) { + t.Fatalf("Get got status %d at %q, want 200 matching /foo/$", r.StatusCode, url) + } +} + +func TestSetBasicAuth(t *testing.T) { + r, _ := NewRequest("GET", "http://example.com/", nil) + r.SetBasicAuth("Aladdin", "open sesame") + if g, e := r.Header.Get("Authorization"), "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="; g != e { + t.Errorf("got header %q, want %q", g, e) + } +} + +func TestMultipartRequest(t *testing.T) { + // Test that we can read the values and files of a + // multipart request with FormValue and FormFile, + // and that ParseMultipartForm can be called multiple times. + req := newTestMultipartRequest(t) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm first call:", err) + } + defer req.MultipartForm.RemoveAll() + validateTestMultipartContents(t, req, false) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm second call:", err) + } + validateTestMultipartContents(t, req, false) +} + +func TestMultipartRequestAuto(t *testing.T) { + // Test that FormValue and FormFile automatically invoke + // ParseMultipartForm and return the right values. + req := newTestMultipartRequest(t) + defer func() { + if req.MultipartForm != nil { + req.MultipartForm.RemoveAll() + } + }() + validateTestMultipartContents(t, req, true) +} + +func TestEmptyMultipartRequest(t *testing.T) { + // Test that FormValue and FormFile automatically invoke + // ParseMultipartForm and return the right values. + req, err := NewRequest("GET", "/", nil) + if err != nil { + t.Errorf("NewRequest err = %q", err) + } + testMissingFile(t, req) +} + +func TestRequestMultipartCallOrder(t *testing.T) { + req := newTestMultipartRequest(t) + _, err := req.MultipartReader() + if err != nil { + t.Fatalf("MultipartReader: %v", err) + } + err = req.ParseMultipartForm(1024) + if err == nil { + t.Errorf("expected an error from ParseMultipartForm after call to MultipartReader") + } +} + +func testMissingFile(t *testing.T, req *Request) { + f, fh, err := req.FormFile("missing") + if f != nil { + t.Errorf("FormFile file = %q, want nil", f) + } + if fh != nil { + t.Errorf("FormFile file header = %q, want nil", fh) + } + if err != ErrMissingFile { + t.Errorf("FormFile err = %q, want ErrMissingFile", err) + } +} + +func newTestMultipartRequest(t *testing.T) *Request { + b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1)) + req, err := NewRequest("POST", "/", b) + if err != nil { + t.Fatal("NewRequest:", err) + } + ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary) + req.Header.Set("Content-type", ctype) + return req +} + +func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) { + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g, e := req.FormValue("textb"), textbValue; g != e { + t.Errorf("textb value = %q, want %q", g, e) + } + if g := req.FormValue("missing"); g != "" { + t.Errorf("missing value = %q, want empty string", g) + } + + assertMem := func(n string, fd multipart.File) { + if _, ok := fd.(*os.File); ok { + t.Error(n, " is *os.File, should not be") + } + } + fda := testMultipartFile(t, req, "filea", "filea.txt", fileaContents) + defer fda.Close() + assertMem("filea", fda) + fdb := testMultipartFile(t, req, "fileb", "fileb.txt", filebContents) + defer fdb.Close() + if allMem { + assertMem("fileb", fdb) + } else { + if _, ok := fdb.(*os.File); !ok { + t.Errorf("fileb has unexpected underlying type %T", fdb) + } + } + + testMissingFile(t, req) +} + +func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File { + f, fh, err := req.FormFile(key) + if err != nil { + t.Fatalf("FormFile(%q): %q", key, err) + } + if fh.Filename != expectFilename { + t.Errorf("filename = %q, want %q", fh.Filename, expectFilename) + } + var b bytes.Buffer + _, err = io.Copy(&b, f) + if err != nil { + t.Fatal("copying contents:", err) + } + if g := b.String(); g != expectContent { + t.Errorf("contents = %q, want %q", g, expectContent) + } + return f +} + +const ( + fileaContents = "This is a test file." + filebContents = "Another test file." + textaValue = "foo" + textbValue = "bar" + boundary = `MyBoundary` +) + +const message = ` +--MyBoundary +Content-Disposition: form-data; name="filea"; filename="filea.txt" +Content-Type: text/plain + +` + fileaContents + ` +--MyBoundary +Content-Disposition: form-data; name="fileb"; filename="fileb.txt" +Content-Type: text/plain + +` + filebContents + ` +--MyBoundary +Content-Disposition: form-data; name="texta" + +` + textaValue + ` +--MyBoundary +Content-Disposition: form-data; name="textb" + +` + textbValue + ` +--MyBoundary-- +` diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go new file mode 100644 index 000000000..fc3186f0c --- /dev/null +++ b/src/pkg/net/http/requestwrite_test.go @@ -0,0 +1,438 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net/url" + "strings" + "testing" +) + +type reqWriteTest struct { + Req Request + Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + + // Any of these three may be empty to skip that test. + WantWrite string // Request.Write + WantProxy string // Request.WriteProxy + + WantError error // wanted error from Request.Write +} + +var reqWriteTests = []reqWriteTest{ + // HTTP/1.1 => chunked coding; no body; no trailer + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.techcrunch.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, + "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"}, + "Accept-Encoding": {"gzip,deflate"}, + "Accept-Language": {"en-us,en;q=0.5"}, + "Keep-Alive": {"300"}, + "Proxy-Connection": {"keep-alive"}, + "User-Agent": {"Fake"}, + }, + Body: nil, + Close: false, + Host: "www.techcrunch.com", + Form: map[string][]string{}, + }, + + WantWrite: "GET / HTTP/1.1\r\n" + + "Host: www.techcrunch.com\r\n" + + "User-Agent: Fake\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + + "Accept-Encoding: gzip,deflate\r\n" + + "Accept-Language: en-us,en;q=0.5\r\n" + + "Keep-Alive: 300\r\n" + + "Proxy-Connection: keep-alive\r\n\r\n", + + WantProxy: "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + "Host: www.techcrunch.com\r\n" + + "User-Agent: Fake\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + + "Accept-Encoding: gzip,deflate\r\n" + + "Accept-Language: en-us,en;q=0.5\r\n" + + "Keep-Alive: 300\r\n" + + "Proxy-Connection: keep-alive\r\n\r\n", + }, + // HTTP/1.1 => chunked coding; body; empty trailer + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + TransferEncoding: []string{"chunked"}, + }, + + Body: []byte("abcdef"), + + WantWrite: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + + WantProxy: "GET http://www.google.com/search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + }, + // HTTP/1.1 POST => chunked coding; body; empty trailer + { + Req: Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: true, + TransferEncoding: []string{"chunked"}, + }, + + Body: []byte("abcdef"), + + WantWrite: "POST /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + + WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + }, + + // HTTP/1.1 POST with Content-Length, no chunking + { + Req: Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: true, + ContentLength: 6, + }, + + Body: []byte("abcdef"), + + WantWrite: "POST /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + + // HTTP/1.1 POST with Content-Length in headers + { + Req: Request{ + Method: "POST", + URL: mustParseURL("http://example.com/"), + Host: "example.com", + Header: Header{ + "Content-Length": []string{"10"}, // ignored + }, + ContentLength: 6, + }, + + Body: []byte("abcdef"), + + WantWrite: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + WantProxy: "POST http://example.com/ HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + + // default to HTTP/1.1 + { + Req: Request{ + Method: "GET", + URL: mustParseURL("/search"), + Host: "www.google.com", + }, + + WantWrite: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "\r\n", + }, + + // Request with a 0 ContentLength and a 0 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, + + // RFC 2616 Section 14.13 says Content-Length should be specified + // unless body is prohibited by the request method. + // Also, nginx expects it for POST and PUT. + WantWrite: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + + WantProxy: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + }, + + // Request with a 0 ContentLength and a 1 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) }, + + WantWrite: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("x") + chunk(""), + + WantProxy: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go http package\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("x") + chunk(""), + }, + + // Request with a ContentLength of 10 but a 5 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 10, // but we're going to send only 5 bytes + }, + Body: []byte("12345"), + WantError: errors.New("http: Request.ContentLength=10 with Body length 5"), + }, + + // Request with a ContentLength of 4 but an 8 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 4, // but we're going to try to send 8 bytes + }, + Body: []byte("12345678"), + WantError: errors.New("http: Request.ContentLength=4 with Body length 8"), + }, + + // Request with a 5 ContentLength and nil body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 5, // but we'll omit the body + }, + WantError: errors.New("http: Request.ContentLength=5 with nil Body"), + }, + + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, + // and doesn't add a User-Agent. + { + Req: Request{ + Method: "GET", + URL: mustParseURL("/foo"), + ProtoMajor: 1, + ProtoMinor: 0, + Header: Header{ + "X-Foo": []string{"X-Bar"}, + }, + }, + + WantWrite: "GET /foo HTTP/1.1\r\n" + + "Host: \r\n" + + "User-Agent: Go http package\r\n" + + "X-Foo: X-Bar\r\n\r\n", + }, +} + +func TestRequestWrite(t *testing.T) { + for i := range reqWriteTests { + tt := &reqWriteTests[i] + + setBody := func() { + if tt.Body == nil { + return + } + switch b := tt.Body.(type) { + case []byte: + tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b)) + case func() io.ReadCloser: + tt.Req.Body = b() + } + } + setBody() + if tt.Req.Header == nil { + tt.Req.Header = make(Header) + } + + var braw bytes.Buffer + err := tt.Req.Write(&braw) + if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.WantError); g != e { + t.Errorf("writing #%d, err = %q, want %q", i, g, e) + continue + } + if err != nil { + continue + } + + if tt.WantWrite != "" { + sraw := braw.String() + if sraw != tt.WantWrite { + t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantWrite, sraw) + continue + } + } + + if tt.WantProxy != "" { + setBody() + var praw bytes.Buffer + err = tt.Req.WriteProxy(&praw) + if err != nil { + t.Errorf("WriteProxy #%d: %s", i, err) + continue + } + sraw := praw.String() + if sraw != tt.WantProxy { + t.Errorf("Test Proxy %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantProxy, sraw) + continue + } + } + } +} + +type closeChecker struct { + io.Reader + closed bool +} + +func (rc *closeChecker) Close() error { + rc.closed = true + return nil +} + +// TestRequestWriteClosesBody tests that Request.Write does close its request.Body. +// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer +// inside a NopCloser, and that it serializes it correctly. +func TestRequestWriteClosesBody(t *testing.T) { + rc := &closeChecker{Reader: strings.NewReader("my body")} + req, _ := NewRequest("POST", "http://foo.com/", rc) + if req.ContentLength != 0 { + t.Errorf("got req.ContentLength %d, want 0", req.ContentLength) + } + buf := new(bytes.Buffer) + req.Write(buf) + if !rc.closed { + t.Error("body not closed after write") + } + expected := "POST / HTTP/1.1\r\n" + + "Host: foo.com\r\n" + + "User-Agent: Go http package\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + // TODO: currently we don't buffer before chunking, so we get a + // single "m" chunk before the other chunks, as this was the 1-byte + // read from our MultiReader where we stiched the Body back together + // after sniffing whether the Body was 0 bytes or not. + chunk("m") + + chunk("y body") + + chunk("") + if buf.String() != expected { + t.Errorf("write:\n got: %s\nwant: %s", buf.String(), expected) + } +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Sprintf("Error parsing URL %q: %v", s, err)) + } + return u +} diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go new file mode 100644 index 000000000..ae314b5ac --- /dev/null +++ b/src/pkg/net/http/response.go @@ -0,0 +1,236 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP Response reading and parsing. + +package http + +import ( + "bufio" + "errors" + "io" + "net/textproto" + "net/url" + "strconv" + "strings" +) + +var respExcludeHeader = map[string]bool{ + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// Response represents the response from an HTTP request. +// +type Response struct { + Status string // e.g. "200 OK" + StatusCode int // e.g. 200 + Proto string // e.g. "HTTP/1.0" + ProtoMajor int // e.g. 1 + ProtoMinor int // e.g. 0 + + // Header maps header keys to values. If the response had multiple + // headers with the same key, they will be concatenated, with comma + // delimiters. (Section 4.2 of RFC 2616 requires that multiple headers + // be semantically equivalent to a comma-delimited sequence.) Values + // duplicated by other fields in this struct (e.g., ContentLength) are + // omitted from Header. + // + // Keys in the map are canonicalized (see CanonicalHeaderKey). + Header Header + + // Body represents the response body. + // + // The http Client and Transport guarantee that Body is always + // non-nil, even on responses without a body or responses with + // a zero-lengthed body. + Body io.ReadCloser + + // ContentLength records the length of the associated content. The + // value -1 indicates that the length is unknown. Unless RequestMethod + // is "HEAD", values >= 0 indicate that the given number of bytes may + // be read from Body. + ContentLength int64 + + // Contains transfer encodings from outer-most to inner-most. Value is + // nil, means that "identity" encoding is used. + TransferEncoding []string + + // Close records whether the header directed that the connection be + // closed after reading Body. The value is advice for clients: neither + // ReadResponse nor Response.Write ever closes a connection. + Close bool + + // Trailer maps trailer keys to values, in the same + // format as the header. + Trailer Header + + // The Request that was sent to obtain this Response. + // Request's Body is nil (having already been consumed). + // This is only populated for Client requests. + Request *Request +} + +// Cookies parses and returns the cookies set in the Set-Cookie headers. +func (r *Response) Cookies() []*Cookie { + return readSetCookies(r.Header) +} + +var ErrNoLocation = errors.New("http: no Location header in response") + +// Location returns the URL of the response's "Location" header, +// if present. Relative redirects are resolved relative to +// the Response's Request. ErrNoLocation is returned if no +// Location header is present. +func (r *Response) Location() (*url.URL, error) { + lv := r.Header.Get("Location") + if lv == "" { + return nil, ErrNoLocation + } + if r.Request != nil && r.Request.URL != nil { + return r.Request.URL.Parse(lv) + } + return url.Parse(lv) +} + +// ReadResponse reads and returns an HTTP response from r. The +// req parameter specifies the Request that corresponds to +// this Response. Clients must call resp.Body.Close when finished +// reading resp.Body. After that call, clients can inspect +// resp.Trailer to find key/value pairs included in the response +// trailer. +func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err error) { + + tp := textproto.NewReader(r) + resp = new(Response) + + resp.Request = req + resp.Request.Method = strings.ToUpper(resp.Request.Method) + + // Parse the first line of the response. + line, err := tp.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + f := strings.SplitN(line, " ", 3) + if len(f) < 2 { + return nil, &badStringError{"malformed HTTP response", line} + } + reasonPhrase := "" + if len(f) > 2 { + reasonPhrase = f[2] + } + resp.Status = f[1] + " " + reasonPhrase + resp.StatusCode, err = strconv.Atoi(f[1]) + if err != nil { + return nil, &badStringError{"malformed HTTP status code", f[1]} + } + + resp.Proto = f[0] + var ok bool + if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok { + return nil, &badStringError{"malformed HTTP version", resp.Proto} + } + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + return nil, err + } + resp.Header = Header(mimeHeader) + + fixPragmaCacheControl(resp.Header) + + err = readTransfer(resp, r) + if err != nil { + return nil, err + } + + return resp, nil +} + +// RFC2616: Should treat +// Pragma: no-cache +// like +// Cache-Control: no-cache +func fixPragmaCacheControl(header Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} + +// ProtoAtLeast returns whether the HTTP protocol used +// in the response is at least major.minor. +func (r *Response) ProtoAtLeast(major, minor int) bool { + return r.ProtoMajor > major || + r.ProtoMajor == major && r.ProtoMinor >= minor +} + +// Writes the response (header, body and trailer) in wire format. This method +// consults the following fields of resp: +// +// StatusCode +// ProtoMajor +// ProtoMinor +// RequestMethod +// TransferEncoding +// Trailer +// Body +// ContentLength +// Header, values for non-canonical keys will have unpredictable behavior +// +func (resp *Response) Write(w io.Writer) error { + + // RequestMethod should be upper-case + if resp.Request != nil { + resp.Request.Method = strings.ToUpper(resp.Request.Method) + } + + // Status line + text := resp.Status + if text == "" { + var ok bool + text, ok = statusText[resp.StatusCode] + if !ok { + text = "status code " + strconv.Itoa(resp.StatusCode) + } + } + io.WriteString(w, "HTTP/"+strconv.Itoa(resp.ProtoMajor)+".") + io.WriteString(w, strconv.Itoa(resp.ProtoMinor)+" ") + io.WriteString(w, strconv.Itoa(resp.StatusCode)+" "+text+"\r\n") + + // Process Body,ContentLength,Close,Trailer + tw, err := newTransferWriter(resp) + if err != nil { + return err + } + err = tw.WriteHeader(w) + if err != nil { + return err + } + + // Rest of header + err = resp.Header.WriteSubset(w, respExcludeHeader) + if err != nil { + return err + } + + // End-of-header + io.WriteString(w, "\r\n") + + // Write body and trailer + err = tw.WriteBody(w) + if err != nil { + return err + } + + // Success + return nil +} diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go new file mode 100644 index 000000000..e5d01698e --- /dev/null +++ b/src/pkg/net/http/response_test.go @@ -0,0 +1,448 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "bytes" + "compress/gzip" + "crypto/rand" + "fmt" + "io" + "io/ioutil" + "net/url" + "reflect" + "testing" +) + +type respTest struct { + Raw string + Resp Response + Body string +} + +func dummyReq(method string) *Request { + return &Request{Method: method} +} + +var respTests = []respTest{ + // Unchunked response without Content-Length. + { + "HTTP/1.0 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{ + "Connection": {"close"}, // TODO(rsc): Delete? + }, + Close: true, + ContentLength: -1, + }, + + "Body here\n", + }, + + // Unchunked HTTP/1.1 response without Content-Length or + // Connection headers. + { + "HTTP/1.1 200 OK\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Request: dummyReq("GET"), + Close: true, + ContentLength: -1, + }, + + "Body here\n", + }, + + // Unchunked HTTP/1.1 204 response without Content-Length. + { + "HTTP/1.1 204 No Content\r\n" + + "\r\n" + + "Body should not be read!\n", + + Response{ + Status: "204 No Content", + StatusCode: 204, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Request: dummyReq("GET"), + Close: false, + ContentLength: 0, + }, + + "", + }, + + // Unchunked response with Content-Length. + { + "HTTP/1.0 200 OK\r\n" + + "Content-Length: 10\r\n" + + "Connection: close\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{ + "Connection": {"close"}, // TODO(rsc): Delete? + "Content-Length": {"10"}, // TODO(rsc): Delete? + }, + Close: true, + ContentLength: 10, + }, + + "Body here\n", + }, + + // Chunked response without Content-Length. + { + "HTTP/1.0 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0a\r\n" + + "Body here\n\r\n" + + "09\r\n" + + "continued\r\n" + + "0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Close: true, + ContentLength: -1, + TransferEncoding: []string{"chunked"}, + }, + + "Body here\ncontinued", + }, + + // Chunked response with Content-Length. + { + "HTTP/1.0 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "Content-Length: 10\r\n" + + "\r\n" + + "0a\r\n" + + "Body here\n" + + "0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Close: true, + ContentLength: -1, // TODO(rsc): Fix? + TransferEncoding: []string{"chunked"}, + }, + + "Body here\n", + }, + + // Chunked response in response to a HEAD request (the "chunked" should + // be ignored, as HEAD responses never have bodies) + { + "HTTP/1.0 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("HEAD"), + Header: Header{}, + Close: true, + ContentLength: 0, + }, + + "", + }, + + // explicit Content-Length of 0. + { + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Content-Length": {"0"}, + }, + Close: false, + ContentLength: 0, + }, + + "", + }, + + // Status line without a Reason-Phrase, but trailing space. + // (permitted by RFC 2616) + { + "HTTP/1.0 303 \r\n\r\n", + Response{ + Status: "303 ", + StatusCode: 303, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Close: true, + ContentLength: -1, + }, + + "", + }, + + // Status line without a Reason-Phrase, and no trailing space. + // (not permitted by RFC 2616, but we'll accept it anyway) + { + "HTTP/1.0 303\r\n\r\n", + Response{ + Status: "303 ", + StatusCode: 303, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Close: true, + ContentLength: -1, + }, + + "", + }, +} + +func TestReadResponse(t *testing.T) { + for i := range respTests { + tt := &respTests[i] + var braw bytes.Buffer + braw.WriteString(tt.Raw) + resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.Request) + if err != nil { + t.Errorf("#%d: %s", i, err) + continue + } + rbody := resp.Body + resp.Body = nil + diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp) + var bout bytes.Buffer + if rbody != nil { + io.Copy(&bout, rbody) + rbody.Close() + } + body := bout.String() + if body != tt.Body { + t.Errorf("#%d: Body = %q want %q", i, body, tt.Body) + } + } +} + +var readResponseCloseInMiddleTests = []struct { + chunked, compressed bool +}{ + {false, false}, + {true, false}, + {true, true}, +} + +// TestReadResponseCloseInMiddle tests that closing a body after +// reading only part of its contents advances the read to the end of +// the request, right up until the next request. +func TestReadResponseCloseInMiddle(t *testing.T) { + for _, test := range readResponseCloseInMiddleTests { + fatalf := func(format string, args ...interface{}) { + args = append([]interface{}{test.chunked, test.compressed}, args...) + t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...) + } + checkErr := func(err error, msg string) { + if err == nil { + return + } + fatalf(msg+": %v", err) + } + var buf bytes.Buffer + buf.WriteString("HTTP/1.1 200 OK\r\n") + if test.chunked { + buf.WriteString("Transfer-Encoding: chunked\r\n") + } else { + buf.WriteString("Content-Length: 1000000\r\n") + } + var wr io.Writer = &buf + if test.chunked { + wr = newChunkedWriter(wr) + } + if test.compressed { + buf.WriteString("Content-Encoding: gzip\r\n") + var err error + wr, err = gzip.NewWriter(wr) + checkErr(err, "gzip.NewWriter") + } + buf.WriteString("\r\n") + + chunk := bytes.Repeat([]byte{'x'}, 1000) + for i := 0; i < 1000; i++ { + if test.compressed { + // Otherwise this compresses too well. + _, err := io.ReadFull(rand.Reader, chunk) + checkErr(err, "rand.Reader ReadFull") + } + wr.Write(chunk) + } + if test.compressed { + err := wr.(*gzip.Compressor).Close() + checkErr(err, "compressor close") + } + if test.chunked { + buf.WriteString("0\r\n\r\n") + } + buf.WriteString("Next Request Here") + + bufr := bufio.NewReader(&buf) + resp, err := ReadResponse(bufr, dummyReq("GET")) + checkErr(err, "ReadResponse") + expectedLength := int64(-1) + if !test.chunked { + expectedLength = 1000000 + } + if resp.ContentLength != expectedLength { + fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength) + } + if resp.Body == nil { + fatalf("nil body") + } + if test.compressed { + gzReader, err := gzip.NewReader(resp.Body) + checkErr(err, "gzip.NewReader") + resp.Body = &readFirstCloseBoth{gzReader, resp.Body} + } + + rbuf := make([]byte, 2500) + n, err := io.ReadFull(resp.Body, rbuf) + checkErr(err, "2500 byte ReadFull") + if n != 2500 { + fatalf("ReadFull only read %d bytes", n) + } + if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) { + fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf)) + } + resp.Body.Close() + + rest, err := ioutil.ReadAll(bufr) + checkErr(err, "ReadAll on remainder") + if e, g := "Next Request Here", string(rest); e != g { + fatalf("remainder = %q, expected %q", g, e) + } + } +} + +func diff(t *testing.T, prefix string, have, want interface{}) { + hv := reflect.ValueOf(have).Elem() + wv := reflect.ValueOf(want).Elem() + if hv.Type() != wv.Type() { + t.Errorf("%s: type mismatch %v want %v", prefix, hv.Type(), wv.Type()) + } + for i := 0; i < hv.NumField(); i++ { + hf := hv.Field(i).Interface() + wf := wv.Field(i).Interface() + if !reflect.DeepEqual(hf, wf) { + t.Errorf("%s: %s = %v want %v", prefix, hv.Type().Field(i).Name, hf, wf) + } + } +} + +type responseLocationTest struct { + location string // Response's Location header or "" + requrl string // Response.Request.URL or "" + want string + wantErr error +} + +var responseLocationTests = []responseLocationTest{ + {"/foo", "http://bar.com/baz", "http://bar.com/foo", nil}, + {"http://foo.com/", "http://bar.com/baz", "http://foo.com/", nil}, + {"", "http://bar.com/baz", "", ErrNoLocation}, +} + +func TestLocationResponse(t *testing.T) { + for i, tt := range responseLocationTests { + res := new(Response) + res.Header = make(Header) + res.Header.Set("Location", tt.location) + if tt.requrl != "" { + res.Request = &Request{} + var err error + res.Request.URL, err = url.Parse(tt.requrl) + if err != nil { + t.Fatalf("bad test URL %q: %v", tt.requrl, err) + } + } + + got, err := res.Location() + if tt.wantErr != nil { + if err == nil { + t.Errorf("%d. err=nil; want %q", i, tt.wantErr) + continue + } + if g, e := err.Error(), tt.wantErr.Error(); g != e { + t.Errorf("%d. err=%q; want %q", i, g, e) + continue + } + continue + } + if err != nil { + t.Errorf("%d. err=%q", i, err) + continue + } + if g, e := got.String(), tt.want; g != e { + t.Errorf("%d. Location=%q; want %q", i, g, e) + } + } +} diff --git a/src/pkg/net/http/responsewrite_test.go b/src/pkg/net/http/responsewrite_test.go new file mode 100644 index 000000000..f8e63acf4 --- /dev/null +++ b/src/pkg/net/http/responsewrite_test.go @@ -0,0 +1,109 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "io/ioutil" + "testing" +) + +type respWriteTest struct { + Resp Response + Raw string +} + +var respWriteTests = []respWriteTest{ + // HTTP/1.0, identity coding; no trailer + { + Response{ + StatusCode: 503, + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + ContentLength: 6, + }, + + "HTTP/1.0 503 Service Unavailable\r\n" + + "Content-Length: 6\r\n\r\n" + + "abcdef", + }, + // Unchunked response without Content-Length. + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + ContentLength: -1, + }, + "HTTP/1.0 200 OK\r\n" + + "\r\n" + + "abcdef", + }, + // HTTP/1.1, chunked coding; empty trailer; close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + ContentLength: 6, + TransferEncoding: []string{"chunked"}, + Close: true, + }, + + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "6\r\nabcdef\r\n0\r\n\r\n", + }, + + // Header value with a newline character (Issue 914). + // Also tests removal of leading and trailing whitespace. + { + Response{ + StatusCode: 204, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Foo": []string{" Bar\nBaz "}, + }, + Body: nil, + ContentLength: 0, + TransferEncoding: []string{"chunked"}, + Close: true, + }, + + "HTTP/1.1 204 No Content\r\n" + + "Connection: close\r\n" + + "Foo: Bar Baz\r\n" + + "\r\n", + }, +} + +func TestResponseWrite(t *testing.T) { + for i := range respWriteTests { + tt := &respWriteTests[i] + var braw bytes.Buffer + err := tt.Resp.Write(&braw) + if err != nil { + t.Errorf("error writing #%d: %s", i, err) + continue + } + sraw := braw.String() + if sraw != tt.Raw { + t.Errorf("Test %d, expecting:\n%q\nGot:\n%q\n", i, tt.Raw, sraw) + continue + } + } +} diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go new file mode 100644 index 000000000..147c216ec --- /dev/null +++ b/src/pkg/net/http/serve_test.go @@ -0,0 +1,1182 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// End-to-end serving tests + +package http_test + +import ( + "bufio" + "bytes" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + "log" + "net" + . "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "os" + "reflect" + "strings" + "syscall" + "testing" + "time" +) + +type dummyAddr string +type oneConnListener struct { + conn net.Conn +} + +func (l *oneConnListener) Accept() (c net.Conn, err error) { + c = l.conn + if c == nil { + err = io.EOF + return + } + err = nil + l.conn = nil + return +} + +func (l *oneConnListener) Close() error { + return nil +} + +func (l *oneConnListener) Addr() net.Addr { + return dummyAddr("test-address") +} + +func (a dummyAddr) Network() string { + return string(a) +} + +func (a dummyAddr) String() string { + return string(a) +} + +type testConn struct { + readBuf bytes.Buffer + writeBuf bytes.Buffer +} + +func (c *testConn) Read(b []byte) (int, error) { + return c.readBuf.Read(b) +} + +func (c *testConn) Write(b []byte) (int, error) { + return c.writeBuf.Write(b) +} + +func (c *testConn) Close() error { + return nil +} + +func (c *testConn) LocalAddr() net.Addr { + return dummyAddr("local-addr") +} + +func (c *testConn) RemoteAddr() net.Addr { + return dummyAddr("remote-addr") +} + +func (c *testConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *testConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *testConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func TestConsumingBodyOnNextConn(t *testing.T) { + conn := new(testConn) + for i := 0; i < 2; i++ { + conn.readBuf.Write([]byte( + "POST / HTTP/1.1\r\n" + + "Host: test\r\n" + + "Content-Length: 11\r\n" + + "\r\n" + + "foo=1&bar=1")) + } + + reqNum := 0 + ch := make(chan *Request) + servech := make(chan error) + listener := &oneConnListener{conn} + handler := func(res ResponseWriter, req *Request) { + reqNum++ + ch <- req + } + + go func() { + servech <- Serve(listener, HandlerFunc(handler)) + }() + + var req *Request + req = <-ch + if req == nil { + t.Fatal("Got nil first request.") + } + if req.Method != "POST" { + t.Errorf("For request #1's method, got %q; expected %q", + req.Method, "POST") + } + + req = <-ch + if req == nil { + t.Fatal("Got nil first request.") + } + if req.Method != "POST" { + t.Errorf("For request #2's method, got %q; expected %q", + req.Method, "POST") + } + + if serveerr := <-servech; serveerr != io.EOF { + t.Errorf("Serve returned %q; expected EOF", serveerr) + } +} + +type stringHandler string + +func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) { + w.Header().Set("Result", string(s)) +} + +var handlers = []struct { + pattern string + msg string +}{ + {"/", "Default"}, + {"/someDir/", "someDir"}, + {"someHost.com/someDir/", "someHost.com/someDir"}, +} + +var vtests = []struct { + url string + expected string +}{ + {"http://localhost/someDir/apage", "someDir"}, + {"http://localhost/otherDir/apage", "Default"}, + {"http://someHost.com/someDir/apage", "someHost.com/someDir"}, + {"http://otherHost.com/someDir/apage", "someDir"}, + {"http://otherHost.com/aDir/apage", "Default"}, +} + +func TestHostHandlers(t *testing.T) { + for _, h := range handlers { + Handle(h.pattern, stringHandler(h.msg)) + } + ts := httptest.NewServer(nil) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + cc := httputil.NewClientConn(conn, nil) + for _, vt := range vtests { + var r *Response + var req Request + if req.URL, err = url.Parse(vt.url); err != nil { + t.Errorf("cannot parse url: %v", err) + continue + } + if err := cc.Write(&req); err != nil { + t.Errorf("writing request: %v", err) + continue + } + r, err := cc.Read(&req) + if err != nil { + t.Errorf("reading response: %v", err) + continue + } + s := r.Header.Get("Result") + if s != vt.expected { + t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + } + } +} + +// Tests for http://code.google.com/p/go/issues/detail?id=900 +func TestMuxRedirectLeadingSlashes(t *testing.T) { + paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"} + for _, path := range paths { + req, err := ReadRequest(bufio.NewReader(bytes.NewBufferString("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n"))) + if err != nil { + t.Errorf("%s", err) + } + mux := NewServeMux() + resp := httptest.NewRecorder() + + mux.ServeHTTP(resp, req) + + if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected { + t.Errorf("Expected Location header set to %q; got %q", expected, loc) + return + } + + if code, expected := resp.Code, StatusMovedPermanently; code != expected { + t.Errorf("Expected response code of StatusMovedPermanently; got %d", code) + return + } + } +} + +func TestServerTimeouts(t *testing.T) { + // TODO(bradfitz): convert this to use httptest.Server + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen error: %v", err) + } + addr, _ := l.Addr().(*net.TCPAddr) + + reqNum := 0 + handler := HandlerFunc(func(res ResponseWriter, req *Request) { + reqNum++ + fmt.Fprintf(res, "req=%d", reqNum) + }) + + const second = 1000000000 /* nanos */ + server := &Server{Handler: handler, ReadTimeout: 0.25 * second, WriteTimeout: 0.25 * second} + go server.Serve(l) + + url := fmt.Sprintf("http://%s/", addr) + + // Hit the HTTP server successfully. + tr := &Transport{DisableKeepAlives: true} // they interfere with this test + c := &Client{Transport: tr} + r, err := c.Get(url) + if err != nil { + t.Fatalf("http Get #1: %v", err) + } + got, _ := ioutil.ReadAll(r.Body) + expected := "req=1" + if string(got) != expected { + t.Errorf("Unexpected response for request #1; got %q; expected %q", + string(got), expected) + } + + // Slow client that should timeout. + t1 := time.Now() + conn, err := net.Dial("tcp", addr.String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + buf := make([]byte, 1) + n, err := conn.Read(buf) + latency := time.Now().Sub(t1) + if n != 0 || err != io.EOF { + t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) + } + if latency < 200*time.Millisecond /* fudge from 0.25 above */ { + t.Errorf("got EOF after %s, want >= %s", latency, 200*time.Millisecond) + } + + // Hit the HTTP server successfully again, verifying that the + // previous slow connection didn't run our handler. (that we + // get "req=2", not "req=3") + r, err = Get(url) + if err != nil { + t.Fatalf("http Get #2: %v", err) + } + got, _ = ioutil.ReadAll(r.Body) + expected = "req=2" + if string(got) != expected { + t.Errorf("Get #2 got %q, want %q", string(got), expected) + } + + l.Close() +} + +// TestIdentityResponse verifies that a handler can unset +func TestIdentityResponse(t *testing.T) { + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { + rw.Header().Set("Content-Length", "3") + rw.Header().Set("Transfer-Encoding", req.FormValue("te")) + switch { + case req.FormValue("overwrite") == "1": + _, err := rw.Write([]byte("foo TOO LONG")) + if err != ErrContentLength { + t.Errorf("expected ErrContentLength; got %v", err) + } + case req.FormValue("underwrite") == "1": + rw.Header().Set("Content-Length", "500") + rw.Write([]byte("too short")) + default: + rw.Write([]byte("foo")) + } + }) + + ts := httptest.NewServer(handler) + defer ts.Close() + + // Note: this relies on the assumption (which is true) that + // Get sends HTTP/1.1 or greater requests. Otherwise the + // server wouldn't have the choice to send back chunked + // responses. + for _, te := range []string{"", "identity"} { + url := ts.URL + "/?te=" + te + res, err := Get(url) + if err != nil { + t.Fatalf("error with Get of %s: %v", url, err) + } + if cl, expected := res.ContentLength, int64(3); cl != expected { + t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl) + } + if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected { + t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl) + } + if tl, expected := len(res.TransferEncoding), 0; tl != expected { + t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)", + url, expected, tl, res.TransferEncoding) + } + res.Body.Close() + } + + // Verify that ErrContentLength is returned + url := ts.URL + "/?overwrite=1" + _, err := Get(url) + if err != nil { + t.Fatalf("error with Get of %s: %v", url, err) + } + // Verify that the connection is closed when the declared Content-Length + // is larger than what the handler wrote. + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("error writing: %v", err) + } + + // The ReadAll will hang for a failing test, so use a Timer to + // fail explicitly. + goTimeout(t, 2*time.Second, func() { + got, _ := ioutil.ReadAll(conn) + expectedSuffix := "\r\n\r\ntoo short" + if !strings.HasSuffix(string(got), expectedSuffix) { + t.Errorf("Expected output to end with %q; got response body %q", + expectedSuffix, string(got)) + } + }) +} + +func testTcpConnectionCloses(t *testing.T, req string, h Handler) { + s := httptest.NewServer(h) + defer s.Close() + + conn, err := net.Dial("tcp", s.Listener.Addr().String()) + if err != nil { + t.Fatal("dial error:", err) + } + defer conn.Close() + + _, err = fmt.Fprint(conn, req) + if err != nil { + t.Fatal("print error:", err) + } + + r := bufio.NewReader(conn) + _, err = ReadResponse(r, &Request{Method: "GET"}) + if err != nil { + t.Fatal("ReadResponse error:", err) + } + + success := make(chan bool) + go func() { + select { + case <-time.After(5 * time.Second): + t.Fatal("body not closed after 5s") + case <-success: + } + }() + + _, err = ioutil.ReadAll(r) + if err != nil { + t.Fatal("read error:", err) + } + + success <- true +} + +// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. +func TestServeHTTP10Close(t *testing.T) { + testTcpConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, "testdata/file") + })) +} + +// TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close, +// even for HTTP/1.1 requests. +func TestHandlersCanSetConnectionClose11(t *testing.T) { + testTcpConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + })) +} + +func TestHandlersCanSetConnectionClose10(t *testing.T) { + testTcpConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + })) +} + +func TestSetsRemoteAddr(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%s", r.RemoteAddr) + })) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + ip := string(body) + if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") { + t.Fatalf("Expected local addr; got %q", ip) + } +} + +func TestChunkedResponseHeaders(t *testing.T) { + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted + fmt.Fprintf(w, "I am a chunked response.") + })) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + if g, e := res.ContentLength, int64(-1); g != e { + t.Errorf("expected ContentLength of %d; got %d", e, g) + } + if g, e := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(g, e) { + t.Errorf("expected TransferEncoding of %v; got %v", e, g) + } + if _, haveCL := res.Header["Content-Length"]; haveCL { + t.Errorf("Unexpected Content-Length") + } +} + +// Test304Responses verifies that 304s don't declare that they're +// chunking in their response headers and aren't allowed to produce +// output. +func Test304Responses(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNotModified) + _, err := w.Write([]byte("illegal body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + } + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +// TestHeadResponses verifies that responses to HEAD requests don't +// declare that they're chunking in their response headers and aren't +// allowed to produce output. +func TestHeadResponses(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("Ignored body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + } + + // Also exercise the ReaderFrom path + _, err = io.Copy(w, strings.NewReader("Ignored body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Copy, expected ErrBodyNotAllowed, got %v", err) + } + })) + defer ts.Close() + res, err := Head(ts.URL) + if err != nil { + t.Error(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +func TestTLSHandshakeTimeout(t *testing.T) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.StartTLS() + defer ts.Close() + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + goTimeout(t, 10*time.Second, func() { + var buf [1]byte + n, err := conn.Read(buf[:]) + if err == nil || n != 0 { + t.Errorf("Read = %d, %v; want an error and no bytes", n, err) + } + }) +} + +func TestTLSServer(t *testing.T) { + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.TLS != nil { + w.Header().Set("X-TLS-Set", "true") + if r.TLS.HandshakeComplete { + w.Header().Set("X-TLS-HandshakeComplete", "true") + } + } + })) + defer ts.Close() + + // Connect an idle TCP connection to this server before we run + // our real tests. This idle connection used to block forever + // in the TLS handshake, preventing future connections from + // being accepted. It may prevent future accidental blocking + // in newConn. + idleConn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer idleConn.Close() + goTimeout(t, 10*time.Second, func() { + if !strings.HasPrefix(ts.URL, "https://") { + t.Errorf("expected test TLS server to start with https://, got %q", ts.URL) + return + } + noVerifyTransport := &Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + client := &Client{Transport: noVerifyTransport} + res, err := client.Get(ts.URL) + if err != nil { + t.Error(err) + return + } + if res == nil { + t.Errorf("got nil Response") + return + } + defer res.Body.Close() + if res.Header.Get("X-TLS-Set") != "true" { + t.Errorf("expected X-TLS-Set response header") + return + } + if res.Header.Get("X-TLS-HandshakeComplete") != "true" { + t.Errorf("expected X-TLS-HandshakeComplete header") + } + }) +} + +type serverExpectTest struct { + contentLength int // of request body + expectation string // e.g. "100-continue" + readBody bool // whether handler should read the body (if false, sends StatusUnauthorized) + expectedResponse string // expected substring in first line of http response +} + +var serverExpectTests = []serverExpectTest{ + // Normal 100-continues, case-insensitive. + {100, "100-continue", true, "100 Continue"}, + {100, "100-cOntInUE", true, "100 Continue"}, + + // No 100-continue. + {100, "", true, "200 OK"}, + + // 100-continue but requesting client to deny us, + // so it never reads the body. + {100, "100-continue", false, "401 Unauthorized"}, + // Likewise without 100-continue: + {100, "", false, "401 Unauthorized"}, + + // Non-standard expectations are failures + {0, "a-pony", false, "417 Expectation Failed"}, + + // Expect-100 requested but no body + {0, "100-continue", true, "400 Bad Request"}, +} + +// Tests that the server responds to the "Expect" request header +// correctly. +func TestServerExpect(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // Note using r.FormValue("readbody") because for POST + // requests that would read from r.Body, which we only + // conditionally want to do. + if strings.Contains(r.URL.RawQuery, "readbody=true") { + ioutil.ReadAll(r.Body) + w.Write([]byte("Hi")) + } else { + w.WriteHeader(StatusUnauthorized) + } + })) + defer ts.Close() + + runTest := func(test serverExpectTest) { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + sendf := func(format string, args ...interface{}) { + _, err := fmt.Fprintf(conn, format, args...) + if err != nil { + t.Fatalf("On test %#v, error writing %q: %v", test, format, err) + } + } + go func() { + sendf("POST /?readbody=%v HTTP/1.1\r\n"+ + "Connection: close\r\n"+ + "Content-Length: %d\r\n"+ + "Expect: %s\r\nHost: foo\r\n\r\n", + test.readBody, test.contentLength, test.expectation) + if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" { + body := strings.Repeat("A", test.contentLength) + sendf(body) + } + }() + bufr := bufio.NewReader(conn) + line, err := bufr.ReadString('\n') + if err != nil { + t.Fatalf("ReadString: %v", err) + } + if !strings.Contains(line, test.expectedResponse) { + t.Errorf("for test %#v got first line=%q", test, line) + } + } + + for _, test := range serverExpectTests { + runTest(test) + } +} + +// Under a ~256KB (maxPostHandlerReadBytes) threshold, the server +// should consume client request bodies that a handler didn't read. +func TestServerUnreadRequestBodyLittle(t *testing.T) { + conn := new(testConn) + body := strings.Repeat("x", 100<<10) + conn.readBuf.Write([]byte(fmt.Sprintf( + "POST / HTTP/1.1\r\n"+ + "Host: test\r\n"+ + "Content-Length: %d\r\n"+ + "\r\n", len(body)))) + conn.readBuf.Write([]byte(body)) + + done := make(chan bool) + + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + defer close(done) + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len()) + } + rw.WriteHeader(200) + if g, e := conn.readBuf.Len(), 0; g != e { + t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) + } + if c := rw.Header().Get("Connection"); c != "" { + t.Errorf(`Connection header = %q; want ""`, c) + } + })) + <-done +} + +// Over a ~256KB (maxPostHandlerReadBytes) threshold, the server +// should ignore client request bodies that a handler didn't read +// and close the connection. +func TestServerUnreadRequestBodyLarge(t *testing.T) { + conn := new(testConn) + body := strings.Repeat("x", 1<<20) + conn.readBuf.Write([]byte(fmt.Sprintf( + "POST / HTTP/1.1\r\n"+ + "Host: test\r\n"+ + "Content-Length: %d\r\n"+ + "\r\n", len(body)))) + conn.readBuf.Write([]byte(body)) + + done := make(chan bool) + + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + defer close(done) + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) + } + rw.WriteHeader(200) + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) + } + if c := rw.Header().Get("Connection"); c != "close" { + t.Errorf(`Connection header = %q; want "close"`, c) + } + })) + <-done +} + +func TestTimeoutHandler(t *testing.T) { + sendHi := make(chan bool, 1) + writeErrors := make(chan error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + <-sendHi + _, werr := w.Write([]byte("hi")) + writeErrors <- werr + }) + timeout := make(chan time.Time, 1) // write to this to force timeouts + ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout)) + defer ts.Close() + + // Succeed without timing out: + sendHi <- true + res, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusOK; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := ioutil.ReadAll(res.Body) + if g, e := string(body), "hi"; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g := <-writeErrors; g != nil { + t.Errorf("got unexpected Write error on first request: %v", g) + } + + // Times out: + timeout <- time.Time{} + res, err = Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ = ioutil.ReadAll(res.Body) + if !strings.Contains(string(body), "<title>Timeout</title>") { + t.Errorf("expected timeout body; got %q", string(body)) + } + + // Now make the previously-timed out handler speak again, + // which verifies the panic is handled: + sendHi <- true + if g, e := <-writeErrors, ErrHandlerTimeout; g != e { + t.Errorf("expected Write error of %v; got %v", e, g) + } +} + +// Verifies we don't path.Clean() on the wrong parts in redirects. +func TestRedirectMunging(t *testing.T) { + req, _ := NewRequest("GET", "http://example.com/", nil) + + resp := httptest.NewRecorder() + Redirect(resp, req, "/foo?next=http://bar.com/", 302) + if g, e := resp.Header().Get("Location"), "/foo?next=http://bar.com/"; g != e { + t.Errorf("Location header was %q; want %q", g, e) + } + + resp = httptest.NewRecorder() + Redirect(resp, req, "http://localhost:8080/_ah/login?continue=http://localhost:8080/", 302) + if g, e := resp.Header().Get("Location"), "http://localhost:8080/_ah/login?continue=http://localhost:8080/"; g != e { + t.Errorf("Location header was %q; want %q", g, e) + } +} + +// TestZeroLengthPostAndResponse exercises an optimization done by the Transport: +// when there is no body (either because the method doesn't permit a body, or an +// explicit Content-Length of zero is present), then the transport can re-use the +// connection immediately. But when it re-uses the connection, it typically closes +// the previous request's body, which is not optimal for zero-lengthed bodies, +// as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. +func TestZeroLengthPostAndResponse(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + all, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("handler ReadAll: %v", err) + } + if len(all) != 0 { + t.Errorf("handler got %d bytes; expected 0", len(all)) + } + rw.Header().Set("Content-Length", "0") + })) + defer ts.Close() + + req, err := NewRequest("POST", ts.URL, strings.NewReader("")) + if err != nil { + t.Fatal(err) + } + req.ContentLength = 0 + + var resp [5]*Response + for i := range resp { + resp[i], err = DefaultClient.Do(req) + if err != nil { + t.Fatalf("client post #%d: %v", i, err) + } + } + + for i := range resp { + all, err := ioutil.ReadAll(resp[i].Body) + if err != nil { + t.Fatalf("req #%d: client ReadAll: %v", i, err) + } + if len(all) != 0 { + t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all)) + } + } +} + +func TestHandlerPanic(t *testing.T) { + testHandlerPanic(t, false) +} + +func TestHandlerPanicWithHijack(t *testing.T) { + testHandlerPanic(t, true) +} + +func testHandlerPanic(t *testing.T, withHijack bool) { + // Unlike the other tests that set the log output to ioutil.Discard + // to quiet the output, this test uses a pipe. The pipe serves three + // purposes: + // + // 1) The log.Print from the http server (generated by the caught + // panic) will go to the pipe instead of stderr, making the + // output quiet. + // + // 2) We read from the pipe to verify that the handler + // actually caught the panic and logged something. + // + // 3) The blocking Read call prevents this TestHandlerPanic + // function from exiting before the HTTP server handler + // finishes crashing. If this text function exited too + // early (and its defer log.SetOutput(os.Stderr) ran), + // then the crash output could spill into the next test. + pr, pw := io.Pipe() + log.SetOutput(pw) + defer log.SetOutput(os.Stderr) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if withHijack { + rwc, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Logf("unexpected error: %v", err) + } + defer rwc.Close() + } + panic("intentional death for testing") + })) + defer ts.Close() + + // Do a blocking read on the log output pipe so its logging + // doesn't bleed into the next test. But wait only 5 seconds + // for it. + done := make(chan bool, 1) + go func() { + buf := make([]byte, 4<<10) + _, err := pr.Read(buf) + pr.Close() + if err != nil { + t.Fatal(err) + } + done <- true + }() + + _, err := Get(ts.URL) + if err == nil { + t.Logf("expected an error") + } + + select { + case <-done: + return + case <-time.After(5 * time.Second): + t.Fatal("expected server handler to log an error") + } +} + +func TestNoDate(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header()["Date"] = nil + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + _, present := res.Header["Date"] + if present { + t.Fatalf("Expected no Date header; got %v", res.Header["Date"]) + } +} + +func TestStripPrefix(t *testing.T) { + h := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("X-Path", r.URL.Path) + }) + ts := httptest.NewServer(StripPrefix("/foo", h)) + defer ts.Close() + + res, err := Get(ts.URL + "/foo/bar") + if err != nil { + t.Fatal(err) + } + if g, e := res.Header.Get("X-Path"), "/bar"; g != e { + t.Errorf("test 1: got %s, want %s", g, e) + } + + res, err = Get(ts.URL + "/bar") + if err != nil { + t.Fatal(err) + } + if g, e := res.StatusCode, 404; g != e { + t.Errorf("test 2: got status %v, want %v", g, e) + } +} + +func TestRequestLimit(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + t.Fatalf("didn't expect to get request in Handler") + })) + defer ts.Close() + req, _ := NewRequest("GET", ts.URL, nil) + var bytesPerHeader = len("header12345: val12345\r\n") + for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ { + req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i)) + } + res, err := DefaultClient.Do(req) + if err != nil { + // Some HTTP clients may fail on this undefined behavior (server replying and + // closing the connection while the request is still being written), but + // we do support it (at least currently), so we expect a response below. + t.Fatalf("Do: %v", err) + } + if res.StatusCode != 413 { + t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +type countReader struct { + r io.Reader + n *int64 +} + +func (cr countReader) Read(p []byte) (n int, err error) { + n, err = cr.r.Read(p) + *cr.n += int64(n) + return +} + +func TestRequestBodyLimit(t *testing.T) { + const limit = 1 << 20 + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + r.Body = MaxBytesReader(w, r.Body, limit) + n, err := io.Copy(ioutil.Discard, r.Body) + if err == nil { + t.Errorf("expected error from io.Copy") + } + if n != limit { + t.Errorf("io.Copy = %d, want %d", n, limit) + } + })) + defer ts.Close() + + nWritten := int64(0) + req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), &nWritten}, limit*200)) + + // Send the POST, but don't care it succeeds or not. The + // remote side is going to reply and then close the TCP + // connection, and HTTP doesn't really define if that's + // allowed or not. Some HTTP clients will get the response + // and some (like ours, currently) will complain that the + // request write failed, without reading the response. + // + // But that's okay, since what we're really testing is that + // the remote side hung up on us before we wrote too much. + _, _ = DefaultClient.Do(req) + + if nWritten > limit*100 { + t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d", + limit, nWritten) + } +} + +// TestClientWriteShutdown tests that if the client shuts down the write +// side of their TCP connection, the server doesn't send a 400 Bad Request. +func TestClientWriteShutdown(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + err = conn.(*net.TCPConn).CloseWrite() + if err != nil { + t.Fatalf("Dial: %v", err) + } + donec := make(chan bool) + go func() { + defer close(donec) + bs, err := ioutil.ReadAll(conn) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + got := string(bs) + if got != "" { + t.Errorf("read %q from server; want nothing", got) + } + }() + select { + case <-donec: + case <-time.After(10 * time.Second): + t.Fatalf("timeout") + } +} + +// Tests that chunked server responses that write 1 byte at a time are +// buffered before chunk headers are added, not after chunk headers. +func TestServerBufferedChunking(t *testing.T) { + if true { + t.Logf("Skipping known broken test; see Issue 2357") + return + } + conn := new(testConn) + conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n")) + done := make(chan bool) + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + defer close(done) + rw.Header().Set("Content-Type", "text/plain") // prevent sniffing, which buffers + rw.Write([]byte{'x'}) + rw.Write([]byte{'y'}) + rw.Write([]byte{'z'}) + })) + <-done + if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) { + t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q", + conn.writeBuf.Bytes()) + } +} + +// goTimeout runs f, failing t if f takes more than ns to complete. +func goTimeout(t *testing.T, d time.Duration, f func()) { + ch := make(chan bool, 2) + timer := time.AfterFunc(d, func() { + t.Errorf("Timeout expired after %v", d) + ch <- true + }) + defer timer.Stop() + go func() { + defer func() { ch <- true }() + f() + }() + <-ch +} + +type errorListener struct { + errs []error +} + +func (l *errorListener) Accept() (c net.Conn, err error) { + if len(l.errs) == 0 { + return nil, io.EOF + } + err = l.errs[0] + l.errs = l.errs[1:] + return +} + +func (l *errorListener) Close() error { + return nil +} + +func (l *errorListener) Addr() net.Addr { + return dummyAddr("test-address") +} + +func TestAcceptMaxFds(t *testing.T) { + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + + ln := &errorListener{[]error{ + &net.OpError{ + Op: "accept", + Err: syscall.EMFILE, + }}} + err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {}))) + if err != io.EOF { + t.Errorf("got error %v, want EOF", err) + } +} + +func BenchmarkClientServer(b *testing.B) { + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + fmt.Fprintf(rw, "Hello world.\n") + })) + defer ts.Close() + b.StartTimer() + + for i := 0; i < b.N; i++ { + res, err := Get(ts.URL) + if err != nil { + b.Fatal("Get:", err) + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + b.Fatal("ReadAll:", err) + } + body := string(all) + if body != "Hello world.\n" { + b.Fatal("Got body:", body) + } + } + + b.StopTimer() +} diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go new file mode 100644 index 000000000..bad3bcb28 --- /dev/null +++ b/src/pkg/net/http/server.go @@ -0,0 +1,1191 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP server. See RFC 2616. + +// TODO(rsc): +// logging + +package http + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/url" + "path" + "runtime/debug" + "strconv" + "strings" + "sync" + "time" +) + +// Errors introduced by the HTTP server. +var ( + ErrWriteAfterFlush = errors.New("Conn.Write called after Flush") + ErrBodyNotAllowed = errors.New("http: response status code does not allow body") + ErrHijacked = errors.New("Conn has been hijacked") + ErrContentLength = errors.New("Conn.Write wrote more than the declared Content-Length") +) + +// Objects implementing the Handler interface can be +// registered to serve a particular path or subtree +// in the HTTP server. +// +// ServeHTTP should write reply headers and data to the ResponseWriter +// and then return. Returning signals that the request is finished +// and that the HTTP server can move on to the next request on +// the connection. +type Handler interface { + ServeHTTP(ResponseWriter, *Request) +} + +// A ResponseWriter interface is used by an HTTP handler to +// construct an HTTP response. +type ResponseWriter interface { + // Header returns the header map that will be sent by WriteHeader. + // Changing the header after a call to WriteHeader (or Write) has + // no effect. + Header() Header + + // Write writes the data to the connection as part of an HTTP reply. + // If WriteHeader has not yet been called, Write calls WriteHeader(http.StatusOK) + // before writing the data. + Write([]byte) (int, error) + + // WriteHeader sends an HTTP response header with status code. + // If WriteHeader is not called explicitly, the first call to Write + // will trigger an implicit WriteHeader(http.StatusOK). + // Thus explicit calls to WriteHeader are mainly used to + // send error codes. + WriteHeader(int) +} + +// The Flusher interface is implemented by ResponseWriters that allow +// an HTTP handler to flush buffered data to the client. +// +// Note that even for ResponseWriters that support Flush, +// if the client is connected through an HTTP proxy, +// the buffered data may not reach the client until the response +// completes. +type Flusher interface { + // Flush sends any buffered data to the client. + Flush() +} + +// The Hijacker interface is implemented by ResponseWriters that allow +// an HTTP handler to take over the connection. +type Hijacker interface { + // Hijack lets the caller take over the connection. + // After a call to Hijack(), the HTTP server library + // will not do anything else with the connection. + // It becomes the caller's responsibility to manage + // and close the connection. + Hijack() (net.Conn, *bufio.ReadWriter, error) +} + +// A conn represents the server side of an HTTP connection. +type conn struct { + remoteAddr string // network address of remote side + server *Server // the Server on which the connection arrived + rwc net.Conn // i/o connection + lr *io.LimitedReader // io.LimitReader(rwc) + buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc + hijacked bool // connection has been hijacked by handler + tlsState *tls.ConnectionState // or nil when not using TLS + body []byte +} + +// A response represents the server side of an HTTP response. +type response struct { + conn *conn + req *Request // request for this response + chunking bool // using chunked transfer encoding for reply body + wroteHeader bool // reply header has been written + wroteContinue bool // 100 Continue response was written + header Header // reply header parameters + written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 + status int // status code passed to WriteHeader + needSniff bool // need to sniff to find Content-Type + + // close connection after this reply. set on request and + // updated after response from handler if there's a + // "Connection: keep-alive" response header and a + // Content-Length. + closeAfterReply bool + + // requestBodyLimitHit is set by requestTooLarge when + // maxBytesReader hits its max size. It is checked in + // WriteHeader, to make sure we don't consume the the + // remaining request body to try to advance to the next HTTP + // request. Instead, when this is set, we stop doing + // subsequent requests on this connection and stop reading + // input from it. + requestBodyLimitHit bool +} + +// requestTooLarge is called by maxBytesReader when too much input has +// been read from the client. +func (w *response) requestTooLarge() { + w.closeAfterReply = true + w.requestBodyLimitHit = true + if !w.wroteHeader { + w.Header().Set("Connection", "close") + } +} + +type writerOnly struct { + io.Writer +} + +func (w *response) ReadFrom(src io.Reader) (n int64, err error) { + // Call WriteHeader before checking w.chunking if it hasn't + // been called yet, since WriteHeader is what sets w.chunking. + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + if !w.chunking && w.bodyAllowed() && !w.needSniff { + w.Flush() + if rf, ok := w.conn.rwc.(io.ReaderFrom); ok { + n, err = rf.ReadFrom(src) + w.written += n + return + } + } + // Fall back to default io.Copy implementation. + // Use wrapper to hide w.ReadFrom from io.Copy. + return io.Copy(writerOnly{w}, src) +} + +// noLimit is an effective infinite upper bound for io.LimitedReader +const noLimit int64 = (1 << 63) - 1 + +// Create new connection from rwc. +func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { + c = new(conn) + c.remoteAddr = rwc.RemoteAddr().String() + c.server = srv + c.rwc = rwc + c.body = make([]byte, sniffLen) + c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader) + br := bufio.NewReader(c.lr) + bw := bufio.NewWriter(rwc) + c.buf = bufio.NewReadWriter(br, bw) + return c, nil +} + +// DefaultMaxHeaderBytes is the maximum permitted size of the headers +// in an HTTP request. +// This can be overridden by setting Server.MaxHeaderBytes. +const DefaultMaxHeaderBytes = 1 << 20 // 1 MB + +func (srv *Server) maxHeaderBytes() int { + if srv.MaxHeaderBytes > 0 { + return srv.MaxHeaderBytes + } + return DefaultMaxHeaderBytes +} + +// wrapper around io.ReaderCloser which on first read, sends an +// HTTP/1.1 100 Continue header +type expectContinueReader struct { + resp *response + readCloser io.ReadCloser + closed bool +} + +func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { + if ecr.closed { + return 0, errors.New("http: Read after Close on request Body") + } + if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked { + ecr.resp.wroteContinue = true + io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") + ecr.resp.conn.buf.Flush() + } + return ecr.readCloser.Read(p) +} + +func (ecr *expectContinueReader) Close() error { + ecr.closed = true + return ecr.readCloser.Close() +} + +// TimeFormat is the time format to use with +// time.Parse and time.Time.Format when parsing +// or generating times in HTTP headers. +// It is like time.RFC1123 but hard codes GMT as the time zone. +const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" + +var errTooLarge = errors.New("http: request too large") + +// Read next request from connection. +func (c *conn) readRequest() (w *response, err error) { + if c.hijacked { + return nil, ErrHijacked + } + c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ + var req *Request + if req, err = ReadRequest(c.buf.Reader); err != nil { + if c.lr.N == 0 { + return nil, errTooLarge + } + return nil, err + } + c.lr.N = noLimit + + req.RemoteAddr = c.remoteAddr + req.TLS = c.tlsState + + w = new(response) + w.conn = c + w.req = req + w.header = make(Header) + w.contentLength = -1 + c.body = c.body[:0] + return w, nil +} + +func (w *response) Header() Header { + return w.header +} + +// maxPostHandlerReadBytes is the max number of Request.Body bytes not +// consumed by a handler that the server will read from the client +// in order to keep a connection alive. If there are more bytes than +// this then the server to be paranoid instead sends a "Connection: +// close" response. +// +// This number is approximately what a typical machine's TCP buffer +// size is anyway. (if we have the bytes on the machine, we might as +// well read them) +const maxPostHandlerReadBytes = 256 << 10 + +func (w *response) WriteHeader(code int) { + if w.conn.hijacked { + log.Print("http: response.WriteHeader on hijacked connection") + return + } + if w.wroteHeader { + log.Print("http: multiple response.WriteHeader calls") + return + } + w.wroteHeader = true + w.status = code + + // Check for a explicit (and valid) Content-Length header. + var hasCL bool + var contentLength int64 + if clenStr := w.header.Get("Content-Length"); clenStr != "" { + var err error + contentLength, err = strconv.ParseInt(clenStr, 10, 64) + if err == nil { + hasCL = true + } else { + log.Printf("http: invalid Content-Length of %q sent", clenStr) + w.header.Del("Content-Length") + } + } + + if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { + _, connectionHeaderSet := w.header["Connection"] + if !connectionHeaderSet { + w.header.Set("Connection", "keep-alive") + } + } else if !w.req.ProtoAtLeast(1, 1) { + // Client did not ask to keep connection alive. + w.closeAfterReply = true + } + + if w.header.Get("Connection") == "close" { + w.closeAfterReply = true + } + + // Per RFC 2616, we should consume the request body before + // replying, if the handler hasn't already done so. But we + // don't want to do an unbounded amount of reading here for + // DoS reasons, so we only try up to a threshold. + if w.req.ContentLength != 0 && !w.closeAfterReply { + ecr, isExpecter := w.req.Body.(*expectContinueReader) + if !isExpecter || ecr.resp.wroteContinue { + n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) + if n >= maxPostHandlerReadBytes { + w.requestTooLarge() + w.header.Set("Connection", "close") + } else { + w.req.Body.Close() + } + } + } + + if code == StatusNotModified { + // Must not have body. + for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { + if w.header.Get(header) != "" { + // TODO: return an error if WriteHeader gets a return parameter + // or set a flag on w to make future Writes() write an error page? + // for now just log and drop the header. + log.Printf("http: StatusNotModified response with header %q defined", header) + w.header.Del(header) + } + } + } else { + // If no content type, apply sniffing algorithm to body. + if w.header.Get("Content-Type") == "" { + w.needSniff = true + } + } + + if _, ok := w.header["Date"]; !ok { + w.Header().Set("Date", time.Now().UTC().Format(TimeFormat)) + } + + te := w.header.Get("Transfer-Encoding") + hasTE := te != "" + if hasCL && hasTE && te != "identity" { + // TODO: return an error if WriteHeader gets a return parameter + // For now just ignore the Content-Length. + log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", + te, contentLength) + w.header.Del("Content-Length") + hasCL = false + } + + if w.req.Method == "HEAD" || code == StatusNotModified { + // do nothing + } else if hasCL { + w.contentLength = contentLength + w.header.Del("Transfer-Encoding") + } else if w.req.ProtoAtLeast(1, 1) { + // HTTP/1.1 or greater: use chunked transfer encoding + // to avoid closing the connection at EOF. + // TODO: this blows away any custom or stacked Transfer-Encoding they + // might have set. Deal with that as need arises once we have a valid + // use case. + w.chunking = true + w.header.Set("Transfer-Encoding", "chunked") + } else { + // HTTP version < 1.1: cannot do chunked transfer + // encoding and we don't know the Content-Length so + // signal EOF by closing connection. + w.closeAfterReply = true + w.header.Del("Transfer-Encoding") // in case already set + } + + // Cannot use Content-Length with non-identity Transfer-Encoding. + if w.chunking { + w.header.Del("Content-Length") + } + if !w.req.ProtoAtLeast(1, 0) { + return + } + proto := "HTTP/1.0" + if w.req.ProtoAtLeast(1, 1) { + proto = "HTTP/1.1" + } + codestring := strconv.Itoa(code) + text, ok := statusText[code] + if !ok { + text = "status code " + codestring + } + io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") + w.header.Write(w.conn.buf) + + // If we need to sniff the body, leave the header open. + // Otherwise, end it here. + if !w.needSniff { + io.WriteString(w.conn.buf, "\r\n") + } +} + +// sniff uses the first block of written data, +// stored in w.conn.body, to decide the Content-Type +// for the HTTP body. +func (w *response) sniff() { + if !w.needSniff { + return + } + w.needSniff = false + + data := w.conn.body + fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n\r\n", DetectContentType(data)) + + if len(data) == 0 { + return + } + if w.chunking { + fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) + } + _, err := w.conn.buf.Write(data) + if w.chunking && err == nil { + io.WriteString(w.conn.buf, "\r\n") + } +} + +// bodyAllowed returns true if a Write is allowed for this response type. +// It's illegal to call this before the header has been flushed. +func (w *response) bodyAllowed() bool { + if !w.wroteHeader { + panic("") + } + return w.status != StatusNotModified && w.req.Method != "HEAD" +} + +func (w *response) Write(data []byte) (n int, err error) { + if w.conn.hijacked { + log.Print("http: response.Write on hijacked connection") + return 0, ErrHijacked + } + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + if len(data) == 0 { + return 0, nil + } + if !w.bodyAllowed() { + return 0, ErrBodyNotAllowed + } + + w.written += int64(len(data)) // ignoring errors, for errorKludge + if w.contentLength != -1 && w.written > w.contentLength { + return 0, ErrContentLength + } + + var m int + if w.needSniff { + // We need to sniff the beginning of the output to + // determine the content type. Accumulate the + // initial writes in w.conn.body. + // Cap m so that append won't allocate. + m = cap(w.conn.body) - len(w.conn.body) + if m > len(data) { + m = len(data) + } + w.conn.body = append(w.conn.body, data[:m]...) + data = data[m:] + if len(data) == 0 { + // Copied everything into the buffer. + // Wait for next write. + return m, nil + } + + // Filled the buffer; more data remains. + // Sniff the content (flushes the buffer) + // and then proceed with the remainder + // of the data as a normal Write. + // Calling sniff clears needSniff. + w.sniff() + } + + // TODO(rsc): if chunking happened after the buffering, + // then there would be fewer chunk headers. + // On the other hand, it would make hijacking more difficult. + if w.chunking { + fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt + } + n, err = w.conn.buf.Write(data) + if err == nil && w.chunking { + if n != len(data) { + err = io.ErrShortWrite + } + if err == nil { + io.WriteString(w.conn.buf, "\r\n") + } + } + + return m + n, err +} + +func (w *response) finishRequest() { + // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length + // back, we can make this a keep-alive response ... + if w.req.wantsHttp10KeepAlive() { + sentLength := w.header.Get("Content-Length") != "" + if sentLength && w.header.Get("Connection") == "keep-alive" { + w.closeAfterReply = false + } + } + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + if w.needSniff { + w.sniff() + } + if w.chunking { + io.WriteString(w.conn.buf, "0\r\n") + // trailer key/value pairs, followed by blank line + io.WriteString(w.conn.buf, "\r\n") + } + w.conn.buf.Flush() + // Close the body, unless we're about to close the whole TCP connection + // anyway. + if !w.closeAfterReply { + w.req.Body.Close() + } + if w.req.MultipartForm != nil { + w.req.MultipartForm.RemoveAll() + } + + if w.contentLength != -1 && w.contentLength != w.written { + // Did not write enough. Avoid getting out of sync. + w.closeAfterReply = true + } +} + +func (w *response) Flush() { + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + w.sniff() + w.conn.buf.Flush() +} + +// Close the connection. +func (c *conn) close() { + if c.buf != nil { + c.buf.Flush() + c.buf = nil + } + if c.rwc != nil { + c.rwc.Close() + c.rwc = nil + } +} + +// Serve a new connection. +func (c *conn) serve() { + defer func() { + err := recover() + if err == nil { + return + } + + var buf bytes.Buffer + fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err) + buf.Write(debug.Stack()) + log.Print(buf.String()) + + if c.rwc != nil { // may be nil if connection hijacked + c.rwc.Close() + } + }() + + if tlsConn, ok := c.rwc.(*tls.Conn); ok { + if err := tlsConn.Handshake(); err != nil { + c.close() + return + } + c.tlsState = new(tls.ConnectionState) + *c.tlsState = tlsConn.ConnectionState() + } + + for { + w, err := c.readRequest() + if err != nil { + msg := "400 Bad Request" + if err == errTooLarge { + // Their HTTP client may or may not be + // able to read this if we're + // responding to them and hanging up + // while they're still writing their + // request. Undefined behavior. + msg = "413 Request Entity Too Large" + } else if err == io.ErrUnexpectedEOF { + break // Don't reply + } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + break // Don't reply + } + fmt.Fprintf(c.rwc, "HTTP/1.1 %s\r\n\r\n", msg) + break + } + + // Expect 100 Continue support + req := w.req + if req.expectsContinue() { + if req.ProtoAtLeast(1, 1) { + // Wrap the Body reader with one that replies on the connection + req.Body = &expectContinueReader{readCloser: req.Body, resp: w} + } + if req.ContentLength == 0 { + w.Header().Set("Connection", "close") + w.WriteHeader(StatusBadRequest) + w.finishRequest() + break + } + req.Header.Del("Expect") + } else if req.Header.Get("Expect") != "" { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 2616 14.20 which says + // "If a server receives a request containing an + // Expect field that includes an expectation- + // extension that it does not support, it MUST + // respond with a 417 (Expectation Failed) status." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + w.finishRequest() + break + } + + handler := c.server.Handler + if handler == nil { + handler = DefaultServeMux + } + + // HTTP cannot have multiple simultaneous active requests.[*] + // Until the server replies to this request, it can't read another, + // so we might as well run the handler in this goroutine. + // [*] Not strictly true: HTTP pipelining. We could let them all process + // in parallel even if their responses need to be serialized. + handler.ServeHTTP(w, w.req) + if c.hijacked { + return + } + w.finishRequest() + if w.closeAfterReply { + break + } + } + c.close() +} + +// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter +// and a Hijacker. +func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + if w.conn.hijacked { + return nil, nil, ErrHijacked + } + w.conn.hijacked = true + rwc = w.conn.rwc + buf = w.conn.buf + w.conn.rwc = nil + w.conn.buf = nil + return +} + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as HTTP handlers. If f is a function +// with the appropriate signature, HandlerFunc(f) is a +// Handler object that calls f. +type HandlerFunc func(ResponseWriter, *Request) + +// ServeHTTP calls f(w, r). +func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { + f(w, r) +} + +// Helper handlers + +// Error replies to the request with the specified error message and HTTP code. +func Error(w ResponseWriter, error string, code int) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(code) + fmt.Fprintln(w, error) +} + +// NotFound replies to the request with an HTTP 404 not found error. +func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", StatusNotFound) } + +// NotFoundHandler returns a simple request handler +// that replies to each request with a ``404 page not found'' reply. +func NotFoundHandler() Handler { return HandlerFunc(NotFound) } + +// StripPrefix returns a handler that serves HTTP requests +// by removing the given prefix from the request URL's Path +// and invoking the handler h. StripPrefix handles a +// request for a path that doesn't begin with prefix by +// replying with an HTTP 404 not found error. +func StripPrefix(prefix string, h Handler) Handler { + return HandlerFunc(func(w ResponseWriter, r *Request) { + if !strings.HasPrefix(r.URL.Path, prefix) { + NotFound(w, r) + return + } + r.URL.Path = r.URL.Path[len(prefix):] + h.ServeHTTP(w, r) + }) +} + +// Redirect replies to the request with a redirect to url, +// which may be a path relative to the request path. +func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { + if u, err := url.Parse(urlStr); err == nil { + // If url was relative, make absolute by + // combining with request path. + // The browser would probably do this for us, + // but doing it ourselves is more reliable. + + // NOTE(rsc): RFC 2616 says that the Location + // line must be an absolute URI, like + // "http://www.google.com/redirect/", + // not a path like "/redirect/". + // Unfortunately, we don't know what to + // put in the host name section to get the + // client to connect to us again, so we can't + // know the right absolute URI to send back. + // Because of this problem, no one pays attention + // to the RFC; they all send back just a new path. + // So do we. + oldpath := r.URL.Path + if oldpath == "" { // should not happen, but avoid a crash if it does + oldpath = "/" + } + if u.Scheme == "" { + // no leading http://server + if urlStr == "" || urlStr[0] != '/' { + // make relative path absolute + olddir, _ := path.Split(oldpath) + urlStr = olddir + urlStr + } + + var query string + if i := strings.Index(urlStr, "?"); i != -1 { + urlStr, query = urlStr[:i], urlStr[i:] + } + + // clean up but preserve trailing slash + trailing := urlStr[len(urlStr)-1] == '/' + urlStr = path.Clean(urlStr) + if trailing && urlStr[len(urlStr)-1] != '/' { + urlStr += "/" + } + urlStr += query + } + } + + w.Header().Set("Location", urlStr) + w.WriteHeader(code) + + // RFC2616 recommends that a short note "SHOULD" be included in the + // response because older user agents may not understand 301/307. + // Shouldn't send the response for POST or HEAD; that leaves GET. + if r.Method == "GET" { + note := "<a href=\"" + htmlEscape(urlStr) + "\">" + statusText[code] + "</a>.\n" + fmt.Fprintln(w, note) + } +} + +var htmlReplacer = strings.NewReplacer( + "&", "&", + "<", "<", + ">", ">", + `"`, """, + "'", "'", +) + +func htmlEscape(s string) string { + return htmlReplacer.Replace(s) +} + +// Redirect to a fixed URL +type redirectHandler struct { + url string + code int +} + +func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) { + Redirect(w, r, rh.url, rh.code) +} + +// RedirectHandler returns a request handler that redirects +// each request it receives to the given url using the given +// status code. +func RedirectHandler(url string, code int) Handler { + return &redirectHandler{url, code} +} + +// ServeMux is an HTTP request multiplexer. +// It matches the URL of each incoming request against a list of registered +// patterns and calls the handler for the pattern that +// most closely matches the URL. +// +// Patterns named fixed, rooted paths, like "/favicon.ico", +// or rooted subtrees, like "/images/" (note the trailing slash). +// Longer patterns take precedence over shorter ones, so that +// if there are handlers registered for both "/images/" +// and "/images/thumbnails/", the latter handler will be +// called for paths beginning "/images/thumbnails/" and the +// former will receiver requests for any other paths in the +// "/images/" subtree. +// +// Patterns may optionally begin with a host name, restricting matches to +// URLs on that host only. Host-specific patterns take precedence over +// general patterns, so that a handler might register for the two patterns +// "/codesearch" and "codesearch.google.com/" without also taking over +// requests for "http://www.google.com/". +// +// ServeMux also takes care of sanitizing the URL request path, +// redirecting any request containing . or .. elements to an +// equivalent .- and ..-free URL. +type ServeMux struct { + m map[string]Handler +} + +// NewServeMux allocates and returns a new ServeMux. +func NewServeMux() *ServeMux { return &ServeMux{make(map[string]Handler)} } + +// DefaultServeMux is the default ServeMux used by Serve. +var DefaultServeMux = NewServeMux() + +// Does path match pattern? +func pathMatch(pattern, path string) bool { + if len(pattern) == 0 { + // should not happen + return false + } + n := len(pattern) + if pattern[n-1] != '/' { + return pattern == path + } + return len(path) >= n && path[0:n] == pattern +} + +// Return the canonical path for p, eliminating . and .. elements. +func cleanPath(p string) string { + if p == "" { + return "/" + } + if p[0] != '/' { + p = "/" + p + } + np := path.Clean(p) + // path.Clean removes trailing slash except for root; + // put the trailing slash back if necessary. + if p[len(p)-1] == '/' && np != "/" { + np += "/" + } + return np +} + +// Find a handler on a handler map given a path string +// Most-specific (longest) pattern wins +func (mux *ServeMux) match(path string) Handler { + var h Handler + var n = 0 + for k, v := range mux.m { + if !pathMatch(k, path) { + continue + } + if h == nil || len(k) > n { + n = len(k) + h = v + } + } + return h +} + +// ServeHTTP dispatches the request to the handler whose +// pattern most closely matches the request URL. +func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { + // Clean path to canonical form and redirect. + if p := cleanPath(r.URL.Path); p != r.URL.Path { + w.Header().Set("Location", p) + w.WriteHeader(StatusMovedPermanently) + return + } + // Host-specific pattern takes precedence over generic ones + h := mux.match(r.Host + r.URL.Path) + if h == nil { + h = mux.match(r.URL.Path) + } + if h == nil { + h = NotFoundHandler() + } + h.ServeHTTP(w, r) +} + +// Handle registers the handler for the given pattern. +func (mux *ServeMux) Handle(pattern string, handler Handler) { + if pattern == "" { + panic("http: invalid pattern " + pattern) + } + + mux.m[pattern] = handler + + // Helpful behavior: + // If pattern is /tree/, insert permanent redirect for /tree. + n := len(pattern) + if n > 0 && pattern[n-1] == '/' { + mux.m[pattern[0:n-1]] = RedirectHandler(pattern, StatusMovedPermanently) + } +} + +// HandleFunc registers the handler function for the given pattern. +func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + mux.Handle(pattern, HandlerFunc(handler)) +} + +// Handle registers the handler for the given pattern +// in the DefaultServeMux. +// The documentation for ServeMux explains how patterns are matched. +func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } + +// HandleFunc registers the handler function for the given pattern +// in the DefaultServeMux. +// The documentation for ServeMux explains how patterns are matched. +func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + DefaultServeMux.HandleFunc(pattern, handler) +} + +// Serve accepts incoming HTTP connections on the listener l, +// creating a new service thread for each. The service threads +// read requests and then call handler to reply to them. +// Handler is typically nil, in which case the DefaultServeMux is used. +func Serve(l net.Listener, handler Handler) error { + srv := &Server{Handler: handler} + return srv.Serve(l) +} + +// A Server defines parameters for running an HTTP server. +type Server struct { + Addr string // TCP address to listen on, ":http" if empty + Handler Handler // handler to invoke, http.DefaultServeMux if nil + ReadTimeout time.Duration // maximum duration before timing out read of the request + WriteTimeout time.Duration // maximum duration before timing out write of the response + MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0 +} + +// ListenAndServe listens on the TCP network address srv.Addr and then +// calls Serve to handle requests on incoming connections. If +// srv.Addr is blank, ":http" is used. +func (srv *Server) ListenAndServe() error { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + l, e := net.Listen("tcp", addr) + if e != nil { + return e + } + return srv.Serve(l) +} + +// Serve accepts incoming connections on the Listener l, creating a +// new service thread for each. The service threads read requests and +// then call srv.Handler to reply to them. +func (srv *Server) Serve(l net.Listener) error { + defer l.Close() + for { + rw, e := l.Accept() + if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + log.Printf("http: Accept error: %v", e) + continue + } + return e + } + if srv.ReadTimeout != 0 { + rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) + } + if srv.WriteTimeout != 0 { + rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout)) + } + c, err := srv.newConn(rw) + if err != nil { + continue + } + go c.serve() + } + panic("not reached") +} + +// ListenAndServe listens on the TCP network address addr +// and then calls Serve with handler to handle requests +// on incoming connections. Handler is typically nil, +// in which case the DefaultServeMux is used. +// +// A trivial example server is: +// +// package main +// +// import ( +// "io" +// "net/http" +// "log" +// ) +// +// // hello world, the web server +// func HelloServer(w http.ResponseWriter, req *http.Request) { +// io.WriteString(w, "hello, world!\n") +// } +// +// func main() { +// http.HandleFunc("/hello", HelloServer) +// err := http.ListenAndServe(":12345", nil) +// if err != nil { +// log.Fatal("ListenAndServe: ", err) +// } +// } +func ListenAndServe(addr string, handler Handler) error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServe() +} + +// ListenAndServeTLS acts identically to ListenAndServe, except that it +// expects HTTPS connections. Additionally, files containing a certificate and +// matching private key for the server must be provided. If the certificate +// is signed by a certificate authority, the certFile should be the concatenation +// of the server's certificate followed by the CA's certificate. +// +// A trivial example server is: +// +// import ( +// "log" +// "net/http" +// ) +// +// func handler(w http.ResponseWriter, req *http.Request) { +// w.Header().Set("Content-Type", "text/plain") +// w.Write([]byte("This is an example server.\n")) +// } +// +// func main() { +// http.HandleFunc("/", handler) +// log.Printf("About to listen on 10443. Go to https://127.0.0.1:10443/") +// err := http.ListenAndServeTLS(":10443", "cert.pem", "key.pem", nil) +// if err != nil { +// log.Fatal(err) +// } +// } +// +// One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem. +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler) error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServeTLS(certFile, keyFile) +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and +// then calls Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for +// the server must be provided. If the certificate is signed by a +// certificate authority, the certFile should be the concatenation +// of the server's certificate followed by the CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { + addr := s.Addr + if addr == "" { + addr = ":https" + } + config := &tls.Config{ + Rand: rand.Reader, + NextProtos: []string{"http/1.1"}, + } + + var err error + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + + conn, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + tlsListener := tls.NewListener(conn, config) + return s.Serve(tlsListener) +} + +// TimeoutHandler returns a Handler that runs h with the given time limit. +// +// The new Handler calls h.ServeHTTP to handle each request, but if a +// call runs for more than ns nanoseconds, the handler responds with +// a 503 Service Unavailable error and the given message in its body. +// (If msg is empty, a suitable default message will be sent.) +// After such a timeout, writes by h to its ResponseWriter will return +// ErrHandlerTimeout. +func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler { + f := func() <-chan time.Time { + return time.After(dt) + } + return &timeoutHandler{h, f, msg} +} + +// ErrHandlerTimeout is returned on ResponseWriter Write calls +// in handlers which have timed out. +var ErrHandlerTimeout = errors.New("http: Handler timeout") + +type timeoutHandler struct { + handler Handler + timeout func() <-chan time.Time // returns channel producing a timeout + body string +} + +func (h *timeoutHandler) errorBody() string { + if h.body != "" { + return h.body + } + return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>" +} + +func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { + done := make(chan bool) + tw := &timeoutWriter{w: w} + go func() { + h.handler.ServeHTTP(tw, r) + done <- true + }() + select { + case <-done: + return + case <-h.timeout(): + tw.mu.Lock() + defer tw.mu.Unlock() + if !tw.wroteHeader { + tw.w.WriteHeader(StatusServiceUnavailable) + tw.w.Write([]byte(h.errorBody())) + } + tw.timedOut = true + } +} + +type timeoutWriter struct { + w ResponseWriter + + mu sync.Mutex + timedOut bool + wroteHeader bool +} + +func (tw *timeoutWriter) Header() Header { + return tw.w.Header() +} + +func (tw *timeoutWriter) Write(p []byte) (int, error) { + tw.mu.Lock() + timedOut := tw.timedOut + tw.mu.Unlock() + if timedOut { + return 0, ErrHandlerTimeout + } + return tw.w.Write(p) +} + +func (tw *timeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + if tw.timedOut || tw.wroteHeader { + tw.mu.Unlock() + return + } + tw.wroteHeader = true + tw.mu.Unlock() + tw.w.WriteHeader(code) +} diff --git a/src/pkg/net/http/sniff.go b/src/pkg/net/http/sniff.go new file mode 100644 index 000000000..c1c78e241 --- /dev/null +++ b/src/pkg/net/http/sniff.go @@ -0,0 +1,214 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "encoding/binary" +) + +// Content-type sniffing algorithm. +// References in this file refer to this draft specification: +// http://mimesniff.spec.whatwg.org/ + +// The algorithm prefers to use sniffLen bytes to make its decision. +const sniffLen = 512 + +// DetectContentType returns the sniffed Content-Type string +// for the given data. This function always returns a valid MIME type. +func DetectContentType(data []byte) string { + if len(data) > sniffLen { + data = data[:sniffLen] + } + + // Index of the first non-whitespace byte in data. + firstNonWS := 0 + for ; firstNonWS < len(data) && isWS(data[firstNonWS]); firstNonWS++ { + } + + for _, sig := range sniffSignatures { + if ct := sig.match(data, firstNonWS); ct != "" { + return ct + } + } + + return "application/octet-stream" // fallback +} + +func isWS(b byte) bool { + return bytes.IndexByte([]byte("\t\n\x0C\r "), b) != -1 +} + +type sniffSig interface { + // match returns the MIME type of the data, or "" if unknown. + match(data []byte, firstNonWS int) string +} + +// Data matching the table in section 6. +var sniffSignatures = []sniffSig{ + htmlSig("<!DOCTYPE HTML"), + htmlSig("<HTML"), + htmlSig("<HEAD"), + htmlSig("<SCRIPT"), + htmlSig("<IFRAME"), + htmlSig("<H1"), + htmlSig("<DIV"), + htmlSig("<FONT"), + htmlSig("<TABLE"), + htmlSig("<A"), + htmlSig("<STYLE"), + htmlSig("<TITLE"), + htmlSig("<B"), + htmlSig("<BODY"), + htmlSig("<BR"), + htmlSig("<P"), + htmlSig("<!--"), + + &maskedSig{mask: []byte("\xFF\xFF\xFF\xFF\xFF"), pat: []byte("<?xml"), skipWS: true, ct: "text/xml; charset=utf-8"}, + + &exactSig{[]byte("%PDF-"), "application/pdf"}, + &exactSig{[]byte("%!PS-Adobe-"), "application/postscript"}, + + // UTF BOMs. + &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFE\xFF\x00\x00"), ct: "text/plain; charset=utf-16be"}, + &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFF\xFE\x00\x00"), ct: "text/plain; charset=utf-16le"}, + &maskedSig{mask: []byte("\xFF\xFF\xFF\x00"), pat: []byte("\xEF\xBB\xBF\x00"), ct: "text/plain; charset=utf-8"}, + + &exactSig{[]byte("GIF87a"), "image/gif"}, + &exactSig{[]byte("GIF89a"), "image/gif"}, + &exactSig{[]byte("\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"), "image/png"}, + &exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"}, + &exactSig{[]byte("BM"), "image/bmp"}, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"), + pat: []byte("RIFF\x00\x00\x00\x00WEBPVP"), + ct: "image/webp", + }, + &exactSig{[]byte("\x00\x00\x01\x00"), "image/vnd.microsoft.icon"}, + &exactSig{[]byte("\x4F\x67\x67\x53\x00"), "application/ogg"}, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), + pat: []byte("RIFF\x00\x00\x00\x00WAVE"), + ct: "audio/wave", + }, + &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"}, + &exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"}, + &exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"}, + &exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"}, + + // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4. + //mp4Sig(0), + + textSig(0), // should be last +} + +type exactSig struct { + sig []byte + ct string +} + +func (e *exactSig) match(data []byte, firstNonWS int) string { + if bytes.HasPrefix(data, e.sig) { + return e.ct + } + return "" +} + +type maskedSig struct { + mask, pat []byte + skipWS bool + ct string +} + +func (m *maskedSig) match(data []byte, firstNonWS int) string { + if m.skipWS { + data = data[firstNonWS:] + } + if len(data) < len(m.mask) { + return "" + } + for i, mask := range m.mask { + db := data[i] & mask + if db != m.pat[i] { + return "" + } + } + return m.ct +} + +type htmlSig []byte + +func (h htmlSig) match(data []byte, firstNonWS int) string { + data = data[firstNonWS:] + if len(data) < len(h)+1 { + return "" + } + for i, b := range h { + db := data[i] + if 'A' <= b && b <= 'Z' { + db &= 0xDF + } + if b != db { + return "" + } + } + // Next byte must be space or right angle bracket. + if db := data[len(h)]; db != ' ' && db != '>' { + return "" + } + return "text/html; charset=utf-8" +} + +type mp4Sig int + +func (mp4Sig) match(data []byte, firstNonWS int) string { + // c.f. section 6.1. + if len(data) < 8 { + return "" + } + boxSize := int(binary.BigEndian.Uint32(data[:4])) + if boxSize%4 != 0 || len(data) < boxSize { + return "" + } + if !bytes.Equal(data[4:8], []byte("ftyp")) { + return "" + } + for st := 8; st < boxSize; st += 4 { + if st == 12 { + // minor version number + continue + } + seg := string(data[st : st+3]) + switch seg { + case "mp4", "iso", "M4V", "M4P", "M4B": + return "video/mp4" + /* The remainder are not in the spec. + case "M4A": + return "audio/mp4" + case "3gp": + return "video/3gpp" + case "jp2": + return "image/jp2" // JPEG 2000 + */ + } + } + return "" +} + +type textSig int + +func (textSig) match(data []byte, firstNonWS int) string { + // c.f. section 5, step 4. + for _, b := range data[firstNonWS:] { + switch { + case 0x00 <= b && b <= 0x08, + b == 0x0B, + 0x0E <= b && b <= 0x1A, + 0x1C <= b && b <= 0x1F: + return "" + } + } + return "text/plain; charset=utf-8" +} diff --git a/src/pkg/net/http/sniff_test.go b/src/pkg/net/http/sniff_test.go new file mode 100644 index 000000000..6efa8ce1c --- /dev/null +++ b/src/pkg/net/http/sniff_test.go @@ -0,0 +1,137 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + . "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +var sniffTests = []struct { + desc string + data []byte + contentType string +}{ + // Some nonsense. + {"Empty", []byte{}, "text/plain; charset=utf-8"}, + {"Binary", []byte{1, 2, 3}, "application/octet-stream"}, + + {"HTML document #1", []byte(`<HtMl><bOdY>blah blah blah</body></html>`), "text/html; charset=utf-8"}, + {"HTML document #2", []byte(`<HTML></HTML>`), "text/html; charset=utf-8"}, + {"HTML document #3 (leading whitespace)", []byte(` <!DOCTYPE HTML>...`), "text/html; charset=utf-8"}, + {"HTML document #4 (leading CRLF)", []byte("\r\n<html>..."), "text/html; charset=utf-8"}, + + {"Plain text", []byte(`This is not HTML. It has ☃ though.`), "text/plain; charset=utf-8"}, + + {"XML", []byte("\n<?xml!"), "text/xml; charset=utf-8"}, + + // Image types. + {"GIF 87a", []byte(`GIF87a`), "image/gif"}, + {"GIF 89a", []byte(`GIF89a...`), "image/gif"}, + + // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4. + //{"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"}, + //{"MP4 audio", []byte("\x00\x00\x00\x20ftypM4A \x00\x00\x00\x00M4A mp42isom\x00\x00\x00\x00"), "audio/mp4"}, +} + +func TestDetectContentType(t *testing.T) { + for _, tt := range sniffTests { + ct := DetectContentType(tt.data) + if ct != tt.contentType { + t.Errorf("%v: DetectContentType = %q, want %q", tt.desc, ct, tt.contentType) + } + } +} + +func TestServerContentType(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + i, _ := strconv.Atoi(r.FormValue("i")) + tt := sniffTests[i] + n, err := w.Write(tt.data) + if n != len(tt.data) || err != nil { + log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data)) + } + })) + defer ts.Close() + + for i, tt := range sniffTests { + resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i)) + if err != nil { + t.Errorf("%v: %v", tt.desc, err) + continue + } + if ct := resp.Header.Get("Content-Type"); ct != tt.contentType { + t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, tt.contentType) + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("%v: reading body: %v", tt.desc, err) + } else if !bytes.Equal(data, tt.data) { + t.Errorf("%v: data is %q, want %q", tt.desc, data, tt.data) + } + resp.Body.Close() + } +} + +func TestContentTypeWithCopy(t *testing.T) { + const ( + input = "\n<html>\n\t<head>\n" + expected = "text/html; charset=utf-8" + ) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // Use io.Copy from a bytes.Buffer to trigger ReadFrom. + buf := bytes.NewBuffer([]byte(input)) + n, err := io.Copy(w, buf) + if int(n) != len(input) || err != nil { + t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input)) + } + })) + defer ts.Close() + + resp, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + if ct := resp.Header.Get("Content-Type"); ct != expected { + t.Errorf("Content-Type = %q, want %q", ct, expected) + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("reading body: %v", err) + } else if !bytes.Equal(data, []byte(input)) { + t.Errorf("data is %q, want %q", data, input) + } + resp.Body.Close() +} + +func TestSniffWriteSize(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + size, _ := strconv.Atoi(r.FormValue("size")) + written, err := io.WriteString(w, strings.Repeat("a", size)) + if err != nil { + t.Errorf("write of %d bytes: %v", size, err) + return + } + if written != size { + t.Errorf("write of %d bytes wrote %d bytes", size, written) + } + })) + defer ts.Close() + for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} { + _, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size)) + if err != nil { + t.Fatalf("size %d: %v", size, err) + } + } +} diff --git a/src/pkg/net/http/status.go b/src/pkg/net/http/status.go new file mode 100644 index 000000000..b6e2d65c6 --- /dev/null +++ b/src/pkg/net/http/status.go @@ -0,0 +1,106 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// HTTP status codes, defined in RFC 2616. +const ( + StatusContinue = 100 + StatusSwitchingProtocols = 101 + + StatusOK = 200 + StatusCreated = 201 + StatusAccepted = 202 + StatusNonAuthoritativeInfo = 203 + StatusNoContent = 204 + StatusResetContent = 205 + StatusPartialContent = 206 + + StatusMultipleChoices = 300 + StatusMovedPermanently = 301 + StatusFound = 302 + StatusSeeOther = 303 + StatusNotModified = 304 + StatusUseProxy = 305 + StatusTemporaryRedirect = 307 + + StatusBadRequest = 400 + StatusUnauthorized = 401 + StatusPaymentRequired = 402 + StatusForbidden = 403 + StatusNotFound = 404 + StatusMethodNotAllowed = 405 + StatusNotAcceptable = 406 + StatusProxyAuthRequired = 407 + StatusRequestTimeout = 408 + StatusConflict = 409 + StatusGone = 410 + StatusLengthRequired = 411 + StatusPreconditionFailed = 412 + StatusRequestEntityTooLarge = 413 + StatusRequestURITooLong = 414 + StatusUnsupportedMediaType = 415 + StatusRequestedRangeNotSatisfiable = 416 + StatusExpectationFailed = 417 + + StatusInternalServerError = 500 + StatusNotImplemented = 501 + StatusBadGateway = 502 + StatusServiceUnavailable = 503 + StatusGatewayTimeout = 504 + StatusHTTPVersionNotSupported = 505 +) + +var statusText = map[int]string{ + StatusContinue: "Continue", + StatusSwitchingProtocols: "Switching Protocols", + + StatusOK: "OK", + StatusCreated: "Created", + StatusAccepted: "Accepted", + StatusNonAuthoritativeInfo: "Non-Authoritative Information", + StatusNoContent: "No Content", + StatusResetContent: "Reset Content", + StatusPartialContent: "Partial Content", + + StatusMultipleChoices: "Multiple Choices", + StatusMovedPermanently: "Moved Permanently", + StatusFound: "Found", + StatusSeeOther: "See Other", + StatusNotModified: "Not Modified", + StatusUseProxy: "Use Proxy", + StatusTemporaryRedirect: "Temporary Redirect", + + StatusBadRequest: "Bad Request", + StatusUnauthorized: "Unauthorized", + StatusPaymentRequired: "Payment Required", + StatusForbidden: "Forbidden", + StatusNotFound: "Not Found", + StatusMethodNotAllowed: "Method Not Allowed", + StatusNotAcceptable: "Not Acceptable", + StatusProxyAuthRequired: "Proxy Authentication Required", + StatusRequestTimeout: "Request Timeout", + StatusConflict: "Conflict", + StatusGone: "Gone", + StatusLengthRequired: "Length Required", + StatusPreconditionFailed: "Precondition Failed", + StatusRequestEntityTooLarge: "Request Entity Too Large", + StatusRequestURITooLong: "Request URI Too Long", + StatusUnsupportedMediaType: "Unsupported Media Type", + StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", + StatusExpectationFailed: "Expectation Failed", + + StatusInternalServerError: "Internal Server Error", + StatusNotImplemented: "Not Implemented", + StatusBadGateway: "Bad Gateway", + StatusServiceUnavailable: "Service Unavailable", + StatusGatewayTimeout: "Gateway Timeout", + StatusHTTPVersionNotSupported: "HTTP Version Not Supported", +} + +// StatusText returns a text for the HTTP status code. It returns the empty +// string if the code is unknown. +func StatusText(code int) string { + return statusText[code] +} diff --git a/src/pkg/net/http/testdata/file b/src/pkg/net/http/testdata/file new file mode 100644 index 000000000..11f11f9be --- /dev/null +++ b/src/pkg/net/http/testdata/file @@ -0,0 +1 @@ +0123456789 diff --git a/src/pkg/net/http/testdata/index.html b/src/pkg/net/http/testdata/index.html new file mode 100644 index 000000000..da8e1e93d --- /dev/null +++ b/src/pkg/net/http/testdata/index.html @@ -0,0 +1 @@ +index.html says hello diff --git a/src/pkg/net/http/testdata/style.css b/src/pkg/net/http/testdata/style.css new file mode 100644 index 000000000..208d16d42 --- /dev/null +++ b/src/pkg/net/http/testdata/style.css @@ -0,0 +1 @@ +body {} diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go new file mode 100644 index 000000000..ef9564af9 --- /dev/null +++ b/src/pkg/net/http/transfer.go @@ -0,0 +1,630 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net/textproto" + "strconv" + "strings" +) + +// transferWriter inspects the fields of a user-supplied Request or Response, +// sanitizes them without changing the user object and provides methods for +// writing the respective header, body and trailer in wire format. +type transferWriter struct { + Method string + Body io.Reader + BodyCloser io.Closer + ResponseToHEAD bool + ContentLength int64 // -1 means unknown, 0 means exactly none + Close bool + TransferEncoding []string + Trailer Header +} + +func newTransferWriter(r interface{}) (t *transferWriter, err error) { + t = &transferWriter{} + + // Extract relevant fields + atLeastHTTP11 := false + switch rr := r.(type) { + case *Request: + if rr.ContentLength != 0 && rr.Body == nil { + return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength) + } + t.Method = rr.Method + t.Body = rr.Body + t.BodyCloser = rr.Body + t.ContentLength = rr.ContentLength + t.Close = rr.Close + t.TransferEncoding = rr.TransferEncoding + t.Trailer = rr.Trailer + atLeastHTTP11 = rr.ProtoAtLeast(1, 1) + if t.Body != nil && len(t.TransferEncoding) == 0 && atLeastHTTP11 { + if t.ContentLength == 0 { + // Test to see if it's actually zero or just unset. + var buf [1]byte + n, _ := io.ReadFull(t.Body, buf[:]) + if n == 1 { + // Oh, guess there is data in this Body Reader after all. + // The ContentLength field just wasn't set. + // Stich the Body back together again, re-attaching our + // consumed byte. + t.ContentLength = -1 + t.Body = io.MultiReader(bytes.NewBuffer(buf[:]), t.Body) + } else { + // Body is actually empty. + t.Body = nil + t.BodyCloser = nil + } + } + if t.ContentLength < 0 { + t.TransferEncoding = []string{"chunked"} + } + } + case *Response: + t.Method = rr.Request.Method + t.Body = rr.Body + t.BodyCloser = rr.Body + t.ContentLength = rr.ContentLength + t.Close = rr.Close + t.TransferEncoding = rr.TransferEncoding + t.Trailer = rr.Trailer + atLeastHTTP11 = rr.ProtoAtLeast(1, 1) + t.ResponseToHEAD = noBodyExpected(rr.Request.Method) + } + + // Sanitize Body,ContentLength,TransferEncoding + if t.ResponseToHEAD { + t.Body = nil + t.TransferEncoding = nil + // ContentLength is expected to hold Content-Length + if t.ContentLength < 0 { + return nil, ErrMissingContentLength + } + } else { + if !atLeastHTTP11 || t.Body == nil { + t.TransferEncoding = nil + } + if chunked(t.TransferEncoding) { + t.ContentLength = -1 + } else if t.Body == nil { // no chunking, no body + t.ContentLength = 0 + } + } + + // Sanitize Trailer + if !chunked(t.TransferEncoding) { + t.Trailer = nil + } + + return t, nil +} + +func noBodyExpected(requestMethod string) bool { + return requestMethod == "HEAD" +} + +func (t *transferWriter) shouldSendContentLength() bool { + if chunked(t.TransferEncoding) { + return false + } + if t.ContentLength > 0 { + return true + } + if t.ResponseToHEAD { + return true + } + // Many servers expect a Content-Length for these methods + if t.Method == "POST" || t.Method == "PUT" { + return true + } + if t.ContentLength == 0 && isIdentity(t.TransferEncoding) { + return true + } + + return false +} + +func (t *transferWriter) WriteHeader(w io.Writer) (err error) { + if t.Close { + _, err = io.WriteString(w, "Connection: close\r\n") + if err != nil { + return + } + } + + // Write Content-Length and/or Transfer-Encoding whose values are a + // function of the sanitized field triple (Body, ContentLength, + // TransferEncoding) + if t.shouldSendContentLength() { + io.WriteString(w, "Content-Length: ") + _, err = io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n") + if err != nil { + return + } + } else if chunked(t.TransferEncoding) { + _, err = io.WriteString(w, "Transfer-Encoding: chunked\r\n") + if err != nil { + return + } + } + + // Write Trailer header + if t.Trailer != nil { + // TODO: At some point, there should be a generic mechanism for + // writing long headers, using HTTP line splitting + io.WriteString(w, "Trailer: ") + needComma := false + for k := range t.Trailer { + k = CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return &badStringError{"invalid Trailer key", k} + } + if needComma { + io.WriteString(w, ",") + } + io.WriteString(w, k) + needComma = true + } + _, err = io.WriteString(w, "\r\n") + } + + return +} + +func (t *transferWriter) WriteBody(w io.Writer) (err error) { + var ncopy int64 + + // Write body + if t.Body != nil { + if chunked(t.TransferEncoding) { + cw := newChunkedWriter(w) + _, err = io.Copy(cw, t.Body) + if err == nil { + err = cw.Close() + } + } else if t.ContentLength == -1 { + ncopy, err = io.Copy(w, t.Body) + } else { + ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) + nextra, err := io.Copy(ioutil.Discard, t.Body) + if err != nil { + return err + } + ncopy += nextra + } + if err != nil { + return err + } + if err = t.BodyCloser.Close(); err != nil { + return err + } + } + + if t.ContentLength != -1 && t.ContentLength != ncopy { + return fmt.Errorf("http: Request.ContentLength=%d with Body length %d", + t.ContentLength, ncopy) + } + + // TODO(petar): Place trailer writer code here. + if chunked(t.TransferEncoding) { + // Last chunk, empty trailer + _, err = io.WriteString(w, "\r\n") + } + + return +} + +type transferReader struct { + // Input + Header Header + StatusCode int + RequestMethod string + ProtoMajor int + ProtoMinor int + // Output + Body io.ReadCloser + ContentLength int64 + TransferEncoding []string + Close bool + Trailer Header +} + +// bodyAllowedForStatus returns whether a given response status code +// permits a body. See RFC2616, section 4.4. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +// msg is *Request or *Response. +func readTransfer(msg interface{}, r *bufio.Reader) (err error) { + t := &transferReader{} + + // Unify input + isResponse := false + switch rr := msg.(type) { + case *Response: + t.Header = rr.Header + t.StatusCode = rr.StatusCode + t.RequestMethod = rr.Request.Method + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header) + isResponse = true + case *Request: + t.Header = rr.Header + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + // Transfer semantics for Requests are exactly like those for + // Responses with status code 200, responding to a GET method + t.StatusCode = 200 + t.RequestMethod = "GET" + default: + panic("unexpected type") + } + + // Default to HTTP/1.1 + if t.ProtoMajor == 0 && t.ProtoMinor == 0 { + t.ProtoMajor, t.ProtoMinor = 1, 1 + } + + // Transfer encoding, content length + t.TransferEncoding, err = fixTransferEncoding(t.RequestMethod, t.Header) + if err != nil { + return err + } + + t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) + if err != nil { + return err + } + + // Trailer + t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding) + if err != nil { + return err + } + + // If there is no Content-Length or chunked Transfer-Encoding on a *Response + // and the status is not 1xx, 204 or 304, then the body is unbounded. + // See RFC2616, section 4.4. + switch msg.(type) { + case *Response: + if t.ContentLength == -1 && + !chunked(t.TransferEncoding) && + bodyAllowedForStatus(t.StatusCode) { + // Unbounded body. + t.Close = true + } + } + + // Prepare body reader. ContentLength < 0 means chunked encoding + // or close connection when finished, since multipart is not supported yet + switch { + case chunked(t.TransferEncoding): + t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + case t.ContentLength >= 0: + // TODO: limit the Content-Length. This is an easy DoS vector. + t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close} + default: + // t.ContentLength < 0, i.e. "Content-Length" not mentioned in header + if t.Close { + // Close semantics (i.e. HTTP/1.0) + t.Body = &body{Reader: r, closing: t.Close} + } else { + // Persistent connection (i.e. HTTP/1.1) + t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close} + } + } + + // Unify output + switch rr := msg.(type) { + case *Request: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + rr.TransferEncoding = t.TransferEncoding + rr.Close = t.Close + rr.Trailer = t.Trailer + case *Response: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + rr.TransferEncoding = t.TransferEncoding + rr.Close = t.Close + rr.Trailer = t.Trailer + } + + return nil +} + +// Checks whether chunked is part of the encodings stack +func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } + +// Checks whether the encoding is explicitly "identity". +func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" } + +// Sanitize transfer encoding +func fixTransferEncoding(requestMethod string, header Header) ([]string, error) { + raw, present := header["Transfer-Encoding"] + if !present { + return nil, nil + } + + delete(header, "Transfer-Encoding") + + // Head responses have no bodies, so the transfer encoding + // should be ignored. + if requestMethod == "HEAD" { + return nil, nil + } + + encodings := strings.Split(raw[0], ",") + te := make([]string, 0, len(encodings)) + // TODO: Even though we only support "identity" and "chunked" + // encodings, the loop below is designed with foresight. One + // invariant that must be maintained is that, if present, + // chunked encoding must always come first. + for _, encoding := range encodings { + encoding = strings.ToLower(strings.TrimSpace(encoding)) + // "identity" encoding is not recored + if encoding == "identity" { + break + } + if encoding != "chunked" { + return nil, &badStringError{"unsupported transfer encoding", encoding} + } + te = te[0 : len(te)+1] + te[len(te)-1] = encoding + } + if len(te) > 1 { + return nil, &badStringError{"too many transfer encodings", strings.Join(te, ",")} + } + if len(te) > 0 { + // Chunked encoding trumps Content-Length. See RFC 2616 + // Section 4.4. Currently len(te) > 0 implies chunked + // encoding. + delete(header, "Content-Length") + return te, nil + } + + return nil, nil +} + +// Determine the expected body length, using RFC 2616 Section 4.4. This +// function is not a method, because ultimately it should be shared by +// ReadResponse and ReadRequest. +func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) { + + // Logic based on response type or status + if noBodyExpected(requestMethod) { + return 0, nil + } + if status/100 == 1 { + return 0, nil + } + switch status { + case 204, 304: + return 0, nil + } + + // Logic based on Transfer-Encoding + if chunked(te) { + return -1, nil + } + + // Logic based on Content-Length + cl := strings.TrimSpace(header.Get("Content-Length")) + if cl != "" { + n, err := strconv.ParseInt(cl, 10, 64) + if err != nil || n < 0 { + return -1, &badStringError{"bad Content-Length", cl} + } + return n, nil + } else { + header.Del("Content-Length") + } + + if !isResponse && requestMethod == "GET" { + // RFC 2616 doesn't explicitly permit nor forbid an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + return 0, nil + } + + // Logic based on media type. The purpose of the following code is just + // to detect whether the unsupported "multipart/byteranges" is being + // used. A proper Content-Type parser is needed in the future. + if strings.Contains(strings.ToLower(header.Get("Content-Type")), "multipart/byteranges") { + return -1, ErrNotSupported + } + + // Body-EOF logic based on other methods (like closing, or chunked coding) + return -1, nil +} + +// Determine whether to hang up after sending a request and body, or +// receiving a response and body +// 'header' is the request headers +func shouldClose(major, minor int, header Header) bool { + if major < 1 { + return true + } else if major == 1 && minor == 0 { + if !strings.Contains(strings.ToLower(header.Get("Connection")), "keep-alive") { + return true + } + return false + } else { + // TODO: Should split on commas, toss surrounding white space, + // and check each field. + if strings.ToLower(header.Get("Connection")) == "close" { + header.Del("Connection") + return true + } + } + return false +} + +// Parse the trailer header +func fixTrailer(header Header, te []string) (Header, error) { + raw := header.Get("Trailer") + if raw == "" { + return nil, nil + } + + header.Del("Trailer") + trailer := make(Header) + keys := strings.Split(raw, ",") + for _, key := range keys { + key = CanonicalHeaderKey(strings.TrimSpace(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + return nil, &badStringError{"bad trailer key", key} + } + trailer.Del(key) + } + if len(trailer) == 0 { + return nil, nil + } + if !chunked(te) { + // Trailer and no chunking + return nil, ErrUnexpectedTrailer + } + return trailer, nil +} + +// body turns a Reader into a ReadCloser. +// Close ensures that the body has been fully read +// and then reads the trailer if necessary. +type body struct { + io.Reader + hdr interface{} // non-nil (Response or Request) value means read trailer + r *bufio.Reader // underlying wire-format reader for the trailer + closing bool // is the connection to be closed after reading body? + closed bool + + res *response // response writer for server requests, else nil +} + +// ErrBodyReadAfterClose is returned when reading a Request Body after +// the body has been closed. This typically happens when the body is +// read after an HTTP Handler calls WriteHeader or Write on its +// ResponseWriter. +var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed request Body") + +func (b *body) Read(p []byte) (n int, err error) { + if b.closed { + return 0, ErrBodyReadAfterClose + } + n, err = b.Reader.Read(p) + + // Read the final trailer once we hit EOF. + if err == io.EOF && b.hdr != nil { + if e := b.readTrailer(); e != nil { + err = e + } + b.hdr = nil + } + return n, err +} + +var ( + singleCRLF = []byte("\r\n") + doubleCRLF = []byte("\r\n\r\n") +) + +func seeUpcomingDoubleCRLF(r *bufio.Reader) bool { + for peekSize := 4; ; peekSize++ { + // This loop stops when Peek returns an error, + // which it does when r's buffer has been filled. + buf, err := r.Peek(peekSize) + if bytes.HasSuffix(buf, doubleCRLF) { + return true + } + if err != nil { + break + } + } + return false +} + +func (b *body) readTrailer() error { + // The common case, since nobody uses trailers. + buf, _ := b.r.Peek(2) + if bytes.Equal(buf, singleCRLF) { + b.r.ReadByte() + b.r.ReadByte() + return nil + } + + // Make sure there's a header terminator coming up, to prevent + // a DoS with an unbounded size Trailer. It's not easy to + // slip in a LimitReader here, as textproto.NewReader requires + // a concrete *bufio.Reader. Also, we can't get all the way + // back up to our conn's LimitedReader that *might* be backing + // this bufio.Reader. Instead, a hack: we iteratively Peek up + // to the bufio.Reader's max size, looking for a double CRLF. + // This limits the trailer to the underlying buffer size, typically 4kB. + if !seeUpcomingDoubleCRLF(b.r) { + return errors.New("http: suspiciously long trailer after chunked body") + } + + hdr, err := textproto.NewReader(b.r).ReadMIMEHeader() + if err != nil { + return err + } + switch rr := b.hdr.(type) { + case *Request: + rr.Trailer = Header(hdr) + case *Response: + rr.Trailer = Header(hdr) + } + return nil +} + +func (b *body) Close() error { + if b.closed { + return nil + } + defer func() { + b.closed = true + }() + if b.hdr == nil && b.closing { + // no trailer and closing the connection next. + // no point in reading to EOF. + return nil + } + + // In a server request, don't continue reading from the client + // if we've already hit the maximum body size set by the + // handler. If this is set, that also means the TCP connection + // is about to be closed, so getting to the next HTTP request + // in the stream is not necessary. + if b.res != nil && b.res.requestBodyLimitHit { + return nil + } + + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + if _, err := io.Copy(ioutil.Discard, b); err != nil { + return err + } + return nil +} diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go new file mode 100644 index 000000000..4de070f01 --- /dev/null +++ b/src/pkg/net/http/transport.go @@ -0,0 +1,736 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP client implementation. See RFC 2616. +// +// This is the low-level Transport implementation of RoundTripper. +// The high-level interface is in client.go. + +package http + +import ( + "bufio" + "compress/gzip" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/url" + "os" + "strings" + "sync" +) + +// DefaultTransport is the default implementation of Transport and is +// used by DefaultClient. It establishes a new network connection for +// each call to Do and uses HTTP proxies as directed by the +// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy) +// environment variables. +var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment} + +// DefaultMaxIdleConnsPerHost is the default value of Transport's +// MaxIdleConnsPerHost. +const DefaultMaxIdleConnsPerHost = 2 + +// Transport is an implementation of RoundTripper that supports http, +// https, and http proxies (for either http or https with CONNECT). +// Transport can also cache connections for future re-use. +type Transport struct { + lk sync.Mutex + idleConn map[string][]*persistConn + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper + + // TODO: tunable on global max cached connections + // TODO: tunable on timeout on cached connections + // TODO: optional pipelining + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*Request) (*url.URL, error) + + // Dial specifies the dial function for creating TCP + // connections. + // If Dial is nil, net.Dial is used. + Dial func(net, addr string) (c net.Conn, err error) + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + DisableKeepAlives bool + DisableCompression bool + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + // (keep-alive) to keep to keep per-host. If zero, + // DefaultMaxIdleConnsPerHost is used. + MaxIdleConnsPerHost int +} + +// ProxyFromEnvironment returns the URL of the proxy to use for a +// given request, as indicated by the environment variables +// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy). +// Either URL or an error is returned. +func ProxyFromEnvironment(req *Request) (*url.URL, error) { + proxy := getenvEitherCase("HTTP_PROXY") + if proxy == "" { + return nil, nil + } + if !useProxy(canonicalAddr(req.URL)) { + return nil, nil + } + proxyURL, err := url.ParseRequest(proxy) + if err != nil { + return nil, errors.New("invalid proxy address") + } + if proxyURL.Host == "" { + proxyURL, err = url.ParseRequest("http://" + proxy) + if err != nil { + return nil, errors.New("invalid proxy address") + } + } + return proxyURL, nil +} + +// ProxyURL returns a proxy function (for use in a Transport) +// that always returns the same URL. +func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) { + return func(*Request) (*url.URL, error) { + return fixedURL, nil + } +} + +// transportRequest is a wrapper around a *Request that adds +// optional extra headers to write. +type transportRequest struct { + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil +} + +func (tr *transportRequest) extraHeaders() Header { + if tr.extra == nil { + tr.extra = make(Header) + } + return tr.extra +} + +// RoundTrip implements the RoundTripper interface. +func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { + if req.URL == nil { + return nil, errors.New("http: nil Request.URL") + } + if req.Header == nil { + return nil, errors.New("http: nil Request.Header") + } + if req.URL.Scheme != "http" && req.URL.Scheme != "https" { + t.lk.Lock() + var rt RoundTripper + if t.altProto != nil { + rt = t.altProto[req.URL.Scheme] + } + t.lk.Unlock() + if rt == nil { + return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} + } + return rt.RoundTrip(req) + } + treq := &transportRequest{Request: req} + cm, err := t.connectMethodForRequest(treq) + if err != nil { + return nil, err + } + + // Get the cached or newly-created connection to either the + // host (for http or https), the http proxy, or the http proxy + // pre-CONNECTed to https server. In any case, we'll be ready + // to send it requests. + pconn, err := t.getConn(cm) + if err != nil { + return nil, err + } + + return pconn.roundTrip(treq) +} + +// RegisterProtocol registers a new protocol with scheme. +// The Transport will pass requests using the given scheme to rt. +// It is rt's responsibility to simulate HTTP request semantics. +// +// RegisterProtocol can be used by other packages to provide +// implementations of protocol schemes like "ftp" or "file". +func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { + if scheme == "http" || scheme == "https" { + panic("protocol " + scheme + " already registered") + } + t.lk.Lock() + defer t.lk.Unlock() + if t.altProto == nil { + t.altProto = make(map[string]RoundTripper) + } + if _, exists := t.altProto[scheme]; exists { + panic("protocol " + scheme + " already registered") + } + t.altProto[scheme] = rt +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in +// a "keep-alive" state. It does not interrupt any connections currently +// in use. +func (t *Transport) CloseIdleConnections() { + t.lk.Lock() + defer t.lk.Unlock() + if t.idleConn == nil { + return + } + for _, conns := range t.idleConn { + for _, pconn := range conns { + pconn.close() + } + } + t.idleConn = nil +} + +// +// Private implementation past this point. +// + +func getenvEitherCase(k string) string { + if v := os.Getenv(strings.ToUpper(k)); v != "" { + return v + } + return os.Getenv(strings.ToLower(k)) +} + +func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, error) { + cm := &connectMethod{ + targetScheme: treq.URL.Scheme, + targetAddr: canonicalAddr(treq.URL), + } + if t.Proxy != nil { + var err error + cm.proxyURL, err = t.Proxy(treq.Request) + if err != nil { + return nil, err + } + } + return cm, nil +} + +// proxyAuth returns the Proxy-Authorization header to set +// on requests, if applicable. +func (cm *connectMethod) proxyAuth() string { + if cm.proxyURL == nil { + return "" + } + if u := cm.proxyURL.User; u != nil { + return "Basic " + base64.URLEncoding.EncodeToString([]byte(u.String())) + } + return "" +} + +func (t *Transport) putIdleConn(pconn *persistConn) { + t.lk.Lock() + defer t.lk.Unlock() + if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { + pconn.close() + return + } + if pconn.isBroken() { + return + } + key := pconn.cacheKey + max := t.MaxIdleConnsPerHost + if max == 0 { + max = DefaultMaxIdleConnsPerHost + } + if len(t.idleConn[key]) >= max { + pconn.close() + return + } + t.idleConn[key] = append(t.idleConn[key], pconn) +} + +func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { + t.lk.Lock() + defer t.lk.Unlock() + if t.idleConn == nil { + t.idleConn = make(map[string][]*persistConn) + } + key := cm.String() + for { + pconns, ok := t.idleConn[key] + if !ok { + return nil + } + if len(pconns) == 1 { + pconn = pconns[0] + delete(t.idleConn, key) + } else { + // 2 or more cached connections; pop last + // TODO: queue? + pconn = pconns[len(pconns)-1] + t.idleConn[key] = pconns[0 : len(pconns)-1] + } + if !pconn.isBroken() { + return + } + } + return +} + +func (t *Transport) dial(network, addr string) (c net.Conn, err error) { + if t.Dial != nil { + return t.Dial(network, addr) + } + return net.Dial(network, addr) +} + +// getConn dials and creates a new persistConn to the target as +// specified in the connectMethod. This includes doing a proxy CONNECT +// and/or setting up TLS. If this doesn't return an error, the persistConn +// is ready to write requests to. +func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { + if pc := t.getIdleConn(cm); pc != nil { + return pc, nil + } + + conn, err := t.dial("tcp", cm.addr()) + if err != nil { + if cm.proxyURL != nil { + err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + } + return nil, err + } + + pa := cm.proxyAuth() + + pconn := &persistConn{ + t: t, + cacheKey: cm.String(), + conn: conn, + reqch: make(chan requestAndChan, 50), + } + + switch { + case cm.proxyURL == nil: + // Do nothing. + case cm.targetScheme == "http": + pconn.isProxy = true + if pa != "" { + pconn.mutateHeaderFunc = func(h Header) { + h.Set("Proxy-Authorization", pa) + } + } + case cm.targetScheme == "https": + connectReq := &Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: cm.targetAddr}, + Host: cm.targetAddr, + Header: make(Header), + } + if pa != "" { + connectReq.Header.Set("Proxy-Authorization", pa) + } + connectReq.Write(conn) + + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := ReadResponse(br, connectReq) + if err != nil { + conn.Close() + return nil, err + } + if resp.StatusCode != 200 { + f := strings.SplitN(resp.Status, " ", 2) + conn.Close() + return nil, errors.New(f[1]) + } + } + + if cm.targetScheme == "https" { + // Initiate TLS and check remote host name against certificate. + conn = tls.Client(conn, t.TLSClientConfig) + if err = conn.(*tls.Conn).Handshake(); err != nil { + return nil, err + } + if t.TLSClientConfig == nil || !t.TLSClientConfig.InsecureSkipVerify { + if err = conn.(*tls.Conn).VerifyHostname(cm.tlsHost()); err != nil { + return nil, err + } + } + pconn.conn = conn + } + + pconn.br = bufio.NewReader(pconn.conn) + pconn.bw = bufio.NewWriter(pconn.conn) + go pconn.readLoop() + return pconn, nil +} + +// useProxy returns true if requests to addr should use a proxy, +// according to the NO_PROXY or no_proxy environment variable. +// addr is always a canonicalAddr with a host and port. +func useProxy(addr string) bool { + if len(addr) == 0 { + return true + } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return false + } + if host == "localhost" { + return false + } + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() { + return false + } + } + + no_proxy := getenvEitherCase("NO_PROXY") + if no_proxy == "*" { + return false + } + + addr = strings.ToLower(strings.TrimSpace(addr)) + if hasPort(addr) { + addr = addr[:strings.LastIndex(addr, ":")] + } + + for _, p := range strings.Split(no_proxy, ",") { + p = strings.ToLower(strings.TrimSpace(p)) + if len(p) == 0 { + continue + } + if hasPort(p) { + p = p[:strings.LastIndex(p, ":")] + } + if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) { + return false + } + } + return true +} + +// connectMethod is the map key (in its String form) for keeping persistent +// TCP connections alive for subsequent HTTP requests. +// +// A connect method may be of the following types: +// +// Cache key form Description +// ----------------- ------------------------- +// ||http|foo.com http directly to server, no proxy +// ||https|foo.com https directly to server, no proxy +// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com +// http://proxy.com|http http to proxy, http to anywhere after that +// +// Note: no support to https to the proxy yet. +// +type connectMethod struct { + proxyURL *url.URL // nil for no proxy, else full proxy URL + targetScheme string // "http" or "https" + targetAddr string // Not used if proxy + http targetScheme (4th example in table) +} + +func (ck *connectMethod) String() string { + proxyStr := "" + if ck.proxyURL != nil { + proxyStr = ck.proxyURL.String() + } + return strings.Join([]string{proxyStr, ck.targetScheme, ck.targetAddr}, "|") +} + +// addr returns the first hop "host:port" to which we need to TCP connect. +func (cm *connectMethod) addr() string { + if cm.proxyURL != nil { + return canonicalAddr(cm.proxyURL) + } + return cm.targetAddr +} + +// tlsHost returns the host name to match against the peer's +// TLS certificate. +func (cm *connectMethod) tlsHost() string { + h := cm.targetAddr + if hasPort(h) { + h = h[:strings.LastIndex(h, ":")] + } + return h +} + +// persistConn wraps a connection, usually a persistent one +// (but may be used for non-keep-alive requests as well) +type persistConn struct { + t *Transport + cacheKey string // its connectMethod.String() + conn net.Conn + br *bufio.Reader // from conn + bw *bufio.Writer // to conn + reqch chan requestAndChan // written by roundTrip(); read by readLoop() + isProxy bool + + // mutateHeaderFunc is an optional func to modify extra + // headers on each outbound request before it's written. (the + // original Request given to RoundTrip is not modified) + mutateHeaderFunc func(Header) + + lk sync.Mutex // guards numExpectedResponses and broken + numExpectedResponses int + broken bool // an error has happened on this connection; marked broken so it's not reused. +} + +func (pc *persistConn) isBroken() bool { + pc.lk.Lock() + defer pc.lk.Unlock() + return pc.broken +} + +var remoteSideClosedFunc func(error) bool // or nil to use default + +func remoteSideClosed(err error) bool { + if err == io.EOF { + return true + } + if remoteSideClosedFunc != nil { + return remoteSideClosedFunc(err) + } + return false +} + +func (pc *persistConn) readLoop() { + alive := true + var lastbody io.ReadCloser // last response body, if any, read on this connection + + for alive { + pb, err := pc.br.Peek(1) + + pc.lk.Lock() + if pc.numExpectedResponses == 0 { + pc.closeLocked() + pc.lk.Unlock() + if len(pb) > 0 { + log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", + string(pb), err) + } + return + } + pc.lk.Unlock() + + rc := <-pc.reqch + + // Advance past the previous response's body, if the + // caller hasn't done so. + if lastbody != nil { + lastbody.Close() // assumed idempotent + lastbody = nil + } + resp, err := ReadResponse(pc.br, rc.req) + + if err == nil { + hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 + if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" { + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + gzReader, zerr := gzip.NewReader(resp.Body) + if zerr != nil { + pc.close() + err = zerr + } else { + resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body} + } + } + resp.Body = &bodyEOFSignal{body: resp.Body} + } + + if err != nil || resp.Close || rc.req.Close { + alive = false + } + + hasBody := resp != nil && resp.ContentLength != 0 + var waitForBodyRead chan bool + if alive { + if hasBody { + lastbody = resp.Body + waitForBodyRead = make(chan bool) + resp.Body.(*bodyEOFSignal).fn = func() { + pc.t.putIdleConn(pc) + waitForBodyRead <- true + } + } else { + // When there's no response body, we immediately + // reuse the TCP connection (putIdleConn), but + // we need to prevent ClientConn.Read from + // closing the Response.Body on the next + // loop, otherwise it might close the body + // before the client code has had a chance to + // read it (even though it'll just be 0, EOF). + lastbody = nil + + pc.t.putIdleConn(pc) + } + } + + rc.ch <- responseAndError{resp, err} + + // Wait for the just-returned response body to be fully consumed + // before we race and peek on the underlying bufio reader. + if waitForBodyRead != nil { + <-waitForBodyRead + } + } +} + +type responseAndError struct { + res *Response + err error +} + +type requestAndChan struct { + req *Request + ch chan responseAndError + + // did the Transport (as opposed to the client code) add an + // Accept-Encoding gzip header? only if it we set it do + // we transparently decode the gzip. + addedGzip bool +} + +func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { + if pc.mutateHeaderFunc != nil { + pc.mutateHeaderFunc(req.extraHeaders()) + } + + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempted to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 + requestedGzip = true + req.extraHeaders().Set("Accept-Encoding", "gzip") + } + + pc.lk.Lock() + pc.numExpectedResponses++ + pc.lk.Unlock() + + err = req.Request.write(pc.bw, pc.isProxy, req.extra) + if err != nil { + pc.close() + return + } + pc.bw.Flush() + + ch := make(chan responseAndError, 1) + pc.reqch <- requestAndChan{req.Request, ch, requestedGzip} + re := <-ch + pc.lk.Lock() + pc.numExpectedResponses-- + pc.lk.Unlock() + + return re.res, re.err +} + +func (pc *persistConn) close() { + pc.lk.Lock() + defer pc.lk.Unlock() + pc.closeLocked() +} + +func (pc *persistConn) closeLocked() { + pc.broken = true + pc.conn.Close() + pc.mutateHeaderFunc = nil +} + +var portMap = map[string]string{ + "http": "80", + "https": "443", +} + +// canonicalAddr returns url.Host but always with a ":port" suffix +func canonicalAddr(url *url.URL) string { + addr := url.Host + if !hasPort(addr) { + return addr + ":" + portMap[url.Scheme] + } + return addr +} + +func responseIsKeepAlive(res *Response) bool { + // TODO: implement. for now just always shutting down the connection. + return false +} + +// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most +// once, right before the final Read() or Close() call returns, but after +// EOF has been seen. +type bodyEOFSignal struct { + body io.ReadCloser + fn func() + isClosed bool +} + +func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { + n, err = es.body.Read(p) + if es.isClosed && n > 0 { + panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") + } + if err == io.EOF && es.fn != nil { + es.fn() + es.fn = nil + } + return +} + +func (es *bodyEOFSignal) Close() (err error) { + if es.isClosed { + return nil + } + es.isClosed = true + err = es.body.Close() + if err == nil && es.fn != nil { + es.fn() + es.fn = nil + } + return +} + +type readFirstCloseBoth struct { + io.ReadCloser + io.Closer +} + +func (r *readFirstCloseBoth) Close() error { + if err := r.ReadCloser.Close(); err != nil { + r.Closer.Close() + return err + } + if err := r.Closer.Close(); err != nil { + return err + } + return nil +} + +// discardOnCloseReadCloser consumes all its input on Close. +type discardOnCloseReadCloser struct { + io.ReadCloser +} + +func (d *discardOnCloseReadCloser) Close() error { + io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed + return d.ReadCloser.Close() +} diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go new file mode 100644 index 000000000..321da52e2 --- /dev/null +++ b/src/pkg/net/http/transport_test.go @@ -0,0 +1,695 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for transport.go + +package http_test + +import ( + "bytes" + "compress/gzip" + "crypto/rand" + "fmt" + "io" + "io/ioutil" + . "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "testing" + "time" +) + +// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close +// and then verify that the final 2 responses get errors back. + +// hostPortHandler writes back the client's "host:port". +var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + if r.FormValue("close") == "true" { + w.Header().Set("Connection", "close") + } + w.Write([]byte(r.RemoteAddr)) +}) + +// Two subsequent requests and verify their response is the same. +// The response from the server is our own IP:port +func TestTransportKeepAlives(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + for _, disableKeepAlive := range []bool{false, true} { + tr := &Transport{DisableKeepAlives: disableKeepAlive} + c := &Client{Transport: tr} + + fetch := func(n int) string { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + + bodiesDiffer := body1 != body2 + if bodiesDiffer != disableKeepAlive { + t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + disableKeepAlive, bodiesDiffer, body1, body2) + } + } +} + +func TestTransportConnectionCloseOnResponse(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + for _, connectionClose := range []bool{false, true} { + tr := &Transport{} + c := &Client{Transport: tr} + + fetch := func(n int) string { + req := new(Request) + var err error + req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + } + body, err := ioutil.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + bodiesDiffer := body1 != body2 + if bodiesDiffer != connectionClose { + t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + connectionClose, bodiesDiffer, body1, body2) + } + } +} + +func TestTransportConnectionCloseOnRequest(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + for _, connectionClose := range []bool{false, true} { + tr := &Transport{} + c := &Client{Transport: tr} + + fetch := func(n int) string { + req := new(Request) + var err error + req.URL, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + req.Close = connectionClose + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + bodiesDiffer := body1 != body2 + if bodiesDiffer != connectionClose { + t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + connectionClose, bodiesDiffer, body1, body2) + } + } +} + +func TestTransportIdleCacheKeys(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } + + resp, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } + ioutil.ReadAll(resp.Body) + + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) + } + + if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { + t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) + } + + tr.CloseIdleConnections() + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } +} + +func TestTransportMaxPerHostIdleConns(t *testing.T) { + resch := make(chan string) + gotReq := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + gotReq <- true + msg := <-resch + _, err := w.Write([]byte(msg)) + if err != nil { + t.Fatalf("Write: %v", err) + } + })) + defer ts.Close() + maxIdleConns := 2 + tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns} + c := &Client{Transport: tr} + + // Start 3 outstanding requests and wait for the server to get them. + // Their responses will hang until we we write to resch, though. + donech := make(chan bool) + doReq := func() { + resp, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + donech <- true + } + go doReq() + <-gotReq + go doReq() + <-gotReq + go doReq() + <-gotReq + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) + } + + resch <- "res1" + <-donech + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) + } + cacheKey := "|http|" + ts.Listener.Addr().String() + if keys[0] != cacheKey { + t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) + } + if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g { + t.Errorf("after first response, expected %d idle conns; got %d", e, g) + } + + resch <- "res2" + <-donech + if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g { + t.Errorf("after second response, expected %d idle conns; got %d", e, g) + } + + resch <- "res3" + <-donech + if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g { + t.Errorf("after third response, still expected %d idle conns; got %d", e, g) + } +} + +func TestTransportServerClosingUnexpectedly(t *testing.T) { + ts := httptest.NewServer(hostPortHandler) + defer ts.Close() + + tr := &Transport{} + c := &Client{Transport: tr} + + fetch := func(n, retries int) string { + condFatalf := func(format string, arg ...interface{}) { + if retries <= 0 { + t.Fatalf(format, arg...) + } + t.Logf("retrying shortly after expected error: "+format, arg...) + time.Sleep(time.Second / time.Duration(retries)) + } + for retries >= 0 { + retries-- + res, err := c.Get(ts.URL) + if err != nil { + condFatalf("error in req #%d, GET: %v", n, err) + continue + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + condFatalf("error in req #%d, ReadAll: %v", n, err) + continue + } + res.Body.Close() + return string(body) + } + panic("unreachable") + } + + body1 := fetch(1, 0) + body2 := fetch(2, 0) + + ts.CloseClientConnections() // surprise! + + // This test has an expected race. Sleeping for 25 ms prevents + // it on most fast machines, causing the next fetch() call to + // succeed quickly. But if we do get errors, fetch() will retry 5 + // times with some delays between. + time.Sleep(25 * time.Millisecond) + + body3 := fetch(3, 5) + + if body1 != body2 { + t.Errorf("expected body1 and body2 to be equal") + } + if body2 == body3 { + t.Errorf("expected body2 and body3 to be different") + } +} + +// Test for http://golang.org/issue/2616 (appropriate issue number) +// This fails pretty reliably with GOMAXPROCS=100 or something high. +func TestStressSurpriseServerCloses(t *testing.T) { + if testing.Short() { + t.Logf("skipping test in short mode") + return + } + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "5") + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("Hello")) + w.(Flusher).Flush() + conn, buf, _ := w.(Hijacker).Hijack() + buf.Flush() + conn.Close() + })) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + + // Do a bunch of traffic from different goroutines. Send to activityc + // after each request completes, regardless of whether it failed. + const ( + numClients = 50 + reqsPerClient = 250 + ) + activityc := make(chan bool) + for i := 0; i < numClients; i++ { + go func() { + for i := 0; i < reqsPerClient; i++ { + res, err := c.Get(ts.URL) + if err == nil { + // We expect errors since the server is + // hanging up on us after telling us to + // send more requests, so we don't + // actually care what the error is. + // But we want to close the body in cases + // where we won the race. + res.Body.Close() + } + activityc <- true + } + }() + } + + // Make sure all the request come back, one way or another. + for i := 0; i < numClients*reqsPerClient; i++ { + select { + case <-activityc: + case <-time.After(5 * time.Second): + t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") + } + } +} + +// TestTransportHeadResponses verifies that we deal with Content-Lengths +// with no bodies properly +func TestTransportHeadResponses(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Content-Length", "123") + w.WriteHeader(200) + })) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + for i := 0; i < 2; i++ { + res, err := c.Head(ts.URL) + if err != nil { + t.Errorf("error on loop %d: %v", i, err) + } + if e, g := "123", res.Header.Get("Content-Length"); e != g { + t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) + } + if e, g := int64(0), res.ContentLength; e != g { + t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) + } + } +} + +// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding +// on responses to HEAD requests. +func TestTransportHeadChunkedResponse(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Transfer-Encoding", "chunked") // client should ignore + w.Header().Set("x-client-ipport", r.RemoteAddr) + w.WriteHeader(200) + })) + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + c := &Client{Transport: tr} + + res1, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("request 1 error: %v", err) + } + res2, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("request 2 error: %v", err) + } + if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { + t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) + } +} + +var roundTripTests = []struct { + accept string + expectAccept string + compressed bool +}{ + // Requests with no accept-encoding header use transparent compression + {"", "gzip", false}, + // Requests with other accept-encoding should pass through unmodified + {"foo", "foo", false}, + // Requests with accept-encoding == gzip should be passed through + {"gzip", "gzip", true}, +} + +// Test that the modification made to the Request by the RoundTripper is cleaned up +func TestRoundTripGzip(t *testing.T) { + const responseBody = "test response body" + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + accept := req.Header.Get("Accept-Encoding") + if expect := req.FormValue("expect_accept"); accept != expect { + t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", + req.FormValue("testnum"), accept, expect) + } + if accept == "gzip" { + rw.Header().Set("Content-Encoding", "gzip") + gz, _ := gzip.NewWriter(rw) + gz.Write([]byte(responseBody)) + gz.Close() + } else { + rw.Header().Set("Content-Encoding", accept) + rw.Write([]byte(responseBody)) + } + })) + defer ts.Close() + + for i, test := range roundTripTests { + // Test basic request (no accept-encoding) + req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) + if test.accept != "" { + req.Header.Set("Accept-Encoding", test.accept) + } + res, err := DefaultTransport.RoundTrip(req) + var body []byte + if test.compressed { + gzip, _ := gzip.NewReader(res.Body) + body, err = ioutil.ReadAll(gzip) + res.Body.Close() + } else { + body, err = ioutil.ReadAll(res.Body) + } + if err != nil { + t.Errorf("%d. Error: %q", i, err) + continue + } + if g, e := string(body), responseBody; g != e { + t.Errorf("%d. body = %q; want %q", i, g, e) + } + if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { + t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) + } + if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { + t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) + } + } + +} + +func TestTransportGzip(t *testing.T) { + const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + const nRandBytes = 1024 * 1024 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { + t.Errorf("Accept-Encoding = %q, want %q", g, e) + } + rw.Header().Set("Content-Encoding", "gzip") + if req.Method == "HEAD" { + return + } + + var w io.Writer = rw + var buf bytes.Buffer + if req.FormValue("chunked") == "0" { + w = &buf + defer io.Copy(rw, &buf) + defer func() { + rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + }() + } + gz, _ := gzip.NewWriter(w) + gz.Write([]byte(testString)) + if req.FormValue("body") == "large" { + io.CopyN(gz, rand.Reader, nRandBytes) + } + gz.Close() + })) + defer ts.Close() + + for _, chunked := range []string{"1", "0"} { + c := &Client{Transport: &Transport{}} + + // First fetch something large, but only read some of it. + res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) + if err != nil { + t.Fatalf("large get: %v", err) + } + buf := make([]byte, len(testString)) + n, err := io.ReadFull(res.Body, buf) + if err != nil { + t.Fatalf("partial read of large response: size=%d, %v", n, err) + } + if e, g := testString, string(buf); e != g { + t.Errorf("partial read got %q, expected %q", g, e) + } + res.Body.Close() + // Read on the body, even though it's closed + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) + } + + // Then something small. + res, err = c.Get(ts.URL + "/?chunked=" + chunked) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), testString; g != e { + t.Fatalf("body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } + + // Read on the body after it's been fully read: + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) + } + res.Body.Close() + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after Close; got %d, %v", n, err) + } + } + + // And a HEAD request too, because they're always weird. + c := &Client{Transport: &Transport{}} + res, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("Head: %v", err) + } + if res.StatusCode != 200 { + t.Errorf("Head status=%d; want=200", res.StatusCode) + } +} + +func TestTransportProxy(t *testing.T) { + ch := make(chan string, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- "real server" + })) + defer ts.Close() + proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- "proxy for " + r.URL.String() + })) + defer proxy.Close() + + pu, err := url.Parse(proxy.URL) + if err != nil { + t.Fatal(err) + } + c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}} + c.Head(ts.URL) + got := <-ch + want := "proxy for " + ts.URL + "/" + if got != want { + t.Errorf("want %q, got %q", want, got) + } +} + +// TestTransportGzipRecursive sends a gzip quine and checks that the +// client gets the same value back. This is more cute than anything, +// but checks that we don't recurse forever, and checks that +// Content-Encoding is removed. +func TestTransportGzipRecursive(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write(rgz) + })) + defer ts.Close() + + c := &Client{Transport: &Transport{}} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(body, rgz) { + t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", + body, rgz) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } +} + +type fooProto struct{} + +func (fooProto) RoundTrip(req *Request) (*Response, error) { + res := &Response{ + Status: "200 OK", + StatusCode: 200, + Header: make(Header), + Body: ioutil.NopCloser(strings.NewReader("You wanted " + req.URL.String())), + } + return res, nil +} + +func TestTransportAltProto(t *testing.T) { + tr := &Transport{} + c := &Client{Transport: tr} + tr.RegisterProtocol("foo", fooProto{}) + res, err := c.Get("foo://bar.com/path") + if err != nil { + t.Fatal(err) + } + bodyb, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + body := string(bodyb) + if e := "You wanted foo://bar.com/path"; body != e { + t.Errorf("got response %q, want %q", body, e) + } +} + +// rgz is a gzip quine that uncompresses to itself. +var rgz = []byte{ + 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, + 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, + 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, + 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, + 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, + 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, + 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, + 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, + 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, + 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, + 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, + 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, + 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, + 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, + 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, + 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, + 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, +} diff --git a/src/pkg/net/http/triv.go b/src/pkg/net/http/triv.go new file mode 100644 index 000000000..994fc0e32 --- /dev/null +++ b/src/pkg/net/http/triv.go @@ -0,0 +1,149 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "expvar" + "flag" + "fmt" + "io" + "log" + "net/http" + "os" + "strconv" +) + +// hello world, the web server +var helloRequests = expvar.NewInt("hello-requests") + +func HelloServer(w http.ResponseWriter, req *http.Request) { + helloRequests.Add(1) + io.WriteString(w, "hello, world!\n") +} + +// Simple counter server. POSTing to it will set the value. +type Counter struct { + n int +} + +// This makes Counter satisfy the expvar.Var interface, so we can export +// it directly. +func (ctr *Counter) String() string { return fmt.Sprintf("%d", ctr.n) } + +func (ctr *Counter) ServeHTTP(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case "GET": + ctr.n++ + case "POST": + buf := new(bytes.Buffer) + io.Copy(buf, req.Body) + body := buf.String() + if n, err := strconv.Atoi(body); err != nil { + fmt.Fprintf(w, "bad POST: %v\nbody: [%v]\n", err, body) + } else { + ctr.n = n + fmt.Fprint(w, "counter reset\n") + } + } + fmt.Fprintf(w, "counter = %d\n", ctr.n) +} + +// simple flag server +var booleanflag = flag.Bool("boolean", true, "another flag for testing") + +func FlagServer(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, "Flags:\n") + flag.VisitAll(func(f *flag.Flag) { + if f.Value.String() != f.DefValue { + fmt.Fprintf(w, "%s = %s [default = %s]\n", f.Name, f.Value.String(), f.DefValue) + } else { + fmt.Fprintf(w, "%s = %s\n", f.Name, f.Value.String()) + } + }) +} + +// simple argument server +func ArgServer(w http.ResponseWriter, req *http.Request) { + for _, s := range os.Args { + fmt.Fprint(w, s, " ") + } +} + +// a channel (just for the fun of it) +type Chan chan int + +func ChanCreate() Chan { + c := make(Chan) + go func(c Chan) { + for x := 0; ; x++ { + c <- x + } + }(c) + return c +} + +func (ch Chan) ServeHTTP(w http.ResponseWriter, req *http.Request) { + io.WriteString(w, fmt.Sprintf("channel send #%d\n", <-ch)) +} + +// exec a program, redirecting output +func DateServer(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("Content-Type", "text/plain; charset=utf-8") + r, w, err := os.Pipe() + if err != nil { + fmt.Fprintf(rw, "pipe: %s\n", err) + return + } + + p, err := os.StartProcess("/bin/date", []string{"date"}, &os.ProcAttr{Files: []*os.File{nil, w, w}}) + defer r.Close() + w.Close() + if err != nil { + fmt.Fprintf(rw, "fork/exec: %s\n", err) + return + } + defer p.Release() + io.Copy(rw, r) + wait, err := p.Wait(0) + if err != nil { + fmt.Fprintf(rw, "wait: %s\n", err) + return + } + if !wait.Exited() || wait.ExitStatus() != 0 { + fmt.Fprintf(rw, "date: %v\n", wait) + return + } +} + +func Logger(w http.ResponseWriter, req *http.Request) { + log.Print(req.URL.Raw) + w.WriteHeader(404) + w.Write([]byte("oops")) +} + +var webroot = flag.String("root", "/home/rsc", "web root directory") + +func main() { + flag.Parse() + + // The counter is published as a variable directly. + ctr := new(Counter) + http.Handle("/counter", ctr) + expvar.Publish("counter", ctr) + + http.Handle("/", http.HandlerFunc(Logger)) + http.Handle("/go/", http.StripPrefix("/go/", http.FileServer(http.Dir(*webroot)))) + http.Handle("/flags", http.HandlerFunc(FlagServer)) + http.Handle("/args", http.HandlerFunc(ArgServer)) + http.Handle("/go/hello", http.HandlerFunc(HelloServer)) + http.Handle("/chan", ChanCreate()) + http.Handle("/date", http.HandlerFunc(DateServer)) + err := http.ListenAndServe(":12345", nil) + if err != nil { + log.Panicln("ListenAndServe:", err) + } +} diff --git a/src/pkg/net/interface.go b/src/pkg/net/interface.go index 2696b7f4c..5e7b352ed 100644 --- a/src/pkg/net/interface.go +++ b/src/pkg/net/interface.go @@ -8,8 +8,16 @@ package net import ( "bytes" + "errors" "fmt" - "os" +) + +var ( + errInvalidInterface = errors.New("net: invalid interface") + errInvalidInterfaceIndex = errors.New("net: invalid interface index") + errInvalidInterfaceName = errors.New("net: invalid interface name") + errNoSuchInterface = errors.New("net: no such interface") + errNoSuchMulticastInterface = errors.New("net: no such multicast interface") ) // A HardwareAddr represents a physical hardware address. @@ -34,7 +42,7 @@ func (a HardwareAddr) String() string { // 01-23-45-67-89-ab-cd-ef // 0123.4567.89ab // 0123.4567.89ab.cdef -func ParseMAC(s string) (hw HardwareAddr, err os.Error) { +func ParseMAC(s string) (hw HardwareAddr, err error) { if len(s) < 14 { goto error } @@ -80,7 +88,7 @@ func ParseMAC(s string) (hw HardwareAddr, err os.Error) { return hw, nil error: - return nil, os.NewError("invalid MAC address: " + s) + return nil, errors.New("invalid MAC address: " + s) } // Interface represents a mapping between network interface name @@ -129,37 +137,37 @@ func (f Flags) String() string { } // Addrs returns interface addresses for a specific interface. -func (ifi *Interface) Addrs() ([]Addr, os.Error) { +func (ifi *Interface) Addrs() ([]Addr, error) { if ifi == nil { - return nil, os.NewError("net: invalid interface") + return nil, errInvalidInterface } return interfaceAddrTable(ifi.Index) } // MulticastAddrs returns multicast, joined group addresses for // a specific interface. -func (ifi *Interface) MulticastAddrs() ([]Addr, os.Error) { +func (ifi *Interface) MulticastAddrs() ([]Addr, error) { if ifi == nil { - return nil, os.NewError("net: invalid interface") + return nil, errInvalidInterface } return interfaceMulticastAddrTable(ifi.Index) } // Interfaces returns a list of the systems's network interfaces. -func Interfaces() ([]Interface, os.Error) { +func Interfaces() ([]Interface, error) { return interfaceTable(0) } // InterfaceAddrs returns a list of the system's network interface // addresses. -func InterfaceAddrs() ([]Addr, os.Error) { +func InterfaceAddrs() ([]Addr, error) { return interfaceAddrTable(0) } // InterfaceByIndex returns the interface specified by index. -func InterfaceByIndex(index int) (*Interface, os.Error) { +func InterfaceByIndex(index int) (*Interface, error) { if index <= 0 { - return nil, os.NewError("net: invalid interface index") + return nil, errInvalidInterfaceIndex } ift, err := interfaceTable(index) if err != nil { @@ -168,13 +176,13 @@ func InterfaceByIndex(index int) (*Interface, os.Error) { for _, ifi := range ift { return &ifi, nil } - return nil, os.NewError("net: no such interface") + return nil, errNoSuchInterface } // InterfaceByName returns the interface specified by name. -func InterfaceByName(name string) (*Interface, os.Error) { +func InterfaceByName(name string) (*Interface, error) { if name == "" { - return nil, os.NewError("net: invalid interface name") + return nil, errInvalidInterfaceName } ift, err := interfaceTable(0) if err != nil { @@ -185,5 +193,5 @@ func InterfaceByName(name string) (*Interface, os.Error) { return &ifi, nil } } - return nil, os.NewError("net: no such interface") + return nil, errNoSuchInterface } diff --git a/src/pkg/net/interface_bsd.go b/src/pkg/net/interface_bsd.go index 9171827d2..907f80a80 100644 --- a/src/pkg/net/interface_bsd.go +++ b/src/pkg/net/interface_bsd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd openbsd +// +build darwin freebsd netbsd openbsd // Network interface identification for BSD variants @@ -17,22 +17,17 @@ import ( // If the ifindex is zero, interfaceTable returns mappings of all // network interfaces. Otheriwse it returns a mapping of a specific // interface. -func interfaceTable(ifindex int) ([]Interface, os.Error) { - var ( - tab []byte - e int - msgs []syscall.RoutingMessage - ift []Interface - ) - - tab, e = syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex) - if e != 0 { - return nil, os.NewSyscallError("route rib", e) +func interfaceTable(ifindex int) ([]Interface, error) { + var ift []Interface + + tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex) + if err != nil { + return nil, os.NewSyscallError("route rib", err) } - msgs, e = syscall.ParseRoutingMessage(tab) - if e != 0 { - return nil, os.NewSyscallError("route message", e) + msgs, err := syscall.ParseRoutingMessage(tab) + if err != nil { + return nil, os.NewSyscallError("route message", err) } for _, m := range msgs { @@ -51,12 +46,12 @@ func interfaceTable(ifindex int) ([]Interface, os.Error) { return ift, nil } -func newLink(m *syscall.InterfaceMessage) ([]Interface, os.Error) { +func newLink(m *syscall.InterfaceMessage) ([]Interface, error) { var ift []Interface - sas, e := syscall.ParseRoutingSockaddr(m) - if e != 0 { - return nil, os.NewSyscallError("route sockaddr", e) + sas, err := syscall.ParseRoutingSockaddr(m) + if err != nil { + return nil, os.NewSyscallError("route sockaddr", err) } for _, s := range sas { @@ -107,22 +102,17 @@ func linkFlags(rawFlags int32) Flags { // If the ifindex is zero, interfaceAddrTable returns addresses // for all network interfaces. Otherwise it returns addresses // for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { - var ( - tab []byte - e int - msgs []syscall.RoutingMessage - ifat []Addr - ) - - tab, e = syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex) - if e != 0 { - return nil, os.NewSyscallError("route rib", e) +func interfaceAddrTable(ifindex int) ([]Addr, error) { + var ifat []Addr + + tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex) + if err != nil { + return nil, os.NewSyscallError("route rib", err) } - msgs, e = syscall.ParseRoutingMessage(tab) - if e != 0 { - return nil, os.NewSyscallError("route message", e) + msgs, err := syscall.ParseRoutingMessage(tab) + if err != nil { + return nil, os.NewSyscallError("route message", err) } for _, m := range msgs { @@ -133,7 +123,7 @@ func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { if err != nil { return nil, err } - ifat = append(ifat, ifa...) + ifat = append(ifat, ifa) } } } @@ -141,33 +131,41 @@ func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { return ifat, nil } -func newAddr(m *syscall.InterfaceAddrMessage) ([]Addr, os.Error) { - var ifat []Addr +func newAddr(m *syscall.InterfaceAddrMessage) (Addr, error) { + ifa := &IPNet{} - sas, e := syscall.ParseRoutingSockaddr(m) - if e != 0 { - return nil, os.NewSyscallError("route sockaddr", e) + sas, err := syscall.ParseRoutingSockaddr(m) + if err != nil { + return nil, os.NewSyscallError("route sockaddr", err) } - for _, s := range sas { - + for i, s := range sas { switch v := s.(type) { case *syscall.SockaddrInet4: - ifa := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])} - ifat = append(ifat, ifa.toAddr()) + switch i { + case 0: + ifa.Mask = IPv4Mask(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3]) + case 1: + ifa.IP = IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3]) + } case *syscall.SockaddrInet6: - ifa := &IPAddr{IP: make(IP, IPv6len)} - copy(ifa.IP, v.Addr[:]) - // NOTE: KAME based IPv6 protcol stack usually embeds - // the interface index in the interface-local or link- - // local address as the kernel-internal form. - if ifa.IP.IsLinkLocalUnicast() { - // remove embedded scope zone ID - ifa.IP[2], ifa.IP[3] = 0, 0 + switch i { + case 0: + ifa.Mask = make(IPMask, IPv6len) + copy(ifa.Mask, v.Addr[:]) + case 1: + ifa.IP = make(IP, IPv6len) + copy(ifa.IP, v.Addr[:]) + // NOTE: KAME based IPv6 protcol stack usually embeds + // the interface index in the interface-local or link- + // local address as the kernel-internal form. + if ifa.IP.IsLinkLocalUnicast() { + // remove embedded scope zone ID + ifa.IP[2], ifa.IP[3] = 0, 0 + } } - ifat = append(ifat, ifa.toAddr()) } } - return ifat, nil + return ifa, nil } diff --git a/src/pkg/net/interface_darwin.go b/src/pkg/net/interface_darwin.go index a7b68ad7f..2da447adc 100644 --- a/src/pkg/net/interface_darwin.go +++ b/src/pkg/net/interface_darwin.go @@ -14,21 +14,21 @@ import ( // If the ifindex is zero, interfaceMulticastAddrTable returns // addresses for all network interfaces. Otherwise it returns // addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { +func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { var ( tab []byte - e int + e error msgs []syscall.RoutingMessage ifmat []Addr ) tab, e = syscall.RouteRIB(syscall.NET_RT_IFLIST2, ifindex) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("route rib", e) } msgs, e = syscall.ParseRoutingMessage(tab) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("route message", e) } @@ -48,11 +48,11 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { return ifmat, nil } -func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, os.Error) { +func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { var ifmat []Addr sas, e := syscall.ParseRoutingSockaddr(m) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("route sockaddr", e) } diff --git a/src/pkg/net/interface_freebsd.go b/src/pkg/net/interface_freebsd.go index 20f506b08..a12877e25 100644 --- a/src/pkg/net/interface_freebsd.go +++ b/src/pkg/net/interface_freebsd.go @@ -14,21 +14,21 @@ import ( // If the ifindex is zero, interfaceMulticastAddrTable returns // addresses for all network interfaces. Otherwise it returns // addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { +func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { var ( tab []byte - e int + e error msgs []syscall.RoutingMessage ifmat []Addr ) tab, e = syscall.RouteRIB(syscall.NET_RT_IFMALIST, ifindex) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("route rib", e) } msgs, e = syscall.ParseRoutingMessage(tab) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("route message", e) } @@ -48,11 +48,11 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { return ifmat, nil } -func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, os.Error) { +func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { var ifmat []Addr sas, e := syscall.ParseRoutingSockaddr(m) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("route sockaddr", e) } diff --git a/src/pkg/net/interface_linux.go b/src/pkg/net/interface_linux.go index 3d2a0bb9f..c0887c57e 100644 --- a/src/pkg/net/interface_linux.go +++ b/src/pkg/net/interface_linux.go @@ -16,22 +16,17 @@ import ( // If the ifindex is zero, interfaceTable returns mappings of all // network interfaces. Otheriwse it returns a mapping of a specific // interface. -func interfaceTable(ifindex int) ([]Interface, os.Error) { - var ( - ift []Interface - tab []byte - msgs []syscall.NetlinkMessage - e int - ) +func interfaceTable(ifindex int) ([]Interface, error) { + var ift []Interface - tab, e = syscall.NetlinkRIB(syscall.RTM_GETLINK, syscall.AF_UNSPEC) - if e != 0 { - return nil, os.NewSyscallError("netlink rib", e) + tab, err := syscall.NetlinkRIB(syscall.RTM_GETLINK, syscall.AF_UNSPEC) + if err != nil { + return nil, os.NewSyscallError("netlink rib", err) } - msgs, e = syscall.ParseNetlinkMessage(tab) - if e != 0 { - return nil, os.NewSyscallError("netlink message", e) + msgs, err := syscall.ParseNetlinkMessage(tab) + if err != nil { + return nil, os.NewSyscallError("netlink message", err) } for _, m := range msgs { @@ -41,11 +36,11 @@ func interfaceTable(ifindex int) ([]Interface, os.Error) { case syscall.RTM_NEWLINK: ifim := (*syscall.IfInfomsg)(unsafe.Pointer(&m.Data[0])) if ifindex == 0 || ifindex == int(ifim.Index) { - attrs, e := syscall.ParseNetlinkRouteAttr(&m) - if e != 0 { - return nil, os.NewSyscallError("netlink routeattr", e) + attrs, err := syscall.ParseNetlinkRouteAttr(&m) + if err != nil { + return nil, os.NewSyscallError("netlink routeattr", err) } - ifi := newLink(attrs, ifim) + ifi := newLink(ifim, attrs) ift = append(ift, ifi) } } @@ -55,7 +50,7 @@ done: return ift, nil } -func newLink(attrs []syscall.NetlinkRouteAttr, ifim *syscall.IfInfomsg) Interface { +func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) Interface { ifi := Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)} for _, a := range attrs { switch a.Attr.Type { @@ -101,47 +96,26 @@ func linkFlags(rawFlags uint32) Flags { // If the ifindex is zero, interfaceAddrTable returns addresses // for all network interfaces. Otherwise it returns addresses // for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { - var ( - tab []byte - e int - err os.Error - ifat4 []Addr - ifat6 []Addr - msgs4 []syscall.NetlinkMessage - msgs6 []syscall.NetlinkMessage - ) - - tab, e = syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_INET) - if e != 0 { - return nil, os.NewSyscallError("netlink rib", e) - } - msgs4, e = syscall.ParseNetlinkMessage(tab) - if e != 0 { - return nil, os.NewSyscallError("netlink message", e) - } - ifat4, err = addrTable(msgs4, ifindex) +func interfaceAddrTable(ifindex int) ([]Addr, error) { + tab, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC) if err != nil { - return nil, err + return nil, os.NewSyscallError("netlink rib", err) } - tab, e = syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_INET6) - if e != 0 { - return nil, os.NewSyscallError("netlink rib", e) - } - msgs6, e = syscall.ParseNetlinkMessage(tab) - if e != 0 { - return nil, os.NewSyscallError("netlink message", e) + msgs, err := syscall.ParseNetlinkMessage(tab) + if err != nil { + return nil, os.NewSyscallError("netlink message", err) } - ifat6, err = addrTable(msgs6, ifindex) + + ifat, err := addrTable(msgs, ifindex) if err != nil { return nil, err } - return append(ifat4, ifat6...), nil + return ifat, nil } -func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, os.Error) { +func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, error) { var ifat []Addr for _, m := range msgs { @@ -151,11 +125,11 @@ func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, os.Error) { case syscall.RTM_NEWADDR: ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0])) if ifindex == 0 || ifindex == int(ifam.Index) { - attrs, e := syscall.ParseNetlinkRouteAttr(&m) - if e != 0 { - return nil, os.NewSyscallError("netlink routeattr", e) + attrs, err := syscall.ParseNetlinkRouteAttr(&m) + if err != nil { + return nil, os.NewSyscallError("netlink routeattr", err) } - ifat = append(ifat, newAddr(attrs, int(ifam.Family))...) + ifat = append(ifat, newAddr(attrs, int(ifam.Family), int(ifam.Prefixlen))) } } } @@ -164,34 +138,32 @@ done: return ifat, nil } -func newAddr(attrs []syscall.NetlinkRouteAttr, family int) []Addr { - var ifat []Addr - +func newAddr(attrs []syscall.NetlinkRouteAttr, family, pfxlen int) Addr { + ifa := &IPNet{} for _, a := range attrs { switch a.Attr.Type { case syscall.IFA_ADDRESS: switch family { case syscall.AF_INET: - ifa := &IPAddr{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3])} - ifat = append(ifat, ifa.toAddr()) + ifa.IP = IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]) + ifa.Mask = CIDRMask(pfxlen, 8*IPv4len) case syscall.AF_INET6: - ifa := &IPAddr{IP: make(IP, IPv6len)} + ifa.IP = make(IP, IPv6len) copy(ifa.IP, a.Value[:]) - ifat = append(ifat, ifa.toAddr()) + ifa.Mask = CIDRMask(pfxlen, 8*IPv6len) } } } - - return ifat + return ifa } // If the ifindex is zero, interfaceMulticastAddrTable returns // addresses for all network interfaces. Otherwise it returns // addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { +func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { var ( + err error ifi *Interface - err os.Error ) if ifindex > 0 { diff --git a/src/pkg/net/interface_netbsd.go b/src/pkg/net/interface_netbsd.go new file mode 100644 index 000000000..4150e9ad5 --- /dev/null +++ b/src/pkg/net/interface_netbsd.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Network interface identification for NetBSD + +package net + +// If the ifindex is zero, interfaceMulticastAddrTable returns +// addresses for all network interfaces. Otherwise it returns +// addresses for a specific interface. +func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { + return nil, nil +} diff --git a/src/pkg/net/interface_openbsd.go b/src/pkg/net/interface_openbsd.go index f18149393..d8adb4676 100644 --- a/src/pkg/net/interface_openbsd.go +++ b/src/pkg/net/interface_openbsd.go @@ -6,11 +6,9 @@ package net -import "os" - // If the ifindex is zero, interfaceMulticastAddrTable returns // addresses for all network interfaces. Otherwise it returns // addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { +func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { return nil, nil } diff --git a/src/pkg/net/interface_stub.go b/src/pkg/net/interface_stub.go index 282b38b5e..4876b3af3 100644 --- a/src/pkg/net/interface_stub.go +++ b/src/pkg/net/interface_stub.go @@ -8,25 +8,23 @@ package net -import "os" - // If the ifindex is zero, interfaceTable returns mappings of all // network interfaces. Otheriwse it returns a mapping of a specific // interface. -func interfaceTable(ifindex int) ([]Interface, os.Error) { +func interfaceTable(ifindex int) ([]Interface, error) { return nil, nil } // If the ifindex is zero, interfaceAddrTable returns addresses // for all network interfaces. Otherwise it returns addresses // for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { +func interfaceAddrTable(ifindex int) ([]Addr, error) { return nil, nil } // If the ifindex is zero, interfaceMulticastAddrTable returns // addresses for all network interfaces. Otherwise it returns // addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { +func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { return nil, nil } diff --git a/src/pkg/net/interface_test.go b/src/pkg/net/interface_test.go index c918f247f..4ce01dc90 100644 --- a/src/pkg/net/interface_test.go +++ b/src/pkg/net/interface_test.go @@ -6,7 +6,6 @@ package net import ( "bytes" - "os" "reflect" "strings" "testing" @@ -25,7 +24,7 @@ func sameInterface(i, j *Interface) bool { func TestInterfaces(t *testing.T) { ift, err := Interfaces() if err != nil { - t.Fatalf("Interfaces() failed: %v", err) + t.Fatalf("Interfaces failed: %v", err) } t.Logf("table: len/cap = %v/%v\n", len(ift), cap(ift)) @@ -44,34 +43,57 @@ func TestInterfaces(t *testing.T) { if !sameInterface(ifxn, &ifi) { t.Fatalf("InterfaceByName(%#q) = %v, want %v", ifi.Name, *ifxn, ifi) } - ifat, err := ifi.Addrs() - if err != nil { - t.Fatalf("Interface.Addrs() failed: %v", err) - } - ifmat, err := ifi.MulticastAddrs() - if err != nil { - t.Fatalf("Interface.MulticastAddrs() failed: %v", err) - } t.Logf("%q: flags %q, ifindex %v, mtu %v\n", ifi.Name, ifi.Flags.String(), ifi.Index, ifi.MTU) - for _, ifa := range ifat { - t.Logf("\tinterface address %q\n", ifa.String()) - } - for _, ifma := range ifmat { - t.Logf("\tjoined group address %q\n", ifma.String()) - } t.Logf("\thardware address %q", ifi.HardwareAddr.String()) + testInterfaceAddrs(t, &ifi) + testInterfaceMulticastAddrs(t, &ifi) } } func TestInterfaceAddrs(t *testing.T) { ifat, err := InterfaceAddrs() if err != nil { - t.Fatalf("InterfaceAddrs() failed: %v", err) + t.Fatalf("InterfaceAddrs failed: %v", err) } t.Logf("table: len/cap = %v/%v\n", len(ifat), cap(ifat)) + testAddrs(t, ifat) +} + +func testInterfaceAddrs(t *testing.T, ifi *Interface) { + ifat, err := ifi.Addrs() + if err != nil { + t.Fatalf("Interface.Addrs failed: %v", err) + } + testAddrs(t, ifat) +} + +func testInterfaceMulticastAddrs(t *testing.T, ifi *Interface) { + ifmat, err := ifi.MulticastAddrs() + if err != nil { + t.Fatalf("Interface.MulticastAddrs failed: %v", err) + } + testMulticastAddrs(t, ifmat) +} +func testAddrs(t *testing.T, ifat []Addr) { for _, ifa := range ifat { - t.Logf("interface address %q\n", ifa.String()) + switch ifa.(type) { + case *IPAddr, *IPNet: + t.Logf("\tinterface address %q\n", ifa.String()) + default: + t.Errorf("\tunexpected type: %T", ifa) + } + } +} + +func testMulticastAddrs(t *testing.T, ifmat []Addr) { + for _, ifma := range ifmat { + switch ifma.(type) { + case *IPAddr: + t.Logf("\tjoined group address %q\n", ifma.String()) + default: + t.Errorf("\tunexpected type: %T", ifma) + } } } @@ -101,11 +123,11 @@ var mactests = []struct { {"0123.4567.89AB.CDEF", HardwareAddr{1, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, ""}, } -func match(err os.Error, s string) bool { +func match(err error, s string) bool { if s == "" { return err == nil } - return err != nil && strings.Contains(err.String(), s) + return err != nil && strings.Contains(err.Error(), s) } func TestParseMAC(t *testing.T) { diff --git a/src/pkg/net/interface_windows.go b/src/pkg/net/interface_windows.go index 7f5169c87..add3dd3b9 100644 --- a/src/pkg/net/interface_windows.go +++ b/src/pkg/net/interface_windows.go @@ -21,7 +21,7 @@ func bytePtrToString(p *uint8) string { return string(a[:i]) } -func getAdapterList() (*syscall.IpAdapterInfo, os.Error) { +func getAdapterList() (*syscall.IpAdapterInfo, error) { b := make([]byte, 1000) l := uint32(len(b)) a := (*syscall.IpAdapterInfo)(unsafe.Pointer(&b[0])) @@ -31,15 +31,15 @@ func getAdapterList() (*syscall.IpAdapterInfo, os.Error) { a = (*syscall.IpAdapterInfo)(unsafe.Pointer(&b[0])) e = syscall.GetAdaptersInfo(a, &l) } - if e != 0 { + if e != nil { return nil, os.NewSyscallError("GetAdaptersInfo", e) } return a, nil } -func getInterfaceList() ([]syscall.InterfaceInfo, os.Error) { +func getInterfaceList() ([]syscall.InterfaceInfo, error) { s, e := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("Socket", e) } defer syscall.Closesocket(s) @@ -48,7 +48,7 @@ func getInterfaceList() ([]syscall.InterfaceInfo, os.Error) { ret := uint32(0) size := uint32(unsafe.Sizeof(ii)) e = syscall.WSAIoctl(s, syscall.SIO_GET_INTERFACE_LIST, nil, 0, (*byte)(unsafe.Pointer(&ii[0])), size, &ret, nil, 0) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("WSAIoctl", e) } c := ret / uint32(unsafe.Sizeof(ii[0])) @@ -58,7 +58,7 @@ func getInterfaceList() ([]syscall.InterfaceInfo, os.Error) { // If the ifindex is zero, interfaceTable returns mappings of all // network interfaces. Otheriwse it returns a mapping of a specific // interface. -func interfaceTable(ifindex int) ([]Interface, os.Error) { +func interfaceTable(ifindex int) ([]Interface, error) { ai, e := getAdapterList() if e != nil { return nil, e @@ -77,7 +77,7 @@ func interfaceTable(ifindex int) ([]Interface, os.Error) { row := syscall.MibIfRow{Index: index} e := syscall.GetIfEntry(&row) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("GetIfEntry", e) } @@ -129,7 +129,7 @@ func interfaceTable(ifindex int) ([]Interface, os.Error) { // If the ifindex is zero, interfaceAddrTable returns addresses // for all network interfaces. Otherwise it returns addresses // for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { +func interfaceAddrTable(ifindex int) ([]Addr, error) { ai, e := getAdapterList() if e != nil { return nil, e @@ -153,6 +153,6 @@ func interfaceAddrTable(ifindex int) ([]Addr, os.Error) { // If the ifindex is zero, interfaceMulticastAddrTable returns // addresses for all network interfaces. Otherwise it returns // addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, os.Error) { +func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { return nil, nil } diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go index 61dc3be90..979d7acd5 100644 --- a/src/pkg/net/ip.go +++ b/src/pkg/net/ip.go @@ -12,8 +12,6 @@ package net -import "os" - // IP address lengths (bytes). const ( IPv4len = 4 @@ -452,6 +450,9 @@ func (n *IPNet) String() string { return nn.String() + "/" + itod(uint(l)) } +// Network returns the address's network name, "ip+net". +func (n *IPNet) Network() string { return "ip+net" } + // Parse IPv4 address (d.d.d.d). func parseIPv4(s string) IP { var p [IPv4len]byte @@ -594,7 +595,7 @@ type ParseError struct { Text string } -func (e *ParseError) String() string { +func (e *ParseError) Error() string { return "invalid " + e.Type + ": " + e.Text } @@ -627,7 +628,7 @@ func ParseIP(s string) IP { // It returns the IP address and the network implied by the IP // and mask. For example, ParseCIDR("192.168.100.1/16") returns // the IP address 192.168.100.1 and the network 192.168.0.0/16. -func ParseCIDR(s string) (IP, *IPNet, os.Error) { +func ParseCIDR(s string) (IP, *IPNet, error) { i := byteIndex(s, '/') if i < 0 { return nil, nil, &ParseError{"CIDR address", s} diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go index 07e627aef..df647ef73 100644 --- a/src/pkg/net/ip_test.go +++ b/src/pkg/net/ip_test.go @@ -7,9 +7,8 @@ package net import ( "bytes" "reflect" - "testing" - "os" "runtime" + "testing" ) func isEqual(a, b []byte) bool { @@ -113,7 +112,7 @@ var parsecidrtests = []struct { in string ip IP net *IPNet - err os.Error + err error }{ {"135.104.0.0/32", IPv4(135, 104, 0, 0), &IPNet{IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 255)}, nil}, {"0.0.0.0/24", IPv4(0, 0, 0, 0), &IPNet{IPv4(0, 0, 0, 0), IPv4Mask(255, 255, 255, 0)}, nil}, diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go index 6894ce656..f9401c110 100644 --- a/src/pkg/net/ipraw_test.go +++ b/src/pkg/net/ipraw_test.go @@ -2,119 +2,191 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TODO(cw): ListenPacket test, Read() test, ipv6 test & -// Dial()/Listen() level tests - package net import ( "bytes" - "flag" "os" "testing" + "time" ) -const ICMP_ECHO_REQUEST = 8 -const ICMP_ECHO_REPLY = 0 - -// returns a suitable 'ping request' packet, with id & seq and a -// payload length of pktlen -func makePingRequest(id, seq, pktlen int, filler []byte) []byte { - p := make([]byte, pktlen) - copy(p[8:], bytes.Repeat(filler, (pktlen-8)/len(filler)+1)) - - p[0] = ICMP_ECHO_REQUEST // type - p[1] = 0 // code - p[2] = 0 // cksum - p[3] = 0 // cksum - p[4] = uint8(id >> 8) // id - p[5] = uint8(id & 0xff) // id - p[6] = uint8(seq >> 8) // sequence - p[7] = uint8(seq & 0xff) // sequence - - // calculate icmp checksum - cklen := len(p) - s := uint32(0) - for i := 0; i < (cklen - 1); i += 2 { - s += uint32(p[i+1])<<8 | uint32(p[i]) - } - if cklen&1 == 1 { - s += uint32(p[cklen-1]) - } - s = (s >> 16) + (s & 0xffff) - s = s + (s >> 16) - - // place checksum back in header; using ^= avoids the - // assumption the checksum bytes are zero - p[2] ^= uint8(^s & 0xff) - p[3] ^= uint8(^s >> 8) - - return p -} - -func parsePingReply(p []byte) (id, seq int) { - id = int(p[4])<<8 | int(p[5]) - seq = int(p[6])<<8 | int(p[7]) - return +var icmpTests = []struct { + net string + laddr string + raddr string + ipv6 bool +}{ + {"ip4:icmp", "", "127.0.0.1", false}, + {"ip6:icmp", "", "::1", true}, } -var srchost = flag.String("srchost", "", "Source of the ICMP ECHO request") -// 127.0.0.1 because this is an IPv4-specific test. -var dsthost = flag.String("dsthost", "127.0.0.1", "Destination for the ICMP ECHO request") - -// test (raw) IP socket using ICMP func TestICMP(t *testing.T) { if os.Getuid() != 0 { t.Logf("test disabled; must be root") return } - var ( - laddr *IPAddr - err os.Error - ) - if *srchost != "" { - laddr, err = ResolveIPAddr("ip4", *srchost) - if err != nil { - t.Fatalf(`net.ResolveIPAddr("ip4", %v") = %v, %v`, *srchost, laddr, err) + seqnum := 61455 + for _, tt := range icmpTests { + if tt.ipv6 && !supportsIPv6 { + continue } + id := os.Getpid() & 0xffff + seqnum++ + echo := newICMPEchoRequest(tt.ipv6, id, seqnum, 128, []byte("Go Go Gadget Ping!!!")) + exchangeICMPEcho(t, tt.net, tt.laddr, tt.raddr, tt.ipv6, echo) } +} - raddr, err := ResolveIPAddr("ip4", *dsthost) +func exchangeICMPEcho(t *testing.T, net, laddr, raddr string, ipv6 bool, echo []byte) { + c, err := ListenPacket(net, laddr) if err != nil { - t.Fatalf(`net.ResolveIPAddr("ip4", %v") = %v, %v`, *dsthost, raddr, err) + t.Errorf("ListenPacket(%#q, %#q) failed: %v", net, laddr, err) + return } + c.SetDeadline(time.Now().Add(100 * time.Millisecond)) + defer c.Close() - c, err := ListenIP("ip4:icmp", laddr) + ra, err := ResolveIPAddr(net, raddr) if err != nil { - t.Fatalf(`net.ListenIP("ip4:icmp", %v) = %v, %v`, *srchost, c, err) + t.Errorf("ResolveIPAddr(%#q, %#q) failed: %v", net, raddr, err) + return } - sendid := os.Getpid() & 0xffff - const sendseq = 61455 - const pingpktlen = 128 - sendpkt := makePingRequest(sendid, sendseq, pingpktlen, []byte("Go Go Gadget Ping!!!")) + waitForReady := make(chan bool) + go icmpEchoTransponder(t, net, raddr, ipv6, waitForReady) + <-waitForReady - n, err := c.WriteToIP(sendpkt, raddr) - if err != nil || n != pingpktlen { - t.Fatalf(`net.WriteToIP(..., %v) = %v, %v`, raddr, n, err) + _, err = c.WriteTo(echo, ra) + if err != nil { + t.Errorf("WriteTo failed: %v", err) + return } - c.SetTimeout(100e6) - resp := make([]byte, 1024) + reply := make([]byte, 256) for { - n, from, err := c.ReadFrom(resp) + _, _, err := c.ReadFrom(reply) if err != nil { - t.Fatalf(`ReadFrom(...) = %v, %v, %v`, n, from, err) + t.Errorf("ReadFrom failed: %v", err) + return } - if resp[0] != ICMP_ECHO_REPLY { + if !ipv6 && reply[0] != ICMP4_ECHO_REPLY { continue } - rcvid, rcvseq := parsePingReply(resp) - if rcvid != sendid || rcvseq != sendseq { - t.Fatalf(`Ping reply saw id,seq=0x%x,0x%x (expected 0x%x, 0x%x)`, rcvid, rcvseq, sendid, sendseq) + if ipv6 && reply[0] != ICMP6_ECHO_REPLY { + continue + } + xid, xseqnum := parseICMPEchoReply(echo) + rid, rseqnum := parseICMPEchoReply(reply) + if rid != xid || rseqnum != xseqnum { + t.Errorf("ID = %v, Seqnum = %v, want ID = %v, Seqnum = %v", rid, rseqnum, xid, xseqnum) + return } + break + } +} + +func icmpEchoTransponder(t *testing.T, net, raddr string, ipv6 bool, waitForReady chan bool) { + c, err := Dial(net, raddr) + if err != nil { + waitForReady <- true + t.Errorf("Dial(%#q, %#q) failed: %v", net, raddr, err) return } - t.Fatalf("saw no ping return") + c.SetDeadline(time.Now().Add(100 * time.Millisecond)) + defer c.Close() + waitForReady <- true + + echo := make([]byte, 256) + var nr int + for { + nr, err = c.Read(echo) + if err != nil { + t.Errorf("Read failed: %v", err) + return + } + if !ipv6 && echo[0] != ICMP4_ECHO_REQUEST { + continue + } + if ipv6 && echo[0] != ICMP6_ECHO_REQUEST { + continue + } + break + } + + if !ipv6 { + echo[0] = ICMP4_ECHO_REPLY + } else { + echo[0] = ICMP6_ECHO_REPLY + } + + _, err = c.Write(echo[:nr]) + if err != nil { + t.Errorf("Write failed: %v", err) + return + } +} + +const ( + ICMP4_ECHO_REQUEST = 8 + ICMP4_ECHO_REPLY = 0 + ICMP6_ECHO_REQUEST = 128 + ICMP6_ECHO_REPLY = 129 +) + +func newICMPEchoRequest(ipv6 bool, id, seqnum, msglen int, filler []byte) []byte { + if !ipv6 { + return newICMPv4EchoRequest(id, seqnum, msglen, filler) + } + return newICMPv6EchoRequest(id, seqnum, msglen, filler) +} + +func newICMPv4EchoRequest(id, seqnum, msglen int, filler []byte) []byte { + b := newICMPInfoMessage(id, seqnum, msglen, filler) + b[0] = ICMP4_ECHO_REQUEST + + // calculate ICMP checksum + cklen := len(b) + s := uint32(0) + for i := 0; i < cklen-1; i += 2 { + s += uint32(b[i+1])<<8 | uint32(b[i]) + } + if cklen&1 == 1 { + s += uint32(b[cklen-1]) + } + s = (s >> 16) + (s & 0xffff) + s = s + (s >> 16) + // place checksum back in header; using ^= avoids the + // assumption the checksum bytes are zero + b[2] ^= uint8(^s & 0xff) + b[3] ^= uint8(^s >> 8) + + return b +} + +func newICMPv6EchoRequest(id, seqnum, msglen int, filler []byte) []byte { + b := newICMPInfoMessage(id, seqnum, msglen, filler) + b[0] = ICMP6_ECHO_REQUEST + return b +} + +func newICMPInfoMessage(id, seqnum, msglen int, filler []byte) []byte { + b := make([]byte, msglen) + copy(b[8:], bytes.Repeat(filler, (msglen-8)/len(filler)+1)) + b[0] = 0 // type + b[1] = 0 // code + b[2] = 0 // checksum + b[3] = 0 // checksum + b[4] = uint8(id >> 8) // identifier + b[5] = uint8(id & 0xff) // identifier + b[6] = uint8(seqnum >> 8) // sequence number + b[7] = uint8(seqnum & 0xff) // sequence number + return b +} + +func parseICMPEchoReply(b []byte) (id, seqnum int) { + id = int(b[4])<<8 | int(b[5]) + seqnum = int(b[6])<<8 | int(b[7]) + return } diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go index 662b9f57b..b23213ee1 100644 --- a/src/pkg/net/iprawsock.go +++ b/src/pkg/net/iprawsock.go @@ -6,10 +6,6 @@ package net -import ( - "os" -) - // IPAddr represents the address of a IP end point. type IPAddr struct { IP IP @@ -29,7 +25,7 @@ func (a *IPAddr) String() string { // names to numeric addresses on the network net, which must be // "ip", "ip4" or "ip6". A literal IPv6 host address must be // enclosed in square brackets, as in "[::]". -func ResolveIPAddr(net, addr string) (*IPAddr, os.Error) { +func ResolveIPAddr(net, addr string) (*IPAddr, error) { ip, err := hostToIP(net, addr) if err != nil { return nil, err @@ -38,7 +34,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, os.Error) { } // Convert "host" into IP address. -func hostToIP(net, host string) (ip IP, err os.Error) { +func hostToIP(net, host string) (ip IP, err error) { var addr IP // Try as an IP address. addr = ParseIP(host) diff --git a/src/pkg/net/iprawsock_plan9.go b/src/pkg/net/iprawsock_plan9.go index 808e17974..859153c2a 100644 --- a/src/pkg/net/iprawsock_plan9.go +++ b/src/pkg/net/iprawsock_plan9.go @@ -8,26 +8,42 @@ package net import ( "os" + "time" ) // IPConn is the implementation of the Conn and PacketConn // interfaces for IP network connections. type IPConn bool +// SetDeadline implements the net.Conn SetDeadline method. +func (c *IPConn) SetDeadline(t time.Time) error { + return os.EPLAN9 +} + +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (c *IPConn) SetReadDeadline(t time.Time) error { + return os.EPLAN9 +} + +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (c *IPConn) SetWriteDeadline(t time.Time) error { + return os.EPLAN9 +} + // Implementation of the Conn interface - see Conn for documentation. -// Read implements the net.Conn Read method. -func (c *IPConn) Read(b []byte) (n int, err os.Error) { +// Read implements the Conn Read method. +func (c *IPConn) Read(b []byte) (int, error) { return 0, os.EPLAN9 } -// Write implements the net.Conn Write method. -func (c *IPConn) Write(b []byte) (n int, err os.Error) { +// Write implements the Conn Write method. +func (c *IPConn) Write(b []byte) (int, error) { return 0, os.EPLAN9 } // Close closes the IP connection. -func (c *IPConn) Close() os.Error { +func (c *IPConn) Close() error { return os.EPLAN9 } @@ -41,52 +57,42 @@ func (c *IPConn) RemoteAddr() Addr { return nil } -// SetTimeout implements the net.Conn SetTimeout method. -func (c *IPConn) SetTimeout(nsec int64) os.Error { - return os.EPLAN9 -} - -// SetReadTimeout implements the net.Conn SetReadTimeout method. -func (c *IPConn) SetReadTimeout(nsec int64) os.Error { - return os.EPLAN9 -} +// IP-specific methods. -// SetWriteTimeout implements the net.Conn SetWriteTimeout method. -func (c *IPConn) SetWriteTimeout(nsec int64) os.Error { - return os.EPLAN9 +// ReadFromIP reads a IP packet from c, copying the payload into b. +// It returns the number of bytes copied into b and the return address +// that was on the packet. +// +// ReadFromIP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetReadDeadline. +func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) { + return 0, nil, os.EPLAN9 } -// IP-specific methods. - -// ReadFrom implements the net.PacketConn ReadFrom method. -func (c *IPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { - err = os.EPLAN9 - return +// ReadFrom implements the PacketConn ReadFrom method. +func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { + return 0, nil, os.EPLAN9 } // WriteToIP writes a IP packet to addr via c, copying the payload from b. // // WriteToIP can be made to time out and return // an error with Timeout() == true after a fixed time limit; -// see SetTimeout and SetWriteTimeout. +// see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. -func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (n int, err os.Error) { +func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { return 0, os.EPLAN9 } -// WriteTo implements the net.PacketConn WriteTo method. -func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { +// WriteTo implements the PacketConn WriteTo method. +func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) { return 0, os.EPLAN9 } -func splitNetProto(netProto string) (net string, proto int, err os.Error) { - err = os.EPLAN9 - return -} - -// DialIP connects to the remote address raddr on the network net, -// which must be "ip", "ip4", or "ip6". -func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) { +// DialIP connects to the remote address raddr on the network protocol netProto, +// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name. +func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { return nil, os.EPLAN9 } @@ -94,6 +100,6 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) { // local address laddr. The returned connection c's ReadFrom // and WriteTo methods can be used to receive and send IP // packets with per-packet addressing. -func ListenIP(netProto string, laddr *IPAddr) (c *IPConn, err os.Error) { +func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { return nil, os.EPLAN9 } diff --git a/src/pkg/net/iprawsock_posix.go b/src/pkg/net/iprawsock_posix.go index 35aceb223..c34ffeb12 100644 --- a/src/pkg/net/iprawsock_posix.go +++ b/src/pkg/net/iprawsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd windows +// +build darwin freebsd linux netbsd openbsd windows // (Raw) IP sockets @@ -10,12 +10,10 @@ package net import ( "os" - "sync" "syscall" + "time" ) -var onceReadProtocols sync.Once - func sockaddrToIP(sa syscall.Sockaddr) Addr { switch sa := sa.(type) { case *syscall.SockaddrInet4: @@ -36,7 +34,7 @@ func (a *IPAddr) family() int { return syscall.AF_INET6 } -func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) { +func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) { return ipToSockaddr(family, a.IP, 0) } @@ -59,14 +57,14 @@ func (c *IPConn) ok() bool { return c != nil && c.fd != nil } // Implementation of the Conn interface - see Conn for documentation. -// Read implements the net.Conn Read method. -func (c *IPConn) Read(b []byte) (n int, err os.Error) { - n, _, err = c.ReadFrom(b) - return +// Read implements the Conn Read method. +func (c *IPConn) Read(b []byte) (int, error) { + n, _, err := c.ReadFrom(b) + return n, err } -// Write implements the net.Conn Write method. -func (c *IPConn) Write(b []byte) (n int, err os.Error) { +// Write implements the Conn Write method. +func (c *IPConn) Write(b []byte) (int, error) { if !c.ok() { return 0, os.EINVAL } @@ -74,7 +72,7 @@ func (c *IPConn) Write(b []byte) (n int, err os.Error) { } // Close closes the IP connection. -func (c *IPConn) Close() os.Error { +func (c *IPConn) Close() error { if !c.ok() { return os.EINVAL } @@ -99,33 +97,33 @@ func (c *IPConn) RemoteAddr() Addr { return c.fd.raddr } -// SetTimeout implements the net.Conn SetTimeout method. -func (c *IPConn) SetTimeout(nsec int64) os.Error { +// SetDeadline implements the Conn SetDeadline method. +func (c *IPConn) SetDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setTimeout(c.fd, nsec) + return setDeadline(c.fd, t) } -// SetReadTimeout implements the net.Conn SetReadTimeout method. -func (c *IPConn) SetReadTimeout(nsec int64) os.Error { +// SetReadDeadline implements the Conn SetReadDeadline method. +func (c *IPConn) SetReadDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setReadTimeout(c.fd, nsec) + return setReadDeadline(c.fd, t) } -// SetWriteTimeout implements the net.Conn SetWriteTimeout method. -func (c *IPConn) SetWriteTimeout(nsec int64) os.Error { +// SetWriteDeadline implements the Conn SetWriteDeadline method. +func (c *IPConn) SetWriteDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setWriteTimeout(c.fd, nsec) + return setWriteDeadline(c.fd, t) } // SetReadBuffer sets the size of the operating system's // receive buffer associated with the connection. -func (c *IPConn) SetReadBuffer(bytes int) os.Error { +func (c *IPConn) SetReadBuffer(bytes int) error { if !c.ok() { return os.EINVAL } @@ -134,7 +132,7 @@ func (c *IPConn) SetReadBuffer(bytes int) os.Error { // SetWriteBuffer sets the size of the operating system's // transmit buffer associated with the connection. -func (c *IPConn) SetWriteBuffer(bytes int) os.Error { +func (c *IPConn) SetWriteBuffer(bytes int) error { if !c.ok() { return os.EINVAL } @@ -148,14 +146,15 @@ func (c *IPConn) SetWriteBuffer(bytes int) os.Error { // that was on the packet. // // ReadFromIP can be made to time out and return an error with -// Timeout() == true after a fixed time limit; see SetTimeout and -// SetReadTimeout. -func (c *IPConn) ReadFromIP(b []byte) (n int, addr *IPAddr, err os.Error) { +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetReadDeadline. +func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) { if !c.ok() { return 0, nil, os.EINVAL } // TODO(cw,rsc): consider using readv if we know the family // type to avoid the header trim/copy + var addr *IPAddr n, sa, err := c.fd.ReadFrom(b) switch sa := sa.(type) { case *syscall.SockaddrInet4: @@ -168,11 +167,11 @@ func (c *IPConn) ReadFromIP(b []byte) (n int, addr *IPAddr, err os.Error) { case *syscall.SockaddrInet6: addr = &IPAddr{sa.Addr[0:]} } - return + return n, addr, err } -// ReadFrom implements the net.PacketConn ReadFrom method. -func (c *IPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { +// ReadFrom implements the PacketConn ReadFrom method. +func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { if !c.ok() { return 0, nil, os.EINVAL } @@ -184,81 +183,37 @@ func (c *IPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { // // WriteToIP can be made to time out and return // an error with Timeout() == true after a fixed time limit; -// see SetTimeout and SetWriteTimeout. +// see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. -func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (n int, err os.Error) { +func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { if !c.ok() { return 0, os.EINVAL } - sa, err1 := addr.sockaddr(c.fd.family) - if err1 != nil { - return 0, &OpError{Op: "write", Net: "ip", Addr: addr, Error: err1} + sa, err := addr.sockaddr(c.fd.family) + if err != nil { + return 0, &OpError{"write", c.fd.net, addr, err} } return c.fd.WriteTo(b, sa) } -// WriteTo implements the net.PacketConn WriteTo method. -func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { +// WriteTo implements the PacketConn WriteTo method. +func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) { if !c.ok() { return 0, os.EINVAL } a, ok := addr.(*IPAddr) if !ok { - return 0, &OpError{"writeto", "ip", addr, os.EINVAL} + return 0, &OpError{"write", c.fd.net, addr, os.EINVAL} } return c.WriteToIP(b, a) } -var protocols map[string]int - -func readProtocols() { - protocols = make(map[string]int) - if file, err := open("/etc/protocols"); err == nil { - for line, ok := file.readLine(); ok; line, ok = file.readLine() { - // tcp 6 TCP # transmission control protocol - if i := byteIndex(line, '#'); i >= 0 { - line = line[0:i] - } - f := getFields(line) - if len(f) < 2 { - continue - } - if proto, _, ok := dtoi(f[1], 0); ok { - protocols[f[0]] = proto - for _, alias := range f[2:] { - protocols[alias] = proto - } - } - } - file.close() - } -} - -func splitNetProto(netProto string) (net string, proto int, err os.Error) { - onceReadProtocols.Do(readProtocols) - i := last(netProto, ':') - if i < 0 { // no colon - return "", 0, os.NewError("no IP protocol specified") - } - net = netProto[0:i] - protostr := netProto[i+1:] - proto, i, ok := dtoi(protostr, 0) - if !ok || i != len(protostr) { - // lookup by name - proto, ok = protocols[protostr] - if ok { - return - } - } - return -} - -// DialIP connects to the remote address raddr on the network net, -// which must be "ip", "ip4", or "ip6". -func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) { - net, proto, err := splitNetProto(netProto) +// DialIP connects to the remote address raddr on the network protocol netProto, +// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name. +func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { + net, proto, err := parseDialNetwork(netProto) if err != nil { - return + return nil, err } switch net { case "ip", "ip4", "ip6": @@ -266,11 +221,11 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) { return nil, UnknownNetworkError(net) } if raddr == nil { - return nil, &OpError{"dial", "ip", nil, errMissingAddress} + return nil, &OpError{"dial", netProto, nil, errMissingAddress} } - fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_RAW, proto, "dial", sockaddrToIP) - if e != nil { - return nil, e + fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_RAW, proto, "dial", sockaddrToIP) + if err != nil { + return nil, err } return newIPConn(fd), nil } @@ -279,29 +234,24 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) { // local address laddr. The returned connection c's ReadFrom // and WriteTo methods can be used to receive and send IP // packets with per-packet addressing. -func ListenIP(netProto string, laddr *IPAddr) (c *IPConn, err os.Error) { - net, proto, err := splitNetProto(netProto) +func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { + net, proto, err := parseDialNetwork(netProto) if err != nil { - return + return nil, err } switch net { case "ip", "ip4", "ip6": default: return nil, UnknownNetworkError(net) } - fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "dial", sockaddrToIP) - if e != nil { - return nil, e + fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "listen", sockaddrToIP) + if err != nil { + return nil, err } return newIPConn(fd), nil } -// BindToDevice binds an IPConn to a network interface. -func (c *IPConn) BindToDevice(device string) os.Error { - if !c.ok() { - return os.EINVAL - } - c.fd.incref() - defer c.fd.decref() - return os.NewSyscallError("setsockopt", syscall.BindToDevice(c.fd.sysfd, device)) -} +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *IPConn) File() (f *os.File, err error) { return c.fd.dup() } diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go index 4e2a5622b..9234f5aff 100644 --- a/src/pkg/net/ipsock.go +++ b/src/pkg/net/ipsock.go @@ -6,14 +6,10 @@ package net -import ( - "os" -) - var supportsIPv6, supportsIPv4map = probeIPv6Stack() func firstFavoriteAddr(filter func(IP) IP, addrs []string) (addr IP) { - if filter == anyaddr { + if filter == nil { // We'll take any IP address, but since the dialing code // does not yet try multiple addresses, prefer to use // an IPv4 address if possible. This is especially relevant @@ -61,14 +57,14 @@ func ipv6only(x IP) IP { type InvalidAddrError string -func (e InvalidAddrError) String() string { return string(e) } +func (e InvalidAddrError) Error() string { return string(e) } func (e InvalidAddrError) Timeout() bool { return false } func (e InvalidAddrError) Temporary() bool { return false } // SplitHostPort splits a network address of the form // "host:port" or "[host]:port" into host and port. // The latter form must be used when host contains a colon. -func SplitHostPort(hostport string) (host, port string, err os.Error) { +func SplitHostPort(hostport string) (host, port string, err error) { // The port starts after the last colon. i := last(hostport, ':') if i < 0 { @@ -102,7 +98,7 @@ func JoinHostPort(host, port string) string { } // Convert "host:port" into IP address and port. -func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) { +func hostPortToIP(net, hostport string) (ip IP, iport int, err error) { var ( addr IP p, i int @@ -117,7 +113,7 @@ func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) { // Try as an IP address. addr = ParseIP(host) if addr == nil { - filter := anyaddr + var filter func(IP) IP if net != "" && net[len(net)-1] == '4' { filter = ipv4only } diff --git a/src/pkg/net/ipsock_plan9.go b/src/pkg/net/ipsock_plan9.go index 9e5da6d38..09d8d6b4e 100644 --- a/src/pkg/net/ipsock_plan9.go +++ b/src/pkg/net/ipsock_plan9.go @@ -7,7 +7,10 @@ package net import ( + "errors" + "io" "os" + "time" ) // probeIPv6Stack returns two boolean values. If the first boolean value is @@ -18,7 +21,7 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { } // parsePlan9Addr parses address of the form [ip!]port (e.g. 127.0.0.1!80). -func parsePlan9Addr(s string) (ip IP, iport int, err os.Error) { +func parsePlan9Addr(s string) (ip IP, iport int, err error) { var ( addr IP p, i int @@ -29,13 +32,13 @@ func parsePlan9Addr(s string) (ip IP, iport int, err os.Error) { if i >= 0 { addr = ParseIP(s[:i]) if addr == nil { - err = os.NewError("net: parsing IP failed") + err = errors.New("net: parsing IP failed") goto Error } } p, _, ok = dtoi(s[i+1:], 0) if !ok { - err = os.NewError("net: parsing port failed") + err = errors.New("net: parsing port failed") goto Error } if p < 0 || p > 0xFFFF { @@ -48,7 +51,7 @@ Error: return nil, 0, err } -func readPlan9Addr(proto, filename string) (addr Addr, err os.Error) { +func readPlan9Addr(proto, filename string) (addr Addr, err error) { var buf [128]byte f, err := os.Open(filename) @@ -69,7 +72,7 @@ func readPlan9Addr(proto, filename string) (addr Addr, err os.Error) { case "udp": addr = &UDPAddr{ip, port} default: - return nil, os.NewError("unknown protocol " + proto) + return nil, errors.New("unknown protocol " + proto) } return addr, nil } @@ -89,7 +92,7 @@ func (c *plan9Conn) ok() bool { return c != nil && c.ctl != nil } // Implementation of the Conn interface - see Conn for documentation. // Read implements the net.Conn Read method. -func (c *plan9Conn) Read(b []byte) (n int, err os.Error) { +func (c *plan9Conn) Read(b []byte) (n int, err error) { if !c.ok() { return 0, os.EINVAL } @@ -100,7 +103,7 @@ func (c *plan9Conn) Read(b []byte) (n int, err os.Error) { } } n, err = c.data.Read(b) - if c.proto == "udp" && err == os.EOF { + if c.proto == "udp" && err == io.EOF { n = 0 err = nil } @@ -108,7 +111,7 @@ func (c *plan9Conn) Read(b []byte) (n int, err os.Error) { } // Write implements the net.Conn Write method. -func (c *plan9Conn) Write(b []byte) (n int, err os.Error) { +func (c *plan9Conn) Write(b []byte) (n int, err error) { if !c.ok() { return 0, os.EINVAL } @@ -122,7 +125,7 @@ func (c *plan9Conn) Write(b []byte) (n int, err os.Error) { } // Close closes the connection. -func (c *plan9Conn) Close() os.Error { +func (c *plan9Conn) Close() error { if !c.ok() { return os.EINVAL } @@ -154,31 +157,22 @@ func (c *plan9Conn) RemoteAddr() Addr { return c.raddr } -// SetTimeout implements the net.Conn SetTimeout method. -func (c *plan9Conn) SetTimeout(nsec int64) os.Error { - if !c.ok() { - return os.EINVAL - } +// SetDeadline implements the net.Conn SetDeadline method. +func (c *plan9Conn) SetDeadline(t time.Time) error { return os.EPLAN9 } -// SetReadTimeout implements the net.Conn SetReadTimeout method. -func (c *plan9Conn) SetReadTimeout(nsec int64) os.Error { - if !c.ok() { - return os.EINVAL - } +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (c *plan9Conn) SetReadDeadline(t time.Time) error { return os.EPLAN9 } -// SetWriteTimeout implements the net.Conn SetWriteTimeout method. -func (c *plan9Conn) SetWriteTimeout(nsec int64) os.Error { - if !c.ok() { - return os.EINVAL - } +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (c *plan9Conn) SetWriteDeadline(t time.Time) error { return os.EPLAN9 } -func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err os.Error) { +func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) { var ( ip IP port int @@ -213,7 +207,7 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, return f, dest, proto, string(buf[:n]), nil } -func dialPlan9(net string, laddr, raddr Addr) (c *plan9Conn, err os.Error) { +func dialPlan9(net string, laddr, raddr Addr) (c *plan9Conn, err error) { f, dest, proto, name, err := startPlan9(net, raddr) if err != nil { return @@ -239,7 +233,7 @@ type plan9Listener struct { laddr Addr } -func listenPlan9(net string, laddr Addr) (l *plan9Listener, err os.Error) { +func listenPlan9(net string, laddr Addr) (l *plan9Listener, err error) { f, dest, proto, name, err := startPlan9(net, laddr) if err != nil { return @@ -265,7 +259,7 @@ func (l *plan9Listener) plan9Conn() *plan9Conn { return newPlan9Conn(l.proto, l.name, l.ctl, l.laddr, nil) } -func (l *plan9Listener) acceptPlan9() (c *plan9Conn, err os.Error) { +func (l *plan9Listener) acceptPlan9() (c *plan9Conn, err error) { f, err := os.Open(l.dir + "/listen") if err != nil { return @@ -287,7 +281,7 @@ func (l *plan9Listener) acceptPlan9() (c *plan9Conn, err os.Error) { return newPlan9Conn(l.proto, name, f, laddr, raddr), nil } -func (l *plan9Listener) Accept() (c Conn, err os.Error) { +func (l *plan9Listener) Accept() (c Conn, err error) { c1, err := l.acceptPlan9() if err != nil { return @@ -295,7 +289,7 @@ func (l *plan9Listener) Accept() (c Conn, err os.Error) { return c1, nil } -func (l *plan9Listener) Close() os.Error { +func (l *plan9Listener) Close() error { if l == nil || l.ctl == nil { return os.EINVAL } diff --git a/src/pkg/net/ipsock_posix.go b/src/pkg/net/ipsock_posix.go index 049df9ea4..3a059f516 100644 --- a/src/pkg/net/ipsock_posix.go +++ b/src/pkg/net/ipsock_posix.go @@ -2,14 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd windows +// +build darwin freebsd linux netbsd openbsd windows package net -import ( - "os" - "syscall" -) +import "syscall" // Should we try to use the IPv4 socket interface if we're // only dealing with IPv4 sockets? As long as the host system @@ -36,8 +33,8 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { } for i := range probes { - s, errno := syscall.Socket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) - if errno != 0 { + s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + if err != nil { continue } defer closesocket(s) @@ -45,8 +42,8 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { if err != nil { continue } - errno = syscall.Bind(s, sa) - if errno != 0 { + err = syscall.Bind(s, sa) + if err != nil { continue } probes[i].ok = true @@ -94,23 +91,18 @@ func favoriteAddrFamily(net string, raddr, laddr sockaddr, mode string) int { return syscall.AF_INET6 } -// TODO(rsc): if syscall.OS == "linux", we're supposed to read -// /proc/sys/net/core/somaxconn, -// to take advantage of kernels that have raised the limit. -func listenBacklog() int { return syscall.SOMAXCONN } - // Internet sockets (TCP, UDP) // A sockaddr represents a TCP or UDP network address that can // be converted into a syscall.Sockaddr. type sockaddr interface { Addr - sockaddr(family int) (syscall.Sockaddr, os.Error) + sockaddr(family int) (syscall.Sockaddr, error) family() int } -func internetSocket(net string, laddr, raddr sockaddr, socktype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err os.Error) { - var oserr os.Error +func internetSocket(net string, laddr, raddr sockaddr, sotype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { + var oserr error var la, ra syscall.Sockaddr family := favoriteAddrFamily(net, raddr, laddr, mode) if laddr != nil { @@ -123,7 +115,7 @@ func internetSocket(net string, laddr, raddr sockaddr, socktype, proto int, mode goto Error } } - fd, oserr = socket(net, family, socktype, proto, la, ra, toAddr) + fd, oserr = socket(net, family, sotype, proto, la, ra, toAddr) if oserr != nil { goto Error } @@ -137,7 +129,7 @@ Error: return nil, &OpError{mode, net, addr, oserr} } -func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, os.Error) { +func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, error) { switch family { case syscall.AF_INET: if len(ip) == 0 { diff --git a/src/pkg/net/lookup_plan9.go b/src/pkg/net/lookup_plan9.go index ee0c9e879..c0bb9225a 100644 --- a/src/pkg/net/lookup_plan9.go +++ b/src/pkg/net/lookup_plan9.go @@ -5,10 +5,11 @@ package net import ( + "errors" "os" ) -func query(filename, query string, bufSize int) (res []string, err os.Error) { +func query(filename, query string, bufSize int) (res []string, err error) { file, err := os.OpenFile(filename, os.O_RDWR, 0) if err != nil { return @@ -34,7 +35,7 @@ func query(filename, query string, bufSize int) (res []string, err os.Error) { return } -func queryCS(net, host, service string) (res []string, err os.Error) { +func queryCS(net, host, service string) (res []string, err error) { switch net { case "tcp4", "tcp6": net = "tcp" @@ -47,9 +48,9 @@ func queryCS(net, host, service string) (res []string, err os.Error) { return query("/net/cs", net+"!"+host+"!"+service, 128) } -func queryCS1(net string, ip IP, port int) (clone, dest string, err os.Error) { +func queryCS1(net string, ip IP, port int) (clone, dest string, err error) { ips := "*" - if !ip.IsUnspecified() { + if len(ip) != 0 && !ip.IsUnspecified() { ips = ip.String() } lines, err := queryCS(net, ips, itoa(port)) @@ -58,19 +59,22 @@ func queryCS1(net string, ip IP, port int) (clone, dest string, err os.Error) { } f := getFields(lines[0]) if len(f) < 2 { - return "", "", os.NewError("net: bad response from ndb/cs") + return "", "", errors.New("net: bad response from ndb/cs") } clone, dest = f[0], f[1] return } -func queryDNS(addr string, typ string) (res []string, err os.Error) { +func queryDNS(addr string, typ string) (res []string, err error) { return query("/net/dns", addr+" "+typ, 1024) } -// LookupHost looks up the given host using the local resolver. -// It returns an array of that host's addresses. -func LookupHost(host string) (addrs []string, err os.Error) { +func lookupProtocol(name string) (proto int, err error) { + // TODO: Implement this + return 0, os.EPLAN9 +} + +func lookupHost(host string) (addrs []string, err error) { // Use /net/cs insead of /net/dns because cs knows about // host names in local network (e.g. from /lib/ndb/local) lines, err := queryCS("tcp", host, "1") @@ -94,9 +98,7 @@ func LookupHost(host string) (addrs []string, err os.Error) { return } -// LookupIP looks up host using the local resolver. -// It returns an array of that host's IPv4 and IPv6 addresses. -func LookupIP(host string) (ips []IP, err os.Error) { +func lookupIP(host string) (ips []IP, err error) { addrs, err := LookupHost(host) if err != nil { return @@ -109,8 +111,7 @@ func LookupIP(host string) (ips []IP, err os.Error) { return } -// LookupPort looks up the port for the given network and service. -func LookupPort(network, service string) (port int, err os.Error) { +func lookupPort(network, service string) (port int, err error) { switch network { case "tcp4", "tcp6": network = "tcp" @@ -139,11 +140,7 @@ func LookupPort(network, service string) (port int, err os.Error) { return 0, unknownPortError } -// LookupCNAME returns the canonical DNS host for the given name. -// Callers that do not care about the canonical name can call -// LookupHost or LookupIP directly; both take care of resolving -// the canonical name as part of the lookup. -func LookupCNAME(name string) (cname string, err os.Error) { +func lookupCNAME(name string) (cname string, err error) { lines, err := queryDNS(name, "cname") if err != nil { return @@ -153,16 +150,16 @@ func LookupCNAME(name string) (cname string, err os.Error) { return f[2] + ".", nil } } - return "", os.NewError("net: bad response from ndb/dns") + return "", errors.New("net: bad response from ndb/dns") } -// LookupSRV tries to resolve an SRV query of the given service, -// protocol, and domain name, as specified in RFC 2782. In most cases -// the proto argument can be the same as the corresponding -// Addr.Network(). The returned records are sorted by priority -// and randomized by weight within a priority. -func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { - target := "_" + service + "._" + proto + "." + name +func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { + var target string + if service == "" && proto == "" { + target = name + } else { + target = "_" + service + "._" + proto + "." + name + } lines, err := queryDNS(target, "srv") if err != nil { return @@ -185,8 +182,7 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os. return } -// LookupMX returns the DNS MX records for the given domain name sorted by preference. -func LookupMX(name string) (mx []*MX, err os.Error) { +func lookupMX(name string) (mx []*MX, err error) { lines, err := queryDNS(name, "mx") if err != nil { return @@ -204,14 +200,20 @@ func LookupMX(name string) (mx []*MX, err os.Error) { return } -// LookupTXT returns the DNS TXT records for the given domain name. -func LookupTXT(name string) (txt []string, err os.Error) { - return nil, os.NewError("net.LookupTXT is not implemented on Plan 9") +func lookupTXT(name string) (txt []string, err error) { + lines, err := queryDNS(name, "txt") + if err != nil { + return + } + for _, line := range lines { + if i := byteIndex(line, '\t'); i >= 0 { + txt = append(txt, line[i+1:]) + } + } + return } -// LookupAddr performs a reverse lookup for the given address, returning a list -// of names mapping to that address. -func LookupAddr(addr string) (name []string, err os.Error) { +func lookupAddr(addr string) (name []string, err error) { arpa, err := reverseaddr(addr) if err != nil { return diff --git a/src/pkg/net/lookup_test.go b/src/pkg/net/lookup_test.go index 41066fe48..9a39ca8a1 100644 --- a/src/pkg/net/lookup_test.go +++ b/src/pkg/net/lookup_test.go @@ -26,6 +26,15 @@ func TestGoogleSRV(t *testing.T) { if len(addrs) == 0 { t.Errorf("no results") } + + // Non-standard back door. + _, addrs, err = LookupSRV("", "", "_xmpp-server._tcp.google.com") + if err != nil { + t.Errorf("back door failed: %s", err) + } + if len(addrs) == 0 { + t.Errorf("back door no results") + } } func TestGmailMX(t *testing.T) { @@ -43,10 +52,6 @@ func TestGmailMX(t *testing.T) { } func TestGmailTXT(t *testing.T) { - if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { - t.Logf("LookupTXT is not implemented on Windows or Plan 9") - return - } if testing.Short() || avoidMacFirewall { t.Logf("skipping test to avoid external network") return diff --git a/src/pkg/net/lookup_unix.go b/src/pkg/net/lookup_unix.go index 7368b751e..d500a1240 100644 --- a/src/pkg/net/lookup_unix.go +++ b/src/pkg/net/lookup_unix.go @@ -2,17 +2,57 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd +// +build darwin freebsd linux netbsd openbsd package net import ( - "os" + "errors" + "sync" ) -// LookupHost looks up the given host using the local resolver. -// It returns an array of that host's addresses. -func LookupHost(host string) (addrs []string, err os.Error) { +var ( + protocols map[string]int + onceReadProtocols sync.Once +) + +// readProtocols loads contents of /etc/protocols into protocols map +// for quick access. +func readProtocols() { + protocols = make(map[string]int) + if file, err := open("/etc/protocols"); err == nil { + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + // tcp 6 TCP # transmission control protocol + if i := byteIndex(line, '#'); i >= 0 { + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 { + continue + } + if proto, _, ok := dtoi(f[1], 0); ok { + protocols[f[0]] = proto + for _, alias := range f[2:] { + protocols[alias] = proto + } + } + } + file.close() + } +} + +// lookupProtocol looks up IP protocol name in /etc/protocols and +// returns correspondent protocol number. +func lookupProtocol(name string) (proto int, err error) { + onceReadProtocols.Do(readProtocols) + proto, found := protocols[name] + if !found { + return 0, errors.New("unknown IP protocol specified: " + name) + } + return +} + +func lookupHost(host string) (addrs []string, err error) { addrs, err, ok := cgoLookupHost(host) if !ok { addrs, err = goLookupHost(host) @@ -20,9 +60,7 @@ func LookupHost(host string) (addrs []string, err os.Error) { return } -// LookupIP looks up host using the local resolver. -// It returns an array of that host's IPv4 and IPv6 addresses. -func LookupIP(host string) (addrs []IP, err os.Error) { +func lookupIP(host string) (addrs []IP, err error) { addrs, err, ok := cgoLookupIP(host) if !ok { addrs, err = goLookupIP(host) @@ -30,8 +68,7 @@ func LookupIP(host string) (addrs []IP, err os.Error) { return } -// LookupPort looks up the port for the given network and service. -func LookupPort(network, service string) (port int, err os.Error) { +func lookupPort(network, service string) (port int, err error) { port, err, ok := cgoLookupPort(network, service) if !ok { port, err = goLookupPort(network, service) @@ -39,11 +76,7 @@ func LookupPort(network, service string) (port int, err os.Error) { return } -// LookupCNAME returns the canonical DNS host for the given name. -// Callers that do not care about the canonical name can call -// LookupHost or LookupIP directly; both take care of resolving -// the canonical name as part of the lookup. -func LookupCNAME(name string) (cname string, err os.Error) { +func lookupCNAME(name string) (cname string, err error) { cname, err, ok := cgoLookupCNAME(name) if !ok { cname, err = goLookupCNAME(name) @@ -51,13 +84,13 @@ func LookupCNAME(name string) (cname string, err os.Error) { return } -// LookupSRV tries to resolve an SRV query of the given service, -// protocol, and domain name, as specified in RFC 2782. In most cases -// the proto argument can be the same as the corresponding -// Addr.Network(). The returned records are sorted by priority -// and randomized by weight within a priority. -func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { - target := "_" + service + "._" + proto + "." + name +func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { + var target string + if service == "" && proto == "" { + target = name + } else { + target = "_" + service + "._" + proto + "." + name + } var records []dnsRR cname, records, err = lookup(target, dnsTypeSRV) if err != nil { @@ -72,8 +105,7 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os. return } -// LookupMX returns the DNS MX records for the given domain name sorted by preference. -func LookupMX(name string) (mx []*MX, err os.Error) { +func lookupMX(name string) (mx []*MX, err error) { _, records, err := lookup(name, dnsTypeMX) if err != nil { return @@ -87,8 +119,7 @@ func LookupMX(name string) (mx []*MX, err os.Error) { return } -// LookupTXT returns the DNS TXT records for the given domain name. -func LookupTXT(name string) (txt []string, err os.Error) { +func lookupTXT(name string) (txt []string, err error) { _, records, err := lookup(name, dnsTypeTXT) if err != nil { return @@ -100,9 +131,7 @@ func LookupTXT(name string) (txt []string, err os.Error) { return } -// LookupAddr performs a reverse lookup for the given address, returning a list -// of names mapping to that address. -func LookupAddr(addr string) (name []string, err os.Error) { +func lookupAddr(addr string) (name []string, err error) { name = lookupStaticAddr(addr) if len(name) > 0 { return diff --git a/src/pkg/net/lookup_windows.go b/src/pkg/net/lookup_windows.go index b33c7f949..dfe2ff6f1 100644 --- a/src/pkg/net/lookup_windows.go +++ b/src/pkg/net/lookup_windows.go @@ -5,16 +5,30 @@ package net import ( - "syscall" - "unsafe" "os" "sync" + "syscall" + "unsafe" ) -var hostentLock sync.Mutex -var serventLock sync.Mutex +var ( + protoentLock sync.Mutex + hostentLock sync.Mutex + serventLock sync.Mutex +) + +// lookupProtocol looks up IP protocol name and returns correspondent protocol number. +func lookupProtocol(name string) (proto int, err error) { + protoentLock.Lock() + defer protoentLock.Unlock() + p, e := syscall.GetProtoByName(name) + if e != nil { + return 0, os.NewSyscallError("GetProtoByName", e) + } + return int(p.Proto), nil +} -func LookupHost(name string) (addrs []string, err os.Error) { +func lookupHost(name string) (addrs []string, err error) { ips, err := LookupIP(name) if err != nil { return @@ -26,11 +40,11 @@ func LookupHost(name string) (addrs []string, err os.Error) { return } -func LookupIP(name string) (addrs []IP, err os.Error) { +func lookupIP(name string) (addrs []IP, err error) { hostentLock.Lock() defer hostentLock.Unlock() h, e := syscall.GetHostByName(name) - if e != 0 { + if e != nil { return nil, os.NewSyscallError("GetHostByName", e) } switch h.AddrType { @@ -47,7 +61,7 @@ func LookupIP(name string) (addrs []IP, err os.Error) { return addrs, nil } -func LookupPort(network, service string) (port int, err os.Error) { +func lookupPort(network, service string) (port int, err error) { switch network { case "tcp4", "tcp6": network = "tcp" @@ -57,17 +71,17 @@ func LookupPort(network, service string) (port int, err os.Error) { serventLock.Lock() defer serventLock.Unlock() s, e := syscall.GetServByName(service, network) - if e != 0 { + if e != nil { return 0, os.NewSyscallError("GetServByName", e) } return int(syscall.Ntohs(s.Port)), nil } -func LookupCNAME(name string) (cname string, err os.Error) { +func lookupCNAME(name string) (cname string, err error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) - if int(e) != 0 { - return "", os.NewSyscallError("LookupCNAME", int(e)) + if e != nil { + return "", os.NewSyscallError("LookupCNAME", e) } defer syscall.DnsRecordListFree(r, 1) if r != nil && r.Type == syscall.DNS_TYPE_CNAME { @@ -77,12 +91,17 @@ func LookupCNAME(name string) (cname string, err os.Error) { return } -func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { +func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { + var target string + if service == "" && proto == "" { + target = name + } else { + target = "_" + service + "._" + proto + "." + name + } var r *syscall.DNSRecord - target := "_" + service + "._" + proto + "." + name e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil) - if int(e) != 0 { - return "", nil, os.NewSyscallError("LookupSRV", int(e)) + if e != nil { + return "", nil, os.NewSyscallError("LookupSRV", e) } defer syscall.DnsRecordListFree(r, 1) addrs = make([]*SRV, 0, 10) @@ -94,11 +113,11 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os. return name, addrs, nil } -func LookupMX(name string) (mx []*MX, err os.Error) { +func lookupMX(name string) (mx []*MX, err error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil) - if int(e) != 0 { - return nil, os.NewSyscallError("LookupMX", int(e)) + if e != nil { + return nil, os.NewSyscallError("LookupMX", e) } defer syscall.DnsRecordListFree(r, 1) mx = make([]*MX, 0, 10) @@ -110,19 +129,33 @@ func LookupMX(name string) (mx []*MX, err os.Error) { return mx, nil } -func LookupTXT(name string) (txt []string, err os.Error) { - return nil, os.NewError("net.LookupTXT is not implemented on Windows") +func lookupTXT(name string) (txt []string, err error) { + var r *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil) + if e != nil { + return nil, os.NewSyscallError("LookupTXT", e) + } + defer syscall.DnsRecordListFree(r, 1) + txt = make([]string, 0, 10) + if r != nil && r.Type == syscall.DNS_TYPE_TEXT { + d := (*syscall.DNSTXTData)(unsafe.Pointer(&r.Data[0])) + for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] { + s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:]) + txt = append(txt, s) + } + } + return } -func LookupAddr(addr string) (name []string, err os.Error) { +func lookupAddr(addr string) (name []string, err error) { arpa, err := reverseaddr(addr) if err != nil { return nil, err } var r *syscall.DNSRecord e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil) - if int(e) != 0 { - return nil, os.NewSyscallError("LookupAddr", int(e)) + if e != nil { + return nil, os.NewSyscallError("LookupAddr", e) } defer syscall.DnsRecordListFree(r, 1) name = make([]string, 0, 10) diff --git a/src/pkg/net/mail/Makefile b/src/pkg/net/mail/Makefile new file mode 100644 index 000000000..acb1c2a6d --- /dev/null +++ b/src/pkg/net/mail/Makefile @@ -0,0 +1,11 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=net/mail +GOFILES=\ + message.go\ + +include ../../../Make.pkg diff --git a/src/pkg/net/mail/message.go b/src/pkg/net/mail/message.go new file mode 100644 index 000000000..bf22c711e --- /dev/null +++ b/src/pkg/net/mail/message.go @@ -0,0 +1,524 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package mail implements parsing of mail messages. + +For the most part, this package follows the syntax as specified by RFC 5322. +Notable divergences: + * Obsolete address formats are not parsed, including addresses with + embedded route information. + * Group addresses are not parsed. + * The full range of spacing (the CFWS syntax element) is not supported, + such as breaking addresses across lines. +*/ +package mail + +import ( + "bufio" + "bytes" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net/textproto" + "strconv" + "strings" + "time" +) + +var debug = debugT(false) + +type debugT bool + +func (d debugT) Printf(format string, args ...interface{}) { + if d { + log.Printf(format, args...) + } +} + +// A Message represents a parsed mail message. +type Message struct { + Header Header + Body io.Reader +} + +// ReadMessage reads a message from r. +// The headers are parsed, and the body of the message will be reading from r. +func ReadMessage(r io.Reader) (msg *Message, err error) { + tp := textproto.NewReader(bufio.NewReader(r)) + + hdr, err := tp.ReadMIMEHeader() + if err != nil { + return nil, err + } + + return &Message{ + Header: Header(hdr), + Body: tp.R, + }, nil +} + +// Layouts suitable for passing to time.Parse. +// These are tried in order. +var dateLayouts []string + +func init() { + // Generate layouts based on RFC 5322, section 3.3. + + dows := [...]string{"", "Mon, "} // day-of-week + days := [...]string{"2", "02"} // day = 1*2DIGIT + years := [...]string{"2006", "06"} // year = 4*DIGIT / 2*DIGIT + seconds := [...]string{":05", ""} // second + zones := [...]string{"-0700", "MST"} // zone = (("+" / "-") 4DIGIT) / "GMT" / ... + + for _, dow := range dows { + for _, day := range days { + for _, year := range years { + for _, second := range seconds { + for _, zone := range zones { + s := dow + day + " Jan " + year + " 15:04" + second + " " + zone + dateLayouts = append(dateLayouts, s) + } + } + } + } + } +} + +func parseDate(date string) (time.Time, error) { + for _, layout := range dateLayouts { + t, err := time.Parse(layout, date) + if err == nil { + return t, nil + } + } + return time.Time{}, errors.New("mail: header could not be parsed") +} + +// A Header represents the key-value pairs in a mail message header. +type Header map[string][]string + +// Get gets the first value associated with the given key. +// If there are no values associated with the key, Get returns "". +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +var ErrHeaderNotPresent = errors.New("mail: header not in message") + +// Date parses the Date header field. +func (h Header) Date() (time.Time, error) { + hdr := h.Get("Date") + if hdr == "" { + return time.Time{}, ErrHeaderNotPresent + } + return parseDate(hdr) +} + +// AddressList parses the named header field as a list of addresses. +func (h Header) AddressList(key string) ([]*Address, error) { + hdr := h.Get(key) + if hdr == "" { + return nil, ErrHeaderNotPresent + } + return newAddrParser(hdr).parseAddressList() +} + +// Address represents a single mail address. +// An address such as "Barry Gibbs <bg@example.com>" is represented +// as Address{Name: "Barry Gibbs", Address: "bg@example.com"}. +type Address struct { + Name string // Proper name; may be empty. + Address string // user@domain +} + +// String formats the address as a valid RFC 5322 address. +// If the address's name contains non-ASCII characters +// the name will be rendered according to RFC 2047. +func (a *Address) String() string { + s := "<" + a.Address + ">" + if a.Name == "" { + return s + } + // If every character is printable ASCII, quoting is simple. + allPrintable := true + for i := 0; i < len(a.Name); i++ { + if !isVchar(a.Name[i]) { + allPrintable = false + break + } + } + if allPrintable { + b := bytes.NewBufferString(`"`) + for i := 0; i < len(a.Name); i++ { + if !isQtext(a.Name[i]) { + b.WriteByte('\\') + } + b.WriteByte(a.Name[i]) + } + b.WriteString(`" `) + b.WriteString(s) + return b.String() + } + + // UTF-8 "Q" encoding + b := bytes.NewBufferString("=?utf-8?q?") + for i := 0; i < len(a.Name); i++ { + switch c := a.Name[i]; { + case c == ' ': + b.WriteByte('_') + case isVchar(c) && c != '=' && c != '?' && c != '_': + b.WriteByte(c) + default: + fmt.Fprintf(b, "=%02X", c) + } + } + b.WriteString("?= ") + b.WriteString(s) + return b.String() +} + +type addrParser []byte + +func newAddrParser(s string) *addrParser { + p := addrParser(s) + return &p +} + +func (p *addrParser) parseAddressList() ([]*Address, error) { + var list []*Address + for { + p.skipSpace() + addr, err := p.parseAddress() + if err != nil { + return nil, err + } + list = append(list, addr) + + p.skipSpace() + if p.empty() { + break + } + if !p.consume(',') { + return nil, errors.New("mail: expected comma") + } + } + return list, nil +} + +// parseAddress parses a single RFC 5322 address at the start of p. +func (p *addrParser) parseAddress() (addr *Address, err error) { + debug.Printf("parseAddress: %q", *p) + p.skipSpace() + if p.empty() { + return nil, errors.New("mail: no address") + } + + // address = name-addr / addr-spec + // TODO(dsymonds): Support parsing group address. + + // addr-spec has a more restricted grammar than name-addr, + // so try parsing it first, and fallback to name-addr. + // TODO(dsymonds): Is this really correct? + spec, err := p.consumeAddrSpec() + if err == nil { + return &Address{ + Address: spec, + }, err + } + debug.Printf("parseAddress: not an addr-spec: %v", err) + debug.Printf("parseAddress: state is now %q", *p) + + // display-name + var displayName string + if p.peek() != '<' { + displayName, err = p.consumePhrase() + if err != nil { + return nil, err + } + } + debug.Printf("parseAddress: displayName=%q", displayName) + + // angle-addr = "<" addr-spec ">" + p.skipSpace() + if !p.consume('<') { + return nil, errors.New("mail: no angle-addr") + } + spec, err = p.consumeAddrSpec() + if err != nil { + return nil, err + } + if !p.consume('>') { + return nil, errors.New("mail: unclosed angle-addr") + } + debug.Printf("parseAddress: spec=%q", spec) + + return &Address{ + Name: displayName, + Address: spec, + }, nil +} + +// consumeAddrSpec parses a single RFC 5322 addr-spec at the start of p. +func (p *addrParser) consumeAddrSpec() (spec string, err error) { + debug.Printf("consumeAddrSpec: %q", *p) + + orig := *p + defer func() { + if err != nil { + *p = orig + } + }() + + // local-part = dot-atom / quoted-string + var localPart string + p.skipSpace() + if p.empty() { + return "", errors.New("mail: no addr-spec") + } + if p.peek() == '"' { + // quoted-string + debug.Printf("consumeAddrSpec: parsing quoted-string") + localPart, err = p.consumeQuotedString() + } else { + // dot-atom + debug.Printf("consumeAddrSpec: parsing dot-atom") + localPart, err = p.consumeAtom(true) + } + if err != nil { + debug.Printf("consumeAddrSpec: failed: %v", err) + return "", err + } + + if !p.consume('@') { + return "", errors.New("mail: missing @ in addr-spec") + } + + // domain = dot-atom / domain-literal + var domain string + p.skipSpace() + if p.empty() { + return "", errors.New("mail: no domain in addr-spec") + } + // TODO(dsymonds): Handle domain-literal + domain, err = p.consumeAtom(true) + if err != nil { + return "", err + } + + return localPart + "@" + domain, nil +} + +// consumePhrase parses the RFC 5322 phrase at the start of p. +func (p *addrParser) consumePhrase() (phrase string, err error) { + debug.Printf("consumePhrase: [%s]", *p) + // phrase = 1*word + var words []string + for { + // word = atom / quoted-string + var word string + p.skipSpace() + if p.empty() { + return "", errors.New("mail: missing phrase") + } + if p.peek() == '"' { + // quoted-string + word, err = p.consumeQuotedString() + } else { + // atom + word, err = p.consumeAtom(false) + } + + // RFC 2047 encoded-word starts with =?, ends with ?=, and has two other ?s. + if err == nil && strings.HasPrefix(word, "=?") && strings.HasSuffix(word, "?=") && strings.Count(word, "?") == 4 { + word, err = decodeRFC2047Word(word) + } + + if err != nil { + break + } + debug.Printf("consumePhrase: consumed %q", word) + words = append(words, word) + } + // Ignore any error if we got at least one word. + if err != nil && len(words) == 0 { + debug.Printf("consumePhrase: hit err: %v", err) + return "", errors.New("mail: missing word in phrase") + } + phrase = strings.Join(words, " ") + return phrase, nil +} + +// consumeQuotedString parses the quoted string at the start of p. +func (p *addrParser) consumeQuotedString() (qs string, err error) { + // Assume first byte is '"'. + i := 1 + qsb := make([]byte, 0, 10) +Loop: + for { + if i >= p.len() { + return "", errors.New("mail: unclosed quoted-string") + } + switch c := (*p)[i]; { + case c == '"': + break Loop + case c == '\\': + if i+1 == p.len() { + return "", errors.New("mail: unclosed quoted-string") + } + qsb = append(qsb, (*p)[i+1]) + i += 2 + case isQtext(c), c == ' ' || c == '\t': + // qtext (printable US-ASCII excluding " and \), or + // FWS (almost; we're ignoring CRLF) + qsb = append(qsb, c) + i++ + default: + return "", fmt.Errorf("mail: bad character in quoted-string: %q", c) + } + } + *p = (*p)[i+1:] + return string(qsb), nil +} + +// consumeAtom parses an RFC 5322 atom at the start of p. +// If dot is true, consumeAtom parses an RFC 5322 dot-atom instead. +func (p *addrParser) consumeAtom(dot bool) (atom string, err error) { + if !isAtext(p.peek(), false) { + return "", errors.New("mail: invalid string") + } + i := 1 + for ; i < p.len() && isAtext((*p)[i], dot); i++ { + } + // TODO(dsymonds): Remove the []byte() conversion here when 6g doesn't need it. + atom, *p = string([]byte((*p)[:i])), (*p)[i:] + return atom, nil +} + +func (p *addrParser) consume(c byte) bool { + if p.empty() || p.peek() != c { + return false + } + *p = (*p)[1:] + return true +} + +// skipSpace skips the leading space and tab characters. +func (p *addrParser) skipSpace() { + *p = bytes.TrimLeft(*p, " \t") +} + +func (p *addrParser) peek() byte { + return (*p)[0] +} + +func (p *addrParser) empty() bool { + return p.len() == 0 +} + +func (p *addrParser) len() int { + return len(*p) +} + +func decodeRFC2047Word(s string) (string, error) { + fields := strings.Split(s, "?") + if len(fields) != 5 || fields[0] != "=" || fields[4] != "=" { + return "", errors.New("mail: address not RFC 2047 encoded") + } + charset, enc := strings.ToLower(fields[1]), strings.ToLower(fields[2]) + if charset != "iso-8859-1" && charset != "utf-8" { + return "", fmt.Errorf("mail: charset not supported: %q", charset) + } + + in := bytes.NewBufferString(fields[3]) + var r io.Reader + switch enc { + case "b": + r = base64.NewDecoder(base64.StdEncoding, in) + case "q": + r = qDecoder{r: in} + default: + return "", fmt.Errorf("mail: RFC 2047 encoding not supported: %q", enc) + } + + dec, err := ioutil.ReadAll(r) + if err != nil { + return "", err + } + + switch charset { + case "iso-8859-1": + b := new(bytes.Buffer) + for _, c := range dec { + b.WriteRune(rune(c)) + } + return b.String(), nil + case "utf-8": + return string(dec), nil + } + panic("unreachable") +} + +type qDecoder struct { + r io.Reader + scratch [2]byte +} + +func (qd qDecoder) Read(p []byte) (n int, err error) { + // This method writes at most one byte into p. + if len(p) == 0 { + return 0, nil + } + if _, err := qd.r.Read(qd.scratch[:1]); err != nil { + return 0, err + } + switch c := qd.scratch[0]; { + case c == '=': + if _, err := io.ReadFull(qd.r, qd.scratch[:2]); err != nil { + return 0, err + } + x, err := strconv.ParseInt(string(qd.scratch[:2]), 16, 64) + if err != nil { + return 0, fmt.Errorf("mail: invalid RFC 2047 encoding: %q", qd.scratch[:2]) + } + p[0] = byte(x) + case c == '_': + p[0] = ' ' + default: + p[0] = c + } + return 1, nil +} + +var atextChars = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + "abcdefghijklmnopqrstuvwxyz" + + "0123456789" + + "!#$%&'*+-/=?^_`{|}~") + +// isAtext returns true if c is an RFC 5322 atext character. +// If dot is true, period is included. +func isAtext(c byte, dot bool) bool { + if dot && c == '.' { + return true + } + return bytes.IndexByte(atextChars, c) >= 0 +} + +// isQtext returns true if c is an RFC 5322 qtest character. +func isQtext(c byte) bool { + // Printable US-ASCII, excluding backslash or quote. + if c == '\\' || c == '"' { + return false + } + return '!' <= c && c <= '~' +} + +// isVchar returns true if c is an RFC 5322 VCHAR character. +func isVchar(c byte) bool { + // Visible (printing) characters. + return '!' <= c && c <= '~' +} diff --git a/src/pkg/net/mail/message_test.go b/src/pkg/net/mail/message_test.go new file mode 100644 index 000000000..671ff2efa --- /dev/null +++ b/src/pkg/net/mail/message_test.go @@ -0,0 +1,261 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mail + +import ( + "bytes" + "io/ioutil" + "reflect" + "testing" + "time" +) + +var parseTests = []struct { + in string + header Header + body string +}{ + { + // RFC 5322, Appendix A.1.1 + in: `From: John Doe <jdoe@machine.example> +To: Mary Smith <mary@example.net> +Subject: Saying Hello +Date: Fri, 21 Nov 1997 09:55:06 -0600 +Message-ID: <1234@local.machine.example> + +This is a message just to say hello. +So, "Hello". +`, + header: Header{ + "From": []string{"John Doe <jdoe@machine.example>"}, + "To": []string{"Mary Smith <mary@example.net>"}, + "Subject": []string{"Saying Hello"}, + "Date": []string{"Fri, 21 Nov 1997 09:55:06 -0600"}, + "Message-Id": []string{"<1234@local.machine.example>"}, + }, + body: "This is a message just to say hello.\nSo, \"Hello\".\n", + }, +} + +func TestParsing(t *testing.T) { + for i, test := range parseTests { + msg, err := ReadMessage(bytes.NewBuffer([]byte(test.in))) + if err != nil { + t.Errorf("test #%d: Failed parsing message: %v", i, err) + continue + } + if !headerEq(msg.Header, test.header) { + t.Errorf("test #%d: Incorrectly parsed message header.\nGot:\n%+v\nWant:\n%+v", + i, msg.Header, test.header) + } + body, err := ioutil.ReadAll(msg.Body) + if err != nil { + t.Errorf("test #%d: Failed reading body: %v", i, err) + continue + } + bodyStr := string(body) + if bodyStr != test.body { + t.Errorf("test #%d: Incorrectly parsed message body.\nGot:\n%+v\nWant:\n%+v", + i, bodyStr, test.body) + } + } +} + +func headerEq(a, b Header) bool { + if len(a) != len(b) { + return false + } + for k, as := range a { + bs, ok := b[k] + if !ok { + return false + } + if !reflect.DeepEqual(as, bs) { + return false + } + } + return true +} + +func TestDateParsing(t *testing.T) { + tests := []struct { + dateStr string + exp time.Time + }{ + // RFC 5322, Appendix A.1.1 + { + "Fri, 21 Nov 1997 09:55:06 -0600", + time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)), + }, + // RFC5322, Appendix A.6.2 + // Obsolete date. + { + "21 Nov 97 09:55:06 GMT", + time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("GMT", 0)), + }, + } + for _, test := range tests { + hdr := Header{ + "Date": []string{test.dateStr}, + } + date, err := hdr.Date() + if err != nil { + t.Errorf("Failed parsing %q: %v", test.dateStr, err) + continue + } + if !date.Equal(test.exp) { + t.Errorf("Parse of %q: got %+v, want %+v", test.dateStr, date, test.exp) + } + } +} + +func TestAddressParsing(t *testing.T) { + tests := []struct { + addrsStr string + exp []*Address + }{ + // Bare address + { + `jdoe@machine.example`, + []*Address{{ + Address: "jdoe@machine.example", + }}, + }, + // RFC 5322, Appendix A.1.1 + { + `John Doe <jdoe@machine.example>`, + []*Address{{ + Name: "John Doe", + Address: "jdoe@machine.example", + }}, + }, + // RFC 5322, Appendix A.1.2 + { + `"Joe Q. Public" <john.q.public@example.com>`, + []*Address{{ + Name: "Joe Q. Public", + Address: "john.q.public@example.com", + }}, + }, + { + `Mary Smith <mary@x.test>, jdoe@example.org, Who? <one@y.test>`, + []*Address{ + { + Name: "Mary Smith", + Address: "mary@x.test", + }, + { + Address: "jdoe@example.org", + }, + { + Name: "Who?", + Address: "one@y.test", + }, + }, + }, + { + `<boss@nil.test>, "Giant; \"Big\" Box" <sysservices@example.net>`, + []*Address{ + { + Address: "boss@nil.test", + }, + { + Name: `Giant; "Big" Box`, + Address: "sysservices@example.net", + }, + }, + }, + // RFC 5322, Appendix A.1.3 + // TODO(dsymonds): Group addresses. + + // RFC 2047 "Q"-encoded ISO-8859-1 address. + { + `=?iso-8859-1?q?J=F6rg_Doe?= <joerg@example.com>`, + []*Address{ + { + Name: `Jörg Doe`, + Address: "joerg@example.com", + }, + }, + }, + // RFC 2047 "Q"-encoded UTF-8 address. + { + `=?utf-8?q?J=C3=B6rg_Doe?= <joerg@example.com>`, + []*Address{ + { + Name: `Jörg Doe`, + Address: "joerg@example.com", + }, + }, + }, + // RFC 2047, Section 8. + { + `=?ISO-8859-1?Q?Andr=E9?= Pirard <PIRARD@vm1.ulg.ac.be>`, + []*Address{ + { + Name: `André Pirard`, + Address: "PIRARD@vm1.ulg.ac.be", + }, + }, + }, + // Custom example of RFC 2047 "B"-encoded ISO-8859-1 address. + { + `=?ISO-8859-1?B?SvZyZw==?= <joerg@example.com>`, + []*Address{ + { + Name: `Jörg`, + Address: "joerg@example.com", + }, + }, + }, + // Custom example of RFC 2047 "B"-encoded UTF-8 address. + { + `=?UTF-8?B?SsO2cmc=?= <joerg@example.com>`, + []*Address{ + { + Name: `Jörg`, + Address: "joerg@example.com", + }, + }, + }, + } + for _, test := range tests { + addrs, err := newAddrParser(test.addrsStr).parseAddressList() + if err != nil { + t.Errorf("Failed parsing %q: %v", test.addrsStr, err) + continue + } + if !reflect.DeepEqual(addrs, test.exp) { + t.Errorf("Parse of %q: got %+v, want %+v", test.addrsStr, addrs, test.exp) + } + } +} + +func TestAddressFormatting(t *testing.T) { + tests := []struct { + addr *Address + exp string + }{ + { + &Address{Address: "bob@example.com"}, + "<bob@example.com>", + }, + { + &Address{Name: "Bob", Address: "bob@example.com"}, + `"Bob" <bob@example.com>`, + }, + { + // note the ö (o with an umlaut) + &Address{Name: "Böb", Address: "bob@example.com"}, + `=?utf-8?q?B=C3=B6b?= <bob@example.com>`, + }, + } + for _, test := range tests { + s := test.addr.String() + if s != test.exp { + t.Errorf("Address%+v.String() = %v, want %v", *test.addr, s, test.exp) + } + } +} diff --git a/src/pkg/net/multicast_test.go b/src/pkg/net/multicast_test.go index a66250c84..183d5a8ab 100644 --- a/src/pkg/net/multicast_test.go +++ b/src/pkg/net/multicast_test.go @@ -13,7 +13,7 @@ import ( var multicast = flag.Bool("multicast", false, "enable multicast tests") -var joinAndLeaveGroupUDPTests = []struct { +var multicastUDPTests = []struct { net string laddr IP gaddr IP @@ -32,8 +32,8 @@ var joinAndLeaveGroupUDPTests = []struct { {"udp6", IPv6unspecified, ParseIP("ff0e::114"), (FlagUp | FlagLoopback), true}, } -func TestJoinAndLeaveGroupUDP(t *testing.T) { - if runtime.GOOS == "windows" { +func TestMulticastUDP(t *testing.T) { + if runtime.GOOS == "plan9" || runtime.GOOS == "windows" { return } if !*multicast { @@ -41,7 +41,7 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) { return } - for _, tt := range joinAndLeaveGroupUDPTests { + for _, tt := range multicastUDPTests { var ( ifi *Interface found bool @@ -51,7 +51,7 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) { } ift, err := Interfaces() if err != nil { - t.Fatalf("Interfaces() failed: %v", err) + t.Fatalf("Interfaces failed: %v", err) } for _, x := range ift { if x.Flags&tt.flags == tt.flags { @@ -65,15 +65,20 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) { } c, err := ListenUDP(tt.net, &UDPAddr{IP: tt.laddr}) if err != nil { - t.Fatal(err) + t.Fatalf("ListenUDP failed: %v", err) } defer c.Close() if err := c.JoinGroup(ifi, tt.gaddr); err != nil { - t.Fatal(err) + t.Fatalf("JoinGroup failed: %v", err) + } + if !tt.ipv6 { + testIPv4MulticastSocketOptions(t, c.fd, ifi) + } else { + testIPv6MulticastSocketOptions(t, c.fd, ifi) } ifmat, err := ifi.MulticastAddrs() if err != nil { - t.Fatalf("MulticastAddrs() failed: %v", err) + t.Fatalf("MulticastAddrs failed: %v", err) } for _, ifma := range ifmat { if ifma.(*IPAddr).IP.Equal(tt.gaddr) { @@ -85,7 +90,114 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) { t.Fatalf("%q not found in RIB", tt.gaddr.String()) } if err := c.LeaveGroup(ifi, tt.gaddr); err != nil { - t.Fatal(err) + t.Fatalf("LeaveGroup failed: %v", err) + } + } +} + +func TestSimpleMulticastUDP(t *testing.T) { + if runtime.GOOS == "plan9" { + return + } + if !*multicast { + t.Logf("test disabled; use --multicast to enable") + return + } + + for _, tt := range multicastUDPTests { + var ifi *Interface + if tt.ipv6 { + continue + } + tt.flags = FlagUp | FlagMulticast + ift, err := Interfaces() + if err != nil { + t.Fatalf("Interfaces failed: %v", err) + } + for _, x := range ift { + if x.Flags&tt.flags == tt.flags { + ifi = &x + break + } + } + if ifi == nil { + t.Logf("an appropriate multicast interface not found") + return + } + c, err := ListenUDP(tt.net, &UDPAddr{IP: tt.laddr}) + if err != nil { + t.Fatalf("ListenUDP failed: %v", err) + } + defer c.Close() + if err := c.JoinGroup(ifi, tt.gaddr); err != nil { + t.Fatalf("JoinGroup failed: %v", err) + } + if err := c.LeaveGroup(ifi, tt.gaddr); err != nil { + t.Fatalf("LeaveGroup failed: %v", err) } } } + +func testIPv4MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) { + ifmc, err := ipv4MulticastInterface(fd) + if err != nil { + t.Fatalf("ipv4MulticastInterface failed: %v", err) + } + t.Logf("IPv4 multicast interface: %v", ifmc) + err = setIPv4MulticastInterface(fd, ifi) + if err != nil { + t.Fatalf("setIPv4MulticastInterface failed: %v", err) + } + + ttl, err := ipv4MulticastTTL(fd) + if err != nil { + t.Fatalf("ipv4MulticastTTL failed: %v", err) + } + t.Logf("IPv4 multicast TTL: %v", ttl) + err = setIPv4MulticastTTL(fd, 1) + if err != nil { + t.Fatalf("setIPv4MulticastTTL failed: %v", err) + } + + loop, err := ipv4MulticastLoopback(fd) + if err != nil { + t.Fatalf("ipv4MulticastLoopback failed: %v", err) + } + t.Logf("IPv4 multicast loopback: %v", loop) + err = setIPv4MulticastLoopback(fd, false) + if err != nil { + t.Fatalf("setIPv4MulticastLoopback failed: %v", err) + } +} + +func testIPv6MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) { + ifmc, err := ipv6MulticastInterface(fd) + if err != nil { + t.Fatalf("ipv6MulticastInterface failed: %v", err) + } + t.Logf("IPv6 multicast interface: %v", ifmc) + err = setIPv6MulticastInterface(fd, ifi) + if err != nil { + t.Fatalf("setIPv6MulticastInterface failed: %v", err) + } + + hoplim, err := ipv6MulticastHopLimit(fd) + if err != nil { + t.Fatalf("ipv6MulticastHopLimit failed: %v", err) + } + t.Logf("IPv6 multicast hop limit: %v", hoplim) + err = setIPv6MulticastHopLimit(fd, 1) + if err != nil { + t.Fatalf("setIPv6MulticastHopLimit failed: %v", err) + } + + loop, err := ipv6MulticastLoopback(fd) + if err != nil { + t.Fatalf("ipv6MulticastLoopback failed: %v", err) + } + t.Logf("IPv6 multicast loopback: %v", loop) + err = setIPv6MulticastLoopback(fd, false) + if err != nil { + t.Fatalf("setIPv6MulticastLoopback failed: %v", err) + } +} diff --git a/src/pkg/net/net.go b/src/pkg/net/net.go index 5c84d3434..609fee242 100644 --- a/src/pkg/net/net.go +++ b/src/pkg/net/net.go @@ -9,7 +9,10 @@ package net // TODO(rsc): // support for raw ethernet sockets -import "os" +import ( + "errors" + "time" +) // Addr represents a network end point address. type Addr interface { @@ -21,16 +24,16 @@ type Addr interface { type Conn interface { // Read reads data from the connection. // Read can be made to time out and return a net.Error with Timeout() == true - // after a fixed time limit; see SetTimeout and SetReadTimeout. - Read(b []byte) (n int, err os.Error) + // after a fixed time limit; see SetDeadline and SetReadDeadline. + Read(b []byte) (n int, err error) // Write writes data to the connection. // Write can be made to time out and return a net.Error with Timeout() == true - // after a fixed time limit; see SetTimeout and SetWriteTimeout. - Write(b []byte) (n int, err os.Error) + // after a fixed time limit; see SetDeadline and SetWriteDeadline. + Write(b []byte) (n int, err error) // Close closes the connection. - Close() os.Error + Close() error // LocalAddr returns the local network address. LocalAddr() Addr @@ -38,26 +41,28 @@ type Conn interface { // RemoteAddr returns the remote network address. RemoteAddr() Addr - // SetTimeout sets the read and write deadlines associated + // SetDeadline sets the read and write deadlines associated // with the connection. - SetTimeout(nsec int64) os.Error - - // SetReadTimeout sets the time (in nanoseconds) that - // Read will wait for data before returning an error with Timeout() == true. - // Setting nsec == 0 (the default) disables the deadline. - SetReadTimeout(nsec int64) os.Error - - // SetWriteTimeout sets the time (in nanoseconds) that - // Write will wait to send its data before returning an error with Timeout() == true. - // Setting nsec == 0 (the default) disables the deadline. + SetDeadline(t time.Time) error + + // SetReadDeadline sets the deadline for all Read calls to return. + // If the deadline is reached, Read will fail with a timeout + // (see type Error) instead of blocking. + // A zero value for t means Read will not time out. + SetReadDeadline(t time.Time) error + + // SetWriteDeadline sets the deadline for all Write calls to return. + // If the deadline is reached, Write will fail with a timeout + // (see type Error) instead of blocking. + // A zero value for t means Write will not time out. // Even if write times out, it may return n > 0, indicating that // some of the data was successfully written. - SetWriteTimeout(nsec int64) os.Error + SetWriteDeadline(t time.Time) error } // An Error represents a network error. type Error interface { - os.Error + error Timeout() bool // Is the error a timeout? Temporary() bool // Is the error temporary? } @@ -70,61 +75,63 @@ type PacketConn interface { // was on the packet. // ReadFrom can be made to time out and return // an error with Timeout() == true after a fixed time limit; - // see SetTimeout and SetReadTimeout. - ReadFrom(b []byte) (n int, addr Addr, err os.Error) + // see SetDeadline and SetReadDeadline. + ReadFrom(b []byte) (n int, addr Addr, err error) // WriteTo writes a packet with payload b to addr. // WriteTo can be made to time out and return // an error with Timeout() == true after a fixed time limit; - // see SetTimeout and SetWriteTimeout. + // see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. - WriteTo(b []byte, addr Addr) (n int, err os.Error) + WriteTo(b []byte, addr Addr) (n int, err error) // Close closes the connection. - Close() os.Error + Close() error // LocalAddr returns the local network address. LocalAddr() Addr - // SetTimeout sets the read and write deadlines associated + // SetDeadline sets the read and write deadlines associated // with the connection. - SetTimeout(nsec int64) os.Error - - // SetReadTimeout sets the time (in nanoseconds) that - // Read will wait for data before returning an error with Timeout() == true. - // Setting nsec == 0 (the default) disables the deadline. - SetReadTimeout(nsec int64) os.Error - - // SetWriteTimeout sets the time (in nanoseconds) that - // Write will wait to send its data before returning an error with Timeout() == true. - // Setting nsec == 0 (the default) disables the deadline. + SetDeadline(t time.Time) error + + // SetReadDeadline sets the deadline for all Read calls to return. + // If the deadline is reached, Read will fail with a timeout + // (see type Error) instead of blocking. + // A zero value for t means Read will not time out. + SetReadDeadline(t time.Time) error + + // SetWriteDeadline sets the deadline for all Write calls to return. + // If the deadline is reached, Write will fail with a timeout + // (see type Error) instead of blocking. + // A zero value for t means Write will not time out. // Even if write times out, it may return n > 0, indicating that // some of the data was successfully written. - SetWriteTimeout(nsec int64) os.Error + SetWriteDeadline(t time.Time) error } // A Listener is a generic network listener for stream-oriented protocols. type Listener interface { // Accept waits for and returns the next connection to the listener. - Accept() (c Conn, err os.Error) + Accept() (c Conn, err error) // Close closes the listener. - Close() os.Error + Close() error // Addr returns the listener's network address. Addr() Addr } -var errMissingAddress = os.NewError("missing address") +var errMissingAddress = errors.New("missing address") type OpError struct { - Op string - Net string - Addr Addr - Error os.Error + Op string + Net string + Addr Addr + Err error } -func (e *OpError) String() string { +func (e *OpError) Error() string { if e == nil { return "<nil>" } @@ -135,7 +142,7 @@ func (e *OpError) String() string { if e.Addr != nil { s += " " + e.Addr.String() } - s += ": " + e.Error.String() + s += ": " + e.Err.Error() return s } @@ -144,7 +151,7 @@ type temporary interface { } func (e *OpError) Temporary() bool { - t, ok := e.Error.(temporary) + t, ok := e.Err.(temporary) return ok && t.Temporary() } @@ -153,20 +160,28 @@ type timeout interface { } func (e *OpError) Timeout() bool { - t, ok := e.Error.(timeout) + t, ok := e.Err.(timeout) return ok && t.Timeout() } +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +var errTimeout error = &timeoutError{} + type AddrError struct { - Error string - Addr string + Err string + Addr string } -func (e *AddrError) String() string { +func (e *AddrError) Error() string { if e == nil { return "<nil>" } - s := e.Error + s := e.Err if e.Addr != "" { s += " " + e.Addr } @@ -183,6 +198,6 @@ func (e *AddrError) Timeout() bool { type UnknownNetworkError string -func (e UnknownNetworkError) String() string { return "unknown network " + string(e) } +func (e UnknownNetworkError) Error() string { return "unknown network " + string(e) } func (e UnknownNetworkError) Temporary() bool { return false } func (e UnknownNetworkError) Timeout() bool { return false } diff --git a/src/pkg/net/net_test.go b/src/pkg/net/net_test.go index 698a84527..0dc86698e 100644 --- a/src/pkg/net/net_test.go +++ b/src/pkg/net/net_test.go @@ -6,7 +6,9 @@ package net import ( "flag" + "io" "regexp" + "runtime" "testing" ) @@ -61,6 +63,8 @@ var dialErrorTests = []DialErrorTest{ }, } +var duplicateErrorPattern = `dial (.*) dial (.*)` + func TestDialError(t *testing.T) { if !*runErrorTest { t.Logf("test disabled; use --run_error_test to enable") @@ -75,11 +79,15 @@ func TestDialError(t *testing.T) { t.Errorf("#%d: nil error, want match for %#q", i, tt.Pattern) continue } - s := e.String() + s := e.Error() match, _ := regexp.MatchString(tt.Pattern, s) if !match { t.Errorf("#%d: %q, want match for %#q", i, s, tt.Pattern) } + match, _ = regexp.MatchString(duplicateErrorPattern, s) + if match { + t.Errorf("#%d: %q, duplicate error return from Dial", i, s) + } } } @@ -111,11 +119,57 @@ func TestReverseAddress(t *testing.T) { if len(tt.ErrPrefix) == 0 && e != nil { t.Errorf("#%d: expected <nil>, got %q (error)", i, e) } - if e != nil && e.(*DNSError).Error != tt.ErrPrefix { - t.Errorf("#%d: expected %q, got %q (mismatched error)", i, tt.ErrPrefix, e.(*DNSError).Error) + if e != nil && e.(*DNSError).Err != tt.ErrPrefix { + t.Errorf("#%d: expected %q, got %q (mismatched error)", i, tt.ErrPrefix, e.(*DNSError).Err) } if a != tt.Reverse { t.Errorf("#%d: expected %q, got %q (reverse address)", i, tt.Reverse, a) } } } + +func TestShutdown(t *testing.T) { + if runtime.GOOS == "plan9" { + return + } + l, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = Listen("tcp6", "[::1]:0"); err != nil { + t.Fatalf("ListenTCP on :0: %v", err) + } + } + + go func() { + c, err := l.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + var buf [10]byte + n, err := c.Read(buf[:]) + if n != 0 || err != io.EOF { + t.Fatalf("server Read = %d, %v; want 0, io.EOF", n, err) + } + c.Write([]byte("response")) + c.Close() + }() + + c, err := Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + + err = c.(*TCPConn).CloseWrite() + if err != nil { + t.Fatalf("CloseWrite: %v", err) + } + var buf [10]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("client Read: %d, %v", n, err) + } + got := string(buf[:n]) + if got != "response" { + t.Errorf("read = %q, want \"response\"", got) + } +} diff --git a/src/pkg/net/newpollserver.go b/src/pkg/net/newpollserver.go index 3c9a6da53..a410bb6ce 100644 --- a/src/pkg/net/newpollserver.go +++ b/src/pkg/net/newpollserver.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd +// +build darwin freebsd linux netbsd openbsd package net @@ -11,18 +11,17 @@ import ( "syscall" ) -func newPollServer() (s *pollServer, err os.Error) { +func newPollServer() (s *pollServer, err error) { s = new(pollServer) s.cr = make(chan *netFD, 1) s.cw = make(chan *netFD, 1) if s.pr, s.pw, err = os.Pipe(); err != nil { return nil, err } - var e int - if e = syscall.SetNonblock(s.pr.Fd(), true); e != 0 { + if err = syscall.SetNonblock(s.pr.Fd(), true); err != nil { goto Errno } - if e = syscall.SetNonblock(s.pw.Fd(), true); e != 0 { + if err = syscall.SetNonblock(s.pw.Fd(), true); err != nil { goto Errno } if s.poll, err = newpollster(); err != nil { @@ -37,7 +36,7 @@ func newPollServer() (s *pollServer, err os.Error) { return s, nil Errno: - err = &os.PathError{"setnonblock", s.pr.Name(), os.Errno(e)} + err = &os.PathError{"setnonblock", s.pr.Name(), err} Error: s.pr.Close() s.pw.Close() diff --git a/src/pkg/net/parse.go b/src/pkg/net/parse.go index 0d30a7ac6..4c4200a49 100644 --- a/src/pkg/net/parse.go +++ b/src/pkg/net/parse.go @@ -54,7 +54,7 @@ func (f *file) readLine() (s string, ok bool) { if n >= 0 { f.data = f.data[0 : ln+n] } - if err == os.EOF { + if err == io.EOF { f.atEOF = true } } @@ -62,7 +62,7 @@ func (f *file) readLine() (s string, ok bool) { return } -func open(name string) (*file, os.Error) { +func open(name string) (*file, error) { fd, err := os.Open(name) if err != nil { return nil, err diff --git a/src/pkg/net/parse_test.go b/src/pkg/net/parse_test.go index 8d51eba18..dfbaba4d9 100644 --- a/src/pkg/net/parse_test.go +++ b/src/pkg/net/parse_test.go @@ -7,8 +7,8 @@ package net import ( "bufio" "os" - "testing" "runtime" + "testing" ) func TestReadLine(t *testing.T) { diff --git a/src/pkg/net/pipe.go b/src/pkg/net/pipe.go index c0bbd356b..f1a2eca4e 100644 --- a/src/pkg/net/pipe.go +++ b/src/pkg/net/pipe.go @@ -1,8 +1,13 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package net import ( + "errors" "io" - "os" + "time" ) // Pipe creates a synchronous, in-memory, full duplex @@ -32,7 +37,7 @@ func (pipeAddr) String() string { return "pipe" } -func (p *pipe) Close() os.Error { +func (p *pipe) Close() error { err := p.PipeReader.Close() err1 := p.PipeWriter.Close() if err == nil { @@ -49,14 +54,14 @@ func (p *pipe) RemoteAddr() Addr { return pipeAddr(0) } -func (p *pipe) SetTimeout(nsec int64) os.Error { - return os.NewError("net.Pipe does not support timeouts") +func (p *pipe) SetDeadline(t time.Time) error { + return errors.New("net.Pipe does not support deadlines") } -func (p *pipe) SetReadTimeout(nsec int64) os.Error { - return os.NewError("net.Pipe does not support timeouts") +func (p *pipe) SetReadDeadline(t time.Time) error { + return errors.New("net.Pipe does not support deadlines") } -func (p *pipe) SetWriteTimeout(nsec int64) os.Error { - return os.NewError("net.Pipe does not support timeouts") +func (p *pipe) SetWriteDeadline(t time.Time) error { + return errors.New("net.Pipe does not support deadlines") } diff --git a/src/pkg/net/pipe_test.go b/src/pkg/net/pipe_test.go index 7e4c6db44..afe4f2408 100644 --- a/src/pkg/net/pipe_test.go +++ b/src/pkg/net/pipe_test.go @@ -7,7 +7,6 @@ package net import ( "bytes" "io" - "os" "testing" ) @@ -22,7 +21,7 @@ func checkWrite(t *testing.T, w io.Writer, data []byte, c chan int) { c <- 0 } -func checkRead(t *testing.T, r io.Reader, data []byte, wantErr os.Error) { +func checkRead(t *testing.T, r io.Reader, data []byte, wantErr error) { buf := make([]byte, len(data)+10) n, err := r.Read(buf) if err != wantErr { @@ -52,6 +51,6 @@ func TestPipe(t *testing.T) { checkRead(t, srv, []byte("a third line"), nil) <-c go srv.Close() - checkRead(t, cli, nil, os.EOF) + checkRead(t, cli, nil, io.EOF) cli.Close() } diff --git a/src/pkg/net/port.go b/src/pkg/net/port.go index a8ca60c60..16780da11 100644 --- a/src/pkg/net/port.go +++ b/src/pkg/net/port.go @@ -2,19 +2,16 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd +// +build darwin freebsd linux netbsd openbsd // Read system port mappings from /etc/services package net -import ( - "os" - "sync" -) +import "sync" var services map[string]map[string]int -var servicesError os.Error +var servicesError error var onceReadServices sync.Once func readServices() { @@ -53,7 +50,7 @@ func readServices() { } // goLookupPort is the native Go implementation of LookupPort. -func goLookupPort(network, service string) (port int, err os.Error) { +func goLookupPort(network, service string) (port int, err error) { onceReadServices.Do(readServices) switch network { diff --git a/src/pkg/net/rpc/Makefile b/src/pkg/net/rpc/Makefile new file mode 100644 index 000000000..0e6c9846b --- /dev/null +++ b/src/pkg/net/rpc/Makefile @@ -0,0 +1,13 @@ +# Copyright 2009 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=net/rpc +GOFILES=\ + client.go\ + debug.go\ + server.go\ + +include ../../../Make.pkg diff --git a/src/pkg/net/rpc/client.go b/src/pkg/net/rpc/client.go new file mode 100644 index 000000000..abc1e59cd --- /dev/null +++ b/src/pkg/net/rpc/client.go @@ -0,0 +1,288 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +import ( + "bufio" + "encoding/gob" + "errors" + "io" + "log" + "net" + "net/http" + "sync" +) + +// ServerError represents an error that has been returned from +// the remote side of the RPC connection. +type ServerError string + +func (e ServerError) Error() string { + return string(e) +} + +var ErrShutdown = errors.New("connection is shut down") + +// Call represents an active RPC. +type Call struct { + ServiceMethod string // The name of the service and method to call. + Args interface{} // The argument to the function (*struct). + Reply interface{} // The reply from the function (*struct). + Error error // After completion, the error status. + Done chan *Call // Strobes when call is complete; value is the error status. + seq uint64 +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client. +type Client struct { + mutex sync.Mutex // protects pending, seq, request + sending sync.Mutex + request Request + seq uint64 + codec ClientCodec + pending map[uint64]*Call + closing bool + shutdown bool +} + +// A ClientCodec implements writing of RPC requests and +// reading of RPC responses for the client side of an RPC session. +// The client calls WriteRequest to write a request to the connection +// and calls ReadResponseHeader and ReadResponseBody in pairs +// to read responses. The client calls Close when finished with the +// connection. ReadResponseBody may be called with a nil +// argument to force the body of the response to be read and then +// discarded. +type ClientCodec interface { + WriteRequest(*Request, interface{}) error + ReadResponseHeader(*Response) error + ReadResponseBody(interface{}) error + + Close() error +} + +func (client *Client) send(c *Call) { + // Register this call. + client.mutex.Lock() + if client.shutdown { + c.Error = ErrShutdown + client.mutex.Unlock() + c.done() + return + } + c.seq = client.seq + client.seq++ + client.pending[c.seq] = c + client.mutex.Unlock() + + // Encode and send the request. + client.sending.Lock() + defer client.sending.Unlock() + client.request.Seq = c.seq + client.request.ServiceMethod = c.ServiceMethod + if err := client.codec.WriteRequest(&client.request, c.Args); err != nil { + c.Error = err + c.done() + } +} + +func (client *Client) input() { + var err error + var response Response + for err == nil { + response = Response{} + err = client.codec.ReadResponseHeader(&response) + if err != nil { + if err == io.EOF && !client.closing { + err = io.ErrUnexpectedEOF + } + break + } + seq := response.Seq + client.mutex.Lock() + c := client.pending[seq] + delete(client.pending, seq) + client.mutex.Unlock() + + if response.Error == "" { + err = client.codec.ReadResponseBody(c.Reply) + if err != nil { + c.Error = errors.New("reading body " + err.Error()) + } + } else { + // We've got an error response. Give this to the request; + // any subsequent requests will get the ReadResponseBody + // error if there is one. + c.Error = ServerError(response.Error) + err = client.codec.ReadResponseBody(nil) + if err != nil { + err = errors.New("reading error body: " + err.Error()) + } + } + c.done() + } + // Terminate pending calls. + client.mutex.Lock() + client.shutdown = true + for _, call := range client.pending { + call.Error = err + call.done() + } + client.mutex.Unlock() + if err != io.EOF || !client.closing { + log.Println("rpc: client protocol error:", err) + } +} + +func (call *Call) done() { + select { + case call.Done <- call: + // ok + default: + // We don't want to block here. It is the caller's responsibility to make + // sure the channel has enough buffer space. See comment in Go(). + log.Println("rpc: discarding Call reply due to insufficient Done chan capacity") + } +} + +// NewClient returns a new Client to handle requests to the +// set of services at the other end of the connection. +// It adds a buffer to the write side of the connection so +// the header and payload are sent as a unit. +func NewClient(conn io.ReadWriteCloser) *Client { + encBuf := bufio.NewWriter(conn) + client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf} + return NewClientWithCodec(client) +} + +// NewClientWithCodec is like NewClient but uses the specified +// codec to encode requests and decode responses. +func NewClientWithCodec(codec ClientCodec) *Client { + client := &Client{ + codec: codec, + pending: make(map[uint64]*Call), + } + go client.input() + return client +} + +type gobClientCodec struct { + rwc io.ReadWriteCloser + dec *gob.Decoder + enc *gob.Encoder + encBuf *bufio.Writer +} + +func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) { + if err = c.enc.Encode(r); err != nil { + return + } + if err = c.enc.Encode(body); err != nil { + return + } + return c.encBuf.Flush() +} + +func (c *gobClientCodec) ReadResponseHeader(r *Response) error { + return c.dec.Decode(r) +} + +func (c *gobClientCodec) ReadResponseBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *gobClientCodec) Close() error { + return c.rwc.Close() +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string) (*Client, error) { + return DialHTTPPath(network, address, DefaultRPCPath) +} + +// DialHTTPPath connects to an HTTP RPC server +// at the specified network address and path. +func DialHTTPPath(network, address, path string) (*Client, error) { + var err error + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n") + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn), nil + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + conn.Close() + return nil, &net.OpError{"dial-http", network + " " + address, nil, err} +} + +// Dial connects to an RPC server at the specified network address. +func Dial(network, address string) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewClient(conn), nil +} + +func (client *Client) Close() error { + client.mutex.Lock() + if client.shutdown || client.closing { + client.mutex.Unlock() + return ErrShutdown + } + client.closing = true + client.mutex.Unlock() + return client.codec.Close() +} + +// Go invokes the function asynchronously. It returns the Call structure representing +// the invocation. The done channel will signal when the call is complete by returning +// the same Call object. If done is nil, Go will allocate a new channel. +// If non-nil, done must be buffered or Go will deliberately crash. +func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call { + call := new(Call) + call.ServiceMethod = serviceMethod + call.Args = args + call.Reply = reply + if done == nil { + done = make(chan *Call, 10) // buffered. + } else { + // If caller passes done != nil, it must arrange that + // done has enough buffer for the number of simultaneous + // RPCs that will be using that channel. If the channel + // is totally unbuffered, it's best not to run at all. + if cap(done) == 0 { + log.Panic("rpc: done channel is unbuffered") + } + } + call.Done = done + if client.shutdown { + call.Error = ErrShutdown + call.done() + return call + } + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, and returns its error status. +func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error { + if client.shutdown { + return ErrShutdown + } + call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done + return call.Error +} diff --git a/src/pkg/net/rpc/debug.go b/src/pkg/net/rpc/debug.go new file mode 100644 index 000000000..663663fe9 --- /dev/null +++ b/src/pkg/net/rpc/debug.go @@ -0,0 +1,90 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +/* + Some HTML presented at http://machine:port/debug/rpc + Lists services, their methods, and some statistics, still rudimentary. +*/ + +import ( + "fmt" + "net/http" + "sort" + "text/template" +) + +const debugText = `<html> + <body> + <title>Services</title> + {{range .}} + <hr> + Service {{.Name}} + <hr> + <table> + <th align=center>Method</th><th align=center>Calls</th> + {{range .Method}} + <tr> + <td align=left font=fixed>{{.Name}}({{.Type.ArgType}}, {{.Type.ReplyType}}) error</td> + <td align=center>{{.Type.NumCalls}}</td> + </tr> + {{end}} + </table> + {{end}} + </body> + </html>` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +type debugMethod struct { + Type *methodType + Name string +} + +type methodArray []debugMethod + +type debugService struct { + Service *service + Name string + Method methodArray +} + +type serviceArray []debugService + +func (s serviceArray) Len() int { return len(s) } +func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name } +func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func (m methodArray) Len() int { return len(m) } +func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name } +func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] } + +type debugHTTP struct { + *Server +} + +// Runs at /debug/rpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services = make(serviceArray, len(server.serviceMap)) + i := 0 + server.mu.Lock() + for sname, service := range server.serviceMap { + services[i] = debugService{service, sname, make(methodArray, len(service.method))} + j := 0 + for mname, method := range service.method { + services[i].Method[j] = debugMethod{method, mname} + j++ + } + sort.Sort(services[i].Method) + i++ + } + server.mu.Unlock() + sort.Sort(services) + err := debug.Execute(w, services) + if err != nil { + fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} diff --git a/src/pkg/net/rpc/jsonrpc/Makefile b/src/pkg/net/rpc/jsonrpc/Makefile new file mode 100644 index 000000000..c5ea5373d --- /dev/null +++ b/src/pkg/net/rpc/jsonrpc/Makefile @@ -0,0 +1,12 @@ +# Copyright 2010 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../../Make.inc + +TARG=net/rpc/jsonrpc +GOFILES=\ + client.go\ + server.go\ + +include ../../../../Make.pkg diff --git a/src/pkg/net/rpc/jsonrpc/all_test.go b/src/pkg/net/rpc/jsonrpc/all_test.go new file mode 100644 index 000000000..e6c7441f0 --- /dev/null +++ b/src/pkg/net/rpc/jsonrpc/all_test.go @@ -0,0 +1,221 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jsonrpc + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/rpc" + "testing" +) + +type Args struct { + A, B int +} + +type Reply struct { + C int +} + +type Arith int + +func (t *Arith) Add(args *Args, reply *Reply) error { + reply.C = args.A + args.B + return nil +} + +func (t *Arith) Mul(args *Args, reply *Reply) error { + reply.C = args.A * args.B + return nil +} + +func (t *Arith) Div(args *Args, reply *Reply) error { + if args.B == 0 { + return errors.New("divide by zero") + } + reply.C = args.A / args.B + return nil +} + +func (t *Arith) Error(args *Args, reply *Reply) error { + panic("ERROR") +} + +func init() { + rpc.Register(new(Arith)) +} + +func TestServer(t *testing.T) { + type addResp struct { + Id interface{} `json:"id"` + Result Reply `json:"result"` + Error interface{} `json:"error"` + } + + cli, srv := net.Pipe() + defer cli.Close() + go ServeConn(srv) + dec := json.NewDecoder(cli) + + // Send hand-coded requests to server, parse responses. + for i := 0; i < 10; i++ { + fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1) + var resp addResp + err := dec.Decode(&resp) + if err != nil { + t.Fatalf("Decode: %s", err) + } + if resp.Error != nil { + t.Fatalf("resp.Error: %s", resp.Error) + } + if resp.Id.(string) != string(i) { + t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(i)) + } + if resp.Result.C != 2*i+1 { + t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C) + } + } + + fmt.Fprintf(cli, "{}\n") + var resp addResp + if err := dec.Decode(&resp); err != nil { + t.Fatalf("Decode after empty: %s", err) + } + if resp.Error == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestClient(t *testing.T) { + // Assume server is okay (TestServer is above). + // Test client against server. + cli, srv := net.Pipe() + go ServeConn(srv) + + client := NewClient(cli) + defer client.Close() + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err := client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Mul", args, reply) + if err != nil { + t.Errorf("Mul: expected no error but got string %q", err.Error()) + } + if reply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + } + + // Out of order. + args = &Args{7, 8} + mulReply := new(Reply) + mulCall := client.Go("Arith.Mul", args, mulReply, nil) + addReply := new(Reply) + addCall := client.Go("Arith.Add", args, addReply, nil) + + addCall = <-addCall.Done + if addCall.Error != nil { + t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) + } + if addReply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) + } + + mulCall = <-mulCall.Done + if mulCall.Error != nil { + t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) + } + if mulReply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) + } + + // Error test + args = &Args{7, 0} + reply = new(Reply) + err = client.Call("Arith.Div", args, reply) + // expect an error: zero divide + if err == nil { + t.Error("Div: expected error") + } else if err.Error() != "divide by zero" { + t.Error("Div: expected divide by zero error; got", err) + } +} + +func TestMalformedInput(t *testing.T) { + cli, srv := net.Pipe() + go cli.Write([]byte(`{id:1}`)) // invalid json + ServeConn(srv) // must return, not loop +} + +func TestUnexpectedError(t *testing.T) { + cli, srv := myPipe() + go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error + ServeConn(srv) // must return, not loop +} + +// Copied from package net. +func myPipe() (*pipe, *pipe) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + return &pipe{r1, w2}, &pipe{r2, w1} +} + +type pipe struct { + *io.PipeReader + *io.PipeWriter +} + +type pipeAddr int + +func (pipeAddr) Network() string { + return "pipe" +} + +func (pipeAddr) String() string { + return "pipe" +} + +func (p *pipe) Close() error { + err := p.PipeReader.Close() + err1 := p.PipeWriter.Close() + if err == nil { + err = err1 + } + return err +} + +func (p *pipe) LocalAddr() net.Addr { + return pipeAddr(0) +} + +func (p *pipe) RemoteAddr() net.Addr { + return pipeAddr(0) +} + +func (p *pipe) SetTimeout(nsec int64) error { + return errors.New("net.Pipe does not support timeouts") +} + +func (p *pipe) SetReadTimeout(nsec int64) error { + return errors.New("net.Pipe does not support timeouts") +} + +func (p *pipe) SetWriteTimeout(nsec int64) error { + return errors.New("net.Pipe does not support timeouts") +} diff --git a/src/pkg/net/rpc/jsonrpc/client.go b/src/pkg/net/rpc/jsonrpc/client.go new file mode 100644 index 000000000..3fa8cbf08 --- /dev/null +++ b/src/pkg/net/rpc/jsonrpc/client.go @@ -0,0 +1,123 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package jsonrpc implements a JSON-RPC ClientCodec and ServerCodec +// for the rpc package. +package jsonrpc + +import ( + "encoding/json" + "fmt" + "io" + "net" + "net/rpc" + "sync" +) + +type clientCodec struct { + dec *json.Decoder // for reading JSON values + enc *json.Encoder // for writing JSON values + c io.Closer + + // temporary work space + req clientRequest + resp clientResponse + + // JSON-RPC responses include the request id but not the request method. + // Package rpc expects both. + // We save the request method in pending when sending a request + // and then look it up by request ID when filling out the rpc Response. + mutex sync.Mutex // protects pending + pending map[uint64]string // map request id to method name +} + +// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn. +func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { + return &clientCodec{ + dec: json.NewDecoder(conn), + enc: json.NewEncoder(conn), + c: conn, + pending: make(map[uint64]string), + } +} + +type clientRequest struct { + Method string `json:"method"` + Params [1]interface{} `json:"params"` + Id uint64 `json:"id"` +} + +func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) error { + c.mutex.Lock() + c.pending[r.Seq] = r.ServiceMethod + c.mutex.Unlock() + c.req.Method = r.ServiceMethod + c.req.Params[0] = param + c.req.Id = r.Seq + return c.enc.Encode(&c.req) +} + +type clientResponse struct { + Id uint64 `json:"id"` + Result *json.RawMessage `json:"result"` + Error interface{} `json:"error"` +} + +func (r *clientResponse) reset() { + r.Id = 0 + r.Result = nil + r.Error = nil +} + +func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error { + c.resp.reset() + if err := c.dec.Decode(&c.resp); err != nil { + return err + } + + c.mutex.Lock() + r.ServiceMethod = c.pending[c.resp.Id] + delete(c.pending, c.resp.Id) + c.mutex.Unlock() + + r.Error = "" + r.Seq = c.resp.Id + if c.resp.Error != nil { + x, ok := c.resp.Error.(string) + if !ok { + return fmt.Errorf("invalid error %v", c.resp.Error) + } + if x == "" { + x = "unspecified error" + } + r.Error = x + } + return nil +} + +func (c *clientCodec) ReadResponseBody(x interface{}) error { + if x == nil { + return nil + } + return json.Unmarshal(*c.resp.Result, x) +} + +func (c *clientCodec) Close() error { + return c.c.Close() +} + +// NewClient returns a new rpc.Client to handle requests to the +// set of services at the other end of the connection. +func NewClient(conn io.ReadWriteCloser) *rpc.Client { + return rpc.NewClientWithCodec(NewClientCodec(conn)) +} + +// Dial connects to a JSON-RPC server at the specified network address. +func Dial(network, address string) (*rpc.Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewClient(conn), err +} diff --git a/src/pkg/net/rpc/jsonrpc/server.go b/src/pkg/net/rpc/jsonrpc/server.go new file mode 100644 index 000000000..4c54553a7 --- /dev/null +++ b/src/pkg/net/rpc/jsonrpc/server.go @@ -0,0 +1,136 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jsonrpc + +import ( + "encoding/json" + "errors" + "io" + "net/rpc" + "sync" +) + +type serverCodec struct { + dec *json.Decoder // for reading JSON values + enc *json.Encoder // for writing JSON values + c io.Closer + + // temporary work space + req serverRequest + resp serverResponse + + // JSON-RPC clients can use arbitrary json values as request IDs. + // Package rpc expects uint64 request IDs. + // We assign uint64 sequence numbers to incoming requests + // but save the original request ID in the pending map. + // When rpc responds, we use the sequence number in + // the response to find the original request ID. + mutex sync.Mutex // protects seq, pending + seq uint64 + pending map[uint64]*json.RawMessage +} + +// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn. +func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { + return &serverCodec{ + dec: json.NewDecoder(conn), + enc: json.NewEncoder(conn), + c: conn, + pending: make(map[uint64]*json.RawMessage), + } +} + +type serverRequest struct { + Method string `json:"method"` + Params *json.RawMessage `json:"params"` + Id *json.RawMessage `json:"id"` +} + +func (r *serverRequest) reset() { + r.Method = "" + if r.Params != nil { + *r.Params = (*r.Params)[0:0] + } + if r.Id != nil { + *r.Id = (*r.Id)[0:0] + } +} + +type serverResponse struct { + Id *json.RawMessage `json:"id"` + Result interface{} `json:"result"` + Error interface{} `json:"error"` +} + +func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error { + c.req.reset() + if err := c.dec.Decode(&c.req); err != nil { + return err + } + r.ServiceMethod = c.req.Method + + // JSON request id can be any JSON value; + // RPC package expects uint64. Translate to + // internal uint64 and save JSON on the side. + c.mutex.Lock() + c.seq++ + c.pending[c.seq] = c.req.Id + c.req.Id = nil + r.Seq = c.seq + c.mutex.Unlock() + + return nil +} + +func (c *serverCodec) ReadRequestBody(x interface{}) error { + if x == nil { + return nil + } + // JSON params is array value. + // RPC params is struct. + // Unmarshal into array containing struct for now. + // Should think about making RPC more general. + var params [1]interface{} + params[0] = x + return json.Unmarshal(*c.req.Params, ¶ms) +} + +var null = json.RawMessage([]byte("null")) + +func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error { + var resp serverResponse + c.mutex.Lock() + b, ok := c.pending[r.Seq] + if !ok { + c.mutex.Unlock() + return errors.New("invalid sequence number in response") + } + delete(c.pending, r.Seq) + c.mutex.Unlock() + + if b == nil { + // Invalid request so no id. Use JSON null. + b = &null + } + resp.Id = b + resp.Result = x + if r.Error == "" { + resp.Error = nil + } else { + resp.Error = r.Error + } + return c.enc.Encode(resp) +} + +func (c *serverCodec) Close() error { + return c.c.Close() +} + +// ServeConn runs the JSON-RPC server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +func ServeConn(conn io.ReadWriteCloser) { + rpc.ServeCodec(NewServerCodec(conn)) +} diff --git a/src/pkg/net/rpc/server.go b/src/pkg/net/rpc/server.go new file mode 100644 index 000000000..920ae9137 --- /dev/null +++ b/src/pkg/net/rpc/server.go @@ -0,0 +1,640 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + Package rpc provides access to the exported methods of an object across a + network or other I/O connection. A server registers an object, making it visible + as a service with the name of the type of the object. After registration, exported + methods of the object will be accessible remotely. A server may register multiple + objects (services) of different types but it is an error to register multiple + objects of the same type. + + Only methods that satisfy these criteria will be made available for remote access; + other methods will be ignored: + + - the method name is exported, that is, begins with an upper case letter. + - the method receiver is exported or local (defined in the package + registering the service). + - the method has two arguments, both exported or local types. + - the method's second argument is a pointer. + - the method has return type error. + + The method's first argument represents the arguments provided by the caller; the + second argument represents the result parameters to be returned to the caller. + The method's return value, if non-nil, is passed back as a string that the client + sees as if created by errors.New. + + The server may handle requests on a single connection by calling ServeConn. More + typically it will create a network listener and call Accept or, for an HTTP + listener, HandleHTTP and http.Serve. + + A client wishing to use the service establishes a connection and then invokes + NewClient on the connection. The convenience function Dial (DialHTTP) performs + both steps for a raw network connection (an HTTP connection). The resulting + Client object has two methods, Call and Go, that specify the service and method to + call, a pointer containing the arguments, and a pointer to receive the result + parameters. + + Call waits for the remote call to complete; Go launches the call asynchronously + and returns a channel that will signal completion. + + Package "gob" is used to transport the data. + + Here is a simple example. A server wishes to export an object of type Arith: + + package server + + type Args struct { + A, B int + } + + type Quotient struct { + Quo, Rem int + } + + type Arith int + + func (t *Arith) Multiply(args *Args, reply *int) error { + *reply = args.A * args.B + return nil + } + + func (t *Arith) Divide(args *Args, quo *Quotient) error { + if args.B == 0 { + return errors.New("divide by zero") + } + quo.Quo = args.A / args.B + quo.Rem = args.A % args.B + return nil + } + + The server calls (for HTTP service): + + arith := new(Arith) + rpc.Register(arith) + rpc.HandleHTTP() + l, e := net.Listen("tcp", ":1234") + if e != nil { + log.Fatal("listen error:", e) + } + go http.Serve(l, nil) + + At this point, clients can see a service "Arith" with methods "Arith.Multiply" and + "Arith.Divide". To invoke one, a client first dials the server: + + client, err := rpc.DialHTTP("tcp", serverAddress + ":1234") + if err != nil { + log.Fatal("dialing:", err) + } + + Then it can make a remote call: + + // Synchronous call + args := &server.Args{7,8} + var reply int + err = client.Call("Arith.Multiply", args, &reply) + if err != nil { + log.Fatal("arith error:", err) + } + fmt.Printf("Arith: %d*%d=%d", args.A, args.B, reply) + + or + + // Asynchronous call + quotient := new(Quotient) + divCall := client.Go("Arith.Divide", args, "ient, nil) + replyCall := <-divCall.Done // will be equal to divCall + // check errors, print, etc. + + A server implementation will often provide a simple, type-safe wrapper for the + client. +*/ +package rpc + +import ( + "bufio" + "encoding/gob" + "errors" + "io" + "log" + "net" + "net/http" + "reflect" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +const ( + // Defaults used by HandleHTTP + DefaultRPCPath = "/_goRPC_" + DefaultDebugPath = "/debug/rpc" +) + +// Precompute the reflect type for error. Can't use error directly +// because Typeof takes an empty interface value. This is annoying. +var typeOfError = reflect.TypeOf((*error)(nil)).Elem() + +type methodType struct { + sync.Mutex // protects counters + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint +} + +type service struct { + name string // name of service + rcvr reflect.Value // receiver of methods for the service + typ reflect.Type // type of the receiver + method map[string]*methodType // registered methods +} + +// Request is a header written before every RPC call. It is used internally +// but documented here as an aid to debugging, such as when analyzing +// network traffic. +type Request struct { + ServiceMethod string // format: "Service.Method" + Seq uint64 // sequence number chosen by client + next *Request // for free list in Server +} + +// Response is a header written before every RPC return. It is used internally +// but documented here as an aid to debugging, such as when analyzing +// network traffic. +type Response struct { + ServiceMethod string // echoes that of the Request + Seq uint64 // echoes that of the request + Error string // error, if any. + next *Response // for free list in Server +} + +// Server represents an RPC Server. +type Server struct { + mu sync.Mutex // protects the serviceMap + serviceMap map[string]*service + reqLock sync.Mutex // protects freeReq + freeReq *Request + respLock sync.Mutex // protects freeResp + freeResp *Response +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{serviceMap: make(map[string]*service)} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// Is this an exported - upper case - name? +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// Is this type exported or a builtin? +func isExportedOrBuiltinType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + // PkgPath will be non-empty even for an exported type, + // so we need to check the type name as well. + return isExported(t.Name()) || t.PkgPath() == "" +} + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method +// - two arguments, both pointers to exported structs +// - one return value, of type error +// It returns an error if the receiver is not an exported type or has no +// suitable methods. +// The client accesses each method using a string of the form "Type.Method", +// where Type is the receiver's concrete type. +func (server *Server) Register(rcvr interface{}) error { + return server.register(rcvr, "", false) +} + +// RegisterName is like Register but uses the provided name for the type +// instead of the receiver's concrete type. +func (server *Server) RegisterName(name string, rcvr interface{}) error { + return server.register(rcvr, name, true) +} + +func (server *Server) register(rcvr interface{}, name string, useName bool) error { + server.mu.Lock() + defer server.mu.Unlock() + if server.serviceMap == nil { + server.serviceMap = make(map[string]*service) + } + s := new(service) + s.typ = reflect.TypeOf(rcvr) + s.rcvr = reflect.ValueOf(rcvr) + sname := reflect.Indirect(s.rcvr).Type().Name() + if useName { + sname = name + } + if sname == "" { + log.Fatal("rpc: no service name for type", s.typ.String()) + } + if !isExported(sname) && !useName { + s := "rpc Register: type " + sname + " is not exported" + log.Print(s) + return errors.New(s) + } + if _, present := server.serviceMap[sname]; present { + return errors.New("rpc: service already defined: " + sname) + } + s.name = sname + s.method = make(map[string]*methodType) + + // Install the methods + for m := 0; m < s.typ.NumMethod(); m++ { + method := s.typ.Method(m) + mtype := method.Type + mname := method.Name + if method.PkgPath != "" { + continue + } + // Method needs three ins: receiver, *args, *reply. + if mtype.NumIn() != 3 { + log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) + continue + } + // First arg need not be a pointer. + argType := mtype.In(1) + if !isExportedOrBuiltinType(argType) { + log.Println(mname, "argument type not exported or local:", argType) + continue + } + // Second arg must be a pointer. + replyType := mtype.In(2) + if replyType.Kind() != reflect.Ptr { + log.Println("method", mname, "reply type not a pointer:", replyType) + continue + } + if !isExportedOrBuiltinType(replyType) { + log.Println("method", mname, "reply type not exported or local:", replyType) + continue + } + // Method needs one out: error. + if mtype.NumOut() != 1 { + log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) + continue + } + if returnType := mtype.Out(0); returnType != typeOfError { + log.Println("method", mname, "returns", returnType.String(), "not error") + continue + } + s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} + } + + if len(s.method) == 0 { + s := "rpc Register: type " + sname + " has no exported methods of suitable type" + log.Print(s) + return errors.New(s) + } + server.serviceMap[s.name] = s + return nil +} + +// A value sent as a placeholder for the response when the server receives an invalid request. +type InvalidRequest struct{} + +var invalidRequest = InvalidRequest{} + +func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { + resp := server.getResponse() + // Encode the response header + resp.ServiceMethod = req.ServiceMethod + if errmsg != "" { + resp.Error = errmsg + reply = invalidRequest + } + resp.Seq = req.Seq + sending.Lock() + err := codec.WriteResponse(resp, reply) + if err != nil { + log.Println("rpc: writing response:", err) + } + sending.Unlock() + server.freeResponse(resp) +} + +func (m *methodType) NumCalls() (n uint) { + m.Lock() + n = m.numCalls + m.Unlock() + return n +} + +func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { + mtype.Lock() + mtype.numCalls++ + mtype.Unlock() + function := mtype.method.Func + // Invoke the method, providing a new value for the reply. + returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv}) + // The return value for the method is an error. + errInter := returnValues[0].Interface() + errmsg := "" + if errInter != nil { + errmsg = errInter.(error).Error() + } + server.sendResponse(sending, req, replyv.Interface(), codec, errmsg) + server.freeRequest(req) +} + +type gobServerCodec struct { + rwc io.ReadWriteCloser + dec *gob.Decoder + enc *gob.Encoder + encBuf *bufio.Writer +} + +func (c *gobServerCodec) ReadRequestHeader(r *Request) error { + return c.dec.Decode(r) +} + +func (c *gobServerCodec) ReadRequestBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err error) { + if err = c.enc.Encode(r); err != nil { + return + } + if err = c.enc.Encode(body); err != nil { + return + } + return c.encBuf.Flush() +} + +func (c *gobServerCodec) Close() error { + return c.rwc.Close() +} + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +// ServeConn uses the gob wire format (see package gob) on the +// connection. To use an alternate codec, use ServeCodec. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + buf := bufio.NewWriter(conn) + srv := &gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(buf), buf} + server.ServeCodec(srv) +} + +// ServeCodec is like ServeConn but uses the specified codec to +// decode requests and encode responses. +func (server *Server) ServeCodec(codec ServerCodec) { + sending := new(sync.Mutex) + for { + service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) + if err != nil { + if err != io.EOF { + log.Println("rpc:", err) + } + if !keepReading { + break + } + // send a response if we actually managed to read a header. + if req != nil { + server.sendResponse(sending, req, invalidRequest, codec, err.Error()) + server.freeRequest(req) + } + continue + } + go service.call(server, sending, mtype, req, argv, replyv, codec) + } + codec.Close() +} + +// ServeRequest is like ServeCodec but synchronously serves a single request. +// It does not close the codec upon completion. +func (server *Server) ServeRequest(codec ServerCodec) error { + sending := new(sync.Mutex) + service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) + if err != nil { + if !keepReading { + return err + } + // send a response if we actually managed to read a header. + if req != nil { + server.sendResponse(sending, req, invalidRequest, codec, err.Error()) + server.freeRequest(req) + } + return err + } + service.call(server, sending, mtype, req, argv, replyv, codec) + return nil +} + +func (server *Server) getRequest() *Request { + server.reqLock.Lock() + req := server.freeReq + if req == nil { + req = new(Request) + } else { + server.freeReq = req.next + *req = Request{} + } + server.reqLock.Unlock() + return req +} + +func (server *Server) freeRequest(req *Request) { + server.reqLock.Lock() + req.next = server.freeReq + server.freeReq = req + server.reqLock.Unlock() +} + +func (server *Server) getResponse() *Response { + server.respLock.Lock() + resp := server.freeResp + if resp == nil { + resp = new(Response) + } else { + server.freeResp = resp.next + *resp = Response{} + } + server.respLock.Unlock() + return resp +} + +func (server *Server) freeResponse(resp *Response) { + server.respLock.Lock() + resp.next = server.freeResp + server.freeResp = resp + server.respLock.Unlock() +} + +func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) { + service, mtype, req, keepReading, err = server.readRequestHeader(codec) + if err != nil { + if !keepReading { + return + } + // discard body + codec.ReadRequestBody(nil) + return + } + + // Decode the argument value. + argIsValue := false // if true, need to indirect before calling. + if mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(mtype.ArgType.Elem()) + } else { + argv = reflect.New(mtype.ArgType) + argIsValue = true + } + // argv guaranteed to be a pointer now. + if err = codec.ReadRequestBody(argv.Interface()); err != nil { + return + } + if argIsValue { + argv = argv.Elem() + } + + replyv = reflect.New(mtype.ReplyType.Elem()) + return +} + +func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, keepReading bool, err error) { + // Grab the request header. + req = server.getRequest() + err = codec.ReadRequestHeader(req) + if err != nil { + req = nil + if err == io.EOF || err == io.ErrUnexpectedEOF { + return + } + err = errors.New("rpc: server cannot decode request: " + err.Error()) + return + } + + // We read the header successfully. If we see an error now, + // we can still recover and move on to the next request. + keepReading = true + + serviceMethod := strings.Split(req.ServiceMethod, ".") + if len(serviceMethod) != 2 { + err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod) + return + } + // Look up the request. + server.mu.Lock() + service = server.serviceMap[serviceMethod[0]] + server.mu.Unlock() + if service == nil { + err = errors.New("rpc: can't find service " + req.ServiceMethod) + return + } + mtype = service.method[serviceMethod[1]] + if mtype == nil { + err = errors.New("rpc: can't find method " + req.ServiceMethod) + } + return +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. Accept blocks; the caller typically +// invokes it in a go statement. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Fatal("rpc.Serve: accept:", err.Error()) // TODO(r): exit? + } + go server.ServeConn(conn) + } +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } + +// RegisterName is like Register but uses the provided name for the type +// instead of the receiver's concrete type. +func RegisterName(name string, rcvr interface{}) error { + return DefaultServer.RegisterName(name, rcvr) +} + +// A ServerCodec implements reading of RPC requests and writing of +// RPC responses for the server side of an RPC session. +// The server calls ReadRequestHeader and ReadRequestBody in pairs +// to read requests from the connection, and it calls WriteResponse to +// write a response back. The server calls Close when finished with the +// connection. ReadRequestBody may be called with a nil +// argument to force the body of the request to be read and discarded. +type ServerCodec interface { + ReadRequestHeader(*Request) error + ReadRequestBody(interface{}) error + WriteResponse(*Response, interface{}) error + + Close() error +} + +// ServeConn runs the DefaultServer on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +// ServeConn uses the gob wire format (see package gob) on the +// connection. To use an alternate codec, use ServeCodec. +func ServeConn(conn io.ReadWriteCloser) { + DefaultServer.ServeConn(conn) +} + +// ServeCodec is like ServeConn but uses the specified codec to +// decode requests and encode responses. +func ServeCodec(codec ServerCodec) { + DefaultServer.ServeCodec(codec) +} + +// ServeRequest is like ServeCodec but synchronously serves a single request. +// It does not close the codec upon completion. +func ServeRequest(codec ServerCodec) error { + return DefaultServer.ServeRequest(codec) +} + +// Accept accepts connections on the listener and serves requests +// to DefaultServer for each incoming connection. +// Accept blocks; the caller typically invokes it in a go statement. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Can connect to RPC service using HTTP CONNECT to rpcPath. +var connected = "200 Connected to Go RPC" + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) + return + } + io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP(rpcPath, debugPath string) { + http.Handle(rpcPath, server) + http.Handle(debugPath, debugHTTP{server}) +} + +// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer +// on DefaultRPCPath and a debugging handler on DefaultDebugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func HandleHTTP() { + DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath) +} diff --git a/src/pkg/net/rpc/server_test.go b/src/pkg/net/rpc/server_test.go new file mode 100644 index 000000000..b05c63c05 --- /dev/null +++ b/src/pkg/net/rpc/server_test.go @@ -0,0 +1,596 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "net/http/httptest" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +var ( + newServer *Server + serverAddr, newServerAddr string + httpServerAddr string + once, newOnce, httpOnce sync.Once +) + +const ( + newHttpPath = "/foo" +) + +type Args struct { + A, B int +} + +type Reply struct { + C int +} + +type Arith int + +// Some of Arith's methods have value args, some have pointer args. That's deliberate. + +func (t *Arith) Add(args Args, reply *Reply) error { + reply.C = args.A + args.B + return nil +} + +func (t *Arith) Mul(args *Args, reply *Reply) error { + reply.C = args.A * args.B + return nil +} + +func (t *Arith) Div(args Args, reply *Reply) error { + if args.B == 0 { + return errors.New("divide by zero") + } + reply.C = args.A / args.B + return nil +} + +func (t *Arith) String(args *Args, reply *string) error { + *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) + return nil +} + +func (t *Arith) Scan(args string, reply *Reply) (err error) { + _, err = fmt.Sscan(args, &reply.C) + return +} + +func (t *Arith) Error(args *Args, reply *Reply) error { + panic("ERROR") +} + +func listenTCP() (net.Listener, string) { + l, e := net.Listen("tcp", "127.0.0.1:0") // any available address + if e != nil { + log.Fatalf("net.Listen tcp :0: %v", e) + } + return l, l.Addr().String() +} + +func startServer() { + Register(new(Arith)) + + var l net.Listener + l, serverAddr = listenTCP() + log.Println("Test RPC server listening on", serverAddr) + go Accept(l) + + HandleHTTP() + httpOnce.Do(startHttpServer) +} + +func startNewServer() { + newServer = NewServer() + newServer.Register(new(Arith)) + + var l net.Listener + l, newServerAddr = listenTCP() + log.Println("NewServer test RPC server listening on", newServerAddr) + go Accept(l) + + newServer.HandleHTTP(newHttpPath, "/bar") + httpOnce.Do(startHttpServer) +} + +func startHttpServer() { + server := httptest.NewServer(nil) + httpServerAddr = server.Listener.Addr().String() + log.Println("Test HTTP RPC server listening on", httpServerAddr) +} + +func TestRPC(t *testing.T) { + once.Do(startServer) + testRPC(t, serverAddr) + newOnce.Do(startNewServer) + testRPC(t, newServerAddr) +} + +func testRPC(t *testing.T, addr string) { + client, err := Dial("tcp", addr) + if err != nil { + t.Fatal("dialing", err) + } + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + + // Nonexistent method + args = &Args{7, 0} + reply = new(Reply) + err = client.Call("Arith.BadOperation", args, reply) + // expect an error + if err == nil { + t.Error("BadOperation: expected error") + } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") { + t.Errorf("BadOperation: expected can't find method error; got %q", err) + } + + // Unknown service + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Unknown", args, reply) + if err == nil { + t.Error("expected error calling unknown service") + } else if strings.Index(err.Error(), "method") < 0 { + t.Error("expected error about method; got", err) + } + + // Out of order. + args = &Args{7, 8} + mulReply := new(Reply) + mulCall := client.Go("Arith.Mul", args, mulReply, nil) + addReply := new(Reply) + addCall := client.Go("Arith.Add", args, addReply, nil) + + addCall = <-addCall.Done + if addCall.Error != nil { + t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) + } + if addReply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) + } + + mulCall = <-mulCall.Done + if mulCall.Error != nil { + t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) + } + if mulReply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) + } + + // Error test + args = &Args{7, 0} + reply = new(Reply) + err = client.Call("Arith.Div", args, reply) + // expect an error: zero divide + if err == nil { + t.Error("Div: expected error") + } else if err.Error() != "divide by zero" { + t.Error("Div: expected divide by zero error; got", err) + } + + // Bad type. + reply = new(Reply) + err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use + if err == nil { + t.Error("expected error calling Arith.Add with wrong arg type") + } else if strings.Index(err.Error(), "type") < 0 { + t.Error("expected error about type; got", err) + } + + // Non-struct argument + const Val = 12345 + str := fmt.Sprint(Val) + reply = new(Reply) + err = client.Call("Arith.Scan", &str, reply) + if err != nil { + t.Errorf("Scan: expected no error but got string %q", err.Error()) + } else if reply.C != Val { + t.Errorf("Scan: expected %d got %d", Val, reply.C) + } + + // Non-struct reply + args = &Args{27, 35} + str = "" + err = client.Call("Arith.String", args, &str) + if err != nil { + t.Errorf("String: expected no error but got string %q", err.Error()) + } + expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) + if str != expect { + t.Errorf("String: expected %s got %s", expect, str) + } + + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Mul", args, reply) + if err != nil { + t.Errorf("Mul: expected no error but got string %q", err.Error()) + } + if reply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + } +} + +func TestHTTP(t *testing.T) { + once.Do(startServer) + testHTTPRPC(t, "") + newOnce.Do(startNewServer) + testHTTPRPC(t, newHttpPath) +} + +func testHTTPRPC(t *testing.T, path string) { + var client *Client + var err error + if path == "" { + client, err = DialHTTP("tcp", httpServerAddr) + } else { + client, err = DialHTTPPath("tcp", httpServerAddr, path) + } + if err != nil { + t.Fatal("dialing", err) + } + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } +} + +// CodecEmulator provides a client-like api and a ServerCodec interface. +// Can be used to test ServeRequest. +type CodecEmulator struct { + server *Server + serviceMethod string + args *Args + reply *Reply + err error +} + +func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error { + codec.serviceMethod = serviceMethod + codec.args = args + codec.reply = reply + codec.err = nil + var serverError error + if codec.server == nil { + serverError = ServeRequest(codec) + } else { + serverError = codec.server.ServeRequest(codec) + } + if codec.err == nil && serverError != nil { + codec.err = serverError + } + return codec.err +} + +func (codec *CodecEmulator) ReadRequestHeader(req *Request) error { + req.ServiceMethod = codec.serviceMethod + req.Seq = 0 + return nil +} + +func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error { + if codec.args == nil { + return io.ErrUnexpectedEOF + } + *(argv.(*Args)) = *codec.args + return nil +} + +func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error { + if resp.Error != "" { + codec.err = errors.New(resp.Error) + } else { + *codec.reply = *(reply.(*Reply)) + } + return nil +} + +func (codec *CodecEmulator) Close() error { + return nil +} + +func TestServeRequest(t *testing.T) { + once.Do(startServer) + testServeRequest(t, nil) + newOnce.Do(startNewServer) + testServeRequest(t, newServer) +} + +func testServeRequest(t *testing.T, server *Server) { + client := CodecEmulator{server: server} + + args := &Args{7, 8} + reply := new(Reply) + err := client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + + err = client.Call("Arith.Add", nil, reply) + if err == nil { + t.Errorf("expected error calling Arith.Add with nil arg") + } +} + +type ReplyNotPointer int +type ArgNotPublic int +type ReplyNotPublic int +type local struct{} + +func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error { + return nil +} + +func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error { + return nil +} + +func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error { + return nil +} + +// Check that registration handles lots of bad methods and a type with no suitable methods. +func TestRegistrationError(t *testing.T) { + err := Register(new(ReplyNotPointer)) + if err == nil { + t.Errorf("expected error registering ReplyNotPointer") + } + err = Register(new(ArgNotPublic)) + if err == nil { + t.Errorf("expected error registering ArgNotPublic") + } + err = Register(new(ReplyNotPublic)) + if err == nil { + t.Errorf("expected error registering ReplyNotPublic") + } +} + +type WriteFailCodec int + +func (WriteFailCodec) WriteRequest(*Request, interface{}) error { + // the panic caused by this error used to not unlock a lock. + return errors.New("fail") +} + +func (WriteFailCodec) ReadResponseHeader(*Response) error { + time.Sleep(120 * time.Second) + panic("unreachable") +} + +func (WriteFailCodec) ReadResponseBody(interface{}) error { + time.Sleep(120 * time.Second) + panic("unreachable") +} + +func (WriteFailCodec) Close() error { + return nil +} + +func TestSendDeadlock(t *testing.T) { + client := NewClientWithCodec(WriteFailCodec(0)) + + done := make(chan bool) + go func() { + testSendDeadlock(client) + testSendDeadlock(client) + done <- true + }() + select { + case <-done: + return + case <-time.After(5 * time.Second): + t.Fatal("deadlock") + } +} + +func testSendDeadlock(client *Client) { + defer func() { + recover() + }() + args := &Args{7, 8} + reply := new(Reply) + client.Call("Arith.Add", args, reply) +} + +func dialDirect() (*Client, error) { + return Dial("tcp", serverAddr) +} + +func dialHTTP() (*Client, error) { + return DialHTTP("tcp", httpServerAddr) +} + +func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { + once.Do(startServer) + client, err := dial() + if err != nil { + t.Fatal("error dialing", err) + } + args := &Args{7, 8} + reply := new(Reply) + runtime.UpdateMemStats() + mallocs := 0 - runtime.MemStats.Mallocs + const count = 100 + for i := 0; i < count; i++ { + err := client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + } + runtime.UpdateMemStats() + mallocs += runtime.MemStats.Mallocs + return mallocs / count +} + +func TestCountMallocs(t *testing.T) { + fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t)) +} + +func TestCountMallocsOverHTTP(t *testing.T) { + fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t)) +} + +type writeCrasher struct { + done chan bool +} + +func (writeCrasher) Close() error { + return nil +} + +func (w *writeCrasher) Read(p []byte) (int, error) { + <-w.done + return 0, io.EOF +} + +func (writeCrasher) Write(p []byte) (int, error) { + return 0, errors.New("fake write failure") +} + +func TestClientWriteError(t *testing.T) { + w := &writeCrasher{done: make(chan bool)} + c := NewClient(w) + res := false + err := c.Call("foo", 1, &res) + if err == nil { + t.Fatal("expected error") + } + if err.Error() != "fake write failure" { + t.Error("unexpected value of error:", err) + } + w.done <- true +} + +func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { + b.StopTimer() + once.Do(startServer) + client, err := dial() + if err != nil { + b.Fatal("error dialing:", err) + } + + // Synchronous calls + args := &Args{7, 8} + procs := runtime.GOMAXPROCS(-1) + N := int32(b.N) + var wg sync.WaitGroup + wg.Add(procs) + b.StartTimer() + + for p := 0; p < procs; p++ { + go func() { + reply := new(Reply) + for atomic.AddInt32(&N, -1) >= 0 { + err := client.Call("Arith.Add", args, reply) + if err != nil { + b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) + } + } + wg.Done() + }() + } + wg.Wait() +} + +func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { + const MaxConcurrentCalls = 100 + b.StopTimer() + once.Do(startServer) + client, err := dial() + if err != nil { + b.Fatal("error dialing:", err) + } + + // Asynchronous calls + args := &Args{7, 8} + procs := 4 * runtime.GOMAXPROCS(-1) + send := int32(b.N) + recv := int32(b.N) + var wg sync.WaitGroup + wg.Add(procs) + gate := make(chan bool, MaxConcurrentCalls) + res := make(chan *Call, MaxConcurrentCalls) + b.StartTimer() + + for p := 0; p < procs; p++ { + go func() { + for atomic.AddInt32(&send, -1) >= 0 { + gate <- true + reply := new(Reply) + client.Go("Arith.Add", args, reply, res) + } + }() + go func() { + for call := range res { + A := call.Args.(*Args).A + B := call.Args.(*Args).B + C := call.Reply.(*Reply).C + if A+B != C { + b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C) + } + <-gate + if atomic.AddInt32(&recv, -1) == 0 { + close(res) + } + } + wg.Done() + }() + } + wg.Wait() +} + +func BenchmarkEndToEnd(b *testing.B) { + benchmarkEndToEnd(dialDirect, b) +} + +func BenchmarkEndToEndHTTP(b *testing.B) { + benchmarkEndToEnd(dialHTTP, b) +} + +func BenchmarkEndToEndAsync(b *testing.B) { + benchmarkEndToEndAsync(dialDirect, b) +} + +func BenchmarkEndToEndAsyncHTTP(b *testing.B) { + benchmarkEndToEndAsync(dialHTTP, b) +} diff --git a/src/pkg/net/sendfile_linux.go b/src/pkg/net/sendfile_linux.go index 6a5a06c8c..e9ab06666 100644 --- a/src/pkg/net/sendfile_linux.go +++ b/src/pkg/net/sendfile_linux.go @@ -21,7 +21,7 @@ const maxSendfileSize int = 4 << 20 // non-EOF error. // // if handled == false, sendFile performed no work. -func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool) { +func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { var remain int64 = 1 << 62 // by default, copy until EOF lr, ok := r.(*io.LimitedReader) @@ -40,15 +40,6 @@ func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool) defer c.wio.Unlock() c.incref() defer c.decref() - if c.wdeadline_delta > 0 { - // This is a little odd that we're setting the timeout - // for the entire file but Write has the same issue - // (if one slurps the whole file into memory and - // do one large Write). At least they're consistent. - c.wdeadline = pollserver.Now() + c.wdeadline_delta - } else { - c.wdeadline = 0 - } dst := c.sysfd src := f.Fd() @@ -62,18 +53,18 @@ func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool) written += int64(n) remain -= int64(n) } - if n == 0 && errno == 0 { + if n == 0 && errno == nil { break } if errno == syscall.EAGAIN && c.wdeadline >= 0 { pollserver.WaitWrite(c) continue } - if errno != 0 { + if errno != nil { // This includes syscall.ENOSYS (no kernel // support) and syscall.EINVAL (fd types which // don't implement sendfile together) - err = &OpError{"sendfile", c.net, c.raddr, os.Errno(errno)} + err = &OpError{"sendfile", c.net, c.raddr, errno} break } } diff --git a/src/pkg/net/sendfile_stub.go b/src/pkg/net/sendfile_stub.go index c55be6c08..ff76ab9cf 100644 --- a/src/pkg/net/sendfile_stub.go +++ b/src/pkg/net/sendfile_stub.go @@ -2,15 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd openbsd +// +build darwin freebsd netbsd openbsd package net -import ( - "io" - "os" -) +import "io" -func sendFile(c *netFD, r io.Reader) (n int64, err os.Error, handled bool) { +func sendFile(c *netFD, r io.Reader) (n int64, err error, handled bool) { return 0, nil, false } diff --git a/src/pkg/net/sendfile_windows.go b/src/pkg/net/sendfile_windows.go index d9c2f537a..ee7ff8b98 100644 --- a/src/pkg/net/sendfile_windows.go +++ b/src/pkg/net/sendfile_windows.go @@ -16,7 +16,7 @@ type sendfileOp struct { n uint32 } -func (o *sendfileOp) Submit() (errno int) { +func (o *sendfileOp) Submit() (err error) { return syscall.TransmitFile(o.fd.sysfd, o.src, o.n, 0, &o.o, nil, syscall.TF_WRITE_BEHIND) } @@ -33,7 +33,7 @@ func (o *sendfileOp) Name() string { // if handled == false, sendFile performed no work. // // Note that sendfile for windows does not suppport >2GB file. -func sendFile(c *netFD, r io.Reader) (written int64, err os.Error, handled bool) { +func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { var n int64 = 0 // by default, copy until EOF lr, ok := r.(*io.LimitedReader) diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go index a2ff218e7..b0b546be3 100644 --- a/src/pkg/net/server_test.go +++ b/src/pkg/net/server_test.go @@ -8,10 +8,10 @@ import ( "flag" "io" "os" + "runtime" "strings" - "syscall" "testing" - "runtime" + "time" ) // Do not test empty datagrams by default. @@ -55,7 +55,7 @@ func runServe(t *testing.T, network, addr string, listening chan<- string, done func connect(t *testing.T, network, addr string, isEmpty bool) { var fd Conn - var err os.Error + var err error if network == "unixgram" { fd, err = DialUnix(network, &UnixAddr{addr + ".local", network}, &UnixAddr{addr, network}) } else { @@ -64,7 +64,7 @@ func connect(t *testing.T, network, addr string, isEmpty bool) { if err != nil { t.Fatalf("net.Dial(%q, %q) = _, %v", network, addr, err) } - fd.SetReadTimeout(1e9) // 1s + fd.SetReadDeadline(time.Now().Add(1 * time.Second)) var b []byte if !isEmpty { @@ -92,7 +92,7 @@ func connect(t *testing.T, network, addr string, isEmpty bool) { } func doTest(t *testing.T, network, listenaddr, dialaddr string) { - t.Logf("Test %q %q %q\n", network, listenaddr, dialaddr) + t.Logf("Test %q %q %q", network, listenaddr, dialaddr) switch listenaddr { case "", "0.0.0.0", "[::]", "[::ffff:0.0.0.0]": if testing.Short() || avoidMacFirewall { @@ -115,7 +115,7 @@ func doTest(t *testing.T, network, listenaddr, dialaddr string) { } func TestTCPServer(t *testing.T) { - if syscall.OS != "openbsd" { + if runtime.GOOS != "openbsd" { doTest(t, "tcp", "", "127.0.0.1") } doTest(t, "tcp", "0.0.0.0", "127.0.0.1") @@ -155,7 +155,7 @@ func TestUnixServer(t *testing.T) { os.Remove("/tmp/gotest.net") doTest(t, "unix", "/tmp/gotest.net", "/tmp/gotest.net") os.Remove("/tmp/gotest.net") - if syscall.OS == "linux" { + if runtime.GOOS == "linux" { doTest(t, "unixpacket", "/tmp/gotest.net", "/tmp/gotest.net") os.Remove("/tmp/gotest.net") // Test abstract unix domain socket, a Linux-ism @@ -170,10 +170,10 @@ func runPacket(t *testing.T, network, addr string, listening chan<- string, done t.Fatalf("net.ListenPacket(%q, %q) = _, %v", network, addr, err) } listening <- c.LocalAddr().String() - c.SetReadTimeout(10e6) // 10ms var buf [1000]byte Run: for { + c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) n, addr, err := c.ReadFrom(buf[0:]) if e, ok := err.(Error); ok && e.Timeout() { select { @@ -195,7 +195,7 @@ Run: } func doTestPacket(t *testing.T, network, listenaddr, dialaddr string, isEmpty bool) { - t.Logf("TestPacket %s %s %s\n", network, listenaddr, dialaddr) + t.Logf("TestPacket %q %q %q", network, listenaddr, dialaddr) listening := make(chan string) done := make(chan int) if network == "udp" { @@ -237,7 +237,7 @@ func TestUnixDatagramServer(t *testing.T) { doTestPacket(t, "unixgram", "/tmp/gotest1.net", "/tmp/gotest1.net", isEmpty) os.Remove("/tmp/gotest1.net") os.Remove("/tmp/gotest1.net.local") - if syscall.OS == "linux" { + if runtime.GOOS == "linux" { // Test abstract unix domain socket, a Linux-ism doTestPacket(t, "unixgram", "@gotest1/net", "@gotest1/net", isEmpty) } diff --git a/src/pkg/net/smtp/Makefile b/src/pkg/net/smtp/Makefile new file mode 100644 index 000000000..d9812d5cb --- /dev/null +++ b/src/pkg/net/smtp/Makefile @@ -0,0 +1,12 @@ +# Copyright 2010 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=net/smtp +GOFILES=\ + auth.go\ + smtp.go\ + +include ../../../Make.pkg diff --git a/src/pkg/net/smtp/auth.go b/src/pkg/net/smtp/auth.go new file mode 100644 index 000000000..d401e3c21 --- /dev/null +++ b/src/pkg/net/smtp/auth.go @@ -0,0 +1,98 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package smtp + +import ( + "crypto/hmac" + "crypto/md5" + "errors" + "fmt" +) + +// Auth is implemented by an SMTP authentication mechanism. +type Auth interface { + // Start begins an authentication with a server. + // It returns the name of the authentication protocol + // and optionally data to include in the initial AUTH message + // sent to the server. It can return proto == "" to indicate + // that the authentication should be skipped. + // If it returns a non-nil error, the SMTP client aborts + // the authentication attempt and closes the connection. + Start(server *ServerInfo) (proto string, toServer []byte, err error) + + // Next continues the authentication. The server has just sent + // the fromServer data. If more is true, the server expects a + // response, which Next should return as toServer; otherwise + // Next should return toServer == nil. + // If Next returns a non-nil error, the SMTP client aborts + // the authentication attempt and closes the connection. + Next(fromServer []byte, more bool) (toServer []byte, err error) +} + +// ServerInfo records information about an SMTP server. +type ServerInfo struct { + Name string // SMTP server name + TLS bool // using TLS, with valid certificate for Name + Auth []string // advertised authentication mechanisms +} + +type plainAuth struct { + identity, username, password string + host string +} + +// PlainAuth returns an Auth that implements the PLAIN authentication +// mechanism as defined in RFC 4616. +// The returned Auth uses the given username and password to authenticate +// on TLS connections to host and act as identity. Usually identity will be +// left blank to act as username. +func PlainAuth(identity, username, password, host string) Auth { + return &plainAuth{identity, username, password, host} +} + +func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { + if !server.TLS { + return "", nil, errors.New("unencrypted connection") + } + if server.Name != a.host { + return "", nil, errors.New("wrong host name") + } + resp := []byte(a.identity + "\x00" + a.username + "\x00" + a.password) + return "PLAIN", resp, nil +} + +func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + // We've already sent everything. + return nil, errors.New("unexpected server challenge") + } + return nil, nil +} + +type cramMD5Auth struct { + username, secret string +} + +// CRAMMD5Auth returns an Auth that implements the CRAM-MD5 authentication +// mechanism as defined in RFC 2195. +// The returned Auth uses the given username and secret to authenticate +// to the server using the challenge-response mechanism. +func CRAMMD5Auth(username, secret string) Auth { + return &cramMD5Auth{username, secret} +} + +func (a *cramMD5Auth) Start(server *ServerInfo) (string, []byte, error) { + return "CRAM-MD5", nil, nil +} + +func (a *cramMD5Auth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + d := hmac.New(md5.New, []byte(a.secret)) + d.Write(fromServer) + s := make([]byte, 0, d.Size()) + return []byte(fmt.Sprintf("%s %x", a.username, d.Sum(s))), nil + } + return nil, nil +} diff --git a/src/pkg/net/smtp/smtp.go b/src/pkg/net/smtp/smtp.go new file mode 100644 index 000000000..8d935ffb7 --- /dev/null +++ b/src/pkg/net/smtp/smtp.go @@ -0,0 +1,293 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package smtp implements the Simple Mail Transfer Protocol as defined in RFC 5321. +// It also implements the following extensions: +// 8BITMIME RFC 1652 +// AUTH RFC 2554 +// STARTTLS RFC 3207 +// Additional extensions may be handled by clients. +package smtp + +import ( + "crypto/tls" + "encoding/base64" + "io" + "net" + "net/textproto" + "strings" +) + +// A Client represents a client connection to an SMTP server. +type Client struct { + // Text is the textproto.Conn used by the Client. It is exported to allow for + // clients to add extensions. + Text *textproto.Conn + // keep a reference to the connection so it can be used to create a TLS + // connection later + conn net.Conn + // whether the Client is using TLS + tls bool + serverName string + // map of supported extensions + ext map[string]string + // supported auth mechanisms + auth []string +} + +// Dial returns a new Client connected to an SMTP server at addr. +func Dial(addr string) (*Client, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + host := addr[:strings.Index(addr, ":")] + return NewClient(conn, host) +} + +// NewClient returns a new Client using an existing connection and host as a +// server name to be used when authenticating. +func NewClient(conn net.Conn, host string) (*Client, error) { + text := textproto.NewConn(conn) + _, msg, err := text.ReadResponse(220) + if err != nil { + text.Close() + return nil, err + } + c := &Client{Text: text, conn: conn, serverName: host} + if strings.Contains(msg, "ESMTP") { + err = c.ehlo() + } else { + err = c.helo() + } + return c, err +} + +// cmd is a convenience function that sends a command and returns the response +func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { + id, err := c.Text.Cmd(format, args...) + if err != nil { + return 0, "", err + } + c.Text.StartResponse(id) + defer c.Text.EndResponse(id) + code, msg, err := c.Text.ReadResponse(expectCode) + return code, msg, err +} + +// helo sends the HELO greeting to the server. It should be used only when the +// server does not support ehlo. +func (c *Client) helo() error { + c.ext = nil + _, _, err := c.cmd(250, "HELO localhost") + return err +} + +// ehlo sends the EHLO (extended hello) greeting to the server. It +// should be the preferred greeting for servers that support it. +func (c *Client) ehlo() error { + _, msg, err := c.cmd(250, "EHLO localhost") + if err != nil { + return err + } + ext := make(map[string]string) + extList := strings.Split(msg, "\n") + if len(extList) > 1 { + extList = extList[1:] + for _, line := range extList { + args := strings.SplitN(line, " ", 2) + if len(args) > 1 { + ext[args[0]] = args[1] + } else { + ext[args[0]] = "" + } + } + } + if mechs, ok := ext["AUTH"]; ok { + c.auth = strings.Split(mechs, " ") + } + c.ext = ext + return err +} + +// StartTLS sends the STARTTLS command and encrypts all further communication. +// Only servers that advertise the STARTTLS extension support this function. +func (c *Client) StartTLS(config *tls.Config) error { + _, _, err := c.cmd(220, "STARTTLS") + if err != nil { + return err + } + c.conn = tls.Client(c.conn, config) + c.Text = textproto.NewConn(c.conn) + c.tls = true + return c.ehlo() +} + +// Verify checks the validity of an email address on the server. +// If Verify returns nil, the address is valid. A non-nil return +// does not necessarily indicate an invalid address. Many servers +// will not verify addresses for security reasons. +func (c *Client) Verify(addr string) error { + _, _, err := c.cmd(250, "VRFY %s", addr) + return err +} + +// Auth authenticates a client using the provided authentication mechanism. +// A failed authentication closes the connection. +// Only servers that advertise the AUTH extension support this function. +func (c *Client) Auth(a Auth) error { + encoding := base64.StdEncoding + mech, resp, err := a.Start(&ServerInfo{c.serverName, c.tls, c.auth}) + if err != nil { + c.Quit() + return err + } + resp64 := make([]byte, encoding.EncodedLen(len(resp))) + encoding.Encode(resp64, resp) + code, msg64, err := c.cmd(0, "AUTH %s %s", mech, resp64) + for err == nil { + var msg []byte + switch code { + case 334: + msg, err = encoding.DecodeString(msg64) + case 235: + // the last message isn't base64 because it isn't a challenge + msg = []byte(msg64) + default: + err = &textproto.Error{code, msg64} + } + resp, err = a.Next(msg, code == 334) + if err != nil { + // abort the AUTH + c.cmd(501, "*") + c.Quit() + break + } + if resp == nil { + break + } + resp64 = make([]byte, encoding.EncodedLen(len(resp))) + encoding.Encode(resp64, resp) + code, msg64, err = c.cmd(0, string(resp64)) + } + return err +} + +// Mail issues a MAIL command to the server using the provided email address. +// If the server supports the 8BITMIME extension, Mail adds the BODY=8BITMIME +// parameter. +// This initiates a mail transaction and is followed by one or more Rcpt calls. +func (c *Client) Mail(from string) error { + cmdStr := "MAIL FROM:<%s>" + if c.ext != nil { + if _, ok := c.ext["8BITMIME"]; ok { + cmdStr += " BODY=8BITMIME" + } + } + _, _, err := c.cmd(250, cmdStr, from) + return err +} + +// Rcpt issues a RCPT command to the server using the provided email address. +// A call to Rcpt must be preceded by a call to Mail and may be followed by +// a Data call or another Rcpt call. +func (c *Client) Rcpt(to string) error { + _, _, err := c.cmd(25, "RCPT TO:<%s>", to) + return err +} + +type dataCloser struct { + c *Client + io.WriteCloser +} + +func (d *dataCloser) Close() error { + d.WriteCloser.Close() + _, _, err := d.c.Text.ReadResponse(250) + return err +} + +// Data issues a DATA command to the server and returns a writer that +// can be used to write the data. The caller should close the writer +// before calling any more methods on c. +// A call to Data must be preceded by one or more calls to Rcpt. +func (c *Client) Data() (io.WriteCloser, error) { + _, _, err := c.cmd(354, "DATA") + if err != nil { + return nil, err + } + return &dataCloser{c, c.Text.DotWriter()}, nil +} + +// SendMail connects to the server at addr, switches to TLS if possible, +// authenticates with mechanism a if possible, and then sends an email from +// address from, to addresses to, with message msg. +func SendMail(addr string, a Auth, from string, to []string, msg []byte) error { + c, err := Dial(addr) + if err != nil { + return err + } + if ok, _ := c.Extension("STARTTLS"); ok { + if err = c.StartTLS(nil); err != nil { + return err + } + } + if a != nil && c.ext != nil { + if _, ok := c.ext["AUTH"]; ok { + if err = c.Auth(a); err != nil { + return err + } + } + } + if err = c.Mail(from); err != nil { + return err + } + for _, addr := range to { + if err = c.Rcpt(addr); err != nil { + return err + } + } + w, err := c.Data() + if err != nil { + return err + } + _, err = w.Write(msg) + if err != nil { + return err + } + err = w.Close() + if err != nil { + return err + } + return c.Quit() +} + +// Extension reports whether an extension is support by the server. +// The extension name is case-insensitive. If the extension is supported, +// Extension also returns a string that contains any parameters the +// server specifies for the extension. +func (c *Client) Extension(ext string) (bool, string) { + if c.ext == nil { + return false, "" + } + ext = strings.ToUpper(ext) + param, ok := c.ext[ext] + return ok, param +} + +// Reset sends the RSET command to the server, aborting the current mail +// transaction. +func (c *Client) Reset() error { + _, _, err := c.cmd(250, "RSET") + return err +} + +// Quit sends the QUIT command and closes the connection to the server. +func (c *Client) Quit() error { + _, _, err := c.cmd(221, "QUIT") + if err != nil { + return err + } + return c.Text.Close() +} diff --git a/src/pkg/net/smtp/smtp_test.go b/src/pkg/net/smtp/smtp_test.go new file mode 100644 index 000000000..ce8878205 --- /dev/null +++ b/src/pkg/net/smtp/smtp_test.go @@ -0,0 +1,182 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package smtp + +import ( + "bufio" + "bytes" + "io" + "net/textproto" + "strings" + "testing" +) + +type authTest struct { + auth Auth + challenges []string + name string + responses []string +} + +var authTests = []authTest{ + {PlainAuth("", "user", "pass", "testserver"), []string{}, "PLAIN", []string{"\x00user\x00pass"}}, + {PlainAuth("foo", "bar", "baz", "testserver"), []string{}, "PLAIN", []string{"foo\x00bar\x00baz"}}, + {CRAMMD5Auth("user", "pass"), []string{"<123456.1322876914@testserver>"}, "CRAM-MD5", []string{"", "user 287eb355114cf5c471c26a875f1ca4ae"}}, +} + +func TestAuth(t *testing.T) { +testLoop: + for i, test := range authTests { + name, resp, err := test.auth.Start(&ServerInfo{"testserver", true, nil}) + if name != test.name { + t.Errorf("#%d got name %s, expected %s", i, name, test.name) + } + if !bytes.Equal(resp, []byte(test.responses[0])) { + t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0]) + } + if err != nil { + t.Errorf("#%d error: %s", i, err) + } + for j := range test.challenges { + challenge := []byte(test.challenges[j]) + expected := []byte(test.responses[j+1]) + resp, err := test.auth.Next(challenge, true) + if err != nil { + t.Errorf("#%d error: %s", i, err) + continue testLoop + } + if !bytes.Equal(resp, expected) { + t.Errorf("#%d got %s, expected %s", i, resp, expected) + continue testLoop + } + } + } +} + +type faker struct { + io.ReadWriter +} + +func (f faker) Close() error { + return nil +} + +func TestBasic(t *testing.T) { + basicServer = strings.Join(strings.Split(basicServer, "\n"), "\r\n") + basicClient = strings.Join(strings.Split(basicClient, "\n"), "\r\n") + + var cmdbuf bytes.Buffer + bcmdbuf := bufio.NewWriter(&cmdbuf) + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(basicServer)), bcmdbuf) + c := &Client{Text: textproto.NewConn(fake)} + + if err := c.helo(); err != nil { + t.Fatalf("HELO failed: %s", err) + } + if err := c.ehlo(); err == nil { + t.Fatalf("Expected first EHLO to fail") + } + if err := c.ehlo(); err != nil { + t.Fatalf("Second EHLO failed: %s", err) + } + + if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { + t.Fatalf("Expected AUTH supported") + } + if ok, _ := c.Extension("DSN"); ok { + t.Fatalf("Shouldn't support DSN") + } + + if err := c.Mail("user@gmail.com"); err == nil { + t.Fatalf("MAIL should require authentication") + } + + if err := c.Verify("user1@gmail.com"); err == nil { + t.Fatalf("First VRFY: expected no verification") + } + if err := c.Verify("user2@gmail.com"); err != nil { + t.Fatalf("Second VRFY: expected verification, got %s", err) + } + + // fake TLS so authentication won't complain + c.tls = true + c.serverName = "smtp.google.com" + if err := c.Auth(PlainAuth("", "user", "pass", "smtp.google.com")); err != nil { + t.Fatalf("AUTH failed: %s", err) + } + + if err := c.Mail("user@gmail.com"); err != nil { + t.Fatalf("MAIL failed: %s", err) + } + if err := c.Rcpt("golang-nuts@googlegroups.com"); err != nil { + t.Fatalf("RCPT failed: %s", err) + } + msg := `From: user@gmail.com +To: golang-nuts@googlegroups.com +Subject: Hooray for Go + +Line 1 +.Leading dot line . +Goodbye.` + w, err := c.Data() + if err != nil { + t.Fatalf("DATA failed: %s", err) + } + if _, err := w.Write([]byte(msg)); err != nil { + t.Fatalf("Data write failed: %s", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Bad data response: %s", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + if basicClient != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, basicClient) + } +} + +var basicServer = `250 mx.google.com at your service +502 Unrecognized command. +250-mx.google.com at your service +250-SIZE 35651584 +250-AUTH LOGIN PLAIN +250 8BITMIME +530 Authentication required +252 Send some mail, I'll try my best +250 User is valid +235 Accepted +250 Sender OK +250 Receiver OK +354 Go ahead +250 Data OK +221 OK +` + +var basicClient = `HELO localhost +EHLO localhost +EHLO localhost +MAIL FROM:<user@gmail.com> BODY=8BITMIME +VRFY user1@gmail.com +VRFY user2@gmail.com +AUTH PLAIN AHVzZXIAcGFzcw== +MAIL FROM:<user@gmail.com> BODY=8BITMIME +RCPT TO:<golang-nuts@googlegroups.com> +DATA +From: user@gmail.com +To: golang-nuts@googlegroups.com +Subject: Hooray for Go + +Line 1 +..Leading dot line . +Goodbye. +. +QUIT +` diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go index 366e050ff..867e328f1 100644 --- a/src/pkg/net/sock.go +++ b/src/pkg/net/sock.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd windows +// +build darwin freebsd linux netbsd openbsd windows // Sockets @@ -10,51 +10,46 @@ package net import ( "io" - "os" "reflect" "syscall" ) -// Boolean to int. -func boolint(b bool) int { - if b { - return 1 - } - return 0 -} +var listenerBacklog = maxListenerBacklog() // Generic socket creation. -func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err os.Error) { +func socket(net string, f, t, p int, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { // See ../syscall/exec.go for description of ForkLock. syscall.ForkLock.RLock() - s, e := syscall.Socket(f, p, t) - if e != 0 { + s, err := syscall.Socket(f, t, p) + if err != nil { syscall.ForkLock.RUnlock() - return nil, os.Errno(e) + return nil, err } syscall.CloseOnExec(s) syscall.ForkLock.RUnlock() - setKernelSpecificSockopt(s, f) + setDefaultSockopts(s, f, t) if la != nil { - e = syscall.Bind(s, la) - if e != 0 { + err = syscall.Bind(s, la) + if err != nil { closesocket(s) - return nil, os.Errno(e) + return nil, err } } - if fd, err = newFD(s, f, p, net); err != nil { + if fd, err = newFD(s, f, t, net); err != nil { closesocket(s) return nil, err } if ra != nil { if err = fd.connect(ra); err != nil { + closesocket(s) fd.Close() return nil, err } + fd.isConnected = true } sa, _ := syscall.Getsockname(s) @@ -66,93 +61,11 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal return fd, nil } -func setsockoptInt(fd *netFD, level, opt int, value int) os.Error { - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, level, opt, value)) -} - -func setsockoptNsec(fd *netFD, level, opt int, nsec int64) os.Error { - var tv = syscall.NsecToTimeval(nsec) - return os.NewSyscallError("setsockopt", syscall.SetsockoptTimeval(fd.sysfd, level, opt, &tv)) -} - -func setReadBuffer(fd *netFD, bytes int) os.Error { - fd.incref() - defer fd.decref() - return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes) -} - -func setWriteBuffer(fd *netFD, bytes int) os.Error { - fd.incref() - defer fd.decref() - return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes) -} - -func setReadTimeout(fd *netFD, nsec int64) os.Error { - fd.rdeadline_delta = nsec - return nil -} - -func setWriteTimeout(fd *netFD, nsec int64) os.Error { - fd.wdeadline_delta = nsec - return nil -} - -func setTimeout(fd *netFD, nsec int64) os.Error { - if e := setReadTimeout(fd, nsec); e != nil { - return e - } - return setWriteTimeout(fd, nsec) -} - -func setReuseAddr(fd *netFD, reuse bool) os.Error { - fd.incref() - defer fd.decref() - return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse)) -} - -func bindToDevice(fd *netFD, dev string) os.Error { - // TODO(rsc): call setsockopt with null-terminated string pointer - return os.EINVAL -} - -func setDontRoute(fd *netFD, dontroute bool) os.Error { - fd.incref() - defer fd.decref() - return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute)) -} - -func setKeepAlive(fd *netFD, keepalive bool) os.Error { - fd.incref() - defer fd.decref() - return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive)) -} - -func setNoDelay(fd *netFD, noDelay bool) os.Error { - fd.incref() - defer fd.decref() - return setsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay)) -} - -func setLinger(fd *netFD, sec int) os.Error { - var l syscall.Linger - if sec >= 0 { - l.Onoff = 1 - l.Linger = int32(sec) - } else { - l.Onoff = 0 - l.Linger = 0 - } - fd.incref() - defer fd.decref() - e := syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l) - return os.NewSyscallError("setsockopt", e) -} - type UnknownSocketError struct { sa syscall.Sockaddr } -func (e *UnknownSocketError) String() string { +func (e *UnknownSocketError) Error() string { return "unknown socket address type " + reflect.TypeOf(e.sa).String() } @@ -162,7 +75,7 @@ type writerOnly struct { // Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't // applicable. -func genericReadFrom(w io.Writer, r io.Reader) (n int64, err os.Error) { +func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) { // Use wrapper to hide existing r.ReadFrom from io.Copy. return io.Copy(writerOnly{w}, r) } diff --git a/src/pkg/net/sock_bsd.go b/src/pkg/net/sock_bsd.go index c59802fec..630a91ed9 100644 --- a/src/pkg/net/sock_bsd.go +++ b/src/pkg/net/sock_bsd.go @@ -1,33 +1,33 @@ -// Copyright 2011 The Go Authors. All rights reserved. +// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd +// +build darwin freebsd netbsd openbsd // Sockets for BSD variants package net import ( + "runtime" "syscall" ) -func setKernelSpecificSockopt(s, f int) { - // Allow reuse of recently-used addresses. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - - // Allow reuse of recently-used ports. - // This option is supported only in descendants of 4.4BSD, - // to make an effective multicast application and an application - // that requires quick draw possible. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1) - - // Allow broadcast. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) - - if f == syscall.AF_INET6 { - // using ip, tcp, udp, etc. - // allow both protocols even if the OS default is otherwise. - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) +func maxListenerBacklog() int { + var ( + n uint32 + err error + ) + switch runtime.GOOS { + case "darwin", "freebsd": + n, err = syscall.SysctlUint32("kern.ipc.somaxconn") + case "netbsd": + // NOTE: NetBSD has no somaxconn-like kernel state so far + case "openbsd": + n, err = syscall.SysctlUint32("kern.somaxconn") + } + if n == 0 || err != nil { + return syscall.SOMAXCONN } + return int(n) } diff --git a/src/pkg/net/sock_linux.go b/src/pkg/net/sock_linux.go index ec31e803b..2cbc34f24 100644 --- a/src/pkg/net/sock_linux.go +++ b/src/pkg/net/sock_linux.go @@ -1,4 +1,4 @@ -// Copyright 2011 The Go Authors. All rights reserved. +// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -6,20 +6,22 @@ package net -import ( - "syscall" -) +import "syscall" -func setKernelSpecificSockopt(s, f int) { - // Allow reuse of recently-used addresses. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - - // Allow broadcast. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) - - if f == syscall.AF_INET6 { - // using ip, tcp, udp, etc. - // allow both protocols even if the OS default is otherwise. - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) +func maxListenerBacklog() int { + fd, err := open("/proc/sys/net/core/somaxconn") + if err != nil { + return syscall.SOMAXCONN + } + defer fd.close() + l, ok := fd.readLine() + if !ok { + return syscall.SOMAXCONN + } + f := getFields(l) + n, _, ok := dtoi(f[0], 0) + if n == 0 || !ok { + return syscall.SOMAXCONN } + return n } diff --git a/src/pkg/net/sock_windows.go b/src/pkg/net/sock_windows.go index c6dbd0465..2d803de1f 100644 --- a/src/pkg/net/sock_windows.go +++ b/src/pkg/net/sock_windows.go @@ -1,4 +1,4 @@ -// Copyright 2011 The Go Authors. All rights reserved. +// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -6,20 +6,9 @@ package net -import ( - "syscall" -) +import "syscall" -func setKernelSpecificSockopt(s syscall.Handle, f int) { - // Allow reuse of recently-used addresses and ports. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - - // Allow broadcast. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) - - if f == syscall.AF_INET6 { - // using ip, tcp, udp, etc. - // allow both protocols even if the OS default is otherwise. - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) - } +func maxListenerBacklog() int { + // TODO: Implement this + return syscall.SOMAXCONN } diff --git a/src/pkg/net/sockopt.go b/src/pkg/net/sockopt.go new file mode 100644 index 000000000..3d0f8dd7a --- /dev/null +++ b/src/pkg/net/sockopt.go @@ -0,0 +1,180 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd windows + +// Socket options + +package net + +import ( + "bytes" + "os" + "syscall" + "time" +) + +// Boolean to int. +func boolint(b bool) int { + if b { + return 1 + } + return 0 +} + +func ipv4AddrToInterface(ip IP) (*Interface, error) { + ift, err := Interfaces() + if err != nil { + return nil, err + } + for _, ifi := range ift { + ifat, err := ifi.Addrs() + if err != nil { + return nil, err + } + for _, ifa := range ifat { + switch v := ifa.(type) { + case *IPAddr: + if ip.Equal(v.IP) { + return &ifi, nil + } + case *IPNet: + if ip.Equal(v.IP) { + return &ifi, nil + } + } + } + } + if ip.Equal(IPv4zero) { + return nil, nil + } + return nil, errNoSuchInterface +} + +func interfaceToIPv4Addr(ifi *Interface) (IP, error) { + if ifi == nil { + return IPv4zero, nil + } + ifat, err := ifi.Addrs() + if err != nil { + return nil, err + } + for _, ifa := range ifat { + switch v := ifa.(type) { + case *IPAddr: + if v.IP.To4() != nil { + return v.IP, nil + } + case *IPNet: + if v.IP.To4() != nil { + return v.IP, nil + } + } + } + return nil, errNoSuchInterface +} + +func setIPv4MreqToInterface(mreq *syscall.IPMreq, ifi *Interface) error { + if ifi == nil { + return nil + } + ifat, err := ifi.Addrs() + if err != nil { + return err + } + for _, ifa := range ifat { + switch v := ifa.(type) { + case *IPAddr: + if a := v.IP.To4(); a != nil { + copy(mreq.Interface[:], a) + goto done + } + case *IPNet: + if a := v.IP.To4(); a != nil { + copy(mreq.Interface[:], a) + goto done + } + } + } +done: + if bytes.Equal(mreq.Multiaddr[:], IPv4zero.To4()) { + return errNoSuchMulticastInterface + } + return nil +} + +func setReadBuffer(fd *netFD, bytes int) error { + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes)) +} + +func setWriteBuffer(fd *netFD, bytes int) error { + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)) +} + +func setReadDeadline(fd *netFD, t time.Time) error { + if t.IsZero() { + fd.rdeadline = 0 + } else { + fd.rdeadline = t.UnixNano() + } + return nil +} + +func setWriteDeadline(fd *netFD, t time.Time) error { + if t.IsZero() { + fd.wdeadline = 0 + } else { + fd.wdeadline = t.UnixNano() + } + return nil +} + +func setDeadline(fd *netFD, t time.Time) error { + if e := setReadDeadline(fd, t); e != nil { + return e + } + return setWriteDeadline(fd, t) +} + +func setReuseAddr(fd *netFD, reuse bool) error { + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse))) +} + +func setDontRoute(fd *netFD, dontroute bool) error { + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute))) +} + +func setKeepAlive(fd *netFD, keepalive bool) error { + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive))) +} + +func setNoDelay(fd *netFD, noDelay bool) error { + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay))) +} + +func setLinger(fd *netFD, sec int) error { + var l syscall.Linger + if sec >= 0 { + l.Onoff = 1 + l.Linger = int32(sec) + } else { + l.Onoff = 0 + l.Linger = 0 + } + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l)) +} diff --git a/src/pkg/net/sockopt_bsd.go b/src/pkg/net/sockopt_bsd.go new file mode 100644 index 000000000..2093e0812 --- /dev/null +++ b/src/pkg/net/sockopt_bsd.go @@ -0,0 +1,45 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd netbsd openbsd + +// Socket options for BSD variants + +package net + +import ( + "syscall" +) + +func setDefaultSockopts(s, f, t int) { + switch f { + case syscall.AF_INET6: + // Allow both IP versions even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } + + if f == syscall.AF_UNIX || + (f == syscall.AF_INET || f == syscall.AF_INET6) && t == syscall.SOCK_STREAM { + // Allow reuse of recently-used addresses. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + + // Allow reuse of recently-used ports. + // This option is supported only in descendants of 4.4BSD, + // to make an effective multicast application and an application + // that requires quick draw possible. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1) + } + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) +} + +func setDefaultMulticastSockopts(fd *netFD) { + fd.incref() + defer fd.decref() + // Allow multicast UDP and raw IP datagram sockets to listen + // concurrently across multiple listeners. + syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1) +} diff --git a/src/pkg/net/sockopt_linux.go b/src/pkg/net/sockopt_linux.go new file mode 100644 index 000000000..9dbb4e5dd --- /dev/null +++ b/src/pkg/net/sockopt_linux.go @@ -0,0 +1,37 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Socket options for Linux + +package net + +import ( + "syscall" +) + +func setDefaultSockopts(s, f, t int) { + switch f { + case syscall.AF_INET6: + // Allow both IP versions even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } + + if f == syscall.AF_UNIX || + (f == syscall.AF_INET || f == syscall.AF_INET6) && t == syscall.SOCK_STREAM { + // Allow reuse of recently-used addresses. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + } + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + +} + +func setDefaultMulticastSockopts(fd *netFD) { + fd.incref() + defer fd.decref() + // Allow multicast UDP and raw IP datagram sockets to listen + // concurrently across multiple listeners. + syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) +} diff --git a/src/pkg/net/sockopt_windows.go b/src/pkg/net/sockopt_windows.go new file mode 100644 index 000000000..a7b5606d8 --- /dev/null +++ b/src/pkg/net/sockopt_windows.go @@ -0,0 +1,38 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Socket options for Windows + +package net + +import ( + "syscall" +) + +func setDefaultSockopts(s syscall.Handle, f, t int) { + switch f { + case syscall.AF_INET6: + // Allow both IP versions even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } + + // Windows will reuse recently-used addresses by default. + // SO_REUSEADDR should not be used here, as it allows + // a socket to forcibly bind to a port in use by another socket. + // This could lead to a non-deterministic behavior, where + // connection requests over the port cannot be guaranteed + // to be handled by the correct socket. + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + +} + +func setDefaultMulticastSockopts(fd *netFD) { + fd.incref() + defer fd.decref() + // Allow multicast UDP and raw IP datagram sockets to listen + // concurrently across multiple listeners. + syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) +} diff --git a/src/pkg/net/sockoptip.go b/src/pkg/net/sockoptip.go new file mode 100644 index 000000000..90b6f751e --- /dev/null +++ b/src/pkg/net/sockoptip.go @@ -0,0 +1,187 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd windows + +// IP-level socket options + +package net + +import ( + "os" + "syscall" +) + +func ipv4TOS(fd *netFD) (int, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS) + if err != nil { + return -1, os.NewSyscallError("getsockopt", err) + } + return v, nil +} + +func setIPv4TOS(fd *netFD, v int) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS, v) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4TTL(fd *netFD) (int, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL) + if err != nil { + return -1, os.NewSyscallError("getsockopt", err) + } + return v, nil +} + +func setIPv4TTL(fd *netFD, v int) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL, v) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error { + mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}} + if err := setIPv4MreqToInterface(mreq, ifi); err != nil { + return err + } + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)) +} + +func leaveIPv4Group(fd *netFD, ifi *Interface, ip IP) error { + mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}} + if err := setIPv4MreqToInterface(mreq, ifi); err != nil { + return err + } + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_DROP_MEMBERSHIP, mreq)) +} + +func ipv6HopLimit(fd *netFD) (int, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS) + if err != nil { + return -1, os.NewSyscallError("getsockopt", err) + } + return v, nil +} + +func setIPv6HopLimit(fd *netFD, v int) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, v) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv6MulticastInterface(fd *netFD) (*Interface, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF) + if err != nil { + return nil, os.NewSyscallError("getsockopt", err) + } + if v == 0 { + return nil, nil + } + ifi, err := InterfaceByIndex(v) + if err != nil { + return nil, err + } + return ifi, nil +} + +func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error { + var v int + if ifi != nil { + v = ifi.Index + } + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv6MulticastHopLimit(fd *netFD) (int, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS) + if err != nil { + return -1, os.NewSyscallError("getsockopt", err) + } + return v, nil +} + +func setIPv6MulticastHopLimit(fd *netFD, v int) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS, v) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv6MulticastLoopback(fd *netFD) (bool, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP) + if err != nil { + return false, os.NewSyscallError("getsockopt", err) + } + return v == 1, nil +} + +func setIPv6MulticastLoopback(fd *netFD, v bool) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error { + mreq := &syscall.IPv6Mreq{} + copy(mreq.Multiaddr[:], ip) + if ifi != nil { + mreq.Interface = uint32(ifi.Index) + } + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)) +} + +func leaveIPv6Group(fd *netFD, ifi *Interface, ip IP) error { + mreq := &syscall.IPv6Mreq{} + copy(mreq.Multiaddr[:], ip) + if ifi != nil { + mreq.Interface = uint32(ifi.Index) + } + fd.incref() + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_LEAVE_GROUP, mreq)) +} diff --git a/src/pkg/net/sockoptip_bsd.go b/src/pkg/net/sockoptip_bsd.go new file mode 100644 index 000000000..5f7dff248 --- /dev/null +++ b/src/pkg/net/sockoptip_bsd.go @@ -0,0 +1,54 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd netbsd openbsd + +// IP-level socket options for BSD variants + +package net + +import ( + "os" + "syscall" +) + +func ipv4MulticastTTL(fd *netFD) (int, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL) + if err != nil { + return -1, os.NewSyscallError("getsockopt", err) + } + return int(v), nil +} + +func setIPv4MulticastTTL(fd *netFD, v int) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, byte(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv6TrafficClass(fd *netFD) (int, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS) + if err != nil { + return -1, os.NewSyscallError("getsockopt", err) + } + return v, nil +} + +func setIPv6TrafficClass(fd *netFD, v int) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} diff --git a/src/pkg/net/sockoptip_darwin.go b/src/pkg/net/sockoptip_darwin.go new file mode 100644 index 000000000..dedfd6f4c --- /dev/null +++ b/src/pkg/net/sockoptip_darwin.go @@ -0,0 +1,78 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP-level socket options for Darwin + +package net + +import ( + "os" + "syscall" +) + +func ipv4MulticastInterface(fd *netFD) (*Interface, error) { + fd.incref() + defer fd.decref() + a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF) + if err != nil { + return nil, os.NewSyscallError("getsockopt", err) + } + return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3])) +} + +func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { + ip, err := interfaceToIPv4Addr(ifi) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + var x [4]byte + copy(x[:], ip.To4()) + fd.incref() + defer fd.decref() + err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4MulticastLoopback(fd *netFD) (bool, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP) + if err != nil { + return false, os.NewSyscallError("getsockopt", err) + } + return v == 1, nil +} + +func setIPv4MulticastLoopback(fd *netFD, v bool) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4ReceiveInterface(fd *netFD) (bool, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF) + if err != nil { + return false, os.NewSyscallError("getsockopt", err) + } + return v == 1, nil +} + +func setIPv4ReceiveInterface(fd *netFD, v bool) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} diff --git a/src/pkg/net/sockoptip_freebsd.go b/src/pkg/net/sockoptip_freebsd.go new file mode 100644 index 000000000..55f7b1a60 --- /dev/null +++ b/src/pkg/net/sockoptip_freebsd.go @@ -0,0 +1,80 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP-level socket options for FreeBSD + +package net + +import ( + "os" + "syscall" +) + +func ipv4MulticastInterface(fd *netFD) (*Interface, error) { + fd.incref() + defer fd.decref() + mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF) + if err != nil { + return nil, os.NewSyscallError("getsockopt", err) + } + if int(mreq.Ifindex) == 0 { + return nil, nil + } + return InterfaceByIndex(int(mreq.Ifindex)) +} + +func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { + var v int32 + if ifi != nil { + v = int32(ifi.Index) + } + mreq := &syscall.IPMreqn{Ifindex: v} + fd.incref() + defer fd.decref() + err := syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4MulticastLoopback(fd *netFD) (bool, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP) + if err != nil { + return false, os.NewSyscallError("getsockopt", err) + } + return v == 1, nil +} + +func setIPv4MulticastLoopback(fd *netFD, v bool) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4ReceiveInterface(fd *netFD) (bool, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF) + if err != nil { + return false, os.NewSyscallError("getsockopt", err) + } + return v == 1, nil +} + +func setIPv4ReceiveInterface(fd *netFD, v bool) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} diff --git a/src/pkg/net/sockoptip_linux.go b/src/pkg/net/sockoptip_linux.go new file mode 100644 index 000000000..360f8dea6 --- /dev/null +++ b/src/pkg/net/sockoptip_linux.go @@ -0,0 +1,120 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP-level socket options for Linux + +package net + +import ( + "os" + "syscall" +) + +func ipv4MulticastInterface(fd *netFD) (*Interface, error) { + fd.incref() + defer fd.decref() + mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF) + if err != nil { + return nil, os.NewSyscallError("getsockopt", err) + } + if int(mreq.Ifindex) == 0 { + return nil, nil + } + return InterfaceByIndex(int(mreq.Ifindex)) +} + +func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { + var v int32 + if ifi != nil { + v = int32(ifi.Index) + } + mreq := &syscall.IPMreqn{Ifindex: v} + fd.incref() + defer fd.decref() + err := syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4MulticastTTL(fd *netFD) (int, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL) + if err != nil { + return -1, os.NewSyscallError("getsockopt", err) + } + return v, nil +} + +func setIPv4MulticastTTL(fd *netFD, v int) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, v) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4MulticastLoopback(fd *netFD) (bool, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP) + if err != nil { + return false, os.NewSyscallError("getsockopt", err) + } + return v == 1, nil +} + +func setIPv4MulticastLoopback(fd *netFD, v bool) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4ReceiveInterface(fd *netFD) (bool, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO) + if err != nil { + return false, os.NewSyscallError("getsockopt", err) + } + return v == 1, nil +} + +func setIPv4ReceiveInterface(fd *netFD, v bool) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO, boolint(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv6TrafficClass(fd *netFD) (int, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS) + if err != nil { + return -1, os.NewSyscallError("getsockopt", err) + } + return v, nil +} + +func setIPv6TrafficClass(fd *netFD, v int) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} diff --git a/src/pkg/net/sockoptip_openbsd.go b/src/pkg/net/sockoptip_openbsd.go new file mode 100644 index 000000000..89b8e4592 --- /dev/null +++ b/src/pkg/net/sockoptip_openbsd.go @@ -0,0 +1,78 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP-level socket options for OpenBSD + +package net + +import ( + "os" + "syscall" +) + +func ipv4MulticastInterface(fd *netFD) (*Interface, error) { + fd.incref() + defer fd.decref() + a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF) + if err != nil { + return nil, os.NewSyscallError("getsockopt", err) + } + return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3])) +} + +func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { + ip, err := interfaceToIPv4Addr(ifi) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + var x [4]byte + copy(x[:], ip.To4()) + fd.incref() + defer fd.decref() + err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4MulticastLoopback(fd *netFD) (bool, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP) + if err != nil { + return false, os.NewSyscallError("getsockopt", err) + } + return v == 1, nil +} + +func setIPv4MulticastLoopback(fd *netFD, v bool) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v))) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func ipv4ReceiveInterface(fd *netFD) (bool, error) { + fd.incref() + defer fd.decref() + v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF) + if err != nil { + return false, os.NewSyscallError("getsockopt", err) + } + return v == 1, nil +} + +func setIPv4ReceiveInterface(fd *netFD, v bool) error { + fd.incref() + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} diff --git a/src/pkg/net/sockoptip_windows.go b/src/pkg/net/sockoptip_windows.go new file mode 100644 index 000000000..3320e76bd --- /dev/null +++ b/src/pkg/net/sockoptip_windows.go @@ -0,0 +1,61 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP-level socket options for Windows + +package net + +import ( + "syscall" +) + +func ipv4MulticastInterface(fd *netFD) (*Interface, error) { + // TODO: Implement this + return nil, syscall.EWINDOWS +} + +func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { + // TODO: Implement this + return syscall.EWINDOWS +} + +func ipv4MulticastTTL(fd *netFD) (int, error) { + // TODO: Implement this + return -1, syscall.EWINDOWS +} + +func setIPv4MulticastTTL(fd *netFD, v int) error { + // TODO: Implement this + return syscall.EWINDOWS +} + +func ipv4MulticastLoopback(fd *netFD) (bool, error) { + // TODO: Implement this + return false, syscall.EWINDOWS +} + +func setIPv4MulticastLoopback(fd *netFD, v bool) error { + // TODO: Implement this + return syscall.EWINDOWS +} + +func ipv4ReceiveInterface(fd *netFD) (bool, error) { + // TODO: Implement this + return false, syscall.EWINDOWS +} + +func setIPv4ReceiveInterface(fd *netFD, v bool) error { + // TODO: Implement this + return syscall.EWINDOWS +} + +func ipv6TrafficClass(fd *netFD) (int, error) { + // TODO: Implement this + return 0, syscall.EWINDOWS +} + +func setIPv6TrafficClass(fd *netFD, v int) error { + // TODO: Implement this + return syscall.EWINDOWS +} diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go index f5c0a2781..47fbf2919 100644 --- a/src/pkg/net/tcpsock.go +++ b/src/pkg/net/tcpsock.go @@ -6,10 +6,6 @@ package net -import ( - "os" -) - // TCPAddr represents the address of a TCP end point. type TCPAddr struct { IP IP @@ -31,7 +27,7 @@ func (a *TCPAddr) String() string { // numeric addresses on the network net, which must be "tcp", // "tcp4" or "tcp6". A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". -func ResolveTCPAddr(net, addr string) (*TCPAddr, os.Error) { +func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { ip, port, err := hostPortToIP(net, addr) if err != nil { return nil, err diff --git a/src/pkg/net/tcpsock_plan9.go b/src/pkg/net/tcpsock_plan9.go index f4f6e9fee..f2444a4d9 100644 --- a/src/pkg/net/tcpsock_plan9.go +++ b/src/pkg/net/tcpsock_plan9.go @@ -8,6 +8,7 @@ package net import ( "os" + "time" ) // TCPConn is an implementation of the Conn interface @@ -16,17 +17,50 @@ type TCPConn struct { plan9Conn } +// SetDeadline implements the net.Conn SetDeadline method. +func (c *TCPConn) SetDeadline(t time.Time) error { + return os.EPLAN9 +} + +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (c *TCPConn) SetReadDeadline(t time.Time) error { + return os.EPLAN9 +} + +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (c *TCPConn) SetWriteDeadline(t time.Time) error { + return os.EPLAN9 +} + +// CloseRead shuts down the reading side of the TCP connection. +// Most callers should just use Close. +func (c *TCPConn) CloseRead() error { + if !c.ok() { + return os.EINVAL + } + return os.EPLAN9 +} + +// CloseWrite shuts down the writing side of the TCP connection. +// Most callers should just use Close. +func (c *TCPConn) CloseWrite() error { + if !c.ok() { + return os.EINVAL + } + return os.EPLAN9 +} + // DialTCP connects to the remote address raddr on the network net, // which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used // as the local address for the connection. -func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err os.Error) { +func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err error) { switch net { case "tcp", "tcp4", "tcp6": default: return nil, UnknownNetworkError(net) } if raddr == nil { - return nil, &OpError{"dial", "tcp", nil, errMissingAddress} + return nil, &OpError{"dial", net, nil, errMissingAddress} } c1, err := dialPlan9(net, laddr, raddr) if err != nil { @@ -46,14 +80,14 @@ type TCPListener struct { // Net must be "tcp", "tcp4", or "tcp6". // If laddr has a port of 0, it means to listen on some available port. // The caller can use l.Addr() to retrieve the chosen address. -func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) { +func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err error) { switch net { case "tcp", "tcp4", "tcp6": default: return nil, UnknownNetworkError(net) } if laddr == nil { - return nil, &OpError{"listen", "tcp", nil, errMissingAddress} + return nil, &OpError{"listen", net, nil, errMissingAddress} } l1, err := listenPlan9(net, laddr) if err != nil { diff --git a/src/pkg/net/tcpsock_posix.go b/src/pkg/net/tcpsock_posix.go index 35d536c31..65ec49303 100644 --- a/src/pkg/net/tcpsock_posix.go +++ b/src/pkg/net/tcpsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd windows +// +build darwin freebsd linux netbsd openbsd windows // TCP sockets @@ -12,6 +12,7 @@ import ( "io" "os" "syscall" + "time" ) // BUG(rsc): On OpenBSD, listening on the "tcp" network does not listen for @@ -39,7 +40,7 @@ func (a *TCPAddr) family() int { return syscall.AF_INET6 } -func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) { +func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, error) { return ipToSockaddr(family, a.IP, a.Port) } @@ -67,7 +68,7 @@ func (c *TCPConn) ok() bool { return c != nil && c.fd != nil } // Implementation of the Conn interface - see Conn for documentation. // Read implements the net.Conn Read method. -func (c *TCPConn) Read(b []byte) (n int, err os.Error) { +func (c *TCPConn) Read(b []byte) (n int, err error) { if !c.ok() { return 0, os.EINVAL } @@ -75,7 +76,7 @@ func (c *TCPConn) Read(b []byte) (n int, err os.Error) { } // ReadFrom implements the io.ReaderFrom ReadFrom method. -func (c *TCPConn) ReadFrom(r io.Reader) (int64, os.Error) { +func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { if n, err, handled := sendFile(c.fd, r); handled { return n, err } @@ -83,7 +84,7 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, os.Error) { } // Write implements the net.Conn Write method. -func (c *TCPConn) Write(b []byte) (n int, err os.Error) { +func (c *TCPConn) Write(b []byte) (n int, err error) { if !c.ok() { return 0, os.EINVAL } @@ -91,7 +92,7 @@ func (c *TCPConn) Write(b []byte) (n int, err os.Error) { } // Close closes the TCP connection. -func (c *TCPConn) Close() os.Error { +func (c *TCPConn) Close() error { if !c.ok() { return os.EINVAL } @@ -100,6 +101,24 @@ func (c *TCPConn) Close() os.Error { return err } +// CloseRead shuts down the reading side of the TCP connection. +// Most callers should just use Close. +func (c *TCPConn) CloseRead() error { + if !c.ok() { + return os.EINVAL + } + return c.fd.CloseRead() +} + +// CloseWrite shuts down the writing side of the TCP connection. +// Most callers should just use Close. +func (c *TCPConn) CloseWrite() error { + if !c.ok() { + return os.EINVAL + } + return c.fd.CloseWrite() +} + // LocalAddr returns the local network address, a *TCPAddr. func (c *TCPConn) LocalAddr() Addr { if !c.ok() { @@ -116,33 +135,33 @@ func (c *TCPConn) RemoteAddr() Addr { return c.fd.raddr } -// SetTimeout implements the net.Conn SetTimeout method. -func (c *TCPConn) SetTimeout(nsec int64) os.Error { +// SetDeadline implements the net.Conn SetDeadline method. +func (c *TCPConn) SetDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setTimeout(c.fd, nsec) + return setDeadline(c.fd, t) } -// SetReadTimeout implements the net.Conn SetReadTimeout method. -func (c *TCPConn) SetReadTimeout(nsec int64) os.Error { +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (c *TCPConn) SetReadDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setReadTimeout(c.fd, nsec) + return setReadDeadline(c.fd, t) } -// SetWriteTimeout implements the net.Conn SetWriteTimeout method. -func (c *TCPConn) SetWriteTimeout(nsec int64) os.Error { +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (c *TCPConn) SetWriteDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setWriteTimeout(c.fd, nsec) + return setWriteDeadline(c.fd, t) } // SetReadBuffer sets the size of the operating system's // receive buffer associated with the connection. -func (c *TCPConn) SetReadBuffer(bytes int) os.Error { +func (c *TCPConn) SetReadBuffer(bytes int) error { if !c.ok() { return os.EINVAL } @@ -151,7 +170,7 @@ func (c *TCPConn) SetReadBuffer(bytes int) os.Error { // SetWriteBuffer sets the size of the operating system's // transmit buffer associated with the connection. -func (c *TCPConn) SetWriteBuffer(bytes int) os.Error { +func (c *TCPConn) SetWriteBuffer(bytes int) error { if !c.ok() { return os.EINVAL } @@ -169,7 +188,7 @@ func (c *TCPConn) SetWriteBuffer(bytes int) os.Error { // // If sec > 0, Close blocks for at most sec seconds waiting for // data to be sent and acknowledged. -func (c *TCPConn) SetLinger(sec int) os.Error { +func (c *TCPConn) SetLinger(sec int) error { if !c.ok() { return os.EINVAL } @@ -178,7 +197,7 @@ func (c *TCPConn) SetLinger(sec int) os.Error { // SetKeepAlive sets whether the operating system should send // keepalive messages on the connection. -func (c *TCPConn) SetKeepAlive(keepalive bool) os.Error { +func (c *TCPConn) SetKeepAlive(keepalive bool) error { if !c.ok() { return os.EINVAL } @@ -189,7 +208,7 @@ func (c *TCPConn) SetKeepAlive(keepalive bool) os.Error { // packet transmission in hopes of sending fewer packets // (Nagle's algorithm). The default is true (no delay), meaning // that data is sent as soon as possible after a Write. -func (c *TCPConn) SetNoDelay(noDelay bool) os.Error { +func (c *TCPConn) SetNoDelay(noDelay bool) error { if !c.ok() { return os.EINVAL } @@ -199,14 +218,14 @@ func (c *TCPConn) SetNoDelay(noDelay bool) os.Error { // File returns a copy of the underlying os.File, set to blocking mode. // It is the caller's responsibility to close f when finished. // Closing c does not affect f, and closing f does not affect c. -func (c *TCPConn) File() (f *os.File, err os.Error) { return c.fd.dup() } +func (c *TCPConn) File() (f *os.File, err error) { return c.fd.dup() } // DialTCP connects to the remote address raddr on the network net, // which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used // as the local address for the connection. -func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err os.Error) { +func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err error) { if raddr == nil { - return nil, &OpError{"dial", "tcp", nil, errMissingAddress} + return nil, &OpError{"dial", net, nil, errMissingAddress} } fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) if e != nil { @@ -226,15 +245,15 @@ type TCPListener struct { // Net must be "tcp", "tcp4", or "tcp6". // If laddr has a port of 0, it means to listen on some available port. // The caller can use l.Addr() to retrieve the chosen address. -func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) { +func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err error) { fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP) if err != nil { return nil, err } - errno := syscall.Listen(fd.sysfd, listenBacklog()) - if errno != 0 { + err = syscall.Listen(fd.sysfd, listenerBacklog) + if err != nil { closesocket(fd.sysfd) - return nil, &OpError{"listen", "tcp", laddr, os.Errno(errno)} + return nil, &OpError{"listen", net, laddr, err} } l = new(TCPListener) l.fd = fd @@ -243,7 +262,7 @@ func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) { // AcceptTCP accepts the next incoming call and returns the new connection // and the remote address. -func (l *TCPListener) AcceptTCP() (c *TCPConn, err os.Error) { +func (l *TCPListener) AcceptTCP() (c *TCPConn, err error) { if l == nil || l.fd == nil || l.fd.sysfd < 0 { return nil, os.EINVAL } @@ -256,7 +275,7 @@ func (l *TCPListener) AcceptTCP() (c *TCPConn, err os.Error) { // Accept implements the Accept method in the Listener interface; // it waits for the next call and returns a generic Conn. -func (l *TCPListener) Accept() (c Conn, err os.Error) { +func (l *TCPListener) Accept() (c Conn, err error) { c1, err := l.AcceptTCP() if err != nil { return nil, err @@ -266,7 +285,7 @@ func (l *TCPListener) Accept() (c Conn, err os.Error) { // Close stops listening on the TCP address. // Already Accepted connections are not closed. -func (l *TCPListener) Close() os.Error { +func (l *TCPListener) Close() error { if l == nil || l.fd == nil { return os.EINVAL } @@ -276,15 +295,16 @@ func (l *TCPListener) Close() os.Error { // Addr returns the listener's network address, a *TCPAddr. func (l *TCPListener) Addr() Addr { return l.fd.laddr } -// SetTimeout sets the deadline associated with the listener -func (l *TCPListener) SetTimeout(nsec int64) os.Error { +// SetDeadline sets the deadline associated with the listener. +// A zero time value disables the deadline. +func (l *TCPListener) SetDeadline(t time.Time) error { if l == nil || l.fd == nil { return os.EINVAL } - return setTimeout(l.fd, nsec) + return setDeadline(l.fd, t) } // File returns a copy of the underlying os.File, set to blocking mode. // It is the caller's responsibility to close f when finished. // Closing c does not affect f, and closing f does not affect c. -func (l *TCPListener) File() (f *os.File, err os.Error) { return l.fd.dup() } +func (l *TCPListener) File() (f *os.File, err error) { return l.fd.dup() } diff --git a/src/pkg/net/textproto/header.go b/src/pkg/net/textproto/header.go index 288deb2ce..7fb32f804 100644 --- a/src/pkg/net/textproto/header.go +++ b/src/pkg/net/textproto/header.go @@ -39,5 +39,5 @@ func (h MIMEHeader) Get(key string) string { // Del deletes the values associated with key. func (h MIMEHeader) Del(key string) { - h[CanonicalMIMEHeaderKey(key)] = nil, false + delete(h, CanonicalMIMEHeaderKey(key)) } diff --git a/src/pkg/net/textproto/pipeline.go b/src/pkg/net/textproto/pipeline.go index 8c25884b3..ca50eddac 100644 --- a/src/pkg/net/textproto/pipeline.go +++ b/src/pkg/net/textproto/pipeline.go @@ -108,7 +108,7 @@ func (s *sequencer) End(id uint) { } c, ok := s.wait[id] if ok { - s.wait[id] = nil, false + delete(s.wait, id) } s.mu.Unlock() if ok { diff --git a/src/pkg/net/textproto/reader.go b/src/pkg/net/textproto/reader.go index a404f4758..862cd536c 100644 --- a/src/pkg/net/textproto/reader.go +++ b/src/pkg/net/textproto/reader.go @@ -9,7 +9,6 @@ import ( "bytes" "io" "io/ioutil" - "os" "strconv" "strings" ) @@ -23,6 +22,7 @@ import ( type Reader struct { R *bufio.Reader dot *dotReader + buf []byte // a re-usable buffer for readContinuedLineSlice } // NewReader returns a new Reader reading from r. @@ -32,13 +32,13 @@ func NewReader(r *bufio.Reader) *Reader { // ReadLine reads a single line from r, // eliding the final \n or \r\n from the returned string. -func (r *Reader) ReadLine() (string, os.Error) { +func (r *Reader) ReadLine() (string, error) { line, err := r.readLineSlice() return string(line), err } // ReadLineBytes is like ReadLine but returns a []byte instead of a string. -func (r *Reader) ReadLineBytes() ([]byte, os.Error) { +func (r *Reader) ReadLineBytes() ([]byte, error) { line, err := r.readLineSlice() if line != nil { buf := make([]byte, len(line)) @@ -48,10 +48,24 @@ func (r *Reader) ReadLineBytes() ([]byte, os.Error) { return line, err } -func (r *Reader) readLineSlice() ([]byte, os.Error) { +func (r *Reader) readLineSlice() ([]byte, error) { r.closeDot() - line, _, err := r.R.ReadLine() - return line, err + var line []byte + for { + l, more, err := r.R.ReadLine() + if err != nil { + return nil, err + } + // Avoid the copy if the first call produced a full line. + if line == nil && !more { + return l, nil + } + line = append(line, l...) + if !more { + break + } + } + return line, nil } // ReadContinuedLine reads a possibly continued line from r, @@ -73,7 +87,7 @@ func (r *Reader) readLineSlice() ([]byte, os.Error) { // // A line consisting of only white space is never continued. // -func (r *Reader) ReadContinuedLine() (string, os.Error) { +func (r *Reader) ReadContinuedLine() (string, error) { line, err := r.readContinuedLineSlice() return string(line), err } @@ -94,7 +108,7 @@ func trim(s []byte) []byte { // ReadContinuedLineBytes is like ReadContinuedLine but // returns a []byte instead of a string. -func (r *Reader) ReadContinuedLineBytes() ([]byte, os.Error) { +func (r *Reader) ReadContinuedLineBytes() ([]byte, error) { line, err := r.readContinuedLineSlice() if line != nil { buf := make([]byte, len(line)) @@ -104,81 +118,51 @@ func (r *Reader) ReadContinuedLineBytes() ([]byte, os.Error) { return line, err } -func (r *Reader) readContinuedLineSlice() ([]byte, os.Error) { +func (r *Reader) readContinuedLineSlice() ([]byte, error) { // Read the first line. line, err := r.readLineSlice() if err != nil { - return line, err + return nil, err } if len(line) == 0 { // blank line - no continuation return line, nil } - line = trim(line) - copied := false - if r.R.Buffered() < 1 { - // ReadByte will flush the buffer; make a copy of the slice. - copied = true - line = append([]byte(nil), line...) - } - - // Look for a continuation line. - c, err := r.R.ReadByte() - if err != nil { - // Delay err until we read the byte next time. - return line, nil - } - if c != ' ' && c != '\t' { - // Not a continuation. - r.R.UnreadByte() - return line, nil - } - - if !copied { - // The next readLineSlice will invalidate the previous one. - line = append(make([]byte, 0, len(line)*2), line...) - } + // ReadByte or the next readLineSlice will flush the read buffer; + // copy the slice into buf. + r.buf = append(r.buf[:0], trim(line)...) // Read continuation lines. - for { - // Consume leading spaces; one already gone. - for { - c, err = r.R.ReadByte() - if err != nil { - break - } - if c != ' ' && c != '\t' { - r.R.UnreadByte() - break - } - } - var cont []byte - cont, err = r.readLineSlice() - cont = trim(cont) - line = append(line, ' ') - line = append(line, cont...) + for r.skipSpace() > 0 { + line, err := r.readLineSlice() if err != nil { break } + r.buf = append(r.buf, ' ') + r.buf = append(r.buf, line...) + } + return r.buf, nil +} - // Check for leading space on next line. - if c, err = r.R.ReadByte(); err != nil { +// skipSpace skips R over all spaces and returns the number of bytes skipped. +func (r *Reader) skipSpace() int { + n := 0 + for { + c, err := r.R.ReadByte() + if err != nil { + // Bufio will keep err until next read. break } if c != ' ' && c != '\t' { r.R.UnreadByte() break } + n++ } - - // Delay error until next call. - if len(line) > 0 { - err = nil - } - return line, err + return n } -func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err os.Error) { +func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) { line, err := r.ReadLine() if err != nil { return @@ -186,7 +170,7 @@ func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message return parseCodeLine(line, expectCode) } -func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err os.Error) { +func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) { if len(line) < 4 || line[3] != ' ' && line[3] != '-' { err = ProtocolError("short response: " + line) return @@ -221,7 +205,7 @@ func parseCodeLine(line string, expectCode int) (code int, continued bool, messa // // An expectCode <= 0 disables the check of the status code. // -func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err os.Error) { +func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) { code, continued, message, err := r.readCodeLine(expectCode) if err == nil && continued { err = ProtocolError("unexpected multi-line response: " + message) @@ -251,12 +235,12 @@ func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err os. // // An expectCode <= 0 disables the check of the status code. // -func (r *Reader) ReadResponse(expectCode int) (code int, message string, err os.Error) { +func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) { code, continued, message, err := r.readCodeLine(expectCode) for err == nil && continued { line, err := r.ReadLine() if err != nil { - return + return 0, "", err } var code2 int @@ -286,7 +270,7 @@ func (r *Reader) ReadResponse(expectCode int) (code int, message string, err os. // // The decoded form returned by the Reader's Read method // rewrites the "\r\n" line endings into the simpler "\n", -// removes leading dot escapes if present, and stops with error os.EOF +// removes leading dot escapes if present, and stops with error io.EOF // after consuming (and discarding) the end-of-sequence line. func (r *Reader) DotReader() io.Reader { r.closeDot() @@ -300,7 +284,7 @@ type dotReader struct { } // Read satisfies reads by decoding dot-encoded data read from d.r. -func (d *dotReader) Read(b []byte) (n int, err os.Error) { +func (d *dotReader) Read(b []byte) (n int, err error) { // Run data through a simple state machine to // elide leading dots, rewrite trailing \r\n into \n, // and detect ending .\r\n line. @@ -317,7 +301,7 @@ func (d *dotReader) Read(b []byte) (n int, err os.Error) { var c byte c, err = br.ReadByte() if err != nil { - if err == os.EOF { + if err == io.EOF { err = io.ErrUnexpectedEOF } break @@ -379,7 +363,7 @@ func (d *dotReader) Read(b []byte) (n int, err os.Error) { n++ } if err == nil && d.state == stateEOF { - err = os.EOF + err = io.EOF } if err != nil && d.r.dot == d { d.r.dot = nil @@ -404,7 +388,7 @@ func (r *Reader) closeDot() { // ReadDotBytes reads a dot-encoding and returns the decoded data. // // See the documentation for the DotReader method for details about dot-encoding. -func (r *Reader) ReadDotBytes() ([]byte, os.Error) { +func (r *Reader) ReadDotBytes() ([]byte, error) { return ioutil.ReadAll(r.DotReader()) } @@ -412,17 +396,17 @@ func (r *Reader) ReadDotBytes() ([]byte, os.Error) { // containing the decoded lines, with the final \r\n or \n elided from each. // // See the documentation for the DotReader method for details about dot-encoding. -func (r *Reader) ReadDotLines() ([]string, os.Error) { +func (r *Reader) ReadDotLines() ([]string, error) { // We could use ReadDotBytes and then Split it, // but reading a line at a time avoids needing a // large contiguous block of memory and is simpler. var v []string - var err os.Error + var err error for { var line string line, err = r.ReadLine() if err != nil { - if err == os.EOF { + if err == io.EOF { err = io.ErrUnexpectedEOF } break @@ -460,7 +444,7 @@ func (r *Reader) ReadDotLines() ([]string, os.Error) { // "Long-Key": {"Even Longer Value"}, // } // -func (r *Reader) ReadMIMEHeader() (MIMEHeader, os.Error) { +func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { m := make(MIMEHeader) for { kv, err := r.readContinuedLineSlice() diff --git a/src/pkg/net/textproto/reader_test.go b/src/pkg/net/textproto/reader_test.go index 23ebc3f61..4d0369148 100644 --- a/src/pkg/net/textproto/reader_test.go +++ b/src/pkg/net/textproto/reader_test.go @@ -7,7 +7,6 @@ package textproto import ( "bufio" "io" - "os" "reflect" "strings" "testing" @@ -49,7 +48,7 @@ func TestReadLine(t *testing.T) { t.Fatalf("Line 2: %s, %v", s, err) } s, err = r.ReadLine() - if s != "" || err != os.EOF { + if s != "" || err != io.EOF { t.Fatalf("EOF: %s, %v", s, err) } } @@ -69,7 +68,7 @@ func TestReadContinuedLine(t *testing.T) { t.Fatalf("Line 3: %s, %v", s, err) } s, err = r.ReadContinuedLine() - if s != "" || err != os.EOF { + if s != "" || err != io.EOF { t.Fatalf("EOF: %s, %v", s, err) } } @@ -92,7 +91,7 @@ func TestReadCodeLine(t *testing.T) { t.Fatalf("Line 3: wrong error %v\n", err) } code, msg, err = r.ReadCodeLine(1) - if code != 0 || msg != "" || err != os.EOF { + if code != 0 || msg != "" || err != io.EOF { t.Fatalf("EOF: %d, %s, %v", code, msg, err) } } @@ -139,6 +138,32 @@ func TestReadMIMEHeader(t *testing.T) { } } +func TestReadMIMEHeaderSingle(t *testing.T) { + r := reader("Foo: bar\n\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{"Foo": {"bar"}} + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} + +func TestLargeReadMIMEHeader(t *testing.T) { + data := make([]byte, 16*1024) + for i := 0; i < len(data); i++ { + data[i] = 'x' + } + sdata := string(data) + r := reader("Cookie: " + sdata + "\r\n\n") + m, err := r.ReadMIMEHeader() + if err != nil { + t.Fatalf("ReadMIMEHeader: %v", err) + } + cookie := m.Get("Cookie") + if cookie != sdata { + t.Fatalf("ReadMIMEHeader: %v bytes, want %v bytes", len(cookie), len(sdata)) + } +} + type readResponseTest struct { in string inCode int @@ -187,7 +212,7 @@ func TestRFC959Lines(t *testing.T) { t.Errorf("#%d: code=%d, want %d", i, code, tt.wantCode) } if msg != tt.wantMsg { - t.Errorf("%#d: msg=%q, want %q", i, msg, tt.wantMsg) + t.Errorf("#%d: msg=%q, want %q", i, msg, tt.wantMsg) } } } diff --git a/src/pkg/net/textproto/textproto.go b/src/pkg/net/textproto/textproto.go index 9f19b5495..317ec72b0 100644 --- a/src/pkg/net/textproto/textproto.go +++ b/src/pkg/net/textproto/textproto.go @@ -27,7 +27,6 @@ import ( "fmt" "io" "net" - "os" ) // An Error represents a numeric error response from a server. @@ -36,7 +35,7 @@ type Error struct { Msg string } -func (e *Error) String() string { +func (e *Error) Error() string { return fmt.Sprintf("%03d %s", e.Code, e.Msg) } @@ -44,7 +43,7 @@ func (e *Error) String() string { // as an invalid response or a hung-up connection. type ProtocolError string -func (p ProtocolError) String() string { +func (p ProtocolError) Error() string { return string(p) } @@ -70,13 +69,13 @@ func NewConn(conn io.ReadWriteCloser) *Conn { } // Close closes the connection. -func (c *Conn) Close() os.Error { +func (c *Conn) Close() error { return c.conn.Close() } // Dial connects to the given address on the given network using net.Dial // and then returns a new Conn for the connection. -func Dial(network, addr string) (*Conn, os.Error) { +func Dial(network, addr string) (*Conn, error) { c, err := net.Dial(network, addr) if err != nil { return nil, err @@ -109,7 +108,7 @@ func Dial(network, addr string) (*Conn, os.Error) { // } // return c.ReadCodeLine(250) // -func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err os.Error) { +func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err error) { id = c.Next() c.StartRequest(id) err = c.PrintfLine(format, args...) diff --git a/src/pkg/net/textproto/writer.go b/src/pkg/net/textproto/writer.go index 4e705f6c3..03e2fd658 100644 --- a/src/pkg/net/textproto/writer.go +++ b/src/pkg/net/textproto/writer.go @@ -8,7 +8,6 @@ import ( "bufio" "fmt" "io" - "os" ) // A Writer implements convenience methods for writing @@ -27,7 +26,7 @@ var crnl = []byte{'\r', '\n'} var dotcrnl = []byte{'.', '\r', '\n'} // PrintfLine writes the formatted output followed by \r\n. -func (w *Writer) PrintfLine(format string, args ...interface{}) os.Error { +func (w *Writer) PrintfLine(format string, args ...interface{}) error { w.closeDot() fmt.Fprintf(w.W, format, args...) w.W.Write(crnl) @@ -64,7 +63,7 @@ const ( wstateData // writing data in middle of line ) -func (d *dotWriter) Write(b []byte) (n int, err os.Error) { +func (d *dotWriter) Write(b []byte) (n int, err error) { bw := d.w.W for n < len(b) { c := b[n] @@ -100,7 +99,7 @@ func (d *dotWriter) Write(b []byte) (n int, err os.Error) { return } -func (d *dotWriter) Close() os.Error { +func (d *dotWriter) Close() error { if d.w.dot == d { d.w.dot = nil } diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index 0dbab5846..bae37c86b 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -5,7 +5,8 @@ package net import ( - "os" + "fmt" + "runtime" "testing" "time" ) @@ -17,35 +18,56 @@ func testTimeout(t *testing.T, network, addr string, readFrom bool) { return } defer fd.Close() - t0 := time.Nanoseconds() - fd.SetReadTimeout(1e8) // 100ms - var b [100]byte - var n int - var err1 os.Error - if readFrom { - n, _, err1 = fd.(PacketConn).ReadFrom(b[0:]) - } else { - n, err1 = fd.Read(b[0:]) - } - t1 := time.Nanoseconds() what := "Read" if readFrom { what = "ReadFrom" } - if n != 0 || err1 == nil || !err1.(Error).Timeout() { - t.Errorf("fd.%s on %s %s did not return 0, timeout: %v, %v", what, network, addr, n, err1) - } - if t1-t0 < 0.5e8 || t1-t0 > 1.5e8 { - t.Errorf("fd.%s on %s %s took %f seconds, expected 0.1", what, network, addr, float64(t1-t0)/1e9) + + errc := make(chan error, 1) + go func() { + t0 := time.Now() + fd.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + var b [100]byte + var n int + var err1 error + if readFrom { + n, _, err1 = fd.(PacketConn).ReadFrom(b[0:]) + } else { + n, err1 = fd.Read(b[0:]) + } + t1 := time.Now() + if n != 0 || err1 == nil || !err1.(Error).Timeout() { + errc <- fmt.Errorf("fd.%s on %s %s did not return 0, timeout: %v, %v", what, network, addr, n, err1) + return + } + if dt := t1.Sub(t0); dt < 50*time.Millisecond || dt > 250*time.Millisecond { + errc <- fmt.Errorf("fd.%s on %s %s took %s, expected 0.1s", what, network, addr, dt) + return + } + errc <- nil + }() + select { + case err := <-errc: + if err != nil { + t.Error(err) + } + case <-time.After(1 * time.Second): + t.Errorf("%s on %s %s took over 1 second, expected 0.1s", what, network, addr) } } func TestTimeoutUDP(t *testing.T) { + if runtime.GOOS == "plan9" { + return + } testTimeout(t, "udp", "127.0.0.1:53", false) testTimeout(t, "udp", "127.0.0.1:53", true) } func TestTimeoutTCP(t *testing.T) { + if runtime.GOOS == "plan9" { + return + } // set up a listener that won't talk back listening := make(chan string) done := make(chan int) @@ -55,3 +77,30 @@ func TestTimeoutTCP(t *testing.T) { testTimeout(t, "tcp", addr, false) <-done } + +func TestDeadlineReset(t *testing.T) { + if runtime.GOOS == "plan9" { + return + } + ln, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + tl := ln.(*TCPListener) + tl.SetDeadline(time.Now().Add(1 * time.Minute)) + tl.SetDeadline(time.Time{}) // reset it + errc := make(chan error, 1) + go func() { + _, err := ln.Accept() + errc <- err + }() + select { + case <-time.After(50 * time.Millisecond): + // Pass. + case err := <-errc: + // Accept should never return; we never + // connected to it. + t.Errorf("unexpected return from Accept; err=%v", err) + } +} diff --git a/src/pkg/net/udp_test.go b/src/pkg/net/udp_test.go new file mode 100644 index 000000000..6ba762b1f --- /dev/null +++ b/src/pkg/net/udp_test.go @@ -0,0 +1,87 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "runtime" + "testing" +) + +func TestWriteToUDP(t *testing.T) { + if runtime.GOOS == "plan9" { + return + } + + l, err := ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer l.Close() + + testWriteToConn(t, l.LocalAddr().String()) + testWriteToPacketConn(t, l.LocalAddr().String()) +} + +func testWriteToConn(t *testing.T, raddr string) { + c, err := Dial("udp", raddr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c.Close() + + ra, err := ResolveUDPAddr("udp", raddr) + if err != nil { + t.Fatalf("ResolveUDPAddr failed: %v", err) + } + + _, err = c.(*UDPConn).WriteToUDP([]byte("Connection-oriented mode socket"), ra) + if err == nil { + t.Fatal("WriteToUDP should be failed") + } + if err != nil && err.(*OpError).Err != ErrWriteToConnected { + t.Fatalf("WriteToUDP should be failed as ErrWriteToConnected: %v", err) + } + + _, err = c.(*UDPConn).WriteTo([]byte("Connection-oriented mode socket"), ra) + if err == nil { + t.Fatal("WriteTo should be failed") + } + if err != nil && err.(*OpError).Err != ErrWriteToConnected { + t.Fatalf("WriteTo should be failed as ErrWriteToConnected: %v", err) + } + + _, err = c.Write([]byte("Connection-oriented mode socket")) + if err != nil { + t.Fatalf("Write failed: %v", err) + } +} + +func testWriteToPacketConn(t *testing.T, raddr string) { + c, err := ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + defer c.Close() + + ra, err := ResolveUDPAddr("udp", raddr) + if err != nil { + t.Fatalf("ResolveUDPAddr failed: %v", err) + } + + _, err = c.(*UDPConn).WriteToUDP([]byte("Connection-less mode socket"), ra) + if err != nil { + t.Fatalf("WriteToUDP failed: %v", err) + } + + _, err = c.WriteTo([]byte("Connection-less mode socket"), ra) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + _, err = c.(*UDPConn).Write([]byte("Connection-less mode socket")) + if err == nil { + t.Fatal("Write should be failed") + } +} diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go index 3dfa71675..b3520cf09 100644 --- a/src/pkg/net/udpsock.go +++ b/src/pkg/net/udpsock.go @@ -6,10 +6,6 @@ package net -import ( - "os" -) - // UDPAddr represents the address of a UDP end point. type UDPAddr struct { IP IP @@ -31,7 +27,7 @@ func (a *UDPAddr) String() string { // numeric addresses on the network net, which must be "udp", // "udp4" or "udp6". A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". -func ResolveUDPAddr(net, addr string) (*UDPAddr, os.Error) { +func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { ip, port, err := hostPortToIP(net, addr) if err != nil { return nil, err diff --git a/src/pkg/net/udpsock_plan9.go b/src/pkg/net/udpsock_plan9.go index d5c6ccb90..573438f85 100644 --- a/src/pkg/net/udpsock_plan9.go +++ b/src/pkg/net/udpsock_plan9.go @@ -7,7 +7,9 @@ package net import ( + "errors" "os" + "time" ) // UDPConn is the implementation of the Conn and PacketConn @@ -16,6 +18,21 @@ type UDPConn struct { plan9Conn } +// SetDeadline implements the net.Conn SetDeadline method. +func (c *UDPConn) SetDeadline(t time.Time) error { + return os.EPLAN9 +} + +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (c *UDPConn) SetReadDeadline(t time.Time) error { + return os.EPLAN9 +} + +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (c *UDPConn) SetWriteDeadline(t time.Time) error { + return os.EPLAN9 +} + // UDP-specific methods. // ReadFromUDP reads a UDP packet from c, copying the payload into b. @@ -23,8 +40,8 @@ type UDPConn struct { // that was on the packet. // // ReadFromUDP can be made to time out and return an error with Timeout() == true -// after a fixed time limit; see SetTimeout and SetReadTimeout. -func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) { +// after a fixed time limit; see SetDeadline and SetReadDeadline. +func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) { if !c.ok() { return 0, nil, os.EINVAL } @@ -40,7 +57,7 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) { return } if m < udpHeaderSize { - return 0, nil, os.NewError("short read reading UDP header") + return 0, nil, errors.New("short read reading UDP header") } buf = buf[:m] @@ -50,7 +67,7 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) { } // ReadFrom implements the net.PacketConn ReadFrom method. -func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { +func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err error) { if !c.ok() { return 0, nil, os.EINVAL } @@ -61,9 +78,9 @@ func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { // // WriteToUDP can be made to time out and return // an error with Timeout() == true after a fixed time limit; -// see SetTimeout and SetWriteTimeout. +// see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. -func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err os.Error) { +func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err error) { if !c.ok() { return 0, os.EINVAL } @@ -87,13 +104,13 @@ func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err os.Error) { } // WriteTo implements the net.PacketConn WriteTo method. -func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { +func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err error) { if !c.ok() { return 0, os.EINVAL } a, ok := addr.(*UDPAddr) if !ok { - return 0, &OpError{"writeto", "udp", addr, os.EINVAL} + return 0, &OpError{"write", c.dir, addr, os.EINVAL} } return c.WriteToUDP(b, a) } @@ -101,14 +118,14 @@ func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { // DialUDP connects to the remote address raddr on the network net, // which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used // as the local address for the connection. -func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err os.Error) { +func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err error) { switch net { case "udp", "udp4", "udp6": default: return nil, UnknownNetworkError(net) } if raddr == nil { - return nil, &OpError{"dial", "udp", nil, errMissingAddress} + return nil, &OpError{"dial", net, nil, errMissingAddress} } c1, err := dialPlan9(net, laddr, raddr) if err != nil { @@ -149,14 +166,14 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) { // local address laddr. The returned connection c's ReadFrom // and WriteTo methods can be used to receive and send UDP // packets with per-packet addressing. -func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err os.Error) { +func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err error) { switch net { case "udp", "udp4", "udp6": default: return nil, UnknownNetworkError(net) } if laddr == nil { - return nil, &OpError{"listen", "udp", nil, errMissingAddress} + return nil, &OpError{"listen", net, nil, errMissingAddress} } l, err := listenPlan9(net, laddr) if err != nil { @@ -172,7 +189,7 @@ func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err os.Error) { // JoinGroup joins the IP multicast group named by addr on ifi, // which specifies the interface to join. JoinGroup uses the // default multicast interface if ifi is nil. -func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) os.Error { +func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) error { if !c.ok() { return os.EINVAL } @@ -180,7 +197,7 @@ func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) os.Error { } // LeaveGroup exits the IP multicast group named by addr on ifi. -func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) os.Error { +func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) error { if !c.ok() { return os.EINVAL } diff --git a/src/pkg/net/udpsock_posix.go b/src/pkg/net/udpsock_posix.go index 06298ee40..fa3d29adf 100644 --- a/src/pkg/net/udpsock_posix.go +++ b/src/pkg/net/udpsock_posix.go @@ -2,18 +2,21 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd windows +// +build darwin freebsd linux netbsd openbsd windows // UDP sockets package net import ( - "bytes" + "errors" "os" "syscall" + "time" ) +var ErrWriteToConnected = errors.New("use of WriteTo with pre-connected UDP") + func sockaddrToUDP(sa syscall.Sockaddr) Addr { switch sa := sa.(type) { case *syscall.SockaddrInet4: @@ -34,7 +37,7 @@ func (a *UDPAddr) family() int { return syscall.AF_INET6 } -func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) { +func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) { return ipToSockaddr(family, a.IP, a.Port) } @@ -58,7 +61,7 @@ func (c *UDPConn) ok() bool { return c != nil && c.fd != nil } // Implementation of the Conn interface - see Conn for documentation. // Read implements the net.Conn Read method. -func (c *UDPConn) Read(b []byte) (n int, err os.Error) { +func (c *UDPConn) Read(b []byte) (n int, err error) { if !c.ok() { return 0, os.EINVAL } @@ -66,7 +69,7 @@ func (c *UDPConn) Read(b []byte) (n int, err os.Error) { } // Write implements the net.Conn Write method. -func (c *UDPConn) Write(b []byte) (n int, err os.Error) { +func (c *UDPConn) Write(b []byte) (n int, err error) { if !c.ok() { return 0, os.EINVAL } @@ -74,7 +77,7 @@ func (c *UDPConn) Write(b []byte) (n int, err os.Error) { } // Close closes the UDP connection. -func (c *UDPConn) Close() os.Error { +func (c *UDPConn) Close() error { if !c.ok() { return os.EINVAL } @@ -99,33 +102,33 @@ func (c *UDPConn) RemoteAddr() Addr { return c.fd.raddr } -// SetTimeout implements the net.Conn SetTimeout method. -func (c *UDPConn) SetTimeout(nsec int64) os.Error { +// SetDeadline implements the net.Conn SetDeadline method. +func (c *UDPConn) SetDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setTimeout(c.fd, nsec) + return setDeadline(c.fd, t) } -// SetReadTimeout implements the net.Conn SetReadTimeout method. -func (c *UDPConn) SetReadTimeout(nsec int64) os.Error { +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (c *UDPConn) SetReadDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setReadTimeout(c.fd, nsec) + return setReadDeadline(c.fd, t) } -// SetWriteTimeout implements the net.Conn SetWriteTimeout method. -func (c *UDPConn) SetWriteTimeout(nsec int64) os.Error { +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (c *UDPConn) SetWriteDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setWriteTimeout(c.fd, nsec) + return setWriteDeadline(c.fd, t) } // SetReadBuffer sets the size of the operating system's // receive buffer associated with the connection. -func (c *UDPConn) SetReadBuffer(bytes int) os.Error { +func (c *UDPConn) SetReadBuffer(bytes int) error { if !c.ok() { return os.EINVAL } @@ -134,7 +137,7 @@ func (c *UDPConn) SetReadBuffer(bytes int) os.Error { // SetWriteBuffer sets the size of the operating system's // transmit buffer associated with the connection. -func (c *UDPConn) SetWriteBuffer(bytes int) os.Error { +func (c *UDPConn) SetWriteBuffer(bytes int) error { if !c.ok() { return os.EINVAL } @@ -148,8 +151,8 @@ func (c *UDPConn) SetWriteBuffer(bytes int) os.Error { // that was on the packet. // // ReadFromUDP can be made to time out and return an error with Timeout() == true -// after a fixed time limit; see SetTimeout and SetReadTimeout. -func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) { +// after a fixed time limit; see SetDeadline and SetReadDeadline. +func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) { if !c.ok() { return 0, nil, os.EINVAL } @@ -164,7 +167,7 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) { } // ReadFrom implements the net.PacketConn ReadFrom method. -func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { +func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err error) { if !c.ok() { return 0, nil, os.EINVAL } @@ -176,27 +179,30 @@ func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { // // WriteToUDP can be made to time out and return // an error with Timeout() == true after a fixed time limit; -// see SetTimeout and SetWriteTimeout. +// see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. -func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err os.Error) { +func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { if !c.ok() { return 0, os.EINVAL } - sa, err1 := addr.sockaddr(c.fd.family) - if err1 != nil { - return 0, &OpError{Op: "write", Net: "udp", Addr: addr, Error: err1} + if c.fd.isConnected { + return 0, &OpError{"write", c.fd.net, addr, ErrWriteToConnected} + } + sa, err := addr.sockaddr(c.fd.family) + if err != nil { + return 0, &OpError{"write", c.fd.net, addr, err} } return c.fd.WriteTo(b, sa) } // WriteTo implements the net.PacketConn WriteTo method. -func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { +func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) { if !c.ok() { return 0, os.EINVAL } a, ok := addr.(*UDPAddr) if !ok { - return 0, &OpError{"writeto", "udp", addr, os.EINVAL} + return 0, &OpError{"write", c.fd.net, addr, os.EINVAL} } return c.WriteToUDP(b, a) } @@ -204,14 +210,14 @@ func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { // DialUDP connects to the remote address raddr on the network net, // which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used // as the local address for the connection. -func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err os.Error) { +func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err error) { switch net { case "udp", "udp4", "udp6": default: return nil, UnknownNetworkError(net) } if raddr == nil { - return nil, &OpError{"dial", "udp", nil, errMissingAddress} + return nil, &OpError{"dial", net, nil, errMissingAddress} } fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) if e != nil { @@ -224,44 +230,35 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err os.Error) { // local address laddr. The returned connection c's ReadFrom // and WriteTo methods can be used to receive and send UDP // packets with per-packet addressing. -func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err os.Error) { +func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { switch net { case "udp", "udp4", "udp6": default: return nil, UnknownNetworkError(net) } if laddr == nil { - return nil, &OpError{"listen", "udp", nil, errMissingAddress} + return nil, &OpError{"listen", net, nil, errMissingAddress} } - fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) - if e != nil { - return nil, e + fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP) + if err != nil { + return nil, err } return newUDPConn(fd), nil } -// BindToDevice binds a UDPConn to a network interface. -func (c *UDPConn) BindToDevice(device string) os.Error { - if !c.ok() { - return os.EINVAL - } - c.fd.incref() - defer c.fd.decref() - return os.NewSyscallError("setsockopt", syscall.BindToDevice(c.fd.sysfd, device)) -} - // File returns a copy of the underlying os.File, set to blocking mode. // It is the caller's responsibility to close f when finished. // Closing c does not affect f, and closing f does not affect c. -func (c *UDPConn) File() (f *os.File, err os.Error) { return c.fd.dup() } +func (c *UDPConn) File() (f *os.File, err error) { return c.fd.dup() } // JoinGroup joins the IP multicast group named by addr on ifi, // which specifies the interface to join. JoinGroup uses the // default multicast interface if ifi is nil. -func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) os.Error { +func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) error { if !c.ok() { return os.EINVAL } + setDefaultMulticastSockopts(c.fd) ip := addr.To4() if ip != nil { return joinIPv4GroupUDP(c, ifi, ip) @@ -270,7 +267,7 @@ func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) os.Error { } // LeaveGroup exits the IP multicast group named by addr on ifi. -func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) os.Error { +func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) error { if !c.ok() { return os.EINVAL } @@ -281,68 +278,34 @@ func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) os.Error { return leaveIPv6GroupUDP(c, ifi, addr) } -func joinIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) os.Error { - mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}} - if err := setIPv4InterfaceToJoin(mreq, ifi); err != nil { - return &OpError{"joinipv4group", "udp", &IPAddr{ip}, err} - } - if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(c.fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)); err != nil { - return &OpError{"joinipv4group", "udp", &IPAddr{ip}, err} - } - return nil -} - -func leaveIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) os.Error { - mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}} - if err := setIPv4InterfaceToJoin(mreq, ifi); err != nil { - return &OpError{"leaveipv4group", "udp", &IPAddr{ip}, err} - } - if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(c.fd.sysfd, syscall.IPPROTO_IP, syscall.IP_DROP_MEMBERSHIP, mreq)); err != nil { - return &OpError{"leaveipv4group", "udp", &IPAddr{ip}, err} +func joinIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error { + err := joinIPv4Group(c.fd, ifi, ip) + if err != nil { + return &OpError{"joinipv4group", c.fd.net, &IPAddr{ip}, err} } return nil } -func setIPv4InterfaceToJoin(mreq *syscall.IPMreq, ifi *Interface) os.Error { - if ifi == nil { - return nil - } - ifat, err := ifi.Addrs() +func leaveIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error { + err := leaveIPv4Group(c.fd, ifi, ip) if err != nil { - return err - } - for _, ifa := range ifat { - if x := ifa.(*IPAddr).IP.To4(); x != nil { - copy(mreq.Interface[:], x) - break - } - } - if bytes.Equal(mreq.Multiaddr[:], IPv4zero) { - return os.EINVAL + return &OpError{"leaveipv4group", c.fd.net, &IPAddr{ip}, err} } return nil } -func joinIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) os.Error { - mreq := &syscall.IPv6Mreq{} - copy(mreq.Multiaddr[:], ip) - if ifi != nil { - mreq.Interface = uint32(ifi.Index) - } - if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(c.fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)); err != nil { - return &OpError{"joinipv6group", "udp", &IPAddr{ip}, err} +func joinIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error { + err := joinIPv6Group(c.fd, ifi, ip) + if err != nil { + return &OpError{"joinipv6group", c.fd.net, &IPAddr{ip}, err} } return nil } -func leaveIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) os.Error { - mreq := &syscall.IPv6Mreq{} - copy(mreq.Multiaddr[:], ip) - if ifi != nil { - mreq.Interface = uint32(ifi.Index) - } - if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(c.fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_LEAVE_GROUP, mreq)); err != nil { - return &OpError{"leaveipv6group", "udp", &IPAddr{ip}, err} +func leaveIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error { + err := leaveIPv6Group(c.fd, ifi, ip) + if err != nil { + return &OpError{"leaveipv6group", c.fd.net, &IPAddr{ip}, err} } return nil } diff --git a/src/pkg/net/unicast_test.go b/src/pkg/net/unicast_test.go new file mode 100644 index 000000000..297276d3a --- /dev/null +++ b/src/pkg/net/unicast_test.go @@ -0,0 +1,111 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "io" + "runtime" + "testing" +) + +var unicastTests = []struct { + net string + laddr string + ipv6 bool + packet bool +}{ + {net: "tcp4", laddr: "127.0.0.1:0"}, + {net: "tcp4", laddr: "previous"}, + {net: "tcp6", laddr: "[::1]:0", ipv6: true}, + {net: "tcp6", laddr: "previous", ipv6: true}, + {net: "udp4", laddr: "127.0.0.1:0", packet: true}, + {net: "udp6", laddr: "[::1]:0", ipv6: true, packet: true}, +} + +func TestUnicastTCPAndUDP(t *testing.T) { + if runtime.GOOS == "plan9" || runtime.GOOS == "windows" { + return + } + + prevladdr := "" + for _, tt := range unicastTests { + if tt.ipv6 && !supportsIPv6 { + continue + } + var ( + fd *netFD + closer io.Closer + ) + if !tt.packet { + if tt.laddr == "previous" { + tt.laddr = prevladdr + } + l, err := Listen(tt.net, tt.laddr) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + prevladdr = l.Addr().String() + closer = l + fd = l.(*TCPListener).fd + } else { + c, err := ListenPacket(tt.net, tt.laddr) + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + closer = c + fd = c.(*UDPConn).fd + } + if !tt.ipv6 { + testIPv4UnicastSocketOptions(t, fd) + } else { + testIPv6UnicastSocketOptions(t, fd) + } + closer.Close() + } +} + +func testIPv4UnicastSocketOptions(t *testing.T, fd *netFD) { + tos, err := ipv4TOS(fd) + if err != nil { + t.Fatalf("ipv4TOS failed: %v", err) + } + t.Logf("IPv4 TOS: %v", tos) + err = setIPv4TOS(fd, 1) + if err != nil { + t.Fatalf("setIPv4TOS failed: %v", err) + } + + ttl, err := ipv4TTL(fd) + if err != nil { + t.Fatalf("ipv4TTL failed: %v", err) + } + t.Logf("IPv4 TTL: %v", ttl) + err = setIPv4TTL(fd, 1) + if err != nil { + t.Fatalf("setIPv4TTL failed: %v", err) + } +} + +func testIPv6UnicastSocketOptions(t *testing.T, fd *netFD) { + tos, err := ipv6TrafficClass(fd) + if err != nil { + t.Fatalf("ipv6TrafficClass failed: %v", err) + } + t.Logf("IPv6 TrafficClass: %v", tos) + err = setIPv6TrafficClass(fd, 1) + if err != nil { + t.Fatalf("setIPv6TrafficClass failed: %v", err) + } + + hoplim, err := ipv6HopLimit(fd) + if err != nil { + t.Fatalf("ipv6HopLimit failed: %v", err) + } + t.Logf("IPv6 HopLimit: %v", hoplim) + err = setIPv6HopLimit(fd, 1) + if err != nil { + t.Fatalf("setIPv6HopLimit failed: %v", err) + } +} diff --git a/src/pkg/net/unixsock.go b/src/pkg/net/unixsock.go index d5040f9a2..ae0956958 100644 --- a/src/pkg/net/unixsock.go +++ b/src/pkg/net/unixsock.go @@ -6,10 +6,6 @@ package net -import ( - "os" -) - // UnixAddr represents the address of a Unix domain socket end point. type UnixAddr struct { Name string @@ -38,7 +34,7 @@ func (a *UnixAddr) toAddr() Addr { // ResolveUnixAddr parses addr as a Unix domain socket address. // The string net gives the network name, "unix", "unixgram" or // "unixpacket". -func ResolveUnixAddr(net, addr string) (*UnixAddr, os.Error) { +func ResolveUnixAddr(net, addr string) (*UnixAddr, error) { switch net { case "unix": case "unixpacket": diff --git a/src/pkg/net/unixsock_plan9.go b/src/pkg/net/unixsock_plan9.go index 7e212df8a..e8087d09a 100644 --- a/src/pkg/net/unixsock_plan9.go +++ b/src/pkg/net/unixsock_plan9.go @@ -8,6 +8,7 @@ package net import ( "os" + "time" ) // UnixConn is an implementation of the Conn interface @@ -17,17 +18,17 @@ type UnixConn bool // Implementation of the Conn interface - see Conn for documentation. // Read implements the net.Conn Read method. -func (c *UnixConn) Read(b []byte) (n int, err os.Error) { +func (c *UnixConn) Read(b []byte) (n int, err error) { return 0, os.EPLAN9 } // Write implements the net.Conn Write method. -func (c *UnixConn) Write(b []byte) (n int, err os.Error) { +func (c *UnixConn) Write(b []byte) (n int, err error) { return 0, os.EPLAN9 } // Close closes the Unix domain connection. -func (c *UnixConn) Close() os.Error { +func (c *UnixConn) Close() error { return os.EPLAN9 } @@ -44,29 +45,29 @@ func (c *UnixConn) RemoteAddr() Addr { return nil } -// SetTimeout implements the net.Conn SetTimeout method. -func (c *UnixConn) SetTimeout(nsec int64) os.Error { +// SetDeadline implements the net.Conn SetDeadline method. +func (c *UnixConn) SetDeadline(t time.Time) error { return os.EPLAN9 } -// SetReadTimeout implements the net.Conn SetReadTimeout method. -func (c *UnixConn) SetReadTimeout(nsec int64) os.Error { +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (c *UnixConn) SetReadDeadline(t time.Time) error { return os.EPLAN9 } -// SetWriteTimeout implements the net.Conn SetWriteTimeout method. -func (c *UnixConn) SetWriteTimeout(nsec int64) os.Error { +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (c *UnixConn) SetWriteDeadline(t time.Time) error { return os.EPLAN9 } // ReadFrom implements the net.PacketConn ReadFrom method. -func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { +func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err error) { err = os.EPLAN9 return } // WriteTo implements the net.PacketConn WriteTo method. -func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { +func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err error) { err = os.EPLAN9 return } @@ -74,7 +75,7 @@ func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { // DialUnix connects to the remote address raddr on the network net, // which must be "unix" or "unixgram". If laddr is not nil, it is used // as the local address for the connection. -func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err os.Error) { +func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err error) { return nil, os.EPLAN9 } @@ -85,19 +86,19 @@ type UnixListener bool // ListenUnix announces on the Unix domain socket laddr and returns a Unix listener. // Net must be "unix" (stream sockets). -func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) { +func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err error) { return nil, os.EPLAN9 } // Accept implements the Accept method in the Listener interface; // it waits for the next call and returns a generic Conn. -func (l *UnixListener) Accept() (c Conn, err os.Error) { +func (l *UnixListener) Accept() (c Conn, err error) { return nil, os.EPLAN9 } // Close stops listening on the Unix address. // Already accepted connections are not closed. -func (l *UnixListener) Close() os.Error { +func (l *UnixListener) Close() error { return os.EPLAN9 } diff --git a/src/pkg/net/unixsock_posix.go b/src/pkg/net/unixsock_posix.go index fccf0189c..e500ddb4e 100644 --- a/src/pkg/net/unixsock_posix.go +++ b/src/pkg/net/unixsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux openbsd windows +// +build darwin freebsd linux netbsd openbsd windows // Unix domain sockets @@ -11,19 +11,20 @@ package net import ( "os" "syscall" + "time" ) -func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err os.Error) { - var proto int +func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err error) { + var sotype int switch net { default: return nil, UnknownNetworkError(net) case "unix": - proto = syscall.SOCK_STREAM + sotype = syscall.SOCK_STREAM case "unixgram": - proto = syscall.SOCK_DGRAM + sotype = syscall.SOCK_DGRAM case "unixpacket": - proto = syscall.SOCK_SEQPACKET + sotype = syscall.SOCK_SEQPACKET } var la, ra syscall.Sockaddr @@ -37,8 +38,8 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err } if raddr != nil { ra = &syscall.SockaddrUnix{Name: raddr.Name} - } else if proto != syscall.SOCK_DGRAM || laddr == nil { - return nil, &OpError{Op: mode, Net: net, Error: errMissingAddress} + } else if sotype != syscall.SOCK_DGRAM || laddr == nil { + return nil, &OpError{Op: mode, Net: net, Err: errMissingAddress} } case "listen": @@ -47,18 +48,18 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err } la = &syscall.SockaddrUnix{Name: laddr.Name} if raddr != nil { - return nil, &OpError{Op: mode, Net: net, Addr: raddr, Error: &AddrError{Error: "unexpected remote address", Addr: raddr.String()}} + return nil, &OpError{Op: mode, Net: net, Addr: raddr, Err: &AddrError{Err: "unexpected remote address", Addr: raddr.String()}} } } f := sockaddrToUnix - if proto == syscall.SOCK_DGRAM { + if sotype == syscall.SOCK_DGRAM { f = sockaddrToUnixgram - } else if proto == syscall.SOCK_SEQPACKET { + } else if sotype == syscall.SOCK_SEQPACKET { f = sockaddrToUnixpacket } - fd, oserr := socket(net, syscall.AF_UNIX, proto, 0, la, ra, f) + fd, oserr := socket(net, syscall.AF_UNIX, sotype, 0, la, ra, f) if oserr != nil { goto Error } @@ -69,7 +70,7 @@ Error: if mode == "listen" { addr = laddr } - return nil, &OpError{Op: mode, Net: net, Addr: addr, Error: oserr} + return nil, &OpError{Op: mode, Net: net, Addr: addr, Err: oserr} } func sockaddrToUnix(sa syscall.Sockaddr) Addr { @@ -93,8 +94,8 @@ func sockaddrToUnixpacket(sa syscall.Sockaddr) Addr { return nil } -func protoToNet(proto int) string { - switch proto { +func sotypeToNet(sotype int) string { + switch sotype { case syscall.SOCK_STREAM: return "unix" case syscall.SOCK_SEQPACKET: @@ -102,7 +103,7 @@ func protoToNet(proto int) string { case syscall.SOCK_DGRAM: return "unixgram" default: - panic("protoToNet unknown protocol") + panic("sotypeToNet unknown socket type") } return "" } @@ -120,7 +121,7 @@ func (c *UnixConn) ok() bool { return c != nil && c.fd != nil } // Implementation of the Conn interface - see Conn for documentation. // Read implements the net.Conn Read method. -func (c *UnixConn) Read(b []byte) (n int, err os.Error) { +func (c *UnixConn) Read(b []byte) (n int, err error) { if !c.ok() { return 0, os.EINVAL } @@ -128,7 +129,7 @@ func (c *UnixConn) Read(b []byte) (n int, err os.Error) { } // Write implements the net.Conn Write method. -func (c *UnixConn) Write(b []byte) (n int, err os.Error) { +func (c *UnixConn) Write(b []byte) (n int, err error) { if !c.ok() { return 0, os.EINVAL } @@ -136,7 +137,7 @@ func (c *UnixConn) Write(b []byte) (n int, err os.Error) { } // Close closes the Unix domain connection. -func (c *UnixConn) Close() os.Error { +func (c *UnixConn) Close() error { if !c.ok() { return os.EINVAL } @@ -164,33 +165,33 @@ func (c *UnixConn) RemoteAddr() Addr { return c.fd.raddr } -// SetTimeout implements the net.Conn SetTimeout method. -func (c *UnixConn) SetTimeout(nsec int64) os.Error { +// SetDeadline implements the net.Conn SetDeadline method. +func (c *UnixConn) SetDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setTimeout(c.fd, nsec) + return setDeadline(c.fd, t) } -// SetReadTimeout implements the net.Conn SetReadTimeout method. -func (c *UnixConn) SetReadTimeout(nsec int64) os.Error { +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (c *UnixConn) SetReadDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setReadTimeout(c.fd, nsec) + return setReadDeadline(c.fd, t) } -// SetWriteTimeout implements the net.Conn SetWriteTimeout method. -func (c *UnixConn) SetWriteTimeout(nsec int64) os.Error { +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (c *UnixConn) SetWriteDeadline(t time.Time) error { if !c.ok() { return os.EINVAL } - return setWriteTimeout(c.fd, nsec) + return setWriteDeadline(c.fd, t) } // SetReadBuffer sets the size of the operating system's // receive buffer associated with the connection. -func (c *UnixConn) SetReadBuffer(bytes int) os.Error { +func (c *UnixConn) SetReadBuffer(bytes int) error { if !c.ok() { return os.EINVAL } @@ -199,7 +200,7 @@ func (c *UnixConn) SetReadBuffer(bytes int) os.Error { // SetWriteBuffer sets the size of the operating system's // transmit buffer associated with the connection. -func (c *UnixConn) SetWriteBuffer(bytes int) os.Error { +func (c *UnixConn) SetWriteBuffer(bytes int) error { if !c.ok() { return os.EINVAL } @@ -212,21 +213,21 @@ func (c *UnixConn) SetWriteBuffer(bytes int) os.Error { // // ReadFromUnix can be made to time out and return // an error with Timeout() == true after a fixed time limit; -// see SetTimeout and SetReadTimeout. -func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err os.Error) { +// see SetDeadline and SetReadDeadline. +func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err error) { if !c.ok() { return 0, nil, os.EINVAL } n, sa, err := c.fd.ReadFrom(b) switch sa := sa.(type) { case *syscall.SockaddrUnix: - addr = &UnixAddr{sa.Name, protoToNet(c.fd.proto)} + addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} } return } // ReadFrom implements the net.PacketConn ReadFrom method. -func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { +func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err error) { if !c.ok() { return 0, nil, os.EINVAL } @@ -238,13 +239,13 @@ func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { // // WriteToUnix can be made to time out and return // an error with Timeout() == true after a fixed time limit; -// see SetTimeout and SetWriteTimeout. +// see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. -func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err os.Error) { +func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err error) { if !c.ok() { return 0, os.EINVAL } - if addr.Net != protoToNet(c.fd.proto) { + if addr.Net != sotypeToNet(c.fd.sotype) { return 0, os.EAFNOSUPPORT } sa := &syscall.SockaddrUnix{Name: addr.Name} @@ -252,35 +253,35 @@ func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err os.Error) { } // WriteTo implements the net.PacketConn WriteTo method. -func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { +func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err error) { if !c.ok() { return 0, os.EINVAL } a, ok := addr.(*UnixAddr) if !ok { - return 0, &OpError{"writeto", "unix", addr, os.EINVAL} + return 0, &OpError{"write", c.fd.net, addr, os.EINVAL} } return c.WriteToUnix(b, a) } -func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err os.Error) { +func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) { if !c.ok() { return 0, 0, 0, nil, os.EINVAL } n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob) switch sa := sa.(type) { case *syscall.SockaddrUnix: - addr = &UnixAddr{sa.Name, protoToNet(c.fd.proto)} + addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} } return } -func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err os.Error) { +func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) { if !c.ok() { return 0, 0, os.EINVAL } if addr != nil { - if addr.Net != protoToNet(c.fd.proto) { + if addr.Net != sotypeToNet(c.fd.sotype) { return 0, 0, os.EAFNOSUPPORT } sa := &syscall.SockaddrUnix{Name: addr.Name} @@ -292,12 +293,12 @@ func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err // File returns a copy of the underlying os.File, set to blocking mode. // It is the caller's responsibility to close f when finished. // Closing c does not affect f, and closing f does not affect c. -func (c *UnixConn) File() (f *os.File, err os.Error) { return c.fd.dup() } +func (c *UnixConn) File() (f *os.File, err error) { return c.fd.dup() } // DialUnix connects to the remote address raddr on the network net, // which must be "unix" or "unixgram". If laddr is not nil, it is used // as the local address for the connection. -func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err os.Error) { +func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err error) { fd, e := unixSocket(net, laddr, raddr, "dial") if e != nil { return nil, e @@ -315,7 +316,7 @@ type UnixListener struct { // ListenUnix announces on the Unix domain socket laddr and returns a Unix listener. // Net must be "unix" (stream sockets). -func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) { +func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { if net != "unix" && net != "unixgram" && net != "unixpacket" { return nil, UnknownNetworkError(net) } @@ -326,17 +327,17 @@ func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) { if err != nil { return nil, err } - e1 := syscall.Listen(fd.sysfd, 8) // listenBacklog()); - if e1 != 0 { + err = syscall.Listen(fd.sysfd, listenerBacklog) + if err != nil { closesocket(fd.sysfd) - return nil, &OpError{Op: "listen", Net: "unix", Addr: laddr, Error: os.Errno(e1)} + return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: err} } return &UnixListener{fd, laddr.Name}, nil } // AcceptUnix accepts the next incoming call and returns the new connection // and the remote address. -func (l *UnixListener) AcceptUnix() (c *UnixConn, err os.Error) { +func (l *UnixListener) AcceptUnix() (c *UnixConn, err error) { if l == nil || l.fd == nil { return nil, os.EINVAL } @@ -350,7 +351,7 @@ func (l *UnixListener) AcceptUnix() (c *UnixConn, err os.Error) { // Accept implements the Accept method in the Listener interface; // it waits for the next call and returns a generic Conn. -func (l *UnixListener) Accept() (c Conn, err os.Error) { +func (l *UnixListener) Accept() (c Conn, err error) { c1, err := l.AcceptUnix() if err != nil { return nil, err @@ -360,7 +361,7 @@ func (l *UnixListener) Accept() (c Conn, err os.Error) { // Close stops listening on the Unix address. // Already accepted connections are not closed. -func (l *UnixListener) Close() os.Error { +func (l *UnixListener) Close() error { if l == nil || l.fd == nil { return os.EINVAL } @@ -386,31 +387,32 @@ func (l *UnixListener) Close() os.Error { // Addr returns the listener's network address. func (l *UnixListener) Addr() Addr { return l.fd.laddr } -// SetTimeout sets the deadline associated wuth the listener -func (l *UnixListener) SetTimeout(nsec int64) (err os.Error) { +// SetDeadline sets the deadline associated with the listener. +// A zero time value disables the deadline. +func (l *UnixListener) SetDeadline(t time.Time) (err error) { if l == nil || l.fd == nil { return os.EINVAL } - return setTimeout(l.fd, nsec) + return setDeadline(l.fd, t) } // File returns a copy of the underlying os.File, set to blocking mode. // It is the caller's responsibility to close f when finished. // Closing c does not affect f, and closing f does not affect c. -func (l *UnixListener) File() (f *os.File, err os.Error) { return l.fd.dup() } +func (l *UnixListener) File() (f *os.File, err error) { return l.fd.dup() } // ListenUnixgram listens for incoming Unix datagram packets addressed to the // local address laddr. The returned connection c's ReadFrom // and WriteTo methods can be used to receive and send UDP // packets with per-packet addressing. The network net must be "unixgram". -func ListenUnixgram(net string, laddr *UnixAddr) (c *UDPConn, err os.Error) { +func ListenUnixgram(net string, laddr *UnixAddr) (c *UDPConn, err error) { switch net { case "unixgram": default: return nil, UnknownNetworkError(net) } if laddr == nil { - return nil, &OpError{"listen", "unixgram", nil, errMissingAddress} + return nil, &OpError{"listen", net, nil, errMissingAddress} } fd, e := unixSocket(net, laddr, nil, "listen") if e != nil { diff --git a/src/pkg/net/url/Makefile b/src/pkg/net/url/Makefile new file mode 100644 index 000000000..bef0647a4 --- /dev/null +++ b/src/pkg/net/url/Makefile @@ -0,0 +1,11 @@ +# Copyright 2009 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=net/url +GOFILES=\ + url.go\ + +include ../../../Make.pkg diff --git a/src/pkg/net/url/url.go b/src/pkg/net/url/url.go new file mode 100644 index 000000000..0068e98af --- /dev/null +++ b/src/pkg/net/url/url.go @@ -0,0 +1,664 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package URL parses URLs and implements query escaping. +// See RFC 3986. +package url + +import ( + "errors" + "strconv" + "strings" +) + +// Error reports an error and the operation and URL that caused it. +type Error struct { + Op string + URL string + Err error +} + +func (e *Error) Error() string { return e.Op + " " + e.URL + ": " + e.Err.Error() } + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} + +type encoding int + +const ( + encodePath encoding = 1 + iota + encodeUserPassword + encodeQueryComponent + encodeFragment +) + +type EscapeError string + +func (e EscapeError) Error() string { + return "invalid URL escape " + strconv.Quote(string(e)) +} + +// Return true if the specified character should be escaped when +// appearing in a URL string, according to RFC 2396. +// When 'all' is true the full range of reserved characters are matched. +func shouldEscape(c byte, mode encoding) bool { + // RFC 2396 §2.3 Unreserved characters (alphanum) + if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' { + return false + } + // TODO: Update the character sets after RFC 3986. + switch c { + case '-', '_', '.', '!', '~', '*', '\'', '(', ')': // §2.3 Unreserved characters (mark) + return false + + case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved) + // Different sections of the URL allow a few of + // the reserved characters to appear unescaped. + switch mode { + case encodePath: // §3.3 + // The RFC allows : @ & = + $ but saves / ; , for assigning + // meaning to individual path segments. This package + // only manipulates the path as a whole, so we allow those + // last two as well. That leaves only ? to escape. + return c == '?' + + case encodeUserPassword: // §3.2.2 + // The RFC allows ; : & = + $ , in userinfo, so we must escape only @ and /. + // The parsing of userinfo treats : as special so we must escape that too. + return c == '@' || c == '/' || c == ':' + + case encodeQueryComponent: // §3.4 + // The RFC reserves (so we must escape) everything. + return true + + case encodeFragment: // §4.1 + // The RFC text is silent but the grammar allows + // everything, so escape nothing. + return false + } + } + + // Everything else must be escaped. + return true +} + +// QueryUnescape does the inverse transformation of QueryEscape, converting +// %AB into the byte 0xAB and '+' into ' ' (space). It returns an error if +// any % is not followed by two hexadecimal digits. +func QueryUnescape(s string) (string, error) { + return unescape(s, encodeQueryComponent) +} + +// unescape unescapes a string; the mode specifies +// which section of the URL string is being unescaped. +func unescape(s string, mode encoding) (string, error) { + // Count %, check that they're well-formed. + n := 0 + hasPlus := false + for i := 0; i < len(s); { + switch s[i] { + case '%': + n++ + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + s = s[i:] + if len(s) > 3 { + s = s[0:3] + } + return "", EscapeError(s) + } + i += 3 + case '+': + hasPlus = mode == encodeQueryComponent + i++ + default: + i++ + } + } + + if n == 0 && !hasPlus { + return s, nil + } + + t := make([]byte, len(s)-2*n) + j := 0 + for i := 0; i < len(s); { + switch s[i] { + case '%': + t[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) + j++ + i += 3 + case '+': + if mode == encodeQueryComponent { + t[j] = ' ' + } else { + t[j] = '+' + } + j++ + i++ + default: + t[j] = s[i] + j++ + i++ + } + } + return string(t), nil +} + +// QueryEscape escapes the string so it can be safely placed +// inside a URL query. +func QueryEscape(s string) string { + return escape(s, encodeQueryComponent) +} + +func escape(s string, mode encoding) string { + spaceCount, hexCount := 0, 0 + for i := 0; i < len(s); i++ { + c := s[i] + if shouldEscape(c, mode) { + if c == ' ' && mode == encodeQueryComponent { + spaceCount++ + } else { + hexCount++ + } + } + } + + if spaceCount == 0 && hexCount == 0 { + return s + } + + t := make([]byte, len(s)+2*hexCount) + j := 0 + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case c == ' ' && mode == encodeQueryComponent: + t[j] = '+' + j++ + case shouldEscape(c, mode): + t[j] = '%' + t[j+1] = "0123456789ABCDEF"[c>>4] + t[j+2] = "0123456789ABCDEF"[c&15] + j += 3 + default: + t[j] = s[i] + j++ + } + } + return string(t) +} + +// A URL represents a parsed URL (technically, a URI reference). +// The general form represented is: +// +// scheme://[userinfo@]host/path[?query][#fragment] +// +// URLs that do not start with a slash after the scheme are interpreted as: +// +// scheme:opaque[?query][#fragment] +// +type URL struct { + Scheme string + Opaque string // encoded opaque data + User *Userinfo // username and password information + Host string + Path string + RawQuery string // encoded query values, without '?' + Fragment string // fragment for references, without '#' +} + +// User returns a Userinfo containing the provided username +// and no password set. +func User(username string) *Userinfo { + return &Userinfo{username, "", false} +} + +// UserPassword returns a Userinfo containing the provided username +// and password. +// This functionality should only be used with legacy web sites. +// RFC 2396 warns that interpreting Userinfo this way +// ``is NOT RECOMMENDED, because the passing of authentication +// information in clear text (such as URI) has proven to be a +// security risk in almost every case where it has been used.'' +func UserPassword(username, password string) *Userinfo { + return &Userinfo{username, password, true} +} + +// The Userinfo type is an immutable encapsulation of username and +// password details for a URL. An existing Userinfo value is guaranteed +// to have a username set (potentially empty, as allowed by RFC 2396), +// and optionally a password. +type Userinfo struct { + username string + password string + passwordSet bool +} + +// Username returns the username. +func (u *Userinfo) Username() string { + return u.username +} + +// Password returns the password in case it is set, and whether it is set. +func (u *Userinfo) Password() (string, bool) { + if u.passwordSet { + return u.password, true + } + return "", false +} + +// String returns the encoded userinfo information in the standard form +// of "username[:password]". +func (u *Userinfo) String() string { + s := escape(u.username, encodeUserPassword) + if u.passwordSet { + s += ":" + escape(u.password, encodeUserPassword) + } + return s +} + +// Maybe rawurl is of the form scheme:path. +// (Scheme must be [a-zA-Z][a-zA-Z0-9+-.]*) +// If so, return scheme, path; else return "", rawurl. +func getscheme(rawurl string) (scheme, path string, err error) { + for i := 0; i < len(rawurl); i++ { + c := rawurl[i] + switch { + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + // do nothing + case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.': + if i == 0 { + return "", rawurl, nil + } + case c == ':': + if i == 0 { + return "", "", errors.New("missing protocol scheme") + } + return rawurl[0:i], rawurl[i+1:], nil + default: + // we have encountered an invalid character, + // so there is no valid scheme + return "", rawurl, nil + } + } + return "", rawurl, nil +} + +// Maybe s is of the form t c u. +// If so, return t, c u (or t, u if cutc == true). +// If not, return s, "". +func split(s string, c byte, cutc bool) (string, string) { + for i := 0; i < len(s); i++ { + if s[i] == c { + if cutc { + return s[0:i], s[i+1:] + } + return s[0:i], s[i:] + } + } + return s, "" +} + +// Parse parses rawurl into a URL structure. +// The string rawurl is assumed not to have a #fragment suffix. +// (Web browsers strip #fragment before sending the URL to a web server.) +// The rawurl may be relative or absolute. +func Parse(rawurl string) (url *URL, err error) { + return parse(rawurl, false) +} + +// ParseRequest parses rawurl into a URL structure. It assumes that +// rawurl was received from an HTTP request, so the rawurl is interpreted +// only as an absolute URI or an absolute path. +// The string rawurl is assumed not to have a #fragment suffix. +// (Web browsers strip #fragment before sending the URL to a web server.) +func ParseRequest(rawurl string) (url *URL, err error) { + return parse(rawurl, true) +} + +// parse parses a URL from a string in one of two contexts. If +// viaRequest is true, the URL is assumed to have arrived via an HTTP request, +// in which case only absolute URLs or path-absolute relative URLs are allowed. +// If viaRequest is false, all forms of relative URLs are allowed. +func parse(rawurl string, viaRequest bool) (url *URL, err error) { + var rest string + + if rawurl == "" { + err = errors.New("empty url") + goto Error + } + url = new(URL) + + // Split off possible leading "http:", "mailto:", etc. + // Cannot contain escaped characters. + if url.Scheme, rest, err = getscheme(rawurl); err != nil { + goto Error + } + + rest, url.RawQuery = split(rest, '?', true) + + if !strings.HasPrefix(rest, "/") { + if url.Scheme != "" { + // We consider rootless paths per RFC 3986 as opaque. + url.Opaque = rest + return url, nil + } + if viaRequest { + err = errors.New("invalid URI for request") + goto Error + } + } + + if (url.Scheme != "" || !viaRequest) && strings.HasPrefix(rest, "//") && !strings.HasPrefix(rest, "///") { + var authority string + authority, rest = split(rest[2:], '/', false) + url.User, url.Host, err = parseAuthority(authority) + if err != nil { + goto Error + } + if strings.Contains(url.Host, "%") { + err = errors.New("hexadecimal escape in host") + goto Error + } + } + if url.Path, err = unescape(rest, encodePath); err != nil { + goto Error + } + return url, nil + +Error: + return nil, &Error{"parse", rawurl, err} +} + +func parseAuthority(authority string) (user *Userinfo, host string, err error) { + if strings.Index(authority, "@") < 0 { + host = authority + return + } + userinfo, host := split(authority, '@', true) + if strings.Index(userinfo, ":") < 0 { + if userinfo, err = unescape(userinfo, encodeUserPassword); err != nil { + return + } + user = User(userinfo) + } else { + username, password := split(userinfo, ':', true) + if username, err = unescape(username, encodeUserPassword); err != nil { + return + } + if password, err = unescape(password, encodeUserPassword); err != nil { + return + } + user = UserPassword(username, password) + } + return +} + +// ParseWithReference is like Parse but allows a trailing #fragment. +func ParseWithReference(rawurlref string) (url *URL, err error) { + // Cut off #frag + rawurl, frag := split(rawurlref, '#', true) + if url, err = Parse(rawurl); err != nil { + return nil, err + } + if frag == "" { + return url, nil + } + if url.Fragment, err = unescape(frag, encodeFragment); err != nil { + return nil, &Error{"parse", rawurlref, err} + } + return url, nil +} + +// String reassembles url into a valid URL string. +func (url *URL) String() string { + // TODO: Rewrite to use bytes.Buffer + result := "" + if url.Scheme != "" { + result += url.Scheme + ":" + } + if url.Opaque != "" { + result += url.Opaque + } else { + if url.Host != "" || url.User != nil { + result += "//" + if u := url.User; u != nil { + result += u.String() + "@" + } + result += url.Host + } + result += escape(url.Path, encodePath) + } + if url.RawQuery != "" { + result += "?" + url.RawQuery + } + if url.Fragment != "" { + result += "#" + escape(url.Fragment, encodeFragment) + } + return result +} + +// Values maps a string key to a list of values. +// It is typically used for query parameters and form values. +// Unlike in the http.Header map, the keys in a Values map +// are case-sensitive. +type Values map[string][]string + +// Get gets the first value associated with the given key. +// If there are no values associated with the key, Get returns +// the empty string. To access multiple values, use the map +// directly. +func (v Values) Get(key string) string { + if v == nil { + return "" + } + vs, ok := v[key] + if !ok || len(vs) == 0 { + return "" + } + return vs[0] +} + +// Set sets the key to value. It replaces any existing +// values. +func (v Values) Set(key, value string) { + v[key] = []string{value} +} + +// Add adds the key to value. It appends to any existing +// values associated with key. +func (v Values) Add(key, value string) { + v[key] = append(v[key], value) +} + +// Del deletes the values associated with key. +func (v Values) Del(key string) { + delete(v, key) +} + +// ParseQuery parses the URL-encoded query string and returns +// a map listing the values specified for each key. +// ParseQuery always returns a non-nil map containing all the +// valid query parameters found; err describes the first decoding error +// encountered, if any. +func ParseQuery(query string) (m Values, err error) { + m = make(Values) + err = parseQuery(m, query) + return +} + +func parseQuery(m Values, query string) (err error) { + for query != "" { + key := query + if i := strings.IndexAny(key, "&;"); i >= 0 { + key, query = key[:i], key[i+1:] + } else { + query = "" + } + if key == "" { + continue + } + value := "" + if i := strings.Index(key, "="); i >= 0 { + key, value = key[:i], key[i+1:] + } + key, err1 := QueryUnescape(key) + if err1 != nil { + err = err1 + continue + } + value, err1 = QueryUnescape(value) + if err1 != nil { + err = err1 + continue + } + m[key] = append(m[key], value) + } + return err +} + +// Encode encodes the values into ``URL encoded'' form. +// e.g. "foo=bar&bar=baz" +func (v Values) Encode() string { + if v == nil { + return "" + } + parts := make([]string, 0, len(v)) // will be large enough for most uses + for k, vs := range v { + prefix := QueryEscape(k) + "=" + for _, v := range vs { + parts = append(parts, prefix+QueryEscape(v)) + } + } + return strings.Join(parts, "&") +} + +// resolvePath applies special path segments from refs and applies +// them to base, per RFC 2396. +func resolvePath(basepath string, refpath string) string { + base := strings.Split(basepath, "/") + refs := strings.Split(refpath, "/") + if len(base) == 0 { + base = []string{""} + } + for idx, ref := range refs { + switch { + case ref == ".": + base[len(base)-1] = "" + case ref == "..": + newLen := len(base) - 1 + if newLen < 1 { + newLen = 1 + } + base = base[0:newLen] + base[len(base)-1] = "" + default: + if idx == 0 || base[len(base)-1] == "" { + base[len(base)-1] = ref + } else { + base = append(base, ref) + } + } + } + return strings.Join(base, "/") +} + +// IsAbs returns true if the URL is absolute. +func (url *URL) IsAbs() bool { + return url.Scheme != "" +} + +// Parse parses a URL in the context of a base URL. The URL in ref +// may be relative or absolute. Parse returns nil, err on parse +// failure, otherwise its return value is the same as ResolveReference. +func (base *URL) Parse(ref string) (*URL, error) { + refurl, err := Parse(ref) + if err != nil { + return nil, err + } + return base.ResolveReference(refurl), nil +} + +// ResolveReference resolves a URI reference to an absolute URI from +// an absolute base URI, per RFC 2396 Section 5.2. The URI reference +// may be relative or absolute. ResolveReference always returns a new +// URL instance, even if the returned URL is identical to either the +// base or reference. If ref is an absolute URL, then ResolveReference +// ignores base and returns a copy of ref. +func (base *URL) ResolveReference(ref *URL) *URL { + if ref.IsAbs() { + url := *ref + return &url + } + // relativeURI = ( net_path | abs_path | rel_path ) [ "?" query ] + url := *base + url.RawQuery = ref.RawQuery + url.Fragment = ref.Fragment + if ref.Opaque != "" { + url.Opaque = ref.Opaque + url.User = nil + url.Host = "" + url.Path = "" + return &url + } + if ref.Host != "" || ref.User != nil { + // The "net_path" case. + url.Host = ref.Host + url.User = ref.User + } + if strings.HasPrefix(ref.Path, "/") { + // The "abs_path" case. + url.Path = ref.Path + } else { + // The "rel_path" case. + path := resolvePath(base.Path, ref.Path) + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + url.Path = path + } + return &url +} + +// Query parses RawQuery and returns the corresponding values. +func (u *URL) Query() Values { + v, _ := ParseQuery(u.RawQuery) + return v +} + +// RequestURI returns the encoded path?query or opaque?query +// string that would be used in an HTTP request for u. +func (u *URL) RequestURI() string { + result := u.Opaque + if result == "" { + result = escape(u.Path, encodePath) + if result == "" { + result = "/" + } + } + if u.RawQuery != "" { + result += "?" + u.RawQuery + } + return result +} diff --git a/src/pkg/net/url/url_test.go b/src/pkg/net/url/url_test.go new file mode 100644 index 000000000..9fe5ff886 --- /dev/null +++ b/src/pkg/net/url/url_test.go @@ -0,0 +1,771 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package url + +import ( + "fmt" + "reflect" + "testing" +) + +type URLTest struct { + in string + out *URL + roundtrip string // expected result of reserializing the URL; empty means same as "in". +} + +var urltests = []URLTest{ + // no path + { + "http://www.google.com", + &URL{ + Scheme: "http", + Host: "www.google.com", + }, + "", + }, + // path + { + "http://www.google.com/", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + }, + "", + }, + // path with hex escaping + { + "http://www.google.com/file%20one%26two", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/file one&two", + }, + "http://www.google.com/file%20one&two", + }, + // user + { + "ftp://webmaster@www.google.com/", + &URL{ + Scheme: "ftp", + User: User("webmaster"), + Host: "www.google.com", + Path: "/", + }, + "", + }, + // escape sequence in username + { + "ftp://john%20doe@www.google.com/", + &URL{ + Scheme: "ftp", + User: User("john doe"), + Host: "www.google.com", + Path: "/", + }, + "ftp://john%20doe@www.google.com/", + }, + // query + { + "http://www.google.com/?q=go+language", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + RawQuery: "q=go+language", + }, + "", + }, + // query with hex escaping: NOT parsed + { + "http://www.google.com/?q=go%20language", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + RawQuery: "q=go%20language", + }, + "", + }, + // %20 outside query + { + "http://www.google.com/a%20b?q=c+d", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/a b", + RawQuery: "q=c+d", + }, + "", + }, + // path without leading /, so no parsing + { + "http:www.google.com/?q=go+language", + &URL{ + Scheme: "http", + Opaque: "www.google.com/", + RawQuery: "q=go+language", + }, + "http:www.google.com/?q=go+language", + }, + // path without leading /, so no parsing + { + "http:%2f%2fwww.google.com/?q=go+language", + &URL{ + Scheme: "http", + Opaque: "%2f%2fwww.google.com/", + RawQuery: "q=go+language", + }, + "http:%2f%2fwww.google.com/?q=go+language", + }, + // non-authority + { + "mailto:/webmaster@golang.org", + &URL{ + Scheme: "mailto", + Path: "/webmaster@golang.org", + }, + "", + }, + // non-authority + { + "mailto:webmaster@golang.org", + &URL{ + Scheme: "mailto", + Opaque: "webmaster@golang.org", + }, + "", + }, + // unescaped :// in query should not create a scheme + { + "/foo?query=http://bad", + &URL{ + Path: "/foo", + RawQuery: "query=http://bad", + }, + "", + }, + // leading // without scheme should create an authority + { + "//foo", + &URL{ + Host: "foo", + }, + "", + }, + // leading // without scheme, with userinfo, path, and query + { + "//user@foo/path?a=b", + &URL{ + User: User("user"), + Host: "foo", + Path: "/path", + RawQuery: "a=b", + }, + "", + }, + // Three leading slashes isn't an authority, but doesn't return an error. + // (We can't return an error, as this code is also used via + // ServeHTTP -> ReadRequest -> Parse, which is arguably a + // different URL parsing context, but currently shares the + // same codepath) + { + "///threeslashes", + &URL{ + Path: "///threeslashes", + }, + "", + }, + { + "http://user:password@google.com", + &URL{ + Scheme: "http", + User: UserPassword("user", "password"), + Host: "google.com", + }, + "http://user:password@google.com", + }, +} + +var urlnofragtests = []URLTest{ + { + "http://www.google.com/?q=go+language#foo", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + RawQuery: "q=go+language#foo", + }, + "", + }, +} + +var urlfragtests = []URLTest{ + { + "http://www.google.com/?q=go+language#foo", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + RawQuery: "q=go+language", + Fragment: "foo", + }, + "", + }, + { + "http://www.google.com/?q=go+language#foo%26bar", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + RawQuery: "q=go+language", + Fragment: "foo&bar", + }, + "http://www.google.com/?q=go+language#foo&bar", + }, +} + +// more useful string for debugging than fmt's struct printer +func ufmt(u *URL) string { + var user, pass interface{} + if u.User != nil { + user = u.User.Username() + if p, ok := u.User.Password(); ok { + pass = p + } + } + return fmt.Sprintf("opaque=%q, scheme=%q, user=%#v, pass=%#v, host=%q, path=%q, rawq=%q, frag=%q", + u.Opaque, u.Scheme, user, pass, u.Host, u.Path, u.RawQuery, u.Fragment) +} + +func DoTest(t *testing.T, parse func(string) (*URL, error), name string, tests []URLTest) { + for _, tt := range tests { + u, err := parse(tt.in) + if err != nil { + t.Errorf("%s(%q) returned error %s", name, tt.in, err) + continue + } + if !reflect.DeepEqual(u, tt.out) { + t.Errorf("%s(%q):\n\thave %v\n\twant %v\n", + name, tt.in, ufmt(u), ufmt(tt.out)) + } + } +} + +func TestParse(t *testing.T) { + DoTest(t, Parse, "Parse", urltests) + DoTest(t, Parse, "Parse", urlnofragtests) +} + +func TestParseWithReference(t *testing.T) { + DoTest(t, ParseWithReference, "ParseWithReference", urltests) + DoTest(t, ParseWithReference, "ParseWithReference", urlfragtests) +} + +const pathThatLooksSchemeRelative = "//not.a.user@not.a.host/just/a/path" + +var parseRequestUrlTests = []struct { + url string + expectedValid bool +}{ + {"http://foo.com", true}, + {"http://foo.com/", true}, + {"http://foo.com/path", true}, + {"/", true}, + {pathThatLooksSchemeRelative, true}, + {"//not.a.user@%66%6f%6f.com/just/a/path/also", true}, + {"foo.html", false}, + {"../dir/", false}, +} + +func TestParseRequest(t *testing.T) { + for _, test := range parseRequestUrlTests { + _, err := ParseRequest(test.url) + valid := err == nil + if valid != test.expectedValid { + t.Errorf("Expected valid=%v for %q; got %v", test.expectedValid, test.url, valid) + } + } + + url, err := ParseRequest(pathThatLooksSchemeRelative) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + if url.Path != pathThatLooksSchemeRelative { + t.Errorf("Expected path %q; got %q", pathThatLooksSchemeRelative, url.Path) + } +} + +func DoTestString(t *testing.T, parse func(string) (*URL, error), name string, tests []URLTest) { + for _, tt := range tests { + u, err := parse(tt.in) + if err != nil { + t.Errorf("%s(%q) returned error %s", name, tt.in, err) + continue + } + expected := tt.in + if len(tt.roundtrip) > 0 { + expected = tt.roundtrip + } + s := u.String() + if s != expected { + t.Errorf("%s(%q).String() == %q (expected %q)", name, tt.in, s, expected) + } + } +} + +func TestURLString(t *testing.T) { + DoTestString(t, Parse, "Parse", urltests) + DoTestString(t, Parse, "Parse", urlnofragtests) + DoTestString(t, ParseWithReference, "ParseWithReference", urltests) + DoTestString(t, ParseWithReference, "ParseWithReference", urlfragtests) +} + +type EscapeTest struct { + in string + out string + err error +} + +var unescapeTests = []EscapeTest{ + { + "", + "", + nil, + }, + { + "abc", + "abc", + nil, + }, + { + "1%41", + "1A", + nil, + }, + { + "1%41%42%43", + "1ABC", + nil, + }, + { + "%4a", + "J", + nil, + }, + { + "%6F", + "o", + nil, + }, + { + "%", // not enough characters after % + "", + EscapeError("%"), + }, + { + "%a", // not enough characters after % + "", + EscapeError("%a"), + }, + { + "%1", // not enough characters after % + "", + EscapeError("%1"), + }, + { + "123%45%6", // not enough characters after % + "", + EscapeError("%6"), + }, + { + "%zzzzz", // invalid hex digits + "", + EscapeError("%zz"), + }, +} + +func TestUnescape(t *testing.T) { + for _, tt := range unescapeTests { + actual, err := QueryUnescape(tt.in) + if actual != tt.out || (err != nil) != (tt.err != nil) { + t.Errorf("QueryUnescape(%q) = %q, %s; want %q, %s", tt.in, actual, err, tt.out, tt.err) + } + } +} + +var escapeTests = []EscapeTest{ + { + "", + "", + nil, + }, + { + "abc", + "abc", + nil, + }, + { + "one two", + "one+two", + nil, + }, + { + "10%", + "10%25", + nil, + }, + { + " ?&=#+%!<>#\"{}|\\^[]`☺\t", + "+%3F%26%3D%23%2B%25!%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09", + nil, + }, +} + +func TestEscape(t *testing.T) { + for _, tt := range escapeTests { + actual := QueryEscape(tt.in) + if tt.out != actual { + t.Errorf("QueryEscape(%q) = %q, want %q", tt.in, actual, tt.out) + } + + // for bonus points, verify that escape:unescape is an identity. + roundtrip, err := QueryUnescape(actual) + if roundtrip != tt.in || err != nil { + t.Errorf("QueryUnescape(%q) = %q, %s; want %q, %s", actual, roundtrip, err, tt.in, "[no error]") + } + } +} + +//var userinfoTests = []UserinfoTest{ +// {"user", "password", "user:password"}, +// {"foo:bar", "~!@#$%^&*()_+{}|[]\\-=`:;'\"<>?,./", +// "foo%3Abar:~!%40%23$%25%5E&*()_+%7B%7D%7C%5B%5D%5C-=%60%3A;'%22%3C%3E?,.%2F"}, +//} + +type EncodeQueryTest struct { + m Values + expected string + expected1 string +} + +var encodeQueryTests = []EncodeQueryTest{ + {nil, "", ""}, + {Values{"q": {"puppies"}, "oe": {"utf8"}}, "q=puppies&oe=utf8", "oe=utf8&q=puppies"}, + {Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7", "q=dogs&q=%26&q=7"}, +} + +func TestEncodeQuery(t *testing.T) { + for _, tt := range encodeQueryTests { + if q := tt.m.Encode(); q != tt.expected && q != tt.expected1 { + t.Errorf(`EncodeQuery(%+v) = %q, want %q`, tt.m, q, tt.expected) + } + } +} + +var resolvePathTests = []struct { + base, ref, expected string +}{ + {"a/b", ".", "a/"}, + {"a/b", "c", "a/c"}, + {"a/b", "..", ""}, + {"a/", "..", ""}, + {"a/", "../..", ""}, + {"a/b/c", "..", "a/"}, + {"a/b/c", "../d", "a/d"}, + {"a/b/c", ".././d", "a/d"}, + {"a/b", "./..", ""}, + {"a/./b", ".", "a/./"}, + {"a/../", ".", "a/../"}, + {"a/.././b", "c", "a/.././c"}, +} + +func TestResolvePath(t *testing.T) { + for _, test := range resolvePathTests { + got := resolvePath(test.base, test.ref) + if got != test.expected { + t.Errorf("For %q + %q got %q; expected %q", test.base, test.ref, got, test.expected) + } + } +} + +var resolveReferenceTests = []struct { + base, rel, expected string +}{ + // Absolute URL references + {"http://foo.com?a=b", "https://bar.com/", "https://bar.com/"}, + {"http://foo.com/", "https://bar.com/?a=b", "https://bar.com/?a=b"}, + {"http://foo.com/bar", "mailto:foo@example.com", "mailto:foo@example.com"}, + + // Path-absolute references + {"http://foo.com/bar", "/baz", "http://foo.com/baz"}, + {"http://foo.com/bar?a=b#f", "/baz", "http://foo.com/baz"}, + {"http://foo.com/bar?a=b", "/baz?c=d", "http://foo.com/baz?c=d"}, + + // Scheme-relative + {"https://foo.com/bar?a=b", "//bar.com/quux", "https://bar.com/quux"}, + + // Path-relative references: + + // ... current directory + {"http://foo.com", ".", "http://foo.com/"}, + {"http://foo.com/bar", ".", "http://foo.com/"}, + {"http://foo.com/bar/", ".", "http://foo.com/bar/"}, + + // ... going down + {"http://foo.com", "bar", "http://foo.com/bar"}, + {"http://foo.com/", "bar", "http://foo.com/bar"}, + {"http://foo.com/bar/baz", "quux", "http://foo.com/bar/quux"}, + + // ... going up + {"http://foo.com/bar/baz", "../quux", "http://foo.com/quux"}, + {"http://foo.com/bar/baz", "../../../../../quux", "http://foo.com/quux"}, + {"http://foo.com/bar", "..", "http://foo.com/"}, + {"http://foo.com/bar/baz", "./..", "http://foo.com/"}, + + // "." and ".." in the base aren't special + {"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/./dotdot/../baz"}, + + // Triple dot isn't special + {"http://foo.com/bar", "...", "http://foo.com/..."}, + + // Fragment + {"http://foo.com/bar", ".#frag", "http://foo.com/#frag"}, +} + +func TestResolveReference(t *testing.T) { + mustParse := func(url string) *URL { + u, err := ParseWithReference(url) + if err != nil { + t.Fatalf("Expected URL to parse: %q, got error: %v", url, err) + } + return u + } + for _, test := range resolveReferenceTests { + base := mustParse(test.base) + rel := mustParse(test.rel) + url := base.ResolveReference(rel) + urlStr := url.String() + if urlStr != test.expected { + t.Errorf("Resolving %q + %q != %q; got %q", test.base, test.rel, test.expected, urlStr) + } + } + + // Test that new instances are returned. + base := mustParse("http://foo.com/") + abs := base.ResolveReference(mustParse(".")) + if base == abs { + t.Errorf("Expected no-op reference to return new URL instance.") + } + barRef := mustParse("http://bar.com/") + abs = base.ResolveReference(barRef) + if abs == barRef { + t.Errorf("Expected resolution of absolute reference to return new URL instance.") + } + + // Test the convenience wrapper too + base = mustParse("http://foo.com/path/one/") + abs, _ = base.Parse("../two") + expected := "http://foo.com/path/two" + if abs.String() != expected { + t.Errorf("Parse wrapper got %q; expected %q", abs.String(), expected) + } + _, err := base.Parse("") + if err == nil { + t.Errorf("Expected an error from Parse wrapper parsing an empty string.") + } + + // Ensure Opaque resets the URL. + base = mustParse("scheme://user@foo.com/bar") + abs = base.ResolveReference(&URL{Opaque: "opaque"}) + want := mustParse("scheme:opaque") + if *abs != *want { + t.Errorf("ResolveReference failed to resolve opaque URL: want %#v, got %#v", abs, want) + } +} + +func TestResolveReferenceOpaque(t *testing.T) { + mustParse := func(url string) *URL { + u, err := ParseWithReference(url) + if err != nil { + t.Fatalf("Expected URL to parse: %q, got error: %v", url, err) + } + return u + } + for _, test := range resolveReferenceTests { + base := mustParse(test.base) + rel := mustParse(test.rel) + url := base.ResolveReference(rel) + urlStr := url.String() + if urlStr != test.expected { + t.Errorf("Resolving %q + %q != %q; got %q", test.base, test.rel, test.expected, urlStr) + } + } + + // Test that new instances are returned. + base := mustParse("http://foo.com/") + abs := base.ResolveReference(mustParse(".")) + if base == abs { + t.Errorf("Expected no-op reference to return new URL instance.") + } + barRef := mustParse("http://bar.com/") + abs = base.ResolveReference(barRef) + if abs == barRef { + t.Errorf("Expected resolution of absolute reference to return new URL instance.") + } + + // Test the convenience wrapper too + base = mustParse("http://foo.com/path/one/") + abs, _ = base.Parse("../two") + expected := "http://foo.com/path/two" + if abs.String() != expected { + t.Errorf("Parse wrapper got %q; expected %q", abs.String(), expected) + } + _, err := base.Parse("") + if err == nil { + t.Errorf("Expected an error from Parse wrapper parsing an empty string.") + } + +} + +func TestQueryValues(t *testing.T) { + u, _ := Parse("http://x.com?foo=bar&bar=1&bar=2") + v := u.Query() + if len(v) != 2 { + t.Errorf("got %d keys in Query values, want 2", len(v)) + } + if g, e := v.Get("foo"), "bar"; g != e { + t.Errorf("Get(foo) = %q, want %q", g, e) + } + // Case sensitive: + if g, e := v.Get("Foo"), ""; g != e { + t.Errorf("Get(Foo) = %q, want %q", g, e) + } + if g, e := v.Get("bar"), "1"; g != e { + t.Errorf("Get(bar) = %q, want %q", g, e) + } + if g, e := v.Get("baz"), ""; g != e { + t.Errorf("Get(baz) = %q, want %q", g, e) + } + v.Del("bar") + if g, e := v.Get("bar"), ""; g != e { + t.Errorf("second Get(bar) = %q, want %q", g, e) + } +} + +type parseTest struct { + query string + out Values +} + +var parseTests = []parseTest{ + { + query: "a=1&b=2", + out: Values{"a": []string{"1"}, "b": []string{"2"}}, + }, + { + query: "a=1&a=2&a=banana", + out: Values{"a": []string{"1", "2", "banana"}}, + }, + { + query: "ascii=%3Ckey%3A+0x90%3E", + out: Values{"ascii": []string{"<key: 0x90>"}}, + }, + { + query: "a=1;b=2", + out: Values{"a": []string{"1"}, "b": []string{"2"}}, + }, + { + query: "a=1&a=2;a=banana", + out: Values{"a": []string{"1", "2", "banana"}}, + }, +} + +func TestParseQuery(t *testing.T) { + for i, test := range parseTests { + form, err := ParseQuery(test.query) + if err != nil { + t.Errorf("test %d: Unexpected error: %v", i, err) + continue + } + if len(form) != len(test.out) { + t.Errorf("test %d: len(form) = %d, want %d", i, len(form), len(test.out)) + } + for k, evs := range test.out { + vs, ok := form[k] + if !ok { + t.Errorf("test %d: Missing key %q", i, k) + continue + } + if len(vs) != len(evs) { + t.Errorf("test %d: len(form[%q]) = %d, want %d", i, k, len(vs), len(evs)) + continue + } + for j, ev := range evs { + if v := vs[j]; v != ev { + t.Errorf("test %d: form[%q][%d] = %q, want %q", i, k, j, v, ev) + } + } + } + } +} + +type RequestURITest struct { + url *URL + out string +} + +var requritests = []RequestURITest{ + { + &URL{ + Scheme: "http", + Host: "example.com", + Path: "", + }, + "/", + }, + { + &URL{ + Scheme: "http", + Host: "example.com", + Path: "/a b", + }, + "/a%20b", + }, + { + &URL{ + Scheme: "http", + Host: "example.com", + Path: "/a b", + RawQuery: "q=go+language", + }, + "/a%20b?q=go+language", + }, + { + &URL{ + Scheme: "myschema", + Opaque: "opaque", + }, + "opaque", + }, + { + &URL{ + Scheme: "myschema", + Opaque: "opaque", + RawQuery: "q=go+language", + }, + "opaque?q=go+language", + }, +} + +func TestRequestURI(t *testing.T) { + for _, tt := range requritests { + s := tt.url.RequestURI() + if s != tt.out { + t.Errorf("%#v.RequestURI() == %q (expected %q)", tt.url, s, tt.out) + } + } +} |