summaryrefslogtreecommitdiff
path: root/src/pkg/net
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/net')
-rw-r--r--src/pkg/net/cgo_bsd.go4
-rw-r--r--src/pkg/net/cgo_linux.go8
-rw-r--r--src/pkg/net/cgo_netbsd.go14
-rw-r--r--src/pkg/net/cgo_openbsd.go14
-rw-r--r--src/pkg/net/cgo_unix.go10
-rw-r--r--src/pkg/net/conn_test.go114
-rw-r--r--src/pkg/net/dial.go243
-rw-r--r--src/pkg/net/dial_test.go114
-rw-r--r--src/pkg/net/dialgoogle_test.go24
-rw-r--r--src/pkg/net/dnsclient.go7
-rw-r--r--src/pkg/net/dnsclient_unix.go28
-rw-r--r--src/pkg/net/dnsconfig_unix.go (renamed from src/pkg/net/dnsconfig.go)0
-rw-r--r--src/pkg/net/dnsmsg.go2
-rw-r--r--src/pkg/net/example_test.go2
-rw-r--r--src/pkg/net/fd_bsd.go (renamed from src/pkg/net/fd_openbsd.go)9
-rw-r--r--src/pkg/net/fd_darwin.go7
-rw-r--r--src/pkg/net/fd_freebsd.go116
-rw-r--r--src/pkg/net/fd_linux.go7
-rw-r--r--src/pkg/net/fd_netbsd.go116
-rw-r--r--src/pkg/net/fd_plan9.go129
-rw-r--r--src/pkg/net/fd_posix_test.go57
-rw-r--r--src/pkg/net/fd_unix.go (renamed from src/pkg/net/fd.go)319
-rw-r--r--src/pkg/net/fd_unix_test.go58
-rw-r--r--src/pkg/net/fd_windows.go293
-rw-r--r--src/pkg/net/file_plan9.go131
-rw-r--r--src/pkg/net/file_test.go8
-rw-r--r--src/pkg/net/file_unix.go (renamed from src/pkg/net/file.go)40
-rw-r--r--src/pkg/net/http/cgi/child.go12
-rw-r--r--src/pkg/net/http/cgi/child_test.go24
-rw-r--r--src/pkg/net/http/cgi/host_test.go86
-rw-r--r--src/pkg/net/http/cgi/plan9_test.go18
-rw-r--r--src/pkg/net/http/cgi/posix_test.go21
-rwxr-xr-xsrc/pkg/net/http/cgi/testdata/test.cgi55
-rw-r--r--src/pkg/net/http/chunked.go59
-rw-r--r--src/pkg/net/http/chunked_test.go54
-rw-r--r--src/pkg/net/http/client.go178
-rw-r--r--src/pkg/net/http/client_test.go250
-rw-r--r--src/pkg/net/http/cookie.go9
-rw-r--r--src/pkg/net/http/cookie_test.go2
-rw-r--r--src/pkg/net/http/cookiejar/jar.go494
-rw-r--r--src/pkg/net/http/cookiejar/jar_test.go1267
-rw-r--r--src/pkg/net/http/cookiejar/punycode.go159
-rw-r--r--src/pkg/net/http/cookiejar/punycode_test.go161
-rw-r--r--src/pkg/net/http/example_test.go2
-rw-r--r--src/pkg/net/http/export_test.go23
-rw-r--r--src/pkg/net/http/filetransport_test.go10
-rw-r--r--src/pkg/net/http/fs.go196
-rw-r--r--src/pkg/net/http/fs_test.go439
-rw-r--r--src/pkg/net/http/header.go135
-rw-r--r--src/pkg/net/http/header_test.go123
-rw-r--r--src/pkg/net/http/httptest/example_test.go50
-rw-r--r--src/pkg/net/http/httptest/recorder.go24
-rw-r--r--src/pkg/net/http/httptest/recorder_test.go90
-rw-r--r--src/pkg/net/http/httptest/server.go73
-rw-r--r--src/pkg/net/http/httputil/chunked.go59
-rw-r--r--src/pkg/net/http/httputil/chunked_test.go54
-rw-r--r--src/pkg/net/http/httputil/dump.go4
-rw-r--r--src/pkg/net/http/httputil/reverseproxy.go58
-rw-r--r--src/pkg/net/http/httputil/reverseproxy_test.go84
-rw-r--r--src/pkg/net/http/jar.go19
-rw-r--r--src/pkg/net/http/lex.go206
-rw-r--r--src/pkg/net/http/lex_test.go65
-rw-r--r--src/pkg/net/http/npn_test.go118
-rw-r--r--src/pkg/net/http/pprof/pprof.go17
-rw-r--r--src/pkg/net/http/proxy_test.go4
-rw-r--r--src/pkg/net/http/range_test.go22
-rw-r--r--src/pkg/net/http/readrequest_test.go48
-rw-r--r--src/pkg/net/http/request.go247
-rw-r--r--src/pkg/net/http/request_test.go189
-rw-r--r--src/pkg/net/http/requestwrite_test.go63
-rw-r--r--src/pkg/net/http/response.go14
-rw-r--r--src/pkg/net/http/response_test.go162
-rw-r--r--src/pkg/net/http/responsewrite_test.go138
-rw-r--r--src/pkg/net/http/serve_test.go578
-rw-r--r--src/pkg/net/http/server.go873
-rw-r--r--src/pkg/net/http/server_test.go95
-rw-r--r--src/pkg/net/http/transfer.go100
-rw-r--r--src/pkg/net/http/transfer_test.go37
-rw-r--r--src/pkg/net/http/transport.go392
-rw-r--r--src/pkg/net/http/transport_test.go577
-rw-r--r--src/pkg/net/http/z_last_test.go60
-rw-r--r--src/pkg/net/interface.go16
-rw-r--r--src/pkg/net/interface_bsd.go104
-rw-r--r--src/pkg/net/interface_bsd_test.go52
-rw-r--r--src/pkg/net/interface_darwin.go35
-rw-r--r--src/pkg/net/interface_freebsd.go35
-rw-r--r--src/pkg/net/interface_linux.go109
-rw-r--r--src/pkg/net/interface_linux_test.go50
-rw-r--r--src/pkg/net/interface_netbsd.go10
-rw-r--r--src/pkg/net/interface_openbsd.go10
-rw-r--r--src/pkg/net/interface_stub.go17
-rw-r--r--src/pkg/net/interface_test.go116
-rw-r--r--src/pkg/net/interface_unix_test.go145
-rw-r--r--src/pkg/net/interface_windows.go28
-rw-r--r--src/pkg/net/ip.go11
-rw-r--r--src/pkg/net/ip_test.go123
-rw-r--r--src/pkg/net/ipraw_test.go413
-rw-r--r--src/pkg/net/iprawsock.go57
-rw-r--r--src/pkg/net/iprawsock_plan9.go97
-rw-r--r--src/pkg/net/iprawsock_posix.go165
-rw-r--r--src/pkg/net/ipsock.go180
-rw-r--r--src/pkg/net/ipsock_plan9.go214
-rw-r--r--src/pkg/net/ipsock_posix.go35
-rw-r--r--src/pkg/net/lookup.go (renamed from src/pkg/net/doc.go)46
-rw-r--r--src/pkg/net/lookup_plan9.go15
-rw-r--r--src/pkg/net/lookup_test.go36
-rw-r--r--src/pkg/net/lookup_unix.go13
-rw-r--r--src/pkg/net/lookup_windows.go155
-rw-r--r--src/pkg/net/mail/message.go15
-rw-r--r--src/pkg/net/mail/message_test.go17
-rw-r--r--src/pkg/net/multicast_posix_test.go180
-rw-r--r--src/pkg/net/multicast_test.go234
-rw-r--r--src/pkg/net/net.go169
-rw-r--r--src/pkg/net/net_test.go126
-rw-r--r--src/pkg/net/newpollserver_unix.go (renamed from src/pkg/net/newpollserver.go)2
-rw-r--r--src/pkg/net/packetconn_test.go200
-rw-r--r--src/pkg/net/parse_test.go3
-rw-r--r--src/pkg/net/port.go73
-rw-r--r--src/pkg/net/port_test.go2
-rw-r--r--src/pkg/net/port_unix.go69
-rw-r--r--src/pkg/net/protoconn_test.go358
-rw-r--r--src/pkg/net/rpc/client.go43
-rw-r--r--src/pkg/net/rpc/jsonrpc/all_test.go61
-rw-r--r--src/pkg/net/rpc/jsonrpc/server.go13
-rw-r--r--src/pkg/net/rpc/server.go84
-rw-r--r--src/pkg/net/rpc/server_test.go72
-rw-r--r--src/pkg/net/sendfile_freebsd.go105
-rw-r--r--src/pkg/net/sendfile_linux.go4
-rw-r--r--src/pkg/net/sendfile_stub.go2
-rw-r--r--src/pkg/net/sendfile_windows.go4
-rw-r--r--src/pkg/net/server_test.go51
-rw-r--r--src/pkg/net/smtp/smtp.go65
-rw-r--r--src/pkg/net/smtp/smtp_test.go229
-rw-r--r--src/pkg/net/sock.go87
-rw-r--r--src/pkg/net/sock_bsd.go31
-rw-r--r--src/pkg/net/sock_cloexec.go69
-rw-r--r--src/pkg/net/sock_linux.go31
-rw-r--r--src/pkg/net/sock_posix.go67
-rw-r--r--src/pkg/net/sock_unix.go36
-rw-r--r--src/pkg/net/sock_windows.go29
-rw-r--r--src/pkg/net/sockopt_posix.go (renamed from src/pkg/net/sockopt.go)37
-rw-r--r--src/pkg/net/sockoptip.go219
-rw-r--r--src/pkg/net/sockoptip_bsd.go36
-rw-r--r--src/pkg/net/sockoptip_darwin.go90
-rw-r--r--src/pkg/net/sockoptip_freebsd.go92
-rw-r--r--src/pkg/net/sockoptip_linux.go101
-rw-r--r--src/pkg/net/sockoptip_netbsd.go39
-rw-r--r--src/pkg/net/sockoptip_openbsd.go90
-rw-r--r--src/pkg/net/sockoptip_posix.go73
-rw-r--r--src/pkg/net/sockoptip_windows.go61
-rw-r--r--src/pkg/net/sys_cloexec.go54
-rw-r--r--src/pkg/net/tcp_test.go206
-rw-r--r--src/pkg/net/tcpsock.go12
-rw-r--r--src/pkg/net/tcpsock_plan9.go161
-rw-r--r--src/pkg/net/tcpsock_posix.go189
-rw-r--r--src/pkg/net/textproto/reader.go129
-rw-r--r--src/pkg/net/textproto/reader_test.go94
-rw-r--r--src/pkg/net/textproto/textproto.go31
-rw-r--r--src/pkg/net/timeout_test.go540
-rw-r--r--src/pkg/net/udp_test.go63
-rw-r--r--src/pkg/net/udpsock.go16
-rw-r--r--src/pkg/net/udpsock_plan9.go134
-rw-r--r--src/pkg/net/udpsock_posix.go229
-rw-r--r--src/pkg/net/unicast_posix_test.go (renamed from src/pkg/net/unicast_test.go)112
-rw-r--r--src/pkg/net/unix_test.go144
-rw-r--r--src/pkg/net/unixsock_plan9.go145
-rw-r--r--src/pkg/net/unixsock_posix.go286
-rw-r--r--src/pkg/net/url/url.go92
-rw-r--r--src/pkg/net/url/url_test.go107
169 files changed, 13730 insertions, 4841 deletions
diff --git a/src/pkg/net/cgo_bsd.go b/src/pkg/net/cgo_bsd.go
index 63750f7a3..3b38e3d83 100644
--- a/src/pkg/net/cgo_bsd.go
+++ b/src/pkg/net/cgo_bsd.go
@@ -11,6 +11,6 @@ package net
*/
import "C"
-func cgoAddrInfoMask() C.int {
- return C.AI_MASK
+func cgoAddrInfoFlags() C.int {
+ return (C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL) & C.AI_MASK
}
diff --git a/src/pkg/net/cgo_linux.go b/src/pkg/net/cgo_linux.go
index 8d4413d2d..f6cefa89a 100644
--- a/src/pkg/net/cgo_linux.go
+++ b/src/pkg/net/cgo_linux.go
@@ -9,6 +9,12 @@ package net
*/
import "C"
-func cgoAddrInfoMask() C.int {
+func cgoAddrInfoFlags() C.int {
+ // NOTE(rsc): In theory there are approximately balanced
+ // arguments for and against including AI_ADDRCONFIG
+ // in the flags (it includes IPv4 results only on IPv4 systems,
+ // and similarly for IPv6), but in practice setting it causes
+ // getaddrinfo to return the wrong canonical name on Linux.
+ // So definitely leave it out.
return C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL
}
diff --git a/src/pkg/net/cgo_netbsd.go b/src/pkg/net/cgo_netbsd.go
new file mode 100644
index 000000000..aeaf8e568
--- /dev/null
+++ b/src/pkg/net/cgo_netbsd.go
@@ -0,0 +1,14 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+/*
+#include <netdb.h>
+*/
+import "C"
+
+func cgoAddrInfoFlags() C.int {
+ return C.AI_CANONNAME
+}
diff --git a/src/pkg/net/cgo_openbsd.go b/src/pkg/net/cgo_openbsd.go
new file mode 100644
index 000000000..aeaf8e568
--- /dev/null
+++ b/src/pkg/net/cgo_openbsd.go
@@ -0,0 +1,14 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+/*
+#include <netdb.h>
+*/
+import "C"
+
+func cgoAddrInfoFlags() C.int {
+ return C.AI_CANONNAME
+}
diff --git a/src/pkg/net/cgo_unix.go b/src/pkg/net/cgo_unix.go
index 36a3f3d34..7476140eb 100644
--- a/src/pkg/net/cgo_unix.go
+++ b/src/pkg/net/cgo_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux
+// +build darwin freebsd linux netbsd openbsd
package net
@@ -81,13 +81,7 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err error, complet
var res *C.struct_addrinfo
var hints C.struct_addrinfo
- // NOTE(rsc): In theory there are approximately balanced
- // arguments for and against including AI_ADDRCONFIG
- // in the flags (it includes IPv4 results only on IPv4 systems,
- // and similarly for IPv6), but in practice setting it causes
- // getaddrinfo to return the wrong canonical name on Linux.
- // So definitely leave it out.
- hints.ai_flags = (C.AI_ALL | C.AI_V4MAPPED | C.AI_CANONNAME) & cgoAddrInfoMask()
+ hints.ai_flags = cgoAddrInfoFlags()
h := C.CString(name)
defer C.free(unsafe.Pointer(h))
diff --git a/src/pkg/net/conn_test.go b/src/pkg/net/conn_test.go
new file mode 100644
index 000000000..fdb90862f
--- /dev/null
+++ b/src/pkg/net/conn_test.go
@@ -0,0 +1,114 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements API tests across platforms and will never have a build
+// tag.
+
+package net
+
+import (
+ "os"
+ "runtime"
+ "testing"
+ "time"
+)
+
+var connTests = []struct {
+ net string
+ addr string
+}{
+ {"tcp", "127.0.0.1:0"},
+ {"unix", testUnixAddr()},
+ {"unixpacket", testUnixAddr()},
+}
+
+// someTimeout is used just to test that net.Conn implementations
+// don't explode when their SetFooDeadline methods are called.
+// It isn't actually used for testing timeouts.
+const someTimeout = 10 * time.Second
+
+func TestConnAndListener(t *testing.T) {
+ for _, tt := range connTests {
+ switch tt.net {
+ case "unix", "unixpacket":
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ continue
+ }
+ if tt.net == "unixpacket" && runtime.GOOS != "linux" {
+ continue
+ }
+ }
+
+ ln, err := Listen(tt.net, tt.addr)
+ if err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+ defer func(ln Listener, net, addr string) {
+ ln.Close()
+ switch net {
+ case "unix", "unixpacket":
+ os.Remove(addr)
+ }
+ }(ln, tt.net, tt.addr)
+ ln.Addr()
+
+ done := make(chan int)
+ go transponder(t, ln, done)
+
+ c, err := Dial(tt.net, ln.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial failed: %v", err)
+ }
+ defer c.Close()
+ c.LocalAddr()
+ c.RemoteAddr()
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+ if _, err := c.Write([]byte("CONN TEST")); err != nil {
+ t.Fatalf("Conn.Write failed: %v", err)
+ }
+ rb := make([]byte, 128)
+ if _, err := c.Read(rb); err != nil {
+ t.Fatalf("Conn.Read failed: %v", err)
+ }
+
+ <-done
+ }
+}
+
+func transponder(t *testing.T, ln Listener, done chan<- int) {
+ defer func() { done <- 1 }()
+
+ switch ln := ln.(type) {
+ case *TCPListener:
+ ln.SetDeadline(time.Now().Add(someTimeout))
+ case *UnixListener:
+ ln.SetDeadline(time.Now().Add(someTimeout))
+ }
+ c, err := ln.Accept()
+ if err != nil {
+ t.Errorf("Listener.Accept failed: %v", err)
+ return
+ }
+ defer c.Close()
+ c.LocalAddr()
+ c.RemoteAddr()
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+ b := make([]byte, 128)
+ n, err := c.Read(b)
+ if err != nil {
+ t.Errorf("Conn.Read failed: %v", err)
+ return
+ }
+ if _, err := c.Write(b[:n]); err != nil {
+ t.Errorf("Conn.Write failed: %v", err)
+ return
+ }
+}
diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go
index 10ca5faf7..22e1e7dd8 100644
--- a/src/pkg/net/dial.go
+++ b/src/pkg/net/dial.go
@@ -5,15 +5,91 @@
package net
import (
+ "errors"
"time"
)
-func parseDialNetwork(net string) (afnet string, proto int, err error) {
+// A DialOption modifies a DialOpt call.
+type DialOption interface {
+ dialOption()
+}
+
+var (
+ // TCP is a dial option to dial with TCP (over IPv4 or IPv6).
+ TCP = Network("tcp")
+
+ // UDP is a dial option to dial with UDP (over IPv4 or IPv6).
+ UDP = Network("udp")
+)
+
+// Network returns a DialOption to dial using the given network.
+//
+// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only),
+// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4"
+// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and
+// "unixpacket".
+//
+// For IP networks, net must be "ip", "ip4" or "ip6" followed
+// by a colon and a protocol number or name, such as
+// "ipv4:1" or "ip6:ospf".
+func Network(net string) DialOption {
+ return dialNetwork(net)
+}
+
+type dialNetwork string
+
+func (dialNetwork) dialOption() {}
+
+// Deadline returns a DialOption to fail a dial that doesn't
+// complete before t.
+func Deadline(t time.Time) DialOption {
+ return dialDeadline(t)
+}
+
+// Timeout returns a DialOption to fail a dial that doesn't
+// complete within the provided duration.
+func Timeout(d time.Duration) DialOption {
+ return dialDeadline(time.Now().Add(d))
+}
+
+type dialDeadline time.Time
+
+func (dialDeadline) dialOption() {}
+
+type tcpFastOpen struct{}
+
+func (tcpFastOpen) dialOption() {}
+
+// TODO(bradfitz): implement this (golang.org/issue/4842) and unexport this.
+//
+// TCPFastTimeout returns an option to use TCP Fast Open (TFO) when
+// doing this dial. It is only valid for use with TCP connections.
+// Data sent over a TFO connection may be processed by the peer
+// multiple times, so should be used with caution.
+func todo_TCPFastTimeout() DialOption {
+ return tcpFastOpen{}
+}
+
+type localAddrOption struct {
+ la Addr
+}
+
+func (localAddrOption) dialOption() {}
+
+// LocalAddress returns a dial option to perform a dial with the
+// provided local address. The address must be of a compatible type
+// for the network being dialed.
+func LocalAddress(addr Addr) DialOption {
+ return localAddrOption{addr}
+}
+
+func parseNetwork(net string) (afnet string, proto int, err error) {
i := last(net, ':')
if i < 0 { // no colon
switch net {
case "tcp", "tcp4", "tcp6":
case "udp", "udp4", "udp6":
+ case "ip", "ip4", "ip6":
case "unix", "unixgram", "unixpacket":
default:
return "", 0, UnknownNetworkError(net)
@@ -36,40 +112,27 @@ func parseDialNetwork(net string) (afnet string, proto int, err error) {
return "", 0, UnknownNetworkError(net)
}
-func resolveNetAddr(op, net, addr string) (afnet string, a Addr, err error) {
- afnet, _, err = parseDialNetwork(net)
+func resolveAddr(op, net, addr string, deadline time.Time) (Addr, error) {
+ afnet, _, err := parseNetwork(net)
if err != nil {
- return "", nil, &OpError{op, net, nil, err}
+ return nil, &OpError{op, net, nil, err}
}
if op == "dial" && addr == "" {
- return "", nil, &OpError{op, net, nil, errMissingAddress}
+ return nil, &OpError{op, net, nil, errMissingAddress}
}
switch afnet {
- case "tcp", "tcp4", "tcp6":
- if addr != "" {
- a, err = ResolveTCPAddr(afnet, addr)
- }
- case "udp", "udp4", "udp6":
- if addr != "" {
- a, err = ResolveUDPAddr(afnet, addr)
- }
- case "ip", "ip4", "ip6":
- if addr != "" {
- a, err = ResolveIPAddr(afnet, addr)
- }
case "unix", "unixgram", "unixpacket":
- if addr != "" {
- a, err = ResolveUnixAddr(afnet, addr)
- }
+ return ResolveUnixAddr(afnet, addr)
}
- return
+ return resolveInternetAddr(afnet, addr, deadline)
}
// Dial connects to the address addr on the network net.
//
// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only),
// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4"
-// (IPv4-only), "ip6" (IPv6-only), "unix" and "unixpacket".
+// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and
+// "unixpacket".
//
// For TCP and UDP networks, addresses have the form host:port.
// If host is a literal IPv6 address, it must be enclosed
@@ -81,7 +144,7 @@ func resolveNetAddr(op, net, addr string) (afnet string, a Addr, err error) {
// Dial("tcp", "google.com:80")
// Dial("tcp", "[de:ad:be:ef::ca:fe]:80")
//
-// For IP networks, addr must be "ip", "ip4" or "ip6" followed
+// For IP networks, net must be "ip", "ip4" or "ip6" followed
// by a colon and a protocol number or name.
//
// Examples:
@@ -89,25 +152,71 @@ func resolveNetAddr(op, net, addr string) (afnet string, a Addr, err error) {
// Dial("ip6:ospf", "::1")
//
func Dial(net, addr string) (Conn, error) {
- _, addri, err := resolveNetAddr("dial", net, addr)
+ return DialOpt(addr, dialNetwork(net))
+}
+
+func netFromOptions(opts []DialOption) string {
+ for _, opt := range opts {
+ if p, ok := opt.(dialNetwork); ok {
+ return string(p)
+ }
+ }
+ return "tcp"
+}
+
+func deadlineFromOptions(opts []DialOption) time.Time {
+ for _, opt := range opts {
+ if d, ok := opt.(dialDeadline); ok {
+ return time.Time(d)
+ }
+ }
+ return noDeadline
+}
+
+var noLocalAddr Addr // nil
+
+func localAddrFromOptions(opts []DialOption) Addr {
+ for _, opt := range opts {
+ if o, ok := opt.(localAddrOption); ok {
+ return o.la
+ }
+ }
+ return noLocalAddr
+}
+
+// DialOpt dials addr using the provided options.
+// If no options are provided, DialOpt(addr) is equivalent
+// to Dial("tcp", addr). See Dial for the syntax of addr.
+func DialOpt(addr string, opts ...DialOption) (Conn, error) {
+ net := netFromOptions(opts)
+ deadline := deadlineFromOptions(opts)
+ la := localAddrFromOptions(opts)
+ ra, err := resolveAddr("dial", net, addr, deadline)
if err != nil {
return nil, err
}
- return dialAddr(net, addr, addri)
+ return dial(net, addr, la, ra, deadline)
}
-func dialAddr(net, addr string, addri Addr) (c Conn, err error) {
- switch ra := addri.(type) {
+func dial(net, addr string, la, ra Addr, deadline time.Time) (c Conn, err error) {
+ if la != nil && la.Network() != ra.Network() {
+ return nil, &OpError{"dial", net, ra, errors.New("mismatched local addr type " + la.Network())}
+ }
+ switch ra := ra.(type) {
case *TCPAddr:
- c, err = DialTCP(net, nil, ra)
+ la, _ := la.(*TCPAddr)
+ c, err = dialTCP(net, la, ra, deadline)
case *UDPAddr:
- c, err = DialUDP(net, nil, ra)
+ la, _ := la.(*UDPAddr)
+ c, err = dialUDP(net, la, ra, deadline)
case *IPAddr:
- c, err = DialIP(net, nil, ra)
+ la, _ := la.(*IPAddr)
+ c, err = dialIP(net, la, ra, deadline)
case *UnixAddr:
- c, err = DialUnix(net, nil, ra)
+ la, _ := la.(*UnixAddr)
+ c, err = dialUnix(net, la, ra, deadline)
default:
- err = &OpError{"dial", net + " " + addr, nil, UnknownNetworkError(net)}
+ err = &OpError{"dial", net + " " + addr, ra, UnknownNetworkError(net)}
}
if err != nil {
return nil, err
@@ -118,10 +227,14 @@ func dialAddr(net, addr string, addri Addr) (c Conn, err error) {
// DialTimeout acts like Dial but takes a timeout.
// The timeout includes name resolution, if required.
func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
- // TODO(bradfitz): the timeout should be pushed down into the
- // net package's event loop, so on timeout to dead hosts we
- // don't have a goroutine sticking around for the default of
- // ~3 minutes.
+ return dialTimeout(net, addr, timeout)
+}
+
+// dialTimeoutRace is the old implementation of DialTimeout, still used
+// on operating systems where the deadline hasn't been pushed down
+// into the pollserver.
+// TODO: fix this on plan9.
+func dialTimeoutRace(net, addr string, timeout time.Duration) (Conn, error) {
t := time.NewTimer(timeout)
defer t.Stop()
type pair struct {
@@ -131,30 +244,30 @@ func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
ch := make(chan pair, 1)
resolvedAddr := make(chan Addr, 1)
go func() {
- _, addri, err := resolveNetAddr("dial", net, addr)
+ ra, err := resolveAddr("dial", net, addr, noDeadline)
if err != nil {
ch <- pair{nil, err}
return
}
- resolvedAddr <- addri // in case we need it for OpError
- c, err := dialAddr(net, addr, addri)
+ resolvedAddr <- ra // in case we need it for OpError
+ c, err := dial(net, addr, noLocalAddr, ra, noDeadline)
ch <- pair{c, err}
}()
select {
case <-t.C:
// Try to use the real Addr in our OpError, if we resolved it
// before the timeout. Otherwise we just use stringAddr.
- var addri Addr
+ var ra Addr
select {
case a := <-resolvedAddr:
- addri = a
+ ra = a
default:
- addri = &stringAddr{net, addr}
+ ra = &stringAddr{net, addr}
}
err := &OpError{
Op: "dial",
Net: net,
- Addr: addri,
+ Addr: ra,
Err: &timeoutError{},
}
return nil, err
@@ -173,24 +286,16 @@ func (a stringAddr) String() string { return a.addr }
// Listen announces on the local network address laddr.
// The network string net must be a stream-oriented network:
-// "tcp", "tcp4", "tcp6", or "unix", or "unixpacket".
+// "tcp", "tcp4", "tcp6", "unix" or "unixpacket".
func Listen(net, laddr string) (Listener, error) {
- afnet, a, err := resolveNetAddr("listen", net, laddr)
+ la, err := resolveAddr("listen", net, laddr, noDeadline)
if err != nil {
return nil, err
}
- switch afnet {
- case "tcp", "tcp4", "tcp6":
- var la *TCPAddr
- if a != nil {
- la = a.(*TCPAddr)
- }
+ switch la := la.(type) {
+ case *TCPAddr:
return ListenTCP(net, la)
- case "unix", "unixpacket":
- var la *UnixAddr
- if a != nil {
- la = a.(*UnixAddr)
- }
+ case *UnixAddr:
return ListenUnix(net, la)
}
return nil, UnknownNetworkError(net)
@@ -199,30 +304,18 @@ func Listen(net, laddr string) (Listener, error) {
// ListenPacket announces on the local network address laddr.
// The network string net must be a packet-oriented network:
// "udp", "udp4", "udp6", "ip", "ip4", "ip6" or "unixgram".
-func ListenPacket(net, addr string) (PacketConn, error) {
- afnet, a, err := resolveNetAddr("listen", net, addr)
+func ListenPacket(net, laddr string) (PacketConn, error) {
+ la, err := resolveAddr("listen", net, laddr, noDeadline)
if err != nil {
return nil, err
}
- switch afnet {
- case "udp", "udp4", "udp6":
- var la *UDPAddr
- if a != nil {
- la = a.(*UDPAddr)
- }
+ switch la := la.(type) {
+ case *UDPAddr:
return ListenUDP(net, la)
- case "ip", "ip4", "ip6":
- var la *IPAddr
- if a != nil {
- la = a.(*IPAddr)
- }
+ case *IPAddr:
return ListenIP(net, la)
- case "unixgram":
- var la *UnixAddr
- if a != nil {
- la = a.(*UnixAddr)
- }
- return DialUnix(net, la, nil)
+ case *UnixAddr:
+ return ListenUnixgram(net, la)
}
return nil, UnknownNetworkError(net)
}
diff --git a/src/pkg/net/dial_test.go b/src/pkg/net/dial_test.go
index 7212087fe..2303e8fa4 100644
--- a/src/pkg/net/dial_test.go
+++ b/src/pkg/net/dial_test.go
@@ -7,6 +7,9 @@ package net
import (
"flag"
"fmt"
+ "io"
+ "os"
+ "reflect"
"regexp"
"runtime"
"testing"
@@ -55,7 +58,7 @@ func TestDialTimeout(t *testing.T) {
// on our 386 builder, this Dial succeeds, connecting
// to an IIS web server somewhere. The data center
// or VM or firewall must be stealing the TCP connection.
- //
+ //
// IANA Service Name and Transport Protocol Port Number Registry
// <http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xml>
go func() {
@@ -72,8 +75,7 @@ func TestDialTimeout(t *testing.T) {
// by default. FreeBSD likely works, but is untested.
// TODO(rsc):
// The timeout never happens on Windows. Why? Issue 3016.
- t.Logf("skipping test on %q; untested.", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q; untested.", runtime.GOOS)
}
connected := 0
@@ -105,8 +107,7 @@ func TestDialTimeout(t *testing.T) {
func TestSelfConnect(t *testing.T) {
if runtime.GOOS == "windows" {
// TODO(brainman): do not know why it hangs.
- t.Logf("skipping known-broken test on windows")
- return
+ t.Skip("skipping known-broken test on windows")
}
// Test that Dial does not honor self-connects.
// See the comment in DialTCP.
@@ -130,7 +131,7 @@ func TestSelfConnect(t *testing.T) {
n = 1000
}
switch runtime.GOOS {
- case "darwin", "freebsd", "openbsd", "windows":
+ case "darwin", "freebsd", "netbsd", "openbsd", "plan9", "windows":
// Non-Linux systems take a long time to figure
// out that there is nothing listening on localhost.
n = 100
@@ -222,3 +223,104 @@ func TestDialError(t *testing.T) {
}
}
}
+
+var invalidDialAndListenArgTests = []struct {
+ net string
+ addr string
+ err error
+}{
+ {"foo", "bar", &OpError{Op: "dial", Net: "foo", Addr: nil, Err: UnknownNetworkError("foo")}},
+ {"baz", "", &OpError{Op: "listen", Net: "baz", Addr: nil, Err: UnknownNetworkError("baz")}},
+ {"tcp", "", &OpError{Op: "dial", Net: "tcp", Addr: nil, Err: errMissingAddress}},
+}
+
+func TestInvalidDialAndListenArgs(t *testing.T) {
+ for _, tt := range invalidDialAndListenArgTests {
+ var err error
+ switch tt.err.(*OpError).Op {
+ case "dial":
+ _, err = Dial(tt.net, tt.addr)
+ case "listen":
+ _, err = Listen(tt.net, tt.addr)
+ }
+ if !reflect.DeepEqual(tt.err, err) {
+ t.Fatalf("got %#v; expected %#v", err, tt.err)
+ }
+ }
+}
+
+func TestDialTimeoutFDLeak(t *testing.T) {
+ if runtime.GOOS != "linux" {
+ // TODO(bradfitz): test on other platforms
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ type connErr struct {
+ conn Conn
+ err error
+ }
+ dials := listenerBacklog + 100
+ // used to be listenerBacklog + 5, but was found to be unreliable, issue 4384.
+ maxGoodConnect := listenerBacklog + runtime.NumCPU()*10
+ resc := make(chan connErr)
+ for i := 0; i < dials; i++ {
+ go func() {
+ conn, err := DialTimeout("tcp", ln.Addr().String(), 500*time.Millisecond)
+ resc <- connErr{conn, err}
+ }()
+ }
+
+ var firstErr string
+ var ngood int
+ var toClose []io.Closer
+ for i := 0; i < dials; i++ {
+ ce := <-resc
+ if ce.err == nil {
+ ngood++
+ if ngood > maxGoodConnect {
+ t.Errorf("%d good connects; expected at most %d", ngood, maxGoodConnect)
+ }
+ toClose = append(toClose, ce.conn)
+ continue
+ }
+ err := ce.err
+ if firstErr == "" {
+ firstErr = err.Error()
+ } else if err.Error() != firstErr {
+ t.Fatalf("inconsistent error messages: first was %q, then later %q", firstErr, err)
+ }
+ }
+ for _, c := range toClose {
+ c.Close()
+ }
+ for i := 0; i < 100; i++ {
+ if got := numFD(); got < dials {
+ // Test passes.
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ if got := numFD(); got >= dials {
+ t.Errorf("num fds after %d timeouts = %d; want <%d", dials, got, dials)
+ }
+}
+
+func numFD() int {
+ if runtime.GOOS == "linux" {
+ f, err := os.Open("/proc/self/fd")
+ if err != nil {
+ panic(err)
+ }
+ defer f.Close()
+ names, err := f.Readdirnames(0)
+ if err != nil {
+ panic(err)
+ }
+ return len(names)
+ }
+ // All tests using this should be skipped anyway, but:
+ panic("numFDs not implemented on " + runtime.GOOS)
+}
diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go
index 03c449972..73a94f5bf 100644
--- a/src/pkg/net/dialgoogle_test.go
+++ b/src/pkg/net/dialgoogle_test.go
@@ -41,17 +41,6 @@ func doDial(t *testing.T, network, addr string) {
fd.Close()
}
-func TestLookupCNAME(t *testing.T) {
- if testing.Short() || !*testExternal {
- t.Logf("skipping test to avoid external network")
- return
- }
- cname, err := LookupCNAME("www.google.com")
- if !strings.HasSuffix(cname, ".l.google.com.") || err != nil {
- t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "*.l.google.com.", nil`, cname, err)
- }
-}
-
var googleaddrsipv4 = []string{
"%d.%d.%d.%d:80",
"www.google.com:80",
@@ -67,8 +56,7 @@ var googleaddrsipv4 = []string{
func TestDialGoogleIPv4(t *testing.T) {
if testing.Short() || !*testExternal {
- t.Logf("skipping test to avoid external network")
- return
+ t.Skip("skipping test to avoid external network")
}
// Insert an actual IPv4 address for google.com
@@ -123,12 +111,14 @@ var googleaddrsipv6 = []string{
func TestDialGoogleIPv6(t *testing.T) {
if testing.Short() || !*testExternal {
- t.Logf("skipping test to avoid external network")
- return
+ t.Skip("skipping test to avoid external network")
}
// Only run tcp6 if the kernel will take it.
- if !*testIPv6 || !supportsIPv6 {
- return
+ if !supportsIPv6 {
+ t.Skip("skipping test; ipv6 is not supported")
+ }
+ if !*testIPv6 {
+ t.Skip("test disabled; use -ipv6 to enable")
}
// Insert an actual IPv6 address for ipv6.google.com
diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go
index e69cb3188..76b192645 100644
--- a/src/pkg/net/dnsclient.go
+++ b/src/pkg/net/dnsclient.go
@@ -183,7 +183,7 @@ func (s byPriorityWeight) Less(i, j int) bool {
}
// shuffleByWeight shuffles SRV records by weight using the algorithm
-// described in RFC 2782.
+// described in RFC 2782.
func (addrs byPriorityWeight) shuffleByWeight() {
sum := 0
for _, addr := range addrs {
@@ -244,3 +244,8 @@ func (s byPref) sort() {
}
sort.Sort(s)
}
+
+// An NS represents a single DNS NS record.
+type NS struct {
+ Host string
+}
diff --git a/src/pkg/net/dnsclient_unix.go b/src/pkg/net/dnsclient_unix.go
index 18c39360e..9e21bb4a0 100644
--- a/src/pkg/net/dnsclient_unix.go
+++ b/src/pkg/net/dnsclient_unix.go
@@ -237,24 +237,30 @@ func goLookupIP(name string) (addrs []IP, err error) {
}
var records []dnsRR
var cname string
- cname, records, err = lookup(name, dnsTypeA)
- if err != nil {
- return
- }
+ var err4, err6 error
+ cname, records, err4 = lookup(name, dnsTypeA)
addrs = convertRR_A(records)
if cname != "" {
name = cname
}
- _, records, err = lookup(name, dnsTypeAAAA)
- if err != nil && len(addrs) > 0 {
- // Ignore error because A lookup succeeded.
- err = nil
+ _, records, err6 = lookup(name, dnsTypeAAAA)
+ if err4 != nil && err6 == nil {
+ // Ignore A error because AAAA lookup succeeded.
+ err4 = nil
}
- if err != nil {
- return
+ if err6 != nil && len(addrs) > 0 {
+ // Ignore AAAA error because A lookup succeeded.
+ err6 = nil
}
+ if err4 != nil {
+ return nil, err4
+ }
+ if err6 != nil {
+ return nil, err6
+ }
+
addrs = append(addrs, convertRR_AAAA(records)...)
- return
+ return addrs, nil
}
// goLookupCNAME is the native Go implementation of LookupCNAME.
diff --git a/src/pkg/net/dnsconfig.go b/src/pkg/net/dnsconfig_unix.go
index bb46cc900..bb46cc900 100644
--- a/src/pkg/net/dnsconfig.go
+++ b/src/pkg/net/dnsconfig_unix.go
diff --git a/src/pkg/net/dnsmsg.go b/src/pkg/net/dnsmsg.go
index b6ebe1173..161afb2a5 100644
--- a/src/pkg/net/dnsmsg.go
+++ b/src/pkg/net/dnsmsg.go
@@ -618,7 +618,7 @@ func printStruct(any dnsStruct) string {
s += name + "="
switch tag {
case "ipv4":
- i := val.(uint32)
+ i := *val.(*uint32)
s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
case "ipv6":
i := val.([]byte)
diff --git a/src/pkg/net/example_test.go b/src/pkg/net/example_test.go
index 1a1c2edfe..eefe84fa7 100644
--- a/src/pkg/net/example_test.go
+++ b/src/pkg/net/example_test.go
@@ -17,7 +17,7 @@ func ExampleListener() {
log.Fatal(err)
}
for {
- // Wait for a connection.
+ // Wait for a connection.
conn, err := l.Accept()
if err != nil {
log.Fatal(err)
diff --git a/src/pkg/net/fd_openbsd.go b/src/pkg/net/fd_bsd.go
index 35d84c30e..8bb1ae538 100644
--- a/src/pkg/net/fd_openbsd.go
+++ b/src/pkg/net/fd_bsd.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build freebsd netbsd openbsd
+
// Waiting for FDs via kqueue/kevent.
package net
@@ -31,6 +33,8 @@ func newpollster() (p *pollster, err error) {
return p, nil
}
+// First return value is whether the pollServer should be woken up.
+// This version always returns false.
func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
// pollServer is locked.
@@ -62,7 +66,9 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
return false, nil
}
-func (p *pollster) DelFD(fd int, mode int) {
+// Return value is whether the pollServer should be woken up.
+// This version always returns false.
+func (p *pollster) DelFD(fd int, mode int) bool {
// pollServer is locked.
var kmode int
@@ -75,6 +81,7 @@ func (p *pollster) DelFD(fd int, mode int) {
// EV_DELETE - delete event from kqueue list
syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE)
syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
+ return false
}
func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
diff --git a/src/pkg/net/fd_darwin.go b/src/pkg/net/fd_darwin.go
index 3dd33edc2..382465ba6 100644
--- a/src/pkg/net/fd_darwin.go
+++ b/src/pkg/net/fd_darwin.go
@@ -32,6 +32,8 @@ func newpollster() (p *pollster, err error) {
return p, nil
}
+// First return value is whether the pollServer should be woken up.
+// This version always returns false.
func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
// pollServer is locked.
@@ -65,7 +67,9 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
return false, nil
}
-func (p *pollster) DelFD(fd int, mode int) {
+// Return value is whether the pollServer should be woken up.
+// This version always returns false.
+func (p *pollster) DelFD(fd int, mode int) bool {
// pollServer is locked.
var kmode int
@@ -80,6 +84,7 @@ func (p *pollster) DelFD(fd int, mode int) {
// rather than waiting for real event
syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE|syscall.EV_RECEIPT)
syscall.Kevent(p.kq, p.kbuf[0:], p.kbuf[0:], nil)
+ return false
}
func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
diff --git a/src/pkg/net/fd_freebsd.go b/src/pkg/net/fd_freebsd.go
deleted file mode 100644
index 35d84c30e..000000000
--- a/src/pkg/net/fd_freebsd.go
+++ /dev/null
@@ -1,116 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Waiting for FDs via kqueue/kevent.
-
-package net
-
-import (
- "os"
- "syscall"
-)
-
-type pollster struct {
- kq int
- eventbuf [10]syscall.Kevent_t
- events []syscall.Kevent_t
-
- // An event buffer for AddFD/DelFD.
- // Must hold pollServer lock.
- kbuf [1]syscall.Kevent_t
-}
-
-func newpollster() (p *pollster, err error) {
- p = new(pollster)
- if p.kq, err = syscall.Kqueue(); err != nil {
- return nil, os.NewSyscallError("kqueue", err)
- }
- syscall.CloseOnExec(p.kq)
- p.events = p.eventbuf[0:0]
- return p, nil
-}
-
-func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
- // pollServer is locked.
-
- var kmode int
- if mode == 'r' {
- kmode = syscall.EVFILT_READ
- } else {
- kmode = syscall.EVFILT_WRITE
- }
- ev := &p.kbuf[0]
- // EV_ADD - add event to kqueue list
- // EV_ONESHOT - delete the event the first time it triggers
- flags := syscall.EV_ADD
- if !repeat {
- flags |= syscall.EV_ONESHOT
- }
- syscall.SetKevent(ev, fd, kmode, flags)
-
- n, err := syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
- if err != nil {
- return false, os.NewSyscallError("kevent", err)
- }
- if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
- return false, os.NewSyscallError("kqueue phase error", err)
- }
- if ev.Data != 0 {
- return false, syscall.Errno(int(ev.Data))
- }
- return false, nil
-}
-
-func (p *pollster) DelFD(fd int, mode int) {
- // pollServer is locked.
-
- var kmode int
- if mode == 'r' {
- kmode = syscall.EVFILT_READ
- } else {
- kmode = syscall.EVFILT_WRITE
- }
- ev := &p.kbuf[0]
- // EV_DELETE - delete event from kqueue list
- syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE)
- syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
-}
-
-func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
- var t *syscall.Timespec
- for len(p.events) == 0 {
- if nsec > 0 {
- if t == nil {
- t = new(syscall.Timespec)
- }
- *t = syscall.NsecToTimespec(nsec)
- }
-
- s.Unlock()
- n, err := syscall.Kevent(p.kq, nil, p.eventbuf[:], t)
- s.Lock()
-
- if err != nil {
- if err == syscall.EINTR {
- continue
- }
- return -1, 0, os.NewSyscallError("kevent", err)
- }
- if n == 0 {
- return -1, 0, nil
- }
- p.events = p.eventbuf[:n]
- }
- ev := &p.events[0]
- p.events = p.events[1:]
- fd = int(ev.Ident)
- if ev.Filter == syscall.EVFILT_READ {
- mode = 'r'
- } else {
- mode = 'w'
- }
- return fd, mode, nil
-}
-
-func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) }
diff --git a/src/pkg/net/fd_linux.go b/src/pkg/net/fd_linux.go
index 085e42307..03679196d 100644
--- a/src/pkg/net/fd_linux.go
+++ b/src/pkg/net/fd_linux.go
@@ -51,6 +51,8 @@ func newpollster() (p *pollster, err error) {
return p, nil
}
+// First return value is whether the pollServer should be woken up.
+// This version always returns false.
func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
// pollServer is locked.
@@ -114,7 +116,9 @@ func (p *pollster) StopWaiting(fd int, bits uint) {
}
}
-func (p *pollster) DelFD(fd int, mode int) {
+// Return value is whether the pollServer should be woken up.
+// This version always returns false.
+func (p *pollster) DelFD(fd int, mode int) bool {
// pollServer is locked.
if mode == 'r' {
@@ -133,6 +137,7 @@ func (p *pollster) DelFD(fd int, mode int) {
i++
}
}
+ return false
}
func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
diff --git a/src/pkg/net/fd_netbsd.go b/src/pkg/net/fd_netbsd.go
deleted file mode 100644
index 35d84c30e..000000000
--- a/src/pkg/net/fd_netbsd.go
+++ /dev/null
@@ -1,116 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Waiting for FDs via kqueue/kevent.
-
-package net
-
-import (
- "os"
- "syscall"
-)
-
-type pollster struct {
- kq int
- eventbuf [10]syscall.Kevent_t
- events []syscall.Kevent_t
-
- // An event buffer for AddFD/DelFD.
- // Must hold pollServer lock.
- kbuf [1]syscall.Kevent_t
-}
-
-func newpollster() (p *pollster, err error) {
- p = new(pollster)
- if p.kq, err = syscall.Kqueue(); err != nil {
- return nil, os.NewSyscallError("kqueue", err)
- }
- syscall.CloseOnExec(p.kq)
- p.events = p.eventbuf[0:0]
- return p, nil
-}
-
-func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
- // pollServer is locked.
-
- var kmode int
- if mode == 'r' {
- kmode = syscall.EVFILT_READ
- } else {
- kmode = syscall.EVFILT_WRITE
- }
- ev := &p.kbuf[0]
- // EV_ADD - add event to kqueue list
- // EV_ONESHOT - delete the event the first time it triggers
- flags := syscall.EV_ADD
- if !repeat {
- flags |= syscall.EV_ONESHOT
- }
- syscall.SetKevent(ev, fd, kmode, flags)
-
- n, err := syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
- if err != nil {
- return false, os.NewSyscallError("kevent", err)
- }
- if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
- return false, os.NewSyscallError("kqueue phase error", err)
- }
- if ev.Data != 0 {
- return false, syscall.Errno(int(ev.Data))
- }
- return false, nil
-}
-
-func (p *pollster) DelFD(fd int, mode int) {
- // pollServer is locked.
-
- var kmode int
- if mode == 'r' {
- kmode = syscall.EVFILT_READ
- } else {
- kmode = syscall.EVFILT_WRITE
- }
- ev := &p.kbuf[0]
- // EV_DELETE - delete event from kqueue list
- syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE)
- syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
-}
-
-func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) {
- var t *syscall.Timespec
- for len(p.events) == 0 {
- if nsec > 0 {
- if t == nil {
- t = new(syscall.Timespec)
- }
- *t = syscall.NsecToTimespec(nsec)
- }
-
- s.Unlock()
- n, err := syscall.Kevent(p.kq, nil, p.eventbuf[:], t)
- s.Lock()
-
- if err != nil {
- if err == syscall.EINTR {
- continue
- }
- return -1, 0, os.NewSyscallError("kevent", err)
- }
- if n == 0 {
- return -1, 0, nil
- }
- p.events = p.eventbuf[:n]
- }
- ev := &p.events[0]
- p.events = p.events[1:]
- fd = int(ev.Ident)
- if ev.Filter == syscall.EVFILT_READ {
- mode = 'r'
- } else {
- mode = 'w'
- }
- return fd, mode, nil
-}
-
-func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) }
diff --git a/src/pkg/net/fd_plan9.go b/src/pkg/net/fd_plan9.go
new file mode 100644
index 000000000..169087999
--- /dev/null
+++ b/src/pkg/net/fd_plan9.go
@@ -0,0 +1,129 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "io"
+ "os"
+ "syscall"
+ "time"
+)
+
+// Network file descritor.
+type netFD struct {
+ proto, name, dir string
+ ctl, data *os.File
+ laddr, raddr Addr
+}
+
+var canCancelIO = true // used for testing current package
+
+func sysInit() {
+}
+
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+ // On plan9, use the relatively inefficient
+ // goroutine-racing implementation.
+ return dialTimeoutRace(net, addr, timeout)
+}
+
+func newFD(proto, name string, ctl, data *os.File, laddr, raddr Addr) *netFD {
+ return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, data, laddr, raddr}
+}
+
+func (fd *netFD) ok() bool { return fd != nil && fd.ctl != nil }
+
+func (fd *netFD) Read(b []byte) (n int, err error) {
+ if !fd.ok() || fd.data == nil {
+ return 0, syscall.EINVAL
+ }
+ n, err = fd.data.Read(b)
+ if fd.proto == "udp" && err == io.EOF {
+ n = 0
+ err = nil
+ }
+ return
+}
+
+func (fd *netFD) Write(b []byte) (n int, err error) {
+ if !fd.ok() || fd.data == nil {
+ return 0, syscall.EINVAL
+ }
+ return fd.data.Write(b)
+}
+
+func (fd *netFD) CloseRead() error {
+ if !fd.ok() {
+ return syscall.EINVAL
+ }
+ return syscall.EPLAN9
+}
+
+func (fd *netFD) CloseWrite() error {
+ if !fd.ok() {
+ return syscall.EINVAL
+ }
+ return syscall.EPLAN9
+}
+
+func (fd *netFD) Close() error {
+ if !fd.ok() {
+ return syscall.EINVAL
+ }
+ err := fd.ctl.Close()
+ if fd.data != nil {
+ if err1 := fd.data.Close(); err1 != nil && err == nil {
+ err = err1
+ }
+ }
+ fd.ctl = nil
+ fd.data = nil
+ return err
+}
+
+// This method is only called via Conn.
+func (fd *netFD) dup() (*os.File, error) {
+ if !fd.ok() || fd.data == nil {
+ return nil, syscall.EINVAL
+ }
+ return fd.file(fd.data, fd.dir+"/data")
+}
+
+func (l *TCPListener) dup() (*os.File, error) {
+ if !l.fd.ok() {
+ return nil, syscall.EINVAL
+ }
+ return l.fd.file(l.fd.ctl, l.fd.dir+"/ctl")
+}
+
+func (fd *netFD) file(f *os.File, s string) (*os.File, error) {
+ syscall.ForkLock.RLock()
+ dfd, err := syscall.Dup(int(f.Fd()), -1)
+ syscall.ForkLock.RUnlock()
+ if err != nil {
+ return nil, &OpError{"dup", s, fd.laddr, err}
+ }
+ return os.NewFile(uintptr(dfd), s), nil
+}
+
+func setDeadline(fd *netFD, t time.Time) error {
+ return syscall.EPLAN9
+}
+
+func setReadDeadline(fd *netFD, t time.Time) error {
+ return syscall.EPLAN9
+}
+
+func setWriteDeadline(fd *netFD, t time.Time) error {
+ return syscall.EPLAN9
+}
+
+func setReadBuffer(fd *netFD, bytes int) error {
+ return syscall.EPLAN9
+}
+
+func setWriteBuffer(fd *netFD, bytes int) error {
+ return syscall.EPLAN9
+}
diff --git a/src/pkg/net/fd_posix_test.go b/src/pkg/net/fd_posix_test.go
new file mode 100644
index 000000000..8be0335d6
--- /dev/null
+++ b/src/pkg/net/fd_posix_test.go
@@ -0,0 +1,57 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd windows
+
+package net
+
+import (
+ "testing"
+ "time"
+)
+
+var deadlineSetTimeTests = []struct {
+ input time.Time
+ expected int64
+}{
+ {time.Time{}, 0},
+ {time.Date(2009, 11, 10, 23, 00, 00, 00, time.UTC), 1257894000000000000}, // 2009-11-10 23:00:00 +0000 UTC
+}
+
+func TestDeadlineSetTime(t *testing.T) {
+ for _, tt := range deadlineSetTimeTests {
+ var d deadline
+ d.setTime(tt.input)
+ actual := d.value()
+ expected := int64(0)
+ if !tt.input.IsZero() {
+ expected = tt.input.UnixNano()
+ }
+ if actual != expected {
+ t.Errorf("set/value failed: expected %v, actual %v", expected, actual)
+ }
+ }
+}
+
+var deadlineExpiredTests = []struct {
+ deadline time.Time
+ expired bool
+}{
+ // note, times are relative to the start of the test run, not
+ // the start of TestDeadlineExpired
+ {time.Now().Add(5 * time.Minute), false},
+ {time.Now().Add(-5 * time.Minute), true},
+ {time.Time{}, false}, // no deadline set
+}
+
+func TestDeadlineExpired(t *testing.T) {
+ for _, tt := range deadlineExpiredTests {
+ var d deadline
+ d.set(tt.deadline.UnixNano())
+ expired := d.expired()
+ if expired != tt.expired {
+ t.Errorf("expire failed: expected %v, actual %v", tt.expired, expired)
+ }
+ }
+}
diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd_unix.go
index 76c953b9b..0540df825 100644
--- a/src/pkg/net/fd.go
+++ b/src/pkg/net/fd_unix.go
@@ -7,9 +7,9 @@
package net
import (
- "errors"
"io"
"os"
+ "runtime"
"sync"
"syscall"
"time"
@@ -37,44 +37,24 @@ type netFD struct {
laddr Addr
raddr Addr
- // owned by client
- rdeadline int64
- rio sync.Mutex
- wdeadline int64
- wio sync.Mutex
+ // serialize access to Read and Write methods
+ rio, wio sync.Mutex
+
+ // read and write deadlines
+ rdeadline, wdeadline deadline
// owned by fd wait server
ncr, ncw int
+
+ // wait server
+ pollServer *pollServer
}
// A pollServer helps FDs determine when to retry a non-blocking
// read or write after they get EAGAIN. When an FD needs to wait,
-// send the fd on s.cr (for a read) or s.cw (for a write) to pass the
-// request to the poll server. Then receive on fd.cr/fd.cw.
+// call s.WaitRead() or s.WaitWrite() to pass the request to the poll server.
// When the pollServer finds that i/o on FD should be possible
-// again, it will send fd on fd.cr/fd.cw to wake any waiting processes.
-// This protocol is implemented as s.WaitRead() and s.WaitWrite().
-//
-// There is one subtlety: when sending on s.cr/s.cw, the
-// poll server is probably in a system call, waiting for an fd
-// to become ready. It's not looking at the request channels.
-// To resolve this, the poll server waits not just on the FDs it has
-// been given but also its own pipe. After sending on the
-// buffered channel s.cr/s.cw, WaitRead/WaitWrite writes a
-// byte to the pipe, causing the pollServer's poll system call to
-// return. In response to the pipe being readable, the pollServer
-// re-polls its request channels.
-//
-// Note that the ordering is "send request" and then "wake up server".
-// If the operations were reversed, there would be a race: the poll
-// server might wake up and look at the request channel, see that it
-// was empty, and go back to sleep, all before the requester managed
-// to send the request. Because the send must complete before the wakeup,
-// the request channel must be buffered. A buffer of size 1 is sufficient
-// for any request load. If many processes are trying to submit requests,
-// one will succeed, the pollServer will read the request, and then the
-// channel will be empty for the next process's request. A larger buffer
-// might help batch requests.
+// again, it will send on fd.cr/fd.cw to wake any waiting goroutines.
//
// To avoid races in closing, all fd operations are locked and
// refcounted. when netFD.Close() is called, it calls syscall.Shutdown
@@ -82,7 +62,6 @@ type netFD struct {
// will the fd be closed.
type pollServer struct {
- cr, cw chan *netFD // buffered >= 1
pr, pw *os.File
poll *pollster // low-level OS hooks
sync.Mutex // controls pending and deadline
@@ -103,11 +82,11 @@ func (s *pollServer) AddFD(fd *netFD, mode int) error {
key := intfd << 1
if mode == 'r' {
fd.ncr++
- t = fd.rdeadline
+ t = fd.rdeadline.value()
} else {
fd.ncw++
key++
- t = fd.wdeadline
+ t = fd.wdeadline.value()
}
s.pending[key] = fd
doWakeup := false
@@ -117,15 +96,11 @@ func (s *pollServer) AddFD(fd *netFD, mode int) error {
}
wake, err := s.poll.AddFD(intfd, mode, false)
+ s.Unlock()
if err != nil {
- panic("pollServer AddFD " + err.Error())
- }
- if wake {
- doWakeup = true
+ return &OpError{"addfd", fd.net, fd.laddr, err}
}
- s.Unlock()
-
- if doWakeup {
+ if wake || doWakeup {
s.Wakeup()
}
return nil
@@ -134,17 +109,24 @@ func (s *pollServer) AddFD(fd *netFD, mode int) error {
// Evict evicts fd from the pending list, unblocking
// any I/O running on fd. The caller must have locked
// pollserver.
-func (s *pollServer) Evict(fd *netFD) {
+// Return value is whether the pollServer should be woken up.
+func (s *pollServer) Evict(fd *netFD) bool {
+ doWakeup := false
if s.pending[fd.sysfd<<1] == fd {
s.WakeFD(fd, 'r', errClosing)
- s.poll.DelFD(fd.sysfd, 'r')
+ if s.poll.DelFD(fd.sysfd, 'r') {
+ doWakeup = true
+ }
delete(s.pending, fd.sysfd<<1)
}
if s.pending[fd.sysfd<<1|1] == fd {
s.WakeFD(fd, 'w', errClosing)
- s.poll.DelFD(fd.sysfd, 'w')
+ if s.poll.DelFD(fd.sysfd, 'w') {
+ doWakeup = true
+ }
delete(s.pending, fd.sysfd<<1|1)
}
+ return doWakeup
}
var wakeupbuf [1]byte
@@ -178,16 +160,12 @@ func (s *pollServer) WakeFD(fd *netFD, mode int, err error) {
}
}
-func (s *pollServer) Now() int64 {
- return time.Now().UnixNano()
-}
-
func (s *pollServer) CheckDeadlines() {
- now := s.Now()
+ now := time.Now().UnixNano()
// TODO(rsc): This will need to be handled more efficiently,
// probably with a heap indexed by wakeup time.
- var next_deadline int64
+ var nextDeadline int64
for key, fd := range s.pending {
var t int64
var mode int
@@ -197,27 +175,21 @@ func (s *pollServer) CheckDeadlines() {
mode = 'w'
}
if mode == 'r' {
- t = fd.rdeadline
+ t = fd.rdeadline.value()
} else {
- t = fd.wdeadline
+ t = fd.wdeadline.value()
}
if t > 0 {
if t <= now {
delete(s.pending, key)
- if mode == 'r' {
- s.poll.DelFD(fd.sysfd, mode)
- fd.rdeadline = -1
- } else {
- s.poll.DelFD(fd.sysfd, mode)
- fd.wdeadline = -1
- }
- s.WakeFD(fd, mode, nil)
- } else if next_deadline == 0 || t < next_deadline {
- next_deadline = t
+ s.poll.DelFD(fd.sysfd, mode)
+ s.WakeFD(fd, mode, errTimeout)
+ } else if nextDeadline == 0 || t < nextDeadline {
+ nextDeadline = t
}
}
}
- s.deadline = next_deadline
+ s.deadline = nextDeadline
}
func (s *pollServer) Run() {
@@ -225,15 +197,15 @@ func (s *pollServer) Run() {
s.Lock()
defer s.Unlock()
for {
- var t = s.deadline
- if t > 0 {
- t = t - s.Now()
- if t <= 0 {
+ var timeout int64 // nsec to wait for or 0 for none
+ if s.deadline > 0 {
+ timeout = s.deadline - time.Now().UnixNano()
+ if timeout <= 0 {
s.CheckDeadlines()
continue
}
}
- fd, mode, err := s.poll.WaitFD(s, t)
+ fd, mode, err := s.poll.WaitFD(s, timeout)
if err != nil {
print("pollServer WaitFD: ", err.Error(), "\n")
return
@@ -279,24 +251,56 @@ func (s *pollServer) WaitWrite(fd *netFD) error {
}
// Network FD methods.
-// All the network FDs use a single pollServer.
+// Spread network FDs over several pollServers.
+
+var pollMaxN int
+var pollservers []*pollServer
+var startServersOnce []func()
+
+var canCancelIO = true // used for testing current package
-var pollserver *pollServer
-var onceStartServer sync.Once
+func sysInit() {
+ pollMaxN = runtime.NumCPU()
+ if pollMaxN > 8 {
+ pollMaxN = 8 // No improvement then.
+ }
+ pollservers = make([]*pollServer, pollMaxN)
+ startServersOnce = make([]func(), pollMaxN)
+ for i := 0; i < pollMaxN; i++ {
+ k := i
+ once := new(sync.Once)
+ startServersOnce[i] = func() { once.Do(func() { startServer(k) }) }
+ }
+}
-func startServer() {
+func startServer(k int) {
p, err := newPollServer()
if err != nil {
- print("Start pollServer: ", err.Error(), "\n")
+ panic(err)
}
- pollserver = p
+ pollservers[k] = p
}
-func newFD(fd, family, sotype int, net string) (*netFD, error) {
- onceStartServer.Do(startServer)
- if err := syscall.SetNonblock(fd, true); err != nil {
+func server(fd int) *pollServer {
+ pollN := runtime.GOMAXPROCS(0)
+ if pollN > pollMaxN {
+ pollN = pollMaxN
+ }
+ k := fd % pollN
+ startServersOnce[k]()
+ return pollservers[k]
+}
+
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+ deadline := time.Now().Add(timeout)
+ ra, err := resolveAddr("dial", net, addr, deadline)
+ if err != nil {
return nil, err
}
+ return dial(net, addr, noLocalAddr, ra, deadline)
+}
+
+func newFD(fd, family, sotype int, net string) (*netFD, error) {
netfd := &netFD{
sysfd: fd,
family: family,
@@ -305,26 +309,31 @@ func newFD(fd, family, sotype int, net string) (*netFD, error) {
}
netfd.cr = make(chan error, 1)
netfd.cw = make(chan error, 1)
+ netfd.pollServer = server(fd)
return netfd, nil
}
func (fd *netFD) setAddr(laddr, raddr Addr) {
fd.laddr = laddr
fd.raddr = raddr
+ fd.sysfile = os.NewFile(uintptr(fd.sysfd), fd.net)
+}
+
+func (fd *netFD) name() string {
var ls, rs string
- if laddr != nil {
- ls = laddr.String()
+ if fd.laddr != nil {
+ ls = fd.laddr.String()
}
- if raddr != nil {
- rs = raddr.String()
+ if fd.raddr != nil {
+ rs = fd.raddr.String()
}
- fd.sysfile = os.NewFile(uintptr(fd.sysfd), fd.net+":"+ls+"->"+rs)
+ return fd.net + ":" + ls + "->" + rs
}
func (fd *netFD) connect(ra syscall.Sockaddr) error {
err := syscall.Connect(fd.sysfd, ra)
if err == syscall.EINPROGRESS {
- if err = pollserver.WaitWrite(fd); err != nil {
+ if err = fd.pollServer.WaitWrite(fd); err != nil {
return err
}
var e int
@@ -339,15 +348,10 @@ func (fd *netFD) connect(ra syscall.Sockaddr) error {
return err
}
-var errClosing = errors.New("use of closed network connection")
-
// Add a reference to this fd.
// If closing==true, pollserver must be locked; mark the fd as closing.
// Returns an error if the fd cannot be used.
func (fd *netFD) incref(closing bool) error {
- if fd == nil {
- return errClosing
- }
fd.sysmu.Lock()
if fd.closing {
fd.sysmu.Unlock()
@@ -364,9 +368,6 @@ func (fd *netFD) incref(closing bool) error {
// Remove a reference to this FD and close if we've been asked to do so (and
// there are no references left.
func (fd *netFD) decref() {
- if fd == nil {
- return
- }
fd.sysmu.Lock()
fd.sysref--
if fd.closing && fd.sysref == 0 && fd.sysfile != nil {
@@ -378,9 +379,9 @@ func (fd *netFD) decref() {
}
func (fd *netFD) Close() error {
- pollserver.Lock() // needed for both fd.incref(true) and pollserver.Evict
- defer pollserver.Unlock()
+ fd.pollServer.Lock() // needed for both fd.incref(true) and pollserver.Evict
if err := fd.incref(true); err != nil {
+ fd.pollServer.Unlock()
return err
}
// Unblock any I/O. Once it all unblocks and returns,
@@ -388,8 +389,12 @@ func (fd *netFD) Close() error {
// the final decref will close fd.sysfd. This should happen
// fairly quickly, since all the I/O is non-blocking, and any
// attempts to block in the pollserver will return errClosing.
- pollserver.Evict(fd)
+ doWakeup := fd.pollServer.Evict(fd)
+ fd.pollServer.Unlock()
fd.decref()
+ if doWakeup {
+ fd.pollServer.Wakeup()
+ }
return nil
}
@@ -421,20 +426,20 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
}
defer fd.decref()
for {
- n, err = syscall.Read(int(fd.sysfd), p)
- if err == syscall.EAGAIN {
+ if fd.rdeadline.expired() {
err = errTimeout
- if fd.rdeadline >= 0 {
- if err = pollserver.WaitRead(fd); err == nil {
- continue
- }
- }
+ break
}
+ n, err = syscall.Read(int(fd.sysfd), p)
if err != nil {
n = 0
- } else if n == 0 && err == nil && fd.sotype != syscall.SOCK_DGRAM {
- err = io.EOF
+ if err == syscall.EAGAIN {
+ if err = fd.pollServer.WaitRead(fd); err == nil {
+ continue
+ }
+ }
}
+ err = chkReadErr(n, err, fd)
break
}
if err != nil && err != io.EOF {
@@ -451,18 +456,20 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
}
defer fd.decref()
for {
- n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
- if err == syscall.EAGAIN {
+ if fd.rdeadline.expired() {
err = errTimeout
- if fd.rdeadline >= 0 {
- if err = pollserver.WaitRead(fd); err == nil {
- continue
- }
- }
+ break
}
+ n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
if err != nil {
n = 0
+ if err == syscall.EAGAIN {
+ if err = fd.pollServer.WaitRead(fd); err == nil {
+ continue
+ }
+ }
}
+ err = chkReadErr(n, err, fd)
break
}
if err != nil && err != io.EOF {
@@ -479,41 +486,47 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
}
defer fd.decref()
for {
- n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
- if err == syscall.EAGAIN {
+ if fd.rdeadline.expired() {
err = errTimeout
- if fd.rdeadline >= 0 {
- if err = pollserver.WaitRead(fd); err == nil {
+ break
+ }
+ n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
+ if err != nil {
+ // TODO(dfc) should n and oobn be set to 0
+ if err == syscall.EAGAIN {
+ if err = fd.pollServer.WaitRead(fd); err == nil {
continue
}
}
}
- if err == nil && n == 0 {
- err = io.EOF
- }
+ err = chkReadErr(n, err, fd)
break
}
if err != nil && err != io.EOF {
err = &OpError{"read", fd.net, fd.laddr, err}
- return
}
return
}
-func (fd *netFD) Write(p []byte) (int, error) {
+func chkReadErr(n int, err error, fd *netFD) error {
+ if n == 0 && err == nil && fd.sotype != syscall.SOCK_DGRAM && fd.sotype != syscall.SOCK_RAW {
+ return io.EOF
+ }
+ return err
+}
+
+func (fd *netFD) Write(p []byte) (nn int, err error) {
fd.wio.Lock()
defer fd.wio.Unlock()
if err := fd.incref(false); err != nil {
return 0, err
}
defer fd.decref()
- if fd.sysfile == nil {
- return 0, syscall.EINVAL
- }
-
- var err error
- nn := 0
for {
+ if fd.wdeadline.expired() {
+ err = errTimeout
+ break
+ }
var n int
n, err = syscall.Write(int(fd.sysfd), p[nn:])
if n > 0 {
@@ -523,11 +536,8 @@ func (fd *netFD) Write(p []byte) (int, error) {
break
}
if err == syscall.EAGAIN {
- err = errTimeout
- if fd.wdeadline >= 0 {
- if err = pollserver.WaitWrite(fd); err == nil {
- continue
- }
+ if err = fd.pollServer.WaitWrite(fd); err == nil {
+ continue
}
}
if err != nil {
@@ -553,13 +563,14 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
}
defer fd.decref()
for {
+ if fd.wdeadline.expired() {
+ err = errTimeout
+ break
+ }
err = syscall.Sendto(fd.sysfd, p, 0, sa)
if err == syscall.EAGAIN {
- err = errTimeout
- if fd.wdeadline >= 0 {
- if err = pollserver.WaitWrite(fd); err == nil {
- continue
- }
+ if err = fd.pollServer.WaitWrite(fd); err == nil {
+ continue
}
}
break
@@ -580,13 +591,14 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
}
defer fd.decref()
for {
+ if fd.wdeadline.expired() {
+ err = errTimeout
+ break
+ }
err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
if err == syscall.EAGAIN {
- err = errTimeout
- if fd.wdeadline >= 0 {
- if err = pollserver.WaitWrite(fd); err == nil {
- continue
- }
+ if err = fd.pollServer.WaitWrite(fd); err == nil {
+ continue
}
}
break
@@ -606,22 +618,14 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
}
defer fd.decref()
- // See ../syscall/exec.go for description of ForkLock.
- // It is okay to hold the lock across syscall.Accept
- // because we have put fd.sysfd into non-blocking mode.
var s int
var rsa syscall.Sockaddr
for {
- syscall.ForkLock.RLock()
- s, rsa, err = syscall.Accept(fd.sysfd)
+ s, rsa, err = accept(fd.sysfd)
if err != nil {
- syscall.ForkLock.RUnlock()
if err == syscall.EAGAIN {
- err = errTimeout
- if fd.rdeadline >= 0 {
- if err = pollserver.WaitRead(fd); err == nil {
- continue
- }
+ if err = fd.pollServer.WaitRead(fd); err == nil {
+ continue
}
} else if err == syscall.ECONNABORTED {
// This means that a socket on the listen queue was closed
@@ -632,11 +636,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
}
break
}
- syscall.CloseOnExec(s)
- syscall.ForkLock.RUnlock()
if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil {
- syscall.Close(s)
+ closesocket(s)
return nil, err
}
lsa, _ := syscall.Getsockname(netfd.sysfd)
@@ -645,17 +647,24 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
}
func (fd *netFD) dup() (f *os.File, err error) {
+ syscall.ForkLock.RLock()
ns, err := syscall.Dup(fd.sysfd)
if err != nil {
+ syscall.ForkLock.RUnlock()
return nil, &OpError{"dup", fd.net, fd.laddr, err}
}
+ syscall.CloseOnExec(ns)
+ syscall.ForkLock.RUnlock()
// We want blocking mode for the new fd, hence the double negative.
+ // This also puts the old fd into blocking mode, meaning that
+ // I/O will block the thread instead of letting us use the epoll server.
+ // Everything will still work, just with more threads.
if err = syscall.SetNonblock(ns, false); err != nil {
return nil, &OpError{"setnonblock", fd.net, fd.laddr, err}
}
- return os.NewFile(uintptr(ns), fd.sysfile.Name()), nil
+ return os.NewFile(uintptr(ns), fd.name()), nil
}
func closesocket(s int) error {
diff --git a/src/pkg/net/fd_unix_test.go b/src/pkg/net/fd_unix_test.go
new file mode 100644
index 000000000..664ef1bf1
--- /dev/null
+++ b/src/pkg/net/fd_unix_test.go
@@ -0,0 +1,58 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd
+
+package net
+
+import (
+ "io"
+ "syscall"
+ "testing"
+)
+
+var chkReadErrTests = []struct {
+ n int
+ err error
+ fd *netFD
+ expected error
+}{
+
+ {100, nil, &netFD{sotype: syscall.SOCK_STREAM}, nil},
+ {100, io.EOF, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF},
+ {100, errClosing, &netFD{sotype: syscall.SOCK_STREAM}, errClosing},
+ {0, nil, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF},
+ {0, io.EOF, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF},
+ {0, errClosing, &netFD{sotype: syscall.SOCK_STREAM}, errClosing},
+
+ {100, nil, &netFD{sotype: syscall.SOCK_DGRAM}, nil},
+ {100, io.EOF, &netFD{sotype: syscall.SOCK_DGRAM}, io.EOF},
+ {100, errClosing, &netFD{sotype: syscall.SOCK_DGRAM}, errClosing},
+ {0, nil, &netFD{sotype: syscall.SOCK_DGRAM}, nil},
+ {0, io.EOF, &netFD{sotype: syscall.SOCK_DGRAM}, io.EOF},
+ {0, errClosing, &netFD{sotype: syscall.SOCK_DGRAM}, errClosing},
+
+ {100, nil, &netFD{sotype: syscall.SOCK_SEQPACKET}, nil},
+ {100, io.EOF, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF},
+ {100, errClosing, &netFD{sotype: syscall.SOCK_SEQPACKET}, errClosing},
+ {0, nil, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF},
+ {0, io.EOF, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF},
+ {0, errClosing, &netFD{sotype: syscall.SOCK_SEQPACKET}, errClosing},
+
+ {100, nil, &netFD{sotype: syscall.SOCK_RAW}, nil},
+ {100, io.EOF, &netFD{sotype: syscall.SOCK_RAW}, io.EOF},
+ {100, errClosing, &netFD{sotype: syscall.SOCK_RAW}, errClosing},
+ {0, nil, &netFD{sotype: syscall.SOCK_RAW}, nil},
+ {0, io.EOF, &netFD{sotype: syscall.SOCK_RAW}, io.EOF},
+ {0, errClosing, &netFD{sotype: syscall.SOCK_RAW}, errClosing},
+}
+
+func TestChkReadErr(t *testing.T) {
+ for _, tt := range chkReadErrTests {
+ actual := chkReadErr(tt.n, tt.err, tt.fd)
+ if actual != tt.expected {
+ t.Errorf("chkReadError(%v, %v, %v): expected %v, actual %v", tt.n, tt.err, tt.fd.sotype, tt.expected, actual)
+ }
+ }
+}
diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go
index 45f5c2d88..0e331b44d 100644
--- a/src/pkg/net/fd_windows.go
+++ b/src/pkg/net/fd_windows.go
@@ -17,19 +17,58 @@ import (
var initErr error
-func init() {
+// CancelIo Windows API cancels all outstanding IO for a particular
+// socket on current thread. To overcome that limitation, we run
+// special goroutine, locked to OS single thread, that both starts
+// and cancels IO. It means, there are 2 unavoidable thread switches
+// for every IO.
+// Some newer versions of Windows has new CancelIoEx API, that does
+// not have that limitation and can be used from any thread. This
+// package uses CancelIoEx API, if present, otherwise it fallback
+// to CancelIo.
+
+var canCancelIO bool // determines if CancelIoEx API is present
+
+func sysInit() {
var d syscall.WSAData
e := syscall.WSAStartup(uint32(0x202), &d)
if e != nil {
initErr = os.NewSyscallError("WSAStartup", e)
}
+ canCancelIO = syscall.LoadCancelIoEx() == nil
+ if syscall.LoadGetAddrInfo() == nil {
+ lookupPort = newLookupPort
+ lookupIP = newLookupIP
+ }
}
func closesocket(s syscall.Handle) error {
return syscall.Closesocket(s)
}
-// Interface for all io operations.
+func canUseConnectEx(net string) bool {
+ if net == "udp" || net == "udp4" || net == "udp6" {
+ // ConnectEx windows API does not support connectionless sockets.
+ return false
+ }
+ return syscall.LoadConnectEx() == nil
+}
+
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+ if !canUseConnectEx(net) {
+ // Use the relatively inefficient goroutine-racing
+ // implementation of DialTimeout.
+ return dialTimeoutRace(net, addr, timeout)
+ }
+ deadline := time.Now().Add(timeout)
+ ra, err := resolveAddr("dial", net, addr, deadline)
+ if err != nil {
+ return nil, err
+ }
+ return dial(net, addr, noLocalAddr, ra, deadline)
+}
+
+// Interface for all IO operations.
type anOpIface interface {
Op() *anOp
Name() string
@@ -42,7 +81,7 @@ type ioResult struct {
err error
}
-// anOp implements functionality common to all io operations.
+// anOp implements functionality common to all IO operations.
type anOp struct {
// Used by IOCP interface, it must be first field
// of the struct, as our code rely on it.
@@ -75,7 +114,7 @@ func (o *anOp) Op() *anOp {
return o
}
-// bufOp is used by io operations that read / write
+// bufOp is used by IO operations that read / write
// data from / to client buffer.
type bufOp struct {
anOp
@@ -92,7 +131,7 @@ func (o *bufOp) Init(fd *netFD, buf []byte, mode int) {
}
}
-// resultSrv will retrieve all io completion results from
+// resultSrv will retrieve all IO completion results from
// iocp and send them to the correspondent waiting client
// goroutine via channel supplied in the request.
type resultSrv struct {
@@ -107,7 +146,7 @@ func (s *resultSrv) Run() {
r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE)
switch {
case r.err == nil:
- // Dequeued successfully completed io packet.
+ // Dequeued successfully completed IO packet.
case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil:
// Wait has timed out (should not happen now, but might be used in the future).
panic("GetQueuedCompletionStatus timed out")
@@ -115,22 +154,23 @@ func (s *resultSrv) Run() {
// Failed to dequeue anything -> report the error.
panic("GetQueuedCompletionStatus failed " + r.err.Error())
default:
- // Dequeued failed io packet.
+ // Dequeued failed IO packet.
}
(*anOp)(unsafe.Pointer(o)).resultc <- r
}
}
-// ioSrv executes net io requests.
+// ioSrv executes net IO requests.
type ioSrv struct {
- submchan chan anOpIface // submit io requests
- canchan chan anOpIface // cancel io requests
+ submchan chan anOpIface // submit IO requests
+ canchan chan anOpIface // cancel IO requests
}
-// ProcessRemoteIO will execute submit io requests on behalf
+// ProcessRemoteIO will execute submit IO requests on behalf
// of other goroutines, all on a single os thread, so it can
// cancel them later. Results of all operations will be sent
// back to their requesters via channel supplied in request.
+// It is used only when the CancelIoEx API is unavailable.
func (s *ioSrv) ProcessRemoteIO() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
@@ -144,20 +184,30 @@ func (s *ioSrv) ProcessRemoteIO() {
}
}
-// ExecIO executes a single io operation. It either executes it
-// inline, or, if a deadline is employed, passes the request onto
+// ExecIO executes a single IO operation oi. It submits and cancels
+// IO in the current thread for systems where Windows CancelIoEx API
+// is available. Alternatively, it passes the request onto
// a special goroutine and waits for completion or cancels request.
// deadline is unix nanos.
func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) {
var err error
o := oi.Op()
+ // Calculate timeout delta.
+ var delta int64
if deadline != 0 {
+ delta = deadline - time.Now().UnixNano()
+ if delta <= 0 {
+ return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, errTimeout}
+ }
+ }
+ // Start IO.
+ if canCancelIO {
+ err = oi.Submit()
+ } else {
// Send request to a special dedicated thread,
- // so it can stop the io with CancelIO later.
+ // so it can stop the IO with CancelIO later.
s.submchan <- oi
err = <-o.errnoc
- } else {
- err = oi.Submit()
}
switch err {
case nil:
@@ -168,27 +218,46 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) {
default:
return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, err}
}
+ // Setup timer, if deadline is given.
+ var timer <-chan time.Time
+ if delta > 0 {
+ t := time.NewTimer(time.Duration(delta) * time.Nanosecond)
+ defer t.Stop()
+ timer = t.C
+ }
// Wait for our request to complete.
var r ioResult
- if deadline != 0 {
- dt := deadline - time.Now().UnixNano()
- if dt < 1 {
- dt = 1
- }
- timer := time.NewTimer(time.Duration(dt) * time.Nanosecond)
- defer timer.Stop()
- select {
- case r = <-o.resultc:
- case <-timer.C:
+ var cancelled, timeout bool
+ select {
+ case r = <-o.resultc:
+ case <-timer:
+ cancelled = true
+ timeout = true
+ case <-o.fd.closec:
+ cancelled = true
+ }
+ if cancelled {
+ // Cancel it.
+ if canCancelIO {
+ err := syscall.CancelIoEx(syscall.Handle(o.Op().fd.sysfd), &o.o)
+ // Assuming ERROR_NOT_FOUND is returned, if IO is completed.
+ if err != nil && err != syscall.ERROR_NOT_FOUND {
+ // TODO(brainman): maybe do something else, but panic.
+ panic(err)
+ }
+ } else {
s.canchan <- oi
<-o.errnoc
- r = <-o.resultc
- if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled
- r.err = syscall.EWOULDBLOCK
- }
}
- } else {
+ // Wait for IO to be canceled or complete successfully.
r = <-o.resultc
+ if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled
+ if timeout {
+ r.err = errTimeout
+ } else {
+ r.err = errClosing
+ }
+ }
}
if r.err != nil {
err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err}
@@ -211,9 +280,13 @@ func startServer() {
go resultsrv.Run()
iosrv = new(ioSrv)
- iosrv.submchan = make(chan anOpIface)
- iosrv.canchan = make(chan anOpIface)
- go iosrv.ProcessRemoteIO()
+ if !canCancelIO {
+ // Only CancelIo API is available. Lets start special goroutine
+ // locked to an OS thread, that both starts and cancels IO.
+ iosrv.submchan = make(chan anOpIface)
+ iosrv.canchan = make(chan anOpIface)
+ go iosrv.ProcessRemoteIO()
+ }
}
// Network file descriptor.
@@ -233,12 +306,13 @@ type netFD struct {
raddr Addr
resultc [2]chan ioResult // read/write completion results
errnoc [2]chan error // read/write submit or cancel operation errors
+ closec chan bool // used by Close to cancel pending IO
+
+ // serialize access to Read and Write methods
+ rio, wio sync.Mutex
- // owned by client
- rdeadline int64
- rio sync.Mutex
- wdeadline int64
- wio sync.Mutex
+ // read and write deadlines
+ rdeadline, wdeadline deadline
}
func allocFD(fd syscall.Handle, family, sotype int, net string) *netFD {
@@ -247,8 +321,8 @@ func allocFD(fd syscall.Handle, family, sotype int, net string) *netFD {
family: family,
sotype: sotype,
net: net,
+ closec: make(chan bool),
}
- runtime.SetFinalizer(netfd, (*netFD).Close)
return netfd
}
@@ -267,13 +341,52 @@ func newFD(fd syscall.Handle, family, proto int, net string) (*netFD, error) {
func (fd *netFD) setAddr(laddr, raddr Addr) {
fd.laddr = laddr
fd.raddr = raddr
+ runtime.SetFinalizer(fd, (*netFD).closesocket)
}
-func (fd *netFD) connect(ra syscall.Sockaddr) error {
- return syscall.Connect(fd.sysfd, ra)
+// Make new connection.
+
+type connectOp struct {
+ anOp
+ ra syscall.Sockaddr
+}
+
+func (o *connectOp) Submit() error {
+ return syscall.ConnectEx(o.fd.sysfd, o.ra, nil, 0, nil, &o.o)
}
-var errClosing = errors.New("use of closed network connection")
+func (o *connectOp) Name() string {
+ return "ConnectEx"
+}
+
+func (fd *netFD) connect(ra syscall.Sockaddr) error {
+ if !canUseConnectEx(fd.net) {
+ return syscall.Connect(fd.sysfd, ra)
+ }
+ // ConnectEx windows API requires an unconnected, previously bound socket.
+ var la syscall.Sockaddr
+ switch ra.(type) {
+ case *syscall.SockaddrInet4:
+ la = &syscall.SockaddrInet4{}
+ case *syscall.SockaddrInet6:
+ la = &syscall.SockaddrInet6{}
+ default:
+ panic("unexpected type in connect")
+ }
+ if err := syscall.Bind(fd.sysfd, la); err != nil {
+ return err
+ }
+ // Call ConnectEx API.
+ var o connectOp
+ o.Init(fd, 'w')
+ o.ra = ra
+ _, err := iosrv.ExecIO(&o, fd.wdeadline.value())
+ if err != nil {
+ return err
+ }
+ // Refresh socket properties.
+ return syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
+}
// Add a reference to this fd.
// If closing==true, mark the fd as closing.
@@ -299,24 +412,12 @@ func (fd *netFD) incref(closing bool) error {
// Remove a reference to this FD and close if we've been asked to do so (and
// there are no references left.
func (fd *netFD) decref() {
+ if fd == nil {
+ return
+ }
fd.sysmu.Lock()
fd.sysref--
- // NOTE(rsc): On Unix we check fd.sysref == 0 here before closing,
- // but on Windows we have no way to wake up the blocked I/O other
- // than closing the socket (or calling Shutdown, which breaks other
- // programs that might have a reference to the socket). So there is
- // a small race here that we might close fd.sysfd and then some other
- // goroutine might start a read of fd.sysfd (having read it before we
- // write InvalidHandle to it), which might refer to some other file
- // if the specific handle value gets reused. I think handle values on
- // Windows are not reused as aggressively as file descriptors on Unix,
- // so this might be tolerable.
- if fd.closing && fd.sysfd != syscall.InvalidHandle {
- // In case the user has set linger, switch to blocking mode so
- // the close blocks. As long as this doesn't happen often, we
- // can handle the extra OS processes. Otherwise we'll need to
- // use the resultsrv for Close too. Sigh.
- syscall.SetNonblock(fd.sysfd, false)
+ if fd.closing && fd.sysref == 0 && fd.sysfd != syscall.InvalidHandle {
closesocket(fd.sysfd)
fd.sysfd = syscall.InvalidHandle
// no need for a finalizer anymore
@@ -329,14 +430,22 @@ func (fd *netFD) Close() error {
if err := fd.incref(true); err != nil {
return err
}
- fd.decref()
+ defer fd.decref()
+ // unblock pending reader and writer
+ close(fd.closec)
+ // wait for both reader and writer to exit
+ fd.rio.Lock()
+ defer fd.rio.Unlock()
+ fd.wio.Lock()
+ defer fd.wio.Unlock()
return nil
}
func (fd *netFD) shutdown(how int) error {
- if fd == nil || fd.sysfd == syscall.InvalidHandle {
- return syscall.EINVAL
+ if err := fd.incref(false); err != nil {
+ return err
}
+ defer fd.decref()
err := syscall.Shutdown(fd.sysfd, how)
if err != nil {
return &OpError{"shutdown", fd.net, fd.laddr, err}
@@ -352,6 +461,10 @@ func (fd *netFD) CloseWrite() error {
return fd.shutdown(syscall.SHUT_WR)
}
+func (fd *netFD) closesocket() error {
+ return closesocket(fd.sysfd)
+}
+
// Read from network.
type readOp struct {
@@ -368,21 +481,15 @@ func (o *readOp) Name() string {
}
func (fd *netFD) Read(buf []byte) (int, error) {
- if fd == nil {
- return 0, syscall.EINVAL
- }
- fd.rio.Lock()
- defer fd.rio.Unlock()
if err := fd.incref(false); err != nil {
return 0, err
}
defer fd.decref()
- if fd.sysfd == syscall.InvalidHandle {
- return 0, syscall.EINVAL
- }
+ fd.rio.Lock()
+ defer fd.rio.Unlock()
var o readOp
o.Init(fd, buf, 'r')
- n, err := iosrv.ExecIO(&o, fd.rdeadline)
+ n, err := iosrv.ExecIO(&o, fd.rdeadline.value())
if err == nil && n == 0 {
err = io.EOF
}
@@ -407,22 +514,19 @@ func (o *readFromOp) Name() string {
}
func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) {
- if fd == nil {
- return 0, nil, syscall.EINVAL
- }
if len(buf) == 0 {
return 0, nil, nil
}
- fd.rio.Lock()
- defer fd.rio.Unlock()
if err := fd.incref(false); err != nil {
return 0, nil, err
}
defer fd.decref()
+ fd.rio.Lock()
+ defer fd.rio.Unlock()
var o readFromOp
o.Init(fd, buf, 'r')
o.rsan = int32(unsafe.Sizeof(o.rsa))
- n, err = iosrv.ExecIO(&o, fd.rdeadline)
+ n, err = iosrv.ExecIO(&o, fd.rdeadline.value())
if err != nil {
return 0, nil, err
}
@@ -446,18 +550,15 @@ func (o *writeOp) Name() string {
}
func (fd *netFD) Write(buf []byte) (int, error) {
- if fd == nil {
- return 0, syscall.EINVAL
- }
- fd.wio.Lock()
- defer fd.wio.Unlock()
if err := fd.incref(false); err != nil {
return 0, err
}
defer fd.decref()
+ fd.wio.Lock()
+ defer fd.wio.Unlock()
var o writeOp
o.Init(fd, buf, 'w')
- return iosrv.ExecIO(&o, fd.wdeadline)
+ return iosrv.ExecIO(&o, fd.wdeadline.value())
}
// WriteTo to network.
@@ -477,25 +578,19 @@ func (o *writeToOp) Name() string {
}
func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (int, error) {
- if fd == nil {
- return 0, syscall.EINVAL
- }
if len(buf) == 0 {
return 0, nil
}
- fd.wio.Lock()
- defer fd.wio.Unlock()
if err := fd.incref(false); err != nil {
return 0, err
}
defer fd.decref()
- if fd.sysfd == syscall.InvalidHandle {
- return 0, syscall.EINVAL
- }
+ fd.wio.Lock()
+ defer fd.wio.Unlock()
var o writeToOp
o.Init(fd, buf, 'w')
o.sa = sa
- return iosrv.ExecIO(&o, fd.wdeadline)
+ return iosrv.ExecIO(&o, fd.wdeadline.value())
}
// Accept new network connections.
@@ -524,19 +619,15 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) {
defer fd.decref()
// Get new socket.
- // See ../syscall/exec.go for description of ForkLock.
- syscall.ForkLock.RLock()
- s, err := syscall.Socket(fd.family, fd.sotype, 0)
+ s, err := sysSocket(fd.family, fd.sotype, 0)
if err != nil {
- syscall.ForkLock.RUnlock()
- return nil, err
+ return nil, &OpError{"socket", fd.net, fd.laddr, err}
}
- syscall.CloseOnExec(s)
- syscall.ForkLock.RUnlock()
// Associate our new socket with IOCP.
onceStartServer.Do(startServer)
if _, err := syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); err != nil {
+ closesocket(s)
return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, err}
}
@@ -544,7 +635,7 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) {
var o acceptOp
o.Init(fd, 'r')
o.newsock = s
- _, err = iosrv.ExecIO(&o, 0)
+ _, err = iosrv.ExecIO(&o, fd.rdeadline.value())
if err != nil {
closesocket(s)
return nil, err
@@ -554,7 +645,7 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) {
err = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
if err != nil {
closesocket(s)
- return nil, err
+ return nil, &OpError{"Setsockopt", fd.net, fd.laddr, err}
}
// Get local and peer addr out of AcceptEx buffer.
diff --git a/src/pkg/net/file_plan9.go b/src/pkg/net/file_plan9.go
index 04f7ee040..f6ee1c29e 100644
--- a/src/pkg/net/file_plan9.go
+++ b/src/pkg/net/file_plan9.go
@@ -5,24 +5,147 @@
package net
import (
+ "errors"
+ "io"
"os"
"syscall"
)
+func (fd *netFD) status(ln int) (string, error) {
+ if !fd.ok() {
+ return "", syscall.EINVAL
+ }
+
+ status, err := os.Open(fd.dir + "/status")
+ if err != nil {
+ return "", err
+ }
+ defer status.Close()
+ buf := make([]byte, ln)
+ n, err := io.ReadFull(status, buf[:])
+ if err != nil {
+ return "", err
+ }
+ return string(buf[:n]), nil
+}
+
+func newFileFD(f *os.File) (net *netFD, err error) {
+ var ctl *os.File
+ close := func(fd int) {
+ if err != nil {
+ syscall.Close(fd)
+ }
+ }
+
+ path, err := syscall.Fd2path(int(f.Fd()))
+ if err != nil {
+ return nil, os.NewSyscallError("fd2path", err)
+ }
+ comp := splitAtBytes(path, "/")
+ n := len(comp)
+ if n < 3 || comp[0] != "net" {
+ return nil, syscall.EPLAN9
+ }
+
+ name := comp[2]
+ switch file := comp[n-1]; file {
+ case "ctl", "clone":
+ syscall.ForkLock.RLock()
+ fd, err := syscall.Dup(int(f.Fd()), -1)
+ syscall.ForkLock.RUnlock()
+ if err != nil {
+ return nil, os.NewSyscallError("dup", err)
+ }
+ defer close(fd)
+
+ dir := "/net/" + comp[n-2]
+ ctl = os.NewFile(uintptr(fd), dir+"/"+file)
+ ctl.Seek(0, 0)
+ var buf [16]byte
+ n, err := ctl.Read(buf[:])
+ if err != nil {
+ return nil, err
+ }
+ name = string(buf[:n])
+ default:
+ if len(comp) < 4 {
+ return nil, errors.New("could not find control file for connection")
+ }
+ dir := "/net/" + comp[1] + "/" + name
+ ctl, err = os.OpenFile(dir+"/ctl", os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+ defer close(int(ctl.Fd()))
+ }
+ dir := "/net/" + comp[1] + "/" + name
+ laddr, err := readPlan9Addr(comp[1], dir+"/local")
+ if err != nil {
+ return nil, err
+ }
+ return newFD(comp[1], name, ctl, nil, laddr, nil), nil
+}
+
+func newFileConn(f *os.File) (c Conn, err error) {
+ fd, err := newFileFD(f)
+ if err != nil {
+ return nil, err
+ }
+ if !fd.ok() {
+ return nil, syscall.EINVAL
+ }
+
+ fd.data, err = os.OpenFile(fd.dir+"/data", os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ switch fd.laddr.(type) {
+ case *TCPAddr:
+ return newTCPConn(fd), nil
+ case *UDPAddr:
+ return newUDPConn(fd), nil
+ }
+ return nil, syscall.EPLAN9
+}
+
+func newFileListener(f *os.File) (l Listener, err error) {
+ fd, err := newFileFD(f)
+ if err != nil {
+ return nil, err
+ }
+ switch fd.laddr.(type) {
+ case *TCPAddr:
+ default:
+ return nil, syscall.EPLAN9
+ }
+
+ // check that file corresponds to a listener
+ s, err := fd.status(len("Listen"))
+ if err != nil {
+ return nil, err
+ }
+ if s != "Listen" {
+ return nil, errors.New("file does not represent a listener")
+ }
+
+ return &TCPListener{fd}, nil
+}
+
// FileConn returns a copy of the network connection corresponding to
// the open file f. It is the caller's responsibility to close f when
// finished. Closing c does not affect f, and closing f does not
// affect c.
func FileConn(f *os.File) (c Conn, err error) {
- return nil, syscall.EPLAN9
+ return newFileConn(f)
}
// FileListener returns a copy of the network listener corresponding
// to the open file f. It is the caller's responsibility to close l
-// when finished. Closing c does not affect l, and closing l does not
-// affect c.
+// when finished. Closing l does not affect f, and closing f does not
+// affect l.
func FileListener(f *os.File) (l Listener, err error) {
- return nil, syscall.EPLAN9
+ return newFileListener(f)
}
// FilePacketConn returns a copy of the packet network connection
diff --git a/src/pkg/net/file_test.go b/src/pkg/net/file_test.go
index 95c0b6699..acaf18851 100644
--- a/src/pkg/net/file_test.go
+++ b/src/pkg/net/file_test.go
@@ -89,9 +89,8 @@ var fileListenerTests = []struct {
func TestFileListener(t *testing.T) {
switch runtime.GOOS {
- case "plan9", "windows":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ case "windows":
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
for _, tt := range fileListenerTests {
@@ -181,8 +180,7 @@ var filePacketConnTests = []struct {
func TestFilePacketConn(t *testing.T) {
switch runtime.GOOS {
case "plan9", "windows":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
for _, tt := range filePacketConnTests {
diff --git a/src/pkg/net/file.go b/src/pkg/net/file_unix.go
index fc6c6fad8..4c8403e40 100644
--- a/src/pkg/net/file.go
+++ b/src/pkg/net/file_unix.go
@@ -12,52 +12,62 @@ import (
)
func newFileFD(f *os.File) (*netFD, error) {
+ syscall.ForkLock.RLock()
fd, err := syscall.Dup(int(f.Fd()))
if err != nil {
+ syscall.ForkLock.RUnlock()
return nil, os.NewSyscallError("dup", err)
}
+ syscall.CloseOnExec(fd)
+ syscall.ForkLock.RUnlock()
+ if err = syscall.SetNonblock(fd, true); err != nil {
+ closesocket(fd)
+ return nil, err
+ }
- proto, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
+ sotype, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
if err != nil {
+ closesocket(fd)
return nil, os.NewSyscallError("getsockopt", err)
}
family := syscall.AF_UNSPEC
toAddr := sockaddrToTCP
- sa, _ := syscall.Getsockname(fd)
- switch sa.(type) {
+ lsa, _ := syscall.Getsockname(fd)
+ switch lsa.(type) {
default:
closesocket(fd)
return nil, syscall.EINVAL
case *syscall.SockaddrInet4:
family = syscall.AF_INET
- if proto == syscall.SOCK_DGRAM {
+ if sotype == syscall.SOCK_DGRAM {
toAddr = sockaddrToUDP
- } else if proto == syscall.SOCK_RAW {
+ } else if sotype == syscall.SOCK_RAW {
toAddr = sockaddrToIP
}
case *syscall.SockaddrInet6:
family = syscall.AF_INET6
- if proto == syscall.SOCK_DGRAM {
+ if sotype == syscall.SOCK_DGRAM {
toAddr = sockaddrToUDP
- } else if proto == syscall.SOCK_RAW {
+ } else if sotype == syscall.SOCK_RAW {
toAddr = sockaddrToIP
}
case *syscall.SockaddrUnix:
family = syscall.AF_UNIX
toAddr = sockaddrToUnix
- if proto == syscall.SOCK_DGRAM {
+ if sotype == syscall.SOCK_DGRAM {
toAddr = sockaddrToUnixgram
- } else if proto == syscall.SOCK_SEQPACKET {
+ } else if sotype == syscall.SOCK_SEQPACKET {
toAddr = sockaddrToUnixpacket
}
}
- laddr := toAddr(sa)
- sa, _ = syscall.Getpeername(fd)
- raddr := toAddr(sa)
+ laddr := toAddr(lsa)
+ rsa, _ := syscall.Getpeername(fd)
+ raddr := toAddr(rsa)
- netfd, err := newFD(fd, family, proto, laddr.Network())
+ netfd, err := newFD(fd, family, sotype, laddr.Network())
if err != nil {
+ closesocket(fd)
return nil, err
}
netfd.setAddr(laddr, raddr)
@@ -78,10 +88,10 @@ func FileConn(f *os.File) (c Conn, err error) {
return newTCPConn(fd), nil
case *UDPAddr:
return newUDPConn(fd), nil
- case *UnixAddr:
- return newUnixConn(fd), nil
case *IPAddr:
return newIPConn(fd), nil
+ case *UnixAddr:
+ return newUnixConn(fd), nil
}
fd.Close()
return nil, syscall.EINVAL
diff --git a/src/pkg/net/http/cgi/child.go b/src/pkg/net/http/cgi/child.go
index 1ba7bec5f..100b8b777 100644
--- a/src/pkg/net/http/cgi/child.go
+++ b/src/pkg/net/http/cgi/child.go
@@ -91,10 +91,19 @@ func RequestFromMap(params map[string]string) (*http.Request, error) {
// TODO: cookies. parsing them isn't exported, though.
+ uriStr := params["REQUEST_URI"]
+ if uriStr == "" {
+ // Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING.
+ uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"]
+ s := params["QUERY_STRING"]
+ if s != "" {
+ uriStr += "?" + s
+ }
+ }
if r.Host != "" {
// Hostname is provided, so we can reasonably construct a URL,
// even if we have to assume 'http' for the scheme.
- rawurl := "http://" + r.Host + params["REQUEST_URI"]
+ rawurl := "http://" + r.Host + uriStr
url, err := url.Parse(rawurl)
if err != nil {
return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl)
@@ -104,7 +113,6 @@ func RequestFromMap(params map[string]string) (*http.Request, error) {
// Fallback logic if we don't have a Host header or the URL
// failed to parse
if r.URL == nil {
- uriStr := params["REQUEST_URI"]
url, err := url.Parse(uriStr)
if err != nil {
return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr)
diff --git a/src/pkg/net/http/cgi/child_test.go b/src/pkg/net/http/cgi/child_test.go
index ec53ab851..74e068014 100644
--- a/src/pkg/net/http/cgi/child_test.go
+++ b/src/pkg/net/http/cgi/child_test.go
@@ -82,6 +82,28 @@ func TestRequestWithoutHost(t *testing.T) {
t.Fatalf("unexpected nil URL")
}
if g, e := req.URL.String(), "/path?a=b"; e != g {
- t.Errorf("expected URL %q; got %q", e, g)
+ t.Errorf("URL = %q; want %q", g, e)
+ }
+}
+
+func TestRequestWithoutRequestURI(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "HTTP_HOST": "example.com",
+ "REQUEST_METHOD": "GET",
+ "SCRIPT_NAME": "/dir/scriptname",
+ "PATH_INFO": "/p1/p2",
+ "QUERY_STRING": "a=1&b=2",
+ "CONTENT_LENGTH": "123",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if req.URL == nil {
+ t.Fatalf("unexpected nil URL")
+ }
+ if g, e := req.URL.String(), "http://example.com/dir/scriptname/p1/p2?a=1&b=2"; e != g {
+ t.Errorf("URL = %q; want %q", g, e)
}
}
diff --git a/src/pkg/net/http/cgi/host_test.go b/src/pkg/net/http/cgi/host_test.go
index 4db3d850c..8c16e6897 100644
--- a/src/pkg/net/http/cgi/host_test.go
+++ b/src/pkg/net/http/cgi/host_test.go
@@ -19,7 +19,6 @@ import (
"runtime"
"strconv"
"strings"
- "syscall"
"testing"
"time"
)
@@ -63,17 +62,25 @@ readlines:
}
for key, expected := range expectedMap {
- if got := m[key]; got != expected {
+ got := m[key]
+ if key == "cwd" {
+ // For Windows. golang.org/issue/4645.
+ fi1, _ := os.Stat(got)
+ fi2, _ := os.Stat(expected)
+ if os.SameFile(fi1, fi2) {
+ got = expected
+ }
+ }
+ if got != expected {
t.Errorf("for key %q got %q; expected %q", key, got, expected)
}
}
return rw
}
-var cgiTested = false
-var cgiWorks bool
+var cgiTested, cgiWorks bool
-func skipTest(t *testing.T) bool {
+func check(t *testing.T) {
if !cgiTested {
cgiTested = true
cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil
@@ -81,16 +88,12 @@ func skipTest(t *testing.T) bool {
if !cgiWorks {
// No Perl on Windows, needed by test.cgi
// TODO: make the child process be Go, not Perl.
- t.Logf("Skipping test: test.cgi failed.")
- return true
+ t.Skip("Skipping test: test.cgi failed.")
}
- return false
}
func TestCGIBasicGet(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
@@ -124,9 +127,7 @@ func TestCGIBasicGet(t *testing.T) {
}
func TestCGIBasicGetAbsPath(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
pwd, err := os.Getwd()
if err != nil {
t.Fatalf("getwd error: %v", err)
@@ -144,9 +145,7 @@ func TestCGIBasicGetAbsPath(t *testing.T) {
}
func TestPathInfo(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
@@ -163,9 +162,7 @@ func TestPathInfo(t *testing.T) {
}
func TestPathInfoDirRoot(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/myscript/",
@@ -181,9 +178,7 @@ func TestPathInfoDirRoot(t *testing.T) {
}
func TestDupHeaders(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
}
@@ -203,9 +198,7 @@ func TestDupHeaders(t *testing.T) {
}
func TestPathInfoNoRoot(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "",
@@ -221,9 +214,7 @@ func TestPathInfoNoRoot(t *testing.T) {
}
func TestCGIBasicPost(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
postReq := `POST /test.cgi?a=b HTTP/1.0
Host: example.com
Content-Type: application/x-www-form-urlencoded
@@ -250,9 +241,7 @@ func chunk(s string) string {
// The CGI spec doesn't allow chunked requests.
func TestCGIPostChunked(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
postReq := `POST /test.cgi?a=b HTTP/1.1
Host: example.com
Content-Type: application/x-www-form-urlencoded
@@ -273,9 +262,7 @@ Transfer-Encoding: chunked
}
func TestRedirect(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
h := &Handler{
Path: "testdata/test.cgi",
Root: "/test.cgi",
@@ -290,9 +277,7 @@ func TestRedirect(t *testing.T) {
}
func TestInternalRedirect(t *testing.T) {
- if skipTest(t) {
- return
- }
+ check(t)
baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path)
fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr)
@@ -312,8 +297,9 @@ func TestInternalRedirect(t *testing.T) {
// TestCopyError tests that we kill the process if there's an error copying
// its output. (for example, from the client having gone away)
func TestCopyError(t *testing.T) {
- if skipTest(t) || runtime.GOOS == "windows" {
- return
+ check(t)
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
h := &Handler{
Path: "testdata/test.cgi",
@@ -353,11 +339,7 @@ func TestCopyError(t *testing.T) {
}
childRunning := func() bool {
- p, err := os.FindProcess(pid)
- if err != nil {
- return false
- }
- return p.Signal(syscall.Signal(0)) == nil
+ return isProcessRunning(t, pid)
}
if !childRunning() {
@@ -376,10 +358,10 @@ func TestCopyError(t *testing.T) {
}
func TestDirUnix(t *testing.T) {
- if skipTest(t) || runtime.GOOS == "windows" {
- return
+ check(t)
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
-
cwd, _ := os.Getwd()
h := &Handler{
Path: "testdata/test.cgi",
@@ -404,8 +386,8 @@ func TestDirUnix(t *testing.T) {
}
func TestDirWindows(t *testing.T) {
- if skipTest(t) || runtime.GOOS != "windows" {
- return
+ if runtime.GOOS != "windows" {
+ t.Skip("Skipping windows specific test.")
}
cgifile, _ := filepath.Abs("testdata/test.cgi")
@@ -414,7 +396,7 @@ func TestDirWindows(t *testing.T) {
var err error
perl, err = exec.LookPath("perl")
if err != nil {
- return
+ t.Skip("Skipping test: perl not found.")
}
perl, _ = filepath.Abs(perl)
@@ -456,7 +438,7 @@ func TestEnvOverride(t *testing.T) {
var err error
perl, err = exec.LookPath("perl")
if err != nil {
- return
+ t.Skipf("Skipping test: perl not found.")
}
perl, _ = filepath.Abs(perl)
diff --git a/src/pkg/net/http/cgi/plan9_test.go b/src/pkg/net/http/cgi/plan9_test.go
new file mode 100644
index 000000000..c8235831b
--- /dev/null
+++ b/src/pkg/net/http/cgi/plan9_test.go
@@ -0,0 +1,18 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build plan9
+
+package cgi
+
+import (
+ "os"
+ "strconv"
+ "testing"
+)
+
+func isProcessRunning(t *testing.T, pid int) bool {
+ _, err := os.Stat("/proc/" + strconv.Itoa(pid))
+ return err == nil
+}
diff --git a/src/pkg/net/http/cgi/posix_test.go b/src/pkg/net/http/cgi/posix_test.go
new file mode 100644
index 000000000..5ff9e7d5e
--- /dev/null
+++ b/src/pkg/net/http/cgi/posix_test.go
@@ -0,0 +1,21 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !plan9
+
+package cgi
+
+import (
+ "os"
+ "syscall"
+ "testing"
+)
+
+func isProcessRunning(t *testing.T, pid int) bool {
+ p, err := os.FindProcess(pid)
+ if err != nil {
+ return false
+ }
+ return p.Signal(syscall.Signal(0)) == nil
+}
diff --git a/src/pkg/net/http/cgi/testdata/test.cgi b/src/pkg/net/http/cgi/testdata/test.cgi
index b46b1330f..3214df6f0 100755
--- a/src/pkg/net/http/cgi/testdata/test.cgi
+++ b/src/pkg/net/http/cgi/testdata/test.cgi
@@ -8,6 +8,8 @@
use strict;
use Cwd;
+binmode STDOUT;
+
my $q = MiniCGI->new;
my $params = $q->Vars;
@@ -16,51 +18,44 @@ if ($params->{"loc"}) {
exit(0);
}
-my $NL = "\r\n";
-$NL = "\n" if $params->{mode} eq "NL";
-
-my $p = sub {
- print "$_[0]$NL";
-};
-
-# With carriage returns
-$p->("Content-Type: text/html");
-$p->("X-CGI-Pid: $$");
-$p->("X-Test-Header: X-Test-Value");
-$p->("");
+print "Content-Type: text/html\r\n";
+print "X-CGI-Pid: $$\r\n";
+print "X-Test-Header: X-Test-Value\r\n";
+print "\r\n";
if ($params->{"bigresponse"}) {
- for (1..1024) {
- print "A" x 1024, "\n";
+ # 17 MB, for OS X: golang.org/issue/4958
+ for (1..(17 * 1024)) {
+ print "A" x 1024, "\r\n";
}
exit 0;
}
-print "test=Hello CGI\n";
+print "test=Hello CGI\r\n";
foreach my $k (sort keys %$params) {
- print "param-$k=$params->{$k}\n";
+ print "param-$k=$params->{$k}\r\n";
}
foreach my $k (sort keys %ENV) {
- my $clean_env = $ENV{$k};
- $clean_env =~ s/[\n\r]//g;
- print "env-$k=$clean_env\n";
+ my $clean_env = $ENV{$k};
+ $clean_env =~ s/[\n\r]//g;
+ print "env-$k=$clean_env\r\n";
}
-# NOTE: don't call getcwd() for windows.
-# msys return /c/go/src/... not C:\go\...
-my $dir;
+# NOTE: msys perl returns /c/go/src/... not C:\go\....
+my $dir = getcwd();
if ($^O eq 'MSWin32' || $^O eq 'msys') {
- my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe';
- $cmd =~ s!\\!/!g;
- $dir = `$cmd /c cd`;
- chomp $dir;
-} else {
- $dir = getcwd();
+ if ($dir =~ /^.:/) {
+ $dir =~ s!/!\\!g;
+ } else {
+ my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe';
+ $cmd =~ s!\\!/!g;
+ $dir = `$cmd /c cd`;
+ chomp $dir;
+ }
}
-print "cwd=$dir\n";
-
+print "cwd=$dir\r\n";
# A minimal version of CGI.pm, for people without the perl-modules
# package installed. (CGI.pm used to be part of the Perl core, but
diff --git a/src/pkg/net/http/chunked.go b/src/pkg/net/http/chunked.go
index 60a478fd8..91db01724 100644
--- a/src/pkg/net/http/chunked.go
+++ b/src/pkg/net/http/chunked.go
@@ -11,10 +11,9 @@ package http
import (
"bufio"
- "bytes"
"errors"
+ "fmt"
"io"
- "strconv"
)
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
@@ -22,7 +21,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
// newChunkedReader returns a new chunkedReader that translates the data read from r
-// out of HTTP "chunked" format before returning it.
+// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// newChunkedReader is not needed by normal applications. The http package
@@ -39,16 +38,17 @@ type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
+ buf [2]byte
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
- var line string
+ var line []byte
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
- cr.n, cr.err = strconv.ParseUint(line, 16, 64)
+ cr.n, cr.err = parseHexUint(line)
if cr.err != nil {
return
}
@@ -74,9 +74,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
- b := make([]byte, 2)
- if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
- if b[0] != '\r' || b[1] != '\n' {
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil {
+ if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
@@ -88,7 +87,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
-func readLineBytes(b *bufio.Reader) (p []byte, err error) {
+func readLine(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
@@ -102,20 +101,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
-
- // Chop off trailing white space.
- p = bytes.TrimRight(p, " \r\t\n")
-
- return p, nil
+ return trimTrailingWhitespace(p), nil
}
-// readLineBytes, but convert the bytes into a string.
-func readLine(b *bufio.Reader) (s string, err error) {
- p, e := readLineBytes(b)
- if e != nil {
- return "", e
+func trimTrailingWhitespace(b []byte) []byte {
+ for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+ b = b[:len(b)-1]
}
- return string(p), nil
+ return b
+}
+
+func isASCIISpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
@@ -147,9 +144,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
return 0, nil
}
- head := strconv.FormatInt(int64(len(data)), 16) + "\r\n"
-
- if _, err = io.WriteString(cw.Wire, head); err != nil {
+ if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil {
return 0, err
}
if n, err = cw.Wire.Write(data); err != nil {
@@ -168,3 +163,21 @@ func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n")
return err
}
+
+func parseHexUint(v []byte) (n uint64, err error) {
+ for _, b := range v {
+ n <<= 4
+ switch {
+ case '0' <= b && b <= '9':
+ b = b - '0'
+ case 'a' <= b && b <= 'f':
+ b = b - 'a' + 10
+ case 'A' <= b && b <= 'F':
+ b = b - 'A' + 10
+ default:
+ return 0, errors.New("invalid byte in chunk length")
+ }
+ n |= uint64(b)
+ }
+ return
+}
diff --git a/src/pkg/net/http/chunked_test.go b/src/pkg/net/http/chunked_test.go
index b77ee2ff2..0b18c7b55 100644
--- a/src/pkg/net/http/chunked_test.go
+++ b/src/pkg/net/http/chunked_test.go
@@ -9,7 +9,10 @@ package http
import (
"bytes"
+ "fmt"
+ "io"
"io/ioutil"
+ "runtime"
"testing"
)
@@ -37,3 +40,54 @@ func TestChunk(t *testing.T) {
t.Errorf("chunk reader read %q; want %q", g, e)
}
}
+
+func TestChunkReaderAllocs(t *testing.T) {
+ // temporarily set GOMAXPROCS to 1 as we are testing memory allocations
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ var buf bytes.Buffer
+ w := newChunkedWriter(&buf)
+ a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
+ w.Write(a)
+ w.Write(b)
+ w.Write(c)
+ w.Close()
+
+ r := newChunkedReader(&buf)
+ readBuf := make([]byte, len(a)+len(b)+len(c)+1)
+
+ var ms runtime.MemStats
+ runtime.ReadMemStats(&ms)
+ m0 := ms.Mallocs
+
+ n, err := io.ReadFull(r, readBuf)
+
+ runtime.ReadMemStats(&ms)
+ mallocs := ms.Mallocs - m0
+ if mallocs > 1 {
+ t.Errorf("%d mallocs; want <= 1", mallocs)
+ }
+
+ if n != len(readBuf)-1 {
+ t.Errorf("read %d bytes; want %d", n, len(readBuf)-1)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("read error = %v; want ErrUnexpectedEOF", err)
+ }
+}
+
+func TestParseHexUint(t *testing.T) {
+ for i := uint64(0); i <= 1234; i++ {
+ line := []byte(fmt.Sprintf("%x", i))
+ got, err := parseHexUint(line)
+ if err != nil {
+ t.Fatalf("on %d: %v", i, err)
+ }
+ if got != i {
+ t.Errorf("for input %q = %d; want %d", line, got, i)
+ }
+ }
+ _, err := parseHexUint([]byte("bogus"))
+ if err == nil {
+ t.Error("expected error on bogus input")
+ }
+}
diff --git a/src/pkg/net/http/client.go b/src/pkg/net/http/client.go
index 54564e098..5ee0804c7 100644
--- a/src/pkg/net/http/client.go
+++ b/src/pkg/net/http/client.go
@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
// HTTP client. See RFC 2616.
-//
+//
// This is the high-level Client interface.
// The low-level implementation is in transport.go.
@@ -14,6 +14,7 @@ import (
"errors"
"fmt"
"io"
+ "log"
"net/url"
"strings"
)
@@ -32,17 +33,19 @@ type Client struct {
// CheckRedirect specifies the policy for handling redirects.
// If CheckRedirect is not nil, the client calls it before
- // following an HTTP redirect. The arguments req and via
- // are the upcoming request and the requests made already,
- // oldest first. If CheckRedirect returns an error, the client
- // returns that error instead of issue the Request req.
+ // following an HTTP redirect. The arguments req and via are
+ // the upcoming request and the requests made already, oldest
+ // first. If CheckRedirect returns an error, the Client's Get
+ // method returns both the previous Response and
+ // CheckRedirect's error (wrapped in a url.Error) instead of
+ // issuing the Request req.
//
// If CheckRedirect is nil, the Client uses its default policy,
// which is to stop after 10 consecutive requests.
CheckRedirect func(req *Request, via []*Request) error
- // Jar specifies the cookie jar.
- // If Jar is nil, cookies are not sent in requests and ignored
+ // Jar specifies the cookie jar.
+ // If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar CookieJar
}
@@ -84,10 +87,32 @@ type readClose struct {
io.Closer
}
+func (c *Client) send(req *Request) (*Response, error) {
+ if c.Jar != nil {
+ for _, cookie := range c.Jar.Cookies(req.URL) {
+ req.AddCookie(cookie)
+ }
+ }
+ resp, err := send(req, c.Transport)
+ if err != nil {
+ return nil, err
+ }
+ if c.Jar != nil {
+ if rc := resp.Cookies(); len(rc) > 0 {
+ c.Jar.SetCookies(req.URL, rc)
+ }
+ }
+ return resp, err
+}
+
// Do sends an HTTP request and returns an HTTP response, following
// policy (e.g. redirects, cookies, auth) as configured on the client.
//
-// A non-nil response always contains a non-nil resp.Body.
+// An error is returned if caused by client policy (such as
+// CheckRedirect), or if there was an HTTP protocol error.
+// A non-2xx response doesn't cause an error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
//
// Callers should close resp.Body when done reading from it. If
// resp.Body is not closed, the Client's underlying RoundTripper
@@ -97,12 +122,16 @@ type readClose struct {
// Generally Get, Post, or PostForm will be used instead of Do.
func (c *Client) Do(req *Request) (resp *Response, err error) {
if req.Method == "GET" || req.Method == "HEAD" {
- return c.doFollowingRedirects(req)
+ return c.doFollowingRedirects(req, shouldRedirectGet)
}
- return send(req, c.Transport)
+ if req.Method == "POST" || req.Method == "PUT" {
+ return c.doFollowingRedirects(req, shouldRedirectPost)
+ }
+ return c.send(req)
}
-// send issues an HTTP request. Caller should close resp.Body when done reading from it.
+// send issues an HTTP request.
+// Caller should close resp.Body when done reading from it.
func send(req *Request, t RoundTripper) (resp *Response, err error) {
if t == nil {
t = DefaultTransport
@@ -130,12 +159,19 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) {
if u := req.URL.User; u != nil {
req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(u.String())))
}
- return t.RoundTrip(req)
+ resp, err = t.RoundTrip(req)
+ if err != nil {
+ if resp != nil {
+ log.Printf("RoundTripper returned a response & error; ignoring response")
+ }
+ return nil, err
+ }
+ return resp, nil
}
// True if the specified HTTP status code is one for which the Get utility should
// automatically redirect.
-func shouldRedirect(statusCode int) bool {
+func shouldRedirectGet(statusCode int) bool {
switch statusCode {
case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect:
return true
@@ -143,6 +179,16 @@ func shouldRedirect(statusCode int) bool {
return false
}
+// True if the specified HTTP status code is one for which the Post utility should
+// automatically redirect.
+func shouldRedirectPost(statusCode int) bool {
+ switch statusCode {
+ case StatusFound, StatusSeeOther:
+ return true
+ }
+ return false
+}
+
// Get issues a GET to the specified URL. If the response is one of the following
// redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
//
@@ -151,10 +197,15 @@ func shouldRedirect(statusCode int) bool {
// 303 (See Other)
// 307 (Temporary Redirect)
//
-// Caller should close r.Body when done reading from it.
+// An error is returned if there were too many redirects or if there
+// was an HTTP protocol error. A non-2xx response doesn't cause an
+// error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
//
// Get is a wrapper around DefaultClient.Get.
-func Get(url string) (r *Response, err error) {
+func Get(url string) (resp *Response, err error) {
return DefaultClient.Get(url)
}
@@ -167,18 +218,21 @@ func Get(url string) (r *Response, err error) {
// 303 (See Other)
// 307 (Temporary Redirect)
//
-// Caller should close r.Body when done reading from it.
-func (c *Client) Get(url string) (r *Response, err error) {
+// An error is returned if the Client's CheckRedirect function fails
+// or if there was an HTTP protocol error. A non-2xx response doesn't
+// cause an error.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+func (c *Client) Get(url string) (resp *Response, err error) {
req, err := NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
- return c.doFollowingRedirects(req)
+ return c.doFollowingRedirects(req, shouldRedirectGet)
}
-func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
- // TODO: if/when we add cookie support, the redirected request shouldn't
- // necessarily supply the same cookies as the original.
+func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) {
var base *url.URL
redirectChecker := c.CheckRedirect
if redirectChecker == nil {
@@ -190,17 +244,16 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
return nil, errors.New("http: nil Request.URL")
}
- jar := c.Jar
- if jar == nil {
- jar = blackHoleJar{}
- }
-
req := ireq
urlStr := "" // next relative or absolute URL to fetch (after first request)
+ redirectFailed := false
for redirect := 0; ; redirect++ {
if redirect != 0 {
req = new(Request)
req.Method = ireq.Method
+ if ireq.Method == "POST" || ireq.Method == "PUT" {
+ req.Method = "GET"
+ }
req.Header = make(Header)
req.URL, err = base.Parse(urlStr)
if err != nil {
@@ -215,26 +268,21 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
err = redirectChecker(req, via)
if err != nil {
+ redirectFailed = true
break
}
}
}
- for _, cookie := range jar.Cookies(req.URL) {
- req.AddCookie(cookie)
- }
urlStr = req.URL.String()
- if r, err = send(req, c.Transport); err != nil {
+ if resp, err = c.send(req); err != nil {
break
}
- if c := r.Cookies(); len(c) > 0 {
- jar.SetCookies(req.URL, c)
- }
- if shouldRedirect(r.StatusCode) {
- r.Body.Close()
- if urlStr = r.Header.Get("Location"); urlStr == "" {
- err = errors.New(fmt.Sprintf("%d response missing Location header", r.StatusCode))
+ if shouldRedirect(resp.StatusCode) {
+ resp.Body.Close()
+ if urlStr = resp.Header.Get("Location"); urlStr == "" {
+ err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode))
break
}
base = req.URL
@@ -245,12 +293,23 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) {
}
method := ireq.Method
- err = &url.Error{
+ urlErr := &url.Error{
Op: method[0:1] + strings.ToLower(method[1:]),
URL: urlStr,
Err: err,
}
- return
+
+ if redirectFailed {
+ // Special case for Go 1 compatibility: return both the response
+ // and an error if the CheckRedirect function failed.
+ // See http://golang.org/issue/3795
+ return resp, urlErr
+ }
+
+ if resp != nil {
+ resp.Body.Close()
+ }
+ return nil, urlErr
}
func defaultCheckRedirect(req *Request, via []*Request) error {
@@ -262,49 +321,42 @@ func defaultCheckRedirect(req *Request, via []*Request) error {
// Post issues a POST to the specified URL.
//
-// Caller should close r.Body when done reading from it.
+// Caller should close resp.Body when done reading from it.
//
// Post is a wrapper around DefaultClient.Post
-func Post(url string, bodyType string, body io.Reader) (r *Response, err error) {
+func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
return DefaultClient.Post(url, bodyType, body)
}
// Post issues a POST to the specified URL.
//
-// Caller should close r.Body when done reading from it.
-func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err error) {
+// Caller should close resp.Body when done reading from it.
+func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
req, err := NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", bodyType)
- if c.Jar != nil {
- for _, cookie := range c.Jar.Cookies(req.URL) {
- req.AddCookie(cookie)
- }
- }
- r, err = send(req, c.Transport)
- if err == nil && c.Jar != nil {
- c.Jar.SetCookies(req.URL, r.Cookies())
- }
- return r, err
+ return c.doFollowingRedirects(req, shouldRedirectPost)
}
-// PostForm issues a POST to the specified URL,
-// with data's keys and values urlencoded as the request body.
+// PostForm issues a POST to the specified URL, with data's keys and
+// values URL-encoded as the request body.
//
-// Caller should close r.Body when done reading from it.
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
//
// PostForm is a wrapper around DefaultClient.PostForm
-func PostForm(url string, data url.Values) (r *Response, err error) {
+func PostForm(url string, data url.Values) (resp *Response, err error) {
return DefaultClient.PostForm(url, data)
}
-// PostForm issues a POST to the specified URL,
+// PostForm issues a POST to the specified URL,
// with data's keys and values urlencoded as the request body.
//
-// Caller should close r.Body when done reading from it.
-func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) {
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) {
return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
@@ -318,7 +370,7 @@ func (c *Client) PostForm(url string, data url.Values) (r *Response, err error)
// 307 (Temporary Redirect)
//
// Head is a wrapper around DefaultClient.Head
-func Head(url string) (r *Response, err error) {
+func Head(url string) (resp *Response, err error) {
return DefaultClient.Head(url)
}
@@ -330,10 +382,10 @@ func Head(url string) (r *Response, err error) {
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
-func (c *Client) Head(url string) (r *Response, err error) {
+func (c *Client) Head(url string) (resp *Response, err error) {
req, err := NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
- return c.doFollowingRedirects(req)
+ return c.doFollowingRedirects(req, shouldRedirectGet)
}
diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go
index 9b4261b9f..88649bb16 100644
--- a/src/pkg/net/http/client_test.go
+++ b/src/pkg/net/http/client_test.go
@@ -7,7 +7,9 @@
package http_test
import (
+ "bytes"
"crypto/tls"
+ "crypto/x509"
"errors"
"fmt"
"io"
@@ -53,6 +55,7 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) {
}
func TestClient(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(robotsTxtHandler)
defer ts.Close()
@@ -70,6 +73,7 @@ func TestClient(t *testing.T) {
}
func TestClientHead(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(robotsTxtHandler)
defer ts.Close()
@@ -92,6 +96,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error)
}
func TestGetRequestFormat(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &recordingTransport{}
client := &Client{Transport: tr}
url := "http://dummy.faketld/"
@@ -108,6 +113,7 @@ func TestGetRequestFormat(t *testing.T) {
}
func TestPostRequestFormat(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &recordingTransport{}
client := &Client{Transport: tr}
@@ -134,6 +140,7 @@ func TestPostRequestFormat(t *testing.T) {
}
func TestPostFormRequestFormat(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &recordingTransport{}
client := &Client{Transport: tr}
@@ -175,6 +182,7 @@ func TestPostFormRequestFormat(t *testing.T) {
}
func TestRedirects(t *testing.T) {
+ defer checkLeakedTransports(t)
var ts *httptest.Server
ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
n, _ := strconv.Atoi(r.FormValue("n"))
@@ -218,6 +226,10 @@ func TestRedirects(t *testing.T) {
return checkErr
}}
res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ res.Body.Close()
finalUrl := res.Request.URL.String()
if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
t.Errorf("with custom client, expected error %q, got %q", e, g)
@@ -231,9 +243,63 @@ func TestRedirects(t *testing.T) {
checkErr = errors.New("no redirects allowed")
res, err = c.Get(ts.URL)
- finalUrl = res.Request.URL.String()
- if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g {
- t.Errorf("with redirects forbidden, expected error %q, got %q", e, g)
+ if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
+ t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err)
+ }
+ if res == nil {
+ t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)")
+ }
+ res.Body.Close()
+ if res.Header.Get("Location") == "" {
+ t.Errorf("no Location header in Response")
+ }
+}
+
+func TestPostRedirects(t *testing.T) {
+ defer checkLeakedTransports(t)
+ var log struct {
+ sync.Mutex
+ bytes.Buffer
+ }
+ var ts *httptest.Server
+ ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ log.Lock()
+ fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI)
+ log.Unlock()
+ if v := r.URL.Query().Get("code"); v != "" {
+ code, _ := strconv.Atoi(v)
+ if code/100 == 3 {
+ w.Header().Set("Location", ts.URL)
+ }
+ w.WriteHeader(code)
+ }
+ }))
+ defer ts.Close()
+ tests := []struct {
+ suffix string
+ want int // response code
+ }{
+ {"/", 200},
+ {"/?code=301", 301},
+ {"/?code=302", 200},
+ {"/?code=303", 200},
+ {"/?code=404", 404},
+ }
+ for _, tt := range tests {
+ res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
+ }
+ }
+ log.Lock()
+ got := log.String()
+ log.Unlock()
+ want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 "
+ if got != want {
+ t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want)
}
}
@@ -279,6 +345,10 @@ func TestClientSendsCookieFromJar(t *testing.T) {
req, _ := NewRequest("GET", us, nil)
client.Do(req) // Note: doesn't hit network
matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ req, _ = NewRequest("POST", us, nil)
+ client.Do(req) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
}
// Just enough correctness for our redirect tests. Uses the URL.Host as the
@@ -291,6 +361,9 @@ type TestJar struct {
func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) {
j.m.Lock()
defer j.m.Unlock()
+ if j.perURL == nil {
+ j.perURL = make(map[string][]*Cookie)
+ }
j.perURL[u.Host] = cookies
}
@@ -301,6 +374,7 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie {
}
func TestRedirectCookiesOnRequest(t *testing.T) {
+ defer checkLeakedTransports(t)
var ts *httptest.Server
ts = httptest.NewServer(echoCookiesRedirectHandler)
defer ts.Close()
@@ -318,14 +392,20 @@ func TestRedirectCookiesOnRequest(t *testing.T) {
}
func TestRedirectCookiesJar(t *testing.T) {
+ defer checkLeakedTransports(t)
var ts *httptest.Server
ts = httptest.NewServer(echoCookiesRedirectHandler)
defer ts.Close()
- c := &Client{}
- c.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
+ c := &Client{
+ Jar: new(TestJar),
+ }
u, _ := url.Parse(ts.URL)
c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
- resp, _ := c.Get(ts.URL)
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ resp.Body.Close()
matchReturnedCookies(t, expectedCookies, resp.Cookies())
}
@@ -348,7 +428,72 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
}
}
+func TestJarCalls(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ pathSuffix := r.RequestURI[1:]
+ if r.RequestURI == "/nosetcookie" {
+ return // dont set cookies for this path
+ }
+ SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix})
+ if r.RequestURI == "/" {
+ Redirect(w, r, "http://secondhost.fake/secondpath", 302)
+ }
+ }))
+ defer ts.Close()
+ jar := new(RecordingJar)
+ c := &Client{
+ Jar: jar,
+ Transport: &Transport{
+ Dial: func(_ string, _ string) (net.Conn, error) {
+ return net.Dial("tcp", ts.Listener.Addr().String())
+ },
+ },
+ }
+ _, err := c.Get("http://firsthost.fake/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Get("http://firsthost.fake/nosetcookie")
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := jar.log.String()
+ want := `Cookies("http://firsthost.fake/")
+SetCookie("http://firsthost.fake/", [name=val])
+Cookies("http://secondhost.fake/secondpath")
+SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath])
+Cookies("http://firsthost.fake/nosetcookie")
+`
+ if got != want {
+ t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want)
+ }
+}
+
+// RecordingJar keeps a log of calls made to it, without
+// tracking any cookies.
+type RecordingJar struct {
+ mu sync.Mutex
+ log bytes.Buffer
+}
+
+func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) {
+ j.logf("SetCookie(%q, %v)\n", u, cookies)
+}
+
+func (j *RecordingJar) Cookies(u *url.URL) []*Cookie {
+ j.logf("Cookies(%q)\n", u)
+ return nil
+}
+
+func (j *RecordingJar) logf(format string, args ...interface{}) {
+ j.mu.Lock()
+ defer j.mu.Unlock()
+ fmt.Fprintf(&j.log, format, args...)
+}
+
func TestStreamingGet(t *testing.T) {
+ defer checkLeakedTransports(t)
say := make(chan string)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.(Flusher).Flush()
@@ -399,6 +544,7 @@ func (c *writeCountingConn) Write(p []byte) (int, error) {
// TestClientWrites verifies that client requests are buffered and we
// don't send a TCP packet per line of the http request + body.
func TestClientWrites(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
}))
defer ts.Close()
@@ -432,6 +578,7 @@ func TestClientWrites(t *testing.T) {
}
func TestClientInsecureTransport(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("Hello"))
}))
@@ -446,15 +593,20 @@ func TestClientInsecureTransport(t *testing.T) {
InsecureSkipVerify: insecure,
},
}
+ defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
- _, err := c.Get(ts.URL)
+ res, err := c.Get(ts.URL)
if (err == nil) != insecure {
t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
}
+ if res != nil {
+ res.Body.Close()
+ }
}
}
func TestClientErrorWithRequestURI(t *testing.T) {
+ defer checkLeakedTransports(t)
req, _ := NewRequest("GET", "http://localhost:1234/", nil)
req.RequestURI = "/this/field/is/illegal/and/should/error/"
_, err := DefaultClient.Do(req)
@@ -465,3 +617,87 @@ func TestClientErrorWithRequestURI(t *testing.T) {
t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
}
}
+
+func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
+ certs := x509.NewCertPool()
+ for _, c := range ts.TLS.Certificates {
+ roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+ if err != nil {
+ t.Fatalf("error parsing server's root cert: %v", err)
+ }
+ for _, root := range roots {
+ certs.AddCert(root)
+ }
+ }
+ return &Transport{
+ TLSClientConfig: &tls.Config{RootCAs: certs},
+ }
+}
+
+func TestClientWithCorrectTLSServerName(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.TLS.ServerName != "127.0.0.1" {
+ t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName)
+ }
+ }))
+ defer ts.Close()
+
+ c := &Client{Transport: newTLSTransport(t, ts)}
+ if _, err := c.Get(ts.URL); err != nil {
+ t.Fatalf("expected successful TLS connection, got error: %v", err)
+ }
+}
+
+func TestClientWithIncorrectTLSServerName(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer ts.Close()
+
+ trans := newTLSTransport(t, ts)
+ trans.TLSClientConfig.ServerName = "badserver"
+ c := &Client{Transport: trans}
+ _, err := c.Get(ts.URL)
+ if err == nil {
+ t.Fatalf("expected an error")
+ }
+ if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
+ t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
+ }
+}
+
+// Verify Response.ContentLength is populated. http://golang.org/issue/4126
+func TestClientHeadContentLength(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if v := r.FormValue("cl"); v != "" {
+ w.Header().Set("Content-Length", v)
+ }
+ }))
+ defer ts.Close()
+ tests := []struct {
+ suffix string
+ want int64
+ }{
+ {"/?cl=1234", 1234},
+ {"/?cl=0", 0},
+ {"", -1},
+ }
+ for _, tt := range tests {
+ req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil)
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.ContentLength != tt.want {
+ t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
+ }
+ bs, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != 0 {
+ t.Errorf("Unexpected content: %q", bs)
+ }
+ }
+}
diff --git a/src/pkg/net/http/cookie.go b/src/pkg/net/http/cookie.go
index 2e30bbff1..155b09223 100644
--- a/src/pkg/net/http/cookie.go
+++ b/src/pkg/net/http/cookie.go
@@ -26,7 +26,7 @@ type Cookie struct {
Expires time.Time
RawExpires string
- // MaxAge=0 means no 'Max-Age' attribute specified.
+ // MaxAge=0 means no 'Max-Age' attribute specified.
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
// MaxAge>0 means Max-Age attribute present and given in seconds
MaxAge int
@@ -258,10 +258,5 @@ func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool)
}
func isCookieNameValid(raw string) bool {
- for _, c := range raw {
- if !isToken(byte(c)) {
- return false
- }
- }
- return true
+ return strings.IndexFunc(raw, isNotToken) < 0
}
diff --git a/src/pkg/net/http/cookie_test.go b/src/pkg/net/http/cookie_test.go
index 1e9186a05..f84f73936 100644
--- a/src/pkg/net/http/cookie_test.go
+++ b/src/pkg/net/http/cookie_test.go
@@ -217,7 +217,7 @@ var readCookiesTests = []struct {
func TestReadCookies(t *testing.T) {
for i, tt := range readCookiesTests {
- for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input
+ for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input
c := readCookies(tt.Header, tt.Filter)
if !reflect.DeepEqual(c, tt.Cookies) {
t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies))
diff --git a/src/pkg/net/http/cookiejar/jar.go b/src/pkg/net/http/cookiejar/jar.go
new file mode 100644
index 000000000..5d1aeb87f
--- /dev/null
+++ b/src/pkg/net/http/cookiejar/jar.go
@@ -0,0 +1,494 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar.
+package cookiejar
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+)
+
+// PublicSuffixList provides the public suffix of a domain. For example:
+// - the public suffix of "example.com" is "com",
+// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and
+// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us".
+//
+// Implementations of PublicSuffixList must be safe for concurrent use by
+// multiple goroutines.
+//
+// An implementation that always returns "" is valid and may be useful for
+// testing but it is not secure: it means that the HTTP server for foo.com can
+// set a cookie for bar.com.
+type PublicSuffixList interface {
+ // PublicSuffix returns the public suffix of domain.
+ //
+ // TODO: specify which of the caller and callee is responsible for IP
+ // addresses, for leading and trailing dots, for case sensitivity, and
+ // for IDN/Punycode.
+ PublicSuffix(domain string) string
+
+ // String returns a description of the source of this public suffix
+ // list. The description will typically contain something like a time
+ // stamp or version number.
+ String() string
+}
+
+// Options are the options for creating a new Jar.
+type Options struct {
+ // PublicSuffixList is the public suffix list that determines whether
+ // an HTTP server can set a cookie for a domain.
+ //
+ // A nil value is valid and may be useful for testing but it is not
+ // secure: it means that the HTTP server for foo.co.uk can set a cookie
+ // for bar.co.uk.
+ PublicSuffixList PublicSuffixList
+}
+
+// Jar implements the http.CookieJar interface from the net/http package.
+type Jar struct {
+ psList PublicSuffixList
+
+ // mu locks the remaining fields.
+ mu sync.Mutex
+
+ // entries is a set of entries, keyed by their eTLD+1 and subkeyed by
+ // their name/domain/path.
+ entries map[string]map[string]entry
+
+ // nextSeqNum is the next sequence number assigned to a new cookie
+ // created SetCookies.
+ nextSeqNum uint64
+}
+
+// New returns a new cookie jar. A nil *Options is equivalent to a zero
+// Options.
+func New(o *Options) (*Jar, error) {
+ jar := &Jar{
+ entries: make(map[string]map[string]entry),
+ }
+ if o != nil {
+ jar.psList = o.PublicSuffixList
+ }
+ return jar, nil
+}
+
+// entry is the internal representation of a cookie.
+//
+// This struct type is not used outside of this package per se, but the exported
+// fields are those of RFC 6265.
+type entry struct {
+ Name string
+ Value string
+ Domain string
+ Path string
+ Secure bool
+ HttpOnly bool
+ Persistent bool
+ HostOnly bool
+ Expires time.Time
+ Creation time.Time
+ LastAccess time.Time
+
+ // seqNum is a sequence number so that Cookies returns cookies in a
+ // deterministic order, even for cookies that have equal Path length and
+ // equal Creation time. This simplifies testing.
+ seqNum uint64
+}
+
+// Id returns the domain;path;name triple of e as an id.
+func (e *entry) id() string {
+ return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
+}
+
+// shouldSend determines whether e's cookie qualifies to be included in a
+// request to host/path. It is the caller's responsibility to check if the
+// cookie is expired.
+func (e *entry) shouldSend(https bool, host, path string) bool {
+ return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
+}
+
+// domainMatch implements "domain-match" of RFC 6265 section 5.1.3.
+func (e *entry) domainMatch(host string) bool {
+ if e.Domain == host {
+ return true
+ }
+ return !e.HostOnly && hasDotSuffix(host, e.Domain)
+}
+
+// pathMatch implements "path-match" according to RFC 6265 section 5.1.4.
+func (e *entry) pathMatch(requestPath string) bool {
+ if requestPath == e.Path {
+ return true
+ }
+ if strings.HasPrefix(requestPath, e.Path) {
+ if e.Path[len(e.Path)-1] == '/' {
+ return true // The "/any/" matches "/any/path" case.
+ } else if requestPath[len(e.Path)] == '/' {
+ return true // The "/any" matches "/any/path" case.
+ }
+ }
+ return false
+}
+
+// hasDotSuffix returns whether s ends in "."+suffix.
+func hasDotSuffix(s, suffix string) bool {
+ return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
+}
+
+// byPathLength is a []entry sort.Interface that sorts according to RFC 6265
+// section 5.4 point 2: by longest path and then by earliest creation time.
+type byPathLength []entry
+
+func (s byPathLength) Len() int { return len(s) }
+
+func (s byPathLength) Less(i, j int) bool {
+ if len(s[i].Path) != len(s[j].Path) {
+ return len(s[i].Path) > len(s[j].Path)
+ }
+ if !s[i].Creation.Equal(s[j].Creation) {
+ return s[i].Creation.Before(s[j].Creation)
+ }
+ return s[i].seqNum < s[j].seqNum
+}
+
+func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// Cookies implements the Cookies method of the http.CookieJar interface.
+//
+// It returns an empty slice if the URL's scheme is not HTTP or HTTPS.
+func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
+ return j.cookies(u, time.Now())
+}
+
+// cookies is like Cookies but takes the current time as a parameter.
+func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return cookies
+ }
+ host, err := canonicalHost(u.Host)
+ if err != nil {
+ return cookies
+ }
+ key := jarKey(host, j.psList)
+
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ submap := j.entries[key]
+ if submap == nil {
+ return cookies
+ }
+
+ https := u.Scheme == "https"
+ path := u.Path
+ if path == "" {
+ path = "/"
+ }
+
+ modified := false
+ var selected []entry
+ for id, e := range submap {
+ if e.Persistent && !e.Expires.After(now) {
+ delete(submap, id)
+ modified = true
+ continue
+ }
+ if !e.shouldSend(https, host, path) {
+ continue
+ }
+ e.LastAccess = now
+ submap[id] = e
+ selected = append(selected, e)
+ modified = true
+ }
+ if modified {
+ if len(submap) == 0 {
+ delete(j.entries, key)
+ } else {
+ j.entries[key] = submap
+ }
+ }
+
+ sort.Sort(byPathLength(selected))
+ for _, e := range selected {
+ cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
+ }
+
+ return cookies
+}
+
+// SetCookies implements the SetCookies method of the http.CookieJar interface.
+//
+// It does nothing if the URL's scheme is not HTTP or HTTPS.
+func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
+ j.setCookies(u, cookies, time.Now())
+}
+
+// setCookies is like SetCookies but takes the current time as parameter.
+func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
+ if len(cookies) == 0 {
+ return
+ }
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return
+ }
+ host, err := canonicalHost(u.Host)
+ if err != nil {
+ return
+ }
+ key := jarKey(host, j.psList)
+ defPath := defaultPath(u.Path)
+
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ submap := j.entries[key]
+
+ modified := false
+ for _, cookie := range cookies {
+ e, remove, err := j.newEntry(cookie, now, defPath, host)
+ if err != nil {
+ continue
+ }
+ id := e.id()
+ if remove {
+ if submap != nil {
+ if _, ok := submap[id]; ok {
+ delete(submap, id)
+ modified = true
+ }
+ }
+ continue
+ }
+ if submap == nil {
+ submap = make(map[string]entry)
+ }
+
+ if old, ok := submap[id]; ok {
+ e.Creation = old.Creation
+ e.seqNum = old.seqNum
+ } else {
+ e.Creation = now
+ e.seqNum = j.nextSeqNum
+ j.nextSeqNum++
+ }
+ e.LastAccess = now
+ submap[id] = e
+ modified = true
+ }
+
+ if modified {
+ if len(submap) == 0 {
+ delete(j.entries, key)
+ } else {
+ j.entries[key] = submap
+ }
+ }
+}
+
+// canonicalHost strips port from host if present and returns the canonicalized
+// host name.
+func canonicalHost(host string) (string, error) {
+ var err error
+ host = strings.ToLower(host)
+ if hasPort(host) {
+ host, _, err = net.SplitHostPort(host)
+ if err != nil {
+ return "", err
+ }
+ }
+ if strings.HasSuffix(host, ".") {
+ // Strip trailing dot from fully qualified domain names.
+ host = host[:len(host)-1]
+ }
+ return toASCII(host)
+}
+
+// hasPort returns whether host contains a port number. host may be a host
+// name, an IPv4 or an IPv6 address.
+func hasPort(host string) bool {
+ colons := strings.Count(host, ":")
+ if colons == 0 {
+ return false
+ }
+ if colons == 1 {
+ return true
+ }
+ return host[0] == '[' && strings.Contains(host, "]:")
+}
+
+// jarKey returns the key to use for a jar.
+func jarKey(host string, psl PublicSuffixList) string {
+ if isIP(host) {
+ return host
+ }
+
+ var i int
+ if psl == nil {
+ i = strings.LastIndex(host, ".")
+ if i == -1 {
+ return host
+ }
+ } else {
+ suffix := psl.PublicSuffix(host)
+ if suffix == host {
+ return host
+ }
+ i = len(host) - len(suffix)
+ if i <= 0 || host[i-1] != '.' {
+ // The provided public suffix list psl is broken.
+ // Storing cookies under host is a safe stopgap.
+ return host
+ }
+ }
+ prevDot := strings.LastIndex(host[:i-1], ".")
+ return host[prevDot+1:]
+}
+
+// isIP returns whether host is an IP address.
+func isIP(host string) bool {
+ return net.ParseIP(host) != nil
+}
+
+// defaultPath returns the directory part of an URL's path according to
+// RFC 6265 section 5.1.4.
+func defaultPath(path string) string {
+ if len(path) == 0 || path[0] != '/' {
+ return "/" // Path is empty or malformed.
+ }
+
+ i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1.
+ if i == 0 {
+ return "/" // Path has the form "/abc".
+ }
+ return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/".
+}
+
+// newEntry creates an entry from a http.Cookie c. now is the current time and
+// is compared to c.Expires to determine deletion of c. defPath and host are the
+// default-path and the canonical host name of the URL c was received from.
+//
+// remove is whether the jar should delete this cookie, as it has already
+// expired with respect to now. In this case, e may be incomplete, but it will
+// be valid to call e.id (which depends on e's Name, Domain and Path).
+//
+// A malformed c.Domain will result in an error.
+func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
+ e.Name = c.Name
+
+ if c.Path == "" || c.Path[0] != '/' {
+ e.Path = defPath
+ } else {
+ e.Path = c.Path
+ }
+
+ e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
+ if err != nil {
+ return e, false, err
+ }
+
+ // MaxAge takes precedence over Expires.
+ if c.MaxAge < 0 {
+ return e, true, nil
+ } else if c.MaxAge > 0 {
+ e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
+ e.Persistent = true
+ } else {
+ if c.Expires.IsZero() {
+ e.Expires = endOfTime
+ e.Persistent = false
+ } else {
+ if !c.Expires.After(now) {
+ return e, true, nil
+ }
+ e.Expires = c.Expires
+ e.Persistent = true
+ }
+ }
+
+ e.Value = c.Value
+ e.Secure = c.Secure
+ e.HttpOnly = c.HttpOnly
+
+ return e, false, nil
+}
+
+var (
+ errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
+ errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
+ errNoHostname = errors.New("cookiejar: no host name available (IP only)")
+)
+
+// endOfTime is the time when session (non-persistent) cookies expire.
+// This instant is representable in most date/time formats (not just
+// Go's time.Time) and should be far enough in the future.
+var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
+
+// domainAndType determines the cookie's domain and hostOnly attribute.
+func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
+ if domain == "" {
+ // No domain attribute in the SetCookie header indicates a
+ // host cookie.
+ return host, true, nil
+ }
+
+ if isIP(host) {
+ // According to RFC 6265 domain-matching includes not being
+ // an IP address.
+ // TODO: This might be relaxed as in common browsers.
+ return "", false, errNoHostname
+ }
+
+ // From here on: If the cookie is valid, it is a domain cookie (with
+ // the one exception of a public suffix below).
+ // See RFC 6265 section 5.2.3.
+ if domain[0] == '.' {
+ domain = domain[1:]
+ }
+
+ if len(domain) == 0 || domain[0] == '.' {
+ // Received either "Domain=." or "Domain=..some.thing",
+ // both are illegal.
+ return "", false, errMalformedDomain
+ }
+ domain = strings.ToLower(domain)
+
+ if domain[len(domain)-1] == '.' {
+ // We received stuff like "Domain=www.example.com.".
+ // Browsers do handle such stuff (actually differently) but
+ // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in
+ // requiring a reject. 4.1.2.3 is not normative, but
+ // "Domain Matching" (5.1.3) and "Canonicalized Host Names"
+ // (5.1.2) are.
+ return "", false, errMalformedDomain
+ }
+
+ // See RFC 6265 section 5.3 #5.
+ if j.psList != nil {
+ if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
+ if host == domain {
+ // This is the one exception in which a cookie
+ // with a domain attribute is a host cookie.
+ return host, true, nil
+ }
+ return "", false, errIllegalDomain
+ }
+ }
+
+ // The domain must domain-match host: www.mycompany.com cannot
+ // set cookies for .ourcompetitors.com.
+ if host != domain && !hasDotSuffix(host, domain) {
+ return "", false, errIllegalDomain
+ }
+
+ return domain, false, nil
+}
diff --git a/src/pkg/net/http/cookiejar/jar_test.go b/src/pkg/net/http/cookiejar/jar_test.go
new file mode 100644
index 000000000..3aa601586
--- /dev/null
+++ b/src/pkg/net/http/cookiejar/jar_test.go
@@ -0,0 +1,1267 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+import (
+ "fmt"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+)
+
+// tNow is the synthetic current time used as now during testing.
+var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC)
+
+// testPSL implements PublicSuffixList with just two rules: "co.uk"
+// and the default rule "*".
+type testPSL struct{}
+
+func (testPSL) String() string {
+ return "testPSL"
+}
+func (testPSL) PublicSuffix(d string) string {
+ if d == "co.uk" || strings.HasSuffix(d, ".co.uk") {
+ return "co.uk"
+ }
+ return d[strings.LastIndex(d, ".")+1:]
+}
+
+// newTestJar creates an empty Jar with testPSL as the public suffix list.
+func newTestJar() *Jar {
+ jar, err := New(&Options{PublicSuffixList: testPSL{}})
+ if err != nil {
+ panic(err)
+ }
+ return jar
+}
+
+var hasDotSuffixTests = [...]struct {
+ s, suffix string
+}{
+ {"", ""},
+ {"", "."},
+ {"", "x"},
+ {".", ""},
+ {".", "."},
+ {".", ".."},
+ {".", "x"},
+ {".", "x."},
+ {".", ".x"},
+ {".", ".x."},
+ {"x", ""},
+ {"x", "."},
+ {"x", ".."},
+ {"x", "x"},
+ {"x", "x."},
+ {"x", ".x"},
+ {"x", ".x."},
+ {".x", ""},
+ {".x", "."},
+ {".x", ".."},
+ {".x", "x"},
+ {".x", "x."},
+ {".x", ".x"},
+ {".x", ".x."},
+ {"x.", ""},
+ {"x.", "."},
+ {"x.", ".."},
+ {"x.", "x"},
+ {"x.", "x."},
+ {"x.", ".x"},
+ {"x.", ".x."},
+ {"com", ""},
+ {"com", "m"},
+ {"com", "om"},
+ {"com", "com"},
+ {"com", ".com"},
+ {"com", "x.com"},
+ {"com", "xcom"},
+ {"com", "xorg"},
+ {"com", "org"},
+ {"com", "rg"},
+ {"foo.com", ""},
+ {"foo.com", "m"},
+ {"foo.com", "om"},
+ {"foo.com", "com"},
+ {"foo.com", ".com"},
+ {"foo.com", "o.com"},
+ {"foo.com", "oo.com"},
+ {"foo.com", "foo.com"},
+ {"foo.com", ".foo.com"},
+ {"foo.com", "x.foo.com"},
+ {"foo.com", "xfoo.com"},
+ {"foo.com", "xfoo.org"},
+ {"foo.com", "foo.org"},
+ {"foo.com", "oo.org"},
+ {"foo.com", "o.org"},
+ {"foo.com", ".org"},
+ {"foo.com", "org"},
+ {"foo.com", "rg"},
+}
+
+func TestHasDotSuffix(t *testing.T) {
+ for _, tc := range hasDotSuffixTests {
+ got := hasDotSuffix(tc.s, tc.suffix)
+ want := strings.HasSuffix(tc.s, "."+tc.suffix)
+ if got != want {
+ t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want)
+ }
+ }
+}
+
+var canonicalHostTests = map[string]string{
+ "www.example.com": "www.example.com",
+ "WWW.EXAMPLE.COM": "www.example.com",
+ "wWw.eXAmple.CoM": "www.example.com",
+ "www.example.com:80": "www.example.com",
+ "192.168.0.10": "192.168.0.10",
+ "192.168.0.5:8080": "192.168.0.5",
+ "2001:4860:0:2001::68": "2001:4860:0:2001::68",
+ "[2001:4860:0:::68]:8080": "2001:4860:0:::68",
+ "www.bücher.de": "www.xn--bcher-kva.de",
+ "www.example.com.": "www.example.com",
+ "[bad.unmatched.bracket:": "error",
+}
+
+func TestCanonicalHost(t *testing.T) {
+ for h, want := range canonicalHostTests {
+ got, err := canonicalHost(h)
+ if want == "error" {
+ if err == nil {
+ t.Errorf("%q: got nil error, want non-nil", h)
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("%q: %v", h, err)
+ continue
+ }
+ if got != want {
+ t.Errorf("%q: got %q, want %q", h, got, want)
+ continue
+ }
+ }
+}
+
+var hasPortTests = map[string]bool{
+ "www.example.com": false,
+ "www.example.com:80": true,
+ "127.0.0.1": false,
+ "127.0.0.1:8080": true,
+ "2001:4860:0:2001::68": false,
+ "[2001::0:::68]:80": true,
+}
+
+func TestHasPort(t *testing.T) {
+ for host, want := range hasPortTests {
+ if got := hasPort(host); got != want {
+ t.Errorf("%q: got %t, want %t", host, got, want)
+ }
+ }
+}
+
+var jarKeyTests = map[string]string{
+ "foo.www.example.com": "example.com",
+ "www.example.com": "example.com",
+ "example.com": "example.com",
+ "com": "com",
+ "foo.www.bbc.co.uk": "bbc.co.uk",
+ "www.bbc.co.uk": "bbc.co.uk",
+ "bbc.co.uk": "bbc.co.uk",
+ "co.uk": "co.uk",
+ "uk": "uk",
+ "192.168.0.5": "192.168.0.5",
+}
+
+func TestJarKey(t *testing.T) {
+ for host, want := range jarKeyTests {
+ if got := jarKey(host, testPSL{}); got != want {
+ t.Errorf("%q: got %q, want %q", host, got, want)
+ }
+ }
+}
+
+var jarKeyNilPSLTests = map[string]string{
+ "foo.www.example.com": "example.com",
+ "www.example.com": "example.com",
+ "example.com": "example.com",
+ "com": "com",
+ "foo.www.bbc.co.uk": "co.uk",
+ "www.bbc.co.uk": "co.uk",
+ "bbc.co.uk": "co.uk",
+ "co.uk": "co.uk",
+ "uk": "uk",
+ "192.168.0.5": "192.168.0.5",
+}
+
+func TestJarKeyNilPSL(t *testing.T) {
+ for host, want := range jarKeyNilPSLTests {
+ if got := jarKey(host, nil); got != want {
+ t.Errorf("%q: got %q, want %q", host, got, want)
+ }
+ }
+}
+
+var isIPTests = map[string]bool{
+ "127.0.0.1": true,
+ "1.2.3.4": true,
+ "2001:4860:0:2001::68": true,
+ "example.com": false,
+ "1.1.1.300": false,
+ "www.foo.bar.net": false,
+ "123.foo.bar.net": false,
+}
+
+func TestIsIP(t *testing.T) {
+ for host, want := range isIPTests {
+ if got := isIP(host); got != want {
+ t.Errorf("%q: got %t, want %t", host, got, want)
+ }
+ }
+}
+
+var defaultPathTests = map[string]string{
+ "/": "/",
+ "/abc": "/",
+ "/abc/": "/abc",
+ "/abc/xyz": "/abc",
+ "/abc/xyz/": "/abc/xyz",
+ "/a/b/c.html": "/a/b",
+ "": "/",
+ "strange": "/",
+ "//": "/",
+ "/a//b": "/a/",
+ "/a/./b": "/a/.",
+ "/a/../b": "/a/..",
+}
+
+func TestDefaultPath(t *testing.T) {
+ for path, want := range defaultPathTests {
+ if got := defaultPath(path); got != want {
+ t.Errorf("%q: got %q, want %q", path, got, want)
+ }
+ }
+}
+
+var domainAndTypeTests = [...]struct {
+ host string // host Set-Cookie header was received from
+ domain string // domain attribute in Set-Cookie header
+ wantDomain string // expected domain of cookie
+ wantHostOnly bool // expected host-cookie flag
+ wantErr error // expected error
+}{
+ {"www.example.com", "", "www.example.com", true, nil},
+ {"127.0.0.1", "", "127.0.0.1", true, nil},
+ {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil},
+ {"www.example.com", "example.com", "example.com", false, nil},
+ {"www.example.com", ".example.com", "example.com", false, nil},
+ {"www.example.com", "www.example.com", "www.example.com", false, nil},
+ {"www.example.com", ".www.example.com", "www.example.com", false, nil},
+ {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil},
+ {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil},
+ {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil},
+ {"127.0.0.1", "127.0.0.1", "", false, errNoHostname},
+ {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", false, errNoHostname},
+ {"www.example.com", ".", "", false, errMalformedDomain},
+ {"www.example.com", "..", "", false, errMalformedDomain},
+ {"www.example.com", "other.com", "", false, errIllegalDomain},
+ {"www.example.com", "com", "", false, errIllegalDomain},
+ {"www.example.com", ".com", "", false, errIllegalDomain},
+ {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain},
+ {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain},
+ {"com", "", "com", true, nil},
+ {"com", "com", "com", true, nil},
+ {"com", ".com", "com", true, nil},
+ {"co.uk", "", "co.uk", true, nil},
+ {"co.uk", "co.uk", "co.uk", true, nil},
+ {"co.uk", ".co.uk", "co.uk", true, nil},
+}
+
+func TestDomainAndType(t *testing.T) {
+ jar := newTestJar()
+ for _, tc := range domainAndTypeTests {
+ domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain)
+ if err != tc.wantErr {
+ t.Errorf("%q/%q: got %q error, want %q",
+ tc.host, tc.domain, err, tc.wantErr)
+ continue
+ }
+ if err != nil {
+ continue
+ }
+ if domain != tc.wantDomain || hostOnly != tc.wantHostOnly {
+ t.Errorf("%q/%q: got %q/%t want %q/%t",
+ tc.host, tc.domain, domain, hostOnly,
+ tc.wantDomain, tc.wantHostOnly)
+ }
+ }
+}
+
+// expiresIn creates an expires attribute delta seconds from tNow.
+func expiresIn(delta int) string {
+ t := tNow.Add(time.Duration(delta) * time.Second)
+ return "expires=" + t.Format(time.RFC1123)
+}
+
+// mustParseURL parses s to an URL and panics on error.
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil || u.Scheme == "" || u.Host == "" {
+ panic(fmt.Sprintf("Unable to parse URL %s.", s))
+ }
+ return u
+}
+
+// jarTest encapsulates the following actions on a jar:
+// 1. Perform SetCookies with fromURL and the cookies from setCookies.
+// (Done at time tNow + 0 ms.)
+// 2. Check that the entries in the jar matches content.
+// (Done at time tNow + 1001 ms.)
+// 3. For each query in tests: Check that Cookies with toURL yields the
+// cookies in want.
+// (Query n done at tNow + (n+2)*1001 ms.)
+type jarTest struct {
+ description string // The description of what this test is supposed to test
+ fromURL string // The full URL of the request from which Set-Cookie headers where received
+ setCookies []string // All the cookies received from fromURL
+ content string // The whole (non-expired) content of the jar
+ queries []query // Queries to test the Jar.Cookies method
+}
+
+// query contains one test of the cookies returned from Jar.Cookies.
+type query struct {
+ toURL string // the URL in the Cookies call
+ want string // the expected list of cookies (order matters)
+}
+
+// run runs the jarTest.
+func (test jarTest) run(t *testing.T, jar *Jar) {
+ now := tNow
+
+ // Populate jar with cookies.
+ setCookies := make([]*http.Cookie, len(test.setCookies))
+ for i, cs := range test.setCookies {
+ cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies()
+ if len(cookies) != 1 {
+ panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies))
+ }
+ setCookies[i] = cookies[0]
+ }
+ jar.setCookies(mustParseURL(test.fromURL), setCookies, now)
+ now = now.Add(1001 * time.Millisecond)
+
+ // Serialize non-expired entries in the form "name1=val1 name2=val2".
+ var cs []string
+ for _, submap := range jar.entries {
+ for _, cookie := range submap {
+ if !cookie.Expires.After(now) {
+ continue
+ }
+ cs = append(cs, cookie.Name+"="+cookie.Value)
+ }
+ }
+ sort.Strings(cs)
+ got := strings.Join(cs, " ")
+
+ // Make sure jar content matches our expectations.
+ if got != test.content {
+ t.Errorf("Test %q Content\ngot %q\nwant %q",
+ test.description, got, test.content)
+ }
+
+ // Test different calls to Cookies.
+ for i, query := range test.queries {
+ now = now.Add(1001 * time.Millisecond)
+ var s []string
+ for _, c := range jar.cookies(mustParseURL(query.toURL), now) {
+ s = append(s, c.Name+"="+c.Value)
+ }
+ if got := strings.Join(s, " "); got != query.want {
+ t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want)
+ }
+ }
+}
+
+// basicsTests contains fundamental tests. Each jarTest has to be performed on
+// a fresh, empty Jar.
+var basicsTests = [...]jarTest{
+ {
+ "Retrieval of a plain host cookie.",
+ "http://www.host.test/",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", "A=a"},
+ {"http://www.host.test/", "A=a"},
+ {"http://www.host.test/some/path", "A=a"},
+ {"https://www.host.test", "A=a"},
+ {"https://www.host.test/", "A=a"},
+ {"https://www.host.test/some/path", "A=a"},
+ {"ftp://www.host.test", ""},
+ {"ftp://www.host.test/", ""},
+ {"ftp://www.host.test/some/path", ""},
+ {"http://www.other.org", ""},
+ {"http://sibling.host.test", ""},
+ {"http://deep.www.host.test", ""},
+ },
+ },
+ {
+ "Secure cookies are not returned to http.",
+ "http://www.host.test/",
+ []string{"A=a; secure"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some/path", ""},
+ {"https://www.host.test", "A=a"},
+ {"https://www.host.test/", "A=a"},
+ {"https://www.host.test/some/path", "A=a"},
+ },
+ },
+ {
+ "Explicit path.",
+ "http://www.host.test/",
+ []string{"A=a; path=/some/path"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #1: path is a directory.",
+ "http://www.host.test/some/path/",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #2: path is not a directory.",
+ "http://www.host.test/some/path/index.html",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #3: no path in URL at all.",
+ "http://www.host.test",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", "A=a"},
+ {"http://www.host.test/", "A=a"},
+ {"http://www.host.test/some/path", "A=a"},
+ },
+ },
+ {
+ "Cookies are sorted by path length.",
+ "http://www.host.test/",
+ []string{
+ "A=a; path=/foo/bar",
+ "B=b; path=/foo/bar/baz/qux",
+ "C=c; path=/foo/bar/baz",
+ "D=d; path=/foo"},
+ "A=a B=b C=c D=d",
+ []query{
+ {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"},
+ {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"},
+ {"http://www.host.test/foo/bar", "A=a D=d"},
+ },
+ },
+ {
+ "Creation time determines sorting on same length paths.",
+ "http://www.host.test/",
+ []string{
+ "A=a; path=/foo/bar",
+ "X=x; path=/foo/bar",
+ "Y=y; path=/foo/bar/baz/qux",
+ "B=b; path=/foo/bar/baz/qux",
+ "C=c; path=/foo/bar/baz",
+ "W=w; path=/foo/bar/baz",
+ "Z=z; path=/foo",
+ "D=d; path=/foo"},
+ "A=a B=b C=c D=d W=w X=x Y=y Z=z",
+ []query{
+ {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"},
+ {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"},
+ {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"},
+ },
+ },
+ {
+ "Sorting of same-name cookies.",
+ "http://www.host.test/",
+ []string{
+ "A=1; path=/",
+ "A=2; path=/path",
+ "A=3; path=/quux",
+ "A=4; path=/path/foo",
+ "A=5; domain=.host.test; path=/path",
+ "A=6; domain=.host.test; path=/quux",
+ "A=7; domain=.host.test; path=/path/foo",
+ },
+ "A=1 A=2 A=3 A=4 A=5 A=6 A=7",
+ []query{
+ {"http://www.host.test/path", "A=2 A=5 A=1"},
+ {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"},
+ },
+ },
+ {
+ "Disallow domain cookie on public suffix.",
+ "http://www.bbc.co.uk",
+ []string{
+ "a=1",
+ "b=2; domain=co.uk",
+ },
+ "a=1",
+ []query{{"http://www.bbc.co.uk", "a=1"}},
+ },
+ {
+ "Host cookie on IP.",
+ "http://192.168.0.10",
+ []string{"a=1"},
+ "a=1",
+ []query{{"http://192.168.0.10", "a=1"}},
+ },
+ {
+ "Port is ignored #1.",
+ "http://www.host.test/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://www.host.test:8080/", "a=1"},
+ },
+ },
+ {
+ "Port is ignored #2.",
+ "http://www.host.test:8080/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://www.host.test:8080/", "a=1"},
+ {"http://www.host.test:1234/", "a=1"},
+ },
+ },
+}
+
+func TestBasics(t *testing.T) {
+ for _, test := range basicsTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
+
+// updateAndDeleteTests contains jarTests which must be performed on the same
+// Jar.
+var updateAndDeleteTests = [...]jarTest{
+ {
+ "Set initial cookies.",
+ "http://www.host.test",
+ []string{
+ "a=1",
+ "b=2; secure",
+ "c=3; httponly",
+ "d=4; secure; httponly"},
+ "a=1 b=2 c=3 d=4",
+ []query{
+ {"http://www.host.test", "a=1 c=3"},
+ {"https://www.host.test", "a=1 b=2 c=3 d=4"},
+ },
+ },
+ {
+ "Update value via http.",
+ "http://www.host.test",
+ []string{
+ "a=w",
+ "b=x; secure",
+ "c=y; httponly",
+ "d=z; secure; httponly"},
+ "a=w b=x c=y d=z",
+ []query{
+ {"http://www.host.test", "a=w c=y"},
+ {"https://www.host.test", "a=w b=x c=y d=z"},
+ },
+ },
+ {
+ "Clear Secure flag from a http.",
+ "http://www.host.test/",
+ []string{
+ "b=xx",
+ "d=zz; httponly"},
+ "a=w b=xx c=y d=zz",
+ []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}},
+ },
+ {
+ "Delete all.",
+ "http://www.host.test/",
+ []string{
+ "a=1; max-Age=-1", // delete via MaxAge
+ "b=2; " + expiresIn(-10), // delete via Expires
+ "c=2; max-age=-1; " + expiresIn(-10), // delete via both
+ "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence
+ "",
+ []query{{"http://www.host.test", ""}},
+ },
+ {
+ "Refill #1.",
+ "http://www.host.test",
+ []string{
+ "A=1",
+ "A=2; path=/foo",
+ "A=3; domain=.host.test",
+ "A=4; path=/foo; domain=.host.test"},
+ "A=1 A=2 A=3 A=4",
+ []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}},
+ },
+ {
+ "Refill #2.",
+ "http://www.google.com",
+ []string{
+ "A=6",
+ "A=7; path=/foo",
+ "A=8; domain=.google.com",
+ "A=9; path=/foo; domain=.google.com"},
+ "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"},
+ {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A7.",
+ "http://www.google.com",
+ []string{"A=; path=/foo; max-age=-1"},
+ "A=1 A=2 A=3 A=4 A=6 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A4.",
+ "http://www.host.test",
+ []string{"A=; path=/foo; domain=host.test; max-age=-1"},
+ "A=1 A=2 A=3 A=6 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A6.",
+ "http://www.google.com",
+ []string{"A=; max-age=-1"},
+ "A=1 A=2 A=3 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "Delete A3.",
+ "http://www.host.test",
+ []string{"A=; domain=host.test; max-age=-1"},
+ "A=1 A=2 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "No cross-domain delete.",
+ "http://www.host.test",
+ []string{
+ "A=; domain=google.com; max-age=-1",
+ "A=; path=/foo; domain=google.com; max-age=-1"},
+ "A=1 A=2 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "Delete A8 and A9.",
+ "http://www.google.com",
+ []string{
+ "A=; domain=google.com; max-age=-1",
+ "A=; path=/foo; domain=google.com; max-age=-1"},
+ "A=1 A=2",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", ""},
+ },
+ },
+}
+
+func TestUpdateAndDelete(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range updateAndDeleteTests {
+ test.run(t, jar)
+ }
+}
+
+func TestExpiration(t *testing.T) {
+ jar := newTestJar()
+ jarTest{
+ "Expiration.",
+ "http://www.host.test",
+ []string{
+ "a=1",
+ "b=2; max-age=3",
+ "c=3; " + expiresIn(3),
+ "d=4; max-age=5",
+ "e=5; " + expiresIn(5),
+ "f=6; max-age=100",
+ },
+ "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms
+ []query{
+ {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms
+ {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms
+ {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms
+ {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms
+ {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms
+ },
+ }.run(t, jar)
+}
+
+//
+// Tests derived from Chromium's cookie_store_unittest.h.
+//
+
+// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain
+// Some of the original tests are in a bad condition (e.g.
+// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g.
+// TestNonDottedAndTLD #1 and #6) and have not been ported.
+
+// chromiumBasicsTests contains fundamental tests. Each jarTest has to be
+// performed on a fresh, empty Jar.
+var chromiumBasicsTests = [...]jarTest{
+ {
+ "DomainWithTrailingDotTest.",
+ "http://www.google.com/",
+ []string{
+ "a=1; domain=.www.google.com.",
+ "b=2; domain=.www.google.com.."},
+ "",
+ []query{
+ {"http://www.google.com", ""},
+ },
+ },
+ {
+ "ValidSubdomainTest #1.",
+ "http://a.b.c.d.com",
+ []string{
+ "a=1; domain=.a.b.c.d.com",
+ "b=2; domain=.b.c.d.com",
+ "c=3; domain=.c.d.com",
+ "d=4; domain=.d.com"},
+ "a=1 b=2 c=3 d=4",
+ []query{
+ {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"},
+ {"http://b.c.d.com", "b=2 c=3 d=4"},
+ {"http://c.d.com", "c=3 d=4"},
+ {"http://d.com", "d=4"},
+ },
+ },
+ {
+ "ValidSubdomainTest #2.",
+ "http://a.b.c.d.com",
+ []string{
+ "a=1; domain=.a.b.c.d.com",
+ "b=2; domain=.b.c.d.com",
+ "c=3; domain=.c.d.com",
+ "d=4; domain=.d.com",
+ "X=bcd; domain=.b.c.d.com",
+ "X=cd; domain=.c.d.com"},
+ "X=bcd X=cd a=1 b=2 c=3 d=4",
+ []query{
+ {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"},
+ {"http://c.d.com", "c=3 d=4 X=cd"},
+ },
+ },
+ {
+ "InvalidDomainTest #1.",
+ "http://foo.bar.com",
+ []string{
+ "a=1; domain=.yo.foo.bar.com",
+ "b=2; domain=.foo.com",
+ "c=3; domain=.bar.foo.com",
+ "d=4; domain=.foo.bar.com.net",
+ "e=5; domain=ar.com",
+ "f=6; domain=.",
+ "g=7; domain=/",
+ "h=8; domain=http://foo.bar.com",
+ "i=9; domain=..foo.bar.com",
+ "j=10; domain=..bar.com",
+ "k=11; domain=.foo.bar.com?blah",
+ "l=12; domain=.foo.bar.com/blah",
+ "m=12; domain=.foo.bar.com:80",
+ "n=14; domain=.foo.bar.com:",
+ "o=15; domain=.foo.bar.com#sup",
+ },
+ "", // Jar is empty.
+ []query{{"http://foo.bar.com", ""}},
+ },
+ {
+ "InvalidDomainTest #2.",
+ "http://foo.com.com",
+ []string{"a=1; domain=.foo.com.com.com"},
+ "",
+ []query{{"http://foo.bar.com", ""}},
+ },
+ {
+ "DomainWithoutLeadingDotTest #1.",
+ "http://manage.hosted.filefront.com",
+ []string{"a=1; domain=filefront.com"},
+ "a=1",
+ []query{{"http://www.filefront.com", "a=1"}},
+ },
+ {
+ "DomainWithoutLeadingDotTest #2.",
+ "http://www.google.com",
+ []string{"a=1; domain=www.google.com"},
+ "a=1",
+ []query{
+ {"http://www.google.com", "a=1"},
+ {"http://sub.www.google.com", "a=1"},
+ {"http://something-else.com", ""},
+ },
+ },
+ {
+ "CaseInsensitiveDomainTest.",
+ "http://www.google.com",
+ []string{
+ "a=1; domain=.GOOGLE.COM",
+ "b=2; domain=.www.gOOgLE.coM"},
+ "a=1 b=2",
+ []query{{"http://www.google.com", "a=1 b=2"}},
+ },
+ {
+ "TestIpAddress #1.",
+ "http://1.2.3.4/foo",
+ []string{"a=1; path=/"},
+ "a=1",
+ []query{{"http://1.2.3.4/foo", "a=1"}},
+ },
+ {
+ "TestIpAddress #2.",
+ "http://1.2.3.4/foo",
+ []string{
+ "a=1; domain=.1.2.3.4",
+ "b=2; domain=.3.4"},
+ "",
+ []query{{"http://1.2.3.4/foo", ""}},
+ },
+ {
+ "TestIpAddress #3.",
+ "http://1.2.3.4/foo",
+ []string{"a=1; domain=1.2.3.4"},
+ "",
+ []query{{"http://1.2.3.4/foo", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #2.",
+ "http://com./index.html",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://com./index.html", "a=1"},
+ {"http://no-cookies.com./index.html", ""},
+ },
+ },
+ {
+ "TestNonDottedAndTLD #3.",
+ "http://a.b",
+ []string{
+ "a=1; domain=.b",
+ "b=2; domain=b"},
+ "",
+ []query{{"http://bar.foo", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #4.",
+ "http://google.com",
+ []string{
+ "a=1; domain=.com",
+ "b=2; domain=com"},
+ "",
+ []query{{"http://google.com", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #5.",
+ "http://google.co.uk",
+ []string{
+ "a=1; domain=.co.uk",
+ "b=2; domain=.uk"},
+ "",
+ []query{
+ {"http://google.co.uk", ""},
+ {"http://else.co.com", ""},
+ {"http://else.uk", ""},
+ },
+ },
+ {
+ "TestHostEndsWithDot.",
+ "http://www.google.com",
+ []string{
+ "a=1",
+ "b=2; domain=.www.google.com."},
+ "a=1",
+ []query{{"http://www.google.com", "a=1"}},
+ },
+ {
+ "PathTest",
+ "http://www.google.izzle",
+ []string{"a=1; path=/wee"},
+ "a=1",
+ []query{
+ {"http://www.google.izzle/wee", "a=1"},
+ {"http://www.google.izzle/wee/", "a=1"},
+ {"http://www.google.izzle/wee/war", "a=1"},
+ {"http://www.google.izzle/wee/war/more/more", "a=1"},
+ {"http://www.google.izzle/weehee", ""},
+ {"http://www.google.izzle/", ""},
+ },
+ },
+}
+
+func TestChromiumBasics(t *testing.T) {
+ for _, test := range chromiumBasicsTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
+
+// chromiumDomainTests contains jarTests which must be executed all on the
+// same Jar.
+var chromiumDomainTests = [...]jarTest{
+ {
+ "Fill #1.",
+ "http://www.google.izzle",
+ []string{"A=B"},
+ "A=B",
+ []query{{"http://www.google.izzle", "A=B"}},
+ },
+ {
+ "Fill #2.",
+ "http://www.google.izzle",
+ []string{"C=D; domain=.google.izzle"},
+ "A=B C=D",
+ []query{{"http://www.google.izzle", "A=B C=D"}},
+ },
+ {
+ "Verify A is a host cookie and not accessible from subdomain.",
+ "http://unused.nil",
+ []string{},
+ "A=B C=D",
+ []query{{"http://foo.www.google.izzle", "C=D"}},
+ },
+ {
+ "Verify domain cookies are found on proper domain.",
+ "http://www.google.izzle",
+ []string{"E=F; domain=.www.google.izzle"},
+ "A=B C=D E=F",
+ []query{{"http://www.google.izzle", "A=B C=D E=F"}},
+ },
+ {
+ "Leading dots in domain attributes are optional.",
+ "http://www.google.izzle",
+ []string{"G=H; domain=www.google.izzle"},
+ "A=B C=D E=F G=H",
+ []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}},
+ },
+ {
+ "Verify domain enforcement works #1.",
+ "http://www.google.izzle",
+ []string{"K=L; domain=.bar.www.google.izzle"},
+ "A=B C=D E=F G=H",
+ []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}},
+ },
+ {
+ "Verify domain enforcement works #2.",
+ "http://unused.nil",
+ []string{},
+ "A=B C=D E=F G=H",
+ []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}},
+ },
+}
+
+func TestChromiumDomain(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range chromiumDomainTests {
+ test.run(t, jar)
+ }
+
+}
+
+// chromiumDeletionTests must be performed all on the same Jar.
+var chromiumDeletionTests = [...]jarTest{
+ {
+ "Create session cookie a1.",
+ "http://www.google.com",
+ []string{"a=1"},
+ "a=1",
+ []query{{"http://www.google.com", "a=1"}},
+ },
+ {
+ "Delete sc a1 via MaxAge.",
+ "http://www.google.com",
+ []string{"a=1; max-age=-1"},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create session cookie b2.",
+ "http://www.google.com",
+ []string{"b=2"},
+ "b=2",
+ []query{{"http://www.google.com", "b=2"}},
+ },
+ {
+ "Delete sc b2 via Expires.",
+ "http://www.google.com",
+ []string{"b=2; " + expiresIn(-10)},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create persistent cookie c3.",
+ "http://www.google.com",
+ []string{"c=3; max-age=3600"},
+ "c=3",
+ []query{{"http://www.google.com", "c=3"}},
+ },
+ {
+ "Delete pc c3 via MaxAge.",
+ "http://www.google.com",
+ []string{"c=3; max-age=-1"},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create persistent cookie d4.",
+ "http://www.google.com",
+ []string{"d=4; max-age=3600"},
+ "d=4",
+ []query{{"http://www.google.com", "d=4"}},
+ },
+ {
+ "Delete pc d4 via Expires.",
+ "http://www.google.com",
+ []string{"d=4; " + expiresIn(-10)},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+}
+
+func TestChromiumDeletion(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range chromiumDeletionTests {
+ test.run(t, jar)
+ }
+}
+
+// domainHandlingTests tests and documents the rules for domain handling.
+// Each test must be performed on an empty new Jar.
+var domainHandlingTests = [...]jarTest{
+ {
+ "Host cookie",
+ "http://www.host.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", ""},
+ {"http://bar.host.test", ""},
+ {"http://foo.www.host.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie #1",
+ "http://www.host.test",
+ []string{"a=1; domain=host.test"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", "a=1"},
+ {"http://bar.host.test", "a=1"},
+ {"http://foo.www.host.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie #2",
+ "http://www.host.test",
+ []string{"a=1; domain=.host.test"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", "a=1"},
+ {"http://bar.host.test", "a=1"},
+ {"http://foo.www.host.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on IDNA domain #1",
+ "http://www.bücher.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", ""},
+ {"http://xn--bcher-kva.test", ""},
+ {"http://bar.bücher.test", ""},
+ {"http://bar.xn--bcher-kva.test", ""},
+ {"http://foo.www.bücher.test", ""},
+ {"http://foo.www.xn--bcher-kva.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on IDNA domain #2",
+ "http://www.xn--bcher-kva.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", ""},
+ {"http://xn--bcher-kva.test", ""},
+ {"http://bar.bücher.test", ""},
+ {"http://bar.xn--bcher-kva.test", ""},
+ {"http://foo.www.bücher.test", ""},
+ {"http://foo.www.xn--bcher-kva.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie on IDNA domain #1",
+ "http://www.bücher.test",
+ []string{"a=1; domain=xn--bcher-kva.test"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", "a=1"},
+ {"http://xn--bcher-kva.test", "a=1"},
+ {"http://bar.bücher.test", "a=1"},
+ {"http://bar.xn--bcher-kva.test", "a=1"},
+ {"http://foo.www.bücher.test", "a=1"},
+ {"http://foo.www.xn--bcher-kva.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie on IDNA domain #2",
+ "http://www.xn--bcher-kva.test",
+ []string{"a=1; domain=xn--bcher-kva.test"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", "a=1"},
+ {"http://xn--bcher-kva.test", "a=1"},
+ {"http://bar.bücher.test", "a=1"},
+ {"http://bar.xn--bcher-kva.test", "a=1"},
+ {"http://foo.www.bücher.test", "a=1"},
+ {"http://foo.www.xn--bcher-kva.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on TLD.",
+ "http://com",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://com", "a=1"},
+ {"http://any.com", ""},
+ {"http://any.test", ""},
+ },
+ },
+ {
+ "Domain cookie on TLD becomes a host cookie.",
+ "http://com",
+ []string{"a=1; domain=com"},
+ "a=1",
+ []query{
+ {"http://com", "a=1"},
+ {"http://any.com", ""},
+ {"http://any.test", ""},
+ },
+ },
+ {
+ "Host cookie on public suffix.",
+ "http://co.uk",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://co.uk", "a=1"},
+ {"http://uk", ""},
+ {"http://some.co.uk", ""},
+ {"http://foo.some.co.uk", ""},
+ {"http://any.uk", ""},
+ },
+ },
+ {
+ "Domain cookie on public suffix is ignored.",
+ "http://some.co.uk",
+ []string{"a=1; domain=co.uk"},
+ "",
+ []query{
+ {"http://co.uk", ""},
+ {"http://uk", ""},
+ {"http://some.co.uk", ""},
+ {"http://foo.some.co.uk", ""},
+ {"http://any.uk", ""},
+ },
+ },
+}
+
+func TestDomainHandling(t *testing.T) {
+ for _, test := range domainHandlingTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
diff --git a/src/pkg/net/http/cookiejar/punycode.go b/src/pkg/net/http/cookiejar/punycode.go
new file mode 100644
index 000000000..ea7ceb5ef
--- /dev/null
+++ b/src/pkg/net/http/cookiejar/punycode.go
@@ -0,0 +1,159 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+// This file implements the Punycode algorithm from RFC 3492.
+
+import (
+ "fmt"
+ "strings"
+ "unicode/utf8"
+)
+
+// These parameter values are specified in section 5.
+//
+// All computation is done with int32s, so that overflow behavior is identical
+// regardless of whether int is 32-bit or 64-bit.
+const (
+ base int32 = 36
+ damp int32 = 700
+ initialBias int32 = 72
+ initialN int32 = 128
+ skew int32 = 38
+ tmax int32 = 26
+ tmin int32 = 1
+)
+
+// encode encodes a string as specified in section 6.3 and prepends prefix to
+// the result.
+//
+// The "while h < length(input)" line in the specification becomes "for
+// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes.
+func encode(prefix, s string) (string, error) {
+ output := make([]byte, len(prefix), len(prefix)+1+2*len(s))
+ copy(output, prefix)
+ delta, n, bias := int32(0), initialN, initialBias
+ b, remaining := int32(0), int32(0)
+ for _, r := range s {
+ if r < 0x80 {
+ b++
+ output = append(output, byte(r))
+ } else {
+ remaining++
+ }
+ }
+ h := b
+ if b > 0 {
+ output = append(output, '-')
+ }
+ for remaining != 0 {
+ m := int32(0x7fffffff)
+ for _, r := range s {
+ if m > r && r >= n {
+ m = r
+ }
+ }
+ delta += (m - n) * (h + 1)
+ if delta < 0 {
+ return "", fmt.Errorf("cookiejar: invalid label %q", s)
+ }
+ n = m
+ for _, r := range s {
+ if r < n {
+ delta++
+ if delta < 0 {
+ return "", fmt.Errorf("cookiejar: invalid label %q", s)
+ }
+ continue
+ }
+ if r > n {
+ continue
+ }
+ q := delta
+ for k := base; ; k += base {
+ t := k - bias
+ if t < tmin {
+ t = tmin
+ } else if t > tmax {
+ t = tmax
+ }
+ if q < t {
+ break
+ }
+ output = append(output, encodeDigit(t+(q-t)%(base-t)))
+ q = (q - t) / (base - t)
+ }
+ output = append(output, encodeDigit(q))
+ bias = adapt(delta, h+1, h == b)
+ delta = 0
+ h++
+ remaining--
+ }
+ delta++
+ n++
+ }
+ return string(output), nil
+}
+
+func encodeDigit(digit int32) byte {
+ switch {
+ case 0 <= digit && digit < 26:
+ return byte(digit + 'a')
+ case 26 <= digit && digit < 36:
+ return byte(digit + ('0' - 26))
+ }
+ panic("cookiejar: internal error in punycode encoding")
+}
+
+// adapt is the bias adaptation function specified in section 6.1.
+func adapt(delta, numPoints int32, firstTime bool) int32 {
+ if firstTime {
+ delta /= damp
+ } else {
+ delta /= 2
+ }
+ delta += delta / numPoints
+ k := int32(0)
+ for delta > ((base-tmin)*tmax)/2 {
+ delta /= base - tmin
+ k += base
+ }
+ return k + (base-tmin+1)*delta/(delta+skew)
+}
+
+// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and
+// friends) and not Punycode (RFC 3492) per se.
+
+// acePrefix is the ASCII Compatible Encoding prefix.
+const acePrefix = "xn--"
+
+// toASCII converts a domain or domain label to its ASCII form. For example,
+// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
+// toASCII("golang") is "golang".
+func toASCII(s string) (string, error) {
+ if ascii(s) {
+ return s, nil
+ }
+ labels := strings.Split(s, ".")
+ for i, label := range labels {
+ if !ascii(label) {
+ a, err := encode(acePrefix, label)
+ if err != nil {
+ return "", err
+ }
+ labels[i] = a
+ }
+ }
+ return strings.Join(labels, "."), nil
+}
+
+func ascii(s string) bool {
+ for i := 0; i < len(s); i++ {
+ if s[i] >= utf8.RuneSelf {
+ return false
+ }
+ }
+ return true
+}
diff --git a/src/pkg/net/http/cookiejar/punycode_test.go b/src/pkg/net/http/cookiejar/punycode_test.go
new file mode 100644
index 000000000..0301de14e
--- /dev/null
+++ b/src/pkg/net/http/cookiejar/punycode_test.go
@@ -0,0 +1,161 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+import (
+ "testing"
+)
+
+var punycodeTestCases = [...]struct {
+ s, encoded string
+}{
+ {"", ""},
+ {"-", "--"},
+ {"-a", "-a-"},
+ {"-a-", "-a--"},
+ {"a", "a-"},
+ {"a-", "a--"},
+ {"a-b", "a-b-"},
+ {"books", "books-"},
+ {"bücher", "bcher-kva"},
+ {"Hello世界", "Hello-ck1hg65u"},
+ {"ü", "tda"},
+ {"üý", "tdac"},
+
+ // The test cases below come from RFC 3492 section 7.1 with Errata 3026.
+ {
+ // (A) Arabic (Egyptian).
+ "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" +
+ "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F",
+ "egbpdaj6bu4bxfgehfvwxn",
+ },
+ {
+ // (B) Chinese (simplified).
+ "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587",
+ "ihqwcrb4cv8a8dqg056pqjye",
+ },
+ {
+ // (C) Chinese (traditional).
+ "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587",
+ "ihqwctvzc91f659drss3x8bo0yb",
+ },
+ {
+ // (D) Czech.
+ "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" +
+ "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" +
+ "\u0065\u0073\u006B\u0079",
+ "Proprostnemluvesky-uyb24dma41a",
+ },
+ {
+ // (E) Hebrew.
+ "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" +
+ "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" +
+ "\u05D1\u05E8\u05D9\u05EA",
+ "4dbcagdahymbxekheh6e0a7fei0b",
+ },
+ {
+ // (F) Hindi (Devanagari).
+ "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" +
+ "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" +
+ "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" +
+ "\u0939\u0948\u0902",
+ "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd",
+ },
+ {
+ // (G) Japanese (kanji and hiragana).
+ "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" +
+ "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B",
+ "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa",
+ },
+ {
+ // (H) Korean (Hangul syllables).
+ "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" +
+ "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" +
+ "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C",
+ "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" +
+ "psd879ccm6fea98c",
+ },
+ {
+ // (I) Russian (Cyrillic).
+ "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" +
+ "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" +
+ "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" +
+ "\u0438",
+ "b1abfaaepdrnnbgefbadotcwatmq2g4l",
+ },
+ {
+ // (J) Spanish.
+ "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" +
+ "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" +
+ "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" +
+ "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" +
+ "\u0061\u00F1\u006F\u006C",
+ "PorqunopuedensimplementehablarenEspaol-fmd56a",
+ },
+ {
+ // (K) Vietnamese.
+ "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" +
+ "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" +
+ "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" +
+ "\u0056\u0069\u1EC7\u0074",
+ "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g",
+ },
+ {
+ // (L) 3<nen>B<gumi><kinpachi><sensei>.
+ "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F",
+ "3B-ww4c5e180e575a65lsy2b",
+ },
+ {
+ // (M) <amuro><namie>-with-SUPER-MONKEYS.
+ "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" +
+ "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" +
+ "\u004F\u004E\u004B\u0045\u0059\u0053",
+ "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n",
+ },
+ {
+ // (N) Hello-Another-Way-<sorezore><no><basho>.
+ "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" +
+ "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" +
+ "\u305D\u308C\u305E\u308C\u306E\u5834\u6240",
+ "Hello-Another-Way--fc4qua05auwb3674vfr0b",
+ },
+ {
+ // (O) <hitotsu><yane><no><shita>2.
+ "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032",
+ "2-u9tlzr9756bt3uc0v",
+ },
+ {
+ // (P) Maji<de>Koi<suru>5<byou><mae>
+ "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" +
+ "\u308B\u0035\u79D2\u524D",
+ "MajiKoi5-783gue6qz075azm5e",
+ },
+ {
+ // (Q) <pafii>de<runba>
+ "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0",
+ "de-jg4avhby1noc0d",
+ },
+ {
+ // (R) <sono><supiido><de>
+ "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067",
+ "d9juau41awczczp",
+ },
+ {
+ // (S) -> $1.00 <-
+ "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" +
+ "\u003C\u002D",
+ "-> $1.00 <--",
+ },
+}
+
+func TestPunycode(t *testing.T) {
+ for _, tc := range punycodeTestCases {
+ if got, err := encode("", tc.s); err != nil {
+ t.Errorf(`encode("", %q): %v`, tc.s, err)
+ } else if got != tc.encoded {
+ t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded)
+ }
+ }
+}
diff --git a/src/pkg/net/http/example_test.go b/src/pkg/net/http/example_test.go
index ec814407d..22073eaf7 100644
--- a/src/pkg/net/http/example_test.go
+++ b/src/pkg/net/http/example_test.go
@@ -43,10 +43,10 @@ func ExampleGet() {
log.Fatal(err)
}
robots, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
if err != nil {
log.Fatal(err)
}
- res.Body.Close()
fmt.Printf("%s", robots)
}
diff --git a/src/pkg/net/http/export_test.go b/src/pkg/net/http/export_test.go
index 13640ca85..a7bca20a0 100644
--- a/src/pkg/net/http/export_test.go
+++ b/src/pkg/net/http/export_test.go
@@ -7,12 +7,25 @@
package http
-import "time"
+import (
+ "net"
+ "time"
+)
+
+func NewLoggingConn(baseName string, c net.Conn) net.Conn {
+ return newLoggingConn(baseName, c)
+}
+
+func (t *Transport) NumPendingRequestsForTesting() int {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ return len(t.reqConn)
+}
func (t *Transport) IdleConnKeysForTesting() (keys []string) {
keys = make([]string, 0)
- t.lk.Lock()
- defer t.lk.Unlock()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
if t.idleConn == nil {
return
}
@@ -23,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) {
}
func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
- t.lk.Lock()
- defer t.lk.Unlock()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
if t.idleConn == nil {
return 0
}
diff --git a/src/pkg/net/http/filetransport_test.go b/src/pkg/net/http/filetransport_test.go
index 039926b53..6f1a537e2 100644
--- a/src/pkg/net/http/filetransport_test.go
+++ b/src/pkg/net/http/filetransport_test.go
@@ -2,11 +2,10 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package http_test
+package http
import (
"io/ioutil"
- "net/http"
"os"
"path/filepath"
"testing"
@@ -32,9 +31,9 @@ func TestFileTransport(t *testing.T) {
defer os.Remove(dname)
defer os.Remove(fname)
- tr := &http.Transport{}
- tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname)))
- c := &http.Client{Transport: tr}
+ tr := &Transport{}
+ tr.RegisterProtocol("file", NewFileTransport(Dir(dname)))
+ c := &Client{Transport: tr}
fooURLs := []string{"file:///foo.txt", "file://../foo.txt"}
for _, urlstr := range fooURLs {
@@ -62,4 +61,5 @@ func TestFileTransport(t *testing.T) {
if res.StatusCode != 404 {
t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode)
}
+ res.Body.Close()
}
diff --git a/src/pkg/net/http/fs.go b/src/pkg/net/http/fs.go
index f35dd32c3..b6bea0dfa 100644
--- a/src/pkg/net/http/fs.go
+++ b/src/pkg/net/http/fs.go
@@ -11,6 +11,8 @@ import (
"fmt"
"io"
"mime"
+ "mime/multipart"
+ "net/textproto"
"os"
"path"
"path/filepath"
@@ -26,7 +28,8 @@ import (
type Dir string
func (d Dir) Open(name string) (File, error) {
- if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 {
+ if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 ||
+ strings.Contains(name, "\x00") {
return nil, errors.New("http: invalid character in file path")
}
dir := string(d)
@@ -97,6 +100,9 @@ func dirList(w ResponseWriter, f File) {
// The content's Seek method must work: ServeContent uses
// a seek to the end of the content to determine its size.
//
+// If the caller has set w's ETag header, ServeContent uses it to
+// handle requests using If-Range and If-None-Match.
+//
// Note that *os.File implements the io.ReadSeeker interface.
func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) {
size, err := content.Seek(0, os.SEEK_END)
@@ -119,12 +125,17 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
if checkLastModified(w, r, modtime) {
return
}
+ rangeReq, done := checkETag(w, r)
+ if done {
+ return
+ }
code := StatusOK
// If Content-Type isn't set, use the file's extension to find it.
- if w.Header().Get("Content-Type") == "" {
- ctype := mime.TypeByExtension(filepath.Ext(name))
+ ctype := w.Header().Get("Content-Type")
+ if ctype == "" {
+ ctype = mime.TypeByExtension(filepath.Ext(name))
if ctype == "" {
// read a chunk to decide between utf-8 text and binary
var buf [1024]byte
@@ -141,18 +152,34 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
}
// handle Content-Range header.
- // TODO(adg): handle multiple ranges
sendSize := size
+ var sendContent io.Reader = content
if size >= 0 {
- ranges, err := parseRange(r.Header.Get("Range"), size)
- if err == nil && len(ranges) > 1 {
- err = errors.New("multiple ranges not supported")
- }
+ ranges, err := parseRange(rangeReq, size)
if err != nil {
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
return
}
- if len(ranges) == 1 {
+ if sumRangesSize(ranges) >= size {
+ // The total number of bytes in all the ranges
+ // is larger than the size of the file by
+ // itself, so this is probably an attack, or a
+ // dumb client. Ignore the range request.
+ ranges = nil
+ }
+ switch {
+ case len(ranges) == 1:
+ // RFC 2616, Section 14.16:
+ // "When an HTTP message includes the content of a single
+ // range (for example, a response to a request for a
+ // single range, or to a request for a set of ranges
+ // that overlap without any holes), this content is
+ // transmitted with a Content-Range header, and a
+ // Content-Length header showing the number of bytes
+ // actually transferred.
+ // ...
+ // A response to a request for a single range MUST NOT
+ // be sent using the multipart/byteranges media type."
ra := ranges[0]
if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
@@ -160,7 +187,41 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
}
sendSize = ra.length
code = StatusPartialContent
- w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, size))
+ w.Header().Set("Content-Range", ra.contentRange(size))
+ case len(ranges) > 1:
+ for _, ra := range ranges {
+ if ra.start > size {
+ Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ return
+ }
+ }
+ sendSize = rangesMIMESize(ranges, ctype, size)
+ code = StatusPartialContent
+
+ pr, pw := io.Pipe()
+ mw := multipart.NewWriter(pw)
+ w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary())
+ sendContent = pr
+ defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish.
+ go func() {
+ for _, ra := range ranges {
+ part, err := mw.CreatePart(ra.mimeHeader(ctype, size))
+ if err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := io.CopyN(part, content, ra.length); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ }
+ mw.Close()
+ pw.Close()
+ }()
}
w.Header().Set("Accept-Ranges", "bytes")
@@ -172,11 +233,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
w.WriteHeader(code)
if r.Method != "HEAD" {
- if sendSize == -1 {
- io.Copy(w, content)
- } else {
- io.CopyN(w, content, sendSize)
- }
+ io.CopyN(w, sendContent, sendSize)
}
}
@@ -190,6 +247,9 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool {
// The Date-Modified header truncates sub-second precision, so
// use mtime < t+1s instead of mtime <= t to check for unmodified.
if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) {
+ h := w.Header()
+ delete(h, "Content-Type")
+ delete(h, "Content-Length")
w.WriteHeader(StatusNotModified)
return true
}
@@ -197,6 +257,58 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool {
return false
}
+// checkETag implements If-None-Match and If-Range checks.
+// The ETag must have been previously set in the ResponseWriter's headers.
+//
+// The return value is the effective request "Range" header to use and
+// whether this request is now considered done.
+func checkETag(w ResponseWriter, r *Request) (rangeReq string, done bool) {
+ etag := w.Header().get("Etag")
+ rangeReq = r.Header.get("Range")
+
+ // Invalidate the range request if the entity doesn't match the one
+ // the client was expecting.
+ // "If-Range: version" means "ignore the Range: header unless version matches the
+ // current file."
+ // We only support ETag versions.
+ // The caller must have set the ETag on the response already.
+ if ir := r.Header.get("If-Range"); ir != "" && ir != etag {
+ // TODO(bradfitz): handle If-Range requests with Last-Modified
+ // times instead of ETags? I'd rather not, at least for
+ // now. That seems like a bug/compromise in the RFC 2616, and
+ // I've never heard of anybody caring about that (yet).
+ rangeReq = ""
+ }
+
+ if inm := r.Header.get("If-None-Match"); inm != "" {
+ // Must know ETag.
+ if etag == "" {
+ return rangeReq, false
+ }
+
+ // TODO(bradfitz): non-GET/HEAD requests require more work:
+ // sending a different status code on matches, and
+ // also can't use weak cache validators (those with a "W/
+ // prefix). But most users of ServeContent will be using
+ // it on GET or HEAD, so only support those for now.
+ if r.Method != "GET" && r.Method != "HEAD" {
+ return rangeReq, false
+ }
+
+ // TODO(bradfitz): deal with comma-separated or multiple-valued
+ // list of If-None-match values. For now just handle the common
+ // case of a single item.
+ if inm == etag || inm == "*" {
+ h := w.Header()
+ delete(h, "Content-Type")
+ delete(h, "Content-Length")
+ w.WriteHeader(StatusNotModified)
+ return "", true
+ }
+ }
+ return rangeReq, false
+}
+
// name is '/'-separated, not filepath.Separator.
func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) {
const indexPage = "/index.html"
@@ -243,9 +355,6 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
// use contents of index.html for directory, if present
if d.IsDir() {
- if checkLastModified(w, r, d.ModTime()) {
- return
- }
index := name + indexPage
ff, err := fs.Open(index)
if err == nil {
@@ -259,11 +368,16 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
}
}
+ // Still a directory? (we didn't find an index.html file)
if d.IsDir() {
+ if checkLastModified(w, r, d.ModTime()) {
+ return
+ }
dirList(w, f)
return
}
+ // serverContent will check modification time
serveContent(w, r, d.Name(), d.ModTime(), d.Size(), f)
}
@@ -312,6 +426,17 @@ type httpRange struct {
start, length int64
}
+func (r httpRange) contentRange(size int64) string {
+ return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size)
+}
+
+func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader {
+ return textproto.MIMEHeader{
+ "Content-Range": {r.contentRange(size)},
+ "Content-Type": {contentType},
+ }
+}
+
// parseRange parses a Range header string as per RFC 2616.
func parseRange(s string, size int64) ([]httpRange, error) {
if s == "" {
@@ -323,11 +448,15 @@ func parseRange(s string, size int64) ([]httpRange, error) {
}
var ranges []httpRange
for _, ra := range strings.Split(s[len(b):], ",") {
+ ra = strings.TrimSpace(ra)
+ if ra == "" {
+ continue
+ }
i := strings.Index(ra, "-")
if i < 0 {
return nil, errors.New("invalid range")
}
- start, end := ra[:i], ra[i+1:]
+ start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:])
var r httpRange
if start == "" {
// If no start is specified, end specifies the
@@ -365,3 +494,32 @@ func parseRange(s string, size int64) ([]httpRange, error) {
}
return ranges, nil
}
+
+// countingWriter counts how many bytes have been written to it.
+type countingWriter int64
+
+func (w *countingWriter) Write(p []byte) (n int, err error) {
+ *w += countingWriter(len(p))
+ return len(p), nil
+}
+
+// rangesMIMESize returns the nunber of bytes it takes to encode the
+// provided ranges as a multipart response.
+func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
+ var w countingWriter
+ mw := multipart.NewWriter(&w)
+ for _, ra := range ranges {
+ mw.CreatePart(ra.mimeHeader(contentType, contentSize))
+ encSize += ra.length
+ }
+ mw.Close()
+ encSize += int64(w)
+ return
+}
+
+func sumRangesSize(ranges []httpRange) (size int64) {
+ for _, ra := range ranges {
+ size += ra.length
+ }
+ return
+}
diff --git a/src/pkg/net/http/fs_test.go b/src/pkg/net/http/fs_test.go
index 5aa93ce58..0dd6d0df9 100644
--- a/src/pkg/net/http/fs_test.go
+++ b/src/pkg/net/http/fs_test.go
@@ -10,12 +10,15 @@ import (
"fmt"
"io"
"io/ioutil"
+ "mime"
+ "mime/multipart"
"net"
. "net/http"
"net/http/httptest"
"net/url"
"os"
"os/exec"
+ "path"
"path/filepath"
"regexp"
"runtime"
@@ -25,24 +28,33 @@ import (
)
const (
- testFile = "testdata/file"
- testFileLength = 11
+ testFile = "testdata/file"
+ testFileLen = 11
)
+type wantRange struct {
+ start, end int64 // range [start,end)
+}
+
var ServeFileRangeTests = []struct {
- start, end int
- r string
- code int
+ r string
+ code int
+ ranges []wantRange
}{
- {0, testFileLength, "", StatusOK},
- {0, 5, "0-4", StatusPartialContent},
- {2, testFileLength, "2-", StatusPartialContent},
- {testFileLength - 5, testFileLength, "-5", StatusPartialContent},
- {3, 8, "3-7", StatusPartialContent},
- {0, 0, "20-", StatusRequestedRangeNotSatisfiable},
+ {r: "", code: StatusOK},
+ {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}},
+ {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}},
+ {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}},
+ {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}},
+ {r: "bytes=20-", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}},
+ {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}},
+ {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}},
+ {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request
}
func TestServeFile(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "testdata/file")
}))
@@ -65,33 +77,86 @@ func TestServeFile(t *testing.T) {
// straight GET
_, body := getBody(t, "straight get", req)
- if !equal(body, file) {
+ if !bytes.Equal(body, file) {
t.Fatalf("body mismatch: got %q, want %q", body, file)
}
// Range tests
- for i, rt := range ServeFileRangeTests {
- req.Header.Set("Range", "bytes="+rt.r)
- if rt.r == "" {
- req.Header["Range"] = nil
+Cases:
+ for _, rt := range ServeFileRangeTests {
+ if rt.r != "" {
+ req.Header.Set("Range", rt.r)
}
- r, body := getBody(t, fmt.Sprintf("test %d", i), req)
- if r.StatusCode != rt.code {
- t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, r.StatusCode, rt.code)
+ resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req)
+ if resp.StatusCode != rt.code {
+ t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code)
}
if rt.code == StatusRequestedRangeNotSatisfiable {
continue
}
- h := fmt.Sprintf("bytes %d-%d/%d", rt.start, rt.end-1, testFileLength)
- if rt.r == "" {
- h = ""
+ wantContentRange := ""
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ }
+ cr := resp.Header.Get("Content-Range")
+ if cr != wantContentRange {
+ t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange)
}
- cr := r.Header.Get("Content-Range")
- if cr != h {
- t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h)
+ ct := resp.Header.Get("Content-Type")
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ if strings.HasPrefix(ct, "multipart/byteranges") {
+ t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r, ct)
+ }
}
- if !equal(body, file[rt.start:rt.end]) {
- t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end])
+ if len(rt.ranges) > 1 {
+ typ, params, err := mime.ParseMediaType(ct)
+ if err != nil {
+ t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err)
+ continue
+ }
+ if typ != "multipart/byteranges" {
+ t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r, typ)
+ continue
+ }
+ if params["boundary"] == "" {
+ t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct)
+ continue
+ }
+ if g, w := resp.ContentLength, int64(len(body)); g != w {
+ t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w)
+ continue
+ }
+ mr := multipart.NewReader(bytes.NewReader(body), params["boundary"])
+ for ri, rng := range rt.ranges {
+ part, err := mr.NextPart()
+ if err != nil {
+ t.Errorf("range=%q, reading part index %d: %v", rt.r, ri, err)
+ continue Cases
+ }
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w {
+ t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w)
+ }
+ body, err := ioutil.ReadAll(part)
+ if err != nil {
+ t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err)
+ continue Cases
+ }
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ }
+ _, err = mr.NextPart()
+ if err != io.EOF {
+ t.Errorf("range=%q; expected final error io.EOF; got %v", rt.r, err)
+ }
}
}
}
@@ -105,6 +170,7 @@ var fsRedirectTestData = []struct {
}
func TestFSRedirect(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir("."))))
defer ts.Close()
@@ -129,6 +195,7 @@ func (fs *testFileSystem) Open(name string) (File, error) {
}
func TestFileServerCleans(t *testing.T) {
+ defer checkLeakedTransports(t)
ch := make(chan string, 1)
fs := FileServer(&testFileSystem{func(name string) (File, error) {
ch <- name
@@ -160,6 +227,7 @@ func mustRemoveAll(dir string) {
}
func TestFileServerImplicitLeadingSlash(t *testing.T) {
+ defer checkLeakedTransports(t)
tempDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatalf("TempDir: %v", err)
@@ -193,8 +261,7 @@ func TestFileServerImplicitLeadingSlash(t *testing.T) {
func TestDirJoin(t *testing.T) {
wfi, err := os.Stat("/etc/hosts")
if err != nil {
- t.Logf("skipping test; no /etc/hosts file")
- return
+ t.Skip("skipping test; no /etc/hosts file")
}
test := func(d Dir, name string) {
f, err := d.Open(name)
@@ -239,6 +306,7 @@ func TestEmptyDirOpenCWD(t *testing.T) {
}
func TestServeFileContentType(t *testing.T) {
+ defer checkLeakedTransports(t)
const ctype = "icecream/chocolate"
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.FormValue("override") == "1" {
@@ -255,12 +323,14 @@ func TestServeFileContentType(t *testing.T) {
if h := resp.Header.Get("Content-Type"); h != want {
t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
}
+ resp.Body.Close()
}
get("0", "text/plain; charset=utf-8")
get("1", ctype)
}
func TestServeFileMimeType(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "testdata/style.css")
}))
@@ -269,6 +339,7 @@ func TestServeFileMimeType(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ resp.Body.Close()
want := "text/css; charset=utf-8"
if h := resp.Header.Get("Content-Type"); h != want {
t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
@@ -276,6 +347,7 @@ func TestServeFileMimeType(t *testing.T) {
}
func TestServeFileFromCWD(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "fs_test.go")
}))
@@ -284,12 +356,14 @@ func TestServeFileFromCWD(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ r.Body.Close()
if r.StatusCode != 200 {
t.Fatalf("expected 200 OK, got %s", r.Status)
}
}
func TestServeFileWithContentEncoding(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "foo")
ServeFile(w, r, "testdata/file")
@@ -299,12 +373,14 @@ func TestServeFileWithContentEncoding(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ resp.Body.Close()
if g, e := resp.ContentLength, int64(-1); g != e {
t.Errorf("Content-Length mismatch: got %d, want %d", g, e)
}
}
func TestServeIndexHtml(t *testing.T) {
+ defer checkLeakedTransports(t)
const want = "index.html says hello\n"
ts := httptest.NewServer(FileServer(Dir(".")))
defer ts.Close()
@@ -325,64 +401,289 @@ func TestServeIndexHtml(t *testing.T) {
}
}
-func TestServeContent(t *testing.T) {
- type req struct {
- name string
- modtime time.Time
- content io.ReadSeeker
- }
- ch := make(chan req, 1)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- p := <-ch
- ServeContent(w, r, p.name, p.modtime, p.content)
- }))
+func TestFileServerZeroByte(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(FileServer(Dir(".")))
defer ts.Close()
- css, err := os.Open("testdata/style.css")
+ res, err := Get(ts.URL + "/..\x00")
if err != nil {
t.Fatal(err)
}
- defer css.Close()
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if res.StatusCode == 200 {
+ t.Errorf("got status 200; want an error. Body is:\n%s", string(b))
+ }
+}
+
+type fakeFileInfo struct {
+ dir bool
+ basename string
+ modtime time.Time
+ ents []*fakeFileInfo
+ contents string
+}
+
+func (f *fakeFileInfo) Name() string { return f.basename }
+func (f *fakeFileInfo) Sys() interface{} { return nil }
+func (f *fakeFileInfo) ModTime() time.Time { return f.modtime }
+func (f *fakeFileInfo) IsDir() bool { return f.dir }
+func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) }
+func (f *fakeFileInfo) Mode() os.FileMode {
+ if f.dir {
+ return 0755 | os.ModeDir
+ }
+ return 0644
+}
+
+type fakeFile struct {
+ io.ReadSeeker
+ fi *fakeFileInfo
+ path string // as opened
+}
+
+func (f *fakeFile) Close() error { return nil }
+func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil }
+func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
+ if !f.fi.dir {
+ return nil, os.ErrInvalid
+ }
+ var fis []os.FileInfo
+ for _, fi := range f.fi.ents {
+ fis = append(fis, fi)
+ }
+ return fis, nil
+}
+
+type fakeFS map[string]*fakeFileInfo
+
+func (fs fakeFS) Open(name string) (File, error) {
+ name = path.Clean(name)
+ f, ok := fs[name]
+ if !ok {
+ println("fake filesystem didn't find file", name)
+ return nil, os.ErrNotExist
+ }
+ return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil
+}
+
+func TestDirectoryIfNotModified(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const indexContents = "I am a fake index.html file"
+ fileMod := time.Unix(1000000000, 0).UTC()
+ fileModStr := fileMod.Format(TimeFormat)
+ dirMod := time.Unix(123, 0).UTC()
+ indexFile := &fakeFileInfo{
+ basename: "index.html",
+ modtime: fileMod,
+ contents: indexContents,
+ }
+ fs := fakeFS{
+ "/": &fakeFileInfo{
+ dir: true,
+ modtime: dirMod,
+ ents: []*fakeFileInfo{indexFile},
+ },
+ "/index.html": indexFile,
+ }
+
+ ts := httptest.NewServer(FileServer(fs))
+ defer ts.Close()
- ch <- req{"style.css", time.Time{}, css}
res, err := Get(ts.URL)
if err != nil {
t.Fatal(err)
}
- if g, e := res.Header.Get("Content-Type"), "text/css; charset=utf-8"; g != e {
- t.Errorf("style.css: content type = %q, want %q", g, e)
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
}
- if g := res.Header.Get("Last-Modified"); g != "" {
- t.Errorf("want empty Last-Modified; got %q", g)
+ if string(b) != indexContents {
+ t.Fatalf("Got body %q; want %q", b, indexContents)
}
+ res.Body.Close()
+
+ lastMod := res.Header.Get("Last-Modified")
+ if lastMod != fileModStr {
+ t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr)
+ }
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Header.Set("If-Modified-Since", lastMod)
- fi, err := css.Stat()
+ res, err = DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
- ch <- req{"style.html", fi.ModTime(), css}
- res, err = Get(ts.URL)
+ if res.StatusCode != 304 {
+ t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode)
+ }
+ res.Body.Close()
+
+ // Advance the index.html file's modtime, but not the directory's.
+ indexFile.modtime = indexFile.modtime.Add(1 * time.Hour)
+
+ res, err = DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
- if g, e := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != e {
- t.Errorf("style.html: content type = %q, want %q", g, e)
+ if res.StatusCode != 200 {
+ t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res)
}
- if g := res.Header.Get("Last-Modified"); g == "" {
- t.Errorf("want non-empty last-modified")
+ res.Body.Close()
+}
+
+func mustStat(t *testing.T, fileName string) os.FileInfo {
+ fi, err := os.Stat(fileName)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return fi
+}
+
+func TestServeContent(t *testing.T) {
+ defer checkLeakedTransports(t)
+ type serveParam struct {
+ name string
+ modtime time.Time
+ content io.ReadSeeker
+ contentType string
+ etag string
+ }
+ servec := make(chan serveParam, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ p := <-servec
+ if p.etag != "" {
+ w.Header().Set("ETag", p.etag)
+ }
+ if p.contentType != "" {
+ w.Header().Set("Content-Type", p.contentType)
+ }
+ ServeContent(w, r, p.name, p.modtime, p.content)
+ }))
+ defer ts.Close()
+
+ type testCase struct {
+ file string
+ modtime time.Time
+ serveETag string // optional
+ serveContentType string // optional
+ reqHeader map[string]string
+ wantLastMod string
+ wantContentType string
+ wantStatus int
+ }
+ htmlModTime := mustStat(t, "testdata/index.html").ModTime()
+ tests := map[string]testCase{
+ "no_last_modified": {
+ file: "testdata/style.css",
+ wantContentType: "text/css; charset=utf-8",
+ wantStatus: 200,
+ },
+ "with_last_modified": {
+ file: "testdata/index.html",
+ wantContentType: "text/html; charset=utf-8",
+ modtime: htmlModTime,
+ wantLastMod: htmlModTime.UTC().Format(TimeFormat),
+ wantStatus: 200,
+ },
+ "not_modified_modtime": {
+ file: "testdata/style.css",
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 304,
+ },
+ "not_modified_modtime_with_contenttype": {
+ file: "testdata/style.css",
+ serveContentType: "text/css", // explicit content type
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 304,
+ },
+ "not_modified_etag": {
+ file: "testdata/style.css",
+ serveETag: `"foo"`,
+ reqHeader: map[string]string{
+ "If-None-Match": `"foo"`,
+ },
+ wantStatus: 304,
+ },
+ "range_good": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ // An If-Range resource for entity "A", but entity "B" is now current.
+ // The Range request should be ignored.
+ "range_no_match": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": `"B"`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ }
+ for testName, tt := range tests {
+ f, err := os.Open(tt.file)
+ if err != nil {
+ t.Fatalf("test %q: %v", testName, err)
+ }
+ defer f.Close()
+
+ servec <- serveParam{
+ name: filepath.Base(tt.file),
+ content: f,
+ modtime: tt.modtime,
+ etag: tt.serveETag,
+ contentType: tt.serveContentType,
+ }
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for k, v := range tt.reqHeader {
+ req.Header.Set(k, v)
+ }
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ io.Copy(ioutil.Discard, res.Body)
+ res.Body.Close()
+ if res.StatusCode != tt.wantStatus {
+ t.Errorf("test %q: status = %d; want %d", testName, res.StatusCode, tt.wantStatus)
+ }
+ if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e {
+ t.Errorf("test %q: content-type = %q, want %q", testName, g, e)
+ }
+ if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e {
+ t.Errorf("test %q: last-modified = %q, want %q", testName, g, e)
+ }
}
}
// verifies that sendfile is being used on Linux
func TestLinuxSendfile(t *testing.T) {
+ defer checkLeakedTransports(t)
if runtime.GOOS != "linux" {
- t.Logf("skipping; linux-only test")
- return
+ t.Skip("skipping; linux-only test")
}
- _, err := exec.LookPath("strace")
- if err != nil {
- t.Logf("skipping; strace not found in path")
- return
+ if _, err := exec.LookPath("strace"); err != nil {
+ t.Skip("skipping; strace not found in path")
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
@@ -401,10 +702,8 @@ func TestLinuxSendfile(t *testing.T) {
child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...)
child.Stdout = &buf
child.Stderr = &buf
- err = child.Start()
- if err != nil {
- t.Logf("skipping; failed to start straced child: %v", err)
- return
+ if err := child.Start(); err != nil {
+ t.Skipf("skipping; failed to start straced child: %v", err)
}
res, err := Get(fmt.Sprintf("http://%s/", ln.Addr()))
@@ -464,15 +763,3 @@ func TestLinuxSendfileChild(*testing.T) {
panic(err)
}
}
-
-func equal(a, b []byte) bool {
- if len(a) != len(b) {
- return false
- }
- for i := range a {
- if a[i] != b[i] {
- return false
- }
- }
- return true
-}
diff --git a/src/pkg/net/http/header.go b/src/pkg/net/http/header.go
index b107c312d..f479b7b4e 100644
--- a/src/pkg/net/http/header.go
+++ b/src/pkg/net/http/header.go
@@ -5,11 +5,11 @@
package http
import (
- "fmt"
"io"
"net/textproto"
"sort"
"strings"
+ "time"
)
// A Header represents the key-value pairs in an HTTP header.
@@ -36,6 +36,14 @@ func (h Header) Get(key string) string {
return textproto.MIMEHeader(h).Get(key)
}
+// get is like Get, but key must already be in CanonicalHeaderKey form.
+func (h Header) get(key string) string {
+ if v := h[key]; len(v) > 0 {
+ return v[0]
+ }
+ return ""
+}
+
// Del deletes the values associated with key.
func (h Header) Del(key string) {
textproto.MIMEHeader(h).Del(key)
@@ -46,24 +54,87 @@ func (h Header) Write(w io.Writer) error {
return h.WriteSubset(w, nil)
}
+func (h Header) clone() Header {
+ h2 := make(Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
+
+var timeFormats = []string{
+ TimeFormat,
+ time.RFC850,
+ time.ANSIC,
+}
+
+// ParseTime parses a time header (such as the Date: header),
+// trying each of the three formats allowed by HTTP/1.1:
+// TimeFormat, time.RFC850, and time.ANSIC.
+func ParseTime(text string) (t time.Time, err error) {
+ for _, layout := range timeFormats {
+ t, err = time.Parse(layout, text)
+ if err == nil {
+ return
+ }
+ }
+ return
+}
+
var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ")
+type writeStringer interface {
+ WriteString(string) (int, error)
+}
+
+// stringWriter implements WriteString on a Writer.
+type stringWriter struct {
+ w io.Writer
+}
+
+func (w stringWriter) WriteString(s string) (n int, err error) {
+ return w.w.Write([]byte(s))
+}
+
+type keyValues struct {
+ key string
+ values []string
+}
+
+type byKey []keyValues
+
+func (s byKey) Len() int { return len(s) }
+func (s byKey) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+func (s byKey) Less(i, j int) bool { return s[i].key < s[j].key }
+
+func (h Header) sortedKeyValues(exclude map[string]bool) []keyValues {
+ kvs := make([]keyValues, 0, len(h))
+ for k, vv := range h {
+ if !exclude[k] {
+ kvs = append(kvs, keyValues{k, vv})
+ }
+ }
+ sort.Sort(byKey(kvs))
+ return kvs
+}
+
// WriteSubset writes a header in wire format.
// If exclude is not nil, keys where exclude[key] == true are not written.
func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
- keys := make([]string, 0, len(h))
- for k := range h {
- if exclude == nil || !exclude[k] {
- keys = append(keys, k)
- }
+ ws, ok := w.(writeStringer)
+ if !ok {
+ ws = stringWriter{w}
}
- sort.Strings(keys)
- for _, k := range keys {
- for _, v := range h[k] {
+ for _, kv := range h.sortedKeyValues(exclude) {
+ for _, v := range kv.values {
v = headerNewlineToSpace.Replace(v)
- v = strings.TrimSpace(v)
- if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil {
- return err
+ v = textproto.TrimString(v)
+ for _, s := range []string{kv.key, ": ", v, "\r\n"} {
+ if _, err := ws.WriteString(s); err != nil {
+ return err
+ }
}
}
}
@@ -76,3 +147,43 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
// the rest are converted to lowercase. For example, the
// canonical key for "accept-encoding" is "Accept-Encoding".
func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
+
+// hasToken returns whether token appears with v, ASCII
+// case-insensitive, with space or comma boundaries.
+// token must be all lowercase.
+// v may contain mixed cased.
+func hasToken(v, token string) bool {
+ if len(token) > len(v) || token == "" {
+ return false
+ }
+ if v == token {
+ return true
+ }
+ for sp := 0; sp <= len(v)-len(token); sp++ {
+ // Check that first character is good.
+ // The token is ASCII, so checking only a single byte
+ // is sufficient. We skip this potential starting
+ // position if both the first byte and its potential
+ // ASCII uppercase equivalent (b|0x20) don't match.
+ // False positives ('^' => '~') are caught by EqualFold.
+ if b := v[sp]; b != token[0] && b|0x20 != token[0] {
+ continue
+ }
+ // Check that start pos is on a valid token boundary.
+ if sp > 0 && !isTokenBoundary(v[sp-1]) {
+ continue
+ }
+ // Check that end pos is on a valid token boundary.
+ if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
+ continue
+ }
+ if strings.EqualFold(v[sp:sp+len(token)], token) {
+ return true
+ }
+ }
+ return false
+}
+
+func isTokenBoundary(b byte) bool {
+ return b == ' ' || b == ',' || b == '\t'
+}
diff --git a/src/pkg/net/http/header_test.go b/src/pkg/net/http/header_test.go
index ccdee8a97..2313b5549 100644
--- a/src/pkg/net/http/header_test.go
+++ b/src/pkg/net/http/header_test.go
@@ -7,6 +7,7 @@ package http
import (
"bytes"
"testing"
+ "time"
)
var headerWriteTests = []struct {
@@ -67,6 +68,24 @@ var headerWriteTests = []struct {
nil,
"Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n",
},
+ // Tests header sorting when over the insertion sort threshold side:
+ {
+ Header{
+ "k1": {"1a", "1b"},
+ "k2": {"2a", "2b"},
+ "k3": {"3a", "3b"},
+ "k4": {"4a", "4b"},
+ "k5": {"5a", "5b"},
+ "k6": {"6a", "6b"},
+ "k7": {"7a", "7b"},
+ "k8": {"8a", "8b"},
+ "k9": {"9a", "9b"},
+ },
+ map[string]bool{"k5": true},
+ "k1: 1a\r\nk1: 1b\r\nk2: 2a\r\nk2: 2b\r\nk3: 3a\r\nk3: 3b\r\n" +
+ "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" +
+ "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n",
+ },
}
func TestHeaderWrite(t *testing.T) {
@@ -79,3 +98,107 @@ func TestHeaderWrite(t *testing.T) {
buf.Reset()
}
}
+
+var parseTimeTests = []struct {
+ h Header
+ err bool
+}{
+ {Header{"Date": {""}}, true},
+ {Header{"Date": {"invalid"}}, true},
+ {Header{"Date": {"1994-11-06T08:49:37Z00:00"}}, true},
+ {Header{"Date": {"Sun, 06 Nov 1994 08:49:37 GMT"}}, false},
+ {Header{"Date": {"Sunday, 06-Nov-94 08:49:37 GMT"}}, false},
+ {Header{"Date": {"Sun Nov 6 08:49:37 1994"}}, false},
+}
+
+func TestParseTime(t *testing.T) {
+ expect := time.Date(1994, 11, 6, 8, 49, 37, 0, time.UTC)
+ for i, test := range parseTimeTests {
+ d, err := ParseTime(test.h.Get("Date"))
+ if err != nil {
+ if !test.err {
+ t.Errorf("#%d:\n got err: %v", i, err)
+ }
+ continue
+ }
+ if test.err {
+ t.Errorf("#%d:\n should err", i)
+ continue
+ }
+ if !expect.Equal(d) {
+ t.Errorf("#%d:\n got: %v\nwant: %v", i, d, expect)
+ }
+ }
+}
+
+type hasTokenTest struct {
+ header string
+ token string
+ want bool
+}
+
+var hasTokenTests = []hasTokenTest{
+ {"", "", false},
+ {"", "foo", false},
+ {"foo", "foo", true},
+ {"foo ", "foo", true},
+ {" foo", "foo", true},
+ {" foo ", "foo", true},
+ {"foo,bar", "foo", true},
+ {"bar,foo", "foo", true},
+ {"bar, foo", "foo", true},
+ {"bar,foo, baz", "foo", true},
+ {"bar, foo,baz", "foo", true},
+ {"bar,foo, baz", "foo", true},
+ {"bar, foo, baz", "foo", true},
+ {"FOO", "foo", true},
+ {"FOO ", "foo", true},
+ {" FOO", "foo", true},
+ {" FOO ", "foo", true},
+ {"FOO,BAR", "foo", true},
+ {"BAR,FOO", "foo", true},
+ {"BAR, FOO", "foo", true},
+ {"BAR,FOO, baz", "foo", true},
+ {"BAR, FOO,BAZ", "foo", true},
+ {"BAR,FOO, BAZ", "foo", true},
+ {"BAR, FOO, BAZ", "foo", true},
+ {"foobar", "foo", false},
+ {"barfoo ", "foo", false},
+}
+
+func TestHasToken(t *testing.T) {
+ for _, tt := range hasTokenTests {
+ if hasToken(tt.header, tt.token) != tt.want {
+ t.Errorf("hasToken(%q, %q) = %v; want %v", tt.header, tt.token, !tt.want, tt.want)
+ }
+ }
+}
+
+var testHeader = Header{
+ "Content-Length": {"123"},
+ "Content-Type": {"text/plain"},
+ "Date": {"some date at some time Z"},
+ "Server": {"Go http package"},
+}
+
+var buf bytes.Buffer
+
+func BenchmarkHeaderWriteSubset(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ buf.Reset()
+ testHeader.WriteSubset(&buf, nil)
+ }
+}
+
+func TestHeaderWriteSubsetMallocs(t *testing.T) {
+ n := testing.AllocsPerRun(100, func() {
+ buf.Reset()
+ testHeader.WriteSubset(&buf, nil)
+ })
+ if n > 1 {
+ // TODO(bradfitz,rsc): once we can sort without allocating,
+ // make this an error. See http://golang.org/issue/3761
+ // t.Errorf("got %v allocs, want <= %v", n, 1)
+ }
+}
diff --git a/src/pkg/net/http/httptest/example_test.go b/src/pkg/net/http/httptest/example_test.go
new file mode 100644
index 000000000..239470d97
--- /dev/null
+++ b/src/pkg/net/http/httptest/example_test.go
@@ -0,0 +1,50 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest_test
+
+import (
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/http/httptest"
+)
+
+func ExampleRecorder() {
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "something failed", http.StatusInternalServerError)
+ }
+
+ req, err := http.NewRequest("GET", "http://example.com/foo", nil)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ w := httptest.NewRecorder()
+ handler(w, req)
+
+ fmt.Printf("%d - %s", w.Code, w.Body.String())
+ // Output: 500 - something failed
+}
+
+func ExampleServer() {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, client")
+ }))
+ defer ts.Close()
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ greeting, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", greeting)
+ // Output: Hello, client
+}
diff --git a/src/pkg/net/http/httptest/recorder.go b/src/pkg/net/http/httptest/recorder.go
index 9aa0d510b..5451f5423 100644
--- a/src/pkg/net/http/httptest/recorder.go
+++ b/src/pkg/net/http/httptest/recorder.go
@@ -17,6 +17,8 @@ type ResponseRecorder struct {
HeaderMap http.Header // the HTTP response headers
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
Flushed bool
+
+ wroteHeader bool
}
// NewRecorder returns an initialized ResponseRecorder.
@@ -24,6 +26,7 @@ func NewRecorder() *ResponseRecorder {
return &ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
+ Code: 200,
}
}
@@ -33,26 +36,37 @@ const DefaultRemoteAddr = "1.2.3.4"
// Header returns the response headers.
func (rw *ResponseRecorder) Header() http.Header {
- return rw.HeaderMap
+ m := rw.HeaderMap
+ if m == nil {
+ m = make(http.Header)
+ rw.HeaderMap = m
+ }
+ return m
}
// Write always succeeds and writes to rw.Body, if not nil.
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
if rw.Body != nil {
rw.Body.Write(buf)
}
- if rw.Code == 0 {
- rw.Code = http.StatusOK
- }
return len(buf), nil
}
// WriteHeader sets rw.Code.
func (rw *ResponseRecorder) WriteHeader(code int) {
- rw.Code = code
+ if !rw.wroteHeader {
+ rw.Code = code
+ }
+ rw.wroteHeader = true
}
// Flush sets rw.Flushed to true.
func (rw *ResponseRecorder) Flush() {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
rw.Flushed = true
}
diff --git a/src/pkg/net/http/httptest/recorder_test.go b/src/pkg/net/http/httptest/recorder_test.go
new file mode 100644
index 000000000..2b563260c
--- /dev/null
+++ b/src/pkg/net/http/httptest/recorder_test.go
@@ -0,0 +1,90 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "fmt"
+ "net/http"
+ "testing"
+)
+
+func TestRecorder(t *testing.T) {
+ type checkFunc func(*ResponseRecorder) error
+ check := func(fns ...checkFunc) []checkFunc { return fns }
+
+ hasStatus := func(wantCode int) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Code != wantCode {
+ return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
+ }
+ return nil
+ }
+ }
+ hasContents := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Body.String() != want {
+ return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
+ }
+ return nil
+ }
+ }
+ hasFlush := func(want bool) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Flushed != want {
+ return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
+ }
+ return nil
+ }
+ }
+
+ tests := []struct {
+ name string
+ h func(w http.ResponseWriter, r *http.Request)
+ checks []checkFunc
+ }{
+ {
+ "200 default",
+ func(w http.ResponseWriter, r *http.Request) {},
+ check(hasStatus(200), hasContents("")),
+ },
+ {
+ "first code only",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ w.Write([]byte("hi"))
+ },
+ check(hasStatus(201), hasContents("hi")),
+ },
+ {
+ "write sends 200",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hi first"))
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ },
+ check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
+ },
+ {
+ "flush",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(http.Flusher).Flush() // also sends a 200
+ w.WriteHeader(201)
+ },
+ check(hasStatus(200), hasFlush(true)),
+ },
+ }
+ r, _ := http.NewRequest("GET", "http://foo.com/", nil)
+ for _, tt := range tests {
+ h := http.HandlerFunc(tt.h)
+ rec := NewRecorder()
+ h.ServeHTTP(rec, r)
+ for _, check := range tt.checks {
+ if err := check(rec); err != nil {
+ t.Errorf("%s: %v", tt.name, err)
+ }
+ }
+ }
+}
diff --git a/src/pkg/net/http/httptest/server.go b/src/pkg/net/http/httptest/server.go
index 57cf0c941..7f265552f 100644
--- a/src/pkg/net/http/httptest/server.go
+++ b/src/pkg/net/http/httptest/server.go
@@ -21,7 +21,11 @@ import (
type Server struct {
URL string // base URL of form http://ipaddr:port with no trailing slash
Listener net.Listener
- TLS *tls.Config // nil if not using using TLS
+
+ // TLS is the optional TLS configuration, populated with a new config
+ // after TLS is started. If set on an unstarted server before StartTLS
+ // is called, existing fields are copied into the new config.
+ TLS *tls.Config
// Config may be changed after calling NewUnstartedServer and
// before Start or StartTLS.
@@ -36,13 +40,16 @@ type Server struct {
// accepted.
type historyListener struct {
net.Listener
- history []net.Conn
+ sync.Mutex // protects history
+ history []net.Conn
}
func (hs *historyListener) Accept() (c net.Conn, err error) {
c, err = hs.Listener.Accept()
if err == nil {
+ hs.Lock()
hs.history = append(hs.history, c)
+ hs.Unlock()
}
return
}
@@ -96,7 +103,7 @@ func (s *Server) Start() {
if s.URL != "" {
panic("Server already started")
}
- s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)}
+ s.Listener = &historyListener{Listener: s.Listener}
s.URL = "http://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener)
@@ -116,13 +123,20 @@ func (s *Server) StartTLS() {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
- s.TLS = &tls.Config{
- NextProtos: []string{"http/1.1"},
- Certificates: []tls.Certificate{cert},
+ existingConfig := s.TLS
+ s.TLS = new(tls.Config)
+ if existingConfig != nil {
+ *s.TLS = *existingConfig
+ }
+ if s.TLS.NextProtos == nil {
+ s.TLS.NextProtos = []string{"http/1.1"}
+ }
+ if len(s.TLS.Certificates) == 0 {
+ s.TLS.Certificates = []tls.Certificate{cert}
}
tlsListener := tls.NewListener(s.Listener, s.TLS)
- s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)}
+ s.Listener = &historyListener{Listener: tlsListener}
s.URL = "https://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener)
@@ -152,6 +166,10 @@ func NewTLSServer(handler http.Handler) *Server {
func (s *Server) Close() {
s.Listener.Close()
s.wg.Wait()
+ s.CloseClientConnections()
+ if t, ok := http.DefaultTransport.(*http.Transport); ok {
+ t.CloseIdleConnections()
+ }
}
// CloseClientConnections closes any currently open HTTP connections
@@ -161,9 +179,11 @@ func (s *Server) CloseClientConnections() {
if !ok {
return
}
+ hl.Lock()
for _, conn := range hl.history {
conn.Close()
}
+ hl.Unlock()
}
// waitGroupHandler wraps a handler, incrementing and decrementing a
@@ -180,28 +200,29 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.h.ServeHTTP(w, r)
}
-// localhostCert is a PEM-encoded TLS cert with SAN DNS names
+// localhostCert is a PEM-encoded TLS cert with SAN IPs
// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
// of ASN.1 time).
+// generated from src/pkg/crypto/tls:
+// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
-MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
-DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7
-qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL
-8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw
-DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG
-CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7
-+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM=
------END CERTIFICATE-----
-`)
+MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD
+bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj
+bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa
+IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA
+AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud
+EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA
+AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk
+Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA==
+-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
-MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v
-tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi
-SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0
-3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ
-/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN
-poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh
-AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93
------END RSA PRIVATE KEY-----
-`)
+MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0
+0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV
+NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d
+AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW
+MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD
+EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA
+1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE=
+-----END RSA PRIVATE KEY-----`)
diff --git a/src/pkg/net/http/httputil/chunked.go b/src/pkg/net/http/httputil/chunked.go
index 29eaf3475..b66d40951 100644
--- a/src/pkg/net/http/httputil/chunked.go
+++ b/src/pkg/net/http/httputil/chunked.go
@@ -13,10 +13,9 @@ package httputil
import (
"bufio"
- "bytes"
"errors"
+ "fmt"
"io"
- "strconv"
)
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
@@ -24,7 +23,7 @@ const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
// NewChunkedReader returns a new chunkedReader that translates the data read from r
-// out of HTTP "chunked" format before returning it.
+// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
@@ -41,16 +40,17 @@ type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
+ buf [2]byte
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
- var line string
+ var line []byte
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
- cr.n, cr.err = strconv.ParseUint(line, 16, 64)
+ cr.n, cr.err = parseHexUint(line)
if cr.err != nil {
return
}
@@ -76,9 +76,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
- b := make([]byte, 2)
- if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
- if b[0] != '\r' || b[1] != '\n' {
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:]); cr.err == nil {
+ if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
@@ -90,7 +89,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
-func readLineBytes(b *bufio.Reader) (p []byte, err error) {
+func readLine(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
@@ -104,20 +103,18 @@ func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
-
- // Chop off trailing white space.
- p = bytes.TrimRight(p, " \r\t\n")
-
- return p, nil
+ return trimTrailingWhitespace(p), nil
}
-// readLineBytes, but convert the bytes into a string.
-func readLine(b *bufio.Reader) (s string, err error) {
- p, e := readLineBytes(b)
- if e != nil {
- return "", e
+func trimTrailingWhitespace(b []byte) []byte {
+ for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+ b = b[:len(b)-1]
}
- return string(p), nil
+ return b
+}
+
+func isASCIISpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
@@ -149,9 +146,7 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
return 0, nil
}
- head := strconv.FormatInt(int64(len(data)), 16) + "\r\n"
-
- if _, err = io.WriteString(cw.Wire, head); err != nil {
+ if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil {
return 0, err
}
if n, err = cw.Wire.Write(data); err != nil {
@@ -170,3 +165,21 @@ func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n")
return err
}
+
+func parseHexUint(v []byte) (n uint64, err error) {
+ for _, b := range v {
+ n <<= 4
+ switch {
+ case '0' <= b && b <= '9':
+ b = b - '0'
+ case 'a' <= b && b <= 'f':
+ b = b - 'a' + 10
+ case 'A' <= b && b <= 'F':
+ b = b - 'A' + 10
+ default:
+ return 0, errors.New("invalid byte in chunk length")
+ }
+ n |= uint64(b)
+ }
+ return
+}
diff --git a/src/pkg/net/http/httputil/chunked_test.go b/src/pkg/net/http/httputil/chunked_test.go
index 155a32bdf..a06bffad5 100644
--- a/src/pkg/net/http/httputil/chunked_test.go
+++ b/src/pkg/net/http/httputil/chunked_test.go
@@ -11,7 +11,10 @@ package httputil
import (
"bytes"
+ "fmt"
+ "io"
"io/ioutil"
+ "runtime"
"testing"
)
@@ -39,3 +42,54 @@ func TestChunk(t *testing.T) {
t.Errorf("chunk reader read %q; want %q", g, e)
}
}
+
+func TestChunkReaderAllocs(t *testing.T) {
+ // temporarily set GOMAXPROCS to 1 as we are testing memory allocations
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ var buf bytes.Buffer
+ w := NewChunkedWriter(&buf)
+ a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
+ w.Write(a)
+ w.Write(b)
+ w.Write(c)
+ w.Close()
+
+ r := NewChunkedReader(&buf)
+ readBuf := make([]byte, len(a)+len(b)+len(c)+1)
+
+ var ms runtime.MemStats
+ runtime.ReadMemStats(&ms)
+ m0 := ms.Mallocs
+
+ n, err := io.ReadFull(r, readBuf)
+
+ runtime.ReadMemStats(&ms)
+ mallocs := ms.Mallocs - m0
+ if mallocs > 1 {
+ t.Errorf("%d mallocs; want <= 1", mallocs)
+ }
+
+ if n != len(readBuf)-1 {
+ t.Errorf("read %d bytes; want %d", n, len(readBuf)-1)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("read error = %v; want ErrUnexpectedEOF", err)
+ }
+}
+
+func TestParseHexUint(t *testing.T) {
+ for i := uint64(0); i <= 1234; i++ {
+ line := []byte(fmt.Sprintf("%x", i))
+ got, err := parseHexUint(line)
+ if err != nil {
+ t.Fatalf("on %d: %v", i, err)
+ }
+ if got != i {
+ t.Errorf("for input %q = %d; want %d", line, got, i)
+ }
+ }
+ _, err := parseHexUint([]byte("bogus"))
+ if err == nil {
+ t.Error("expected error on bogus input")
+ }
+}
diff --git a/src/pkg/net/http/httputil/dump.go b/src/pkg/net/http/httputil/dump.go
index 892ef4ede..0b0035661 100644
--- a/src/pkg/net/http/httputil/dump.go
+++ b/src/pkg/net/http/httputil/dump.go
@@ -75,7 +75,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
// Use the actual Transport code to record what we would send
// on the wire, but not using TCP. Use a Transport with a
- // customer dialer that returns a fake net.Conn that waits
+ // custom dialer that returns a fake net.Conn that waits
// for the full input (and recording it), and then responds
// with a dummy response.
var buf bytes.Buffer // records the output
@@ -89,7 +89,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
t := &http.Transport{
Dial: func(net, addr string) (net.Conn, error) {
- return &dumpConn{io.MultiWriter(pw, &buf), dr}, nil
+ return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
},
}
diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go
index 9c4bd6e09..134c45299 100644
--- a/src/pkg/net/http/httputil/reverseproxy.go
+++ b/src/pkg/net/http/httputil/reverseproxy.go
@@ -17,6 +17,10 @@ import (
"time"
)
+// onExitFlushLoop is a callback set by tests to detect the state of the
+// flushLoop() goroutine.
+var onExitFlushLoop func()
+
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
@@ -102,8 +106,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Del("Connection")
}
- if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
- outreq.Header.Set("X-Forwarded-For", clientIp)
+ if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ // If we aren't the first proxy retain prior
+ // X-Forwarded-For information as a comma+space
+ // separated list and fold multiple headers into one.
+ if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
+ clientIP = strings.Join(prior, ", ") + ", " + clientIP
+ }
+ outreq.Header.Set("X-Forwarded-For", clientIP)
}
res, err := transport.RoundTrip(outreq)
@@ -112,20 +122,29 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
return
}
+ defer res.Body.Close()
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
+ p.copyResponse(rw, res.Body)
+}
- if res.Body != nil {
- var dst io.Writer = rw
- if p.FlushInterval != 0 {
- if wf, ok := rw.(writeFlusher); ok {
- dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
+ if p.FlushInterval != 0 {
+ if wf, ok := dst.(writeFlusher); ok {
+ mlw := &maxLatencyWriter{
+ dst: wf,
+ latency: p.FlushInterval,
+ done: make(chan bool),
}
+ go mlw.flushLoop()
+ defer mlw.stop()
+ dst = mlw
}
- io.Copy(dst, res.Body)
}
+
+ io.Copy(dst, src)
}
type writeFlusher interface {
@@ -137,22 +156,14 @@ type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
- lk sync.Mutex // protects init of done, as well Write + Flush
+ lk sync.Mutex // protects Write + Flush
done chan bool
}
-func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
- if m.done == nil {
- m.done = make(chan bool)
- go m.flushLoop()
- }
- n, err = m.dst.Write(p)
- if err != nil {
- m.done <- true
- }
- return
+ return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
@@ -160,13 +171,18 @@ func (m *maxLatencyWriter) flushLoop() {
defer t.Stop()
for {
select {
+ case <-m.done:
+ if onExitFlushLoop != nil {
+ onExitFlushLoop()
+ }
+ return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
- case <-m.done:
- return
}
}
panic("unreached")
}
+
+func (m *maxLatencyWriter) stop() { m.done <- true }
diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go
index 28e9c90ad..863927162 100644
--- a/src/pkg/net/http/httputil/reverseproxy_test.go
+++ b/src/pkg/net/http/httputil/reverseproxy_test.go
@@ -11,7 +11,9 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "strings"
"testing"
+ "time"
)
func TestReverseProxy(t *testing.T) {
@@ -70,6 +72,47 @@ func TestReverseProxy(t *testing.T) {
}
}
+func TestXForwardedFor(t *testing.T) {
+ const prevForwardedFor = "client ip"
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
+ t.Errorf("X-Forwarded-For didn't contain prior data")
+ }
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close")
+ getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
+ getReq.Close = true
+ res, err := http.DefaultClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
var proxyQueryTests = []struct {
baseSuffix string // suffix to add to backend URL
reqSuffix string // suffix to add to frontend's request URL
@@ -107,3 +150,44 @@ func TestReverseProxyQuery(t *testing.T) {
frontend.Close()
}
}
+
+func TestReverseProxyFlushInterval(t *testing.T) {
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ done := make(chan bool)
+ onExitFlushLoop = func() { done <- true }
+ defer func() { onExitFlushLoop = nil }()
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+
+ select {
+ case <-done:
+ // OK
+ case <-time.After(5 * time.Second):
+ t.Error("maxLatencyWriter flushLoop() never exited")
+ }
+}
diff --git a/src/pkg/net/http/jar.go b/src/pkg/net/http/jar.go
index 2c2caa251..5c3de0dad 100644
--- a/src/pkg/net/http/jar.go
+++ b/src/pkg/net/http/jar.go
@@ -8,23 +8,20 @@ import (
"net/url"
)
-// A CookieJar manages storage and use of cookies in HTTP requests.
+// A CookieJar manages storage and use of cookies in HTTP requests.
//
// Implementations of CookieJar must be safe for concurrent use by multiple
// goroutines.
+//
+// The net/http/cookiejar package provides a CookieJar implementation.
type CookieJar interface {
- // SetCookies handles the receipt of the cookies in a reply for the
- // given URL. It may or may not choose to save the cookies, depending
- // on the jar's policy and implementation.
+ // SetCookies handles the receipt of the cookies in a reply for the
+ // given URL. It may or may not choose to save the cookies, depending
+ // on the jar's policy and implementation.
SetCookies(u *url.URL, cookies []*Cookie)
// Cookies returns the cookies to send in a request for the given URL.
- // It is up to the implementation to honor the standard cookie use
- // restrictions such as in RFC 6265.
+ // It is up to the implementation to honor the standard cookie use
+ // restrictions such as in RFC 6265.
Cookies(u *url.URL) []*Cookie
}
-
-type blackHoleJar struct{}
-
-func (blackHoleJar) SetCookies(u *url.URL, cookies []*Cookie) {}
-func (blackHoleJar) Cookies(u *url.URL) []*Cookie { return nil }
diff --git a/src/pkg/net/http/lex.go b/src/pkg/net/http/lex.go
index ffb393ccf..cb33318f4 100644
--- a/src/pkg/net/http/lex.go
+++ b/src/pkg/net/http/lex.go
@@ -6,131 +6,91 @@ package http
// This file deals with lexical matters of HTTP
-func isSeparator(c byte) bool {
- switch c {
- case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t':
- return true
- }
- return false
+var isTokenTable = [127]bool{
+ '!': true,
+ '#': true,
+ '$': true,
+ '%': true,
+ '&': true,
+ '\'': true,
+ '*': true,
+ '+': true,
+ '-': true,
+ '.': true,
+ '0': true,
+ '1': true,
+ '2': true,
+ '3': true,
+ '4': true,
+ '5': true,
+ '6': true,
+ '7': true,
+ '8': true,
+ '9': true,
+ 'A': true,
+ 'B': true,
+ 'C': true,
+ 'D': true,
+ 'E': true,
+ 'F': true,
+ 'G': true,
+ 'H': true,
+ 'I': true,
+ 'J': true,
+ 'K': true,
+ 'L': true,
+ 'M': true,
+ 'N': true,
+ 'O': true,
+ 'P': true,
+ 'Q': true,
+ 'R': true,
+ 'S': true,
+ 'T': true,
+ 'U': true,
+ 'W': true,
+ 'V': true,
+ 'X': true,
+ 'Y': true,
+ 'Z': true,
+ '^': true,
+ '_': true,
+ '`': true,
+ 'a': true,
+ 'b': true,
+ 'c': true,
+ 'd': true,
+ 'e': true,
+ 'f': true,
+ 'g': true,
+ 'h': true,
+ 'i': true,
+ 'j': true,
+ 'k': true,
+ 'l': true,
+ 'm': true,
+ 'n': true,
+ 'o': true,
+ 'p': true,
+ 'q': true,
+ 'r': true,
+ 's': true,
+ 't': true,
+ 'u': true,
+ 'v': true,
+ 'w': true,
+ 'x': true,
+ 'y': true,
+ 'z': true,
+ '|': true,
+ '~': true,
}
-func isCtl(c byte) bool { return (0 <= c && c <= 31) || c == 127 }
-
-func isChar(c byte) bool { return 0 <= c && c <= 127 }
-
-func isAnyText(c byte) bool { return !isCtl(c) }
-
-func isQdText(c byte) bool { return isAnyText(c) && c != '"' }
-
-func isToken(c byte) bool { return isChar(c) && !isCtl(c) && !isSeparator(c) }
-
-// Valid escaped sequences are not specified in RFC 2616, so for now, we assume
-// that they coincide with the common sense ones used by GO. Malformed
-// characters should probably not be treated as errors by a robust (forgiving)
-// parser, so we replace them with the '?' character.
-func httpUnquotePair(b byte) byte {
- // skip the first byte, which should always be '\'
- switch b {
- case 'a':
- return '\a'
- case 'b':
- return '\b'
- case 'f':
- return '\f'
- case 'n':
- return '\n'
- case 'r':
- return '\r'
- case 't':
- return '\t'
- case 'v':
- return '\v'
- case '\\':
- return '\\'
- case '\'':
- return '\''
- case '"':
- return '"'
- }
- return '?'
-}
-
-// raw must begin with a valid quoted string. Only the first quoted string is
-// parsed and is unquoted in result. eaten is the number of bytes parsed, or -1
-// upon failure.
-func httpUnquote(raw []byte) (eaten int, result string) {
- buf := make([]byte, len(raw))
- if raw[0] != '"' {
- return -1, ""
- }
- eaten = 1
- j := 0 // # of bytes written in buf
- for i := 1; i < len(raw); i++ {
- switch b := raw[i]; b {
- case '"':
- eaten++
- buf = buf[0:j]
- return i + 1, string(buf)
- case '\\':
- if len(raw) < i+2 {
- return -1, ""
- }
- buf[j] = httpUnquotePair(raw[i+1])
- eaten += 2
- j++
- i++
- default:
- if isQdText(b) {
- buf[j] = b
- } else {
- buf[j] = '?'
- }
- eaten++
- j++
- }
- }
- return -1, ""
+func isToken(r rune) bool {
+ i := int(r)
+ return i < len(isTokenTable) && isTokenTable[i]
}
-// This is a best effort parse, so errors are not returned, instead not all of
-// the input string might be parsed. result is always non-nil.
-func httpSplitFieldValue(fv string) (eaten int, result []string) {
- result = make([]string, 0, len(fv))
- raw := []byte(fv)
- i := 0
- chunk := ""
- for i < len(raw) {
- b := raw[i]
- switch {
- case b == '"':
- eaten, unq := httpUnquote(raw[i:len(raw)])
- if eaten < 0 {
- return i, result
- } else {
- i += eaten
- chunk += unq
- }
- case isSeparator(b):
- if chunk != "" {
- result = result[0 : len(result)+1]
- result[len(result)-1] = chunk
- chunk = ""
- }
- i++
- case isToken(b):
- chunk += string(b)
- i++
- case b == '\n' || b == '\r':
- i++
- default:
- chunk += "?"
- i++
- }
- }
- if chunk != "" {
- result = result[0 : len(result)+1]
- result[len(result)-1] = chunk
- chunk = ""
- }
- return i, result
+func isNotToken(r rune) bool {
+ return !isToken(r)
}
diff --git a/src/pkg/net/http/lex_test.go b/src/pkg/net/http/lex_test.go
index 5386f7534..6d9d294f7 100644
--- a/src/pkg/net/http/lex_test.go
+++ b/src/pkg/net/http/lex_test.go
@@ -8,63 +8,24 @@ import (
"testing"
)
-type lexTest struct {
- Raw string
- Parsed int // # of parsed characters
- Result []string
-}
+func isChar(c rune) bool { return c <= 127 }
-var lexTests = []lexTest{
- {
- Raw: `"abc"def,:ghi`,
- Parsed: 13,
- Result: []string{"abcdef", "ghi"},
- },
- // My understanding of the RFC is that escape sequences outside of
- // quotes are not interpreted?
- {
- Raw: `"\t"\t"\t"`,
- Parsed: 10,
- Result: []string{"\t", "t\t"},
- },
- {
- Raw: `"\yab"\r\n`,
- Parsed: 10,
- Result: []string{"?ab", "r", "n"},
- },
- {
- Raw: "ab\f",
- Parsed: 3,
- Result: []string{"ab?"},
- },
- {
- Raw: "\"ab \" c,de f, gh, ij\n\t\r",
- Parsed: 23,
- Result: []string{"ab ", "c", "de", "f", "gh", "ij"},
- },
-}
+func isCtl(c rune) bool { return c <= 31 || c == 127 }
-func min(x, y int) int {
- if x <= y {
- return x
+func isSeparator(c rune) bool {
+ switch c {
+ case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t':
+ return true
}
- return y
+ return false
}
-func TestSplitFieldValue(t *testing.T) {
- for k, l := range lexTests {
- parsed, result := httpSplitFieldValue(l.Raw)
- if parsed != l.Parsed {
- t.Errorf("#%d: Parsed %d, expected %d", k, parsed, l.Parsed)
- }
- if len(result) != len(l.Result) {
- t.Errorf("#%d: Result len %d, expected %d", k, len(result), len(l.Result))
- }
- for i := 0; i < min(len(result), len(l.Result)); i++ {
- if result[i] != l.Result[i] {
- t.Errorf("#%d: %d-th entry mismatch. Have {%s}, expect {%s}",
- k, i, result[i], l.Result[i])
- }
+func TestIsToken(t *testing.T) {
+ for i := 0; i <= 130; i++ {
+ r := rune(i)
+ expected := isChar(r) && !isCtl(r) && !isSeparator(r)
+ if isToken(r) != expected {
+ t.Errorf("isToken(0x%x) = %v", r, !expected)
}
}
}
diff --git a/src/pkg/net/http/npn_test.go b/src/pkg/net/http/npn_test.go
new file mode 100644
index 000000000..98b8930d0
--- /dev/null
+++ b/src/pkg/net/http/npn_test.go
@@ -0,0 +1,118 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bufio"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "io/ioutil"
+ . "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestNextProtoUpgrade(t *testing.T) {
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
+ if r.TLS != nil {
+ w.Write([]byte(r.TLS.NegotiatedProtocol))
+ }
+ if r.RemoteAddr == "" {
+ t.Error("request with no RemoteAddr")
+ }
+ if r.Body == nil {
+ t.Errorf("request with nil Body")
+ }
+ }))
+ ts.TLS = &tls.Config{
+ NextProtos: []string{"unhandled-proto", "tls-0.9"},
+ }
+ ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
+ "tls-0.9": handleTLSProtocol09,
+ }
+ ts.StartTLS()
+ defer ts.Close()
+
+ tr := newTLSTransport(t, ts)
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ // Normal request, without NPN.
+ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/,proto="; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+
+ // Request to an advertised but unhandled NPN protocol.
+ // Server will hang up.
+ {
+ tr.CloseIdleConnections()
+ tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"}
+ _, err := c.Get(ts.URL)
+ if err == nil {
+ t.Errorf("expected error on unhandled-proto request")
+ }
+ }
+
+ // Request using the "tls-0.9" protocol, which we register here.
+ // It is HTTP/0.9 over TLS.
+ {
+ tlsConfig := newTLSTransport(t, ts).TLSClientConfig
+ tlsConfig.NextProtos = []string{"tls-0.9"}
+ conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.Write([]byte("GET /foo\n"))
+ body, err := ioutil.ReadAll(conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/foo,proto=tls-0.9"; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+}
+
+// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the
+// TestNextProtoUpgrade test.
+func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
+ br := bufio.NewReader(conn)
+ line, err := br.ReadString('\n')
+ if err != nil {
+ return
+ }
+ line = strings.TrimSpace(line)
+ path := strings.TrimPrefix(line, "GET ")
+ if path == line {
+ return
+ }
+ req, _ := NewRequest("GET", path, nil)
+ req.Proto = "HTTP/0.9"
+ req.ProtoMajor = 0
+ req.ProtoMinor = 9
+ rw := &http09Writer{conn, make(Header)}
+ h.ServeHTTP(rw, req)
+}
+
+type http09Writer struct {
+ io.Writer
+ h Header
+}
+
+func (w http09Writer) Header() Header { return w.h }
+func (w http09Writer) WriteHeader(int) {} // no headers
diff --git a/src/pkg/net/http/pprof/pprof.go b/src/pkg/net/http/pprof/pprof.go
index 06fcde144..0c7548e3e 100644
--- a/src/pkg/net/http/pprof/pprof.go
+++ b/src/pkg/net/http/pprof/pprof.go
@@ -14,6 +14,14 @@
// To use pprof, link this package into your program:
// import _ "net/http/pprof"
//
+// If your application is not already running an http server, you
+// need to start one. Add "net/http" and "log" to your imports and
+// the following code to your main function:
+//
+// go func() {
+// log.Println(http.ListenAndServe("localhost:6060", nil))
+// }()
+//
// Then use the pprof tool to look at the heap profile:
//
// go tool pprof http://localhost:6060/debug/pprof/heap
@@ -22,9 +30,12 @@
//
// go tool pprof http://localhost:6060/debug/pprof/profile
//
-// Or to view all available profiles:
+// Or to look at the goroutine blocking profile:
+//
+// go tool pprof http://localhost:6060/debug/pprof/block
//
-// go tool pprof http://localhost:6060/debug/pprof/
+// To view all available profiles, open http://localhost:6060/debug/pprof/
+// in your browser.
//
// For a study of the facility in action, visit
//
@@ -161,7 +172,7 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// listing the available profiles.
func Index(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/debug/pprof/") {
- name := r.URL.Path[len("/debug/pprof/"):]
+ name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/")
if name != "" {
handler(name).ServeHTTP(w, r)
return
diff --git a/src/pkg/net/http/proxy_test.go b/src/pkg/net/http/proxy_test.go
index 5ecffafac..449ccaeea 100644
--- a/src/pkg/net/http/proxy_test.go
+++ b/src/pkg/net/http/proxy_test.go
@@ -25,13 +25,13 @@ var UseProxyTests = []struct {
{"[::2]", true}, // not a loopback address
{"barbaz.net", false}, // match as .barbaz.net
- {"foobar.com", false}, // have a port but match
+ {"foobar.com", false}, // have a port but match
{"foofoobar.com", true}, // not match as a part of foobar.com
{"baz.com", true}, // not match as a part of barbaz.com
{"localhost.net", true}, // not match as suffix of address
{"local.localhost", true}, // not match as prefix as address
{"barbarbaz.net", true}, // not match because NO_PROXY have a '.'
- {"www.foobar.com", true}, // not match because NO_PROXY is not .foobar.com
+ {"www.foobar.com", false}, // match because NO_PROXY includes "foobar.com"
}
func TestUseProxy(t *testing.T) {
diff --git a/src/pkg/net/http/range_test.go b/src/pkg/net/http/range_test.go
index 5274a81fa..ef911af7b 100644
--- a/src/pkg/net/http/range_test.go
+++ b/src/pkg/net/http/range_test.go
@@ -14,15 +14,34 @@ var ParseRangeTests = []struct {
r []httpRange
}{
{"", 0, nil},
+ {"", 1000, nil},
{"foo", 0, nil},
{"bytes=", 0, nil},
+ {"bytes=7", 10, nil},
+ {"bytes= 7 ", 10, nil},
+ {"bytes=1-", 0, nil},
{"bytes=5-4", 10, nil},
{"bytes=0-2,5-4", 10, nil},
+ {"bytes=2-5,4-3", 10, nil},
+ {"bytes=--5,4--3", 10, nil},
+ {"bytes=A-", 10, nil},
+ {"bytes=A- ", 10, nil},
+ {"bytes=A-Z", 10, nil},
+ {"bytes= -Z", 10, nil},
+ {"bytes=5-Z", 10, nil},
+ {"bytes=Ran-dom, garbage", 10, nil},
+ {"bytes=0x01-0x02", 10, nil},
+ {"bytes= ", 10, nil},
+ {"bytes= , , , ", 10, nil},
+
{"bytes=0-9", 10, []httpRange{{0, 10}}},
{"bytes=0-", 10, []httpRange{{0, 10}}},
{"bytes=5-", 10, []httpRange{{5, 5}}},
{"bytes=0-20", 10, []httpRange{{0, 10}}},
{"bytes=15-,0-5", 10, nil},
+ {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}},
+ {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}},
+ {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}},
{"bytes=-5", 10, []httpRange{{5, 5}}},
{"bytes=-15", 10, []httpRange{{0, 10}}},
{"bytes=0-499", 10000, []httpRange{{0, 500}}},
@@ -32,6 +51,9 @@ var ParseRangeTests = []struct {
{"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}},
{"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}},
{"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}},
+
+ // Match Apache laxity:
+ {"bytes= 1 -2 , 4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}},
}
func TestParseRange(t *testing.T) {
diff --git a/src/pkg/net/http/readrequest_test.go b/src/pkg/net/http/readrequest_test.go
index 2e03c658a..ffdd6a892 100644
--- a/src/pkg/net/http/readrequest_test.go
+++ b/src/pkg/net/http/readrequest_test.go
@@ -247,6 +247,54 @@ var reqTests = []reqTest{
noTrailer,
noError,
},
+
+ // SSDP Notify request. golang.org/issue/3692
+ {
+ "NOTIFY * HTTP/1.1\r\nServer: foo\r\n\r\n",
+ &Request{
+ Method: "NOTIFY",
+ URL: &url.URL{
+ Path: "*",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Server": []string{"foo"},
+ },
+ Close: false,
+ ContentLength: 0,
+ RequestURI: "*",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
+
+ // OPTIONS request. Similar to golang.org/issue/3692
+ {
+ "OPTIONS * HTTP/1.1\r\nServer: foo\r\n\r\n",
+ &Request{
+ Method: "OPTIONS",
+ URL: &url.URL{
+ Path: "*",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Server": []string{"foo"},
+ },
+ Close: false,
+ ContentLength: 0,
+ RequestURI: "*",
+ },
+
+ noBody,
+ noTrailer,
+ noError,
+ },
}
func TestReadRequest(t *testing.T) {
diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go
index f5bc6eb91..217f35b48 100644
--- a/src/pkg/net/http/request.go
+++ b/src/pkg/net/http/request.go
@@ -19,6 +19,7 @@ import (
"mime/multipart"
"net/textproto"
"net/url"
+ "strconv"
"strings"
)
@@ -70,7 +71,13 @@ var reqWriteExcludeHeader = map[string]bool{
// or to be sent by a client.
type Request struct {
Method string // GET, POST, PUT, etc.
- URL *url.URL
+
+ // URL is created from the URI supplied on the Request-Line
+ // as stored in RequestURI.
+ //
+ // For most requests, fields other than Path and RawQuery
+ // will be empty. (See RFC 2616, Section 5.1.2)
+ URL *url.URL
// The protocol version for incoming requests.
// Outgoing requests always use HTTP/1.1.
@@ -123,6 +130,7 @@ type Request struct {
// The host on which the URL is sought.
// Per RFC 2616, this is either the value of the Host: header
// or the host name given in the URL itself.
+ // It may be of the form "host:port".
Host string
// Form contains the parsed form data, including both the URL
@@ -131,6 +139,12 @@ type Request struct {
// The HTTP client ignores Form and uses Body instead.
Form url.Values
+ // PostForm contains the parsed form data from POST or PUT
+ // body parameters.
+ // This field is only available after ParseForm is called.
+ // The HTTP client ignores PostForm and uses Body instead.
+ PostForm url.Values
+
// MultipartForm is the parsed multipart form, including file uploads.
// This field is only available after ParseMultipartForm is called.
// The HTTP client ignores MultipartForm and uses Body instead.
@@ -317,11 +331,20 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
}
// TODO(bradfitz): escape at least newlines in ruri?
- bw := bufio.NewWriter(w)
- fmt.Fprintf(bw, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+ // Wrap the writer in a bufio Writer if it's not already buffered.
+ // Don't always call NewWriter, as that forces a bytes.Buffer
+ // and other small bufio Writers to have a minimum 4k buffer
+ // size.
+ var bw *bufio.Writer
+ if _, ok := w.(io.ByteWriter); !ok {
+ bw = bufio.NewWriter(w)
+ w = bw
+ }
+
+ fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
// Header lines
- fmt.Fprintf(bw, "Host: %s\r\n", host)
+ fmt.Fprintf(w, "Host: %s\r\n", host)
// Use the defaultUserAgent unless the Header contains one, which
// may be blank to not send the header.
@@ -332,7 +355,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
}
}
if userAgent != "" {
- fmt.Fprintf(bw, "User-Agent: %s\r\n", userAgent)
+ fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
}
// Process Body,ContentLength,Close,Trailer
@@ -340,65 +363,61 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
if err != nil {
return err
}
- err = tw.WriteHeader(bw)
+ err = tw.WriteHeader(w)
if err != nil {
return err
}
// TODO: split long values? (If so, should share code with Conn.Write)
- err = req.Header.WriteSubset(bw, reqWriteExcludeHeader)
+ err = req.Header.WriteSubset(w, reqWriteExcludeHeader)
if err != nil {
return err
}
if extraHeaders != nil {
- err = extraHeaders.Write(bw)
+ err = extraHeaders.Write(w)
if err != nil {
return err
}
}
- io.WriteString(bw, "\r\n")
+ io.WriteString(w, "\r\n")
// Write body and trailer
- err = tw.WriteBody(bw)
+ err = tw.WriteBody(w)
if err != nil {
return err
}
- return bw.Flush()
-}
-
-// Convert decimal at s[i:len(s)] to integer,
-// returning value, string position where the digits stopped,
-// and whether there was a valid number (digits, not too big).
-func atoi(s string, i int) (n, i1 int, ok bool) {
- const Big = 1000000
- if i >= len(s) || s[i] < '0' || s[i] > '9' {
- return 0, 0, false
- }
- n = 0
- for ; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
- n = n*10 + int(s[i]-'0')
- if n > Big {
- return 0, 0, false
- }
+ if bw != nil {
+ return bw.Flush()
}
- return n, i, true
+ return nil
}
// ParseHTTPVersion parses a HTTP version string.
// "HTTP/1.0" returns (1, 0, true).
func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
- if len(vers) < 5 || vers[0:5] != "HTTP/" {
+ const Big = 1000000 // arbitrary upper bound
+ switch vers {
+ case "HTTP/1.1":
+ return 1, 1, true
+ case "HTTP/1.0":
+ return 1, 0, true
+ }
+ if !strings.HasPrefix(vers, "HTTP/") {
return 0, 0, false
}
- major, i, ok := atoi(vers, 5)
- if !ok || i >= len(vers) || vers[i] != '.' {
+ dot := strings.Index(vers, ".")
+ if dot < 0 {
return 0, 0, false
}
- minor, i, ok = atoi(vers, i+1)
- if !ok || i != len(vers) {
+ major, err := strconv.Atoi(vers[5:dot])
+ if err != nil || major < 0 || major > Big {
+ return 0, 0, false
+ }
+ minor, err = strconv.Atoi(vers[dot+1:])
+ if err != nil || minor < 0 || minor > Big {
return 0, 0, false
}
return major, minor, true
@@ -426,10 +445,12 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
}
if body != nil {
switch v := body.(type) {
- case *strings.Reader:
- req.ContentLength = int64(v.Len())
case *bytes.Buffer:
req.ContentLength = int64(v.Len())
+ case *bytes.Reader:
+ req.ContentLength = int64(v.Len())
+ case *strings.Reader:
+ req.ContentLength = int64(v.Len())
}
}
@@ -513,9 +534,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) {
// the same. In the second case, any Host line is ignored.
req.Host = req.URL.Host
if req.Host == "" {
- req.Host = req.Header.Get("Host")
+ req.Host = req.Header.get("Host")
}
- req.Header.Del("Host")
+ delete(req.Header, "Host")
fixPragmaCacheControl(req.Header)
@@ -594,66 +615,97 @@ func (l *maxBytesReader) Close() error {
return l.r.Close()
}
-// ParseForm parses the raw query from the URL.
+func copyValues(dst, src url.Values) {
+ for k, vs := range src {
+ for _, value := range vs {
+ dst.Add(k, value)
+ }
+ }
+}
+
+func parsePostForm(r *Request) (vs url.Values, err error) {
+ if r.Body == nil {
+ err = errors.New("missing form body")
+ return
+ }
+ ct := r.Header.Get("Content-Type")
+ ct, _, err = mime.ParseMediaType(ct)
+ switch {
+ case ct == "application/x-www-form-urlencoded":
+ var reader io.Reader = r.Body
+ maxFormSize := int64(1<<63 - 1)
+ if _, ok := r.Body.(*maxBytesReader); !ok {
+ maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
+ reader = io.LimitReader(r.Body, maxFormSize+1)
+ }
+ b, e := ioutil.ReadAll(reader)
+ if e != nil {
+ if err == nil {
+ err = e
+ }
+ break
+ }
+ if int64(len(b)) > maxFormSize {
+ err = errors.New("http: POST too large")
+ return
+ }
+ vs, e = url.ParseQuery(string(b))
+ if err == nil {
+ err = e
+ }
+ case ct == "multipart/form-data":
+ // handled by ParseMultipartForm (which is calling us, or should be)
+ // TODO(bradfitz): there are too many possible
+ // orders to call too many functions here.
+ // Clean this up and write more tests.
+ // request_test.go contains the start of this,
+ // in TestRequestMultipartCallOrder.
+ }
+ return
+}
+
+// ParseForm parses the raw query from the URL and updates r.Form.
+//
+// For POST or PUT requests, it also parses the request body as a form and
+// put the results into both r.PostForm and r.Form.
+// POST and PUT body parameters take precedence over URL query string values
+// in r.Form.
//
-// For POST or PUT requests, it also parses the request body as a form.
// If the request Body's size has not already been limited by MaxBytesReader,
// the size is capped at 10MB.
//
// ParseMultipartForm calls ParseForm automatically.
// It is idempotent.
-func (r *Request) ParseForm() (err error) {
- if r.Form != nil {
- return
- }
- if r.URL != nil {
- r.Form, err = url.ParseQuery(r.URL.RawQuery)
+func (r *Request) ParseForm() error {
+ var err error
+ if r.PostForm == nil {
+ if r.Method == "POST" || r.Method == "PUT" {
+ r.PostForm, err = parsePostForm(r)
+ }
+ if r.PostForm == nil {
+ r.PostForm = make(url.Values)
+ }
}
- if r.Method == "POST" || r.Method == "PUT" {
- if r.Body == nil {
- return errors.New("missing form body")
+ if r.Form == nil {
+ if len(r.PostForm) > 0 {
+ r.Form = make(url.Values)
+ copyValues(r.Form, r.PostForm)
}
- ct := r.Header.Get("Content-Type")
- ct, _, err = mime.ParseMediaType(ct)
- switch {
- case ct == "application/x-www-form-urlencoded":
- var reader io.Reader = r.Body
- maxFormSize := int64(1<<63 - 1)
- if _, ok := r.Body.(*maxBytesReader); !ok {
- maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
- reader = io.LimitReader(r.Body, maxFormSize+1)
- }
- b, e := ioutil.ReadAll(reader)
- if e != nil {
- if err == nil {
- err = e
- }
- break
- }
- if int64(len(b)) > maxFormSize {
- return errors.New("http: POST too large")
- }
- var newValues url.Values
- newValues, e = url.ParseQuery(string(b))
+ var newValues url.Values
+ if r.URL != nil {
+ var e error
+ newValues, e = url.ParseQuery(r.URL.RawQuery)
if err == nil {
err = e
}
- if r.Form == nil {
- r.Form = make(url.Values)
- }
- // Copy values into r.Form. TODO: make this smoother.
- for k, vs := range newValues {
- for _, value := range vs {
- r.Form.Add(k, value)
- }
- }
- case ct == "multipart/form-data":
- // handled by ParseMultipartForm (which is calling us, or should be)
- // TODO(bradfitz): there are too many possible
- // orders to call too many functions here.
- // Clean this up and write more tests.
- // request_test.go contains the start of this,
- // in TestRequestMultipartCallOrder.
+ }
+ if newValues == nil {
+ newValues = make(url.Values)
+ }
+ if r.Form == nil {
+ r.Form = newValues
+ } else {
+ copyValues(r.Form, newValues)
}
}
return err
@@ -699,7 +751,9 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error {
}
// FormValue returns the first value for the named component of the query.
+// POST and PUT body parameters take precedence over URL query string values.
// FormValue calls ParseMultipartForm and ParseForm if necessary.
+// To access multiple values of the same key use ParseForm.
func (r *Request) FormValue(key string) string {
if r.Form == nil {
r.ParseMultipartForm(defaultMaxMemory)
@@ -710,6 +764,19 @@ func (r *Request) FormValue(key string) string {
return ""
}
+// PostFormValue returns the first value for the named component of the POST
+// or PUT request body. URL query parameters are ignored.
+// PostFormValue calls ParseMultipartForm and ParseForm if necessary.
+func (r *Request) PostFormValue(key string) string {
+ if r.PostForm == nil {
+ r.ParseMultipartForm(defaultMaxMemory)
+ }
+ if vs := r.PostForm[key]; len(vs) > 0 {
+ return vs[0]
+ }
+ return ""
+}
+
// FormFile returns the first file for the provided form key.
// FormFile calls ParseMultipartForm and ParseForm if necessary.
func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) {
@@ -732,12 +799,16 @@ func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, e
}
func (r *Request) expectsContinue() bool {
- return strings.ToLower(r.Header.Get("Expect")) == "100-continue"
+ return hasToken(r.Header.get("Expect"), "100-continue")
}
func (r *Request) wantsHttp10KeepAlive() bool {
if r.ProtoMajor != 1 || r.ProtoMinor != 0 {
return false
}
- return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "keep-alive")
+ return hasToken(r.Header.get("Connection"), "keep-alive")
+}
+
+func (r *Request) wantsClose() bool {
+ return hasToken(r.Header.get("Connection"), "close")
}
diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go
index 6e00b9bfd..00ad791de 100644
--- a/src/pkg/net/http/request_test.go
+++ b/src/pkg/net/http/request_test.go
@@ -30,8 +30,8 @@ func TestQuery(t *testing.T) {
}
func TestPostQuery(t *testing.T) {
- req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x",
- strings.NewReader("z=post&both=y"))
+ req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not",
+ strings.NewReader("z=post&both=y&prio=2&empty="))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
if q := req.FormValue("q"); q != "foo" {
@@ -40,8 +40,23 @@ func TestPostQuery(t *testing.T) {
if z := req.FormValue("z"); z != "post" {
t.Errorf(`req.FormValue("z") = %q, want "post"`, z)
}
- if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"x", "y"}) {
- t.Errorf(`req.FormValue("both") = %q, want ["x", "y"]`, both)
+ if bq, found := req.PostForm["q"]; found {
+ t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq)
+ }
+ if bz := req.PostFormValue("z"); bz != "post" {
+ t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz)
+ }
+ if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) {
+ t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs)
+ }
+ if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) {
+ t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both)
+ }
+ if prio := req.FormValue("prio"); prio != "2" {
+ t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio)
+ }
+ if empty := req.FormValue("empty"); empty != "" {
+ t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty)
}
}
@@ -76,6 +91,23 @@ func TestParseFormUnknownContentType(t *testing.T) {
}
}
+func TestParseFormInitializeOnError(t *testing.T) {
+ nilBody, _ := NewRequest("POST", "http://www.google.com/search?q=foo", nil)
+ tests := []*Request{
+ nilBody,
+ {Method: "GET", URL: nil},
+ }
+ for i, req := range tests {
+ err := req.ParseForm()
+ if req.Form == nil {
+ t.Errorf("%d. Form not initialized, error %v", i, err)
+ }
+ if req.PostForm == nil {
+ t.Errorf("%d. PostForm not initialized, error %v", i, err)
+ }
+ }
+}
+
func TestMultipartReader(t *testing.T) {
req := &Request{
Method: "POST",
@@ -129,7 +161,7 @@ func TestSetBasicAuth(t *testing.T) {
}
func TestMultipartRequest(t *testing.T) {
- // Test that we can read the values and files of a
+ // Test that we can read the values and files of a
// multipart request with FormValue and FormFile,
// and that ParseMultipartForm can be called multiple times.
req := newTestMultipartRequest(t)
@@ -196,6 +228,75 @@ func TestReadRequestErrors(t *testing.T) {
}
}
+func TestNewRequestHost(t *testing.T) {
+ req, err := NewRequest("GET", "http://localhost:1234/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if req.Host != "localhost:1234" {
+ t.Errorf("Host = %q; want localhost:1234", req.Host)
+ }
+}
+
+func TestNewRequestContentLength(t *testing.T) {
+ readByte := func(r io.Reader) io.Reader {
+ var b [1]byte
+ r.Read(b[:])
+ return r
+ }
+ tests := []struct {
+ r io.Reader
+ want int64
+ }{
+ {bytes.NewReader([]byte("123")), 3},
+ {bytes.NewBuffer([]byte("1234")), 4},
+ {strings.NewReader("12345"), 5},
+ // Not detected:
+ {struct{ io.Reader }{strings.NewReader("xyz")}, 0},
+ {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0},
+ {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0},
+ }
+ for _, tt := range tests {
+ req, err := NewRequest("POST", "http://localhost/", tt.r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if req.ContentLength != tt.want {
+ t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want)
+ }
+ }
+}
+
+type logWrites struct {
+ t *testing.T
+ dst *[]string
+}
+
+func (l logWrites) WriteByte(c byte) error {
+ l.t.Fatalf("unexpected WriteByte call")
+ return nil
+}
+
+func (l logWrites) Write(p []byte) (n int, err error) {
+ *l.dst = append(*l.dst, string(p))
+ return len(p), nil
+}
+
+func TestRequestWriteBufferedWriter(t *testing.T) {
+ got := []string{}
+ req, _ := NewRequest("GET", "http://foo.com/", nil)
+ req.Write(logWrites{t, &got})
+ want := []string{
+ "GET / HTTP/1.1\r\n",
+ "Host: foo.com\r\n",
+ "User-Agent: Go http package\r\n",
+ "\r\n",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Writes = %q\n Want = %q", got, want)
+ }
+}
+
func testMissingFile(t *testing.T, req *Request) {
f, fh, err := req.FormFile("missing")
if f != nil {
@@ -300,3 +401,81 @@ Content-Disposition: form-data; name="textb"
` + textbValue + `
--MyBoundary--
`
+
+func benchmarkReadRequest(b *testing.B, request string) {
+ request = request + "\n" // final \n
+ request = strings.Replace(request, "\n", "\r\n", -1) // expand \n to \r\n
+ b.SetBytes(int64(len(request)))
+ r := bufio.NewReader(&infiniteReader{buf: []byte(request)})
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := ReadRequest(r)
+ if err != nil {
+ b.Fatalf("failed to read request: %v", err)
+ }
+ }
+}
+
+// infiniteReader satisfies Read requests as if the contents of buf
+// loop indefinitely.
+type infiniteReader struct {
+ buf []byte
+ offset int
+}
+
+func (r *infiniteReader) Read(b []byte) (int, error) {
+ n := copy(b, r.buf[r.offset:])
+ r.offset = (r.offset + n) % len(r.buf)
+ return n, nil
+}
+
+func BenchmarkReadRequestChrome(b *testing.B) {
+ // https://github.com/felixge/node-http-perf/blob/master/fixtures/get.http
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+Connection: keep-alive
+Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+Cookie: __utma=1.1978842379.1323102373.1323102373.1323102373.1; EPi:NumberOfVisits=1,2012-02-28T13:42:18; CrmSession=5b707226b9563e1bc69084d07a107c98; plushContainerWidth=100%25; plushNoTopMenu=0; hudson_auto_refresh=false
+`)
+}
+
+func BenchmarkReadRequestCurl(b *testing.B) {
+ // curl http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+User-Agent: curl/7.27.0
+Host: localhost:8080
+Accept: */*
+`)
+}
+
+func BenchmarkReadRequestApachebench(b *testing.B) {
+ // ab -n 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.0
+Host: localhost:8080
+User-Agent: ApacheBench/2.3
+Accept: */*
+`)
+}
+
+func BenchmarkReadRequestSiege(b *testing.B) {
+ // siege -r 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+Accept: */*
+Accept-Encoding: gzip
+User-Agent: JoeDog/1.00 [en] (X11; I; Siege 2.70)
+Connection: keep-alive
+`)
+}
+
+func BenchmarkReadRequestWrk(b *testing.B) {
+ // wrk -t 1 -r 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+`)
+}
diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go
index fc3186f0c..bc637f18b 100644
--- a/src/pkg/net/http/requestwrite_test.go
+++ b/src/pkg/net/http/requestwrite_test.go
@@ -328,6 +328,69 @@ var reqWriteTests = []reqWriteTest{
"User-Agent: Go http package\r\n" +
"X-Foo: X-Bar\r\n\r\n",
},
+
+ // If no Request.Host and no Request.URL.Host, we send
+ // an empty Host header, and don't use
+ // Request.Header["Host"]. This is just testing that
+ // we don't change Go 1.0 behavior.
+ {
+ Req: Request{
+ Method: "GET",
+ Host: "",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Host": []string{"bad.example.com"},
+ },
+ },
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: \r\n" +
+ "User-Agent: Go http package\r\n\r\n",
+ },
+
+ // Opaque test #1 from golang.org/issue/4860
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Opaque: "/%2F/%2F/",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ },
+
+ WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go http package\r\n\r\n",
+ },
+
+ // Opaque test #2 from golang.org/issue/4860
+ {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "x.google.com",
+ Opaque: "//y.google.com/%2F/%2F/",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ },
+
+ WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" +
+ "Host: x.google.com\r\n" +
+ "User-Agent: Go http package\r\n\r\n",
+ },
}
func TestRequestWrite(t *testing.T) {
diff --git a/src/pkg/net/http/response.go b/src/pkg/net/http/response.go
index 945ecd8a4..391ebbf6d 100644
--- a/src/pkg/net/http/response.go
+++ b/src/pkg/net/http/response.go
@@ -49,7 +49,7 @@ type Response struct {
Body io.ReadCloser
// ContentLength records the length of the associated content. The
- // value -1 indicates that the length is unknown. Unless RequestMethod
+ // value -1 indicates that the length is unknown. Unless Request.Method
// is "HEAD", values >= 0 indicate that the given number of bytes may
// be read from Body.
ContentLength int64
@@ -107,7 +107,6 @@ func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err error) {
resp = new(Response)
resp.Request = req
- resp.Request.Method = strings.ToUpper(resp.Request.Method)
// Parse the first line of the response.
line, err := tp.ReadLine()
@@ -179,7 +178,7 @@ func (r *Response) ProtoAtLeast(major, minor int) bool {
// StatusCode
// ProtoMajor
// ProtoMinor
-// RequestMethod
+// Request.Method
// TransferEncoding
// Trailer
// Body
@@ -188,11 +187,6 @@ func (r *Response) ProtoAtLeast(major, minor int) bool {
//
func (r *Response) Write(w io.Writer) error {
- // RequestMethod should be upper-case
- if r.Request != nil {
- r.Request.Method = strings.ToUpper(r.Request.Method)
- }
-
// Status line
text := r.Status
if text == "" {
@@ -204,9 +198,7 @@ func (r *Response) Write(w io.Writer) error {
}
protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor)
statusCode := strconv.Itoa(r.StatusCode) + " "
- if strings.HasPrefix(text, statusCode) {
- text = text[len(statusCode):]
- }
+ text = strings.TrimPrefix(text, statusCode)
io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n")
// Process Body,ContentLength,Close,Trailer
diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go
index 6eed4887d..2f5f77369 100644
--- a/src/pkg/net/http/response_test.go
+++ b/src/pkg/net/http/response_test.go
@@ -124,7 +124,7 @@ var respTests = []respTest{
// Chunked response without Content-Length.
{
- "HTTP/1.0 200 OK\r\n" +
+ "HTTP/1.1 200 OK\r\n" +
"Transfer-Encoding: chunked\r\n" +
"\r\n" +
"0a\r\n" +
@@ -137,12 +137,12 @@ var respTests = []respTest{
Response{
Status: "200 OK",
StatusCode: 200,
- Proto: "HTTP/1.0",
+ Proto: "HTTP/1.1",
ProtoMajor: 1,
- ProtoMinor: 0,
+ ProtoMinor: 1,
Request: dummyReq("GET"),
Header: Header{},
- Close: true,
+ Close: false,
ContentLength: -1,
TransferEncoding: []string{"chunked"},
},
@@ -152,24 +152,24 @@ var respTests = []respTest{
// Chunked response with Content-Length.
{
- "HTTP/1.0 200 OK\r\n" +
+ "HTTP/1.1 200 OK\r\n" +
"Transfer-Encoding: chunked\r\n" +
"Content-Length: 10\r\n" +
"\r\n" +
"0a\r\n" +
- "Body here\n" +
+ "Body here\n\r\n" +
"0\r\n" +
"\r\n",
Response{
Status: "200 OK",
StatusCode: 200,
- Proto: "HTTP/1.0",
+ Proto: "HTTP/1.1",
ProtoMajor: 1,
- ProtoMinor: 0,
+ ProtoMinor: 1,
Request: dummyReq("GET"),
Header: Header{},
- Close: true,
+ Close: false,
ContentLength: -1, // TODO(rsc): Fix?
TransferEncoding: []string{"chunked"},
},
@@ -177,23 +177,88 @@ var respTests = []respTest{
"Body here\n",
},
- // Chunked response in response to a HEAD request (the "chunked" should
- // be ignored, as HEAD responses never have bodies)
+ // Chunked response in response to a HEAD request
{
- "HTTP/1.0 200 OK\r\n" +
+ "HTTP/1.1 200 OK\r\n" +
"Transfer-Encoding: chunked\r\n" +
"\r\n",
Response{
- Status: "200 OK",
- StatusCode: 200,
- Proto: "HTTP/1.0",
- ProtoMajor: 1,
- ProtoMinor: 0,
- Request: dummyReq("HEAD"),
- Header: Header{},
- Close: true,
- ContentLength: 0,
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{},
+ TransferEncoding: []string{"chunked"},
+ Close: false,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // Content-Length in response to a HEAD request
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("HEAD"),
+ Header: Header{"Content-Length": {"256"}},
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // Content-Length in response to a HEAD request with HTTP/1.1
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{"Content-Length": {"256"}},
+ TransferEncoding: nil,
+ Close: false,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // No Content-Length or Chunked in response to a HEAD request
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("HEAD"),
+ Header: Header{},
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: -1,
},
"",
@@ -259,16 +324,37 @@ var respTests = []respTest{
"",
},
+
+ // golang.org/issue/4767: don't special-case multipart/byteranges responses
+ {
+ `HTTP/1.1 206 Partial Content
+Connection: close
+Content-Type: multipart/byteranges; boundary=18a75608c8f47cef
+
+some body`,
+ Response{
+ Status: "206 Partial Content",
+ StatusCode: 206,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Type": []string{"multipart/byteranges; boundary=18a75608c8f47cef"},
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "some body",
+ },
}
func TestReadResponse(t *testing.T) {
- for i := range respTests {
- tt := &respTests[i]
- var braw bytes.Buffer
- braw.WriteString(tt.Raw)
- resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.Request)
+ for i, tt := range respTests {
+ resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request)
if err != nil {
- t.Errorf("#%d: %s", i, err)
+ t.Errorf("#%d: %v", i, err)
continue
}
rbody := resp.Body
@@ -276,7 +362,11 @@ func TestReadResponse(t *testing.T) {
diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp)
var bout bytes.Buffer
if rbody != nil {
- io.Copy(&bout, rbody)
+ _, err = io.Copy(&bout, rbody)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
rbody.Close()
}
body := bout.String()
@@ -286,6 +376,22 @@ func TestReadResponse(t *testing.T) {
}
}
+func TestWriteResponse(t *testing.T) {
+ for i, tt := range respTests {
+ resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ bout := bytes.NewBuffer(nil)
+ err = resp.Write(bout)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ }
+}
+
var readResponseCloseInMiddleTests = []struct {
chunked, compressed bool
}{
diff --git a/src/pkg/net/http/responsewrite_test.go b/src/pkg/net/http/responsewrite_test.go
index f8e63acf4..5c10e2161 100644
--- a/src/pkg/net/http/responsewrite_test.go
+++ b/src/pkg/net/http/responsewrite_test.go
@@ -15,83 +15,83 @@ type respWriteTest struct {
Raw string
}
-var respWriteTests = []respWriteTest{
- // HTTP/1.0, identity coding; no trailer
- {
- Response{
- StatusCode: 503,
- ProtoMajor: 1,
- ProtoMinor: 0,
- Request: dummyReq("GET"),
- Header: Header{},
- Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
- ContentLength: 6,
- },
+func TestResponseWrite(t *testing.T) {
+ respWriteTests := []respWriteTest{
+ // HTTP/1.0, identity coding; no trailer
+ {
+ Response{
+ StatusCode: 503,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ ContentLength: 6,
+ },
- "HTTP/1.0 503 Service Unavailable\r\n" +
- "Content-Length: 6\r\n\r\n" +
- "abcdef",
- },
- // Unchunked response without Content-Length.
- {
- Response{
- StatusCode: 200,
- ProtoMajor: 1,
- ProtoMinor: 0,
- Request: dummyReq("GET"),
- Header: Header{},
- Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
- ContentLength: -1,
+ "HTTP/1.0 503 Service Unavailable\r\n" +
+ "Content-Length: 6\r\n\r\n" +
+ "abcdef",
},
- "HTTP/1.0 200 OK\r\n" +
- "\r\n" +
- "abcdef",
- },
- // HTTP/1.1, chunked coding; empty trailer; close
- {
- Response{
- StatusCode: 200,
- ProtoMajor: 1,
- ProtoMinor: 1,
- Request: dummyReq("GET"),
- Header: Header{},
- Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
- ContentLength: 6,
- TransferEncoding: []string{"chunked"},
- Close: true,
+ // Unchunked response without Content-Length.
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ ContentLength: -1,
+ },
+ "HTTP/1.0 200 OK\r\n" +
+ "\r\n" +
+ "abcdef",
},
+ // HTTP/1.1, chunked coding; empty trailer; close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")),
+ ContentLength: 6,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
+ },
- "HTTP/1.1 200 OK\r\n" +
- "Connection: close\r\n" +
- "Transfer-Encoding: chunked\r\n\r\n" +
- "6\r\nabcdef\r\n0\r\n\r\n",
- },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "6\r\nabcdef\r\n0\r\n\r\n",
+ },
- // Header value with a newline character (Issue 914).
- // Also tests removal of leading and trailing whitespace.
- {
- Response{
- StatusCode: 204,
- ProtoMajor: 1,
- ProtoMinor: 1,
- Request: dummyReq("GET"),
- Header: Header{
- "Foo": []string{" Bar\nBaz "},
+ // Header value with a newline character (Issue 914).
+ // Also tests removal of leading and trailing whitespace.
+ {
+ Response{
+ StatusCode: 204,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Foo": []string{" Bar\nBaz "},
+ },
+ Body: nil,
+ ContentLength: 0,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
},
- Body: nil,
- ContentLength: 0,
- TransferEncoding: []string{"chunked"},
- Close: true,
- },
- "HTTP/1.1 204 No Content\r\n" +
- "Connection: close\r\n" +
- "Foo: Bar Baz\r\n" +
- "\r\n",
- },
-}
+ "HTTP/1.1 204 No Content\r\n" +
+ "Connection: close\r\n" +
+ "Foo: Bar Baz\r\n" +
+ "\r\n",
+ },
+ }
-func TestResponseWrite(t *testing.T) {
for i := range respWriteTests {
tt := &respWriteTests[i]
var braw bytes.Buffer
diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go
index b6a6b4c77..3300fef59 100644
--- a/src/pkg/net/http/serve_test.go
+++ b/src/pkg/net/http/serve_test.go
@@ -20,8 +20,13 @@ import (
"net/http/httputil"
"net/url"
"os"
+ "os/exec"
"reflect"
+ "runtime"
+ "strconv"
"strings"
+ "sync"
+ "sync/atomic"
"syscall"
"testing"
"time"
@@ -62,6 +67,7 @@ func (a dummyAddr) String() string {
type testConn struct {
readBuf bytes.Buffer
writeBuf bytes.Buffer
+ closec chan bool // if non-nil, send value to it on close
}
func (c *testConn) Read(b []byte) (int, error) {
@@ -73,6 +79,10 @@ func (c *testConn) Write(b []byte) (int, error) {
}
func (c *testConn) Close() error {
+ select {
+ case c.closec <- true:
+ default:
+ }
return nil
}
@@ -168,13 +178,18 @@ var vtests = []struct {
{"http://someHost.com/someDir/apage", "someHost.com/someDir"},
{"http://otherHost.com/someDir/apage", "someDir"},
{"http://otherHost.com/aDir/apage", "Default"},
+ // redirections for trees
+ {"http://localhost/someDir", "/someDir/"},
+ {"http://someHost.com/someDir", "/someDir/"},
}
func TestHostHandlers(t *testing.T) {
+ defer checkLeakedTransports(t)
+ mux := NewServeMux()
for _, h := range handlers {
- Handle(h.pattern, stringHandler(h.msg))
+ mux.Handle(h.pattern, stringHandler(h.msg))
}
- ts := httptest.NewServer(nil)
+ ts := httptest.NewServer(mux)
defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -199,9 +214,19 @@ func TestHostHandlers(t *testing.T) {
t.Errorf("reading response: %v", err)
continue
}
- s := r.Header.Get("Result")
- if s != vt.expected {
- t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ switch r.StatusCode {
+ case StatusOK:
+ s := r.Header.Get("Result")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ case StatusMovedPermanently:
+ s := r.Header.Get("Location")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ default:
+ t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
}
}
}
@@ -232,28 +257,22 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
}
func TestServerTimeouts(t *testing.T) {
- // TODO(bradfitz): convert this to use httptest.Server
- l, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- t.Fatalf("listen error: %v", err)
- }
- addr, _ := l.Addr().(*net.TCPAddr)
-
+ defer checkLeakedTransports(t)
reqNum := 0
- handler := HandlerFunc(func(res ResponseWriter, req *Request) {
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
reqNum++
fmt.Fprintf(res, "req=%d", reqNum)
- })
-
- server := &Server{Handler: handler, ReadTimeout: 250 * time.Millisecond, WriteTimeout: 250 * time.Millisecond}
- go server.Serve(l)
-
- url := fmt.Sprintf("http://%s/", addr)
+ }))
+ ts.Config.ReadTimeout = 250 * time.Millisecond
+ ts.Config.WriteTimeout = 250 * time.Millisecond
+ ts.Start()
+ defer ts.Close()
// Hit the HTTP server successfully.
tr := &Transport{DisableKeepAlives: true} // they interfere with this test
+ defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
- r, err := c.Get(url)
+ r, err := c.Get(ts.URL)
if err != nil {
t.Fatalf("http Get #1: %v", err)
}
@@ -266,13 +285,13 @@ func TestServerTimeouts(t *testing.T) {
// Slow client that should timeout.
t1 := time.Now()
- conn, err := net.Dial("tcp", addr.String())
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
buf := make([]byte, 1)
n, err := conn.Read(buf)
- latency := time.Now().Sub(t1)
+ latency := time.Since(t1)
if n != 0 || err != io.EOF {
t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
}
@@ -283,7 +302,7 @@ func TestServerTimeouts(t *testing.T) {
// Hit the HTTP server successfully again, verifying that the
// previous slow connection didn't run our handler. (that we
// get "req=2", not "req=3")
- r, err = Get(url)
+ r, err = Get(ts.URL)
if err != nil {
t.Fatalf("http Get #2: %v", err)
}
@@ -293,11 +312,87 @@ func TestServerTimeouts(t *testing.T) {
t.Errorf("Get #2 got %q, want %q", string(got), expected)
}
- l.Close()
+ if !testing.Short() {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+ go io.Copy(ioutil.Discard, conn)
+ for i := 0; i < 5; i++ {
+ _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ if err != nil {
+ t.Fatalf("on write %d: %v", i, err)
+ }
+ time.Sleep(ts.Config.ReadTimeout / 2)
+ }
+ }
+}
+
+// golang.org/issue/4741 -- setting only a write timeout that triggers
+// shouldn't cause a handler to block forever on reads (next HTTP
+// request) that will never happen.
+func TestOnlyWriteTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ var conn net.Conn
+ var afterTimeoutErrc = make(chan error, 1)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) {
+ buf := make([]byte, 512<<10)
+ _, err := w.Write(buf)
+ if err != nil {
+ t.Errorf("handler Write error: %v", err)
+ return
+ }
+ conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
+ _, err = w.Write(buf)
+ afterTimeoutErrc <- err
+ }))
+ ts.Listener = trackLastConnListener{ts.Listener, &conn}
+ ts.Start()
+ defer ts.Close()
+
+ tr := &Transport{DisableKeepAlives: false}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ errc := make(chan error)
+ go func() {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ errc <- err
+ return
+ }
+ _, err = io.Copy(ioutil.Discard, res.Body)
+ errc <- err
+ }()
+ select {
+ case err := <-errc:
+ if err == nil {
+ t.Errorf("expected an error from Get request")
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for Get error")
+ }
+ if err := <-afterTimeoutErrc; err == nil {
+ t.Error("expected write error after timeout")
+ }
+}
+
+// trackLastConnListener tracks the last net.Conn that was accepted.
+type trackLastConnListener struct {
+ net.Listener
+ last *net.Conn // destination
}
-// TestIdentityResponse verifies that a handler can unset
+func (l trackLastConnListener) Accept() (c net.Conn, err error) {
+ c, err = l.Listener.Accept()
+ *l.last = c
+ return
+}
+
+// TestIdentityResponse verifies that a handler can unset
func TestIdentityResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
rw.Header().Set("Content-Length", "3")
rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
@@ -343,10 +438,12 @@ func TestIdentityResponse(t *testing.T) {
// Verify that ErrContentLength is returned
url := ts.URL + "/?overwrite=1"
- _, err := Get(url)
+ res, err := Get(url)
if err != nil {
t.Fatalf("error with Get of %s: %v", url, err)
}
+ res.Body.Close()
+
// Verify that the connection is closed when the declared Content-Length
// is larger than what the handler wrote.
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -370,7 +467,8 @@ func TestIdentityResponse(t *testing.T) {
})
}
-func testTcpConnectionCloses(t *testing.T, req string, h Handler) {
+func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
+ defer checkLeakedTransports(t)
s := httptest.NewServer(h)
defer s.Close()
@@ -386,17 +484,18 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) {
}
r := bufio.NewReader(conn)
- _, err = ReadResponse(r, &Request{Method: "GET"})
+ res, err := ReadResponse(r, &Request{Method: "GET"})
if err != nil {
t.Fatal("ReadResponse error:", err)
}
- success := make(chan bool)
+ didReadAll := make(chan bool, 1)
go func() {
select {
case <-time.After(5 * time.Second):
- t.Fatal("body not closed after 5s")
- case <-success:
+ t.Error("body not closed after 5s")
+ return
+ case <-didReadAll:
}
}()
@@ -404,32 +503,43 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) {
if err != nil {
t.Fatal("read error:", err)
}
+ didReadAll <- true
- success <- true
+ if !res.Close {
+ t.Errorf("Response.Close = false; want true")
+ }
}
// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive.
func TestServeHTTP10Close(t *testing.T) {
- testTcpConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "testdata/file")
}))
}
+// TestClientCanClose verifies that clients can also force a connection to close.
+func TestClientCanClose(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Nothing.
+ }))
+}
+
// TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close,
// even for HTTP/1.1 requests.
func TestHandlersCanSetConnectionClose11(t *testing.T) {
- testTcpConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Connection", "close")
}))
}
func TestHandlersCanSetConnectionClose10(t *testing.T) {
- testTcpConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Connection", "close")
}))
}
func TestSetsRemoteAddr(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%s", r.RemoteAddr)
}))
@@ -450,11 +560,13 @@ func TestSetsRemoteAddr(t *testing.T) {
}
func TestChunkedResponseHeaders(t *testing.T) {
+ defer checkLeakedTransports(t)
log.SetOutput(ioutil.Discard) // is noisy otherwise
defer log.SetOutput(os.Stderr)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
+ w.(Flusher).Flush()
fmt.Fprintf(w, "I am a chunked response.")
}))
defer ts.Close()
@@ -463,6 +575,7 @@ func TestChunkedResponseHeaders(t *testing.T) {
if err != nil {
t.Fatalf("Get error: %v", err)
}
+ defer res.Body.Close()
if g, e := res.ContentLength, int64(-1); g != e {
t.Errorf("expected ContentLength of %d; got %d", e, g)
}
@@ -478,6 +591,7 @@ func TestChunkedResponseHeaders(t *testing.T) {
// chunking in their response headers and aren't allowed to produce
// output.
func Test304Responses(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.WriteHeader(StatusNotModified)
_, err := w.Write([]byte("illegal body"))
@@ -507,6 +621,7 @@ func Test304Responses(t *testing.T) {
// allowed to produce output, and don't set a Content-Type since
// the real type of the body data cannot be inferred.
func TestHeadResponses(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("Ignored body"))
if err != ErrBodyNotAllowed {
@@ -541,6 +656,7 @@ func TestHeadResponses(t *testing.T) {
}
func TestTLSHandshakeTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
ts.Config.ReadTimeout = 250 * time.Millisecond
ts.StartTLS()
@@ -560,6 +676,7 @@ func TestTLSHandshakeTimeout(t *testing.T) {
}
func TestTLSServer(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.TLS != nil {
w.Header().Set("X-TLS-Set", "true")
@@ -642,6 +759,7 @@ var serverExpectTests = []serverExpectTest{
// Tests that the server responds to the "Expect" request header
// correctly.
func TestServerExpect(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// Note using r.FormValue("readbody") because for POST
// requests that would read from r.Body, which we only
@@ -661,30 +779,51 @@ func TestServerExpect(t *testing.T) {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
- sendf := func(format string, args ...interface{}) {
- _, err := fmt.Fprintf(conn, format, args...)
- if err != nil {
- t.Fatalf("On test %#v, error writing %q: %v", test, format, err)
- }
- }
+
+ // Only send the body immediately if we're acting like an HTTP client
+ // that doesn't send 100-continue expectations.
+ writeBody := test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue"
+
go func() {
- sendf("POST /?readbody=%v HTTP/1.1\r\n"+
+ _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
"Connection: close\r\n"+
"Content-Length: %d\r\n"+
"Expect: %s\r\nHost: foo\r\n\r\n",
test.readBody, test.contentLength, test.expectation)
- if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" {
+ if err != nil {
+ t.Errorf("On test %#v, error writing request headers: %v", test, err)
+ return
+ }
+ if writeBody {
body := strings.Repeat("A", test.contentLength)
- sendf(body)
+ _, err = fmt.Fprint(conn, body)
+ if err != nil {
+ if !test.readBody {
+ // Server likely already hung up on us.
+ // See larger comment below.
+ t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
+ return
+ }
+ t.Errorf("On test %#v, error writing request body: %v", test, err)
+ }
}
}()
bufr := bufio.NewReader(conn)
line, err := bufr.ReadString('\n')
if err != nil {
- t.Fatalf("ReadString: %v", err)
+ if writeBody && !test.readBody {
+ // This is an acceptable failure due to a possible TCP race:
+ // We were still writing data and the server hung up on us. A TCP
+ // implementation may send a RST if our request body data was known
+ // to be lost, which may trigger our reads to fail.
+ // See RFC 1122 page 88.
+ t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
+ return
+ }
+ t.Fatalf("On test %#v, ReadString: %v", test, err)
}
if !strings.Contains(line, test.expectedResponse) {
- t.Errorf("for test %#v got first line=%q", test, line)
+ t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
}
}
@@ -714,6 +853,7 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) {
t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len())
}
rw.WriteHeader(200)
+ rw.(Flusher).Flush()
if g, e := conn.readBuf.Len(), 0; g != e {
t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
}
@@ -736,27 +876,28 @@ func TestServerUnreadRequestBodyLarge(t *testing.T) {
"Content-Length: %d\r\n"+
"\r\n", len(body))))
conn.readBuf.Write([]byte(body))
-
- done := make(chan bool)
+ conn.closec = make(chan bool, 1)
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
- defer close(done)
if conn.readBuf.Len() < len(body)/2 {
t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
}
rw.WriteHeader(200)
+ rw.(Flusher).Flush()
if conn.readBuf.Len() < len(body)/2 {
t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
}
- if c := rw.Header().Get("Connection"); c != "close" {
- t.Errorf(`Connection header = %q; want "close"`, c)
- }
}))
- <-done
+ <-conn.closec
+
+ if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
+ t.Errorf("Expected a Connection: close header; got response: %s", res)
+ }
}
func TestTimeoutHandler(t *testing.T) {
+ defer checkLeakedTransports(t)
sendHi := make(chan bool, 1)
writeErrors := make(chan error, 1)
sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -831,6 +972,7 @@ func TestRedirectMunging(t *testing.T) {
// the previous request's body, which is not optimal for zero-lengthed bodies,
// as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF.
func TestZeroLengthPostAndResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
all, err := ioutil.ReadAll(r.Body)
if err != nil {
@@ -868,15 +1010,20 @@ func TestZeroLengthPostAndResponse(t *testing.T) {
}
}
+func TestHandlerPanicNil(t *testing.T) {
+ testHandlerPanic(t, false, nil)
+}
+
func TestHandlerPanic(t *testing.T) {
- testHandlerPanic(t, false)
+ testHandlerPanic(t, false, "intentional death for testing")
}
func TestHandlerPanicWithHijack(t *testing.T) {
- testHandlerPanic(t, true)
+ testHandlerPanic(t, true, "intentional death for testing")
}
-func testHandlerPanic(t *testing.T, withHijack bool) {
+func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) {
+ defer checkLeakedTransports(t)
// Unlike the other tests that set the log output to ioutil.Discard
// to quiet the output, this test uses a pipe. The pipe serves three
// purposes:
@@ -896,6 +1043,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
pr, pw := io.Pipe()
log.SetOutput(pw)
defer log.SetOutput(os.Stderr)
+ defer pw.Close()
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if withHijack {
@@ -905,7 +1053,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
}
defer rwc.Close()
}
- panic("intentional death for testing")
+ panic(panicValue)
}))
defer ts.Close()
@@ -917,8 +1065,8 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
buf := make([]byte, 4<<10)
_, err := pr.Read(buf)
pr.Close()
- if err != nil {
- t.Fatal(err)
+ if err != nil && err != io.EOF {
+ t.Error(err)
}
done <- true
}()
@@ -928,6 +1076,10 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
t.Logf("expected an error")
}
+ if panicValue == nil {
+ return
+ }
+
select {
case <-done:
return
@@ -937,6 +1089,7 @@ func testHandlerPanic(t *testing.T, withHijack bool) {
}
func TestNoDate(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header()["Date"] = nil
}))
@@ -952,6 +1105,7 @@ func TestNoDate(t *testing.T) {
}
func TestStripPrefix(t *testing.T) {
+ defer checkLeakedTransports(t)
h := HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("X-Path", r.URL.Path)
})
@@ -965,6 +1119,7 @@ func TestStripPrefix(t *testing.T) {
if g, e := res.Header.Get("X-Path"), "/bar"; g != e {
t.Errorf("test 1: got %s, want %s", g, e)
}
+ res.Body.Close()
res, err = Get(ts.URL + "/bar")
if err != nil {
@@ -973,9 +1128,11 @@ func TestStripPrefix(t *testing.T) {
if g, e := res.StatusCode, 404; g != e {
t.Errorf("test 2: got status %v, want %v", g, e)
}
+ res.Body.Close()
}
func TestRequestLimit(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
t.Fatalf("didn't expect to get request in Handler")
}))
@@ -992,6 +1149,7 @@ func TestRequestLimit(t *testing.T) {
// we do support it (at least currently), so we expect a response below.
t.Fatalf("Do: %v", err)
}
+ defer res.Body.Close()
if res.StatusCode != 413 {
t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status)
}
@@ -1013,11 +1171,12 @@ type countReader struct {
func (cr countReader) Read(p []byte) (n int, err error) {
n, err = cr.r.Read(p)
- *cr.n += int64(n)
+ atomic.AddInt64(cr.n, int64(n))
return
}
func TestRequestBodyLimit(t *testing.T) {
+ defer checkLeakedTransports(t)
const limit = 1 << 20
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
r.Body = MaxBytesReader(w, r.Body, limit)
@@ -1031,8 +1190,8 @@ func TestRequestBodyLimit(t *testing.T) {
}))
defer ts.Close()
- nWritten := int64(0)
- req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), &nWritten}, limit*200))
+ nWritten := new(int64)
+ req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
// Send the POST, but don't care it succeeds or not. The
// remote side is going to reply and then close the TCP
@@ -1045,7 +1204,7 @@ func TestRequestBodyLimit(t *testing.T) {
// the remote side hung up on us before we wrote too much.
_, _ = DefaultClient.Do(req)
- if nWritten > limit*100 {
+ if atomic.LoadInt64(nWritten) > limit*100 {
t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
limit, nWritten)
}
@@ -1054,6 +1213,7 @@ func TestRequestBodyLimit(t *testing.T) {
// TestClientWriteShutdown tests that if the client shuts down the write
// side of their TCP connection, the server doesn't send a 400 Bad Request.
func TestClientWriteShutdown(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -1086,28 +1246,207 @@ func TestClientWriteShutdown(t *testing.T) {
// Tests that chunked server responses that write 1 byte at a time are
// buffered before chunk headers are added, not after chunk headers.
func TestServerBufferedChunking(t *testing.T) {
- if true {
- t.Logf("Skipping known broken test; see Issue 2357")
- return
- }
conn := new(testConn)
conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
- done := make(chan bool)
+ conn.closec = make(chan bool, 1)
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
- defer close(done)
- rw.Header().Set("Content-Type", "text/plain") // prevent sniffing, which buffers
+ rw.(Flusher).Flush() // force the Header to be sent, in chunking mode, not counting the length
rw.Write([]byte{'x'})
rw.Write([]byte{'y'})
rw.Write([]byte{'z'})
}))
- <-done
+ <-conn.closec
if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
conn.writeBuf.Bytes())
}
}
+// Tests that the server flushes its response headers out when it's
+// ignoring the response body and waits a bit before forcefully
+// closing the TCP connection, causing the client to get a RST.
+// See http://golang.org/issue/3595
+func TestServerGracefulClose(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, "bye", StatusUnauthorized)
+ }))
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ const bodySize = 5 << 20
+ req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
+ for i := 0; i < bodySize; i++ {
+ req = append(req, 'x')
+ }
+ writeErr := make(chan error)
+ go func() {
+ _, err := conn.Write(req)
+ writeErr <- err
+ }()
+ br := bufio.NewReader(conn)
+ lineNum := 0
+ for {
+ line, err := br.ReadString('\n')
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatalf("ReadLine: %v", err)
+ }
+ lineNum++
+ if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
+ t.Errorf("Response line = %q; want a 401", line)
+ }
+ }
+ // Wait for write to finish. This is a broken pipe on both
+ // Darwin and Linux, but checking this isn't the point of
+ // the test.
+ <-writeErr
+}
+
+func TestCaseSensitiveMethod(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "get" {
+ t.Errorf(`Got method %q; want "get"`, r.Method)
+ }
+ }))
+ defer ts.Close()
+ req, _ := NewRequest("get", ts.URL, nil)
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ res.Body.Close()
+}
+
+// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1
+// request (both keep-alive), when a Handler never writes any
+// response, the net/http package adds a "Content-Length: 0" response
+// header.
+func TestContentLengthZero(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {}))
+ defer ts.Close()
+
+ for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
+ if err != nil {
+ t.Fatalf("error writing: %v", err)
+ }
+ req, _ := NewRequest("GET", "/", nil)
+ res, err := ReadResponse(bufio.NewReader(conn), req)
+ if err != nil {
+ t.Fatalf("error reading response: %v", err)
+ }
+ if te := res.TransferEncoding; len(te) > 0 {
+ t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
+ }
+ if cl := res.ContentLength; cl != 0 {
+ t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
+ }
+ conn.Close()
+ }
+}
+
+func TestCloseNotifier(t *testing.T) {
+ gotReq := make(chan bool, 1)
+ sawClose := make(chan bool, 1)
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ gotReq <- true
+ cc := rw.(CloseNotifier).CloseNotify()
+ <-cc
+ sawClose <- true
+ }))
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ diec := make(chan bool)
+ go func() {
+ _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
+ if err != nil {
+ t.Fatal(err)
+ }
+ <-diec
+ conn.Close()
+ }()
+For:
+ for {
+ select {
+ case <-gotReq:
+ diec <- true
+ case <-sawClose:
+ break For
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout")
+ }
+ }
+ ts.Close()
+}
+
+func TestOptions(t *testing.T) {
+ uric := make(chan string, 2) // only expect 1, but leave space for 2
+ mux := NewServeMux()
+ mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
+ uric <- r.RequestURI
+ })
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ // An OPTIONS * request should succeed.
+ _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ br := bufio.NewReader(conn)
+ res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
+ }
+
+ // A GET * request on a ServeMux should fail.
+ _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err = ReadResponse(br, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 400 {
+ t.Errorf("Got non-400 response to GET *: %#v", res)
+ }
+
+ res, err = Get(ts.URL + "/second")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if got := <-uric; got != "/second" {
+ t.Errorf("Handler saw request for %q; want /second", got)
+ }
+}
+
// goTimeout runs f, failing t if f takes more than ns to complete.
func goTimeout(t *testing.T, d time.Duration, f func()) {
ch := make(chan bool, 2)
@@ -1184,3 +1523,100 @@ func BenchmarkClientServer(b *testing.B) {
b.StopTimer()
}
+
+func BenchmarkClientServerParallel4(b *testing.B) {
+ benchmarkClientServerParallel(b, 4)
+}
+
+func BenchmarkClientServerParallel64(b *testing.B) {
+ benchmarkClientServerParallel(b, 64)
+}
+
+func benchmarkClientServerParallel(b *testing.B, conc int) {
+ b.StopTimer()
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ fmt.Fprintf(rw, "Hello world.\n")
+ }))
+ defer ts.Close()
+ b.StartTimer()
+
+ numProcs := runtime.GOMAXPROCS(-1) * conc
+ var wg sync.WaitGroup
+ wg.Add(numProcs)
+ n := int32(b.N)
+ for p := 0; p < numProcs; p++ {
+ go func() {
+ for atomic.AddInt32(&n, -1) >= 0 {
+ res, err := Get(ts.URL)
+ if err != nil {
+ b.Logf("Get: %v", err)
+ continue
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ b.Logf("ReadAll: %v", err)
+ continue
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ panic("Got body: " + body)
+ }
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+}
+
+// A benchmark for profiling the server without the HTTP client code.
+// The client code runs in a subprocess.
+//
+// For use like:
+// $ go test -c
+// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof
+// $ go tool pprof http.test http.prof
+// (pprof) web
+func BenchmarkServer(b *testing.B) {
+ // Child process mode;
+ if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
+ n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
+ if err != nil {
+ panic(err)
+ }
+ for i := 0; i < n; i++ {
+ res, err := Get(url)
+ if err != nil {
+ log.Panicf("Get: %v", err)
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ log.Panicf("ReadAll: %v", err)
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ log.Panicf("Got body: %q", body)
+ }
+ }
+ os.Exit(0)
+ return
+ }
+
+ var res = []byte("Hello world.\n")
+ b.StopTimer()
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "text/html; charset=utf-8")
+ rw.Write(res)
+ }))
+ defer ts.Close()
+ b.StartTimer()
+
+ cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer")
+ cmd.Env = append([]string{
+ fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
+ fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
+ }, os.Environ()...)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ b.Errorf("Test failure: %v, with output: %s", err, out)
+ }
+}
diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go
index 0572b4ae3..b6ab78228 100644
--- a/src/pkg/net/http/server.go
+++ b/src/pkg/net/http/server.go
@@ -11,7 +11,6 @@ package http
import (
"bufio"
- "bytes"
"crypto/tls"
"errors"
"fmt"
@@ -21,7 +20,7 @@ import (
"net"
"net/url"
"path"
- "runtime/debug"
+ "runtime"
"strconv"
"strings"
"sync"
@@ -94,30 +93,188 @@ type Hijacker interface {
Hijack() (net.Conn, *bufio.ReadWriter, error)
}
+// The CloseNotifier interface is implemented by ResponseWriters which
+// allow detecting when the underlying connection has gone away.
+//
+// This mechanism can be used to cancel long operations on the server
+// if the client has disconnected before the response is ready.
+type CloseNotifier interface {
+ // CloseNotify returns a channel that receives a single value
+ // when the client connection has gone away.
+ CloseNotify() <-chan bool
+}
+
// A conn represents the server side of an HTTP connection.
type conn struct {
remoteAddr string // network address of remote side
server *Server // the Server on which the connection arrived
rwc net.Conn // i/o connection
- lr *io.LimitedReader // io.LimitReader(rwc)
- buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc
- hijacked bool // connection has been hijacked by handler
+ sr switchReader // where the LimitReader reads from; usually the rwc
+ lr *io.LimitedReader // io.LimitReader(sr)
+ buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
tlsState *tls.ConnectionState // or nil when not using TLS
- body []byte
+
+ mu sync.Mutex // guards the following
+ clientGone bool // if client has disconnected mid-request
+ closeNotifyc chan bool // made lazily
+ hijackedv bool // connection has been hijacked by handler
+}
+
+func (c *conn) hijacked() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.hijackedv
+}
+
+func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.hijackedv {
+ return nil, nil, ErrHijacked
+ }
+ if c.closeNotifyc != nil {
+ return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier")
+ }
+ c.hijackedv = true
+ rwc = c.rwc
+ buf = c.buf
+ c.rwc = nil
+ c.buf = nil
+ return
+}
+
+func (c *conn) closeNotify() <-chan bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closeNotifyc == nil {
+ c.closeNotifyc = make(chan bool)
+ if c.hijackedv {
+ // to obey the function signature, even though
+ // it'll never receive a value.
+ return c.closeNotifyc
+ }
+ pr, pw := io.Pipe()
+
+ readSource := c.sr.r
+ c.sr.Lock()
+ c.sr.r = pr
+ c.sr.Unlock()
+ go func() {
+ _, err := io.Copy(pw, readSource)
+ if err == nil {
+ err = io.EOF
+ }
+ pw.CloseWithError(err)
+ c.noteClientGone()
+ }()
+ }
+ return c.closeNotifyc
+}
+
+func (c *conn) noteClientGone() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closeNotifyc != nil && !c.clientGone {
+ c.closeNotifyc <- true
+ }
+ c.clientGone = true
+}
+
+type switchReader struct {
+ sync.Mutex
+ r io.Reader
+}
+
+func (sr *switchReader) Read(p []byte) (n int, err error) {
+ sr.Lock()
+ r := sr.r
+ sr.Unlock()
+ return r.Read(p)
+}
+
+// This should be >= 512 bytes for DetectContentType,
+// but otherwise it's somewhat arbitrary.
+const bufferBeforeChunkingSize = 2048
+
+// chunkWriter writes to a response's conn buffer, and is the writer
+// wrapped by the response.bufw buffered writer.
+//
+// chunkWriter also is responsible for finalizing the Header, including
+// conditionally setting the Content-Type and setting a Content-Length
+// in cases where the handler's final output is smaller than the buffer
+// size. It also conditionally adds chunk headers, when in chunking mode.
+//
+// See the comment above (*response).Write for the entire write flow.
+type chunkWriter struct {
+ res *response
+ header Header // a deep copy of r.Header, once WriteHeader is called
+ wroteHeader bool // whether the header's been sent
+
+ // set by the writeHeader method:
+ chunking bool // using chunked transfer encoding for reply body
+}
+
+var crlf = []byte("\r\n")
+
+func (cw *chunkWriter) Write(p []byte) (n int, err error) {
+ if !cw.wroteHeader {
+ cw.writeHeader(p)
+ }
+ if cw.chunking {
+ _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p))
+ if err != nil {
+ cw.res.conn.rwc.Close()
+ return
+ }
+ }
+ n, err = cw.res.conn.buf.Write(p)
+ if cw.chunking && err == nil {
+ _, err = cw.res.conn.buf.Write(crlf)
+ }
+ if err != nil {
+ cw.res.conn.rwc.Close()
+ }
+ return
+}
+
+func (cw *chunkWriter) flush() {
+ if !cw.wroteHeader {
+ cw.writeHeader(nil)
+ }
+ cw.res.conn.buf.Flush()
+}
+
+func (cw *chunkWriter) close() {
+ if !cw.wroteHeader {
+ cw.writeHeader(nil)
+ }
+ if cw.chunking {
+ // zero EOF chunk, trailer key/value pairs (currently
+ // unsupported in Go's server), followed by a blank
+ // line.
+ io.WriteString(cw.res.conn.buf, "0\r\n\r\n")
+ }
}
// A response represents the server side of an HTTP response.
type response struct {
conn *conn
req *Request // request for this response
- chunking bool // using chunked transfer encoding for reply body
- wroteHeader bool // reply header has been written
+ wroteHeader bool // reply header has been (logically) written
wroteContinue bool // 100 Continue response was written
- header Header // reply header parameters
- written int64 // number of bytes written in body
- contentLength int64 // explicitly-declared Content-Length; or -1
- status int // status code passed to WriteHeader
- needSniff bool // need to sniff to find Content-Type
+
+ w *bufio.Writer // buffers output in chunks to chunkWriter
+ cw *chunkWriter
+
+ // handlerHeader is the Header that Handlers get access to,
+ // which may be retained and mutated even after WriteHeader.
+ // handlerHeader is copied into cw.header at WriteHeader
+ // time, and privately mutated thereafter.
+ handlerHeader Header
+
+ written int64 // number of bytes written in body
+ contentLength int64 // explicitly-declared Content-Length; or -1
+ status int // status code passed to WriteHeader
// close connection after this reply. set on request and
// updated after response from handler if there's a
@@ -127,12 +284,14 @@ type response struct {
// requestBodyLimitHit is set by requestTooLarge when
// maxBytesReader hits its max size. It is checked in
- // WriteHeader, to make sure we don't consume the the
+ // WriteHeader, to make sure we don't consume the
// remaining request body to try to advance to the next HTTP
- // request. Instead, when this is set, we stop doing
+ // request. Instead, when this is set, we stop reading
// subsequent requests on this connection and stop reading
// input from it.
requestBodyLimitHit bool
+
+ handlerDone bool // set true when the handler exits
}
// requestTooLarge is called by maxBytesReader when too much input has
@@ -145,42 +304,68 @@ func (w *response) requestTooLarge() {
}
}
+// needsSniff returns whether a Content-Type still needs to be sniffed.
+func (w *response) needsSniff() bool {
+ return !w.cw.wroteHeader && w.handlerHeader.Get("Content-Type") == "" && w.written < sniffLen
+}
+
type writerOnly struct {
io.Writer
}
func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
- // Call WriteHeader before checking w.chunking if it hasn't
- // been called yet, since WriteHeader is what sets w.chunking.
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
- if !w.chunking && w.bodyAllowed() && !w.needSniff {
- w.Flush()
+
+ if w.needsSniff() {
+ n0, err := io.Copy(writerOnly{w}, io.LimitReader(src, sniffLen))
+ n += n0
+ if err != nil {
+ return n, err
+ }
+ }
+
+ w.w.Flush() // get rid of any previous writes
+ w.cw.flush() // make sure Header is written; flush data to rwc
+
+ // Now that cw has been flushed, its chunking field is guaranteed initialized.
+ if !w.cw.chunking && w.bodyAllowed() {
if rf, ok := w.conn.rwc.(io.ReaderFrom); ok {
- n, err = rf.ReadFrom(src)
- w.written += n
- return
+ n0, err := rf.ReadFrom(src)
+ n += n0
+ w.written += n0
+ return n, err
}
}
+
// Fall back to default io.Copy implementation.
// Use wrapper to hide w.ReadFrom from io.Copy.
- return io.Copy(writerOnly{w}, src)
+ n0, err := io.Copy(writerOnly{w}, src)
+ n += n0
+ return n, err
}
// noLimit is an effective infinite upper bound for io.LimitedReader
const noLimit int64 = (1 << 63) - 1
+// debugServerConnections controls whether all server connections are wrapped
+// with a verbose logging wrapper.
+const debugServerConnections = false
+
// Create new connection from rwc.
func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
c = new(conn)
c.remoteAddr = rwc.RemoteAddr().String()
c.server = srv
c.rwc = rwc
- c.body = make([]byte, sniffLen)
- c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader)
+ if debugServerConnections {
+ c.rwc = newLoggingConn("server", c.rwc)
+ }
+ c.sr = switchReader{r: c.rwc}
+ c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
br := bufio.NewReader(c.lr)
- bw := bufio.NewWriter(rwc)
+ bw := bufio.NewWriter(c.rwc)
c.buf = bufio.NewReadWriter(br, bw)
return c, nil
}
@@ -207,9 +392,9 @@ type expectContinueReader struct {
func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
if ecr.closed {
- return 0, errors.New("http: Read after Close on request Body")
+ return 0, ErrBodyReadAfterClose
}
- if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked {
+ if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() {
ecr.resp.wroteContinue = true
io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n")
ecr.resp.conn.buf.Flush()
@@ -232,9 +417,19 @@ var errTooLarge = errors.New("http: request too large")
// Read next request from connection.
func (c *conn) readRequest() (w *response, err error) {
- if c.hijacked {
+ if c.hijacked() {
return nil, ErrHijacked
}
+
+ if d := c.server.ReadTimeout; d != 0 {
+ c.rwc.SetReadDeadline(time.Now().Add(d))
+ }
+ if d := c.server.WriteTimeout; d != 0 {
+ defer func() {
+ c.rwc.SetWriteDeadline(time.Now().Add(d))
+ }()
+ }
+
c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
var req *Request
if req, err = ReadRequest(c.buf.Reader); err != nil {
@@ -248,17 +443,20 @@ func (c *conn) readRequest() (w *response, err error) {
req.RemoteAddr = c.remoteAddr
req.TLS = c.tlsState
- w = new(response)
- w.conn = c
- w.req = req
- w.header = make(Header)
- w.contentLength = -1
- c.body = c.body[:0]
+ w = &response{
+ conn: c,
+ req: req,
+ handlerHeader: make(Header),
+ contentLength: -1,
+ cw: new(chunkWriter),
+ }
+ w.cw.res = w
+ w.w = bufio.NewWriterSize(w.cw, bufferBeforeChunkingSize)
return w, nil
}
func (w *response) Header() Header {
- return w.header
+ return w.handlerHeader
}
// maxPostHandlerReadBytes is the max number of Request.Body bytes not
@@ -273,7 +471,7 @@ func (w *response) Header() Header {
const maxPostHandlerReadBytes = 256 << 10
func (w *response) WriteHeader(code int) {
- if w.conn.hijacked {
+ if w.conn.hijacked() {
log.Print("http: response.WriteHeader on hijacked connection")
return
}
@@ -284,31 +482,68 @@ func (w *response) WriteHeader(code int) {
w.wroteHeader = true
w.status = code
- // Check for a explicit (and valid) Content-Length header.
- var hasCL bool
- var contentLength int64
- if clenStr := w.header.Get("Content-Length"); clenStr != "" {
- var err error
- contentLength, err = strconv.ParseInt(clenStr, 10, 64)
- if err == nil {
- hasCL = true
+ w.cw.header = w.handlerHeader.clone()
+
+ if cl := w.cw.header.get("Content-Length"); cl != "" {
+ v, err := strconv.ParseInt(cl, 10, 64)
+ if err == nil && v >= 0 {
+ w.contentLength = v
} else {
- log.Printf("http: invalid Content-Length of %q sent", clenStr)
- w.header.Del("Content-Length")
+ log.Printf("http: invalid Content-Length of %q", cl)
+ w.cw.header.Del("Content-Length")
+ }
+ }
+}
+
+// writeHeader finalizes the header sent to the client and writes it
+// to cw.res.conn.buf.
+//
+// p is not written by writeHeader, but is the first chunk of the body
+// that will be written. It is sniffed for a Content-Type if none is
+// set explicitly. It's also used to set the Content-Length, if the
+// total body size was small and the handler has already finished
+// running.
+func (cw *chunkWriter) writeHeader(p []byte) {
+ if cw.wroteHeader {
+ return
+ }
+ cw.wroteHeader = true
+
+ w := cw.res
+ code := w.status
+ done := w.handlerDone
+
+ // If the handler is done but never sent a Content-Length
+ // response header and this is our first (and last) write, set
+ // it, even to zero. This helps HTTP/1.0 clients keep their
+ // "keep-alive" connections alive.
+ if done && cw.header.get("Content-Length") == "" && w.req.Method != "HEAD" {
+ w.contentLength = int64(len(p))
+ cw.header.Set("Content-Length", strconv.Itoa(len(p)))
+ }
+
+ // If this was an HTTP/1.0 request with keep-alive and we sent a
+ // Content-Length back, we can make this a keep-alive response ...
+ if w.req.wantsHttp10KeepAlive() {
+ sentLength := cw.header.get("Content-Length") != ""
+ if sentLength && cw.header.get("Connection") == "keep-alive" {
+ w.closeAfterReply = false
}
}
+ // Check for a explicit (and valid) Content-Length header.
+ hasCL := w.contentLength != -1
+
if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) {
- _, connectionHeaderSet := w.header["Connection"]
+ _, connectionHeaderSet := cw.header["Connection"]
if !connectionHeaderSet {
- w.header.Set("Connection", "keep-alive")
+ cw.header.Set("Connection", "keep-alive")
}
- } else if !w.req.ProtoAtLeast(1, 1) {
- // Client did not ask to keep connection alive.
+ } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() {
w.closeAfterReply = true
}
- if w.header.Get("Connection") == "close" {
+ if cw.header.get("Connection") == "close" {
w.closeAfterReply = true
}
@@ -322,7 +557,7 @@ func (w *response) WriteHeader(code int) {
n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1)
if n >= maxPostHandlerReadBytes {
w.requestTooLarge()
- w.header.Set("Connection", "close")
+ cw.header.Set("Connection", "close")
} else {
w.req.Body.Close()
}
@@ -332,64 +567,67 @@ func (w *response) WriteHeader(code int) {
if code == StatusNotModified {
// Must not have body.
for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} {
- if w.header.Get(header) != "" {
- // TODO: return an error if WriteHeader gets a return parameter
- // or set a flag on w to make future Writes() write an error page?
- // for now just log and drop the header.
- log.Printf("http: StatusNotModified response with header %q defined", header)
- w.header.Del(header)
+ // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers"
+ if cw.header.get(header) != "" {
+ cw.header.Del(header)
}
}
} else {
// If no content type, apply sniffing algorithm to body.
- if w.header.Get("Content-Type") == "" && w.req.Method != "HEAD" {
- w.needSniff = true
+ if cw.header.get("Content-Type") == "" && w.req.Method != "HEAD" {
+ cw.header.Set("Content-Type", DetectContentType(p))
}
}
- if _, ok := w.header["Date"]; !ok {
- w.Header().Set("Date", time.Now().UTC().Format(TimeFormat))
+ if _, ok := cw.header["Date"]; !ok {
+ cw.header.Set("Date", time.Now().UTC().Format(TimeFormat))
}
- te := w.header.Get("Transfer-Encoding")
+ te := cw.header.get("Transfer-Encoding")
hasTE := te != ""
if hasCL && hasTE && te != "identity" {
// TODO: return an error if WriteHeader gets a return parameter
// For now just ignore the Content-Length.
log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
- te, contentLength)
- w.header.Del("Content-Length")
+ te, w.contentLength)
+ cw.header.Del("Content-Length")
hasCL = false
}
if w.req.Method == "HEAD" || code == StatusNotModified {
// do nothing
+ } else if code == StatusNoContent {
+ cw.header.Del("Transfer-Encoding")
} else if hasCL {
- w.contentLength = contentLength
- w.header.Del("Transfer-Encoding")
+ cw.header.Del("Transfer-Encoding")
} else if w.req.ProtoAtLeast(1, 1) {
// HTTP/1.1 or greater: use chunked transfer encoding
// to avoid closing the connection at EOF.
// TODO: this blows away any custom or stacked Transfer-Encoding they
// might have set. Deal with that as need arises once we have a valid
// use case.
- w.chunking = true
- w.header.Set("Transfer-Encoding", "chunked")
+ cw.chunking = true
+ cw.header.Set("Transfer-Encoding", "chunked")
} else {
// HTTP version < 1.1: cannot do chunked transfer
// encoding and we don't know the Content-Length so
// signal EOF by closing connection.
w.closeAfterReply = true
- w.header.Del("Transfer-Encoding") // in case already set
+ cw.header.Del("Transfer-Encoding") // in case already set
}
// Cannot use Content-Length with non-identity Transfer-Encoding.
- if w.chunking {
- w.header.Del("Content-Length")
+ if cw.chunking {
+ cw.header.Del("Content-Length")
}
if !w.req.ProtoAtLeast(1, 0) {
return
}
+
+ if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") {
+ cw.header.Set("Connection", "close")
+ }
+
proto := "HTTP/1.0"
if w.req.ProtoAtLeast(1, 1) {
proto = "HTTP/1.1"
@@ -400,37 +638,8 @@ func (w *response) WriteHeader(code int) {
text = "status code " + codestring
}
io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n")
- w.header.Write(w.conn.buf)
-
- // If we need to sniff the body, leave the header open.
- // Otherwise, end it here.
- if !w.needSniff {
- io.WriteString(w.conn.buf, "\r\n")
- }
-}
-
-// sniff uses the first block of written data,
-// stored in w.conn.body, to decide the Content-Type
-// for the HTTP body.
-func (w *response) sniff() {
- if !w.needSniff {
- return
- }
- w.needSniff = false
-
- data := w.conn.body
- fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n\r\n", DetectContentType(data))
-
- if len(data) == 0 {
- return
- }
- if w.chunking {
- fmt.Fprintf(w.conn.buf, "%x\r\n", len(data))
- }
- _, err := w.conn.buf.Write(data)
- if w.chunking && err == nil {
- io.WriteString(w.conn.buf, "\r\n")
- }
+ cw.header.Write(w.conn.buf)
+ w.conn.buf.Write(crlf)
}
// bodyAllowed returns true if a Write is allowed for this response type.
@@ -442,8 +651,40 @@ func (w *response) bodyAllowed() bool {
return w.status != StatusNotModified && w.req.Method != "HEAD"
}
+// The Life Of A Write is like this:
+//
+// Handler starts. No header has been sent. The handler can either
+// write a header, or just start writing. Writing before sending a header
+// sends an implicity empty 200 OK header.
+//
+// If the handler didn't declare a Content-Length up front, we either
+// go into chunking mode or, if the handler finishes running before
+// the chunking buffer size, we compute a Content-Length and send that
+// in the header instead.
+//
+// Likewise, if the handler didn't set a Content-Type, we sniff that
+// from the initial chunk of output.
+//
+// The Writers are wired together like:
+//
+// 1. *response (the ResponseWriter) ->
+// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes
+// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type)
+// and which writes the chunk headers, if needed.
+// 4. conn.buf, a bufio.Writer of default (4kB) bytes
+// 5. the rwc, the net.Conn.
+//
+// TODO(bradfitz): short-circuit some of the buffering when the
+// initial header contains both a Content-Type and Content-Length.
+// Also short-circuit in (1) when the header's been sent and not in
+// chunking mode, writing directly to (4) instead, if (2) has no
+// buffered data. More generally, we could short-circuit from (1) to
+// (3) even in chunking mode if the write size from (1) is over some
+// threshold and nothing is in (2). The answer might be mostly making
+// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal
+// with this instead.
func (w *response) Write(data []byte) (n int, err error) {
- if w.conn.hijacked {
+ if w.conn.hijacked() {
log.Print("http: response.Write on hijacked connection")
return 0, ErrHijacked
}
@@ -461,73 +702,20 @@ func (w *response) Write(data []byte) (n int, err error) {
if w.contentLength != -1 && w.written > w.contentLength {
return 0, ErrContentLength
}
-
- var m int
- if w.needSniff {
- // We need to sniff the beginning of the output to
- // determine the content type. Accumulate the
- // initial writes in w.conn.body.
- // Cap m so that append won't allocate.
- m = cap(w.conn.body) - len(w.conn.body)
- if m > len(data) {
- m = len(data)
- }
- w.conn.body = append(w.conn.body, data[:m]...)
- data = data[m:]
- if len(data) == 0 {
- // Copied everything into the buffer.
- // Wait for next write.
- return m, nil
- }
-
- // Filled the buffer; more data remains.
- // Sniff the content (flushes the buffer)
- // and then proceed with the remainder
- // of the data as a normal Write.
- // Calling sniff clears needSniff.
- w.sniff()
- }
-
- // TODO(rsc): if chunking happened after the buffering,
- // then there would be fewer chunk headers.
- // On the other hand, it would make hijacking more difficult.
- if w.chunking {
- fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt
- }
- n, err = w.conn.buf.Write(data)
- if err == nil && w.chunking {
- if n != len(data) {
- err = io.ErrShortWrite
- }
- if err == nil {
- io.WriteString(w.conn.buf, "\r\n")
- }
- }
-
- return m + n, err
+ return w.w.Write(data)
}
func (w *response) finishRequest() {
- // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length
- // back, we can make this a keep-alive response ...
- if w.req.wantsHttp10KeepAlive() {
- sentLength := w.header.Get("Content-Length") != ""
- if sentLength && w.header.Get("Connection") == "keep-alive" {
- w.closeAfterReply = false
- }
- }
+ w.handlerDone = true
+
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
- if w.needSniff {
- w.sniff()
- }
- if w.chunking {
- io.WriteString(w.conn.buf, "0\r\n")
- // trailer key/value pairs, followed by blank line
- io.WriteString(w.conn.buf, "\r\n")
- }
+
+ w.w.Flush()
+ w.cw.close()
w.conn.buf.Flush()
+
// Close the body, unless we're about to close the whole TCP connection
// anyway.
if !w.closeAfterReply {
@@ -537,7 +725,7 @@ func (w *response) finishRequest() {
w.req.MultipartForm.RemoveAll()
}
- if w.contentLength != -1 && w.contentLength != w.written {
+ if w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written {
// Did not write enough. Avoid getting out of sync.
w.closeAfterReply = true
}
@@ -547,66 +735,114 @@ func (w *response) Flush() {
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
- w.sniff()
- w.conn.buf.Flush()
+ w.w.Flush()
+ w.cw.flush()
}
-// Close the connection.
-func (c *conn) close() {
+func (c *conn) finalFlush() {
if c.buf != nil {
c.buf.Flush()
c.buf = nil
}
+}
+
+// Close the connection.
+func (c *conn) close() {
+ c.finalFlush()
if c.rwc != nil {
c.rwc.Close()
c.rwc = nil
}
}
+// rstAvoidanceDelay is the amount of time we sleep after closing the
+// write side of a TCP connection before closing the entire socket.
+// By sleeping, we increase the chances that the client sees our FIN
+// and processes its final data before they process the subsequent RST
+// from closing a connection with known unread data.
+// This RST seems to occur mostly on BSD systems. (And Windows?)
+// This timeout is somewhat arbitrary (~latency around the planet).
+const rstAvoidanceDelay = 500 * time.Millisecond
+
+// closeWrite flushes any outstanding data and sends a FIN packet (if
+// client is connected via TCP), signalling that we're done. We then
+// pause for a bit, hoping the client processes it before `any
+// subsequent RST.
+//
+// See http://golang.org/issue/3595
+func (c *conn) closeWriteAndWait() {
+ c.finalFlush()
+ if tcp, ok := c.rwc.(*net.TCPConn); ok {
+ tcp.CloseWrite()
+ }
+ time.Sleep(rstAvoidanceDelay)
+}
+
+// validNPN returns whether the proto is not a blacklisted Next
+// Protocol Negotiation protocol. Empty and built-in protocol types
+// are blacklisted and can't be overridden with alternate
+// implementations.
+func validNPN(proto string) bool {
+ switch proto {
+ case "", "http/1.1", "http/1.0":
+ return false
+ }
+ return true
+}
+
// Serve a new connection.
func (c *conn) serve() {
defer func() {
- err := recover()
- if err == nil {
- return
+ if err := recover(); err != nil {
+ const size = 4096
+ buf := make([]byte, size)
+ buf = buf[:runtime.Stack(buf, false)]
+ log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
}
-
- var buf bytes.Buffer
- fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err)
- buf.Write(debug.Stack())
- log.Print(buf.String())
-
- if c.rwc != nil { // may be nil if connection hijacked
- c.rwc.Close()
+ if !c.hijacked() {
+ c.close()
}
}()
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
+ if d := c.server.ReadTimeout; d != 0 {
+ c.rwc.SetReadDeadline(time.Now().Add(d))
+ }
+ if d := c.server.WriteTimeout; d != 0 {
+ c.rwc.SetWriteDeadline(time.Now().Add(d))
+ }
if err := tlsConn.Handshake(); err != nil {
- c.close()
return
}
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()
+ if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) {
+ if fn := c.server.TLSNextProto[proto]; fn != nil {
+ h := initNPNRequest{tlsConn, serverHandler{c.server}}
+ fn(c.server, tlsConn, h)
+ }
+ return
+ }
}
for {
w, err := c.readRequest()
if err != nil {
- msg := "400 Bad Request"
if err == errTooLarge {
// Their HTTP client may or may not be
// able to read this if we're
// responding to them and hanging up
// while they're still writing their
// request. Undefined behavior.
- msg = "413 Request Entity Too Large"
+ io.WriteString(c.rwc, "HTTP/1.1 413 Request Entity Too Large\r\n\r\n")
+ c.closeWriteAndWait()
+ break
} else if err == io.EOF {
break // Don't reply
} else if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
break // Don't reply
}
- fmt.Fprintf(c.rwc, "HTTP/1.1 %s\r\n\r\n", msg)
+ io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\n\r\n")
break
}
@@ -624,59 +860,59 @@ func (c *conn) serve() {
break
}
req.Header.Del("Expect")
- } else if req.Header.Get("Expect") != "" {
- // TODO(bradfitz): let ServeHTTP handlers handle
- // requests with non-standard expectation[s]? Seems
- // theoretical at best, and doesn't fit into the
- // current ServeHTTP model anyway. We'd need to
- // make the ResponseWriter an optional
- // "ExpectReplier" interface or something.
- //
- // For now we'll just obey RFC 2616 14.20 which says
- // "If a server receives a request containing an
- // Expect field that includes an expectation-
- // extension that it does not support, it MUST
- // respond with a 417 (Expectation Failed) status."
- w.Header().Set("Connection", "close")
- w.WriteHeader(StatusExpectationFailed)
- w.finishRequest()
+ } else if req.Header.get("Expect") != "" {
+ w.sendExpectationFailed()
break
}
- handler := c.server.Handler
- if handler == nil {
- handler = DefaultServeMux
- }
-
// HTTP cannot have multiple simultaneous active requests.[*]
// Until the server replies to this request, it can't read another,
// so we might as well run the handler in this goroutine.
// [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized.
- handler.ServeHTTP(w, w.req)
- if c.hijacked {
+ serverHandler{c.server}.ServeHTTP(w, w.req)
+ if c.hijacked() {
return
}
w.finishRequest()
if w.closeAfterReply {
+ if w.requestBodyLimitHit {
+ c.closeWriteAndWait()
+ }
break
}
}
- c.close()
+}
+
+func (w *response) sendExpectationFailed() {
+ // TODO(bradfitz): let ServeHTTP handlers handle
+ // requests with non-standard expectation[s]? Seems
+ // theoretical at best, and doesn't fit into the
+ // current ServeHTTP model anyway. We'd need to
+ // make the ResponseWriter an optional
+ // "ExpectReplier" interface or something.
+ //
+ // For now we'll just obey RFC 2616 14.20 which says
+ // "If a server receives a request containing an
+ // Expect field that includes an expectation-
+ // extension that it does not support, it MUST
+ // respond with a 417 (Expectation Failed) status."
+ w.Header().Set("Connection", "close")
+ w.WriteHeader(StatusExpectationFailed)
+ w.finishRequest()
}
// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
// and a Hijacker.
func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
- if w.conn.hijacked {
- return nil, nil, ErrHijacked
+ if w.wroteHeader {
+ w.cw.flush()
}
- w.conn.hijacked = true
- rwc = w.conn.rwc
- buf = w.conn.buf
- w.conn.rwc = nil
- w.conn.buf = nil
- return
+ return w.conn.hijack()
+}
+
+func (w *response) CloseNotify() <-chan bool {
+ return w.conn.closeNotify()
}
// The HandlerFunc type is an adapter to allow the use of
@@ -817,13 +1053,13 @@ func RedirectHandler(url string, code int) Handler {
// patterns and calls the handler for the pattern that
// most closely matches the URL.
//
-// Patterns named fixed, rooted paths, like "/favicon.ico",
+// Patterns name fixed, rooted paths, like "/favicon.ico",
// or rooted subtrees, like "/images/" (note the trailing slash).
// Longer patterns take precedence over shorter ones, so that
// if there are handlers registered for both "/images/"
// and "/images/thumbnails/", the latter handler will be
// called for paths beginning "/images/thumbnails/" and the
-// former will receiver requests for any other paths in the
+// former will receive requests for any other paths in the
// "/images/" subtree.
//
// Patterns may optionally begin with a host name, restricting matches to
@@ -836,13 +1072,15 @@ func RedirectHandler(url string, code int) Handler {
// redirecting any request containing . or .. elements to an
// equivalent .- and ..-free URL.
type ServeMux struct {
- mu sync.RWMutex
- m map[string]muxEntry
+ mu sync.RWMutex
+ m map[string]muxEntry
+ hosts bool // whether any patterns contain hostnames
}
type muxEntry struct {
explicit bool
h Handler
+ pattern string
}
// NewServeMux allocates and returns a new ServeMux.
@@ -883,8 +1121,7 @@ func cleanPath(p string) string {
// Find a handler on a handler map given a path string
// Most-specific (longest) pattern wins
-func (mux *ServeMux) match(path string) Handler {
- var h Handler
+func (mux *ServeMux) match(path string) (h Handler, pattern string) {
var n = 0
for k, v := range mux.m {
if !pathMatch(k, path) {
@@ -893,37 +1130,64 @@ func (mux *ServeMux) match(path string) Handler {
if h == nil || len(k) > n {
n = len(k)
h = v.h
+ pattern = v.pattern
+ }
+ }
+ return
+}
+
+// Handler returns the handler to use for the given request,
+// consulting r.Method, r.Host, and r.URL.Path. It always returns
+// a non-nil handler. If the path is not in its canonical form, the
+// handler will be an internally-generated handler that redirects
+// to the canonical path.
+//
+// Handler also returns the registered pattern that matches the
+// request or, in the case of internally-generated redirects,
+// the pattern that will match after following the redirect.
+//
+// If there is no registered handler that applies to the request,
+// Handler returns a ``page not found'' handler and an empty pattern.
+func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) {
+ if r.Method != "CONNECT" {
+ if p := cleanPath(r.URL.Path); p != r.URL.Path {
+ _, pattern = mux.handler(r.Host, p)
+ return RedirectHandler(p, StatusMovedPermanently), pattern
}
}
- return h
+
+ return mux.handler(r.Host, r.URL.Path)
}
-// handler returns the handler to use for the request r.
-func (mux *ServeMux) handler(r *Request) Handler {
+// handler is the main implementation of Handler.
+// The path is known to be in canonical form, except for CONNECT methods.
+func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) {
mux.mu.RLock()
defer mux.mu.RUnlock()
// Host-specific pattern takes precedence over generic ones
- h := mux.match(r.Host + r.URL.Path)
+ if mux.hosts {
+ h, pattern = mux.match(host + path)
+ }
if h == nil {
- h = mux.match(r.URL.Path)
+ h, pattern = mux.match(path)
}
if h == nil {
- h = NotFoundHandler()
+ h, pattern = NotFoundHandler(), ""
}
- return h
+ return
}
// ServeHTTP dispatches the request to the handler whose
// pattern most closely matches the request URL.
func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
- // Clean path to canonical form and redirect.
- if p := cleanPath(r.URL.Path); p != r.URL.Path {
- w.Header().Set("Location", p)
- w.WriteHeader(StatusMovedPermanently)
+ if r.RequestURI == "*" {
+ w.Header().Set("Connection", "close")
+ w.WriteHeader(StatusBadRequest)
return
}
- mux.handler(r).ServeHTTP(w, r)
+ h, _ := mux.Handler(r)
+ h.ServeHTTP(w, r)
}
// Handle registers the handler for the given pattern.
@@ -942,14 +1206,26 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) {
panic("http: multiple registrations for " + pattern)
}
- mux.m[pattern] = muxEntry{explicit: true, h: handler}
+ mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern}
+
+ if pattern[0] != '/' {
+ mux.hosts = true
+ }
// Helpful behavior:
// If pattern is /tree/, insert an implicit permanent redirect for /tree.
// It can be overridden by an explicit registration.
n := len(pattern)
if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit {
- mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(pattern, StatusMovedPermanently)}
+ // If pattern contains a host name, strip it and use remaining
+ // path for redirect.
+ path := pattern
+ if pattern[0] != '/' {
+ // In pattern, at least the last character is a '/', so
+ // strings.Index can't be -1.
+ path = pattern[strings.Index(pattern, "/"):]
+ }
+ mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(path, StatusMovedPermanently), pattern: pattern}
}
}
@@ -971,7 +1247,7 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
}
// Serve accepts incoming HTTP connections on the listener l,
-// creating a new service thread for each. The service threads
+// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
// Handler is typically nil, in which case the DefaultServeMux is used.
func Serve(l net.Listener, handler Handler) error {
@@ -987,6 +1263,32 @@ type Server struct {
WriteTimeout time.Duration // maximum duration before timing out write of the response
MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0
TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS
+
+ // TLSNextProto optionally specifies a function to take over
+ // ownership of the provided TLS connection when an NPN
+ // protocol upgrade has occured. The map key is the protocol
+ // name negotiated. The Handler argument should be used to
+ // handle HTTP requests and will initialize the Request's TLS
+ // and RemoteAddr if not already set. The connection is
+ // automatically closed when the function returns.
+ TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
+}
+
+// serverHandler delegates to either the server's Handler or
+// DefaultServeMux and also handles "OPTIONS *" requests.
+type serverHandler struct {
+ srv *Server
+}
+
+func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
+ handler := sh.srv.Handler
+ if handler == nil {
+ handler = DefaultServeMux
+ }
+ if req.RequestURI == "*" && req.Method == "OPTIONS" {
+ handler = globalOptionsHandler{}
+ }
+ handler.ServeHTTP(rw, req)
}
// ListenAndServe listens on the TCP network address srv.Addr and then
@@ -1005,7 +1307,7 @@ func (srv *Server) ListenAndServe() error {
}
// Serve accepts incoming connections on the Listener l, creating a
-// new service thread for each. The service threads read requests and
+// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
func (srv *Server) Serve(l net.Listener) error {
defer l.Close()
@@ -1029,12 +1331,6 @@ func (srv *Server) Serve(l net.Listener) error {
return e
}
tempDelay = 0
- if srv.ReadTimeout != 0 {
- rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
- }
- if srv.WriteTimeout != 0 {
- rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
- }
c, err := srv.newConn(rw)
if err != nil {
continue
@@ -1150,7 +1446,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
// TimeoutHandler returns a Handler that runs h with the given time limit.
//
// The new Handler calls h.ServeHTTP to handle each request, but if a
-// call runs for more than ns nanoseconds, the handler responds with
+// call runs for longer than its time limit, the handler responds with
// a 503 Service Unavailable error and the given message in its body.
// (If msg is empty, a suitable default message will be sent.)
// After such a timeout, writes by h to its ResponseWriter will return
@@ -1180,7 +1476,7 @@ func (h *timeoutHandler) errorBody() string {
}
func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
- done := make(chan bool)
+ done := make(chan bool, 1)
tw := &timeoutWriter{w: w}
go func() {
h.handler.ServeHTTP(tw, r)
@@ -1232,3 +1528,86 @@ func (tw *timeoutWriter) WriteHeader(code int) {
tw.mu.Unlock()
tw.w.WriteHeader(code)
}
+
+// globalOptionsHandler responds to "OPTIONS *" requests.
+type globalOptionsHandler struct{}
+
+func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "0")
+ if r.ContentLength != 0 {
+ // Read up to 4KB of OPTIONS body (as mentioned in the
+ // spec as being reserved for future use), but anything
+ // over that is considered a waste of server resources
+ // (or an attack) and we abort and close the connection,
+ // courtesy of MaxBytesReader's EOF behavior.
+ mb := MaxBytesReader(w, r.Body, 4<<10)
+ io.Copy(ioutil.Discard, mb)
+ }
+}
+
+// eofReader is a non-nil io.ReadCloser that always returns EOF.
+var eofReader = ioutil.NopCloser(strings.NewReader(""))
+
+// initNPNRequest is an HTTP handler that initializes certain
+// uninitialized fields in its *Request. Such partially-initialized
+// Requests come from NPN protocol handlers.
+type initNPNRequest struct {
+ c *tls.Conn
+ h serverHandler
+}
+
+func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) {
+ if req.TLS == nil {
+ req.TLS = &tls.ConnectionState{}
+ *req.TLS = h.c.ConnectionState()
+ }
+ if req.Body == nil {
+ req.Body = eofReader
+ }
+ if req.RemoteAddr == "" {
+ req.RemoteAddr = h.c.RemoteAddr().String()
+ }
+ h.h.ServeHTTP(rw, req)
+}
+
+// loggingConn is used for debugging.
+type loggingConn struct {
+ name string
+ net.Conn
+}
+
+var (
+ uniqNameMu sync.Mutex
+ uniqNameNext = make(map[string]int)
+)
+
+func newLoggingConn(baseName string, c net.Conn) net.Conn {
+ uniqNameMu.Lock()
+ defer uniqNameMu.Unlock()
+ uniqNameNext[baseName]++
+ return &loggingConn{
+ name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]),
+ Conn: c,
+ }
+}
+
+func (c *loggingConn) Write(p []byte) (n int, err error) {
+ log.Printf("%s.Write(%d) = ....", c.name, len(p))
+ n, err = c.Conn.Write(p)
+ log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err)
+ return
+}
+
+func (c *loggingConn) Read(p []byte) (n int, err error) {
+ log.Printf("%s.Read(%d) = ....", c.name, len(p))
+ n, err = c.Conn.Read(p)
+ log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err)
+ return
+}
+
+func (c *loggingConn) Close() (err error) {
+ log.Printf("%s.Close() = ...", c.name)
+ err = c.Conn.Close()
+ log.Printf("%s.Close() = %v", c.name, err)
+ return
+}
diff --git a/src/pkg/net/http/server_test.go b/src/pkg/net/http/server_test.go
new file mode 100644
index 000000000..8b4e8c6d6
--- /dev/null
+++ b/src/pkg/net/http/server_test.go
@@ -0,0 +1,95 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "net/url"
+ "testing"
+)
+
+var serveMuxRegister = []struct {
+ pattern string
+ h Handler
+}{
+ {"/dir/", serve(200)},
+ {"/search", serve(201)},
+ {"codesearch.google.com/search", serve(202)},
+ {"codesearch.google.com/", serve(203)},
+}
+
+// serve returns a handler that sends a response with the given code.
+func serve(code int) HandlerFunc {
+ return func(w ResponseWriter, r *Request) {
+ w.WriteHeader(code)
+ }
+}
+
+var serveMuxTests = []struct {
+ method string
+ host string
+ path string
+ code int
+ pattern string
+}{
+ {"GET", "google.com", "/", 404, ""},
+ {"GET", "google.com", "/dir", 301, "/dir/"},
+ {"GET", "google.com", "/dir/", 200, "/dir/"},
+ {"GET", "google.com", "/dir/file", 200, "/dir/"},
+ {"GET", "google.com", "/search", 201, "/search"},
+ {"GET", "google.com", "/search/", 404, ""},
+ {"GET", "google.com", "/search/foo", 404, ""},
+ {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
+ {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
+ {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
+ {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
+ {"GET", "images.google.com", "/search", 201, "/search"},
+ {"GET", "images.google.com", "/search/", 404, ""},
+ {"GET", "images.google.com", "/search/foo", 404, ""},
+ {"GET", "google.com", "/../search", 301, "/search"},
+ {"GET", "google.com", "/dir/..", 301, ""},
+ {"GET", "google.com", "/dir/..", 301, ""},
+ {"GET", "google.com", "/dir/./file", 301, "/dir/"},
+
+ // The /foo -> /foo/ redirect applies to CONNECT requests
+ // but the path canonicalization does not.
+ {"CONNECT", "google.com", "/dir", 301, "/dir/"},
+ {"CONNECT", "google.com", "/../search", 404, ""},
+ {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
+ {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
+ {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
+}
+
+func TestServeMuxHandler(t *testing.T) {
+ mux := NewServeMux()
+ for _, e := range serveMuxRegister {
+ mux.Handle(e.pattern, e.h)
+ }
+
+ for _, tt := range serveMuxTests {
+ r := &Request{
+ Method: tt.method,
+ Host: tt.host,
+ URL: &url.URL{
+ Path: tt.path,
+ },
+ }
+ h, pattern := mux.Handler(r)
+ cs := &codeSaver{h: Header{}}
+ h.ServeHTTP(cs, r)
+ if pattern != tt.pattern || cs.code != tt.code {
+ t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, cs.code, pattern, tt.code, tt.pattern)
+ }
+ }
+}
+
+// A codeSaver is a ResponseWriter that saves the code passed to WriteHeader.
+type codeSaver struct {
+ h Header
+ code int
+}
+
+func (cs *codeSaver) Header() Header { return cs.h }
+func (cs *codeSaver) Write(p []byte) (int, error) { return len(p), nil }
+func (cs *codeSaver) WriteHeader(code int) { cs.code = code }
diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go
index 9e9d84172..43c6023a3 100644
--- a/src/pkg/net/http/transfer.go
+++ b/src/pkg/net/http/transfer.go
@@ -87,10 +87,8 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) {
// Sanitize Body,ContentLength,TransferEncoding
if t.ResponseToHEAD {
t.Body = nil
- t.TransferEncoding = nil
- // ContentLength is expected to hold Content-Length
- if t.ContentLength < 0 {
- return nil, ErrMissingContentLength
+ if chunked(t.TransferEncoding) {
+ t.ContentLength = -1
}
} else {
if !atLeastHTTP11 || t.Body == nil {
@@ -122,9 +120,6 @@ func (t *transferWriter) shouldSendContentLength() bool {
if t.ContentLength > 0 {
return true
}
- if t.ResponseToHEAD {
- return true
- }
// Many servers expect a Content-Length for these methods
if t.Method == "POST" || t.Method == "PUT" {
return true
@@ -199,10 +194,11 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) {
ncopy, err = io.Copy(w, t.Body)
} else {
ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength))
- nextra, err := io.Copy(ioutil.Discard, t.Body)
if err != nil {
return err
}
+ var nextra int64
+ nextra, err = io.Copy(ioutil.Discard, t.Body)
ncopy += nextra
}
if err != nil {
@@ -213,7 +209,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) {
}
}
- if t.ContentLength != -1 && t.ContentLength != ncopy {
+ if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy {
return fmt.Errorf("http: Request.ContentLength=%d with Body length %d",
t.ContentLength, ncopy)
}
@@ -294,10 +290,19 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
return err
}
- t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
+ realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
if err != nil {
return err
}
+ if isResponse && t.RequestMethod == "HEAD" {
+ if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
+ return err
+ } else {
+ t.ContentLength = n
+ }
+ } else {
+ t.ContentLength = realLength
+ }
// Trailer
t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding)
@@ -310,7 +315,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
// See RFC2616, section 4.4.
switch msg.(type) {
case *Response:
- if t.ContentLength == -1 &&
+ if realLength == -1 &&
!chunked(t.TransferEncoding) &&
bodyAllowedForStatus(t.StatusCode) {
// Unbounded body.
@@ -322,12 +327,16 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
// or close connection when finished, since multipart is not supported yet
switch {
case chunked(t.TransferEncoding):
- t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
- case t.ContentLength >= 0:
+ if noBodyExpected(t.RequestMethod) {
+ t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close}
+ } else {
+ t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
+ }
+ case realLength >= 0:
// TODO: limit the Content-Length. This is an easy DoS vector.
- t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close}
+ t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close}
default:
- // t.ContentLength < 0, i.e. "Content-Length" not mentioned in header
+ // realLength < 0, i.e. "Content-Length" not mentioned in header
if t.Close {
// Close semantics (i.e. HTTP/1.0)
t.Body = &body{Reader: r, closing: t.Close}
@@ -371,12 +380,6 @@ func fixTransferEncoding(requestMethod string, header Header) ([]string, error)
delete(header, "Transfer-Encoding")
- // Head responses have no bodies, so the transfer encoding
- // should be ignored.
- if requestMethod == "HEAD" {
- return nil, nil
- }
-
encodings := strings.Split(raw[0], ",")
te := make([]string, 0, len(encodings))
// TODO: Even though we only support "identity" and "chunked"
@@ -432,11 +435,11 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
}
// Logic based on Content-Length
- cl := strings.TrimSpace(header.Get("Content-Length"))
+ cl := strings.TrimSpace(header.get("Content-Length"))
if cl != "" {
- n, err := strconv.ParseInt(cl, 10, 64)
- if err != nil || n < 0 {
- return -1, &badStringError{"bad Content-Length", cl}
+ n, err := parseContentLength(cl)
+ if err != nil {
+ return -1, err
}
return n, nil
} else {
@@ -451,13 +454,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
return 0, nil
}
- // Logic based on media type. The purpose of the following code is just
- // to detect whether the unsupported "multipart/byteranges" is being
- // used. A proper Content-Type parser is needed in the future.
- if strings.Contains(strings.ToLower(header.Get("Content-Type")), "multipart/byteranges") {
- return -1, ErrNotSupported
- }
-
// Body-EOF logic based on other methods (like closing, or chunked coding)
return -1, nil
}
@@ -469,14 +465,14 @@ func shouldClose(major, minor int, header Header) bool {
if major < 1 {
return true
} else if major == 1 && minor == 0 {
- if !strings.Contains(strings.ToLower(header.Get("Connection")), "keep-alive") {
+ if !strings.Contains(strings.ToLower(header.get("Connection")), "keep-alive") {
return true
}
return false
} else {
// TODO: Should split on commas, toss surrounding white space,
// and check each field.
- if strings.ToLower(header.Get("Connection")) == "close" {
+ if strings.ToLower(header.get("Connection")) == "close" {
header.Del("Connection")
return true
}
@@ -486,7 +482,7 @@ func shouldClose(major, minor int, header Header) bool {
// Parse the trailer header
func fixTrailer(header Header, te []string) (Header, error) {
- raw := header.Get("Trailer")
+ raw := header.get("Trailer")
if raw == "" {
return nil, nil
}
@@ -525,11 +521,11 @@ type body struct {
res *response // response writer for server requests, else nil
}
-// ErrBodyReadAfterClose is returned when reading a Request Body after
-// the body has been closed. This typically happens when the body is
+// ErrBodyReadAfterClose is returned when reading a Request or Response
+// Body after the body has been closed. This typically happens when the body is
// read after an HTTP Handler calls WriteHeader or Write on its
// ResponseWriter.
-var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed request Body")
+var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
func (b *body) Read(p []byte) (n int, err error) {
if b.closed {
@@ -567,14 +563,22 @@ func seeUpcomingDoubleCRLF(r *bufio.Reader) bool {
return false
}
+var errTrailerEOF = errors.New("http: unexpected EOF reading trailer")
+
func (b *body) readTrailer() error {
// The common case, since nobody uses trailers.
- buf, _ := b.r.Peek(2)
+ buf, err := b.r.Peek(2)
if bytes.Equal(buf, singleCRLF) {
b.r.ReadByte()
b.r.ReadByte()
return nil
}
+ if len(buf) < 2 {
+ return errTrailerEOF
+ }
+ if err != nil {
+ return err
+ }
// Make sure there's a header terminator coming up, to prevent
// a DoS with an unbounded size Trailer. It's not easy to
@@ -590,6 +594,9 @@ func (b *body) readTrailer() error {
hdr, err := textproto.NewReader(b.r).ReadMIMEHeader()
if err != nil {
+ if err == io.EOF {
+ return errTrailerEOF
+ }
return err
}
switch rr := b.hdr.(type) {
@@ -630,3 +637,18 @@ func (b *body) Close() error {
}
return nil
}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+func parseContentLength(cl string) (int64, error) {
+ cl = strings.TrimSpace(cl)
+ if cl == "" {
+ return -1, nil
+ }
+ n, err := strconv.ParseInt(cl, 10, 64)
+ if err != nil || n < 0 {
+ return 0, &badStringError{"bad Content-Length", cl}
+ }
+ return n, nil
+
+}
diff --git a/src/pkg/net/http/transfer_test.go b/src/pkg/net/http/transfer_test.go
new file mode 100644
index 000000000..8627a374c
--- /dev/null
+++ b/src/pkg/net/http/transfer_test.go
@@ -0,0 +1,37 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "strings"
+ "testing"
+)
+
+func TestBodyReadBadTrailer(t *testing.T) {
+ b := &body{
+ Reader: strings.NewReader("foobar"),
+ hdr: true, // force reading the trailer
+ r: bufio.NewReader(strings.NewReader("")),
+ }
+ buf := make([]byte, 7)
+ n, err := b.Read(buf[:3])
+ got := string(buf[:n])
+ if got != "foo" || err != nil {
+ t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err)
+ }
+
+ n, err = b.Read(buf[:])
+ got = string(buf[:n])
+ if got != "bar" || err != nil {
+ t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err)
+ }
+
+ n, err = b.Read(buf[:])
+ got = string(buf[:n])
+ if err == nil {
+ t.Errorf("final Read was successful (%q), expected error from trailer read", got)
+ }
+}
diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go
index 6efe191eb..685d7d56c 100644
--- a/src/pkg/net/http/transport.go
+++ b/src/pkg/net/http/transport.go
@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
// HTTP client implementation. See RFC 2616.
-//
+//
// This is the low-level Transport implementation of RoundTripper.
// The high-level interface is in client.go.
@@ -24,13 +24,14 @@ import (
"os"
"strings"
"sync"
+ "time"
)
// DefaultTransport is the default implementation of Transport and is
-// used by DefaultClient. It establishes a new network connection for
-// each call to Do and uses HTTP proxies as directed by the
-// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy)
-// environment variables.
+// used by DefaultClient. It establishes network connections as needed
+// and caches them for reuse by subsequent calls. It uses HTTP proxies
+// as directed by the $HTTP_PROXY and $NO_PROXY (or $http_proxy and
+// $no_proxy) environment variables.
var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment}
// DefaultMaxIdleConnsPerHost is the default value of Transport's
@@ -41,8 +42,11 @@ const DefaultMaxIdleConnsPerHost = 2
// https, and http proxies (for either http or https with CONNECT).
// Transport can also cache connections for future re-use.
type Transport struct {
- lk sync.Mutex
+ idleMu sync.Mutex
idleConn map[string][]*persistConn
+ reqMu sync.Mutex
+ reqConn map[*Request]*persistConn
+ altMu sync.RWMutex
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
// TODO: tunable on global max cached connections
@@ -68,9 +72,15 @@ type Transport struct {
DisableCompression bool
// MaxIdleConnsPerHost, if non-zero, controls the maximum idle
- // (keep-alive) to keep to keep per-host. If zero,
+ // (keep-alive) to keep per-host. If zero,
// DefaultMaxIdleConnsPerHost is used.
MaxIdleConnsPerHost int
+
+ // ResponseHeaderTimeout, if non-zero, specifies the amount of
+ // time to wait for a server's response headers after fully
+ // writing the request (including its body, if any). This
+ // time does not include the time to read the response body.
+ ResponseHeaderTimeout time.Duration
}
// ProxyFromEnvironment returns the URL of the proxy to use for a
@@ -88,7 +98,7 @@ func ProxyFromEnvironment(req *Request) (*url.URL, error) {
return nil, nil
}
proxyURL, err := url.Parse(proxy)
- if err != nil || proxyURL.Scheme == "" {
+ if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") {
if u, err := url.Parse("http://" + proxy); err == nil {
proxyURL = u
err = nil
@@ -131,17 +141,20 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
return nil, errors.New("http: nil Request.Header")
}
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
- t.lk.Lock()
+ t.altMu.RLock()
var rt RoundTripper
if t.altProto != nil {
rt = t.altProto[req.URL.Scheme]
}
- t.lk.Unlock()
+ t.altMu.RUnlock()
if rt == nil {
return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
}
return rt.RoundTrip(req)
}
+ if req.URL.Host == "" {
+ return nil, errors.New("http: no Host in request URL")
+ }
treq := &transportRequest{Request: req}
cm, err := t.connectMethodForRequest(treq)
if err != nil {
@@ -170,8 +183,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
if scheme == "http" || scheme == "https" {
panic("protocol " + scheme + " already registered")
}
- t.lk.Lock()
- defer t.lk.Unlock()
+ t.altMu.Lock()
+ defer t.altMu.Unlock()
if t.altProto == nil {
t.altProto = make(map[string]RoundTripper)
}
@@ -186,17 +199,29 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
// a "keep-alive" state. It does not interrupt any connections currently
// in use.
func (t *Transport) CloseIdleConnections() {
- t.lk.Lock()
- defer t.lk.Unlock()
- if t.idleConn == nil {
+ t.idleMu.Lock()
+ m := t.idleConn
+ t.idleConn = nil
+ t.idleMu.Unlock()
+ if m == nil {
return
}
- for _, conns := range t.idleConn {
+ for _, conns := range m {
for _, pconn := range conns {
pconn.close()
}
}
- t.idleConn = make(map[string][]*persistConn)
+}
+
+// CancelRequest cancels an in-flight request by closing its
+// connection.
+func (t *Transport) CancelRequest(req *Request) {
+ t.reqMu.Lock()
+ pc := t.reqConn[req]
+ t.reqMu.Unlock()
+ if pc != nil {
+ pc.conn.Close()
+ }
}
//
@@ -242,8 +267,6 @@ func (cm *connectMethod) proxyAuth() string {
// If pconn is no longer needed or not in a good state, putIdleConn
// returns false.
func (t *Transport) putIdleConn(pconn *persistConn) bool {
- t.lk.Lock()
- defer t.lk.Unlock()
if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
pconn.close()
return false
@@ -256,21 +279,32 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
if max == 0 {
max = DefaultMaxIdleConnsPerHost
}
+ t.idleMu.Lock()
+ if t.idleConn == nil {
+ t.idleConn = make(map[string][]*persistConn)
+ }
if len(t.idleConn[key]) >= max {
+ t.idleMu.Unlock()
pconn.close()
return false
}
+ for _, exist := range t.idleConn[key] {
+ if exist == pconn {
+ log.Fatalf("dup idle pconn %p in freelist", pconn)
+ }
+ }
t.idleConn[key] = append(t.idleConn[key], pconn)
+ t.idleMu.Unlock()
return true
}
func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
- t.lk.Lock()
- defer t.lk.Unlock()
+ key := cm.String()
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
if t.idleConn == nil {
- t.idleConn = make(map[string][]*persistConn)
+ return nil
}
- key := cm.String()
for {
pconns, ok := t.idleConn[key]
if !ok {
@@ -289,7 +323,20 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
return
}
}
- return
+ panic("unreachable")
+}
+
+func (t *Transport) setReqConn(r *Request, pc *persistConn) {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ if t.reqConn == nil {
+ t.reqConn = make(map[*Request]*persistConn)
+ }
+ if pc != nil {
+ t.reqConn[r] = pc
+ } else {
+ delete(t.reqConn, r)
+ }
}
func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
@@ -323,6 +370,8 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
cacheKey: cm.String(),
conn: conn,
reqch: make(chan requestAndChan, 50),
+ writech: make(chan writeRequest, 50),
+ closech: make(chan struct{}),
}
switch {
@@ -365,7 +414,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
if cm.targetScheme == "https" {
// Initiate TLS and check remote host name against certificate.
- conn = tls.Client(conn, t.TLSClientConfig)
+ cfg := t.TLSClientConfig
+ if cfg == nil || cfg.ServerName == "" {
+ host := cm.tlsHost()
+ if cfg == nil {
+ cfg = &tls.Config{ServerName: host}
+ } else {
+ clone := *cfg // shallow clone
+ clone.ServerName = host
+ cfg = &clone
+ }
+ }
+ conn = tls.Client(conn, cfg)
if err = conn.(*tls.Conn).Handshake(); err != nil {
return nil, err
}
@@ -380,6 +440,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
pconn.br = bufio.NewReader(pconn.conn)
pconn.bw = bufio.NewWriter(pconn.conn)
go pconn.readLoop()
+ go pconn.writeLoop()
return pconn, nil
}
@@ -421,7 +482,15 @@ func useProxy(addr string) bool {
if hasPort(p) {
p = p[:strings.LastIndex(p, ":")]
}
- if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) {
+ if addr == p {
+ return false
+ }
+ if p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:]) {
+ // no_proxy ".foo.com" matches "bar.foo.com" or "foo.com"
+ return false
+ }
+ if p[0] != '.' && strings.HasSuffix(addr, p) && addr[len(addr)-len(p)-1] == '.' {
+ // no_proxy "foo.com" matches "bar.foo.com"
return false
}
}
@@ -484,25 +553,28 @@ type persistConn struct {
t *Transport
cacheKey string // its connectMethod.String()
conn net.Conn
+ closed bool // whether conn has been closed
br *bufio.Reader // from conn
bw *bufio.Writer // to conn
- reqch chan requestAndChan // written by roundTrip(); read by readLoop()
+ reqch chan requestAndChan // written by roundTrip; read by readLoop
+ writech chan writeRequest // written by roundTrip; read by writeLoop
+ closech chan struct{} // broadcast close when readLoop (TCP connection) closes
isProxy bool
+ lk sync.Mutex // guards following 3 fields
+ numExpectedResponses int
+ broken bool // an error has happened on this connection; marked broken so it's not reused.
// mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the
// original Request given to RoundTrip is not modified)
mutateHeaderFunc func(Header)
-
- lk sync.Mutex // guards numExpectedResponses and broken
- numExpectedResponses int
- broken bool // an error has happened on this connection; marked broken so it's not reused.
}
func (pc *persistConn) isBroken() bool {
pc.lk.Lock()
- defer pc.lk.Unlock()
- return pc.broken
+ b := pc.broken
+ pc.lk.Unlock()
+ return b
}
var remoteSideClosedFunc func(error) bool // or nil to use default
@@ -518,6 +590,7 @@ func remoteSideClosed(err error) bool {
}
func (pc *persistConn) readLoop() {
+ defer close(pc.closech)
alive := true
var lastbody io.ReadCloser // last response body, if any, read on this connection
@@ -544,12 +617,16 @@ func (pc *persistConn) readLoop() {
lastbody.Close() // assumed idempotent
lastbody = nil
}
- resp, err := ReadResponse(pc.br, rc.req)
+
+ var resp *Response
+ if err == nil {
+ resp, err = ReadResponse(pc.br, rc.req)
+ }
+ hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0
if err != nil {
pc.close()
} else {
- hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" {
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
@@ -569,31 +646,37 @@ func (pc *persistConn) readLoop() {
alive = false
}
- hasBody := resp != nil && resp.ContentLength != 0
var waitForBodyRead chan bool
- if alive {
- if hasBody {
- lastbody = resp.Body
- waitForBodyRead = make(chan bool)
- resp.Body.(*bodyEOFSignal).fn = func() {
- if !pc.t.putIdleConn(pc) {
- alive = false
- }
- waitForBodyRead <- true
+ if hasBody {
+ lastbody = resp.Body
+ waitForBodyRead = make(chan bool, 1)
+ resp.Body.(*bodyEOFSignal).fn = func(err error) {
+ alive1 := alive
+ if err != nil {
+ alive1 = false
}
- } else {
- // When there's no response body, we immediately
- // reuse the TCP connection (putIdleConn), but
- // we need to prevent ClientConn.Read from
- // closing the Response.Body on the next
- // loop, otherwise it might close the body
- // before the client code has had a chance to
- // read it (even though it'll just be 0, EOF).
- lastbody = nil
-
- if !pc.t.putIdleConn(pc) {
- alive = false
+ if alive1 && !pc.t.putIdleConn(pc) {
+ alive1 = false
+ }
+ if !alive1 || pc.isBroken() {
+ pc.close()
}
+ waitForBodyRead <- alive1
+ }
+ }
+
+ if alive && !hasBody {
+ // When there's no response body, we immediately
+ // reuse the TCP connection (putIdleConn), but
+ // we need to prevent ClientConn.Read from
+ // closing the Response.Body on the next
+ // loop, otherwise it might close the body
+ // before the client code has had a chance to
+ // read it (even though it'll just be 0, EOF).
+ lastbody = nil
+
+ if !pc.t.putIdleConn(pc) {
+ alive = false
}
}
@@ -602,7 +685,35 @@ func (pc *persistConn) readLoop() {
// Wait for the just-returned response body to be fully consumed
// before we race and peek on the underlying bufio reader.
if waitForBodyRead != nil {
- <-waitForBodyRead
+ alive = <-waitForBodyRead
+ }
+
+ pc.t.setReqConn(rc.req, nil)
+
+ if !alive {
+ pc.close()
+ }
+ }
+}
+
+func (pc *persistConn) writeLoop() {
+ for {
+ select {
+ case wr := <-pc.writech:
+ if pc.isBroken() {
+ wr.ch <- errors.New("http: can't write HTTP request on broken connection")
+ continue
+ }
+ err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra)
+ if err == nil {
+ err = pc.bw.Flush()
+ }
+ if err != nil {
+ pc.markBroken()
+ }
+ wr.ch <- err
+ case <-pc.closech:
+ return
}
}
}
@@ -622,9 +733,24 @@ type requestAndChan struct {
addedGzip bool
}
+// A writeRequest is sent by the readLoop's goroutine to the
+// writeLoop's goroutine to write a request while the read loop
+// concurrently waits on both the write response and the server's
+// reply.
+type writeRequest struct {
+ req *transportRequest
+ ch chan<- error
+}
+
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
- if pc.mutateHeaderFunc != nil {
- pc.mutateHeaderFunc(req.extraHeaders())
+ pc.t.setReqConn(req.Request, pc)
+ pc.lk.Lock()
+ pc.numExpectedResponses++
+ headerFn := pc.mutateHeaderFunc
+ pc.lk.Unlock()
+
+ if headerFn != nil {
+ headerFn(req.extraHeaders())
}
// Ask for a compressed version if the caller didn't set their
@@ -633,34 +759,84 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
// requested it.
requestedGzip := false
if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
- // Request gzip only, not deflate. Deflate is ambiguous and
+ // Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: http://www.gzip.org/zlib/zlib_faq.html#faq38
requestedGzip = true
req.extraHeaders().Set("Accept-Encoding", "gzip")
}
- pc.lk.Lock()
- pc.numExpectedResponses++
- pc.lk.Unlock()
+ // Write the request concurrently with waiting for a response,
+ // in case the server decides to reply before reading our full
+ // request body.
+ writeErrCh := make(chan error, 1)
+ pc.writech <- writeRequest{req, writeErrCh}
- err = req.Request.write(pc.bw, pc.isProxy, req.extra)
- if err != nil {
- pc.close()
- return
+ resc := make(chan responseAndError, 1)
+ pc.reqch <- requestAndChan{req.Request, resc, requestedGzip}
+
+ var re responseAndError
+ var pconnDeadCh = pc.closech
+ var failTicker <-chan time.Time
+ var respHeaderTimer <-chan time.Time
+WaitResponse:
+ for {
+ select {
+ case err := <-writeErrCh:
+ if err != nil {
+ re = responseAndError{nil, err}
+ pc.close()
+ break WaitResponse
+ }
+ if d := pc.t.ResponseHeaderTimeout; d > 0 {
+ respHeaderTimer = time.After(d)
+ }
+ case <-pconnDeadCh:
+ // The persist connection is dead. This shouldn't
+ // usually happen (only with Connection: close responses
+ // with no response bodies), but if it does happen it
+ // means either a) the remote server hung up on us
+ // prematurely, or b) the readLoop sent us a response &
+ // closed its closech at roughly the same time, and we
+ // selected this case first, in which case a response
+ // might still be coming soon.
+ //
+ // We can't avoid the select race in b) by using a unbuffered
+ // resc channel instead, because then goroutines can
+ // leak if we exit due to other errors.
+ pconnDeadCh = nil // avoid spinning
+ failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc
+ case <-failTicker:
+ re = responseAndError{err: errors.New("net/http: transport closed before response was received")}
+ break WaitResponse
+ case <-respHeaderTimer:
+ pc.close()
+ re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")}
+ break WaitResponse
+ case re = <-resc:
+ break WaitResponse
+ }
}
- pc.bw.Flush()
- ch := make(chan responseAndError, 1)
- pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
- re := <-ch
pc.lk.Lock()
pc.numExpectedResponses--
pc.lk.Unlock()
+ if re.err != nil {
+ pc.t.setReqConn(req.Request, nil)
+ }
return re.res, re.err
}
+// markBroken marks a connection as broken (so it's not reused).
+// It differs from close in that it doesn't close the underlying
+// connection for use when it's still being read.
+func (pc *persistConn) markBroken() {
+ pc.lk.Lock()
+ defer pc.lk.Unlock()
+ pc.broken = true
+}
+
func (pc *persistConn) close() {
pc.lk.Lock()
defer pc.lk.Unlock()
@@ -669,7 +845,10 @@ func (pc *persistConn) close() {
func (pc *persistConn) closeLocked() {
pc.broken = true
- pc.conn.Close()
+ if !pc.closed {
+ pc.conn.Close()
+ pc.closed = true
+ }
pc.mutateHeaderFunc = nil
}
@@ -687,43 +866,62 @@ func canonicalAddr(url *url.URL) string {
return addr
}
-func responseIsKeepAlive(res *Response) bool {
- // TODO: implement. for now just always shutting down the connection.
- return false
-}
-
// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
-// once, right before the final Read() or Close() call returns, but after
-// EOF has been seen.
+// once, right before its final (error-producing) Read or Close call
+// returns.
type bodyEOFSignal struct {
- body io.ReadCloser
- fn func()
- isClosed bool
+ body io.ReadCloser
+ mu sync.Mutex // guards closed, rerr and fn
+ closed bool // whether Close has been called
+ rerr error // sticky Read error
+ fn func(error) // error will be nil on Read io.EOF
}
func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
- n, err = es.body.Read(p)
- if es.isClosed && n > 0 {
- panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725")
+ es.mu.Lock()
+ closed, rerr := es.closed, es.rerr
+ es.mu.Unlock()
+ if closed {
+ return 0, errors.New("http: read on closed response body")
+ }
+ if rerr != nil {
+ return 0, rerr
}
- if err == io.EOF && es.fn != nil {
- es.fn()
- es.fn = nil
+
+ n, err = es.body.Read(p)
+ if err != nil {
+ es.mu.Lock()
+ defer es.mu.Unlock()
+ if es.rerr == nil {
+ es.rerr = err
+ }
+ es.condfn(err)
}
return
}
-func (es *bodyEOFSignal) Close() (err error) {
- if es.isClosed {
+func (es *bodyEOFSignal) Close() error {
+ es.mu.Lock()
+ defer es.mu.Unlock()
+ if es.closed {
return nil
}
- es.isClosed = true
- err = es.body.Close()
- if err == nil && es.fn != nil {
- es.fn()
- es.fn = nil
+ es.closed = true
+ err := es.body.Close()
+ es.condfn(err)
+ return err
+}
+
+// caller must hold es.mu.
+func (es *bodyEOFSignal) condfn(err error) {
+ if es.fn == nil {
+ return
}
- return
+ if err == io.EOF {
+ err = nil
+ }
+ es.fn(err)
+ es.fn = nil
}
type readFirstCloseBoth struct {
diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go
index a9e401de5..68010e68b 100644
--- a/src/pkg/net/http/transport_test.go
+++ b/src/pkg/net/http/transport_test.go
@@ -13,6 +13,7 @@ import (
"fmt"
"io"
"io/ioutil"
+ "net"
. "net/http"
"net/http/httptest"
"net/url"
@@ -20,6 +21,7 @@ import (
"runtime"
"strconv"
"strings"
+ "sync"
"testing"
"time"
)
@@ -35,14 +37,78 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte(r.RemoteAddr))
})
+// testCloseConn is a net.Conn tracked by a testConnSet.
+type testCloseConn struct {
+ net.Conn
+ set *testConnSet
+}
+
+func (c *testCloseConn) Close() error {
+ c.set.remove(c)
+ return c.Conn.Close()
+}
+
+// testConnSet tracks a set of TCP connections and whether they've
+// been closed.
+type testConnSet struct {
+ t *testing.T
+ closed map[net.Conn]bool
+ list []net.Conn // in order created
+ mutex sync.Mutex
+}
+
+func (tcs *testConnSet) insert(c net.Conn) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+ tcs.closed[c] = false
+ tcs.list = append(tcs.list, c)
+}
+
+func (tcs *testConnSet) remove(c net.Conn) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+ tcs.closed[c] = true
+}
+
+// some tests use this to manage raw tcp connections for later inspection
+func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
+ connSet := &testConnSet{
+ t: t,
+ closed: make(map[net.Conn]bool),
+ }
+ dial := func(n, addr string) (net.Conn, error) {
+ c, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ tc := &testCloseConn{c, connSet}
+ connSet.insert(tc)
+ return tc, nil
+ }
+ return connSet, dial
+}
+
+func (tcs *testConnSet) check(t *testing.T) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+
+ for i, c := range tcs.list {
+ if !tcs.closed[c] {
+ t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
+ }
+ }
+}
+
// Two subsequent requests and verify their response is the same.
// The response from the server is our own IP:port
func TestTransportKeepAlives(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
for _, disableKeepAlive := range []bool{false, true} {
tr := &Transport{DisableKeepAlives: disableKeepAlive}
+ defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
fetch := func(n int) string {
@@ -69,11 +135,16 @@ func TestTransportKeepAlives(t *testing.T) {
}
func TestTransportConnectionCloseOnResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
+ connSet, testDial := makeTestDial(t)
+
for _, connectionClose := range []bool{false, true} {
- tr := &Transport{}
+ tr := &Transport{
+ Dial: testDial,
+ }
c := &Client{Transport: tr}
fetch := func(n int) string {
@@ -92,8 +163,8 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
}
- body, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
+ body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
}
@@ -107,15 +178,24 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
+
+ tr.CloseIdleConnections()
}
+
+ connSet.check(t)
}
func TestTransportConnectionCloseOnRequest(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
+ connSet, testDial := makeTestDial(t)
+
for _, connectionClose := range []bool{false, true} {
- tr := &Transport{}
+ tr := &Transport{
+ Dial: testDial,
+ }
c := &Client{Transport: tr}
fetch := func(n int) string {
@@ -149,10 +229,15 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
+
+ tr.CloseIdleConnections()
}
+
+ connSet.check(t)
}
func TestTransportIdleCacheKeys(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
@@ -185,6 +270,7 @@ func TestTransportIdleCacheKeys(t *testing.T) {
}
func TestTransportMaxPerHostIdleConns(t *testing.T) {
+ defer checkLeakedTransports(t)
resch := make(chan string)
gotReq := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -201,7 +287,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
c := &Client{Transport: tr}
// Start 3 outstanding requests and wait for the server to get them.
- // Their responses will hang until we we write to resch, though.
+ // Their responses will hang until we write to resch, though.
donech := make(chan bool)
doReq := func() {
resp, err := c.Get(ts.URL)
@@ -253,6 +339,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
}
func TestTransportServerClosingUnexpectedly(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
@@ -309,9 +396,9 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) {
// Test for http://golang.org/issue/2616 (appropriate issue number)
// This fails pretty reliably with GOMAXPROCS=100 or something high.
func TestStressSurpriseServerCloses(t *testing.T) {
+ defer checkLeakedTransports(t)
if testing.Short() {
- t.Logf("skipping test in short mode")
- return
+ t.Skip("skipping test in short mode")
}
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "5")
@@ -365,6 +452,7 @@ func TestStressSurpriseServerCloses(t *testing.T) {
// TestTransportHeadResponses verifies that we deal with Content-Lengths
// with no bodies properly
func TestTransportHeadResponses(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "HEAD" {
panic("expected HEAD; got " + r.Method)
@@ -384,7 +472,7 @@ func TestTransportHeadResponses(t *testing.T) {
if e, g := "123", res.Header.Get("Content-Length"); e != g {
t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
}
- if e, g := int64(0), res.ContentLength; e != g {
+ if e, g := int64(123), res.ContentLength; e != g {
t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
}
}
@@ -393,6 +481,7 @@ func TestTransportHeadResponses(t *testing.T) {
// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
// on responses to HEAD requests.
func TestTransportHeadChunkedResponse(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "HEAD" {
panic("expected HEAD; got " + r.Method)
@@ -434,6 +523,7 @@ var roundTripTests = []struct {
// Test that the modification made to the Request by the RoundTripper is cleaned up
func TestRoundTripGzip(t *testing.T) {
+ defer checkLeakedTransports(t)
const responseBody = "test response body"
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
accept := req.Header.Get("Accept-Encoding")
@@ -490,6 +580,7 @@ func TestRoundTripGzip(t *testing.T) {
}
func TestTransportGzip(t *testing.T) {
+ defer checkLeakedTransports(t)
const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
const nRandBytes = 1024 * 1024
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -582,6 +673,7 @@ func TestTransportGzip(t *testing.T) {
}
func TestTransportProxy(t *testing.T) {
+ defer checkLeakedTransports(t)
ch := make(chan string, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- "real server"
@@ -610,6 +702,7 @@ func TestTransportProxy(t *testing.T) {
// but checks that we don't recurse forever, and checks that
// Content-Encoding is removed.
func TestTransportGzipRecursive(t *testing.T) {
+ defer checkLeakedTransports(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "gzip")
w.Write(rgz)
@@ -636,6 +729,7 @@ func TestTransportGzipRecursive(t *testing.T) {
// tests that persistent goroutine connections shut down when no longer desired.
func TestTransportPersistConnLeak(t *testing.T) {
+ defer checkLeakedTransports(t)
gotReqCh := make(chan bool)
unblockCh := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -698,8 +792,49 @@ func TestTransportPersistConnLeak(t *testing.T) {
}
}
+// golang.org/issue/4531: Transport leaks goroutines when
+// request.ContentLength is explicitly short
+func TestTransportPersistConnLeakShortBody(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ n0 := runtime.NumGoroutine()
+ body := []byte("Hello")
+ for i := 0; i < 20; i++ {
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = int64(len(body) - 2) // explicitly short
+ _, err = c.Do(req)
+ if err == nil {
+ t.Fatal("Expect an error from writing too long of a body.")
+ }
+ }
+ nhigh := runtime.NumGoroutine()
+ tr.CloseIdleConnections()
+ time.Sleep(50 * time.Millisecond)
+ runtime.GC()
+ nfinal := runtime.NumGoroutine()
+
+ growth := nfinal - n0
+
+ // We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
+ // Previously we were leaking one per numReq.
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+ if int(growth) > 5 {
+ t.Error("too many new goroutines")
+ }
+}
+
// This used to crash; http://golang.org/issue/3266
func TestTransportIdleConnCrash(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &Transport{}
c := &Client{Transport: tr}
@@ -724,6 +859,361 @@ func TestTransportIdleConnCrash(t *testing.T) {
<-didreq
}
+// Test that the transport doesn't close the TCP connection early,
+// before the response body has been read. This was a regression
+// which sadly lacked a triggering test. The large response body made
+// the old race easier to trigger.
+func TestIssue3644(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const numFoos = 5000
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ for i := 0; i < numFoos; i++ {
+ w.Write([]byte("foo "))
+ }
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ bs, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != numFoos*len("foo ") {
+ t.Errorf("unexpected response length")
+ }
+}
+
+// Test that a client receives a server's reply, even if the server doesn't read
+// the entire request body.
+func TestIssue3595(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const deniedMsg = "sorry, denied."
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, deniedMsg, StatusUnauthorized)
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
+ if err != nil {
+ t.Errorf("Post: %v", err)
+ return
+ }
+ got, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Body ReadAll: %v", err)
+ }
+ if !strings.Contains(string(got), deniedMsg) {
+ t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
+ }
+}
+
+// From http://golang.org/issue/4454 ,
+// "client fails to handle requests with no body and chunked encoding"
+func TestChunkedNoContent(t *testing.T) {
+ defer checkLeakedTransports(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNoContent)
+ }))
+ defer ts.Close()
+
+ for _, closeBody := range []bool{true, false} {
+ c := &Client{Transport: &Transport{}}
+ const n = 4
+ for i := 1; i <= n; i++ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
+ } else {
+ if closeBody {
+ res.Body.Close()
+ }
+ }
+ }
+ }
+}
+
+func TestTransportConcurrency(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const maxProcs = 16
+ const numReqs = 500
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%v", r.FormValue("echo"))
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ reqs := make(chan string)
+ defer close(reqs)
+
+ var wg sync.WaitGroup
+ wg.Add(numReqs)
+ for i := 0; i < maxProcs*2; i++ {
+ go func() {
+ for req := range reqs {
+ res, err := c.Get(ts.URL + "/?echo=" + req)
+ if err != nil {
+ t.Errorf("error on req %s: %v", req, err)
+ wg.Done()
+ continue
+ }
+ all, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Errorf("read error on req %s: %v", req, err)
+ wg.Done()
+ continue
+ }
+ if string(all) != req {
+ t.Errorf("body of req %s = %q; want %q", req, all, req)
+ }
+ wg.Done()
+ res.Body.Close()
+ }
+ }()
+ }
+ for i := 0; i < numReqs; i++ {
+ reqs <- fmt.Sprintf("request-%d", i)
+ }
+ wg.Wait()
+}
+
+func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ ts := httptest.NewServer(mux)
+ timeout := 100 * time.Millisecond
+
+ client := &Client{
+ Transport: &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ },
+ DisableKeepAlives: true,
+ },
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := client.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ _, err = io.Copy(ioutil.Discard, sres.Body)
+ if err == nil {
+ t.Errorf("Unexpected successful copy")
+ break
+ }
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
+ defer r.Body.Close()
+ io.Copy(ioutil.Discard, r.Body)
+ })
+ ts := httptest.NewServer(mux)
+ timeout := 100 * time.Millisecond
+
+ client := &Client{
+ Transport: &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ },
+ DisableKeepAlives: true,
+ },
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := client.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
+ _, err = client.Do(req)
+ if err == nil {
+ sres.Body.Close()
+ t.Errorf("Unexpected successful PUT")
+ break
+ }
+ sres.Body.Close()
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestTransportResponseHeaderTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ if testing.Short() {
+ t.Skip("skipping timeout test in -short mode")
+ }
+ mux := NewServeMux()
+ mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
+ mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
+ time.Sleep(2 * time.Second)
+ })
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ tr := &Transport{
+ ResponseHeaderTimeout: 500 * time.Millisecond,
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ tests := []struct {
+ path string
+ want int
+ wantErr string
+ }{
+ {path: "/fast", want: 200},
+ {path: "/slow", wantErr: "timeout awaiting response headers"},
+ {path: "/fast", want: 200},
+ }
+ for i, tt := range tests {
+ res, err := c.Get(ts.URL + tt.path)
+ if err != nil {
+ if strings.Contains(err.Error(), tt.wantErr) {
+ continue
+ }
+ t.Errorf("%d. unexpected error: %v", i, err)
+ continue
+ }
+ if tt.wantErr != "" {
+ t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
+ }
+ }
+}
+
+func TestTransportCancelRequest(t *testing.T) {
+ defer checkLeakedTransports(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello")
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ }))
+ defer ts.Close()
+ defer close(unblockc)
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ go func() {
+ time.Sleep(1 * time.Second)
+ tr.CancelRequest(req)
+ }()
+ t0 := time.Now()
+ body, err := ioutil.ReadAll(res.Body)
+ d := time.Since(t0)
+
+ if err == nil {
+ t.Error("expected an error reading the body")
+ }
+ if string(body) != "Hello" {
+ t.Errorf("Body = %q; want Hello", body)
+ }
+ if d < 500*time.Millisecond {
+ t.Errorf("expected ~1 second delay; got %v", d)
+ }
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ for tries := 3; tries > 0; tries-- {
+ n := tr.NumPendingRequestsForTesting()
+ if n == 0 {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ if tries == 1 {
+ t.Errorf("pending requests = %d; want 0", n)
+ }
+ }
+}
+
type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) {
@@ -737,6 +1227,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) {
}
func TestTransportAltProto(t *testing.T) {
+ defer checkLeakedTransports(t)
tr := &Transport{}
c := &Client{Transport: tr}
tr.RegisterProtocol("foo", fooProto{})
@@ -754,15 +1245,58 @@ func TestTransportAltProto(t *testing.T) {
}
}
-var proxyFromEnvTests = []struct {
+func TestTransportNoHost(t *testing.T) {
+ defer checkLeakedTransports(t)
+ tr := &Transport{}
+ _, err := tr.RoundTrip(&Request{
+ Header: make(Header),
+ URL: &url.URL{
+ Scheme: "http",
+ },
+ })
+ want := "http: no Host in request URL"
+ if got := fmt.Sprint(err); got != want {
+ t.Errorf("error = %v; want %q", err, want)
+ }
+}
+
+type proxyFromEnvTest struct {
+ req string // URL to fetch; blank means "http://example.com"
env string
- wanturl string
+ noenv string
+ want string
wanterr error
-}{
- {"127.0.0.1:8080", "http://127.0.0.1:8080", nil},
- {"http://127.0.0.1:8080", "http://127.0.0.1:8080", nil},
- {"https://127.0.0.1:8080", "https://127.0.0.1:8080", nil},
- {"", "<nil>", nil},
+}
+
+func (t proxyFromEnvTest) String() string {
+ var buf bytes.Buffer
+ if t.env != "" {
+ fmt.Fprintf(&buf, "http_proxy=%q", t.env)
+ }
+ if t.noenv != "" {
+ fmt.Fprintf(&buf, " no_proxy=%q", t.noenv)
+ }
+ req := "http://example.com"
+ if t.req != "" {
+ req = t.req
+ }
+ fmt.Fprintf(&buf, " req=%q", req)
+ return strings.TrimSpace(buf.String())
+}
+
+var proxyFromEnvTests = []proxyFromEnvTest{
+ {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
+ {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
+ {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
+ {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
+ {want: "<nil>"},
+ {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+ {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
}
func TestProxyFromEnvironment(t *testing.T) {
@@ -770,16 +1304,21 @@ func TestProxyFromEnvironment(t *testing.T) {
os.Setenv("http_proxy", "")
os.Setenv("NO_PROXY", "")
os.Setenv("no_proxy", "")
- for i, tt := range proxyFromEnvTests {
+ for _, tt := range proxyFromEnvTests {
os.Setenv("HTTP_PROXY", tt.env)
- req, _ := NewRequest("GET", "http://example.com", nil)
+ os.Setenv("NO_PROXY", tt.noenv)
+ reqURL := tt.req
+ if reqURL == "" {
+ reqURL = "http://example.com"
+ }
+ req, _ := NewRequest("GET", reqURL, nil)
url, err := ProxyFromEnvironment(req)
if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
- t.Errorf("%d. got error = %q, want %q", i, g, e)
+ t.Errorf("%v: got error = %q, want %q", tt, g, e)
continue
}
- if got := fmt.Sprintf("%s", url); got != tt.wanturl {
- t.Errorf("%d. got URL = %q, want %q", i, url, tt.wanturl)
+ if got := fmt.Sprintf("%s", url); got != tt.want {
+ t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
}
}
}
diff --git a/src/pkg/net/http/z_last_test.go b/src/pkg/net/http/z_last_test.go
new file mode 100644
index 000000000..44095a8d9
--- /dev/null
+++ b/src/pkg/net/http/z_last_test.go
@@ -0,0 +1,60 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "net/http"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+// Verify the other tests didn't leave any goroutines running.
+// This is in a file named z_last_test.go so it sorts at the end.
+func TestGoroutinesRunning(t *testing.T) {
+ n := runtime.NumGoroutine()
+ t.Logf("num goroutines = %d", n)
+ if n > 20 {
+ // Currently 14 on Linux (blocked in epoll_wait,
+ // waiting for on fds that are closed?), but give some
+ // slop for now.
+ buf := make([]byte, 1<<20)
+ buf = buf[:runtime.Stack(buf, true)]
+ t.Errorf("Too many goroutines:\n%s", buf)
+ }
+}
+
+func checkLeakedTransports(t *testing.T) {
+ http.DefaultTransport.(*http.Transport).CloseIdleConnections()
+ if testing.Short() {
+ return
+ }
+ buf := make([]byte, 1<<20)
+ var stacks string
+ var bad string
+ badSubstring := map[string]string{
+ ").readLoop(": "a Transport",
+ ").writeLoop(": "a Transport",
+ "created by net/http/httptest.(*Server).Start": "an httptest.Server",
+ "timeoutHandler": "a TimeoutHandler",
+ }
+ for i := 0; i < 4; i++ {
+ bad = ""
+ stacks = string(buf[:runtime.Stack(buf, true)])
+ for substr, what := range badSubstring {
+ if strings.Contains(stacks, substr) {
+ bad = what
+ }
+ }
+ if bad == "" {
+ return
+ }
+ // Bad stuff found, but goroutines might just still be
+ // shutting down, so give it some time.
+ time.Sleep(250 * time.Millisecond)
+ }
+ t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks)
+}
diff --git a/src/pkg/net/interface.go b/src/pkg/net/interface.go
index ee23570a9..0713e9cd6 100644
--- a/src/pkg/net/interface.go
+++ b/src/pkg/net/interface.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Network interface identification
-
package net
import "errors"
@@ -66,7 +64,7 @@ func (ifi *Interface) Addrs() ([]Addr, error) {
if ifi == nil {
return nil, errInvalidInterface
}
- return interfaceAddrTable(ifi.Index)
+ return interfaceAddrTable(ifi)
}
// MulticastAddrs returns multicast, joined group addresses for
@@ -75,7 +73,7 @@ func (ifi *Interface) MulticastAddrs() ([]Addr, error) {
if ifi == nil {
return nil, errInvalidInterface
}
- return interfaceMulticastAddrTable(ifi.Index)
+ return interfaceMulticastAddrTable(ifi)
}
// Interfaces returns a list of the system's network interfaces.
@@ -86,7 +84,7 @@ func Interfaces() ([]Interface, error) {
// InterfaceAddrs returns a list of the system's network interface
// addresses.
func InterfaceAddrs() ([]Addr, error) {
- return interfaceAddrTable(0)
+ return interfaceAddrTable(nil)
}
// InterfaceByIndex returns the interface specified by index.
@@ -98,8 +96,14 @@ func InterfaceByIndex(index int) (*Interface, error) {
if err != nil {
return nil, err
}
+ return interfaceByIndex(ift, index)
+}
+
+func interfaceByIndex(ift []Interface, index int) (*Interface, error) {
for _, ifi := range ift {
- return &ifi, nil
+ if index == ifi.Index {
+ return &ifi, nil
+ }
}
return nil, errNoSuchInterface
}
diff --git a/src/pkg/net/interface_bsd.go b/src/pkg/net/interface_bsd.go
index 7f090d8d4..f58065a85 100644
--- a/src/pkg/net/interface_bsd.go
+++ b/src/pkg/net/interface_bsd.go
@@ -4,8 +4,6 @@
// +build darwin freebsd netbsd openbsd
-// Network interface identification for BSD variants
-
package net
import (
@@ -22,57 +20,60 @@ func interfaceTable(ifindex int) ([]Interface, error) {
if err != nil {
return nil, os.NewSyscallError("route rib", err)
}
-
msgs, err := syscall.ParseRoutingMessage(tab)
if err != nil {
return nil, os.NewSyscallError("route message", err)
}
+ return parseInterfaceTable(ifindex, msgs)
+}
+func parseInterfaceTable(ifindex int, msgs []syscall.RoutingMessage) ([]Interface, error) {
var ift []Interface
+loop:
for _, m := range msgs {
- switch v := m.(type) {
+ switch m := m.(type) {
case *syscall.InterfaceMessage:
- if ifindex == 0 || ifindex == int(v.Header.Index) {
- ifi, err := newLink(v)
+ if ifindex == 0 || ifindex == int(m.Header.Index) {
+ ifi, err := newLink(m)
if err != nil {
return nil, err
}
- ift = append(ift, ifi...)
+ ift = append(ift, *ifi)
+ if ifindex == int(m.Header.Index) {
+ break loop
+ }
}
}
}
return ift, nil
}
-func newLink(m *syscall.InterfaceMessage) ([]Interface, error) {
+func newLink(m *syscall.InterfaceMessage) (*Interface, error) {
sas, err := syscall.ParseRoutingSockaddr(m)
if err != nil {
return nil, os.NewSyscallError("route sockaddr", err)
}
-
- var ift []Interface
- for _, s := range sas {
- switch v := s.(type) {
+ ifi := &Interface{Index: int(m.Header.Index), Flags: linkFlags(m.Header.Flags)}
+ for _, sa := range sas {
+ switch sa := sa.(type) {
case *syscall.SockaddrDatalink:
// NOTE: SockaddrDatalink.Data is minimum work area,
// can be larger.
- m.Data = m.Data[unsafe.Offsetof(v.Data):]
- ifi := Interface{Index: int(m.Header.Index), Flags: linkFlags(m.Header.Flags)}
+ m.Data = m.Data[unsafe.Offsetof(sa.Data):]
var name [syscall.IFNAMSIZ]byte
- for i := 0; i < int(v.Nlen); i++ {
+ for i := 0; i < int(sa.Nlen); i++ {
name[i] = byte(m.Data[i])
}
- ifi.Name = string(name[:v.Nlen])
+ ifi.Name = string(name[:sa.Nlen])
ifi.MTU = int(m.Header.Data.Mtu)
- addr := make([]byte, v.Alen)
- for i := 0; i < int(v.Alen); i++ {
- addr[i] = byte(m.Data[int(v.Nlen)+i])
+ addr := make([]byte, sa.Alen)
+ for i := 0; i < int(sa.Alen); i++ {
+ addr[i] = byte(m.Data[int(sa.Nlen)+i])
}
- ifi.HardwareAddr = addr[:v.Alen]
- ift = append(ift, ifi)
+ ifi.HardwareAddr = addr[:sa.Alen]
}
}
- return ift, nil
+ return ifi, nil
}
func linkFlags(rawFlags int32) Flags {
@@ -95,68 +96,87 @@ func linkFlags(rawFlags int32) Flags {
return f
}
-// If the ifindex is zero, interfaceAddrTable returns addresses
-// for all network interfaces. Otherwise it returns addresses
-// for a specific interface.
-func interfaceAddrTable(ifindex int) ([]Addr, error) {
- tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex)
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ index := 0
+ if ifi != nil {
+ index = ifi.Index
+ }
+ tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, index)
if err != nil {
return nil, os.NewSyscallError("route rib", err)
}
-
msgs, err := syscall.ParseRoutingMessage(tab)
if err != nil {
return nil, os.NewSyscallError("route message", err)
}
-
+ var ift []Interface
+ if index == 0 {
+ ift, err = parseInterfaceTable(index, msgs)
+ if err != nil {
+ return nil, err
+ }
+ }
var ifat []Addr
for _, m := range msgs {
- switch v := m.(type) {
+ switch m := m.(type) {
case *syscall.InterfaceAddrMessage:
- if ifindex == 0 || ifindex == int(v.Header.Index) {
- ifa, err := newAddr(v)
+ if index == 0 || index == int(m.Header.Index) {
+ if index == 0 {
+ var err error
+ ifi, err = interfaceByIndex(ift, int(m.Header.Index))
+ if err != nil {
+ return nil, err
+ }
+ }
+ ifa, err := newAddr(ifi, m)
if err != nil {
return nil, err
}
- ifat = append(ifat, ifa)
+ if ifa != nil {
+ ifat = append(ifat, ifa)
+ }
}
}
}
return ifat, nil
}
-func newAddr(m *syscall.InterfaceAddrMessage) (Addr, error) {
+func newAddr(ifi *Interface, m *syscall.InterfaceAddrMessage) (Addr, error) {
sas, err := syscall.ParseRoutingSockaddr(m)
if err != nil {
return nil, os.NewSyscallError("route sockaddr", err)
}
-
ifa := &IPNet{}
- for i, s := range sas {
- switch v := s.(type) {
+ for i, sa := range sas {
+ switch sa := sa.(type) {
case *syscall.SockaddrInet4:
switch i {
case 0:
- ifa.Mask = IPv4Mask(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])
+ ifa.Mask = IPv4Mask(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])
case 1:
- ifa.IP = IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])
+ ifa.IP = IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])
}
case *syscall.SockaddrInet6:
switch i {
case 0:
ifa.Mask = make(IPMask, IPv6len)
- copy(ifa.Mask, v.Addr[:])
+ copy(ifa.Mask, sa.Addr[:])
case 1:
ifa.IP = make(IP, IPv6len)
- copy(ifa.IP, v.Addr[:])
+ copy(ifa.IP, sa.Addr[:])
// NOTE: KAME based IPv6 protcol stack usually embeds
// the interface index in the interface-local or link-
// local address as the kernel-internal form.
if ifa.IP.IsLinkLocalUnicast() {
- // remove embedded scope zone ID
+ ifa.Zone = ifi.Name
ifa.IP[2], ifa.IP[3] = 0, 0
}
}
+ default: // Sockaddrs contain syscall.SockaddrDatalink on NetBSD
+ return nil, nil
}
}
return ifa, nil
diff --git a/src/pkg/net/interface_bsd_test.go b/src/pkg/net/interface_bsd_test.go
new file mode 100644
index 000000000..aa1141903
--- /dev/null
+++ b/src/pkg/net/interface_bsd_test.go
@@ -0,0 +1,52 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd netbsd openbsd
+
+package net
+
+import (
+ "fmt"
+ "os/exec"
+)
+
+func (ti *testInterface) setBroadcast(suffix int) error {
+ ti.name = fmt.Sprintf("vlan%d", suffix)
+ xname, err := exec.LookPath("ifconfig")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "create"},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "destroy"},
+ })
+ return nil
+}
+
+func (ti *testInterface) setPointToPoint(suffix int, local, remote string) error {
+ ti.name = fmt.Sprintf("gif%d", suffix)
+ ti.local = local
+ ti.remote = remote
+ xname, err := exec.LookPath("ifconfig")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "create"},
+ })
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "inet", ti.local, ti.remote},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "destroy"},
+ })
+ return nil
+}
diff --git a/src/pkg/net/interface_darwin.go b/src/pkg/net/interface_darwin.go
index 0b5fb5fb9..83e483ba2 100644
--- a/src/pkg/net/interface_darwin.go
+++ b/src/pkg/net/interface_darwin.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Network interface identification for Darwin
-
package net
import (
@@ -11,26 +9,23 @@ import (
"syscall"
)
-// If the ifindex is zero, interfaceMulticastAddrTable returns
-// addresses for all network interfaces. Otherwise it returns
-// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
- tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST2, ifindex)
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST2, ifi.Index)
if err != nil {
return nil, os.NewSyscallError("route rib", err)
}
-
msgs, err := syscall.ParseRoutingMessage(tab)
if err != nil {
return nil, os.NewSyscallError("route message", err)
}
-
var ifmat []Addr
for _, m := range msgs {
- switch v := m.(type) {
+ switch m := m.(type) {
case *syscall.InterfaceMulticastAddrMessage:
- if ifindex == 0 || ifindex == int(v.Header.Index) {
- ifma, err := newMulticastAddr(v)
+ if ifi.Index == int(m.Header.Index) {
+ ifma, err := newMulticastAddr(ifi, m)
if err != nil {
return nil, err
}
@@ -41,27 +36,25 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
return ifmat, nil
}
-func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) {
+func newMulticastAddr(ifi *Interface, m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) {
sas, err := syscall.ParseRoutingSockaddr(m)
if err != nil {
return nil, os.NewSyscallError("route sockaddr", err)
}
-
var ifmat []Addr
- for _, s := range sas {
- switch v := s.(type) {
+ for _, sa := range sas {
+ switch sa := sa.(type) {
case *syscall.SockaddrInet4:
- ifma := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])}
+ ifma := &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])}
ifmat = append(ifmat, ifma.toAddr())
case *syscall.SockaddrInet6:
ifma := &IPAddr{IP: make(IP, IPv6len)}
- copy(ifma.IP, v.Addr[:])
+ copy(ifma.IP, sa.Addr[:])
// NOTE: KAME based IPv6 protcol stack usually embeds
// the interface index in the interface-local or link-
// local address as the kernel-internal form.
- if ifma.IP.IsInterfaceLocalMulticast() ||
- ifma.IP.IsLinkLocalMulticast() {
- // remove embedded scope zone ID
+ if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() {
+ ifma.Zone = ifi.Name
ifma.IP[2], ifma.IP[3] = 0, 0
}
ifmat = append(ifmat, ifma.toAddr())
diff --git a/src/pkg/net/interface_freebsd.go b/src/pkg/net/interface_freebsd.go
index 3cba28fc6..1bf5ae72b 100644
--- a/src/pkg/net/interface_freebsd.go
+++ b/src/pkg/net/interface_freebsd.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Network interface identification for FreeBSD
-
package net
import (
@@ -11,26 +9,23 @@ import (
"syscall"
)
-// If the ifindex is zero, interfaceMulticastAddrTable returns
-// addresses for all network interfaces. Otherwise it returns
-// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
- tab, err := syscall.RouteRIB(syscall.NET_RT_IFMALIST, ifindex)
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ tab, err := syscall.RouteRIB(syscall.NET_RT_IFMALIST, ifi.Index)
if err != nil {
return nil, os.NewSyscallError("route rib", err)
}
-
msgs, err := syscall.ParseRoutingMessage(tab)
if err != nil {
return nil, os.NewSyscallError("route message", err)
}
-
var ifmat []Addr
for _, m := range msgs {
- switch v := m.(type) {
+ switch m := m.(type) {
case *syscall.InterfaceMulticastAddrMessage:
- if ifindex == 0 || ifindex == int(v.Header.Index) {
- ifma, err := newMulticastAddr(v)
+ if ifi.Index == int(m.Header.Index) {
+ ifma, err := newMulticastAddr(ifi, m)
if err != nil {
return nil, err
}
@@ -41,27 +36,25 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
return ifmat, nil
}
-func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) {
+func newMulticastAddr(ifi *Interface, m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) {
sas, err := syscall.ParseRoutingSockaddr(m)
if err != nil {
return nil, os.NewSyscallError("route sockaddr", err)
}
-
var ifmat []Addr
- for _, s := range sas {
- switch v := s.(type) {
+ for _, sa := range sas {
+ switch sa := sa.(type) {
case *syscall.SockaddrInet4:
- ifma := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])}
+ ifma := &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])}
ifmat = append(ifmat, ifma.toAddr())
case *syscall.SockaddrInet6:
ifma := &IPAddr{IP: make(IP, IPv6len)}
- copy(ifma.IP, v.Addr[:])
+ copy(ifma.IP, sa.Addr[:])
// NOTE: KAME based IPv6 protcol stack usually embeds
// the interface index in the interface-local or link-
// local address as the kernel-internal form.
- if ifma.IP.IsInterfaceLocalMulticast() ||
- ifma.IP.IsLinkLocalMulticast() {
- // remove embedded scope zone ID
+ if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() {
+ ifma.Zone = ifi.Name
ifma.IP[2], ifma.IP[3] = 0, 0
}
ifmat = append(ifmat, ifma.toAddr())
diff --git a/src/pkg/net/interface_linux.go b/src/pkg/net/interface_linux.go
index 825b20227..e66daef06 100644
--- a/src/pkg/net/interface_linux.go
+++ b/src/pkg/net/interface_linux.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Network interface identification for Linux
-
package net
import (
@@ -20,17 +18,16 @@ func interfaceTable(ifindex int) ([]Interface, error) {
if err != nil {
return nil, os.NewSyscallError("netlink rib", err)
}
-
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil {
return nil, os.NewSyscallError("netlink message", err)
}
-
var ift []Interface
+loop:
for _, m := range msgs {
switch m.Header.Type {
case syscall.NLMSG_DONE:
- goto done
+ break loop
case syscall.RTM_NEWLINK:
ifim := (*syscall.IfInfomsg)(unsafe.Pointer(&m.Data[0]))
if ifindex == 0 || ifindex == int(ifim.Index) {
@@ -38,17 +35,18 @@ func interfaceTable(ifindex int) ([]Interface, error) {
if err != nil {
return nil, os.NewSyscallError("netlink routeattr", err)
}
- ifi := newLink(ifim, attrs)
- ift = append(ift, ifi)
+ ift = append(ift, *newLink(ifim, attrs))
+ if ifindex == int(ifim.Index) {
+ break loop
+ }
}
}
}
-done:
return ift, nil
}
-func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) Interface {
- ifi := Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)}
+func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) *Interface {
+ ifi := &Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)}
for _, a := range attrs {
switch a.Attr.Type {
case syscall.IFLA_ADDRESS:
@@ -64,7 +62,7 @@ func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) Interfac
case syscall.IFLA_IFNAME:
ifi.Name = string(a.Value[:len(a.Value)-1])
case syscall.IFLA_MTU:
- ifi.MTU = int(uint32(a.Value[3])<<24 | uint32(a.Value[2])<<16 | uint32(a.Value[1])<<8 | uint32(a.Value[0]))
+ ifi.MTU = int(*(*uint32)(unsafe.Pointer(&a.Value[:4][0])))
}
}
return ifi
@@ -90,81 +88,87 @@ func linkFlags(rawFlags uint32) Flags {
return f
}
-// If the ifindex is zero, interfaceAddrTable returns addresses
-// for all network interfaces. Otherwise it returns addresses
-// for a specific interface.
-func interfaceAddrTable(ifindex int) ([]Addr, error) {
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
tab, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC)
if err != nil {
return nil, os.NewSyscallError("netlink rib", err)
}
-
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil {
return nil, os.NewSyscallError("netlink message", err)
}
-
- ifat, err := addrTable(msgs, ifindex)
+ var ift []Interface
+ if ifi == nil {
+ var err error
+ ift, err = interfaceTable(0)
+ if err != nil {
+ return nil, err
+ }
+ }
+ ifat, err := addrTable(ift, ifi, msgs)
if err != nil {
return nil, err
}
return ifat, nil
}
-func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, error) {
+func addrTable(ift []Interface, ifi *Interface, msgs []syscall.NetlinkMessage) ([]Addr, error) {
var ifat []Addr
+loop:
for _, m := range msgs {
switch m.Header.Type {
case syscall.NLMSG_DONE:
- goto done
+ break loop
case syscall.RTM_NEWADDR:
ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0]))
- if ifindex == 0 || ifindex == int(ifam.Index) {
+ if len(ift) != 0 || ifi.Index == int(ifam.Index) {
+ if len(ift) != 0 {
+ var err error
+ ifi, err = interfaceByIndex(ift, int(ifam.Index))
+ if err != nil {
+ return nil, err
+ }
+ }
attrs, err := syscall.ParseNetlinkRouteAttr(&m)
if err != nil {
return nil, os.NewSyscallError("netlink routeattr", err)
}
- ifat = append(ifat, newAddr(attrs, int(ifam.Family), int(ifam.Prefixlen)))
+ ifa := newAddr(ifi, ifam, attrs)
+ if ifa != nil {
+ ifat = append(ifat, ifa)
+ }
}
}
}
-done:
return ifat, nil
}
-func newAddr(attrs []syscall.NetlinkRouteAttr, family, pfxlen int) Addr {
- ifa := &IPNet{}
+func newAddr(ifi *Interface, ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRouteAttr) Addr {
for _, a := range attrs {
- switch a.Attr.Type {
- case syscall.IFA_ADDRESS:
- switch family {
+ if ifi.Flags&FlagPointToPoint != 0 && a.Attr.Type == syscall.IFA_LOCAL ||
+ ifi.Flags&FlagPointToPoint == 0 && a.Attr.Type == syscall.IFA_ADDRESS {
+ switch ifam.Family {
case syscall.AF_INET:
- ifa.IP = IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3])
- ifa.Mask = CIDRMask(pfxlen, 8*IPv4len)
+ return &IPNet{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv4len)}
case syscall.AF_INET6:
- ifa.IP = make(IP, IPv6len)
+ ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv6len)}
copy(ifa.IP, a.Value[:])
- ifa.Mask = CIDRMask(pfxlen, 8*IPv6len)
+ if ifam.Scope == syscall.RT_SCOPE_HOST || ifam.Scope == syscall.RT_SCOPE_LINK {
+ ifa.Zone = ifi.Name
+ }
+ return ifa
}
}
}
- return ifa
+ return nil
}
-// If the ifindex is zero, interfaceMulticastAddrTable returns
-// addresses for all network interfaces. Otherwise it returns
-// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
- var (
- err error
- ifi *Interface
- )
- if ifindex > 0 {
- ifi, err = InterfaceByIndex(ifindex)
- if err != nil {
- return nil, err
- }
- }
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
ifmat4 := parseProcNetIGMP("/proc/net/igmp", ifi)
ifmat6 := parseProcNetIGMP6("/proc/net/igmp6", ifi)
return append(ifmat4, ifmat6...), nil
@@ -176,7 +180,6 @@ func parseProcNetIGMP(path string, ifi *Interface) []Addr {
return nil
}
defer fd.close()
-
var (
ifmat []Addr
name string
@@ -193,10 +196,14 @@ func parseProcNetIGMP(path string, ifi *Interface) []Addr {
name = f[1]
case len(f[0]) == 8:
if ifi == nil || name == ifi.Name {
+ // The Linux kernel puts the IP
+ // address in /proc/net/igmp in native
+ // endianness.
for i := 0; i+1 < len(f[0]); i += 2 {
b[i/2], _ = xtoi2(f[0][i:i+2], 0)
}
- ifma := IPAddr{IP: IPv4(b[3], b[2], b[1], b[0])}
+ i := *(*uint32)(unsafe.Pointer(&b[:4][0]))
+ ifma := IPAddr{IP: IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i))}
ifmat = append(ifmat, ifma.toAddr())
}
}
@@ -210,7 +217,6 @@ func parseProcNetIGMP6(path string, ifi *Interface) []Addr {
return nil
}
defer fd.close()
-
var ifmat []Addr
b := make([]byte, IPv6len)
for l, ok := fd.readLine(); ok; l, ok = fd.readLine() {
@@ -223,6 +229,9 @@ func parseProcNetIGMP6(path string, ifi *Interface) []Addr {
b[i/2], _ = xtoi2(f[2][i:i+2], 0)
}
ifma := IPAddr{IP: IP{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]}}
+ if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() {
+ ifma.Zone = ifi.Name
+ }
ifmat = append(ifmat, ifma.toAddr())
}
}
diff --git a/src/pkg/net/interface_linux_test.go b/src/pkg/net/interface_linux_test.go
index f14d1fe06..085d3de9d 100644
--- a/src/pkg/net/interface_linux_test.go
+++ b/src/pkg/net/interface_linux_test.go
@@ -4,7 +4,55 @@
package net
-import "testing"
+import (
+ "fmt"
+ "os/exec"
+ "testing"
+)
+
+func (ti *testInterface) setBroadcast(suffix int) error {
+ ti.name = fmt.Sprintf("gotest%d", suffix)
+ xname, err := exec.LookPath("ip")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "link", "add", ti.name, "type", "dummy"},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "link", "delete", ti.name, "type", "dummy"},
+ })
+ return nil
+}
+
+func (ti *testInterface) setPointToPoint(suffix int, local, remote string) error {
+ ti.name = fmt.Sprintf("gotest%d", suffix)
+ ti.local = local
+ ti.remote = remote
+ xname, err := exec.LookPath("ip")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "tunnel", "add", ti.name, "mode", "gre", "local", local, "remote", remote},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "tunnel", "del", ti.name, "mode", "gre", "local", local, "remote", remote},
+ })
+ xname, err = exec.LookPath("ifconfig")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "inet", local, "dstaddr", remote},
+ })
+ return nil
+}
const (
numOfTestIPv4MCAddrs = 14
diff --git a/src/pkg/net/interface_netbsd.go b/src/pkg/net/interface_netbsd.go
index 4150e9ad5..c9ce5a7ac 100644
--- a/src/pkg/net/interface_netbsd.go
+++ b/src/pkg/net/interface_netbsd.go
@@ -2,13 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Network interface identification for NetBSD
-
package net
-// If the ifindex is zero, interfaceMulticastAddrTable returns
-// addresses for all network interfaces. Otherwise it returns
-// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ // TODO(mikio): Implement this like other platforms.
return nil, nil
}
diff --git a/src/pkg/net/interface_openbsd.go b/src/pkg/net/interface_openbsd.go
index d8adb4676..c9ce5a7ac 100644
--- a/src/pkg/net/interface_openbsd.go
+++ b/src/pkg/net/interface_openbsd.go
@@ -2,13 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Network interface identification for OpenBSD
-
package net
-// If the ifindex is zero, interfaceMulticastAddrTable returns
-// addresses for all network interfaces. Otherwise it returns
-// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ // TODO(mikio): Implement this like other platforms.
return nil, nil
}
diff --git a/src/pkg/net/interface_stub.go b/src/pkg/net/interface_stub.go
index d4d7ce9c7..a4eb731da 100644
--- a/src/pkg/net/interface_stub.go
+++ b/src/pkg/net/interface_stub.go
@@ -4,8 +4,6 @@
// +build plan9
-// Network interface identification
-
package net
// If the ifindex is zero, interfaceTable returns mappings of all
@@ -15,16 +13,15 @@ func interfaceTable(ifindex int) ([]Interface, error) {
return nil, nil
}
-// If the ifindex is zero, interfaceAddrTable returns addresses
-// for all network interfaces. Otherwise it returns addresses
-// for a specific interface.
-func interfaceAddrTable(ifindex int) ([]Addr, error) {
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
return nil, nil
}
-// If the ifindex is zero, interfaceMulticastAddrTable returns
-// addresses for all network interfaces. Otherwise it returns
-// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
return nil, nil
}
diff --git a/src/pkg/net/interface_test.go b/src/pkg/net/interface_test.go
index 0a33bfdb5..7fb342818 100644
--- a/src/pkg/net/interface_test.go
+++ b/src/pkg/net/interface_test.go
@@ -5,18 +5,24 @@
package net
import (
- "bytes"
+ "reflect"
"testing"
)
-func sameInterface(i, j *Interface) bool {
- if i == nil || j == nil {
- return false
+// loopbackInterface returns an available logical network interface
+// for loopback tests. It returns nil if no suitable interface is
+// found.
+func loopbackInterface() *Interface {
+ ift, err := Interfaces()
+ if err != nil {
+ return nil
}
- if i.Index == j.Index && i.Name == j.Name && bytes.Equal(i.HardwareAddr, j.HardwareAddr) {
- return true
+ for _, ifi := range ift {
+ if ifi.Flags&FlagLoopback != 0 && ifi.Flags&FlagUp != 0 {
+ return &ifi
+ }
}
- return false
+ return nil
}
func TestInterfaces(t *testing.T) {
@@ -24,24 +30,24 @@ func TestInterfaces(t *testing.T) {
if err != nil {
t.Fatalf("Interfaces failed: %v", err)
}
- t.Logf("table: len/cap = %v/%v\n", len(ift), cap(ift))
+ t.Logf("table: len/cap = %v/%v", len(ift), cap(ift))
for _, ifi := range ift {
ifxi, err := InterfaceByIndex(ifi.Index)
if err != nil {
- t.Fatalf("InterfaceByIndex(%q) failed: %v", ifi.Index, err)
+ t.Fatalf("InterfaceByIndex(%v) failed: %v", ifi.Index, err)
}
- if !sameInterface(ifxi, &ifi) {
- t.Fatalf("InterfaceByIndex(%q) = %v, want %v", ifi.Index, *ifxi, ifi)
+ if !reflect.DeepEqual(ifxi, &ifi) {
+ t.Fatalf("InterfaceByIndex(%v) = %v, want %v", ifi.Index, ifxi, ifi)
}
ifxn, err := InterfaceByName(ifi.Name)
if err != nil {
t.Fatalf("InterfaceByName(%q) failed: %v", ifi.Name, err)
}
- if !sameInterface(ifxn, &ifi) {
- t.Fatalf("InterfaceByName(%q) = %v, want %v", ifi.Name, *ifxn, ifi)
+ if !reflect.DeepEqual(ifxn, &ifi) {
+ t.Fatalf("InterfaceByName(%q) = %v, want %v", ifi.Name, ifxn, ifi)
}
- t.Logf("%q: flags %q, ifindex %v, mtu %v\n", ifi.Name, ifi.Flags.String(), ifi.Index, ifi.MTU)
+ t.Logf("%q: flags %q, ifindex %v, mtu %v", ifi.Name, ifi.Flags.String(), ifi.Index, ifi.MTU)
t.Logf("\thardware address %q", ifi.HardwareAddr.String())
testInterfaceAddrs(t, &ifi)
testInterfaceMulticastAddrs(t, &ifi)
@@ -53,7 +59,7 @@ func TestInterfaceAddrs(t *testing.T) {
if err != nil {
t.Fatalf("InterfaceAddrs failed: %v", err)
}
- t.Logf("table: len/cap = %v/%v\n", len(ifat), cap(ifat))
+ t.Logf("table: len/cap = %v/%v", len(ifat), cap(ifat))
testAddrs(t, ifat)
}
@@ -75,9 +81,13 @@ func testInterfaceMulticastAddrs(t *testing.T, ifi *Interface) {
func testAddrs(t *testing.T, ifat []Addr) {
for _, ifa := range ifat {
- switch ifa.(type) {
+ switch v := ifa.(type) {
case *IPAddr, *IPNet:
- t.Logf("\tinterface address %q\n", ifa.String())
+ if v == nil {
+ t.Errorf("\tunexpected value: %v", ifa)
+ } else {
+ t.Logf("\tinterface address %q", ifa.String())
+ }
default:
t.Errorf("\tunexpected type: %T", ifa)
}
@@ -86,11 +96,79 @@ func testAddrs(t *testing.T, ifat []Addr) {
func testMulticastAddrs(t *testing.T, ifmat []Addr) {
for _, ifma := range ifmat {
- switch ifma.(type) {
+ switch v := ifma.(type) {
case *IPAddr:
- t.Logf("\tjoined group address %q\n", ifma.String())
+ if v == nil {
+ t.Errorf("\tunexpected value: %v", ifma)
+ } else {
+ t.Logf("\tjoined group address %q", ifma.String())
+ }
default:
t.Errorf("\tunexpected type: %T", ifma)
}
}
}
+
+func BenchmarkInterfaces(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ if _, err := Interfaces(); err != nil {
+ b.Fatalf("Interfaces failed: %v", err)
+ }
+ }
+}
+
+func BenchmarkInterfaceByIndex(b *testing.B) {
+ ifi := loopbackInterface()
+ if ifi == nil {
+ b.Skip("loopback interface not found")
+ }
+ for i := 0; i < b.N; i++ {
+ if _, err := InterfaceByIndex(ifi.Index); err != nil {
+ b.Fatalf("InterfaceByIndex failed: %v", err)
+ }
+ }
+}
+
+func BenchmarkInterfaceByName(b *testing.B) {
+ ifi := loopbackInterface()
+ if ifi == nil {
+ b.Skip("loopback interface not found")
+ }
+ for i := 0; i < b.N; i++ {
+ if _, err := InterfaceByName(ifi.Name); err != nil {
+ b.Fatalf("InterfaceByName failed: %v", err)
+ }
+ }
+}
+
+func BenchmarkInterfaceAddrs(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ if _, err := InterfaceAddrs(); err != nil {
+ b.Fatalf("InterfaceAddrs failed: %v", err)
+ }
+ }
+}
+
+func BenchmarkInterfacesAndAddrs(b *testing.B) {
+ ifi := loopbackInterface()
+ if ifi == nil {
+ b.Skip("loopback interface not found")
+ }
+ for i := 0; i < b.N; i++ {
+ if _, err := ifi.Addrs(); err != nil {
+ b.Fatalf("Interface.Addrs failed: %v", err)
+ }
+ }
+}
+
+func BenchmarkInterfacesAndMulticastAddrs(b *testing.B) {
+ ifi := loopbackInterface()
+ if ifi == nil {
+ b.Skip("loopback interface not found")
+ }
+ for i := 0; i < b.N; i++ {
+ if _, err := ifi.MulticastAddrs(); err != nil {
+ b.Fatalf("Interface.MulticastAddrs failed: %v", err)
+ }
+ }
+}
diff --git a/src/pkg/net/interface_unix_test.go b/src/pkg/net/interface_unix_test.go
new file mode 100644
index 000000000..6dbd6e6e7
--- /dev/null
+++ b/src/pkg/net/interface_unix_test.go
@@ -0,0 +1,145 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd
+
+package net
+
+import (
+ "os"
+ "os/exec"
+ "runtime"
+ "testing"
+ "time"
+)
+
+type testInterface struct {
+ name string
+ local string
+ remote string
+ setupCmds []*exec.Cmd
+ teardownCmds []*exec.Cmd
+}
+
+func (ti *testInterface) setup() error {
+ for _, cmd := range ti.setupCmds {
+ if err := cmd.Run(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (ti *testInterface) teardown() error {
+ for _, cmd := range ti.teardownCmds {
+ if err := cmd.Run(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func TestPointToPointInterface(t *testing.T) {
+ switch runtime.GOOS {
+ case "darwin":
+ t.Skipf("skipping read test on %q", runtime.GOOS)
+ }
+ if os.Getuid() != 0 {
+ t.Skip("skipping test; must be root")
+ }
+
+ local, remote := "169.254.0.1", "169.254.0.254"
+ ip := ParseIP(remote)
+ for i := 0; i < 3; i++ {
+ ti := &testInterface{}
+ if err := ti.setPointToPoint(5963+i, local, remote); err != nil {
+ t.Skipf("test requries external command: %v", err)
+ }
+ if err := ti.setup(); err != nil {
+ t.Fatalf("testInterface.setup failed: %v", err)
+ } else {
+ time.Sleep(3 * time.Millisecond)
+ }
+ ift, err := Interfaces()
+ if err != nil {
+ ti.teardown()
+ t.Fatalf("Interfaces failed: %v", err)
+ }
+ for _, ifi := range ift {
+ if ti.name == ifi.Name {
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ ti.teardown()
+ t.Fatalf("Interface.Addrs failed: %v", err)
+ }
+ for _, ifa := range ifat {
+ if ip.Equal(ifa.(*IPNet).IP) {
+ ti.teardown()
+ t.Fatalf("got %v; want %v", ip, local)
+ }
+ }
+ }
+ }
+ if err := ti.teardown(); err != nil {
+ t.Fatalf("testInterface.teardown failed: %v", err)
+ } else {
+ time.Sleep(3 * time.Millisecond)
+ }
+ }
+}
+
+func TestInterfaceArrivalAndDeparture(t *testing.T) {
+ if os.Getuid() != 0 {
+ t.Skip("skipping test; must be root")
+ }
+
+ for i := 0; i < 3; i++ {
+ ift1, err := Interfaces()
+ if err != nil {
+ t.Fatalf("Interfaces failed: %v", err)
+ }
+ ti := &testInterface{}
+ if err := ti.setBroadcast(5682 + i); err != nil {
+ t.Skipf("test requires external command: %v", err)
+ }
+ if err := ti.setup(); err != nil {
+ t.Fatalf("testInterface.setup failed: %v", err)
+ } else {
+ time.Sleep(3 * time.Millisecond)
+ }
+ ift2, err := Interfaces()
+ if err != nil {
+ ti.teardown()
+ t.Fatalf("Interfaces failed: %v", err)
+ }
+ if len(ift2) <= len(ift1) {
+ for _, ifi := range ift1 {
+ t.Logf("before: %v", ifi)
+ }
+ for _, ifi := range ift2 {
+ t.Logf("after: %v", ifi)
+ }
+ ti.teardown()
+ t.Fatalf("got %v; want gt %v", len(ift2), len(ift1))
+ }
+ if err := ti.teardown(); err != nil {
+ t.Fatalf("testInterface.teardown failed: %v", err)
+ } else {
+ time.Sleep(3 * time.Millisecond)
+ }
+ ift3, err := Interfaces()
+ if err != nil {
+ t.Fatalf("Interfaces failed: %v", err)
+ }
+ if len(ift3) >= len(ift2) {
+ for _, ifi := range ift2 {
+ t.Logf("before: %v", ifi)
+ }
+ for _, ifi := range ift3 {
+ t.Logf("after: %v", ifi)
+ }
+ t.Fatalf("got %v; want lt %v", len(ift3), len(ift2))
+ }
+ }
+}
diff --git a/src/pkg/net/interface_windows.go b/src/pkg/net/interface_windows.go
index 4368b3306..0759dc255 100644
--- a/src/pkg/net/interface_windows.go
+++ b/src/pkg/net/interface_windows.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Network interface identification for Windows
-
package net
import (
@@ -25,6 +23,9 @@ func getAdapterList() (*syscall.IpAdapterInfo, error) {
b := make([]byte, 1000)
l := uint32(len(b))
a := (*syscall.IpAdapterInfo)(unsafe.Pointer(&b[0]))
+ // TODO(mikio): GetAdaptersInfo returns IP_ADAPTER_INFO that
+ // contains IPv4 address list only. We should use another API
+ // for fetching IPv6 stuff from the kernel.
err := syscall.GetAdaptersInfo(a, &l)
if err == syscall.ERROR_BUFFER_OVERFLOW {
b = make([]byte, l)
@@ -38,7 +39,7 @@ func getAdapterList() (*syscall.IpAdapterInfo, error) {
}
func getInterfaceList() ([]syscall.InterfaceInfo, error) {
- s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
+ s, err := sysSocket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
if err != nil {
return nil, os.NewSyscallError("Socket", err)
}
@@ -126,10 +127,10 @@ func interfaceTable(ifindex int) ([]Interface, error) {
return ift, nil
}
-// If the ifindex is zero, interfaceAddrTable returns addresses
-// for all network interfaces. Otherwise it returns addresses
-// for a specific interface.
-func interfaceAddrTable(ifindex int) ([]Addr, error) {
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
ai, err := getAdapterList()
if err != nil {
return nil, err
@@ -138,11 +139,10 @@ func interfaceAddrTable(ifindex int) ([]Addr, error) {
var ifat []Addr
for ; ai != nil; ai = ai.Next {
index := ai.Index
- if ifindex == 0 || ifindex == int(index) {
+ if ifi == nil || ifi.Index == int(index) {
ipl := &ai.IpAddressList
for ; ipl != nil; ipl = ipl.Next {
- ifa := IPAddr{}
- ifa.IP = parseIPv4(bytePtrToString(&ipl.IpAddress.String[0]))
+ ifa := IPAddr{IP: parseIPv4(bytePtrToString(&ipl.IpAddress.String[0]))}
ifat = append(ifat, ifa.toAddr())
}
}
@@ -150,9 +150,9 @@ func interfaceAddrTable(ifindex int) ([]Addr, error) {
return ifat, nil
}
-// If the ifindex is zero, interfaceMulticastAddrTable returns
-// addresses for all network interfaces. Otherwise it returns
-// addresses for a specific interface.
-func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) {
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ // TODO(mikio): Implement this like other platforms.
return nil, nil
}
diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go
index 979d7acd5..d588e3a42 100644
--- a/src/pkg/net/ip.go
+++ b/src/pkg/net/ip.go
@@ -7,7 +7,7 @@
// IPv4 addresses are 4 bytes; IPv6 addresses are 16 bytes.
// An IPv4 address can be converted to an IPv6 address by
// adding a canonical prefix (10 zeros, 2 0xFFs).
-// This library accepts either size of byte array but always
+// This library accepts either size of byte slice but always
// returns 16-byte addresses.
package net
@@ -18,14 +18,14 @@ const (
IPv6len = 16
)
-// An IP is a single IP address, an array of bytes.
+// An IP is a single IP address, a slice of bytes.
// Functions in this package accept either 4-byte (IPv4)
-// or 16-byte (IPv6) arrays as input.
+// or 16-byte (IPv6) slices as input.
//
// Note that in this documentation, referring to an
// IP address as an IPv4 address or an IPv6 address
// is a semantic property of the address, not just the
-// length of the byte array: a 16-byte array can still
+// length of the byte slice: a 16-byte slice can still
// be an IPv4 address.
type IP []byte
@@ -36,6 +36,7 @@ type IPMask []byte
type IPNet struct {
IP IP // network number
Mask IPMask // network mask
+ Zone string // IPv6 scoped addressing zone
}
// IPv4 returns the IP address (in 16-byte form) of the
@@ -645,5 +646,5 @@ func ParseCIDR(s string) (IP, *IPNet, error) {
return nil, nil, &ParseError{"CIDR address", s}
}
m := CIDRMask(n, 8*iplen)
- return ip, &IPNet{ip.Mask(m), m}, nil
+ return ip, &IPNet{IP: ip.Mask(m), Mask: m}, nil
}
diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go
index df647ef73..f8b7f067f 100644
--- a/src/pkg/net/ip_test.go
+++ b/src/pkg/net/ip_test.go
@@ -114,23 +114,23 @@ var parsecidrtests = []struct {
net *IPNet
err error
}{
- {"135.104.0.0/32", IPv4(135, 104, 0, 0), &IPNet{IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 255)}, nil},
- {"0.0.0.0/24", IPv4(0, 0, 0, 0), &IPNet{IPv4(0, 0, 0, 0), IPv4Mask(255, 255, 255, 0)}, nil},
- {"135.104.0.0/24", IPv4(135, 104, 0, 0), &IPNet{IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 0)}, nil},
- {"135.104.0.1/32", IPv4(135, 104, 0, 1), &IPNet{IPv4(135, 104, 0, 1), IPv4Mask(255, 255, 255, 255)}, nil},
- {"135.104.0.1/24", IPv4(135, 104, 0, 1), &IPNet{IPv4(135, 104, 0, 0), IPv4Mask(255, 255, 255, 0)}, nil},
- {"::1/128", ParseIP("::1"), &IPNet{ParseIP("::1"), IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"))}, nil},
- {"abcd:2345::/127", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe"))}, nil},
- {"abcd:2345::/65", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff:8000::"))}, nil},
- {"abcd:2345::/64", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:ffff::"))}, nil},
- {"abcd:2345::/63", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:ffff:fffe::"))}, nil},
- {"abcd:2345::/33", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff:8000::"))}, nil},
- {"abcd:2345::/32", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2345::"), IPMask(ParseIP("ffff:ffff::"))}, nil},
- {"abcd:2344::/31", ParseIP("abcd:2344::"), &IPNet{ParseIP("abcd:2344::"), IPMask(ParseIP("ffff:fffe::"))}, nil},
- {"abcd:2300::/24", ParseIP("abcd:2300::"), &IPNet{ParseIP("abcd:2300::"), IPMask(ParseIP("ffff:ff00::"))}, nil},
- {"abcd:2345::/24", ParseIP("abcd:2345::"), &IPNet{ParseIP("abcd:2300::"), IPMask(ParseIP("ffff:ff00::"))}, nil},
- {"2001:DB8::/48", ParseIP("2001:DB8::"), &IPNet{ParseIP("2001:DB8::"), IPMask(ParseIP("ffff:ffff:ffff::"))}, nil},
- {"2001:DB8::1/48", ParseIP("2001:DB8::1"), &IPNet{ParseIP("2001:DB8::"), IPMask(ParseIP("ffff:ffff:ffff::"))}, nil},
+ {"135.104.0.0/32", IPv4(135, 104, 0, 0), &IPNet{IP: IPv4(135, 104, 0, 0), Mask: IPv4Mask(255, 255, 255, 255)}, nil},
+ {"0.0.0.0/24", IPv4(0, 0, 0, 0), &IPNet{IP: IPv4(0, 0, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)}, nil},
+ {"135.104.0.0/24", IPv4(135, 104, 0, 0), &IPNet{IP: IPv4(135, 104, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)}, nil},
+ {"135.104.0.1/32", IPv4(135, 104, 0, 1), &IPNet{IP: IPv4(135, 104, 0, 1), Mask: IPv4Mask(255, 255, 255, 255)}, nil},
+ {"135.104.0.1/24", IPv4(135, 104, 0, 1), &IPNet{IP: IPv4(135, 104, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)}, nil},
+ {"::1/128", ParseIP("::1"), &IPNet{IP: ParseIP("::1"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"))}, nil},
+ {"abcd:2345::/127", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe"))}, nil},
+ {"abcd:2345::/65", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:8000::"))}, nil},
+ {"abcd:2345::/64", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff::"))}, nil},
+ {"abcd:2345::/63", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:fffe::"))}, nil},
+ {"abcd:2345::/33", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:8000::"))}, nil},
+ {"abcd:2345::/32", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff::"))}, nil},
+ {"abcd:2344::/31", ParseIP("abcd:2344::"), &IPNet{IP: ParseIP("abcd:2344::"), Mask: IPMask(ParseIP("ffff:fffe::"))}, nil},
+ {"abcd:2300::/24", ParseIP("abcd:2300::"), &IPNet{IP: ParseIP("abcd:2300::"), Mask: IPMask(ParseIP("ffff:ff00::"))}, nil},
+ {"abcd:2345::/24", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2300::"), Mask: IPMask(ParseIP("ffff:ff00::"))}, nil},
+ {"2001:DB8::/48", ParseIP("2001:DB8::"), &IPNet{IP: ParseIP("2001:DB8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff::"))}, nil},
+ {"2001:DB8::1/48", ParseIP("2001:DB8::1"), &IPNet{IP: ParseIP("2001:DB8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff::"))}, nil},
{"192.168.1.1/255.255.255.0", nil, nil, &ParseError{"CIDR address", "192.168.1.1/255.255.255.0"}},
{"192.168.1.1/35", nil, nil, &ParseError{"CIDR address", "192.168.1.1/35"}},
{"2001:db8::1/-1", nil, nil, &ParseError{"CIDR address", "2001:db8::1/-1"}},
@@ -154,14 +154,14 @@ var ipnetcontainstests = []struct {
net *IPNet
ok bool
}{
- {IPv4(172, 16, 1, 1), &IPNet{IPv4(172, 16, 0, 0), CIDRMask(12, 32)}, true},
- {IPv4(172, 24, 0, 1), &IPNet{IPv4(172, 16, 0, 0), CIDRMask(13, 32)}, false},
- {IPv4(192, 168, 0, 3), &IPNet{IPv4(192, 168, 0, 0), IPv4Mask(0, 0, 255, 252)}, true},
- {IPv4(192, 168, 0, 4), &IPNet{IPv4(192, 168, 0, 0), IPv4Mask(0, 255, 0, 252)}, false},
- {ParseIP("2001:db8:1:2::1"), &IPNet{ParseIP("2001:db8:1::"), CIDRMask(47, 128)}, true},
- {ParseIP("2001:db8:1:2::1"), &IPNet{ParseIP("2001:db8:2::"), CIDRMask(47, 128)}, false},
- {ParseIP("2001:db8:1:2::1"), &IPNet{ParseIP("2001:db8:1::"), IPMask(ParseIP("ffff:0:ffff::"))}, true},
- {ParseIP("2001:db8:1:2::1"), &IPNet{ParseIP("2001:db8:1::"), IPMask(ParseIP("0:0:0:ffff::"))}, false},
+ {IPv4(172, 16, 1, 1), &IPNet{IP: IPv4(172, 16, 0, 0), Mask: CIDRMask(12, 32)}, true},
+ {IPv4(172, 24, 0, 1), &IPNet{IP: IPv4(172, 16, 0, 0), Mask: CIDRMask(13, 32)}, false},
+ {IPv4(192, 168, 0, 3), &IPNet{IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(0, 0, 255, 252)}, true},
+ {IPv4(192, 168, 0, 4), &IPNet{IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(0, 255, 0, 252)}, false},
+ {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:1::"), Mask: CIDRMask(47, 128)}, true},
+ {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:2::"), Mask: CIDRMask(47, 128)}, false},
+ {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:1::"), Mask: IPMask(ParseIP("ffff:0:ffff::"))}, true},
+ {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:1::"), Mask: IPMask(ParseIP("0:0:0:ffff::"))}, false},
}
func TestIPNetContains(t *testing.T) {
@@ -176,10 +176,10 @@ var ipnetstringtests = []struct {
in *IPNet
out string
}{
- {&IPNet{IPv4(192, 168, 1, 0), CIDRMask(26, 32)}, "192.168.1.0/26"},
- {&IPNet{IPv4(192, 168, 1, 0), IPv4Mask(255, 0, 255, 0)}, "192.168.1.0/ff00ff00"},
- {&IPNet{ParseIP("2001:db8::"), CIDRMask(55, 128)}, "2001:db8::/55"},
- {&IPNet{ParseIP("2001:db8::"), IPMask(ParseIP("8000:f123:0:cafe::"))}, "2001:db8::/8000f1230000cafe0000000000000000"},
+ {&IPNet{IP: IPv4(192, 168, 1, 0), Mask: CIDRMask(26, 32)}, "192.168.1.0/26"},
+ {&IPNet{IP: IPv4(192, 168, 1, 0), Mask: IPv4Mask(255, 0, 255, 0)}, "192.168.1.0/ff00ff00"},
+ {&IPNet{IP: ParseIP("2001:db8::"), Mask: CIDRMask(55, 128)}, "2001:db8::/55"},
+ {&IPNet{IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("8000:f123:0:cafe::"))}, "2001:db8::/8000f1230000cafe0000000000000000"},
}
func TestIPNetString(t *testing.T) {
@@ -233,27 +233,27 @@ var networknumberandmasktests = []struct {
in IPNet
out IPNet
}{
- {IPNet{v4addr, v4mask}, IPNet{v4addr, v4mask}},
- {IPNet{v4addr, v4mappedv6mask}, IPNet{v4addr, v4mask}},
- {IPNet{v4mappedv6addr, v4mappedv6mask}, IPNet{v4addr, v4mask}},
- {IPNet{v4mappedv6addr, v6mask}, IPNet{v4addr, v4maskzero}},
- {IPNet{v4addr, v6mask}, IPNet{v4addr, v4maskzero}},
- {IPNet{v6addr, v6mask}, IPNet{v6addr, v6mask}},
- {IPNet{v6addr, v4mappedv6mask}, IPNet{v6addr, v4mappedv6mask}},
- {in: IPNet{v6addr, v4mask}},
- {in: IPNet{v4addr, badmask}},
- {in: IPNet{v4mappedv6addr, badmask}},
- {in: IPNet{v6addr, badmask}},
- {in: IPNet{badaddr, v4mask}},
- {in: IPNet{badaddr, v4mappedv6mask}},
- {in: IPNet{badaddr, v6mask}},
- {in: IPNet{badaddr, badmask}},
+ {IPNet{IP: v4addr, Mask: v4mask}, IPNet{IP: v4addr, Mask: v4mask}},
+ {IPNet{IP: v4addr, Mask: v4mappedv6mask}, IPNet{IP: v4addr, Mask: v4mask}},
+ {IPNet{IP: v4mappedv6addr, Mask: v4mappedv6mask}, IPNet{IP: v4addr, Mask: v4mask}},
+ {IPNet{IP: v4mappedv6addr, Mask: v6mask}, IPNet{IP: v4addr, Mask: v4maskzero}},
+ {IPNet{IP: v4addr, Mask: v6mask}, IPNet{IP: v4addr, Mask: v4maskzero}},
+ {IPNet{IP: v6addr, Mask: v6mask}, IPNet{IP: v6addr, Mask: v6mask}},
+ {IPNet{IP: v6addr, Mask: v4mappedv6mask}, IPNet{IP: v6addr, Mask: v4mappedv6mask}},
+ {in: IPNet{IP: v6addr, Mask: v4mask}},
+ {in: IPNet{IP: v4addr, Mask: badmask}},
+ {in: IPNet{IP: v4mappedv6addr, Mask: badmask}},
+ {in: IPNet{IP: v6addr, Mask: badmask}},
+ {in: IPNet{IP: badaddr, Mask: v4mask}},
+ {in: IPNet{IP: badaddr, Mask: v4mappedv6mask}},
+ {in: IPNet{IP: badaddr, Mask: v6mask}},
+ {in: IPNet{IP: badaddr, Mask: badmask}},
}
func TestNetworkNumberAndMask(t *testing.T) {
for _, tt := range networknumberandmasktests {
ip, m := networkNumberAndMask(&tt.in)
- out := &IPNet{ip, m}
+ out := &IPNet{IP: ip, Mask: m}
if !reflect.DeepEqual(&tt.out, out) {
t.Errorf("networkNumberAndMask(%v) = %v; want %v", tt.in, out, &tt.out)
}
@@ -268,6 +268,29 @@ var splitjointests = []struct {
{"www.google.com", "80", "www.google.com:80"},
{"127.0.0.1", "1234", "127.0.0.1:1234"},
{"::1", "80", "[::1]:80"},
+ {"google.com", "https%foo", "google.com:https%foo"}, // Go 1.0 behavior
+ {"", "0", ":0"},
+ {"127.0.0.1", "", "127.0.0.1:"}, // Go 1.0 behaviour
+ {"www.google.com", "", "www.google.com:"}, // Go 1.0 behaviour
+}
+
+var splitfailuretests = []struct {
+ HostPort string
+ Err string
+}{
+ {"www.google.com", "missing port in address"},
+ {"127.0.0.1", "missing port in address"},
+ {"[::1]", "missing port in address"},
+ {"::1", "too many colons in address"},
+
+ // Test cases that didn't fail in Go 1.0
+ {"[foo:bar]", "missing port in address"},
+ {"[foo:bar]baz", "missing port in address"},
+ {"[foo]:[bar]:baz", "too many colons in address"},
+ {"[foo]bar:baz", "missing port in address"},
+ {"[foo]:[bar]baz", "unexpected '[' in address"},
+ {"foo[bar]:baz", "unexpected '[' in address"},
+ {"foo]bar:baz", "unexpected ']' in address"},
}
func TestSplitHostPort(t *testing.T) {
@@ -276,6 +299,16 @@ func TestSplitHostPort(t *testing.T) {
t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.Join, host, port, err, tt.Host, tt.Port)
}
}
+ for _, tt := range splitfailuretests {
+ if _, _, err := SplitHostPort(tt.HostPort); err == nil {
+ t.Errorf("SplitHostPort(%q) should have failed", tt.HostPort)
+ } else {
+ e := err.(*AddrError)
+ if e.Err != tt.Err {
+ t.Errorf("SplitHostPort(%q) = _, _, %q; want %q", tt.HostPort, e.Err, tt.Err)
+ }
+ }
+ }
}
func TestJoinHostPort(t *testing.T) {
diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go
index 613620272..65defc7ea 100644
--- a/src/pkg/net/ipraw_test.go
+++ b/src/pkg/net/ipraw_test.go
@@ -2,205 +2,340 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !plan9
+
package net
import (
"bytes"
+ "errors"
"os"
- "syscall"
+ "reflect"
"testing"
"time"
)
-var icmpTests = []struct {
+var resolveIPAddrTests = []struct {
+ net string
+ litAddr string
+ addr *IPAddr
+ err error
+}{
+ {"ip", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil},
+ {"ip4", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil},
+ {"ip4:icmp", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil},
+
+ {"ip", "::1", &IPAddr{IP: ParseIP("::1")}, nil},
+ {"ip6", "::1", &IPAddr{IP: ParseIP("::1")}, nil},
+ {"ip6:icmp", "::1", &IPAddr{IP: ParseIP("::1")}, nil},
+
+ {"", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, // Go 1.0 behavior
+ {"", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, // Go 1.0 behavior
+
+ {"l2tp", "127.0.0.1", nil, UnknownNetworkError("l2tp")},
+ {"l2tp:gre", "127.0.0.1", nil, UnknownNetworkError("l2tp:gre")},
+ {"tcp", "1.2.3.4:123", nil, UnknownNetworkError("tcp")},
+}
+
+func TestResolveIPAddr(t *testing.T) {
+ for _, tt := range resolveIPAddrTests {
+ addr, err := ResolveIPAddr(tt.net, tt.litAddr)
+ if err != tt.err {
+ t.Fatalf("ResolveIPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err)
+ }
+ if !reflect.DeepEqual(addr, tt.addr) {
+ t.Fatalf("got %#v; expected %#v", addr, tt.addr)
+ }
+ }
+}
+
+var icmpEchoTests = []struct {
net string
laddr string
raddr string
- ipv6 bool // test with underlying AF_INET6 socket
}{
- {"ip4:icmp", "", "127.0.0.1", false},
- {"ip6:icmp", "", "::1", true},
+ {"ip4:icmp", "0.0.0.0", "127.0.0.1"},
+ {"ip6:ipv6-icmp", "::", "::1"},
}
-func TestICMP(t *testing.T) {
+func TestConnICMPEcho(t *testing.T) {
if os.Getuid() != 0 {
- t.Logf("test disabled; must be root")
- return
+ t.Skip("skipping test; must be root")
}
- seqnum := 61455
- for _, tt := range icmpTests {
- if tt.ipv6 && !supportsIPv6 {
+ for i, tt := range icmpEchoTests {
+ net, _, err := parseNetwork(tt.net)
+ if err != nil {
+ t.Fatalf("parseNetwork failed: %v", err)
+ }
+ if net == "ip6" && !supportsIPv6 {
continue
}
- id := os.Getpid() & 0xffff
- seqnum++
- echo := newICMPEchoRequest(tt.net, id, seqnum, 128, []byte("Go Go Gadget Ping!!!"))
- exchangeICMPEcho(t, tt.net, tt.laddr, tt.raddr, echo)
- }
-}
-
-func exchangeICMPEcho(t *testing.T, net, laddr, raddr string, echo []byte) {
- c, err := ListenPacket(net, laddr)
- if err != nil {
- t.Errorf("ListenPacket(%q, %q) failed: %v", net, laddr, err)
- return
- }
- c.SetDeadline(time.Now().Add(100 * time.Millisecond))
- defer c.Close()
-
- ra, err := ResolveIPAddr(net, raddr)
- if err != nil {
- t.Errorf("ResolveIPAddr(%q, %q) failed: %v", net, raddr, err)
- return
- }
- waitForReady := make(chan bool)
- go icmpEchoTransponder(t, net, raddr, waitForReady)
- <-waitForReady
-
- _, err = c.WriteTo(echo, ra)
- if err != nil {
- t.Errorf("WriteTo failed: %v", err)
- return
- }
+ c, err := Dial(tt.net, tt.raddr)
+ if err != nil {
+ t.Fatalf("Dial failed: %v", err)
+ }
+ c.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ defer c.Close()
- reply := make([]byte, 256)
- for {
- _, _, err := c.ReadFrom(reply)
+ typ := icmpv4EchoRequest
+ if net == "ip6" {
+ typ = icmpv6EchoRequest
+ }
+ xid, xseq := os.Getpid()&0xffff, i+1
+ b, err := (&icmpMessage{
+ Type: typ, Code: 0,
+ Body: &icmpEcho{
+ ID: xid, Seq: xseq,
+ Data: bytes.Repeat([]byte("Go Go Gadget Ping!!!"), 3),
+ },
+ }).Marshal()
if err != nil {
- t.Errorf("ReadFrom failed: %v", err)
- return
+ t.Fatalf("icmpMessage.Marshal failed: %v", err)
}
- switch c.(*IPConn).fd.family {
- case syscall.AF_INET:
- if reply[0] != ICMP4_ECHO_REPLY {
- continue
+ if _, err := c.Write(b); err != nil {
+ t.Fatalf("Conn.Write failed: %v", err)
+ }
+ var m *icmpMessage
+ for {
+ if _, err := c.Read(b); err != nil {
+ t.Fatalf("Conn.Read failed: %v", err)
+ }
+ if net == "ip4" {
+ b = ipv4Payload(b)
}
- case syscall.AF_INET6:
- if reply[0] != ICMP6_ECHO_REPLY {
+ if m, err = parseICMPMessage(b); err != nil {
+ t.Fatalf("parseICMPMessage failed: %v", err)
+ }
+ switch m.Type {
+ case icmpv4EchoRequest, icmpv6EchoRequest:
continue
}
+ break
}
- xid, xseqnum := parseICMPEchoReply(echo)
- rid, rseqnum := parseICMPEchoReply(reply)
- if rid != xid || rseqnum != xseqnum {
- t.Errorf("ID = %v, Seqnum = %v, want ID = %v, Seqnum = %v", rid, rseqnum, xid, xseqnum)
- return
+ switch p := m.Body.(type) {
+ case *icmpEcho:
+ if p.ID != xid || p.Seq != xseq {
+ t.Fatalf("got id=%v, seqnum=%v; expected id=%v, seqnum=%v", p.ID, p.Seq, xid, xseq)
+ }
+ default:
+ t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, typ, 0)
}
- break
}
}
-func icmpEchoTransponder(t *testing.T, net, raddr string, waitForReady chan bool) {
- c, err := Dial(net, raddr)
- if err != nil {
- waitForReady <- true
- t.Errorf("Dial(%q, %q) failed: %v", net, raddr, err)
- return
+func TestPacketConnICMPEcho(t *testing.T) {
+ if os.Getuid() != 0 {
+ t.Skip("skipping test; must be root")
}
- c.SetDeadline(time.Now().Add(100 * time.Millisecond))
- defer c.Close()
- waitForReady <- true
- echo := make([]byte, 256)
- var nr int
- for {
- nr, err = c.Read(echo)
+ for i, tt := range icmpEchoTests {
+ net, _, err := parseNetwork(tt.net)
if err != nil {
- t.Errorf("Read failed: %v", err)
- return
+ t.Fatalf("parseNetwork failed: %v", err)
}
- switch c.(*IPConn).fd.family {
- case syscall.AF_INET:
- if echo[0] != ICMP4_ECHO_REQUEST {
- continue
+ if net == "ip6" && !supportsIPv6 {
+ continue
+ }
+
+ c, err := ListenPacket(tt.net, tt.laddr)
+ if err != nil {
+ t.Fatalf("ListenPacket failed: %v", err)
+ }
+ c.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ defer c.Close()
+
+ ra, err := ResolveIPAddr(tt.net, tt.raddr)
+ if err != nil {
+ t.Fatalf("ResolveIPAddr failed: %v", err)
+ }
+ typ := icmpv4EchoRequest
+ if net == "ip6" {
+ typ = icmpv6EchoRequest
+ }
+ xid, xseq := os.Getpid()&0xffff, i+1
+ b, err := (&icmpMessage{
+ Type: typ, Code: 0,
+ Body: &icmpEcho{
+ ID: xid, Seq: xseq,
+ Data: bytes.Repeat([]byte("Go Go Gadget Ping!!!"), 3),
+ },
+ }).Marshal()
+ if err != nil {
+ t.Fatalf("icmpMessage.Marshal failed: %v", err)
+ }
+ if _, err := c.WriteTo(b, ra); err != nil {
+ t.Fatalf("PacketConn.WriteTo failed: %v", err)
+ }
+ var m *icmpMessage
+ for {
+ if _, _, err := c.ReadFrom(b); err != nil {
+ t.Fatalf("PacketConn.ReadFrom failed: %v", err)
}
- case syscall.AF_INET6:
- if echo[0] != ICMP6_ECHO_REQUEST {
+ // TODO: fix issue 3944
+ //if net == "ip4" {
+ // b = ipv4Payload(b)
+ //}
+ if m, err = parseICMPMessage(b); err != nil {
+ t.Fatalf("parseICMPMessage failed: %v", err)
+ }
+ switch m.Type {
+ case icmpv4EchoRequest, icmpv6EchoRequest:
continue
}
+ break
+ }
+ switch p := m.Body.(type) {
+ case *icmpEcho:
+ if p.ID != xid || p.Seq != xseq {
+ t.Fatalf("got id=%v, seqnum=%v; expected id=%v, seqnum=%v", p.ID, p.Seq, xid, xseq)
+ }
+ default:
+ t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, typ, 0)
}
- break
- }
-
- switch c.(*IPConn).fd.family {
- case syscall.AF_INET:
- echo[0] = ICMP4_ECHO_REPLY
- case syscall.AF_INET6:
- echo[0] = ICMP6_ECHO_REPLY
}
+}
- _, err = c.Write(echo[:nr])
- if err != nil {
- t.Errorf("Write failed: %v", err)
- return
+func ipv4Payload(b []byte) []byte {
+ if len(b) < 20 {
+ return b
}
+ hdrlen := int(b[0]&0x0f) << 2
+ return b[hdrlen:]
}
const (
- ICMP4_ECHO_REQUEST = 8
- ICMP4_ECHO_REPLY = 0
- ICMP6_ECHO_REQUEST = 128
- ICMP6_ECHO_REPLY = 129
+ icmpv4EchoRequest = 8
+ icmpv4EchoReply = 0
+ icmpv6EchoRequest = 128
+ icmpv6EchoReply = 129
)
-func newICMPEchoRequest(net string, id, seqnum, msglen int, filler []byte) []byte {
- afnet, _, _ := parseDialNetwork(net)
- switch afnet {
- case "ip4":
- return newICMPv4EchoRequest(id, seqnum, msglen, filler)
- case "ip6":
- return newICMPv6EchoRequest(id, seqnum, msglen, filler)
- }
- return nil
+// icmpMessage represents an ICMP message.
+type icmpMessage struct {
+ Type int // type
+ Code int // code
+ Checksum int // checksum
+ Body icmpMessageBody // body
}
-func newICMPv4EchoRequest(id, seqnum, msglen int, filler []byte) []byte {
- b := newICMPInfoMessage(id, seqnum, msglen, filler)
- b[0] = ICMP4_ECHO_REQUEST
+// icmpMessageBody represents an ICMP message body.
+type icmpMessageBody interface {
+ Len() int
+ Marshal() ([]byte, error)
+}
- // calculate ICMP checksum
- cklen := len(b)
+// Marshal returns the binary enconding of the ICMP echo request or
+// reply message m.
+func (m *icmpMessage) Marshal() ([]byte, error) {
+ b := []byte{byte(m.Type), byte(m.Code), 0, 0}
+ if m.Body != nil && m.Body.Len() != 0 {
+ mb, err := m.Body.Marshal()
+ if err != nil {
+ return nil, err
+ }
+ b = append(b, mb...)
+ }
+ switch m.Type {
+ case icmpv6EchoRequest, icmpv6EchoReply:
+ return b, nil
+ }
+ csumcv := len(b) - 1 // checksum coverage
s := uint32(0)
- for i := 0; i < cklen-1; i += 2 {
+ for i := 0; i < csumcv; i += 2 {
s += uint32(b[i+1])<<8 | uint32(b[i])
}
- if cklen&1 == 1 {
- s += uint32(b[cklen-1])
+ if csumcv&1 == 0 {
+ s += uint32(b[csumcv])
}
- s = (s >> 16) + (s & 0xffff)
- s = s + (s >> 16)
- // place checksum back in header; using ^= avoids the
- // assumption the checksum bytes are zero
- b[2] ^= uint8(^s & 0xff)
- b[3] ^= uint8(^s >> 8)
+ s = s>>16 + s&0xffff
+ s = s + s>>16
+ // Place checksum back in header; using ^= avoids the
+ // assumption the checksum bytes are zero.
+ b[2] ^= byte(^s & 0xff)
+ b[3] ^= byte(^s >> 8)
+ return b, nil
+}
- return b
+// parseICMPMessage parses b as an ICMP message.
+func parseICMPMessage(b []byte) (*icmpMessage, error) {
+ msglen := len(b)
+ if msglen < 4 {
+ return nil, errors.New("message too short")
+ }
+ m := &icmpMessage{Type: int(b[0]), Code: int(b[1]), Checksum: int(b[2])<<8 | int(b[3])}
+ if msglen > 4 {
+ var err error
+ switch m.Type {
+ case icmpv4EchoRequest, icmpv4EchoReply, icmpv6EchoRequest, icmpv6EchoReply:
+ m.Body, err = parseICMPEcho(b[4:])
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+ return m, nil
+}
+
+// imcpEcho represenets an ICMP echo request or reply message body.
+type icmpEcho struct {
+ ID int // identifier
+ Seq int // sequence number
+ Data []byte // data
+}
+
+func (p *icmpEcho) Len() int {
+ if p == nil {
+ return 0
+ }
+ return 4 + len(p.Data)
}
-func newICMPv6EchoRequest(id, seqnum, msglen int, filler []byte) []byte {
- b := newICMPInfoMessage(id, seqnum, msglen, filler)
- b[0] = ICMP6_ECHO_REQUEST
- return b
+// Marshal returns the binary enconding of the ICMP echo request or
+// reply message body p.
+func (p *icmpEcho) Marshal() ([]byte, error) {
+ b := make([]byte, 4+len(p.Data))
+ b[0], b[1] = byte(p.ID>>8), byte(p.ID&0xff)
+ b[2], b[3] = byte(p.Seq>>8), byte(p.Seq&0xff)
+ copy(b[4:], p.Data)
+ return b, nil
}
-func newICMPInfoMessage(id, seqnum, msglen int, filler []byte) []byte {
- b := make([]byte, msglen)
- copy(b[8:], bytes.Repeat(filler, (msglen-8)/len(filler)+1))
- b[0] = 0 // type
- b[1] = 0 // code
- b[2] = 0 // checksum
- b[3] = 0 // checksum
- b[4] = uint8(id >> 8) // identifier
- b[5] = uint8(id & 0xff) // identifier
- b[6] = uint8(seqnum >> 8) // sequence number
- b[7] = uint8(seqnum & 0xff) // sequence number
- return b
+// parseICMPEcho parses b as an ICMP echo request or reply message
+// body.
+func parseICMPEcho(b []byte) (*icmpEcho, error) {
+ bodylen := len(b)
+ p := &icmpEcho{ID: int(b[0])<<8 | int(b[1]), Seq: int(b[2])<<8 | int(b[3])}
+ if bodylen > 4 {
+ p.Data = make([]byte, bodylen-4)
+ copy(p.Data, b[4:])
+ }
+ return p, nil
}
-func parseICMPEchoReply(b []byte) (id, seqnum int) {
- id = int(b[4])<<8 | int(b[5])
- seqnum = int(b[6])<<8 | int(b[7])
- return
+var ipConnLocalNameTests = []struct {
+ net string
+ laddr *IPAddr
+}{
+ {"ip4:icmp", &IPAddr{IP: IPv4(127, 0, 0, 1)}},
+ {"ip4:icmp", &IPAddr{}},
+ {"ip4:icmp", nil},
+}
+
+func TestIPConnLocalName(t *testing.T) {
+ if os.Getuid() != 0 {
+ t.Skip("skipping test; must be root")
+ }
+
+ for _, tt := range ipConnLocalNameTests {
+ c, err := ListenIP(tt.net, tt.laddr)
+ if err != nil {
+ t.Fatalf("ListenIP failed: %v", err)
+ }
+ defer c.Close()
+ if la := c.LocalAddr(); la == nil {
+ t.Fatal("IPConn.LocalAddr failed")
+ }
+ }
}
diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go
index b23213ee1..daccba366 100644
--- a/src/pkg/net/iprawsock.go
+++ b/src/pkg/net/iprawsock.go
@@ -2,13 +2,14 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// (Raw) IP sockets
+// Raw IP sockets
package net
-// IPAddr represents the address of a IP end point.
+// IPAddr represents the address of an IP end point.
type IPAddr struct {
- IP IP
+ IP IP
+ Zone string // IPv6 scoped addressing zone
}
// Network returns the address's network name, "ip".
@@ -21,45 +22,25 @@ func (a *IPAddr) String() string {
return a.IP.String()
}
-// ResolveIPAddr parses addr as a IP address and resolves domain
+// ResolveIPAddr parses addr as an IP address and resolves domain
// names to numeric addresses on the network net, which must be
-// "ip", "ip4" or "ip6". A literal IPv6 host address must be
-// enclosed in square brackets, as in "[::]".
+// "ip", "ip4" or "ip6".
func ResolveIPAddr(net, addr string) (*IPAddr, error) {
- ip, err := hostToIP(net, addr)
+ if net == "" { // a hint wildcard for Go 1.0 undocumented behavior
+ net = "ip"
+ }
+ afnet, _, err := parseNetwork(net)
if err != nil {
return nil, err
}
- return &IPAddr{ip}, nil
-}
-
-// Convert "host" into IP address.
-func hostToIP(net, host string) (ip IP, err error) {
- var addr IP
- // Try as an IP address.
- addr = ParseIP(host)
- if addr == nil {
- filter := anyaddr
- if net != "" && net[len(net)-1] == '4' {
- filter = ipv4only
- }
- if net != "" && net[len(net)-1] == '6' {
- filter = ipv6only
- }
- // Not an IP address. Try as a DNS name.
- addrs, err1 := LookupHost(host)
- if err1 != nil {
- err = err1
- goto Error
- }
- addr = firstFavoriteAddr(filter, addrs)
- if addr == nil {
- // should not happen
- err = &AddrError{"LookupHost returned no suitable address", addrs[0]}
- goto Error
- }
+ switch afnet {
+ case "ip", "ip4", "ip6":
+ default:
+ return nil, UnknownNetworkError(net)
+ }
+ a, err := resolveInternetAddr(afnet, addr, noDeadline)
+ if err != nil {
+ return nil, err
}
- return addr, nil
-Error:
- return nil, err
+ return a.(*IPAddr), nil
}
diff --git a/src/pkg/net/iprawsock_plan9.go b/src/pkg/net/iprawsock_plan9.go
index 43719fc99..88e3b2c60 100644
--- a/src/pkg/net/iprawsock_plan9.go
+++ b/src/pkg/net/iprawsock_plan9.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// (Raw) IP sockets stubs for Plan 9
+// Raw IP sockets for Plan 9
package net
@@ -11,55 +11,13 @@ import (
"time"
)
-// IPConn is the implementation of the Conn and PacketConn
-// interfaces for IP network connections.
-type IPConn bool
-
-// SetDeadline implements the Conn SetDeadline method.
-func (c *IPConn) SetDeadline(t time.Time) error {
- return syscall.EPLAN9
-}
-
-// SetReadDeadline implements the Conn SetReadDeadline method.
-func (c *IPConn) SetReadDeadline(t time.Time) error {
- return syscall.EPLAN9
-}
-
-// SetWriteDeadline implements the Conn SetWriteDeadline method.
-func (c *IPConn) SetWriteDeadline(t time.Time) error {
- return syscall.EPLAN9
-}
-
-// Implementation of the Conn interface - see Conn for documentation.
-
-// Read implements the Conn Read method.
-func (c *IPConn) Read(b []byte) (int, error) {
- return 0, syscall.EPLAN9
-}
-
-// Write implements the Conn Write method.
-func (c *IPConn) Write(b []byte) (int, error) {
- return 0, syscall.EPLAN9
-}
-
-// Close closes the IP connection.
-func (c *IPConn) Close() error {
- return syscall.EPLAN9
-}
-
-// LocalAddr returns the local network address.
-func (c *IPConn) LocalAddr() Addr {
- return nil
+// IPConn is the implementation of the Conn and PacketConn interfaces
+// for IP network connections.
+type IPConn struct {
+ conn
}
-// RemoteAddr returns the remote network address, a *IPAddr.
-func (c *IPConn) RemoteAddr() Addr {
- return nil
-}
-
-// IP-specific methods.
-
-// ReadFromIP reads a IP packet from c, copying the payload into b.
+// ReadFromIP reads an IP packet from c, copying the payload into b.
// It returns the number of bytes copied into b and the return address
// that was on the packet.
//
@@ -75,12 +33,21 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) {
return 0, nil, syscall.EPLAN9
}
-// WriteToIP writes a IP packet to addr via c, copying the payload from b.
+// ReadMsgIP reads a packet from c, copying the payload into b and the
+// associdated out-of-band data into oob. It returns the number of
+// bytes copied into b, the number of bytes copied into oob, the flags
+// that were set on the packet and the source address of the packet.
+func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
+ return 0, 0, 0, nil, syscall.EPLAN9
+}
+
+// WriteToIP writes an IP packet to addr via c, copying the payload
+// from b.
//
-// WriteToIP can be made to time out and return
-// an error with Timeout() == true after a fixed time limit;
-// see SetDeadline and SetWriteDeadline.
-// On packet-oriented connections, write timeouts are rare.
+// WriteToIP can be made to time out and return an error with
+// Timeout() == true after a fixed time limit; see SetDeadline and
+// SetWriteDeadline. On packet-oriented connections, write timeouts
+// are rare.
func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) {
return 0, syscall.EPLAN9
}
@@ -90,16 +57,28 @@ func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) {
return 0, syscall.EPLAN9
}
-// DialIP connects to the remote address raddr on the network protocol netProto,
-// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name.
+// WriteMsgIP writes a packet to addr via c, copying the payload from
+// b and the associated out-of-band data from oob. It returns the
+// number of payload and out-of-band bytes written.
+func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error) {
+ return 0, 0, syscall.EPLAN9
+}
+
+// DialIP connects to the remote address raddr on the network protocol
+// netProto, which must be "ip", "ip4", or "ip6" followed by a colon
+// and a protocol number or name.
func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
+ return dialIP(netProto, laddr, raddr, noDeadline)
+}
+
+func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) {
return nil, syscall.EPLAN9
}
-// ListenIP listens for incoming IP packets addressed to the
-// local address laddr. The returned connection c's ReadFrom
-// and WriteTo methods can be used to receive and send IP
-// packets with per-packet addressing.
+// ListenIP listens for incoming IP packets addressed to the local
+// address laddr. The returned connection c's ReadFrom and WriteTo
+// methods can be used to receive and send IP packets with per-packet
+// addressing.
func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
return nil, syscall.EPLAN9
}
diff --git a/src/pkg/net/iprawsock_posix.go b/src/pkg/net/iprawsock_posix.go
index 9fc7ecdb9..2ef4db19c 100644
--- a/src/pkg/net/iprawsock_posix.go
+++ b/src/pkg/net/iprawsock_posix.go
@@ -4,12 +4,11 @@
// +build darwin freebsd linux netbsd openbsd windows
-// (Raw) IP sockets
+// Raw IP sockets for POSIX
package net
import (
- "os"
"syscall"
"time"
)
@@ -17,9 +16,9 @@ import (
func sockaddrToIP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
- return &IPAddr{sa.Addr[0:]}
+ return &IPAddr{IP: sa.Addr[0:]}
case *syscall.SockaddrInet6:
- return &IPAddr{sa.Addr[0:]}
+ return &IPAddr{IP: sa.Addr[0:], Zone: zoneToString(int(sa.ZoneId))}
}
return nil
}
@@ -42,7 +41,7 @@ func (a *IPAddr) isWildcard() bool {
}
func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
- return ipToSockaddr(family, a.IP, 0)
+ return ipToSockaddr(family, a.IP, 0, a.Zone)
}
func (a *IPAddr) toAddr() sockaddr {
@@ -55,98 +54,12 @@ func (a *IPAddr) toAddr() sockaddr {
// IPConn is the implementation of the Conn and PacketConn
// interfaces for IP network connections.
type IPConn struct {
- fd *netFD
+ conn
}
-func newIPConn(fd *netFD) *IPConn { return &IPConn{fd} }
+func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
-func (c *IPConn) ok() bool { return c != nil && c.fd != nil }
-
-// Implementation of the Conn interface - see Conn for documentation.
-
-// Read implements the Conn Read method.
-func (c *IPConn) Read(b []byte) (int, error) {
- n, _, err := c.ReadFrom(b)
- return n, err
-}
-
-// Write implements the Conn Write method.
-func (c *IPConn) Write(b []byte) (int, error) {
- if !c.ok() {
- return 0, syscall.EINVAL
- }
- return c.fd.Write(b)
-}
-
-// Close closes the IP connection.
-func (c *IPConn) Close() error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return c.fd.Close()
-}
-
-// LocalAddr returns the local network address.
-func (c *IPConn) LocalAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.fd.laddr
-}
-
-// RemoteAddr returns the remote network address, a *IPAddr.
-func (c *IPConn) RemoteAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.fd.raddr
-}
-
-// SetDeadline implements the Conn SetDeadline method.
-func (c *IPConn) SetDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setDeadline(c.fd, t)
-}
-
-// SetReadDeadline implements the Conn SetReadDeadline method.
-func (c *IPConn) SetReadDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setReadDeadline(c.fd, t)
-}
-
-// SetWriteDeadline implements the Conn SetWriteDeadline method.
-func (c *IPConn) SetWriteDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setWriteDeadline(c.fd, t)
-}
-
-// SetReadBuffer sets the size of the operating system's
-// receive buffer associated with the connection.
-func (c *IPConn) SetReadBuffer(bytes int) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setReadBuffer(c.fd, bytes)
-}
-
-// SetWriteBuffer sets the size of the operating system's
-// transmit buffer associated with the connection.
-func (c *IPConn) SetWriteBuffer(bytes int) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setWriteBuffer(c.fd, bytes)
-}
-
-// IP-specific methods.
-
-// ReadFromIP reads a IP packet from c, copying the payload into b.
+// ReadFromIP reads an IP packet from c, copying the payload into b.
// It returns the number of bytes copied into b and the return address
// that was on the packet.
//
@@ -163,14 +76,14 @@ func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) {
n, sa, err := c.fd.ReadFrom(b)
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
- addr = &IPAddr{sa.Addr[0:]}
+ addr = &IPAddr{IP: sa.Addr[0:]}
if len(b) >= IPv4len { // discard ipv4 header
hsize := (int(b[0]) & 0xf) * 4
copy(b, b[hsize:])
n -= hsize
}
case *syscall.SockaddrInet6:
- addr = &IPAddr{sa.Addr[0:]}
+ addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneToString(int(sa.ZoneId))}
}
return n, addr, err
}
@@ -180,11 +93,30 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
}
- n, uaddr, err := c.ReadFromIP(b)
- return n, uaddr.toAddr(), err
+ n, addr, err := c.ReadFromIP(b)
+ return n, addr.toAddr(), err
+}
+
+// ReadMsgIP reads a packet from c, copying the payload into b and the
+// associdated out-of-band data into oob. It returns the number of
+// bytes copied into b, the number of bytes copied into oob, the flags
+// that were set on the packet and the source address of the packet.
+func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
+ if !c.ok() {
+ return 0, 0, 0, nil, syscall.EINVAL
+ }
+ var sa syscall.Sockaddr
+ n, oobn, flags, sa, err = c.fd.ReadMsg(b, oob)
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ addr = &IPAddr{IP: sa.Addr[0:]}
+ case *syscall.SockaddrInet6:
+ addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneToString(int(sa.ZoneId))}
+ }
+ return
}
-// WriteToIP writes a IP packet to addr via c, copying the payload from b.
+// WriteToIP writes an IP packet to addr via c, copying the payload from b.
//
// WriteToIP can be made to time out and return
// an error with Timeout() == true after a fixed time limit;
@@ -213,22 +145,40 @@ func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) {
return c.WriteToIP(b, a)
}
+// WriteMsgIP writes a packet to addr via c, copying the payload from
+// b and the associated out-of-band data from oob. It returns the
+// number of payload and out-of-band bytes written.
+func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error) {
+ if !c.ok() {
+ return 0, 0, syscall.EINVAL
+ }
+ sa, err := addr.sockaddr(c.fd.family)
+ if err != nil {
+ return 0, 0, &OpError{"write", c.fd.net, addr, err}
+ }
+ return c.fd.WriteMsg(b, oob, sa)
+}
+
// DialIP connects to the remote address raddr on the network protocol netProto,
// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name.
func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
- net, proto, err := parseDialNetwork(netProto)
+ return dialIP(netProto, laddr, raddr, noDeadline)
+}
+
+func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) {
+ net, proto, err := parseNetwork(netProto)
if err != nil {
return nil, err
}
switch net {
case "ip", "ip4", "ip6":
default:
- return nil, UnknownNetworkError(net)
+ return nil, UnknownNetworkError(netProto)
}
if raddr == nil {
return nil, &OpError{"dial", netProto, nil, errMissingAddress}
}
- fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_RAW, proto, "dial", sockaddrToIP)
+ fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_RAW, proto, "dial", sockaddrToIP)
if err != nil {
return nil, err
}
@@ -240,23 +190,18 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
// and WriteTo methods can be used to receive and send IP
// packets with per-packet addressing.
func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
- net, proto, err := parseDialNetwork(netProto)
+ net, proto, err := parseNetwork(netProto)
if err != nil {
return nil, err
}
switch net {
case "ip", "ip4", "ip6":
default:
- return nil, UnknownNetworkError(net)
+ return nil, UnknownNetworkError(netProto)
}
- fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "listen", sockaddrToIP)
+ fd, err := internetSocket(net, laddr.toAddr(), nil, noDeadline, syscall.SOCK_RAW, proto, "listen", sockaddrToIP)
if err != nil {
return nil, err
}
return newIPConn(fd), nil
}
-
-// File returns a copy of the underlying os.File, set to blocking mode.
-// It is the caller's responsibility to close f when finished.
-// Closing c does not affect f, and closing f does not affect c.
-func (c *IPConn) File() (f *os.File, err error) { return c.fd.dup() }
diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go
index bfbce18a4..1ef489289 100644
--- a/src/pkg/net/ipsock.go
+++ b/src/pkg/net/ipsock.go
@@ -2,11 +2,18 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// IP sockets
+// Internet protocol family sockets
package net
-var supportsIPv6, supportsIPv4map = probeIPv6Stack()
+import "time"
+
+var supportsIPv6, supportsIPv4map bool
+
+func init() {
+ sysInit()
+ supportsIPv6, supportsIPv4map = probeIPv6Stack()
+}
func firstFavoriteAddr(filter func(IP) IP, addrs []string) (addr IP) {
if filter == nil {
@@ -65,25 +72,67 @@ func (e InvalidAddrError) Temporary() bool { return false }
// "host:port" or "[host]:port" into host and port.
// The latter form must be used when host contains a colon.
func SplitHostPort(hostport string) (host, port string, err error) {
+ host, port, _, err = splitHostPort(hostport)
+ return
+}
+
+func splitHostPort(hostport string) (host, port, zone string, err error) {
+ j, k := 0, 0
+
// The port starts after the last colon.
i := last(hostport, ':')
if i < 0 {
- err = &AddrError{"missing port in address", hostport}
- return
+ goto missingPort
}
- host, port = hostport[0:i], hostport[i+1:]
-
- // Can put brackets around host ...
- if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' {
- host = host[1 : len(host)-1]
+ if hostport[0] == '[' {
+ // Expect the first ']' just before the last ':'.
+ end := byteIndex(hostport, ']')
+ if end < 0 {
+ err = &AddrError{"missing ']' in address", hostport}
+ return
+ }
+ switch end + 1 {
+ case len(hostport):
+ // There can't be a ':' behind the ']' now.
+ goto missingPort
+ case i:
+ // The expected result.
+ default:
+ // Either ']' isn't followed by a colon, or it is
+ // followed by a colon that is not the last one.
+ if hostport[end+1] == ':' {
+ goto tooManyColons
+ }
+ goto missingPort
+ }
+ host = hostport[1:end]
+ j, k = 1, end+1 // there can't be a '[' resp. ']' before these positions
} else {
- // ... but if there are no brackets, no colons.
+ host = hostport[:i]
+
if byteIndex(host, ':') >= 0 {
- err = &AddrError{"too many colons in address", hostport}
- return
+ goto tooManyColons
}
}
+ if byteIndex(hostport[j:], '[') >= 0 {
+ err = &AddrError{"unexpected '[' in address", hostport}
+ return
+ }
+ if byteIndex(hostport[k:], ']') >= 0 {
+ err = &AddrError{"unexpected ']' in address", hostport}
+ return
+ }
+
+ port = hostport[i+1:]
+ return
+
+missingPort:
+ err = &AddrError{"missing port in address", hostport}
+ return
+
+tooManyColons:
+ err = &AddrError{"too many colons in address", hostport}
return
}
@@ -97,49 +146,84 @@ func JoinHostPort(host, port string) string {
return host + ":" + port
}
-// Convert "host:port" into IP address and port.
-func hostPortToIP(net, hostport string) (ip IP, iport int, err error) {
- host, port, err := SplitHostPort(hostport)
- if err != nil {
- return nil, 0, err
- }
-
- var addr IP
- if host != "" {
- // Try as an IP address.
- addr = ParseIP(host)
- if addr == nil {
- var filter func(IP) IP
- if net != "" && net[len(net)-1] == '4' {
- filter = ipv4only
+func resolveInternetAddr(net, addr string, deadline time.Time) (Addr, error) {
+ var (
+ err error
+ host, port, zone string
+ portnum int
+ )
+ switch net {
+ case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
+ if addr != "" {
+ if host, port, zone, err = splitHostPort(addr); err != nil {
+ return nil, err
}
- if net != "" && net[len(net)-1] == '6' {
- filter = ipv6only
- }
- // Not an IP address. Try as a DNS name.
- addrs, err := LookupHost(host)
- if err != nil {
- return nil, 0, err
- }
- addr = firstFavoriteAddr(filter, addrs)
- if addr == nil {
- // should not happen
- return nil, 0, &AddrError{"LookupHost returned no suitable address", addrs[0]}
+ if portnum, err = parsePort(net, port); err != nil {
+ return nil, err
}
}
+ case "ip", "ip4", "ip6":
+ if addr != "" {
+ host = addr
+ }
+ default:
+ return nil, UnknownNetworkError(net)
}
-
- p, i, ok := dtoi(port, 0)
- if !ok || i != len(port) {
- p, err = LookupPort(net, port)
- if err != nil {
- return nil, 0, err
+ inetaddr := func(net string, ip IP, port int, zone string) Addr {
+ switch net {
+ case "tcp", "tcp4", "tcp6":
+ return &TCPAddr{IP: ip, Port: port, Zone: zone}
+ case "udp", "udp4", "udp6":
+ return &UDPAddr{IP: ip, Port: port, Zone: zone}
+ case "ip", "ip4", "ip6":
+ return &IPAddr{IP: ip, Zone: zone}
}
+ return nil
+ }
+ if host == "" {
+ return inetaddr(net, nil, portnum, zone), nil
}
- if p < 0 || p > 0xFFFF {
- return nil, 0, &AddrError{"invalid port", port}
+ // Try as an IP address.
+ if ip := ParseIP(host); ip != nil {
+ return inetaddr(net, ip, portnum, zone), nil
}
+ var filter func(IP) IP
+ if net != "" && net[len(net)-1] == '4' {
+ filter = ipv4only
+ }
+ if net != "" && net[len(net)-1] == '6' {
+ filter = ipv6only
+ }
+ // Try as a DNS name.
+ addrs, err := lookupHostDeadline(host, deadline)
+ if err != nil {
+ return nil, err
+ }
+ ip := firstFavoriteAddr(filter, addrs)
+ if ip == nil {
+ // should not happen
+ return nil, &AddrError{"LookupHost returned no suitable address", addrs[0]}
+ }
+ return inetaddr(net, ip, portnum, zone), nil
+}
- return addr, p, nil
+func zoneToString(zone int) string {
+ if zone == 0 {
+ return ""
+ }
+ if ifi, err := InterfaceByIndex(zone); err == nil {
+ return ifi.Name
+ }
+ return itod(uint(zone))
+}
+func zoneToInt(zone string) int {
+ if zone == "" {
+ return 0
+ }
+ if ifi, err := InterfaceByName(zone); err == nil {
+ return ifi.Index
+ }
+ n, _, _ := dtoi(zone, 0)
+ return n
}
diff --git a/src/pkg/net/ipsock_plan9.go b/src/pkg/net/ipsock_plan9.go
index eab0bf3e8..c7d542dab 100644
--- a/src/pkg/net/ipsock_plan9.go
+++ b/src/pkg/net/ipsock_plan9.go
@@ -2,21 +2,22 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// IP sockets stubs for Plan 9
+// Internet protocol family sockets for Plan 9
package net
import (
"errors"
- "io"
"os"
"syscall"
- "time"
)
-// probeIPv6Stack returns two boolean values. If the first boolean value is
-// true, kernel supports basic IPv6 functionality. If the second
-// boolean value is true, kernel supports IPv6 IPv4-mapping.
+// /sys/include/ape/sys/socket.h:/SOMAXCONN
+var listenerBacklog = 5
+
+// probeIPv6Stack returns two boolean values. If the first boolean
+// value is true, kernel supports basic IPv6 functionality. If the
+// second boolean value is true, kernel supports IPv6 IPv4-mapping.
func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) {
return false, false
}
@@ -48,6 +49,7 @@ func readPlan9Addr(proto, filename string) (addr Addr, err error) {
if err != nil {
return
}
+ defer f.Close()
n, err := f.Read(buf[:])
if err != nil {
return
@@ -58,110 +60,15 @@ func readPlan9Addr(proto, filename string) (addr Addr, err error) {
}
switch proto {
case "tcp":
- addr = &TCPAddr{ip, port}
+ addr = &TCPAddr{IP: ip, Port: port}
case "udp":
- addr = &UDPAddr{ip, port}
+ addr = &UDPAddr{IP: ip, Port: port}
default:
return nil, errors.New("unknown protocol " + proto)
}
return addr, nil
}
-type plan9Conn struct {
- proto, name, dir string
- ctl, data *os.File
- laddr, raddr Addr
-}
-
-func newPlan9Conn(proto, name string, ctl *os.File, laddr, raddr Addr) *plan9Conn {
- return &plan9Conn{proto, name, "/net/" + proto + "/" + name, ctl, nil, laddr, raddr}
-}
-
-func (c *plan9Conn) ok() bool { return c != nil && c.ctl != nil }
-
-// Implementation of the Conn interface - see Conn for documentation.
-
-// Read implements the Conn Read method.
-func (c *plan9Conn) Read(b []byte) (n int, err error) {
- if !c.ok() {
- return 0, syscall.EINVAL
- }
- if c.data == nil {
- c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0)
- if err != nil {
- return 0, err
- }
- }
- n, err = c.data.Read(b)
- if c.proto == "udp" && err == io.EOF {
- n = 0
- err = nil
- }
- return
-}
-
-// Write implements the Conn Write method.
-func (c *plan9Conn) Write(b []byte) (n int, err error) {
- if !c.ok() {
- return 0, syscall.EINVAL
- }
- if c.data == nil {
- c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0)
- if err != nil {
- return 0, err
- }
- }
- return c.data.Write(b)
-}
-
-// Close closes the connection.
-func (c *plan9Conn) Close() error {
- if !c.ok() {
- return syscall.EINVAL
- }
- err := c.ctl.Close()
- if err != nil {
- return err
- }
- if c.data != nil {
- err = c.data.Close()
- }
- c.ctl = nil
- c.data = nil
- return err
-}
-
-// LocalAddr returns the local network address.
-func (c *plan9Conn) LocalAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.laddr
-}
-
-// RemoteAddr returns the remote network address.
-func (c *plan9Conn) RemoteAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.raddr
-}
-
-// SetDeadline implements the Conn SetDeadline method.
-func (c *plan9Conn) SetDeadline(t time.Time) error {
- return syscall.EPLAN9
-}
-
-// SetReadDeadline implements the Conn SetReadDeadline method.
-func (c *plan9Conn) SetReadDeadline(t time.Time) error {
- return syscall.EPLAN9
-}
-
-// SetWriteDeadline implements the Conn SetWriteDeadline method.
-func (c *plan9Conn) SetWriteDeadline(t time.Time) error {
- return syscall.EPLAN9
-}
-
func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) {
var (
ip IP
@@ -192,98 +99,95 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string,
var buf [16]byte
n, err := f.Read(buf[:])
if err != nil {
+ f.Close()
return
}
return f, dest, proto, string(buf[:n]), nil
}
-func dialPlan9(net string, laddr, raddr Addr) (c *plan9Conn, err error) {
+func netErr(e error) {
+ oe, ok := e.(*OpError)
+ if !ok {
+ return
+ }
+ if pe, ok := oe.Err.(*os.PathError); ok {
+ if _, ok = pe.Err.(syscall.ErrorString); ok {
+ oe.Err = pe.Err
+ }
+ }
+}
+
+func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) {
+ defer func() { netErr(err) }()
f, dest, proto, name, err := startPlan9(net, raddr)
if err != nil {
- return
+ return nil, &OpError{"dial", net, raddr, err}
}
_, err = f.WriteString("connect " + dest)
if err != nil {
- return
+ f.Close()
+ return nil, &OpError{"dial", f.Name(), raddr, err}
}
- laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local")
+ data, err := os.OpenFile("/net/"+proto+"/"+name+"/data", os.O_RDWR, 0)
if err != nil {
- return
+ f.Close()
+ return nil, &OpError{"dial", net, raddr, err}
}
- raddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/remote")
+ laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local")
if err != nil {
- return
+ data.Close()
+ f.Close()
+ return nil, &OpError{"dial", proto, raddr, err}
}
- return newPlan9Conn(proto, name, f, laddr, raddr), nil
-}
-
-type plan9Listener struct {
- proto, name, dir string
- ctl *os.File
- laddr Addr
+ return newFD(proto, name, f, data, laddr, raddr), nil
}
-func listenPlan9(net string, laddr Addr) (l *plan9Listener, err error) {
+func listenPlan9(net string, laddr Addr) (fd *netFD, err error) {
+ defer func() { netErr(err) }()
f, dest, proto, name, err := startPlan9(net, laddr)
if err != nil {
- return
+ return nil, &OpError{"listen", net, laddr, err}
}
_, err = f.WriteString("announce " + dest)
if err != nil {
- return
+ f.Close()
+ return nil, &OpError{"announce", proto, laddr, err}
}
laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local")
if err != nil {
- return
+ f.Close()
+ return nil, &OpError{Op: "listen", Net: net, Err: err}
}
- l = new(plan9Listener)
- l.proto = proto
- l.name = name
- l.dir = "/net/" + proto + "/" + name
- l.ctl = f
- l.laddr = laddr
- return l, nil
+ return newFD(proto, name, f, nil, laddr, nil), nil
}
-func (l *plan9Listener) plan9Conn() *plan9Conn {
- return newPlan9Conn(l.proto, l.name, l.ctl, l.laddr, nil)
+func (l *netFD) netFD() *netFD {
+ return newFD(l.proto, l.name, l.ctl, l.data, l.laddr, l.raddr)
}
-func (l *plan9Listener) acceptPlan9() (c *plan9Conn, err error) {
+func (l *netFD) acceptPlan9() (fd *netFD, err error) {
+ defer func() { netErr(err) }()
f, err := os.Open(l.dir + "/listen")
if err != nil {
- return
+ return nil, &OpError{"accept", l.dir + "/listen", l.laddr, err}
}
var buf [16]byte
n, err := f.Read(buf[:])
if err != nil {
- return
+ f.Close()
+ return nil, &OpError{"accept", l.dir + "/listen", l.laddr, err}
}
name := string(buf[:n])
- laddr, err := readPlan9Addr(l.proto, l.dir+"/local")
- if err != nil {
- return
- }
- raddr, err := readPlan9Addr(l.proto, l.dir+"/remote")
+ data, err := os.OpenFile("/net/"+l.proto+"/"+name+"/data", os.O_RDWR, 0)
if err != nil {
- return
+ f.Close()
+ return nil, &OpError{"accept", l.proto, l.laddr, err}
}
- return newPlan9Conn(l.proto, name, f, laddr, raddr), nil
-}
-
-func (l *plan9Listener) Accept() (c Conn, err error) {
- c1, err := l.acceptPlan9()
+ raddr, err := readPlan9Addr(l.proto, "/net/"+l.proto+"/"+name+"/remote")
if err != nil {
- return
+ data.Close()
+ f.Close()
+ return nil, &OpError{"accept", l.proto, l.laddr, err}
}
- return c1, nil
+ return newFD(l.proto, name, f, data, l.laddr, raddr), nil
}
-
-func (l *plan9Listener) Close() error {
- if l == nil || l.ctl == nil {
- return syscall.EINVAL
- }
- return l.ctl.Close()
-}
-
-func (l *plan9Listener) Addr() Addr { return l.laddr }
diff --git a/src/pkg/net/ipsock_posix.go b/src/pkg/net/ipsock_posix.go
index ed313195c..4c37616ec 100644
--- a/src/pkg/net/ipsock_posix.go
+++ b/src/pkg/net/ipsock_posix.go
@@ -4,9 +4,14 @@
// +build darwin freebsd linux netbsd openbsd windows
+// Internet protocol family sockets for POSIX
+
package net
-import "syscall"
+import (
+ "syscall"
+ "time"
+)
// Should we try to use the IPv4 socket interface if we're
// only dealing with IPv4 sockets? As long as the host system
@@ -97,10 +102,13 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family
return syscall.AF_INET6, true
}
- if mode == "listen" && laddr.isWildcard() {
+ if mode == "listen" && (laddr == nil || laddr.isWildcard()) {
if supportsIPv4map {
return syscall.AF_INET6, false
}
+ if laddr == nil {
+ return syscall.AF_INET, false
+ }
return laddr.family(), false
}
@@ -122,7 +130,7 @@ type sockaddr interface {
sockaddr(family int) (syscall.Sockaddr, error)
}
-func internetSocket(net string, laddr, raddr sockaddr, sotype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
+func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
var la, ra syscall.Sockaddr
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
if laddr != nil {
@@ -135,7 +143,7 @@ func internetSocket(net string, laddr, raddr sockaddr, sotype, proto int, mode s
goto Error
}
}
- fd, err = socket(net, family, sotype, proto, ipv6only, la, ra, toAddr)
+ fd, err = socket(net, family, sotype, proto, ipv6only, la, ra, deadline, toAddr)
if err != nil {
goto Error
}
@@ -149,7 +157,7 @@ Error:
return nil, &OpError{mode, net, addr, err}
}
-func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, error) {
+func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
switch family {
case syscall.AF_INET:
if len(ip) == 0 {
@@ -158,12 +166,12 @@ func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, error) {
if ip = ip.To4(); ip == nil {
return nil, InvalidAddrError("non-IPv4 address")
}
- s := new(syscall.SockaddrInet4)
+ sa := new(syscall.SockaddrInet4)
for i := 0; i < IPv4len; i++ {
- s.Addr[i] = ip[i]
+ sa.Addr[i] = ip[i]
}
- s.Port = port
- return s, nil
+ sa.Port = port
+ return sa, nil
case syscall.AF_INET6:
if len(ip) == 0 {
ip = IPv6zero
@@ -177,12 +185,13 @@ func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, error) {
if ip = ip.To16(); ip == nil {
return nil, InvalidAddrError("non-IPv6 address")
}
- s := new(syscall.SockaddrInet6)
+ sa := new(syscall.SockaddrInet6)
for i := 0; i < IPv6len; i++ {
- s.Addr[i] = ip[i]
+ sa.Addr[i] = ip[i]
}
- s.Port = port
- return s, nil
+ sa.Port = port
+ sa.ZoneId = uint32(zoneToInt(zone))
+ return sa, nil
}
return nil, InvalidAddrError("unexpected socket family")
}
diff --git a/src/pkg/net/doc.go b/src/pkg/net/lookup.go
index 3a44e528e..bec93ec08 100644
--- a/src/pkg/net/doc.go
+++ b/src/pkg/net/lookup.go
@@ -4,12 +4,53 @@
package net
+import (
+ "time"
+)
+
// LookupHost looks up the given host using the local resolver.
// It returns an array of that host's addresses.
func LookupHost(host string) (addrs []string, err error) {
return lookupHost(host)
}
+func lookupHostDeadline(host string, deadline time.Time) (addrs []string, err error) {
+ if deadline.IsZero() {
+ return lookupHost(host)
+ }
+
+ // TODO(bradfitz): consider pushing the deadline down into the
+ // name resolution functions. But that involves fixing it for
+ // the native Go resolver, cgo, Windows, etc.
+ //
+ // In the meantime, just use a goroutine. Most users affected
+ // by http://golang.org/issue/2631 are due to TCP connections
+ // to unresponsive hosts, not DNS.
+ timeout := deadline.Sub(time.Now())
+ if timeout <= 0 {
+ err = errTimeout
+ return
+ }
+ t := time.NewTimer(timeout)
+ defer t.Stop()
+ type res struct {
+ addrs []string
+ err error
+ }
+ resc := make(chan res, 1)
+ go func() {
+ a, err := lookupHost(host)
+ resc <- res{a, err}
+ }()
+ select {
+ case <-t.C:
+ err = errTimeout
+ case r := <-resc:
+ addrs, err = r.addrs, r.err
+ }
+ return
+}
+
// LookupIP looks up host using the local resolver.
// It returns an array of that host's IPv4 and IPv6 addresses.
func LookupIP(host string) (addrs []IP, err error) {
@@ -47,6 +88,11 @@ func LookupMX(name string) (mx []*MX, err error) {
return lookupMX(name)
}
+// LookupNS returns the DNS NS records for the given domain name.
+func LookupNS(name string) (ns []*NS, err error) {
+ return lookupNS(name)
+}
+
// LookupTXT returns the DNS TXT records for the given domain name.
func LookupTXT(name string) (txt []string, err error) {
return lookupTXT(name)
diff --git a/src/pkg/net/lookup_plan9.go b/src/pkg/net/lookup_plan9.go
index 2c698304b..ae7cf7942 100644
--- a/src/pkg/net/lookup_plan9.go
+++ b/src/pkg/net/lookup_plan9.go
@@ -201,6 +201,21 @@ func lookupMX(name string) (mx []*MX, err error) {
return
}
+func lookupNS(name string) (ns []*NS, err error) {
+ lines, err := queryDNS(name, "ns")
+ if err != nil {
+ return
+ }
+ for _, line := range lines {
+ f := getFields(line)
+ if len(f) < 4 {
+ continue
+ }
+ ns = append(ns, &NS{f[3]})
+ }
+ return
+}
+
func lookupTXT(name string) (txt []string, err error) {
lines, err := queryDNS(name, "txt")
if err != nil {
diff --git a/src/pkg/net/lookup_test.go b/src/pkg/net/lookup_test.go
index 3a61dfb29..3355e4694 100644
--- a/src/pkg/net/lookup_test.go
+++ b/src/pkg/net/lookup_test.go
@@ -9,6 +9,7 @@ package net
import (
"flag"
+ "strings"
"testing"
)
@@ -16,8 +17,7 @@ var testExternal = flag.Bool("external", true, "allow use of external networks d
func TestGoogleSRV(t *testing.T) {
if testing.Short() || !*testExternal {
- t.Logf("skipping test to avoid external network")
- return
+ t.Skip("skipping test to avoid external network")
}
_, addrs, err := LookupSRV("xmpp-server", "tcp", "google.com")
if err != nil {
@@ -39,8 +39,7 @@ func TestGoogleSRV(t *testing.T) {
func TestGmailMX(t *testing.T) {
if testing.Short() || !*testExternal {
- t.Logf("skipping test to avoid external network")
- return
+ t.Skip("skipping test to avoid external network")
}
mx, err := LookupMX("gmail.com")
if err != nil {
@@ -51,10 +50,22 @@ func TestGmailMX(t *testing.T) {
}
}
+func TestGmailNS(t *testing.T) {
+ if testing.Short() || !*testExternal {
+ t.Skip("skipping test to avoid external network")
+ }
+ ns, err := LookupNS("gmail.com")
+ if err != nil {
+ t.Errorf("failed: %s", err)
+ }
+ if len(ns) == 0 {
+ t.Errorf("no results")
+ }
+}
+
func TestGmailTXT(t *testing.T) {
if testing.Short() || !*testExternal {
- t.Logf("skipping test to avoid external network")
- return
+ t.Skip("skipping test to avoid external network")
}
txt, err := LookupTXT("gmail.com")
if err != nil {
@@ -67,8 +78,7 @@ func TestGmailTXT(t *testing.T) {
func TestGoogleDNSAddr(t *testing.T) {
if testing.Short() || !*testExternal {
- t.Logf("skipping test to avoid external network")
- return
+ t.Skip("skipping test to avoid external network")
}
names, err := LookupAddr("8.8.8.8")
if err != nil {
@@ -79,6 +89,16 @@ func TestGoogleDNSAddr(t *testing.T) {
}
}
+func TestLookupIANACNAME(t *testing.T) {
+ if testing.Short() || !*testExternal {
+ t.Skip("skipping test to avoid external network")
+ }
+ cname, err := LookupCNAME("www.iana.org")
+ if !strings.HasSuffix(cname, ".icann.org.") || err != nil {
+ t.Errorf(`LookupCNAME("www.iana.org.") = %q, %v, want "*.icann.org.", nil`, cname, err)
+ }
+}
+
var revAddrTests = []struct {
Addr string
Reverse string
diff --git a/src/pkg/net/lookup_unix.go b/src/pkg/net/lookup_unix.go
index d500a1240..fa98eed5f 100644
--- a/src/pkg/net/lookup_unix.go
+++ b/src/pkg/net/lookup_unix.go
@@ -119,6 +119,19 @@ func lookupMX(name string) (mx []*MX, err error) {
return
}
+func lookupNS(name string) (ns []*NS, err error) {
+ _, records, err := lookup(name, dnsTypeNS)
+ if err != nil {
+ return
+ }
+ ns = make([]*NS, len(records))
+ for i, r := range records {
+ r := r.(*dnsRR_NS)
+ ns[i] = &NS{r.Ns}
+ }
+ return
+}
+
func lookupTXT(name string) (txt []string, err error) {
_, records, err := lookup(name, dnsTypeTXT)
if err != nil {
diff --git a/src/pkg/net/lookup_windows.go b/src/pkg/net/lookup_windows.go
index 99783e975..3b29724f2 100644
--- a/src/pkg/net/lookup_windows.go
+++ b/src/pkg/net/lookup_windows.go
@@ -6,21 +6,17 @@ package net
import (
"os"
- "sync"
+ "runtime"
"syscall"
"unsafe"
)
var (
- protoentLock sync.Mutex
- hostentLock sync.Mutex
- serventLock sync.Mutex
+ lookupPort = oldLookupPort
+ lookupIP = oldLookupIP
)
-// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
-func lookupProtocol(name string) (proto int, err error) {
- protoentLock.Lock()
- defer protoentLock.Unlock()
+func getprotobyname(name string) (proto int, err error) {
p, err := syscall.GetProtoByName(name)
if err != nil {
return 0, os.NewSyscallError("GetProtoByName", err)
@@ -28,6 +24,25 @@ func lookupProtocol(name string) (proto int, err error) {
return int(p.Proto), nil
}
+// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
+func lookupProtocol(name string) (proto int, err error) {
+ // GetProtoByName return value is stored in thread local storage.
+ // Start new os thread before the call to prevent races.
+ type result struct {
+ proto int
+ err error
+ }
+ ch := make(chan result)
+ go func() {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+ proto, err := getprotobyname(name)
+ ch <- result{proto: proto, err: err}
+ }()
+ r := <-ch
+ return r.proto, r.err
+}
+
func lookupHost(name string) (addrs []string, err error) {
ips, err := LookupIP(name)
if err != nil {
@@ -40,9 +55,7 @@ func lookupHost(name string) (addrs []string, err error) {
return
}
-func lookupIP(name string) (addrs []IP, err error) {
- hostentLock.Lock()
- defer hostentLock.Unlock()
+func gethostbyname(name string) (addrs []IP, err error) {
h, err := syscall.GetHostByName(name)
if err != nil {
return nil, os.NewSyscallError("GetHostByName", err)
@@ -56,20 +69,65 @@ func lookupIP(name string) (addrs []IP, err error) {
}
addrs = addrs[0:i]
default: // TODO(vcc): Implement non IPv4 address lookups.
- return nil, os.NewSyscallError("LookupHost", syscall.EWINDOWS)
+ return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS)
}
return addrs, nil
}
-func lookupPort(network, service string) (port int, err error) {
+func oldLookupIP(name string) (addrs []IP, err error) {
+ // GetHostByName return value is stored in thread local storage.
+ // Start new os thread before the call to prevent races.
+ type result struct {
+ addrs []IP
+ err error
+ }
+ ch := make(chan result)
+ go func() {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+ addrs, err := gethostbyname(name)
+ ch <- result{addrs: addrs, err: err}
+ }()
+ r := <-ch
+ return r.addrs, r.err
+}
+
+func newLookupIP(name string) (addrs []IP, err error) {
+ hints := syscall.AddrinfoW{
+ Family: syscall.AF_UNSPEC,
+ Socktype: syscall.SOCK_STREAM,
+ Protocol: syscall.IPPROTO_IP,
+ }
+ var result *syscall.AddrinfoW
+ e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
+ if e != nil {
+ return nil, os.NewSyscallError("GetAddrInfoW", e)
+ }
+ defer syscall.FreeAddrInfoW(result)
+ addrs = make([]IP, 0, 5)
+ for ; result != nil; result = result.Next {
+ addr := unsafe.Pointer(result.Addr)
+ switch result.Family {
+ case syscall.AF_INET:
+ a := (*syscall.RawSockaddrInet4)(addr).Addr
+ addrs = append(addrs, IPv4(a[0], a[1], a[2], a[3]))
+ case syscall.AF_INET6:
+ a := (*syscall.RawSockaddrInet6)(addr).Addr
+ addrs = append(addrs, IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]})
+ default:
+ return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS)
+ }
+ }
+ return addrs, nil
+}
+
+func getservbyname(network, service string) (port int, err error) {
switch network {
case "tcp4", "tcp6":
network = "tcp"
case "udp4", "udp6":
network = "udp"
}
- serventLock.Lock()
- defer serventLock.Unlock()
s, err := syscall.GetServByName(service, network)
if err != nil {
return 0, os.NewSyscallError("GetServByName", err)
@@ -77,6 +135,58 @@ func lookupPort(network, service string) (port int, err error) {
return int(syscall.Ntohs(s.Port)), nil
}
+func oldLookupPort(network, service string) (port int, err error) {
+ // GetServByName return value is stored in thread local storage.
+ // Start new os thread before the call to prevent races.
+ type result struct {
+ port int
+ err error
+ }
+ ch := make(chan result)
+ go func() {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+ port, err := getservbyname(network, service)
+ ch <- result{port: port, err: err}
+ }()
+ r := <-ch
+ return r.port, r.err
+}
+
+func newLookupPort(network, service string) (port int, err error) {
+ var stype int32
+ switch network {
+ case "tcp4", "tcp6":
+ stype = syscall.SOCK_STREAM
+ case "udp4", "udp6":
+ stype = syscall.SOCK_DGRAM
+ }
+ hints := syscall.AddrinfoW{
+ Family: syscall.AF_UNSPEC,
+ Socktype: stype,
+ Protocol: syscall.IPPROTO_IP,
+ }
+ var result *syscall.AddrinfoW
+ e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
+ if e != nil {
+ return 0, os.NewSyscallError("GetAddrInfoW", e)
+ }
+ defer syscall.FreeAddrInfoW(result)
+ if result == nil {
+ return 0, os.NewSyscallError("LookupPort", syscall.EINVAL)
+ }
+ addr := unsafe.Pointer(result.Addr)
+ switch result.Family {
+ case syscall.AF_INET:
+ a := (*syscall.RawSockaddrInet4)(addr)
+ return int(syscall.Ntohs(a.Port)), nil
+ case syscall.AF_INET6:
+ a := (*syscall.RawSockaddrInet6)(addr)
+ return int(syscall.Ntohs(a.Port)), nil
+ }
+ return 0, os.NewSyscallError("LookupPort", syscall.EINVAL)
+}
+
func lookupCNAME(name string) (cname string, err error) {
var r *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
@@ -129,6 +239,21 @@ func lookupMX(name string) (mx []*MX, err error) {
return mx, nil
}
+func lookupNS(name string) (ns []*NS, err error) {
+ var r *syscall.DNSRecord
+ e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
+ if e != nil {
+ return nil, os.NewSyscallError("LookupNS", e)
+ }
+ defer syscall.DnsRecordListFree(r, 1)
+ ns = make([]*NS, 0, 10)
+ for p := r; p != nil && p.Type == syscall.DNS_TYPE_NS; p = p.Next {
+ v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
+ ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
+ }
+ return ns, nil
+}
+
func lookupTXT(name string) (txt []string, err error) {
var r *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
diff --git a/src/pkg/net/mail/message.go b/src/pkg/net/mail/message.go
index b610ccf3f..96c796e78 100644
--- a/src/pkg/net/mail/message.go
+++ b/src/pkg/net/mail/message.go
@@ -47,7 +47,8 @@ type Message struct {
}
// ReadMessage reads a message from r.
-// The headers are parsed, and the body of the message will be reading from r.
+// The headers are parsed, and the body of the message will be available
+// for reading from r.
func ReadMessage(r io.Reader) (msg *Message, err error) {
tp := textproto.NewReader(bufio.NewReader(r))
@@ -126,7 +127,7 @@ func (h Header) AddressList(key string) ([]*Address, error) {
if hdr == "" {
return nil, ErrHeaderNotPresent
}
- return newAddrParser(hdr).parseAddressList()
+ return ParseAddressList(hdr)
}
// Address represents a single mail address.
@@ -137,6 +138,16 @@ type Address struct {
Address string // user@domain
}
+// Parses a single RFC 5322 address, e.g. "Barry Gibbs <bg@example.com>"
+func ParseAddress(address string) (*Address, error) {
+ return newAddrParser(address).parseAddress()
+}
+
+// ParseAddressList parses the given string as a list of addresses.
+func ParseAddressList(list string) ([]*Address, error) {
+ return newAddrParser(list).parseAddressList()
+}
+
// String formats the address as a valid RFC 5322 address.
// If the address's name contains non-ASCII characters
// the name will be rendered according to RFC 2047.
diff --git a/src/pkg/net/mail/message_test.go b/src/pkg/net/mail/message_test.go
index fd17eb414..2e746f4a7 100644
--- a/src/pkg/net/mail/message_test.go
+++ b/src/pkg/net/mail/message_test.go
@@ -227,13 +227,24 @@ func TestAddressParsing(t *testing.T) {
},
}
for _, test := range tests {
- addrs, err := newAddrParser(test.addrsStr).parseAddressList()
+ if len(test.exp) == 1 {
+ addr, err := ParseAddress(test.addrsStr)
+ if err != nil {
+ t.Errorf("Failed parsing (single) %q: %v", test.addrsStr, err)
+ continue
+ }
+ if !reflect.DeepEqual([]*Address{addr}, test.exp) {
+ t.Errorf("Parse (single) of %q: got %+v, want %+v", test.addrsStr, addr, test.exp)
+ }
+ }
+
+ addrs, err := ParseAddressList(test.addrsStr)
if err != nil {
- t.Errorf("Failed parsing %q: %v", test.addrsStr, err)
+ t.Errorf("Failed parsing (list) %q: %v", test.addrsStr, err)
continue
}
if !reflect.DeepEqual(addrs, test.exp) {
- t.Errorf("Parse of %q: got %+v, want %+v", test.addrsStr, addrs, test.exp)
+ t.Errorf("Parse (list) of %q: got %+v, want %+v", test.addrsStr, addrs, test.exp)
}
}
}
diff --git a/src/pkg/net/multicast_posix_test.go b/src/pkg/net/multicast_posix_test.go
new file mode 100644
index 000000000..ff1edaf83
--- /dev/null
+++ b/src/pkg/net/multicast_posix_test.go
@@ -0,0 +1,180 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !plan9
+
+package net
+
+import (
+ "errors"
+ "os"
+ "runtime"
+ "testing"
+)
+
+var multicastListenerTests = []struct {
+ net string
+ gaddr *UDPAddr
+ flags Flags
+ ipv6 bool // test with underlying AF_INET6 socket
+}{
+ // cf. RFC 4727: Experimental Values in IPv4, IPv6, ICMPv4, ICMPv6, UDP, and TCP Headers
+
+ {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, FlagUp | FlagLoopback, false},
+ {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, 0, false},
+ {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, FlagUp | FlagLoopback, true},
+ {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, 0, true},
+
+ {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, FlagUp | FlagLoopback, false},
+ {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, 0, false},
+
+ {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}, FlagUp | FlagLoopback, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}, 0, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}, FlagUp | FlagLoopback, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}, 0, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}, FlagUp | FlagLoopback, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}, 0, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}, FlagUp | FlagLoopback, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}, 0, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}, FlagUp | FlagLoopback, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}, 0, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, FlagUp | FlagLoopback, true},
+ {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, 0, true},
+}
+
+// TestMulticastListener tests both single and double listen to a test
+// listener with same address family, same group address and same port.
+func TestMulticastListener(t *testing.T) {
+ switch runtime.GOOS {
+ case "netbsd", "openbsd", "plan9", "solaris", "windows":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ case "linux":
+ if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" {
+ t.Skipf("skipping test on %q/%q", runtime.GOOS, runtime.GOARCH)
+ }
+ }
+
+ for _, tt := range multicastListenerTests {
+ if tt.ipv6 && (!*testIPv6 || !supportsIPv6 || os.Getuid() != 0) {
+ continue
+ }
+ ifi, err := availMulticastInterface(t, tt.flags)
+ if err != nil {
+ continue
+ }
+ c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr)
+ if err != nil {
+ t.Fatalf("First ListenMulticastUDP failed: %v", err)
+ }
+ checkMulticastListener(t, err, c1, tt.gaddr)
+ c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr)
+ if err != nil {
+ t.Fatalf("Second ListenMulticastUDP failed: %v", err)
+ }
+ checkMulticastListener(t, err, c2, tt.gaddr)
+ c2.Close()
+ c1.Close()
+ }
+}
+
+func TestSimpleMulticastListener(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ case "windows":
+ if testing.Short() || !*testExternal {
+ t.Skip("skipping test on windows to avoid firewall")
+ }
+ }
+
+ for _, tt := range multicastListenerTests {
+ if tt.ipv6 {
+ continue
+ }
+ tt.flags = FlagUp | FlagMulticast // for windows testing
+ ifi, err := availMulticastInterface(t, tt.flags)
+ if err != nil {
+ continue
+ }
+ c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr)
+ if err != nil {
+ t.Fatalf("First ListenMulticastUDP failed: %v", err)
+ }
+ checkSimpleMulticastListener(t, err, c1, tt.gaddr)
+ c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr)
+ if err != nil {
+ t.Fatalf("Second ListenMulticastUDP failed: %v", err)
+ }
+ checkSimpleMulticastListener(t, err, c2, tt.gaddr)
+ c2.Close()
+ c1.Close()
+ }
+}
+
+func checkMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) {
+ if !multicastRIBContains(t, gaddr.IP) {
+ t.Errorf("%q not found in RIB", gaddr.String())
+ return
+ }
+ la := c.LocalAddr()
+ if la == nil {
+ t.Error("LocalAddr failed")
+ return
+ }
+ if a, ok := la.(*UDPAddr); !ok || a.Port == 0 {
+ t.Errorf("got %v; expected a proper address with non-zero port number", la)
+ return
+ }
+}
+
+func checkSimpleMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) {
+ la := c.LocalAddr()
+ if la == nil {
+ t.Error("LocalAddr failed")
+ return
+ }
+ if a, ok := la.(*UDPAddr); !ok || a.Port == 0 {
+ t.Errorf("got %v; expected a proper address with non-zero port number", la)
+ return
+ }
+}
+
+func availMulticastInterface(t *testing.T, flags Flags) (*Interface, error) {
+ var ifi *Interface
+ if flags != Flags(0) {
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatalf("Interfaces failed: %v", err)
+ }
+ for _, x := range ift {
+ if x.Flags&flags == flags {
+ ifi = &x
+ break
+ }
+ }
+ if ifi == nil {
+ return nil, errors.New("an appropriate multicast interface not found")
+ }
+ }
+ return ifi, nil
+}
+
+func multicastRIBContains(t *testing.T, ip IP) bool {
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatalf("Interfaces failed: %v", err)
+ }
+ for _, ifi := range ift {
+ ifmat, err := ifi.MulticastAddrs()
+ if err != nil {
+ t.Fatalf("MulticastAddrs failed: %v", err)
+ }
+ for _, ifma := range ifmat {
+ if ifma.(*IPAddr).IP.Equal(ip) {
+ return true
+ }
+ }
+ }
+ return false
+}
diff --git a/src/pkg/net/multicast_test.go b/src/pkg/net/multicast_test.go
deleted file mode 100644
index 67261b1ee..000000000
--- a/src/pkg/net/multicast_test.go
+++ /dev/null
@@ -1,234 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package net
-
-import (
- "errors"
- "os"
- "runtime"
- "syscall"
- "testing"
-)
-
-var multicastListenerTests = []struct {
- net string
- gaddr *UDPAddr
- flags Flags
- ipv6 bool // test with underlying AF_INET6 socket
-}{
- // cf. RFC 4727: Experimental Values in IPv4, IPv6, ICMPv4, ICMPv6, UDP, and TCP Headers
-
- {"udp", &UDPAddr{IPv4(224, 0, 0, 254), 12345}, FlagUp | FlagLoopback, false},
- {"udp", &UDPAddr{IPv4(224, 0, 0, 254), 12345}, 0, false},
- {"udp", &UDPAddr{ParseIP("ff0e::114"), 12345}, FlagUp | FlagLoopback, true},
- {"udp", &UDPAddr{ParseIP("ff0e::114"), 12345}, 0, true},
-
- {"udp4", &UDPAddr{IPv4(224, 0, 0, 254), 12345}, FlagUp | FlagLoopback, false},
- {"udp4", &UDPAddr{IPv4(224, 0, 0, 254), 12345}, 0, false},
-
- {"udp6", &UDPAddr{ParseIP("ff01::114"), 12345}, FlagUp | FlagLoopback, true},
- {"udp6", &UDPAddr{ParseIP("ff01::114"), 12345}, 0, true},
- {"udp6", &UDPAddr{ParseIP("ff02::114"), 12345}, FlagUp | FlagLoopback, true},
- {"udp6", &UDPAddr{ParseIP("ff02::114"), 12345}, 0, true},
- {"udp6", &UDPAddr{ParseIP("ff04::114"), 12345}, FlagUp | FlagLoopback, true},
- {"udp6", &UDPAddr{ParseIP("ff04::114"), 12345}, 0, true},
- {"udp6", &UDPAddr{ParseIP("ff05::114"), 12345}, FlagUp | FlagLoopback, true},
- {"udp6", &UDPAddr{ParseIP("ff05::114"), 12345}, 0, true},
- {"udp6", &UDPAddr{ParseIP("ff08::114"), 12345}, FlagUp | FlagLoopback, true},
- {"udp6", &UDPAddr{ParseIP("ff08::114"), 12345}, 0, true},
- {"udp6", &UDPAddr{ParseIP("ff0e::114"), 12345}, FlagUp | FlagLoopback, true},
- {"udp6", &UDPAddr{ParseIP("ff0e::114"), 12345}, 0, true},
-}
-
-// TestMulticastListener tests both single and double listen to a test
-// listener with same address family, same group address and same port.
-func TestMulticastListener(t *testing.T) {
- switch runtime.GOOS {
- case "netbsd", "openbsd", "plan9", "windows":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
- case "linux":
- if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" {
- t.Logf("skipping test on %q/%q", runtime.GOOS, runtime.GOARCH)
- return
- }
- }
-
- for _, tt := range multicastListenerTests {
- if tt.ipv6 && (!supportsIPv6 || os.Getuid() != 0) {
- continue
- }
- ifi, err := availMulticastInterface(t, tt.flags)
- if err != nil {
- continue
- }
- c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr)
- if err != nil {
- t.Fatalf("First ListenMulticastUDP failed: %v", err)
- }
- checkMulticastListener(t, err, c1, tt.gaddr)
- c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr)
- if err != nil {
- t.Fatalf("Second ListenMulticastUDP failed: %v", err)
- }
- checkMulticastListener(t, err, c2, tt.gaddr)
- c2.Close()
- switch c1.fd.family {
- case syscall.AF_INET:
- testIPv4MulticastSocketOptions(t, c1.fd, ifi)
- case syscall.AF_INET6:
- testIPv6MulticastSocketOptions(t, c1.fd, ifi)
- }
- c1.Close()
- }
-}
-
-func TestSimpleMulticastListener(t *testing.T) {
- switch runtime.GOOS {
- case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
- case "windows":
- if testing.Short() || !*testExternal {
- t.Logf("skipping test on windows to avoid firewall")
- return
- }
- }
-
- for _, tt := range multicastListenerTests {
- if tt.ipv6 {
- continue
- }
- tt.flags = FlagUp | FlagMulticast // for windows testing
- ifi, err := availMulticastInterface(t, tt.flags)
- if err != nil {
- continue
- }
- c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr)
- if err != nil {
- t.Fatalf("First ListenMulticastUDP failed: %v", err)
- }
- checkSimpleMulticastListener(t, err, c1, tt.gaddr)
- c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr)
- if err != nil {
- t.Fatalf("Second ListenMulticastUDP failed: %v", err)
- }
- checkSimpleMulticastListener(t, err, c2, tt.gaddr)
- c2.Close()
- c1.Close()
- }
-}
-
-func checkMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) {
- if !multicastRIBContains(t, gaddr.IP) {
- t.Fatalf("%q not found in RIB", gaddr.String())
- }
- if c.LocalAddr().String() != gaddr.String() {
- t.Fatalf("LocalAddr returns %q, expected %q", c.LocalAddr().String(), gaddr.String())
- }
-}
-
-func checkSimpleMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) {
- if c.LocalAddr().String() != gaddr.String() {
- t.Fatalf("LocalAddr returns %q, expected %q", c.LocalAddr().String(), gaddr.String())
- }
-}
-
-func availMulticastInterface(t *testing.T, flags Flags) (*Interface, error) {
- var ifi *Interface
- if flags != Flags(0) {
- ift, err := Interfaces()
- if err != nil {
- t.Fatalf("Interfaces failed: %v", err)
- }
- for _, x := range ift {
- if x.Flags&flags == flags {
- ifi = &x
- break
- }
- }
- if ifi == nil {
- return nil, errors.New("an appropriate multicast interface not found")
- }
- }
- return ifi, nil
-}
-
-func multicastRIBContains(t *testing.T, ip IP) bool {
- ift, err := Interfaces()
- if err != nil {
- t.Fatalf("Interfaces failed: %v", err)
- }
- for _, ifi := range ift {
- ifmat, err := ifi.MulticastAddrs()
- if err != nil {
- t.Fatalf("MulticastAddrs failed: %v", err)
- }
- for _, ifma := range ifmat {
- if ifma.(*IPAddr).IP.Equal(ip) {
- return true
- }
- }
- }
- return false
-}
-
-func testIPv4MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) {
- _, err := ipv4MulticastInterface(fd)
- if err != nil {
- t.Fatalf("ipv4MulticastInterface failed: %v", err)
- }
- if ifi != nil {
- err = setIPv4MulticastInterface(fd, ifi)
- if err != nil {
- t.Fatalf("setIPv4MulticastInterface failed: %v", err)
- }
- }
- _, err = ipv4MulticastTTL(fd)
- if err != nil {
- t.Fatalf("ipv4MulticastTTL failed: %v", err)
- }
- err = setIPv4MulticastTTL(fd, 1)
- if err != nil {
- t.Fatalf("setIPv4MulticastTTL failed: %v", err)
- }
- _, err = ipv4MulticastLoopback(fd)
- if err != nil {
- t.Fatalf("ipv4MulticastLoopback failed: %v", err)
- }
- err = setIPv4MulticastLoopback(fd, false)
- if err != nil {
- t.Fatalf("setIPv4MulticastLoopback failed: %v", err)
- }
-}
-
-func testIPv6MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) {
- _, err := ipv6MulticastInterface(fd)
- if err != nil {
- t.Fatalf("ipv6MulticastInterface failed: %v", err)
- }
- if ifi != nil {
- err = setIPv6MulticastInterface(fd, ifi)
- if err != nil {
- t.Fatalf("setIPv6MulticastInterface failed: %v", err)
- }
- }
- _, err = ipv6MulticastHopLimit(fd)
- if err != nil {
- t.Fatalf("ipv6MulticastHopLimit failed: %v", err)
- }
- err = setIPv6MulticastHopLimit(fd, 1)
- if err != nil {
- t.Fatalf("setIPv6MulticastHopLimit failed: %v", err)
- }
- _, err = ipv6MulticastLoopback(fd)
- if err != nil {
- t.Fatalf("ipv6MulticastLoopback failed: %v", err)
- }
- err = setIPv6MulticastLoopback(fd, false)
- if err != nil {
- t.Fatalf("setIPv6MulticastLoopback failed: %v", err)
- }
-}
diff --git a/src/pkg/net/net.go b/src/pkg/net/net.go
index 9ebcdbe99..72b2b646c 100644
--- a/src/pkg/net/net.go
+++ b/src/pkg/net/net.go
@@ -44,6 +44,10 @@ package net
import (
"errors"
+ "io"
+ "os"
+ "sync"
+ "syscall"
"time"
)
@@ -103,6 +107,105 @@ type Conn interface {
SetWriteDeadline(t time.Time) error
}
+type conn struct {
+ fd *netFD
+}
+
+func (c *conn) ok() bool { return c != nil && c.fd != nil }
+
+// Implementation of the Conn interface.
+
+// Read implements the Conn Read method.
+func (c *conn) Read(b []byte) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ return c.fd.Read(b)
+}
+
+// Write implements the Conn Write method.
+func (c *conn) Write(b []byte) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ return c.fd.Write(b)
+}
+
+// Close closes the connection.
+func (c *conn) Close() error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ return c.fd.Close()
+}
+
+// LocalAddr returns the local network address.
+func (c *conn) LocalAddr() Addr {
+ if !c.ok() {
+ return nil
+ }
+ return c.fd.laddr
+}
+
+// RemoteAddr returns the remote network address.
+func (c *conn) RemoteAddr() Addr {
+ if !c.ok() {
+ return nil
+ }
+ return c.fd.raddr
+}
+
+// SetDeadline implements the Conn SetDeadline method.
+func (c *conn) SetDeadline(t time.Time) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ return setDeadline(c.fd, t)
+}
+
+// SetReadDeadline implements the Conn SetReadDeadline method.
+func (c *conn) SetReadDeadline(t time.Time) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ return setReadDeadline(c.fd, t)
+}
+
+// SetWriteDeadline implements the Conn SetWriteDeadline method.
+func (c *conn) SetWriteDeadline(t time.Time) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ return setWriteDeadline(c.fd, t)
+}
+
+// SetReadBuffer sets the size of the operating system's
+// receive buffer associated with the connection.
+func (c *conn) SetReadBuffer(bytes int) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ return setReadBuffer(c.fd, bytes)
+}
+
+// SetWriteBuffer sets the size of the operating system's
+// transmit buffer associated with the connection.
+func (c *conn) SetWriteBuffer(bytes int) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ return setWriteBuffer(c.fd, bytes)
+}
+
+// File sets the underlying os.File to blocking mode and returns a copy.
+// It is the caller's responsibility to close f when finished.
+// Closing c does not affect f, and closing f does not affect c.
+//
+// The returned os.File's file descriptor is different from the connection's.
+// Attempting to change properties of the original using this duplicate
+// may or may not have the desired effect.
+func (c *conn) File() (f *os.File, err error) { return c.fd.dup() }
+
// An Error represents a network error.
type Error interface {
error
@@ -173,11 +276,23 @@ type Listener interface {
var errMissingAddress = errors.New("missing address")
+// OpError is the error type usually returned by functions in the net
+// package. It describes the operation, network type, and address of
+// an error.
type OpError struct {
- Op string
- Net string
+ // Op is the operation which caused the error, such as
+ // "read" or "write".
+ Op string
+
+ // Net is the network type on which this error occurred,
+ // such as "tcp" or "udp6".
+ Net string
+
+ // Addr is the network address on which this error occurred.
Addr Addr
- Err error
+
+ // Err is the error that occurred during the operation.
+ Err error
}
func (e *OpError) Error() string {
@@ -204,6 +319,8 @@ func (e *OpError) Temporary() bool {
return ok && t.Temporary()
}
+var noDeadline = time.Time{}
+
type timeout interface {
Timeout() bool
}
@@ -221,6 +338,8 @@ func (e *timeoutError) Temporary() bool { return true }
var errTimeout error = &timeoutError{}
+var errClosing = errors.New("use of closed network connection")
+
type AddrError struct {
Err string
Addr string
@@ -262,3 +381,47 @@ func (e *DNSConfigError) Error() string {
func (e *DNSConfigError) Timeout() bool { return false }
func (e *DNSConfigError) Temporary() bool { return false }
+
+type writerOnly struct {
+ io.Writer
+}
+
+// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
+// applicable.
+func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
+ // Use wrapper to hide existing r.ReadFrom from io.Copy.
+ return io.Copy(writerOnly{w}, r)
+}
+
+// deadline is an atomically-accessed number of nanoseconds since 1970
+// or 0, if no deadline is set.
+type deadline struct {
+ sync.Mutex
+ val int64
+}
+
+func (d *deadline) expired() bool {
+ t := d.value()
+ return t > 0 && time.Now().UnixNano() >= t
+}
+
+func (d *deadline) value() (v int64) {
+ d.Lock()
+ v = d.val
+ d.Unlock()
+ return
+}
+
+func (d *deadline) set(v int64) {
+ d.Lock()
+ d.val = v
+ d.Unlock()
+}
+
+func (d *deadline) setTime(t time.Time) {
+ if t.IsZero() {
+ d.set(0)
+ } else {
+ d.set(t.UnixNano())
+ }
+}
diff --git a/src/pkg/net/net_test.go b/src/pkg/net/net_test.go
index fd145e1d7..1a512a5b1 100644
--- a/src/pkg/net/net_test.go
+++ b/src/pkg/net/net_test.go
@@ -6,6 +6,8 @@ package net
import (
"io"
+ "io/ioutil"
+ "os"
"runtime"
"testing"
"time"
@@ -13,18 +15,17 @@ import (
func TestShutdown(t *testing.T) {
if runtime.GOOS == "plan9" {
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
- l, err := Listen("tcp", "127.0.0.1:0")
+ ln, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
- if l, err = Listen("tcp6", "[::1]:0"); err != nil {
+ if ln, err = Listen("tcp6", "[::1]:0"); err != nil {
t.Fatalf("ListenTCP on :0: %v", err)
}
}
go func() {
- c, err := l.Accept()
+ c, err := ln.Accept()
if err != nil {
t.Fatalf("Accept: %v", err)
}
@@ -37,7 +38,7 @@ func TestShutdown(t *testing.T) {
c.Close()
}()
- c, err := Dial("tcp", l.Addr().String())
+ c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
@@ -58,8 +59,61 @@ func TestShutdown(t *testing.T) {
}
}
+func TestShutdownUnix(t *testing.T) {
+ switch runtime.GOOS {
+ case "windows", "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+ f, err := ioutil.TempFile("", "go_net_unixtest")
+ if err != nil {
+ t.Fatalf("TempFile: %s", err)
+ }
+ f.Close()
+ tmpname := f.Name()
+ os.Remove(tmpname)
+ ln, err := Listen("unix", tmpname)
+ if err != nil {
+ t.Fatalf("ListenUnix on %s: %s", tmpname, err)
+ }
+ defer os.Remove(tmpname)
+
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("Accept: %v", err)
+ }
+ var buf [10]byte
+ n, err := c.Read(buf[:])
+ if n != 0 || err != io.EOF {
+ t.Fatalf("server Read = %d, %v; want 0, io.EOF", n, err)
+ }
+ c.Write([]byte("response"))
+ c.Close()
+ }()
+
+ c, err := Dial("unix", tmpname)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer c.Close()
+
+ err = c.(*UnixConn).CloseWrite()
+ if err != nil {
+ t.Fatalf("CloseWrite: %v", err)
+ }
+ var buf [10]byte
+ n, err := c.Read(buf[:])
+ if err != nil {
+ t.Fatalf("client Read: %d, %v", n, err)
+ }
+ got := string(buf[:n])
+ if got != "response" {
+ t.Errorf("read = %q, want \"response\"", got)
+ }
+}
+
func TestTCPListenClose(t *testing.T) {
- l, err := Listen("tcp", "127.0.0.1:0")
+ ln, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
@@ -67,11 +121,12 @@ func TestTCPListenClose(t *testing.T) {
done := make(chan bool, 1)
go func() {
time.Sleep(100 * time.Millisecond)
- l.Close()
+ ln.Close()
}()
go func() {
- _, err = l.Accept()
+ c, err := ln.Accept()
if err == nil {
+ c.Close()
t.Error("Accept succeeded")
} else {
t.Logf("Accept timeout error: %s (any error is fine)", err)
@@ -86,7 +141,11 @@ func TestTCPListenClose(t *testing.T) {
}
func TestUDPListenClose(t *testing.T) {
- l, err := ListenPacket("udp", "127.0.0.1:0")
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+ ln, err := ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
@@ -95,10 +154,10 @@ func TestUDPListenClose(t *testing.T) {
done := make(chan bool, 1)
go func() {
time.Sleep(100 * time.Millisecond)
- l.Close()
+ ln.Close()
}()
go func() {
- _, _, err = l.ReadFrom(buf)
+ _, _, err = ln.ReadFrom(buf)
if err == nil {
t.Error("ReadFrom succeeded")
} else {
@@ -112,3 +171,46 @@ func TestUDPListenClose(t *testing.T) {
t.Fatal("timeout waiting for UDP close")
}
}
+
+func TestTCPClose(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+ l, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ read := func(r io.Reader) error {
+ var m [1]byte
+ _, err := r.Read(m[:])
+ return err
+ }
+
+ go func() {
+ c, err := Dial("tcp", l.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ go read(c)
+
+ time.Sleep(10 * time.Millisecond)
+ c.Close()
+ }()
+
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ for err == nil {
+ err = read(c)
+ }
+ if err != nil && err != io.EOF {
+ t.Fatal(err)
+ }
+}
diff --git a/src/pkg/net/newpollserver.go b/src/pkg/net/newpollserver_unix.go
index d34bb511f..618b5b10b 100644
--- a/src/pkg/net/newpollserver.go
+++ b/src/pkg/net/newpollserver_unix.go
@@ -13,8 +13,6 @@ import (
func newPollServer() (s *pollServer, err error) {
s = new(pollServer)
- s.cr = make(chan *netFD, 1)
- s.cw = make(chan *netFD, 1)
if s.pr, s.pw, err = os.Pipe(); err != nil {
return nil, err
}
diff --git a/src/pkg/net/packetconn_test.go b/src/pkg/net/packetconn_test.go
new file mode 100644
index 000000000..93c7a6472
--- /dev/null
+++ b/src/pkg/net/packetconn_test.go
@@ -0,0 +1,200 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements API tests across platforms and will never have a build
+// tag.
+
+package net
+
+import (
+ "os"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+var packetConnTests = []struct {
+ net string
+ addr1 string
+ addr2 string
+}{
+ {"udp", "127.0.0.1:0", "127.0.0.1:0"},
+ {"ip:icmp", "127.0.0.1", "127.0.0.1"},
+ {"unixgram", testUnixAddr(), testUnixAddr()},
+}
+
+func TestPacketConn(t *testing.T) {
+ closer := func(c PacketConn, net, addr1, addr2 string) {
+ c.Close()
+ switch net {
+ case "unixgram":
+ os.Remove(addr1)
+ os.Remove(addr2)
+ }
+ }
+
+ for i, tt := range packetConnTests {
+ var wb []byte
+ netstr := strings.Split(tt.net, ":")
+ switch netstr[0] {
+ case "udp":
+ wb = []byte("UDP PACKETCONN TEST")
+ case "ip":
+ switch runtime.GOOS {
+ case "plan9":
+ continue
+ }
+ if os.Getuid() != 0 {
+ continue
+ }
+ var err error
+ wb, err = (&icmpMessage{
+ Type: icmpv4EchoRequest, Code: 0,
+ Body: &icmpEcho{
+ ID: os.Getpid() & 0xffff, Seq: i + 1,
+ Data: []byte("IP PACKETCONN TEST"),
+ },
+ }).Marshal()
+ if err != nil {
+ t.Fatalf("icmpMessage.Marshal failed: %v", err)
+ }
+ case "unixgram":
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ continue
+ }
+ wb = []byte("UNIXGRAM PACKETCONN TEST")
+ default:
+ continue
+ }
+
+ c1, err := ListenPacket(tt.net, tt.addr1)
+ if err != nil {
+ t.Fatalf("ListenPacket failed: %v", err)
+ }
+ defer closer(c1, netstr[0], tt.addr1, tt.addr2)
+ c1.LocalAddr()
+ c1.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ c1.SetWriteDeadline(time.Now().Add(100 * time.Millisecond))
+
+ c2, err := ListenPacket(tt.net, tt.addr2)
+ if err != nil {
+ t.Fatalf("ListenPacket failed: %v", err)
+ }
+ defer closer(c2, netstr[0], tt.addr1, tt.addr2)
+ c2.LocalAddr()
+ c2.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ c2.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ c2.SetWriteDeadline(time.Now().Add(100 * time.Millisecond))
+
+ if _, err := c1.WriteTo(wb, c2.LocalAddr()); err != nil {
+ t.Fatalf("PacketConn.WriteTo failed: %v", err)
+ }
+ rb2 := make([]byte, 128)
+ if _, _, err := c2.ReadFrom(rb2); err != nil {
+ t.Fatalf("PacketConn.ReadFrom failed: %v", err)
+ }
+ if _, err := c2.WriteTo(wb, c1.LocalAddr()); err != nil {
+ t.Fatalf("PacketConn.WriteTo failed: %v", err)
+ }
+ rb1 := make([]byte, 128)
+ if _, _, err := c1.ReadFrom(rb1); err != nil {
+ t.Fatalf("PacketConn.ReadFrom failed: %v", err)
+ }
+ }
+}
+
+func TestConnAndPacketConn(t *testing.T) {
+ closer := func(c PacketConn, net, addr1, addr2 string) {
+ c.Close()
+ switch net {
+ case "unixgram":
+ os.Remove(addr1)
+ os.Remove(addr2)
+ }
+ }
+
+ for i, tt := range packetConnTests {
+ var wb []byte
+ netstr := strings.Split(tt.net, ":")
+ switch netstr[0] {
+ case "udp":
+ wb = []byte("UDP PACKETCONN TEST")
+ case "ip":
+ switch runtime.GOOS {
+ case "plan9":
+ continue
+ }
+ if os.Getuid() != 0 {
+ continue
+ }
+ var err error
+ wb, err = (&icmpMessage{
+ Type: icmpv4EchoRequest, Code: 0,
+ Body: &icmpEcho{
+ ID: os.Getpid() & 0xffff, Seq: i + 1,
+ Data: []byte("IP PACKETCONN TEST"),
+ },
+ }).Marshal()
+ if err != nil {
+ t.Fatalf("icmpMessage.Marshal failed: %v", err)
+ }
+ case "unixgram":
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ continue
+ }
+ wb = []byte("UNIXGRAM PACKETCONN TEST")
+ default:
+ continue
+ }
+
+ c1, err := ListenPacket(tt.net, tt.addr1)
+ if err != nil {
+ t.Fatalf("ListenPacket failed: %v", err)
+ }
+ defer closer(c1, netstr[0], tt.addr1, tt.addr2)
+ c1.LocalAddr()
+ c1.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ c1.SetWriteDeadline(time.Now().Add(100 * time.Millisecond))
+
+ c2, err := Dial(tt.net, c1.LocalAddr().String())
+ if err != nil {
+ t.Fatalf("Dial failed: %v", err)
+ }
+ defer c2.Close()
+ c2.LocalAddr()
+ c2.RemoteAddr()
+ c2.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ c2.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ c2.SetWriteDeadline(time.Now().Add(100 * time.Millisecond))
+
+ if _, err := c2.Write(wb); err != nil {
+ t.Fatalf("Conn.Write failed: %v", err)
+ }
+ rb1 := make([]byte, 128)
+ if _, _, err := c1.ReadFrom(rb1); err != nil {
+ t.Fatalf("PacetConn.ReadFrom failed: %v", err)
+ }
+ var dst Addr
+ switch netstr[0] {
+ case "ip":
+ dst = &IPAddr{IP: IPv4(127, 0, 0, 1)}
+ case "unixgram":
+ continue
+ default:
+ dst = c2.LocalAddr()
+ }
+ if _, err := c1.WriteTo(wb, dst); err != nil {
+ t.Fatalf("PacketConn.WriteTo failed: %v", err)
+ }
+ rb2 := make([]byte, 128)
+ if _, err := c2.Read(rb2); err != nil {
+ t.Fatalf("Conn.Read failed: %v", err)
+ }
+ }
+}
diff --git a/src/pkg/net/parse_test.go b/src/pkg/net/parse_test.go
index 30fda45df..9df0c534b 100644
--- a/src/pkg/net/parse_test.go
+++ b/src/pkg/net/parse_test.go
@@ -15,8 +15,7 @@ func TestReadLine(t *testing.T) {
// /etc/services file does not exist on windows and Plan 9.
switch runtime.GOOS {
case "plan9", "windows":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
filename := "/etc/services" // a nice big file
diff --git a/src/pkg/net/port.go b/src/pkg/net/port.go
index 16780da11..c24f4ed5b 100644
--- a/src/pkg/net/port.go
+++ b/src/pkg/net/port.go
@@ -1,69 +1,24 @@
-// Copyright 2009 The Go Authors. All rights reserved.
+// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux netbsd openbsd
-
-// Read system port mappings from /etc/services
+// Network service port manipulations
package net
-import "sync"
-
-var services map[string]map[string]int
-var servicesError error
-var onceReadServices sync.Once
-
-func readServices() {
- services = make(map[string]map[string]int)
- var file *file
- if file, servicesError = open("/etc/services"); servicesError != nil {
- return
- }
- for line, ok := file.readLine(); ok; line, ok = file.readLine() {
- // "http 80/tcp www www-http # World Wide Web HTTP"
- if i := byteIndex(line, '#'); i >= 0 {
- line = line[0:i]
- }
- f := getFields(line)
- if len(f) < 2 {
- continue
- }
- portnet := f[1] // "tcp/80"
- port, j, ok := dtoi(portnet, 0)
- if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' {
- continue
- }
- netw := portnet[j+1:] // "tcp"
- m, ok1 := services[netw]
- if !ok1 {
- m = make(map[string]int)
- services[netw] = m
- }
- for i := 0; i < len(f); i++ {
- if i != 1 { // f[1] was port/net
- m[f[i]] = port
- }
+// parsePort parses port as a network service port number for both
+// TCP and UDP.
+func parsePort(net, port string) (int, error) {
+ p, i, ok := dtoi(port, 0)
+ if !ok || i != len(port) {
+ var err error
+ p, err = LookupPort(net, port)
+ if err != nil {
+ return 0, err
}
}
- file.close()
-}
-
-// goLookupPort is the native Go implementation of LookupPort.
-func goLookupPort(network, service string) (port int, err error) {
- onceReadServices.Do(readServices)
-
- switch network {
- case "tcp4", "tcp6":
- network = "tcp"
- case "udp4", "udp6":
- network = "udp"
- }
-
- if m, ok := services[network]; ok {
- if port, ok = m[service]; ok {
- return
- }
+ if p < 0 || p > 0xFFFF {
+ return 0, &AddrError{"invalid port", port}
}
- return 0, &AddrError{"unknown port", network + "/" + service}
+ return p, nil
}
diff --git a/src/pkg/net/port_test.go b/src/pkg/net/port_test.go
index 329b169f3..9e8968f35 100644
--- a/src/pkg/net/port_test.go
+++ b/src/pkg/net/port_test.go
@@ -46,7 +46,7 @@ func TestLookupPort(t *testing.T) {
for i := 0; i < len(porttests); i++ {
tt := porttests[i]
if port, err := LookupPort(tt.netw, tt.name); port != tt.port || (err == nil) != tt.ok {
- t.Errorf("LookupPort(%q, %q) = %v, %s; want %v",
+ t.Errorf("LookupPort(%q, %q) = %v, %v; want %v",
tt.netw, tt.name, port, err, tt.port)
}
}
diff --git a/src/pkg/net/port_unix.go b/src/pkg/net/port_unix.go
new file mode 100644
index 000000000..16780da11
--- /dev/null
+++ b/src/pkg/net/port_unix.go
@@ -0,0 +1,69 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd
+
+// Read system port mappings from /etc/services
+
+package net
+
+import "sync"
+
+var services map[string]map[string]int
+var servicesError error
+var onceReadServices sync.Once
+
+func readServices() {
+ services = make(map[string]map[string]int)
+ var file *file
+ if file, servicesError = open("/etc/services"); servicesError != nil {
+ return
+ }
+ for line, ok := file.readLine(); ok; line, ok = file.readLine() {
+ // "http 80/tcp www www-http # World Wide Web HTTP"
+ if i := byteIndex(line, '#'); i >= 0 {
+ line = line[0:i]
+ }
+ f := getFields(line)
+ if len(f) < 2 {
+ continue
+ }
+ portnet := f[1] // "tcp/80"
+ port, j, ok := dtoi(portnet, 0)
+ if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' {
+ continue
+ }
+ netw := portnet[j+1:] // "tcp"
+ m, ok1 := services[netw]
+ if !ok1 {
+ m = make(map[string]int)
+ services[netw] = m
+ }
+ for i := 0; i < len(f); i++ {
+ if i != 1 { // f[1] was port/net
+ m[f[i]] = port
+ }
+ }
+ }
+ file.close()
+}
+
+// goLookupPort is the native Go implementation of LookupPort.
+func goLookupPort(network, service string) (port int, err error) {
+ onceReadServices.Do(readServices)
+
+ switch network {
+ case "tcp4", "tcp6":
+ network = "tcp"
+ case "udp4", "udp6":
+ network = "udp"
+ }
+
+ if m, ok := services[network]; ok {
+ if port, ok = m[service]; ok {
+ return
+ }
+ }
+ return 0, &AddrError{"unknown port", network + "/" + service}
+}
diff --git a/src/pkg/net/protoconn_test.go b/src/pkg/net/protoconn_test.go
new file mode 100644
index 000000000..2fe7d1d1f
--- /dev/null
+++ b/src/pkg/net/protoconn_test.go
@@ -0,0 +1,358 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements API tests across platforms and will never have a build
+// tag.
+
+package net
+
+import (
+ "io/ioutil"
+ "os"
+ "runtime"
+ "testing"
+ "time"
+)
+
+// testUnixAddr uses ioutil.TempFile to get a name that is unique.
+func testUnixAddr() string {
+ f, err := ioutil.TempFile("", "nettest")
+ if err != nil {
+ panic(err)
+ }
+ addr := f.Name()
+ f.Close()
+ os.Remove(addr)
+ return addr
+}
+
+var condFatalf = func() func(*testing.T, string, ...interface{}) {
+ // A few APIs are not implemented yet on both Plan 9 and Windows.
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ return (*testing.T).Logf
+ }
+ return (*testing.T).Fatalf
+}()
+
+func TestTCPListenerSpecificMethods(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ la, err := ResolveTCPAddr("tcp4", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("ResolveTCPAddr failed: %v", err)
+ }
+ ln, err := ListenTCP("tcp4", la)
+ if err != nil {
+ t.Fatalf("ListenTCP failed: %v", err)
+ }
+ defer ln.Close()
+ ln.Addr()
+ ln.SetDeadline(time.Now().Add(30 * time.Nanosecond))
+
+ if c, err := ln.Accept(); err != nil {
+ if !err.(Error).Timeout() {
+ t.Fatalf("TCPListener.Accept failed: %v", err)
+ }
+ } else {
+ c.Close()
+ }
+ if c, err := ln.AcceptTCP(); err != nil {
+ if !err.(Error).Timeout() {
+ t.Fatalf("TCPListener.AcceptTCP failed: %v", err)
+ }
+ } else {
+ c.Close()
+ }
+
+ if f, err := ln.File(); err != nil {
+ condFatalf(t, "TCPListener.File failed: %v", err)
+ } else {
+ f.Close()
+ }
+}
+
+func TestTCPConnSpecificMethods(t *testing.T) {
+ la, err := ResolveTCPAddr("tcp4", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("ResolveTCPAddr failed: %v", err)
+ }
+ ln, err := ListenTCP("tcp4", la)
+ if err != nil {
+ t.Fatalf("ListenTCP failed: %v", err)
+ }
+ defer ln.Close()
+ ln.Addr()
+
+ done := make(chan int)
+ go transponder(t, ln, done)
+
+ ra, err := ResolveTCPAddr("tcp4", ln.Addr().String())
+ if err != nil {
+ t.Fatalf("ResolveTCPAddr failed: %v", err)
+ }
+ c, err := DialTCP("tcp4", nil, ra)
+ if err != nil {
+ t.Fatalf("DialTCP failed: %v", err)
+ }
+ defer c.Close()
+ c.SetKeepAlive(false)
+ c.SetLinger(0)
+ c.SetNoDelay(false)
+ c.LocalAddr()
+ c.RemoteAddr()
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+ if _, err := c.Write([]byte("TCPCONN TEST")); err != nil {
+ t.Fatalf("TCPConn.Write failed: %v", err)
+ }
+ rb := make([]byte, 128)
+ if _, err := c.Read(rb); err != nil {
+ t.Fatalf("TCPConn.Read failed: %v", err)
+ }
+
+ <-done
+}
+
+func TestUDPConnSpecificMethods(t *testing.T) {
+ la, err := ResolveUDPAddr("udp4", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("ResolveUDPAddr failed: %v", err)
+ }
+ c, err := ListenUDP("udp4", la)
+ if err != nil {
+ t.Fatalf("ListenUDP failed: %v", err)
+ }
+ defer c.Close()
+ c.LocalAddr()
+ c.RemoteAddr()
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+ c.SetReadBuffer(2048)
+ c.SetWriteBuffer(2048)
+
+ wb := []byte("UDPCONN TEST")
+ rb := make([]byte, 128)
+ if _, err := c.WriteToUDP(wb, c.LocalAddr().(*UDPAddr)); err != nil {
+ t.Fatalf("UDPConn.WriteToUDP failed: %v", err)
+ }
+ if _, _, err := c.ReadFromUDP(rb); err != nil {
+ t.Fatalf("UDPConn.ReadFromUDP failed: %v", err)
+ }
+ if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*UDPAddr)); err != nil {
+ condFatalf(t, "UDPConn.WriteMsgUDP failed: %v", err)
+ }
+ if _, _, _, _, err := c.ReadMsgUDP(rb, nil); err != nil {
+ condFatalf(t, "UDPConn.ReadMsgUDP failed: %v", err)
+ }
+
+ if f, err := c.File(); err != nil {
+ condFatalf(t, "UDPConn.File failed: %v", err)
+ } else {
+ f.Close()
+ }
+}
+
+func TestIPConnSpecificMethods(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping read test on %q", runtime.GOOS)
+ }
+ if os.Getuid() != 0 {
+ t.Skipf("skipping test; must be root")
+ }
+
+ la, err := ResolveIPAddr("ip4", "127.0.0.1")
+ if err != nil {
+ t.Fatalf("ResolveIPAddr failed: %v", err)
+ }
+ c, err := ListenIP("ip4:icmp", la)
+ if err != nil {
+ t.Fatalf("ListenIP failed: %v", err)
+ }
+ defer c.Close()
+ c.LocalAddr()
+ c.RemoteAddr()
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+ c.SetReadBuffer(2048)
+ c.SetWriteBuffer(2048)
+
+ wb, err := (&icmpMessage{
+ Type: icmpv4EchoRequest, Code: 0,
+ Body: &icmpEcho{
+ ID: os.Getpid() & 0xffff, Seq: 1,
+ Data: []byte("IPCONN TEST "),
+ },
+ }).Marshal()
+ if err != nil {
+ t.Fatalf("icmpMessage.Marshal failed: %v", err)
+ }
+ rb := make([]byte, 20+128)
+ if _, err := c.WriteToIP(wb, c.LocalAddr().(*IPAddr)); err != nil {
+ t.Fatalf("IPConn.WriteToIP failed: %v", err)
+ }
+ if _, _, err := c.ReadFromIP(rb); err != nil {
+ t.Fatalf("IPConn.ReadFromIP failed: %v", err)
+ }
+ if _, _, err := c.WriteMsgIP(wb, nil, c.LocalAddr().(*IPAddr)); err != nil {
+ condFatalf(t, "IPConn.WriteMsgIP failed: %v", err)
+ }
+ if _, _, _, _, err := c.ReadMsgIP(rb, nil); err != nil {
+ condFatalf(t, "IPConn.ReadMsgIP failed: %v", err)
+ }
+
+ if f, err := c.File(); err != nil {
+ condFatalf(t, "IPConn.File failed: %v", err)
+ } else {
+ f.Close()
+ }
+}
+
+func TestUnixListenerSpecificMethods(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ t.Skipf("skipping read test on %q", runtime.GOOS)
+ }
+
+ addr := testUnixAddr()
+ la, err := ResolveUnixAddr("unix", addr)
+ if err != nil {
+ t.Fatalf("ResolveUnixAddr failed: %v", err)
+ }
+ ln, err := ListenUnix("unix", la)
+ if err != nil {
+ t.Fatalf("ListenUnix failed: %v", err)
+ }
+ defer ln.Close()
+ defer os.Remove(addr)
+ ln.Addr()
+ ln.SetDeadline(time.Now().Add(30 * time.Nanosecond))
+
+ if c, err := ln.Accept(); err != nil {
+ if !err.(Error).Timeout() {
+ t.Fatalf("UnixListener.Accept failed: %v", err)
+ }
+ } else {
+ c.Close()
+ }
+ if c, err := ln.AcceptUnix(); err != nil {
+ if !err.(Error).Timeout() {
+ t.Fatalf("UnixListener.AcceptUnix failed: %v", err)
+ }
+ } else {
+ c.Close()
+ }
+
+ if f, err := ln.File(); err != nil {
+ t.Fatalf("UnixListener.File failed: %v", err)
+ } else {
+ f.Close()
+ }
+}
+
+func TestUnixConnSpecificMethods(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ addr1, addr2, addr3 := testUnixAddr(), testUnixAddr(), testUnixAddr()
+
+ a1, err := ResolveUnixAddr("unixgram", addr1)
+ if err != nil {
+ t.Fatalf("ResolveUnixAddr failed: %v", err)
+ }
+ c1, err := DialUnix("unixgram", a1, nil)
+ if err != nil {
+ t.Fatalf("DialUnix failed: %v", err)
+ }
+ defer c1.Close()
+ defer os.Remove(addr1)
+ c1.LocalAddr()
+ c1.RemoteAddr()
+ c1.SetDeadline(time.Now().Add(someTimeout))
+ c1.SetReadDeadline(time.Now().Add(someTimeout))
+ c1.SetWriteDeadline(time.Now().Add(someTimeout))
+ c1.SetReadBuffer(2048)
+ c1.SetWriteBuffer(2048)
+
+ a2, err := ResolveUnixAddr("unixgram", addr2)
+ if err != nil {
+ t.Fatalf("ResolveUnixAddr failed: %v", err)
+ }
+ c2, err := DialUnix("unixgram", a2, nil)
+ if err != nil {
+ t.Fatalf("DialUnix failed: %v", err)
+ }
+ defer c2.Close()
+ defer os.Remove(addr2)
+ c2.LocalAddr()
+ c2.RemoteAddr()
+ c2.SetDeadline(time.Now().Add(someTimeout))
+ c2.SetReadDeadline(time.Now().Add(someTimeout))
+ c2.SetWriteDeadline(time.Now().Add(someTimeout))
+ c2.SetReadBuffer(2048)
+ c2.SetWriteBuffer(2048)
+
+ a3, err := ResolveUnixAddr("unixgram", addr3)
+ if err != nil {
+ t.Fatalf("ResolveUnixAddr failed: %v", err)
+ }
+ c3, err := ListenUnixgram("unixgram", a3)
+ if err != nil {
+ t.Fatalf("ListenUnixgram failed: %v", err)
+ }
+ defer c3.Close()
+ defer os.Remove(addr3)
+ c3.LocalAddr()
+ c3.RemoteAddr()
+ c3.SetDeadline(time.Now().Add(someTimeout))
+ c3.SetReadDeadline(time.Now().Add(someTimeout))
+ c3.SetWriteDeadline(time.Now().Add(someTimeout))
+ c3.SetReadBuffer(2048)
+ c3.SetWriteBuffer(2048)
+
+ wb := []byte("UNIXCONN TEST")
+ rb1 := make([]byte, 128)
+ rb2 := make([]byte, 128)
+ rb3 := make([]byte, 128)
+ if _, _, err := c1.WriteMsgUnix(wb, nil, a2); err != nil {
+ t.Fatalf("UnixConn.WriteMsgUnix failed: %v", err)
+ }
+ if _, _, _, _, err := c2.ReadMsgUnix(rb2, nil); err != nil {
+ t.Fatalf("UnixConn.ReadMsgUnix failed: %v", err)
+ }
+ if _, err := c2.WriteToUnix(wb, a1); err != nil {
+ t.Fatalf("UnixConn.WriteToUnix failed: %v", err)
+ }
+ if _, _, err := c1.ReadFromUnix(rb1); err != nil {
+ t.Fatalf("UnixConn.ReadFromUnix failed: %v", err)
+ }
+ if _, err := c3.WriteToUnix(wb, a1); err != nil {
+ t.Fatalf("UnixConn.WriteToUnix failed: %v", err)
+ }
+ if _, _, err := c1.ReadFromUnix(rb1); err != nil {
+ t.Fatalf("UnixConn.ReadFromUnix failed: %v", err)
+ }
+ if _, err := c2.WriteToUnix(wb, a3); err != nil {
+ t.Fatalf("UnixConn.WriteToUnix failed: %v", err)
+ }
+ if _, _, err := c3.ReadFromUnix(rb3); err != nil {
+ t.Fatalf("UnixConn.ReadFromUnix failed: %v", err)
+ }
+
+ if f, err := c1.File(); err != nil {
+ t.Fatalf("UnixConn.File failed: %v", err)
+ } else {
+ f.Close()
+ }
+}
diff --git a/src/pkg/net/rpc/client.go b/src/pkg/net/rpc/client.go
index db2da8e44..4b0c9c3bb 100644
--- a/src/pkg/net/rpc/client.go
+++ b/src/pkg/net/rpc/client.go
@@ -71,7 +71,7 @@ func (client *Client) send(call *Call) {
// Register this call.
client.mutex.Lock()
- if client.shutdown {
+ if client.shutdown || client.closing {
call.Error = ErrShutdown
client.mutex.Unlock()
call.done()
@@ -88,10 +88,13 @@ func (client *Client) send(call *Call) {
err := client.codec.WriteRequest(&client.request, call.Args)
if err != nil {
client.mutex.Lock()
+ call = client.pending[seq]
delete(client.pending, seq)
client.mutex.Unlock()
- call.Error = err
- call.done()
+ if call != nil {
+ call.Error = err
+ call.done()
+ }
}
}
@@ -102,9 +105,6 @@ func (client *Client) input() {
response = Response{}
err = client.codec.ReadResponseHeader(&response)
if err != nil {
- if err == io.EOF && !client.closing {
- err = io.ErrUnexpectedEOF
- }
break
}
seq := response.Seq
@@ -113,12 +113,18 @@ func (client *Client) input() {
delete(client.pending, seq)
client.mutex.Unlock()
- if response.Error == "" {
- err = client.codec.ReadResponseBody(call.Reply)
+ switch {
+ case call == nil:
+ // We've got no pending call. That usually means that
+ // WriteRequest partially failed, and call was already
+ // removed; response is a server telling us about an
+ // error reading request body. We should still attempt
+ // to read error body, but there's no one to give it to.
+ err = client.codec.ReadResponseBody(nil)
if err != nil {
- call.Error = errors.New("reading body " + err.Error())
+ err = errors.New("reading error body: " + err.Error())
}
- } else {
+ case response.Error != "":
// We've got an error response. Give this to the request;
// any subsequent requests will get the ReadResponseBody
// error if there is one.
@@ -127,14 +133,27 @@ func (client *Client) input() {
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
+ call.done()
+ default:
+ err = client.codec.ReadResponseBody(call.Reply)
+ if err != nil {
+ call.Error = errors.New("reading body " + err.Error())
+ }
+ call.done()
}
- call.done()
}
// Terminate pending calls.
client.sending.Lock()
client.mutex.Lock()
client.shutdown = true
closing := client.closing
+ if err == io.EOF {
+ if closing {
+ err = ErrShutdown
+ } else {
+ err = io.ErrUnexpectedEOF
+ }
+ }
for _, call := range client.pending {
call.Error = err
call.done()
@@ -213,7 +232,7 @@ func DialHTTP(network, address string) (*Client, error) {
return DialHTTPPath(network, address, DefaultRPCPath)
}
-// DialHTTPPath connects to an HTTP RPC server
+// DialHTTPPath connects to an HTTP RPC server
// at the specified network address and path.
func DialHTTPPath(network, address, path string) (*Client, error) {
var err error
diff --git a/src/pkg/net/rpc/jsonrpc/all_test.go b/src/pkg/net/rpc/jsonrpc/all_test.go
index e6c7441f0..3c7c4d48f 100644
--- a/src/pkg/net/rpc/jsonrpc/all_test.go
+++ b/src/pkg/net/rpc/jsonrpc/all_test.go
@@ -24,6 +24,12 @@ type Reply struct {
type Arith int
+type ArithAddResp struct {
+ Id interface{} `json:"id"`
+ Result Reply `json:"result"`
+ Error interface{} `json:"error"`
+}
+
func (t *Arith) Add(args *Args, reply *Reply) error {
reply.C = args.A + args.B
return nil
@@ -50,13 +56,39 @@ func init() {
rpc.Register(new(Arith))
}
-func TestServer(t *testing.T) {
- type addResp struct {
- Id interface{} `json:"id"`
- Result Reply `json:"result"`
- Error interface{} `json:"error"`
+func TestServerNoParams(t *testing.T) {
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`)
+ var resp ArithAddResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after no params: %s", err)
+ }
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+
+func TestServerEmptyMessage(t *testing.T) {
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ fmt.Fprintf(cli, "{}")
+ var resp ArithAddResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after empty: %s", err)
}
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+func TestServer(t *testing.T) {
cli, srv := net.Pipe()
defer cli.Close()
go ServeConn(srv)
@@ -65,7 +97,7 @@ func TestServer(t *testing.T) {
// Send hand-coded requests to server, parse responses.
for i := 0; i < 10; i++ {
fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
- var resp addResp
+ var resp ArithAddResp
err := dec.Decode(&resp)
if err != nil {
t.Fatalf("Decode: %s", err)
@@ -80,15 +112,6 @@ func TestServer(t *testing.T) {
t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
}
}
-
- fmt.Fprintf(cli, "{}\n")
- var resp addResp
- if err := dec.Decode(&resp); err != nil {
- t.Fatalf("Decode after empty: %s", err)
- }
- if resp.Error == nil {
- t.Fatalf("Expected error, got nil")
- }
}
func TestClient(t *testing.T) {
@@ -108,7 +131,7 @@ func TestClient(t *testing.T) {
t.Errorf("Add: expected no error but got string %q", err.Error())
}
if reply.C != args.A+args.B {
- t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B)
}
args = &Args{7, 8}
@@ -118,7 +141,7 @@ func TestClient(t *testing.T) {
t.Errorf("Mul: expected no error but got string %q", err.Error())
}
if reply.C != args.A*args.B {
- t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B)
}
// Out of order.
@@ -133,7 +156,7 @@ func TestClient(t *testing.T) {
t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
}
if addReply.C != args.A+args.B {
- t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+ t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B)
}
mulCall = <-mulCall.Done
@@ -141,7 +164,7 @@ func TestClient(t *testing.T) {
t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
}
if mulReply.C != args.A*args.B {
- t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+ t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B)
}
// Error test
diff --git a/src/pkg/net/rpc/jsonrpc/server.go b/src/pkg/net/rpc/jsonrpc/server.go
index 4c54553a7..5bc05fd0a 100644
--- a/src/pkg/net/rpc/jsonrpc/server.go
+++ b/src/pkg/net/rpc/jsonrpc/server.go
@@ -12,6 +12,8 @@ import (
"sync"
)
+var errMissingParams = errors.New("jsonrpc: request body missing params")
+
type serverCodec struct {
dec *json.Decoder // for reading JSON values
enc *json.Encoder // for writing JSON values
@@ -50,12 +52,8 @@ type serverRequest struct {
func (r *serverRequest) reset() {
r.Method = ""
- if r.Params != nil {
- *r.Params = (*r.Params)[0:0]
- }
- if r.Id != nil {
- *r.Id = (*r.Id)[0:0]
- }
+ r.Params = nil
+ r.Id = nil
}
type serverResponse struct {
@@ -88,6 +86,9 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error {
if x == nil {
return nil
}
+ if c.req.Params == nil {
+ return errMissingParams
+ }
// JSON params is array value.
// RPC params is struct.
// Unmarshal into array containing struct for now.
diff --git a/src/pkg/net/rpc/server.go b/src/pkg/net/rpc/server.go
index 1680e2f0d..e71b6fb1a 100644
--- a/src/pkg/net/rpc/server.go
+++ b/src/pkg/net/rpc/server.go
@@ -24,12 +24,13 @@
where T, T1 and T2 can be marshaled by encoding/gob.
These requirements apply even if a different codec is used.
- (In future, these requirements may soften for custom codecs.)
+ (In the future, these requirements may soften for custom codecs.)
The method's first argument represents the arguments provided by the caller; the
second argument represents the result parameters to be returned to the caller.
The method's return value, if non-nil, is passed back as a string that the client
- sees as if created by errors.New.
+ sees as if created by errors.New. If an error is returned, the reply parameter
+ will not be sent back to the client.
The server may handle requests on a single connection by calling ServeConn. More
typically it will create a network listener and call Accept or, for an HTTP
@@ -111,7 +112,7 @@
// Asynchronous call
quotient := new(Quotient)
- divCall := client.Go("Arith.Divide", args, &quotient, nil)
+ divCall := client.Go("Arith.Divide", args, quotient, nil)
replyCall := <-divCall.Done // will be equal to divCall
// check errors, print, etc.
@@ -181,7 +182,7 @@ type Response struct {
// Server represents an RPC Server.
type Server struct {
- mu sync.Mutex // protects the serviceMap
+ mu sync.RWMutex // protects the serviceMap
serviceMap map[string]*service
reqLock sync.Mutex // protects freeReq
freeReq *Request
@@ -218,15 +219,15 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
// - exported method
// - two arguments, both pointers to exported structs
// - one return value, of type error
-// It returns an error if the receiver is not an exported type or has no
-// suitable methods.
+// It returns an error if the receiver is not an exported type or has
+// no methods or unsuitable methods. It also logs the error using package log.
// The client accesses each method using a string of the form "Type.Method",
// where Type is the receiver's concrete type.
func (server *Server) Register(rcvr interface{}) error {
return server.register(rcvr, "", false)
}
-// RegisterName is like Register but uses the provided name for the type
+// RegisterName is like Register but uses the provided name for the type
// instead of the receiver's concrete type.
func (server *Server) RegisterName(name string, rcvr interface{}) error {
return server.register(rcvr, name, true)
@@ -260,8 +261,30 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
s.method = make(map[string]*methodType)
// Install the methods
- for m := 0; m < s.typ.NumMethod(); m++ {
- method := s.typ.Method(m)
+ s.method = suitableMethods(s.typ, true)
+
+ if len(s.method) == 0 {
+ str := ""
+ // To help the user, see if a pointer receiver would work.
+ method := suitableMethods(reflect.PtrTo(s.typ), false)
+ if len(method) != 0 {
+ str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
+ } else {
+ str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
+ }
+ log.Print(str)
+ return errors.New(str)
+ }
+ server.serviceMap[s.name] = s
+ return nil
+}
+
+// suitableMethods returns suitable Rpc methods of typ, it will report
+// error using log if reportErr is true.
+func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
+ methods := make(map[string]*methodType)
+ for m := 0; m < typ.NumMethod(); m++ {
+ method := typ.Method(m)
mtype := method.Type
mname := method.Name
// Method must be exported.
@@ -270,46 +293,51 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
}
// Method needs three ins: receiver, *args, *reply.
if mtype.NumIn() != 3 {
- log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
+ if reportErr {
+ log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
+ }
continue
}
// First arg need not be a pointer.
argType := mtype.In(1)
if !isExportedOrBuiltinType(argType) {
- log.Println(mname, "argument type not exported:", argType)
+ if reportErr {
+ log.Println(mname, "argument type not exported:", argType)
+ }
continue
}
// Second arg must be a pointer.
replyType := mtype.In(2)
if replyType.Kind() != reflect.Ptr {
- log.Println("method", mname, "reply type not a pointer:", replyType)
+ if reportErr {
+ log.Println("method", mname, "reply type not a pointer:", replyType)
+ }
continue
}
// Reply type must be exported.
if !isExportedOrBuiltinType(replyType) {
- log.Println("method", mname, "reply type not exported:", replyType)
+ if reportErr {
+ log.Println("method", mname, "reply type not exported:", replyType)
+ }
continue
}
// Method needs one out.
if mtype.NumOut() != 1 {
- log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
+ if reportErr {
+ log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
+ }
continue
}
// The return type of the method must be error.
if returnType := mtype.Out(0); returnType != typeOfError {
- log.Println("method", mname, "returns", returnType.String(), "not error")
+ if reportErr {
+ log.Println("method", mname, "returns", returnType.String(), "not error")
+ }
continue
}
- s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
+ methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
}
-
- if len(s.method) == 0 {
- s := "rpc Register: type " + sname + " has no exported methods of suitable type"
- log.Print(s)
- return errors.New(s)
- }
- server.serviceMap[s.name] = s
- return nil
+ return methods
}
// A value sent as a placeholder for the server's response value when the server
@@ -538,9 +566,9 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt
return
}
// Look up the request.
- server.mu.Lock()
+ server.mu.RLock()
service = server.serviceMap[serviceMethod[0]]
- server.mu.Unlock()
+ server.mu.RUnlock()
if service == nil {
err = errors.New("rpc: can't find service " + req.ServiceMethod)
return
@@ -568,7 +596,7 @@ func (server *Server) Accept(lis net.Listener) {
// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
-// RegisterName is like Register but uses the provided name for the type
+// RegisterName is like Register but uses the provided name for the type
// instead of the receiver's concrete type.
func RegisterName(name string, rcvr interface{}) error {
return DefaultServer.RegisterName(name, rcvr)
@@ -611,7 +639,7 @@ func ServeRequest(codec ServerCodec) error {
}
// Accept accepts connections on the listener and serves requests
-// to DefaultServer for each incoming connection.
+// to DefaultServer for each incoming connection.
// Accept blocks; the caller typically invokes it in a go statement.
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
diff --git a/src/pkg/net/rpc/server_test.go b/src/pkg/net/rpc/server_test.go
index 62c7b1e60..8a1530623 100644
--- a/src/pkg/net/rpc/server_test.go
+++ b/src/pkg/net/rpc/server_test.go
@@ -349,6 +349,7 @@ func testServeRequest(t *testing.T, server *Server) {
type ReplyNotPointer int
type ArgNotPublic int
type ReplyNotPublic int
+type NeedsPtrType int
type local struct{}
func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
@@ -363,19 +364,29 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
return nil
}
+func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
+ return nil
+}
+
// Check that registration handles lots of bad methods and a type with no suitable methods.
func TestRegistrationError(t *testing.T) {
err := Register(new(ReplyNotPointer))
if err == nil {
- t.Errorf("expected error registering ReplyNotPointer")
+ t.Error("expected error registering ReplyNotPointer")
}
err = Register(new(ArgNotPublic))
if err == nil {
- t.Errorf("expected error registering ArgNotPublic")
+ t.Error("expected error registering ArgNotPublic")
}
err = Register(new(ReplyNotPublic))
if err == nil {
- t.Errorf("expected error registering ReplyNotPublic")
+ t.Error("expected error registering ReplyNotPublic")
+ }
+ err = Register(NeedsPtrType(0))
+ if err == nil {
+ t.Error("expected error registering NeedsPtrType")
+ } else if !strings.Contains(err.Error(), "pointer") {
+ t.Error("expected hint when registering NeedsPtrType")
}
}
@@ -434,7 +445,7 @@ func dialHTTP() (*Client, error) {
return DialHTTP("tcp", httpServerAddr)
}
-func countMallocs(dial func() (*Client, error), t *testing.T) uint64 {
+func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
once.Do(startServer)
client, err := dial()
if err != nil {
@@ -442,11 +453,7 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 {
}
args := &Args{7, 8}
reply := new(Reply)
- memstats := new(runtime.MemStats)
- runtime.ReadMemStats(memstats)
- mallocs := 0 - memstats.Mallocs
- const count = 100
- for i := 0; i < count; i++ {
+ return testing.AllocsPerRun(100, func() {
err := client.Call("Arith.Add", args, reply)
if err != nil {
t.Errorf("Add: expected no error but got string %q", err.Error())
@@ -454,18 +461,15 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 {
if reply.C != args.A+args.B {
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
}
- }
- runtime.ReadMemStats(memstats)
- mallocs += memstats.Mallocs
- return mallocs / count
+ })
}
func TestCountMallocs(t *testing.T) {
- fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t))
+ fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
}
func TestCountMallocsOverHTTP(t *testing.T) {
- fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t))
+ fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
}
type writeCrasher struct {
@@ -499,6 +503,44 @@ func TestClientWriteError(t *testing.T) {
w.done <- true
}
+func TestTCPClose(t *testing.T) {
+ once.Do(startServer)
+
+ client, err := dialHTTP()
+ if err != nil {
+ t.Fatalf("dialing: %v", err)
+ }
+ defer client.Close()
+
+ args := Args{17, 8}
+ var reply Reply
+ err = client.Call("Arith.Mul", args, &reply)
+ if err != nil {
+ t.Fatal("arith error:", err)
+ }
+ t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
+ if reply.C != args.A*args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
+ }
+}
+
+func TestErrorAfterClientClose(t *testing.T) {
+ once.Do(startServer)
+
+ client, err := dialHTTP()
+ if err != nil {
+ t.Fatalf("dialing: %v", err)
+ }
+ err = client.Close()
+ if err != nil {
+ t.Fatal("close error:", err)
+ }
+ err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
+ if err != ErrShutdown {
+ t.Errorf("Forever: expected ErrShutdown got %v", err)
+ }
+}
+
func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
b.StopTimer()
once.Do(startServer)
diff --git a/src/pkg/net/sendfile_freebsd.go b/src/pkg/net/sendfile_freebsd.go
new file mode 100644
index 000000000..8008bc3b5
--- /dev/null
+++ b/src/pkg/net/sendfile_freebsd.go
@@ -0,0 +1,105 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "io"
+ "os"
+ "syscall"
+)
+
+// maxSendfileSize is the largest chunk size we ask the kernel to copy
+// at a time.
+const maxSendfileSize int = 4 << 20
+
+// sendFile copies the contents of r to c using the sendfile
+// system call to minimize copies.
+//
+// if handled == true, sendFile returns the number of bytes copied and any
+// non-EOF error.
+//
+// if handled == false, sendFile performed no work.
+func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
+ // FreeBSD uses 0 as the "until EOF" value. If you pass in more bytes than the
+ // file contains, it will loop back to the beginning ad nauseum until it's sent
+ // exactly the number of bytes told to. As such, we need to know exactly how many
+ // bytes to send.
+ var remain int64 = 0
+
+ lr, ok := r.(*io.LimitedReader)
+ if ok {
+ remain, r = lr.N, lr.R
+ if remain <= 0 {
+ return 0, nil, true
+ }
+ }
+ f, ok := r.(*os.File)
+ if !ok {
+ return 0, nil, false
+ }
+
+ if remain == 0 {
+ fi, err := f.Stat()
+ if err != nil {
+ return 0, err, false
+ }
+
+ remain = fi.Size()
+ }
+
+ // The other quirk with FreeBSD's sendfile implementation is that it doesn't
+ // use the current position of the file -- if you pass it offset 0, it starts
+ // from offset 0. There's no way to tell it "start from current position", so
+ // we have to manage that explicitly.
+ pos, err := f.Seek(0, os.SEEK_CUR)
+ if err != nil {
+ return 0, err, false
+ }
+
+ c.wio.Lock()
+ defer c.wio.Unlock()
+ if err := c.incref(false); err != nil {
+ return 0, err, true
+ }
+ defer c.decref()
+
+ dst := c.sysfd
+ src := int(f.Fd())
+ for remain > 0 {
+ n := maxSendfileSize
+ if int64(n) > remain {
+ n = int(remain)
+ }
+ pos1 := pos
+ n, err1 := syscall.Sendfile(dst, src, &pos1, n)
+ if n > 0 {
+ pos += int64(n)
+ written += int64(n)
+ remain -= int64(n)
+ }
+ if n == 0 && err1 == nil {
+ break
+ }
+ if err1 == syscall.EAGAIN {
+ if err1 = c.pollServer.WaitWrite(c); err1 == nil {
+ continue
+ }
+ }
+ if err1 == syscall.EINTR {
+ continue
+ }
+ if err1 != nil {
+ // This includes syscall.ENOSYS (no kernel
+ // support) and syscall.EINVAL (fd types which
+ // don't implement sendfile together)
+ err = &OpError{"sendfile", c.net, c.raddr, err1}
+ break
+ }
+ }
+ if lr != nil {
+ lr.N = remain
+ }
+ return written, err, written > 0
+}
diff --git a/src/pkg/net/sendfile_linux.go b/src/pkg/net/sendfile_linux.go
index a0d530362..3357e6538 100644
--- a/src/pkg/net/sendfile_linux.go
+++ b/src/pkg/net/sendfile_linux.go
@@ -58,8 +58,8 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
if n == 0 && err1 == nil {
break
}
- if err1 == syscall.EAGAIN && c.wdeadline >= 0 {
- if err1 = pollserver.WaitWrite(c); err1 == nil {
+ if err1 == syscall.EAGAIN {
+ if err1 = c.pollServer.WaitWrite(c); err1 == nil {
continue
}
}
diff --git a/src/pkg/net/sendfile_stub.go b/src/pkg/net/sendfile_stub.go
index ff76ab9cf..3660849c1 100644
--- a/src/pkg/net/sendfile_stub.go
+++ b/src/pkg/net/sendfile_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd netbsd openbsd
+// +build darwin netbsd openbsd
package net
diff --git a/src/pkg/net/sendfile_windows.go b/src/pkg/net/sendfile_windows.go
index f5a6d8804..2d64f2f5b 100644
--- a/src/pkg/net/sendfile_windows.go
+++ b/src/pkg/net/sendfile_windows.go
@@ -48,12 +48,12 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
return 0, nil, false
}
- c.wio.Lock()
- defer c.wio.Unlock()
if err := c.incref(false); err != nil {
return 0, err, true
}
defer c.decref()
+ c.wio.Lock()
+ defer c.wio.Unlock()
var o sendfileOp
o.Init(c, 'w')
diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go
index 158b9477d..25c2be5a7 100644
--- a/src/pkg/net/server_test.go
+++ b/src/pkg/net/server_test.go
@@ -113,8 +113,7 @@ func TestStreamConnServer(t *testing.T) {
case "tcp", "tcp4", "tcp6":
_, port, err := SplitHostPort(taddr)
if err != nil {
- t.Errorf("SplitHostPort(%q) failed: %v", taddr, err)
- return
+ t.Fatalf("SplitHostPort(%q) failed: %v", taddr, err)
}
taddr = tt.caddr + ":" + port
}
@@ -142,8 +141,7 @@ var seqpacketConnServerTests = []struct {
func TestSeqpacketConnServer(t *testing.T) {
if runtime.GOOS != "linux" {
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
for _, tt := range seqpacketConnServerTests {
@@ -170,11 +168,11 @@ func TestSeqpacketConnServer(t *testing.T) {
}
func runStreamConnServer(t *testing.T, net, laddr string, listening chan<- string, done chan<- int) {
+ defer close(done)
l, err := Listen(net, laddr)
if err != nil {
t.Errorf("Listen(%q, %q) failed: %v", net, laddr, err)
listening <- "<nil>"
- done <- 1
return
}
defer l.Close()
@@ -189,13 +187,14 @@ func runStreamConnServer(t *testing.T, net, laddr string, listening chan<- strin
}
rw.Write(buf[0:n])
}
- done <- 1
+ close(done)
}
run:
for {
c, err := l.Accept()
if err != nil {
+ t.Logf("Accept failed: %v", err)
continue run
}
echodone := make(chan int)
@@ -204,14 +203,12 @@ run:
c.Close()
break run
}
- done <- 1
}
func runStreamConnClient(t *testing.T, net, taddr string, isEmpty bool) {
c, err := Dial(net, taddr)
if err != nil {
- t.Errorf("Dial(%q, %q) failed: %v", net, taddr, err)
- return
+ t.Fatalf("Dial(%q, %q) failed: %v", net, taddr, err)
}
defer c.Close()
c.SetReadDeadline(time.Now().Add(1 * time.Second))
@@ -221,14 +218,12 @@ func runStreamConnClient(t *testing.T, net, taddr string, isEmpty bool) {
wb = []byte("StreamConnClient by Dial\n")
}
if n, err := c.Write(wb); err != nil || n != len(wb) {
- t.Errorf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb))
- return
+ t.Fatalf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb))
}
rb := make([]byte, 1024)
if n, err := c.Read(rb[0:]); err != nil || n != len(wb) {
- t.Errorf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb))
- return
+ t.Fatalf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb))
}
// Send explicit ending for unixpacket.
@@ -334,8 +329,7 @@ func TestDatagramPacketConnServer(t *testing.T) {
case "udp", "udp4", "udp6":
_, port, err := SplitHostPort(taddr)
if err != nil {
- t.Errorf("SplitHostPort(%q) failed: %v", taddr, err)
- return
+ t.Fatalf("SplitHostPort(%q) failed: %v", taddr, err)
}
taddr = tt.caddr + ":" + port
tt.caddr += ":0"
@@ -398,14 +392,12 @@ func runDatagramConnClient(t *testing.T, net, laddr, taddr string, isEmpty bool)
case "udp", "udp4", "udp6":
c, err = Dial(net, taddr)
if err != nil {
- t.Errorf("Dial(%q, %q) failed: %v", net, taddr, err)
- return
+ t.Fatalf("Dial(%q, %q) failed: %v", net, taddr, err)
}
case "unixgram":
c, err = DialUnix(net, &UnixAddr{laddr, net}, &UnixAddr{taddr, net})
if err != nil {
- t.Errorf("DialUnix(%q, {%q, %q}) failed: %v", net, laddr, taddr, err)
- return
+ t.Fatalf("DialUnix(%q, {%q, %q}) failed: %v", net, laddr, taddr, err)
}
}
defer c.Close()
@@ -416,14 +408,12 @@ func runDatagramConnClient(t *testing.T, net, laddr, taddr string, isEmpty bool)
wb = []byte("DatagramConnClient by Dial\n")
}
if n, err := c.Write(wb[0:]); err != nil || n != len(wb) {
- t.Errorf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb))
- return
+ t.Fatalf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb))
}
rb := make([]byte, 1024)
if n, err := c.Read(rb[0:]); err != nil || n != len(wb) {
- t.Errorf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb))
- return
+ t.Fatalf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb))
}
}
@@ -434,20 +424,17 @@ func runDatagramPacketConnClient(t *testing.T, net, laddr, taddr string, isEmpty
case "udp", "udp4", "udp6":
ra, err = ResolveUDPAddr(net, taddr)
if err != nil {
- t.Errorf("ResolveUDPAddr(%q, %q) failed: %v", net, taddr, err)
- return
+ t.Fatalf("ResolveUDPAddr(%q, %q) failed: %v", net, taddr, err)
}
case "unixgram":
ra, err = ResolveUnixAddr(net, taddr)
if err != nil {
- t.Errorf("ResolveUxixAddr(%q, %q) failed: %v", net, taddr, err)
- return
+ t.Fatalf("ResolveUxixAddr(%q, %q) failed: %v", net, taddr, err)
}
}
c, err := ListenPacket(net, laddr)
if err != nil {
- t.Errorf("ListenPacket(%q, %q) faild: %v", net, laddr, err)
- return
+ t.Fatalf("ListenPacket(%q, %q) faild: %v", net, laddr, err)
}
defer c.Close()
c.SetReadDeadline(time.Now().Add(1 * time.Second))
@@ -457,13 +444,11 @@ func runDatagramPacketConnClient(t *testing.T, net, laddr, taddr string, isEmpty
wb = []byte("DatagramPacketConnClient by ListenPacket\n")
}
if n, err := c.WriteTo(wb[0:], ra); err != nil || n != len(wb) {
- t.Errorf("WriteTo(%v) failed: %v, %v; want %v, <nil>", ra, n, err, len(wb))
- return
+ t.Fatalf("WriteTo(%v) failed: %v, %v; want %v, <nil>", ra, n, err, len(wb))
}
rb := make([]byte, 1024)
if n, _, err := c.ReadFrom(rb[0:]); err != nil || n != len(wb) {
- t.Errorf("ReadFrom failed: %v, %v; want %v, <nil>", n, err, len(wb))
- return
+ t.Fatalf("ReadFrom failed: %v, %v; want %v, <nil>", n, err, len(wb))
}
}
diff --git a/src/pkg/net/smtp/smtp.go b/src/pkg/net/smtp/smtp.go
index 59f6449f0..4b9177877 100644
--- a/src/pkg/net/smtp/smtp.go
+++ b/src/pkg/net/smtp/smtp.go
@@ -13,6 +13,7 @@ package smtp
import (
"crypto/tls"
"encoding/base64"
+ "errors"
"io"
"net"
"net/textproto"
@@ -33,7 +34,10 @@ type Client struct {
// map of supported extensions
ext map[string]string
// supported auth mechanisms
- auth []string
+ auth []string
+ localName string // the name to use in HELO/EHLO
+ didHello bool // whether we've said HELO/EHLO
+ helloError error // the error from the hello
}
// Dial returns a new Client connected to an SMTP server at addr.
@@ -55,12 +59,33 @@ func NewClient(conn net.Conn, host string) (*Client, error) {
text.Close()
return nil, err
}
- c := &Client{Text: text, conn: conn, serverName: host}
- err = c.ehlo()
- if err != nil {
- err = c.helo()
+ c := &Client{Text: text, conn: conn, serverName: host, localName: "localhost"}
+ return c, nil
+}
+
+// hello runs a hello exchange if needed.
+func (c *Client) hello() error {
+ if !c.didHello {
+ c.didHello = true
+ err := c.ehlo()
+ if err != nil {
+ c.helloError = c.helo()
+ }
+ }
+ return c.helloError
+}
+
+// Hello sends a HELO or EHLO to the server as the given host name.
+// Calling this method is only necessary if the client needs control
+// over the host name used. The client will introduce itself as "localhost"
+// automatically otherwise. If Hello is called, it must be called before
+// any of the other methods.
+func (c *Client) Hello(localName string) error {
+ if c.didHello {
+ return errors.New("smtp: Hello called after other methods")
}
- return c, err
+ c.localName = localName
+ return c.hello()
}
// cmd is a convenience function that sends a command and returns the response
@@ -79,14 +104,14 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s
// server does not support ehlo.
func (c *Client) helo() error {
c.ext = nil
- _, _, err := c.cmd(250, "HELO localhost")
+ _, _, err := c.cmd(250, "HELO %s", c.localName)
return err
}
// ehlo sends the EHLO (extended hello) greeting to the server. It
// should be the preferred greeting for servers that support it.
func (c *Client) ehlo() error {
- _, msg, err := c.cmd(250, "EHLO localhost")
+ _, msg, err := c.cmd(250, "EHLO %s", c.localName)
if err != nil {
return err
}
@@ -113,6 +138,9 @@ func (c *Client) ehlo() error {
// StartTLS sends the STARTTLS command and encrypts all further communication.
// Only servers that advertise the STARTTLS extension support this function.
func (c *Client) StartTLS(config *tls.Config) error {
+ if err := c.hello(); err != nil {
+ return err
+ }
_, _, err := c.cmd(220, "STARTTLS")
if err != nil {
return err
@@ -128,6 +156,9 @@ func (c *Client) StartTLS(config *tls.Config) error {
// does not necessarily indicate an invalid address. Many servers
// will not verify addresses for security reasons.
func (c *Client) Verify(addr string) error {
+ if err := c.hello(); err != nil {
+ return err
+ }
_, _, err := c.cmd(250, "VRFY %s", addr)
return err
}
@@ -136,6 +167,9 @@ func (c *Client) Verify(addr string) error {
// A failed authentication closes the connection.
// Only servers that advertise the AUTH extension support this function.
func (c *Client) Auth(a Auth) error {
+ if err := c.hello(); err != nil {
+ return err
+ }
encoding := base64.StdEncoding
mech, resp, err := a.Start(&ServerInfo{c.serverName, c.tls, c.auth})
if err != nil {
@@ -178,6 +212,9 @@ func (c *Client) Auth(a Auth) error {
// parameter.
// This initiates a mail transaction and is followed by one or more Rcpt calls.
func (c *Client) Mail(from string) error {
+ if err := c.hello(); err != nil {
+ return err
+ }
cmdStr := "MAIL FROM:<%s>"
if c.ext != nil {
if _, ok := c.ext["8BITMIME"]; ok {
@@ -227,6 +264,9 @@ func SendMail(addr string, a Auth, from string, to []string, msg []byte) error {
if err != nil {
return err
}
+ if err := c.hello(); err != nil {
+ return err
+ }
if ok, _ := c.Extension("STARTTLS"); ok {
if err = c.StartTLS(nil); err != nil {
return err
@@ -267,6 +307,9 @@ func SendMail(addr string, a Auth, from string, to []string, msg []byte) error {
// Extension also returns a string that contains any parameters the
// server specifies for the extension.
func (c *Client) Extension(ext string) (bool, string) {
+ if err := c.hello(); err != nil {
+ return false, ""
+ }
if c.ext == nil {
return false, ""
}
@@ -278,12 +321,18 @@ func (c *Client) Extension(ext string) (bool, string) {
// Reset sends the RSET command to the server, aborting the current mail
// transaction.
func (c *Client) Reset() error {
+ if err := c.hello(); err != nil {
+ return err
+ }
_, _, err := c.cmd(250, "RSET")
return err
}
// Quit sends the QUIT command and closes the connection to the server.
func (c *Client) Quit() error {
+ if err := c.hello(); err != nil {
+ return err
+ }
_, _, err := c.cmd(221, "QUIT")
if err != nil {
return err
diff --git a/src/pkg/net/smtp/smtp_test.go b/src/pkg/net/smtp/smtp_test.go
index c315d185c..8317428cb 100644
--- a/src/pkg/net/smtp/smtp_test.go
+++ b/src/pkg/net/smtp/smtp_test.go
@@ -69,14 +69,14 @@ func (f faker) SetReadDeadline(time.Time) error { return nil }
func (f faker) SetWriteDeadline(time.Time) error { return nil }
func TestBasic(t *testing.T) {
- basicServer = strings.Join(strings.Split(basicServer, "\n"), "\r\n")
- basicClient = strings.Join(strings.Split(basicClient, "\n"), "\r\n")
+ server := strings.Join(strings.Split(basicServer, "\n"), "\r\n")
+ client := strings.Join(strings.Split(basicClient, "\n"), "\r\n")
var cmdbuf bytes.Buffer
bcmdbuf := bufio.NewWriter(&cmdbuf)
var fake faker
- fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(basicServer)), bcmdbuf)
- c := &Client{Text: textproto.NewConn(fake)}
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
+ c := &Client{Text: textproto.NewConn(fake), localName: "localhost"}
if err := c.helo(); err != nil {
t.Fatalf("HELO failed: %s", err)
@@ -88,6 +88,7 @@ func TestBasic(t *testing.T) {
t.Fatalf("Second EHLO failed: %s", err)
}
+ c.didHello = true
if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" {
t.Fatalf("Expected AUTH supported")
}
@@ -143,8 +144,8 @@ Goodbye.`
bcmdbuf.Flush()
actualcmds := cmdbuf.String()
- if basicClient != actualcmds {
- t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, basicClient)
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
}
}
@@ -187,8 +188,8 @@ QUIT
`
func TestNewClient(t *testing.T) {
- newClientServer = strings.Join(strings.Split(newClientServer, "\n"), "\r\n")
- newClientClient = strings.Join(strings.Split(newClientClient, "\n"), "\r\n")
+ server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n")
+ client := strings.Join(strings.Split(newClientClient, "\n"), "\r\n")
var cmdbuf bytes.Buffer
bcmdbuf := bufio.NewWriter(&cmdbuf)
@@ -197,7 +198,7 @@ func TestNewClient(t *testing.T) {
return cmdbuf.String()
}
var fake faker
- fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(newClientServer)), bcmdbuf)
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
c, err := NewClient(fake, "fake.host")
if err != nil {
t.Fatalf("NewClient: %v\n(after %v)", err, out())
@@ -213,8 +214,8 @@ func TestNewClient(t *testing.T) {
}
actualcmds := out()
- if newClientClient != actualcmds {
- t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, newClientClient)
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
}
}
@@ -231,13 +232,13 @@ QUIT
`
func TestNewClient2(t *testing.T) {
- newClient2Server = strings.Join(strings.Split(newClient2Server, "\n"), "\r\n")
- newClient2Client = strings.Join(strings.Split(newClient2Client, "\n"), "\r\n")
+ server := strings.Join(strings.Split(newClient2Server, "\n"), "\r\n")
+ client := strings.Join(strings.Split(newClient2Client, "\n"), "\r\n")
var cmdbuf bytes.Buffer
bcmdbuf := bufio.NewWriter(&cmdbuf)
var fake faker
- fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(newClient2Server)), bcmdbuf)
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
c, err := NewClient(fake, "fake.host")
if err != nil {
t.Fatalf("NewClient: %v", err)
@@ -251,8 +252,8 @@ func TestNewClient2(t *testing.T) {
bcmdbuf.Flush()
actualcmds := cmdbuf.String()
- if newClient2Client != actualcmds {
- t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, newClient2Client)
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
}
}
@@ -269,3 +270,199 @@ var newClient2Client = `EHLO localhost
HELO localhost
QUIT
`
+
+func TestHello(t *testing.T) {
+
+ if len(helloServer) != len(helloClient) {
+ t.Fatalf("Hello server and client size mismatch")
+ }
+
+ for i := 0; i < len(helloServer); i++ {
+ server := strings.Join(strings.Split(baseHelloServer+helloServer[i], "\n"), "\r\n")
+ client := strings.Join(strings.Split(baseHelloClient+helloClient[i], "\n"), "\r\n")
+ var cmdbuf bytes.Buffer
+ bcmdbuf := bufio.NewWriter(&cmdbuf)
+ var fake faker
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
+ c, err := NewClient(fake, "fake.host")
+ if err != nil {
+ t.Fatalf("NewClient: %v", err)
+ }
+ c.localName = "customhost"
+ err = nil
+
+ switch i {
+ case 0:
+ err = c.Hello("customhost")
+ case 1:
+ err = c.StartTLS(nil)
+ if err.Error() == "502 Not implemented" {
+ err = nil
+ }
+ case 2:
+ err = c.Verify("test@example.com")
+ case 3:
+ c.tls = true
+ c.serverName = "smtp.google.com"
+ err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com"))
+ case 4:
+ err = c.Mail("test@example.com")
+ case 5:
+ ok, _ := c.Extension("feature")
+ if ok {
+ t.Errorf("Expected FEATURE not to be supported")
+ }
+ case 6:
+ err = c.Reset()
+ case 7:
+ err = c.Quit()
+ case 8:
+ err = c.Verify("test@example.com")
+ if err != nil {
+ err = c.Hello("customhost")
+ if err != nil {
+ t.Errorf("Want error, got none")
+ }
+ }
+ default:
+ t.Fatalf("Unhandled command")
+ }
+
+ if err != nil {
+ t.Errorf("Command %d failed: %v", i, err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ if client != actualcmds {
+ t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+ }
+}
+
+var baseHelloServer = `220 hello world
+502 EH?
+250-mx.google.com at your service
+250 FEATURE
+`
+
+var helloServer = []string{
+ "",
+ "502 Not implemented\n",
+ "250 User is valid\n",
+ "235 Accepted\n",
+ "250 Sender ok\n",
+ "",
+ "250 Reset ok\n",
+ "221 Goodbye\n",
+ "250 Sender ok\n",
+}
+
+var baseHelloClient = `EHLO customhost
+HELO customhost
+`
+
+var helloClient = []string{
+ "",
+ "STARTTLS\n",
+ "VRFY test@example.com\n",
+ "AUTH PLAIN AHVzZXIAcGFzcw==\n",
+ "MAIL FROM:<test@example.com>\n",
+ "",
+ "RSET\n",
+ "QUIT\n",
+ "VRFY test@example.com\n",
+}
+
+func TestSendMail(t *testing.T) {
+ server := strings.Join(strings.Split(sendMailServer, "\n"), "\r\n")
+ client := strings.Join(strings.Split(sendMailClient, "\n"), "\r\n")
+ var cmdbuf bytes.Buffer
+ bcmdbuf := bufio.NewWriter(&cmdbuf)
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Unable to to create listener: %v", err)
+ }
+ defer l.Close()
+
+ // prevent data race on bcmdbuf
+ var done = make(chan struct{})
+ go func(data []string) {
+
+ defer close(done)
+
+ conn, err := l.Accept()
+ if err != nil {
+ t.Errorf("Accept error: %v", err)
+ return
+ }
+ defer conn.Close()
+
+ tc := textproto.NewConn(conn)
+ for i := 0; i < len(data) && data[i] != ""; i++ {
+ tc.PrintfLine(data[i])
+ for len(data[i]) >= 4 && data[i][3] == '-' {
+ i++
+ tc.PrintfLine(data[i])
+ }
+ if data[i] == "221 Goodbye" {
+ return
+ }
+ read := false
+ for !read || data[i] == "354 Go ahead" {
+ msg, err := tc.ReadLine()
+ bcmdbuf.Write([]byte(msg + "\r\n"))
+ read = true
+ if err != nil {
+ t.Errorf("Read error: %v", err)
+ return
+ }
+ if data[i] == "354 Go ahead" && msg == "." {
+ break
+ }
+ }
+ }
+ }(strings.Split(server, "\r\n"))
+
+ err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com
+To: other@example.com
+Subject: SendMail test
+
+SendMail is working for me.
+`, "\n", "\r\n", -1)))
+
+ if err != nil {
+ t.Errorf("%v", err)
+ }
+
+ <-done
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ if client != actualcmds {
+ t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+}
+
+var sendMailServer = `220 hello world
+502 EH?
+250 mx.google.com at your service
+250 Sender ok
+250 Receiver ok
+354 Go ahead
+250 Data ok
+221 Goodbye
+`
+
+var sendMailClient = `EHLO localhost
+HELO localhost
+MAIL FROM:<test@example.com>
+RCPT TO:<other@example.com>
+DATA
+From: test@example.com
+To: other@example.com
+Subject: SendMail test
+
+SendMail is working for me.
+.
+QUIT
+`
diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go
deleted file mode 100644
index 3ae16054e..000000000
--- a/src/pkg/net/sock.go
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build darwin freebsd linux netbsd openbsd windows
-
-// Sockets
-
-package net
-
-import (
- "io"
- "syscall"
-)
-
-var listenerBacklog = maxListenerBacklog()
-
-// Generic socket creation.
-func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
- // See ../syscall/exec.go for description of ForkLock.
- syscall.ForkLock.RLock()
- s, err := syscall.Socket(f, t, p)
- if err != nil {
- syscall.ForkLock.RUnlock()
- return nil, err
- }
- syscall.CloseOnExec(s)
- syscall.ForkLock.RUnlock()
-
- err = setDefaultSockopts(s, f, t, ipv6only)
- if err != nil {
- closesocket(s)
- return nil, err
- }
-
- var bla syscall.Sockaddr
- if la != nil {
- bla, err = listenerSockaddr(s, f, la, toAddr)
- if err != nil {
- closesocket(s)
- return nil, err
- }
- err = syscall.Bind(s, bla)
- if err != nil {
- closesocket(s)
- return nil, err
- }
- }
-
- if fd, err = newFD(s, f, t, net); err != nil {
- closesocket(s)
- return nil, err
- }
-
- if ra != nil {
- if err = fd.connect(ra); err != nil {
- closesocket(s)
- fd.Close()
- return nil, err
- }
- fd.isConnected = true
- }
-
- sa, _ := syscall.Getsockname(s)
- var laddr Addr
- if la != nil && bla != la {
- laddr = toAddr(la)
- } else {
- laddr = toAddr(sa)
- }
- sa, _ = syscall.Getpeername(s)
- raddr := toAddr(sa)
-
- fd.setAddr(laddr, raddr)
- return fd, nil
-}
-
-type writerOnly struct {
- io.Writer
-}
-
-// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
-// applicable.
-func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
- // Use wrapper to hide existing r.ReadFrom from io.Copy.
- return io.Copy(writerOnly{w}, r)
-}
diff --git a/src/pkg/net/sock_bsd.go b/src/pkg/net/sock_bsd.go
index 2607b04c7..3205f9404 100644
--- a/src/pkg/net/sock_bsd.go
+++ b/src/pkg/net/sock_bsd.go
@@ -4,8 +4,6 @@
// +build darwin freebsd netbsd openbsd
-// Sockets for BSD variants
-
package net
import (
@@ -31,32 +29,3 @@ func maxListenerBacklog() int {
}
return int(n)
}
-
-func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) {
- a := toAddr(la)
- if a == nil {
- return la, nil
- }
- switch v := a.(type) {
- case *TCPAddr, *UnixAddr:
- err := setDefaultListenerSockopts(s)
- if err != nil {
- return nil, err
- }
- case *UDPAddr:
- if v.IP.IsMulticast() {
- err := setDefaultMulticastSockopts(s)
- if err != nil {
- return nil, err
- }
- switch f {
- case syscall.AF_INET:
- v.IP = IPv4zero
- case syscall.AF_INET6:
- v.IP = IPv6unspecified
- }
- return v.sockaddr(f)
- }
- }
- return la, nil
-}
diff --git a/src/pkg/net/sock_cloexec.go b/src/pkg/net/sock_cloexec.go
new file mode 100644
index 000000000..12d0f3488
--- /dev/null
+++ b/src/pkg/net/sock_cloexec.go
@@ -0,0 +1,69 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements sysSocket and accept for platforms that
+// provide a fast path for setting SetNonblock and CloseOnExec.
+
+// +build linux
+
+package net
+
+import "syscall"
+
+// Wrapper around the socket system call that marks the returned file
+// descriptor as nonblocking and close-on-exec.
+func sysSocket(f, t, p int) (int, error) {
+ s, err := syscall.Socket(f, t|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, p)
+ // The SOCK_NONBLOCK and SOCK_CLOEXEC flags were introduced in
+ // Linux 2.6.27. If we get an EINVAL error, fall back to
+ // using socket without them.
+ if err == nil || err != syscall.EINVAL {
+ return s, err
+ }
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = syscall.Socket(f, t, p)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
+ if err != nil {
+ return -1, err
+ }
+ if err = syscall.SetNonblock(s, true); err != nil {
+ syscall.Close(s)
+ return -1, err
+ }
+ return s, nil
+}
+
+// Wrapper around the accept system call that marks the returned file
+// descriptor as nonblocking and close-on-exec.
+func accept(fd int) (int, syscall.Sockaddr, error) {
+ nfd, sa, err := syscall.Accept4(fd, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
+ // The accept4 system call was introduced in Linux 2.6.28. If
+ // we get an ENOSYS error, fall back to using accept.
+ if err == nil || err != syscall.ENOSYS {
+ return nfd, sa, err
+ }
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ // It is probably okay to hold the lock across syscall.Accept
+ // because we have put fd.sysfd into non-blocking mode.
+ // However, a call to the File method will put it back into
+ // blocking mode. We can't take that risk, so no use of ForkLock here.
+ nfd, sa, err = syscall.Accept(fd)
+ if err == nil {
+ syscall.CloseOnExec(nfd)
+ }
+ if err != nil {
+ return -1, nil, err
+ }
+ if err = syscall.SetNonblock(nfd, true); err != nil {
+ syscall.Close(nfd)
+ return -1, nil, err
+ }
+ return nfd, sa, nil
+}
diff --git a/src/pkg/net/sock_linux.go b/src/pkg/net/sock_linux.go
index e509d9397..8bbd74ddc 100644
--- a/src/pkg/net/sock_linux.go
+++ b/src/pkg/net/sock_linux.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Sockets for Linux
-
package net
import "syscall"
@@ -25,32 +23,3 @@ func maxListenerBacklog() int {
}
return n
}
-
-func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) {
- a := toAddr(la)
- if a == nil {
- return la, nil
- }
- switch v := a.(type) {
- case *TCPAddr, *UnixAddr:
- err := setDefaultListenerSockopts(s)
- if err != nil {
- return nil, err
- }
- case *UDPAddr:
- if v.IP.IsMulticast() {
- err := setDefaultMulticastSockopts(s)
- if err != nil {
- return nil, err
- }
- switch f {
- case syscall.AF_INET:
- v.IP = IPv4zero
- case syscall.AF_INET6:
- v.IP = IPv6unspecified
- }
- return v.sockaddr(f)
- }
- }
- return la, nil
-}
diff --git a/src/pkg/net/sock_posix.go b/src/pkg/net/sock_posix.go
new file mode 100644
index 000000000..b50a892b1
--- /dev/null
+++ b/src/pkg/net/sock_posix.go
@@ -0,0 +1,67 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd windows
+
+package net
+
+import (
+ "syscall"
+ "time"
+)
+
+var listenerBacklog = maxListenerBacklog()
+
+// Generic POSIX socket creation.
+func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, deadline time.Time, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
+ s, err := sysSocket(f, t, p)
+ if err != nil {
+ return nil, err
+ }
+
+ if err = setDefaultSockopts(s, f, t, ipv6only); err != nil {
+ closesocket(s)
+ return nil, err
+ }
+
+ if ulsa != nil {
+ // We provide a socket that listens to a wildcard
+ // address with reusable UDP port when the given ulsa
+ // is an appropriate UDP multicast address prefix.
+ // This makes it possible for a single UDP listener
+ // to join multiple different group addresses, for
+ // multiple UDP listeners that listen on the same UDP
+ // port to join the same group address.
+ if ulsa, err = listenerSockaddr(s, f, ulsa, toAddr); err != nil {
+ closesocket(s)
+ return nil, err
+ }
+ if err = syscall.Bind(s, ulsa); err != nil {
+ closesocket(s)
+ return nil, err
+ }
+ }
+
+ if fd, err = newFD(s, f, t, net); err != nil {
+ closesocket(s)
+ return nil, err
+ }
+
+ if ursa != nil {
+ fd.wdeadline.setTime(deadline)
+ if err = fd.connect(ursa); err != nil {
+ closesocket(s)
+ return nil, err
+ }
+ fd.isConnected = true
+ fd.wdeadline.set(0)
+ }
+
+ lsa, _ := syscall.Getsockname(s)
+ laddr := toAddr(lsa)
+ rsa, _ := syscall.Getpeername(s)
+ raddr := toAddr(rsa)
+ fd.setAddr(laddr, raddr)
+ return fd, nil
+}
diff --git a/src/pkg/net/sock_unix.go b/src/pkg/net/sock_unix.go
new file mode 100644
index 000000000..b0d6d4900
--- /dev/null
+++ b/src/pkg/net/sock_unix.go
@@ -0,0 +1,36 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd
+
+package net
+
+import "syscall"
+
+func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) {
+ a := toAddr(la)
+ if a == nil {
+ return la, nil
+ }
+ switch a := a.(type) {
+ case *TCPAddr, *UnixAddr:
+ if err := setDefaultListenerSockopts(s); err != nil {
+ return nil, err
+ }
+ case *UDPAddr:
+ if a.IP.IsMulticast() {
+ if err := setDefaultMulticastSockopts(s); err != nil {
+ return nil, err
+ }
+ switch f {
+ case syscall.AF_INET:
+ a.IP = IPv4zero
+ case syscall.AF_INET6:
+ a.IP = IPv6unspecified
+ }
+ return a.sockaddr(f)
+ }
+ }
+ return la, nil
+}
diff --git a/src/pkg/net/sock_windows.go b/src/pkg/net/sock_windows.go
index cce6181c9..a77c48437 100644
--- a/src/pkg/net/sock_windows.go
+++ b/src/pkg/net/sock_windows.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Sockets for Windows
-
package net
import "syscall"
@@ -18,26 +16,35 @@ func listenerSockaddr(s syscall.Handle, f int, la syscall.Sockaddr, toAddr func(
if a == nil {
return la, nil
}
- switch v := a.(type) {
+ switch a := a.(type) {
case *TCPAddr, *UnixAddr:
- err := setDefaultListenerSockopts(s)
- if err != nil {
+ if err := setDefaultListenerSockopts(s); err != nil {
return nil, err
}
case *UDPAddr:
- if v.IP.IsMulticast() {
- err := setDefaultMulticastSockopts(s)
- if err != nil {
+ if a.IP.IsMulticast() {
+ if err := setDefaultMulticastSockopts(s); err != nil {
return nil, err
}
switch f {
case syscall.AF_INET:
- v.IP = IPv4zero
+ a.IP = IPv4zero
case syscall.AF_INET6:
- v.IP = IPv6unspecified
+ a.IP = IPv6unspecified
}
- return v.sockaddr(f)
+ return a.sockaddr(f)
}
}
return la, nil
}
+
+func sysSocket(f, t, p int) (syscall.Handle, error) {
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err := syscall.Socket(f, t, p)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
+ return s, err
+}
diff --git a/src/pkg/net/sockopt.go b/src/pkg/net/sockopt_posix.go
index 0cd19266f..fe371fe0c 100644
--- a/src/pkg/net/sockopt.go
+++ b/src/pkg/net/sockopt_posix.go
@@ -119,45 +119,22 @@ func setWriteBuffer(fd *netFD, bytes int) error {
return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes))
}
+// TODO(dfc) these unused error returns could be removed
+
func setReadDeadline(fd *netFD, t time.Time) error {
- if t.IsZero() {
- fd.rdeadline = 0
- } else {
- fd.rdeadline = t.UnixNano()
- }
+ fd.rdeadline.setTime(t)
return nil
}
func setWriteDeadline(fd *netFD, t time.Time) error {
- if t.IsZero() {
- fd.wdeadline = 0
- } else {
- fd.wdeadline = t.UnixNano()
- }
+ fd.wdeadline.setTime(t)
return nil
}
func setDeadline(fd *netFD, t time.Time) error {
- if err := setReadDeadline(fd, t); err != nil {
- return err
- }
- return setWriteDeadline(fd, t)
-}
-
-func setReuseAddr(fd *netFD, reuse bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse)))
-}
-
-func setDontRoute(fd *netFD, dontroute bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute)))
+ setReadDeadline(fd, t)
+ setWriteDeadline(fd, t)
+ return nil
}
func setKeepAlive(fd *netFD, keepalive bool) error {
diff --git a/src/pkg/net/sockoptip.go b/src/pkg/net/sockoptip.go
deleted file mode 100644
index 1fcad4018..000000000
--- a/src/pkg/net/sockoptip.go
+++ /dev/null
@@ -1,219 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build darwin freebsd linux netbsd openbsd windows
-
-// IP-level socket options
-
-package net
-
-import (
- "os"
- "syscall"
-)
-
-func ipv4TOS(fd *netFD) (int, error) {
- if err := fd.incref(false); err != nil {
- return 0, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS)
- if err != nil {
- return 0, os.NewSyscallError("getsockopt", err)
- }
- return v, nil
-}
-
-func setIPv4TOS(fd *netFD, v int) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS, v)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv4TTL(fd *netFD) (int, error) {
- if err := fd.incref(false); err != nil {
- return 0, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL)
- if err != nil {
- return 0, os.NewSyscallError("getsockopt", err)
- }
- return v, nil
-}
-
-func setIPv4TTL(fd *netFD, v int) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL, v)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
- mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
- if err := setIPv4MreqToInterface(mreq, ifi); err != nil {
- return err
- }
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq))
-}
-
-func leaveIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
- mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
- if err := setIPv4MreqToInterface(mreq, ifi); err != nil {
- return err
- }
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_DROP_MEMBERSHIP, mreq))
-}
-
-func ipv6HopLimit(fd *netFD) (int, error) {
- if err := fd.incref(false); err != nil {
- return 0, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS)
- if err != nil {
- return 0, os.NewSyscallError("getsockopt", err)
- }
- return v, nil
-}
-
-func setIPv6HopLimit(fd *netFD, v int) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, v)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv6MulticastInterface(fd *netFD) (*Interface, error) {
- if err := fd.incref(false); err != nil {
- return nil, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF)
- if err != nil {
- return nil, os.NewSyscallError("getsockopt", err)
- }
- if v == 0 {
- return nil, nil
- }
- ifi, err := InterfaceByIndex(v)
- if err != nil {
- return nil, err
- }
- return ifi, nil
-}
-
-func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error {
- var v int
- if ifi != nil {
- v = ifi.Index
- }
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv6MulticastHopLimit(fd *netFD) (int, error) {
- if err := fd.incref(false); err != nil {
- return 0, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS)
- if err != nil {
- return 0, os.NewSyscallError("getsockopt", err)
- }
- return v, nil
-}
-
-func setIPv6MulticastHopLimit(fd *netFD, v int) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS, v)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv6MulticastLoopback(fd *netFD) (bool, error) {
- if err := fd.incref(false); err != nil {
- return false, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP)
- if err != nil {
- return false, os.NewSyscallError("getsockopt", err)
- }
- return v == 1, nil
-}
-
-func setIPv6MulticastLoopback(fd *netFD, v bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v))
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
- mreq := &syscall.IPv6Mreq{}
- copy(mreq.Multiaddr[:], ip)
- if ifi != nil {
- mreq.Interface = uint32(ifi.Index)
- }
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq))
-}
-
-func leaveIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
- mreq := &syscall.IPv6Mreq{}
- copy(mreq.Multiaddr[:], ip)
- if ifi != nil {
- mreq.Interface = uint32(ifi.Index)
- }
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_LEAVE_GROUP, mreq))
-}
diff --git a/src/pkg/net/sockoptip_bsd.go b/src/pkg/net/sockoptip_bsd.go
index 19e2b142e..263f85521 100644
--- a/src/pkg/net/sockoptip_bsd.go
+++ b/src/pkg/net/sockoptip_bsd.go
@@ -4,8 +4,6 @@
// +build darwin freebsd netbsd openbsd
-// IP-level socket options for BSD variants
-
package net
import (
@@ -13,48 +11,30 @@ import (
"syscall"
)
-func ipv4MulticastTTL(fd *netFD) (int, error) {
- if err := fd.incref(false); err != nil {
- return 0, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL)
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ ip, err := interfaceToIPv4Addr(ifi)
if err != nil {
- return 0, os.NewSyscallError("getsockopt", err)
+ return os.NewSyscallError("setsockopt", err)
}
- return int(v), nil
-}
-
-func setIPv4MulticastTTL(fd *netFD, v int) error {
+ var a [4]byte
+ copy(a[:], ip.To4())
if err := fd.incref(false); err != nil {
return err
}
defer fd.decref()
- err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, byte(v))
+ err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, a)
if err != nil {
return os.NewSyscallError("setsockopt", err)
}
return nil
}
-func ipv6TrafficClass(fd *netFD) (int, error) {
- if err := fd.incref(false); err != nil {
- return 0, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS)
- if err != nil {
- return 0, os.NewSyscallError("getsockopt", err)
- }
- return v, nil
-}
-
-func setIPv6TrafficClass(fd *netFD, v int) error {
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
if err := fd.incref(false); err != nil {
return err
}
defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v)
+ err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v)))
if err != nil {
return os.NewSyscallError("setsockopt", err)
}
diff --git a/src/pkg/net/sockoptip_darwin.go b/src/pkg/net/sockoptip_darwin.go
deleted file mode 100644
index 52b237c4b..000000000
--- a/src/pkg/net/sockoptip_darwin.go
+++ /dev/null
@@ -1,90 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// IP-level socket options for Darwin
-
-package net
-
-import (
- "os"
- "syscall"
-)
-
-func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
- if err := fd.incref(false); err != nil {
- return nil, err
- }
- defer fd.decref()
- a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
- if err != nil {
- return nil, os.NewSyscallError("getsockopt", err)
- }
- return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3]))
-}
-
-func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
- ip, err := interfaceToIPv4Addr(ifi)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- var x [4]byte
- copy(x[:], ip.To4())
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv4MulticastLoopback(fd *netFD) (bool, error) {
- if err := fd.incref(false); err != nil {
- return false, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
- if err != nil {
- return false, os.NewSyscallError("getsockopt", err)
- }
- return v == 1, nil
-}
-
-func setIPv4MulticastLoopback(fd *netFD, v bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv4ReceiveInterface(fd *netFD) (bool, error) {
- if err := fd.incref(false); err != nil {
- return false, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF)
- if err != nil {
- return false, os.NewSyscallError("getsockopt", err)
- }
- return v == 1, nil
-}
-
-func setIPv4ReceiveInterface(fd *netFD, v bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v))
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
diff --git a/src/pkg/net/sockoptip_freebsd.go b/src/pkg/net/sockoptip_freebsd.go
deleted file mode 100644
index 4a3bc2e82..000000000
--- a/src/pkg/net/sockoptip_freebsd.go
+++ /dev/null
@@ -1,92 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// IP-level socket options for FreeBSD
-
-package net
-
-import (
- "os"
- "syscall"
-)
-
-func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
- if err := fd.incref(false); err != nil {
- return nil, err
- }
- defer fd.decref()
- mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
- if err != nil {
- return nil, os.NewSyscallError("getsockopt", err)
- }
- if int(mreq.Ifindex) == 0 {
- return nil, nil
- }
- return InterfaceByIndex(int(mreq.Ifindex))
-}
-
-func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
- var v int32
- if ifi != nil {
- v = int32(ifi.Index)
- }
- mreq := &syscall.IPMreqn{Ifindex: v}
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv4MulticastLoopback(fd *netFD) (bool, error) {
- if err := fd.incref(false); err != nil {
- return false, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
- if err != nil {
- return false, os.NewSyscallError("getsockopt", err)
- }
- return v == 1, nil
-}
-
-func setIPv4MulticastLoopback(fd *netFD, v bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv4ReceiveInterface(fd *netFD) (bool, error) {
- if err := fd.incref(false); err != nil {
- return false, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF)
- if err != nil {
- return false, os.NewSyscallError("getsockopt", err)
- }
- return v == 1, nil
-}
-
-func setIPv4ReceiveInterface(fd *netFD, v bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v))
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
diff --git a/src/pkg/net/sockoptip_linux.go b/src/pkg/net/sockoptip_linux.go
index 169718f14..225fb0c4c 100644
--- a/src/pkg/net/sockoptip_linux.go
+++ b/src/pkg/net/sockoptip_linux.go
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// IP-level socket options for Linux
-
package net
import (
@@ -11,21 +9,6 @@ import (
"syscall"
)
-func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
- if err := fd.incref(false); err != nil {
- return nil, err
- }
- defer fd.decref()
- mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
- if err != nil {
- return nil, os.NewSyscallError("getsockopt", err)
- }
- if int(mreq.Ifindex) == 0 {
- return nil, nil
- }
- return InterfaceByIndex(int(mreq.Ifindex))
-}
-
func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
var v int32
if ifi != nil {
@@ -43,42 +26,6 @@ func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
return nil
}
-func ipv4MulticastTTL(fd *netFD) (int, error) {
- if err := fd.incref(false); err != nil {
- return 0, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL)
- if err != nil {
- return -1, os.NewSyscallError("getsockopt", err)
- }
- return v, nil
-}
-
-func setIPv4MulticastTTL(fd *netFD, v int) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, v)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv4MulticastLoopback(fd *netFD) (bool, error) {
- if err := fd.incref(false); err != nil {
- return false, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
- if err != nil {
- return false, os.NewSyscallError("getsockopt", err)
- }
- return v == 1, nil
-}
-
func setIPv4MulticastLoopback(fd *netFD, v bool) error {
if err := fd.incref(false); err != nil {
return err
@@ -90,51 +37,3 @@ func setIPv4MulticastLoopback(fd *netFD, v bool) error {
}
return nil
}
-
-func ipv4ReceiveInterface(fd *netFD) (bool, error) {
- if err := fd.incref(false); err != nil {
- return false, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO)
- if err != nil {
- return false, os.NewSyscallError("getsockopt", err)
- }
- return v == 1, nil
-}
-
-func setIPv4ReceiveInterface(fd *netFD, v bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO, boolint(v))
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv6TrafficClass(fd *netFD) (int, error) {
- if err := fd.incref(false); err != nil {
- return 0, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS)
- if err != nil {
- return 0, os.NewSyscallError("getsockopt", err)
- }
- return v, nil
-}
-
-func setIPv6TrafficClass(fd *netFD, v int) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
diff --git a/src/pkg/net/sockoptip_netbsd.go b/src/pkg/net/sockoptip_netbsd.go
deleted file mode 100644
index 446d92aa3..000000000
--- a/src/pkg/net/sockoptip_netbsd.go
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2012 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// IP-level socket options for NetBSD
-
-package net
-
-import "syscall"
-
-func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
- // TODO: Implement this
- return nil, syscall.EAFNOSUPPORT
-}
-
-func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
- // TODO: Implement this
- return syscall.EAFNOSUPPORT
-}
-
-func ipv4MulticastLoopback(fd *netFD) (bool, error) {
- // TODO: Implement this
- return false, syscall.EAFNOSUPPORT
-}
-
-func setIPv4MulticastLoopback(fd *netFD, v bool) error {
- // TODO: Implement this
- return syscall.EAFNOSUPPORT
-}
-
-func ipv4ReceiveInterface(fd *netFD) (bool, error) {
- // TODO: Implement this
- return false, syscall.EAFNOSUPPORT
-}
-
-func setIPv4ReceiveInterface(fd *netFD, v bool) error {
- // TODO: Implement this
- return syscall.EAFNOSUPPORT
-}
diff --git a/src/pkg/net/sockoptip_openbsd.go b/src/pkg/net/sockoptip_openbsd.go
deleted file mode 100644
index f3e42f1a9..000000000
--- a/src/pkg/net/sockoptip_openbsd.go
+++ /dev/null
@@ -1,90 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// IP-level socket options for OpenBSD
-
-package net
-
-import (
- "os"
- "syscall"
-)
-
-func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
- if err := fd.incref(false); err != nil {
- return nil, err
- }
- defer fd.decref()
- a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
- if err != nil {
- return nil, os.NewSyscallError("getsockopt", err)
- }
- return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3]))
-}
-
-func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
- ip, err := interfaceToIPv4Addr(ifi)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- var x [4]byte
- copy(x[:], ip.To4())
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv4MulticastLoopback(fd *netFD) (bool, error) {
- if err := fd.incref(false); err != nil {
- return false, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
- if err != nil {
- return false, os.NewSyscallError("getsockopt", err)
- }
- return v == 1, nil
-}
-
-func setIPv4MulticastLoopback(fd *netFD, v bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v)))
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
-
-func ipv4ReceiveInterface(fd *netFD) (bool, error) {
- if err := fd.incref(false); err != nil {
- return false, err
- }
- defer fd.decref()
- v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF)
- if err != nil {
- return false, os.NewSyscallError("getsockopt", err)
- }
- return v == 1, nil
-}
-
-func setIPv4ReceiveInterface(fd *netFD, v bool) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v))
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-}
diff --git a/src/pkg/net/sockoptip_posix.go b/src/pkg/net/sockoptip_posix.go
new file mode 100644
index 000000000..e4c56a0e4
--- /dev/null
+++ b/src/pkg/net/sockoptip_posix.go
@@ -0,0 +1,73 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd windows
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
+ mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
+ if err := setIPv4MreqToInterface(mreq, ifi); err != nil {
+ return err
+ }
+ if err := fd.incref(false); err != nil {
+ return err
+ }
+ defer fd.decref()
+ err := syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error {
+ var v int
+ if ifi != nil {
+ v = ifi.Index
+ }
+ if err := fd.incref(false); err != nil {
+ return err
+ }
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func setIPv6MulticastLoopback(fd *netFD, v bool) error {
+ if err := fd.incref(false); err != nil {
+ return err
+ }
+ defer fd.decref()
+ err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v))
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
+
+func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
+ mreq := &syscall.IPv6Mreq{}
+ copy(mreq.Multiaddr[:], ip)
+ if ifi != nil {
+ mreq.Interface = uint32(ifi.Index)
+ }
+ if err := fd.incref(false); err != nil {
+ return err
+ }
+ defer fd.decref()
+ err := syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ return nil
+}
diff --git a/src/pkg/net/sockoptip_windows.go b/src/pkg/net/sockoptip_windows.go
index b9db3334d..3e248441a 100644
--- a/src/pkg/net/sockoptip_windows.go
+++ b/src/pkg/net/sockoptip_windows.go
@@ -2,90 +2,41 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// IP-level socket options for Windows
-
package net
import (
"os"
"syscall"
+ "unsafe"
)
-func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
- // TODO: Implement this
- return nil, syscall.EWINDOWS
-}
-
func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
ip, err := interfaceToIPv4Addr(ifi)
if err != nil {
return os.NewSyscallError("setsockopt", err)
}
- var x [4]byte
- copy(x[:], ip.To4())
+ var a [4]byte
+ copy(a[:], ip.To4())
if err := fd.incref(false); err != nil {
return err
}
defer fd.decref()
- err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x)
+ err = syscall.Setsockopt(fd.sysfd, int32(syscall.IPPROTO_IP), int32(syscall.IP_MULTICAST_IF), (*byte)(unsafe.Pointer(&a[0])), 4)
if err != nil {
return os.NewSyscallError("setsockopt", err)
}
return nil
}
-func ipv4MulticastTTL(fd *netFD) (int, error) {
- // TODO: Implement this
- return -1, syscall.EWINDOWS
-}
-
-func setIPv4MulticastTTL(fd *netFD, v int) error {
- if err := fd.incref(false); err != nil {
- return err
- }
- defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, v)
- if err != nil {
- return os.NewSyscallError("setsockopt", err)
- }
- return nil
-
-}
-
-func ipv4MulticastLoopback(fd *netFD) (bool, error) {
- // TODO: Implement this
- return false, syscall.EWINDOWS
-}
-
func setIPv4MulticastLoopback(fd *netFD, v bool) error {
if err := fd.incref(false); err != nil {
return err
}
defer fd.decref()
- err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
+ vv := int32(boolint(v))
+ err := syscall.Setsockopt(fd.sysfd, int32(syscall.IPPROTO_IP), int32(syscall.IP_MULTICAST_LOOP), (*byte)(unsafe.Pointer(&vv)), 4)
if err != nil {
return os.NewSyscallError("setsockopt", err)
}
return nil
-
-}
-
-func ipv4ReceiveInterface(fd *netFD) (bool, error) {
- // TODO: Implement this
- return false, syscall.EWINDOWS
-}
-
-func setIPv4ReceiveInterface(fd *netFD, v bool) error {
- // TODO: Implement this
- return syscall.EWINDOWS
-}
-
-func ipv6TrafficClass(fd *netFD) (int, error) {
- // TODO: Implement this
- return 0, syscall.EWINDOWS
-}
-
-func setIPv6TrafficClass(fd *netFD, v int) error {
- // TODO: Implement this
- return syscall.EWINDOWS
}
diff --git a/src/pkg/net/sys_cloexec.go b/src/pkg/net/sys_cloexec.go
new file mode 100644
index 000000000..17e874908
--- /dev/null
+++ b/src/pkg/net/sys_cloexec.go
@@ -0,0 +1,54 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements sysSocket and accept for platforms that do not
+// provide a fast path for setting SetNonblock and CloseOnExec.
+
+// +build darwin freebsd netbsd openbsd
+
+package net
+
+import "syscall"
+
+// Wrapper around the socket system call that marks the returned file
+// descriptor as nonblocking and close-on-exec.
+func sysSocket(f, t, p int) (int, error) {
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err := syscall.Socket(f, t, p)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
+ if err != nil {
+ return -1, err
+ }
+ if err = syscall.SetNonblock(s, true); err != nil {
+ syscall.Close(s)
+ return -1, err
+ }
+ return s, nil
+}
+
+// Wrapper around the accept system call that marks the returned file
+// descriptor as nonblocking and close-on-exec.
+func accept(fd int) (int, syscall.Sockaddr, error) {
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ // It is probably okay to hold the lock across syscall.Accept
+ // because we have put fd.sysfd into non-blocking mode.
+ // However, a call to the File method will put it back into
+ // blocking mode. We can't take that risk, so no use of ForkLock here.
+ nfd, sa, err := syscall.Accept(fd)
+ if err == nil {
+ syscall.CloseOnExec(nfd)
+ }
+ if err != nil {
+ return -1, nil, err
+ }
+ if err = syscall.SetNonblock(nfd, true); err != nil {
+ syscall.Close(nfd)
+ return -1, nil, err
+ }
+ return nfd, sa, nil
+}
diff --git a/src/pkg/net/tcp_test.go b/src/pkg/net/tcp_test.go
new file mode 100644
index 000000000..6c4485a94
--- /dev/null
+++ b/src/pkg/net/tcp_test.go
@@ -0,0 +1,206 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "reflect"
+ "runtime"
+ "testing"
+ "time"
+)
+
+func BenchmarkTCP4OneShot(b *testing.B) {
+ benchmarkTCP(b, false, false, "127.0.0.1:0")
+}
+
+func BenchmarkTCP4OneShotTimeout(b *testing.B) {
+ benchmarkTCP(b, false, true, "127.0.0.1:0")
+}
+
+func BenchmarkTCP4Persistent(b *testing.B) {
+ benchmarkTCP(b, true, false, "127.0.0.1:0")
+}
+
+func BenchmarkTCP4PersistentTimeout(b *testing.B) {
+ benchmarkTCP(b, true, true, "127.0.0.1:0")
+}
+
+func BenchmarkTCP6OneShot(b *testing.B) {
+ if !supportsIPv6 {
+ b.Skip("ipv6 is not supported")
+ }
+ benchmarkTCP(b, false, false, "[::1]:0")
+}
+
+func BenchmarkTCP6OneShotTimeout(b *testing.B) {
+ if !supportsIPv6 {
+ b.Skip("ipv6 is not supported")
+ }
+ benchmarkTCP(b, false, true, "[::1]:0")
+}
+
+func BenchmarkTCP6Persistent(b *testing.B) {
+ if !supportsIPv6 {
+ b.Skip("ipv6 is not supported")
+ }
+ benchmarkTCP(b, true, false, "[::1]:0")
+}
+
+func BenchmarkTCP6PersistentTimeout(b *testing.B) {
+ if !supportsIPv6 {
+ b.Skip("ipv6 is not supported")
+ }
+ benchmarkTCP(b, true, true, "[::1]:0")
+}
+
+func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) {
+ const msgLen = 512
+ conns := b.N
+ numConcurrent := runtime.GOMAXPROCS(-1) * 16
+ msgs := 1
+ if persistent {
+ conns = numConcurrent
+ msgs = b.N / conns
+ if msgs == 0 {
+ msgs = 1
+ }
+ if conns > b.N {
+ conns = b.N
+ }
+ }
+ sendMsg := func(c Conn, buf []byte) bool {
+ n, err := c.Write(buf)
+ if n != len(buf) || err != nil {
+ b.Logf("Write failed: %v", err)
+ return false
+ }
+ return true
+ }
+ recvMsg := func(c Conn, buf []byte) bool {
+ for read := 0; read != len(buf); {
+ n, err := c.Read(buf)
+ read += n
+ if err != nil {
+ b.Logf("Read failed: %v", err)
+ return false
+ }
+ }
+ return true
+ }
+ ln, err := Listen("tcp", laddr)
+ if err != nil {
+ b.Fatalf("Listen failed: %v", err)
+ }
+ defer ln.Close()
+ // Acceptor.
+ go func() {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ break
+ }
+ // Server connection.
+ go func(c Conn) {
+ defer c.Close()
+ if timeout {
+ c.SetDeadline(time.Now().Add(time.Hour)) // Not intended to fire.
+ }
+ var buf [msgLen]byte
+ for m := 0; m < msgs; m++ {
+ if !recvMsg(c, buf[:]) || !sendMsg(c, buf[:]) {
+ break
+ }
+ }
+ }(c)
+ }
+ }()
+ sem := make(chan bool, numConcurrent)
+ for i := 0; i < conns; i++ {
+ sem <- true
+ // Client connection.
+ go func() {
+ defer func() {
+ <-sem
+ }()
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ b.Logf("Dial failed: %v", err)
+ return
+ }
+ defer c.Close()
+ if timeout {
+ c.SetDeadline(time.Now().Add(time.Hour)) // Not intended to fire.
+ }
+ var buf [msgLen]byte
+ for m := 0; m < msgs; m++ {
+ if !sendMsg(c, buf[:]) || !recvMsg(c, buf[:]) {
+ break
+ }
+ }
+ }()
+ }
+ for i := 0; i < cap(sem); i++ {
+ sem <- true
+ }
+}
+
+var resolveTCPAddrTests = []struct {
+ net string
+ litAddr string
+ addr *TCPAddr
+ err error
+}{
+ {"tcp", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil},
+ {"tcp4", "127.0.0.1:65535", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil},
+
+ {"tcp", "[::1]:1", &TCPAddr{IP: ParseIP("::1"), Port: 1}, nil},
+ {"tcp6", "[::1]:65534", &TCPAddr{IP: ParseIP("::1"), Port: 65534}, nil},
+
+ {"", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior
+ {"", "[::1]:0", &TCPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior
+
+ {"http", "127.0.0.1:0", nil, UnknownNetworkError("http")},
+}
+
+func TestResolveTCPAddr(t *testing.T) {
+ for _, tt := range resolveTCPAddrTests {
+ addr, err := ResolveTCPAddr(tt.net, tt.litAddr)
+ if err != tt.err {
+ t.Fatalf("ResolveTCPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err)
+ }
+ if !reflect.DeepEqual(addr, tt.addr) {
+ t.Fatalf("got %#v; expected %#v", addr, tt.addr)
+ }
+ }
+}
+
+var tcpListenerNameTests = []struct {
+ net string
+ laddr *TCPAddr
+}{
+ {"tcp4", &TCPAddr{IP: IPv4(127, 0, 0, 1)}},
+ {"tcp4", &TCPAddr{}},
+ {"tcp4", nil},
+}
+
+func TestTCPListenerName(t *testing.T) {
+ if testing.Short() || !*testExternal {
+ t.Skip("skipping test to avoid external network")
+ }
+
+ for _, tt := range tcpListenerNameTests {
+ ln, err := ListenTCP(tt.net, tt.laddr)
+ if err != nil {
+ t.Errorf("ListenTCP failed: %v", err)
+ return
+ }
+ defer ln.Close()
+ la := ln.Addr()
+ if a, ok := la.(*TCPAddr); !ok || a.Port == 0 {
+ t.Errorf("got %v; expected a proper address with non-zero port number", la)
+ return
+ }
+ }
+}
diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go
index 47fbf2919..d5158b22d 100644
--- a/src/pkg/net/tcpsock.go
+++ b/src/pkg/net/tcpsock.go
@@ -10,6 +10,7 @@ package net
type TCPAddr struct {
IP IP
Port int
+ Zone string // IPv6 scoped addressing zone
}
// Network returns the address's network name, "tcp".
@@ -28,9 +29,16 @@ func (a *TCPAddr) String() string {
// "tcp4" or "tcp6". A literal IPv6 host address must be
// enclosed in square brackets, as in "[::]:80".
func ResolveTCPAddr(net, addr string) (*TCPAddr, error) {
- ip, port, err := hostPortToIP(net, addr)
+ switch net {
+ case "tcp", "tcp4", "tcp6":
+ case "": // a hint wildcard for Go 1.0 undocumented behavior
+ net = "tcp"
+ default:
+ return nil, UnknownNetworkError(net)
+ }
+ a, err := resolveInternetAddr(net, addr, noDeadline)
if err != nil {
return nil, err
}
- return &TCPAddr{ip, port}, nil
+ return a.(*TCPAddr), nil
}
diff --git a/src/pkg/net/tcpsock_plan9.go b/src/pkg/net/tcpsock_plan9.go
index 35f56966e..ed3664603 100644
--- a/src/pkg/net/tcpsock_plan9.go
+++ b/src/pkg/net/tcpsock_plan9.go
@@ -2,34 +2,30 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// TCP for Plan 9
+// TCP sockets for Plan 9
package net
import (
+ "io"
+ "os"
"syscall"
"time"
)
-// TCPConn is an implementation of the Conn interface
-// for TCP network connections.
+// TCPConn is an implementation of the Conn interface for TCP network
+// connections.
type TCPConn struct {
- plan9Conn
+ conn
}
-// SetDeadline implements the Conn SetDeadline method.
-func (c *TCPConn) SetDeadline(t time.Time) error {
- return syscall.EPLAN9
+func newTCPConn(fd *netFD) *TCPConn {
+ return &TCPConn{conn{fd}}
}
-// SetReadDeadline implements the Conn SetReadDeadline method.
-func (c *TCPConn) SetReadDeadline(t time.Time) error {
- return syscall.EPLAN9
-}
-
-// SetWriteDeadline implements the Conn SetWriteDeadline method.
-func (c *TCPConn) SetWriteDeadline(t time.Time) error {
- return syscall.EPLAN9
+// ReadFrom implements the io.ReaderFrom ReadFrom method.
+func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
+ return genericReadFrom(c, r)
}
// CloseRead shuts down the reading side of the TCP connection.
@@ -38,7 +34,7 @@ func (c *TCPConn) CloseRead() error {
if !c.ok() {
return syscall.EINVAL
}
- return syscall.EPLAN9
+ return c.fd.CloseRead()
}
// CloseWrite shuts down the writing side of the TCP connection.
@@ -47,51 +43,142 @@ func (c *TCPConn) CloseWrite() error {
if !c.ok() {
return syscall.EINVAL
}
+ return c.fd.CloseWrite()
+}
+
+// SetLinger sets the behavior of Close() on a connection which still
+// has data waiting to be sent or to be acknowledged.
+//
+// If sec < 0 (the default), Close returns immediately and the
+// operating system finishes sending the data in the background.
+//
+// If sec == 0, Close returns immediately and the operating system
+// discards any unsent or unacknowledged data.
+//
+// If sec > 0, Close blocks for at most sec seconds waiting for data
+// to be sent and acknowledged.
+func (c *TCPConn) SetLinger(sec int) error {
+ return syscall.EPLAN9
+}
+
+// SetKeepAlive sets whether the operating system should send
+// keepalive messages on the connection.
+func (c *TCPConn) SetKeepAlive(keepalive bool) error {
+ return syscall.EPLAN9
+}
+
+// SetNoDelay controls whether the operating system should delay
+// packet transmission in hopes of sending fewer packets (Nagle's
+// algorithm). The default is true (no delay), meaning that data is
+// sent as soon as possible after a Write.
+func (c *TCPConn) SetNoDelay(noDelay bool) error {
return syscall.EPLAN9
}
// DialTCP connects to the remote address raddr on the network net,
-// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
-// as the local address for the connection.
-func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err error) {
+// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is
+// used as the local address for the connection.
+func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ return dialTCP(net, laddr, raddr, noDeadline)
+}
+
+func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
+ if !deadline.IsZero() {
+ panic("net.dialTCP: deadline not implemented on Plan 9")
+ }
switch net {
case "tcp", "tcp4", "tcp6":
default:
- return nil, UnknownNetworkError(net)
+ return nil, &OpError{"dial", net, raddr, UnknownNetworkError(net)}
}
if raddr == nil {
return nil, &OpError{"dial", net, nil, errMissingAddress}
}
- c1, err := dialPlan9(net, laddr, raddr)
+ fd, err := dialPlan9(net, laddr, raddr)
if err != nil {
- return
+ return nil, err
}
- return &TCPConn{*c1}, nil
+ return newTCPConn(fd), nil
}
-// TCPListener is a TCP network listener.
-// Clients should typically use variables of type Listener
-// instead of assuming TCP.
+// TCPListener is a TCP network listener. Clients should typically
+// use variables of type Listener instead of assuming TCP.
type TCPListener struct {
- plan9Listener
+ fd *netFD
+}
+
+// AcceptTCP accepts the next incoming call and returns the new
+// connection and the remote address.
+func (l *TCPListener) AcceptTCP() (*TCPConn, error) {
+ if l == nil || l.fd == nil || l.fd.ctl == nil {
+ return nil, syscall.EINVAL
+ }
+ fd, err := l.fd.acceptPlan9()
+ if err != nil {
+ return nil, err
+ }
+ return newTCPConn(fd), nil
+}
+
+// Accept implements the Accept method in the Listener interface; it
+// waits for the next call and returns a generic Conn.
+func (l *TCPListener) Accept() (Conn, error) {
+ if l == nil || l.fd == nil || l.fd.ctl == nil {
+ return nil, syscall.EINVAL
+ }
+ c, err := l.AcceptTCP()
+ if err != nil {
+ return nil, err
+ }
+ return c, nil
}
-// ListenTCP announces on the TCP address laddr and returns a TCP listener.
-// Net must be "tcp", "tcp4", or "tcp6".
-// If laddr has a port of 0, it means to listen on some available port.
-// The caller can use l.Addr() to retrieve the chosen address.
-func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err error) {
+// Close stops listening on the TCP address.
+// Already Accepted connections are not closed.
+func (l *TCPListener) Close() error {
+ if l == nil || l.fd == nil || l.fd.ctl == nil {
+ return syscall.EINVAL
+ }
+ if _, err := l.fd.ctl.WriteString("hangup"); err != nil {
+ l.fd.ctl.Close()
+ return &OpError{"close", l.fd.ctl.Name(), l.fd.laddr, err}
+ }
+ return l.fd.ctl.Close()
+}
+
+// Addr returns the listener's network address, a *TCPAddr.
+func (l *TCPListener) Addr() Addr { return l.fd.laddr }
+
+// SetDeadline sets the deadline associated with the listener.
+// A zero time value disables the deadline.
+func (l *TCPListener) SetDeadline(t time.Time) error {
+ if l == nil || l.fd == nil || l.fd.ctl == nil {
+ return syscall.EINVAL
+ }
+ return setDeadline(l.fd, t)
+}
+
+// File returns a copy of the underlying os.File, set to blocking
+// mode. It is the caller's responsibility to close f when finished.
+// Closing l does not affect f, and closing f does not affect l.
+func (l *TCPListener) File() (f *os.File, err error) { return l.dup() }
+
+// ListenTCP announces on the TCP address laddr and returns a TCP
+// listener. Net must be "tcp", "tcp4", or "tcp6". If laddr has a
+// port of 0, it means to listen on some available port. The caller
+// can use l.Addr() to retrieve the chosen address.
+func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) {
switch net {
case "tcp", "tcp4", "tcp6":
default:
- return nil, UnknownNetworkError(net)
+ return nil, &OpError{"listen", net, laddr, UnknownNetworkError(net)}
}
if laddr == nil {
- return nil, &OpError{"listen", net, nil, errMissingAddress}
+ laddr = &TCPAddr{}
}
- l1, err := listenPlan9(net, laddr)
+ fd, err := listenPlan9(net, laddr)
if err != nil {
- return
+ return nil, err
}
- return &TCPListener{*l1}, nil
+ return &TCPListener{fd}, nil
}
diff --git a/src/pkg/net/tcpsock_posix.go b/src/pkg/net/tcpsock_posix.go
index e6b1937fb..bd5a2a287 100644
--- a/src/pkg/net/tcpsock_posix.go
+++ b/src/pkg/net/tcpsock_posix.go
@@ -23,14 +23,9 @@ import (
func sockaddrToTCP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
- return &TCPAddr{sa.Addr[0:], sa.Port}
+ return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port}
case *syscall.SockaddrInet6:
- return &TCPAddr{sa.Addr[0:], sa.Port}
- default:
- if sa != nil {
- // Diagnose when we will turn a non-nil sockaddr into a nil.
- panic("unexpected type in sockaddrToTCP")
- }
+ return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))}
}
return nil
}
@@ -53,7 +48,7 @@ func (a *TCPAddr) isWildcard() bool {
}
func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
- return ipToSockaddr(family, a.IP, a.Port)
+ return ipToSockaddr(family, a.IP, a.Port, a.Zone)
}
func (a *TCPAddr) toAddr() sockaddr {
@@ -66,27 +61,15 @@ func (a *TCPAddr) toAddr() sockaddr {
// TCPConn is an implementation of the Conn interface
// for TCP network connections.
type TCPConn struct {
- fd *netFD
+ conn
}
func newTCPConn(fd *netFD) *TCPConn {
- c := &TCPConn{fd}
+ c := &TCPConn{conn{fd}}
c.SetNoDelay(true)
return c
}
-func (c *TCPConn) ok() bool { return c != nil && c.fd != nil }
-
-// Implementation of the Conn interface - see Conn for documentation.
-
-// Read implements the Conn Read method.
-func (c *TCPConn) Read(b []byte) (n int, err error) {
- if !c.ok() {
- return 0, syscall.EINVAL
- }
- return c.fd.Read(b)
-}
-
// ReadFrom implements the io.ReaderFrom ReadFrom method.
func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
if n, err, handled := sendFile(c.fd, r); handled {
@@ -95,22 +78,6 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
-// Write implements the Conn Write method.
-func (c *TCPConn) Write(b []byte) (n int, err error) {
- if !c.ok() {
- return 0, syscall.EINVAL
- }
- return c.fd.Write(b)
-}
-
-// Close closes the TCP connection.
-func (c *TCPConn) Close() error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return c.fd.Close()
-}
-
// CloseRead shuts down the reading side of the TCP connection.
// Most callers should just use Close.
func (c *TCPConn) CloseRead() error {
@@ -129,64 +96,6 @@ func (c *TCPConn) CloseWrite() error {
return c.fd.CloseWrite()
}
-// LocalAddr returns the local network address, a *TCPAddr.
-func (c *TCPConn) LocalAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.fd.laddr
-}
-
-// RemoteAddr returns the remote network address, a *TCPAddr.
-func (c *TCPConn) RemoteAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.fd.raddr
-}
-
-// SetDeadline implements the Conn SetDeadline method.
-func (c *TCPConn) SetDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setDeadline(c.fd, t)
-}
-
-// SetReadDeadline implements the Conn SetReadDeadline method.
-func (c *TCPConn) SetReadDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setReadDeadline(c.fd, t)
-}
-
-// SetWriteDeadline implements the Conn SetWriteDeadline method.
-func (c *TCPConn) SetWriteDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setWriteDeadline(c.fd, t)
-}
-
-// SetReadBuffer sets the size of the operating system's
-// receive buffer associated with the connection.
-func (c *TCPConn) SetReadBuffer(bytes int) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setReadBuffer(c.fd, bytes)
-}
-
-// SetWriteBuffer sets the size of the operating system's
-// transmit buffer associated with the connection.
-func (c *TCPConn) SetWriteBuffer(bytes int) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setWriteBuffer(c.fd, bytes)
-}
-
// SetLinger sets the behavior of Close() on a connection
// which still has data waiting to be sent or to be acknowledged.
//
@@ -225,20 +134,23 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error {
return setNoDelay(c.fd, noDelay)
}
-// File returns a copy of the underlying os.File, set to blocking mode.
-// It is the caller's responsibility to close f when finished.
-// Closing c does not affect f, and closing f does not affect c.
-func (c *TCPConn) File() (f *os.File, err error) { return c.fd.dup() }
-
// DialTCP connects to the remote address raddr on the network net,
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
// as the local address for the connection.
func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ switch net {
+ case "tcp", "tcp4", "tcp6":
+ default:
+ return nil, UnknownNetworkError(net)
+ }
if raddr == nil {
return nil, &OpError{"dial", net, nil, errMissingAddress}
}
+ return dialTCP(net, laddr, raddr, noDeadline)
+}
- fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP)
+func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
+ fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP)
// TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
@@ -257,9 +169,18 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
// use the result. See also:
// http://golang.org/issue/2690
// http://stackoverflow.com/questions/4949858/
- for i := 0; i < 2 && err == nil && laddr == nil && selfConnect(fd); i++ {
- fd.Close()
- fd, err = internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP)
+ //
+ // The opposite can also happen: if we ask the kernel to pick an appropriate
+ // originating local address, sometimes it picks one that is already in use.
+ // So if the error is EADDRNOTAVAIL, we have to try again too, just for
+ // a different reason.
+ //
+ // The kernel socket code is no doubt enjoying watching us squirm.
+ for i := 0; i < 2 && (laddr == nil || laddr.Port == 0) && (selfConnect(fd, err) || spuriousENOTAVAIL(err)); i++ {
+ if err == nil {
+ fd.Close()
+ }
+ fd, err = internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP)
}
if err != nil {
@@ -268,7 +189,12 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
return newTCPConn(fd), nil
}
-func selfConnect(fd *netFD) bool {
+func selfConnect(fd *netFD, err error) bool {
+ // If the connect failed, we clearly didn't connect to ourselves.
+ if err != nil {
+ return false
+ }
+
// The socket constructor can return an fd with raddr nil under certain
// unknown conditions. The errors in the calls there to Getpeername
// are discarded, but we can't catch the problem there because those
@@ -285,6 +211,11 @@ func selfConnect(fd *netFD) bool {
return l.Port == r.Port && l.IP.Equal(r.IP)
}
+func spuriousENOTAVAIL(err error) bool {
+ e, ok := err.(*OpError)
+ return ok && e.Err == syscall.EADDRNOTAVAIL
+}
+
// TCPListener is a TCP network listener.
// Clients should typically use variables of type Listener
// instead of assuming TCP.
@@ -292,29 +223,10 @@ type TCPListener struct {
fd *netFD
}
-// ListenTCP announces on the TCP address laddr and returns a TCP listener.
-// Net must be "tcp", "tcp4", or "tcp6".
-// If laddr has a port of 0, it means to listen on some available port.
-// The caller can use l.Addr() to retrieve the chosen address.
-func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) {
- fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP)
- if err != nil {
- return nil, err
- }
- err = syscall.Listen(fd.sysfd, listenerBacklog)
- if err != nil {
- closesocket(fd.sysfd)
- return nil, &OpError{"listen", net, laddr, err}
- }
- l := new(TCPListener)
- l.fd = fd
- return l, nil
-}
-
// AcceptTCP accepts the next incoming call and returns the new connection
// and the remote address.
func (l *TCPListener) AcceptTCP() (c *TCPConn, err error) {
- if l == nil || l.fd == nil || l.fd.sysfd < 0 {
+ if l == nil || l.fd == nil {
return nil, syscall.EINVAL
}
fd, err := l.fd.accept(sockaddrToTCP)
@@ -359,3 +271,28 @@ func (l *TCPListener) SetDeadline(t time.Time) error {
// It is the caller's responsibility to close f when finished.
// Closing l does not affect f, and closing f does not affect l.
func (l *TCPListener) File() (f *os.File, err error) { return l.fd.dup() }
+
+// ListenTCP announces on the TCP address laddr and returns a TCP listener.
+// Net must be "tcp", "tcp4", or "tcp6".
+// If laddr has a port of 0, it means to listen on some available port.
+// The caller can use l.Addr() to retrieve the chosen address.
+func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) {
+ switch net {
+ case "tcp", "tcp4", "tcp6":
+ default:
+ return nil, UnknownNetworkError(net)
+ }
+ if laddr == nil {
+ laddr = &TCPAddr{}
+ }
+ fd, err := internetSocket(net, laddr.toAddr(), nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP)
+ if err != nil {
+ return nil, err
+ }
+ err = syscall.Listen(fd.sysfd, listenerBacklog)
+ if err != nil {
+ closesocket(fd.sysfd)
+ return nil, &OpError{"listen", net, laddr, err}
+ }
+ return &TCPListener{fd}, nil
+}
diff --git a/src/pkg/net/textproto/reader.go b/src/pkg/net/textproto/reader.go
index 125feb3e8..b61bea862 100644
--- a/src/pkg/net/textproto/reader.go
+++ b/src/pkg/net/textproto/reader.go
@@ -128,6 +128,17 @@ func (r *Reader) readContinuedLineSlice() ([]byte, error) {
return line, nil
}
+ // Optimistically assume that we have started to buffer the next line
+ // and it starts with an ASCII letter (the next header key), so we can
+ // avoid copying that buffered data around in memory and skipping over
+ // non-existent whitespace.
+ if r.R.Buffered() > 1 {
+ peek, err := r.R.Peek(1)
+ if err == nil && isASCIILetter(peek[0]) {
+ return trim(line), nil
+ }
+ }
+
// ReadByte or the next readLineSlice will flush the read buffer;
// copy the slice into buf.
r.buf = append(r.buf[:0], trim(line)...)
@@ -445,23 +456,25 @@ func (r *Reader) ReadDotLines() ([]string, error) {
// }
//
func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
- m := make(MIMEHeader)
+ m := make(MIMEHeader, 4)
for {
kv, err := r.readContinuedLineSlice()
if len(kv) == 0 {
return m, err
}
- // Key ends at first colon; must not have spaces.
+ // Key ends at first colon; should not have spaces but
+ // they appear in the wild, violating specs, so we
+ // remove them if present.
i := bytes.IndexByte(kv, ':')
if i < 0 {
return m, ProtocolError("malformed MIME header line: " + string(kv))
}
- key := string(kv[0:i])
- if strings.Index(key, " ") >= 0 {
- key = strings.TrimRight(key, " ")
+ endKey := i
+ for endKey > 0 && kv[endKey-1] == ' ' {
+ endKey--
}
- key = CanonicalMIMEHeaderKey(key)
+ key := canonicalMIMEHeaderKey(kv[:endKey])
// Skip initial spaces in value.
i++ // skip colon
@@ -484,41 +497,107 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
// letter and any letter following a hyphen to upper case;
// the rest are converted to lowercase. For example, the
// canonical key for "accept-encoding" is "Accept-Encoding".
+// MIME header keys are assumed to be ASCII only.
func CanonicalMIMEHeaderKey(s string) string {
// Quick check for canonical encoding.
- needUpper := true
+ upper := true
for i := 0; i < len(s); i++ {
c := s[i]
- if needUpper && 'a' <= c && c <= 'z' {
- goto MustRewrite
+ if upper && 'a' <= c && c <= 'z' {
+ return canonicalMIMEHeaderKey([]byte(s))
}
- if !needUpper && 'A' <= c && c <= 'Z' {
- goto MustRewrite
+ if !upper && 'A' <= c && c <= 'Z' {
+ return canonicalMIMEHeaderKey([]byte(s))
}
- needUpper = c == '-'
+ upper = c == '-'
}
return s
+}
+
+const toLower = 'a' - 'A'
-MustRewrite:
- // Canonicalize: first letter upper case
- // and upper case after each dash.
- // (Host, User-Agent, If-Modified-Since).
- // MIME headers are ASCII only, so no Unicode issues.
- a := []byte(s)
+// canonicalMIMEHeaderKey is like CanonicalMIMEHeaderKey but is
+// allowed to mutate the provided byte slice before returning the
+// string.
+func canonicalMIMEHeaderKey(a []byte) string {
+ // Look for it in commonHeaders , so that we can avoid an
+ // allocation by sharing the strings among all users
+ // of textproto. If we don't find it, a has been canonicalized
+ // so just return string(a).
upper := true
- for i, v := range a {
- if v == ' ' {
+ lo := 0
+ hi := len(commonHeaders)
+ for i := 0; i < len(a); i++ {
+ // Canonicalize: first letter upper case
+ // and upper case after each dash.
+ // (Host, User-Agent, If-Modified-Since).
+ // MIME headers are ASCII only, so no Unicode issues.
+ if a[i] == ' ' {
a[i] = '-'
upper = true
continue
}
- if upper && 'a' <= v && v <= 'z' {
- a[i] = v + 'A' - 'a'
+ c := a[i]
+ if upper && 'a' <= c && c <= 'z' {
+ c -= toLower
+ } else if !upper && 'A' <= c && c <= 'Z' {
+ c += toLower
}
- if !upper && 'A' <= v && v <= 'Z' {
- a[i] = v + 'a' - 'A'
+ a[i] = c
+ upper = c == '-' // for next time
+
+ if lo < hi {
+ for lo < hi && (len(commonHeaders[lo]) <= i || commonHeaders[lo][i] < c) {
+ lo++
+ }
+ for hi > lo && commonHeaders[hi-1][i] > c {
+ hi--
+ }
}
- upper = v == '-'
+ }
+ if lo < hi && len(commonHeaders[lo]) == len(a) {
+ return commonHeaders[lo]
}
return string(a)
}
+
+var commonHeaders = []string{
+ "Accept",
+ "Accept-Charset",
+ "Accept-Encoding",
+ "Accept-Language",
+ "Accept-Ranges",
+ "Cache-Control",
+ "Cc",
+ "Connection",
+ "Content-Id",
+ "Content-Language",
+ "Content-Length",
+ "Content-Transfer-Encoding",
+ "Content-Type",
+ "Date",
+ "Dkim-Signature",
+ "Etag",
+ "Expires",
+ "From",
+ "Host",
+ "If-Modified-Since",
+ "If-None-Match",
+ "In-Reply-To",
+ "Last-Modified",
+ "Location",
+ "Message-Id",
+ "Mime-Version",
+ "Pragma",
+ "Received",
+ "Return-Path",
+ "Server",
+ "Set-Cookie",
+ "Subject",
+ "To",
+ "User-Agent",
+ "Via",
+ "X-Forwarded-For",
+ "X-Imforwards",
+ "X-Powered-By",
+}
diff --git a/src/pkg/net/textproto/reader_test.go b/src/pkg/net/textproto/reader_test.go
index 7c5d16227..26987f611 100644
--- a/src/pkg/net/textproto/reader_test.go
+++ b/src/pkg/net/textproto/reader_test.go
@@ -6,6 +6,7 @@ package textproto
import (
"bufio"
+ "bytes"
"io"
"reflect"
"strings"
@@ -23,6 +24,7 @@ var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{
{"uSER-aGENT", "User-Agent"},
{"user-agent", "User-Agent"},
{"USER-AGENT", "User-Agent"},
+ {"üser-agenT", "üser-Agent"}, // non-ASCII unchanged
}
func TestCanonicalMIMEHeaderKey(t *testing.T) {
@@ -239,3 +241,95 @@ func TestRFC959Lines(t *testing.T) {
}
}
}
+
+func TestCommonHeaders(t *testing.T) {
+ // need to disable the commonHeaders-based optimization
+ // during this check, or we'd not be testing anything
+ oldch := commonHeaders
+ commonHeaders = []string{}
+ defer func() { commonHeaders = oldch }()
+
+ last := ""
+ for _, h := range oldch {
+ if last > h {
+ t.Errorf("%v is out of order", h)
+ }
+ if last == h {
+ t.Errorf("%v is duplicated", h)
+ }
+ if canon := CanonicalMIMEHeaderKey(h); h != canon {
+ t.Errorf("%v is not canonical", h)
+ }
+ last = h
+ }
+}
+
+var clientHeaders = strings.Replace(`Host: golang.org
+Connection: keep-alive
+Cache-Control: max-age=0
+Accept: application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5
+User-Agent: Mozilla/5.0 (X11; U; Linux x86_64; en-US) AppleWebKit/534.3 (KHTML, like Gecko) Chrome/6.0.472.63 Safari/534.3
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8,fr-CH;q=0.6
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+COOKIE: __utma=000000000.0000000000.0000000000.0000000000.0000000000.00; __utmb=000000000.0.00.0000000000; __utmc=000000000; __utmz=000000000.0000000000.00.0.utmcsr=code.google.com|utmccn=(referral)|utmcmd=referral|utmcct=/p/go/issues/detail
+Non-Interned: test
+
+`, "\n", "\r\n", -1)
+
+var serverHeaders = strings.Replace(`Content-Type: text/html; charset=utf-8
+Content-Encoding: gzip
+Date: Thu, 27 Sep 2012 09:03:33 GMT
+Server: Google Frontend
+Cache-Control: private
+Content-Length: 2298
+VIA: 1.1 proxy.example.com:80 (XXX/n.n.n-nnn)
+Connection: Close
+Non-Interned: test
+
+`, "\n", "\r\n", -1)
+
+func BenchmarkReadMIMEHeader(b *testing.B) {
+ var buf bytes.Buffer
+ br := bufio.NewReader(&buf)
+ r := NewReader(br)
+ for i := 0; i < b.N; i++ {
+ var want int
+ var find string
+ if (i & 1) == 1 {
+ buf.WriteString(clientHeaders)
+ want = 10
+ find = "Cookie"
+ } else {
+ buf.WriteString(serverHeaders)
+ want = 9
+ find = "Via"
+ }
+ h, err := r.ReadMIMEHeader()
+ if err != nil {
+ b.Fatal(err)
+ }
+ if len(h) != want {
+ b.Fatalf("wrong number of headers: got %d, want %d", len(h), want)
+ }
+ if _, ok := h[find]; !ok {
+ b.Fatalf("did not find key %s", find)
+ }
+ }
+}
+
+func BenchmarkUncommon(b *testing.B) {
+ var buf bytes.Buffer
+ br := bufio.NewReader(&buf)
+ r := NewReader(br)
+ for i := 0; i < b.N; i++ {
+ buf.WriteString("uncommon-header-for-benchmark: foo\r\n\r\n")
+ h, err := r.ReadMIMEHeader()
+ if err != nil {
+ b.Fatal(err)
+ }
+ if _, ok := h["Uncommon-Header-For-Benchmark"]; !ok {
+ b.Fatal("Missing result header.")
+ }
+ }
+}
diff --git a/src/pkg/net/textproto/textproto.go b/src/pkg/net/textproto/textproto.go
index ad5840cf7..eb6ced1c5 100644
--- a/src/pkg/net/textproto/textproto.go
+++ b/src/pkg/net/textproto/textproto.go
@@ -121,3 +121,34 @@ func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err error) {
}
return id, nil
}
+
+// TrimString returns s without leading and trailing ASCII space.
+func TrimString(s string) string {
+ for len(s) > 0 && isASCIISpace(s[0]) {
+ s = s[1:]
+ }
+ for len(s) > 0 && isASCIISpace(s[len(s)-1]) {
+ s = s[:len(s)-1]
+ }
+ return s
+}
+
+// TrimBytes returns b without leading and trailing ASCII space.
+func TrimBytes(b []byte) []byte {
+ for len(b) > 0 && isASCIISpace(b[0]) {
+ b = b[1:]
+ }
+ for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+ b = b[:len(b)-1]
+ }
+ return b
+}
+
+func isASCIISpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
+}
+
+func isASCIILetter(b byte) bool {
+ b |= 0x20 // make lower case
+ return 'a' <= b && b <= 'z'
+}
diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go
index 672fb7241..0260efcc0 100644
--- a/src/pkg/net/timeout_test.go
+++ b/src/pkg/net/timeout_test.go
@@ -6,11 +6,187 @@ package net
import (
"fmt"
+ "io"
+ "io/ioutil"
"runtime"
"testing"
"time"
)
+func isTimeout(err error) bool {
+ e, ok := err.(Error)
+ return ok && e.Timeout()
+}
+
+type copyRes struct {
+ n int64
+ err error
+ d time.Duration
+}
+
+func TestAcceptTimeout(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t).(*TCPListener)
+ defer ln.Close()
+ ln.SetDeadline(time.Now().Add(-1 * time.Second))
+ if _, err := ln.Accept(); !isTimeout(err) {
+ t.Fatalf("Accept: expected err %v, got %v", errTimeout, err)
+ }
+ if _, err := ln.Accept(); !isTimeout(err) {
+ t.Fatalf("Accept: expected err %v, got %v", errTimeout, err)
+ }
+ ln.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ if _, err := ln.Accept(); !isTimeout(err) {
+ t.Fatalf("Accept: expected err %v, got %v", errTimeout, err)
+ }
+ if _, err := ln.Accept(); !isTimeout(err) {
+ t.Fatalf("Accept: expected err %v, got %v", errTimeout, err)
+ }
+ ln.SetDeadline(noDeadline)
+ errc := make(chan error)
+ go func() {
+ _, err := ln.Accept()
+ errc <- err
+ }()
+ time.Sleep(100 * time.Millisecond)
+ select {
+ case err := <-errc:
+ t.Fatalf("Expected Accept() to not return, but it returned with %v\n", err)
+ default:
+ }
+ ln.Close()
+ switch nerr := <-errc; err := nerr.(type) {
+ case *OpError:
+ if err.Err != errClosing {
+ t.Fatalf("Accept: expected err %v, got %v", errClosing, err)
+ }
+ default:
+ if err != errClosing {
+ t.Fatalf("Accept: expected err %v, got %v", errClosing, err)
+ }
+ }
+}
+
+func TestReadTimeout(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+ c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr))
+ if err != nil {
+ t.Fatalf("Connect: %v", err)
+ }
+ defer c.Close()
+ c.SetDeadline(time.Now().Add(time.Hour))
+ c.SetReadDeadline(time.Now().Add(-1 * time.Second))
+ buf := make([]byte, 1)
+ if _, err = c.Read(buf); !isTimeout(err) {
+ t.Fatalf("Read: expected err %v, got %v", errTimeout, err)
+ }
+ if _, err = c.Read(buf); !isTimeout(err) {
+ t.Fatalf("Read: expected err %v, got %v", errTimeout, err)
+ }
+ c.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ if _, err = c.Read(buf); !isTimeout(err) {
+ t.Fatalf("Read: expected err %v, got %v", errTimeout, err)
+ }
+ if _, err = c.Read(buf); !isTimeout(err) {
+ t.Fatalf("Read: expected err %v, got %v", errTimeout, err)
+ }
+ c.SetReadDeadline(noDeadline)
+ c.SetWriteDeadline(time.Now().Add(-1 * time.Second))
+ errc := make(chan error)
+ go func() {
+ _, err := c.Read(buf)
+ errc <- err
+ }()
+ time.Sleep(100 * time.Millisecond)
+ select {
+ case err := <-errc:
+ t.Fatalf("Expected Read() to not return, but it returned with %v\n", err)
+ default:
+ }
+ c.Close()
+ switch nerr := <-errc; err := nerr.(type) {
+ case *OpError:
+ if err.Err != errClosing {
+ t.Fatalf("Read: expected err %v, got %v", errClosing, err)
+ }
+ default:
+ if err != errClosing {
+ t.Fatalf("Read: expected err %v, got %v", errClosing, err)
+ }
+ }
+}
+
+func TestWriteTimeout(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+ c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr))
+ if err != nil {
+ t.Fatalf("Connect: %v", err)
+ }
+ defer c.Close()
+ c.SetDeadline(time.Now().Add(time.Hour))
+ c.SetWriteDeadline(time.Now().Add(-1 * time.Second))
+ buf := make([]byte, 4096)
+ writeUntilTimeout := func() {
+ for {
+ _, err := c.Write(buf)
+ if err != nil {
+ if isTimeout(err) {
+ return
+ }
+ t.Fatalf("Write: expected err %v, got %v", errTimeout, err)
+ }
+ }
+ }
+ writeUntilTimeout()
+ c.SetDeadline(time.Now().Add(10 * time.Millisecond))
+ writeUntilTimeout()
+ writeUntilTimeout()
+ c.SetWriteDeadline(noDeadline)
+ c.SetReadDeadline(time.Now().Add(-1 * time.Second))
+ errc := make(chan error)
+ go func() {
+ for {
+ _, err := c.Write(buf)
+ if err != nil {
+ errc <- err
+ }
+ }
+ }()
+ time.Sleep(100 * time.Millisecond)
+ select {
+ case err := <-errc:
+ t.Fatalf("Expected Write() to not return, but it returned with %v\n", err)
+ default:
+ }
+ c.Close()
+ switch nerr := <-errc; err := nerr.(type) {
+ case *OpError:
+ if err.Err != errClosing {
+ t.Fatalf("Write: expected err %v, got %v", errClosing, err)
+ }
+ default:
+ if err != errClosing {
+ t.Fatalf("Write: expected err %v, got %v", errClosing, err)
+ }
+ }
+}
+
func testTimeout(t *testing.T, net, addr string, readFrom bool) {
c, err := Dial(net, addr)
if err != nil {
@@ -59,8 +235,7 @@ func testTimeout(t *testing.T, net, addr string, readFrom bool) {
func TestTimeoutUDP(t *testing.T) {
switch runtime.GOOS {
case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
// set up a listener that won't talk back
@@ -77,8 +252,7 @@ func TestTimeoutUDP(t *testing.T) {
func TestTimeoutTCP(t *testing.T) {
switch runtime.GOOS {
case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
// set up a listener that won't talk back
@@ -94,8 +268,7 @@ func TestTimeoutTCP(t *testing.T) {
func TestDeadlineReset(t *testing.T) {
switch runtime.GOOS {
case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
ln, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
@@ -104,7 +277,7 @@ func TestDeadlineReset(t *testing.T) {
defer ln.Close()
tl := ln.(*TCPListener)
tl.SetDeadline(time.Now().Add(1 * time.Minute))
- tl.SetDeadline(time.Time{}) // reset it
+ tl.SetDeadline(noDeadline) // reset it
errc := make(chan error, 1)
go func() {
_, err := ln.Accept()
@@ -119,3 +292,356 @@ func TestDeadlineReset(t *testing.T) {
t.Errorf("unexpected return from Accept; err=%v", err)
}
}
+
+func TestTimeoutAccept(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+ ln, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ tl := ln.(*TCPListener)
+ tl.SetDeadline(time.Now().Add(100 * time.Millisecond))
+ errc := make(chan error, 1)
+ go func() {
+ _, err := ln.Accept()
+ errc <- err
+ }()
+ select {
+ case <-time.After(1 * time.Second):
+ // Accept shouldn't block indefinitely
+ t.Errorf("Accept didn't return in an expected time")
+ case <-errc:
+ // Pass.
+ }
+}
+
+func TestReadWriteDeadline(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ if !canCancelIO {
+ t.Skip("skipping test on this system")
+ }
+ const (
+ readTimeout = 50 * time.Millisecond
+ writeTimeout = 250 * time.Millisecond
+ )
+ checkTimeout := func(command string, start time.Time, should time.Duration) {
+ is := time.Now().Sub(start)
+ d := is - should
+ if d < -30*time.Millisecond || !testing.Short() && 150*time.Millisecond < d {
+ t.Errorf("%s timeout test failed: is=%v should=%v\n", command, is, should)
+ }
+ }
+
+ ln, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("ListenTCP on :0: %v", err)
+ }
+ defer ln.Close()
+
+ lnquit := make(chan bool)
+
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("Accept: %v", err)
+ }
+ defer c.Close()
+ lnquit <- true
+ }()
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer c.Close()
+
+ start := time.Now()
+ err = c.SetReadDeadline(start.Add(readTimeout))
+ if err != nil {
+ t.Fatalf("SetReadDeadline: %v", err)
+ }
+ err = c.SetWriteDeadline(start.Add(writeTimeout))
+ if err != nil {
+ t.Fatalf("SetWriteDeadline: %v", err)
+ }
+
+ quit := make(chan bool)
+
+ go func() {
+ var buf [10]byte
+ _, err := c.Read(buf[:])
+ if err == nil {
+ t.Errorf("Read should not succeed")
+ }
+ checkTimeout("Read", start, readTimeout)
+ quit <- true
+ }()
+
+ go func() {
+ var buf [10000]byte
+ for {
+ _, err := c.Write(buf[:])
+ if err != nil {
+ break
+ }
+ }
+ checkTimeout("Write", start, writeTimeout)
+ quit <- true
+ }()
+
+ <-quit
+ <-quit
+ <-lnquit
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+func TestVariousDeadlines1Proc(t *testing.T) {
+ testVariousDeadlines(t, 1)
+}
+
+func TestVariousDeadlines4Proc(t *testing.T) {
+ testVariousDeadlines(t, 4)
+}
+
+func testVariousDeadlines(t *testing.T, maxProcs int) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+ ln := newLocalListener(t)
+ defer ln.Close()
+ acceptc := make(chan error, 1)
+
+ // The server, with no timeouts of its own, sending bytes to clients
+ // as fast as it can.
+ servec := make(chan copyRes)
+ go func() {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ acceptc <- err
+ return
+ }
+ go func() {
+ t0 := time.Now()
+ n, err := io.Copy(c, neverEnding('a'))
+ d := time.Since(t0)
+ c.Close()
+ servec <- copyRes{n, err, d}
+ }()
+ }
+ }()
+
+ for _, timeout := range []time.Duration{
+ 1 * time.Nanosecond,
+ 2 * time.Nanosecond,
+ 5 * time.Nanosecond,
+ 50 * time.Nanosecond,
+ 100 * time.Nanosecond,
+ 200 * time.Nanosecond,
+ 500 * time.Nanosecond,
+ 750 * time.Nanosecond,
+ 1 * time.Microsecond,
+ 5 * time.Microsecond,
+ 25 * time.Microsecond,
+ 250 * time.Microsecond,
+ 500 * time.Microsecond,
+ 1 * time.Millisecond,
+ 5 * time.Millisecond,
+ 100 * time.Millisecond,
+ 250 * time.Millisecond,
+ 500 * time.Millisecond,
+ 1 * time.Second,
+ } {
+ numRuns := 3
+ if testing.Short() {
+ numRuns = 1
+ if timeout > 500*time.Microsecond {
+ continue
+ }
+ }
+ for run := 0; run < numRuns; run++ {
+ name := fmt.Sprintf("%v run %d/%d", timeout, run+1, numRuns)
+ t.Log(name)
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ clientc := make(chan copyRes)
+ go func() {
+ t0 := time.Now()
+ c.SetDeadline(t0.Add(timeout))
+ n, err := io.Copy(ioutil.Discard, c)
+ d := time.Since(t0)
+ c.Close()
+ clientc <- copyRes{n, err, d}
+ }()
+
+ const tooLong = 2000 * time.Millisecond
+ select {
+ case res := <-clientc:
+ if isTimeout(res.err) {
+ t.Logf("for %v, good client timeout after %v, reading %d bytes", name, res.d, res.n)
+ } else {
+ t.Fatalf("for %v: client Copy = %d, %v (want timeout)", name, res.n, res.err)
+ }
+ case <-time.After(tooLong):
+ t.Fatalf("for %v: timeout (%v) waiting for client to timeout (%v) reading", name, tooLong, timeout)
+ }
+
+ select {
+ case res := <-servec:
+ t.Logf("for %v: server in %v wrote %d, %v", name, res.d, res.n, res.err)
+ case err := <-acceptc:
+ t.Fatalf("for %v: server Accept = %v", name, err)
+ case <-time.After(tooLong):
+ t.Fatalf("for %v, timeout waiting for server to finish writing", name)
+ }
+ }
+ }
+}
+
+// TestReadDeadlineDataAvailable tests that read deadlines work, even
+// if there's data ready to be read.
+func TestReadDeadlineDataAvailable(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ servec := make(chan copyRes)
+ const msg = "data client shouldn't read, even though it it'll be waiting"
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("Accept: %v", err)
+ }
+ defer c.Close()
+ n, err := c.Write([]byte(msg))
+ servec <- copyRes{n: int64(n), err: err}
+ }()
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer c.Close()
+ if res := <-servec; res.err != nil || res.n != int64(len(msg)) {
+ t.Fatalf("unexpected server Write: n=%d, err=%d; want n=%d, err=nil", res.n, res.err, len(msg))
+ }
+ c.SetReadDeadline(time.Now().Add(-5 * time.Second)) // in the psat.
+ buf := make([]byte, len(msg)/2)
+ n, err := c.Read(buf)
+ if n > 0 || !isTimeout(err) {
+ t.Fatalf("client read = %d (%q) err=%v; want 0, timeout", n, buf[:n], err)
+ }
+}
+
+// TestWriteDeadlineBufferAvailable tests that write deadlines work, even
+// if there's buffer space available to write.
+func TestWriteDeadlineBufferAvailable(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ servec := make(chan copyRes)
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("Accept: %v", err)
+ }
+ defer c.Close()
+ c.SetWriteDeadline(time.Now().Add(-5 * time.Second)) // in the past
+ n, err := c.Write([]byte{'x'})
+ servec <- copyRes{n: int64(n), err: err}
+ }()
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer c.Close()
+ res := <-servec
+ if res.n != 0 {
+ t.Errorf("Write = %d; want 0", res.n)
+ }
+ if !isTimeout(res.err) {
+ t.Errorf("Write error = %v; want timeout", res.err)
+ }
+}
+
+// TestProlongTimeout tests concurrent deadline modification.
+// Known to cause data races in the past.
+func TestProlongTimeout(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+ connected := make(chan bool)
+ go func() {
+ s, err := ln.Accept()
+ connected <- true
+ if err != nil {
+ t.Fatalf("ln.Accept: %v", err)
+ }
+ defer s.Close()
+ s.SetDeadline(time.Now().Add(time.Hour))
+ go func() {
+ var buf [4096]byte
+ for {
+ _, err := s.Write(buf[:])
+ if err != nil {
+ break
+ }
+ s.SetDeadline(time.Now().Add(time.Hour))
+ }
+ }()
+ buf := make([]byte, 1)
+ for {
+ _, err := s.Read(buf)
+ if err != nil {
+ break
+ }
+ s.SetDeadline(time.Now().Add(time.Hour))
+ }
+ }()
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatalf("DialTCP: %v", err)
+ }
+ defer c.Close()
+ <-connected
+ for i := 0; i < 1024; i++ {
+ var buf [1]byte
+ c.Write(buf[:])
+ }
+}
diff --git a/src/pkg/net/udp_test.go b/src/pkg/net/udp_test.go
index f80d3b5a9..220422e13 100644
--- a/src/pkg/net/udp_test.go
+++ b/src/pkg/net/udp_test.go
@@ -5,15 +5,45 @@
package net
import (
+ "reflect"
"runtime"
"testing"
)
+var resolveUDPAddrTests = []struct {
+ net string
+ litAddr string
+ addr *UDPAddr
+ err error
+}{
+ {"udp", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil},
+ {"udp4", "127.0.0.1:65535", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil},
+
+ {"udp", "[::1]:1", &UDPAddr{IP: ParseIP("::1"), Port: 1}, nil},
+ {"udp6", "[::1]:65534", &UDPAddr{IP: ParseIP("::1"), Port: 65534}, nil},
+
+ {"", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior
+ {"", "[::1]:0", &UDPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior
+
+ {"sip", "127.0.0.1:0", nil, UnknownNetworkError("sip")},
+}
+
+func TestResolveUDPAddr(t *testing.T) {
+ for _, tt := range resolveUDPAddrTests {
+ addr, err := ResolveUDPAddr(tt.net, tt.litAddr)
+ if err != tt.err {
+ t.Fatalf("ResolveUDPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err)
+ }
+ if !reflect.DeepEqual(addr, tt.addr) {
+ t.Fatalf("got %#v; expected %#v", addr, tt.addr)
+ }
+ }
+}
+
func TestWriteToUDP(t *testing.T) {
switch runtime.GOOS {
case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
l, err := ListenPacket("udp", "127.0.0.1:0")
@@ -87,3 +117,32 @@ func testWriteToPacketConn(t *testing.T, raddr string) {
t.Fatal("Write should fail")
}
}
+
+var udpConnLocalNameTests = []struct {
+ net string
+ laddr *UDPAddr
+}{
+ {"udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}},
+ {"udp4", &UDPAddr{}},
+ {"udp4", nil},
+}
+
+func TestUDPConnLocalName(t *testing.T) {
+ if testing.Short() || !*testExternal {
+ t.Skip("skipping test to avoid external network")
+ }
+
+ for _, tt := range udpConnLocalNameTests {
+ c, err := ListenUDP(tt.net, tt.laddr)
+ if err != nil {
+ t.Errorf("ListenUDP failed: %v", err)
+ return
+ }
+ defer c.Close()
+ la := c.LocalAddr()
+ if a, ok := la.(*UDPAddr); !ok || a.Port == 0 {
+ t.Errorf("got %v; expected a proper address with non-zero port number", la)
+ return
+ }
+ }
+}
diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go
index b3520cf09..6e5e90268 100644
--- a/src/pkg/net/udpsock.go
+++ b/src/pkg/net/udpsock.go
@@ -6,10 +6,15 @@
package net
+import "errors"
+
+var ErrWriteToConnected = errors.New("use of WriteTo with pre-connected UDP")
+
// UDPAddr represents the address of a UDP end point.
type UDPAddr struct {
IP IP
Port int
+ Zone string // IPv6 scoped addressing zone
}
// Network returns the address's network name, "udp".
@@ -28,9 +33,16 @@ func (a *UDPAddr) String() string {
// "udp4" or "udp6". A literal IPv6 host address must be
// enclosed in square brackets, as in "[::]:80".
func ResolveUDPAddr(net, addr string) (*UDPAddr, error) {
- ip, port, err := hostPortToIP(net, addr)
+ switch net {
+ case "udp", "udp4", "udp6":
+ case "": // a hint wildcard for Go 1.0 undocumented behavior
+ net = "udp"
+ default:
+ return nil, UnknownNetworkError(net)
+ }
+ a, err := resolveInternetAddr(net, addr, noDeadline)
if err != nil {
return nil, err
}
- return &UDPAddr{ip, port}, nil
+ return a.(*UDPAddr), nil
}
diff --git a/src/pkg/net/udpsock_plan9.go b/src/pkg/net/udpsock_plan9.go
index 4f298a42f..2a7e3d19c 100644
--- a/src/pkg/net/udpsock_plan9.go
+++ b/src/pkg/net/udpsock_plan9.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// UDP for Plan 9
+// UDP sockets for Plan 9
package net
@@ -16,44 +16,26 @@ import (
// UDPConn is the implementation of the Conn and PacketConn
// interfaces for UDP network connections.
type UDPConn struct {
- plan9Conn
+ conn
}
-// SetDeadline implements the Conn SetDeadline method.
-func (c *UDPConn) SetDeadline(t time.Time) error {
- return syscall.EPLAN9
+func newUDPConn(fd *netFD) *UDPConn {
+ return &UDPConn{conn{fd}}
}
-// SetReadDeadline implements the Conn SetReadDeadline method.
-func (c *UDPConn) SetReadDeadline(t time.Time) error {
- return syscall.EPLAN9
-}
-
-// SetWriteDeadline implements the Conn SetWriteDeadline method.
-func (c *UDPConn) SetWriteDeadline(t time.Time) error {
- return syscall.EPLAN9
-}
-
-// UDP-specific methods.
-
// ReadFromUDP reads a UDP packet from c, copying the payload into b.
// It returns the number of bytes copied into b and the return address
// that was on the packet.
//
-// ReadFromUDP can be made to time out and return an error with Timeout() == true
-// after a fixed time limit; see SetDeadline and SetReadDeadline.
+// ReadFromUDP can be made to time out and return an error with
+// Timeout() == true after a fixed time limit; see SetDeadline and
+// SetReadDeadline.
func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) {
- if !c.ok() {
+ if !c.ok() || c.fd.data == nil {
return 0, nil, syscall.EINVAL
}
- if c.data == nil {
- c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0)
- if err != nil {
- return 0, nil, err
- }
- }
buf := make([]byte, udpHeaderSize+len(b))
- m, err := c.data.Read(buf)
+ m, err := c.fd.data.Read(buf)
if err != nil {
return
}
@@ -64,62 +46,80 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) {
h, buf := unmarshalUDPHeader(buf)
n = copy(b, buf)
- return n, &UDPAddr{h.raddr, int(h.rport)}, nil
+ return n, &UDPAddr{IP: h.raddr, Port: int(h.rport)}, nil
}
// ReadFrom implements the PacketConn ReadFrom method.
-func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err error) {
+func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
}
return c.ReadFromUDP(b)
}
-// WriteToUDP writes a UDP packet to addr via c, copying the payload from b.
+// ReadMsgUDP reads a packet from c, copying the payload into b and
+// the associdated out-of-band data into oob. It returns the number
+// of bytes copied into b, the number of bytes copied into oob, the
+// flags that were set on the packet and the source address of the
+// packet.
+func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
+ return 0, 0, 0, nil, syscall.EPLAN9
+}
+
+// WriteToUDP writes a UDP packet to addr via c, copying the payload
+// from b.
//
-// WriteToUDP can be made to time out and return
-// an error with Timeout() == true after a fixed time limit;
-// see SetDeadline and SetWriteDeadline.
-// On packet-oriented connections, write timeouts are rare.
-func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err error) {
- if !c.ok() {
+// WriteToUDP can be made to time out and return an error with
+// Timeout() == true after a fixed time limit; see SetDeadline and
+// SetWriteDeadline. On packet-oriented connections, write timeouts
+// are rare.
+func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) {
+ if !c.ok() || c.fd.data == nil {
return 0, syscall.EINVAL
}
- if c.data == nil {
- c.data, err = os.OpenFile(c.dir+"/data", os.O_RDWR, 0)
- if err != nil {
- return 0, err
- }
- }
h := new(udpHeader)
h.raddr = addr.IP.To16()
- h.laddr = c.laddr.(*UDPAddr).IP.To16()
+ h.laddr = c.fd.laddr.(*UDPAddr).IP.To16()
h.ifcaddr = IPv6zero // ignored (receive only)
h.rport = uint16(addr.Port)
- h.lport = uint16(c.laddr.(*UDPAddr).Port)
+ h.lport = uint16(c.fd.laddr.(*UDPAddr).Port)
buf := make([]byte, udpHeaderSize+len(b))
i := copy(buf, h.Bytes())
copy(buf[i:], b)
- return c.data.Write(buf)
+ return c.fd.data.Write(buf)
}
// WriteTo implements the PacketConn WriteTo method.
-func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err error) {
+func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
a, ok := addr.(*UDPAddr)
if !ok {
- return 0, &OpError{"write", c.dir, addr, syscall.EINVAL}
+ return 0, &OpError{"write", c.fd.dir, addr, syscall.EINVAL}
}
return c.WriteToUDP(b, a)
}
+// WriteMsgUDP writes a packet to addr via c, copying the payload from
+// b and the associated out-of-band data from oob. It returns the
+// number of payload and out-of-band bytes written.
+func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
+ return 0, 0, syscall.EPLAN9
+}
+
// DialUDP connects to the remote address raddr on the network net,
-// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used
-// as the local address for the connection.
-func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err error) {
+// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is
+// used as the local address for the connection.
+func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
+ return dialUDP(net, laddr, raddr, noDeadline)
+}
+
+func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) {
+ if !deadline.IsZero() {
+ panic("net.dialUDP: deadline not implemented on Plan 9")
+ }
switch net {
case "udp", "udp4", "udp6":
default:
@@ -128,11 +128,11 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err error) {
if raddr == nil {
return nil, &OpError{"dial", net, nil, errMissingAddress}
}
- c1, err := dialPlan9(net, laddr, raddr)
+ fd, err := dialPlan9(net, laddr, raddr)
if err != nil {
- return
+ return nil, err
}
- return &UDPConn{*c1}, nil
+ return newUDPConn(fd), nil
}
const udpHeaderSize = 16*3 + 2*2
@@ -163,34 +163,38 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) {
return h, b
}
-// ListenUDP listens for incoming UDP packets addressed to the
-// local address laddr. The returned connection c's ReadFrom
-// and WriteTo methods can be used to receive and send UDP
-// packets with per-packet addressing.
-func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err error) {
+// ListenUDP listens for incoming UDP packets addressed to the local
+// address laddr. The returned connection c's ReadFrom and WriteTo
+// methods can be used to receive and send UDP packets with per-packet
+// addressing.
+func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) {
switch net {
case "udp", "udp4", "udp6":
default:
return nil, UnknownNetworkError(net)
}
if laddr == nil {
- return nil, &OpError{"listen", net, nil, errMissingAddress}
+ laddr = &UDPAddr{}
}
l, err := listenPlan9(net, laddr)
if err != nil {
- return
+ return nil, err
}
_, err = l.ctl.WriteString("headers")
if err != nil {
- return
+ return nil, err
+ }
+ l.data, err = os.OpenFile(l.dir+"/data", os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
}
- return &UDPConn{*l.plan9Conn()}, nil
+ return newUDPConn(l.netFD()), nil
}
// ListenMulticastUDP listens for incoming multicast UDP packets
-// addressed to the group address gaddr on ifi, which specifies
-// the interface to join. ListenMulticastUDP uses default
-// multicast interface if ifi is nil.
+// addressed to the group address gaddr on ifi, which specifies the
+// interface to join. ListenMulticastUDP uses default multicast
+// interface if ifi is nil.
func ListenMulticastUDP(net string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
return nil, syscall.EPLAN9
}
diff --git a/src/pkg/net/udpsock_posix.go b/src/pkg/net/udpsock_posix.go
index 9c6b6d393..385cd902e 100644
--- a/src/pkg/net/udpsock_posix.go
+++ b/src/pkg/net/udpsock_posix.go
@@ -4,25 +4,21 @@
// +build darwin freebsd linux netbsd openbsd windows
-// UDP sockets
+// UDP sockets for POSIX
package net
import (
- "errors"
- "os"
"syscall"
"time"
)
-var ErrWriteToConnected = errors.New("use of WriteTo with pre-connected UDP")
-
func sockaddrToUDP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
- return &UDPAddr{sa.Addr[0:], sa.Port}
+ return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
case *syscall.SockaddrInet6:
- return &UDPAddr{sa.Addr[0:], sa.Port}
+ return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))}
}
return nil
}
@@ -45,7 +41,7 @@ func (a *UDPAddr) isWildcard() bool {
}
func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
- return ipToSockaddr(family, a.IP, a.Port)
+ return ipToSockaddr(family, a.IP, a.Port, a.Zone)
}
func (a *UDPAddr) toAddr() sockaddr {
@@ -58,98 +54,10 @@ func (a *UDPAddr) toAddr() sockaddr {
// UDPConn is the implementation of the Conn and PacketConn
// interfaces for UDP network connections.
type UDPConn struct {
- fd *netFD
-}
-
-func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{fd} }
-
-func (c *UDPConn) ok() bool { return c != nil && c.fd != nil }
-
-// Implementation of the Conn interface - see Conn for documentation.
-
-// Read implements the Conn Read method.
-func (c *UDPConn) Read(b []byte) (int, error) {
- if !c.ok() {
- return 0, syscall.EINVAL
- }
- return c.fd.Read(b)
-}
-
-// Write implements the Conn Write method.
-func (c *UDPConn) Write(b []byte) (int, error) {
- if !c.ok() {
- return 0, syscall.EINVAL
- }
- return c.fd.Write(b)
-}
-
-// Close closes the UDP connection.
-func (c *UDPConn) Close() error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return c.fd.Close()
-}
-
-// LocalAddr returns the local network address.
-func (c *UDPConn) LocalAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.fd.laddr
-}
-
-// RemoteAddr returns the remote network address, a *UDPAddr.
-func (c *UDPConn) RemoteAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.fd.raddr
-}
-
-// SetDeadline implements the Conn SetDeadline method.
-func (c *UDPConn) SetDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setDeadline(c.fd, t)
-}
-
-// SetReadDeadline implements the Conn SetReadDeadline method.
-func (c *UDPConn) SetReadDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setReadDeadline(c.fd, t)
-}
-
-// SetWriteDeadline implements the Conn SetWriteDeadline method.
-func (c *UDPConn) SetWriteDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setWriteDeadline(c.fd, t)
-}
-
-// SetReadBuffer sets the size of the operating system's
-// receive buffer associated with the connection.
-func (c *UDPConn) SetReadBuffer(bytes int) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setReadBuffer(c.fd, bytes)
-}
-
-// SetWriteBuffer sets the size of the operating system's
-// transmit buffer associated with the connection.
-func (c *UDPConn) SetWriteBuffer(bytes int) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setWriteBuffer(c.fd, bytes)
+ conn
}
-// UDP-specific methods.
+func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
// ReadFromUDP reads a UDP packet from c, copying the payload into b.
// It returns the number of bytes copied into b and the return address
@@ -164,9 +72,9 @@ func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) {
n, sa, err := c.fd.ReadFrom(b)
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
- addr = &UDPAddr{sa.Addr[0:], sa.Port}
+ addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
case *syscall.SockaddrInet6:
- addr = &UDPAddr{sa.Addr[0:], sa.Port}
+ addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))}
}
return
}
@@ -176,8 +84,28 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
}
- n, uaddr, err := c.ReadFromUDP(b)
- return n, uaddr.toAddr(), err
+ n, addr, err := c.ReadFromUDP(b)
+ return n, addr.toAddr(), err
+}
+
+// ReadMsgUDP reads a packet from c, copying the payload into b and
+// the associdated out-of-band data into oob. It returns the number
+// of bytes copied into b, the number of bytes copied into oob, the
+// flags that were set on the packet and the source address of the
+// packet.
+func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
+ if !c.ok() {
+ return 0, 0, 0, nil, syscall.EINVAL
+ }
+ var sa syscall.Sockaddr
+ n, oobn, flags, sa, err = c.fd.ReadMsg(b, oob)
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
+ case *syscall.SockaddrInet6:
+ addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))}
+ }
+ return
}
// WriteToUDP writes a UDP packet to addr via c, copying the payload from b.
@@ -212,15 +140,31 @@ func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) {
return c.WriteToUDP(b, a)
}
-// File returns a copy of the underlying os.File, set to blocking mode.
-// It is the caller's responsibility to close f when finished.
-// Closing c does not affect f, and closing f does not affect c.
-func (c *UDPConn) File() (f *os.File, err error) { return c.fd.dup() }
+// WriteMsgUDP writes a packet to addr via c, copying the payload from
+// b and the associated out-of-band data from oob. It returns the
+// number of payload and out-of-band bytes written.
+func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
+ if !c.ok() {
+ return 0, 0, syscall.EINVAL
+ }
+ if c.fd.isConnected {
+ return 0, 0, &OpError{"write", c.fd.net, addr, ErrWriteToConnected}
+ }
+ sa, err := addr.sockaddr(c.fd.family)
+ if err != nil {
+ return 0, 0, &OpError{"write", c.fd.net, addr, err}
+ }
+ return c.fd.WriteMsg(b, oob, sa)
+}
// DialUDP connects to the remote address raddr on the network net,
// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used
// as the local address for the connection.
func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
+ return dialUDP(net, laddr, raddr, noDeadline)
+}
+
+func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) {
switch net {
case "udp", "udp4", "udp6":
default:
@@ -229,7 +173,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
if raddr == nil {
return nil, &OpError{"dial", net, nil, errMissingAddress}
}
- fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP)
+ fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP)
if err != nil {
return nil, err
}
@@ -247,9 +191,9 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) {
return nil, UnknownNetworkError(net)
}
if laddr == nil {
- return nil, &OpError{"listen", net, nil, errMissingAddress}
+ laddr = &UDPAddr{}
}
- fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP)
+ fd, err := internetSocket(net, laddr.toAddr(), nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP)
if err != nil {
return nil, err
}
@@ -267,25 +211,22 @@ func ListenMulticastUDP(net string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, e
return nil, UnknownNetworkError(net)
}
if gaddr == nil || gaddr.IP == nil {
- return nil, &OpError{"listenmulticast", net, nil, errMissingAddress}
+ return nil, &OpError{"listen", net, nil, errMissingAddress}
}
- fd, err := internetSocket(net, gaddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP)
+ fd, err := internetSocket(net, gaddr.toAddr(), nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP)
if err != nil {
return nil, err
}
c := newUDPConn(fd)
- ip4 := gaddr.IP.To4()
- if ip4 != nil {
- err := listenIPv4MulticastUDP(c, ifi, ip4)
- if err != nil {
+ if ip4 := gaddr.IP.To4(); ip4 != nil {
+ if err := listenIPv4MulticastUDP(c, ifi, ip4); err != nil {
c.Close()
- return nil, err
+ return nil, &OpError{"listen", net, &IPAddr{IP: ip4}, err}
}
} else {
- err := listenIPv6MulticastUDP(c, ifi, gaddr.IP)
- if err != nil {
+ if err := listenIPv6MulticastUDP(c, ifi, gaddr.IP); err != nil {
c.Close()
- return nil, err
+ return nil, &OpError{"listen", net, &IPAddr{IP: gaddr.IP}, err}
}
}
return c, nil
@@ -293,17 +234,14 @@ func ListenMulticastUDP(net string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, e
func listenIPv4MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error {
if ifi != nil {
- err := setIPv4MulticastInterface(c.fd, ifi)
- if err != nil {
+ if err := setIPv4MulticastInterface(c.fd, ifi); err != nil {
return err
}
}
- err := setIPv4MulticastLoopback(c.fd, false)
- if err != nil {
+ if err := setIPv4MulticastLoopback(c.fd, false); err != nil {
return err
}
- err = joinIPv4GroupUDP(c, ifi, ip)
- if err != nil {
+ if err := joinIPv4Group(c.fd, ifi, ip); err != nil {
return err
}
return nil
@@ -311,50 +249,15 @@ func listenIPv4MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error {
func listenIPv6MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error {
if ifi != nil {
- err := setIPv6MulticastInterface(c.fd, ifi)
- if err != nil {
+ if err := setIPv6MulticastInterface(c.fd, ifi); err != nil {
return err
}
}
- err := setIPv6MulticastLoopback(c.fd, false)
- if err != nil {
+ if err := setIPv6MulticastLoopback(c.fd, false); err != nil {
return err
}
- err = joinIPv6GroupUDP(c, ifi, ip)
- if err != nil {
+ if err := joinIPv6Group(c.fd, ifi, ip); err != nil {
return err
}
return nil
}
-
-func joinIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
- err := joinIPv4Group(c.fd, ifi, ip)
- if err != nil {
- return &OpError{"joinipv4group", c.fd.net, &IPAddr{ip}, err}
- }
- return nil
-}
-
-func leaveIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
- err := leaveIPv4Group(c.fd, ifi, ip)
- if err != nil {
- return &OpError{"leaveipv4group", c.fd.net, &IPAddr{ip}, err}
- }
- return nil
-}
-
-func joinIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
- err := joinIPv6Group(c.fd, ifi, ip)
- if err != nil {
- return &OpError{"joinipv6group", c.fd.net, &IPAddr{ip}, err}
- }
- return nil
-}
-
-func leaveIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
- err := leaveIPv6Group(c.fd, ifi, ip)
- if err != nil {
- return &OpError{"leaveipv6group", c.fd.net, &IPAddr{ip}, err}
- }
- return nil
-}
diff --git a/src/pkg/net/unicast_test.go b/src/pkg/net/unicast_posix_test.go
index e5dd013db..a8855cab7 100644
--- a/src/pkg/net/unicast_test.go
+++ b/src/pkg/net/unicast_posix_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !plan9
+
package net
import (
@@ -44,8 +46,7 @@ var listenerTests = []struct {
func TestTCPListener(t *testing.T) {
switch runtime.GOOS {
case "plan9", "windows":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
for _, tt := range listenerTests {
@@ -59,13 +60,6 @@ func TestTCPListener(t *testing.T) {
checkFirstListener(t, tt.net, tt.laddr+":"+port, l1)
l2, err := Listen(tt.net, tt.laddr+":"+port)
checkSecondListener(t, tt.net, tt.laddr+":"+port, err, l2)
- fd := l1.(*TCPListener).fd
- switch fd.family {
- case syscall.AF_INET:
- testIPv4UnicastSocketOptions(t, fd)
- case syscall.AF_INET6:
- testIPv6UnicastSocketOptions(t, fd)
- }
l1.Close()
}
}
@@ -76,8 +70,7 @@ func TestTCPListener(t *testing.T) {
func TestUDPListener(t *testing.T) {
switch runtime.GOOS {
case "plan9", "windows":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
toudpnet := func(net string) string {
@@ -104,13 +97,6 @@ func TestUDPListener(t *testing.T) {
checkFirstListener(t, tt.net, tt.laddr+":"+port, l1)
l2, err := ListenPacket(tt.net, tt.laddr+":"+port)
checkSecondListener(t, tt.net, tt.laddr+":"+port, err, l2)
- fd := l1.(*UDPConn).fd
- switch fd.family {
- case syscall.AF_INET:
- testIPv4UnicastSocketOptions(t, fd)
- case syscall.AF_INET6:
- testIPv6UnicastSocketOptions(t, fd)
- }
l1.Close()
}
}
@@ -118,7 +104,7 @@ func TestUDPListener(t *testing.T) {
func TestSimpleTCPListener(t *testing.T) {
switch runtime.GOOS {
case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
+ t.Skipf("skipping test on %q", runtime.GOOS)
return
}
@@ -140,7 +126,7 @@ func TestSimpleTCPListener(t *testing.T) {
func TestSimpleUDPListener(t *testing.T) {
switch runtime.GOOS {
case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
+ t.Skipf("skipping test on %q", runtime.GOOS)
return
}
@@ -183,9 +169,9 @@ var dualStackListenerTests = []struct {
// Test cases and expected results for the attemping 2nd listen on the same port
// 1st listen 2nd listen darwin freebsd linux openbsd
// ------------------------------------------------------------------------------------
- // "tcp" "" "tcp" "" - - - -
- // "tcp" "" "tcp" "0.0.0.0" - - - -
- // "tcp" "0.0.0.0" "tcp" "" - - - -
+ // "tcp" "" "tcp" "" - - - -
+ // "tcp" "" "tcp" "0.0.0.0" - - - -
+ // "tcp" "0.0.0.0" "tcp" "" - - - -
// ------------------------------------------------------------------------------------
// "tcp" "" "tcp" "[::]" - - - ok
// "tcp" "[::]" "tcp" "" - - - ok
@@ -242,8 +228,7 @@ var dualStackListenerTests = []struct {
func TestDualStackTCPListener(t *testing.T) {
switch runtime.GOOS {
case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
if !supportsIPv6 {
return
@@ -275,8 +260,7 @@ func TestDualStackTCPListener(t *testing.T) {
func TestDualStackUDPListener(t *testing.T) {
switch runtime.GOOS {
case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
if !supportsIPv6 {
return
@@ -468,44 +452,6 @@ func checkDualStackAddrFamily(t *testing.T, net, laddr string, fd *netFD) {
}
}
-func testIPv4UnicastSocketOptions(t *testing.T, fd *netFD) {
- _, err := ipv4TOS(fd)
- if err != nil {
- t.Fatalf("ipv4TOS failed: %v", err)
- }
- err = setIPv4TOS(fd, 1)
- if err != nil {
- t.Fatalf("setIPv4TOS failed: %v", err)
- }
- _, err = ipv4TTL(fd)
- if err != nil {
- t.Fatalf("ipv4TTL failed: %v", err)
- }
- err = setIPv4TTL(fd, 1)
- if err != nil {
- t.Fatalf("setIPv4TTL failed: %v", err)
- }
-}
-
-func testIPv6UnicastSocketOptions(t *testing.T, fd *netFD) {
- _, err := ipv6TrafficClass(fd)
- if err != nil {
- t.Fatalf("ipv6TrafficClass failed: %v", err)
- }
- err = setIPv6TrafficClass(fd, 1)
- if err != nil {
- t.Fatalf("setIPv6TrafficClass failed: %v", err)
- }
- _, err = ipv6HopLimit(fd)
- if err != nil {
- t.Fatalf("ipv6HopLimit failed: %v", err)
- }
- err = setIPv6HopLimit(fd, 1)
- if err != nil {
- t.Fatalf("setIPv6HopLimit failed: %v", err)
- }
-}
-
var prohibitionaryDialArgTests = []struct {
net string
addr string
@@ -517,8 +463,7 @@ var prohibitionaryDialArgTests = []struct {
func TestProhibitionaryDialArgs(t *testing.T) {
switch runtime.GOOS {
case "plan9":
- t.Logf("skipping test on %q", runtime.GOOS)
- return
+ t.Skipf("skipping test on %q", runtime.GOOS)
}
// This test requires both IPv6 and IPv6 IPv4-mapping functionality.
if !supportsIPv4map || testing.Short() || !*testExternal {
@@ -536,3 +481,36 @@ func TestProhibitionaryDialArgs(t *testing.T) {
}
}
}
+
+func TestWildWildcardListener(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+
+ if testing.Short() || !*testExternal {
+ t.Skip("skipping test to avoid external network")
+ }
+
+ defer func() {
+ if recover() != nil {
+ t.Fatalf("panicked")
+ }
+ }()
+
+ if ln, err := Listen("tcp", ""); err == nil {
+ ln.Close()
+ }
+ if ln, err := ListenPacket("udp", ""); err == nil {
+ ln.Close()
+ }
+ if ln, err := ListenTCP("tcp", nil); err == nil {
+ ln.Close()
+ }
+ if ln, err := ListenUDP("udp", nil); err == nil {
+ ln.Close()
+ }
+ if ln, err := ListenIP("ip:icmp", nil); err == nil {
+ ln.Close()
+ }
+}
diff --git a/src/pkg/net/unix_test.go b/src/pkg/net/unix_test.go
new file mode 100644
index 000000000..2eaabe86e
--- /dev/null
+++ b/src/pkg/net/unix_test.go
@@ -0,0 +1,144 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !plan9,!windows
+
+package net
+
+import (
+ "bytes"
+ "os"
+ "reflect"
+ "runtime"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func TestReadUnixgramWithUnnamedSocket(t *testing.T) {
+ addr := testUnixAddr()
+ la, err := ResolveUnixAddr("unixgram", addr)
+ if err != nil {
+ t.Fatalf("ResolveUnixAddr failed: %v", err)
+ }
+ c, err := ListenUnixgram("unixgram", la)
+ if err != nil {
+ t.Fatalf("ListenUnixgram failed: %v", err)
+ }
+ defer func() {
+ c.Close()
+ os.Remove(addr)
+ }()
+
+ off := make(chan bool)
+ data := [5]byte{1, 2, 3, 4, 5}
+
+ go func() {
+ defer func() { off <- true }()
+ s, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ t.Errorf("syscall.Socket failed: %v", err)
+ return
+ }
+ defer syscall.Close(s)
+ rsa := &syscall.SockaddrUnix{Name: addr}
+ if err := syscall.Sendto(s, data[:], 0, rsa); err != nil {
+ t.Errorf("syscall.Sendto failed: %v", err)
+ return
+ }
+ }()
+
+ <-off
+ b := make([]byte, 64)
+ c.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ n, from, err := c.ReadFrom(b)
+ if err != nil {
+ t.Errorf("UnixConn.ReadFrom failed: %v", err)
+ return
+ }
+ if from != nil {
+ t.Errorf("neighbor address is %v", from)
+ }
+ if !bytes.Equal(b[:n], data[:]) {
+ t.Errorf("got %v, want %v", b[:n], data[:])
+ return
+ }
+}
+
+func TestReadUnixgramWithZeroBytesBuffer(t *testing.T) {
+ // issue 4352: Recvfrom failed with "address family not
+ // supported by protocol family" if zero-length buffer provided
+
+ addr := testUnixAddr()
+ la, err := ResolveUnixAddr("unixgram", addr)
+ if err != nil {
+ t.Fatalf("ResolveUnixAddr failed: %v", err)
+ }
+ c, err := ListenUnixgram("unixgram", la)
+ if err != nil {
+ t.Fatalf("ListenUnixgram failed: %v", err)
+ }
+ defer func() {
+ c.Close()
+ os.Remove(addr)
+ }()
+
+ off := make(chan bool)
+ go func() {
+ defer func() { off <- true }()
+ c, err := DialUnix("unixgram", nil, la)
+ if err != nil {
+ t.Errorf("DialUnix failed: %v", err)
+ return
+ }
+ defer c.Close()
+ if _, err := c.Write([]byte{1, 2, 3, 4, 5}); err != nil {
+ t.Errorf("UnixConn.Write failed: %v", err)
+ return
+ }
+ }()
+
+ <-off
+ c.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ var peer Addr
+ if _, peer, err = c.ReadFrom(nil); err != nil {
+ t.Errorf("UnixConn.ReadFrom failed: %v", err)
+ return
+ }
+ if peer != nil {
+ t.Errorf("peer adddress is %v", peer)
+ }
+}
+
+func TestUnixAutobind(t *testing.T) {
+ if runtime.GOOS != "linux" {
+ t.Skip("skipping: autobind is linux only")
+ }
+
+ laddr := &UnixAddr{Name: "", Net: "unixgram"}
+ c1, err := ListenUnixgram("unixgram", laddr)
+ if err != nil {
+ t.Fatalf("ListenUnixgram failed: %v", err)
+ }
+ defer c1.Close()
+
+ // retrieve the autobind address
+ autoAddr := c1.LocalAddr().(*UnixAddr)
+ if len(autoAddr.Name) <= 1 {
+ t.Fatalf("Invalid autobind address: %v", autoAddr)
+ }
+ if autoAddr.Name[0] != '@' {
+ t.Fatalf("Invalid autobind address: %v", autoAddr)
+ }
+
+ c2, err := DialUnix("unixgram", nil, autoAddr)
+ if err != nil {
+ t.Fatalf("DialUnix failed: %v", err)
+ }
+ defer c2.Close()
+
+ if !reflect.DeepEqual(c1.LocalAddr(), c2.RemoteAddr()) {
+ t.Fatalf("Expected autobind address %v, got %v", c1.LocalAddr(), c2.RemoteAddr())
+ }
+}
diff --git a/src/pkg/net/unixsock_plan9.go b/src/pkg/net/unixsock_plan9.go
index 7b4ae6bd1..00a0be5b0 100644
--- a/src/pkg/net/unixsock_plan9.go
+++ b/src/pkg/net/unixsock_plan9.go
@@ -7,100 +7,135 @@
package net
import (
+ "os"
"syscall"
"time"
)
-// UnixConn is an implementation of the Conn interface
-// for connections to Unix domain sockets.
-type UnixConn bool
-
-// Implementation of the Conn interface - see Conn for documentation.
+// UnixConn is an implementation of the Conn interface for connections
+// to Unix domain sockets.
+type UnixConn struct {
+ conn
+}
-// Read implements the Conn Read method.
-func (c *UnixConn) Read(b []byte) (n int, err error) {
- return 0, syscall.EPLAN9
+// ReadFromUnix reads a packet from c, copying the payload into b. It
+// returns the number of bytes copied into b and the source address of
+// the packet.
+//
+// ReadFromUnix can be made to time out and return an error with
+// Timeout() == true after a fixed time limit; see SetDeadline and
+// SetReadDeadline.
+func (c *UnixConn) ReadFromUnix(b []byte) (int, *UnixAddr, error) {
+ return 0, nil, syscall.EPLAN9
}
-// Write implements the Conn Write method.
-func (c *UnixConn) Write(b []byte) (n int, err error) {
- return 0, syscall.EPLAN9
+// ReadFrom implements the PacketConn ReadFrom method.
+func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) {
+ return 0, nil, syscall.EPLAN9
}
-// Close closes the Unix domain connection.
-func (c *UnixConn) Close() error {
- return syscall.EPLAN9
+// ReadMsgUnix reads a packet from c, copying the payload into b and
+// the associated out-of-band data into oob. It returns the number of
+// bytes copied into b, the number of bytes copied into oob, the flags
+// that were set on the packet, and the source address of the packet.
+func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
+ return 0, 0, 0, nil, syscall.EPLAN9
}
-// LocalAddr returns the local network address, a *UnixAddr.
-// Unlike in other protocols, LocalAddr is usually nil for dialed connections.
-func (c *UnixConn) LocalAddr() Addr {
- return nil
+// WriteToUnix writes a packet to addr via c, copying the payload from b.
+//
+// WriteToUnix can be made to time out and return an error with
+// Timeout() == true after a fixed time limit; see SetDeadline and
+// SetWriteDeadline. On packet-oriented connections, write timeouts
+// are rare.
+func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (int, error) {
+ return 0, syscall.EPLAN9
}
-// RemoteAddr returns the remote network address, a *UnixAddr.
-// Unlike in other protocols, RemoteAddr is usually nil for connections
-// accepted by a listener.
-func (c *UnixConn) RemoteAddr() Addr {
- return nil
+// WriteTo implements the PacketConn WriteTo method.
+func (c *UnixConn) WriteTo(b []byte, addr Addr) (int, error) {
+ return 0, syscall.EPLAN9
}
-// SetDeadline implements the Conn SetDeadline method.
-func (c *UnixConn) SetDeadline(t time.Time) error {
- return syscall.EPLAN9
+// WriteMsgUnix writes a packet to addr via c, copying the payload
+// from b and the associated out-of-band data from oob. It returns
+// the number of payload and out-of-band bytes written.
+func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) {
+ return 0, 0, syscall.EPLAN9
}
-// SetReadDeadline implements the Conn SetReadDeadline method.
-func (c *UnixConn) SetReadDeadline(t time.Time) error {
+// CloseRead shuts down the reading side of the Unix domain connection.
+// Most callers should just use Close.
+func (c *UnixConn) CloseRead() error {
return syscall.EPLAN9
}
-// SetWriteDeadline implements the Conn SetWriteDeadline method.
-func (c *UnixConn) SetWriteDeadline(t time.Time) error {
+// CloseWrite shuts down the writing side of the Unix domain connection.
+// Most callers should just use Close.
+func (c *UnixConn) CloseWrite() error {
return syscall.EPLAN9
}
-// ReadFrom implements the PacketConn ReadFrom method.
-func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err error) {
- err = syscall.EPLAN9
- return
+// DialUnix connects to the remote address raddr on the network net,
+// which must be "unix", "unixgram" or "unixpacket". If laddr is not
+// nil, it is used as the local address for the connection.
+func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
+ return dialUnix(net, laddr, raddr, noDeadline)
}
-// WriteTo implements the PacketConn WriteTo method.
-func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err error) {
- err = syscall.EPLAN9
- return
+func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) {
+ return nil, syscall.EPLAN9
}
-// DialUnix connects to the remote address raddr on the network net,
-// which must be "unix" or "unixgram". If laddr is not nil, it is used
-// as the local address for the connection.
-func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err error) {
+// UnixListener is a Unix domain socket listener. Clients should
+// typically use variables of type Listener instead of assuming Unix
+// domain sockets.
+type UnixListener struct{}
+
+// ListenUnix announces on the Unix domain socket laddr and returns a
+// Unix listener. The network net must be "unix" or "unixpacket".
+func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) {
return nil, syscall.EPLAN9
}
-// UnixListener is a Unix domain socket listener.
-// Clients should typically use variables of type Listener
-// instead of assuming Unix domain sockets.
-type UnixListener bool
-
-// ListenUnix announces on the Unix domain socket laddr and returns a Unix listener.
-// Net must be "unix" (stream sockets).
-func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err error) {
+// AcceptUnix accepts the next incoming call and returns the new
+// connection and the remote address.
+func (l *UnixListener) AcceptUnix() (*UnixConn, error) {
return nil, syscall.EPLAN9
}
-// Accept implements the Accept method in the Listener interface;
-// it waits for the next call and returns a generic Conn.
-func (l *UnixListener) Accept() (c Conn, err error) {
+// Accept implements the Accept method in the Listener interface; it
+// waits for the next call and returns a generic Conn.
+func (l *UnixListener) Accept() (Conn, error) {
return nil, syscall.EPLAN9
}
-// Close stops listening on the Unix address.
-// Already accepted connections are not closed.
+// Close stops listening on the Unix address. Already accepted
+// connections are not closed.
func (l *UnixListener) Close() error {
return syscall.EPLAN9
}
// Addr returns the listener's network address.
func (l *UnixListener) Addr() Addr { return nil }
+
+// SetDeadline sets the deadline associated with the listener.
+// A zero time value disables the deadline.
+func (l *UnixListener) SetDeadline(t time.Time) error {
+ return syscall.EPLAN9
+}
+
+// File returns a copy of the underlying os.File, set to blocking
+// mode. It is the caller's responsibility to close f when finished.
+// Closing l does not affect f, and closing f does not affect l.
+func (l *UnixListener) File() (*os.File, error) {
+ return nil, syscall.EPLAN9
+}
+
+// ListenUnixgram listens for incoming Unix datagram packets addressed
+// to the local address laddr. The returned connection c's ReadFrom
+// and WriteTo methods can be used to receive and send packets with
+// per-packet addressing. The network net must be "unixgram".
+func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) {
+ return nil, syscall.EPLAN9
+}
diff --git a/src/pkg/net/unixsock_posix.go b/src/pkg/net/unixsock_posix.go
index 57d784c71..6d6ce3f5e 100644
--- a/src/pkg/net/unixsock_posix.go
+++ b/src/pkg/net/unixsock_posix.go
@@ -9,29 +9,27 @@
package net
import (
+ "errors"
"os"
"syscall"
"time"
)
-func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err error) {
+func unixSocket(net string, laddr, raddr *UnixAddr, mode string, deadline time.Time) (*netFD, error) {
var sotype int
switch net {
- default:
- return nil, UnknownNetworkError(net)
case "unix":
sotype = syscall.SOCK_STREAM
case "unixgram":
sotype = syscall.SOCK_DGRAM
case "unixpacket":
sotype = syscall.SOCK_SEQPACKET
+ default:
+ return nil, UnknownNetworkError(net)
}
var la, ra syscall.Sockaddr
switch mode {
- default:
- panic("unixSocket mode " + mode)
-
case "dial":
if laddr != nil {
la = &syscall.SockaddrUnix{Name: laddr.Name}
@@ -41,15 +39,10 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err
} else if sotype != syscall.SOCK_DGRAM || laddr == nil {
return nil, &OpError{Op: mode, Net: net, Err: errMissingAddress}
}
-
case "listen":
- if laddr == nil {
- return nil, &OpError{mode, net, nil, errMissingAddress}
- }
la = &syscall.SockaddrUnix{Name: laddr.Name}
- if raddr != nil {
- return nil, &OpError{Op: mode, Net: net, Addr: raddr, Err: &AddrError{Err: "unexpected remote address", Addr: raddr.String()}}
- }
+ default:
+ return nil, errors.New("unknown mode: " + mode)
}
f := sockaddrToUnix
@@ -59,15 +52,16 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err
f = sockaddrToUnixpacket
}
- fd, err = socket(net, syscall.AF_UNIX, sotype, 0, false, la, ra, f)
+ fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, la, ra, deadline, f)
if err != nil {
- goto Error
+ goto error
}
return fd, nil
-Error:
+error:
addr := raddr
- if mode == "listen" {
+ switch mode {
+ case "listen":
addr = laddr
}
return nil, &OpError{Op: mode, Net: net, Addr: addr, Err: err}
@@ -108,110 +102,21 @@ func sotypeToNet(sotype int) string {
return ""
}
-// UnixConn is an implementation of the Conn interface
-// for connections to Unix domain sockets.
+// UnixConn is an implementation of the Conn interface for connections
+// to Unix domain sockets.
type UnixConn struct {
- fd *netFD
-}
-
-func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{fd} }
-
-func (c *UnixConn) ok() bool { return c != nil && c.fd != nil }
-
-// Implementation of the Conn interface - see Conn for documentation.
-
-// Read implements the Conn Read method.
-func (c *UnixConn) Read(b []byte) (n int, err error) {
- if !c.ok() {
- return 0, syscall.EINVAL
- }
- return c.fd.Read(b)
-}
-
-// Write implements the Conn Write method.
-func (c *UnixConn) Write(b []byte) (n int, err error) {
- if !c.ok() {
- return 0, syscall.EINVAL
- }
- return c.fd.Write(b)
-}
-
-// Close closes the Unix domain connection.
-func (c *UnixConn) Close() error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return c.fd.Close()
-}
-
-// LocalAddr returns the local network address, a *UnixAddr.
-// Unlike in other protocols, LocalAddr is usually nil for dialed connections.
-func (c *UnixConn) LocalAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.fd.laddr
+ conn
}
-// RemoteAddr returns the remote network address, a *UnixAddr.
-// Unlike in other protocols, RemoteAddr is usually nil for connections
-// accepted by a listener.
-func (c *UnixConn) RemoteAddr() Addr {
- if !c.ok() {
- return nil
- }
- return c.fd.raddr
-}
+func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
-// SetDeadline implements the Conn SetDeadline method.
-func (c *UnixConn) SetDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setDeadline(c.fd, t)
-}
-
-// SetReadDeadline implements the Conn SetReadDeadline method.
-func (c *UnixConn) SetReadDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setReadDeadline(c.fd, t)
-}
-
-// SetWriteDeadline implements the Conn SetWriteDeadline method.
-func (c *UnixConn) SetWriteDeadline(t time.Time) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setWriteDeadline(c.fd, t)
-}
-
-// SetReadBuffer sets the size of the operating system's
-// receive buffer associated with the connection.
-func (c *UnixConn) SetReadBuffer(bytes int) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setReadBuffer(c.fd, bytes)
-}
-
-// SetWriteBuffer sets the size of the operating system's
-// transmit buffer associated with the connection.
-func (c *UnixConn) SetWriteBuffer(bytes int) error {
- if !c.ok() {
- return syscall.EINVAL
- }
- return setWriteBuffer(c.fd, bytes)
-}
-
-// ReadFromUnix reads a packet from c, copying the payload into b.
-// It returns the number of bytes copied into b and the source address
-// of the packet.
+// ReadFromUnix reads a packet from c, copying the payload into b. It
+// returns the number of bytes copied into b and the source address of
+// the packet.
//
-// ReadFromUnix can be made to time out and return
-// an error with Timeout() == true after a fixed time limit;
-// see SetDeadline and SetReadDeadline.
+// ReadFromUnix can be made to time out and return an error with
+// Timeout() == true after a fixed time limit; see SetDeadline and
+// SetReadDeadline.
func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
@@ -219,26 +124,46 @@ func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err error) {
n, sa, err := c.fd.ReadFrom(b)
switch sa := sa.(type) {
case *syscall.SockaddrUnix:
- addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)}
+ if sa.Name != "" {
+ addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)}
+ }
}
return
}
// ReadFrom implements the PacketConn ReadFrom method.
-func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err error) {
+func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
}
- n, uaddr, err := c.ReadFromUnix(b)
- return n, uaddr.toAddr(), err
+ n, addr, err := c.ReadFromUnix(b)
+ return n, addr.toAddr(), err
+}
+
+// ReadMsgUnix reads a packet from c, copying the payload into b and
+// the associated out-of-band data into oob. It returns the number of
+// bytes copied into b, the number of bytes copied into oob, the flags
+// that were set on the packet, and the source address of the packet.
+func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
+ if !c.ok() {
+ return 0, 0, 0, nil, syscall.EINVAL
+ }
+ n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob)
+ switch sa := sa.(type) {
+ case *syscall.SockaddrUnix:
+ if sa.Name != "" {
+ addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)}
+ }
+ }
+ return
}
// WriteToUnix writes a packet to addr via c, copying the payload from b.
//
-// WriteToUnix can be made to time out and return
-// an error with Timeout() == true after a fixed time limit;
-// see SetDeadline and SetWriteDeadline.
-// On packet-oriented connections, write timeouts are rare.
+// WriteToUnix can be made to time out and return an error with
+// Timeout() == true after a fixed time limit; see SetDeadline and
+// SetWriteDeadline. On packet-oriented connections, write timeouts
+// are rare.
func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err error) {
if !c.ok() {
return 0, syscall.EINVAL
@@ -262,26 +187,9 @@ func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err error) {
return c.WriteToUnix(b, a)
}
-// ReadMsgUnix reads a packet from c, copying the payload into b
-// and the associated out-of-band data into oob.
-// It returns the number of bytes copied into b, the number of
-// bytes copied into oob, the flags that were set on the packet,
-// and the source address of the packet.
-func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
- if !c.ok() {
- return 0, 0, 0, nil, syscall.EINVAL
- }
- n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob)
- switch sa := sa.(type) {
- case *syscall.SockaddrUnix:
- addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)}
- }
- return
-}
-
-// WriteMsgUnix writes a packet to addr via c, copying the payload from b
-// and the associated out-of-band data from oob. It returns the number
-// of payload and out-of-band bytes written.
+// WriteMsgUnix writes a packet to addr via c, copying the payload
+// from b and the associated out-of-band data from oob. It returns
+// the number of payload and out-of-band bytes written.
func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) {
if !c.ok() {
return 0, 0, syscall.EINVAL
@@ -296,40 +204,64 @@ func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err
return c.fd.WriteMsg(b, oob, nil)
}
-// File returns a copy of the underlying os.File, set to blocking mode.
-// It is the caller's responsibility to close f when finished.
-// Closing c does not affect f, and closing f does not affect c.
-func (c *UnixConn) File() (f *os.File, err error) { return c.fd.dup() }
+// CloseRead shuts down the reading side of the Unix domain connection.
+// Most callers should just use Close.
+func (c *UnixConn) CloseRead() error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ return c.fd.CloseRead()
+}
+
+// CloseWrite shuts down the writing side of the Unix domain connection.
+// Most callers should just use Close.
+func (c *UnixConn) CloseWrite() error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ return c.fd.CloseWrite()
+}
// DialUnix connects to the remote address raddr on the network net,
-// which must be "unix" or "unixgram". If laddr is not nil, it is used
-// as the local address for the connection.
+// which must be "unix", "unixgram" or "unixpacket". If laddr is not
+// nil, it is used as the local address for the connection.
func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
- fd, err := unixSocket(net, laddr, raddr, "dial")
+ return dialUnix(net, laddr, raddr, noDeadline)
+}
+
+func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) {
+ switch net {
+ case "unix", "unixgram", "unixpacket":
+ default:
+ return nil, UnknownNetworkError(net)
+ }
+ fd, err := unixSocket(net, laddr, raddr, "dial", deadline)
if err != nil {
return nil, err
}
return newUnixConn(fd), nil
}
-// UnixListener is a Unix domain socket listener.
-// Clients should typically use variables of type Listener
-// instead of assuming Unix domain sockets.
+// UnixListener is a Unix domain socket listener. Clients should
+// typically use variables of type Listener instead of assuming Unix
+// domain sockets.
type UnixListener struct {
fd *netFD
path string
}
-// ListenUnix announces on the Unix domain socket laddr and returns a Unix listener.
-// Net must be "unix" (stream sockets).
+// ListenUnix announces on the Unix domain socket laddr and returns a
+// Unix listener. The network net must be "unix" or "unixpacket".
func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) {
- if net != "unix" && net != "unixgram" && net != "unixpacket" {
+ switch net {
+ case "unix", "unixpacket":
+ default:
return nil, UnknownNetworkError(net)
}
- if laddr != nil {
- laddr = &UnixAddr{laddr.Name, net} // make our own copy
+ if laddr == nil {
+ return nil, &OpError{"listen", net, nil, errMissingAddress}
}
- fd, err := unixSocket(net, laddr, nil, "listen")
+ fd, err := unixSocket(net, laddr, nil, "listen", noDeadline)
if err != nil {
return nil, err
}
@@ -341,8 +273,8 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) {
return &UnixListener{fd, laddr.Name}, nil
}
-// AcceptUnix accepts the next incoming call and returns the new connection
-// and the remote address.
+// AcceptUnix accepts the next incoming call and returns the new
+// connection and the remote address.
func (l *UnixListener) AcceptUnix() (*UnixConn, error) {
if l == nil || l.fd == nil {
return nil, syscall.EINVAL
@@ -355,8 +287,8 @@ func (l *UnixListener) AcceptUnix() (*UnixConn, error) {
return c, nil
}
-// Accept implements the Accept method in the Listener interface;
-// it waits for the next call and returns a generic Conn.
+// Accept implements the Accept method in the Listener interface; it
+// waits for the next call and returns a generic Conn.
func (l *UnixListener) Accept() (c Conn, err error) {
c1, err := l.AcceptUnix()
if err != nil {
@@ -365,8 +297,8 @@ func (l *UnixListener) Accept() (c Conn, err error) {
return c1, nil
}
-// Close stops listening on the Unix address.
-// Already accepted connections are not closed.
+// Close stops listening on the Unix address. Already accepted
+// connections are not closed.
func (l *UnixListener) Close() error {
if l == nil || l.fd == nil {
return syscall.EINVAL
@@ -385,9 +317,7 @@ func (l *UnixListener) Close() error {
if l.path[0] != '@' {
syscall.Unlink(l.path)
}
- err := l.fd.Close()
- l.fd = nil
- return err
+ return l.fd.Close()
}
// Addr returns the listener's network address.
@@ -402,16 +332,16 @@ func (l *UnixListener) SetDeadline(t time.Time) (err error) {
return setDeadline(l.fd, t)
}
-// File returns a copy of the underlying os.File, set to blocking mode.
-// It is the caller's responsibility to close f when finished.
+// File returns a copy of the underlying os.File, set to blocking
+// mode. It is the caller's responsibility to close f when finished.
// Closing l does not affect f, and closing f does not affect l.
func (l *UnixListener) File() (f *os.File, err error) { return l.fd.dup() }
-// ListenUnixgram listens for incoming Unix datagram packets addressed to the
-// local address laddr. The returned connection c's ReadFrom
-// and WriteTo methods can be used to receive and send UDP
-// packets with per-packet addressing. The network net must be "unixgram".
-func ListenUnixgram(net string, laddr *UnixAddr) (*UDPConn, error) {
+// ListenUnixgram listens for incoming Unix datagram packets addressed
+// to the local address laddr. The returned connection c's ReadFrom
+// and WriteTo methods can be used to receive and send packets with
+// per-packet addressing. The network net must be "unixgram".
+func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) {
switch net {
case "unixgram":
default:
@@ -420,9 +350,9 @@ func ListenUnixgram(net string, laddr *UnixAddr) (*UDPConn, error) {
if laddr == nil {
return nil, &OpError{"listen", net, nil, errMissingAddress}
}
- fd, err := unixSocket(net, laddr, nil, "listen")
+ fd, err := unixSocket(net, laddr, nil, "listen", noDeadline)
if err != nil {
return nil, err
}
- return newUDPConn(fd), nil
+ return newUnixConn(fd), nil
}
diff --git a/src/pkg/net/url/url.go b/src/pkg/net/url/url.go
index 17bf0d3a3..a39964ea1 100644
--- a/src/pkg/net/url/url.go
+++ b/src/pkg/net/url/url.go
@@ -7,7 +7,9 @@
package url
import (
+ "bytes"
"errors"
+ "sort"
"strconv"
"strings"
)
@@ -218,11 +220,18 @@ func escape(s string, mode encoding) string {
//
// scheme:opaque[?query][#fragment]
//
+// Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/.
+// A consequence is that it is impossible to tell which slashes in the Path were
+// slashes in the raw URL and which were %2f. This distinction is rarely important,
+// but when it is a client must use other routines to parse the raw URL or construct
+// the parsed URL. For example, an HTTP server can consult req.RequestURI, and
+// an HTTP client can use URL{Host: "example.com", Opaque: "//example.com/Go%2f"}
+// instead of URL{Host: "example.com", Path: "/Go/"}.
type URL struct {
Scheme string
Opaque string // encoded opaque data
User *Userinfo // username and password information
- Host string
+ Host string // host or host:port
Path string
RawQuery string // encoded query values, without '?'
Fragment string // fragment for references, without '#'
@@ -359,11 +368,17 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) {
}
url = new(URL)
+ if rawurl == "*" {
+ url.Path = "*"
+ return
+ }
+
// Split off possible leading "http:", "mailto:", etc.
// Cannot contain escaped characters.
if url.Scheme, rest, err = getscheme(rawurl); err != nil {
goto Error
}
+ url.Scheme = strings.ToLower(url.Scheme)
rest, url.RawQuery = split(rest, '?', true)
@@ -379,7 +394,7 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) {
}
}
- if (url.Scheme != "" || !viaRequest) && strings.HasPrefix(rest, "//") && !strings.HasPrefix(rest, "///") {
+ if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") {
var authority string
authority, rest = split(rest[2:], '/', false)
url.User, url.Host, err = parseAuthority(authority)
@@ -427,30 +442,35 @@ func parseAuthority(authority string) (user *Userinfo, host string, err error) {
// String reassembles the URL into a valid URL string.
func (u *URL) String() string {
- // TODO: Rewrite to use bytes.Buffer
- result := ""
+ var buf bytes.Buffer
if u.Scheme != "" {
- result += u.Scheme + ":"
+ buf.WriteString(u.Scheme)
+ buf.WriteByte(':')
}
if u.Opaque != "" {
- result += u.Opaque
+ buf.WriteString(u.Opaque)
} else {
- if u.Host != "" || u.User != nil {
- result += "//"
+ if u.Scheme != "" || u.Host != "" || u.User != nil {
+ buf.WriteString("//")
if u := u.User; u != nil {
- result += u.String() + "@"
+ buf.WriteString(u.String())
+ buf.WriteByte('@')
+ }
+ if h := u.Host; h != "" {
+ buf.WriteString(h)
}
- result += u.Host
}
- result += escape(u.Path, encodePath)
+ buf.WriteString(escape(u.Path, encodePath))
}
if u.RawQuery != "" {
- result += "?" + u.RawQuery
+ buf.WriteByte('?')
+ buf.WriteString(u.RawQuery)
}
if u.Fragment != "" {
- result += "#" + escape(u.Fragment, encodeFragment)
+ buf.WriteByte('#')
+ buf.WriteString(escape(u.Fragment, encodeFragment))
}
- return result
+ return buf.String()
}
// Values maps a string key to a list of values.
@@ -519,12 +539,16 @@ func parseQuery(m Values, query string) (err error) {
}
key, err1 := QueryUnescape(key)
if err1 != nil {
- err = err1
+ if err == nil {
+ err = err1
+ }
continue
}
value, err1 = QueryUnescape(value)
if err1 != nil {
- err = err1
+ if err == nil {
+ err = err1
+ }
continue
}
m[key] = append(m[key], value)
@@ -538,14 +562,24 @@ func (v Values) Encode() string {
if v == nil {
return ""
}
- parts := make([]string, 0, len(v)) // will be large enough for most uses
- for k, vs := range v {
+ var buf bytes.Buffer
+ keys := make([]string, 0, len(v))
+ for k := range v {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ for _, k := range keys {
+ vs := v[k]
prefix := QueryEscape(k) + "="
for _, v := range vs {
- parts = append(parts, prefix+QueryEscape(v))
+ if buf.Len() > 0 {
+ buf.WriteByte('&')
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(QueryEscape(v))
}
}
- return strings.Join(parts, "&")
+ return buf.String()
}
// resolvePath applies special path segments from refs and applies
@@ -556,23 +590,33 @@ func resolvePath(basepath string, refpath string) string {
if len(base) == 0 {
base = []string{""}
}
+
+ rm := true
for idx, ref := range refs {
switch {
case ref == ".":
- base[len(base)-1] = ""
+ if idx == 0 {
+ base[len(base)-1] = ""
+ rm = true
+ } else {
+ rm = false
+ }
case ref == "..":
newLen := len(base) - 1
if newLen < 1 {
newLen = 1
}
base = base[0:newLen]
- base[len(base)-1] = ""
+ if rm {
+ base[len(base)-1] = ""
+ }
default:
if idx == 0 || base[len(base)-1] == "" {
base[len(base)-1] = ref
} else {
base = append(base, ref)
}
+ rm = false
}
}
return strings.Join(base, "/")
@@ -650,6 +694,10 @@ func (u *URL) RequestURI() string {
if result == "" {
result = "/"
}
+ } else {
+ if strings.HasPrefix(result, "//") {
+ result = u.Scheme + ":" + result
+ }
}
if u.RawQuery != "" {
result += "?" + u.RawQuery
diff --git a/src/pkg/net/url/url_test.go b/src/pkg/net/url/url_test.go
index 75e8abe4e..4c4f406c2 100644
--- a/src/pkg/net/url/url_test.go
+++ b/src/pkg/net/url/url_test.go
@@ -7,6 +7,7 @@ package url
import (
"fmt"
"reflect"
+ "strings"
"testing"
)
@@ -121,14 +122,14 @@ var urltests = []URLTest{
},
"http:%2f%2fwww.google.com/?q=go+language",
},
- // non-authority
+ // non-authority with path
{
"mailto:/webmaster@golang.org",
&URL{
Scheme: "mailto",
Path: "/webmaster@golang.org",
},
- "",
+ "mailto:///webmaster@golang.org", // unfortunate compromise
},
// non-authority
{
@@ -241,6 +242,24 @@ var urltests = []URLTest{
},
"http://www.google.com/?q=go+language#foo&bar",
},
+ {
+ "file:///home/adg/rabbits",
+ &URL{
+ Scheme: "file",
+ Host: "",
+ Path: "/home/adg/rabbits",
+ },
+ "file:///home/adg/rabbits",
+ },
+ // case-insensitive scheme
+ {
+ "MaIlTo:webmaster@golang.org",
+ &URL{
+ Scheme: "mailto",
+ Opaque: "webmaster@golang.org",
+ },
+ "mailto:webmaster@golang.org",
+ },
}
// more useful string for debugging than fmt's struct printer
@@ -270,13 +289,37 @@ func DoTest(t *testing.T, parse func(string) (*URL, error), name string, tests [
}
}
+func BenchmarkString(b *testing.B) {
+ b.StopTimer()
+ b.ReportAllocs()
+ for _, tt := range urltests {
+ u, err := Parse(tt.in)
+ if err != nil {
+ b.Errorf("Parse(%q) returned error %s", tt.in, err)
+ continue
+ }
+ if tt.roundtrip == "" {
+ continue
+ }
+ b.StartTimer()
+ var g string
+ for i := 0; i < b.N; i++ {
+ g = u.String()
+ }
+ b.StopTimer()
+ if w := tt.roundtrip; g != w {
+ b.Errorf("Parse(%q).String() == %q, want %q", tt.in, g, w)
+ }
+ }
+}
+
func TestParse(t *testing.T) {
DoTest(t, Parse, "Parse", urltests)
}
const pathThatLooksSchemeRelative = "//not.a.user@not.a.host/just/a/path"
-var parseRequestUrlTests = []struct {
+var parseRequestURLTests = []struct {
url string
expectedValid bool
}{
@@ -288,10 +331,11 @@ var parseRequestUrlTests = []struct {
{"//not.a.user@%66%6f%6f.com/just/a/path/also", true},
{"foo.html", false},
{"../dir/", false},
+ {"*", true},
}
func TestParseRequestURI(t *testing.T) {
- for _, test := range parseRequestUrlTests {
+ for _, test := range parseRequestURLTests {
_, err := ParseRequestURI(test.url)
valid := err == nil
if valid != test.expectedValid {
@@ -453,20 +497,24 @@ func TestEscape(t *testing.T) {
//}
type EncodeQueryTest struct {
- m Values
- expected string
- expected1 string
+ m Values
+ expected string
}
var encodeQueryTests = []EncodeQueryTest{
- {nil, "", ""},
- {Values{"q": {"puppies"}, "oe": {"utf8"}}, "q=puppies&oe=utf8", "oe=utf8&q=puppies"},
- {Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7", "q=dogs&q=%26&q=7"},
+ {nil, ""},
+ {Values{"q": {"puppies"}, "oe": {"utf8"}}, "oe=utf8&q=puppies"},
+ {Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7"},
+ {Values{
+ "a": {"a1", "a2", "a3"},
+ "b": {"b1", "b2", "b3"},
+ "c": {"c1", "c2", "c3"},
+ }, "a=a1&a=a2&a=a3&b=b1&b=b2&b=b3&c=c1&c=c2&c=c3"},
}
func TestEncodeQuery(t *testing.T) {
for _, tt := range encodeQueryTests {
- if q := tt.m.Encode(); q != tt.expected && q != tt.expected1 {
+ if q := tt.m.Encode(); q != tt.expected {
t.Errorf(`EncodeQuery(%+v) = %q, want %q`, tt.m, q, tt.expected)
}
}
@@ -531,6 +579,15 @@ var resolveReferenceTests = []struct {
{"http://foo.com/bar/baz", "../../../../../quux", "http://foo.com/quux"},
{"http://foo.com/bar", "..", "http://foo.com/"},
{"http://foo.com/bar/baz", "./..", "http://foo.com/"},
+ // ".." in the middle (issue 3560)
+ {"http://foo.com/bar/baz", "quux/dotdot/../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/.././tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/./../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/././../../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/./.././../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/dotdot/./../../.././././tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot"},
// "." and ".." in the base aren't special
{"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/./dotdot/../baz"},
@@ -741,6 +798,24 @@ var requritests = []RequestURITest{
},
"/a%20b",
},
+ // golang.org/issue/4860 variant 1
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Opaque: "/%2F/%2F/",
+ },
+ "/%2F/%2F/",
+ },
+ // golang.org/issue/4860 variant 2
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Opaque: "//other.example.com/%2F/%2F/",
+ },
+ "http://other.example.com/%2F/%2F/",
+ },
{
&URL{
Scheme: "http",
@@ -775,3 +850,13 @@ func TestRequestURI(t *testing.T) {
}
}
}
+
+func TestParseFailure(t *testing.T) {
+ // Test that the first parse error is returned.
+ const url = "%gh&%ij"
+ _, err := ParseQuery(url)
+ errStr := fmt.Sprint(err)
+ if !strings.Contains(errStr, "%gh") {
+ t.Errorf(`ParseQuery(%q) returned error %q, want something containing %q"`, url, errStr, "%gh")
+ }
+}