diff options
Diffstat (limited to 'src/pkg/net')
169 files changed, 13730 insertions, 4841 deletions
diff --git a/src/pkg/net/cgo_bsd.go b/src/pkg/net/cgo_bsd.go index 63750f7a3..3b38e3d83 100644 --- a/src/pkg/net/cgo_bsd.go +++ b/src/pkg/net/cgo_bsd.go @@ -11,6 +11,6 @@ package net */ import "C" -func cgoAddrInfoMask() C.int { - return C.AI_MASK +func cgoAddrInfoFlags() C.int { + return (C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL) & C.AI_MASK } diff --git a/src/pkg/net/cgo_linux.go b/src/pkg/net/cgo_linux.go index 8d4413d2d..f6cefa89a 100644 --- a/src/pkg/net/cgo_linux.go +++ b/src/pkg/net/cgo_linux.go @@ -9,6 +9,12 @@ package net */ import "C" -func cgoAddrInfoMask() C.int { +func cgoAddrInfoFlags() C.int { + // NOTE(rsc): In theory there are approximately balanced + // arguments for and against including AI_ADDRCONFIG + // in the flags (it includes IPv4 results only on IPv4 systems, + // and similarly for IPv6), but in practice setting it causes + // getaddrinfo to return the wrong canonical name on Linux. + // So definitely leave it out. return C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL } diff --git a/src/pkg/net/cgo_netbsd.go b/src/pkg/net/cgo_netbsd.go new file mode 100644 index 000000000..aeaf8e568 --- /dev/null +++ b/src/pkg/net/cgo_netbsd.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <netdb.h> +*/ +import "C" + +func cgoAddrInfoFlags() C.int { + return C.AI_CANONNAME +} diff --git a/src/pkg/net/cgo_openbsd.go b/src/pkg/net/cgo_openbsd.go new file mode 100644 index 000000000..aeaf8e568 --- /dev/null +++ b/src/pkg/net/cgo_openbsd.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <netdb.h> +*/ +import "C" + +func cgoAddrInfoFlags() C.int { + return C.AI_CANONNAME +} diff --git a/src/pkg/net/cgo_unix.go b/src/pkg/net/cgo_unix.go index 36a3f3d34..7476140eb 100644 --- a/src/pkg/net/cgo_unix.go +++ b/src/pkg/net/cgo_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux +// +build darwin freebsd linux netbsd openbsd package net @@ -81,13 +81,7 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err error, complet var res *C.struct_addrinfo var hints C.struct_addrinfo - // NOTE(rsc): In theory there are approximately balanced - // arguments for and against including AI_ADDRCONFIG - // in the flags (it includes IPv4 results only on IPv4 systems, - // and similarly for IPv6), but in practice setting it causes - // getaddrinfo to return the wrong canonical name on Linux. - // So definitely leave it out. - hints.ai_flags = (C.AI_ALL | C.AI_V4MAPPED | C.AI_CANONNAME) & cgoAddrInfoMask() + hints.ai_flags = cgoAddrInfoFlags() h := C.CString(name) defer C.free(unsafe.Pointer(h)) diff --git a/src/pkg/net/conn_test.go b/src/pkg/net/conn_test.go new file mode 100644 index 000000000..fdb90862f --- /dev/null +++ b/src/pkg/net/conn_test.go @@ -0,0 +1,114 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements API tests across platforms and will never have a build +// tag. + +package net + +import ( + "os" + "runtime" + "testing" + "time" +) + +var connTests = []struct { + net string + addr string +}{ + {"tcp", "127.0.0.1:0"}, + {"unix", testUnixAddr()}, + {"unixpacket", testUnixAddr()}, +} + +// someTimeout is used just to test that net.Conn implementations +// don't explode when their SetFooDeadline methods are called. +// It isn't actually used for testing timeouts. +const someTimeout = 10 * time.Second + +func TestConnAndListener(t *testing.T) { + for _, tt := range connTests { + switch tt.net { + case "unix", "unixpacket": + switch runtime.GOOS { + case "plan9", "windows": + continue + } + if tt.net == "unixpacket" && runtime.GOOS != "linux" { + continue + } + } + + ln, err := Listen(tt.net, tt.addr) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer func(ln Listener, net, addr string) { + ln.Close() + switch net { + case "unix", "unixpacket": + os.Remove(addr) + } + }(ln, tt.net, tt.addr) + ln.Addr() + + done := make(chan int) + go transponder(t, ln, done) + + c, err := Dial(tt.net, ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c.Close() + c.LocalAddr() + c.RemoteAddr() + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) + + if _, err := c.Write([]byte("CONN TEST")); err != nil { + t.Fatalf("Conn.Write failed: %v", err) + } + rb := make([]byte, 128) + if _, err := c.Read(rb); err != nil { + t.Fatalf("Conn.Read failed: %v", err) + } + + <-done + } +} + +func transponder(t *testing.T, ln Listener, done chan<- int) { + defer func() { done <- 1 }() + + switch ln := ln.(type) { + case *TCPListener: + ln.SetDeadline(time.Now().Add(someTimeout)) + case *UnixListener: + ln.SetDeadline(time.Now().Add(someTimeout)) + } + c, err := ln.Accept() + if err != nil { + t.Errorf("Listener.Accept failed: %v", err) + return + } + defer c.Close() + c.LocalAddr() + c.RemoteAddr() + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) + + b := make([]byte, 128) + n, err := c.Read(b) + if err != nil { + t.Errorf("Conn.Read failed: %v", err) + return + } + if _, err := c.Write(b[:n]); err != nil { + t.Errorf("Conn.Write failed: %v", err) + return + } +} diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index 10ca5faf7..22e1e7dd8 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -5,15 +5,91 @@ package net import ( + "errors" "time" ) -func parseDialNetwork(net string) (afnet string, proto int, err error) { +// A DialOption modifies a DialOpt call. +type DialOption interface { + dialOption() +} + +var ( + // TCP is a dial option to dial with TCP (over IPv4 or IPv6). + TCP = Network("tcp") + + // UDP is a dial option to dial with UDP (over IPv4 or IPv6). + UDP = Network("udp") +) + +// Network returns a DialOption to dial using the given network. +// +// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), +// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" +// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and +// "unixpacket". +// +// For IP networks, net must be "ip", "ip4" or "ip6" followed +// by a colon and a protocol number or name, such as +// "ipv4:1" or "ip6:ospf". +func Network(net string) DialOption { + return dialNetwork(net) +} + +type dialNetwork string + +func (dialNetwork) dialOption() {} + +// Deadline returns a DialOption to fail a dial that doesn't +// complete before t. +func Deadline(t time.Time) DialOption { + return dialDeadline(t) +} + +// Timeout returns a DialOption to fail a dial that doesn't +// complete within the provided duration. +func Timeout(d time.Duration) DialOption { + return dialDeadline(time.Now().Add(d)) +} + +type dialDeadline time.Time + +func (dialDeadline) dialOption() {} + +type tcpFastOpen struct{} + +func (tcpFastOpen) dialOption() {} + +// TODO(bradfitz): implement this (golang.org/issue/4842) and unexport this. +// +// TCPFastTimeout returns an option to use TCP Fast Open (TFO) when +// doing this dial. It is only valid for use with TCP connections. +// Data sent over a TFO connection may be processed by the peer +// multiple times, so should be used with caution. +func todo_TCPFastTimeout() DialOption { + return tcpFastOpen{} +} + +type localAddrOption struct { + la Addr +} + +func (localAddrOption) dialOption() {} + +// LocalAddress returns a dial option to perform a dial with the +// provided local address. The address must be of a compatible type +// for the network being dialed. +func LocalAddress(addr Addr) DialOption { + return localAddrOption{addr} +} + +func parseNetwork(net string) (afnet string, proto int, err error) { i := last(net, ':') if i < 0 { // no colon switch net { case "tcp", "tcp4", "tcp6": case "udp", "udp4", "udp6": + case "ip", "ip4", "ip6": case "unix", "unixgram", "unixpacket": default: return "", 0, UnknownNetworkError(net) @@ -36,40 +112,27 @@ func parseDialNetwork(net string) (afnet string, proto int, err error) { return "", 0, UnknownNetworkError(net) } -func resolveNetAddr(op, net, addr string) (afnet string, a Addr, err error) { - afnet, _, err = parseDialNetwork(net) +func resolveAddr(op, net, addr string, deadline time.Time) (Addr, error) { + afnet, _, err := parseNetwork(net) if err != nil { - return "", nil, &OpError{op, net, nil, err} + return nil, &OpError{op, net, nil, err} } if op == "dial" && addr == "" { - return "", nil, &OpError{op, net, nil, errMissingAddress} + return nil, &OpError{op, net, nil, errMissingAddress} } switch afnet { - case "tcp", "tcp4", "tcp6": - if addr != "" { - a, err = ResolveTCPAddr(afnet, addr) - } - case "udp", "udp4", "udp6": - if addr != "" { - a, err = ResolveUDPAddr(afnet, addr) - } - case "ip", "ip4", "ip6": - if addr != "" { - a, err = ResolveIPAddr(afnet, addr) - } case "unix", "unixgram", "unixpacket": - if addr != "" { - a, err = ResolveUnixAddr(afnet, addr) - } + return ResolveUnixAddr(afnet, addr) } - return + return resolveInternetAddr(afnet, addr, deadline) } // Dial connects to the address addr on the network net. // // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), // "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" -// (IPv4-only), "ip6" (IPv6-only), "unix" and "unixpacket". +// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and +// "unixpacket". // // For TCP and UDP networks, addresses have the form host:port. // If host is a literal IPv6 address, it must be enclosed @@ -81,7 +144,7 @@ func resolveNetAddr(op, net, addr string) (afnet string, a Addr, err error) { // Dial("tcp", "google.com:80") // Dial("tcp", "[de:ad:be:ef::ca:fe]:80") // -// For IP networks, addr must be "ip", "ip4" or "ip6" followed +// For IP networks, net must be "ip", "ip4" or "ip6" followed // by a colon and a protocol number or name. // // Examples: @@ -89,25 +152,71 @@ func resolveNetAddr(op, net, addr string) (afnet string, a Addr, err error) { // Dial("ip6:ospf", "::1") // func Dial(net, addr string) (Conn, error) { - _, addri, err := resolveNetAddr("dial", net, addr) + return DialOpt(addr, dialNetwork(net)) +} + +func netFromOptions(opts []DialOption) string { + for _, opt := range opts { + if p, ok := opt.(dialNetwork); ok { + return string(p) + } + } + return "tcp" +} + +func deadlineFromOptions(opts []DialOption) time.Time { + for _, opt := range opts { + if d, ok := opt.(dialDeadline); ok { + return time.Time(d) + } + } + return noDeadline +} + +var noLocalAddr Addr // nil + +func localAddrFromOptions(opts []DialOption) Addr { + for _, opt := range opts { + if o, ok := opt.(localAddrOption); ok { + return o.la + } + } + return noLocalAddr +} + +// DialOpt dials addr using the provided options. +// If no options are provided, DialOpt(addr) is equivalent +// to Dial("tcp", addr). See Dial for the syntax of addr. +func DialOpt(addr string, opts ...DialOption) (Conn, error) { + net := netFromOptions(opts) + deadline := deadlineFromOptions(opts) + la := localAddrFromOptions(opts) + ra, err := resolveAddr("dial", net, addr, deadline) if err != nil { return nil, err } - return dialAddr(net, addr, addri) + return dial(net, addr, la, ra, deadline) } -func dialAddr(net, addr string, addri Addr) (c Conn, err error) { - switch ra := addri.(type) { +func dial(net, addr string, la, ra Addr, deadline time.Time) (c Conn, err error) { + if la != nil && la.Network() != ra.Network() { + return nil, &OpError{"dial", net, ra, errors.New("mismatched local addr type " + la.Network())} + } + switch ra := ra.(type) { case *TCPAddr: - c, err = DialTCP(net, nil, ra) + la, _ := la.(*TCPAddr) + c, err = dialTCP(net, la, ra, deadline) case *UDPAddr: - c, err = DialUDP(net, nil, ra) + la, _ := la.(*UDPAddr) + c, err = dialUDP(net, la, ra, deadline) case *IPAddr: - c, err = DialIP(net, nil, ra) + la, _ := la.(*IPAddr) + c, err = dialIP(net, la, ra, deadline) case *UnixAddr: - c, err = DialUnix(net, nil, ra) + la, _ := la.(*UnixAddr) + c, err = dialUnix(net, la, ra, deadline) default: - err = &OpError{"dial", net + " " + addr, nil, UnknownNetworkError(net)} + err = &OpError{"dial", net + " " + addr, ra, UnknownNetworkError(net)} } if err != nil { return nil, err @@ -118,10 +227,14 @@ func dialAddr(net, addr string, addri Addr) (c Conn, err error) { // DialTimeout acts like Dial but takes a timeout. // The timeout includes name resolution, if required. func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) { - // TODO(bradfitz): the timeout should be pushed down into the - // net package's event loop, so on timeout to dead hosts we - // don't have a goroutine sticking around for the default of - // ~3 minutes. + return dialTimeout(net, addr, timeout) +} + +// dialTimeoutRace is the old implementation of DialTimeout, still used +// on operating systems where the deadline hasn't been pushed down +// into the pollserver. +// TODO: fix this on plan9. +func dialTimeoutRace(net, addr string, timeout time.Duration) (Conn, error) { t := time.NewTimer(timeout) defer t.Stop() type pair struct { @@ -131,30 +244,30 @@ func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) { ch := make(chan pair, 1) resolvedAddr := make(chan Addr, 1) go func() { - _, addri, err := resolveNetAddr("dial", net, addr) + ra, err := resolveAddr("dial", net, addr, noDeadline) if err != nil { ch <- pair{nil, err} return } - resolvedAddr <- addri // in case we need it for OpError - c, err := dialAddr(net, addr, addri) + resolvedAddr <- ra // in case we need it for OpError + c, err := dial(net, addr, noLocalAddr, ra, noDeadline) ch <- pair{c, err} }() select { case <-t.C: // Try to use the real Addr in our OpError, if we resolved it // before the timeout. Otherwise we just use stringAddr. - var addri Addr + var ra Addr select { case a := <-resolvedAddr: - addri = a + ra = a default: - addri = &stringAddr{net, addr} + ra = &stringAddr{net, addr} } err := &OpError{ Op: "dial", Net: net, - Addr: addri, + Addr: ra, Err: &timeoutError{}, } return nil, err @@ -173,24 +286,16 @@ func (a stringAddr) String() string { return a.addr } // Listen announces on the local network address laddr. // The network string net must be a stream-oriented network: -// "tcp", "tcp4", "tcp6", or "unix", or "unixpacket". +// "tcp", "tcp4", "tcp6", "unix" or "unixpacket". func Listen(net, laddr string) (Listener, error) { - afnet, a, err := resolveNetAddr("listen", net, laddr) + la, err := resolveAddr("listen", net, laddr, noDeadline) if err != nil { return nil, err } - switch afnet { - case "tcp", "tcp4", "tcp6": - var la *TCPAddr - if a != nil { - la = a.(*TCPAddr) - } + switch la := la.(type) { + case *TCPAddr: return ListenTCP(net, la) - case "unix", "unixpacket": - var la *UnixAddr - if a != nil { - la = a.(*UnixAddr) - } + case *UnixAddr: return ListenUnix(net, la) } return nil, UnknownNetworkError(net) @@ -199,30 +304,18 @@ func Listen(net, laddr string) (Listener, error) { // ListenPacket announces on the local network address laddr. // The network string net must be a packet-oriented network: // "udp", "udp4", "udp6", "ip", "ip4", "ip6" or "unixgram". -func ListenPacket(net, addr string) (PacketConn, error) { - afnet, a, err := resolveNetAddr("listen", net, addr) +func ListenPacket(net, laddr string) (PacketConn, error) { + la, err := resolveAddr("listen", net, laddr, noDeadline) if err != nil { return nil, err } - switch afnet { - case "udp", "udp4", "udp6": - var la *UDPAddr - if a != nil { - la = a.(*UDPAddr) - } + switch la := la.(type) { + case *UDPAddr: return ListenUDP(net, la) - case "ip", "ip4", "ip6": - var la *IPAddr - if a != nil { - la = a.(*IPAddr) - } + case *IPAddr: return ListenIP(net, la) - case "unixgram": - var la *UnixAddr - if a != nil { - la = a.(*UnixAddr) - } - return DialUnix(net, la, nil) + case *UnixAddr: + return ListenUnixgram(net, la) } return nil, UnknownNetworkError(net) } diff --git a/src/pkg/net/dial_test.go b/src/pkg/net/dial_test.go index 7212087fe..2303e8fa4 100644 --- a/src/pkg/net/dial_test.go +++ b/src/pkg/net/dial_test.go @@ -7,6 +7,9 @@ package net import ( "flag" "fmt" + "io" + "os" + "reflect" "regexp" "runtime" "testing" @@ -55,7 +58,7 @@ func TestDialTimeout(t *testing.T) { // on our 386 builder, this Dial succeeds, connecting // to an IIS web server somewhere. The data center // or VM or firewall must be stealing the TCP connection. - // + // // IANA Service Name and Transport Protocol Port Number Registry // <http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xml> go func() { @@ -72,8 +75,7 @@ func TestDialTimeout(t *testing.T) { // by default. FreeBSD likely works, but is untested. // TODO(rsc): // The timeout never happens on Windows. Why? Issue 3016. - t.Logf("skipping test on %q; untested.", runtime.GOOS) - return + t.Skipf("skipping test on %q; untested.", runtime.GOOS) } connected := 0 @@ -105,8 +107,7 @@ func TestDialTimeout(t *testing.T) { func TestSelfConnect(t *testing.T) { if runtime.GOOS == "windows" { // TODO(brainman): do not know why it hangs. - t.Logf("skipping known-broken test on windows") - return + t.Skip("skipping known-broken test on windows") } // Test that Dial does not honor self-connects. // See the comment in DialTCP. @@ -130,7 +131,7 @@ func TestSelfConnect(t *testing.T) { n = 1000 } switch runtime.GOOS { - case "darwin", "freebsd", "openbsd", "windows": + case "darwin", "freebsd", "netbsd", "openbsd", "plan9", "windows": // Non-Linux systems take a long time to figure // out that there is nothing listening on localhost. n = 100 @@ -222,3 +223,104 @@ func TestDialError(t *testing.T) { } } } + +var invalidDialAndListenArgTests = []struct { + net string + addr string + err error +}{ + {"foo", "bar", &OpError{Op: "dial", Net: "foo", Addr: nil, Err: UnknownNetworkError("foo")}}, + {"baz", "", &OpError{Op: "listen", Net: "baz", Addr: nil, Err: UnknownNetworkError("baz")}}, + {"tcp", "", &OpError{Op: "dial", Net: "tcp", Addr: nil, Err: errMissingAddress}}, +} + +func TestInvalidDialAndListenArgs(t *testing.T) { + for _, tt := range invalidDialAndListenArgTests { + var err error + switch tt.err.(*OpError).Op { + case "dial": + _, err = Dial(tt.net, tt.addr) + case "listen": + _, err = Listen(tt.net, tt.addr) + } + if !reflect.DeepEqual(tt.err, err) { + t.Fatalf("got %#v; expected %#v", err, tt.err) + } + } +} + +func TestDialTimeoutFDLeak(t *testing.T) { + if runtime.GOOS != "linux" { + // TODO(bradfitz): test on other platforms + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t) + defer ln.Close() + + type connErr struct { + conn Conn + err error + } + dials := listenerBacklog + 100 + // used to be listenerBacklog + 5, but was found to be unreliable, issue 4384. + maxGoodConnect := listenerBacklog + runtime.NumCPU()*10 + resc := make(chan connErr) + for i := 0; i < dials; i++ { + go func() { + conn, err := DialTimeout("tcp", ln.Addr().String(), 500*time.Millisecond) + resc <- connErr{conn, err} + }() + } + + var firstErr string + var ngood int + var toClose []io.Closer + for i := 0; i < dials; i++ { + ce := <-resc + if ce.err == nil { + ngood++ + if ngood > maxGoodConnect { + t.Errorf("%d good connects; expected at most %d", ngood, maxGoodConnect) + } + toClose = append(toClose, ce.conn) + continue + } + err := ce.err + if firstErr == "" { + firstErr = err.Error() + } else if err.Error() != firstErr { + t.Fatalf("inconsistent error messages: first was %q, then later %q", firstErr, err) + } + } + for _, c := range toClose { + c.Close() + } + for i := 0; i < 100; i++ { + if got := numFD(); got < dials { + // Test passes. + return + } + time.Sleep(10 * time.Millisecond) + } + if got := numFD(); got >= dials { + t.Errorf("num fds after %d timeouts = %d; want <%d", dials, got, dials) + } +} + +func numFD() int { + if runtime.GOOS == "linux" { + f, err := os.Open("/proc/self/fd") + if err != nil { + panic(err) + } + defer f.Close() + names, err := f.Readdirnames(0) + if err != nil { + panic(err) + } + return len(names) + } + // All tests using this should be skipped anyway, but: + panic("numFDs not implemented on " + runtime.GOOS) +} diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go index 03c449972..73a94f5bf 100644 --- a/src/pkg/net/dialgoogle_test.go +++ b/src/pkg/net/dialgoogle_test.go @@ -41,17 +41,6 @@ func doDial(t *testing.T, network, addr string) { fd.Close() } -func TestLookupCNAME(t *testing.T) { - if testing.Short() || !*testExternal { - t.Logf("skipping test to avoid external network") - return - } - cname, err := LookupCNAME("www.google.com") - if !strings.HasSuffix(cname, ".l.google.com.") || err != nil { - t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "*.l.google.com.", nil`, cname, err) - } -} - var googleaddrsipv4 = []string{ "%d.%d.%d.%d:80", "www.google.com:80", @@ -67,8 +56,7 @@ var googleaddrsipv4 = []string{ func TestDialGoogleIPv4(t *testing.T) { if testing.Short() || !*testExternal { - t.Logf("skipping test to avoid external network") - return + t.Skip("skipping test to avoid external network") } // Insert an actual IPv4 address for google.com @@ -123,12 +111,14 @@ var googleaddrsipv6 = []string{ func TestDialGoogleIPv6(t *testing.T) { if testing.Short() || !*testExternal { - t.Logf("skipping test to avoid external network") - return + t.Skip("skipping test to avoid external network") } // Only run tcp6 if the kernel will take it. - if !*testIPv6 || !supportsIPv6 { - return + if !supportsIPv6 { + t.Skip("skipping test; ipv6 is not supported") + } + if !*testIPv6 { + t.Skip("test disabled; use -ipv6 to enable") } // Insert an actual IPv6 address for ipv6.google.com diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go index e69cb3188..76b192645 100644 --- a/src/pkg/net/dnsclient.go +++ b/src/pkg/net/dnsclient.go @@ -183,7 +183,7 @@ func (s byPriorityWeight) Less(i, j int) bool { } // shuffleByWeight shuffles SRV records by weight using the algorithm -// described in RFC 2782. +// described in RFC 2782. func (addrs byPriorityWeight) shuffleByWeight() { sum := 0 for _, addr := range addrs { @@ -244,3 +244,8 @@ func (s byPref) sort() { } sort.Sort(s) } + +// An NS represents a single DNS NS record. +type NS struct { + Host string +} diff --git a/src/pkg/net/dnsclient_unix.go b/src/pkg/net/dnsclient_unix.go index 18c39360e..9e21bb4a0 100644 --- a/src/pkg/net/dnsclient_unix.go +++ b/src/pkg/net/dnsclient_unix.go @@ -237,24 +237,30 @@ func goLookupIP(name string) (addrs []IP, err error) { } var records []dnsRR var cname string - cname, records, err = lookup(name, dnsTypeA) - if err != nil { - return - } + var err4, err6 error + cname, records, err4 = lookup(name, dnsTypeA) addrs = convertRR_A(records) if cname != "" { name = cname } - _, records, err = lookup(name, dnsTypeAAAA) - if err != nil && len(addrs) > 0 { - // Ignore error because A lookup succeeded. - err = nil + _, records, err6 = lookup(name, dnsTypeAAAA) + if err4 != nil && err6 == nil { + // Ignore A error because AAAA lookup succeeded. + err4 = nil } - if err != nil { - return + if err6 != nil && len(addrs) > 0 { + // Ignore AAAA error because A lookup succeeded. + err6 = nil } + if err4 != nil { + return nil, err4 + } + if err6 != nil { + return nil, err6 + } + addrs = append(addrs, convertRR_AAAA(records)...) - return + return addrs, nil } // goLookupCNAME is the native Go implementation of LookupCNAME. diff --git a/src/pkg/net/dnsconfig.go b/src/pkg/net/dnsconfig_unix.go index bb46cc900..bb46cc900 100644 --- a/src/pkg/net/dnsconfig.go +++ b/src/pkg/net/dnsconfig_unix.go diff --git a/src/pkg/net/dnsmsg.go b/src/pkg/net/dnsmsg.go index b6ebe1173..161afb2a5 100644 --- a/src/pkg/net/dnsmsg.go +++ b/src/pkg/net/dnsmsg.go @@ -618,7 +618,7 @@ func printStruct(any dnsStruct) string { s += name + "=" switch tag { case "ipv4": - i := val.(uint32) + i := *val.(*uint32) s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() case "ipv6": i := val.([]byte) diff --git a/src/pkg/net/example_test.go b/src/pkg/net/example_test.go index 1a1c2edfe..eefe84fa7 100644 --- a/src/pkg/net/example_test.go +++ b/src/pkg/net/example_test.go @@ -17,7 +17,7 @@ func ExampleListener() { log.Fatal(err) } for { - // Wait for a connection. + // Wait for a connection. conn, err := l.Accept() if err != nil { log.Fatal(err) diff --git a/src/pkg/net/fd_openbsd.go b/src/pkg/net/fd_bsd.go index 35d84c30e..8bb1ae538 100644 --- a/src/pkg/net/fd_openbsd.go +++ b/src/pkg/net/fd_bsd.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// +build freebsd netbsd openbsd + // Waiting for FDs via kqueue/kevent. package net @@ -31,6 +33,8 @@ func newpollster() (p *pollster, err error) { return p, nil } +// First return value is whether the pollServer should be woken up. +// This version always returns false. func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { // pollServer is locked. @@ -62,7 +66,9 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { return false, nil } -func (p *pollster) DelFD(fd int, mode int) { +// Return value is whether the pollServer should be woken up. +// This version always returns false. +func (p *pollster) DelFD(fd int, mode int) bool { // pollServer is locked. var kmode int @@ -75,6 +81,7 @@ func (p *pollster) DelFD(fd int, mode int) { // EV_DELETE - delete event from kqueue list syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE) syscall.Kevent(p.kq, p.kbuf[:], nil, nil) + return false } func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { diff --git a/src/pkg/net/fd_darwin.go b/src/pkg/net/fd_darwin.go index 3dd33edc2..382465ba6 100644 --- a/src/pkg/net/fd_darwin.go +++ b/src/pkg/net/fd_darwin.go @@ -32,6 +32,8 @@ func newpollster() (p *pollster, err error) { return p, nil } +// First return value is whether the pollServer should be woken up. +// This version always returns false. func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { // pollServer is locked. @@ -65,7 +67,9 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { return false, nil } -func (p *pollster) DelFD(fd int, mode int) { +// Return value is whether the pollServer should be woken up. +// This version always returns false. +func (p *pollster) DelFD(fd int, mode int) bool { // pollServer is locked. var kmode int @@ -80,6 +84,7 @@ func (p *pollster) DelFD(fd int, mode int) { // rather than waiting for real event syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE|syscall.EV_RECEIPT) syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil) + return false } func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { diff --git a/src/pkg/net/fd_freebsd.go b/src/pkg/net/fd_freebsd.go deleted file mode 100644 index 35d84c30e..000000000 --- a/src/pkg/net/fd_freebsd.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Waiting for FDs via kqueue/kevent. - -package net - -import ( - "os" - "syscall" -) - -type pollster struct { - kq int - eventbuf [10]syscall.Kevent_t - events []syscall.Kevent_t - - // An event buffer for AddFD/DelFD. - // Must hold pollServer lock. - kbuf [1]syscall.Kevent_t -} - -func newpollster() (p *pollster, err error) { - p = new(pollster) - if p.kq, err = syscall.Kqueue(); err != nil { - return nil, os.NewSyscallError("kqueue", err) - } - syscall.CloseOnExec(p.kq) - p.events = p.eventbuf[0:0] - return p, nil -} - -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { - // pollServer is locked. - - var kmode int - if mode == 'r' { - kmode = syscall.EVFILT_READ - } else { - kmode = syscall.EVFILT_WRITE - } - ev := &p.kbuf[0] - // EV_ADD - add event to kqueue list - // EV_ONESHOT - delete the event the first time it triggers - flags := syscall.EV_ADD - if !repeat { - flags |= syscall.EV_ONESHOT - } - syscall.SetKevent(ev, fd, kmode, flags) - - n, err := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) - if err != nil { - return false, os.NewSyscallError("kevent", err) - } - if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { - return false, os.NewSyscallError("kqueue phase error", err) - } - if ev.Data != 0 { - return false, syscall.Errno(int(ev.Data)) - } - return false, nil -} - -func (p *pollster) DelFD(fd int, mode int) { - // pollServer is locked. - - var kmode int - if mode == 'r' { - kmode = syscall.EVFILT_READ - } else { - kmode = syscall.EVFILT_WRITE - } - ev := &p.kbuf[0] - // EV_DELETE - delete event from kqueue list - syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE) - syscall.Kevent(p.kq, p.kbuf[:], nil, nil) -} - -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { - var t *syscall.Timespec - for len(p.events) == 0 { - if nsec > 0 { - if t == nil { - t = new(syscall.Timespec) - } - *t = syscall.NsecToTimespec(nsec) - } - - s.Unlock() - n, err := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) - s.Lock() - - if err != nil { - if err == syscall.EINTR { - continue - } - return -1, 0, os.NewSyscallError("kevent", err) - } - if n == 0 { - return -1, 0, nil - } - p.events = p.eventbuf[:n] - } - ev := &p.events[0] - p.events = p.events[1:] - fd = int(ev.Ident) - if ev.Filter == syscall.EVFILT_READ { - mode = 'r' - } else { - mode = 'w' - } - return fd, mode, nil -} - -func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_linux.go b/src/pkg/net/fd_linux.go index 085e42307..03679196d 100644 --- a/src/pkg/net/fd_linux.go +++ b/src/pkg/net/fd_linux.go @@ -51,6 +51,8 @@ func newpollster() (p *pollster, err error) { return p, nil } +// First return value is whether the pollServer should be woken up. +// This version always returns false. func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { // pollServer is locked. @@ -114,7 +116,9 @@ func (p *pollster) StopWaiting(fd int, bits uint) { } } -func (p *pollster) DelFD(fd int, mode int) { +// Return value is whether the pollServer should be woken up. +// This version always returns false. +func (p *pollster) DelFD(fd int, mode int) bool { // pollServer is locked. if mode == 'r' { @@ -133,6 +137,7 @@ func (p *pollster) DelFD(fd int, mode int) { i++ } } + return false } func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { diff --git a/src/pkg/net/fd_netbsd.go b/src/pkg/net/fd_netbsd.go deleted file mode 100644 index 35d84c30e..000000000 --- a/src/pkg/net/fd_netbsd.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Waiting for FDs via kqueue/kevent. - -package net - -import ( - "os" - "syscall" -) - -type pollster struct { - kq int - eventbuf [10]syscall.Kevent_t - events []syscall.Kevent_t - - // An event buffer for AddFD/DelFD. - // Must hold pollServer lock. - kbuf [1]syscall.Kevent_t -} - -func newpollster() (p *pollster, err error) { - p = new(pollster) - if p.kq, err = syscall.Kqueue(); err != nil { - return nil, os.NewSyscallError("kqueue", err) - } - syscall.CloseOnExec(p.kq) - p.events = p.eventbuf[0:0] - return p, nil -} - -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { - // pollServer is locked. - - var kmode int - if mode == 'r' { - kmode = syscall.EVFILT_READ - } else { - kmode = syscall.EVFILT_WRITE - } - ev := &p.kbuf[0] - // EV_ADD - add event to kqueue list - // EV_ONESHOT - delete the event the first time it triggers - flags := syscall.EV_ADD - if !repeat { - flags |= syscall.EV_ONESHOT - } - syscall.SetKevent(ev, fd, kmode, flags) - - n, err := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) - if err != nil { - return false, os.NewSyscallError("kevent", err) - } - if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { - return false, os.NewSyscallError("kqueue phase error", err) - } - if ev.Data != 0 { - return false, syscall.Errno(int(ev.Data)) - } - return false, nil -} - -func (p *pollster) DelFD(fd int, mode int) { - // pollServer is locked. - - var kmode int - if mode == 'r' { - kmode = syscall.EVFILT_READ - } else { - kmode = syscall.EVFILT_WRITE - } - ev := &p.kbuf[0] - // EV_DELETE - delete event from kqueue list - syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE) - syscall.Kevent(p.kq, p.kbuf[:], nil, nil) -} - -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { - var t *syscall.Timespec - for len(p.events) == 0 { - if nsec > 0 { - if t == nil { - t = new(syscall.Timespec) - } - *t = syscall.NsecToTimespec(nsec) - } - - s.Unlock() - n, err := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) - s.Lock() - - if err != nil { - if err == syscall.EINTR { - continue - } - return -1, 0, os.NewSyscallError("kevent", err) - } - if n == 0 { - return -1, 0, nil - } - p.events = p.eventbuf[:n] - } - ev := &p.events[0] - p.events = p.events[1:] - fd = int(ev.Ident) - if ev.Filter == syscall.EVFILT_READ { - mode = 'r' - } else { - mode = 'w' - } - return fd, mode, nil -} - -func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_plan9.go b/src/pkg/net/fd_plan9.go new file mode 100644 index 000000000..169087999 --- /dev/null +++ b/src/pkg/net/fd_plan9.go @@ -0,0 +1,129 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "io" + "os" + "syscall" + "time" +) + +// Network file descritor. +type netFD struct { + proto, name, dir string + ctl, data *os.File + laddr, raddr Addr +} + +var canCancelIO = true // used for testing current package + +func sysInit() { +} + +func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { + // On plan9, use the relatively inefficient + // goroutine-racing implementation. + return dialTimeoutRace(net, addr, timeout) +} + +func newFD(proto, name string, ctl, data *os.File, laddr, raddr Addr) *netFD { + return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, data, laddr, raddr} +} + +func (fd *netFD) ok() bool { return fd != nil && fd.ctl != nil } + +func (fd *netFD) Read(b []byte) (n int, err error) { + if !fd.ok() || fd.data == nil { + return 0, syscall.EINVAL + } + n, err = fd.data.Read(b) + if fd.proto == "udp" && err == io.EOF { + n = 0 + err = nil + } + return +} + +func (fd *netFD) Write(b []byte) (n int, err error) { + if !fd.ok() || fd.data == nil { + return 0, syscall.EINVAL + } + return fd.data.Write(b) +} + +func (fd *netFD) CloseRead() error { + if !fd.ok() { + return syscall.EINVAL + } + return syscall.EPLAN9 +} + +func (fd *netFD) CloseWrite() error { + if !fd.ok() { + return syscall.EINVAL + } + return syscall.EPLAN9 +} + +func (fd *netFD) Close() error { + if !fd.ok() { + return syscall.EINVAL + } + err := fd.ctl.Close() + if fd.data != nil { + if err1 := fd.data.Close(); err1 != nil && err == nil { + err = err1 + } + } + fd.ctl = nil + fd.data = nil + return err +} + +// This method is only called via Conn. +func (fd *netFD) dup() (*os.File, error) { + if !fd.ok() || fd.data == nil { + return nil, syscall.EINVAL + } + return fd.file(fd.data, fd.dir+"/data") +} + +func (l *TCPListener) dup() (*os.File, error) { + if !l.fd.ok() { + return nil, syscall.EINVAL + } + return l.fd.file(l.fd.ctl, l.fd.dir+"/ctl") +} + +func (fd *netFD) file(f *os.File, s string) (*os.File, error) { + syscall.ForkLock.RLock() + dfd, err := syscall.Dup(int(f.Fd()), -1) + syscall.ForkLock.RUnlock() + if err != nil { + return nil, &OpError{"dup", s, fd.laddr, err} + } + return os.NewFile(uintptr(dfd), s), nil +} + +func setDeadline(fd *netFD, t time.Time) error { + return syscall.EPLAN9 +} + +func setReadDeadline(fd *netFD, t time.Time) error { + return syscall.EPLAN9 +} + +func setWriteDeadline(fd *netFD, t time.Time) error { + return syscall.EPLAN9 +} + +func setReadBuffer(fd *netFD, bytes int) error { + return syscall.EPLAN9 +} + +func setWriteBuffer(fd *netFD, bytes int) error { + return syscall.EPLAN9 +} diff --git a/src/pkg/net/fd_posix_test.go b/src/pkg/net/fd_posix_test.go new file mode 100644 index 000000000..8be0335d6 --- /dev/null +++ b/src/pkg/net/fd_posix_test.go @@ -0,0 +1,57 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd windows + +package net + +import ( + "testing" + "time" +) + +var deadlineSetTimeTests = []struct { + input time.Time + expected int64 +}{ + {time.Time{}, 0}, + {time.Date(2009, 11, 10, 23, 00, 00, 00, time.UTC), 1257894000000000000}, // 2009-11-10 23:00:00 +0000 UTC +} + +func TestDeadlineSetTime(t *testing.T) { + for _, tt := range deadlineSetTimeTests { + var d deadline + d.setTime(tt.input) + actual := d.value() + expected := int64(0) + if !tt.input.IsZero() { + expected = tt.input.UnixNano() + } + if actual != expected { + t.Errorf("set/value failed: expected %v, actual %v", expected, actual) + } + } +} + +var deadlineExpiredTests = []struct { + deadline time.Time + expired bool +}{ + // note, times are relative to the start of the test run, not + // the start of TestDeadlineExpired + {time.Now().Add(5 * time.Minute), false}, + {time.Now().Add(-5 * time.Minute), true}, + {time.Time{}, false}, // no deadline set +} + +func TestDeadlineExpired(t *testing.T) { + for _, tt := range deadlineExpiredTests { + var d deadline + d.set(tt.deadline.UnixNano()) + expired := d.expired() + if expired != tt.expired { + t.Errorf("expire failed: expected %v, actual %v", tt.expired, expired) + } + } +} diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd_unix.go index 76c953b9b..0540df825 100644 --- a/src/pkg/net/fd.go +++ b/src/pkg/net/fd_unix.go @@ -7,9 +7,9 @@ package net import ( - "errors" "io" "os" + "runtime" "sync" "syscall" "time" @@ -37,44 +37,24 @@ type netFD struct { laddr Addr raddr Addr - // owned by client - rdeadline int64 - rio sync.Mutex - wdeadline int64 - wio sync.Mutex + // serialize access to Read and Write methods + rio, wio sync.Mutex + + // read and write deadlines + rdeadline, wdeadline deadline // owned by fd wait server ncr, ncw int + + // wait server + pollServer *pollServer } // A pollServer helps FDs determine when to retry a non-blocking // read or write after they get EAGAIN. When an FD needs to wait, -// send the fd on s.cr (for a read) or s.cw (for a write) to pass the -// request to the poll server. Then receive on fd.cr/fd.cw. +// call s.WaitRead() or s.WaitWrite() to pass the request to the poll server. // When the pollServer finds that i/o on FD should be possible -// again, it will send fd on fd.cr/fd.cw to wake any waiting processes. -// This protocol is implemented as s.WaitRead() and s.WaitWrite(). -// -// There is one subtlety: when sending on s.cr/s.cw, the -// poll server is probably in a system call, waiting for an fd -// to become ready. It's not looking at the request channels. -// To resolve this, the poll server waits not just on the FDs it has -// been given but also its own pipe. After sending on the -// buffered channel s.cr/s.cw, WaitRead/WaitWrite writes a -// byte to the pipe, causing the pollServer's poll system call to -// return. In response to the pipe being readable, the pollServer -// re-polls its request channels. -// -// Note that the ordering is "send request" and then "wake up server". -// If the operations were reversed, there would be a race: the poll -// server might wake up and look at the request channel, see that it -// was empty, and go back to sleep, all before the requester managed -// to send the request. Because the send must complete before the wakeup, -// the request channel must be buffered. A buffer of size 1 is sufficient -// for any request load. If many processes are trying to submit requests, -// one will succeed, the pollServer will read the request, and then the -// channel will be empty for the next process's request. A larger buffer -// might help batch requests. +// again, it will send on fd.cr/fd.cw to wake any waiting goroutines. // // To avoid races in closing, all fd operations are locked and // refcounted. when netFD.Close() is called, it calls syscall.Shutdown @@ -82,7 +62,6 @@ type netFD struct { // will the fd be closed. type pollServer struct { - cr, cw chan *netFD // buffered >= 1 pr, pw *os.File poll *pollster // low-level OS hooks sync.Mutex // controls pending and deadline @@ -103,11 +82,11 @@ func (s *pollServer) AddFD(fd *netFD, mode int) error { key := intfd << 1 if mode == 'r' { fd.ncr++ - t = fd.rdeadline + t = fd.rdeadline.value() } else { fd.ncw++ key++ - t = fd.wdeadline + t = fd.wdeadline.value() } s.pending[key] = fd doWakeup := false @@ -117,15 +96,11 @@ func (s *pollServer) AddFD(fd *netFD, mode int) error { } wake, err := s.poll.AddFD(intfd, mode, false) + s.Unlock() if err != nil { - panic("pollServer AddFD " + err.Error()) - } - if wake { - doWakeup = true + return &OpError{"addfd", fd.net, fd.laddr, err} } - s.Unlock() - - if doWakeup { + if wake || doWakeup { s.Wakeup() } return nil @@ -134,17 +109,24 @@ func (s *pollServer) AddFD(fd *netFD, mode int) error { // Evict evicts fd from the pending list, unblocking // any I/O running on fd. The caller must have locked // pollserver. -func (s *pollServer) Evict(fd *netFD) { +// Return value is whether the pollServer should be woken up. +func (s *pollServer) Evict(fd *netFD) bool { + doWakeup := false if s.pending[fd.sysfd<<1] == fd { s.WakeFD(fd, 'r', errClosing) - s.poll.DelFD(fd.sysfd, 'r') + if s.poll.DelFD(fd.sysfd, 'r') { + doWakeup = true + } delete(s.pending, fd.sysfd<<1) } if s.pending[fd.sysfd<<1|1] == fd { s.WakeFD(fd, 'w', errClosing) - s.poll.DelFD(fd.sysfd, 'w') + if s.poll.DelFD(fd.sysfd, 'w') { + doWakeup = true + } delete(s.pending, fd.sysfd<<1|1) } + return doWakeup } var wakeupbuf [1]byte @@ -178,16 +160,12 @@ func (s *pollServer) WakeFD(fd *netFD, mode int, err error) { } } -func (s *pollServer) Now() int64 { - return time.Now().UnixNano() -} - func (s *pollServer) CheckDeadlines() { - now := s.Now() + now := time.Now().UnixNano() // TODO(rsc): This will need to be handled more efficiently, // probably with a heap indexed by wakeup time. - var next_deadline int64 + var nextDeadline int64 for key, fd := range s.pending { var t int64 var mode int @@ -197,27 +175,21 @@ func (s *pollServer) CheckDeadlines() { mode = 'w' } if mode == 'r' { - t = fd.rdeadline + t = fd.rdeadline.value() } else { - t = fd.wdeadline + t = fd.wdeadline.value() } if t > 0 { if t <= now { delete(s.pending, key) - if mode == 'r' { - s.poll.DelFD(fd.sysfd, mode) - fd.rdeadline = -1 - } else { - s.poll.DelFD(fd.sysfd, mode) - fd.wdeadline = -1 - } - s.WakeFD(fd, mode, nil) - } else if next_deadline == 0 || t < next_deadline { - next_deadline = t + s.poll.DelFD(fd.sysfd, mode) + s.WakeFD(fd, mode, errTimeout) + } else if nextDeadline == 0 || t < nextDeadline { + nextDeadline = t } } } - s.deadline = next_deadline + s.deadline = nextDeadline } func (s *pollServer) Run() { @@ -225,15 +197,15 @@ func (s *pollServer) Run() { s.Lock() defer s.Unlock() for { - var t = s.deadline - if t > 0 { - t = t - s.Now() - if t <= 0 { + var timeout int64 // nsec to wait for or 0 for none + if s.deadline > 0 { + timeout = s.deadline - time.Now().UnixNano() + if timeout <= 0 { s.CheckDeadlines() continue } } - fd, mode, err := s.poll.WaitFD(s, t) + fd, mode, err := s.poll.WaitFD(s, timeout) if err != nil { print("pollServer WaitFD: ", err.Error(), "\n") return @@ -279,24 +251,56 @@ func (s *pollServer) WaitWrite(fd *netFD) error { } // Network FD methods. -// All the network FDs use a single pollServer. +// Spread network FDs over several pollServers. + +var pollMaxN int +var pollservers []*pollServer +var startServersOnce []func() + +var canCancelIO = true // used for testing current package -var pollserver *pollServer -var onceStartServer sync.Once +func sysInit() { + pollMaxN = runtime.NumCPU() + if pollMaxN > 8 { + pollMaxN = 8 // No improvement then. + } + pollservers = make([]*pollServer, pollMaxN) + startServersOnce = make([]func(), pollMaxN) + for i := 0; i < pollMaxN; i++ { + k := i + once := new(sync.Once) + startServersOnce[i] = func() { once.Do(func() { startServer(k) }) } + } +} -func startServer() { +func startServer(k int) { p, err := newPollServer() if err != nil { - print("Start pollServer: ", err.Error(), "\n") + panic(err) } - pollserver = p + pollservers[k] = p } -func newFD(fd, family, sotype int, net string) (*netFD, error) { - onceStartServer.Do(startServer) - if err := syscall.SetNonblock(fd, true); err != nil { +func server(fd int) *pollServer { + pollN := runtime.GOMAXPROCS(0) + if pollN > pollMaxN { + pollN = pollMaxN + } + k := fd % pollN + startServersOnce[k]() + return pollservers[k] +} + +func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { + deadline := time.Now().Add(timeout) + ra, err := resolveAddr("dial", net, addr, deadline) + if err != nil { return nil, err } + return dial(net, addr, noLocalAddr, ra, deadline) +} + +func newFD(fd, family, sotype int, net string) (*netFD, error) { netfd := &netFD{ sysfd: fd, family: family, @@ -305,26 +309,31 @@ func newFD(fd, family, sotype int, net string) (*netFD, error) { } netfd.cr = make(chan error, 1) netfd.cw = make(chan error, 1) + netfd.pollServer = server(fd) return netfd, nil } func (fd *netFD) setAddr(laddr, raddr Addr) { fd.laddr = laddr fd.raddr = raddr + fd.sysfile = os.NewFile(uintptr(fd.sysfd), fd.net) +} + +func (fd *netFD) name() string { var ls, rs string - if laddr != nil { - ls = laddr.String() + if fd.laddr != nil { + ls = fd.laddr.String() } - if raddr != nil { - rs = raddr.String() + if fd.raddr != nil { + rs = fd.raddr.String() } - fd.sysfile = os.NewFile(uintptr(fd.sysfd), fd.net+":"+ls+"->"+rs) + return fd.net + ":" + ls + "->" + rs } func (fd *netFD) connect(ra syscall.Sockaddr) error { err := syscall.Connect(fd.sysfd, ra) if err == syscall.EINPROGRESS { - if err = pollserver.WaitWrite(fd); err != nil { + if err = fd.pollServer.WaitWrite(fd); err != nil { return err } var e int @@ -339,15 +348,10 @@ func (fd *netFD) connect(ra syscall.Sockaddr) error { return err } -var errClosing = errors.New("use of closed network connection") - // Add a reference to this fd. // If closing==true, pollserver must be locked; mark the fd as closing. // Returns an error if the fd cannot be used. func (fd *netFD) incref(closing bool) error { - if fd == nil { - return errClosing - } fd.sysmu.Lock() if fd.closing { fd.sysmu.Unlock() @@ -364,9 +368,6 @@ func (fd *netFD) incref(closing bool) error { // Remove a reference to this FD and close if we've been asked to do so (and // there are no references left. func (fd *netFD) decref() { - if fd == nil { - return - } fd.sysmu.Lock() fd.sysref-- if fd.closing && fd.sysref == 0 && fd.sysfile != nil { @@ -378,9 +379,9 @@ func (fd *netFD) decref() { } func (fd *netFD) Close() error { - pollserver.Lock() // needed for both fd.incref(true) and pollserver.Evict - defer pollserver.Unlock() + fd.pollServer.Lock() // needed for both fd.incref(true) and pollserver.Evict if err := fd.incref(true); err != nil { + fd.pollServer.Unlock() return err } // Unblock any I/O. Once it all unblocks and returns, @@ -388,8 +389,12 @@ func (fd *netFD) Close() error { // the final decref will close fd.sysfd. This should happen // fairly quickly, since all the I/O is non-blocking, and any // attempts to block in the pollserver will return errClosing. - pollserver.Evict(fd) + doWakeup := fd.pollServer.Evict(fd) + fd.pollServer.Unlock() fd.decref() + if doWakeup { + fd.pollServer.Wakeup() + } return nil } @@ -421,20 +426,20 @@ func (fd *netFD) Read(p []byte) (n int, err error) { } defer fd.decref() for { - n, err = syscall.Read(int(fd.sysfd), p) - if err == syscall.EAGAIN { + if fd.rdeadline.expired() { err = errTimeout - if fd.rdeadline >= 0 { - if err = pollserver.WaitRead(fd); err == nil { - continue - } - } + break } + n, err = syscall.Read(int(fd.sysfd), p) if err != nil { n = 0 - } else if n == 0 && err == nil && fd.sotype != syscall.SOCK_DGRAM { - err = io.EOF + if err == syscall.EAGAIN { + if err = fd.pollServer.WaitRead(fd); err == nil { + continue + } + } } + err = chkReadErr(n, err, fd) break } if err != nil && err != io.EOF { @@ -451,18 +456,20 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { } defer fd.decref() for { - n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0) - if err == syscall.EAGAIN { + if fd.rdeadline.expired() { err = errTimeout - if fd.rdeadline >= 0 { - if err = pollserver.WaitRead(fd); err == nil { - continue - } - } + break } + n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0) if err != nil { n = 0 + if err == syscall.EAGAIN { + if err = fd.pollServer.WaitRead(fd); err == nil { + continue + } + } } + err = chkReadErr(n, err, fd) break } if err != nil && err != io.EOF { @@ -479,41 +486,47 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S } defer fd.decref() for { - n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0) - if err == syscall.EAGAIN { + if fd.rdeadline.expired() { err = errTimeout - if fd.rdeadline >= 0 { - if err = pollserver.WaitRead(fd); err == nil { + break + } + n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0) + if err != nil { + // TODO(dfc) should n and oobn be set to 0 + if err == syscall.EAGAIN { + if err = fd.pollServer.WaitRead(fd); err == nil { continue } } } - if err == nil && n == 0 { - err = io.EOF - } + err = chkReadErr(n, err, fd) break } if err != nil && err != io.EOF { err = &OpError{"read", fd.net, fd.laddr, err} - return } return } -func (fd *netFD) Write(p []byte) (int, error) { +func chkReadErr(n int, err error, fd *netFD) error { + if n == 0 && err == nil && fd.sotype != syscall.SOCK_DGRAM && fd.sotype != syscall.SOCK_RAW { + return io.EOF + } + return err +} + +func (fd *netFD) Write(p []byte) (nn int, err error) { fd.wio.Lock() defer fd.wio.Unlock() if err := fd.incref(false); err != nil { return 0, err } defer fd.decref() - if fd.sysfile == nil { - return 0, syscall.EINVAL - } - - var err error - nn := 0 for { + if fd.wdeadline.expired() { + err = errTimeout + break + } var n int n, err = syscall.Write(int(fd.sysfd), p[nn:]) if n > 0 { @@ -523,11 +536,8 @@ func (fd *netFD) Write(p []byte) (int, error) { break } if err == syscall.EAGAIN { - err = errTimeout - if fd.wdeadline >= 0 { - if err = pollserver.WaitWrite(fd); err == nil { - continue - } + if err = fd.pollServer.WaitWrite(fd); err == nil { + continue } } if err != nil { @@ -553,13 +563,14 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { } defer fd.decref() for { + if fd.wdeadline.expired() { + err = errTimeout + break + } err = syscall.Sendto(fd.sysfd, p, 0, sa) if err == syscall.EAGAIN { - err = errTimeout - if fd.wdeadline >= 0 { - if err = pollserver.WaitWrite(fd); err == nil { - continue - } + if err = fd.pollServer.WaitWrite(fd); err == nil { + continue } } break @@ -580,13 +591,14 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob } defer fd.decref() for { + if fd.wdeadline.expired() { + err = errTimeout + break + } err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) if err == syscall.EAGAIN { - err = errTimeout - if fd.wdeadline >= 0 { - if err = pollserver.WaitWrite(fd); err == nil { - continue - } + if err = fd.pollServer.WaitWrite(fd); err == nil { + continue } } break @@ -606,22 +618,14 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e } defer fd.decref() - // See ../syscall/exec.go for description of ForkLock. - // It is okay to hold the lock across syscall.Accept - // because we have put fd.sysfd into non-blocking mode. var s int var rsa syscall.Sockaddr for { - syscall.ForkLock.RLock() - s, rsa, err = syscall.Accept(fd.sysfd) + s, rsa, err = accept(fd.sysfd) if err != nil { - syscall.ForkLock.RUnlock() if err == syscall.EAGAIN { - err = errTimeout - if fd.rdeadline >= 0 { - if err = pollserver.WaitRead(fd); err == nil { - continue - } + if err = fd.pollServer.WaitRead(fd); err == nil { + continue } } else if err == syscall.ECONNABORTED { // This means that a socket on the listen queue was closed @@ -632,11 +636,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e } break } - syscall.CloseOnExec(s) - syscall.ForkLock.RUnlock() if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil { - syscall.Close(s) + closesocket(s) return nil, err } lsa, _ := syscall.Getsockname(netfd.sysfd) @@ -645,17 +647,24 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e } func (fd *netFD) dup() (f *os.File, err error) { + syscall.ForkLock.RLock() ns, err := syscall.Dup(fd.sysfd) if err != nil { + syscall.ForkLock.RUnlock() return nil, &OpError{"dup", fd.net, fd.laddr, err} } + syscall.CloseOnExec(ns) + syscall.ForkLock.RUnlock() // We want blocking mode for the new fd, hence the double negative. + // This also puts the old fd into blocking mode, meaning that + // I/O will block the thread instead of letting us use the epoll server. + // Everything will still work, just with more threads. if err = syscall.SetNonblock(ns, false); err != nil { return nil, &OpError{"setnonblock", fd.net, fd.laddr, err} } - return os.NewFile(uintptr(ns), fd.sysfile.Name()), nil + return os.NewFile(uintptr(ns), fd.name()), nil } func closesocket(s int) error { diff --git a/src/pkg/net/fd_unix_test.go b/src/pkg/net/fd_unix_test.go new file mode 100644 index 000000000..664ef1bf1 --- /dev/null +++ b/src/pkg/net/fd_unix_test.go @@ -0,0 +1,58 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd + +package net + +import ( + "io" + "syscall" + "testing" +) + +var chkReadErrTests = []struct { + n int + err error + fd *netFD + expected error +}{ + + {100, nil, &netFD{sotype: syscall.SOCK_STREAM}, nil}, + {100, io.EOF, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF}, + {100, errClosing, &netFD{sotype: syscall.SOCK_STREAM}, errClosing}, + {0, nil, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF}, + {0, io.EOF, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF}, + {0, errClosing, &netFD{sotype: syscall.SOCK_STREAM}, errClosing}, + + {100, nil, &netFD{sotype: syscall.SOCK_DGRAM}, nil}, + {100, io.EOF, &netFD{sotype: syscall.SOCK_DGRAM}, io.EOF}, + {100, errClosing, &netFD{sotype: syscall.SOCK_DGRAM}, errClosing}, + {0, nil, &netFD{sotype: syscall.SOCK_DGRAM}, nil}, + {0, io.EOF, &netFD{sotype: syscall.SOCK_DGRAM}, io.EOF}, + {0, errClosing, &netFD{sotype: syscall.SOCK_DGRAM}, errClosing}, + + {100, nil, &netFD{sotype: syscall.SOCK_SEQPACKET}, nil}, + {100, io.EOF, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF}, + {100, errClosing, &netFD{sotype: syscall.SOCK_SEQPACKET}, errClosing}, + {0, nil, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF}, + {0, io.EOF, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF}, + {0, errClosing, &netFD{sotype: syscall.SOCK_SEQPACKET}, errClosing}, + + {100, nil, &netFD{sotype: syscall.SOCK_RAW}, nil}, + {100, io.EOF, &netFD{sotype: syscall.SOCK_RAW}, io.EOF}, + {100, errClosing, &netFD{sotype: syscall.SOCK_RAW}, errClosing}, + {0, nil, &netFD{sotype: syscall.SOCK_RAW}, nil}, + {0, io.EOF, &netFD{sotype: syscall.SOCK_RAW}, io.EOF}, + {0, errClosing, &netFD{sotype: syscall.SOCK_RAW}, errClosing}, +} + +func TestChkReadErr(t *testing.T) { + for _, tt := range chkReadErrTests { + actual := chkReadErr(tt.n, tt.err, tt.fd) + if actual != tt.expected { + t.Errorf("chkReadError(%v, %v, %v): expected %v, actual %v", tt.n, tt.err, tt.fd.sotype, tt.expected, actual) + } + } +} diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index 45f5c2d88..0e331b44d 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -17,19 +17,58 @@ import ( var initErr error -func init() { +// CancelIo Windows API cancels all outstanding IO for a particular +// socket on current thread. To overcome that limitation, we run +// special goroutine, locked to OS single thread, that both starts +// and cancels IO. It means, there are 2 unavoidable thread switches +// for every IO. +// Some newer versions of Windows has new CancelIoEx API, that does +// not have that limitation and can be used from any thread. This +// package uses CancelIoEx API, if present, otherwise it fallback +// to CancelIo. + +var canCancelIO bool // determines if CancelIoEx API is present + +func sysInit() { var d syscall.WSAData e := syscall.WSAStartup(uint32(0x202), &d) if e != nil { initErr = os.NewSyscallError("WSAStartup", e) } + canCancelIO = syscall.LoadCancelIoEx() == nil + if syscall.LoadGetAddrInfo() == nil { + lookupPort = newLookupPort + lookupIP = newLookupIP + } } func closesocket(s syscall.Handle) error { return syscall.Closesocket(s) } -// Interface for all io operations. +func canUseConnectEx(net string) bool { + if net == "udp" || net == "udp4" || net == "udp6" { + // ConnectEx windows API does not support connectionless sockets. + return false + } + return syscall.LoadConnectEx() == nil +} + +func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { + if !canUseConnectEx(net) { + // Use the relatively inefficient goroutine-racing + // implementation of DialTimeout. + return dialTimeoutRace(net, addr, timeout) + } + deadline := time.Now().Add(timeout) + ra, err := resolveAddr("dial", net, addr, deadline) + if err != nil { + return nil, err + } + return dial(net, addr, noLocalAddr, ra, deadline) +} + +// Interface for all IO operations. type anOpIface interface { Op() *anOp Name() string @@ -42,7 +81,7 @@ type ioResult struct { err error } -// anOp implements functionality common to all io operations. +// anOp implements functionality common to all IO operations. type anOp struct { // Used by IOCP interface, it must be first field // of the struct, as our code rely on it. @@ -75,7 +114,7 @@ func (o *anOp) Op() *anOp { return o } -// bufOp is used by io operations that read / write +// bufOp is used by IO operations that read / write // data from / to client buffer. type bufOp struct { anOp @@ -92,7 +131,7 @@ func (o *bufOp) Init(fd *netFD, buf []byte, mode int) { } } -// resultSrv will retrieve all io completion results from +// resultSrv will retrieve all IO completion results from // iocp and send them to the correspondent waiting client // goroutine via channel supplied in the request. type resultSrv struct { @@ -107,7 +146,7 @@ func (s *resultSrv) Run() { r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) switch { case r.err == nil: - // Dequeued successfully completed io packet. + // Dequeued successfully completed IO packet. case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil: // Wait has timed out (should not happen now, but might be used in the future). panic("GetQueuedCompletionStatus timed out") @@ -115,22 +154,23 @@ func (s *resultSrv) Run() { // Failed to dequeue anything -> report the error. panic("GetQueuedCompletionStatus failed " + r.err.Error()) default: - // Dequeued failed io packet. + // Dequeued failed IO packet. } (*anOp)(unsafe.Pointer(o)).resultc <- r } } -// ioSrv executes net io requests. +// ioSrv executes net IO requests. type ioSrv struct { - submchan chan anOpIface // submit io requests - canchan chan anOpIface // cancel io requests + submchan chan anOpIface // submit IO requests + canchan chan anOpIface // cancel IO requests } -// ProcessRemoteIO will execute submit io requests on behalf +// ProcessRemoteIO will execute submit IO requests on behalf // of other goroutines, all on a single os thread, so it can // cancel them later. Results of all operations will be sent // back to their requesters via channel supplied in request. +// It is used only when the CancelIoEx API is unavailable. func (s *ioSrv) ProcessRemoteIO() { runtime.LockOSThread() defer runtime.UnlockOSThread() @@ -144,20 +184,30 @@ func (s *ioSrv) ProcessRemoteIO() { } } -// ExecIO executes a single io operation. It either executes it -// inline, or, if a deadline is employed, passes the request onto +// ExecIO executes a single IO operation oi. It submits and cancels +// IO in the current thread for systems where Windows CancelIoEx API +// is available. Alternatively, it passes the request onto // a special goroutine and waits for completion or cancels request. // deadline is unix nanos. func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) { var err error o := oi.Op() + // Calculate timeout delta. + var delta int64 if deadline != 0 { + delta = deadline - time.Now().UnixNano() + if delta <= 0 { + return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, errTimeout} + } + } + // Start IO. + if canCancelIO { + err = oi.Submit() + } else { // Send request to a special dedicated thread, - // so it can stop the io with CancelIO later. + // so it can stop the IO with CancelIO later. s.submchan <- oi err = <-o.errnoc - } else { - err = oi.Submit() } switch err { case nil: @@ -168,27 +218,46 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) { default: return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, err} } + // Setup timer, if deadline is given. + var timer <-chan time.Time + if delta > 0 { + t := time.NewTimer(time.Duration(delta) * time.Nanosecond) + defer t.Stop() + timer = t.C + } // Wait for our request to complete. var r ioResult - if deadline != 0 { - dt := deadline - time.Now().UnixNano() - if dt < 1 { - dt = 1 - } - timer := time.NewTimer(time.Duration(dt) * time.Nanosecond) - defer timer.Stop() - select { - case r = <-o.resultc: - case <-timer.C: + var cancelled, timeout bool + select { + case r = <-o.resultc: + case <-timer: + cancelled = true + timeout = true + case <-o.fd.closec: + cancelled = true + } + if cancelled { + // Cancel it. + if canCancelIO { + err := syscall.CancelIoEx(syscall.Handle(o.Op().fd.sysfd), &o.o) + // Assuming ERROR_NOT_FOUND is returned, if IO is completed. + if err != nil && err != syscall.ERROR_NOT_FOUND { + // TODO(brainman): maybe do something else, but panic. + panic(err) + } + } else { s.canchan <- oi <-o.errnoc - r = <-o.resultc - if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled - r.err = syscall.EWOULDBLOCK - } } - } else { + // Wait for IO to be canceled or complete successfully. r = <-o.resultc + if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled + if timeout { + r.err = errTimeout + } else { + r.err = errClosing + } + } } if r.err != nil { err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err} @@ -211,9 +280,13 @@ func startServer() { go resultsrv.Run() iosrv = new(ioSrv) - iosrv.submchan = make(chan anOpIface) - iosrv.canchan = make(chan anOpIface) - go iosrv.ProcessRemoteIO() + if !canCancelIO { + // Only CancelIo API is available. Lets start special goroutine + // locked to an OS thread, that both starts and cancels IO. + iosrv.submchan = make(chan anOpIface) + iosrv.canchan = make(chan anOpIface) + go iosrv.ProcessRemoteIO() + } } // Network file descriptor. @@ -233,12 +306,13 @@ type netFD struct { raddr Addr resultc [2]chan ioResult // read/write completion results errnoc [2]chan error // read/write submit or cancel operation errors + closec chan bool // used by Close to cancel pending IO + + // serialize access to Read and Write methods + rio, wio sync.Mutex - // owned by client - rdeadline int64 - rio sync.Mutex - wdeadline int64 - wio sync.Mutex + // read and write deadlines + rdeadline, wdeadline deadline } func allocFD(fd syscall.Handle, family, sotype int, net string) *netFD { @@ -247,8 +321,8 @@ func allocFD(fd syscall.Handle, family, sotype int, net string) *netFD { family: family, sotype: sotype, net: net, + closec: make(chan bool), } - runtime.SetFinalizer(netfd, (*netFD).Close) return netfd } @@ -267,13 +341,52 @@ func newFD(fd syscall.Handle, family, proto int, net string) (*netFD, error) { func (fd *netFD) setAddr(laddr, raddr Addr) { fd.laddr = laddr fd.raddr = raddr + runtime.SetFinalizer(fd, (*netFD).closesocket) } -func (fd *netFD) connect(ra syscall.Sockaddr) error { - return syscall.Connect(fd.sysfd, ra) +// Make new connection. + +type connectOp struct { + anOp + ra syscall.Sockaddr +} + +func (o *connectOp) Submit() error { + return syscall.ConnectEx(o.fd.sysfd, o.ra, nil, 0, nil, &o.o) } -var errClosing = errors.New("use of closed network connection") +func (o *connectOp) Name() string { + return "ConnectEx" +} + +func (fd *netFD) connect(ra syscall.Sockaddr) error { + if !canUseConnectEx(fd.net) { + return syscall.Connect(fd.sysfd, ra) + } + // ConnectEx windows API requires an unconnected, previously bound socket. + var la syscall.Sockaddr + switch ra.(type) { + case *syscall.SockaddrInet4: + la = &syscall.SockaddrInet4{} + case *syscall.SockaddrInet6: + la = &syscall.SockaddrInet6{} + default: + panic("unexpected type in connect") + } + if err := syscall.Bind(fd.sysfd, la); err != nil { + return err + } + // Call ConnectEx API. + var o connectOp + o.Init(fd, 'w') + o.ra = ra + _, err := iosrv.ExecIO(&o, fd.wdeadline.value()) + if err != nil { + return err + } + // Refresh socket properties. + return syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))) +} // Add a reference to this fd. // If closing==true, mark the fd as closing. @@ -299,24 +412,12 @@ func (fd *netFD) incref(closing bool) error { // Remove a reference to this FD and close if we've been asked to do so (and // there are no references left. func (fd *netFD) decref() { + if fd == nil { + return + } fd.sysmu.Lock() fd.sysref-- - // NOTE(rsc): On Unix we check fd.sysref == 0 here before closing, - // but on Windows we have no way to wake up the blocked I/O other - // than closing the socket (or calling Shutdown, which breaks other - // programs that might have a reference to the socket). So there is - // a small race here that we might close fd.sysfd and then some other - // goroutine might start a read of fd.sysfd (having read it before we - // write InvalidHandle to it), which might refer to some other file - // if the specific handle value gets reused. I think handle values on - // Windows are not reused as aggressively as file descriptors on Unix, - // so this might be tolerable. - if fd.closing && fd.sysfd != syscall.InvalidHandle { - // In case the user has set linger, switch to blocking mode so - // the close blocks. As long as this doesn't happen often, we - // can handle the extra OS processes. Otherwise we'll need to - // use the resultsrv for Close too. Sigh. - syscall.SetNonblock(fd.sysfd, false) + if fd.closing && fd.sysref == 0 && fd.sysfd != syscall.InvalidHandle { closesocket(fd.sysfd) fd.sysfd = syscall.InvalidHandle // no need for a finalizer anymore @@ -329,14 +430,22 @@ func (fd *netFD) Close() error { if err := fd.incref(true); err != nil { return err } - fd.decref() + defer fd.decref() + // unblock pending reader and writer + close(fd.closec) + // wait for both reader and writer to exit + fd.rio.Lock() + defer fd.rio.Unlock() + fd.wio.Lock() + defer fd.wio.Unlock() return nil } func (fd *netFD) shutdown(how int) error { - if fd == nil || fd.sysfd == syscall.InvalidHandle { - return syscall.EINVAL + if err := fd.incref(false); err != nil { + return err } + defer fd.decref() err := syscall.Shutdown(fd.sysfd, how) if err != nil { return &OpError{"shutdown", fd.net, fd.laddr, err} @@ -352,6 +461,10 @@ func (fd *netFD) CloseWrite() error { return fd.shutdown(syscall.SHUT_WR) } +func (fd *netFD) closesocket() error { + return closesocket(fd.sysfd) +} + // Read from network. type readOp struct { @@ -368,21 +481,15 @@ func (o *readOp) Name() string { } func (fd *netFD) Read(buf []byte) (int, error) { - if fd == nil { - return 0, syscall.EINVAL - } - fd.rio.Lock() - defer fd.rio.Unlock() if err := fd.incref(false); err != nil { return 0, err } defer fd.decref() - if fd.sysfd == syscall.InvalidHandle { - return 0, syscall.EINVAL - } + fd.rio.Lock() + defer fd.rio.Unlock() var o readOp o.Init(fd, buf, 'r') - n, err := iosrv.ExecIO(&o, fd.rdeadline) + n, err := iosrv.ExecIO(&o, fd.rdeadline.value()) if err == nil && n == 0 { err = io.EOF } @@ -407,22 +514,19 @@ func (o *readFromOp) Name() string { } func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) { - if fd == nil { - return 0, nil, syscall.EINVAL - } if len(buf) == 0 { return 0, nil, nil } - fd.rio.Lock() - defer fd.rio.Unlock() if err := fd.incref(false); err != nil { return 0, nil, err } defer fd.decref() + fd.rio.Lock() + defer fd.rio.Unlock() var o readFromOp o.Init(fd, buf, 'r') o.rsan = int32(unsafe.Sizeof(o.rsa)) - n, err = iosrv.ExecIO(&o, fd.rdeadline) + n, err = iosrv.ExecIO(&o, fd.rdeadline.value()) if err != nil { return 0, nil, err } @@ -446,18 +550,15 @@ func (o *writeOp) Name() string { } func (fd *netFD) Write(buf []byte) (int, error) { - if fd == nil { - return 0, syscall.EINVAL - } - fd.wio.Lock() - defer fd.wio.Unlock() if err := fd.incref(false); err != nil { return 0, err } defer fd.decref() + fd.wio.Lock() + defer fd.wio.Unlock() var o writeOp o.Init(fd, buf, 'w') - return iosrv.ExecIO(&o, fd.wdeadline) + return iosrv.ExecIO(&o, fd.wdeadline.value()) } // WriteTo to network. @@ -477,25 +578,19 @@ func (o *writeToOp) Name() string { } func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (int, error) { - if fd == nil { - return 0, syscall.EINVAL - } if len(buf) == 0 { return 0, nil } - fd.wio.Lock() - defer fd.wio.Unlock() if err := fd.incref(false); err != nil { return 0, err } defer fd.decref() - if fd.sysfd == syscall.InvalidHandle { - return 0, syscall.EINVAL - } + fd.wio.Lock() + defer fd.wio.Unlock() var o writeToOp o.Init(fd, buf, 'w') o.sa = sa - return iosrv.ExecIO(&o, fd.wdeadline) + return iosrv.ExecIO(&o, fd.wdeadline.value()) } // Accept new network connections. @@ -524,19 +619,15 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) { defer fd.decref() // Get new socket. - // See ../syscall/exec.go for description of ForkLock. - syscall.ForkLock.RLock() - s, err := syscall.Socket(fd.family, fd.sotype, 0) + s, err := sysSocket(fd.family, fd.sotype, 0) if err != nil { - syscall.ForkLock.RUnlock() - return nil, err + return nil, &OpError{"socket", fd.net, fd.laddr, err} } - syscall.CloseOnExec(s) - syscall.ForkLock.RUnlock() // Associate our new socket with IOCP. onceStartServer.Do(startServer) if _, err := syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); err != nil { + closesocket(s) return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, err} } @@ -544,7 +635,7 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) { var o acceptOp o.Init(fd, 'r') o.newsock = s - _, err = iosrv.ExecIO(&o, 0) + _, err = iosrv.ExecIO(&o, fd.rdeadline.value()) if err != nil { closesocket(s) return nil, err @@ -554,7 +645,7 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) { err = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))) if err != nil { closesocket(s) - return nil, err + return nil, &OpError{"Setsockopt", fd.net, fd.laddr, err} } // Get local and peer addr out of AcceptEx buffer. diff --git a/src/pkg/net/file_plan9.go b/src/pkg/net/file_plan9.go index 04f7ee040..f6ee1c29e 100644 --- a/src/pkg/net/file_plan9.go +++ b/src/pkg/net/file_plan9.go @@ -5,24 +5,147 @@ package net import ( + "errors" + "io" "os" "syscall" ) +func (fd *netFD) status(ln int) (string, error) { + if !fd.ok() { + return "", syscall.EINVAL + } + + status, err := os.Open(fd.dir + "/status") + if err != nil { + return "", err + } + defer status.Close() + buf := make([]byte, ln) + n, err := io.ReadFull(status, buf[:]) + if err != nil { + return "", err + } + return string(buf[:n]), nil +} + +func newFileFD(f *os.File) (net *netFD, err error) { + var ctl *os.File + close := func(fd int) { + if err != nil { + syscall.Close(fd) + } + } + + path, err := syscall.Fd2path(int(f.Fd())) + if err != nil { + return nil, os.NewSyscallError("fd2path", err) + } + comp := splitAtBytes(path, "/") + n := len(comp) + if n < 3 || comp[0] != "net" { + return nil, syscall.EPLAN9 + } + + name := comp[2] + switch file := comp[n-1]; file { + case "ctl", "clone": + syscall.ForkLock.RLock() + fd, err := syscall.Dup(int(f.Fd()), -1) + syscall.ForkLock.RUnlock() + if err != nil { + return nil, os.NewSyscallError("dup", err) + } + defer close(fd) + + dir := "/net/" + comp[n-2] + ctl = os.NewFile(uintptr(fd), dir+"/"+file) + ctl.Seek(0, 0) + var buf [16]byte + n, err := ctl.Read(buf[:]) + if err != nil { + return nil, err + } + name = string(buf[:n]) + default: + if len(comp) < 4 { + return nil, errors.New("could not find control file for connection") + } + dir := "/net/" + comp[1] + "/" + name + ctl, err = os.OpenFile(dir+"/ctl", os.O_RDWR, 0) + if err != nil { + return nil, err + } + defer close(int(ctl.Fd())) + } + dir := "/net/" + comp[1] + "/" + name + laddr, err := readPlan9Addr(comp[1], dir+"/local") + if err != nil { + return nil, err + } + return newFD(comp[1], name, ctl, nil, laddr, nil), nil +} + +func newFileConn(f *os.File) (c Conn, err error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + if !fd.ok() { + return nil, syscall.EINVAL + } + + fd.data, err = os.OpenFile(fd.dir+"/data", os.O_RDWR, 0) + if err != nil { + return nil, err + } + + switch fd.laddr.(type) { + case *TCPAddr: + return newTCPConn(fd), nil + case *UDPAddr: + return newUDPConn(fd), nil + } + return nil, syscall.EPLAN9 +} + +func newFileListener(f *os.File) (l Listener, err error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + switch fd.laddr.(type) { + case *TCPAddr: + default: + return nil, syscall.EPLAN9 + } + + // check that file corresponds to a listener + s, err := fd.status(len("Listen")) + if err != nil { + return nil, err + } + if s != "Listen" { + return nil, errors.New("file does not represent a listener") + } + + return &TCPListener{fd}, nil +} + // FileConn returns a copy of the network connection corresponding to // the open file f. It is the caller's responsibility to close f when // finished. Closing c does not affect f, and closing f does not // affect c. func FileConn(f *os.File) (c Conn, err error) { - return nil, syscall.EPLAN9 + return newFileConn(f) } // FileListener returns a copy of the network listener corresponding // to the open file f. It is the caller's responsibility to close l -// when finished. Closing c does not affect l, and closing l does not -// affect c. +// when finished. Closing l does not affect f, and closing f does not +// affect l. func FileListener(f *os.File) (l Listener, err error) { - return nil, syscall.EPLAN9 + return newFileListener(f) } // FilePacketConn returns a copy of the packet network connection diff --git a/src/pkg/net/file_test.go b/src/pkg/net/file_test.go index 95c0b6699..acaf18851 100644 --- a/src/pkg/net/file_test.go +++ b/src/pkg/net/file_test.go @@ -89,9 +89,8 @@ var fileListenerTests = []struct { func TestFileListener(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": - t.Logf("skipping test on %q", runtime.GOOS) - return + case "windows": + t.Skipf("skipping test on %q", runtime.GOOS) } for _, tt := range fileListenerTests { @@ -181,8 +180,7 @@ var filePacketConnTests = []struct { func TestFilePacketConn(t *testing.T) { switch runtime.GOOS { case "plan9", "windows": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } for _, tt := range filePacketConnTests { diff --git a/src/pkg/net/file.go b/src/pkg/net/file_unix.go index fc6c6fad8..4c8403e40 100644 --- a/src/pkg/net/file.go +++ b/src/pkg/net/file_unix.go @@ -12,52 +12,62 @@ import ( ) func newFileFD(f *os.File) (*netFD, error) { + syscall.ForkLock.RLock() fd, err := syscall.Dup(int(f.Fd())) if err != nil { + syscall.ForkLock.RUnlock() return nil, os.NewSyscallError("dup", err) } + syscall.CloseOnExec(fd) + syscall.ForkLock.RUnlock() + if err = syscall.SetNonblock(fd, true); err != nil { + closesocket(fd) + return nil, err + } - proto, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE) + sotype, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE) if err != nil { + closesocket(fd) return nil, os.NewSyscallError("getsockopt", err) } family := syscall.AF_UNSPEC toAddr := sockaddrToTCP - sa, _ := syscall.Getsockname(fd) - switch sa.(type) { + lsa, _ := syscall.Getsockname(fd) + switch lsa.(type) { default: closesocket(fd) return nil, syscall.EINVAL case *syscall.SockaddrInet4: family = syscall.AF_INET - if proto == syscall.SOCK_DGRAM { + if sotype == syscall.SOCK_DGRAM { toAddr = sockaddrToUDP - } else if proto == syscall.SOCK_RAW { + } else if sotype == syscall.SOCK_RAW { toAddr = sockaddrToIP } case *syscall.SockaddrInet6: family = syscall.AF_INET6 - if proto == syscall.SOCK_DGRAM { + if sotype == syscall.SOCK_DGRAM { toAddr = sockaddrToUDP - } else if proto == syscall.SOCK_RAW { + } else if sotype == syscall.SOCK_RAW { toAddr = sockaddrToIP } case *syscall.SockaddrUnix: family = syscall.AF_UNIX toAddr = sockaddrToUnix - if proto == syscall.SOCK_DGRAM { + if sotype == syscall.SOCK_DGRAM { toAddr = sockaddrToUnixgram - } else if proto == syscall.SOCK_SEQPACKET { + } else if sotype == syscall.SOCK_SEQPACKET { toAddr = sockaddrToUnixpacket } } - laddr := toAddr(sa) - sa, _ = syscall.Getpeername(fd) - raddr := toAddr(sa) + laddr := toAddr(lsa) + rsa, _ := syscall.Getpeername(fd) + raddr := toAddr(rsa) - netfd, err := newFD(fd, family, proto, laddr.Network()) + netfd, err := newFD(fd, family, sotype, laddr.Network()) if err != nil { + closesocket(fd) return nil, err } netfd.setAddr(laddr, raddr) @@ -78,10 +88,10 @@ func FileConn(f *os.File) (c Conn, err error) { return newTCPConn(fd), nil case *UDPAddr: return newUDPConn(fd), nil - case *UnixAddr: - return newUnixConn(fd), nil case *IPAddr: return newIPConn(fd), nil + case *UnixAddr: + return newUnixConn(fd), nil } fd.Close() return nil, syscall.EINVAL diff --git a/src/pkg/net/http/cgi/child.go b/src/pkg/net/http/cgi/child.go index 1ba7bec5f..100b8b777 100644 --- a/src/pkg/net/http/cgi/child.go +++ b/src/pkg/net/http/cgi/child.go @@ -91,10 +91,19 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { // TODO: cookies. parsing them isn't exported, though. + uriStr := params["REQUEST_URI"] + if uriStr == "" { + // Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING. + uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"] + s := params["QUERY_STRING"] + if s != "" { + uriStr += "?" + s + } + } if r.Host != "" { // Hostname is provided, so we can reasonably construct a URL, // even if we have to assume 'http' for the scheme. - rawurl := "http://" + r.Host + params["REQUEST_URI"] + rawurl := "http://" + r.Host + uriStr url, err := url.Parse(rawurl) if err != nil { return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl) @@ -104,7 +113,6 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { // Fallback logic if we don't have a Host header or the URL // failed to parse if r.URL == nil { - uriStr := params["REQUEST_URI"] url, err := url.Parse(uriStr) if err != nil { return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr) diff --git a/src/pkg/net/http/cgi/child_test.go b/src/pkg/net/http/cgi/child_test.go index ec53ab851..74e068014 100644 --- a/src/pkg/net/http/cgi/child_test.go +++ b/src/pkg/net/http/cgi/child_test.go @@ -82,6 +82,28 @@ func TestRequestWithoutHost(t *testing.T) { t.Fatalf("unexpected nil URL") } if g, e := req.URL.String(), "/path?a=b"; e != g { - t.Errorf("expected URL %q; got %q", e, g) + t.Errorf("URL = %q; want %q", g, e) + } +} + +func TestRequestWithoutRequestURI(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "example.com", + "REQUEST_METHOD": "GET", + "SCRIPT_NAME": "/dir/scriptname", + "PATH_INFO": "/p1/p2", + "QUERY_STRING": "a=1&b=2", + "CONTENT_LENGTH": "123", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if req.URL == nil { + t.Fatalf("unexpected nil URL") + } + if g, e := req.URL.String(), "http://example.com/dir/scriptname/p1/p2?a=1&b=2"; e != g { + t.Errorf("URL = %q; want %q", g, e) } } diff --git a/src/pkg/net/http/cgi/host_test.go b/src/pkg/net/http/cgi/host_test.go index 4db3d850c..8c16e6897 100644 --- a/src/pkg/net/http/cgi/host_test.go +++ b/src/pkg/net/http/cgi/host_test.go @@ -19,7 +19,6 @@ import ( "runtime" "strconv" "strings" - "syscall" "testing" "time" ) @@ -63,17 +62,25 @@ readlines: } for key, expected := range expectedMap { - if got := m[key]; got != expected { + got := m[key] + if key == "cwd" { + // For Windows. golang.org/issue/4645. + fi1, _ := os.Stat(got) + fi2, _ := os.Stat(expected) + if os.SameFile(fi1, fi2) { + got = expected + } + } + if got != expected { t.Errorf("for key %q got %q; expected %q", key, got, expected) } } return rw } -var cgiTested = false -var cgiWorks bool +var cgiTested, cgiWorks bool -func skipTest(t *testing.T) bool { +func check(t *testing.T) { if !cgiTested { cgiTested = true cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil @@ -81,16 +88,12 @@ func skipTest(t *testing.T) bool { if !cgiWorks { // No Perl on Windows, needed by test.cgi // TODO: make the child process be Go, not Perl. - t.Logf("Skipping test: test.cgi failed.") - return true + t.Skip("Skipping test: test.cgi failed.") } - return false } func TestCGIBasicGet(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "/test.cgi", @@ -124,9 +127,7 @@ func TestCGIBasicGet(t *testing.T) { } func TestCGIBasicGetAbsPath(t *testing.T) { - if skipTest(t) { - return - } + check(t) pwd, err := os.Getwd() if err != nil { t.Fatalf("getwd error: %v", err) @@ -144,9 +145,7 @@ func TestCGIBasicGetAbsPath(t *testing.T) { } func TestPathInfo(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "/test.cgi", @@ -163,9 +162,7 @@ func TestPathInfo(t *testing.T) { } func TestPathInfoDirRoot(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "/myscript/", @@ -181,9 +178,7 @@ func TestPathInfoDirRoot(t *testing.T) { } func TestDupHeaders(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", } @@ -203,9 +198,7 @@ func TestDupHeaders(t *testing.T) { } func TestPathInfoNoRoot(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "", @@ -221,9 +214,7 @@ func TestPathInfoNoRoot(t *testing.T) { } func TestCGIBasicPost(t *testing.T) { - if skipTest(t) { - return - } + check(t) postReq := `POST /test.cgi?a=b HTTP/1.0 Host: example.com Content-Type: application/x-www-form-urlencoded @@ -250,9 +241,7 @@ func chunk(s string) string { // The CGI spec doesn't allow chunked requests. func TestCGIPostChunked(t *testing.T) { - if skipTest(t) { - return - } + check(t) postReq := `POST /test.cgi?a=b HTTP/1.1 Host: example.com Content-Type: application/x-www-form-urlencoded @@ -273,9 +262,7 @@ Transfer-Encoding: chunked } func TestRedirect(t *testing.T) { - if skipTest(t) { - return - } + check(t) h := &Handler{ Path: "testdata/test.cgi", Root: "/test.cgi", @@ -290,9 +277,7 @@ func TestRedirect(t *testing.T) { } func TestInternalRedirect(t *testing.T) { - if skipTest(t) { - return - } + check(t) baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) @@ -312,8 +297,9 @@ func TestInternalRedirect(t *testing.T) { // TestCopyError tests that we kill the process if there's an error copying // its output. (for example, from the client having gone away) func TestCopyError(t *testing.T) { - if skipTest(t) || runtime.GOOS == "windows" { - return + check(t) + if runtime.GOOS == "windows" { + t.Skipf("skipping test on %q", runtime.GOOS) } h := &Handler{ Path: "testdata/test.cgi", @@ -353,11 +339,7 @@ func TestCopyError(t *testing.T) { } childRunning := func() bool { - p, err := os.FindProcess(pid) - if err != nil { - return false - } - return p.Signal(syscall.Signal(0)) == nil + return isProcessRunning(t, pid) } if !childRunning() { @@ -376,10 +358,10 @@ func TestCopyError(t *testing.T) { } func TestDirUnix(t *testing.T) { - if skipTest(t) || runtime.GOOS == "windows" { - return + check(t) + if runtime.GOOS == "windows" { + t.Skipf("skipping test on %q", runtime.GOOS) } - cwd, _ := os.Getwd() h := &Handler{ Path: "testdata/test.cgi", @@ -404,8 +386,8 @@ func TestDirUnix(t *testing.T) { } func TestDirWindows(t *testing.T) { - if skipTest(t) || runtime.GOOS != "windows" { - return + if runtime.GOOS != "windows" { + t.Skip("Skipping windows specific test.") } cgifile, _ := filepath.Abs("testdata/test.cgi") @@ -414,7 +396,7 @@ func TestDirWindows(t *testing.T) { var err error perl, err = exec.LookPath("perl") if err != nil { - return + t.Skip("Skipping test: perl not found.") } perl, _ = filepath.Abs(perl) @@ -456,7 +438,7 @@ func TestEnvOverride(t *testing.T) { var err error perl, err = exec.LookPath("perl") if err != nil { - return + t.Skipf("Skipping test: perl not found.") } perl, _ = filepath.Abs(perl) diff --git a/src/pkg/net/http/cgi/plan9_test.go b/src/pkg/net/http/cgi/plan9_test.go new file mode 100644 index 000000000..c8235831b --- /dev/null +++ b/src/pkg/net/http/cgi/plan9_test.go @@ -0,0 +1,18 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build plan9 + +package cgi + +import ( + "os" + "strconv" + "testing" +) + +func isProcessRunning(t *testing.T, pid int) bool { + _, err := os.Stat("/proc/" + strconv.Itoa(pid)) + return err == nil +} diff --git a/src/pkg/net/http/cgi/posix_test.go b/src/pkg/net/http/cgi/posix_test.go new file mode 100644 index 000000000..5ff9e7d5e --- /dev/null +++ b/src/pkg/net/http/cgi/posix_test.go @@ -0,0 +1,21 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9 + +package cgi + +import ( + "os" + "syscall" + "testing" +) + +func isProcessRunning(t *testing.T, pid int) bool { + p, err := os.FindProcess(pid) + if err != nil { + return false + } + return p.Signal(syscall.Signal(0)) == nil +} diff --git a/src/pkg/net/http/cgi/testdata/test.cgi b/src/pkg/net/http/cgi/testdata/test.cgi index b46b1330f..3214df6f0 100755 --- a/src/pkg/net/http/cgi/testdata/test.cgi +++ b/src/pkg/net/http/cgi/testdata/test.cgi @@ -8,6 +8,8 @@ use strict; use Cwd; +binmode STDOUT; + my $q = MiniCGI->new; my $params = $q->Vars; @@ -16,51 +18,44 @@ if ($params->{"loc"}) { exit(0); } -my $NL = "\r\n"; -$NL = "\n" if $params->{mode} eq "NL"; - -my $p = sub { - print "$_[0]$NL"; -}; - -# With carriage returns -$p->("Content-Type: text/html"); -$p->("X-CGI-Pid: $$"); -$p->("X-Test-Header: X-Test-Value"); -$p->(""); +print "Content-Type: text/html\r\n"; +print "X-CGI-Pid: $$\r\n"; +print "X-Test-Header: X-Test-Value\r\n"; +print "\r\n"; if ($params->{"bigresponse"}) { - for (1..1024) { - print "A" x 1024, "\n"; + # 17 MB, for OS X: golang.org/issue/4958 + for (1..(17 * 1024)) { + print "A" x 1024, "\r\n"; } exit 0; } -print "test=Hello CGI\n"; +print "test=Hello CGI\r\n"; foreach my $k (sort keys %$params) { - print "param-$k=$params->{$k}\n"; + print "param-$k=$params->{$k}\r\n"; } foreach my $k (sort keys %ENV) { - my $clean_env = $ENV{$k}; - $clean_env =~ s/[\n\r]//g; - print "env-$k=$clean_env\n"; + my $clean_env = $ENV{$k}; + $clean_env =~ s/[\n\r]//g; + print "env-$k=$clean_env\r\n"; } -# NOTE: don't call getcwd() for windows. -# msys return /c/go/src/... not C:\go\... -my $dir; +# NOTE: msys perl returns /c/go/src/... not C:\go\.... +my $dir = getcwd(); if ($^O eq 'MSWin32' || $^O eq 'msys') { - my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe'; - $cmd =~ s!\\!/!g; - $dir = `$cmd /c cd`; - chomp $dir; -} else { - $dir = getcwd(); + if ($dir =~ /^.:/) { + $dir =~ s!/!\\!g; + } else { + my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe'; + $cmd =~ s!\\!/!g; + $dir = `$cmd /c cd`; + chomp $dir; + } } -print "cwd=$dir\n"; - +print "cwd=$dir\r\n"; # A minimal version of CGI.pm, for people without the perl-modules # package installed. (CGI.pm used to be part of the Perl core, but diff --git a/src/pkg/net/http/chunked.go b/src/pkg/net/http/chunked.go index 60a478fd8..91db01724 100644 --- a/src/pkg/net/http/chunked.go +++ b/src/pkg/net/http/chunked.go @@ -11,10 +11,9 @@ package http import ( "bufio" - "bytes" "errors" + "fmt" "io" - "strconv" ) const maxLineLength = 4096 // assumed <= bufio.defaultBufSize @@ -22,7 +21,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize var ErrLineTooLong = errors.New("header line too long") // newChunkedReader returns a new chunkedReader that translates the data read from r -// out of HTTP "chunked" format before returning it. +// out of HTTP "chunked" format before returning it. // The chunkedReader returns io.EOF when the final 0-length chunk is read. // // newChunkedReader is not needed by normal applications. The http package @@ -39,16 +38,17 @@ type chunkedReader struct { r *bufio.Reader n uint64 // unread bytes in chunk err error + buf [2]byte } func (cr *chunkedReader) beginChunk() { // chunk-size CRLF - var line string + var line []byte line, cr.err = readLine(cr.r) if cr.err != nil { return } - cr.n, cr.err = strconv.ParseUint(line, 16, 64) + cr.n, cr.err = parseHexUint(line) if cr.err != nil { return } @@ -74,9 +74,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { cr.n -= uint64(n) if cr.n == 0 && cr.err == nil { // end of chunk (CRLF) - b := make([]byte, 2) - if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil { - if b[0] != '\r' || b[1] != '\n' { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { cr.err = errors.New("malformed chunked encoding") } } @@ -88,7 +87,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // Give up if the line exceeds maxLineLength. // The returned bytes are a pointer into storage in // the bufio, so they are only valid until the next bufio read. -func readLineBytes(b *bufio.Reader) (p []byte, err error) { +func readLine(b *bufio.Reader) (p []byte, err error) { if p, err = b.ReadSlice('\n'); err != nil { // We always know when EOF is coming. // If the caller asked for a line, there should be a line. @@ -102,20 +101,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) { if len(p) >= maxLineLength { return nil, ErrLineTooLong } - - // Chop off trailing white space. - p = bytes.TrimRight(p, " \r\t\n") - - return p, nil + return trimTrailingWhitespace(p), nil } -// readLineBytes, but convert the bytes into a string. -func readLine(b *bufio.Reader) (s string, err error) { - p, e := readLineBytes(b) - if e != nil { - return "", e +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] } - return string(p), nil + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' } // newChunkedWriter returns a new chunkedWriter that translates writes into HTTP @@ -147,9 +144,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) { return 0, nil } - head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" - - if _, err = io.WriteString(cw.Wire, head); err != nil { + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { return 0, err } if n, err = cw.Wire.Write(data); err != nil { @@ -168,3 +163,21 @@ func (cw *chunkedWriter) Close() error { _, err := io.WriteString(cw.Wire, "0\r\n") return err } + +func parseHexUint(v []byte) (n uint64, err error) { + for _, b := range v { + n <<= 4 + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + n |= uint64(b) + } + return +} diff --git a/src/pkg/net/http/chunked_test.go b/src/pkg/net/http/chunked_test.go index b77ee2ff2..0b18c7b55 100644 --- a/src/pkg/net/http/chunked_test.go +++ b/src/pkg/net/http/chunked_test.go @@ -9,7 +9,10 @@ package http import ( "bytes" + "fmt" + "io" "io/ioutil" + "runtime" "testing" ) @@ -37,3 +40,54 @@ func TestChunk(t *testing.T) { t.Errorf("chunk reader read %q; want %q", g, e) } } + +func TestChunkReaderAllocs(t *testing.T) { + // temporarily set GOMAXPROCS to 1 as we are testing memory allocations + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + var buf bytes.Buffer + w := newChunkedWriter(&buf) + a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") + w.Write(a) + w.Write(b) + w.Write(c) + w.Close() + + r := newChunkedReader(&buf) + readBuf := make([]byte, len(a)+len(b)+len(c)+1) + + var ms runtime.MemStats + runtime.ReadMemStats(&ms) + m0 := ms.Mallocs + + n, err := io.ReadFull(r, readBuf) + + runtime.ReadMemStats(&ms) + mallocs := ms.Mallocs - m0 + if mallocs > 1 { + t.Errorf("%d mallocs; want <= 1", mallocs) + } + + if n != len(readBuf)-1 { + t.Errorf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("read error = %v; want ErrUnexpectedEOF", err) + } +} + +func TestParseHexUint(t *testing.T) { + for i := uint64(0); i <= 1234; i++ { + line := []byte(fmt.Sprintf("%x", i)) + got, err := parseHexUint(line) + if err != nil { + t.Fatalf("on %d: %v", i, err) + } + if got != i { + t.Errorf("for input %q = %d; want %d", line, got, i) + } + } + _, err := parseHexUint([]byte("bogus")) + if err == nil { + t.Error("expected error on bogus input") + } +} diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go index 54564e098..5ee0804c7 100644 --- a/src/pkg/net/http/client.go +++ b/src/pkg/net/http/client.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // HTTP client. See RFC 2616. -// +// // This is the high-level Client interface. // The low-level implementation is in transport.go. @@ -14,6 +14,7 @@ import ( "errors" "fmt" "io" + "log" "net/url" "strings" ) @@ -32,17 +33,19 @@ type Client struct { // CheckRedirect specifies the policy for handling redirects. // If CheckRedirect is not nil, the client calls it before - // following an HTTP redirect. The arguments req and via - // are the upcoming request and the requests made already, - // oldest first. If CheckRedirect returns an error, the client - // returns that error instead of issue the Request req. + // following an HTTP redirect. The arguments req and via are + // the upcoming request and the requests made already, oldest + // first. If CheckRedirect returns an error, the Client's Get + // method returns both the previous Response and + // CheckRedirect's error (wrapped in a url.Error) instead of + // issuing the Request req. // // If CheckRedirect is nil, the Client uses its default policy, // which is to stop after 10 consecutive requests. CheckRedirect func(req *Request, via []*Request) error - // Jar specifies the cookie jar. - // If Jar is nil, cookies are not sent in requests and ignored + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar CookieJar } @@ -84,10 +87,32 @@ type readClose struct { io.Closer } +func (c *Client) send(req *Request) (*Response, error) { + if c.Jar != nil { + for _, cookie := range c.Jar.Cookies(req.URL) { + req.AddCookie(cookie) + } + } + resp, err := send(req, c.Transport) + if err != nil { + return nil, err + } + if c.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + c.Jar.SetCookies(req.URL, rc) + } + } + return resp, err +} + // Do sends an HTTP request and returns an HTTP response, following // policy (e.g. redirects, cookies, auth) as configured on the client. // -// A non-nil response always contains a non-nil resp.Body. +// An error is returned if caused by client policy (such as +// CheckRedirect), or if there was an HTTP protocol error. +// A non-2xx response doesn't cause an error. +// +// When err is nil, resp always contains a non-nil resp.Body. // // Callers should close resp.Body when done reading from it. If // resp.Body is not closed, the Client's underlying RoundTripper @@ -97,12 +122,16 @@ type readClose struct { // Generally Get, Post, or PostForm will be used instead of Do. func (c *Client) Do(req *Request) (resp *Response, err error) { if req.Method == "GET" || req.Method == "HEAD" { - return c.doFollowingRedirects(req) + return c.doFollowingRedirects(req, shouldRedirectGet) } - return send(req, c.Transport) + if req.Method == "POST" || req.Method == "PUT" { + return c.doFollowingRedirects(req, shouldRedirectPost) + } + return c.send(req) } -// send issues an HTTP request. Caller should close resp.Body when done reading from it. +// send issues an HTTP request. +// Caller should close resp.Body when done reading from it. func send(req *Request, t RoundTripper) (resp *Response, err error) { if t == nil { t = DefaultTransport @@ -130,12 +159,19 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { if u := req.URL.User; u != nil { req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(u.String()))) } - return t.RoundTrip(req) + resp, err = t.RoundTrip(req) + if err != nil { + if resp != nil { + log.Printf("RoundTripper returned a response & error; ignoring response") + } + return nil, err + } + return resp, nil } // True if the specified HTTP status code is one for which the Get utility should // automatically redirect. -func shouldRedirect(statusCode int) bool { +func shouldRedirectGet(statusCode int) bool { switch statusCode { case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect: return true @@ -143,6 +179,16 @@ func shouldRedirect(statusCode int) bool { return false } +// True if the specified HTTP status code is one for which the Post utility should +// automatically redirect. +func shouldRedirectPost(statusCode int) bool { + switch statusCode { + case StatusFound, StatusSeeOther: + return true + } + return false +} + // Get issues a GET to the specified URL. If the response is one of the following // redirect codes, Get follows the redirect, up to a maximum of 10 redirects: // @@ -151,10 +197,15 @@ func shouldRedirect(statusCode int) bool { // 303 (See Other) // 307 (Temporary Redirect) // -// Caller should close r.Body when done reading from it. +// An error is returned if there were too many redirects or if there +// was an HTTP protocol error. A non-2xx response doesn't cause an +// error. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. // // Get is a wrapper around DefaultClient.Get. -func Get(url string) (r *Response, err error) { +func Get(url string) (resp *Response, err error) { return DefaultClient.Get(url) } @@ -167,18 +218,21 @@ func Get(url string) (r *Response, err error) { // 303 (See Other) // 307 (Temporary Redirect) // -// Caller should close r.Body when done reading from it. -func (c *Client) Get(url string) (r *Response, err error) { +// An error is returned if the Client's CheckRedirect function fails +// or if there was an HTTP protocol error. A non-2xx response doesn't +// cause an error. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +func (c *Client) Get(url string) (resp *Response, err error) { req, err := NewRequest("GET", url, nil) if err != nil { return nil, err } - return c.doFollowingRedirects(req) + return c.doFollowingRedirects(req, shouldRedirectGet) } -func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { - // TODO: if/when we add cookie support, the redirected request shouldn't - // necessarily supply the same cookies as the original. +func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) { var base *url.URL redirectChecker := c.CheckRedirect if redirectChecker == nil { @@ -190,17 +244,16 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { return nil, errors.New("http: nil Request.URL") } - jar := c.Jar - if jar == nil { - jar = blackHoleJar{} - } - req := ireq urlStr := "" // next relative or absolute URL to fetch (after first request) + redirectFailed := false for redirect := 0; ; redirect++ { if redirect != 0 { req = new(Request) req.Method = ireq.Method + if ireq.Method == "POST" || ireq.Method == "PUT" { + req.Method = "GET" + } req.Header = make(Header) req.URL, err = base.Parse(urlStr) if err != nil { @@ -215,26 +268,21 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { err = redirectChecker(req, via) if err != nil { + redirectFailed = true break } } } - for _, cookie := range jar.Cookies(req.URL) { - req.AddCookie(cookie) - } urlStr = req.URL.String() - if r, err = send(req, c.Transport); err != nil { + if resp, err = c.send(req); err != nil { break } - if c := r.Cookies(); len(c) > 0 { - jar.SetCookies(req.URL, c) - } - if shouldRedirect(r.StatusCode) { - r.Body.Close() - if urlStr = r.Header.Get("Location"); urlStr == "" { - err = errors.New(fmt.Sprintf("%d response missing Location header", r.StatusCode)) + if shouldRedirect(resp.StatusCode) { + resp.Body.Close() + if urlStr = resp.Header.Get("Location"); urlStr == "" { + err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode)) break } base = req.URL @@ -245,12 +293,23 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { } method := ireq.Method - err = &url.Error{ + urlErr := &url.Error{ Op: method[0:1] + strings.ToLower(method[1:]), URL: urlStr, Err: err, } - return + + if redirectFailed { + // Special case for Go 1 compatibility: return both the response + // and an error if the CheckRedirect function failed. + // See http://golang.org/issue/3795 + return resp, urlErr + } + + if resp != nil { + resp.Body.Close() + } + return nil, urlErr } func defaultCheckRedirect(req *Request, via []*Request) error { @@ -262,49 +321,42 @@ func defaultCheckRedirect(req *Request, via []*Request) error { // Post issues a POST to the specified URL. // -// Caller should close r.Body when done reading from it. +// Caller should close resp.Body when done reading from it. // // Post is a wrapper around DefaultClient.Post -func Post(url string, bodyType string, body io.Reader) (r *Response, err error) { +func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { return DefaultClient.Post(url, bodyType, body) } // Post issues a POST to the specified URL. // -// Caller should close r.Body when done reading from it. -func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err error) { +// Caller should close resp.Body when done reading from it. +func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { req, err := NewRequest("POST", url, body) if err != nil { return nil, err } req.Header.Set("Content-Type", bodyType) - if c.Jar != nil { - for _, cookie := range c.Jar.Cookies(req.URL) { - req.AddCookie(cookie) - } - } - r, err = send(req, c.Transport) - if err == nil && c.Jar != nil { - c.Jar.SetCookies(req.URL, r.Cookies()) - } - return r, err + return c.doFollowingRedirects(req, shouldRedirectPost) } -// PostForm issues a POST to the specified URL, -// with data's keys and values urlencoded as the request body. +// PostForm issues a POST to the specified URL, with data's keys and +// values URL-encoded as the request body. // -// Caller should close r.Body when done reading from it. +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. // // PostForm is a wrapper around DefaultClient.PostForm -func PostForm(url string, data url.Values) (r *Response, err error) { +func PostForm(url string, data url.Values) (resp *Response, err error) { return DefaultClient.PostForm(url, data) } -// PostForm issues a POST to the specified URL, +// PostForm issues a POST to the specified URL, // with data's keys and values urlencoded as the request body. // -// Caller should close r.Body when done reading from it. -func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) { +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) } @@ -318,7 +370,7 @@ func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) // 307 (Temporary Redirect) // // Head is a wrapper around DefaultClient.Head -func Head(url string) (r *Response, err error) { +func Head(url string) (resp *Response, err error) { return DefaultClient.Head(url) } @@ -330,10 +382,10 @@ func Head(url string) (r *Response, err error) { // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) -func (c *Client) Head(url string) (r *Response, err error) { +func (c *Client) Head(url string) (resp *Response, err error) { req, err := NewRequest("HEAD", url, nil) if err != nil { return nil, err } - return c.doFollowingRedirects(req) + return c.doFollowingRedirects(req, shouldRedirectGet) } diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go index 9b4261b9f..88649bb16 100644 --- a/src/pkg/net/http/client_test.go +++ b/src/pkg/net/http/client_test.go @@ -7,7 +7,9 @@ package http_test import ( + "bytes" "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -53,6 +55,7 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { } func TestClient(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -70,6 +73,7 @@ func TestClient(t *testing.T) { } func TestClientHead(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -92,6 +96,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) } func TestGetRequestFormat(t *testing.T) { + defer checkLeakedTransports(t) tr := &recordingTransport{} client := &Client{Transport: tr} url := "http://dummy.faketld/" @@ -108,6 +113,7 @@ func TestGetRequestFormat(t *testing.T) { } func TestPostRequestFormat(t *testing.T) { + defer checkLeakedTransports(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -134,6 +140,7 @@ func TestPostRequestFormat(t *testing.T) { } func TestPostFormRequestFormat(t *testing.T) { + defer checkLeakedTransports(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -175,6 +182,7 @@ func TestPostFormRequestFormat(t *testing.T) { } func TestRedirects(t *testing.T) { + defer checkLeakedTransports(t) var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { n, _ := strconv.Atoi(r.FormValue("n")) @@ -218,6 +226,10 @@ func TestRedirects(t *testing.T) { return checkErr }} res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + res.Body.Close() finalUrl := res.Request.URL.String() if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { t.Errorf("with custom client, expected error %q, got %q", e, g) @@ -231,9 +243,63 @@ func TestRedirects(t *testing.T) { checkErr = errors.New("no redirects allowed") res, err = c.Get(ts.URL) - finalUrl = res.Request.URL.String() - if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { - t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr { + t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err) + } + if res == nil { + t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)") + } + res.Body.Close() + if res.Header.Get("Location") == "" { + t.Errorf("no Location header in Response") + } +} + +func TestPostRedirects(t *testing.T) { + defer checkLeakedTransports(t) + var log struct { + sync.Mutex + bytes.Buffer + } + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + log.Lock() + fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI) + log.Unlock() + if v := r.URL.Query().Get("code"); v != "" { + code, _ := strconv.Atoi(v) + if code/100 == 3 { + w.Header().Set("Location", ts.URL) + } + w.WriteHeader(code) + } + })) + defer ts.Close() + tests := []struct { + suffix string + want int // response code + }{ + {"/", 200}, + {"/?code=301", 301}, + {"/?code=302", 200}, + {"/?code=303", 200}, + {"/?code=404", 404}, + } + for _, tt := range tests { + res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content")) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != tt.want { + t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want) + } + } + log.Lock() + got := log.String() + log.Unlock() + want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 " + if got != want { + t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want) } } @@ -279,6 +345,10 @@ func TestClientSendsCookieFromJar(t *testing.T) { req, _ := NewRequest("GET", us, nil) client.Do(req) // Note: doesn't hit network matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) + + req, _ = NewRequest("POST", us, nil) + client.Do(req) // Note: doesn't hit network + matchReturnedCookies(t, expectedCookies, tr.req.Cookies()) } // Just enough correctness for our redirect tests. Uses the URL.Host as the @@ -291,6 +361,9 @@ type TestJar struct { func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) { j.m.Lock() defer j.m.Unlock() + if j.perURL == nil { + j.perURL = make(map[string][]*Cookie) + } j.perURL[u.Host] = cookies } @@ -301,6 +374,7 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { } func TestRedirectCookiesOnRequest(t *testing.T) { + defer checkLeakedTransports(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() @@ -318,14 +392,20 @@ func TestRedirectCookiesOnRequest(t *testing.T) { } func TestRedirectCookiesJar(t *testing.T) { + defer checkLeakedTransports(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() - c := &Client{} - c.Jar = &TestJar{perURL: make(map[string][]*Cookie)} + c := &Client{ + Jar: new(TestJar), + } u, _ := url.Parse(ts.URL) c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) - resp, _ := c.Get(ts.URL) + resp, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + resp.Body.Close() matchReturnedCookies(t, expectedCookies, resp.Cookies()) } @@ -348,7 +428,72 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { } } +func TestJarCalls(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + pathSuffix := r.RequestURI[1:] + if r.RequestURI == "/nosetcookie" { + return // dont set cookies for this path + } + SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix}) + if r.RequestURI == "/" { + Redirect(w, r, "http://secondhost.fake/secondpath", 302) + } + })) + defer ts.Close() + jar := new(RecordingJar) + c := &Client{ + Jar: jar, + Transport: &Transport{ + Dial: func(_ string, _ string) (net.Conn, error) { + return net.Dial("tcp", ts.Listener.Addr().String()) + }, + }, + } + _, err := c.Get("http://firsthost.fake/") + if err != nil { + t.Fatal(err) + } + _, err = c.Get("http://firsthost.fake/nosetcookie") + if err != nil { + t.Fatal(err) + } + got := jar.log.String() + want := `Cookies("http://firsthost.fake/") +SetCookie("http://firsthost.fake/", [name=val]) +Cookies("http://secondhost.fake/secondpath") +SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath]) +Cookies("http://firsthost.fake/nosetcookie") +` + if got != want { + t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want) + } +} + +// RecordingJar keeps a log of calls made to it, without +// tracking any cookies. +type RecordingJar struct { + mu sync.Mutex + log bytes.Buffer +} + +func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) { + j.logf("SetCookie(%q, %v)\n", u, cookies) +} + +func (j *RecordingJar) Cookies(u *url.URL) []*Cookie { + j.logf("Cookies(%q)\n", u) + return nil +} + +func (j *RecordingJar) logf(format string, args ...interface{}) { + j.mu.Lock() + defer j.mu.Unlock() + fmt.Fprintf(&j.log, format, args...) +} + func TestStreamingGet(t *testing.T) { + defer checkLeakedTransports(t) say := make(chan string) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() @@ -399,6 +544,7 @@ func (c *writeCountingConn) Write(p []byte) (int, error) { // TestClientWrites verifies that client requests are buffered and we // don't send a TCP packet per line of the http request + body. func TestClientWrites(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() @@ -432,6 +578,7 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) })) @@ -446,15 +593,20 @@ func TestClientInsecureTransport(t *testing.T) { InsecureSkipVerify: insecure, }, } + defer tr.CloseIdleConnections() c := &Client{Transport: tr} - _, err := c.Get(ts.URL) + res, err := c.Get(ts.URL) if (err == nil) != insecure { t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) } + if res != nil { + res.Body.Close() + } } } func TestClientErrorWithRequestURI(t *testing.T) { + defer checkLeakedTransports(t) req, _ := NewRequest("GET", "http://localhost:1234/", nil) req.RequestURI = "/this/field/is/illegal/and/should/error/" _, err := DefaultClient.Do(req) @@ -465,3 +617,87 @@ func TestClientErrorWithRequestURI(t *testing.T) { t.Errorf("wanted error mentioning RequestURI; got error: %v", err) } } + +func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { + certs := x509.NewCertPool() + for _, c := range ts.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + return &Transport{ + TLSClientConfig: &tls.Config{RootCAs: certs}, + } +} + +func TestClientWithCorrectTLSServerName(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.TLS.ServerName != "127.0.0.1" { + t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) + } + })) + defer ts.Close() + + c := &Client{Transport: newTLSTransport(t, ts)} + if _, err := c.Get(ts.URL); err != nil { + t.Fatalf("expected successful TLS connection, got error: %v", err) + } +} + +func TestClientWithIncorrectTLSServerName(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + + trans := newTLSTransport(t, ts) + trans.TLSClientConfig.ServerName = "badserver" + c := &Client{Transport: trans} + _, err := c.Get(ts.URL) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { + t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) + } +} + +// Verify Response.ContentLength is populated. http://golang.org/issue/4126 +func TestClientHeadContentLength(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if v := r.FormValue("cl"); v != "" { + w.Header().Set("Content-Length", v) + } + })) + defer ts.Close() + tests := []struct { + suffix string + want int64 + }{ + {"/?cl=1234", 1234}, + {"/?cl=0", 0}, + {"", -1}, + } + for _, tt := range tests { + req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil) + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.ContentLength != tt.want { + t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want) + } + bs, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != 0 { + t.Errorf("Unexpected content: %q", bs) + } + } +} diff --git a/src/pkg/net/http/cookie.go b/src/pkg/net/http/cookie.go index 2e30bbff1..155b09223 100644 --- a/src/pkg/net/http/cookie.go +++ b/src/pkg/net/http/cookie.go @@ -26,7 +26,7 @@ type Cookie struct { Expires time.Time RawExpires string - // MaxAge=0 means no 'Max-Age' attribute specified. + // MaxAge=0 means no 'Max-Age' attribute specified. // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' // MaxAge>0 means Max-Age attribute present and given in seconds MaxAge int @@ -258,10 +258,5 @@ func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool) } func isCookieNameValid(raw string) bool { - for _, c := range raw { - if !isToken(byte(c)) { - return false - } - } - return true + return strings.IndexFunc(raw, isNotToken) < 0 } diff --git a/src/pkg/net/http/cookie_test.go b/src/pkg/net/http/cookie_test.go index 1e9186a05..f84f73936 100644 --- a/src/pkg/net/http/cookie_test.go +++ b/src/pkg/net/http/cookie_test.go @@ -217,7 +217,7 @@ var readCookiesTests = []struct { func TestReadCookies(t *testing.T) { for i, tt := range readCookiesTests { - for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input + for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input c := readCookies(tt.Header, tt.Filter) if !reflect.DeepEqual(c, tt.Cookies) { t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies)) diff --git a/src/pkg/net/http/cookiejar/jar.go b/src/pkg/net/http/cookiejar/jar.go new file mode 100644 index 000000000..5d1aeb87f --- /dev/null +++ b/src/pkg/net/http/cookiejar/jar.go @@ -0,0 +1,494 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar. +package cookiejar + +import ( + "errors" + "fmt" + "net" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" +) + +// PublicSuffixList provides the public suffix of a domain. For example: +// - the public suffix of "example.com" is "com", +// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and +// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us". +// +// Implementations of PublicSuffixList must be safe for concurrent use by +// multiple goroutines. +// +// An implementation that always returns "" is valid and may be useful for +// testing but it is not secure: it means that the HTTP server for foo.com can +// set a cookie for bar.com. +type PublicSuffixList interface { + // PublicSuffix returns the public suffix of domain. + // + // TODO: specify which of the caller and callee is responsible for IP + // addresses, for leading and trailing dots, for case sensitivity, and + // for IDN/Punycode. + PublicSuffix(domain string) string + + // String returns a description of the source of this public suffix + // list. The description will typically contain something like a time + // stamp or version number. + String() string +} + +// Options are the options for creating a new Jar. +type Options struct { + // PublicSuffixList is the public suffix list that determines whether + // an HTTP server can set a cookie for a domain. + // + // A nil value is valid and may be useful for testing but it is not + // secure: it means that the HTTP server for foo.co.uk can set a cookie + // for bar.co.uk. + PublicSuffixList PublicSuffixList +} + +// Jar implements the http.CookieJar interface from the net/http package. +type Jar struct { + psList PublicSuffixList + + // mu locks the remaining fields. + mu sync.Mutex + + // entries is a set of entries, keyed by their eTLD+1 and subkeyed by + // their name/domain/path. + entries map[string]map[string]entry + + // nextSeqNum is the next sequence number assigned to a new cookie + // created SetCookies. + nextSeqNum uint64 +} + +// New returns a new cookie jar. A nil *Options is equivalent to a zero +// Options. +func New(o *Options) (*Jar, error) { + jar := &Jar{ + entries: make(map[string]map[string]entry), + } + if o != nil { + jar.psList = o.PublicSuffixList + } + return jar, nil +} + +// entry is the internal representation of a cookie. +// +// This struct type is not used outside of this package per se, but the exported +// fields are those of RFC 6265. +type entry struct { + Name string + Value string + Domain string + Path string + Secure bool + HttpOnly bool + Persistent bool + HostOnly bool + Expires time.Time + Creation time.Time + LastAccess time.Time + + // seqNum is a sequence number so that Cookies returns cookies in a + // deterministic order, even for cookies that have equal Path length and + // equal Creation time. This simplifies testing. + seqNum uint64 +} + +// Id returns the domain;path;name triple of e as an id. +func (e *entry) id() string { + return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name) +} + +// shouldSend determines whether e's cookie qualifies to be included in a +// request to host/path. It is the caller's responsibility to check if the +// cookie is expired. +func (e *entry) shouldSend(https bool, host, path string) bool { + return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure) +} + +// domainMatch implements "domain-match" of RFC 6265 section 5.1.3. +func (e *entry) domainMatch(host string) bool { + if e.Domain == host { + return true + } + return !e.HostOnly && hasDotSuffix(host, e.Domain) +} + +// pathMatch implements "path-match" according to RFC 6265 section 5.1.4. +func (e *entry) pathMatch(requestPath string) bool { + if requestPath == e.Path { + return true + } + if strings.HasPrefix(requestPath, e.Path) { + if e.Path[len(e.Path)-1] == '/' { + return true // The "/any/" matches "/any/path" case. + } else if requestPath[len(e.Path)] == '/' { + return true // The "/any" matches "/any/path" case. + } + } + return false +} + +// hasDotSuffix returns whether s ends in "."+suffix. +func hasDotSuffix(s, suffix string) bool { + return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix +} + +// byPathLength is a []entry sort.Interface that sorts according to RFC 6265 +// section 5.4 point 2: by longest path and then by earliest creation time. +type byPathLength []entry + +func (s byPathLength) Len() int { return len(s) } + +func (s byPathLength) Less(i, j int) bool { + if len(s[i].Path) != len(s[j].Path) { + return len(s[i].Path) > len(s[j].Path) + } + if !s[i].Creation.Equal(s[j].Creation) { + return s[i].Creation.Before(s[j].Creation) + } + return s[i].seqNum < s[j].seqNum +} + +func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Cookies implements the Cookies method of the http.CookieJar interface. +// +// It returns an empty slice if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) { + return j.cookies(u, time.Now()) +} + +// cookies is like Cookies but takes the current time as a parameter. +func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { + if u.Scheme != "http" && u.Scheme != "https" { + return cookies + } + host, err := canonicalHost(u.Host) + if err != nil { + return cookies + } + key := jarKey(host, j.psList) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + if submap == nil { + return cookies + } + + https := u.Scheme == "https" + path := u.Path + if path == "" { + path = "/" + } + + modified := false + var selected []entry + for id, e := range submap { + if e.Persistent && !e.Expires.After(now) { + delete(submap, id) + modified = true + continue + } + if !e.shouldSend(https, host, path) { + continue + } + e.LastAccess = now + submap[id] = e + selected = append(selected, e) + modified = true + } + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } + + sort.Sort(byPathLength(selected)) + for _, e := range selected { + cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value}) + } + + return cookies +} + +// SetCookies implements the SetCookies method of the http.CookieJar interface. +// +// It does nothing if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.setCookies(u, cookies, time.Now()) +} + +// setCookies is like SetCookies but takes the current time as parameter. +func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) { + if len(cookies) == 0 { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + host, err := canonicalHost(u.Host) + if err != nil { + return + } + key := jarKey(host, j.psList) + defPath := defaultPath(u.Path) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + + modified := false + for _, cookie := range cookies { + e, remove, err := j.newEntry(cookie, now, defPath, host) + if err != nil { + continue + } + id := e.id() + if remove { + if submap != nil { + if _, ok := submap[id]; ok { + delete(submap, id) + modified = true + } + } + continue + } + if submap == nil { + submap = make(map[string]entry) + } + + if old, ok := submap[id]; ok { + e.Creation = old.Creation + e.seqNum = old.seqNum + } else { + e.Creation = now + e.seqNum = j.nextSeqNum + j.nextSeqNum++ + } + e.LastAccess = now + submap[id] = e + modified = true + } + + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } +} + +// canonicalHost strips port from host if present and returns the canonicalized +// host name. +func canonicalHost(host string) (string, error) { + var err error + host = strings.ToLower(host) + if hasPort(host) { + host, _, err = net.SplitHostPort(host) + if err != nil { + return "", err + } + } + if strings.HasSuffix(host, ".") { + // Strip trailing dot from fully qualified domain names. + host = host[:len(host)-1] + } + return toASCII(host) +} + +// hasPort returns whether host contains a port number. host may be a host +// name, an IPv4 or an IPv6 address. +func hasPort(host string) bool { + colons := strings.Count(host, ":") + if colons == 0 { + return false + } + if colons == 1 { + return true + } + return host[0] == '[' && strings.Contains(host, "]:") +} + +// jarKey returns the key to use for a jar. +func jarKey(host string, psl PublicSuffixList) string { + if isIP(host) { + return host + } + + var i int + if psl == nil { + i = strings.LastIndex(host, ".") + if i == -1 { + return host + } + } else { + suffix := psl.PublicSuffix(host) + if suffix == host { + return host + } + i = len(host) - len(suffix) + if i <= 0 || host[i-1] != '.' { + // The provided public suffix list psl is broken. + // Storing cookies under host is a safe stopgap. + return host + } + } + prevDot := strings.LastIndex(host[:i-1], ".") + return host[prevDot+1:] +} + +// isIP returns whether host is an IP address. +func isIP(host string) bool { + return net.ParseIP(host) != nil +} + +// defaultPath returns the directory part of an URL's path according to +// RFC 6265 section 5.1.4. +func defaultPath(path string) string { + if len(path) == 0 || path[0] != '/' { + return "/" // Path is empty or malformed. + } + + i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1. + if i == 0 { + return "/" // Path has the form "/abc". + } + return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/". +} + +// newEntry creates an entry from a http.Cookie c. now is the current time and +// is compared to c.Expires to determine deletion of c. defPath and host are the +// default-path and the canonical host name of the URL c was received from. +// +// remove is whether the jar should delete this cookie, as it has already +// expired with respect to now. In this case, e may be incomplete, but it will +// be valid to call e.id (which depends on e's Name, Domain and Path). +// +// A malformed c.Domain will result in an error. +func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) { + e.Name = c.Name + + if c.Path == "" || c.Path[0] != '/' { + e.Path = defPath + } else { + e.Path = c.Path + } + + e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain) + if err != nil { + return e, false, err + } + + // MaxAge takes precedence over Expires. + if c.MaxAge < 0 { + return e, true, nil + } else if c.MaxAge > 0 { + e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second) + e.Persistent = true + } else { + if c.Expires.IsZero() { + e.Expires = endOfTime + e.Persistent = false + } else { + if !c.Expires.After(now) { + return e, true, nil + } + e.Expires = c.Expires + e.Persistent = true + } + } + + e.Value = c.Value + e.Secure = c.Secure + e.HttpOnly = c.HttpOnly + + return e, false, nil +} + +var ( + errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute") + errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute") + errNoHostname = errors.New("cookiejar: no host name available (IP only)") +) + +// endOfTime is the time when session (non-persistent) cookies expire. +// This instant is representable in most date/time formats (not just +// Go's time.Time) and should be far enough in the future. +var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) + +// domainAndType determines the cookie's domain and hostOnly attribute. +func (j *Jar) domainAndType(host, domain string) (string, bool, error) { + if domain == "" { + // No domain attribute in the SetCookie header indicates a + // host cookie. + return host, true, nil + } + + if isIP(host) { + // According to RFC 6265 domain-matching includes not being + // an IP address. + // TODO: This might be relaxed as in common browsers. + return "", false, errNoHostname + } + + // From here on: If the cookie is valid, it is a domain cookie (with + // the one exception of a public suffix below). + // See RFC 6265 section 5.2.3. + if domain[0] == '.' { + domain = domain[1:] + } + + if len(domain) == 0 || domain[0] == '.' { + // Received either "Domain=." or "Domain=..some.thing", + // both are illegal. + return "", false, errMalformedDomain + } + domain = strings.ToLower(domain) + + if domain[len(domain)-1] == '.' { + // We received stuff like "Domain=www.example.com.". + // Browsers do handle such stuff (actually differently) but + // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in + // requiring a reject. 4.1.2.3 is not normative, but + // "Domain Matching" (5.1.3) and "Canonicalized Host Names" + // (5.1.2) are. + return "", false, errMalformedDomain + } + + // See RFC 6265 section 5.3 #5. + if j.psList != nil { + if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) { + if host == domain { + // This is the one exception in which a cookie + // with a domain attribute is a host cookie. + return host, true, nil + } + return "", false, errIllegalDomain + } + } + + // The domain must domain-match host: www.mycompany.com cannot + // set cookies for .ourcompetitors.com. + if host != domain && !hasDotSuffix(host, domain) { + return "", false, errIllegalDomain + } + + return domain, false, nil +} diff --git a/src/pkg/net/http/cookiejar/jar_test.go b/src/pkg/net/http/cookiejar/jar_test.go new file mode 100644 index 000000000..3aa601586 --- /dev/null +++ b/src/pkg/net/http/cookiejar/jar_test.go @@ -0,0 +1,1267 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "testing" + "time" +) + +// tNow is the synthetic current time used as now during testing. +var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC) + +// testPSL implements PublicSuffixList with just two rules: "co.uk" +// and the default rule "*". +type testPSL struct{} + +func (testPSL) String() string { + return "testPSL" +} +func (testPSL) PublicSuffix(d string) string { + if d == "co.uk" || strings.HasSuffix(d, ".co.uk") { + return "co.uk" + } + return d[strings.LastIndex(d, ".")+1:] +} + +// newTestJar creates an empty Jar with testPSL as the public suffix list. +func newTestJar() *Jar { + jar, err := New(&Options{PublicSuffixList: testPSL{}}) + if err != nil { + panic(err) + } + return jar +} + +var hasDotSuffixTests = [...]struct { + s, suffix string +}{ + {"", ""}, + {"", "."}, + {"", "x"}, + {".", ""}, + {".", "."}, + {".", ".."}, + {".", "x"}, + {".", "x."}, + {".", ".x"}, + {".", ".x."}, + {"x", ""}, + {"x", "."}, + {"x", ".."}, + {"x", "x"}, + {"x", "x."}, + {"x", ".x"}, + {"x", ".x."}, + {".x", ""}, + {".x", "."}, + {".x", ".."}, + {".x", "x"}, + {".x", "x."}, + {".x", ".x"}, + {".x", ".x."}, + {"x.", ""}, + {"x.", "."}, + {"x.", ".."}, + {"x.", "x"}, + {"x.", "x."}, + {"x.", ".x"}, + {"x.", ".x."}, + {"com", ""}, + {"com", "m"}, + {"com", "om"}, + {"com", "com"}, + {"com", ".com"}, + {"com", "x.com"}, + {"com", "xcom"}, + {"com", "xorg"}, + {"com", "org"}, + {"com", "rg"}, + {"foo.com", ""}, + {"foo.com", "m"}, + {"foo.com", "om"}, + {"foo.com", "com"}, + {"foo.com", ".com"}, + {"foo.com", "o.com"}, + {"foo.com", "oo.com"}, + {"foo.com", "foo.com"}, + {"foo.com", ".foo.com"}, + {"foo.com", "x.foo.com"}, + {"foo.com", "xfoo.com"}, + {"foo.com", "xfoo.org"}, + {"foo.com", "foo.org"}, + {"foo.com", "oo.org"}, + {"foo.com", "o.org"}, + {"foo.com", ".org"}, + {"foo.com", "org"}, + {"foo.com", "rg"}, +} + +func TestHasDotSuffix(t *testing.T) { + for _, tc := range hasDotSuffixTests { + got := hasDotSuffix(tc.s, tc.suffix) + want := strings.HasSuffix(tc.s, "."+tc.suffix) + if got != want { + t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want) + } + } +} + +var canonicalHostTests = map[string]string{ + "www.example.com": "www.example.com", + "WWW.EXAMPLE.COM": "www.example.com", + "wWw.eXAmple.CoM": "www.example.com", + "www.example.com:80": "www.example.com", + "192.168.0.10": "192.168.0.10", + "192.168.0.5:8080": "192.168.0.5", + "2001:4860:0:2001::68": "2001:4860:0:2001::68", + "[2001:4860:0:::68]:8080": "2001:4860:0:::68", + "www.bücher.de": "www.xn--bcher-kva.de", + "www.example.com.": "www.example.com", + "[bad.unmatched.bracket:": "error", +} + +func TestCanonicalHost(t *testing.T) { + for h, want := range canonicalHostTests { + got, err := canonicalHost(h) + if want == "error" { + if err == nil { + t.Errorf("%q: got nil error, want non-nil", h) + } + continue + } + if err != nil { + t.Errorf("%q: %v", h, err) + continue + } + if got != want { + t.Errorf("%q: got %q, want %q", h, got, want) + continue + } + } +} + +var hasPortTests = map[string]bool{ + "www.example.com": false, + "www.example.com:80": true, + "127.0.0.1": false, + "127.0.0.1:8080": true, + "2001:4860:0:2001::68": false, + "[2001::0:::68]:80": true, +} + +func TestHasPort(t *testing.T) { + for host, want := range hasPortTests { + if got := hasPort(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var jarKeyTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "bbc.co.uk", + "www.bbc.co.uk": "bbc.co.uk", + "bbc.co.uk": "bbc.co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKey(t *testing.T) { + for host, want := range jarKeyTests { + if got := jarKey(host, testPSL{}); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var jarKeyNilPSLTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "co.uk", + "www.bbc.co.uk": "co.uk", + "bbc.co.uk": "co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKeyNilPSL(t *testing.T) { + for host, want := range jarKeyNilPSLTests { + if got := jarKey(host, nil); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var isIPTests = map[string]bool{ + "127.0.0.1": true, + "1.2.3.4": true, + "2001:4860:0:2001::68": true, + "example.com": false, + "1.1.1.300": false, + "www.foo.bar.net": false, + "123.foo.bar.net": false, +} + +func TestIsIP(t *testing.T) { + for host, want := range isIPTests { + if got := isIP(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var defaultPathTests = map[string]string{ + "/": "/", + "/abc": "/", + "/abc/": "/abc", + "/abc/xyz": "/abc", + "/abc/xyz/": "/abc/xyz", + "/a/b/c.html": "/a/b", + "": "/", + "strange": "/", + "//": "/", + "/a//b": "/a/", + "/a/./b": "/a/.", + "/a/../b": "/a/..", +} + +func TestDefaultPath(t *testing.T) { + for path, want := range defaultPathTests { + if got := defaultPath(path); got != want { + t.Errorf("%q: got %q, want %q", path, got, want) + } + } +} + +var domainAndTypeTests = [...]struct { + host string // host Set-Cookie header was received from + domain string // domain attribute in Set-Cookie header + wantDomain string // expected domain of cookie + wantHostOnly bool // expected host-cookie flag + wantErr error // expected error +}{ + {"www.example.com", "", "www.example.com", true, nil}, + {"127.0.0.1", "", "127.0.0.1", true, nil}, + {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil}, + {"www.example.com", "example.com", "example.com", false, nil}, + {"www.example.com", ".example.com", "example.com", false, nil}, + {"www.example.com", "www.example.com", "www.example.com", false, nil}, + {"www.example.com", ".www.example.com", "www.example.com", false, nil}, + {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil}, + {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil}, + {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil}, + {"127.0.0.1", "127.0.0.1", "", false, errNoHostname}, + {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", false, errNoHostname}, + {"www.example.com", ".", "", false, errMalformedDomain}, + {"www.example.com", "..", "", false, errMalformedDomain}, + {"www.example.com", "other.com", "", false, errIllegalDomain}, + {"www.example.com", "com", "", false, errIllegalDomain}, + {"www.example.com", ".com", "", false, errIllegalDomain}, + {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain}, + {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain}, + {"com", "", "com", true, nil}, + {"com", "com", "com", true, nil}, + {"com", ".com", "com", true, nil}, + {"co.uk", "", "co.uk", true, nil}, + {"co.uk", "co.uk", "co.uk", true, nil}, + {"co.uk", ".co.uk", "co.uk", true, nil}, +} + +func TestDomainAndType(t *testing.T) { + jar := newTestJar() + for _, tc := range domainAndTypeTests { + domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain) + if err != tc.wantErr { + t.Errorf("%q/%q: got %q error, want %q", + tc.host, tc.domain, err, tc.wantErr) + continue + } + if err != nil { + continue + } + if domain != tc.wantDomain || hostOnly != tc.wantHostOnly { + t.Errorf("%q/%q: got %q/%t want %q/%t", + tc.host, tc.domain, domain, hostOnly, + tc.wantDomain, tc.wantHostOnly) + } + } +} + +// expiresIn creates an expires attribute delta seconds from tNow. +func expiresIn(delta int) string { + t := tNow.Add(time.Duration(delta) * time.Second) + return "expires=" + t.Format(time.RFC1123) +} + +// mustParseURL parses s to an URL and panics on error. +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil || u.Scheme == "" || u.Host == "" { + panic(fmt.Sprintf("Unable to parse URL %s.", s)) + } + return u +} + +// jarTest encapsulates the following actions on a jar: +// 1. Perform SetCookies with fromURL and the cookies from setCookies. +// (Done at time tNow + 0 ms.) +// 2. Check that the entries in the jar matches content. +// (Done at time tNow + 1001 ms.) +// 3. For each query in tests: Check that Cookies with toURL yields the +// cookies in want. +// (Query n done at tNow + (n+2)*1001 ms.) +type jarTest struct { + description string // The description of what this test is supposed to test + fromURL string // The full URL of the request from which Set-Cookie headers where received + setCookies []string // All the cookies received from fromURL + content string // The whole (non-expired) content of the jar + queries []query // Queries to test the Jar.Cookies method +} + +// query contains one test of the cookies returned from Jar.Cookies. +type query struct { + toURL string // the URL in the Cookies call + want string // the expected list of cookies (order matters) +} + +// run runs the jarTest. +func (test jarTest) run(t *testing.T, jar *Jar) { + now := tNow + + // Populate jar with cookies. + setCookies := make([]*http.Cookie, len(test.setCookies)) + for i, cs := range test.setCookies { + cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies() + if len(cookies) != 1 { + panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies)) + } + setCookies[i] = cookies[0] + } + jar.setCookies(mustParseURL(test.fromURL), setCookies, now) + now = now.Add(1001 * time.Millisecond) + + // Serialize non-expired entries in the form "name1=val1 name2=val2". + var cs []string + for _, submap := range jar.entries { + for _, cookie := range submap { + if !cookie.Expires.After(now) { + continue + } + cs = append(cs, cookie.Name+"="+cookie.Value) + } + } + sort.Strings(cs) + got := strings.Join(cs, " ") + + // Make sure jar content matches our expectations. + if got != test.content { + t.Errorf("Test %q Content\ngot %q\nwant %q", + test.description, got, test.content) + } + + // Test different calls to Cookies. + for i, query := range test.queries { + now = now.Add(1001 * time.Millisecond) + var s []string + for _, c := range jar.cookies(mustParseURL(query.toURL), now) { + s = append(s, c.Name+"="+c.Value) + } + if got := strings.Join(s, " "); got != query.want { + t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want) + } + } +} + +// basicsTests contains fundamental tests. Each jarTest has to be performed on +// a fresh, empty Jar. +var basicsTests = [...]jarTest{ + { + "Retrieval of a plain host cookie.", + "http://www.host.test/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + {"ftp://www.host.test", ""}, + {"ftp://www.host.test/", ""}, + {"ftp://www.host.test/some/path", ""}, + {"http://www.other.org", ""}, + {"http://sibling.host.test", ""}, + {"http://deep.www.host.test", ""}, + }, + }, + { + "Secure cookies are not returned to http.", + "http://www.host.test/", + []string{"A=a; secure"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some/path", ""}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + }, + }, + { + "Explicit path.", + "http://www.host.test/", + []string{"A=a; path=/some/path"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #1: path is a directory.", + "http://www.host.test/some/path/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #2: path is not a directory.", + "http://www.host.test/some/path/index.html", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #3: no path in URL at all.", + "http://www.host.test", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + }, + }, + { + "Cookies are sorted by path length.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "D=d; path=/foo"}, + "A=a B=b C=c D=d", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"}, + {"http://www.host.test/foo/bar", "A=a D=d"}, + }, + }, + { + "Creation time determines sorting on same length paths.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "X=x; path=/foo/bar", + "Y=y; path=/foo/bar/baz/qux", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "W=w; path=/foo/bar/baz", + "Z=z; path=/foo", + "D=d; path=/foo"}, + "A=a B=b C=c D=d W=w X=x Y=y Z=z", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"}, + }, + }, + { + "Sorting of same-name cookies.", + "http://www.host.test/", + []string{ + "A=1; path=/", + "A=2; path=/path", + "A=3; path=/quux", + "A=4; path=/path/foo", + "A=5; domain=.host.test; path=/path", + "A=6; domain=.host.test; path=/quux", + "A=7; domain=.host.test; path=/path/foo", + }, + "A=1 A=2 A=3 A=4 A=5 A=6 A=7", + []query{ + {"http://www.host.test/path", "A=2 A=5 A=1"}, + {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"}, + }, + }, + { + "Disallow domain cookie on public suffix.", + "http://www.bbc.co.uk", + []string{ + "a=1", + "b=2; domain=co.uk", + }, + "a=1", + []query{{"http://www.bbc.co.uk", "a=1"}}, + }, + { + "Host cookie on IP.", + "http://192.168.0.10", + []string{"a=1"}, + "a=1", + []query{{"http://192.168.0.10", "a=1"}}, + }, + { + "Port is ignored #1.", + "http://www.host.test/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + }, + }, + { + "Port is ignored #2.", + "http://www.host.test:8080/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + {"http://www.host.test:1234/", "a=1"}, + }, + }, +} + +func TestBasics(t *testing.T) { + for _, test := range basicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// updateAndDeleteTests contains jarTests which must be performed on the same +// Jar. +var updateAndDeleteTests = [...]jarTest{ + { + "Set initial cookies.", + "http://www.host.test", + []string{ + "a=1", + "b=2; secure", + "c=3; httponly", + "d=4; secure; httponly"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://www.host.test", "a=1 c=3"}, + {"https://www.host.test", "a=1 b=2 c=3 d=4"}, + }, + }, + { + "Update value via http.", + "http://www.host.test", + []string{ + "a=w", + "b=x; secure", + "c=y; httponly", + "d=z; secure; httponly"}, + "a=w b=x c=y d=z", + []query{ + {"http://www.host.test", "a=w c=y"}, + {"https://www.host.test", "a=w b=x c=y d=z"}, + }, + }, + { + "Clear Secure flag from a http.", + "http://www.host.test/", + []string{ + "b=xx", + "d=zz; httponly"}, + "a=w b=xx c=y d=zz", + []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}}, + }, + { + "Delete all.", + "http://www.host.test/", + []string{ + "a=1; max-Age=-1", // delete via MaxAge + "b=2; " + expiresIn(-10), // delete via Expires + "c=2; max-age=-1; " + expiresIn(-10), // delete via both + "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence + "", + []query{{"http://www.host.test", ""}}, + }, + { + "Refill #1.", + "http://www.host.test", + []string{ + "A=1", + "A=2; path=/foo", + "A=3; domain=.host.test", + "A=4; path=/foo; domain=.host.test"}, + "A=1 A=2 A=3 A=4", + []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}}, + }, + { + "Refill #2.", + "http://www.google.com", + []string{ + "A=6", + "A=7; path=/foo", + "A=8; domain=.google.com", + "A=9; path=/foo; domain=.google.com"}, + "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"}, + }, + }, + { + "Delete A7.", + "http://www.google.com", + []string{"A=; path=/foo; max-age=-1"}, + "A=1 A=2 A=3 A=4 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A4.", + "http://www.host.test", + []string{"A=; path=/foo; domain=host.test; max-age=-1"}, + "A=1 A=2 A=3 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A6.", + "http://www.google.com", + []string{"A=; max-age=-1"}, + "A=1 A=2 A=3 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A3.", + "http://www.host.test", + []string{"A=; domain=host.test; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "No cross-domain delete.", + "http://www.host.test", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A8 and A9.", + "http://www.google.com", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", ""}, + }, + }, +} + +func TestUpdateAndDelete(t *testing.T) { + jar := newTestJar() + for _, test := range updateAndDeleteTests { + test.run(t, jar) + } +} + +func TestExpiration(t *testing.T) { + jar := newTestJar() + jarTest{ + "Expiration.", + "http://www.host.test", + []string{ + "a=1", + "b=2; max-age=3", + "c=3; " + expiresIn(3), + "d=4; max-age=5", + "e=5; " + expiresIn(5), + "f=6; max-age=100", + }, + "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms + []query{ + {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms + }, + }.run(t, jar) +} + +// +// Tests derived from Chromium's cookie_store_unittest.h. +// + +// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain +// Some of the original tests are in a bad condition (e.g. +// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g. +// TestNonDottedAndTLD #1 and #6) and have not been ported. + +// chromiumBasicsTests contains fundamental tests. Each jarTest has to be +// performed on a fresh, empty Jar. +var chromiumBasicsTests = [...]jarTest{ + { + "DomainWithTrailingDotTest.", + "http://www.google.com/", + []string{ + "a=1; domain=.www.google.com.", + "b=2; domain=.www.google.com.."}, + "", + []query{ + {"http://www.google.com", ""}, + }, + }, + { + "ValidSubdomainTest #1.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"}, + {"http://b.c.d.com", "b=2 c=3 d=4"}, + {"http://c.d.com", "c=3 d=4"}, + {"http://d.com", "d=4"}, + }, + }, + { + "ValidSubdomainTest #2.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com", + "X=bcd; domain=.b.c.d.com", + "X=cd; domain=.c.d.com"}, + "X=bcd X=cd a=1 b=2 c=3 d=4", + []query{ + {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"}, + {"http://c.d.com", "c=3 d=4 X=cd"}, + }, + }, + { + "InvalidDomainTest #1.", + "http://foo.bar.com", + []string{ + "a=1; domain=.yo.foo.bar.com", + "b=2; domain=.foo.com", + "c=3; domain=.bar.foo.com", + "d=4; domain=.foo.bar.com.net", + "e=5; domain=ar.com", + "f=6; domain=.", + "g=7; domain=/", + "h=8; domain=http://foo.bar.com", + "i=9; domain=..foo.bar.com", + "j=10; domain=..bar.com", + "k=11; domain=.foo.bar.com?blah", + "l=12; domain=.foo.bar.com/blah", + "m=12; domain=.foo.bar.com:80", + "n=14; domain=.foo.bar.com:", + "o=15; domain=.foo.bar.com#sup", + }, + "", // Jar is empty. + []query{{"http://foo.bar.com", ""}}, + }, + { + "InvalidDomainTest #2.", + "http://foo.com.com", + []string{"a=1; domain=.foo.com.com.com"}, + "", + []query{{"http://foo.bar.com", ""}}, + }, + { + "DomainWithoutLeadingDotTest #1.", + "http://manage.hosted.filefront.com", + []string{"a=1; domain=filefront.com"}, + "a=1", + []query{{"http://www.filefront.com", "a=1"}}, + }, + { + "DomainWithoutLeadingDotTest #2.", + "http://www.google.com", + []string{"a=1; domain=www.google.com"}, + "a=1", + []query{ + {"http://www.google.com", "a=1"}, + {"http://sub.www.google.com", "a=1"}, + {"http://something-else.com", ""}, + }, + }, + { + "CaseInsensitiveDomainTest.", + "http://www.google.com", + []string{ + "a=1; domain=.GOOGLE.COM", + "b=2; domain=.www.gOOgLE.coM"}, + "a=1 b=2", + []query{{"http://www.google.com", "a=1 b=2"}}, + }, + { + "TestIpAddress #1.", + "http://1.2.3.4/foo", + []string{"a=1; path=/"}, + "a=1", + []query{{"http://1.2.3.4/foo", "a=1"}}, + }, + { + "TestIpAddress #2.", + "http://1.2.3.4/foo", + []string{ + "a=1; domain=.1.2.3.4", + "b=2; domain=.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestIpAddress #3.", + "http://1.2.3.4/foo", + []string{"a=1; domain=1.2.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestNonDottedAndTLD #2.", + "http://com./index.html", + []string{"a=1"}, + "a=1", + []query{ + {"http://com./index.html", "a=1"}, + {"http://no-cookies.com./index.html", ""}, + }, + }, + { + "TestNonDottedAndTLD #3.", + "http://a.b", + []string{ + "a=1; domain=.b", + "b=2; domain=b"}, + "", + []query{{"http://bar.foo", ""}}, + }, + { + "TestNonDottedAndTLD #4.", + "http://google.com", + []string{ + "a=1; domain=.com", + "b=2; domain=com"}, + "", + []query{{"http://google.com", ""}}, + }, + { + "TestNonDottedAndTLD #5.", + "http://google.co.uk", + []string{ + "a=1; domain=.co.uk", + "b=2; domain=.uk"}, + "", + []query{ + {"http://google.co.uk", ""}, + {"http://else.co.com", ""}, + {"http://else.uk", ""}, + }, + }, + { + "TestHostEndsWithDot.", + "http://www.google.com", + []string{ + "a=1", + "b=2; domain=.www.google.com."}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "PathTest", + "http://www.google.izzle", + []string{"a=1; path=/wee"}, + "a=1", + []query{ + {"http://www.google.izzle/wee", "a=1"}, + {"http://www.google.izzle/wee/", "a=1"}, + {"http://www.google.izzle/wee/war", "a=1"}, + {"http://www.google.izzle/wee/war/more/more", "a=1"}, + {"http://www.google.izzle/weehee", ""}, + {"http://www.google.izzle/", ""}, + }, + }, +} + +func TestChromiumBasics(t *testing.T) { + for _, test := range chromiumBasicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// chromiumDomainTests contains jarTests which must be executed all on the +// same Jar. +var chromiumDomainTests = [...]jarTest{ + { + "Fill #1.", + "http://www.google.izzle", + []string{"A=B"}, + "A=B", + []query{{"http://www.google.izzle", "A=B"}}, + }, + { + "Fill #2.", + "http://www.google.izzle", + []string{"C=D; domain=.google.izzle"}, + "A=B C=D", + []query{{"http://www.google.izzle", "A=B C=D"}}, + }, + { + "Verify A is a host cookie and not accessible from subdomain.", + "http://unused.nil", + []string{}, + "A=B C=D", + []query{{"http://foo.www.google.izzle", "C=D"}}, + }, + { + "Verify domain cookies are found on proper domain.", + "http://www.google.izzle", + []string{"E=F; domain=.www.google.izzle"}, + "A=B C=D E=F", + []query{{"http://www.google.izzle", "A=B C=D E=F"}}, + }, + { + "Leading dots in domain attributes are optional.", + "http://www.google.izzle", + []string{"G=H; domain=www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #1.", + "http://www.google.izzle", + []string{"K=L; domain=.bar.www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #2.", + "http://unused.nil", + []string{}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, +} + +func TestChromiumDomain(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDomainTests { + test.run(t, jar) + } + +} + +// chromiumDeletionTests must be performed all on the same Jar. +var chromiumDeletionTests = [...]jarTest{ + { + "Create session cookie a1.", + "http://www.google.com", + []string{"a=1"}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "Delete sc a1 via MaxAge.", + "http://www.google.com", + []string{"a=1; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create session cookie b2.", + "http://www.google.com", + []string{"b=2"}, + "b=2", + []query{{"http://www.google.com", "b=2"}}, + }, + { + "Delete sc b2 via Expires.", + "http://www.google.com", + []string{"b=2; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie c3.", + "http://www.google.com", + []string{"c=3; max-age=3600"}, + "c=3", + []query{{"http://www.google.com", "c=3"}}, + }, + { + "Delete pc c3 via MaxAge.", + "http://www.google.com", + []string{"c=3; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie d4.", + "http://www.google.com", + []string{"d=4; max-age=3600"}, + "d=4", + []query{{"http://www.google.com", "d=4"}}, + }, + { + "Delete pc d4 via Expires.", + "http://www.google.com", + []string{"d=4; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, +} + +func TestChromiumDeletion(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDeletionTests { + test.run(t, jar) + } +} + +// domainHandlingTests tests and documents the rules for domain handling. +// Each test must be performed on an empty new Jar. +var domainHandlingTests = [...]jarTest{ + { + "Host cookie", + "http://www.host.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", ""}, + {"http://bar.host.test", ""}, + {"http://foo.www.host.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #1", + "http://www.host.test", + []string{"a=1; domain=host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #2", + "http://www.host.test", + []string{"a=1; domain=.host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on TLD.", + "http://com", + []string{"a=1"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Domain cookie on TLD becomes a host cookie.", + "http://com", + []string{"a=1; domain=com"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Host cookie on public suffix.", + "http://co.uk", + []string{"a=1"}, + "a=1", + []query{ + {"http://co.uk", "a=1"}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, + { + "Domain cookie on public suffix is ignored.", + "http://some.co.uk", + []string{"a=1; domain=co.uk"}, + "", + []query{ + {"http://co.uk", ""}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, +} + +func TestDomainHandling(t *testing.T) { + for _, test := range domainHandlingTests { + jar := newTestJar() + test.run(t, jar) + } +} diff --git a/src/pkg/net/http/cookiejar/punycode.go b/src/pkg/net/http/cookiejar/punycode.go new file mode 100644 index 000000000..ea7ceb5ef --- /dev/null +++ b/src/pkg/net/http/cookiejar/punycode.go @@ -0,0 +1,159 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +// This file implements the Punycode algorithm from RFC 3492. + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +// These parameter values are specified in section 5. +// +// All computation is done with int32s, so that overflow behavior is identical +// regardless of whether int is 32-bit or 64-bit. +const ( + base int32 = 36 + damp int32 = 700 + initialBias int32 = 72 + initialN int32 = 128 + skew int32 = 38 + tmax int32 = 26 + tmin int32 = 1 +) + +// encode encodes a string as specified in section 6.3 and prepends prefix to +// the result. +// +// The "while h < length(input)" line in the specification becomes "for +// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes. +func encode(prefix, s string) (string, error) { + output := make([]byte, len(prefix), len(prefix)+1+2*len(s)) + copy(output, prefix) + delta, n, bias := int32(0), initialN, initialBias + b, remaining := int32(0), int32(0) + for _, r := range s { + if r < 0x80 { + b++ + output = append(output, byte(r)) + } else { + remaining++ + } + } + h := b + if b > 0 { + output = append(output, '-') + } + for remaining != 0 { + m := int32(0x7fffffff) + for _, r := range s { + if m > r && r >= n { + m = r + } + } + delta += (m - n) * (h + 1) + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + n = m + for _, r := range s { + if r < n { + delta++ + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + continue + } + if r > n { + continue + } + q := delta + for k := base; ; k += base { + t := k - bias + if t < tmin { + t = tmin + } else if t > tmax { + t = tmax + } + if q < t { + break + } + output = append(output, encodeDigit(t+(q-t)%(base-t))) + q = (q - t) / (base - t) + } + output = append(output, encodeDigit(q)) + bias = adapt(delta, h+1, h == b) + delta = 0 + h++ + remaining-- + } + delta++ + n++ + } + return string(output), nil +} + +func encodeDigit(digit int32) byte { + switch { + case 0 <= digit && digit < 26: + return byte(digit + 'a') + case 26 <= digit && digit < 36: + return byte(digit + ('0' - 26)) + } + panic("cookiejar: internal error in punycode encoding") +} + +// adapt is the bias adaptation function specified in section 6.1. +func adapt(delta, numPoints int32, firstTime bool) int32 { + if firstTime { + delta /= damp + } else { + delta /= 2 + } + delta += delta / numPoints + k := int32(0) + for delta > ((base-tmin)*tmax)/2 { + delta /= base - tmin + k += base + } + return k + (base-tmin+1)*delta/(delta+skew) +} + +// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and +// friends) and not Punycode (RFC 3492) per se. + +// acePrefix is the ASCII Compatible Encoding prefix. +const acePrefix = "xn--" + +// toASCII converts a domain or domain label to its ASCII form. For example, +// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and +// toASCII("golang") is "golang". +func toASCII(s string) (string, error) { + if ascii(s) { + return s, nil + } + labels := strings.Split(s, ".") + for i, label := range labels { + if !ascii(label) { + a, err := encode(acePrefix, label) + if err != nil { + return "", err + } + labels[i] = a + } + } + return strings.Join(labels, "."), nil +} + +func ascii(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} diff --git a/src/pkg/net/http/cookiejar/punycode_test.go b/src/pkg/net/http/cookiejar/punycode_test.go new file mode 100644 index 000000000..0301de14e --- /dev/null +++ b/src/pkg/net/http/cookiejar/punycode_test.go @@ -0,0 +1,161 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "testing" +) + +var punycodeTestCases = [...]struct { + s, encoded string +}{ + {"", ""}, + {"-", "--"}, + {"-a", "-a-"}, + {"-a-", "-a--"}, + {"a", "a-"}, + {"a-", "a--"}, + {"a-b", "a-b-"}, + {"books", "books-"}, + {"bücher", "bcher-kva"}, + {"Hello世界", "Hello-ck1hg65u"}, + {"ü", "tda"}, + {"üý", "tdac"}, + + // The test cases below come from RFC 3492 section 7.1 with Errata 3026. + { + // (A) Arabic (Egyptian). + "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" + + "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F", + "egbpdaj6bu4bxfgehfvwxn", + }, + { + // (B) Chinese (simplified). + "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587", + "ihqwcrb4cv8a8dqg056pqjye", + }, + { + // (C) Chinese (traditional). + "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587", + "ihqwctvzc91f659drss3x8bo0yb", + }, + { + // (D) Czech. + "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" + + "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" + + "\u0065\u0073\u006B\u0079", + "Proprostnemluvesky-uyb24dma41a", + }, + { + // (E) Hebrew. + "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" + + "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" + + "\u05D1\u05E8\u05D9\u05EA", + "4dbcagdahymbxekheh6e0a7fei0b", + }, + { + // (F) Hindi (Devanagari). + "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" + + "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" + + "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" + + "\u0939\u0948\u0902", + "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd", + }, + { + // (G) Japanese (kanji and hiragana). + "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" + + "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B", + "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa", + }, + { + // (H) Korean (Hangul syllables). + "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" + + "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" + + "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C", + "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" + + "psd879ccm6fea98c", + }, + { + // (I) Russian (Cyrillic). + "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" + + "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" + + "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" + + "\u0438", + "b1abfaaepdrnnbgefbadotcwatmq2g4l", + }, + { + // (J) Spanish. + "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" + + "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" + + "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" + + "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" + + "\u0061\u00F1\u006F\u006C", + "PorqunopuedensimplementehablarenEspaol-fmd56a", + }, + { + // (K) Vietnamese. + "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" + + "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" + + "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" + + "\u0056\u0069\u1EC7\u0074", + "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g", + }, + { + // (L) 3<nen>B<gumi><kinpachi><sensei>. + "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F", + "3B-ww4c5e180e575a65lsy2b", + }, + { + // (M) <amuro><namie>-with-SUPER-MONKEYS. + "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" + + "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" + + "\u004F\u004E\u004B\u0045\u0059\u0053", + "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n", + }, + { + // (N) Hello-Another-Way-<sorezore><no><basho>. + "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" + + "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" + + "\u305D\u308C\u305E\u308C\u306E\u5834\u6240", + "Hello-Another-Way--fc4qua05auwb3674vfr0b", + }, + { + // (O) <hitotsu><yane><no><shita>2. + "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032", + "2-u9tlzr9756bt3uc0v", + }, + { + // (P) Maji<de>Koi<suru>5<byou><mae> + "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" + + "\u308B\u0035\u79D2\u524D", + "MajiKoi5-783gue6qz075azm5e", + }, + { + // (Q) <pafii>de<runba> + "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0", + "de-jg4avhby1noc0d", + }, + { + // (R) <sono><supiido><de> + "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067", + "d9juau41awczczp", + }, + { + // (S) -> $1.00 <- + "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" + + "\u003C\u002D", + "-> $1.00 <--", + }, +} + +func TestPunycode(t *testing.T) { + for _, tc := range punycodeTestCases { + if got, err := encode("", tc.s); err != nil { + t.Errorf(`encode("", %q): %v`, tc.s, err) + } else if got != tc.encoded { + t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded) + } + } +} diff --git a/src/pkg/net/http/example_test.go b/src/pkg/net/http/example_test.go index ec814407d..22073eaf7 100644 --- a/src/pkg/net/http/example_test.go +++ b/src/pkg/net/http/example_test.go @@ -43,10 +43,10 @@ func ExampleGet() { log.Fatal(err) } robots, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { log.Fatal(err) } - res.Body.Close() fmt.Printf("%s", robots) } diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go index 13640ca85..a7bca20a0 100644 --- a/src/pkg/net/http/export_test.go +++ b/src/pkg/net/http/export_test.go @@ -7,12 +7,25 @@ package http -import "time" +import ( + "net" + "time" +) + +func NewLoggingConn(baseName string, c net.Conn) net.Conn { + return newLoggingConn(baseName, c) +} + +func (t *Transport) NumPendingRequestsForTesting() int { + t.reqMu.Lock() + defer t.reqMu.Unlock() + return len(t.reqConn) +} func (t *Transport) IdleConnKeysForTesting() (keys []string) { keys = make([]string, 0) - t.lk.Lock() - defer t.lk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return } @@ -23,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { } func (t *Transport) IdleConnCountForTesting(cacheKey string) int { - t.lk.Lock() - defer t.lk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return 0 } diff --git a/src/pkg/net/http/filetransport_test.go b/src/pkg/net/http/filetransport_test.go index 039926b53..6f1a537e2 100644 --- a/src/pkg/net/http/filetransport_test.go +++ b/src/pkg/net/http/filetransport_test.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http_test +package http import ( "io/ioutil" - "net/http" "os" "path/filepath" "testing" @@ -32,9 +31,9 @@ func TestFileTransport(t *testing.T) { defer os.Remove(dname) defer os.Remove(fname) - tr := &http.Transport{} - tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname))) - c := &http.Client{Transport: tr} + tr := &Transport{} + tr.RegisterProtocol("file", NewFileTransport(Dir(dname))) + c := &Client{Transport: tr} fooURLs := []string{"file:///foo.txt", "file://../foo.txt"} for _, urlstr := range fooURLs { @@ -62,4 +61,5 @@ func TestFileTransport(t *testing.T) { if res.StatusCode != 404 { t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode) } + res.Body.Close() } diff --git a/src/pkg/net/http/fs.go b/src/pkg/net/http/fs.go index f35dd32c3..b6bea0dfa 100644 --- a/src/pkg/net/http/fs.go +++ b/src/pkg/net/http/fs.go @@ -11,6 +11,8 @@ import ( "fmt" "io" "mime" + "mime/multipart" + "net/textproto" "os" "path" "path/filepath" @@ -26,7 +28,8 @@ import ( type Dir string func (d Dir) Open(name string) (File, error) { - if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 { + if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 || + strings.Contains(name, "\x00") { return nil, errors.New("http: invalid character in file path") } dir := string(d) @@ -97,6 +100,9 @@ func dirList(w ResponseWriter, f File) { // The content's Seek method must work: ServeContent uses // a seek to the end of the content to determine its size. // +// If the caller has set w's ETag header, ServeContent uses it to +// handle requests using If-Range and If-None-Match. +// // Note that *os.File implements the io.ReadSeeker interface. func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { size, err := content.Seek(0, os.SEEK_END) @@ -119,12 +125,17 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, if checkLastModified(w, r, modtime) { return } + rangeReq, done := checkETag(w, r) + if done { + return + } code := StatusOK // If Content-Type isn't set, use the file's extension to find it. - if w.Header().Get("Content-Type") == "" { - ctype := mime.TypeByExtension(filepath.Ext(name)) + ctype := w.Header().Get("Content-Type") + if ctype == "" { + ctype = mime.TypeByExtension(filepath.Ext(name)) if ctype == "" { // read a chunk to decide between utf-8 text and binary var buf [1024]byte @@ -141,18 +152,34 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } // handle Content-Range header. - // TODO(adg): handle multiple ranges sendSize := size + var sendContent io.Reader = content if size >= 0 { - ranges, err := parseRange(r.Header.Get("Range"), size) - if err == nil && len(ranges) > 1 { - err = errors.New("multiple ranges not supported") - } + ranges, err := parseRange(rangeReq, size) if err != nil { Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } - if len(ranges) == 1 { + if sumRangesSize(ranges) >= size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 2616, Section 14.16: + // "When an HTTP message includes the content of a single + // range (for example, a response to a request for a + // single range, or to a request for a set of ranges + // that overlap without any holes), this content is + // transmitted with a Content-Range header, and a + // Content-Length header showing the number of bytes + // actually transferred. + // ... + // A response to a request for a single range MUST NOT + // be sent using the multipart/byteranges media type." ra := ranges[0] if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) @@ -160,7 +187,41 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } sendSize = ra.length code = StatusPartialContent - w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, size)) + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + for _, ra := range ranges { + if ra.start > size { + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + } + sendSize = rangesMIMESize(ranges, ctype, size) + code = StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() } w.Header().Set("Accept-Ranges", "bytes") @@ -172,11 +233,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, w.WriteHeader(code) if r.Method != "HEAD" { - if sendSize == -1 { - io.Copy(w, content) - } else { - io.CopyN(w, content, sendSize) - } + io.CopyN(w, sendContent, sendSize) } } @@ -190,6 +247,9 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { // The Date-Modified header truncates sub-second precision, so // use mtime < t+1s instead of mtime <= t to check for unmodified. if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") w.WriteHeader(StatusNotModified) return true } @@ -197,6 +257,58 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { return false } +// checkETag implements If-None-Match and If-Range checks. +// The ETag must have been previously set in the ResponseWriter's headers. +// +// The return value is the effective request "Range" header to use and +// whether this request is now considered done. +func checkETag(w ResponseWriter, r *Request) (rangeReq string, done bool) { + etag := w.Header().get("Etag") + rangeReq = r.Header.get("Range") + + // Invalidate the range request if the entity doesn't match the one + // the client was expecting. + // "If-Range: version" means "ignore the Range: header unless version matches the + // current file." + // We only support ETag versions. + // The caller must have set the ETag on the response already. + if ir := r.Header.get("If-Range"); ir != "" && ir != etag { + // TODO(bradfitz): handle If-Range requests with Last-Modified + // times instead of ETags? I'd rather not, at least for + // now. That seems like a bug/compromise in the RFC 2616, and + // I've never heard of anybody caring about that (yet). + rangeReq = "" + } + + if inm := r.Header.get("If-None-Match"); inm != "" { + // Must know ETag. + if etag == "" { + return rangeReq, false + } + + // TODO(bradfitz): non-GET/HEAD requests require more work: + // sending a different status code on matches, and + // also can't use weak cache validators (those with a "W/ + // prefix). But most users of ServeContent will be using + // it on GET or HEAD, so only support those for now. + if r.Method != "GET" && r.Method != "HEAD" { + return rangeReq, false + } + + // TODO(bradfitz): deal with comma-separated or multiple-valued + // list of If-None-match values. For now just handle the common + // case of a single item. + if inm == etag || inm == "*" { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(StatusNotModified) + return "", true + } + } + return rangeReq, false +} + // name is '/'-separated, not filepath.Separator. func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) { const indexPage = "/index.html" @@ -243,9 +355,6 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec // use contents of index.html for directory, if present if d.IsDir() { - if checkLastModified(w, r, d.ModTime()) { - return - } index := name + indexPage ff, err := fs.Open(index) if err == nil { @@ -259,11 +368,16 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec } } + // Still a directory? (we didn't find an index.html file) if d.IsDir() { + if checkLastModified(w, r, d.ModTime()) { + return + } dirList(w, f) return } + // serverContent will check modification time serveContent(w, r, d.Name(), d.ModTime(), d.Size(), f) } @@ -312,6 +426,17 @@ type httpRange struct { start, length int64 } +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + // parseRange parses a Range header string as per RFC 2616. func parseRange(s string, size int64) ([]httpRange, error) { if s == "" { @@ -323,11 +448,15 @@ func parseRange(s string, size int64) ([]httpRange, error) { } var ranges []httpRange for _, ra := range strings.Split(s[len(b):], ",") { + ra = strings.TrimSpace(ra) + if ra == "" { + continue + } i := strings.Index(ra, "-") if i < 0 { return nil, errors.New("invalid range") } - start, end := ra[:i], ra[i+1:] + start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:]) var r httpRange if start == "" { // If no start is specified, end specifies the @@ -365,3 +494,32 @@ func parseRange(s string, size int64) ([]httpRange, error) { } return ranges, nil } + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the nunber of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go index 5aa93ce58..0dd6d0df9 100644 --- a/src/pkg/net/http/fs_test.go +++ b/src/pkg/net/http/fs_test.go @@ -10,12 +10,15 @@ import ( "fmt" "io" "io/ioutil" + "mime" + "mime/multipart" "net" . "net/http" "net/http/httptest" "net/url" "os" "os/exec" + "path" "path/filepath" "regexp" "runtime" @@ -25,24 +28,33 @@ import ( ) const ( - testFile = "testdata/file" - testFileLength = 11 + testFile = "testdata/file" + testFileLen = 11 ) +type wantRange struct { + start, end int64 // range [start,end) +} + var ServeFileRangeTests = []struct { - start, end int - r string - code int + r string + code int + ranges []wantRange }{ - {0, testFileLength, "", StatusOK}, - {0, 5, "0-4", StatusPartialContent}, - {2, testFileLength, "2-", StatusPartialContent}, - {testFileLength - 5, testFileLength, "-5", StatusPartialContent}, - {3, 8, "3-7", StatusPartialContent}, - {0, 0, "20-", StatusRequestedRangeNotSatisfiable}, + {r: "", code: StatusOK}, + {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}}, + {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}}, + {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}}, + {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}}, + {r: "bytes=20-", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}}, + {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}}, + {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}}, + {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request } func TestServeFile(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") })) @@ -65,33 +77,86 @@ func TestServeFile(t *testing.T) { // straight GET _, body := getBody(t, "straight get", req) - if !equal(body, file) { + if !bytes.Equal(body, file) { t.Fatalf("body mismatch: got %q, want %q", body, file) } // Range tests - for i, rt := range ServeFileRangeTests { - req.Header.Set("Range", "bytes="+rt.r) - if rt.r == "" { - req.Header["Range"] = nil +Cases: + for _, rt := range ServeFileRangeTests { + if rt.r != "" { + req.Header.Set("Range", rt.r) } - r, body := getBody(t, fmt.Sprintf("test %d", i), req) - if r.StatusCode != rt.code { - t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, r.StatusCode, rt.code) + resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req) + if resp.StatusCode != rt.code { + t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code) } if rt.code == StatusRequestedRangeNotSatisfiable { continue } - h := fmt.Sprintf("bytes %d-%d/%d", rt.start, rt.end-1, testFileLength) - if rt.r == "" { - h = "" + wantContentRange := "" + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + } + cr := resp.Header.Get("Content-Range") + if cr != wantContentRange { + t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange) } - cr := r.Header.Get("Content-Range") - if cr != h { - t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h) + ct := resp.Header.Get("Content-Type") + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + if strings.HasPrefix(ct, "multipart/byteranges") { + t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r, ct) + } } - if !equal(body, file[rt.start:rt.end]) { - t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end]) + if len(rt.ranges) > 1 { + typ, params, err := mime.ParseMediaType(ct) + if err != nil { + t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err) + continue + } + if typ != "multipart/byteranges" { + t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r, typ) + continue + } + if params["boundary"] == "" { + t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct) + continue + } + if g, w := resp.ContentLength, int64(len(body)); g != w { + t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w) + continue + } + mr := multipart.NewReader(bytes.NewReader(body), params["boundary"]) + for ri, rng := range rt.ranges { + part, err := mr.NextPart() + if err != nil { + t.Errorf("range=%q, reading part index %d: %v", rt.r, ri, err) + continue Cases + } + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w { + t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w) + } + body, err := ioutil.ReadAll(part) + if err != nil { + t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err) + continue Cases + } + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + } + _, err = mr.NextPart() + if err != io.EOF { + t.Errorf("range=%q; expected final error io.EOF; got %v", rt.r, err) + } } } } @@ -105,6 +170,7 @@ var fsRedirectTestData = []struct { } func TestFSRedirect(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) defer ts.Close() @@ -129,6 +195,7 @@ func (fs *testFileSystem) Open(name string) (File, error) { } func TestFileServerCleans(t *testing.T) { + defer checkLeakedTransports(t) ch := make(chan string, 1) fs := FileServer(&testFileSystem{func(name string) (File, error) { ch <- name @@ -160,6 +227,7 @@ func mustRemoveAll(dir string) { } func TestFileServerImplicitLeadingSlash(t *testing.T) { + defer checkLeakedTransports(t) tempDir, err := ioutil.TempDir("", "") if err != nil { t.Fatalf("TempDir: %v", err) @@ -193,8 +261,7 @@ func TestFileServerImplicitLeadingSlash(t *testing.T) { func TestDirJoin(t *testing.T) { wfi, err := os.Stat("/etc/hosts") if err != nil { - t.Logf("skipping test; no /etc/hosts file") - return + t.Skip("skipping test; no /etc/hosts file") } test := func(d Dir, name string) { f, err := d.Open(name) @@ -239,6 +306,7 @@ func TestEmptyDirOpenCWD(t *testing.T) { } func TestServeFileContentType(t *testing.T) { + defer checkLeakedTransports(t) const ctype = "icecream/chocolate" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.FormValue("override") == "1" { @@ -255,12 +323,14 @@ func TestServeFileContentType(t *testing.T) { if h := resp.Header.Get("Content-Type"); h != want { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) } + resp.Body.Close() } get("0", "text/plain; charset=utf-8") get("1", ctype) } func TestServeFileMimeType(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/style.css") })) @@ -269,6 +339,7 @@ func TestServeFileMimeType(t *testing.T) { if err != nil { t.Fatal(err) } + resp.Body.Close() want := "text/css; charset=utf-8" if h := resp.Header.Get("Content-Type"); h != want { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) @@ -276,6 +347,7 @@ func TestServeFileMimeType(t *testing.T) { } func TestServeFileFromCWD(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") })) @@ -284,12 +356,14 @@ func TestServeFileFromCWD(t *testing.T) { if err != nil { t.Fatal(err) } + r.Body.Close() if r.StatusCode != 200 { t.Fatalf("expected 200 OK, got %s", r.Status) } } func TestServeFileWithContentEncoding(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "foo") ServeFile(w, r, "testdata/file") @@ -299,12 +373,14 @@ func TestServeFileWithContentEncoding(t *testing.T) { if err != nil { t.Fatal(err) } + resp.Body.Close() if g, e := resp.ContentLength, int64(-1); g != e { t.Errorf("Content-Length mismatch: got %d, want %d", g, e) } } func TestServeIndexHtml(t *testing.T) { + defer checkLeakedTransports(t) const want = "index.html says hello\n" ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -325,64 +401,289 @@ func TestServeIndexHtml(t *testing.T) { } } -func TestServeContent(t *testing.T) { - type req struct { - name string - modtime time.Time - content io.ReadSeeker - } - ch := make(chan req, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - p := <-ch - ServeContent(w, r, p.name, p.modtime, p.content) - })) +func TestFileServerZeroByte(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() - css, err := os.Open("testdata/style.css") + res, err := Get(ts.URL + "/..\x00") if err != nil { t.Fatal(err) } - defer css.Close() + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if res.StatusCode == 200 { + t.Errorf("got status 200; want an error. Body is:\n%s", string(b)) + } +} + +type fakeFileInfo struct { + dir bool + basename string + modtime time.Time + ents []*fakeFileInfo + contents string +} + +func (f *fakeFileInfo) Name() string { return f.basename } +func (f *fakeFileInfo) Sys() interface{} { return nil } +func (f *fakeFileInfo) ModTime() time.Time { return f.modtime } +func (f *fakeFileInfo) IsDir() bool { return f.dir } +func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) } +func (f *fakeFileInfo) Mode() os.FileMode { + if f.dir { + return 0755 | os.ModeDir + } + return 0644 +} + +type fakeFile struct { + io.ReadSeeker + fi *fakeFileInfo + path string // as opened +} + +func (f *fakeFile) Close() error { return nil } +func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil } +func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { + if !f.fi.dir { + return nil, os.ErrInvalid + } + var fis []os.FileInfo + for _, fi := range f.fi.ents { + fis = append(fis, fi) + } + return fis, nil +} + +type fakeFS map[string]*fakeFileInfo + +func (fs fakeFS) Open(name string) (File, error) { + name = path.Clean(name) + f, ok := fs[name] + if !ok { + println("fake filesystem didn't find file", name) + return nil, os.ErrNotExist + } + return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil +} + +func TestDirectoryIfNotModified(t *testing.T) { + defer checkLeakedTransports(t) + const indexContents = "I am a fake index.html file" + fileMod := time.Unix(1000000000, 0).UTC() + fileModStr := fileMod.Format(TimeFormat) + dirMod := time.Unix(123, 0).UTC() + indexFile := &fakeFileInfo{ + basename: "index.html", + modtime: fileMod, + contents: indexContents, + } + fs := fakeFS{ + "/": &fakeFileInfo{ + dir: true, + modtime: dirMod, + ents: []*fakeFileInfo{indexFile}, + }, + "/index.html": indexFile, + } + + ts := httptest.NewServer(FileServer(fs)) + defer ts.Close() - ch <- req{"style.css", time.Time{}, css} res, err := Get(ts.URL) if err != nil { t.Fatal(err) } - if g, e := res.Header.Get("Content-Type"), "text/css; charset=utf-8"; g != e { - t.Errorf("style.css: content type = %q, want %q", g, e) + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) } - if g := res.Header.Get("Last-Modified"); g != "" { - t.Errorf("want empty Last-Modified; got %q", g) + if string(b) != indexContents { + t.Fatalf("Got body %q; want %q", b, indexContents) } + res.Body.Close() + + lastMod := res.Header.Get("Last-Modified") + if lastMod != fileModStr { + t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr) + } + + req, _ := NewRequest("GET", ts.URL, nil) + req.Header.Set("If-Modified-Since", lastMod) - fi, err := css.Stat() + res, err = DefaultClient.Do(req) if err != nil { t.Fatal(err) } - ch <- req{"style.html", fi.ModTime(), css} - res, err = Get(ts.URL) + if res.StatusCode != 304 { + t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode) + } + res.Body.Close() + + // Advance the index.html file's modtime, but not the directory's. + indexFile.modtime = indexFile.modtime.Add(1 * time.Hour) + + res, err = DefaultClient.Do(req) if err != nil { t.Fatal(err) } - if g, e := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != e { - t.Errorf("style.html: content type = %q, want %q", g, e) + if res.StatusCode != 200 { + t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res) } - if g := res.Header.Get("Last-Modified"); g == "" { - t.Errorf("want non-empty last-modified") + res.Body.Close() +} + +func mustStat(t *testing.T, fileName string) os.FileInfo { + fi, err := os.Stat(fileName) + if err != nil { + t.Fatal(err) + } + return fi +} + +func TestServeContent(t *testing.T) { + defer checkLeakedTransports(t) + type serveParam struct { + name string + modtime time.Time + content io.ReadSeeker + contentType string + etag string + } + servec := make(chan serveParam, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + p := <-servec + if p.etag != "" { + w.Header().Set("ETag", p.etag) + } + if p.contentType != "" { + w.Header().Set("Content-Type", p.contentType) + } + ServeContent(w, r, p.name, p.modtime, p.content) + })) + defer ts.Close() + + type testCase struct { + file string + modtime time.Time + serveETag string // optional + serveContentType string // optional + reqHeader map[string]string + wantLastMod string + wantContentType string + wantStatus int + } + htmlModTime := mustStat(t, "testdata/index.html").ModTime() + tests := map[string]testCase{ + "no_last_modified": { + file: "testdata/style.css", + wantContentType: "text/css; charset=utf-8", + wantStatus: 200, + }, + "with_last_modified": { + file: "testdata/index.html", + wantContentType: "text/html; charset=utf-8", + modtime: htmlModTime, + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + wantStatus: 200, + }, + "not_modified_modtime": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_modtime_with_contenttype": { + file: "testdata/style.css", + serveContentType: "text/css", // explicit content type + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_etag": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"foo"`, + }, + wantStatus: 304, + }, + "range_good": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + }, + // An If-Range resource for entity "A", but entity "B" is now current. + // The Range request should be ignored. + "range_no_match": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `"B"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + } + for testName, tt := range tests { + f, err := os.Open(tt.file) + if err != nil { + t.Fatalf("test %q: %v", testName, err) + } + defer f.Close() + + servec <- serveParam{ + name: filepath.Base(tt.file), + content: f, + modtime: tt.modtime, + etag: tt.serveETag, + contentType: tt.serveContentType, + } + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + for k, v := range tt.reqHeader { + req.Header.Set(k, v) + } + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + if res.StatusCode != tt.wantStatus { + t.Errorf("test %q: status = %d; want %d", testName, res.StatusCode, tt.wantStatus) + } + if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e { + t.Errorf("test %q: content-type = %q, want %q", testName, g, e) + } + if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { + t.Errorf("test %q: last-modified = %q, want %q", testName, g, e) + } } } // verifies that sendfile is being used on Linux func TestLinuxSendfile(t *testing.T) { + defer checkLeakedTransports(t) if runtime.GOOS != "linux" { - t.Logf("skipping; linux-only test") - return + t.Skip("skipping; linux-only test") } - _, err := exec.LookPath("strace") - if err != nil { - t.Logf("skipping; strace not found in path") - return + if _, err := exec.LookPath("strace"); err != nil { + t.Skip("skipping; strace not found in path") } ln, err := net.Listen("tcp", "127.0.0.1:0") @@ -401,10 +702,8 @@ func TestLinuxSendfile(t *testing.T) { child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) child.Stdout = &buf child.Stderr = &buf - err = child.Start() - if err != nil { - t.Logf("skipping; failed to start straced child: %v", err) - return + if err := child.Start(); err != nil { + t.Skipf("skipping; failed to start straced child: %v", err) } res, err := Get(fmt.Sprintf("http://%s/", ln.Addr())) @@ -464,15 +763,3 @@ func TestLinuxSendfileChild(*testing.T) { panic(err) } } - -func equal(a, b []byte) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go index b107c312d..f479b7b4e 100644 --- a/src/pkg/net/http/header.go +++ b/src/pkg/net/http/header.go @@ -5,11 +5,11 @@ package http import ( - "fmt" "io" "net/textproto" "sort" "strings" + "time" ) // A Header represents the key-value pairs in an HTTP header. @@ -36,6 +36,14 @@ func (h Header) Get(key string) string { return textproto.MIMEHeader(h).Get(key) } +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + // Del deletes the values associated with key. func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) @@ -46,24 +54,87 @@ func (h Header) Write(w io.Writer) error { return h.WriteSubset(w, nil) } +func (h Header) clone() Header { + h2 := make(Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +var timeFormats = []string{ + TimeFormat, + time.RFC850, + time.ANSIC, +} + +// ParseTime parses a time header (such as the Date: header), +// trying each of the three formats allowed by HTTP/1.1: +// TimeFormat, time.RFC850, and time.ANSIC. +func ParseTime(text string) (t time.Time, err error) { + for _, layout := range timeFormats { + t, err = time.Parse(layout, text) + if err == nil { + return + } + } + return +} + var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") +type writeStringer interface { + WriteString(string) (int, error) +} + +// stringWriter implements WriteString on a Writer. +type stringWriter struct { + w io.Writer +} + +func (w stringWriter) WriteString(s string) (n int, err error) { + return w.w.Write([]byte(s)) +} + +type keyValues struct { + key string + values []string +} + +type byKey []keyValues + +func (s byKey) Len() int { return len(s) } +func (s byKey) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s byKey) Less(i, j int) bool { return s[i].key < s[j].key } + +func (h Header) sortedKeyValues(exclude map[string]bool) []keyValues { + kvs := make([]keyValues, 0, len(h)) + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + sort.Sort(byKey(kvs)) + return kvs +} + // WriteSubset writes a header in wire format. // If exclude is not nil, keys where exclude[key] == true are not written. func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { - keys := make([]string, 0, len(h)) - for k := range h { - if exclude == nil || !exclude[k] { - keys = append(keys, k) - } + ws, ok := w.(writeStringer) + if !ok { + ws = stringWriter{w} } - sort.Strings(keys) - for _, k := range keys { - for _, v := range h[k] { + for _, kv := range h.sortedKeyValues(exclude) { + for _, v := range kv.values { v = headerNewlineToSpace.Replace(v) - v = strings.TrimSpace(v) - if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { - return err + v = textproto.TrimString(v) + for _, s := range []string{kv.key, ": ", v, "\r\n"} { + if _, err := ws.WriteString(s); err != nil { + return err + } } } } @@ -76,3 +147,43 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { // the rest are converted to lowercase. For example, the // canonical key for "accept-encoding" is "Accept-Encoding". func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + +// hasToken returns whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if strings.EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go index ccdee8a97..2313b5549 100644 --- a/src/pkg/net/http/header_test.go +++ b/src/pkg/net/http/header_test.go @@ -7,6 +7,7 @@ package http import ( "bytes" "testing" + "time" ) var headerWriteTests = []struct { @@ -67,6 +68,24 @@ var headerWriteTests = []struct { nil, "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n", }, + // Tests header sorting when over the insertion sort threshold side: + { + Header{ + "k1": {"1a", "1b"}, + "k2": {"2a", "2b"}, + "k3": {"3a", "3b"}, + "k4": {"4a", "4b"}, + "k5": {"5a", "5b"}, + "k6": {"6a", "6b"}, + "k7": {"7a", "7b"}, + "k8": {"8a", "8b"}, + "k9": {"9a", "9b"}, + }, + map[string]bool{"k5": true}, + "k1: 1a\r\nk1: 1b\r\nk2: 2a\r\nk2: 2b\r\nk3: 3a\r\nk3: 3b\r\n" + + "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" + + "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n", + }, } func TestHeaderWrite(t *testing.T) { @@ -79,3 +98,107 @@ func TestHeaderWrite(t *testing.T) { buf.Reset() } } + +var parseTimeTests = []struct { + h Header + err bool +}{ + {Header{"Date": {""}}, true}, + {Header{"Date": {"invalid"}}, true}, + {Header{"Date": {"1994-11-06T08:49:37Z00:00"}}, true}, + {Header{"Date": {"Sun, 06 Nov 1994 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sunday, 06-Nov-94 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sun Nov 6 08:49:37 1994"}}, false}, +} + +func TestParseTime(t *testing.T) { + expect := time.Date(1994, 11, 6, 8, 49, 37, 0, time.UTC) + for i, test := range parseTimeTests { + d, err := ParseTime(test.h.Get("Date")) + if err != nil { + if !test.err { + t.Errorf("#%d:\n got err: %v", i, err) + } + continue + } + if test.err { + t.Errorf("#%d:\n should err", i) + continue + } + if !expect.Equal(d) { + t.Errorf("#%d:\n got: %v\nwant: %v", i, d, expect) + } + } +} + +type hasTokenTest struct { + header string + token string + want bool +} + +var hasTokenTests = []hasTokenTest{ + {"", "", false}, + {"", "foo", false}, + {"foo", "foo", true}, + {"foo ", "foo", true}, + {" foo", "foo", true}, + {" foo ", "foo", true}, + {"foo,bar", "foo", true}, + {"bar,foo", "foo", true}, + {"bar, foo", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo,baz", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo, baz", "foo", true}, + {"FOO", "foo", true}, + {"FOO ", "foo", true}, + {" FOO", "foo", true}, + {" FOO ", "foo", true}, + {"FOO,BAR", "foo", true}, + {"BAR,FOO", "foo", true}, + {"BAR, FOO", "foo", true}, + {"BAR,FOO, baz", "foo", true}, + {"BAR, FOO,BAZ", "foo", true}, + {"BAR,FOO, BAZ", "foo", true}, + {"BAR, FOO, BAZ", "foo", true}, + {"foobar", "foo", false}, + {"barfoo ", "foo", false}, +} + +func TestHasToken(t *testing.T) { + for _, tt := range hasTokenTests { + if hasToken(tt.header, tt.token) != tt.want { + t.Errorf("hasToken(%q, %q) = %v; want %v", tt.header, tt.token, !tt.want, tt.want) + } + } +} + +var testHeader = Header{ + "Content-Length": {"123"}, + "Content-Type": {"text/plain"}, + "Date": {"some date at some time Z"}, + "Server": {"Go http package"}, +} + +var buf bytes.Buffer + +func BenchmarkHeaderWriteSubset(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + } +} + +func TestHeaderWriteSubsetMallocs(t *testing.T) { + n := testing.AllocsPerRun(100, func() { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + }) + if n > 1 { + // TODO(bradfitz,rsc): once we can sort without allocating, + // make this an error. See http://golang.org/issue/3761 + // t.Errorf("got %v allocs, want <= %v", n, 1) + } +} diff --git a/src/pkg/net/http/httptest/example_test.go b/src/pkg/net/http/httptest/example_test.go new file mode 100644 index 000000000..239470d97 --- /dev/null +++ b/src/pkg/net/http/httptest/example_test.go @@ -0,0 +1,50 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest_test + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" +) + +func ExampleRecorder() { + handler := func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "something failed", http.StatusInternalServerError) + } + + req, err := http.NewRequest("GET", "http://example.com/foo", nil) + if err != nil { + log.Fatal(err) + } + + w := httptest.NewRecorder() + handler(w, req) + + fmt.Printf("%d - %s", w.Code, w.Body.String()) + // Output: 500 - something failed +} + +func ExampleServer() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + res, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", greeting) + // Output: Hello, client +} diff --git a/src/pkg/net/http/httptest/recorder.go b/src/pkg/net/http/httptest/recorder.go index 9aa0d510b..5451f5423 100644 --- a/src/pkg/net/http/httptest/recorder.go +++ b/src/pkg/net/http/httptest/recorder.go @@ -17,6 +17,8 @@ type ResponseRecorder struct { HeaderMap http.Header // the HTTP response headers Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Flushed bool + + wroteHeader bool } // NewRecorder returns an initialized ResponseRecorder. @@ -24,6 +26,7 @@ func NewRecorder() *ResponseRecorder { return &ResponseRecorder{ HeaderMap: make(http.Header), Body: new(bytes.Buffer), + Code: 200, } } @@ -33,26 +36,37 @@ const DefaultRemoteAddr = "1.2.3.4" // Header returns the response headers. func (rw *ResponseRecorder) Header() http.Header { - return rw.HeaderMap + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m } // Write always succeeds and writes to rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + if !rw.wroteHeader { + rw.WriteHeader(200) + } if rw.Body != nil { rw.Body.Write(buf) } - if rw.Code == 0 { - rw.Code = http.StatusOK - } return len(buf), nil } // WriteHeader sets rw.Code. func (rw *ResponseRecorder) WriteHeader(code int) { - rw.Code = code + if !rw.wroteHeader { + rw.Code = code + } + rw.wroteHeader = true } // Flush sets rw.Flushed to true. func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } rw.Flushed = true } diff --git a/src/pkg/net/http/httptest/recorder_test.go b/src/pkg/net/http/httptest/recorder_test.go new file mode 100644 index 000000000..2b563260c --- /dev/null +++ b/src/pkg/net/http/httptest/recorder_test.go @@ -0,0 +1,90 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "fmt" + "net/http" + "testing" +) + +func TestRecorder(t *testing.T) { + type checkFunc func(*ResponseRecorder) error + check := func(fns ...checkFunc) []checkFunc { return fns } + + hasStatus := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Code != wantCode { + return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) + } + return nil + } + } + hasContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Body.String() != want { + return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) + } + return nil + } + } + hasFlush := func(want bool) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Flushed != want { + return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) + } + return nil + } + } + + tests := []struct { + name string + h func(w http.ResponseWriter, r *http.Request) + checks []checkFunc + }{ + { + "200 default", + func(w http.ResponseWriter, r *http.Request) {}, + check(hasStatus(200), hasContents("")), + }, + { + "first code only", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + w.WriteHeader(202) + w.Write([]byte("hi")) + }, + check(hasStatus(201), hasContents("hi")), + }, + { + "write sends 200", + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi first")) + w.WriteHeader(201) + w.WriteHeader(202) + }, + check(hasStatus(200), hasContents("hi first"), hasFlush(false)), + }, + { + "flush", + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // also sends a 200 + w.WriteHeader(201) + }, + check(hasStatus(200), hasFlush(true)), + }, + } + r, _ := http.NewRequest("GET", "http://foo.com/", nil) + for _, tt := range tests { + h := http.HandlerFunc(tt.h) + rec := NewRecorder() + h.ServeHTTP(rec, r) + for _, check := range tt.checks { + if err := check(rec); err != nil { + t.Errorf("%s: %v", tt.name, err) + } + } + } +} diff --git a/src/pkg/net/http/httptest/server.go b/src/pkg/net/http/httptest/server.go index 57cf0c941..7f265552f 100644 --- a/src/pkg/net/http/httptest/server.go +++ b/src/pkg/net/http/httptest/server.go @@ -21,7 +21,11 @@ import ( type Server struct { URL string // base URL of form http://ipaddr:port with no trailing slash Listener net.Listener - TLS *tls.Config // nil if not using using TLS + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config // Config may be changed after calling NewUnstartedServer and // before Start or StartTLS. @@ -36,13 +40,16 @@ type Server struct { // accepted. type historyListener struct { net.Listener - history []net.Conn + sync.Mutex // protects history + history []net.Conn } func (hs *historyListener) Accept() (c net.Conn, err error) { c, err = hs.Listener.Accept() if err == nil { + hs.Lock() hs.history = append(hs.history, c) + hs.Unlock() } return } @@ -96,7 +103,7 @@ func (s *Server) Start() { if s.URL != "" { panic("Server already started") } - s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)} + s.Listener = &historyListener{Listener: s.Listener} s.URL = "http://" + s.Listener.Addr().String() s.wrapHandler() go s.Config.Serve(s.Listener) @@ -116,13 +123,20 @@ func (s *Server) StartTLS() { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) } - s.TLS = &tls.Config{ - NextProtos: []string{"http/1.1"}, - Certificates: []tls.Certificate{cert}, + existingConfig := s.TLS + s.TLS = new(tls.Config) + if existingConfig != nil { + *s.TLS = *existingConfig + } + if s.TLS.NextProtos == nil { + s.TLS.NextProtos = []string{"http/1.1"} + } + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} } tlsListener := tls.NewListener(s.Listener, s.TLS) - s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)} + s.Listener = &historyListener{Listener: tlsListener} s.URL = "https://" + s.Listener.Addr().String() s.wrapHandler() go s.Config.Serve(s.Listener) @@ -152,6 +166,10 @@ func NewTLSServer(handler http.Handler) *Server { func (s *Server) Close() { s.Listener.Close() s.wg.Wait() + s.CloseClientConnections() + if t, ok := http.DefaultTransport.(*http.Transport); ok { + t.CloseIdleConnections() + } } // CloseClientConnections closes any currently open HTTP connections @@ -161,9 +179,11 @@ func (s *Server) CloseClientConnections() { if !ok { return } + hl.Lock() for _, conn := range hl.history { conn.Close() } + hl.Unlock() } // waitGroupHandler wraps a handler, incrementing and decrementing a @@ -180,28 +200,29 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.h.ServeHTTP(w, r) } -// localhostCert is a PEM-encoded TLS cert with SAN DNS names +// localhostCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end // of ASN.1 time). +// generated from src/pkg/crypto/tls: +// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var localhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX -DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7 -qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL -8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw -DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG -CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7 -+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM= ------END CERTIFICATE----- -`) +MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD +bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj +bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa +IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA +AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud +EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA +AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk +Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA== +-----END CERTIFICATE-----`) // localhostKey is the private key for localhostCert. var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- -MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v -tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi -SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0 -3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ -/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN -poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh -AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93 ------END RSA PRIVATE KEY----- -`) +MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0 +0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV +NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d +AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW +MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD +EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA +1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE= +-----END RSA PRIVATE KEY-----`) diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go index 29eaf3475..b66d40951 100644 --- a/src/pkg/net/http/httputil/chunked.go +++ b/src/pkg/net/http/httputil/chunked.go @@ -13,10 +13,9 @@ package httputil import ( "bufio" - "bytes" "errors" + "fmt" "io" - "strconv" ) const maxLineLength = 4096 // assumed <= bufio.defaultBufSize @@ -24,7 +23,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize var ErrLineTooLong = errors.New("header line too long") // NewChunkedReader returns a new chunkedReader that translates the data read from r -// out of HTTP "chunked" format before returning it. +// out of HTTP "chunked" format before returning it. // The chunkedReader returns io.EOF when the final 0-length chunk is read. // // NewChunkedReader is not needed by normal applications. The http package @@ -41,16 +40,17 @@ type chunkedReader struct { r *bufio.Reader n uint64 // unread bytes in chunk err error + buf [2]byte } func (cr *chunkedReader) beginChunk() { // chunk-size CRLF - var line string + var line []byte line, cr.err = readLine(cr.r) if cr.err != nil { return } - cr.n, cr.err = strconv.ParseUint(line, 16, 64) + cr.n, cr.err = parseHexUint(line) if cr.err != nil { return } @@ -76,9 +76,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { cr.n -= uint64(n) if cr.n == 0 && cr.err == nil { // end of chunk (CRLF) - b := make([]byte, 2) - if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil { - if b[0] != '\r' || b[1] != '\n' { + if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil { + if cr.buf[0] != '\r' || cr.buf[1] != '\n' { cr.err = errors.New("malformed chunked encoding") } } @@ -90,7 +89,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // Give up if the line exceeds maxLineLength. // The returned bytes are a pointer into storage in // the bufio, so they are only valid until the next bufio read. -func readLineBytes(b *bufio.Reader) (p []byte, err error) { +func readLine(b *bufio.Reader) (p []byte, err error) { if p, err = b.ReadSlice('\n'); err != nil { // We always know when EOF is coming. // If the caller asked for a line, there should be a line. @@ -104,20 +103,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) { if len(p) >= maxLineLength { return nil, ErrLineTooLong } - - // Chop off trailing white space. - p = bytes.TrimRight(p, " \r\t\n") - - return p, nil + return trimTrailingWhitespace(p), nil } -// readLineBytes, but convert the bytes into a string. -func readLine(b *bufio.Reader) (s string, err error) { - p, e := readLineBytes(b) - if e != nil { - return "", e +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] } - return string(p), nil + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' } // NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP @@ -149,9 +146,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) { return 0, nil } - head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" - - if _, err = io.WriteString(cw.Wire, head); err != nil { + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { return 0, err } if n, err = cw.Wire.Write(data); err != nil { @@ -170,3 +165,21 @@ func (cw *chunkedWriter) Close() error { _, err := io.WriteString(cw.Wire, "0\r\n") return err } + +func parseHexUint(v []byte) (n uint64, err error) { + for _, b := range v { + n <<= 4 + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + n |= uint64(b) + } + return +} diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go index 155a32bdf..a06bffad5 100644 --- a/src/pkg/net/http/httputil/chunked_test.go +++ b/src/pkg/net/http/httputil/chunked_test.go @@ -11,7 +11,10 @@ package httputil import ( "bytes" + "fmt" + "io" "io/ioutil" + "runtime" "testing" ) @@ -39,3 +42,54 @@ func TestChunk(t *testing.T) { t.Errorf("chunk reader read %q; want %q", g, e) } } + +func TestChunkReaderAllocs(t *testing.T) { + // temporarily set GOMAXPROCS to 1 as we are testing memory allocations + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + var buf bytes.Buffer + w := NewChunkedWriter(&buf) + a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") + w.Write(a) + w.Write(b) + w.Write(c) + w.Close() + + r := NewChunkedReader(&buf) + readBuf := make([]byte, len(a)+len(b)+len(c)+1) + + var ms runtime.MemStats + runtime.ReadMemStats(&ms) + m0 := ms.Mallocs + + n, err := io.ReadFull(r, readBuf) + + runtime.ReadMemStats(&ms) + mallocs := ms.Mallocs - m0 + if mallocs > 1 { + t.Errorf("%d mallocs; want <= 1", mallocs) + } + + if n != len(readBuf)-1 { + t.Errorf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("read error = %v; want ErrUnexpectedEOF", err) + } +} + +func TestParseHexUint(t *testing.T) { + for i := uint64(0); i <= 1234; i++ { + line := []byte(fmt.Sprintf("%x", i)) + got, err := parseHexUint(line) + if err != nil { + t.Fatalf("on %d: %v", i, err) + } + if got != i { + t.Errorf("for input %q = %d; want %d", line, got, i) + } + } + _, err := parseHexUint([]byte("bogus")) + if err == nil { + t.Error("expected error on bogus input") + } +} diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go index 892ef4ede..0b0035661 100644 --- a/src/pkg/net/http/httputil/dump.go +++ b/src/pkg/net/http/httputil/dump.go @@ -75,7 +75,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { // Use the actual Transport code to record what we would send // on the wire, but not using TCP. Use a Transport with a - // customer dialer that returns a fake net.Conn that waits + // custom dialer that returns a fake net.Conn that waits // for the full input (and recording it), and then responds // with a dummy response. var buf bytes.Buffer // records the output @@ -89,7 +89,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { t := &http.Transport{ Dial: func(net, addr string) (net.Conn, error) { - return &dumpConn{io.MultiWriter(pw, &buf), dr}, nil + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil }, } diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go index 9c4bd6e09..134c45299 100644 --- a/src/pkg/net/http/httputil/reverseproxy.go +++ b/src/pkg/net/http/httputil/reverseproxy.go @@ -17,6 +17,10 @@ import ( "time" ) +// onExitFlushLoop is a callback set by tests to detect the state of the +// flushLoop() goroutine. +var onExitFlushLoop func() + // ReverseProxy is an HTTP Handler that takes an incoming request and // sends it to another server, proxying the response back to the // client. @@ -102,8 +106,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header.Del("Connection") } - if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - outreq.Header.Set("X-Forwarded-For", clientIp) + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) } res, err := transport.RoundTrip(outreq) @@ -112,20 +122,29 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) return } + defer res.Body.Close() copyHeader(rw.Header(), res.Header) rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) +} - if res.Body != nil { - var dst io.Writer = rw - if p.FlushInterval != 0 { - if wf, ok := rw.(writeFlusher); ok { - dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval} +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw } - io.Copy(dst, res.Body) } + + io.Copy(dst, src) } type writeFlusher interface { @@ -137,22 +156,14 @@ type maxLatencyWriter struct { dst writeFlusher latency time.Duration - lk sync.Mutex // protects init of done, as well Write + Flush + lk sync.Mutex // protects Write + Flush done chan bool } -func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { +func (m *maxLatencyWriter) Write(p []byte) (int, error) { m.lk.Lock() defer m.lk.Unlock() - if m.done == nil { - m.done = make(chan bool) - go m.flushLoop() - } - n, err = m.dst.Write(p) - if err != nil { - m.done <- true - } - return + return m.dst.Write(p) } func (m *maxLatencyWriter) flushLoop() { @@ -160,13 +171,18 @@ func (m *maxLatencyWriter) flushLoop() { defer t.Stop() for { select { + case <-m.done: + if onExitFlushLoop != nil { + onExitFlushLoop() + } + return case <-t.C: m.lk.Lock() m.dst.Flush() m.lk.Unlock() - case <-m.done: - return } } panic("unreached") } + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go index 28e9c90ad..863927162 100644 --- a/src/pkg/net/http/httputil/reverseproxy_test.go +++ b/src/pkg/net/http/httputil/reverseproxy_test.go @@ -11,7 +11,9 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "time" ) func TestReverseProxy(t *testing.T) { @@ -70,6 +72,47 @@ func TestReverseProxy(t *testing.T) { } } +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + var proxyQueryTests = []struct { baseSuffix string // suffix to add to backend URL reqSuffix string // suffix to add to frontend's request URL @@ -107,3 +150,44 @@ func TestReverseProxyQuery(t *testing.T) { frontend.Close() } } + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + done := make(chan bool) + onExitFlushLoop = func() { done <- true } + defer func() { onExitFlushLoop = nil }() + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + + select { + case <-done: + // OK + case <-time.After(5 * time.Second): + t.Error("maxLatencyWriter flushLoop() never exited") + } +} diff --git a/src/pkg/net/http/jar.go b/src/pkg/net/http/jar.go index 2c2caa251..5c3de0dad 100644 --- a/src/pkg/net/http/jar.go +++ b/src/pkg/net/http/jar.go @@ -8,23 +8,20 @@ import ( "net/url" ) -// A CookieJar manages storage and use of cookies in HTTP requests. +// A CookieJar manages storage and use of cookies in HTTP requests. // // Implementations of CookieJar must be safe for concurrent use by multiple // goroutines. +// +// The net/http/cookiejar package provides a CookieJar implementation. type CookieJar interface { - // SetCookies handles the receipt of the cookies in a reply for the - // given URL. It may or may not choose to save the cookies, depending - // on the jar's policy and implementation. + // SetCookies handles the receipt of the cookies in a reply for the + // given URL. It may or may not choose to save the cookies, depending + // on the jar's policy and implementation. SetCookies(u *url.URL, cookies []*Cookie) // Cookies returns the cookies to send in a request for the given URL. - // It is up to the implementation to honor the standard cookie use - // restrictions such as in RFC 6265. + // It is up to the implementation to honor the standard cookie use + // restrictions such as in RFC 6265. Cookies(u *url.URL) []*Cookie } - -type blackHoleJar struct{} - -func (blackHoleJar) SetCookies(u *url.URL, cookies []*Cookie) {} -func (blackHoleJar) Cookies(u *url.URL) []*Cookie { return nil } diff --git a/src/pkg/net/http/lex.go b/src/pkg/net/http/lex.go index ffb393ccf..cb33318f4 100644 --- a/src/pkg/net/http/lex.go +++ b/src/pkg/net/http/lex.go @@ -6,131 +6,91 @@ package http // This file deals with lexical matters of HTTP -func isSeparator(c byte) bool { - switch c { - case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t': - return true - } - return false +var isTokenTable = [127]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, } -func isCtl(c byte) bool { return (0 <= c && c <= 31) || c == 127 } - -func isChar(c byte) bool { return 0 <= c && c <= 127 } - -func isAnyText(c byte) bool { return !isCtl(c) } - -func isQdText(c byte) bool { return isAnyText(c) && c != '"' } - -func isToken(c byte) bool { return isChar(c) && !isCtl(c) && !isSeparator(c) } - -// Valid escaped sequences are not specified in RFC 2616, so for now, we assume -// that they coincide with the common sense ones used by GO. Malformed -// characters should probably not be treated as errors by a robust (forgiving) -// parser, so we replace them with the '?' character. -func httpUnquotePair(b byte) byte { - // skip the first byte, which should always be '\' - switch b { - case 'a': - return '\a' - case 'b': - return '\b' - case 'f': - return '\f' - case 'n': - return '\n' - case 'r': - return '\r' - case 't': - return '\t' - case 'v': - return '\v' - case '\\': - return '\\' - case '\'': - return '\'' - case '"': - return '"' - } - return '?' -} - -// raw must begin with a valid quoted string. Only the first quoted string is -// parsed and is unquoted in result. eaten is the number of bytes parsed, or -1 -// upon failure. -func httpUnquote(raw []byte) (eaten int, result string) { - buf := make([]byte, len(raw)) - if raw[0] != '"' { - return -1, "" - } - eaten = 1 - j := 0 // # of bytes written in buf - for i := 1; i < len(raw); i++ { - switch b := raw[i]; b { - case '"': - eaten++ - buf = buf[0:j] - return i + 1, string(buf) - case '\\': - if len(raw) < i+2 { - return -1, "" - } - buf[j] = httpUnquotePair(raw[i+1]) - eaten += 2 - j++ - i++ - default: - if isQdText(b) { - buf[j] = b - } else { - buf[j] = '?' - } - eaten++ - j++ - } - } - return -1, "" +func isToken(r rune) bool { + i := int(r) + return i < len(isTokenTable) && isTokenTable[i] } -// This is a best effort parse, so errors are not returned, instead not all of -// the input string might be parsed. result is always non-nil. -func httpSplitFieldValue(fv string) (eaten int, result []string) { - result = make([]string, 0, len(fv)) - raw := []byte(fv) - i := 0 - chunk := "" - for i < len(raw) { - b := raw[i] - switch { - case b == '"': - eaten, unq := httpUnquote(raw[i:len(raw)]) - if eaten < 0 { - return i, result - } else { - i += eaten - chunk += unq - } - case isSeparator(b): - if chunk != "" { - result = result[0 : len(result)+1] - result[len(result)-1] = chunk - chunk = "" - } - i++ - case isToken(b): - chunk += string(b) - i++ - case b == '\n' || b == '\r': - i++ - default: - chunk += "?" - i++ - } - } - if chunk != "" { - result = result[0 : len(result)+1] - result[len(result)-1] = chunk - chunk = "" - } - return i, result +func isNotToken(r rune) bool { + return !isToken(r) } diff --git a/src/pkg/net/http/lex_test.go b/src/pkg/net/http/lex_test.go index 5386f7534..6d9d294f7 100644 --- a/src/pkg/net/http/lex_test.go +++ b/src/pkg/net/http/lex_test.go @@ -8,63 +8,24 @@ import ( "testing" ) -type lexTest struct { - Raw string - Parsed int // # of parsed characters - Result []string -} +func isChar(c rune) bool { return c <= 127 } -var lexTests = []lexTest{ - { - Raw: `"abc"def,:ghi`, - Parsed: 13, - Result: []string{"abcdef", "ghi"}, - }, - // My understanding of the RFC is that escape sequences outside of - // quotes are not interpreted? - { - Raw: `"\t"\t"\t"`, - Parsed: 10, - Result: []string{"\t", "t\t"}, - }, - { - Raw: `"\yab"\r\n`, - Parsed: 10, - Result: []string{"?ab", "r", "n"}, - }, - { - Raw: "ab\f", - Parsed: 3, - Result: []string{"ab?"}, - }, - { - Raw: "\"ab \" c,de f, gh, ij\n\t\r", - Parsed: 23, - Result: []string{"ab ", "c", "de", "f", "gh", "ij"}, - }, -} +func isCtl(c rune) bool { return c <= 31 || c == 127 } -func min(x, y int) int { - if x <= y { - return x +func isSeparator(c rune) bool { + switch c { + case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t': + return true } - return y + return false } -func TestSplitFieldValue(t *testing.T) { - for k, l := range lexTests { - parsed, result := httpSplitFieldValue(l.Raw) - if parsed != l.Parsed { - t.Errorf("#%d: Parsed %d, expected %d", k, parsed, l.Parsed) - } - if len(result) != len(l.Result) { - t.Errorf("#%d: Result len %d, expected %d", k, len(result), len(l.Result)) - } - for i := 0; i < min(len(result), len(l.Result)); i++ { - if result[i] != l.Result[i] { - t.Errorf("#%d: %d-th entry mismatch. Have {%s}, expect {%s}", - k, i, result[i], l.Result[i]) - } +func TestIsToken(t *testing.T) { + for i := 0; i <= 130; i++ { + r := rune(i) + expected := isChar(r) && !isCtl(r) && !isSeparator(r) + if isToken(r) != expected { + t.Errorf("isToken(0x%x) = %v", r, !expected) } } } diff --git a/src/pkg/net/http/npn_test.go b/src/pkg/net/http/npn_test.go new file mode 100644 index 000000000..98b8930d0 --- /dev/null +++ b/src/pkg/net/http/npn_test.go @@ -0,0 +1,118 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bufio" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + . "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNextProtoUpgrade(t *testing.T) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "path=%s,proto=", r.URL.Path) + if r.TLS != nil { + w.Write([]byte(r.TLS.NegotiatedProtocol)) + } + if r.RemoteAddr == "" { + t.Error("request with no RemoteAddr") + } + if r.Body == nil { + t.Errorf("request with nil Body") + } + })) + ts.TLS = &tls.Config{ + NextProtos: []string{"unhandled-proto", "tls-0.9"}, + } + ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){ + "tls-0.9": handleTLSProtocol09, + } + ts.StartTLS() + defer ts.Close() + + tr := newTLSTransport(t, ts) + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + // Normal request, without NPN. + { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if want := "path=/,proto="; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } + + // Request to an advertised but unhandled NPN protocol. + // Server will hang up. + { + tr.CloseIdleConnections() + tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} + _, err := c.Get(ts.URL) + if err == nil { + t.Errorf("expected error on unhandled-proto request") + } + } + + // Request using the "tls-0.9" protocol, which we register here. + // It is HTTP/0.9 over TLS. + { + tlsConfig := newTLSTransport(t, ts).TLSClientConfig + tlsConfig.NextProtos = []string{"tls-0.9"} + conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) + if err != nil { + t.Fatal(err) + } + conn.Write([]byte("GET /foo\n")) + body, err := ioutil.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if want := "path=/foo,proto=tls-0.9"; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } +} + +// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the +// TestNextProtoUpgrade test. +func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) { + br := bufio.NewReader(conn) + line, err := br.ReadString('\n') + if err != nil { + return + } + line = strings.TrimSpace(line) + path := strings.TrimPrefix(line, "GET ") + if path == line { + return + } + req, _ := NewRequest("GET", path, nil) + req.Proto = "HTTP/0.9" + req.ProtoMajor = 0 + req.ProtoMinor = 9 + rw := &http09Writer{conn, make(Header)} + h.ServeHTTP(rw, req) +} + +type http09Writer struct { + io.Writer + h Header +} + +func (w http09Writer) Header() Header { return w.h } +func (w http09Writer) WriteHeader(int) {} // no headers diff --git a/src/pkg/net/http/pprof/pprof.go b/src/pkg/net/http/pprof/pprof.go index 06fcde144..0c7548e3e 100644 --- a/src/pkg/net/http/pprof/pprof.go +++ b/src/pkg/net/http/pprof/pprof.go @@ -14,6 +14,14 @@ // To use pprof, link this package into your program: // import _ "net/http/pprof" // +// If your application is not already running an http server, you +// need to start one. Add "net/http" and "log" to your imports and +// the following code to your main function: +// +// go func() { +// log.Println(http.ListenAndServe("localhost:6060", nil)) +// }() +// // Then use the pprof tool to look at the heap profile: // // go tool pprof http://localhost:6060/debug/pprof/heap @@ -22,9 +30,12 @@ // // go tool pprof http://localhost:6060/debug/pprof/profile // -// Or to view all available profiles: +// Or to look at the goroutine blocking profile: +// +// go tool pprof http://localhost:6060/debug/pprof/block // -// go tool pprof http://localhost:6060/debug/pprof/ +// To view all available profiles, open http://localhost:6060/debug/pprof/ +// in your browser. // // For a study of the facility in action, visit // @@ -161,7 +172,7 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // listing the available profiles. func Index(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.URL.Path, "/debug/pprof/") { - name := r.URL.Path[len("/debug/pprof/"):] + name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/") if name != "" { handler(name).ServeHTTP(w, r) return diff --git a/src/pkg/net/http/proxy_test.go b/src/pkg/net/http/proxy_test.go index 5ecffafac..449ccaeea 100644 --- a/src/pkg/net/http/proxy_test.go +++ b/src/pkg/net/http/proxy_test.go @@ -25,13 +25,13 @@ var UseProxyTests = []struct { {"[::2]", true}, // not a loopback address {"barbaz.net", false}, // match as .barbaz.net - {"foobar.com", false}, // have a port but match + {"foobar.com", false}, // have a port but match {"foofoobar.com", true}, // not match as a part of foobar.com {"baz.com", true}, // not match as a part of barbaz.com {"localhost.net", true}, // not match as suffix of address {"local.localhost", true}, // not match as prefix as address {"barbarbaz.net", true}, // not match because NO_PROXY have a '.' - {"www.foobar.com", true}, // not match because NO_PROXY is not .foobar.com + {"www.foobar.com", false}, // match because NO_PROXY includes "foobar.com" } func TestUseProxy(t *testing.T) { diff --git a/src/pkg/net/http/range_test.go b/src/pkg/net/http/range_test.go index 5274a81fa..ef911af7b 100644 --- a/src/pkg/net/http/range_test.go +++ b/src/pkg/net/http/range_test.go @@ -14,15 +14,34 @@ var ParseRangeTests = []struct { r []httpRange }{ {"", 0, nil}, + {"", 1000, nil}, {"foo", 0, nil}, {"bytes=", 0, nil}, + {"bytes=7", 10, nil}, + {"bytes= 7 ", 10, nil}, + {"bytes=1-", 0, nil}, {"bytes=5-4", 10, nil}, {"bytes=0-2,5-4", 10, nil}, + {"bytes=2-5,4-3", 10, nil}, + {"bytes=--5,4--3", 10, nil}, + {"bytes=A-", 10, nil}, + {"bytes=A- ", 10, nil}, + {"bytes=A-Z", 10, nil}, + {"bytes= -Z", 10, nil}, + {"bytes=5-Z", 10, nil}, + {"bytes=Ran-dom, garbage", 10, nil}, + {"bytes=0x01-0x02", 10, nil}, + {"bytes= ", 10, nil}, + {"bytes= , , , ", 10, nil}, + {"bytes=0-9", 10, []httpRange{{0, 10}}}, {"bytes=0-", 10, []httpRange{{0, 10}}}, {"bytes=5-", 10, []httpRange{{5, 5}}}, {"bytes=0-20", 10, []httpRange{{0, 10}}}, {"bytes=15-,0-5", 10, nil}, + {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}}, + {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}}, + {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}}, {"bytes=-5", 10, []httpRange{{5, 5}}}, {"bytes=-15", 10, []httpRange{{0, 10}}}, {"bytes=0-499", 10000, []httpRange{{0, 500}}}, @@ -32,6 +51,9 @@ var ParseRangeTests = []struct { {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}}, {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}}, {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}}, + + // Match Apache laxity: + {"bytes= 1 -2 , 4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}}, } func TestParseRange(t *testing.T) { diff --git a/src/pkg/net/http/readrequest_test.go b/src/pkg/net/http/readrequest_test.go index 2e03c658a..ffdd6a892 100644 --- a/src/pkg/net/http/readrequest_test.go +++ b/src/pkg/net/http/readrequest_test.go @@ -247,6 +247,54 @@ var reqTests = []reqTest{ noTrailer, noError, }, + + // SSDP Notify request. golang.org/issue/3692 + { + "NOTIFY * HTTP/1.1\r\nServer: foo\r\n\r\n", + &Request{ + Method: "NOTIFY", + URL: &url.URL{ + Path: "*", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Server": []string{"foo"}, + }, + Close: false, + ContentLength: 0, + RequestURI: "*", + }, + + noBody, + noTrailer, + noError, + }, + + // OPTIONS request. Similar to golang.org/issue/3692 + { + "OPTIONS * HTTP/1.1\r\nServer: foo\r\n\r\n", + &Request{ + Method: "OPTIONS", + URL: &url.URL{ + Path: "*", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Server": []string{"foo"}, + }, + Close: false, + ContentLength: 0, + RequestURI: "*", + }, + + noBody, + noTrailer, + noError, + }, } func TestReadRequest(t *testing.T) { diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go index f5bc6eb91..217f35b48 100644 --- a/src/pkg/net/http/request.go +++ b/src/pkg/net/http/request.go @@ -19,6 +19,7 @@ import ( "mime/multipart" "net/textproto" "net/url" + "strconv" "strings" ) @@ -70,7 +71,13 @@ var reqWriteExcludeHeader = map[string]bool{ // or to be sent by a client. type Request struct { Method string // GET, POST, PUT, etc. - URL *url.URL + + // URL is created from the URI supplied on the Request-Line + // as stored in RequestURI. + // + // For most requests, fields other than Path and RawQuery + // will be empty. (See RFC 2616, Section 5.1.2) + URL *url.URL // The protocol version for incoming requests. // Outgoing requests always use HTTP/1.1. @@ -123,6 +130,7 @@ type Request struct { // The host on which the URL is sought. // Per RFC 2616, this is either the value of the Host: header // or the host name given in the URL itself. + // It may be of the form "host:port". Host string // Form contains the parsed form data, including both the URL @@ -131,6 +139,12 @@ type Request struct { // The HTTP client ignores Form and uses Body instead. Form url.Values + // PostForm contains the parsed form data from POST or PUT + // body parameters. + // This field is only available after ParseForm is called. + // The HTTP client ignores PostForm and uses Body instead. + PostForm url.Values + // MultipartForm is the parsed multipart form, including file uploads. // This field is only available after ParseMultipartForm is called. // The HTTP client ignores MultipartForm and uses Body instead. @@ -317,11 +331,20 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err } // TODO(bradfitz): escape at least newlines in ruri? - bw := bufio.NewWriter(w) - fmt.Fprintf(bw, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) + // Wrap the writer in a bufio Writer if it's not already buffered. + // Don't always call NewWriter, as that forces a bytes.Buffer + // and other small bufio Writers to have a minimum 4k buffer + // size. + var bw *bufio.Writer + if _, ok := w.(io.ByteWriter); !ok { + bw = bufio.NewWriter(w) + w = bw + } + + fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) // Header lines - fmt.Fprintf(bw, "Host: %s\r\n", host) + fmt.Fprintf(w, "Host: %s\r\n", host) // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. @@ -332,7 +355,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err } } if userAgent != "" { - fmt.Fprintf(bw, "User-Agent: %s\r\n", userAgent) + fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) } // Process Body,ContentLength,Close,Trailer @@ -340,65 +363,61 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err if err != nil { return err } - err = tw.WriteHeader(bw) + err = tw.WriteHeader(w) if err != nil { return err } // TODO: split long values? (If so, should share code with Conn.Write) - err = req.Header.WriteSubset(bw, reqWriteExcludeHeader) + err = req.Header.WriteSubset(w, reqWriteExcludeHeader) if err != nil { return err } if extraHeaders != nil { - err = extraHeaders.Write(bw) + err = extraHeaders.Write(w) if err != nil { return err } } - io.WriteString(bw, "\r\n") + io.WriteString(w, "\r\n") // Write body and trailer - err = tw.WriteBody(bw) + err = tw.WriteBody(w) if err != nil { return err } - return bw.Flush() -} - -// Convert decimal at s[i:len(s)] to integer, -// returning value, string position where the digits stopped, -// and whether there was a valid number (digits, not too big). -func atoi(s string, i int) (n, i1 int, ok bool) { - const Big = 1000000 - if i >= len(s) || s[i] < '0' || s[i] > '9' { - return 0, 0, false - } - n = 0 - for ; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { - n = n*10 + int(s[i]-'0') - if n > Big { - return 0, 0, false - } + if bw != nil { + return bw.Flush() } - return n, i, true + return nil } // ParseHTTPVersion parses a HTTP version string. // "HTTP/1.0" returns (1, 0, true). func ParseHTTPVersion(vers string) (major, minor int, ok bool) { - if len(vers) < 5 || vers[0:5] != "HTTP/" { + const Big = 1000000 // arbitrary upper bound + switch vers { + case "HTTP/1.1": + return 1, 1, true + case "HTTP/1.0": + return 1, 0, true + } + if !strings.HasPrefix(vers, "HTTP/") { return 0, 0, false } - major, i, ok := atoi(vers, 5) - if !ok || i >= len(vers) || vers[i] != '.' { + dot := strings.Index(vers, ".") + if dot < 0 { return 0, 0, false } - minor, i, ok = atoi(vers, i+1) - if !ok || i != len(vers) { + major, err := strconv.Atoi(vers[5:dot]) + if err != nil || major < 0 || major > Big { + return 0, 0, false + } + minor, err = strconv.Atoi(vers[dot+1:]) + if err != nil || minor < 0 || minor > Big { return 0, 0, false } return major, minor, true @@ -426,10 +445,12 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { } if body != nil { switch v := body.(type) { - case *strings.Reader: - req.ContentLength = int64(v.Len()) case *bytes.Buffer: req.ContentLength = int64(v.Len()) + case *bytes.Reader: + req.ContentLength = int64(v.Len()) + case *strings.Reader: + req.ContentLength = int64(v.Len()) } } @@ -513,9 +534,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { // the same. In the second case, any Host line is ignored. req.Host = req.URL.Host if req.Host == "" { - req.Host = req.Header.Get("Host") + req.Host = req.Header.get("Host") } - req.Header.Del("Host") + delete(req.Header, "Host") fixPragmaCacheControl(req.Header) @@ -594,66 +615,97 @@ func (l *maxBytesReader) Close() error { return l.r.Close() } -// ParseForm parses the raw query from the URL. +func copyValues(dst, src url.Values) { + for k, vs := range src { + for _, value := range vs { + dst.Add(k, value) + } + } +} + +func parsePostForm(r *Request) (vs url.Values, err error) { + if r.Body == nil { + err = errors.New("missing form body") + return + } + ct := r.Header.Get("Content-Type") + ct, _, err = mime.ParseMediaType(ct) + switch { + case ct == "application/x-www-form-urlencoded": + var reader io.Reader = r.Body + maxFormSize := int64(1<<63 - 1) + if _, ok := r.Body.(*maxBytesReader); !ok { + maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + reader = io.LimitReader(r.Body, maxFormSize+1) + } + b, e := ioutil.ReadAll(reader) + if e != nil { + if err == nil { + err = e + } + break + } + if int64(len(b)) > maxFormSize { + err = errors.New("http: POST too large") + return + } + vs, e = url.ParseQuery(string(b)) + if err == nil { + err = e + } + case ct == "multipart/form-data": + // handled by ParseMultipartForm (which is calling us, or should be) + // TODO(bradfitz): there are too many possible + // orders to call too many functions here. + // Clean this up and write more tests. + // request_test.go contains the start of this, + // in TestRequestMultipartCallOrder. + } + return +} + +// ParseForm parses the raw query from the URL and updates r.Form. +// +// For POST or PUT requests, it also parses the request body as a form and +// put the results into both r.PostForm and r.Form. +// POST and PUT body parameters take precedence over URL query string values +// in r.Form. // -// For POST or PUT requests, it also parses the request body as a form. // If the request Body's size has not already been limited by MaxBytesReader, // the size is capped at 10MB. // // ParseMultipartForm calls ParseForm automatically. // It is idempotent. -func (r *Request) ParseForm() (err error) { - if r.Form != nil { - return - } - if r.URL != nil { - r.Form, err = url.ParseQuery(r.URL.RawQuery) +func (r *Request) ParseForm() error { + var err error + if r.PostForm == nil { + if r.Method == "POST" || r.Method == "PUT" { + r.PostForm, err = parsePostForm(r) + } + if r.PostForm == nil { + r.PostForm = make(url.Values) + } } - if r.Method == "POST" || r.Method == "PUT" { - if r.Body == nil { - return errors.New("missing form body") + if r.Form == nil { + if len(r.PostForm) > 0 { + r.Form = make(url.Values) + copyValues(r.Form, r.PostForm) } - ct := r.Header.Get("Content-Type") - ct, _, err = mime.ParseMediaType(ct) - switch { - case ct == "application/x-www-form-urlencoded": - var reader io.Reader = r.Body - maxFormSize := int64(1<<63 - 1) - if _, ok := r.Body.(*maxBytesReader); !ok { - maxFormSize = int64(10 << 20) // 10 MB is a lot of text. - reader = io.LimitReader(r.Body, maxFormSize+1) - } - b, e := ioutil.ReadAll(reader) - if e != nil { - if err == nil { - err = e - } - break - } - if int64(len(b)) > maxFormSize { - return errors.New("http: POST too large") - } - var newValues url.Values - newValues, e = url.ParseQuery(string(b)) + var newValues url.Values + if r.URL != nil { + var e error + newValues, e = url.ParseQuery(r.URL.RawQuery) if err == nil { err = e } - if r.Form == nil { - r.Form = make(url.Values) - } - // Copy values into r.Form. TODO: make this smoother. - for k, vs := range newValues { - for _, value := range vs { - r.Form.Add(k, value) - } - } - case ct == "multipart/form-data": - // handled by ParseMultipartForm (which is calling us, or should be) - // TODO(bradfitz): there are too many possible - // orders to call too many functions here. - // Clean this up and write more tests. - // request_test.go contains the start of this, - // in TestRequestMultipartCallOrder. + } + if newValues == nil { + newValues = make(url.Values) + } + if r.Form == nil { + r.Form = newValues + } else { + copyValues(r.Form, newValues) } } return err @@ -699,7 +751,9 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error { } // FormValue returns the first value for the named component of the query. +// POST and PUT body parameters take precedence over URL query string values. // FormValue calls ParseMultipartForm and ParseForm if necessary. +// To access multiple values of the same key use ParseForm. func (r *Request) FormValue(key string) string { if r.Form == nil { r.ParseMultipartForm(defaultMaxMemory) @@ -710,6 +764,19 @@ func (r *Request) FormValue(key string) string { return "" } +// PostFormValue returns the first value for the named component of the POST +// or PUT request body. URL query parameters are ignored. +// PostFormValue calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) PostFormValue(key string) string { + if r.PostForm == nil { + r.ParseMultipartForm(defaultMaxMemory) + } + if vs := r.PostForm[key]; len(vs) > 0 { + return vs[0] + } + return "" +} + // FormFile returns the first file for the provided form key. // FormFile calls ParseMultipartForm and ParseForm if necessary. func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) { @@ -732,12 +799,16 @@ func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, e } func (r *Request) expectsContinue() bool { - return strings.ToLower(r.Header.Get("Expect")) == "100-continue" + return hasToken(r.Header.get("Expect"), "100-continue") } func (r *Request) wantsHttp10KeepAlive() bool { if r.ProtoMajor != 1 || r.ProtoMinor != 0 { return false } - return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "keep-alive") + return hasToken(r.Header.get("Connection"), "keep-alive") +} + +func (r *Request) wantsClose() bool { + return hasToken(r.Header.get("Connection"), "close") } diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go index 6e00b9bfd..00ad791de 100644 --- a/src/pkg/net/http/request_test.go +++ b/src/pkg/net/http/request_test.go @@ -30,8 +30,8 @@ func TestQuery(t *testing.T) { } func TestPostQuery(t *testing.T) { - req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x", - strings.NewReader("z=post&both=y")) + req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", + strings.NewReader("z=post&both=y&prio=2&empty=")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") if q := req.FormValue("q"); q != "foo" { @@ -40,8 +40,23 @@ func TestPostQuery(t *testing.T) { if z := req.FormValue("z"); z != "post" { t.Errorf(`req.FormValue("z") = %q, want "post"`, z) } - if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"x", "y"}) { - t.Errorf(`req.FormValue("both") = %q, want ["x", "y"]`, both) + if bq, found := req.PostForm["q"]; found { + t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq) + } + if bz := req.PostFormValue("z"); bz != "post" { + t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz) + } + if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) { + t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs) + } + if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) { + t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both) + } + if prio := req.FormValue("prio"); prio != "2" { + t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) + } + if empty := req.FormValue("empty"); empty != "" { + t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) } } @@ -76,6 +91,23 @@ func TestParseFormUnknownContentType(t *testing.T) { } } +func TestParseFormInitializeOnError(t *testing.T) { + nilBody, _ := NewRequest("POST", "http://www.google.com/search?q=foo", nil) + tests := []*Request{ + nilBody, + {Method: "GET", URL: nil}, + } + for i, req := range tests { + err := req.ParseForm() + if req.Form == nil { + t.Errorf("%d. Form not initialized, error %v", i, err) + } + if req.PostForm == nil { + t.Errorf("%d. PostForm not initialized, error %v", i, err) + } + } +} + func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", @@ -129,7 +161,7 @@ func TestSetBasicAuth(t *testing.T) { } func TestMultipartRequest(t *testing.T) { - // Test that we can read the values and files of a + // Test that we can read the values and files of a // multipart request with FormValue and FormFile, // and that ParseMultipartForm can be called multiple times. req := newTestMultipartRequest(t) @@ -196,6 +228,75 @@ func TestReadRequestErrors(t *testing.T) { } } +func TestNewRequestHost(t *testing.T) { + req, err := NewRequest("GET", "http://localhost:1234/", nil) + if err != nil { + t.Fatal(err) + } + if req.Host != "localhost:1234" { + t.Errorf("Host = %q; want localhost:1234", req.Host) + } +} + +func TestNewRequestContentLength(t *testing.T) { + readByte := func(r io.Reader) io.Reader { + var b [1]byte + r.Read(b[:]) + return r + } + tests := []struct { + r io.Reader + want int64 + }{ + {bytes.NewReader([]byte("123")), 3}, + {bytes.NewBuffer([]byte("1234")), 4}, + {strings.NewReader("12345"), 5}, + // Not detected: + {struct{ io.Reader }{strings.NewReader("xyz")}, 0}, + {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0}, + {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0}, + } + for _, tt := range tests { + req, err := NewRequest("POST", "http://localhost/", tt.r) + if err != nil { + t.Fatal(err) + } + if req.ContentLength != tt.want { + t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want) + } + } +} + +type logWrites struct { + t *testing.T + dst *[]string +} + +func (l logWrites) WriteByte(c byte) error { + l.t.Fatalf("unexpected WriteByte call") + return nil +} + +func (l logWrites) Write(p []byte) (n int, err error) { + *l.dst = append(*l.dst, string(p)) + return len(p), nil +} + +func TestRequestWriteBufferedWriter(t *testing.T) { + got := []string{} + req, _ := NewRequest("GET", "http://foo.com/", nil) + req.Write(logWrites{t, &got}) + want := []string{ + "GET / HTTP/1.1\r\n", + "Host: foo.com\r\n", + "User-Agent: Go http package\r\n", + "\r\n", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Writes = %q\n Want = %q", got, want) + } +} + func testMissingFile(t *testing.T, req *Request) { f, fh, err := req.FormFile("missing") if f != nil { @@ -300,3 +401,81 @@ Content-Disposition: form-data; name="textb" ` + textbValue + ` --MyBoundary-- ` + +func benchmarkReadRequest(b *testing.B, request string) { + request = request + "\n" // final \n + request = strings.Replace(request, "\n", "\r\n", -1) // expand \n to \r\n + b.SetBytes(int64(len(request))) + r := bufio.NewReader(&infiniteReader{buf: []byte(request)}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ReadRequest(r) + if err != nil { + b.Fatalf("failed to read request: %v", err) + } + } +} + +// infiniteReader satisfies Read requests as if the contents of buf +// loop indefinitely. +type infiniteReader struct { + buf []byte + offset int +} + +func (r *infiniteReader) Read(b []byte) (int, error) { + n := copy(b, r.buf[r.offset:]) + r.offset = (r.offset + n) % len(r.buf) + return n, nil +} + +func BenchmarkReadRequestChrome(b *testing.B) { + // https://github.com/felixge/node-http-perf/blob/master/fixtures/get.http + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Connection: keep-alive +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +Cookie: __utma=1.1978842379.1323102373.1323102373.1323102373.1; EPi:NumberOfVisits=1,2012-02-28T13:42:18; CrmSession=5b707226b9563e1bc69084d07a107c98; plushContainerWidth=100%25; plushNoTopMenu=0; hudson_auto_refresh=false +`) +} + +func BenchmarkReadRequestCurl(b *testing.B) { + // curl http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +User-Agent: curl/7.27.0 +Host: localhost:8080 +Accept: */* +`) +} + +func BenchmarkReadRequestApachebench(b *testing.B) { + // ab -n 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.0 +Host: localhost:8080 +User-Agent: ApacheBench/2.3 +Accept: */* +`) +} + +func BenchmarkReadRequestSiege(b *testing.B) { + // siege -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Accept: */* +Accept-Encoding: gzip +User-Agent: JoeDog/1.00 [en] (X11; I; Siege 2.70) +Connection: keep-alive +`) +} + +func BenchmarkReadRequestWrk(b *testing.B) { + // wrk -t 1 -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +`) +} diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go index fc3186f0c..bc637f18b 100644 --- a/src/pkg/net/http/requestwrite_test.go +++ b/src/pkg/net/http/requestwrite_test.go @@ -328,6 +328,69 @@ var reqWriteTests = []reqWriteTest{ "User-Agent: Go http package\r\n" + "X-Foo: X-Bar\r\n\r\n", }, + + // If no Request.Host and no Request.URL.Host, we send + // an empty Host header, and don't use + // Request.Header["Host"]. This is just testing that + // we don't change Go 1.0 behavior. + { + Req: Request{ + Method: "GET", + Host: "", + URL: &url.URL{ + Scheme: "http", + Host: "", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Host": []string{"bad.example.com"}, + }, + }, + + WantWrite: "GET /search HTTP/1.1\r\n" + + "Host: \r\n" + + "User-Agent: Go http package\r\n\r\n", + }, + + // Opaque test #1 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Opaque: "/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n\r\n", + }, + + // Opaque test #2 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "x.google.com", + Opaque: "//y.google.com/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" + + "Host: x.google.com\r\n" + + "User-Agent: Go http package\r\n\r\n", + }, } func TestRequestWrite(t *testing.T) { diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go index 945ecd8a4..391ebbf6d 100644 --- a/src/pkg/net/http/response.go +++ b/src/pkg/net/http/response.go @@ -49,7 +49,7 @@ type Response struct { Body io.ReadCloser // ContentLength records the length of the associated content. The - // value -1 indicates that the length is unknown. Unless RequestMethod + // value -1 indicates that the length is unknown. Unless Request.Method // is "HEAD", values >= 0 indicate that the given number of bytes may // be read from Body. ContentLength int64 @@ -107,7 +107,6 @@ func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err error) { resp = new(Response) resp.Request = req - resp.Request.Method = strings.ToUpper(resp.Request.Method) // Parse the first line of the response. line, err := tp.ReadLine() @@ -179,7 +178,7 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { // StatusCode // ProtoMajor // ProtoMinor -// RequestMethod +// Request.Method // TransferEncoding // Trailer // Body @@ -188,11 +187,6 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { // func (r *Response) Write(w io.Writer) error { - // RequestMethod should be upper-case - if r.Request != nil { - r.Request.Method = strings.ToUpper(r.Request.Method) - } - // Status line text := r.Status if text == "" { @@ -204,9 +198,7 @@ func (r *Response) Write(w io.Writer) error { } protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor) statusCode := strconv.Itoa(r.StatusCode) + " " - if strings.HasPrefix(text, statusCode) { - text = text[len(statusCode):] - } + text = strings.TrimPrefix(text, statusCode) io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n") // Process Body,ContentLength,Close,Trailer diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go index 6eed4887d..2f5f77369 100644 --- a/src/pkg/net/http/response_test.go +++ b/src/pkg/net/http/response_test.go @@ -124,7 +124,7 @@ var respTests = []respTest{ // Chunked response without Content-Length. { - "HTTP/1.0 200 OK\r\n" + + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0a\r\n" + @@ -137,12 +137,12 @@ var respTests = []respTest{ Response{ Status: "200 OK", StatusCode: 200, - Proto: "HTTP/1.0", + Proto: "HTTP/1.1", ProtoMajor: 1, - ProtoMinor: 0, + ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Close: true, + Close: false, ContentLength: -1, TransferEncoding: []string{"chunked"}, }, @@ -152,24 +152,24 @@ var respTests = []respTest{ // Chunked response with Content-Length. { - "HTTP/1.0 200 OK\r\n" + + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "Content-Length: 10\r\n" + "\r\n" + "0a\r\n" + - "Body here\n" + + "Body here\n\r\n" + "0\r\n" + "\r\n", Response{ Status: "200 OK", StatusCode: 200, - Proto: "HTTP/1.0", + Proto: "HTTP/1.1", ProtoMajor: 1, - ProtoMinor: 0, + ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Close: true, + Close: false, ContentLength: -1, // TODO(rsc): Fix? TransferEncoding: []string{"chunked"}, }, @@ -177,23 +177,88 @@ var respTests = []respTest{ "Body here\n", }, - // Chunked response in response to a HEAD request (the "chunked" should - // be ignored, as HEAD responses never have bodies) + // Chunked response in response to a HEAD request { - "HTTP/1.0 200 OK\r\n" + + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n", Response{ - Status: "200 OK", - StatusCode: 200, - Proto: "HTTP/1.0", - ProtoMajor: 1, - ProtoMinor: 0, - Request: dummyReq("HEAD"), - Header: Header{}, - Close: true, - ContentLength: 0, + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("HEAD"), + Header: Header{}, + TransferEncoding: []string{"chunked"}, + Close: false, + ContentLength: -1, + }, + + "", + }, + + // Content-Length in response to a HEAD request + { + "HTTP/1.0 200 OK\r\n" + + "Content-Length: 256\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("HEAD"), + Header: Header{"Content-Length": {"256"}}, + TransferEncoding: nil, + Close: true, + ContentLength: 256, + }, + + "", + }, + + // Content-Length in response to a HEAD request with HTTP/1.1 + { + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 256\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("HEAD"), + Header: Header{"Content-Length": {"256"}}, + TransferEncoding: nil, + Close: false, + ContentLength: 256, + }, + + "", + }, + + // No Content-Length or Chunked in response to a HEAD request + { + "HTTP/1.0 200 OK\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("HEAD"), + Header: Header{}, + TransferEncoding: nil, + Close: true, + ContentLength: -1, }, "", @@ -259,16 +324,37 @@ var respTests = []respTest{ "", }, + + // golang.org/issue/4767: don't special-case multipart/byteranges responses + { + `HTTP/1.1 206 Partial Content +Connection: close +Content-Type: multipart/byteranges; boundary=18a75608c8f47cef + +some body`, + Response{ + Status: "206 Partial Content", + StatusCode: 206, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Content-Type": []string{"multipart/byteranges; boundary=18a75608c8f47cef"}, + }, + Close: true, + ContentLength: -1, + }, + + "some body", + }, } func TestReadResponse(t *testing.T) { - for i := range respTests { - tt := &respTests[i] - var braw bytes.Buffer - braw.WriteString(tt.Raw) - resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.Request) + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) if err != nil { - t.Errorf("#%d: %s", i, err) + t.Errorf("#%d: %v", i, err) continue } rbody := resp.Body @@ -276,7 +362,11 @@ func TestReadResponse(t *testing.T) { diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp) var bout bytes.Buffer if rbody != nil { - io.Copy(&bout, rbody) + _, err = io.Copy(&bout, rbody) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } rbody.Close() } body := bout.String() @@ -286,6 +376,22 @@ func TestReadResponse(t *testing.T) { } } +func TestWriteResponse(t *testing.T) { + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + bout := bytes.NewBuffer(nil) + err = resp.Write(bout) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + } +} + var readResponseCloseInMiddleTests = []struct { chunked, compressed bool }{ diff --git a/src/pkg/net/http/responsewrite_test.go b/src/pkg/net/http/responsewrite_test.go index f8e63acf4..5c10e2161 100644 --- a/src/pkg/net/http/responsewrite_test.go +++ b/src/pkg/net/http/responsewrite_test.go @@ -15,83 +15,83 @@ type respWriteTest struct { Raw string } -var respWriteTests = []respWriteTest{ - // HTTP/1.0, identity coding; no trailer - { - Response{ - StatusCode: 503, - ProtoMajor: 1, - ProtoMinor: 0, - Request: dummyReq("GET"), - Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), - ContentLength: 6, - }, +func TestResponseWrite(t *testing.T) { + respWriteTests := []respWriteTest{ + // HTTP/1.0, identity coding; no trailer + { + Response{ + StatusCode: 503, + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + ContentLength: 6, + }, - "HTTP/1.0 503 Service Unavailable\r\n" + - "Content-Length: 6\r\n\r\n" + - "abcdef", - }, - // Unchunked response without Content-Length. - { - Response{ - StatusCode: 200, - ProtoMajor: 1, - ProtoMinor: 0, - Request: dummyReq("GET"), - Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), - ContentLength: -1, + "HTTP/1.0 503 Service Unavailable\r\n" + + "Content-Length: 6\r\n\r\n" + + "abcdef", }, - "HTTP/1.0 200 OK\r\n" + - "\r\n" + - "abcdef", - }, - // HTTP/1.1, chunked coding; empty trailer; close - { - Response{ - StatusCode: 200, - ProtoMajor: 1, - ProtoMinor: 1, - Request: dummyReq("GET"), - Header: Header{}, - Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), - ContentLength: 6, - TransferEncoding: []string{"chunked"}, - Close: true, + // Unchunked response without Content-Length. + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + ContentLength: -1, + }, + "HTTP/1.0 200 OK\r\n" + + "\r\n" + + "abcdef", }, + // HTTP/1.1, chunked coding; empty trailer; close + { + Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{}, + Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), + ContentLength: 6, + TransferEncoding: []string{"chunked"}, + Close: true, + }, - "HTTP/1.1 200 OK\r\n" + - "Connection: close\r\n" + - "Transfer-Encoding: chunked\r\n\r\n" + - "6\r\nabcdef\r\n0\r\n\r\n", - }, + "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "6\r\nabcdef\r\n0\r\n\r\n", + }, - // Header value with a newline character (Issue 914). - // Also tests removal of leading and trailing whitespace. - { - Response{ - StatusCode: 204, - ProtoMajor: 1, - ProtoMinor: 1, - Request: dummyReq("GET"), - Header: Header{ - "Foo": []string{" Bar\nBaz "}, + // Header value with a newline character (Issue 914). + // Also tests removal of leading and trailing whitespace. + { + Response{ + StatusCode: 204, + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Foo": []string{" Bar\nBaz "}, + }, + Body: nil, + ContentLength: 0, + TransferEncoding: []string{"chunked"}, + Close: true, }, - Body: nil, - ContentLength: 0, - TransferEncoding: []string{"chunked"}, - Close: true, - }, - "HTTP/1.1 204 No Content\r\n" + - "Connection: close\r\n" + - "Foo: Bar Baz\r\n" + - "\r\n", - }, -} + "HTTP/1.1 204 No Content\r\n" + + "Connection: close\r\n" + + "Foo: Bar Baz\r\n" + + "\r\n", + }, + } -func TestResponseWrite(t *testing.T) { for i := range respWriteTests { tt := &respWriteTests[i] var braw bytes.Buffer diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index b6a6b4c77..3300fef59 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -20,8 +20,13 @@ import ( "net/http/httputil" "net/url" "os" + "os/exec" "reflect" + "runtime" + "strconv" "strings" + "sync" + "sync/atomic" "syscall" "testing" "time" @@ -62,6 +67,7 @@ func (a dummyAddr) String() string { type testConn struct { readBuf bytes.Buffer writeBuf bytes.Buffer + closec chan bool // if non-nil, send value to it on close } func (c *testConn) Read(b []byte) (int, error) { @@ -73,6 +79,10 @@ func (c *testConn) Write(b []byte) (int, error) { } func (c *testConn) Close() error { + select { + case c.closec <- true: + default: + } return nil } @@ -168,13 +178,18 @@ var vtests = []struct { {"http://someHost.com/someDir/apage", "someHost.com/someDir"}, {"http://otherHost.com/someDir/apage", "someDir"}, {"http://otherHost.com/aDir/apage", "Default"}, + // redirections for trees + {"http://localhost/someDir", "/someDir/"}, + {"http://someHost.com/someDir", "/someDir/"}, } func TestHostHandlers(t *testing.T) { + defer checkLeakedTransports(t) + mux := NewServeMux() for _, h := range handlers { - Handle(h.pattern, stringHandler(h.msg)) + mux.Handle(h.pattern, stringHandler(h.msg)) } - ts := httptest.NewServer(nil) + ts := httptest.NewServer(mux) defer ts.Close() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -199,9 +214,19 @@ func TestHostHandlers(t *testing.T) { t.Errorf("reading response: %v", err) continue } - s := r.Header.Get("Result") - if s != vt.expected { - t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + switch r.StatusCode { + case StatusOK: + s := r.Header.Get("Result") + if s != vt.expected { + t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + } + case StatusMovedPermanently: + s := r.Header.Get("Location") + if s != vt.expected { + t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + } + default: + t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode) } } } @@ -232,28 +257,22 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - // TODO(bradfitz): convert this to use httptest.Server - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen error: %v", err) - } - addr, _ := l.Addr().(*net.TCPAddr) - + defer checkLeakedTransports(t) reqNum := 0 - handler := HandlerFunc(func(res ResponseWriter, req *Request) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - }) - - server := &Server{Handler: handler, ReadTimeout: 250 * time.Millisecond, WriteTimeout: 250 * time.Millisecond} - go server.Serve(l) - - url := fmt.Sprintf("http://%s/", addr) + })) + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.Start() + defer ts.Close() // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test + defer tr.CloseIdleConnections() c := &Client{Transport: tr} - r, err := c.Get(url) + r, err := c.Get(ts.URL) if err != nil { t.Fatalf("http Get #1: %v", err) } @@ -266,13 +285,13 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Now() - conn, err := net.Dial("tcp", addr.String()) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } buf := make([]byte, 1) n, err := conn.Read(buf) - latency := time.Now().Sub(t1) + latency := time.Since(t1) if n != 0 || err != io.EOF { t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) } @@ -283,7 +302,7 @@ func TestServerTimeouts(t *testing.T) { // Hit the HTTP server successfully again, verifying that the // previous slow connection didn't run our handler. (that we // get "req=2", not "req=3") - r, err = Get(url) + r, err = Get(ts.URL) if err != nil { t.Fatalf("http Get #2: %v", err) } @@ -293,11 +312,87 @@ func TestServerTimeouts(t *testing.T) { t.Errorf("Get #2 got %q, want %q", string(got), expected) } - l.Close() + if !testing.Short() { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + go io.Copy(ioutil.Discard, conn) + for i := 0; i < 5; i++ { + _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("on write %d: %v", i, err) + } + time.Sleep(ts.Config.ReadTimeout / 2) + } + } +} + +// golang.org/issue/4741 -- setting only a write timeout that triggers +// shouldn't cause a handler to block forever on reads (next HTTP +// request) that will never happen. +func TestOnlyWriteTimeout(t *testing.T) { + defer checkLeakedTransports(t) + var conn net.Conn + var afterTimeoutErrc = make(chan error, 1) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + buf := make([]byte, 512<<10) + _, err := w.Write(buf) + if err != nil { + t.Errorf("handler Write error: %v", err) + return + } + conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) + _, err = w.Write(buf) + afterTimeoutErrc <- err + })) + ts.Listener = trackLastConnListener{ts.Listener, &conn} + ts.Start() + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + errc := make(chan error) + go func() { + res, err := c.Get(ts.URL) + if err != nil { + errc <- err + return + } + _, err = io.Copy(ioutil.Discard, res.Body) + errc <- err + }() + select { + case err := <-errc: + if err == nil { + t.Errorf("expected an error from Get request") + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Get error") + } + if err := <-afterTimeoutErrc; err == nil { + t.Error("expected write error after timeout") + } +} + +// trackLastConnListener tracks the last net.Conn that was accepted. +type trackLastConnListener struct { + net.Listener + last *net.Conn // destination } -// TestIdentityResponse verifies that a handler can unset +func (l trackLastConnListener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + *l.last = c + return +} + +// TestIdentityResponse verifies that a handler can unset func TestIdentityResponse(t *testing.T) { + defer checkLeakedTransports(t) handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -343,10 +438,12 @@ func TestIdentityResponse(t *testing.T) { // Verify that ErrContentLength is returned url := ts.URL + "/?overwrite=1" - _, err := Get(url) + res, err := Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } + res.Body.Close() + // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -370,7 +467,8 @@ func TestIdentityResponse(t *testing.T) { }) } -func testTcpConnectionCloses(t *testing.T, req string, h Handler) { +func testTCPConnectionCloses(t *testing.T, req string, h Handler) { + defer checkLeakedTransports(t) s := httptest.NewServer(h) defer s.Close() @@ -386,17 +484,18 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) { } r := bufio.NewReader(conn) - _, err = ReadResponse(r, &Request{Method: "GET"}) + res, err := ReadResponse(r, &Request{Method: "GET"}) if err != nil { t.Fatal("ReadResponse error:", err) } - success := make(chan bool) + didReadAll := make(chan bool, 1) go func() { select { case <-time.After(5 * time.Second): - t.Fatal("body not closed after 5s") - case <-success: + t.Error("body not closed after 5s") + return + case <-didReadAll: } }() @@ -404,32 +503,43 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) { if err != nil { t.Fatal("read error:", err) } + didReadAll <- true - success <- true + if !res.Close { + t.Errorf("Response.Close = false; want true") + } } // TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. func TestServeHTTP10Close(t *testing.T) { - testTcpConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") })) } +// TestClientCanClose verifies that clients can also force a connection to close. +func TestClientCanClose(t *testing.T) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + // Nothing. + })) +} + // TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close, // even for HTTP/1.1 requests. func TestHandlersCanSetConnectionClose11(t *testing.T) { - testTcpConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") })) } func TestHandlersCanSetConnectionClose10(t *testing.T) { - testTcpConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") })) } func TestSetsRemoteAddr(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) @@ -450,11 +560,13 @@ func TestSetsRemoteAddr(t *testing.T) { } func TestChunkedResponseHeaders(t *testing.T) { + defer checkLeakedTransports(t) log.SetOutput(ioutil.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted + w.(Flusher).Flush() fmt.Fprintf(w, "I am a chunked response.") })) defer ts.Close() @@ -463,6 +575,7 @@ func TestChunkedResponseHeaders(t *testing.T) { if err != nil { t.Fatalf("Get error: %v", err) } + defer res.Body.Close() if g, e := res.ContentLength, int64(-1); g != e { t.Errorf("expected ContentLength of %d; got %d", e, g) } @@ -478,6 +591,7 @@ func TestChunkedResponseHeaders(t *testing.T) { // chunking in their response headers and aren't allowed to produce // output. func Test304Responses(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) @@ -507,6 +621,7 @@ func Test304Responses(t *testing.T) { // allowed to produce output, and don't set a Content-Type since // the real type of the body data cannot be inferred. func TestHeadResponses(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("Ignored body")) if err != ErrBodyNotAllowed { @@ -541,6 +656,7 @@ func TestHeadResponses(t *testing.T) { } func TestTLSHandshakeTimeout(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) ts.Config.ReadTimeout = 250 * time.Millisecond ts.StartTLS() @@ -560,6 +676,7 @@ func TestTLSHandshakeTimeout(t *testing.T) { } func TestTLSServer(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") @@ -642,6 +759,7 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. func TestServerExpect(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only @@ -661,30 +779,51 @@ func TestServerExpect(t *testing.T) { t.Fatalf("Dial: %v", err) } defer conn.Close() - sendf := func(format string, args ...interface{}) { - _, err := fmt.Fprintf(conn, format, args...) - if err != nil { - t.Fatalf("On test %#v, error writing %q: %v", test, format, err) - } - } + + // Only send the body immediately if we're acting like an HTTP client + // that doesn't send 100-continue expectations. + writeBody := test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" + go func() { - sendf("POST /?readbody=%v HTTP/1.1\r\n"+ + _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+ "Connection: close\r\n"+ "Content-Length: %d\r\n"+ "Expect: %s\r\nHost: foo\r\n\r\n", test.readBody, test.contentLength, test.expectation) - if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" { + if err != nil { + t.Errorf("On test %#v, error writing request headers: %v", test, err) + return + } + if writeBody { body := strings.Repeat("A", test.contentLength) - sendf(body) + _, err = fmt.Fprint(conn, body) + if err != nil { + if !test.readBody { + // Server likely already hung up on us. + // See larger comment below. + t.Logf("On test %#v, acceptable error writing request body: %v", test, err) + return + } + t.Errorf("On test %#v, error writing request body: %v", test, err) + } } }() bufr := bufio.NewReader(conn) line, err := bufr.ReadString('\n') if err != nil { - t.Fatalf("ReadString: %v", err) + if writeBody && !test.readBody { + // This is an acceptable failure due to a possible TCP race: + // We were still writing data and the server hung up on us. A TCP + // implementation may send a RST if our request body data was known + // to be lost, which may trigger our reads to fail. + // See RFC 1122 page 88. + t.Logf("On test %#v, acceptable error from ReadString: %v", test, err) + return + } + t.Fatalf("On test %#v, ReadString: %v", test, err) } if !strings.Contains(line, test.expectedResponse) { - t.Errorf("for test %#v got first line=%q", test, line) + t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse) } } @@ -714,6 +853,7 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) { t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len()) } rw.WriteHeader(200) + rw.(Flusher).Flush() if g, e := conn.readBuf.Len(), 0; g != e { t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) } @@ -736,27 +876,28 @@ func TestServerUnreadRequestBodyLarge(t *testing.T) { "Content-Length: %d\r\n"+ "\r\n", len(body)))) conn.readBuf.Write([]byte(body)) - - done := make(chan bool) + conn.closec = make(chan bool, 1) ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { - defer close(done) if conn.readBuf.Len() < len(body)/2 { t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) } rw.WriteHeader(200) + rw.(Flusher).Flush() if conn.readBuf.Len() < len(body)/2 { t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) } - if c := rw.Header().Get("Connection"); c != "close" { - t.Errorf(`Connection header = %q; want "close"`, c) - } })) - <-done + <-conn.closec + + if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") { + t.Errorf("Expected a Connection: close header; got response: %s", res) + } } func TestTimeoutHandler(t *testing.T) { + defer checkLeakedTransports(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -831,6 +972,7 @@ func TestRedirectMunging(t *testing.T) { // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. func TestZeroLengthPostAndResponse(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := ioutil.ReadAll(r.Body) if err != nil { @@ -868,15 +1010,20 @@ func TestZeroLengthPostAndResponse(t *testing.T) { } } +func TestHandlerPanicNil(t *testing.T) { + testHandlerPanic(t, false, nil) +} + func TestHandlerPanic(t *testing.T) { - testHandlerPanic(t, false) + testHandlerPanic(t, false, "intentional death for testing") } func TestHandlerPanicWithHijack(t *testing.T) { - testHandlerPanic(t, true) + testHandlerPanic(t, true, "intentional death for testing") } -func testHandlerPanic(t *testing.T, withHijack bool) { +func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { + defer checkLeakedTransports(t) // Unlike the other tests that set the log output to ioutil.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -896,6 +1043,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) { pr, pw := io.Pipe() log.SetOutput(pw) defer log.SetOutput(os.Stderr) + defer pw.Close() ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if withHijack { @@ -905,7 +1053,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) { } defer rwc.Close() } - panic("intentional death for testing") + panic(panicValue) })) defer ts.Close() @@ -917,8 +1065,8 @@ func testHandlerPanic(t *testing.T, withHijack bool) { buf := make([]byte, 4<<10) _, err := pr.Read(buf) pr.Close() - if err != nil { - t.Fatal(err) + if err != nil && err != io.EOF { + t.Error(err) } done <- true }() @@ -928,6 +1076,10 @@ func testHandlerPanic(t *testing.T, withHijack bool) { t.Logf("expected an error") } + if panicValue == nil { + return + } + select { case <-done: return @@ -937,6 +1089,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) { } func TestNoDate(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()["Date"] = nil })) @@ -952,6 +1105,7 @@ func TestNoDate(t *testing.T) { } func TestStripPrefix(t *testing.T) { + defer checkLeakedTransports(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) }) @@ -965,6 +1119,7 @@ func TestStripPrefix(t *testing.T) { if g, e := res.Header.Get("X-Path"), "/bar"; g != e { t.Errorf("test 1: got %s, want %s", g, e) } + res.Body.Close() res, err = Get(ts.URL + "/bar") if err != nil { @@ -973,9 +1128,11 @@ func TestStripPrefix(t *testing.T) { if g, e := res.StatusCode, 404; g != e { t.Errorf("test 2: got status %v, want %v", g, e) } + res.Body.Close() } func TestRequestLimit(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") })) @@ -992,6 +1149,7 @@ func TestRequestLimit(t *testing.T) { // we do support it (at least currently), so we expect a response below. t.Fatalf("Do: %v", err) } + defer res.Body.Close() if res.StatusCode != 413 { t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status) } @@ -1013,11 +1171,12 @@ type countReader struct { func (cr countReader) Read(p []byte) (n int, err error) { n, err = cr.r.Read(p) - *cr.n += int64(n) + atomic.AddInt64(cr.n, int64(n)) return } func TestRequestBodyLimit(t *testing.T) { + defer checkLeakedTransports(t) const limit = 1 << 20 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) @@ -1031,8 +1190,8 @@ func TestRequestBodyLimit(t *testing.T) { })) defer ts.Close() - nWritten := int64(0) - req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), &nWritten}, limit*200)) + nWritten := new(int64) + req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) // Send the POST, but don't care it succeeds or not. The // remote side is going to reply and then close the TCP @@ -1045,7 +1204,7 @@ func TestRequestBodyLimit(t *testing.T) { // the remote side hung up on us before we wrote too much. _, _ = DefaultClient.Do(req) - if nWritten > limit*100 { + if atomic.LoadInt64(nWritten) > limit*100 { t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d", limit, nWritten) } @@ -1054,6 +1213,7 @@ func TestRequestBodyLimit(t *testing.T) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. func TestClientWriteShutdown(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1086,28 +1246,207 @@ func TestClientWriteShutdown(t *testing.T) { // Tests that chunked server responses that write 1 byte at a time are // buffered before chunk headers are added, not after chunk headers. func TestServerBufferedChunking(t *testing.T) { - if true { - t.Logf("Skipping known broken test; see Issue 2357") - return - } conn := new(testConn) conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n")) - done := make(chan bool) + conn.closec = make(chan bool, 1) ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { - defer close(done) - rw.Header().Set("Content-Type", "text/plain") // prevent sniffing, which buffers + rw.(Flusher).Flush() // force the Header to be sent, in chunking mode, not counting the length rw.Write([]byte{'x'}) rw.Write([]byte{'y'}) rw.Write([]byte{'z'}) })) - <-done + <-conn.closec if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) { t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q", conn.writeBuf.Bytes()) } } +// Tests that the server flushes its response headers out when it's +// ignoring the response body and waits a bit before forcefully +// closing the TCP connection, causing the client to get a RST. +// See http://golang.org/issue/3595 +func TestServerGracefulClose(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, "bye", StatusUnauthorized) + })) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + const bodySize = 5 << 20 + req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize)) + for i := 0; i < bodySize; i++ { + req = append(req, 'x') + } + writeErr := make(chan error) + go func() { + _, err := conn.Write(req) + writeErr <- err + }() + br := bufio.NewReader(conn) + lineNum := 0 + for { + line, err := br.ReadString('\n') + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("ReadLine: %v", err) + } + lineNum++ + if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") { + t.Errorf("Response line = %q; want a 401", line) + } + } + // Wait for write to finish. This is a broken pipe on both + // Darwin and Linux, but checking this isn't the point of + // the test. + <-writeErr +} + +func TestCaseSensitiveMethod(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "get" { + t.Errorf(`Got method %q; want "get"`, r.Method) + } + })) + defer ts.Close() + req, _ := NewRequest("get", ts.URL, nil) + res, err := DefaultClient.Do(req) + if err != nil { + t.Error(err) + return + } + res.Body.Close() +} + +// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1 +// request (both keep-alive), when a Handler never writes any +// response, the net/http package adds a "Content-Length: 0" response +// header. +func TestContentLengthZero(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {})) + defer ts.Close() + + for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version) + if err != nil { + t.Fatalf("error writing: %v", err) + } + req, _ := NewRequest("GET", "/", nil) + res, err := ReadResponse(bufio.NewReader(conn), req) + if err != nil { + t.Fatalf("error reading response: %v", err) + } + if te := res.TransferEncoding; len(te) > 0 { + t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te) + } + if cl := res.ContentLength; cl != 0 { + t.Errorf("For version %q, Content-Length = %v; want 0", version, cl) + } + conn.Close() + } +} + +func TestCloseNotifier(t *testing.T) { + gotReq := make(chan bool, 1) + sawClose := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + gotReq <- true + cc := rw.(CloseNotifier).CloseNotify() + <-cc + sawClose <- true + })) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + diec := make(chan bool) + go func() { + _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n") + if err != nil { + t.Fatal(err) + } + <-diec + conn.Close() + }() +For: + for { + select { + case <-gotReq: + diec <- true + case <-sawClose: + break For + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + } + ts.Close() +} + +func TestOptions(t *testing.T) { + uric := make(chan string, 2) // only expect 1, but leave space for 2 + mux := NewServeMux() + mux.HandleFunc("/", func(w ResponseWriter, r *Request) { + uric <- r.RequestURI + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // An OPTIONS * request should succeed. + _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n")) + if err != nil { + t.Fatal(err) + } + br := bufio.NewReader(conn) + res, err := ReadResponse(br, &Request{Method: "OPTIONS"}) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 200 { + t.Errorf("Got non-200 response to OPTIONS *: %#v", res) + } + + // A GET * request on a ServeMux should fail. + _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n")) + if err != nil { + t.Fatal(err) + } + res, err = ReadResponse(br, &Request{Method: "GET"}) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 400 { + t.Errorf("Got non-400 response to GET *: %#v", res) + } + + res, err = Get(ts.URL + "/second") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if got := <-uric; got != "/second" { + t.Errorf("Handler saw request for %q; want /second", got) + } +} + // goTimeout runs f, failing t if f takes more than ns to complete. func goTimeout(t *testing.T, d time.Duration, f func()) { ch := make(chan bool, 2) @@ -1184,3 +1523,100 @@ func BenchmarkClientServer(b *testing.B) { b.StopTimer() } + +func BenchmarkClientServerParallel4(b *testing.B) { + benchmarkClientServerParallel(b, 4) +} + +func BenchmarkClientServerParallel64(b *testing.B) { + benchmarkClientServerParallel(b, 64) +} + +func benchmarkClientServerParallel(b *testing.B, conc int) { + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + fmt.Fprintf(rw, "Hello world.\n") + })) + defer ts.Close() + b.StartTimer() + + numProcs := runtime.GOMAXPROCS(-1) * conc + var wg sync.WaitGroup + wg.Add(numProcs) + n := int32(b.N) + for p := 0; p < numProcs; p++ { + go func() { + for atomic.AddInt32(&n, -1) >= 0 { + res, err := Get(ts.URL) + if err != nil { + b.Logf("Get: %v", err) + continue + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + b.Logf("ReadAll: %v", err) + continue + } + body := string(all) + if body != "Hello world.\n" { + panic("Got body: " + body) + } + } + wg.Done() + }() + } + wg.Wait() +} + +// A benchmark for profiling the server without the HTTP client code. +// The client code runs in a subprocess. +// +// For use like: +// $ go test -c +// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof +// $ go tool pprof http.test http.prof +// (pprof) web +func BenchmarkServer(b *testing.B) { + // Child process mode; + if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" { + n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N")) + if err != nil { + panic(err) + } + for i := 0; i < n; i++ { + res, err := Get(url) + if err != nil { + log.Panicf("Get: %v", err) + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + log.Panicf("ReadAll: %v", err) + } + body := string(all) + if body != "Hello world.\n" { + log.Panicf("Got body: %q", body) + } + } + os.Exit(0) + return + } + + var res = []byte("Hello world.\n") + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + })) + defer ts.Close() + b.StartTimer() + + cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer") + cmd.Env = append([]string{ + fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N), + fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL), + }, os.Environ()...) + out, err := cmd.CombinedOutput() + if err != nil { + b.Errorf("Test failure: %v, with output: %s", err, out) + } +} diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index 0572b4ae3..b6ab78228 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -11,7 +11,6 @@ package http import ( "bufio" - "bytes" "crypto/tls" "errors" "fmt" @@ -21,7 +20,7 @@ import ( "net" "net/url" "path" - "runtime/debug" + "runtime" "strconv" "strings" "sync" @@ -94,30 +93,188 @@ type Hijacker interface { Hijack() (net.Conn, *bufio.ReadWriter, error) } +// The CloseNotifier interface is implemented by ResponseWriters which +// allow detecting when the underlying connection has gone away. +// +// This mechanism can be used to cancel long operations on the server +// if the client has disconnected before the response is ready. +type CloseNotifier interface { + // CloseNotify returns a channel that receives a single value + // when the client connection has gone away. + CloseNotify() <-chan bool +} + // A conn represents the server side of an HTTP connection. type conn struct { remoteAddr string // network address of remote side server *Server // the Server on which the connection arrived rwc net.Conn // i/o connection - lr *io.LimitedReader // io.LimitReader(rwc) - buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc - hijacked bool // connection has been hijacked by handler + sr switchReader // where the LimitReader reads from; usually the rwc + lr *io.LimitedReader // io.LimitReader(sr) + buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc tlsState *tls.ConnectionState // or nil when not using TLS - body []byte + + mu sync.Mutex // guards the following + clientGone bool // if client has disconnected mid-request + closeNotifyc chan bool // made lazily + hijackedv bool // connection has been hijacked by handler +} + +func (c *conn) hijacked() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.hijackedv +} + +func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.hijackedv { + return nil, nil, ErrHijacked + } + if c.closeNotifyc != nil { + return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier") + } + c.hijackedv = true + rwc = c.rwc + buf = c.buf + c.rwc = nil + c.buf = nil + return +} + +func (c *conn) closeNotify() <-chan bool { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc == nil { + c.closeNotifyc = make(chan bool) + if c.hijackedv { + // to obey the function signature, even though + // it'll never receive a value. + return c.closeNotifyc + } + pr, pw := io.Pipe() + + readSource := c.sr.r + c.sr.Lock() + c.sr.r = pr + c.sr.Unlock() + go func() { + _, err := io.Copy(pw, readSource) + if err == nil { + err = io.EOF + } + pw.CloseWithError(err) + c.noteClientGone() + }() + } + return c.closeNotifyc +} + +func (c *conn) noteClientGone() { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeNotifyc != nil && !c.clientGone { + c.closeNotifyc <- true + } + c.clientGone = true +} + +type switchReader struct { + sync.Mutex + r io.Reader +} + +func (sr *switchReader) Read(p []byte) (n int, err error) { + sr.Lock() + r := sr.r + sr.Unlock() + return r.Read(p) +} + +// This should be >= 512 bytes for DetectContentType, +// but otherwise it's somewhat arbitrary. +const bufferBeforeChunkingSize = 2048 + +// chunkWriter writes to a response's conn buffer, and is the writer +// wrapped by the response.bufw buffered writer. +// +// chunkWriter also is responsible for finalizing the Header, including +// conditionally setting the Content-Type and setting a Content-Length +// in cases where the handler's final output is smaller than the buffer +// size. It also conditionally adds chunk headers, when in chunking mode. +// +// See the comment above (*response).Write for the entire write flow. +type chunkWriter struct { + res *response + header Header // a deep copy of r.Header, once WriteHeader is called + wroteHeader bool // whether the header's been sent + + // set by the writeHeader method: + chunking bool // using chunked transfer encoding for reply body +} + +var crlf = []byte("\r\n") + +func (cw *chunkWriter) Write(p []byte) (n int, err error) { + if !cw.wroteHeader { + cw.writeHeader(p) + } + if cw.chunking { + _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) + if err != nil { + cw.res.conn.rwc.Close() + return + } + } + n, err = cw.res.conn.buf.Write(p) + if cw.chunking && err == nil { + _, err = cw.res.conn.buf.Write(crlf) + } + if err != nil { + cw.res.conn.rwc.Close() + } + return +} + +func (cw *chunkWriter) flush() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + cw.res.conn.buf.Flush() +} + +func (cw *chunkWriter) close() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + if cw.chunking { + // zero EOF chunk, trailer key/value pairs (currently + // unsupported in Go's server), followed by a blank + // line. + io.WriteString(cw.res.conn.buf, "0\r\n\r\n") + } } // A response represents the server side of an HTTP response. type response struct { conn *conn req *Request // request for this response - chunking bool // using chunked transfer encoding for reply body - wroteHeader bool // reply header has been written + wroteHeader bool // reply header has been (logically) written wroteContinue bool // 100 Continue response was written - header Header // reply header parameters - written int64 // number of bytes written in body - contentLength int64 // explicitly-declared Content-Length; or -1 - status int // status code passed to WriteHeader - needSniff bool // need to sniff to find Content-Type + + w *bufio.Writer // buffers output in chunks to chunkWriter + cw *chunkWriter + + // handlerHeader is the Header that Handlers get access to, + // which may be retained and mutated even after WriteHeader. + // handlerHeader is copied into cw.header at WriteHeader + // time, and privately mutated thereafter. + handlerHeader Header + + written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 + status int // status code passed to WriteHeader // close connection after this reply. set on request and // updated after response from handler if there's a @@ -127,12 +284,14 @@ type response struct { // requestBodyLimitHit is set by requestTooLarge when // maxBytesReader hits its max size. It is checked in - // WriteHeader, to make sure we don't consume the the + // WriteHeader, to make sure we don't consume the // remaining request body to try to advance to the next HTTP - // request. Instead, when this is set, we stop doing + // request. Instead, when this is set, we stop reading // subsequent requests on this connection and stop reading // input from it. requestBodyLimitHit bool + + handlerDone bool // set true when the handler exits } // requestTooLarge is called by maxBytesReader when too much input has @@ -145,42 +304,68 @@ func (w *response) requestTooLarge() { } } +// needsSniff returns whether a Content-Type still needs to be sniffed. +func (w *response) needsSniff() bool { + return !w.cw.wroteHeader && w.handlerHeader.Get("Content-Type") == "" && w.written < sniffLen +} + type writerOnly struct { io.Writer } func (w *response) ReadFrom(src io.Reader) (n int64, err error) { - // Call WriteHeader before checking w.chunking if it hasn't - // been called yet, since WriteHeader is what sets w.chunking. if !w.wroteHeader { w.WriteHeader(StatusOK) } - if !w.chunking && w.bodyAllowed() && !w.needSniff { - w.Flush() + + if w.needsSniff() { + n0, err := io.Copy(writerOnly{w}, io.LimitReader(src, sniffLen)) + n += n0 + if err != nil { + return n, err + } + } + + w.w.Flush() // get rid of any previous writes + w.cw.flush() // make sure Header is written; flush data to rwc + + // Now that cw has been flushed, its chunking field is guaranteed initialized. + if !w.cw.chunking && w.bodyAllowed() { if rf, ok := w.conn.rwc.(io.ReaderFrom); ok { - n, err = rf.ReadFrom(src) - w.written += n - return + n0, err := rf.ReadFrom(src) + n += n0 + w.written += n0 + return n, err } } + // Fall back to default io.Copy implementation. // Use wrapper to hide w.ReadFrom from io.Copy. - return io.Copy(writerOnly{w}, src) + n0, err := io.Copy(writerOnly{w}, src) + n += n0 + return n, err } // noLimit is an effective infinite upper bound for io.LimitedReader const noLimit int64 = (1 << 63) - 1 +// debugServerConnections controls whether all server connections are wrapped +// with a verbose logging wrapper. +const debugServerConnections = false + // Create new connection from rwc. func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { c = new(conn) c.remoteAddr = rwc.RemoteAddr().String() c.server = srv c.rwc = rwc - c.body = make([]byte, sniffLen) - c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader) + if debugServerConnections { + c.rwc = newLoggingConn("server", c.rwc) + } + c.sr = switchReader{r: c.rwc} + c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) br := bufio.NewReader(c.lr) - bw := bufio.NewWriter(rwc) + bw := bufio.NewWriter(c.rwc) c.buf = bufio.NewReadWriter(br, bw) return c, nil } @@ -207,9 +392,9 @@ type expectContinueReader struct { func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { if ecr.closed { - return 0, errors.New("http: Read after Close on request Body") + return 0, ErrBodyReadAfterClose } - if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked { + if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { ecr.resp.wroteContinue = true io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") ecr.resp.conn.buf.Flush() @@ -232,9 +417,19 @@ var errTooLarge = errors.New("http: request too large") // Read next request from connection. func (c *conn) readRequest() (w *response, err error) { - if c.hijacked { + if c.hijacked() { return nil, ErrHijacked } + + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + defer func() { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + }() + } + c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ var req *Request if req, err = ReadRequest(c.buf.Reader); err != nil { @@ -248,17 +443,20 @@ func (c *conn) readRequest() (w *response, err error) { req.RemoteAddr = c.remoteAddr req.TLS = c.tlsState - w = new(response) - w.conn = c - w.req = req - w.header = make(Header) - w.contentLength = -1 - c.body = c.body[:0] + w = &response{ + conn: c, + req: req, + handlerHeader: make(Header), + contentLength: -1, + cw: new(chunkWriter), + } + w.cw.res = w + w.w = bufio.NewWriterSize(w.cw, bufferBeforeChunkingSize) return w, nil } func (w *response) Header() Header { - return w.header + return w.handlerHeader } // maxPostHandlerReadBytes is the max number of Request.Body bytes not @@ -273,7 +471,7 @@ func (w *response) Header() Header { const maxPostHandlerReadBytes = 256 << 10 func (w *response) WriteHeader(code int) { - if w.conn.hijacked { + if w.conn.hijacked() { log.Print("http: response.WriteHeader on hijacked connection") return } @@ -284,31 +482,68 @@ func (w *response) WriteHeader(code int) { w.wroteHeader = true w.status = code - // Check for a explicit (and valid) Content-Length header. - var hasCL bool - var contentLength int64 - if clenStr := w.header.Get("Content-Length"); clenStr != "" { - var err error - contentLength, err = strconv.ParseInt(clenStr, 10, 64) - if err == nil { - hasCL = true + w.cw.header = w.handlerHeader.clone() + + if cl := w.cw.header.get("Content-Length"); cl != "" { + v, err := strconv.ParseInt(cl, 10, 64) + if err == nil && v >= 0 { + w.contentLength = v } else { - log.Printf("http: invalid Content-Length of %q sent", clenStr) - w.header.Del("Content-Length") + log.Printf("http: invalid Content-Length of %q", cl) + w.cw.header.Del("Content-Length") + } + } +} + +// writeHeader finalizes the header sent to the client and writes it +// to cw.res.conn.buf. +// +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. It's also used to set the Content-Length, if the +// total body size was small and the handler has already finished +// running. +func (cw *chunkWriter) writeHeader(p []byte) { + if cw.wroteHeader { + return + } + cw.wroteHeader = true + + w := cw.res + code := w.status + done := w.handlerDone + + // If the handler is done but never sent a Content-Length + // response header and this is our first (and last) write, set + // it, even to zero. This helps HTTP/1.0 clients keep their + // "keep-alive" connections alive. + if done && cw.header.get("Content-Length") == "" && w.req.Method != "HEAD" { + w.contentLength = int64(len(p)) + cw.header.Set("Content-Length", strconv.Itoa(len(p))) + } + + // If this was an HTTP/1.0 request with keep-alive and we sent a + // Content-Length back, we can make this a keep-alive response ... + if w.req.wantsHttp10KeepAlive() { + sentLength := cw.header.get("Content-Length") != "" + if sentLength && cw.header.get("Connection") == "keep-alive" { + w.closeAfterReply = false } } + // Check for a explicit (and valid) Content-Length header. + hasCL := w.contentLength != -1 + if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { - _, connectionHeaderSet := w.header["Connection"] + _, connectionHeaderSet := cw.header["Connection"] if !connectionHeaderSet { - w.header.Set("Connection", "keep-alive") + cw.header.Set("Connection", "keep-alive") } - } else if !w.req.ProtoAtLeast(1, 1) { - // Client did not ask to keep connection alive. + } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() { w.closeAfterReply = true } - if w.header.Get("Connection") == "close" { + if cw.header.get("Connection") == "close" { w.closeAfterReply = true } @@ -322,7 +557,7 @@ func (w *response) WriteHeader(code int) { n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) if n >= maxPostHandlerReadBytes { w.requestTooLarge() - w.header.Set("Connection", "close") + cw.header.Set("Connection", "close") } else { w.req.Body.Close() } @@ -332,64 +567,67 @@ func (w *response) WriteHeader(code int) { if code == StatusNotModified { // Must not have body. for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - if w.header.Get(header) != "" { - // TODO: return an error if WriteHeader gets a return parameter - // or set a flag on w to make future Writes() write an error page? - // for now just log and drop the header. - log.Printf("http: StatusNotModified response with header %q defined", header) - w.header.Del(header) + // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" + if cw.header.get(header) != "" { + cw.header.Del(header) } } } else { // If no content type, apply sniffing algorithm to body. - if w.header.Get("Content-Type") == "" && w.req.Method != "HEAD" { - w.needSniff = true + if cw.header.get("Content-Type") == "" && w.req.Method != "HEAD" { + cw.header.Set("Content-Type", DetectContentType(p)) } } - if _, ok := w.header["Date"]; !ok { - w.Header().Set("Date", time.Now().UTC().Format(TimeFormat)) + if _, ok := cw.header["Date"]; !ok { + cw.header.Set("Date", time.Now().UTC().Format(TimeFormat)) } - te := w.header.Get("Transfer-Encoding") + te := cw.header.get("Transfer-Encoding") hasTE := te != "" if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter // For now just ignore the Content-Length. log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", - te, contentLength) - w.header.Del("Content-Length") + te, w.contentLength) + cw.header.Del("Content-Length") hasCL = false } if w.req.Method == "HEAD" || code == StatusNotModified { // do nothing + } else if code == StatusNoContent { + cw.header.Del("Transfer-Encoding") } else if hasCL { - w.contentLength = contentLength - w.header.Del("Transfer-Encoding") + cw.header.Del("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. // TODO: this blows away any custom or stacked Transfer-Encoding they // might have set. Deal with that as need arises once we have a valid // use case. - w.chunking = true - w.header.Set("Transfer-Encoding", "chunked") + cw.chunking = true + cw.header.Set("Transfer-Encoding", "chunked") } else { // HTTP version < 1.1: cannot do chunked transfer // encoding and we don't know the Content-Length so // signal EOF by closing connection. w.closeAfterReply = true - w.header.Del("Transfer-Encoding") // in case already set + cw.header.Del("Transfer-Encoding") // in case already set } // Cannot use Content-Length with non-identity Transfer-Encoding. - if w.chunking { - w.header.Del("Content-Length") + if cw.chunking { + cw.header.Del("Content-Length") } if !w.req.ProtoAtLeast(1, 0) { return } + + if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") { + cw.header.Set("Connection", "close") + } + proto := "HTTP/1.0" if w.req.ProtoAtLeast(1, 1) { proto = "HTTP/1.1" @@ -400,37 +638,8 @@ func (w *response) WriteHeader(code int) { text = "status code " + codestring } io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - w.header.Write(w.conn.buf) - - // If we need to sniff the body, leave the header open. - // Otherwise, end it here. - if !w.needSniff { - io.WriteString(w.conn.buf, "\r\n") - } -} - -// sniff uses the first block of written data, -// stored in w.conn.body, to decide the Content-Type -// for the HTTP body. -func (w *response) sniff() { - if !w.needSniff { - return - } - w.needSniff = false - - data := w.conn.body - fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n\r\n", DetectContentType(data)) - - if len(data) == 0 { - return - } - if w.chunking { - fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) - } - _, err := w.conn.buf.Write(data) - if w.chunking && err == nil { - io.WriteString(w.conn.buf, "\r\n") - } + cw.header.Write(w.conn.buf) + w.conn.buf.Write(crlf) } // bodyAllowed returns true if a Write is allowed for this response type. @@ -442,8 +651,40 @@ func (w *response) bodyAllowed() bool { return w.status != StatusNotModified && w.req.Method != "HEAD" } +// The Life Of A Write is like this: +// +// Handler starts. No header has been sent. The handler can either +// write a header, or just start writing. Writing before sending a header +// sends an implicity empty 200 OK header. +// +// If the handler didn't declare a Content-Length up front, we either +// go into chunking mode or, if the handler finishes running before +// the chunking buffer size, we compute a Content-Length and send that +// in the header instead. +// +// Likewise, if the handler didn't set a Content-Type, we sniff that +// from the initial chunk of output. +// +// The Writers are wired together like: +// +// 1. *response (the ResponseWriter) -> +// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes +// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type) +// and which writes the chunk headers, if needed. +// 4. conn.buf, a bufio.Writer of default (4kB) bytes +// 5. the rwc, the net.Conn. +// +// TODO(bradfitz): short-circuit some of the buffering when the +// initial header contains both a Content-Type and Content-Length. +// Also short-circuit in (1) when the header's been sent and not in +// chunking mode, writing directly to (4) instead, if (2) has no +// buffered data. More generally, we could short-circuit from (1) to +// (3) even in chunking mode if the write size from (1) is over some +// threshold and nothing is in (2). The answer might be mostly making +// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal +// with this instead. func (w *response) Write(data []byte) (n int, err error) { - if w.conn.hijacked { + if w.conn.hijacked() { log.Print("http: response.Write on hijacked connection") return 0, ErrHijacked } @@ -461,73 +702,20 @@ func (w *response) Write(data []byte) (n int, err error) { if w.contentLength != -1 && w.written > w.contentLength { return 0, ErrContentLength } - - var m int - if w.needSniff { - // We need to sniff the beginning of the output to - // determine the content type. Accumulate the - // initial writes in w.conn.body. - // Cap m so that append won't allocate. - m = cap(w.conn.body) - len(w.conn.body) - if m > len(data) { - m = len(data) - } - w.conn.body = append(w.conn.body, data[:m]...) - data = data[m:] - if len(data) == 0 { - // Copied everything into the buffer. - // Wait for next write. - return m, nil - } - - // Filled the buffer; more data remains. - // Sniff the content (flushes the buffer) - // and then proceed with the remainder - // of the data as a normal Write. - // Calling sniff clears needSniff. - w.sniff() - } - - // TODO(rsc): if chunking happened after the buffering, - // then there would be fewer chunk headers. - // On the other hand, it would make hijacking more difficult. - if w.chunking { - fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt - } - n, err = w.conn.buf.Write(data) - if err == nil && w.chunking { - if n != len(data) { - err = io.ErrShortWrite - } - if err == nil { - io.WriteString(w.conn.buf, "\r\n") - } - } - - return m + n, err + return w.w.Write(data) } func (w *response) finishRequest() { - // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length - // back, we can make this a keep-alive response ... - if w.req.wantsHttp10KeepAlive() { - sentLength := w.header.Get("Content-Length") != "" - if sentLength && w.header.Get("Connection") == "keep-alive" { - w.closeAfterReply = false - } - } + w.handlerDone = true + if !w.wroteHeader { w.WriteHeader(StatusOK) } - if w.needSniff { - w.sniff() - } - if w.chunking { - io.WriteString(w.conn.buf, "0\r\n") - // trailer key/value pairs, followed by blank line - io.WriteString(w.conn.buf, "\r\n") - } + + w.w.Flush() + w.cw.close() w.conn.buf.Flush() + // Close the body, unless we're about to close the whole TCP connection // anyway. if !w.closeAfterReply { @@ -537,7 +725,7 @@ func (w *response) finishRequest() { w.req.MultipartForm.RemoveAll() } - if w.contentLength != -1 && w.contentLength != w.written { + if w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { // Did not write enough. Avoid getting out of sync. w.closeAfterReply = true } @@ -547,66 +735,114 @@ func (w *response) Flush() { if !w.wroteHeader { w.WriteHeader(StatusOK) } - w.sniff() - w.conn.buf.Flush() + w.w.Flush() + w.cw.flush() } -// Close the connection. -func (c *conn) close() { +func (c *conn) finalFlush() { if c.buf != nil { c.buf.Flush() c.buf = nil } +} + +// Close the connection. +func (c *conn) close() { + c.finalFlush() if c.rwc != nil { c.rwc.Close() c.rwc = nil } } +// rstAvoidanceDelay is the amount of time we sleep after closing the +// write side of a TCP connection before closing the entire socket. +// By sleeping, we increase the chances that the client sees our FIN +// and processes its final data before they process the subsequent RST +// from closing a connection with known unread data. +// This RST seems to occur mostly on BSD systems. (And Windows?) +// This timeout is somewhat arbitrary (~latency around the planet). +const rstAvoidanceDelay = 500 * time.Millisecond + +// closeWrite flushes any outstanding data and sends a FIN packet (if +// client is connected via TCP), signalling that we're done. We then +// pause for a bit, hoping the client processes it before `any +// subsequent RST. +// +// See http://golang.org/issue/3595 +func (c *conn) closeWriteAndWait() { + c.finalFlush() + if tcp, ok := c.rwc.(*net.TCPConn); ok { + tcp.CloseWrite() + } + time.Sleep(rstAvoidanceDelay) +} + +// validNPN returns whether the proto is not a blacklisted Next +// Protocol Negotiation protocol. Empty and built-in protocol types +// are blacklisted and can't be overridden with alternate +// implementations. +func validNPN(proto string) bool { + switch proto { + case "", "http/1.1", "http/1.0": + return false + } + return true +} + // Serve a new connection. func (c *conn) serve() { defer func() { - err := recover() - if err == nil { - return + if err := recover(); err != nil { + const size = 4096 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) } - - var buf bytes.Buffer - fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err) - buf.Write(debug.Stack()) - log.Print(buf.String()) - - if c.rwc != nil { // may be nil if connection hijacked - c.rwc.Close() + if !c.hijacked() { + c.close() } }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + } if err := tlsConn.Handshake(); err != nil { - c.close() return } c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() + if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) { + if fn := c.server.TLSNextProto[proto]; fn != nil { + h := initNPNRequest{tlsConn, serverHandler{c.server}} + fn(c.server, tlsConn, h) + } + return + } } for { w, err := c.readRequest() if err != nil { - msg := "400 Bad Request" if err == errTooLarge { // Their HTTP client may or may not be // able to read this if we're // responding to them and hanging up // while they're still writing their // request. Undefined behavior. - msg = "413 Request Entity Too Large" + io.WriteString(c.rwc, "HTTP/1.1 413 Request Entity Too Large\r\n\r\n") + c.closeWriteAndWait() + break } else if err == io.EOF { break // Don't reply } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { break // Don't reply } - fmt.Fprintf(c.rwc, "HTTP/1.1 %s\r\n\r\n", msg) + io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\n\r\n") break } @@ -624,59 +860,59 @@ func (c *conn) serve() { break } req.Header.Del("Expect") - } else if req.Header.Get("Expect") != "" { - // TODO(bradfitz): let ServeHTTP handlers handle - // requests with non-standard expectation[s]? Seems - // theoretical at best, and doesn't fit into the - // current ServeHTTP model anyway. We'd need to - // make the ResponseWriter an optional - // "ExpectReplier" interface or something. - // - // For now we'll just obey RFC 2616 14.20 which says - // "If a server receives a request containing an - // Expect field that includes an expectation- - // extension that it does not support, it MUST - // respond with a 417 (Expectation Failed) status." - w.Header().Set("Connection", "close") - w.WriteHeader(StatusExpectationFailed) - w.finishRequest() + } else if req.Header.get("Expect") != "" { + w.sendExpectationFailed() break } - handler := c.server.Handler - if handler == nil { - handler = DefaultServeMux - } - // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. - handler.ServeHTTP(w, w.req) - if c.hijacked { + serverHandler{c.server}.ServeHTTP(w, w.req) + if c.hijacked() { return } w.finishRequest() if w.closeAfterReply { + if w.requestBodyLimitHit { + c.closeWriteAndWait() + } break } } - c.close() +} + +func (w *response) sendExpectationFailed() { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 2616 14.20 which says + // "If a server receives a request containing an + // Expect field that includes an expectation- + // extension that it does not support, it MUST + // respond with a 417 (Expectation Failed) status." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + w.finishRequest() } // Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter // and a Hijacker. func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { - if w.conn.hijacked { - return nil, nil, ErrHijacked + if w.wroteHeader { + w.cw.flush() } - w.conn.hijacked = true - rwc = w.conn.rwc - buf = w.conn.buf - w.conn.rwc = nil - w.conn.buf = nil - return + return w.conn.hijack() +} + +func (w *response) CloseNotify() <-chan bool { + return w.conn.closeNotify() } // The HandlerFunc type is an adapter to allow the use of @@ -817,13 +1053,13 @@ func RedirectHandler(url string, code int) Handler { // patterns and calls the handler for the pattern that // most closely matches the URL. // -// Patterns named fixed, rooted paths, like "/favicon.ico", +// Patterns name fixed, rooted paths, like "/favicon.ico", // or rooted subtrees, like "/images/" (note the trailing slash). // Longer patterns take precedence over shorter ones, so that // if there are handlers registered for both "/images/" // and "/images/thumbnails/", the latter handler will be // called for paths beginning "/images/thumbnails/" and the -// former will receiver requests for any other paths in the +// former will receive requests for any other paths in the // "/images/" subtree. // // Patterns may optionally begin with a host name, restricting matches to @@ -836,13 +1072,15 @@ func RedirectHandler(url string, code int) Handler { // redirecting any request containing . or .. elements to an // equivalent .- and ..-free URL. type ServeMux struct { - mu sync.RWMutex - m map[string]muxEntry + mu sync.RWMutex + m map[string]muxEntry + hosts bool // whether any patterns contain hostnames } type muxEntry struct { explicit bool h Handler + pattern string } // NewServeMux allocates and returns a new ServeMux. @@ -883,8 +1121,7 @@ func cleanPath(p string) string { // Find a handler on a handler map given a path string // Most-specific (longest) pattern wins -func (mux *ServeMux) match(path string) Handler { - var h Handler +func (mux *ServeMux) match(path string) (h Handler, pattern string) { var n = 0 for k, v := range mux.m { if !pathMatch(k, path) { @@ -893,37 +1130,64 @@ func (mux *ServeMux) match(path string) Handler { if h == nil || len(k) > n { n = len(k) h = v.h + pattern = v.pattern + } + } + return +} + +// Handler returns the handler to use for the given request, +// consulting r.Method, r.Host, and r.URL.Path. It always returns +// a non-nil handler. If the path is not in its canonical form, the +// handler will be an internally-generated handler that redirects +// to the canonical path. +// +// Handler also returns the registered pattern that matches the +// request or, in the case of internally-generated redirects, +// the pattern that will match after following the redirect. +// +// If there is no registered handler that applies to the request, +// Handler returns a ``page not found'' handler and an empty pattern. +func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { + if r.Method != "CONNECT" { + if p := cleanPath(r.URL.Path); p != r.URL.Path { + _, pattern = mux.handler(r.Host, p) + return RedirectHandler(p, StatusMovedPermanently), pattern } } - return h + + return mux.handler(r.Host, r.URL.Path) } -// handler returns the handler to use for the request r. -func (mux *ServeMux) handler(r *Request) Handler { +// handler is the main implementation of Handler. +// The path is known to be in canonical form, except for CONNECT methods. +func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { mux.mu.RLock() defer mux.mu.RUnlock() // Host-specific pattern takes precedence over generic ones - h := mux.match(r.Host + r.URL.Path) + if mux.hosts { + h, pattern = mux.match(host + path) + } if h == nil { - h = mux.match(r.URL.Path) + h, pattern = mux.match(path) } if h == nil { - h = NotFoundHandler() + h, pattern = NotFoundHandler(), "" } - return h + return } // ServeHTTP dispatches the request to the handler whose // pattern most closely matches the request URL. func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { - // Clean path to canonical form and redirect. - if p := cleanPath(r.URL.Path); p != r.URL.Path { - w.Header().Set("Location", p) - w.WriteHeader(StatusMovedPermanently) + if r.RequestURI == "*" { + w.Header().Set("Connection", "close") + w.WriteHeader(StatusBadRequest) return } - mux.handler(r).ServeHTTP(w, r) + h, _ := mux.Handler(r) + h.ServeHTTP(w, r) } // Handle registers the handler for the given pattern. @@ -942,14 +1206,26 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) { panic("http: multiple registrations for " + pattern) } - mux.m[pattern] = muxEntry{explicit: true, h: handler} + mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern} + + if pattern[0] != '/' { + mux.hosts = true + } // Helpful behavior: // If pattern is /tree/, insert an implicit permanent redirect for /tree. // It can be overridden by an explicit registration. n := len(pattern) if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit { - mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(pattern, StatusMovedPermanently)} + // If pattern contains a host name, strip it and use remaining + // path for redirect. + path := pattern + if pattern[0] != '/' { + // In pattern, at least the last character is a '/', so + // strings.Index can't be -1. + path = pattern[strings.Index(pattern, "/"):] + } + mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(path, StatusMovedPermanently), pattern: pattern} } } @@ -971,7 +1247,7 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { } // Serve accepts incoming HTTP connections on the listener l, -// creating a new service thread for each. The service threads +// creating a new service goroutine for each. The service goroutines // read requests and then call handler to reply to them. // Handler is typically nil, in which case the DefaultServeMux is used. func Serve(l net.Listener, handler Handler) error { @@ -987,6 +1263,32 @@ type Server struct { WriteTimeout time.Duration // maximum duration before timing out write of the response MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0 TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // TLSNextProto optionally specifies a function to take over + // ownership of the provided TLS connection when an NPN + // protocol upgrade has occured. The map key is the protocol + // name negotiated. The Handler argument should be used to + // handle HTTP requests and will initialize the Request's TLS + // and RemoteAddr if not already set. The connection is + // automatically closed when the function returns. + TLSNextProto map[string]func(*Server, *tls.Conn, Handler) +} + +// serverHandler delegates to either the server's Handler or +// DefaultServeMux and also handles "OPTIONS *" requests. +type serverHandler struct { + srv *Server +} + +func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { + handler := sh.srv.Handler + if handler == nil { + handler = DefaultServeMux + } + if req.RequestURI == "*" && req.Method == "OPTIONS" { + handler = globalOptionsHandler{} + } + handler.ServeHTTP(rw, req) } // ListenAndServe listens on the TCP network address srv.Addr and then @@ -1005,7 +1307,7 @@ func (srv *Server) ListenAndServe() error { } // Serve accepts incoming connections on the Listener l, creating a -// new service thread for each. The service threads read requests and +// new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. func (srv *Server) Serve(l net.Listener) error { defer l.Close() @@ -1029,12 +1331,6 @@ func (srv *Server) Serve(l net.Listener) error { return e } tempDelay = 0 - if srv.ReadTimeout != 0 { - rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) - } - if srv.WriteTimeout != 0 { - rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout)) - } c, err := srv.newConn(rw) if err != nil { continue @@ -1150,7 +1446,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { // TimeoutHandler returns a Handler that runs h with the given time limit. // // The new Handler calls h.ServeHTTP to handle each request, but if a -// call runs for more than ns nanoseconds, the handler responds with +// call runs for longer than its time limit, the handler responds with // a 503 Service Unavailable error and the given message in its body. // (If msg is empty, a suitable default message will be sent.) // After such a timeout, writes by h to its ResponseWriter will return @@ -1180,7 +1476,7 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { - done := make(chan bool) + done := make(chan bool, 1) tw := &timeoutWriter{w: w} go func() { h.handler.ServeHTTP(tw, r) @@ -1232,3 +1528,86 @@ func (tw *timeoutWriter) WriteHeader(code int) { tw.mu.Unlock() tw.w.WriteHeader(code) } + +// globalOptionsHandler responds to "OPTIONS *" requests. +type globalOptionsHandler struct{} + +func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "0") + if r.ContentLength != 0 { + // Read up to 4KB of OPTIONS body (as mentioned in the + // spec as being reserved for future use), but anything + // over that is considered a waste of server resources + // (or an attack) and we abort and close the connection, + // courtesy of MaxBytesReader's EOF behavior. + mb := MaxBytesReader(w, r.Body, 4<<10) + io.Copy(ioutil.Discard, mb) + } +} + +// eofReader is a non-nil io.ReadCloser that always returns EOF. +var eofReader = ioutil.NopCloser(strings.NewReader("")) + +// initNPNRequest is an HTTP handler that initializes certain +// uninitialized fields in its *Request. Such partially-initialized +// Requests come from NPN protocol handlers. +type initNPNRequest struct { + c *tls.Conn + h serverHandler +} + +func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { + if req.TLS == nil { + req.TLS = &tls.ConnectionState{} + *req.TLS = h.c.ConnectionState() + } + if req.Body == nil { + req.Body = eofReader + } + if req.RemoteAddr == "" { + req.RemoteAddr = h.c.RemoteAddr().String() + } + h.h.ServeHTTP(rw, req) +} + +// loggingConn is used for debugging. +type loggingConn struct { + name string + net.Conn +} + +var ( + uniqNameMu sync.Mutex + uniqNameNext = make(map[string]int) +) + +func newLoggingConn(baseName string, c net.Conn) net.Conn { + uniqNameMu.Lock() + defer uniqNameMu.Unlock() + uniqNameNext[baseName]++ + return &loggingConn{ + name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]), + Conn: c, + } +} + +func (c *loggingConn) Write(p []byte) (n int, err error) { + log.Printf("%s.Write(%d) = ....", c.name, len(p)) + n, err = c.Conn.Write(p) + log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Read(p []byte) (n int, err error) { + log.Printf("%s.Read(%d) = ....", c.name, len(p)) + n, err = c.Conn.Read(p) + log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Close() (err error) { + log.Printf("%s.Close() = ...", c.name) + err = c.Conn.Close() + log.Printf("%s.Close() = %v", c.name, err) + return +} diff --git a/src/pkg/net/http/server_test.go b/src/pkg/net/http/server_test.go new file mode 100644 index 000000000..8b4e8c6d6 --- /dev/null +++ b/src/pkg/net/http/server_test.go @@ -0,0 +1,95 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" + "testing" +) + +var serveMuxRegister = []struct { + pattern string + h Handler +}{ + {"/dir/", serve(200)}, + {"/search", serve(201)}, + {"codesearch.google.com/search", serve(202)}, + {"codesearch.google.com/", serve(203)}, +} + +// serve returns a handler that sends a response with the given code. +func serve(code int) HandlerFunc { + return func(w ResponseWriter, r *Request) { + w.WriteHeader(code) + } +} + +var serveMuxTests = []struct { + method string + host string + path string + code int + pattern string +}{ + {"GET", "google.com", "/", 404, ""}, + {"GET", "google.com", "/dir", 301, "/dir/"}, + {"GET", "google.com", "/dir/", 200, "/dir/"}, + {"GET", "google.com", "/dir/file", 200, "/dir/"}, + {"GET", "google.com", "/search", 201, "/search"}, + {"GET", "google.com", "/search/", 404, ""}, + {"GET", "google.com", "/search/foo", 404, ""}, + {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"}, + {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"}, + {"GET", "images.google.com", "/search", 201, "/search"}, + {"GET", "images.google.com", "/search/", 404, ""}, + {"GET", "images.google.com", "/search/foo", 404, ""}, + {"GET", "google.com", "/../search", 301, "/search"}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/./file", 301, "/dir/"}, + + // The /foo -> /foo/ redirect applies to CONNECT requests + // but the path canonicalization does not. + {"CONNECT", "google.com", "/dir", 301, "/dir/"}, + {"CONNECT", "google.com", "/../search", 404, ""}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"}, +} + +func TestServeMuxHandler(t *testing.T) { + mux := NewServeMux() + for _, e := range serveMuxRegister { + mux.Handle(e.pattern, e.h) + } + + for _, tt := range serveMuxTests { + r := &Request{ + Method: tt.method, + Host: tt.host, + URL: &url.URL{ + Path: tt.path, + }, + } + h, pattern := mux.Handler(r) + cs := &codeSaver{h: Header{}} + h.ServeHTTP(cs, r) + if pattern != tt.pattern || cs.code != tt.code { + t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, cs.code, pattern, tt.code, tt.pattern) + } + } +} + +// A codeSaver is a ResponseWriter that saves the code passed to WriteHeader. +type codeSaver struct { + h Header + code int +} + +func (cs *codeSaver) Header() Header { return cs.h } +func (cs *codeSaver) Write(p []byte) (int, error) { return len(p), nil } +func (cs *codeSaver) WriteHeader(code int) { cs.code = code } diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go index 9e9d84172..43c6023a3 100644 --- a/src/pkg/net/http/transfer.go +++ b/src/pkg/net/http/transfer.go @@ -87,10 +87,8 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { // Sanitize Body,ContentLength,TransferEncoding if t.ResponseToHEAD { t.Body = nil - t.TransferEncoding = nil - // ContentLength is expected to hold Content-Length - if t.ContentLength < 0 { - return nil, ErrMissingContentLength + if chunked(t.TransferEncoding) { + t.ContentLength = -1 } } else { if !atLeastHTTP11 || t.Body == nil { @@ -122,9 +120,6 @@ func (t *transferWriter) shouldSendContentLength() bool { if t.ContentLength > 0 { return true } - if t.ResponseToHEAD { - return true - } // Many servers expect a Content-Length for these methods if t.Method == "POST" || t.Method == "PUT" { return true @@ -199,10 +194,11 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { ncopy, err = io.Copy(w, t.Body) } else { ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) - nextra, err := io.Copy(ioutil.Discard, t.Body) if err != nil { return err } + var nextra int64 + nextra, err = io.Copy(ioutil.Discard, t.Body) ncopy += nextra } if err != nil { @@ -213,7 +209,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { } } - if t.ContentLength != -1 && t.ContentLength != ncopy { + if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy { return fmt.Errorf("http: Request.ContentLength=%d with Body length %d", t.ContentLength, ncopy) } @@ -294,10 +290,19 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { return err } - t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) + realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) if err != nil { return err } + if isResponse && t.RequestMethod == "HEAD" { + if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { + return err + } else { + t.ContentLength = n + } + } else { + t.ContentLength = realLength + } // Trailer t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding) @@ -310,7 +315,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // See RFC2616, section 4.4. switch msg.(type) { case *Response: - if t.ContentLength == -1 && + if realLength == -1 && !chunked(t.TransferEncoding) && bodyAllowedForStatus(t.StatusCode) { // Unbounded body. @@ -322,12 +327,16 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): - t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} - case t.ContentLength >= 0: + if noBodyExpected(t.RequestMethod) { + t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close} + } else { + t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + } + case realLength >= 0: // TODO: limit the Content-Length. This is an easy DoS vector. - t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close} + t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} default: - // t.ContentLength < 0, i.e. "Content-Length" not mentioned in header + // realLength < 0, i.e. "Content-Length" not mentioned in header if t.Close { // Close semantics (i.e. HTTP/1.0) t.Body = &body{Reader: r, closing: t.Close} @@ -371,12 +380,6 @@ func fixTransferEncoding(requestMethod string, header Header) ([]string, error) delete(header, "Transfer-Encoding") - // Head responses have no bodies, so the transfer encoding - // should be ignored. - if requestMethod == "HEAD" { - return nil, nil - } - encodings := strings.Split(raw[0], ",") te := make([]string, 0, len(encodings)) // TODO: Even though we only support "identity" and "chunked" @@ -432,11 +435,11 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, } // Logic based on Content-Length - cl := strings.TrimSpace(header.Get("Content-Length")) + cl := strings.TrimSpace(header.get("Content-Length")) if cl != "" { - n, err := strconv.ParseInt(cl, 10, 64) - if err != nil || n < 0 { - return -1, &badStringError{"bad Content-Length", cl} + n, err := parseContentLength(cl) + if err != nil { + return -1, err } return n, nil } else { @@ -451,13 +454,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, return 0, nil } - // Logic based on media type. The purpose of the following code is just - // to detect whether the unsupported "multipart/byteranges" is being - // used. A proper Content-Type parser is needed in the future. - if strings.Contains(strings.ToLower(header.Get("Content-Type")), "multipart/byteranges") { - return -1, ErrNotSupported - } - // Body-EOF logic based on other methods (like closing, or chunked coding) return -1, nil } @@ -469,14 +465,14 @@ func shouldClose(major, minor int, header Header) bool { if major < 1 { return true } else if major == 1 && minor == 0 { - if !strings.Contains(strings.ToLower(header.Get("Connection")), "keep-alive") { + if !strings.Contains(strings.ToLower(header.get("Connection")), "keep-alive") { return true } return false } else { // TODO: Should split on commas, toss surrounding white space, // and check each field. - if strings.ToLower(header.Get("Connection")) == "close" { + if strings.ToLower(header.get("Connection")) == "close" { header.Del("Connection") return true } @@ -486,7 +482,7 @@ func shouldClose(major, minor int, header Header) bool { // Parse the trailer header func fixTrailer(header Header, te []string) (Header, error) { - raw := header.Get("Trailer") + raw := header.get("Trailer") if raw == "" { return nil, nil } @@ -525,11 +521,11 @@ type body struct { res *response // response writer for server requests, else nil } -// ErrBodyReadAfterClose is returned when reading a Request Body after -// the body has been closed. This typically happens when the body is +// ErrBodyReadAfterClose is returned when reading a Request or Response +// Body after the body has been closed. This typically happens when the body is // read after an HTTP Handler calls WriteHeader or Write on its // ResponseWriter. -var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed request Body") +var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") func (b *body) Read(p []byte) (n int, err error) { if b.closed { @@ -567,14 +563,22 @@ func seeUpcomingDoubleCRLF(r *bufio.Reader) bool { return false } +var errTrailerEOF = errors.New("http: unexpected EOF reading trailer") + func (b *body) readTrailer() error { // The common case, since nobody uses trailers. - buf, _ := b.r.Peek(2) + buf, err := b.r.Peek(2) if bytes.Equal(buf, singleCRLF) { b.r.ReadByte() b.r.ReadByte() return nil } + if len(buf) < 2 { + return errTrailerEOF + } + if err != nil { + return err + } // Make sure there's a header terminator coming up, to prevent // a DoS with an unbounded size Trailer. It's not easy to @@ -590,6 +594,9 @@ func (b *body) readTrailer() error { hdr, err := textproto.NewReader(b.r).ReadMIMEHeader() if err != nil { + if err == io.EOF { + return errTrailerEOF + } return err } switch rr := b.hdr.(type) { @@ -630,3 +637,18 @@ func (b *body) Close() error { } return nil } + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +func parseContentLength(cl string) (int64, error) { + cl = strings.TrimSpace(cl) + if cl == "" { + return -1, nil + } + n, err := strconv.ParseInt(cl, 10, 64) + if err != nil || n < 0 { + return 0, &badStringError{"bad Content-Length", cl} + } + return n, nil + +} diff --git a/src/pkg/net/http/transfer_test.go b/src/pkg/net/http/transfer_test.go new file mode 100644 index 000000000..8627a374c --- /dev/null +++ b/src/pkg/net/http/transfer_test.go @@ -0,0 +1,37 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "strings" + "testing" +) + +func TestBodyReadBadTrailer(t *testing.T) { + b := &body{ + Reader: strings.NewReader("foobar"), + hdr: true, // force reading the trailer + r: bufio.NewReader(strings.NewReader("")), + } + buf := make([]byte, 7) + n, err := b.Read(buf[:3]) + got := string(buf[:n]) + if got != "foo" || err != nil { + t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err) + } + + n, err = b.Read(buf[:]) + got = string(buf[:n]) + if got != "bar" || err != nil { + t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err) + } + + n, err = b.Read(buf[:]) + got = string(buf[:n]) + if err == nil { + t.Errorf("final Read was successful (%q), expected error from trailer read", got) + } +} diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index 6efe191eb..685d7d56c 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // HTTP client implementation. See RFC 2616. -// +// // This is the low-level Transport implementation of RoundTripper. // The high-level interface is in client.go. @@ -24,13 +24,14 @@ import ( "os" "strings" "sync" + "time" ) // DefaultTransport is the default implementation of Transport and is -// used by DefaultClient. It establishes a new network connection for -// each call to Do and uses HTTP proxies as directed by the -// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy) -// environment variables. +// used by DefaultClient. It establishes network connections as needed +// and caches them for reuse by subsequent calls. It uses HTTP proxies +// as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and +// $no_proxy) environment variables. var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment} // DefaultMaxIdleConnsPerHost is the default value of Transport's @@ -41,8 +42,11 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - lk sync.Mutex + idleMu sync.Mutex idleConn map[string][]*persistConn + reqMu sync.Mutex + reqConn map[*Request]*persistConn + altMu sync.RWMutex altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // TODO: tunable on global max cached connections @@ -68,9 +72,15 @@ type Transport struct { DisableCompression bool // MaxIdleConnsPerHost, if non-zero, controls the maximum idle - // (keep-alive) to keep to keep per-host. If zero, + // (keep-alive) to keep per-host. If zero, // DefaultMaxIdleConnsPerHost is used. MaxIdleConnsPerHost int + + // ResponseHeaderTimeout, if non-zero, specifies the amount of + // time to wait for a server's response headers after fully + // writing the request (including its body, if any). This + // time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration } // ProxyFromEnvironment returns the URL of the proxy to use for a @@ -88,7 +98,7 @@ func ProxyFromEnvironment(req *Request) (*url.URL, error) { return nil, nil } proxyURL, err := url.Parse(proxy) - if err != nil || proxyURL.Scheme == "" { + if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") { if u, err := url.Parse("http://" + proxy); err == nil { proxyURL = u err = nil @@ -131,17 +141,20 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { return nil, errors.New("http: nil Request.Header") } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - t.lk.Lock() + t.altMu.RLock() var rt RoundTripper if t.altProto != nil { rt = t.altProto[req.URL.Scheme] } - t.lk.Unlock() + t.altMu.RUnlock() if rt == nil { return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} } return rt.RoundTrip(req) } + if req.URL.Host == "" { + return nil, errors.New("http: no Host in request URL") + } treq := &transportRequest{Request: req} cm, err := t.connectMethodForRequest(treq) if err != nil { @@ -170,8 +183,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { if scheme == "http" || scheme == "https" { panic("protocol " + scheme + " already registered") } - t.lk.Lock() - defer t.lk.Unlock() + t.altMu.Lock() + defer t.altMu.Unlock() if t.altProto == nil { t.altProto = make(map[string]RoundTripper) } @@ -186,17 +199,29 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { - t.lk.Lock() - defer t.lk.Unlock() - if t.idleConn == nil { + t.idleMu.Lock() + m := t.idleConn + t.idleConn = nil + t.idleMu.Unlock() + if m == nil { return } - for _, conns := range t.idleConn { + for _, conns := range m { for _, pconn := range conns { pconn.close() } } - t.idleConn = make(map[string][]*persistConn) +} + +// CancelRequest cancels an in-flight request by closing its +// connection. +func (t *Transport) CancelRequest(req *Request) { + t.reqMu.Lock() + pc := t.reqConn[req] + t.reqMu.Unlock() + if pc != nil { + pc.conn.Close() + } } // @@ -242,8 +267,6 @@ func (cm *connectMethod) proxyAuth() string { // If pconn is no longer needed or not in a good state, putIdleConn // returns false. func (t *Transport) putIdleConn(pconn *persistConn) bool { - t.lk.Lock() - defer t.lk.Unlock() if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { pconn.close() return false @@ -256,21 +279,32 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { if max == 0 { max = DefaultMaxIdleConnsPerHost } + t.idleMu.Lock() + if t.idleConn == nil { + t.idleConn = make(map[string][]*persistConn) + } if len(t.idleConn[key]) >= max { + t.idleMu.Unlock() pconn.close() return false } + for _, exist := range t.idleConn[key] { + if exist == pconn { + log.Fatalf("dup idle pconn %p in freelist", pconn) + } + } t.idleConn[key] = append(t.idleConn[key], pconn) + t.idleMu.Unlock() return true } func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { - t.lk.Lock() - defer t.lk.Unlock() + key := cm.String() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { - t.idleConn = make(map[string][]*persistConn) + return nil } - key := cm.String() for { pconns, ok := t.idleConn[key] if !ok { @@ -289,7 +323,20 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { return } } - return + panic("unreachable") +} + +func (t *Transport) setReqConn(r *Request, pc *persistConn) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqConn == nil { + t.reqConn = make(map[*Request]*persistConn) + } + if pc != nil { + t.reqConn[r] = pc + } else { + delete(t.reqConn, r) + } } func (t *Transport) dial(network, addr string) (c net.Conn, err error) { @@ -323,6 +370,8 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { cacheKey: cm.String(), conn: conn, reqch: make(chan requestAndChan, 50), + writech: make(chan writeRequest, 50), + closech: make(chan struct{}), } switch { @@ -365,7 +414,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { if cm.targetScheme == "https" { // Initiate TLS and check remote host name against certificate. - conn = tls.Client(conn, t.TLSClientConfig) + cfg := t.TLSClientConfig + if cfg == nil || cfg.ServerName == "" { + host := cm.tlsHost() + if cfg == nil { + cfg = &tls.Config{ServerName: host} + } else { + clone := *cfg // shallow clone + clone.ServerName = host + cfg = &clone + } + } + conn = tls.Client(conn, cfg) if err = conn.(*tls.Conn).Handshake(); err != nil { return nil, err } @@ -380,6 +440,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { pconn.br = bufio.NewReader(pconn.conn) pconn.bw = bufio.NewWriter(pconn.conn) go pconn.readLoop() + go pconn.writeLoop() return pconn, nil } @@ -421,7 +482,15 @@ func useProxy(addr string) bool { if hasPort(p) { p = p[:strings.LastIndex(p, ":")] } - if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) { + if addr == p { + return false + } + if p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:]) { + // no_proxy ".foo.com" matches "bar.foo.com" or "foo.com" + return false + } + if p[0] != '.' && strings.HasSuffix(addr, p) && addr[len(addr)-len(p)-1] == '.' { + // no_proxy "foo.com" matches "bar.foo.com" return false } } @@ -484,25 +553,28 @@ type persistConn struct { t *Transport cacheKey string // its connectMethod.String() conn net.Conn + closed bool // whether conn has been closed br *bufio.Reader // from conn bw *bufio.Writer // to conn - reqch chan requestAndChan // written by roundTrip(); read by readLoop() + reqch chan requestAndChan // written by roundTrip; read by readLoop + writech chan writeRequest // written by roundTrip; read by writeLoop + closech chan struct{} // broadcast close when readLoop (TCP connection) closes isProxy bool + lk sync.Mutex // guards following 3 fields + numExpectedResponses int + broken bool // an error has happened on this connection; marked broken so it's not reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the // original Request given to RoundTrip is not modified) mutateHeaderFunc func(Header) - - lk sync.Mutex // guards numExpectedResponses and broken - numExpectedResponses int - broken bool // an error has happened on this connection; marked broken so it's not reused. } func (pc *persistConn) isBroken() bool { pc.lk.Lock() - defer pc.lk.Unlock() - return pc.broken + b := pc.broken + pc.lk.Unlock() + return b } var remoteSideClosedFunc func(error) bool // or nil to use default @@ -518,6 +590,7 @@ func remoteSideClosed(err error) bool { } func (pc *persistConn) readLoop() { + defer close(pc.closech) alive := true var lastbody io.ReadCloser // last response body, if any, read on this connection @@ -544,12 +617,16 @@ func (pc *persistConn) readLoop() { lastbody.Close() // assumed idempotent lastbody = nil } - resp, err := ReadResponse(pc.br, rc.req) + + var resp *Response + if err == nil { + resp, err = ReadResponse(pc.br, rc.req) + } + hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 if err != nil { pc.close() } else { - hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" { resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") @@ -569,31 +646,37 @@ func (pc *persistConn) readLoop() { alive = false } - hasBody := resp != nil && resp.ContentLength != 0 var waitForBodyRead chan bool - if alive { - if hasBody { - lastbody = resp.Body - waitForBodyRead = make(chan bool) - resp.Body.(*bodyEOFSignal).fn = func() { - if !pc.t.putIdleConn(pc) { - alive = false - } - waitForBodyRead <- true + if hasBody { + lastbody = resp.Body + waitForBodyRead = make(chan bool, 1) + resp.Body.(*bodyEOFSignal).fn = func(err error) { + alive1 := alive + if err != nil { + alive1 = false } - } else { - // When there's no response body, we immediately - // reuse the TCP connection (putIdleConn), but - // we need to prevent ClientConn.Read from - // closing the Response.Body on the next - // loop, otherwise it might close the body - // before the client code has had a chance to - // read it (even though it'll just be 0, EOF). - lastbody = nil - - if !pc.t.putIdleConn(pc) { - alive = false + if alive1 && !pc.t.putIdleConn(pc) { + alive1 = false + } + if !alive1 || pc.isBroken() { + pc.close() } + waitForBodyRead <- alive1 + } + } + + if alive && !hasBody { + // When there's no response body, we immediately + // reuse the TCP connection (putIdleConn), but + // we need to prevent ClientConn.Read from + // closing the Response.Body on the next + // loop, otherwise it might close the body + // before the client code has had a chance to + // read it (even though it'll just be 0, EOF). + lastbody = nil + + if !pc.t.putIdleConn(pc) { + alive = false } } @@ -602,7 +685,35 @@ func (pc *persistConn) readLoop() { // Wait for the just-returned response body to be fully consumed // before we race and peek on the underlying bufio reader. if waitForBodyRead != nil { - <-waitForBodyRead + alive = <-waitForBodyRead + } + + pc.t.setReqConn(rc.req, nil) + + if !alive { + pc.close() + } + } +} + +func (pc *persistConn) writeLoop() { + for { + select { + case wr := <-pc.writech: + if pc.isBroken() { + wr.ch <- errors.New("http: can't write HTTP request on broken connection") + continue + } + err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra) + if err == nil { + err = pc.bw.Flush() + } + if err != nil { + pc.markBroken() + } + wr.ch <- err + case <-pc.closech: + return } } } @@ -622,9 +733,24 @@ type requestAndChan struct { addedGzip bool } +// A writeRequest is sent by the readLoop's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + req *transportRequest + ch chan<- error +} + func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - if pc.mutateHeaderFunc != nil { - pc.mutateHeaderFunc(req.extraHeaders()) + pc.t.setReqConn(req.Request, pc) + pc.lk.Lock() + pc.numExpectedResponses++ + headerFn := pc.mutateHeaderFunc + pc.lk.Unlock() + + if headerFn != nil { + headerFn(req.extraHeaders()) } // Ask for a compressed version if the caller didn't set their @@ -633,34 +759,84 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // requested it. requestedGzip := false if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { - // Request gzip only, not deflate. Deflate is ambiguous and + // Request gzip only, not deflate. Deflate is ambiguous and // not as universally supported anyway. // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 requestedGzip = true req.extraHeaders().Set("Accept-Encoding", "gzip") } - pc.lk.Lock() - pc.numExpectedResponses++ - pc.lk.Unlock() + // Write the request concurrently with waiting for a response, + // in case the server decides to reply before reading our full + // request body. + writeErrCh := make(chan error, 1) + pc.writech <- writeRequest{req, writeErrCh} - err = req.Request.write(pc.bw, pc.isProxy, req.extra) - if err != nil { - pc.close() - return + resc := make(chan responseAndError, 1) + pc.reqch <- requestAndChan{req.Request, resc, requestedGzip} + + var re responseAndError + var pconnDeadCh = pc.closech + var failTicker <-chan time.Time + var respHeaderTimer <-chan time.Time +WaitResponse: + for { + select { + case err := <-writeErrCh: + if err != nil { + re = responseAndError{nil, err} + pc.close() + break WaitResponse + } + if d := pc.t.ResponseHeaderTimeout; d > 0 { + respHeaderTimer = time.After(d) + } + case <-pconnDeadCh: + // The persist connection is dead. This shouldn't + // usually happen (only with Connection: close responses + // with no response bodies), but if it does happen it + // means either a) the remote server hung up on us + // prematurely, or b) the readLoop sent us a response & + // closed its closech at roughly the same time, and we + // selected this case first, in which case a response + // might still be coming soon. + // + // We can't avoid the select race in b) by using a unbuffered + // resc channel instead, because then goroutines can + // leak if we exit due to other errors. + pconnDeadCh = nil // avoid spinning + failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc + case <-failTicker: + re = responseAndError{err: errors.New("net/http: transport closed before response was received")} + break WaitResponse + case <-respHeaderTimer: + pc.close() + re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")} + break WaitResponse + case re = <-resc: + break WaitResponse + } } - pc.bw.Flush() - ch := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req.Request, ch, requestedGzip} - re := <-ch pc.lk.Lock() pc.numExpectedResponses-- pc.lk.Unlock() + if re.err != nil { + pc.t.setReqConn(req.Request, nil) + } return re.res, re.err } +// markBroken marks a connection as broken (so it's not reused). +// It differs from close in that it doesn't close the underlying +// connection for use when it's still being read. +func (pc *persistConn) markBroken() { + pc.lk.Lock() + defer pc.lk.Unlock() + pc.broken = true +} + func (pc *persistConn) close() { pc.lk.Lock() defer pc.lk.Unlock() @@ -669,7 +845,10 @@ func (pc *persistConn) close() { func (pc *persistConn) closeLocked() { pc.broken = true - pc.conn.Close() + if !pc.closed { + pc.conn.Close() + pc.closed = true + } pc.mutateHeaderFunc = nil } @@ -687,43 +866,62 @@ func canonicalAddr(url *url.URL) string { return addr } -func responseIsKeepAlive(res *Response) bool { - // TODO: implement. for now just always shutting down the connection. - return false -} - // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most -// once, right before the final Read() or Close() call returns, but after -// EOF has been seen. +// once, right before its final (error-producing) Read or Close call +// returns. type bodyEOFSignal struct { - body io.ReadCloser - fn func() - isClosed bool + body io.ReadCloser + mu sync.Mutex // guards closed, rerr and fn + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) // error will be nil on Read io.EOF } func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { - n, err = es.body.Read(p) - if es.isClosed && n > 0 { - panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") + es.mu.Lock() + closed, rerr := es.closed, es.rerr + es.mu.Unlock() + if closed { + return 0, errors.New("http: read on closed response body") + } + if rerr != nil { + return 0, rerr } - if err == io.EOF && es.fn != nil { - es.fn() - es.fn = nil + + n, err = es.body.Read(p) + if err != nil { + es.mu.Lock() + defer es.mu.Unlock() + if es.rerr == nil { + es.rerr = err + } + es.condfn(err) } return } -func (es *bodyEOFSignal) Close() (err error) { - if es.isClosed { +func (es *bodyEOFSignal) Close() error { + es.mu.Lock() + defer es.mu.Unlock() + if es.closed { return nil } - es.isClosed = true - err = es.body.Close() - if err == nil && es.fn != nil { - es.fn() - es.fn = nil + es.closed = true + err := es.body.Close() + es.condfn(err) + return err +} + +// caller must hold es.mu. +func (es *bodyEOFSignal) condfn(err error) { + if es.fn == nil { + return } - return + if err == io.EOF { + err = nil + } + es.fn(err) + es.fn = nil } type readFirstCloseBoth struct { diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index a9e401de5..68010e68b 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "io/ioutil" + "net" . "net/http" "net/http/httptest" "net/url" @@ -20,6 +21,7 @@ import ( "runtime" "strconv" "strings" + "sync" "testing" "time" ) @@ -35,14 +37,78 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte(r.RemoteAddr)) }) +// testCloseConn is a net.Conn tracked by a testConnSet. +type testCloseConn struct { + net.Conn + set *testConnSet +} + +func (c *testCloseConn) Close() error { + c.set.remove(c) + return c.Conn.Close() +} + +// testConnSet tracks a set of TCP connections and whether they've +// been closed. +type testConnSet struct { + t *testing.T + closed map[net.Conn]bool + list []net.Conn // in order created + mutex sync.Mutex +} + +func (tcs *testConnSet) insert(c net.Conn) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + tcs.closed[c] = false + tcs.list = append(tcs.list, c) +} + +func (tcs *testConnSet) remove(c net.Conn) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + tcs.closed[c] = true +} + +// some tests use this to manage raw tcp connections for later inspection +func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) { + connSet := &testConnSet{ + t: t, + closed: make(map[net.Conn]bool), + } + dial := func(n, addr string) (net.Conn, error) { + c, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + tc := &testCloseConn{c, connSet} + connSet.insert(tc) + return tc, nil + } + return connSet, dial +} + +func (tcs *testConnSet) check(t *testing.T) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + + for i, c := range tcs.list { + if !tcs.closed[c] { + t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) + } + } +} + // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() for _, disableKeepAlive := range []bool{false, true} { tr := &Transport{DisableKeepAlives: disableKeepAlive} + defer tr.CloseIdleConnections() c := &Client{Transport: tr} fetch := func(n int) string { @@ -69,11 +135,16 @@ func TestTransportKeepAlives(t *testing.T) { } func TestTransportConnectionCloseOnResponse(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() + connSet, testDial := makeTestDial(t) + for _, connectionClose := range []bool{false, true} { - tr := &Transport{} + tr := &Transport{ + Dial: testDial, + } c := &Client{Transport: tr} fetch := func(n int) string { @@ -92,8 +163,8 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) } - body, err := ioutil.ReadAll(res.Body) defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } @@ -107,15 +178,24 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } + + tr.CloseIdleConnections() } + + connSet.check(t) } func TestTransportConnectionCloseOnRequest(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() + connSet, testDial := makeTestDial(t) + for _, connectionClose := range []bool{false, true} { - tr := &Transport{} + tr := &Transport{ + Dial: testDial, + } c := &Client{Transport: tr} fetch := func(n int) string { @@ -149,10 +229,15 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } + + tr.CloseIdleConnections() } + + connSet.check(t) } func TestTransportIdleCacheKeys(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -185,6 +270,7 @@ func TestTransportIdleCacheKeys(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { + defer checkLeakedTransports(t) resch := make(chan string) gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -201,7 +287,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { c := &Client{Transport: tr} // Start 3 outstanding requests and wait for the server to get them. - // Their responses will hang until we we write to resch, though. + // Their responses will hang until we write to resch, though. donech := make(chan bool) doReq := func() { resp, err := c.Get(ts.URL) @@ -253,6 +339,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportServerClosingUnexpectedly(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -309,9 +396,9 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { // Test for http://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { + defer checkLeakedTransports(t) if testing.Short() { - t.Logf("skipping test in short mode") - return + t.Skip("skipping test in short mode") } ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "5") @@ -365,6 +452,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly func TestTransportHeadResponses(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -384,7 +472,7 @@ func TestTransportHeadResponses(t *testing.T) { if e, g := "123", res.Header.Get("Content-Length"); e != g { t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) } - if e, g := int64(0), res.ContentLength; e != g { + if e, g := int64(123), res.ContentLength; e != g { t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } } @@ -393,6 +481,7 @@ func TestTransportHeadResponses(t *testing.T) { // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -434,6 +523,7 @@ var roundTripTests = []struct { // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { + defer checkLeakedTransports(t) const responseBody = "test response body" ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") @@ -490,6 +580,7 @@ func TestRoundTripGzip(t *testing.T) { } func TestTransportGzip(t *testing.T) { + defer checkLeakedTransports(t) const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -582,6 +673,7 @@ func TestTransportGzip(t *testing.T) { } func TestTransportProxy(t *testing.T) { + defer checkLeakedTransports(t) ch := make(chan string, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ch <- "real server" @@ -610,6 +702,7 @@ func TestTransportProxy(t *testing.T) { // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. func TestTransportGzipRecursive(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) @@ -636,6 +729,7 @@ func TestTransportGzipRecursive(t *testing.T) { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { + defer checkLeakedTransports(t) gotReqCh := make(chan bool) unblockCh := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -698,8 +792,49 @@ func TestTransportPersistConnLeak(t *testing.T) { } } +// golang.org/issue/4531: Transport leaks goroutines when +// request.ContentLength is explicitly short +func TestTransportPersistConnLeakShortBody(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })) + defer ts.Close() + + tr := &Transport{} + c := &Client{Transport: tr} + + n0 := runtime.NumGoroutine() + body := []byte("Hello") + for i := 0; i < 20; i++ { + req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.ContentLength = int64(len(body) - 2) // explicitly short + _, err = c.Do(req) + if err == nil { + t.Fatal("Expect an error from writing too long of a body.") + } + } + nhigh := runtime.NumGoroutine() + tr.CloseIdleConnections() + time.Sleep(50 * time.Millisecond) + runtime.GC() + nfinal := runtime.NumGoroutine() + + growth := nfinal - n0 + + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // Previously we were leaking one per numReq. + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) + if int(growth) > 5 { + t.Error("too many new goroutines") + } +} + // This used to crash; http://golang.org/issue/3266 func TestTransportIdleConnCrash(t *testing.T) { + defer checkLeakedTransports(t) tr := &Transport{} c := &Client{Transport: tr} @@ -724,6 +859,361 @@ func TestTransportIdleConnCrash(t *testing.T) { <-didreq } +// Test that the transport doesn't close the TCP connection early, +// before the response body has been read. This was a regression +// which sadly lacked a triggering test. The large response body made +// the old race easier to trigger. +func TestIssue3644(t *testing.T) { + defer checkLeakedTransports(t) + const numFoos = 5000 + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + for i := 0; i < numFoos; i++ { + w.Write([]byte("foo ")) + } + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + bs, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != numFoos*len("foo ") { + t.Errorf("unexpected response length") + } +} + +// Test that a client receives a server's reply, even if the server doesn't read +// the entire request body. +func TestIssue3595(t *testing.T) { + defer checkLeakedTransports(t) + const deniedMsg = "sorry, denied." + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, deniedMsg, StatusUnauthorized) + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) + if err != nil { + t.Errorf("Post: %v", err) + return + } + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body ReadAll: %v", err) + } + if !strings.Contains(string(got), deniedMsg) { + t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) + } +} + +// From http://golang.org/issue/4454 , +// "client fails to handle requests with no body and chunked encoding" +func TestChunkedNoContent(t *testing.T) { + defer checkLeakedTransports(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNoContent) + })) + defer ts.Close() + + for _, closeBody := range []bool{true, false} { + c := &Client{Transport: &Transport{}} + const n = 4 + for i := 1; i <= n; i++ { + res, err := c.Get(ts.URL) + if err != nil { + t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) + } else { + if closeBody { + res.Body.Close() + } + } + } + } +} + +func TestTransportConcurrency(t *testing.T) { + defer checkLeakedTransports(t) + const maxProcs = 16 + const numReqs = 500 + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%v", r.FormValue("echo")) + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + reqs := make(chan string) + defer close(reqs) + + var wg sync.WaitGroup + wg.Add(numReqs) + for i := 0; i < maxProcs*2; i++ { + go func() { + for req := range reqs { + res, err := c.Get(ts.URL + "/?echo=" + req) + if err != nil { + t.Errorf("error on req %s: %v", req, err) + wg.Done() + continue + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("read error on req %s: %v", req, err) + wg.Done() + continue + } + if string(all) != req { + t.Errorf("body of req %s = %q; want %q", req, all, req) + } + wg.Done() + res.Body.Close() + } + }() + } + for i := 0; i < numReqs; i++ { + reqs <- fmt.Sprintf("request-%d", i) + } + wg.Wait() +} + +func TestIssue4191_InfiniteGetTimeout(t *testing.T) { + defer checkLeakedTransports(t) + const debug = false + mux := NewServeMux() + mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { + io.Copy(w, neverEnding('a')) + }) + ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond + + client := &Client{ + Transport: &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil + }, + DisableKeepAlives: true, + }, + } + + getFailed := false + nRuns := 5 + if testing.Short() { + nRuns = 1 + } + for i := 0; i < nRuns; i++ { + if debug { + println("run", i+1, "of", nRuns) + } + sres, err := client.Get(ts.URL + "/get") + if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } + t.Errorf("Error issuing GET: %v", err) + break + } + _, err = io.Copy(ioutil.Discard, sres.Body) + if err == nil { + t.Errorf("Unexpected successful copy") + break + } + } + if debug { + println("tests complete; waiting for handlers to finish") + } + ts.Close() +} + +func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { + defer checkLeakedTransports(t) + const debug = false + mux := NewServeMux() + mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { + io.Copy(w, neverEnding('a')) + }) + mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { + defer r.Body.Close() + io.Copy(ioutil.Discard, r.Body) + }) + ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond + + client := &Client{ + Transport: &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil + }, + DisableKeepAlives: true, + }, + } + + getFailed := false + nRuns := 5 + if testing.Short() { + nRuns = 1 + } + for i := 0; i < nRuns; i++ { + if debug { + println("run", i+1, "of", nRuns) + } + sres, err := client.Get(ts.URL + "/get") + if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } + t.Errorf("Error issuing GET: %v", err) + break + } + req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) + _, err = client.Do(req) + if err == nil { + sres.Body.Close() + t.Errorf("Unexpected successful PUT") + break + } + sres.Body.Close() + } + if debug { + println("tests complete; waiting for handlers to finish") + } + ts.Close() +} + +func TestTransportResponseHeaderTimeout(t *testing.T) { + defer checkLeakedTransports(t) + if testing.Short() { + t.Skip("skipping timeout test in -short mode") + } + mux := NewServeMux() + mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {}) + mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { + time.Sleep(2 * time.Second) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + tr := &Transport{ + ResponseHeaderTimeout: 500 * time.Millisecond, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + tests := []struct { + path string + want int + wantErr string + }{ + {path: "/fast", want: 200}, + {path: "/slow", wantErr: "timeout awaiting response headers"}, + {path: "/fast", want: 200}, + } + for i, tt := range tests { + res, err := c.Get(ts.URL + tt.path) + if err != nil { + if strings.Contains(err.Error(), tt.wantErr) { + continue + } + t.Errorf("%d. unexpected error: %v", i, err) + continue + } + if tt.wantErr != "" { + t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) + continue + } + if res.StatusCode != tt.want { + t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) + } + } +} + +func TestTransportCancelRequest(t *testing.T) { + defer checkLeakedTransports(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello") + w.(Flusher).Flush() // send headers and some body + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(1 * time.Second) + tr.CancelRequest(req) + }() + t0 := time.Now() + body, err := ioutil.ReadAll(res.Body) + d := time.Since(t0) + + if err == nil { + t.Error("expected an error reading the body") + } + if string(body) != "Hello" { + t.Errorf("Body = %q; want Hello", body) + } + if d < 500*time.Millisecond { + t.Errorf("expected ~1 second delay; got %v", d) + } + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + for tries := 3; tries > 0; tries-- { + n := tr.NumPendingRequestsForTesting() + if n == 0 { + break + } + time.Sleep(100 * time.Millisecond) + if tries == 1 { + t.Errorf("pending requests = %d; want 0", n) + } + } +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { @@ -737,6 +1227,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) { } func TestTransportAltProto(t *testing.T) { + defer checkLeakedTransports(t) tr := &Transport{} c := &Client{Transport: tr} tr.RegisterProtocol("foo", fooProto{}) @@ -754,15 +1245,58 @@ func TestTransportAltProto(t *testing.T) { } } -var proxyFromEnvTests = []struct { +func TestTransportNoHost(t *testing.T) { + defer checkLeakedTransports(t) + tr := &Transport{} + _, err := tr.RoundTrip(&Request{ + Header: make(Header), + URL: &url.URL{ + Scheme: "http", + }, + }) + want := "http: no Host in request URL" + if got := fmt.Sprint(err); got != want { + t.Errorf("error = %v; want %q", err, want) + } +} + +type proxyFromEnvTest struct { + req string // URL to fetch; blank means "http://example.com" env string - wanturl string + noenv string + want string wanterr error -}{ - {"127.0.0.1:8080", "http://127.0.0.1:8080", nil}, - {"http://127.0.0.1:8080", "http://127.0.0.1:8080", nil}, - {"https://127.0.0.1:8080", "https://127.0.0.1:8080", nil}, - {"", "<nil>", nil}, +} + +func (t proxyFromEnvTest) String() string { + var buf bytes.Buffer + if t.env != "" { + fmt.Fprintf(&buf, "http_proxy=%q", t.env) + } + if t.noenv != "" { + fmt.Fprintf(&buf, " no_proxy=%q", t.noenv) + } + req := "http://example.com" + if t.req != "" { + req = t.req + } + fmt.Fprintf(&buf, " req=%q", req) + return strings.TrimSpace(buf.String()) +} + +var proxyFromEnvTests = []proxyFromEnvTest{ + {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, + {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, + {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, + {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, + {want: "<nil>"}, + {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, + {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, + {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, + {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"}, + {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, } func TestProxyFromEnvironment(t *testing.T) { @@ -770,16 +1304,21 @@ func TestProxyFromEnvironment(t *testing.T) { os.Setenv("http_proxy", "") os.Setenv("NO_PROXY", "") os.Setenv("no_proxy", "") - for i, tt := range proxyFromEnvTests { + for _, tt := range proxyFromEnvTests { os.Setenv("HTTP_PROXY", tt.env) - req, _ := NewRequest("GET", "http://example.com", nil) + os.Setenv("NO_PROXY", tt.noenv) + reqURL := tt.req + if reqURL == "" { + reqURL = "http://example.com" + } + req, _ := NewRequest("GET", reqURL, nil) url, err := ProxyFromEnvironment(req) if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { - t.Errorf("%d. got error = %q, want %q", i, g, e) + t.Errorf("%v: got error = %q, want %q", tt, g, e) continue } - if got := fmt.Sprintf("%s", url); got != tt.wanturl { - t.Errorf("%d. got URL = %q, want %q", i, url, tt.wanturl) + if got := fmt.Sprintf("%s", url); got != tt.want { + t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) } } } diff --git a/src/pkg/net/http/z_last_test.go b/src/pkg/net/http/z_last_test.go new file mode 100644 index 000000000..44095a8d9 --- /dev/null +++ b/src/pkg/net/http/z_last_test.go @@ -0,0 +1,60 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "net/http" + "runtime" + "strings" + "testing" + "time" +) + +// Verify the other tests didn't leave any goroutines running. +// This is in a file named z_last_test.go so it sorts at the end. +func TestGoroutinesRunning(t *testing.T) { + n := runtime.NumGoroutine() + t.Logf("num goroutines = %d", n) + if n > 20 { + // Currently 14 on Linux (blocked in epoll_wait, + // waiting for on fds that are closed?), but give some + // slop for now. + buf := make([]byte, 1<<20) + buf = buf[:runtime.Stack(buf, true)] + t.Errorf("Too many goroutines:\n%s", buf) + } +} + +func checkLeakedTransports(t *testing.T) { + http.DefaultTransport.(*http.Transport).CloseIdleConnections() + if testing.Short() { + return + } + buf := make([]byte, 1<<20) + var stacks string + var bad string + badSubstring := map[string]string{ + ").readLoop(": "a Transport", + ").writeLoop(": "a Transport", + "created by net/http/httptest.(*Server).Start": "an httptest.Server", + "timeoutHandler": "a TimeoutHandler", + } + for i := 0; i < 4; i++ { + bad = "" + stacks = string(buf[:runtime.Stack(buf, true)]) + for substr, what := range badSubstring { + if strings.Contains(stacks, substr) { + bad = what + } + } + if bad == "" { + return + } + // Bad stuff found, but goroutines might just still be + // shutting down, so give it some time. + time.Sleep(250 * time.Millisecond) + } + t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) +} diff --git a/src/pkg/net/interface.go b/src/pkg/net/interface.go index ee23570a9..0713e9cd6 100644 --- a/src/pkg/net/interface.go +++ b/src/pkg/net/interface.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification - package net import "errors" @@ -66,7 +64,7 @@ func (ifi *Interface) Addrs() ([]Addr, error) { if ifi == nil { return nil, errInvalidInterface } - return interfaceAddrTable(ifi.Index) + return interfaceAddrTable(ifi) } // MulticastAddrs returns multicast, joined group addresses for @@ -75,7 +73,7 @@ func (ifi *Interface) MulticastAddrs() ([]Addr, error) { if ifi == nil { return nil, errInvalidInterface } - return interfaceMulticastAddrTable(ifi.Index) + return interfaceMulticastAddrTable(ifi) } // Interfaces returns a list of the system's network interfaces. @@ -86,7 +84,7 @@ func Interfaces() ([]Interface, error) { // InterfaceAddrs returns a list of the system's network interface // addresses. func InterfaceAddrs() ([]Addr, error) { - return interfaceAddrTable(0) + return interfaceAddrTable(nil) } // InterfaceByIndex returns the interface specified by index. @@ -98,8 +96,14 @@ func InterfaceByIndex(index int) (*Interface, error) { if err != nil { return nil, err } + return interfaceByIndex(ift, index) +} + +func interfaceByIndex(ift []Interface, index int) (*Interface, error) { for _, ifi := range ift { - return &ifi, nil + if index == ifi.Index { + return &ifi, nil + } } return nil, errNoSuchInterface } diff --git a/src/pkg/net/interface_bsd.go b/src/pkg/net/interface_bsd.go index 7f090d8d4..f58065a85 100644 --- a/src/pkg/net/interface_bsd.go +++ b/src/pkg/net/interface_bsd.go @@ -4,8 +4,6 @@ // +build darwin freebsd netbsd openbsd -// Network interface identification for BSD variants - package net import ( @@ -22,57 +20,60 @@ func interfaceTable(ifindex int) ([]Interface, error) { if err != nil { return nil, os.NewSyscallError("route rib", err) } - msgs, err := syscall.ParseRoutingMessage(tab) if err != nil { return nil, os.NewSyscallError("route message", err) } + return parseInterfaceTable(ifindex, msgs) +} +func parseInterfaceTable(ifindex int, msgs []syscall.RoutingMessage) ([]Interface, error) { var ift []Interface +loop: for _, m := range msgs { - switch v := m.(type) { + switch m := m.(type) { case *syscall.InterfaceMessage: - if ifindex == 0 || ifindex == int(v.Header.Index) { - ifi, err := newLink(v) + if ifindex == 0 || ifindex == int(m.Header.Index) { + ifi, err := newLink(m) if err != nil { return nil, err } - ift = append(ift, ifi...) + ift = append(ift, *ifi) + if ifindex == int(m.Header.Index) { + break loop + } } } } return ift, nil } -func newLink(m *syscall.InterfaceMessage) ([]Interface, error) { +func newLink(m *syscall.InterfaceMessage) (*Interface, error) { sas, err := syscall.ParseRoutingSockaddr(m) if err != nil { return nil, os.NewSyscallError("route sockaddr", err) } - - var ift []Interface - for _, s := range sas { - switch v := s.(type) { + ifi := &Interface{Index: int(m.Header.Index), Flags: linkFlags(m.Header.Flags)} + for _, sa := range sas { + switch sa := sa.(type) { case *syscall.SockaddrDatalink: // NOTE: SockaddrDatalink.Data is minimum work area, // can be larger. - m.Data = m.Data[unsafe.Offsetof(v.Data):] - ifi := Interface{Index: int(m.Header.Index), Flags: linkFlags(m.Header.Flags)} + m.Data = m.Data[unsafe.Offsetof(sa.Data):] var name [syscall.IFNAMSIZ]byte - for i := 0; i < int(v.Nlen); i++ { + for i := 0; i < int(sa.Nlen); i++ { name[i] = byte(m.Data[i]) } - ifi.Name = string(name[:v.Nlen]) + ifi.Name = string(name[:sa.Nlen]) ifi.MTU = int(m.Header.Data.Mtu) - addr := make([]byte, v.Alen) - for i := 0; i < int(v.Alen); i++ { - addr[i] = byte(m.Data[int(v.Nlen)+i]) + addr := make([]byte, sa.Alen) + for i := 0; i < int(sa.Alen); i++ { + addr[i] = byte(m.Data[int(sa.Nlen)+i]) } - ifi.HardwareAddr = addr[:v.Alen] - ift = append(ift, ifi) + ifi.HardwareAddr = addr[:sa.Alen] } } - return ift, nil + return ifi, nil } func linkFlags(rawFlags int32) Flags { @@ -95,68 +96,87 @@ func linkFlags(rawFlags int32) Flags { return f } -// If the ifindex is zero, interfaceAddrTable returns addresses -// for all network interfaces. Otherwise it returns addresses -// for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, error) { - tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex) +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { + index := 0 + if ifi != nil { + index = ifi.Index + } + tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, index) if err != nil { return nil, os.NewSyscallError("route rib", err) } - msgs, err := syscall.ParseRoutingMessage(tab) if err != nil { return nil, os.NewSyscallError("route message", err) } - + var ift []Interface + if index == 0 { + ift, err = parseInterfaceTable(index, msgs) + if err != nil { + return nil, err + } + } var ifat []Addr for _, m := range msgs { - switch v := m.(type) { + switch m := m.(type) { case *syscall.InterfaceAddrMessage: - if ifindex == 0 || ifindex == int(v.Header.Index) { - ifa, err := newAddr(v) + if index == 0 || index == int(m.Header.Index) { + if index == 0 { + var err error + ifi, err = interfaceByIndex(ift, int(m.Header.Index)) + if err != nil { + return nil, err + } + } + ifa, err := newAddr(ifi, m) if err != nil { return nil, err } - ifat = append(ifat, ifa) + if ifa != nil { + ifat = append(ifat, ifa) + } } } } return ifat, nil } -func newAddr(m *syscall.InterfaceAddrMessage) (Addr, error) { +func newAddr(ifi *Interface, m *syscall.InterfaceAddrMessage) (Addr, error) { sas, err := syscall.ParseRoutingSockaddr(m) if err != nil { return nil, os.NewSyscallError("route sockaddr", err) } - ifa := &IPNet{} - for i, s := range sas { - switch v := s.(type) { + for i, sa := range sas { + switch sa := sa.(type) { case *syscall.SockaddrInet4: switch i { case 0: - ifa.Mask = IPv4Mask(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3]) + ifa.Mask = IPv4Mask(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) case 1: - ifa.IP = IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3]) + ifa.IP = IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) } case *syscall.SockaddrInet6: switch i { case 0: ifa.Mask = make(IPMask, IPv6len) - copy(ifa.Mask, v.Addr[:]) + copy(ifa.Mask, sa.Addr[:]) case 1: ifa.IP = make(IP, IPv6len) - copy(ifa.IP, v.Addr[:]) + copy(ifa.IP, sa.Addr[:]) // NOTE: KAME based IPv6 protcol stack usually embeds // the interface index in the interface-local or link- // local address as the kernel-internal form. if ifa.IP.IsLinkLocalUnicast() { - // remove embedded scope zone ID + ifa.Zone = ifi.Name ifa.IP[2], ifa.IP[3] = 0, 0 } } + default: // Sockaddrs contain syscall.SockaddrDatalink on NetBSD + return nil, nil } } return ifa, nil diff --git a/src/pkg/net/interface_bsd_test.go b/src/pkg/net/interface_bsd_test.go new file mode 100644 index 000000000..aa1141903 --- /dev/null +++ b/src/pkg/net/interface_bsd_test.go @@ -0,0 +1,52 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd netbsd openbsd + +package net + +import ( + "fmt" + "os/exec" +) + +func (ti *testInterface) setBroadcast(suffix int) error { + ti.name = fmt.Sprintf("vlan%d", suffix) + xname, err := exec.LookPath("ifconfig") + if err != nil { + return err + } + ti.setupCmds = append(ti.setupCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ifconfig", ti.name, "create"}, + }) + ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ifconfig", ti.name, "destroy"}, + }) + return nil +} + +func (ti *testInterface) setPointToPoint(suffix int, local, remote string) error { + ti.name = fmt.Sprintf("gif%d", suffix) + ti.local = local + ti.remote = remote + xname, err := exec.LookPath("ifconfig") + if err != nil { + return err + } + ti.setupCmds = append(ti.setupCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ifconfig", ti.name, "create"}, + }) + ti.setupCmds = append(ti.setupCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ifconfig", ti.name, "inet", ti.local, ti.remote}, + }) + ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ifconfig", ti.name, "destroy"}, + }) + return nil +} diff --git a/src/pkg/net/interface_darwin.go b/src/pkg/net/interface_darwin.go index 0b5fb5fb9..83e483ba2 100644 --- a/src/pkg/net/interface_darwin.go +++ b/src/pkg/net/interface_darwin.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for Darwin - package net import ( @@ -11,26 +9,23 @@ import ( "syscall" ) -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { - tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST2, ifindex) +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST2, ifi.Index) if err != nil { return nil, os.NewSyscallError("route rib", err) } - msgs, err := syscall.ParseRoutingMessage(tab) if err != nil { return nil, os.NewSyscallError("route message", err) } - var ifmat []Addr for _, m := range msgs { - switch v := m.(type) { + switch m := m.(type) { case *syscall.InterfaceMulticastAddrMessage: - if ifindex == 0 || ifindex == int(v.Header.Index) { - ifma, err := newMulticastAddr(v) + if ifi.Index == int(m.Header.Index) { + ifma, err := newMulticastAddr(ifi, m) if err != nil { return nil, err } @@ -41,27 +36,25 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { return ifmat, nil } -func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { +func newMulticastAddr(ifi *Interface, m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { sas, err := syscall.ParseRoutingSockaddr(m) if err != nil { return nil, os.NewSyscallError("route sockaddr", err) } - var ifmat []Addr - for _, s := range sas { - switch v := s.(type) { + for _, sa := range sas { + switch sa := sa.(type) { case *syscall.SockaddrInet4: - ifma := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])} + ifma := &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])} ifmat = append(ifmat, ifma.toAddr()) case *syscall.SockaddrInet6: ifma := &IPAddr{IP: make(IP, IPv6len)} - copy(ifma.IP, v.Addr[:]) + copy(ifma.IP, sa.Addr[:]) // NOTE: KAME based IPv6 protcol stack usually embeds // the interface index in the interface-local or link- // local address as the kernel-internal form. - if ifma.IP.IsInterfaceLocalMulticast() || - ifma.IP.IsLinkLocalMulticast() { - // remove embedded scope zone ID + if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() { + ifma.Zone = ifi.Name ifma.IP[2], ifma.IP[3] = 0, 0 } ifmat = append(ifmat, ifma.toAddr()) diff --git a/src/pkg/net/interface_freebsd.go b/src/pkg/net/interface_freebsd.go index 3cba28fc6..1bf5ae72b 100644 --- a/src/pkg/net/interface_freebsd.go +++ b/src/pkg/net/interface_freebsd.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for FreeBSD - package net import ( @@ -11,26 +9,23 @@ import ( "syscall" ) -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { - tab, err := syscall.RouteRIB(syscall.NET_RT_IFMALIST, ifindex) +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + tab, err := syscall.RouteRIB(syscall.NET_RT_IFMALIST, ifi.Index) if err != nil { return nil, os.NewSyscallError("route rib", err) } - msgs, err := syscall.ParseRoutingMessage(tab) if err != nil { return nil, os.NewSyscallError("route message", err) } - var ifmat []Addr for _, m := range msgs { - switch v := m.(type) { + switch m := m.(type) { case *syscall.InterfaceMulticastAddrMessage: - if ifindex == 0 || ifindex == int(v.Header.Index) { - ifma, err := newMulticastAddr(v) + if ifi.Index == int(m.Header.Index) { + ifma, err := newMulticastAddr(ifi, m) if err != nil { return nil, err } @@ -41,27 +36,25 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { return ifmat, nil } -func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { +func newMulticastAddr(ifi *Interface, m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { sas, err := syscall.ParseRoutingSockaddr(m) if err != nil { return nil, os.NewSyscallError("route sockaddr", err) } - var ifmat []Addr - for _, s := range sas { - switch v := s.(type) { + for _, sa := range sas { + switch sa := sa.(type) { case *syscall.SockaddrInet4: - ifma := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])} + ifma := &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])} ifmat = append(ifmat, ifma.toAddr()) case *syscall.SockaddrInet6: ifma := &IPAddr{IP: make(IP, IPv6len)} - copy(ifma.IP, v.Addr[:]) + copy(ifma.IP, sa.Addr[:]) // NOTE: KAME based IPv6 protcol stack usually embeds // the interface index in the interface-local or link- // local address as the kernel-internal form. - if ifma.IP.IsInterfaceLocalMulticast() || - ifma.IP.IsLinkLocalMulticast() { - // remove embedded scope zone ID + if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() { + ifma.Zone = ifi.Name ifma.IP[2], ifma.IP[3] = 0, 0 } ifmat = append(ifmat, ifma.toAddr()) diff --git a/src/pkg/net/interface_linux.go b/src/pkg/net/interface_linux.go index 825b20227..e66daef06 100644 --- a/src/pkg/net/interface_linux.go +++ b/src/pkg/net/interface_linux.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for Linux - package net import ( @@ -20,17 +18,16 @@ func interfaceTable(ifindex int) ([]Interface, error) { if err != nil { return nil, os.NewSyscallError("netlink rib", err) } - msgs, err := syscall.ParseNetlinkMessage(tab) if err != nil { return nil, os.NewSyscallError("netlink message", err) } - var ift []Interface +loop: for _, m := range msgs { switch m.Header.Type { case syscall.NLMSG_DONE: - goto done + break loop case syscall.RTM_NEWLINK: ifim := (*syscall.IfInfomsg)(unsafe.Pointer(&m.Data[0])) if ifindex == 0 || ifindex == int(ifim.Index) { @@ -38,17 +35,18 @@ func interfaceTable(ifindex int) ([]Interface, error) { if err != nil { return nil, os.NewSyscallError("netlink routeattr", err) } - ifi := newLink(ifim, attrs) - ift = append(ift, ifi) + ift = append(ift, *newLink(ifim, attrs)) + if ifindex == int(ifim.Index) { + break loop + } } } } -done: return ift, nil } -func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) Interface { - ifi := Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)} +func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) *Interface { + ifi := &Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)} for _, a := range attrs { switch a.Attr.Type { case syscall.IFLA_ADDRESS: @@ -64,7 +62,7 @@ func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) Interfac case syscall.IFLA_IFNAME: ifi.Name = string(a.Value[:len(a.Value)-1]) case syscall.IFLA_MTU: - ifi.MTU = int(uint32(a.Value[3])<<24 | uint32(a.Value[2])<<16 | uint32(a.Value[1])<<8 | uint32(a.Value[0])) + ifi.MTU = int(*(*uint32)(unsafe.Pointer(&a.Value[:4][0]))) } } return ifi @@ -90,81 +88,87 @@ func linkFlags(rawFlags uint32) Flags { return f } -// If the ifindex is zero, interfaceAddrTable returns addresses -// for all network interfaces. Otherwise it returns addresses -// for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, error) { +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { tab, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC) if err != nil { return nil, os.NewSyscallError("netlink rib", err) } - msgs, err := syscall.ParseNetlinkMessage(tab) if err != nil { return nil, os.NewSyscallError("netlink message", err) } - - ifat, err := addrTable(msgs, ifindex) + var ift []Interface + if ifi == nil { + var err error + ift, err = interfaceTable(0) + if err != nil { + return nil, err + } + } + ifat, err := addrTable(ift, ifi, msgs) if err != nil { return nil, err } return ifat, nil } -func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, error) { +func addrTable(ift []Interface, ifi *Interface, msgs []syscall.NetlinkMessage) ([]Addr, error) { var ifat []Addr +loop: for _, m := range msgs { switch m.Header.Type { case syscall.NLMSG_DONE: - goto done + break loop case syscall.RTM_NEWADDR: ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0])) - if ifindex == 0 || ifindex == int(ifam.Index) { + if len(ift) != 0 || ifi.Index == int(ifam.Index) { + if len(ift) != 0 { + var err error + ifi, err = interfaceByIndex(ift, int(ifam.Index)) + if err != nil { + return nil, err + } + } attrs, err := syscall.ParseNetlinkRouteAttr(&m) if err != nil { return nil, os.NewSyscallError("netlink routeattr", err) } - ifat = append(ifat, newAddr(attrs, int(ifam.Family), int(ifam.Prefixlen))) + ifa := newAddr(ifi, ifam, attrs) + if ifa != nil { + ifat = append(ifat, ifa) + } } } } -done: return ifat, nil } -func newAddr(attrs []syscall.NetlinkRouteAttr, family, pfxlen int) Addr { - ifa := &IPNet{} +func newAddr(ifi *Interface, ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRouteAttr) Addr { for _, a := range attrs { - switch a.Attr.Type { - case syscall.IFA_ADDRESS: - switch family { + if ifi.Flags&FlagPointToPoint != 0 && a.Attr.Type == syscall.IFA_LOCAL || + ifi.Flags&FlagPointToPoint == 0 && a.Attr.Type == syscall.IFA_ADDRESS { + switch ifam.Family { case syscall.AF_INET: - ifa.IP = IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]) - ifa.Mask = CIDRMask(pfxlen, 8*IPv4len) + return &IPNet{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv4len)} case syscall.AF_INET6: - ifa.IP = make(IP, IPv6len) + ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv6len)} copy(ifa.IP, a.Value[:]) - ifa.Mask = CIDRMask(pfxlen, 8*IPv6len) + if ifam.Scope == syscall.RT_SCOPE_HOST || ifam.Scope == syscall.RT_SCOPE_LINK { + ifa.Zone = ifi.Name + } + return ifa } } } - return ifa + return nil } -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { - var ( - err error - ifi *Interface - ) - if ifindex > 0 { - ifi, err = InterfaceByIndex(ifindex) - if err != nil { - return nil, err - } - } +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { ifmat4 := parseProcNetIGMP("/proc/net/igmp", ifi) ifmat6 := parseProcNetIGMP6("/proc/net/igmp6", ifi) return append(ifmat4, ifmat6...), nil @@ -176,7 +180,6 @@ func parseProcNetIGMP(path string, ifi *Interface) []Addr { return nil } defer fd.close() - var ( ifmat []Addr name string @@ -193,10 +196,14 @@ func parseProcNetIGMP(path string, ifi *Interface) []Addr { name = f[1] case len(f[0]) == 8: if ifi == nil || name == ifi.Name { + // The Linux kernel puts the IP + // address in /proc/net/igmp in native + // endianness. for i := 0; i+1 < len(f[0]); i += 2 { b[i/2], _ = xtoi2(f[0][i:i+2], 0) } - ifma := IPAddr{IP: IPv4(b[3], b[2], b[1], b[0])} + i := *(*uint32)(unsafe.Pointer(&b[:4][0])) + ifma := IPAddr{IP: IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i))} ifmat = append(ifmat, ifma.toAddr()) } } @@ -210,7 +217,6 @@ func parseProcNetIGMP6(path string, ifi *Interface) []Addr { return nil } defer fd.close() - var ifmat []Addr b := make([]byte, IPv6len) for l, ok := fd.readLine(); ok; l, ok = fd.readLine() { @@ -223,6 +229,9 @@ func parseProcNetIGMP6(path string, ifi *Interface) []Addr { b[i/2], _ = xtoi2(f[2][i:i+2], 0) } ifma := IPAddr{IP: IP{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]}} + if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() { + ifma.Zone = ifi.Name + } ifmat = append(ifmat, ifma.toAddr()) } } diff --git a/src/pkg/net/interface_linux_test.go b/src/pkg/net/interface_linux_test.go index f14d1fe06..085d3de9d 100644 --- a/src/pkg/net/interface_linux_test.go +++ b/src/pkg/net/interface_linux_test.go @@ -4,7 +4,55 @@ package net -import "testing" +import ( + "fmt" + "os/exec" + "testing" +) + +func (ti *testInterface) setBroadcast(suffix int) error { + ti.name = fmt.Sprintf("gotest%d", suffix) + xname, err := exec.LookPath("ip") + if err != nil { + return err + } + ti.setupCmds = append(ti.setupCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ip", "link", "add", ti.name, "type", "dummy"}, + }) + ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ip", "link", "delete", ti.name, "type", "dummy"}, + }) + return nil +} + +func (ti *testInterface) setPointToPoint(suffix int, local, remote string) error { + ti.name = fmt.Sprintf("gotest%d", suffix) + ti.local = local + ti.remote = remote + xname, err := exec.LookPath("ip") + if err != nil { + return err + } + ti.setupCmds = append(ti.setupCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ip", "tunnel", "add", ti.name, "mode", "gre", "local", local, "remote", remote}, + }) + ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ip", "tunnel", "del", ti.name, "mode", "gre", "local", local, "remote", remote}, + }) + xname, err = exec.LookPath("ifconfig") + if err != nil { + return err + } + ti.setupCmds = append(ti.setupCmds, &exec.Cmd{ + Path: xname, + Args: []string{"ifconfig", ti.name, "inet", local, "dstaddr", remote}, + }) + return nil +} const ( numOfTestIPv4MCAddrs = 14 diff --git a/src/pkg/net/interface_netbsd.go b/src/pkg/net/interface_netbsd.go index 4150e9ad5..c9ce5a7ac 100644 --- a/src/pkg/net/interface_netbsd.go +++ b/src/pkg/net/interface_netbsd.go @@ -2,13 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for NetBSD - package net -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + // TODO(mikio): Implement this like other platforms. return nil, nil } diff --git a/src/pkg/net/interface_openbsd.go b/src/pkg/net/interface_openbsd.go index d8adb4676..c9ce5a7ac 100644 --- a/src/pkg/net/interface_openbsd.go +++ b/src/pkg/net/interface_openbsd.go @@ -2,13 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for OpenBSD - package net -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + // TODO(mikio): Implement this like other platforms. return nil, nil } diff --git a/src/pkg/net/interface_stub.go b/src/pkg/net/interface_stub.go index d4d7ce9c7..a4eb731da 100644 --- a/src/pkg/net/interface_stub.go +++ b/src/pkg/net/interface_stub.go @@ -4,8 +4,6 @@ // +build plan9 -// Network interface identification - package net // If the ifindex is zero, interfaceTable returns mappings of all @@ -15,16 +13,15 @@ func interfaceTable(ifindex int) ([]Interface, error) { return nil, nil } -// If the ifindex is zero, interfaceAddrTable returns addresses -// for all network interfaces. Otherwise it returns addresses -// for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, error) { +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { return nil, nil } -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { return nil, nil } diff --git a/src/pkg/net/interface_test.go b/src/pkg/net/interface_test.go index 0a33bfdb5..7fb342818 100644 --- a/src/pkg/net/interface_test.go +++ b/src/pkg/net/interface_test.go @@ -5,18 +5,24 @@ package net import ( - "bytes" + "reflect" "testing" ) -func sameInterface(i, j *Interface) bool { - if i == nil || j == nil { - return false +// loopbackInterface returns an available logical network interface +// for loopback tests. It returns nil if no suitable interface is +// found. +func loopbackInterface() *Interface { + ift, err := Interfaces() + if err != nil { + return nil } - if i.Index == j.Index && i.Name == j.Name && bytes.Equal(i.HardwareAddr, j.HardwareAddr) { - return true + for _, ifi := range ift { + if ifi.Flags&FlagLoopback != 0 && ifi.Flags&FlagUp != 0 { + return &ifi + } } - return false + return nil } func TestInterfaces(t *testing.T) { @@ -24,24 +30,24 @@ func TestInterfaces(t *testing.T) { if err != nil { t.Fatalf("Interfaces failed: %v", err) } - t.Logf("table: len/cap = %v/%v\n", len(ift), cap(ift)) + t.Logf("table: len/cap = %v/%v", len(ift), cap(ift)) for _, ifi := range ift { ifxi, err := InterfaceByIndex(ifi.Index) if err != nil { - t.Fatalf("InterfaceByIndex(%q) failed: %v", ifi.Index, err) + t.Fatalf("InterfaceByIndex(%v) failed: %v", ifi.Index, err) } - if !sameInterface(ifxi, &ifi) { - t.Fatalf("InterfaceByIndex(%q) = %v, want %v", ifi.Index, *ifxi, ifi) + if !reflect.DeepEqual(ifxi, &ifi) { + t.Fatalf("InterfaceByIndex(%v) = %v, want %v", ifi.Index, ifxi, ifi) } ifxn, err := InterfaceByName(ifi.Name) if err != nil { t.Fatalf("InterfaceByName(%q) failed: %v", ifi.Name, err) } - if !sameInterface(ifxn, &ifi) { - t.Fatalf("InterfaceByName(%q) = %v, want %v", ifi.Name, *ifxn, ifi) + if !reflect.DeepEqual(ifxn, &ifi) { + t.Fatalf("InterfaceByName(%q) = %v, want %v", ifi.Name, ifxn, ifi) } - t.Logf("%q: flags %q, ifindex %v, mtu %v\n", ifi.Name, ifi.Flags.String(), ifi.Index, ifi.MTU) + t.Logf("%q: flags %q, ifindex %v, mtu %v", ifi.Name, ifi.Flags.String(), ifi.Index, ifi.MTU) t.Logf("\thardware address %q", ifi.HardwareAddr.String()) testInterfaceAddrs(t, &ifi) testInterfaceMulticastAddrs(t, &ifi) @@ -53,7 +59,7 @@ func TestInterfaceAddrs(t *testing.T) { if err != nil { t.Fatalf("InterfaceAddrs failed: %v", err) } - t.Logf("table: len/cap = %v/%v\n", len(ifat), cap(ifat)) + t.Logf("table: len/cap = %v/%v", len(ifat), cap(ifat)) testAddrs(t, ifat) } @@ -75,9 +81,13 @@ func testInterfaceMulticastAddrs(t *testing.T, ifi *Interface) { func testAddrs(t *testing.T, ifat []Addr) { for _, ifa := range ifat { - switch ifa.(type) { + switch v := ifa.(type) { case *IPAddr, *IPNet: - t.Logf("\tinterface address %q\n", ifa.String()) + if v == nil { + t.Errorf("\tunexpected value: %v", ifa) + } else { + t.Logf("\tinterface address %q", ifa.String()) + } default: t.Errorf("\tunexpected type: %T", ifa) } @@ -86,11 +96,79 @@ func testAddrs(t *testing.T, ifat []Addr) { func testMulticastAddrs(t *testing.T, ifmat []Addr) { for _, ifma := range ifmat { - switch ifma.(type) { + switch v := ifma.(type) { case *IPAddr: - t.Logf("\tjoined group address %q\n", ifma.String()) + if v == nil { + t.Errorf("\tunexpected value: %v", ifma) + } else { + t.Logf("\tjoined group address %q", ifma.String()) + } default: t.Errorf("\tunexpected type: %T", ifma) } } } + +func BenchmarkInterfaces(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := Interfaces(); err != nil { + b.Fatalf("Interfaces failed: %v", err) + } + } +} + +func BenchmarkInterfaceByIndex(b *testing.B) { + ifi := loopbackInterface() + if ifi == nil { + b.Skip("loopback interface not found") + } + for i := 0; i < b.N; i++ { + if _, err := InterfaceByIndex(ifi.Index); err != nil { + b.Fatalf("InterfaceByIndex failed: %v", err) + } + } +} + +func BenchmarkInterfaceByName(b *testing.B) { + ifi := loopbackInterface() + if ifi == nil { + b.Skip("loopback interface not found") + } + for i := 0; i < b.N; i++ { + if _, err := InterfaceByName(ifi.Name); err != nil { + b.Fatalf("InterfaceByName failed: %v", err) + } + } +} + +func BenchmarkInterfaceAddrs(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := InterfaceAddrs(); err != nil { + b.Fatalf("InterfaceAddrs failed: %v", err) + } + } +} + +func BenchmarkInterfacesAndAddrs(b *testing.B) { + ifi := loopbackInterface() + if ifi == nil { + b.Skip("loopback interface not found") + } + for i := 0; i < b.N; i++ { + if _, err := ifi.Addrs(); err != nil { + b.Fatalf("Interface.Addrs failed: %v", err) + } + } +} + +func BenchmarkInterfacesAndMulticastAddrs(b *testing.B) { + ifi := loopbackInterface() + if ifi == nil { + b.Skip("loopback interface not found") + } + for i := 0; i < b.N; i++ { + if _, err := ifi.MulticastAddrs(); err != nil { + b.Fatalf("Interface.MulticastAddrs failed: %v", err) + } + } +} diff --git a/src/pkg/net/interface_unix_test.go b/src/pkg/net/interface_unix_test.go new file mode 100644 index 000000000..6dbd6e6e7 --- /dev/null +++ b/src/pkg/net/interface_unix_test.go @@ -0,0 +1,145 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd + +package net + +import ( + "os" + "os/exec" + "runtime" + "testing" + "time" +) + +type testInterface struct { + name string + local string + remote string + setupCmds []*exec.Cmd + teardownCmds []*exec.Cmd +} + +func (ti *testInterface) setup() error { + for _, cmd := range ti.setupCmds { + if err := cmd.Run(); err != nil { + return err + } + } + return nil +} + +func (ti *testInterface) teardown() error { + for _, cmd := range ti.teardownCmds { + if err := cmd.Run(); err != nil { + return err + } + } + return nil +} + +func TestPointToPointInterface(t *testing.T) { + switch runtime.GOOS { + case "darwin": + t.Skipf("skipping read test on %q", runtime.GOOS) + } + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") + } + + local, remote := "169.254.0.1", "169.254.0.254" + ip := ParseIP(remote) + for i := 0; i < 3; i++ { + ti := &testInterface{} + if err := ti.setPointToPoint(5963+i, local, remote); err != nil { + t.Skipf("test requries external command: %v", err) + } + if err := ti.setup(); err != nil { + t.Fatalf("testInterface.setup failed: %v", err) + } else { + time.Sleep(3 * time.Millisecond) + } + ift, err := Interfaces() + if err != nil { + ti.teardown() + t.Fatalf("Interfaces failed: %v", err) + } + for _, ifi := range ift { + if ti.name == ifi.Name { + ifat, err := ifi.Addrs() + if err != nil { + ti.teardown() + t.Fatalf("Interface.Addrs failed: %v", err) + } + for _, ifa := range ifat { + if ip.Equal(ifa.(*IPNet).IP) { + ti.teardown() + t.Fatalf("got %v; want %v", ip, local) + } + } + } + } + if err := ti.teardown(); err != nil { + t.Fatalf("testInterface.teardown failed: %v", err) + } else { + time.Sleep(3 * time.Millisecond) + } + } +} + +func TestInterfaceArrivalAndDeparture(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") + } + + for i := 0; i < 3; i++ { + ift1, err := Interfaces() + if err != nil { + t.Fatalf("Interfaces failed: %v", err) + } + ti := &testInterface{} + if err := ti.setBroadcast(5682 + i); err != nil { + t.Skipf("test requires external command: %v", err) + } + if err := ti.setup(); err != nil { + t.Fatalf("testInterface.setup failed: %v", err) + } else { + time.Sleep(3 * time.Millisecond) + } + ift2, err := Interfaces() + if err != nil { + ti.teardown() + t.Fatalf("Interfaces failed: %v", err) + } + if len(ift2) <= len(ift1) { + for _, ifi := range ift1 { + t.Logf("before: %v", ifi) + } + for _, ifi := range ift2 { + t.Logf("after: %v", ifi) + } + ti.teardown() + t.Fatalf("got %v; want gt %v", len(ift2), len(ift1)) + } + if err := ti.teardown(); err != nil { + t.Fatalf("testInterface.teardown failed: %v", err) + } else { + time.Sleep(3 * time.Millisecond) + } + ift3, err := Interfaces() + if err != nil { + t.Fatalf("Interfaces failed: %v", err) + } + if len(ift3) >= len(ift2) { + for _, ifi := range ift2 { + t.Logf("before: %v", ifi) + } + for _, ifi := range ift3 { + t.Logf("after: %v", ifi) + } + t.Fatalf("got %v; want lt %v", len(ift3), len(ift2)) + } + } +} diff --git a/src/pkg/net/interface_windows.go b/src/pkg/net/interface_windows.go index 4368b3306..0759dc255 100644 --- a/src/pkg/net/interface_windows.go +++ b/src/pkg/net/interface_windows.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for Windows - package net import ( @@ -25,6 +23,9 @@ func getAdapterList() (*syscall.IpAdapterInfo, error) { b := make([]byte, 1000) l := uint32(len(b)) a := (*syscall.IpAdapterInfo)(unsafe.Pointer(&b[0])) + // TODO(mikio): GetAdaptersInfo returns IP_ADAPTER_INFO that + // contains IPv4 address list only. We should use another API + // for fetching IPv6 stuff from the kernel. err := syscall.GetAdaptersInfo(a, &l) if err == syscall.ERROR_BUFFER_OVERFLOW { b = make([]byte, l) @@ -38,7 +39,7 @@ func getAdapterList() (*syscall.IpAdapterInfo, error) { } func getInterfaceList() ([]syscall.InterfaceInfo, error) { - s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + s, err := sysSocket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) if err != nil { return nil, os.NewSyscallError("Socket", err) } @@ -126,10 +127,10 @@ func interfaceTable(ifindex int) ([]Interface, error) { return ift, nil } -// If the ifindex is zero, interfaceAddrTable returns addresses -// for all network interfaces. Otherwise it returns addresses -// for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, error) { +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { ai, err := getAdapterList() if err != nil { return nil, err @@ -138,11 +139,10 @@ func interfaceAddrTable(ifindex int) ([]Addr, error) { var ifat []Addr for ; ai != nil; ai = ai.Next { index := ai.Index - if ifindex == 0 || ifindex == int(index) { + if ifi == nil || ifi.Index == int(index) { ipl := &ai.IpAddressList for ; ipl != nil; ipl = ipl.Next { - ifa := IPAddr{} - ifa.IP = parseIPv4(bytePtrToString(&ipl.IpAddress.String[0])) + ifa := IPAddr{IP: parseIPv4(bytePtrToString(&ipl.IpAddress.String[0]))} ifat = append(ifat, ifa.toAddr()) } } @@ -150,9 +150,9 @@ func interfaceAddrTable(ifindex int) ([]Addr, error) { return ifat, nil } -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + // TODO(mikio): Implement this like other platforms. return nil, nil } diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go index 979d7acd5..d588e3a42 100644 --- a/src/pkg/net/ip.go +++ b/src/pkg/net/ip.go @@ -7,7 +7,7 @@ // IPv4 addresses are 4 bytes; IPv6 addresses are 16 bytes. // An IPv4 address can be converted to an IPv6 address by // adding a canonical prefix (10 zeros, 2 0xFFs). -// This library accepts either size of byte array but always +// This library accepts either size of byte slice but always // returns 16-byte addresses. package net @@ -18,14 +18,14 @@ const ( IPv6len = 16 ) -// An IP is a single IP address, an array of bytes. +// An IP is a single IP address, a slice of bytes. // Functions in this package accept either 4-byte (IPv4) -// or 16-byte (IPv6) arrays as input. +// or 16-byte (IPv6) slices as input. // // Note that in this documentation, referring to an // IP address as an IPv4 address or an IPv6 address // is a semantic property of the address, not just the -// length of the byte array: a 16-byte array can still +// length of the byte slice: a 16-byte slice can still // be an IPv4 address. type IP []byte @@ -36,6 +36,7 @@ type IPMask []byte type IPNet struct { IP IP // network number Mask IPMask // network mask + Zone string // IPv6 scoped addressing zone } // IPv4 returns the IP address (in 16-byte form) of the @@ -645,5 +646,5 @@ func ParseCIDR(s string) (IP, *IPNet, error) { return nil, nil, &ParseError{"CIDR address", s} } m := CIDRMask(n, 8*iplen) - return ip, &IPNet{ip.Mask(m), m}, nil + return ip, &IPNet{IP: ip.Mask(m), Mask: m}, nil } diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go index df647ef73..f8b7f067f 100644 --- a/src/pkg/net/ip_test.go +++ b/src/pkg/net/ip_test.go @@ -114,23 +114,23 @@ var parsecidrtests = []struct { net *IPNet err error }{ - {"135.104.0.0/32", IPv4(135, 104, 0, 0), &IPNet{IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 255)}, nil}, - {"0.0.0.0/24", IPv4(0, 0, 0, 0), &IPNet{IPv4(0, 0, 0, 0), IPv4Mask(255, 255, 255, 0)}, nil}, - {"135.104.0.0/24", IPv4(135, 104, 0, 0), &IPNet{IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 0)}, nil}, - {"135.104.0.1/32", IPv4(135, 104, 0, 1), &IPNet{IPv4(135, 104, 0, 1), IPv4Mask(255, 255, 255, 255)}, nil}, - {"135.104.0.1/24", IPv4(135, 104, 0, 1), &IPNet{IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 0)}, nil}, - {"::1/128", ParseIP("::1"), &IPNet{ParseIP("::1"), IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"))}, nil}, - {"abcd:2345::/127", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe"))}, nil}, - {"abcd:2345::/65", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff:8000::"))}, nil}, - {"abcd:2345::/64", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff::"))}, nil}, - {"abcd:2345::/63", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:fffe::"))}, nil}, - {"abcd:2345::/33", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:8000::"))}, nil}, - {"abcd:2345::/32", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff::"))}, nil}, - {"abcd:2344::/31", ParseIP("abcd:2344::"), &IPNet{ParseIP("abcd:2344::"), IPMask(ParseIP("ffff:fffe::"))}, nil}, - {"abcd:2300::/24", ParseIP("abcd:2300::"), &IPNet{ParseIP("abcd:2300::"), IPMask(ParseIP("ffff:ff00::"))}, nil}, - {"abcd:2345::/24", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2300::"), IPMask(ParseIP("ffff:ff00::"))}, nil}, - {"2001:DB8::/48", ParseIP("2001:DB8::"), &IPNet{ParseIP("2001:DB8::"), IPMask(ParseIP("ffff:ffff:ffff::"))}, nil}, - {"2001:DB8::1/48", ParseIP("2001:DB8::1"), &IPNet{ParseIP("2001:DB8::"), IPMask(ParseIP("ffff:ffff:ffff::"))}, nil}, + {"135.104.0.0/32", IPv4(135, 104, 0, 0), &IPNet{IP: IPv4(135, 104, 0, 0), Mask: IPv4Mask(255, 255, 255, 255)}, nil}, + {"0.0.0.0/24", IPv4(0, 0, 0, 0), &IPNet{IP: IPv4(0, 0, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)}, nil}, + {"135.104.0.0/24", IPv4(135, 104, 0, 0), &IPNet{IP: IPv4(135, 104, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)}, nil}, + {"135.104.0.1/32", IPv4(135, 104, 0, 1), &IPNet{IP: IPv4(135, 104, 0, 1), Mask: IPv4Mask(255, 255, 255, 255)}, nil}, + {"135.104.0.1/24", IPv4(135, 104, 0, 1), &IPNet{IP: IPv4(135, 104, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)}, nil}, + {"::1/128", ParseIP("::1"), &IPNet{IP: ParseIP("::1"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"))}, nil}, + {"abcd:2345::/127", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe"))}, nil}, + {"abcd:2345::/65", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:8000::"))}, nil}, + {"abcd:2345::/64", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff::"))}, nil}, + {"abcd:2345::/63", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:fffe::"))}, nil}, + {"abcd:2345::/33", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:8000::"))}, nil}, + {"abcd:2345::/32", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff::"))}, nil}, + {"abcd:2344::/31", ParseIP("abcd:2344::"), &IPNet{IP: ParseIP("abcd:2344::"), Mask: IPMask(ParseIP("ffff:fffe::"))}, nil}, + {"abcd:2300::/24", ParseIP("abcd:2300::"), &IPNet{IP: ParseIP("abcd:2300::"), Mask: IPMask(ParseIP("ffff:ff00::"))}, nil}, + {"abcd:2345::/24", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2300::"), Mask: IPMask(ParseIP("ffff:ff00::"))}, nil}, + {"2001:DB8::/48", ParseIP("2001:DB8::"), &IPNet{IP: ParseIP("2001:DB8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff::"))}, nil}, + {"2001:DB8::1/48", ParseIP("2001:DB8::1"), &IPNet{IP: ParseIP("2001:DB8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff::"))}, nil}, {"192.168.1.1/255.255.255.0", nil, nil, &ParseError{"CIDR address", "192.168.1.1/255.255.255.0"}}, {"192.168.1.1/35", nil, nil, &ParseError{"CIDR address", "192.168.1.1/35"}}, {"2001:db8::1/-1", nil, nil, &ParseError{"CIDR address", "2001:db8::1/-1"}}, @@ -154,14 +154,14 @@ var ipnetcontainstests = []struct { net *IPNet ok bool }{ - {IPv4(172, 16, 1, 1), &IPNet{IPv4(172, 16, 0, 0), CIDRMask(12, 32)}, true}, - {IPv4(172, 24, 0, 1), &IPNet{IPv4(172, 16, 0, 0), CIDRMask(13, 32)}, false}, - {IPv4(192, 168, 0, 3), &IPNet{IPv4(192, 168, 0, 0), IPv4Mask(0, 0, 255, 252)}, true}, - {IPv4(192, 168, 0, 4), &IPNet{IPv4(192, 168, 0, 0), IPv4Mask(0, 255, 0, 252)}, false}, - {ParseIP("2001:db8:1:2::1"), &IPNet{ParseIP("2001:db8:1::"), CIDRMask(47, 128)}, true}, - {ParseIP("2001:db8:1:2::1"), &IPNet{ParseIP("2001:db8:2::"), CIDRMask(47, 128)}, false}, - {ParseIP("2001:db8:1:2::1"), &IPNet{ParseIP("2001:db8:1::"), IPMask(ParseIP("ffff:0:ffff::"))}, true}, - {ParseIP("2001:db8:1:2::1"), &IPNet{ParseIP("2001:db8:1::"), IPMask(ParseIP("0:0:0:ffff::"))}, false}, + {IPv4(172, 16, 1, 1), &IPNet{IP: IPv4(172, 16, 0, 0), Mask: CIDRMask(12, 32)}, true}, + {IPv4(172, 24, 0, 1), &IPNet{IP: IPv4(172, 16, 0, 0), Mask: CIDRMask(13, 32)}, false}, + {IPv4(192, 168, 0, 3), &IPNet{IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(0, 0, 255, 252)}, true}, + {IPv4(192, 168, 0, 4), &IPNet{IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(0, 255, 0, 252)}, false}, + {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:1::"), Mask: CIDRMask(47, 128)}, true}, + {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:2::"), Mask: CIDRMask(47, 128)}, false}, + {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:1::"), Mask: IPMask(ParseIP("ffff:0:ffff::"))}, true}, + {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:1::"), Mask: IPMask(ParseIP("0:0:0:ffff::"))}, false}, } func TestIPNetContains(t *testing.T) { @@ -176,10 +176,10 @@ var ipnetstringtests = []struct { in *IPNet out string }{ - {&IPNet{IPv4(192, 168, 1, 0), CIDRMask(26, 32)}, "192.168.1.0/26"}, - {&IPNet{IPv4(192, 168, 1, 0), IPv4Mask(255, 0, 255, 0)}, "192.168.1.0/ff00ff00"}, - {&IPNet{ParseIP("2001:db8::"), CIDRMask(55, 128)}, "2001:db8::/55"}, - {&IPNet{ParseIP("2001:db8::"), IPMask(ParseIP("8000:f123:0:cafe::"))}, "2001:db8::/8000f1230000cafe0000000000000000"}, + {&IPNet{IP: IPv4(192, 168, 1, 0), Mask: CIDRMask(26, 32)}, "192.168.1.0/26"}, + {&IPNet{IP: IPv4(192, 168, 1, 0), Mask: IPv4Mask(255, 0, 255, 0)}, "192.168.1.0/ff00ff00"}, + {&IPNet{IP: ParseIP("2001:db8::"), Mask: CIDRMask(55, 128)}, "2001:db8::/55"}, + {&IPNet{IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("8000:f123:0:cafe::"))}, "2001:db8::/8000f1230000cafe0000000000000000"}, } func TestIPNetString(t *testing.T) { @@ -233,27 +233,27 @@ var networknumberandmasktests = []struct { in IPNet out IPNet }{ - {IPNet{v4addr, v4mask}, IPNet{v4addr, v4mask}}, - {IPNet{v4addr, v4mappedv6mask}, IPNet{v4addr, v4mask}}, - {IPNet{v4mappedv6addr, v4mappedv6mask}, IPNet{v4addr, v4mask}}, - {IPNet{v4mappedv6addr, v6mask}, IPNet{v4addr, v4maskzero}}, - {IPNet{v4addr, v6mask}, IPNet{v4addr, v4maskzero}}, - {IPNet{v6addr, v6mask}, IPNet{v6addr, v6mask}}, - {IPNet{v6addr, v4mappedv6mask}, IPNet{v6addr, v4mappedv6mask}}, - {in: IPNet{v6addr, v4mask}}, - {in: IPNet{v4addr, badmask}}, - {in: IPNet{v4mappedv6addr, badmask}}, - {in: IPNet{v6addr, badmask}}, - {in: IPNet{badaddr, v4mask}}, - {in: IPNet{badaddr, v4mappedv6mask}}, - {in: IPNet{badaddr, v6mask}}, - {in: IPNet{badaddr, badmask}}, + {IPNet{IP: v4addr, Mask: v4mask}, IPNet{IP: v4addr, Mask: v4mask}}, + {IPNet{IP: v4addr, Mask: v4mappedv6mask}, IPNet{IP: v4addr, Mask: v4mask}}, + {IPNet{IP: v4mappedv6addr, Mask: v4mappedv6mask}, IPNet{IP: v4addr, Mask: v4mask}}, + {IPNet{IP: v4mappedv6addr, Mask: v6mask}, IPNet{IP: v4addr, Mask: v4maskzero}}, + {IPNet{IP: v4addr, Mask: v6mask}, IPNet{IP: v4addr, Mask: v4maskzero}}, + {IPNet{IP: v6addr, Mask: v6mask}, IPNet{IP: v6addr, Mask: v6mask}}, + {IPNet{IP: v6addr, Mask: v4mappedv6mask}, IPNet{IP: v6addr, Mask: v4mappedv6mask}}, + {in: IPNet{IP: v6addr, Mask: v4mask}}, + {in: IPNet{IP: v4addr, Mask: badmask}}, + {in: IPNet{IP: v4mappedv6addr, Mask: badmask}}, + {in: IPNet{IP: v6addr, Mask: badmask}}, + {in: IPNet{IP: badaddr, Mask: v4mask}}, + {in: IPNet{IP: badaddr, Mask: v4mappedv6mask}}, + {in: IPNet{IP: badaddr, Mask: v6mask}}, + {in: IPNet{IP: badaddr, Mask: badmask}}, } func TestNetworkNumberAndMask(t *testing.T) { for _, tt := range networknumberandmasktests { ip, m := networkNumberAndMask(&tt.in) - out := &IPNet{ip, m} + out := &IPNet{IP: ip, Mask: m} if !reflect.DeepEqual(&tt.out, out) { t.Errorf("networkNumberAndMask(%v) = %v; want %v", tt.in, out, &tt.out) } @@ -268,6 +268,29 @@ var splitjointests = []struct { {"www.google.com", "80", "www.google.com:80"}, {"127.0.0.1", "1234", "127.0.0.1:1234"}, {"::1", "80", "[::1]:80"}, + {"google.com", "https%foo", "google.com:https%foo"}, // Go 1.0 behavior + {"", "0", ":0"}, + {"127.0.0.1", "", "127.0.0.1:"}, // Go 1.0 behaviour + {"www.google.com", "", "www.google.com:"}, // Go 1.0 behaviour +} + +var splitfailuretests = []struct { + HostPort string + Err string +}{ + {"www.google.com", "missing port in address"}, + {"127.0.0.1", "missing port in address"}, + {"[::1]", "missing port in address"}, + {"::1", "too many colons in address"}, + + // Test cases that didn't fail in Go 1.0 + {"[foo:bar]", "missing port in address"}, + {"[foo:bar]baz", "missing port in address"}, + {"[foo]:[bar]:baz", "too many colons in address"}, + {"[foo]bar:baz", "missing port in address"}, + {"[foo]:[bar]baz", "unexpected '[' in address"}, + {"foo[bar]:baz", "unexpected '[' in address"}, + {"foo]bar:baz", "unexpected ']' in address"}, } func TestSplitHostPort(t *testing.T) { @@ -276,6 +299,16 @@ func TestSplitHostPort(t *testing.T) { t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.Join, host, port, err, tt.Host, tt.Port) } } + for _, tt := range splitfailuretests { + if _, _, err := SplitHostPort(tt.HostPort); err == nil { + t.Errorf("SplitHostPort(%q) should have failed", tt.HostPort) + } else { + e := err.(*AddrError) + if e.Err != tt.Err { + t.Errorf("SplitHostPort(%q) = _, _, %q; want %q", tt.HostPort, e.Err, tt.Err) + } + } + } } func TestJoinHostPort(t *testing.T) { diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go index 613620272..65defc7ea 100644 --- a/src/pkg/net/ipraw_test.go +++ b/src/pkg/net/ipraw_test.go @@ -2,205 +2,340 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// +build !plan9 + package net import ( "bytes" + "errors" "os" - "syscall" + "reflect" "testing" "time" ) -var icmpTests = []struct { +var resolveIPAddrTests = []struct { + net string + litAddr string + addr *IPAddr + err error +}{ + {"ip", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, + {"ip4", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, + {"ip4:icmp", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, + + {"ip", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, + {"ip6", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, + {"ip6:icmp", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, + + {"", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, // Go 1.0 behavior + {"", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, // Go 1.0 behavior + + {"l2tp", "127.0.0.1", nil, UnknownNetworkError("l2tp")}, + {"l2tp:gre", "127.0.0.1", nil, UnknownNetworkError("l2tp:gre")}, + {"tcp", "1.2.3.4:123", nil, UnknownNetworkError("tcp")}, +} + +func TestResolveIPAddr(t *testing.T) { + for _, tt := range resolveIPAddrTests { + addr, err := ResolveIPAddr(tt.net, tt.litAddr) + if err != tt.err { + t.Fatalf("ResolveIPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) + } + if !reflect.DeepEqual(addr, tt.addr) { + t.Fatalf("got %#v; expected %#v", addr, tt.addr) + } + } +} + +var icmpEchoTests = []struct { net string laddr string raddr string - ipv6 bool // test with underlying AF_INET6 socket }{ - {"ip4:icmp", "", "127.0.0.1", false}, - {"ip6:icmp", "", "::1", true}, + {"ip4:icmp", "0.0.0.0", "127.0.0.1"}, + {"ip6:ipv6-icmp", "::", "::1"}, } -func TestICMP(t *testing.T) { +func TestConnICMPEcho(t *testing.T) { if os.Getuid() != 0 { - t.Logf("test disabled; must be root") - return + t.Skip("skipping test; must be root") } - seqnum := 61455 - for _, tt := range icmpTests { - if tt.ipv6 && !supportsIPv6 { + for i, tt := range icmpEchoTests { + net, _, err := parseNetwork(tt.net) + if err != nil { + t.Fatalf("parseNetwork failed: %v", err) + } + if net == "ip6" && !supportsIPv6 { continue } - id := os.Getpid() & 0xffff - seqnum++ - echo := newICMPEchoRequest(tt.net, id, seqnum, 128, []byte("Go Go Gadget Ping!!!")) - exchangeICMPEcho(t, tt.net, tt.laddr, tt.raddr, echo) - } -} - -func exchangeICMPEcho(t *testing.T, net, laddr, raddr string, echo []byte) { - c, err := ListenPacket(net, laddr) - if err != nil { - t.Errorf("ListenPacket(%q, %q) failed: %v", net, laddr, err) - return - } - c.SetDeadline(time.Now().Add(100 * time.Millisecond)) - defer c.Close() - - ra, err := ResolveIPAddr(net, raddr) - if err != nil { - t.Errorf("ResolveIPAddr(%q, %q) failed: %v", net, raddr, err) - return - } - waitForReady := make(chan bool) - go icmpEchoTransponder(t, net, raddr, waitForReady) - <-waitForReady - - _, err = c.WriteTo(echo, ra) - if err != nil { - t.Errorf("WriteTo failed: %v", err) - return - } + c, err := Dial(tt.net, tt.raddr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + c.SetDeadline(time.Now().Add(100 * time.Millisecond)) + defer c.Close() - reply := make([]byte, 256) - for { - _, _, err := c.ReadFrom(reply) + typ := icmpv4EchoRequest + if net == "ip6" { + typ = icmpv6EchoRequest + } + xid, xseq := os.Getpid()&0xffff, i+1 + b, err := (&icmpMessage{ + Type: typ, Code: 0, + Body: &icmpEcho{ + ID: xid, Seq: xseq, + Data: bytes.Repeat([]byte("Go Go Gadget Ping!!!"), 3), + }, + }).Marshal() if err != nil { - t.Errorf("ReadFrom failed: %v", err) - return + t.Fatalf("icmpMessage.Marshal failed: %v", err) } - switch c.(*IPConn).fd.family { - case syscall.AF_INET: - if reply[0] != ICMP4_ECHO_REPLY { - continue + if _, err := c.Write(b); err != nil { + t.Fatalf("Conn.Write failed: %v", err) + } + var m *icmpMessage + for { + if _, err := c.Read(b); err != nil { + t.Fatalf("Conn.Read failed: %v", err) + } + if net == "ip4" { + b = ipv4Payload(b) } - case syscall.AF_INET6: - if reply[0] != ICMP6_ECHO_REPLY { + if m, err = parseICMPMessage(b); err != nil { + t.Fatalf("parseICMPMessage failed: %v", err) + } + switch m.Type { + case icmpv4EchoRequest, icmpv6EchoRequest: continue } + break } - xid, xseqnum := parseICMPEchoReply(echo) - rid, rseqnum := parseICMPEchoReply(reply) - if rid != xid || rseqnum != xseqnum { - t.Errorf("ID = %v, Seqnum = %v, want ID = %v, Seqnum = %v", rid, rseqnum, xid, xseqnum) - return + switch p := m.Body.(type) { + case *icmpEcho: + if p.ID != xid || p.Seq != xseq { + t.Fatalf("got id=%v, seqnum=%v; expected id=%v, seqnum=%v", p.ID, p.Seq, xid, xseq) + } + default: + t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, typ, 0) } - break } } -func icmpEchoTransponder(t *testing.T, net, raddr string, waitForReady chan bool) { - c, err := Dial(net, raddr) - if err != nil { - waitForReady <- true - t.Errorf("Dial(%q, %q) failed: %v", net, raddr, err) - return +func TestPacketConnICMPEcho(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") } - c.SetDeadline(time.Now().Add(100 * time.Millisecond)) - defer c.Close() - waitForReady <- true - echo := make([]byte, 256) - var nr int - for { - nr, err = c.Read(echo) + for i, tt := range icmpEchoTests { + net, _, err := parseNetwork(tt.net) if err != nil { - t.Errorf("Read failed: %v", err) - return + t.Fatalf("parseNetwork failed: %v", err) } - switch c.(*IPConn).fd.family { - case syscall.AF_INET: - if echo[0] != ICMP4_ECHO_REQUEST { - continue + if net == "ip6" && !supportsIPv6 { + continue + } + + c, err := ListenPacket(tt.net, tt.laddr) + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + c.SetDeadline(time.Now().Add(100 * time.Millisecond)) + defer c.Close() + + ra, err := ResolveIPAddr(tt.net, tt.raddr) + if err != nil { + t.Fatalf("ResolveIPAddr failed: %v", err) + } + typ := icmpv4EchoRequest + if net == "ip6" { + typ = icmpv6EchoRequest + } + xid, xseq := os.Getpid()&0xffff, i+1 + b, err := (&icmpMessage{ + Type: typ, Code: 0, + Body: &icmpEcho{ + ID: xid, Seq: xseq, + Data: bytes.Repeat([]byte("Go Go Gadget Ping!!!"), 3), + }, + }).Marshal() + if err != nil { + t.Fatalf("icmpMessage.Marshal failed: %v", err) + } + if _, err := c.WriteTo(b, ra); err != nil { + t.Fatalf("PacketConn.WriteTo failed: %v", err) + } + var m *icmpMessage + for { + if _, _, err := c.ReadFrom(b); err != nil { + t.Fatalf("PacketConn.ReadFrom failed: %v", err) } - case syscall.AF_INET6: - if echo[0] != ICMP6_ECHO_REQUEST { + // TODO: fix issue 3944 + //if net == "ip4" { + // b = ipv4Payload(b) + //} + if m, err = parseICMPMessage(b); err != nil { + t.Fatalf("parseICMPMessage failed: %v", err) + } + switch m.Type { + case icmpv4EchoRequest, icmpv6EchoRequest: continue } + break + } + switch p := m.Body.(type) { + case *icmpEcho: + if p.ID != xid || p.Seq != xseq { + t.Fatalf("got id=%v, seqnum=%v; expected id=%v, seqnum=%v", p.ID, p.Seq, xid, xseq) + } + default: + t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, typ, 0) } - break - } - - switch c.(*IPConn).fd.family { - case syscall.AF_INET: - echo[0] = ICMP4_ECHO_REPLY - case syscall.AF_INET6: - echo[0] = ICMP6_ECHO_REPLY } +} - _, err = c.Write(echo[:nr]) - if err != nil { - t.Errorf("Write failed: %v", err) - return +func ipv4Payload(b []byte) []byte { + if len(b) < 20 { + return b } + hdrlen := int(b[0]&0x0f) << 2 + return b[hdrlen:] } const ( - ICMP4_ECHO_REQUEST = 8 - ICMP4_ECHO_REPLY = 0 - ICMP6_ECHO_REQUEST = 128 - ICMP6_ECHO_REPLY = 129 + icmpv4EchoRequest = 8 + icmpv4EchoReply = 0 + icmpv6EchoRequest = 128 + icmpv6EchoReply = 129 ) -func newICMPEchoRequest(net string, id, seqnum, msglen int, filler []byte) []byte { - afnet, _, _ := parseDialNetwork(net) - switch afnet { - case "ip4": - return newICMPv4EchoRequest(id, seqnum, msglen, filler) - case "ip6": - return newICMPv6EchoRequest(id, seqnum, msglen, filler) - } - return nil +// icmpMessage represents an ICMP message. +type icmpMessage struct { + Type int // type + Code int // code + Checksum int // checksum + Body icmpMessageBody // body } -func newICMPv4EchoRequest(id, seqnum, msglen int, filler []byte) []byte { - b := newICMPInfoMessage(id, seqnum, msglen, filler) - b[0] = ICMP4_ECHO_REQUEST +// icmpMessageBody represents an ICMP message body. +type icmpMessageBody interface { + Len() int + Marshal() ([]byte, error) +} - // calculate ICMP checksum - cklen := len(b) +// Marshal returns the binary enconding of the ICMP echo request or +// reply message m. +func (m *icmpMessage) Marshal() ([]byte, error) { + b := []byte{byte(m.Type), byte(m.Code), 0, 0} + if m.Body != nil && m.Body.Len() != 0 { + mb, err := m.Body.Marshal() + if err != nil { + return nil, err + } + b = append(b, mb...) + } + switch m.Type { + case icmpv6EchoRequest, icmpv6EchoReply: + return b, nil + } + csumcv := len(b) - 1 // checksum coverage s := uint32(0) - for i := 0; i < cklen-1; i += 2 { + for i := 0; i < csumcv; i += 2 { s += uint32(b[i+1])<<8 | uint32(b[i]) } - if cklen&1 == 1 { - s += uint32(b[cklen-1]) + if csumcv&1 == 0 { + s += uint32(b[csumcv]) } - s = (s >> 16) + (s & 0xffff) - s = s + (s >> 16) - // place checksum back in header; using ^= avoids the - // assumption the checksum bytes are zero - b[2] ^= uint8(^s & 0xff) - b[3] ^= uint8(^s >> 8) + s = s>>16 + s&0xffff + s = s + s>>16 + // Place checksum back in header; using ^= avoids the + // assumption the checksum bytes are zero. + b[2] ^= byte(^s & 0xff) + b[3] ^= byte(^s >> 8) + return b, nil +} - return b +// parseICMPMessage parses b as an ICMP message. +func parseICMPMessage(b []byte) (*icmpMessage, error) { + msglen := len(b) + if msglen < 4 { + return nil, errors.New("message too short") + } + m := &icmpMessage{Type: int(b[0]), Code: int(b[1]), Checksum: int(b[2])<<8 | int(b[3])} + if msglen > 4 { + var err error + switch m.Type { + case icmpv4EchoRequest, icmpv4EchoReply, icmpv6EchoRequest, icmpv6EchoReply: + m.Body, err = parseICMPEcho(b[4:]) + if err != nil { + return nil, err + } + } + } + return m, nil +} + +// imcpEcho represenets an ICMP echo request or reply message body. +type icmpEcho struct { + ID int // identifier + Seq int // sequence number + Data []byte // data +} + +func (p *icmpEcho) Len() int { + if p == nil { + return 0 + } + return 4 + len(p.Data) } -func newICMPv6EchoRequest(id, seqnum, msglen int, filler []byte) []byte { - b := newICMPInfoMessage(id, seqnum, msglen, filler) - b[0] = ICMP6_ECHO_REQUEST - return b +// Marshal returns the binary enconding of the ICMP echo request or +// reply message body p. +func (p *icmpEcho) Marshal() ([]byte, error) { + b := make([]byte, 4+len(p.Data)) + b[0], b[1] = byte(p.ID>>8), byte(p.ID&0xff) + b[2], b[3] = byte(p.Seq>>8), byte(p.Seq&0xff) + copy(b[4:], p.Data) + return b, nil } -func newICMPInfoMessage(id, seqnum, msglen int, filler []byte) []byte { - b := make([]byte, msglen) - copy(b[8:], bytes.Repeat(filler, (msglen-8)/len(filler)+1)) - b[0] = 0 // type - b[1] = 0 // code - b[2] = 0 // checksum - b[3] = 0 // checksum - b[4] = uint8(id >> 8) // identifier - b[5] = uint8(id & 0xff) // identifier - b[6] = uint8(seqnum >> 8) // sequence number - b[7] = uint8(seqnum & 0xff) // sequence number - return b +// parseICMPEcho parses b as an ICMP echo request or reply message +// body. +func parseICMPEcho(b []byte) (*icmpEcho, error) { + bodylen := len(b) + p := &icmpEcho{ID: int(b[0])<<8 | int(b[1]), Seq: int(b[2])<<8 | int(b[3])} + if bodylen > 4 { + p.Data = make([]byte, bodylen-4) + copy(p.Data, b[4:]) + } + return p, nil } -func parseICMPEchoReply(b []byte) (id, seqnum int) { - id = int(b[4])<<8 | int(b[5]) - seqnum = int(b[6])<<8 | int(b[7]) - return +var ipConnLocalNameTests = []struct { + net string + laddr *IPAddr +}{ + {"ip4:icmp", &IPAddr{IP: IPv4(127, 0, 0, 1)}}, + {"ip4:icmp", &IPAddr{}}, + {"ip4:icmp", nil}, +} + +func TestIPConnLocalName(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") + } + + for _, tt := range ipConnLocalNameTests { + c, err := ListenIP(tt.net, tt.laddr) + if err != nil { + t.Fatalf("ListenIP failed: %v", err) + } + defer c.Close() + if la := c.LocalAddr(); la == nil { + t.Fatal("IPConn.LocalAddr failed") + } + } } diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go index b23213ee1..daccba366 100644 --- a/src/pkg/net/iprawsock.go +++ b/src/pkg/net/iprawsock.go @@ -2,13 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// (Raw) IP sockets +// Raw IP sockets package net -// IPAddr represents the address of a IP end point. +// IPAddr represents the address of an IP end point. type IPAddr struct { - IP IP + IP IP + Zone string // IPv6 scoped addressing zone } // Network returns the address's network name, "ip". @@ -21,45 +22,25 @@ func (a *IPAddr) String() string { return a.IP.String() } -// ResolveIPAddr parses addr as a IP address and resolves domain +// ResolveIPAddr parses addr as an IP address and resolves domain // names to numeric addresses on the network net, which must be -// "ip", "ip4" or "ip6". A literal IPv6 host address must be -// enclosed in square brackets, as in "[::]". +// "ip", "ip4" or "ip6". func ResolveIPAddr(net, addr string) (*IPAddr, error) { - ip, err := hostToIP(net, addr) + if net == "" { // a hint wildcard for Go 1.0 undocumented behavior + net = "ip" + } + afnet, _, err := parseNetwork(net) if err != nil { return nil, err } - return &IPAddr{ip}, nil -} - -// Convert "host" into IP address. -func hostToIP(net, host string) (ip IP, err error) { - var addr IP - // Try as an IP address. - addr = ParseIP(host) - if addr == nil { - filter := anyaddr - if net != "" && net[len(net)-1] == '4' { - filter = ipv4only - } - if net != "" && net[len(net)-1] == '6' { - filter = ipv6only - } - // Not an IP address. Try as a DNS name. - addrs, err1 := LookupHost(host) - if err1 != nil { - err = err1 - goto Error - } - addr = firstFavoriteAddr(filter, addrs) - if addr == nil { - // should not happen - err = &AddrError{"LookupHost returned no suitable address", addrs[0]} - goto Error - } + switch afnet { + case "ip", "ip4", "ip6": + default: + return nil, UnknownNetworkError(net) + } + a, err := resolveInternetAddr(afnet, addr, noDeadline) + if err != nil { + return nil, err } - return addr, nil -Error: - return nil, err + return a.(*IPAddr), nil } diff --git a/src/pkg/net/iprawsock_plan9.go b/src/pkg/net/iprawsock_plan9.go index 43719fc99..88e3b2c60 100644 --- a/src/pkg/net/iprawsock_plan9.go +++ b/src/pkg/net/iprawsock_plan9.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// (Raw) IP sockets stubs for Plan 9 +// Raw IP sockets for Plan 9 package net @@ -11,55 +11,13 @@ import ( "time" ) -// IPConn is the implementation of the Conn and PacketConn -// interfaces for IP network connections. -type IPConn bool - -// SetDeadline implements the Conn SetDeadline method. -func (c *IPConn) SetDeadline(t time.Time) error { - return syscall.EPLAN9 -} - -// SetReadDeadline implements the Conn SetReadDeadline method. -func (c *IPConn) SetReadDeadline(t time.Time) error { - return syscall.EPLAN9 -} - -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (c *IPConn) SetWriteDeadline(t time.Time) error { - return syscall.EPLAN9 -} - -// Implementation of the Conn interface - see Conn for documentation. - -// Read implements the Conn Read method. -func (c *IPConn) Read(b []byte) (int, error) { - return 0, syscall.EPLAN9 -} - -// Write implements the Conn Write method. -func (c *IPConn) Write(b []byte) (int, error) { - return 0, syscall.EPLAN9 -} - -// Close closes the IP connection. -func (c *IPConn) Close() error { - return syscall.EPLAN9 -} - -// LocalAddr returns the local network address. -func (c *IPConn) LocalAddr() Addr { - return nil +// IPConn is the implementation of the Conn and PacketConn interfaces +// for IP network connections. +type IPConn struct { + conn } -// RemoteAddr returns the remote network address, a *IPAddr. -func (c *IPConn) RemoteAddr() Addr { - return nil -} - -// IP-specific methods. - -// ReadFromIP reads a IP packet from c, copying the payload into b. +// ReadFromIP reads an IP packet from c, copying the payload into b. // It returns the number of bytes copied into b and the return address // that was on the packet. // @@ -75,12 +33,21 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { return 0, nil, syscall.EPLAN9 } -// WriteToIP writes a IP packet to addr via c, copying the payload from b. +// ReadMsgIP reads a packet from c, copying the payload into b and the +// associdated out-of-band data into oob. It returns the number of +// bytes copied into b, the number of bytes copied into oob, the flags +// that were set on the packet and the source address of the packet. +func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) { + return 0, 0, 0, nil, syscall.EPLAN9 +} + +// WriteToIP writes an IP packet to addr via c, copying the payload +// from b. // -// WriteToIP can be made to time out and return -// an error with Timeout() == true after a fixed time limit; -// see SetDeadline and SetWriteDeadline. -// On packet-oriented connections, write timeouts are rare. +// WriteToIP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetWriteDeadline. On packet-oriented connections, write timeouts +// are rare. func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { return 0, syscall.EPLAN9 } @@ -90,16 +57,28 @@ func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) { return 0, syscall.EPLAN9 } -// DialIP connects to the remote address raddr on the network protocol netProto, -// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name. +// WriteMsgIP writes a packet to addr via c, copying the payload from +// b and the associated out-of-band data from oob. It returns the +// number of payload and out-of-band bytes written. +func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error) { + return 0, 0, syscall.EPLAN9 +} + +// DialIP connects to the remote address raddr on the network protocol +// netProto, which must be "ip", "ip4", or "ip6" followed by a colon +// and a protocol number or name. func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { + return dialIP(netProto, laddr, raddr, noDeadline) +} + +func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) { return nil, syscall.EPLAN9 } -// ListenIP listens for incoming IP packets addressed to the -// local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send IP -// packets with per-packet addressing. +// ListenIP listens for incoming IP packets addressed to the local +// address laddr. The returned connection c's ReadFrom and WriteTo +// methods can be used to receive and send IP packets with per-packet +// addressing. func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { return nil, syscall.EPLAN9 } diff --git a/src/pkg/net/iprawsock_posix.go b/src/pkg/net/iprawsock_posix.go index 9fc7ecdb9..2ef4db19c 100644 --- a/src/pkg/net/iprawsock_posix.go +++ b/src/pkg/net/iprawsock_posix.go @@ -4,12 +4,11 @@ // +build darwin freebsd linux netbsd openbsd windows -// (Raw) IP sockets +// Raw IP sockets for POSIX package net import ( - "os" "syscall" "time" ) @@ -17,9 +16,9 @@ import ( func sockaddrToIP(sa syscall.Sockaddr) Addr { switch sa := sa.(type) { case *syscall.SockaddrInet4: - return &IPAddr{sa.Addr[0:]} + return &IPAddr{IP: sa.Addr[0:]} case *syscall.SockaddrInet6: - return &IPAddr{sa.Addr[0:]} + return &IPAddr{IP: sa.Addr[0:], Zone: zoneToString(int(sa.ZoneId))} } return nil } @@ -42,7 +41,7 @@ func (a *IPAddr) isWildcard() bool { } func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) { - return ipToSockaddr(family, a.IP, 0) + return ipToSockaddr(family, a.IP, 0, a.Zone) } func (a *IPAddr) toAddr() sockaddr { @@ -55,98 +54,12 @@ func (a *IPAddr) toAddr() sockaddr { // IPConn is the implementation of the Conn and PacketConn // interfaces for IP network connections. type IPConn struct { - fd *netFD + conn } -func newIPConn(fd *netFD) *IPConn { return &IPConn{fd} } +func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} } -func (c *IPConn) ok() bool { return c != nil && c.fd != nil } - -// Implementation of the Conn interface - see Conn for documentation. - -// Read implements the Conn Read method. -func (c *IPConn) Read(b []byte) (int, error) { - n, _, err := c.ReadFrom(b) - return n, err -} - -// Write implements the Conn Write method. -func (c *IPConn) Write(b []byte) (int, error) { - if !c.ok() { - return 0, syscall.EINVAL - } - return c.fd.Write(b) -} - -// Close closes the IP connection. -func (c *IPConn) Close() error { - if !c.ok() { - return syscall.EINVAL - } - return c.fd.Close() -} - -// LocalAddr returns the local network address. -func (c *IPConn) LocalAddr() Addr { - if !c.ok() { - return nil - } - return c.fd.laddr -} - -// RemoteAddr returns the remote network address, a *IPAddr. -func (c *IPConn) RemoteAddr() Addr { - if !c.ok() { - return nil - } - return c.fd.raddr -} - -// SetDeadline implements the Conn SetDeadline method. -func (c *IPConn) SetDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setDeadline(c.fd, t) -} - -// SetReadDeadline implements the Conn SetReadDeadline method. -func (c *IPConn) SetReadDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setReadDeadline(c.fd, t) -} - -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (c *IPConn) SetWriteDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setWriteDeadline(c.fd, t) -} - -// SetReadBuffer sets the size of the operating system's -// receive buffer associated with the connection. -func (c *IPConn) SetReadBuffer(bytes int) error { - if !c.ok() { - return syscall.EINVAL - } - return setReadBuffer(c.fd, bytes) -} - -// SetWriteBuffer sets the size of the operating system's -// transmit buffer associated with the connection. -func (c *IPConn) SetWriteBuffer(bytes int) error { - if !c.ok() { - return syscall.EINVAL - } - return setWriteBuffer(c.fd, bytes) -} - -// IP-specific methods. - -// ReadFromIP reads a IP packet from c, copying the payload into b. +// ReadFromIP reads an IP packet from c, copying the payload into b. // It returns the number of bytes copied into b and the return address // that was on the packet. // @@ -163,14 +76,14 @@ func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) { n, sa, err := c.fd.ReadFrom(b) switch sa := sa.(type) { case *syscall.SockaddrInet4: - addr = &IPAddr{sa.Addr[0:]} + addr = &IPAddr{IP: sa.Addr[0:]} if len(b) >= IPv4len { // discard ipv4 header hsize := (int(b[0]) & 0xf) * 4 copy(b, b[hsize:]) n -= hsize } case *syscall.SockaddrInet6: - addr = &IPAddr{sa.Addr[0:]} + addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneToString(int(sa.ZoneId))} } return n, addr, err } @@ -180,11 +93,30 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { if !c.ok() { return 0, nil, syscall.EINVAL } - n, uaddr, err := c.ReadFromIP(b) - return n, uaddr.toAddr(), err + n, addr, err := c.ReadFromIP(b) + return n, addr.toAddr(), err +} + +// ReadMsgIP reads a packet from c, copying the payload into b and the +// associdated out-of-band data into oob. It returns the number of +// bytes copied into b, the number of bytes copied into oob, the flags +// that were set on the packet and the source address of the packet. +func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) { + if !c.ok() { + return 0, 0, 0, nil, syscall.EINVAL + } + var sa syscall.Sockaddr + n, oobn, flags, sa, err = c.fd.ReadMsg(b, oob) + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + addr = &IPAddr{IP: sa.Addr[0:]} + case *syscall.SockaddrInet6: + addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneToString(int(sa.ZoneId))} + } + return } -// WriteToIP writes a IP packet to addr via c, copying the payload from b. +// WriteToIP writes an IP packet to addr via c, copying the payload from b. // // WriteToIP can be made to time out and return // an error with Timeout() == true after a fixed time limit; @@ -213,22 +145,40 @@ func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) { return c.WriteToIP(b, a) } +// WriteMsgIP writes a packet to addr via c, copying the payload from +// b and the associated out-of-band data from oob. It returns the +// number of payload and out-of-band bytes written. +func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error) { + if !c.ok() { + return 0, 0, syscall.EINVAL + } + sa, err := addr.sockaddr(c.fd.family) + if err != nil { + return 0, 0, &OpError{"write", c.fd.net, addr, err} + } + return c.fd.WriteMsg(b, oob, sa) +} + // DialIP connects to the remote address raddr on the network protocol netProto, // which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name. func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { - net, proto, err := parseDialNetwork(netProto) + return dialIP(netProto, laddr, raddr, noDeadline) +} + +func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) { + net, proto, err := parseNetwork(netProto) if err != nil { return nil, err } switch net { case "ip", "ip4", "ip6": default: - return nil, UnknownNetworkError(net) + return nil, UnknownNetworkError(netProto) } if raddr == nil { return nil, &OpError{"dial", netProto, nil, errMissingAddress} } - fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_RAW, proto, "dial", sockaddrToIP) + fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_RAW, proto, "dial", sockaddrToIP) if err != nil { return nil, err } @@ -240,23 +190,18 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { // and WriteTo methods can be used to receive and send IP // packets with per-packet addressing. func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { - net, proto, err := parseDialNetwork(netProto) + net, proto, err := parseNetwork(netProto) if err != nil { return nil, err } switch net { case "ip", "ip4", "ip6": default: - return nil, UnknownNetworkError(net) + return nil, UnknownNetworkError(netProto) } - fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "listen", sockaddrToIP) + fd, err := internetSocket(net, laddr.toAddr(), nil, noDeadline, syscall.SOCK_RAW, proto, "listen", sockaddrToIP) if err != nil { return nil, err } return newIPConn(fd), nil } - -// File returns a copy of the underlying os.File, set to blocking mode. -// It is the caller's responsibility to close f when finished. -// Closing c does not affect f, and closing f does not affect c. -func (c *IPConn) File() (f *os.File, err error) { return c.fd.dup() } diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go index bfbce18a4..1ef489289 100644 --- a/src/pkg/net/ipsock.go +++ b/src/pkg/net/ipsock.go @@ -2,11 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// IP sockets +// Internet protocol family sockets package net -var supportsIPv6, supportsIPv4map = probeIPv6Stack() +import "time" + +var supportsIPv6, supportsIPv4map bool + +func init() { + sysInit() + supportsIPv6, supportsIPv4map = probeIPv6Stack() +} func firstFavoriteAddr(filter func(IP) IP, addrs []string) (addr IP) { if filter == nil { @@ -65,25 +72,67 @@ func (e InvalidAddrError) Temporary() bool { return false } // "host:port" or "[host]:port" into host and port. // The latter form must be used when host contains a colon. func SplitHostPort(hostport string) (host, port string, err error) { + host, port, _, err = splitHostPort(hostport) + return +} + +func splitHostPort(hostport string) (host, port, zone string, err error) { + j, k := 0, 0 + // The port starts after the last colon. i := last(hostport, ':') if i < 0 { - err = &AddrError{"missing port in address", hostport} - return + goto missingPort } - host, port = hostport[0:i], hostport[i+1:] - - // Can put brackets around host ... - if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { - host = host[1 : len(host)-1] + if hostport[0] == '[' { + // Expect the first ']' just before the last ':'. + end := byteIndex(hostport, ']') + if end < 0 { + err = &AddrError{"missing ']' in address", hostport} + return + } + switch end + 1 { + case len(hostport): + // There can't be a ':' behind the ']' now. + goto missingPort + case i: + // The expected result. + default: + // Either ']' isn't followed by a colon, or it is + // followed by a colon that is not the last one. + if hostport[end+1] == ':' { + goto tooManyColons + } + goto missingPort + } + host = hostport[1:end] + j, k = 1, end+1 // there can't be a '[' resp. ']' before these positions } else { - // ... but if there are no brackets, no colons. + host = hostport[:i] + if byteIndex(host, ':') >= 0 { - err = &AddrError{"too many colons in address", hostport} - return + goto tooManyColons } } + if byteIndex(hostport[j:], '[') >= 0 { + err = &AddrError{"unexpected '[' in address", hostport} + return + } + if byteIndex(hostport[k:], ']') >= 0 { + err = &AddrError{"unexpected ']' in address", hostport} + return + } + + port = hostport[i+1:] + return + +missingPort: + err = &AddrError{"missing port in address", hostport} + return + +tooManyColons: + err = &AddrError{"too many colons in address", hostport} return } @@ -97,49 +146,84 @@ func JoinHostPort(host, port string) string { return host + ":" + port } -// Convert "host:port" into IP address and port. -func hostPortToIP(net, hostport string) (ip IP, iport int, err error) { - host, port, err := SplitHostPort(hostport) - if err != nil { - return nil, 0, err - } - - var addr IP - if host != "" { - // Try as an IP address. - addr = ParseIP(host) - if addr == nil { - var filter func(IP) IP - if net != "" && net[len(net)-1] == '4' { - filter = ipv4only +func resolveInternetAddr(net, addr string, deadline time.Time) (Addr, error) { + var ( + err error + host, port, zone string + portnum int + ) + switch net { + case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": + if addr != "" { + if host, port, zone, err = splitHostPort(addr); err != nil { + return nil, err } - if net != "" && net[len(net)-1] == '6' { - filter = ipv6only - } - // Not an IP address. Try as a DNS name. - addrs, err := LookupHost(host) - if err != nil { - return nil, 0, err - } - addr = firstFavoriteAddr(filter, addrs) - if addr == nil { - // should not happen - return nil, 0, &AddrError{"LookupHost returned no suitable address", addrs[0]} + if portnum, err = parsePort(net, port); err != nil { + return nil, err } } + case "ip", "ip4", "ip6": + if addr != "" { + host = addr + } + default: + return nil, UnknownNetworkError(net) } - - p, i, ok := dtoi(port, 0) - if !ok || i != len(port) { - p, err = LookupPort(net, port) - if err != nil { - return nil, 0, err + inetaddr := func(net string, ip IP, port int, zone string) Addr { + switch net { + case "tcp", "tcp4", "tcp6": + return &TCPAddr{IP: ip, Port: port, Zone: zone} + case "udp", "udp4", "udp6": + return &UDPAddr{IP: ip, Port: port, Zone: zone} + case "ip", "ip4", "ip6": + return &IPAddr{IP: ip, Zone: zone} } + return nil + } + if host == "" { + return inetaddr(net, nil, portnum, zone), nil } - if p < 0 || p > 0xFFFF { - return nil, 0, &AddrError{"invalid port", port} + // Try as an IP address. + if ip := ParseIP(host); ip != nil { + return inetaddr(net, ip, portnum, zone), nil } + var filter func(IP) IP + if net != "" && net[len(net)-1] == '4' { + filter = ipv4only + } + if net != "" && net[len(net)-1] == '6' { + filter = ipv6only + } + // Try as a DNS name. + addrs, err := lookupHostDeadline(host, deadline) + if err != nil { + return nil, err + } + ip := firstFavoriteAddr(filter, addrs) + if ip == nil { + // should not happen + return nil, &AddrError{"LookupHost returned no suitable address", addrs[0]} + } + return inetaddr(net, ip, portnum, zone), nil +} - return addr, p, nil +func zoneToString(zone int) string { + if zone == 0 { + return "" + } + if ifi, err := InterfaceByIndex(zone); err == nil { + return ifi.Name + } + return itod(uint(zone)) +} +func zoneToInt(zone string) int { + if zone == "" { + return 0 + } + if ifi, err := InterfaceByName(zone); err == nil { + return ifi.Index + } + n, _, _ := dtoi(zone, 0) + return n } diff --git a/src/pkg/net/ipsock_plan9.go b/src/pkg/net/ipsock_plan9.go index eab0bf3e8..c7d542dab 100644 --- a/src/pkg/net/ipsock_plan9.go +++ b/src/pkg/net/ipsock_plan9.go @@ -2,21 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// IP sockets stubs for Plan 9 +// Internet protocol family sockets for Plan 9 package net import ( "errors" - "io" "os" "syscall" - "time" ) -// probeIPv6Stack returns two boolean values. If the first boolean value is -// true, kernel supports basic IPv6 functionality. If the second -// boolean value is true, kernel supports IPv6 IPv4-mapping. +// /sys/include/ape/sys/socket.h:/SOMAXCONN +var listenerBacklog = 5 + +// probeIPv6Stack returns two boolean values. If the first boolean +// value is true, kernel supports basic IPv6 functionality. If the +// second boolean value is true, kernel supports IPv6 IPv4-mapping. func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { return false, false } @@ -48,6 +49,7 @@ func readPlan9Addr(proto, filename string) (addr Addr, err error) { if err != nil { return } + defer f.Close() n, err := f.Read(buf[:]) if err != nil { return @@ -58,110 +60,15 @@ func readPlan9Addr(proto, filename string) (addr Addr, err error) { } switch proto { case "tcp": - addr = &TCPAddr{ip, port} + addr = &TCPAddr{IP: ip, Port: port} case "udp": - addr = &UDPAddr{ip, port} + addr = &UDPAddr{IP: ip, Port: port} default: return nil, errors.New("unknown protocol " + proto) } return addr, nil } -type plan9Conn struct { - proto, name, dir string - ctl, data *os.File - laddr, raddr Addr -} - -func newPlan9Conn(proto, name string, ctl *os.File, laddr, raddr Addr) *plan9Conn { - return &plan9Conn{proto, name, "/net/" + proto + "/" + name, ctl, nil, laddr, raddr} -} - -func (c *plan9Conn) ok() bool { return c != nil && c.ctl != nil } - -// Implementation of the Conn interface - see Conn for documentation. - -// Read implements the Conn Read method. -func (c *plan9Conn) Read(b []byte) (n int, err error) { - if !c.ok() { - return 0, syscall.EINVAL - } - if c.data == nil { - c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0) - if err != nil { - return 0, err - } - } - n, err = c.data.Read(b) - if c.proto == "udp" && err == io.EOF { - n = 0 - err = nil - } - return -} - -// Write implements the Conn Write method. -func (c *plan9Conn) Write(b []byte) (n int, err error) { - if !c.ok() { - return 0, syscall.EINVAL - } - if c.data == nil { - c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0) - if err != nil { - return 0, err - } - } - return c.data.Write(b) -} - -// Close closes the connection. -func (c *plan9Conn) Close() error { - if !c.ok() { - return syscall.EINVAL - } - err := c.ctl.Close() - if err != nil { - return err - } - if c.data != nil { - err = c.data.Close() - } - c.ctl = nil - c.data = nil - return err -} - -// LocalAddr returns the local network address. -func (c *plan9Conn) LocalAddr() Addr { - if !c.ok() { - return nil - } - return c.laddr -} - -// RemoteAddr returns the remote network address. -func (c *plan9Conn) RemoteAddr() Addr { - if !c.ok() { - return nil - } - return c.raddr -} - -// SetDeadline implements the Conn SetDeadline method. -func (c *plan9Conn) SetDeadline(t time.Time) error { - return syscall.EPLAN9 -} - -// SetReadDeadline implements the Conn SetReadDeadline method. -func (c *plan9Conn) SetReadDeadline(t time.Time) error { - return syscall.EPLAN9 -} - -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (c *plan9Conn) SetWriteDeadline(t time.Time) error { - return syscall.EPLAN9 -} - func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) { var ( ip IP @@ -192,98 +99,95 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, var buf [16]byte n, err := f.Read(buf[:]) if err != nil { + f.Close() return } return f, dest, proto, string(buf[:n]), nil } -func dialPlan9(net string, laddr, raddr Addr) (c *plan9Conn, err error) { +func netErr(e error) { + oe, ok := e.(*OpError) + if !ok { + return + } + if pe, ok := oe.Err.(*os.PathError); ok { + if _, ok = pe.Err.(syscall.ErrorString); ok { + oe.Err = pe.Err + } + } +} + +func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) { + defer func() { netErr(err) }() f, dest, proto, name, err := startPlan9(net, raddr) if err != nil { - return + return nil, &OpError{"dial", net, raddr, err} } _, err = f.WriteString("connect " + dest) if err != nil { - return + f.Close() + return nil, &OpError{"dial", f.Name(), raddr, err} } - laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") + data, err := os.OpenFile("/net/"+proto+"/"+name+"/data", os.O_RDWR, 0) if err != nil { - return + f.Close() + return nil, &OpError{"dial", net, raddr, err} } - raddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/remote") + laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") if err != nil { - return + data.Close() + f.Close() + return nil, &OpError{"dial", proto, raddr, err} } - return newPlan9Conn(proto, name, f, laddr, raddr), nil -} - -type plan9Listener struct { - proto, name, dir string - ctl *os.File - laddr Addr + return newFD(proto, name, f, data, laddr, raddr), nil } -func listenPlan9(net string, laddr Addr) (l *plan9Listener, err error) { +func listenPlan9(net string, laddr Addr) (fd *netFD, err error) { + defer func() { netErr(err) }() f, dest, proto, name, err := startPlan9(net, laddr) if err != nil { - return + return nil, &OpError{"listen", net, laddr, err} } _, err = f.WriteString("announce " + dest) if err != nil { - return + f.Close() + return nil, &OpError{"announce", proto, laddr, err} } laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") if err != nil { - return + f.Close() + return nil, &OpError{Op: "listen", Net: net, Err: err} } - l = new(plan9Listener) - l.proto = proto - l.name = name - l.dir = "/net/" + proto + "/" + name - l.ctl = f - l.laddr = laddr - return l, nil + return newFD(proto, name, f, nil, laddr, nil), nil } -func (l *plan9Listener) plan9Conn() *plan9Conn { - return newPlan9Conn(l.proto, l.name, l.ctl, l.laddr, nil) +func (l *netFD) netFD() *netFD { + return newFD(l.proto, l.name, l.ctl, l.data, l.laddr, l.raddr) } -func (l *plan9Listener) acceptPlan9() (c *plan9Conn, err error) { +func (l *netFD) acceptPlan9() (fd *netFD, err error) { + defer func() { netErr(err) }() f, err := os.Open(l.dir + "/listen") if err != nil { - return + return nil, &OpError{"accept", l.dir + "/listen", l.laddr, err} } var buf [16]byte n, err := f.Read(buf[:]) if err != nil { - return + f.Close() + return nil, &OpError{"accept", l.dir + "/listen", l.laddr, err} } name := string(buf[:n]) - laddr, err := readPlan9Addr(l.proto, l.dir+"/local") - if err != nil { - return - } - raddr, err := readPlan9Addr(l.proto, l.dir+"/remote") + data, err := os.OpenFile("/net/"+l.proto+"/"+name+"/data", os.O_RDWR, 0) if err != nil { - return + f.Close() + return nil, &OpError{"accept", l.proto, l.laddr, err} } - return newPlan9Conn(l.proto, name, f, laddr, raddr), nil -} - -func (l *plan9Listener) Accept() (c Conn, err error) { - c1, err := l.acceptPlan9() + raddr, err := readPlan9Addr(l.proto, "/net/"+l.proto+"/"+name+"/remote") if err != nil { - return + data.Close() + f.Close() + return nil, &OpError{"accept", l.proto, l.laddr, err} } - return c1, nil + return newFD(l.proto, name, f, data, l.laddr, raddr), nil } - -func (l *plan9Listener) Close() error { - if l == nil || l.ctl == nil { - return syscall.EINVAL - } - return l.ctl.Close() -} - -func (l *plan9Listener) Addr() Addr { return l.laddr } diff --git a/src/pkg/net/ipsock_posix.go b/src/pkg/net/ipsock_posix.go index ed313195c..4c37616ec 100644 --- a/src/pkg/net/ipsock_posix.go +++ b/src/pkg/net/ipsock_posix.go @@ -4,9 +4,14 @@ // +build darwin freebsd linux netbsd openbsd windows +// Internet protocol family sockets for POSIX + package net -import "syscall" +import ( + "syscall" + "time" +) // Should we try to use the IPv4 socket interface if we're // only dealing with IPv4 sockets? As long as the host system @@ -97,10 +102,13 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family return syscall.AF_INET6, true } - if mode == "listen" && laddr.isWildcard() { + if mode == "listen" && (laddr == nil || laddr.isWildcard()) { if supportsIPv4map { return syscall.AF_INET6, false } + if laddr == nil { + return syscall.AF_INET, false + } return laddr.family(), false } @@ -122,7 +130,7 @@ type sockaddr interface { sockaddr(family int) (syscall.Sockaddr, error) } -func internetSocket(net string, laddr, raddr sockaddr, sotype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { +func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { var la, ra syscall.Sockaddr family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode) if laddr != nil { @@ -135,7 +143,7 @@ func internetSocket(net string, laddr, raddr sockaddr, sotype, proto int, mode s goto Error } } - fd, err = socket(net, family, sotype, proto, ipv6only, la, ra, toAddr) + fd, err = socket(net, family, sotype, proto, ipv6only, la, ra, deadline, toAddr) if err != nil { goto Error } @@ -149,7 +157,7 @@ Error: return nil, &OpError{mode, net, addr, err} } -func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, error) { +func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) { switch family { case syscall.AF_INET: if len(ip) == 0 { @@ -158,12 +166,12 @@ func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, error) { if ip = ip.To4(); ip == nil { return nil, InvalidAddrError("non-IPv4 address") } - s := new(syscall.SockaddrInet4) + sa := new(syscall.SockaddrInet4) for i := 0; i < IPv4len; i++ { - s.Addr[i] = ip[i] + sa.Addr[i] = ip[i] } - s.Port = port - return s, nil + sa.Port = port + return sa, nil case syscall.AF_INET6: if len(ip) == 0 { ip = IPv6zero @@ -177,12 +185,13 @@ func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, error) { if ip = ip.To16(); ip == nil { return nil, InvalidAddrError("non-IPv6 address") } - s := new(syscall.SockaddrInet6) + sa := new(syscall.SockaddrInet6) for i := 0; i < IPv6len; i++ { - s.Addr[i] = ip[i] + sa.Addr[i] = ip[i] } - s.Port = port - return s, nil + sa.Port = port + sa.ZoneId = uint32(zoneToInt(zone)) + return sa, nil } return nil, InvalidAddrError("unexpected socket family") } diff --git a/src/pkg/net/doc.go b/src/pkg/net/lookup.go index 3a44e528e..bec93ec08 100644 --- a/src/pkg/net/doc.go +++ b/src/pkg/net/lookup.go @@ -4,12 +4,53 @@ package net +import ( + "time" +) + // LookupHost looks up the given host using the local resolver. // It returns an array of that host's addresses. func LookupHost(host string) (addrs []string, err error) { return lookupHost(host) } +func lookupHostDeadline(host string, deadline time.Time) (addrs []string, err error) { + if deadline.IsZero() { + return lookupHost(host) + } + + // TODO(bradfitz): consider pushing the deadline down into the + // name resolution functions. But that involves fixing it for + // the native Go resolver, cgo, Windows, etc. + // + // In the meantime, just use a goroutine. Most users affected + // by http://golang.org/issue/2631 are due to TCP connections + // to unresponsive hosts, not DNS. + timeout := deadline.Sub(time.Now()) + if timeout <= 0 { + err = errTimeout + return + } + t := time.NewTimer(timeout) + defer t.Stop() + type res struct { + addrs []string + err error + } + resc := make(chan res, 1) + go func() { + a, err := lookupHost(host) + resc <- res{a, err} + }() + select { + case <-t.C: + err = errTimeout + case r := <-resc: + addrs, err = r.addrs, r.err + } + return +} + // LookupIP looks up host using the local resolver. // It returns an array of that host's IPv4 and IPv6 addresses. func LookupIP(host string) (addrs []IP, err error) { @@ -47,6 +88,11 @@ func LookupMX(name string) (mx []*MX, err error) { return lookupMX(name) } +// LookupNS returns the DNS NS records for the given domain name. +func LookupNS(name string) (ns []*NS, err error) { + return lookupNS(name) +} + // LookupTXT returns the DNS TXT records for the given domain name. func LookupTXT(name string) (txt []string, err error) { return lookupTXT(name) diff --git a/src/pkg/net/lookup_plan9.go b/src/pkg/net/lookup_plan9.go index 2c698304b..ae7cf7942 100644 --- a/src/pkg/net/lookup_plan9.go +++ b/src/pkg/net/lookup_plan9.go @@ -201,6 +201,21 @@ func lookupMX(name string) (mx []*MX, err error) { return } +func lookupNS(name string) (ns []*NS, err error) { + lines, err := queryDNS(name, "ns") + if err != nil { + return + } + for _, line := range lines { + f := getFields(line) + if len(f) < 4 { + continue + } + ns = append(ns, &NS{f[3]}) + } + return +} + func lookupTXT(name string) (txt []string, err error) { lines, err := queryDNS(name, "txt") if err != nil { diff --git a/src/pkg/net/lookup_test.go b/src/pkg/net/lookup_test.go index 3a61dfb29..3355e4694 100644 --- a/src/pkg/net/lookup_test.go +++ b/src/pkg/net/lookup_test.go @@ -9,6 +9,7 @@ package net import ( "flag" + "strings" "testing" ) @@ -16,8 +17,7 @@ var testExternal = flag.Bool("external", true, "allow use of external networks d func TestGoogleSRV(t *testing.T) { if testing.Short() || !*testExternal { - t.Logf("skipping test to avoid external network") - return + t.Skip("skipping test to avoid external network") } _, addrs, err := LookupSRV("xmpp-server", "tcp", "google.com") if err != nil { @@ -39,8 +39,7 @@ func TestGoogleSRV(t *testing.T) { func TestGmailMX(t *testing.T) { if testing.Short() || !*testExternal { - t.Logf("skipping test to avoid external network") - return + t.Skip("skipping test to avoid external network") } mx, err := LookupMX("gmail.com") if err != nil { @@ -51,10 +50,22 @@ func TestGmailMX(t *testing.T) { } } +func TestGmailNS(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + ns, err := LookupNS("gmail.com") + if err != nil { + t.Errorf("failed: %s", err) + } + if len(ns) == 0 { + t.Errorf("no results") + } +} + func TestGmailTXT(t *testing.T) { if testing.Short() || !*testExternal { - t.Logf("skipping test to avoid external network") - return + t.Skip("skipping test to avoid external network") } txt, err := LookupTXT("gmail.com") if err != nil { @@ -67,8 +78,7 @@ func TestGmailTXT(t *testing.T) { func TestGoogleDNSAddr(t *testing.T) { if testing.Short() || !*testExternal { - t.Logf("skipping test to avoid external network") - return + t.Skip("skipping test to avoid external network") } names, err := LookupAddr("8.8.8.8") if err != nil { @@ -79,6 +89,16 @@ func TestGoogleDNSAddr(t *testing.T) { } } +func TestLookupIANACNAME(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + cname, err := LookupCNAME("www.iana.org") + if !strings.HasSuffix(cname, ".icann.org.") || err != nil { + t.Errorf(`LookupCNAME("www.iana.org.") = %q, %v, want "*.icann.org.", nil`, cname, err) + } +} + var revAddrTests = []struct { Addr string Reverse string diff --git a/src/pkg/net/lookup_unix.go b/src/pkg/net/lookup_unix.go index d500a1240..fa98eed5f 100644 --- a/src/pkg/net/lookup_unix.go +++ b/src/pkg/net/lookup_unix.go @@ -119,6 +119,19 @@ func lookupMX(name string) (mx []*MX, err error) { return } +func lookupNS(name string) (ns []*NS, err error) { + _, records, err := lookup(name, dnsTypeNS) + if err != nil { + return + } + ns = make([]*NS, len(records)) + for i, r := range records { + r := r.(*dnsRR_NS) + ns[i] = &NS{r.Ns} + } + return +} + func lookupTXT(name string) (txt []string, err error) { _, records, err := lookup(name, dnsTypeTXT) if err != nil { diff --git a/src/pkg/net/lookup_windows.go b/src/pkg/net/lookup_windows.go index 99783e975..3b29724f2 100644 --- a/src/pkg/net/lookup_windows.go +++ b/src/pkg/net/lookup_windows.go @@ -6,21 +6,17 @@ package net import ( "os" - "sync" + "runtime" "syscall" "unsafe" ) var ( - protoentLock sync.Mutex - hostentLock sync.Mutex - serventLock sync.Mutex + lookupPort = oldLookupPort + lookupIP = oldLookupIP ) -// lookupProtocol looks up IP protocol name and returns correspondent protocol number. -func lookupProtocol(name string) (proto int, err error) { - protoentLock.Lock() - defer protoentLock.Unlock() +func getprotobyname(name string) (proto int, err error) { p, err := syscall.GetProtoByName(name) if err != nil { return 0, os.NewSyscallError("GetProtoByName", err) @@ -28,6 +24,25 @@ func lookupProtocol(name string) (proto int, err error) { return int(p.Proto), nil } +// lookupProtocol looks up IP protocol name and returns correspondent protocol number. +func lookupProtocol(name string) (proto int, err error) { + // GetProtoByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + proto int + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + proto, err := getprotobyname(name) + ch <- result{proto: proto, err: err} + }() + r := <-ch + return r.proto, r.err +} + func lookupHost(name string) (addrs []string, err error) { ips, err := LookupIP(name) if err != nil { @@ -40,9 +55,7 @@ func lookupHost(name string) (addrs []string, err error) { return } -func lookupIP(name string) (addrs []IP, err error) { - hostentLock.Lock() - defer hostentLock.Unlock() +func gethostbyname(name string) (addrs []IP, err error) { h, err := syscall.GetHostByName(name) if err != nil { return nil, os.NewSyscallError("GetHostByName", err) @@ -56,20 +69,65 @@ func lookupIP(name string) (addrs []IP, err error) { } addrs = addrs[0:i] default: // TODO(vcc): Implement non IPv4 address lookups. - return nil, os.NewSyscallError("LookupHost", syscall.EWINDOWS) + return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS) } return addrs, nil } -func lookupPort(network, service string) (port int, err error) { +func oldLookupIP(name string) (addrs []IP, err error) { + // GetHostByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + addrs []IP + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + addrs, err := gethostbyname(name) + ch <- result{addrs: addrs, err: err} + }() + r := <-ch + return r.addrs, r.err +} + +func newLookupIP(name string) (addrs []IP, err error) { + hints := syscall.AddrinfoW{ + Family: syscall.AF_UNSPEC, + Socktype: syscall.SOCK_STREAM, + Protocol: syscall.IPPROTO_IP, + } + var result *syscall.AddrinfoW + e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result) + if e != nil { + return nil, os.NewSyscallError("GetAddrInfoW", e) + } + defer syscall.FreeAddrInfoW(result) + addrs = make([]IP, 0, 5) + for ; result != nil; result = result.Next { + addr := unsafe.Pointer(result.Addr) + switch result.Family { + case syscall.AF_INET: + a := (*syscall.RawSockaddrInet4)(addr).Addr + addrs = append(addrs, IPv4(a[0], a[1], a[2], a[3])) + case syscall.AF_INET6: + a := (*syscall.RawSockaddrInet6)(addr).Addr + addrs = append(addrs, IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}) + default: + return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS) + } + } + return addrs, nil +} + +func getservbyname(network, service string) (port int, err error) { switch network { case "tcp4", "tcp6": network = "tcp" case "udp4", "udp6": network = "udp" } - serventLock.Lock() - defer serventLock.Unlock() s, err := syscall.GetServByName(service, network) if err != nil { return 0, os.NewSyscallError("GetServByName", err) @@ -77,6 +135,58 @@ func lookupPort(network, service string) (port int, err error) { return int(syscall.Ntohs(s.Port)), nil } +func oldLookupPort(network, service string) (port int, err error) { + // GetServByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + port int + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + port, err := getservbyname(network, service) + ch <- result{port: port, err: err} + }() + r := <-ch + return r.port, r.err +} + +func newLookupPort(network, service string) (port int, err error) { + var stype int32 + switch network { + case "tcp4", "tcp6": + stype = syscall.SOCK_STREAM + case "udp4", "udp6": + stype = syscall.SOCK_DGRAM + } + hints := syscall.AddrinfoW{ + Family: syscall.AF_UNSPEC, + Socktype: stype, + Protocol: syscall.IPPROTO_IP, + } + var result *syscall.AddrinfoW + e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result) + if e != nil { + return 0, os.NewSyscallError("GetAddrInfoW", e) + } + defer syscall.FreeAddrInfoW(result) + if result == nil { + return 0, os.NewSyscallError("LookupPort", syscall.EINVAL) + } + addr := unsafe.Pointer(result.Addr) + switch result.Family { + case syscall.AF_INET: + a := (*syscall.RawSockaddrInet4)(addr) + return int(syscall.Ntohs(a.Port)), nil + case syscall.AF_INET6: + a := (*syscall.RawSockaddrInet6)(addr) + return int(syscall.Ntohs(a.Port)), nil + } + return 0, os.NewSyscallError("LookupPort", syscall.EINVAL) +} + func lookupCNAME(name string) (cname string, err error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) @@ -129,6 +239,21 @@ func lookupMX(name string) (mx []*MX, err error) { return mx, nil } +func lookupNS(name string) (ns []*NS, err error) { + var r *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil) + if e != nil { + return nil, os.NewSyscallError("LookupNS", e) + } + defer syscall.DnsRecordListFree(r, 1) + ns = make([]*NS, 0, 10) + for p := r; p != nil && p.Type == syscall.DNS_TYPE_NS; p = p.Next { + v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) + ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."}) + } + return ns, nil +} + func lookupTXT(name string) (txt []string, err error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil) diff --git a/src/pkg/net/mail/message.go b/src/pkg/net/mail/message.go index b610ccf3f..96c796e78 100644 --- a/src/pkg/net/mail/message.go +++ b/src/pkg/net/mail/message.go @@ -47,7 +47,8 @@ type Message struct { } // ReadMessage reads a message from r. -// The headers are parsed, and the body of the message will be reading from r. +// The headers are parsed, and the body of the message will be available +// for reading from r. func ReadMessage(r io.Reader) (msg *Message, err error) { tp := textproto.NewReader(bufio.NewReader(r)) @@ -126,7 +127,7 @@ func (h Header) AddressList(key string) ([]*Address, error) { if hdr == "" { return nil, ErrHeaderNotPresent } - return newAddrParser(hdr).parseAddressList() + return ParseAddressList(hdr) } // Address represents a single mail address. @@ -137,6 +138,16 @@ type Address struct { Address string // user@domain } +// Parses a single RFC 5322 address, e.g. "Barry Gibbs <bg@example.com>" +func ParseAddress(address string) (*Address, error) { + return newAddrParser(address).parseAddress() +} + +// ParseAddressList parses the given string as a list of addresses. +func ParseAddressList(list string) ([]*Address, error) { + return newAddrParser(list).parseAddressList() +} + // String formats the address as a valid RFC 5322 address. // If the address's name contains non-ASCII characters // the name will be rendered according to RFC 2047. diff --git a/src/pkg/net/mail/message_test.go b/src/pkg/net/mail/message_test.go index fd17eb414..2e746f4a7 100644 --- a/src/pkg/net/mail/message_test.go +++ b/src/pkg/net/mail/message_test.go @@ -227,13 +227,24 @@ func TestAddressParsing(t *testing.T) { }, } for _, test := range tests { - addrs, err := newAddrParser(test.addrsStr).parseAddressList() + if len(test.exp) == 1 { + addr, err := ParseAddress(test.addrsStr) + if err != nil { + t.Errorf("Failed parsing (single) %q: %v", test.addrsStr, err) + continue + } + if !reflect.DeepEqual([]*Address{addr}, test.exp) { + t.Errorf("Parse (single) of %q: got %+v, want %+v", test.addrsStr, addr, test.exp) + } + } + + addrs, err := ParseAddressList(test.addrsStr) if err != nil { - t.Errorf("Failed parsing %q: %v", test.addrsStr, err) + t.Errorf("Failed parsing (list) %q: %v", test.addrsStr, err) continue } if !reflect.DeepEqual(addrs, test.exp) { - t.Errorf("Parse of %q: got %+v, want %+v", test.addrsStr, addrs, test.exp) + t.Errorf("Parse (list) of %q: got %+v, want %+v", test.addrsStr, addrs, test.exp) } } } diff --git a/src/pkg/net/multicast_posix_test.go b/src/pkg/net/multicast_posix_test.go new file mode 100644 index 000000000..ff1edaf83 --- /dev/null +++ b/src/pkg/net/multicast_posix_test.go @@ -0,0 +1,180 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9 + +package net + +import ( + "errors" + "os" + "runtime" + "testing" +) + +var multicastListenerTests = []struct { + net string + gaddr *UDPAddr + flags Flags + ipv6 bool // test with underlying AF_INET6 socket +}{ + // cf. RFC 4727: Experimental Values in IPv4, IPv6, ICMPv4, ICMPv6, UDP, and TCP Headers + + {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, FlagUp | FlagLoopback, false}, + {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, 0, false}, + {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, FlagUp | FlagLoopback, true}, + {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, 0, true}, + + {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, FlagUp | FlagLoopback, false}, + {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, 0, false}, + + {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}, FlagUp | FlagLoopback, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}, 0, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}, FlagUp | FlagLoopback, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}, 0, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}, FlagUp | FlagLoopback, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}, 0, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}, FlagUp | FlagLoopback, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}, 0, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}, FlagUp | FlagLoopback, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}, 0, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, FlagUp | FlagLoopback, true}, + {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, 0, true}, +} + +// TestMulticastListener tests both single and double listen to a test +// listener with same address family, same group address and same port. +func TestMulticastListener(t *testing.T) { + switch runtime.GOOS { + case "netbsd", "openbsd", "plan9", "solaris", "windows": + t.Skipf("skipping test on %q", runtime.GOOS) + case "linux": + if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" { + t.Skipf("skipping test on %q/%q", runtime.GOOS, runtime.GOARCH) + } + } + + for _, tt := range multicastListenerTests { + if tt.ipv6 && (!*testIPv6 || !supportsIPv6 || os.Getuid() != 0) { + continue + } + ifi, err := availMulticastInterface(t, tt.flags) + if err != nil { + continue + } + c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) + if err != nil { + t.Fatalf("First ListenMulticastUDP failed: %v", err) + } + checkMulticastListener(t, err, c1, tt.gaddr) + c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) + if err != nil { + t.Fatalf("Second ListenMulticastUDP failed: %v", err) + } + checkMulticastListener(t, err, c2, tt.gaddr) + c2.Close() + c1.Close() + } +} + +func TestSimpleMulticastListener(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + case "windows": + if testing.Short() || !*testExternal { + t.Skip("skipping test on windows to avoid firewall") + } + } + + for _, tt := range multicastListenerTests { + if tt.ipv6 { + continue + } + tt.flags = FlagUp | FlagMulticast // for windows testing + ifi, err := availMulticastInterface(t, tt.flags) + if err != nil { + continue + } + c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) + if err != nil { + t.Fatalf("First ListenMulticastUDP failed: %v", err) + } + checkSimpleMulticastListener(t, err, c1, tt.gaddr) + c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) + if err != nil { + t.Fatalf("Second ListenMulticastUDP failed: %v", err) + } + checkSimpleMulticastListener(t, err, c2, tt.gaddr) + c2.Close() + c1.Close() + } +} + +func checkMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) { + if !multicastRIBContains(t, gaddr.IP) { + t.Errorf("%q not found in RIB", gaddr.String()) + return + } + la := c.LocalAddr() + if la == nil { + t.Error("LocalAddr failed") + return + } + if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { + t.Errorf("got %v; expected a proper address with non-zero port number", la) + return + } +} + +func checkSimpleMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) { + la := c.LocalAddr() + if la == nil { + t.Error("LocalAddr failed") + return + } + if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { + t.Errorf("got %v; expected a proper address with non-zero port number", la) + return + } +} + +func availMulticastInterface(t *testing.T, flags Flags) (*Interface, error) { + var ifi *Interface + if flags != Flags(0) { + ift, err := Interfaces() + if err != nil { + t.Fatalf("Interfaces failed: %v", err) + } + for _, x := range ift { + if x.Flags&flags == flags { + ifi = &x + break + } + } + if ifi == nil { + return nil, errors.New("an appropriate multicast interface not found") + } + } + return ifi, nil +} + +func multicastRIBContains(t *testing.T, ip IP) bool { + ift, err := Interfaces() + if err != nil { + t.Fatalf("Interfaces failed: %v", err) + } + for _, ifi := range ift { + ifmat, err := ifi.MulticastAddrs() + if err != nil { + t.Fatalf("MulticastAddrs failed: %v", err) + } + for _, ifma := range ifmat { + if ifma.(*IPAddr).IP.Equal(ip) { + return true + } + } + } + return false +} diff --git a/src/pkg/net/multicast_test.go b/src/pkg/net/multicast_test.go deleted file mode 100644 index 67261b1ee..000000000 --- a/src/pkg/net/multicast_test.go +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "errors" - "os" - "runtime" - "syscall" - "testing" -) - -var multicastListenerTests = []struct { - net string - gaddr *UDPAddr - flags Flags - ipv6 bool // test with underlying AF_INET6 socket -}{ - // cf. RFC 4727: Experimental Values in IPv4, IPv6, ICMPv4, ICMPv6, UDP, and TCP Headers - - {"udp", &UDPAddr{IPv4(224, 0, 0, 254), 12345}, FlagUp | FlagLoopback, false}, - {"udp", &UDPAddr{IPv4(224, 0, 0, 254), 12345}, 0, false}, - {"udp", &UDPAddr{ParseIP("ff0e::114"), 12345}, FlagUp | FlagLoopback, true}, - {"udp", &UDPAddr{ParseIP("ff0e::114"), 12345}, 0, true}, - - {"udp4", &UDPAddr{IPv4(224, 0, 0, 254), 12345}, FlagUp | FlagLoopback, false}, - {"udp4", &UDPAddr{IPv4(224, 0, 0, 254), 12345}, 0, false}, - - {"udp6", &UDPAddr{ParseIP("ff01::114"), 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{ParseIP("ff01::114"), 12345}, 0, true}, - {"udp6", &UDPAddr{ParseIP("ff02::114"), 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{ParseIP("ff02::114"), 12345}, 0, true}, - {"udp6", &UDPAddr{ParseIP("ff04::114"), 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{ParseIP("ff04::114"), 12345}, 0, true}, - {"udp6", &UDPAddr{ParseIP("ff05::114"), 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{ParseIP("ff05::114"), 12345}, 0, true}, - {"udp6", &UDPAddr{ParseIP("ff08::114"), 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{ParseIP("ff08::114"), 12345}, 0, true}, - {"udp6", &UDPAddr{ParseIP("ff0e::114"), 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{ParseIP("ff0e::114"), 12345}, 0, true}, -} - -// TestMulticastListener tests both single and double listen to a test -// listener with same address family, same group address and same port. -func TestMulticastListener(t *testing.T) { - switch runtime.GOOS { - case "netbsd", "openbsd", "plan9", "windows": - t.Logf("skipping test on %q", runtime.GOOS) - return - case "linux": - if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" { - t.Logf("skipping test on %q/%q", runtime.GOOS, runtime.GOARCH) - return - } - } - - for _, tt := range multicastListenerTests { - if tt.ipv6 && (!supportsIPv6 || os.Getuid() != 0) { - continue - } - ifi, err := availMulticastInterface(t, tt.flags) - if err != nil { - continue - } - c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("First ListenMulticastUDP failed: %v", err) - } - checkMulticastListener(t, err, c1, tt.gaddr) - c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("Second ListenMulticastUDP failed: %v", err) - } - checkMulticastListener(t, err, c2, tt.gaddr) - c2.Close() - switch c1.fd.family { - case syscall.AF_INET: - testIPv4MulticastSocketOptions(t, c1.fd, ifi) - case syscall.AF_INET6: - testIPv6MulticastSocketOptions(t, c1.fd, ifi) - } - c1.Close() - } -} - -func TestSimpleMulticastListener(t *testing.T) { - switch runtime.GOOS { - case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) - return - case "windows": - if testing.Short() || !*testExternal { - t.Logf("skipping test on windows to avoid firewall") - return - } - } - - for _, tt := range multicastListenerTests { - if tt.ipv6 { - continue - } - tt.flags = FlagUp | FlagMulticast // for windows testing - ifi, err := availMulticastInterface(t, tt.flags) - if err != nil { - continue - } - c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("First ListenMulticastUDP failed: %v", err) - } - checkSimpleMulticastListener(t, err, c1, tt.gaddr) - c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("Second ListenMulticastUDP failed: %v", err) - } - checkSimpleMulticastListener(t, err, c2, tt.gaddr) - c2.Close() - c1.Close() - } -} - -func checkMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) { - if !multicastRIBContains(t, gaddr.IP) { - t.Fatalf("%q not found in RIB", gaddr.String()) - } - if c.LocalAddr().String() != gaddr.String() { - t.Fatalf("LocalAddr returns %q, expected %q", c.LocalAddr().String(), gaddr.String()) - } -} - -func checkSimpleMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) { - if c.LocalAddr().String() != gaddr.String() { - t.Fatalf("LocalAddr returns %q, expected %q", c.LocalAddr().String(), gaddr.String()) - } -} - -func availMulticastInterface(t *testing.T, flags Flags) (*Interface, error) { - var ifi *Interface - if flags != Flags(0) { - ift, err := Interfaces() - if err != nil { - t.Fatalf("Interfaces failed: %v", err) - } - for _, x := range ift { - if x.Flags&flags == flags { - ifi = &x - break - } - } - if ifi == nil { - return nil, errors.New("an appropriate multicast interface not found") - } - } - return ifi, nil -} - -func multicastRIBContains(t *testing.T, ip IP) bool { - ift, err := Interfaces() - if err != nil { - t.Fatalf("Interfaces failed: %v", err) - } - for _, ifi := range ift { - ifmat, err := ifi.MulticastAddrs() - if err != nil { - t.Fatalf("MulticastAddrs failed: %v", err) - } - for _, ifma := range ifmat { - if ifma.(*IPAddr).IP.Equal(ip) { - return true - } - } - } - return false -} - -func testIPv4MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) { - _, err := ipv4MulticastInterface(fd) - if err != nil { - t.Fatalf("ipv4MulticastInterface failed: %v", err) - } - if ifi != nil { - err = setIPv4MulticastInterface(fd, ifi) - if err != nil { - t.Fatalf("setIPv4MulticastInterface failed: %v", err) - } - } - _, err = ipv4MulticastTTL(fd) - if err != nil { - t.Fatalf("ipv4MulticastTTL failed: %v", err) - } - err = setIPv4MulticastTTL(fd, 1) - if err != nil { - t.Fatalf("setIPv4MulticastTTL failed: %v", err) - } - _, err = ipv4MulticastLoopback(fd) - if err != nil { - t.Fatalf("ipv4MulticastLoopback failed: %v", err) - } - err = setIPv4MulticastLoopback(fd, false) - if err != nil { - t.Fatalf("setIPv4MulticastLoopback failed: %v", err) - } -} - -func testIPv6MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) { - _, err := ipv6MulticastInterface(fd) - if err != nil { - t.Fatalf("ipv6MulticastInterface failed: %v", err) - } - if ifi != nil { - err = setIPv6MulticastInterface(fd, ifi) - if err != nil { - t.Fatalf("setIPv6MulticastInterface failed: %v", err) - } - } - _, err = ipv6MulticastHopLimit(fd) - if err != nil { - t.Fatalf("ipv6MulticastHopLimit failed: %v", err) - } - err = setIPv6MulticastHopLimit(fd, 1) - if err != nil { - t.Fatalf("setIPv6MulticastHopLimit failed: %v", err) - } - _, err = ipv6MulticastLoopback(fd) - if err != nil { - t.Fatalf("ipv6MulticastLoopback failed: %v", err) - } - err = setIPv6MulticastLoopback(fd, false) - if err != nil { - t.Fatalf("setIPv6MulticastLoopback failed: %v", err) - } -} diff --git a/src/pkg/net/net.go b/src/pkg/net/net.go index 9ebcdbe99..72b2b646c 100644 --- a/src/pkg/net/net.go +++ b/src/pkg/net/net.go @@ -44,6 +44,10 @@ package net import ( "errors" + "io" + "os" + "sync" + "syscall" "time" ) @@ -103,6 +107,105 @@ type Conn interface { SetWriteDeadline(t time.Time) error } +type conn struct { + fd *netFD +} + +func (c *conn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface. + +// Read implements the Conn Read method. +func (c *conn) Read(b []byte) (int, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + return c.fd.Read(b) +} + +// Write implements the Conn Write method. +func (c *conn) Write(b []byte) (int, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + return c.fd.Write(b) +} + +// Close closes the connection. +func (c *conn) Close() error { + if !c.ok() { + return syscall.EINVAL + } + return c.fd.Close() +} + +// LocalAddr returns the local network address. +func (c *conn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address. +func (c *conn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetDeadline implements the Conn SetDeadline method. +func (c *conn) SetDeadline(t time.Time) error { + if !c.ok() { + return syscall.EINVAL + } + return setDeadline(c.fd, t) +} + +// SetReadDeadline implements the Conn SetReadDeadline method. +func (c *conn) SetReadDeadline(t time.Time) error { + if !c.ok() { + return syscall.EINVAL + } + return setReadDeadline(c.fd, t) +} + +// SetWriteDeadline implements the Conn SetWriteDeadline method. +func (c *conn) SetWriteDeadline(t time.Time) error { + if !c.ok() { + return syscall.EINVAL + } + return setWriteDeadline(c.fd, t) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *conn) SetReadBuffer(bytes int) error { + if !c.ok() { + return syscall.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *conn) SetWriteBuffer(bytes int) error { + if !c.ok() { + return syscall.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// File sets the underlying os.File to blocking mode and returns a copy. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +// +// The returned os.File's file descriptor is different from the connection's. +// Attempting to change properties of the original using this duplicate +// may or may not have the desired effect. +func (c *conn) File() (f *os.File, err error) { return c.fd.dup() } + // An Error represents a network error. type Error interface { error @@ -173,11 +276,23 @@ type Listener interface { var errMissingAddress = errors.New("missing address") +// OpError is the error type usually returned by functions in the net +// package. It describes the operation, network type, and address of +// an error. type OpError struct { - Op string - Net string + // Op is the operation which caused the error, such as + // "read" or "write". + Op string + + // Net is the network type on which this error occurred, + // such as "tcp" or "udp6". + Net string + + // Addr is the network address on which this error occurred. Addr Addr - Err error + + // Err is the error that occurred during the operation. + Err error } func (e *OpError) Error() string { @@ -204,6 +319,8 @@ func (e *OpError) Temporary() bool { return ok && t.Temporary() } +var noDeadline = time.Time{} + type timeout interface { Timeout() bool } @@ -221,6 +338,8 @@ func (e *timeoutError) Temporary() bool { return true } var errTimeout error = &timeoutError{} +var errClosing = errors.New("use of closed network connection") + type AddrError struct { Err string Addr string @@ -262,3 +381,47 @@ func (e *DNSConfigError) Error() string { func (e *DNSConfigError) Timeout() bool { return false } func (e *DNSConfigError) Temporary() bool { return false } + +type writerOnly struct { + io.Writer +} + +// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't +// applicable. +func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) { + // Use wrapper to hide existing r.ReadFrom from io.Copy. + return io.Copy(writerOnly{w}, r) +} + +// deadline is an atomically-accessed number of nanoseconds since 1970 +// or 0, if no deadline is set. +type deadline struct { + sync.Mutex + val int64 +} + +func (d *deadline) expired() bool { + t := d.value() + return t > 0 && time.Now().UnixNano() >= t +} + +func (d *deadline) value() (v int64) { + d.Lock() + v = d.val + d.Unlock() + return +} + +func (d *deadline) set(v int64) { + d.Lock() + d.val = v + d.Unlock() +} + +func (d *deadline) setTime(t time.Time) { + if t.IsZero() { + d.set(0) + } else { + d.set(t.UnixNano()) + } +} diff --git a/src/pkg/net/net_test.go b/src/pkg/net/net_test.go index fd145e1d7..1a512a5b1 100644 --- a/src/pkg/net/net_test.go +++ b/src/pkg/net/net_test.go @@ -6,6 +6,8 @@ package net import ( "io" + "io/ioutil" + "os" "runtime" "testing" "time" @@ -13,18 +15,17 @@ import ( func TestShutdown(t *testing.T) { if runtime.GOOS == "plan9" { - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } - l, err := Listen("tcp", "127.0.0.1:0") + ln, err := Listen("tcp", "127.0.0.1:0") if err != nil { - if l, err = Listen("tcp6", "[::1]:0"); err != nil { + if ln, err = Listen("tcp6", "[::1]:0"); err != nil { t.Fatalf("ListenTCP on :0: %v", err) } } go func() { - c, err := l.Accept() + c, err := ln.Accept() if err != nil { t.Fatalf("Accept: %v", err) } @@ -37,7 +38,7 @@ func TestShutdown(t *testing.T) { c.Close() }() - c, err := Dial("tcp", l.Addr().String()) + c, err := Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } @@ -58,8 +59,61 @@ func TestShutdown(t *testing.T) { } } +func TestShutdownUnix(t *testing.T) { + switch runtime.GOOS { + case "windows", "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + f, err := ioutil.TempFile("", "go_net_unixtest") + if err != nil { + t.Fatalf("TempFile: %s", err) + } + f.Close() + tmpname := f.Name() + os.Remove(tmpname) + ln, err := Listen("unix", tmpname) + if err != nil { + t.Fatalf("ListenUnix on %s: %s", tmpname, err) + } + defer os.Remove(tmpname) + + go func() { + c, err := ln.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + var buf [10]byte + n, err := c.Read(buf[:]) + if n != 0 || err != io.EOF { + t.Fatalf("server Read = %d, %v; want 0, io.EOF", n, err) + } + c.Write([]byte("response")) + c.Close() + }() + + c, err := Dial("unix", tmpname) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + + err = c.(*UnixConn).CloseWrite() + if err != nil { + t.Fatalf("CloseWrite: %v", err) + } + var buf [10]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("client Read: %d, %v", n, err) + } + got := string(buf[:n]) + if got != "response" { + t.Errorf("read = %q, want \"response\"", got) + } +} + func TestTCPListenClose(t *testing.T) { - l, err := Listen("tcp", "127.0.0.1:0") + ln, err := Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Listen failed: %v", err) } @@ -67,11 +121,12 @@ func TestTCPListenClose(t *testing.T) { done := make(chan bool, 1) go func() { time.Sleep(100 * time.Millisecond) - l.Close() + ln.Close() }() go func() { - _, err = l.Accept() + c, err := ln.Accept() if err == nil { + c.Close() t.Error("Accept succeeded") } else { t.Logf("Accept timeout error: %s (any error is fine)", err) @@ -86,7 +141,11 @@ func TestTCPListenClose(t *testing.T) { } func TestUDPListenClose(t *testing.T) { - l, err := ListenPacket("udp", "127.0.0.1:0") + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + ln, err := ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("Listen failed: %v", err) } @@ -95,10 +154,10 @@ func TestUDPListenClose(t *testing.T) { done := make(chan bool, 1) go func() { time.Sleep(100 * time.Millisecond) - l.Close() + ln.Close() }() go func() { - _, _, err = l.ReadFrom(buf) + _, _, err = ln.ReadFrom(buf) if err == nil { t.Error("ReadFrom succeeded") } else { @@ -112,3 +171,46 @@ func TestUDPListenClose(t *testing.T) { t.Fatal("timeout waiting for UDP close") } } + +func TestTCPClose(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + l, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + read := func(r io.Reader) error { + var m [1]byte + _, err := r.Read(m[:]) + return err + } + + go func() { + c, err := Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + + go read(c) + + time.Sleep(10 * time.Millisecond) + c.Close() + }() + + c, err := l.Accept() + if err != nil { + t.Fatal(err) + } + defer c.Close() + + for err == nil { + err = read(c) + } + if err != nil && err != io.EOF { + t.Fatal(err) + } +} diff --git a/src/pkg/net/newpollserver.go b/src/pkg/net/newpollserver_unix.go index d34bb511f..618b5b10b 100644 --- a/src/pkg/net/newpollserver.go +++ b/src/pkg/net/newpollserver_unix.go @@ -13,8 +13,6 @@ import ( func newPollServer() (s *pollServer, err error) { s = new(pollServer) - s.cr = make(chan *netFD, 1) - s.cw = make(chan *netFD, 1) if s.pr, s.pw, err = os.Pipe(); err != nil { return nil, err } diff --git a/src/pkg/net/packetconn_test.go b/src/pkg/net/packetconn_test.go new file mode 100644 index 000000000..93c7a6472 --- /dev/null +++ b/src/pkg/net/packetconn_test.go @@ -0,0 +1,200 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements API tests across platforms and will never have a build +// tag. + +package net + +import ( + "os" + "runtime" + "strings" + "testing" + "time" +) + +var packetConnTests = []struct { + net string + addr1 string + addr2 string +}{ + {"udp", "127.0.0.1:0", "127.0.0.1:0"}, + {"ip:icmp", "127.0.0.1", "127.0.0.1"}, + {"unixgram", testUnixAddr(), testUnixAddr()}, +} + +func TestPacketConn(t *testing.T) { + closer := func(c PacketConn, net, addr1, addr2 string) { + c.Close() + switch net { + case "unixgram": + os.Remove(addr1) + os.Remove(addr2) + } + } + + for i, tt := range packetConnTests { + var wb []byte + netstr := strings.Split(tt.net, ":") + switch netstr[0] { + case "udp": + wb = []byte("UDP PACKETCONN TEST") + case "ip": + switch runtime.GOOS { + case "plan9": + continue + } + if os.Getuid() != 0 { + continue + } + var err error + wb, err = (&icmpMessage{ + Type: icmpv4EchoRequest, Code: 0, + Body: &icmpEcho{ + ID: os.Getpid() & 0xffff, Seq: i + 1, + Data: []byte("IP PACKETCONN TEST"), + }, + }).Marshal() + if err != nil { + t.Fatalf("icmpMessage.Marshal failed: %v", err) + } + case "unixgram": + switch runtime.GOOS { + case "plan9", "windows": + continue + } + wb = []byte("UNIXGRAM PACKETCONN TEST") + default: + continue + } + + c1, err := ListenPacket(tt.net, tt.addr1) + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + defer closer(c1, netstr[0], tt.addr1, tt.addr2) + c1.LocalAddr() + c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) + c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + c1.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + + c2, err := ListenPacket(tt.net, tt.addr2) + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + defer closer(c2, netstr[0], tt.addr1, tt.addr2) + c2.LocalAddr() + c2.SetDeadline(time.Now().Add(100 * time.Millisecond)) + c2.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + c2.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + + if _, err := c1.WriteTo(wb, c2.LocalAddr()); err != nil { + t.Fatalf("PacketConn.WriteTo failed: %v", err) + } + rb2 := make([]byte, 128) + if _, _, err := c2.ReadFrom(rb2); err != nil { + t.Fatalf("PacketConn.ReadFrom failed: %v", err) + } + if _, err := c2.WriteTo(wb, c1.LocalAddr()); err != nil { + t.Fatalf("PacketConn.WriteTo failed: %v", err) + } + rb1 := make([]byte, 128) + if _, _, err := c1.ReadFrom(rb1); err != nil { + t.Fatalf("PacketConn.ReadFrom failed: %v", err) + } + } +} + +func TestConnAndPacketConn(t *testing.T) { + closer := func(c PacketConn, net, addr1, addr2 string) { + c.Close() + switch net { + case "unixgram": + os.Remove(addr1) + os.Remove(addr2) + } + } + + for i, tt := range packetConnTests { + var wb []byte + netstr := strings.Split(tt.net, ":") + switch netstr[0] { + case "udp": + wb = []byte("UDP PACKETCONN TEST") + case "ip": + switch runtime.GOOS { + case "plan9": + continue + } + if os.Getuid() != 0 { + continue + } + var err error + wb, err = (&icmpMessage{ + Type: icmpv4EchoRequest, Code: 0, + Body: &icmpEcho{ + ID: os.Getpid() & 0xffff, Seq: i + 1, + Data: []byte("IP PACKETCONN TEST"), + }, + }).Marshal() + if err != nil { + t.Fatalf("icmpMessage.Marshal failed: %v", err) + } + case "unixgram": + switch runtime.GOOS { + case "plan9", "windows": + continue + } + wb = []byte("UNIXGRAM PACKETCONN TEST") + default: + continue + } + + c1, err := ListenPacket(tt.net, tt.addr1) + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + defer closer(c1, netstr[0], tt.addr1, tt.addr2) + c1.LocalAddr() + c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) + c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + c1.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + + c2, err := Dial(tt.net, c1.LocalAddr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c2.Close() + c2.LocalAddr() + c2.RemoteAddr() + c2.SetDeadline(time.Now().Add(100 * time.Millisecond)) + c2.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + c2.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + + if _, err := c2.Write(wb); err != nil { + t.Fatalf("Conn.Write failed: %v", err) + } + rb1 := make([]byte, 128) + if _, _, err := c1.ReadFrom(rb1); err != nil { + t.Fatalf("PacetConn.ReadFrom failed: %v", err) + } + var dst Addr + switch netstr[0] { + case "ip": + dst = &IPAddr{IP: IPv4(127, 0, 0, 1)} + case "unixgram": + continue + default: + dst = c2.LocalAddr() + } + if _, err := c1.WriteTo(wb, dst); err != nil { + t.Fatalf("PacketConn.WriteTo failed: %v", err) + } + rb2 := make([]byte, 128) + if _, err := c2.Read(rb2); err != nil { + t.Fatalf("Conn.Read failed: %v", err) + } + } +} diff --git a/src/pkg/net/parse_test.go b/src/pkg/net/parse_test.go index 30fda45df..9df0c534b 100644 --- a/src/pkg/net/parse_test.go +++ b/src/pkg/net/parse_test.go @@ -15,8 +15,7 @@ func TestReadLine(t *testing.T) { // /etc/services file does not exist on windows and Plan 9. switch runtime.GOOS { case "plan9", "windows": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } filename := "/etc/services" // a nice big file diff --git a/src/pkg/net/port.go b/src/pkg/net/port.go index 16780da11..c24f4ed5b 100644 --- a/src/pkg/net/port.go +++ b/src/pkg/net/port.go @@ -1,69 +1,24 @@ -// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2012 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd - -// Read system port mappings from /etc/services +// Network service port manipulations package net -import "sync" - -var services map[string]map[string]int -var servicesError error -var onceReadServices sync.Once - -func readServices() { - services = make(map[string]map[string]int) - var file *file - if file, servicesError = open("/etc/services"); servicesError != nil { - return - } - for line, ok := file.readLine(); ok; line, ok = file.readLine() { - // "http 80/tcp www www-http # World Wide Web HTTP" - if i := byteIndex(line, '#'); i >= 0 { - line = line[0:i] - } - f := getFields(line) - if len(f) < 2 { - continue - } - portnet := f[1] // "tcp/80" - port, j, ok := dtoi(portnet, 0) - if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' { - continue - } - netw := portnet[j+1:] // "tcp" - m, ok1 := services[netw] - if !ok1 { - m = make(map[string]int) - services[netw] = m - } - for i := 0; i < len(f); i++ { - if i != 1 { // f[1] was port/net - m[f[i]] = port - } +// parsePort parses port as a network service port number for both +// TCP and UDP. +func parsePort(net, port string) (int, error) { + p, i, ok := dtoi(port, 0) + if !ok || i != len(port) { + var err error + p, err = LookupPort(net, port) + if err != nil { + return 0, err } } - file.close() -} - -// goLookupPort is the native Go implementation of LookupPort. -func goLookupPort(network, service string) (port int, err error) { - onceReadServices.Do(readServices) - - switch network { - case "tcp4", "tcp6": - network = "tcp" - case "udp4", "udp6": - network = "udp" - } - - if m, ok := services[network]; ok { - if port, ok = m[service]; ok { - return - } + if p < 0 || p > 0xFFFF { + return 0, &AddrError{"invalid port", port} } - return 0, &AddrError{"unknown port", network + "/" + service} + return p, nil } diff --git a/src/pkg/net/port_test.go b/src/pkg/net/port_test.go index 329b169f3..9e8968f35 100644 --- a/src/pkg/net/port_test.go +++ b/src/pkg/net/port_test.go @@ -46,7 +46,7 @@ func TestLookupPort(t *testing.T) { for i := 0; i < len(porttests); i++ { tt := porttests[i] if port, err := LookupPort(tt.netw, tt.name); port != tt.port || (err == nil) != tt.ok { - t.Errorf("LookupPort(%q, %q) = %v, %s; want %v", + t.Errorf("LookupPort(%q, %q) = %v, %v; want %v", tt.netw, tt.name, port, err, tt.port) } } diff --git a/src/pkg/net/port_unix.go b/src/pkg/net/port_unix.go new file mode 100644 index 000000000..16780da11 --- /dev/null +++ b/src/pkg/net/port_unix.go @@ -0,0 +1,69 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd + +// Read system port mappings from /etc/services + +package net + +import "sync" + +var services map[string]map[string]int +var servicesError error +var onceReadServices sync.Once + +func readServices() { + services = make(map[string]map[string]int) + var file *file + if file, servicesError = open("/etc/services"); servicesError != nil { + return + } + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + // "http 80/tcp www www-http # World Wide Web HTTP" + if i := byteIndex(line, '#'); i >= 0 { + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 { + continue + } + portnet := f[1] // "tcp/80" + port, j, ok := dtoi(portnet, 0) + if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' { + continue + } + netw := portnet[j+1:] // "tcp" + m, ok1 := services[netw] + if !ok1 { + m = make(map[string]int) + services[netw] = m + } + for i := 0; i < len(f); i++ { + if i != 1 { // f[1] was port/net + m[f[i]] = port + } + } + } + file.close() +} + +// goLookupPort is the native Go implementation of LookupPort. +func goLookupPort(network, service string) (port int, err error) { + onceReadServices.Do(readServices) + + switch network { + case "tcp4", "tcp6": + network = "tcp" + case "udp4", "udp6": + network = "udp" + } + + if m, ok := services[network]; ok { + if port, ok = m[service]; ok { + return + } + } + return 0, &AddrError{"unknown port", network + "/" + service} +} diff --git a/src/pkg/net/protoconn_test.go b/src/pkg/net/protoconn_test.go new file mode 100644 index 000000000..2fe7d1d1f --- /dev/null +++ b/src/pkg/net/protoconn_test.go @@ -0,0 +1,358 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements API tests across platforms and will never have a build +// tag. + +package net + +import ( + "io/ioutil" + "os" + "runtime" + "testing" + "time" +) + +// testUnixAddr uses ioutil.TempFile to get a name that is unique. +func testUnixAddr() string { + f, err := ioutil.TempFile("", "nettest") + if err != nil { + panic(err) + } + addr := f.Name() + f.Close() + os.Remove(addr) + return addr +} + +var condFatalf = func() func(*testing.T, string, ...interface{}) { + // A few APIs are not implemented yet on both Plan 9 and Windows. + switch runtime.GOOS { + case "plan9", "windows": + return (*testing.T).Logf + } + return (*testing.T).Fatalf +}() + +func TestTCPListenerSpecificMethods(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + la, err := ResolveTCPAddr("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveTCPAddr failed: %v", err) + } + ln, err := ListenTCP("tcp4", la) + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + defer ln.Close() + ln.Addr() + ln.SetDeadline(time.Now().Add(30 * time.Nanosecond)) + + if c, err := ln.Accept(); err != nil { + if !err.(Error).Timeout() { + t.Fatalf("TCPListener.Accept failed: %v", err) + } + } else { + c.Close() + } + if c, err := ln.AcceptTCP(); err != nil { + if !err.(Error).Timeout() { + t.Fatalf("TCPListener.AcceptTCP failed: %v", err) + } + } else { + c.Close() + } + + if f, err := ln.File(); err != nil { + condFatalf(t, "TCPListener.File failed: %v", err) + } else { + f.Close() + } +} + +func TestTCPConnSpecificMethods(t *testing.T) { + la, err := ResolveTCPAddr("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveTCPAddr failed: %v", err) + } + ln, err := ListenTCP("tcp4", la) + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + defer ln.Close() + ln.Addr() + + done := make(chan int) + go transponder(t, ln, done) + + ra, err := ResolveTCPAddr("tcp4", ln.Addr().String()) + if err != nil { + t.Fatalf("ResolveTCPAddr failed: %v", err) + } + c, err := DialTCP("tcp4", nil, ra) + if err != nil { + t.Fatalf("DialTCP failed: %v", err) + } + defer c.Close() + c.SetKeepAlive(false) + c.SetLinger(0) + c.SetNoDelay(false) + c.LocalAddr() + c.RemoteAddr() + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) + + if _, err := c.Write([]byte("TCPCONN TEST")); err != nil { + t.Fatalf("TCPConn.Write failed: %v", err) + } + rb := make([]byte, 128) + if _, err := c.Read(rb); err != nil { + t.Fatalf("TCPConn.Read failed: %v", err) + } + + <-done +} + +func TestUDPConnSpecificMethods(t *testing.T) { + la, err := ResolveUDPAddr("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr failed: %v", err) + } + c, err := ListenUDP("udp4", la) + if err != nil { + t.Fatalf("ListenUDP failed: %v", err) + } + defer c.Close() + c.LocalAddr() + c.RemoteAddr() + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) + c.SetReadBuffer(2048) + c.SetWriteBuffer(2048) + + wb := []byte("UDPCONN TEST") + rb := make([]byte, 128) + if _, err := c.WriteToUDP(wb, c.LocalAddr().(*UDPAddr)); err != nil { + t.Fatalf("UDPConn.WriteToUDP failed: %v", err) + } + if _, _, err := c.ReadFromUDP(rb); err != nil { + t.Fatalf("UDPConn.ReadFromUDP failed: %v", err) + } + if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*UDPAddr)); err != nil { + condFatalf(t, "UDPConn.WriteMsgUDP failed: %v", err) + } + if _, _, _, _, err := c.ReadMsgUDP(rb, nil); err != nil { + condFatalf(t, "UDPConn.ReadMsgUDP failed: %v", err) + } + + if f, err := c.File(); err != nil { + condFatalf(t, "UDPConn.File failed: %v", err) + } else { + f.Close() + } +} + +func TestIPConnSpecificMethods(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping read test on %q", runtime.GOOS) + } + if os.Getuid() != 0 { + t.Skipf("skipping test; must be root") + } + + la, err := ResolveIPAddr("ip4", "127.0.0.1") + if err != nil { + t.Fatalf("ResolveIPAddr failed: %v", err) + } + c, err := ListenIP("ip4:icmp", la) + if err != nil { + t.Fatalf("ListenIP failed: %v", err) + } + defer c.Close() + c.LocalAddr() + c.RemoteAddr() + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) + c.SetReadBuffer(2048) + c.SetWriteBuffer(2048) + + wb, err := (&icmpMessage{ + Type: icmpv4EchoRequest, Code: 0, + Body: &icmpEcho{ + ID: os.Getpid() & 0xffff, Seq: 1, + Data: []byte("IPCONN TEST "), + }, + }).Marshal() + if err != nil { + t.Fatalf("icmpMessage.Marshal failed: %v", err) + } + rb := make([]byte, 20+128) + if _, err := c.WriteToIP(wb, c.LocalAddr().(*IPAddr)); err != nil { + t.Fatalf("IPConn.WriteToIP failed: %v", err) + } + if _, _, err := c.ReadFromIP(rb); err != nil { + t.Fatalf("IPConn.ReadFromIP failed: %v", err) + } + if _, _, err := c.WriteMsgIP(wb, nil, c.LocalAddr().(*IPAddr)); err != nil { + condFatalf(t, "IPConn.WriteMsgIP failed: %v", err) + } + if _, _, _, _, err := c.ReadMsgIP(rb, nil); err != nil { + condFatalf(t, "IPConn.ReadMsgIP failed: %v", err) + } + + if f, err := c.File(); err != nil { + condFatalf(t, "IPConn.File failed: %v", err) + } else { + f.Close() + } +} + +func TestUnixListenerSpecificMethods(t *testing.T) { + switch runtime.GOOS { + case "plan9", "windows": + t.Skipf("skipping read test on %q", runtime.GOOS) + } + + addr := testUnixAddr() + la, err := ResolveUnixAddr("unix", addr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + ln, err := ListenUnix("unix", la) + if err != nil { + t.Fatalf("ListenUnix failed: %v", err) + } + defer ln.Close() + defer os.Remove(addr) + ln.Addr() + ln.SetDeadline(time.Now().Add(30 * time.Nanosecond)) + + if c, err := ln.Accept(); err != nil { + if !err.(Error).Timeout() { + t.Fatalf("UnixListener.Accept failed: %v", err) + } + } else { + c.Close() + } + if c, err := ln.AcceptUnix(); err != nil { + if !err.(Error).Timeout() { + t.Fatalf("UnixListener.AcceptUnix failed: %v", err) + } + } else { + c.Close() + } + + if f, err := ln.File(); err != nil { + t.Fatalf("UnixListener.File failed: %v", err) + } else { + f.Close() + } +} + +func TestUnixConnSpecificMethods(t *testing.T) { + switch runtime.GOOS { + case "plan9", "windows": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + addr1, addr2, addr3 := testUnixAddr(), testUnixAddr(), testUnixAddr() + + a1, err := ResolveUnixAddr("unixgram", addr1) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c1, err := DialUnix("unixgram", a1, nil) + if err != nil { + t.Fatalf("DialUnix failed: %v", err) + } + defer c1.Close() + defer os.Remove(addr1) + c1.LocalAddr() + c1.RemoteAddr() + c1.SetDeadline(time.Now().Add(someTimeout)) + c1.SetReadDeadline(time.Now().Add(someTimeout)) + c1.SetWriteDeadline(time.Now().Add(someTimeout)) + c1.SetReadBuffer(2048) + c1.SetWriteBuffer(2048) + + a2, err := ResolveUnixAddr("unixgram", addr2) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c2, err := DialUnix("unixgram", a2, nil) + if err != nil { + t.Fatalf("DialUnix failed: %v", err) + } + defer c2.Close() + defer os.Remove(addr2) + c2.LocalAddr() + c2.RemoteAddr() + c2.SetDeadline(time.Now().Add(someTimeout)) + c2.SetReadDeadline(time.Now().Add(someTimeout)) + c2.SetWriteDeadline(time.Now().Add(someTimeout)) + c2.SetReadBuffer(2048) + c2.SetWriteBuffer(2048) + + a3, err := ResolveUnixAddr("unixgram", addr3) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c3, err := ListenUnixgram("unixgram", a3) + if err != nil { + t.Fatalf("ListenUnixgram failed: %v", err) + } + defer c3.Close() + defer os.Remove(addr3) + c3.LocalAddr() + c3.RemoteAddr() + c3.SetDeadline(time.Now().Add(someTimeout)) + c3.SetReadDeadline(time.Now().Add(someTimeout)) + c3.SetWriteDeadline(time.Now().Add(someTimeout)) + c3.SetReadBuffer(2048) + c3.SetWriteBuffer(2048) + + wb := []byte("UNIXCONN TEST") + rb1 := make([]byte, 128) + rb2 := make([]byte, 128) + rb3 := make([]byte, 128) + if _, _, err := c1.WriteMsgUnix(wb, nil, a2); err != nil { + t.Fatalf("UnixConn.WriteMsgUnix failed: %v", err) + } + if _, _, _, _, err := c2.ReadMsgUnix(rb2, nil); err != nil { + t.Fatalf("UnixConn.ReadMsgUnix failed: %v", err) + } + if _, err := c2.WriteToUnix(wb, a1); err != nil { + t.Fatalf("UnixConn.WriteToUnix failed: %v", err) + } + if _, _, err := c1.ReadFromUnix(rb1); err != nil { + t.Fatalf("UnixConn.ReadFromUnix failed: %v", err) + } + if _, err := c3.WriteToUnix(wb, a1); err != nil { + t.Fatalf("UnixConn.WriteToUnix failed: %v", err) + } + if _, _, err := c1.ReadFromUnix(rb1); err != nil { + t.Fatalf("UnixConn.ReadFromUnix failed: %v", err) + } + if _, err := c2.WriteToUnix(wb, a3); err != nil { + t.Fatalf("UnixConn.WriteToUnix failed: %v", err) + } + if _, _, err := c3.ReadFromUnix(rb3); err != nil { + t.Fatalf("UnixConn.ReadFromUnix failed: %v", err) + } + + if f, err := c1.File(); err != nil { + t.Fatalf("UnixConn.File failed: %v", err) + } else { + f.Close() + } +} diff --git a/src/pkg/net/rpc/client.go b/src/pkg/net/rpc/client.go index db2da8e44..4b0c9c3bb 100644 --- a/src/pkg/net/rpc/client.go +++ b/src/pkg/net/rpc/client.go @@ -71,7 +71,7 @@ func (client *Client) send(call *Call) { // Register this call. client.mutex.Lock() - if client.shutdown { + if client.shutdown || client.closing { call.Error = ErrShutdown client.mutex.Unlock() call.done() @@ -88,10 +88,13 @@ func (client *Client) send(call *Call) { err := client.codec.WriteRequest(&client.request, call.Args) if err != nil { client.mutex.Lock() + call = client.pending[seq] delete(client.pending, seq) client.mutex.Unlock() - call.Error = err - call.done() + if call != nil { + call.Error = err + call.done() + } } } @@ -102,9 +105,6 @@ func (client *Client) input() { response = Response{} err = client.codec.ReadResponseHeader(&response) if err != nil { - if err == io.EOF && !client.closing { - err = io.ErrUnexpectedEOF - } break } seq := response.Seq @@ -113,12 +113,18 @@ func (client *Client) input() { delete(client.pending, seq) client.mutex.Unlock() - if response.Error == "" { - err = client.codec.ReadResponseBody(call.Reply) + switch { + case call == nil: + // We've got no pending call. That usually means that + // WriteRequest partially failed, and call was already + // removed; response is a server telling us about an + // error reading request body. We should still attempt + // to read error body, but there's no one to give it to. + err = client.codec.ReadResponseBody(nil) if err != nil { - call.Error = errors.New("reading body " + err.Error()) + err = errors.New("reading error body: " + err.Error()) } - } else { + case response.Error != "": // We've got an error response. Give this to the request; // any subsequent requests will get the ReadResponseBody // error if there is one. @@ -127,14 +133,27 @@ func (client *Client) input() { if err != nil { err = errors.New("reading error body: " + err.Error()) } + call.done() + default: + err = client.codec.ReadResponseBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() } - call.done() } // Terminate pending calls. client.sending.Lock() client.mutex.Lock() client.shutdown = true closing := client.closing + if err == io.EOF { + if closing { + err = ErrShutdown + } else { + err = io.ErrUnexpectedEOF + } + } for _, call := range client.pending { call.Error = err call.done() @@ -213,7 +232,7 @@ func DialHTTP(network, address string) (*Client, error) { return DialHTTPPath(network, address, DefaultRPCPath) } -// DialHTTPPath connects to an HTTP RPC server +// DialHTTPPath connects to an HTTP RPC server // at the specified network address and path. func DialHTTPPath(network, address, path string) (*Client, error) { var err error diff --git a/src/pkg/net/rpc/jsonrpc/all_test.go b/src/pkg/net/rpc/jsonrpc/all_test.go index e6c7441f0..3c7c4d48f 100644 --- a/src/pkg/net/rpc/jsonrpc/all_test.go +++ b/src/pkg/net/rpc/jsonrpc/all_test.go @@ -24,6 +24,12 @@ type Reply struct { type Arith int +type ArithAddResp struct { + Id interface{} `json:"id"` + Result Reply `json:"result"` + Error interface{} `json:"error"` +} + func (t *Arith) Add(args *Args, reply *Reply) error { reply.C = args.A + args.B return nil @@ -50,13 +56,39 @@ func init() { rpc.Register(new(Arith)) } -func TestServer(t *testing.T) { - type addResp struct { - Id interface{} `json:"id"` - Result Reply `json:"result"` - Error interface{} `json:"error"` +func TestServerNoParams(t *testing.T) { + cli, srv := net.Pipe() + defer cli.Close() + go ServeConn(srv) + dec := json.NewDecoder(cli) + + fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`) + var resp ArithAddResp + if err := dec.Decode(&resp); err != nil { + t.Fatalf("Decode after no params: %s", err) + } + if resp.Error == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestServerEmptyMessage(t *testing.T) { + cli, srv := net.Pipe() + defer cli.Close() + go ServeConn(srv) + dec := json.NewDecoder(cli) + + fmt.Fprintf(cli, "{}") + var resp ArithAddResp + if err := dec.Decode(&resp); err != nil { + t.Fatalf("Decode after empty: %s", err) } + if resp.Error == nil { + t.Fatalf("Expected error, got nil") + } +} +func TestServer(t *testing.T) { cli, srv := net.Pipe() defer cli.Close() go ServeConn(srv) @@ -65,7 +97,7 @@ func TestServer(t *testing.T) { // Send hand-coded requests to server, parse responses. for i := 0; i < 10; i++ { fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1) - var resp addResp + var resp ArithAddResp err := dec.Decode(&resp) if err != nil { t.Fatalf("Decode: %s", err) @@ -80,15 +112,6 @@ func TestServer(t *testing.T) { t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C) } } - - fmt.Fprintf(cli, "{}\n") - var resp addResp - if err := dec.Decode(&resp); err != nil { - t.Fatalf("Decode after empty: %s", err) - } - if resp.Error == nil { - t.Fatalf("Expected error, got nil") - } } func TestClient(t *testing.T) { @@ -108,7 +131,7 @@ func TestClient(t *testing.T) { t.Errorf("Add: expected no error but got string %q", err.Error()) } if reply.C != args.A+args.B { - t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B) } args = &Args{7, 8} @@ -118,7 +141,7 @@ func TestClient(t *testing.T) { t.Errorf("Mul: expected no error but got string %q", err.Error()) } if reply.C != args.A*args.B { - t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B) } // Out of order. @@ -133,7 +156,7 @@ func TestClient(t *testing.T) { t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) } if addReply.C != args.A+args.B { - t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) + t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B) } mulCall = <-mulCall.Done @@ -141,7 +164,7 @@ func TestClient(t *testing.T) { t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) } if mulReply.C != args.A*args.B { - t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) + t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B) } // Error test diff --git a/src/pkg/net/rpc/jsonrpc/server.go b/src/pkg/net/rpc/jsonrpc/server.go index 4c54553a7..5bc05fd0a 100644 --- a/src/pkg/net/rpc/jsonrpc/server.go +++ b/src/pkg/net/rpc/jsonrpc/server.go @@ -12,6 +12,8 @@ import ( "sync" ) +var errMissingParams = errors.New("jsonrpc: request body missing params") + type serverCodec struct { dec *json.Decoder // for reading JSON values enc *json.Encoder // for writing JSON values @@ -50,12 +52,8 @@ type serverRequest struct { func (r *serverRequest) reset() { r.Method = "" - if r.Params != nil { - *r.Params = (*r.Params)[0:0] - } - if r.Id != nil { - *r.Id = (*r.Id)[0:0] - } + r.Params = nil + r.Id = nil } type serverResponse struct { @@ -88,6 +86,9 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error { if x == nil { return nil } + if c.req.Params == nil { + return errMissingParams + } // JSON params is array value. // RPC params is struct. // Unmarshal into array containing struct for now. diff --git a/src/pkg/net/rpc/server.go b/src/pkg/net/rpc/server.go index 1680e2f0d..e71b6fb1a 100644 --- a/src/pkg/net/rpc/server.go +++ b/src/pkg/net/rpc/server.go @@ -24,12 +24,13 @@ where T, T1 and T2 can be marshaled by encoding/gob. These requirements apply even if a different codec is used. - (In future, these requirements may soften for custom codecs.) + (In the future, these requirements may soften for custom codecs.) The method's first argument represents the arguments provided by the caller; the second argument represents the result parameters to be returned to the caller. The method's return value, if non-nil, is passed back as a string that the client - sees as if created by errors.New. + sees as if created by errors.New. If an error is returned, the reply parameter + will not be sent back to the client. The server may handle requests on a single connection by calling ServeConn. More typically it will create a network listener and call Accept or, for an HTTP @@ -111,7 +112,7 @@ // Asynchronous call quotient := new(Quotient) - divCall := client.Go("Arith.Divide", args, "ient, nil) + divCall := client.Go("Arith.Divide", args, quotient, nil) replyCall := <-divCall.Done // will be equal to divCall // check errors, print, etc. @@ -181,7 +182,7 @@ type Response struct { // Server represents an RPC Server. type Server struct { - mu sync.Mutex // protects the serviceMap + mu sync.RWMutex // protects the serviceMap serviceMap map[string]*service reqLock sync.Mutex // protects freeReq freeReq *Request @@ -218,15 +219,15 @@ func isExportedOrBuiltinType(t reflect.Type) bool { // - exported method // - two arguments, both pointers to exported structs // - one return value, of type error -// It returns an error if the receiver is not an exported type or has no -// suitable methods. +// It returns an error if the receiver is not an exported type or has +// no methods or unsuitable methods. It also logs the error using package log. // The client accesses each method using a string of the form "Type.Method", // where Type is the receiver's concrete type. func (server *Server) Register(rcvr interface{}) error { return server.register(rcvr, "", false) } -// RegisterName is like Register but uses the provided name for the type +// RegisterName is like Register but uses the provided name for the type // instead of the receiver's concrete type. func (server *Server) RegisterName(name string, rcvr interface{}) error { return server.register(rcvr, name, true) @@ -260,8 +261,30 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro s.method = make(map[string]*methodType) // Install the methods - for m := 0; m < s.typ.NumMethod(); m++ { - method := s.typ.Method(m) + s.method = suitableMethods(s.typ, true) + + if len(s.method) == 0 { + str := "" + // To help the user, see if a pointer receiver would work. + method := suitableMethods(reflect.PtrTo(s.typ), false) + if len(method) != 0 { + str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)" + } else { + str = "rpc.Register: type " + sname + " has no exported methods of suitable type" + } + log.Print(str) + return errors.New(str) + } + server.serviceMap[s.name] = s + return nil +} + +// suitableMethods returns suitable Rpc methods of typ, it will report +// error using log if reportErr is true. +func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { + methods := make(map[string]*methodType) + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) mtype := method.Type mname := method.Name // Method must be exported. @@ -270,46 +293,51 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro } // Method needs three ins: receiver, *args, *reply. if mtype.NumIn() != 3 { - log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) + if reportErr { + log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) + } continue } // First arg need not be a pointer. argType := mtype.In(1) if !isExportedOrBuiltinType(argType) { - log.Println(mname, "argument type not exported:", argType) + if reportErr { + log.Println(mname, "argument type not exported:", argType) + } continue } // Second arg must be a pointer. replyType := mtype.In(2) if replyType.Kind() != reflect.Ptr { - log.Println("method", mname, "reply type not a pointer:", replyType) + if reportErr { + log.Println("method", mname, "reply type not a pointer:", replyType) + } continue } // Reply type must be exported. if !isExportedOrBuiltinType(replyType) { - log.Println("method", mname, "reply type not exported:", replyType) + if reportErr { + log.Println("method", mname, "reply type not exported:", replyType) + } continue } // Method needs one out. if mtype.NumOut() != 1 { - log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) + if reportErr { + log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) + } continue } // The return type of the method must be error. if returnType := mtype.Out(0); returnType != typeOfError { - log.Println("method", mname, "returns", returnType.String(), "not error") + if reportErr { + log.Println("method", mname, "returns", returnType.String(), "not error") + } continue } - s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} + methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} } - - if len(s.method) == 0 { - s := "rpc Register: type " + sname + " has no exported methods of suitable type" - log.Print(s) - return errors.New(s) - } - server.serviceMap[s.name] = s - return nil + return methods } // A value sent as a placeholder for the server's response value when the server @@ -538,9 +566,9 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt return } // Look up the request. - server.mu.Lock() + server.mu.RLock() service = server.serviceMap[serviceMethod[0]] - server.mu.Unlock() + server.mu.RUnlock() if service == nil { err = errors.New("rpc: can't find service " + req.ServiceMethod) return @@ -568,7 +596,7 @@ func (server *Server) Accept(lis net.Listener) { // Register publishes the receiver's methods in the DefaultServer. func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } -// RegisterName is like Register but uses the provided name for the type +// RegisterName is like Register but uses the provided name for the type // instead of the receiver's concrete type. func RegisterName(name string, rcvr interface{}) error { return DefaultServer.RegisterName(name, rcvr) @@ -611,7 +639,7 @@ func ServeRequest(codec ServerCodec) error { } // Accept accepts connections on the listener and serves requests -// to DefaultServer for each incoming connection. +// to DefaultServer for each incoming connection. // Accept blocks; the caller typically invokes it in a go statement. func Accept(lis net.Listener) { DefaultServer.Accept(lis) } diff --git a/src/pkg/net/rpc/server_test.go b/src/pkg/net/rpc/server_test.go index 62c7b1e60..8a1530623 100644 --- a/src/pkg/net/rpc/server_test.go +++ b/src/pkg/net/rpc/server_test.go @@ -349,6 +349,7 @@ func testServeRequest(t *testing.T, server *Server) { type ReplyNotPointer int type ArgNotPublic int type ReplyNotPublic int +type NeedsPtrType int type local struct{} func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error { @@ -363,19 +364,29 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error { return nil } +func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error { + return nil +} + // Check that registration handles lots of bad methods and a type with no suitable methods. func TestRegistrationError(t *testing.T) { err := Register(new(ReplyNotPointer)) if err == nil { - t.Errorf("expected error registering ReplyNotPointer") + t.Error("expected error registering ReplyNotPointer") } err = Register(new(ArgNotPublic)) if err == nil { - t.Errorf("expected error registering ArgNotPublic") + t.Error("expected error registering ArgNotPublic") } err = Register(new(ReplyNotPublic)) if err == nil { - t.Errorf("expected error registering ReplyNotPublic") + t.Error("expected error registering ReplyNotPublic") + } + err = Register(NeedsPtrType(0)) + if err == nil { + t.Error("expected error registering NeedsPtrType") + } else if !strings.Contains(err.Error(), "pointer") { + t.Error("expected hint when registering NeedsPtrType") } } @@ -434,7 +445,7 @@ func dialHTTP() (*Client, error) { return DialHTTP("tcp", httpServerAddr) } -func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { +func countMallocs(dial func() (*Client, error), t *testing.T) float64 { once.Do(startServer) client, err := dial() if err != nil { @@ -442,11 +453,7 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { } args := &Args{7, 8} reply := new(Reply) - memstats := new(runtime.MemStats) - runtime.ReadMemStats(memstats) - mallocs := 0 - memstats.Mallocs - const count = 100 - for i := 0; i < count; i++ { + return testing.AllocsPerRun(100, func() { err := client.Call("Arith.Add", args, reply) if err != nil { t.Errorf("Add: expected no error but got string %q", err.Error()) @@ -454,18 +461,15 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { if reply.C != args.A+args.B { t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) } - } - runtime.ReadMemStats(memstats) - mallocs += memstats.Mallocs - return mallocs / count + }) } func TestCountMallocs(t *testing.T) { - fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t)) + fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) } func TestCountMallocsOverHTTP(t *testing.T) { - fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t)) + fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) } type writeCrasher struct { @@ -499,6 +503,44 @@ func TestClientWriteError(t *testing.T) { w.done <- true } +func TestTCPClose(t *testing.T) { + once.Do(startServer) + + client, err := dialHTTP() + if err != nil { + t.Fatalf("dialing: %v", err) + } + defer client.Close() + + args := Args{17, 8} + var reply Reply + err = client.Call("Arith.Mul", args, &reply) + if err != nil { + t.Fatal("arith error:", err) + } + t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) + if reply.C != args.A*args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) + } +} + +func TestErrorAfterClientClose(t *testing.T) { + once.Do(startServer) + + client, err := dialHTTP() + if err != nil { + t.Fatalf("dialing: %v", err) + } + err = client.Close() + if err != nil { + t.Fatal("close error:", err) + } + err = client.Call("Arith.Add", &Args{7, 9}, new(Reply)) + if err != ErrShutdown { + t.Errorf("Forever: expected ErrShutdown got %v", err) + } +} + func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { b.StopTimer() once.Do(startServer) diff --git a/src/pkg/net/sendfile_freebsd.go b/src/pkg/net/sendfile_freebsd.go new file mode 100644 index 000000000..8008bc3b5 --- /dev/null +++ b/src/pkg/net/sendfile_freebsd.go @@ -0,0 +1,105 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "io" + "os" + "syscall" +) + +// maxSendfileSize is the largest chunk size we ask the kernel to copy +// at a time. +const maxSendfileSize int = 4 << 20 + +// sendFile copies the contents of r to c using the sendfile +// system call to minimize copies. +// +// if handled == true, sendFile returns the number of bytes copied and any +// non-EOF error. +// +// if handled == false, sendFile performed no work. +func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { + // FreeBSD uses 0 as the "until EOF" value. If you pass in more bytes than the + // file contains, it will loop back to the beginning ad nauseum until it's sent + // exactly the number of bytes told to. As such, we need to know exactly how many + // bytes to send. + var remain int64 = 0 + + lr, ok := r.(*io.LimitedReader) + if ok { + remain, r = lr.N, lr.R + if remain <= 0 { + return 0, nil, true + } + } + f, ok := r.(*os.File) + if !ok { + return 0, nil, false + } + + if remain == 0 { + fi, err := f.Stat() + if err != nil { + return 0, err, false + } + + remain = fi.Size() + } + + // The other quirk with FreeBSD's sendfile implementation is that it doesn't + // use the current position of the file -- if you pass it offset 0, it starts + // from offset 0. There's no way to tell it "start from current position", so + // we have to manage that explicitly. + pos, err := f.Seek(0, os.SEEK_CUR) + if err != nil { + return 0, err, false + } + + c.wio.Lock() + defer c.wio.Unlock() + if err := c.incref(false); err != nil { + return 0, err, true + } + defer c.decref() + + dst := c.sysfd + src := int(f.Fd()) + for remain > 0 { + n := maxSendfileSize + if int64(n) > remain { + n = int(remain) + } + pos1 := pos + n, err1 := syscall.Sendfile(dst, src, &pos1, n) + if n > 0 { + pos += int64(n) + written += int64(n) + remain -= int64(n) + } + if n == 0 && err1 == nil { + break + } + if err1 == syscall.EAGAIN { + if err1 = c.pollServer.WaitWrite(c); err1 == nil { + continue + } + } + if err1 == syscall.EINTR { + continue + } + if err1 != nil { + // This includes syscall.ENOSYS (no kernel + // support) and syscall.EINVAL (fd types which + // don't implement sendfile together) + err = &OpError{"sendfile", c.net, c.raddr, err1} + break + } + } + if lr != nil { + lr.N = remain + } + return written, err, written > 0 +} diff --git a/src/pkg/net/sendfile_linux.go b/src/pkg/net/sendfile_linux.go index a0d530362..3357e6538 100644 --- a/src/pkg/net/sendfile_linux.go +++ b/src/pkg/net/sendfile_linux.go @@ -58,8 +58,8 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { if n == 0 && err1 == nil { break } - if err1 == syscall.EAGAIN && c.wdeadline >= 0 { - if err1 = pollserver.WaitWrite(c); err1 == nil { + if err1 == syscall.EAGAIN { + if err1 = c.pollServer.WaitWrite(c); err1 == nil { continue } } diff --git a/src/pkg/net/sendfile_stub.go b/src/pkg/net/sendfile_stub.go index ff76ab9cf..3660849c1 100644 --- a/src/pkg/net/sendfile_stub.go +++ b/src/pkg/net/sendfile_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd netbsd openbsd +// +build darwin netbsd openbsd package net diff --git a/src/pkg/net/sendfile_windows.go b/src/pkg/net/sendfile_windows.go index f5a6d8804..2d64f2f5b 100644 --- a/src/pkg/net/sendfile_windows.go +++ b/src/pkg/net/sendfile_windows.go @@ -48,12 +48,12 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, nil, false } - c.wio.Lock() - defer c.wio.Unlock() if err := c.incref(false); err != nil { return 0, err, true } defer c.decref() + c.wio.Lock() + defer c.wio.Unlock() var o sendfileOp o.Init(c, 'w') diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go index 158b9477d..25c2be5a7 100644 --- a/src/pkg/net/server_test.go +++ b/src/pkg/net/server_test.go @@ -113,8 +113,7 @@ func TestStreamConnServer(t *testing.T) { case "tcp", "tcp4", "tcp6": _, port, err := SplitHostPort(taddr) if err != nil { - t.Errorf("SplitHostPort(%q) failed: %v", taddr, err) - return + t.Fatalf("SplitHostPort(%q) failed: %v", taddr, err) } taddr = tt.caddr + ":" + port } @@ -142,8 +141,7 @@ var seqpacketConnServerTests = []struct { func TestSeqpacketConnServer(t *testing.T) { if runtime.GOOS != "linux" { - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } for _, tt := range seqpacketConnServerTests { @@ -170,11 +168,11 @@ func TestSeqpacketConnServer(t *testing.T) { } func runStreamConnServer(t *testing.T, net, laddr string, listening chan<- string, done chan<- int) { + defer close(done) l, err := Listen(net, laddr) if err != nil { t.Errorf("Listen(%q, %q) failed: %v", net, laddr, err) listening <- "<nil>" - done <- 1 return } defer l.Close() @@ -189,13 +187,14 @@ func runStreamConnServer(t *testing.T, net, laddr string, listening chan<- strin } rw.Write(buf[0:n]) } - done <- 1 + close(done) } run: for { c, err := l.Accept() if err != nil { + t.Logf("Accept failed: %v", err) continue run } echodone := make(chan int) @@ -204,14 +203,12 @@ run: c.Close() break run } - done <- 1 } func runStreamConnClient(t *testing.T, net, taddr string, isEmpty bool) { c, err := Dial(net, taddr) if err != nil { - t.Errorf("Dial(%q, %q) failed: %v", net, taddr, err) - return + t.Fatalf("Dial(%q, %q) failed: %v", net, taddr, err) } defer c.Close() c.SetReadDeadline(time.Now().Add(1 * time.Second)) @@ -221,14 +218,12 @@ func runStreamConnClient(t *testing.T, net, taddr string, isEmpty bool) { wb = []byte("StreamConnClient by Dial\n") } if n, err := c.Write(wb); err != nil || n != len(wb) { - t.Errorf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb)) } rb := make([]byte, 1024) if n, err := c.Read(rb[0:]); err != nil || n != len(wb) { - t.Errorf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb)) } // Send explicit ending for unixpacket. @@ -334,8 +329,7 @@ func TestDatagramPacketConnServer(t *testing.T) { case "udp", "udp4", "udp6": _, port, err := SplitHostPort(taddr) if err != nil { - t.Errorf("SplitHostPort(%q) failed: %v", taddr, err) - return + t.Fatalf("SplitHostPort(%q) failed: %v", taddr, err) } taddr = tt.caddr + ":" + port tt.caddr += ":0" @@ -398,14 +392,12 @@ func runDatagramConnClient(t *testing.T, net, laddr, taddr string, isEmpty bool) case "udp", "udp4", "udp6": c, err = Dial(net, taddr) if err != nil { - t.Errorf("Dial(%q, %q) failed: %v", net, taddr, err) - return + t.Fatalf("Dial(%q, %q) failed: %v", net, taddr, err) } case "unixgram": c, err = DialUnix(net, &UnixAddr{laddr, net}, &UnixAddr{taddr, net}) if err != nil { - t.Errorf("DialUnix(%q, {%q, %q}) failed: %v", net, laddr, taddr, err) - return + t.Fatalf("DialUnix(%q, {%q, %q}) failed: %v", net, laddr, taddr, err) } } defer c.Close() @@ -416,14 +408,12 @@ func runDatagramConnClient(t *testing.T, net, laddr, taddr string, isEmpty bool) wb = []byte("DatagramConnClient by Dial\n") } if n, err := c.Write(wb[0:]); err != nil || n != len(wb) { - t.Errorf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb)) } rb := make([]byte, 1024) if n, err := c.Read(rb[0:]); err != nil || n != len(wb) { - t.Errorf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb)) } } @@ -434,20 +424,17 @@ func runDatagramPacketConnClient(t *testing.T, net, laddr, taddr string, isEmpty case "udp", "udp4", "udp6": ra, err = ResolveUDPAddr(net, taddr) if err != nil { - t.Errorf("ResolveUDPAddr(%q, %q) failed: %v", net, taddr, err) - return + t.Fatalf("ResolveUDPAddr(%q, %q) failed: %v", net, taddr, err) } case "unixgram": ra, err = ResolveUnixAddr(net, taddr) if err != nil { - t.Errorf("ResolveUxixAddr(%q, %q) failed: %v", net, taddr, err) - return + t.Fatalf("ResolveUxixAddr(%q, %q) failed: %v", net, taddr, err) } } c, err := ListenPacket(net, laddr) if err != nil { - t.Errorf("ListenPacket(%q, %q) faild: %v", net, laddr, err) - return + t.Fatalf("ListenPacket(%q, %q) faild: %v", net, laddr, err) } defer c.Close() c.SetReadDeadline(time.Now().Add(1 * time.Second)) @@ -457,13 +444,11 @@ func runDatagramPacketConnClient(t *testing.T, net, laddr, taddr string, isEmpty wb = []byte("DatagramPacketConnClient by ListenPacket\n") } if n, err := c.WriteTo(wb[0:], ra); err != nil || n != len(wb) { - t.Errorf("WriteTo(%v) failed: %v, %v; want %v, <nil>", ra, n, err, len(wb)) - return + t.Fatalf("WriteTo(%v) failed: %v, %v; want %v, <nil>", ra, n, err, len(wb)) } rb := make([]byte, 1024) if n, _, err := c.ReadFrom(rb[0:]); err != nil || n != len(wb) { - t.Errorf("ReadFrom failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("ReadFrom failed: %v, %v; want %v, <nil>", n, err, len(wb)) } } diff --git a/src/pkg/net/smtp/smtp.go b/src/pkg/net/smtp/smtp.go index 59f6449f0..4b9177877 100644 --- a/src/pkg/net/smtp/smtp.go +++ b/src/pkg/net/smtp/smtp.go @@ -13,6 +13,7 @@ package smtp import ( "crypto/tls" "encoding/base64" + "errors" "io" "net" "net/textproto" @@ -33,7 +34,10 @@ type Client struct { // map of supported extensions ext map[string]string // supported auth mechanisms - auth []string + auth []string + localName string // the name to use in HELO/EHLO + didHello bool // whether we've said HELO/EHLO + helloError error // the error from the hello } // Dial returns a new Client connected to an SMTP server at addr. @@ -55,12 +59,33 @@ func NewClient(conn net.Conn, host string) (*Client, error) { text.Close() return nil, err } - c := &Client{Text: text, conn: conn, serverName: host} - err = c.ehlo() - if err != nil { - err = c.helo() + c := &Client{Text: text, conn: conn, serverName: host, localName: "localhost"} + return c, nil +} + +// hello runs a hello exchange if needed. +func (c *Client) hello() error { + if !c.didHello { + c.didHello = true + err := c.ehlo() + if err != nil { + c.helloError = c.helo() + } + } + return c.helloError +} + +// Hello sends a HELO or EHLO to the server as the given host name. +// Calling this method is only necessary if the client needs control +// over the host name used. The client will introduce itself as "localhost" +// automatically otherwise. If Hello is called, it must be called before +// any of the other methods. +func (c *Client) Hello(localName string) error { + if c.didHello { + return errors.New("smtp: Hello called after other methods") } - return c, err + c.localName = localName + return c.hello() } // cmd is a convenience function that sends a command and returns the response @@ -79,14 +104,14 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s // server does not support ehlo. func (c *Client) helo() error { c.ext = nil - _, _, err := c.cmd(250, "HELO localhost") + _, _, err := c.cmd(250, "HELO %s", c.localName) return err } // ehlo sends the EHLO (extended hello) greeting to the server. It // should be the preferred greeting for servers that support it. func (c *Client) ehlo() error { - _, msg, err := c.cmd(250, "EHLO localhost") + _, msg, err := c.cmd(250, "EHLO %s", c.localName) if err != nil { return err } @@ -113,6 +138,9 @@ func (c *Client) ehlo() error { // StartTLS sends the STARTTLS command and encrypts all further communication. // Only servers that advertise the STARTTLS extension support this function. func (c *Client) StartTLS(config *tls.Config) error { + if err := c.hello(); err != nil { + return err + } _, _, err := c.cmd(220, "STARTTLS") if err != nil { return err @@ -128,6 +156,9 @@ func (c *Client) StartTLS(config *tls.Config) error { // does not necessarily indicate an invalid address. Many servers // will not verify addresses for security reasons. func (c *Client) Verify(addr string) error { + if err := c.hello(); err != nil { + return err + } _, _, err := c.cmd(250, "VRFY %s", addr) return err } @@ -136,6 +167,9 @@ func (c *Client) Verify(addr string) error { // A failed authentication closes the connection. // Only servers that advertise the AUTH extension support this function. func (c *Client) Auth(a Auth) error { + if err := c.hello(); err != nil { + return err + } encoding := base64.StdEncoding mech, resp, err := a.Start(&ServerInfo{c.serverName, c.tls, c.auth}) if err != nil { @@ -178,6 +212,9 @@ func (c *Client) Auth(a Auth) error { // parameter. // This initiates a mail transaction and is followed by one or more Rcpt calls. func (c *Client) Mail(from string) error { + if err := c.hello(); err != nil { + return err + } cmdStr := "MAIL FROM:<%s>" if c.ext != nil { if _, ok := c.ext["8BITMIME"]; ok { @@ -227,6 +264,9 @@ func SendMail(addr string, a Auth, from string, to []string, msg []byte) error { if err != nil { return err } + if err := c.hello(); err != nil { + return err + } if ok, _ := c.Extension("STARTTLS"); ok { if err = c.StartTLS(nil); err != nil { return err @@ -267,6 +307,9 @@ func SendMail(addr string, a Auth, from string, to []string, msg []byte) error { // Extension also returns a string that contains any parameters the // server specifies for the extension. func (c *Client) Extension(ext string) (bool, string) { + if err := c.hello(); err != nil { + return false, "" + } if c.ext == nil { return false, "" } @@ -278,12 +321,18 @@ func (c *Client) Extension(ext string) (bool, string) { // Reset sends the RSET command to the server, aborting the current mail // transaction. func (c *Client) Reset() error { + if err := c.hello(); err != nil { + return err + } _, _, err := c.cmd(250, "RSET") return err } // Quit sends the QUIT command and closes the connection to the server. func (c *Client) Quit() error { + if err := c.hello(); err != nil { + return err + } _, _, err := c.cmd(221, "QUIT") if err != nil { return err diff --git a/src/pkg/net/smtp/smtp_test.go b/src/pkg/net/smtp/smtp_test.go index c315d185c..8317428cb 100644 --- a/src/pkg/net/smtp/smtp_test.go +++ b/src/pkg/net/smtp/smtp_test.go @@ -69,14 +69,14 @@ func (f faker) SetReadDeadline(time.Time) error { return nil } func (f faker) SetWriteDeadline(time.Time) error { return nil } func TestBasic(t *testing.T) { - basicServer = strings.Join(strings.Split(basicServer, "\n"), "\r\n") - basicClient = strings.Join(strings.Split(basicClient, "\n"), "\r\n") + server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") var cmdbuf bytes.Buffer bcmdbuf := bufio.NewWriter(&cmdbuf) var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(basicServer)), bcmdbuf) - c := &Client{Text: textproto.NewConn(fake)} + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c := &Client{Text: textproto.NewConn(fake), localName: "localhost"} if err := c.helo(); err != nil { t.Fatalf("HELO failed: %s", err) @@ -88,6 +88,7 @@ func TestBasic(t *testing.T) { t.Fatalf("Second EHLO failed: %s", err) } + c.didHello = true if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { t.Fatalf("Expected AUTH supported") } @@ -143,8 +144,8 @@ Goodbye.` bcmdbuf.Flush() actualcmds := cmdbuf.String() - if basicClient != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, basicClient) + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) } } @@ -187,8 +188,8 @@ QUIT ` func TestNewClient(t *testing.T) { - newClientServer = strings.Join(strings.Split(newClientServer, "\n"), "\r\n") - newClientClient = strings.Join(strings.Split(newClientClient, "\n"), "\r\n") + server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") + client := strings.Join(strings.Split(newClientClient, "\n"), "\r\n") var cmdbuf bytes.Buffer bcmdbuf := bufio.NewWriter(&cmdbuf) @@ -197,7 +198,7 @@ func TestNewClient(t *testing.T) { return cmdbuf.String() } var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(newClientServer)), bcmdbuf) + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) c, err := NewClient(fake, "fake.host") if err != nil { t.Fatalf("NewClient: %v\n(after %v)", err, out()) @@ -213,8 +214,8 @@ func TestNewClient(t *testing.T) { } actualcmds := out() - if newClientClient != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, newClientClient) + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) } } @@ -231,13 +232,13 @@ QUIT ` func TestNewClient2(t *testing.T) { - newClient2Server = strings.Join(strings.Split(newClient2Server, "\n"), "\r\n") - newClient2Client = strings.Join(strings.Split(newClient2Client, "\n"), "\r\n") + server := strings.Join(strings.Split(newClient2Server, "\n"), "\r\n") + client := strings.Join(strings.Split(newClient2Client, "\n"), "\r\n") var cmdbuf bytes.Buffer bcmdbuf := bufio.NewWriter(&cmdbuf) var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(newClient2Server)), bcmdbuf) + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) c, err := NewClient(fake, "fake.host") if err != nil { t.Fatalf("NewClient: %v", err) @@ -251,8 +252,8 @@ func TestNewClient2(t *testing.T) { bcmdbuf.Flush() actualcmds := cmdbuf.String() - if newClient2Client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, newClient2Client) + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) } } @@ -269,3 +270,199 @@ var newClient2Client = `EHLO localhost HELO localhost QUIT ` + +func TestHello(t *testing.T) { + + if len(helloServer) != len(helloClient) { + t.Fatalf("Hello server and client size mismatch") + } + + for i := 0; i < len(helloServer); i++ { + server := strings.Join(strings.Split(baseHelloServer+helloServer[i], "\n"), "\r\n") + client := strings.Join(strings.Split(baseHelloClient+helloClient[i], "\n"), "\r\n") + var cmdbuf bytes.Buffer + bcmdbuf := bufio.NewWriter(&cmdbuf) + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + c.localName = "customhost" + err = nil + + switch i { + case 0: + err = c.Hello("customhost") + case 1: + err = c.StartTLS(nil) + if err.Error() == "502 Not implemented" { + err = nil + } + case 2: + err = c.Verify("test@example.com") + case 3: + c.tls = true + c.serverName = "smtp.google.com" + err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com")) + case 4: + err = c.Mail("test@example.com") + case 5: + ok, _ := c.Extension("feature") + if ok { + t.Errorf("Expected FEATURE not to be supported") + } + case 6: + err = c.Reset() + case 7: + err = c.Quit() + case 8: + err = c.Verify("test@example.com") + if err != nil { + err = c.Hello("customhost") + if err != nil { + t.Errorf("Want error, got none") + } + } + default: + t.Fatalf("Unhandled command") + } + + if err != nil { + t.Errorf("Command %d failed: %v", i, err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + if client != actualcmds { + t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + } +} + +var baseHelloServer = `220 hello world +502 EH? +250-mx.google.com at your service +250 FEATURE +` + +var helloServer = []string{ + "", + "502 Not implemented\n", + "250 User is valid\n", + "235 Accepted\n", + "250 Sender ok\n", + "", + "250 Reset ok\n", + "221 Goodbye\n", + "250 Sender ok\n", +} + +var baseHelloClient = `EHLO customhost +HELO customhost +` + +var helloClient = []string{ + "", + "STARTTLS\n", + "VRFY test@example.com\n", + "AUTH PLAIN AHVzZXIAcGFzcw==\n", + "MAIL FROM:<test@example.com>\n", + "", + "RSET\n", + "QUIT\n", + "VRFY test@example.com\n", +} + +func TestSendMail(t *testing.T) { + server := strings.Join(strings.Split(sendMailServer, "\n"), "\r\n") + client := strings.Join(strings.Split(sendMailClient, "\n"), "\r\n") + var cmdbuf bytes.Buffer + bcmdbuf := bufio.NewWriter(&cmdbuf) + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to to create listener: %v", err) + } + defer l.Close() + + // prevent data race on bcmdbuf + var done = make(chan struct{}) + go func(data []string) { + + defer close(done) + + conn, err := l.Accept() + if err != nil { + t.Errorf("Accept error: %v", err) + return + } + defer conn.Close() + + tc := textproto.NewConn(conn) + for i := 0; i < len(data) && data[i] != ""; i++ { + tc.PrintfLine(data[i]) + for len(data[i]) >= 4 && data[i][3] == '-' { + i++ + tc.PrintfLine(data[i]) + } + if data[i] == "221 Goodbye" { + return + } + read := false + for !read || data[i] == "354 Go ahead" { + msg, err := tc.ReadLine() + bcmdbuf.Write([]byte(msg + "\r\n")) + read = true + if err != nil { + t.Errorf("Read error: %v", err) + return + } + if data[i] == "354 Go ahead" && msg == "." { + break + } + } + } + }(strings.Split(server, "\r\n")) + + err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com +To: other@example.com +Subject: SendMail test + +SendMail is working for me. +`, "\n", "\r\n", -1))) + + if err != nil { + t.Errorf("%v", err) + } + + <-done + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + if client != actualcmds { + t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } +} + +var sendMailServer = `220 hello world +502 EH? +250 mx.google.com at your service +250 Sender ok +250 Receiver ok +354 Go ahead +250 Data ok +221 Goodbye +` + +var sendMailClient = `EHLO localhost +HELO localhost +MAIL FROM:<test@example.com> +RCPT TO:<other@example.com> +DATA +From: test@example.com +To: other@example.com +Subject: SendMail test + +SendMail is working for me. +. +QUIT +` diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go deleted file mode 100644 index 3ae16054e..000000000 --- a/src/pkg/net/sock.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin freebsd linux netbsd openbsd windows - -// Sockets - -package net - -import ( - "io" - "syscall" -) - -var listenerBacklog = maxListenerBacklog() - -// Generic socket creation. -func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { - // See ../syscall/exec.go for description of ForkLock. - syscall.ForkLock.RLock() - s, err := syscall.Socket(f, t, p) - if err != nil { - syscall.ForkLock.RUnlock() - return nil, err - } - syscall.CloseOnExec(s) - syscall.ForkLock.RUnlock() - - err = setDefaultSockopts(s, f, t, ipv6only) - if err != nil { - closesocket(s) - return nil, err - } - - var bla syscall.Sockaddr - if la != nil { - bla, err = listenerSockaddr(s, f, la, toAddr) - if err != nil { - closesocket(s) - return nil, err - } - err = syscall.Bind(s, bla) - if err != nil { - closesocket(s) - return nil, err - } - } - - if fd, err = newFD(s, f, t, net); err != nil { - closesocket(s) - return nil, err - } - - if ra != nil { - if err = fd.connect(ra); err != nil { - closesocket(s) - fd.Close() - return nil, err - } - fd.isConnected = true - } - - sa, _ := syscall.Getsockname(s) - var laddr Addr - if la != nil && bla != la { - laddr = toAddr(la) - } else { - laddr = toAddr(sa) - } - sa, _ = syscall.Getpeername(s) - raddr := toAddr(sa) - - fd.setAddr(laddr, raddr) - return fd, nil -} - -type writerOnly struct { - io.Writer -} - -// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't -// applicable. -func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) { - // Use wrapper to hide existing r.ReadFrom from io.Copy. - return io.Copy(writerOnly{w}, r) -} diff --git a/src/pkg/net/sock_bsd.go b/src/pkg/net/sock_bsd.go index 2607b04c7..3205f9404 100644 --- a/src/pkg/net/sock_bsd.go +++ b/src/pkg/net/sock_bsd.go @@ -4,8 +4,6 @@ // +build darwin freebsd netbsd openbsd -// Sockets for BSD variants - package net import ( @@ -31,32 +29,3 @@ func maxListenerBacklog() int { } return int(n) } - -func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) { - a := toAddr(la) - if a == nil { - return la, nil - } - switch v := a.(type) { - case *TCPAddr, *UnixAddr: - err := setDefaultListenerSockopts(s) - if err != nil { - return nil, err - } - case *UDPAddr: - if v.IP.IsMulticast() { - err := setDefaultMulticastSockopts(s) - if err != nil { - return nil, err - } - switch f { - case syscall.AF_INET: - v.IP = IPv4zero - case syscall.AF_INET6: - v.IP = IPv6unspecified - } - return v.sockaddr(f) - } - } - return la, nil -} diff --git a/src/pkg/net/sock_cloexec.go b/src/pkg/net/sock_cloexec.go new file mode 100644 index 000000000..12d0f3488 --- /dev/null +++ b/src/pkg/net/sock_cloexec.go @@ -0,0 +1,69 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements sysSocket and accept for platforms that +// provide a fast path for setting SetNonblock and CloseOnExec. + +// +build linux + +package net + +import "syscall" + +// Wrapper around the socket system call that marks the returned file +// descriptor as nonblocking and close-on-exec. +func sysSocket(f, t, p int) (int, error) { + s, err := syscall.Socket(f, t|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, p) + // The SOCK_NONBLOCK and SOCK_CLOEXEC flags were introduced in + // Linux 2.6.27. If we get an EINVAL error, fall back to + // using socket without them. + if err == nil || err != syscall.EINVAL { + return s, err + } + + // See ../syscall/exec_unix.go for description of ForkLock. + syscall.ForkLock.RLock() + s, err = syscall.Socket(f, t, p) + if err == nil { + syscall.CloseOnExec(s) + } + syscall.ForkLock.RUnlock() + if err != nil { + return -1, err + } + if err = syscall.SetNonblock(s, true); err != nil { + syscall.Close(s) + return -1, err + } + return s, nil +} + +// Wrapper around the accept system call that marks the returned file +// descriptor as nonblocking and close-on-exec. +func accept(fd int) (int, syscall.Sockaddr, error) { + nfd, sa, err := syscall.Accept4(fd, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) + // The accept4 system call was introduced in Linux 2.6.28. If + // we get an ENOSYS error, fall back to using accept. + if err == nil || err != syscall.ENOSYS { + return nfd, sa, err + } + + // See ../syscall/exec_unix.go for description of ForkLock. + // It is probably okay to hold the lock across syscall.Accept + // because we have put fd.sysfd into non-blocking mode. + // However, a call to the File method will put it back into + // blocking mode. We can't take that risk, so no use of ForkLock here. + nfd, sa, err = syscall.Accept(fd) + if err == nil { + syscall.CloseOnExec(nfd) + } + if err != nil { + return -1, nil, err + } + if err = syscall.SetNonblock(nfd, true); err != nil { + syscall.Close(nfd) + return -1, nil, err + } + return nfd, sa, nil +} diff --git a/src/pkg/net/sock_linux.go b/src/pkg/net/sock_linux.go index e509d9397..8bbd74ddc 100644 --- a/src/pkg/net/sock_linux.go +++ b/src/pkg/net/sock_linux.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Sockets for Linux - package net import "syscall" @@ -25,32 +23,3 @@ func maxListenerBacklog() int { } return n } - -func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) { - a := toAddr(la) - if a == nil { - return la, nil - } - switch v := a.(type) { - case *TCPAddr, *UnixAddr: - err := setDefaultListenerSockopts(s) - if err != nil { - return nil, err - } - case *UDPAddr: - if v.IP.IsMulticast() { - err := setDefaultMulticastSockopts(s) - if err != nil { - return nil, err - } - switch f { - case syscall.AF_INET: - v.IP = IPv4zero - case syscall.AF_INET6: - v.IP = IPv6unspecified - } - return v.sockaddr(f) - } - } - return la, nil -} diff --git a/src/pkg/net/sock_posix.go b/src/pkg/net/sock_posix.go new file mode 100644 index 000000000..b50a892b1 --- /dev/null +++ b/src/pkg/net/sock_posix.go @@ -0,0 +1,67 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd windows + +package net + +import ( + "syscall" + "time" +) + +var listenerBacklog = maxListenerBacklog() + +// Generic POSIX socket creation. +func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, deadline time.Time, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { + s, err := sysSocket(f, t, p) + if err != nil { + return nil, err + } + + if err = setDefaultSockopts(s, f, t, ipv6only); err != nil { + closesocket(s) + return nil, err + } + + if ulsa != nil { + // We provide a socket that listens to a wildcard + // address with reusable UDP port when the given ulsa + // is an appropriate UDP multicast address prefix. + // This makes it possible for a single UDP listener + // to join multiple different group addresses, for + // multiple UDP listeners that listen on the same UDP + // port to join the same group address. + if ulsa, err = listenerSockaddr(s, f, ulsa, toAddr); err != nil { + closesocket(s) + return nil, err + } + if err = syscall.Bind(s, ulsa); err != nil { + closesocket(s) + return nil, err + } + } + + if fd, err = newFD(s, f, t, net); err != nil { + closesocket(s) + return nil, err + } + + if ursa != nil { + fd.wdeadline.setTime(deadline) + if err = fd.connect(ursa); err != nil { + closesocket(s) + return nil, err + } + fd.isConnected = true + fd.wdeadline.set(0) + } + + lsa, _ := syscall.Getsockname(s) + laddr := toAddr(lsa) + rsa, _ := syscall.Getpeername(s) + raddr := toAddr(rsa) + fd.setAddr(laddr, raddr) + return fd, nil +} diff --git a/src/pkg/net/sock_unix.go b/src/pkg/net/sock_unix.go new file mode 100644 index 000000000..b0d6d4900 --- /dev/null +++ b/src/pkg/net/sock_unix.go @@ -0,0 +1,36 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd + +package net + +import "syscall" + +func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) { + a := toAddr(la) + if a == nil { + return la, nil + } + switch a := a.(type) { + case *TCPAddr, *UnixAddr: + if err := setDefaultListenerSockopts(s); err != nil { + return nil, err + } + case *UDPAddr: + if a.IP.IsMulticast() { + if err := setDefaultMulticastSockopts(s); err != nil { + return nil, err + } + switch f { + case syscall.AF_INET: + a.IP = IPv4zero + case syscall.AF_INET6: + a.IP = IPv6unspecified + } + return a.sockaddr(f) + } + } + return la, nil +} diff --git a/src/pkg/net/sock_windows.go b/src/pkg/net/sock_windows.go index cce6181c9..a77c48437 100644 --- a/src/pkg/net/sock_windows.go +++ b/src/pkg/net/sock_windows.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Sockets for Windows - package net import "syscall" @@ -18,26 +16,35 @@ func listenerSockaddr(s syscall.Handle, f int, la syscall.Sockaddr, toAddr func( if a == nil { return la, nil } - switch v := a.(type) { + switch a := a.(type) { case *TCPAddr, *UnixAddr: - err := setDefaultListenerSockopts(s) - if err != nil { + if err := setDefaultListenerSockopts(s); err != nil { return nil, err } case *UDPAddr: - if v.IP.IsMulticast() { - err := setDefaultMulticastSockopts(s) - if err != nil { + if a.IP.IsMulticast() { + if err := setDefaultMulticastSockopts(s); err != nil { return nil, err } switch f { case syscall.AF_INET: - v.IP = IPv4zero + a.IP = IPv4zero case syscall.AF_INET6: - v.IP = IPv6unspecified + a.IP = IPv6unspecified } - return v.sockaddr(f) + return a.sockaddr(f) } } return la, nil } + +func sysSocket(f, t, p int) (syscall.Handle, error) { + // See ../syscall/exec_unix.go for description of ForkLock. + syscall.ForkLock.RLock() + s, err := syscall.Socket(f, t, p) + if err == nil { + syscall.CloseOnExec(s) + } + syscall.ForkLock.RUnlock() + return s, err +} diff --git a/src/pkg/net/sockopt.go b/src/pkg/net/sockopt_posix.go index 0cd19266f..fe371fe0c 100644 --- a/src/pkg/net/sockopt.go +++ b/src/pkg/net/sockopt_posix.go @@ -119,45 +119,22 @@ func setWriteBuffer(fd *netFD, bytes int) error { return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)) } +// TODO(dfc) these unused error returns could be removed + func setReadDeadline(fd *netFD, t time.Time) error { - if t.IsZero() { - fd.rdeadline = 0 - } else { - fd.rdeadline = t.UnixNano() - } + fd.rdeadline.setTime(t) return nil } func setWriteDeadline(fd *netFD, t time.Time) error { - if t.IsZero() { - fd.wdeadline = 0 - } else { - fd.wdeadline = t.UnixNano() - } + fd.wdeadline.setTime(t) return nil } func setDeadline(fd *netFD, t time.Time) error { - if err := setReadDeadline(fd, t); err != nil { - return err - } - return setWriteDeadline(fd, t) -} - -func setReuseAddr(fd *netFD, reuse bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse))) -} - -func setDontRoute(fd *netFD, dontroute bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute))) + setReadDeadline(fd, t) + setWriteDeadline(fd, t) + return nil } func setKeepAlive(fd *netFD, keepalive bool) error { diff --git a/src/pkg/net/sockoptip.go b/src/pkg/net/sockoptip.go deleted file mode 100644 index 1fcad4018..000000000 --- a/src/pkg/net/sockoptip.go +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin freebsd linux netbsd openbsd windows - -// IP-level socket options - -package net - -import ( - "os" - "syscall" -) - -func ipv4TOS(fd *netFD) (int, error) { - if err := fd.incref(false); err != nil { - return 0, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS) - if err != nil { - return 0, os.NewSyscallError("getsockopt", err) - } - return v, nil -} - -func setIPv4TOS(fd *netFD, v int) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS, v) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv4TTL(fd *netFD) (int, error) { - if err := fd.incref(false); err != nil { - return 0, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL) - if err != nil { - return 0, os.NewSyscallError("getsockopt", err) - } - return v, nil -} - -func setIPv4TTL(fd *netFD, v int) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL, v) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error { - mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}} - if err := setIPv4MreqToInterface(mreq, ifi); err != nil { - return err - } - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)) -} - -func leaveIPv4Group(fd *netFD, ifi *Interface, ip IP) error { - mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}} - if err := setIPv4MreqToInterface(mreq, ifi); err != nil { - return err - } - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_DROP_MEMBERSHIP, mreq)) -} - -func ipv6HopLimit(fd *netFD) (int, error) { - if err := fd.incref(false); err != nil { - return 0, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS) - if err != nil { - return 0, os.NewSyscallError("getsockopt", err) - } - return v, nil -} - -func setIPv6HopLimit(fd *netFD, v int) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, v) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv6MulticastInterface(fd *netFD) (*Interface, error) { - if err := fd.incref(false); err != nil { - return nil, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF) - if err != nil { - return nil, os.NewSyscallError("getsockopt", err) - } - if v == 0 { - return nil, nil - } - ifi, err := InterfaceByIndex(v) - if err != nil { - return nil, err - } - return ifi, nil -} - -func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error { - var v int - if ifi != nil { - v = ifi.Index - } - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv6MulticastHopLimit(fd *netFD) (int, error) { - if err := fd.incref(false); err != nil { - return 0, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS) - if err != nil { - return 0, os.NewSyscallError("getsockopt", err) - } - return v, nil -} - -func setIPv6MulticastHopLimit(fd *netFD, v int) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS, v) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv6MulticastLoopback(fd *netFD) (bool, error) { - if err := fd.incref(false); err != nil { - return false, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP) - if err != nil { - return false, os.NewSyscallError("getsockopt", err) - } - return v == 1, nil -} - -func setIPv6MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v)) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error { - mreq := &syscall.IPv6Mreq{} - copy(mreq.Multiaddr[:], ip) - if ifi != nil { - mreq.Interface = uint32(ifi.Index) - } - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)) -} - -func leaveIPv6Group(fd *netFD, ifi *Interface, ip IP) error { - mreq := &syscall.IPv6Mreq{} - copy(mreq.Multiaddr[:], ip) - if ifi != nil { - mreq.Interface = uint32(ifi.Index) - } - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_LEAVE_GROUP, mreq)) -} diff --git a/src/pkg/net/sockoptip_bsd.go b/src/pkg/net/sockoptip_bsd.go index 19e2b142e..263f85521 100644 --- a/src/pkg/net/sockoptip_bsd.go +++ b/src/pkg/net/sockoptip_bsd.go @@ -4,8 +4,6 @@ // +build darwin freebsd netbsd openbsd -// IP-level socket options for BSD variants - package net import ( @@ -13,48 +11,30 @@ import ( "syscall" ) -func ipv4MulticastTTL(fd *netFD) (int, error) { - if err := fd.incref(false); err != nil { - return 0, err - } - defer fd.decref() - v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL) +func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { + ip, err := interfaceToIPv4Addr(ifi) if err != nil { - return 0, os.NewSyscallError("getsockopt", err) + return os.NewSyscallError("setsockopt", err) } - return int(v), nil -} - -func setIPv4MulticastTTL(fd *netFD, v int) error { + var a [4]byte + copy(a[:], ip.To4()) if err := fd.incref(false); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, byte(v)) + err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, a) if err != nil { return os.NewSyscallError("setsockopt", err) } return nil } -func ipv6TrafficClass(fd *netFD) (int, error) { - if err := fd.incref(false); err != nil { - return 0, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS) - if err != nil { - return 0, os.NewSyscallError("getsockopt", err) - } - return v, nil -} - -func setIPv6TrafficClass(fd *netFD, v int) error { +func setIPv4MulticastLoopback(fd *netFD, v bool) error { if err := fd.incref(false); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v) + err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v))) if err != nil { return os.NewSyscallError("setsockopt", err) } diff --git a/src/pkg/net/sockoptip_darwin.go b/src/pkg/net/sockoptip_darwin.go deleted file mode 100644 index 52b237c4b..000000000 --- a/src/pkg/net/sockoptip_darwin.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// IP-level socket options for Darwin - -package net - -import ( - "os" - "syscall" -) - -func ipv4MulticastInterface(fd *netFD) (*Interface, error) { - if err := fd.incref(false); err != nil { - return nil, err - } - defer fd.decref() - a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF) - if err != nil { - return nil, os.NewSyscallError("getsockopt", err) - } - return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3])) -} - -func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { - ip, err := interfaceToIPv4Addr(ifi) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - var x [4]byte - copy(x[:], ip.To4()) - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv4MulticastLoopback(fd *netFD) (bool, error) { - if err := fd.incref(false); err != nil { - return false, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP) - if err != nil { - return false, os.NewSyscallError("getsockopt", err) - } - return v == 1, nil -} - -func setIPv4MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v)) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv4ReceiveInterface(fd *netFD) (bool, error) { - if err := fd.incref(false); err != nil { - return false, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF) - if err != nil { - return false, os.NewSyscallError("getsockopt", err) - } - return v == 1, nil -} - -func setIPv4ReceiveInterface(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v)) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} diff --git a/src/pkg/net/sockoptip_freebsd.go b/src/pkg/net/sockoptip_freebsd.go deleted file mode 100644 index 4a3bc2e82..000000000 --- a/src/pkg/net/sockoptip_freebsd.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// IP-level socket options for FreeBSD - -package net - -import ( - "os" - "syscall" -) - -func ipv4MulticastInterface(fd *netFD) (*Interface, error) { - if err := fd.incref(false); err != nil { - return nil, err - } - defer fd.decref() - mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF) - if err != nil { - return nil, os.NewSyscallError("getsockopt", err) - } - if int(mreq.Ifindex) == 0 { - return nil, nil - } - return InterfaceByIndex(int(mreq.Ifindex)) -} - -func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { - var v int32 - if ifi != nil { - v = int32(ifi.Index) - } - mreq := &syscall.IPMreqn{Ifindex: v} - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv4MulticastLoopback(fd *netFD) (bool, error) { - if err := fd.incref(false); err != nil { - return false, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP) - if err != nil { - return false, os.NewSyscallError("getsockopt", err) - } - return v == 1, nil -} - -func setIPv4MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v)) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv4ReceiveInterface(fd *netFD) (bool, error) { - if err := fd.incref(false); err != nil { - return false, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF) - if err != nil { - return false, os.NewSyscallError("getsockopt", err) - } - return v == 1, nil -} - -func setIPv4ReceiveInterface(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v)) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} diff --git a/src/pkg/net/sockoptip_linux.go b/src/pkg/net/sockoptip_linux.go index 169718f14..225fb0c4c 100644 --- a/src/pkg/net/sockoptip_linux.go +++ b/src/pkg/net/sockoptip_linux.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// IP-level socket options for Linux - package net import ( @@ -11,21 +9,6 @@ import ( "syscall" ) -func ipv4MulticastInterface(fd *netFD) (*Interface, error) { - if err := fd.incref(false); err != nil { - return nil, err - } - defer fd.decref() - mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF) - if err != nil { - return nil, os.NewSyscallError("getsockopt", err) - } - if int(mreq.Ifindex) == 0 { - return nil, nil - } - return InterfaceByIndex(int(mreq.Ifindex)) -} - func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { var v int32 if ifi != nil { @@ -43,42 +26,6 @@ func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { return nil } -func ipv4MulticastTTL(fd *netFD) (int, error) { - if err := fd.incref(false); err != nil { - return 0, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL) - if err != nil { - return -1, os.NewSyscallError("getsockopt", err) - } - return v, nil -} - -func setIPv4MulticastTTL(fd *netFD, v int) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, v) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv4MulticastLoopback(fd *netFD) (bool, error) { - if err := fd.incref(false); err != nil { - return false, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP) - if err != nil { - return false, os.NewSyscallError("getsockopt", err) - } - return v == 1, nil -} - func setIPv4MulticastLoopback(fd *netFD, v bool) error { if err := fd.incref(false); err != nil { return err @@ -90,51 +37,3 @@ func setIPv4MulticastLoopback(fd *netFD, v bool) error { } return nil } - -func ipv4ReceiveInterface(fd *netFD) (bool, error) { - if err := fd.incref(false); err != nil { - return false, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO) - if err != nil { - return false, os.NewSyscallError("getsockopt", err) - } - return v == 1, nil -} - -func setIPv4ReceiveInterface(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO, boolint(v)) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv6TrafficClass(fd *netFD) (int, error) { - if err := fd.incref(false); err != nil { - return 0, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS) - if err != nil { - return 0, os.NewSyscallError("getsockopt", err) - } - return v, nil -} - -func setIPv6TrafficClass(fd *netFD, v int) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} diff --git a/src/pkg/net/sockoptip_netbsd.go b/src/pkg/net/sockoptip_netbsd.go deleted file mode 100644 index 446d92aa3..000000000 --- a/src/pkg/net/sockoptip_netbsd.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// IP-level socket options for NetBSD - -package net - -import "syscall" - -func ipv4MulticastInterface(fd *netFD) (*Interface, error) { - // TODO: Implement this - return nil, syscall.EAFNOSUPPORT -} - -func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { - // TODO: Implement this - return syscall.EAFNOSUPPORT -} - -func ipv4MulticastLoopback(fd *netFD) (bool, error) { - // TODO: Implement this - return false, syscall.EAFNOSUPPORT -} - -func setIPv4MulticastLoopback(fd *netFD, v bool) error { - // TODO: Implement this - return syscall.EAFNOSUPPORT -} - -func ipv4ReceiveInterface(fd *netFD) (bool, error) { - // TODO: Implement this - return false, syscall.EAFNOSUPPORT -} - -func setIPv4ReceiveInterface(fd *netFD, v bool) error { - // TODO: Implement this - return syscall.EAFNOSUPPORT -} diff --git a/src/pkg/net/sockoptip_openbsd.go b/src/pkg/net/sockoptip_openbsd.go deleted file mode 100644 index f3e42f1a9..000000000 --- a/src/pkg/net/sockoptip_openbsd.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// IP-level socket options for OpenBSD - -package net - -import ( - "os" - "syscall" -) - -func ipv4MulticastInterface(fd *netFD) (*Interface, error) { - if err := fd.incref(false); err != nil { - return nil, err - } - defer fd.decref() - a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF) - if err != nil { - return nil, os.NewSyscallError("getsockopt", err) - } - return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3])) -} - -func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { - ip, err := interfaceToIPv4Addr(ifi) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - var x [4]byte - copy(x[:], ip.To4()) - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv4MulticastLoopback(fd *netFD) (bool, error) { - if err := fd.incref(false); err != nil { - return false, err - } - defer fd.decref() - v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP) - if err != nil { - return false, os.NewSyscallError("getsockopt", err) - } - return v == 1, nil -} - -func setIPv4MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v))) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -func ipv4ReceiveInterface(fd *netFD) (bool, error) { - if err := fd.incref(false); err != nil { - return false, err - } - defer fd.decref() - v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF) - if err != nil { - return false, os.NewSyscallError("getsockopt", err) - } - return v == 1, nil -} - -func setIPv4ReceiveInterface(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v)) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} diff --git a/src/pkg/net/sockoptip_posix.go b/src/pkg/net/sockoptip_posix.go new file mode 100644 index 000000000..e4c56a0e4 --- /dev/null +++ b/src/pkg/net/sockoptip_posix.go @@ -0,0 +1,73 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd windows + +package net + +import ( + "os" + "syscall" +) + +func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error { + mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}} + if err := setIPv4MreqToInterface(mreq, ifi); err != nil { + return err + } + if err := fd.incref(false); err != nil { + return err + } + defer fd.decref() + err := syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error { + var v int + if ifi != nil { + v = ifi.Index + } + if err := fd.incref(false); err != nil { + return err + } + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func setIPv6MulticastLoopback(fd *netFD, v bool) error { + if err := fd.incref(false); err != nil { + return err + } + defer fd.decref() + err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v)) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} + +func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error { + mreq := &syscall.IPv6Mreq{} + copy(mreq.Multiaddr[:], ip) + if ifi != nil { + mreq.Interface = uint32(ifi.Index) + } + if err := fd.incref(false); err != nil { + return err + } + defer fd.decref() + err := syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq) + if err != nil { + return os.NewSyscallError("setsockopt", err) + } + return nil +} diff --git a/src/pkg/net/sockoptip_windows.go b/src/pkg/net/sockoptip_windows.go index b9db3334d..3e248441a 100644 --- a/src/pkg/net/sockoptip_windows.go +++ b/src/pkg/net/sockoptip_windows.go @@ -2,90 +2,41 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// IP-level socket options for Windows - package net import ( "os" "syscall" + "unsafe" ) -func ipv4MulticastInterface(fd *netFD) (*Interface, error) { - // TODO: Implement this - return nil, syscall.EWINDOWS -} - func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { ip, err := interfaceToIPv4Addr(ifi) if err != nil { return os.NewSyscallError("setsockopt", err) } - var x [4]byte - copy(x[:], ip.To4()) + var a [4]byte + copy(a[:], ip.To4()) if err := fd.incref(false); err != nil { return err } defer fd.decref() - err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x) + err = syscall.Setsockopt(fd.sysfd, int32(syscall.IPPROTO_IP), int32(syscall.IP_MULTICAST_IF), (*byte)(unsafe.Pointer(&a[0])), 4) if err != nil { return os.NewSyscallError("setsockopt", err) } return nil } -func ipv4MulticastTTL(fd *netFD) (int, error) { - // TODO: Implement this - return -1, syscall.EWINDOWS -} - -func setIPv4MulticastTTL(fd *netFD, v int) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, v) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil - -} - -func ipv4MulticastLoopback(fd *netFD) (bool, error) { - // TODO: Implement this - return false, syscall.EWINDOWS -} - func setIPv4MulticastLoopback(fd *netFD, v bool) error { if err := fd.incref(false); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v)) + vv := int32(boolint(v)) + err := syscall.Setsockopt(fd.sysfd, int32(syscall.IPPROTO_IP), int32(syscall.IP_MULTICAST_LOOP), (*byte)(unsafe.Pointer(&vv)), 4) if err != nil { return os.NewSyscallError("setsockopt", err) } return nil - -} - -func ipv4ReceiveInterface(fd *netFD) (bool, error) { - // TODO: Implement this - return false, syscall.EWINDOWS -} - -func setIPv4ReceiveInterface(fd *netFD, v bool) error { - // TODO: Implement this - return syscall.EWINDOWS -} - -func ipv6TrafficClass(fd *netFD) (int, error) { - // TODO: Implement this - return 0, syscall.EWINDOWS -} - -func setIPv6TrafficClass(fd *netFD, v int) error { - // TODO: Implement this - return syscall.EWINDOWS } diff --git a/src/pkg/net/sys_cloexec.go b/src/pkg/net/sys_cloexec.go new file mode 100644 index 000000000..17e874908 --- /dev/null +++ b/src/pkg/net/sys_cloexec.go @@ -0,0 +1,54 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements sysSocket and accept for platforms that do not +// provide a fast path for setting SetNonblock and CloseOnExec. + +// +build darwin freebsd netbsd openbsd + +package net + +import "syscall" + +// Wrapper around the socket system call that marks the returned file +// descriptor as nonblocking and close-on-exec. +func sysSocket(f, t, p int) (int, error) { + // See ../syscall/exec_unix.go for description of ForkLock. + syscall.ForkLock.RLock() + s, err := syscall.Socket(f, t, p) + if err == nil { + syscall.CloseOnExec(s) + } + syscall.ForkLock.RUnlock() + if err != nil { + return -1, err + } + if err = syscall.SetNonblock(s, true); err != nil { + syscall.Close(s) + return -1, err + } + return s, nil +} + +// Wrapper around the accept system call that marks the returned file +// descriptor as nonblocking and close-on-exec. +func accept(fd int) (int, syscall.Sockaddr, error) { + // See ../syscall/exec_unix.go for description of ForkLock. + // It is probably okay to hold the lock across syscall.Accept + // because we have put fd.sysfd into non-blocking mode. + // However, a call to the File method will put it back into + // blocking mode. We can't take that risk, so no use of ForkLock here. + nfd, sa, err := syscall.Accept(fd) + if err == nil { + syscall.CloseOnExec(nfd) + } + if err != nil { + return -1, nil, err + } + if err = syscall.SetNonblock(nfd, true); err != nil { + syscall.Close(nfd) + return -1, nil, err + } + return nfd, sa, nil +} diff --git a/src/pkg/net/tcp_test.go b/src/pkg/net/tcp_test.go new file mode 100644 index 000000000..6c4485a94 --- /dev/null +++ b/src/pkg/net/tcp_test.go @@ -0,0 +1,206 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "reflect" + "runtime" + "testing" + "time" +) + +func BenchmarkTCP4OneShot(b *testing.B) { + benchmarkTCP(b, false, false, "127.0.0.1:0") +} + +func BenchmarkTCP4OneShotTimeout(b *testing.B) { + benchmarkTCP(b, false, true, "127.0.0.1:0") +} + +func BenchmarkTCP4Persistent(b *testing.B) { + benchmarkTCP(b, true, false, "127.0.0.1:0") +} + +func BenchmarkTCP4PersistentTimeout(b *testing.B) { + benchmarkTCP(b, true, true, "127.0.0.1:0") +} + +func BenchmarkTCP6OneShot(b *testing.B) { + if !supportsIPv6 { + b.Skip("ipv6 is not supported") + } + benchmarkTCP(b, false, false, "[::1]:0") +} + +func BenchmarkTCP6OneShotTimeout(b *testing.B) { + if !supportsIPv6 { + b.Skip("ipv6 is not supported") + } + benchmarkTCP(b, false, true, "[::1]:0") +} + +func BenchmarkTCP6Persistent(b *testing.B) { + if !supportsIPv6 { + b.Skip("ipv6 is not supported") + } + benchmarkTCP(b, true, false, "[::1]:0") +} + +func BenchmarkTCP6PersistentTimeout(b *testing.B) { + if !supportsIPv6 { + b.Skip("ipv6 is not supported") + } + benchmarkTCP(b, true, true, "[::1]:0") +} + +func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { + const msgLen = 512 + conns := b.N + numConcurrent := runtime.GOMAXPROCS(-1) * 16 + msgs := 1 + if persistent { + conns = numConcurrent + msgs = b.N / conns + if msgs == 0 { + msgs = 1 + } + if conns > b.N { + conns = b.N + } + } + sendMsg := func(c Conn, buf []byte) bool { + n, err := c.Write(buf) + if n != len(buf) || err != nil { + b.Logf("Write failed: %v", err) + return false + } + return true + } + recvMsg := func(c Conn, buf []byte) bool { + for read := 0; read != len(buf); { + n, err := c.Read(buf) + read += n + if err != nil { + b.Logf("Read failed: %v", err) + return false + } + } + return true + } + ln, err := Listen("tcp", laddr) + if err != nil { + b.Fatalf("Listen failed: %v", err) + } + defer ln.Close() + // Acceptor. + go func() { + for { + c, err := ln.Accept() + if err != nil { + break + } + // Server connection. + go func(c Conn) { + defer c.Close() + if timeout { + c.SetDeadline(time.Now().Add(time.Hour)) // Not intended to fire. + } + var buf [msgLen]byte + for m := 0; m < msgs; m++ { + if !recvMsg(c, buf[:]) || !sendMsg(c, buf[:]) { + break + } + } + }(c) + } + }() + sem := make(chan bool, numConcurrent) + for i := 0; i < conns; i++ { + sem <- true + // Client connection. + go func() { + defer func() { + <-sem + }() + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + b.Logf("Dial failed: %v", err) + return + } + defer c.Close() + if timeout { + c.SetDeadline(time.Now().Add(time.Hour)) // Not intended to fire. + } + var buf [msgLen]byte + for m := 0; m < msgs; m++ { + if !sendMsg(c, buf[:]) || !recvMsg(c, buf[:]) { + break + } + } + }() + } + for i := 0; i < cap(sem); i++ { + sem <- true + } +} + +var resolveTCPAddrTests = []struct { + net string + litAddr string + addr *TCPAddr + err error +}{ + {"tcp", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, + {"tcp4", "127.0.0.1:65535", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil}, + + {"tcp", "[::1]:1", &TCPAddr{IP: ParseIP("::1"), Port: 1}, nil}, + {"tcp6", "[::1]:65534", &TCPAddr{IP: ParseIP("::1"), Port: 65534}, nil}, + + {"", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior + {"", "[::1]:0", &TCPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior + + {"http", "127.0.0.1:0", nil, UnknownNetworkError("http")}, +} + +func TestResolveTCPAddr(t *testing.T) { + for _, tt := range resolveTCPAddrTests { + addr, err := ResolveTCPAddr(tt.net, tt.litAddr) + if err != tt.err { + t.Fatalf("ResolveTCPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) + } + if !reflect.DeepEqual(addr, tt.addr) { + t.Fatalf("got %#v; expected %#v", addr, tt.addr) + } + } +} + +var tcpListenerNameTests = []struct { + net string + laddr *TCPAddr +}{ + {"tcp4", &TCPAddr{IP: IPv4(127, 0, 0, 1)}}, + {"tcp4", &TCPAddr{}}, + {"tcp4", nil}, +} + +func TestTCPListenerName(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + + for _, tt := range tcpListenerNameTests { + ln, err := ListenTCP(tt.net, tt.laddr) + if err != nil { + t.Errorf("ListenTCP failed: %v", err) + return + } + defer ln.Close() + la := ln.Addr() + if a, ok := la.(*TCPAddr); !ok || a.Port == 0 { + t.Errorf("got %v; expected a proper address with non-zero port number", la) + return + } + } +} diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go index 47fbf2919..d5158b22d 100644 --- a/src/pkg/net/tcpsock.go +++ b/src/pkg/net/tcpsock.go @@ -10,6 +10,7 @@ package net type TCPAddr struct { IP IP Port int + Zone string // IPv6 scoped addressing zone } // Network returns the address's network name, "tcp". @@ -28,9 +29,16 @@ func (a *TCPAddr) String() string { // "tcp4" or "tcp6". A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { - ip, port, err := hostPortToIP(net, addr) + switch net { + case "tcp", "tcp4", "tcp6": + case "": // a hint wildcard for Go 1.0 undocumented behavior + net = "tcp" + default: + return nil, UnknownNetworkError(net) + } + a, err := resolveInternetAddr(net, addr, noDeadline) if err != nil { return nil, err } - return &TCPAddr{ip, port}, nil + return a.(*TCPAddr), nil } diff --git a/src/pkg/net/tcpsock_plan9.go b/src/pkg/net/tcpsock_plan9.go index 35f56966e..ed3664603 100644 --- a/src/pkg/net/tcpsock_plan9.go +++ b/src/pkg/net/tcpsock_plan9.go @@ -2,34 +2,30 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TCP for Plan 9 +// TCP sockets for Plan 9 package net import ( + "io" + "os" "syscall" "time" ) -// TCPConn is an implementation of the Conn interface -// for TCP network connections. +// TCPConn is an implementation of the Conn interface for TCP network +// connections. type TCPConn struct { - plan9Conn + conn } -// SetDeadline implements the Conn SetDeadline method. -func (c *TCPConn) SetDeadline(t time.Time) error { - return syscall.EPLAN9 +func newTCPConn(fd *netFD) *TCPConn { + return &TCPConn{conn{fd}} } -// SetReadDeadline implements the Conn SetReadDeadline method. -func (c *TCPConn) SetReadDeadline(t time.Time) error { - return syscall.EPLAN9 -} - -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (c *TCPConn) SetWriteDeadline(t time.Time) error { - return syscall.EPLAN9 +// ReadFrom implements the io.ReaderFrom ReadFrom method. +func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { + return genericReadFrom(c, r) } // CloseRead shuts down the reading side of the TCP connection. @@ -38,7 +34,7 @@ func (c *TCPConn) CloseRead() error { if !c.ok() { return syscall.EINVAL } - return syscall.EPLAN9 + return c.fd.CloseRead() } // CloseWrite shuts down the writing side of the TCP connection. @@ -47,51 +43,142 @@ func (c *TCPConn) CloseWrite() error { if !c.ok() { return syscall.EINVAL } + return c.fd.CloseWrite() +} + +// SetLinger sets the behavior of Close() on a connection which still +// has data waiting to be sent or to be acknowledged. +// +// If sec < 0 (the default), Close returns immediately and the +// operating system finishes sending the data in the background. +// +// If sec == 0, Close returns immediately and the operating system +// discards any unsent or unacknowledged data. +// +// If sec > 0, Close blocks for at most sec seconds waiting for data +// to be sent and acknowledged. +func (c *TCPConn) SetLinger(sec int) error { + return syscall.EPLAN9 +} + +// SetKeepAlive sets whether the operating system should send +// keepalive messages on the connection. +func (c *TCPConn) SetKeepAlive(keepalive bool) error { + return syscall.EPLAN9 +} + +// SetNoDelay controls whether the operating system should delay +// packet transmission in hopes of sending fewer packets (Nagle's +// algorithm). The default is true (no delay), meaning that data is +// sent as soon as possible after a Write. +func (c *TCPConn) SetNoDelay(noDelay bool) error { return syscall.EPLAN9 } // DialTCP connects to the remote address raddr on the network net, -// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used -// as the local address for the connection. -func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err error) { +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is +// used as the local address for the connection. +func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { + return dialTCP(net, laddr, raddr, noDeadline) +} + +func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) { + if !deadline.IsZero() { + panic("net.dialTCP: deadline not implemented on Plan 9") + } switch net { case "tcp", "tcp4", "tcp6": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{"dial", net, raddr, UnknownNetworkError(net)} } if raddr == nil { return nil, &OpError{"dial", net, nil, errMissingAddress} } - c1, err := dialPlan9(net, laddr, raddr) + fd, err := dialPlan9(net, laddr, raddr) if err != nil { - return + return nil, err } - return &TCPConn{*c1}, nil + return newTCPConn(fd), nil } -// TCPListener is a TCP network listener. -// Clients should typically use variables of type Listener -// instead of assuming TCP. +// TCPListener is a TCP network listener. Clients should typically +// use variables of type Listener instead of assuming TCP. type TCPListener struct { - plan9Listener + fd *netFD +} + +// AcceptTCP accepts the next incoming call and returns the new +// connection and the remote address. +func (l *TCPListener) AcceptTCP() (*TCPConn, error) { + if l == nil || l.fd == nil || l.fd.ctl == nil { + return nil, syscall.EINVAL + } + fd, err := l.fd.acceptPlan9() + if err != nil { + return nil, err + } + return newTCPConn(fd), nil +} + +// Accept implements the Accept method in the Listener interface; it +// waits for the next call and returns a generic Conn. +func (l *TCPListener) Accept() (Conn, error) { + if l == nil || l.fd == nil || l.fd.ctl == nil { + return nil, syscall.EINVAL + } + c, err := l.AcceptTCP() + if err != nil { + return nil, err + } + return c, nil } -// ListenTCP announces on the TCP address laddr and returns a TCP listener. -// Net must be "tcp", "tcp4", or "tcp6". -// If laddr has a port of 0, it means to listen on some available port. -// The caller can use l.Addr() to retrieve the chosen address. -func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err error) { +// Close stops listening on the TCP address. +// Already Accepted connections are not closed. +func (l *TCPListener) Close() error { + if l == nil || l.fd == nil || l.fd.ctl == nil { + return syscall.EINVAL + } + if _, err := l.fd.ctl.WriteString("hangup"); err != nil { + l.fd.ctl.Close() + return &OpError{"close", l.fd.ctl.Name(), l.fd.laddr, err} + } + return l.fd.ctl.Close() +} + +// Addr returns the listener's network address, a *TCPAddr. +func (l *TCPListener) Addr() Addr { return l.fd.laddr } + +// SetDeadline sets the deadline associated with the listener. +// A zero time value disables the deadline. +func (l *TCPListener) SetDeadline(t time.Time) error { + if l == nil || l.fd == nil || l.fd.ctl == nil { + return syscall.EINVAL + } + return setDeadline(l.fd, t) +} + +// File returns a copy of the underlying os.File, set to blocking +// mode. It is the caller's responsibility to close f when finished. +// Closing l does not affect f, and closing f does not affect l. +func (l *TCPListener) File() (f *os.File, err error) { return l.dup() } + +// ListenTCP announces on the TCP address laddr and returns a TCP +// listener. Net must be "tcp", "tcp4", or "tcp6". If laddr has a +// port of 0, it means to listen on some available port. The caller +// can use l.Addr() to retrieve the chosen address. +func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { switch net { case "tcp", "tcp4", "tcp6": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{"listen", net, laddr, UnknownNetworkError(net)} } if laddr == nil { - return nil, &OpError{"listen", net, nil, errMissingAddress} + laddr = &TCPAddr{} } - l1, err := listenPlan9(net, laddr) + fd, err := listenPlan9(net, laddr) if err != nil { - return + return nil, err } - return &TCPListener{*l1}, nil + return &TCPListener{fd}, nil } diff --git a/src/pkg/net/tcpsock_posix.go b/src/pkg/net/tcpsock_posix.go index e6b1937fb..bd5a2a287 100644 --- a/src/pkg/net/tcpsock_posix.go +++ b/src/pkg/net/tcpsock_posix.go @@ -23,14 +23,9 @@ import ( func sockaddrToTCP(sa syscall.Sockaddr) Addr { switch sa := sa.(type) { case *syscall.SockaddrInet4: - return &TCPAddr{sa.Addr[0:], sa.Port} + return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port} case *syscall.SockaddrInet6: - return &TCPAddr{sa.Addr[0:], sa.Port} - default: - if sa != nil { - // Diagnose when we will turn a non-nil sockaddr into a nil. - panic("unexpected type in sockaddrToTCP") - } + return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))} } return nil } @@ -53,7 +48,7 @@ func (a *TCPAddr) isWildcard() bool { } func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, error) { - return ipToSockaddr(family, a.IP, a.Port) + return ipToSockaddr(family, a.IP, a.Port, a.Zone) } func (a *TCPAddr) toAddr() sockaddr { @@ -66,27 +61,15 @@ func (a *TCPAddr) toAddr() sockaddr { // TCPConn is an implementation of the Conn interface // for TCP network connections. type TCPConn struct { - fd *netFD + conn } func newTCPConn(fd *netFD) *TCPConn { - c := &TCPConn{fd} + c := &TCPConn{conn{fd}} c.SetNoDelay(true) return c } -func (c *TCPConn) ok() bool { return c != nil && c.fd != nil } - -// Implementation of the Conn interface - see Conn for documentation. - -// Read implements the Conn Read method. -func (c *TCPConn) Read(b []byte) (n int, err error) { - if !c.ok() { - return 0, syscall.EINVAL - } - return c.fd.Read(b) -} - // ReadFrom implements the io.ReaderFrom ReadFrom method. func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { if n, err, handled := sendFile(c.fd, r); handled { @@ -95,22 +78,6 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { return genericReadFrom(c, r) } -// Write implements the Conn Write method. -func (c *TCPConn) Write(b []byte) (n int, err error) { - if !c.ok() { - return 0, syscall.EINVAL - } - return c.fd.Write(b) -} - -// Close closes the TCP connection. -func (c *TCPConn) Close() error { - if !c.ok() { - return syscall.EINVAL - } - return c.fd.Close() -} - // CloseRead shuts down the reading side of the TCP connection. // Most callers should just use Close. func (c *TCPConn) CloseRead() error { @@ -129,64 +96,6 @@ func (c *TCPConn) CloseWrite() error { return c.fd.CloseWrite() } -// LocalAddr returns the local network address, a *TCPAddr. -func (c *TCPConn) LocalAddr() Addr { - if !c.ok() { - return nil - } - return c.fd.laddr -} - -// RemoteAddr returns the remote network address, a *TCPAddr. -func (c *TCPConn) RemoteAddr() Addr { - if !c.ok() { - return nil - } - return c.fd.raddr -} - -// SetDeadline implements the Conn SetDeadline method. -func (c *TCPConn) SetDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setDeadline(c.fd, t) -} - -// SetReadDeadline implements the Conn SetReadDeadline method. -func (c *TCPConn) SetReadDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setReadDeadline(c.fd, t) -} - -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (c *TCPConn) SetWriteDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setWriteDeadline(c.fd, t) -} - -// SetReadBuffer sets the size of the operating system's -// receive buffer associated with the connection. -func (c *TCPConn) SetReadBuffer(bytes int) error { - if !c.ok() { - return syscall.EINVAL - } - return setReadBuffer(c.fd, bytes) -} - -// SetWriteBuffer sets the size of the operating system's -// transmit buffer associated with the connection. -func (c *TCPConn) SetWriteBuffer(bytes int) error { - if !c.ok() { - return syscall.EINVAL - } - return setWriteBuffer(c.fd, bytes) -} - // SetLinger sets the behavior of Close() on a connection // which still has data waiting to be sent or to be acknowledged. // @@ -225,20 +134,23 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error { return setNoDelay(c.fd, noDelay) } -// File returns a copy of the underlying os.File, set to blocking mode. -// It is the caller's responsibility to close f when finished. -// Closing c does not affect f, and closing f does not affect c. -func (c *TCPConn) File() (f *os.File, err error) { return c.fd.dup() } - // DialTCP connects to the remote address raddr on the network net, // which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used // as the local address for the connection. func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { + switch net { + case "tcp", "tcp4", "tcp6": + default: + return nil, UnknownNetworkError(net) + } if raddr == nil { return nil, &OpError{"dial", net, nil, errMissingAddress} } + return dialTCP(net, laddr, raddr, noDeadline) +} - fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) +func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) { + fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) // TCP has a rarely used mechanism called a 'simultaneous connection' in // which Dial("tcp", addr1, addr2) run on the machine at addr1 can @@ -257,9 +169,18 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { // use the result. See also: // http://golang.org/issue/2690 // http://stackoverflow.com/questions/4949858/ - for i := 0; i < 2 && err == nil && laddr == nil && selfConnect(fd); i++ { - fd.Close() - fd, err = internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) + // + // The opposite can also happen: if we ask the kernel to pick an appropriate + // originating local address, sometimes it picks one that is already in use. + // So if the error is EADDRNOTAVAIL, we have to try again too, just for + // a different reason. + // + // The kernel socket code is no doubt enjoying watching us squirm. + for i := 0; i < 2 && (laddr == nil || laddr.Port == 0) && (selfConnect(fd, err) || spuriousENOTAVAIL(err)); i++ { + if err == nil { + fd.Close() + } + fd, err = internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) } if err != nil { @@ -268,7 +189,12 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { return newTCPConn(fd), nil } -func selfConnect(fd *netFD) bool { +func selfConnect(fd *netFD, err error) bool { + // If the connect failed, we clearly didn't connect to ourselves. + if err != nil { + return false + } + // The socket constructor can return an fd with raddr nil under certain // unknown conditions. The errors in the calls there to Getpeername // are discarded, but we can't catch the problem there because those @@ -285,6 +211,11 @@ func selfConnect(fd *netFD) bool { return l.Port == r.Port && l.IP.Equal(r.IP) } +func spuriousENOTAVAIL(err error) bool { + e, ok := err.(*OpError) + return ok && e.Err == syscall.EADDRNOTAVAIL +} + // TCPListener is a TCP network listener. // Clients should typically use variables of type Listener // instead of assuming TCP. @@ -292,29 +223,10 @@ type TCPListener struct { fd *netFD } -// ListenTCP announces on the TCP address laddr and returns a TCP listener. -// Net must be "tcp", "tcp4", or "tcp6". -// If laddr has a port of 0, it means to listen on some available port. -// The caller can use l.Addr() to retrieve the chosen address. -func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { - fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP) - if err != nil { - return nil, err - } - err = syscall.Listen(fd.sysfd, listenerBacklog) - if err != nil { - closesocket(fd.sysfd) - return nil, &OpError{"listen", net, laddr, err} - } - l := new(TCPListener) - l.fd = fd - return l, nil -} - // AcceptTCP accepts the next incoming call and returns the new connection // and the remote address. func (l *TCPListener) AcceptTCP() (c *TCPConn, err error) { - if l == nil || l.fd == nil || l.fd.sysfd < 0 { + if l == nil || l.fd == nil { return nil, syscall.EINVAL } fd, err := l.fd.accept(sockaddrToTCP) @@ -359,3 +271,28 @@ func (l *TCPListener) SetDeadline(t time.Time) error { // It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. func (l *TCPListener) File() (f *os.File, err error) { return l.fd.dup() } + +// ListenTCP announces on the TCP address laddr and returns a TCP listener. +// Net must be "tcp", "tcp4", or "tcp6". +// If laddr has a port of 0, it means to listen on some available port. +// The caller can use l.Addr() to retrieve the chosen address. +func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { + switch net { + case "tcp", "tcp4", "tcp6": + default: + return nil, UnknownNetworkError(net) + } + if laddr == nil { + laddr = &TCPAddr{} + } + fd, err := internetSocket(net, laddr.toAddr(), nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP) + if err != nil { + return nil, err + } + err = syscall.Listen(fd.sysfd, listenerBacklog) + if err != nil { + closesocket(fd.sysfd) + return nil, &OpError{"listen", net, laddr, err} + } + return &TCPListener{fd}, nil +} diff --git a/src/pkg/net/textproto/reader.go b/src/pkg/net/textproto/reader.go index 125feb3e8..b61bea862 100644 --- a/src/pkg/net/textproto/reader.go +++ b/src/pkg/net/textproto/reader.go @@ -128,6 +128,17 @@ func (r *Reader) readContinuedLineSlice() ([]byte, error) { return line, nil } + // Optimistically assume that we have started to buffer the next line + // and it starts with an ASCII letter (the next header key), so we can + // avoid copying that buffered data around in memory and skipping over + // non-existent whitespace. + if r.R.Buffered() > 1 { + peek, err := r.R.Peek(1) + if err == nil && isASCIILetter(peek[0]) { + return trim(line), nil + } + } + // ReadByte or the next readLineSlice will flush the read buffer; // copy the slice into buf. r.buf = append(r.buf[:0], trim(line)...) @@ -445,23 +456,25 @@ func (r *Reader) ReadDotLines() ([]string, error) { // } // func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { - m := make(MIMEHeader) + m := make(MIMEHeader, 4) for { kv, err := r.readContinuedLineSlice() if len(kv) == 0 { return m, err } - // Key ends at first colon; must not have spaces. + // Key ends at first colon; should not have spaces but + // they appear in the wild, violating specs, so we + // remove them if present. i := bytes.IndexByte(kv, ':') if i < 0 { return m, ProtocolError("malformed MIME header line: " + string(kv)) } - key := string(kv[0:i]) - if strings.Index(key, " ") >= 0 { - key = strings.TrimRight(key, " ") + endKey := i + for endKey > 0 && kv[endKey-1] == ' ' { + endKey-- } - key = CanonicalMIMEHeaderKey(key) + key := canonicalMIMEHeaderKey(kv[:endKey]) // Skip initial spaces in value. i++ // skip colon @@ -484,41 +497,107 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { // letter and any letter following a hyphen to upper case; // the rest are converted to lowercase. For example, the // canonical key for "accept-encoding" is "Accept-Encoding". +// MIME header keys are assumed to be ASCII only. func CanonicalMIMEHeaderKey(s string) string { // Quick check for canonical encoding. - needUpper := true + upper := true for i := 0; i < len(s); i++ { c := s[i] - if needUpper && 'a' <= c && c <= 'z' { - goto MustRewrite + if upper && 'a' <= c && c <= 'z' { + return canonicalMIMEHeaderKey([]byte(s)) } - if !needUpper && 'A' <= c && c <= 'Z' { - goto MustRewrite + if !upper && 'A' <= c && c <= 'Z' { + return canonicalMIMEHeaderKey([]byte(s)) } - needUpper = c == '-' + upper = c == '-' } return s +} + +const toLower = 'a' - 'A' -MustRewrite: - // Canonicalize: first letter upper case - // and upper case after each dash. - // (Host, User-Agent, If-Modified-Since). - // MIME headers are ASCII only, so no Unicode issues. - a := []byte(s) +// canonicalMIMEHeaderKey is like CanonicalMIMEHeaderKey but is +// allowed to mutate the provided byte slice before returning the +// string. +func canonicalMIMEHeaderKey(a []byte) string { + // Look for it in commonHeaders , so that we can avoid an + // allocation by sharing the strings among all users + // of textproto. If we don't find it, a has been canonicalized + // so just return string(a). upper := true - for i, v := range a { - if v == ' ' { + lo := 0 + hi := len(commonHeaders) + for i := 0; i < len(a); i++ { + // Canonicalize: first letter upper case + // and upper case after each dash. + // (Host, User-Agent, If-Modified-Since). + // MIME headers are ASCII only, so no Unicode issues. + if a[i] == ' ' { a[i] = '-' upper = true continue } - if upper && 'a' <= v && v <= 'z' { - a[i] = v + 'A' - 'a' + c := a[i] + if upper && 'a' <= c && c <= 'z' { + c -= toLower + } else if !upper && 'A' <= c && c <= 'Z' { + c += toLower } - if !upper && 'A' <= v && v <= 'Z' { - a[i] = v + 'a' - 'A' + a[i] = c + upper = c == '-' // for next time + + if lo < hi { + for lo < hi && (len(commonHeaders[lo]) <= i || commonHeaders[lo][i] < c) { + lo++ + } + for hi > lo && commonHeaders[hi-1][i] > c { + hi-- + } } - upper = v == '-' + } + if lo < hi && len(commonHeaders[lo]) == len(a) { + return commonHeaders[lo] } return string(a) } + +var commonHeaders = []string{ + "Accept", + "Accept-Charset", + "Accept-Encoding", + "Accept-Language", + "Accept-Ranges", + "Cache-Control", + "Cc", + "Connection", + "Content-Id", + "Content-Language", + "Content-Length", + "Content-Transfer-Encoding", + "Content-Type", + "Date", + "Dkim-Signature", + "Etag", + "Expires", + "From", + "Host", + "If-Modified-Since", + "If-None-Match", + "In-Reply-To", + "Last-Modified", + "Location", + "Message-Id", + "Mime-Version", + "Pragma", + "Received", + "Return-Path", + "Server", + "Set-Cookie", + "Subject", + "To", + "User-Agent", + "Via", + "X-Forwarded-For", + "X-Imforwards", + "X-Powered-By", +} diff --git a/src/pkg/net/textproto/reader_test.go b/src/pkg/net/textproto/reader_test.go index 7c5d16227..26987f611 100644 --- a/src/pkg/net/textproto/reader_test.go +++ b/src/pkg/net/textproto/reader_test.go @@ -6,6 +6,7 @@ package textproto import ( "bufio" + "bytes" "io" "reflect" "strings" @@ -23,6 +24,7 @@ var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{ {"uSER-aGENT", "User-Agent"}, {"user-agent", "User-Agent"}, {"USER-AGENT", "User-Agent"}, + {"üser-agenT", "üser-Agent"}, // non-ASCII unchanged } func TestCanonicalMIMEHeaderKey(t *testing.T) { @@ -239,3 +241,95 @@ func TestRFC959Lines(t *testing.T) { } } } + +func TestCommonHeaders(t *testing.T) { + // need to disable the commonHeaders-based optimization + // during this check, or we'd not be testing anything + oldch := commonHeaders + commonHeaders = []string{} + defer func() { commonHeaders = oldch }() + + last := "" + for _, h := range oldch { + if last > h { + t.Errorf("%v is out of order", h) + } + if last == h { + t.Errorf("%v is duplicated", h) + } + if canon := CanonicalMIMEHeaderKey(h); h != canon { + t.Errorf("%v is not canonical", h) + } + last = h + } +} + +var clientHeaders = strings.Replace(`Host: golang.org +Connection: keep-alive +Cache-Control: max-age=0 +Accept: application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5 +User-Agent: Mozilla/5.0 (X11; U; Linux x86_64; en-US) AppleWebKit/534.3 (KHTML, like Gecko) Chrome/6.0.472.63 Safari/534.3 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8,fr-CH;q=0.6 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +COOKIE: __utma=000000000.0000000000.0000000000.0000000000.0000000000.00; __utmb=000000000.0.00.0000000000; __utmc=000000000; __utmz=000000000.0000000000.00.0.utmcsr=code.google.com|utmccn=(referral)|utmcmd=referral|utmcct=/p/go/issues/detail +Non-Interned: test + +`, "\n", "\r\n", -1) + +var serverHeaders = strings.Replace(`Content-Type: text/html; charset=utf-8 +Content-Encoding: gzip +Date: Thu, 27 Sep 2012 09:03:33 GMT +Server: Google Frontend +Cache-Control: private +Content-Length: 2298 +VIA: 1.1 proxy.example.com:80 (XXX/n.n.n-nnn) +Connection: Close +Non-Interned: test + +`, "\n", "\r\n", -1) + +func BenchmarkReadMIMEHeader(b *testing.B) { + var buf bytes.Buffer + br := bufio.NewReader(&buf) + r := NewReader(br) + for i := 0; i < b.N; i++ { + var want int + var find string + if (i & 1) == 1 { + buf.WriteString(clientHeaders) + want = 10 + find = "Cookie" + } else { + buf.WriteString(serverHeaders) + want = 9 + find = "Via" + } + h, err := r.ReadMIMEHeader() + if err != nil { + b.Fatal(err) + } + if len(h) != want { + b.Fatalf("wrong number of headers: got %d, want %d", len(h), want) + } + if _, ok := h[find]; !ok { + b.Fatalf("did not find key %s", find) + } + } +} + +func BenchmarkUncommon(b *testing.B) { + var buf bytes.Buffer + br := bufio.NewReader(&buf) + r := NewReader(br) + for i := 0; i < b.N; i++ { + buf.WriteString("uncommon-header-for-benchmark: foo\r\n\r\n") + h, err := r.ReadMIMEHeader() + if err != nil { + b.Fatal(err) + } + if _, ok := h["Uncommon-Header-For-Benchmark"]; !ok { + b.Fatal("Missing result header.") + } + } +} diff --git a/src/pkg/net/textproto/textproto.go b/src/pkg/net/textproto/textproto.go index ad5840cf7..eb6ced1c5 100644 --- a/src/pkg/net/textproto/textproto.go +++ b/src/pkg/net/textproto/textproto.go @@ -121,3 +121,34 @@ func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err error) { } return id, nil } + +// TrimString returns s without leading and trailing ASCII space. +func TrimString(s string) string { + for len(s) > 0 && isASCIISpace(s[0]) { + s = s[1:] + } + for len(s) > 0 && isASCIISpace(s[len(s)-1]) { + s = s[:len(s)-1] + } + return s +} + +// TrimBytes returns b without leading and trailing ASCII space. +func TrimBytes(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[0]) { + b = b[1:] + } + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] + } + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +func isASCIILetter(b byte) bool { + b |= 0x20 // make lower case + return 'a' <= b && b <= 'z' +} diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index 672fb7241..0260efcc0 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -6,11 +6,187 @@ package net import ( "fmt" + "io" + "io/ioutil" "runtime" "testing" "time" ) +func isTimeout(err error) bool { + e, ok := err.(Error) + return ok && e.Timeout() +} + +type copyRes struct { + n int64 + err error + d time.Duration +} + +func TestAcceptTimeout(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t).(*TCPListener) + defer ln.Close() + ln.SetDeadline(time.Now().Add(-1 * time.Second)) + if _, err := ln.Accept(); !isTimeout(err) { + t.Fatalf("Accept: expected err %v, got %v", errTimeout, err) + } + if _, err := ln.Accept(); !isTimeout(err) { + t.Fatalf("Accept: expected err %v, got %v", errTimeout, err) + } + ln.SetDeadline(time.Now().Add(100 * time.Millisecond)) + if _, err := ln.Accept(); !isTimeout(err) { + t.Fatalf("Accept: expected err %v, got %v", errTimeout, err) + } + if _, err := ln.Accept(); !isTimeout(err) { + t.Fatalf("Accept: expected err %v, got %v", errTimeout, err) + } + ln.SetDeadline(noDeadline) + errc := make(chan error) + go func() { + _, err := ln.Accept() + errc <- err + }() + time.Sleep(100 * time.Millisecond) + select { + case err := <-errc: + t.Fatalf("Expected Accept() to not return, but it returned with %v\n", err) + default: + } + ln.Close() + switch nerr := <-errc; err := nerr.(type) { + case *OpError: + if err.Err != errClosing { + t.Fatalf("Accept: expected err %v, got %v", errClosing, err) + } + default: + if err != errClosing { + t.Fatalf("Accept: expected err %v, got %v", errClosing, err) + } + } +} + +func TestReadTimeout(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t) + defer ln.Close() + c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr)) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer c.Close() + c.SetDeadline(time.Now().Add(time.Hour)) + c.SetReadDeadline(time.Now().Add(-1 * time.Second)) + buf := make([]byte, 1) + if _, err = c.Read(buf); !isTimeout(err) { + t.Fatalf("Read: expected err %v, got %v", errTimeout, err) + } + if _, err = c.Read(buf); !isTimeout(err) { + t.Fatalf("Read: expected err %v, got %v", errTimeout, err) + } + c.SetDeadline(time.Now().Add(100 * time.Millisecond)) + if _, err = c.Read(buf); !isTimeout(err) { + t.Fatalf("Read: expected err %v, got %v", errTimeout, err) + } + if _, err = c.Read(buf); !isTimeout(err) { + t.Fatalf("Read: expected err %v, got %v", errTimeout, err) + } + c.SetReadDeadline(noDeadline) + c.SetWriteDeadline(time.Now().Add(-1 * time.Second)) + errc := make(chan error) + go func() { + _, err := c.Read(buf) + errc <- err + }() + time.Sleep(100 * time.Millisecond) + select { + case err := <-errc: + t.Fatalf("Expected Read() to not return, but it returned with %v\n", err) + default: + } + c.Close() + switch nerr := <-errc; err := nerr.(type) { + case *OpError: + if err.Err != errClosing { + t.Fatalf("Read: expected err %v, got %v", errClosing, err) + } + default: + if err != errClosing { + t.Fatalf("Read: expected err %v, got %v", errClosing, err) + } + } +} + +func TestWriteTimeout(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t) + defer ln.Close() + c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr)) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer c.Close() + c.SetDeadline(time.Now().Add(time.Hour)) + c.SetWriteDeadline(time.Now().Add(-1 * time.Second)) + buf := make([]byte, 4096) + writeUntilTimeout := func() { + for { + _, err := c.Write(buf) + if err != nil { + if isTimeout(err) { + return + } + t.Fatalf("Write: expected err %v, got %v", errTimeout, err) + } + } + } + writeUntilTimeout() + c.SetDeadline(time.Now().Add(10 * time.Millisecond)) + writeUntilTimeout() + writeUntilTimeout() + c.SetWriteDeadline(noDeadline) + c.SetReadDeadline(time.Now().Add(-1 * time.Second)) + errc := make(chan error) + go func() { + for { + _, err := c.Write(buf) + if err != nil { + errc <- err + } + } + }() + time.Sleep(100 * time.Millisecond) + select { + case err := <-errc: + t.Fatalf("Expected Write() to not return, but it returned with %v\n", err) + default: + } + c.Close() + switch nerr := <-errc; err := nerr.(type) { + case *OpError: + if err.Err != errClosing { + t.Fatalf("Write: expected err %v, got %v", errClosing, err) + } + default: + if err != errClosing { + t.Fatalf("Write: expected err %v, got %v", errClosing, err) + } + } +} + func testTimeout(t *testing.T, net, addr string, readFrom bool) { c, err := Dial(net, addr) if err != nil { @@ -59,8 +235,7 @@ func testTimeout(t *testing.T, net, addr string, readFrom bool) { func TestTimeoutUDP(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } // set up a listener that won't talk back @@ -77,8 +252,7 @@ func TestTimeoutUDP(t *testing.T) { func TestTimeoutTCP(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } // set up a listener that won't talk back @@ -94,8 +268,7 @@ func TestTimeoutTCP(t *testing.T) { func TestDeadlineReset(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } ln, err := Listen("tcp", "127.0.0.1:0") if err != nil { @@ -104,7 +277,7 @@ func TestDeadlineReset(t *testing.T) { defer ln.Close() tl := ln.(*TCPListener) tl.SetDeadline(time.Now().Add(1 * time.Minute)) - tl.SetDeadline(time.Time{}) // reset it + tl.SetDeadline(noDeadline) // reset it errc := make(chan error, 1) go func() { _, err := ln.Accept() @@ -119,3 +292,356 @@ func TestDeadlineReset(t *testing.T) { t.Errorf("unexpected return from Accept; err=%v", err) } } + +func TestTimeoutAccept(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + ln, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + tl := ln.(*TCPListener) + tl.SetDeadline(time.Now().Add(100 * time.Millisecond)) + errc := make(chan error, 1) + go func() { + _, err := ln.Accept() + errc <- err + }() + select { + case <-time.After(1 * time.Second): + // Accept shouldn't block indefinitely + t.Errorf("Accept didn't return in an expected time") + case <-errc: + // Pass. + } +} + +func TestReadWriteDeadline(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + if !canCancelIO { + t.Skip("skipping test on this system") + } + const ( + readTimeout = 50 * time.Millisecond + writeTimeout = 250 * time.Millisecond + ) + checkTimeout := func(command string, start time.Time, should time.Duration) { + is := time.Now().Sub(start) + d := is - should + if d < -30*time.Millisecond || !testing.Short() && 150*time.Millisecond < d { + t.Errorf("%s timeout test failed: is=%v should=%v\n", command, is, should) + } + } + + ln, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenTCP on :0: %v", err) + } + defer ln.Close() + + lnquit := make(chan bool) + + go func() { + c, err := ln.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + defer c.Close() + lnquit <- true + }() + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + + start := time.Now() + err = c.SetReadDeadline(start.Add(readTimeout)) + if err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + err = c.SetWriteDeadline(start.Add(writeTimeout)) + if err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + + quit := make(chan bool) + + go func() { + var buf [10]byte + _, err := c.Read(buf[:]) + if err == nil { + t.Errorf("Read should not succeed") + } + checkTimeout("Read", start, readTimeout) + quit <- true + }() + + go func() { + var buf [10000]byte + for { + _, err := c.Write(buf[:]) + if err != nil { + break + } + } + checkTimeout("Write", start, writeTimeout) + quit <- true + }() + + <-quit + <-quit + <-lnquit +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +func TestVariousDeadlines1Proc(t *testing.T) { + testVariousDeadlines(t, 1) +} + +func TestVariousDeadlines4Proc(t *testing.T) { + testVariousDeadlines(t, 4) +} + +func testVariousDeadlines(t *testing.T, maxProcs int) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + ln := newLocalListener(t) + defer ln.Close() + acceptc := make(chan error, 1) + + // The server, with no timeouts of its own, sending bytes to clients + // as fast as it can. + servec := make(chan copyRes) + go func() { + for { + c, err := ln.Accept() + if err != nil { + acceptc <- err + return + } + go func() { + t0 := time.Now() + n, err := io.Copy(c, neverEnding('a')) + d := time.Since(t0) + c.Close() + servec <- copyRes{n, err, d} + }() + } + }() + + for _, timeout := range []time.Duration{ + 1 * time.Nanosecond, + 2 * time.Nanosecond, + 5 * time.Nanosecond, + 50 * time.Nanosecond, + 100 * time.Nanosecond, + 200 * time.Nanosecond, + 500 * time.Nanosecond, + 750 * time.Nanosecond, + 1 * time.Microsecond, + 5 * time.Microsecond, + 25 * time.Microsecond, + 250 * time.Microsecond, + 500 * time.Microsecond, + 1 * time.Millisecond, + 5 * time.Millisecond, + 100 * time.Millisecond, + 250 * time.Millisecond, + 500 * time.Millisecond, + 1 * time.Second, + } { + numRuns := 3 + if testing.Short() { + numRuns = 1 + if timeout > 500*time.Microsecond { + continue + } + } + for run := 0; run < numRuns; run++ { + name := fmt.Sprintf("%v run %d/%d", timeout, run+1, numRuns) + t.Log(name) + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + clientc := make(chan copyRes) + go func() { + t0 := time.Now() + c.SetDeadline(t0.Add(timeout)) + n, err := io.Copy(ioutil.Discard, c) + d := time.Since(t0) + c.Close() + clientc <- copyRes{n, err, d} + }() + + const tooLong = 2000 * time.Millisecond + select { + case res := <-clientc: + if isTimeout(res.err) { + t.Logf("for %v, good client timeout after %v, reading %d bytes", name, res.d, res.n) + } else { + t.Fatalf("for %v: client Copy = %d, %v (want timeout)", name, res.n, res.err) + } + case <-time.After(tooLong): + t.Fatalf("for %v: timeout (%v) waiting for client to timeout (%v) reading", name, tooLong, timeout) + } + + select { + case res := <-servec: + t.Logf("for %v: server in %v wrote %d, %v", name, res.d, res.n, res.err) + case err := <-acceptc: + t.Fatalf("for %v: server Accept = %v", name, err) + case <-time.After(tooLong): + t.Fatalf("for %v, timeout waiting for server to finish writing", name) + } + } + } +} + +// TestReadDeadlineDataAvailable tests that read deadlines work, even +// if there's data ready to be read. +func TestReadDeadlineDataAvailable(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t) + defer ln.Close() + + servec := make(chan copyRes) + const msg = "data client shouldn't read, even though it it'll be waiting" + go func() { + c, err := ln.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + defer c.Close() + n, err := c.Write([]byte(msg)) + servec <- copyRes{n: int64(n), err: err} + }() + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + if res := <-servec; res.err != nil || res.n != int64(len(msg)) { + t.Fatalf("unexpected server Write: n=%d, err=%d; want n=%d, err=nil", res.n, res.err, len(msg)) + } + c.SetReadDeadline(time.Now().Add(-5 * time.Second)) // in the psat. + buf := make([]byte, len(msg)/2) + n, err := c.Read(buf) + if n > 0 || !isTimeout(err) { + t.Fatalf("client read = %d (%q) err=%v; want 0, timeout", n, buf[:n], err) + } +} + +// TestWriteDeadlineBufferAvailable tests that write deadlines work, even +// if there's buffer space available to write. +func TestWriteDeadlineBufferAvailable(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t) + defer ln.Close() + + servec := make(chan copyRes) + go func() { + c, err := ln.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + defer c.Close() + c.SetWriteDeadline(time.Now().Add(-5 * time.Second)) // in the past + n, err := c.Write([]byte{'x'}) + servec <- copyRes{n: int64(n), err: err} + }() + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + res := <-servec + if res.n != 0 { + t.Errorf("Write = %d; want 0", res.n) + } + if !isTimeout(res.err) { + t.Errorf("Write error = %v; want timeout", res.err) + } +} + +// TestProlongTimeout tests concurrent deadline modification. +// Known to cause data races in the past. +func TestProlongTimeout(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t) + defer ln.Close() + connected := make(chan bool) + go func() { + s, err := ln.Accept() + connected <- true + if err != nil { + t.Fatalf("ln.Accept: %v", err) + } + defer s.Close() + s.SetDeadline(time.Now().Add(time.Hour)) + go func() { + var buf [4096]byte + for { + _, err := s.Write(buf[:]) + if err != nil { + break + } + s.SetDeadline(time.Now().Add(time.Hour)) + } + }() + buf := make([]byte, 1) + for { + _, err := s.Read(buf) + if err != nil { + break + } + s.SetDeadline(time.Now().Add(time.Hour)) + } + }() + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("DialTCP: %v", err) + } + defer c.Close() + <-connected + for i := 0; i < 1024; i++ { + var buf [1]byte + c.Write(buf[:]) + } +} diff --git a/src/pkg/net/udp_test.go b/src/pkg/net/udp_test.go index f80d3b5a9..220422e13 100644 --- a/src/pkg/net/udp_test.go +++ b/src/pkg/net/udp_test.go @@ -5,15 +5,45 @@ package net import ( + "reflect" "runtime" "testing" ) +var resolveUDPAddrTests = []struct { + net string + litAddr string + addr *UDPAddr + err error +}{ + {"udp", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, + {"udp4", "127.0.0.1:65535", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil}, + + {"udp", "[::1]:1", &UDPAddr{IP: ParseIP("::1"), Port: 1}, nil}, + {"udp6", "[::1]:65534", &UDPAddr{IP: ParseIP("::1"), Port: 65534}, nil}, + + {"", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior + {"", "[::1]:0", &UDPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior + + {"sip", "127.0.0.1:0", nil, UnknownNetworkError("sip")}, +} + +func TestResolveUDPAddr(t *testing.T) { + for _, tt := range resolveUDPAddrTests { + addr, err := ResolveUDPAddr(tt.net, tt.litAddr) + if err != tt.err { + t.Fatalf("ResolveUDPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) + } + if !reflect.DeepEqual(addr, tt.addr) { + t.Fatalf("got %#v; expected %#v", addr, tt.addr) + } + } +} + func TestWriteToUDP(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } l, err := ListenPacket("udp", "127.0.0.1:0") @@ -87,3 +117,32 @@ func testWriteToPacketConn(t *testing.T, raddr string) { t.Fatal("Write should fail") } } + +var udpConnLocalNameTests = []struct { + net string + laddr *UDPAddr +}{ + {"udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}}, + {"udp4", &UDPAddr{}}, + {"udp4", nil}, +} + +func TestUDPConnLocalName(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + + for _, tt := range udpConnLocalNameTests { + c, err := ListenUDP(tt.net, tt.laddr) + if err != nil { + t.Errorf("ListenUDP failed: %v", err) + return + } + defer c.Close() + la := c.LocalAddr() + if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { + t.Errorf("got %v; expected a proper address with non-zero port number", la) + return + } + } +} diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go index b3520cf09..6e5e90268 100644 --- a/src/pkg/net/udpsock.go +++ b/src/pkg/net/udpsock.go @@ -6,10 +6,15 @@ package net +import "errors" + +var ErrWriteToConnected = errors.New("use of WriteTo with pre-connected UDP") + // UDPAddr represents the address of a UDP end point. type UDPAddr struct { IP IP Port int + Zone string // IPv6 scoped addressing zone } // Network returns the address's network name, "udp". @@ -28,9 +33,16 @@ func (a *UDPAddr) String() string { // "udp4" or "udp6". A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { - ip, port, err := hostPortToIP(net, addr) + switch net { + case "udp", "udp4", "udp6": + case "": // a hint wildcard for Go 1.0 undocumented behavior + net = "udp" + default: + return nil, UnknownNetworkError(net) + } + a, err := resolveInternetAddr(net, addr, noDeadline) if err != nil { return nil, err } - return &UDPAddr{ip, port}, nil + return a.(*UDPAddr), nil } diff --git a/src/pkg/net/udpsock_plan9.go b/src/pkg/net/udpsock_plan9.go index 4f298a42f..2a7e3d19c 100644 --- a/src/pkg/net/udpsock_plan9.go +++ b/src/pkg/net/udpsock_plan9.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// UDP for Plan 9 +// UDP sockets for Plan 9 package net @@ -16,44 +16,26 @@ import ( // UDPConn is the implementation of the Conn and PacketConn // interfaces for UDP network connections. type UDPConn struct { - plan9Conn + conn } -// SetDeadline implements the Conn SetDeadline method. -func (c *UDPConn) SetDeadline(t time.Time) error { - return syscall.EPLAN9 +func newUDPConn(fd *netFD) *UDPConn { + return &UDPConn{conn{fd}} } -// SetReadDeadline implements the Conn SetReadDeadline method. -func (c *UDPConn) SetReadDeadline(t time.Time) error { - return syscall.EPLAN9 -} - -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (c *UDPConn) SetWriteDeadline(t time.Time) error { - return syscall.EPLAN9 -} - -// UDP-specific methods. - // ReadFromUDP reads a UDP packet from c, copying the payload into b. // It returns the number of bytes copied into b and the return address // that was on the packet. // -// ReadFromUDP can be made to time out and return an error with Timeout() == true -// after a fixed time limit; see SetDeadline and SetReadDeadline. +// ReadFromUDP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetReadDeadline. func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) { - if !c.ok() { + if !c.ok() || c.fd.data == nil { return 0, nil, syscall.EINVAL } - if c.data == nil { - c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0) - if err != nil { - return 0, nil, err - } - } buf := make([]byte, udpHeaderSize+len(b)) - m, err := c.data.Read(buf) + m, err := c.fd.data.Read(buf) if err != nil { return } @@ -64,62 +46,80 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) { h, buf := unmarshalUDPHeader(buf) n = copy(b, buf) - return n, &UDPAddr{h.raddr, int(h.rport)}, nil + return n, &UDPAddr{IP: h.raddr, Port: int(h.rport)}, nil } // ReadFrom implements the PacketConn ReadFrom method. -func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err error) { +func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) { if !c.ok() { return 0, nil, syscall.EINVAL } return c.ReadFromUDP(b) } -// WriteToUDP writes a UDP packet to addr via c, copying the payload from b. +// ReadMsgUDP reads a packet from c, copying the payload into b and +// the associdated out-of-band data into oob. It returns the number +// of bytes copied into b, the number of bytes copied into oob, the +// flags that were set on the packet and the source address of the +// packet. +func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) { + return 0, 0, 0, nil, syscall.EPLAN9 +} + +// WriteToUDP writes a UDP packet to addr via c, copying the payload +// from b. // -// WriteToUDP can be made to time out and return -// an error with Timeout() == true after a fixed time limit; -// see SetDeadline and SetWriteDeadline. -// On packet-oriented connections, write timeouts are rare. -func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err error) { - if !c.ok() { +// WriteToUDP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetWriteDeadline. On packet-oriented connections, write timeouts +// are rare. +func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { + if !c.ok() || c.fd.data == nil { return 0, syscall.EINVAL } - if c.data == nil { - c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0) - if err != nil { - return 0, err - } - } h := new(udpHeader) h.raddr = addr.IP.To16() - h.laddr = c.laddr.(*UDPAddr).IP.To16() + h.laddr = c.fd.laddr.(*UDPAddr).IP.To16() h.ifcaddr = IPv6zero // ignored (receive only) h.rport = uint16(addr.Port) - h.lport = uint16(c.laddr.(*UDPAddr).Port) + h.lport = uint16(c.fd.laddr.(*UDPAddr).Port) buf := make([]byte, udpHeaderSize+len(b)) i := copy(buf, h.Bytes()) copy(buf[i:], b) - return c.data.Write(buf) + return c.fd.data.Write(buf) } // WriteTo implements the PacketConn WriteTo method. -func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err error) { +func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) { if !c.ok() { return 0, syscall.EINVAL } a, ok := addr.(*UDPAddr) if !ok { - return 0, &OpError{"write", c.dir, addr, syscall.EINVAL} + return 0, &OpError{"write", c.fd.dir, addr, syscall.EINVAL} } return c.WriteToUDP(b, a) } +// WriteMsgUDP writes a packet to addr via c, copying the payload from +// b and the associated out-of-band data from oob. It returns the +// number of payload and out-of-band bytes written. +func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) { + return 0, 0, syscall.EPLAN9 +} + // DialUDP connects to the remote address raddr on the network net, -// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used -// as the local address for the connection. -func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err error) { +// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is +// used as the local address for the connection. +func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) { + return dialUDP(net, laddr, raddr, noDeadline) +} + +func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) { + if !deadline.IsZero() { + panic("net.dialUDP: deadline not implemented on Plan 9") + } switch net { case "udp", "udp4", "udp6": default: @@ -128,11 +128,11 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err error) { if raddr == nil { return nil, &OpError{"dial", net, nil, errMissingAddress} } - c1, err := dialPlan9(net, laddr, raddr) + fd, err := dialPlan9(net, laddr, raddr) if err != nil { - return + return nil, err } - return &UDPConn{*c1}, nil + return newUDPConn(fd), nil } const udpHeaderSize = 16*3 + 2*2 @@ -163,34 +163,38 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) { return h, b } -// ListenUDP listens for incoming UDP packets addressed to the -// local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send UDP -// packets with per-packet addressing. -func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err error) { +// ListenUDP listens for incoming UDP packets addressed to the local +// address laddr. The returned connection c's ReadFrom and WriteTo +// methods can be used to receive and send UDP packets with per-packet +// addressing. +func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { switch net { case "udp", "udp4", "udp6": default: return nil, UnknownNetworkError(net) } if laddr == nil { - return nil, &OpError{"listen", net, nil, errMissingAddress} + laddr = &UDPAddr{} } l, err := listenPlan9(net, laddr) if err != nil { - return + return nil, err } _, err = l.ctl.WriteString("headers") if err != nil { - return + return nil, err + } + l.data, err = os.OpenFile(l.dir+"/data", os.O_RDWR, 0) + if err != nil { + return nil, err } - return &UDPConn{*l.plan9Conn()}, nil + return newUDPConn(l.netFD()), nil } // ListenMulticastUDP listens for incoming multicast UDP packets -// addressed to the group address gaddr on ifi, which specifies -// the interface to join. ListenMulticastUDP uses default -// multicast interface if ifi is nil. +// addressed to the group address gaddr on ifi, which specifies the +// interface to join. ListenMulticastUDP uses default multicast +// interface if ifi is nil. func ListenMulticastUDP(net string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { return nil, syscall.EPLAN9 } diff --git a/src/pkg/net/udpsock_posix.go b/src/pkg/net/udpsock_posix.go index 9c6b6d393..385cd902e 100644 --- a/src/pkg/net/udpsock_posix.go +++ b/src/pkg/net/udpsock_posix.go @@ -4,25 +4,21 @@ // +build darwin freebsd linux netbsd openbsd windows -// UDP sockets +// UDP sockets for POSIX package net import ( - "errors" - "os" "syscall" "time" ) -var ErrWriteToConnected = errors.New("use of WriteTo with pre-connected UDP") - func sockaddrToUDP(sa syscall.Sockaddr) Addr { switch sa := sa.(type) { case *syscall.SockaddrInet4: - return &UDPAddr{sa.Addr[0:], sa.Port} + return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} case *syscall.SockaddrInet6: - return &UDPAddr{sa.Addr[0:], sa.Port} + return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))} } return nil } @@ -45,7 +41,7 @@ func (a *UDPAddr) isWildcard() bool { } func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) { - return ipToSockaddr(family, a.IP, a.Port) + return ipToSockaddr(family, a.IP, a.Port, a.Zone) } func (a *UDPAddr) toAddr() sockaddr { @@ -58,98 +54,10 @@ func (a *UDPAddr) toAddr() sockaddr { // UDPConn is the implementation of the Conn and PacketConn // interfaces for UDP network connections. type UDPConn struct { - fd *netFD -} - -func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{fd} } - -func (c *UDPConn) ok() bool { return c != nil && c.fd != nil } - -// Implementation of the Conn interface - see Conn for documentation. - -// Read implements the Conn Read method. -func (c *UDPConn) Read(b []byte) (int, error) { - if !c.ok() { - return 0, syscall.EINVAL - } - return c.fd.Read(b) -} - -// Write implements the Conn Write method. -func (c *UDPConn) Write(b []byte) (int, error) { - if !c.ok() { - return 0, syscall.EINVAL - } - return c.fd.Write(b) -} - -// Close closes the UDP connection. -func (c *UDPConn) Close() error { - if !c.ok() { - return syscall.EINVAL - } - return c.fd.Close() -} - -// LocalAddr returns the local network address. -func (c *UDPConn) LocalAddr() Addr { - if !c.ok() { - return nil - } - return c.fd.laddr -} - -// RemoteAddr returns the remote network address, a *UDPAddr. -func (c *UDPConn) RemoteAddr() Addr { - if !c.ok() { - return nil - } - return c.fd.raddr -} - -// SetDeadline implements the Conn SetDeadline method. -func (c *UDPConn) SetDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setDeadline(c.fd, t) -} - -// SetReadDeadline implements the Conn SetReadDeadline method. -func (c *UDPConn) SetReadDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setReadDeadline(c.fd, t) -} - -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (c *UDPConn) SetWriteDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setWriteDeadline(c.fd, t) -} - -// SetReadBuffer sets the size of the operating system's -// receive buffer associated with the connection. -func (c *UDPConn) SetReadBuffer(bytes int) error { - if !c.ok() { - return syscall.EINVAL - } - return setReadBuffer(c.fd, bytes) -} - -// SetWriteBuffer sets the size of the operating system's -// transmit buffer associated with the connection. -func (c *UDPConn) SetWriteBuffer(bytes int) error { - if !c.ok() { - return syscall.EINVAL - } - return setWriteBuffer(c.fd, bytes) + conn } -// UDP-specific methods. +func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} } // ReadFromUDP reads a UDP packet from c, copying the payload into b. // It returns the number of bytes copied into b and the return address @@ -164,9 +72,9 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) { n, sa, err := c.fd.ReadFrom(b) switch sa := sa.(type) { case *syscall.SockaddrInet4: - addr = &UDPAddr{sa.Addr[0:], sa.Port} + addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} case *syscall.SockaddrInet6: - addr = &UDPAddr{sa.Addr[0:], sa.Port} + addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))} } return } @@ -176,8 +84,28 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) { if !c.ok() { return 0, nil, syscall.EINVAL } - n, uaddr, err := c.ReadFromUDP(b) - return n, uaddr.toAddr(), err + n, addr, err := c.ReadFromUDP(b) + return n, addr.toAddr(), err +} + +// ReadMsgUDP reads a packet from c, copying the payload into b and +// the associdated out-of-band data into oob. It returns the number +// of bytes copied into b, the number of bytes copied into oob, the +// flags that were set on the packet and the source address of the +// packet. +func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) { + if !c.ok() { + return 0, 0, 0, nil, syscall.EINVAL + } + var sa syscall.Sockaddr + n, oobn, flags, sa, err = c.fd.ReadMsg(b, oob) + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} + case *syscall.SockaddrInet6: + addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))} + } + return } // WriteToUDP writes a UDP packet to addr via c, copying the payload from b. @@ -212,15 +140,31 @@ func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) { return c.WriteToUDP(b, a) } -// File returns a copy of the underlying os.File, set to blocking mode. -// It is the caller's responsibility to close f when finished. -// Closing c does not affect f, and closing f does not affect c. -func (c *UDPConn) File() (f *os.File, err error) { return c.fd.dup() } +// WriteMsgUDP writes a packet to addr via c, copying the payload from +// b and the associated out-of-band data from oob. It returns the +// number of payload and out-of-band bytes written. +func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) { + if !c.ok() { + return 0, 0, syscall.EINVAL + } + if c.fd.isConnected { + return 0, 0, &OpError{"write", c.fd.net, addr, ErrWriteToConnected} + } + sa, err := addr.sockaddr(c.fd.family) + if err != nil { + return 0, 0, &OpError{"write", c.fd.net, addr, err} + } + return c.fd.WriteMsg(b, oob, sa) +} // DialUDP connects to the remote address raddr on the network net, // which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used // as the local address for the connection. func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) { + return dialUDP(net, laddr, raddr, noDeadline) +} + +func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) { switch net { case "udp", "udp4", "udp6": default: @@ -229,7 +173,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) { if raddr == nil { return nil, &OpError{"dial", net, nil, errMissingAddress} } - fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) + fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) if err != nil { return nil, err } @@ -247,9 +191,9 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { return nil, UnknownNetworkError(net) } if laddr == nil { - return nil, &OpError{"listen", net, nil, errMissingAddress} + laddr = &UDPAddr{} } - fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP) + fd, err := internetSocket(net, laddr.toAddr(), nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP) if err != nil { return nil, err } @@ -267,25 +211,22 @@ func ListenMulticastUDP(net string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, e return nil, UnknownNetworkError(net) } if gaddr == nil || gaddr.IP == nil { - return nil, &OpError{"listenmulticast", net, nil, errMissingAddress} + return nil, &OpError{"listen", net, nil, errMissingAddress} } - fd, err := internetSocket(net, gaddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP) + fd, err := internetSocket(net, gaddr.toAddr(), nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP) if err != nil { return nil, err } c := newUDPConn(fd) - ip4 := gaddr.IP.To4() - if ip4 != nil { - err := listenIPv4MulticastUDP(c, ifi, ip4) - if err != nil { + if ip4 := gaddr.IP.To4(); ip4 != nil { + if err := listenIPv4MulticastUDP(c, ifi, ip4); err != nil { c.Close() - return nil, err + return nil, &OpError{"listen", net, &IPAddr{IP: ip4}, err} } } else { - err := listenIPv6MulticastUDP(c, ifi, gaddr.IP) - if err != nil { + if err := listenIPv6MulticastUDP(c, ifi, gaddr.IP); err != nil { c.Close() - return nil, err + return nil, &OpError{"listen", net, &IPAddr{IP: gaddr.IP}, err} } } return c, nil @@ -293,17 +234,14 @@ func ListenMulticastUDP(net string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, e func listenIPv4MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error { if ifi != nil { - err := setIPv4MulticastInterface(c.fd, ifi) - if err != nil { + if err := setIPv4MulticastInterface(c.fd, ifi); err != nil { return err } } - err := setIPv4MulticastLoopback(c.fd, false) - if err != nil { + if err := setIPv4MulticastLoopback(c.fd, false); err != nil { return err } - err = joinIPv4GroupUDP(c, ifi, ip) - if err != nil { + if err := joinIPv4Group(c.fd, ifi, ip); err != nil { return err } return nil @@ -311,50 +249,15 @@ func listenIPv4MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error { func listenIPv6MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error { if ifi != nil { - err := setIPv6MulticastInterface(c.fd, ifi) - if err != nil { + if err := setIPv6MulticastInterface(c.fd, ifi); err != nil { return err } } - err := setIPv6MulticastLoopback(c.fd, false) - if err != nil { + if err := setIPv6MulticastLoopback(c.fd, false); err != nil { return err } - err = joinIPv6GroupUDP(c, ifi, ip) - if err != nil { + if err := joinIPv6Group(c.fd, ifi, ip); err != nil { return err } return nil } - -func joinIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error { - err := joinIPv4Group(c.fd, ifi, ip) - if err != nil { - return &OpError{"joinipv4group", c.fd.net, &IPAddr{ip}, err} - } - return nil -} - -func leaveIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error { - err := leaveIPv4Group(c.fd, ifi, ip) - if err != nil { - return &OpError{"leaveipv4group", c.fd.net, &IPAddr{ip}, err} - } - return nil -} - -func joinIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error { - err := joinIPv6Group(c.fd, ifi, ip) - if err != nil { - return &OpError{"joinipv6group", c.fd.net, &IPAddr{ip}, err} - } - return nil -} - -func leaveIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error { - err := leaveIPv6Group(c.fd, ifi, ip) - if err != nil { - return &OpError{"leaveipv6group", c.fd.net, &IPAddr{ip}, err} - } - return nil -} diff --git a/src/pkg/net/unicast_test.go b/src/pkg/net/unicast_posix_test.go index e5dd013db..a8855cab7 100644 --- a/src/pkg/net/unicast_test.go +++ b/src/pkg/net/unicast_posix_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// +build !plan9 + package net import ( @@ -44,8 +46,7 @@ var listenerTests = []struct { func TestTCPListener(t *testing.T) { switch runtime.GOOS { case "plan9", "windows": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } for _, tt := range listenerTests { @@ -59,13 +60,6 @@ func TestTCPListener(t *testing.T) { checkFirstListener(t, tt.net, tt.laddr+":"+port, l1) l2, err := Listen(tt.net, tt.laddr+":"+port) checkSecondListener(t, tt.net, tt.laddr+":"+port, err, l2) - fd := l1.(*TCPListener).fd - switch fd.family { - case syscall.AF_INET: - testIPv4UnicastSocketOptions(t, fd) - case syscall.AF_INET6: - testIPv6UnicastSocketOptions(t, fd) - } l1.Close() } } @@ -76,8 +70,7 @@ func TestTCPListener(t *testing.T) { func TestUDPListener(t *testing.T) { switch runtime.GOOS { case "plan9", "windows": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } toudpnet := func(net string) string { @@ -104,13 +97,6 @@ func TestUDPListener(t *testing.T) { checkFirstListener(t, tt.net, tt.laddr+":"+port, l1) l2, err := ListenPacket(tt.net, tt.laddr+":"+port) checkSecondListener(t, tt.net, tt.laddr+":"+port, err, l2) - fd := l1.(*UDPConn).fd - switch fd.family { - case syscall.AF_INET: - testIPv4UnicastSocketOptions(t, fd) - case syscall.AF_INET6: - testIPv6UnicastSocketOptions(t, fd) - } l1.Close() } } @@ -118,7 +104,7 @@ func TestUDPListener(t *testing.T) { func TestSimpleTCPListener(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) + t.Skipf("skipping test on %q", runtime.GOOS) return } @@ -140,7 +126,7 @@ func TestSimpleTCPListener(t *testing.T) { func TestSimpleUDPListener(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) + t.Skipf("skipping test on %q", runtime.GOOS) return } @@ -183,9 +169,9 @@ var dualStackListenerTests = []struct { // Test cases and expected results for the attemping 2nd listen on the same port // 1st listen 2nd listen darwin freebsd linux openbsd // ------------------------------------------------------------------------------------ - // "tcp" "" "tcp" "" - - - - - // "tcp" "" "tcp" "0.0.0.0" - - - - - // "tcp" "0.0.0.0" "tcp" "" - - - - + // "tcp" "" "tcp" "" - - - - + // "tcp" "" "tcp" "0.0.0.0" - - - - + // "tcp" "0.0.0.0" "tcp" "" - - - - // ------------------------------------------------------------------------------------ // "tcp" "" "tcp" "[::]" - - - ok // "tcp" "[::]" "tcp" "" - - - ok @@ -242,8 +228,7 @@ var dualStackListenerTests = []struct { func TestDualStackTCPListener(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } if !supportsIPv6 { return @@ -275,8 +260,7 @@ func TestDualStackTCPListener(t *testing.T) { func TestDualStackUDPListener(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } if !supportsIPv6 { return @@ -468,44 +452,6 @@ func checkDualStackAddrFamily(t *testing.T, net, laddr string, fd *netFD) { } } -func testIPv4UnicastSocketOptions(t *testing.T, fd *netFD) { - _, err := ipv4TOS(fd) - if err != nil { - t.Fatalf("ipv4TOS failed: %v", err) - } - err = setIPv4TOS(fd, 1) - if err != nil { - t.Fatalf("setIPv4TOS failed: %v", err) - } - _, err = ipv4TTL(fd) - if err != nil { - t.Fatalf("ipv4TTL failed: %v", err) - } - err = setIPv4TTL(fd, 1) - if err != nil { - t.Fatalf("setIPv4TTL failed: %v", err) - } -} - -func testIPv6UnicastSocketOptions(t *testing.T, fd *netFD) { - _, err := ipv6TrafficClass(fd) - if err != nil { - t.Fatalf("ipv6TrafficClass failed: %v", err) - } - err = setIPv6TrafficClass(fd, 1) - if err != nil { - t.Fatalf("setIPv6TrafficClass failed: %v", err) - } - _, err = ipv6HopLimit(fd) - if err != nil { - t.Fatalf("ipv6HopLimit failed: %v", err) - } - err = setIPv6HopLimit(fd, 1) - if err != nil { - t.Fatalf("setIPv6HopLimit failed: %v", err) - } -} - var prohibitionaryDialArgTests = []struct { net string addr string @@ -517,8 +463,7 @@ var prohibitionaryDialArgTests = []struct { func TestProhibitionaryDialArgs(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Logf("skipping test on %q", runtime.GOOS) - return + t.Skipf("skipping test on %q", runtime.GOOS) } // This test requires both IPv6 and IPv6 IPv4-mapping functionality. if !supportsIPv4map || testing.Short() || !*testExternal { @@ -536,3 +481,36 @@ func TestProhibitionaryDialArgs(t *testing.T) { } } } + +func TestWildWildcardListener(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + + defer func() { + if recover() != nil { + t.Fatalf("panicked") + } + }() + + if ln, err := Listen("tcp", ""); err == nil { + ln.Close() + } + if ln, err := ListenPacket("udp", ""); err == nil { + ln.Close() + } + if ln, err := ListenTCP("tcp", nil); err == nil { + ln.Close() + } + if ln, err := ListenUDP("udp", nil); err == nil { + ln.Close() + } + if ln, err := ListenIP("ip:icmp", nil); err == nil { + ln.Close() + } +} diff --git a/src/pkg/net/unix_test.go b/src/pkg/net/unix_test.go new file mode 100644 index 000000000..2eaabe86e --- /dev/null +++ b/src/pkg/net/unix_test.go @@ -0,0 +1,144 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9,!windows + +package net + +import ( + "bytes" + "os" + "reflect" + "runtime" + "syscall" + "testing" + "time" +) + +func TestReadUnixgramWithUnnamedSocket(t *testing.T) { + addr := testUnixAddr() + la, err := ResolveUnixAddr("unixgram", addr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c, err := ListenUnixgram("unixgram", la) + if err != nil { + t.Fatalf("ListenUnixgram failed: %v", err) + } + defer func() { + c.Close() + os.Remove(addr) + }() + + off := make(chan bool) + data := [5]byte{1, 2, 3, 4, 5} + + go func() { + defer func() { off <- true }() + s, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + t.Errorf("syscall.Socket failed: %v", err) + return + } + defer syscall.Close(s) + rsa := &syscall.SockaddrUnix{Name: addr} + if err := syscall.Sendto(s, data[:], 0, rsa); err != nil { + t.Errorf("syscall.Sendto failed: %v", err) + return + } + }() + + <-off + b := make([]byte, 64) + c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + n, from, err := c.ReadFrom(b) + if err != nil { + t.Errorf("UnixConn.ReadFrom failed: %v", err) + return + } + if from != nil { + t.Errorf("neighbor address is %v", from) + } + if !bytes.Equal(b[:n], data[:]) { + t.Errorf("got %v, want %v", b[:n], data[:]) + return + } +} + +func TestReadUnixgramWithZeroBytesBuffer(t *testing.T) { + // issue 4352: Recvfrom failed with "address family not + // supported by protocol family" if zero-length buffer provided + + addr := testUnixAddr() + la, err := ResolveUnixAddr("unixgram", addr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c, err := ListenUnixgram("unixgram", la) + if err != nil { + t.Fatalf("ListenUnixgram failed: %v", err) + } + defer func() { + c.Close() + os.Remove(addr) + }() + + off := make(chan bool) + go func() { + defer func() { off <- true }() + c, err := DialUnix("unixgram", nil, la) + if err != nil { + t.Errorf("DialUnix failed: %v", err) + return + } + defer c.Close() + if _, err := c.Write([]byte{1, 2, 3, 4, 5}); err != nil { + t.Errorf("UnixConn.Write failed: %v", err) + return + } + }() + + <-off + c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + var peer Addr + if _, peer, err = c.ReadFrom(nil); err != nil { + t.Errorf("UnixConn.ReadFrom failed: %v", err) + return + } + if peer != nil { + t.Errorf("peer adddress is %v", peer) + } +} + +func TestUnixAutobind(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("skipping: autobind is linux only") + } + + laddr := &UnixAddr{Name: "", Net: "unixgram"} + c1, err := ListenUnixgram("unixgram", laddr) + if err != nil { + t.Fatalf("ListenUnixgram failed: %v", err) + } + defer c1.Close() + + // retrieve the autobind address + autoAddr := c1.LocalAddr().(*UnixAddr) + if len(autoAddr.Name) <= 1 { + t.Fatalf("Invalid autobind address: %v", autoAddr) + } + if autoAddr.Name[0] != '@' { + t.Fatalf("Invalid autobind address: %v", autoAddr) + } + + c2, err := DialUnix("unixgram", nil, autoAddr) + if err != nil { + t.Fatalf("DialUnix failed: %v", err) + } + defer c2.Close() + + if !reflect.DeepEqual(c1.LocalAddr(), c2.RemoteAddr()) { + t.Fatalf("Expected autobind address %v, got %v", c1.LocalAddr(), c2.RemoteAddr()) + } +} diff --git a/src/pkg/net/unixsock_plan9.go b/src/pkg/net/unixsock_plan9.go index 7b4ae6bd1..00a0be5b0 100644 --- a/src/pkg/net/unixsock_plan9.go +++ b/src/pkg/net/unixsock_plan9.go @@ -7,100 +7,135 @@ package net import ( + "os" "syscall" "time" ) -// UnixConn is an implementation of the Conn interface -// for connections to Unix domain sockets. -type UnixConn bool - -// Implementation of the Conn interface - see Conn for documentation. +// UnixConn is an implementation of the Conn interface for connections +// to Unix domain sockets. +type UnixConn struct { + conn +} -// Read implements the Conn Read method. -func (c *UnixConn) Read(b []byte) (n int, err error) { - return 0, syscall.EPLAN9 +// ReadFromUnix reads a packet from c, copying the payload into b. It +// returns the number of bytes copied into b and the source address of +// the packet. +// +// ReadFromUnix can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetReadDeadline. +func (c *UnixConn) ReadFromUnix(b []byte) (int, *UnixAddr, error) { + return 0, nil, syscall.EPLAN9 } -// Write implements the Conn Write method. -func (c *UnixConn) Write(b []byte) (n int, err error) { - return 0, syscall.EPLAN9 +// ReadFrom implements the PacketConn ReadFrom method. +func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) { + return 0, nil, syscall.EPLAN9 } -// Close closes the Unix domain connection. -func (c *UnixConn) Close() error { - return syscall.EPLAN9 +// ReadMsgUnix reads a packet from c, copying the payload into b and +// the associated out-of-band data into oob. It returns the number of +// bytes copied into b, the number of bytes copied into oob, the flags +// that were set on the packet, and the source address of the packet. +func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) { + return 0, 0, 0, nil, syscall.EPLAN9 } -// LocalAddr returns the local network address, a *UnixAddr. -// Unlike in other protocols, LocalAddr is usually nil for dialed connections. -func (c *UnixConn) LocalAddr() Addr { - return nil +// WriteToUnix writes a packet to addr via c, copying the payload from b. +// +// WriteToUnix can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetWriteDeadline. On packet-oriented connections, write timeouts +// are rare. +func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (int, error) { + return 0, syscall.EPLAN9 } -// RemoteAddr returns the remote network address, a *UnixAddr. -// Unlike in other protocols, RemoteAddr is usually nil for connections -// accepted by a listener. -func (c *UnixConn) RemoteAddr() Addr { - return nil +// WriteTo implements the PacketConn WriteTo method. +func (c *UnixConn) WriteTo(b []byte, addr Addr) (int, error) { + return 0, syscall.EPLAN9 } -// SetDeadline implements the Conn SetDeadline method. -func (c *UnixConn) SetDeadline(t time.Time) error { - return syscall.EPLAN9 +// WriteMsgUnix writes a packet to addr via c, copying the payload +// from b and the associated out-of-band data from oob. It returns +// the number of payload and out-of-band bytes written. +func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) { + return 0, 0, syscall.EPLAN9 } -// SetReadDeadline implements the Conn SetReadDeadline method. -func (c *UnixConn) SetReadDeadline(t time.Time) error { +// CloseRead shuts down the reading side of the Unix domain connection. +// Most callers should just use Close. +func (c *UnixConn) CloseRead() error { return syscall.EPLAN9 } -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (c *UnixConn) SetWriteDeadline(t time.Time) error { +// CloseWrite shuts down the writing side of the Unix domain connection. +// Most callers should just use Close. +func (c *UnixConn) CloseWrite() error { return syscall.EPLAN9 } -// ReadFrom implements the PacketConn ReadFrom method. -func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err error) { - err = syscall.EPLAN9 - return +// DialUnix connects to the remote address raddr on the network net, +// which must be "unix", "unixgram" or "unixpacket". If laddr is not +// nil, it is used as the local address for the connection. +func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) { + return dialUnix(net, laddr, raddr, noDeadline) } -// WriteTo implements the PacketConn WriteTo method. -func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err error) { - err = syscall.EPLAN9 - return +func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) { + return nil, syscall.EPLAN9 } -// DialUnix connects to the remote address raddr on the network net, -// which must be "unix" or "unixgram". If laddr is not nil, it is used -// as the local address for the connection. -func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err error) { +// UnixListener is a Unix domain socket listener. Clients should +// typically use variables of type Listener instead of assuming Unix +// domain sockets. +type UnixListener struct{} + +// ListenUnix announces on the Unix domain socket laddr and returns a +// Unix listener. The network net must be "unix" or "unixpacket". +func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { return nil, syscall.EPLAN9 } -// UnixListener is a Unix domain socket listener. -// Clients should typically use variables of type Listener -// instead of assuming Unix domain sockets. -type UnixListener bool - -// ListenUnix announces on the Unix domain socket laddr and returns a Unix listener. -// Net must be "unix" (stream sockets). -func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err error) { +// AcceptUnix accepts the next incoming call and returns the new +// connection and the remote address. +func (l *UnixListener) AcceptUnix() (*UnixConn, error) { return nil, syscall.EPLAN9 } -// Accept implements the Accept method in the Listener interface; -// it waits for the next call and returns a generic Conn. -func (l *UnixListener) Accept() (c Conn, err error) { +// Accept implements the Accept method in the Listener interface; it +// waits for the next call and returns a generic Conn. +func (l *UnixListener) Accept() (Conn, error) { return nil, syscall.EPLAN9 } -// Close stops listening on the Unix address. -// Already accepted connections are not closed. +// Close stops listening on the Unix address. Already accepted +// connections are not closed. func (l *UnixListener) Close() error { return syscall.EPLAN9 } // Addr returns the listener's network address. func (l *UnixListener) Addr() Addr { return nil } + +// SetDeadline sets the deadline associated with the listener. +// A zero time value disables the deadline. +func (l *UnixListener) SetDeadline(t time.Time) error { + return syscall.EPLAN9 +} + +// File returns a copy of the underlying os.File, set to blocking +// mode. It is the caller's responsibility to close f when finished. +// Closing l does not affect f, and closing f does not affect l. +func (l *UnixListener) File() (*os.File, error) { + return nil, syscall.EPLAN9 +} + +// ListenUnixgram listens for incoming Unix datagram packets addressed +// to the local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send packets with +// per-packet addressing. The network net must be "unixgram". +func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { + return nil, syscall.EPLAN9 +} diff --git a/src/pkg/net/unixsock_posix.go b/src/pkg/net/unixsock_posix.go index 57d784c71..6d6ce3f5e 100644 --- a/src/pkg/net/unixsock_posix.go +++ b/src/pkg/net/unixsock_posix.go @@ -9,29 +9,27 @@ package net import ( + "errors" "os" "syscall" "time" ) -func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err error) { +func unixSocket(net string, laddr, raddr *UnixAddr, mode string, deadline time.Time) (*netFD, error) { var sotype int switch net { - default: - return nil, UnknownNetworkError(net) case "unix": sotype = syscall.SOCK_STREAM case "unixgram": sotype = syscall.SOCK_DGRAM case "unixpacket": sotype = syscall.SOCK_SEQPACKET + default: + return nil, UnknownNetworkError(net) } var la, ra syscall.Sockaddr switch mode { - default: - panic("unixSocket mode " + mode) - case "dial": if laddr != nil { la = &syscall.SockaddrUnix{Name: laddr.Name} @@ -41,15 +39,10 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err } else if sotype != syscall.SOCK_DGRAM || laddr == nil { return nil, &OpError{Op: mode, Net: net, Err: errMissingAddress} } - case "listen": - if laddr == nil { - return nil, &OpError{mode, net, nil, errMissingAddress} - } la = &syscall.SockaddrUnix{Name: laddr.Name} - if raddr != nil { - return nil, &OpError{Op: mode, Net: net, Addr: raddr, Err: &AddrError{Err: "unexpected remote address", Addr: raddr.String()}} - } + default: + return nil, errors.New("unknown mode: " + mode) } f := sockaddrToUnix @@ -59,15 +52,16 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err f = sockaddrToUnixpacket } - fd, err = socket(net, syscall.AF_UNIX, sotype, 0, false, la, ra, f) + fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, la, ra, deadline, f) if err != nil { - goto Error + goto error } return fd, nil -Error: +error: addr := raddr - if mode == "listen" { + switch mode { + case "listen": addr = laddr } return nil, &OpError{Op: mode, Net: net, Addr: addr, Err: err} @@ -108,110 +102,21 @@ func sotypeToNet(sotype int) string { return "" } -// UnixConn is an implementation of the Conn interface -// for connections to Unix domain sockets. +// UnixConn is an implementation of the Conn interface for connections +// to Unix domain sockets. type UnixConn struct { - fd *netFD -} - -func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{fd} } - -func (c *UnixConn) ok() bool { return c != nil && c.fd != nil } - -// Implementation of the Conn interface - see Conn for documentation. - -// Read implements the Conn Read method. -func (c *UnixConn) Read(b []byte) (n int, err error) { - if !c.ok() { - return 0, syscall.EINVAL - } - return c.fd.Read(b) -} - -// Write implements the Conn Write method. -func (c *UnixConn) Write(b []byte) (n int, err error) { - if !c.ok() { - return 0, syscall.EINVAL - } - return c.fd.Write(b) -} - -// Close closes the Unix domain connection. -func (c *UnixConn) Close() error { - if !c.ok() { - return syscall.EINVAL - } - return c.fd.Close() -} - -// LocalAddr returns the local network address, a *UnixAddr. -// Unlike in other protocols, LocalAddr is usually nil for dialed connections. -func (c *UnixConn) LocalAddr() Addr { - if !c.ok() { - return nil - } - return c.fd.laddr + conn } -// RemoteAddr returns the remote network address, a *UnixAddr. -// Unlike in other protocols, RemoteAddr is usually nil for connections -// accepted by a listener. -func (c *UnixConn) RemoteAddr() Addr { - if !c.ok() { - return nil - } - return c.fd.raddr -} +func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} } -// SetDeadline implements the Conn SetDeadline method. -func (c *UnixConn) SetDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setDeadline(c.fd, t) -} - -// SetReadDeadline implements the Conn SetReadDeadline method. -func (c *UnixConn) SetReadDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setReadDeadline(c.fd, t) -} - -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (c *UnixConn) SetWriteDeadline(t time.Time) error { - if !c.ok() { - return syscall.EINVAL - } - return setWriteDeadline(c.fd, t) -} - -// SetReadBuffer sets the size of the operating system's -// receive buffer associated with the connection. -func (c *UnixConn) SetReadBuffer(bytes int) error { - if !c.ok() { - return syscall.EINVAL - } - return setReadBuffer(c.fd, bytes) -} - -// SetWriteBuffer sets the size of the operating system's -// transmit buffer associated with the connection. -func (c *UnixConn) SetWriteBuffer(bytes int) error { - if !c.ok() { - return syscall.EINVAL - } - return setWriteBuffer(c.fd, bytes) -} - -// ReadFromUnix reads a packet from c, copying the payload into b. -// It returns the number of bytes copied into b and the source address -// of the packet. +// ReadFromUnix reads a packet from c, copying the payload into b. It +// returns the number of bytes copied into b and the source address of +// the packet. // -// ReadFromUnix can be made to time out and return -// an error with Timeout() == true after a fixed time limit; -// see SetDeadline and SetReadDeadline. +// ReadFromUnix can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetReadDeadline. func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err error) { if !c.ok() { return 0, nil, syscall.EINVAL @@ -219,26 +124,46 @@ func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err error) { n, sa, err := c.fd.ReadFrom(b) switch sa := sa.(type) { case *syscall.SockaddrUnix: - addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} + if sa.Name != "" { + addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} + } } return } // ReadFrom implements the PacketConn ReadFrom method. -func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err error) { +func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) { if !c.ok() { return 0, nil, syscall.EINVAL } - n, uaddr, err := c.ReadFromUnix(b) - return n, uaddr.toAddr(), err + n, addr, err := c.ReadFromUnix(b) + return n, addr.toAddr(), err +} + +// ReadMsgUnix reads a packet from c, copying the payload into b and +// the associated out-of-band data into oob. It returns the number of +// bytes copied into b, the number of bytes copied into oob, the flags +// that were set on the packet, and the source address of the packet. +func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) { + if !c.ok() { + return 0, 0, 0, nil, syscall.EINVAL + } + n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob) + switch sa := sa.(type) { + case *syscall.SockaddrUnix: + if sa.Name != "" { + addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} + } + } + return } // WriteToUnix writes a packet to addr via c, copying the payload from b. // -// WriteToUnix can be made to time out and return -// an error with Timeout() == true after a fixed time limit; -// see SetDeadline and SetWriteDeadline. -// On packet-oriented connections, write timeouts are rare. +// WriteToUnix can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetWriteDeadline. On packet-oriented connections, write timeouts +// are rare. func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err error) { if !c.ok() { return 0, syscall.EINVAL @@ -262,26 +187,9 @@ func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err error) { return c.WriteToUnix(b, a) } -// ReadMsgUnix reads a packet from c, copying the payload into b -// and the associated out-of-band data into oob. -// It returns the number of bytes copied into b, the number of -// bytes copied into oob, the flags that were set on the packet, -// and the source address of the packet. -func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) { - if !c.ok() { - return 0, 0, 0, nil, syscall.EINVAL - } - n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob) - switch sa := sa.(type) { - case *syscall.SockaddrUnix: - addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} - } - return -} - -// WriteMsgUnix writes a packet to addr via c, copying the payload from b -// and the associated out-of-band data from oob. It returns the number -// of payload and out-of-band bytes written. +// WriteMsgUnix writes a packet to addr via c, copying the payload +// from b and the associated out-of-band data from oob. It returns +// the number of payload and out-of-band bytes written. func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) { if !c.ok() { return 0, 0, syscall.EINVAL @@ -296,40 +204,64 @@ func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err return c.fd.WriteMsg(b, oob, nil) } -// File returns a copy of the underlying os.File, set to blocking mode. -// It is the caller's responsibility to close f when finished. -// Closing c does not affect f, and closing f does not affect c. -func (c *UnixConn) File() (f *os.File, err error) { return c.fd.dup() } +// CloseRead shuts down the reading side of the Unix domain connection. +// Most callers should just use Close. +func (c *UnixConn) CloseRead() error { + if !c.ok() { + return syscall.EINVAL + } + return c.fd.CloseRead() +} + +// CloseWrite shuts down the writing side of the Unix domain connection. +// Most callers should just use Close. +func (c *UnixConn) CloseWrite() error { + if !c.ok() { + return syscall.EINVAL + } + return c.fd.CloseWrite() +} // DialUnix connects to the remote address raddr on the network net, -// which must be "unix" or "unixgram". If laddr is not nil, it is used -// as the local address for the connection. +// which must be "unix", "unixgram" or "unixpacket". If laddr is not +// nil, it is used as the local address for the connection. func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) { - fd, err := unixSocket(net, laddr, raddr, "dial") + return dialUnix(net, laddr, raddr, noDeadline) +} + +func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) { + switch net { + case "unix", "unixgram", "unixpacket": + default: + return nil, UnknownNetworkError(net) + } + fd, err := unixSocket(net, laddr, raddr, "dial", deadline) if err != nil { return nil, err } return newUnixConn(fd), nil } -// UnixListener is a Unix domain socket listener. -// Clients should typically use variables of type Listener -// instead of assuming Unix domain sockets. +// UnixListener is a Unix domain socket listener. Clients should +// typically use variables of type Listener instead of assuming Unix +// domain sockets. type UnixListener struct { fd *netFD path string } -// ListenUnix announces on the Unix domain socket laddr and returns a Unix listener. -// Net must be "unix" (stream sockets). +// ListenUnix announces on the Unix domain socket laddr and returns a +// Unix listener. The network net must be "unix" or "unixpacket". func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { - if net != "unix" && net != "unixgram" && net != "unixpacket" { + switch net { + case "unix", "unixpacket": + default: return nil, UnknownNetworkError(net) } - if laddr != nil { - laddr = &UnixAddr{laddr.Name, net} // make our own copy + if laddr == nil { + return nil, &OpError{"listen", net, nil, errMissingAddress} } - fd, err := unixSocket(net, laddr, nil, "listen") + fd, err := unixSocket(net, laddr, nil, "listen", noDeadline) if err != nil { return nil, err } @@ -341,8 +273,8 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { return &UnixListener{fd, laddr.Name}, nil } -// AcceptUnix accepts the next incoming call and returns the new connection -// and the remote address. +// AcceptUnix accepts the next incoming call and returns the new +// connection and the remote address. func (l *UnixListener) AcceptUnix() (*UnixConn, error) { if l == nil || l.fd == nil { return nil, syscall.EINVAL @@ -355,8 +287,8 @@ func (l *UnixListener) AcceptUnix() (*UnixConn, error) { return c, nil } -// Accept implements the Accept method in the Listener interface; -// it waits for the next call and returns a generic Conn. +// Accept implements the Accept method in the Listener interface; it +// waits for the next call and returns a generic Conn. func (l *UnixListener) Accept() (c Conn, err error) { c1, err := l.AcceptUnix() if err != nil { @@ -365,8 +297,8 @@ func (l *UnixListener) Accept() (c Conn, err error) { return c1, nil } -// Close stops listening on the Unix address. -// Already accepted connections are not closed. +// Close stops listening on the Unix address. Already accepted +// connections are not closed. func (l *UnixListener) Close() error { if l == nil || l.fd == nil { return syscall.EINVAL @@ -385,9 +317,7 @@ func (l *UnixListener) Close() error { if l.path[0] != '@' { syscall.Unlink(l.path) } - err := l.fd.Close() - l.fd = nil - return err + return l.fd.Close() } // Addr returns the listener's network address. @@ -402,16 +332,16 @@ func (l *UnixListener) SetDeadline(t time.Time) (err error) { return setDeadline(l.fd, t) } -// File returns a copy of the underlying os.File, set to blocking mode. -// It is the caller's responsibility to close f when finished. +// File returns a copy of the underlying os.File, set to blocking +// mode. It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. func (l *UnixListener) File() (f *os.File, err error) { return l.fd.dup() } -// ListenUnixgram listens for incoming Unix datagram packets addressed to the -// local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send UDP -// packets with per-packet addressing. The network net must be "unixgram". -func ListenUnixgram(net string, laddr *UnixAddr) (*UDPConn, error) { +// ListenUnixgram listens for incoming Unix datagram packets addressed +// to the local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send packets with +// per-packet addressing. The network net must be "unixgram". +func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { switch net { case "unixgram": default: @@ -420,9 +350,9 @@ func ListenUnixgram(net string, laddr *UnixAddr) (*UDPConn, error) { if laddr == nil { return nil, &OpError{"listen", net, nil, errMissingAddress} } - fd, err := unixSocket(net, laddr, nil, "listen") + fd, err := unixSocket(net, laddr, nil, "listen", noDeadline) if err != nil { return nil, err } - return newUDPConn(fd), nil + return newUnixConn(fd), nil } diff --git a/src/pkg/net/url/url.go b/src/pkg/net/url/url.go index 17bf0d3a3..a39964ea1 100644 --- a/src/pkg/net/url/url.go +++ b/src/pkg/net/url/url.go @@ -7,7 +7,9 @@ package url import ( + "bytes" "errors" + "sort" "strconv" "strings" ) @@ -218,11 +220,18 @@ func escape(s string, mode encoding) string { // // scheme:opaque[?query][#fragment] // +// Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/. +// A consequence is that it is impossible to tell which slashes in the Path were +// slashes in the raw URL and which were %2f. This distinction is rarely important, +// but when it is a client must use other routines to parse the raw URL or construct +// the parsed URL. For example, an HTTP server can consult req.RequestURI, and +// an HTTP client can use URL{Host: "example.com", Opaque: "//example.com/Go%2f"} +// instead of URL{Host: "example.com", Path: "/Go/"}. type URL struct { Scheme string Opaque string // encoded opaque data User *Userinfo // username and password information - Host string + Host string // host or host:port Path string RawQuery string // encoded query values, without '?' Fragment string // fragment for references, without '#' @@ -359,11 +368,17 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) { } url = new(URL) + if rawurl == "*" { + url.Path = "*" + return + } + // Split off possible leading "http:", "mailto:", etc. // Cannot contain escaped characters. if url.Scheme, rest, err = getscheme(rawurl); err != nil { goto Error } + url.Scheme = strings.ToLower(url.Scheme) rest, url.RawQuery = split(rest, '?', true) @@ -379,7 +394,7 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) { } } - if (url.Scheme != "" || !viaRequest) && strings.HasPrefix(rest, "//") && !strings.HasPrefix(rest, "///") { + if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") { var authority string authority, rest = split(rest[2:], '/', false) url.User, url.Host, err = parseAuthority(authority) @@ -427,30 +442,35 @@ func parseAuthority(authority string) (user *Userinfo, host string, err error) { // String reassembles the URL into a valid URL string. func (u *URL) String() string { - // TODO: Rewrite to use bytes.Buffer - result := "" + var buf bytes.Buffer if u.Scheme != "" { - result += u.Scheme + ":" + buf.WriteString(u.Scheme) + buf.WriteByte(':') } if u.Opaque != "" { - result += u.Opaque + buf.WriteString(u.Opaque) } else { - if u.Host != "" || u.User != nil { - result += "//" + if u.Scheme != "" || u.Host != "" || u.User != nil { + buf.WriteString("//") if u := u.User; u != nil { - result += u.String() + "@" + buf.WriteString(u.String()) + buf.WriteByte('@') + } + if h := u.Host; h != "" { + buf.WriteString(h) } - result += u.Host } - result += escape(u.Path, encodePath) + buf.WriteString(escape(u.Path, encodePath)) } if u.RawQuery != "" { - result += "?" + u.RawQuery + buf.WriteByte('?') + buf.WriteString(u.RawQuery) } if u.Fragment != "" { - result += "#" + escape(u.Fragment, encodeFragment) + buf.WriteByte('#') + buf.WriteString(escape(u.Fragment, encodeFragment)) } - return result + return buf.String() } // Values maps a string key to a list of values. @@ -519,12 +539,16 @@ func parseQuery(m Values, query string) (err error) { } key, err1 := QueryUnescape(key) if err1 != nil { - err = err1 + if err == nil { + err = err1 + } continue } value, err1 = QueryUnescape(value) if err1 != nil { - err = err1 + if err == nil { + err = err1 + } continue } m[key] = append(m[key], value) @@ -538,14 +562,24 @@ func (v Values) Encode() string { if v == nil { return "" } - parts := make([]string, 0, len(v)) // will be large enough for most uses - for k, vs := range v { + var buf bytes.Buffer + keys := make([]string, 0, len(v)) + for k := range v { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + vs := v[k] prefix := QueryEscape(k) + "=" for _, v := range vs { - parts = append(parts, prefix+QueryEscape(v)) + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(prefix) + buf.WriteString(QueryEscape(v)) } } - return strings.Join(parts, "&") + return buf.String() } // resolvePath applies special path segments from refs and applies @@ -556,23 +590,33 @@ func resolvePath(basepath string, refpath string) string { if len(base) == 0 { base = []string{""} } + + rm := true for idx, ref := range refs { switch { case ref == ".": - base[len(base)-1] = "" + if idx == 0 { + base[len(base)-1] = "" + rm = true + } else { + rm = false + } case ref == "..": newLen := len(base) - 1 if newLen < 1 { newLen = 1 } base = base[0:newLen] - base[len(base)-1] = "" + if rm { + base[len(base)-1] = "" + } default: if idx == 0 || base[len(base)-1] == "" { base[len(base)-1] = ref } else { base = append(base, ref) } + rm = false } } return strings.Join(base, "/") @@ -650,6 +694,10 @@ func (u *URL) RequestURI() string { if result == "" { result = "/" } + } else { + if strings.HasPrefix(result, "//") { + result = u.Scheme + ":" + result + } } if u.RawQuery != "" { result += "?" + u.RawQuery diff --git a/src/pkg/net/url/url_test.go b/src/pkg/net/url/url_test.go index 75e8abe4e..4c4f406c2 100644 --- a/src/pkg/net/url/url_test.go +++ b/src/pkg/net/url/url_test.go @@ -7,6 +7,7 @@ package url import ( "fmt" "reflect" + "strings" "testing" ) @@ -121,14 +122,14 @@ var urltests = []URLTest{ }, "http:%2f%2fwww.google.com/?q=go+language", }, - // non-authority + // non-authority with path { "mailto:/webmaster@golang.org", &URL{ Scheme: "mailto", Path: "/webmaster@golang.org", }, - "", + "mailto:///webmaster@golang.org", // unfortunate compromise }, // non-authority { @@ -241,6 +242,24 @@ var urltests = []URLTest{ }, "http://www.google.com/?q=go+language#foo&bar", }, + { + "file:///home/adg/rabbits", + &URL{ + Scheme: "file", + Host: "", + Path: "/home/adg/rabbits", + }, + "file:///home/adg/rabbits", + }, + // case-insensitive scheme + { + "MaIlTo:webmaster@golang.org", + &URL{ + Scheme: "mailto", + Opaque: "webmaster@golang.org", + }, + "mailto:webmaster@golang.org", + }, } // more useful string for debugging than fmt's struct printer @@ -270,13 +289,37 @@ func DoTest(t *testing.T, parse func(string) (*URL, error), name string, tests [ } } +func BenchmarkString(b *testing.B) { + b.StopTimer() + b.ReportAllocs() + for _, tt := range urltests { + u, err := Parse(tt.in) + if err != nil { + b.Errorf("Parse(%q) returned error %s", tt.in, err) + continue + } + if tt.roundtrip == "" { + continue + } + b.StartTimer() + var g string + for i := 0; i < b.N; i++ { + g = u.String() + } + b.StopTimer() + if w := tt.roundtrip; g != w { + b.Errorf("Parse(%q).String() == %q, want %q", tt.in, g, w) + } + } +} + func TestParse(t *testing.T) { DoTest(t, Parse, "Parse", urltests) } const pathThatLooksSchemeRelative = "//not.a.user@not.a.host/just/a/path" -var parseRequestUrlTests = []struct { +var parseRequestURLTests = []struct { url string expectedValid bool }{ @@ -288,10 +331,11 @@ var parseRequestUrlTests = []struct { {"//not.a.user@%66%6f%6f.com/just/a/path/also", true}, {"foo.html", false}, {"../dir/", false}, + {"*", true}, } func TestParseRequestURI(t *testing.T) { - for _, test := range parseRequestUrlTests { + for _, test := range parseRequestURLTests { _, err := ParseRequestURI(test.url) valid := err == nil if valid != test.expectedValid { @@ -453,20 +497,24 @@ func TestEscape(t *testing.T) { //} type EncodeQueryTest struct { - m Values - expected string - expected1 string + m Values + expected string } var encodeQueryTests = []EncodeQueryTest{ - {nil, "", ""}, - {Values{"q": {"puppies"}, "oe": {"utf8"}}, "q=puppies&oe=utf8", "oe=utf8&q=puppies"}, - {Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7", "q=dogs&q=%26&q=7"}, + {nil, ""}, + {Values{"q": {"puppies"}, "oe": {"utf8"}}, "oe=utf8&q=puppies"}, + {Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7"}, + {Values{ + "a": {"a1", "a2", "a3"}, + "b": {"b1", "b2", "b3"}, + "c": {"c1", "c2", "c3"}, + }, "a=a1&a=a2&a=a3&b=b1&b=b2&b=b3&c=c1&c=c2&c=c3"}, } func TestEncodeQuery(t *testing.T) { for _, tt := range encodeQueryTests { - if q := tt.m.Encode(); q != tt.expected && q != tt.expected1 { + if q := tt.m.Encode(); q != tt.expected { t.Errorf(`EncodeQuery(%+v) = %q, want %q`, tt.m, q, tt.expected) } } @@ -531,6 +579,15 @@ var resolveReferenceTests = []struct { {"http://foo.com/bar/baz", "../../../../../quux", "http://foo.com/quux"}, {"http://foo.com/bar", "..", "http://foo.com/"}, {"http://foo.com/bar/baz", "./..", "http://foo.com/"}, + // ".." in the middle (issue 3560) + {"http://foo.com/bar/baz", "quux/dotdot/../tail", "http://foo.com/bar/quux/tail"}, + {"http://foo.com/bar/baz", "quux/./dotdot/../tail", "http://foo.com/bar/quux/tail"}, + {"http://foo.com/bar/baz", "quux/./dotdot/.././tail", "http://foo.com/bar/quux/tail"}, + {"http://foo.com/bar/baz", "quux/./dotdot/./../tail", "http://foo.com/bar/quux/tail"}, + {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/././../../tail", "http://foo.com/bar/quux/tail"}, + {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/./.././../tail", "http://foo.com/bar/quux/tail"}, + {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/dotdot/./../../.././././tail", "http://foo.com/bar/quux/tail"}, + {"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot"}, // "." and ".." in the base aren't special {"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/./dotdot/../baz"}, @@ -741,6 +798,24 @@ var requritests = []RequestURITest{ }, "/a%20b", }, + // golang.org/issue/4860 variant 1 + { + &URL{ + Scheme: "http", + Host: "example.com", + Opaque: "/%2F/%2F/", + }, + "/%2F/%2F/", + }, + // golang.org/issue/4860 variant 2 + { + &URL{ + Scheme: "http", + Host: "example.com", + Opaque: "//other.example.com/%2F/%2F/", + }, + "http://other.example.com/%2F/%2F/", + }, { &URL{ Scheme: "http", @@ -775,3 +850,13 @@ func TestRequestURI(t *testing.T) { } } } + +func TestParseFailure(t *testing.T) { + // Test that the first parse error is returned. + const url = "%gh&%ij" + _, err := ParseQuery(url) + errStr := fmt.Sprint(err) + if !strings.Contains(errStr, "%gh") { + t.Errorf(`ParseQuery(%q) returned error %q, want something containing %q"`, url, errStr, "%gh") + } +} |