diff options
Diffstat (limited to 'src/pkg/net')
94 files changed, 3692 insertions, 2015 deletions
diff --git a/src/pkg/net/conn_test.go b/src/pkg/net/conn_test.go index fdb90862f..98bd69549 100644 --- a/src/pkg/net/conn_test.go +++ b/src/pkg/net/conn_test.go @@ -16,11 +16,11 @@ import ( var connTests = []struct { net string - addr string + addr func() string }{ - {"tcp", "127.0.0.1:0"}, - {"unix", testUnixAddr()}, - {"unixpacket", testUnixAddr()}, + {"tcp", func() string { return "127.0.0.1:0" }}, + {"unix", testUnixAddr}, + {"unixpacket", testUnixAddr}, } // someTimeout is used just to test that net.Conn implementations @@ -41,7 +41,8 @@ func TestConnAndListener(t *testing.T) { } } - ln, err := Listen(tt.net, tt.addr) + addr := tt.addr() + ln, err := Listen(tt.net, addr) if err != nil { t.Fatalf("Listen failed: %v", err) } @@ -51,7 +52,7 @@ func TestConnAndListener(t *testing.T) { case "unix", "unixpacket": os.Remove(addr) } - }(ln, tt.net, tt.addr) + }(ln, tt.net, addr) ln.Addr() done := make(chan int) diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index 22e1e7dd8..b18d28362 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -9,78 +9,48 @@ import ( "time" ) -// A DialOption modifies a DialOpt call. -type DialOption interface { - dialOption() -} - -var ( - // TCP is a dial option to dial with TCP (over IPv4 or IPv6). - TCP = Network("tcp") - - // UDP is a dial option to dial with UDP (over IPv4 or IPv6). - UDP = Network("udp") -) - -// Network returns a DialOption to dial using the given network. -// -// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), -// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" -// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and -// "unixpacket". -// -// For IP networks, net must be "ip", "ip4" or "ip6" followed -// by a colon and a protocol number or name, such as -// "ipv4:1" or "ip6:ospf". -func Network(net string) DialOption { - return dialNetwork(net) -} - -type dialNetwork string - -func (dialNetwork) dialOption() {} - -// Deadline returns a DialOption to fail a dial that doesn't -// complete before t. -func Deadline(t time.Time) DialOption { - return dialDeadline(t) -} - -// Timeout returns a DialOption to fail a dial that doesn't -// complete within the provided duration. -func Timeout(d time.Duration) DialOption { - return dialDeadline(time.Now().Add(d)) -} - -type dialDeadline time.Time - -func (dialDeadline) dialOption() {} - -type tcpFastOpen struct{} - -func (tcpFastOpen) dialOption() {} - -// TODO(bradfitz): implement this (golang.org/issue/4842) and unexport this. +// A Dialer contains options for connecting to an address. // -// TCPFastTimeout returns an option to use TCP Fast Open (TFO) when -// doing this dial. It is only valid for use with TCP connections. -// Data sent over a TFO connection may be processed by the peer -// multiple times, so should be used with caution. -func todo_TCPFastTimeout() DialOption { - return tcpFastOpen{} -} - -type localAddrOption struct { - la Addr -} - -func (localAddrOption) dialOption() {} - -// LocalAddress returns a dial option to perform a dial with the -// provided local address. The address must be of a compatible type -// for the network being dialed. -func LocalAddress(addr Addr) DialOption { - return localAddrOption{addr} +// The zero value for each field is equivalent to dialing +// without that option. Dialing with the zero value of Dialer +// is therefore equivalent to just calling the Dial function. +type Dialer struct { + // Timeout is the maximum amount of time a dial will wait for + // a connect to complete. If Deadline is also set, it may fail + // earlier. + // + // The default is no timeout. + // + // With or without a timeout, the operating system may impose + // its own earlier timeout. For instance, TCP timeouts are + // often around 3 minutes. + Timeout time.Duration + + // Deadline is the absolute point in time after which dials + // will fail. If Timeout is set, it may fail earlier. + // Zero means no deadline, or dependent on the operating system + // as with the Timeout option. + Deadline time.Time + + // LocalAddr is the local address to use when dialing an + // address. The address must be of a compatible type for the + // network being dialed. + // If nil, a local address is automatically chosen. + LocalAddr Addr +} + +// Return either now+Timeout or Deadline, whichever comes first. +// Or zero, if neither is set. +func (d *Dialer) deadline() time.Time { + if d.Timeout == 0 { + return d.Deadline + } + timeoutDeadline := time.Now().Add(d.Timeout) + if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) { + return timeoutDeadline + } else { + return d.Deadline + } } func parseNetwork(net string) (afnet string, proto int, err error) { @@ -127,7 +97,7 @@ func resolveAddr(op, net, addr string, deadline time.Time) (Addr, error) { return resolveInternetAddr(afnet, addr, deadline) } -// Dial connects to the address addr on the network net. +// Dial connects to the address on the named network. // // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), // "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" @@ -135,67 +105,45 @@ func resolveAddr(op, net, addr string, deadline time.Time) (Addr, error) { // "unixpacket". // // For TCP and UDP networks, addresses have the form host:port. -// If host is a literal IPv6 address, it must be enclosed -// in square brackets. The functions JoinHostPort and SplitHostPort -// manipulate addresses in this form. +// If host is a literal IPv6 address or host name, it must be enclosed +// in square brackets as in "[::1]:80", "[ipv6-host]:http" or +// "[ipv6-host%zone]:80". +// The functions JoinHostPort and SplitHostPort manipulate addresses +// in this form. // // Examples: // Dial("tcp", "12.34.56.78:80") -// Dial("tcp", "google.com:80") -// Dial("tcp", "[de:ad:be:ef::ca:fe]:80") +// Dial("tcp", "google.com:http") +// Dial("tcp", "[2001:db8::1]:http") +// Dial("tcp", "[fe80::1%lo0]:80") // -// For IP networks, net must be "ip", "ip4" or "ip6" followed -// by a colon and a protocol number or name. +// For IP networks, the network must be "ip", "ip4" or "ip6" followed +// by a colon and a protocol number or name and the addr must be a +// literal IP address. // // Examples: // Dial("ip4:1", "127.0.0.1") // Dial("ip6:ospf", "::1") // -func Dial(net, addr string) (Conn, error) { - return DialOpt(addr, dialNetwork(net)) +// For Unix networks, the address must be a file system path. +func Dial(network, address string) (Conn, error) { + var d Dialer + return d.Dial(network, address) } -func netFromOptions(opts []DialOption) string { - for _, opt := range opts { - if p, ok := opt.(dialNetwork); ok { - return string(p) - } - } - return "tcp" -} - -func deadlineFromOptions(opts []DialOption) time.Time { - for _, opt := range opts { - if d, ok := opt.(dialDeadline); ok { - return time.Time(d) - } - } - return noDeadline -} - -var noLocalAddr Addr // nil - -func localAddrFromOptions(opts []DialOption) Addr { - for _, opt := range opts { - if o, ok := opt.(localAddrOption); ok { - return o.la - } - } - return noLocalAddr +// DialTimeout acts like Dial but takes a timeout. +// The timeout includes name resolution, if required. +func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { + d := Dialer{Timeout: timeout} + return d.Dial(network, address) } -// DialOpt dials addr using the provided options. -// If no options are provided, DialOpt(addr) is equivalent -// to Dial("tcp", addr). See Dial for the syntax of addr. -func DialOpt(addr string, opts ...DialOption) (Conn, error) { - net := netFromOptions(opts) - deadline := deadlineFromOptions(opts) - la := localAddrFromOptions(opts) - ra, err := resolveAddr("dial", net, addr, deadline) - if err != nil { - return nil, err - } - return dial(net, addr, la, ra, deadline) +// Dial connects to the address on the named network. +// +// See func Dial for a description of the network and address +// parameters. +func (d *Dialer) Dial(network, address string) (Conn, error) { + return resolveAndDial(network, address, d.LocalAddr, d.deadline()) } func dial(net, addr string, la, ra Addr, deadline time.Time) (c Conn, err error) { @@ -224,59 +172,6 @@ func dial(net, addr string, la, ra Addr, deadline time.Time) (c Conn, err error) return } -// DialTimeout acts like Dial but takes a timeout. -// The timeout includes name resolution, if required. -func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) { - return dialTimeout(net, addr, timeout) -} - -// dialTimeoutRace is the old implementation of DialTimeout, still used -// on operating systems where the deadline hasn't been pushed down -// into the pollserver. -// TODO: fix this on plan9. -func dialTimeoutRace(net, addr string, timeout time.Duration) (Conn, error) { - t := time.NewTimer(timeout) - defer t.Stop() - type pair struct { - Conn - error - } - ch := make(chan pair, 1) - resolvedAddr := make(chan Addr, 1) - go func() { - ra, err := resolveAddr("dial", net, addr, noDeadline) - if err != nil { - ch <- pair{nil, err} - return - } - resolvedAddr <- ra // in case we need it for OpError - c, err := dial(net, addr, noLocalAddr, ra, noDeadline) - ch <- pair{c, err} - }() - select { - case <-t.C: - // Try to use the real Addr in our OpError, if we resolved it - // before the timeout. Otherwise we just use stringAddr. - var ra Addr - select { - case a := <-resolvedAddr: - ra = a - default: - ra = &stringAddr{net, addr} - } - err := &OpError{ - Op: "dial", - Net: net, - Addr: ra, - Err: &timeoutError{}, - } - return nil, err - case p := <-ch: - return p.Conn, p.error - } - panic("unreachable") -} - type stringAddr struct { net, addr string } @@ -285,8 +180,9 @@ func (a stringAddr) Network() string { return a.net } func (a stringAddr) String() string { return a.addr } // Listen announces on the local network address laddr. -// The network string net must be a stream-oriented network: -// "tcp", "tcp4", "tcp6", "unix" or "unixpacket". +// The network net must be a stream-oriented network: "tcp", "tcp4", +// "tcp6", "unix" or "unixpacket". +// See Dial for the syntax of laddr. func Listen(net, laddr string) (Listener, error) { la, err := resolveAddr("listen", net, laddr, noDeadline) if err != nil { @@ -302,8 +198,9 @@ func Listen(net, laddr string) (Listener, error) { } // ListenPacket announces on the local network address laddr. -// The network string net must be a packet-oriented network: -// "udp", "udp4", "udp6", "ip", "ip4", "ip6" or "unixgram". +// The network net must be a packet-oriented network: "udp", "udp4", +// "udp6", "ip", "ip4", "ip6" or "unixgram". +// See Dial for the syntax of laddr. func ListenPacket(net, laddr string) (PacketConn, error) { la, err := resolveAddr("listen", net, laddr, noDeadline) if err != nil { diff --git a/src/pkg/net/dial_gen.go b/src/pkg/net/dial_gen.go new file mode 100644 index 000000000..19f868168 --- /dev/null +++ b/src/pkg/net/dial_gen.go @@ -0,0 +1,73 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows plan9 + +package net + +import ( + "time" +) + +var testingIssue5349 bool // used during tests + +// resolveAndDialChannel is the simple pure-Go implementation of +// resolveAndDial, still used on operating systems where the deadline +// hasn't been pushed down into the pollserver. (Plan 9 and some old +// versions of Windows) +func resolveAndDialChannel(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { + var timeout time.Duration + if !deadline.IsZero() { + timeout = deadline.Sub(time.Now()) + } + if timeout <= 0 { + ra, err := resolveAddr("dial", net, addr, noDeadline) + if err != nil { + return nil, err + } + return dial(net, addr, localAddr, ra, noDeadline) + } + t := time.NewTimer(timeout) + defer t.Stop() + type pair struct { + Conn + error + } + ch := make(chan pair, 1) + resolvedAddr := make(chan Addr, 1) + go func() { + if testingIssue5349 { + time.Sleep(time.Millisecond) + } + ra, err := resolveAddr("dial", net, addr, noDeadline) + if err != nil { + ch <- pair{nil, err} + return + } + resolvedAddr <- ra // in case we need it for OpError + c, err := dial(net, addr, localAddr, ra, noDeadline) + ch <- pair{c, err} + }() + select { + case <-t.C: + // Try to use the real Addr in our OpError, if we resolved it + // before the timeout. Otherwise we just use stringAddr. + var ra Addr + select { + case a := <-resolvedAddr: + ra = a + default: + ra = &stringAddr{net, addr} + } + err := &OpError{ + Op: "dial", + Net: net, + Addr: ra, + Err: &timeoutError{}, + } + return nil, err + case p := <-ch: + return p.Conn, p.error + } +} diff --git a/src/pkg/net/dial_gen_test.go b/src/pkg/net/dial_gen_test.go new file mode 100644 index 000000000..c857acd06 --- /dev/null +++ b/src/pkg/net/dial_gen_test.go @@ -0,0 +1,11 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows plan9 + +package net + +func init() { + testingIssue5349 = true +} diff --git a/src/pkg/net/dial_test.go b/src/pkg/net/dial_test.go index 2303e8fa4..e24fecc8d 100644 --- a/src/pkg/net/dial_test.go +++ b/src/pkg/net/dial_test.go @@ -28,12 +28,18 @@ func newLocalListener(t *testing.T) Listener { } func TestDialTimeout(t *testing.T) { + origBacklog := listenerBacklog + defer func() { + listenerBacklog = origBacklog + }() + listenerBacklog = 1 + ln := newLocalListener(t) defer ln.Close() errc := make(chan error) - numConns := listenerBacklog + 10 + numConns := listenerBacklog + 100 // TODO(bradfitz): It's hard to test this in a portable // way. This is unfortunate, but works for now. @@ -324,3 +330,80 @@ func numFD() int { // All tests using this should be skipped anyway, but: panic("numFDs not implemented on " + runtime.GOOS) } + +var testPoller = flag.Bool("poller", false, "platform supports runtime-integrated poller") + +// Assert that a failed Dial attempt does not leak +// runtime.PollDesc structures +func TestDialFailPDLeak(t *testing.T) { + if !*testPoller { + t.Skip("test disabled; use -poller to enable") + } + + const loops = 10 + const count = 20000 + var old runtime.MemStats // used by sysdelta + runtime.ReadMemStats(&old) + sysdelta := func() uint64 { + var new runtime.MemStats + runtime.ReadMemStats(&new) + delta := old.Sys - new.Sys + old = new + return delta + } + d := &Dialer{Timeout: time.Nanosecond} // don't bother TCP with handshaking + failcount := 0 + for i := 0; i < loops; i++ { + for i := 0; i < count; i++ { + conn, err := d.Dial("tcp", "127.0.0.1:1") + if err == nil { + t.Error("dial should not succeed") + conn.Close() + t.FailNow() + } + } + if delta := sysdelta(); delta > 0 { + failcount++ + } + // there are always some allocations on the first loop + if failcount > 3 { + t.Error("detected possible memory leak in runtime") + t.FailNow() + } + } +} + +func TestDialer(t *testing.T) { + ln, err := Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer ln.Close() + ch := make(chan error, 1) + go func() { + var err error + c, err := ln.Accept() + if err != nil { + ch <- fmt.Errorf("Accept failed: %v", err) + return + } + defer c.Close() + ch <- nil + }() + + laddr, err := ResolveTCPAddr("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveTCPAddr failed: %v", err) + } + d := &Dialer{LocalAddr: laddr} + c, err := d.Dial("tcp4", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c.Close() + c.Read(make([]byte, 1)) + err = <-ch + if err != nil { + t.Error(err) + } +} diff --git a/src/pkg/net/empty.c b/src/pkg/net/empty.c new file mode 100644 index 000000000..a515c2fe2 --- /dev/null +++ b/src/pkg/net/empty.c @@ -0,0 +1,8 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file is required to prevent compiler errors +// when the package built with CGO_ENABLED=0. +// Otherwise the compiler says: +// pkg/net/fd_poll_runtime.go:15: missing function body diff --git a/src/pkg/net/example_test.go b/src/pkg/net/example_test.go index eefe84fa7..6f2f9074c 100644 --- a/src/pkg/net/example_test.go +++ b/src/pkg/net/example_test.go @@ -16,6 +16,7 @@ func ExampleListener() { if err != nil { log.Fatal(err) } + defer l.Close() for { // Wait for a connection. conn, err := l.Accept() diff --git a/src/pkg/net/fd_darwin.go b/src/pkg/net/fd_darwin.go deleted file mode 100644 index 382465ba6..000000000 --- a/src/pkg/net/fd_darwin.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Waiting for FDs via kqueue/kevent. - -package net - -import ( - "errors" - "os" - "syscall" -) - -type pollster struct { - kq int - eventbuf [10]syscall.Kevent_t - events []syscall.Kevent_t - - // An event buffer for AddFD/DelFD. - // Must hold pollServer lock. - kbuf [1]syscall.Kevent_t -} - -func newpollster() (p *pollster, err error) { - p = new(pollster) - if p.kq, err = syscall.Kqueue(); err != nil { - return nil, os.NewSyscallError("kqueue", err) - } - syscall.CloseOnExec(p.kq) - p.events = p.eventbuf[0:0] - return p, nil -} - -// First return value is whether the pollServer should be woken up. -// This version always returns false. -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { - // pollServer is locked. - - var kmode int - if mode == 'r' { - kmode = syscall.EVFILT_READ - } else { - kmode = syscall.EVFILT_WRITE - } - ev := &p.kbuf[0] - // EV_ADD - add event to kqueue list - // EV_RECEIPT - generate fake EV_ERROR as result of add, - // rather than waiting for real event - // EV_ONESHOT - delete the event the first time it triggers - flags := syscall.EV_ADD | syscall.EV_RECEIPT - if !repeat { - flags |= syscall.EV_ONESHOT - } - syscall.SetKevent(ev, fd, kmode, flags) - - n, err := syscall.Kevent(p.kq, p.kbuf[:], p.kbuf[:], nil) - if err != nil { - return false, os.NewSyscallError("kevent", err) - } - if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { - return false, errors.New("kqueue phase error") - } - if ev.Data != 0 { - return false, syscall.Errno(ev.Data) - } - return false, nil -} - -// Return value is whether the pollServer should be woken up. -// This version always returns false. -func (p *pollster) DelFD(fd int, mode int) bool { - // pollServer is locked. - - var kmode int - if mode == 'r' { - kmode = syscall.EVFILT_READ - } else { - kmode = syscall.EVFILT_WRITE - } - ev := &p.kbuf[0] - // EV_DELETE - delete event from kqueue list - // EV_RECEIPT - generate fake EV_ERROR as result of add, - // rather than waiting for real event - syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE|syscall.EV_RECEIPT) - syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil) - return false -} - -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { - var t *syscall.Timespec - for len(p.events) == 0 { - if nsec > 0 { - if t == nil { - t = new(syscall.Timespec) - } - *t = syscall.NsecToTimespec(nsec) - } - - s.Unlock() - n, err := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) - s.Lock() - - if err != nil { - if err == syscall.EINTR { - continue - } - return -1, 0, os.NewSyscallError("kevent", nil) - } - if n == 0 { - return -1, 0, nil - } - p.events = p.eventbuf[:n] - } - ev := &p.events[0] - p.events = p.events[1:] - fd = int(ev.Ident) - if ev.Filter == syscall.EVFILT_READ { - mode = 'r' - } else { - mode = 'w' - } - return fd, mode, nil -} - -func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/src/pkg/net/fd_linux.go b/src/pkg/net/fd_linux.go deleted file mode 100644 index 03679196d..000000000 --- a/src/pkg/net/fd_linux.go +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Waiting for FDs via epoll(7). - -package net - -import ( - "os" - "syscall" -) - -const ( - readFlags = syscall.EPOLLIN | syscall.EPOLLRDHUP - writeFlags = syscall.EPOLLOUT -) - -type pollster struct { - epfd int - - // Events we're already waiting for - // Must hold pollServer lock - events map[int]uint32 - - // An event buffer for EpollWait. - // Used without a lock, may only be used by WaitFD. - waitEventBuf [10]syscall.EpollEvent - waitEvents []syscall.EpollEvent - - // An event buffer for EpollCtl, to avoid a malloc. - // Must hold pollServer lock. - ctlEvent syscall.EpollEvent -} - -func newpollster() (p *pollster, err error) { - p = new(pollster) - if p.epfd, err = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC); err != nil { - if err != syscall.ENOSYS { - return nil, os.NewSyscallError("epoll_create1", err) - } - // The arg to epoll_create is a hint to the kernel - // about the number of FDs we will care about. - // We don't know, and since 2.6.8 the kernel ignores it anyhow. - if p.epfd, err = syscall.EpollCreate(16); err != nil { - return nil, os.NewSyscallError("epoll_create", err) - } - syscall.CloseOnExec(p.epfd) - } - p.events = make(map[int]uint32) - return p, nil -} - -// First return value is whether the pollServer should be woken up. -// This version always returns false. -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { - // pollServer is locked. - - var already bool - p.ctlEvent.Fd = int32(fd) - p.ctlEvent.Events, already = p.events[fd] - if !repeat { - p.ctlEvent.Events |= syscall.EPOLLONESHOT - } - if mode == 'r' { - p.ctlEvent.Events |= readFlags - } else { - p.ctlEvent.Events |= writeFlags - } - - var op int - if already { - op = syscall.EPOLL_CTL_MOD - } else { - op = syscall.EPOLL_CTL_ADD - } - if err := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); err != nil { - return false, os.NewSyscallError("epoll_ctl", err) - } - p.events[fd] = p.ctlEvent.Events - return false, nil -} - -func (p *pollster) StopWaiting(fd int, bits uint) { - // pollServer is locked. - - events, already := p.events[fd] - if !already { - // The fd returned by the kernel may have been - // cancelled already; return silently. - return - } - - // If syscall.EPOLLONESHOT is not set, the wait - // is a repeating wait, so don't change it. - if events&syscall.EPOLLONESHOT == 0 { - return - } - - // Disable the given bits. - // If we're still waiting for other events, modify the fd - // event in the kernel. Otherwise, delete it. - events &= ^uint32(bits) - if int32(events)&^syscall.EPOLLONESHOT != 0 { - p.ctlEvent.Fd = int32(fd) - p.ctlEvent.Events = events - if err := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); err != nil { - print("Epoll modify fd=", fd, ": ", err.Error(), "\n") - } - p.events[fd] = events - } else { - if err := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); err != nil { - print("Epoll delete fd=", fd, ": ", err.Error(), "\n") - } - delete(p.events, fd) - } -} - -// Return value is whether the pollServer should be woken up. -// This version always returns false. -func (p *pollster) DelFD(fd int, mode int) bool { - // pollServer is locked. - - if mode == 'r' { - p.StopWaiting(fd, readFlags) - } else { - p.StopWaiting(fd, writeFlags) - } - - // Discard any queued up events. - i := 0 - for i < len(p.waitEvents) { - if fd == int(p.waitEvents[i].Fd) { - copy(p.waitEvents[i:], p.waitEvents[i+1:]) - p.waitEvents = p.waitEvents[:len(p.waitEvents)-1] - } else { - i++ - } - } - return false -} - -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { - for len(p.waitEvents) == 0 { - var msec int = -1 - if nsec > 0 { - msec = int((nsec + 1e6 - 1) / 1e6) - } - - s.Unlock() - n, err := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec) - s.Lock() - - if err != nil { - if err == syscall.EAGAIN || err == syscall.EINTR { - continue - } - return -1, 0, os.NewSyscallError("epoll_wait", err) - } - if n == 0 { - return -1, 0, nil - } - p.waitEvents = p.waitEventBuf[0:n] - } - - ev := &p.waitEvents[0] - p.waitEvents = p.waitEvents[1:] - - fd = int(ev.Fd) - - if ev.Events&writeFlags != 0 { - p.StopWaiting(fd, writeFlags) - return fd, 'w', nil - } - if ev.Events&readFlags != 0 { - p.StopWaiting(fd, readFlags) - return fd, 'r', nil - } - - // Other events are error conditions - wake whoever is waiting. - events, _ := p.events[fd] - if events&writeFlags != 0 { - p.StopWaiting(fd, writeFlags) - return fd, 'w', nil - } - p.StopWaiting(fd, readFlags) - return fd, 'r', nil -} - -func (p *pollster) Close() error { - return os.NewSyscallError("close", syscall.Close(p.epfd)) -} diff --git a/src/pkg/net/fd_plan9.go b/src/pkg/net/fd_plan9.go index 169087999..e9527a374 100644 --- a/src/pkg/net/fd_plan9.go +++ b/src/pkg/net/fd_plan9.go @@ -23,10 +23,10 @@ var canCancelIO = true // used for testing current package func sysInit() { } -func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { +func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { // On plan9, use the relatively inefficient // goroutine-racing implementation. - return dialTimeoutRace(net, addr, timeout) + return resolveAndDialChannel(net, addr, localAddr, deadline) } func newFD(proto, name string, ctl, data *os.File, laddr, raddr Addr) *netFD { diff --git a/src/pkg/net/fd_poll_runtime.go b/src/pkg/net/fd_poll_runtime.go new file mode 100644 index 000000000..e3b4f7e46 --- /dev/null +++ b/src/pkg/net/fd_poll_runtime.go @@ -0,0 +1,119 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin linux + +package net + +import ( + "sync" + "syscall" + "time" +) + +func runtime_pollServerInit() +func runtime_pollOpen(fd int) (uintptr, int) +func runtime_pollClose(ctx uintptr) +func runtime_pollWait(ctx uintptr, mode int) int +func runtime_pollReset(ctx uintptr, mode int) int +func runtime_pollSetDeadline(ctx uintptr, d int64, mode int) +func runtime_pollUnblock(ctx uintptr) + +var canCancelIO = true // used for testing current package + +type pollDesc struct { + runtimeCtx uintptr +} + +var serverInit sync.Once + +func sysInit() { +} + +func (pd *pollDesc) Init(fd *netFD) error { + serverInit.Do(runtime_pollServerInit) + ctx, errno := runtime_pollOpen(fd.sysfd) + if errno != 0 { + return syscall.Errno(errno) + } + pd.runtimeCtx = ctx + return nil +} + +func (pd *pollDesc) Close() { + runtime_pollClose(pd.runtimeCtx) +} + +func (pd *pollDesc) Lock() { +} + +func (pd *pollDesc) Unlock() { +} + +func (pd *pollDesc) Wakeup() { +} + +// Evict evicts fd from the pending list, unblocking any I/O running on fd. +// Return value is whether the pollServer should be woken up. +func (pd *pollDesc) Evict() bool { + runtime_pollUnblock(pd.runtimeCtx) + return false +} + +func (pd *pollDesc) PrepareRead() error { + res := runtime_pollReset(pd.runtimeCtx, 'r') + return convertErr(res) +} + +func (pd *pollDesc) PrepareWrite() error { + res := runtime_pollReset(pd.runtimeCtx, 'w') + return convertErr(res) +} + +func (pd *pollDesc) WaitRead() error { + res := runtime_pollWait(pd.runtimeCtx, 'r') + return convertErr(res) +} + +func (pd *pollDesc) WaitWrite() error { + res := runtime_pollWait(pd.runtimeCtx, 'w') + return convertErr(res) +} + +func convertErr(res int) error { + switch res { + case 0: + return nil + case 1: + return errClosing + case 2: + return errTimeout + } + panic("unreachable") +} + +func setReadDeadline(fd *netFD, t time.Time) error { + return setDeadlineImpl(fd, t, 'r') +} + +func setWriteDeadline(fd *netFD, t time.Time) error { + return setDeadlineImpl(fd, t, 'w') +} + +func setDeadline(fd *netFD, t time.Time) error { + return setDeadlineImpl(fd, t, 'r'+'w') +} + +func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { + d := t.UnixNano() + if t.IsZero() { + d = 0 + } + if err := fd.incref(false); err != nil { + return err + } + runtime_pollSetDeadline(fd.pd.runtimeCtx, d, mode) + fd.decref() + return nil +} diff --git a/src/pkg/net/fd_poll_unix.go b/src/pkg/net/fd_poll_unix.go new file mode 100644 index 000000000..307e577e9 --- /dev/null +++ b/src/pkg/net/fd_poll_unix.go @@ -0,0 +1,360 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build freebsd netbsd openbsd + +package net + +import ( + "os" + "runtime" + "sync" + "syscall" + "time" +) + +// A pollServer helps FDs determine when to retry a non-blocking +// read or write after they get EAGAIN. When an FD needs to wait, +// call s.WaitRead() or s.WaitWrite() to pass the request to the poll server. +// When the pollServer finds that i/o on FD should be possible +// again, it will send on fd.cr/fd.cw to wake any waiting goroutines. +// +// To avoid races in closing, all fd operations are locked and +// refcounted. when netFD.Close() is called, it calls syscall.Shutdown +// and sets a closing flag. Only when the last reference is removed +// will the fd be closed. + +type pollServer struct { + pr, pw *os.File + poll *pollster // low-level OS hooks + sync.Mutex // controls pending and deadline + pending map[int]*pollDesc + deadline int64 // next deadline (nsec since 1970) +} + +// A pollDesc contains netFD state related to pollServer. +type pollDesc struct { + // immutable after Init() + pollServer *pollServer + sysfd int + cr, cw chan error + + // mutable, protected by pollServer mutex + closing bool + ncr, ncw int + + // mutable, safe for concurrent access + rdeadline, wdeadline deadline +} + +func newPollServer() (s *pollServer, err error) { + s = new(pollServer) + if s.pr, s.pw, err = os.Pipe(); err != nil { + return nil, err + } + if err = syscall.SetNonblock(int(s.pr.Fd()), true); err != nil { + goto Errno + } + if err = syscall.SetNonblock(int(s.pw.Fd()), true); err != nil { + goto Errno + } + if s.poll, err = newpollster(); err != nil { + goto Error + } + if _, err = s.poll.AddFD(int(s.pr.Fd()), 'r', true); err != nil { + s.poll.Close() + goto Error + } + s.pending = make(map[int]*pollDesc) + go s.Run() + return s, nil + +Errno: + err = &os.PathError{ + Op: "setnonblock", + Path: s.pr.Name(), + Err: err, + } +Error: + s.pr.Close() + s.pw.Close() + return nil, err +} + +func (s *pollServer) AddFD(pd *pollDesc, mode int) error { + s.Lock() + intfd := pd.sysfd + if intfd < 0 || pd.closing { + // fd closed underfoot + s.Unlock() + return errClosing + } + + var t int64 + key := intfd << 1 + if mode == 'r' { + pd.ncr++ + t = pd.rdeadline.value() + } else { + pd.ncw++ + key++ + t = pd.wdeadline.value() + } + s.pending[key] = pd + doWakeup := false + if t > 0 && (s.deadline == 0 || t < s.deadline) { + s.deadline = t + doWakeup = true + } + + wake, err := s.poll.AddFD(intfd, mode, false) + s.Unlock() + if err != nil { + return err + } + if wake || doWakeup { + s.Wakeup() + } + return nil +} + +// Evict evicts pd from the pending list, unblocking +// any I/O running on pd. The caller must have locked +// pollserver. +// Return value is whether the pollServer should be woken up. +func (s *pollServer) Evict(pd *pollDesc) bool { + pd.closing = true + doWakeup := false + if s.pending[pd.sysfd<<1] == pd { + s.WakeFD(pd, 'r', errClosing) + if s.poll.DelFD(pd.sysfd, 'r') { + doWakeup = true + } + delete(s.pending, pd.sysfd<<1) + } + if s.pending[pd.sysfd<<1|1] == pd { + s.WakeFD(pd, 'w', errClosing) + if s.poll.DelFD(pd.sysfd, 'w') { + doWakeup = true + } + delete(s.pending, pd.sysfd<<1|1) + } + return doWakeup +} + +var wakeupbuf [1]byte + +func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) } + +func (s *pollServer) LookupFD(fd int, mode int) *pollDesc { + key := fd << 1 + if mode == 'w' { + key++ + } + netfd, ok := s.pending[key] + if !ok { + return nil + } + delete(s.pending, key) + return netfd +} + +func (s *pollServer) WakeFD(pd *pollDesc, mode int, err error) { + if mode == 'r' { + for pd.ncr > 0 { + pd.ncr-- + pd.cr <- err + } + } else { + for pd.ncw > 0 { + pd.ncw-- + pd.cw <- err + } + } +} + +func (s *pollServer) CheckDeadlines() { + now := time.Now().UnixNano() + // TODO(rsc): This will need to be handled more efficiently, + // probably with a heap indexed by wakeup time. + + var nextDeadline int64 + for key, pd := range s.pending { + var t int64 + var mode int + if key&1 == 0 { + mode = 'r' + } else { + mode = 'w' + } + if mode == 'r' { + t = pd.rdeadline.value() + } else { + t = pd.wdeadline.value() + } + if t > 0 { + if t <= now { + delete(s.pending, key) + s.poll.DelFD(pd.sysfd, mode) + s.WakeFD(pd, mode, errTimeout) + } else if nextDeadline == 0 || t < nextDeadline { + nextDeadline = t + } + } + } + s.deadline = nextDeadline +} + +func (s *pollServer) Run() { + var scratch [100]byte + s.Lock() + defer s.Unlock() + for { + var timeout int64 // nsec to wait for or 0 for none + if s.deadline > 0 { + timeout = s.deadline - time.Now().UnixNano() + if timeout <= 0 { + s.CheckDeadlines() + continue + } + } + fd, mode, err := s.poll.WaitFD(s, timeout) + if err != nil { + print("pollServer WaitFD: ", err.Error(), "\n") + return + } + if fd < 0 { + // Timeout happened. + s.CheckDeadlines() + continue + } + if fd == int(s.pr.Fd()) { + // Drain our wakeup pipe (we could loop here, + // but it's unlikely that there are more than + // len(scratch) wakeup calls). + s.pr.Read(scratch[0:]) + s.CheckDeadlines() + } else { + pd := s.LookupFD(fd, mode) + if pd == nil { + // This can happen because the WaitFD runs without + // holding s's lock, so there might be a pending wakeup + // for an fd that has been evicted. No harm done. + continue + } + s.WakeFD(pd, mode, nil) + } + } +} + +func (pd *pollDesc) Close() { +} + +func (pd *pollDesc) Lock() { + pd.pollServer.Lock() +} + +func (pd *pollDesc) Unlock() { + pd.pollServer.Unlock() +} + +func (pd *pollDesc) Wakeup() { + pd.pollServer.Wakeup() +} + +func (pd *pollDesc) PrepareRead() error { + if pd.rdeadline.expired() { + return errTimeout + } + return nil +} + +func (pd *pollDesc) PrepareWrite() error { + if pd.wdeadline.expired() { + return errTimeout + } + return nil +} + +func (pd *pollDesc) WaitRead() error { + err := pd.pollServer.AddFD(pd, 'r') + if err == nil { + err = <-pd.cr + } + return err +} + +func (pd *pollDesc) WaitWrite() error { + err := pd.pollServer.AddFD(pd, 'w') + if err == nil { + err = <-pd.cw + } + return err +} + +func (pd *pollDesc) Evict() bool { + return pd.pollServer.Evict(pd) +} + +// Spread network FDs over several pollServers. + +var pollMaxN int +var pollservers []*pollServer +var startServersOnce []func() + +var canCancelIO = true // used for testing current package + +func sysInit() { + pollMaxN = runtime.NumCPU() + if pollMaxN > 8 { + pollMaxN = 8 // No improvement then. + } + pollservers = make([]*pollServer, pollMaxN) + startServersOnce = make([]func(), pollMaxN) + for i := 0; i < pollMaxN; i++ { + k := i + once := new(sync.Once) + startServersOnce[i] = func() { once.Do(func() { startServer(k) }) } + } +} + +func startServer(k int) { + p, err := newPollServer() + if err != nil { + panic(err) + } + pollservers[k] = p +} + +func (pd *pollDesc) Init(fd *netFD) error { + pollN := runtime.GOMAXPROCS(0) + if pollN > pollMaxN { + pollN = pollMaxN + } + k := fd.sysfd % pollN + startServersOnce[k]() + pd.sysfd = fd.sysfd + pd.pollServer = pollservers[k] + pd.cr = make(chan error, 1) + pd.cw = make(chan error, 1) + return nil +} + +// TODO(dfc) these unused error returns could be removed + +func setReadDeadline(fd *netFD, t time.Time) error { + fd.pd.rdeadline.setTime(t) + return nil +} + +func setWriteDeadline(fd *netFD, t time.Time) error { + fd.pd.wdeadline.setTime(t) + return nil +} + +func setDeadline(fd *netFD, t time.Time) error { + setReadDeadline(fd, t) + setWriteDeadline(fd, t) + return nil +} diff --git a/src/pkg/net/fd_unix.go b/src/pkg/net/fd_unix.go index 0540df825..8c59bff98 100644 --- a/src/pkg/net/fd_unix.go +++ b/src/pkg/net/fd_unix.go @@ -9,7 +9,6 @@ package net import ( "io" "os" - "runtime" "sync" "syscall" "time" @@ -21,7 +20,7 @@ type netFD struct { sysmu sync.Mutex sysref int - // must lock both sysmu and pollserver to write + // must lock both sysmu and pollDesc to write // can lock either to read closing bool @@ -31,8 +30,6 @@ type netFD struct { sotype int isConnected bool sysfile *os.File - cr chan error - cw chan error net string laddr Addr raddr Addr @@ -40,264 +37,16 @@ type netFD struct { // serialize access to Read and Write methods rio, wio sync.Mutex - // read and write deadlines - rdeadline, wdeadline deadline - - // owned by fd wait server - ncr, ncw int - // wait server - pollServer *pollServer -} - -// A pollServer helps FDs determine when to retry a non-blocking -// read or write after they get EAGAIN. When an FD needs to wait, -// call s.WaitRead() or s.WaitWrite() to pass the request to the poll server. -// When the pollServer finds that i/o on FD should be possible -// again, it will send on fd.cr/fd.cw to wake any waiting goroutines. -// -// To avoid races in closing, all fd operations are locked and -// refcounted. when netFD.Close() is called, it calls syscall.Shutdown -// and sets a closing flag. Only when the last reference is removed -// will the fd be closed. - -type pollServer struct { - pr, pw *os.File - poll *pollster // low-level OS hooks - sync.Mutex // controls pending and deadline - pending map[int]*netFD - deadline int64 // next deadline (nsec since 1970) -} - -func (s *pollServer) AddFD(fd *netFD, mode int) error { - s.Lock() - intfd := fd.sysfd - if intfd < 0 || fd.closing { - // fd closed underfoot - s.Unlock() - return errClosing - } - - var t int64 - key := intfd << 1 - if mode == 'r' { - fd.ncr++ - t = fd.rdeadline.value() - } else { - fd.ncw++ - key++ - t = fd.wdeadline.value() - } - s.pending[key] = fd - doWakeup := false - if t > 0 && (s.deadline == 0 || t < s.deadline) { - s.deadline = t - doWakeup = true - } - - wake, err := s.poll.AddFD(intfd, mode, false) - s.Unlock() - if err != nil { - return &OpError{"addfd", fd.net, fd.laddr, err} - } - if wake || doWakeup { - s.Wakeup() - } - return nil -} - -// Evict evicts fd from the pending list, unblocking -// any I/O running on fd. The caller must have locked -// pollserver. -// Return value is whether the pollServer should be woken up. -func (s *pollServer) Evict(fd *netFD) bool { - doWakeup := false - if s.pending[fd.sysfd<<1] == fd { - s.WakeFD(fd, 'r', errClosing) - if s.poll.DelFD(fd.sysfd, 'r') { - doWakeup = true - } - delete(s.pending, fd.sysfd<<1) - } - if s.pending[fd.sysfd<<1|1] == fd { - s.WakeFD(fd, 'w', errClosing) - if s.poll.DelFD(fd.sysfd, 'w') { - doWakeup = true - } - delete(s.pending, fd.sysfd<<1|1) - } - return doWakeup -} - -var wakeupbuf [1]byte - -func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) } - -func (s *pollServer) LookupFD(fd int, mode int) *netFD { - key := fd << 1 - if mode == 'w' { - key++ - } - netfd, ok := s.pending[key] - if !ok { - return nil - } - delete(s.pending, key) - return netfd -} - -func (s *pollServer) WakeFD(fd *netFD, mode int, err error) { - if mode == 'r' { - for fd.ncr > 0 { - fd.ncr-- - fd.cr <- err - } - } else { - for fd.ncw > 0 { - fd.ncw-- - fd.cw <- err - } - } -} - -func (s *pollServer) CheckDeadlines() { - now := time.Now().UnixNano() - // TODO(rsc): This will need to be handled more efficiently, - // probably with a heap indexed by wakeup time. - - var nextDeadline int64 - for key, fd := range s.pending { - var t int64 - var mode int - if key&1 == 0 { - mode = 'r' - } else { - mode = 'w' - } - if mode == 'r' { - t = fd.rdeadline.value() - } else { - t = fd.wdeadline.value() - } - if t > 0 { - if t <= now { - delete(s.pending, key) - s.poll.DelFD(fd.sysfd, mode) - s.WakeFD(fd, mode, errTimeout) - } else if nextDeadline == 0 || t < nextDeadline { - nextDeadline = t - } - } - } - s.deadline = nextDeadline -} - -func (s *pollServer) Run() { - var scratch [100]byte - s.Lock() - defer s.Unlock() - for { - var timeout int64 // nsec to wait for or 0 for none - if s.deadline > 0 { - timeout = s.deadline - time.Now().UnixNano() - if timeout <= 0 { - s.CheckDeadlines() - continue - } - } - fd, mode, err := s.poll.WaitFD(s, timeout) - if err != nil { - print("pollServer WaitFD: ", err.Error(), "\n") - return - } - if fd < 0 { - // Timeout happened. - s.CheckDeadlines() - continue - } - if fd == int(s.pr.Fd()) { - // Drain our wakeup pipe (we could loop here, - // but it's unlikely that there are more than - // len(scratch) wakeup calls). - s.pr.Read(scratch[0:]) - s.CheckDeadlines() - } else { - netfd := s.LookupFD(fd, mode) - if netfd == nil { - // This can happen because the WaitFD runs without - // holding s's lock, so there might be a pending wakeup - // for an fd that has been evicted. No harm done. - continue - } - s.WakeFD(netfd, mode, nil) - } - } + pd pollDesc } -func (s *pollServer) WaitRead(fd *netFD) error { - err := s.AddFD(fd, 'r') - if err == nil { - err = <-fd.cr - } - return err -} - -func (s *pollServer) WaitWrite(fd *netFD) error { - err := s.AddFD(fd, 'w') - if err == nil { - err = <-fd.cw - } - return err -} - -// Network FD methods. -// Spread network FDs over several pollServers. - -var pollMaxN int -var pollservers []*pollServer -var startServersOnce []func() - -var canCancelIO = true // used for testing current package - -func sysInit() { - pollMaxN = runtime.NumCPU() - if pollMaxN > 8 { - pollMaxN = 8 // No improvement then. - } - pollservers = make([]*pollServer, pollMaxN) - startServersOnce = make([]func(), pollMaxN) - for i := 0; i < pollMaxN; i++ { - k := i - once := new(sync.Once) - startServersOnce[i] = func() { once.Do(func() { startServer(k) }) } - } -} - -func startServer(k int) { - p, err := newPollServer() - if err != nil { - panic(err) - } - pollservers[k] = p -} - -func server(fd int) *pollServer { - pollN := runtime.GOMAXPROCS(0) - if pollN > pollMaxN { - pollN = pollMaxN - } - k := fd % pollN - startServersOnce[k]() - return pollservers[k] -} - -func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { - deadline := time.Now().Add(timeout) +func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { ra, err := resolveAddr("dial", net, addr, deadline) if err != nil { return nil, err } - return dial(net, addr, noLocalAddr, ra, deadline) + return dial(net, addr, localAddr, ra, deadline) } func newFD(fd, family, sotype int, net string) (*netFD, error) { @@ -307,9 +56,9 @@ func newFD(fd, family, sotype int, net string) (*netFD, error) { sotype: sotype, net: net, } - netfd.cr = make(chan error, 1) - netfd.cw = make(chan error, 1) - netfd.pollServer = server(fd) + if err := netfd.pd.Init(netfd); err != nil { + return nil, err + } return netfd, nil } @@ -330,26 +79,29 @@ func (fd *netFD) name() string { return fd.net + ":" + ls + "->" + rs } -func (fd *netFD) connect(ra syscall.Sockaddr) error { - err := syscall.Connect(fd.sysfd, ra) - if err == syscall.EINPROGRESS { - if err = fd.pollServer.WaitWrite(fd); err != nil { - return err +func (fd *netFD) connect(la, ra syscall.Sockaddr) error { + fd.wio.Lock() + defer fd.wio.Unlock() + if err := fd.pd.PrepareWrite(); err != nil { + return err + } + for { + err := syscall.Connect(fd.sysfd, ra) + if err == nil || err == syscall.EISCONN { + break } - var e int - e, err = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) - if err != nil { - return os.NewSyscallError("getsockopt", err) + if err != syscall.EINPROGRESS && err != syscall.EALREADY && err != syscall.EINTR { + return err } - if e != 0 { - err = syscall.Errno(e) + if err = fd.pd.WaitWrite(); err != nil { + return err } } - return err + return nil } // Add a reference to this fd. -// If closing==true, pollserver must be locked; mark the fd as closing. +// If closing==true, pollDesc must be locked; mark the fd as closing. // Returns an error if the fd cannot be used. func (fd *netFD) incref(closing bool) error { fd.sysmu.Lock() @@ -370,30 +122,37 @@ func (fd *netFD) incref(closing bool) error { func (fd *netFD) decref() { fd.sysmu.Lock() fd.sysref-- - if fd.closing && fd.sysref == 0 && fd.sysfile != nil { - fd.sysfile.Close() - fd.sysfile = nil + if fd.closing && fd.sysref == 0 { + // Poller may want to unregister fd in readiness notification mechanism, + // so this must be executed before sysfile.Close(). + fd.pd.Close() + if fd.sysfile != nil { + fd.sysfile.Close() + fd.sysfile = nil + } else { + closesocket(fd.sysfd) + } fd.sysfd = -1 } fd.sysmu.Unlock() } func (fd *netFD) Close() error { - fd.pollServer.Lock() // needed for both fd.incref(true) and pollserver.Evict + fd.pd.Lock() // needed for both fd.incref(true) and pollDesc.Evict if err := fd.incref(true); err != nil { - fd.pollServer.Unlock() + fd.pd.Unlock() return err } // Unblock any I/O. Once it all unblocks and returns, // so that it cannot be referring to fd.sysfd anymore, // the final decref will close fd.sysfd. This should happen // fairly quickly, since all the I/O is non-blocking, and any - // attempts to block in the pollserver will return errClosing. - doWakeup := fd.pollServer.Evict(fd) - fd.pollServer.Unlock() + // attempts to block in the pollDesc will return errClosing. + doWakeup := fd.pd.Evict() + fd.pd.Unlock() fd.decref() if doWakeup { - fd.pollServer.Wakeup() + fd.pd.Wakeup() } return nil } @@ -425,16 +184,15 @@ func (fd *netFD) Read(p []byte) (n int, err error) { return 0, err } defer fd.decref() + if err := fd.pd.PrepareRead(); err != nil { + return 0, &OpError{"read", fd.net, fd.raddr, err} + } for { - if fd.rdeadline.expired() { - err = errTimeout - break - } n, err = syscall.Read(int(fd.sysfd), p) if err != nil { n = 0 if err == syscall.EAGAIN { - if err = fd.pollServer.WaitRead(fd); err == nil { + if err = fd.pd.WaitRead(); err == nil { continue } } @@ -455,16 +213,15 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { return 0, nil, err } defer fd.decref() + if err := fd.pd.PrepareRead(); err != nil { + return 0, nil, &OpError{"read", fd.net, fd.laddr, err} + } for { - if fd.rdeadline.expired() { - err = errTimeout - break - } n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0) if err != nil { n = 0 if err == syscall.EAGAIN { - if err = fd.pollServer.WaitRead(fd); err == nil { + if err = fd.pd.WaitRead(); err == nil { continue } } @@ -485,16 +242,15 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S return 0, 0, 0, nil, err } defer fd.decref() + if err := fd.pd.PrepareRead(); err != nil { + return 0, 0, 0, nil, &OpError{"read", fd.net, fd.laddr, err} + } for { - if fd.rdeadline.expired() { - err = errTimeout - break - } n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0) if err != nil { // TODO(dfc) should n and oobn be set to 0 if err == syscall.EAGAIN { - if err = fd.pollServer.WaitRead(fd); err == nil { + if err = fd.pd.WaitRead(); err == nil { continue } } @@ -522,11 +278,10 @@ func (fd *netFD) Write(p []byte) (nn int, err error) { return 0, err } defer fd.decref() + if err := fd.pd.PrepareWrite(); err != nil { + return 0, &OpError{"write", fd.net, fd.raddr, err} + } for { - if fd.wdeadline.expired() { - err = errTimeout - break - } var n int n, err = syscall.Write(int(fd.sysfd), p[nn:]) if n > 0 { @@ -536,7 +291,7 @@ func (fd *netFD) Write(p []byte) (nn int, err error) { break } if err == syscall.EAGAIN { - if err = fd.pollServer.WaitWrite(fd); err == nil { + if err = fd.pd.WaitWrite(); err == nil { continue } } @@ -562,14 +317,13 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { return 0, err } defer fd.decref() + if err := fd.pd.PrepareWrite(); err != nil { + return 0, &OpError{"write", fd.net, fd.raddr, err} + } for { - if fd.wdeadline.expired() { - err = errTimeout - break - } err = syscall.Sendto(fd.sysfd, p, 0, sa) if err == syscall.EAGAIN { - if err = fd.pollServer.WaitWrite(fd); err == nil { + if err = fd.pd.WaitWrite(); err == nil { continue } } @@ -590,14 +344,13 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob return 0, 0, err } defer fd.decref() + if err := fd.pd.PrepareWrite(); err != nil { + return 0, 0, &OpError{"write", fd.net, fd.raddr, err} + } for { - if fd.wdeadline.expired() { - err = errTimeout - break - } err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) if err == syscall.EAGAIN { - if err = fd.pollServer.WaitWrite(fd); err == nil { + if err = fd.pd.WaitWrite(); err == nil { continue } } @@ -613,6 +366,8 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob } func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err error) { + fd.rio.Lock() + defer fd.rio.Unlock() if err := fd.incref(false); err != nil { return nil, err } @@ -620,11 +375,14 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e var s int var rsa syscall.Sockaddr + if err = fd.pd.PrepareRead(); err != nil { + return nil, &OpError{"accept", fd.net, fd.laddr, err} + } for { s, rsa, err = accept(fd.sysfd) if err != nil { if err == syscall.EAGAIN { - if err = fd.pollServer.WaitRead(fd); err == nil { + if err = fd.pd.WaitRead(); err == nil { continue } } else if err == syscall.ECONNABORTED { diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index 0e331b44d..fefd174ba 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -54,18 +54,17 @@ func canUseConnectEx(net string) bool { return syscall.LoadConnectEx() == nil } -func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { +func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { if !canUseConnectEx(net) { // Use the relatively inefficient goroutine-racing // implementation of DialTimeout. - return dialTimeoutRace(net, addr, timeout) + return resolveAndDialChannel(net, addr, localAddr, deadline) } - deadline := time.Now().Add(timeout) ra, err := resolveAddr("dial", net, addr, deadline) if err != nil { return nil, err } - return dial(net, addr, noLocalAddr, ra, deadline) + return dial(net, addr, localAddr, ra, deadline) } // Interface for all IO operations. @@ -138,12 +137,18 @@ type resultSrv struct { iocp syscall.Handle } +func runtime_blockingSyscallHint() + func (s *resultSrv) Run() { var o *syscall.Overlapped var key uint32 var r ioResult for { - r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) + r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, 0) + if r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil { + runtime_blockingSyscallHint() + r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) + } switch { case r.err == nil: // Dequeued successfully completed IO packet. @@ -359,22 +364,23 @@ func (o *connectOp) Name() string { return "ConnectEx" } -func (fd *netFD) connect(ra syscall.Sockaddr) error { +func (fd *netFD) connect(la, ra syscall.Sockaddr) error { if !canUseConnectEx(fd.net) { return syscall.Connect(fd.sysfd, ra) } // ConnectEx windows API requires an unconnected, previously bound socket. - var la syscall.Sockaddr - switch ra.(type) { - case *syscall.SockaddrInet4: - la = &syscall.SockaddrInet4{} - case *syscall.SockaddrInet6: - la = &syscall.SockaddrInet6{} - default: - panic("unexpected type in connect") - } - if err := syscall.Bind(fd.sysfd, la); err != nil { - return err + if la == nil { + switch ra.(type) { + case *syscall.SockaddrInet4: + la = &syscall.SockaddrInet4{} + case *syscall.SockaddrInet6: + la = &syscall.SockaddrInet6{} + default: + panic("unexpected type in connect") + } + if err := syscall.Bind(fd.sysfd, la); err != nil { + return err + } } // Call ConnectEx API. var o connectOp diff --git a/src/pkg/net/file_windows.go b/src/pkg/net/file_windows.go index c50c32e21..ca2b9b226 100644 --- a/src/pkg/net/file_windows.go +++ b/src/pkg/net/file_windows.go @@ -9,16 +9,28 @@ import ( "syscall" ) +// FileConn returns a copy of the network connection corresponding to +// the open file f. It is the caller's responsibility to close f when +// finished. Closing c does not affect f, and closing f does not +// affect c. func FileConn(f *os.File) (c Conn, err error) { // TODO: Implement this return nil, os.NewSyscallError("FileConn", syscall.EWINDOWS) } +// FileListener returns a copy of the network listener corresponding +// to the open file f. It is the caller's responsibility to close l +// when finished. Closing l does not affect f, and closing f does not +// affect l. func FileListener(f *os.File) (l Listener, err error) { // TODO: Implement this return nil, os.NewSyscallError("FileListener", syscall.EWINDOWS) } +// FilePacketConn returns a copy of the packet network connection +// corresponding to the open file f. It is the caller's +// responsibility to close f when finished. Closing c does not affect +// f, and closing f does not affect c. func FilePacketConn(f *os.File) (c PacketConn, err error) { // TODO: Implement this return nil, os.NewSyscallError("FilePacketConn", syscall.EWINDOWS) diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go index 5ee0804c7..a34d47be1 100644 --- a/src/pkg/net/http/client.go +++ b/src/pkg/net/http/client.go @@ -19,12 +19,16 @@ import ( "strings" ) -// A Client is an HTTP client. Its zero value (DefaultClient) is a usable client -// that uses DefaultTransport. +// A Client is an HTTP client. Its zero value (DefaultClient) is a +// usable client that uses DefaultTransport. // -// The Client's Transport typically has internal state (cached -// TCP connections), so Clients should be reused instead of created as +// The Client's Transport typically has internal state (cached TCP +// connections), so Clients should be reused instead of created as // needed. Clients are safe for concurrent use by multiple goroutines. +// +// A Client is higher-level than a RoundTripper (such as Transport) +// and additionally handles HTTP details such as cookies and +// redirects. type Client struct { // Transport specifies the mechanism by which individual // HTTP requests are made. diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go index 88649bb16..73f1fe3c1 100644 --- a/src/pkg/net/http/client_test.go +++ b/src/pkg/net/http/client_test.go @@ -51,11 +51,10 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { return b, err } } - panic("unreachable") } func TestClient(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -73,7 +72,7 @@ func TestClient(t *testing.T) { } func TestClientHead(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -96,7 +95,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) } func TestGetRequestFormat(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} url := "http://dummy.faketld/" @@ -113,7 +112,7 @@ func TestGetRequestFormat(t *testing.T) { } func TestPostRequestFormat(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -140,7 +139,7 @@ func TestPostRequestFormat(t *testing.T) { } func TestPostFormRequestFormat(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -181,8 +180,8 @@ func TestPostFormRequestFormat(t *testing.T) { } } -func TestRedirects(t *testing.T) { - defer checkLeakedTransports(t) +func TestClientRedirects(t *testing.T) { + defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { n, _ := strconv.Atoi(r.FormValue("n")) @@ -256,7 +255,7 @@ func TestRedirects(t *testing.T) { } func TestPostRedirects(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) var log struct { sync.Mutex bytes.Buffer @@ -374,7 +373,7 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { } func TestRedirectCookiesOnRequest(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() @@ -392,7 +391,7 @@ func TestRedirectCookiesOnRequest(t *testing.T) { } func TestRedirectCookiesJar(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() @@ -429,7 +428,7 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { } func TestJarCalls(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { pathSuffix := r.RequestURI[1:] if r.RequestURI == "/nosetcookie" { @@ -493,7 +492,7 @@ func (j *RecordingJar) logf(format string, args ...interface{}) { } func TestStreamingGet(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) say := make(chan string) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() @@ -544,7 +543,7 @@ func (c *writeCountingConn) Write(p []byte) (int, error) { // TestClientWrites verifies that client requests are buffered and we // don't send a TCP packet per line of the http request + body. func TestClientWrites(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() @@ -578,7 +577,7 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) })) @@ -606,7 +605,7 @@ func TestClientInsecureTransport(t *testing.T) { } func TestClientErrorWithRequestURI(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) req, _ := NewRequest("GET", "http://localhost:1234/", nil) req.RequestURI = "/this/field/is/illegal/and/should/error/" _, err := DefaultClient.Do(req) @@ -635,7 +634,7 @@ func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { } func TestClientWithCorrectTLSServerName(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS.ServerName != "127.0.0.1" { t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) @@ -650,7 +649,7 @@ func TestClientWithCorrectTLSServerName(t *testing.T) { } func TestClientWithIncorrectTLSServerName(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() @@ -668,7 +667,7 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { // Verify Response.ContentLength is populated. http://golang.org/issue/4126 func TestClientHeadContentLength(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if v := r.FormValue("cl"); v != "" { w.Header().Set("Content-Length", v) diff --git a/src/pkg/net/http/cookiejar/jar.go b/src/pkg/net/http/cookiejar/jar.go index 5d1aeb87f..5977d48b6 100644 --- a/src/pkg/net/http/cookiejar/jar.go +++ b/src/pkg/net/http/cookiejar/jar.go @@ -28,6 +28,9 @@ import ( // An implementation that always returns "" is valid and may be useful for // testing but it is not secure: it means that the HTTP server for foo.com can // set a cookie for bar.com. +// +// A public suffix list implementation is in the package +// code.google.com/p/go.net/publicsuffix. type PublicSuffixList interface { // PublicSuffix returns the public suffix of domain. // diff --git a/src/pkg/net/http/example_test.go b/src/pkg/net/http/example_test.go index 22073eaf7..bc60df7f2 100644 --- a/src/pkg/net/http/example_test.go +++ b/src/pkg/net/http/example_test.go @@ -51,6 +51,20 @@ func ExampleGet() { } func ExampleFileServer() { - // we use StripPrefix so that /tmpfiles/somefile will access /tmp/somefile + // Simple static webserver: + log.Fatal(http.ListenAndServe(":8080", http.FileServer(http.Dir("/usr/share/doc")))) +} + +func ExampleFileServer_stripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: + http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) +} + +func ExampleStripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) } diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go index a7bca20a0..3fc245326 100644 --- a/src/pkg/net/http/export_test.go +++ b/src/pkg/net/http/export_test.go @@ -54,3 +54,5 @@ func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { } return &timeoutHandler{handler, f, ""} } + +var DefaultUserAgent = defaultUserAgent diff --git a/src/pkg/net/http/fcgi/child.go b/src/pkg/net/http/fcgi/child.go index c8b9a33c8..60b794e07 100644 --- a/src/pkg/net/http/fcgi/child.go +++ b/src/pkg/net/http/fcgi/child.go @@ -10,10 +10,12 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "net/http" "net/http/cgi" "os" + "strings" "time" ) @@ -152,20 +154,23 @@ func (c *child) serve() { var errCloseConn = errors.New("fcgi: connection should be closed") +var emptyBody = ioutil.NopCloser(strings.NewReader("")) + func (c *child) handleRecord(rec *record) error { req, ok := c.requests[rec.h.Id] if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { // The spec says to ignore unknown request IDs. return nil } - if ok && rec.h.Type == typeBeginRequest { - // The server is trying to begin a request with the same ID - // as an in-progress request. This is an error. - return errors.New("fcgi: received ID that is already in-flight") - } switch rec.h.Type { case typeBeginRequest: + if req != nil { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return errors.New("fcgi: received ID that is already in-flight") + } + var br beginRequest if err := br.read(rec.content()); err != nil { return err @@ -175,6 +180,7 @@ func (c *child) handleRecord(rec *record) error { return nil } c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + return nil case typeParams: // NOTE(eds): Technically a key-value pair can straddle the boundary // between two packets. We buffer until we've received all parameters. @@ -183,6 +189,7 @@ func (c *child) handleRecord(rec *record) error { return nil } req.parseParams() + return nil case typeStdin: content := rec.content() if req.pw == nil { @@ -191,6 +198,8 @@ func (c *child) handleRecord(rec *record) error { // body could be an io.LimitReader, but it shouldn't matter // as long as both sides are behaving. body, req.pw = io.Pipe() + } else { + body = emptyBody } go c.serveRequest(req, body) } @@ -201,24 +210,29 @@ func (c *child) handleRecord(rec *record) error { } else if req.pw != nil { req.pw.Close() } + return nil case typeGetValues: values := map[string]string{"FCGI_MPXS_CONNS": "1"} c.conn.writePairs(typeGetValuesResult, 0, values) + return nil case typeData: // If the filter role is implemented, read the data stream here. + return nil case typeAbortRequest: + println("abort") delete(c.requests, rec.h.Id) c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) if !req.keepConn { // connection will close upon return return errCloseConn } + return nil default: b := make([]byte, 8) b[0] = byte(rec.h.Type) c.conn.writeRecord(typeUnknownType, 0, b) + return nil } - return nil } func (c *child) serveRequest(req *request, body io.ReadCloser) { @@ -232,11 +246,19 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) { httpReq.Body = body c.handler.ServeHTTP(r, httpReq) } - if body != nil { - body.Close() - } r.Close() c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + + // Consume the entire body, so the host isn't still writing to + // us when we close the socket below in the !keepConn case, + // otherwise we'd send a RST. (golang.org/issue/4183) + // TODO(bradfitz): also bound this copy in time. Or send + // some sort of abort request to the host, so the host + // can properly cut off the client sending all the data. + // For now just bound it a little and + io.CopyN(ioutil.Discard, body, 100<<20) + body.Close() + if !req.keepConn { c.conn.Close() } @@ -267,5 +289,4 @@ func Serve(l net.Listener, handler http.Handler) error { c := newChild(rw, handler) go c.serve() } - panic("unreachable") } diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go index 0dd6d0df9..2c3737653 100644 --- a/src/pkg/net/http/fs_test.go +++ b/src/pkg/net/http/fs_test.go @@ -54,7 +54,7 @@ var ServeFileRangeTests = []struct { } func TestServeFile(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") })) @@ -170,7 +170,7 @@ var fsRedirectTestData = []struct { } func TestFSRedirect(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) defer ts.Close() @@ -195,7 +195,7 @@ func (fs *testFileSystem) Open(name string) (File, error) { } func TestFileServerCleans(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ch := make(chan string, 1) fs := FileServer(&testFileSystem{func(name string) (File, error) { ch <- name @@ -227,7 +227,7 @@ func mustRemoveAll(dir string) { } func TestFileServerImplicitLeadingSlash(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) tempDir, err := ioutil.TempDir("", "") if err != nil { t.Fatalf("TempDir: %v", err) @@ -306,7 +306,7 @@ func TestEmptyDirOpenCWD(t *testing.T) { } func TestServeFileContentType(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const ctype = "icecream/chocolate" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.FormValue("override") == "1" { @@ -330,7 +330,7 @@ func TestServeFileContentType(t *testing.T) { } func TestServeFileMimeType(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/style.css") })) @@ -347,7 +347,7 @@ func TestServeFileMimeType(t *testing.T) { } func TestServeFileFromCWD(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") })) @@ -363,7 +363,7 @@ func TestServeFileFromCWD(t *testing.T) { } func TestServeFileWithContentEncoding(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "foo") ServeFile(w, r, "testdata/file") @@ -380,7 +380,7 @@ func TestServeFileWithContentEncoding(t *testing.T) { } func TestServeIndexHtml(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const want = "index.html says hello\n" ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -402,7 +402,7 @@ func TestServeIndexHtml(t *testing.T) { } func TestFileServerZeroByte(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -471,7 +471,7 @@ func (fs fakeFS) Open(name string) (File, error) { } func TestDirectoryIfNotModified(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const indexContents = "I am a fake index.html file" fileMod := time.Unix(1000000000, 0).UTC() fileModStr := fileMod.Format(TimeFormat) @@ -545,7 +545,7 @@ func mustStat(t *testing.T, fileName string) os.FileInfo { } func TestServeContent(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) type serveParam struct { name string modtime time.Time @@ -678,7 +678,7 @@ func TestServeContent(t *testing.T) { // verifies that sendfile is being used on Linux func TestLinuxSendfile(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) if runtime.GOOS != "linux" { t.Skip("skipping; linux-only test") } @@ -697,7 +697,7 @@ func TestLinuxSendfile(t *testing.T) { defer ln.Close() var buf bytes.Buffer - child := exec.Command("strace", "-f", os.Args[0], "-test.run=TestLinuxSendfileChild") + child := exec.Command("strace", "-f", "-q", "-e", "trace=sendfile,sendfile64", os.Args[0], "-test.run=TestLinuxSendfileChild") child.ExtraFiles = append(child.ExtraFiles, lnf) child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) child.Stdout = &buf diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go index f479b7b4e..6374237fb 100644 --- a/src/pkg/net/http/header.go +++ b/src/pkg/net/http/header.go @@ -103,21 +103,41 @@ type keyValues struct { values []string } -type byKey []keyValues - -func (s byKey) Len() int { return len(s) } -func (s byKey) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s byKey) Less(i, j int) bool { return s[i].key < s[j].key } - -func (h Header) sortedKeyValues(exclude map[string]bool) []keyValues { - kvs := make([]keyValues, 0, len(h)) +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +// TODO: convert this to a sync.Cache (issue 4720) +var headerSorterCache = make(chan *headerSorter, 8) + +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + select { + case hs = <-headerSorterCache: + default: + hs = new(headerSorter) + } + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] for k, vv := range h { if !exclude[k] { kvs = append(kvs, keyValues{k, vv}) } } - sort.Sort(byKey(kvs)) - return kvs + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs } // WriteSubset writes a header in wire format. @@ -127,7 +147,8 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { if !ok { ws = stringWriter{w} } - for _, kv := range h.sortedKeyValues(exclude) { + kvs, sorter := h.sortedKeyValues(exclude) + for _, kv := range kvs { for _, v := range kv.values { v = headerNewlineToSpace.Replace(v) v = textproto.TrimString(v) @@ -138,6 +159,10 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { } } } + select { + case headerSorterCache <- sorter: + default: + } return nil } diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go index 2313b5549..a2b82a701 100644 --- a/src/pkg/net/http/header_test.go +++ b/src/pkg/net/http/header_test.go @@ -6,6 +6,7 @@ package http import ( "bytes" + "runtime" "testing" "time" ) @@ -178,7 +179,7 @@ var testHeader = Header{ "Content-Length": {"123"}, "Content-Type": {"text/plain"}, "Date": {"some date at some time Z"}, - "Server": {"Go http package"}, + "Server": {DefaultUserAgent}, } var buf bytes.Buffer @@ -192,13 +193,14 @@ func BenchmarkHeaderWriteSubset(b *testing.B) { } func TestHeaderWriteSubsetMallocs(t *testing.T) { + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } n := testing.AllocsPerRun(100, func() { buf.Reset() testHeader.WriteSubset(&buf, nil) }) - if n > 1 { - // TODO(bradfitz,rsc): once we can sort without allocating, - // make this an error. See http://golang.org/issue/3761 - // t.Errorf("got %v allocs, want <= %v", n, 1) + if n > 0 { + t.Errorf("mallocs = %d; want 0", n) } } diff --git a/src/pkg/net/http/httptest/example_test.go b/src/pkg/net/http/httptest/example_test.go index 239470d97..42a0ec953 100644 --- a/src/pkg/net/http/httptest/example_test.go +++ b/src/pkg/net/http/httptest/example_test.go @@ -12,7 +12,7 @@ import ( "net/http/httptest" ) -func ExampleRecorder() { +func ExampleResponseRecorder() { handler := func(w http.ResponseWriter, r *http.Request) { http.Error(w, "something failed", http.StatusInternalServerError) } diff --git a/src/pkg/net/http/httputil/dump_test.go b/src/pkg/net/http/httputil/dump_test.go index 5afe9ba74..3e87c27bc 100644 --- a/src/pkg/net/http/httputil/dump_test.go +++ b/src/pkg/net/http/httputil/dump_test.go @@ -68,7 +68,7 @@ var dumpTests = []dumpTest{ WantDumpOut: "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Accept-Encoding: gzip\r\n\r\n", }, @@ -80,7 +80,7 @@ var dumpTests = []dumpTest{ WantDumpOut: "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Accept-Encoding: gzip\r\n\r\n", }, } diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go index 134c45299..1990f64db 100644 --- a/src/pkg/net/http/httputil/reverseproxy.go +++ b/src/pkg/net/http/httputil/reverseproxy.go @@ -81,6 +81,19 @@ func copyHeader(dst, src http.Header) { } } +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport if transport == nil { @@ -96,14 +109,21 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.ProtoMinor = 1 outreq.Close = false - // Remove the connection header to the backend. We want a - // persistent connection, regardless of what the client sent - // to us. This is modifying the same underlying map from req - // (shallow copied above) so we only copy it if necessary. - if outreq.Header.Get("Connection") != "" { - outreq.Header = make(http.Header) - copyHeader(outreq.Header, req.Header) - outreq.Header.Del("Connection") + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. This + // is modifying the same underlying map from req (shallow + // copied above) so we only copy it if necessary. + copiedHeaders := false + for _, h := range hopHeaders { + if outreq.Header.Get(h) != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + copiedHeaders = true + } + outreq.Header.Del(h) + } } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { @@ -182,7 +202,6 @@ func (m *maxLatencyWriter) flushLoop() { m.lk.Unlock() } } - panic("unreached") } func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go index 863927162..1c0444ec4 100644 --- a/src/pkg/net/http/httputil/reverseproxy_test.go +++ b/src/pkg/net/http/httputil/reverseproxy_test.go @@ -29,6 +29,9 @@ func TestReverseProxy(t *testing.T) { if c := r.Header.Get("Connection"); c != "" { t.Errorf("handler got Connection header value %q", c) } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } if g, e := r.Host, "some-name"; g != e { t.Errorf("backend got Host header %q, want %q", g, e) } @@ -49,6 +52,7 @@ func TestReverseProxy(t *testing.T) { getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq.Host = "some-name" getReq.Header.Set("Connection", "close") + getReq.Header.Set("Upgrade", "foo") getReq.Close = true res, err := http.DefaultClient.Do(getReq) if err != nil { diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go index 217f35b48..6d4569146 100644 --- a/src/pkg/net/http/request.go +++ b/src/pkg/net/http/request.go @@ -48,7 +48,7 @@ var ( ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"} ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} - ErrMissingBoundary = &ProtocolError{"no multipart boundary param Content-Type"} + ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"} ) type badStringError struct { @@ -283,7 +283,7 @@ func valueOrDefault(value, def string) string { return def } -const defaultUserAgent = "Go http package" +const defaultUserAgent = "Go 1.1 package http" // Write writes an HTTP/1.1 request -- header and body -- in wire format. // This method consults the following fields of the request: @@ -467,10 +467,42 @@ func (r *Request) SetBasicAuth(username, password string) { r.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) } +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + s1 := strings.Index(line, " ") + s2 := strings.Index(line[s1+1:], " ") + if s1 < 0 || s2 < 0 { + return + } + s2 += s1 + 1 + return line[:s1], line[s1+1 : s2], line[s2+1:], true +} + +// TODO(bradfitz): use a sync.Cache when available +var textprotoReaderCache = make(chan *textproto.Reader, 4) + +func newTextprotoReader(br *bufio.Reader) *textproto.Reader { + select { + case r := <-textprotoReaderCache: + r.R = br + return r + default: + return textproto.NewReader(br) + } +} + +func putTextprotoReader(r *textproto.Reader) { + r.R = nil + select { + case textprotoReaderCache <- r: + default: + } +} + // ReadRequest reads and parses a request from b. func ReadRequest(b *bufio.Reader) (req *Request, err error) { - tp := textproto.NewReader(b) + tp := newTextprotoReader(b) req = new(Request) // First line: GET /index.html HTTP/1.0 @@ -479,18 +511,18 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { return nil, err } defer func() { + putTextprotoReader(tp) if err == io.EOF { err = io.ErrUnexpectedEOF } }() - var f []string - if f = strings.SplitN(s, " ", 3); len(f) < 3 { + var ok bool + req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s) + if !ok { return nil, &badStringError{"malformed HTTP request", s} } - req.Method, req.RequestURI, req.Proto = f[0], f[1], f[2] rawurl := req.RequestURI - var ok bool if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { return nil, &badStringError{"malformed HTTP version", req.Proto} } diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go index 00ad791de..692485c49 100644 --- a/src/pkg/net/http/request_test.go +++ b/src/pkg/net/http/request_test.go @@ -267,6 +267,38 @@ func TestNewRequestContentLength(t *testing.T) { } } +var parseHTTPVersionTests = []struct { + vers string + major, minor int + ok bool +}{ + {"HTTP/0.9", 0, 9, true}, + {"HTTP/1.0", 1, 0, true}, + {"HTTP/1.1", 1, 1, true}, + {"HTTP/3.14", 3, 14, true}, + + {"HTTP", 0, 0, false}, + {"HTTP/one.one", 0, 0, false}, + {"HTTP/1.1/", 0, 0, false}, + {"HTTP/-1,0", 0, 0, false}, + {"HTTP/0,-1", 0, 0, false}, + {"HTTP/", 0, 0, false}, + {"HTTP/1,1", 0, 0, false}, +} + +func TestParseHTTPVersion(t *testing.T) { + for _, tt := range parseHTTPVersionTests { + major, minor, ok := ParseHTTPVersion(tt.vers) + if ok != tt.ok || major != tt.major || minor != tt.minor { + type version struct { + major, minor int + ok bool + } + t.Errorf("failed to parse %q, expected: %#v, got %#v", tt.vers, version{tt.major, tt.minor, tt.ok}, version{major, minor, ok}) + } + } +} + type logWrites struct { t *testing.T dst *[]string @@ -289,7 +321,7 @@ func TestRequestWriteBufferedWriter(t *testing.T) { want := []string{ "GET / HTTP/1.1\r\n", "Host: foo.com\r\n", - "User-Agent: Go http package\r\n", + "User-Agent: " + DefaultUserAgent + "\r\n", "\r\n", } if !reflect.DeepEqual(got, want) { diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go index bc637f18b..b27b1f7ce 100644 --- a/src/pkg/net/http/requestwrite_test.go +++ b/src/pkg/net/http/requestwrite_test.go @@ -93,13 +93,13 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), WantProxy: "GET http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), }, @@ -123,14 +123,14 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), @@ -156,7 +156,7 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + "\r\n" + @@ -164,7 +164,7 @@ var reqWriteTests = []reqWriteTest{ WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + "\r\n" + @@ -187,14 +187,14 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 6\r\n" + "\r\n" + "abcdef", WantProxy: "POST http://example.com/ HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 6\r\n" + "\r\n" + "abcdef", @@ -210,7 +210,7 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "\r\n", }, @@ -232,13 +232,13 @@ var reqWriteTests = []reqWriteTest{ // Also, nginx expects it for POST and PUT. WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 0\r\n" + "\r\n", WantProxy: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 0\r\n" + "\r\n", }, @@ -258,13 +258,13 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("x") + chunk(""), WantProxy: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("x") + chunk(""), }, @@ -325,7 +325,7 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /foo HTTP/1.1\r\n" + "Host: \r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "X-Foo: X-Bar\r\n\r\n", }, @@ -351,7 +351,7 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /search HTTP/1.1\r\n" + "Host: \r\n" + - "User-Agent: Go http package\r\n\r\n", + "User-Agent: Go 1.1 package http\r\n\r\n", }, // Opaque test #1 from golang.org/issue/4860 @@ -370,7 +370,7 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n\r\n", + "User-Agent: Go 1.1 package http\r\n\r\n", }, // Opaque test #2 from golang.org/issue/4860 @@ -389,7 +389,31 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" + "Host: x.google.com\r\n" + - "User-Agent: Go http package\r\n\r\n", + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Testing custom case in header keys. Issue 5022. + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "ALL-CAPS": {"x"}, + }, + }, + + WantWrite: "GET / HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "ALL-CAPS: x\r\n" + + "\r\n", }, } @@ -474,7 +498,7 @@ func TestRequestWriteClosesBody(t *testing.T) { } expected := "POST / HTTP/1.1\r\n" + "Host: foo.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + // TODO: currently we don't buffer before chunking, so we get a // single "m" chunk before the other chunks, as this was the 1-byte diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go index 391ebbf6d..9a7e4e319 100644 --- a/src/pkg/net/http/response.go +++ b/src/pkg/net/http/response.go @@ -46,6 +46,9 @@ type Response struct { // The http Client and Transport guarantee that Body is always // non-nil, even on responses without a body or responses with // a zero-lengthed body. + // + // The Body is automatically dechunked if the server replied + // with a "chunked" Transfer-Encoding. Body io.ReadCloser // ContentLength records the length of the associated content. The diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go index 2f5f77369..02796e88b 100644 --- a/src/pkg/net/http/response_test.go +++ b/src/pkg/net/http/response_test.go @@ -112,8 +112,8 @@ var respTests = []respTest{ ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{ - "Connection": {"close"}, // TODO(rsc): Delete? - "Content-Length": {"10"}, // TODO(rsc): Delete? + "Connection": {"close"}, + "Content-Length": {"10"}, }, Close: true, ContentLength: 10, @@ -170,7 +170,7 @@ var respTests = []respTest{ Request: dummyReq("GET"), Header: Header{}, Close: false, - ContentLength: -1, // TODO(rsc): Fix? + ContentLength: -1, TransferEncoding: []string{"chunked"}, }, @@ -466,7 +466,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { if test.compressed { gzReader, err := gzip.NewReader(resp.Body) checkErr(err, "gzip.NewReader") - resp.Body = &readFirstCloseBoth{gzReader, resp.Body} + resp.Body = &readerAndCloser{gzReader, resp.Body} } rbuf := make([]byte, 2500) diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index 3300fef59..d7b321597 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -10,6 +10,7 @@ import ( "bufio" "bytes" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -64,10 +65,39 @@ func (a dummyAddr) String() string { return string(a) } +type noopConn struct{} + +func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") } +func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") } +func (noopConn) SetDeadline(t time.Time) error { return nil } +func (noopConn) SetReadDeadline(t time.Time) error { return nil } +func (noopConn) SetWriteDeadline(t time.Time) error { return nil } + +type rwTestConn struct { + io.Reader + io.Writer + noopConn + + closeFunc func() error // called if non-nil + closec chan bool // else, if non-nil, send value to it on close +} + +func (c *rwTestConn) Close() error { + if c.closeFunc != nil { + return c.closeFunc() + } + select { + case c.closec <- true: + default: + } + return nil +} + type testConn struct { readBuf bytes.Buffer writeBuf bytes.Buffer closec chan bool // if non-nil, send value to it on close + noopConn } func (c *testConn) Read(b []byte) (int, error) { @@ -86,26 +116,6 @@ func (c *testConn) Close() error { return nil } -func (c *testConn) LocalAddr() net.Addr { - return dummyAddr("local-addr") -} - -func (c *testConn) RemoteAddr() net.Addr { - return dummyAddr("remote-addr") -} - -func (c *testConn) SetDeadline(t time.Time) error { - return nil -} - -func (c *testConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *testConn) SetWriteDeadline(t time.Time) error { - return nil -} - func TestConsumingBodyOnNextConn(t *testing.T) { conn := new(testConn) for i := 0; i < 2; i++ { @@ -184,7 +194,7 @@ var vtests = []struct { } func TestHostHandlers(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) mux := NewServeMux() for _, h := range handlers { mux.Handle(h.pattern, stringHandler(h.msg)) @@ -257,7 +267,7 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) reqNum := 0 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ @@ -333,7 +343,7 @@ func TestServerTimeouts(t *testing.T) { // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. func TestOnlyWriteTimeout(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) var conn net.Conn var afterTimeoutErrc = make(chan error, 1) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { @@ -392,7 +402,7 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) { // TestIdentityResponse verifies that a handler can unset func TestIdentityResponse(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -468,7 +478,7 @@ func TestIdentityResponse(t *testing.T) { } func testTCPConnectionCloses(t *testing.T, req string, h Handler) { - defer checkLeakedTransports(t) + defer afterTest(t) s := httptest.NewServer(h) defer s.Close() @@ -539,7 +549,7 @@ func TestHandlersCanSetConnectionClose10(t *testing.T) { } func TestSetsRemoteAddr(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) @@ -560,7 +570,7 @@ func TestSetsRemoteAddr(t *testing.T) { } func TestChunkedResponseHeaders(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) log.SetOutput(ioutil.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) @@ -591,7 +601,7 @@ func TestChunkedResponseHeaders(t *testing.T) { // chunking in their response headers and aren't allowed to produce // output. func Test304Responses(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) @@ -621,7 +631,7 @@ func Test304Responses(t *testing.T) { // allowed to produce output, and don't set a Content-Type since // the real type of the body data cannot be inferred. func TestHeadResponses(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("Ignored body")) if err != ErrBodyNotAllowed { @@ -656,7 +666,7 @@ func TestHeadResponses(t *testing.T) { } func TestTLSHandshakeTimeout(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) ts.Config.ReadTimeout = 250 * time.Millisecond ts.StartTLS() @@ -676,7 +686,7 @@ func TestTLSHandshakeTimeout(t *testing.T) { } func TestTLSServer(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") @@ -759,7 +769,7 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. func TestServerExpect(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only @@ -897,7 +907,7 @@ func TestServerUnreadRequestBodyLarge(t *testing.T) { } func TestTimeoutHandler(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -972,7 +982,7 @@ func TestRedirectMunging(t *testing.T) { // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. func TestZeroLengthPostAndResponse(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := ioutil.ReadAll(r.Body) if err != nil { @@ -1023,7 +1033,7 @@ func TestHandlerPanicWithHijack(t *testing.T) { } func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { - defer checkLeakedTransports(t) + defer afterTest(t) // Unlike the other tests that set the log output to ioutil.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -1089,7 +1099,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { } func TestNoDate(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()["Date"] = nil })) @@ -1105,7 +1115,7 @@ func TestNoDate(t *testing.T) { } func TestStripPrefix(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) }) @@ -1132,7 +1142,7 @@ func TestStripPrefix(t *testing.T) { } func TestRequestLimit(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") })) @@ -1176,7 +1186,7 @@ func (cr countReader) Read(p []byte) (n int, err error) { } func TestRequestBodyLimit(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const limit = 1 << 20 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) @@ -1213,7 +1223,7 @@ func TestRequestBodyLimit(t *testing.T) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. func TestClientWriteShutdown(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1268,7 +1278,7 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See http://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) })) @@ -1311,7 +1321,7 @@ func TestServerGracefulClose(t *testing.T) { } func TestCaseSensitiveMethod(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "get" { t.Errorf(`Got method %q; want "get"`, r.Method) @@ -1360,6 +1370,7 @@ func TestContentLengthZero(t *testing.T) { } func TestCloseNotifier(t *testing.T) { + defer afterTest(t) gotReq := make(chan bool, 1) sawClose := make(chan bool, 1) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -1395,6 +1406,31 @@ For: ts.Close() } +func TestCloseNotifierChanLeak(t *testing.T) { + defer afterTest(t) + req := []byte(strings.Replace(`GET / HTTP/1.0 +Host: golang.org + +`, "\n", "\r\n", -1)) + for i := 0; i < 20; i++ { + var output bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &output, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + // Ignore the return value and never read from + // it, testing that we don't leak goroutines + // on the sending side: + _ = rw.(CloseNotifier).CloseNotify() + }) + go Serve(ln, handler) + <-conn.closec + } +} + func TestOptions(t *testing.T) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() @@ -1447,6 +1483,198 @@ func TestOptions(t *testing.T) { } } +// Tests regarding the ordering of Write, WriteHeader, Header, and +// Flush calls. In Go 1.0, rw.WriteHeader immediately flushed the +// (*response).header to the wire. In Go 1.1, the actual wire flush is +// delayed, so we could maybe tack on a Content-Length and better +// Content-Type after we see more (or all) of the output. To preserve +// compatibility with Go 1, we need to be careful to track which +// headers were live at the time of WriteHeader, so we write the same +// ones, even if the handler modifies them (~erroneously) after the +// first Write. +func TestHeaderToWire(t *testing.T) { + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org + +`, "\n", "\r\n", -1)) + + tests := []struct { + name string + handler func(ResponseWriter, *Request) + check func(output string) error + }{ + { + name: "write without Header", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("hello world")) + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Length:") { + return errors.New("no content-length") + } + if !strings.Contains(got, "Content-Type: text/plain") { + return errors.New("no content-length") + } + return nil + }, + }, + { + name: "Header mutation before write", + handler: func(rw ResponseWriter, r *Request) { + h := rw.Header() + h.Set("Content-Type", "some/type") + rw.Write([]byte("hello world")) + h.Set("Too-Late", "bogus") + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Length:") { + return errors.New("no content-length") + } + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-type") + } + if strings.Contains(got, "Too-Late") { + return errors.New("don't want too-late header") + } + return nil + }, + }, + { + name: "write then useless Header mutation", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("hello world")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + return nil + }, + }, + { + name: "flush then write", + handler: func(rw ResponseWriter, r *Request) { + rw.(Flusher).Flush() + rw.Write([]byte("post-flush")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if !strings.Contains(got, "Transfer-Encoding: chunked") { + return errors.New("not chunked") + } + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + return nil + }, + }, + { + name: "header then flush", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "some/type") + rw.(Flusher).Flush() + rw.Write([]byte("post-flush")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if !strings.Contains(got, "Transfer-Encoding: chunked") { + return errors.New("not chunked") + } + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-length") + } + return nil + }, + }, + { + name: "sniff-on-first-write content-type", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("<html><head></head><body>some html</body></html>")) + rw.Header().Set("Content-Type", "x/wrong") + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: text/html") { + return errors.New("wrong content-length; want html") + } + return nil + }, + }, + { + name: "explicit content-type wins", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "some/type") + rw.Write([]byte("<html><head></head><body>some html</body></html>")) + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-length; want html") + } + return nil + }, + }, + { + name: "empty handler", + handler: func(rw ResponseWriter, r *Request) { + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: text/plain") { + return errors.New("wrong content-length; want text/plain") + } + if !strings.Contains(got, "Content-Length: 0") { + return errors.New("want 0 content-length") + } + return nil + }, + }, + { + name: "only Header, no write", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Some-Header", "some-value") + }, + check: func(got string) error { + if !strings.Contains(got, "Some-Header") { + return errors.New("didn't get header") + } + return nil + }, + }, + { + name: "WriteHeader call", + handler: func(rw ResponseWriter, r *Request) { + rw.WriteHeader(404) + rw.Header().Set("Too-Late", "some-value") + }, + check: func(got string) error { + if !strings.Contains(got, "404") { + return errors.New("wrong status") + } + if strings.Contains(got, "Some-Header") { + return errors.New("shouldn't have seen Too-Late") + } + return nil + }, + }, + } + for _, tc := range tests { + var output bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &output, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + go Serve(ln, HandlerFunc(tc.handler)) + <-conn.closec + if err := tc.check(output.String()); err != nil { + t.Errorf("%s: %v\nGot response:\n%s", tc.name, err, output.Bytes()) + } + } +} + // goTimeout runs f, failing t if f takes more than ns to complete. func goTimeout(t *testing.T, d time.Duration, f func()) { ch := make(chan bool, 2) @@ -1620,3 +1848,179 @@ func BenchmarkServer(b *testing.B) { b.Errorf("Test failure: %v, with output: %s", err, out) } } + +func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) { + b.ReportAllocs() + req := []byte(strings.Replace(`GET / HTTP/1.0 +Host: golang.org +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 + +`, "\n", "\r\n", -1)) + res := []byte("Hello world!\n") + + conn := &testConn{ + // testConn.Close will not push into the channel + // if it's full. + closec: make(chan bool, 1), + } + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + }) + ln := new(oneConnListener) + for i := 0; i < b.N; i++ { + conn.readBuf.Reset() + conn.writeBuf.Reset() + conn.readBuf.Write(req) + ln.conn = conn + Serve(ln, handler) + <-conn.closec + } +} + +// repeatReader reads content count times, then EOFs. +type repeatReader struct { + content []byte + count int + off int +} + +func (r *repeatReader) Read(p []byte) (n int, err error) { + if r.count <= 0 { + return 0, io.EOF + } + n = copy(p, r.content[r.off:]) + r.off += n + if r.off == len(r.content) { + r.count-- + r.off = 0 + } + return +} + +func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) { + b.ReportAllocs() + + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 + +`, "\n", "\r\n", -1)) + res := []byte("Hello world!\n") + + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} + +// same as above, but representing the most simple possible request +// and handler. Notably: the handler does not call rw.Header(). +func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) { + b.ReportAllocs() + + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org + +`, "\n", "\r\n", -1)) + res := []byte("Hello world!\n") + + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + rw.Write(res) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} + +const someResponse = "<html>some response</html>" + +// A Response that's just no bigger than 2KB, the buffer-before-chunking threshold. +var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse)) + +// Both Content-Type and Content-Length set. Should be no buffering. +func BenchmarkServerHandlerTypeLen(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", strconv.Itoa(len(response))) + w.Write(response) + })) +} + +// A Content-Type is set, but no length. No sniffing, but will count the Content-Length. +func BenchmarkServerHandlerNoLen(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/html") + w.Write(response) + })) +} + +// A Content-Length is set, but the Content-Type will be sniffed. +func BenchmarkServerHandlerNoType(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", strconv.Itoa(len(response))) + w.Write(response) + })) +} + +// Neither a Content-Type or Content-Length, so sniffed and counted. +func BenchmarkServerHandlerNoHeader(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write(response) + })) +} + +func benchmarkHandler(b *testing.B, h Handler) { + b.ReportAllocs() + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org + +`, "\n", "\r\n", -1)) + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + h.ServeHTTP(rw, r) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index b6ab78228..b25960705 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -4,9 +4,6 @@ // HTTP server. See RFC 2616. -// TODO(rsc): -// logging - package http import ( @@ -109,9 +106,11 @@ type conn struct { remoteAddr string // network address of remote side server *Server // the Server on which the connection arrived rwc net.Conn // i/o connection - sr switchReader // where the LimitReader reads from; usually the rwc + sr liveSwitchReader // where the LimitReader reads from; usually the rwc lr *io.LimitedReader // io.LimitReader(sr) buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc + bufswr *switchReader // the *switchReader io.Reader source of buf + bufsww *switchWriter // the *switchWriter io.Writer dest of buf tlsState *tls.ConnectionState // or nil when not using TLS mu sync.Mutex // guards the following @@ -147,7 +146,7 @@ func (c *conn) closeNotify() <-chan bool { c.mu.Lock() defer c.mu.Unlock() if c.closeNotifyc == nil { - c.closeNotifyc = make(chan bool) + c.closeNotifyc = make(chan bool, 1) if c.hijackedv { // to obey the function signature, even though // it'll never receive a value. @@ -180,12 +179,26 @@ func (c *conn) noteClientGone() { c.clientGone = true } +// A switchReader can have its Reader changed at runtime. +// It's not safe for concurrent Reads and switches. type switchReader struct { + io.Reader +} + +// A switchWriter can have its Writer changed at runtime. +// It's not safe for concurrent Writes and switches. +type switchWriter struct { + io.Writer +} + +// A liveSwitchReader is a switchReader that's safe for concurrent +// reads and switches, if its mutex is held. +type liveSwitchReader struct { sync.Mutex r io.Reader } -func (sr *switchReader) Read(p []byte) (n int, err error) { +func (sr *liveSwitchReader) Read(p []byte) (n int, err error) { sr.Lock() r := sr.r sr.Unlock() @@ -206,15 +219,28 @@ const bufferBeforeChunkingSize = 2048 // // See the comment above (*response).Write for the entire write flow. type chunkWriter struct { - res *response - header Header // a deep copy of r.Header, once WriteHeader is called - wroteHeader bool // whether the header's been sent + res *response + + // header is either nil or a deep clone of res.handlerHeader + // at the time of res.WriteHeader, if res.WriteHeader is + // called and extra buffering is being done to calculate + // Content-Type and/or Content-Length. + header Header + + // wroteHeader tells whether the header's been written to "the + // wire" (or rather: w.conn.buf). this is unlike + // (*response).wroteHeader, which tells only whether it was + // logically written. + wroteHeader bool // set by the writeHeader method: chunking bool // using chunked transfer encoding for reply body } -var crlf = []byte("\r\n") +var ( + crlf = []byte("\r\n") + colonSpace = []byte(": ") +) func (cw *chunkWriter) Write(p []byte) (n int, err error) { if !cw.wroteHeader { @@ -264,13 +290,15 @@ type response struct { wroteContinue bool // 100 Continue response was written w *bufio.Writer // buffers output in chunks to chunkWriter - cw *chunkWriter + cw chunkWriter + sw *switchWriter // of the bufio.Writer, for return to putBufioWriter // handlerHeader is the Header that Handlers get access to, // which may be retained and mutated even after WriteHeader. // handlerHeader is copied into cw.header at WriteHeader // time, and privately mutated thereafter. handlerHeader Header + calledHeader bool // handler accessed handlerHeader via Header written int64 // number of bytes written in body contentLength int64 // explicitly-declared Content-Length; or -1 @@ -362,14 +390,98 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { if debugServerConnections { c.rwc = newLoggingConn("server", c.rwc) } - c.sr = switchReader{r: c.rwc} + c.sr = liveSwitchReader{r: c.rwc} c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) - br := bufio.NewReader(c.lr) - bw := bufio.NewWriter(c.rwc) + br, sr := newBufioReader(c.lr) + bw, sw := newBufioWriterSize(c.rwc, 4<<10) c.buf = bufio.NewReadWriter(br, bw) + c.bufswr = sr + c.bufsww = sw return c, nil } +// TODO: remove this, if issue 5100 is fixed +type bufioReaderPair struct { + br *bufio.Reader + sr *switchReader // from which the bufio.Reader is reading +} + +// TODO: remove this, if issue 5100 is fixed +type bufioWriterPair struct { + bw *bufio.Writer + sw *switchWriter // to which the bufio.Writer is writing +} + +// TODO: use a sync.Cache instead +var ( + bufioReaderCache = make(chan bufioReaderPair, 4) + bufioWriterCache2k = make(chan bufioWriterPair, 4) + bufioWriterCache4k = make(chan bufioWriterPair, 4) +) + +func bufioWriterCache(size int) chan bufioWriterPair { + switch size { + case 2 << 10: + return bufioWriterCache2k + case 4 << 10: + return bufioWriterCache4k + } + return nil +} + +func newBufioReader(r io.Reader) (*bufio.Reader, *switchReader) { + select { + case p := <-bufioReaderCache: + p.sr.Reader = r + return p.br, p.sr + default: + sr := &switchReader{r} + return bufio.NewReader(sr), sr + } +} + +func putBufioReader(br *bufio.Reader, sr *switchReader) { + if n := br.Buffered(); n > 0 { + io.CopyN(ioutil.Discard, br, int64(n)) + } + br.Read(nil) // clears br.err + sr.Reader = nil + select { + case bufioReaderCache <- bufioReaderPair{br, sr}: + default: + } +} + +func newBufioWriterSize(w io.Writer, size int) (*bufio.Writer, *switchWriter) { + select { + case p := <-bufioWriterCache(size): + p.sw.Writer = w + return p.bw, p.sw + default: + sw := &switchWriter{w} + return bufio.NewWriterSize(sw, size), sw + } +} + +func putBufioWriter(bw *bufio.Writer, sw *switchWriter) { + if bw.Buffered() > 0 { + // It must have failed to flush to its target + // earlier. We can't reuse this bufio.Writer. + return + } + if err := bw.Flush(); err != nil { + // Its sticky error field is set, which is returned by + // Flush even when there's no data buffered. This + // bufio Writer is dead to us. Don't reuse it. + return + } + sw.Writer = nil + select { + case bufioWriterCache(bw.Available()) <- bufioWriterPair{bw, sw}: + default: + } +} + // DefaultMaxHeaderBytes is the maximum permitted size of the headers // in an HTTP request. // This can be overridden by setting Server.MaxHeaderBytes. @@ -448,14 +560,20 @@ func (c *conn) readRequest() (w *response, err error) { req: req, handlerHeader: make(Header), contentLength: -1, - cw: new(chunkWriter), } w.cw.res = w - w.w = bufio.NewWriterSize(w.cw, bufferBeforeChunkingSize) + w.w, w.sw = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) return w, nil } func (w *response) Header() Header { + if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader { + // Accessing the header between logically writing it + // and physically writing it means we need to allocate + // a clone to snapshot the logically written state. + w.cw.header = w.handlerHeader.clone() + } + w.calledHeader = true return w.handlerHeader } @@ -482,15 +600,48 @@ func (w *response) WriteHeader(code int) { w.wroteHeader = true w.status = code - w.cw.header = w.handlerHeader.clone() + if w.calledHeader && w.cw.header == nil { + w.cw.header = w.handlerHeader.clone() + } - if cl := w.cw.header.get("Content-Length"); cl != "" { + if cl := w.handlerHeader.get("Content-Length"); cl != "" { v, err := strconv.ParseInt(cl, 10, 64) if err == nil && v >= 0 { w.contentLength = v } else { log.Printf("http: invalid Content-Length of %q", cl) - w.cw.header.Del("Content-Length") + w.handlerHeader.Del("Content-Length") + } + } +} + +// extraHeader is the set of headers sometimes added by chunkWriter.writeHeader. +// This type is used to avoid extra allocations from cloning and/or populating +// the response Header map and all its 1-element slices. +type extraHeader struct { + contentType string + contentLength string + connection string + date string + transferEncoding string +} + +// Sorted the same as extraHeader.Write's loop. +var extraHeaderKeys = [][]byte{ + []byte("Content-Type"), []byte("Content-Length"), + []byte("Connection"), []byte("Date"), []byte("Transfer-Encoding"), +} + +// The value receiver, despite copying 5 strings to the stack, +// prevents an extra allocation. The escape analysis isn't smart +// enough to realize this doesn't mutate h. +func (h extraHeader) Write(w io.Writer) { + for i, v := range []string{h.contentType, h.contentLength, h.connection, h.date, h.transferEncoding} { + if v != "" { + w.Write(extraHeaderKeys[i]) + w.Write(colonSpace) + io.WriteString(w, v) + w.Write(crlf) } } } @@ -510,23 +661,47 @@ func (cw *chunkWriter) writeHeader(p []byte) { cw.wroteHeader = true w := cw.res - code := w.status - done := w.handlerDone + + // header is written out to w.conn.buf below. Depending on the + // state of the handler, we either own the map or not. If we + // don't own it, the exclude map is created lazily for + // WriteSubset to remove headers. The setHeader struct holds + // headers we need to add. + header := cw.header + owned := header != nil + if !owned { + header = w.handlerHeader + } + var excludeHeader map[string]bool + delHeader := func(key string) { + if owned { + header.Del(key) + return + } + if _, ok := header[key]; !ok { + return + } + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[key] = true + } + var setHeader extraHeader // If the handler is done but never sent a Content-Length // response header and this is our first (and last) write, set // it, even to zero. This helps HTTP/1.0 clients keep their // "keep-alive" connections alive. - if done && cw.header.get("Content-Length") == "" && w.req.Method != "HEAD" { + if w.handlerDone && header.get("Content-Length") == "" && w.req.Method != "HEAD" { w.contentLength = int64(len(p)) - cw.header.Set("Content-Length", strconv.Itoa(len(p))) + setHeader.contentLength = strconv.Itoa(len(p)) } // If this was an HTTP/1.0 request with keep-alive and we sent a // Content-Length back, we can make this a keep-alive response ... if w.req.wantsHttp10KeepAlive() { - sentLength := cw.header.get("Content-Length") != "" - if sentLength && cw.header.get("Connection") == "keep-alive" { + sentLength := header.get("Content-Length") != "" + if sentLength && header.get("Connection") == "keep-alive" { w.closeAfterReply = false } } @@ -535,15 +710,15 @@ func (cw *chunkWriter) writeHeader(p []byte) { hasCL := w.contentLength != -1 if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { - _, connectionHeaderSet := cw.header["Connection"] + _, connectionHeaderSet := header["Connection"] if !connectionHeaderSet { - cw.header.Set("Connection", "keep-alive") + setHeader.connection = "keep-alive" } } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() { w.closeAfterReply = true } - if cw.header.get("Connection") == "close" { + if header.get("Connection") == "close" { w.closeAfterReply = true } @@ -557,49 +732,49 @@ func (cw *chunkWriter) writeHeader(p []byte) { n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) if n >= maxPostHandlerReadBytes { w.requestTooLarge() - cw.header.Set("Connection", "close") + delHeader("Connection") + setHeader.connection = "close" } else { w.req.Body.Close() } } } + code := w.status if code == StatusNotModified { // Must not have body. - for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" - if cw.header.get(header) != "" { - cw.header.Del(header) - } + // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" + for _, k := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { + delHeader(k) } } else { // If no content type, apply sniffing algorithm to body. - if cw.header.get("Content-Type") == "" && w.req.Method != "HEAD" { - cw.header.Set("Content-Type", DetectContentType(p)) + if header.get("Content-Type") == "" && w.req.Method != "HEAD" { + setHeader.contentType = DetectContentType(p) } } - if _, ok := cw.header["Date"]; !ok { - cw.header.Set("Date", time.Now().UTC().Format(TimeFormat)) + if _, ok := header["Date"]; !ok { + setHeader.date = time.Now().UTC().Format(TimeFormat) } - te := cw.header.get("Transfer-Encoding") + te := header.get("Transfer-Encoding") hasTE := te != "" if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter // For now just ignore the Content-Length. log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", te, w.contentLength) - cw.header.Del("Content-Length") + delHeader("Content-Length") hasCL = false } if w.req.Method == "HEAD" || code == StatusNotModified { // do nothing } else if code == StatusNoContent { - cw.header.Del("Transfer-Encoding") + delHeader("Transfer-Encoding") } else if hasCL { - cw.header.Del("Transfer-Encoding") + delHeader("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. @@ -607,29 +782,63 @@ func (cw *chunkWriter) writeHeader(p []byte) { // might have set. Deal with that as need arises once we have a valid // use case. cw.chunking = true - cw.header.Set("Transfer-Encoding", "chunked") + setHeader.transferEncoding = "chunked" } else { // HTTP version < 1.1: cannot do chunked transfer // encoding and we don't know the Content-Length so // signal EOF by closing connection. w.closeAfterReply = true - cw.header.Del("Transfer-Encoding") // in case already set + delHeader("Transfer-Encoding") // in case already set } // Cannot use Content-Length with non-identity Transfer-Encoding. if cw.chunking { - cw.header.Del("Content-Length") + delHeader("Content-Length") } if !w.req.ProtoAtLeast(1, 0) { return } if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") { - cw.header.Set("Connection", "close") + delHeader("Connection") + setHeader.connection = "close" } + io.WriteString(w.conn.buf, statusLine(w.req, code)) + cw.header.WriteSubset(w.conn.buf, excludeHeader) + setHeader.Write(w.conn.buf) + w.conn.buf.Write(crlf) +} + +// statusLines is a cache of Status-Line strings, keyed by code (for +// HTTP/1.1) or negative code (for HTTP/1.0). This is faster than a +// map keyed by struct of two fields. This map's max size is bounded +// by 2*len(statusText), two protocol types for each known official +// status code in the statusText map. +var ( + statusMu sync.RWMutex + statusLines = make(map[int]string) +) + +// statusLine returns a response Status-Line (RFC 2616 Section 6.1) +// for the given request and response status code. +func statusLine(req *Request, code int) string { + // Fast path: + key := code + proto11 := req.ProtoAtLeast(1, 1) + if !proto11 { + key = -key + } + statusMu.RLock() + line, ok := statusLines[key] + statusMu.RUnlock() + if ok { + return line + } + + // Slow path: proto := "HTTP/1.0" - if w.req.ProtoAtLeast(1, 1) { + if proto11 { proto = "HTTP/1.1" } codestring := strconv.Itoa(code) @@ -637,9 +846,13 @@ func (cw *chunkWriter) writeHeader(p []byte) { if !ok { text = "status code " + codestring } - io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - cw.header.Write(w.conn.buf) - w.conn.buf.Write(crlf) + line = proto + " " + codestring + " " + text + "\r\n" + if ok { + statusMu.Lock() + defer statusMu.Unlock() + statusLines[key] = line + } + return line } // bodyAllowed returns true if a Write is allowed for this response type. @@ -655,7 +868,7 @@ func (w *response) bodyAllowed() bool { // // Handler starts. No header has been sent. The handler can either // write a header, or just start writing. Writing before sending a header -// sends an implicity empty 200 OK header. +// sends an implicitly empty 200 OK header. // // If the handler didn't declare a Content-Length up front, we either // go into chunking mode or, if the handler finishes running before @@ -713,6 +926,7 @@ func (w *response) finishRequest() { } w.w.Flush() + putBufioWriter(w.w, w.sw) w.cw.close() w.conn.buf.Flush() @@ -742,6 +956,15 @@ func (w *response) Flush() { func (c *conn) finalFlush() { if c.buf != nil { c.buf.Flush() + + // Steal the bufio.Reader (~4KB worth of memory) and its associated + // reader for a future connection. + putBufioReader(c.buf.Reader, c.bufswr) + + // Steal the bufio.Writer (~4KB worth of memory) and its associated + // writer for a future connection. + putBufioWriter(c.buf.Writer, c.bufsww) + c.buf = nil } } @@ -948,13 +1171,16 @@ func NotFoundHandler() Handler { return HandlerFunc(NotFound) } // request for a path that doesn't begin with prefix by // replying with an HTTP 404 not found error. func StripPrefix(prefix string, h Handler) Handler { + if prefix == "" { + return h + } return HandlerFunc(func(w ResponseWriter, r *Request) { - if !strings.HasPrefix(r.URL.Path, prefix) { + if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { + r.URL.Path = p + h.ServeHTTP(w, r) + } else { NotFound(w, r) - return } - r.URL.Path = r.URL.Path[len(prefix):] - h.ServeHTTP(w, r) }) } @@ -996,9 +1222,9 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { } // clean up but preserve trailing slash - trailing := urlStr[len(urlStr)-1] == '/' + trailing := strings.HasSuffix(urlStr, "/") urlStr = path.Clean(urlStr) - if trailing && urlStr[len(urlStr)-1] != '/' { + if trailing && !strings.HasSuffix(urlStr, "/") { urlStr += "/" } urlStr += query @@ -1266,7 +1492,7 @@ type Server struct { // TLSNextProto optionally specifies a function to take over // ownership of the provided TLS connection when an NPN - // protocol upgrade has occured. The map key is the protocol + // protocol upgrade has occurred. The map key is the protocol // name negotiated. The Handler argument should be used to // handle HTTP requests and will initialize the Request's TLS // and RemoteAddr if not already set. The connection is @@ -1337,7 +1563,6 @@ func (srv *Server) Serve(l net.Listener) error { } go c.serve() } - panic("not reached") } // ListenAndServe listens on the TCP network address addr diff --git a/src/pkg/net/http/server_test.go b/src/pkg/net/http/server_test.go index 8b4e8c6d6..e8b69f76c 100644 --- a/src/pkg/net/http/server_test.go +++ b/src/pkg/net/http/server_test.go @@ -2,9 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http +package http_test import ( + . "net/http" + "net/http/httptest" "net/url" "testing" ) @@ -76,20 +78,27 @@ func TestServeMuxHandler(t *testing.T) { }, } h, pattern := mux.Handler(r) - cs := &codeSaver{h: Header{}} - h.ServeHTTP(cs, r) - if pattern != tt.pattern || cs.code != tt.code { - t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, cs.code, pattern, tt.code, tt.pattern) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, r) + if pattern != tt.pattern || rr.Code != tt.code { + t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern) } } } -// A codeSaver is a ResponseWriter that saves the code passed to WriteHeader. -type codeSaver struct { - h Header - code int +func TestServerRedirect(t *testing.T) { + // This used to crash. It's not valid input (bad path), but it + // shouldn't crash. + rr := httptest.NewRecorder() + req := &Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Path: "not-empty-but-no-leading-slash", // bogus + }, + } + Redirect(rr, req, "", 304) + if rr.Code != 304 { + t.Errorf("Code = %d; want 304", rr.Code) + } } - -func (cs *codeSaver) Header() Header { return cs.h } -func (cs *codeSaver) Write(p []byte) (int, error) { return len(p), nil } -func (cs *codeSaver) WriteHeader(code int) { cs.code = code } diff --git a/src/pkg/net/http/sniff_test.go b/src/pkg/net/http/sniff_test.go index 8ab72ac23..106d94ec1 100644 --- a/src/pkg/net/http/sniff_test.go +++ b/src/pkg/net/http/sniff_test.go @@ -54,6 +54,7 @@ func TestDetectContentType(t *testing.T) { } func TestServerContentType(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) tt := sniffTests[i] @@ -84,6 +85,8 @@ func TestServerContentType(t *testing.T) { } func TestContentTypeWithCopy(t *testing.T) { + defer afterTest(t) + const ( input = "\n<html>\n\t<head>\n" expected = "text/html; charset=utf-8" @@ -116,6 +119,7 @@ func TestContentTypeWithCopy(t *testing.T) { } func TestSniffWriteSize(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) written, err := io.WriteString(w, strings.Repeat("a", size)) @@ -133,6 +137,11 @@ func TestSniffWriteSize(t *testing.T) { if err != nil { t.Fatalf("size %d: %v", size, err) } - res.Body.Close() + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatalf("size %d: io.Copy of body = %v", size, err) + } + if err := res.Body.Close(); err != nil { + t.Fatalf("size %d: body Close = %v", size, err) + } } } diff --git a/src/pkg/net/http/status.go b/src/pkg/net/http/status.go index 5af0b77c4..d253bd5cb 100644 --- a/src/pkg/net/http/status.go +++ b/src/pkg/net/http/status.go @@ -51,6 +51,13 @@ const ( StatusServiceUnavailable = 503 StatusGatewayTimeout = 504 StatusHTTPVersionNotSupported = 505 + + // New HTTP status codes from RFC 6585. Not exported yet in Go 1.1. + // See discussion at https://codereview.appspot.com/7678043/ + statusPreconditionRequired = 428 + statusTooManyRequests = 429 + statusRequestHeaderFieldsTooLarge = 431 + statusNetworkAuthenticationRequired = 511 ) var statusText = map[int]string{ @@ -99,6 +106,11 @@ var statusText = map[int]string{ StatusServiceUnavailable: "Service Unavailable", StatusGatewayTimeout: "Gateway Timeout", StatusHTTPVersionNotSupported: "HTTP Version Not Supported", + + statusPreconditionRequired: "Precondition Required", + statusTooManyRequests: "Too Many Requests", + statusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large", + statusNetworkAuthenticationRequired: "Network Authentication Required", } // StatusText returns a text for the HTTP status code. It returns the empty diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go index 43c6023a3..53569bcc2 100644 --- a/src/pkg/net/http/transfer.go +++ b/src/pkg/net/http/transfer.go @@ -328,12 +328,13 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { switch { case chunked(t.TransferEncoding): if noBodyExpected(t.RequestMethod) { - t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close} + t.Body = &body{Reader: eofReader, closing: t.Close} } else { t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} } - case realLength >= 0: - // TODO: limit the Content-Length. This is an easy DoS vector. + case realLength == 0: + t.Body = &body{Reader: eofReader, closing: t.Close} + case realLength > 0: t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header @@ -342,7 +343,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t.Body = &body{Reader: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) - t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close} + t.Body = &body{Reader: eofReader, closing: t.Close} } } @@ -612,30 +613,26 @@ func (b *body) Close() error { if b.closed { return nil } - defer func() { - b.closed = true - }() - if b.hdr == nil && b.closing { + var err error + switch { + case b.hdr == nil && b.closing: // no trailer and closing the connection next. // no point in reading to EOF. - return nil - } - - // In a server request, don't continue reading from the client - // if we've already hit the maximum body size set by the - // handler. If this is set, that also means the TCP connection - // is about to be closed, so getting to the next HTTP request - // in the stream is not necessary. - if b.res != nil && b.res.requestBodyLimitHit { - return nil - } - - // Fully consume the body, which will also lead to us reading - // the trailer headers after the body, if present. - if _, err := io.Copy(ioutil.Discard, b); err != nil { - return err + case b.res != nil && b.res.requestBodyLimitHit: + // In a server request, don't continue reading from the client + // if we've already hit the maximum body size set by the + // handler. If this is set, that also means the TCP connection + // is about to be closed, so getting to the next HTTP request + // in the stream is not necessary. + case b.Reader == eofReader: + // Nothing to read. No need to io.Copy from it. + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(ioutil.Discard, b) } - return nil + b.closed = true + return err } // parseContentLength trims whitespace from s and returns -1 if no value diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index 685d7d56c..4cd0533ff 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -17,7 +17,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" "net/url" @@ -42,16 +41,13 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - idleMu sync.Mutex - idleConn map[string][]*persistConn - reqMu sync.Mutex - reqConn map[*Request]*persistConn - altMu sync.RWMutex - altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper - - // TODO: tunable on global max cached connections - // TODO: tunable on timeout on cached connections - // TODO: optional pipelining + idleMu sync.Mutex + idleConn map[string][]*persistConn + idleConnCh map[string]chan *persistConn + reqMu sync.Mutex + reqConn map[*Request]*persistConn + altMu sync.RWMutex + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -62,13 +58,24 @@ type Transport struct { // Dial specifies the dial function for creating TCP // connections. // If Dial is nil, net.Dial is used. - Dial func(net, addr string) (c net.Conn, err error) + Dial func(network, addr string) (net.Conn, error) // TLSClientConfig specifies the TLS configuration to use with // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config - DisableKeepAlives bool + // DisableKeepAlives, if true, prevents re-use of TCP connections + // between different HTTP requests. + DisableKeepAlives bool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. DisableCompression bool // MaxIdleConnsPerHost, if non-zero, controls the maximum idle @@ -81,6 +88,9 @@ type Transport struct { // writing the request (including its body, if any). This // time does not include the time to read the response body. ResponseHeaderTimeout time.Duration + + // TODO: tunable on global max cached connections + // TODO: tunable on timeout on cached connections } // ProxyFromEnvironment returns the URL of the proxy to use for a @@ -133,6 +143,9 @@ func (tr *transportRequest) extraHeaders() Header { } // RoundTrip implements the RoundTripper interface. +// +// For higher-level HTTP client support (such as handling of cookies +// and redirects), see Get, Post, and the Client type. func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { if req.URL == nil { return nil, errors.New("http: nil Request.URL") @@ -280,6 +293,17 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { max = DefaultMaxIdleConnsPerHost } t.idleMu.Lock() + select { + case t.idleConnCh[key] <- pconn: + // We're done with this pconn and somebody else is + // currently waiting for a conn of this type (they're + // actively dialing, but this conn is ready + // first). Chrome calls this socket late binding. See + // https://insouciant.org/tech/connection-management-in-chromium/ + t.idleMu.Unlock() + return true + default: + } if t.idleConn == nil { t.idleConn = make(map[string][]*persistConn) } @@ -298,8 +322,23 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { return true } +func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { + key := cm.key() + t.idleMu.Lock() + defer t.idleMu.Unlock() + if t.idleConnCh == nil { + t.idleConnCh = make(map[string]chan *persistConn) + } + ch, ok := t.idleConnCh[key] + if !ok { + ch = make(chan *persistConn) + t.idleConnCh[key] = ch + } + return ch +} + func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { - key := cm.String() + key := cm.key() t.idleMu.Lock() defer t.idleMu.Unlock() if t.idleConn == nil { @@ -323,7 +362,6 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { return } } - panic("unreachable") } func (t *Transport) setReqConn(r *Request, pc *persistConn) { @@ -355,6 +393,37 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { return pc, nil } + type dialRes struct { + pc *persistConn + err error + } + dialc := make(chan dialRes) + go func() { + pc, err := t.dialConn(cm) + dialc <- dialRes{pc, err} + }() + + idleConnCh := t.getIdleConnCh(cm) + select { + case v := <-dialc: + // Our dial finished. + return v.pc, v.err + case pc := <-idleConnCh: + // Another request finished first and its net.Conn + // became available before our dial. Or somebody + // else's dial that they didn't use. + // But our dial is still going, so give it away + // when it finishes: + go func() { + if v := <-dialc; v.err == nil { + t.putIdleConn(v.pc) + } + }() + return pc, nil + } +} + +func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { conn, err := t.dial("tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { @@ -367,7 +436,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { pconn := &persistConn{ t: t, - cacheKey: cm.String(), + cacheKey: cm.key(), conn: conn, reqch: make(chan requestAndChan, 50), writech: make(chan writeRequest, 50), @@ -517,6 +586,10 @@ type connectMethod struct { targetAddr string // Not used if proxy + http targetScheme (4th example in table) } +func (ck *connectMethod) key() string { + return ck.String() // TODO: use a struct type instead +} + func (ck *connectMethod) String() string { proxyStr := "" targetAddr := ck.targetAddr @@ -592,7 +665,6 @@ func remoteSideClosed(err error) bool { func (pc *persistConn) readLoop() { defer close(pc.closech) alive := true - var lastbody io.ReadCloser // last response body, if any, read on this connection for alive { pb, err := pc.br.Peek(1) @@ -611,16 +683,17 @@ func (pc *persistConn) readLoop() { rc := <-pc.reqch - // Advance past the previous response's body, if the - // caller hasn't done so. - if lastbody != nil { - lastbody.Close() // assumed idempotent - lastbody = nil - } - var resp *Response if err == nil { resp, err = ReadResponse(pc.br, rc.req) + if err == nil && resp.StatusCode == 100 { + // Skip any 100-continue for now. + // TODO(bradfitz): if rc.req had "Expect: 100-continue", + // actually block the request body write and signal the + // writeLoop now to begin sending it. (Issue 2184) For now we + // eat it, since we're never expecting one. + resp, err = ReadResponse(pc.br, rc.req) + } } hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 @@ -636,20 +709,29 @@ func (pc *persistConn) readLoop() { pc.close() err = zerr } else { - resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body} + resp.Body = &readerAndCloser{gzReader, resp.Body} } } resp.Body = &bodyEOFSignal{body: resp.Body} } - if err != nil || resp.Close || rc.req.Close { + if err != nil || resp.Close || rc.req.Close || resp.StatusCode <= 199 { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. alive = false } var waitForBodyRead chan bool if hasBody { - lastbody = resp.Body - waitForBodyRead = make(chan bool, 1) + waitForBodyRead = make(chan bool, 2) + resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error { + // Sending false here sets alive to + // false and closes the connection + // below. + waitForBodyRead <- false + return nil + } resp.Body.(*bodyEOFSignal).fn = func(err error) { alive1 := alive if err != nil { @@ -666,15 +748,6 @@ func (pc *persistConn) readLoop() { } if alive && !hasBody { - // When there's no response body, we immediately - // reuse the TCP connection (putIdleConn), but - // we need to prevent ClientConn.Read from - // closing the Response.Body on the next - // loop, otherwise it might close the body - // before the client code has had a chance to - // read it (even though it'll just be 0, EOF). - lastbody = nil - if !pc.t.putIdleConn(pc) { alive = false } @@ -868,13 +941,16 @@ func canonicalAddr(url *url.URL) string { // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // once, right before its final (error-producing) Read or Close call -// returns. +// returns. If earlyCloseFn is non-nil and Close is called before +// io.EOF is seen, earlyCloseFn is called instead of fn, and its +// return value is the return value from Close. type bodyEOFSignal struct { - body io.ReadCloser - mu sync.Mutex // guards closed, rerr and fn - closed bool // whether Close has been called - rerr error // sticky Read error - fn func(error) // error will be nil on Read io.EOF + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) // error will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen } func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { @@ -907,6 +983,9 @@ func (es *bodyEOFSignal) Close() error { return nil } es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } err := es.body.Close() es.condfn(err) return err @@ -924,28 +1003,7 @@ func (es *bodyEOFSignal) condfn(err error) { es.fn = nil } -type readFirstCloseBoth struct { - io.ReadCloser +type readerAndCloser struct { + io.Reader io.Closer } - -func (r *readFirstCloseBoth) Close() error { - if err := r.ReadCloser.Close(); err != nil { - r.Closer.Close() - return err - } - if err := r.Closer.Close(); err != nil { - return err - } - return nil -} - -// discardOnCloseReadCloser consumes all its input on Close. -type discardOnCloseReadCloser struct { - io.ReadCloser -} - -func (d *discardOnCloseReadCloser) Close() error { - io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed - return d.ReadCloser.Close() -} diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index 68010e68b..9f64a6e4b 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -7,6 +7,7 @@ package http_test import ( + "bufio" "bytes" "compress/gzip" "crypto/rand" @@ -102,7 +103,7 @@ func (tcs *testConnSet) check(t *testing.T) { // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -135,7 +136,7 @@ func TestTransportKeepAlives(t *testing.T) { } func TestTransportConnectionCloseOnResponse(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -186,7 +187,7 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { } func TestTransportConnectionCloseOnRequest(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -237,7 +238,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { } func TestTransportIdleCacheKeys(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -270,7 +271,7 @@ func TestTransportIdleCacheKeys(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) resch := make(chan string) gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -339,7 +340,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportServerClosingUnexpectedly(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -396,7 +397,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { // Test for http://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) if testing.Short() { t.Skip("skipping test in short mode") } @@ -452,7 +453,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly func TestTransportHeadResponses(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -481,7 +482,7 @@ func TestTransportHeadResponses(t *testing.T) { // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -523,7 +524,7 @@ var roundTripTests = []struct { // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const responseBody = "test response body" ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") @@ -580,7 +581,7 @@ func TestRoundTripGzip(t *testing.T) { } func TestTransportGzip(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -673,7 +674,7 @@ func TestTransportGzip(t *testing.T) { } func TestTransportProxy(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ch := make(chan string, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ch <- "real server" @@ -702,7 +703,7 @@ func TestTransportProxy(t *testing.T) { // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. func TestTransportGzipRecursive(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) @@ -729,7 +730,7 @@ func TestTransportGzipRecursive(t *testing.T) { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) gotReqCh := make(chan bool) unblockCh := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -795,7 +796,7 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() @@ -834,7 +835,7 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { // This used to crash; http://golang.org/issue/3266 func TestTransportIdleConnCrash(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) tr := &Transport{} c := &Client{Transport: tr} @@ -864,7 +865,7 @@ func TestTransportIdleConnCrash(t *testing.T) { // which sadly lacked a triggering test. The large response body made // the old race easier to trigger. func TestIssue3644(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const numFoos = 5000 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") @@ -892,7 +893,7 @@ func TestIssue3644(t *testing.T) { // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. func TestIssue3595(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const deniedMsg = "sorry, denied." ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, deniedMsg, StatusUnauthorized) @@ -917,7 +918,7 @@ func TestIssue3595(t *testing.T) { // From http://golang.org/issue/4454 , // "client fails to handle requests with no body and chunked encoding" func TestChunkedNoContent(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) })) @@ -940,21 +941,38 @@ func TestChunkedNoContent(t *testing.T) { } func TestTransportConcurrency(t *testing.T) { - defer checkLeakedTransports(t) - const maxProcs = 16 - const numReqs = 500 + defer afterTest(t) + maxProcs, numReqs := 16, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%v", r.FormValue("echo")) })) defer ts.Close() - tr := &Transport{} + + var wg sync.WaitGroup + wg.Add(numReqs) + + tr := &Transport{ + Dial: func(netw, addr string) (c net.Conn, err error) { + // Due to the Transport's "socket late + // binding" (see idleConnCh in transport.go), + // the numReqs HTTP requests below can finish + // with a dial still outstanding. So count + // our dials as work too so the leak checker + // doesn't complain at us. + wg.Add(1) + defer wg.Done() + return net.Dial(netw, addr) + }, + } + defer tr.CloseIdleConnections() c := &Client{Transport: tr} reqs := make(chan string) defer close(reqs) - var wg sync.WaitGroup - wg.Add(numReqs) for i := 0; i < maxProcs*2; i++ { go func() { for req := range reqs { @@ -973,8 +991,8 @@ func TestTransportConcurrency(t *testing.T) { if string(all) != req { t.Errorf("body of req %s = %q; want %q", req, all, req) } - wg.Done() res.Body.Close() + wg.Done() } }() } @@ -985,7 +1003,7 @@ func TestTransportConcurrency(t *testing.T) { } func TestIssue4191_InfiniteGetTimeout(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { @@ -1046,7 +1064,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { @@ -1114,7 +1132,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { } func TestTransportResponseHeaderTimeout(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) if testing.Short() { t.Skip("skipping timeout test in -short mode") } @@ -1161,7 +1179,7 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { } func TestTransportCancelRequest(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") } @@ -1214,6 +1232,70 @@ func TestTransportCancelRequest(t *testing.T) { } } +// golang.org/issue/3672 -- Client can't close HTTP stream +// Calling Close on a Response.Body used to just read until EOF. +// Now it actually closes the TCP connection. +func TestTransportCloseResponseBody(t *testing.T) { + defer afterTest(t) + writeErr := make(chan error, 1) + msg := []byte("young\n") + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + for { + _, err := w.Write(msg) + if err != nil { + writeErr <- err + return + } + w.(Flusher).Flush() + } + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + defer tr.CancelRequest(req) + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + const repeats = 3 + buf := make([]byte, len(msg)*repeats) + want := bytes.Repeat(msg, repeats) + + _, err = io.ReadFull(res.Body, buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, want) { + t.Errorf("read %q; want %q", buf, want) + } + didClose := make(chan error, 1) + go func() { + didClose <- res.Body.Close() + }() + select { + case err := <-didClose: + if err != nil { + t.Errorf("Close = %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for close") + } + select { + case err := <-writeErr: + if err == nil { + t.Errorf("expected non-nil write error") + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for write error") + } +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { @@ -1227,7 +1309,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) { } func TestTransportAltProto(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) tr := &Transport{} c := &Client{Transport: tr} tr.RegisterProtocol("foo", fooProto{}) @@ -1246,7 +1328,7 @@ func TestTransportAltProto(t *testing.T) { } func TestTransportNoHost(t *testing.T) { - defer checkLeakedTransports(t) + defer afterTest(t) tr := &Transport{} _, err := tr.RoundTrip(&Request{ Header: make(Header), @@ -1260,6 +1342,172 @@ func TestTransportNoHost(t *testing.T) { } } +func TestTransportSocketLateBinding(t *testing.T) { + defer afterTest(t) + + mux := NewServeMux() + fooGate := make(chan bool, 1) + mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { + w.Header().Set("foo-ipport", r.RemoteAddr) + w.(Flusher).Flush() + <-fooGate + }) + mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { + w.Header().Set("bar-ipport", r.RemoteAddr) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + dialGate := make(chan bool, 1) + tr := &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + <-dialGate + return net.Dial(n, addr) + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &Client{ + Transport: tr, + } + + dialGate <- true // only allow one dial + fooRes, err := c.Get(ts.URL + "/foo") + if err != nil { + t.Fatal(err) + } + fooAddr := fooRes.Header.Get("foo-ipport") + if fooAddr == "" { + t.Fatal("No addr on /foo request") + } + time.AfterFunc(200*time.Millisecond, func() { + // let the foo response finish so we can use its + // connection for /bar + fooGate <- true + io.Copy(ioutil.Discard, fooRes.Body) + fooRes.Body.Close() + }) + + barRes, err := c.Get(ts.URL + "/bar") + if err != nil { + t.Fatal(err) + } + barAddr := barRes.Header.Get("bar-ipport") + if barAddr != fooAddr { + t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) + } + barRes.Body.Close() + dialGate <- true +} + +// Issue 2184 +func TestTransportReading100Continue(t *testing.T) { + defer afterTest(t) + + const numReqs = 5 + reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } + reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } + + send100Response := func(w *io.PipeWriter, r *io.PipeReader) { + defer w.Close() + defer r.Close() + br := bufio.NewReader(r) + n := 0 + for { + n++ + req, err := ReadRequest(br) + if err == io.EOF { + return + } + if err != nil { + t.Error(err) + return + } + slurp, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("Server request body slurp: %v", err) + return + } + id := req.Header.Get("Request-Id") + resCode := req.Header.Get("X-Want-Response-Code") + if resCode == "" { + resCode = "100 Continue" + if string(slurp) != reqBody(n) { + t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) + } + } + body := fmt.Sprintf("Response number %d", n) + v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s +Date: Thu, 28 Feb 2013 17:55:41 GMT + +HTTP/1.1 200 OK +Content-Type: text/html +Echo-Request-Id: %s +Content-Length: %d + +%s`, resCode, id, len(body), body), "\n", "\r\n", -1)) + w.Write(v) + if id == reqID(numReqs) { + return + } + } + + } + + tr := &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + sr, sw := io.Pipe() // server read/write + cr, cw := io.Pipe() // client read/write + conn := &rwTestConn{ + Reader: cr, + Writer: sw, + closeFunc: func() error { + sw.Close() + cw.Close() + return nil + }, + } + go send100Response(cw, sr) + return conn, nil + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + testResponse := func(req *Request, name string, wantCode int) { + res, err := c.Do(req) + if err != nil { + t.Fatalf("%s: Do: %v", name, err) + } + if res.StatusCode != wantCode { + t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) + } + if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { + t.Errorf("%s: response id %q != request id %q", name, idBack, id) + } + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("%s: Slurp error: %v", name, err) + } + } + + // Few 100 responses, making sure we're not off-by-one. + for i := 1; i <= numReqs; i++ { + req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("Request-Id", reqID(i)) + testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) + } + + // And some other informational 1xx but non-100 responses, to test + // we return them but don't re-use the connection. + for i := 1; i <= numReqs; i++ { + req, _ := NewRequest("POST", "http://other.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("X-Want-Response-Code", "123 Sesame Street") + testResponse(req, fmt.Sprintf("123, %d/%d", i, numReqs), 123) + } +} + type proxyFromEnvTest struct { req string // URL to fetch; blank means "http://example.com" env string diff --git a/src/pkg/net/http/z_last_test.go b/src/pkg/net/http/z_last_test.go index 44095a8d9..2161db736 100644 --- a/src/pkg/net/http/z_last_test.go +++ b/src/pkg/net/http/z_last_test.go @@ -7,43 +7,81 @@ package http_test import ( "net/http" "runtime" + "sort" "strings" "testing" "time" ) +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if stack == "" || + strings.Contains(stack, "created by net.newPollServer") || + strings.Contains(stack, "created by net.startServer") || + strings.Contains(stack, "created by testing.RunTests") || + strings.Contains(stack, "closeWriteAndWait") || + strings.Contains(stack, "testing.Main(") || + // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, stack) + } + sort.Strings(gs) + return +} + // Verify the other tests didn't leave any goroutines running. // This is in a file named z_last_test.go so it sorts at the end. func TestGoroutinesRunning(t *testing.T) { - n := runtime.NumGoroutine() + if testing.Short() { + t.Skip("not counting goroutines for leakage in -short mode") + } + gs := interestingGoroutines() + + n := 0 + stackCount := make(map[string]int) + for _, g := range gs { + stackCount[g]++ + n++ + } + t.Logf("num goroutines = %d", n) - if n > 20 { - // Currently 14 on Linux (blocked in epoll_wait, - // waiting for on fds that are closed?), but give some - // slop for now. - buf := make([]byte, 1<<20) - buf = buf[:runtime.Stack(buf, true)] - t.Errorf("Too many goroutines:\n%s", buf) + if n > 0 { + t.Error("Too many goroutines.") + for stack, count := range stackCount { + t.Logf("%d instances of:\n%s", count, stack) + } } } -func checkLeakedTransports(t *testing.T) { +func afterTest(t *testing.T) { http.DefaultTransport.(*http.Transport).CloseIdleConnections() if testing.Short() { return } - buf := make([]byte, 1<<20) - var stacks string var bad string badSubstring := map[string]string{ ").readLoop(": "a Transport", ").writeLoop(": "a Transport", "created by net/http/httptest.(*Server).Start": "an httptest.Server", "timeoutHandler": "a TimeoutHandler", + "net.(*netFD).connect(": "a timing out dial", + ").noteClientGone(": "a closenotifier sender", } + var stacks string for i := 0; i < 4; i++ { bad = "" - stacks = string(buf[:runtime.Stack(buf, true)]) + stacks = strings.Join(interestingGoroutines(), "\n\n") for substr, what := range badSubstring { if strings.Contains(stacks, substr) { bad = what diff --git a/src/pkg/net/interface_bsd.go b/src/pkg/net/interface_bsd.go index f58065a85..716b60a97 100644 --- a/src/pkg/net/interface_bsd.go +++ b/src/pkg/net/interface_bsd.go @@ -171,7 +171,6 @@ func newAddr(ifi *Interface, m *syscall.InterfaceAddrMessage) (Addr, error) { // the interface index in the interface-local or link- // local address as the kernel-internal form. if ifa.IP.IsLinkLocalUnicast() { - ifa.Zone = ifi.Name ifa.IP[2], ifa.IP[3] = 0, 0 } } diff --git a/src/pkg/net/interface_darwin.go b/src/pkg/net/interface_darwin.go index 83e483ba2..ad0937db0 100644 --- a/src/pkg/net/interface_darwin.go +++ b/src/pkg/net/interface_darwin.go @@ -50,11 +50,10 @@ func newMulticastAddr(ifi *Interface, m *syscall.InterfaceMulticastAddrMessage) case *syscall.SockaddrInet6: ifma := &IPAddr{IP: make(IP, IPv6len)} copy(ifma.IP, sa.Addr[:]) - // NOTE: KAME based IPv6 protcol stack usually embeds + // NOTE: KAME based IPv6 protocol stack usually embeds // the interface index in the interface-local or link- // local address as the kernel-internal form. if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() { - ifma.Zone = ifi.Name ifma.IP[2], ifma.IP[3] = 0, 0 } ifmat = append(ifmat, ifma.toAddr()) diff --git a/src/pkg/net/interface_freebsd.go b/src/pkg/net/interface_freebsd.go index 1bf5ae72b..5df767910 100644 --- a/src/pkg/net/interface_freebsd.go +++ b/src/pkg/net/interface_freebsd.go @@ -50,11 +50,10 @@ func newMulticastAddr(ifi *Interface, m *syscall.InterfaceMulticastAddrMessage) case *syscall.SockaddrInet6: ifma := &IPAddr{IP: make(IP, IPv6len)} copy(ifma.IP, sa.Addr[:]) - // NOTE: KAME based IPv6 protcol stack usually embeds + // NOTE: KAME based IPv6 protocol stack usually embeds // the interface index in the interface-local or link- // local address as the kernel-internal form. if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() { - ifma.Zone = ifi.Name ifma.IP[2], ifma.IP[3] = 0, 0 } ifmat = append(ifmat, ifma.toAddr()) diff --git a/src/pkg/net/interface_linux.go b/src/pkg/net/interface_linux.go index e66daef06..1207c0f26 100644 --- a/src/pkg/net/interface_linux.go +++ b/src/pkg/net/interface_linux.go @@ -156,9 +156,6 @@ func newAddr(ifi *Interface, ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRou case syscall.AF_INET6: ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv6len)} copy(ifa.IP, a.Value[:]) - if ifam.Scope == syscall.RT_SCOPE_HOST || ifam.Scope == syscall.RT_SCOPE_LINK { - ifa.Zone = ifi.Name - } return ifa } } @@ -229,9 +226,6 @@ func parseProcNetIGMP6(path string, ifi *Interface) []Addr { b[i/2], _ = xtoi2(f[2][i:i+2], 0) } ifma := IPAddr{IP: IP{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]}} - if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() { - ifma.Zone = ifi.Name - } ifmat = append(ifmat, ifma.toAddr()) } } diff --git a/src/pkg/net/interface_test.go b/src/pkg/net/interface_test.go index 7fb342818..e31894abf 100644 --- a/src/pkg/net/interface_test.go +++ b/src/pkg/net/interface_test.go @@ -25,6 +25,32 @@ func loopbackInterface() *Interface { return nil } +// ipv6LinkLocalUnicastAddr returns an IPv6 link-local unicast address +// on the given network interface for tests. It returns "" if no +// suitable address is found. +func ipv6LinkLocalUnicastAddr(ifi *Interface) string { + if ifi == nil { + return "" + } + ifat, err := ifi.Addrs() + if err != nil { + return "" + } + for _, ifa := range ifat { + switch ifa := ifa.(type) { + case *IPAddr: + if ifa.IP.To4() == nil && ifa.IP.IsLinkLocalUnicast() { + return ifa.IP.String() + } + case *IPNet: + if ifa.IP.To4() == nil && ifa.IP.IsLinkLocalUnicast() { + return ifa.IP.String() + } + } + } + return "" +} + func TestInterfaces(t *testing.T) { ift, err := Interfaces() if err != nil { @@ -81,9 +107,9 @@ func testInterfaceMulticastAddrs(t *testing.T, ifi *Interface) { func testAddrs(t *testing.T, ifat []Addr) { for _, ifa := range ifat { - switch v := ifa.(type) { + switch ifa := ifa.(type) { case *IPAddr, *IPNet: - if v == nil { + if ifa == nil { t.Errorf("\tunexpected value: %v", ifa) } else { t.Logf("\tinterface address %q", ifa.String()) @@ -96,9 +122,9 @@ func testAddrs(t *testing.T, ifat []Addr) { func testMulticastAddrs(t *testing.T, ifmat []Addr) { for _, ifma := range ifmat { - switch v := ifma.(type) { + switch ifma := ifma.(type) { case *IPAddr: - if v == nil { + if ifma == nil { t.Errorf("\tunexpected value: %v", ifma) } else { t.Logf("\tjoined group address %q", ifma.String()) diff --git a/src/pkg/net/interface_unix_test.go b/src/pkg/net/interface_unix_test.go index 6dbd6e6e7..0a453c095 100644 --- a/src/pkg/net/interface_unix_test.go +++ b/src/pkg/net/interface_unix_test.go @@ -41,8 +41,11 @@ func (ti *testInterface) teardown() error { } func TestPointToPointInterface(t *testing.T) { - switch runtime.GOOS { - case "darwin": + if testing.Short() { + t.Skip("skipping test in short mode") + } + switch { + case runtime.GOOS == "darwin": t.Skipf("skipping read test on %q", runtime.GOOS) } if os.Getuid() != 0 { @@ -90,6 +93,9 @@ func TestPointToPointInterface(t *testing.T) { } func TestInterfaceArrivalAndDeparture(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode") + } if os.Getuid() != 0 { t.Skip("skipping test; must be root") } diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go index d588e3a42..0e42da216 100644 --- a/src/pkg/net/ip.go +++ b/src/pkg/net/ip.go @@ -36,7 +36,6 @@ type IPMask []byte type IPNet struct { IP IP // network number Mask IPMask // network mask - Zone string // IPv6 scoped addressing zone } // IPv4 returns the IP address (in 16-byte form) of the @@ -223,7 +222,6 @@ func (ip IP) DefaultMask() IPMask { default: return classCMask } - return nil // not reached } func allFF(b []byte) bool { @@ -433,6 +431,9 @@ func (n *IPNet) Contains(ip IP) bool { return true } +// Network returns the address's network name, "ip+net". +func (n *IPNet) Network() string { return "ip+net" } + // String returns the CIDR notation of n like "192.168.100.1/24" // or "2001:DB8::/48" as defined in RFC 4632 and RFC 4291. // If the mask is not in the canonical form, it returns the @@ -451,9 +452,6 @@ func (n *IPNet) String() string { return nn.String() + "/" + itod(uint(l)) } -// Network returns the address's network name, "ip+net". -func (n *IPNet) Network() string { return "ip+net" } - // Parse IPv4 address (d.d.d.d). func parseIPv4(s string) IP { var p [IPv4len]byte @@ -485,26 +483,26 @@ func parseIPv4(s string) IP { return IPv4(p[0], p[1], p[2], p[3]) } -// Parse IPv6 address. Many forms. -// The basic form is a sequence of eight colon-separated -// 16-bit hex numbers separated by colons, -// as in 0123:4567:89ab:cdef:0123:4567:89ab:cdef. -// Two exceptions: -// * A run of zeros can be replaced with "::". -// * The last 32 bits can be in IPv4 form. -// Thus, ::ffff:1.2.3.4 is the IPv4 address 1.2.3.4. -func parseIPv6(s string) IP { - p := make(IP, IPv6len) +// parseIPv6 parses s as a literal IPv6 address described in RFC 4291 +// and RFC 5952. It can also parse a literal scoped IPv6 address with +// zone identifier which is described in RFC 4007 when zoneAllowed is +// true. +func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) { + ip = make(IP, IPv6len) ellipsis := -1 // position of ellipsis in p i := 0 // index in string s + if zoneAllowed { + s, zone = splitHostZone(s) + } + // Might have leading ellipsis if len(s) >= 2 && s[0] == ':' && s[1] == ':' { ellipsis = 0 i = 2 // Might be only ellipsis if i == len(s) { - return p + return ip, zone } } @@ -514,35 +512,35 @@ func parseIPv6(s string) IP { // Hex number. n, i1, ok := xtoi(s, i) if !ok || n > 0xFFFF { - return nil + return nil, zone } // If followed by dot, might be in trailing IPv4. if i1 < len(s) && s[i1] == '.' { if ellipsis < 0 && j != IPv6len-IPv4len { // Not the right place. - return nil + return nil, zone } if j+IPv4len > IPv6len { // Not enough room. - return nil + return nil, zone } - p4 := parseIPv4(s[i:]) - if p4 == nil { - return nil + ip4 := parseIPv4(s[i:]) + if ip4 == nil { + return nil, zone } - p[j] = p4[12] - p[j+1] = p4[13] - p[j+2] = p4[14] - p[j+3] = p4[15] + ip[j] = ip4[12] + ip[j+1] = ip4[13] + ip[j+2] = ip4[14] + ip[j+3] = ip4[15] i = len(s) j += IPv4len break } // Save this 16-bit chunk. - p[j] = byte(n >> 8) - p[j+1] = byte(n) + ip[j] = byte(n >> 8) + ip[j+1] = byte(n) j += 2 // Stop at end of string. @@ -553,14 +551,14 @@ func parseIPv6(s string) IP { // Otherwise must be followed by colon and more. if s[i] != ':' || i+1 == len(s) { - return nil + return nil, zone } i++ // Look for ellipsis. if s[i] == ':' { if ellipsis >= 0 { // already have one - return nil + return nil, zone } ellipsis = j if i++; i == len(s) { // can be at end @@ -571,23 +569,23 @@ func parseIPv6(s string) IP { // Must have used entire string. if i != len(s) { - return nil + return nil, zone } // If didn't parse enough, expand ellipsis. if j < IPv6len { if ellipsis < 0 { - return nil + return nil, zone } n := IPv6len - j for k := j - 1; k >= ellipsis; k-- { - p[k+n] = p[k] + ip[k+n] = ip[k] } for k := ellipsis + n - 1; k >= ellipsis; k-- { - p[k] = 0 + ip[k] = 0 } } - return p + return ip, zone } // A ParseError represents a malformed text string and the type of string that was expected. @@ -600,26 +598,17 @@ func (e *ParseError) Error() string { return "invalid " + e.Type + ": " + e.Text } -func parseIP(s string) IP { - if p := parseIPv4(s); p != nil { - return p - } - if p := parseIPv6(s); p != nil { - return p - } - return nil -} - // ParseIP parses s as an IP address, returning the result. // The string s can be in dotted decimal ("74.125.19.99") // or IPv6 ("2001:4860:0:2001::68") form. // If s is not a valid textual representation of an IP address, // ParseIP returns nil. func ParseIP(s string) IP { - if p := parseIPv4(s); p != nil { - return p + if ip := parseIPv4(s); ip != nil { + return ip } - return parseIPv6(s) + ip, _ := parseIPv6(s, false) + return ip } // ParseCIDR parses s as a CIDR notation IP address and mask, @@ -634,15 +623,15 @@ func ParseCIDR(s string) (IP, *IPNet, error) { if i < 0 { return nil, nil, &ParseError{"CIDR address", s} } - ipstr, maskstr := s[:i], s[i+1:] + addr, mask := s[:i], s[i+1:] iplen := IPv4len - ip := parseIPv4(ipstr) + ip := parseIPv4(addr) if ip == nil { iplen = IPv6len - ip = parseIPv6(ipstr) + ip, _ = parseIPv6(addr, false) } - n, i, ok := dtoi(maskstr, 0) - if ip == nil || !ok || i != len(maskstr) || n < 0 || n > 8*iplen { + n, i, ok := dtoi(mask, 0) + if ip == nil || !ok || i != len(mask) || n < 0 || n > 8*iplen { return nil, nil, &ParseError{"CIDR address", s} } m := CIDRMask(n, 8*iplen) diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go index f8b7f067f..16f30d446 100644 --- a/src/pkg/net/ip_test.go +++ b/src/pkg/net/ip_test.go @@ -5,23 +5,12 @@ package net import ( - "bytes" "reflect" "runtime" "testing" ) -func isEqual(a, b []byte) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil { - return false - } - return bytes.Equal(a, b) -} - -var parseiptests = []struct { +var parseIPTests = []struct { in string out IP }{ @@ -33,22 +22,23 @@ var parseiptests = []struct { {"::ffff:127.0.0.1", IPv4(127, 0, 0, 1)}, {"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}}, {"::ffff:4a7d:1363", IPv4(74, 125, 19, 99)}, + {"fe80::1%lo0", nil}, + {"fe80::1%911", nil}, {"", nil}, } func TestParseIP(t *testing.T) { - for _, tt := range parseiptests { - if out := ParseIP(tt.in); !isEqual(out, tt.out) { + for _, tt := range parseIPTests { + if out := ParseIP(tt.in); !reflect.DeepEqual(out, tt.out) { t.Errorf("ParseIP(%q) = %v, want %v", tt.in, out, tt.out) } } } -var ipstringtests = []struct { +var ipStringTests = []struct { in IP - out string + out string // see RFC 5952 }{ - // cf. RFC 5952 (A Recommendation for IPv6 Address Text Representation) {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1}, "2001:db8::123:12:1"}, {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1}, "2001:db8::1"}, {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0x1, 0, 0, 0, 0x1, 0, 0, 0, 0x1}, "2001:db8:0:1:0:1:0:1"}, @@ -61,14 +51,14 @@ var ipstringtests = []struct { } func TestIPString(t *testing.T) { - for _, tt := range ipstringtests { + for _, tt := range ipStringTests { if out := tt.in.String(); out != tt.out { t.Errorf("IP.String(%v) = %q, want %q", tt.in, out, tt.out) } } } -var ipmasktests = []struct { +var ipMaskTests = []struct { in IP mask IPMask out IP @@ -82,14 +72,14 @@ var ipmasktests = []struct { } func TestIPMask(t *testing.T) { - for _, tt := range ipmasktests { + for _, tt := range ipMaskTests { if out := tt.in.Mask(tt.mask); out == nil || !tt.out.Equal(out) { t.Errorf("IP(%v).Mask(%v) = %v, want %v", tt.in, tt.mask, out, tt.out) } } } -var ipmaskstringtests = []struct { +var ipMaskStringTests = []struct { in IPMask out string }{ @@ -101,14 +91,14 @@ var ipmaskstringtests = []struct { } func TestIPMaskString(t *testing.T) { - for _, tt := range ipmaskstringtests { + for _, tt := range ipMaskStringTests { if out := tt.in.String(); out != tt.out { t.Errorf("IPMask.String(%v) = %q, want %q", tt.in, out, tt.out) } } } -var parsecidrtests = []struct { +var parseCIDRTests = []struct { in string ip IP net *IPNet @@ -138,18 +128,18 @@ var parsecidrtests = []struct { } func TestParseCIDR(t *testing.T) { - for _, tt := range parsecidrtests { + for _, tt := range parseCIDRTests { ip, net, err := ParseCIDR(tt.in) if !reflect.DeepEqual(err, tt.err) { t.Errorf("ParseCIDR(%q) = %v, %v; want %v, %v", tt.in, ip, net, tt.ip, tt.net) } - if err == nil && (!tt.ip.Equal(ip) || !tt.net.IP.Equal(net.IP) || !isEqual(net.Mask, tt.net.Mask)) { - t.Errorf("ParseCIDR(%q) = %v, {%v, %v}; want %v {%v, %v}", tt.in, ip, net.IP, net.Mask, tt.ip, tt.net.IP, tt.net.Mask) + if err == nil && (!tt.ip.Equal(ip) || !tt.net.IP.Equal(net.IP) || !reflect.DeepEqual(net.Mask, tt.net.Mask)) { + t.Errorf("ParseCIDR(%q) = %v, {%v, %v}; want %v, {%v, %v}", tt.in, ip, net.IP, net.Mask, tt.ip, tt.net.IP, tt.net.Mask) } } } -var ipnetcontainstests = []struct { +var ipNetContainsTests = []struct { ip IP net *IPNet ok bool @@ -165,14 +155,14 @@ var ipnetcontainstests = []struct { } func TestIPNetContains(t *testing.T) { - for _, tt := range ipnetcontainstests { + for _, tt := range ipNetContainsTests { if ok := tt.net.Contains(tt.ip); ok != tt.ok { t.Errorf("IPNet(%v).Contains(%v) = %v, want %v", tt.net, tt.ip, ok, tt.ok) } } } -var ipnetstringtests = []struct { +var ipNetStringTests = []struct { in *IPNet out string }{ @@ -183,14 +173,14 @@ var ipnetstringtests = []struct { } func TestIPNetString(t *testing.T) { - for _, tt := range ipnetstringtests { + for _, tt := range ipNetStringTests { if out := tt.in.String(); out != tt.out { t.Errorf("IPNet.String(%v) = %q, want %q", tt.in, out, tt.out) } } } -var cidrmasktests = []struct { +var cidrMaskTests = []struct { ones int bits int out IPMask @@ -210,8 +200,8 @@ var cidrmasktests = []struct { } func TestCIDRMask(t *testing.T) { - for _, tt := range cidrmasktests { - if out := CIDRMask(tt.ones, tt.bits); !isEqual(out, tt.out) { + for _, tt := range cidrMaskTests { + if out := CIDRMask(tt.ones, tt.bits); !reflect.DeepEqual(out, tt.out) { t.Errorf("CIDRMask(%v, %v) = %v, want %v", tt.ones, tt.bits, out, tt.out) } } @@ -229,7 +219,7 @@ var ( v4maskzero = IPMask{0, 0, 0, 0} ) -var networknumberandmasktests = []struct { +var networkNumberAndMaskTests = []struct { in IPNet out IPNet }{ @@ -251,75 +241,90 @@ var networknumberandmasktests = []struct { } func TestNetworkNumberAndMask(t *testing.T) { - for _, tt := range networknumberandmasktests { + for _, tt := range networkNumberAndMaskTests { ip, m := networkNumberAndMask(&tt.in) out := &IPNet{IP: ip, Mask: m} if !reflect.DeepEqual(&tt.out, out) { - t.Errorf("networkNumberAndMask(%v) = %v; want %v", tt.in, out, &tt.out) + t.Errorf("networkNumberAndMask(%v) = %v, want %v", tt.in, out, &tt.out) } } } -var splitjointests = []struct { - Host string - Port string - Join string +var splitJoinTests = []struct { + host string + port string + join string }{ {"www.google.com", "80", "www.google.com:80"}, {"127.0.0.1", "1234", "127.0.0.1:1234"}, {"::1", "80", "[::1]:80"}, - {"google.com", "https%foo", "google.com:https%foo"}, // Go 1.0 behavior + {"fe80::1%lo0", "80", "[fe80::1%lo0]:80"}, + {"localhost%lo0", "80", "[localhost%lo0]:80"}, {"", "0", ":0"}, - {"127.0.0.1", "", "127.0.0.1:"}, // Go 1.0 behaviour - {"www.google.com", "", "www.google.com:"}, // Go 1.0 behaviour + + {"google.com", "https%foo", "google.com:https%foo"}, // Go 1.0 behavior + {"127.0.0.1", "", "127.0.0.1:"}, // Go 1.0 behaviour + {"www.google.com", "", "www.google.com:"}, // Go 1.0 behaviour } -var splitfailuretests = []struct { - HostPort string - Err string +var splitFailureTests = []struct { + hostPort string + err string }{ {"www.google.com", "missing port in address"}, {"127.0.0.1", "missing port in address"}, {"[::1]", "missing port in address"}, + {"[fe80::1%lo0]", "missing port in address"}, + {"[localhost%lo0]", "missing port in address"}, + {"localhost%lo0", "missing port in address"}, + {"::1", "too many colons in address"}, + {"fe80::1%lo0", "too many colons in address"}, + {"fe80::1%lo0:80", "too many colons in address"}, + + {"localhost%lo0:80", "missing brackets in address"}, // Test cases that didn't fail in Go 1.0 + {"[foo:bar]", "missing port in address"}, {"[foo:bar]baz", "missing port in address"}, - {"[foo]:[bar]:baz", "too many colons in address"}, {"[foo]bar:baz", "missing port in address"}, + + {"[foo]:[bar]:baz", "too many colons in address"}, + {"[foo]:[bar]baz", "unexpected '[' in address"}, {"foo[bar]:baz", "unexpected '[' in address"}, + {"foo]bar:baz", "unexpected ']' in address"}, } func TestSplitHostPort(t *testing.T) { - for _, tt := range splitjointests { - if host, port, err := SplitHostPort(tt.Join); host != tt.Host || port != tt.Port || err != nil { - t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.Join, host, port, err, tt.Host, tt.Port) + for _, tt := range splitJoinTests { + if host, port, err := SplitHostPort(tt.join); host != tt.host || port != tt.port || err != nil { + t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.join, host, port, err, tt.host, tt.port) } } - for _, tt := range splitfailuretests { - if _, _, err := SplitHostPort(tt.HostPort); err == nil { - t.Errorf("SplitHostPort(%q) should have failed", tt.HostPort) + for _, tt := range splitFailureTests { + if _, _, err := SplitHostPort(tt.hostPort); err == nil { + t.Errorf("SplitHostPort(%q) should have failed", tt.hostPort) } else { e := err.(*AddrError) - if e.Err != tt.Err { - t.Errorf("SplitHostPort(%q) = _, _, %q; want %q", tt.HostPort, e.Err, tt.Err) + if e.Err != tt.err { + t.Errorf("SplitHostPort(%q) = _, _, %q; want %q", tt.hostPort, e.Err, tt.err) } } } } func TestJoinHostPort(t *testing.T) { - for _, tt := range splitjointests { - if join := JoinHostPort(tt.Host, tt.Port); join != tt.Join { - t.Errorf("JoinHostPort(%q, %q) = %q; want %q", tt.Host, tt.Port, join, tt.Join) + for _, tt := range splitJoinTests { + if join := JoinHostPort(tt.host, tt.port); join != tt.join { + t.Errorf("JoinHostPort(%q, %q) = %q; want %q", tt.host, tt.port, join, tt.join) } } } -var ipaftests = []struct { +var ipAddrFamilyTests = []struct { in IP af4 bool af6 bool @@ -342,7 +347,7 @@ var ipaftests = []struct { } func TestIPAddrFamily(t *testing.T) { - for _, tt := range ipaftests { + for _, tt := range ipAddrFamilyTests { if af := tt.in.To4() != nil; af != tt.af4 { t.Errorf("verifying IPv4 address family for %q = %v, want %v", tt.in, af, tt.af4) } @@ -352,7 +357,7 @@ func TestIPAddrFamily(t *testing.T) { } } -var ipscopetests = []struct { +var ipAddrScopeTests = []struct { scope func(IP) bool in IP ok bool @@ -393,7 +398,7 @@ func name(f interface{}) string { } func TestIPAddrScope(t *testing.T) { - for _, tt := range ipscopetests { + for _, tt := range ipAddrScopeTests { if ok := tt.scope(tt.in); ok != tt.ok { t.Errorf("%s(%q) = %v, want %v", name(tt.scope), tt.in, ok, tt.ok) } diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go index 65defc7ea..12c199d1c 100644 --- a/src/pkg/net/ipraw_test.go +++ b/src/pkg/net/ipraw_test.go @@ -2,32 +2,36 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !plan9 - package net import ( "bytes" "errors" + "fmt" "os" "reflect" "testing" "time" ) -var resolveIPAddrTests = []struct { +type resolveIPAddrTest struct { net string litAddr string addr *IPAddr err error -}{ +} + +var resolveIPAddrTests = []resolveIPAddrTest{ {"ip", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, {"ip4", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, {"ip4:icmp", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, {"ip", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, {"ip6", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, - {"ip6:icmp", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, + {"ip6:ipv6-icmp", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, + + {"ip", "::1%en0", &IPAddr{IP: ParseIP("::1"), Zone: "en0"}, nil}, + {"ip6", "::1%911", &IPAddr{IP: ParseIP("::1"), Zone: "911"}, nil}, {"", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, // Go 1.0 behavior {"", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, // Go 1.0 behavior @@ -37,13 +41,22 @@ var resolveIPAddrTests = []struct { {"tcp", "1.2.3.4:123", nil, UnknownNetworkError("tcp")}, } +func init() { + if ifi := loopbackInterface(); ifi != nil { + index := fmt.Sprintf("%v", ifi.Index) + resolveIPAddrTests = append(resolveIPAddrTests, []resolveIPAddrTest{ + {"ip6", "fe80::1%" + ifi.Name, &IPAddr{IP: ParseIP("fe80::1"), Zone: zoneToString(ifi.Index)}, nil}, + {"ip6", "fe80::1%" + index, &IPAddr{IP: ParseIP("fe80::1"), Zone: index}, nil}, + }...) + } +} + func TestResolveIPAddr(t *testing.T) { for _, tt := range resolveIPAddrTests { addr, err := ResolveIPAddr(tt.net, tt.litAddr) if err != tt.err { - t.Fatalf("ResolveIPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) - } - if !reflect.DeepEqual(addr, tt.addr) { + condFatalf(t, "ResolveIPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) + } else if !reflect.DeepEqual(addr, tt.addr) { t.Fatalf("got %#v; expected %#v", addr, tt.addr) } } @@ -339,3 +352,19 @@ func TestIPConnLocalName(t *testing.T) { } } } + +func TestIPConnRemoteName(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") + } + + raddr := &IPAddr{IP: IPv4(127, 0, 0, 10).To4()} + c, err := DialIP("ip:tcp", &IPAddr{IP: IPv4(127, 0, 0, 1)}, raddr) + if err != nil { + t.Fatalf("DialIP failed: %v", err) + } + defer c.Close() + if !reflect.DeepEqual(raddr, c.RemoteAddr()) { + t.Fatalf("got %#v, expected %#v", c.RemoteAddr(), raddr) + } +} diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go index daccba366..0be94eb70 100644 --- a/src/pkg/net/iprawsock.go +++ b/src/pkg/net/iprawsock.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Raw IP sockets - package net // IPAddr represents the address of an IP end point. @@ -19,12 +17,15 @@ func (a *IPAddr) String() string { if a == nil { return "<nil>" } + if a.Zone != "" { + return a.IP.String() + "%" + a.Zone + } return a.IP.String() } -// ResolveIPAddr parses addr as an IP address and resolves domain -// names to numeric addresses on the network net, which must be -// "ip", "ip4" or "ip6". +// ResolveIPAddr parses addr as an IP address of the form "host" or +// "ipv6-host%zone" and resolves the domain name on the network net, +// which must be "ip", "ip4" or "ip6". func ResolveIPAddr(net, addr string) (*IPAddr, error) { if net == "" { // a hint wildcard for Go 1.0 undocumented behavior net = "ip" diff --git a/src/pkg/net/iprawsock_plan9.go b/src/pkg/net/iprawsock_plan9.go index 88e3b2c60..e62d116b8 100644 --- a/src/pkg/net/iprawsock_plan9.go +++ b/src/pkg/net/iprawsock_plan9.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Raw IP sockets for Plan 9 - package net import ( @@ -34,7 +32,7 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { } // ReadMsgIP reads a packet from c, copying the payload into b and the -// associdated out-of-band data into oob. It returns the number of +// associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags // that were set on the packet and the source address of the packet. func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) { @@ -76,7 +74,7 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, } // ListenIP listens for incoming IP packets addressed to the local -// address laddr. The returned connection c's ReadFrom and WriteTo +// address laddr. The returned connection's ReadFrom and WriteTo // methods can be used to receive and send IP packets with per-packet // addressing. func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { diff --git a/src/pkg/net/iprawsock_posix.go b/src/pkg/net/iprawsock_posix.go index 2ef4db19c..caeeb4653 100644 --- a/src/pkg/net/iprawsock_posix.go +++ b/src/pkg/net/iprawsock_posix.go @@ -4,8 +4,6 @@ // +build darwin freebsd linux netbsd openbsd windows -// Raw IP sockets for POSIX - package net import ( @@ -51,8 +49,8 @@ func (a *IPAddr) toAddr() sockaddr { return a } -// IPConn is the implementation of the Conn and PacketConn -// interfaces for IP network connections. +// IPConn is the implementation of the Conn and PacketConn interfaces +// for IP network connections. type IPConn struct { conn } @@ -98,7 +96,7 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { } // ReadMsgIP reads a packet from c, copying the payload into b and the -// associdated out-of-band data into oob. It returns the number of +// associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags // that were set on the packet and the source address of the packet. func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) { @@ -116,12 +114,13 @@ func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err return } -// WriteToIP writes an IP packet to addr via c, copying the payload from b. +// WriteToIP writes an IP packet to addr via c, copying the payload +// from b. // -// WriteToIP can be made to time out and return -// an error with Timeout() == true after a fixed time limit; -// see SetDeadline and SetWriteDeadline. -// On packet-oriented connections, write timeouts are rare. +// WriteToIP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetWriteDeadline. On packet-oriented connections, write timeouts +// are rare. func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -159,8 +158,9 @@ func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error return c.fd.WriteMsg(b, oob, sa) } -// DialIP connects to the remote address raddr on the network protocol netProto, -// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name. +// DialIP connects to the remote address raddr on the network protocol +// netProto, which must be "ip", "ip4", or "ip6" followed by a colon +// and a protocol number or name. func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { return dialIP(netProto, laddr, raddr, noDeadline) } @@ -185,10 +185,10 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, return newIPConn(fd), nil } -// ListenIP listens for incoming IP packets addressed to the -// local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send IP -// packets with per-packet addressing. +// ListenIP listens for incoming IP packets addressed to the local +// address laddr. The returned connection's ReadFrom and WriteTo +// methods can be used to receive and send IP packets with per-packet +// addressing. func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { net, proto, err := parseNetwork(netProto) if err != nil { diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go index 1ef489289..d93059587 100644 --- a/src/pkg/net/ipsock.go +++ b/src/pkg/net/ipsock.go @@ -68,15 +68,12 @@ func (e InvalidAddrError) Error() string { return string(e) } func (e InvalidAddrError) Timeout() bool { return false } func (e InvalidAddrError) Temporary() bool { return false } -// SplitHostPort splits a network address of the form -// "host:port" or "[host]:port" into host and port. -// The latter form must be used when host contains a colon. +// SplitHostPort splits a network address of the form "host:port", +// "[host]:port" or "[ipv6-host%zone]:port" into host or +// ipv6-host%zone and port. A literal address or host name for IPv6 +// must be enclosed in square brackets, as in "[::1]:80", +// "[ipv6-host]:http" or "[ipv6-host%zone]:80". func SplitHostPort(hostport string) (host, port string, err error) { - host, port, _, err = splitHostPort(hostport) - return -} - -func splitHostPort(hostport string) (host, port, zone string, err error) { j, k := 0, 0 // The port starts after the last colon. @@ -110,10 +107,12 @@ func splitHostPort(hostport string) (host, port, zone string, err error) { j, k = 1, end+1 // there can't be a '[' resp. ']' before these positions } else { host = hostport[:i] - if byteIndex(host, ':') >= 0 { goto tooManyColons } + if byteIndex(host, '%') >= 0 { + goto missingBrackets + } } if byteIndex(hostport[j:], '[') >= 0 { err = &AddrError{"unexpected '[' in address", hostport} @@ -134,13 +133,29 @@ missingPort: tooManyColons: err = &AddrError{"too many colons in address", hostport} return + +missingBrackets: + err = &AddrError{"missing brackets in address", hostport} + return +} + +func splitHostZone(s string) (host, zone string) { + // The IPv6 scoped addressing zone identifer starts after the + // last percent sign. + if i := last(s, '%'); i > 0 { + host, zone = s[:i], s[i+1:] + } else { + host = s + } + return } -// JoinHostPort combines host and port into a network address -// of the form "host:port" or, if host contains a colon, "[host]:port". +// JoinHostPort combines host and port into a network address of the +// form "host:port" or, if host contains a colon or a percent sign, +// "[host]:port". func JoinHostPort(host, port string) string { - // If host has colons, have to bracket it. - if byteIndex(host, ':') >= 0 { + // If host has colons or a percent sign, have to bracket it. + if byteIndex(host, ':') >= 0 || byteIndex(host, '%') >= 0 { return "[" + host + "]:" + port } return host + ":" + port @@ -155,7 +170,7 @@ func resolveInternetAddr(net, addr string, deadline time.Time) (Addr, error) { switch net { case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": if addr != "" { - if host, port, zone, err = splitHostPort(addr); err != nil { + if host, port, err = SplitHostPort(addr); err != nil { return nil, err } if portnum, err = parsePort(net, port); err != nil { @@ -184,21 +199,25 @@ func resolveInternetAddr(net, addr string, deadline time.Time) (Addr, error) { return inetaddr(net, nil, portnum, zone), nil } // Try as an IP address. - if ip := ParseIP(host); ip != nil { + if ip := parseIPv4(host); ip != nil { + return inetaddr(net, ip, portnum, zone), nil + } + if ip, zone := parseIPv6(host, true); ip != nil { return inetaddr(net, ip, portnum, zone), nil } + // Try as a domain name. + host, zone = splitHostZone(host) + addrs, err := lookupHostDeadline(host, deadline) + if err != nil { + return nil, err + } var filter func(IP) IP if net != "" && net[len(net)-1] == '4' { filter = ipv4only } - if net != "" && net[len(net)-1] == '6' { + if net != "" && net[len(net)-1] == '6' || zone != "" { filter = ipv6only } - // Try as a DNS name. - addrs, err := lookupHostDeadline(host, deadline) - if err != nil { - return nil, err - } ip := firstFavoriteAddr(filter, addrs) if ip == nil { // should not happen diff --git a/src/pkg/net/lookup_plan9.go b/src/pkg/net/lookup_plan9.go index ae7cf7942..94c553328 100644 --- a/src/pkg/net/lookup_plan9.go +++ b/src/pkg/net/lookup_plan9.go @@ -7,7 +7,6 @@ package net import ( "errors" "os" - "syscall" ) func query(filename, query string, bufSize int) (res []string, err error) { @@ -70,9 +69,26 @@ func queryDNS(addr string, typ string) (res []string, err error) { return query("/net/dns", addr+" "+typ, 1024) } +// lookupProtocol looks up IP protocol name and returns +// the corresponding protocol number. func lookupProtocol(name string) (proto int, err error) { - // TODO: Implement this - return 0, syscall.EPLAN9 + lines, err := query("/net/cs", "!protocol="+name, 128) + if err != nil { + return 0, err + } + unknownProtoError := errors.New("unknown IP protocol specified: " + name) + if len(lines) == 0 { + return 0, unknownProtoError + } + f := getFields(lines[0]) + if len(f) < 2 { + return 0, unknownProtoError + } + s := f[1] + if n, _, ok := dtoi(s, byteIndex(s, '=')+1); ok { + return n, nil + } + return 0, unknownProtoError } func lookupHost(host string) (addrs []string, err error) { diff --git a/src/pkg/net/multicast_posix_test.go b/src/pkg/net/multicast_posix_test.go deleted file mode 100644 index ff1edaf83..000000000 --- a/src/pkg/net/multicast_posix_test.go +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !plan9 - -package net - -import ( - "errors" - "os" - "runtime" - "testing" -) - -var multicastListenerTests = []struct { - net string - gaddr *UDPAddr - flags Flags - ipv6 bool // test with underlying AF_INET6 socket -}{ - // cf. RFC 4727: Experimental Values in IPv4, IPv6, ICMPv4, ICMPv6, UDP, and TCP Headers - - {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, FlagUp | FlagLoopback, false}, - {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, 0, false}, - {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, 0, true}, - - {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, FlagUp | FlagLoopback, false}, - {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, 0, false}, - - {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, 0, true}, -} - -// TestMulticastListener tests both single and double listen to a test -// listener with same address family, same group address and same port. -func TestMulticastListener(t *testing.T) { - switch runtime.GOOS { - case "netbsd", "openbsd", "plan9", "solaris", "windows": - t.Skipf("skipping test on %q", runtime.GOOS) - case "linux": - if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" { - t.Skipf("skipping test on %q/%q", runtime.GOOS, runtime.GOARCH) - } - } - - for _, tt := range multicastListenerTests { - if tt.ipv6 && (!*testIPv6 || !supportsIPv6 || os.Getuid() != 0) { - continue - } - ifi, err := availMulticastInterface(t, tt.flags) - if err != nil { - continue - } - c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("First ListenMulticastUDP failed: %v", err) - } - checkMulticastListener(t, err, c1, tt.gaddr) - c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("Second ListenMulticastUDP failed: %v", err) - } - checkMulticastListener(t, err, c2, tt.gaddr) - c2.Close() - c1.Close() - } -} - -func TestSimpleMulticastListener(t *testing.T) { - switch runtime.GOOS { - case "plan9": - t.Skipf("skipping test on %q", runtime.GOOS) - case "windows": - if testing.Short() || !*testExternal { - t.Skip("skipping test on windows to avoid firewall") - } - } - - for _, tt := range multicastListenerTests { - if tt.ipv6 { - continue - } - tt.flags = FlagUp | FlagMulticast // for windows testing - ifi, err := availMulticastInterface(t, tt.flags) - if err != nil { - continue - } - c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("First ListenMulticastUDP failed: %v", err) - } - checkSimpleMulticastListener(t, err, c1, tt.gaddr) - c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("Second ListenMulticastUDP failed: %v", err) - } - checkSimpleMulticastListener(t, err, c2, tt.gaddr) - c2.Close() - c1.Close() - } -} - -func checkMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) { - if !multicastRIBContains(t, gaddr.IP) { - t.Errorf("%q not found in RIB", gaddr.String()) - return - } - la := c.LocalAddr() - if la == nil { - t.Error("LocalAddr failed") - return - } - if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { - t.Errorf("got %v; expected a proper address with non-zero port number", la) - return - } -} - -func checkSimpleMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) { - la := c.LocalAddr() - if la == nil { - t.Error("LocalAddr failed") - return - } - if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { - t.Errorf("got %v; expected a proper address with non-zero port number", la) - return - } -} - -func availMulticastInterface(t *testing.T, flags Flags) (*Interface, error) { - var ifi *Interface - if flags != Flags(0) { - ift, err := Interfaces() - if err != nil { - t.Fatalf("Interfaces failed: %v", err) - } - for _, x := range ift { - if x.Flags&flags == flags { - ifi = &x - break - } - } - if ifi == nil { - return nil, errors.New("an appropriate multicast interface not found") - } - } - return ifi, nil -} - -func multicastRIBContains(t *testing.T, ip IP) bool { - ift, err := Interfaces() - if err != nil { - t.Fatalf("Interfaces failed: %v", err) - } - for _, ifi := range ift { - ifmat, err := ifi.MulticastAddrs() - if err != nil { - t.Fatalf("MulticastAddrs failed: %v", err) - } - for _, ifma := range ifmat { - if ifma.(*IPAddr).IP.Equal(ip) { - return true - } - } - } - return false -} diff --git a/src/pkg/net/multicast_test.go b/src/pkg/net/multicast_test.go new file mode 100644 index 000000000..8ff02a3c9 --- /dev/null +++ b/src/pkg/net/multicast_test.go @@ -0,0 +1,184 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "fmt" + "os" + "runtime" + "testing" +) + +var ipv4MulticastListenerTests = []struct { + net string + gaddr *UDPAddr // see RFC 4727 +}{ + {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}}, + + {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}}, +} + +// TestIPv4MulticastListener tests both single and double listen to a +// test listener with same address family, same group address and same +// port. +func TestIPv4MulticastListener(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + closer := func(cs []*UDPConn) { + for _, c := range cs { + if c != nil { + c.Close() + } + } + } + + for _, ifi := range []*Interface{loopbackInterface(), nil} { + // Note that multicast interface assignment by system + // is not recommended because it usually relies on + // routing stuff for finding out an appropriate + // nexthop containing both network and link layer + // adjacencies. + if ifi == nil && !*testExternal { + continue + } + for _, tt := range ipv4MulticastListenerTests { + var err error + cs := make([]*UDPConn, 2) + if cs[0], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil { + t.Fatalf("First ListenMulticastUDP on %v failed: %v", ifi, err) + } + if err := checkMulticastListener(cs[0], tt.gaddr.IP); err != nil { + closer(cs) + t.Fatal(err) + } + if cs[1], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil { + closer(cs) + t.Fatalf("Second ListenMulticastUDP on %v failed: %v", ifi, err) + } + if err := checkMulticastListener(cs[1], tt.gaddr.IP); err != nil { + closer(cs) + t.Fatal(err) + } + closer(cs) + } + } +} + +var ipv6MulticastListenerTests = []struct { + net string + gaddr *UDPAddr // see RFC 4727 +}{ + {"udp", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}}, + + {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}}, +} + +// TestIPv6MulticastListener tests both single and double listen to a +// test listener with same address family, same group address and same +// port. +func TestIPv6MulticastListener(t *testing.T) { + switch runtime.GOOS { + case "plan9", "solaris": + t.Skipf("skipping test on %q", runtime.GOOS) + } + if !supportsIPv6 { + t.Skip("ipv6 is not supported") + } + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") + } + + closer := func(cs []*UDPConn) { + for _, c := range cs { + if c != nil { + c.Close() + } + } + } + + for _, ifi := range []*Interface{loopbackInterface(), nil} { + // Note that multicast interface assignment by system + // is not recommended because it usually relies on + // routing stuff for finding out an appropriate + // nexthop containing both network and link layer + // adjacencies. + if ifi == nil && (!*testExternal || !*testIPv6) { + continue + } + for _, tt := range ipv6MulticastListenerTests { + var err error + cs := make([]*UDPConn, 2) + if cs[0], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil { + t.Fatalf("First ListenMulticastUDP on %v failed: %v", ifi, err) + } + if err := checkMulticastListener(cs[0], tt.gaddr.IP); err != nil { + closer(cs) + t.Fatal(err) + } + if cs[1], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil { + closer(cs) + t.Fatalf("Second ListenMulticastUDP on %v failed: %v", ifi, err) + } + if err := checkMulticastListener(cs[1], tt.gaddr.IP); err != nil { + closer(cs) + t.Fatal(err) + } + closer(cs) + } + } +} + +func checkMulticastListener(c *UDPConn, ip IP) error { + if ok, err := multicastRIBContains(ip); err != nil { + return err + } else if !ok { + return fmt.Errorf("%q not found in multicast RIB", ip.String()) + } + la := c.LocalAddr() + if la, ok := la.(*UDPAddr); !ok || la.Port == 0 { + return fmt.Errorf("got %v; expected a proper address with non-zero port number", la) + } + return nil +} + +func multicastRIBContains(ip IP) (bool, error) { + switch runtime.GOOS { + case "netbsd", "openbsd", "plan9", "solaris", "windows": + return true, nil // not implemented yet + case "linux": + if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" { + return true, nil // not implemented yet + } + } + ift, err := Interfaces() + if err != nil { + return false, err + } + for _, ifi := range ift { + ifmat, err := ifi.MulticastAddrs() + if err != nil { + return false, err + } + for _, ifma := range ifmat { + if ifma.(*IPAddr).IP.Equal(ip) { + return true, nil + } + } + } + return false, nil +} diff --git a/src/pkg/net/newpollserver_unix.go b/src/pkg/net/newpollserver_unix.go deleted file mode 100644 index 618b5b10b..000000000 --- a/src/pkg/net/newpollserver_unix.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin freebsd linux netbsd openbsd - -package net - -import ( - "os" - "syscall" -) - -func newPollServer() (s *pollServer, err error) { - s = new(pollServer) - if s.pr, s.pw, err = os.Pipe(); err != nil { - return nil, err - } - if err = syscall.SetNonblock(int(s.pr.Fd()), true); err != nil { - goto Errno - } - if err = syscall.SetNonblock(int(s.pw.Fd()), true); err != nil { - goto Errno - } - if s.poll, err = newpollster(); err != nil { - goto Error - } - if _, err = s.poll.AddFD(int(s.pr.Fd()), 'r', true); err != nil { - s.poll.Close() - goto Error - } - s.pending = make(map[int]*netFD) - go s.Run() - return s, nil - -Errno: - err = &os.PathError{ - Op: "setnonblock", - Path: s.pr.Name(), - Err: err, - } -Error: - s.pr.Close() - s.pw.Close() - return nil, err -} diff --git a/src/pkg/net/packetconn_test.go b/src/pkg/net/packetconn_test.go index 93c7a6472..ec5dd710f 100644 --- a/src/pkg/net/packetconn_test.go +++ b/src/pkg/net/packetconn_test.go @@ -15,14 +15,20 @@ import ( "time" ) +func strfunc(s string) func() string { + return func() string { + return s + } +} + var packetConnTests = []struct { net string - addr1 string - addr2 string + addr1 func() string + addr2 func() string }{ - {"udp", "127.0.0.1:0", "127.0.0.1:0"}, - {"ip:icmp", "127.0.0.1", "127.0.0.1"}, - {"unixgram", testUnixAddr(), testUnixAddr()}, + {"udp", strfunc("127.0.0.1:0"), strfunc("127.0.0.1:0")}, + {"ip:icmp", strfunc("127.0.0.1"), strfunc("127.0.0.1")}, + {"unixgram", testUnixAddr, testUnixAddr}, } func TestPacketConn(t *testing.T) { @@ -70,21 +76,22 @@ func TestPacketConn(t *testing.T) { continue } - c1, err := ListenPacket(tt.net, tt.addr1) + addr1, addr2 := tt.addr1(), tt.addr2() + c1, err := ListenPacket(tt.net, addr1) if err != nil { t.Fatalf("ListenPacket failed: %v", err) } - defer closer(c1, netstr[0], tt.addr1, tt.addr2) + defer closer(c1, netstr[0], addr1, addr2) c1.LocalAddr() c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) - c2, err := ListenPacket(tt.net, tt.addr2) + c2, err := ListenPacket(tt.net, addr2) if err != nil { t.Fatalf("ListenPacket failed: %v", err) } - defer closer(c2, netstr[0], tt.addr1, tt.addr2) + defer closer(c2, netstr[0], addr1, addr2) c2.LocalAddr() c2.SetDeadline(time.Now().Add(100 * time.Millisecond)) c2.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) @@ -152,11 +159,12 @@ func TestConnAndPacketConn(t *testing.T) { continue } - c1, err := ListenPacket(tt.net, tt.addr1) + addr1, addr2 := tt.addr1(), tt.addr2() + c1, err := ListenPacket(tt.net, addr1) if err != nil { t.Fatalf("ListenPacket failed: %v", err) } - defer closer(c1, netstr[0], tt.addr1, tt.addr2) + defer closer(c1, netstr[0], addr1, addr2) c1.LocalAddr() c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) diff --git a/src/pkg/net/protoconn_test.go b/src/pkg/net/protoconn_test.go index 2fe7d1d1f..b59925e01 100644 --- a/src/pkg/net/protoconn_test.go +++ b/src/pkg/net/protoconn_test.go @@ -15,9 +15,11 @@ import ( "time" ) -// testUnixAddr uses ioutil.TempFile to get a name that is unique. +// testUnixAddr uses ioutil.TempFile to get a name that is unique. It +// also uses /tmp directory in case it is prohibited to create UNIX +// sockets in TMPDIR. func testUnixAddr() string { - f, err := ioutil.TempFile("", "nettest") + f, err := ioutil.TempFile("/tmp", "nettest") if err != nil { panic(err) } @@ -163,7 +165,7 @@ func TestUDPConnSpecificMethods(t *testing.T) { func TestIPConnSpecificMethods(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Skipf("skipping read test on %q", runtime.GOOS) + t.Skipf("skipping test on %q", runtime.GOOS) } if os.Getuid() != 0 { t.Skipf("skipping test; must be root") @@ -220,7 +222,7 @@ func TestIPConnSpecificMethods(t *testing.T) { func TestUnixListenerSpecificMethods(t *testing.T) { switch runtime.GOOS { case "plan9", "windows": - t.Skipf("skipping read test on %q", runtime.GOOS) + t.Skipf("skipping test on %q", runtime.GOOS) } addr := testUnixAddr() diff --git a/src/pkg/net/rpc/jsonrpc/all_test.go b/src/pkg/net/rpc/jsonrpc/all_test.go index 3c7c4d48f..40d4b82d7 100644 --- a/src/pkg/net/rpc/jsonrpc/all_test.go +++ b/src/pkg/net/rpc/jsonrpc/all_test.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "net/rpc" "testing" @@ -185,6 +186,22 @@ func TestMalformedInput(t *testing.T) { ServeConn(srv) // must return, not loop } +func TestMalformedOutput(t *testing.T) { + cli, srv := net.Pipe() + go srv.Write([]byte(`{"id":0,"result":null,"error":null}`)) + go ioutil.ReadAll(srv) + + client := NewClient(cli) + defer client.Close() + + args := &Args{7, 8} + reply := new(Reply) + err := client.Call("Arith.Add", args, reply) + if err == nil { + t.Error("expected error") + } +} + func TestUnexpectedError(t *testing.T) { cli, srv := myPipe() go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error diff --git a/src/pkg/net/rpc/jsonrpc/client.go b/src/pkg/net/rpc/jsonrpc/client.go index 3fa8cbf08..2194f2125 100644 --- a/src/pkg/net/rpc/jsonrpc/client.go +++ b/src/pkg/net/rpc/jsonrpc/client.go @@ -83,7 +83,7 @@ func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error { r.Error = "" r.Seq = c.resp.Id - if c.resp.Error != nil { + if c.resp.Error != nil || c.resp.Result == nil { x, ok := c.resp.Error.(string) if !ok { return fmt.Errorf("invalid error %v", c.resp.Error) diff --git a/src/pkg/net/rpc/server_test.go b/src/pkg/net/rpc/server_test.go index 8a1530623..eb17210ab 100644 --- a/src/pkg/net/rpc/server_test.go +++ b/src/pkg/net/rpc/server_test.go @@ -399,12 +399,10 @@ func (WriteFailCodec) WriteRequest(*Request, interface{}) error { func (WriteFailCodec) ReadResponseHeader(*Response) error { select {} - panic("unreachable") } func (WriteFailCodec) ReadResponseBody(interface{}) error { select {} - panic("unreachable") } func (WriteFailCodec) Close() error { @@ -465,10 +463,16 @@ func countMallocs(dial func() (*Client, error), t *testing.T) float64 { } func TestCountMallocs(t *testing.T) { + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) } func TestCountMallocsOverHTTP(t *testing.T) { + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) } diff --git a/src/pkg/net/sendfile_freebsd.go b/src/pkg/net/sendfile_freebsd.go index 8008bc3b5..dc5b76755 100644 --- a/src/pkg/net/sendfile_freebsd.go +++ b/src/pkg/net/sendfile_freebsd.go @@ -83,7 +83,7 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { break } if err1 == syscall.EAGAIN { - if err1 = c.pollServer.WaitWrite(c); err1 == nil { + if err1 = c.pd.WaitWrite(); err1 == nil { continue } } diff --git a/src/pkg/net/sendfile_linux.go b/src/pkg/net/sendfile_linux.go index 3357e6538..6f1323b3d 100644 --- a/src/pkg/net/sendfile_linux.go +++ b/src/pkg/net/sendfile_linux.go @@ -59,7 +59,7 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { break } if err1 == syscall.EAGAIN { - if err1 = c.pollServer.WaitWrite(c); err1 == nil { + if err1 = c.pd.WaitWrite(); err1 == nil { continue } } diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go index 25c2be5a7..9194a8ec2 100644 --- a/src/pkg/net/server_test.go +++ b/src/pkg/net/server_test.go @@ -9,6 +9,7 @@ import ( "io" "os" "runtime" + "strconv" "testing" "time" ) @@ -41,6 +42,12 @@ func skipServerTest(net, unixsotype, addr string, ipv6, ipv4map, linuxonly bool) return false } +func tempfile(filename string) string { + // use /tmp in case it is prohibited to create + // UNIX sockets in TMPDIR + return "/tmp/" + filename + "." + strconv.Itoa(os.Getpid()) +} + var streamConnServerTests = []struct { snet string // server side saddr string @@ -86,7 +93,7 @@ var streamConnServerTests = []struct { {snet: "tcp6", saddr: "[::1]", cnet: "tcp6", caddr: "[::1]", ipv6: true}, - {snet: "unix", saddr: "/tmp/gotest1.net", cnet: "unix", caddr: "/tmp/gotest1.net.local"}, + {snet: "unix", saddr: tempfile("gotest1.net"), cnet: "unix", caddr: tempfile("gotest1.net.local")}, {snet: "unix", saddr: "@gotest2/net", cnet: "unix", caddr: "@gotest2/net.local", linux: true}, } @@ -135,7 +142,7 @@ var seqpacketConnServerTests = []struct { caddr string // client address empty bool // test with empty data }{ - {net: "unixpacket", saddr: "/tmp/gotest3.net", caddr: "/tmp/gotest3.net.local"}, + {net: "unixpacket", saddr: tempfile("/gotest3.net"), caddr: tempfile("gotest3.net.local")}, {net: "unixpacket", saddr: "@gotest4/net", caddr: "@gotest4/net.local"}, } @@ -294,10 +301,10 @@ var datagramPacketConnServerTests = []struct { {snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true, empty: true}, {snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true, dial: true, empty: true}, - {snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local"}, - {snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local", dial: true}, - {snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local", empty: true}, - {snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local", dial: true, empty: true}, + {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local")}, + {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local"), dial: true}, + {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local"), empty: true}, + {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local"), dial: true, empty: true}, {snet: "unixgram", saddr: "@gotest6/net", cnet: "unixgram", caddr: "@gotest6/net.local", linux: true}, } @@ -395,7 +402,7 @@ func runDatagramConnClient(t *testing.T, net, laddr, taddr string, isEmpty bool) t.Fatalf("Dial(%q, %q) failed: %v", net, taddr, err) } case "unixgram": - c, err = DialUnix(net, &UnixAddr{laddr, net}, &UnixAddr{taddr, net}) + c, err = DialUnix(net, &UnixAddr{Name: laddr, Net: net}, &UnixAddr{Name: taddr, Net: net}) if err != nil { t.Fatalf("DialUnix(%q, {%q, %q}) failed: %v", net, laddr, taddr, err) } diff --git a/src/pkg/net/smtp/auth.go b/src/pkg/net/smtp/auth.go index d401e3c21..3f1339ebc 100644 --- a/src/pkg/net/smtp/auth.go +++ b/src/pkg/net/smtp/auth.go @@ -54,7 +54,16 @@ func PlainAuth(identity, username, password, host string) Auth { func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { if !server.TLS { - return "", nil, errors.New("unencrypted connection") + advertised := false + for _, mechanism := range server.Auth { + if mechanism == "PLAIN" { + advertised = true + break + } + } + if !advertised { + return "", nil, errors.New("unencrypted connection") + } } if server.Name != a.host { return "", nil, errors.New("wrong host name") diff --git a/src/pkg/net/smtp/smtp_test.go b/src/pkg/net/smtp/smtp_test.go index 8317428cb..c190b32c0 100644 --- a/src/pkg/net/smtp/smtp_test.go +++ b/src/pkg/net/smtp/smtp_test.go @@ -57,6 +57,41 @@ testLoop: } } +func TestAuthPlain(t *testing.T) { + auth := PlainAuth("foo", "bar", "baz", "servername") + + tests := []struct { + server *ServerInfo + err string + }{ + { + server: &ServerInfo{Name: "servername", TLS: true}, + }, + { + // Okay; explicitly advertised by server. + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + }, + { + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + err: "unencrypted connection", + }, + { + server: &ServerInfo{Name: "attacker", TLS: true}, + err: "wrong host name", + }, + } + for i, tt := range tests { + _, _, err := auth.Start(tt.server) + got := "" + if err != nil { + got = err.Error() + } + if got != tt.err { + t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + } + } +} + type faker struct { io.ReadWriter } diff --git a/src/pkg/net/sock_bsd.go b/src/pkg/net/sock_bsd.go index 3205f9404..d99349265 100644 --- a/src/pkg/net/sock_bsd.go +++ b/src/pkg/net/sock_bsd.go @@ -27,5 +27,11 @@ func maxListenerBacklog() int { if n == 0 || err != nil { return syscall.SOMAXCONN } + // FreeBSD stores the backlog in a uint16, as does Linux. + // Assume the other BSDs do too. Truncate number to avoid wrapping. + // See issue 5030. + if n > 1<<16-1 { + n = 1<<16 - 1 + } return int(n) } diff --git a/src/pkg/net/sock_cloexec.go b/src/pkg/net/sock_cloexec.go index 12d0f3488..3f22cd8f5 100644 --- a/src/pkg/net/sock_cloexec.go +++ b/src/pkg/net/sock_cloexec.go @@ -44,8 +44,8 @@ func sysSocket(f, t, p int) (int, error) { func accept(fd int) (int, syscall.Sockaddr, error) { nfd, sa, err := syscall.Accept4(fd, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) // The accept4 system call was introduced in Linux 2.6.28. If - // we get an ENOSYS error, fall back to using accept. - if err == nil || err != syscall.ENOSYS { + // we get an ENOSYS or EINVAL error, fall back to using accept. + if err == nil || (err != syscall.ENOSYS && err != syscall.EINVAL) { return nfd, sa, err } diff --git a/src/pkg/net/sock_linux.go b/src/pkg/net/sock_linux.go index 8bbd74ddc..cc5ce153b 100644 --- a/src/pkg/net/sock_linux.go +++ b/src/pkg/net/sock_linux.go @@ -21,5 +21,11 @@ func maxListenerBacklog() int { if n == 0 || !ok { return syscall.SOMAXCONN } + // Linux stores the backlog in a uint16. + // Truncate number to avoid wrapping. + // See issue 5030. + if n > 1<<16-1 { + n = 1<<16 - 1 + } return n } diff --git a/src/pkg/net/sock_posix.go b/src/pkg/net/sock_posix.go index b50a892b1..be89c26db 100644 --- a/src/pkg/net/sock_posix.go +++ b/src/pkg/net/sock_posix.go @@ -25,7 +25,8 @@ func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, return nil, err } - if ulsa != nil { + // This socket is used by a listener. + if ulsa != nil && ursa == nil { // We provide a socket that listens to a wildcard // address with reusable UDP port when the given ulsa // is an appropriate UDP multicast address prefix. @@ -37,6 +38,9 @@ func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, closesocket(s) return nil, err } + } + + if ulsa != nil { if err = syscall.Bind(s, ulsa); err != nil { closesocket(s) return nil, err @@ -48,19 +52,27 @@ func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, return nil, err } + // This socket is used by a dialer. if ursa != nil { - fd.wdeadline.setTime(deadline) - if err = fd.connect(ursa); err != nil { - closesocket(s) + if !deadline.IsZero() { + setWriteDeadline(fd, deadline) + } + if err = fd.connect(ulsa, ursa); err != nil { + fd.Close() return nil, err } fd.isConnected = true - fd.wdeadline.set(0) + if !deadline.IsZero() { + setWriteDeadline(fd, time.Time{}) + } } lsa, _ := syscall.Getsockname(s) laddr := toAddr(lsa) rsa, _ := syscall.Getpeername(s) + if rsa == nil { + rsa = ursa + } raddr := toAddr(rsa) fd.setAddr(laddr, raddr) return fd, nil diff --git a/src/pkg/net/sock_windows.go b/src/pkg/net/sock_windows.go index a77c48437..41368d39e 100644 --- a/src/pkg/net/sock_windows.go +++ b/src/pkg/net/sock_windows.go @@ -8,6 +8,7 @@ import "syscall" func maxListenerBacklog() int { // TODO: Implement this + // NOTE: Never return a number bigger than 1<<16 - 1. See issue 5030. return syscall.SOMAXCONN } diff --git a/src/pkg/net/sockopt_posix.go b/src/pkg/net/sockopt_posix.go index fe371fe0c..1590f4e98 100644 --- a/src/pkg/net/sockopt_posix.go +++ b/src/pkg/net/sockopt_posix.go @@ -11,7 +11,6 @@ package net import ( "os" "syscall" - "time" ) // Boolean to int. @@ -119,24 +118,6 @@ func setWriteBuffer(fd *netFD, bytes int) error { return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)) } -// TODO(dfc) these unused error returns could be removed - -func setReadDeadline(fd *netFD, t time.Time) error { - fd.rdeadline.setTime(t) - return nil -} - -func setWriteDeadline(fd *netFD, t time.Time) error { - fd.wdeadline.setTime(t) - return nil -} - -func setDeadline(fd *netFD, t time.Time) error { - setReadDeadline(fd, t) - setWriteDeadline(fd, t) - return nil -} - func setKeepAlive(fd *netFD, keepalive bool) error { if err := fd.incref(false); err != nil { return err diff --git a/src/pkg/net/sockopt_windows.go b/src/pkg/net/sockopt_windows.go index 509b5963b..0861fe8f4 100644 --- a/src/pkg/net/sockopt_windows.go +++ b/src/pkg/net/sockopt_windows.go @@ -9,6 +9,7 @@ package net import ( "os" "syscall" + "time" ) func setDefaultSockopts(s syscall.Handle, f, t int, ipv6only bool) error { @@ -47,3 +48,21 @@ func setDefaultMulticastSockopts(s syscall.Handle) error { } return nil } + +// TODO(dfc) these unused error returns could be removed + +func setReadDeadline(fd *netFD, t time.Time) error { + fd.rdeadline.setTime(t) + return nil +} + +func setWriteDeadline(fd *netFD, t time.Time) error { + fd.wdeadline.setTime(t) + return nil +} + +func setDeadline(fd *netFD, t time.Time) error { + setReadDeadline(fd, t) + setWriteDeadline(fd, t) + return nil +} diff --git a/src/pkg/net/tcp_test.go b/src/pkg/net/tcp_test.go index 6c4485a94..a71b02b47 100644 --- a/src/pkg/net/tcp_test.go +++ b/src/pkg/net/tcp_test.go @@ -5,6 +5,7 @@ package net import ( + "fmt" "reflect" "runtime" "testing" @@ -146,24 +147,39 @@ func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { } } -var resolveTCPAddrTests = []struct { +type resolveTCPAddrTest struct { net string litAddr string addr *TCPAddr err error -}{ +} + +var resolveTCPAddrTests = []resolveTCPAddrTest{ {"tcp", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, {"tcp4", "127.0.0.1:65535", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil}, {"tcp", "[::1]:1", &TCPAddr{IP: ParseIP("::1"), Port: 1}, nil}, {"tcp6", "[::1]:65534", &TCPAddr{IP: ParseIP("::1"), Port: 65534}, nil}, + {"tcp", "[::1%en0]:1", &TCPAddr{IP: ParseIP("::1"), Port: 1, Zone: "en0"}, nil}, + {"tcp6", "[::1%911]:2", &TCPAddr{IP: ParseIP("::1"), Port: 2, Zone: "911"}, nil}, + {"", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior {"", "[::1]:0", &TCPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior {"http", "127.0.0.1:0", nil, UnknownNetworkError("http")}, } +func init() { + if ifi := loopbackInterface(); ifi != nil { + index := fmt.Sprintf("%v", ifi.Index) + resolveTCPAddrTests = append(resolveTCPAddrTests, []resolveTCPAddrTest{ + {"tcp6", "[fe80::1%" + ifi.Name + "]:3", &TCPAddr{IP: ParseIP("fe80::1"), Port: 3, Zone: zoneToString(ifi.Index)}, nil}, + {"tcp6", "[fe80::1%" + index + "]:4", &TCPAddr{IP: ParseIP("fe80::1"), Port: 4, Zone: index}, nil}, + }...) + } +} + func TestResolveTCPAddr(t *testing.T) { for _, tt := range resolveTCPAddrTests { addr, err := ResolveTCPAddr(tt.net, tt.litAddr) @@ -193,14 +209,88 @@ func TestTCPListenerName(t *testing.T) { for _, tt := range tcpListenerNameTests { ln, err := ListenTCP(tt.net, tt.laddr) if err != nil { - t.Errorf("ListenTCP failed: %v", err) - return + t.Fatalf("ListenTCP failed: %v", err) } defer ln.Close() la := ln.Addr() if a, ok := la.(*TCPAddr); !ok || a.Port == 0 { - t.Errorf("got %v; expected a proper address with non-zero port number", la) - return + t.Fatalf("got %v; expected a proper address with non-zero port number", la) } } } + +func TestIPv6LinkLocalUnicastTCP(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + if !supportsIPv6 { + t.Skip("ipv6 is not supported") + } + ifi := loopbackInterface() + if ifi == nil { + t.Skip("loopback interface not found") + } + laddr := ipv6LinkLocalUnicastAddr(ifi) + if laddr == "" { + t.Skip("ipv6 unicast address on loopback not found") + } + + type test struct { + net, addr string + nameLookup bool + } + var tests = []test{ + {"tcp", "[" + laddr + "%" + ifi.Name + "]:0", false}, + {"tcp6", "[" + laddr + "%" + ifi.Name + "]:0", false}, + } + switch runtime.GOOS { + case "darwin", "freebsd", "opensbd", "netbsd": + tests = append(tests, []test{ + {"tcp", "[localhost%" + ifi.Name + "]:0", true}, + {"tcp6", "[localhost%" + ifi.Name + "]:0", true}, + }...) + case "linux": + tests = append(tests, []test{ + {"tcp", "[ip6-localhost%" + ifi.Name + "]:0", true}, + {"tcp6", "[ip6-localhost%" + ifi.Name + "]:0", true}, + }...) + } + for _, tt := range tests { + ln, err := Listen(tt.net, tt.addr) + if err != nil { + // It might return "LookupHost returned no + // suitable address" error on some platforms. + t.Logf("Listen failed: %v", err) + continue + } + defer ln.Close() + if la, ok := ln.Addr().(*TCPAddr); !ok || !tt.nameLookup && la.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", la) + } + + done := make(chan int) + go transponder(t, ln, done) + + c, err := Dial(tt.net, ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c.Close() + if la, ok := c.LocalAddr().(*TCPAddr); !ok || !tt.nameLookup && la.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", la) + } + if ra, ok := c.RemoteAddr().(*TCPAddr); !ok || !tt.nameLookup && ra.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", ra) + } + + if _, err := c.Write([]byte("TCP OVER IPV6 LINKLOCAL TEST")); err != nil { + t.Fatalf("Conn.Write failed: %v", err) + } + b := make([]byte, 32) + if _, err := c.Read(b); err != nil { + t.Fatalf("Conn.Read failed: %v", err) + } + + <-done + } +} diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go index d5158b22d..4d9ebd214 100644 --- a/src/pkg/net/tcpsock.go +++ b/src/pkg/net/tcpsock.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TCP sockets - package net // TCPAddr represents the address of a TCP end point. @@ -20,14 +18,18 @@ func (a *TCPAddr) String() string { if a == nil { return "<nil>" } + if a.Zone != "" { + return JoinHostPort(a.IP.String()+"%"+a.Zone, itoa(a.Port)) + } return JoinHostPort(a.IP.String(), itoa(a.Port)) } -// ResolveTCPAddr parses addr as a TCP address of the form -// host:port and resolves domain names or port names to -// numeric addresses on the network net, which must be "tcp", -// "tcp4" or "tcp6". A literal IPv6 host address must be -// enclosed in square brackets, as in "[::]:80". +// ResolveTCPAddr parses addr as a TCP address of the form "host:port" +// or "[ipv6-host%zone]:port" and resolves a pair of domain name and +// port name on the network net, which must be "tcp", "tcp4" or +// "tcp6". A literal address or host name for IPv6 must be enclosed +// in square brackets, as in "[::1]:80", "[ipv6-host]:http" or +// "[ipv6-host%zone]:80". func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { switch net { case "tcp", "tcp4", "tcp6": diff --git a/src/pkg/net/tcpsock_plan9.go b/src/pkg/net/tcpsock_plan9.go index ed3664603..48334fed7 100644 --- a/src/pkg/net/tcpsock_plan9.go +++ b/src/pkg/net/tcpsock_plan9.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TCP sockets for Plan 9 - package net import ( @@ -161,12 +159,16 @@ func (l *TCPListener) SetDeadline(t time.Time) error { // File returns a copy of the underlying os.File, set to blocking // mode. It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. +// +// The returned os.File's file descriptor is different from the +// connection's. Attempting to change properties of the original +// using this duplicate may or may not have the desired effect. func (l *TCPListener) File() (f *os.File, err error) { return l.dup() } // ListenTCP announces on the TCP address laddr and returns a TCP // listener. Net must be "tcp", "tcp4", or "tcp6". If laddr has a -// port of 0, it means to listen on some available port. The caller -// can use l.Addr() to retrieve the chosen address. +// port of 0, ListenTCP will choose an available port. The caller can +// use the Addr method of TCPListener to retrieve the chosen address. func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { switch net { case "tcp", "tcp4", "tcp6": diff --git a/src/pkg/net/tcpsock_posix.go b/src/pkg/net/tcpsock_posix.go index bd5a2a287..876edb101 100644 --- a/src/pkg/net/tcpsock_posix.go +++ b/src/pkg/net/tcpsock_posix.go @@ -4,8 +4,6 @@ // +build darwin freebsd linux netbsd openbsd windows -// TCP sockets - package net import ( @@ -58,8 +56,8 @@ func (a *TCPAddr) toAddr() sockaddr { return a } -// TCPConn is an implementation of the Conn interface -// for TCP network connections. +// TCPConn is an implementation of the Conn interface for TCP network +// connections. type TCPConn struct { conn } @@ -96,17 +94,17 @@ func (c *TCPConn) CloseWrite() error { return c.fd.CloseWrite() } -// SetLinger sets the behavior of Close() on a connection -// which still has data waiting to be sent or to be acknowledged. +// SetLinger sets the behavior of Close() on a connection which still +// has data waiting to be sent or to be acknowledged. // -// If sec < 0 (the default), Close returns immediately and -// the operating system finishes sending the data in the background. +// If sec < 0 (the default), Close returns immediately and the +// operating system finishes sending the data in the background. // // If sec == 0, Close returns immediately and the operating system // discards any unsent or unacknowledged data. // -// If sec > 0, Close blocks for at most sec seconds waiting for -// data to be sent and acknowledged. +// If sec > 0, Close blocks for at most sec seconds waiting for data +// to be sent and acknowledged. func (c *TCPConn) SetLinger(sec int) error { if !c.ok() { return syscall.EINVAL @@ -124,9 +122,9 @@ func (c *TCPConn) SetKeepAlive(keepalive bool) error { } // SetNoDelay controls whether the operating system should delay -// packet transmission in hopes of sending fewer packets -// (Nagle's algorithm). The default is true (no delay), meaning -// that data is sent as soon as possible after a Write. +// packet transmission in hopes of sending fewer packets (Nagle's +// algorithm). The default is true (no delay), meaning that data is +// sent as soon as possible after a Write. func (c *TCPConn) SetNoDelay(noDelay bool) error { if !c.ok() { return syscall.EINVAL @@ -135,8 +133,8 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error { } // DialTCP connects to the remote address raddr on the network net, -// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used -// as the local address for the connection. +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is +// used as the local address for the connection. func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { switch net { case "tcp", "tcp4", "tcp6": @@ -216,16 +214,15 @@ func spuriousENOTAVAIL(err error) bool { return ok && e.Err == syscall.EADDRNOTAVAIL } -// TCPListener is a TCP network listener. -// Clients should typically use variables of type Listener -// instead of assuming TCP. +// TCPListener is a TCP network listener. Clients should typically +// use variables of type Listener instead of assuming TCP. type TCPListener struct { fd *netFD } -// AcceptTCP accepts the next incoming call and returns the new connection -// and the remote address. -func (l *TCPListener) AcceptTCP() (c *TCPConn, err error) { +// AcceptTCP accepts the next incoming call and returns the new +// connection and the remote address. +func (l *TCPListener) AcceptTCP() (*TCPConn, error) { if l == nil || l.fd == nil { return nil, syscall.EINVAL } @@ -236,14 +233,14 @@ func (l *TCPListener) AcceptTCP() (c *TCPConn, err error) { return newTCPConn(fd), nil } -// Accept implements the Accept method in the Listener interface; -// it waits for the next call and returns a generic Conn. -func (l *TCPListener) Accept() (c Conn, err error) { - c1, err := l.AcceptTCP() +// Accept implements the Accept method in the Listener interface; it +// waits for the next call and returns a generic Conn. +func (l *TCPListener) Accept() (Conn, error) { + c, err := l.AcceptTCP() if err != nil { return nil, err } - return c1, nil + return c, nil } // Close stops listening on the TCP address. @@ -267,15 +264,19 @@ func (l *TCPListener) SetDeadline(t time.Time) error { return setDeadline(l.fd, t) } -// File returns a copy of the underlying os.File, set to blocking mode. -// It is the caller's responsibility to close f when finished. +// File returns a copy of the underlying os.File, set to blocking +// mode. It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. +// +// The returned os.File's file descriptor is different from the +// connection's. Attempting to change properties of the original +// using this duplicate may or may not have the desired effect. func (l *TCPListener) File() (f *os.File, err error) { return l.fd.dup() } -// ListenTCP announces on the TCP address laddr and returns a TCP listener. -// Net must be "tcp", "tcp4", or "tcp6". -// If laddr has a port of 0, it means to listen on some available port. -// The caller can use l.Addr() to retrieve the chosen address. +// ListenTCP announces on the TCP address laddr and returns a TCP +// listener. Net must be "tcp", "tcp4", or "tcp6". If laddr has a +// port of 0, ListenTCP will choose an available port. The caller can +// use the Addr method of TCPListener to retrieve the chosen address. func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { switch net { case "tcp", "tcp4", "tcp6": @@ -291,7 +292,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { } err = syscall.Listen(fd.sysfd, listenerBacklog) if err != nil { - closesocket(fd.sysfd) + fd.Close() return nil, &OpError{"listen", net, laddr, err} } return &TCPListener{fd}, nil diff --git a/src/pkg/net/textproto/reader.go b/src/pkg/net/textproto/reader.go index b61bea862..5bd26ac8d 100644 --- a/src/pkg/net/textproto/reader.go +++ b/src/pkg/net/textproto/reader.go @@ -489,7 +489,6 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { return m, err } } - panic("unreachable") } // CanonicalMIMEHeaderKey returns the canonical format of the @@ -575,6 +574,7 @@ var commonHeaders = []string{ "Content-Length", "Content-Transfer-Encoding", "Content-Type", + "Cookie", "Date", "Dkim-Signature", "Etag", diff --git a/src/pkg/net/textproto/reader_test.go b/src/pkg/net/textproto/reader_test.go index 26987f611..f27042d4e 100644 --- a/src/pkg/net/textproto/reader_test.go +++ b/src/pkg/net/textproto/reader_test.go @@ -290,6 +290,7 @@ Non-Interned: test `, "\n", "\r\n", -1) func BenchmarkReadMIMEHeader(b *testing.B) { + b.ReportAllocs() var buf bytes.Buffer br := bufio.NewReader(&buf) r := NewReader(br) @@ -319,6 +320,7 @@ func BenchmarkReadMIMEHeader(b *testing.B) { } func BenchmarkUncommon(b *testing.B) { + b.ReportAllocs() var buf bytes.Buffer br := bufio.NewReader(&buf) r := NewReader(br) diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index 0260efcc0..2e92147b8 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -532,7 +532,7 @@ func TestReadDeadlineDataAvailable(t *testing.T) { defer ln.Close() servec := make(chan copyRes) - const msg = "data client shouldn't read, even though it it'll be waiting" + const msg = "data client shouldn't read, even though it'll be waiting" go func() { c, err := ln.Accept() if err != nil { @@ -596,6 +596,64 @@ func TestWriteDeadlineBufferAvailable(t *testing.T) { } } +// TestAcceptDeadlineConnectionAvailable tests that accept deadlines work, even +// if there's incoming connections available. +func TestAcceptDeadlineConnectionAvailable(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t).(*TCPListener) + defer ln.Close() + + go func() { + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + var buf [1]byte + c.Read(buf[:]) // block until the connection or listener is closed + }() + time.Sleep(10 * time.Millisecond) + ln.SetDeadline(time.Now().Add(-5 * time.Second)) // in the past + c, err := ln.Accept() + if err == nil { + defer c.Close() + } + if !isTimeout(err) { + t.Fatalf("Accept: got %v; want timeout", err) + } +} + +// TestConnectDeadlineInThePast tests that connect deadlines work, even +// if the connection can be established w/o blocking. +func TestConnectDeadlineInThePast(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t).(*TCPListener) + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err == nil { + defer c.Close() + } + }() + time.Sleep(10 * time.Millisecond) + c, err := DialTimeout("tcp", ln.Addr().String(), -5*time.Second) // in the past + if err == nil { + defer c.Close() + } + if !isTimeout(err) { + t.Fatalf("DialTimeout: got %v; want timeout", err) + } +} + // TestProlongTimeout tests concurrent deadline modification. // Known to cause data races in the past. func TestProlongTimeout(t *testing.T) { diff --git a/src/pkg/net/udp_test.go b/src/pkg/net/udp_test.go index 220422e13..4278f6dd4 100644 --- a/src/pkg/net/udp_test.go +++ b/src/pkg/net/udp_test.go @@ -5,29 +5,45 @@ package net import ( + "fmt" "reflect" "runtime" "testing" ) -var resolveUDPAddrTests = []struct { +type resolveUDPAddrTest struct { net string litAddr string addr *UDPAddr err error -}{ +} + +var resolveUDPAddrTests = []resolveUDPAddrTest{ {"udp", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, {"udp4", "127.0.0.1:65535", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil}, {"udp", "[::1]:1", &UDPAddr{IP: ParseIP("::1"), Port: 1}, nil}, {"udp6", "[::1]:65534", &UDPAddr{IP: ParseIP("::1"), Port: 65534}, nil}, + {"udp", "[::1%en0]:1", &UDPAddr{IP: ParseIP("::1"), Port: 1, Zone: "en0"}, nil}, + {"udp6", "[::1%911]:2", &UDPAddr{IP: ParseIP("::1"), Port: 2, Zone: "911"}, nil}, + {"", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior {"", "[::1]:0", &UDPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior {"sip", "127.0.0.1:0", nil, UnknownNetworkError("sip")}, } +func init() { + if ifi := loopbackInterface(); ifi != nil { + index := fmt.Sprintf("%v", ifi.Index) + resolveUDPAddrTests = append(resolveUDPAddrTests, []resolveUDPAddrTest{ + {"udp6", "[fe80::1%" + ifi.Name + "]:3", &UDPAddr{IP: ParseIP("fe80::1"), Port: 3, Zone: zoneToString(ifi.Index)}, nil}, + {"udp6", "[fe80::1%" + index + "]:4", &UDPAddr{IP: ParseIP("fe80::1"), Port: 4, Zone: index}, nil}, + }...) + } +} + func TestResolveUDPAddr(t *testing.T) { for _, tt := range resolveUDPAddrTests { addr, err := ResolveUDPAddr(tt.net, tt.litAddr) @@ -135,14 +151,125 @@ func TestUDPConnLocalName(t *testing.T) { for _, tt := range udpConnLocalNameTests { c, err := ListenUDP(tt.net, tt.laddr) if err != nil { - t.Errorf("ListenUDP failed: %v", err) - return + t.Fatalf("ListenUDP failed: %v", err) } defer c.Close() la := c.LocalAddr() if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { - t.Errorf("got %v; expected a proper address with non-zero port number", la) - return + t.Fatalf("got %v; expected a proper address with non-zero port number", la) + } + } +} + +func TestUDPConnLocalAndRemoteNames(t *testing.T) { + for _, laddr := range []string{"", "127.0.0.1:0"} { + c1, err := ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenUDP failed: %v", err) + } + defer c1.Close() + + var la *UDPAddr + if laddr != "" { + var err error + if la, err = ResolveUDPAddr("udp", laddr); err != nil { + t.Fatalf("ResolveUDPAddr failed: %v", err) + } + } + c2, err := DialUDP("udp", la, c1.LocalAddr().(*UDPAddr)) + if err != nil { + t.Fatalf("DialUDP failed: %v", err) + } + defer c2.Close() + + var connAddrs = [4]struct { + got Addr + ok bool + }{ + {c1.LocalAddr(), true}, + {c1.(*UDPConn).RemoteAddr(), false}, + {c2.LocalAddr(), true}, + {c2.RemoteAddr(), true}, + } + for _, ca := range connAddrs { + if a, ok := ca.got.(*UDPAddr); ok != ca.ok || ok && a.Port == 0 { + t.Fatalf("got %v; expected a proper address with non-zero port number", ca.got) + } + } + } +} + +func TestIPv6LinkLocalUnicastUDP(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + if !supportsIPv6 { + t.Skip("ipv6 is not supported") + } + ifi := loopbackInterface() + if ifi == nil { + t.Skip("loopback interface not found") + } + laddr := ipv6LinkLocalUnicastAddr(ifi) + if laddr == "" { + t.Skip("ipv6 unicast address on loopback not found") + } + + type test struct { + net, addr string + nameLookup bool + } + var tests = []test{ + {"udp", "[" + laddr + "%" + ifi.Name + "]:0", false}, + {"udp6", "[" + laddr + "%" + ifi.Name + "]:0", false}, + } + switch runtime.GOOS { + case "darwin", "freebsd", "openbsd", "netbsd": + tests = append(tests, []test{ + {"udp", "[localhost%" + ifi.Name + "]:0", true}, + {"udp6", "[localhost%" + ifi.Name + "]:0", true}, + }...) + case "linux": + tests = append(tests, []test{ + {"udp", "[ip6-localhost%" + ifi.Name + "]:0", true}, + {"udp6", "[ip6-localhost%" + ifi.Name + "]:0", true}, + }...) + } + for _, tt := range tests { + c1, err := ListenPacket(tt.net, tt.addr) + if err != nil { + // It might return "LookupHost returned no + // suitable address" error on some platforms. + t.Logf("ListenPacket failed: %v", err) + continue + } + defer c1.Close() + if la, ok := c1.LocalAddr().(*UDPAddr); !ok || !tt.nameLookup && la.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", la) + } + + c2, err := Dial(tt.net, c1.LocalAddr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c2.Close() + if la, ok := c2.LocalAddr().(*UDPAddr); !ok || !tt.nameLookup && la.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", la) + } + if ra, ok := c2.RemoteAddr().(*UDPAddr); !ok || !tt.nameLookup && ra.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", ra) + } + + if _, err := c2.Write([]byte("UDP OVER IPV6 LINKLOCAL TEST")); err != nil { + t.Fatalf("Conn.Write failed: %v", err) + } + b := make([]byte, 32) + if _, from, err := c1.ReadFrom(b); err != nil { + t.Fatalf("PacketConn.ReadFrom failed: %v", err) + } else { + if ra, ok := from.(*UDPAddr); !ok || !tt.nameLookup && ra.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", ra) + } } } } diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go index 6e5e90268..5ce7d6bea 100644 --- a/src/pkg/net/udpsock.go +++ b/src/pkg/net/udpsock.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// UDP sockets - package net import "errors" @@ -24,14 +22,18 @@ func (a *UDPAddr) String() string { if a == nil { return "<nil>" } + if a.Zone != "" { + return JoinHostPort(a.IP.String()+"%"+a.Zone, itoa(a.Port)) + } return JoinHostPort(a.IP.String(), itoa(a.Port)) } -// ResolveUDPAddr parses addr as a UDP address of the form -// host:port and resolves domain names or port names to -// numeric addresses on the network net, which must be "udp", -// "udp4" or "udp6". A literal IPv6 host address must be -// enclosed in square brackets, as in "[::]:80". +// ResolveUDPAddr parses addr as a UDP address of the form "host:port" +// or "[ipv6-host%zone]:port" and resolves a pair of domain name and +// port name on the network net, which must be "udp", "udp4" or +// "udp6". A literal address or host name for IPv6 must be enclosed +// in square brackets, as in "[::1]:80", "[ipv6-host]:http" or +// "[ipv6-host%zone]:80". func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { switch net { case "udp", "udp4", "udp6": diff --git a/src/pkg/net/udpsock_plan9.go b/src/pkg/net/udpsock_plan9.go index 2a7e3d19c..12a348399 100644 --- a/src/pkg/net/udpsock_plan9.go +++ b/src/pkg/net/udpsock_plan9.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// UDP sockets for Plan 9 - package net import ( @@ -13,15 +11,13 @@ import ( "time" ) -// UDPConn is the implementation of the Conn and PacketConn -// interfaces for UDP network connections. +// UDPConn is the implementation of the Conn and PacketConn interfaces +// for UDP network connections. type UDPConn struct { conn } -func newUDPConn(fd *netFD) *UDPConn { - return &UDPConn{conn{fd}} -} +func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} } // ReadFromUDP reads a UDP packet from c, copying the payload into b. // It returns the number of bytes copied into b and the return address @@ -58,7 +54,7 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) { } // ReadMsgUDP reads a packet from c, copying the payload into b and -// the associdated out-of-band data into oob. It returns the number +// the associated out-of-band data into oob. It returns the number // of bytes copied into b, the number of bytes copied into oob, the // flags that were set on the packet and the source address of the // packet. @@ -164,7 +160,10 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) { } // ListenUDP listens for incoming UDP packets addressed to the local -// address laddr. The returned connection c's ReadFrom and WriteTo +// address laddr. Net must be "udp", "udp4", or "udp6". If laddr has +// a port of 0, ListenUDP will choose an available port. +// The LocalAddr method of the returned UDPConn can be used to +// discover the port. The returned connection's ReadFrom and WriteTo // methods can be used to receive and send UDP packets with per-packet // addressing. func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { diff --git a/src/pkg/net/udpsock_posix.go b/src/pkg/net/udpsock_posix.go index 385cd902e..b90cb030d 100644 --- a/src/pkg/net/udpsock_posix.go +++ b/src/pkg/net/udpsock_posix.go @@ -4,8 +4,6 @@ // +build darwin freebsd linux netbsd openbsd windows -// UDP sockets for POSIX - package net import ( @@ -51,8 +49,8 @@ func (a *UDPAddr) toAddr() sockaddr { return a } -// UDPConn is the implementation of the Conn and PacketConn -// interfaces for UDP network connections. +// UDPConn is the implementation of the Conn and PacketConn interfaces +// for UDP network connections. type UDPConn struct { conn } @@ -63,8 +61,9 @@ func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} } // It returns the number of bytes copied into b and the return address // that was on the packet. // -// ReadFromUDP can be made to time out and return an error with Timeout() == true -// after a fixed time limit; see SetDeadline and SetReadDeadline. +// ReadFromUDP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetReadDeadline. func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) { if !c.ok() { return 0, nil, syscall.EINVAL @@ -89,7 +88,7 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) { } // ReadMsgUDP reads a packet from c, copying the payload into b and -// the associdated out-of-band data into oob. It returns the number +// the associated out-of-band data into oob. It returns the number // of bytes copied into b, the number of bytes copied into oob, the // flags that were set on the packet and the source address of the // packet. @@ -108,12 +107,13 @@ func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, return } -// WriteToUDP writes a UDP packet to addr via c, copying the payload from b. +// WriteToUDP writes a UDP packet to addr via c, copying the payload +// from b. // -// WriteToUDP can be made to time out and return -// an error with Timeout() == true after a fixed time limit; -// see SetDeadline and SetWriteDeadline. -// On packet-oriented connections, write timeouts are rare. +// WriteToUDP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetWriteDeadline. On packet-oriented connections, write timeouts +// are rare. func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -158,8 +158,8 @@ func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err er } // DialUDP connects to the remote address raddr on the network net, -// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used -// as the local address for the connection. +// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is +// used as the local address for the connection. func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) { return dialUDP(net, laddr, raddr, noDeadline) } @@ -180,10 +180,13 @@ func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, e return newUDPConn(fd), nil } -// ListenUDP listens for incoming UDP packets addressed to the -// local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send UDP -// packets with per-packet addressing. +// ListenUDP listens for incoming UDP packets addressed to the local +// address laddr. Net must be "udp", "udp4", or "udp6". If laddr has +// a port of 0, ListenUDP will choose an available port. +// The LocalAddr method of the returned UDPConn can be used to +// discover the port. The returned connection's ReadFrom and WriteTo +// methods can be used to receive and send UDP packets with per-packet +// addressing. func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { switch net { case "udp", "udp4", "udp6": @@ -201,9 +204,9 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { } // ListenMulticastUDP listens for incoming multicast UDP packets -// addressed to the group address gaddr on ifi, which specifies -// the interface to join. ListenMulticastUDP uses default -// multicast interface if ifi is nil. +// addressed to the group address gaddr on ifi, which specifies the +// interface to join. ListenMulticastUDP uses default multicast +// interface if ifi is nil. func ListenMulticastUDP(net string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { switch net { case "udp", "udp4", "udp6": diff --git a/src/pkg/net/unicast_posix_test.go b/src/pkg/net/unicast_posix_test.go index a8855cab7..b0588f4e5 100644 --- a/src/pkg/net/unicast_posix_test.go +++ b/src/pkg/net/unicast_posix_test.go @@ -45,7 +45,7 @@ var listenerTests = []struct { // same port. func TestTCPListener(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "plan9": t.Skipf("skipping test on %q", runtime.GOOS) } @@ -69,65 +69,8 @@ func TestTCPListener(t *testing.T) { // same port. func TestUDPListener(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": - t.Skipf("skipping test on %q", runtime.GOOS) - } - - toudpnet := func(net string) string { - switch net { - case "tcp": - return "udp" - case "tcp4": - return "udp4" - case "tcp6": - return "udp6" - } - return "<nil>" - } - - for _, tt := range listenerTests { - if tt.wildcard && (testing.Short() || !*testExternal) { - continue - } - if tt.ipv6 && !supportsIPv6 { - continue - } - tt.net = toudpnet(tt.net) - l1, port := usableListenPacketPort(t, tt.net, tt.laddr) - checkFirstListener(t, tt.net, tt.laddr+":"+port, l1) - l2, err := ListenPacket(tt.net, tt.laddr+":"+port) - checkSecondListener(t, tt.net, tt.laddr+":"+port, err, l2) - l1.Close() - } -} - -func TestSimpleTCPListener(t *testing.T) { - switch runtime.GOOS { - case "plan9": - t.Skipf("skipping test on %q", runtime.GOOS) - return - } - - for _, tt := range listenerTests { - if tt.wildcard && (testing.Short() || !*testExternal) { - continue - } - if tt.ipv6 { - continue - } - l1, port := usableListenPort(t, tt.net, tt.laddr) - checkFirstListener(t, tt.net, tt.laddr+":"+port, l1) - l2, err := Listen(tt.net, tt.laddr+":"+port) - checkSecondListener(t, tt.net, tt.laddr+":"+port, err, l2) - l1.Close() - } -} - -func TestSimpleUDPListener(t *testing.T) { - switch runtime.GOOS { case "plan9": t.Skipf("skipping test on %q", runtime.GOOS) - return } toudpnet := func(net string) string { @@ -146,7 +89,7 @@ func TestSimpleUDPListener(t *testing.T) { if tt.wildcard && (testing.Short() || !*testExternal) { continue } - if tt.ipv6 { + if tt.ipv6 && !supportsIPv6 { continue } tt.net = toudpnet(tt.net) @@ -231,7 +174,7 @@ func TestDualStackTCPListener(t *testing.T) { t.Skipf("skipping test on %q", runtime.GOOS) } if !supportsIPv6 { - return + t.Skip("ipv6 is not supported") } for _, tt := range dualStackListenerTests { @@ -263,7 +206,7 @@ func TestDualStackUDPListener(t *testing.T) { t.Skipf("skipping test on %q", runtime.GOOS) } if !supportsIPv6 { - return + t.Skip("ipv6 is not supported") } toudpnet := func(net string) string { diff --git a/src/pkg/net/unix_test.go b/src/pkg/net/unix_test.go index 2eaabe86e..5e63e9d9d 100644 --- a/src/pkg/net/unix_test.go +++ b/src/pkg/net/unix_test.go @@ -33,7 +33,6 @@ func TestReadUnixgramWithUnnamedSocket(t *testing.T) { off := make(chan bool) data := [5]byte{1, 2, 3, 4, 5} - go func() { defer func() { off <- true }() s, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) @@ -54,15 +53,13 @@ func TestReadUnixgramWithUnnamedSocket(t *testing.T) { c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) n, from, err := c.ReadFrom(b) if err != nil { - t.Errorf("UnixConn.ReadFrom failed: %v", err) - return + t.Fatalf("UnixConn.ReadFrom failed: %v", err) } if from != nil { - t.Errorf("neighbor address is %v", from) + t.Fatalf("neighbor address is %v", from) } if !bytes.Equal(b[:n], data[:]) { - t.Errorf("got %v, want %v", b[:n], data[:]) - return + t.Fatalf("got %v, want %v", b[:n], data[:]) } } @@ -101,13 +98,12 @@ func TestReadUnixgramWithZeroBytesBuffer(t *testing.T) { <-off c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - var peer Addr - if _, peer, err = c.ReadFrom(nil); err != nil { - t.Errorf("UnixConn.ReadFrom failed: %v", err) - return + _, from, err := c.ReadFrom(nil) + if err != nil { + t.Fatalf("UnixConn.ReadFrom failed: %v", err) } - if peer != nil { - t.Errorf("peer adddress is %v", peer) + if from != nil { + t.Fatalf("neighbor address is %v", from) } } @@ -126,10 +122,10 @@ func TestUnixAutobind(t *testing.T) { // retrieve the autobind address autoAddr := c1.LocalAddr().(*UnixAddr) if len(autoAddr.Name) <= 1 { - t.Fatalf("Invalid autobind address: %v", autoAddr) + t.Fatalf("invalid autobind address: %v", autoAddr) } if autoAddr.Name[0] != '@' { - t.Fatalf("Invalid autobind address: %v", autoAddr) + t.Fatalf("invalid autobind address: %v", autoAddr) } c2, err := DialUnix("unixgram", nil, autoAddr) @@ -139,6 +135,112 @@ func TestUnixAutobind(t *testing.T) { defer c2.Close() if !reflect.DeepEqual(c1.LocalAddr(), c2.RemoteAddr()) { - t.Fatalf("Expected autobind address %v, got %v", c1.LocalAddr(), c2.RemoteAddr()) + t.Fatalf("expected autobind address %v, got %v", c1.LocalAddr(), c2.RemoteAddr()) + } +} + +func TestUnixConnLocalAndRemoteNames(t *testing.T) { + for _, laddr := range []string{"", testUnixAddr()} { + taddr := testUnixAddr() + ta, err := ResolveUnixAddr("unix", taddr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + ln, err := ListenUnix("unix", ta) + if err != nil { + t.Fatalf("ListenUnix failed: %v", err) + } + defer func() { + ln.Close() + os.Remove(taddr) + }() + + done := make(chan int) + go transponder(t, ln, done) + + la, err := ResolveUnixAddr("unix", laddr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c, err := DialUnix("unix", la, ta) + if err != nil { + t.Fatalf("DialUnix failed: %v", err) + } + defer func() { + c.Close() + if la != nil { + defer os.Remove(laddr) + } + }() + if _, err := c.Write([]byte("UNIXCONN LOCAL AND REMOTE NAME TEST")); err != nil { + t.Fatalf("UnixConn.Write failed: %v", err) + } + + if runtime.GOOS == "linux" && laddr == "" { + laddr = "@" // autobind feature + } + var connAddrs = [3]struct{ got, want Addr }{ + {ln.Addr(), ta}, + {c.LocalAddr(), &UnixAddr{Name: laddr, Net: "unix"}}, + {c.RemoteAddr(), ta}, + } + for _, ca := range connAddrs { + if !reflect.DeepEqual(ca.got, ca.want) { + t.Fatalf("got %#v, expected %#v", ca.got, ca.want) + } + } + + <-done + } +} + +func TestUnixgramConnLocalAndRemoteNames(t *testing.T) { + for _, laddr := range []string{"", testUnixAddr()} { + taddr := testUnixAddr() + ta, err := ResolveUnixAddr("unixgram", taddr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c1, err := ListenUnixgram("unixgram", ta) + if err != nil { + t.Fatalf("ListenUnixgram failed: %v", err) + } + defer func() { + c1.Close() + os.Remove(taddr) + }() + + var la *UnixAddr + if laddr != "" { + var err error + if la, err = ResolveUnixAddr("unixgram", laddr); err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + } + c2, err := DialUnix("unixgram", la, ta) + if err != nil { + t.Fatalf("DialUnix failed: %v", err) + } + defer func() { + c2.Close() + if la != nil { + defer os.Remove(laddr) + } + }() + + if runtime.GOOS == "linux" && laddr == "" { + laddr = "@" // autobind feature + } + var connAddrs = [4]struct{ got, want Addr }{ + {c1.LocalAddr(), ta}, + {c1.RemoteAddr(), nil}, + {c2.LocalAddr(), &UnixAddr{Name: laddr, Net: "unixgram"}}, + {c2.RemoteAddr(), ta}, + } + for _, ca := range connAddrs { + if !reflect.DeepEqual(ca.got, ca.want) { + t.Fatalf("got %#v, expected %#v", ca.got, ca.want) + } + } } } diff --git a/src/pkg/net/unixsock.go b/src/pkg/net/unixsock.go index ae0956958..21a19eca2 100644 --- a/src/pkg/net/unixsock.go +++ b/src/pkg/net/unixsock.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Unix domain sockets - package net // UnixAddr represents the address of a Unix domain socket end point. @@ -12,7 +10,8 @@ type UnixAddr struct { Net string } -// Network returns the address's network name, "unix" or "unixgram". +// Network returns the address's network name, "unix", "unixgram" or +// "unixpacket". func (a *UnixAddr) Network() string { return a.Net } @@ -36,11 +35,9 @@ func (a *UnixAddr) toAddr() Addr { // "unixpacket". func ResolveUnixAddr(net, addr string) (*UnixAddr, error) { switch net { - case "unix": - case "unixpacket": - case "unixgram": + case "unix", "unixgram", "unixpacket": + return &UnixAddr{Name: addr, Net: net}, nil default: return nil, UnknownNetworkError(net) } - return &UnixAddr{addr, net}, nil } diff --git a/src/pkg/net/unixsock_plan9.go b/src/pkg/net/unixsock_plan9.go index 00a0be5b0..8a1281fb1 100644 --- a/src/pkg/net/unixsock_plan9.go +++ b/src/pkg/net/unixsock_plan9.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Unix domain sockets stubs for Plan 9 - package net import ( @@ -128,14 +126,18 @@ func (l *UnixListener) SetDeadline(t time.Time) error { // File returns a copy of the underlying os.File, set to blocking // mode. It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. +// +// The returned os.File's file descriptor is different from the +// connection's. Attempting to change properties of the original +// using this duplicate may or may not have the desired effect. func (l *UnixListener) File() (*os.File, error) { return nil, syscall.EPLAN9 } // ListenUnixgram listens for incoming Unix datagram packets addressed -// to the local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send packets with -// per-packet addressing. The network net must be "unixgram". +// to the local address laddr. The network net must be "unixgram". +// The returned connection's ReadFrom and WriteTo methods can be used +// to receive and send packets with per-packet addressing. func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { return nil, syscall.EPLAN9 } diff --git a/src/pkg/net/unixsock_posix.go b/src/pkg/net/unixsock_posix.go index 6d6ce3f5e..5db30df95 100644 --- a/src/pkg/net/unixsock_posix.go +++ b/src/pkg/net/unixsock_posix.go @@ -4,8 +4,6 @@ // +build darwin freebsd linux netbsd openbsd windows -// Unix domain sockets - package net import ( @@ -15,6 +13,13 @@ import ( "time" ) +func (a *UnixAddr) isUnnamed() bool { + if a == nil || a.Name == "" { + return true + } + return false +} + func unixSocket(net string, laddr, raddr *UnixAddr, mode string, deadline time.Time) (*netFD, error) { var sotype int switch net { @@ -31,12 +36,12 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string, deadline time.T var la, ra syscall.Sockaddr switch mode { case "dial": - if laddr != nil { + if !laddr.isUnnamed() { la = &syscall.SockaddrUnix{Name: laddr.Name} } if raddr != nil { ra = &syscall.SockaddrUnix{Name: raddr.Name} - } else if sotype != syscall.SOCK_DGRAM || laddr == nil { + } else if sotype != syscall.SOCK_DGRAM || laddr.isUnnamed() { return nil, &OpError{Op: mode, Net: net, Err: errMissingAddress} } case "listen": @@ -69,21 +74,21 @@ error: func sockaddrToUnix(sa syscall.Sockaddr) Addr { if s, ok := sa.(*syscall.SockaddrUnix); ok { - return &UnixAddr{s.Name, "unix"} + return &UnixAddr{Name: s.Name, Net: "unix"} } return nil } func sockaddrToUnixgram(sa syscall.Sockaddr) Addr { if s, ok := sa.(*syscall.SockaddrUnix); ok { - return &UnixAddr{s.Name, "unixgram"} + return &UnixAddr{Name: s.Name, Net: "unixgram"} } return nil } func sockaddrToUnixpacket(sa syscall.Sockaddr) Addr { if s, ok := sa.(*syscall.SockaddrUnix); ok { - return &UnixAddr{s.Name, "unixpacket"} + return &UnixAddr{Name: s.Name, Net: "unixpacket"} } return nil } @@ -92,14 +97,13 @@ func sotypeToNet(sotype int) string { switch sotype { case syscall.SOCK_STREAM: return "unix" - case syscall.SOCK_SEQPACKET: - return "unixpacket" case syscall.SOCK_DGRAM: return "unixgram" + case syscall.SOCK_SEQPACKET: + return "unixpacket" default: panic("sotypeToNet unknown socket type") } - return "" } // UnixConn is an implementation of the Conn interface for connections @@ -125,7 +129,7 @@ func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err error) { switch sa := sa.(type) { case *syscall.SockaddrUnix: if sa.Name != "" { - addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} + addr = &UnixAddr{Name: sa.Name, Net: sotypeToNet(c.fd.sotype)} } } return @@ -152,7 +156,7 @@ func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAdd switch sa := sa.(type) { case *syscall.SockaddrUnix: if sa.Name != "" { - addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} + addr = &UnixAddr{Name: sa.Name, Net: sotypeToNet(c.fd.sotype)} } } return @@ -267,7 +271,7 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { } err = syscall.Listen(fd.sysfd, listenerBacklog) if err != nil { - closesocket(fd.sysfd) + fd.Close() return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: err} } return &UnixListener{fd, laddr.Name}, nil @@ -335,12 +339,16 @@ func (l *UnixListener) SetDeadline(t time.Time) (err error) { // File returns a copy of the underlying os.File, set to blocking // mode. It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. +// +// The returned os.File's file descriptor is different from the +// connection's. Attempting to change properties of the original +// using this duplicate may or may not have the desired effect. func (l *UnixListener) File() (f *os.File, err error) { return l.fd.dup() } // ListenUnixgram listens for incoming Unix datagram packets addressed -// to the local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send packets with -// per-packet addressing. The network net must be "unixgram". +// to the local address laddr. The network net must be "unixgram". +// The returned connection's ReadFrom and WriteTo methods can be used +// to receive and send packets with per-packet addressing. func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { switch net { case "unixgram": diff --git a/src/pkg/net/url/url.go b/src/pkg/net/url/url.go index a39964ea1..459dc473c 100644 --- a/src/pkg/net/url/url.go +++ b/src/pkg/net/url/url.go @@ -317,23 +317,22 @@ func getscheme(rawurl string) (scheme, path string, err error) { // Maybe s is of the form t c u. // If so, return t, c u (or t, u if cutc == true). // If not, return s, "". -func split(s string, c byte, cutc bool) (string, string) { - for i := 0; i < len(s); i++ { - if s[i] == c { - if cutc { - return s[0:i], s[i+1:] - } - return s[0:i], s[i:] - } +func split(s string, c string, cutc bool) (string, string) { + i := strings.Index(s, c) + if i < 0 { + return s, "" + } + if cutc { + return s[0:i], s[i+len(c):] } - return s, "" + return s[0:i], s[i:] } // Parse parses rawurl into a URL structure. // The rawurl may be relative or absolute. func Parse(rawurl string) (url *URL, err error) { // Cut off #frag - u, frag := split(rawurl, '#', true) + u, frag := split(rawurl, "#", true) if url, err = parse(u, false); err != nil { return nil, err } @@ -362,7 +361,7 @@ func ParseRequestURI(rawurl string) (url *URL, err error) { func parse(rawurl string, viaRequest bool) (url *URL, err error) { var rest string - if rawurl == "" { + if rawurl == "" && viaRequest { err = errors.New("empty url") goto Error } @@ -380,7 +379,7 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) { } url.Scheme = strings.ToLower(url.Scheme) - rest, url.RawQuery = split(rest, '?', true) + rest, url.RawQuery = split(rest, "?", true) if !strings.HasPrefix(rest, "/") { if url.Scheme != "" { @@ -396,7 +395,7 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) { if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") { var authority string - authority, rest = split(rest[2:], '/', false) + authority, rest = split(rest[2:], "/", false) url.User, url.Host, err = parseAuthority(authority) if err != nil { goto Error @@ -428,7 +427,7 @@ func parseAuthority(authority string) (user *Userinfo, host string, err error) { } user = User(userinfo) } else { - username, password := split(userinfo, ':', true) + username, password := split(userinfo, ":", true) if username, err = unescape(username, encodeUserPassword); err != nil { return } @@ -583,43 +582,39 @@ func (v Values) Encode() string { } // resolvePath applies special path segments from refs and applies -// them to base, per RFC 2396. -func resolvePath(basepath string, refpath string) string { - base := strings.Split(basepath, "/") - refs := strings.Split(refpath, "/") - if len(base) == 0 { - base = []string{""} +// them to base, per RFC 3986. +func resolvePath(base, ref string) string { + var full string + if ref == "" { + full = base + } else if ref[0] != '/' { + i := strings.LastIndex(base, "/") + full = base[:i+1] + ref + } else { + full = ref } - - rm := true - for idx, ref := range refs { - switch { - case ref == ".": - if idx == 0 { - base[len(base)-1] = "" - rm = true - } else { - rm = false - } - case ref == "..": - newLen := len(base) - 1 - if newLen < 1 { - newLen = 1 - } - base = base[0:newLen] - if rm { - base[len(base)-1] = "" + if full == "" { + return "" + } + var dst []string + src := strings.Split(full, "/") + for _, elem := range src { + switch elem { + case ".": + // drop + case "..": + if len(dst) > 0 { + dst = dst[:len(dst)-1] } default: - if idx == 0 || base[len(base)-1] == "" { - base[len(base)-1] = ref - } else { - base = append(base, ref) - } - rm = false + dst = append(dst, elem) } } - return strings.Join(base, "/") + if last := src[len(src)-1]; last == "." || last == ".." { + // Add final slash to the joined path. + dst = append(dst, "") + } + return "/" + strings.TrimLeft(strings.Join(dst, "/"), "/") } // IsAbs returns true if the URL is absolute. @@ -639,43 +634,39 @@ func (u *URL) Parse(ref string) (*URL, error) { } // ResolveReference resolves a URI reference to an absolute URI from -// an absolute base URI, per RFC 2396 Section 5.2. The URI reference +// an absolute base URI, per RFC 3986 Section 5.2. The URI reference // may be relative or absolute. ResolveReference always returns a new // URL instance, even if the returned URL is identical to either the // base or reference. If ref is an absolute URL, then ResolveReference // ignores base and returns a copy of ref. func (u *URL) ResolveReference(ref *URL) *URL { - if ref.IsAbs() { - url := *ref + url := *ref + if ref.Scheme == "" { + url.Scheme = u.Scheme + } + if ref.Scheme != "" || ref.Host != "" || ref.User != nil { + // The "absoluteURI" or "net_path" cases. + url.Path = resolvePath(ref.Path, "") return &url } - // relativeURI = ( net_path | abs_path | rel_path ) [ "?" query ] - url := *u - url.RawQuery = ref.RawQuery - url.Fragment = ref.Fragment if ref.Opaque != "" { - url.Opaque = ref.Opaque url.User = nil url.Host = "" url.Path = "" return &url } - if ref.Host != "" || ref.User != nil { - // The "net_path" case. - url.Host = ref.Host - url.User = ref.User - } - if strings.HasPrefix(ref.Path, "/") { - // The "abs_path" case. - url.Path = ref.Path - } else { - // The "rel_path" case. - path := resolvePath(u.Path, ref.Path) - if !strings.HasPrefix(path, "/") { - path = "/" + path + if ref.Path == "" { + if ref.RawQuery == "" { + url.RawQuery = u.RawQuery + if ref.Fragment == "" { + url.Fragment = u.Fragment + } } - url.Path = path } + // The "abs_path" or "rel_path" cases. + url.Host = u.Host + url.User = u.User + url.Path = resolvePath(u.Path, ref.Path) return &url } diff --git a/src/pkg/net/url/url_test.go b/src/pkg/net/url/url_test.go index 4c4f406c2..9d81289ce 100644 --- a/src/pkg/net/url/url_test.go +++ b/src/pkg/net/url/url_test.go @@ -523,18 +523,18 @@ func TestEncodeQuery(t *testing.T) { var resolvePathTests = []struct { base, ref, expected string }{ - {"a/b", ".", "a/"}, - {"a/b", "c", "a/c"}, - {"a/b", "..", ""}, - {"a/", "..", ""}, - {"a/", "../..", ""}, - {"a/b/c", "..", "a/"}, - {"a/b/c", "../d", "a/d"}, - {"a/b/c", ".././d", "a/d"}, - {"a/b", "./..", ""}, - {"a/./b", ".", "a/./"}, - {"a/../", ".", "a/../"}, - {"a/.././b", "c", "a/.././c"}, + {"a/b", ".", "/a/"}, + {"a/b", "c", "/a/c"}, + {"a/b", "..", "/"}, + {"a/", "..", "/"}, + {"a/", "../..", "/"}, + {"a/b/c", "..", "/a/"}, + {"a/b/c", "../d", "/a/d"}, + {"a/b/c", ".././d", "/a/d"}, + {"a/b", "./..", "/"}, + {"a/./b", ".", "/a/"}, + {"a/../", ".", "/"}, + {"a/.././b", "c", "/c"}, } func TestResolvePath(t *testing.T) { @@ -587,16 +587,71 @@ var resolveReferenceTests = []struct { {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/././../../tail", "http://foo.com/bar/quux/tail"}, {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/./.././../tail", "http://foo.com/bar/quux/tail"}, {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/dotdot/./../../.././././tail", "http://foo.com/bar/quux/tail"}, - {"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot"}, + {"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot/"}, - // "." and ".." in the base aren't special - {"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/./dotdot/../baz"}, + // Remove any dot-segments prior to forming the target URI. + // http://tools.ietf.org/html/rfc3986#section-5.2.4 + {"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/baz"}, // Triple dot isn't special {"http://foo.com/bar", "...", "http://foo.com/..."}, // Fragment {"http://foo.com/bar", ".#frag", "http://foo.com/#frag"}, + + // RFC 3986: Normal Examples + // http://tools.ietf.org/html/rfc3986#section-5.4.1 + {"http://a/b/c/d;p?q", "g:h", "g:h"}, + {"http://a/b/c/d;p?q", "g", "http://a/b/c/g"}, + {"http://a/b/c/d;p?q", "./g", "http://a/b/c/g"}, + {"http://a/b/c/d;p?q", "g/", "http://a/b/c/g/"}, + {"http://a/b/c/d;p?q", "/g", "http://a/g"}, + {"http://a/b/c/d;p?q", "//g", "http://g"}, + {"http://a/b/c/d;p?q", "?y", "http://a/b/c/d;p?y"}, + {"http://a/b/c/d;p?q", "g?y", "http://a/b/c/g?y"}, + {"http://a/b/c/d;p?q", "#s", "http://a/b/c/d;p?q#s"}, + {"http://a/b/c/d;p?q", "g#s", "http://a/b/c/g#s"}, + {"http://a/b/c/d;p?q", "g?y#s", "http://a/b/c/g?y#s"}, + {"http://a/b/c/d;p?q", ";x", "http://a/b/c/;x"}, + {"http://a/b/c/d;p?q", "g;x", "http://a/b/c/g;x"}, + {"http://a/b/c/d;p?q", "g;x?y#s", "http://a/b/c/g;x?y#s"}, + {"http://a/b/c/d;p?q", "", "http://a/b/c/d;p?q"}, + {"http://a/b/c/d;p?q", ".", "http://a/b/c/"}, + {"http://a/b/c/d;p?q", "./", "http://a/b/c/"}, + {"http://a/b/c/d;p?q", "..", "http://a/b/"}, + {"http://a/b/c/d;p?q", "../", "http://a/b/"}, + {"http://a/b/c/d;p?q", "../g", "http://a/b/g"}, + {"http://a/b/c/d;p?q", "../..", "http://a/"}, + {"http://a/b/c/d;p?q", "../../", "http://a/"}, + {"http://a/b/c/d;p?q", "../../g", "http://a/g"}, + + // RFC 3986: Abnormal Examples + // http://tools.ietf.org/html/rfc3986#section-5.4.2 + {"http://a/b/c/d;p?q", "../../../g", "http://a/g"}, + {"http://a/b/c/d;p?q", "../../../../g", "http://a/g"}, + {"http://a/b/c/d;p?q", "/./g", "http://a/g"}, + {"http://a/b/c/d;p?q", "/../g", "http://a/g"}, + {"http://a/b/c/d;p?q", "g.", "http://a/b/c/g."}, + {"http://a/b/c/d;p?q", ".g", "http://a/b/c/.g"}, + {"http://a/b/c/d;p?q", "g..", "http://a/b/c/g.."}, + {"http://a/b/c/d;p?q", "..g", "http://a/b/c/..g"}, + {"http://a/b/c/d;p?q", "./../g", "http://a/b/g"}, + {"http://a/b/c/d;p?q", "./g/.", "http://a/b/c/g/"}, + {"http://a/b/c/d;p?q", "g/./h", "http://a/b/c/g/h"}, + {"http://a/b/c/d;p?q", "g/../h", "http://a/b/c/h"}, + {"http://a/b/c/d;p?q", "g;x=1/./y", "http://a/b/c/g;x=1/y"}, + {"http://a/b/c/d;p?q", "g;x=1/../y", "http://a/b/c/y"}, + {"http://a/b/c/d;p?q", "g?y/./x", "http://a/b/c/g?y/./x"}, + {"http://a/b/c/d;p?q", "g?y/../x", "http://a/b/c/g?y/../x"}, + {"http://a/b/c/d;p?q", "g#s/./x", "http://a/b/c/g#s/./x"}, + {"http://a/b/c/d;p?q", "g#s/../x", "http://a/b/c/g#s/../x"}, + + // Extras. + {"https://a/b/c/d;p?q", "//g?q", "https://g?q"}, + {"https://a/b/c/d;p?q", "//g#s", "https://g#s"}, + {"https://a/b/c/d;p?q", "//g/d/e/f?y#s", "https://g/d/e/f?y#s"}, + {"https://a/b/c/d;p#s", "?y", "https://a/b/c/d;p?y"}, + {"https://a/b/c/d;p?q#s", "?y", "https://a/b/c/d;p?y"}, } func TestResolveReference(t *testing.T) { @@ -607,91 +662,44 @@ func TestResolveReference(t *testing.T) { } return u } + opaque := &URL{Scheme: "scheme", Opaque: "opaque"} for _, test := range resolveReferenceTests { base := mustParse(test.base) rel := mustParse(test.rel) url := base.ResolveReference(rel) - urlStr := url.String() - if urlStr != test.expected { - t.Errorf("Resolving %q + %q != %q; got %q", test.base, test.rel, test.expected, urlStr) + if url.String() != test.expected { + t.Errorf("URL(%q).ResolveReference(%q) == %q, got %q", test.base, test.rel, test.expected, url.String()) } - } - - // Test that new instances are returned. - base := mustParse("http://foo.com/") - abs := base.ResolveReference(mustParse(".")) - if base == abs { - t.Errorf("Expected no-op reference to return new URL instance.") - } - barRef := mustParse("http://bar.com/") - abs = base.ResolveReference(barRef) - if abs == barRef { - t.Errorf("Expected resolution of absolute reference to return new URL instance.") - } - - // Test the convenience wrapper too - base = mustParse("http://foo.com/path/one/") - abs, _ = base.Parse("../two") - expected := "http://foo.com/path/two" - if abs.String() != expected { - t.Errorf("Parse wrapper got %q; expected %q", abs.String(), expected) - } - _, err := base.Parse("") - if err == nil { - t.Errorf("Expected an error from Parse wrapper parsing an empty string.") - } - - // Ensure Opaque resets the URL. - base = mustParse("scheme://user@foo.com/bar") - abs = base.ResolveReference(&URL{Opaque: "opaque"}) - want := mustParse("scheme:opaque") - if *abs != *want { - t.Errorf("ResolveReference failed to resolve opaque URL: want %#v, got %#v", abs, want) - } -} - -func TestResolveReferenceOpaque(t *testing.T) { - mustParse := func(url string) *URL { - u, err := Parse(url) + // Ensure that new instances are returned. + if base == url { + t.Errorf("Expected URL.ResolveReference to return new URL instance.") + } + // Test the convenience wrapper too. + url, err := base.Parse(test.rel) if err != nil { - t.Fatalf("Expected URL to parse: %q, got error: %v", url, err) + t.Errorf("URL(%q).Parse(%q) failed: %v", test.base, test.rel, err) + } else if url.String() != test.expected { + t.Errorf("URL(%q).Parse(%q) == %q, got %q", test.base, test.rel, test.expected, url.String()) + } else if base == url { + // Ensure that new instances are returned for the wrapper too. + t.Errorf("Expected URL.Parse to return new URL instance.") } - return u - } - for _, test := range resolveReferenceTests { - base := mustParse(test.base) - rel := mustParse(test.rel) - url := base.ResolveReference(rel) - urlStr := url.String() - if urlStr != test.expected { - t.Errorf("Resolving %q + %q != %q; got %q", test.base, test.rel, test.expected, urlStr) + // Ensure Opaque resets the URL. + url = base.ResolveReference(opaque) + if *url != *opaque { + t.Errorf("ResolveReference failed to resolve opaque URL: want %#v, got %#v", url, opaque) + } + // Test the convenience wrapper with an opaque URL too. + url, err = base.Parse("scheme:opaque") + if err != nil { + t.Errorf(`URL(%q).Parse("scheme:opaque") failed: %v`, test.base, err) + } else if *url != *opaque { + t.Errorf("Parse failed to resolve opaque URL: want %#v, got %#v", url, opaque) + } else if base == url { + // Ensure that new instances are returned, again. + t.Errorf("Expected URL.Parse to return new URL instance.") } } - - // Test that new instances are returned. - base := mustParse("http://foo.com/") - abs := base.ResolveReference(mustParse(".")) - if base == abs { - t.Errorf("Expected no-op reference to return new URL instance.") - } - barRef := mustParse("http://bar.com/") - abs = base.ResolveReference(barRef) - if abs == barRef { - t.Errorf("Expected resolution of absolute reference to return new URL instance.") - } - - // Test the convenience wrapper too - base = mustParse("http://foo.com/path/one/") - abs, _ = base.Parse("../two") - expected := "http://foo.com/path/two" - if abs.String() != expected { - t.Errorf("Parse wrapper got %q; expected %q", abs.String(), expected) - } - _, err := base.Parse("") - if err == nil { - t.Errorf("Expected an error from Parse wrapper parsing an empty string.") - } - } func TestQueryValues(t *testing.T) { |